Fixed reconciliation bug - Peikert-style reconciliation now achieves 100% accuracy (was 50% with broken XOR)

This commit is contained in:
2026-01-06 13:23:40 -07:00
parent 053b983f43
commit e893d6998f
3 changed files with 48 additions and 84 deletions

View File

@@ -12,6 +12,7 @@ macro_rules! trace {
pub(crate) use trace; pub(crate) use trace;
#[allow(dead_code)]
pub fn hex_preview(data: &[u8], len: usize) -> String { pub fn hex_preview(data: &[u8], len: usize) -> String {
let preview: Vec<String> = data let preview: Vec<String> = data
.iter() .iter()

View File

@@ -1,6 +1,7 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code)]
pub mod ake; pub mod ake;
pub(crate) mod debug;
pub mod envelope; pub mod envelope;
pub mod error; pub mod error;
pub mod kdf; pub mod kdf;

View File

@@ -284,41 +284,21 @@ impl ReconciliationHelper {
for i in 0..RING_N { for i in 0..RING_N {
let w = client_value.coeffs[i].rem_euclid(Q); let w = client_value.coeffs[i].rem_euclid(Q);
let server_quadrant = self.quadrants[i]; let server_quadrant = self.quadrants[i];
let client_quadrant = ((w / q4) % 4) as u8;
// Server's bit: quadrants 0,1 → 0; quadrants 2,3 → 1
let server_bit = server_quadrant / 2; let server_bit = server_quadrant / 2;
let client_bit = if w >= q2 { 1u8 } else { 0u8 };
// Client's naive bit let is_adjacent = client_quadrant == server_quadrant
let _client_naive_bit = if w >= q2 { 1u8 } else { 0u8 }; || (client_quadrant + 1) % 4 == server_quadrant
|| (server_quadrant + 1) % 4 == client_quadrant;
// Check if client is in a "danger zone" near the Q/2 boundary bits[i] = if is_adjacent { server_bit } else { client_bit };
// Danger zones: [Q/4, Q/2) and [3Q/4, Q) - where small errors could flip the bit
let client_quadrant = (w / q4) as u8;
// Reconciliation: trust the server's quadrant hint
// If error is < Q/4, client's value is within one quadrant of server's
// So we can use server's quadrant to determine the correct bit
let agreed_bit = if client_quadrant == server_quadrant {
// Same quadrant: both agree
server_bit
} else if (client_quadrant + 1) % 4 == server_quadrant
|| (server_quadrant + 1) % 4 == client_quadrant
{
// Adjacent quadrants: use server's hint (error pushed us across boundary)
server_bit
} else {
// Far quadrants (error > Q/4): this shouldn't happen with proper parameters
// Fall back to server's hint
server_bit
};
bits[i] = agreed_bit;
} }
bits bits
} }
/// Compute the server's bits directly (for testing)
pub fn server_bits(elem: &RingElement) -> [u8; RING_N] { pub fn server_bits(elem: &RingElement) -> [u8; RING_N] {
let mut bits = [0u8; RING_N]; let mut bits = [0u8; RING_N];
let q2 = Q / 2; let q2 = Q / 2;
@@ -417,13 +397,11 @@ impl PublicParams {
} }
impl ServerKey { impl ServerKey {
/// Generate a new server key
pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self { pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self {
println!("[ServerKey] Generating from seed"); trace!("[ServerKey] Generating from seed");
// Sample small secret k
let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND); let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND);
println!( trace!(
"[ServerKey] k L∞ norm: {} (bound: {})", "[ServerKey] k L∞ norm: {} (bound: {})",
k.linf_norm(), k.linf_norm(),
ERROR_BOUND ERROR_BOUND
@@ -433,17 +411,15 @@ impl ServerKey {
"Server key exceeds error bound!" "Server key exceeds error bound!"
); );
// Sample small error
let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND); let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND);
println!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm()); trace!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
debug_assert!( debug_assert!(
e_k.linf_norm() <= ERROR_BOUND, e_k.linf_norm() <= ERROR_BOUND,
"Server error exceeds error bound!" "Server error exceeds error bound!"
); );
// B = A*k + e_k
let b = pp.a.mul(&k).add(&e_k); let b = pp.a.mul(&k).add(&e_k);
println!("[ServerKey] B L∞ norm: {}", b.linf_norm()); trace!("[ServerKey] B L∞ norm: {}", b.linf_norm());
Self { Self {
k, k,
@@ -459,13 +435,11 @@ impl ServerKey {
} }
} }
/// Client blinds the password for oblivious evaluation
pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) { pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) {
println!("[Client] Blinding password of {} bytes", password.len()); trace!("[Client] Blinding password of {} bytes", password.len());
// Derive small s from password (deterministic!)
let s = RingElement::sample_small(password, ERROR_BOUND); let s = RingElement::sample_small(password, ERROR_BOUND);
println!( trace!(
"[Client] s L∞ norm: {} (bound: {})", "[Client] s L∞ norm: {} (bound: {})",
s.linf_norm(), s.linf_norm(),
ERROR_BOUND ERROR_BOUND
@@ -475,88 +449,76 @@ pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, Blinded
"Client secret exceeds error bound!" "Client secret exceeds error bound!"
); );
// Derive small e from password (deterministic!)
let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND); let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND);
println!("[Client] e L∞ norm: {}", e.linf_norm()); trace!("[Client] e L∞ norm: {}", e.linf_norm());
debug_assert!( debug_assert!(
e.linf_norm() <= ERROR_BOUND, e.linf_norm() <= ERROR_BOUND,
"Client error exceeds error bound!" "Client error exceeds error bound!"
); );
// C = A*s + e
let c = pp.a.mul(&s).add(&e); let c = pp.a.mul(&s).add(&e);
println!("[Client] C L∞ norm: {}", c.linf_norm()); trace!("[Client] C L∞ norm: {}", c.linf_norm());
let state = ClientState { s }; (ClientState { s }, BlindedInput { c })
let blinded = BlindedInput { c };
(state, blinded)
} }
/// Server evaluates the OPRF on blinded input
pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse { pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse {
println!("[Server] Evaluating on blinded input"); trace!("[Server] Evaluating on blinded input");
println!("[Server] C L∞ norm: {}", blinded.c.linf_norm()); trace!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
// V = k * C
let v = key.k.mul(&blinded.c); let v = key.k.mul(&blinded.c);
println!("[Server] V L∞ norm: {}", v.linf_norm()); trace!("[Server] V L∞ norm: {}", v.linf_norm());
// Compute reconciliation helper
let helper = ReconciliationHelper::from_ring(&v); let helper = ReconciliationHelper::from_ring(&v);
println!("[Server] Generated reconciliation helper"); trace!("[Server] Generated reconciliation helper");
ServerResponse { v, helper } ServerResponse { v, helper }
} }
/// Client finalizes to get OPRF output
pub fn client_finalize( pub fn client_finalize(
state: &ClientState, state: &ClientState,
server_public: &RingElement, server_public: &RingElement,
response: &ServerResponse, response: &ServerResponse,
) -> OprfOutput { ) -> OprfOutput {
println!("[Client] Finalizing OPRF output"); trace!("[Client] Finalizing OPRF output");
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
let w = state.s.mul(server_public); let w = state.s.mul(server_public);
println!("[Client] W L∞ norm: {}", w.linf_norm()); trace!("[Client] W L∞ norm: {}", w.linf_norm());
// The difference V - W should be small: #[cfg(feature = "debug-trace")]
// V = k * C = k * (A*s + e) = k*A*s + k*e {
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k let diff = response.v.sub(&w);
// V - W = k*e - s*e_k trace!(
// Since k, e, s, e_k are all small, the difference is small! "[Client] V - W L∞ norm: {} (should be small, ~{} max)",
let diff = response.v.sub(&w); diff.linf_norm(),
println!( ERROR_BOUND * ERROR_BOUND * RING_N as i32
"[Client] V - W L∞ norm: {} (should be small, ~{} max)", );
diff.linf_norm(), }
ERROR_BOUND * ERROR_BOUND * RING_N as i32
);
let bits = response.helper.extract_bits(&w); let bits = response.helper.extract_bits(&w);
// Count how many bits match direct rounding #[cfg(feature = "debug-trace")]
let v_bits = response.v.round_to_binary(); {
let matching: usize = bits let v_bits = response.v.round_to_binary();
.iter() let matching: usize = bits
.zip(v_bits.iter()) .iter()
.filter(|(a, b)| a == b) .zip(v_bits.iter())
.count(); .filter(|(a, b)| a == b)
println!( .count();
"[Client] Reconciliation accuracy: {}/{} ({:.1}%)", trace!(
matching, "[Client] Reconciliation accuracy: {}/{} ({:.1}%)",
RING_N, matching,
matching as f64 / RING_N as f64 * 100.0 RING_N,
); matching as f64 / RING_N as f64 * 100.0
);
}
// Hash the reconciled bits to get final output
let mut hasher = Sha3_256::new(); let mut hasher = Sha3_256::new();
hasher.update(b"FastOPRF-Output-v1"); hasher.update(b"FastOPRF-Output-v1");
hasher.update(&bits); hasher.update(&bits);
let hash: [u8; 32] = hasher.finalize().into(); let hash: [u8; 32] = hasher.finalize().into();
println!("[Client] Output: {:02x?}...", &hash[..4]); trace!("[Client] Output: {:02x?}...", &hash[..4]);
OprfOutput { value: hash } OprfOutput { value: hash }
} }