Distributed Ranges
Loading...
Searching...
No Matches
csr_matrix_view.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <dr/detail/index.hpp>
8#include <dr/detail/matrix_entry.hpp>
9#include <iterator>
10
11namespace dr::views {
12
13template <typename T, typename I, typename TIter, typename IIter>
15public:
16 using size_type = std::size_t;
17 using difference_type = std::ptrdiff_t;
18
19 using scalar_type = std::iter_value_t<TIter>;
20 using scalar_reference = std::iter_reference_t<TIter>;
21
22 using index_type = I;
23
25
27
28 using iterator_category = std::random_access_iterator_tag;
29
33
34 using key_type = dr::index<I>;
35
36 constexpr csr_matrix_view_accessor() noexcept = default;
37 constexpr ~csr_matrix_view_accessor() noexcept = default;
39 const csr_matrix_view_accessor &) noexcept = default;
40 constexpr csr_matrix_view_accessor &
41 operator=(const csr_matrix_view_accessor &) noexcept = default;
42
43 constexpr csr_matrix_view_accessor(TIter values, IIter rowptr, IIter colind,
44 size_type idx, index_type row,
45 size_type row_dim) noexcept
46 : values_(values), rowptr_(rowptr), colind_(colind), idx_(idx), row_(row),
47 row_dim_(row_dim), idx_offset_(key_type{0, 0}) {
48 fast_forward_row();
49 }
50
51 constexpr csr_matrix_view_accessor(TIter values, IIter rowptr, IIter colind,
52 size_type idx, index_type row,
53 size_type row_dim,
54 key_type idx_offset) noexcept
55 : values_(values), rowptr_(rowptr), colind_(colind), idx_(idx), row_(row),
56 row_dim_(row_dim), idx_offset_(idx_offset) {
57 fast_forward_row();
58 }
59
60 // Given that `idx_` has just been advanced to an element
61 // possibly in a new row, advance `row_` to find the new row.
62 // That is:
63 // Advance `row_` until idx_ >= rowptr_[row_] && idx_ < rowptr_[row_+1]
64 void fast_forward_row() noexcept {
65 while (row_ < row_dim_ - 1 && idx_ >= rowptr_[row_ + 1]) {
66 row_++;
67 }
68 }
69
70 // Given that `idx_` has just been retreated to an element
71 // possibly in a previous row, retreat `row_` to find the new row.
72 // That is:
73 // Retreat `row_` until idx_ >= rowptr_[row_] && idx_ < rowptr_[row_+1]
74 void fast_backward_row() noexcept {
75 while (idx_ < rowptr_[row_]) {
76 row_--;
77 }
78 }
79
80 constexpr csr_matrix_view_accessor &
81 operator+=(difference_type offset) noexcept {
82 idx_ += offset;
83 if (offset < 0) {
84 fast_backward_row();
85 } else {
86 fast_forward_row();
87 }
88 return *this;
89 }
90
91 constexpr bool operator==(const iterator_accessor &other) const noexcept {
92 return idx_ == other.idx_;
93 }
94
95 constexpr difference_type
96 operator-(const iterator_accessor &other) const noexcept {
97 return difference_type(idx_) - difference_type(other.idx_);
98 }
99
100 constexpr bool operator<(const iterator_accessor &other) const noexcept {
101 return idx_ < other.idx_;
102 }
103
104 constexpr reference operator*() const noexcept {
105 return reference(
106 key_type(row_ + idx_offset_[0], colind_[idx_] + idx_offset_[1]),
107 values_[idx_]);
108 }
109
110private:
111 TIter values_;
112 IIter rowptr_;
113 IIter colind_;
114 size_type idx_;
115 index_type row_;
116 size_type row_dim_;
117 key_type idx_offset_;
118};
119
120template <typename T, typename I, typename TIter, typename IIter>
123
124template <typename T, typename I, typename TIter = T *, typename IIter = I *>
126 : public rng::view_interface<csr_matrix_view<T, I, TIter, IIter>> {
127public:
128 using size_type = std::size_t;
129 using difference_type = std::ptrdiff_t;
130
131 using scalar_reference = std::iter_reference_t<TIter>;
133
134 using scalar_type = T;
135 using index_type = I;
136
137 using key_type = dr::index<I>;
138 using map_type = T;
139
141 csr_matrix_view() = default;
142 csr_matrix_view(TIter values, IIter rowptr, IIter colind, key_type shape,
143 size_type nnz, size_type rank)
144 : values_(values), rowptr_(rowptr), colind_(colind), shape_(shape),
145 nnz_(nnz), rank_(rank), idx_offset_(key_type{0, 0}) {}
146
147 csr_matrix_view(TIter values, IIter rowptr, IIter colind, key_type shape,
148 size_type nnz, size_type rank, key_type idx_offset)
149 : values_(values), rowptr_(rowptr), colind_(colind), shape_(shape),
150 nnz_(nnz), rank_(rank), idx_offset_(idx_offset) {}
151
152 key_type shape() const noexcept { return shape_; }
153
154 size_type size() const noexcept { return nnz_; }
155
156 std::size_t rank() const { return rank_; }
157
158 iterator begin() const {
159 return iterator(values_, rowptr_, colind_, 0, 0, shape()[0], idx_offset_);
160 }
161
162 iterator end() const {
163 return iterator(values_, rowptr_, colind_, nnz_, shape()[0], shape()[0],
164 idx_offset_);
165 }
166
167 auto row(I row_index) const {
168 I first = rowptr_[row_index];
169 I last = rowptr_[row_index + 1];
170
171 TIter values = values_;
172 IIter colind = colind_;
173
174 auto row_elements = rng::views::iota(first, last);
175
176 return row_elements | rng::views::transform([=](auto idx) {
177 return reference(key_type(row_index, colind[idx]), values[idx]);
178 });
179 }
180
181 auto submatrix(key_type rows, key_type columns) const {
182 return rng::views::iota(rows[0], rows[1]) |
183 rng::views::transform([=, *this](auto &&row_index) {
184 return row(row_index) | rng::views::drop_while([=](auto &&e) {
185 auto &&[index, v] = e;
186 return index[1] < columns[0];
187 }) |
188 rng::views::take_while([=](auto &&e) {
189 auto &&[index, v] = e;
190 return index[1] < columns[1];
191 }) |
192 rng::views::transform([=](auto &&elem) {
193 auto &&[index, v] = elem;
194 auto &&[i, j] = index;
195 return reference(key_type(i - rows[0], j - columns[0]),
196 v);
197 });
198 }) |
199 rng::views::join;
200 }
201
202 auto values_data() const { return values_; }
203
204 auto rowptr_data() const { return rowptr_; }
205
206 auto colind_data() const { return colind_; }
207
208private:
209 TIter values_;
210 IIter rowptr_;
211 IIter colind_;
212
213 key_type shape_;
214 size_type nnz_;
215
216 size_type rank_;
217 key_type idx_offset_;
218};
219
220template <typename TIter, typename IIter, typename... Args>
221csr_matrix_view(TIter, IIter, IIter, Args &&...)
222 -> csr_matrix_view<std::iter_value_t<TIter>, std::iter_value_t<IIter>,
223 TIter, IIter>;
224
225} // namespace dr::views
Definition: iterator_adaptor.hpp:23
Definition: matrix_entry.hpp:20
Definition: matrix_entry.hpp:115
Definition: csr_matrix_view.hpp:14
Definition: csr_matrix_view.hpp:126