Source code for galtab.moments

# Functions which calculate the "standardized" moments.
# For simplicity, I'm not going to try implementing unbiased estimators
import math
from functools import partial

import numpy as np
import jax
from jax import numpy as jnp


[docs] @partial(jax.jit, static_argnums=(3, 4)) def jit_sum_at(arr_in, ind_in, ind_out, len_out=None, ind_out_is_sorted=False): if len_out is None: len_out = len(arr_in) arr_out = jnp.zeros(len_out, dtype=arr_in.dtype) arr_out = arr_out.at[ind_out].add( arr_in[ind_in], indices_are_sorted=ind_out_is_sorted) return arr_out
def numpy_sum_at(arr_in, ind_in, ind_out, len_out=None): if len_out is None: len_out = len(arr_in) arr_out = np.zeros(len_out, dtype=arr_in.dtype) np.add.at(arr_out, ind_out, arr_in[ind_in]) return arr_out
[docs] def moments_from_samples(samples, k_vals, weights=None): """ Calculate up to the kth moment of a given sample. 1st moment is the mean, 2nd moment is the standard deviation, and for k=3+ we compute standardized moments, e.g. skewness, kurtosis, etc. Parameters ---------- samples : array-like Sampled values from which to calculate moments k_vals : list[int] Moments to be calculated weights : array-like, optional Weights for each sample. If None, all samples are equally weighted. Returns ------- jnp.array Calculated moments """ kmax = np.max(k_vals) weights = jnp.ones_like(samples) if weights is None else jnp.asarray(weights) weights = weights / jnp.sum(weights) mu1 = jnp.sum(samples * weights) if kmax > 0 else jnp.nan mu2 = (jnp.sum(weights * (samples - mu1)**2)) ** 0.5 if kmax > 1 else jnp.nan moments = [] for k in k_vals: if k == 1: moment = mu1 elif k == 2: moment = mu2 else: moment = jnp.sum(weights * (samples - mu1)**k) / mu2**k moments.append(moment) return jnp.array(moments)
[docs] def moments_from_binned_pmf(bin_edges, pmf, k_vals): bin_edges = np.sort(bin_edges) ceil_edges = np.ceil(bin_edges).astype(int) assert not np.any(bin_edges == ceil_edges), \ "bin_edges are not allowed to be integers" ints = np.arange(ceil_edges[0], ceil_edges[-1]).astype(int) lengths = np.diff(ceil_edges) pdf = sum([x*[y] for x, y in zip(lengths, pmf)], []) return moments_from_samples(ints, k_vals, weights=pdf)
def standardized_moments_from_cumulants(kappa): m = raw_moments_from_cumulants(kappa) return standardized_moments_from_raw_moments(m) def standardized_moments_from_raw_moments(m): # Prepend the 0th raw moment, which is always 1 m = [1, *m] # Don't calculate 0th central moment - it is always zero # Start with the 1st standardized moment (aka the mean - same as raw moment) mu = [m[1]] for k in range(2, len(m)): mu_k = jnp.zeros_like(m[1]) if not np.ndim(mu_k): mu_k = mu_k.tolist() # Calculate central moments for i in range(k+1): mu_k += math.comb(k, i) * pow(-1, k-i) * m[i] * pow(m[1], k-i) # Standardized 2nd moment is the standard deviation if k == 2: mu_k = jnp.sqrt(mu_k) # Standardized k(>2)th moment is central moment / kth power of std dev else: assert k > 2 mu_k = mu_k / jnp.power(mu[1], k) mu.append(mu_k) return mu def raw_moments_from_cumulants(kappa): # Prepend the 0th cumulant as NaN kappa = [jnp.nan, *kappa] # Start with the 0th raw moment for bookkeeping; remove it before returning m = [jnp.nan] for n in range(1, len(kappa)): m_n = kappa[n] if np.ndim(m_n): m_n = jnp.array(m_n, copy=True) for i in range(1, n): m_n += math.comb(n - 1, i - 1) * kappa[i]*m[n-i] m.append(m_n) return m[1:] @partial(jax.jit, static_argnums=(1,)) def bernoulli_cumulant(p, k): assert k > 0, "Start at the 1st cumulant" if k == 1: ans = p elif k == 2: ans = p * (1 - p) elif k == 3: ans = p * (p - 1) * (2 * p - 1) elif k == 4: ans = p * (-6 * p ** 3 + 12 * p ** 2 - 7 * p + 1) elif k == 5: ans = p * (24 * p ** 4 - 60 * p ** 3 + 50 * p ** 2 - 15 * p + 1) elif k == 6: ans = p * (-120 * p ** 5 + 360 * p ** 4 - 390 * p ** 3 + 180 * p ** 2 - 31 * p + 1) elif k == 7: ans = p * (720 * p ** 6 - 2520 * p ** 5 + 3360 * p ** 4 - 2100 * p ** 3 + 602 * p ** 2 - 63 * p + 1) elif k == 8: ans = p * (-5040 * p ** 7 + 20160 * p ** 6 - 31920 * p ** 5 + 25200 * p ** 4 - 10206 * p ** 3 + 1932 * p ** 2 - 127 * p + 1) elif k == 9: ans = p * (40320 * p ** 8 - 181440 * p ** 7 + 332640 * p ** 6 - 317520 * p ** 5 + 166824 * p ** 4 - 46620 * p ** 3 + 6050 * p ** 2 - 255 * p + 1) elif k == 10: ans = p * (-362880 * p ** 9 + 1814400 * p ** 8 - 3780000 * p ** 7 + 4233600 * p ** 6 - 2739240 * p ** 5 + 1020600 * p ** 4 - 204630 * p ** 3 + 18660 * p ** 2 - 511 * p + 1) else: raise NotImplementedError( """ Analytic expression for k>10 not yet implemented. Run the BernoulliCumulantGenerator to implement higher cumulants: # Example (going up to k=11): gen = BernoulliCumulantGenerator() gen.generate(kmax=11) print(*gen.get_cumulants(), sep="\n\n") """ ) return ans
[docs] class BernoulliCumulantGenerator: def __init__(self): import sympy p, t = sympy.symbols("p t") self.generator = sympy.log(1 - p + p*sympy.exp(t)) self.known_derivatives = [ sympy.Derivative(self.generator, t).simplify()] def calc_next_deriv(self): import sympy exp = self.known_derivatives[-1] deriv = sympy.Derivative(exp, "t").simplify() self.known_derivatives.append(deriv) def generate(self, kmax): n = kmax - len(self.known_derivatives) if n <= 0: return for _ in range(n): self.calc_next_deriv() def get_cumulants(self): return [x.subs("t", 0).simplify() for x in self.known_derivatives]
""" Precomputed Bernoulli cumulants =============================== gen = BernoulliCumulantGenerator() gen.generate(kmax=10) print(*gen.get_cumulants(), sep="\n\n") --------------------------------------- p p*(1 - p) p*(p - 1)*(2*p - 1) p*(-6*p**3 + 12*p**2 - 7*p + 1) p*(24*p**4 - 60*p**3 + 50*p**2 - 15*p + 1) p*(-120*p**5 + 360*p**4 - 390*p**3 + 180*p**2 - 31*p + 1) p*(720*p**6 - 2520*p**5 + 3360*p**4 - 2100*p**3 + 602*p**2 - 63*p + 1) p*(-5040*p**7 + 20160*p**6 - 31920*p**5 + 25200*p**4 - 10206*p**3 + 1932*p**2 - 127*p + 1) p*(40320*p**8 - 181440*p**7 + 332640*p**6 - 317520*p**5 + 166824*p**4 - 46620*p**3 + 6050*p**2 - 255*p + 1) p*(-362880*p**9 + 1814400*p**8 - 3780000*p**7 + 4233600*p**6 - 2739240*p**5 + 1020600*p**4 - 204630*p**3 + 18660*p**2 - 511*p + 1) """