struct dnnl::ukernel::brgemm¶
Overview¶
BRGeMM ukernel. More…
#include <dnnl_ukernel.hpp> struct brgemm: public dnnl::handle { // construction brgemm(); brgemm( memory::dim M, memory::dim N, memory::dim K, memory::dim batch_size, memory::dim lda, memory::dim ldb, memory::dim ldc, memory::data_type a_dt, memory::data_type b_dt, memory::data_type c_dt, bool allow_empty = false ); // methods void set_add_C(bool add_C); void set_post_ops( memory::dim ldd, memory::data_type d_dt, const post_ops& po = default_post_ops() ); void set_A_scales(int a_scale_mask); void set_B_scales(int b_scale_mask); void set_D_scales(int d_scale_mask); void finalize(); pack_type get_B_pack_type() const; size_t get_scratchpad_size() const; void set_hw_context() const; void generate(); void execute( const void* A, const void* B, const std::vector<std::pair<memory::dim, memory::dim>>& A_B_offsets, void* C, void* scratchpad ) const; void execute( const void* A, const void* B, const std::vector<std::pair<memory::dim, memory::dim>>& A_B_offsets, void* C, void* D, void* scratchpad, const attr_params& params = default_attr_params() ) const; static void release_hw_context(); static const post_ops& default_post_ops(); static const attr_params& default_attr_params(); };
Inherited Members¶
public: // methods handle<T, traits>& operator = (const handle<T, traits>&); handle<T, traits>& operator = (handle<T, traits>&&); void reset(T t, bool weak = false); T get(bool allow_empty = false) const; operator T () const; operator bool () const; bool operator == (const handle<T, traits>& other) const; bool operator != (const handle& other) const;
Detailed Documentation¶
BRGeMM ukernel.
Construction¶
brgemm()
Default constructor. Produces an empty object.
brgemm( memory::dim M, memory::dim N, memory::dim K, memory::dim batch_size, memory::dim lda, memory::dim ldb, memory::dim ldc, memory::data_type a_dt, memory::data_type b_dt, memory::data_type c_dt, bool allow_empty = false )
Constructs a BRGeMM ukernel object.
Operates by the following formula: C = [A x B]
.
Parameters:
M |
Dimension M of tensor A. |
N |
Dimension N of tensor B. |
K |
Dimension K of tensors A and B. |
batch_size |
Number of batches to process. |
lda |
Leading dimension of tensor A. |
ldb |
Leading dimension of tensor B. |
ldc |
Leading dimension of tensor C. |
a_dt |
Data type of tensor A. |
b_dt |
Data type of tensor B. |
c_dt |
Data type of tensor C. |
allow_empty |
A flag signifying whether construction is allowed to fail without throwing an exception. In this case an empty object will be produced. This flag is optional and defaults to false. |
Methods¶
void set_add_C(bool add_C)
Sets adding an intermediate result to the output tensor C instead of writing: C += [A x B]
.
Parameters:
add_C |
Value to indicate addition. |
void set_post_ops( memory::dim ldd, memory::data_type d_dt, const post_ops& po = default_post_ops() )
Sets post-operations to a BRGeMM ukernel object: D = post-operations(C)
.
Post-operations applies if one of the following holds:
Non-empty post-operations are specified.
Output data type
d_dt
is different from accumulation data typec_dt
.
Parameters:
ldd |
Leading dimension of tensor D. |
d_dt |
Data type of tensor D. |
po |
Primitive post-operation attributes to extend the kernel operations. |
void set_A_scales(int a_scale_mask)
Sets tensor A scales mask to a BRGeMM ukernel object.
For quantization flavor tensor A scales apply to accumulation buffer once C is ready.
Parameters:
a_scale_mask |
Tensor A scale mask. Can be |
void set_B_scales(int b_scale_mask)
Sets tensor B scales mask to a BRGeMM ukernel object.
For quantization flavor tensor B scales apply to accumulation buffer once C is ready.
Parameters:
b_scale_mask |
Tensor B scale mask. Can be |
void set_D_scales(int d_scale_mask)
Sets tensor D scales mask to a BRGeMM ukernel object.
For quantization flavor tensor D scales apply after all post-ops are applied.
Parameters:
d_scale_mask |
Tensor D scale mask. Can be |
void finalize()
Finalizes initialization of a BRGeMM ukernel object.
This step must be performed prior to querying information from the object.
pack_type get_B_pack_type() const
Returns the packing type expected by a tensor B of a BRGeMM ukernel object.
size_t get_scratchpad_size() const
Returns the size of a scratchpad memory needed for the BRGeMM ukernel object.
void set_hw_context() const
Initializes the hardware-specific context.
Affects the global state for all BRGeMM ukernel objects. If no initialization required, returns.
void generate()
Generates an executable part of BRGeMM ukernel object.
void execute( const void* A, const void* B, const std::vector<std::pair<memory::dim, memory::dim>>& A_B_offsets, void* C, void* scratchpad ) const
Executes a BRGeMM ukernel object.
Parameters:
A |
Base pointer to a tensor A. |
B |
Base pointer to a tensor B. |
A_B_offsets |
Vector of pairs of tensors A and B offsets for each batch. The number of batches must coincide with the |
C |
Pointer to a tensor C (accumulation buffer). |
scratchpad |
Pointer to a scratchpad buffer. |
void execute( const void* A, const void* B, const std::vector<std::pair<memory::dim, memory::dim>>& A_B_offsets, void* C, void* D, void* scratchpad, const attr_params& params = default_attr_params() ) const
Executes a BRGeMM ukernel object with post operations.
Parameters:
A |
Base pointer to a tensor A. |
B |
Base pointer to a tensor B. |
A_B_offsets |
Vector of pairs of tensors A and B offsets for each batch. The number of batches must coincide with the |
C |
Pointer to a tensor C (accumulation buffer). |
D |
Pointer to a tensor D (output buffer). |
scratchpad |
Pointer to a scratchpad buffer. |
params |
Post-op memory arguments. Must be passed If binary post-op or scales were set. |
static void release_hw_context()
Releases the hardware-specific context.
Affects the global state for all BRGeMM ukernel objects. Must be used after all the execution calls to BRGeMM ukernel objects.
static const post_ops& default_post_ops()
Returns a constant reference to a static instance of default constructed primitive post-operations attribute.
static const attr_params& default_attr_params()
Returns a constant reference to a static instance of default constructed ukernel attributes parameters.