1use prost::Message as _;
6
7#[cfg(feature = "uniffi")]
8uniffi::setup_scaffolding!();
9
10include!(concat!(env!("OUT_DIR"), "/_include.rs"));
11pub use messages::*;
12
13mod timestamp;
14
15#[derive(Debug, thiserror::Error)]
16enum ProtobufError {
17 #[error(transparent)]
18 ProtobufDecode(#[from] prost::DecodeError),
19 #[error(transparent)]
20 ProtobufEncode(#[from] prost::EncodeError),
21}
22
23#[derive(Debug, Clone, thiserror::Error)]
24pub enum ProtocolError {
25 #[error("protobuf encode error: {0}")]
26 ProtobufEncode(#[from] prost::EncodeError),
27 #[error("protobuf decode error: {0}")]
28 ProtobufDecode(#[from] prost::DecodeError),
29 #[error("short buffer, need {0} bytes")]
30 ShortBuffer(usize),
31 #[error("invalid message")]
32 InvalidMessage,
33 #[error("invalid message type: {0} (len={1})")]
34 InvalidMessageType(u32, usize),
35}
36
37pub const MAX_MESSAGE_SIZE: usize = 1048576;
40
41pub const ALPN_PROTOCOL_VERSION: &[u8] = b"mm00";
43
44macro_rules! message_types {
46 ($($num:expr => $variant:ident),*,) => {
47 #[repr(u32)]
49 #[derive(Clone, Debug, PartialEq)]
50 pub enum MessageType {
51 $($variant($variant) = $num),*
52 }
53
54 impl std::fmt::Display for MessageType {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 $(MessageType::$variant(_) => write!(f, "{}:{}", $num, stringify!($variant))),*
58 }
59 }
60 }
61
62 impl MessageType {
63 fn message_type(&self) -> u32 {
64 match self {
65 $(MessageType::$variant(_) => $num),*
66 }
67 }
68
69 fn encoded_len(&self) -> usize {
70 match self {
71 $(MessageType::$variant(v) => v.encoded_len()),*
72 }
73 }
74
75 fn encode<B>(&self, buf: &mut B) -> Result<(), ProtocolError>
76 where
77 B: bytes::BufMut,
78 {
79 let res = match self {
80 $(MessageType::$variant(v) => v.encode(buf)),*
81 };
82
83 res.map_err(|e| e.into())
84 }
85
86 fn decode<B: bytes::Buf>(msg_type: u32, total_len: usize, buf: B) -> Result<Self, ProtocolError> {
87 match msg_type {
88 $($num => Ok($variant::decode(buf)?.into())),*,
89 _ => Err(ProtocolError::InvalidMessageType(msg_type, total_len)),
90 }
91 }
92 }
93
94 $(impl From<$variant> for MessageType {
95 fn from(v: $variant) -> Self {
96 MessageType::$variant(v)
97 }
98 })*
99 };
100}
101
102message_types! {
103 1 => Error,
104 11 => ListApplications,
105 12 => ApplicationList,
106 13 => LaunchSession,
107 14 => SessionLaunched,
108 15 => UpdateSession,
109 16 => SessionUpdated,
110 17 => ListSessions,
111 18 => SessionList,
112 19 => EndSession,
113 20 => SessionEnded,
114 21 => FetchApplicationImage,
115 22 => ApplicationImage,
116 30 => Attach,
117 31 => Attached,
118 32 => KeepAlive,
119 33 => SessionParametersChanged,
120 35 => Detach,
121 51 => VideoChunk,
122 52 => RequestVideoRefresh,
123 56 => AudioChunk,
124 60 => KeyboardInput,
125 61 => PointerEntered,
126 62 => PointerLeft,
127 63 => PointerMotion,
128 64 => PointerInput,
129 65 => PointerScroll,
130 66 => UpdateCursor,
131 67 => LockPointer,
132 68 => ReleasePointer,
133 69 => RelativePointerMotion,
134 70 => GamepadAvailable,
135 71 => GamepadUnavailable,
136 72 => GamepadMotion,
137 73 => GamepadInput,
138}
139
140pub fn decode_message(buf: &[u8]) -> Result<(MessageType, usize), ProtocolError> {
144 if buf.len() < 10 {
145 return Err(ProtocolError::ShortBuffer(10));
146 }
147
148 let (msg_type, data_off, total_len) = {
149 let mut hdr = octets::Octets::with_slice(&buf[..10]);
150
151 let remaining = get_varint32(&mut hdr)? as usize;
152 let prefix_off = hdr.off();
153
154 let msg_type = get_varint32(&mut hdr)?;
155 let off = hdr.off();
156
157 (msg_type, off, prefix_off + remaining)
158 };
159
160 if msg_type == 0 || total_len == 0 || total_len > MAX_MESSAGE_SIZE || data_off > total_len {
161 return Err(ProtocolError::InvalidMessage);
162 } else if data_off > buf.len() || total_len > buf.len() {
163 return Err(ProtocolError::ShortBuffer(total_len));
164 }
165
166 let padded_len = total_len.max(10);
167 let msg = MessageType::decode(msg_type, padded_len, &buf[data_off..total_len])?;
168 Ok((msg, padded_len))
169}
170
171pub fn encode_message(msg: &MessageType, buf: &mut [u8]) -> Result<usize, ProtocolError> {
175 let msg_type = msg.message_type();
176 let msg_len =
177 u32::try_from(msg.encoded_len()).map_err(|_| ProtocolError::InvalidMessage)? as usize;
178
179 let header_len = encode_header(msg_type, msg_len, buf)?;
180 let total_len = header_len + msg_len;
181
182 let mut msg_buf = &mut buf[header_len..];
183 msg.encode(&mut msg_buf)?;
184
185 if total_len < 10 {
186 buf[total_len..].fill(0);
187 Ok(10)
188 } else {
189 Ok(total_len)
190 }
191}
192
193fn encode_header(msg_type: u32, msg_len: usize, buf: &mut [u8]) -> Result<usize, ProtocolError> {
194 let msg_type_len = octets::varint_len(msg_type as u64);
195 let prefix_len = octets::varint_len((msg_type_len + msg_len) as u64);
196 let total_len = prefix_len + msg_type_len + msg_len;
197
198 if total_len > MAX_MESSAGE_SIZE {
199 return Err(ProtocolError::InvalidMessage);
200 } else if total_len > buf.len() || buf.len() < 10 {
201 return Err(ProtocolError::ShortBuffer(std::cmp::max(total_len, 10)));
202 }
203
204 let off = {
205 let mut hdr = octets::OctetsMut::with_slice(buf);
206 hdr.put_varint((msg_type_len + msg_len) as u64).unwrap();
207 hdr.put_varint(msg_type as u64).unwrap();
208 hdr.off()
209 };
210
211 Ok(off)
212}
213
214fn get_varint32(buf: &mut octets::Octets) -> Result<u32, ProtocolError> {
216 let x = match buf.get_varint() {
217 Ok(x) => x,
218 Err(_) => return Err(ProtocolError::InvalidMessage),
219 };
220
221 u32::try_from(x).map_err(|_| ProtocolError::InvalidMessage)
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 macro_rules! test_roundtrip {
229 ($name:ident : $value:expr) => {
230 #[test]
231 fn $name() {
232 let msg = $value.into();
233 let mut buf = [0; MAX_MESSAGE_SIZE];
234 let len = encode_message(&msg, &mut buf).unwrap();
235 let (decoded_msg, decoded_len) = decode_message(&buf).unwrap();
236 assert_eq!(msg, decoded_msg);
237 assert_eq!(len, decoded_len);
238 }
239 };
240 }
241
242 test_roundtrip!(test_roundtrip_detach: Detach {});
243
244 test_roundtrip!(test_roundtrip_error: Error {
245 err_code: 1,
246 error_text: "test".to_string(),
247 });
248
249 test_roundtrip!(test_roundtrip_smallframe: VideoChunk {
250 attachment_id: 0,
251 session_id: 1,
252 stream_seq: 1,
253 seq: 2,
254 chunk: 3,
255 num_chunks: 4,
256 data: bytes::Bytes::from(vec![9; 52]),
257 timestamp: 1234,
258 ..Default::default()
259 });
260
261 test_roundtrip!(test_roundtrip_frame: VideoChunk {
262 attachment_id: 0,
263 session_id: 1,
264 stream_seq: 1,
265 seq: 2,
266 chunk: 3,
267 num_chunks: 4,
268 data: bytes::Bytes::from(vec![9; 1200]),
269 timestamp: 1234,
270 hierarchical_layer: 0,
271 ..Default::default()
272 });
273
274 #[test]
275 fn invalid_message_type() {
276 let msg_type = 999;
277
278 let msg_buf = [100_u8; 322];
279 let msg_len = msg_buf.len();
280
281 let mut buf = [0; MAX_MESSAGE_SIZE];
283 let header_len =
284 encode_header(msg_type, msg_len, &mut buf).expect("failed to encode fake message");
285 let total_len = header_len + msg_len;
286 buf[header_len..total_len].copy_from_slice(&msg_buf);
287
288 match decode_message(&buf) {
289 Err(ProtocolError::InvalidMessageType(t, len)) => {
290 assert_eq!(t, 999);
291 assert_eq!(len, total_len);
292 }
293 v => panic!("expected InvalidMessageType, got {:?}", v),
294 }
295 }
296}