karyon_p2p/
tls_config.rs

1use std::sync::Arc;
2
3#[cfg(feature = "smol")]
4use futures_rustls::rustls;
5
6#[cfg(feature = "tokio")]
7use tokio_rustls::rustls;
8
9use rustls::{
10    client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11    crypto::{
12        aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group},
13        CryptoProvider, SupportedKxGroup,
14    },
15    server::danger::{ClientCertVerified, ClientCertVerifier},
16    CertificateError, DigitallySignedStruct, DistinguishedName,
17    Error::InvalidCertificate,
18    SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
19};
20
21use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
22
23use log::error;
24use rcgen::PublicKeyData;
25use x509_parser::{certificate::X509Certificate, parse_x509_certificate};
26
27use karyon_core::crypto::{KeyPair, KeyPairType, PublicKey};
28
29use crate::{PeerID, Result};
30
31// NOTE: This code needs a comprehensive audit.
32
33static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13];
34static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256];
35static KX_GROUPS: &[&dyn SupportedKxGroup] = &[kx_group::X25519];
36static SIGNATURE_SCHEMES: &[SignatureScheme] = &[SignatureScheme::ED25519];
37
38const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature);
39const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding);
40
41/// Returns a TLS client configuration.
42pub fn tls_client_config(
43    key_pair: &KeyPair,
44    peer_id: Option<PeerID>,
45) -> Result<rustls::ClientConfig> {
46    let (cert, private_key) = generate_cert(key_pair)?;
47    let server_verifier = SrvrCertVerifier { peer_id };
48
49    let client_config = rustls::ClientConfig::builder_with_provider(
50        CryptoProvider {
51            kx_groups: KX_GROUPS.to_vec(),
52            cipher_suites: CIPHER_SUITES.to_vec(),
53            ..aws_lc_rs::default_provider()
54        }
55        .into(),
56    )
57    .with_protocol_versions(PROTOCOL_VERSIONS)?
58    .dangerous()
59    .with_custom_certificate_verifier(Arc::new(server_verifier))
60    .with_client_auth_cert(vec![cert], private_key)?;
61
62    Ok(client_config)
63}
64
65/// Returns a TLS server configuration.
66pub fn tls_server_config(key_pair: &KeyPair) -> Result<rustls::ServerConfig> {
67    let (cert, private_key) = generate_cert(key_pair)?;
68    let client_verifier = CliCertVerifier {};
69    let server_config = rustls::ServerConfig::builder_with_provider(
70        CryptoProvider {
71            kx_groups: KX_GROUPS.to_vec(),
72            cipher_suites: CIPHER_SUITES.to_vec(),
73            ..aws_lc_rs::default_provider()
74        }
75        .into(),
76    )
77    .with_protocol_versions(PROTOCOL_VERSIONS)?
78    .with_client_cert_verifier(Arc::new(client_verifier))
79    .with_single_cert(vec![cert], private_key)?;
80
81    Ok(server_config)
82}
83
84/// Generates a certificate and returns both the certificate and the private key.
85fn generate_cert<'a>(key_pair: &KeyPair) -> Result<(CertificateDer<'a>, PrivateKeyDer<'a>)> {
86    let cert_key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
87    let private_key = PrivateKeyDer::Pkcs8(cert_key_pair.serialize_der().into());
88
89    // Add a custom extension to the certificate:
90    //   - Sign the certificate's public key with the provided key pair's private key
91    //   - Append both the computed signature and the key pair's public key to the extension
92    let signature = key_pair.sign(&cert_key_pair.subject_public_key_info());
93    let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature));
94    // XXX: Not sure about the oid number ???
95    let mut ext = rcgen::CustomExtension::from_oid_content(&[0, 0, 0, 0], ext_content);
96    // XXX: Non-critical because rustls rejects unknown critical extensions
97    // before our custom verifiers can run. The extension is still validated
98    // by verify_cert() which requires it to be present and checks the signature.
99    ext.set_criticality(false);
100
101    let mut params = rcgen::CertificateParams::new(vec![])?;
102    params.custom_extensions.push(ext);
103
104    let cert = CertificateDer::from(params.self_signed(&cert_key_pair)?);
105    Ok((cert, private_key))
106}
107
108/// Verifies the given certification.
109fn verify_cert(end_entity: &CertificateDer<'_>) -> std::result::Result<PeerID, rustls::Error> {
110    // Parse the certificate.
111    let cert = parse_cert(end_entity)?;
112
113    match cert.extensions().first() {
114        Some(ext) => {
115            // Extract the peer id (public key) and the signature from the extension.
116            let (public_key, signature): (Vec<u8>, Vec<u8>) =
117                yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?;
118
119            // Use the peer id (public key) to verify the extracted signature.
120            let public_key = PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key)
121                .map_err(|_| BAD_ENCODING_ERR)?;
122            public_key
123                .verify(cert.public_key().raw, &signature)
124                .map_err(|_| BAD_SIGNATURE_ERR)?;
125
126            // Verify the certificate signature.
127            verify_cert_signature(
128                &cert,
129                cert.tbs_certificate.as_ref(),
130                cert.signature_value.as_ref(),
131            )?;
132
133            PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR)
134        }
135        None => Err(BAD_ENCODING_ERR),
136    }
137}
138
139/// Parses the given x509 certificate.
140fn parse_cert<'a>(
141    end_entity: &'a CertificateDer<'a>,
142) -> std::result::Result<X509Certificate<'a>, rustls::Error> {
143    let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?;
144
145    if !cert.validity().is_valid() {
146        return Err(InvalidCertificate(CertificateError::NotValidYet));
147    }
148
149    Ok(cert)
150}
151
152/// Verifies the signature of the given certificate.
153fn verify_cert_signature(
154    cert: &X509Certificate,
155    message: &[u8],
156    signature: &[u8],
157) -> std::result::Result<(), rustls::Error> {
158    let public_key = PublicKey::from_bytes(
159        &KeyPairType::Ed25519,
160        cert.tbs_certificate.subject_pki.subject_public_key.as_ref(),
161    )
162    .map_err(|_| BAD_ENCODING_ERR)?;
163
164    public_key
165        .verify(message, signature)
166        .map_err(|_| BAD_SIGNATURE_ERR)
167}
168
169#[derive(Debug)]
170struct SrvrCertVerifier {
171    peer_id: Option<PeerID>,
172}
173
174impl ServerCertVerifier for SrvrCertVerifier {
175    fn verify_server_cert(
176        &self,
177        end_entity: &CertificateDer<'_>,
178        _intermediates: &[CertificateDer<'_>],
179        _server_name: &ServerName,
180        _ocsp_response: &[u8],
181        _now: UnixTime,
182    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
183        let peer_id = match verify_cert(end_entity) {
184            Ok(pid) => pid,
185            Err(err) => {
186                error!("Failed to verify cert: {err}");
187                return Err(err);
188            }
189        };
190
191        // Verify that the peer id in the certificate's extension matches the
192        // one the client intends to connect to.
193        // Both should be equal for establishing a fully secure connection.
194        if let Some(pid) = &self.peer_id {
195            if pid != &peer_id {
196                return Err(InvalidCertificate(
197                    CertificateError::ApplicationVerificationFailure,
198                ));
199            }
200        }
201
202        Ok(ServerCertVerified::assertion())
203    }
204
205    fn verify_tls12_signature(
206        &self,
207        _message: &[u8],
208        _cert: &CertificateDer<'_>,
209        _dss: &DigitallySignedStruct,
210    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
211        unreachable!("ONLY SUPPORT tls 13 VERSION")
212    }
213
214    fn verify_tls13_signature(
215        &self,
216        message: &[u8],
217        cert: &CertificateDer<'_>,
218        dss: &DigitallySignedStruct,
219    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
220        let cert = parse_cert(cert)?;
221        verify_cert_signature(&cert, message, dss.signature())?;
222        Ok(HandshakeSignatureValid::assertion())
223    }
224
225    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
226        SIGNATURE_SCHEMES.to_vec()
227    }
228}
229
230#[derive(Debug)]
231struct CliCertVerifier {}
232impl ClientCertVerifier for CliCertVerifier {
233    fn verify_client_cert(
234        &self,
235        end_entity: &CertificateDer<'_>,
236        _intermediates: &[CertificateDer<'_>],
237        _now: UnixTime,
238    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
239        if let Err(err) = verify_cert(end_entity) {
240            error!("Failed to verify cert: {err}");
241            return Err(err);
242        };
243        Ok(ClientCertVerified::assertion())
244    }
245
246    fn root_hint_subjects(&self) -> &[DistinguishedName] {
247        &[]
248    }
249
250    fn verify_tls12_signature(
251        &self,
252        _message: &[u8],
253        _cert: &CertificateDer<'_>,
254        _dss: &DigitallySignedStruct,
255    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
256        unreachable!("ONLY SUPPORT tls 13 VERSION")
257    }
258
259    fn verify_tls13_signature(
260        &self,
261        message: &[u8],
262        cert: &CertificateDer<'_>,
263        dss: &DigitallySignedStruct,
264    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
265        let cert = parse_cert(cert)?;
266        verify_cert_signature(&cert, message, dss.signature())?;
267        Ok(HandshakeSignatureValid::assertion())
268    }
269
270    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
271        SIGNATURE_SCHEMES.to_vec()
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn verify_generated_certificate() {
281        let key_pair = KeyPair::generate(&KeyPairType::Ed25519);
282        let (cert, _) = generate_cert(&key_pair).unwrap();
283
284        let result = verify_cert(&cert);
285        assert!(result.is_ok());
286        let peer_id = result.unwrap();
287        assert_eq!(peer_id, PeerID::try_from(key_pair.public()).unwrap());
288        assert_eq!(peer_id.0, key_pair.public().as_bytes());
289    }
290}