mm_client_common/
lib.rs

1// Copyright 2024 Colin Marc <hi@colinmarc.com>
2//
3// SPDX-License-Identifier: MIT
4
5use std::{
6    collections::{HashMap, HashSet},
7    sync::Arc,
8    time,
9};
10
11use async_mutex::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard};
12use futures::{channel::oneshot, executor::block_on};
13use mm_protocol as protocol;
14use tracing::{debug, error};
15
16mod attachment;
17mod conn;
18mod logging;
19mod packet;
20mod session;
21mod stats;
22mod validation;
23
24pub mod codec;
25pub mod display_params;
26pub mod input;
27pub mod pixel_scale;
28
29pub use attachment::*;
30pub use logging::*;
31pub use packet::*;
32pub use session::*;
33
34uniffi::setup_scaffolding!();
35
36pub use protocol::error::ErrorCode;
37
38#[derive(Debug, Clone, thiserror::Error, uniffi::Error)]
39#[uniffi(flat_error)]
40pub enum ClientError {
41    #[error("protocol error")]
42    ProtocolError(#[from] protocol::ProtocolError),
43    #[error("{}: {}", .0.err_code().as_str_name(), .0.error_text)]
44    ServerError(protocol::Error),
45    #[error("request timed out")]
46    RequestTimeout,
47    #[error("connection error")]
48    ConnectionError(#[from] conn::ConnError),
49    #[error("stream closed before request could be received")]
50    Canceled(#[from] oneshot::Canceled),
51    #[error("received unexpected message: {0}")]
52    UnexpectedMessage(protocol::MessageType),
53    #[error("message validation failed")]
54    ValidationFailed(#[from] validation::ValidationError),
55    #[error("client defunct")]
56    Defunct,
57    #[error("attachment ended")]
58    Detached,
59}
60
61/// A handle for the QUIC connection thread, used to push outgoing messages.
62struct ConnHandle {
63    thread_handle: std::thread::JoinHandle<Result<(), conn::ConnError>>,
64    waker: Arc<mio::Waker>,
65    outgoing: flume::Sender<conn::OutgoingMessage>,
66    roundtrips: flume::Sender<(u64, Roundtrip)>,
67    attachments: flume::Sender<(u64, AttachmentState)>,
68    shutdown: oneshot::Sender<()>,
69}
70
71impl ConnHandle {
72    /// Signals the connection thread that it should close.
73    fn close(self) -> Result<(), Option<conn::ConnError>> {
74        let _ = self.shutdown.send(());
75        self.waker.wake().map_err(conn::ConnError::from)?;
76
77        if !self.thread_handle.is_finished() {
78            return Ok(());
79        }
80
81        match self.thread_handle.join() {
82            Ok(Ok(_)) => Ok(()),
83            Ok(Err(e)) => Err(Some(e)),
84            // The connection thread panicked.
85            Err(_) => {
86                error!("connection thread panicked");
87                Err(None)
88            }
89        }
90    }
91}
92
93/// Stores the current connection state.
94enum ClientState {
95    Connected(ConnHandle),
96    Defunct(ClientError),
97}
98
99struct Roundtrip {
100    tx: oneshot::Sender<Result<protocol::MessageType, ClientError>>,
101    deadline: Option<time::Instant>,
102}
103
104/// Client state inside the mutex.
105struct InnerClient {
106    next_stream_id: u64,
107    state: ClientState,
108}
109
110impl InnerClient {
111    fn next_stream_id(&mut self) -> u64 {
112        let sid = self.next_stream_id;
113        self.next_stream_id += 4;
114
115        sid
116    }
117
118    fn close(&mut self) -> Result<(), ClientError> {
119        if let ClientState::Defunct(err) = &self.state {
120            return Err(err.clone());
121        }
122
123        let ClientState::Connected(conn) =
124            std::mem::replace(&mut self.state, ClientState::Defunct(ClientError::Defunct))
125        else {
126            unreachable!();
127        };
128
129        //Shut down the connection thread.
130        let close_err = conn.close();
131        if let Err(Some(e)) = &close_err {
132            error!("connection error: {e:?}");
133            self.state = ClientState::Defunct(e.clone().into());
134        }
135
136        match close_err {
137            Ok(_) => Ok(()),
138            Err(Some(e)) => Err(e.into()),
139            Err(None) => Err(ClientError::Defunct),
140        }
141    }
142}
143
144#[derive(uniffi::Object)]
145pub struct Client {
146    name: String,
147    addr: String,
148    connect_timeout: time::Duration,
149    inner: Arc<AsyncMutex<InnerClient>>,
150    stats: Arc<stats::StatsCollector>,
151}
152
153impl Client {
154    async fn reconnect(&self) -> Result<AsyncMutexGuard<InnerClient>, ClientError> {
155        let inner_clone = self.inner.clone();
156        let mut guard = self.inner.lock().await;
157
158        match &guard.state {
159            ClientState::Connected(_) => (),
160            ClientState::Defunct(ClientError::ConnectionError(conn::ConnError::Idle)) => {
161                // Reconnect after an idle timeout.
162                let conn = match spawn_conn(
163                    &self.addr,
164                    inner_clone,
165                    self.stats.clone(),
166                    self.connect_timeout,
167                )
168                .await
169                {
170                    Ok(conn) => conn,
171                    Err(e) => {
172                        error!("connection failed: {e:#}");
173                        return Err(e);
174                    }
175                };
176
177                guard.state = ClientState::Connected(conn);
178
179                debug!("reconnected after idle timeout");
180            }
181            ClientState::Defunct(e) => {
182                return Err(e.clone());
183            }
184        }
185
186        Ok(guard)
187    }
188
189    async fn initiate_stream(
190        &self,
191        msg: impl Into<protocol::MessageType>,
192        fin: bool,
193        timeout: Option<time::Duration>,
194    ) -> Result<(u64, protocol::MessageType), ClientError> {
195        let mut guard = self.reconnect().await?;
196
197        let sid = guard.next_stream_id();
198        let (oneshot_tx, oneshot_rx) = oneshot::channel();
199
200        let ConnHandle {
201            waker,
202            outgoing,
203            roundtrips,
204            ..
205        } = match &guard.state {
206            ClientState::Connected(conn) => conn,
207            ClientState::Defunct(err) => return Err(err.clone()),
208        };
209
210        if outgoing
211            .send(conn::OutgoingMessage {
212                sid,
213                msg: msg.into(),
214                fin,
215            })
216            .is_err()
217        {
218            match guard.close() {
219                Ok(_) => return Err(ClientError::Defunct),
220                Err(e) => return Err(e),
221            }
222        }
223
224        let deadline = timeout.map(|d| time::Instant::now() + d);
225        if roundtrips
226            .send_async((
227                sid,
228                Roundtrip {
229                    tx: oneshot_tx,
230                    deadline,
231                },
232            ))
233            .await
234            .is_err()
235        {
236            match guard.close() {
237                Ok(_) => return Err(ClientError::Defunct),
238                Err(e) => return Err(e),
239            }
240        };
241
242        waker.wake().map_err(conn::ConnError::from)?;
243
244        // We don't want to hold the mutex while waiting for a response.
245        drop(guard);
246
247        let res = oneshot_rx.await??;
248        Ok((sid, res))
249    }
250
251    async fn roundtrip(
252        &self,
253        msg: impl Into<protocol::MessageType>,
254        timeout: time::Duration,
255    ) -> Result<protocol::MessageType, ClientError> {
256        let (_, msg) = self.initiate_stream(msg, false, Some(timeout)).await?;
257        Ok(msg)
258    }
259}
260
261#[uniffi::export]
262impl Client {
263    #[uniffi::constructor]
264    pub async fn new(
265        addr: &str,
266        client_name: &str,
267        connect_timeout: time::Duration,
268    ) -> Result<Self, ClientError> {
269        let inner = Arc::new(AsyncMutex::new(InnerClient {
270            next_stream_id: 0,
271            state: ClientState::Defunct(ClientError::Defunct),
272        }));
273
274        let stats = Arc::new(stats::StatsCollector::default());
275        let conn = spawn_conn(addr, inner.clone(), stats.clone(), connect_timeout).await?;
276        inner.lock().await.state = ClientState::Connected(conn);
277
278        Ok(Self {
279            name: client_name.to_owned(),
280            addr: addr.to_owned(),
281            connect_timeout,
282            inner,
283            stats,
284        })
285    }
286
287    pub fn stats(&self) -> stats::ClientStats {
288        self.stats.snapshot()
289    }
290
291    pub async fn list_applications(
292        &self,
293        timeout: time::Duration,
294    ) -> Result<Vec<Application>, ClientError> {
295        let res = match self
296            .roundtrip(protocol::ListApplications {}, timeout)
297            .await?
298        {
299            protocol::MessageType::ApplicationList(res) => res,
300            protocol::MessageType::Error(e) => return Err(ClientError::ServerError(e)),
301            msg => return Err(ClientError::UnexpectedMessage(msg)),
302        };
303
304        Ok(res
305            .list
306            .into_iter()
307            .map(Application::try_from)
308            .collect::<Result<Vec<_>, validation::ValidationError>>()?)
309    }
310
311    pub async fn fetch_application_image(
312        &self,
313        application_id: String,
314        format: session::ApplicationImageFormat,
315        timeout: time::Duration,
316    ) -> Result<Vec<u8>, ClientError> {
317        let fetch = protocol::FetchApplicationImage {
318            format: format.into(),
319            application_id,
320        };
321
322        match self.roundtrip(fetch, timeout).await? {
323            protocol::MessageType::ApplicationImage(res) => Ok(res.image_data.into()),
324            protocol::MessageType::Error(e) => Err(ClientError::ServerError(e)),
325            msg => Err(ClientError::UnexpectedMessage(msg)),
326        }
327    }
328
329    pub async fn list_sessions(
330        &self,
331        timeout: time::Duration,
332    ) -> Result<Vec<Session>, ClientError> {
333        let res = match self.roundtrip(protocol::ListSessions {}, timeout).await? {
334            protocol::MessageType::SessionList(res) => res,
335            protocol::MessageType::Error(e) => return Err(ClientError::ServerError(e)),
336            msg => return Err(ClientError::UnexpectedMessage(msg)),
337        };
338
339        Ok(res
340            .list
341            .into_iter()
342            .map(Session::try_from)
343            .collect::<Result<Vec<_>, validation::ValidationError>>()?)
344    }
345
346    pub async fn launch_session(
347        &self,
348        application_id: String,
349        display_params: display_params::DisplayParams,
350        permanent_gamepads: Vec<input::Gamepad>,
351        timeout: time::Duration,
352    ) -> Result<Session, ClientError> {
353        let msg = protocol::LaunchSession {
354            application_id: application_id.clone(),
355            display_params: Some(display_params.clone().into()),
356            permanent_gamepads: permanent_gamepads.iter().map(|pad| (*pad).into()).collect(),
357        };
358
359        let res = match self.roundtrip(msg, timeout).await? {
360            protocol::MessageType::SessionLaunched(msg) => msg,
361            protocol::MessageType::Error(e) => return Err(ClientError::ServerError(e)),
362            msg => return Err(ClientError::UnexpectedMessage(msg)),
363        };
364
365        Ok(Session {
366            id: res.id,
367            start: time::SystemTime::now(),
368            application_id,
369            display_params,
370        })
371    }
372
373    pub async fn end_session(&self, id: u64, timeout: time::Duration) -> Result<(), ClientError> {
374        let msg = protocol::EndSession { session_id: id };
375        match self.roundtrip(msg, timeout).await? {
376            protocol::MessageType::SessionEnded(_) => Ok(()),
377            protocol::MessageType::Error(e) => Err(ClientError::ServerError(e)),
378            msg => Err(ClientError::UnexpectedMessage(msg)),
379        }
380    }
381
382    pub async fn update_session_display_params(
383        &self,
384        id: u64,
385        params: display_params::DisplayParams,
386        timeout: time::Duration,
387    ) -> Result<(), ClientError> {
388        let msg = protocol::UpdateSession {
389            session_id: id,
390            display_params: Some(params.into()),
391        };
392
393        match self.roundtrip(msg, timeout).await? {
394            protocol::MessageType::SessionUpdated(_) => Ok(()),
395            protocol::MessageType::Error(e) => Err(ClientError::ServerError(e)),
396            msg => Err(ClientError::UnexpectedMessage(msg)),
397        }
398    }
399
400    /// Attach to a session. The timeout parameter is used for the duration of
401    /// the initial request, i.e. until an Attached message is returned by the
402    /// server.
403    pub async fn attach_session(
404        &self,
405        session_id: u64,
406        config: AttachmentConfig,
407        delegate: Arc<dyn AttachmentDelegate>,
408        timeout: time::Duration,
409    ) -> Result<Attachment, ClientError> {
410        // Send an attach message using the roundtrip mechanism, but the leave
411        // the stream open.
412        let channel_conf = if config.channels.is_empty() {
413            None
414        } else {
415            Some(protocol::AudioChannels {
416                channels: config.channels.iter().copied().map(Into::into).collect(),
417            })
418        };
419
420        let attach = protocol::Attach {
421            session_id,
422            client_name: self.name.clone(),
423            attachment_type: protocol::AttachmentType::Operator.into(),
424            video_codec: config.video_codec.unwrap_or_default().into(),
425            streaming_resolution: Some(protocol::Size {
426                width: config.width,
427                height: config.height,
428            }),
429            video_profile: config.video_profile.unwrap_or_default().into(),
430            quality_preset: config.quality_preset.unwrap_or_default(),
431
432            audio_codec: config.audio_codec.unwrap_or_default().into(),
433            sample_rate_hz: config.sample_rate.unwrap_or_default(),
434            channels: channel_conf,
435        };
436
437        let (sid, res) = self.initiate_stream(attach, false, Some(timeout)).await?;
438
439        let attached = match res {
440            protocol::MessageType::Attached(att) => att,
441            protocol::MessageType::Error(e) => return Err(ClientError::ServerError(e)),
442            msg => return Err(ClientError::UnexpectedMessage(msg)),
443        };
444
445        Attachment::new(
446            sid,
447            self.inner.clone(),
448            attached,
449            delegate,
450            config.video_stream_seq_offset,
451        )
452        .await
453    }
454}
455
456async fn spawn_conn(
457    addr: &str,
458    client: Arc<AsyncMutex<InnerClient>>,
459    stats: Arc<stats::StatsCollector>,
460    connect_timeout: time::Duration,
461) -> Result<ConnHandle, ClientError> {
462    let (incoming_tx, incoming_rx) = flume::unbounded();
463    let (outgoing_tx, outgoing_rx) = flume::unbounded();
464    let (ready_tx, ready_rx) = oneshot::channel();
465    let (shutdown_tx, shutdown_rx) = oneshot::channel();
466
467    // Rendezvous channels for synchronized state.
468    let (roundtrips_tx, roundtrips_rx) = flume::bounded(0);
469    let (attachments_tx, attachments_rx) = flume::bounded(0);
470
471    let mut conn = conn::Conn::new(addr, incoming_tx, outgoing_rx, ready_tx, shutdown_rx, stats)?;
472    let waker = conn.waker();
473
474    // Spawn a polling loop for the quic connection.
475    let thread_handle = std::thread::Builder::new()
476        .name("QUIC conn".to_string())
477        .spawn(move || conn.run(connect_timeout))
478        .unwrap();
479
480    // Spawn a second thread to fulfill request/response futures and drive
481    // the attachment delegates.
482
483    let _ = std::thread::Builder::new()
484        .name("mmclient reactor".to_string())
485        .spawn(move || conn_reactor(incoming_rx, roundtrips_rx, attachments_rx, client))
486        .unwrap();
487
488    if ready_rx.await.is_err() {
489        // An error occured while spinning up.
490        match thread_handle.join() {
491            Ok(Ok(_)) | Err(_) => return Err(ClientError::Defunct),
492            Ok(Err(e)) => return Err(e.into()),
493        }
494    }
495
496    Ok(ConnHandle {
497        thread_handle,
498        waker,
499        outgoing: outgoing_tx,
500        shutdown: shutdown_tx,
501        roundtrips: roundtrips_tx,
502        attachments: attachments_tx,
503    })
504}
505
506#[derive(Default)]
507struct InFlight {
508    roundtrips: HashMap<u64, Roundtrip>,
509    attachments: HashMap<u64, AttachmentState>,
510    prev_attachments: HashSet<u64>, // By attachment ID.
511}
512
513fn conn_reactor(
514    incoming: flume::Receiver<conn::ConnEvent>,
515    roundtrips: flume::Receiver<(u64, Roundtrip)>,
516    attachments: flume::Receiver<(u64, AttachmentState)>,
517    client: Arc<AsyncMutex<InnerClient>>,
518) {
519    let mut in_flight = InFlight::default();
520    let mut tick = time::Instant::now() + time::Duration::from_secs(1);
521
522    loop {
523        // Perform some cleanup once per second.
524        let now = time::Instant::now();
525        if now > tick {
526            tick = now + time::Duration::from_secs(1);
527
528            // Check roundtrip deadlines.
529            let mut timed_out = Vec::new();
530            for (sid, Roundtrip { deadline, .. }) in in_flight.roundtrips.iter() {
531                if deadline.is_some_and(|dl| now >= dl) {
532                    timed_out.push(*sid);
533                }
534            }
535
536            // Fulfill the futures with an error.
537            for id in &timed_out {
538                let Roundtrip { tx, .. } = in_flight.roundtrips.remove(id).unwrap();
539                let _ = tx.send(Err(ClientError::RequestTimeout));
540            }
541        }
542
543        enum SelectResult {
544            RecvError,
545            InsertRoundtrip(u64, Roundtrip),
546            InsertAttachment(u64, AttachmentState),
547            Incoming(conn::ConnEvent),
548        }
549
550        let res = flume::select::Selector::new()
551            .recv(&roundtrips, |ev| {
552                if let Ok((sid, rt)) = ev {
553                    SelectResult::InsertRoundtrip(sid, rt)
554                } else {
555                    SelectResult::RecvError
556                }
557            })
558            .recv(&attachments, |ev| {
559                if let Ok((sid, att)) = ev {
560                    SelectResult::InsertAttachment(sid, att)
561                } else {
562                    SelectResult::RecvError
563                }
564            })
565            .recv(&incoming, |ev| {
566                if let Ok(ev) = ev {
567                    SelectResult::Incoming(ev)
568                } else {
569                    SelectResult::RecvError
570                }
571            })
572            .wait_deadline(tick);
573
574        match res {
575            Err(flume::select::SelectError::Timeout) => continue,
576            Ok(SelectResult::RecvError) => break,
577            Ok(SelectResult::InsertRoundtrip(sid, rt)) => {
578                in_flight.roundtrips.insert(sid, rt);
579            }
580            Ok(SelectResult::InsertAttachment(sid, att)) => {
581                in_flight.attachments.insert(sid, att);
582            }
583            Ok(SelectResult::Incoming(ev)) => conn_reactor_handle_incoming(&mut in_flight, ev),
584        };
585    }
586
587    // The client is probably already closed, but we should make sure, since
588    // this thread is the only one notified if the connection thread died.
589    let mut guard = block_on(client.lock());
590    let stream_err = match guard.close() {
591        Err(e) => Some(e.clone()),
592        Ok(_) => None,
593    };
594
595    for (_, att) in in_flight.attachments.drain() {
596        att.handle_close(stream_err.clone());
597    }
598
599    in_flight.roundtrips.clear(); // Cancels the futures.
600}
601
602fn conn_reactor_handle_incoming(in_flight: &mut InFlight, ev: conn::ConnEvent) {
603    match ev {
604        conn::ConnEvent::StreamMessage(sid, msg) => {
605            if let Some(attachment) = in_flight.attachments.get_mut(&sid) {
606                attachment.handle_message(msg);
607                return;
608            }
609
610            if let Some(Roundtrip { tx, .. }) = in_flight.roundtrips.remove(&sid) {
611                let _ = tx.send(Ok(msg));
612            }
613        }
614        conn::ConnEvent::Datagram(msg) => {
615            let (session_id, attachment_id) = match &msg {
616                protocol::MessageType::VideoChunk(chunk) => (chunk.session_id, chunk.attachment_id),
617                protocol::MessageType::AudioChunk(chunk) => (chunk.session_id, chunk.attachment_id),
618                msg => {
619                    error!("unexpected {} as datagram", msg);
620                    return;
621                }
622            };
623
624            // Find the relevant attachment. The session ID and attachment
625            // may be omitted if there's only one attachment.
626            let attachment = match (session_id, attachment_id) {
627                (0, 0) if in_flight.attachments.len() == 1 => {
628                    in_flight.attachments.iter_mut().next()
629                }
630                (0, _) | (_, 0) => None, // This is invalid.
631                (s, a) => in_flight
632                    .attachments
633                    .iter_mut()
634                    .find(|(_, att)| att.session_id == s && att.attachment_id == a),
635            };
636
637            if let Some((_, attachment)) = attachment {
638                attachment.handle_message(msg);
639            } else if !in_flight.prev_attachments.contains(&attachment_id) {
640                error!(
641                    session_id,
642                    attachment_id, "failed to match datagram to attachment"
643                );
644            }
645        }
646        conn::ConnEvent::StreamClosed(sid) => {
647            in_flight.roundtrips.remove(&sid);
648            if let Some(attachment) = in_flight.attachments.remove(&sid) {
649                in_flight.prev_attachments.insert(attachment.attachment_id);
650                attachment.handle_close(None);
651            }
652        }
653    }
654}