BRGeMM ukernel

Overview

BRGeMM ukernel routines. More…

// typedefs

typedef struct dnnl_brgemm* dnnl_brgemm_t;
typedef const struct dnnl_brgemm* const_dnnl_brgemm_t;
typedef struct dnnl_transform* dnnl_transform_t;
typedef const struct dnnl_transform* const_dnnl_transform_t;

// structs

struct dnnl::ukernel::brgemm;
struct dnnl_brgemm;
struct dnnl_transform;

// global functions

dnnl_status_t DNNL_API dnnl_brgemm_create(
    dnnl_brgemm_t* brgemm,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    dnnl_dim_t batch_size,
    dnnl_dim_t lda,
    dnnl_dim_t ldb,
    dnnl_dim_t ldc,
    dnnl_data_type_t a_dt,
    dnnl_data_type_t b_dt,
    dnnl_data_type_t c_dt
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_add_C(dnnl_brgemm_t brgemm, int add_C);

dnnl_status_t DNNL_API dnnl_brgemm_set_post_ops(
    dnnl_brgemm_t brgemm,
    dnnl_dim_t ldd,
    dnnl_data_type_t d_dt,
    const_dnnl_post_ops_t post_ops
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_A_scales(
    dnnl_brgemm_t brgemm,
    int a_scale_mask
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_B_scales(
    dnnl_brgemm_t brgemm,
    int b_scale_mask
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_D_scales(
    dnnl_brgemm_t brgemm,
    int d_scale_mask
    );

dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm);

dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type(
    const_dnnl_brgemm_t brgemm,
    dnnl_pack_type_t* pack_type
    );

dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size(
    const_dnnl_brgemm_t brgemm,
    size_t* size
    );

dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid(
    const_dnnl_brgemm_t brgemm,
    int* valid
    );

dnnl_status_t DNNL_API dnnl_brgemm_set_hw_context(const_dnnl_brgemm_t brgemm);
dnnl_status_t DNNL_API dnnl_brgemm_release_hw_context();
dnnl_status_t DNNL_API dnnl_brgemm_generate(dnnl_brgemm_t brgemm);

dnnl_status_t DNNL_API dnnl_brgemm_execute(
    const_dnnl_brgemm_t brgemm,
    const void* A_ptr,
    const void* B_ptr,
    const dnnl_dim_t* A_B_offsets,
    void* C_ptr,
    void* scratchpad_ptr
    );

dnnl_status_t DNNL_API dnnl_brgemm_execute_postops(
    const_dnnl_brgemm_t brgemm,
    const void* A,
    const void* B,
    const dnnl_dim_t* A_B_offsets,
    const void* C_ptr,
    void* D_ptr,
    void* scratchpad_ptr,
    const_dnnl_ukernel_attr_params_t attr_params
    );

dnnl_status_t DNNL_API dnnl_brgemm_destroy(dnnl_brgemm_t brgemm);

dnnl_status_t DNNL_API dnnl_transform_create(
    dnnl_transform_t* transform,
    dnnl_dim_t K,
    dnnl_dim_t N,
    dnnl_pack_type_t in_pack_type,
    dnnl_dim_t in_ld,
    dnnl_dim_t out_ld,
    dnnl_data_type_t in_dt,
    dnnl_data_type_t out_dt
    );

dnnl_status_t DNNL_API dnnl_transform_generate(dnnl_transform_t transform);

dnnl_status_t DNNL_API dnnl_transform_execute(
    const_dnnl_transform_t transform,
    const void* in_ptr,
    void* out_ptr
    );

dnnl_status_t DNNL_API dnnl_transform_destroy(dnnl_transform_t transform);

Detailed Documentation

BRGeMM ukernel routines.

Typedefs

typedef struct dnnl_brgemm* dnnl_brgemm_t

A brgemm ukernel handle.

typedef const struct dnnl_brgemm* const_dnnl_brgemm_t

A constant brgemm ukernel handle.

typedef struct dnnl_transform* dnnl_transform_t

A transform routine handle.

typedef const struct dnnl_transform* const_dnnl_transform_t

A constant transform routine handle.

Global Functions

dnnl_status_t DNNL_API dnnl_brgemm_create(
    dnnl_brgemm_t* brgemm,
    dnnl_dim_t M,
    dnnl_dim_t N,
    dnnl_dim_t K,
    dnnl_dim_t batch_size,
    dnnl_dim_t lda,
    dnnl_dim_t ldb,
    dnnl_dim_t ldc,
    dnnl_data_type_t a_dt,
    dnnl_data_type_t b_dt,
    dnnl_data_type_t c_dt
    )

Creates a BRGeMM ukernel object.

Operates by the following formula: C = [A x B].

Parameters:

brgemm

Output BRGeMM ukernel object.

M

Dimension M of tensor A.

N

Dimension N of tensor B.

K

Dimension K of tensors A and B.

batch_size

Number of batches to process.

lda

Leading dimension of tensor A.

ldb

Leading dimension of tensor B.

ldc

Leading dimension of tensor C.

a_dt

Data type of tensor A.

b_dt

Data type of tensor B.

c_dt

Data type of tensor C. Must be dnnl_f32.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_set_add_C(dnnl_brgemm_t brgemm, int add_C)

Sets adding an intermediate result to the output tensor C instead of writing: C += [A x B].

Parameters:

brgemm

BRGeMM ukernel object.

add_C

Value to indicate addition. Can be 0 to skip addition, and 1 to apply addition.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_set_post_ops(
    dnnl_brgemm_t brgemm,
    dnnl_dim_t ldd,
    dnnl_data_type_t d_dt,
    const_dnnl_post_ops_t post_ops
    )

Sets post-operations to a BRGeMM ukernel object: D = post-operations(C).

Post-operations applies if one of the following holds:

  • Non-empty attributes are specified.

  • Output data type d_dt is different from accumulation data type c_dt.

If any of conditions happens, the final call of the accumulation chain must be dnnl_brgemm_execute_postops, and dnnl_brgemm_execute, otherwise.

Parameters:

brgemm

BRGeMM ukernel object.

ldd

Leading dimension of tensor D.

d_dt

Data type of tensor D.

post_ops

Primitive post operations attribute to extend the kernel operations.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_set_A_scales(
    dnnl_brgemm_t brgemm,
    int a_scale_mask
    )

Sets tensor A scales mask to a BRGeMM ukernel object.

For quantization flavor tensor A scales apply to accumulation buffer once C is ready.

Parameters:

brgemm

BRGeMM ukernel object.

a_scale_mask

Tensor A scale mask. Can be 0 only.

dnnl_status_t DNNL_API dnnl_brgemm_set_B_scales(
    dnnl_brgemm_t brgemm,
    int b_scale_mask
    )

Sets tensor B scales mask to a BRGeMM ukernel object.

For quantization flavor tensor B scales apply to accumulation buffer once C is ready.

Parameters:

brgemm

BRGeMM ukernel object.

b_scale_mask

Tensor B scale mask. Can be 0 and 2 only.

dnnl_status_t DNNL_API dnnl_brgemm_set_D_scales(
    dnnl_brgemm_t brgemm,
    int d_scale_mask
    )

Sets tensor D scales mask to a BRGeMM ukernel object.

For quantization flavor tensor D scales apply after all post-ops are applied.

Parameters:

brgemm

BRGeMM ukernel object.

d_scale_mask

Tensor D scale mask. Can be 0 only.

dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm)

Finalizes initialization of a BRGeMM ukernel object.

This step is mandatory to query information from the object.

Parameters:

brgemm

Output BRGeMM ukernel object.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type(
    const_dnnl_brgemm_t brgemm,
    dnnl_pack_type_t* pack_type
    )

Returns the packing type expected by a tensor B of a BRGeMM ukernel object.

Parameters:

brgemm

BRGeMM ukernel object.

pack_type

Output packing type. Can be dnnl_brgemm_no_pack if packing is not expected, and dnnl_brgemm_pack_32, otherwise.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size(
    const_dnnl_brgemm_t brgemm,
    size_t* size
    )

Returns the size of a scratchpad memory needed for the BRGeMM ukernel object.

Parameters:

brgemm

BRGeMM ukernel object.

size

Output size of a buffer required for the BRGeMM ukernel object.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid(
    const_dnnl_brgemm_t brgemm,
    int* valid
    )

Returns the flag indicating when the call to dnnl_brgemm_execute_postops is valid.

Parameters:

brgemm

BRGeMM ukernel object.

valid

The flag indicating if dnnl_brgemm_execute_postops is valid for a given ukernel object. 1 is for valid and 0, otherwise.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_set_hw_context(const_dnnl_brgemm_t brgemm)

Initializes the hardware-specific context.

If no initialization required, returns the success status.

Parameters:

brgemm

BRGeMM ukernel object.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_release_hw_context()

Releases the hardware-specific context.

Must be used after all the execution calls to BRGeMM ukernel objects.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_generate(dnnl_brgemm_t brgemm)

Generates an executable part of BRGeMM ukernel object.

Parameters:

brgemm

BRGeMM ukernel object.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_execute(
    const_dnnl_brgemm_t brgemm,
    const void* A_ptr,
    const void* B_ptr,
    const dnnl_dim_t* A_B_offsets,
    void* C_ptr,
    void* scratchpad_ptr
    )

Executes a BRGeMM ukernel object.

Parameters:

brgemm

BRGeMM ukernel object.

A_ptr

Base pointer to a tensor A.

B_ptr

Base pointer to a tensor B.

A_B_offsets

Pointer to the set of tensor A and tensor B offsets for each batch; the set must be contiguous in memory. Single batch should supply offsets for both tensors A and B simultaneously. The number of batches must coincide with the batch_size value passed at the creation stage.

C_ptr

Pointer to a tensor C (accumulation buffer).

scratchpad_ptr

Pointer to a scratchpad buffer.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_execute_postops(
    const_dnnl_brgemm_t brgemm,
    const void* A,
    const void* B,
    const dnnl_dim_t* A_B_offsets,
    const void* C_ptr,
    void* D_ptr,
    void* scratchpad_ptr,
    const_dnnl_ukernel_attr_params_t attr_params
    )

Executes a BRGeMM ukernel object with post operations.

Parameters:

brgemm

BRGeMM ukernel object.

A

Base pointer to a tensor A.

B

Base pointer to a tensor B.

A_B_offsets

Pointer to a set of tensor A and tensor B offsets for each batch. A set must be contiguous in memory. A single batch should supply offsets for both tensors A and B simultaneously. The number of batches must coincide with the batch_size value passed at the creation stage.

C_ptr

Pointer to a tensor C (accumulation buffer).

D_ptr

Pointer to a tensor D (output buffer).

scratchpad_ptr

Pointer to a scratchpad buffer.

attr_params

Ukernel attributes memory storage.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_brgemm_destroy(dnnl_brgemm_t brgemm)

Destroys a BRGeMM ukernel object.

Parameters:

brgemm

BRGeMM ukernel object to destroy.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_transform_create(
    dnnl_transform_t* transform,
    dnnl_dim_t K,
    dnnl_dim_t N,
    dnnl_pack_type_t in_pack_type,
    dnnl_dim_t in_ld,
    dnnl_dim_t out_ld,
    dnnl_data_type_t in_dt,
    dnnl_data_type_t out_dt
    )

Creates a transform object.

Parameters:

transform

Output transform object.

K

Dimension K.

N

Dimension N.

in_pack_type

Input packing type. Must be one of dnnl_pack_type_no_trans, or dnnl_pack_type_trans.

in_ld

Input leading dimension.

out_ld

Output leading dimension. When packing data, it specifies a block by N dimension.

in_dt

Input data type.

out_dt

Output data type.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_transform_generate(dnnl_transform_t transform)

Generates an executable part of transform object.

Parameters:

transform

Transform object.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_transform_execute(
    const_dnnl_transform_t transform,
    const void* in_ptr,
    void* out_ptr
    )

Executes a transform object.

Parameters:

transform

Transform object.

in_ptr

Pointer to an input buffer.

out_ptr

Pointer to an output buffer.

Returns:

dnnl_success on success and a status describing the error otherwise.

dnnl_status_t DNNL_API dnnl_transform_destroy(dnnl_transform_t transform)

Destroys a transform object.

Parameters:

transform

Transform object.

Returns:

dnnl_success on success and a status describing the error otherwise.