solution
sol_3159247_1781243189660930107_0
01
source
Submitted source
16688 bytes
show source
// Adapted from ggml-org/llama.cpp ggml-cpu Q4_Kx8 GEMV code.
// Upstream license: MIT, copyright (c) 2023-2026 The ggml authors.
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <vector>
#if defined(__AVX2__)
#include <immintrin.h>
#endif
static constexpr int QK_K = 256;
struct block_q8_K {
float d;
std::int8_t qs[QK_K];
std::int16_t bsums[QK_K / 16];
};
struct block_q4_Kx8 {
std::uint16_t d[8];
std::uint16_t dmin[8];
std::uint8_t scales[96];
std::uint8_t qs[1024];
};
static_assert(sizeof(block_q8_K) == 292);
static_assert(sizeof(block_q4_Kx8) == 1152);
static float half_to_float(std::uint16_t bits) {
std::uint32_t sign = static_cast<std::uint32_t>(bits & 0x8000u) << 16;
int exp = (bits >> 10) & 0x1f;
std::uint32_t frac = bits & 0x03ffu;
std::uint32_t out;
if (exp == 0) {
if (frac == 0) {
out = sign;
} else {
int exponent = -14;
while ((frac & 0x0400u) == 0) {
frac <<= 1;
--exponent;
}
frac &= 0x03ffu;
out = sign | (static_cast<std::uint32_t>(exponent + 127) << 23) | (frac << 13);
}
} else if (exp == 31) {
out = sign | 0x7f800000u | (frac << 13);
} else {
out = sign | (static_cast<std::uint32_t>(exp - 15 + 127) << 23) | (frac << 13);
}
float value;
std::memcpy(&value, &out, sizeof(value));
return value;
}
static std::uint32_t read_u32(const unsigned char * p) {
return static_cast<std::uint32_t>(p[0]) |
(static_cast<std::uint32_t>(p[1]) << 8) |
(static_cast<std::uint32_t>(p[2]) << 16) |
(static_cast<std::uint32_t>(p[3]) << 24);
}
struct InputView {
const unsigned char * data = nullptr;
std::size_t size = 0;
std::vector<unsigned char> fallback;
bool mapped = false;
~InputView() {
if (mapped && data != MAP_FAILED) {
munmap(const_cast<unsigned char *>(data), size);
}
}
};
static InputView read_input() {
InputView input;
struct stat st {};
if (fstat(STDIN_FILENO, &st) == 0 && st.st_size > 0) {
void * mapped = mmap(nullptr, static_cast<std::size_t>(st.st_size), PROT_READ, MAP_PRIVATE, STDIN_FILENO, 0);
if (mapped != MAP_FAILED) {
input.data = static_cast<const unsigned char *>(mapped);
input.size = static_cast<std::size_t>(st.st_size);
input.mapped = true;
return input;
}
}
std::array<unsigned char, 1 << 20> buffer{};
while (true) {
const std::size_t n = std::fread(buffer.data(), 1, buffer.size(), stdin);
input.fallback.insert(input.fallback.end(), buffer.data(), buffer.data() + n);
if (n < buffer.size()) {
break;
}
}
input.data = input.fallback.data();
input.size = input.fallback.size();
return input;
}
static void gemv_q4_K_8x8_q8_K_generic(int n, float * out, const block_q4_Kx8 * q4, const block_q8_K * q8) {
const int nb = n / QK_K;
static const std::uint32_t kmask1 = 0x3f3f3f3f;
static const std::uint32_t kmask2 = 0x0f0f0f0f;
static const std::uint32_t kmask3 = 0x03030303;
float sumf[8] = {};
float sum_minf[8] = {};
std::uint32_t utmp[32];
for (int l = 0; l < nb; ++l) {
for (int sb = 0; sb < 8; ++sb) {
std::memcpy(utmp + sb * 4, q4[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] =
((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const std::uint32_t uaux = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] =
(utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux;
utmp[sb * 4] &= kmask1;
}
for (int k = 0; k < QK_K / 16; ++k) {
std::uint8_t * scales_0 = reinterpret_cast<std::uint8_t *>(utmp) + (k / 4) * 32;
std::uint8_t * scales_1 = reinterpret_cast<std::uint8_t *>(utmp) + (k / 4) * 32 + 16;
for (int j = 0; j < 8; ++j) {
int sumi = 0;
for (int i = 0; i < 8; ++i) {
const int q = q4[l].qs[k * 64 + j * 8 + i];
const int v0 = q & 0x0f;
const int v1 = q >> 4;
const int y0 = q8[l].qs[(k >> 2) * 64 + (k & 3) * 8 + i];
const int y1 = q8[l].qs[(k >> 2) * 64 + (k & 3) * 8 + i + 32];
sumi += v0 * y0 * scales_0[j] + v1 * y1 * scales_1[j];
}
sumf[j] += static_cast<float>(sumi) * half_to_float(q4[l].d[j]) * q8[l].d;
}
}
for (int sb = 0; sb < 8; ++sb) {
std::uint8_t * mins = reinterpret_cast<std::uint8_t *>(utmp) + 8 + sb * 16;
const int bsum = q8[l].bsums[sb * 2] + q8[l].bsums[sb * 2 + 1];
for (int j = 0; j < 8; ++j) {
sum_minf[j] += mins[j] * bsum * half_to_float(q4[l].dmin[j]) * q8[l].d;
}
}
}
for (int j = 0; j < 8; ++j) {
out[j] = sumf[j] - sum_minf[j];
}
}
#if defined(__AVX2__)
static inline __m256 f32cx8_load(const std::uint16_t * x) {
return _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)));
}
static inline __m256 f32cx8_rearrange_load(const std::uint16_t * x, __m128i mask) {
return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)), mask));
}
static void gemv_q4_K_8x8_q8_K(int n, float * out, const block_q4_Kx8 * q4, const block_q8_K * q8) {
#if defined(__F16C__)
const int nb = n / QK_K;
static const std::uint32_t kmask1 = 0x3f3f3f3f;
static const std::uint32_t kmask2 = 0x0f0f0f0f;
static const std::uint32_t kmask3 = 0x03030303;
const __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
const __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
const __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
const __m256i m4b = _mm256_set1_epi8(0x0f);
__m256 acc_row = _mm256_setzero_ps();
__m256 acc_min_rows = _mm256_setzero_ps();
for (int b = 0; b < nb; ++b) {
const __m256 row_scale_f32 = _mm256_set1_ps(q8[b].d);
const __m256 col_scale_f32 = f32cx8_rearrange_load(q4[b].d, deltamask);
const __m256 col_dmin_f32 = f32cx8_load(q4[b].dmin);
__m256i iacc_b = _mm256_setzero_si256();
__m256i iacc_min_b = _mm256_setzero_si256();
const __m256i q8sums = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q8[b].bsums));
__m256i q8s =
_mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
for (int sb = 0; sb < QK_K / 64; ++sb) {
const std::uint8_t * q = q4[b].qs + sb * 256;
const __m256i raw_0123_0 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q));
const __m256i raw_4567_0 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 32));
const __m256i raw_0123_1 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 64));
const __m256i raw_4567_1 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 96));
const __m256i raw_0123_2 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 128));
const __m256i raw_4567_2 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 160));
const __m256i raw_0123_3 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 192));
const __m256i raw_4567_3 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 224));
const __m256i v0123_00 = _mm256_and_si256(raw_0123_0, m4b);
const __m256i v4567_00 = _mm256_and_si256(raw_4567_0, m4b);
const __m256i v0123_01 = _mm256_and_si256(raw_0123_1, m4b);
const __m256i v4567_01 = _mm256_and_si256(raw_4567_1, m4b);
const __m256i v0123_02 = _mm256_and_si256(raw_0123_2, m4b);
const __m256i v4567_02 = _mm256_and_si256(raw_4567_2, m4b);
const __m256i v0123_03 = _mm256_and_si256(raw_0123_3, m4b);
const __m256i v4567_03 = _mm256_and_si256(raw_4567_3, m4b);
const __m256i v0123_10 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_0, 4), m4b);
const __m256i v4567_10 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_0, 4), m4b);
const __m256i v0123_11 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_1, 4), m4b);
const __m256i v4567_11 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_1, 4), m4b);
const __m256i v0123_12 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_2, 4), m4b);
const __m256i v4567_12 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_2, 4), m4b);
const __m256i v0123_13 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_3, 4), m4b);
const __m256i v4567_13 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_3, 4), m4b);
std::uint32_t utmp_0[4], utmp_1[4];
std::memcpy(utmp_0, q4[b].scales + 24 * sb, 12);
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
const std::uint32_t uaux_0 = utmp_0[1] & kmask1;
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
utmp_0[2] = uaux_0;
utmp_0[0] &= kmask1;
std::memcpy(utmp_1, q4[b].scales + 12 + sb * 24, 12);
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
const std::uint32_t uaux_1 = utmp_1[1] & kmask1;
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
utmp_1[2] = uaux_1;
utmp_1[0] &= kmask1;
const __m128i mins_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
const __m128i mins_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(mins_scales_0, scalemask));
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(mins_scales_1, scalemask));
const __m256i mins_01 = _mm256_cvtepu8_epi16(
_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_scales_0, 78), _mm_shuffle_epi32(mins_scales_1, 78)));
const std::int8_t * y = q8[b].qs + sb * 64;
__m256i lhs_00 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y)));
__m256i lhs_01 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 16)));
__m256i lhs_10 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 32)));
__m256i lhs_11 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 48)));
lhs_00 = _mm256_permute2f128_si256(lhs_00, lhs_00, 0);
lhs_01 = _mm256_permute2f128_si256(lhs_01, lhs_01, 0);
lhs_10 = _mm256_permute2f128_si256(lhs_10, lhs_10, 0);
lhs_11 = _mm256_permute2f128_si256(lhs_11, lhs_11, 0);
__m256i iacc_0 = _mm256_setzero_si256();
__m256i iacc_1 = _mm256_setzero_si256();
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_00, _mm256_shuffle_epi32(v4567_00, 177), 170), _mm256_shuffle_epi32(lhs_00, 0)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_00, 177), v4567_00, 170), _mm256_shuffle_epi32(lhs_00, 85)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_01, _mm256_shuffle_epi32(v4567_01, 177), 170), _mm256_shuffle_epi32(lhs_00, 170)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_01, 177), v4567_01, 170), _mm256_shuffle_epi32(lhs_00, 255)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_02, _mm256_shuffle_epi32(v4567_02, 177), 170), _mm256_shuffle_epi32(lhs_01, 0)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_02, 177), v4567_02, 170), _mm256_shuffle_epi32(lhs_01, 85)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_03, _mm256_shuffle_epi32(v4567_03, 177), 170), _mm256_shuffle_epi32(lhs_01, 170)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_03, 177), v4567_03, 170), _mm256_shuffle_epi32(lhs_01, 255)));
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_10, _mm256_shuffle_epi32(v4567_10, 177), 170), _mm256_shuffle_epi32(lhs_10, 0)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_10, 177), v4567_10, 170), _mm256_shuffle_epi32(lhs_10, 85)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_11, _mm256_shuffle_epi32(v4567_11, 177), 170), _mm256_shuffle_epi32(lhs_10, 170)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_11, 177), v4567_11, 170), _mm256_shuffle_epi32(lhs_10, 255)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_12, _mm256_shuffle_epi32(v4567_12, 177), 170), _mm256_shuffle_epi32(lhs_11, 0)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_12, 177), v4567_12, 170), _mm256_shuffle_epi32(lhs_11, 85)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_13, _mm256_shuffle_epi32(v4567_13, 177), 170), _mm256_shuffle_epi32(lhs_11, 170)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_13, 177), v4567_13, 170), _mm256_shuffle_epi32(lhs_11, 255)));
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
const __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
const __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
const __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
q8s = _mm256_bsrli_epi128(q8s, 4);
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
}
acc_row = _mm256_add_ps(
acc_row,
_mm256_mul_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32)));
acc_min_rows = _mm256_add_ps(
acc_min_rows,
_mm256_mul_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32)));
}
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
_mm256_storeu_ps(out, _mm256_sub_ps(acc_row, acc_min_rows));
#else
gemv_q4_K_8x8_q8_K_generic(n, out, q4, q8);
#endif
}
#else
static void gemv_q4_K_8x8_q8_K(int n, float * out, const block_q4_Kx8 * q4, const block_q8_K * q8) {
gemv_q4_K_8x8_q8_K_generic(n, out, q4, q8);
}
#endif
int main() {
InputView input = read_input();
if (input.size < 24 || std::memcmp(input.data, "CMQ4KX01", 8) != 0) {
return 1;
}
const std::uint32_t case_count = read_u32(input.data + 8);
const std::uint32_t rows_per_case = read_u32(input.data + 12);
const std::uint32_t blocks_per_row = read_u32(input.data + 16);
const std::uint32_t n = read_u32(input.data + 20);
const unsigned char * ptr = input.data + 24;
double total = 0.0;
for (std::uint32_t case_index = 0; case_index < case_count; ++case_index) {
const block_q8_K * q8 = reinterpret_cast<const block_q8_K *>(ptr);
ptr += static_cast<std::size_t>(blocks_per_row) * sizeof(block_q8_K);
for (std::uint32_t row_group = 0; row_group < rows_per_case / 8; ++row_group) {
const block_q4_Kx8 * q4 = reinterpret_cast<const block_q4_Kx8 *>(ptr);
ptr += static_cast<std::size_t>(blocks_per_row) * sizeof(block_q4_Kx8);
float rows[8];
gemv_q4_K_8x8_q8_K(static_cast<int>(n), rows, q4, q8);
for (float row : rows) {
total += static_cast<double>(row);
}
}
}
std::printf("%.2f", total);
}
02
jobs
Systems
02 jobs
03
counters
Performance counters
31 counters
cyclesi
16,115,654Show more
branch_instructionsi
950,040branch_missesi
15,532cycle_activity.stalls_l1d_missi
1,855,496cycle_activity.stalls_l2_missi
872,962cycle_activity.stalls_l3_missi
724,963cycle_activity.stalls_totali
2,732,983dtlb_load_misses.walk_completedi
852exe_activity.bound_on_loadsi
2,106,856exe_activity.bound_on_storesi
20,235instructionsi
52,380,658machine_clearsi
817mem_inst_retired.split_loadsi
1,073,625mem_load_retired.l1_missi
126,374mem_load_retired.l2_missi
27,829mem_load_retired.l3_missi
22,381tma_backend_boundi
34,666,962tma_bad_speculationi
2,269,165tma_branch_mispredict_slotsi
2,152,605tma_frontend_boundi
5,998,816tma_memory_boundi
10,443,132tma_retiringi
54,249,589tma_slotsi
97,165,086uops_dispatched.port_0i
10,140,284uops_dispatched.port_1i
10,699,023uops_dispatched.port_2_3_10i
11,119,520uops_dispatched.port_4_9i
2,650,194uops_dispatched.port_5_11i
14,219,225uops_dispatched.port_6i
4,875,922uops_dispatched.port_7_8i
1,725,806uops_retired.msi
0
04
top down
Top-down analysis
Raptor Cove P-core
05
profile
load profile
03
counters
Performance counters
26 counters
cyclesi
29,471,974Show more
branch_instructionsi
948,537branch_missesi
17,764dtlb_load_misses.walk_completedi
2,363instructionsi
52,375,376mem_bound_stalls.load_dram_hiti
1,050,353mem_bound_stalls.load_l2_hiti
3,645,543mem_bound_stalls.load_llc_hiti
135,671mem_inst_retired.split_loadsi
1,073,689mem_load_retired.l1_missi
1,911,281mem_load_retired.l2_missi
13,803mem_load_retired.l3_missi
10,033tma_backend_boundi
55,546,246tma_backend_bound_alloc_restrictionsi
168,454tma_backend_bound_non_memory_scheduleri
43,275,905tma_backend_bound_registeri
3,485,100tma_backend_bound_reorder_bufferi
1,292,710tma_backend_bound_serializationi
1,436,406tma_bad_speculationi
4,239,564tma_bad_speculation_branch_mispredicti
4,144,186tma_bad_speculation_machine_clearsi
95,378tma_frontend_bandwidthi
3,603,844tma_frontend_boundi
6,015,825tma_frontend_latencyi
2,411,981tma_memory_boundi
401,205tma_retiringi
87,213,146
04
top down
Top-down analysis
Gracemont E-core
05
profile
load profile