8template <
typename DSM>
class csr_row_segment_iterator;
14 using value_type =
typename DSM::value_type;
15 using index_type =
typename DSM::index_type;
16 using elem_type =
typename DSM::elem_type;
20 operator value_type()
const {
return iterator_.get(); }
21 operator std::pair<std::pair<index_type, index_type>, elem_type>()
const {
22 return iterator_.get();
25 template <std::
size_t Index>
auto get()
const noexcept {
26 if constexpr (Index == 0) {
27 return iterator_.get_index();
29 if constexpr (Index == 1) {
30 return iterator_.get_value();
35 *
this = value_type(other);
38 auto operator&()
const {
return iterator_; }
46 using value_type =
typename DSM::value_type;
47 using index_type =
typename DSM::index_type;
48 using elem_type =
typename DSM::elem_type;
49 using difference_type =
typename DSM::difference_type;
55 segment_index_ = segment_index;
63 assert(dsm_ == other.dsm_);
64 return segment_index_ == other.segment_index_
65 ? index_ <=> other.index_
66 : segment_index_ <=> other.segment_index_;
71 return (*this <=> other) == 0;
75 auto &operator+=(difference_type n) {
76 assert(dsm_ !=
nullptr);
77 assert(n >= 0 ||
static_cast<difference_type
>(index_) >= -n);
82 auto &operator-=(difference_type n) {
return *
this += (-n); }
86 assert(dsm_ !=
nullptr && dsm_ == other.dsm_);
87 assert(index_ >= other.index_);
88 return index_ - other.index_;
102 auto operator++(
int) {
107 auto operator--(
int) {
113 auto operator+(difference_type n)
const {
118 auto operator-(difference_type n)
const {
125 friend auto operator+(difference_type n,
131 auto operator*()
const {
132 assert(dsm_ !=
nullptr);
135 auto operator[](difference_type n)
const {
136 assert(dsm_ !=
nullptr);
140 void get(value_type *dst, std::size_t size)
const {
141 auto elems =
new elem_type[size];
143 get_value(elems, size);
144 get_index(indexes, size);
145 for (std::size_t i = 0; i < size; i++) {
146 *(dst + i) = {indexes[i], elems[i]};
150 value_type get()
const {
156 void get_value(elem_type *dst, std::size_t size)
const {
157 assert(dsm_ !=
nullptr);
158 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
159 dsm_->vals_backend_.getmem(dst, index_ *
sizeof(elem_type),
160 size *
sizeof(elem_type), segment_index_);
163 elem_type get_value()
const {
170 assert(dsm_ !=
nullptr);
171 assert(segment_index_ * dsm_->segment_size_ + index_ < dsm_->nnz_);
172 index_type *col_data;
173 if (rank() == dsm_->cols_backend_.getrank()) {
174 col_data = dsm_->cols_data_ + index_;
176 col_data =
new index_type[size];
177 dsm_->cols_backend_.getmem(col_data, index_ *
sizeof(index_type),
178 size *
sizeof(index_type), segment_index_);
181 std::size_t rows_length = dsm_->segment_size_;
182 rows =
new index_type[rows_length];
183 (dsm_->rows_data_->segments()[segment_index_].begin())
184 .get(rows, rows_length);
186 auto position = dsm_->val_offsets_[segment_index_] + index_;
187 auto rows_iter = rows + 1;
188 index_type *cols_iter = col_data;
190 std::size_t current_row = dsm_->segment_size_ * segment_index_;
191 std::size_t last_row =
192 std::min(current_row + rows_length - 1, dsm_->shape_[0] - 1);
194 for (
int i = 0; i < size; i++) {
195 while (current_row < last_row && *rows_iter <= position + i) {
199 iter->first = current_row;
200 iter->second = *cols_iter;
204 if (rank() != dsm_->cols_backend_.getrank()) {
217 assert(dsm_ !=
nullptr);
218 return segment_index_;
221 auto segments()
const {
222 assert(dsm_ !=
nullptr);
223 return dr::__detail::drop_segments(dsm_->segments(), segment_index_,
228 const auto my_process_segment_index = dsm_->vals_backend_.getrank();
229 assert(my_process_segment_index == segment_index_);
230 if (dsm_->local_view ==
nullptr) {
231 throw std::runtime_error(
"Requesting not existing local segment");
233 return dsm_->local_view->begin();
240 std::size_t segment_index_ = 0;
241 std::size_t index_ = 0;
249 using difference_type = std::ptrdiff_t;
252 std::size_t reserved) {
254 segment_index_ = segment_index;
256 reserved_ = reserved;
257 assert(dsm_ !=
nullptr);
261 assert(dsm_ !=
nullptr);
265 auto begin()
const {
return iterator(dsm_, segment_index_, 0); }
266 auto end()
const {
return begin() + size(); }
267 auto reserved()
const {
return reserved_; }
269 auto operator[](difference_type n)
const {
return *(begin() + n); }
271 bool is_local()
const {
return segment_index_ == default_comm().rank(); }
275 std::size_t segment_index_;
277 std::size_t reserved_;
283template <
typename DSM>
284struct tuple_size<dr::mp::csr_row_segment_reference<DSM>>
285 : std::integral_constant<std::size_t, 2> {};
287template <std::
size_t Index,
typename DSM>
288struct tuple_element<Index, dr::mp::csr_row_segment_reference<DSM>>
289 : tuple_element<Index, std::tuple<dr::index<typename DSM::index_type>,
290 typename DSM::elem_type>> {};
Definition: csr_row_segment.hpp:44
Definition: csr_row_segment.hpp:10
Definition: csr_row_segment.hpp:244