Distributed Ranges
Loading...
Searching...
No Matches
gemv.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/index.hpp>
8#include <dr/detail/ranges_shim.hpp>
9
10#include <dr/sp/algorithms/matrix/local_gemv.hpp>
11#include <dr/sp/containers/duplicated_vector.hpp>
12#include <dr/sp/containers/sparse_matrix.hpp>
13#include <dr/sp/device_vector.hpp>
14#include <dr/sp/distributed_span.hpp>
15#include <dr/sp/util.hpp>
16
17namespace dr::sp {
18
19template <dr::distributed_range C, typename T, typename I,
21void flat_gemv(C &&c, dr::sp::sparse_matrix<T, I> &a, B &&b) {
22 assert(a.shape()[0] == c.size());
23 assert(a.shape()[1] == b.size());
24 assert(a.grid_shape()[0] == c.segments().size());
25 assert(a.grid_shape()[1] == 1);
26
27 auto &&devices = dr::sp::devices();
28
29 using b_scalar_type = rng::range_value_t<B>;
30
31 using local_vector_type =
32 dr::sp::device_vector<b_scalar_type,
34
35 std::vector<local_vector_type> local_b;
36 std::vector<sycl::event> copy_events;
37 std::vector<sycl::event> comp_events;
38
39 for (std::size_t i = 0; i < devices.size(); i++) {
40 dr::sp::device_allocator<T> allocator(dr::sp::context(), devices[i]);
41 local_b.push_back(local_vector_type(b.size(), allocator, i));
42 }
43
44 for (auto &&l_b : local_b) {
45 auto event =
46 dr::sp::copy_async(b.begin(), b.end(), dr::ranges::local(l_b.begin()));
47 copy_events.push_back(event);
48 }
49
50 for (std::size_t i = 0; i < a.grid_shape()[0]; i++) {
51 auto a_tile = a.tile(dr::index<I>(i, 0));
52
53 auto a_iter = a_tile.begin();
54 auto b_iter = dr::ranges::local(local_b[i].begin());
55 auto c_iter = dr::ranges::local(c.segments()[i].begin());
56
57 auto &&q = __detail::queue(a_tile.rank());
58
59 auto event = q.submit([&](auto &&h) {
60 h.depends_on(copy_events[a_tile.rank()]);
61 h.parallel_for(a_tile.size(), [=](auto idx) {
62 auto &&[index, a_v] = *(a_iter + idx);
63 auto &&[i, k] = index;
64 auto &&b_v = *(b_iter + k);
65 auto &&c_v = *(c_iter + i);
66 sycl::atomic_ref<T, sycl::memory_order::relaxed,
67 sycl::memory_scope::device>
68 c_ref(c_v);
69 c_ref += a_v * b_v;
70 });
71 });
72 comp_events.push_back(event);
73 }
74
75 __detail::wait(comp_events);
76}
77
78template <dr::distributed_range C, typename T, typename I,
80void gemv(C &&c, dr::sp::sparse_matrix<T, I> &a, B &&b,
81 sp::duplicated_vector<rng::range_value_t<B>> &scratch) {
82 assert(a.shape()[0] == c.size());
83 assert(a.shape()[1] == b.size());
84 assert(a.grid_shape()[0] == c.segments().size());
85 assert(a.grid_shape()[1] == 1);
86
87 auto &&b_duplicated = scratch;
88
89 std::vector<sycl::event> copy_events;
90 std::vector<sycl::event> comp_events;
91 copy_events.reserve(sp::nprocs());
92 comp_events.reserve(a.grid_shape()[0]);
93
94 for (std::size_t i = 0; i < sp::nprocs(); i++) {
95 auto &&l_b = b_duplicated.local_vector(i);
96 auto event = dr::sp::copy_async(b.begin(), b.end(), l_b.begin());
97 copy_events.push_back(event);
98 }
99
100 for (std::size_t i = 0; i < a.grid_shape()[0]; i++) {
101 auto a_tile = a.tile(dr::index<I>(i, 0));
102
103 auto b_iter =
104 dr::ranges::local(b_duplicated.local_vector(a_tile.rank()).begin());
105 auto c_iter = dr::ranges::local(c.segments()[i].begin());
106
107 assert(c.segments()[i].size() == a_tile.shape()[0]);
108 auto &&q = __detail::queue(a_tile.rank());
109
110 auto event = __detail::local_gemv(q, a_tile, b_iter, c_iter,
111 {copy_events[a_tile.rank()]});
112 comp_events.push_back(event);
113 }
114
115 __detail::wait(comp_events);
116}
117
118template <dr::distributed_range C, typename T, typename I,
120void gemv(C &&c, dr::sp::sparse_matrix<T, I> &a, B &&b) {
122
123 gemv(c, a, b, b_duplicated);
124}
125
126template <dr::distributed_range C, typename T, typename I,
128void gemv_square(C &&c, dr::sp::sparse_matrix<T, I> &a, B &&b) {
129 assert(a.shape()[0] == c.size());
130 assert(a.shape()[1] == b.size());
131 assert(a.grid_shape()[0] == c.segments().size());
132 assert(a.grid_shape()[1] == b.segments().size());
133
134 std::vector<sycl::event> events;
135
136 for (std::size_t i = 0; i < a.grid_shape()[0]; i++) {
137 std::size_t k_offset = i;
138 for (std::size_t k_ = 0; k_ < a.grid_shape()[1]; k_++) {
139 std::size_t k = (k_ + k_offset) % a.grid_shape()[1];
140 auto a_tile = a.tile(dr::index<I>(i, k));
141 auto b_segment = b.segments()[k];
142 auto c_segment = c.segments()[i];
143
144 auto b_iter = dr::ranges::local(b_segment.begin());
145 auto c_iter = dr::ranges::local(c_segment.begin());
146
147 auto &&q = __detail::queue(a_tile.rank());
148
149 auto event = __detail::custom_gemv(q, a_tile, b_iter, c_iter);
150 events.push_back(event);
151 }
152 }
153
154 __detail::wait(events);
155}
156
157template <dr::distributed_range C, typename T, typename I,
159void gemv_square_copy(C &&c, dr::sp::sparse_matrix<T, I> &a, B &&b) {
160 assert(a.shape()[0] == c.size());
161 assert(a.shape()[1] == b.size());
162 assert(a.grid_shape()[0] == c.segments().size());
163 assert(a.grid_shape()[1] == b.segments().size());
164
165 auto &&devices = dr::sp::devices();
166
167 using b_scalar_type = rng::range_value_t<B>;
168
169 using local_vector_type =
170 dr::sp::device_vector<b_scalar_type,
172
173 std::vector<local_vector_type> local_b;
174 std::vector<sycl::event> events;
175
176 local_b.reserve(a.grid_shape()[0]);
177
178 for (std::size_t i = 0; i < a.grid_shape()[0]; i++) {
180 dr::sp::context(), devices[a.tile(dr::index<I>(i, 0)).rank()]);
181 local_b.emplace_back(b.size(), allocator,
182 a.tile(dr::index<I>(i, 0)).rank());
183 }
184
185 for (std::size_t i = 0; i < a.grid_shape()[0]; i++) {
186 std::size_t k_offset = i;
187 for (std::size_t k_ = 0; k_ < a.grid_shape()[1]; k_++) {
188 std::size_t k = (k_ + k_offset) % a.grid_shape()[1];
189 auto a_tile = a.tile({i, k});
190 auto b_iter = local_b[i].begin() + (k * a.tile_shape()[1]);
191 auto c_iter = c.segments()[i].begin();
192
193 auto &&b_segment = b.segments()[k];
194 auto &&q = __detail::queue(a_tile.rank());
195
196 auto ce =
197 dr::sp::copy_async(q, b_segment.begin(), b_segment.end(), b_iter);
198
199 auto event = __detail::custom_gemv(q, a_tile, b_iter.local(),
200 c_iter.local(), {ce});
201
202 events.push_back(event);
203 }
204 }
205
206 __detail::wait(events);
207}
208
209} // namespace dr::sp
Definition: index.hpp:34
Definition: allocators.hpp:20
Definition: device_vector.hpp:13
Definition: duplicated_vector.hpp:13
Definition: sparse_matrix.hpp:135
Definition: concepts.hpp:20