Distributed Ranges
Loading...
Searching...
No Matches
distributed_dense_matrix.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <memory>
8
9#include <dr/detail/index.hpp>
10#include <dr/detail/matrix_entry.hpp>
11#include <dr/detail/owning_view.hpp>
12#include <dr/sp/containers/matrix_partition.hpp>
13#include <dr/sp/containers/sequential/dense_matrix.hpp>
14#include <dr/sp/device_vector.hpp>
15#include <dr/sp/future.hpp>
16#include <dr/sp/views/dense_matrix_view.hpp>
17
18namespace dr::sp {
19
20template <typename T, typename L> class distributed_dense_matrix_accessor {
21public:
22 using size_type = std::size_t;
23 using difference_type = std::ptrdiff_t;
24
25 using scalar_value_type = rng::range_value_t<L>;
26 using scalar_reference = rng::range_reference_t<L>;
27
29
31
32 using iterator_category = std::random_access_iterator_tag;
33
37
38 using tile_type = L;
39
40 using key_type = dr::index<>;
41
42 constexpr distributed_dense_matrix_accessor() noexcept = default;
43 constexpr ~distributed_dense_matrix_accessor() noexcept = default;
45 const distributed_dense_matrix_accessor &) noexcept = default;
47 operator=(const distributed_dense_matrix_accessor &) noexcept = default;
48
50 std::span<tile_type> tiles, key_type grid_idx, key_type tile_idx,
51 key_type grid_shape, key_type tile_shape, key_type matrix_shape) noexcept
52 : grid_idx_(grid_idx), tile_idx_(tile_idx), grid_shape_(grid_shape),
53 tile_shape_(tile_shape), matrix_shape_(matrix_shape), tiles_(tiles) {}
54
56 operator+=(difference_type offset) noexcept {
57 std::size_t new_global_idx_ = get_global_idx_() + offset;
58 key_type new_global_idx = {new_global_idx_ / matrix_shape_[1],
59 new_global_idx_ % matrix_shape_[1]};
60 key_type new_grid_idx = {new_global_idx[0] / tile_shape_[0],
61 new_global_idx[1] / tile_shape_[1]};
62
63 key_type new_tile_idx = {new_global_idx[0] % tile_shape_[0],
64 new_global_idx[1] % tile_shape_[1]};
65
66 grid_idx_ = new_grid_idx;
67 tile_idx_ = new_tile_idx;
68 return *this;
69 }
70
71 constexpr bool operator==(const iterator_accessor &other) const noexcept {
72 return grid_idx_ == other.grid_idx_ && tile_idx_ == other.tile_idx_;
73 }
74
75 constexpr difference_type
76 operator-(const iterator_accessor &other) const noexcept {
77 return difference_type(get_global_idx_()) - other.get_global_idx_();
78 }
79
80 constexpr bool operator<(const iterator_accessor &other) const noexcept {
81 if (get_grid_idx() < other.get_grid_idx()) {
82 return true;
83 } else if (get_grid_idx() == other.get_grid_idx()) {
84 return get_local_idx() < other.get_local_idx();
85 } else {
86 return false;
87 }
88 }
89
90 constexpr reference operator*() const noexcept {
91 auto &&tile = tiles_[get_grid_idx()];
92 auto &&value = tile[get_local_idx()];
93 key_type idx = {tile_idx_[0] + grid_idx_[0] * tile_shape_[0],
94 tile_idx_[1] + grid_idx_[1] * tile_shape_[1]};
95 return reference(idx, value);
96 }
97
98private:
99 size_type get_global_idx_() const noexcept {
100 auto gidx = get_global_idx();
101 return gidx[0] * matrix_shape_[1] + gidx[1];
102 }
103
104 key_type get_global_idx() const noexcept {
105 return {grid_idx_[0] * tile_shape_[0] + tile_idx_[0],
106 grid_idx_[1] * tile_shape_[1] + tile_idx_[1]};
107 }
108
109 size_type get_grid_idx() const noexcept {
110 return grid_idx_[0] * grid_shape_[1] + grid_idx_[1];
111 }
112
113 size_type get_local_idx() const noexcept {
114 return tile_idx_[0] * tile_shape_[1] + tile_idx_[1];
115 }
116
117 size_type get_tile_size() const noexcept {
118 return tile_shape_[0] * tile_shape_[1];
119 }
120
121private:
122 key_type grid_idx_;
123 key_type tile_idx_;
124
125 key_type grid_shape_;
126 key_type tile_shape_;
127 key_type matrix_shape_;
128
129 std::span<tile_type> tiles_;
130};
131
132template <typename T, typename L>
135
136template <typename T> class distributed_dense_matrix {
137public:
138 using size_type = std::size_t;
139 using difference_type = std::ptrdiff_t;
140
142
143 using scalar_reference = rng::range_reference_t<
145 using const_scalar_reference = rng::range_reference_t<
147
150
151 using key_type = dr::index<>;
152
155
157 : shape_(shape), partition_(new dr::sp::block_cyclic()) {
158 init_();
159 }
160
161 distributed_dense_matrix(key_type shape, const matrix_partition &partition)
162 : shape_(shape), partition_(partition.clone()) {
163 init_();
164 }
165
166 size_type size() const noexcept { return shape()[0] * shape()[1]; }
167
168 key_type shape() const noexcept { return shape_; }
169
170 scalar_reference operator[](key_type index) {
171 std::size_t tile_i = index[0] / tile_shape_[0];
172 std::size_t tile_j = index[1] / tile_shape_[1];
173
174 std::size_t local_i = index[0] % tile_shape_[0];
175 std::size_t local_j = index[1] % tile_shape_[1];
176
177 auto &&tile = tiles_[tile_i * grid_shape_[1] + tile_j];
178
179 return tile[local_i * tile_shape_[1] + local_j];
180 }
181
182 const_scalar_reference operator[](key_type index) const {
183 std::size_t tile_i = index[0] / tile_shape_[0];
184 std::size_t tile_j = index[1] / tile_shape_[1];
185
186 std::size_t local_i = index[0] % tile_shape_[0];
187 std::size_t local_j = index[1] % tile_shape_[1];
188
189 auto &&tile = tiles_[tile_i * grid_shape_[1] + tile_j];
190
191 return tile[local_i * tile_shape_[1] + local_j];
192 }
193
194 iterator begin() {
195 return iterator(tiles_, key_type({0, 0}), key_type({0, 0}), grid_shape_,
196 tile_shape_, shape_);
197 }
198
199 iterator end() { return begin() + shape()[0] * shape()[1]; }
200
201 key_type tile_shape() const noexcept { return tile_shape_; }
202
203 key_type grid_shape() const noexcept { return grid_shape_; }
204
205 auto tile(key_type tile_index) {
206 auto &&[i, j] = tile_index;
207 auto iter = tiles_[i * grid_shape()[1] + j].begin();
208
209 std::size_t tm =
210 std::min(tile_shape()[0], shape()[0] - i * tile_shape()[0]);
211 std::size_t tn =
212 std::min(tile_shape()[1], shape()[1] - j * tile_shape()[1]);
213
214 return dense_matrix_view<
215 T,
216 rng::iterator_t<dr::sp::device_vector<T, dr::sp::device_allocator<T>>>>(
217 iter, key_type{tm, tn}, tile_shape()[1],
218 tiles_[i * grid_shape()[1] + j].rank());
219 }
220
221 std::vector<dense_matrix_view<T, rng::iterator_t<dr::sp::device_vector<
223 tiles() {
224 std::vector<dense_matrix_view<T, rng::iterator_t<dr::sp::device_vector<
226 views_;
227
228 for (std::size_t i = 0; i < grid_shape_[0]; i++) {
229 for (std::size_t j = 0; j < grid_shape_[1]; j++) {
230 auto iter = tiles_[i * grid_shape_[1] + j].begin();
231
232 std::size_t tm =
233 std::min(tile_shape_[0], shape()[0] - i * tile_shape_[0]);
234 std::size_t tn =
235 std::min(tile_shape_[1], shape()[1] - j * tile_shape_[1]);
236
237 views_.emplace_back(iter, key_type{tm, tn}, tile_shape_[1],
238 tiles_[i * grid_shape_[1] + j].rank());
239 }
240 }
241 return views_;
242 }
243
244 template <typename Allocator = std::allocator<T>>
245 auto get_tile(key_type tile_index, const Allocator &alloc = Allocator{}) {
246 std::size_t nrows = get_tile_shape_(tile_index)[0];
247 std::size_t ld = tile_shape_[1];
248 std::size_t tile_size = nrows * ld;
249 dense_matrix<T, Allocator> local_tile(get_tile_shape_(tile_index), ld,
250 alloc);
251 auto remote_tile = tile(tile_index);
252 sp::copy(remote_tile.data(), remote_tile.data() + tile_size,
253 local_tile.data());
254 return local_tile;
255 }
256
257 template <typename Allocator = std::allocator<T>>
258 auto get_tile_async(key_type tile_index,
259 const Allocator &alloc = Allocator{}) {
260 std::size_t nrows = get_tile_shape_(tile_index)[0];
261 std::size_t ld = tile_shape_[1];
262 std::size_t tile_size = nrows * ld;
263 dense_matrix<T, Allocator> local_tile(get_tile_shape_(tile_index), ld,
264 alloc);
265 auto remote_tile = tile(tile_index);
266 auto event = sp::copy_async(
267 remote_tile.data(), remote_tile.data() + tile_size, local_tile.data());
268 // TODO: use/enhance the existing future in parallel_backend_sycl_utils.h
269 return future(std::move(local_tile), {event});
270 }
271
272 auto segments() {
273 std::vector<dense_matrix_view<T, rng::iterator_t<dr::sp::device_vector<
275 views_;
276
277 for (std::size_t i = 0; i < grid_shape_[0]; i++) {
278 for (std::size_t j = 0; j < grid_shape_[1]; j++) {
279 auto iter = tiles_[i * grid_shape_[1] + j].begin();
280
281 std::size_t tm =
282 std::min(tile_shape_[0], shape()[0] - i * tile_shape_[0]);
283 std::size_t tn =
284 std::min(tile_shape_[1], shape()[1] - j * tile_shape_[1]);
285
286 std::size_t m_offset = i * tile_shape_[0];
287 std::size_t n_offset = j * tile_shape_[1];
288
289 views_.emplace_back(iter, key_type{tm, tn},
290 key_type{m_offset, n_offset}, tile_shape_[1],
291 tiles_[i * grid_shape_[1] + j].rank());
292 }
293 }
294 return dr::__detail::owning_view(std::move(views_));
295 }
296
297private:
298 void init_() {
299 grid_shape_ = partition_->grid_shape(shape());
300 tile_shape_ = partition_->tile_shape(shape());
301
302 tiles_.reserve(grid_shape_[0] * grid_shape_[1]);
303
304 for (std::size_t i = 0; i < grid_shape_[0]; i++) {
305 for (std::size_t j = 0; j < grid_shape_[1]; j++) {
306 std::size_t rank = partition_->tile_rank(shape(), {i, j});
307
308 auto device = dr::sp::devices()[rank];
309 dr::sp::device_allocator<T> alloc(dr::sp::context(), device);
310
311 std::size_t tile_size = tile_shape_[0] * tile_shape_[1];
312
313 tiles_.emplace_back(tile_size, alloc, rank);
314 }
315 }
316 }
317
318 key_type get_tile_shape_(key_type tile_index) {
319 auto &&[i, j] = tile_index;
320 std::size_t tm = std::min(tile_shape_[0], shape()[0] - i * tile_shape_[0]);
321 std::size_t tn = std::min(tile_shape_[1], shape()[1] - j * tile_shape_[1]);
322 return key_type{tm, tn};
323 }
324
325private:
326 key_type shape_;
327 key_type grid_shape_;
328 key_type tile_shape_;
329 std::unique_ptr<dr::sp::matrix_partition> partition_;
330
331 std::vector<dr::sp::device_vector<T, dr::sp::device_allocator<T>>> tiles_;
332};
333
334} // namespace dr::sp
Definition: owning_view.hpp:18
Definition: index.hpp:34
Definition: iterator_adaptor.hpp:23
Definition: matrix_entry.hpp:20
Definition: matrix_entry.hpp:115
Definition: matrix_partition.hpp:34
Definition: dense_matrix_view.hpp:21
Definition: dense_matrix.hpp:19
Definition: allocators.hpp:20
Definition: device_vector.hpp:13
Definition: distributed_dense_matrix.hpp:20
Definition: distributed_dense_matrix.hpp:136
Definition: future.hpp:14
Definition: matrix_partition.hpp:23