Reverse modeling of Methane emission spectrum using precomputed spectrum grids

The opacity calculators in ExoJAX are fully auto-differentiable. However, in some case, the precomputation of the spectrum and the interpolation of the grid model are useful to perform rapid reverse modeling. Here, we demonstrate the grid-based retrieval using ExoJAX.

from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp
from jax import vmap

import pandas as pd
import pkg_resources

from exojax.spec.atmrt import ArtEmisPure
from exojax.spec.api import MdbExomol
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.contdb import CdbCIA
from exojax.spec.opacont import OpaCIA
from exojax.spec.response import ipgauss_sampling
from exojax.spec.spin_rotation import convolve_rigid_rotation
from exojax.spec import molinfo
from exojax.spec.unitconvert import nu2wav
from exojax.utils.grids import velocity_grid
from exojax.utils.astrofunc import gravity_jupiter
from exojax.utils.grids import wavenumber_grid
from exojax.utils.instfunc import resolution_to_gaussian_std
filename = pkg_resources.resource_filename(
    'exojax', 'data/testdata/' + SAMPLE_SPECTRA_CH4_NEW)
dat = pd.read_csv(filename, delimiter=",", names=("wavenumber", "flux"))
nusd = dat['wavenumber'].values
flux = dat['flux'].values
wavd = nu2wav(nusd)

sigmain = 0.05
norm = 20000
nflux = flux / norm + np.random.normal(0, sigmain, len(wavd))

plt.plot(wavd, nflux)

We make the grid model using ArtEmissPure and MdbExomol.

# set wavenumber grid for the model
Nx = 7500
nu_grid, wav, res = wavenumber_grid(np.min(wavd) - 10.0,
                                    np.max(wavd) + 10.0,

Tlow = 400.0
Thigh = 1500.0
art = ArtEmisPure(nu_grid, pressure_top=1.e-8, pressure_btm=1.e2, nlayer=100)
art.change_temperature_range(Tlow, Thigh)
Mp = 33.2
Rinst = 100000.
beta_inst = resolution_to_gaussian_std(Rinst)

## CH4 setting (PREMODIT)
mdb = MdbExomol('.database/CH4/12C-1H4/YT10to10/',
print('# of lines = ', len(mdb.nu_lines))
diffmode = 1
opa = OpaPremodit(mdb=mdb,
                  auto_trange=[Tlow, Thigh],

## CIA setting
cdbH2H2 = CdbCIA('.database/H2-H2_2011.cia', nu_grid)
opcia = OpaCIA(cdb=cdbH2H2, nu_grid=nu_grid)
mmw = 2.33  # mean molecular weight
mmrH2 = 0.74
molmassH2 = molinfo.molmass_isotope('H2')
vmrH2 = (mmrH2 * mmw / molmassH2)  # VMR

#settings before HMC
vsini_max = 100.0
vr_array = velocity_grid(res, vsini_max)

#given gravity, temperature exponent, MMR
g = gravity_jupiter(Rp=0.88, Mp=33.2)
alpha = 0.1
MMR_CH4 = 0.0059
xsmode =  premodit
xsmode assumes ESLOG in wavenumber space: mode=premodit
HITRAN exact name= (12C)(1H)4
HITRAN exact name= (12C)(1H)4
Background atmosphere:  H2
Reading .database/CH4/12C-1H4/YT10to10/12C-1H4__YT10to10__06000-06100.trans.bz2
Reading .database/CH4/12C-1H4/YT10to10/12C-1H4__YT10to10__06100-06200.trans.bz2
.broad is used.
Broadening code level= a1
default broadening parameters are used for  23  J lower states in  40  states
# of lines =  80505310
OpaPremodit: params automatically set.
Robust range: 397.7740728313057 - 1635.1214022614588 K
Tref changed: 296.0K->433.7496941348324K
Premodit: Twt= 1021.1189195562007 K Tref= 433.7496941348324 K
Making LSD:|####################| 100%
Making LSD:|####################| 100%

Because we would like to infer T0 and the rotational broadenings and so on, we define the raw spectrum model as a function of T0.

def raw_spectrum_model(T0):
    #T-P model
    Tarr = art.powerlaw_temperature(T0, alpha)

    xsmatrix = opa.xsmatrix(Tarr, art.pressure)
    mmr_arr = art.constant_mmr_profile(MMR_CH4)
    dtaumCH4 = art.opacity_profile_lines(xsmatrix, mmr_arr, opa.mdb.molmass, g)

    logacia_matrix = opcia.logacia_matrix(Tarr)
    dtaucH2H2 = art.opacity_profile_cia(logacia_matrix, Tarr, vmrH2, vmrH2,
                                        mmw, g)
    dtau = dtaumCH4 + dtaucH2H2
    F0 =, Tarr) / norm
    return F0

Then, we make a grid model of emission spectra as a function of T0. The spectrum is generated via the interpolation of the grid, i.e. jnp.interp. The spectrum has a dimension of wavenumber. So, we need to ‘vmap’ for jnp.interp.

# compute F0 grid given T0 grid
Ngrid = 200  # delta T = 1 K
T0_grid = jnp.linspace(1200, 1400, Ngrid)
import tqdm

F0_grid = []
for T0 in tqdm.tqdm(T0_grid, desc="computing grid"):
    F0 = raw_spectrum_model(T0)
F0_grid = jnp.array(F0_grid).T

vmapinterp = vmap(jnp.interp, (None, None, 0))
#PPL import
import arviz
from numpyro.diagnostics import hpdi
from numpyro.infer import Predictive
from numpyro.infer import MCMC, NUTS
import numpyro
import numpyro.distributions as dist

Define a model for PPL.

def model_c(nu1, y1):
    A = numpyro.sample('A', dist.Uniform(0.5, 2.0))
    RV = numpyro.sample('RV', dist.Uniform(5.0, 15.0))
    T0 = numpyro.sample('T0', dist.Uniform(1100.0, 1300.0))
    vsini = numpyro.sample('vsini', dist.Uniform(15.0, 25.0))
    F0 = A * vmapinterp(T0, T0_grid, F0_grid)
    Frot = convolve_rigid_rotation(F0, vr_array, vsini, u1=0.0, u2=0.0)
    mu = ipgauss_sampling(nu1, nu_grid, Frot, beta_inst, RV, vr_array)
    numpyro.sample('y1', dist.Normal(mu, sigmain), obs=y1)

Run HMC-NUTS! It took only within 2 minutes using my laptop (RTX 3080).

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
#kernel = NUTS(model_c, forward_mode_differentiation=True)
kernel = NUTS(model_c, forward_mode_differentiation=False)

mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples), nu1=nusd, y1=nflux)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
         A      1.01      0.00      1.01      1.00      1.01    585.55      1.00
        RV      9.43      0.41      9.43      8.81     10.18    614.05      1.00
        T0   1154.83     27.51   1156.24   1115.35   1199.99    383.16      1.00
     vsini     20.42      0.70     20.37     19.31     21.52    528.74      1.00

posterior_sample = mcmc.get_samples()
pred = Predictive(model_c, posterior_sample, return_sites=['y1'])
predictions = pred(rng_key_, nu1=nusd, y1=None)
median_mu1 = jnp.median(predictions['y1'], axis=0)
hpdi_mu1 = hpdi(predictions['y1'], 0.9)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20, 6.0))
ax.plot(wavd[::-1], median_mu1, color='C0')
ax.plot(wavd[::-1], nflux, '+', color='black', label='data')
                label='90% area')
plt.xlabel('wavelength ($\AA$)', fontsize=16)
plt.savefig("pred_diffmode" + str(diffmode) + ".png")
pararr = ['A', 'T0', 'vsini', 'RV']
plt.savefig("corner_diffmode" + str(diffmode) + ".png")