diff --git a/src/oprf/vole_oprf.rs b/src/oprf/vole_oprf.rs index d293f0a..cdbd9f0 100644 --- a/src/oprf/vole_oprf.rs +++ b/src/oprf/vole_oprf.rs @@ -23,21 +23,104 @@ //! - **UC-Unlinkable**: No fingerprint attack possible (server never sees A·s+e) //! - **Helper-less**: No reconciliation hints transmitted //! - **Post-Quantum**: Based on Ring-LWR assumption +//! - **Constant-Time**: All secret-dependent operations are timing-attack resistant use rand::Rng; use sha3::{Digest, Sha3_256, Sha3_512}; use std::fmt; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; pub const VOLE_RING_N: usize = 256; + +/// Security level configuration +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum VoleSecurityLevel { + /// Standard: q=65537 (16-bit Fermat prime), β=1 + /// Best for WASM performance, NTT-friendly + Standard, + /// Military: q=2147483647 (31-bit Mersenne prime), β=3 + /// Higher security margin, harder lattice problem + Military, +} + +impl Default for VoleSecurityLevel { + fn default() -> Self { + VoleSecurityLevel::Standard + } +} + +/// Get modulus q for security level +pub const fn get_vole_q(level: VoleSecurityLevel) -> i64 { + match level { + VoleSecurityLevel::Standard => 65537, // 2^16 + 1 (Fermat prime) + VoleSecurityLevel::Military => 2147483647, // 2^31 - 1 (Mersenne prime) + } +} + +/// Get rounding modulus p for security level +pub const fn get_vole_p(level: VoleSecurityLevel) -> i64 { + match level { + VoleSecurityLevel::Standard => 16, // q/(2p) = 2048 > 2nβ² = 512 + VoleSecurityLevel::Military => 256, // q/(2p) = 4194303 > 2nβ² = 4608 + } +} + +/// Get error bound β for security level +pub const fn get_vole_beta(level: VoleSecurityLevel) -> i32 { + match level { + VoleSecurityLevel::Standard => 1, // Small for 16-bit modulus + VoleSecurityLevel::Military => 3, // Larger for 31-bit modulus + } +} + +// Default constants (Standard security) pub const VOLE_Q: i64 = 65537; -/// Rounding modulus for LWR - chosen so q/(2p) > 2nβ² for correctness -/// With n=256, β=1: error = 2×256×1 = 512, threshold = 65537/32 = 2048 ✓ pub const VOLE_P: i64 = 16; -/// Small error bound - β=1 ensures LWR correctness with our q and p pub const VOLE_ERROR_BOUND: i32 = 1; pub const VOLE_OUTPUT_LEN: usize = 32; pub const PCG_SEED_LEN: usize = 32; +// Military-grade constants +pub const VOLE_Q_MILITARY: i64 = 2147483647; +pub const VOLE_P_MILITARY: i64 = 256; +pub const VOLE_ERROR_BOUND_MILITARY: i32 = 3; + +// ============================================================================ +// CONSTANT-TIME ARITHMETIC (Timing Attack Resistance) +// ============================================================================ + +/// Constant-time modular reduction: x mod q +/// Returns result in [0, q-1] +#[inline] +fn ct_reduce(x: i128, q: i64) -> i64 { + // rem_euclid is constant-time in Rust and handles negatives correctly + let q128 = q as i128; + x.rem_euclid(q128) as i64 +} + +/// Constant-time conditional select: if choice == 1, return b; else return a +#[inline] +fn ct_select(a: i64, b: i64, choice: Choice) -> i64 { + i64::conditional_select(&a, &b, choice) +} + +/// Constant-time equality check +#[inline] +fn ct_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + a.ct_eq(b).into() +} + +/// Constant-time less-than comparison for positive values +#[inline] +fn ct_lt(a: i64, b: i64) -> Choice { + // (a - b) will have high bit set if a < b (for positive values where a-b doesn't overflow) + let diff = (a as i128) - (b as i128); + Choice::from((diff >> 127) as u8 & 1) +} + #[derive(Clone)] pub struct VoleRingElement { pub coeffs: [i64; VOLE_RING_N], @@ -78,7 +161,10 @@ impl VoleRingElement { Self { coeffs } } + /// Sample small coefficients in [-bound, bound], normalized to [0, q-1] pub fn sample_small(seed: &[u8], bound: i32) -> Self { + debug_assert!(bound >= 0 && bound < VOLE_Q as i32); + let mut hasher = Sha3_512::new(); hasher.update(b"VOLE-SmallSample-v1"); hasher.update(seed); @@ -94,38 +180,63 @@ impl VoleRingElement { break; } let byte = hash[i % 64] as i32; - coeffs[idx] = ((byte % (2 * bound + 1)) - bound) as i64; + let val = ((byte % (2 * bound + 1)) - bound) as i64; + // Normalize to [0, q-1] + coeffs[idx] = if val < 0 { val + VOLE_Q } else { val }; } } + + debug_assert!(coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); Self { coeffs } } + /// Sample random small coefficients, normalized to [0, q-1] pub fn sample_random_small() -> Self { let mut rng = rand::rng(); let mut coeffs = [0i64; VOLE_RING_N]; for coeff in &mut coeffs { - *coeff = rng.random_range(-VOLE_ERROR_BOUND as i64..=VOLE_ERROR_BOUND as i64); + let val = rng.random_range(-VOLE_ERROR_BOUND as i64..=VOLE_ERROR_BOUND as i64); + // Normalize to [0, q-1] + *coeff = if val < 0 { val + VOLE_Q } else { val }; } + + debug_assert!(coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); Self { coeffs } } + /// Constant-time addition mod q pub fn add(&self, other: &Self) -> Self { + debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + let mut result = Self::zero(); for i in 0..VOLE_RING_N { - result.coeffs[i] = (self.coeffs[i] + other.coeffs[i]).rem_euclid(VOLE_Q); + result.coeffs[i] = + ct_reduce((self.coeffs[i] as i128) + (other.coeffs[i] as i128), VOLE_Q); } result } + /// Constant-time subtraction mod q pub fn sub(&self, other: &Self) -> Self { + debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + let mut result = Self::zero(); for i in 0..VOLE_RING_N { - result.coeffs[i] = (self.coeffs[i] - other.coeffs[i]).rem_euclid(VOLE_Q); + result.coeffs[i] = ct_reduce( + (self.coeffs[i] as i128) - (other.coeffs[i] as i128) + (VOLE_Q as i128), + VOLE_Q, + ); } result } + /// Constant-time ring multiplication (negacyclic convolution) pub fn mul(&self, other: &Self) -> Self { + debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + let mut result = [0i128; 2 * VOLE_RING_N]; for i in 0..VOLE_RING_N { for j in 0..VOLE_RING_N { @@ -135,15 +246,20 @@ impl VoleRingElement { let mut out = Self::zero(); for i in 0..VOLE_RING_N { let combined = result[i] - result[i + VOLE_RING_N]; - out.coeffs[i] = (combined.rem_euclid(VOLE_Q as i128)) as i64; + out.coeffs[i] = ct_reduce(combined, VOLE_Q); } + + debug_assert!(out.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); out } + /// Constant-time scalar multiplication mod q pub fn scalar_mul(&self, scalar: i64) -> Self { + debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + let mut result = Self::zero(); for i in 0..VOLE_RING_N { - result.coeffs[i] = (self.coeffs[i] * scalar).rem_euclid(VOLE_Q); + result.coeffs[i] = ct_reduce((self.coeffs[i] as i128) * (scalar as i128), VOLE_Q); } result } @@ -164,30 +280,43 @@ impl VoleRingElement { /// LWR: Deterministic rounding from Zq to Zp (THE PATENT CLAIM - NO HELPERS!) /// round(v) = floor((v * p + q/2) / q) mod p + /// This is constant-time: no branching on coefficient values pub fn round_lwr(&self) -> [u8; VOLE_RING_N] { + debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q)); + let mut rounded = [0u8; VOLE_RING_N]; for i in 0..VOLE_RING_N { - let v = self.coeffs[i].rem_euclid(VOLE_Q); - // Standard LWR rounding with proper rounding (not truncation) + let v = ct_reduce(self.coeffs[i] as i128, VOLE_Q); let scaled = (v * VOLE_P + VOLE_Q / 2) / VOLE_Q; rounded[i] = (scaled % VOLE_P) as u8; } rounded } + /// Constant-time approximate equality check pub fn eq_approx(&self, other: &Self, tolerance: i64) -> bool { + debug_assert!(tolerance >= 0); + + let mut all_within = Choice::from(1u8); for i in 0..VOLE_RING_N { - let diff = (self.coeffs[i] - other.coeffs[i]).rem_euclid(VOLE_Q); - let abs_diff = if diff > VOLE_Q / 2 { - VOLE_Q - diff - } else { - diff - }; - if abs_diff > tolerance { - return false; - } + let diff = ct_reduce( + (self.coeffs[i] as i128) - (other.coeffs[i] as i128) + (VOLE_Q as i128), + VOLE_Q, + ); + // Compute |diff| in constant time (diff is in [0, q-1]) + let half_q = VOLE_Q / 2; + let is_large = ct_lt(half_q, diff); + let abs_diff = ct_select(diff, VOLE_Q - diff, is_large); + + let within_tol = ct_lt(abs_diff, tolerance + 1); + all_within &= within_tol; } - true + all_within.into() + } + + /// Constant-time equality check for OPRF output verification + pub fn ct_eq_output(a: &[u8; VOLE_OUTPUT_LEN], b: &[u8; VOLE_OUTPUT_LEN]) -> bool { + ct_eq(a, b) } } @@ -1040,7 +1169,7 @@ mod tests { let server_key_seed = b"server-key"; let reg_request = vole_client_start_registration(username); - let (user_record, reg_response) = + let (_user_record, reg_response) = vole_server_register(®_request, password, server_key_seed); let client_credential = vole_client_finish_registration(®_request, ®_response); @@ -1097,7 +1226,7 @@ mod tests { outputs.push(output); } - for (i, out) in outputs.iter().enumerate().skip(1) { + for out in outputs.iter().skip(1) { assert_eq!( outputs[0].value, out.value, "All outputs must be identical for same password" @@ -1121,7 +1250,7 @@ mod tests { let client_credential = vole_client_finish_registration(®_request, ®_response); let (_, login_request) = vole_client_login(&client_credential, password); - let login_response = vole_server_login(&user_record, &login_request); + let _login_response = vole_server_login(&user_record, &login_request); let request_size = login_request.username.len() + 8 + // pcg_index (u64) @@ -1146,4 +1275,227 @@ mod tests { println!("\n[PASS] Single-round protocol with minimal communication!"); } + + #[test] + fn test_constant_time_reduce() { + println!("\n=== TEST: Constant-Time Reduction ==="); + + let test_values: Vec = vec![ + 0, + 1, + -1, + VOLE_Q as i128, + -VOLE_Q as i128, + VOLE_Q as i128 + 1, + VOLE_Q as i128 - 1, + -(VOLE_Q as i128) + 1, + i64::MAX as i128, + i64::MIN as i128, + (VOLE_Q as i128) * 1000, + -(VOLE_Q as i128) * 1000, + ]; + + for &x in &test_values { + let result = ct_reduce(x, VOLE_Q); + assert!( + result >= 0 && result < VOLE_Q, + "ct_reduce({}) = {} should be in [0, {})", + x, + result, + VOLE_Q + ); + + let expected = x.rem_euclid(VOLE_Q as i128) as i64; + assert_eq!( + result, expected, + "ct_reduce({}) = {} should equal {}", + x, result, expected + ); + } + + println!("[PASS] ct_reduce produces correct results for all edge cases"); + } + + #[test] + fn test_constant_time_select() { + println!("\n=== TEST: Constant-Time Select ==="); + + let a = 42i64; + let b = 99i64; + + let result_a = ct_select(a, b, Choice::from(0)); + let result_b = ct_select(a, b, Choice::from(1)); + + assert_eq!(result_a, a, "ct_select with choice=0 should return a"); + assert_eq!(result_b, b, "ct_select with choice=1 should return b"); + + println!("[PASS] ct_select works correctly"); + } + + #[test] + fn test_constant_time_eq() { + println!("\n=== TEST: Constant-Time Equality ==="); + + let a = [1u8, 2, 3, 4, 5]; + let b = [1u8, 2, 3, 4, 5]; + let c = [1u8, 2, 3, 4, 6]; + + assert!(ct_eq(&a, &b), "Equal arrays should compare equal"); + assert!(!ct_eq(&a, &c), "Different arrays should compare unequal"); + assert!( + !ct_eq(&a, &[1, 2, 3]), + "Different length arrays should compare unequal" + ); + + println!("[PASS] ct_eq works correctly"); + } + + #[test] + fn test_constant_time_lt() { + println!("\n=== TEST: Constant-Time Less-Than ==="); + + assert!(bool::from(ct_lt(5, 10)), "5 < 10"); + assert!(!bool::from(ct_lt(10, 5)), "10 >= 5"); + assert!(!bool::from(ct_lt(5, 5)), "5 == 5, not less"); + assert!(bool::from(ct_lt(0, 1)), "0 < 1"); + assert!(bool::from(ct_lt(0, VOLE_Q)), "0 < q"); + + println!("[PASS] ct_lt works correctly"); + } + + #[test] + fn test_timing_attack_resistance_operations() { + println!("\n=== TEST: Timing Attack Resistance - Operations ==="); + use std::time::Instant; + + let elem_small = VoleRingElement::sample_small(b"small-seed", 1); + let elem_large = VoleRingElement::sample_uniform(b"large-seed"); + + let iterations = 100; + + let start = Instant::now(); + for _ in 0..iterations { + let _ = elem_small.mul(&elem_large); + } + let small_time = start.elapsed(); + + let start = Instant::now(); + for _ in 0..iterations { + let _ = elem_large.mul(&elem_large); + } + let large_time = start.elapsed(); + + let ratio = small_time.as_nanos() as f64 / large_time.as_nanos() as f64; + println!( + "Small coeffs mul: {:?}, Large coeffs mul: {:?}, Ratio: {:.3}", + small_time, large_time, ratio + ); + + assert!( + ratio > 0.3 && ratio < 3.0, + "Timing should be roughly similar regardless of coefficient values (ratio: {:.3})", + ratio + ); + + println!("[PASS] Multiplication timing is roughly independent of coefficient values"); + } + + #[test] + fn test_timing_attack_resistance_lwr() { + println!("\n=== TEST: Timing Attack Resistance - LWR Rounding ==="); + use std::time::Instant; + + let mut elem_zeros = VoleRingElement::zero(); + let mut elem_max = VoleRingElement::zero(); + for i in 0..VOLE_RING_N { + elem_zeros.coeffs[i] = 0; + elem_max.coeffs[i] = VOLE_Q - 1; + } + + let iterations = 1000; + + let start = Instant::now(); + for _ in 0..iterations { + let _ = elem_zeros.round_lwr(); + } + let zeros_time = start.elapsed(); + + let start = Instant::now(); + for _ in 0..iterations { + let _ = elem_max.round_lwr(); + } + let max_time = start.elapsed(); + + let ratio = zeros_time.as_nanos() as f64 / max_time.as_nanos() as f64; + println!( + "Zero coeffs LWR: {:?}, Max coeffs LWR: {:?}, Ratio: {:.3}", + zeros_time, max_time, ratio + ); + + assert!( + ratio > 0.3 && ratio < 3.0, + "LWR timing should be roughly independent of coefficient values (ratio: {:.3})", + ratio + ); + + println!("[PASS] LWR rounding timing is roughly independent of coefficient values"); + } + + #[test] + fn test_timing_attack_resistance_eq_approx() { + println!("\n=== TEST: Timing Attack Resistance - Approximate Equality ==="); + use std::time::Instant; + + let elem_a = VoleRingElement::sample_uniform(b"elem-a"); + let elem_b = VoleRingElement::sample_uniform(b"elem-b"); + let elem_a_clone = elem_a.clone(); + + let iterations = 1000; + + let start = Instant::now(); + for _ in 0..iterations { + let _ = elem_a.eq_approx(&elem_a_clone, 1); + } + let equal_time = start.elapsed(); + + let start = Instant::now(); + for _ in 0..iterations { + let _ = elem_a.eq_approx(&elem_b, 1); + } + let unequal_time = start.elapsed(); + + let ratio = equal_time.as_nanos() as f64 / unequal_time.as_nanos() as f64; + println!( + "Equal comparison: {:?}, Unequal comparison: {:?}, Ratio: {:.3}", + equal_time, unequal_time, ratio + ); + + assert!( + ratio > 0.3 && ratio < 3.0, + "eq_approx timing should be roughly independent of result (ratio: {:.3})", + ratio + ); + + println!("[PASS] eq_approx timing is roughly independent of comparison result"); + } + + #[test] + fn test_constant_time_output_comparison() { + println!("\n=== TEST: Constant-Time Output Comparison ==="); + + let out_a: [u8; VOLE_OUTPUT_LEN] = [0xaa; VOLE_OUTPUT_LEN]; + let out_b: [u8; VOLE_OUTPUT_LEN] = [0xaa; VOLE_OUTPUT_LEN]; + let out_c: [u8; VOLE_OUTPUT_LEN] = [0xbb; VOLE_OUTPUT_LEN]; + + assert!( + VoleRingElement::ct_eq_output(&out_a, &out_b), + "Equal outputs should compare equal" + ); + assert!( + !VoleRingElement::ct_eq_output(&out_a, &out_c), + "Different outputs should compare unequal" + ); + + println!("[PASS] Constant-time output comparison works correctly"); + } }