solution
sol_358578_1778384723756220974_3
01
source
Submitted source
12242 bytes
show source
use std::arch::x86_64::*;
use std::io::{Read, Write};
const KARATSUBA_THRESHOLD: usize = 128;
fn main() {
let mut input = Vec::with_capacity(1 << 20);
std::io::stdin().lock().read_to_end(&mut input).unwrap();
for _ in 0..32 {
input.push(b' ');
}
let mut p = 0usize;
let n = read_u64_fast(&input, &mut p) as usize;
let m = read_u64_fast(&input, &mut p) as usize;
let mut a = vec![0u64; n];
for i in 0..n {
a[i] = read_u64_fast(&input, &mut p);
}
let mut b = vec![0u64; m];
for j in 0..m {
b[j] = read_u64_fast(&input, &mut p);
}
let out_len = n + m - 1;
let c: Vec<__m128i> = unsafe { karatsuba_top(&a, &b) };
let mut out = Vec::with_capacity(out_len * 21 + 16);
let mut tmp = [0u8; 24];
for k in 0..out_len {
if k > 0 {
out.push(b' ');
}
let v = unsafe { reduce(c[k]) };
let s = u64_to_str(v, &mut tmp);
out.extend_from_slice(s);
}
out.push(b'\n');
let mut stdout = std::io::stdout().lock();
stdout.write_all(&out).unwrap();
}
#[inline(always)]
fn parse_8_digits_swar(chunk: u64) -> u64 {
let sub = chunk.wrapping_sub(0x3030303030303030);
let lower_digits = (sub & 0x0F00_0F00_0F00_0F00) >> 8;
let upper_digits = (sub & 0x000F_000F_000F_000F) * 10;
let chunk = lower_digits + upper_digits;
let lower_pairs = (chunk & 0x00FF_0000_00FF_0000) >> 16;
let upper_pairs = (chunk & 0x0000_00FF_0000_00FF) * 100;
let chunk = lower_pairs + upper_pairs;
let lower = (chunk & 0x0000_FFFF_0000_0000) >> 32;
let upper = (chunk & 0x0000_0000_0000_FFFF) * 10000;
lower + upper
}
#[inline]
fn read_u64_fast(buf: &[u8], idx: &mut usize) -> u64 {
let mut p = *idx;
let len = buf.len();
while p < len && buf[p].wrapping_sub(b'0') > 9 {
p += 1;
}
let mut v: u64 = 0;
while p + 8 <= len {
let chunk = u64::from_le_bytes(buf[p..p + 8].try_into().unwrap());
let sub = chunk.wrapping_sub(0x3030303030303030);
let plus = sub.wrapping_add(0x7676767676767676);
if (plus | sub) & 0x8080808080808080 != 0 {
break;
}
let parsed = parse_8_digits_swar(chunk);
v = v.wrapping_mul(100_000_000).wrapping_add(parsed);
p += 8;
}
while p < len && buf[p].wrapping_sub(b'0') <= 9 {
v = v.wrapping_mul(10).wrapping_add((buf[p] - b'0') as u64);
p += 1;
}
*idx = p;
v
}
fn u64_to_str(mut v: u64, buf: &mut [u8; 24]) -> &[u8] {
if v == 0 {
buf[0] = b'0';
return &buf[..1];
}
let mut i = 24;
while v > 0 {
i -= 1;
buf[i] = b'0' + (v % 10) as u8;
v /= 10;
}
&buf[i..]
}
#[target_feature(enable = "vpclmulqdq,avx2,pclmulqdq,sse2")]
unsafe fn karatsuba_top(a: &[u64], b: &[u64]) -> Vec<__m128i> {
unsafe {
let n = a.len();
let m = b.len();
let out_len = n + m - 1;
let mut result: Vec<__m128i> = vec![_mm_setzero_si128(); out_len];
// Compute total scratch needs.
// For Karatsuba on size N: needs a_xor, b_xor of size N/2, plus result of size N-1 for m_xx,
// plus same for sub-calls.
// Pre-allocate generously.
let scratch_u64_size = a_xor_total(n);
let scratch_m128_size = scratch_total(n);
let mut scratch_u64: Vec<u64> = vec![0u64; scratch_u64_size];
let mut scratch_m128: Vec<__m128i> = vec![_mm_setzero_si128(); scratch_m128_size];
karatsuba_into(
a,
b,
&mut result,
&mut scratch_u64,
&mut scratch_m128,
);
result
}
}
fn a_xor_total(n: usize) -> usize {
// 2 * n/2 (a_xor + b_xor) plus same for the m_xx recursion only (siblings serial,
// but m_xx call needs its own a_xor, b_xor space).
if n <= KARATSUBA_THRESHOLD || (n & 1) != 0 {
return 0;
}
let h = n / 2;
2 * h + a_xor_total(h)
}
fn scratch_total(n: usize) -> usize {
if n <= KARATSUBA_THRESHOLD || (n & 1) != 0 {
return 2 * n - 1; // m_xx-equivalent size for direct case (not used at top, but for safety)
}
// m_xx result: 2*h - 1, plus sub-recursion scratch (only m_xx subcall needs its own)
let h = n / 2;
let m_xx_size = 2 * h - 1;
// sub calls' scratch (max of any sub, but they're sequential — reuse possible)
let sub_scratch = scratch_total(h);
m_xx_size + sub_scratch
}
#[target_feature(enable = "vpclmulqdq,avx2,pclmulqdq,sse2")]
unsafe fn karatsuba_into(
a: &[u64],
b: &[u64],
out: &mut [__m128i],
scratch_u64: &mut [u64],
scratch_m128: &mut [__m128i],
) {
unsafe {
let n = a.len();
let m = b.len();
let out_len = n + m - 1;
if n != m || n <= KARATSUBA_THRESHOLD || (n & 1) != 0 {
// Direct convolve into out.
convolve_direct_into(a, b, out);
return;
}
let h = n / 2;
let a_lo = &a[..h];
let a_hi = &a[h..];
let b_lo = &b[..h];
let b_hi = &b[h..];
// Use scratch_u64[0..h] for a_xor, [h..2h] for b_xor.
let (axbx, scratch_u64_rest) = scratch_u64.split_at_mut(2 * h);
{
let chunks = h / 4;
for k in 0..chunks {
let p = k * 4;
let av = _mm256_xor_si256(
_mm256_loadu_si256(a_lo.as_ptr().add(p) as *const __m256i),
_mm256_loadu_si256(a_hi.as_ptr().add(p) as *const __m256i),
);
_mm256_storeu_si256(axbx.as_mut_ptr().add(p) as *mut __m256i, av);
let bv = _mm256_xor_si256(
_mm256_loadu_si256(b_lo.as_ptr().add(p) as *const __m256i),
_mm256_loadu_si256(b_hi.as_ptr().add(p) as *const __m256i),
);
_mm256_storeu_si256(axbx.as_mut_ptr().add(h + p) as *mut __m256i, bv);
}
for k in chunks * 4..h {
axbx[k] = a_lo[k] ^ a_hi[k];
axbx[h + k] = b_lo[k] ^ b_hi[k];
}
}
let len_sub = 2 * h - 1;
// m_xx_buf at start of scratch_m128, sub-call scratch follows.
let (m_xx_buf, m_xx_sub_scratch) = scratch_m128.split_at_mut(len_sub);
// Compute m_ll into out[0..len_sub]
// Compute m_hh into out[2h..2h+len_sub]
// First compute m_ll (writes to out[0..len_sub])
karatsuba_into(
a_lo,
b_lo,
&mut out[..len_sub],
scratch_u64_rest,
m_xx_sub_scratch,
);
// Then m_hh into out[2h..2h+len_sub]
karatsuba_into(
a_hi,
b_hi,
&mut out[2 * h..2 * h + len_sub],
scratch_u64_rest,
m_xx_sub_scratch,
);
// Compute m_xx into m_xx_buf
let (a_xor_slice, b_xor_slice) = axbx.split_at(h);
karatsuba_into(
a_xor_slice,
b_xor_slice,
m_xx_buf,
scratch_u64_rest,
m_xx_sub_scratch,
);
// Set out[len_sub] (i.e., out[2h-1]) = 0 (gap between m_ll and m_hh ranges).
out[len_sub] = _mm_setzero_si128();
// Compute mid = m_xx ^ out[..len_sub] (= m_ll) ^ out[2h..2h+len_sub] (= m_hh) into m_xx_buf.
{
let chunks = len_sub / 2;
let mp = m_xx_buf.as_mut_ptr();
let lp = out.as_ptr(); // m_ll at offset 0
let hp = out.as_ptr().add(2 * h); // m_hh at offset 2h
for k in 0..chunks {
let q = k * 2;
let mv = _mm256_loadu_si256(mp.add(q) as *const __m256i);
let lv = _mm256_loadu_si256(lp.add(q) as *const __m256i);
let hv = _mm256_loadu_si256(hp.add(q) as *const __m256i);
let r = _mm256_xor_si256(_mm256_xor_si256(mv, lv), hv);
_mm256_storeu_si256(mp.add(q) as *mut __m256i, r);
}
for k in chunks * 2..len_sub {
m_xx_buf[k] =
_mm_xor_si128(_mm_xor_si128(m_xx_buf[k], *lp.add(k)), *hp.add(k));
}
}
// Add mid to out[h..h+len_sub]
{
let chunks = len_sub / 2;
let op = out.as_mut_ptr().add(h);
let mp = m_xx_buf.as_ptr();
for k in 0..chunks {
let q = k * 2;
let ov = _mm256_loadu_si256(op.add(q) as *const __m256i);
let mv = _mm256_loadu_si256(mp.add(q) as *const __m256i);
_mm256_storeu_si256(op.add(q) as *mut __m256i, _mm256_xor_si256(ov, mv));
}
for k in chunks * 2..len_sub {
let cell = op.add(k);
*cell = _mm_xor_si128(*cell, *mp.add(k));
}
}
let _ = out_len;
}
}
#[target_feature(enable = "vpclmulqdq,avx2,pclmulqdq,sse2")]
unsafe fn convolve_direct_into(a: &[u64], b_in: &[u64], out: &mut [__m128i]) {
unsafe {
let n = a.len();
let m = b_in.len();
let out_len = n + m - 1;
const PAD_FRONT: usize = 2;
const PAD_BACK: usize = 4;
let bz_len = PAD_FRONT + m + PAD_BACK;
let mut bz: Vec<__m128i> = vec![_mm_setzero_si128(); bz_len];
for j in 0..m {
bz[PAD_FRONT + j] = _mm_cvtsi64_si128(b_in[j] as i64);
}
let bz_ptr = bz.as_ptr();
let aptr = a.as_ptr();
let optr = out.as_mut_ptr();
let num_blocks = (out_len + 1) / 2;
for g in 0..num_blocks {
let k0 = 2 * g;
let mut acc = _mm256_setzero_si256();
let i_lo = (k0 + 1).saturating_sub(m);
let i_hi = n.min(k0 + 2);
let s0 = PAD_FRONT + k0 - i_lo;
let mut bp = bz_ptr.add(s0);
let mut ap = aptr.add(i_lo);
let a_end = aptr.add(i_hi);
let chunks = (i_hi - i_lo) / 4;
for _ in 0..chunks {
let ai0 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap as i64));
let ai1 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap.add(1) as i64));
let ai2 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap.add(2) as i64));
let ai3 = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap.add(3) as i64));
let p0 = _mm256_clmulepi64_epi128(ai0, *(bp as *const __m256i), 0x00);
let p1 = _mm256_clmulepi64_epi128(ai1, *(bp.sub(1) as *const __m256i), 0x00);
let p2 = _mm256_clmulepi64_epi128(ai2, *(bp.sub(2) as *const __m256i), 0x00);
let p3 = _mm256_clmulepi64_epi128(ai3, *(bp.sub(3) as *const __m256i), 0x00);
let s01 = _mm256_xor_si256(p0, p1);
let s23 = _mm256_xor_si256(p2, p3);
acc = _mm256_xor_si256(acc, _mm256_xor_si256(s01, s23));
ap = ap.add(4);
bp = bp.sub(4);
}
while ap < a_end {
let ai = _mm256_broadcastsi128_si256(_mm_cvtsi64_si128(*ap as i64));
let prod = _mm256_clmulepi64_epi128(ai, *(bp as *const __m256i), 0x00);
acc = _mm256_xor_si256(acc, prod);
ap = ap.add(1);
bp = bp.sub(1);
}
// Output writes 2 cells (k0, k0+1). Last block at k0 = 2*(num_blocks-1) might write past out_len-1
// (specifically k0+1 = out_len if out_len is odd, then we'd write 1 cell beyond).
if k0 + 1 < out_len {
let dst = optr.add(k0) as *mut __m256i;
_mm256_storeu_si256(dst, acc);
} else {
// Only k0 == out_len - 1 valid. Store low half.
let lo = _mm256_castsi256_si128(acc);
_mm_storeu_si128(optr.add(k0), lo);
}
}
}
}
#[inline(always)]
unsafe fn reduce(p: __m128i) -> u64 {
let arr: [u64; 2] = std::mem::transmute(p);
let low = arr[0];
let high = arr[1];
let mut r = low ^ high ^ (high << 1) ^ (high << 3) ^ (high << 4);
let carry = (high >> 63) ^ (high >> 61) ^ (high >> 60);
r ^= carry ^ (carry << 1) ^ (carry << 3) ^ (carry << 4);
r
}
02
jobs
Systems
02 jobs
03
counters
Performance counters
18 counters
cyclesi
9,006,680Show more
branch_instructionsi
2,302,725branch_missesi
21,575dtlb_load_misses.walk_completedi
398instructionsi
27,630,097mem_load_retired.l1_missi
31,001mem_load_retired.l2_missi
1,269mem_load_retired.l3_missi
440tma_backend_boundi
13,868,508tma_bad_speculationi
3,090,909tma_retiringi
30,837,733uops_dispatched.port_0i
3,289,878uops_dispatched.port_1i
3,206,296uops_dispatched.port_2_3_10i
9,744,810uops_dispatched.port_4_9i
1,427,629uops_dispatched.port_5_11i
7,535,477uops_dispatched.port_6i
4,030,608uops_dispatched.port_7_8i
1,379,614
04
top down
Top-down analysis
Raptor Cove P-core
05
profile
Profile
not available
No profile is available for this job.
03
counters
Performance counters
17 counters
cyclesi
13,871,220Show more
branch_instructionsi
2,305,637branch_missesi
21,707dtlb_load_misses.walk_completedi
3,151instructionsi
27,644,791mem_bound_stalls.load_dram_hiti
104,538mem_bound_stalls.load_l2_hiti
199,480mem_bound_stalls.load_llc_hiti
32,135mem_inst_retired.split_loadsi
4,818mem_load_retired.l1_missi
228,385mem_load_retired.l2_missi
1,762mem_load_retired.l3_missi
868tma_backend_boundi
19,389,137tma_bad_speculationi
3,594,626tma_frontend_boundi
5,672,840tma_memory_boundi
2,596,235tma_retiringi
40,742,603
04
top down
Top-down analysis
Gracemont E-core
05
profile
load profile