Distributed Ranges
Loading...
Searching...
No Matches
zip.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <cstddef>
8#include <iterator>
9#include <tuple>
10#include <type_traits>
11#include <utility>
12
13#include <dr/detail/ranges_shim.hpp>
14#include <dr/mp/alignment.hpp>
15#include <dr/mp/views/segmented.hpp>
16
17namespace dr::mp::__detail {
18
19template <typename R>
20concept zipable = rng::random_access_range<R> && rng::common_range<R>;
21
22} // namespace dr::mp::__detail
23
24namespace dr::mp {
25
26template <__detail::zipable... Rs> class zip_view;
27
28namespace views {
29
30template <typename... Rs> auto zip(Rs &&...rs) {
31 return zip_view(std::forward<Rs>(rs)...);
32}
33
34} // namespace views
35
36namespace __detail {
37
38template <typename T>
41
42template <typename T, typename... Rest>
43inline auto select_segments(T &&t, Rest &&...rest) {
44 if constexpr (is_distributed<T>) {
45 return dr::ranges::segments(std::forward<T>(t));
46 } else {
47 return select_segments(std::forward<Rest>(rest)...);
48 }
49}
50
51template <typename T, typename Seg> inline auto tpl_segments(T &&t, Seg &&tpl) {
52 if constexpr (is_distributed<T>) {
53 return dr::ranges::segments(std::forward<T>(t));
54 } else if constexpr (rng::forward_range<T>) {
55 return views::segmented(std::forward<T>(t), std::forward<Seg>(tpl));
56 } else if constexpr (rng::forward_iterator<T>) {
57 return views::segmented(rng::subrange(std::forward<T>(t), T{}),
58 std::forward<Seg>(tpl));
59 }
60}
61
62template <typename Base> auto base_to_segments(Base &&base) {
63 // Given segments, return elementwise zip
64 auto zip_segments = [](auto &&...segments) {
65 return views::zip(segments...);
66 };
67
68 // Given a tuple of segments, return a single segment by doing
69 // elementwise zip
70 auto zip_segment_tuple = [zip_segments](auto &&v) {
71 return std::apply(zip_segments, v);
72 };
73
74 // Given base ranges, return segments
75 auto bases_to_segments = [zip_segment_tuple](auto &&...bases) {
76 bool is_aligned = aligned(bases...);
77 auto tpl = select_segments(bases...);
78 return rng::views::zip(tpl_segments(bases, tpl)...) |
79 rng::views::transform(zip_segment_tuple) |
80 rng::views::filter([is_aligned](auto &&v) { return is_aligned; });
81 };
82
83 return std::apply(bases_to_segments, base);
84}
85
86} // namespace __detail
87
88template <std::random_access_iterator RngIter,
89 std::random_access_iterator... BaseIters>
91public:
92 using value_type = rng::iter_value_t<RngIter>;
93 using difference_type = rng::iter_difference_t<RngIter>;
94
95 using iterator_category = std::random_access_iterator_tag;
96
97 zip_iterator() {}
98 zip_iterator(RngIter rng_iter, BaseIters... base_iters)
99 : rng_iter_(rng_iter), base_(base_iters...) {}
100
101 auto operator+(difference_type n) const {
102 auto iter(*this);
103 iter.rng_iter_ += n;
104 iter.offset_ += n;
105 return iter;
106 }
107 friend auto operator+(difference_type n, const zip_iterator &other) {
108 return other + n;
109 }
110 auto operator-(difference_type n) const {
111 auto iter(*this);
112 iter.rng_iter_ -= n;
113 iter.offset_ -= n;
114 return iter;
115 }
116 auto operator-(zip_iterator other) const {
117 return rng_iter_ - other.rng_iter_;
118 }
119
120 auto &operator+=(difference_type n) {
121 rng_iter_ += n;
122 offset_ += n;
123 return *this;
124 }
125 auto &operator-=(difference_type n) {
126 rng_iter_ -= n;
127 offset_ -= n;
128 return *this;
129 }
130 auto &operator++() {
131 rng_iter_++;
132 offset_++;
133 return *this;
134 }
135 auto operator++(int) {
136 auto iter(*this);
137 rng_iter_++;
138 offset_++;
139 return iter;
140 }
141 auto &operator--() {
142 rng_iter_--;
143 offset_--;
144 return *this;
145 }
146 auto operator--(int) {
147 auto iter(*this);
148 rng_iter_--;
149 offset_--;
150 return iter;
151 }
152
153 bool operator==(zip_iterator other) const {
154 return rng_iter_ == other.rng_iter_;
155 }
156 auto operator<=>(zip_iterator other) const {
157 return offset_ <=> other.offset_;
158 }
159
160 // Underlying zip_iterator does not return a reference
161 auto operator*() const { return *rng_iter_; }
162 auto operator[](difference_type n) const { return rng_iter_[n]; }
163
164 //
165 // Distributed Ranges support
166 //
167 auto segments() const
168 requires(distributed_iterator<BaseIters> || ...)
169 {
170 return dr::__detail::drop_segments(__detail::base_to_segments(base_),
171 offset_);
172 }
173
174 auto rank() const
175 requires(remote_iterator<BaseIters> || ...)
176 {
177 return dr::ranges::rank(std::get<0>(base_));
178 }
179
180 auto local() const
182 {
183 // Create a temporary zip_view and return the iterator. This code
184 // assumes the iterator is valid even if the underlying zip_view
185 // is destroyed.
186 auto zip = [this]<typename... Iters>(Iters &&...iters) {
187 return rng::begin(rng::views::zip(
188 rng::subrange(base_local(std::forward<Iters>(iters)) + this->offset_,
189 decltype(base_local(iters)){})...));
190 };
191
192 return std::apply(zip, base_);
193 }
194
195private:
196 // If it is not a remote iterator, assume it is a local iterator
197 auto static base_local(auto iter) { return iter; }
198
199 auto static base_local(dr::ranges::__detail::has_local auto iter) {
200 return dr::ranges::local(iter);
201 }
202
203 RngIter rng_iter_;
204 std::tuple<BaseIters...> base_;
205 difference_type offset_ = 0;
206};
207
208template <__detail::zipable... Rs>
209class zip_view : public rng::view_interface<zip_view<Rs...>> {
210private:
211 using rng_zip = rng::zip_view<Rs...>;
212 using rng_zip_iterator = rng::iterator_t<rng_zip>;
213 using difference_type = std::iter_difference_t<rng_zip_iterator>;
214
215public:
216 zip_view(Rs... rs)
217 : rng_zip_(rng::views::all(rs)...), base_(rng::views::all(rs)...) {}
218
219 auto begin() const {
220 auto make_begin = [this](auto &&...bases) {
221 return zip_iterator(rng::begin(this->rng_zip_), rng::begin(bases)...);
222 };
223 return std::apply(make_begin, base_);
224 }
225 auto end() const
226 requires(rng::common_range<rng_zip>)
227 {
228 auto make_end = [this](auto &&...bases) {
229 return zip_iterator(rng::end(this->rng_zip_), rng::end(bases)...);
230 };
231 return std::apply(make_end, base_);
232 }
233 auto size() const { return rng::size(rng_zip_); }
234
235 auto operator[](difference_type n) const { return rng_zip_[n]; }
236
237 auto base() const { return base_; }
238
239 //
240 // Distributed Ranges support
241 //
242 auto segments() const
243 requires(distributed_range<Rs> || ...)
244 {
245 return __detail::base_to_segments(base_);
246 }
247
248 auto rank() const
249 requires(remote_range<Rs> || ...)
250 {
251 return dr::ranges::rank(std::get<0>(base_));
252 }
253
254 auto local() const
255 requires(remote_range<Rs> || ...)
256 {
257 auto zip = []<typename... Vs>(Vs &&...bases) {
258 return rng::views::zip(dr::ranges::local(std::forward<Vs>(bases))...);
259 };
260
261 return std::apply(zip, base_);
262 }
263
264private:
265 rng_zip rng_zip_;
266 std::tuple<rng::views::all_t<Rs>...> base_;
267};
268
269template <typename... Rs>
271
272} // namespace dr::mp
Definition: zip.hpp:90
Definition: zip.hpp:209
Definition: concepts.hpp:31
Definition: concepts.hpp:20
Definition: zip.hpp:39
Definition: zip.hpp:20
Definition: ranges.hpp:242
Definition: concepts.hpp:12
Definition: concepts.hpp:16