Quickstart Tutorial

galtab is a general approach for calculating the expectation value of counts-in-cells statistics for a given halo catalog and HOD model. It pretabulates placeholder galaxies inside each halo to yield rapid, deterministic results, which is ideal for MCMC likelihood evaluations.

This tutorial will demonstrate some basic Counts-in-Cylinders (CiC) calculations using the intended galtab workflow.

To cite galtab, learn more implementation details, and explore an example science use case, check out https://arxiv.org/abs/2309.08675.

Prerequisites

All of the following are pip installable

  • galtab

    • numpy

    • jax

    • astropy

    • halotools

  • matplotlib

  • jupyterlab

After installing the above and downloading the bolplanck z=0 halotools catalog, you should be able to run the following cell. In this cell:

  • set our cosmology and CiC parameters

  • choose an HOD model

  • load the simulation data

[1]:
import numpy as np
import matplotlib.pyplot as plt
import jax
from jax import numpy as jnp

from astropy import cosmology
import halotools.empirical_models as htem
import halotools.sim_manager as htsm
import halotools.mock_observables as htmo

import galtab

# Set our CiC parameters (all lengths are in Mpc/h)
proj_search_radius = 2.0
cylinder_half_length = 10.0
cic_edges = np.arange(-0.5, 16)

# Set our cosmology and HOD model
cosmo = cosmology.Planck13
hod = htem.PrebuiltHodModelFactory("zheng07", threshold=-21)

# Load Bolshoi-Planck simulation halos at z=0
halocat = htsm.CachedHaloCatalog(simname="bolplanck", redshift=0)
halocat.halo_table[:5]
[1]:
Table length=5
halo_vmax_firstacchalo_dmvir_dt_tdynhalo_macchalo_scale_factorhalo_vmax_mpeakhalo_m_pe_behroozihalo_delta_vmax_behroozi17halo_xoffhalo_spinhalo_tidal_forcehalo_scale_factor_firstacchalo_c_to_ahalo_mvir_firstacchalo_scale_factor_last_mmhalo_tidal_idhalo_scale_factor_mpeakhalo_pidhalo_m500chalo_idhalo_halfmass_scale_factorhalo_upidhalo_t_by_uhalo_rvirhalo_vpeakhalo_dmvir_dt_100myrhalo_mpeakhalo_m_pe_diemerhalo_jxhalo_jyhalo_jzhalo_m2500chalo_mvirhalo_voffhalo_axisA_zhalo_axisA_xhalo_axisA_yhalo_yhalo_b_to_ahalo_xhalo_zhalo_m200bhalo_vacchalo_scale_factor_lastacchalo_vmaxhalo_m200chalo_vxhalo_vyhalo_vzhalo_dmvir_dt_insthalo_tidal_force_tdynhalo_rshalo_nfw_conchalo_hostidhalo_mvir_host_halo
float32float32float32float32float32float32float32float32float32float32float32float32float32float32int64float32int64float32int64float32int64float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32float32int64float32
1001.5712810.0200800000000000.01.002311001.57202700000000000.0-0.00260.02573570.023910.119541.002310.47559200800000000000.00.2834328126061931.002-1116580000000000.028110426390.41506-10.5931.1904471091.3817390.0200800000000000.0111800000000000.02536000000000000.0-474400000000000.0-6566000000000000.065777000000000.0200800000000000.021.3419.223159.7891-18.800143.140820.6366336.1798417.96339223780000000000.01001.571.002311001.57158240000000000.016.18.51-78.8817390.00.122440.1379538.6293672811042639200800000000000.0
895.213760.0179600000000000.01.00231895.2181000000000000.0-0.010650.0419870.062970.505871.002310.56181179600000000000.00.2986228110771051.002-1100360000000000.028110556060.50618-10.6271.146849969.057324.0179600000000000.0128700000000000.01.074e+164931000000000000.0-1.185e+1647026000000000.0179600000000000.041.9141.206234.680317.888249.544170.839745.3664440.01593204460000000000.0895.21.00231895.2142290000000000.02.46264.77-128.087324.00.4910.1858056.1723262811055606179600000000000.0
853.834666.0129800000000000.01.00231853.83149500000000000.00.005310.0264619010.036070.075681.002310.66381129800000000000.00.4960628106302421.002-187766000000000.028092501670.491-10.57741.029343926.372747.0129800000000000.080320000000000.02133000000000000.0-3236000000000000.0-3111000000000000.039496000000000.0129800000000000.023.35-17.526838.959624.362613.882610.7614922.023189.80153141210000000000.0853.831.00231853.83112010000000000.018.49124.89-35.192747.00.100740.1192939958.6286242809250167129800000000000.0
777.644401.0103000000000000.01.00231777.64104800000000000.00.004980.05169980.050310.096771.002310.47302103000000000000.00.3846928205928161.002-157781000000000.028094839460.65806-10.61520.952978831.172747.0103000000000000.064200000000000.01713000000000000.0-1488000000000000.04582000000000000.030529000000000.0103000000000000.098.4524.7744-10.356838.994936.678810.788112.2978834.18085115110000000000.0777.641.00231777.6482069000000000.0-281.37-115.39-391.282747.00.102590.1323347.2013092809483946103000000000000.0
748.5611480.099470000000000.01.00231748.56107600000000000.00.059890.07796970.03480.124651.002310.4740999470000000000.00.6327528094839461.002-159100000000000.028092726030.63781-10.670.941893748.565218.099470000000000.070970000000000.01207000000000000.0-2126000000000000.0-2677000000000000.026267000000000.099470000000000.0118.7929.218352.77966.1883626.128770.6615510.6603722.5009108110000000000.0748.561.00231748.5684337000000000.0-43.87292.95-171.475218.00.15790.140776.691006280927260399470000000000.0

Calculate CiC the standard way with halotools

  • Populate the halocat with galaxies probabilistically from the HOD model

  • Compute the number of neighbors within a cylinder around each neighbor

  • Tally up a histogram of the neighbor counts for a given set of CiC bins

[2]:
# Choose your HOD parameters (in this case, we will keep them the same)
hod.param_dict.update({})

# Populated model galaxies and get their Cartesian coordinates
hod.populate_mock(halocat)
galaxies = hod.mock.galaxy_table
xyz = htmo.return_xyz_formatted_array(
    galaxies["x"], galaxies["y"], galaxies["z"], velocity=galaxies["vz"],
    velocity_distortion_dimension="z", period=halocat.Lbox, cosmology=cosmo
)

# Compute CiC (self-counting subtracted by the `-1`)
cic_counts = htmo.counts_in_cylinders(
    xyz, xyz, proj_search_radius, cylinder_half_length) - 1
cic_halotools = np.histogram(cic_counts, bins=cic_edges, density=True)[0]
cic_halotools
[2]:
array([0.3675878 , 0.25521013, 0.15238918, 0.08963731, 0.05152562,
       0.03218192, 0.0193437 , 0.01162925, 0.00420265, 0.00529649,
       0.00362694, 0.00270581, 0.00166955, 0.00149683, 0.00080599,
       0.00069085])

Now let’s do it the galtab way

[3]:
# Give the Tabulator the halo catalog and a fiducial HOD model
gtab = galtab.GalaxyTabulator(halocat, hod)

# Prepare the CICTabulator to make predictions
cictab = galtab.CICTabulator(gtab, proj_search_radius, cylinder_half_length,
                            bin_edges=cic_edges)

# Choose your HOD parameters (in this case, we will keep them the same)
hod.param_dict.update({})

# Predict CiC for this model
cic_galtab = cictab.predict(hod)
cic_galtab
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[3]:
DeviceArray([0.35819536, 0.2572313 , 0.15485156, 0.09022042, 0.05297909,
             0.03188442, 0.01996169, 0.0124573 , 0.00806456, 0.00497749,
             0.00318336, 0.00215645, 0.0015641 , 0.00095675, 0.00078415,
             0.00053201], dtype=float32)

Plot the galtab vs. halotools comparison

  • galtab predicts the CiC expectation value (smooth + deterministic)

  • halotools draws a CiC realization (noisy + stochastic)

[4]:
cic_cens = 0.5 * (cic_edges[:-1] + cic_edges[1:])
plt.semilogy(cic_cens, cic_galtab, label="galtab", lw=3)
plt.semilogy(cic_cens, cic_halotools, label="halotools", lw=3, ls="--")
plt.legend(frameon=False)
plt.xlabel("$N_{\\rm CiC}$")
plt.ylabel("$P(N_{\\rm CiC})$")
plt.show()
../_images/notebooks_intro_9_0.png

In Development: Differentiate CiC w.r.t. the HOD parameter \(\log M_{\rm min}\)

galtab is implemented in JAX, so it is portable to GPU and differentiable (in principal), assuming your HOD model is compatible with JAX. Unfortunately, this requires a few modifications to halotools models. For example, let’s use the JaxZheng07Cens and JaxZheng07Sats models, originally implemented for the JaxTabCorr project.

We can construct a composite HOD model with our JAX-compatible mean occupation functions, which we call hod_jax. This model allows us to differentiate cictab.predict with jax.grad.

Note: You shouldn’t try using jax.jit directly on cictab.predict, since it contains some I/O lines that can’t be compiled. Rest assured that the primary expensive computations will automatically compile and run on the GPU if available.

[5]:
from galtab.jaxhalotools import JaxZheng07Cens, JaxZheng07Sats

# Create JAX-compatible composite HOD model
hod_jax = htem.HodModelFactory(
    centrals_occupation=JaxZheng07Cens(threshold=-21),
    satellites_occupation=JaxZheng07Sats(threshold=-21),
    centrals_profile=htem.TrivialPhaseSpace(),
    satellites_profile=htem.NFWPhaseSpace()
)

# Define function that predictions P(N_cic = 1)
def calc_cic1(logMmin=12.79):
    hod_jax.param_dict.update({"logMmin": logMmin})
    return cictab.predict(hod_jax, warn_p_over_1=False)[1]

# Define the derivative of calc_cic1
diff_cic1 = jax.grad(calc_cic1)

# Note that we shouldn't make logMmin too much lower than that of our fiducial
# model. If desired, make more conservative choices for the fiducial parameters.
# i.e., low logMmin / logM1 / logM0 values and large sigma_logM values
for logmmin in np.linspace(11.0, 15.0, 20):
    value = calc_cic1(logmmin)
    derivative = diff_cic1(logmmin)

    plt.plot(logmmin, value, "bo")
    plt.quiver(logmmin, value, 1, derivative, angles="xy")

plt.xlabel("$\\log M_{\\rm min}$")
plt.ylabel("$P(N_{\\rm CiC} = 1)$")
plt.show()
../_images/notebooks_intro_12_0.png

jax.grad (the arrows in the above plot) isn’t working yet

  • I actually wasn’t expecting the above to work perfectly, because it’s using the Monte-Carlo mode, which isn’t perfectly continuous

  • But analytic mode moment derivatives aren’t working either…

  • TODO: Figure out what’s going wrong