cpu.mode fastest code on the internet
solution

sol_358578_1778384723756220974_3

Rust s7nfo 2 runs public
01 source
Submitted source 12242 bytes
Compiler rustc Flags --edition=2024 -C opt-level=3 -C panic=abort -C codegen-units=1 -C target-cpu=native -C target-feature=+crt-static
show source
use std::arch::x86_64::*;
use std::io::{Read, Write};

const KARATSUBA_THRESHOLD: usize = 128;

fn main() {
    let mut input = Vec::with_capacity(1 << 20);
    std::io::stdin().lock().read_to_end(&mut input).unwrap();
    for _ in 0..32 {
        input.push(b' ');
    }

    let mut p = 0usize;
    let n = read_u64_fast(&input, &mut p) as usize;
    let m = read_u64_fast(&input, &mut p) as usize;
    let mut a = vec![0u64; n];
    for i in 0..n {
        a[i] = read_u64_fast(&input, &mut p);
    }
    let mut b = vec![0u64; m];
    for j in 0..m {
        b[j] = read_u64_fast(&input, &mut p);
    }

    let out_len = n + m - 1;
    let c: Vec<__m128i> = unsafe { karatsuba_top(&a, &b) };

    let mut out = Vec::with_capacity(out_len * 21 + 16);
    let mut tmp = [0u8; 24];
    for k in 0..out_len {
        if k > 0 {
            out.push(b' ');
        }
        let v = unsafe { reduce(c[k]) };
        let s = u64_to_str(v, &mut tmp);
        out.extend_from_slice(s);
    }
    out.push(b'\n');
    let mut stdout = std::io::stdout().lock();
    stdout.write_all(&out).unwrap();
}

#[inline(always)]
fn parse_8_digits_swar(chunk: u64) -> u64 {
    let sub = chunk.wrapping_sub(0x3030303030303030);
    let lower_digits = (sub & 0x0F00_0F00_0F00_0F00) >> 8;
    let upper_digits = (sub & 0x000F_000F_000F_000F) * 10;
    let chunk = lower_digits + upper_digits;
    let lower_pairs = (chunk & 0x00FF_0000_00FF_0000) >> 16;
    let upper_pairs = (chunk & 0x0000_00FF_0000_00FF) * 100;
    let chunk = lower_pairs + upper_pairs;
    let lower = (chunk & 0x0000_FFFF_0000_0000) >> 32;
    let upper = (chunk & 0x0000_0000_0000_FFFF) * 10000;
    lower + upper
}

#[inline]
fn read_u64_fast(buf: &[u8], idx: &mut usize) -> u64 {
    let mut p = *idx;
    let len = buf.len();
    while p < len && buf[p].wrapping_sub(b'0') > 9 {
        p += 1;
    }
    let mut v: u64 = 0;
    while p + 8 <= len {
        let chunk = u64::from_le_bytes(buf[p..p + 8].try_into().unwrap());
        let sub = chunk.wrapping_sub(0x3030303030303030);
        let plus = sub.wrapping_add(0x7676767676767676);
        if (plus | sub) & 0x8080808080808080 != 0 {
            break;
        }
        let parsed = parse_8_digits_swar(chunk);
        v = v.wrapping_mul(100_000_000).wrapping_add(parsed);
        p += 8;
    }
    while p < len && buf[p].wrapping_sub(b'0') <= 9 {
        v = v.wrapping_mul(10).wrapping_add((buf[p] - b'0') as u64);
        p += 1;
    }
    *idx = p;
    v
}

fn u64_to_str(mut v: u64, buf: &mut [u8; 24]) -> &[u8] {
    if v == 0 {
        buf[0] = b'0';
        return &buf[..1];
    }
    let mut i = 24;
    while v > 0 {
        i -= 1;
        buf[i] = b'0' + (v % 10) as u8;
        v /= 10;
    }
    &buf[i..]
}

#[target_feature(enable = "vpclmulqdq,avx2,pclmulqdq,sse2")]
unsafe fn karatsuba_top(a: &[u64], b: &[u64]) -> Vec<__m128i> {
    unsafe {
        let n = a.len();
        let m = b.len();
        let out_len = n + m - 1;
        let mut result: Vec<__m128i> = vec![_mm_setzero_si128(); out_len];

        // Compute total scratch needs.
        // For Karatsuba on size N: needs a_xor, b_xor of size N/2, plus result of size N-1 for m_xx,
        // plus same for sub-calls.
        // Pre-allocate generously.
        let scratch_u64_size = a_xor_total(n);
        let scratch_m128_size = scratch_total(n);
        let mut scratch_u64: Vec<u64> = vec![0u64; scratch_u64_size];
        let mut scratch_m128: Vec<__m128i> = vec![_mm_setzero_si128(); scratch_m128_size];

        karatsuba_into(
            a,
            b,
            &mut result,
            &mut scratch_u64,
            &mut scratch_m128,
        );
        result
    }
}

fn a_xor_total(n: usize) -> usize {
    // 2 * n/2 (a_xor + b_xor) plus same for the m_xx recursion only (siblings serial,
    // but m_xx call needs its own a_xor, b_xor space).
    if n <= KARATSUBA_THRESHOLD || (n & 1) != 0 {
        return 0;
    }
    let h = n / 2;
    2 * h + a_xor_total(h)
}

fn scratch_total(n: usize) -> usize {
    if n <= KARATSUBA_THRESHOLD || (n & 1) != 0 {
        return 2 * n - 1; // m_xx-equivalent size for direct case (not used at top, but for safety)
    }
    // m_xx result: 2*h - 1, plus sub-recursion scratch (only m_xx subcall needs its own)
    let h = n / 2;
    let m_xx_size = 2 * h - 1;
    // sub calls' scratch (max of any sub, but they're sequential — reuse possible)
    let sub_scratch = scratch_total(h);
    m_xx_size + sub_scratch
}

#[target_feature(enable = "vpclmulqdq,avx2,pclmulqdq,sse2")]
unsafe fn karatsuba_into(
    a: &[u64],
    b: &[u64],
    out: &mut [__m128i],
    scratch_u64: &mut [u64],
    scratch_m128: &mut [__m128i],
) {
    unsafe {
        let n = a.len();
        let m = b.len();
        let out_len = n + m - 1;

        if n != m || n <= KARATSUBA_THRESHOLD || (n & 1) != 0 {
            // Direct convolve into out.
            convolve_direct_into(a, b, out);
            return;
        }

        let h = n / 2;
        let a_lo = &a[..h];
        let a_hi = &a[h..];
        let b_lo = &b[..h];
        let b_hi = &b[h..];

        // Use scratch_u64[0..h] for a_xor, [h..2h] for b_xor.
        let (axbx, scratch_u64_rest) = scratch_u64.split_at_mut(2 * h);
        {
            let chunks = h / 4;
            for k in 0..chunks {
                let p = k * 4;
                let av = _mm256_xor_si256(
                    _mm256_loadu_si256(a_lo.as_ptr().add(p) as *const __m256i),
                    _mm256_loadu_si256(a_hi.as_ptr().add(p) as *const __m256i),
                );
                _mm256_storeu_si256(axbx.as_mut_ptr().add(p) as *mut __m256i, av);
                let bv = _mm256_xor_si256(
                    _mm256_loadu_si256(b_lo.as_ptr().add(p) as *const __m256i),
                    _mm256_loadu_si256(b_hi.as_ptr().add(p) as *const __m256i),
                );
                _mm256_storeu_si256(axbx.as_mut_ptr().add(h + p) as *mut __m256i, bv);
            }
            for k in chunks * 4..h {
                axbx[k] = a_lo[k] ^ a_hi[k];
                axbx[h + k] = b_lo[k] ^ b_hi[k];
            }
        }
        let len_sub = 2 * h - 1;

        // m_xx_buf at start of scratch_m128, sub-call scratch follows.
        let (m_xx_buf, m_xx_sub_scratch) = scratch_m128.split_at_mut(len_sub);

        // Compute m_ll into out[0..len_sub]
        // Compute m_hh into out[2h..2h+len_sub]
        // First compute m_ll (writes to out[0..len_sub])
        karatsuba_into(
            a_lo,
            b_lo,
            &mut out[..len_sub],
            scratch_u64_rest,
            m_xx_sub_scratch,
        );
        // Then m_hh into out[2h..2h+len_sub]
        karatsuba_into(
            a_hi,
            b_hi,
            &mut out[2 * h..2 * h + len_sub],
            scratch_u64_rest,
            m_xx_sub_scratch,
        );
        // Compute m_xx into m_xx_buf
        let (a_xor_slice, b_xor_slice) = axbx.split_at(h);
        karatsuba_into(
            a_xor_slice,
            b_xor_slice,
            m_xx_buf,
            scratch_u64_rest,
            m_xx_sub_scratch,
        );

        // Set out[len_sub] (i.e., out[2h-1]) = 0 (gap between m_ll and m_hh ranges).
        out[len_sub] = _mm_setzero_si128();

        // Compute mid = m_xx ^ out[..len_sub] (= m_ll) ^ out[2h..2h+len_sub] (= m_hh) into m_xx_buf.
        {
            let chunks = len_sub / 2;
            let mp = m_xx_buf.as_mut_ptr();
            let lp = out.as_ptr(); // m_ll at offset 0
            let hp = out.as_ptr().add(2 * h); // m_hh at offset 2h
            for k in 0..chunks {
                let q = k * 2;
                let mv = _mm256_loadu_si256(mp.add(q) as *const __m256i);
                let lv = _mm256_loadu_si256(lp.add(q) as *const __m256i);
                let hv = _mm256_loadu_si256(hp.add(q) as *const __m256i);
                let r = _mm256_xor_si256(_mm256_xor_si256(mv, lv), hv);
                _mm256_storeu_si256(mp.add(q) as *mut __m256i, r);
            }
            for k in chunks * 2..len_sub {
                m_xx_buf[k] =
                    _mm_xor_si128(_mm_xor_si128(m_xx_buf[k], *lp.add(k)), *hp.add(k));
            }
        }

        // Add mid to out[h..h+len_sub]
        {
            let chunks = len_sub / 2;
            let op = out.as_mut_ptr().add(h);
            let mp = m_xx_buf.as_ptr();
            for k in 0..chunks {
                let q = k * 2;
                let ov = _mm256_loadu_si256(op.add(q) as *const __m256i);
                let mv = _mm256_loadu_si256(mp.add(q) as *const __m256i);
                _mm256_storeu_si256(op.add(q) as *mut __m256i, _mm256_xor_si256(ov, mv));
            }
            for k in chunks * 2..len_sub {
                let cell = op.add(k);
                *cell = _mm_xor_si128(*cell, *mp.add(k));
            }
        }

        let _ = out_len;
    }
}

#[target_feature(enable = "vpclmulqdq,avx2,pclmulqdq,sse2")]
unsafe fn convolve_direct_into(a: &[u64], b_in: &[u64], out: &mut [__m128i]) {
    unsafe {
        let n = a.len();
        let m = b_in.len();
        let out_len = n + m - 1;

        const PAD_FRONT: usize = 2;
        const PAD_BACK: usize = 4;
        let bz_len = PAD_FRONT + m + PAD_BACK;
        let mut bz: Vec<__m128i> = vec![_mm_setzero_si128(); bz_len];
        for j in 0..m {
            bz[PAD_FRONT + j] = _mm_cvtsi64_si128(b_in[j] as i64);
        }

        let bz_ptr = bz.as_ptr();
        let aptr = a.as_ptr();
        let optr = out.as_mut_ptr();

        let num_blocks = (out_len + 1) / 2;
        for g in 0..num_blocks {
            let k0 = 2 * g;
            let mut acc = _mm256_setzero_si256();

            let i_lo = (k0 + 1).saturating_sub(m);
            let i_hi = n.min(k0 + 2);

            let s0 = PAD_FRONT + k0 - i_lo;
            let mut bp = bz_ptr.add(s0);
            let mut ap = aptr.add(i_lo);
            let a_end = aptr.add(i_hi);

            let chunks = (i_hi - i_lo) / 4;
            for _ in 0..chunks {
                let ai0 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap as i64));
                let ai1 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap.add(1) as i64));
                let ai2 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap.add(2) as i64));
                let ai3 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap.add(3) as i64));
                let p0 = _mm256_clmulepi64_epi128(ai0, *(bp as *const __m256i), 0x00);
                let p1 = _mm256_clmulepi64_epi128(ai1, *(bp.sub(1) as *const __m256i), 0x00);
                let p2 = _mm256_clmulepi64_epi128(ai2, *(bp.sub(2) as *const __m256i), 0x00);
                let p3 = _mm256_clmulepi64_epi128(ai3, *(bp.sub(3) as *const __m256i), 0x00);
                let s01 = _mm256_xor_si256(p0, p1);
                let s23 = _mm256_xor_si256(p2, p3);
                acc = _mm256_xor_si256(acc, _mm256_xor_si256(s01, s23));
                ap = ap.add(4);
                bp = bp.sub(4);
            }

            while ap < a_end {
                let ai = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap as i64));
                let prod = _mm256_clmulepi64_epi128(ai, *(bp as *const __m256i), 0x00);
                acc = _mm256_xor_si256(acc, prod);
                ap = ap.add(1);
                bp = bp.sub(1);
            }

            // Output writes 2 cells (k0, k0+1). Last block at k0 = 2*(num_blocks-1) might write past out_len-1
            // (specifically k0+1 = out_len if out_len is odd, then we'd write 1 cell beyond).
            if k0 + 1 < out_len {
                let dst = optr.add(k0) as *mut __m256i;
                _mm256_storeu_si256(dst, acc);
            } else {
                // Only k0 == out_len - 1 valid. Store low half.
                let lo = _mm256_castsi256_si128(acc);
                _mm_storeu_si128(optr.add(k0), lo);
            }
        }
    }
}

#[inline(always)]
unsafe fn reduce(p: __m128i) -> u64 {
    let arr: [u64; 2] = std::mem::transmute(p);
    let low = arr[0];
    let high = arr[1];
    let mut r = low ^ high ^ (high << 1) ^ (high << 3) ^ (high << 4);
    let carry = (high >> 63) ^ (high >> 61) ^ (high >> 60);
    r ^= carry ^ (carry << 1) ^ (carry << 3) ^ (carry << 4);
    r
}
02 jobs
Systems 02 jobs
03 counters
Performance counters 18 counters
cyclesi
9,006,680
Show more
branch_instructionsi
2,302,725
branch_missesi
21,575
dtlb_load_misses.walk_completedi
398
instructionsi
27,630,097
mem_load_retired.l1_missi
31,001
mem_load_retired.l2_missi
1,269
mem_load_retired.l3_missi
440
tma_backend_boundi
13,868,508
tma_bad_speculationi
3,090,909
tma_retiringi
30,837,733
uops_dispatched.port_0i
3,289,878
uops_dispatched.port_1i
3,206,296
uops_dispatched.port_2_3_10i
9,744,810
uops_dispatched.port_4_9i
1,427,629
uops_dispatched.port_5_11i
7,535,477
uops_dispatched.port_6i
4,030,608
uops_dispatched.port_7_8i
1,379,614
04 top down
Top-down analysis Raptor Cove P-core
05 profile
Profile not available
No profile is available for this job.