Distributed Ranges
Loading...
Searching...
No Matches
csr_eq_distribution.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4#pragma once
5#include <dr/detail/matrix_entry.hpp>
6#include <dr/detail/multiply_view.hpp>
7#include <dr/mp/containers/matrix_formats/csr_eq_segment.hpp>
8#include <dr/views/csr_matrix_view.hpp>
9
10namespace dr::mp {
11
12template <typename T, typename I, class BackendT = MpiBackend>
14 using view_tuple = std::tuple<std::size_t, std::size_t, std::size_t, I *>;
15
16public:
19 using elem_type = T;
20 using index_type = I;
21 using difference_type = std::ptrdiff_t;
22
24 csr_eq_distribution &operator=(const csr_eq_distribution &) = delete;
25 csr_eq_distribution(csr_eq_distribution &&) { assert(false); }
26
29 std::size_t root = 0) {
30 init(csr_view, dist, root);
31 }
32
34 if (!finalized()) {
35 fence();
36 if (rows_data_ != nullptr) {
37 rows_backend_.deallocate(rows_data_, row_size_ * sizeof(index_type));
38 tuple_alloc.deallocate(view_helper_const, 1);
39 }
40 }
41 }
42 std::size_t get_id_in_segment(std::size_t offset) const {
43 return offset % segment_size_;
44 }
45 std::size_t get_segment_from_offset(std::size_t offset) const {
46 return offset / segment_size_;
47 }
48 auto segments() const { return rng::views::all(segments_); }
49 auto nnz() const { return nnz_; }
50 auto shape() const { return shape_; }
51 void fence() const { rows_backend_.fence(); }
52
53 template <typename C>
54 auto local_gemv(C &res, T *vals, std::size_t vals_width) const {
55 auto rank = rows_backend_.getrank();
56 if (nnz_ <= segment_size_ * rank) {
57 return;
58 }
59 auto vals_len = shape_[1];
60 auto size = row_sizes_[rank];
61 auto res_col_len = row_sizes_[default_comm().rank()];
62 if (dr::mp::use_sycl()) {
63 auto localVals = dr::__detail::direct_iterator(
64 dr::mp::local_segment(*vals_data_).begin());
65 auto localCols = dr::__detail::direct_iterator(
66 dr::mp::local_segment(*cols_data_).begin());
67 auto offset = rank * segment_size_;
68 auto real_segment_size =
69 std::min(nnz_ - rank * segment_size_, segment_size_);
70 auto local_data = rows_data_;
71 auto division = std::max(1ul, real_segment_size / 50);
72 auto one_computation_size = (real_segment_size + division - 1) / division;
73 auto row_size = row_size_;
74 dr::__detail::parallel_for_workaround(
75 dr::mp::sycl_queue(), sycl::range<1>{division},
76 [=](auto idx) {
77 std::size_t lower_bound = one_computation_size * idx;
78 std::size_t upper_bound =
79 std::min(one_computation_size * (idx + 1), real_segment_size);
80 std::size_t position = lower_bound + offset;
81 std::size_t first_row = rng::distance(
82 local_data,
83 std::upper_bound(local_data, local_data + row_size, position) -
84 1);
85 for (auto j = 0; j < vals_width; j++) {
86 auto row = first_row;
87 T sum = 0;
88
89 for (auto i = lower_bound; i < upper_bound; i++) {
90 while (row + 1 < row_size &&
91 local_data[row + 1] <= offset + i) {
92 sycl::atomic_ref<T, sycl::memory_order::relaxed,
93 sycl::memory_scope::device>
94 c_ref(res[row + j * res_col_len]);
95 c_ref += sum;
96 row++;
97 sum = 0;
98 }
99 auto colNum = localCols[i] + j * vals_len;
100 auto matrixVal = vals[colNum];
101 auto vectorVal = localVals[i];
102
103 sum += matrixVal * vectorVal;
104 }
105 sycl::atomic_ref<T, sycl::memory_order::relaxed,
106 sycl::memory_scope::device>
107 c_ref(res[row + j * res_col_len]);
108 c_ref += sum;
109 }
110 })
111 .wait();
112 } else {
113 auto row_i = -1;
114 auto position = segment_size_ * rank;
115 auto elem_count = std::min(segment_size_, nnz_ - segment_size_ * rank);
116 auto current_row_position = rows_data_[0];
117 auto local_vals = dr::mp::local_segment(*vals_data_);
118 auto local_cols = dr::mp::local_segment(*cols_data_);
119
120 for (int i = 0; i < elem_count; i++) {
121 while (row_i + 1 < size && position + i >= current_row_position) {
122 row_i++;
123 current_row_position = rows_data_[row_i + 1];
124 }
125 for (int j = 0; j < vals_width; j++) {
126 res[row_i + j * res_col_len] +=
127 local_vals[i] * vals[local_cols[i] + j * vals_len];
128 }
129 }
130 }
131 }
132
133 template <typename C>
134 auto local_gemv_and_collect(std::size_t root, C &res, T *vals,
135 std::size_t vals_width) const {
136 assert(res.size() == shape_.first * vals_width);
138 auto res_alloc =
139 alloc.allocate(row_sizes_[default_comm().rank()] * vals_width);
140 if (use_sycl()) {
141 sycl_queue()
142 .fill(res_alloc, 0, row_sizes_[default_comm().rank()] * vals_width)
143 .wait();
144 } else {
145 std::fill(res_alloc,
146 res_alloc + row_sizes_[default_comm().rank()] * vals_width, 0);
147 }
148
149 local_gemv(res_alloc, vals, vals_width);
150 gather_gemv_vector(root, res, res_alloc, vals_width);
151 fence();
152 alloc.deallocate(res_alloc, row_sizes_[default_comm().rank()] * vals_width);
153 }
154
155private:
157
158 template <typename C, typename A>
159 void gather_gemv_vector(std::size_t root, C &res, A &partial_res,
160 std::size_t vals_width) const {
161 auto communicator = default_comm();
163 long long *counts = new long long[communicator.size()];
164 for (auto i = 0; i < communicator.size(); i++) {
165 counts[i] = row_sizes_[i] * sizeof(T) * vals_width;
166 }
167
168 if (communicator.rank() == root) {
169 long *offsets = new long[communicator.size()];
170 offsets[0] = 0;
171 for (auto i = 0; i < communicator.size() - 1; i++) {
172 offsets[i + 1] = offsets[i] + counts[i];
173 }
174 auto gathered_res = alloc.allocate(max_row_size_ * vals_width);
175 communicator.gatherv(partial_res, counts, offsets, gathered_res, root);
176 T *gathered_res_host;
177
178 if (use_sycl()) {
179 gathered_res_host = new T[max_row_size_ * vals_width];
180 __detail::sycl_copy(gathered_res, gathered_res_host,
181 max_row_size_ * vals_width);
182 } else {
183 gathered_res_host = gathered_res;
184 }
185 rng::fill(res, 0);
186
187 for (auto k = 0; k < vals_width; k++) {
188 auto current_offset = 0;
189 for (auto i = 0; i < communicator.size(); i++) {
190 auto first_row = row_offsets_[i];
191 auto last_row = row_offsets_[i] + row_sizes_[i];
192 auto row_size = row_sizes_[i];
193 if (first_row < last_row) {
194 res[first_row + k * shape_[0]] +=
195 gathered_res_host[vals_width * current_offset + k * row_size];
196 }
197 if (first_row < last_row - 1) {
198 auto piece_start = gathered_res_host + vals_width * current_offset +
199 k * row_size + 1;
200 std::copy(piece_start, piece_start + last_row - first_row - 1,
201 res.begin() + first_row + k * shape_[0] + 1);
202 }
203 current_offset += row_sizes_[i];
204 }
205 }
206
207 if (use_sycl()) {
208 delete[] gathered_res_host;
209 }
210 delete[] offsets;
211 alloc.deallocate(gathered_res, max_row_size_ * vals_width);
212 } else {
213 communicator.gatherv(partial_res, counts, nullptr, nullptr, root);
214 }
215 delete[] counts;
216 }
217
218 std::size_t get_row_size(std::size_t rank) { return row_sizes_[rank]; }
219
220 void init(dr::views::csr_matrix_view<T, I> csr_view, auto dist,
221 std::size_t root) {
222 distribution_ = dist;
223 auto rank = rows_backend_.getrank();
224
225 std::size_t initial_data[3];
226 if (root == rank) {
227 initial_data[0] = csr_view.size();
228 initial_data[1] = csr_view.shape().first;
229 initial_data[2] = csr_view.shape().second;
230 default_comm().bcast(initial_data, sizeof(std::size_t) * 3, root);
231 } else {
232 default_comm().bcast(initial_data, sizeof(std::size_t) * 3, root);
233 }
234
235 nnz_ = initial_data[0];
236 shape_ = {initial_data[1], initial_data[2]};
237 vals_data_ = std::make_shared<distributed_vector<T>>(nnz_);
238 cols_data_ = std::make_shared<distributed_vector<I>>(nnz_);
239 dr::mp::copy(root,
240 std::ranges::subrange(csr_view.values_data(),
241 csr_view.values_data() + nnz_),
242 vals_data_->begin());
243 dr::mp::copy(root,
244 std::ranges::subrange(csr_view.colind_data(),
245 csr_view.colind_data() + nnz_),
246 cols_data_->begin());
247
248 auto row_info_size = default_comm().size() * 2 + 1;
250 std::size_t *row_information = new std::size_t[row_info_size];
251 row_offsets_.reserve(default_comm().size());
252 row_sizes_.reserve(default_comm().size());
253 if (root == default_comm().rank()) {
254 for (int i = 0; i < default_comm().size(); i++) {
255 auto first_index = vals_data_->get_segment_offset(i);
256 auto last_index = vals_data_->get_segment_offset(i + 1) - 1;
257 auto lower_limit =
258 rng::distance(csr_view.rowptr_data(),
259 std::upper_bound(csr_view.rowptr_data(),
260 csr_view.rowptr_data() + shape_[0],
261 first_index)) -
262 1;
263 auto higher_limit = rng::distance(
264 csr_view.rowptr_data(),
265 std::upper_bound(csr_view.rowptr_data(),
266 csr_view.rowptr_data() + shape_[0], last_index));
267 row_offsets_.push_back(lower_limit);
268 row_sizes_.push_back(higher_limit - lower_limit);
269 row_information[i] = lower_limit;
270 row_information[default_comm().size() + i] = higher_limit - lower_limit;
271 max_row_size_ = max_row_size_ + row_sizes_.back();
272 }
273 row_information[default_comm().size() * 2] = max_row_size_;
274 default_comm().bcast(row_information, sizeof(std::size_t) * row_info_size,
275 root);
276 } else {
277 default_comm().bcast(row_information, sizeof(std::size_t) * row_info_size,
278 root);
279 for (int i = 0; i < default_comm().size(); i++) {
280 row_offsets_.push_back(row_information[i]);
281 row_sizes_.push_back(row_information[default_comm().size() + i]);
282 }
283 max_row_size_ = row_information[default_comm().size() * 2];
284 }
285 delete[] row_information;
286 row_size_ = std::max(row_sizes_[rank], static_cast<std::size_t>(1));
287 rows_data_ =
288 static_cast<I *>(rows_backend_.allocate(row_size_ * sizeof(I)));
289
290 fence();
291 if (rank == root) {
292 for (std::size_t i = 0; i < default_comm().size(); i++) {
293 auto lower_limit = row_offsets_[i];
294 auto row_size = row_sizes_[i];
295 if (row_size > 0) {
296 rows_backend_.putmem(csr_view.rowptr_data() + lower_limit, 0,
297 row_size * sizeof(I), i);
298 }
299 }
300 }
301
302 std::size_t segment_index = 0;
303 segment_size_ = vals_data_->segment_size();
304 assert(segment_size_ == cols_data_->segment_size());
305 for (std::size_t i = 0; i < nnz_; i += segment_size_) {
306 segments_.emplace_back(this, segment_index++,
307 std::min(segment_size_, nnz_ - i), segment_size_);
308 }
309 auto local_rows = rows_data_;
310 auto real_val_size = std::min(vals_data_->segment_size(),
311 nnz_ - vals_data_->segment_size() * rank);
312 auto my_tuple = std::make_tuple(row_size_, row_offsets_[rank],
313 segment_size_ * rank, local_rows);
314 view_helper_const = tuple_alloc.allocate(1);
315
316 if (use_sycl()) {
317 sycl_queue()
318 .memcpy(view_helper_const, &my_tuple, sizeof(view_tuple))
319 .wait();
320 } else {
321 view_helper_const[0] = my_tuple;
322 }
323
324 auto local_cols = static_cast<I *>(nullptr);
325 auto local_vals = static_cast<T *>(nullptr);
326 if (cols_data_->segments().size() > rank) {
327 local_cols = cols_data_->segments()[rank].begin().local();
328 local_vals = vals_data_->segments()[rank].begin().local();
329 local_view = std::make_shared<view_type>(get_elem_view(
330 real_val_size, view_helper_const, local_cols, local_vals, rank));
331 }
332 fence();
333 }
334
335 static auto get_elem_view(std::size_t vals_size, view_tuple *helper_tuple,
336 index_type *local_cols, elem_type *local_vals,
337 std::size_t rank) {
338 auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
339 auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
340 auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
341 auto enumerated_zipped = rng::views::enumerate(zipped_results);
342 // we need to use multiply_view here,
343 // because lambda is not properly copied to sycl environment
344 // when we use variable capture
345 auto multiply_range = dr::__detail::multiply_view(
346 rng::subrange(helper_tuple, helper_tuple + 1), vals_size);
347 auto enumerted_with_data =
348 rng::views::zip(enumerated_zipped, multiply_range);
349
350 auto transformer = [=](auto x) {
351 auto [entry, tuple] = x;
352 auto [row_size, row_offset, offset, local_rows] = tuple;
353 auto [index, pair] = entry;
354 auto [val, column] = pair;
355 auto row =
356 rng::distance(local_rows,
357 std::upper_bound(local_rows, local_rows + row_size,
358 offset + index) -
359 1) +
360 row_offset;
361 dr::index<index_type> index_obj(row, column);
362 value_type entry_obj(index_obj, val);
363 return entry_obj;
364 };
365 return rng::transform_view(enumerted_with_data, std::move(transformer));
366 }
367
368 using view_type = decltype(get_elem_view(0, nullptr, nullptr, nullptr, 0));
369
371 view_tuple *view_helper_const;
372 std::shared_ptr<view_type> local_view = nullptr;
373
374 std::size_t segment_size_ = 0;
375 std::size_t row_size_ = 0;
376 std::size_t max_row_size_ = 0;
377 std::vector<std::size_t> row_offsets_;
378 std::vector<std::size_t> row_sizes_;
379
380 index_type *rows_data_ = nullptr;
381 BackendT rows_backend_;
382
383 distribution distribution_;
384 dr::index<size_t> shape_;
385 std::size_t nnz_;
386 std::vector<segment_type> segments_;
387 std::shared_ptr<distributed_vector<T>> vals_data_;
388 std::shared_ptr<distributed_vector<I>> cols_data_;
389};
390} // namespace dr::mp
Definition: onedpl_direct_iterator.hpp:15
Definition: multiply_view.hpp:119
Definition: communicator.hpp:13
Definition: index.hpp:34
Definition: matrix_entry.hpp:20
Definition: allocator.hpp:11
Definition: csr_eq_distribution.hpp:13
Definition: csr_eq_segment.hpp:49
Definition: csr_eq_segment.hpp:256
Definition: csr_matrix_view.hpp:126
Definition: distribution.hpp:11