7#include <dr/sp/views/dense_matrix_view.hpp>
10#include <oneapi/mkl.hpp>
18auto custom_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
19 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
20 const std::vector<sycl::event> &dependencies = {}) {
21 assert(c.shape()[0] == a.shape()[0]);
22 assert(c.shape()[1] == b.shape()[1]);
23 assert(a.shape()[1] == b.shape()[0]);
25 std::size_t M = c.shape()[0];
26 std::size_t N = c.shape()[1];
27 std::size_t K = a.shape()[1];
33 auto e = q.parallel_for(sycl::range<3>{M, K, N}, [=](
auto idx) {
38 sycl::atomic_ref<T, sycl::memory_order::relaxed, sycl::memory_scope::device>
39 c_ref(c_p[i * N + j]);
41 c_ref += a_p[i * K + k] * b_p[k * N + j];
49auto mkl_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
50 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
51 const std::vector<sycl::event> &dependencies = {}) {
52 assert(c.shape()[0] == a.shape()[0]);
53 assert(c.shape()[1] == b.shape()[1]);
54 assert(a.shape()[1] == b.shape()[0]);
56 auto event = oneapi::mkl::blas::row_major::gemm(
57 q, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans,
58 c.shape()[0], c.shape()[1], a.shape()[1], T(1), a.data(), a.ld(),
59 b.data(), b.ld(), T(1), c.data(), c.ld(), dependencies);
65auto local_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
66 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
67 const std::vector<sycl::event> &dependencies = {}) {
68 return mkl_gemm(q, a, b, c, dependencies);
74auto local_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
75 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
76 const std::vector<sycl::event> &dependencies = {}) {
77 return custom_gemm(q, a, b, c, dependencies);