Skip to main content

karyon_net/codec/
length_codec.rs

1use bincode::config;
2
3use crate::{codec::Codec, ByteBuffer, Error, Result};
4
5const MAX_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
6
7/// The size of the message length.
8const MSG_LENGTH_SIZE: usize = std::mem::size_of::<u32>();
9
10#[derive(Clone)]
11pub struct LengthCodec {
12    max_size: usize,
13}
14
15impl LengthCodec {
16    pub fn new(max_size: usize) -> Self {
17        Self { max_size }
18    }
19}
20
21impl Default for LengthCodec {
22    fn default() -> Self {
23        Self {
24            max_size: MAX_BUFFER_SIZE,
25        }
26    }
27}
28
29fn bincode_config() -> impl config::Config {
30    config::standard().with_fixed_int_encoding()
31}
32
33impl Codec<ByteBuffer> for LengthCodec {
34    type Message = Vec<u8>;
35    type Error = Error;
36
37    fn encode(&self, src: &Vec<u8>, dst: &mut ByteBuffer) -> Result<usize> {
38        if src.len() > self.max_size {
39            return Err(Error::BufferFull(format!(
40                "Buffer size {} exceeds maximum {}",
41                src.len(),
42                self.max_size
43            )));
44        }
45
46        let length_buf = &mut [0u8; MSG_LENGTH_SIZE];
47        bincode::encode_into_slice(src.len() as u32, length_buf, bincode_config())?;
48        dst.extend_from_slice(length_buf);
49        dst.extend_from_slice(src);
50
51        Ok(dst.len())
52    }
53
54    fn decode(&self, src: &mut ByteBuffer) -> Result<Option<(usize, Vec<u8>)>> {
55        if src.len() < MSG_LENGTH_SIZE {
56            return Ok(None);
57        }
58
59        let mut length = [0u8; MSG_LENGTH_SIZE];
60        length.copy_from_slice(&src.as_ref()[..MSG_LENGTH_SIZE]);
61        let (length, _) = bincode::decode_from_slice::<u32, _>(&length, bincode_config())?;
62        let length = length as usize;
63
64        if length > self.max_size {
65            return Err(Error::BufferFull(format!(
66                "Frame length {} exceeds maximum {}",
67                length, self.max_size
68            )));
69        }
70
71        if src.len() - MSG_LENGTH_SIZE >= length {
72            Ok(Some((
73                length + MSG_LENGTH_SIZE,
74                src.as_ref()[MSG_LENGTH_SIZE..length + MSG_LENGTH_SIZE].to_vec(),
75            )))
76        } else {
77            Ok(None)
78        }
79    }
80}