Distributed Ranges
Loading...
Searching...
No Matches
csr_eq_segment.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7namespace dr::mp {
8
9template <typename DSM> class csr_eq_segment_iterator;
10
11template <typename DSM> class csr_eq_segment_reference {
13
14public:
15 using value_type = typename DSM::value_type;
16 using index_type = typename DSM::index_type;
17 using elem_type = typename DSM::elem_type;
18
19 csr_eq_segment_reference(const iterator it) : iterator_(it) {}
20
21 operator value_type() const { return iterator_.get(); }
22 operator std::pair<std::pair<index_type, index_type>, elem_type>() const {
23 return iterator_.get();
24 }
25
26 template <std::size_t Index> auto get() const noexcept {
27 if constexpr (Index == 0) {
28 return iterator_.get_index();
29 }
30 if constexpr (Index == 1) {
31 return iterator_.get_value();
32 }
33 }
34
35 auto operator=(const value_type &value) const {
36 iterator_.put(value);
37 return *this;
38 }
39 auto operator=(const csr_eq_segment_reference &other) const {
40 *this = value_type(other);
41 return *this;
42 }
43 auto operator&() const { return iterator_; }
44
45private:
46 const iterator iterator_;
47}; // csr_eq_segment_reference
48
49template <typename DSM> class csr_eq_segment_iterator {
50public:
51 using value_type = typename DSM::value_type;
52 using index_type = typename DSM::index_type;
53 using elem_type = typename DSM::elem_type;
54 using difference_type = typename DSM::difference_type;
55
56 csr_eq_segment_iterator() = default;
57 csr_eq_segment_iterator(DSM *dsm, std::size_t segment_index,
58 std::size_t index) {
59 dsm_ = dsm;
60 segment_index_ = segment_index;
61 index_ = index;
62 }
63
64 auto operator<=>(const csr_eq_segment_iterator &other) const noexcept {
65 // assertion below checks against compare dereferenceable iterator to a
66 // singular iterator and against attempt to compare iterators from different
67 // sequences like _Safe_iterator<gnu_cxx::normal_iterator> does
68 assert(dsm_ == other.dsm_);
69 return segment_index_ == other.segment_index_
70 ? index_ <=> other.index_
71 : segment_index_ <=> other.segment_index_;
72 }
73
74 // Comparison
75 bool operator==(const csr_eq_segment_iterator &other) const noexcept {
76 return (*this <=> other) == 0;
77 }
78
79 // Only this arithmetic manipulate internal state
80 auto &operator+=(difference_type n) {
81 assert(dsm_ != nullptr);
82 assert(n >= 0 || static_cast<difference_type>(index_) >= -n);
83 index_ += n;
84 return *this;
85 }
86
87 auto &operator-=(difference_type n) { return *this += (-n); }
88
89 difference_type
90 operator-(const csr_eq_segment_iterator &other) const noexcept {
91 assert(dsm_ != nullptr && dsm_ == other.dsm_);
92 assert(index_ >= other.index_);
93 return index_ - other.index_;
94 }
95
96 // prefix
97 auto &operator++() {
98 *this += 1;
99 return *this;
100 }
101 auto &operator--() {
102 *this -= 1;
103 return *this;
104 }
105
106 // postfix
107 auto operator++(int) {
108 auto prev = *this;
109 *this += 1;
110 return prev;
111 }
112 auto operator--(int) {
113 auto prev = *this;
114 *this -= 1;
115 return prev;
116 }
117
118 auto operator+(difference_type n) const {
119 auto p = *this;
120 p += n;
121 return p;
122 }
123 auto operator-(difference_type n) const {
124 auto p = *this;
125 p -= n;
126 return p;
127 }
128
129 // When *this is not first in the expression
130 friend auto operator+(difference_type n,
131 const csr_eq_segment_iterator &other) {
132 return other + n;
133 }
134
135 // dereference
136 auto operator*() const {
137 assert(dsm_ != nullptr);
138 return csr_eq_segment_reference<DSM>{*this};
139 }
140 auto operator[](difference_type n) const {
141 assert(dsm_ != nullptr);
142 return *(*this + n);
143 }
144
145 void get(value_type *dst, std::size_t size) const {
146 auto elems = new elem_type[size];
147 auto indexes = new dr::index<index_type>[size];
148 get_value(elems, size);
149 get_index(indexes, size);
150 for (std::size_t i = 0; i < size; i++) {
151 *(dst + i) = {indexes[i], elems[i]};
152 }
153 }
154
155 value_type get() const {
156 value_type val;
157 get(&val, 1);
158 return val;
159 }
160
161 void get_value(elem_type *dst, std::size_t size) const {
162 assert(dsm_ != nullptr);
163 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
164 (dsm_->vals_data_->segments()[segment_index_].begin() + index_)
165 .get(dst, size);
166 }
167
168 elem_type get_value() const {
169 elem_type val;
170 get_value(&val, 1);
171 return val;
172 }
173
174 void get_index(dr::index<index_type> *dst, std::size_t size) const {
175 assert(dsm_ != nullptr);
176 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
177 auto col_data = new index_type[size];
178 (dsm_->cols_data_->segments()[segment_index_].begin() + index_)
179 .get(col_data, size);
180 index_type *rows;
181 std::size_t rows_length = dsm_->get_row_size(segment_index_);
182
183 if (rank() == dsm_->rows_backend_.getrank()) {
184 rows = dsm_->rows_data_;
185 } else {
186 rows = new index_type[rows_length];
187 dsm_->rows_backend_.getmem(rows, 0, rows_length * sizeof(index_type),
188 segment_index_);
189 }
190 auto position =
191 dsm_->cols_data_->get_segment_offset(segment_index_) + index_;
192 auto rows_iter = rows + 1;
193 auto cols_iter = col_data;
194 auto iter = dst;
195 std::size_t current_row = dsm_->row_offsets_[segment_index_];
196 std::size_t last_row = current_row + rows_length - 1;
197 for (int i = 0; i < size; i++) {
198 while (current_row < last_row && *rows_iter <= position + i) {
199 rows_iter++;
200 current_row++;
201 }
202 iter->first = current_row;
203 iter->second = *cols_iter;
204 cols_iter++;
205 iter++;
206 }
207 if (rank() != dsm_->rows_backend_.getrank()) {
208 delete[] rows;
209 }
210 delete[] col_data;
211 }
212
213 dr::index<index_type> get_index() const {
215 get_index(&val, 1);
216 return val;
217 }
218
219 void put(const value_type *dst, std::size_t size) const {
220 assert(dsm_ != nullptr);
221 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
222 (dsm_->vals_data_->segments()[segment_index_].begin() + index_)
223 .put(dst, size);
224 }
225
226 void put(const value_type &value) const { put(&value, 1); }
227
228 auto rank() const {
229 assert(dsm_ != nullptr);
230 return segment_index_;
231 }
232
233 auto segments() const {
234 assert(dsm_ != nullptr);
235 return dr::__detail::drop_segments(dsm_->segments(), segment_index_,
236 index_);
237 }
238
239 auto local() const {
240 const auto my_process_segment_index = dsm_->rows_backend_.getrank();
241 assert(my_process_segment_index == segment_index_);
242 if (dsm_->local_view == nullptr) {
243 throw std::runtime_error("Requesting not existing local segment");
244 }
245 return dsm_->local_view->begin();
246 }
247
248private:
249 // all fields need to be initialized by default ctor so every default
250 // constructed iter is equal to any other default constructed iter
251 DSM *dsm_ = nullptr;
252 std::size_t segment_index_ = 0;
253 std::size_t index_ = 0;
254}; // csr_eq_segment_iterator
255
256template <typename DSM> class csr_eq_segment {
257private:
259
260public:
261 using difference_type = std::ptrdiff_t;
262 csr_eq_segment() = default;
263 csr_eq_segment(DSM *dsm, std::size_t segment_index, std::size_t size,
264 std::size_t reserved) {
265 dsm_ = dsm;
266 segment_index_ = segment_index;
267 size_ = size;
268 reserved_ = reserved;
269 assert(dsm_ != nullptr);
270 }
271
272 auto size() const {
273 assert(dsm_ != nullptr);
274 return size_;
275 }
276
277 auto begin() const { return iterator(dsm_, segment_index_, 0); }
278 auto end() const { return begin() + size(); }
279 auto reserved() const { return reserved_; }
280
281 auto operator[](difference_type n) const { return *(begin() + n); }
282
283 bool is_local() const { return segment_index_ == default_comm().rank(); }
284
285private:
286 DSM *dsm_ = nullptr;
287 std::size_t segment_index_;
288 std::size_t size_;
289 std::size_t reserved_;
290}; // csr_eq_segment
291
292} // namespace dr::mp
293
294namespace std {
295template <typename DSM>
296struct tuple_size<dr::mp::csr_eq_segment_reference<DSM>>
297 : std::integral_constant<std::size_t, 2> {};
298
299template <std::size_t Index, typename DSM>
300struct tuple_element<Index, dr::mp::csr_eq_segment_reference<DSM>>
301 : tuple_element<Index, std::tuple<dr::index<typename DSM::index_type>,
302 typename DSM::elem_type>> {};
303
304} // namespace std
Definition: index.hpp:34
Definition: csr_eq_segment.hpp:49
Definition: csr_eq_segment.hpp:11
Definition: csr_eq_segment.hpp:256