.. index:: pair: page; Layer Normalization .. _doxid-dev_guide_layer_normalization: Layer Normalization =================== :ref:API Reference  General ~~~~~~~ The layer normalization primitive performs a forward or backward layer normalization operation on a 2-5D data tensor. Forward ------- The layer normalization operation performs normalization over the last logical axis of the data tensor and is defined by the following formulas. We show formulas only for 3D data, which are straightforward to generalize to cases of higher dimensions. Variable names follow the standard :ref:Naming Conventions . .. math:: \dst(t, n, c) = \gamma(c) \cdot \frac{\src(t, n, c) - \mu(t, n)} {\sqrt{\sigma^2(t, n) + \varepsilon}} + \beta(c), where * :math:\gamma(c), \beta(c) are optional scale and shift for a channel (see :ref:dnnl_use_scale , :ref:dnnl_use_shift  flags), * :math:\mu(t, n), \sigma^2(t, n) are mean and variance (see :ref:dnnl_use_global_stats  flag), and * :math:\varepsilon is a constant to improve numerical stability. Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used: * :math:\mu(t, n) = \frac{1}{C} \sum\limits_{c} \src(t, n, c)_{}, * :math:\sigma^2(t, n) = \frac{1}{C} \sum\limits_{c} {}_{} (\src(t, n, c) - \mu(t, n))^2. The :math:\gamma(c) and :math:\beta(c) tensors are considered learnable. Difference Between Forward Training and Forward Inference +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ * If mean and variance are computed at runtime (i.e., :ref:dnnl_use_global_stats  is not set), they become outputs for the propagation kind :ref:dnnl_forward_training  (because they would be required during the backward propagation). Data layout for mean and variance must be specified during creation of the layer normalization primitive descriptor by passing the memory descriptor for statistics (e.g., by passing stat_desc in :ref:dnnl::layer_normalization_forward::primitive_desc() ). Mean and variance are not exposed for the propagation kind :ref:dnnl_forward_inference . Backward -------- The backward propagation computes :math:\diffsrc(t, n, c), :math:\diffgamma(c)^*, and :math:\diffbeta(c)^* based on :math:\diffdst(t, n, c), :math:src(t, n, c), :math:\mu(t, n), :math:\sigma^2(t, n), :math:\gamma(c) ^*, and :math:\beta(c) ^*. The tensors marked with an asterisk are used only when the primitive is configured to use :math:\gamma(c), and :math:\beta(c) (i.e., :ref:dnnl_use_scale  or :ref:dnnl_use_shift  are set). Execution Arguments ~~~~~~~~~~~~~~~~~~~ Depending on the :ref:flags  and :ref:propagation kind , the layer normalization primitive requires different inputs and outputs. For clarity, a summary is shown below. ====================================================================================================================================================================================================================================================================================================================================================================================================================================== ================================================================================================================================================= ================================================================================================================================================ =================================================================================================================================================================================== ==================================================================================================================================================== Flags :ref:dnnl_forward_inference  :ref:dnnl_forward_training  :ref:dnnl_backward  :ref:dnnl_backward_data  ====================================================================================================================================================================================================================================================================================================================================================================================================================================== ================================================================================================================================================= ================================================================================================================================================ =================================================================================================================================================================================== ==================================================================================================================================================== :ref:dnnl_normalization_flags_none  *Inputs* : :math:\src *Outputs* : :math:\dst *Inputs* : :math:\src *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\diffsrc Same as for :ref:dnnl_backward  :ref:dnnl_use_global_stats  *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\dst *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 *Outputs* : :math:\diffsrc Same as for :ref:dnnl_backward  :ref:dnnl_use_scale  *Inputs* : :math:\src , :math:\gamma *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\gamma *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma *Outputs* : :math:\diffsrc , :math:\diffgamma Not supported :ref:dnnl_use_shift  *Inputs* : :math:\src , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\beta *Outputs* : :math:\dst , :math:\mu , :math:\sigma^2 *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\beta *Outputs* : :math:\diffsrc , :math:\diffbeta Not supported :ref:dnnl_use_global_stats  | :ref:dnnl_use_scale  | :ref:dnnl_use_shift  *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\dst *Inputs* : :math:\diffdst , :math:\src , :math:\mu , :math:\sigma^2 , :math:\gamma , :math:\beta *Outputs* : :math:\diffsrc , :math:\diffgamma , :math:\diffbeta Not supported ====================================================================================================================================================================================================================================================================================================================================================================================================================================== ================================================================================================================================================= ================================================================================================================================================ =================================================================================================================================================================================== ==================================================================================================================================================== When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table. ============================ ==================================== Primitive input/output Execution argument index ============================ ==================================== :math:\src DNNL_ARG_SRC :math:\gamma DNNL_ARG_SCALE :math:\beta DNNL_ARG_SHIFT mean ( :math:\mu ) DNNL_ARG_MEAN variance ( :math:\sigma ) DNNL_ARG_VARIANCE :math:\dst DNNL_ARG_DST :math:\diffdst DNNL_ARG_DIFF_DST :math:\diffsrc DNNL_ARG_DIFF_SRC :math:\diffgamma DNNL_ARG_DIFF_SCALE :math:\diffbeta DNNL_ARG_DIFF_SHIFT :math:src scale DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC :math:dst scale DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST ============================ ==================================== Implementation Details ~~~~~~~~~~~~~~~~~~~~~~ General Notes ------------- #. The different flavors of the primitive are partially controlled by the flags parameter that is passed to the primitive descriptor creation function (e.g., :ref:dnnl::layer_normalization_forward::primitive_desc() ). Multiple flags can be set using the bitwise OR operator (|). #. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the :ref:dnnl_use_global_stats  flag. For the backward propagation, the mean and variance are always input parameters. #. Both forward and backward propagation support in-place operations, meaning that :math:\src can be used as input and output for forward propagation, and :math:\diffdst can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. This support is limited to cases when data types of :math:\src and :math:\dst or :math:\diffsrc and :math:\diffdst are identical. Note, however, that backward propagation requires original :math:\src, hence the corresponding forward propagation should not be performed in-place. Post-ops and Attributes ----------------------- Attributes enable you to modify the behavior of the layer normalization primitive. The following attributes are supported by the layer normalization primitive: ============ ========== ======================================================================================= ============================================================== =================================================================================== Propagation Type Operation Description Restrictions ============ ========== ======================================================================================= ============================================================== =================================================================================== forward attribute :ref:Scales  Scales the corresponding tensor by the given scale factor(s). Supported only for int8 layer normalization and one scale per tensor is supported. ============ ========== ======================================================================================= ============================================================== =================================================================================== Data Type Support ----------------- The operation supports the following combinations of data types: ============ ============================ ============================ Propagation Source Destination ============ ============================ ============================ forward f32, bf16, f16, u8, s8, f64 f32, bf16, f16, u8, s8, f64 backward f32, bf16, f16, f64 f32, bf16, f16, f64 ============ ============================ ============================ Mean, Variance and ScaleShift data types are always f32 and independent of Source or Destination data types. Data Representation ------------------- Mean and Variance +++++++++++++++++ The mean (:math:\mu) and variance (:math:\sigma^2) are separate tensors with number of dimensions equal to (:math:data\_ndims - 1) and size :math:(data\_dim[0], data\_dim[1], ..., data\_dim[ndims - 2]). The corresponding memory object can have an arbitrary memory format. Unless mean and variance are computed at runtime and not exposed (i.e., propagation kind is :ref:dnnl_forward_inference  and :ref:dnnl_use_global_stats  is not set), the user should provide a memory descriptor for statistics when creating the layer normalization primitive descriptor. For best performance, it is advised to use the memory format that follows the data memory format; i.e., if the data format is :ref:dnnl_tnc , the best performance can be expected for statistics with the :ref:dnnl_tn  format and suboptimal for statistics with the :ref:dnnl_nt  format. Scale and Shift +++++++++++++++ If :ref:dnnl_use_scale  or :ref:dnnl_use_shift  are used, the scale (:math:\gamma) and shift (:math:\beta) are separate 1D tensors of shape :math:C. The format of the corresponding memory object must be :ref:dnnl_nc  (:ref:dnnl_ab ). Source, Destination, and Their Gradients ++++++++++++++++++++++++++++++++++++++++ The layer normalization primitive works with an arbitrary data tensor; however, it was designed for RNN data tensors (i.e., :ref:dnnl_nc , :ref:dnnl_tnc , :ref:dnnl_ldnc ). Unlike CNN data tensors, RNN data tensors have a single feature dimension. Layer normalization performs normalization over the last logical dimension (feature dimension for RNN tensors) across non-feature dimensions. The layer normalization primitive is optimized for the following memory formats: =============== ===================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== Logical tensor Implementations optimized for memory formats =============== ===================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== NC :ref:dnnl_nc  ( :ref:dnnl_ab  ) TNC :ref:dnnl_tnc  ( :ref:dnnl_abc  ), :ref:dnnl_ntc  ( :ref:dnnl_bac  ) LDNC :ref:dnnl_ldnc  ( :ref:dnnl_abcd  ) =============== ===================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================== Implementation Limitations ~~~~~~~~~~~~~~~~~~~~~~~~~~ #. Refer to :ref:Data Types  for limitations related to data types support. #. GPU * Only tensors of 6 or fewer dimensions are supported. * Different data types for source and destination is not supported. * Integer data types for source and destination are not supported. Performance Tips ~~~~~~~~~~~~~~~~ #. For data tensors :math:\src, :math:\dst, :math:\diffsrc, and :math:\diffdst, use memory formats for which the last logical axis is the last in the physical memory layout. #. For mean and variance, use the memory format that follows the data memory format; i.e., if the data format is :ref:dnnl_tnc , the best performance can be expected for statistics with :ref:dnnl_tn  and suboptimal for statistics with the :ref:dnnl_nt  format. #. For backward propagation, use the same memory format for :math:\src, :math:\diffdst, and :math:\diffsrc. Different formats are functionally supported but lead to highly suboptimal performance. #. Use in-place operations whenever possible (see caveats in General Notes). Example ~~~~~~~ :ref:Layer Normalization Primitive Example  This C++ API example demonstrates how to create and execute a :ref:Layer normalization  primitive in forward propagation mode. Key optimizations included in this example: * In-place primitive execution; * Creation of memory objects using the primitive descriptor.