5#include <dr/detail/matrix_entry.hpp>
6#include <dr/detail/multiply_view.hpp>
7#include <dr/mp/containers/matrix_formats/csr_row_segment.hpp>
8#include <dr/views/csr_matrix_view.hpp>
12template <
typename T,
typename I,
class BackendT = MpiBackend>
14 using view_tuple = std::tuple<std::size_t, std::size_t, std::size_t, I *>;
21 using difference_type = std::ptrdiff_t;
29 std::size_t root = 0) {
30 init(csr_view, dist, root);
36 if (vals_data_ !=
nullptr) {
37 vals_backend_.deallocate(vals_data_, vals_size_ *
sizeof(index_type));
38 cols_backend_.deallocate(cols_data_, vals_size_ *
sizeof(index_type));
39 alloc.deallocate(view_helper_const, 1);
43 std::size_t get_id_in_segment(std::size_t offset)
const {
44 assert(offset < nnz_);
46 std::upper_bound(val_offsets_.begin(), val_offsets_.end(), offset) - 1;
47 return offset - *pos_iter;
49 std::size_t get_segment_from_offset(std::size_t offset)
const {
50 assert(offset < nnz_);
52 std::upper_bound(val_offsets_.begin(), val_offsets_.end(), offset);
53 return rng::distance(val_offsets_.begin(), pos_iter) - 1;
55 auto segments()
const {
return rng::views::all(segments_); }
56 auto nnz()
const {
return nnz_; }
57 auto shape()
const {
return shape_; }
59 vals_backend_.fence();
60 cols_backend_.fence();
63 auto local_gemv(C &res, T *vals, std::size_t vals_width)
const {
64 auto rank = cols_backend_.getrank();
65 if (shape_[0] <= segment_size_ * rank)
67 auto size = std::min(segment_size_, shape_[0] - segment_size_ * rank);
68 auto vals_len = shape_[1];
69 if (dr::mp::use_sycl()) {
70 auto local_vals = vals_data_;
71 auto local_cols = cols_data_;
72 auto offset = val_offsets_[rank];
73 auto real_segment_size = std::min(nnz_ - offset, val_sizes_[rank]);
75 dr::mp::local_segment(*rows_data_).begin());
76 auto res_col_len = segment_size_;
78 while (vals_width * size * wg > INT_MAX) {
85 .submit([&](
auto &&h) {
87 sycl::nd_range<1>(vals_width * size * wg, wg), [=](
auto item) {
88 auto input_j = item.get_group(0) / size;
89 auto idx = item.get_group(0) % size;
90 auto local_id = item.get_local_id();
91 auto group_size = item.get_local_range(0);
92 std::size_t lower_bound = 0;
93 if (rows_data[idx] > offset) {
94 lower_bound = rows_data[idx] - offset;
96 std::size_t upper_bound = real_segment_size;
98 upper_bound = rows_data[idx + 1] - offset;
101 for (
auto i = lower_bound + local_id; i < upper_bound;
103 auto colNum = local_cols[i];
104 auto matrixVal = vals[colNum + input_j * vals_len];
105 auto vectorVal = local_vals[i];
106 sum += matrixVal * vectorVal;
109 sycl::atomic_ref<T, sycl::memory_order::relaxed,
110 sycl::memory_scope::device>
111 c_ref(res[idx + input_j * res_col_len]);
117 auto local_rows = dr::mp::local_segment(*rows_data_);
118 auto val_count = val_sizes_[rank];
120 auto position = val_offsets_[rank];
121 auto current_row_position = local_rows[1];
123 for (
int i = 0; i < val_count; i++) {
124 while (row_i + 1 < size && position + i >= current_row_position) {
126 current_row_position = local_rows[row_i + 1];
128 for (
auto j = 0; j < vals_width; j++) {
129 res[row_i + j * segment_size_] +=
130 vals_data_[i] * vals[cols_data_[i] + j * vals_len];
136 template <
typename C>
137 auto local_gemv_and_collect(std::size_t root, C &res, T *&vals,
138 std::size_t vals_width)
const {
139 assert(res.size() == shape_.first * vals_width);
141 auto res_alloc = alloc.allocate(segment_size_ * vals_width);
143 sycl_queue().fill(res_alloc, 0, segment_size_ * vals_width).wait();
145 std::fill(res_alloc, res_alloc + segment_size_ * vals_width, 0);
148 local_gemv(res_alloc, vals, vals_width);
150 gather_gemv_vector(root, res, res_alloc, vals_width);
152 alloc.deallocate(res_alloc, segment_size_ * vals_width);
158 template <
typename C,
typename A>
159 void gather_gemv_vector(std::size_t root, C &res, A &partial_res,
160 std::size_t vals_width)
const {
166 alloc.allocate(segment_size_ *
communicator.size() * vals_width);
167 communicator.gather(partial_res, scratch, segment_size_ * vals_width,
171 temp =
new T[res.size()];
174 if (j * segment_size_ >= shape_.first) {
177 auto comm_segment_size =
178 std::min(segment_size_, shape_.first - j * segment_size_);
180 for (
auto i = 0; i < vals_width; i++) {
182 scratch + j * vals_width * segment_size_ + i * segment_size_;
185 __detail::sycl_copy(piece_start,
186 temp + shape_.first * i + j * segment_size_,
189 std::copy(piece_start, piece_start + comm_segment_size,
190 res.begin() + shape_.first * i + j * segment_size_);
195 std::copy(temp, temp + res.size(), res.begin());
198 alloc.deallocate(scratch,
201 communicator.gather(partial_res,
static_cast<T *
>(
nullptr),
202 segment_size_ * vals_width, root);
207 distribution_ = dist;
208 auto rank = vals_backend_.getrank();
210 std::size_t initial_data[3];
212 initial_data[0] = csr_view.size();
213 initial_data[1] = csr_view.shape().first;
214 initial_data[2] = csr_view.shape().second;
215 default_comm().bcast(initial_data,
sizeof(std::size_t) * 3, root);
217 default_comm().bcast(initial_data,
sizeof(std::size_t) * 3, root);
220 nnz_ = initial_data[0];
221 shape_ = {initial_data[1], initial_data[2]};
223 rows_data_ = std::make_shared<distributed_vector<I>>(shape_.first);
226 std::ranges::subrange(csr_view.rowptr_data(),
227 csr_view.rowptr_data() + shape_.first),
228 rows_data_->begin());
230 auto row_info_size = default_comm().size() * 2;
231 std::size_t *val_information =
new std::size_t[row_info_size];
232 val_offsets_.reserve(row_info_size);
233 val_sizes_.reserve(row_info_size);
235 for (
int i = 0; i < default_comm().size(); i++) {
236 auto first_index = rows_data_->get_segment_offset(i);
237 if (first_index > shape_.first) {
238 val_offsets_.push_back(nnz_);
239 val_sizes_.push_back(0);
242 std::size_t lower_limit = csr_view.rowptr_data()[first_index];
243 std::size_t higher_limit = nnz_;
244 if (rows_data_->get_segment_offset(i + 1) < shape_.first) {
245 auto last_index = rows_data_->get_segment_offset(i + 1);
246 higher_limit = csr_view.rowptr_data()[last_index];
248 val_offsets_.push_back(lower_limit);
249 val_sizes_.push_back(higher_limit - lower_limit);
250 val_information[i] = lower_limit;
251 val_information[i + default_comm().size()] = higher_limit - lower_limit;
253 default_comm().bcast(val_information,
sizeof(std::size_t) * row_info_size,
256 default_comm().bcast(val_information,
sizeof(std::size_t) * row_info_size,
258 for (
int i = 0; i < default_comm().size(); i++) {
259 val_offsets_.push_back(val_information[i]);
260 val_sizes_.push_back(val_information[default_comm().size() + i]);
263 delete[] val_information;
264 vals_size_ = std::max(val_sizes_[rank],
static_cast<std::size_t
>(1));
267 static_cast<I *
>(cols_backend_.allocate(vals_size_ *
sizeof(I)));
269 static_cast<T *
>(vals_backend_.allocate(vals_size_ *
sizeof(T)));
273 for (std::size_t i = 0; i < default_comm().size(); i++) {
274 auto lower_limit = val_offsets_[i];
275 auto row_size = val_sizes_[i];
277 vals_backend_.putmem(csr_view.values_data() + lower_limit, 0,
278 row_size *
sizeof(T), i);
279 cols_backend_.putmem(csr_view.colind_data() + lower_limit, 0,
280 row_size *
sizeof(I), i);
285 std::size_t segment_index = 0;
286 segment_size_ = rows_data_->segment_size();
287 for (std::size_t i = 0; i < default_comm().size(); i++) {
288 segments_.emplace_back(
289 this, segment_index++, val_sizes_[i],
290 std::max(val_sizes_[i],
static_cast<std::size_t
>(1)));
293 auto local_rows =
static_cast<I *
>(
nullptr);
294 if (rows_data_->segments().size() > rank) {
295 local_rows = rows_data_->segments()[rank].begin().local();
297 auto offset = val_offsets_[rank];
299 std::min(rows_data_->segment_size(),
300 shape_.first - rows_data_->segment_size() * rank);
301 auto my_tuple = std::make_tuple(real_row_size, segment_size_ * rank, offset,
303 view_helper_const = alloc.allocate(1);
307 .memcpy(view_helper_const, &my_tuple,
sizeof(view_tuple))
310 view_helper_const[0] = my_tuple;
313 if (rows_data_->segments().size() > rank) {
314 local_view = std::make_shared<view_type>(get_elem_view(
315 vals_size_, view_helper_const, cols_data_, vals_data_, rank));
320 static auto get_elem_view(std::size_t vals_size, view_tuple *helper_tuple,
321 index_type *local_cols, elem_type *local_vals,
323 auto local_vals_range = rng::subrange(local_vals, local_vals + vals_size);
324 auto local_cols_range = rng::subrange(local_cols, local_cols + vals_size);
325 auto zipped_results = rng::views::zip(local_vals_range, local_cols_range);
326 auto enumerated_zipped = rng::views::enumerate(zipped_results);
331 rng::subrange(helper_tuple, helper_tuple + 1), vals_size);
332 auto enumerted_with_data =
333 rng::views::zip(enumerated_zipped, multiply_range);
335 auto transformer = [=](
auto x) {
336 auto [entry, tuple] = x;
337 auto [row_size, row_offset, offset, local_rows] = tuple;
338 auto [
index, pair] = entry;
339 auto [val, column] = pair;
341 rng::distance(local_rows,
342 std::upper_bound(local_rows, local_rows + row_size,
350 return rng::transform_view(enumerted_with_data, std::move(transformer));
353 using view_type =
decltype(get_elem_view(0,
nullptr,
nullptr,
nullptr, 0));
356 view_tuple *view_helper_const;
357 std::shared_ptr<view_type> local_view;
359 std::size_t segment_size_ = 0;
360 std::size_t vals_size_ = 0;
361 std::vector<std::size_t> val_offsets_;
362 std::vector<std::size_t> val_sizes_;
364 index_type *cols_data_ =
nullptr;
365 BackendT cols_backend_;
367 elem_type *vals_data_ =
nullptr;
368 BackendT vals_backend_;
373 std::vector<segment_type> segments_;
374 std::shared_ptr<distributed_vector<I>> rows_data_ =
nullptr;
Definition: onedpl_direct_iterator.hpp:15
Definition: multiply_view.hpp:119
Definition: communicator.hpp:13
Definition: matrix_entry.hpp:20
Definition: allocator.hpp:11
Definition: csr_row_distribution.hpp:13
Definition: csr_row_segment.hpp:44
Definition: csr_row_segment.hpp:244
Definition: csr_matrix_view.hpp:126
Definition: distribution.hpp:11