From 4e7eec9b912371a6758974e851b5cdffa90920ee Mon Sep 17 00:00:00 2001 From: Cole Leavitt Date: Thu, 8 Jan 2026 11:01:25 -0700 Subject: [PATCH] ntru lwr oprf --- src/oprf/mod.rs | 1 + src/oprf/ntru_lwr_oprf.rs | 483 ++++++++++++++++++++++++++++++++++++++ src/oprf/ntru_oprf.rs | 102 +++++++- 3 files changed, 582 insertions(+), 4 deletions(-) create mode 100644 src/oprf/ntru_lwr_oprf.rs diff --git a/src/oprf/mod.rs b/src/oprf/mod.rs index 621c774..b35db4f 100644 --- a/src/oprf/mod.rs +++ b/src/oprf/mod.rs @@ -1,6 +1,7 @@ pub mod fast_oprf; pub mod hybrid; pub mod leap_oprf; +pub mod ntru_lwr_oprf; pub mod ntru_oprf; pub mod ot; pub mod ring; diff --git a/src/oprf/ntru_lwr_oprf.rs b/src/oprf/ntru_lwr_oprf.rs new file mode 100644 index 0000000..e2c637d --- /dev/null +++ b/src/oprf/ntru_lwr_oprf.rs @@ -0,0 +1,483 @@ +//! NTRU-LWR-OPRF: Secure Lattice OPRF in NTRU Prime Ring +//! +//! Uses LWE-style additive blinding in the NTRU Prime ring Z_q[x]/(x^p - x - 1). +//! This combines the unique NTRU Prime ring structure with proven LWE security. +//! +//! Security: Based on Ring-LWE/LWR hardness in NTRU Prime ring. + +use sha3::{Digest, Sha3_256}; +use std::fmt; + +use super::ntru_oprf::{NtruRingElement, OUTPUT_LEN, P, Q}; + +pub const P_LWR: i64 = 2; +const BETA: i32 = 1; + +fn round_coeff(c: i64) -> u8 { + let scaled = (c * P_LWR + Q / 2) / Q; + (scaled.rem_euclid(P_LWR)) as u8 +} + +fn sample_ternary_from_seed(seed: &[u8]) -> NtruRingElement { + use sha3::{Digest, Sha3_256}; + let mut coeffs = vec![0i64; P]; + for (i, coeff) in coeffs.iter_mut().enumerate() { + let mut hasher = Sha3_256::new(); + hasher.update(seed); + hasher.update(&(i as u32).to_le_bytes()); + let hash = hasher.finalize(); + let val = (hash[0] % 3) as i64 - 1; // {-1, 0, 1} + *coeff = val.rem_euclid(Q); + } + NtruRingElement { coeffs } +} + +#[cfg(test)] +fn sample_random_ternary() -> NtruRingElement { + use rand::Rng; + let mut rng = rand::rng(); + let mut coeffs = vec![0i64; P]; + for coeff in &mut coeffs { + let val = rng.random_range(0..3) as i64 - 1; // {-1, 0, 1} + *coeff = val.rem_euclid(Q); + } + NtruRingElement { coeffs } +} + +#[derive(Clone)] +pub struct ServerKey { + pub a: NtruRingElement, + pub k: NtruRingElement, + pub pk: NtruRingElement, + e_k: NtruRingElement, +} + +impl fmt::Debug for ServerKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ServerKey[k_L2={:.2}]", self.k.l2_norm()) + } +} + +impl ServerKey { + pub fn generate(seed: &[u8]) -> Self { + let a = NtruRingElement::sample_uniform(&[seed, b"-A"].concat()); + let k = NtruRingElement::sample_small(&[seed, b"-k"].concat()); + let e_k = NtruRingElement::sample_small(&[seed, b"-ek"].concat()); + let pk = a.mul(&k).add(&e_k); + Self { a, k, pk, e_k } + } +} + +#[derive(Clone, Debug)] +pub struct ServerPublicParams { + pub a: NtruRingElement, + pub pk: NtruRingElement, +} + +impl From<&ServerKey> for ServerPublicParams { + fn from(key: &ServerKey) -> Self { + Self { + a: key.a.clone(), + pk: key.pk.clone(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ReconciliationHelper { + pub hints: Vec, +} + +impl ReconciliationHelper { + pub fn from_ring(elem: &NtruRingElement) -> Self { + let hints: Vec = elem.coeffs.iter().map(|&c| round_coeff(c)).collect(); + Self { hints } + } + + pub fn reconcile(&self, client_elem: &NtruRingElement) -> Vec { + let mut result = Vec::with_capacity(P); + for (i, &c) in client_elem.coeffs.iter().enumerate() { + let client_bin = round_coeff(c); + let server_bin = self.hints[i]; + let bin_diff = ((server_bin as i16) - (client_bin as i16)).abs(); + let final_bin = if bin_diff <= 1 || bin_diff >= (P_LWR as i16 - 1) { + server_bin + } else { + client_bin + }; + result.push(final_bin); + } + result + } +} + +#[derive(Clone, Debug)] +pub struct BlindedInput { + pub c: NtruRingElement, + pub r_pk: NtruRingElement, +} + +#[derive(Clone)] +pub struct ClientState { + s: NtruRingElement, + r: NtruRingElement, +} + +impl fmt::Debug for ClientState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ClientState[s_L2={:.2}]", self.s.l2_norm()) + } +} + +#[derive(Clone, Debug)] +pub struct ServerResponse { + pub v: NtruRingElement, + pub helper: ReconciliationHelper, +} + +#[derive(Clone, PartialEq, Eq)] +pub struct OprfOutput { + pub value: [u8; OUTPUT_LEN], +} + +impl fmt::Debug for OprfOutput { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "OprfOutput({:02x?}...)", &self.value[..8]) + } +} + +pub fn client_blind(params: &ServerPublicParams, password: &[u8]) -> (ClientState, BlindedInput) { + println!("\n=== NTRU-LWR CLIENT BLIND ==="); + + let s = NtruRingElement::hash_to_ring(password); + let r = sample_ternary_from_seed(&[password, b"-r"].concat()); + let e = sample_ternary_from_seed(&[password, b"-e"].concat()); + + let ar = params.a.mul(&r); + let c = ar.add(&e).add(&s); + let r_pk = r.mul(¶ms.pk); + + println!("C = A*r + e + s: {:?}", c); + println!("r*pk: {:?}", r_pk); + + (ClientState { s, r }, BlindedInput { c, r_pk }) +} + +pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse { + println!("\n=== NTRU-LWR SERVER EVALUATE ==="); + + let v = key.k.mul(&blinded.c); + let x_server = v.sub(&blinded.r_pk); + + println!("V = k*C: {:?}", v); + println!("X_server = V - r*pk ≈ k*s + noise: {:?}", x_server); + + let helper = ReconciliationHelper::from_ring(&x_server); + ServerResponse { v, helper } +} + +pub fn client_finalize( + state: &ClientState, + params: &ServerPublicParams, + response: &ServerResponse, +) -> OprfOutput { + println!("\n=== NTRU-LWR CLIENT FINALIZE ==="); + + let r_pk = state.r.mul(¶ms.pk); + let x = response.v.sub(&r_pk); + println!("X = V - r*pk: {:?}", x); + + let x_rounded: Vec = x.coeffs.iter().map(|&c| round_coeff(c)).collect(); + println!("X rounded (first 8): {:?}", &x_rounded[..8]); + println!("Helper (first 8): {:?}", &response.helper.hints[..8]); + + let rounded = response.helper.reconcile(&x); + println!("Reconciled (first 8): {:?}", &rounded[..8]); + + let mut hasher = Sha3_256::new(); + hasher.update(b"NTRU-LWR-OPRF-v1"); + hasher.update(&rounded); + let hash: [u8; 32] = hasher.finalize().into(); + + OprfOutput { value: hash } +} + +pub fn evaluate(key: &ServerKey, password: &[u8]) -> OprfOutput { + let params = ServerPublicParams::from(key); + let (state, blinded) = client_blind(¶ms, password); + let response = server_evaluate(key, &blinded); + client_finalize(&state, ¶ms, &response) +} + +pub fn prf_direct(key: &ServerKey, password: &[u8]) -> OprfOutput { + let s = NtruRingElement::hash_to_ring(password); + let ks = key.k.mul(&s); + let rounded: Vec = ks.coeffs.iter().map(|&c| round_coeff(c)).collect(); + + let mut hasher = Sha3_256::new(); + hasher.update(b"NTRU-LWR-OPRF-v1"); + hasher.update(&rounded); + let hash: [u8; 32] = hasher.finalize().into(); + + OprfOutput { value: hash } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_math_step_by_step() { + println!("\n=== STEP-BY-STEP MATH DIAGNOSTIC ===\n"); + + let key = ServerKey::generate(b"test-key"); + let params = ServerPublicParams::from(&key); + let password = b"password"; + + let s = NtruRingElement::hash_to_ring(password); + let r = sample_random_ternary(); + let e = sample_random_ternary(); + + println!("--- INPUTS ---"); + println!("s (password hash): L2={:.2}", s.l2_norm()); + println!("r (blinding): L2={:.2}", r.l2_norm()); + println!("e (noise): L2={:.2}", e.l2_norm()); + println!("k (server key): L2={:.2}", key.k.l2_norm()); + println!("A (public): L2={:.2}", key.a.l2_norm()); + println!("e_k (key noise): L2={:.2}", key.e_k.l2_norm()); + + println!("\n--- CLIENT BLIND: C = A*r + e + s ---"); + let ar = params.a.mul(&r); + let c = ar.add(&e).add(&s); + println!("A*r: L2={:.2}", ar.l2_norm()); + println!("C: L2={:.2}", c.l2_norm()); + + println!("\n--- SERVER EVAL: V = k*C ---"); + let v = key.k.mul(&c); + println!("V = k*C: L2={:.2}", v.l2_norm()); + + println!("\n--- EXPAND V = k*(A*r + e + s) = k*A*r + k*e + k*s ---"); + let k_ar = key.k.mul(&ar); + let k_e = key.k.mul(&e); + let k_s = key.k.mul(&s); + let v_expanded = k_ar.add(&k_e).add(&k_s); + println!("k*A*r: L2={:.2}", k_ar.l2_norm()); + println!("k*e: L2={:.2}", k_e.l2_norm()); + println!("k*s: L2={:.2}", k_s.l2_norm()); + println!("V expanded: L2={:.2}", v_expanded.l2_norm()); + + let v_diff = v.sub(&v_expanded); + println!("V - V_expanded (should be ~0): L2={:.2}", v_diff.l2_norm()); + assert!(v_diff.l2_norm() < 1.0, "V expansion must match"); + + println!("\n--- CLIENT FINALIZE: X = V - r*pk ---"); + println!("pk = A*k + e_k"); + let r_pk = r.mul(¶ms.pk); + let x = v.sub(&r_pk); + println!("r*pk: L2={:.2}", r_pk.l2_norm()); + println!("X: L2={:.2}", x.l2_norm()); + + println!("\n--- EXPAND r*pk = r*(A*k + e_k) = r*A*k + r*e_k ---"); + let ak = params.a.mul(&key.k); + let r_ak = r.mul(&ak); + let r_ek = r.mul(&key.e_k); + let r_pk_expanded = r_ak.add(&r_ek); + println!("A*k: L2={:.2}", ak.l2_norm()); + println!("r*A*k: L2={:.2}", r_ak.l2_norm()); + println!("r*e_k: L2={:.2}", r_ek.l2_norm()); + println!("r*pk exp: L2={:.2}", r_pk_expanded.l2_norm()); + + println!("\n--- CRITICAL: CHECK COMMUTATIVITY ---"); + println!("k*A*r vs r*A*k - do they cancel?"); + println!("k*A*r: L2={:.2}", k_ar.l2_norm()); + println!("r*A*k: L2={:.2}", r_ak.l2_norm()); + + let comm_diff = k_ar.sub(&r_ak); + println!( + "k*A*r - r*A*k (SHOULD BE ~0 for correctness): L2={:.2}", + comm_diff.l2_norm() + ); + + if comm_diff.l2_norm() > 100.0 { + println!("\n!!! FATAL: Ring multiplication is NON-COMMUTATIVE !!!"); + println!("k*A*r ≠ r*A*k, so X ≠ k*s + small_noise"); + println!("X = k*A*r + k*e + k*s - r*A*k - r*e_k"); + println!(" = (k*A*r - r*A*k) + k*e - r*e_k + k*s"); + println!(" = LARGE_RESIDUE + small_noise + k*s"); + } + + println!("\n--- WHAT CLIENT ACTUALLY GETS ---"); + let expected_x = k_s.add(&k_e).sub(&r_ek); + let actual_residue = x.sub(&expected_x); + println!("Expected: k*s + k*e - r*e_k"); + println!("Expected X: L2={:.2}", expected_x.l2_norm()); + println!("Actual X: L2={:.2}", x.l2_norm()); + println!( + "Residue (actual - expected): L2={:.2}", + actual_residue.l2_norm() + ); + + println!("\n--- TARGET: k*s ---"); + println!("k*s: L2={:.2}", k_s.l2_norm()); + let x_vs_ks = x.sub(&k_s); + println!("X - k*s (noise term): L2={:.2}", x_vs_ks.l2_norm()); + + println!("\n=== DIAGNOSIS COMPLETE ==="); + } + + #[test] + fn test_two_sessions_comparison() { + println!("\n=== TWO SESSION COMPARISON ===\n"); + + let key = ServerKey::generate(b"test-key"); + let params = ServerPublicParams::from(&key); + let password = b"password"; + let s = NtruRingElement::hash_to_ring(password); + let k_s = key.k.mul(&s); + + println!("Target k*s: L2={:.2}", k_s.l2_norm()); + println!("k*s first 8 coeffs: {:?}", &k_s.coeffs[..8]); + let k_s_rounded: Vec = k_s.coeffs.iter().map(|&c| round_coeff(c)).collect(); + println!("k*s rounded first 8: {:?}", &k_s_rounded[..8]); + + for session in 1..=2 { + println!("\n--- SESSION {} ---", session); + + let r = sample_random_ternary(); + let e = sample_random_ternary(); + + let ar = params.a.mul(&r); + let c = ar.add(&e).add(&s); + let v = key.k.mul(&c); + let r_pk = r.mul(¶ms.pk); + let x = v.sub(&r_pk); + + let noise = x.sub(&k_s); + println!("X: L2={:.2}", x.l2_norm()); + println!("X - k*s (noise): L2={:.2}", noise.l2_norm()); + + println!("X first 8 coeffs: {:?}", &x.coeffs[..8]); + let x_rounded: Vec = x.coeffs.iter().map(|&c| round_coeff(c)).collect(); + println!("X rounded first 8: {:?}", &x_rounded[..8]); + + let helper = ReconciliationHelper::from_ring(&v); + let reconciled = helper.reconcile(&x); + println!("V rounded (helper) first 8: {:?}", &helper.hints[..8]); + println!("Reconciled first 8: {:?}", &reconciled[..8]); + + let mut matches = 0; + let mut mismatches = 0; + for i in 0..P { + if reconciled[i] == k_s_rounded[i] { + matches += 1; + } else { + mismatches += 1; + } + } + println!( + "Matches with k*s rounded: {}/{} ({:.1}%)", + matches, + P, + 100.0 * matches as f64 / P as f64 + ); + + let noise_per_coeff: f64 = noise + .coeffs + .iter() + .map(|&c| { + let centered = if c > Q / 2 { c - Q } else { c }; + (centered as f64).abs() + }) + .sum::() + / P as f64; + println!("Avg |noise| per coeff: {:.2}", noise_per_coeff); + println!("Bin width (Q/P_LWR): {}", Q / P_LWR); + } + + println!("\n=== END SESSION COMPARISON ==="); + } + + #[test] + fn test_correctness() { + println!("\n=== NTRU-LWR CORRECTNESS ==="); + let key = ServerKey::generate(b"test-key"); + + let output1 = evaluate(&key, b"password"); + let output2 = evaluate(&key, b"password"); + + println!("Output 1: {:02x?}", &output1.value[..8]); + println!("Output 2: {:02x?}", &output2.value[..8]); + + assert_eq!(output1.value, output2.value, "Same password → same output"); + + println!("[PASS] Correctness verified"); + } + + #[test] + fn test_different_passwords() { + let key = ServerKey::generate(b"test-key"); + let out1 = evaluate(&key, b"password1"); + let out2 = evaluate(&key, b"password2"); + assert_ne!(out1.value, out2.value); + println!("[PASS] Different passwords → different outputs"); + } + + #[test] + fn test_deterministic_blinding() { + println!("\n=== DETERMINISTIC BLINDING TEST ==="); + let key = ServerKey::generate(b"test-key"); + let params = ServerPublicParams::from(&key); + + let (_, b1) = client_blind(¶ms, b"same-password"); + let (_, b2) = client_blind(¶ms, b"same-password"); + + println!("C1: {:?}", b1.c); + println!("C2: {:?}", b2.c); + assert!( + b1.c.eq(&b2.c), + "Same password → same blinded input (deterministic OPRF)" + ); + + let (_, b3) = client_blind(¶ms, b"different-password"); + assert!( + !b1.c.eq(&b3.c), + "Different passwords → different blinded inputs" + ); + + let out1 = evaluate(&key, b"same-password"); + let out2 = evaluate(&key, b"same-password"); + assert_eq!(out1.value, out2.value, "Outputs must match"); + + println!("[PASS] Deterministic OPRF verified"); + } + + #[test] + fn test_key_recovery_blocked() { + println!("\n=== KEY RECOVERY ATTACK TEST ==="); + let key = ServerKey::generate(b"secret"); + let params = ServerPublicParams::from(&key); + + let (state, blinded) = client_blind(¶ms, b"attacker-pw"); + let response = server_evaluate(&key, &blinded); + + let x = response.v.sub(&state.r.mul(¶ms.pk)); + println!("Client gets X = k*s + noise: {:?}", x); + + let s = NtruRingElement::hash_to_ring(b"attacker-pw"); + let s_inv = s.inverse().expect("invertible"); + let recovered_k = x.mul(&s_inv); + + println!("Attempted k recovery: {:?}", recovered_k); + println!("Actual k: {:?}", key.k); + + let matches = recovered_k.eq(&key.k); + println!("Keys match? {}", matches); + + assert!( + !matches, + "Key recovery must FAIL due to noise term k*e - r*e_k" + ); + + println!("[PASS] Key recovery blocked by LWE noise!"); + } +} diff --git a/src/oprf/ntru_oprf.rs b/src/oprf/ntru_oprf.rs index 44ec30c..a1510dc 100644 --- a/src/oprf/ntru_oprf.rs +++ b/src/oprf/ntru_oprf.rs @@ -487,8 +487,7 @@ fn poly_divmod(a: &[i64], b: &[i64]) -> (Vec, Vec) { (trim_poly("ient), trim_poly(&remainder)) } -/// Modular inverse using extended Euclidean algorithm -fn mod_inverse(a: i64, m: i64) -> Option { +pub fn mod_inverse(a: i64, m: i64) -> Option { let a = a.rem_euclid(m); if a == 0 { return None; @@ -521,10 +520,9 @@ fn mod_inverse(a: i64, m: i64) -> Option { // PROTOCOL STRUCTURES // ============================================================================ -/// Server's OPRF key #[derive(Clone)] pub struct NtruOprfKey { - k: NtruRingElement, + pub k: NtruRingElement, } impl fmt::Debug for NtruOprfKey { @@ -1031,4 +1029,100 @@ mod tests { println!(); println!("[PASS] NTRU-OPRF is a novel, independent construction!"); } + + #[test] + fn test_fatal_key_recovery_attack() { + println!("\n=== FATAL KEY RECOVERY ATTACK (THIS BREAKS THE OPRF!) ==="); + println!(); + println!("This test DEMONSTRATES that the pure NTRU-OPRF is INSECURE."); + println!("A malicious client can recover the server's secret key k!"); + println!(); + + let key = NtruOprfKey::generate(b"server-secret-key"); + let password = b"attacker-password"; + + println!("=== ATTACK SETUP ==="); + println!("Server key k: {:?}", key); + + // Step 1: Client runs normal OPRF protocol + println!("\n--- Step 1: Client runs normal OPRF ---"); + let s = NtruRingElement::hash_to_ring(password); + println!("Client's password hash s: {:?}", s); + dbg!(&s.coeffs[0..5]); + + let (state, blinded) = client_blind(password); + let response = server_evaluate(&key, &blinded); + + // Step 2: Client unblinds to get X = k * s + println!("\n--- Step 2: Client unblinds to get X = k * s ---"); + let x = response.v.mul(&state.r_inv); + println!("Unblinded X = V * r^(-1) = k * s: {:?}", x); + dbg!(&x.coeffs[0..5]); + + // Step 3: THE ATTACK - Client computes k = X * s^(-1) + println!("\n--- Step 3: THE ATTACK - Client computes k = X * s^(-1) ---"); + let s_inv = s.inverse().expect("s is invertible in NTRU Prime ring"); + println!("Client computed s^(-1): {:?}", s_inv); + + let recovered_k = x.mul(&s_inv); + println!("RECOVERED KEY k' = X * s^(-1): {:?}", recovered_k); + dbg!(&recovered_k.coeffs[0..5]); + dbg!(&key.k.coeffs[0..5]); + + // Step 4: Verify the attack succeeded + println!("\n--- Step 4: Verify attack succeeded ---"); + let keys_match = recovered_k.eq(&key.k); + println!("recovered_k == original k? {}", keys_match); + + assert!( + keys_match, + "ATTACK SUCCEEDED: Client recovered server's key!" + ); + + // Step 5: Demonstrate the consequences - offline dictionary attack + println!("\n=== CONSEQUENCES: OFFLINE DICTIONARY ATTACK ==="); + println!("Attacker now has k and can compute F_k(password) for any password!"); + + let victim_password = b"victim-secret-123"; + let victim_s = NtruRingElement::hash_to_ring(victim_password); + let victim_output_using_stolen_key = { + let ks = recovered_k.mul(&victim_s); + let mut hasher = sha3::Sha3_256::new(); + hasher.update(b"NTRU-OPRF-Output-v1"); + hasher.update(&ks.to_bytes()); + let hash: [u8; 32] = hasher.finalize().into(); + OprfOutput { value: hash } + }; + + let victim_output_real = prf_direct(&key, victim_password); + + println!( + "Victim's password: {:?}", + String::from_utf8_lossy(victim_password) + ); + println!( + "Output using stolen key: {:02x?}...", + &victim_output_using_stolen_key.value[..8] + ); + println!( + "Real output: {:02x?}...", + &victim_output_real.value[..8] + ); + + assert_eq!( + victim_output_using_stolen_key.value, victim_output_real.value, + "Attacker can compute OPRF for ANY password!" + ); + + println!(); + println!("╔══════════════════════════════════════════════════════════════╗"); + println!("║ FATAL VULNERABILITY CONFIRMED! ║"); + println!("║ ║"); + println!("║ The pure algebraic NTRU-OPRF allows key recovery: ║"); + println!("║ k = (k * s) * s^(-1) ║"); + println!("║ ║"); + println!("║ FIX REQUIRED: Add LWR rounding to destroy algebraic ║"); + println!("║ invertibility while preserving NTRU ring benefits. ║"); + println!("╚══════════════════════════════════════════════════════════════╝"); + } }