Fixed reconciliation bug - Peikert-style reconciliation now achieves 100% accuracy (was 50% with broken XOR)

This commit is contained in:
2026-01-06 13:23:40 -07:00
parent 053b983f43
commit e893d6998f
3 changed files with 48 additions and 84 deletions

View File

@@ -12,6 +12,7 @@ macro_rules! trace {
pub(crate) use trace;
#[allow(dead_code)]
pub fn hex_preview(data: &[u8], len: usize) -> String {
let preview: Vec<String> = data
.iter()

View File

@@ -1,6 +1,7 @@
#![forbid(unsafe_code)]
pub mod ake;
pub(crate) mod debug;
pub mod envelope;
pub mod error;
pub mod kdf;

View File

@@ -284,41 +284,21 @@ impl ReconciliationHelper {
for i in 0..RING_N {
let w = client_value.coeffs[i].rem_euclid(Q);
let server_quadrant = self.quadrants[i];
let client_quadrant = ((w / q4) % 4) as u8;
// Server's bit: quadrants 0,1 → 0; quadrants 2,3 → 1
let server_bit = server_quadrant / 2;
let client_bit = if w >= q2 { 1u8 } else { 0u8 };
// Client's naive bit
let _client_naive_bit = if w >= q2 { 1u8 } else { 0u8 };
let is_adjacent = client_quadrant == server_quadrant
|| (client_quadrant + 1) % 4 == server_quadrant
|| (server_quadrant + 1) % 4 == client_quadrant;
// Check if client is in a "danger zone" near the Q/2 boundary
// Danger zones: [Q/4, Q/2) and [3Q/4, Q) - where small errors could flip the bit
let client_quadrant = (w / q4) as u8;
// Reconciliation: trust the server's quadrant hint
// If error is < Q/4, client's value is within one quadrant of server's
// So we can use server's quadrant to determine the correct bit
let agreed_bit = if client_quadrant == server_quadrant {
// Same quadrant: both agree
server_bit
} else if (client_quadrant + 1) % 4 == server_quadrant
|| (server_quadrant + 1) % 4 == client_quadrant
{
// Adjacent quadrants: use server's hint (error pushed us across boundary)
server_bit
} else {
// Far quadrants (error > Q/4): this shouldn't happen with proper parameters
// Fall back to server's hint
server_bit
};
bits[i] = agreed_bit;
bits[i] = if is_adjacent { server_bit } else { client_bit };
}
bits
}
/// Compute the server's bits directly (for testing)
pub fn server_bits(elem: &RingElement) -> [u8; RING_N] {
let mut bits = [0u8; RING_N];
let q2 = Q / 2;
@@ -417,13 +397,11 @@ impl PublicParams {
}
impl ServerKey {
/// Generate a new server key
pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self {
println!("[ServerKey] Generating from seed");
trace!("[ServerKey] Generating from seed");
// Sample small secret k
let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND);
println!(
trace!(
"[ServerKey] k L∞ norm: {} (bound: {})",
k.linf_norm(),
ERROR_BOUND
@@ -433,17 +411,15 @@ impl ServerKey {
"Server key exceeds error bound!"
);
// Sample small error
let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND);
println!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
trace!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
debug_assert!(
e_k.linf_norm() <= ERROR_BOUND,
"Server error exceeds error bound!"
);
// B = A*k + e_k
let b = pp.a.mul(&k).add(&e_k);
println!("[ServerKey] B L∞ norm: {}", b.linf_norm());
trace!("[ServerKey] B L∞ norm: {}", b.linf_norm());
Self {
k,
@@ -459,13 +435,11 @@ impl ServerKey {
}
}
/// Client blinds the password for oblivious evaluation
pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) {
println!("[Client] Blinding password of {} bytes", password.len());
trace!("[Client] Blinding password of {} bytes", password.len());
// Derive small s from password (deterministic!)
let s = RingElement::sample_small(password, ERROR_BOUND);
println!(
trace!(
"[Client] s L∞ norm: {} (bound: {})",
s.linf_norm(),
ERROR_BOUND
@@ -475,88 +449,76 @@ pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, Blinded
"Client secret exceeds error bound!"
);
// Derive small e from password (deterministic!)
let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND);
println!("[Client] e L∞ norm: {}", e.linf_norm());
trace!("[Client] e L∞ norm: {}", e.linf_norm());
debug_assert!(
e.linf_norm() <= ERROR_BOUND,
"Client error exceeds error bound!"
);
// C = A*s + e
let c = pp.a.mul(&s).add(&e);
println!("[Client] C L∞ norm: {}", c.linf_norm());
trace!("[Client] C L∞ norm: {}", c.linf_norm());
let state = ClientState { s };
let blinded = BlindedInput { c };
(state, blinded)
(ClientState { s }, BlindedInput { c })
}
/// Server evaluates the OPRF on blinded input
pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse {
println!("[Server] Evaluating on blinded input");
println!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
trace!("[Server] Evaluating on blinded input");
trace!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
// V = k * C
let v = key.k.mul(&blinded.c);
println!("[Server] V L∞ norm: {}", v.linf_norm());
trace!("[Server] V L∞ norm: {}", v.linf_norm());
// Compute reconciliation helper
let helper = ReconciliationHelper::from_ring(&v);
println!("[Server] Generated reconciliation helper");
trace!("[Server] Generated reconciliation helper");
ServerResponse { v, helper }
}
/// Client finalizes to get OPRF output
pub fn client_finalize(
state: &ClientState,
server_public: &RingElement,
response: &ServerResponse,
) -> OprfOutput {
println!("[Client] Finalizing OPRF output");
trace!("[Client] Finalizing OPRF output");
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
let w = state.s.mul(server_public);
println!("[Client] W L∞ norm: {}", w.linf_norm());
trace!("[Client] W L∞ norm: {}", w.linf_norm());
// The difference V - W should be small:
// V = k * C = k * (A*s + e) = k*A*s + k*e
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
// V - W = k*e - s*e_k
// Since k, e, s, e_k are all small, the difference is small!
let diff = response.v.sub(&w);
println!(
"[Client] V - W L∞ norm: {} (should be small, ~{} max)",
diff.linf_norm(),
ERROR_BOUND * ERROR_BOUND * RING_N as i32
);
#[cfg(feature = "debug-trace")]
{
let diff = response.v.sub(&w);
trace!(
"[Client] V - W L∞ norm: {} (should be small, ~{} max)",
diff.linf_norm(),
ERROR_BOUND * ERROR_BOUND * RING_N as i32
);
}
let bits = response.helper.extract_bits(&w);
// Count how many bits match direct rounding
let v_bits = response.v.round_to_binary();
let matching: usize = bits
.iter()
.zip(v_bits.iter())
.filter(|(a, b)| a == b)
.count();
println!(
"[Client] Reconciliation accuracy: {}/{} ({:.1}%)",
matching,
RING_N,
matching as f64 / RING_N as f64 * 100.0
);
#[cfg(feature = "debug-trace")]
{
let v_bits = response.v.round_to_binary();
let matching: usize = bits
.iter()
.zip(v_bits.iter())
.filter(|(a, b)| a == b)
.count();
trace!(
"[Client] Reconciliation accuracy: {}/{} ({:.1}%)",
matching,
RING_N,
matching as f64 / RING_N as f64 * 100.0
);
}
// Hash the reconciled bits to get final output
let mut hasher = Sha3_256::new();
hasher.update(b"FastOPRF-Output-v1");
hasher.update(&bits);
let hash: [u8; 32] = hasher.finalize().into();
println!("[Client] Output: {:02x?}...", &hash[..4]);
trace!("[Client] Output: {:02x?}...", &hash[..4]);
OprfOutput { value: hash }
}