Distributed Ranges
Loading...
Searching...
No Matches
csr_row_segment.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7namespace dr::mp {
8template <typename DSM> class csr_row_segment_iterator;
9
10template <typename DSM> class csr_row_segment_reference {
12
13public:
14 using value_type = typename DSM::value_type;
15 using index_type = typename DSM::index_type;
16 using elem_type = typename DSM::elem_type;
17
18 csr_row_segment_reference(const iterator it) : iterator_(it) {}
19
20 operator value_type() const { return iterator_.get(); }
21 operator std::pair<std::pair<index_type, index_type>, elem_type>() const {
22 return iterator_.get();
23 }
24
25 template <std::size_t Index> auto get() const noexcept {
26 if constexpr (Index == 0) {
27 return iterator_.get_index();
28 }
29 if constexpr (Index == 1) {
30 return iterator_.get_value();
31 }
32 }
33
34 auto operator=(const csr_row_segment_reference &other) const {
35 *this = value_type(other);
36 return *this;
37 }
38 auto operator&() const { return iterator_; }
39
40private:
41 const iterator iterator_;
42}; // csr_row_segment_reference
43
44template <typename DSM> class csr_row_segment_iterator {
45public:
46 using value_type = typename DSM::value_type;
47 using index_type = typename DSM::index_type;
48 using elem_type = typename DSM::elem_type;
49 using difference_type = typename DSM::difference_type;
50
51 csr_row_segment_iterator() = default;
52 csr_row_segment_iterator(DSM *dsm, std::size_t segment_index,
53 std::size_t index) {
54 dsm_ = dsm;
55 segment_index_ = segment_index;
56 index_ = index;
57 }
58
59 auto operator<=>(const csr_row_segment_iterator &other) const noexcept {
60 // assertion below checks against compare dereferenceable iterator to a
61 // singular iterator and against attempt to compare iterators from different
62 // sequences like _Safe_iterator<gnu_cxx::normal_iterator> does
63 assert(dsm_ == other.dsm_);
64 return segment_index_ == other.segment_index_
65 ? index_ <=> other.index_
66 : segment_index_ <=> other.segment_index_;
67 }
68
69 // Comparison
70 bool operator==(const csr_row_segment_iterator &other) const noexcept {
71 return (*this <=> other) == 0;
72 }
73
74 // Only this arithmetic manipulate internal state
75 auto &operator+=(difference_type n) {
76 assert(dsm_ != nullptr);
77 assert(n >= 0 || static_cast<difference_type>(index_) >= -n);
78 index_ += n;
79 return *this;
80 }
81
82 auto &operator-=(difference_type n) { return *this += (-n); }
83
84 difference_type
85 operator-(const csr_row_segment_iterator &other) const noexcept {
86 assert(dsm_ != nullptr && dsm_ == other.dsm_);
87 assert(index_ >= other.index_);
88 return index_ - other.index_;
89 }
90
91 // prefix
92 auto &operator++() {
93 *this += 1;
94 return *this;
95 }
96 auto &operator--() {
97 *this -= 1;
98 return *this;
99 }
100
101 // postfix
102 auto operator++(int) {
103 auto prev = *this;
104 *this += 1;
105 return prev;
106 }
107 auto operator--(int) {
108 auto prev = *this;
109 *this -= 1;
110 return prev;
111 }
112
113 auto operator+(difference_type n) const {
114 auto p = *this;
115 p += n;
116 return p;
117 }
118 auto operator-(difference_type n) const {
119 auto p = *this;
120 p -= n;
121 return p;
122 }
123
124 // When *this is not first in the expression
125 friend auto operator+(difference_type n,
126 const csr_row_segment_iterator &other) {
127 return other + n;
128 }
129
130 // dereference
131 auto operator*() const {
132 assert(dsm_ != nullptr);
133 return csr_row_segment_reference<DSM>{*this};
134 }
135 auto operator[](difference_type n) const {
136 assert(dsm_ != nullptr);
137 return *(*this + n);
138 }
139
140 void get(value_type *dst, std::size_t size) const {
141 auto elems = new elem_type[size];
142 auto indexes = new dr::index<index_type>[size];
143 get_value(elems, size);
144 get_index(indexes, size);
145 for (std::size_t i = 0; i < size; i++) {
146 *(dst + i) = {indexes[i], elems[i]};
147 }
148 }
149
150 value_type get() const {
151 value_type val;
152 get(&val, 1);
153 return val;
154 }
155
156 void get_value(elem_type *dst, std::size_t size) const {
157 assert(dsm_ != nullptr);
158 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
159 dsm_->vals_backend_.getmem(dst, index_ * sizeof(elem_type),
160 size * sizeof(elem_type), segment_index_);
161 }
162
163 elem_type get_value() const {
164 elem_type val;
165 get_value(&val, 1);
166 return val;
167 }
168
169 void get_index(dr::index<index_type> *dst, std::size_t size) const {
170 assert(dsm_ != nullptr);
171 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
172 index_type *col_data;
173 if (rank() == dsm_->cols_backend_.getrank()) {
174 col_data = dsm_->cols_data_ + index_;
175 } else {
176 col_data = new index_type[size];
177 dsm_->cols_backend_.getmem(col_data, index_ * sizeof(index_type),
178 size * sizeof(index_type), segment_index_);
179 }
180 index_type *rows;
181 std::size_t rows_length = dsm_->segment_size_;
182 rows = new index_type[rows_length];
183 (dsm_->rows_data_->segments()[segment_index_].begin())
184 .get(rows, rows_length);
185
186 auto position = dsm_->val_offsets_[segment_index_] + index_;
187 auto rows_iter = rows + 1;
188 index_type *cols_iter = col_data;
189 auto iter = dst;
190 std::size_t current_row = dsm_->segment_size_ * segment_index_;
191 std::size_t last_row =
192 std::min(current_row + rows_length - 1, dsm_->shape_[0] - 1);
193
194 for (int i = 0; i < size; i++) {
195 while (current_row < last_row && *rows_iter <= position + i) {
196 rows_iter++;
197 current_row++;
198 }
199 iter->first = current_row;
200 iter->second = *cols_iter;
201 cols_iter++;
202 iter++;
203 }
204 if (rank() != dsm_->cols_backend_.getrank()) {
205 delete[] col_data;
206 }
207 delete[] rows;
208 }
209
210 dr::index<index_type> get_index() const {
212 get_index(&val, 1);
213 return val;
214 }
215
216 auto rank() const {
217 assert(dsm_ != nullptr);
218 return segment_index_;
219 }
220
221 auto segments() const {
222 assert(dsm_ != nullptr);
223 return dr::__detail::drop_segments(dsm_->segments(), segment_index_,
224 index_);
225 }
226
227 auto local() const {
228 const auto my_process_segment_index = dsm_->vals_backend_.getrank();
229 assert(my_process_segment_index == segment_index_);
230 if (dsm_->local_view == nullptr) {
231 throw std::runtime_error("Requesting not existing local segment");
232 }
233 return dsm_->local_view->begin();
234 }
235
236private:
237 // all fields need to be initialized by default ctor so every default
238 // constructed iter is equal to any other default constructed iter
239 DSM *dsm_ = nullptr;
240 std::size_t segment_index_ = 0;
241 std::size_t index_ = 0;
242}; // csr_row_segment_iterator
243
244template <typename DSM> class csr_row_segment {
245private:
247
248public:
249 using difference_type = std::ptrdiff_t;
250 csr_row_segment() = default;
251 csr_row_segment(DSM *dsm, std::size_t segment_index, std::size_t size,
252 std::size_t reserved) {
253 dsm_ = dsm;
254 segment_index_ = segment_index;
255 size_ = size;
256 reserved_ = reserved;
257 assert(dsm_ != nullptr);
258 }
259
260 auto size() const {
261 assert(dsm_ != nullptr);
262 return size_;
263 }
264
265 auto begin() const { return iterator(dsm_, segment_index_, 0); }
266 auto end() const { return begin() + size(); }
267 auto reserved() const { return reserved_; }
268
269 auto operator[](difference_type n) const { return *(begin() + n); }
270
271 bool is_local() const { return segment_index_ == default_comm().rank(); }
272
273private:
274 DSM *dsm_ = nullptr;
275 std::size_t segment_index_;
276 std::size_t size_;
277 std::size_t reserved_;
278}; // csr_row_segment
279
280} // namespace dr::mp
281
282namespace std {
283template <typename DSM>
284struct tuple_size<dr::mp::csr_row_segment_reference<DSM>>
285 : std::integral_constant<std::size_t, 2> {};
286
287template <std::size_t Index, typename DSM>
288struct tuple_element<Index, dr::mp::csr_row_segment_reference<DSM>>
289 : tuple_element<Index, std::tuple<dr::index<typename DSM::index_type>,
290 typename DSM::elem_type>> {};
291
292} // namespace std
Definition: index.hpp:34
Definition: csr_row_segment.hpp:44
Definition: csr_row_segment.hpp:10
Definition: csr_row_segment.hpp:244