-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Problem
We use the type-annotation Scalar = int | float | jax.Array
frequently in the code-base. The jax.Array
is included because certain functions receive floats
and return zero-dimensional jax.Array
's.
This is not ideal, since
- We would like the type-checker to be able to tell the difference between scalars and arrays with at least one dimension, and
- It can be confusing for people reading the code.
Second issue is that Array
types do not include dtype information, this would be helpful in many cases (e.g., for grids, distinguishing between indices and values). See here.
ToDo's:
- Create list of code-blocks that actually require
Scalar
to containjax.Array
- Add dtype info to jax.Array
Solution ideas
jaxtyping
Could then do:
from jaxtyping import Array, Float, Int
Scalar = int | float | Float[Array, ""] | Float[Int, ""]
preventing use of Scalar
when we really need a 1+ dimensional array?
jaxtyping would also allow annotating Arrays with dtype info.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request