Distributed Ranges
Loading...
Searching...
No Matches
csr_row_distribution.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4#pragma once
5#include <dr/detail/matrix_entry.hpp>
6#include <dr/detail/multiply_view.hpp>
7#include <dr/mp/containers/matrix_formats/csr_row_segment.hpp>
8#include <dr/views/csr_matrix_view.hpp>
9#include <fmt/core.h>
10
11namespace dr::mp {
12template <typename T, typename I, class BackendT = MpiBackend>
14 using view_tuple = std::tuple<std::size_t, std::size_t, std::size_t, I *>;
15
16public:
19 using elem_type = T;
20 using index_type = I;
21 using difference_type = std::ptrdiff_t;
22
24 csr_row_distribution &operator=(const csr_row_distribution &) = delete;
25 csr_row_distribution(csr_row_distribution &&) { assert(false); }
26
29 std::size_t root = 0) {
30 init(csr_view, dist, root);
31 }
32
34 if (!finalized()) {
35 fence();
36 if (vals_data_ != nullptr) {
37 vals_backend_.deallocate(vals_data_, vals_size_ * sizeof(index_type));
38 cols_backend_.deallocate(cols_data_, vals_size_ * sizeof(index_type));
39 alloc.deallocate(view_helper_const, 1);
40 }
41 }
42 }
43 std::size_t get_id_in_segment(std::size_t offset) const {
44 assert(offset < nnz_);
45 auto pos_iter =
46 std::upper_bound(val_offsets_.begin(), val_offsets_.end(), offset) - 1;
47 return offset - *pos_iter;
48 }
49 std::size_t get_segment_from_offset(std::size_t offset) const {
50 assert(offset < nnz_);
51 auto pos_iter =
52 std::upper_bound(val_offsets_.begin(), val_offsets_.end(), offset);
53 return rng::distance(val_offsets_.begin(), pos_iter) - 1;
54 }
55 auto segments() const { return rng::views::all(segments_); }
56 auto nnz() const { return nnz_; }
57 auto shape() const { return shape_; }
58 void fence() const {
59 vals_backend_.fence();
60 cols_backend_.fence();
61 }
62 template <typename C>
63 auto local_gemv(C &res, T *vals, std::size_t vals_width) const {
64 auto rank = cols_backend_.getrank();
65 if (shape_[0] <= segment_size_ * rank)
66 return;
67 auto size = std::min(segment_size_, shape_[0] - segment_size_ * rank);
68 auto vals_len = shape_[1];
69 if (dr::mp::use_sycl()) {
70 auto local_vals = vals_data_;
71 auto local_cols = cols_data_;
72 auto offset = val_offsets_[rank];
73 auto real_segment_size = std::min(nnz_ - offset, val_sizes_[rank]);
74 auto rows_data = dr::__detail::direct_iterator(
75 dr::mp::local_segment(*rows_data_).begin());
76 auto res_col_len = segment_size_;
77 std::size_t wg = 32;
78 while (vals_width * size * wg > INT_MAX) {
79 // this check is necessary, because sycl does not permit ranges
80 // exceeding integer limit
81 wg /= 2;
82 }
83 assert(wg > 0);
84 dr::mp::sycl_queue()
85 .submit([&](auto &&h) {
86 h.parallel_for(
87 sycl::nd_range<1>(vals_width * size * wg, wg), [=](auto item) {
88 auto input_j = item.get_group(0) / size;
89 auto idx = item.get_group(0) % size;
90 auto local_id = item.get_local_id();
91 auto group_size = item.get_local_range(0);
92 std::size_t lower_bound = 0;
93 if (rows_data[idx] > offset) {
94 lower_bound = rows_data[idx] - offset;
95 }
96 std::size_t upper_bound = real_segment_size;
97 if (idx < size - 1) {
98 upper_bound = rows_data[idx + 1] - offset;
99 }
100 T sum = 0;
101 for (auto i = lower_bound + local_id; i < upper_bound;
102 i += group_size) {
103 auto colNum = local_cols[i];
104 auto matrixVal = vals[colNum + input_j * vals_len];
105 auto vectorVal = local_vals[i];
106 sum += matrixVal * vectorVal;
107 }
108
109 sycl::atomic_ref<T, sycl::memory_order::relaxed,
110 sycl::memory_scope::device>
111 c_ref(res[idx + input_j * res_col_len]);
112 c_ref += sum;
113 });
114 })
115 .wait();
116 } else {
117 auto local_rows = dr::mp::local_segment(*rows_data_);
118 auto val_count = val_sizes_[rank];
119 auto row_i = 0;
120 auto position = val_offsets_[rank];
121 auto current_row_position = local_rows[1];
122
123 for (int i = 0; i < val_count; i++) {
124 while (row_i + 1 < size && position + i >= current_row_position) {
125 row_i++;
126 current_row_position = local_rows[row_i + 1];
127 }
128 for (auto j = 0; j < vals_width; j++) {
129 res[row_i + j * segment_size_] +=
130 vals_data_[i] * vals[cols_data_[i] + j * vals_len];
131 }
132 }
133 }
134 }
135
136 template <typename C>
137 auto local_gemv_and_collect(std::size_t root, C &res, T *&vals,
138 std::size_t vals_width) const {
139 assert(res.size() == shape_.first * vals_width);
141 auto res_alloc = alloc.allocate(segment_size_ * vals_width);
142 if (use_sycl()) {
143 sycl_queue().fill(res_alloc, 0, segment_size_ * vals_width).wait();
144 } else {
145 std::fill(res_alloc, res_alloc + segment_size_ * vals_width, 0);
146 }
147
148 local_gemv(res_alloc, vals, vals_width);
149
150 gather_gemv_vector(root, res, res_alloc, vals_width);
151 fence();
152 alloc.deallocate(res_alloc, segment_size_ * vals_width);
153 }
154
155private:
157
158 template <typename C, typename A>
159 void gather_gemv_vector(std::size_t root, C &res, A &partial_res,
160 std::size_t vals_width) const {
161 auto communicator = default_comm();
163
164 if (communicator.rank() == root) {
165 auto scratch =
166 alloc.allocate(segment_size_ * communicator.size() * vals_width);
167 communicator.gather(partial_res, scratch, segment_size_ * vals_width,
168 root);
169 T *temp = nullptr;
170 if (use_sycl()) {
171 temp = new T[res.size()];
172 }
173 for (auto j = 0; j < communicator.size(); j++) {
174 if (j * segment_size_ >= shape_.first) {
175 break;
176 }
177 auto comm_segment_size =
178 std::min(segment_size_, shape_.first - j * segment_size_);
179
180 for (auto i = 0; i < vals_width; i++) {
181 auto piece_start =
182 scratch + j * vals_width * segment_size_ + i * segment_size_;
183
184 if (use_sycl()) {
185 __detail::sycl_copy(piece_start,
186 temp + shape_.first * i + j * segment_size_,
187 comm_segment_size);
188 } else {
189 std::copy(piece_start, piece_start + comm_segment_size,
190 res.begin() + shape_.first * i + j * segment_size_);
191 }
192 }
193 }
194 if (use_sycl()) {
195 std::copy(temp, temp + res.size(), res.begin());
196 delete[] temp;
197 }
198 alloc.deallocate(scratch,
199 segment_size_ * communicator.size() * vals_width);
200 } else {
201 communicator.gather(partial_res, static_cast<T *>(nullptr),
202 segment_size_ * vals_width, root);
203 }
204 }
205 void init(dr::views::csr_matrix_view<T, I> csr_view, auto dist,
206 std::size_t root) {
207 distribution_ = dist;
208 auto rank = vals_backend_.getrank();
209
210 std::size_t initial_data[3];
211 if (root == rank) {
212 initial_data[0] = csr_view.size();
213 initial_data[1] = csr_view.shape().first;
214 initial_data[2] = csr_view.shape().second;
215 default_comm().bcast(initial_data, sizeof(std::size_t) * 3, root);
216 } else {
217 default_comm().bcast(initial_data, sizeof(std::size_t) * 3, root);
218 }
219
220 nnz_ = initial_data[0];
221 shape_ = {initial_data[1], initial_data[2]};
222
223 rows_data_ = std::make_shared<distributed_vector<I>>(shape_.first);
224
225 dr::mp::copy(root,
226 std::ranges::subrange(csr_view.rowptr_data(),
227 csr_view.rowptr_data() + shape_.first),
228 rows_data_->begin());
229
230 auto row_info_size = default_comm().size() * 2;
231 std::size_t *val_information = new std::size_t[row_info_size];
232 val_offsets_.reserve(row_info_size);
233 val_sizes_.reserve(row_info_size);
234 if (rank == root) {
235 for (int i = 0; i < default_comm().size(); i++) {
236 auto first_index = rows_data_->get_segment_offset(i);
237 if (first_index > shape_.first) {
238 val_offsets_.push_back(nnz_);
239 val_sizes_.push_back(0);
240 continue;
241 }
242 std::size_t lower_limit = csr_view.rowptr_data()[first_index];
243 std::size_t higher_limit = nnz_;
244 if (rows_data_->get_segment_offset(i + 1) < shape_.first) {
245 auto last_index = rows_data_->get_segment_offset(i + 1);
246 higher_limit = csr_view.rowptr_data()[last_index];
247 }
248 val_offsets_.push_back(lower_limit);
249 val_sizes_.push_back(higher_limit - lower_limit);
250 val_information[i] = lower_limit;
251 val_information[i + default_comm().size()] = higher_limit - lower_limit;
252 }
253 default_comm().bcast(val_information, sizeof(std::size_t) * row_info_size,
254 root);
255 } else {
256 default_comm().bcast(val_information, sizeof(std::size_t) * row_info_size,
257 root);
258 for (int i = 0; i < default_comm().size(); i++) {
259 val_offsets_.push_back(val_information[i]);
260 val_sizes_.push_back(val_information[default_comm().size() + i]);
261 }
262 }
263 delete[] val_information;
264 vals_size_ = std::max(val_sizes_[rank], static_cast<std::size_t>(1));
265
266 cols_data_ =
267 static_cast<I *>(cols_backend_.allocate(vals_size_ * sizeof(I)));
268 vals_data_ =
269 static_cast<T *>(vals_backend_.allocate(vals_size_ * sizeof(T)));
270
271 fence();
272 if (rank == root) {
273 for (std::size_t i = 0; i < default_comm().size(); i++) {
274 auto lower_limit = val_offsets_[i];
275 auto row_size = val_sizes_[i];
276 if (row_size > 0) {
277 vals_backend_.putmem(csr_view.values_data() + lower_limit, 0,
278 row_size * sizeof(T), i);
279 cols_backend_.putmem(csr_view.colind_data() + lower_limit, 0,
280 row_size * sizeof(I), i);
281 }
282 }
283 }
284
285 std::size_t segment_index = 0;
286 segment_size_ = rows_data_->segment_size();
287 for (std::size_t i = 0; i < default_comm().size(); i++) {
288 segments_.emplace_back(
289 this, segment_index++, val_sizes_[i],
290 std::max(val_sizes_[i], static_cast<std::size_t>(1)));
291 }
292
293 auto local_rows = static_cast<I *>(nullptr);
294 if (rows_data_->segments().size() > rank) {
295 local_rows = rows_data_->segments()[rank].begin().local();
296 }
297 auto offset = val_offsets_[rank];
298 auto real_row_size =
299 std::min(rows_data_->segment_size(),
300 shape_.first - rows_data_->segment_size() * rank);
301 auto my_tuple = std::make_tuple(real_row_size, segment_size_ * rank, offset,
302 local_rows);
303 view_helper_const = alloc.allocate(1);
304
305 if (use_sycl()) {
306 sycl_queue()
307 .memcpy(view_helper_const, &my_tuple, sizeof(view_tuple))
308 .wait();
309 } else {
310 view_helper_const[0] = my_tuple;
311 }
312
313 if (rows_data_->segments().size() > rank) {
314 local_view = std::make_shared<view_type>(get_elem_view(
315 vals_size_, view_helper_const, cols_data_, vals_data_, rank));
316 }
317 fence();
318 }
319
320 static auto get_elem_view(std::size_t vals_size, view_tuple *helper_tuple,
321 index_type *local_cols, elem_type *local_vals,
322 std::size_t rank) {
323 auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
324 auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
325 auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
326 auto enumerated_zipped = rng::views::enumerate(zipped_results);
327 // we need to use multiply_view here,
328 // because lambda is not properly copied to sycl environment
329 // when we use variable capture
330 auto multiply_range = dr::__detail::multiply_view(
331 rng::subrange(helper_tuple, helper_tuple + 1), vals_size);
332 auto enumerted_with_data =
333 rng::views::zip(enumerated_zipped, multiply_range);
334
335 auto transformer = [=](auto x) {
336 auto [entry, tuple] = x;
337 auto [row_size, row_offset, offset, local_rows] = tuple;
338 auto [index, pair] = entry;
339 auto [val, column] = pair;
340 auto row =
341 rng::distance(local_rows,
342 std::upper_bound(local_rows, local_rows + row_size,
343 offset + index) -
344 1) +
345 row_offset;
346 dr::index<index_type> index_obj(row, column);
347 value_type entry_obj(index_obj, val);
348 return entry_obj;
349 };
350 return rng::transform_view(enumerted_with_data, std::move(transformer));
351 }
352
353 using view_type = decltype(get_elem_view(0, nullptr, nullptr, nullptr, 0));
354
356 view_tuple *view_helper_const;
357 std::shared_ptr<view_type> local_view;
358
359 std::size_t segment_size_ = 0;
360 std::size_t vals_size_ = 0;
361 std::vector<std::size_t> val_offsets_;
362 std::vector<std::size_t> val_sizes_;
363
364 index_type *cols_data_ = nullptr;
365 BackendT cols_backend_;
366
367 elem_type *vals_data_ = nullptr;
368 BackendT vals_backend_;
369
370 distribution distribution_;
371 dr::index<size_t> shape_;
372 std::size_t nnz_;
373 std::vector<segment_type> segments_;
374 std::shared_ptr<distributed_vector<I>> rows_data_ = nullptr;
375};
376} // namespace dr::mp
Definition: onedpl_direct_iterator.hpp:15
Definition: multiply_view.hpp:119
Definition: communicator.hpp:13
Definition: index.hpp:34
Definition: matrix_entry.hpp:20
Definition: allocator.hpp:11
Definition: csr_row_distribution.hpp:13
Definition: csr_row_segment.hpp:44
Definition: csr_row_segment.hpp:244
Definition: csr_matrix_view.hpp:126
Definition: distribution.hpp:11