karyon_net/codec/
length_codec.rs

1use karyon_core::util::{decode, encode_into_slice};
2
3use crate::{
4    codec::{ByteBuffer, Codec, Decoder, Encoder},
5    Error, Result,
6};
7
8const MAX_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
9
10/// The size of the message length.
11const MSG_LENGTH_SIZE: usize = std::mem::size_of::<u32>();
12
13#[derive(Clone)]
14pub struct LengthCodec {
15    max_size: usize,
16}
17
18impl LengthCodec {
19    pub fn new(max_size: usize) -> Self {
20        Self { max_size }
21    }
22}
23
24impl Default for LengthCodec {
25    fn default() -> Self {
26        Self {
27            max_size: MAX_BUFFER_SIZE,
28        }
29    }
30}
31
32impl Codec for LengthCodec {
33    type Message = Vec<u8>;
34    type Error = Error;
35}
36
37impl Encoder for LengthCodec {
38    type EnMessage = Vec<u8>;
39    type EnError = Error;
40    fn encode(&self, src: &Self::EnMessage, dst: &mut ByteBuffer) -> Result<usize> {
41        if src.len() > self.max_size {
42            return Err(Error::BufferFull(format!(
43                "Buffer size {} exceeds maximum {}",
44                src.len(),
45                self.max_size
46            )));
47        }
48
49        let length_buf = &mut [0u8; MSG_LENGTH_SIZE];
50
51        encode_into_slice(&(src.len() as u32), length_buf)?;
52        dst.extend_from_slice(length_buf);
53        dst.extend_from_slice(src);
54
55        Ok(dst.len())
56    }
57}
58
59impl Decoder for LengthCodec {
60    type DeMessage = Vec<u8>;
61    type DeError = Error;
62    fn decode(&self, src: &mut ByteBuffer) -> Result<Option<(usize, Self::DeMessage)>> {
63        if src.len() < MSG_LENGTH_SIZE {
64            return Ok(None);
65        }
66
67        if src.as_ref()[..MSG_LENGTH_SIZE].len() > self.max_size {
68            return Err(Error::BufferFull(format!(
69                "Buffer size {} exceeds maximum {}",
70                src.len(),
71                self.max_size
72            )));
73        }
74
75        let mut length = [0u8; MSG_LENGTH_SIZE];
76        length.copy_from_slice(&src.as_ref()[..MSG_LENGTH_SIZE]);
77        let (length, _) = decode::<u32>(&length)?;
78        let length = length as usize;
79
80        if src.len() - MSG_LENGTH_SIZE >= length {
81            Ok(Some((
82                length + MSG_LENGTH_SIZE,
83                src.as_ref()[MSG_LENGTH_SIZE..length + MSG_LENGTH_SIZE].to_vec(),
84            )))
85        } else {
86            Ok(None)
87        }
88    }
89}