Source code for galtab.galtab

from copy import copy
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd

# from astropy import cosmology
# import halotools.empirical_models as htem
import halotools.mock_observables as htmo
import halotools.sim_manager as htsm

from . import _galaxy_tabulator as gt
from . import moments


[docs] class GalaxyTabulator: """ This object populates placeholder galaxies according to a fiducial halo occupation model. Then, given any occupation model, each placeholder which will be assigned a "weight" according to the mean occupation of """
[docs] def __init__(self, halocat, fiducial_model, n_mc=10, min_quant=1e-4, max_weight=0.05, sample_fraction=1.0, num_ptcl_requirement=None, seed=None, cosmo=None, sat_quant_instead_of_max_weight=False): """ Parameters ---------- halocat : htsm.CachedHaloCatalog Catalog containing the halos to populate galaxies on top of fiducial_model : htem.ModelFactory Used to calculate the number of placeholder galaxies per halo n_mc : int (default = 10) Number of Monte-Carlo realizations to combine min_quant : float (default = 0.001) The minimum quantile of galaxies as a function of the model's prim_haloprop to populate at least one placeholder in such a halo max_weight : float (default = 0.01) The quantile of the Poisson distribution centered around <N_sat> to use as the number of satellite placeholders per halo sample_fraction : float (default = 1.0) The fraction of galaxies to keep - useful for modeling surveys whose targeting fraction is <100%, but not spatially correlated num_ptcl_requirement : Optional[int] Passed to the initial call of model.populate_mock() seed : Optional[int] Monte-Carlo realization seeds; also passed to model.populate_mock() cosmo : Optional[astropy.Cosmology object] Used for redshift-space distortions; default taken from halocat sat_quant_instead_of_max_weight : Optional[bool] If true, max_weight is interpretted as 1 - sat_quant, where sat_quant specifies the number of placeholder satellites by the quantile of the (Poisson) occupation distribution of each halo Examples -------- Choose HOD model and load halos >>> hod = halotools.empirical_models.PrebuiltHodModelFactory("zheng07") >>> halocat = halotools.sim_manager.CachedHaloCatalog(simname="bolplanck") Instantiate the tabulators >>> gtab = GalaxyTabulator(halocat, hod) >>> cictab = CICTabulator( ... gtab, proj_search_radius=2.0, ... cylinder_half_length=10.0, bin_edges=np.arange(-0.5, 16)) Update HOD parameters to your liking and perform CiC prediction >>> hod.param_dict.update({}) >>> cictab.predict(hod) """ self.halocat = copy(halocat) self.fiducial_model = copy(fiducial_model) self.n_mc = n_mc self.min_quant = min_quant self.max_weight = max_weight self.sample_fraction = sample_fraction self.sat_quant_instead_of_max_weight = sat_quant_instead_of_max_weight if num_ptcl_requirement is None: self.num_ptcl_requirement = htsm.sim_defaults.Num_ptcl_requirement else: self.num_ptcl_requirement = num_ptcl_requirement self.seed = seed if cosmo is None: if hasattr(halocat, "cosmology"): cosmo = halocat.cosmology else: raise ValueError( "There is no halocat.cosmology - please specify it manually") self.cosmo = cosmo self.predictor = None self.weights = None # Populate_mock is my lazy way of making sure the halo_table # doesn't include subhalos if the model doesn't need them fiducial_model.populate_mock( halocat, seed=seed, Num_ptcl_requirement=self.num_ptcl_requirement) self.halo_table = fiducial_model.mock.halo_table self.galaxies, self._placeholder_model, self.halocat = \ self.populate_placeholders() self.halo_table = self._placeholder_model.mock.halo_table self.halo_inds = self.tabulate_halo_inds() # Unassign the halotools models to preserve pickle-ability self.fiducial_model = None self._placeholder_model = None
def populate_placeholders(self): return gt.make_placeholder_model(self) def calc_weights(self, model): self.weights = gt.calc_weights( self.halo_table, self.galaxies, self.halo_inds, model) * self.sample_fraction return self.weights
[docs] def tabulate_cic(self, **kwargs): # TODO: Replace kwargs with actual arguments for tab-complete-ability self.predictor = CICTabulator(self, **kwargs) return self.predictor
def predict(self, model, *args, **kwargs): if self.predictor is None: raise RuntimeError("You must tabulate a statistic " "before predicting it") return self.predictor.predict(model, *args, **kwargs) def tabulate_halo_inds(self): gal_df = pd.DataFrame(dict(halo_id=self.galaxies["halo_id"])) halo_df = pd.DataFrame(dict(halo_id=self.halo_table["halo_id"], halo_ind=np.arange(len(self.halo_table)))) return pd.merge(gal_df, halo_df, on="halo_id", how="left")["halo_ind"].values
[docs] class CICTabulator:
[docs] def __init__(self, galtabulator, proj_search_radius, cylinder_half_length, k_vals=None, bin_edges=None, sample1_selector=None, sample2_selector=None, analytic_moments=True, sort_tabulated_indices=False, max_ncic=int(1e5), seed=None, **kwargs): """ Initialize a CICTabulator This object tabulates the cylinder counts of each placeholder galaxy to quickly, deterministically, and differentiably predict dP(N_CIC)/dN_CIC for any given occupation model. Parameters ---------- galtabulator : GalaxyTabulator Object containing the tabulated placeholder galaxies proj_search_radius : float Perpendicular radius of circle in which to count pairs cylinder_half_length : float Half length of cylinder in which to count pairs k_vals : Optional[np.ndarray] Array of moment numbers (i.e. [1, 2] for mean, std) bin_edges : Optional[np.ndarray] Bin edges of P(Ncic); ignored if k_vals are provided sample1_selector : Optional[callable] Not implemented sample2_selector : Optional[callable] Not implemented analytic_moments : Optional[bool] Less noisy than Monte-Carlo approximation and quicker if only calculating a few k_vals sort_tabulated_indices : Optional[bool] This will cause longer tabulation time, but shorter subsequent prediction calls max_ncic : Optional[int] If any galaxies have more Ncic than this, return Ncic=0 seed : Optional[int] Random seed to produce reproducible results, even if not using analytic_moments=True Note: Remaining keyword arguments are passed to halotools' counts-in- cylinders function. """ assert sample1_selector is None, "Parameter not implemented" assert sample2_selector is None, "Parameter not implemented" self.galtabulator = galtabulator self.k_vals = k_vals self.kmax = None if k_vals is not None: self.k_vals = np.array(k_vals, dtype=int) self.kmax = np.max(k_vals, initial=0) assert np.min(k_vals, initial=1) > 0, "Lowest allowed k_val is 1" assert np.all(self.k_vals == k_vals), "k_vals must be list of ints" self.bin_edges = bin_edges self.sample1_selector = sample1_selector self.sample2_selector = sample2_selector self.analytic_moments = analytic_moments self.sort_tabulated_indices = sort_tabulated_indices self.max_ncic = max_ncic self.rs = np.random.RandomState(seed) self.sample1_inds = slice(None) self.sample2_inds = slice(None) if sample1_selector is not None: self.sample1_inds = np.where( sample1_selector(galtabulator.galaxies))[0] if sample2_selector is not None: self.sample2_inds = np.where( sample2_selector(galtabulator.galaxies))[0] self.sample1 = galtabulator.galaxies[self.sample1_inds] self.sample2 = galtabulator.galaxies[self.sample2_inds] self.n = len(galtabulator.galaxies) self.n1 = len(self.sample1) self.n2 = len(self.sample2) self.n_mc = galtabulator.n_mc self.mc_rands = self.seed_monte_carlo() self.indices = None self.proj_search_radius = proj_search_radius self.cylinder_half_length = cylinder_half_length self.tabulate(proj_search_radius=proj_search_radius, cylinder_half_length=cylinder_half_length, **kwargs)
def seed_monte_carlo(self): self.mc_rands = self.rs.random((self.n, self.n_mc)) return self.mc_rands def tabulate(self, **kwargs): if "sample1" in kwargs: raise ValueError("Cannot overwrite 'sample1' kwarg") if "sample2" in kwargs: raise ValueError("Cannot overwrite 'sample2' kwarg") if "period" in kwargs: raise ValueError("Cannot overwrite 'period' kwarg") if "return_indexes" in kwargs: raise ValueError("Cannot overwrite 'return_indexes' kwarg") kwargs["sample1"] = np.array([self.sample1[f"obs_{x}"] for x in "xyz"], dtype=np.float64).T kwargs["sample2"] = np.array([self.sample2[f"obs_{x}"] for x in "xyz"], dtype=np.float64).T kwargs["period"] = self.galtabulator.halocat.Lbox kwargs["return_indexes"] = True self.indices = htmo.counts_in_cylinders(**kwargs)[1] # Remove self-counting! # TODO: Is there a way that allows sample1 and sample2 to differ? self.indices = self.indices[self.indices["i1"] != self.indices["i2"]] if self.sort_tabulated_indices: self.indices.sort(order="i1")
[docs] def predict(self, model, return_number_densities=False, n_mc=None, reseed_mc=False, warn_p_over_1=True): """ Perform tabulation-accelerated prediction Parameters ---------- model : halotools.empirical_models.ModelFactory Halotools model we want to evaluate CiC for return_number_densities : bool [Optional] Return number densities of both samples n1 and n2 n_mc : int [Optional] Number of Monte-Carlo realizations (ignored in analytic moments) reseed_mc : bool [Optional] Reseed the CICTabulator's Monte Carlo realization generator warn_p_over_1 : bool | str [Optional] If true (default), print a warning if any placeholder weights > 1 If string starting with "return", return warn_status, don't print Returns ------- cic : np.ndarray Values of P(Ncic) or moments specified by k_vals [n1] : float Number density of sample1. Only returned if return_number_densities [n2] : float n2 always = n1 for now. Only returned if return_number_densities [warn_status] : dict Dictionary specifying the warning status. Only returned if warn_p_over_1 is a string starting with "return" """ result = self.calc_cic( model, return_number_densities=return_number_densities, n_mc=n_mc, reseed_mc=reseed_mc, warn_p_over_1=warn_p_over_1) n1 = n2 = None if return_number_densities: cic, n1, n2, warn_status = result else: cic, warn_status = result if self.analytic_moments and self.k_vals is not None: pass elif self.k_vals is not None: cic = moments.moments_from_samples( jnp.arange(len(cic)), self.k_vals, weights=cic) elif self.bin_edges is not None: hist = jnp.histogram( jnp.arange(len(cic)), self.bin_edges, weights=cic)[0] cic = hist / hist.sum() / jnp.diff(self.bin_edges) return_warn_status = (hasattr(warn_p_over_1, "lower") and warn_p_over_1.lower().startswith("return")) if return_number_densities: if return_warn_status: return cic, n1, n2, warn_status else: return cic, n1, n2 elif return_warn_status: return cic, warn_status else: return cic
def calc_cic(self, model, return_number_densities=False, n_mc=None, reseed_mc=False, warn_p_over_1=True): if reseed_mc: self.seed_monte_carlo() if n_mc is None: n_mc = self.n_mc mc_rands = self.mc_rands else: mc_rands = self.mc_rands[:, :n_mc] n1 = self.n1 indices = self.indices weights = self.galtabulator.calc_weights(model) previous_ints = jnp.ceil(weights) - 1 # previous_ints[previous_ints < 0] = 0 previous_ints = jnp.where(previous_ints < 0, 0, previous_ints) warn_raised = False n_bad_weights, mean_bad_weight = None, None if warn_p_over_1 and jnp.any(previous_ints): warn_raised = True bad_weights = weights[previous_ints != 0] n_bad_weights = len(bad_weights) mean_bad_weight = bad_weights.mean() if hasattr(warn_p_over_1, "lower") and warn_p_over_1.lower( ).startswith("return"): pass else: jax.debug.print( "WARNING: There are {x} placeholders with weight>1, averaging: {y}", x=n_bad_weights, y=mean_bad_weight) if self.analytic_moments and self.kmax is not None: # It was spending ~99% of computational time in np.add.at # before replacing np.add.at with jax at[].add() pb_cumulants = [] for k in range(1, self.kmax + 1): # Bernoulli cumulant bc = moments.bernoulli_cumulant(weights, k) # Sum of Bernoulli cumulants --> Poisson Binomial cumulants pc = moments.jit_sum_at( bc, indices["i2"], indices["i1"], len_out=n1, ind_out_is_sorted=self.sort_tabulated_indices) # TODO: replace bc with bc[self.sample2_inds] ??? pb_cumulants.append(pc) pb_raw_moments = moments.raw_moments_from_cumulants(pb_cumulants) avg_raw_moments = jnp.average( jnp.array(pb_raw_moments), weights=weights, axis=1) cic = moments.standardized_moments_from_raw_moments( avg_raw_moments) cic = jnp.array([cic[k-1] if k > 0 else 1 for k in self.k_vals]) else: next_int_probs = weights - previous_ints.astype(jnp.float32) mc_num_arrays = previous_ints[:, None] + (mc_rands < next_int_probs[:, None]) ncic_arrays = moments.jit_sum_at( mc_num_arrays, indices["i2"], indices["i1"], len_out=(n1, n_mc), ind_out_is_sorted=self.sort_tabulated_indices) # replace mc_num_arrays with mc_num_arrays[self.sample2_inds] ? numbins = int(jnp.max(ncic_arrays) + 1) if numbins < self.max_ncic: ncic_arrays1 = ncic_arrays.astype(int) + ( numbins * jnp.arange(n_mc)) ncic_hists = jnp.bincount( ncic_arrays1.ravel(), weights=jnp.repeat(weights, n_mc), minlength=numbins * n_mc).reshape(n_mc, -1) p_ncic_configs = ncic_hists / jnp.sum(ncic_hists, axis=1)[:, None] cic = jnp.nanmean(p_ncic_configs, axis=0) else: cic = jnp.array([jnp.nan]) warn_status = {"warn_raised": warn_raised, "n_bad_weights": n_bad_weights, "mean_bad_weight": mean_bad_weight} if return_number_densities: vol = jnp.product(self.galtabulator.halocat.Lbox) weights1 = weights # [self.sample1_inds] n_density1 = jnp.sum(weights1) / vol # weights2 = weights[self.sample2_inds] n_density2 = n_density1 # np.sum(weights2) / vol return cic, n_density1, n_density2, warn_status else: return cic, warn_status def save(self, filename): with open(filename, "wb") as f: pickle.dump(self, f, protocol=4) @classmethod def load(cls, filename): with open(filename, "rb") as f: obj = pickle.load(f) assert isinstance(obj, cls) return obj
GalaxyTabulator.tabulate_cic.__doc__ = CICTabulator.__init__.__doc__