nix_compat/wire/bytes/reader/
mod.rs1use std::{
2 future::Future,
3 io,
4 num::NonZeroU64,
5 ops::RangeBounds,
6 pin::Pin,
7 task::{self, ready, Poll},
8};
9use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf};
10
11use trailer::{read_trailer, ReadTrailer, Trailer};
12
13#[doc(hidden)]
14pub use self::trailer::Pad;
15pub(crate) use self::trailer::Tag;
16mod trailer;
17
18#[derive(Debug)]
32#[allow(private_bounds)]
33pub struct BytesReader<R, T: Tag = Pad> {
34 state: State<R, T>,
35}
36
37#[inline(always)]
41fn split_user_len(user_len: NonZeroU64) -> (u64, u8) {
42 let n = user_len.get() - 1;
43 let body_len = n & !7;
44 let tail_len = (n & 7) as u8 + 1;
45 (body_len, tail_len)
46}
47
48#[derive(Debug)]
49enum State<R, T: Tag> {
50 Body {
53 reader: Option<R>,
54 consumed: u64,
55 user_len: NonZeroU64,
57 },
58 ReadTrailer(ReadTrailer<R, T>),
60 ReleaseTrailer { consumed: u8, data: Trailer },
63}
64
65impl<R> BytesReader<R>
66where
67 R: AsyncRead + Unpin,
68{
69 pub async fn new<S: RangeBounds<u64>>(reader: R, allowed_size: S) -> io::Result<Self> {
71 BytesReader::new_internal(reader, allowed_size).await
72 }
73}
74
75#[allow(private_bounds)]
76impl<R, T: Tag> BytesReader<R, T>
77where
78 R: AsyncRead + Unpin,
79{
80 pub(crate) async fn new_internal<S: RangeBounds<u64>>(
82 mut reader: R,
83 allowed_size: S,
84 ) -> io::Result<Self> {
85 let size = reader.read_u64_le().await?;
86
87 if !allowed_size.contains(&size) {
88 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size"));
89 }
90
91 Ok(Self {
92 state: match NonZeroU64::new(size) {
93 Some(size) => State::Body {
94 reader: Some(reader),
95 consumed: 0,
96 user_len: size,
97 },
98 None => State::ReleaseTrailer {
99 consumed: 0,
100 data: read_trailer::<R, T>(reader, 0).await?,
101 },
102 },
103 })
104 }
105
106 pub fn is_empty(&self) -> bool {
108 self.len() == 0
109 }
110
111 pub fn len(&self) -> u64 {
113 match self.state {
114 State::Body {
115 consumed, user_len, ..
116 } => user_len.get() - consumed,
117 State::ReadTrailer(ref fut) => fut.len() as u64,
118 State::ReleaseTrailer { consumed, ref data } => data.len() as u64 - consumed as u64,
119 }
120 }
121}
122
123#[allow(private_bounds)]
124impl<R: AsyncRead + Unpin, T: Tag> AsyncRead for BytesReader<R, T> {
125 fn poll_read(
126 mut self: Pin<&mut Self>,
127 cx: &mut task::Context,
128 buf: &mut ReadBuf,
129 ) -> Poll<io::Result<()>> {
130 let this = &mut self.state;
131
132 loop {
133 match this {
134 State::Body {
135 reader,
136 consumed,
137 user_len,
138 } => {
139 let (body_len, tail_len) = split_user_len(*user_len);
140 let remaining = body_len - *consumed;
141
142 let reader = if remaining == 0 {
143 let reader = reader.take().unwrap();
144 *this = State::ReadTrailer(read_trailer(reader, tail_len));
145 continue;
146 } else {
147 Pin::new(reader.as_mut().unwrap())
148 };
149
150 let mut bytes_read = 0;
151 ready!(with_limited(buf, remaining, |buf| {
152 let ret = reader.poll_read(cx, buf);
153 bytes_read = buf.filled().len();
154 ret
155 }))?;
156
157 *consumed += bytes_read as u64;
158
159 return if bytes_read != 0 {
160 Ok(())
161 } else {
162 Err(io::ErrorKind::UnexpectedEof.into())
163 }
164 .into();
165 }
166 State::ReadTrailer(fut) => {
167 *this = State::ReleaseTrailer {
168 consumed: 0,
169 data: ready!(Pin::new(fut).poll(cx))?,
170 };
171 }
172 State::ReleaseTrailer { consumed, data } => {
173 let data = &data[*consumed as usize..];
174 let data = &data[..usize::min(data.len(), buf.remaining())];
175
176 buf.put_slice(data);
177 *consumed += data.len() as u8;
178
179 return Ok(()).into();
180 }
181 }
182 }
183 }
184}
185
186#[allow(private_bounds)]
187impl<R: AsyncBufRead + Unpin, T: Tag> AsyncBufRead for BytesReader<R, T> {
188 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<io::Result<&[u8]>> {
189 let this = &mut self.get_mut().state;
190
191 loop {
192 match this {
193 State::Body {
197 reader,
198 consumed,
199 user_len,
200 } if {
201 let (body_len, _) = split_user_len(*user_len);
202 let remaining = body_len - *consumed;
203
204 remaining == 0
205 } =>
206 {
207 let reader = reader.take().unwrap();
208 let (_, tail_len) = split_user_len(*user_len);
209
210 *this = State::ReadTrailer(read_trailer(reader, tail_len));
211 }
212 State::Body {
213 reader,
214 consumed,
215 user_len,
216 } => {
217 let (body_len, _) = split_user_len(*user_len);
218 let remaining = body_len - *consumed;
219
220 let reader = Pin::new(reader.as_mut().unwrap());
221
222 match ready!(reader.poll_fill_buf(cx))? {
223 &[] => {
224 return Err(io::ErrorKind::UnexpectedEof.into()).into();
225 }
226 mut buf => {
227 if buf.len() as u64 > remaining {
228 buf = &buf[..remaining as usize];
229 }
230
231 return Ok(buf).into();
232 }
233 }
234 }
235 State::ReadTrailer(fut) => {
236 *this = State::ReleaseTrailer {
237 consumed: 0,
238 data: ready!(Pin::new(fut).poll(cx))?,
239 };
240 }
241 State::ReleaseTrailer { consumed, data } => {
242 return Ok(&data[*consumed as usize..]).into();
243 }
244 }
245 }
246 }
247
248 fn consume(mut self: Pin<&mut Self>, amt: usize) {
249 match &mut self.state {
250 State::Body {
251 reader,
252 consumed,
253 user_len,
254 } => {
255 let reader = Pin::new(reader.as_mut().unwrap());
256 let (body_len, _) = split_user_len(*user_len);
257
258 *consumed = consumed
259 .checked_add(amt as u64)
260 .filter(|&consumed| consumed <= body_len)
261 .expect("consumed out of bounds");
262
263 reader.consume(amt);
264 }
265 State::ReadTrailer(_) => unreachable!(),
266 State::ReleaseTrailer { consumed, data } => {
267 *consumed = amt
268 .checked_add(*consumed as usize)
269 .filter(|&consumed| consumed <= data.len())
270 .expect("consumed out of bounds") as u8;
271 }
272 }
273 }
274}
275
276fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
279 let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
280 let ptr = nbuf.initialized().as_ptr();
281 let ret = f(&mut nbuf);
282
283 unsafe {
289 assert_eq!(nbuf.initialized().as_ptr(), ptr);
291
292 let n = nbuf.filled().len();
293 buf.assume_init(n);
294 buf.advance(n);
295 }
296
297 ret
298}
299
300#[cfg(test)]
301mod tests {
302 use std::sync::LazyLock;
303 use std::time::Duration;
304
305 use crate::wire::bytes::{padding_len, write_bytes};
306 use hex_literal::hex;
307 use rstest::rstest;
308 use tokio::io::{AsyncReadExt, BufReader};
309 use tokio_test::io::Builder;
310
311 use super::*;
312
313 const MAX_LEN: u64 = 1024;
316
317 pub static LARGE_PAYLOAD: LazyLock<Vec<u8>> =
318 LazyLock::new(|| (0..255).collect::<Vec<u8>>().repeat(4 * 1024));
319
320 async fn produce_packet_bytes(payload: &[u8]) -> Vec<u8> {
323 let mut exp = vec![];
324 write_bytes(&mut exp, payload).await.unwrap();
325 exp
326 }
327
328 #[rstest]
331 #[case::empty(&[])] #[case::size_1b(&[0xff])] #[case::size_8b(&hex!("0001020304050607"))] #[case::size_9b(&hex!("000102030405060708"))] #[case::size_1m(LARGE_PAYLOAD.as_slice())] #[tokio::test]
337 async fn read_payload_correct(#[case] payload: &[u8]) {
338 let mut mock = Builder::new()
339 .read(&produce_packet_bytes(payload).await)
340 .build();
341
342 let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64)
343 .await
344 .unwrap();
345 let mut buf = Vec::new();
346 r.read_to_end(&mut buf).await.expect("must succeed");
347
348 assert_eq!(payload, &buf[..]);
349 }
350
351 #[rstest]
354 #[case::empty(&[])] #[case::size_1b(&[0xff])] #[case::size_8b(&hex!("0001020304050607"))] #[case::size_9b(&hex!("000102030405060708"))] #[case::size_1m(LARGE_PAYLOAD.as_slice())] #[tokio::test]
360 async fn read_payload_correct_readbuf(#[case] payload: &[u8]) {
361 let mut mock = BufReader::new(
362 Builder::new()
363 .read(&produce_packet_bytes(payload).await)
364 .build(),
365 );
366
367 let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64)
368 .await
369 .unwrap();
370
371 let mut buf = Vec::new();
372 tokio::io::copy_buf(&mut r, &mut buf)
373 .await
374 .expect("copy_buf must succeed");
375
376 assert_eq!(payload, &buf[..]);
377 }
378
379 #[tokio::test]
381 async fn read_bigger_than_allowed_fail() {
382 let payload = LARGE_PAYLOAD.as_slice();
383 let mut mock = Builder::new()
384 .read(&produce_packet_bytes(payload).await[0..8]) .build();
386
387 assert_eq!(
388 BytesReader::new(&mut mock, ..2048)
389 .await
390 .unwrap_err()
391 .kind(),
392 io::ErrorKind::InvalidData
393 );
394 }
395
396 #[tokio::test]
398 async fn read_smaller_than_allowed_fail() {
399 let payload = &[0x00, 0x01, 0x02];
400 let mut mock = Builder::new()
401 .read(&produce_packet_bytes(payload).await[0..8]) .build();
403
404 assert_eq!(
405 BytesReader::new(&mut mock, 1024..2048)
406 .await
407 .unwrap_err()
408 .kind(),
409 io::ErrorKind::InvalidData
410 );
411 }
412
413 #[cfg(feature = "async")]
415 #[tokio::test]
416 async fn read_trailer_immediately() {
417 use crate::nar::wire::PadPar;
418
419 let mut mock = Builder::new()
420 .read(&[0; 8])
421 .read(&PadPar::PATTERN[8..])
422 .build();
423
424 BytesReader::<_, PadPar>::new_internal(&mut mock, ..)
425 .await
426 .unwrap();
427
428 }
430
431 #[cfg(feature = "async")]
433 #[tokio::test]
434 async fn read_exact_trailer() {
435 use crate::nar::wire::PadPar;
436
437 let mut mock = Builder::new()
438 .read(&16u64.to_le_bytes())
439 .read(&[0x55; 16])
440 .read(&PadPar::PATTERN[8..])
441 .build();
442
443 let mut reader = BytesReader::<_, PadPar>::new_internal(&mut mock, ..)
444 .await
445 .unwrap();
446
447 let mut buf = [0; 16];
448 reader.read_exact(&mut buf).await.unwrap();
449 assert_eq!(buf, [0x55; 16]);
450
451 }
453
454 #[tokio::test]
456 async fn read_fail_if_nonzero_padding() {
457 let payload = &[0x00, 0x01, 0x02];
458 let mut packet_bytes = produce_packet_bytes(payload).await;
459 packet_bytes[12] = 0xff;
461 let mut mock = Builder::new().read(&packet_bytes).build(); let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
464 let mut buf = Vec::new();
465
466 r.read_to_end(&mut buf).await.expect_err("must fail");
467 }
468
469 #[tokio::test]
474 async fn read_9b_eof_during_size() {
475 let payload = &hex!("FF0102030405060708");
476 let mut mock = Builder::new()
477 .read(&produce_packet_bytes(payload).await[..4])
478 .build();
479
480 assert_eq!(
481 BytesReader::new(&mut mock, ..MAX_LEN)
482 .await
483 .expect_err("must fail")
484 .kind(),
485 io::ErrorKind::UnexpectedEof
486 );
487 }
488
489 #[tokio::test]
494 async fn read_9b_eof_during_payload() {
495 let payload = &hex!("FF0102030405060708");
496 let mut mock = Builder::new()
497 .read(&produce_packet_bytes(payload).await[..8 + 4])
498 .build();
499
500 let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
501 let mut buf = [0; 9];
502
503 r.read_exact(&mut buf[..4]).await.expect("must succeed");
504
505 assert_eq!(
506 r.read_exact(&mut buf[4..=4])
507 .await
508 .expect_err("must fail")
509 .kind(),
510 std::io::ErrorKind::UnexpectedEof
511 );
512 }
513
514 #[rstest]
517 #[case::before_padding(8 + 9)]
518 #[case::during_padding(8 + 9 + 2)]
519 #[case::after_padding(8 + 9 + padding_len(9) as usize - 1)]
520 #[tokio::test]
521 async fn read_9b_eof_after_payload(#[case] offset: usize) {
522 let payload = &hex!("FF0102030405060708");
523 let mut mock = Builder::new()
524 .read(&produce_packet_bytes(payload).await[..offset])
525 .build();
526
527 let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
528
529 assert_eq!(r.read_exact(&mut [0; 8]).await.unwrap(), 8);
532 assert_eq!(
533 r.read_exact(&mut [0]).await.unwrap_err().kind(),
534 std::io::ErrorKind::UnexpectedEof
535 );
536 }
537
538 #[rstest]
541 #[case::during_size(4)]
542 #[case::before_payload(8)]
543 #[case::during_payload(8 + 4)]
544 #[case::before_padding(8 + 4)]
545 #[case::during_padding(8 + 9 + 2)]
546 #[tokio::test]
547 async fn propagate_error_from_reader(#[case] offset: usize) {
548 let payload = &hex!("FF0102030405060708");
549 let mut mock = Builder::new()
550 .read(&produce_packet_bytes(payload).await[..offset])
551 .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
552 .build();
553
554 let err: io::Error = async {
556 let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await?;
557 let mut buf = Vec::new();
558
559 r.read_to_end(&mut buf).await?;
560
561 Ok(())
562 }
563 .await
564 .expect_err("must fail");
565
566 assert_eq!(
567 err.kind(),
568 std::io::ErrorKind::Other,
569 "error kind must match"
570 );
571
572 assert_eq!(
573 err.into_inner().unwrap().to_string(),
574 "foo",
575 "error payload must contain foo"
576 );
577 }
578
579 #[rstest]
582 #[case::during_size(4)]
583 #[case::before_payload(8)]
584 #[case::during_payload(8 + 4)]
585 #[case::before_padding(8 + 4)]
586 #[case::during_padding(8 + 9 + 2)]
587 #[tokio::test]
588 async fn propagate_error_from_reader_buffered(#[case] offset: usize) {
589 let payload = &hex!("FF0102030405060708");
590 let mock = Builder::new()
591 .read(&produce_packet_bytes(payload).await[..offset])
592 .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
593 .build();
594 let mut mock = BufReader::new(mock);
595
596 let err: io::Error = async {
598 let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await?;
599 let mut buf = Vec::new();
600
601 tokio::io::copy_buf(&mut r, &mut buf).await?;
602
603 Ok(())
604 }
605 .await
606 .expect_err("must fail");
607
608 assert_eq!(
609 err.kind(),
610 std::io::ErrorKind::Other,
611 "error kind must match"
612 );
613
614 assert_eq!(
615 err.into_inner().unwrap().to_string(),
616 "foo",
617 "error payload must contain foo"
618 );
619 }
620
621 #[tokio::test]
624 async fn no_error_after_eof() {
625 let payload = &hex!("FF0102030405060708");
626 let mut mock = Builder::new()
627 .read(&produce_packet_bytes(payload).await)
628 .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
629 .build();
630
631 let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
632 let mut buf = Vec::new();
633
634 r.read_to_end(&mut buf).await.expect("must succeed");
635 assert_eq!(buf.as_slice(), payload);
636 }
637
638 #[tokio::test]
641 async fn no_error_after_eof_buffered() {
642 let payload = &hex!("FF0102030405060708");
643 let mock = Builder::new()
644 .read(&produce_packet_bytes(payload).await)
645 .read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
646 .build();
647 let mut mock = BufReader::new(mock);
648
649 let mut r = BytesReader::new(&mut mock, ..MAX_LEN).await.unwrap();
650 let mut buf = Vec::new();
651
652 tokio::io::copy_buf(&mut r, &mut buf)
653 .await
654 .expect("must succeed");
655 assert_eq!(buf.as_slice(), payload);
656 }
657
658 #[rstest]
661 #[case::beginning(0)]
662 #[case::before_payload(8)]
663 #[case::during_payload(8 + 4)]
664 #[case::before_padding(8 + 4)]
665 #[case::during_padding(8 + 9 + 2)]
666 #[tokio::test]
667 async fn read_payload_correct_pending(#[case] offset: usize) {
668 let payload = &hex!("FF0102030405060708");
669 let mut mock = Builder::new()
670 .read(&produce_packet_bytes(payload).await[..offset])
671 .wait(Duration::from_nanos(0))
672 .read(&produce_packet_bytes(payload).await[offset..])
673 .build();
674
675 let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64)
676 .await
677 .unwrap();
678 let mut buf = Vec::new();
679 r.read_to_end(&mut buf).await.expect("must succeed");
680
681 assert_eq!(payload, &buf[..]);
682 }
683}