5#include <dr/detail/matrix_entry.hpp>
6#include <dr/detail/multiply_view.hpp>
7#include <dr/mp/containers/matrix_formats/csr_eq_segment.hpp>
8#include <dr/views/csr_matrix_view.hpp>
12template <
typename T,
typename I,
class BackendT = MpiBackend>
14 using view_tuple = std::tuple<std::size_t, std::size_t, std::size_t, I *>;
21 using difference_type = std::ptrdiff_t;
29 std::size_t root = 0) {
30 init(csr_view, dist, root);
36 if (rows_data_ !=
nullptr) {
37 rows_backend_.deallocate(rows_data_, row_size_ *
sizeof(index_type));
38 tuple_alloc.deallocate(view_helper_const, 1);
42 std::size_t get_id_in_segment(std::size_t offset)
const {
43 return offset % segment_size_;
45 std::size_t get_segment_from_offset(std::size_t offset)
const {
46 return offset / segment_size_;
48 auto segments()
const {
return rng::views::all(segments_); }
49 auto nnz()
const {
return nnz_; }
50 auto shape()
const {
return shape_; }
51 void fence()
const { rows_backend_.fence(); }
54 auto local_gemv(C &res, T *vals, std::size_t vals_width)
const {
55 auto rank = rows_backend_.getrank();
56 if (nnz_ <= segment_size_ * rank) {
59 auto vals_len = shape_[1];
60 auto size = row_sizes_[rank];
61 auto res_col_len = row_sizes_[default_comm().rank()];
62 if (dr::mp::use_sycl()) {
64 dr::mp::local_segment(*vals_data_).begin());
66 dr::mp::local_segment(*cols_data_).begin());
67 auto offset = rank * segment_size_;
68 auto real_segment_size =
69 std::min(nnz_ - rank * segment_size_, segment_size_);
70 auto local_data = rows_data_;
71 auto division = std::max(1ul, real_segment_size / 50);
72 auto one_computation_size = (real_segment_size + division - 1) / division;
73 auto row_size = row_size_;
74 dr::__detail::parallel_for_workaround(
75 dr::mp::sycl_queue(), sycl::range<1>{division},
77 std::size_t lower_bound = one_computation_size * idx;
78 std::size_t upper_bound =
79 std::min(one_computation_size * (idx + 1), real_segment_size);
80 std::size_t position = lower_bound + offset;
81 std::size_t first_row = rng::distance(
83 std::upper_bound(local_data, local_data + row_size, position) -
85 for (
auto j = 0; j < vals_width; j++) {
89 for (
auto i = lower_bound; i < upper_bound; i++) {
90 while (row + 1 < row_size &&
91 local_data[row + 1] <= offset + i) {
92 sycl::atomic_ref<T, sycl::memory_order::relaxed,
93 sycl::memory_scope::device>
94 c_ref(res[row + j * res_col_len]);
99 auto colNum = localCols[i] + j * vals_len;
100 auto matrixVal = vals[colNum];
101 auto vectorVal = localVals[i];
103 sum += matrixVal * vectorVal;
105 sycl::atomic_ref<T, sycl::memory_order::relaxed,
106 sycl::memory_scope::device>
107 c_ref(res[row + j * res_col_len]);
114 auto position = segment_size_ * rank;
115 auto elem_count = std::min(segment_size_, nnz_ - segment_size_ * rank);
116 auto current_row_position = rows_data_[0];
117 auto local_vals = dr::mp::local_segment(*vals_data_);
118 auto local_cols = dr::mp::local_segment(*cols_data_);
120 for (
int i = 0; i < elem_count; i++) {
121 while (row_i + 1 < size && position + i >= current_row_position) {
123 current_row_position = rows_data_[row_i + 1];
125 for (
int j = 0; j < vals_width; j++) {
126 res[row_i + j * res_col_len] +=
127 local_vals[i] * vals[local_cols[i] + j * vals_len];
133 template <
typename C>
134 auto local_gemv_and_collect(std::size_t root, C &res, T *vals,
135 std::size_t vals_width)
const {
136 assert(res.size() == shape_.first * vals_width);
139 alloc.allocate(row_sizes_[default_comm().rank()] * vals_width);
142 .fill(res_alloc, 0, row_sizes_[default_comm().rank()] * vals_width)
146 res_alloc + row_sizes_[default_comm().rank()] * vals_width, 0);
149 local_gemv(res_alloc, vals, vals_width);
150 gather_gemv_vector(root, res, res_alloc, vals_width);
152 alloc.deallocate(res_alloc, row_sizes_[default_comm().rank()] * vals_width);
158 template <
typename C,
typename A>
159 void gather_gemv_vector(std::size_t root, C &res, A &partial_res,
160 std::size_t vals_width)
const {
165 counts[i] = row_sizes_[i] *
sizeof(T) * vals_width;
172 offsets[i + 1] = offsets[i] + counts[i];
174 auto gathered_res = alloc.allocate(max_row_size_ * vals_width);
175 communicator.gatherv(partial_res, counts, offsets, gathered_res, root);
176 T *gathered_res_host;
179 gathered_res_host =
new T[max_row_size_ * vals_width];
180 __detail::sycl_copy(gathered_res, gathered_res_host,
181 max_row_size_ * vals_width);
183 gathered_res_host = gathered_res;
187 for (
auto k = 0; k < vals_width; k++) {
188 auto current_offset = 0;
190 auto first_row = row_offsets_[i];
191 auto last_row = row_offsets_[i] + row_sizes_[i];
192 auto row_size = row_sizes_[i];
193 if (first_row < last_row) {
194 res[first_row + k * shape_[0]] +=
195 gathered_res_host[vals_width * current_offset + k * row_size];
197 if (first_row < last_row - 1) {
198 auto piece_start = gathered_res_host + vals_width * current_offset +
200 std::copy(piece_start, piece_start + last_row - first_row - 1,
201 res.begin() + first_row + k * shape_[0] + 1);
203 current_offset += row_sizes_[i];
208 delete[] gathered_res_host;
211 alloc.deallocate(gathered_res, max_row_size_ * vals_width);
213 communicator.gatherv(partial_res, counts,
nullptr,
nullptr, root);
218 std::size_t get_row_size(std::size_t rank) {
return row_sizes_[rank]; }
222 distribution_ = dist;
223 auto rank = rows_backend_.getrank();
225 std::size_t initial_data[3];
227 initial_data[0] = csr_view.size();
228 initial_data[1] = csr_view.shape().first;
229 initial_data[2] = csr_view.shape().second;
230 default_comm().bcast(initial_data,
sizeof(std::size_t) * 3, root);
232 default_comm().bcast(initial_data,
sizeof(std::size_t) * 3, root);
235 nnz_ = initial_data[0];
236 shape_ = {initial_data[1], initial_data[2]};
237 vals_data_ = std::make_shared<distributed_vector<T>>(nnz_);
238 cols_data_ = std::make_shared<distributed_vector<I>>(nnz_);
240 std::ranges::subrange(csr_view.values_data(),
241 csr_view.values_data() + nnz_),
242 vals_data_->begin());
244 std::ranges::subrange(csr_view.colind_data(),
245 csr_view.colind_data() + nnz_),
246 cols_data_->begin());
248 auto row_info_size = default_comm().size() * 2 + 1;
250 std::size_t *row_information =
new std::size_t[row_info_size];
251 row_offsets_.reserve(default_comm().size());
252 row_sizes_.reserve(default_comm().size());
253 if (root == default_comm().rank()) {
254 for (
int i = 0; i < default_comm().size(); i++) {
255 auto first_index = vals_data_->get_segment_offset(i);
256 auto last_index = vals_data_->get_segment_offset(i + 1) - 1;
258 rng::distance(csr_view.rowptr_data(),
259 std::upper_bound(csr_view.rowptr_data(),
260 csr_view.rowptr_data() + shape_[0],
263 auto higher_limit = rng::distance(
264 csr_view.rowptr_data(),
265 std::upper_bound(csr_view.rowptr_data(),
266 csr_view.rowptr_data() + shape_[0], last_index));
267 row_offsets_.push_back(lower_limit);
268 row_sizes_.push_back(higher_limit - lower_limit);
269 row_information[i] = lower_limit;
270 row_information[default_comm().size() + i] = higher_limit - lower_limit;
271 max_row_size_ = max_row_size_ + row_sizes_.back();
273 row_information[default_comm().size() * 2] = max_row_size_;
274 default_comm().bcast(row_information,
sizeof(std::size_t) * row_info_size,
277 default_comm().bcast(row_information,
sizeof(std::size_t) * row_info_size,
279 for (
int i = 0; i < default_comm().size(); i++) {
280 row_offsets_.push_back(row_information[i]);
281 row_sizes_.push_back(row_information[default_comm().size() + i]);
283 max_row_size_ = row_information[default_comm().size() * 2];
285 delete[] row_information;
286 row_size_ = std::max(row_sizes_[rank],
static_cast<std::size_t
>(1));
288 static_cast<I *
>(rows_backend_.allocate(row_size_ *
sizeof(I)));
292 for (std::size_t i = 0; i < default_comm().size(); i++) {
293 auto lower_limit = row_offsets_[i];
294 auto row_size = row_sizes_[i];
296 rows_backend_.putmem(csr_view.rowptr_data() + lower_limit, 0,
297 row_size *
sizeof(I), i);
302 std::size_t segment_index = 0;
303 segment_size_ = vals_data_->segment_size();
304 assert(segment_size_ == cols_data_->segment_size());
305 for (std::size_t i = 0; i < nnz_; i += segment_size_) {
306 segments_.emplace_back(
this, segment_index++,
307 std::min(segment_size_, nnz_ - i), segment_size_);
309 auto local_rows = rows_data_;
310 auto real_val_size = std::min(vals_data_->segment_size(),
311 nnz_ - vals_data_->segment_size() * rank);
312 auto my_tuple = std::make_tuple(row_size_, row_offsets_[rank],
313 segment_size_ * rank, local_rows);
314 view_helper_const = tuple_alloc.allocate(1);
318 .memcpy(view_helper_const, &my_tuple,
sizeof(view_tuple))
321 view_helper_const[0] = my_tuple;
324 auto local_cols =
static_cast<I *
>(
nullptr);
325 auto local_vals =
static_cast<T *
>(
nullptr);
326 if (cols_data_->segments().size() > rank) {
327 local_cols = cols_data_->segments()[rank].begin().local();
328 local_vals = vals_data_->segments()[rank].begin().local();
329 local_view = std::make_shared<view_type>(get_elem_view(
330 real_val_size, view_helper_const, local_cols, local_vals, rank));
335 static auto get_elem_view(std::size_t vals_size, view_tuple *helper_tuple,
336 index_type *local_cols, elem_type *local_vals,
338 auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
339 auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
340 auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
341 auto enumerated_zipped = rng::views::enumerate(zipped_results);
346 rng::subrange(helper_tuple, helper_tuple + 1), vals_size);
347 auto enumerted_with_data =
348 rng::views::zip(enumerated_zipped, multiply_range);
350 auto transformer = [=](
auto x) {
351 auto [entry, tuple] = x;
352 auto [row_size, row_offset, offset, local_rows] = tuple;
353 auto [
index, pair] = entry;
354 auto [val, column] = pair;
356 rng::distance(local_rows,
357 std::upper_bound(local_rows, local_rows + row_size,
365 return rng::transform_view(enumerted_with_data, std::move(transformer));
368 using view_type =
decltype(get_elem_view(0,
nullptr,
nullptr,
nullptr, 0));
371 view_tuple *view_helper_const;
372 std::shared_ptr<view_type> local_view =
nullptr;
374 std::size_t segment_size_ = 0;
375 std::size_t row_size_ = 0;
376 std::size_t max_row_size_ = 0;
377 std::vector<std::size_t> row_offsets_;
378 std::vector<std::size_t> row_sizes_;
380 index_type *rows_data_ =
nullptr;
381 BackendT rows_backend_;
386 std::vector<segment_type> segments_;
387 std::shared_ptr<distributed_vector<T>> vals_data_;
388 std::shared_ptr<distributed_vector<I>> cols_data_;
Definition: onedpl_direct_iterator.hpp:15
Definition: multiply_view.hpp:119
Definition: communicator.hpp:13
Definition: matrix_entry.hpp:20
Definition: allocator.hpp:11
Definition: csr_eq_distribution.hpp:13
Definition: csr_eq_segment.hpp:49
Definition: csr_eq_segment.hpp:256
Definition: csr_matrix_view.hpp:126
Definition: distribution.hpp:11