7#include <dr/detail/ranges_shim.hpp>
8#include <dr/sp/containers/sparse_matrix.hpp>
9#include <dr/sp/util.hpp>
12#include <oneapi/mkl.hpp>
19template <
typename T,
typename I, std::random_access_iterator Iter,
21 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
24 const std::vector<sycl::event> &dependencies = {}) {
27 auto event = q.submit([&](
auto &&h) {
28 h.depends_on(dependencies);
29 h.parallel_for(sycl::nd_range<1>(a.shape()[0] * wg, wg), [=](
auto item) {
30 auto row_index = item.get_group(0);
31 auto local_id = item.get_local_id();
32 auto group_size = item.get_local_range(0);
34 auto row = a.row(row_index);
36 for (std::size_t idx = local_id; idx < row.size(); idx += group_size) {
37 auto &&[index, a_v] = row[idx];
38 auto &&[i, k] = index;
40 auto &&b_v = *(b + k);
41 auto &&c_v = *(c + i);
43 sycl::atomic_ref<T, sycl::memory_order::relaxed,
44 sycl::memory_scope::device>
56template <
typename T,
typename I, std::random_access_iterator Iter,
58 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
61 const std::vector<sycl::event> &dependencies = {}) {
63 oneapi::mkl::sparse::matrix_handle_t a_handle;
64 oneapi::mkl::sparse::init_matrix_handle(&a_handle);
66 auto rowptr = dr::sp::__detail::local(a.rowptr_data());
67 auto colind = dr::sp::__detail::local(a.colind_data());
68 auto values = dr::sp::__detail::local(a.values_data());
70 oneapi::mkl::sparse::set_csr_data(q, a_handle, a.shape()[0], a.shape()[1],
71 oneapi::mkl::index_base::zero, rowptr,
75 oneapi::mkl::sparse::gemv(q, oneapi::mkl::transpose::nontrans, T(1),
76 a_handle, b, T(1), c, dependencies);
80template <
typename T,
typename I, std::random_access_iterator Iter,
82 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
85 const std::vector<sycl::event> &dependencies = {}) {
86 return mkl_gemv(q, a, b, c, dependencies);
91template <
typename T,
typename I, std::random_access_iterator Iter,
93 requires(std::is_same_v<std::iter_value_t<Iter>, T>)
96 const std::vector<sycl::event> &dependencies = {}) {
97 return custom_gemv(q, a, b, c, dependencies);
Definition: csr_matrix_view.hpp:126