Source code for dataiter.aggregate

# -*- coding: utf-8 -*-

# Copyright (c) 2022 Osmo Salomaa
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

import dataiter
import functools
import numpy as np
import statistics

from collections import Counter
from dataiter import deco
from dataiter import Vector

    assert dataiter.USE_NUMBA
    from numba import njit
    from numba import types
    from numba.extending import overload
except Exception:
    # We need the decorators used to exist to avoid import time errors,
    # but actual calls to the below shouldn't be made (see 'select').
    def dummy_jit(*args, **kwargs):
        def outer_wrapper(function):
            def inner_wrapper(*args, **kwargs):
                print("Using dummy jit, this shouldn't happen")
                return function(*args, **kwargs)
            return inner_wrapper
        return outer_wrapper
    njit = overload = dummy_jit

def composite(function):
    def wrapper(x, *args, **kwargs):
        if not isinstance(x, (Vector, str)):
            raise TypeError("Expected Vector or str")
        return function(x, *args, **kwargs)
    return wrapper

[docs] @composite def all(x): """ Return whether all elements of `x` evaluate to ``True``. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.all``, see the NumPy documentation for details: >>> di.all(di.Vector([True, False])) >>> di.all(di.Vector([True, True])) >>> di.all("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.all) aggregate.default = True return f(data[x].as_boolean(), data._group_, drop_na=False, default=True, nrequired=0) aggregate.group_aware = True return aggregate x = x.as_boolean() return np.all(x).item()
[docs] @composite def any(x): """ Return whether any element of `x` evaluates to ``True``. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.any``, see the NumPy documentation for details: >>> di.any(di.Vector([False, False])) >>> di.any(di.Vector([True, False])) >>> di.any("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.any) aggregate.default = False return f(data[x].as_boolean(), data._group_, drop_na=False, default=False, nrequired=0) aggregate.group_aware = True return aggregate x = x.as_boolean() return np.any(x).item()
# @composite skipped on purpose due to allowing calls with no x given.
[docs] def count(x="", *, drop_na=False): """ Return the amount of elements in `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Since all columns in a data frame should have the same amount of elements (i.e. rows), you can just leave the x argument at its default blank string, which will give you that row count. >>> di.count(di.Vector([1, 2, 3])) >>> di.count() """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x or "_group_")(len) aggregate.default = 0 return f(data[x or "_group_"], data._group_, drop_na=( drop_na and x and data[x].is_na().any()), default=0, nrequired=0) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return len(x)
[docs] @composite def count_unique(x, *, drop_na=False): """ Return the amount of unique elements in `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.count_unique(di.Vector([1, 2, 2, 3, 3, 3])) >>> di.count_unique("x") """ if isinstance(x, str): def aggregate(data): f = (count_unique_apply, count_unique_apply_numba) f = select(f, data, x) aggregate.default = 0 return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any())) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return len(set(x))
@deco.listify def count_unique_apply(x, group, drop_na): for xg in yield_groups(x, group, drop_na): yield len(set(xg)) @njit(cache=dataiter.USE_NUMBA_CACHE) def count_unique_apply_numba(x, group, drop_na): out = [] for xg in yield_groups_numba(x, group, drop_na): out.append(len(np.unique(xg))) return out
[docs] def first(x, *, drop_na=False): """ Return the first element of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.first(di.Vector([1, 2, 3])) >>> di.first("x") """ return nth(x, 0, drop_na=drop_na)
@functools.lru_cache(256) def generic(function, **kwargs): @deco.listify def aggregate(x, group, drop_na, default, nrequired): for xg in yield_groups(x, group, drop_na): yield function(xg, **kwargs) if len(xg) >= nrequired else default return aggregate @functools.lru_cache(256) def generic_numba(function): @njit(cache=dataiter.USE_NUMBA_CACHE) def aggregate(x, group, drop_na, default, nrequired): out = [] for xg in yield_groups_numba(x, group, drop_na): out.append(function(xg) if len(xg) >= nrequired else default) return out return aggregate def handle_na(x, drop_na): return x[~x.is_na()] if drop_na else x def is_na_item_numba(x): raise NotImplementedError @overload(is_na_item_numba) def is_na_item_numba_overload(x): # "Called at compile-time with the types of the function's runtime arguments." # # if isinstance(x, types.Float): return lambda x: np.isnan(x) if isinstance(x, types.NPDatetime): return lambda x: np.isnat(x) if isinstance(x, types.UnicodeType): return lambda x: x == "" return lambda x: False @njit(cache=dataiter.USE_NUMBA_CACHE) def is_na_numba(x): na = np.full(len(x), False) for i in range(len(x)): na[i] = is_na_item_numba(x[i]) return na
[docs] @composite def last(x, *, drop_na=False): """ Return the last element of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.last(di.Vector([1, 2, 3])) >>> di.last("x") """ return nth(x, -1, drop_na=drop_na)
[docs] @composite def max(x, *, drop_na=True): """ Return the maximum of elements in `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.max(di.Vector([4, 5, 6])) >>> di.max("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.amax) aggregate.default = data[x].na_value return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=None, nrequired=1) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.amax(x).item() if len(x) >= 1 else x.na_value
[docs] @composite def mean(x, *, drop_na=True): """ Return the arithmetic mean of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.mean``, see the NumPy documentation for details: >>> di.mean(di.Vector([1, 2, 10])) >>> di.mean("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.mean) aggregate.default = np.nan return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=np.nan, nrequired=1) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.mean(x).item() if len(x) >= 1 else np.nan
[docs] @composite def median(x, *, drop_na=True): """ Return the median of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.median``, see the NumPy documentation for details: >>> di.median(di.Vector([5, 1, 2])) >>> di.median("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.median) aggregate.default = np.nan return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=np.nan, nrequired=1) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.median(x).item() if len(x) >= 1 else np.nan
[docs] @composite def min(x, *, drop_na=True): """ Return the minimum of elements in `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.min(di.Vector([4, 5, 6])) >>> di.min("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.amin) aggregate.default = data[x].na_value return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=None, nrequired=1) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.amin(x).item() if len(x) >= 1 else x.na_value
[docs] @composite def mode(x, *, drop_na=True): """ Return the most common value in `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.mode(di.Vector([1, 2, 2, 3, 3, 3])) >>> di.mode("x") """ if isinstance(x, str): def aggregate(data): f = (mode_apply, mode_apply_numba) f = select(f, data, x) aggregate.default = data[x].na_value return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any())) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return mode1(x) if len(x) >= 1 else x.na_value
@deco.listify def mode_apply(x, group, drop_na): for xg in yield_groups(x, group, drop_na): yield mode1(xg) if len(xg) >= 1 else None @njit(cache=dataiter.USE_NUMBA_CACHE) def mode_apply_numba(x, group, drop_na): out = [] for xg in yield_groups_numba(x, group, drop_na): if len(xg) > 0: ng = np.full(len(xg), 0) for i in range(len(xg)): for j in range(len(xg)): if xg[j] == xg[i]: ng[i] += 1 out.append(xg[np.argmax(ng)]) else: out.append(None) return out def mode1(x): try: return statistics.mode(x) except statistics.StatisticsError: # Python < 3.8 with several elements tied for mode. # Return the first encountered of the tied elements. return Counter(x).most_common(1)[0][0]
[docs] @composite def nth(x, index, *, drop_na=False): """ Return the element of `x` at `index` (zero-based). If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.nth(di.Vector([1, 2, 3]), 1) >>> di.nth("x", 1) """ if isinstance(x, str): def aggregate(data): f = (nth_apply, nth_apply_numba) f = select(f, data, x) aggregate.default = data[x].na_value return f(data[x], data._group_, index, drop_na=( drop_na and data[x].is_na().any())) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) try: return x[index].item() except IndexError: return x.na_value
@deco.listify def nth_apply(x, group, index, drop_na): for xg in yield_groups(x, group, drop_na): try: yield xg[index] except IndexError: yield None @njit(cache=dataiter.USE_NUMBA_CACHE) def nth_apply_numba(x, group, index, drop_na): out = [] for xg in yield_groups_numba(x, group, drop_na): if 0 <= index < len(xg) or -len(xg) <= index < 0: out.append(xg[index]) else: out.append(None) return out
[docs] @composite def quantile(x, q, *, drop_na=True): """ Return the `qth` quantile of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.quantile``, see the NumPy documentation for details: >>> di.quantile(di.Vector([1, 5, 6]), 0.5) >>> di.quantile("x", 0.5) """ if isinstance(x, str): def aggregate(data): f = (quantile_apply, quantile_apply_numba) f = select(f, data, x) aggregate.default = np.nan return f(data[x].as_float(), data._group_, q, drop_na=( drop_na and data[x].is_na().any())) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.quantile(x.as_float(), q).item() if len(x) >= 1 else np.nan
@deco.listify def quantile_apply(x, group, q, drop_na): for xg in yield_groups(x, group, drop_na): yield np.quantile(xg, q) if len(xg) >= 1 else np.nan @njit(cache=dataiter.USE_NUMBA_CACHE) def quantile_apply_numba(x, group, q, drop_na): out = [] for xg in yield_groups_numba(x, group, drop_na): out.append(np.quantile(xg, q) if len(xg) >= 1 else np.nan) return out def select(functions, data, name): return functions[use_numba(data[name])]
[docs] @composite def std(x, *, ddof=0, drop_na=True): """ Return the standard deviation of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.std``, see the NumPy documentation for details: >>> di.std(di.Vector([3, 6, 7])) >>> di.std("x") """ if isinstance(x, str): def aggregate(data): if ddof == 0: # Numba doesn't support the ddof argument, # so can only handle the default ddof=0. f = (generic, generic_numba) f = select(f, data, x)(np.std) else: f = generic(np.std, ddof=ddof) aggregate.default = np.nan return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=np.nan, nrequired=2) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.std(x, ddof=ddof).item() if len(x) >= 2 else np.nan
[docs] @composite def sum(x, *, drop_na=True): """ Return the sum of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. >>> di.sum(di.Vector([1, 2, 3])) >>> di.sum("x") """ if isinstance(x, str): def aggregate(data): f = (generic, generic_numba) f = select(f, data, x)(np.sum) aggregate.default = 0 return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=0, nrequired=0) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.sum(x).item()
def use_numba(x): # Numba can't handle all dtypes, use conditionally. # Strings are supported, but performance is bad. # return dataiter.USE_NUMBA and ( np.issubdtype(x.dtype, np.bool_) or np.issubdtype(x.dtype, np.datetime64) or np.issubdtype(x.dtype, np.floating) or np.issubdtype(x.dtype, np.integer))
[docs] @composite def var(x, *, ddof=0, drop_na=True): """ Return the variance of `x`. If `x` is a string, return a function usable with :meth:`.DataFrame.aggregate` that operates group-wise on column `x`. Uses ``numpy.var``, see the NumPy documentation for details: >>> di.var(di.Vector([3, 6, 7])) >>> di.var("x") """ if isinstance(x, str): def aggregate(data): if ddof == 0: # Numba doesn't support the ddof argument, # so can only handle the default ddof=0. f = (generic, generic_numba) f = select(f, data, x)(np.var) else: f = generic(np.var, ddof=ddof) aggregate.default = np.nan return f(data[x], data._group_, drop_na=( drop_na and data[x].is_na().any()), default=np.nan, nrequired=2) aggregate.group_aware = True return aggregate x = handle_na(x, drop_na) return np.var(x, ddof=ddof).item() if len(x) >= 2 else np.nan
def yield_groups(x, group, drop_na): # Groups must be contiguous for this to work! i = 0 n = len(x) for j in range(1, n + 1): if j < n and group[j] == group[i]: continue xij = x[i:j] if drop_na: xij = xij[~xij.is_na()] yield xij i = j @njit(cache=dataiter.USE_NUMBA_CACHE) def yield_groups_numba(x, group, drop_na): # Groups must be contiguous for this to work! i = 0 n = len(x) out = [] for j in range(1, n + 1): if j < n and group[j] == group[i]: continue xij = x[i:j] if drop_na: xij = xij[~is_na_numba(xij)] out.append(xij) i = j return out