1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
use std::{
    io::Result,
    pin::Pin,
    task::{ready, Poll},
};

use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, ReadBuf};

/// State machine for [`NixFramedReader`].
///
/// As the reader progresses it linearly cycles through the states.
#[derive(Debug)]
enum NixFramedReaderState {
    /// The reader always starts in this state.
    ///
    /// Before the payload, the client first sends its size.
    /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
    /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
    /// So in this state we read up to 8 bytes and transition to
    /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
    /// otherwise we reset filled to 0, and read the next size value.
    ReadingSize { buf: [u8; 8], filled: usize },
    /// This is where we read the actual payload that is sent to us.
    ///
    /// Once we've read the expected number of bytes, we go back to the
    /// [`NixFramedReaderState::ReadingSize`] state.
    ReadingPayload {
        /// Represents the remaining number of bytes we expect to read based on the value
        /// read in the previous state.
        remaining: u64,
    },
}

pin_project! {
    /// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
    ///
    /// See serialization.md#framed and [`NixFramedReaderState`] for details.
    pub struct NixFramedReader<R> {
        #[pin]
        reader: R,
        state: NixFramedReaderState,
    }
}

impl<R> NixFramedReader<R> {
    pub fn new(reader: R) -> Self {
        Self {
            reader,
            state: NixFramedReaderState::ReadingSize {
                buf: [0; 8],
                filled: 0,
            },
        }
    }
}

impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        read_buf: &mut ReadBuf<'_>,
    ) -> Poll<Result<()>> {
        let mut this = self.as_mut().project();
        match this.state {
            NixFramedReaderState::ReadingSize { buf, filled } => {
                if *filled < buf.len() {
                    let mut size_buf = ReadBuf::new(buf);
                    size_buf.advance(*filled);

                    ready!(this.reader.poll_read(cx, &mut size_buf))?;
                    let bytes_read = size_buf.filled().len() - *filled;
                    if bytes_read == 0 {
                        // oef
                        return Poll::Ready(Ok(()));
                    }
                    *filled += bytes_read;
                    // Schedule ourselves to run again.
                    return self.poll_read(cx, read_buf);
                }
                let size = u64::from_le_bytes(*buf);
                if size == 0 {
                    // eof
                    *filled = 0;
                    return Poll::Ready(Ok(()));
                }
                *this.state = NixFramedReaderState::ReadingPayload { remaining: size };
                self.poll_read(cx, read_buf)
            }
            NixFramedReaderState::ReadingPayload { remaining } => {
                // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
                let safe_remaining = if *remaining <= usize::MAX as u64 {
                    *remaining as usize
                } else {
                    usize::MAX
                };
                if safe_remaining > 0 {
                    // The buffer is no larger than the amount of data that we expect.
                    // Otherwise we will trim the buffer below and come back here.
                    if read_buf.remaining() <= safe_remaining {
                        let filled_before = read_buf.filled().len();

                        ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
                        let bytes_read = read_buf.filled().len() - filled_before;

                        *remaining -= bytes_read as u64;
                        if *remaining == 0 {
                            *this.state = NixFramedReaderState::ReadingSize {
                                buf: [0; 8],
                                filled: 0,
                            };
                        }
                        return Poll::Ready(Ok(()));
                    }
                    // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
                    // internal bookkeeping simpler.
                    let mut smaller_buf = read_buf.take(safe_remaining);
                    ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;

                    let bytes_read = smaller_buf.filled().len();

                    // SAFETY: we just read this number of bytes into read_buf's backing slice above.
                    unsafe { read_buf.assume_init(bytes_read) };
                    read_buf.advance(bytes_read);
                    return Poll::Ready(Ok(()));
                }
                *this.state = NixFramedReaderState::ReadingSize {
                    buf: [0; 8],
                    filled: 0,
                };
                self.poll_read(cx, read_buf)
            }
        }
    }
}

#[cfg(test)]
mod nix_framed_tests {
    use std::time::Duration;

    use tokio::io::AsyncReadExt;
    use tokio_test::io::Builder;

    use crate::nix_daemon::framing::NixFramedReader;

    #[tokio::test]
    async fn read_hello_world_in_two_frames() {
        let mut mock = Builder::new()
            // The client sends len
            .read(&5u64.to_le_bytes())
            // Immediately followed by the bytes
            .read("hello".as_bytes())
            .wait(Duration::ZERO)
            // Send more data separately
            .read(&6u64.to_le_bytes())
            .read(" world".as_bytes())
            .build();

        let mut reader = NixFramedReader::new(&mut mock);
        let mut result = String::new();
        reader
            .read_to_string(&mut result)
            .await
            .expect("Could not read into result");
        assert_eq!("hello world", result);
    }
    #[tokio::test]
    async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
        let mut mock = Builder::new()
            // The client sends len
            .read(&5u64.to_le_bytes())
            // Immediately followed by the bytes
            .read("hello".as_bytes())
            .wait(Duration::ZERO)
            // Send more data separately
            .read(&6u64.to_le_bytes())
            .read(" world".as_bytes())
            .read(&0u64.to_le_bytes())
            .build();

        let mut reader = NixFramedReader::new(&mut mock);
        let mut result = String::new();
        reader
            .read_to_string(&mut result)
            .await
            .expect("Could not read into result");
        assert_eq!("hello world", result);
    }
}