9#include <sycl/sycl.hpp>
11#include <dr/sp/device_ptr.hpp>
16using shared_allocator = sycl::usm_allocator<T, sycl::usm::alloc::shared>;
18template <
typename T, std::
size_t Alignment = 0>
19 requires(std::is_trivially_copyable_v<T>)
27 using size_type = std::size_t;
28 using difference_type = std::ptrdiff_t;
32 : device_(other.get_device()), context_(other.get_context()) {}
35 : device_(q.get_device()), context_(q.get_context()) {}
36 device_allocator(
const sycl::context &ctxt,
const sycl::device &dev) noexcept
37 : device_(dev), context_(ctxt) {}
43 using is_always_equal = std::false_type;
45 pointer allocate(std::size_t size) {
46 if constexpr (Alignment == 0) {
47 return pointer(sycl::malloc_device<T>(size, device_, context_));
50 sycl::aligned_alloc_device<T>(Alignment, size, device_, context_));
54 void deallocate(
pointer ptr, std::size_t n) {
55 sycl::free(ptr.get_raw_pointer(), context_);
61 template <
typename U>
struct rebind {
65 sycl::device get_device() const noexcept {
return device_; }
67 sycl::context get_context() const noexcept {
return context_; }
71 sycl::context context_;
76 using value_type =
typename std::allocator_traits<Allocator>::value_type;
77 using pointer =
typename std::allocator_traits<Allocator>::pointer;
79 typename std::allocator_traits<Allocator>::const_pointer;
80 using size_type =
typename std::allocator_traits<Allocator>::size_type;
81 using difference_type =
82 typename std::allocator_traits<Allocator>::difference_type;
85 std::size_t n_buffers)
86 : alloc_(alloc), buffer_size_(buffer_size),
87 free_buffers_(
new std::vector<pointer>()),
88 buffers_(
new std::vector<pointer>()) {
89 for (std::size_t i = 0; i < n_buffers; i++) {
90 buffers_->push_back(alloc_.allocate(buffer_size_));
92 free_buffers_->assign(buffers_->begin(), buffers_->end());
96 if (buffers_.use_count() == 1) {
97 for (
auto &&buffer : *buffers_) {
98 alloc_.deallocate(buffer, buffer_size_);
103 using is_always_equal = std::false_type;
105 pointer allocate(std::size_t size) {
106 if (size > buffer_size_ || free_buffers_->empty()) {
107 throw std::bad_alloc();
109 pointer buffer = free_buffers_->back();
110 free_buffers_->pop_back();
115 void deallocate(pointer ptr, std::size_t n) { free_buffers_->push_back(ptr); }
122 std::size_t buffer_size_;
123 std::shared_ptr<std::vector<pointer>> free_buffers_;
124 std::shared_ptr<std::vector<pointer>> buffers_;
Definition: allocators.hpp:74
Definition: allocators.hpp:20
Definition: device_ptr.hpp:17
Definition: device_ref.hpp:15
Definition: allocators.hpp:61