{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Quickstart Tutorial" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`galtab` is a general approach for calculating the expectation value of\n", "counts-in-cells statistics for a given halo catalog and HOD model. It pretabulates\n", "placeholder galaxies inside each halo to yield rapid, deterministic results,\n", "which is ideal for MCMC likelihood evaluations.\n", "\n", "This [tutorial](https://github.com/AlanPearl/galtab/blob/main/docs/source/notebooks/intro.ipynb)\n", "will demonstrate some basic Counts-in-Cylinders (CiC) calculations\n", "using the intended `galtab` workflow.\n", "\n", "To cite `galtab`, learn more implementation details, and explore an example science\n", "use case, check out https://arxiv.org/abs/2309.08675." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequisites\n", "\n", "All of the following are `pip` installable\n", "\n", "- `galtab`\n", " - `numpy`\n", " - `jax`\n", " - `astropy`\n", " - `halotools`\n", "- `matplotlib`\n", "- `jupyterlab`\n", "\n", "After installing the above *and downloading the bolplanck z=0 halotools catalog*,\n", "you should be able to run the following cell. In this cell:\n", "\n", "- set our cosmology and CiC parameters\n", "- choose an HOD model\n", "- load the simulation data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import jax\n", "\n", "from astropy import cosmology\n", "import halotools.empirical_models as htem\n", "import halotools.sim_manager as htsm\n", "import halotools.mock_observables as htmo\n", "\n", "import galtab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download an example halotools catalog\n", "htsm.DownloadManager().download_processed_halo_table(\n", " 'bolplanck', 'rockstar', 0.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set our CiC parameters (all lengths are in Mpc/h)\n", "proj_search_radius = 2.0\n", "cylinder_half_length = 10.0\n", "cic_edges = np.arange(-0.5, 16)\n", "\n", "# Set our cosmology and HOD model\n", "cosmo = cosmology.Planck13\n", "hod = htem.PrebuiltHodModelFactory(\"zheng07\", threshold=-21)\n", "\n", "# Load Bolshoi-Planck simulation halos at z=0\n", "halocat = htsm.CachedHaloCatalog(simname=\"bolplanck\", redshift=0)\n", "halocat.halo_table[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Calculate CiC the standard way with `halotools`\n", "\n", "- Populate the halocat with galaxies probabilistically from the HOD model\n", "- Compute the number of neighbors within a cylinder around each neighbor\n", "- Tally up a histogram of the neighbor counts for a given set of CiC bins" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Choose your HOD parameters (in this case, we will keep them the same)\n", "hod.param_dict.update({})\n", "\n", "# Populated model galaxies and get their Cartesian coordinates\n", "hod.populate_mock(halocat, seed=0)\n", "galaxies = hod.mock.galaxy_table\n", "xyz = htmo.return_xyz_formatted_array(\n", " galaxies[\"x\"], galaxies[\"y\"], galaxies[\"z\"], velocity=galaxies[\"vz\"],\n", " velocity_distortion_dimension=\"z\", period=halocat.Lbox, cosmology=cosmo\n", ")\n", "\n", "# Compute CiC (self-counting subtracted by the `-1`)\n", "cic_counts = htmo.counts_in_cylinders(\n", " xyz, xyz, proj_search_radius, cylinder_half_length,\n", " period=halocat.Lbox) - 1\n", "cic_halotools = np.histogram(cic_counts, bins=cic_edges, density=True)[0]\n", "cic_halotools" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now let's do it the `galtab` way" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Give the Tabulator the halo catalog and a fiducial HOD model\n", "gtab = galtab.GalaxyTabulator(halocat, hod)\n", "\n", "# Prepare the CICTabulator to make predictions\n", "cictab = galtab.CICTabulator(gtab, proj_search_radius, cylinder_half_length,\n", " bin_edges=cic_edges)\n", "\n", "# Choose your HOD parameters (in this case, we will keep them the same)\n", "hod.param_dict.update({})\n", "\n", "# Predict CiC for this model\n", "cic_galtab = cictab.predict(hod)\n", "cic_galtab" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optionally, write the CIC tabulation to disk for later use:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Pickle the CICTabulator - creates a large file named `cictab.pickle`:\n", "cictab.save(\"cictab.pickle\")\n", "\n", "# And load it back with:\n", "cictab = galtab.CICTabulator.load(\"cictab.pickle\")\n", "gtab = cictab.galtabulator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the `galtab` vs. `halotools` comparison\n", "\n", "- `galtab` predicts the CiC expectation value (smooth + deterministic)\n", "- `halotools` draws a CiC realization (noisy + stochastic)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cic_cens = 0.5 * (cic_edges[:-1] + cic_edges[1:])\n", "plt.semilogy(cic_cens, cic_galtab, label=\"galtab\", lw=3)\n", "plt.semilogy(cic_cens, cic_halotools, label=\"halotools\", lw=3, ls=\"--\")\n", "plt.legend(frameon=False)\n", "plt.xlabel(\"$N_{\\\\rm CiC}$\")\n", "plt.ylabel(\"$P(N_{\\\\rm CiC})$\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## **In Development:** Differentiate CiC w.r.t. the HOD parameter $\\log M_{\\rm min}$\n", "\n", "`galtab` is implemented in JAX, so it is portable to GPU and differentiable\n", "(in principal), assuming your HOD model is compatible with JAX. Unfortunately,\n", "this requires a few modifications to `halotools` models. For example, let's\n", "use the `JaxZheng07Cens` and `JaxZheng07Sats` models, originally implemented\n", "for the [JaxTabCorr](https://github.com/AlanPearl/JaxTabCorr) project." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can construct a composite HOD model with our JAX-compatible mean\n", "occupation functions, which we call `hod_jax`. This model allows us to\n", "differentiate `cictab.predict` with `jax.grad`.\n", "\n", "*Note:* You shouldn't try using `jax.jit` directly on `cictab.predict`, since it\n", "contains some lines of code that can't be compiled. Rest assured that the primary\n", "expensive computations will automatically compile and run on the GPU if available." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from galtab.jaxhalotools import JaxZheng07Cens, JaxZheng07Sats\n", "\n", "# Create JAX-compatible composite HOD model\n", "def make_hod_jax():\n", " return htem.HodModelFactory(\n", " centrals_occupation=JaxZheng07Cens(threshold=-21),\n", " satellites_occupation=JaxZheng07Sats(threshold=-21),\n", " centrals_profile=htem.TrivialPhaseSpace(),\n", " satellites_profile=htem.NFWPhaseSpace()\n", " )\n", "\n", "# Define function that predictions P(N_cic = 1)\n", "def calc_cic1(logMmin=12.79):\n", " hod_jax = make_hod_jax()\n", " hod_jax.param_dict.update({\"logMmin\": logMmin})\n", " return cictab.predict(hod_jax, warn_p_over_1=False)[1]\n", "\n", "# Define the derivative of calc_cic1\n", "diff_cic1 = jax.grad(calc_cic1)\n", "\n", "# Note that we shouldn't make logMmin too much lower than that of our fiducial\n", "# model. If desired, make more conservative choices for the fiducial parameters.\n", "# i.e., low logMmin / logM1 / logM0 values and large sigma_logM values\n", "for logmmin in np.linspace(11.0, 15.0, 20):\n", " value = calc_cic1(logmmin)\n", " derivative = diff_cic1(logmmin)\n", "\n", " plt.plot(logmmin, value, \"bo\")\n", " plt.quiver(logmmin, value, 1, derivative, angles=\"xy\")\n", "\n", "plt.xlabel(\"$\\\\log M_{\\\\rm min}$\")\n", "plt.ylabel(\"$P(N_{\\\\rm CiC} = 1)$\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `jax.grad` (the arrows in the above plot) isn't working *yet*...\n", "\n", "- I actually wasn't expecting the above to work perfectly, because it's using the Monte-Carlo mode, which isn't perfectly continuous\n", "- But analytic mode moment derivatives aren't working either...\n", "- TODO: Figure out what's going wrong" ] } ], "metadata": { "kernelspec": { "display_name": "311np2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }