Skip to main content

tower/buffer/
service.rs

1use super::{
2    future::ResponseFuture,
3    message::Message,
4    worker::{Handle, Worker},
5};
6
7use std::{
8    future::Future,
9    task::{Context, Poll},
10};
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::sync::PollSender;
13use tower_service::Service;
14
15/// Adds an mpsc buffer in front of an inner service.
16///
17/// See the module documentation for more details.
18#[derive(Debug)]
19pub struct Buffer<Req, F> {
20    tx: PollSender<Message<Req, F>>,
21    handle: Handle,
22}
23
24impl<Req, F> Buffer<Req, F>
25where
26    F: 'static,
27{
28    /// Creates a new [`Buffer`] wrapping `service`.
29    ///
30    /// `bound` gives the maximal number of requests that can be queued for the service before
31    /// backpressure is applied to callers.
32    ///
33    /// The default Tokio executor is used to run the given service, which means that this method
34    /// must be called while on the Tokio runtime.
35    ///
36    /// # Panics
37    ///
38    /// Panics if `bound` is zero.
39    ///
40    /// # A note on choosing a `bound`
41    ///
42    /// When [`Buffer`]'s implementation of [`poll_ready`] returns [`Poll::Ready`], it reserves a
43    /// slot in the channel for the forthcoming [`call`]. However, if this call doesn't arrive,
44    /// this reserved slot may be held up for a long time. As a result, it's advisable to set
45    /// `bound` to be at least the maximum number of concurrent requests the [`Buffer`] will see.
46    /// If you do not, all the slots in the buffer may be held up by futures that have just called
47    /// [`poll_ready`] but will not issue a [`call`], which prevents other senders from issuing new
48    /// requests.
49    ///
50    /// [`Poll::Ready`]: std::task::Poll::Ready
51    /// [`call`]: crate::Service::call
52    /// [`poll_ready`]: crate::Service::poll_ready
53    pub fn new<S>(service: S, bound: usize) -> Self
54    where
55        S: Service<Req, Future = F> + Send + 'static,
56        F: Send,
57        S::Error: Into<crate::BoxError> + Send + Sync,
58        Req: Send + 'static,
59    {
60        let (service, worker) = Self::pair(service, bound);
61        tokio::spawn(worker);
62        service
63    }
64
65    /// Creates a new [`Buffer`] wrapping `service`, but returns the background worker.
66    ///
67    /// This is useful if you do not want to spawn directly onto the tokio runtime
68    /// but instead want to use your own executor. This will return the [`Buffer`] and
69    /// the background `Worker` that you can then spawn.
70    ///
71    /// # Panics
72    ///
73    /// Panics if `bound` is zero.
74    pub fn pair<S>(service: S, bound: usize) -> (Self, Worker<S, Req>)
75    where
76        S: Service<Req, Future = F> + Send + 'static,
77        F: Send,
78        S::Error: Into<crate::BoxError> + Send + Sync,
79        Req: Send + 'static,
80    {
81        assert!(bound > 0, "buffer bound must be greater than zero");
82        let (tx, rx) = mpsc::channel(bound);
83        let (handle, worker) = Worker::new(service, rx);
84        let buffer = Self {
85            tx: PollSender::new(tx),
86            handle,
87        };
88        (buffer, worker)
89    }
90
91    fn get_worker_error(&self) -> crate::BoxError {
92        self.handle.get_error_on_closed()
93    }
94}
95
96impl<Req, Rsp, F, E> Service<Req> for Buffer<Req, F>
97where
98    F: Future<Output = Result<Rsp, E>> + Send + 'static,
99    E: Into<crate::BoxError>,
100    Req: Send + 'static,
101{
102    type Response = Rsp;
103    type Error = crate::BoxError;
104    type Future = ResponseFuture<F>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        // First, check if the worker is still alive.
108        if self.tx.is_closed() {
109            // If the inner service has errored, then we error here.
110            return Poll::Ready(Err(self.get_worker_error()));
111        }
112
113        // Poll the sender to acquire a permit.
114        self.tx
115            .poll_reserve(cx)
116            .map_err(|_| self.get_worker_error())
117    }
118
119    fn call(&mut self, request: Req) -> Self::Future {
120        tracing::trace!("sending request to buffer worker");
121
122        // get the current Span so that we can explicitly propagate it to the worker
123        // if we didn't do this, events on the worker related to this span wouldn't be counted
124        // towards that span since the worker would have no way of entering it.
125        let span = tracing::Span::current();
126
127        // If we've made it here, then a channel permit has already been
128        // acquired, so we can freely allocate a oneshot.
129        let (tx, rx) = oneshot::channel();
130
131        match self.tx.send_item(Message { request, span, tx }) {
132            Ok(_) => ResponseFuture::new(rx),
133            // If the channel is closed, propagate the error from the worker.
134            Err(_) => {
135                tracing::trace!("buffer channel closed");
136                ResponseFuture::failed(self.get_worker_error())
137            }
138        }
139    }
140}
141
142impl<Req, F> Clone for Buffer<Req, F>
143where
144    Req: Send + 'static,
145    F: Send + 'static,
146{
147    fn clone(&self) -> Self {
148        Self {
149            handle: self.handle.clone(),
150            tx: self.tx.clone(),
151        }
152    }
153}