use crate::pb::health_server::{Health, HealthServer};
use crate::pb::{HealthCheckRequest, HealthCheckResponse};
use crate::ServingStatus;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::{watch, RwLock};
use tokio_stream::Stream;
#[cfg(feature = "transport")]
use tonic::server::NamedService;
use tonic::{Request, Response, Status};
pub fn health_reporter() -> (HealthReporter, HealthServer<impl Health>) {
let reporter = HealthReporter::new();
let service = HealthService::new(reporter.statuses.clone());
let server = HealthServer::new(service);
(reporter, server)
}
type StatusPair = (watch::Sender<ServingStatus>, watch::Receiver<ServingStatus>);
#[derive(Clone, Debug)]
pub struct HealthReporter {
statuses: Arc<RwLock<HashMap<String, StatusPair>>>,
}
impl HealthReporter {
fn new() -> Self {
let server_status = ("".to_string(), watch::channel(ServingStatus::Serving));
let statuses = Arc::new(RwLock::new(HashMap::from([server_status])));
HealthReporter { statuses }
}
#[cfg(feature = "transport")]
pub async fn set_serving<S>(&mut self)
where
S: NamedService,
{
let service_name = <S as NamedService>::NAME;
self.set_service_status(service_name, ServingStatus::Serving)
.await;
}
#[cfg(feature = "transport")]
pub async fn set_not_serving<S>(&mut self)
where
S: NamedService,
{
let service_name = <S as NamedService>::NAME;
self.set_service_status(service_name, ServingStatus::NotServing)
.await;
}
pub async fn set_service_status<S>(&mut self, service_name: S, status: ServingStatus)
where
S: AsRef<str>,
{
let service_name = service_name.as_ref();
let mut writer = self.statuses.write().await;
match writer.get(service_name) {
Some((tx, _)) => {
tx.send(status).expect("channel should not be closed");
}
None => {
writer.insert(service_name.to_string(), watch::channel(status));
}
};
}
pub async fn clear_service_status(&mut self, service_name: &str) {
let mut writer = self.statuses.write().await;
let _ = writer.remove(service_name);
}
}
#[derive(Debug)]
pub struct HealthService {
statuses: Arc<RwLock<HashMap<String, StatusPair>>>,
}
impl HealthService {
fn new(services: Arc<RwLock<HashMap<String, StatusPair>>>) -> Self {
HealthService { statuses: services }
}
async fn service_health(&self, service_name: &str) -> Option<ServingStatus> {
let reader = self.statuses.read().await;
reader.get(service_name).map(|p| *p.1.borrow())
}
}
#[tonic::async_trait]
impl Health for HealthService {
async fn check(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<HealthCheckResponse>, Status> {
let service_name = request.get_ref().service.as_str();
let status = self.service_health(service_name).await;
match status {
None => Err(Status::not_found("service not registered")),
Some(status) => Ok(Response::new(HealthCheckResponse {
status: crate::pb::health_check_response::ServingStatus::from(status) as i32,
})),
}
}
type WatchStream =
Pin<Box<dyn Stream<Item = Result<HealthCheckResponse, Status>> + Send + 'static>>;
async fn watch(
&self,
request: Request<HealthCheckRequest>,
) -> Result<Response<Self::WatchStream>, Status> {
let service_name = request.get_ref().service.as_str();
let mut status_rx = match self.statuses.read().await.get(service_name) {
None => return Err(Status::not_found("service not registered")),
Some(pair) => pair.1.clone(),
};
let output = async_stream::try_stream! {
let status = crate::pb::health_check_response::ServingStatus::from(*status_rx.borrow()) as i32;
yield HealthCheckResponse { status };
#[allow(clippy::redundant_pattern_matching)]
while let Ok(_) = status_rx.changed().await {
let status = crate::pb::health_check_response::ServingStatus::from(*status_rx.borrow()) as i32;
yield HealthCheckResponse { status };
}
};
Ok(Response::new(Box::pin(output) as Self::WatchStream))
}
}
#[cfg(test)]
mod tests {
use crate::pb::health_server::Health;
use crate::pb::HealthCheckRequest;
use crate::server::{HealthReporter, HealthService};
use crate::ServingStatus;
use tokio::sync::watch;
use tokio_stream::StreamExt;
use tonic::{Code, Request, Status};
fn assert_serving_status(wire: i32, expected: ServingStatus) {
let expected = crate::pb::health_check_response::ServingStatus::from(expected) as i32;
assert_eq!(wire, expected);
}
fn assert_grpc_status(wire: Option<Status>, expected: Code) {
let wire = wire.expect("status is not None").code();
assert_eq!(wire, expected);
}
async fn make_test_service() -> (HealthReporter, HealthService) {
let health_reporter = HealthReporter::new();
{
let mut statuses = health_reporter.statuses.write().await;
statuses.insert(
"TestService".to_string(),
watch::channel(ServingStatus::Unknown),
);
}
let health_service = HealthService::new(health_reporter.statuses.clone());
(health_reporter, health_service)
}
#[tokio::test]
async fn test_service_check() {
let (mut reporter, service) = make_test_service().await;
let resp = service
.check(Request::new(HealthCheckRequest {
service: "".to_string(),
}))
.await;
assert!(resp.is_ok());
let resp = resp.unwrap().into_inner();
assert_serving_status(resp.status, ServingStatus::Serving);
let resp = service
.check(Request::new(HealthCheckRequest {
service: "Unregistered".to_string(),
}))
.await;
assert!(resp.is_err());
assert_grpc_status(resp.err(), Code::NotFound);
let resp = service
.check(Request::new(HealthCheckRequest {
service: "TestService".to_string(),
}))
.await;
assert!(resp.is_ok());
let resp = resp.unwrap().into_inner();
assert_serving_status(resp.status, ServingStatus::Unknown);
reporter
.set_service_status("TestService", ServingStatus::Serving)
.await;
let resp = service
.check(Request::new(HealthCheckRequest {
service: "TestService".to_string(),
}))
.await;
assert!(resp.is_ok());
let resp = resp.unwrap().into_inner();
assert_serving_status(resp.status, ServingStatus::Serving);
}
#[tokio::test]
async fn test_service_watch() {
let (mut reporter, service) = make_test_service().await;
let resp = service
.watch(Request::new(HealthCheckRequest {
service: "".to_string(),
}))
.await;
assert!(resp.is_ok());
let mut resp = resp.unwrap().into_inner();
let item = resp
.next()
.await
.expect("streamed response is Some")
.expect("response is ok");
assert_serving_status(item.status, ServingStatus::Serving);
let resp = service
.watch(Request::new(HealthCheckRequest {
service: "Unregistered".to_string(),
}))
.await;
assert!(resp.is_err());
assert_grpc_status(resp.err(), Code::NotFound);
let resp = service
.watch(Request::new(HealthCheckRequest {
service: "TestService".to_string(),
}))
.await;
assert!(resp.is_ok());
let mut resp = resp.unwrap().into_inner();
let item = resp
.next()
.await
.expect("streamed response is Some")
.expect("response is ok");
assert_serving_status(item.status, ServingStatus::Unknown);
reporter
.set_service_status("TestService", ServingStatus::NotServing)
.await;
let item = resp
.next()
.await
.expect("streamed response is Some")
.expect("response is ok");
assert_serving_status(item.status, ServingStatus::NotServing);
reporter
.set_service_status("TestService", ServingStatus::Serving)
.await;
let item = resp
.next()
.await
.expect("streamed response is Some")
.expect("response is ok");
assert_serving_status(item.status, ServingStatus::Serving);
reporter.clear_service_status("TestService").await;
let item = resp.next().await;
assert!(item.is_none());
}
}