nix_compat/nix_daemon/framing/
framed_read.rs

1use std::{
2    io::Result,
3    pin::Pin,
4    task::{ready, Poll},
5};
6
7use pin_project_lite::pin_project;
8use tokio::io::{AsyncRead, ReadBuf};
9
10/// State machine for [`NixFramedReader`].
11///
12/// As the reader progresses it linearly cycles through the states.
13#[derive(Debug)]
14enum NixFramedReaderState {
15    /// The reader always starts in this state.
16    ///
17    /// Before the payload, the client first sends its size.
18    /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
19    /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
20    /// So in this state we read up to 8 bytes and transition to
21    /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
22    /// otherwise we reset filled to 0, and read the next size value.
23    ReadingSize { buf: [u8; 8], filled: usize },
24    /// This is where we read the actual payload that is sent to us.
25    ///
26    /// Once we've read the expected number of bytes, we go back to the
27    /// [`NixFramedReaderState::ReadingSize`] state.
28    ReadingPayload {
29        /// Represents the remaining number of bytes we expect to read based on the value
30        /// read in the previous state.
31        remaining: u64,
32    },
33}
34
35pin_project! {
36    /// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
37    ///
38    /// See serialization.md#framed and [`NixFramedReaderState`] for details.
39    pub struct NixFramedReader<R> {
40        #[pin]
41        reader: R,
42        state: NixFramedReaderState,
43    }
44}
45
46impl<R> NixFramedReader<R> {
47    pub fn new(reader: R) -> Self {
48        Self {
49            reader,
50            state: NixFramedReaderState::ReadingSize {
51                buf: [0; 8],
52                filled: 0,
53            },
54        }
55    }
56}
57
58impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
59    fn poll_read(
60        mut self: Pin<&mut Self>,
61        cx: &mut std::task::Context<'_>,
62        read_buf: &mut ReadBuf<'_>,
63    ) -> Poll<Result<()>> {
64        let mut this = self.as_mut().project();
65        match this.state {
66            NixFramedReaderState::ReadingSize { buf, filled } => {
67                if *filled < buf.len() {
68                    let mut size_buf = ReadBuf::new(buf);
69                    size_buf.advance(*filled);
70
71                    ready!(this.reader.poll_read(cx, &mut size_buf))?;
72                    let bytes_read = size_buf.filled().len() - *filled;
73                    if bytes_read == 0 {
74                        // oef
75                        return Poll::Ready(Ok(()));
76                    }
77                    *filled += bytes_read;
78                    // Schedule ourselves to run again.
79                    return self.poll_read(cx, read_buf);
80                }
81                let size = u64::from_le_bytes(*buf);
82                if size == 0 {
83                    // eof
84                    *filled = 0;
85                    return Poll::Ready(Ok(()));
86                }
87                *this.state = NixFramedReaderState::ReadingPayload { remaining: size };
88                self.poll_read(cx, read_buf)
89            }
90            NixFramedReaderState::ReadingPayload { remaining } => {
91                // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
92                let safe_remaining = if *remaining <= usize::MAX as u64 {
93                    *remaining as usize
94                } else {
95                    usize::MAX
96                };
97                if safe_remaining > 0 {
98                    // The buffer is no larger than the amount of data that we expect.
99                    // Otherwise we will trim the buffer below and come back here.
100                    if read_buf.remaining() <= safe_remaining {
101                        let filled_before = read_buf.filled().len();
102
103                        ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
104                        let bytes_read = read_buf.filled().len() - filled_before;
105
106                        *remaining -= bytes_read as u64;
107                        if *remaining == 0 {
108                            *this.state = NixFramedReaderState::ReadingSize {
109                                buf: [0; 8],
110                                filled: 0,
111                            };
112                        }
113                        return Poll::Ready(Ok(()));
114                    }
115                    // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
116                    // internal bookkeeping simpler.
117                    let mut smaller_buf = read_buf.take(safe_remaining);
118                    ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
119
120                    let bytes_read = smaller_buf.filled().len();
121
122                    // SAFETY: we just read this number of bytes into read_buf's backing slice above.
123                    unsafe { read_buf.assume_init(bytes_read) };
124                    read_buf.advance(bytes_read);
125                    return Poll::Ready(Ok(()));
126                }
127                *this.state = NixFramedReaderState::ReadingSize {
128                    buf: [0; 8],
129                    filled: 0,
130                };
131                self.poll_read(cx, read_buf)
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod nix_framed_tests {
139    use std::time::Duration;
140
141    use tokio::io::AsyncReadExt;
142    use tokio_test::io::Builder;
143
144    use crate::nix_daemon::framing::NixFramedReader;
145
146    #[tokio::test]
147    async fn read_hello_world_in_two_frames() {
148        let mut mock = Builder::new()
149            // The client sends len
150            .read(&5u64.to_le_bytes())
151            // Immediately followed by the bytes
152            .read("hello".as_bytes())
153            .wait(Duration::ZERO)
154            // Send more data separately
155            .read(&6u64.to_le_bytes())
156            .read(" world".as_bytes())
157            .build();
158
159        let mut reader = NixFramedReader::new(&mut mock);
160        let mut result = String::new();
161        reader
162            .read_to_string(&mut result)
163            .await
164            .expect("Could not read into result");
165        assert_eq!("hello world", result);
166    }
167    #[tokio::test]
168    async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
169        let mut mock = Builder::new()
170            // The client sends len
171            .read(&5u64.to_le_bytes())
172            // Immediately followed by the bytes
173            .read("hello".as_bytes())
174            .wait(Duration::ZERO)
175            // Send more data separately
176            .read(&6u64.to_le_bytes())
177            .read(" world".as_bytes())
178            .read(&0u64.to_le_bytes())
179            .build();
180
181        let mut reader = NixFramedReader::new(&mut mock);
182        let mut result = String::new();
183        reader
184            .read_to_string(&mut result)
185            .await
186            .expect("Could not read into result");
187        assert_eq!("hello world", result);
188    }
189}