Skip to content

Commit 81ae18a

Browse files
authored
Use the groestl-style T-table lookup for fast kupyna hashing (#778)
A lot of the galois-field ops were running on each and every column load. I didn't do anything clever or original, aside from understanding how the groestl implementation worked, and shamelessly copying as much from that as possible. Once that cleaned out a big pile of intermediate functions, I picked some low-hanging fruit with the new `const fn compute_src_cols`. Moving all the field xor ops into a groestl-esque T-table sees a significant 8x speedup. One weird issue I hit was on the long-hashes, LLVM stops automatically unrolling the loops once they're longer than 8 iterations, which was seeing a major dropoff in throughput. I could've manually unrolled the whole thing, but chunking it into two separate 8-longs (for a 16-byte long chunk) gives me equivalent compiled code as both will unroll it, even though it's a bit silly to read. Benchmarks before this PR: ``` running 8 tests test kupyna256_10 ... bench: 276.20 ns/iter (+/- 5.70) = 36 MB/s test kupyna256_100 ... bench: 2,761.53 ns/iter (+/- 44.17) = 36 MB/s test kupyna256_1000 ... bench: 27,618.72 ns/iter (+/- 235.44) = 36 MB/s test kupyna256_10000 ... bench: 274,159.70 ns/iter (+/- 3,121.12) = 36 MB/s test kupyna512_10 ... bench: 391.96 ns/iter (+/- 21.09) = 25 MB/s test kupyna512_100 ... bench: 3,915.42 ns/iter (+/- 142.03) = 25 MB/s test kupyna512_1000 ... bench: 39,039.71 ns/iter (+/- 182.78) = 25 MB/s test kupyna512_10000 ... bench: 389,250.45 ns/iter (+/- 1,132.40) = 25 MB/s ``` After swapping the galois-field xor, subbytes loading, and shifting with a quick lookup in the T-table: ``` running 8 tests test kupyna256_10 ... bench: 38.17 ns/iter (+/- 3.66) = 263 MB/s test kupyna256_100 ... bench: 370.85 ns/iter (+/- 5.36) = 270 MB/s test kupyna256_1000 ... bench: 3,568.39 ns/iter (+/- 901.09) = 280 MB/s test kupyna256_10000 ... bench: 36,356.67 ns/iter (+/- 1,409.35) = 275 MB/s test kupyna512_10 ... bench: 48.18 ns/iter (+/- 0.31) = 208 MB/s test kupyna512_100 ... bench: 462.74 ns/iter (+/- 21.66) = 216 MB/s test kupyna512_1000 ... bench: 4,738.49 ns/iter (+/- 250.94) = 211 MB/s test kupyna512_10000 ... bench: 47,398.92 ns/iter (+/- 1,137.88) = 210 MB/s ``` Then the `const fn computer_src_cols` instead of computing at runtime: ``` running 8 tests test kupyna256_10 ... bench: 33.61 ns/iter (+/- 0.35) = 303 MB/s test kupyna256_100 ... bench: 338.50 ns/iter (+/- 24.91) = 295 MB/s test kupyna256_1000 ... bench: 3,369.35 ns/iter (+/- 85.24) = 296 MB/s test kupyna256_10000 ... bench: 32,791.09 ns/iter (+/- 315.57) = 304 MB/s test kupyna512_10 ... bench: 45.69 ns/iter (+/- 2.19) = 222 MB/s test kupyna512_100 ... bench: 442.11 ns/iter (+/- 26.73) = 226 MB/s test kupyna512_1000 ... bench: 4,370.52 ns/iter (+/- 55.67) = 228 MB/s test kupyna512_10000 ... bench: 43,252.35 ns/iter (+/- 252.63) = 231 MB/s ``` Overall a ~8.5x speedup for the 256 hashes, and >9x for the 512!
1 parent a478ace commit 81ae18a

5 files changed

Lines changed: 196 additions & 155 deletions

File tree

kupyna/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pub mod block_api;
1515
mod consts;
1616
mod long;
1717
mod short;
18+
mod table;
1819
pub(crate) mod utils;
1920

2021
use digest::consts::{U28, U32, U48, U64};

kupyna/src/long.rs

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,88 @@
1-
use crate::utils::{
2-
add_constant_plus, add_constant_xor, apply_s_box, mix_columns, read_u64s_be, xor,
3-
};
4-
use core::array;
1+
use crate::table::TABLE;
2+
use crate::utils::{read_u64s_be, xor};
53

64
pub(crate) const COLS: usize = 16;
7-
const ROUNDS: u64 = 14;
5+
const ROUNDS: usize = 14;
6+
7+
// Bit shift amounts to extract each byte from a u64
8+
const BYTE_SHIFTS: [usize; 8] = [56, 48, 40, 32, 24, 16, 8, 0];
9+
10+
// ShiftRows offsets for long variant: rows 0-6 shift by index, row 7 shifts by 11
11+
const SHIFTS: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 11];
12+
13+
// Precomputed source columns: SRC_COLS[col][row] = (col + COLS - SHIFTS[row]) % COLS
14+
const fn compute_src_cols() -> [[usize; 8]; COLS] {
15+
let mut result = [[0; 8]; COLS];
16+
let mut col = 0;
17+
while col < COLS {
18+
let mut row = 0;
19+
while row < 8 {
20+
result[col][row] = (col + COLS - SHIFTS[row]) % COLS;
21+
row += 1;
22+
}
23+
col += 1;
24+
}
25+
result
26+
}
27+
const SRC_COLS: [[usize; 8]; COLS] = compute_src_cols();
828

929
pub(crate) fn compress(prev_vector: &mut [u64; COLS], message_block: &[u8; 128]) {
10-
// Convert message block from u8 to u64 (column-major order as per paper)
1130
let message_u64 = read_u64s_be::<128, COLS>(message_block);
1231
let m_xor_p = xor(*prev_vector, message_u64);
1332
let t_xor_mp = t_xor_l(m_xor_p);
1433
let t_plus_m = t_plus_l(message_u64);
1534
*prev_vector = xor(xor(t_xor_mp, t_plus_m), *prev_vector);
1635
}
1736

18-
fn t_plus_l(state: [u64; COLS]) -> [u64; COLS] {
19-
let mut state = state;
20-
for nu in 0..ROUNDS {
21-
add_constant_plus(&mut state, nu as usize);
22-
apply_s_box(&mut state);
23-
state = rotate_rows(state);
24-
mix_columns(&mut state);
37+
/// Compute one output column using T-table lookups
38+
#[inline(always)]
39+
fn column(x: &[u64; COLS], col: usize) -> u64 {
40+
let mut t = 0u64;
41+
for row in 0..8 {
42+
let byte = ((x[SRC_COLS[col][row]] >> BYTE_SHIFTS[row]) & 0xFF) as usize;
43+
t ^= TABLE[row][byte];
2544
}
26-
state
45+
t
2746
}
2847

29-
fn rotate_rows(state: [u64; COLS]) -> [u64; COLS] {
30-
//shift amounts for each row (0-6: row index, 7: special case = 11)
31-
const SHIFTS: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 11];
32-
33-
array::from_fn(|col| {
34-
let rotated_bytes = array::from_fn(|row| {
35-
let shift = SHIFTS[row];
36-
let src_col = (col + COLS - shift) % COLS;
37-
let src_bytes = state[src_col].to_be_bytes();
38-
src_bytes[row]
39-
});
40-
u64::from_be_bytes(rotated_bytes)
41-
})
48+
fn t_plus_l(mut state: [u64; COLS]) -> [u64; COLS] {
49+
for round in 0..ROUNDS {
50+
// AddConstantPlus
51+
for (i, word) in state.iter_mut().enumerate() {
52+
*word = word
53+
.swap_bytes()
54+
.wrapping_add(
55+
0x00F0F0F0F0F0F0F3u64 ^ (((((COLS - i - 1) * 0x10) ^ round) as u64) << 56),
56+
)
57+
.swap_bytes();
58+
}
59+
// Fused SubBytes + ShiftRows + MixColumns via T-tables
60+
let prev = state;
61+
for (col, slot) in state[..8].iter_mut().enumerate() {
62+
*slot = column(&prev, col);
63+
}
64+
for (col, slot) in state[8..].iter_mut().enumerate() {
65+
*slot = column(&prev, col + 8);
66+
}
67+
}
68+
state
4269
}
4370

44-
pub(crate) fn t_xor_l(state: [u64; COLS]) -> [u64; COLS] {
45-
let mut state = state;
46-
for nu in 0..ROUNDS {
47-
add_constant_xor(&mut state, nu as usize);
48-
apply_s_box(&mut state);
49-
state = rotate_rows(state);
50-
mix_columns(&mut state);
71+
pub(crate) fn t_xor_l(mut state: [u64; COLS]) -> [u64; COLS] {
72+
for round in 0..ROUNDS {
73+
// AddConstantXor
74+
for (i, word) in state.iter_mut().enumerate() {
75+
let constant = ((i * 0x10) ^ round) as u64;
76+
*word ^= constant << 56;
77+
}
78+
// Fused SubBytes + ShiftRows + MixColumns via T-tables
79+
let prev = state;
80+
for (col, slot) in state[..8].iter_mut().enumerate() {
81+
*slot = column(&prev, col);
82+
}
83+
for (col, slot) in state[8..].iter_mut().enumerate() {
84+
*slot = column(&prev, col + 8);
85+
}
5186
}
5287
state
5388
}

kupyna/src/short.rs

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,80 @@
1-
use crate::utils::{
2-
add_constant_plus, add_constant_xor, apply_s_box, mix_columns, read_u64s_be, xor,
3-
};
4-
use core::array;
1+
use crate::table::TABLE;
2+
use crate::utils::{read_u64s_be, xor};
53

64
pub(crate) const COLS: usize = 8;
7-
const ROUNDS: u64 = 10;
5+
const ROUNDS: usize = 10;
6+
7+
// Bit shift amounts to extract each byte from a u64
8+
const BYTE_SHIFTS: [usize; 8] = [56, 48, 40, 32, 24, 16, 8, 0];
9+
10+
// Precomputed source columns: SRC_COLS[col][row] = (col + COLS - row) % COLS
11+
// ShiftRows for short variant: row i shifts by i positions
12+
const fn compute_src_cols() -> [[usize; 8]; COLS] {
13+
let mut result = [[0; 8]; COLS];
14+
let mut col = 0;
15+
while col < COLS {
16+
let mut row = 0;
17+
while row < 8 {
18+
result[col][row] = (col + COLS - row) % COLS;
19+
row += 1;
20+
}
21+
col += 1;
22+
}
23+
result
24+
}
25+
const SRC_COLS: [[usize; 8]; COLS] = compute_src_cols();
826

927
pub(crate) fn compress(prev_vector: &mut [u64; COLS], message_block: &[u8; 64]) {
10-
// Convert message block from u8 to u64 (column-major order as per paper)
1128
let message_u64 = read_u64s_be::<64, COLS>(message_block);
1229
let m_xor_p = xor(*prev_vector, message_u64);
1330
let t_xor_mp = t_xor_l(m_xor_p);
1431
let t_plus_m = t_plus_l(message_u64);
1532
*prev_vector = xor(xor(t_xor_mp, t_plus_m), *prev_vector);
1633
}
1734

18-
fn t_plus_l(state: [u64; COLS]) -> [u64; COLS] {
19-
let mut state = state;
20-
for nu in 0..ROUNDS {
21-
add_constant_plus(&mut state, nu as usize);
22-
apply_s_box(&mut state);
23-
state = rotate_rows(state);
24-
mix_columns(&mut state);
35+
/// Compute one output column using T-table lookups
36+
#[inline(always)]
37+
fn column(x: &[u64; COLS], col: usize) -> u64 {
38+
let mut t = 0u64;
39+
for row in 0..8 {
40+
let byte = ((x[SRC_COLS[col][row]] >> BYTE_SHIFTS[row]) & 0xFF) as usize;
41+
t ^= TABLE[row][byte];
2542
}
26-
state
43+
t
2744
}
2845

29-
fn rotate_rows(state: [u64; COLS]) -> [u64; COLS] {
30-
//shift amounts for each row (0-6: row index, 7: special case)
31-
const SHIFTS: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
32-
33-
array::from_fn(|col| {
34-
let rotated_bytes = array::from_fn(|row| {
35-
let shift = SHIFTS[row];
36-
let src_col = (col + COLS - shift) % COLS;
37-
let src_bytes = state[src_col].to_be_bytes();
38-
src_bytes[row]
39-
});
40-
u64::from_be_bytes(rotated_bytes)
41-
})
46+
fn t_plus_l(mut state: [u64; COLS]) -> [u64; COLS] {
47+
for round in 0..ROUNDS {
48+
// AddConstantPlus
49+
for (i, word) in state.iter_mut().enumerate() {
50+
*word = word
51+
.swap_bytes()
52+
.wrapping_add(
53+
0x00F0F0F0F0F0F0F3u64 ^ (((((COLS - i - 1) * 0x10) ^ round) as u64) << 56),
54+
)
55+
.swap_bytes();
56+
}
57+
// Fused SubBytes + ShiftRows + MixColumns via T-tables
58+
let prev = state;
59+
for (col, slot) in state.iter_mut().enumerate() {
60+
*slot = column(&prev, col);
61+
}
62+
}
63+
state
4264
}
4365

44-
pub(crate) fn t_xor_l(state: [u64; COLS]) -> [u64; COLS] {
45-
let mut state = state;
46-
for nu in 0..ROUNDS {
47-
add_constant_xor(&mut state, nu as usize);
48-
apply_s_box(&mut state);
49-
state = rotate_rows(state);
50-
mix_columns(&mut state);
66+
pub(crate) fn t_xor_l(mut state: [u64; COLS]) -> [u64; COLS] {
67+
for round in 0..ROUNDS {
68+
// AddConstantXor
69+
for (i, word) in state.iter_mut().enumerate() {
70+
let constant = ((i * 0x10) ^ round) as u64;
71+
*word ^= constant << 56;
72+
}
73+
// Fused SubBytes + ShiftRows + MixColumns via T-tables
74+
let prev = state;
75+
for (col, slot) in state.iter_mut().enumerate() {
76+
*slot = column(&prev, col);
77+
}
5178
}
5279
state
5380
}

kupyna/src/table.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use crate::consts::{MDS_MATRIX, SBOXES};
2+
3+
/// GF(2^8) multiplication with reduction polynomial x^8 + x^4 + x^3 + x^2 + 1
4+
const fn gf_multiply(mut x: u8, mut y: u8) -> u8 {
5+
const REDUCTION_POLYNOMIAL: u8 = 0x1d; // x^4 + x^3 + x^2 + 1
6+
7+
let mut r = 0u8;
8+
let mut i = 0;
9+
while i < 8 {
10+
if y & 1 == 1 {
11+
r ^= x;
12+
}
13+
let hbit = x & 0x80;
14+
x <<= 1;
15+
if hbit != 0 {
16+
x ^= REDUCTION_POLYNOMIAL;
17+
}
18+
y >>= 1;
19+
i += 1;
20+
}
21+
r
22+
}
23+
24+
/// Generate T-tables that fuse SubBytes + MixColumns
25+
///
26+
/// TABLE[row][byte] gives the contribution to an output column when input byte
27+
/// at position `row` has value `byte`, after applying S-box and MDS multiplication.
28+
const fn generate_t_table() -> [[u64; 256]; 8] {
29+
let mut table = [[0u64; 256]; 8];
30+
31+
let mut row = 0;
32+
while row < 8 {
33+
let mut byte = 0;
34+
while byte < 256 {
35+
// Apply S-box for this row position (S-boxes cycle with period 4)
36+
let s = SBOXES[row % 4][byte];
37+
38+
// Compute contribution to each output row via MDS multiplication
39+
let mut out = [0u8; 8];
40+
let mut out_row = 0;
41+
while out_row < 8 {
42+
// Extract MDS coefficient: MDS_MATRIX[out_row] byte at position `row`
43+
let mds_coef = (MDS_MATRIX[out_row] >> (8 * (7 - row))) as u8;
44+
out[out_row] = gf_multiply(mds_coef, s);
45+
out_row += 1;
46+
}
47+
48+
// Pack into u64 (big-endian)
49+
table[row][byte] = ((out[0] as u64) << 56)
50+
| ((out[1] as u64) << 48)
51+
| ((out[2] as u64) << 40)
52+
| ((out[3] as u64) << 32)
53+
| ((out[4] as u64) << 24)
54+
| ((out[5] as u64) << 16)
55+
| ((out[6] as u64) << 8)
56+
| (out[7] as u64);
57+
58+
byte += 1;
59+
}
60+
row += 1;
61+
}
62+
table
63+
}
64+
65+
pub(crate) static TABLE: [[u64; 256]; 8] = generate_t_table();

0 commit comments

Comments
 (0)