Skip to content

ENH: Better type-annotations for scalars & Jax arrays #107

@timmens

Description

@timmens

Problem

We use the type-annotation Scalar = int | float | jax.Array frequently in the code-base. The jax.Array is included because certain functions receive floats and return zero-dimensional jax.Array's.

This is not ideal, since

  1. We would like the type-checker to be able to tell the difference between scalars and arrays with at least one dimension, and
  2. It can be confusing for people reading the code.

Second issue is that Array types do not include dtype information, this would be helpful in many cases (e.g., for grids, distinguishing between indices and values). See here.

ToDo's:

  • Create list of code-blocks that actually require Scalar to contain jax.Array
  • Add dtype info to jax.Array

Solution ideas

jaxtyping

Could then do:

from jaxtyping import Array, Float, Int

Scalar = int | float | Float[Array, ""] | Float[Int, ""]

preventing use of Scalar when we really need a 1+ dimensional array?

jaxtyping would also allow annotating Arrays with dtype info.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions