karyon_p2p/protocols/
handshake.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use log::trace;
5
6use karyon_core::{async_util::timeout, util::decode};
7
8use crate::{
9    message::{NetMsg, NetMsgCmd, VerAckMsg, VerMsg},
10    peer::Peer,
11    protocol::{InitProtocol, ProtocolID},
12    version::{version_match, VersionInt},
13    Error, PeerID, Result, Version,
14};
15
16pub struct HandshakeProtocol {
17    peer: Arc<Peer>,
18    protocols: HashMap<ProtocolID, Version>,
19}
20
21#[async_trait]
22impl InitProtocol for HandshakeProtocol {
23    type T = Result<PeerID>;
24    /// Initiate a handshake with a connection.
25    async fn init(self: Arc<Self>) -> Self::T {
26        trace!("Init Handshake: {}", self.peer.remote_endpoint());
27
28        if !self.peer.is_inbound() {
29            self.send_vermsg().await?;
30        }
31
32        let t = Duration::from_secs(self.peer.config().handshake_timeout);
33        let msg: NetMsg = timeout(t, self.peer.conn.recv_inner()).await??;
34        match msg.header.command {
35            NetMsgCmd::Version => {
36                let result = self.validate_version_msg(&msg).await;
37                match result {
38                    Ok(_) => {
39                        self.send_verack(true).await?;
40                    }
41                    Err(Error::IncompatibleVersion(_)) | Err(Error::UnsupportedProtocol(_)) => {
42                        self.send_verack(false).await?;
43                    }
44                    _ => {}
45                };
46                result
47            }
48            NetMsgCmd::Verack => self.validate_verack_msg(&msg).await,
49            cmd => Err(Error::InvalidMsg(format!("unexpected msg found {:?}", cmd))),
50        }
51    }
52}
53
54impl HandshakeProtocol {
55    pub fn new(peer: Arc<Peer>, protocols: HashMap<ProtocolID, Version>) -> Arc<Self> {
56        Arc::new(Self { peer, protocols })
57    }
58
59    /// Sends a Version message
60    async fn send_vermsg(&self) -> Result<()> {
61        let protocols = self
62            .protocols
63            .clone()
64            .into_iter()
65            .map(|p| (p.0, p.1.v))
66            .collect();
67
68        let vermsg = VerMsg {
69            peer_id: self.peer.own_id().clone(),
70            protocols,
71            version: self.peer.config().version.v.clone(),
72        };
73
74        trace!("Send VerMsg");
75        self.peer
76            .conn
77            .send_inner(NetMsg::new(NetMsgCmd::Version, &vermsg)?)
78            .await?;
79        Ok(())
80    }
81
82    /// Sends a Verack message
83    async fn send_verack(&self, ack: bool) -> Result<()> {
84        let verack = VerAckMsg {
85            peer_id: self.peer.own_id().clone(),
86            ack,
87        };
88
89        trace!("Send VerAckMsg {:?}", verack);
90        self.peer
91            .conn
92            .send_inner(NetMsg::new(NetMsgCmd::Verack, &verack)?)
93            .await?;
94        Ok(())
95    }
96
97    /// Validates the given version msg
98    async fn validate_version_msg(&self, msg: &NetMsg) -> Result<PeerID> {
99        let (vermsg, _) = decode::<VerMsg>(&msg.payload)?;
100
101        if !version_match(&self.peer.config().version.req, &vermsg.version) {
102            return Err(Error::IncompatibleVersion("system: {}".into()));
103        }
104
105        self.protocols_match(&vermsg.protocols).await?;
106
107        trace!("Received VerMsg from: {}", vermsg.peer_id);
108        Ok(vermsg.peer_id)
109    }
110
111    /// Validates the given verack msg
112    async fn validate_verack_msg(&self, msg: &NetMsg) -> Result<PeerID> {
113        let (verack, _) = decode::<VerAckMsg>(&msg.payload)?;
114
115        if !verack.ack {
116            return Err(Error::IncompatiblePeer);
117        }
118
119        trace!("Received VerAckMsg from: {}", verack.peer_id);
120        Ok(verack.peer_id)
121    }
122
123    /// Check if the new connection has compatible protocols.
124    async fn protocols_match(&self, protocols: &HashMap<ProtocolID, VersionInt>) -> Result<()> {
125        for (n, pv) in protocols.iter() {
126            match self.protocols.get(n) {
127                Some(v) => {
128                    if !version_match(&v.req, pv) {
129                        return Err(Error::IncompatibleVersion(format!("{n} protocol: {pv}")));
130                    }
131                }
132                None => {
133                    return Err(Error::UnsupportedProtocol(n.to_string()));
134                }
135            }
136        }
137        Ok(())
138    }
139}