Distributed Ranges
Loading...
Searching...
No Matches
local_gemm.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/sp/views/dense_matrix_view.hpp>
8
9#ifdef USE_MKL
10#include <oneapi/mkl.hpp>
11#endif
12
13namespace dr::sp {
14
15namespace __detail {
16
17template <typename T>
18auto custom_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
19 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
20 const std::vector<sycl::event> &dependencies = {}) {
21 assert(c.shape()[0] == a.shape()[0]);
22 assert(c.shape()[1] == b.shape()[1]);
23 assert(a.shape()[1] == b.shape()[0]);
24
25 std::size_t M = c.shape()[0];
26 std::size_t N = c.shape()[1];
27 std::size_t K = a.shape()[1];
28
29 auto a_p = a.data();
30 auto b_p = b.data();
31 auto c_p = c.data();
32
33 auto e = q.parallel_for(sycl::range<3>{M, K, N}, [=](auto idx) {
34 auto i = idx[0];
35 auto k = idx[1];
36 auto j = idx[2];
37
38 sycl::atomic_ref<T, sycl::memory_order::relaxed, sycl::memory_scope::device>
39 c_ref(c_p[i * N + j]);
40
41 c_ref += a_p[i * K + k] * b_p[k * N + j];
42 });
43 return e;
44}
45
46#ifdef USE_MKL
47
48template <typename T>
49auto mkl_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
50 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
51 const std::vector<sycl::event> &dependencies = {}) {
52 assert(c.shape()[0] == a.shape()[0]);
53 assert(c.shape()[1] == b.shape()[1]);
54 assert(a.shape()[1] == b.shape()[0]);
55
56 auto event = oneapi::mkl::blas::row_major::gemm(
57 q, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans,
58 c.shape()[0], c.shape()[1], a.shape()[1], T(1), a.data(), a.ld(),
59 b.data(), b.ld(), T(1), c.data(), c.ld(), dependencies);
60
61 return event;
62}
63
64template <typename T>
65auto local_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
66 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
67 const std::vector<sycl::event> &dependencies = {}) {
68 return mkl_gemm(q, a, b, c, dependencies);
69}
70
71#else
72
73template <typename T>
74auto local_gemm(sycl::queue &q, sp::dense_matrix_view<T> a,
75 sp::dense_matrix_view<T> b, sp::dense_matrix_view<T> c,
76 const std::vector<sycl::event> &dependencies = {}) {
77 return custom_gemm(q, a, b, c, dependencies);
78}
79
80#endif
81
82} // namespace __detail
83
84} // namespace dr::sp