How to write custom iterator

Iterator类型

C++目前的6种Iterator:

Name Description Example
Input Iterator 只能前向遍历一次 只读 std::istream_iterator
Output Iterator 只能前向遍历一次 只写 std::ostream_iterator / std::front_inserter / std::back_inserter / std::inserter
Forward Iterator 可以前向遍历多次 可读可写 std::forward_list<T>::begin
Bidirectional Iterator 可以前向和后向遍历多次 可读可写 std::list<T>::begin
Random Access Iterator Bidirectional Iterator的基础上 可以任意跳转 std::vector<T>::begin / std::array<T>::begin / std::deque<T>::begin
Contiguous Iterator Random Access Iterator的基础上 要求相邻的元素在内存中连续  

其中任意跳转指支持如下操作

  • iter + n, iter - n, iter += n, iter -= n, and iter1 - iter2
  • iter1 > iter2, iter1 < iter2, iter1 >= iter2, and iter1 <= iter2

虽然Bidirectional Iterator通过++或者--也能够达成类似iter + n之类的操作,但是Random Access Iterator的相应操作是O(1)时间复杂度,也就是跳过N个元素和跳过一个元素的时间是一样的

figure

需要说明的是,不同的Iterator不是通过继承来派生的,仅仅是逻辑上的包含关系(但是下面提到的iterator_tags是通过继承体现的),各个Iterator需要实现的方法如下

figure

Iterator in C++

在C++17及以前,都是通过tag dispatch来指定,在C++20之后都是使用concepts来约束。

struct input_iterator_tag { };
struct output_iterator_tag { };
struct forward_iterator_tag : public input_iterator_tag { };
struct bidirectional_iterator_tag : public forward_iterator_tag { };
struct random_access_iterator_tag : public bidirectional_iterator_tag { };
struct contiguous_iterator_tag: public random_access_iterator_tag { };

cppreference中的例子:

#include <iostream>
#include <vector>
#include <list>
#include <iterator>

// Using concepts (tag checking is part of the concepts themselves)

template<std::bidirectional_iterator BDIter>
void alg(BDIter, BDIter)
{
    std::cout << "1. alg() \t called for bidirectional iterator\n";
}

template<std::random_access_iterator RAIter>
void alg(RAIter, RAIter)
{
    std::cout << "2. alg() \t called for random-access iterator\n";
}

// Legacy, using tag dispatch

namespace legacy {

// quite often implementation details are hidden in a dedicated namespace
namespace implementation_details {
template<class BDIter>
void alg(BDIter, BDIter, std::bidirectional_iterator_tag)
{
    std::cout << "3. legacy::alg() called for bidirectional iterator\n";
}

template<class RAIter>
void alg(RAIter, RAIter, std::random_access_iterator_tag)
{
    std::cout << "4. legacy::alg() called for random-access iterator\n";
}
} // namespace implementation_details

template<class Iter>
void alg(Iter first, Iter last)
{
    implementation_details::alg(first, last,
        typename std::iterator_traits<Iter>::iterator_category());
}

} // namespace legacy

int main()
{
    std::list<int> l;
    alg(l.begin(), l.end()); // 1.
    legacy::alg(l.begin(), l.end()); // 3.

    std::vector<int> v;
    alg(v.begin(), v.end()); // 2.
    legacy::alg(v.begin(), v.end()); // 4.

    // std::istreambuf_iterator<char> i1(std::cin), i2;
    // alg(i1, i2);         // compile error: no matching function for call
    // legacy::alg(i1, i2); // compile error: no matching function for call
}

本文中我们仍然用tag dispatch来说明,iterator_tag可以通过<iterator>中提供的iterator_traits进行比较,比如:

#include <iostream>
#include <iterator>
#include <forward_list>
#include <vector>

template<typename T>
bool test = std::is_same <
                typename std::iterator_traits<typename T::iterator>::iterator_category,
                std::forward_iterator_tag
            >::value;

int main() {
    // test wheter std::forward_list's iterator is a forward_iterator
    std::cout << std::boolalpha << test<std::forward_list<int>> << std::endl;
    // what about std::vector?
    std::cout << std::boolalpha << test<std::vector<int>>;
}

Ways to provide custom iterator

Iterators from Nested Containers

如果我们的一个类里面使用了STL容器,然后想对外暴露这个容器的Iterator,那么我们可以直接使用对应容器的Iterator。比如一门课中有一个vector保存了所有选了这门课的学生列表:

#include <vector>
#include <iostream>

class Course {
private:
    using StudentList = std::vector<int>;
    StudentList students;

public:
    using iterator_category = std::random_access_iterator_tag;
    using iterator = StudentList::iterator;
    using const_iterator = StudentList::const_iterator;
    iterator begin() { return students.begin(); }
    iterator end() { return students.end(); }

    Course(StudentList&& student_list) : students(std::move(student_list)) {}
};

int main() {
    Course course({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});

    for (const auto& student : course) {
        std::cout << student << '\n';
    }

    std::ostream_iterator<int> outIter(std::cout, "\n");
    std::copy(course.begin(), course.end(), outIter);

    for (auto iter = course.begin(); iter < course.end(); iter += 2) {
        std::cout << *iter << '\n';
    }
    return 0;
}

其实就是delegation

Iterators from Pointers

如果学生列表是使用C-style的数组保存的呢?比如int students[10];

如果是使用的std::array 就和上面一样 直接使用std::array的Iterator即可

在这种情况我们可以直接使用指针作为Iterator,实际上大多数的Iterator类都是基于指针来进行操作的。

#include <vector>
#include <iostream>

class Course {
private:
    int students[10];

public:
    using iterator_category = std::random_access_iterator_tag;
    using iterator = int*;
    using const_iterator = const int*;
    iterator begin() { return &students[0]; }
    iterator end() { return &students[10]; }

    Course(std::initializer_list<int> student_list) {
        assert(student_list.size() <= 10);
        size_t i = 0;
        for (const auto& student : student_list) {
            students[i++] = student;
        }
    }
};

int main() {
    Course course({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});

    for (const auto& student : course) {
        std::cout << student << '\n';
    }

    std::ostream_iterator<int> outIter(std::cout, "\n");
    std::copy(course.begin(), course.end(), outIter);

    for (auto iter = course.begin(); iter < course.end(); iter += 2) {
        std::cout << *iter << '\n';
    }
    return 0;
}

当然,上面代码的问题在于:iterator并不能真正知道数组中有多少个元素,如果用{1,2,3,4,5}去初始化,剩下的都是随机值。

What is iterator_category

上面的例子中都出现了iterator_category,需要说明的是iterator_category在代码中完全不会被提及,两种情况下会被使用:

  • 根据不同Iterator类型,会挑选对应的算法实现。如果类型没有被正确指定,可能会使用非最优的算法。
  • 一些STL中的方法对Iterator的类型有要求,比如std::fill要求传入一个Forward Iterator,而std::reverse要求传入一个Bidirectional Iterator。如果类型没有被正确指定,可能编译会报错。

我们举个例子,对于Random Access Iterator如果要向前移动n次,由于定义了相关的操作符,可以直接操作

template<typename Iter>
void advance_helper(Iter p, int n, random_access_iterator_tag) {
    p += n;
}

而对于Forward Iterator则只能往前走n

template<typename Iter>
void advance_helper(Iter p, int n, forward_iterator_tag) {
    if (0 < n) {
        while (n−−) ++p;
    } else if (n < 0) {
        while (n++) −−p;
    }
}

那么我们可以统一定义一个advance,根据不同的tag,调用对应的方法。

template<typename Iter>
void advance(Iter p, int n) // use the optimal algorithm {
    advance_helper(p, n, typename iterator_traits<Iter>::iterator_category{});
}

除了iterator_category,所有和Iterator相关的属性都定义在https://en.cppreference.com/w/cpp/iterator/iterator_traits里面。

template<typename Iter> struct iterator_traits {
    using value_type = typename Iter::value_type;
    using difference_type = typename Iter::difference_type;         // type used by distance, e.g. (iter1 - iter2)
    using pointer = typename Iter::pointer;                         // pointer type
    using reference = typename Iter::reference;                     // reference type
    using iterator_category = typename Iter::iterator_category;     // tag
};

我们可以结合例子来看

Example of custom forward iterator

下面实现了一个MultiRowNum(多行数据 每一行可能有0个或者多个数)的forward iterator,比如对于如下输入

1

2 3

4 5 6

应该输出1,2,3,4,5,6(中间的空行被跳过)。

在这里我们只实现了forward iterator需要的->*++==!=这几个操作符(iterator_categoryforward_iterator_tag)。我们Iterator在迭代过程中,输出的是int。

#include <iostream>
#include <vector>

class MultiRowNum {
  using Row = std::vector<int>;
  using MultiRow = std::vector<Row>;

 public:
  class Iterator {
   public:
    // assign the properties of the iterator
    using iterator_category = std::forward_iterator_tag;
    using value_type = int;
    // used for some algorithm, e.g. std::count, std::distance
    using difference_type = std::ptrdiff_t;
    using pointer = value_type*;
    using reference = value_type&;

    Iterator(MultiRow* rows, MultiRow::iterator iter) : rows_(rows), iter_(iter) {
      // skip empty lines at the first few rows
      while (iter_ != rows_->end() && idx_ == iter_->size()) {
        ++iter_;
        idx_ = 0;
      }
    }

    // the address of which iterator points to
    pointer operator->() const {
      return &(*iter_)[idx_];
    }

    // the value of which iterator points to
    reference operator*() const {
      return (*iter_)[idx_];
    }

    // prefix increment
    Iterator& operator++() {
      moveToNextNum();
      return *this;
    }

    // postfix increment
    Iterator operator++(int) {
      Iterator ret = *this;
      moveToNextNum();
      return ret;
    }

    bool operator==(const Iterator& other) const {
      return iter_ == other.iter_ && idx_ == other.idx_;
    };

    bool operator!=(const Iterator& other) const {
      return !(*this == other);
    };

   private:
    void moveToNextNum() {
      if (++idx_ < iter_->size()) {
        return;
      } else {
        idx_ = 0;
        do {
          ++iter_;
        } while (iter_ != rows_->end() && idx_ == iter_->size());
      }
    }

    MultiRow* rows_;
    MultiRow::iterator iter_;
    size_t idx_{0};
  };

  Iterator begin() {
    return Iterator(&data_, data_.begin());
  }

  Iterator end() {
    return Iterator(&data_, data_.end());
  }

  MultiRowNum(MultiRow&& data) : data_(std::move(data)) {}

 private:
  MultiRow data_;
};

int main() {
  {
    MultiRowNum nums({ {1}, {}, {2, 3}, {}, {4, 5, 6}, {} });
    for (auto iter = nums.begin(); iter != nums.end(); iter++) {
      std::cout << *iter << ' ';
    }
    std::cout << "\n";
  }
  {
    MultiRowNum nums({ {}, {1}, {}, {2, 3}, {}, {4, 5, 6}, {} });
    for (const auto& num : nums) {
      std::cout << num << ' ';
    }
    std::cout << "\n";
  }
  {
    MultiRowNum nums({ {1}, {2, 3}, {4, 5, 6} });
    auto iter = nums.begin();
    assert(*(iter.operator->()) == *iter);
    assert(*(iter.operator->()) == iter.operator*());
    assert(iter.operator->() == &(iter.operator*()));
    assert(iter.operator->() == std::addressof(iter.operator*()));
    assert(std::count(nums.begin(), nums.end(), 1) == 1);
    assert(std::distance(nums.begin(), nums.end()) == 6);
  }
  static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::iterator_category, std::forward_iterator_tag>);
  static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::value_type, int>);
  static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::difference_type, std::ptrdiff_t>);
  static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::pointer, int*>);
  static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::reference, int&>);
}

其中的代码都很简单,需要注意的是,我们实现的Iterator会跳过空行,所以在初始化和++的时候,可能需要移动多次才能使Iterator指向一个有效的位置。

另外,如果我们不指定difference_type,在使用类似std::count或者std::distance之类的算法时会编译报错:

forward_iterator.cpp:111:12: error: no matching function for call to 'count'
    assert(std::count(nums.begin(), nums.end(), 1) == 1);
           ^~~~~~~~~~
/Library/Developer/CommandLineTools/SDKs/MacOSX12.1.sdk/usr/include/assert.h:99:25: note: expanded from macro 'assert'
    (__builtin_expect(!(e), 0) ? __assert_rtn(__func__, __ASSERT_FILE_NAME, __LINE__, #e) : (void)0)
                        ^
/Library/Developer/CommandLineTools/SDKs/MacOSX12.1.sdk/usr/include/c++/v1/algorithm:1185:1: note: candidate template ignored: substitution failure [with _InputIterator = MultiRowNum::Iterator, _Tp = int]: no type named 'difference_type' in 'std::iterator_traits<MultiRowNum::Iterator>'
count(_InputIterator __first, _InputIterator __last, const _Tp& __value_)
^
/Library/Developer/CommandLineTools/SDKs/MacOSX12.1.sdk/usr/include/c++/v1/__bit_reference:313:1: note: candidate template ignored: could not match '__bit_iterator<type-parameter-0-0, _IsConst, 0>' against 'MultiRowNum::Iterator'
count(__bit_iterator<_Cp, _IsConst> __first, __bit_iterator<_Cp, _IsConst> __last, const _Tp& __value_)

Example of random access iterator

下面是一个类似folly的Range类的Iterator,比如Range<15, 25>会输出[15, 25]之内的所有数。这次我们则实现的是一个Random Access Iterator,需要实现更多的操作符。

#include <algorithm>
#include <iostream>

template <long FROM, long TO>
class Range {
 public:
  class Iterator {
    long num_ = FROM;

   public:
    // assign the properties of the iterator
    using iterator_category = std::random_access_iterator_tag;
    // used for some algorithm, e.g. std::count, std::distance
    using difference_type = long;
    using value_type = long;
    using pointer = const value_type*;
    using reference = const value_type&;

    explicit Iterator(long num) : num_(num) {}

    pointer operator->() const {
      return &num_;
    }

    reference operator*() const {
      return num_;
    }

    Iterator& operator++() {
      num_ = TO >= FROM ? num_ + 1 : num_ - 1;
      return *this;
    }

    Iterator& operator--() {
      num_ = TO >= FROM ? num_ - 1 : num_ + 1;
      return *this;
    }

    Iterator operator++(int) {
      Iterator retval = *this;
      ++(*this);
      return retval;
    }

    Iterator operator--(int) {
      Iterator retval = *this;
      --(*this);
      return retval;
    }

    bool operator==(const Iterator& rhs) const {
      return num_ == rhs.num_;
    }

    bool operator!=(const Iterator& rhs) const {
      return !(*this == rhs);
    }

    bool operator<(const Iterator& rhs) const {
      return num_ < *rhs;
    }

    bool operator<=(const Iterator& rhs) const {
      return !(rhs < *this);
    }

    bool operator>(const Iterator& rhs) const {
      return rhs < *this;
    }

    bool operator>=(const Iterator& rhs) const {
      return !(*this < rhs);
    }

    Iterator operator+(const difference_type diff) const {
      return Iterator(num_ + diff);
    }

    Iterator& operator+=(const difference_type diff) {
      num_ += diff;
      return *this;
    }

    Iterator operator-(const difference_type diff) const {
      return Iterator(num_ - diff);
    }

    Iterator& operator-=(const difference_type diff) {
      num_ -= diff;
      return *this;
    }

    difference_type operator-(const Iterator& rhs) const {
      return num_ - *rhs;
    }

    value_type operator[](size_t idx) {
      return num_ + idx;
    }
  };

  Iterator begin() {
    return Iterator(FROM);
  }
  Iterator end() {
    return Iterator(TO >= FROM ? TO + 1 : TO - 1);
  }
};

int main() {
  {
    auto range = Range<15, 25>();
    auto iter = std::find(range.begin(), range.end(), 18);
    std::cout << *iter << '\n';
  }
  {
    for (auto l : Range<5, 3>()) {
      std::cout << l << ' ';
    }
    std::cout << '\n';
  }
  {
    auto rng = Range<1, 10>();
    auto iter = rng.begin();
    assert(*(iter.operator->()) == iter.operator*());
    assert(*(iter.operator->()) == *iter);
    assert(iter[5] == 1 + 5);
    iter += 10;
    assert(iter == rng.end());
    iter = std::prev(iter, 10);
    assert(iter == rng.begin());
    assert(std::count(rng.begin(), rng.end(), 8) == 1);
    assert(std::distance(rng.begin(), rng.end()) == 10);
    std::cout << "*(iter++): " << *(iter++) << '\n';
    std::cout << "*(iter--): " << *(iter--) << '\n';
    std::cout << "*(++iter): " << *(++iter) << '\n';
    std::cout << "*(--iter): " << *(--iter) << '\n';
    std::cout << std::boolalpha
              << "rng.begin() < rng.begin() + 1: " << (rng.begin() < rng.begin() + 1) << '\n';
    std::cout << std::boolalpha
              << "rng.begin() + 10 <= rng.end(): " << (rng.begin() + 10 <= rng.end()) << '\n';
    std::cout << std::boolalpha << "rng.end() > rng.end() - 1: " << (rng.end() > rng.end() - 1)
              << '\n';
    std::cout << std::boolalpha
              << "rng.end() - 10 >= rng.begin(): " << (rng.end() - 10 >= rng.begin()) << '\n';
  }
  static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::iterator_category, std::random_access_iterator_tag>);
  static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::value_type, long>);
  static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::difference_type, long>);
  static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::pointer, const long*>);
  static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::reference, const long&>);
}

iterator_traits可能会挖新坑 这次就到这吧

Reference

Tags:

Categories:

Updated: