Skip to content

Commit 8734b65

Browse files
committed
✨Array class
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 41fb7b3 commit 8734b65

File tree

4 files changed

+25
-2
lines changed

4 files changed

+25
-2
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Static typing support for the array API standard."""
22

33
__all__ = (
4+
"Array",
45
"HasArrayNamespace",
56
"__version__",
67
"__version_tuple__",
78
)
89

9-
from ._array import HasArrayNamespace
10+
from ._array import Array, HasArrayNamespace
1011
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
__all__ = ("HasArrayNamespace",)
1+
__all__ = (
2+
"Array",
3+
"HasArrayNamespace",
4+
)
25

36
from types import ModuleType
47
from typing import Literal, Protocol
@@ -33,3 +36,10 @@ class HasArrayNamespace(Protocol[NamespaceT_co]):
3336
def __array_namespace__(
3437
self, /, *, api_version: Literal["2021.12"] | None = None
3538
) -> NamespaceT_co: ...
39+
40+
41+
class Array(
42+
HasArrayNamespace[NamespaceT_co],
43+
Protocol[NamespaceT_co],
44+
):
45+
"""Array API specification for array object attributes and methods."""

tests/integration/test_numpy1p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@ ns: ModuleType = a_ns.__array_namespace__()
2828
# Incorrect values are caught when using `__array_namespace__` and
2929
# backpropagated to the type of `a_ns`
3030
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
31+
32+
# =========================================================
33+
# `xpt.Array`
34+
35+
# Check NamespaceT_co assignment
36+
a_ns: xpt.Array[ModuleType] = nparr

tests/integration/test_numpy2p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ ns: ModuleType = a_ns.__array_namespace__()
3434
# Incorrect values are caught when using `__array_namespace__` and
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
37+
38+
# =========================================================
39+
# `xpt.Array`
40+
41+
# Check NamespaceT_co assignment
42+
a_ns: xpt.Array[ModuleType] = nparr

0 commit comments

Comments
 (0)