Fixed reconciliation bug - Peikert-style reconciliation now achieves 100% accuracy (was 50% with broken XOR)
This commit is contained in:
@@ -12,6 +12,7 @@ macro_rules! trace {
|
|||||||
|
|
||||||
pub(crate) use trace;
|
pub(crate) use trace;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn hex_preview(data: &[u8], len: usize) -> String {
|
pub fn hex_preview(data: &[u8], len: usize) -> String {
|
||||||
let preview: Vec<String> = data
|
let preview: Vec<String> = data
|
||||||
.iter()
|
.iter()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#![forbid(unsafe_code)]
|
#![forbid(unsafe_code)]
|
||||||
|
|
||||||
pub mod ake;
|
pub mod ake;
|
||||||
|
pub(crate) mod debug;
|
||||||
pub mod envelope;
|
pub mod envelope;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod kdf;
|
pub mod kdf;
|
||||||
|
|||||||
@@ -284,41 +284,21 @@ impl ReconciliationHelper {
|
|||||||
for i in 0..RING_N {
|
for i in 0..RING_N {
|
||||||
let w = client_value.coeffs[i].rem_euclid(Q);
|
let w = client_value.coeffs[i].rem_euclid(Q);
|
||||||
let server_quadrant = self.quadrants[i];
|
let server_quadrant = self.quadrants[i];
|
||||||
|
let client_quadrant = ((w / q4) % 4) as u8;
|
||||||
|
|
||||||
// Server's bit: quadrants 0,1 → 0; quadrants 2,3 → 1
|
|
||||||
let server_bit = server_quadrant / 2;
|
let server_bit = server_quadrant / 2;
|
||||||
|
let client_bit = if w >= q2 { 1u8 } else { 0u8 };
|
||||||
|
|
||||||
// Client's naive bit
|
let is_adjacent = client_quadrant == server_quadrant
|
||||||
let _client_naive_bit = if w >= q2 { 1u8 } else { 0u8 };
|
|| (client_quadrant + 1) % 4 == server_quadrant
|
||||||
|
|| (server_quadrant + 1) % 4 == client_quadrant;
|
||||||
|
|
||||||
// Check if client is in a "danger zone" near the Q/2 boundary
|
bits[i] = if is_adjacent { server_bit } else { client_bit };
|
||||||
// Danger zones: [Q/4, Q/2) and [3Q/4, Q) - where small errors could flip the bit
|
|
||||||
let client_quadrant = (w / q4) as u8;
|
|
||||||
|
|
||||||
// Reconciliation: trust the server's quadrant hint
|
|
||||||
// If error is < Q/4, client's value is within one quadrant of server's
|
|
||||||
// So we can use server's quadrant to determine the correct bit
|
|
||||||
let agreed_bit = if client_quadrant == server_quadrant {
|
|
||||||
// Same quadrant: both agree
|
|
||||||
server_bit
|
|
||||||
} else if (client_quadrant + 1) % 4 == server_quadrant
|
|
||||||
|| (server_quadrant + 1) % 4 == client_quadrant
|
|
||||||
{
|
|
||||||
// Adjacent quadrants: use server's hint (error pushed us across boundary)
|
|
||||||
server_bit
|
|
||||||
} else {
|
|
||||||
// Far quadrants (error > Q/4): this shouldn't happen with proper parameters
|
|
||||||
// Fall back to server's hint
|
|
||||||
server_bit
|
|
||||||
};
|
|
||||||
|
|
||||||
bits[i] = agreed_bit;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bits
|
bits
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute the server's bits directly (for testing)
|
|
||||||
pub fn server_bits(elem: &RingElement) -> [u8; RING_N] {
|
pub fn server_bits(elem: &RingElement) -> [u8; RING_N] {
|
||||||
let mut bits = [0u8; RING_N];
|
let mut bits = [0u8; RING_N];
|
||||||
let q2 = Q / 2;
|
let q2 = Q / 2;
|
||||||
@@ -417,13 +397,11 @@ impl PublicParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ServerKey {
|
impl ServerKey {
|
||||||
/// Generate a new server key
|
|
||||||
pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self {
|
pub fn generate(pp: &PublicParams, seed: &[u8]) -> Self {
|
||||||
println!("[ServerKey] Generating from seed");
|
trace!("[ServerKey] Generating from seed");
|
||||||
|
|
||||||
// Sample small secret k
|
|
||||||
let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND);
|
let k = RingElement::sample_small(&[seed, b"-key"].concat(), ERROR_BOUND);
|
||||||
println!(
|
trace!(
|
||||||
"[ServerKey] k L∞ norm: {} (bound: {})",
|
"[ServerKey] k L∞ norm: {} (bound: {})",
|
||||||
k.linf_norm(),
|
k.linf_norm(),
|
||||||
ERROR_BOUND
|
ERROR_BOUND
|
||||||
@@ -433,17 +411,15 @@ impl ServerKey {
|
|||||||
"Server key exceeds error bound!"
|
"Server key exceeds error bound!"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Sample small error
|
|
||||||
let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND);
|
let e_k = RingElement::sample_small(&[seed, b"-error"].concat(), ERROR_BOUND);
|
||||||
println!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
|
trace!("[ServerKey] e_k L∞ norm: {}", e_k.linf_norm());
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
e_k.linf_norm() <= ERROR_BOUND,
|
e_k.linf_norm() <= ERROR_BOUND,
|
||||||
"Server error exceeds error bound!"
|
"Server error exceeds error bound!"
|
||||||
);
|
);
|
||||||
|
|
||||||
// B = A*k + e_k
|
|
||||||
let b = pp.a.mul(&k).add(&e_k);
|
let b = pp.a.mul(&k).add(&e_k);
|
||||||
println!("[ServerKey] B L∞ norm: {}", b.linf_norm());
|
trace!("[ServerKey] B L∞ norm: {}", b.linf_norm());
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
k,
|
k,
|
||||||
@@ -459,13 +435,11 @@ impl ServerKey {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Client blinds the password for oblivious evaluation
|
|
||||||
pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) {
|
pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, BlindedInput) {
|
||||||
println!("[Client] Blinding password of {} bytes", password.len());
|
trace!("[Client] Blinding password of {} bytes", password.len());
|
||||||
|
|
||||||
// Derive small s from password (deterministic!)
|
|
||||||
let s = RingElement::sample_small(password, ERROR_BOUND);
|
let s = RingElement::sample_small(password, ERROR_BOUND);
|
||||||
println!(
|
trace!(
|
||||||
"[Client] s L∞ norm: {} (bound: {})",
|
"[Client] s L∞ norm: {} (bound: {})",
|
||||||
s.linf_norm(),
|
s.linf_norm(),
|
||||||
ERROR_BOUND
|
ERROR_BOUND
|
||||||
@@ -475,88 +449,76 @@ pub fn client_blind(pp: &PublicParams, password: &[u8]) -> (ClientState, Blinded
|
|||||||
"Client secret exceeds error bound!"
|
"Client secret exceeds error bound!"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Derive small e from password (deterministic!)
|
|
||||||
let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND);
|
let e = RingElement::sample_small(&[password, b"-client-error"].concat(), ERROR_BOUND);
|
||||||
println!("[Client] e L∞ norm: {}", e.linf_norm());
|
trace!("[Client] e L∞ norm: {}", e.linf_norm());
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
e.linf_norm() <= ERROR_BOUND,
|
e.linf_norm() <= ERROR_BOUND,
|
||||||
"Client error exceeds error bound!"
|
"Client error exceeds error bound!"
|
||||||
);
|
);
|
||||||
|
|
||||||
// C = A*s + e
|
|
||||||
let c = pp.a.mul(&s).add(&e);
|
let c = pp.a.mul(&s).add(&e);
|
||||||
println!("[Client] C L∞ norm: {}", c.linf_norm());
|
trace!("[Client] C L∞ norm: {}", c.linf_norm());
|
||||||
|
|
||||||
let state = ClientState { s };
|
(ClientState { s }, BlindedInput { c })
|
||||||
|
|
||||||
let blinded = BlindedInput { c };
|
|
||||||
|
|
||||||
(state, blinded)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Server evaluates the OPRF on blinded input
|
|
||||||
pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse {
|
pub fn server_evaluate(key: &ServerKey, blinded: &BlindedInput) -> ServerResponse {
|
||||||
println!("[Server] Evaluating on blinded input");
|
trace!("[Server] Evaluating on blinded input");
|
||||||
println!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
|
trace!("[Server] C L∞ norm: {}", blinded.c.linf_norm());
|
||||||
|
|
||||||
// V = k * C
|
|
||||||
let v = key.k.mul(&blinded.c);
|
let v = key.k.mul(&blinded.c);
|
||||||
println!("[Server] V L∞ norm: {}", v.linf_norm());
|
trace!("[Server] V L∞ norm: {}", v.linf_norm());
|
||||||
|
|
||||||
// Compute reconciliation helper
|
|
||||||
let helper = ReconciliationHelper::from_ring(&v);
|
let helper = ReconciliationHelper::from_ring(&v);
|
||||||
println!("[Server] Generated reconciliation helper");
|
trace!("[Server] Generated reconciliation helper");
|
||||||
|
|
||||||
ServerResponse { v, helper }
|
ServerResponse { v, helper }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Client finalizes to get OPRF output
|
|
||||||
pub fn client_finalize(
|
pub fn client_finalize(
|
||||||
state: &ClientState,
|
state: &ClientState,
|
||||||
server_public: &RingElement,
|
server_public: &RingElement,
|
||||||
response: &ServerResponse,
|
response: &ServerResponse,
|
||||||
) -> OprfOutput {
|
) -> OprfOutput {
|
||||||
println!("[Client] Finalizing OPRF output");
|
trace!("[Client] Finalizing OPRF output");
|
||||||
|
|
||||||
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
|
|
||||||
let w = state.s.mul(server_public);
|
let w = state.s.mul(server_public);
|
||||||
println!("[Client] W L∞ norm: {}", w.linf_norm());
|
trace!("[Client] W L∞ norm: {}", w.linf_norm());
|
||||||
|
|
||||||
// The difference V - W should be small:
|
#[cfg(feature = "debug-trace")]
|
||||||
// V = k * C = k * (A*s + e) = k*A*s + k*e
|
{
|
||||||
// W = s * B = s * (A*k + e_k) = s*A*k + s*e_k
|
|
||||||
// V - W = k*e - s*e_k
|
|
||||||
// Since k, e, s, e_k are all small, the difference is small!
|
|
||||||
let diff = response.v.sub(&w);
|
let diff = response.v.sub(&w);
|
||||||
println!(
|
trace!(
|
||||||
"[Client] V - W L∞ norm: {} (should be small, ~{} max)",
|
"[Client] V - W L∞ norm: {} (should be small, ~{} max)",
|
||||||
diff.linf_norm(),
|
diff.linf_norm(),
|
||||||
ERROR_BOUND * ERROR_BOUND * RING_N as i32
|
ERROR_BOUND * ERROR_BOUND * RING_N as i32
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let bits = response.helper.extract_bits(&w);
|
let bits = response.helper.extract_bits(&w);
|
||||||
|
|
||||||
// Count how many bits match direct rounding
|
#[cfg(feature = "debug-trace")]
|
||||||
|
{
|
||||||
let v_bits = response.v.round_to_binary();
|
let v_bits = response.v.round_to_binary();
|
||||||
let matching: usize = bits
|
let matching: usize = bits
|
||||||
.iter()
|
.iter()
|
||||||
.zip(v_bits.iter())
|
.zip(v_bits.iter())
|
||||||
.filter(|(a, b)| a == b)
|
.filter(|(a, b)| a == b)
|
||||||
.count();
|
.count();
|
||||||
println!(
|
trace!(
|
||||||
"[Client] Reconciliation accuracy: {}/{} ({:.1}%)",
|
"[Client] Reconciliation accuracy: {}/{} ({:.1}%)",
|
||||||
matching,
|
matching,
|
||||||
RING_N,
|
RING_N,
|
||||||
matching as f64 / RING_N as f64 * 100.0
|
matching as f64 / RING_N as f64 * 100.0
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Hash the reconciled bits to get final output
|
|
||||||
let mut hasher = Sha3_256::new();
|
let mut hasher = Sha3_256::new();
|
||||||
hasher.update(b"FastOPRF-Output-v1");
|
hasher.update(b"FastOPRF-Output-v1");
|
||||||
hasher.update(&bits);
|
hasher.update(&bits);
|
||||||
let hash: [u8; 32] = hasher.finalize().into();
|
let hash: [u8; 32] = hasher.finalize().into();
|
||||||
|
|
||||||
println!("[Client] Output: {:02x?}...", &hash[..4]);
|
trace!("[Client] Output: {:02x?}...", &hash[..4]);
|
||||||
|
|
||||||
OprfOutput { value: hash }
|
OprfOutput { value: hash }
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user