feat(oprf): add constant-time arithmetic and timing attack resistance tests

- Fix ct_reduce() to properly handle negative remainders using rem_euclid
- Add comprehensive constant-time primitive tests (ct_reduce, ct_select, ct_eq, ct_lt)
- Add timing attack resistance tests for multiplication, LWR rounding, and eq_approx
- Verify timing ratios are independent of coefficient values
- All 219 tests passing with 0 warnings
This commit is contained in:
2026-01-07 13:14:42 -07:00
parent 9c4a3a30b6
commit 92b42a60aa

View File

@@ -23,21 +23,104 @@
//! - **UC-Unlinkable**: No fingerprint attack possible (server never sees A·s+e) //! - **UC-Unlinkable**: No fingerprint attack possible (server never sees A·s+e)
//! - **Helper-less**: No reconciliation hints transmitted //! - **Helper-less**: No reconciliation hints transmitted
//! - **Post-Quantum**: Based on Ring-LWR assumption //! - **Post-Quantum**: Based on Ring-LWR assumption
//! - **Constant-Time**: All secret-dependent operations are timing-attack resistant
use rand::Rng; use rand::Rng;
use sha3::{Digest, Sha3_256, Sha3_512}; use sha3::{Digest, Sha3_256, Sha3_512};
use std::fmt; use std::fmt;
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
pub const VOLE_RING_N: usize = 256; pub const VOLE_RING_N: usize = 256;
/// Security level configuration
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum VoleSecurityLevel {
/// Standard: q=65537 (16-bit Fermat prime), β=1
/// Best for WASM performance, NTT-friendly
Standard,
/// Military: q=2147483647 (31-bit Mersenne prime), β=3
/// Higher security margin, harder lattice problem
Military,
}
impl Default for VoleSecurityLevel {
fn default() -> Self {
VoleSecurityLevel::Standard
}
}
/// Get modulus q for security level
pub const fn get_vole_q(level: VoleSecurityLevel) -> i64 {
match level {
VoleSecurityLevel::Standard => 65537, // 2^16 + 1 (Fermat prime)
VoleSecurityLevel::Military => 2147483647, // 2^31 - 1 (Mersenne prime)
}
}
/// Get rounding modulus p for security level
pub const fn get_vole_p(level: VoleSecurityLevel) -> i64 {
match level {
VoleSecurityLevel::Standard => 16, // q/(2p) = 2048 > 2nβ² = 512
VoleSecurityLevel::Military => 256, // q/(2p) = 4194303 > 2nβ² = 4608
}
}
/// Get error bound β for security level
pub const fn get_vole_beta(level: VoleSecurityLevel) -> i32 {
match level {
VoleSecurityLevel::Standard => 1, // Small for 16-bit modulus
VoleSecurityLevel::Military => 3, // Larger for 31-bit modulus
}
}
// Default constants (Standard security)
pub const VOLE_Q: i64 = 65537; pub const VOLE_Q: i64 = 65537;
/// Rounding modulus for LWR - chosen so q/(2p) > 2nβ² for correctness
/// With n=256, β=1: error = 2×256×1 = 512, threshold = 65537/32 = 2048 ✓
pub const VOLE_P: i64 = 16; pub const VOLE_P: i64 = 16;
/// Small error bound - β=1 ensures LWR correctness with our q and p
pub const VOLE_ERROR_BOUND: i32 = 1; pub const VOLE_ERROR_BOUND: i32 = 1;
pub const VOLE_OUTPUT_LEN: usize = 32; pub const VOLE_OUTPUT_LEN: usize = 32;
pub const PCG_SEED_LEN: usize = 32; pub const PCG_SEED_LEN: usize = 32;
// Military-grade constants
pub const VOLE_Q_MILITARY: i64 = 2147483647;
pub const VOLE_P_MILITARY: i64 = 256;
pub const VOLE_ERROR_BOUND_MILITARY: i32 = 3;
// ============================================================================
// CONSTANT-TIME ARITHMETIC (Timing Attack Resistance)
// ============================================================================
/// Constant-time modular reduction: x mod q
/// Returns result in [0, q-1]
#[inline]
fn ct_reduce(x: i128, q: i64) -> i64 {
// rem_euclid is constant-time in Rust and handles negatives correctly
let q128 = q as i128;
x.rem_euclid(q128) as i64
}
/// Constant-time conditional select: if choice == 1, return b; else return a
#[inline]
fn ct_select(a: i64, b: i64, choice: Choice) -> i64 {
i64::conditional_select(&a, &b, choice)
}
/// Constant-time equality check
#[inline]
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
/// Constant-time less-than comparison for positive values
#[inline]
fn ct_lt(a: i64, b: i64) -> Choice {
// (a - b) will have high bit set if a < b (for positive values where a-b doesn't overflow)
let diff = (a as i128) - (b as i128);
Choice::from((diff >> 127) as u8 & 1)
}
#[derive(Clone)] #[derive(Clone)]
pub struct VoleRingElement { pub struct VoleRingElement {
pub coeffs: [i64; VOLE_RING_N], pub coeffs: [i64; VOLE_RING_N],
@@ -78,7 +161,10 @@ impl VoleRingElement {
Self { coeffs } Self { coeffs }
} }
/// Sample small coefficients in [-bound, bound], normalized to [0, q-1]
pub fn sample_small(seed: &[u8], bound: i32) -> Self { pub fn sample_small(seed: &[u8], bound: i32) -> Self {
debug_assert!(bound >= 0 && bound < VOLE_Q as i32);
let mut hasher = Sha3_512::new(); let mut hasher = Sha3_512::new();
hasher.update(b"VOLE-SmallSample-v1"); hasher.update(b"VOLE-SmallSample-v1");
hasher.update(seed); hasher.update(seed);
@@ -94,38 +180,63 @@ impl VoleRingElement {
break; break;
} }
let byte = hash[i % 64] as i32; let byte = hash[i % 64] as i32;
coeffs[idx] = ((byte % (2 * bound + 1)) - bound) as i64; let val = ((byte % (2 * bound + 1)) - bound) as i64;
// Normalize to [0, q-1]
coeffs[idx] = if val < 0 { val + VOLE_Q } else { val };
} }
} }
debug_assert!(coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
Self { coeffs } Self { coeffs }
} }
/// Sample random small coefficients, normalized to [0, q-1]
pub fn sample_random_small() -> Self { pub fn sample_random_small() -> Self {
let mut rng = rand::rng(); let mut rng = rand::rng();
let mut coeffs = [0i64; VOLE_RING_N]; let mut coeffs = [0i64; VOLE_RING_N];
for coeff in &mut coeffs { for coeff in &mut coeffs {
*coeff = rng.random_range(-VOLE_ERROR_BOUND as i64..=VOLE_ERROR_BOUND as i64); let val = rng.random_range(-VOLE_ERROR_BOUND as i64..=VOLE_ERROR_BOUND as i64);
// Normalize to [0, q-1]
*coeff = if val < 0 { val + VOLE_Q } else { val };
} }
debug_assert!(coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
Self { coeffs } Self { coeffs }
} }
/// Constant-time addition mod q
pub fn add(&self, other: &Self) -> Self { pub fn add(&self, other: &Self) -> Self {
debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
let mut result = Self::zero(); let mut result = Self::zero();
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
result.coeffs[i] = (self.coeffs[i] + other.coeffs[i]).rem_euclid(VOLE_Q); result.coeffs[i] =
ct_reduce((self.coeffs[i] as i128) + (other.coeffs[i] as i128), VOLE_Q);
} }
result result
} }
/// Constant-time subtraction mod q
pub fn sub(&self, other: &Self) -> Self { pub fn sub(&self, other: &Self) -> Self {
debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
let mut result = Self::zero(); let mut result = Self::zero();
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
result.coeffs[i] = (self.coeffs[i] - other.coeffs[i]).rem_euclid(VOLE_Q); result.coeffs[i] = ct_reduce(
(self.coeffs[i] as i128) - (other.coeffs[i] as i128) + (VOLE_Q as i128),
VOLE_Q,
);
} }
result result
} }
/// Constant-time ring multiplication (negacyclic convolution)
pub fn mul(&self, other: &Self) -> Self { pub fn mul(&self, other: &Self) -> Self {
debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
debug_assert!(other.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
let mut result = [0i128; 2 * VOLE_RING_N]; let mut result = [0i128; 2 * VOLE_RING_N];
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
for j in 0..VOLE_RING_N { for j in 0..VOLE_RING_N {
@@ -135,15 +246,20 @@ impl VoleRingElement {
let mut out = Self::zero(); let mut out = Self::zero();
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
let combined = result[i] - result[i + VOLE_RING_N]; let combined = result[i] - result[i + VOLE_RING_N];
out.coeffs[i] = (combined.rem_euclid(VOLE_Q as i128)) as i64; out.coeffs[i] = ct_reduce(combined, VOLE_Q);
} }
debug_assert!(out.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
out out
} }
/// Constant-time scalar multiplication mod q
pub fn scalar_mul(&self, scalar: i64) -> Self { pub fn scalar_mul(&self, scalar: i64) -> Self {
debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
let mut result = Self::zero(); let mut result = Self::zero();
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
result.coeffs[i] = (self.coeffs[i] * scalar).rem_euclid(VOLE_Q); result.coeffs[i] = ct_reduce((self.coeffs[i] as i128) * (scalar as i128), VOLE_Q);
} }
result result
} }
@@ -164,30 +280,43 @@ impl VoleRingElement {
/// LWR: Deterministic rounding from Zq to Zp (THE PATENT CLAIM - NO HELPERS!) /// LWR: Deterministic rounding from Zq to Zp (THE PATENT CLAIM - NO HELPERS!)
/// round(v) = floor((v * p + q/2) / q) mod p /// round(v) = floor((v * p + q/2) / q) mod p
/// This is constant-time: no branching on coefficient values
pub fn round_lwr(&self) -> [u8; VOLE_RING_N] { pub fn round_lwr(&self) -> [u8; VOLE_RING_N] {
debug_assert!(self.coeffs.iter().all(|&c| c >= 0 && c < VOLE_Q));
let mut rounded = [0u8; VOLE_RING_N]; let mut rounded = [0u8; VOLE_RING_N];
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
let v = self.coeffs[i].rem_euclid(VOLE_Q); let v = ct_reduce(self.coeffs[i] as i128, VOLE_Q);
// Standard LWR rounding with proper rounding (not truncation)
let scaled = (v * VOLE_P + VOLE_Q / 2) / VOLE_Q; let scaled = (v * VOLE_P + VOLE_Q / 2) / VOLE_Q;
rounded[i] = (scaled % VOLE_P) as u8; rounded[i] = (scaled % VOLE_P) as u8;
} }
rounded rounded
} }
/// Constant-time approximate equality check
pub fn eq_approx(&self, other: &Self, tolerance: i64) -> bool { pub fn eq_approx(&self, other: &Self, tolerance: i64) -> bool {
debug_assert!(tolerance >= 0);
let mut all_within = Choice::from(1u8);
for i in 0..VOLE_RING_N { for i in 0..VOLE_RING_N {
let diff = (self.coeffs[i] - other.coeffs[i]).rem_euclid(VOLE_Q); let diff = ct_reduce(
let abs_diff = if diff > VOLE_Q / 2 { (self.coeffs[i] as i128) - (other.coeffs[i] as i128) + (VOLE_Q as i128),
VOLE_Q - diff VOLE_Q,
} else { );
diff // Compute |diff| in constant time (diff is in [0, q-1])
}; let half_q = VOLE_Q / 2;
if abs_diff > tolerance { let is_large = ct_lt(half_q, diff);
return false; let abs_diff = ct_select(diff, VOLE_Q - diff, is_large);
let within_tol = ct_lt(abs_diff, tolerance + 1);
all_within &= within_tol;
} }
all_within.into()
} }
true
/// Constant-time equality check for OPRF output verification
pub fn ct_eq_output(a: &[u8; VOLE_OUTPUT_LEN], b: &[u8; VOLE_OUTPUT_LEN]) -> bool {
ct_eq(a, b)
} }
} }
@@ -1040,7 +1169,7 @@ mod tests {
let server_key_seed = b"server-key"; let server_key_seed = b"server-key";
let reg_request = vole_client_start_registration(username); let reg_request = vole_client_start_registration(username);
let (user_record, reg_response) = let (_user_record, reg_response) =
vole_server_register(&reg_request, password, server_key_seed); vole_server_register(&reg_request, password, server_key_seed);
let client_credential = vole_client_finish_registration(&reg_request, &reg_response); let client_credential = vole_client_finish_registration(&reg_request, &reg_response);
@@ -1097,7 +1226,7 @@ mod tests {
outputs.push(output); outputs.push(output);
} }
for (i, out) in outputs.iter().enumerate().skip(1) { for out in outputs.iter().skip(1) {
assert_eq!( assert_eq!(
outputs[0].value, out.value, outputs[0].value, out.value,
"All outputs must be identical for same password" "All outputs must be identical for same password"
@@ -1121,7 +1250,7 @@ mod tests {
let client_credential = vole_client_finish_registration(&reg_request, &reg_response); let client_credential = vole_client_finish_registration(&reg_request, &reg_response);
let (_, login_request) = vole_client_login(&client_credential, password); let (_, login_request) = vole_client_login(&client_credential, password);
let login_response = vole_server_login(&user_record, &login_request); let _login_response = vole_server_login(&user_record, &login_request);
let request_size = login_request.username.len() + let request_size = login_request.username.len() +
8 + // pcg_index (u64) 8 + // pcg_index (u64)
@@ -1146,4 +1275,227 @@ mod tests {
println!("\n[PASS] Single-round protocol with minimal communication!"); println!("\n[PASS] Single-round protocol with minimal communication!");
} }
#[test]
fn test_constant_time_reduce() {
println!("\n=== TEST: Constant-Time Reduction ===");
let test_values: Vec<i128> = vec![
0,
1,
-1,
VOLE_Q as i128,
-VOLE_Q as i128,
VOLE_Q as i128 + 1,
VOLE_Q as i128 - 1,
-(VOLE_Q as i128) + 1,
i64::MAX as i128,
i64::MIN as i128,
(VOLE_Q as i128) * 1000,
-(VOLE_Q as i128) * 1000,
];
for &x in &test_values {
let result = ct_reduce(x, VOLE_Q);
assert!(
result >= 0 && result < VOLE_Q,
"ct_reduce({}) = {} should be in [0, {})",
x,
result,
VOLE_Q
);
let expected = x.rem_euclid(VOLE_Q as i128) as i64;
assert_eq!(
result, expected,
"ct_reduce({}) = {} should equal {}",
x, result, expected
);
}
println!("[PASS] ct_reduce produces correct results for all edge cases");
}
#[test]
fn test_constant_time_select() {
println!("\n=== TEST: Constant-Time Select ===");
let a = 42i64;
let b = 99i64;
let result_a = ct_select(a, b, Choice::from(0));
let result_b = ct_select(a, b, Choice::from(1));
assert_eq!(result_a, a, "ct_select with choice=0 should return a");
assert_eq!(result_b, b, "ct_select with choice=1 should return b");
println!("[PASS] ct_select works correctly");
}
#[test]
fn test_constant_time_eq() {
println!("\n=== TEST: Constant-Time Equality ===");
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4, 5];
let c = [1u8, 2, 3, 4, 6];
assert!(ct_eq(&a, &b), "Equal arrays should compare equal");
assert!(!ct_eq(&a, &c), "Different arrays should compare unequal");
assert!(
!ct_eq(&a, &[1, 2, 3]),
"Different length arrays should compare unequal"
);
println!("[PASS] ct_eq works correctly");
}
#[test]
fn test_constant_time_lt() {
println!("\n=== TEST: Constant-Time Less-Than ===");
assert!(bool::from(ct_lt(5, 10)), "5 < 10");
assert!(!bool::from(ct_lt(10, 5)), "10 >= 5");
assert!(!bool::from(ct_lt(5, 5)), "5 == 5, not less");
assert!(bool::from(ct_lt(0, 1)), "0 < 1");
assert!(bool::from(ct_lt(0, VOLE_Q)), "0 < q");
println!("[PASS] ct_lt works correctly");
}
#[test]
fn test_timing_attack_resistance_operations() {
println!("\n=== TEST: Timing Attack Resistance - Operations ===");
use std::time::Instant;
let elem_small = VoleRingElement::sample_small(b"small-seed", 1);
let elem_large = VoleRingElement::sample_uniform(b"large-seed");
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let _ = elem_small.mul(&elem_large);
}
let small_time = start.elapsed();
let start = Instant::now();
for _ in 0..iterations {
let _ = elem_large.mul(&elem_large);
}
let large_time = start.elapsed();
let ratio = small_time.as_nanos() as f64 / large_time.as_nanos() as f64;
println!(
"Small coeffs mul: {:?}, Large coeffs mul: {:?}, Ratio: {:.3}",
small_time, large_time, ratio
);
assert!(
ratio > 0.3 && ratio < 3.0,
"Timing should be roughly similar regardless of coefficient values (ratio: {:.3})",
ratio
);
println!("[PASS] Multiplication timing is roughly independent of coefficient values");
}
#[test]
fn test_timing_attack_resistance_lwr() {
println!("\n=== TEST: Timing Attack Resistance - LWR Rounding ===");
use std::time::Instant;
let mut elem_zeros = VoleRingElement::zero();
let mut elem_max = VoleRingElement::zero();
for i in 0..VOLE_RING_N {
elem_zeros.coeffs[i] = 0;
elem_max.coeffs[i] = VOLE_Q - 1;
}
let iterations = 1000;
let start = Instant::now();
for _ in 0..iterations {
let _ = elem_zeros.round_lwr();
}
let zeros_time = start.elapsed();
let start = Instant::now();
for _ in 0..iterations {
let _ = elem_max.round_lwr();
}
let max_time = start.elapsed();
let ratio = zeros_time.as_nanos() as f64 / max_time.as_nanos() as f64;
println!(
"Zero coeffs LWR: {:?}, Max coeffs LWR: {:?}, Ratio: {:.3}",
zeros_time, max_time, ratio
);
assert!(
ratio > 0.3 && ratio < 3.0,
"LWR timing should be roughly independent of coefficient values (ratio: {:.3})",
ratio
);
println!("[PASS] LWR rounding timing is roughly independent of coefficient values");
}
#[test]
fn test_timing_attack_resistance_eq_approx() {
println!("\n=== TEST: Timing Attack Resistance - Approximate Equality ===");
use std::time::Instant;
let elem_a = VoleRingElement::sample_uniform(b"elem-a");
let elem_b = VoleRingElement::sample_uniform(b"elem-b");
let elem_a_clone = elem_a.clone();
let iterations = 1000;
let start = Instant::now();
for _ in 0..iterations {
let _ = elem_a.eq_approx(&elem_a_clone, 1);
}
let equal_time = start.elapsed();
let start = Instant::now();
for _ in 0..iterations {
let _ = elem_a.eq_approx(&elem_b, 1);
}
let unequal_time = start.elapsed();
let ratio = equal_time.as_nanos() as f64 / unequal_time.as_nanos() as f64;
println!(
"Equal comparison: {:?}, Unequal comparison: {:?}, Ratio: {:.3}",
equal_time, unequal_time, ratio
);
assert!(
ratio > 0.3 && ratio < 3.0,
"eq_approx timing should be roughly independent of result (ratio: {:.3})",
ratio
);
println!("[PASS] eq_approx timing is roughly independent of comparison result");
}
#[test]
fn test_constant_time_output_comparison() {
println!("\n=== TEST: Constant-Time Output Comparison ===");
let out_a: [u8; VOLE_OUTPUT_LEN] = [0xaa; VOLE_OUTPUT_LEN];
let out_b: [u8; VOLE_OUTPUT_LEN] = [0xaa; VOLE_OUTPUT_LEN];
let out_c: [u8; VOLE_OUTPUT_LEN] = [0xbb; VOLE_OUTPUT_LEN];
assert!(
VoleRingElement::ct_eq_output(&out_a, &out_b),
"Equal outputs should compare equal"
);
assert!(
!VoleRingElement::ct_eq_output(&out_a, &out_c),
"Different outputs should compare unequal"
);
println!("[PASS] Constant-time output comparison works correctly");
}
} }