cpu_matmul_quantization.cppΒΆ

Annotated version: MatMul Tutorial: Quantization

Annotated version: MatMul Tutorial: Quantization

/*******************************************************************************
* Copyright 2019-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/


#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include <type_traits>

#include "oneapi/dnnl/dnnl.hpp"

#include "example_utils.hpp"

using namespace dnnl;

namespace {

void init_vector(std::vector<float> &v, float min_value, float max_value) {
    std::mt19937 gen;
    std::uniform_real_distribution<float> u(min_value, max_value);

    for (auto &e : v)
        e = u(gen);
}

template <typename T>
void find_min_max(const std::vector<T> &v, float &min_value, float &max_value) {
    min_value = max_value = v[0];
    for (auto &e : v) {
        min_value = std::min<float>(min_value, e);
        max_value = std::max<float>(max_value, e);
    }
}

template <typename T>
void compute_q10n_params(const char *message, const std::vector<float> &v,
        float &scale, int32_t &zp) {
    // Find property of T integer type
    // Simple trick to improve accuracy: shrink the range a little bit
    float max_int = (float)std::numeric_limits<T>::max() - 1;
    float min_int = (float)std::numeric_limits<T>::lowest() + 1;

#ifndef OMIT_WORKAROUND_FOR_SKX
    // Read more in CPU / Section 1 here:
    // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html
    if (std::is_same<T, uint8_t>::value) max_int /= 2;
#endif

    // Find min and max value in array
    float min_val = v[0], max_val = v[0];
    find_min_max(v, min_val, max_val);

    // Compute appropriate scale
    scale = (max_val - min_val) / (max_int - min_int);

    // Compute appropriate offset
    if (std::is_same<T, int8_t>::value)
        zp = 0;
    else
        zp = (int32_t)(max_int - max_val / scale);
    printf("\tComputing q10n params for %s\n"
           "\t\tData type: %s\n"
           "\t\tScale:%.3g (inverse scale:%.3g)\n"
           "\t\tZero point:%d\n\n",
            message, std::is_same<T, int8_t>::value ? "int8_t" : "uint8_t",
            scale, 1 / scale, zp);
}

int compare_vectors(const std::vector<float> &v1,
        const std::vector<uint8_t> &v2, float scale_v2, int32_t zp_v2,
        float threshold) {
    double v1_l2 = 0, diff_l2 = 0;
    for (size_t n = 0; n < v1.size(); ++n) {
        float v2_n = scale_v2 * (v2[n] - zp_v2); // deq10n v2
        float diff = v1[n] - v2_n;
        v1_l2 += v1[n] * v1[n];
        diff_l2 += diff * diff;
    }

    v1_l2 = std::sqrt(v1_l2);
    diff_l2 = std::sqrt(diff_l2);
    bool ok = diff_l2 <= threshold * v1_l2;

    printf("\tComparison (using l2-norms)\n"
           "\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative error:%g\n"
           "\nAccuracy check: %s\n\n",
            v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED");

    return ok ? 0 : 1;
}

} // namespace

engine eng(engine::kind::cpu, 0); // We create a global engine for simplicity

// Quantize float data into X_int_m oneDNN memory using the q10n parameters
//
// Inputs:
// - X_f32 -- source f32 matrix
// - scale_X, zp_X -- quantization parameters
// - q10n_scheme -- dynamic or static, to mimic real-world applications wrt to
//                  how the q10n parameters are passed to reorders
// Outputs:
// - X_int_m -- prepared oneDNN memory that would hold quantized values
void quantize(const std::vector<float> &X_f32, float scale_X, int32_t zp_X,
        memory &X_int_m) {
    using dt = memory::data_type;

    stream s(eng);

    memory::desc x_int_md = X_int_m.get_desc();
    const auto &dims = x_int_md.get_dims();

    memory::desc x_f32_md({dims[0], dims[1]}, dt::f32, {dims[1], 1});
    memory X_f32_m(x_f32_md, eng, (void *)X_f32.data());

    primitive_attr q10n_attr;
    q10n_attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 0);
    q10n_attr.set_zero_points_mask(DNNL_ARG_DST, /* mask */ 0);

    reorder::primitive_desc q10n_pd(eng, x_f32_md, eng, x_int_md, q10n_attr);
    memory dst_scale_X_m({{1}, dt::f32, {1}}, eng, &scale_X);
    memory zp_X_m({{1}, dt::s32, {1}}, eng, &zp_X);
    reorder(q10n_pd).execute(s,
            {{DNNL_ARG_SRC, X_f32_m}, {DNNL_ARG_DST, X_int_m},
                    {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale_X_m},
                    {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_X_m}});

    s.wait();
}

// Floating point MatMul
// Inputs:
// - Shape: M, N, K
// - Matrices A and B
// Outputs:
// - Matrix C
void f32_matmul_compute(int64_t M, int64_t N, int64_t K,
        const std::vector<float> &A_f32, const std::vector<float> &B_f32,
        std::vector<float> &C_f32) {
    // Initialize memory descriptors that describes matrices in Row-Major format
    memory::desc a_md({M, K}, memory::data_type::f32, {K, 1});
    memory::desc b_md({K, N}, memory::data_type::f32, {N, 1});
    memory::desc c_md({M, N}, memory::data_type::f32, {N, 1});

    // Wrap raw pointers into oneDNN memory objects
    memory A_f32_m(a_md, eng, (void *)A_f32.data());
    memory B_f32_m(b_md, eng, (void *)B_f32.data());
    memory C_f32_m(c_md, eng, (void *)C_f32.data());

    // Create a MatMul primitive
    matmul::primitive_desc matmul_pd(eng, a_md, b_md, c_md);
    matmul matmul_p(matmul_pd);

    stream s(eng);
    matmul_p.execute(s,
            {{DNNL_ARG_SRC, A_f32_m}, {DNNL_ARG_WEIGHTS, B_f32_m},
                    {DNNL_ARG_DST, C_f32_m}});
    s.wait();
}

// Reduced precision MatMul with **dynamic** quantization
// Inputs:
// - Shape: M, N, K
// - Matrices A and B in float (would be quantized inside the function)
// Outputs:
// - Matrix C in uint8_t
// - Quantization parameters: scale_C and zp_C
void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K,
        const std::vector<float> &A_f32, const std::vector<float> &B_f32,
        std::vector<uint8_t> &C_u8, float &scale_C, int32_t &zp_C) {
    stream s(eng);

    float scale_A, scale_B;
    int32_t zp_A, zp_B;

    // We compute q10n parameters here, but in the real world applications for
    // inputs these parameters are transferred from the previous layers
    compute_q10n_params<uint8_t>("A", A_f32, scale_A, zp_A);
    compute_q10n_params<int8_t>("B", B_f32, scale_B, zp_B);
    assert(zp_B == 0 && "for int8 q10n we assume zero point = 0");

    // Quantize matrix A_u8 using reorder primitive
    std::vector<uint8_t> A_u8(M * K, 0);
    memory::desc a_u8_md({M, K}, memory::data_type::u8, {K, 1});
    memory A_u8_m(a_u8_md, eng, (void *)A_u8.data());
    quantize(A_f32, scale_A, zp_A, A_u8_m);

    // Quantize matrix B_s8 using reorder primitive
    std::vector<uint8_t> B_s8(K * N, 0);
    memory::desc b_s8_md({K, N}, memory::data_type::s8, {N, 1});
    memory B_s8_m(b_s8_md, eng, (void *)B_s8.data());
    quantize(B_f32, scale_B, 0, B_s8_m);

    // Compute C_f32. We cannot directly compute C_u8 since we don't know the
    // appropriate quantization parameters.
    //
    // Note: typically the computed data type in this case is int32_t and not
    //       float. But for brevity we are going to embed the scale_A and
    //       scale_B directly in this quantized MatMul, and hence will get the
    //       intermediate computation in floating point anyways, so there is
    //       no sense to convert the result to int32_t.
    //       In theory, we could postpone using the scale_A and scale_B, compute
    //       the exact C_s32 := (A_u8 - zp_A) * B_s8, and then find the
    //       appropriate quantization parameters for matrix C.
    //       Let it be an exercise :)

    std::vector<float> C_f32(M * N, 0);
    memory::desc c_f32_md({M, N}, memory::data_type::f32, {N, 1});
    memory C_f32_m(c_f32_md, eng, (void *)C_f32.data());

    // Create and compute a reduced precision MatMul primitive
    {
        primitive_attr matmul_attr;
        matmul_attr.set_scales_mask(DNNL_ARG_SRC, /* mask */ 0);
        matmul_attr.set_scales_mask(DNNL_ARG_WEIGHTS, /* mask */ 0);
        matmul_attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0);

        matmul::primitive_desc matmul_pd(
                eng, a_u8_md, b_s8_md, c_f32_md, matmul_attr);
        matmul matmul_p(matmul_pd);

        memory scales_A_m({{1}, memory::data_type::f32, {1}}, eng, &scale_A);
        memory scales_B_m({{1}, memory::data_type::f32, {1}}, eng, &scale_B);
        memory zp_A_m({{1}, memory::data_type::s32, {1}}, eng, &zp_A);

        matmul_p.execute(s,
                {{DNNL_ARG_SRC, A_u8_m}, {DNNL_ARG_WEIGHTS, B_s8_m},
                        {DNNL_ARG_DST, C_f32_m},
                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, scales_A_m},
                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scales_B_m},
                        {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_m}});
    }

    // Find quantization parameters for matrix C
    compute_q10n_params<uint8_t>("C", C_f32, scale_C, zp_C);

    // Finally quantize the matrix C
    memory::desc c_u8_md({M, N}, memory::data_type::u8, {N, 1});
    memory C_u8_m(c_u8_md, eng, (void *)C_u8.data());
    quantize(C_f32, scale_C, zp_C, C_u8_m);
}

void compare_f32_and_quantized_matmuls() {
    // MatMul parameters
    const int64_t M = 10, N = 20, K = 30;

    // Data distribution for matrices A and B
    const float param_A_min_val = -2.f;
    const float param_A_max_val = 1.4f;

    const float param_B_min_val = -1.f;
    const float param_B_max_val = -param_B_min_val; // B is centered around 0

    // Thresholds
    //
    const float threshold_dynamic_q10n = 3 * 1e-2f;

    // Prepare matrices
    std::vector<float> A_f32(M * K), B_f32(K * N), C_f32(M * N, 0);
    init_vector(A_f32, param_A_min_val, param_A_max_val);
    init_vector(B_f32, param_B_min_val, param_B_max_val);

    // Compute _true_ f32 result
    f32_matmul_compute(M, N, K, A_f32, B_f32, C_f32);

    std::vector<uint8_t> C_u8_dynamic_q10n(M * N, 0);

    float scale_C_dynamic_q10n; // Q10n parameters we don't know yet
    int zp_C_dynamic_q10n;

    dynamic_q10n_matmul(M, N, K, A_f32, B_f32, C_u8_dynamic_q10n,
            scale_C_dynamic_q10n, zp_C_dynamic_q10n);

    // Compare _true_ f32 result with dynamic q10n
    int rc = compare_vectors(C_f32, C_u8_dynamic_q10n, scale_C_dynamic_q10n,
            zp_C_dynamic_q10n, threshold_dynamic_q10n);
    if (rc) throw std::logic_error("Dynamic quantization accuracy failed.");
}

int main(int argc, char **argv) {
    return handle_example_errors(
            {engine::kind::cpu}, compare_f32_and_quantized_matmuls);
}