7#include <dr/sp/algorithms/matrix/local_gemm.hpp>
8#include <dr/sp/containers/distributed_dense_matrix.hpp>
13void gemm(distributed_dense_matrix<T> &a, distributed_dense_matrix<T> &b,
14 distributed_dense_matrix<T> &c) {
15 gemm_buffered(a, b, c);
19void gemm_inplace(distributed_dense_matrix<T> &a,
20 distributed_dense_matrix<T> &b,
21 distributed_dense_matrix<T> &c) {
23 assert(c.shape()[0] == a.shape()[0]);
24 assert(c.shape()[1] == b.shape()[1]);
25 assert(a.shape()[1] == b.shape()[0]);
29 assert(c.grid_shape()[0] == a.grid_shape()[0]);
30 assert(c.grid_shape()[1] == b.grid_shape()[1]);
31 assert(a.grid_shape()[1] == b.grid_shape()[0]);
33 std::vector<sycl::event> events;
34 events.reserve(c.grid_shape()[0] * c.grid_shape()[1] * a.grid_shape()[1]);
36 for (std::size_t i = 0; i < c.grid_shape()[0]; i++) {
37 for (std::size_t j = 0; j < c.grid_shape()[1]; j++) {
39 auto &&c_tile = c.tile({i, j});
41 std::vector<sycl::event> local_events;
42 local_events.reserve(a.grid_shape()[1]);
44 std::size_t k_offset = i + j;
45 for (std::size_t k_ = 0; k_ < a.grid_shape()[1]; k_++) {
46 std::size_t k = (k_ + k_offset) % a.grid_shape()[1];
48 auto &&a_tile = a.tile({i, k});
49 auto &&b_tile = b.tile({k, j});
51 auto &&q = __detail::queue(dr::ranges::rank(c_tile));
53 auto e = __detail::local_gemm(q, __detail::local(a_tile),
54 __detail::local(b_tile),
55 __detail::local(c_tile), local_events);
57 local_events.push_back(e);
60 for (
auto &&e : local_events) {
66 __detail::wait(events);
70void gemm_buffered(distributed_dense_matrix<T> &a,
71 distributed_dense_matrix<T> &b,
72 distributed_dense_matrix<T> &c) {
74 assert(c.shape()[0] == a.shape()[0]);
75 assert(c.shape()[1] == b.shape()[1]);
76 assert(a.shape()[1] == b.shape()[0]);
80 assert(c.grid_shape()[0] == a.grid_shape()[0]);
81 assert(c.grid_shape()[1] == b.grid_shape()[1]);
82 assert(a.grid_shape()[1] == b.grid_shape()[0]);
84 std::vector<std::thread> threads;
86 std::atomic<double> communication = 0;
87 std::atomic<double> compute = 0;
89 for (std::size_t i = 0; i < c.grid_shape()[0]; i++) {
90 for (std::size_t j = 0; j < c.grid_shape()[1]; j++) {
91 auto c_local = c.tile({i, j});
93 threads.emplace_back([c_local, i, j, &a, &b, &communication, &compute] {
94 auto &&q = __detail::queue(dr::ranges::rank(c_local));
96 std::size_t a_elem = a.tile_shape()[0] * a.tile_shape()[1];
97 std::size_t b_elem = b.tile_shape()[0] * b.tile_shape()[1];
98 std::size_t buffer_size = std::max(a_elem, b_elem);
103 auto &&allocator = buffered_allocator;
105 std::size_t k_offset = i + j;
107 for (std::size_t k_ = 0; k_ < a.grid_shape()[1]; k_++) {
108 std::size_t k = (k_ + k_offset) % a.grid_shape()[1];
110 auto begin = std::chrono::high_resolution_clock::now();
111 auto a_tile = a.get_tile({i, k}, allocator);
112 auto b_tile = b.get_tile({k, j}, allocator);
113 auto end = std::chrono::high_resolution_clock::now();
114 double duration = std::chrono::duration<double>(end - begin).count();
115 communication += duration;
120 begin = std::chrono::high_resolution_clock::now();
121 __detail::local_gemm(q, __detail::local(a_local),
122 __detail::local(b_local),
123 __detail::local(c_local))
125 end = std::chrono::high_resolution_clock::now();
126 duration = std::chrono::duration<double>(end - begin).count();
133 for (
auto &&t : threads) {
137 bool debug_print =
false;
140 std::cout <<
"communication total: " << (double)communication << std::endl;
141 std::cout <<
"compute total: " << (double)compute << std::endl;
146void gemm_buffered_async(distributed_dense_matrix<T> &a,
147 distributed_dense_matrix<T> &b,
148 distributed_dense_matrix<T> &c) {
150 assert(c.shape()[0] == a.shape()[0]);
151 assert(c.shape()[1] == b.shape()[1]);
152 assert(a.shape()[1] == b.shape()[0]);
156 assert(c.grid_shape()[0] == a.grid_shape()[0]);
157 assert(c.grid_shape()[1] == b.grid_shape()[1]);
158 assert(a.grid_shape()[1] == b.grid_shape()[0]);
160 std::vector<std::thread> threads;
162 std::atomic<double> issue = 0;
163 std::atomic<double> sync = 0;
164 std::atomic<double> compute = 0;
166 for (std::size_t i = 0; i < c.grid_shape()[0]; i++) {
167 for (std::size_t j = 0; j < c.grid_shape()[1]; j++) {
168 auto c_local = c.tile({i, j});
170 threads.emplace_back([c_local, i, j, &a, &b, &issue, &sync, &compute] {
171 auto &&q = __detail::queue(dr::ranges::rank(c_local));
173 std::size_t a_elem = a.tile_shape()[0] * a.tile_shape()[1];
174 std::size_t b_elem = b.tile_shape()[0] * b.tile_shape()[1];
175 std::size_t buffer_size = std::max(a_elem, b_elem);
180 auto &&allocator = buffered_allocator;
182 std::size_t k_offset = i + j;
184 auto begin = std::chrono::high_resolution_clock::now();
186 a.get_tile_async({i, k_offset % a.grid_shape()[1]}, allocator);
189 b.get_tile_async({k_offset % a.grid_shape()[1], j}, allocator);
191 auto end = std::chrono::high_resolution_clock::now();
192 double duration = std::chrono::duration<double>(end - begin).count();
195 for (std::size_t k_ = 0; k_ < a.grid_shape()[1]; k_++) {
196 std::size_t k = (k_ + k_offset) % a.grid_shape()[1];
198 auto begin = std::chrono::high_resolution_clock::now();
199 auto a_tile = a_f.get();
200 auto b_tile = b_f.get();
201 auto end = std::chrono::high_resolution_clock::now();
202 double duration = std::chrono::duration<double>(end - begin).count();
208 if (k_ + 1 < a.grid_shape()[1]) {
209 begin = std::chrono::high_resolution_clock::now();
210 a_f = a.get_tile_async({i, (k + 1) % a.grid_shape()[1]}, allocator);
212 b_f = b.get_tile_async({(k + 1) % a.grid_shape()[1], j}, allocator);
214 end = std::chrono::high_resolution_clock::now();
215 duration = std::chrono::duration<double>(end - begin).count();
219 begin = std::chrono::high_resolution_clock::now();
220 __detail::local_gemm(q, __detail::local(a_local),
221 __detail::local(b_local),
222 __detail::local(c_local))
224 end = std::chrono::high_resolution_clock::now();
225 duration = std::chrono::duration<double>(end - begin).count();
232 for (
auto &&t : threads) {
236 bool debug_print =
false;
239 std::cout <<
"sync total: " << (double)sync << std::endl;
240 std::cout <<
"issue total: " << (double)issue << std::endl;
241 std::cout <<
"compute total: " << (double)compute << std::endl;
Definition: allocators.hpp:74
Definition: dense_matrix_view.hpp:21
Definition: allocators.hpp:20