1use std::{io, sync::Arc, time::Duration};
2
3use futures_util::Future;
4
5#[cfg(any(
6 all(
7 feature = "tokio-comp",
8 any(feature = "async-std-comp", feature = "smol-comp")
9 ),
10 all(
11 feature = "smol-comp",
12 any(feature = "async-std-comp", feature = "tokio-comp")
13 ),
14 all(
15 feature = "async-std-comp",
16 any(feature = "tokio-comp", feature = "smol-comp")
17 )
18))]
19use std::sync::OnceLock;
20
21#[cfg(feature = "async-std-comp")]
22use super::async_std as crate_async_std;
23#[cfg(feature = "smol-comp")]
24use super::smol as crate_smol;
25#[cfg(feature = "tokio-comp")]
26use super::tokio as crate_tokio;
27use super::RedisRuntime;
28use crate::types::RedisError;
29#[cfg(feature = "smol-comp")]
30use smol_timeout::TimeoutExt;
31
32#[derive(Clone, Copy, Debug)]
33pub(crate) enum Runtime {
34 #[cfg(feature = "tokio-comp")]
35 Tokio,
36 #[cfg(feature = "async-std-comp")]
37 AsyncStd,
38 #[cfg(feature = "smol-comp")]
39 Smol,
40}
41
42pub(crate) enum TaskHandle {
43 #[cfg(feature = "tokio-comp")]
44 Tokio(tokio::task::JoinHandle<()>),
45 #[cfg(feature = "async-std-comp")]
46 AsyncStd(async_std::task::JoinHandle<()>),
47 #[cfg(feature = "smol-comp")]
48 Smol(smol::Task<()>),
49}
50
51pub(crate) struct HandleContainer(Option<TaskHandle>);
52
53impl HandleContainer {
54 pub(crate) fn new(handle: TaskHandle) -> Self {
55 Self(Some(handle))
56 }
57}
58
59impl Drop for HandleContainer {
60 fn drop(&mut self) {
61 match self.0.take() {
62 None => {}
63 #[cfg(feature = "tokio-comp")]
64 Some(TaskHandle::Tokio(handle)) => handle.abort(),
65 #[cfg(feature = "async-std-comp")]
66 Some(TaskHandle::AsyncStd(handle)) => {
67 Runtime::locate().spawn(async move { handle.cancel().await.unwrap_or_default() });
70 }
71 #[cfg(feature = "smol-comp")]
72 Some(TaskHandle::Smol(task)) => drop(task),
73 }
74 }
75}
76
77#[derive(Clone)]
78#[allow(dead_code)]
80pub(crate) struct SharedHandleContainer(Arc<HandleContainer>);
81
82impl SharedHandleContainer {
83 pub(crate) fn new(handle: TaskHandle) -> Self {
84 Self(Arc::new(HandleContainer::new(handle)))
85 }
86}
87
88#[cfg(any(
89 all(
90 feature = "tokio-comp",
91 any(feature = "async-std-comp", feature = "smol-comp")
92 ),
93 all(
94 feature = "smol-comp",
95 any(feature = "async-std-comp", feature = "tokio-comp")
96 ),
97 all(
98 feature = "async-std-comp",
99 any(feature = "tokio-comp", feature = "smol-comp")
100 )
101))]
102static CHOSEN_RUNTIME: OnceLock<Runtime> = OnceLock::new();
103
104#[cfg(any(
105 all(
106 feature = "tokio-comp",
107 any(feature = "async-std-comp", feature = "smol-comp")
108 ),
109 all(
110 feature = "smol-comp",
111 any(feature = "async-std-comp", feature = "tokio-comp")
112 ),
113 all(
114 feature = "async-std-comp",
115 any(feature = "tokio-comp", feature = "smol-comp")
116 )
117))]
118fn set_runtime(runtime: Runtime) -> Result<(), RedisError> {
119 const PREFER_RUNTIME_ERROR: &str =
120 "Another runtime preference was already set. Please call this function before any other runtime preference is set.";
121
122 CHOSEN_RUNTIME
123 .set(runtime)
124 .map_err(|_| RedisError::from((crate::ErrorKind::ClientError, PREFER_RUNTIME_ERROR)))
125}
126
127#[cfg(all(
133 feature = "smol-comp",
134 any(feature = "async-std-comp", feature = "tokio-comp")
135))]
136pub fn prefer_smol() -> Result<(), RedisError> {
137 set_runtime(Runtime::Smol)
138}
139
140#[cfg(all(
146 feature = "async-std-comp",
147 any(feature = "tokio-comp", feature = "smol-comp")
148))]
149pub fn prefer_async_std() -> Result<(), RedisError> {
150 set_runtime(Runtime::AsyncStd)
151}
152
153#[cfg(all(
159 feature = "tokio-comp",
160 any(feature = "async-std-comp", feature = "smol-comp")
161))]
162pub fn prefer_tokio() -> Result<(), RedisError> {
163 set_runtime(Runtime::Tokio)
164}
165
166impl Runtime {
167 pub(crate) fn locate() -> Self {
168 #[cfg(any(
169 all(
170 feature = "tokio-comp",
171 any(feature = "async-std-comp", feature = "smol-comp")
172 ),
173 all(
174 feature = "smol-comp",
175 any(feature = "async-std-comp", feature = "tokio-comp")
176 ),
177 all(
178 feature = "async-std-comp",
179 any(feature = "tokio-comp", feature = "smol-comp")
180 )
181 ))]
182 if let Some(runtime) = CHOSEN_RUNTIME.get() {
183 return *runtime;
184 }
185
186 #[cfg(all(
187 feature = "tokio-comp",
188 not(feature = "async-std-comp"),
189 not(feature = "smol-comp")
190 ))]
191 {
192 Runtime::Tokio
193 }
194
195 #[cfg(all(
196 not(feature = "tokio-comp"),
197 not(feature = "smol-comp"),
198 feature = "async-std-comp"
199 ))]
200 {
201 Runtime::AsyncStd
202 }
203
204 #[cfg(all(
205 not(feature = "tokio-comp"),
206 feature = "smol-comp",
207 not(feature = "async-std-comp")
208 ))]
209 {
210 Runtime::Smol
211 }
212
213 cfg_if::cfg_if! {
214 if #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] {
215 if ::tokio::runtime::Handle::try_current().is_ok() {
216 Runtime::Tokio
217 } else {
218 Runtime::AsyncStd
219 }
220 } else if #[cfg(all(feature = "tokio-comp", feature = "smol-comp"))] {
221 if ::tokio::runtime::Handle::try_current().is_ok() {
222 Runtime::Tokio
223 } else {
224 Runtime::Smol
225 }
226 } else if #[cfg(all(feature = "smol-comp", feature = "async-std-comp"))]
227 {
228 Runtime::AsyncStd
229 }
230 }
231
232 #[cfg(all(
233 not(feature = "tokio-comp"),
234 not(feature = "async-std-comp"),
235 not(feature = "smol-comp")
236 ))]
237 {
238 compile_error!(
239 "tokio-comp, async-std-comp, or smol-comp features required for aio feature"
240 )
241 }
242 }
243
244 #[allow(dead_code)]
245 pub(crate) fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
246 match self {
247 #[cfg(feature = "tokio-comp")]
248 Runtime::Tokio => crate_tokio::Tokio::spawn(f),
249 #[cfg(feature = "async-std-comp")]
250 Runtime::AsyncStd => crate_async_std::AsyncStd::spawn(f),
251 #[cfg(feature = "smol-comp")]
252 Runtime::Smol => crate_smol::Smol::spawn(f),
253 }
254 }
255
256 pub(crate) async fn timeout<F: Future>(
257 &self,
258 duration: Duration,
259 future: F,
260 ) -> Result<F::Output, Elapsed> {
261 match self {
262 #[cfg(feature = "tokio-comp")]
263 Runtime::Tokio => tokio::time::timeout(duration, future)
264 .await
265 .map_err(|_| Elapsed(())),
266 #[cfg(feature = "async-std-comp")]
267 Runtime::AsyncStd => async_std::future::timeout(duration, future)
268 .await
269 .map_err(|_| Elapsed(())),
270 #[cfg(feature = "smol-comp")]
271 Runtime::Smol => future.timeout(duration).await.ok_or(Elapsed(())),
272 }
273 }
274
275 #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
276 pub(crate) async fn sleep(&self, duration: Duration) {
277 match self {
278 #[cfg(feature = "tokio-comp")]
279 Runtime::Tokio => {
280 tokio::time::sleep(duration).await;
281 }
282 #[cfg(feature = "async-std-comp")]
283 Runtime::AsyncStd => {
284 async_std::task::sleep(duration).await;
285 }
286 #[cfg(feature = "smol-comp")]
287 Runtime::Smol => {
288 smol::Timer::after(duration).await;
289 }
290 }
291 }
292
293 #[cfg(feature = "cluster-async")]
294 pub(crate) async fn locate_and_sleep(duration: Duration) {
295 Self::locate().sleep(duration).await
296 }
297}
298
299#[derive(Debug)]
300pub(crate) struct Elapsed(());
301
302impl From<Elapsed> for RedisError {
303 fn from(_: Elapsed) -> Self {
304 io::Error::from(io::ErrorKind::TimedOut).into()
305 }
306}