solution
sol_3159247_1781243189954681829_6
01
source
Submitted source
15601 bytes
show source
// Adapted from ggml-org/llama.cpp ggml/src/ggml-cpu/arch/x86/repack.cpp.
// Upstream license: MIT, copyright (c) 2023-2026 The ggml authors.
#include <cstddef>
#include <cstdint>
#include <cstring>
#if defined(__AVX2__) && defined(__F16C__)
#include <immintrin.h>
#endif
static constexpr int QK_K = 256;
using ggml_half = std::uint16_t;
struct block_q4_Kx8 {
ggml_half d[8];
ggml_half dmin[8];
std::uint8_t scales[96];
std::uint8_t qs[1024];
};
struct block_q8_K {
float d;
std::int8_t qs[QK_K];
std::int16_t bsums[QK_K / 16];
};
static_assert(sizeof(block_q4_Kx8) == 1152);
static_assert(sizeof(block_q8_K) == 292);
static float fp16_to_fp32(ggml_half h) {
const std::uint32_t sign = (std::uint32_t(h) & 0x8000u) << 16;
std::uint32_t exp = (std::uint32_t(h) >> 10) & 0x1fu;
std::uint32_t mant = std::uint32_t(h) & 0x03ffu;
if (exp == 0) {
if (mant == 0) {
const std::uint32_t out = sign;
float value;
std::memcpy(&value, &out, sizeof(value));
return value;
}
while ((mant & 0x0400u) == 0) {
mant <<= 1;
--exp;
}
++exp;
mant &= 0x03ffu;
} else if (exp == 31) {
const std::uint32_t out = sign | 0x7f800000u | (mant << 13);
float value;
std::memcpy(&value, &out, sizeof(value));
return value;
}
exp = exp + (127 - 15);
const std::uint32_t out = sign | (exp << 23) | (mant << 13);
float value;
std::memcpy(&value, &out, sizeof(value));
return value;
}
static void gemv_q4_K_8x8_q8_K_generic(
int n,
float * out,
const block_q4_Kx8 * q4,
const block_q8_K * q8
) {
const int nb = n / QK_K;
static const std::uint32_t kmask1 = 0x3f3f3f3f;
static const std::uint32_t kmask2 = 0x0f0f0f0f;
static const std::uint32_t kmask3 = 0x03030303;
float sumf[8] = {};
float sum_minf[8] = {};
std::uint32_t utmp[32];
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
std::memcpy(utmp + sb * 4, q4[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] =
((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const std::uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] =
(utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < QK_K / 16; k++) {
auto * scales_0 = reinterpret_cast<std::uint8_t *>(utmp) + (k / 4) * 32;
auto * scales_1 = reinterpret_cast<std::uint8_t *>(utmp) + (k / 4) * 32 + 16;
for (int j = 0; j < 8; j++) {
int sumi = 0;
for (int i = 0; i < 8; ++i) {
const int idx = k * 64 + j * 8 + i;
const int v0 = int(q4[l].qs[idx] & 0x0f);
const int v1 = int(q4[l].qs[idx] >> 4);
const int q8_base = (k >> 2) * 64 + (k % 4) * 8 + i;
sumi += v0 * int(q8[l].qs[q8_base]) * int(scales_0[j]);
sumi += v1 * int(q8[l].qs[q8_base + 32]) * int(scales_1[j]);
}
sumf[j] += float(sumi) * fp16_to_fp32(q4[l].d[j]) * q8[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
auto * mins = reinterpret_cast<std::uint8_t *>(utmp) + 8 + sb * 16;
const int q8sum = q8[l].bsums[sb * 2] + q8[l].bsums[sb * 2 + 1];
for (int j = 0; j < 8; j++) {
sum_minf[j] +=
float(mins[j] * q8sum) * fp16_to_fp32(q4[l].dmin[j]) * q8[l].d;
}
}
}
for (int j = 0; j < 8; j++) {
out[j] = sumf[j] - sum_minf[j];
}
}
#if defined(__AVX2__) && defined(__F16C__)
static inline __m256 f32cx8_load(const ggml_half * x) {
return _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)));
}
static inline __m256 f32cx8_rearrange_load(const ggml_half * x, __m128i mask) {
return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)), mask));
}
static void gemv_q4_K_8x8_q8_K_inner(
int n,
float * out,
const block_q4_Kx8 * q4,
const block_q8_K * q8
) {
const int nb = n / QK_K;
static const std::uint32_t kmask1 = 0x3f3f3f3f;
static const std::uint32_t kmask2 = 0x0f0f0f0f;
static const std::uint32_t kmask3 = 0x03030303;
const __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
const __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
const __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
const __m256i m4b = _mm256_set1_epi8(0x0f);
__m256 acc_row = _mm256_setzero_ps();
__m256 acc_min_rows = _mm256_setzero_ps();
for (int b = 0; b < nb; b++) {
const __m256 row_scale_f32 = _mm256_set1_ps(q8[b].d);
const __m256 col_scale_f32 = f32cx8_rearrange_load(q4[b].d, deltamask);
const __m256 col_dmin_f32 = f32cx8_load(q4[b].dmin);
__m256i iacc_b = _mm256_setzero_si256();
__m256i iacc_min_b = _mm256_setzero_si256();
const __m256i q8sums = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q8[b].bsums));
__m256i q8s =
_mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
for (int sb = 0; sb < QK_K / 64; sb++) {
const std::uint8_t * q = q4[b].qs + sb * 256;
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q));
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 32));
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 64));
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 96));
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 128));
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 160));
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 192));
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 224));
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
std::uint32_t utmp_0[4], utmp_1[4];
std::memcpy(utmp_0, q4[b].scales + 24 * sb, 12);
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
const std::uint32_t uaux_0 = utmp_0[1] & kmask1;
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
utmp_0[2] = uaux_0;
utmp_0[0] &= kmask1;
std::memcpy(utmp_1, q4[b].scales + 12 + sb * 24, 12);
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
const std::uint32_t uaux_1 = utmp_1[1] & kmask1;
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
utmp_1[2] = uaux_1;
utmp_1[0] &= kmask1;
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(mins_and_scales_0, scalemask));
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(mins_and_scales_1, scalemask));
const __m256i mins_01 = _mm256_cvtepu8_epi16(
_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
const std::int8_t * y = q8[b].qs + sb * 64;
__m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y)));
__m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 16)));
__m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 32)));
__m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 48)));
lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
__m256i iacc_0 = _mm256_setzero_si256();
__m256i iacc_1 = _mm256_setzero_si256();
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00, _mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177), rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01, _mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177), rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02, _mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177), rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03, _mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177), rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10, _mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177), rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11, _mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177), rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12, _mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177), rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13, _mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177), rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
const __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
const __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
const __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
q8s = _mm256_bsrli_epi128(q8s, 4);
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
}
const __m256 scale = _mm256_mul_ps(col_scale_f32, row_scale_f32);
const __m256 min_scale = _mm256_mul_ps(col_dmin_f32, row_scale_f32);
#if defined(__FMA__)
acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_b), scale, acc_row);
acc_min_rows = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_min_b), min_scale, acc_min_rows);
#else
acc_row = _mm256_add_ps(acc_row, _mm256_mul_ps(_mm256_cvtepi32_ps(iacc_b), scale));
acc_min_rows = _mm256_add_ps(acc_min_rows, _mm256_mul_ps(_mm256_cvtepi32_ps(iacc_min_b), min_scale));
#endif
}
acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
_mm256_storeu_ps(out, _mm256_sub_ps(acc_row, acc_min_rows));
}
#else
static void gemv_q4_K_8x8_q8_K_inner(
int n,
float * out,
const block_q4_Kx8 * q4,
const block_q8_K * q8
) {
gemv_q4_K_8x8_q8_K_generic(n, out, q4, q8);
}
#endif
extern "C" void ggml_gemv_q4_K_8x8_q8_K(
int n,
float * s,
std::size_t bs,
const void * vx,
const void * vy,
int nr,
int nc
) {
(void) bs;
const int nb = n / QK_K;
const auto * b_ptr_start = static_cast<const block_q4_Kx8 *>(vx);
const auto * a_ptr_start = static_cast<const block_q8_K *>(vy);
for (int y = 0; y < nr; y++) {
const block_q8_K * a_ptr = a_ptr_start + y * nb;
for (int x = 0; x < nc / 8; x++) {
const block_q4_Kx8 * b_ptr = b_ptr_start + x * nb;
gemv_q4_K_8x8_q8_K_inner(n, s + y * nc + x * 8, b_ptr, a_ptr);
}
}
}
02
jobs
Systems
02 jobs
03
counters
Performance counters
31 counters
cyclesi
96,727,316Show more
branch_instructionsi
5,591,715branch_missesi
29,397cycle_activity.stalls_l1d_missi
1,700,725cycle_activity.stalls_l2_missi
1,234,347cycle_activity.stalls_l3_missi
562,498cycle_activity.stalls_totali
5,362,011dtlb_load_misses.walk_completedi
1,746exe_activity.bound_on_loadsi
2,787,826exe_activity.bound_on_storesi
768,416instructionsi
406,569,861machine_clearsi
5,922mem_inst_retired.split_loadsi
8,639,062mem_load_retired.l1_missi
763,987mem_load_retired.l2_missi
82,892mem_load_retired.l3_missi
5,780tma_backend_boundi
136,132,272tma_bad_speculationi
5,284,299tma_branch_mispredict_slotsi
4,614,400tma_frontend_boundi
19,456,547tma_memory_boundi
17,730,294tma_retiringi
420,635,222tma_slotsi
581,475,846uops_dispatched.port_0i
79,080,783uops_dispatched.port_1i
82,910,402uops_dispatched.port_2_3_10i
83,826,059uops_dispatched.port_4_9i
19,087,336uops_dispatched.port_5_11i
106,808,253uops_dispatched.port_6i
36,027,127uops_dispatched.port_7_8i
12,286,170uops_retired.msi
0
04
top down
Top-down analysis
Raptor Cove P-core
05
profile
load profile
03
counters
Performance counters
26 counters
cyclesi
209,774,149Show more
branch_instructionsi
5,594,442branch_missesi
30,320dtlb_load_misses.walk_completedi
14,214instructionsi
406,594,128mem_bound_stalls.load_dram_hiti
404,416mem_bound_stalls.load_l2_hiti
12,752,247mem_bound_stalls.load_llc_hiti
1,332,332mem_inst_retired.split_loadsi
236,775mem_load_retired.l1_missi
7,636,611mem_load_retired.l2_missi
59,306mem_load_retired.l3_missi
4,704tma_backend_boundi
337,412,854tma_backend_bound_alloc_restrictionsi
920,949tma_backend_bound_non_memory_scheduleri
306,437,245tma_backend_bound_registeri
8,748,115tma_backend_bound_reorder_bufferi
3,325,220tma_backend_bound_serializationi
10,187,636tma_bad_speculationi
8,067,561tma_bad_speculation_branch_mispredicti
7,341,605tma_bad_speculation_machine_clearsi
725,956tma_frontend_bandwidthi
20,634,571tma_frontend_boundi
24,518,511tma_frontend_latencyi
3,883,940tma_memory_boundi
6,632,675tma_retiringi
684,009,643
04
top down
Top-down analysis
Gracemont E-core
05
profile
load profile