Tree Hash Traversal
This is the hot kernel from Anthropic's original performance take-home, without the Python VM or simulated VLIW machine.
Input is binary little-endian u32 words: rounds, node_count, batch_size, tree_height, forest_values_offset, indices_offset, values_offset, then forest values, current node indices, and current input values.
The benchmark uses tree_height = 17, rounds = 128, and batch_size = 262,144. Values are 32-bit words and arithmetic wraps modulo 2^32.
For each round and lane: value = hash32(value xor forest[index]), then go to child 2*index+1 if value is even, otherwise 2*index+2. If the child is past the end of the perfect binary tree, wrap index back to 0.
Print the decimal u64 checksum of the final (index, value) batch state.
Source: https://github.com/anthropics/original_performance_takehome.
use std::io::Read;
fn main() {
let mut input = Vec::new();
std::io::stdin().read_to_end(&mut input).unwrap();
let mut words = Vec::with_capacity(input.len() / 4);
for chunk in input.chunks_exact(4) {
words.push(u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
let rounds = words[0] as usize;
let node_count = words[1] as usize;
let batch_size = words[2] as usize;
let forest_values_offset = words[4] as usize;
let indices_offset = words[5] as usize;
let values_offset = words[6] as usize;
for _ in 0..rounds {
for i in 0..batch_size {
let index = words[indices_offset + i] as usize;
let value = hash32(words[values_offset + i] ^ words[forest_values_offset + index]);
let mut next = index * 2 + if value & 1 == 0 { 1 } else { 2 };
if next >= node_count {
next = 0;
}
words[values_offset + i] = value;
words[indices_offset + i] = next as u32;
}
}
let mut acc = 0xcbf2_9ce4_8422_2325_u64;
for i in 0..batch_size {
let pair = ((words[values_offset + i] as u64) << 32) | u64::from(words[indices_offset + i]);
acc ^= pair;
acc = acc.wrapping_mul(0x0000_0100_0000_01b3);
}
print!("{acc}");
}
fn hash32(mut a: u32) -> u32 {
let x = a;
a = x.wrapping_add(0x7ed5_5d16).wrapping_add(x << 12);
let x = a;
a = (x ^ 0xc761_c23c) ^ (x >> 19);
let x = a;
a = x.wrapping_add(0x1656_67b1).wrapping_add(x << 5);
let x = a;
a = x.wrapping_add(0xd3a2_646c) ^ (x << 9);
let x = a;
a = x.wrapping_add(0xfd70_46c5).wrapping_add(x << 3);
let x = a;
(x ^ 0xb55a_4f09) ^ (x >> 16)
}