Source code for galtab.jaxhalotools.empirical_models

import jax
import jax.scipy
from jax import numpy as jnp

import halotools.empirical_models as htem


[docs] class JaxZheng07Cens(htem.Zheng07Cens): def mean_occupation(self, **kwargs): # Retrieve the array storing the mass-like variable if 'table' in list(kwargs.keys()): mass = kwargs['table'][self.prim_haloprop_key] elif 'prim_haloprop' in list(kwargs.keys()): mass = jnp.atleast_1d(kwargs['prim_haloprop']) else: msg = ("\nYou must pass either a ``table`` or ``prim_haloprop`` argument \n" "to the ``mean_occupation`` function of the ``Zheng07Cens`` class.\n") raise htem.HalotoolsError(msg) return zheng07_cenocc(mass, self.param_dict["logMmin"], self.param_dict["sigma_logM"])
[docs] class JaxZheng07Sats(htem.Zheng07Sats): def mean_occupation(self, **kwargs): if self.modulate_with_cenocc: for key, value in list(self.param_dict.items()): if key in self.central_occupation_model.param_dict: self.central_occupation_model.param_dict[key] = value # Retrieve the array storing the mass-like variable if 'table' in list(kwargs.keys()): mass = kwargs['table'][self.prim_haloprop_key] elif 'prim_haloprop' in list(kwargs.keys()): mass = jnp.atleast_1d(kwargs['prim_haloprop']) else: msg = ("\nYou must pass either a ``table`` or ``prim_haloprop`` argument \n" "to the ``mean_occupation`` function of the ``Zheng07Sats`` class.\n") raise htem.HalotoolsError(msg) mean_nsat = zheng07_satocc(mass, self.param_dict["logM0"], self.param_dict["logM1"], self.param_dict["alpha"]) # If a central occupation model was passed to the constructor, # multiply mean_nsat by an overall factor of mean_ncen if self.modulate_with_cenocc: # compatible with AB models mean_ncen = getattr(self.central_occupation_model, "baseline_mean_occupation", self.central_occupation_model.mean_occupation)(**kwargs) mean_nsat *= mean_ncen return mean_nsat
[docs] def vectorized_cond(pred, true_fun, false_fun, operand, safe_operand_value=0): # Taken from https://github.com/google/jax/issues/1052 # ==================================================== # true_fun and false_fun must act elementwise (i.e. be vectorized) true_op = jnp.where(pred, operand, safe_operand_value) false_op = jnp.where(pred, safe_operand_value, operand) return jnp.where(pred, true_fun(true_op), false_fun(false_op))
@jax.jit def zheng07_cenocc(mass, logmmin, sigma_logm): logm = jnp.log10(mass) return 0.5 * (1.0 + jax.scipy.special.erf((logm - logmmin) / sigma_logm)) @jax.jit def zheng07_satocc(mass, logm0, logm1, alpha): m0 = 10. ** logm0 m1 = 10. ** logm1 is_nonzero = mass > m0 def nonzero_func(x): return ((x - m0) / m1) ** alpha def zero_func(x): return 0 mean_nsat = vectorized_cond(is_nonzero, nonzero_func, zero_func, mass, safe_operand_value=m0 + m1) return mean_nsat