Distributed Ranges
Loading...
Searching...
No Matches
fill.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <memory>
8#include <type_traits>
9
10#include <sycl/sycl.hpp>
11
12#include <dr/concepts/concepts.hpp>
13#include <dr/detail/segments_tools.hpp>
14#include <dr/sp/detail.hpp>
15#include <dr/sp/device_ptr.hpp>
16#include <dr/sp/util.hpp>
17
18namespace dr::sp {
19
20template <std::contiguous_iterator Iter>
21 requires(!std::is_const_v<std::iter_value_t<Iter>> &&
22 std::is_trivially_copyable_v<std::iter_value_t<Iter>>)
23sycl::event fill_async(Iter first, Iter last,
24 const std::iter_value_t<Iter> &value) {
25 auto &&q = __detail::get_queue_for_pointer(first);
26 std::iter_value_t<Iter> *arr = std::to_address(first);
27 // not using q.fill because of CMPLRLLVM-46438
28 return dr::__detail::parallel_for(q, sycl::range<>(last - first),
29 [=](auto idx) { arr[idx] = value; });
30}
31
32template <std::contiguous_iterator Iter>
33 requires(!std::is_const_v<std::iter_value_t<Iter>>)
34void fill(Iter first, Iter last, const std::iter_value_t<Iter> &value) {
35 fill_async(first, last, value).wait();
36}
37
38template <typename T, typename U>
39 requires(std::indirectly_writable<device_ptr<T>, U>)
40sycl::event fill_async(device_ptr<T> first, device_ptr<T> last,
41 const U &value) {
42 auto &&q = __detail::get_queue_for_pointer(first);
43 auto *arr = first.get_raw_pointer();
44 // not using q.fill because of CMPLRLLVM-46438
45 return dr::__detail::parallel_for(q, sycl::range<>(last - first),
46 [=](auto idx) { arr[idx] = value; });
47}
48
49template <typename T, typename U>
50 requires(std::indirectly_writable<device_ptr<T>, U>)
51void fill(device_ptr<T> first, device_ptr<T> last, const U &value) {
52 fill_async(first, last, value).wait();
53}
54
55template <typename T, dr::remote_contiguous_range R>
56sycl::event fill_async(R &&r, const T &value) {
57 auto &&q = __detail::queue(dr::ranges::rank(r));
58 auto *arr = std::to_address(rng::begin(dr::ranges::local(r)));
59 // not using q.fill because of CMPLRLLVM-46438
60 return dr::__detail::parallel_for(q, sycl::range<>(rng::distance(r)),
61 [=](auto idx) { arr[idx] = value; });
62}
63
64template <typename T, dr::remote_contiguous_range R>
65auto fill(R &&r, const T &value) {
66 fill_async(r, value).wait();
67 return rng::end(r);
68}
69
70template <typename T, dr::distributed_contiguous_range DR>
71sycl::event fill_async(DR &&r, const T &value) {
72 std::vector<sycl::event> events;
73
74 for (auto &&segment : dr::ranges::segments(r)) {
75 auto e = dr::sp::fill_async(segment, value);
76 events.push_back(e);
77 }
78
79 return dr::sp::__detail::combine_events(events);
80}
81
82template <typename T, dr::distributed_contiguous_range DR>
83auto fill(DR &&r, const T &value) {
84 fill_async(r, value).wait();
85 return rng::end(r);
86}
87
88template <typename T, dr::distributed_iterator Iter>
89auto fill(Iter first, Iter last, const T &value) {
90 fill_async(rng::subrange(first, last), value).wait();
91 return last;
92}
93
94} // namespace dr::sp