Distributed Ranges
Loading...
Searching...
No Matches
distributed_span.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <vector>
8
9#include <dr/concepts/concepts.hpp>
10#include <dr/detail/iterator_adaptor.hpp>
11#include <dr/detail/ranges.hpp>
12#include <dr/detail/segments_tools.hpp>
13#include <dr/sp/device_span.hpp>
14
15namespace dr::sp {
16
17template <typename T, typename L> class distributed_span_accessor {
18public:
19 using element_type = T;
20 using value_type = std::remove_cv_t<T>;
21
22 using segment_type = L;
23
24 using size_type = rng::range_size_t<segment_type>;
25 using difference_type = rng::range_difference_t<segment_type>;
26
27 // using pointer = typename segment_type::pointer;
28 using reference = rng::range_reference_t<segment_type>;
29
30 using iterator_category = std::random_access_iterator_tag;
31
35
36 constexpr distributed_span_accessor() noexcept = default;
37 constexpr ~distributed_span_accessor() noexcept = default;
39 const distributed_span_accessor &) noexcept = default;
41 operator=(const distributed_span_accessor &) noexcept = default;
42
43 constexpr distributed_span_accessor(std::span<segment_type> segments,
44 size_type segment_id,
45 size_type idx) noexcept
46 : segments_(segments), segment_id_(segment_id), idx_(idx) {}
47
49 operator+=(difference_type offset) noexcept {
50
51 while (offset > 0) {
52 difference_type current_offset =
53 std::min(offset, difference_type(segments_[segment_id_].size()) -
54 difference_type(idx_));
55 idx_ += current_offset;
56 offset -= current_offset;
57
58 if (idx_ >= segments_[segment_id_].size()) {
59 segment_id_++;
60 idx_ = 0;
61 }
62 }
63
64 while (offset < 0) {
65 difference_type current_offset =
66 std::min(-offset, difference_type(idx_) + 1);
67
68 difference_type new_idx = difference_type(idx_) - current_offset;
69
70 if (new_idx < 0) {
71 segment_id_--;
72 new_idx = segments_[segment_id_].size() - 1;
73 }
74
75 idx_ = new_idx;
76 }
77
78 return *this;
79 }
80
81 constexpr bool operator==(const iterator_accessor &other) const noexcept {
82 return segment_id_ == other.segment_id_ && idx_ == other.idx_;
83 }
84
85 constexpr difference_type
86 operator-(const iterator_accessor &other) const noexcept {
87 return difference_type(get_global_idx()) - other.get_global_idx();
88 }
89
90 constexpr bool operator<(const iterator_accessor &other) const noexcept {
91 if (segment_id_ < other.segment_id_) {
92 return true;
93 } else if (segment_id_ == other.segment_id_) {
94 return idx_ < other.idx_;
95 } else {
96 return false;
97 }
98 }
99
100 constexpr reference operator*() const noexcept {
101 return segments_[segment_id_][idx_];
102 }
103
104 auto segments() const noexcept {
105 return dr::__detail::drop_segments(segments_, segment_id_, idx_);
106 }
107
108private:
109 size_type get_global_idx() const noexcept {
110 size_type cumulative_size = 0;
111 for (std::size_t i = 0; i < segment_id_; i++) {
112 cumulative_size += segments_[i].size();
113 }
114 return cumulative_size + idx_;
115 }
116
117 std::span<segment_type> segments_;
118 size_type segment_id_ = 0;
119 size_type idx_ = 0;
120};
121
122template <typename T, typename L>
125
126template <typename T, typename L>
127class distributed_span : public rng::view_interface<distributed_span<T, L>> {
128public:
129 using element_type = T;
130 using value_type = std::remove_cv_t<T>;
131
133
134 using size_type = rng::range_size_t<segment_type>;
135 using difference_type = rng::range_difference_t<segment_type>;
136
137 // using pointer = typename segment_type::pointer;
138 using reference = rng::range_reference_t<segment_type>;
139
140 // Note: creating the "global view" will be trivial once #44178 is resolved.
141 // (https://github.com/llvm/llvm-project/issues/44178)
142 // The "global view" is simply all of the segmented views joined together.
143 // However, this code does not currently compile due to a bug in Clang,
144 // so I am currently implementing my own global iterator manually.
145 // using joined_view_type =
146 // rng::join_view<rng::ref_view<std::vector<segment_type>>>;
147 // using iterator = rng::iterator_t<joined_view_type>;
148
150
151 constexpr distributed_span() noexcept = default;
152 constexpr distributed_span(const distributed_span &) noexcept = default;
153 constexpr distributed_span &
154 operator=(const distributed_span &) noexcept = default;
155
156 template <rng::input_range R>
158 constexpr distributed_span(R &&segments) {
159 for (auto &&segment : segments) {
160 std::size_t size = rng::size(segment);
161 segments_.push_back(
162 segment_type(rng::begin(segment), size, dr::ranges::rank(segment)));
163 size_ += size;
164 }
165 }
166
167 template <dr::distributed_range R> constexpr distributed_span(R &&r) {
168 for (auto &&segment : dr::ranges::segments(std::forward<R>(r))) {
169 std::size_t size = rng::size(segment);
170 segments_.push_back(
171 segment_type(rng::begin(segment), size, dr::ranges::rank(segment)));
172 size_ += size;
173 }
174 }
175
176 constexpr size_type size() const noexcept { return size_; }
177
178 constexpr size_type size_bytes() const noexcept {
179 return size() * sizeof(element_type);
180 }
181
182 constexpr reference operator[](size_type idx) const {
183 // TODO: optimize this
184 std::size_t span_id = 0;
185 for (std::size_t span_id = 0; idx >= segments()[span_id].size();
186 span_id++) {
187 idx -= segments()[span_id].size();
188 }
189 return segments()[span_id][idx];
190 }
191
192 [[nodiscard]] constexpr bool empty() const noexcept { return size() == 0; }
193
194 constexpr distributed_span
195 subspan(size_type Offset, size_type Count = std::dynamic_extent) const {
196 Count = std::min(Count, size() - Offset);
197
198 std::vector<segment_type> new_segments;
199
200 // Forward to segment_id that contains global index `Offset`.
201 std::size_t segment_id = 0;
202 for (segment_id = 0; Offset >= segments()[segment_id].size();
203 segment_id++) {
204 Offset -= segments()[segment_id].size();
205 }
206
207 // Our Offset begins at `segment_id, Offset`
208
209 while (Count > 0) {
210 std::size_t local_count =
211 std::min(Count, segments()[segment_id].size() - Offset);
212 auto new_segment = segments()[segment_id].subspan(Offset, local_count);
213 new_segments.push_back(new_segment);
214 Count -= local_count;
215 Offset = 0;
216 segment_id++;
217 }
218
219 return distributed_span(new_segments);
220 }
221
222 constexpr distributed_span first(size_type Count) const {
223 return subspan(0, Count);
224 }
225
226 constexpr distributed_span last(size_type Count) const {
227 return subspan(size() - Count, Count);
228 }
229
230 iterator begin() { return iterator(segments(), 0, 0); }
231
232 iterator end() { return iterator(segments(), segments().size(), 0); }
233
234 constexpr reference front() { return segments().front().front(); }
235
236 constexpr reference back() { return segments().back().back(); }
237
238 std::span<segment_type> segments() { return segments_; }
239
240 std::span<const segment_type> segments() const { return segments_; }
241
242private:
243 std::size_t size_ = 0;
244 std::vector<segment_type> segments_;
245};
246
247template <rng::input_range R>
248distributed_span(R &&segments)
250 rng::iterator_t<rng::range_value_t<R>>>;
251
252template <dr::distributed_contiguous_range R>
254 rng::range_value_t<R>,
255 rng::iterator_t<rng::range_value_t<decltype(dr::ranges::segments(r))>>>;
256
257} // namespace dr::sp
Definition: iterator_adaptor.hpp:23
Definition: device_span.hpp:44
Definition: distributed_span.hpp:17
Definition: distributed_span.hpp:127
Definition: concepts.hpp:16