Skip to content

Commit 65256af

Browse files
committed
✨Array class
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 255f453 commit 65256af

File tree

4 files changed

+42
-3
lines changed

4 files changed

+42
-3
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Static typing support for the array API standard."""
22

33
__all__ = (
4+
"Array",
45
"HasArrayNamespace",
56
"__version__",
67
"__version_tuple__",
78
)
89

9-
from ._array import HasArrayNamespace
10+
from ._array import Array, HasArrayNamespace
1011
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
__all__ = ("HasArrayNamespace",)
1+
__all__ = (
2+
"Array",
3+
"HasArrayNamespace",
4+
)
25

36
from types import ModuleType
47
from typing import Literal, Protocol
@@ -12,7 +15,7 @@ class HasArrayNamespace(Protocol[NamespaceT_co]):
1215
1316
This `Protocol` is intended for use in static typing to ensure that an
1417
object has an `__array_namespace__` method that returns a namespace for
15-
array operations. This `Protocol` should not be used at runtime, for type
18+
array operations. This `Protocol` should not be used at runtime for type
1619
checking or as a base class.
1720
1821
Example:
@@ -52,3 +55,26 @@ def __array_namespace__(
5255
5356
"""
5457
...
58+
59+
60+
class Array(
61+
HasArrayNamespace[NamespaceT_co],
62+
# -------------------------
63+
Protocol[NamespaceT_co],
64+
):
65+
"""Array API specification for array object attributes and methods.
66+
67+
The type is: ``Array[+NamespaceT = ModuleType] = Array[NamespaceT]`` where:
68+
69+
- `NamespaceT` is the type of the array namespace. It defaults to
70+
`ModuleType`, which is the most common form of array namespace (e.g.,
71+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
72+
`types.SimpleNamespace`, to allow for wrapper libraries to
73+
semi-dynamically define their own array namespaces based on the wrapped
74+
array type.
75+
76+
This type is intended for use in static typing to ensure that an object has
77+
the attributes and methods defined in the array API specification. It should
78+
not be used at runtime for type checking or as a base class.
79+
80+
"""

tests/integration/test_numpy1p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ ns: ModuleType = a_ns.__array_namespace__()
2727
# Incorrect values are caught when using `__array_namespace__` and
2828
# backpropagated to the type of `a_ns`
2929
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
30+
31+
# =========================================================
32+
# `xpt.Array`
33+
34+
# Check NamespaceT_co assignment
35+
a_ns: xpt.Array[ModuleType] = nparr

tests/integration/test_numpy2p0.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ ns: ModuleType = a_ns.__array_namespace__()
3434
# Incorrect values are caught when using `__array_namespace__` and
3535
# backpropagated to the type of `a_ns`
3636
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
37+
38+
# =========================================================
39+
# `xpt.Array`
40+
41+
# Check NamespaceT_co assignment
42+
a_ns: xpt.Array[ModuleType] = nparr

0 commit comments

Comments
 (0)