9#include <sycl/sycl.hpp>
11#include <dr/detail/segments_tools.hpp>
12#include <dr/sp/allocators.hpp>
13#include <dr/sp/device_ptr.hpp>
14#include <dr/sp/device_vector.hpp>
15#include <dr/sp/vector.hpp>
21 using element_type = T;
22 using value_type = std::remove_cv_t<T>;
24 using segment_type = L;
25 using const_segment_type = std::add_const_t<L>;
26 using nonconst_segment_type = std::remove_const_t<L>;
28 using size_type = std::size_t;
29 using difference_type = std::ptrdiff_t;
32 using reference = rng::range_reference_t<segment_type>;
34 using iterator_category = std::random_access_iterator_tag;
48 size_type segment_id, size_type idx,
49 size_type segment_size) noexcept
50 : segments_(segments), segment_id_(segment_id), idx_(idx),
51 segment_size_(segment_size) {}
54 operator+=(difference_type offset)
noexcept {
57 if (idx_ >= segment_size_) {
58 segment_id_ += idx_ / segment_size_;
59 idx_ = idx_ % segment_size_;
64 size_type new_global_idx = get_global_idx() + offset;
65 segment_id_ = new_global_idx / segment_size_;
66 idx_ = new_global_idx % segment_size_;
72 return segment_id_ == other.segment_id_ && idx_ == other.idx_;
75 constexpr difference_type
77 return difference_type(get_global_idx()) - other.get_global_idx();
81 if (segment_id_ < other.segment_id_) {
83 }
else if (segment_id_ == other.segment_id_) {
84 return idx_ < other.idx_;
90 constexpr reference operator*()
const noexcept {
91 return segments_[segment_id_][idx_];
94 auto segments()
const noexcept {
95 return dr::__detail::drop_segments(segments_, segment_id_, idx_);
99 size_type get_global_idx()
const noexcept {
100 return segment_id_ * segment_size_ + idx_;
103 std::span<segment_type> segments_;
104 size_type segment_id_ = 0;
106 size_type segment_size_ = 0;
109template <
typename T,
typename L>
116template <
typename T,
typename Allocator = dr::sp::device_allocator<T>>
120 using const_segment_type =
121 std::add_const_t<dr::sp::device_vector<T, Allocator>>;
123 using value_type = T;
124 using size_type = std::size_t;
125 using difference_type = std::ptrdiff_t;
127 using pointer =
decltype(std::declval<segment_type>().data());
128 using const_pointer =
129 decltype(std::declval<std::add_const_t<segment_type>>().data());
131 using reference = std::iter_reference_t<pointer>;
132 using const_reference = std::iter_reference_t<const_pointer>;
137 using allocator_type = Allocator;
140 assert(dr::sp::devices().size() > 0);
143 (count + dr::sp::devices().size() - 1) / dr::sp::devices().size();
144 capacity_ = segment_size_ * dr::sp::devices().size();
146 std::size_t rank = 0;
147 for (
auto &&device : dr::sp::devices()) {
149 segment_size_, Allocator(dr::sp::context(), device), rank++));
155 dr::sp::fill(*
this, value);
160 dr::sp::copy(rng::begin(init), rng::end(init), begin());
163 reference operator[](size_type pos) {
164 size_type segment_id = pos / segment_size_;
165 size_type local_id = pos % segment_size_;
166 return *(segments_[segment_id].begin() + local_id);
169 const_reference operator[](size_type pos)
const {
170 size_type segment_id = pos / segment_size_;
171 size_type local_id = pos % segment_size_;
172 return *(segments_[segment_id].begin() + local_id);
175 size_type size()
const noexcept {
return size_; }
177 auto segments() {
return dr::__detail::take_segments(segments_, size()); }
179 auto segments()
const {
180 return dr::__detail::take_segments(segments_, size());
190 return size_ ?
iterator(segments_, size() / segment_size_,
191 size() % segment_size_, segment_size_)
197 size() % segment_size_, segment_size_)
201 void resize(size_type count,
const value_type &value) {
203 std::size_t copy_size = std::min(other.size(), size());
204 dr::sp::copy(begin(), begin() + copy_size, other.begin());
205 *
this = std::move(other);
208 void resize(size_type count) { resize(count, value_type{}); }
211 std::vector<segment_type> segments_;
212 std::size_t capacity_ = 0;
213 std::size_t size_ = 0;
214 std::size_t segment_size_ = 0;
Definition: iterator_adaptor.hpp:23
Definition: device_vector.hpp:13
Definition: distributed_vector.hpp:19
distributed vector
Definition: distributed_vector.hpp:117