Files
opaque-lattice/src/ake/kyber.rs
2026-01-08 09:50:51 -07:00

309 lines
9.9 KiB
Rust

#[cfg(feature = "native")]
mod native {
use pqcrypto_kyber::kyber768;
use pqcrypto_traits::kem::{Ciphertext, PublicKey, SecretKey, SharedSecret};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::{OpaqueError, Result};
use crate::types::{KYBER_CT_LEN, KYBER_PK_LEN, KYBER_SK_LEN, KYBER_SS_LEN};
#[derive(Clone)]
pub struct KyberPublicKey(pub(crate) kyber768::PublicKey);
impl KyberPublicKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KYBER_PK_LEN {
return Err(OpaqueError::InvalidKeyLength {
expected: KYBER_PK_LEN,
got: bytes.len(),
});
}
kyber768::PublicKey::from_bytes(bytes)
.map(Self)
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber public key".into()))
}
#[must_use]
pub fn as_bytes(&self) -> Vec<u8> {
self.0.as_bytes().to_vec()
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct KyberSecretKey {
#[zeroize(skip)]
pub(crate) inner: kyber768::SecretKey,
}
impl KyberSecretKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KYBER_SK_LEN {
return Err(OpaqueError::InvalidKeyLength {
expected: KYBER_SK_LEN,
got: bytes.len(),
});
}
kyber768::SecretKey::from_bytes(bytes)
.map(|sk| Self { inner: sk })
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber secret key".into()))
}
#[must_use]
pub fn as_bytes(&self) -> Vec<u8> {
self.inner.as_bytes().to_vec()
}
}
#[derive(Clone)]
pub struct KyberCiphertext(pub(crate) kyber768::Ciphertext);
impl KyberCiphertext {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KYBER_CT_LEN {
return Err(OpaqueError::InvalidKeyLength {
expected: KYBER_CT_LEN,
got: bytes.len(),
});
}
kyber768::Ciphertext::from_bytes(bytes)
.map(Self)
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber ciphertext".into()))
}
#[must_use]
pub fn as_bytes(&self) -> Vec<u8> {
self.0.as_bytes().to_vec()
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct KyberSharedSecret {
#[zeroize(skip)]
pub(crate) inner: kyber768::SharedSecret,
}
impl KyberSharedSecret {
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
self.inner.as_bytes()
}
#[must_use]
pub fn to_array(&self) -> [u8; KYBER_SS_LEN] {
let bytes = self.inner.as_bytes();
let mut arr = [0u8; KYBER_SS_LEN];
arr.copy_from_slice(bytes);
arr
}
}
pub fn generate_keypair() -> (KyberPublicKey, KyberSecretKey) {
let (pk, sk) = kyber768::keypair();
(KyberPublicKey(pk), KyberSecretKey { inner: sk })
}
pub fn encapsulate(pk: &KyberPublicKey) -> Result<(KyberSharedSecret, KyberCiphertext)> {
let (ss, ct) = kyber768::encapsulate(&pk.0);
Ok((KyberSharedSecret { inner: ss }, KyberCiphertext(ct)))
}
pub fn decapsulate(ct: &KyberCiphertext, sk: &KyberSecretKey) -> Result<KyberSharedSecret> {
let ss = kyber768::decapsulate(&ct.0, &sk.inner);
Ok(KyberSharedSecret { inner: ss })
}
}
#[cfg(feature = "wasm")]
mod wasm {
use fips203::ml_kem_768;
use fips203::traits::{Decaps, Encaps, KeyGen, SerDes};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::error::{OpaqueError, Result};
use crate::types::{KYBER_CT_LEN, KYBER_PK_LEN, KYBER_SK_LEN, KYBER_SS_LEN};
#[derive(Clone)]
pub struct KyberPublicKey(pub(crate) ml_kem_768::EncapsKey);
impl KyberPublicKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KYBER_PK_LEN {
return Err(OpaqueError::InvalidKeyLength {
expected: KYBER_PK_LEN,
got: bytes.len(),
});
}
let arr: [u8; KYBER_PK_LEN] = bytes
.try_into()
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber public key".into()))?;
ml_kem_768::EncapsKey::try_from_bytes(arr)
.map(Self)
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber public key".into()))
}
#[must_use]
pub fn as_bytes(&self) -> Vec<u8> {
self.0.clone().into_bytes().to_vec()
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct KyberSecretKey {
#[zeroize(skip)]
pub(crate) inner: ml_kem_768::DecapsKey,
}
impl KyberSecretKey {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KYBER_SK_LEN {
return Err(OpaqueError::InvalidKeyLength {
expected: KYBER_SK_LEN,
got: bytes.len(),
});
}
let arr: [u8; KYBER_SK_LEN] = bytes
.try_into()
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber secret key".into()))?;
ml_kem_768::DecapsKey::try_from_bytes(arr)
.map(|sk| Self { inner: sk })
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber secret key".into()))
}
#[must_use]
pub fn as_bytes(&self) -> Vec<u8> {
self.inner.clone().into_bytes().to_vec()
}
}
#[derive(Clone)]
pub struct KyberCiphertext(pub(crate) ml_kem_768::CipherText);
impl KyberCiphertext {
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KYBER_CT_LEN {
return Err(OpaqueError::InvalidKeyLength {
expected: KYBER_CT_LEN,
got: bytes.len(),
});
}
let arr: [u8; KYBER_CT_LEN] = bytes
.try_into()
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber ciphertext".into()))?;
ml_kem_768::CipherText::try_from_bytes(arr)
.map(Self)
.map_err(|_| OpaqueError::Deserialization("Invalid Kyber ciphertext".into()))
}
#[must_use]
pub fn as_bytes(&self) -> Vec<u8> {
self.0.clone().into_bytes().to_vec()
}
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct KyberSharedSecret {
pub(crate) inner: [u8; KYBER_SS_LEN],
}
impl KyberSharedSecret {
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.inner
}
#[must_use]
pub fn to_array(&self) -> [u8; KYBER_SS_LEN] {
self.inner
}
}
pub fn generate_keypair() -> (KyberPublicKey, KyberSecretKey) {
let (ek, dk) = ml_kem_768::KG::try_keygen().expect("keygen should not fail with good RNG");
(KyberPublicKey(ek), KyberSecretKey { inner: dk })
}
pub fn encapsulate(pk: &KyberPublicKey) -> Result<(KyberSharedSecret, KyberCiphertext)> {
let (ssk, ct) =
pk.0.try_encaps()
.map_err(|_| OpaqueError::EncapsulationFailed)?;
let ss_bytes: [u8; KYBER_SS_LEN] = ssk.into_bytes().into();
Ok((KyberSharedSecret { inner: ss_bytes }, KyberCiphertext(ct)))
}
pub fn decapsulate(ct: &KyberCiphertext, sk: &KyberSecretKey) -> Result<KyberSharedSecret> {
let ssk = sk
.inner
.try_decaps(&ct.0)
.map_err(|_| OpaqueError::DecapsulationFailed)?;
let ss_bytes: [u8; KYBER_SS_LEN] = ssk.into_bytes().into();
Ok(KyberSharedSecret { inner: ss_bytes })
}
}
#[cfg(all(feature = "native", feature = "wasm"))]
compile_error!("Features 'native' and 'wasm' are mutually exclusive. Enable only one.");
#[cfg(all(feature = "native", not(feature = "wasm")))]
pub use native::*;
#[cfg(all(feature = "wasm", not(feature = "native")))]
pub use wasm::*;
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{KYBER_CT_LEN, KYBER_PK_LEN, KYBER_SK_LEN, KYBER_SS_LEN};
#[test]
fn test_keypair_generation() {
let (pk, sk) = generate_keypair();
assert_eq!(pk.as_bytes().len(), KYBER_PK_LEN);
assert_eq!(sk.as_bytes().len(), KYBER_SK_LEN);
}
#[test]
fn test_encapsulate_decapsulate() {
let (pk, sk) = generate_keypair();
let (ss1, ct) = encapsulate(&pk).expect("encapsulation should succeed");
let ss2 = decapsulate(&ct, &sk).expect("decapsulation should succeed");
assert_eq!(ss1.as_bytes(), ss2.as_bytes());
assert_eq!(ss1.as_bytes().len(), KYBER_SS_LEN);
}
#[test]
fn test_public_key_serialization() {
let (pk, _) = generate_keypair();
let bytes = pk.as_bytes();
let pk2 = KyberPublicKey::from_bytes(&bytes).expect("deserialization should succeed");
assert_eq!(pk.as_bytes(), pk2.as_bytes());
}
#[test]
fn test_secret_key_serialization() {
let (_, sk) = generate_keypair();
let bytes = sk.as_bytes();
let sk2 = KyberSecretKey::from_bytes(&bytes).expect("deserialization should succeed");
assert_eq!(sk.as_bytes(), sk2.as_bytes());
}
#[test]
fn test_ciphertext_serialization() {
let (pk, _) = generate_keypair();
let (_, ct) = encapsulate(&pk).expect("encapsulation should succeed");
let bytes = ct.as_bytes();
let ct2 = KyberCiphertext::from_bytes(&bytes).expect("deserialization should succeed");
assert_eq!(ct.as_bytes(), ct2.as_bytes());
}
#[test]
fn test_invalid_key_length() {
let result = KyberPublicKey::from_bytes(&[0u8; 100]);
assert!(result.is_err());
let result = KyberSecretKey::from_bytes(&[0u8; 100]);
assert!(result.is_err());
}
}