7#include <oneapi/dpl/iterator>
9#include <dr/detail/iterator_adaptor.hpp>
10#include <dr/detail/owning_view.hpp>
11#include <dr/detail/ranges_shim.hpp>
12#include <dr/detail/view_detectors.hpp>
13#include <dr/sp/device_span.hpp>
31 using type = std::tuple<Args...>;
35 using type = std::pair<T, U>;
38template <
typename... Args>
43template <rng::random_access_iterator... Iters>
class zip_accessor {
45 using element_type = __detail::tuple_or_pair_t<std::iter_value_t<Iters>...>;
46 using value_type = element_type;
47 using size_type = std::size_t;
48 using difference_type = std::ptrdiff_t;
49 using reference = __detail::tuple_or_pair_t<std::iter_reference_t<Iters>...>;
51 using iterator_category = std::random_access_iterator_tag;
62 constexpr zip_accessor(Iters... iters) : iterators_(iters...) {}
65 auto increment = [&](
auto &&iter) { iter += offset; };
66 iterators_apply_impl_<0>(increment);
71 return std::get<0>(iterators_) == std::get<0>(other.iterators_);
74 constexpr difference_type
76 return std::get<0>(iterators_) - std::get<0>(other.iterators_);
80 return std::get<0>(iterators_) < std::get<0>(other.iterators_);
83 constexpr reference operator*()
const noexcept {
84 return get_impl_(std::make_index_sequence<
sizeof...(Iters)>{});
88 template <std::size_t... Ints>
89 reference get_impl_(std::index_sequence<Ints...>)
const noexcept {
90 return reference(*std::get<Ints>(iterators_)...);
93 template <std::
size_t I,
typename Fn>
void iterators_apply_impl_(Fn &&fn) {
94 fn(std::get<I>(iterators_));
95 if constexpr (I + 1 <
sizeof...(Iters)) {
96 iterators_apply_impl_<I + 1>(fn);
100 std::tuple<Iters...> iterators_;
103template <rng::random_access_iterator... Iters>
107template <rng::random_access_range... Rs>
108class zip_view :
public rng::view_interface<zip_view<Rs...>> {
110 using size_type = std::size_t;
111 using difference_type = std::ptrdiff_t;
113 zip_view(Rs... rs) : views_(rng::views::all(std::forward<Rs>(rs))...) {
114 std::array<std::size_t,
sizeof...(Rs)> sizes = {
115 std::size_t(rng::distance(rs))...};
120 for (
auto &&size : sizes) {
121 size_ = std::min(size_, size);
125 std::size_t size()
const noexcept {
return size_; }
128 return begin_impl_(std::make_index_sequence<
sizeof...(Rs)>{});
131 auto end()
const {
return begin() + size(); }
133 auto operator[](std::size_t idx)
const {
return *(begin() + idx); }
135 static constexpr bool num_views =
sizeof...(Rs);
137 template <std::
size_t I>
decltype(
auto) get_view()
const {
138 auto &&view = std::get<I>(views_);
140 if constexpr (dr::is_ref_view_v<std::remove_cvref_t<
decltype(view)>> ||
141 dr::is_owning_view_v<std::remove_cvref_t<
decltype(view)>>) {
150 auto segments()
const
153 std::array<std::size_t,
sizeof...(Rs)> segment_ids;
154 std::array<std::size_t,
sizeof...(Rs)> local_idx;
158 std::size_t cumulative_size = 0;
160 using segment_view_type =
decltype(get_zipped_view_impl_(
161 segment_ids, local_idx, 0, std::make_index_sequence<
sizeof...(Rs)>{}));
162 std::vector<segment_view_type> segment_views;
164 while (cumulative_size < size()) {
165 auto size = get_next_segment_size(segment_ids, local_idx);
167 cumulative_size += size;
173 get_zipped_view_impl_(segment_ids, local_idx, size,
174 std::make_index_sequence<
sizeof...(Rs)>{});
176 segment_views.push_back(std::move(segment_view));
178 increment_local_idx(segment_ids, local_idx, size);
187 auto zipped_segments()
const
190 std::array<std::size_t,
sizeof...(Rs)> segment_ids;
191 std::array<std::size_t,
sizeof...(Rs)> local_idx;
195 std::size_t cumulative_size = 0;
197 using segment_view_type =
decltype(get_zipped_segments_impl_(
198 segment_ids, local_idx, 0, std::make_index_sequence<
sizeof...(Rs)>{}));
199 std::vector<segment_view_type> segment_views;
201 while (cumulative_size < size()) {
202 auto size = get_next_segment_size(segment_ids, local_idx);
204 cumulative_size += size;
209 get_zipped_segments_impl_(segment_ids, local_idx, size,
210 std::make_index_sequence<
sizeof...(Rs)>{});
212 segment_views.push_back(std::move(segment_view));
214 increment_local_idx(segment_ids, local_idx, size);
220 auto local()
const noexcept
223 return local_impl_(std::make_index_sequence<
sizeof...(Rs)>());
230 std::size_t rank()
const
234 return get_rank_impl_<0, Rs...>();
238 template <std::size_t... Ints>
239 auto local_impl_(std::index_sequence<Ints...>)
const noexcept {
240 return rng::views::zip(__detail::local(std::get<Ints>(views_))...);
243 template <std::
size_t I,
typename R> std::size_t get_rank_impl_()
const {
244 static_assert(I <
sizeof...(Rs));
245 return dr::ranges::rank(get_view<I>());
248 template <std::size_t I,
typename R,
typename... Rs_>
249 requires(
sizeof...(Rs_) > 0)
250 std::size_t get_rank_impl_()
const {
251 static_assert(I <
sizeof...(Rs));
253 return dr::ranges::rank(get_view<I>());
255 return get_rank_impl_<I + 1, Rs_...>();
259 template <
typename T>
auto create_view_impl_(T &&t)
const {
267 template <std::size_t... Is>
268 auto get_zipped_view_impl_(
auto &&segment_ids,
auto &&local_idx,
270 std::index_sequence<Is...>)
const {
271 return zip_view<
decltype(create_view_impl_(
272 segment_or_orig_(get_view<Is>(),
274 .subspan(local_idx[Is], size))...>(
275 create_view_impl_(segment_or_orig_(get_view<Is>(), segment_ids[Is]))
276 .subspan(local_idx[Is], size)...);
279 template <std::size_t... Is>
280 auto get_zipped_segments_impl_(
auto &&segment_ids,
auto &&local_idx,
282 std::index_sequence<Is...>)
const {
284 create_view_impl_(segment_or_orig_(get_view<Is>(), segment_ids[Is]))
285 .subspan(local_idx[Is], size)...);
288 template <std::
size_t I = 0>
289 void increment_local_idx(
auto &&segment_ids,
auto &&local_idx,
290 std::size_t size)
const {
291 local_idx[I] += size;
294 rng::distance(segment_or_orig_(get_view<I>(), segment_ids[I]))) {
299 if constexpr (I + 1 <
sizeof...(Rs)) {
300 increment_local_idx<I + 1>(segment_ids, local_idx, size);
304 template <std::size_t... Is>
305 auto begin_impl_(std::index_sequence<Is...>)
const {
307 rng::begin(std::get<Is>(views_))...);
310 template <dr::distributed_range T>
311 decltype(
auto) segment_or_orig_(T &&t, std::size_t idx)
const {
312 return dr::ranges::segments(t)[idx];
315 template <
typename T>
316 decltype(
auto) segment_or_orig_(T &&t, std::size_t idx)
const {
320 template <std::size_t... Is>
321 std::size_t get_next_segment_size_impl_(
auto &&segment_ids,
auto &&local_idx,
322 std::index_sequence<Is...>)
const {
323 return std::min({std::size_t(rng::distance(
324 segment_or_orig_(get_view<Is>(), segment_ids[Is]))) -
328 std::size_t get_next_segment_size(
auto &&segment_ids,
329 auto &&local_idx)
const {
330 return get_next_segment_size_impl_(
331 segment_ids, local_idx, std::make_index_sequence<
sizeof...(Rs)>{});
334 std::tuple<rng::views::all_t<Rs>...> views_;
343template <rng::random_access_range... Rs>
auto zip(Rs &&...rs) {
Definition: owning_view.hpp:18
Definition: iterator_adaptor.hpp:23
Definition: device_span.hpp:44
Definition: zip_view.hpp:43
zip
Definition: zip_view.hpp:108
Definition: concepts.hpp:20
Definition: concepts.hpp:16
Definition: zip_view.hpp:17
Definition: zip_view.hpp:30