Distributed Ranges
Loading...
Searching...
No Matches
views.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/ranges_shim.hpp>
8#include <dr/views/iota.hpp>
9#include <dr/views/transform.hpp>
10
11namespace dr::mp {
12
13// Select segments local to this rank and convert the iterators in the
14// segment to local
15template <typename R> auto local_segments(R &&dr) {
16 auto is_local = [](const auto &segment) {
17 return dr::ranges::rank(segment) == default_comm().rank();
18 };
19 // Convert from remote iter to local iter
20 auto local_iter = [](const auto &segment) {
21 auto b = dr::ranges::local(rng::begin(segment));
22 return rng::subrange(b, b + rng::distance(segment));
23 };
24 return dr::ranges::segments(std::forward<R>(dr)) |
25 rng::views::filter(is_local) | rng::views::transform(local_iter);
26}
27
28template <typename R> auto local_segments_with_idx(R &&dr) {
29 auto is_local = [](const auto &segment_with_idx) {
30 return dr::ranges::rank(std::get<1>(segment_with_idx)) ==
31 default_comm().rank();
32 };
33 // Convert from remote iter to local iter
34 auto local_iter = [](const auto &segment_with_idx) {
35 auto &&[idx, segment] = segment_with_idx;
36 auto b = dr::ranges::local(rng::begin(segment));
37 return std::tuple(idx, rng::subrange(b, b + rng::distance(segment)));
38 };
39 return dr::ranges::segments(std::forward<R>(dr)) | rng::views::enumerate |
40 rng::views::filter(is_local) | rng::views::transform(local_iter);
41}
42
43template <dr::distributed_contiguous_range R> auto local_segment(R &&r) {
44 auto segments = dr::mp::local_segments(std::forward<R>(r));
45
46 if (rng::empty(segments)) {
47 return rng::range_value_t<decltype(segments)>{};
48 }
49
50 // Should be error, not assert. Or we could join all the segments
51 assert(rng::distance(segments) == 1);
52 return *rng::begin(segments);
53}
54
55template <typename R> auto local_mdspans(R &&dr) {
56 return dr::ranges::segments(std::forward<R>(dr))
57 // Select the local segments
58 | rng::views::filter([](auto s) {
59 return dr::ranges::rank(s) == default_comm().rank();
60 })
61 // Extract the mdspan
62 | rng::views::transform([](auto s) { return s.mdspan(); });
63}
64
65template <dr::distributed_contiguous_range R> auto local_mdspan(R &&r) {
66 auto mdspans = dr::mp::local_mdspans(std::forward<R>(r));
67
68 if (rng::empty(mdspans)) {
69 return rng::range_value_t<decltype(mdspans)>{};
70 }
71
72 // Should be error, not assert. Or we could join all the segments
73 assert(rng::distance(mdspans) == 1);
74 return *rng::begin(mdspans);
75}
76
77} // namespace dr::mp
78
79namespace dr::mp::views {
80
81inline constexpr auto all = rng::views::all;
82inline constexpr auto counted = rng::views::counted;
83inline constexpr auto drop = rng::views::drop;
84inline constexpr auto iota = dr::views::iota;
85inline constexpr auto take = rng::views::take;
86inline constexpr auto transform = dr::views::transform;
87
88} // namespace dr::mp::views