nix_compat/nix_daemon/framing/
framed_read.rs1use std::{
2 io::Result,
3 pin::Pin,
4 task::{ready, Poll},
5};
6
7use pin_project_lite::pin_project;
8use tokio::io::{AsyncRead, ReadBuf};
9
10#[derive(Debug)]
14enum NixFramedReaderState {
15 ReadingSize { buf: [u8; 8], filled: usize },
24 ReadingPayload {
29 remaining: u64,
32 },
33}
34
35pin_project! {
36 pub struct NixFramedReader<R> {
40 #[pin]
41 reader: R,
42 state: NixFramedReaderState,
43 }
44}
45
46impl<R> NixFramedReader<R> {
47 pub fn new(reader: R) -> Self {
48 Self {
49 reader,
50 state: NixFramedReaderState::ReadingSize {
51 buf: [0; 8],
52 filled: 0,
53 },
54 }
55 }
56}
57
58impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
59 fn poll_read(
60 mut self: Pin<&mut Self>,
61 cx: &mut std::task::Context<'_>,
62 read_buf: &mut ReadBuf<'_>,
63 ) -> Poll<Result<()>> {
64 let mut this = self.as_mut().project();
65 match this.state {
66 NixFramedReaderState::ReadingSize { buf, filled } => {
67 if *filled < buf.len() {
68 let mut size_buf = ReadBuf::new(buf);
69 size_buf.advance(*filled);
70
71 ready!(this.reader.poll_read(cx, &mut size_buf))?;
72 let bytes_read = size_buf.filled().len() - *filled;
73 if bytes_read == 0 {
74 return Poll::Ready(Ok(()));
76 }
77 *filled += bytes_read;
78 return self.poll_read(cx, read_buf);
80 }
81 let size = u64::from_le_bytes(*buf);
82 if size == 0 {
83 *filled = 0;
85 return Poll::Ready(Ok(()));
86 }
87 *this.state = NixFramedReaderState::ReadingPayload { remaining: size };
88 self.poll_read(cx, read_buf)
89 }
90 NixFramedReaderState::ReadingPayload { remaining } => {
91 let safe_remaining = if *remaining <= usize::MAX as u64 {
93 *remaining as usize
94 } else {
95 usize::MAX
96 };
97 if safe_remaining > 0 {
98 if read_buf.remaining() <= safe_remaining {
101 let filled_before = read_buf.filled().len();
102
103 ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
104 let bytes_read = read_buf.filled().len() - filled_before;
105
106 *remaining -= bytes_read as u64;
107 if *remaining == 0 {
108 *this.state = NixFramedReaderState::ReadingSize {
109 buf: [0; 8],
110 filled: 0,
111 };
112 }
113 return Poll::Ready(Ok(()));
114 }
115 let mut smaller_buf = read_buf.take(safe_remaining);
118 ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
119
120 let bytes_read = smaller_buf.filled().len();
121
122 unsafe { read_buf.assume_init(bytes_read) };
124 read_buf.advance(bytes_read);
125 return Poll::Ready(Ok(()));
126 }
127 *this.state = NixFramedReaderState::ReadingSize {
128 buf: [0; 8],
129 filled: 0,
130 };
131 self.poll_read(cx, read_buf)
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod nix_framed_tests {
139 use std::time::Duration;
140
141 use tokio::io::AsyncReadExt;
142 use tokio_test::io::Builder;
143
144 use crate::nix_daemon::framing::NixFramedReader;
145
146 #[tokio::test]
147 async fn read_hello_world_in_two_frames() {
148 let mut mock = Builder::new()
149 .read(&5u64.to_le_bytes())
151 .read("hello".as_bytes())
153 .wait(Duration::ZERO)
154 .read(&6u64.to_le_bytes())
156 .read(" world".as_bytes())
157 .build();
158
159 let mut reader = NixFramedReader::new(&mut mock);
160 let mut result = String::new();
161 reader
162 .read_to_string(&mut result)
163 .await
164 .expect("Could not read into result");
165 assert_eq!("hello world", result);
166 }
167 #[tokio::test]
168 async fn read_hello_world_in_two_frames_followed_by_zero_sized_frame() {
169 let mut mock = Builder::new()
170 .read(&5u64.to_le_bytes())
172 .read("hello".as_bytes())
174 .wait(Duration::ZERO)
175 .read(&6u64.to_le_bytes())
177 .read(" world".as_bytes())
178 .read(&0u64.to_le_bytes())
179 .build();
180
181 let mut reader = NixFramedReader::new(&mut mock);
182 let mut result = String::new();
183 reader
184 .read_to_string(&mut result)
185 .await
186 .expect("Could not read into result");
187 assert_eq!("hello world", result);
188 }
189}