Llama Q4_K Dot Product
Input is binary data containing packed llama.cpp-style Q4_K weight rows and Q8_K activation vectors.
For each case, compute every Q4_K row dot Q8_K vector using the ggml_vec_dot_q4_K_q8_K kernel shape. Sum all dot products and print the total with exactly 2 digits after the decimal point.
The input starts with magic bytes CMQ4K001, then little-endian u32 values: case_count, rows_per_case, blocks_per_row, and n. Each case then stores block_q8_K[blocks_per_row], followed by block_q4_K[rows_per_case * blocks_per_row].
The packed block layouts match llama.cpp: block_q4_K is 144 bytes for 256 weights, and block_q8_K is 292 bytes for 256 activation values.
The starter keeps the hot kernel separate from the file wrapper so the C++ kernel body can be moved into llama.cpp with minimal changes.
Source format: https://github.com/ggml-org/llama.cpp.
use std::io::Read;
const QK_K: usize = 256;
const Q4_BLOCK_BYTES: usize = 144;
const Q8_BLOCK_BYTES: usize = 292;
fn main() {
let mut input = Vec::new();
std::io::stdin().read_to_end(&mut input).unwrap();
if input.len() < 24 || &input[..8] != b"CMQ4K001" {
std::process::exit(1);
}
let case_count = read_u32(&input[8..]) as usize;
let rows_per_case = read_u32(&input[12..]) as usize;
let blocks_per_row = read_u32(&input[16..]) as usize;
let n = read_u32(&input[20..]) as usize;
let mut offset = 24;
let mut total = 0.0_f64;
for _ in 0..case_count {
let q8_bytes = blocks_per_row * Q8_BLOCK_BYTES;
let q8 = &input[offset..offset + q8_bytes];
offset += q8_bytes;
let q4_bytes = rows_per_case * blocks_per_row * Q4_BLOCK_BYTES;
let q4 = &input[offset..offset + q4_bytes];
offset += q4_bytes;
for row in 0..rows_per_case {
let row_start = row * blocks_per_row * Q4_BLOCK_BYTES;
let row_end = row_start + blocks_per_row * Q4_BLOCK_BYTES;
let s = ggml_vec_dot_q4_k_q8_k(n, &q4[row_start..row_end], q8);
total += f64::from(s);
}
}
print!("{total:.2}");
}
fn ggml_vec_dot_q4_k_q8_k(n: usize, vx: &[u8], vy: &[u8]) -> f32 {
let nb = n / QK_K;
let mut sums = [0.0_f32; 8];
let mut sumf = 0.0_f32;
for ib in 0..nb {
let x = &vx[ib * Q4_BLOCK_BYTES..][..Q4_BLOCK_BYTES];
let y = &vy[ib * Q8_BLOCK_BYTES..][..Q8_BLOCK_BYTES];
let mut quants = [0_i8; QK_K];
for group in 0..4 {
let src = 16 + group * 32;
let dst = group * 64;
for lane in 0..32 {
let byte = x[src + lane];
quants[dst + lane] = (byte & 0x0f) as i8;
quants[dst + 32 + lane] = (byte >> 4) as i8;
}
}
let (scales, mins) = unpack_scales(&x[4..16]);
let mut sumi = 0_i32;
for index in 0..QK_K / 16 {
let bsum_offset = 4 + QK_K + index * 2;
let bsum = i16::from_le_bytes([y[bsum_offset], y[bsum_offset + 1]]);
sumi += i32::from(bsum) * i32::from(mins[index / 2]);
}
let mut aux32 = [0_i32; 8];
for group in 0..8 {
let scale = i32::from(scales[group]);
let base = group * 32;
for lane in 0..32 {
aux32[lane & 7] +=
scale * i32::from(y[4 + base + lane] as i8) * i32::from(quants[base + lane]);
}
}
let d = f16_to_f32(u16::from_le_bytes([x[0], x[1]]))
* f32::from_le_bytes([y[0], y[1], y[2], y[3]]);
for lane in 0..8 {
sums[lane] += d * aux32[lane] as f32;
}
let dmin = f16_to_f32(u16::from_le_bytes([x[2], x[3]]))
* f32::from_le_bytes([y[0], y[1], y[2], y[3]]);
sumf -= dmin * sumi as f32;
}
for value in sums {
sumf += value;
}
sumf
}
fn unpack_scales(packed: &[u8]) -> ([u8; 8], [u8; 8]) {
let mut scales = [0_u8; 8];
let mut mins = [0_u8; 8];
for index in 0..8 {
if index < 4 {
scales[index] = packed[index] & 63;
mins[index] = packed[index + 4] & 63;
} else {
scales[index] = (packed[index + 4] & 0x0f) | ((packed[index - 4] >> 6) << 4);
mins[index] = (packed[index + 4] >> 4) | ((packed[index] >> 6) << 4);
}
}
(scales, mins)
}
fn read_u32(bytes: &[u8]) -> u32 {
u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
fn f16_to_f32(bits: u16) -> f32 {
let sign = ((bits & 0x8000) as u32) << 16;
let exp = ((bits >> 10) & 0x1f) as i32;
let frac = (bits & 0x03ff) as u32;
let out = if exp == 0 {
if frac == 0 {
sign
} else {
let mut mant = frac;
let mut exponent = -14_i32;
while (mant & 0x0400) == 0 {
mant <<= 1;
exponent -= 1;
}
mant &= 0x03ff;
sign | (((exponent + 127) as u32) << 23) | (mant << 13)
}
} else if exp == 31 {
sign | 0x7f80_0000 | (frac << 13)
} else {
sign | (((exp - 15 + 127) as u32) << 23) | (frac << 13)
};
f32::from_bits(out)
}