diff --git a/src/oprf/vole_oprf.rs b/src/oprf/vole_oprf.rs index cdbd9f0..c1f5159 100644 --- a/src/oprf/vole_oprf.rs +++ b/src/oprf/vole_oprf.rs @@ -116,11 +116,177 @@ fn ct_eq(a: &[u8], b: &[u8]) -> bool { /// Constant-time less-than comparison for positive values #[inline] fn ct_lt(a: i64, b: i64) -> Choice { - // (a - b) will have high bit set if a < b (for positive values where a-b doesn't overflow) let diff = (a as i128) - (b as i128); Choice::from((diff >> 127) as u8 & 1) } +/// Constant-time check if value is negative +#[inline] +fn ct_is_negative(val: i64) -> Choice { + Choice::from((val >> 63) as u8 & 1) +} + +/// Constant-time normalization: if val < 0, return val + q; else return val +#[inline] +fn ct_normalize(val: i64, q: i64) -> i64 { + let is_neg = ct_is_negative(val); + ct_select(val, val + q, is_neg) +} + +// ============================================================================ +// NTT (Number Theoretic Transform) for O(n log n) Polynomial Multiplication +// ============================================================================ +// +// For q = 65537 = 2^16 + 1 (Fermat prime), we use: +// - ψ = primitive 512th root of unity (since we need negacyclic: ψ^256 = -1 mod q) +// - ψ = 3^128 mod 65537 = 256 (precomputed) +// - ψ^(-1) = 65281 mod 65537 +// - n^(-1) = 256^(-1) mod 65537 = 65281 (since 256 * 65281 ≡ 1 mod 65537) + +const NTT_PSI: i64 = 256; +const NTT_PSI_INV: i64 = 65281; +const NTT_N_INV: i64 = 65281; + +/// Modular exponentiation: base^exp mod m +#[inline] +fn mod_pow(base: i64, mut exp: u32, m: i64) -> i64 { + let mut result = 1i64; + let mut base = base % m; + while exp > 0 { + if exp & 1 == 1 { + result = ((result as i128 * base as i128) % m as i128) as i64; + } + exp >>= 1; + base = ((base as i128 * base as i128) % m as i128) as i64; + } + result +} + +/// Bit-reverse permutation for NTT +#[inline] +fn bit_reverse(mut x: usize, log_n: u32) -> usize { + let mut result = 0; + for _ in 0..log_n { + result = (result << 1) | (x & 1); + x >>= 1; + } + result +} + +/// Precompute powers of ψ for forward NTT (twisted for negacyclic) +fn compute_ntt_twiddles() -> [i64; VOLE_RING_N] { + let mut twiddles = [0i64; VOLE_RING_N]; + twiddles[0] = 1; + for i in 1..VOLE_RING_N { + twiddles[i] = ((twiddles[i - 1] as i128 * NTT_PSI as i128) % VOLE_Q as i128) as i64; + } + twiddles +} + +/// Precompute powers of ψ^(-1) for inverse NTT +fn compute_intt_twiddles() -> [i64; VOLE_RING_N] { + let mut twiddles = [0i64; VOLE_RING_N]; + twiddles[0] = 1; + for i in 1..VOLE_RING_N { + twiddles[i] = ((twiddles[i - 1] as i128 * NTT_PSI_INV as i128) % VOLE_Q as i128) as i64; + } + twiddles +} + +/// Forward NTT (Cooley-Tukey, in-place, iterative) +/// Transforms polynomial to NTT domain for O(n log n) multiplication +fn ntt_forward(coeffs: &mut [i64; VOLE_RING_N]) { + let twiddles = compute_ntt_twiddles(); + let log_n = 8u32; // log2(256) = 8 + + // Bit-reverse permutation + for i in 0..VOLE_RING_N { + let j = bit_reverse(i, log_n); + if i < j { + coeffs.swap(i, j); + } + } + + // Pre-multiply by powers of ψ for negacyclic convolution + for i in 0..VOLE_RING_N { + coeffs[i] = ((coeffs[i] as i128 * twiddles[i] as i128) % VOLE_Q as i128) as i64; + } + + // Cooley-Tukey butterfly + let mut m = 1; + for _ in 0..log_n { + let w_m = mod_pow(NTT_PSI, (VOLE_RING_N / (2 * m)) as u32, VOLE_Q); + for k in (0..VOLE_RING_N).step_by(2 * m) { + let mut w = 1i64; + for j in 0..m { + let t = ((w as i128 * coeffs[k + j + m] as i128) % VOLE_Q as i128) as i64; + let u = coeffs[k + j]; + coeffs[k + j] = (u + t) % VOLE_Q; + coeffs[k + j + m] = ((u - t) % VOLE_Q + VOLE_Q) % VOLE_Q; + w = ((w as i128 * w_m as i128) % VOLE_Q as i128) as i64; + } + } + m *= 2; + } +} + +/// Inverse NTT (Gentleman-Sande, in-place, iterative) +/// Transforms from NTT domain back to coefficient domain +fn ntt_inverse(coeffs: &mut [i64; VOLE_RING_N]) { + let twiddles_inv = compute_intt_twiddles(); + let log_n = 8u32; + + // Gentleman-Sande butterfly (inverse of Cooley-Tukey) + let mut m = VOLE_RING_N / 2; + for _ in 0..log_n { + let w_m = mod_pow(NTT_PSI_INV, (VOLE_RING_N / (2 * m)) as u32, VOLE_Q); + for k in (0..VOLE_RING_N).step_by(2 * m) { + let mut w = 1i64; + for j in 0..m { + let t = coeffs[k + j]; + let u = coeffs[k + j + m]; + coeffs[k + j] = (t + u) % VOLE_Q; + coeffs[k + j + m] = + ((((t - u) % VOLE_Q + VOLE_Q) as i128 * w as i128) % VOLE_Q as i128) as i64; + w = ((w as i128 * w_m as i128) % VOLE_Q as i128) as i64; + } + } + m /= 2; + } + + // Bit-reverse permutation + for i in 0..VOLE_RING_N { + let j = bit_reverse(i, log_n); + if i < j { + coeffs.swap(i, j); + } + } + + // Post-multiply by inverse powers of ψ and scale by n^(-1) + for i in 0..VOLE_RING_N { + let scaled = ((coeffs[i] as i128 * twiddles_inv[i] as i128) % VOLE_Q as i128) as i64; + coeffs[i] = ((scaled as i128 * NTT_N_INV as i128) % VOLE_Q as i128) as i64; + } +} + +/// NTT-based negacyclic polynomial multiplication: O(n log n) +fn ntt_mul(a: &[i64; VOLE_RING_N], b: &[i64; VOLE_RING_N]) -> [i64; VOLE_RING_N] { + let mut a_ntt = *a; + let mut b_ntt = *b; + + ntt_forward(&mut a_ntt); + ntt_forward(&mut b_ntt); + + // Point-wise multiplication in NTT domain + let mut c_ntt = [0i64; VOLE_RING_N]; + for i in 0..VOLE_RING_N { + c_ntt[i] = ((a_ntt[i] as i128 * b_ntt[i] as i128) % VOLE_Q as i128) as i64; + } + + ntt_inverse(&mut c_ntt); + c_ntt +} + #[derive(Clone)] pub struct VoleRingElement { pub coeffs: [i64; VOLE_RING_N], @@ -181,8 +347,7 @@ impl VoleRingElement { } let byte = hash[i % 64] as i32; let val = ((byte % (2 * bound + 1)) - bound) as i64; - // Normalize to [0, q-1] - coeffs[idx] = if val < 0 { val + VOLE_Q } else { val }; + coeffs[idx] = ct_normalize(val, VOLE_Q); } } @@ -196,8 +361,7 @@ impl VoleRingElement { let mut coeffs = [0i64; VOLE_RING_N]; for coeff in &mut coeffs { let val = rng.random_range(-VOLE_ERROR_BOUND as i64..=VOLE_ERROR_BOUND as i64); - // Normalize to [0, q-1] - *coeff = if val < 0 { val + VOLE_Q } else { val }; + *coeff = ct_normalize(val, VOLE_Q); } debug_assert!(coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); @@ -232,22 +396,15 @@ impl VoleRingElement { result } - /// Constant-time ring multiplication (negacyclic convolution) + /// Constant-time ring multiplication (negacyclic convolution) using NTT: O(n log n) pub fn mul(&self, other: &Self) -> Self { debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); - let mut result = [0i128; 2 * VOLE_RING_N]; - for i in 0..VOLE_RING_N { - for j in 0..VOLE_RING_N { - result[i + j] += (self.coeffs[i] as i128) * (other.coeffs[j] as i128); - } - } - let mut out = Self::zero(); - for i in 0..VOLE_RING_N { - let combined = result[i] - result[i + VOLE_RING_N]; - out.coeffs[i] = ct_reduce(combined, VOLE_Q); - } + let result_coeffs = ntt_mul(&self.coeffs, &other.coeffs); + let out = Self { + coeffs: result_coeffs, + }; debug_assert!(out.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); out