Distributed Ranges
Loading...
Searching...
No Matches
gemv.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6#include <dr/mp/allocator.hpp>
7#include <dr/mp/containers/broadcasted_slim_matrix.hpp>
8#include <dr/mp/containers/broadcasted_vector.hpp>
9#include <dr/mp/containers/distributed_sparse_matrix.hpp>
10#include <dr/mp/global.hpp>
11#include <fmt/core.h>
12#include <ranges>
13
14namespace dr::mp {
15
16template <typename T, typename I, rng::output_range<T> C, typename Alloc,
17 typename Backend, typename MatDistr>
18 requires(vector_multiplicable<MatDistr>)
19void gemv(int root, C &res,
20 distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
21 broadcasted_vector<T, Alloc> b) {
22 if (default_comm().rank() == root) {
23 assert(a.shape().first == res.size());
24 assert(a.shape().second == b.size());
25 }
26 a.local_gemv_and_collect(root, res, b.broadcasted_data(), 1);
27}
28
29template <typename T, typename I, rng::output_range<T> C, typename Alloc,
30 typename Backend, typename MatDistr>
31 requires(vector_multiplicable<MatDistr>)
32void gemv(int root, C &res,
33 distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
34 broadcasted_slim_matrix<T, Alloc> b) {
35 if (default_comm().rank() == root) {
36 assert(a.shape().first * b.width() == res.size());
37 }
38 a.local_gemv_and_collect(root, res, b.broadcasted_data(), b.width());
39}
40
41template <typename T, typename I, rng::output_range<T> C, typename Alloc,
42 typename Backend, typename MatDistr>
43 requires(vector_multiplicable<MatDistr>)
44void gemv(C &res, distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
45 broadcasted_vector<T, Alloc> b) {
46 std::vector<T> workspace(res.size());
47 gemv(0, workspace, a, b);
48 auto tmp = new T[res.size()];
49 if (default_comm().rank() == 0) {
50 std::copy(workspace.begin(), workspace.end(), tmp);
51 }
52 default_comm().bcast(tmp, sizeof(T) * res.size(), 0);
53 std::copy(tmp, tmp + res.size(), res.begin());
54 delete[] tmp;
55}
56
57template <typename T, typename I, rng::output_range<T> C, typename Alloc,
58 typename Backend, typename MatDistr>
59 requires(vector_multiplicable<MatDistr>)
60void gemv(C &res, distributed_sparse_matrix<T, I, Backend, MatDistr> &a,
61 broadcasted_slim_matrix<T, Alloc> b) {
62 std::vector<T> workspace(res.size());
63 gemv(0, workspace, a, b);
64 auto tmp = new T[res.size()];
65 if (default_comm().rank() == 0) {
66 std::copy(workspace.begin(), workspace.end(), tmp);
67 }
68 default_comm().bcast(tmp, sizeof(T) * res.size(), 0);
69 std::copy(tmp, tmp + res.size(), res.begin());
70 delete[] tmp;
71}
72
73} // namespace dr::mp