Distributed Ranges
Loading...
Searching...
No Matches
distributed_vector.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <vector>
8
9#include <sycl/sycl.hpp>
10
11#include <dr/detail/segments_tools.hpp>
12#include <dr/sp/allocators.hpp>
13#include <dr/sp/device_ptr.hpp>
14#include <dr/sp/device_vector.hpp>
15#include <dr/sp/vector.hpp>
16
17namespace dr::sp {
18
19template <typename T, typename L> class distributed_vector_accessor {
20public:
21 using element_type = T;
22 using value_type = std::remove_cv_t<T>;
23
24 using segment_type = L;
25 using const_segment_type = std::add_const_t<L>;
26 using nonconst_segment_type = std::remove_const_t<L>;
27
28 using size_type = std::size_t;
29 using difference_type = std::ptrdiff_t;
30
31 // using pointer = typename segment_type::pointer;
32 using reference = rng::range_reference_t<segment_type>;
33
34 using iterator_category = std::random_access_iterator_tag;
35
39
40 constexpr distributed_vector_accessor() noexcept = default;
41 constexpr ~distributed_vector_accessor() noexcept = default;
43 const distributed_vector_accessor &) noexcept = default;
45 operator=(const distributed_vector_accessor &) noexcept = default;
46
47 constexpr distributed_vector_accessor(std::span<segment_type> segments,
48 size_type segment_id, size_type idx,
49 size_type segment_size) noexcept
50 : segments_(segments), segment_id_(segment_id), idx_(idx),
51 segment_size_(segment_size) {}
52
54 operator+=(difference_type offset) noexcept {
55 if (offset > 0) {
56 idx_ += offset;
57 if (idx_ >= segment_size_) {
58 segment_id_ += idx_ / segment_size_;
59 idx_ = idx_ % segment_size_;
60 }
61 }
62
63 if (offset < 0) {
64 size_type new_global_idx = get_global_idx() + offset;
65 segment_id_ = new_global_idx / segment_size_;
66 idx_ = new_global_idx % segment_size_;
67 }
68 return *this;
69 }
70
71 constexpr bool operator==(const iterator_accessor &other) const noexcept {
72 return segment_id_ == other.segment_id_ && idx_ == other.idx_;
73 }
74
75 constexpr difference_type
76 operator-(const iterator_accessor &other) const noexcept {
77 return difference_type(get_global_idx()) - other.get_global_idx();
78 }
79
80 constexpr bool operator<(const iterator_accessor &other) const noexcept {
81 if (segment_id_ < other.segment_id_) {
82 return true;
83 } else if (segment_id_ == other.segment_id_) {
84 return idx_ < other.idx_;
85 } else {
86 return false;
87 }
88 }
89
90 constexpr reference operator*() const noexcept {
91 return segments_[segment_id_][idx_];
92 }
93
94 auto segments() const noexcept {
95 return dr::__detail::drop_segments(segments_, segment_id_, idx_);
96 }
97
98private:
99 size_type get_global_idx() const noexcept {
100 return segment_id_ * segment_size_ + idx_;
101 }
102
103 std::span<segment_type> segments_;
104 size_type segment_id_ = 0;
105 size_type idx_ = 0;
106 size_type segment_size_ = 0;
107};
108
109template <typename T, typename L>
112
113// TODO: support teams, distributions
114
116template <typename T, typename Allocator = dr::sp::device_allocator<T>>
118public:
120 using const_segment_type =
121 std::add_const_t<dr::sp::device_vector<T, Allocator>>;
122
123 using value_type = T;
124 using size_type = std::size_t;
125 using difference_type = std::ptrdiff_t;
126
127 using pointer = decltype(std::declval<segment_type>().data());
128 using const_pointer =
129 decltype(std::declval<std::add_const_t<segment_type>>().data());
130
131 using reference = std::iter_reference_t<pointer>;
132 using const_reference = std::iter_reference_t<const_pointer>;
133
135 using const_iterator =
137 using allocator_type = Allocator;
138
139 distributed_vector(std::size_t count = 0) {
140 assert(dr::sp::devices().size() > 0);
141 size_ = count;
142 segment_size_ =
143 (count + dr::sp::devices().size() - 1) / dr::sp::devices().size();
144 capacity_ = segment_size_ * dr::sp::devices().size();
145
146 std::size_t rank = 0;
147 for (auto &&device : dr::sp::devices()) {
148 segments_.emplace_back(segment_type(
149 segment_size_, Allocator(dr::sp::context(), device), rank++));
150 }
151 }
152
153 distributed_vector(std::size_t count, const T &value)
154 : distributed_vector(count) {
155 dr::sp::fill(*this, value);
156 }
157
158 distributed_vector(std::initializer_list<T> init)
159 : distributed_vector(init.size()) {
160 dr::sp::copy(rng::begin(init), rng::end(init), begin());
161 }
162
163 reference operator[](size_type pos) {
164 size_type segment_id = pos / segment_size_;
165 size_type local_id = pos % segment_size_;
166 return *(segments_[segment_id].begin() + local_id);
167 }
168
169 const_reference operator[](size_type pos) const {
170 size_type segment_id = pos / segment_size_;
171 size_type local_id = pos % segment_size_;
172 return *(segments_[segment_id].begin() + local_id);
173 }
174
175 size_type size() const noexcept { return size_; }
176
177 auto segments() { return dr::__detail::take_segments(segments_, size()); }
178
179 auto segments() const {
180 return dr::__detail::take_segments(segments_, size());
181 }
182
183 iterator begin() { return iterator(segments_, 0, 0, segment_size_); }
184
185 const_iterator begin() const {
186 return const_iterator(segments_, 0, 0, segment_size_);
187 }
188
189 iterator end() {
190 return size_ ? iterator(segments_, size() / segment_size_,
191 size() % segment_size_, segment_size_)
192 : begin();
193 }
194
195 const_iterator end() const {
196 return size_ ? const_iterator(segments_, size() / segment_size_,
197 size() % segment_size_, segment_size_)
198 : begin();
199 }
200
201 void resize(size_type count, const value_type &value) {
202 distributed_vector<T, Allocator> other(count, value);
203 std::size_t copy_size = std::min(other.size(), size());
204 dr::sp::copy(begin(), begin() + copy_size, other.begin());
205 *this = std::move(other);
206 }
207
208 void resize(size_type count) { resize(count, value_type{}); }
209
210private:
211 std::vector<segment_type> segments_;
212 std::size_t capacity_ = 0;
213 std::size_t size_ = 0;
214 std::size_t segment_size_ = 0;
215};
216
217} // namespace dr::sp
Definition: iterator_adaptor.hpp:23
Definition: device_vector.hpp:13
Definition: distributed_vector.hpp:19
distributed vector
Definition: distributed_vector.hpp:117