SB model fitting¶
In this notebook, I show how to use numpyro
& haiku
to perform a MCMC fitting on the surface brightness of a cluster.
import numpyro
numpyro.enable_x64()
numpyro.set_host_device_count(4)
import astropy.units as u
from xsb_fluc.data.cluster import Cluster
cluster = Cluster(
imglink='data/A2142/mosaic_a2142.fits.gz',
explink='data/A2142/mosaic_a2142_expo.fits.gz',
bkglink='data/A2142/mosaic_a2142_bkg.fits.gz',
reglink='data/A2142/src_ps.reg',
nhlink='data/A2142/A2142_nh.fits',
ra=239.58615,
dec=27.229433,
r_500=1.403*u.Mpc,
redshift=0.09,
)
WARNING: FITSFixedWarning: RADECSYS= 'FK5 ' / Stellar reference frame the RADECSYS keyword is deprecated, use RADESYSa. [astropy.wcs.wcs] WARNING: FITSFixedWarning: EQUINOX = '2000.0 ' / Coordinate system equinox a floating-point value was expected. [astropy.wcs.wcs]
Build the numpyro
model¶
I'll demonstrate how to fit an elliptic model for the surface brightness, using MCMC sample with both haiku
to define the model and its parameters and numpyro
to perform the sampling using the NUTS sampler. In the following cell, I reduce the cluster
using the Voronoi tesselation derived in the associated notebook, and will infer the number of counts in each of the bins. The MockXrayCountsBetaModel
is a built-in of the xsb_fluc
package and is programmed as an haiku Module, so it has to be transformed.
import haiku as hk
from xsb_fluc.simulation.mock_image import MockXrayCountsBetaModel
cluster_voronoi = cluster.voronoi('data/A2142/voronoi.txt')
images_simulator = hk.without_apply_rng(hk.transform(lambda : MockXrayCountsBetaModel(cluster_voronoi)()))
Once transformed, we can see the required parameters by using the .init
method. Here we display them to get an idea of the shape that is accepted by haiku
for the parameters
images_simulator.init(None)
{'mock_xray_counts_beta_model/~/ellipse_radius': {'angle': Array(0., dtype=float32), 'eccentricity': Array(0., dtype=float32), 'x_c': Array(0., dtype=float32), 'y_c': Array(0., dtype=float32)}, 'mock_xray_counts_beta_model/~/xray_surface_brightness_beta_model': {'log_bkg': Array(-5., dtype=float32), 'log_e_0': Array(-4., dtype=float32), 'log_r_c': Array(-1., dtype=float32), 'beta': Array(0.6666667, dtype=float32)}}
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
prior_distributions = {
'mock_xray_counts_beta_model/~/ellipse_radius': {
'angle': dist.Uniform(0., jnp.pi/2),
'eccentricity': dist.Uniform(0, 0.99),
'x_c': dist.Normal(0, 1),
'y_c': dist.Normal(0, 1)
},
'mock_xray_counts_beta_model/~/xray_surface_brightness_beta_model': {
'log_bkg': jnp.asarray(-100.),
'log_e_0': dist.Uniform(-6, 0),
'log_r_c': dist.Uniform(-3, 0),
'beta': dist.Uniform(0, 5)
}
}
def numpyro_model(observed_cluster=None):
# Here, we inform numpyro that we want to draw the parameters from prior distributions
samples = hk.data_structures.to_haiku_dict(prior_distributions)
for module, parameter, prior in hk.data_structures.traverse(prior_distributions):
samples[module][parameter] = numpyro.sample(parameter, prior) if isinstance(prior, dist.Distribution) else prior
# We compute the expected values using the samples from prior distribution
images_simulator = hk.without_apply_rng(hk.transform(lambda : MockXrayCountsBetaModel(observed_cluster)()))
expected_counts = images_simulator.apply(samples)
# We compare it to the actually observed counts in each pixel
numpyro.sample('likelihood', dist.Poisson(expected_counts), obs=observed_cluster.img)
Run the MCMC¶
Now that the model is built, we can leverage numpyro to perform an efficient sampling. The NUTS sampler is an adaptive sampler that excels at generating uncorrelated parameters really fast. By this, I mean that it requires a reduced number of steps to converge and produce quality posterior samples. In the following cell, I build a NUTS kernel and run 4 chains in parallel with 10 000 burn-in steps.
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS
kernel = NUTS(numpyro_model, max_tree_depth=10)
mcmc = MCMC(kernel, num_chains=4, num_warmup=10000, num_samples=1000)
mcmc.run(PRNGKey(0), observed_cluster=cluster_voronoi)
0%| | 0/11000 [00:00<?, ?it/s]
0%| | 0/11000 [00:00<?, ?it/s]
0%| | 0/11000 [00:00<?, ?it/s]
0%| | 0/11000 [00:00<?, ?it/s]
Analyse the result¶
Now that the MCMC is run, we can check the convergence and analyse the results using arviz
(imo the best library to analyses MCMC results). We can build an InferenceData
object from our previous MCMC and check some summary statistics. I would like to focus on the $\hat{R}$ parameter which is a good criterion to asses the convergence of a chain, supposedly achieved if $\hat{R} < 1.01$. This is true for every parameters of our model.
import arviz as az
import matplotlib.pyplot as plt
inference_data = az.from_numpyro(mcmc)
az.summary(inference_data)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
angle | 0.852 | 0.003 | 0.847 | 0.857 | 0.0 | 0.0 | 2749.0 | 2182.0 | 1.0 |
beta | 0.547 | 0.001 | 0.546 | 0.548 | 0.0 | 0.0 | 1485.0 | 1687.0 | 1.0 |
eccentricity | 0.753 | 0.001 | 0.751 | 0.756 | 0.0 | 0.0 | 3020.0 | 2657.0 | 1.0 |
log_e_0 | -4.924 | 0.002 | -4.927 | -4.920 | 0.0 | 0.0 | 2066.0 | 2317.0 | 1.0 |
log_r_c | -1.157 | 0.002 | -1.161 | -1.153 | 0.0 | 0.0 | 1386.0 | 1480.0 | 1.0 |
x_c | 0.008 | 0.000 | 0.008 | 0.009 | 0.0 | 0.0 | 4577.0 | 2895.0 | 1.0 |
y_c | 0.013 | 0.000 | 0.012 | 0.013 | 0.0 | 0.0 | 4514.0 | 3075.0 | 1.0 |
The below cell shows a trace plot which shows convergence over time, which is clearly achieved here with the uncorrelated samples and the agreeing marginal distributions.
with az.style.context("arviz-darkgrid", after_reset=True):
az.plot_trace(inference_data, compact=False)
plt.show();
We can also investigate the correlation of the posterior samples, which is the best view we can get of our posterior distributions. This is often referred as a pair plot or a corner plot. One should take a look at ChainConsumer to get pretty corner plot which are paper ready.
with az.style.context("arviz-darkgrid", after_reset=True):
az.plot_pair(inference_data, kind='kde')
plt.show();
/Users/sdupourque/opt/anaconda3/envs/clusterLegacy/lib/python3.10/site-packages/arviz/plots/pairplot.py:232: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`. gridsize = int(dataset.dims["draw"] ** 0.35)
Posterior predictive¶
With these posterior samples, we can have a nice idea of what is the best image we can get and even compute fluctuation maps ! To do so, we just need to inject the parameters in our previously define model. In general, the median of the distribution is coincident with the best fit parameters so we use it to plot our posterior predictive and fluctuation map. The cell below is just a parameterization to build a haiku-friendly dictionary of parameters.
import json
import numpy as np
from jax.tree_util import tree_map
# Just save the sample for later uses
with open('data/A2142/posterior_parameters.json', 'w') as file:
json.dump(tree_map(lambda x: list(np.asarray(x)), mcmc.get_samples()), file)
posterior_parameters = tree_map(lambda x: jnp.median(x), mcmc.get_samples())
hk_posterior_parameters = hk.data_structures.to_haiku_dict(prior_distributions)
for module, parameter, prior in hk.data_structures.traverse(prior_distributions):
if parameter in list(posterior_parameters.keys()):
hk_posterior_parameters[module][parameter] = posterior_parameters[parameter]
hk_posterior_parameters
{'mock_xray_counts_beta_model/~/ellipse_radius': {'angle': Array(0.85220802, dtype=float64), 'eccentricity': Array(0.75334761, dtype=float64), 'x_c': Array(0.00841448, dtype=float64), 'y_c': Array(0.01253357, dtype=float64)}, 'mock_xray_counts_beta_model/~/xray_surface_brightness_beta_model': {'log_bkg': Array(-100., dtype=float64, weak_type=True), 'log_e_0': Array(-4.92391862, dtype=float64), 'log_r_c': Array(-1.15687444, dtype=float64), 'beta': Array(0.54693521, dtype=float64)}}
Once the parameter are properly formatted, we can recreate a new simulator function which is suited to the visualization we want (as a reminder, we perform the fit on the Voronoi tesselation, so it would be ugly to plot). The next cell show how to do this.
import cmasher as cmr
import numpy as np
from matplotlib.colors import LogNorm, SymLogNorm
cluster_to_plot = cluster.reduce_to_r500(0.75)
images_simulator_full = hk.without_apply_rng(hk.transform(lambda : MockXrayCountsBetaModel(cluster_to_plot)()))
best_fit_image = images_simulator_full.apply(hk_posterior_parameters)
Finally, we can plot our true image, best fit model, and fluctuation map.
fig, axs = plt.subplots(
figsize=(12, 5),
nrows=1,
ncols=3,
subplot_kw={'projection': cluster.wcs}
)
mask = cluster_to_plot.exp > 0
xsb_fluc = (cluster_to_plot.img - best_fit_image)/(2*cluster_to_plot.exp)
img_norm = LogNorm(vmin=0.5, vmax=200)
map_img = axs[0].imshow(np.where(mask, cluster_to_plot.img, np.nan), norm=img_norm, cmap=cmr.cosmic)
map_fit = axs[1].imshow(np.where(mask, best_fit_image, np.nan), norm=img_norm, cmap=cmr.cosmic)
map_fluc = axs[2].imshow(np.where(mask, xsb_fluc, np.nan), cmr.guppy, norm=SymLogNorm(vmin=-5e-6, vmax=5e-6, linthresh=1e-7))
plt.colorbar(map_img, ax=axs[0], location='bottom', label='Counts (True image)')
plt.colorbar(map_fit, ax=axs[1], location='bottom', label='Counts (Fitted image)')
plt.colorbar(map_fluc, ax=axs[2], location='bottom', label='Fluctuations')
plt.show();