Llama.cpp's ggml_gemv_q4_K_8x8_q8_K
Implement llama.cpp's ggml_gemv_q4_K_8x8_q8_K function. The judge builds your source as libsolution.so and calls the exported function from a verifier executable.
This kernel multiplies a repacked Q4_K weight matrix by a Q8_K activation vector, producing floating-point dot-product results for one or more output columns. It is one of llama.cpp's CPU decode hot paths for quantized models: the weights are stored as 4-bit grouped blocks with scales, the runtime activations are quantized to Q8_K blocks, and the function has to unpack, scale, multiply, and accumulate those blocks fast enough to matter for real token generation.
The required symbol is extern "C" void ggml_gemv_q4_K_8x8_q8_K(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc).
The verifier supplies block_q4_Kx8 and block_q8_K buffers in llama.cpp layout and compares the printed aggregate against the scalar reference with a small floating-point tolerance.
use std::ffi::c_void;
const QK_K: usize = 256;
#[repr(C)]
struct BlockQ4Kx8 {
d: [u16; 8],
dmin: [u16; 8],
scales: [u8; 96],
qs: [u8; 1024],
}
#[repr(C)]
struct BlockQ8K {
d: f32,
qs: [i8; QK_K],
bsums: [i16; QK_K / 16],
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn ggml_gemv_q4_K_8x8_q8_K(
n: i32,
s: *mut f32,
_bs: usize,
vx: *const c_void,
vy: *const c_void,
nr: i32,
nc: i32,
) {
let nb = n as usize / QK_K;
let nr = nr as usize;
let nc = nc as usize;
let q4 = unsafe { std::slice::from_raw_parts(vx.cast::<BlockQ4Kx8>(), (nc / 8) * nb) };
let q8 = unsafe { std::slice::from_raw_parts(vy.cast::<BlockQ8K>(), nr * nb) };
let out = unsafe { std::slice::from_raw_parts_mut(s, nr * nc) };
reference_gemv(nb, nr, nc, q4, q8, out);
}
fn reference_gemv(
nb: usize,
nr: usize,
nc: usize,
q4: &[BlockQ4Kx8],
q8: &[BlockQ8K],
out: &mut [f32],
) {
let kmask1 = 0x3f3f3f3f_u32;
let kmask2 = 0x0f0f0f0f_u32;
let kmask3 = 0x03030303_u32;
for y in 0..nr {
let a_ptr = &q8[y * nb..];
for x in 0..nc / 8 {
let b_ptr = &q4[x * nb..];
let mut sumf = [0.0_f32; 8];
let mut sum_minf = [0.0_f32; 8];
let mut utmp = [0_u32; 32];
for l in 0..nb {
for sb in 0..8 {
let src = sb * 12;
let dst = sb * 4;
utmp[dst] = read_u32(&b_ptr[l].scales[src..src + 4]);
utmp[dst + 1] = read_u32(&b_ptr[l].scales[src + 4..src + 8]);
utmp[dst + 2] = read_u32(&b_ptr[l].scales[src + 8..src + 12]);
utmp[dst + 3] =
((utmp[dst + 2] >> 4) & kmask2) | (((utmp[dst + 1] >> 6) & kmask3) << 4);
let uaux = utmp[dst + 1] & kmask1;
utmp[dst + 1] =
(utmp[dst + 2] & kmask2) | (((utmp[dst] >> 6) & kmask3) << 4);
utmp[dst + 2] = uaux;
utmp[dst] &= kmask1;
}
for k in 0..QK_K / 16 {
let scales_0 = (k / 4) * 32;
let scales_1 = scales_0 + 16;
for j in 0..8 {
let mut sumi = 0_i32;
for i in 0..8 {
let idx = k * 64 + j * 8 + i;
let v0 = i32::from(b_ptr[l].qs[idx] & 0x0f);
let v1 = i32::from(b_ptr[l].qs[idx] >> 4);
let q8_base = (k >> 2) * 64 + (k & 3) * 8 + i;
sumi += v0
* i32::from(a_ptr[l].qs[q8_base])
* i32::from(utmp_byte(&utmp, scales_0 + j));
sumi += v1
* i32::from(a_ptr[l].qs[q8_base + 32])
* i32::from(utmp_byte(&utmp, scales_1 + j));
}
sumf[j] +=
sumi as f32 * f16_to_f32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for sb in 0..8 {
let mins = 8 + sb * 16;
let q8sum = i32::from(a_ptr[l].bsums[sb * 2])
+ i32::from(a_ptr[l].bsums[sb * 2 + 1]);
for j in 0..8 {
sum_minf[j] += f32::from(utmp_byte(&utmp, mins + j))
* q8sum as f32
* f16_to_f32(b_ptr[l].dmin[j])
* a_ptr[l].d;
}
}
}
for j in 0..8 {
out[y * nc + x * 8 + j] = sumf[j] - sum_minf[j];
}
}
}
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits & 0x8000) as u32) << 16;
let exp = ((bits >> 10) & 0x1f) as i32;
let frac = (bits & 0x03ff) as u32;
let out = if exp == 0 {
if frac == 0 {
sign
} else {
let mut mant = frac;
let mut exponent = -14_i32;
while (mant & 0x0400) == 0 {
mant <<= 1;
exponent -= 1;
}
mant &= 0x03ff;
sign | (((exponent + 127) as u32) << 23) | (mant << 13)
}
} else if exp == 31 {
sign | 0x7f80_0000 | (frac << 13)
} else {
sign | (((exp - 15 + 127) as u32) << 23) | (frac << 13)
};
f32::from_bits(out)
}
fn read_u32(bytes: &[u8]) -> u32 {
u32::from_le_bytes(bytes.try_into().expect("slice length"))
}
fn utmp_byte(words: &[u32; 32], byte_index: usize) -> u8 {
((words[byte_index / 4] >> ((byte_index % 4) * 8)) & 0xff) as u8
}