use std::collections::HashSet;
use super::{Directory, DirectoryPutter, DirectoryService};
use crate::composition::{CompositionContext, ServiceBuilder};
use crate::proto::{self, get_directory_request::ByWhat};
use crate::{B3Digest, DirectoryError, Error};
use async_stream::try_stream;
use futures::stream::BoxStream;
use std::sync::Arc;
use tokio::spawn;
use tokio::sync::mpsc::UnboundedSender;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{async_trait, Code, Status};
use tracing::{instrument, warn, Instrument as _};
#[derive(Clone)]
pub struct GRPCDirectoryService<T> {
grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
}
impl<T> GRPCDirectoryService<T> {
pub fn from_client(
grpc_client: proto::directory_service_client::DirectoryServiceClient<T>,
) -> Self {
Self { grpc_client }
}
}
#[async_trait]
impl<T> DirectoryService for GRPCDirectoryService<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody> + Send + Sync + Clone + 'static,
T::ResponseBody: tonic::codegen::Body<Data = tonic::codegen::Bytes> + Send + 'static,
<T::ResponseBody as tonic::codegen::Body>::Error: Into<tonic::codegen::StdError> + Send,
T::Future: Send,
{
#[instrument(level = "trace", skip_all, fields(directory.digest = %digest))]
async fn get(&self, digest: &B3Digest) -> Result<Option<Directory>, crate::Error> {
let mut grpc_client = self.grpc_client.clone();
let digest_cpy = digest.clone();
let message = async move {
let mut s = grpc_client
.get(proto::GetDirectoryRequest {
recursive: false,
by_what: Some(ByWhat::Digest(digest_cpy.into())),
})
.await?
.into_inner();
s.message().await
};
let digest = digest.clone();
match message.await {
Ok(Some(directory)) => {
let actual_digest = directory.digest();
if actual_digest != digest {
Err(crate::Error::StorageError(format!(
"requested directory with digest {}, but got {}",
digest, actual_digest
)))
} else {
Ok(Some(directory.try_into().map_err(|_| {
Error::StorageError("invalid root digest length in response".to_string())
})?))
}
}
Ok(None) => Ok(None),
Err(e) if e.code() == Code::NotFound => Ok(None),
Err(e) => Err(crate::Error::StorageError(e.to_string())),
}
}
#[instrument(level = "trace", skip_all, fields(directory.digest = %directory.digest()))]
async fn put(&self, directory: Directory) -> Result<B3Digest, crate::Error> {
let resp = self
.grpc_client
.clone()
.put(tokio_stream::once(proto::Directory::from(directory)))
.await;
match resp {
Ok(put_directory_resp) => Ok(put_directory_resp
.into_inner()
.root_digest
.try_into()
.map_err(|_| {
Error::StorageError("invalid root digest length in response".to_string())
})?),
Err(e) => Err(crate::Error::StorageError(e.to_string())),
}
}
#[instrument(level = "trace", skip_all, fields(directory.digest = %root_directory_digest))]
fn get_recursive(
&self,
root_directory_digest: &B3Digest,
) -> BoxStream<'static, Result<Directory, Error>> {
let mut grpc_client = self.grpc_client.clone();
let root_directory_digest = root_directory_digest.clone();
let stream = try_stream! {
let mut stream = grpc_client
.get(proto::GetDirectoryRequest {
recursive: true,
by_what: Some(ByWhat::Digest(root_directory_digest.clone().into())),
})
.await
.map_err(|e| crate::Error::StorageError(e.to_string()))?
.into_inner();
let mut received_directory_digests: HashSet<B3Digest> = HashSet::new();
let mut expected_directory_digests: HashSet<B3Digest> = HashSet::from([root_directory_digest.clone()]);
loop {
match stream.message().await {
Ok(Some(directory)) => {
let directory_digest = directory.digest();
let was_expected = expected_directory_digests.remove(&directory_digest);
if !was_expected {
Err(crate::Error::StorageError(format!(
"received unexpected directory {}",
directory_digest
)))?;
}
received_directory_digests.insert(directory_digest);
for child_directory in &directory.directories {
let child_directory_digest =
child_directory.digest.clone().try_into().unwrap();
expected_directory_digests
.insert(child_directory_digest);
}
let directory = directory.try_into()
.map_err(|e: DirectoryError| Error::StorageError(e.to_string()))?;
yield directory;
},
Ok(None) if expected_directory_digests.len() == 1 && expected_directory_digests.contains(&root_directory_digest) => {
return
}
Ok(None) => {
let diff_len = expected_directory_digests
.difference(&received_directory_digests)
.count();
if diff_len != 0 {
Err(crate::Error::StorageError(format!(
"still expected {} directories, but got premature end of stream",
diff_len
)))?
} else {
return
}
},
Err(e) => {
Err(crate::Error::StorageError(e.to_string()))?;
},
}
}
};
Box::pin(stream)
}
#[instrument(skip_all)]
fn put_multiple_start(&self) -> Box<(dyn DirectoryPutter + 'static)>
where
Self: Clone,
{
let mut grpc_client = self.grpc_client.clone();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let task: JoinHandle<Result<proto::PutDirectoryResponse, Status>> = spawn(
async move {
let s = grpc_client
.put(UnboundedReceiverStream::new(rx))
.await?
.into_inner();
Ok(s)
} .in_current_span(),
);
Box::new(GRPCPutter {
rq: Some((task, tx)),
})
}
}
#[derive(serde::Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct GRPCDirectoryServiceConfig {
url: String,
}
impl TryFrom<url::Url> for GRPCDirectoryServiceConfig {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(url: url::Url) -> Result<Self, Self::Error> {
Ok(GRPCDirectoryServiceConfig {
url: url.to_string(),
})
}
}
#[async_trait]
impl ServiceBuilder for GRPCDirectoryServiceConfig {
type Output = dyn DirectoryService;
async fn build<'a>(
&'a self,
_instance_name: &str,
_context: &CompositionContext,
) -> Result<Arc<dyn DirectoryService>, Box<dyn std::error::Error + Send + Sync + 'static>> {
let client = proto::directory_service_client::DirectoryServiceClient::new(
crate::tonic::channel_from_url(&self.url.parse()?).await?,
);
Ok(Arc::new(GRPCDirectoryService::from_client(client)))
}
}
pub struct GRPCPutter {
#[allow(clippy::type_complexity)] rq: Option<(
JoinHandle<Result<proto::PutDirectoryResponse, Status>>,
UnboundedSender<proto::Directory>,
)>,
}
#[async_trait]
impl DirectoryPutter for GRPCPutter {
#[instrument(level = "trace", skip_all, fields(directory.digest=%directory.digest()), err)]
async fn put(&mut self, directory: Directory) -> Result<(), crate::Error> {
match self.rq {
Some((_, ref directory_sender)) => {
if directory_sender.send(directory.into()).is_err() {
self.close().await?;
}
Ok(())
}
None => Err(Error::StorageError(
"DirectoryPutter already closed".to_string(),
)),
}
}
#[instrument(level = "trace", skip_all, ret, err)]
async fn close(&mut self) -> Result<B3Digest, crate::Error> {
match std::mem::take(&mut self.rq) {
None => Err(Error::StorageError("already closed".to_string())),
Some((task, directory_sender)) => {
drop(directory_sender);
let root_digest = task
.await?
.map_err(|e| Error::StorageError(e.to_string()))?
.root_digest;
root_digest.try_into().map_err(|_| {
Error::StorageError("invalid root digest length in response".to_string())
})
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tempfile::TempDir;
use tokio::net::UnixListener;
use tokio_retry::{strategy::ExponentialBackoff, Retry};
use tokio_stream::wrappers::UnixListenerStream;
use crate::{
directoryservice::{DirectoryService, GRPCDirectoryService, MemoryDirectoryService},
fixtures,
proto::{directory_service_client::DirectoryServiceClient, GRPCDirectoryServiceWrapper},
};
#[tokio::test]
async fn test_valid_unix_path_ping_pong() {
let tmpdir = TempDir::new().unwrap();
let socket_path = tmpdir.path().join("daemon");
let path_clone = socket_path.clone();
tokio::spawn(async {
let uds = UnixListener::bind(path_clone).unwrap();
let uds_stream = UnixListenerStream::new(uds);
let mut server = tonic::transport::Server::builder();
let router = server.add_service(
crate::proto::directory_service_server::DirectoryServiceServer::new(
GRPCDirectoryServiceWrapper::new(
Box::<MemoryDirectoryService>::default() as Box<dyn DirectoryService>
),
),
);
router.serve_with_incoming(uds_stream).await
});
Retry::spawn(
ExponentialBackoff::from_millis(20).max_delay(Duration::from_secs(10)),
|| async {
if socket_path.exists() {
Ok(())
} else {
Err(())
}
},
)
.await
.expect("failed to wait for socket");
let grpc_client = {
let url = url::Url::parse(&format!(
"grpc+unix://{}?wait-connect=1",
socket_path.display()
))
.expect("must parse");
let client = DirectoryServiceClient::new(
crate::tonic::channel_from_url(&url)
.await
.expect("must succeed"),
);
GRPCDirectoryService::from_client(client)
};
assert!(grpc_client
.get(&fixtures::DIRECTORY_A.digest())
.await
.expect("must not fail")
.is_none())
}
}