cpu.mode fastest code on the internet
solution

sol_3159247_1781243189660930107_0

C++ llama.cpp reference 2 runs
01 source
Submitted source 16688 bytes
Compiler clang++ Flags -O3 -march=native -std=c++20
show source
// Adapted from ggml-org/llama.cpp ggml-cpu Q4_Kx8 GEMV code.
// Upstream license: MIT, copyright (c) 2023-2026 The ggml authors.
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>
#include <vector>

#if defined(__AVX2__)
#include <immintrin.h>
#endif

static constexpr int QK_K = 256;

struct block_q8_K {
    float d;
    std::int8_t qs[QK_K];
    std::int16_t bsums[QK_K / 16];
};

struct block_q4_Kx8 {
    std::uint16_t d[8];
    std::uint16_t dmin[8];
    std::uint8_t scales[96];
    std::uint8_t qs[1024];
};

static_assert(sizeof(block_q8_K) == 292);
static_assert(sizeof(block_q4_Kx8) == 1152);

static float half_to_float(std::uint16_t bits) {
    std::uint32_t sign = static_cast<std::uint32_t>(bits & 0x8000u) << 16;
    int exp = (bits >> 10) & 0x1f;
    std::uint32_t frac = bits & 0x03ffu;
    std::uint32_t out;
    if (exp == 0) {
        if (frac == 0) {
            out = sign;
        } else {
            int exponent = -14;
            while ((frac & 0x0400u) == 0) {
                frac <<= 1;
                --exponent;
            }
            frac &= 0x03ffu;
            out = sign | (static_cast<std::uint32_t>(exponent + 127) << 23) | (frac << 13);
        }
    } else if (exp == 31) {
        out = sign | 0x7f800000u | (frac << 13);
    } else {
        out = sign | (static_cast<std::uint32_t>(exp - 15 + 127) << 23) | (frac << 13);
    }
    float value;
    std::memcpy(&value, &out, sizeof(value));
    return value;
}

static std::uint32_t read_u32(const unsigned char * p) {
    return static_cast<std::uint32_t>(p[0]) |
           (static_cast<std::uint32_t>(p[1]) << 8) |
           (static_cast<std::uint32_t>(p[2]) << 16) |
           (static_cast<std::uint32_t>(p[3]) << 24);
}

struct InputView {
    const unsigned char * data = nullptr;
    std::size_t size = 0;
    std::vector<unsigned char> fallback;
    bool mapped = false;

    ~InputView() {
        if (mapped && data != MAP_FAILED) {
            munmap(const_cast<unsigned char *>(data), size);
        }
    }
};

static InputView read_input() {
    InputView input;
    struct stat st {};
    if (fstat(STDIN_FILENO, &st) == 0 && st.st_size > 0) {
        void * mapped = mmap(nullptr, static_cast<std::size_t>(st.st_size), PROT_READ, MAP_PRIVATE, STDIN_FILENO, 0);
        if (mapped != MAP_FAILED) {
            input.data = static_cast<const unsigned char *>(mapped);
            input.size = static_cast<std::size_t>(st.st_size);
            input.mapped = true;
            return input;
        }
    }

    std::array<unsigned char, 1 << 20> buffer{};
    while (true) {
        const std::size_t n = std::fread(buffer.data(), 1, buffer.size(), stdin);
        input.fallback.insert(input.fallback.end(), buffer.data(), buffer.data() + n);
        if (n < buffer.size()) {
            break;
        }
    }
    input.data = input.fallback.data();
    input.size = input.fallback.size();
    return input;
}

static void gemv_q4_K_8x8_q8_K_generic(int n, float * out, const block_q4_Kx8 * q4, const block_q8_K * q8) {
    const int nb = n / QK_K;
    static const std::uint32_t kmask1 = 0x3f3f3f3f;
    static const std::uint32_t kmask2 = 0x0f0f0f0f;
    static const std::uint32_t kmask3 = 0x03030303;
    float sumf[8] = {};
    float sum_minf[8] = {};
    std::uint32_t utmp[32];

    for (int l = 0; l < nb; ++l) {
        for (int sb = 0; sb < 8; ++sb) {
            std::memcpy(utmp + sb * 4, q4[l].scales + sb * 12, 12);
            utmp[sb * 4 + 3] =
                ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
            const std::uint32_t uaux = utmp[sb * 4 + 1] & kmask1;
            utmp[sb * 4 + 1] =
                (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4] >> 6) & kmask3) << 4);
            utmp[sb * 4 + 2] = uaux;
            utmp[sb * 4] &= kmask1;
        }

        for (int k = 0; k < QK_K / 16; ++k) {
            std::uint8_t * scales_0 = reinterpret_cast<std::uint8_t *>(utmp) + (k / 4) * 32;
            std::uint8_t * scales_1 = reinterpret_cast<std::uint8_t *>(utmp) + (k / 4) * 32 + 16;
            for (int j = 0; j < 8; ++j) {
                int sumi = 0;
                for (int i = 0; i < 8; ++i) {
                    const int q = q4[l].qs[k * 64 + j * 8 + i];
                    const int v0 = q & 0x0f;
                    const int v1 = q >> 4;
                    const int y0 = q8[l].qs[(k >> 2) * 64 + (k & 3) * 8 + i];
                    const int y1 = q8[l].qs[(k >> 2) * 64 + (k & 3) * 8 + i + 32];
                    sumi += v0 * y0 * scales_0[j] + v1 * y1 * scales_1[j];
                }
                sumf[j] += static_cast<float>(sumi) * half_to_float(q4[l].d[j]) * q8[l].d;
            }
        }

        for (int sb = 0; sb < 8; ++sb) {
            std::uint8_t * mins = reinterpret_cast<std::uint8_t *>(utmp) + 8 + sb * 16;
            const int bsum = q8[l].bsums[sb * 2] + q8[l].bsums[sb * 2 + 1];
            for (int j = 0; j < 8; ++j) {
                sum_minf[j] += mins[j] * bsum * half_to_float(q4[l].dmin[j]) * q8[l].d;
            }
        }
    }

    for (int j = 0; j < 8; ++j) {
        out[j] = sumf[j] - sum_minf[j];
    }
}

#if defined(__AVX2__)
static inline __m256 f32cx8_load(const std::uint16_t * x) {
    return _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)));
}

static inline __m256 f32cx8_rearrange_load(const std::uint16_t * x, __m128i mask) {
    return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128(reinterpret_cast<const __m128i *>(x)), mask));
}

static void gemv_q4_K_8x8_q8_K(int n, float * out, const block_q4_Kx8 * q4, const block_q8_K * q8) {
#if defined(__F16C__)
    const int nb = n / QK_K;
    static const std::uint32_t kmask1 = 0x3f3f3f3f;
    static const std::uint32_t kmask2 = 0x0f0f0f0f;
    static const std::uint32_t kmask3 = 0x03030303;
    const __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
    const __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
    const __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
    const __m256i m4b = _mm256_set1_epi8(0x0f);
    __m256 acc_row = _mm256_setzero_ps();
    __m256 acc_min_rows = _mm256_setzero_ps();

    for (int b = 0; b < nb; ++b) {
        const __m256 row_scale_f32 = _mm256_set1_ps(q8[b].d);
        const __m256 col_scale_f32 = f32cx8_rearrange_load(q4[b].d, deltamask);
        const __m256 col_dmin_f32 = f32cx8_load(q4[b].dmin);
        __m256i iacc_b = _mm256_setzero_si256();
        __m256i iacc_min_b = _mm256_setzero_si256();
        const __m256i q8sums = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q8[b].bsums));
        __m256i q8s =
            _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
        q8s = _mm256_permute2f128_si256(q8s, q8s, 0);

        for (int sb = 0; sb < QK_K / 64; ++sb) {
            const std::uint8_t * q = q4[b].qs + sb * 256;
            const __m256i raw_0123_0 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q));
            const __m256i raw_4567_0 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 32));
            const __m256i raw_0123_1 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 64));
            const __m256i raw_4567_1 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 96));
            const __m256i raw_0123_2 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 128));
            const __m256i raw_4567_2 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 160));
            const __m256i raw_0123_3 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 192));
            const __m256i raw_4567_3 = _mm256_loadu_si256(reinterpret_cast<const __m256i *>(q + 224));

            const __m256i v0123_00 = _mm256_and_si256(raw_0123_0, m4b);
            const __m256i v4567_00 = _mm256_and_si256(raw_4567_0, m4b);
            const __m256i v0123_01 = _mm256_and_si256(raw_0123_1, m4b);
            const __m256i v4567_01 = _mm256_and_si256(raw_4567_1, m4b);
            const __m256i v0123_02 = _mm256_and_si256(raw_0123_2, m4b);
            const __m256i v4567_02 = _mm256_and_si256(raw_4567_2, m4b);
            const __m256i v0123_03 = _mm256_and_si256(raw_0123_3, m4b);
            const __m256i v4567_03 = _mm256_and_si256(raw_4567_3, m4b);

            const __m256i v0123_10 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_0, 4), m4b);
            const __m256i v4567_10 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_0, 4), m4b);
            const __m256i v0123_11 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_1, 4), m4b);
            const __m256i v4567_11 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_1, 4), m4b);
            const __m256i v0123_12 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_2, 4), m4b);
            const __m256i v4567_12 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_2, 4), m4b);
            const __m256i v0123_13 = _mm256_and_si256(_mm256_srli_epi16(raw_0123_3, 4), m4b);
            const __m256i v4567_13 = _mm256_and_si256(_mm256_srli_epi16(raw_4567_3, 4), m4b);

            std::uint32_t utmp_0[4], utmp_1[4];
            std::memcpy(utmp_0, q4[b].scales + 24 * sb, 12);
            utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
            const std::uint32_t uaux_0 = utmp_0[1] & kmask1;
            utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
            utmp_0[2] = uaux_0;
            utmp_0[0] &= kmask1;

            std::memcpy(utmp_1, q4[b].scales + 12 + sb * 24, 12);
            utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
            const std::uint32_t uaux_1 = utmp_1[1] & kmask1;
            utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
            utmp_1[2] = uaux_1;
            utmp_1[0] &= kmask1;

            const __m128i mins_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
            const __m128i mins_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
            const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(mins_scales_0, scalemask));
            const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_shuffle_epi8(mins_scales_1, scalemask));
            const __m256i mins_01 = _mm256_cvtepu8_epi16(
                _mm_unpacklo_epi8(_mm_shuffle_epi32(mins_scales_0, 78), _mm_shuffle_epi32(mins_scales_1, 78)));

            const std::int8_t * y = q8[b].qs + sb * 64;
            __m256i lhs_00 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y)));
            __m256i lhs_01 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 16)));
            __m256i lhs_10 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 32)));
            __m256i lhs_11 = _mm256_castsi128_si256(_mm_loadu_si128(reinterpret_cast<const __m128i *>(y + 48)));
            lhs_00 = _mm256_permute2f128_si256(lhs_00, lhs_00, 0);
            lhs_01 = _mm256_permute2f128_si256(lhs_01, lhs_01, 0);
            lhs_10 = _mm256_permute2f128_si256(lhs_10, lhs_10, 0);
            lhs_11 = _mm256_permute2f128_si256(lhs_11, lhs_11, 0);

            __m256i iacc_0 = _mm256_setzero_si256();
            __m256i iacc_1 = _mm256_setzero_si256();

            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_00, _mm256_shuffle_epi32(v4567_00, 177), 170), _mm256_shuffle_epi32(lhs_00, 0)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_00, 177), v4567_00, 170), _mm256_shuffle_epi32(lhs_00, 85)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_01, _mm256_shuffle_epi32(v4567_01, 177), 170), _mm256_shuffle_epi32(lhs_00, 170)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_01, 177), v4567_01, 170), _mm256_shuffle_epi32(lhs_00, 255)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_02, _mm256_shuffle_epi32(v4567_02, 177), 170), _mm256_shuffle_epi32(lhs_01, 0)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_02, 177), v4567_02, 170), _mm256_shuffle_epi32(lhs_01, 85)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_03, _mm256_shuffle_epi32(v4567_03, 177), 170), _mm256_shuffle_epi32(lhs_01, 170)));
            iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_03, 177), v4567_03, 170), _mm256_shuffle_epi32(lhs_01, 255)));
            iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);

            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_10, _mm256_shuffle_epi32(v4567_10, 177), 170), _mm256_shuffle_epi32(lhs_10, 0)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_10, 177), v4567_10, 170), _mm256_shuffle_epi32(lhs_10, 85)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_11, _mm256_shuffle_epi32(v4567_11, 177), 170), _mm256_shuffle_epi32(lhs_10, 170)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_11, 177), v4567_11, 170), _mm256_shuffle_epi32(lhs_10, 255)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_12, _mm256_shuffle_epi32(v4567_12, 177), 170), _mm256_shuffle_epi32(lhs_11, 0)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_12, 177), v4567_12, 170), _mm256_shuffle_epi32(lhs_11, 85)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(v0123_13, _mm256_shuffle_epi32(v4567_13, 177), 170), _mm256_shuffle_epi32(lhs_11, 170)));
            iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(v0123_13, 177), v4567_13, 170), _mm256_shuffle_epi32(lhs_11, 255)));
            iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);

            const __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
            const __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
            const __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
            q8s = _mm256_bsrli_epi128(q8s, 4);
            iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
            iacc_min_b = _mm256_add_epi32(iacc_min_b, iacc_min_sb);
        }

        acc_row = _mm256_add_ps(
            acc_row,
            _mm256_mul_ps(_mm256_cvtepi32_ps(iacc_b), _mm256_mul_ps(col_scale_f32, row_scale_f32)));
        acc_min_rows = _mm256_add_ps(
            acc_min_rows,
            _mm256_mul_ps(_mm256_cvtepi32_ps(iacc_min_b), _mm256_mul_ps(col_dmin_f32, row_scale_f32)));
    }

    acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
    _mm256_storeu_ps(out, _mm256_sub_ps(acc_row, acc_min_rows));
#else
    gemv_q4_K_8x8_q8_K_generic(n, out, q4, q8);
#endif
}
#else
static void gemv_q4_K_8x8_q8_K(int n, float * out, const block_q4_Kx8 * q4, const block_q8_K * q8) {
    gemv_q4_K_8x8_q8_K_generic(n, out, q4, q8);
}
#endif

int main() {
    InputView input = read_input();

    if (input.size < 24 || std::memcmp(input.data, "CMQ4KX01", 8) != 0) {
        return 1;
    }

    const std::uint32_t case_count = read_u32(input.data + 8);
    const std::uint32_t rows_per_case = read_u32(input.data + 12);
    const std::uint32_t blocks_per_row = read_u32(input.data + 16);
    const std::uint32_t n = read_u32(input.data + 20);
    const unsigned char * ptr = input.data + 24;
    double total = 0.0;

    for (std::uint32_t case_index = 0; case_index < case_count; ++case_index) {
        const block_q8_K * q8 = reinterpret_cast<const block_q8_K *>(ptr);
        ptr += static_cast<std::size_t>(blocks_per_row) * sizeof(block_q8_K);

        for (std::uint32_t row_group = 0; row_group < rows_per_case / 8; ++row_group) {
            const block_q4_Kx8 * q4 = reinterpret_cast<const block_q4_Kx8 *>(ptr);
            ptr += static_cast<std::size_t>(blocks_per_row) * sizeof(block_q4_Kx8);

            float rows[8];
            gemv_q4_K_8x8_q8_K(static_cast<int>(n), rows, q4, q8);
            for (float row : rows) {
                total += static_cast<double>(row);
            }
        }
    }

    std::printf("%.2f", total);
}
02 jobs
Systems 02 jobs
03 counters
Performance counters 31 counters
cyclesi
16,115,654
Show more
branch_instructionsi
950,040
branch_missesi
15,532
cycle_activity.stalls_l1d_missi
1,855,496
cycle_activity.stalls_l2_missi
872,962
cycle_activity.stalls_l3_missi
724,963
cycle_activity.stalls_totali
2,732,983
dtlb_load_misses.walk_completedi
852
exe_activity.bound_on_loadsi
2,106,856
exe_activity.bound_on_storesi
20,235
instructionsi
52,380,658
machine_clearsi
817
mem_inst_retired.split_loadsi
1,073,625
mem_load_retired.l1_missi
126,374
mem_load_retired.l2_missi
27,829
mem_load_retired.l3_missi
22,381
tma_backend_boundi
34,666,962
tma_bad_speculationi
2,269,165
tma_branch_mispredict_slotsi
2,152,605
tma_frontend_boundi
5,998,816
tma_memory_boundi
10,443,132
tma_retiringi
54,249,589
tma_slotsi
97,165,086
uops_dispatched.port_0i
10,140,284
uops_dispatched.port_1i
10,699,023
uops_dispatched.port_2_3_10i
11,119,520
uops_dispatched.port_4_9i
2,650,194
uops_dispatched.port_5_11i
14,219,225
uops_dispatched.port_6i
4,875,922
uops_dispatched.port_7_8i
1,725,806
uops_retired.msi
0
04 top down
Top-down analysis Raptor Cove P-core
05 profile
load profile