karyon_p2p/
tls_config.rs

1use std::sync::Arc;
2
3#[cfg(feature = "smol")]
4use futures_rustls::rustls;
5
6#[cfg(feature = "tokio")]
7use tokio_rustls::rustls;
8
9use rustls::{
10    client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11    crypto::{
12        aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group},
13        CryptoProvider, SupportedKxGroup,
14    },
15    server::danger::{ClientCertVerified, ClientCertVerifier},
16    CertificateError, DigitallySignedStruct, DistinguishedName,
17    Error::InvalidCertificate,
18    SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
19};
20
21use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
22
23use log::error;
24use x509_parser::{certificate::X509Certificate, parse_x509_certificate};
25
26use karyon_core::crypto::{KeyPair, KeyPairType, PublicKey};
27
28use crate::{PeerID, Result};
29
30// NOTE: This code needs a comprehensive audit.
31
32static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13];
33static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256];
34static KX_GROUPS: &[&dyn SupportedKxGroup] = &[kx_group::X25519];
35static SIGNATURE_SCHEMES: &[SignatureScheme] = &[SignatureScheme::ED25519];
36
37const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature);
38const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding);
39
40/// Returns a TLS client configuration.
41pub fn tls_client_config(
42    key_pair: &KeyPair,
43    peer_id: Option<PeerID>,
44) -> Result<rustls::ClientConfig> {
45    let (cert, private_key) = generate_cert(key_pair)?;
46    let server_verifier = SrvrCertVerifier { peer_id };
47
48    let client_config = rustls::ClientConfig::builder_with_provider(
49        CryptoProvider {
50            kx_groups: KX_GROUPS.to_vec(),
51            cipher_suites: CIPHER_SUITES.to_vec(),
52            ..aws_lc_rs::default_provider()
53        }
54        .into(),
55    )
56    .with_protocol_versions(PROTOCOL_VERSIONS)?
57    .dangerous()
58    .with_custom_certificate_verifier(Arc::new(server_verifier))
59    .with_client_auth_cert(vec![cert], private_key)?;
60
61    Ok(client_config)
62}
63
64/// Returns a TLS server configuration.
65pub fn tls_server_config(key_pair: &KeyPair) -> Result<rustls::ServerConfig> {
66    let (cert, private_key) = generate_cert(key_pair)?;
67    let client_verifier = CliCertVerifier {};
68    let server_config = rustls::ServerConfig::builder_with_provider(
69        CryptoProvider {
70            kx_groups: KX_GROUPS.to_vec(),
71            cipher_suites: CIPHER_SUITES.to_vec(),
72            ..aws_lc_rs::default_provider()
73        }
74        .into(),
75    )
76    .with_protocol_versions(PROTOCOL_VERSIONS)?
77    .with_client_cert_verifier(Arc::new(client_verifier))
78    .with_single_cert(vec![cert], private_key)?;
79
80    Ok(server_config)
81}
82
83/// Generates a certificate and returns both the certificate and the private key.
84fn generate_cert<'a>(key_pair: &KeyPair) -> Result<(CertificateDer<'a>, PrivateKeyDer<'a>)> {
85    let cert_key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
86    let private_key = PrivateKeyDer::Pkcs8(cert_key_pair.serialize_der().into());
87
88    // Add a custom extension to the certificate:
89    //   - Sign the certificate's public key with the provided key pair's private key
90    //   - Append both the computed signature and the key pair's public key to the extension
91    let signature = key_pair.sign(&cert_key_pair.public_key_der());
92    let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature));
93    // XXX: Not sure about the oid number ???
94    let mut ext = rcgen::CustomExtension::from_oid_content(&[0, 0, 0, 0], ext_content);
95    ext.set_criticality(true);
96
97    let mut params = rcgen::CertificateParams::new(vec![])?;
98    params.custom_extensions.push(ext);
99
100    let cert = CertificateDer::from(params.self_signed(&cert_key_pair)?);
101    Ok((cert, private_key))
102}
103
104/// Verifies the given certification.
105fn verify_cert(end_entity: &CertificateDer<'_>) -> std::result::Result<PeerID, rustls::Error> {
106    // Parse the certificate.
107    let cert = parse_cert(end_entity)?;
108
109    match cert.extensions().first() {
110        Some(ext) => {
111            // Extract the peer id (public key) and the signature from the extension.
112            let (public_key, signature): (Vec<u8>, Vec<u8>) =
113                yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?;
114
115            // Use the peer id (public key) to verify the extracted signature.
116            let public_key = PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key)
117                .map_err(|_| BAD_ENCODING_ERR)?;
118            public_key
119                .verify(cert.public_key().raw, &signature)
120                .map_err(|_| BAD_SIGNATURE_ERR)?;
121
122            // Verify the certificate signature.
123            verify_cert_signature(
124                &cert,
125                cert.tbs_certificate.as_ref(),
126                cert.signature_value.as_ref(),
127            )?;
128
129            PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR)
130        }
131        None => Err(BAD_ENCODING_ERR),
132    }
133}
134
135/// Parses the given x509 certificate.
136fn parse_cert<'a>(
137    end_entity: &'a CertificateDer<'a>,
138) -> std::result::Result<X509Certificate<'a>, rustls::Error> {
139    let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?;
140
141    if !cert.validity().is_valid() {
142        return Err(InvalidCertificate(CertificateError::NotValidYet));
143    }
144
145    Ok(cert)
146}
147
148/// Verifies the signature of the given certificate.
149fn verify_cert_signature(
150    cert: &X509Certificate,
151    message: &[u8],
152    signature: &[u8],
153) -> std::result::Result<(), rustls::Error> {
154    let public_key = PublicKey::from_bytes(
155        &KeyPairType::Ed25519,
156        cert.tbs_certificate.subject_pki.subject_public_key.as_ref(),
157    )
158    .map_err(|_| BAD_ENCODING_ERR)?;
159
160    public_key
161        .verify(message, signature)
162        .map_err(|_| BAD_SIGNATURE_ERR)
163}
164
165#[derive(Debug)]
166struct SrvrCertVerifier {
167    peer_id: Option<PeerID>,
168}
169
170impl ServerCertVerifier for SrvrCertVerifier {
171    fn verify_server_cert(
172        &self,
173        end_entity: &CertificateDer<'_>,
174        _intermediates: &[CertificateDer<'_>],
175        _server_name: &ServerName,
176        _ocsp_response: &[u8],
177        _now: UnixTime,
178    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
179        let peer_id = match verify_cert(end_entity) {
180            Ok(pid) => pid,
181            Err(err) => {
182                error!("Failed to verify cert: {err}");
183                return Err(err);
184            }
185        };
186
187        // Verify that the peer id in the certificate's extension matches the
188        // one the client intends to connect to.
189        // Both should be equal for establishing a fully secure connection.
190        if let Some(pid) = &self.peer_id {
191            if pid != &peer_id {
192                return Err(InvalidCertificate(
193                    CertificateError::ApplicationVerificationFailure,
194                ));
195            }
196        }
197
198        Ok(ServerCertVerified::assertion())
199    }
200
201    fn verify_tls12_signature(
202        &self,
203        _message: &[u8],
204        _cert: &CertificateDer<'_>,
205        _dss: &DigitallySignedStruct,
206    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
207        unreachable!("ONLY SUPPORT tls 13 VERSION")
208    }
209
210    fn verify_tls13_signature(
211        &self,
212        message: &[u8],
213        cert: &CertificateDer<'_>,
214        dss: &DigitallySignedStruct,
215    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
216        let cert = parse_cert(cert)?;
217        verify_cert_signature(&cert, message, dss.signature())?;
218        Ok(HandshakeSignatureValid::assertion())
219    }
220
221    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
222        SIGNATURE_SCHEMES.to_vec()
223    }
224}
225
226#[derive(Debug)]
227struct CliCertVerifier {}
228impl ClientCertVerifier for CliCertVerifier {
229    fn verify_client_cert(
230        &self,
231        end_entity: &CertificateDer<'_>,
232        _intermediates: &[CertificateDer<'_>],
233        _now: UnixTime,
234    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
235        if let Err(err) = verify_cert(end_entity) {
236            error!("Failed to verify cert: {err}");
237            return Err(err);
238        };
239        Ok(ClientCertVerified::assertion())
240    }
241
242    fn root_hint_subjects(&self) -> &[DistinguishedName] {
243        &[]
244    }
245
246    fn verify_tls12_signature(
247        &self,
248        _message: &[u8],
249        _cert: &CertificateDer<'_>,
250        _dss: &DigitallySignedStruct,
251    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
252        unreachable!("ONLY SUPPORT tls 13 VERSION")
253    }
254
255    fn verify_tls13_signature(
256        &self,
257        message: &[u8],
258        cert: &CertificateDer<'_>,
259        dss: &DigitallySignedStruct,
260    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
261        let cert = parse_cert(cert)?;
262        verify_cert_signature(&cert, message, dss.signature())?;
263        Ok(HandshakeSignatureValid::assertion())
264    }
265
266    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
267        SIGNATURE_SCHEMES.to_vec()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn verify_generated_certificate() {
277        let key_pair = KeyPair::generate(&KeyPairType::Ed25519);
278        let (cert, _) = generate_cert(&key_pair).unwrap();
279
280        let result = verify_cert(&cert);
281        assert!(result.is_ok());
282        let peer_id = result.unwrap();
283        assert_eq!(peer_id, PeerID::try_from(key_pair.public()).unwrap());
284        assert_eq!(peer_id.0, key_pair.public().as_bytes());
285    }
286}