Skip to content

More stable algorithm for variance, standard deviation #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0f29529
update to nanvar to use more stable algorithm if engine is flox
jemmajeffree Jul 18, 2025
1fbf5f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2025
322f511
[revert] only nanvar test
dcherian Jul 18, 2025
adab8e6
Some mods
dcherian Jul 18, 2025
93cd9b3
Update flox/aggregations.py to neater tuple unpacking
jemmajeffree Jul 21, 2025
2be4f74
Change np.all to all in flox/aggregate_flox.py
jemmajeffree Jul 21, 2025
edb655d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2025
dd2e4b6
delete some resolved comments
jemmajeffree Jul 21, 2025
936ed1d
Remove answered questions in comments
jemmajeffree Jul 21, 2025
1968870
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2025
d036ebc
Merge branch 'main' into var_algorithm
jemmajeffree Jul 21, 2025
12bcb0f
Remove more unnecessary comments
jemmajeffree Jul 21, 2025
6f5bece
Merge branch 'var_algorithm' of github.com:jemmajeffree/flox into var…
jemmajeffree Jul 21, 2025
b1f7b5d
Remove _version.py
jemmajeffree Jul 21, 2025
cd9a8b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2025
27448e4
Add preliminary test for std/var precision
jemmajeffree Jul 31, 2025
10214cc
Merge branch 'var_algorithm' of github.com:jemmajeffree/flox into var…
jemmajeffree Jul 31, 2025
a81b1a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2025
004fddc
Correct comment
jemmajeffree Jul 31, 2025
4491ce9
fix merge conflicts
jemmajeffree Jul 31, 2025
c3a6d88
Update flox/aggregate_flox.py
jemmajeffree Aug 5, 2025
4dcd7c2
Replace some list comprehension with tuple
jemmajeffree Aug 5, 2025
c101a2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2025
98e1b4e
Fixes
dcherian Aug 5, 2025
d0d09df
minor edit for neater test reports.
dcherian Aug 5, 2025
1139a9c
Fix another list/tuple comprehension
jemmajeffree Aug 5, 2025
569629c
implement np.full
jemmajeffree Aug 5, 2025
50ad095
Implement np.full and empty chunks in var_chunk
jemmajeffree Aug 6, 2025
f88e231
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2025
77526fd
update comment
jemmajeffree Aug 6, 2025
0f5d587
Fix merge conflict
jemmajeffree Aug 6, 2025
31f30c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,115 @@
from functools import partial
from typing import Self

import numpy as np

from . import xrdtypes as dtypes
from .xrutils import is_scalar, isnull, notnull

MULTIARRAY_HANDLED_FUNCTIONS = {}


class MultiArray:
arrays: tuple[np.ndarray, ...]

def __init__(self, arrays):
self.arrays = arrays # something else needed here to be more careful about types (not sure what)
# Do we want to co-erce arrays into a tuple and make sure it's immutable? Do we want it to be immutable?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine as-is

assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape"

def astype(self, dt, **kwargs):
return MultiArray(tuple(array.astype(dt, **kwargs) for array in self.arrays))

def reshape(self, shape, **kwargs):
return MultiArray(tuple(array.reshape(shape, **kwargs) for array in self.arrays))

def squeeze(self, axis=None):
return MultiArray(tuple(array.squeeze(axis) for array in self.arrays))

def __array_function__(self, func, types, args, kwargs):
if func not in MULTIARRAY_HANDLED_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle MyArray objects
# if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in?
# return NotImplemented
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs)

# Shape is needed, seems likely that the other two might be
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this data structure isn't useful in general, and is only working around some limitations in the design where we need to pass in multiple intermediates to the combine function. So there will be some ugliness. You have good instincts.

@property
def dtype(self) -> np.dtype:
return self.arrays[0].dtype

@property
def shape(self) -> tuple[int, ...]:
return self.arrays[0].shape

@property
def ndim(self) -> int:
return self.arrays[0].ndim

def __getitem__(self, key) -> Self:
return type(self)([array[key] for array in self.arrays])


def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""

def decorator(func):
MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.expand_dims)
def expand_dims_MultiArray(multiarray, axis):
return MultiArray(tuple(np.expand_dims(a, axis) for a in multiarray.arrays))


@implements(np.concatenate)
def concatenate_MultiArray(multiarrays, axis):
n_arrays = len(multiarrays[0].arrays)
for ma in multiarrays[1:]:
if not (
len(ma.arrays) == n_arrays
): # I don't know what trying to concatenate MultiArrays with different numbers of arrays would even mean
raise NotImplementedError

# There's the potential for problematic different shapes coming in here.
# Probably warrants some defensive programming, but I'm not sure what to check for while still being generic

return MultiArray(
tuple(
np.concatenate(tuple(ma.arrays[i] for ma in multiarrays), axis)
for i in range(multiarrays[0].ndim)
)
) # Is this readable?


@implements(np.transpose)
def transpose_MultiArray(multiarray, axes):
return MultiArray(tuple(np.transpose(a, axes) for a in multiarray.arrays))


@implements(np.full)
def full_MultiArray(
shape, fill_values, *args, **kwargs
): # I've used *args, **kwargs instead of the full argument list to give us more flexibility if numpy changes stuff https://numpy.org/doc/stable/reference/generated/numpy.full.html
"""All arguments except fill_value are shared by each array
in the MultiArray.
Iterate over fill_values to create arrays
"""
return MultiArray(
tuple(
np.full(
shape, fv, *args, **kwargs
) # I'm 90% sure I've used *args, **kwargs correctly here -- could you double-check?
for fv in fill_values
)
)


def _prepare_for_flox(group_idx, array):
"""
Expand Down
131 changes: 122 additions & 9 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import pandas as pd
import toolz as tlz
from numpy.typing import ArrayLike, DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
Expand Down Expand Up @@ -343,12 +344,113 @@ def _mean_finalize(sum_, count):
)


def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=None, dtype=None):
Copy link
Collaborator

@dcherian dcherian Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this here, so that we can generalize to "all" engines. it has some ugliness (notice that it now takes the engine kwarg)

from .aggregate_flox import MultiArray

# Calculate length and sum - important for the adjustment terms to sum squared deviations
array_lens = generic_aggregate(
group_idx,
array,
func="nanlen",
engine=engine,
axis=axis,
size=size,
fill_value=fill_value[2], # Unpack fill value bc it's currently defined for multiarray
dtype=dtype,
)

array_sums = generic_aggregate(
group_idx,
array,
func="nansum",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need to be "sum" for "var".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first thought is to pass through some kind of "are NaNs okay" boolean variable through to var_chunk and var_combine. Is this what xarray's skipna does? Or I think I've seen it done as a string "propogate" or "ignore"? And then to call the var_chunk and var_combine as a partial.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the way I do this in flox is create a var_chunk = partial(_var_chunk, skipna=False) and _nanvar_chunk=partial(_var_chunk, skipna=True) you can stick this in the Aggregation constructor I think

engine=engine,
axis=axis,
size=size,
fill_value=fill_value[1], # Unpack fill value bc it's currently defined for multiarray
dtype=dtype,
)

# Calculate sum squared deviations - the main part of variance sum
array_means = array_sums / array_lens

sum_squared_deviations = generic_aggregate(
group_idx,
(array - array_means[..., group_idx]) ** 2,
func="nansum",
engine=engine,
axis=axis,
size=size,
fill_value=fill_value[0], # Unpack fill value bc it's currently defined for multiarray
dtype=dtype,
)

return MultiArray((sum_squared_deviations, array_sums, array_lens))


def _var_combine(array, axis, keepdims=True):
def clip_last(array, n=1):
"""Return array except the last element along axis
Purely included to tidy up the adj_terms line
"""
assert n > 0, "Clipping nothing off the end isn't implemented"
not_last = [slice(None, None) for i in range(array.ndim)]
not_last[axis[0]] = slice(None, -n)
return array[*not_last]

def clip_first(array, n=1):
"""Return array except the first element along axis
Purely included to tidy up the adj_terms line
"""
not_first = [slice(None, None) for i in range(array.ndim)]
not_first[axis[0]] = slice(n, None)
return array[*not_first]

assert len(axis) == 1, "Assuming that the combine function is only in one direction at once"

sum_deviations, sum_X, sum_len = array.arrays

# Calculate parts needed for cascading combination
cumsum_X = np.cumsum(sum_X, axis=axis[0]) # Don't need to be able to merge the last element
cumsum_len = np.cumsum(sum_len, axis=axis[0])

# There will be instances in which one or both chunks being merged are empty
# In which case, the adjustment term should be zero, but will throw a divide-by-zero error
# We're going to add a constant to the bottom of the adjustment term equation on those instances
# and count on the zeros on the top making our adjustment term still zero
zero_denominator = (clip_last(cumsum_len) == 0) | (clip_first(sum_len) == 0)

# Adjustment terms to tweak the sum of squared deviations because not every chunk has the same mean
adj_terms = (
clip_last(cumsum_len) * clip_first(sum_X) - clip_first(sum_len) * clip_last(cumsum_X)
) ** 2 / (
clip_last(cumsum_len) * clip_first(sum_len) * (clip_last(cumsum_len) + clip_first(sum_len))
+ zero_denominator.astype(int)
)

assert np.all((adj_terms * zero_denominator) == 0), (
"Instances where we add something to the denominator must come out to zero"
)

return aggregate_flox.MultiArray(
(
np.sum(sum_deviations, axis=axis, keepdims=keepdims)
+ np.sum(adj_terms, axis=axis, keepdims=keepdims), # sum of squared deviations
np.sum(sum_X, axis=axis, keepdims=keepdims), # sum of array items
np.sum(sum_len, axis=axis, keepdims=keepdims), # sum of array lengths
)
) # I'm not even pretending calling this class from there is a good idea, I think it wants to be somewhere else though


# TODO: fix this for complex numbers
def _var_finalize(sumsq, sum_, count, ddof=0):
with np.errstate(invalid="ignore", divide="ignore"):
result = (sumsq - (sum_**2 / count)) / (count - ddof)
result[count <= ddof] = np.nan
return result
# def _var_finalize(sumsq, sum_, count, ddof=0):
# with np.errstate(invalid="ignore", divide="ignore"):
# result = (sumsq - (sum_**2 / count)) / (count - ddof)
# result[count <= ddof] = np.nan
# return result


def _var_finalize(multiarray, ddof=0):
return multiarray.arrays[0] / (multiarray.arrays[2] - ddof)


def _std_finalize(sumsq, sum_, count, ddof=0):
Expand All @@ -366,14 +468,25 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
dtypes=(None, None, np.intp),
final_dtype=np.floating,
)
# nanvar = Aggregation(
# "nanvar",
# chunk=("nansum_of_squares", "nansum", "nanlen"),
# combine=("sum", "sum", "sum"),
# finalize=_var_finalize,
# fill_value=0,
# final_fill_value=np.nan,
# dtypes=(None, None, np.intp),
# final_dtype=np.floating,
# )
nanvar = Aggregation(
"nanvar",
chunk=("nansum_of_squares", "nansum", "nanlen"),
combine=("sum", "sum", "sum"),
chunk=var_chunk,
numpy=tlz.compose(_var_finalize, var_chunk),
combine=(_var_combine,),
finalize=_var_finalize,
fill_value=0,
fill_value=((0, 0, 0),),
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
dtypes=(None,),
final_dtype=np.floating,
)
std = Aggregation(
Expand Down
10 changes: 9 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_initialize_aggregation,
generic_aggregate,
quantile_new_dims_func,
var_chunk,
)
from .cache import memoize
from .lib import ArrayLayer, dask_array_type, sparse_array_type
Expand Down Expand Up @@ -1288,7 +1289,8 @@ def chunk_reduce(
# optimize that out.
previous_reduction: T_Func = ""
for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes):
if empty:
# UGLY! but this is because the `var` breaks our design assumptions
if empty and reduction is not var_chunk:
Copy link
Collaborator

@dcherian dcherian Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code path is an "optimization" for chunks that don't contain any valid groups. so group_idx is all -1.
We will need to override full in MultiArray. Look up what the like kwarg does here, it dispatches to the appropriate array type.


The next issue will be that fill_value is a scalar like np.nan but that doesn't work for all our intermediates (e.g. the "count").

  1. My first thought is that MultiArray will need to track a default fill_value per array. For var, this can be initialized to (None, None, 0). If None we use the fill_value passed in; else the default.
  2. The other way would be to hardcode some behaviour in _initialize_aggregation so that agg.fill_value["intermediate"] = ( (fill_value, fill_value, 0), ), and then multi-array can receive that tuple and do the "right thing".

The other place this will matter is in reindex_numpy, which is executed at the combine step. I suspect the second tuple approach is the best.

This bit is hairy, and ill-defined. Let me know if you want me to work through it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm partway through implementing something to work here.

  • How do I trigger this code pathway without brute force overwriting if empty: with if True:
  • When np.full is called, like is a np array not a MultiArray, because it's (I think) the chunk data and bypassing var_chunk (could also be an artefact of the if True override above?). In a pinch, I guess I could add an elif that catches the empty and reduction is var_chunk and co-erce that into a MultiArray, but it's also ugly so I'm hoping you might have better ideas

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking some more, I may have misinterpreted what fill_value is used for. When is it needed for intermediates?

result = np.full(shape=final_array_shape, fill_value=fv, like=array)
elif is_nanlen(reduction) and is_nanlen(previous_reduction):
result = results["intermediates"][-1]
Expand All @@ -1297,6 +1299,12 @@ def chunk_reduce(
kw_func = dict(size=size, dtype=dt, fill_value=fv)
kw_func.update(kw)

# UGLY! but this is because the `var` breaks our design assumptions
if reduction is var_chunk or (
isinstance(reduction, tlz.functoolz.Compose) and reduction.first is var_chunk
):
kw_func.update(engine=engine)

if callable(reduction):
# passing a custom reduction for npg to apply per-group is really slow!
# So this `reduction` has to do the groupby-aggregation
Expand Down
35 changes: 34 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def gen_array_by(size, func):
@pytest.mark.parametrize("size", [(1, 12), (12,), (12, 9)])
@pytest.mark.parametrize("nby", [1, 2, 3])
@pytest.mark.parametrize("add_nan_by", [True, False])
@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize("func", ["nanvar"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will revert before merging, but this is the test we need to make work first. It runs a number of complex cases.

def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engine):
if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1):
pytest.skip()
Expand Down Expand Up @@ -2240,3 +2240,36 @@ def test_sparse_nan_fill_value_reductions(chunks, fill_value, shape, func):
expected = np.expand_dims(npfunc(numpy_array, axis=-1), axis=-1)
actual, *_ = groupby_reduce(array, by, func=func, axis=-1)
assert_equal(actual, expected)


@pytest.mark.parametrize(
"func", ("nanvar", "var")
) # Expect to expand this to other functions once written. "nanvar" has updated chunk, combine functions. "var", for the moment, still uses the old algorithm
@pytest.mark.parametrize("engine", ("flox",)) # Expect to expand this to other engines once written
@pytest.mark.parametrize(
"exponent", (10, 12)
) # Should fail at 10e8 for old algorithm, and survive 10e12 for current
def test_std_var_precision(func, exponent, engine):
# Generate a dataset with small variance and big mean
# Check that func with engine gives you the same answer as numpy

size = 1000
offset = 10**exponent
array = np.linspace(-1, 1, size) # has zero mean
labels = np.arange(size) % 2 # Ideally we'd parametrize this too.

# These two need to be the same function, but with the offset added and not added
no_offset, _ = groupby_reduce(array, labels, engine=engine, func=func)
with_offset, _ = groupby_reduce(array + offset, labels, engine=engine, func=func)

expected = np.concatenate([np.nanvar(array[::2], keepdims=True), np.nanvar(array[1::2], keepdims=True)])
expected_offset = np.concatenate(
[np.nanvar(array[::2] + offset, keepdims=True), np.nanvar(array[1::2] + offset, keepdims=True)]
)

tol = {"rtol": 1e-8, "atol": 1e-10} # Not sure how stringent to be here

assert_equal(expected, no_offset, tol)
assert_equal(expected_offset, with_offset, tol)
# Failure threshold in my external tests is dependent on dask chunksize, maybe needs exploring better?
assert_equal(no_offset, with_offset, tol)
Loading