use std::{future::Future, ops::DerefMut, sync::Arc};
use bytes::Bytes;
use tokio::{
io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
sync::Mutex,
};
use tracing::{debug, warn};
use super::{
framing::{NixFramedReader, StderrReadFramedReader},
types::{AddToStoreNarRequest, QueryValidPaths},
worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST},
NixDaemonIO,
};
use crate::{
store_path::StorePath,
wire::{
de::{NixRead, NixReader},
ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder},
ProtocolVersion,
},
};
use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR};
#[allow(dead_code)]
pub struct NixDaemon<IO, R, W> {
io: Arc<IO>,
protocol_version: ProtocolVersion,
client_settings: ClientSettings,
reader: NixReader<R>,
writer: Arc<Mutex<NixWriter<W>>>,
}
impl<IO, R, W> NixDaemon<IO, R, W>
where
IO: NixDaemonIO + Sync + Send,
{
pub fn new(
io: Arc<IO>,
protocol_version: ProtocolVersion,
client_settings: ClientSettings,
reader: NixReader<R>,
writer: NixWriter<W>,
) -> Self {
Self {
io,
protocol_version,
client_settings,
reader,
writer: Arc::new(Mutex::new(writer)),
}
}
}
impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>>
where
RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static,
IO: NixDaemonIO + Sync + Send,
{
pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error>
where
RW: AsyncReadExt + AsyncWriteExt + Send + Unpin,
{
let protocol_version =
server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?;
connection.write_u64_le(STDERR_LAST).await?;
let (reader, writer) = split(connection);
let mut reader = NixReader::builder()
.set_version(protocol_version)
.build(reader);
let mut writer = NixWriterBuilder::default()
.set_version(protocol_version)
.build(writer);
let operation: Operation = reader.read_value().await?;
if operation != Operation::SetOptions {
return Err(std::io::Error::other(
"Expected SetOptions operation, but got {operation}",
));
}
let client_settings: ClientSettings = reader.read_value().await?;
writer.write_number(STDERR_LAST).await?;
writer.flush().await?;
Ok(Self::new(
io,
protocol_version,
client_settings,
reader,
writer,
))
}
pub async fn handle_client(&mut self) -> Result<(), std::io::Error> {
let io = self.io.clone();
loop {
let op_code = self.reader.read_number().await?;
match TryInto::<Operation>::try_into(op_code) {
Ok(operation) => match operation {
Operation::IsValidPath => {
let path: StorePath<String> = self.reader.read_value().await?;
Self::handle(&self.writer, io.is_valid_path(&path)).await?
}
Operation::SetOptions => {
self.client_settings = self.reader.read_value().await?;
Self::handle(&self.writer, async { Ok(()) }).await?
}
Operation::QueryPathInfo => {
let path: StorePath<String> = self.reader.read_value().await?;
Self::handle(&self.writer, io.query_path_info(&path)).await?
}
Operation::QueryPathFromHashPart => {
let hash: Bytes = self.reader.read_value().await?;
Self::handle(&self.writer, io.query_path_from_hash_part(&hash)).await?
}
Operation::QueryValidPaths => {
let query: QueryValidPaths = self.reader.read_value().await?;
Self::handle(&self.writer, io.query_valid_paths(&query)).await?
}
Operation::QueryValidDerivers => {
let path: StorePath<String> = self.reader.read_value().await?;
Self::handle(&self.writer, io.query_valid_derivers(&path)).await?
}
Operation::QueryReferrers | Operation::QueryRealisation => {
let _: String = self.reader.read_value().await?;
Self::handle(&self.writer, async move {
warn!(
?operation,
"This operation is not implemented. Returning empty result..."
);
Ok(Vec::<StorePath<String>>::new())
})
.await?
}
Operation::AddToStoreNar => {
let request: AddToStoreNarRequest = self.reader.read_value().await?;
let minor_version = self.protocol_version.minor();
match minor_version {
..21 => {
Self::handle(
&self.writer,
self.io.add_to_store_nar(request, &mut self.reader),
)
.await?
}
21..23 => {
Self::handle(&self.writer, async {
let mut writer = self.writer.lock().await;
let mut reader = StderrReadFramedReader::new(
&mut self.reader,
writer.deref_mut(),
);
self.io.add_to_store_nar(request, &mut reader).await
})
.await?
}
23.. => {
let mut framed = NixFramedReader::new(&mut self.reader);
Self::handle(&self.writer, async {
self.io.add_to_store_nar(request, &mut framed).await
})
.await?
}
}
}
_ => {
return Err(std::io::Error::other(format!(
"Operation {operation:?} is not implemented"
)));
}
},
_ => {
return Err(std::io::Error::other(format!(
"Unknown operation code received: {op_code}"
)));
}
}
}
}
async fn handle<T>(
writer: &Arc<Mutex<NixWriter<WriteHalf<RW>>>>,
future: impl Future<Output = std::io::Result<T>>,
) -> Result<(), std::io::Error>
where
T: NixSerialize + Send,
{
let result = future.await;
let mut writer = writer.lock().await;
match result {
Ok(r) => {
writer.write_number(STDERR_LAST).await?;
writer.write_value(&r).await?;
writer.flush().await
}
Err(e) => {
debug!(err = ?e, "IO error");
writer.write_number(STDERR_ERROR).await?;
writer.write_value(&NixError::new(format!("{e:?}"))).await?;
writer.flush().await
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{io::ErrorKind, sync::Arc};
use mockall::predicate;
use tokio::io::AsyncWriteExt;
use crate::{
nix_daemon::MockNixDaemonIO,
wire::ProtocolVersion,
worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
};
#[tokio::test]
async fn test_daemon_initialization() {
let mut builder = tokio_test::io::Builder::new();
let test_conn = builder
.read(&WORKER_MAGIC_1.to_le_bytes())
.write(&WORKER_MAGIC_2.to_le_bytes())
.write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
.read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
.read(&[0; 8])
.read(&[0; 8])
.write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
.write(&[50, 46, 49, 56, 46, 50, 0, 0])
.write(&[1, 0, 0, 0, 0, 0, 0, 0])
.write(&[115, 116, 108, 97, 0, 0, 0, 0]);
let mut bytes = Vec::new();
let mut writer = NixWriter::new(&mut bytes);
writer
.write_value(&ClientSettings::default())
.await
.unwrap();
writer.flush().await.unwrap();
let test_conn = test_conn
.read(&[19, 0, 0, 0, 0, 0, 0, 0])
.read(&bytes)
.write(&[115, 116, 108, 97, 0, 0, 0, 0])
.build();
let mock = MockNixDaemonIO::new();
let daemon = NixDaemon::initialize(Arc::new(mock), test_conn)
.await
.unwrap();
assert_eq!(daemon.client_settings, ClientSettings::default());
assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
}
async fn serialize<T>(req: &T, protocol_version: ProtocolVersion) -> Vec<u8>
where
T: NixSerialize + Send,
{
let mut result: Vec<u8> = Vec::new();
let mut w = NixWriter::builder()
.set_version(protocol_version)
.build(&mut result);
w.write_value(req).await.unwrap();
w.flush().await.unwrap();
result
}
async fn respond<T>(
resp: &Result<T, std::io::Error>,
protocol_version: ProtocolVersion,
) -> Vec<u8>
where
T: NixSerialize + Send,
{
let mut result: Vec<u8> = Vec::new();
let mut w = NixWriter::builder()
.set_version(protocol_version)
.build(&mut result);
match resp {
Ok(value) => {
w.write_value(&STDERR_LAST).await.unwrap();
w.write_value(value).await.unwrap();
}
Err(e) => {
w.write_value(&STDERR_ERROR).await.unwrap();
w.write_value(&NixError::new(format!("{:?}", e)))
.await
.unwrap();
}
}
w.flush().await.unwrap();
result
}
#[tokio::test]
async fn test_handle_is_valid_path_ok() {
let version = ProtocolVersion::from_parts(1, 37);
let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
let mut mock = MockNixDaemonIO::new();
let (reader, writer) = split(io);
let path: StorePath<String> = StorePath::<String>::from_absolute_path(
"/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
)
.unwrap();
mock.expect_is_valid_path()
.with(predicate::eq(path.clone()))
.times(1)
.returning(|_| Box::pin(async { Ok(true) }));
handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
handle.read(&serialize(&path, version).await);
handle.write(&respond(&Ok(true), version).await);
drop(handle);
let mut daemon = NixDaemon::new(
Arc::new(mock),
version,
ClientSettings::default(),
NixReader::new(reader),
NixWriter::new(writer),
);
assert_eq!(
ErrorKind::UnexpectedEof,
daemon
.handle_client()
.await
.expect_err("Expecting eof")
.kind()
);
}
#[tokio::test]
async fn test_handle_is_valid_path_err() {
let version = ProtocolVersion::from_parts(1, 37);
let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
let mut mock = MockNixDaemonIO::new();
let (reader, writer) = split(io);
let path: StorePath<String> = StorePath::<String>::from_absolute_path(
"/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
)
.unwrap();
mock.expect_is_valid_path()
.with(predicate::eq(path.clone()))
.times(1)
.returning(|_| Box::pin(async { Err(std::io::Error::other("hello")) }));
handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
handle.read(&serialize(&path, version).await);
handle.write(&respond::<bool>(&Err(std::io::Error::other("hello")), version).await);
drop(handle);
let mut daemon = NixDaemon::new(
Arc::new(mock),
version,
ClientSettings::default(),
NixReader::new(reader),
NixWriter::new(writer),
);
assert_eq!(
ErrorKind::UnexpectedEof,
daemon
.handle_client()
.await
.expect_err("Expecting eof")
.kind()
);
}
}