Distributed Ranges
Loading...
Searching...
No Matches
init.hpp
1// SPDX-FileCopyrightText: Intel Corporation
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#pragma once
6
7#include <cassert>
8#include <memory>
9#include <span>
10#include <sycl/sycl.hpp>
11#include <type_traits>
12#include <vector>
13
14#include <dr/sp/algorithms/execution_policy.hpp>
15#include <dr/sp/util.hpp>
16#include <oneapi/dpl/execution>
17
18namespace dr::sp {
19
20namespace __detail {
21
22inline sycl::context *global_context_;
23
24inline std::vector<sycl::device> devices_;
25
26inline std::vector<sycl::queue> queues_;
27
28inline std::vector<oneapi::dpl::execution::device_policy<>> dpl_policies_;
29
30inline std::size_t ngpus_;
31
32inline sycl::context &global_context() { return *global_context_; }
33
34inline std::size_t ngpus() { return ngpus_; }
35
36inline std::span<sycl::device> global_devices() { return devices_; }
37
38} // namespace __detail
39
40inline sycl::context &context() { return __detail::global_context(); }
41
42inline std::span<sycl::device> devices() { return __detail::global_devices(); }
43
44inline std::size_t nprocs() { return __detail::ngpus(); }
45
46inline device_policy par_unseq;
47
48template <rng::range R>
49inline void init(R &&devices)
50 requires(
51 std::is_same_v<sycl::device, std::remove_cvref_t<rng::range_value_t<R>>>)
52{
53 __detail::devices_.assign(rng::begin(devices), rng::end(devices));
54 __detail::global_context_ = new sycl::context(__detail::devices_);
55 __detail::ngpus_ = rng::size(__detail::devices_);
56
57 for (auto &&device : __detail::devices_) {
58 sycl::queue q(*__detail::global_context_, device);
59 __detail::queues_.push_back(q);
60
61 __detail::dpl_policies_.emplace_back(__detail::queues_.back());
62 }
63
64 par_unseq = device_policy(__detail::devices_);
65}
66
67template <__detail::sycl_device_selector Selector>
68inline void init(Selector &&selector) {
69 auto devices = get_numa_devices(selector);
70 init(devices);
71}
72
73inline void init() { init(sycl::default_selector_v); }
74
75inline void finalize() {
76 __detail::dpl_policies_.clear();
77 __detail::queues_.clear();
78 __detail::devices_.clear();
79 delete __detail::global_context_;
80}
81
82namespace __detail {
83
84inline sycl::queue &queue(std::size_t rank) { return queues_[rank]; }
85
86// Retrieve global queues because of CMPLRLLVM-47008
87inline sycl::queue &queue(const sycl::device &device) {
88 for (std::size_t rank = 0; rank < sp::nprocs(); rank++) {
89 if (sp::devices()[rank] == device) {
90 return queue(rank);
91 }
92 }
93 assert(false);
94 // Reaches here with -DNDEBUG
95 return queue(0);
96}
97
98inline sycl::queue &default_queue() { return queue(0); }
99
100inline auto &dpl_policy(std::size_t rank) {
101 return __detail::dpl_policies_[rank];
102}
103
104} // namespace __detail
105
106} // namespace dr::sp