#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include <type_traits>
#include "example_utils.hpp"
enum class q10n_scheme_t { DYNAMIC, STATIC };
namespace {
void init_vector(std::vector<float> &v, float min_value, float max_value) {
std::mt19937 gen;
std::uniform_real_distribution<float> u(min_value, max_value);
for (auto &e : v)
e = u(gen);
}
template <typename T>
void find_min_max(const std::vector<T> &v, float &min_value, float &max_value) {
min_value = max_value = v[0];
for (auto &e : v) {
min_value = std::min<float>(min_value, e);
max_value = std::max<float>(max_value, e);
}
}
template <typename T>
void compute_q10n_params(const char *message, const std::vector<float> &v,
float &scale, int32_t &zp) {
float max_int = (float)std::numeric_limits<T>::max() - 1;
float min_int = (float)std::numeric_limits<T>::lowest() + 1;
#ifndef OMIT_WORKAROUND_FOR_SKX
if (std::is_same<T, uint8_t>::value) max_int /= 2;
#endif
float min_val = v[0], max_val = v[0];
find_min_max(v, min_val, max_val);
scale = (max_val - min_val) / (max_int - min_int);
if (std::is_same<T, int8_t>::value)
zp = 0;
else
zp = (int32_t)(max_int - max_val / scale);
printf("\tComputing q10n params for %s\n"
"\t\tData type: %s\n"
"\t\tScale:%.3g (inverse scale:%.3g)\n"
"\t\tZero point:%d\n\n",
message, std::is_same<T, int8_t>::value ? "int8_t" : "uint8_t",
scale, 1 / scale, zp);
}
int compare_vectors(const std::vector<float> &v1,
const std::vector<uint8_t> &v2, float scale_v2, int32_t zp_v2,
float threshold) {
double v1_l2 = 0, diff_l2 = 0;
for (size_t n = 0; n < v1.size(); ++n) {
float v2_n = scale_v2 * (v2[n] - zp_v2);
float diff = v1[n] - v2_n;
v1_l2 += v1[n] * v1[n];
diff_l2 += diff * diff;
}
v1_l2 = std::sqrt(v1_l2);
diff_l2 = std::sqrt(diff_l2);
bool ok = diff_l2 <= threshold * v1_l2;
printf("\tComparison (using l2-norms)\n"
"\t\tReference matrix:%g\n\t\tError:%g\n\t\tRelative error:%g\n"
"\nAccuracy check: %s\n\n",
v1_l2, diff_l2, diff_l2 / v1_l2, ok ? "OK" : "FAILED");
return ok ? 0 : 1;
}
}
void quantize(q10n_scheme_t q10n_scheme, const std::vector<float> &X_f32,
float scale_X, int32_t zp_X,
memory &X_int_m) {
float inv_scale_X = 1.f / scale_X;
const bool is_dynamic_q10n = q10n_scheme == q10n_scheme_t::DYNAMIC;
memory::desc x_f32_md({dims[0], dims[1]}, dt::f32, {dims[1], 1});
memory X_f32_m(x_f32_md, eng, (
void *)X_f32.data());
if (is_dynamic_q10n) {
memory scale_X_m({{1}, dt::f32, {1}}, eng, &inv_scale_X);
memory zp_X_m({{1}, dt::s32, {1}}, eng, &zp_X);
} else {
}
s.wait();
}
void f32_matmul_compute(int64_t M, int64_t N, int64_t K,
const std::vector<float> &A_f32, const std::vector<float> &B_f32,
std::vector<float> &C_f32) {
memory A_f32_m(a_md, eng, (
void *)A_f32.data());
memory B_f32_m(b_md, eng, (
void *)B_f32.data());
memory C_f32_m(c_md, eng, (
void *)C_f32.data());
matmul_p.execute(s,
s.wait();
}
void dynamic_q10n_matmul(int64_t M, int64_t N, int64_t K,
const std::vector<float> &A_f32, const std::vector<float> &B_f32,
std::vector<uint8_t> &C_u8, float &scale_C, int32_t &zp_C) {
float scale_A, scale_B;
int32_t zp_A, zp_B;
compute_q10n_params<uint8_t>("A", A_f32, scale_A, zp_A);
compute_q10n_params<int8_t>("B", B_f32, scale_B, zp_B);
assert(zp_B == 0 && "for int8 q10n we assume zero point = 0");
std::vector<uint8_t> A_u8(M * K, 0);
memory A_u8_m(a_u8_md, eng, (
void *)A_u8.data());
quantize(q10n_scheme_t::DYNAMIC, A_f32, scale_A, zp_A, A_u8_m);
std::vector<uint8_t> B_s8(K * N, 0);
memory B_s8_m(b_s8_md, eng, (
void *)B_s8.data());
quantize(q10n_scheme_t::DYNAMIC, B_f32, scale_B, 0, B_s8_m);
std::vector<float> C_f32(M * N, 0);
memory C_f32_m(c_f32_md, eng, (
void *)C_f32.data());
{
float output_scale = scale_A * scale_B;
matmul_p.execute(s,
}
compute_q10n_params<uint8_t>("C", C_f32, scale_C, zp_C);
memory C_u8_m(c_u8_md, eng, (
void *)C_u8.data());
quantize(q10n_scheme_t::DYNAMIC, C_f32, scale_C, zp_C, C_u8_m);
}
void static_q10n_matmul(int64_t M, int64_t N, int64_t K,
const std::vector<float> &A_f32, const std::vector<float> &B_f32,
float scale_A, int32_t zp_A, float scale_B, float scale_C, int32_t zp_C,
std::vector<uint8_t> &C_u8) {
std::vector<uint8_t> A_u8(M * K, 0);
memory A_u8_m(a_u8_md, eng, (
void *)A_u8.data());
quantize(q10n_scheme_t::STATIC, A_f32, scale_A, zp_A, A_u8_m);
std::vector<uint8_t> B_s8(K * N, 0);
memory B_s8_m(b_s8_md, eng, (
void *)B_s8.data());
quantize(q10n_scheme_t::STATIC, B_f32, scale_B, 0, B_s8_m);
{
memory C_u8_m(c_u8_md, eng, (
void *)C_u8.data());
0, {scale_A * scale_B / scale_C});
matmul_p.execute(s,
}
}
void compare_f32_and_quantized_matmuls() {
const int64_t M = 10, N = 20, K = 30;
const float param_A_min_val = -2.f;
const float param_A_max_val = 1.4f;
const float param_B_min_val = -1.f;
const float param_B_max_val = -param_B_min_val;
const float threshold_dynamic_q10n = 3 * 1e-2f;
const float threshold_static_q10n = 4 * 1e-2f;
std::vector<float> A_f32(M * K), B_f32(K * N), C_f32(M * N, 0);
init_vector(A_f32, param_A_min_val, param_A_max_val);
init_vector(B_f32, param_B_min_val, param_B_max_val);
f32_matmul_compute(M, N, K, A_f32, B_f32, C_f32);
{
printf("# DYNAMIC quantization\n\n");
std::vector<uint8_t> C_u8_dynamic_q10n(M * N, 0);
float scale_C_dynamic_q10n;
int zp_C_dynamic_q10n;
dynamic_q10n_matmul(M, N, K, A_f32, B_f32, C_u8_dynamic_q10n,
scale_C_dynamic_q10n, zp_C_dynamic_q10n);
int rc = compare_vectors(C_f32, C_u8_dynamic_q10n, scale_C_dynamic_q10n,
zp_C_dynamic_q10n, threshold_dynamic_q10n);
if (rc) throw std::logic_error("Dynamic quantization accuracy failed.");
}
{
printf("# STATIC quantization\n\n");
std::vector<uint8_t> C_u8_static_q10n(M * N, 0);
const float scale_A_static_q10n
= (param_A_max_val - param_A_min_val) / 128;
const int zp_A_static_q10n
= (int)(128 - param_A_max_val / scale_A_static_q10n);
const float scale_B_static_q10n
= (param_B_max_val - param_B_min_val) / 256;
float scale_C_static_q10n;
int zp_C_static_q10n;
const char *warn_message
= "C"
"\n\t*******************************************************"
"\n\t* NOTE: These computation do not happen in real world *"
"\n\t* applications and used here solely to simplify *"
"\n\t* the example. *"
"\n\t* Please refer to the example source code for *"
"\n\t* more information. *"
"\n\t*******************************************************";
compute_q10n_params<uint8_t>(
warn_message, C_f32, scale_C_static_q10n, zp_C_static_q10n);
static_q10n_matmul(M, N, K, A_f32, B_f32, scale_A_static_q10n,
zp_A_static_q10n, scale_B_static_q10n, scale_C_static_q10n,
zp_C_static_q10n, C_u8_static_q10n);
int rc = compare_vectors(C_f32, C_u8_static_q10n, scale_C_static_q10n,
zp_C_static_q10n, threshold_static_q10n);
if (rc) throw std::logic_error("Static quantization accuracy failed.");
}
}
int main(int argc, char **argv) {
return handle_example_errors(
}