Skip to content

Commit 41fb7b3

Browse files
committed
🚚 move HasArrayNamespace
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 6d54ce7 commit 41fb7b3

File tree

9 files changed

+112
-35
lines changed

9 files changed

+112
-35
lines changed

‎.github/workflows/ci.yml

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,15 @@ jobs:
8888
python-version: "3.11"
8989
activate-environment: true
9090

91-
- name: get major numpy version
92-
id: numpy-major
91+
- name: get major.minor numpy version
92+
id: numpy-version
9393
run: |
94-
version=$(echo ${{ matrix.numpy-version }} | cut -c 1)
95-
echo "::set-output name=version::$version"
94+
version="${{ matrix.numpy-version }}"
95+
major=$(echo "$version" | cut -d. -f1)
96+
minor=$(echo "$version" | cut -d. -f2)
97+
98+
echo "major=$major" >> $GITHUB_OUTPUT
99+
echo "minor=$minor" >> $GITHUB_OUTPUT
96100
97101
- name: install deps
98102
run: |
@@ -101,10 +105,29 @@ jobs:
101105
102106
# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
103107
- name: mypy
104-
run: >
105-
uv run --no-sync --active
106-
mypy --tb --no-incremental --cache-dir=/dev/null
107-
tests/integration/test_numpy${{ steps.numpy-major.outputs.version }}.pyi
108+
run: |
109+
major="${{ steps.numpy-version.outputs.major }}"
110+
minor="${{ steps.numpy-version.outputs.minor }}"
111+
112+
# Directory containing versioned test files
113+
prefix="tests/integration"
114+
files=""
115+
116+
# Find all test files matching the current major version
117+
for path in $(find "$prefix" -name "test_numpy${major}p*.pyi"); do
118+
# Extract file name
119+
fname=$(basename "$path")
120+
# Parse the minor version from the filename
121+
fminor=$(echo "$fname" | sed -E "s/test_numpy${major}p([0-9]+)\.pyi/\1/")
122+
# Include files where minor version ≤ NumPy's minor
123+
if [ "$fminor" -le "$minor" ]; then
124+
files="$files $path"
125+
fi
126+
done
127+
128+
uv run --no-sync --active \
129+
mypy --tb --no-incremental --cache-dir=/dev/null \
130+
$files
108131
109132
# TODO: (based)pyright
110133

‎pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ ignore = [
124124
"FBT", # flake8-boolean-trap
125125
"FIX", # flake8-fixme
126126
"ISC001", # Conflicts with formatter
127+
"PYI041", # Use `float` instead of `int | float`
127128
]
128129

129130
[tool.ruff.lint.pylint]

‎src/array_api_typing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@
66
"__version_tuple__",
77
)
88

9-
from ._namespace import HasArrayNamespace
9+
from ._array import HasArrayNamespace
1010
from ._version import version as __version__, version_tuple as __version_tuple__

‎src/array_api_typing/_namespace.py renamed to ‎src/array_api_typing/_array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
from typing import Literal, Protocol
55
from typing_extensions import TypeVar
66

7-
T_co = TypeVar("T_co", covariant=True, default=ModuleType)
7+
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
88

99

10-
class HasArrayNamespace(Protocol[T_co]):
10+
class HasArrayNamespace(Protocol[NamespaceT_co]):
1111
"""Protocol for classes that have an `__array_namespace__` method.
1212
13+
This `Protocol` is intended for use in static typing to ensure that an
14+
object has an `__array_namespace__` method that returns a namespace for
15+
array operations. This `Protocol` should not be used at runtime, for type
16+
checking or as a base class.
17+
1318
Example:
1419
>>> import array_api_typing as xpt
1520
>>>
@@ -27,4 +32,4 @@ class HasArrayNamespace(Protocol[T_co]):
2732

2833
def __array_namespace__(
2934
self, /, *, api_version: Literal["2021.12"] | None = None
30-
) -> T_co: ...
35+
) -> NamespaceT_co: ...

‎tests/integration/test_numpy1.pyi

Lines changed: 0 additions & 12 deletions
This file was deleted.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import TypeAlias
5+
6+
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
7+
8+
import array_api_typing as xpt
9+
10+
# Define NDArrays against which we can test the protocols
11+
nparr = np.eye(2)
12+
nparr_i32 = np.asarray([1], dtype=np.int32)
13+
nparr_f32 = np.asarray([1.0], dtype=np.float32)
14+
nparr_b = np.asarray([True], dtype=np.bool)
15+
16+
# =========================================================
17+
# `xpt.HasArrayNamespace`
18+
19+
_: xpt.HasArrayNamespace[ModuleType] = nparr
20+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
21+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
22+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
23+
24+
# Check `__array_namespace__` method
25+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
26+
ns: ModuleType = a_ns.__array_namespace__()
27+
28+
# Incorrect values are caught when using `__array_namespace__` and
29+
# backpropagated to the type of `a_ns`
30+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught

‎tests/integration/test_numpy2.pyi

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# mypy: disable-error-code="no-redef"
2+
3+
from types import ModuleType
4+
from typing import Any, TypeAlias
5+
6+
import numpy as np
7+
import numpy.typing as npt
8+
9+
import array_api_typing as xpt
10+
11+
# DType aliases
12+
F32: TypeAlias = np.float32
13+
I32: TypeAlias = np.int32
14+
15+
# Define NDArrays against which we can test the protocols
16+
nparr: npt.NDArray[Any]
17+
nparr_i32: npt.NDArray[I32]
18+
nparr_f32: npt.NDArray[F32]
19+
nparr_b: npt.NDArray[np.bool_]
20+
21+
# =========================================================
22+
# `xpt.HasArrayNamespace`
23+
24+
# Check assignment
25+
_: xpt.HasArrayNamespace[ModuleType] = nparr
26+
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
27+
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
28+
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
29+
30+
# Check `__array_namespace__` method
31+
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
32+
ns: ModuleType = a_ns.__array_namespace__()
33+
34+
# Incorrect values are caught when using `__array_namespace__` and
35+
# backpropagated to the type of `a_ns`
36+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from test_numpy2p0 import nparr
2+
3+
import array_api_typing as xpt
4+
5+
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]

0 commit comments

Comments
 (0)