-
Notifications
You must be signed in to change notification settings - Fork 20
More stable algorithm for variance, standard deviation #456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
def clip_last(array): | ||
"""Return array except the last element along axis | ||
Purely included to tidy up the adj_terms line | ||
""" | ||
not_last = [slice(None, None) for i in range(array.ndim)] | ||
not_last[axis[0]] = slice(None, -1) | ||
return array[*not_last] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use array[..., :-1]
and array[1:, ...]
instead of these helper functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you guarantee that axis will always be -2? If so, array[...,:-1,:]
and array[...,1:,:]
would work, but I wasn't sure if the assumption was valid or if it had to generalise to other values of axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see the issue, this is fine then.
flox/aggregations.py
Outdated
nanvar = Aggregation( | ||
"nanvar", | ||
chunk=("nansum_of_squares", "nansum", "nanlen"), | ||
combine=("sum", "sum", "sum"), | ||
chunk=("var_chunk"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chunk=("var_chunk"), | |
chunk=var_chunk, |
Since you want to refer to the actual function
np.sum(sum_X, axis=axis, keepdims=keepdims), # sum of array items | ||
np.sum(sum_len, axis=axis, keepdims=keepdims), # sum of array lengths | ||
) | ||
) # I'm not even pretending calling this class from there is a good idea, I think it wants to be somewhere else though |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! the combine operates repeatedly so it should return the same thing it accepts : MultiArray in this case. The aggregate or "finalize" step on the other hand; only applies once and takes in the intermediate type, and returns the output type.
new_arrays = [] # I really don't like doing this as a list | ||
for array in self.arrays: # Do we care about trying to avoid for loops here? three separate lines would be faster, but harder to read | ||
new_arrays.append(array.astype(dt, **kwargs)) | ||
return MultiArray(new_arrays) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine though you could
new_arrays = [] # I really don't like doing this as a list | |
for array in self.arrays: # Do we care about trying to avoid for loops here? three separate lines would be faster, but harder to read | |
new_arrays.append(array.astype(dt, **kwargs)) | |
return MultiArray(new_arrays) | |
return MultiArray(tuple(array.astype(dt, **kwargs) for array in self.arrays)) |
|
||
def __init__(self, arrays): | ||
self.arrays = arrays # something else needed here to be more careful about types (not sure what) | ||
# Do we want to co-erce arrays into a tuple and make sure it's immutable? Do we want it to be immutable? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is fine as-is
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs) | ||
|
||
# Shape is needed, seems likely that the other two might be | ||
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this data structure isn't useful in general, and is only working around some limitations in the design where we need to pass in multiple intermediates to the combine function. So there will be some ugliness. You have good instincts.
flox/aggregate_flox.py
Outdated
|
||
sum_squared_deviations = sum( | ||
group_idx, | ||
(array - array_means[..., group_idx]) ** 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👏 👏🏾
@@ -235,7 +235,7 @@ def gen_array_by(size, func): | |||
@pytest.mark.parametrize("size", [(1, 12), (12,), (12, 9)]) | |||
@pytest.mark.parametrize("nby", [1, 2, 3]) | |||
@pytest.mark.parametrize("add_nan_by", [True, False]) | |||
@pytest.mark.parametrize("func", ALL_FUNCS) | |||
@pytest.mark.parametrize("func", ["nanvar"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will revert before merging, but this is the test we need to make work first. It runs a number of complex cases.
@@ -343,12 +343,106 @@ def _mean_finalize(sum_, count): | |||
) | |||
|
|||
|
|||
def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=None, dtype=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved this here, so that we can generalize to "all" engines. it has some ugliness (notice that it now takes the engine
kwarg)
array_sums = generic_aggregate( | ||
group_idx, | ||
array, | ||
func="nansum", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will need to be "sum" for "var".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first thought is to pass through some kind of "are NaNs okay" boolean variable through to var_chunk and var_combine. Is this what xarray's skipna
does? Or I think I've seen it done as a string "propogate"
or "ignore"
? And then to call the var_chunk and var_combine as a partial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes the way I do this in flox is create a var_chunk = partial(_var_chunk, skipna=False)
and _nanvar_chunk=partial(_var_chunk, skipna=True)
you can stick this in the Aggregation constructor I think
@@ -1251,7 +1252,8 @@ def chunk_reduce( | |||
# optimize that out. | |||
previous_reduction: T_Func = "" | |||
for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes): | |||
if empty: | |||
# UGLY! but this is because the `var` breaks our design assumptions | |||
if empty and reduction is not var_chunk: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this code path is an "optimization" for chunks that don't contain any valid groups. so group_idx
is all -1
.
We will need to override full
in MultiArray. Look up what the like
kwarg does here, it dispatches to the appropriate array type.
The next issue will be that fill_value is a scalar like np.nan
but that doesn't work for all our intermediates (e.g. the "count").
- My first thought is that
MultiArray
will need to track a default fill_value per array. Forvar
, this can be initialized to(None, None, 0)
. IfNone
we use thefill_value
passed in; else the default. - The other way would be to hardcode some behaviour in
_initialize_aggregation
so thatagg.fill_value["intermediate"] = ( (fill_value, fill_value, 0), )
, and then multi-array can receive that tuple and do the "right thing".
The other place this will matter is in reindex_numpy
, which is executed at the combine step. I suspect the second tuple approach is the best.
This bit is hairy, and ill-defined. Let me know if you want me to work through it.
This is great progress! Now we reach some much harder parts. I pushed a commit to show where I think the "chunk" function should go and left a few comments. I think the next steps should be to
|
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Updated algorithm for nanvar, to use an adapted version of the Schubert and Gertz (2018) paper mentioned in #386, following discussion in #422