#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <random>
#include <stdexcept>
#include <vector>
#include "example_utils.hpp"
namespace {
void init_vector(std::vector<float> &v) {
std::mt19937 gen;
std::uniform_real_distribution<float> u(0, 1);
for (auto &e : v)
e = u(gen);
}
void init_vector(std::vector<uint8_t> &v) {
std::mt19937 gen;
std::uniform_int_distribution<unsigned int> u(0, 255);
for (auto &e : v)
e = static_cast<uint8_t>(u(gen));
}
}
int number_of_runs = 1;
int64_t K, int64_t N,
const engine &eng) {
attr.set_post_ops(po);
}
std::vector<uint8_t> A_u8(M * K);
init_vector(A_u8);
std::vector<float> scales_f32(N);
init_vector(scales_f32);
int32_t zp_A = 128, zp_C = 40;
write_to_dnnl_memory(A_u8.data(), A_u8_mem);
write_to_dnnl_memory(&zp_A, zp_A_mem);
write_to_dnnl_memory(&zp_C, zp_C_mem);
write_to_dnnl_memory(scales_f32.data(), scale_f32_mem);
}
int32_t zp_C = 0;
std::vector<uint8_t> C_u8(M * N);
read_from_dnnl_memory(C_u8.data(), C_u8_mem);
read_from_dnnl_memory(&zp_C, zp_C_mem);
for (int64_t i = 0; i < M * N; ++i)
if (C_u8[i] < zp_C)
throw std::logic_error(
"Smoke check failed."
"\n\tQuantized value is smaller than the zero point,"
"\n\twhich should not happen since ReLU was applied.");
}
void infer(
const matmul &matmul_p, int64_t M, int64_t N, int64_t K,
prepare_input(A_u8_mem, scale_f32_mem, zp_A_mem, zp_C_mem);
for (int run = 0; run < number_of_runs; ++run)
{{DNNL_ARG_SRC, A_u8_mem}, {DNNL_ARG_WEIGHTS, B_s8_mem},
{DNNL_ARG_DST, C_u8_mem},
{DNNL_ARG_ATTR_OUTPUT_SCALES, scale_f32_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_mem},
{DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_C_mem}});
s.wait();
sanity_check(C_u8_mem, zp_C_mem);
}
const int64_t K = 96;
const int64_t N = 1000;
auto matmul_pd = matmul_pd_create(K, N, eng);
std::vector<float> B_f32(K * N);
init_vector(B_f32);
memory B_s8_mem(matmul_pd.weights_desc(), eng);
{
{{K, N}, memory::data_type::f32, memory::format_tag::ab}, eng);
write_to_dnnl_memory(B_f32.data(), B_f32_mem);
s.wait();
}
for (int64_t M : {1, 100})
infer(matmul_p, M, N, K, B_s8_mem, eng);
}
int main(int argc, char **argv) {
return handle_example_errors(inference_int8_matmul, engine_kind);
}