Distributed Ranges
Loading...
Searching...
No Matches
segment.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7namespace dr::mp {
8
9template <typename DV> class dv_segment_iterator;
10
11template <typename DV> class dv_segment_reference {
13
14public:
15 using value_type = typename DV::value_type;
16
17 dv_segment_reference(const iterator it) : iterator_(it) {}
18
19 operator value_type() const { return iterator_.get(); }
20 auto operator=(const value_type &value) const {
21 iterator_.put(value);
22 return *this;
23 }
24 auto operator=(const dv_segment_reference &other) const {
25 *this = value_type(other);
26 return *this;
27 }
28 auto operator&() const { return iterator_; }
29
30private:
31 const iterator iterator_;
32}; // dv_segment_reference
33
34template <typename DV> class dv_segment_iterator {
35public:
36 using value_type = typename DV::value_type;
37 using size_type = typename DV::size_type;
38 using difference_type = typename DV::difference_type;
39
40 dv_segment_iterator() = default;
41 dv_segment_iterator(DV *dv, std::size_t segment_index, std::size_t index) {
42 dv_ = dv;
43 segment_index_ = segment_index;
44 index_ = index;
45 }
46
47 auto operator<=>(const dv_segment_iterator &other) const noexcept {
48 // assertion below checks against compare dereferenceable iterator to a
49 // singular iterator and against attempt to compare iterators from different
50 // sequences like _Safe_iterator<gnu_cxx::normal_iterator> does
51 assert(dv_ == other.dv_);
52 return segment_index_ == other.segment_index_
53 ? index_ <=> other.index_
54 : segment_index_ <=> other.segment_index_;
55 }
56
57 // Comparison
58 bool operator==(const dv_segment_iterator &other) const noexcept {
59 return (*this <=> other) == 0;
60 }
61
62 // Only this arithmetic manipulate internal state
63 auto &operator+=(difference_type n) {
64 assert(dv_ != nullptr);
65 assert(n >= 0 || static_cast<difference_type>(index_) >= -n);
66 index_ += n;
67 return *this;
68 }
69
70 auto &operator-=(difference_type n) { return *this += (-n); }
71
72 difference_type operator-(const dv_segment_iterator &other) const noexcept {
73 assert(dv_ != nullptr && dv_ == other.dv_);
74 assert(index_ >= other.index_);
75 return index_ - other.index_;
76 }
77
78 // prefix
79 auto &operator++() {
80 *this += 1;
81 return *this;
82 }
83 auto &operator--() {
84 *this -= 1;
85 return *this;
86 }
87
88 // postfix
89 auto operator++(int) {
90 auto prev = *this;
91 *this += 1;
92 return prev;
93 }
94 auto operator--(int) {
95 auto prev = *this;
96 *this -= 1;
97 return prev;
98 }
99
100 auto operator+(difference_type n) const {
101 auto p = *this;
102 p += n;
103 return p;
104 }
105 auto operator-(difference_type n) const {
106 auto p = *this;
107 p -= n;
108 return p;
109 }
110
111 // When *this is not first in the expression
112 friend auto operator+(difference_type n, const dv_segment_iterator &other) {
113 return other + n;
114 }
115
116 // dereference
117 auto operator*() const {
118 assert(dv_ != nullptr);
119 return dv_segment_reference<DV>{*this};
120 }
121 auto operator[](difference_type n) const {
122 assert(dv_ != nullptr);
123 return *(*this + n);
124 }
125
126 void get(value_type *dst, std::size_t size) const {
127 assert(dv_ != nullptr);
128 assert(segment_index_ * dv_->segment_size_ + index_ < dv_->size());
129 auto segment_offset = index_ + dv_->distribution_.halo().prev;
130 dv_->backend.getmem(dst, segment_offset * sizeof(value_type),
131 size * sizeof(value_type), segment_index_);
132 }
133
134 value_type get() const {
135 value_type val;
136 get(&val, 1);
137 return val;
138 }
139
140 void put(const value_type *dst, std::size_t size) const {
141 assert(dv_ != nullptr);
142 assert(segment_index_ * dv_->segment_size_ + index_ < dv_->size());
143 auto segment_offset = index_ + dv_->distribution_.halo().prev;
144 dr::drlog.debug("dv put:: ({}:{}:{})\n", segment_index_, segment_offset,
145 size);
146 dv_->backend.putmem(dst, segment_offset * sizeof(value_type),
147 size * sizeof(value_type), segment_index_);
148 }
149
150 void put(const value_type &value) const { put(&value, 1); }
151
152 auto rank() const {
153 assert(dv_ != nullptr);
154 return segment_index_;
155 }
156
157 auto local() const {
158#ifndef SYCL_LANGUAGE_VERSION
159 assert(dv_ != nullptr);
160#endif
161 const auto my_process_segment_index = dv_->backend.getrank();
162
163 if (my_process_segment_index == segment_index_)
164 return dv_->data_ + index_ + dv_->distribution_.halo().prev;
165#ifndef SYCL_LANGUAGE_VERSION
166 assert(!dv_->distribution_.halo().periodic); // not implemented
167#endif
168 // sliding view needs local iterators that point to the halo
169 if (my_process_segment_index + 1 == segment_index_) {
170#ifndef SYCL_LANGUAGE_VERSION
171 assert(index_ <= dv_->distribution_.halo()
172 .next); // <= instead of < to cover end() case
173#endif
174 return dv_->data_ + dv_->distribution_.halo().prev + index_ +
175 dv_->segment_size_;
176 }
177
178 if (my_process_segment_index == segment_index_ + 1) {
179#ifndef SYCL_LANGUAGE_VERSION
180 assert(dv_->segment_size_ - index_ <= dv_->distribution_.halo().prev);
181#endif
182 return dv_->data_ + dv_->distribution_.halo().prev + index_ -
183 dv_->segment_size_;
184 }
185
186#ifndef SYCL_LANGUAGE_VERSION
187 assert(false); // trying to read non-owned memory
188#endif
189 return static_cast<decltype(dv_->data_)>(nullptr);
190 }
191
192 auto segments() const {
193 assert(dv_ != nullptr);
194 return dr::__detail::drop_segments(dv_->segments(), segment_index_, index_);
195 }
196
197 auto &halo() const {
198 assert(dv_ != nullptr);
199 return dv_->halo();
200 }
201 auto halo_bounds() const {
202 assert(dv_ != nullptr);
203 return dv_->distribution_.halo();
204 }
205
206private:
207 // all fields need to be initialized by default ctor so every default
208 // constructed iter is equal to any other default constructed iter
209 DV *dv_ = nullptr;
210 std::size_t segment_index_ = 0;
211 std::size_t index_ = 0;
212}; // dv_segment_iterator
213
214template <typename DV> class dv_segment {
215private:
217
218public:
219 using difference_type = std::ptrdiff_t;
220 dv_segment() = default;
221 dv_segment(DV *dv, std::size_t segment_index, std::size_t size,
222 std::size_t reserved) {
223 dv_ = dv;
224 segment_index_ = segment_index;
225 size_ = size;
226 reserved_ = reserved;
227 assert(dv_ != nullptr);
228 }
229
230 auto size() const {
231 assert(dv_ != nullptr);
232 return size_;
233 }
234
235 auto begin() const { return iterator(dv_, segment_index_, 0); }
236 auto end() const { return begin() + size(); }
237 auto reserved() const { return reserved_; }
238
239 auto operator[](difference_type n) const { return *(begin() + n); }
240
241 bool is_local() const { return segment_index_ == default_comm().rank(); }
242
243private:
244 DV *dv_ = nullptr;
245 std::size_t segment_index_;
246 std::size_t size_;
247 std::size_t reserved_;
248}; // dv_segment
249
250//
251// Many views preserve the distributed_vector segments iterator, which
252// can supply halo
253//
254template <typename DR>
255concept has_halo_method = dr::distributed_range<DR> && requires(DR &&dr) {
256 { rng::begin(dr::ranges::segments(dr)[0]).halo() };
257};
258
259auto &halo(has_halo_method auto &&dr) {
260 return rng::begin(dr::ranges::segments(dr)[0]).halo();
261}
262
263} // namespace dr::mp
Definition: index.hpp:34
Definition: segment.hpp:34
Definition: segment.hpp:11
Definition: segment.hpp:214
Definition: concepts.hpp:20
Definition: segment.hpp:255
Definition: halo.hpp:362