Skip to main content

karyon_p2p/
tls_config.rs

1use std::sync::Arc;
2
3#[cfg(feature = "smol")]
4use futures_rustls::rustls;
5
6#[cfg(feature = "tokio")]
7use tokio_rustls::rustls;
8
9use rustls::{
10    client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11    crypto::{
12        aws_lc_rs::{self, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, kx_group},
13        CryptoProvider, SupportedKxGroup,
14    },
15    server::danger::{ClientCertVerified, ClientCertVerifier},
16    CertificateError, DigitallySignedStruct, DistinguishedName,
17    Error::InvalidCertificate,
18    SignatureScheme, SupportedCipherSuite, SupportedProtocolVersion,
19};
20
21use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
22
23use log::error;
24use rcgen::PublicKeyData;
25use x509_parser::{asn1_rs::Oid, certificate::X509Certificate, parse_x509_certificate};
26
27use karyon_core::crypto::{KeyPair, KeyPairType, PublicKey};
28
29use crate::{PeerID, Result};
30
31// NOTE: This code needs a comprehensive audit.
32
33static PROTOCOL_VERSIONS: &[&SupportedProtocolVersion] = &[&rustls::version::TLS13];
34static CIPHER_SUITES: &[SupportedCipherSuite] = &[TLS13_CHACHA20_POLY1305_SHA256];
35static KX_GROUPS: &[&dyn SupportedKxGroup] = &[kx_group::X25519];
36static SIGNATURE_SCHEMES: &[SignatureScheme] = &[SignatureScheme::ED25519];
37
38/// OID for the karyon peer-id custom extension.
39// TODO: not standards-conformant. Replace with a registered one.
40const PEER_ID_EXT_OID: &[u64] = &[0, 0, 0, 0];
41
42const BAD_SIGNATURE_ERR: rustls::Error = InvalidCertificate(CertificateError::BadSignature);
43const BAD_ENCODING_ERR: rustls::Error = InvalidCertificate(CertificateError::BadEncoding);
44
45/// Returns a TLS client configuration.
46pub fn tls_client_config(
47    key_pair: &KeyPair,
48    peer_id: Option<PeerID>,
49) -> Result<rustls::ClientConfig> {
50    let (cert, private_key) = generate_cert(key_pair)?;
51    let server_verifier = SrvrCertVerifier { peer_id };
52
53    let client_config = rustls::ClientConfig::builder_with_provider(
54        CryptoProvider {
55            kx_groups: KX_GROUPS.to_vec(),
56            cipher_suites: CIPHER_SUITES.to_vec(),
57            ..aws_lc_rs::default_provider()
58        }
59        .into(),
60    )
61    .with_protocol_versions(PROTOCOL_VERSIONS)?
62    .dangerous()
63    .with_custom_certificate_verifier(Arc::new(server_verifier))
64    .with_client_auth_cert(vec![cert], private_key)?;
65
66    Ok(client_config)
67}
68
69/// Returns a TLS server configuration.
70pub fn tls_server_config(key_pair: &KeyPair) -> Result<rustls::ServerConfig> {
71    let (cert, private_key) = generate_cert(key_pair)?;
72    let client_verifier = CliCertVerifier {};
73    let server_config = rustls::ServerConfig::builder_with_provider(
74        CryptoProvider {
75            kx_groups: KX_GROUPS.to_vec(),
76            cipher_suites: CIPHER_SUITES.to_vec(),
77            ..aws_lc_rs::default_provider()
78        }
79        .into(),
80    )
81    .with_protocol_versions(PROTOCOL_VERSIONS)?
82    .with_client_cert_verifier(Arc::new(client_verifier))
83    .with_single_cert(vec![cert], private_key)?;
84
85    Ok(server_config)
86}
87
88/// Generates a certificate and returns both the certificate and the private key.
89fn generate_cert<'a>(key_pair: &KeyPair) -> Result<(CertificateDer<'a>, PrivateKeyDer<'a>)> {
90    let cert_key_pair = rcgen::KeyPair::generate_for(&rcgen::PKCS_ED25519)?;
91    let private_key = PrivateKeyDer::Pkcs8(cert_key_pair.serialize_der().into());
92
93    // Add a custom extension to the certificate:
94    //   - Sign the certificate's public key with the provided key pair's private key
95    //   - Append both the computed signature and the key pair's public key to the extension
96    let signature = key_pair.sign(&cert_key_pair.subject_public_key_info());
97    let ext_content = yasna::encode_der(&(key_pair.public().as_bytes().to_vec(), signature));
98    let mut ext = rcgen::CustomExtension::from_oid_content(PEER_ID_EXT_OID, ext_content);
99    // XXX: Non-critical because rustls rejects unknown critical extensions
100    // before our custom verifiers can run. The extension is still validated
101    // by verify_cert() which requires it to be present and checks the signature.
102    ext.set_criticality(false);
103
104    let mut params = rcgen::CertificateParams::new(vec![])?;
105    params.custom_extensions.push(ext);
106
107    let cert = CertificateDer::from(params.self_signed(&cert_key_pair)?);
108    Ok((cert, private_key))
109}
110
111/// Derive the peer id from a peer's certificate chain post-handshake.
112/// Reuses `verify_cert` so the validation rules stay in one place.
113/// Returns `None` if the chain is empty or the cert lacks our extension.
114pub(crate) fn peer_id_from_certs(certs: &[CertificateDer<'_>]) -> Option<PeerID> {
115    let end_entity = certs.first()?;
116    verify_cert(end_entity).ok()
117}
118
119/// Verifies the given certification.
120fn verify_cert(end_entity: &CertificateDer<'_>) -> std::result::Result<PeerID, rustls::Error> {
121    // Parse the certificate.
122    let cert = parse_cert(end_entity)?;
123
124    // Find our custom extension by OID, not by position.
125    let want_oid = Oid::from(PEER_ID_EXT_OID).map_err(|_| BAD_ENCODING_ERR)?;
126    let ext = cert
127        .extensions()
128        .iter()
129        .find(|e| e.oid == want_oid)
130        .ok_or(BAD_ENCODING_ERR)?;
131
132    // Extract the peer id (public key) and the signature from the extension.
133    let (public_key, signature): (Vec<u8>, Vec<u8>) =
134        yasna::decode_der(ext.value).map_err(|_| BAD_ENCODING_ERR)?;
135
136    // Use the peer id (public key) to verify the extracted signature.
137    let public_key =
138        PublicKey::from_bytes(&KeyPairType::Ed25519, &public_key).map_err(|_| BAD_ENCODING_ERR)?;
139    public_key
140        .verify(cert.public_key().raw, &signature)
141        .map_err(|_| BAD_SIGNATURE_ERR)?;
142
143    // Verify the certificate signature.
144    verify_cert_signature(
145        &cert,
146        cert.tbs_certificate.as_ref(),
147        cert.signature_value.as_ref(),
148    )?;
149
150    PeerID::try_from(public_key).map_err(|_| BAD_ENCODING_ERR)
151}
152
153/// Parses the given x509 certificate.
154fn parse_cert<'a>(
155    end_entity: &'a CertificateDer<'a>,
156) -> std::result::Result<X509Certificate<'a>, rustls::Error> {
157    let (_, cert) = parse_x509_certificate(end_entity.as_ref()).map_err(|_| BAD_ENCODING_ERR)?;
158
159    if !cert.validity().is_valid() {
160        return Err(InvalidCertificate(CertificateError::NotValidYet));
161    }
162
163    Ok(cert)
164}
165
166/// Verifies the signature of the given certificate.
167fn verify_cert_signature(
168    cert: &X509Certificate,
169    message: &[u8],
170    signature: &[u8],
171) -> std::result::Result<(), rustls::Error> {
172    let public_key = PublicKey::from_bytes(
173        &KeyPairType::Ed25519,
174        cert.tbs_certificate.subject_pki.subject_public_key.as_ref(),
175    )
176    .map_err(|_| BAD_ENCODING_ERR)?;
177
178    public_key
179        .verify(message, signature)
180        .map_err(|_| BAD_SIGNATURE_ERR)
181}
182
183#[derive(Debug)]
184struct SrvrCertVerifier {
185    peer_id: Option<PeerID>,
186}
187
188impl ServerCertVerifier for SrvrCertVerifier {
189    fn verify_server_cert(
190        &self,
191        end_entity: &CertificateDer<'_>,
192        _intermediates: &[CertificateDer<'_>],
193        _server_name: &ServerName,
194        _ocsp_response: &[u8],
195        _now: UnixTime,
196    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
197        let peer_id = match verify_cert(end_entity) {
198            Ok(pid) => pid,
199            Err(err) => {
200                error!("Failed to verify cert: {err}");
201                return Err(err);
202            }
203        };
204
205        // Verify that the peer id in the certificate's extension matches the
206        // one the client intends to connect to.
207        // Both should be equal for establishing a fully secure connection.
208        if let Some(pid) = &self.peer_id {
209            if pid != &peer_id {
210                return Err(InvalidCertificate(
211                    CertificateError::ApplicationVerificationFailure,
212                ));
213            }
214        }
215
216        Ok(ServerCertVerified::assertion())
217    }
218
219    fn verify_tls12_signature(
220        &self,
221        _message: &[u8],
222        _cert: &CertificateDer<'_>,
223        _dss: &DigitallySignedStruct,
224    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
225        unreachable!("ONLY SUPPORT tls 13 VERSION")
226    }
227
228    fn verify_tls13_signature(
229        &self,
230        message: &[u8],
231        cert: &CertificateDer<'_>,
232        dss: &DigitallySignedStruct,
233    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
234        let cert = parse_cert(cert)?;
235        verify_cert_signature(&cert, message, dss.signature())?;
236        Ok(HandshakeSignatureValid::assertion())
237    }
238
239    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
240        SIGNATURE_SCHEMES.to_vec()
241    }
242}
243
244#[derive(Debug)]
245struct CliCertVerifier {}
246impl ClientCertVerifier for CliCertVerifier {
247    fn verify_client_cert(
248        &self,
249        end_entity: &CertificateDer<'_>,
250        _intermediates: &[CertificateDer<'_>],
251        _now: UnixTime,
252    ) -> std::result::Result<ClientCertVerified, rustls::Error> {
253        if let Err(err) = verify_cert(end_entity) {
254            error!("Failed to verify cert: {err}");
255            return Err(err);
256        };
257        Ok(ClientCertVerified::assertion())
258    }
259
260    fn root_hint_subjects(&self) -> &[DistinguishedName] {
261        &[]
262    }
263
264    fn verify_tls12_signature(
265        &self,
266        _message: &[u8],
267        _cert: &CertificateDer<'_>,
268        _dss: &DigitallySignedStruct,
269    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
270        unreachable!("ONLY SUPPORT tls 13 VERSION")
271    }
272
273    fn verify_tls13_signature(
274        &self,
275        message: &[u8],
276        cert: &CertificateDer<'_>,
277        dss: &DigitallySignedStruct,
278    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
279        let cert = parse_cert(cert)?;
280        verify_cert_signature(&cert, message, dss.signature())?;
281        Ok(HandshakeSignatureValid::assertion())
282    }
283
284    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
285        SIGNATURE_SCHEMES.to_vec()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn verify_generated_certificate() {
295        let key_pair = KeyPair::generate(&KeyPairType::Ed25519);
296        let (cert, _) = generate_cert(&key_pair).unwrap();
297
298        let result = verify_cert(&cert);
299        assert!(result.is_ok());
300        let peer_id = result.unwrap();
301        assert_eq!(peer_id, PeerID::try_from(key_pair.public()).unwrap());
302        assert_eq!(peer_id.0, key_pair.public().as_bytes());
303    }
304}