Distributed Ranges
Loading...
Searching...
No Matches
csr_matrix_base.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/matrix_entry.hpp>
8#include <future>
9#include <memory>
10#include <vector>
11
12namespace dr {
13
14namespace __detail {
15
16template <typename T, typename I, typename Allocator = std::allocator<T>>
18public:
19 using value_type = std::pair<T, I>;
20 using scalar_type = T;
21 using index_type = I;
22 using size_type = std::size_t;
23 using difference_type = std::ptrdiff_t;
24
25 using allocator_type = Allocator;
26
27 using key_type = dr::index<I>;
28 using map_type = T;
29
30 using backend_allocator_type = typename std::allocator_traits<
31 allocator_type>::template rebind_alloc<value_type>;
32 using aggregator_allocator_type = typename std::allocator_traits<
33 allocator_type>::template rebind_alloc<std::vector<value_type>>;
34 using row_type = std::vector<value_type, backend_allocator_type>;
35 using backend_type = std::vector<row_type, aggregator_allocator_type>;
36
37 using iterator = typename backend_type::iterator;
38 using const_iterator = typename backend_type::const_iterator;
39
40 csr_matrix_base(dr::index<I> shape, std::size_t nnz) : shape_(shape) {
41 auto average_size = nnz / shape.first / 2;
42 for (std::size_t i = 0; i < shape.first; i++) {
43 tuples_.push_back(row_type());
44 tuples_.back().reserve(average_size);
45 }
46 }
47
48 dr::index<I> shape() const noexcept { return shape_; }
49
50 size_type size() const noexcept { return size_; }
51
52 iterator begin() noexcept { return tuples_.begin(); }
53
54 const_iterator begin() const noexcept { return tuples_.begin(); }
55
56 iterator end() noexcept { return tuples_.end(); }
57
58 const_iterator end() const noexcept { return tuples_.end(); }
59
60 template <typename InputIt> void push_back(InputIt first, InputIt last) {
61 for (auto iter = first; iter != last; ++iter) {
62 push_back(*iter);
63 }
64 }
65
66 void push_back(index_type row, const value_type &value) {
67 tuples_[row].push_back(value);
68 size_++;
69 }
70
71 void sort() {
72 auto comparator = [](auto &one, auto &two) {
73 return one.second < two.second;
74 };
75 for (auto &elem : tuples_) {
76 std::sort(elem.begin(), elem.end(), comparator);
77 }
78 }
79
80 csr_matrix_base() = default;
81 ~csr_matrix_base() = default;
82 csr_matrix_base(const csr_matrix_base &) = default;
83 csr_matrix_base(csr_matrix_base &&) = default;
84 csr_matrix_base &operator=(const csr_matrix_base &) = default;
85 csr_matrix_base &operator=(csr_matrix_base &&) = default;
86
87private:
88 std::size_t size_ = 0;
89 dr::index<I> shape_;
90 backend_type tuples_;
91};
92
93} // namespace __detail
94
95} // namespace dr
Definition: csr_matrix_base.hpp:17