Llama Q4_Kx8 GEMV
Same dot-product totals as llama_q4k_dot, but the Q4_K rows are delivered pre-packed in llama.cpp's block_q4_Kx8 layout (eight interleaved rows per packed block), so a GEMV-shaped kernel can be benchmarked without paying the repack cost.
The input starts with magic bytes CMQ4KX01, then little-endian u32 values: case_count, rows_per_case, blocks_per_row, and n. rows_per_case is guaranteed to be a multiple of 8.
Each case stores block_q8_K[blocks_per_row], followed by (rows_per_case / 8) * blocks_per_row packed block_q4_Kx8 structures, ordered row-group-major (all 32 blocks for rows 0..7, then for rows 8..15, ...).
block_q4_Kx8 is 1152 bytes: 8 u16 d values, 8 u16 dmin values, 96 bytes of repacked 6-bit scales/mins, and 1024 bytes of 4-bit quants packed in 8-byte chunks round-robin across the 8 source blocks.
Source format: https://github.com/ggml-org/llama.cpp ggml/src/ggml-cpu/arch/x86/repack.cpp.
use std::io::Read;
const QK_K: usize = 256;
const Q4_BLOCK_BYTES: usize = 144;
const Q8_BLOCK_BYTES: usize = 292;
const Q4_KX8_BYTES: usize = 1152;
fn main() {
let mut input = Vec::new();
std::io::stdin().read_to_end(&mut input).unwrap();
if input.len() < 24 || &input[..8] != b"CMQ4KX01" {
std::process::exit(1);
}
let case_count = read_u32(&input[8..]) as usize;
let rows_per_case = read_u32(&input[12..]) as usize;
let blocks_per_row = read_u32(&input[16..]) as usize;
let n = read_u32(&input[20..]) as usize;
assert!(rows_per_case % 8 == 0);
let mut offset = 24;
let mut total = 0.0_f64;
// Per-case scratch buffers.
let mut q4_row = vec![0_u8; blocks_per_row * Q4_BLOCK_BYTES];
for _ in 0..case_count {
let q8_bytes = blocks_per_row * Q8_BLOCK_BYTES;
let q8 = &input[offset..offset + q8_bytes];
offset += q8_bytes;
for _ in 0..rows_per_case / 8 {
let group_bytes = blocks_per_row * Q4_KX8_BYTES;
let group = &input[offset..offset + group_bytes];
offset += group_bytes;
// Unpack the row group into 8 plain Q4_K rows, then dot each against q8.
// A real solution would consume block_q4_Kx8 directly via SIMD.
for r in 0..8 {
for blk in 0..blocks_per_row {
let packed = &group[blk * Q4_KX8_BYTES..(blk + 1) * Q4_KX8_BYTES];
let dst = &mut q4_row[blk * Q4_BLOCK_BYTES..(blk + 1) * Q4_BLOCK_BYTES];
unpack_q4_kx8_row(packed, r, dst);
}
total += f64::from(vec_dot_q4_k_q8_k(n, &q4_row, q8));
}
}
}
print!("{total:.2}");
}
/// Recover a single Q4_K row (index `row` in 0..8) from one packed block_q4_Kx8.
fn unpack_q4_kx8_row(packed: &[u8], row: usize, out: &mut [u8]) {
// d, dmin at offsets 0..16 and 16..32 respectively.
out[0..2].copy_from_slice(&packed[row * 2..row * 2 + 2]);
out[2..4].copy_from_slice(&packed[16 + row * 2..16 + row * 2 + 2]);
// Scales: 96 bytes starting at offset 32 in packed; recover 12-byte Q4_K scales.
let scales_in = &packed[32..32 + 96];
let mut scale_values = [0_u8; 8];
let mut min_values = [0_u8; 8];
for i in 0..4 {
let a = [scales_in[i * 12], scales_in[i * 12 + 1], scales_in[i * 12 + 2], scales_in[i * 12 + 3]];
let b = [scales_in[i * 12 + 4], scales_in[i * 12 + 5], scales_in[i * 12 + 6], scales_in[i * 12 + 7]];
let h = [scales_in[i * 12 + 8], scales_in[i * 12 + 9], scales_in[i * 12 + 10], scales_in[i * 12 + 11]];
if row < 4 {
scale_values[i] = a[row] & 63;
min_values[i] = b[row] & 63;
} else {
let idx = row - 4;
scale_values[i] = (h[idx] & 0x0f) | ((a[idx] >> 6) << 4);
min_values[i] = (h[idx] >> 4) | ((b[idx] >> 6) << 4);
}
}
for i in 0..4 {
let a = [scales_in[i * 12 + 48], scales_in[i * 12 + 49], scales_in[i * 12 + 50], scales_in[i * 12 + 51]];
let b = [scales_in[i * 12 + 52], scales_in[i * 12 + 53], scales_in[i * 12 + 54], scales_in[i * 12 + 55]];
let h = [scales_in[i * 12 + 56], scales_in[i * 12 + 57], scales_in[i * 12 + 58], scales_in[i * 12 + 59]];
if row < 4 {
scale_values[i + 4] = a[row] & 63;
min_values[i + 4] = b[row] & 63;
} else {
let idx = row - 4;
scale_values[i + 4] = (h[idx] & 0x0f) | ((a[idx] >> 6) << 4);
min_values[i + 4] = (h[idx] >> 4) | ((b[idx] >> 6) << 4);
}
}
let scales_out = pack_q4_k_scales(scale_values, min_values);
out[4..16].copy_from_slice(&scales_out);
// Quants: 1024 interleaved bytes (128 chunks of 8) — pick chunks i where i%8 == row.
let qs_in = &packed[128..128 + 1024];
let qs_out = &mut out[16..16 + QK_K / 2];
let end = QK_K * 4 / 8;
for i in 0..end {
if i % 8 == row {
let dst_off = (i / 8) * 8;
let src_off = i * 8;
qs_out[dst_off..dst_off + 8].copy_from_slice(&qs_in[src_off..src_off + 8]);
}
}
}
fn pack_q4_k_scales(scales: [u8; 8], mins: [u8; 8]) -> [u8; 12] {
let mut packed = [0_u8; 12];
for index in 0..8 {
let scale = scales[index] & 63;
let min = mins[index] & 63;
if index < 4 {
packed[index] = scale;
packed[index + 4] = min;
} else {
packed[index + 4] = (scale & 0x0f) | ((min & 0x0f) << 4);
packed[index - 4] |= (scale >> 4) << 6;
packed[index] |= (min >> 4) << 6;
}
}
packed
}
fn vec_dot_q4_k_q8_k(n: usize, vx: &[u8], vy: &[u8]) -> f32 {
let nb = n / QK_K;
let mut sums = [0.0_f32; 8];
let mut sumf = 0.0_f32;
for ib in 0..nb {
let x = &vx[ib * Q4_BLOCK_BYTES..][..Q4_BLOCK_BYTES];
let y = &vy[ib * Q8_BLOCK_BYTES..][..Q8_BLOCK_BYTES];
let mut quants = [0_i8; QK_K];
for group in 0..4 {
let src = 16 + group * 32;
let dst = group * 64;
for lane in 0..32 {
let byte = x[src + lane];
quants[dst + lane] = (byte & 0x0f) as i8;
quants[dst + 32 + lane] = (byte >> 4) as i8;
}
}
let (scales, mins) = unpack_scales(&x[4..16]);
let mut sumi = 0_i32;
for index in 0..QK_K / 16 {
let bsum_offset = 4 + QK_K + index * 2;
let bsum = i16::from_le_bytes([y[bsum_offset], y[bsum_offset + 1]]);
sumi += i32::from(bsum) * i32::from(mins[index / 2]);
}
let mut aux32 = [0_i32; 8];
for group in 0..8 {
let scale = i32::from(scales[group]);
let base = group * 32;
for lane in 0..32 {
aux32[lane & 7] +=
scale * i32::from(y[4 + base + lane] as i8) * i32::from(quants[base + lane]);
}
}
let d = f16_to_f32(u16::from_le_bytes([x[0], x[1]]))
* f32::from_le_bytes([y[0], y[1], y[2], y[3]]);
for lane in 0..8 {
sums[lane] += d * aux32[lane] as f32;
}
let dmin = f16_to_f32(u16::from_le_bytes([x[2], x[3]]))
* f32::from_le_bytes([y[0], y[1], y[2], y[3]]);
sumf -= dmin * sumi as f32;
}
for value in sums {
sumf += value;
}
sumf
}
fn unpack_scales(packed: &[u8]) -> ([u8; 8], [u8; 8]) {
let mut scales = [0_u8; 8];
let mut mins = [0_u8; 8];
for index in 0..8 {
if index < 4 {
scales[index] = packed[index] & 63;
mins[index] = packed[index + 4] & 63;
} else {
scales[index] = (packed[index + 4] & 0x0f) | ((packed[index - 4] >> 6) << 4);
mins[index] = (packed[index + 4] >> 4) | ((packed[index] >> 6) << 4);
}
}
(scales, mins)
}
fn read_u32(bytes: &[u8]) -> u32 {
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits & 0x8000) as u32) << 16;
let exp = ((bits >> 10) & 0x1f) as i32;
let frac = (bits & 0x03ff) as u32;
let out = if exp == 0 {
if frac == 0 {
sign
} else {
let mut mant = frac;
let mut exponent = -14_i32;
while (mant & 0x0400) == 0 {
mant <<= 1;
exponent -= 1;
}
mant &= 0x03ff;
sign | (((exponent + 127) as u32) << 23) | (mant << 13)
}
} else if exp == 31 {
sign | 0x7f80_0000 | (frac << 13)
} else {
sign | (((exp - 15 + 127) as u32) << 23) | (frac << 13)
};
f32::from_bits(out)
}