Skip to content

Suggested change to std/var preprocessing to improve precision #422

@jemmajeffree

Description

@jemmajeffree

Hi,
I've noticed that in a few rare situations, groupby and flox can return quite noisy standard deviations. In situations where the mean of an array is much larger than the standard deviation (such as deep ocean salinity, raised here), flox returns noisier values on dask arrays than on loaded numpy arrays. In extreme situations, the standard deviation of a dask array can contain NaNs from square-rooting negative variances.

I'm guessing it's the same idea as #386, in which case @dcherian has thought about this for much longer than I have. I've done a little bit of looking through the code, and could easily have missed something about how this works with neighbouring functions, but my thoughts on the potential problem and how it might be addressed are below.

Minimal complete verifiable example:

import numpy as np
import xarray as xr

l =12000
np.random.seed(1)
test_data = xr.DataArray(np.random.uniform(0,1,l)/100+1000000,dims=('time',)  # huge mean with relatively small variability
                        ).assign_coords({'month':xr.DataArray(np.arange(l)%12,dims=('time',))})

# with numpy arrays returns reasonable and consistent values
test_data.groupby('month').std('time')
# array([0.00283648, 0.00281895, 0.00287791, 0.00287652, 0.00287337,
#        0.00287037, 0.00289802, 0.00289441, 0.00285839, 0.00296478,
#        0.00284787, 0.00292089])

# using lazy computation/dask
dask_test_data = test_data.chunk({'time':100})
dask_test_data.groupby('month').std('time').load()
# array([0.01118034, 0.01118034, 0.01118034, 0.01581139, 0.        ,
#        0.01581139, 0.01118034, 0.01118034, 0.01118034,        nan,
#               nan, 0.        ])

A functional workaround is to subtract the mean before calculating standard deviation:

(dask_test_data.groupby('month')-dask_test_data.groupby('month').mean('time')).groupby('month').std('time').load()

My understanding is that the distinction comes from aggregate_npg.py improving precision by subtracting the first non-nan element of the array, a preprocessing step skipped by aggregations.py. This solution is probably not quite as stable as subtracting the mean, but the first element should be really close to the mean if the standard deviation is small, and it might be faster.

I’d suggest that to improve precision and match the numpy engine behaviour in aggregations_npg.py, the flox engine implementation for dask arrays of nanstd,nanvar,std,var could have a preprocessor that looks something like this:

def var_std_preprocess(array, axis): # Not sure of naming conventions, sorry
    """Subtracts first value of array from whole array, 
    to improve numerical precision of nanstd, nanvar, std, var

    Adapted from from argreduce_preprocess and _var_std_wrapper in aggregate_npg.py
    """
    import dask.array  # Copied from argreduce_preprocess, but maybe these shouldn’t be within the function? 
    import numpy as np # For either this function or argreduce_preprocess?
    
    # NEXT LINE IS PSEUDOCODE; I’m not entirely sure how to apply it lazily
    # If it doesn’t cost anything speed wise, then probably better to use mean. Happy to run some time tests on either
    first_elements = nanfirst(array,axis) 

    def subtract_first(array_, first_elements_):
        return array_-first_elements_

    return dask.array.map_blocks(
        subtract_first,
        array,
        first,
        dtype=array.dtype,
        meta=array._meta,
        name="groupby-var_std-preprocess",
    )

and is included in the Aggregations definition like so:

nanstd = Aggregation(
    "nanstd",
    preprocess=var_std_preprocess, #UPDATED LINE
    chunk=("nansum_of_squares", "nansum", "nanlen"),
    combine=("sum", "sum", "sum"),
    finalize=_std_finalize,
    fill_value=0,
    final_fill_value=np.nan,
    dtypes=(None, None, np.intp),
    final_dtype=np.floating,
)

It seems to work if first_elements is naively array[0] in the one-dimensional, no-nans case, but I’m not sure how to generalise it and apply nanfirst without the usual layers/wrappers. (aggregate_npg.py uses first = _get_aggregate(engine).aggregate(group_idx, array, func="nanfirst", axis=axis), but I don't think this syntax translates to the flox/dask implementation) . If you can give me a few tips or examples to work from, then I’m happy to try implement this behaviour.

Happy also to discuss alternatives, or to provide a pull request if that's easier to work with.

This is also my first time reading through flox code in detail (it's really nicely written and documented, by the way, was lovely to read), and one of the first times I’ve interacted with public github repos, so I’d appreciate any feedback or corrections on what's useful to provide when describing issues.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions