How to write custom iterator
How to write custom iterator
Iterator类型
C++目前的6种Iterator:
Name | Description | Example |
---|---|---|
Input Iterator | 只能前向遍历一次 只读 | std::istream_iterator |
Output Iterator | 只能前向遍历一次 只写 | std::ostream_iterator / std::front_inserter / std::back_inserter / std::inserter |
Forward Iterator | 可以前向遍历多次 可读可写 | std::forward_list<T>::begin |
Bidirectional Iterator | 可以前向和后向遍历多次 可读可写 | std::list<T>::begin |
Random Access Iterator | 在Bidirectional Iterator 的基础上 可以任意跳转 |
std::vector<T>::begin / std::array<T>::begin / std::deque<T>::begin |
Contiguous Iterator | 在Random Access Iterator 的基础上 要求相邻的元素在内存中连续 |
其中任意跳转指支持如下操作
iter + n, iter - n, iter += n, iter -= n, and iter1 - iter2
iter1 > iter2, iter1 < iter2, iter1 >= iter2, and iter1 <= iter2
虽然
Bidirectional Iterator
通过++
或者--
也能够达成类似iter + n
之类的操作,但是Random Access Iterator
的相应操作是O(1)时间复杂度,也就是跳过N个元素和跳过一个元素的时间是一样的
需要说明的是,不同的Iterator不是通过继承来派生的,仅仅是逻辑上的包含关系(但是下面提到的iterator_tags是通过继承体现的),各个Iterator需要实现的方法如下
Iterator in C++
在C++17及以前,都是通过tag dispatch来指定,在C++20之后都是使用concepts
来约束。
struct input_iterator_tag { };
struct output_iterator_tag { };
struct forward_iterator_tag : public input_iterator_tag { };
struct bidirectional_iterator_tag : public forward_iterator_tag { };
struct random_access_iterator_tag : public bidirectional_iterator_tag { };
struct contiguous_iterator_tag: public random_access_iterator_tag { };
cppreference中的例子:
#include <iostream>
#include <vector>
#include <list>
#include <iterator>
// Using concepts (tag checking is part of the concepts themselves)
template<std::bidirectional_iterator BDIter>
void alg(BDIter, BDIter)
{
std::cout << "1. alg() \t called for bidirectional iterator\n";
}
template<std::random_access_iterator RAIter>
void alg(RAIter, RAIter)
{
std::cout << "2. alg() \t called for random-access iterator\n";
}
// Legacy, using tag dispatch
namespace legacy {
// quite often implementation details are hidden in a dedicated namespace
namespace implementation_details {
template<class BDIter>
void alg(BDIter, BDIter, std::bidirectional_iterator_tag)
{
std::cout << "3. legacy::alg() called for bidirectional iterator\n";
}
template<class RAIter>
void alg(RAIter, RAIter, std::random_access_iterator_tag)
{
std::cout << "4. legacy::alg() called for random-access iterator\n";
}
} // namespace implementation_details
template<class Iter>
void alg(Iter first, Iter last)
{
implementation_details::alg(first, last,
typename std::iterator_traits<Iter>::iterator_category());
}
} // namespace legacy
int main()
{
std::list<int> l;
alg(l.begin(), l.end()); // 1.
legacy::alg(l.begin(), l.end()); // 3.
std::vector<int> v;
alg(v.begin(), v.end()); // 2.
legacy::alg(v.begin(), v.end()); // 4.
// std::istreambuf_iterator<char> i1(std::cin), i2;
// alg(i1, i2); // compile error: no matching function for call
// legacy::alg(i1, i2); // compile error: no matching function for call
}
本文中我们仍然用tag dispatch
来说明,iterator_tag可以通过<iterator>
中提供的iterator_traits
进行比较,比如:
#include <iostream>
#include <iterator>
#include <forward_list>
#include <vector>
template<typename T>
bool test = std::is_same <
typename std::iterator_traits<typename T::iterator>::iterator_category,
std::forward_iterator_tag
>::value;
int main() {
// test wheter std::forward_list's iterator is a forward_iterator
std::cout << std::boolalpha << test<std::forward_list<int>> << std::endl;
// what about std::vector?
std::cout << std::boolalpha << test<std::vector<int>>;
}
Ways to provide custom iterator
Iterators from Nested Containers
如果我们的一个类里面使用了STL容器,然后想对外暴露这个容器的Iterator,那么我们可以直接使用对应容器的Iterator。比如一门课中有一个vector
保存了所有选了这门课的学生列表:
#include <vector>
#include <iostream>
class Course {
private:
using StudentList = std::vector<int>;
StudentList students;
public:
using iterator_category = std::random_access_iterator_tag;
using iterator = StudentList::iterator;
using const_iterator = StudentList::const_iterator;
iterator begin() { return students.begin(); }
iterator end() { return students.end(); }
Course(StudentList&& student_list) : students(std::move(student_list)) {}
};
int main() {
Course course({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
for (const auto& student : course) {
std::cout << student << '\n';
}
std::ostream_iterator<int> outIter(std::cout, "\n");
std::copy(course.begin(), course.end(), outIter);
for (auto iter = course.begin(); iter < course.end(); iter += 2) {
std::cout << *iter << '\n';
}
return 0;
}
其实就是delegation
Iterators from Pointers
如果学生列表是使用C-style的数组保存的呢?比如int students[10];
如果是使用的std::array 就和上面一样 直接使用std::array的Iterator即可
在这种情况我们可以直接使用指针作为Iterator,实际上大多数的Iterator类都是基于指针来进行操作的。
#include <vector>
#include <iostream>
class Course {
private:
int students[10];
public:
using iterator_category = std::random_access_iterator_tag;
using iterator = int*;
using const_iterator = const int*;
iterator begin() { return &students[0]; }
iterator end() { return &students[10]; }
Course(std::initializer_list<int> student_list) {
assert(student_list.size() <= 10);
size_t i = 0;
for (const auto& student : student_list) {
students[i++] = student;
}
}
};
int main() {
Course course({1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
for (const auto& student : course) {
std::cout << student << '\n';
}
std::ostream_iterator<int> outIter(std::cout, "\n");
std::copy(course.begin(), course.end(), outIter);
for (auto iter = course.begin(); iter < course.end(); iter += 2) {
std::cout << *iter << '\n';
}
return 0;
}
当然,上面代码的问题在于:iterator并不能真正知道数组中有多少个元素,如果用{1,2,3,4,5}去初始化,剩下的都是随机值。
What is iterator_category
上面的例子中都出现了iterator_category
,需要说明的是iterator_category
在代码中完全不会被提及,两种情况下会被使用:
- 根据不同Iterator类型,会挑选对应的算法实现。如果类型没有被正确指定,可能会使用非最优的算法。
- 一些STL中的方法对Iterator的类型有要求,比如
std::fill
要求传入一个Forward Iterator
,而std::reverse
要求传入一个Bidirectional Iterator
。如果类型没有被正确指定,可能编译会报错。
我们举个例子,对于Random Access Iterator
如果要向前移动n
次,由于定义了相关的操作符,可以直接操作
template<typename Iter>
void advance_helper(Iter p, int n, random_access_iterator_tag) {
p += n;
}
而对于Forward Iterator
则只能往前走n
次
template<typename Iter>
void advance_helper(Iter p, int n, forward_iterator_tag) {
if (0 < n) {
while (n−−) ++p;
} else if (n < 0) {
while (n++) −−p;
}
}
那么我们可以统一定义一个advance
,根据不同的tag,调用对应的方法。
template<typename Iter>
void advance(Iter p, int n) // use the optimal algorithm {
advance_helper(p, n, typename iterator_traits<Iter>::iterator_category{});
}
除了iterator_category
,所有和Iterator相关的属性都定义在https://en.cppreference.com/w/cpp/iterator/iterator_traits里面。
template<typename Iter> struct iterator_traits {
using value_type = typename Iter::value_type;
using difference_type = typename Iter::difference_type; // type used by distance, e.g. (iter1 - iter2)
using pointer = typename Iter::pointer; // pointer type
using reference = typename Iter::reference; // reference type
using iterator_category = typename Iter::iterator_category; // tag
};
我们可以结合例子来看
Example of custom forward iterator
下面实现了一个MultiRowNum(多行数据 每一行可能有0个或者多个数)的forward iterator
,比如对于如下输入
1
2 3
4 5 6
应该输出1,2,3,4,5,6
(中间的空行被跳过)。
在这里我们只实现了forward iterator
需要的->
,*
,++
,==
,!=
这几个操作符(iterator_category
forward_iterator_tag
)。我们Iterator在迭代过程中,输出的是int。
#include <iostream>
#include <vector>
class MultiRowNum {
using Row = std::vector<int>;
using MultiRow = std::vector<Row>;
public:
class Iterator {
public:
// assign the properties of the iterator
using iterator_category = std::forward_iterator_tag;
using value_type = int;
// used for some algorithm, e.g. std::count, std::distance
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
Iterator(MultiRow* rows, MultiRow::iterator iter) : rows_(rows), iter_(iter) {
// skip empty lines at the first few rows
while (iter_ != rows_->end() && idx_ == iter_->size()) {
++iter_;
idx_ = 0;
}
}
// the address of which iterator points to
pointer operator->() const {
return &(*iter_)[idx_];
}
// the value of which iterator points to
reference operator*() const {
return (*iter_)[idx_];
}
// prefix increment
Iterator& operator++() {
moveToNextNum();
return *this;
}
// postfix increment
Iterator operator++(int) {
Iterator ret = *this;
moveToNextNum();
return ret;
}
bool operator==(const Iterator& other) const {
return iter_ == other.iter_ && idx_ == other.idx_;
};
bool operator!=(const Iterator& other) const {
return !(*this == other);
};
private:
void moveToNextNum() {
if (++idx_ < iter_->size()) {
return;
} else {
idx_ = 0;
do {
++iter_;
} while (iter_ != rows_->end() && idx_ == iter_->size());
}
}
MultiRow* rows_;
MultiRow::iterator iter_;
size_t idx_{0};
};
Iterator begin() {
return Iterator(&data_, data_.begin());
}
Iterator end() {
return Iterator(&data_, data_.end());
}
MultiRowNum(MultiRow&& data) : data_(std::move(data)) {}
private:
MultiRow data_;
};
int main() {
{
MultiRowNum nums({ {1}, {}, {2, 3}, {}, {4, 5, 6}, {} });
for (auto iter = nums.begin(); iter != nums.end(); iter++) {
std::cout << *iter << ' ';
}
std::cout << "\n";
}
{
MultiRowNum nums({ {}, {1}, {}, {2, 3}, {}, {4, 5, 6}, {} });
for (const auto& num : nums) {
std::cout << num << ' ';
}
std::cout << "\n";
}
{
MultiRowNum nums({ {1}, {2, 3}, {4, 5, 6} });
auto iter = nums.begin();
assert(*(iter.operator->()) == *iter);
assert(*(iter.operator->()) == iter.operator*());
assert(iter.operator->() == &(iter.operator*()));
assert(iter.operator->() == std::addressof(iter.operator*()));
assert(std::count(nums.begin(), nums.end(), 1) == 1);
assert(std::distance(nums.begin(), nums.end()) == 6);
}
static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::iterator_category, std::forward_iterator_tag>);
static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::value_type, int>);
static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::difference_type, std::ptrdiff_t>);
static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::pointer, int*>);
static_assert(std::is_same_v<std::iterator_traits<MultiRowNum::Iterator>::reference, int&>);
}
其中的代码都很简单,需要注意的是,我们实现的Iterator会跳过空行,所以在初始化和++
的时候,可能需要移动多次才能使Iterator指向一个有效的位置。
另外,如果我们不指定difference_type
,在使用类似std::count
或者std::distance
之类的算法时会编译报错:
forward_iterator.cpp:111:12: error: no matching function for call to 'count'
assert(std::count(nums.begin(), nums.end(), 1) == 1);
^~~~~~~~~~
/Library/Developer/CommandLineTools/SDKs/MacOSX12.1.sdk/usr/include/assert.h:99:25: note: expanded from macro 'assert'
(__builtin_expect(!(e), 0) ? __assert_rtn(__func__, __ASSERT_FILE_NAME, __LINE__, #e) : (void)0)
^
/Library/Developer/CommandLineTools/SDKs/MacOSX12.1.sdk/usr/include/c++/v1/algorithm:1185:1: note: candidate template ignored: substitution failure [with _InputIterator = MultiRowNum::Iterator, _Tp = int]: no type named 'difference_type' in 'std::iterator_traits<MultiRowNum::Iterator>'
count(_InputIterator __first, _InputIterator __last, const _Tp& __value_)
^
/Library/Developer/CommandLineTools/SDKs/MacOSX12.1.sdk/usr/include/c++/v1/__bit_reference:313:1: note: candidate template ignored: could not match '__bit_iterator<type-parameter-0-0, _IsConst, 0>' against 'MultiRowNum::Iterator'
count(__bit_iterator<_Cp, _IsConst> __first, __bit_iterator<_Cp, _IsConst> __last, const _Tp& __value_)
Example of random access iterator
下面是一个类似folly的Range类的Iterator,比如Range<15, 25>
会输出[15, 25]
之内的所有数。这次我们则实现的是一个Random Access Iterator
,需要实现更多的操作符。
#include <algorithm>
#include <iostream>
template <long FROM, long TO>
class Range {
public:
class Iterator {
long num_ = FROM;
public:
// assign the properties of the iterator
using iterator_category = std::random_access_iterator_tag;
// used for some algorithm, e.g. std::count, std::distance
using difference_type = long;
using value_type = long;
using pointer = const value_type*;
using reference = const value_type&;
explicit Iterator(long num) : num_(num) {}
pointer operator->() const {
return &num_;
}
reference operator*() const {
return num_;
}
Iterator& operator++() {
num_ = TO >= FROM ? num_ + 1 : num_ - 1;
return *this;
}
Iterator& operator--() {
num_ = TO >= FROM ? num_ - 1 : num_ + 1;
return *this;
}
Iterator operator++(int) {
Iterator retval = *this;
++(*this);
return retval;
}
Iterator operator--(int) {
Iterator retval = *this;
--(*this);
return retval;
}
bool operator==(const Iterator& rhs) const {
return num_ == rhs.num_;
}
bool operator!=(const Iterator& rhs) const {
return !(*this == rhs);
}
bool operator<(const Iterator& rhs) const {
return num_ < *rhs;
}
bool operator<=(const Iterator& rhs) const {
return !(rhs < *this);
}
bool operator>(const Iterator& rhs) const {
return rhs < *this;
}
bool operator>=(const Iterator& rhs) const {
return !(*this < rhs);
}
Iterator operator+(const difference_type diff) const {
return Iterator(num_ + diff);
}
Iterator& operator+=(const difference_type diff) {
num_ += diff;
return *this;
}
Iterator operator-(const difference_type diff) const {
return Iterator(num_ - diff);
}
Iterator& operator-=(const difference_type diff) {
num_ -= diff;
return *this;
}
difference_type operator-(const Iterator& rhs) const {
return num_ - *rhs;
}
value_type operator[](size_t idx) {
return num_ + idx;
}
};
Iterator begin() {
return Iterator(FROM);
}
Iterator end() {
return Iterator(TO >= FROM ? TO + 1 : TO - 1);
}
};
int main() {
{
auto range = Range<15, 25>();
auto iter = std::find(range.begin(), range.end(), 18);
std::cout << *iter << '\n';
}
{
for (auto l : Range<5, 3>()) {
std::cout << l << ' ';
}
std::cout << '\n';
}
{
auto rng = Range<1, 10>();
auto iter = rng.begin();
assert(*(iter.operator->()) == iter.operator*());
assert(*(iter.operator->()) == *iter);
assert(iter[5] == 1 + 5);
iter += 10;
assert(iter == rng.end());
iter = std::prev(iter, 10);
assert(iter == rng.begin());
assert(std::count(rng.begin(), rng.end(), 8) == 1);
assert(std::distance(rng.begin(), rng.end()) == 10);
std::cout << "*(iter++): " << *(iter++) << '\n';
std::cout << "*(iter--): " << *(iter--) << '\n';
std::cout << "*(++iter): " << *(++iter) << '\n';
std::cout << "*(--iter): " << *(--iter) << '\n';
std::cout << std::boolalpha
<< "rng.begin() < rng.begin() + 1: " << (rng.begin() < rng.begin() + 1) << '\n';
std::cout << std::boolalpha
<< "rng.begin() + 10 <= rng.end(): " << (rng.begin() + 10 <= rng.end()) << '\n';
std::cout << std::boolalpha << "rng.end() > rng.end() - 1: " << (rng.end() > rng.end() - 1)
<< '\n';
std::cout << std::boolalpha
<< "rng.end() - 10 >= rng.begin(): " << (rng.end() - 10 >= rng.begin()) << '\n';
}
static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::iterator_category, std::random_access_iterator_tag>);
static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::value_type, long>);
static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::difference_type, long>);
static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::pointer, const long*>);
static_assert(std::is_same_v<std::iterator_traits<Range<1, 10>::Iterator>::reference, const long&>);
}
iterator_traits可能会挖新坑 这次就到这吧