Skip to content

Commit 5ad0d5c

Browse files
committed
✨HasDType
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 3899a12 commit 5ad0d5c

File tree

4 files changed

+65
-7
lines changed

4 files changed

+65
-7
lines changed

src/array_api_typing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
__all__ = (
44
"Array",
55
"HasArrayNamespace",
6+
"HasDType",
67
"__version__",
78
"__version_tuple__",
89
)
910

10-
from ._array import Array, HasArrayNamespace
11+
from ._array import Array, HasArrayNamespace, HasDType
1112
from ._version import version as __version__, version_tuple as __version_tuple__

src/array_api_typing/_array.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import TypeVar
99

1010
NamespaceT_co = TypeVar("NamespaceT_co", covariant=True, default=ModuleType)
11+
DTypeT_co = TypeVar("DTypeT_co", covariant=True)
1112

1213

1314
class HasArrayNamespace(Protocol[NamespaceT_co]):
@@ -38,8 +39,32 @@ def __array_namespace__(
3839
) -> NamespaceT_co: ...
3940

4041

42+
class HasDType(Protocol[DTypeT_co]):
43+
"""Protocol for array classes that have a data type attribute."""
44+
45+
@property
46+
def dtype(self, /) -> DTypeT_co:
47+
"""Data type of the array elements."""
48+
...
49+
50+
4151
class Array(
42-
HasArrayNamespace[NamespaceT_co],
43-
Protocol[NamespaceT_co],
52+
# ------ Attributes -------
53+
HasDType[DTypeT_co],
54+
# -------------------------
55+
Protocol[DTypeT_co, NamespaceT_co],
4456
):
45-
"""Array API specification for array object attributes and methods."""
57+
"""Array API specification for array object attributes and methods.
58+
59+
The type is: ``Array[+DTypeT, +NamespaceT = ModuleType] = Array[DTypeT,
60+
NamespaceT]`` where:
61+
62+
- `DTypeT` is the data type of the array elements.
63+
- `NamespaceT` is the type of the array namespace. It defaults to
64+
`ModuleType`, which is the most common form of array namespace (e.g.,
65+
`numpy`, `cupy`, etc.). However, it can be any type, e.g. a
66+
`types.SimpleNamespace`, to allow for wrapper libraries to
67+
semi-dynamically define their own array namespaces based on the wrapped
68+
array type.
69+
70+
"""

tests/integration/test_numpy1p0.pyi

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: disable-error-code="no-redef"
22

33
from types import ModuleType
4-
from typing import TypeAlias
4+
from typing import Any
55

66
import numpy.array_api as np # type: ignore[import-not-found, unused-ignore]
77

@@ -29,8 +29,23 @@ ns: ModuleType = a_ns.__array_namespace__()
2929
# backpropagated to the type of `a_ns`
3030
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3131

32+
# =========================================================
33+
# `xpt.HasDType`
34+
35+
# Check DTypeT_co assignment
36+
_: xpt.HasDType[Any] = nparr
37+
_: xpt.HasDType[np.dtype[np.int32]] = nparr_i32
38+
_: xpt.HasDType[np.dtype[np.float32]] = nparr_f32
39+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
40+
3241
# =========================================================
3342
# `xpt.Array`
3443

3544
# Check NamespaceT_co assignment
36-
a_ns: xpt.Array[ModuleType] = nparr
45+
a_ns: xpt.Array[Any, ModuleType] = nparr
46+
47+
# Check DTypeT_co assignment
48+
_: xpt.Array[Any] = nparr
49+
_: xpt.Array[np.dtype[np.int32]] = nparr_i32
50+
_: xpt.Array[np.dtype[np.float32]] = nparr_f32
51+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

tests/integration/test_numpy2p0.pyi

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import numpy.typing as npt
99
import array_api_typing as xpt
1010

1111
# DType aliases
12+
F: TypeAlias = np.floating[Any]
1213
F32: TypeAlias = np.float32
14+
I: TypeAlias = np.integer[Any]
1315
I32: TypeAlias = np.int32
1416

1517
# Define NDArrays against which we can test the protocols
@@ -35,8 +37,23 @@ ns: ModuleType = a_ns.__array_namespace__()
3537
# backpropagated to the type of `a_ns`
3638
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught
3739

40+
# =========================================================
41+
# `xpt.HasDType`
42+
43+
# Check DTypeT_co assignment
44+
_: xpt.HasDType[Any] = nparr
45+
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
46+
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
47+
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
48+
3849
# =========================================================
3950
# `xpt.Array`
4051

4152
# Check NamespaceT_co assignment
42-
a_ns: xpt.Array[ModuleType] = nparr
53+
a_ns: xpt.Array[Any, ModuleType] = nparr
54+
55+
# Check DTypeT_co assignment
56+
_: xpt.Array[Any] = nparr
57+
_: xpt.Array[np.dtype[I32]] = nparr_i32
58+
_: xpt.Array[np.dtype[F32]] = nparr_f32
59+
_: xpt.Array[np.dtype[np.bool_]] = nparr_b

0 commit comments

Comments
 (0)