redis/aio/
runtime.rs

1use std::{io, sync::Arc, time::Duration};
2
3use futures_util::Future;
4
5#[cfg(any(
6    all(
7        feature = "tokio-comp",
8        any(feature = "async-std-comp", feature = "smol-comp")
9    ),
10    all(
11        feature = "smol-comp",
12        any(feature = "async-std-comp", feature = "tokio-comp")
13    ),
14    all(
15        feature = "async-std-comp",
16        any(feature = "tokio-comp", feature = "smol-comp")
17    )
18))]
19use std::sync::OnceLock;
20
21#[cfg(feature = "async-std-comp")]
22use super::async_std as crate_async_std;
23#[cfg(feature = "smol-comp")]
24use super::smol as crate_smol;
25#[cfg(feature = "tokio-comp")]
26use super::tokio as crate_tokio;
27use super::RedisRuntime;
28use crate::types::RedisError;
29#[cfg(feature = "smol-comp")]
30use smol_timeout::TimeoutExt;
31
32#[derive(Clone, Copy, Debug)]
33pub(crate) enum Runtime {
34    #[cfg(feature = "tokio-comp")]
35    Tokio,
36    #[cfg(feature = "async-std-comp")]
37    AsyncStd,
38    #[cfg(feature = "smol-comp")]
39    Smol,
40}
41
42pub(crate) enum TaskHandle {
43    #[cfg(feature = "tokio-comp")]
44    Tokio(tokio::task::JoinHandle<()>),
45    #[cfg(feature = "async-std-comp")]
46    AsyncStd(async_std::task::JoinHandle<()>),
47    #[cfg(feature = "smol-comp")]
48    Smol(smol::Task<()>),
49}
50
51pub(crate) struct HandleContainer(Option<TaskHandle>);
52
53impl HandleContainer {
54    pub(crate) fn new(handle: TaskHandle) -> Self {
55        Self(Some(handle))
56    }
57}
58
59impl Drop for HandleContainer {
60    fn drop(&mut self) {
61        match self.0.take() {
62            None => {}
63            #[cfg(feature = "tokio-comp")]
64            Some(TaskHandle::Tokio(handle)) => handle.abort(),
65            #[cfg(feature = "async-std-comp")]
66            Some(TaskHandle::AsyncStd(handle)) => {
67                // schedule for cancellation without waiting for result.
68                // TODO - can we cancel the task without awaiting its completion?
69                Runtime::locate().spawn(async move { handle.cancel().await.unwrap_or_default() });
70            }
71            #[cfg(feature = "smol-comp")]
72            Some(TaskHandle::Smol(task)) => drop(task),
73        }
74    }
75}
76
77#[derive(Clone)]
78// we allow dead code here because the container isn't used directly, only in the derived drop.
79#[allow(dead_code)]
80pub(crate) struct SharedHandleContainer(Arc<HandleContainer>);
81
82impl SharedHandleContainer {
83    pub(crate) fn new(handle: TaskHandle) -> Self {
84        Self(Arc::new(HandleContainer::new(handle)))
85    }
86}
87
88#[cfg(any(
89    all(
90        feature = "tokio-comp",
91        any(feature = "async-std-comp", feature = "smol-comp")
92    ),
93    all(
94        feature = "smol-comp",
95        any(feature = "async-std-comp", feature = "tokio-comp")
96    ),
97    all(
98        feature = "async-std-comp",
99        any(feature = "tokio-comp", feature = "smol-comp")
100    )
101))]
102static CHOSEN_RUNTIME: OnceLock<Runtime> = OnceLock::new();
103
104#[cfg(any(
105    all(
106        feature = "tokio-comp",
107        any(feature = "async-std-comp", feature = "smol-comp")
108    ),
109    all(
110        feature = "smol-comp",
111        any(feature = "async-std-comp", feature = "tokio-comp")
112    ),
113    all(
114        feature = "async-std-comp",
115        any(feature = "tokio-comp", feature = "smol-comp")
116    )
117))]
118fn set_runtime(runtime: Runtime) -> Result<(), RedisError> {
119    const PREFER_RUNTIME_ERROR: &str =
120    "Another runtime preference was already set. Please call this function before any other runtime preference is set.";
121
122    CHOSEN_RUNTIME
123        .set(runtime)
124        .map_err(|_| RedisError::from((crate::ErrorKind::ClientError, PREFER_RUNTIME_ERROR)))
125}
126
127/// Mark Smol as the preferred runtime.
128///
129/// If the function returns `Err`, another runtime preference was already set, and won't be changed.
130/// Call this function if the application doesn't use multiple runtimes,
131/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided.
132#[cfg(all(
133    feature = "smol-comp",
134    any(feature = "async-std-comp", feature = "tokio-comp")
135))]
136pub fn prefer_smol() -> Result<(), RedisError> {
137    set_runtime(Runtime::Smol)
138}
139
140/// Mark async-std compliant runtimes, such as smol, as the preferred runtime.
141///
142/// If the function returns `Err`, another runtime preference was already set, and won't be changed.
143/// Call this function if the application doesn't use multiple runtimes,
144/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided.
145#[cfg(all(
146    feature = "async-std-comp",
147    any(feature = "tokio-comp", feature = "smol-comp")
148))]
149pub fn prefer_async_std() -> Result<(), RedisError> {
150    set_runtime(Runtime::AsyncStd)
151}
152
153/// Mark Tokio as the preferred runtime.
154///
155/// If the function returns `Err`, another runtime preference was already set, and won't be changed.
156/// Call this function if the application doesn't use multiple runtimes,
157/// but the crate is compiled with multiple runtimes enabled, which is a bad pattern that should be avoided.
158#[cfg(all(
159    feature = "tokio-comp",
160    any(feature = "async-std-comp", feature = "smol-comp")
161))]
162pub fn prefer_tokio() -> Result<(), RedisError> {
163    set_runtime(Runtime::Tokio)
164}
165
166impl Runtime {
167    pub(crate) fn locate() -> Self {
168        #[cfg(any(
169            all(
170                feature = "tokio-comp",
171                any(feature = "async-std-comp", feature = "smol-comp")
172            ),
173            all(
174                feature = "smol-comp",
175                any(feature = "async-std-comp", feature = "tokio-comp")
176            ),
177            all(
178                feature = "async-std-comp",
179                any(feature = "tokio-comp", feature = "smol-comp")
180            )
181        ))]
182        if let Some(runtime) = CHOSEN_RUNTIME.get() {
183            return *runtime;
184        }
185
186        #[cfg(all(
187            feature = "tokio-comp",
188            not(feature = "async-std-comp"),
189            not(feature = "smol-comp")
190        ))]
191        {
192            Runtime::Tokio
193        }
194
195        #[cfg(all(
196            not(feature = "tokio-comp"),
197            not(feature = "smol-comp"),
198            feature = "async-std-comp"
199        ))]
200        {
201            Runtime::AsyncStd
202        }
203
204        #[cfg(all(
205            not(feature = "tokio-comp"),
206            feature = "smol-comp",
207            not(feature = "async-std-comp")
208        ))]
209        {
210            Runtime::Smol
211        }
212
213        cfg_if::cfg_if! {
214        if #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] {
215            if ::tokio::runtime::Handle::try_current().is_ok() {
216                Runtime::Tokio
217            } else {
218                Runtime::AsyncStd
219            }
220        } else if #[cfg(all(feature = "tokio-comp", feature = "smol-comp"))] {
221            if ::tokio::runtime::Handle::try_current().is_ok() {
222                Runtime::Tokio
223            } else {
224                Runtime::Smol
225            }
226        } else if #[cfg(all(feature = "smol-comp", feature = "async-std-comp"))]
227            {
228                Runtime::AsyncStd
229            }
230        }
231
232        #[cfg(all(
233            not(feature = "tokio-comp"),
234            not(feature = "async-std-comp"),
235            not(feature = "smol-comp")
236        ))]
237        {
238            compile_error!(
239                "tokio-comp, async-std-comp, or smol-comp features required for aio feature"
240            )
241        }
242    }
243
244    #[allow(dead_code)]
245    pub(crate) fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
246        match self {
247            #[cfg(feature = "tokio-comp")]
248            Runtime::Tokio => crate_tokio::Tokio::spawn(f),
249            #[cfg(feature = "async-std-comp")]
250            Runtime::AsyncStd => crate_async_std::AsyncStd::spawn(f),
251            #[cfg(feature = "smol-comp")]
252            Runtime::Smol => crate_smol::Smol::spawn(f),
253        }
254    }
255
256    pub(crate) async fn timeout<F: Future>(
257        &self,
258        duration: Duration,
259        future: F,
260    ) -> Result<F::Output, Elapsed> {
261        match self {
262            #[cfg(feature = "tokio-comp")]
263            Runtime::Tokio => tokio::time::timeout(duration, future)
264                .await
265                .map_err(|_| Elapsed(())),
266            #[cfg(feature = "async-std-comp")]
267            Runtime::AsyncStd => async_std::future::timeout(duration, future)
268                .await
269                .map_err(|_| Elapsed(())),
270            #[cfg(feature = "smol-comp")]
271            Runtime::Smol => future.timeout(duration).await.ok_or(Elapsed(())),
272        }
273    }
274
275    #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
276    pub(crate) async fn sleep(&self, duration: Duration) {
277        match self {
278            #[cfg(feature = "tokio-comp")]
279            Runtime::Tokio => {
280                tokio::time::sleep(duration).await;
281            }
282            #[cfg(feature = "async-std-comp")]
283            Runtime::AsyncStd => {
284                async_std::task::sleep(duration).await;
285            }
286            #[cfg(feature = "smol-comp")]
287            Runtime::Smol => {
288                smol::Timer::after(duration).await;
289            }
290        }
291    }
292
293    #[cfg(feature = "cluster-async")]
294    pub(crate) async fn locate_and_sleep(duration: Duration) {
295        Self::locate().sleep(duration).await
296    }
297}
298
299#[derive(Debug)]
300pub(crate) struct Elapsed(());
301
302impl From<Elapsed> for RedisError {
303    fn from(_: Elapsed) -> Self {
304        io::Error::from(io::ErrorKind::TimedOut).into()
305    }
306}