6#include <dr/mp/allocator.hpp>
7#include <dr/mp/containers/broadcasted_slim_matrix.hpp>
8#include <dr/mp/containers/broadcasted_vector.hpp>
9#include <dr/mp/containers/distributed_sparse_matrix.hpp>
10#include <dr/mp/global.hpp>
16template <
typename T,
typename I, rng::output_range<T> C,
typename Alloc,
17 typename Backend,
typename MatDistr>
18 requires(vector_multiplicable<MatDistr>)
19void gemv(
int root, C &res,
20 distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
21 broadcasted_vector<T, Alloc> b) {
22 if (default_comm().rank() == root) {
23 assert(a.shape().first == res.size());
24 assert(a.shape().second == b.size());
26 a.local_gemv_and_collect(root, res, b.broadcasted_data(), 1);
29template <
typename T,
typename I, rng::output_range<T> C,
typename Alloc,
30 typename Backend,
typename MatDistr>
31 requires(vector_multiplicable<MatDistr>)
32void gemv(
int root, C &res,
33 distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
34 broadcasted_slim_matrix<T, Alloc> b) {
35 if (default_comm().rank() == root) {
36 assert(a.shape().first * b.width() == res.size());
38 a.local_gemv_and_collect(root, res, b.broadcasted_data(), b.width());
41template <
typename T,
typename I, rng::output_range<T> C,
typename Alloc,
42 typename Backend,
typename MatDistr>
43 requires(vector_multiplicable<MatDistr>)
44void gemv(C &res, distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
45 broadcasted_vector<T, Alloc> b) {
46 std::vector<T> workspace(res.size());
47 gemv(0, workspace, a, b);
48 auto tmp =
new T[res.size()];
49 if (default_comm().rank() == 0) {
50 std::copy(workspace.begin(), workspace.end(), tmp);
52 default_comm().bcast(tmp,
sizeof(T) * res.size(), 0);
53 std::copy(tmp, tmp + res.size(), res.begin());
57template <
typename T,
typename I, rng::output_range<T> C,
typename Alloc,
58 typename Backend,
typename MatDistr>
59 requires(vector_multiplicable<MatDistr>)
60void gemv(C &res, distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
61 broadcasted_slim_matrix<T, Alloc> b) {
62 std::vector<T> workspace(res.size());
63 gemv(0, workspace, a, b);
64 auto tmp =
new T[res.size()];
65 if (default_comm().rank() == 0) {
66 std::copy(workspace.begin(), workspace.end(), tmp);
68 default_comm().bcast(tmp,
sizeof(T) * res.size(), 0);
69 std::copy(tmp, tmp + res.size(), res.begin());