Distributed Ranges
Loading...
Searching...
No Matches
coo_matrix.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/matrix_entry.hpp>
8#include <memory>
9#include <vector>
10
11namespace dr {
12
13namespace __detail {
14
15template <typename T, typename I, typename Allocator = std::allocator<T>>
17public:
19 using scalar_type = T;
20 using index_type = I;
21 using size_type = std::size_t;
22 using difference_type = std::ptrdiff_t;
23
24 using allocator_type = Allocator;
25
26 using key_type = dr::index<I>;
27 using map_type = T;
28
29 using backend_allocator_type = typename std::allocator_traits<
30 allocator_type>::template rebind_alloc<value_type>;
31 using backend_type = std::vector<value_type, backend_allocator_type>;
32
33 using iterator = typename backend_type::iterator;
34 using const_iterator = typename backend_type::const_iterator;
35
38
39 using scalar_reference = T &;
40
41 coo_matrix(dr::index<I> shape) : shape_(shape) {}
42
43 dr::index<I> shape() const noexcept { return shape_; }
44
45 size_type size() const noexcept { return tuples_.size(); }
46
47 void reserve(size_type new_cap) { tuples_.reserve(new_cap); }
48
49 iterator begin() noexcept { return tuples_.begin(); }
50
51 const_iterator begin() const noexcept { return tuples_.begin(); }
52
53 iterator end() noexcept { return tuples_.end(); }
54
55 const_iterator end() const noexcept { return tuples_.end(); }
56
57 template <typename InputIt> void insert(InputIt first, InputIt last) {
58 for (auto iter = first; iter != last; ++iter) {
59 insert(*iter);
60 }
61 }
62
63 template <typename InputIt> void push_back(InputIt first, InputIt last) {
64 for (auto iter = first; iter != last; ++iter) {
65 push_back(*iter);
66 }
67 }
68
69 void push_back(const value_type &value) { tuples_.push_back(value); }
70
71 template <typename InputIt> void assign_tuples(InputIt first, InputIt last) {
72 tuples_.assign(first, last);
73 }
74
75 std::pair<iterator, bool> insert(value_type &&value) {
76 auto &&[insert_index, insert_value] = value;
77 for (auto iter = begin(); iter != end(); ++iter) {
78 auto &&[index, v] = *iter;
79 if (index == insert_index) {
80 return {iter, false};
81 }
82 }
83 tuples_.push_back(value);
84 return {--tuples_.end(), true};
85 }
86
87 std::pair<iterator, bool> insert(const value_type &value) {
88 auto &&[insert_index, insert_value] = value;
89 for (auto iter = begin(); iter != end(); ++iter) {
90 auto &&[index, v] = *iter;
91 if (index == insert_index) {
92 return {iter, false};
93 }
94 }
95 tuples_.push_back(value);
96 return {--tuples_.end(), true};
97 }
98
99 template <class M>
100 std::pair<iterator, bool> insert_or_assign(key_type k, M &&obj) {
101 for (auto iter = begin(); iter != end(); ++iter) {
102 auto &&[index, v] = *iter;
103 if (index == k) {
104 v = std::forward<M>(obj);
105 return {iter, false};
106 }
107 }
108 tuples_.push_back({k, std::forward<M>(obj)});
109 return {--tuples_.end(), true};
110 }
111
112 iterator find(key_type key) noexcept {
113 return std::ranges::find_if(begin(), end(), [&](auto &&v) {
114 auto &&[i, v_] = v;
115 return i == key;
116 });
117 }
118
119 const_iterator find(key_type key) const noexcept {
120 return std::ranges::find_if(begin(), end(), [&](auto &&v) {
121 auto &&[i, v_] = v;
122 return i == key;
123 });
124 }
125
126 void reshape(dr::index<I> shape) {
127 bool all_inside = true;
128 for (auto &&[index, v] : *this) {
129 auto &&[i, j] = index;
130 if (!(i < shape[0] && j < shape[1])) {
131 all_inside = false;
132 break;
133 }
134 }
135
136 if (all_inside) {
137 shape_ = shape;
138 return;
139 } else {
140 coo_matrix<T, I> new_tuples(shape);
141 for (auto &&[index, v] : *this) {
142 auto &&[i, j] = index;
143 if (i < shape[0] && j < shape[1]) {
144 new_tuples.insert({index, v});
145 }
146 }
147 shape_ = shape;
148 assign_tuples(new_tuples.begin(), new_tuples.end());
149 }
150 }
151
152 coo_matrix() = default;
153 ~coo_matrix() = default;
154 coo_matrix(const coo_matrix &) = default;
155 coo_matrix(coo_matrix &&) = default;
156 coo_matrix &operator=(const coo_matrix &) = default;
157 coo_matrix &operator=(coo_matrix &&) = default;
158
159 std::size_t nbytes() const noexcept {
160 return tuples_.size() * sizeof(value_type);
161 }
162
163private:
164 dr::index<I> shape_;
165 backend_type tuples_;
166};
167
168} // namespace __detail
169
170} // namespace dr
Definition: coo_matrix.hpp:16
Definition: matrix_entry.hpp:20
Definition: matrix_entry.hpp:115