1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
use std::{io, mem};
use std::pin::Pin;
use std::task::{Context, Poll};

use axum::response::{Response, IntoResponse};
use bytes::{Bytes, BytesMut};
use http_body::{Body, SizeHint, Frame};
use futures::Stream;
use pin_project::pin_project;
use tokio::io::ReadBuf;

use crate::RangeBody;

const IO_BUFFER_SIZE: usize = 64 * 1024;

/// Response body stream. Implements [`Stream`], [`Body`], and [`IntoResponse`].
#[pin_project]
pub struct RangedStream<B> {
    state: StreamState,
    length: u64,
    #[pin]
    body: B,
}

impl<B: RangeBody + Send + 'static> RangedStream<B> {
    pub(crate) fn new(body: B, start: u64, length: u64) -> Self {
        RangedStream {
            state: StreamState::Seek { start },
            length,
            body,
        }
    }
}

#[derive(Debug)]
enum StreamState {
    Seek { start: u64 },
    Seeking { remaining: u64 },
    Reading { buffer: BytesMut, remaining: u64 },
}

impl<B: RangeBody + Send + 'static> IntoResponse for RangedStream<B> {
    fn into_response(self) -> Response {
        Response::new(axum::body::Body::new(self))
    }
}

impl<B: RangeBody> Body for RangedStream<B> {
    type Data = Bytes;
    type Error = io::Error;

    fn size_hint(&self) -> SizeHint {
        SizeHint::with_exact(self.length)
    }

    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>)
        -> Poll<Option<io::Result<Frame<Bytes>>>>
    {
        self.poll_next(cx).map(|item| item.map(|result| result.map(Frame::data)))
    }
}

impl<B: RangeBody> Stream for RangedStream<B> {
    type Item = io::Result<Bytes>;

    fn poll_next(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>
    ) -> Poll<Option<io::Result<Bytes>>> {
        let mut this = self.project();

        if let StreamState::Seek { start } = *this.state {
            match this.body.as_mut().start_seek(start) {
                Err(e) => { return Poll::Ready(Some(Err(e))); }
                Ok(()) => {
                    let remaining = *this.length;
                    *this.state = StreamState::Seeking { remaining };
                }
            }
        }

        if let StreamState::Seeking { remaining } = *this.state {
            match this.body.as_mut().poll_complete(cx) {
                Poll::Pending => { return Poll::Pending; }
                Poll::Ready(Err(e)) => { return Poll::Ready(Some(Err(e))); }
                Poll::Ready(Ok(())) => {
                    let buffer = allocate_buffer();
                    *this.state = StreamState::Reading { buffer, remaining };
                }
            }
        }

        if let StreamState::Reading { buffer, remaining } = this.state {
            let uninit = buffer.spare_capacity_mut();

            // calculate max number of bytes to read in this iteration, the
            // smaller of the buffer size and the number of bytes remaining
            let nbytes = std::cmp::min(
                uninit.len(),
                usize::try_from(*remaining).unwrap_or(usize::MAX),
            );

            let mut read_buf = ReadBuf::uninit(&mut uninit[0..nbytes]);

            match this.body.as_mut().poll_read(cx, &mut read_buf) {
                Poll::Pending => { return Poll::Pending; }
                Poll::Ready(Err(e)) => { return Poll::Ready(Some(Err(e))); }
                Poll::Ready(Ok(())) => {
                    match read_buf.filled().len() {
                        0 => { return Poll::Ready(None); }
                        n => {
                            // SAFETY: poll_read has filled the buffer with `n`
                            // additional bytes. `buffer.len` should always be
                            // 0 here, but include it for rigorous correctness
                            unsafe { buffer.set_len(buffer.len() + n); }

                            // replace state buffer and take this one to return
                            let chunk = mem::replace(buffer, allocate_buffer());

                            // subtract the number of bytes we just read from
                            // state.remaining, this usize->u64 conversion is
                            // guaranteed to always succeed, because n cannot be
                            // larger than remaining due to the cmp::min above
                            *remaining -= u64::try_from(n).unwrap();

                            // return this chunk
                            return Poll::Ready(Some(Ok(chunk.freeze())));
                        }
                    }
                }
            }
        }

        unreachable!();
    }
}

fn allocate_buffer() -> BytesMut {
    BytesMut::with_capacity(IO_BUFFER_SIZE)
}