Fixed reconciliation bug - Peikert-style reconciliation now achieves 100% accuracy (was 50% with broken XOR)
This commit is contained in:
@@ -12,6 +12,7 @@ macro_rules! trace {
|
||||
|
||||
pub(crate) use trace;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn hex_preview(data: &[u8], len: usize) -> String {
|
||||
let preview: Vec<String> = data
|
||||
.iter()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod ake;
|
||||
pub(crate) mod debug;
|
||||
pub mod envelope;
|
||||
pub mod error;
|
||||
pub mod kdf;
|
||||
|
||||
@@ -284,41 +284,21 @@ impl ReconciliationHelper {
|
||||
for i in 0..RING_N {
|
||||
let w = client_value.coeffs[i].rem_euclid(Q);
|
||||
let server_quadrant = self.quadrants[i];
|
||||
let client_quadrant = ((w / q4) % 4) as u8;
|
||||
|
||||
// Server's bit: quadrants 0,1 → 0; quadrants 2,3 → 1
|
||||
let server_bit = server_quadrant / 2;
|
||||
let client_bit = if w >= q2 { 1u8 } else { 0u8 };
|
||||
|
||||
// Client's naive bit
|
||||
let _client_naive_bit = if w >= q2 { 1u8 } else { 0u8 };
|
||||
let is_adjacent = client_quadrant == server_quadrant
|
||||
|| (client_quadrant + 1) % 4 == server_quadrant
|
||||
|| (server_quadrant + 1) % 4 == client_quadrant;
|
||||
|
||||
// Check if client is in a "danger zone" near the Q/2 boundary
|
||||
// Danger zones: [Q/4, Q/2) and [3Q/4, Q) - where small errors could flip the bit
|
||||
let client_quadrant = (w / q4) as u8;
|
||||
|
||||
// Reconciliation: trust the server's quadrant hint
|
||||
// If error is < Q/4, client's value is within one quadrant of server's
|
||||
// So we can use server's quadrant to determine the correct bit
|
||||
let agreed_bit = if client_quadrant == server_quadrant {
|
||||
// Same quadrant: both agree
|
||||
server_bit
|
||||
} else if (client_quadrant + 1) % 4 == server_quadrant
|
||||
|| (server_quadrant + 1) % 4 == client_quadrant
|
||||
{
|
||||
// Adjacent quadrants: use server's hint (error pushed us across boundary)
|
||||
server_bit
|
||||
} else {
|
||||
// Far quadrants (error > Q/4): this shouldn't happen with proper parameters
|
||||
// Fall back to server's hint
|
||||
server_bit
|
||||
};
|
||||
|
||||
bits[i] = agreed_bit;
|
||||
bits[i] = if is_adjacent { server_bit } else { client_bit };
|
||||
}
|
||||
|
||||
bits
|
||||
}
|
||||
|
||||
/// Compute the server's bits directly (for testing)
|
||||
pub fn server_bits(elem: &RingElement) -> [u8; RING_N] {
|
||||
let mut bits = [0u8; RING_N];
|
||||
let q2 = Q / 2;
|
||||
@@ -417,13 +397,11 @@ impl PublicParams {
|
||||
}
|
||||
|
||||
impl ServerKey {
|
||||
/// Generate a new server key
|
||||
pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self {
|
||||
println!("[ServerKey] Generating from seed");
|
||||
trace!("[ServerKey] Generating from seed");
|
||||
|
||||
// Sample small secret k
|
||||
let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND);
|
||||
println!(
|
||||
trace!(
|
||||
"[ServerKey] k L∞ norm: {} (bound: {})",
|
||||
k.linf_norm(),
|
||||
ERROR_BOUND
|
||||
@@ -433,17 +411,15 @@ impl ServerKey {
|
||||
"Server key exceeds error bound!"
|
||||
);
|
||||
|
||||
// Sample small error
|
||||
let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND);
|
||||
println!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
|
||||
trace!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
|
||||
debug_assert!(
|
||||
e_k.linf_norm() <= ERROR_BOUND,
|
||||
"Server error exceeds error bound!"
|
||||
);
|
||||
|
||||
// B = A*k + e_k
|
||||
let b = pp.a.mul(&k).add(&e_k);
|
||||
println!("[ServerKey] B L∞ norm: {}", b.linf_norm());
|
||||
trace!("[ServerKey] B L∞ norm: {}", b.linf_norm());
|
||||
|
||||
Self {
|
||||
k,
|
||||
@@ -459,13 +435,11 @@ impl ServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
/// Client blinds the password for oblivious evaluation
|
||||
pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) {
|
||||
println!("[Client] Blinding password of {} bytes", password.len());
|
||||
trace!("[Client] Blinding password of {} bytes", password.len());
|
||||
|
||||
// Derive small s from password (deterministic!)
|
||||
let s = RingElement::sample_small(password, ERROR_BOUND);
|
||||
println!(
|
||||
trace!(
|
||||
"[Client] s L∞ norm: {} (bound: {})",
|
||||
s.linf_norm(),
|
||||
ERROR_BOUND
|
||||
@@ -475,88 +449,76 @@ pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, Blinded
|
||||
"Client secret exceeds error bound!"
|
||||
);
|
||||
|
||||
// Derive small e from password (deterministic!)
|
||||
let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND);
|
||||
println!("[Client] e L∞ norm: {}", e.linf_norm());
|
||||
trace!("[Client] e L∞ norm: {}", e.linf_norm());
|
||||
debug_assert!(
|
||||
e.linf_norm() <= ERROR_BOUND,
|
||||
"Client error exceeds error bound!"
|
||||
);
|
||||
|
||||
// C = A*s + e
|
||||
let c = pp.a.mul(&s).add(&e);
|
||||
println!("[Client] C L∞ norm: {}", c.linf_norm());
|
||||
trace!("[Client] C L∞ norm: {}", c.linf_norm());
|
||||
|
||||
let state = ClientState { s };
|
||||
|
||||
let blinded = BlindedInput { c };
|
||||
|
||||
(state, blinded)
|
||||
(ClientState { s }, BlindedInput { c })
|
||||
}
|
||||
|
||||
/// Server evaluates the OPRF on blinded input
|
||||
pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse {
|
||||
println!("[Server] Evaluating on blinded input");
|
||||
println!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
|
||||
trace!("[Server] Evaluating on blinded input");
|
||||
trace!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
|
||||
|
||||
// V = k * C
|
||||
let v = key.k.mul(&blinded.c);
|
||||
println!("[Server] V L∞ norm: {}", v.linf_norm());
|
||||
trace!("[Server] V L∞ norm: {}", v.linf_norm());
|
||||
|
||||
// Compute reconciliation helper
|
||||
let helper = ReconciliationHelper::from_ring(&v);
|
||||
println!("[Server] Generated reconciliation helper");
|
||||
trace!("[Server] Generated reconciliation helper");
|
||||
|
||||
ServerResponse { v, helper }
|
||||
}
|
||||
|
||||
/// Client finalizes to get OPRF output
|
||||
pub fn client_finalize(
|
||||
state: &ClientState,
|
||||
server_public: &RingElement,
|
||||
response: &ServerResponse,
|
||||
) -> OprfOutput {
|
||||
println!("[Client] Finalizing OPRF output");
|
||||
trace!("[Client] Finalizing OPRF output");
|
||||
|
||||
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
|
||||
let w = state.s.mul(server_public);
|
||||
println!("[Client] W L∞ norm: {}", w.linf_norm());
|
||||
trace!("[Client] W L∞ norm: {}", w.linf_norm());
|
||||
|
||||
// The difference V - W should be small:
|
||||
// V = k * C = k * (A*s + e) = k*A*s + k*e
|
||||
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
|
||||
// V - W = k*e - s*e_k
|
||||
// Since k, e, s, e_k are all small, the difference is small!
|
||||
#[cfg(feature = "debug-trace")]
|
||||
{
|
||||
let diff = response.v.sub(&w);
|
||||
println!(
|
||||
trace!(
|
||||
"[Client] V - W L∞ norm: {} (should be small, ~{} max)",
|
||||
diff.linf_norm(),
|
||||
ERROR_BOUND * ERROR_BOUND * RING_N as i32
|
||||
);
|
||||
}
|
||||
|
||||
let bits = response.helper.extract_bits(&w);
|
||||
|
||||
// Count how many bits match direct rounding
|
||||
#[cfg(feature = "debug-trace")]
|
||||
{
|
||||
let v_bits = response.v.round_to_binary();
|
||||
let matching: usize = bits
|
||||
.iter()
|
||||
.zip(v_bits.iter())
|
||||
.filter(|(a, b)| a == b)
|
||||
.count();
|
||||
println!(
|
||||
trace!(
|
||||
"[Client] Reconciliation accuracy: {}/{} ({:.1}%)",
|
||||
matching,
|
||||
RING_N,
|
||||
matching as f64 / RING_N as f64 * 100.0
|
||||
);
|
||||
}
|
||||
|
||||
// Hash the reconciled bits to get final output
|
||||
let mut hasher = Sha3_256::new();
|
||||
hasher.update(b"FastOPRF-Output-v1");
|
||||
hasher.update(&bits);
|
||||
let hash: [u8; 32] = hasher.finalize().into();
|
||||
|
||||
println!("[Client] Output: {:02x?}...", &hash[..4]);
|
||||
trace!("[Client] Output: {:02x?}...", &hash[..4]);
|
||||
|
||||
OprfOutput { value: hash }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user