aboutsummaryrefslogtreecommitdiffstats
path: root/cms-socket/src/msg.rs
blob: b48699218ac4bbccbc668014110af91fb20215e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// -*- coding: utf-8 -*-
//
// Simple CMS
//
// Copyright (C) 2011-2024 Michael Büsch <m@bues.ch>
//
// Licensed under the Apache License version 2.0
// or the MIT license, at your option.
// SPDX-License-Identifier: Apache-2.0 OR MIT

use anyhow as ah;
use bincode::Options as _;
use serde::{Deserialize, Serialize};

pub const MSG_HDR_LEN: usize = 8;
pub const MAX_RX_BUF: usize = 1024 * 1024 * 64;

#[derive(Clone, Debug)]
pub enum DeserializeResult<M> {
    Ok(M),
    Pending(usize),
}

pub trait MsgSerde<M> {
    fn msg_serialize(&self) -> ah::Result<Vec<u8>>;
    fn try_msg_deserialize(buf: &[u8]) -> ah::Result<DeserializeResult<M>>;
}

#[inline]
pub fn bincode_config() -> impl bincode::Options {
    bincode::DefaultOptions::new()
        .with_limit(MAX_RX_BUF.try_into().unwrap())
        .with_native_endian()
        .with_fixint_encoding()
        .reject_trailing_bytes()
}

/// Generic message header.
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct MsgHdr {
    magic: u32,
    payload_len: u32,
}

impl MsgHdr {
    #[inline]
    pub fn new(magic: u32, payload_len: usize) -> Self {
        Self {
            magic,
            payload_len: payload_len
                .try_into()
                .expect("MsgHdr: Payload length too long"),
        }
    }

    #[inline]
    pub fn magic(&self) -> u32 {
        self.magic
    }

    #[inline]
    pub fn len() -> usize {
        debug_assert_eq!(
            MSG_HDR_LEN,
            bincode_config()
                .serialized_size(&MsgHdr {
                    magic: 0,
                    payload_len: 0,
                })
                .unwrap()
                .try_into()
                .unwrap()
        );
        MSG_HDR_LEN
    }

    #[inline]
    pub fn payload_len(&self) -> usize {
        self.payload_len.try_into().unwrap()
    }
}

#[macro_export]
macro_rules! impl_msg_serde {
    ($struct:ty, $magic:literal) => {
        impl $crate::MsgSerde<$struct> for $struct {
            fn msg_serialize(&self) -> anyhow::Result<Vec<u8>> {
                use anyhow::Context as _;
                use bincode::Options as _;
                use $crate::{bincode_config, MsgHdr};

                let mut payload = bincode_config().serialize(self)?;
                let mut ret = bincode_config().serialize(&MsgHdr::new($magic, payload.len()))?;
                ret.append(&mut payload);
                Ok(ret)
            }

            fn try_msg_deserialize(buf: &[u8]) -> anyhow::Result<$crate::DeserializeResult<Msg>> {
                use anyhow::Context as _;
                use bincode::Options as _;
                use $crate::{bincode_config, MsgHdr};

                let hdr_len = MsgHdr::len();
                if buf.len() < hdr_len {
                    Ok($crate::DeserializeResult::Pending(hdr_len - buf.len()))
                } else {
                    let hdr: MsgHdr = bincode_config()
                        .deserialize(&buf[0..hdr_len])
                        .context("Deserialize MsgHdr")?;
                    if hdr.magic() != $magic {
                        return Err(anyhow::format_err!("Deserialize: Invalid magic code."));
                    }
                    let full_len = hdr_len
                        .checked_add(hdr.payload_len())
                        .context("Msg length overflow")?;
                    if buf.len() < full_len {
                        Ok($crate::DeserializeResult::Pending(full_len - buf.len()))
                    } else {
                        let msg = bincode_config()
                            .deserialize(&buf[hdr_len..full_len])
                            .context("Deserialize Msg")?;
                        Ok($crate::DeserializeResult::Ok(msg))
                    }
                }
            }
        }
    };
}

// vim: ts=4 sw=4 expandtab
bues.ch cgit interface