1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
use std::{
io::Result,
pin::Pin,
task::{ready, Poll},
};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, ReadBuf};
/// State machine for [`NixFramedReader`].
///
/// As the reader progresses it linearly cycles through the states.
#[derive(Debug)]
enum NixFramedReaderState {
/// The reader always starts in this state.
///
/// Before the payload, the client first sends its size.
/// The size is a u64 which is 8 bytes long, while it's likely that we will receive
/// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
/// So in this state we read up to 8 bytes and transition to
/// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
/// otherwise we reset filled to 0, and read the next size value.
ReadingSize { buf: [u8; 8], filled: usize },
/// This is where we read the actual payload that is sent to us.
///
/// Once we've read the expected number of bytes, we go back to the
/// [`NixFramedReaderState::ReadingSize`] state.
ReadingPayload {
/// Represents the remaining number of bytes we expect to read based on the value
/// read in the previous state.
remaining: u64,
},
}
pin_project! {
/// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
///
/// See serialization.md#framed and [`NixFramedReaderState`] for details.
pub struct NixFramedReader<R> {
#[pin]
reader: R,
state: NixFramedReaderState,
}
}
impl<R> NixFramedReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
state: NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 0,
},
}
}
}
impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
read_buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let mut this = self.as_mut().project();
match this.state {
NixFramedReaderState::ReadingSize { buf, filled } => {
if *filled < buf.len() {
let mut size_buf = ReadBuf::new(buf);
size_buf.advance(*filled);
ready!(this.reader.poll_read(cx, &mut size_buf))?;
let bytes_read = size_buf.filled().len() - *filled;
if bytes_read == 0 {
// oef
return Poll::Ready(Ok(()));
}
*filled += bytes_read;
// Schedule ourselves to run again.
return self.poll_read(cx, read_buf);
}
let size = u64::from_le_bytes(*buf);
if size == 0 {
// eof
*filled = 0;
return Poll::Ready(Ok(()));
}
*this.state = NixFramedReaderState::ReadingPayload { remaining: size };
self.poll_read(cx, read_buf)
}
NixFramedReaderState::ReadingPayload { remaining } => {
// Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
let safe_remaining = if *remaining <= usize::MAX as u64 {
*remaining as usize
} else {
usize::MAX
};
if safe_remaining > 0 {
// The buffer is no larger than the amount of data that we expect.
// Otherwise we will trim the buffer below and come back here.
if read_buf.remaining() <= safe_remaining {
let filled_before = read_buf.filled().len();
ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
let bytes_read = read_buf.filled().len() - filled_before;
*remaining -= bytes_read as u64;
if *remaining == 0 {
*this.state = NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 0,
};
}
return Poll::Ready(Ok(()));
}
// Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
// internal bookkeeping simpler.
let mut smaller_buf = read_buf.take(safe_remaining);
ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
let bytes_read = smaller_buf.filled().len();
// SAFETY: we just read this number of bytes into read_buf's backing slice above.
unsafe { read_buf.assume_init(bytes_read) };
read_buf.advance(bytes_read);
return Poll::Ready(Ok(()));
}
*this.state = NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 0,
};
self.poll_read(cx, read_buf)
}
}
}
}
#[cfg(test)]
mod nix_framed_tests {
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio_test::io::Builder;
use crate::nix_daemon::framing::NixFramedReader;
#[tokio::test]
async fn read_hello_world_in_two_frames() {
let mut mock = Builder::new()
// The client sends len
.read(&5u64.to_le_bytes())
// Immediately followed by the bytes
.read("hello".as_bytes())
.wait(Duration::ZERO)
// Send more data separately
.read(&6u64.to_le_bytes())
.read(" world".as_bytes())
.build();
let mut reader = NixFramedReader::new(&mut mock);
let mut result = String::new();
reader
.read_to_string(&mut result)
.await
.expect("Could not read into result");
assert_eq!("hello world", result);
}
#[tokio::test]
async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
let mut mock = Builder::new()
// The client sends len
.read(&5u64.to_le_bytes())
// Immediately followed by the bytes
.read("hello".as_bytes())
.wait(Duration::ZERO)
// Send more data separately
.read(&6u64.to_le_bytes())
.read(" world".as_bytes())
.read(&0u64.to_le_bytes())
.build();
let mut reader = NixFramedReader::new(&mut mock);
let mut result = String::new();
reader
.read_to_string(&mut result)
.await
.expect("Could not read into result");
assert_eq!("hello world", result);
}
}