Distributed Ranges
Loading...
Searching...
No Matches
local_gemv.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/ranges_shim.hpp>
8#include <dr/sp/containers/sparse_matrix.hpp>
9#include <dr/sp/util.hpp>
10
11#ifdef USE_MKL
12#include <oneapi/mkl.hpp>
13#endif
14
15namespace dr::sp {
16
17namespace __detail {
18
19template <typename T, typename I, std::random_access_iterator Iter,
20 typename... Args>
21 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
22auto custom_gemv(sycl::queue &q, dr::views::csr_matrix_view<T, I, Args...> a,
23 Iter b, Iter c,
24 const std::vector<sycl::event> &dependencies = {}) {
25 std::size_t wg = 32;
26
27 auto event = q.submit([&](auto &&h) {
28 h.depends_on(dependencies);
29 h.parallel_for(sycl::nd_range<1>(a.shape()[0] * wg, wg), [=](auto item) {
30 auto row_index = item.get_group(0);
31 auto local_id = item.get_local_id();
32 auto group_size = item.get_local_range(0);
33
34 auto row = a.row(row_index);
35
36 for (std::size_t idx = local_id; idx < row.size(); idx += group_size) {
37 auto &&[index, a_v] = row[idx];
38 auto &&[i, k] = index;
39
40 auto &&b_v = *(b + k);
41 auto &&c_v = *(c + i);
42
43 sycl::atomic_ref<T, sycl::memory_order::relaxed,
44 sycl::memory_scope::device>
45 c_ref(c_v);
46
47 c_ref += a_v * b_v;
48 }
49 });
50 });
51 return event;
52}
53
54#ifdef USE_MKL
55
56template <typename T, typename I, std::random_access_iterator Iter,
57 typename... Args>
58 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
59auto mkl_gemv(sycl::queue &q, dr::views::csr_matrix_view<T, I, Args...> a,
60 Iter b, Iter c,
61 const std::vector<sycl::event> &dependencies = {}) {
62
63 oneapi::mkl::sparse::matrix_handle_t a_handle;
64 oneapi::mkl::sparse::init_matrix_handle(&a_handle);
65
66 auto rowptr = dr::sp::__detail::local(a.rowptr_data());
67 auto colind = dr::sp::__detail::local(a.colind_data());
68 auto values = dr::sp::__detail::local(a.values_data());
69
70 oneapi::mkl::sparse::set_csr_data(q, a_handle, a.shape()[0], a.shape()[1],
71 oneapi::mkl::index_base::zero, rowptr,
72 colind, values);
73
74 auto event =
75 oneapi::mkl::sparse::gemv(q, oneapi::mkl::transpose::nontrans, T(1),
76 a_handle, b, T(1), c, dependencies);
77 return event;
78}
79
80template <typename T, typename I, std::random_access_iterator Iter,
81 typename... Args>
82 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
83auto local_gemv(sycl::queue &q, dr::views::csr_matrix_view<T, I, Args...> a,
84 Iter b, Iter c,
85 const std::vector<sycl::event> &dependencies = {}) {
86 return mkl_gemv(q, a, b, c, dependencies);
87}
88
89#else
90
91template <typename T, typename I, std::random_access_iterator Iter,
92 typename... Args>
93 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
94auto local_gemv(sycl::queue &q, dr::views::csr_matrix_view<T, I, Args...> a,
95 Iter b, Iter c,
96 const std::vector<sycl::event> &dependencies = {}) {
97 return custom_gemv(q, a, b, c, dependencies);
98}
99
100#endif
101
102} // namespace __detail
103
104} // namespace dr::sp
Definition: csr_matrix_view.hpp:126