Distributed Ranges
Loading...
Searching...
No Matches
subrange.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/ranges_shim.hpp>
8
9namespace dr::mp {
10
11template <typename DM> class subrange_iterator {
12public:
13 using value_type = typename DM::value_type;
14 using difference_type = typename DM::difference_type;
15
17
18 subrange_iterator(DM *dm, std::pair<std::size_t, std::size_t> row_rng,
19 std::pair<std::size_t, std::size_t> col_rng,
20 difference_type index = 0) noexcept {
21 dm_ = dm;
22 row_rng_ = row_rng;
23 col_rng_ = col_rng;
24 index_ = index;
25 }
26
27 value_type &operator*() const {
28
29 int offset = dm_->halo_bounds().prev + find_dm_offset(index_) -
30 default_comm().rank() * dm_->segment_size();
31
32 assert(offset >= 0);
33 assert(offset < (int)dm_->data_size());
34 return *(dm_->data() + offset);
35 }
36
37 dm_row<value_type> operator[](int n) {
38 std::size_t rowsize = col_rng_.second - col_rng_.first;
39
40 int offset = dm_->halo_bounds().prev +
41 find_dm_offset((int)(index_ + n * rowsize)) -
42 default_comm().rank() * dm_->segment_size();
43
44 assert(offset >= 0);
45 assert(offset < (int)dm_->data_size());
46
47 signed long idx = default_comm().rank() * dm_->segment_shape()[0]; // ??
48 value_type *ptr = dm_->data() + offset;
49 const dv_segment<DM> *segment = &(dm_->segments()[0]); // comm rank ??
50
51 return dm_row<value_type>(idx, ptr, rowsize, segment);
52 }
53
54 value_type &operator[](std::pair<int, int> p) {
55 int offset = dm_->distribution_.halo().prev + find_dm_offset(index_) -
56 default_comm().rank() * dm_->segment_size() +
57 dm_->shape()[1] * p.first + p.second;
58
59 assert(offset >= 0);
60 assert(offset < (int)dm_->data_size());
61 return *(dm_->data() + offset);
62 }
63
64 // friend operators fulfill rng::detail::weakly_equality_comparable_with_
65 friend bool operator==(subrange_iterator &first, subrange_iterator &second) {
66 return first.index_ == second.index_;
67 }
68 friend bool operator!=(subrange_iterator &first, subrange_iterator &second) {
69 return first.index_ != second.index_;
70 }
71 friend bool operator==(subrange_iterator first, subrange_iterator second) {
72 return first.index_ == second.index_;
73 }
74 friend bool operator!=(subrange_iterator first, subrange_iterator second) {
75 return first.index_ != second.index_;
76 }
77 auto operator<=>(const subrange_iterator &other) const noexcept {
78 return this->index_ <=> other.index_;
79 }
80
81 // Only these arithmetic manipulate internal state
82 auto &operator-=(difference_type n) {
83 index_ -= n;
84 return *this;
85 }
86 auto &operator+=(difference_type n) {
87 index_ += n;
88 return *this;
89 }
90
91 difference_type operator-(const subrange_iterator &other) const noexcept {
92 return index_ - other.index_;
93 }
94 // prefix
95 auto &operator++() {
96 index_ += 1;
97 return *this;
98 }
99 auto &operator--() {
100 index_ -= 1;
101 return *this;
102 }
103
104 // postfix
105 auto operator++(int) {
106 auto prev = *this;
107 index_ += 1;
108 return prev;
109 }
110 auto operator--(int) {
111 auto prev = *this;
112 index_ -= 1;
113 return prev;
114 }
115
116 auto operator+(difference_type n) const {
117 return subrange_iterator(dm_, row_rng_, col_rng_, index_ + n);
118 }
119 auto operator-(difference_type n) const {
120 return subrange_iterator(dm_, row_rng_, col_rng_, index_ - n);
121 }
122
123 // When *this is not first in the expression
124 friend auto operator+(difference_type n, const subrange_iterator &other) {
125 return other + n;
126 }
127
128 auto &halo() { return dm_->halo(); }
129 auto segments() { return dm_->segments(); }
130
131 bool is_local() { return dm_->is_local_cell(find_dm_offset(index_)); }
132
133 // for debug purposes
134 std::size_t find_dm_offset() const { return find_dm_offset(index_); }
135
136private:
137 /*
138 * converts index within subrange (viewed as linear contiguous space)
139 * into index within physical segment in dm
140 */
141 std::size_t find_dm_offset(int index) const {
142 int ind_rows, ind_cols;
143 int offset = 0;
144
145 ind_rows = index / (col_rng_.second - col_rng_.first);
146 ind_cols = index % (col_rng_.second - col_rng_.first);
147
148 if (ind_cols < 0) {
149 ind_rows -= 1;
150 ind_cols += (col_rng_.second - col_rng_.first);
151 }
152
153 offset += row_rng_.first * dm_->shape()[1] + col_rng_.first;
154 offset += (int)(ind_rows * dm_->shape()[1] + ind_cols);
155
156 return offset;
157 };
158
159private:
160 DM *dm_ = nullptr;
161 std::pair<int, int> row_rng_ = std::pair<int, int>(0, 0);
162 std::pair<int, int> col_rng_ = std::pair<int, int>(0, 0);
163
164 std::size_t index_ = 0;
165}; // class subrange_iterator
166
167template <typename DM>
168class subrange : public rng::view_interface<subrange<DM>> {
169public:
171 using value_type = typename DM::value_type;
172
173 subrange(DM &dm, std::pair<std::size_t, std::size_t> row_rng,
174 std::pair<std::size_t, std::size_t> col_rng) {
175 dm_ = &dm;
176 row_rng_ = row_rng;
177 col_rng_ = col_rng;
178
179 subrng_size_ =
180 (col_rng.second - col_rng.first) * (row_rng.second - row_rng.first);
181 }
182
183 iterator begin() const { return iterator(dm_, row_rng_, col_rng_); }
184 iterator end() const { return begin() + subrng_size_; }
185
186 auto size() { return subrng_size_; }
187
188 auto &halo() const { return dm_->halo(); }
189 auto segments() const { return dm_->segments(); }
190
191private:
192 DM *dm_;
193 std::pair<std::size_t, std::size_t> row_rng_;
194 std::pair<std::size_t, std::size_t> col_rng_;
195
196 std::size_t subrng_size_ = 0;
197
198}; // class subrange
199
200} // namespace dr::mp
Definition: index.hpp:34
Definition: segment.hpp:214
Definition: subrange.hpp:11
Definition: subrange.hpp:168