Fitting a spectrum model using Gradient Descent Based Optimization.

last update: July 2nd Hajime Kawahara

The ability of the gradient-based optimizations is s one of the major advantages of ExoJAX. Here we demonstrate how to optimize the model using jaxopt package.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp

Use 64-bit.

from jax.config import config

config.update("jax_enable_x64", True)

Here, we use a mock CH4 spectrum precomputed by ExoJAX. Also, we normalize it and add some noise.

import pkg_resources
from exojax.spec.unitconvert import nu2wav
from exojax.test.data import SAMPLE_SPECTRA_CH4_NEW

# loading the data
filename = pkg_resources.resource_filename(
    'exojax', 'data/testdata/' + SAMPLE_SPECTRA_CH4_NEW)
dat = pd.read_csv(filename, delimiter=",", names=("wavenumber", "flux"))
nusd = dat['wavenumber'].values
flux = dat['flux'].values
wavd = nu2wav(nusd)

sigmain = 0.05
norm = 20000
nflux = flux / norm + np.random.normal(0, sigmain, len(wavd))


plt.plot(wavd[::-1],nflux,alpha=0.5,color="gray")
plt.plot(wavd[::-1],flux/norm,alpha=1,color="gray")

plt.xlabel("wavelength $\AA$")
plt.show()
tutorials/optimize_spectrum_JAXopt_files/optimize_spectrum_JAXopt_7_0.png

Let’s make a model, which should be include CH4, CIA (H2-H2), spin rotation, and response… So, import everthing we need. We use PreMODIT as opa.

from exojax.utils.grids import wavenumber_grid
from exojax.spec.atmrt import ArtEmisPure
from exojax.spec.api import MdbExomol
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.contdb import CdbCIA
from exojax.spec.opacont import OpaCIA
from exojax.spec.specop import SopRotation
from exojax.spec.specop import SopInstProfile
from exojax.utils.instfunc import resolution_to_gaussian_std
/home/kawahara/exojax/src/exojax/spec/dtau_mmwl.py:14: FutureWarning: dtau_mmwl might be removed in future.
  warnings.warn("dtau_mmwl might be removed in future.", FutureWarning)

Again recall this figure.

from IPython.display import Image
Image("../exojax.png")
tutorials/optimize_spectrum_JAXopt_files/optimize_spectrum_JAXopt_11_0.png

Here we will infer here Rp, RV, MMR_CO, T0, alpha, and Vsini.

First, set the model wavenumber grids, which should cover the observational range, and the instrumental setting, and Atmospheric RT (layer) setting, art.

Nx = 1500
nu_grid, wav, res = wavenumber_grid(np.min(wavd) - 5.0,
                                np.max(wavd) + 5.0,
                                Nx,
                                unit="AA",
                                xsmode="premodit")

#Atmospheric setting by "art"
Tlow = 400.0
Thigh = 1500.0
art = ArtEmisPure(nu_grid, pressure_top=1.e-8, pressure_btm=1.e2, nlayer=100)
art.change_temperature_range(Tlow, Thigh)
Mp = 33.2

#instrumental setting
Rinst = 100000.
beta_inst = resolution_to_gaussian_std(Rinst)
xsmode =  premodit
xsmode assumes ESLOG in wavenumber space: mode=premodit
/home/kawahara/exojax/src/exojax/utils/grids.py:126: UserWarning: Resolution may be too small. R=129859.29489937567
  warnings.warn('Resolution may be too small. R=' + str(resolution),

Loading the databases, mdb for ExoMol/CH4 and cdb for CIA. Also, define opa for both databases. It takes ~ a few minites to initialize OpaPremodit (if you do not have the database, it takes more for downloading for the first time). Have a coffee and wait.

### CH4 setting (PREMODIT)
mdb = MdbExomol('.database/CH4/12C-1H4/YT10to10/',
                nurange=nu_grid,
                gpu_transfer=False)
print('N=', len(mdb.nu_lines))
diffmode = 0
opa = OpaPremodit(mdb=mdb,
                  nu_grid=nu_grid,
                  diffmode=diffmode,
                  auto_trange=[Tlow, Thigh],
                  dit_grid_resolution=0.2)

## CIA setting
from exojax.spec import molinfo

cdbH2H2 = CdbCIA('.database/H2-H2_2011.cia', nu_grid)
opcia = OpaCIA(cdb=cdbH2H2, nu_grid=nu_grid)
mmw = 2.33  # mean molecular weight
mmrH2 = 0.74
molmassH2 = molinfo.molmass_isotope('H2')
vmrH2 = (mmrH2 * mmw / molmassH2)  # VMR
/home/kawahara/exojax/src/exojax/utils/molname.py:133: FutureWarning: e2s will be replaced to exact_molname_exomol_to_simple_molname.
  warnings.warn(
/home/kawahara/exojax/src/exojax/utils/molname.py:49: UserWarning: No isotope number identified.
  warnings.warn("No isotope number identified.",UserWarning)
/home/kawahara/exojax/src/exojax/utils/molname.py:49: UserWarning: No isotope number identified.
  warnings.warn("No isotope number identified.",UserWarning)
/home/kawahara/exojax/src/exojax/spec/molinfo.py:28: UserWarning: exact molecule name is not Exomol nor HITRAN form.
  warnings.warn("exact molecule name is not Exomol nor HITRAN form.")
/home/kawahara/exojax/src/exojax/spec/molinfo.py:29: UserWarning: No molmass available
  warnings.warn("No molmass available", UserWarning)
HITRAN exact name= (12C)(1H)4
HITRAN exact name= (12C)(1H)4
Background atmosphere:  H2
Reading .database/CH4/12C-1H4/YT10to10/12C-1H4__YT10to10__06000-06100.trans.bz2
Reading .database/CH4/12C-1H4/YT10to10/12C-1H4__YT10to10__06100-06200.trans.bz2
.broad is used.
Broadening code level= a1
default broadening parameters are used for  23  J lower states in  40  states
N= 76483758
OpaPremodit: params automatically set.
Robust range: 397.77407283130566 - 1689.7679243628259 K
Tref changed: 296.0K->1153.6267095763965K
Tref_broadening is set to  774.5966692414833 K
# of reference width grid :  3
# of temperature exponent grid : 2
uniqidx: 100%|██████████| 2/2 [00:03<00:00,  1.67s/it]
Premodit: Twt= 461.3329793405918 K Tref= 1153.6267095763965 K
Making LSD:|####################| 100%
H2-H2

We have only 76,483,758 CH4 lines.

print(len(mdb.nu_lines))
76483758

Setting spectral operators.

from exojax.utils.astrofunc import gravity_jupiter
sop_rot = SopRotation(nu_grid,res,vsini_max=100.0)
sop_inst = SopInstProfile(nu_grid,res,vrmax=100.0)
/home/kawahara/exojax/src/exojax/utils/grids.py:126: UserWarning: Resolution may be too small. R=129859.29489937567
  warnings.warn('Resolution may be too small. R=' + str(resolution),
/home/kawahara/exojax/src/exojax/utils/grids.py:126: UserWarning: Resolution may be too small. R=129859.29489937567
  warnings.warn('Resolution may be too small. R=' + str(resolution),

Now we write the model, which is used in HMC-NUTS.

#response and rotation settings


def model_c(params,boost):
    Rp,RV,MMR_CH4,T0,alpha,vsini,RV=params*boost

    Tarr = art.powerlaw_temperature(T0, alpha)
    g = gravity_jupiter(Rp=Rp, Mp=Mp)  # gravity in the unit of Jupiter
    #molecule
    xsmatrix = opa.xsmatrix(Tarr, art.pressure)
    mmr_arr = art.constant_mmr_profile(MMR_CH4)
    dtaumCH4 = art.opacity_profile_lines(xsmatrix, mmr_arr, opa.mdb.molmass, g)
    #continuum
    logacia_matrix = opcia.logacia_matrix(Tarr)
    dtaucH2H2 = art.opacity_profile_cia(logacia_matrix, Tarr, vmrH2, vmrH2,
                                        mmw, g)
    #total tau
    dtau = dtaumCH4 + dtaucH2H2
    F0 = art.run(dtau, Tarr) / norm
    Frot = sop_rot.rigid_rotation(F0, vsini, u1=0.0, u2=0.0)
    Frot_inst = sop_inst.ipgauss(Frot, beta_inst)
    mu = sop_inst.sampling(Frot_inst, RV, nusd)
    return mu

Here, we use JAXopt as an optimizer. JAXopt is not automatically installed. If you need install it by pip:

pip install jaxopt

import jaxopt

We use a GradientDescent as an optimizer. Let’s normalize the parameters.

#Rp,RV,MMR_CH4,T0,alpha,vsini, RV
boost=np.array([1.0,10.0,0.1,1000.0,1.e-3,10.0,10.0])
initpar=np.array([0.8,9.0,0.01,1200.0,0.1,17.0,0.0])/boost
f = model_c(initpar,boost)
plt.plot(wavd[::-1],f)
plt.plot(wavd[::-1],nflux,alpha=0.5,color="gray")
[<matplotlib.lines.Line2D at 0x7f234d98e0d0>]
tutorials/optimize_spectrum_JAXopt_files/optimize_spectrum_JAXopt_27_1.png

Define the objective function by a L2 norm.

def objective(params):
    f=nflux-model_c(params,boost)
    g=jnp.dot(f,f)
    return g

Then, run the gradient descent.

gd = jaxopt.GradientDescent(fun=objective, maxiter=1000, stepsize=1.e-4)
res = gd.run(init_params=initpar)
params, state = res

The best-fit parameters

params*boost
DeviceArray([3.32046971e+00, 9.00000000e+00, 1.25542304e-01,
             2.10939481e+03, 1.00095859e-01, 1.93251005e+01,
             1.14472806e+01], dtype=float64)

Plot the results. Good but a bit poor compared with the input… O.K. I prefer ADAM to GD let’s try next.

model=model_c(params,boost)
inmodel=model_c(initpar,boost)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20,6.0))
ax.plot(wavd[::-1],model,color="C1",lw=3,label="fitted")
ax.plot(wavd[::-1],flux/norm,alpha=1,color="black",label="input")

#ax.plot(wavd[::-1],inmodel,color="gray",lw=3,label="initial parameter")
ax.plot(wavd[::-1],nflux,"+",color="black",label="data")
plt.xlabel("wavelength ($\AA$)",fontsize=16)
plt.legend(fontsize=16)
plt.tick_params(labelsize=16)
plt.savefig("gradient_descent_jaxopt.png")
tutorials/optimize_spectrum_JAXopt_files/optimize_spectrum_JAXopt_35_0.png

BTW, We can do the optimization one by one update. It’s useful when you wanna visualize the optimization process.

import tqdm
gd = jaxopt.GradientDescent(fun=objective, stepsize=1.e-4)
state = gd.init_state(initpar)
params=np.copy(initpar)

params_gd=[]
Nit=300
for _ in  tqdm.tqdm(range(Nit)):
    params,state=gd.update(params,state)
    params_gd.append(params)
100%|██████████| 300/300 [00:40<00:00,  7.41it/s]

Using ADAM optimizer

You might use ADAM, instead of a simple GD. Yes, you can.

from jaxopt import OptaxSolver
import optax
import tqdm
adam = OptaxSolver(opt=optax.adam(2.e-2), fun=objective)
state = adam.init_state(initpar)
params_a=np.copy(initpar)

params_adam=[]
Nit=300
for _ in  tqdm.tqdm(range(Nit)):
    params_a,state=adam.update(params_a,state)
    params_adam.append(params_a)
100%|██████████| 300/300 [00:20<00:00, 14.31it/s]
model_adam=model_c(params_a,boost)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20,6.0))
ax.plot(wavd[::-1],model,color="C1",lw=4,label="GD")
ax.plot(wavd[::-1],model_adam,color="C2",lw=4,ls="dashed",label="ADAM")
ax.plot(wavd[::-1],flux/norm,alpha=1,color="black",label="input")

#ax.plot(wavd[::-1],inmodel,color="gray",lw=3,label="initial parameter")
ax.plot(wavd[::-1],nflux,"+",color="black",label="data")
plt.xlabel("wavelength ($\AA$)",fontsize=16)
plt.legend(fontsize=16)
plt.tick_params(labelsize=16)
plt.savefig("gradient_descent_jaxopt.png")
tutorials/optimize_spectrum_JAXopt_files/optimize_spectrum_JAXopt_42_0.png

ADAM is faster and better than GD? I love ADAM.

# if you wanna optimize at once, run the following:
# res = solver.run(init_params=initpar)
# params, state = res

make a movie

Make the movie directory (mkdir movie), and let’s make squential png files.

inmodel=model_c(initpar,boost)
for i in tqdm.tqdm(range(Nit)):
    spec_gd=model_c(params_gd[i],boost)
    spec_adam=model_c(params_adam[i],boost)
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(20,6.0))
    ax.plot(wavd[::-1],spec_gd,color="C0",lw=3,label="GD")
    ax.plot(wavd[::-1],spec_adam,color="C1",lw=3,label="ADAM")
    ax.plot(wavd[::-1],inmodel,color="gray",label="initial parameter")
    ax.plot(wavd[::-1],nflux,"+",color="black",label="data")
    plt.xlabel("wavelength ($\AA$)",fontsize=16)
    plt.tick_params(labelsize=16)
    plt.ylim(0.0,0.6)
    plt.legend(loc="lower left",fontsize=14)
    plt.savefig("movie/gradient_descent_jaxopt"+str(i).zfill(4)+".png")
    plt.close()
100%|██████████| 300/300 [00:57<00:00,  5.19it/s]
#for instance, make a movie by
# > ffmpeg -r 30 -i gradient_descent_jaxopt%04d.png -vcodec libx264 -pix_fmt yuv420p -r 60 outx.mp4