cpu.mode fastest code on the internet
solution

sol_1111340_1778574551675220261_15

Rust s7nfo 2 runs public
01 source
Submitted source 8000 bytes
Compiler rustc Flags --edition=2024 -O -C target-cpu=native -C target-feature=+crt-static
show source
#![no_main]

use std::arch::x86_64::*;
use std::ffi::c_void;
use std::os::raw::c_int;

const QK_K: usize = 256;
const Q4_BLOCK_BYTES: usize = 144;
const Q8_BLOCK_BYTES: usize = 292;

const STDIN_FD: c_int = 0;
const STDOUT_FD: c_int = 1;
const PROT_READ: c_int = 1;
const PROT_WRITE: c_int = 2;
const MAP_PRIVATE: c_int = 2;
const MAP_ANONYMOUS: c_int = 0x20;
const MAP_POPULATE: c_int = 0x8000;

unsafe extern "C" {
    fn mmap(addr: *mut c_void, length: usize, prot: c_int, flags: c_int,
            fd: c_int, offset: i64) -> *mut c_void;
    fn read(fd: c_int, buf: *mut c_void, count: usize) -> isize;
    fn write(fd: c_int, buf: *const c_void, count: usize) -> isize;
    fn lseek(fd: c_int, offset: i64, whence: c_int) -> i64;
    fn _exit(status: c_int) -> !;
}

#[unsafe(no_mangle)]
pub extern "C" fn main(_argc: c_int, _argv: *const *const u8) -> c_int {
    unsafe { run() }
}

#[inline(always)]
unsafe fn read_u32_le(p: *const u8) -> u32 {
    unsafe { (p as *const u32).read_unaligned() }
}

unsafe fn slurp_stdin() -> (*const u8, usize) {
    unsafe {
        let size = lseek(STDIN_FD, 0, 2);
        if size > 0 {
            let _ = lseek(STDIN_FD, 0, 0);
            let p = mmap(core::ptr::null_mut(), size as usize, PROT_READ,
                         MAP_PRIVATE | MAP_POPULATE, STDIN_FD, 0);
            if p as isize != -1 {
                return (p as *const u8, size as usize);
            }
        }
        // Fallback for non-seekable stdin.
        let cap: usize = 256 * 1024 * 1024;
        let p = mmap(core::ptr::null_mut(), cap, PROT_READ | PROT_WRITE,
                     MAP_PRIVATE | MAP_ANONYMOUS, -1, 0) as *mut u8;
        let mut len = 0usize;
        loop {
            let n = read(STDIN_FD, p.add(len) as *mut c_void, cap - len);
            if n <= 0 { break; }
            len += n as usize;
            if len == cap { break; }
        }
        (p as *const u8, len)
    }
}

// 8 entries per i; entries i = 2-byte (k_shuffle[i*32 .. i*32+32]):
// i=0 -> 32 bytes of {0,1, 0,1, 0,1, ...}
// i=1 -> 32 bytes of {2,3, 2,3, ...}, etc.
static K_SCALE_SHUFFLE_K4: [u8; 256] = {
    let mut a = [0u8; 256];
    let mut i = 0;
    while i < 8 {
        let mut j = 0;
        while j < 32 {
            a[i * 32 + j] = (i * 2) as u8;
            a[i * 32 + j + 1] = (i * 2 + 1) as u8;
            j += 2;
        }
        i += 1;
    }
    a
};

#[inline(always)]
unsafe fn get_scale_shuffle_k4(i: usize) -> __m256i {
    unsafe {
        _mm256_loadu_si256(K_SCALE_SHUFFLE_K4.as_ptr().add(i * 32) as *const __m256i)
    }
}

#[inline(always)]
unsafe fn hsum_float_8(x: __m256) -> f32 {
    unsafe {
        let mut r = _mm256_extractf128_ps::<1>(x);
        r = _mm_add_ps(r, _mm256_castps256_ps128(x));
        r = _mm_add_ps(r, _mm_movehl_ps(r, r));
        r = _mm_add_ss(r, _mm_movehdup_ps(r));
        _mm_cvtss_f32(r)
    }
}

const KMASK1: u32 = 0x3f3f3f3f;
const KMASK2: u32 = 0x0f0f0f0f;
const KMASK3: u32 = 0x03030303;

#[target_feature(enable = "avx2,fma,f16c")]
unsafe fn vec_dot_q4_k_q8_k(n: usize, x: *const u8, y: *const u8) -> f32 {
    unsafe {
        let nb = n / QK_K;
        let m4 = _mm256_set1_epi8(0xf);

        let mut acc = _mm256_setzero_ps();
        let mut acc_m = _mm_setzero_ps();

        for ib in 0..nb {
            let xb = x.add(ib * Q4_BLOCK_BYTES);
            let yb = y.add(ib * Q8_BLOCK_BYTES);

            // d_x, dmin_x (half), y_d (float)
            // Load d & dmin halves as a 4-byte u32 and run F16C on both at once.
            let dd_bits = (xb as *const u32).read_unaligned();
            let dd = _mm_cvtph_ps(_mm_cvtsi32_si128(dd_bits as i32));
            let y_d = f32::from_le_bytes([*yb, *yb.add(1), *yb.add(2), *yb.add(3)]);
            let d = _mm_cvtss_f32(dd) * y_d;
            let dmin_f = _mm_cvtss_f32(_mm_shuffle_ps::<0b01_01_01_01>(dd, dd));
            let dmin = -dmin_f * y_d;

            // scales: derive utmp[0..4] (Q4_K's scales/mins layout).
            let mut utmp = [0u32; 4];
            let s = xb.add(4) as *const u8;
            // copy 12 bytes into utmp[0..3]
            core::ptr::copy_nonoverlapping(s, utmp.as_mut_ptr() as *mut u8, 12);
            utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
            let uaux = utmp[1] & KMASK1;
            utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
            utmp[2] = uaux;
            utmp[0] &= KMASK1;

            let q4_ptr = xb.add(16);
            let q8_ptr = yb.add(4);

            let mins_and_scales = _mm256_cvtepu8_epi16(
                _mm_set_epi32(utmp[3] as i32, utmp[2] as i32, utmp[1] as i32, utmp[0] as i32));

            // bsums at offset (4 + QK_K) inside block_q8_K.
            let q8sums = _mm256_loadu_si256(yb.add(4 + QK_K) as *const __m256i);
            let q8s = _mm_hadd_epi16(_mm256_castsi256_si128(q8sums),
                                     _mm256_extracti128_si256::<1>(q8sums));
            let prod = _mm_madd_epi16(_mm256_extracti128_si256::<1>(mins_and_scales), q8s);
            acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);

            let sc128 = _mm256_castsi256_si128(mins_and_scales);
            let scales = _mm256_insertf128_si256::<1>(_mm256_castsi128_si256(sc128), sc128);

            let mut sumi = _mm256_setzero_si256();
            let mut q4 = q4_ptr;
            let mut q8 = q8_ptr;

            for j in 0..(QK_K / 64) {
                let scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j));
                let scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2 * j + 1));

                let q4bits = _mm256_loadu_si256(q4 as *const __m256i);
                q4 = q4.add(32);
                let q4l = _mm256_and_si256(q4bits, m4);
                let q4h = _mm256_and_si256(_mm256_srli_epi16::<4>(q4bits), m4);

                let q8l = _mm256_loadu_si256(q8 as *const __m256i);
                q8 = q8.add(32);
                let mut p16l = _mm256_maddubs_epi16(q4l, q8l);
                p16l = _mm256_madd_epi16(scale_l, p16l);

                let q8h = _mm256_loadu_si256(q8 as *const __m256i);
                q8 = q8.add(32);
                let mut p16h = _mm256_maddubs_epi16(q4h, q8h);
                p16h = _mm256_madd_epi16(scale_h, p16h);
                let sumj = _mm256_add_epi32(p16l, p16h);

                sumi = _mm256_add_epi32(sumi, sumj);
            }

            let vd = _mm256_set1_ps(d);
            acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
        }

        acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
        acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));

        hsum_float_8(acc) + _mm_cvtss_f32(acc_m)
    }
}

#[target_feature(enable = "avx2,fma,f16c")]
unsafe fn run() -> c_int {
    unsafe {
        let (input, input_len) = slurp_stdin();
        if input_len < 24 { _exit(1); }
        // Magic check
        let magic = b"CMQ4K001";
        for i in 0..8 {
            if *input.add(i) != magic[i] { _exit(1); }
        }

        let case_count    = read_u32_le(input.add(8))  as usize;
        let rows_per_case = read_u32_le(input.add(12)) as usize;
        let blocks_per_row= read_u32_le(input.add(16)) as usize;
        let n             = read_u32_le(input.add(20)) as usize;

        let row_bytes = blocks_per_row * Q4_BLOCK_BYTES;
        let q8_bytes = blocks_per_row * Q8_BLOCK_BYTES;

        let mut total = 0.0_f64;
        let mut off = input.add(24);
        for _c in 0..case_count {
            let q8 = off;
            off = off.add(q8_bytes);
            let q4 = off;
            off = off.add(rows_per_case * row_bytes);
            for row in 0..rows_per_case {
                let s = vec_dot_q4_k_q8_k(n, q4.add(row * row_bytes), q8);
                total += s as f64;
            }
        }

        let out = format!("{total:.2}");
        let _ = write(STDOUT_FD, out.as_ptr() as *const c_void, out.len());
        _exit(0);
    }
}
02 jobs
Systems 02 jobs
03 counters
Performance counters 23 counters
cyclesi
15,511,728
Show more
branch_instructionsi
665,739
branch_missesi
9,163
cycle_activity.stalls_l1d_missi
396,184
cycle_activity.stalls_l2_missi
179,556
cycle_activity.stalls_l3_missi
150,643
dtlb_load_misses.walk_completedi
86
instructionsi
51,960,538
mem_inst_retired.split_loadsi
3,080,290
mem_load_retired.l1_missi
25,537
mem_load_retired.l2_missi
3,712
mem_load_retired.l3_missi
2,963
tma_backend_boundi
25,114,254
tma_bad_speculationi
1,046,573
tma_memory_boundi
2,286,108
tma_retiringi
59,455,846
uops_dispatched.port_0i
12,845,547
uops_dispatched.port_1i
13,217,466
uops_dispatched.port_2_3_10i
10,216,658
uops_dispatched.port_4_9i
158,601
uops_dispatched.port_5_11i
16,213,061
uops_dispatched.port_6i
4,895,894
uops_dispatched.port_7_8i
153,868
04 top down
Top-down analysis Raptor Cove P-core
05 profile
load profile