diff --git a/src/debug.rs b/src/debug.rs index fce67e9..7e7007b 100644 --- a/src/debug.rs +++ b/src/debug.rs @@ -12,6 +12,7 @@ macro_rules! trace { pub(crate) use trace; +#[allow(dead_code)] pub fn hex_preview(data: &[u8], len: usize) -> String { let preview: Vec = data .iter() diff --git a/src/lib.rs b/src/lib.rs index 3bf1e51..35a0764 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ #![forbid(unsafe_code)] pub mod ake; +pub(crate) mod debug; pub mod envelope; pub mod error; pub mod kdf; diff --git a/src/oprf/fast_oprf.rs b/src/oprf/fast_oprf.rs index 815707f..60211c6 100644 --- a/src/oprf/fast_oprf.rs +++ b/src/oprf/fast_oprf.rs @@ -284,41 +284,21 @@ impl ReconciliationHelper { for i in 0..RING_N { let w = client_value.coeffs[i].rem_euclid(Q); let server_quadrant = self.quadrants[i]; + let client_quadrant = ((w / q4) % 4) as u8; - // Server's bit: quadrants 0,1 → 0; quadrants 2,3 → 1 let server_bit = server_quadrant / 2; + let client_bit = if w >= q2 { 1u8 } else { 0u8 }; - // Client's naive bit - let _client_naive_bit = if w >= q2 { 1u8 } else { 0u8 }; + let is_adjacent = client_quadrant == server_quadrant + || (client_quadrant + 1) % 4 == server_quadrant + || (server_quadrant + 1) % 4 == client_quadrant; - // Check if client is in a "danger zone" near the Q/2 boundary - // Danger zones: [Q/4, Q/2) and [3Q/4, Q) - where small errors could flip the bit - let client_quadrant = (w / q4) as u8; - - // Reconciliation: trust the server's quadrant hint - // If error is < Q/4, client's value is within one quadrant of server's - // So we can use server's quadrant to determine the correct bit - let agreed_bit = if client_quadrant == server_quadrant { - // Same quadrant: both agree - server_bit - } else if (client_quadrant + 1) % 4 == server_quadrant - || (server_quadrant + 1) % 4 == client_quadrant - { - // Adjacent quadrants: use server's hint (error pushed us across boundary) - server_bit - } else { - // Far quadrants (error > Q/4): this shouldn't happen with proper parameters - // Fall back to server's hint - server_bit - }; - - bits[i] = agreed_bit; + bits[i] = if is_adjacent { server_bit } else { client_bit }; } bits } - /// Compute the server's bits directly (for testing) pub fn server_bits(elem: &RingElement) -> [u8; RING_N] { let mut bits = [0u8; RING_N]; let q2 = Q / 2; @@ -417,13 +397,11 @@ impl PublicParams { } impl ServerKey { - /// Generate a new server key pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self { - println!("[ServerKey] Generating from seed"); + trace!("[ServerKey] Generating from seed"); - // Sample small secret k let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND); - println!( + trace!( "[ServerKey] k L∞ norm: {} (bound: {})", k.linf_norm(), ERROR_BOUND @@ -433,17 +411,15 @@ impl ServerKey { "Server key exceeds error bound!" ); - // Sample small error let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND); - println!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm()); + trace!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm()); debug_assert!( e_k.linf_norm() <= ERROR_BOUND, "Server error exceeds error bound!" ); - // B = A*k + e_k let b = pp.a.mul(&k).add(&e_k); - println!("[ServerKey] B L∞ norm: {}", b.linf_norm()); + trace!("[ServerKey] B L∞ norm: {}", b.linf_norm()); Self { k, @@ -459,13 +435,11 @@ impl ServerKey { } } -/// Client blinds the password for oblivious evaluation pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) { - println!("[Client] Blinding password of {} bytes", password.len()); + trace!("[Client] Blinding password of {} bytes", password.len()); - // Derive small s from password (deterministic!) let s = RingElement::sample_small(password, ERROR_BOUND); - println!( + trace!( "[Client] s L∞ norm: {} (bound: {})", s.linf_norm(), ERROR_BOUND @@ -475,88 +449,76 @@ pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, Blinded "Client secret exceeds error bound!" ); - // Derive small e from password (deterministic!) let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND); - println!("[Client] e L∞ norm: {}", e.linf_norm()); + trace!("[Client] e L∞ norm: {}", e.linf_norm()); debug_assert!( e.linf_norm() <= ERROR_BOUND, "Client error exceeds error bound!" ); - // C = A*s + e let c = pp.a.mul(&s).add(&e); - println!("[Client] C L∞ norm: {}", c.linf_norm()); + trace!("[Client] C L∞ norm: {}", c.linf_norm()); - let state = ClientState { s }; - - let blinded = BlindedInput { c }; - - (state, blinded) + (ClientState { s }, BlindedInput { c }) } -/// Server evaluates the OPRF on blinded input pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse { - println!("[Server] Evaluating on blinded input"); - println!("[Server] C L∞ norm: {}", blinded.c.linf_norm()); + trace!("[Server] Evaluating on blinded input"); + trace!("[Server] C L∞ norm: {}", blinded.c.linf_norm()); - // V = k * C let v = key.k.mul(&blinded.c); - println!("[Server] V L∞ norm: {}", v.linf_norm()); + trace!("[Server] V L∞ norm: {}", v.linf_norm()); - // Compute reconciliation helper let helper = ReconciliationHelper::from_ring(&v); - println!("[Server] Generated reconciliation helper"); + trace!("[Server] Generated reconciliation helper"); ServerResponse { v, helper } } -/// Client finalizes to get OPRF output pub fn client_finalize( state: &ClientState, server_public: &RingElement, response: &ServerResponse, ) -> OprfOutput { - println!("[Client] Finalizing OPRF output"); + trace!("[Client] Finalizing OPRF output"); - // W = s * B = s * (A*k + e_k) = s*A*k + s*e_k let w = state.s.mul(server_public); - println!("[Client] W L∞ norm: {}", w.linf_norm()); + trace!("[Client] W L∞ norm: {}", w.linf_norm()); - // The difference V - W should be small: - // V = k * C = k * (A*s + e) = k*A*s + k*e - // W = s * B = s * (A*k + e_k) = s*A*k + s*e_k - // V - W = k*e - s*e_k - // Since k, e, s, e_k are all small, the difference is small! - let diff = response.v.sub(&w); - println!( - "[Client] V - W L∞ norm: {} (should be small, ~{} max)", - diff.linf_norm(), - ERROR_BOUND * ERROR_BOUND * RING_N as i32 - ); + #[cfg(feature = "debug-trace")] + { + let diff = response.v.sub(&w); + trace!( + "[Client] V - W L∞ norm: {} (should be small, ~{} max)", + diff.linf_norm(), + ERROR_BOUND * ERROR_BOUND * RING_N as i32 + ); + } let bits = response.helper.extract_bits(&w); - // Count how many bits match direct rounding - let v_bits = response.v.round_to_binary(); - let matching: usize = bits - .iter() - .zip(v_bits.iter()) - .filter(|(a, b)| a == b) - .count(); - println!( - "[Client] Reconciliation accuracy: {}/{} ({:.1}%)", - matching, - RING_N, - matching as f64 / RING_N as f64 * 100.0 - ); + #[cfg(feature = "debug-trace")] + { + let v_bits = response.v.round_to_binary(); + let matching: usize = bits + .iter() + .zip(v_bits.iter()) + .filter(|(a, b)| a == b) + .count(); + trace!( + "[Client] Reconciliation accuracy: {}/{} ({:.1}%)", + matching, + RING_N, + matching as f64 / RING_N as f64 * 100.0 + ); + } - // Hash the reconciled bits to get final output let mut hasher = Sha3_256::new(); hasher.update(b"FastOPRF-Output-v1"); hasher.update(&bits); let hash: [u8; 32] = hasher.finalize().into(); - println!("[Client] Output: {:02x?}...", &hash[..4]); + trace!("[Client] Output: {:02x?}...", &hash[..4]); OprfOutput { value: hash } }