use super::compression::{
compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride,
};
use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
use crate::Status;
use bytes::{BufMut, Bytes, BytesMut};
use http::HeaderMap;
use http_body::{Body, Frame};
use pin_project::pin_project;
use std::{
pin::Pin,
task::{ready, Context, Poll},
};
use tokio_stream::{adapters::Fuse, Stream, StreamExt};
#[pin_project(project = EncodedBytesProj)]
#[derive(Debug)]
struct EncodedBytes<T, U> {
#[pin]
source: Fuse<U>,
encoder: T,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
buf: BytesMut,
uncompression_buf: BytesMut,
error: Option<Status>,
}
impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
fn new(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> Self {
let buffer_settings = encoder.buffer_settings();
let buf = BytesMut::with_capacity(buffer_settings.buffer_size);
let compression_encoding =
if compression_override == SingleMessageCompressionOverride::Disable {
None
} else {
compression_encoding
};
let uncompression_buf = if compression_encoding.is_some() {
BytesMut::with_capacity(buffer_settings.buffer_size)
} else {
BytesMut::new()
};
Self {
source: source.fuse(),
encoder,
compression_encoding,
max_message_size,
buf,
uncompression_buf,
error: None,
}
}
}
impl<T, U> Stream for EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
type Item = Result<Bytes, Status>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let EncodedBytesProj {
mut source,
encoder,
compression_encoding,
max_message_size,
buf,
uncompression_buf,
error,
} = self.project();
let buffer_settings = encoder.buffer_settings();
if let Some(status) = error.take() {
return Poll::Ready(Some(Err(status)));
}
loop {
match source.as_mut().poll_next(cx) {
Poll::Pending if buf.is_empty() => {
return Poll::Pending;
}
Poll::Ready(None) if buf.is_empty() => {
return Poll::Ready(None);
}
Poll::Pending | Poll::Ready(None) => {
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
Poll::Ready(Some(Ok(item))) => {
if let Err(status) = encode_item(
encoder,
buf,
uncompression_buf,
*compression_encoding,
*max_message_size,
buffer_settings,
item,
) {
return Poll::Ready(Some(Err(status)));
}
if buf.len() >= buffer_settings.yield_threshold {
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
}
Poll::Ready(Some(Err(status))) => {
if buf.is_empty() {
return Poll::Ready(Some(Err(status)));
}
*error = Some(status);
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
}
}
}
}
fn encode_item<T>(
encoder: &mut T,
buf: &mut BytesMut,
uncompression_buf: &mut BytesMut,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
buffer_settings: BufferSettings,
item: T::Item,
) -> Result<(), Status>
where
T: Encoder<Error = Status>,
{
let offset = buf.len();
buf.reserve(HEADER_SIZE);
unsafe {
buf.advance_mut(HEADER_SIZE);
}
if let Some(encoding) = compression_encoding {
uncompression_buf.clear();
encoder
.encode(item, &mut EncodeBuf::new(uncompression_buf))
.map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
let uncompressed_len = uncompression_buf.len();
compress(
CompressionSettings {
encoding,
buffer_growth_interval: buffer_settings.buffer_size,
},
uncompression_buf,
buf,
uncompressed_len,
)
.map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
} else {
encoder
.encode(item, &mut EncodeBuf::new(buf))
.map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
}
finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
}
fn finish_encoding(
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
buf: &mut [u8],
) -> Result<(), Status> {
let len = buf.len() - HEADER_SIZE;
let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
if len > limit {
return Err(Status::out_of_range(format!(
"Error, encoded message length too large: found {} bytes, the limit is: {} bytes",
len, limit
)));
}
if len > u32::MAX as usize {
return Err(Status::resource_exhausted(format!(
"Cannot return body with more than 4GB of data but got {len} bytes"
)));
}
{
let mut buf = &mut buf[..HEADER_SIZE];
buf.put_u8(compression_encoding.is_some() as u8);
buf.put_u32(len as u32);
}
Ok(())
}
#[derive(Debug)]
enum Role {
Client,
Server,
}
#[pin_project]
#[derive(Debug)]
pub struct EncodeBody<T, U> {
#[pin]
inner: EncodedBytes<T, U>,
state: EncodeState,
}
#[derive(Debug)]
struct EncodeState {
error: Option<Status>,
role: Role,
is_end_stream: bool,
}
impl<T: Encoder, U: Stream> EncodeBody<T, U> {
pub fn new_client(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
) -> Self {
Self {
inner: EncodedBytes::new(
encoder,
source,
compression_encoding,
SingleMessageCompressionOverride::default(),
max_message_size,
),
state: EncodeState {
error: None,
role: Role::Client,
is_end_stream: false,
},
}
}
pub fn new_server(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> Self {
Self {
inner: EncodedBytes::new(
encoder,
source,
compression_encoding,
compression_override,
max_message_size,
),
state: EncodeState {
error: None,
role: Role::Server,
is_end_stream: false,
},
}
}
}
impl EncodeState {
fn trailers(&mut self) -> Option<Result<HeaderMap, Status>> {
match self.role {
Role::Client => None,
Role::Server => {
if self.is_end_stream {
return None;
}
self.is_end_stream = true;
let status = if let Some(status) = self.error.take() {
status
} else {
Status::ok("")
};
Some(status.to_header_map())
}
}
}
}
impl<T, U> Body for EncodeBody<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
type Data = Bytes;
type Error = Status;
fn is_end_stream(&self) -> bool {
self.state.is_end_stream
}
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let self_proj = self.project();
match ready!(self_proj.inner.poll_next(cx)) {
Some(Ok(d)) => Some(Ok(Frame::data(d))).into(),
Some(Err(status)) => match self_proj.state.role {
Role::Client => Some(Err(status)).into(),
Role::Server => {
self_proj.state.is_end_stream = true;
Some(Ok(Frame::trailers(status.to_header_map()?))).into()
}
},
None => self_proj
.state
.trailers()
.map(|t| t.map(Frame::trailers))
.into(),
}
}
}