Skip to content

Commit a9b0c38

Browse files
committed
Fix flake8 issues in reawaitable.py by using Literal instead of Enum
1 parent a396766 commit a9b0c38

File tree

1 file changed

+37
-31
lines changed

1 file changed

+37
-31
lines changed

returns/primitives/reawaitable.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections.abc import Awaitable, Callable, Generator
22
from functools import wraps
3-
from typing import NewType, ParamSpec, Protocol, TypeVar, cast, final
4-
3+
from typing import Literal, NewType, ParamSpec, Protocol, TypeVar, cast, final
4+
# Always import asyncio
5+
import asyncio
56

67
class AsyncLock(Protocol):
78
"""A protocol for an asynchronous lock."""
@@ -13,26 +14,17 @@ async def __aenter__(self) -> None: ...
1314
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: ...
1415

1516

16-
# Import both libraries if available
17-
import asyncio # noqa: WPS433
18-
from enum import Enum, auto
19-
20-
21-
class AsyncContext(Enum):
22-
"""Enum representing different async context types."""
23-
24-
ASYNCIO = auto()
25-
TRIO = auto()
26-
UNKNOWN = auto()
17+
# Define context types as literals
18+
AsyncContext = Literal["asyncio", "trio", "unknown"]
2719

2820

2921
# Check for anyio and trio availability
3022
try:
31-
import anyio # noqa: WPS433
23+
import anyio # pragma: no qa
3224

3325
has_anyio = True
3426
try:
35-
import trio # noqa: WPS433
27+
import trio # pragma: no qa
3628

3729
has_trio = True
3830
except ImportError: # pragma: no cover
@@ -42,27 +34,40 @@ class AsyncContext(Enum):
4234
has_trio = False
4335

4436

37+
def _is_in_trio_context() -> bool:
38+
"""Check if we're in a trio context.
39+
40+
Returns:
41+
bool: True if we're in a trio context
42+
"""
43+
if not has_trio:
44+
return False
45+
46+
# Import trio here since we already checked it's available
47+
import trio
48+
49+
try:
50+
# Will raise RuntimeError if not in trio context
51+
trio.lowlevel.current_task()
52+
except (RuntimeError, AttributeError):
53+
return False
54+
return True
55+
56+
4557
def detect_async_context() -> AsyncContext:
4658
"""Detect which async context we're currently running in.
4759
4860
Returns:
4961
AsyncContext: The current async context type
5062
"""
5163
if not has_anyio: # pragma: no cover
52-
return AsyncContext.ASYNCIO
53-
54-
if has_trio:
55-
try:
56-
# Check if we're in a trio context
57-
# Will raise RuntimeError if not in trio context
58-
trio.lowlevel.current_task()
59-
return AsyncContext.TRIO
60-
except (RuntimeError, AttributeError):
61-
# Not in a trio context or trio API changed
62-
pass
64+
return "asyncio"
65+
66+
if _is_in_trio_context():
67+
return "trio"
6368

6469
# Default to asyncio
65-
return AsyncContext.ASYNCIO
70+
return "asyncio"
6671

6772

6873
_ValueType = TypeVar('_ValueType')
@@ -121,7 +126,7 @@ def __init__(self, coro: Awaitable[_ValueType]) -> None:
121126
"""We need just an awaitable to work with."""
122127
self._coro = coro
123128
self._cache: _ValueType | _Sentinel = _sentinel
124-
self._lock = None # Will be created lazily based on the backend
129+
self._lock: AsyncLock | None = None # Will be created lazily based on the backend
125130

126131
def __await__(self) -> Generator[None, None, _ValueType]:
127132
"""
@@ -171,10 +176,11 @@ def _create_lock(self) -> AsyncLock:
171176
"""Create the appropriate lock based on the current async context."""
172177
context = detect_async_context()
173178

174-
if context == AsyncContext.TRIO and has_anyio:
179+
if context == "trio" and has_anyio:
180+
import anyio
175181
return anyio.Lock()
176182

177-
# For ASYNCIO or UNKNOWN contexts
183+
# For asyncio or unknown contexts
178184
return asyncio.Lock()
179185

180186
async def _awaitable(self) -> _ValueType:
@@ -222,4 +228,4 @@ def decorator(
222228
) -> _AwaitableT:
223229
return ReAwaitable(coro(*args, **kwargs)) # type: ignore[return-value]
224230

225-
return decorator
231+
return decorator

0 commit comments

Comments
 (0)