tower_http/set_header/
request.rs

1//! Set a header on the request.
2//!
3//! The header value to be set may be provided as a fixed value when the
4//! middleware is constructed, or determined dynamically based on the request
5//! by a closure. See the [`MakeHeaderValue`] trait for details.
6//!
7//! # Example
8//!
9//! Setting a header from a fixed value provided when the middleware is constructed:
10//!
11//! ```
12//! use http::{Request, Response, header::{self, HeaderValue}};
13//! use tower::{Service, ServiceExt, ServiceBuilder};
14//! use tower_http::set_header::SetRequestHeaderLayer;
15//! use http_body_util::Full;
16//! use bytes::Bytes;
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! # let http_client = tower::service_fn(|_: Request<Full<Bytes>>| async move {
21//! #     Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default()))
22//! # });
23//! #
24//! let mut svc = ServiceBuilder::new()
25//!     .layer(
26//!         // Layer that sets `User-Agent: my very cool app` on requests.
27//!         //
28//!         // `if_not_present` will only insert the header if it does not already
29//!         // have a value.
30//!         SetRequestHeaderLayer::if_not_present(
31//!             header::USER_AGENT,
32//!             HeaderValue::from_static("my very cool app"),
33//!         )
34//!     )
35//!     .service(http_client);
36//!
37//! let request = Request::new(Full::default());
38//!
39//! let response = svc.ready().await?.call(request).await?;
40//! #
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Setting a header based on a value determined dynamically from the request:
46//!
47//! ```
48//! use http::{Request, Response, header::{self, HeaderValue}};
49//! use tower::{Service, ServiceExt, ServiceBuilder};
50//! use tower_http::set_header::SetRequestHeaderLayer;
51//! use bytes::Bytes;
52//! use http_body_util::Full;
53//!
54//! # #[tokio::main]
55//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
56//! # let http_client = tower::service_fn(|_: Request<Full<Bytes>>| async move {
57//! #     Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default()))
58//! # });
59//! fn date_header_value() -> HeaderValue {
60//!     // ...
61//!     # HeaderValue::from_static("now")
62//! }
63//!
64//! let mut svc = ServiceBuilder::new()
65//!     .layer(
66//!         // Layer that sets `Date` to the current date and time.
67//!         //
68//!         // `overriding` will insert the header and override any previous values it
69//!         // may have.
70//!         SetRequestHeaderLayer::overriding(
71//!             header::DATE,
72//!             |request: &Request<Full<Bytes>>| {
73//!                 Some(date_header_value())
74//!             }
75//!         )
76//!     )
77//!     .service(http_client);
78//!
79//! let request = Request::new(Full::default());
80//!
81//! let response = svc.ready().await?.call(request).await?;
82//! #
83//! # Ok(())
84//! # }
85//! ```
86
87use super::{InsertHeaderMode, MakeHeaderValue};
88use http::{header::HeaderName, Request, Response};
89use std::{
90    fmt,
91    task::{Context, Poll},
92};
93use tower_layer::Layer;
94use tower_service::Service;
95
96/// Layer that applies [`SetRequestHeader`] which adds a request header.
97///
98/// See [`SetRequestHeader`] for more details.
99pub struct SetRequestHeaderLayer<M> {
100    header_name: HeaderName,
101    make: M,
102    mode: InsertHeaderMode,
103}
104
105impl<M> fmt::Debug for SetRequestHeaderLayer<M> {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        f.debug_struct("SetRequestHeaderLayer")
108            .field("header_name", &self.header_name)
109            .field("mode", &self.mode)
110            .field("make", &std::any::type_name::<M>())
111            .finish()
112    }
113}
114
115impl<M> SetRequestHeaderLayer<M> {
116    /// Create a new [`SetRequestHeaderLayer`].
117    ///
118    /// If a previous value exists for the same header, it is removed and replaced with the new
119    /// header value.
120    pub fn overriding(header_name: HeaderName, make: M) -> Self {
121        Self::new(header_name, make, InsertHeaderMode::Override)
122    }
123
124    /// Create a new [`SetRequestHeaderLayer`].
125    ///
126    /// The new header is always added, preserving any existing values. If previous values exist,
127    /// the header will have multiple values.
128    pub fn appending(header_name: HeaderName, make: M) -> Self {
129        Self::new(header_name, make, InsertHeaderMode::Append)
130    }
131
132    /// Create a new [`SetRequestHeaderLayer`].
133    ///
134    /// If a previous value exists for the header, the new value is not inserted.
135    pub fn if_not_present(header_name: HeaderName, make: M) -> Self {
136        Self::new(header_name, make, InsertHeaderMode::IfNotPresent)
137    }
138
139    fn new(header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
140        Self {
141            make,
142            header_name,
143            mode,
144        }
145    }
146}
147
148impl<S, M> Layer<S> for SetRequestHeaderLayer<M>
149where
150    M: Clone,
151{
152    type Service = SetRequestHeader<S, M>;
153
154    fn layer(&self, inner: S) -> Self::Service {
155        SetRequestHeader {
156            inner,
157            header_name: self.header_name.clone(),
158            make: self.make.clone(),
159            mode: self.mode,
160        }
161    }
162}
163
164impl<M> Clone for SetRequestHeaderLayer<M>
165where
166    M: Clone,
167{
168    fn clone(&self) -> Self {
169        Self {
170            make: self.make.clone(),
171            header_name: self.header_name.clone(),
172            mode: self.mode,
173        }
174    }
175}
176
177/// Middleware that sets a header on the request.
178#[derive(Clone)]
179pub struct SetRequestHeader<S, M> {
180    inner: S,
181    header_name: HeaderName,
182    make: M,
183    mode: InsertHeaderMode,
184}
185
186impl<S, M> SetRequestHeader<S, M> {
187    /// Create a new [`SetRequestHeader`].
188    ///
189    /// If a previous value exists for the same header, it is removed and replaced with the new
190    /// header value.
191    pub fn overriding(inner: S, header_name: HeaderName, make: M) -> Self {
192        Self::new(inner, header_name, make, InsertHeaderMode::Override)
193    }
194
195    /// Create a new [`SetRequestHeader`].
196    ///
197    /// The new header is always added, preserving any existing values. If previous values exist,
198    /// the header will have multiple values.
199    pub fn appending(inner: S, header_name: HeaderName, make: M) -> Self {
200        Self::new(inner, header_name, make, InsertHeaderMode::Append)
201    }
202
203    /// Create a new [`SetRequestHeader`].
204    ///
205    /// If a previous value exists for the header, the new value is not inserted.
206    pub fn if_not_present(inner: S, header_name: HeaderName, make: M) -> Self {
207        Self::new(inner, header_name, make, InsertHeaderMode::IfNotPresent)
208    }
209
210    fn new(inner: S, header_name: HeaderName, make: M, mode: InsertHeaderMode) -> Self {
211        Self {
212            inner,
213            header_name,
214            make,
215            mode,
216        }
217    }
218
219    define_inner_service_accessors!();
220}
221
222impl<S, M> fmt::Debug for SetRequestHeader<S, M>
223where
224    S: fmt::Debug,
225{
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        f.debug_struct("SetRequestHeader")
228            .field("inner", &self.inner)
229            .field("header_name", &self.header_name)
230            .field("mode", &self.mode)
231            .field("make", &std::any::type_name::<M>())
232            .finish()
233    }
234}
235
236impl<ReqBody, ResBody, S, M> Service<Request<ReqBody>> for SetRequestHeader<S, M>
237where
238    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
239    M: MakeHeaderValue<Request<ReqBody>>,
240{
241    type Response = S::Response;
242    type Error = S::Error;
243    type Future = S::Future;
244
245    #[inline]
246    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
247        self.inner.poll_ready(cx)
248    }
249
250    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
251        self.mode.apply(&self.header_name, &mut req, &mut self.make);
252        self.inner.call(req)
253    }
254}