diff --git a/tower/src/builder/mod.rs b/tower/src/builder/mod.rs index 1906bfc68..12d7b3030 100644 --- a/tower/src/builder/mod.rs +++ b/tower/src/builder/mod.rs @@ -155,7 +155,7 @@ impl ServiceBuilder { pub fn option_layer( self, layer: Option, - ) -> ServiceBuilder, L>> { + ) -> ServiceBuilder, L>> { self.layer(crate::util::option_layer(layer)) } diff --git a/tower/src/util/mod.rs b/tower/src/util/mod.rs index a0622f334..9fe5501c0 100644 --- a/tower/src/util/mod.rs +++ b/tower/src/util/mod.rs @@ -16,6 +16,7 @@ mod map_result; mod map_future; mod oneshot; mod optional; +mod optional_layer; mod ready; mod service_fn; mod then; @@ -38,6 +39,7 @@ pub use self::{ map_result::{MapResult, MapResultLayer}, oneshot::Oneshot, optional::Optional, + optional_layer::{OptionLayer, OptionService}, ready::{Ready, ReadyOneshot}, service_fn::{service_fn, ServiceFn}, then::{Then, ThenLayer}, @@ -46,8 +48,6 @@ pub use self::{ pub use self::call_all::{CallAll, CallAllUnordered}; use std::future::Future; -use crate::layer::util::Identity; - #[cfg(feature = "buffer")] use crate::buffer::Buffer; @@ -69,6 +69,7 @@ pub mod future { pub use super::map_response::MapResponseFuture; pub use super::map_result::MapResultFuture; pub use super::optional::future as optional; + pub use super::optional_layer::ResponseFuture as OptionResponseFuture; pub use super::then::ThenFuture; } @@ -1077,6 +1078,10 @@ impl ServiceExt for T where T: tower_service::Servi /// Convert an `Option` into a [`Layer`]. /// +/// The returned [`OptionLayer`] unifies the error types of the layered and +/// unlayered branches to [`BoxError`], so the optional layer is allowed to +/// change the error type. +/// /// ``` /// # use std::time::Duration; /// # use tower::Service; @@ -1095,10 +1100,8 @@ impl ServiceExt for T where T: tower_service::Servi /// ``` /// /// [`Layer`]: crate::layer::Layer -pub fn option_layer(layer: Option) -> Either { - if let Some(layer) = layer { - Either::Left(layer) - } else { - Either::Right(Identity::new()) - } +/// [`OptionLayer`]: crate::util::OptionLayer +/// [`BoxError`]: crate::BoxError +pub fn option_layer(layer: Option) -> OptionLayer { + OptionLayer::new(layer) } diff --git a/tower/src/util/optional_layer.rs b/tower/src/util/optional_layer.rs new file mode 100644 index 000000000..3a2835ae3 --- /dev/null +++ b/tower/src/util/optional_layer.rs @@ -0,0 +1,143 @@ +//! A [`Layer`] that is enabled or disabled by an [`Option`]. +//! +//! See [`OptionLayer`] and [`option_layer`] for more details. +//! +//! [`option_layer`]: crate::util::option_layer + +use crate::BoxError; +use pin_project_lite::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// A [`Layer`] that optionally applies an inner layer `L`. +/// +/// This is the layer produced by [`option_layer`]. When the inner layer is +/// present, the resulting service is the layered service; when it is absent, +/// the resulting service is the unmodified service. +/// +/// Unlike branching with [`Either`] directly, [`OptionLayer`] unifies the error +/// types of the two branches to [`BoxError`]. This means the optional layer is +/// allowed to change the error type (as [`TimeoutLayer`] does, for example) +/// without the two branches needing to share an error type. +/// +/// [`option_layer`]: crate::util::option_layer +/// [`Either`]: crate::util::Either +/// [`BoxError`]: crate::BoxError +/// [`TimeoutLayer`]: crate::timeout::TimeoutLayer +#[derive(Clone, Copy, Debug)] +pub struct OptionLayer { + layer: Option, +} + +impl OptionLayer { + /// Create a new [`OptionLayer`] wrapping the given optional layer. + pub const fn new(layer: Option) -> Self { + OptionLayer { layer } + } +} + +impl From> for OptionLayer { + fn from(layer: Option) -> Self { + OptionLayer::new(layer) + } +} + +impl Layer for OptionLayer +where + L: Layer, +{ + type Service = OptionService; + + fn layer(&self, inner: S) -> Self::Service { + match &self.layer { + Some(layer) => OptionService::Some(layer.layer(inner)), + None => OptionService::None(inner), + } + } +} + +/// The [`Service`] produced by [`OptionLayer`]. +/// +/// Its error type is [`BoxError`], erasing any difference between the layered +/// and unlayered branches' error types. +/// +/// [`BoxError`]: crate::BoxError +#[derive(Clone, Copy, Debug)] +pub enum OptionService { + /// The inner layer was present; the layered service. + Some(A), + /// The inner layer was absent; the unmodified service. + None(B), +} + +impl Service for OptionService +where + A: Service, + A::Error: Into, + B: Service, + B::Error: Into, +{ + type Response = A::Response; + type Error = BoxError; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + OptionService::Some(service) => service.poll_ready(cx).map_err(Into::into), + OptionService::None(service) => service.poll_ready(cx).map_err(Into::into), + } + } + + fn call(&mut self, request: Request) -> Self::Future { + match self { + OptionService::Some(service) => ResponseFuture { + kind: Kind::Some { + inner: service.call(request), + }, + }, + OptionService::None(service) => ResponseFuture { + kind: Kind::None { + inner: service.call(request), + }, + }, + } + } +} + +pin_project! { + /// Response future for [`OptionService`]. + pub struct ResponseFuture { + #[pin] + kind: Kind, + } +} + +pin_project! { + #[project = KindProj] + enum Kind { + Some { #[pin] inner: A }, + None { #[pin] inner: B }, + } +} + +impl Future for ResponseFuture +where + A: Future>, + AE: Into, + B: Future>, + BE: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Some { inner } => inner.poll(cx).map_err(Into::into), + KindProj::None { inner } => inner.poll(cx).map_err(Into::into), + } + } +} diff --git a/tower/tests/util/main.rs b/tower/tests/util/main.rs index 18b7813ff..7e79c65e9 100644 --- a/tower/tests/util/main.rs +++ b/tower/tests/util/main.rs @@ -3,6 +3,7 @@ mod call_all; mod oneshot; +mod option_layer; mod service_fn; #[path = "../support.rs"] pub(crate) mod support; diff --git a/tower/tests/util/option_layer.rs b/tower/tests/util/option_layer.rs new file mode 100644 index 000000000..d462249fc --- /dev/null +++ b/tower/tests/util/option_layer.rs @@ -0,0 +1,25 @@ +use std::convert::Infallible; + +use tower::util::{option_layer, MapErrLayer}; +use tower::{Layer, Service, ServiceExt}; + +// Regression test for #665: `option_layer` previously returned an +// `Either` whose service required both branches to share an error +// type. When the optional layer changed the error type, the result did not +// implement `Service` at all. `option_layer` now unifies both branches' errors +// to `BoxError`. +#[tokio::test] +async fn option_layer_unifies_branch_errors() { + let inner = tower::service_fn(|()| async { Ok::<_, Infallible>(()) }); + + // The optional layer changes the error type (`Infallible` -> `String`), + // which differs from the unlayered branch's error type (`Infallible`). + let layer = option_layer(Some(MapErrLayer::new(|e: Infallible| -> String { + match e {} + }))); + + let mut svc = layer.layer(inner); + + let response = svc.ready().await.unwrap().call(()).await; + assert!(response.is_ok()); +}