1
1
from collections .abc import Awaitable , Callable , Generator
2
2
from functools import wraps
3
- from typing import NewType , ParamSpec , Protocol , TypeVar , cast , final
4
-
3
+ from typing import Literal , NewType , ParamSpec , Protocol , TypeVar , cast , final
4
+ # Always import asyncio
5
+ import asyncio
5
6
6
7
class AsyncLock (Protocol ):
7
8
"""A protocol for an asynchronous lock."""
@@ -13,26 +14,17 @@ async def __aenter__(self) -> None: ...
13
14
async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None : ...
14
15
15
16
16
- # Import both libraries if available
17
- import asyncio # noqa: WPS433
18
- from enum import Enum , auto
19
-
20
-
21
- class AsyncContext (Enum ):
22
- """Enum representing different async context types."""
23
-
24
- ASYNCIO = auto ()
25
- TRIO = auto ()
26
- UNKNOWN = auto ()
17
+ # Define context types as literals
18
+ AsyncContext = Literal ["asyncio" , "trio" , "unknown" ]
27
19
28
20
29
21
# Check for anyio and trio availability
30
22
try :
31
- import anyio # noqa: WPS433
23
+ import anyio # pragma: no qa
32
24
33
25
has_anyio = True
34
26
try :
35
- import trio # noqa: WPS433
27
+ import trio # pragma: no qa
36
28
37
29
has_trio = True
38
30
except ImportError : # pragma: no cover
@@ -42,27 +34,40 @@ class AsyncContext(Enum):
42
34
has_trio = False
43
35
44
36
37
+ def _is_in_trio_context () -> bool :
38
+ """Check if we're in a trio context.
39
+
40
+ Returns:
41
+ bool: True if we're in a trio context
42
+ """
43
+ if not has_trio :
44
+ return False
45
+
46
+ # Import trio here since we already checked it's available
47
+ import trio
48
+
49
+ try :
50
+ # Will raise RuntimeError if not in trio context
51
+ trio .lowlevel .current_task ()
52
+ except (RuntimeError , AttributeError ):
53
+ return False
54
+ return True
55
+
56
+
45
57
def detect_async_context () -> AsyncContext :
46
58
"""Detect which async context we're currently running in.
47
59
48
60
Returns:
49
61
AsyncContext: The current async context type
50
62
"""
51
63
if not has_anyio : # pragma: no cover
52
- return AsyncContext .ASYNCIO
53
-
54
- if has_trio :
55
- try :
56
- # Check if we're in a trio context
57
- # Will raise RuntimeError if not in trio context
58
- trio .lowlevel .current_task ()
59
- return AsyncContext .TRIO
60
- except (RuntimeError , AttributeError ):
61
- # Not in a trio context or trio API changed
62
- pass
64
+ return "asyncio"
65
+
66
+ if _is_in_trio_context ():
67
+ return "trio"
63
68
64
69
# Default to asyncio
65
- return AsyncContext . ASYNCIO
70
+ return "asyncio"
66
71
67
72
68
73
_ValueType = TypeVar ('_ValueType' )
@@ -121,7 +126,7 @@ def __init__(self, coro: Awaitable[_ValueType]) -> None:
121
126
"""We need just an awaitable to work with."""
122
127
self ._coro = coro
123
128
self ._cache : _ValueType | _Sentinel = _sentinel
124
- self ._lock = None # Will be created lazily based on the backend
129
+ self ._lock : AsyncLock | None = None # Will be created lazily based on the backend
125
130
126
131
def __await__ (self ) -> Generator [None , None , _ValueType ]:
127
132
"""
@@ -171,10 +176,11 @@ def _create_lock(self) -> AsyncLock:
171
176
"""Create the appropriate lock based on the current async context."""
172
177
context = detect_async_context ()
173
178
174
- if context == AsyncContext .TRIO and has_anyio :
179
+ if context == "trio" and has_anyio :
180
+ import anyio
175
181
return anyio .Lock ()
176
182
177
- # For ASYNCIO or UNKNOWN contexts
183
+ # For asyncio or unknown contexts
178
184
return asyncio .Lock ()
179
185
180
186
async def _awaitable (self ) -> _ValueType :
@@ -222,4 +228,4 @@ def decorator(
222
228
) -> _AwaitableT :
223
229
return ReAwaitable (coro (* args , ** kwargs )) # type: ignore[return-value]
224
230
225
- return decorator
231
+ return decorator
0 commit comments