Distributed Ranges
Loading...
Searching...
No Matches
zip_view.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <oneapi/dpl/iterator>
8
9#include <dr/detail/iterator_adaptor.hpp>
10#include <dr/detail/owning_view.hpp>
11#include <dr/detail/ranges_shim.hpp>
12#include <dr/detail/view_detectors.hpp>
13#include <dr/sp/device_span.hpp>
14
15namespace dr {
16
17template <typename T> struct is_owning_view : std::false_type {};
18// template <rng::range R>
19// struct is_owning_view<rng::owning_view<R>> : std::true_type {};
20
21template <typename T>
22inline constexpr bool is_owning_view_v = is_owning_view<T>{};
23
24}; // namespace dr
25
26namespace dr::sp {
27
28namespace __detail {
29
30template <typename... Args> struct tuple_or_pair {
31 using type = std::tuple<Args...>;
32};
33
34template <typename T, typename U> struct tuple_or_pair<T, U> {
35 using type = std::pair<T, U>;
36};
37
38template <typename... Args>
39using tuple_or_pair_t = typename tuple_or_pair<Args...>::type;
40
41}; // namespace __detail
42
43template <rng::random_access_iterator... Iters> class zip_accessor {
44public:
45 using element_type = __detail::tuple_or_pair_t<std::iter_value_t<Iters>...>;
46 using value_type = element_type;
47 using size_type = std::size_t;
48 using difference_type = std::ptrdiff_t;
49 using reference = __detail::tuple_or_pair_t<std::iter_reference_t<Iters>...>;
50
51 using iterator_category = std::random_access_iterator_tag;
52
56
57 constexpr zip_accessor() noexcept = default;
58 constexpr ~zip_accessor() noexcept = default;
59 constexpr zip_accessor(const zip_accessor &) noexcept = default;
60 constexpr zip_accessor &operator=(const zip_accessor &) noexcept = default;
61
62 constexpr zip_accessor(Iters... iters) : iterators_(iters...) {}
63
64 zip_accessor &operator+=(difference_type offset) {
65 auto increment = [&](auto &&iter) { iter += offset; };
66 iterators_apply_impl_<0>(increment);
67 return *this;
68 }
69
70 constexpr bool operator==(const iterator_accessor &other) const noexcept {
71 return std::get<0>(iterators_) == std::get<0>(other.iterators_);
72 }
73
74 constexpr difference_type
75 operator-(const iterator_accessor &other) const noexcept {
76 return std::get<0>(iterators_) - std::get<0>(other.iterators_);
77 }
78
79 constexpr bool operator<(const iterator_accessor &other) const noexcept {
80 return std::get<0>(iterators_) < std::get<0>(other.iterators_);
81 }
82
83 constexpr reference operator*() const noexcept {
84 return get_impl_(std::make_index_sequence<sizeof...(Iters)>{});
85 }
86
87private:
88 template <std::size_t... Ints>
89 reference get_impl_(std::index_sequence<Ints...>) const noexcept {
90 return reference(*std::get<Ints>(iterators_)...);
91 }
92
93 template <std::size_t I, typename Fn> void iterators_apply_impl_(Fn &&fn) {
94 fn(std::get<I>(iterators_));
95 if constexpr (I + 1 < sizeof...(Iters)) {
96 iterators_apply_impl_<I + 1>(fn);
97 }
98 }
99
100 std::tuple<Iters...> iterators_;
101};
102
103template <rng::random_access_iterator... Iters>
105
107template <rng::random_access_range... Rs>
108class zip_view : public rng::view_interface<zip_view<Rs...>> {
109public:
110 using size_type = std::size_t;
111 using difference_type = std::ptrdiff_t;
112
113 zip_view(Rs... rs) : views_(rng::views::all(std::forward<Rs>(rs))...) {
114 std::array<std::size_t, sizeof...(Rs)> sizes = {
115 std::size_t(rng::distance(rs))...};
116
117 // TODO: support zipped views with some ranges shorter than others
118 size_ = sizes[0];
119
120 for (auto &&size : sizes) {
121 size_ = std::min(size_, size);
122 }
123 }
124
125 std::size_t size() const noexcept { return size_; }
126
127 auto begin() const {
128 return begin_impl_(std::make_index_sequence<sizeof...(Rs)>{});
129 }
130
131 auto end() const { return begin() + size(); }
132
133 auto operator[](std::size_t idx) const { return *(begin() + idx); }
134
135 static constexpr bool num_views = sizeof...(Rs);
136
137 template <std::size_t I> decltype(auto) get_view() const {
138 auto &&view = std::get<I>(views_);
139
140 if constexpr (dr::is_ref_view_v<std::remove_cvref_t<decltype(view)>> ||
141 dr::is_owning_view_v<std::remove_cvref_t<decltype(view)>>) {
142 return view.base();
143 } else {
144 return view;
145 }
146 }
147
148 // If there is at least one distributed range, expose segments
149 // of overlapping remote ranges.
150 auto segments() const
151 requires(dr::distributed_range<Rs> || ...)
152 {
153 std::array<std::size_t, sizeof...(Rs)> segment_ids;
154 std::array<std::size_t, sizeof...(Rs)> local_idx;
155 segment_ids.fill(0);
156 local_idx.fill(0);
157
158 std::size_t cumulative_size = 0;
159
160 using segment_view_type = decltype(get_zipped_view_impl_(
161 segment_ids, local_idx, 0, std::make_index_sequence<sizeof...(Rs)>{}));
162 std::vector<segment_view_type> segment_views;
163
164 while (cumulative_size < size()) {
165 auto size = get_next_segment_size(segment_ids, local_idx);
166
167 cumulative_size += size;
168
169 // Create zipped segment with
170 // zip_view(segments()[Is].subspan(local_idx[Is], size)...) And some rank
171 // (e.g. get_view<0>.rank())
172 auto segment_view =
173 get_zipped_view_impl_(segment_ids, local_idx, size,
174 std::make_index_sequence<sizeof...(Rs)>{});
175
176 segment_views.push_back(std::move(segment_view));
177
178 increment_local_idx(segment_ids, local_idx, size);
179 }
180
181 return dr::__detail::owning_view(std::move(segment_views));
182 }
183
184 // Return a range corresponding to each segment in `segments()`,
185 // but with a tuple of the constituent ranges instead of a
186 // `zip_view` of the ranges.
187 auto zipped_segments() const
188 requires(dr::distributed_range<Rs> || ...)
189 {
190 std::array<std::size_t, sizeof...(Rs)> segment_ids;
191 std::array<std::size_t, sizeof...(Rs)> local_idx;
192 segment_ids.fill(0);
193 local_idx.fill(0);
194
195 std::size_t cumulative_size = 0;
196
197 using segment_view_type = decltype(get_zipped_segments_impl_(
198 segment_ids, local_idx, 0, std::make_index_sequence<sizeof...(Rs)>{}));
199 std::vector<segment_view_type> segment_views;
200
201 while (cumulative_size < size()) {
202 auto size = get_next_segment_size(segment_ids, local_idx);
203
204 cumulative_size += size;
205
206 // Get zipped segments with
207 // std::tuple(segments()[Is].subspan(local_idx[Is], size)...)
208 auto segment_view =
209 get_zipped_segments_impl_(segment_ids, local_idx, size,
210 std::make_index_sequence<sizeof...(Rs)>{});
211
212 segment_views.push_back(std::move(segment_view));
213
214 increment_local_idx(segment_ids, local_idx, size);
215 }
216
217 return dr::__detail::owning_view(std::move(segment_views));
218 }
219
220 auto local() const noexcept
221 requires(!(dr::distributed_range<Rs> || ...))
222 {
223 return local_impl_(std::make_index_sequence<sizeof...(Rs)>());
224 }
225
226 // If:
227 // - There is at least one remote range in the zip
228 // - There are no distributed ranges in the zip
229 // Expose a rank.
230 std::size_t rank() const
231 requires((dr::remote_range<Rs> || ...) &&
233 {
234 return get_rank_impl_<0, Rs...>();
235 }
236
237private:
238 template <std::size_t... Ints>
239 auto local_impl_(std::index_sequence<Ints...>) const noexcept {
240 return rng::views::zip(__detail::local(std::get<Ints>(views_))...);
241 }
242
243 template <std::size_t I, typename R> std::size_t get_rank_impl_() const {
244 static_assert(I < sizeof...(Rs));
245 return dr::ranges::rank(get_view<I>());
246 }
247
248 template <std::size_t I, typename R, typename... Rs_>
249 requires(sizeof...(Rs_) > 0)
250 std::size_t get_rank_impl_() const {
251 static_assert(I < sizeof...(Rs));
252 if constexpr (dr::remote_range<R>) {
253 return dr::ranges::rank(get_view<I>());
254 } else {
255 return get_rank_impl_<I + 1, Rs_...>();
256 }
257 }
258
259 template <typename T> auto create_view_impl_(T &&t) const {
260 if constexpr (dr::remote_range<T>) {
261 return dr::sp::device_span(std::forward<T>(t));
262 } else {
263 return dr::sp::span(std::forward<T>(t));
264 }
265 }
266
267 template <std::size_t... Is>
268 auto get_zipped_view_impl_(auto &&segment_ids, auto &&local_idx,
269 std::size_t size,
270 std::index_sequence<Is...>) const {
271 return zip_view<decltype(create_view_impl_(
272 segment_or_orig_(get_view<Is>(),
273 segment_ids[Is]))
274 .subspan(local_idx[Is], size))...>(
275 create_view_impl_(segment_or_orig_(get_view<Is>(), segment_ids[Is]))
276 .subspan(local_idx[Is], size)...);
277 }
278
279 template <std::size_t... Is>
280 auto get_zipped_segments_impl_(auto &&segment_ids, auto &&local_idx,
281 std::size_t size,
282 std::index_sequence<Is...>) const {
283 return std::tuple(
284 create_view_impl_(segment_or_orig_(get_view<Is>(), segment_ids[Is]))
285 .subspan(local_idx[Is], size)...);
286 }
287
288 template <std::size_t I = 0>
289 void increment_local_idx(auto &&segment_ids, auto &&local_idx,
290 std::size_t size) const {
291 local_idx[I] += size;
292
293 if (local_idx[I] >=
294 rng::distance(segment_or_orig_(get_view<I>(), segment_ids[I]))) {
295 local_idx[I] = 0;
296 segment_ids[I]++;
297 }
298
299 if constexpr (I + 1 < sizeof...(Rs)) {
300 increment_local_idx<I + 1>(segment_ids, local_idx, size);
301 }
302 }
303
304 template <std::size_t... Is>
305 auto begin_impl_(std::index_sequence<Is...>) const {
307 rng::begin(std::get<Is>(views_))...);
308 }
309
310 template <dr::distributed_range T>
311 decltype(auto) segment_or_orig_(T &&t, std::size_t idx) const {
312 return dr::ranges::segments(t)[idx];
313 }
314
315 template <typename T>
316 decltype(auto) segment_or_orig_(T &&t, std::size_t idx) const {
317 return t;
318 }
319
320 template <std::size_t... Is>
321 std::size_t get_next_segment_size_impl_(auto &&segment_ids, auto &&local_idx,
322 std::index_sequence<Is...>) const {
323 return std::min({std::size_t(rng::distance(
324 segment_or_orig_(get_view<Is>(), segment_ids[Is]))) -
325 local_idx[Is]...});
326 }
327
328 std::size_t get_next_segment_size(auto &&segment_ids,
329 auto &&local_idx) const {
330 return get_next_segment_size_impl_(
331 segment_ids, local_idx, std::make_index_sequence<sizeof...(Rs)>{});
332 }
333
334 std::tuple<rng::views::all_t<Rs>...> views_;
335 std::size_t size_;
336};
337
338template <typename... Rs> zip_view(Rs &&...rs) -> zip_view<Rs...>;
339
340namespace views {
341
343template <rng::random_access_range... Rs> auto zip(Rs &&...rs) {
344 return dr::sp::zip_view(std::forward<Rs>(rs)...);
345}
346
347} // namespace views
348
349} // namespace dr::sp
Definition: owning_view.hpp:18
Definition: iterator_adaptor.hpp:23
Definition: device_span.hpp:44
Definition: span.hpp:14
Definition: zip_view.hpp:43
zip
Definition: zip_view.hpp:108
Definition: concepts.hpp:20
Definition: concepts.hpp:16
Definition: zip_view.hpp:17
Definition: zip_view.hpp:30