1use std::{future::Future, ops::DerefMut, sync::Arc};
2
3use bytes::Bytes;
4use tokio::{
5 io::{split, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
6 sync::Mutex,
7};
8use tracing::{debug, warn};
9
10use super::{
11 framing::{NixFramedReader, StderrReadFramedReader},
12 types::{AddToStoreNarRequest, QueryValidPaths},
13 worker_protocol::{server_handshake_client, ClientSettings, Operation, Trust, STDERR_LAST},
14 NixDaemonIO,
15};
16
17use crate::{
18 store_path::StorePath,
19 wire::{
20 de::{NixRead, NixReader},
21 ser::{NixSerialize, NixWrite, NixWriter, NixWriterBuilder},
22 ProtocolVersion,
23 },
24};
25
26use crate::{nix_daemon::types::NixError, worker_protocol::STDERR_ERROR};
27
28#[allow(dead_code)]
39pub struct NixDaemon<IO, R, W> {
40 io: Arc<IO>,
41 protocol_version: ProtocolVersion,
42 client_settings: ClientSettings,
43 reader: NixReader<R>,
44 writer: Arc<Mutex<NixWriter<W>>>,
45}
46
47impl<IO, R, W> NixDaemon<IO, R, W>
48where
49 IO: NixDaemonIO + Sync + Send,
50{
51 pub fn new(
52 io: Arc<IO>,
53 protocol_version: ProtocolVersion,
54 client_settings: ClientSettings,
55 reader: NixReader<R>,
56 writer: NixWriter<W>,
57 ) -> Self {
58 Self {
59 io,
60 protocol_version,
61 client_settings,
62 reader,
63 writer: Arc::new(Mutex::new(writer)),
64 }
65 }
66}
67
68impl<IO, RW> NixDaemon<IO, ReadHalf<RW>, WriteHalf<RW>>
69where
70 RW: AsyncReadExt + AsyncWriteExt + Send + Unpin + 'static,
71 IO: NixDaemonIO + Sync + Send,
72{
73 pub async fn initialize(io: Arc<IO>, mut connection: RW) -> Result<Self, std::io::Error>
80 where
81 RW: AsyncReadExt + AsyncWriteExt + Send + Unpin,
82 {
83 let protocol_version =
84 server_handshake_client(&mut connection, "2.18.2", Trust::Trusted).await?;
85
86 connection.write_u64_le(STDERR_LAST).await?;
87 let (reader, writer) = split(connection);
88 let mut reader = NixReader::builder()
89 .set_version(protocol_version)
90 .build(reader);
91 let mut writer = NixWriterBuilder::default()
92 .set_version(protocol_version)
93 .build(writer);
94
95 let operation: Operation = reader.read_value().await?;
97 if operation != Operation::SetOptions {
98 return Err(std::io::Error::other(
99 "Expected SetOptions operation, but got {operation}",
100 ));
101 }
102 let client_settings: ClientSettings = reader.read_value().await?;
103 writer.write_number(STDERR_LAST).await?;
104 writer.flush().await?;
105
106 Ok(Self::new(
107 io,
108 protocol_version,
109 client_settings,
110 reader,
111 writer,
112 ))
113 }
114
115 pub async fn handle_client(&mut self) -> Result<(), std::io::Error> {
117 let io = self.io.clone();
118 loop {
119 let op_code = self.reader.read_number().await?;
120 match TryInto::<Operation>::try_into(op_code) {
121 Ok(operation) => match operation {
123 Operation::IsValidPath => {
124 let path: StorePath<String> = self.reader.read_value().await?;
125 Self::handle(&self.writer, io.is_valid_path(&path)).await?
126 }
127 Operation::SetOptions => {
132 self.client_settings = self.reader.read_value().await?;
133 Self::handle(&self.writer, async { Ok(()) }).await?
134 }
135 Operation::QueryPathInfo => {
136 let path: StorePath<String> = self.reader.read_value().await?;
137 Self::handle(&self.writer, io.query_path_info(&path)).await?
138 }
139 Operation::QueryPathFromHashPart => {
140 let hash: Bytes = self.reader.read_value().await?;
141 Self::handle(&self.writer, io.query_path_from_hash_part(&hash)).await?
142 }
143 Operation::QueryValidPaths => {
144 let query: QueryValidPaths = self.reader.read_value().await?;
145 Self::handle(&self.writer, io.query_valid_paths(&query)).await?
146 }
147 Operation::QueryValidDerivers => {
148 let path: StorePath<String> = self.reader.read_value().await?;
149 Self::handle(&self.writer, io.query_valid_derivers(&path)).await?
150 }
151 Operation::QueryReferrers | Operation::QueryRealisation => {
158 let _: String = self.reader.read_value().await?;
159 Self::handle(&self.writer, async move {
160 warn!(
161 ?operation,
162 "This operation is not implemented. Returning empty result..."
163 );
164 Ok(Vec::<StorePath<String>>::new())
165 })
166 .await?
167 }
168 Operation::AddToStoreNar => {
169 let request: AddToStoreNarRequest = self.reader.read_value().await?;
170 let minor_version = self.protocol_version.minor();
171 match minor_version {
172 ..21 => {
173 Self::handle(
176 &self.writer,
177 self.io.add_to_store_nar(request, &mut self.reader),
178 )
179 .await?
180 }
181 21..23 => {
182 Self::handle(&self.writer, async {
184 let mut writer = self.writer.lock().await;
185 let mut reader = StderrReadFramedReader::new(
186 &mut self.reader,
187 writer.deref_mut(),
188 );
189 self.io.add_to_store_nar(request, &mut reader).await
190 })
191 .await?
192 }
193 23.. => {
194 let mut framed = NixFramedReader::new(&mut self.reader);
196 Self::handle(&self.writer, async {
197 self.io.add_to_store_nar(request, &mut framed).await
198 })
199 .await?
200 }
201 }
202 }
203 _ => {
204 return Err(std::io::Error::other(format!(
205 "Operation {operation:?} is not implemented"
206 )));
207 }
208 },
209 _ => {
210 return Err(std::io::Error::other(format!(
211 "Unknown operation code received: {op_code}"
212 )));
213 }
214 }
215 }
216 }
217
218 async fn handle<T>(
228 writer: &Arc<Mutex<NixWriter<WriteHalf<RW>>>>,
229 future: impl Future<Output = std::io::Result<T>>,
230 ) -> Result<(), std::io::Error>
231 where
232 T: NixSerialize + Send,
233 {
234 let result = future.await;
235 let mut writer = writer.lock().await;
236
237 match result {
238 Ok(r) => {
239 writer.write_number(STDERR_LAST).await?;
242 writer.write_value(&r).await?;
243 writer.flush().await
244 }
245 Err(e) => {
246 debug!(err = ?e, "IO error");
247 writer.write_number(STDERR_ERROR).await?;
248 writer.write_value(&NixError::new(format!("{e:?}"))).await?;
249 writer.flush().await
250 }
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use std::{io::ErrorKind, sync::Arc};
259
260 use mockall::predicate;
261 use tokio::io::AsyncWriteExt;
262
263 use crate::{
264 nix_daemon::MockNixDaemonIO,
265 wire::ProtocolVersion,
266 worker_protocol::{ClientSettings, WORKER_MAGIC_1, WORKER_MAGIC_2},
267 };
268
269 #[tokio::test]
270 async fn test_daemon_initialization() {
271 let mut builder = tokio_test::io::Builder::new();
272 let test_conn = builder
273 .read(&WORKER_MAGIC_1.to_le_bytes())
274 .write(&WORKER_MAGIC_2.to_le_bytes())
275 .write(&[37, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
277 .read(&[35, 1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
279 .read(&[0; 8])
281 .read(&[0; 8])
283 .write(&[0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00])
285 .write(&[50, 46, 49, 56, 46, 50, 0, 0])
287 .write(&[1, 0, 0, 0, 0, 0, 0, 0])
289 .write(&[115, 116, 108, 97, 0, 0, 0, 0]);
291
292 let mut bytes = Vec::new();
293 let mut writer = NixWriter::new(&mut bytes);
294 writer
295 .write_value(&ClientSettings::default())
296 .await
297 .unwrap();
298 writer.flush().await.unwrap();
299
300 let test_conn = test_conn
301 .read(&[19, 0, 0, 0, 0, 0, 0, 0])
303 .read(&bytes)
304 .write(&[115, 116, 108, 97, 0, 0, 0, 0])
306 .build();
307
308 let mock = MockNixDaemonIO::new();
309 let daemon = NixDaemon::initialize(Arc::new(mock), test_conn)
310 .await
311 .unwrap();
312 assert_eq!(daemon.client_settings, ClientSettings::default());
313 assert_eq!(daemon.protocol_version, ProtocolVersion::from_parts(1, 35));
314 }
315
316 async fn serialize<T>(req: &T, protocol_version: ProtocolVersion) -> Vec<u8>
317 where
318 T: NixSerialize + Send,
319 {
320 let mut result: Vec<u8> = Vec::new();
321 let mut w = NixWriter::builder()
322 .set_version(protocol_version)
323 .build(&mut result);
324 w.write_value(req).await.unwrap();
325 w.flush().await.unwrap();
326 result
327 }
328
329 async fn respond<T>(
330 resp: &Result<T, std::io::Error>,
331 protocol_version: ProtocolVersion,
332 ) -> Vec<u8>
333 where
334 T: NixSerialize + Send,
335 {
336 let mut result: Vec<u8> = Vec::new();
337 let mut w = NixWriter::builder()
338 .set_version(protocol_version)
339 .build(&mut result);
340 match resp {
341 Ok(value) => {
342 w.write_value(&STDERR_LAST).await.unwrap();
343 w.write_value(value).await.unwrap();
344 }
345 Err(e) => {
346 w.write_value(&STDERR_ERROR).await.unwrap();
347 w.write_value(&NixError::new(format!("{e:?}")))
348 .await
349 .unwrap();
350 }
351 }
352 w.flush().await.unwrap();
353 result
354 }
355
356 #[tokio::test]
357 async fn test_handle_is_valid_path_ok() {
358 let version = ProtocolVersion::from_parts(1, 37);
359 let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
360 let mut mock = MockNixDaemonIO::new();
361 let (reader, writer) = split(io);
362 let path: StorePath<String> = StorePath::<String>::from_absolute_path(
363 "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
364 )
365 .unwrap();
366 mock.expect_is_valid_path()
367 .with(predicate::eq(path.clone()))
368 .times(1)
369 .returning(|_| Box::pin(async { Ok(true) }));
370
371 handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
372 handle.read(&serialize(&path, version).await);
373 handle.write(&respond(&Ok(true), version).await);
374 drop(handle);
375
376 let mut daemon = NixDaemon::new(
377 Arc::new(mock),
378 version,
379 ClientSettings::default(),
380 NixReader::new(reader),
381 NixWriter::new(writer),
382 );
383 assert_eq!(
384 ErrorKind::UnexpectedEof,
385 daemon
386 .handle_client()
387 .await
388 .expect_err("Expecting eof")
389 .kind()
390 );
391 }
392
393 #[tokio::test]
394 async fn test_handle_is_valid_path_err() {
395 let version = ProtocolVersion::from_parts(1, 37);
396 let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
397 let mut mock = MockNixDaemonIO::new();
398 let (reader, writer) = split(io);
399 let path: StorePath<String> = StorePath::<String>::from_absolute_path(
400 "/nix/store/33l4p0pn0mybmqzaxfkpppyh7vx1c74p-hello-2.12.1".as_bytes(),
401 )
402 .unwrap();
403 mock.expect_is_valid_path()
404 .with(predicate::eq(path.clone()))
405 .times(1)
406 .returning(|_| Box::pin(async { Err(std::io::Error::other("hello")) }));
407
408 handle.read(&Into::<u64>::into(Operation::IsValidPath).to_le_bytes());
409 handle.read(&serialize(&path, version).await);
410 handle.write(&respond::<bool>(&Err(std::io::Error::other("hello")), version).await);
411 drop(handle);
412
413 let mut daemon = NixDaemon::new(
414 Arc::new(mock),
415 version,
416 ClientSettings::default(),
417 NixReader::new(reader),
418 NixWriter::new(writer),
419 );
420 assert_eq!(
421 ErrorKind::UnexpectedEof,
422 daemon
423 .handle_client()
424 .await
425 .expect_err("Expecting eof")
426 .kind()
427 );
428 }
429}