Get Started (GPU memory-efficient version)
Last update: December 12th (2024) Hajime Kawahara
This is a device memory efficient version of “get started”!
Note: It is worth noting that batch execution of this notebook
(jupyter nbconvert --to script get_started_opart.ipynb; python get_started_opart.py
)
was successfully performed on a laptop equipped with an RTX 3080 (8GB
device memory). The device memory usage was approximately 2.4 GB.
First, we recommend 64-bit if you do not think about numerical errors. Use jax.config to set 64-bit. (But note that 32-bit is sufficient in most cases. Consider to use 32-bit (faster, less device memory) for your real use case.)
#if you wanna monitor the device memory use, you can use jax_smi
#from jax_smi import initialise_tracking
#initialise_tracking()
from jax import config
config.update("jax_enable_x64", True)
One approach to reducing device memory usage is to calculate the opacity
layer by layer and advance the radiative transfer by one layer at a
time. To achieve this, it is necessary to integrate the opacity
calculator (opa
) and the radiative transfer (art
), leading to
the use of the opart
class (opa + art). Here, we demonstrate the
calculation of a pure absorption emission spectrum using opart
.
1. Computes an Emission Spectrum using opart
The user needs to define a class, OpaLayer
, that specifies how to
calculate opacity for each layer. The OpaLayer
class must define at
least an __init__
method and a __call__
method. Additionally,
self.nu_grid
must be defined. The __call__
method should take
the parameters of a layer as input and return the optical depth (delta
tau) for that layer.
from exojax.spec.opacalc import OpaPremodit
from exojax.spec.layeropacity import single_layer_optical_depth
from exojax.utils.grids import wavenumber_grid
from exojax.spec.api import MdbExomol
from exojax.utils.astrofunc import gravity_jupiter
class OpaLayer:
# user defined class, needs to define self.nugrid
def __init__(self, Nnus=150000):
self.nu_grid, self.wav, self.resolution = wavenumber_grid(
1950.0, 2250.0, Nnus, unit="cm-1", xsmode="premodit"
)
self.mdb_co = MdbExomol(".database/CO/12C-16O/Li2015", nurange=self.nu_grid)
self.opa_co = OpaPremodit(
self.mdb_co,
self.nu_grid,
auto_trange=[500.0, 1500.0],
dit_grid_resolution=1.0,
allow_32bit=True
)
self.gravity = gravity_jupiter(1.0, 10.0)
def __call__(self, params):
temperature, pressure, dP, mixing_ratio = params
xsv_co = self.opa_co.xsvector(temperature, pressure)
dtau_co = single_layer_optical_depth(
xsv_co, dP, mixing_ratio, self.mdb_co.molmass, self.gravity
)
return dtau_co
/home/kawahara/exojax/src/exojax/spec/dtau_mmwl.py:13: FutureWarning: dtau_mmwl might be removed in future.
warnings.warn("dtau_mmwl might be removed in future.", FutureWarning)
Do not put @partial(jit, static_argnums=(0,))
on __call__
. This
is not necessary and makes the code significantly slow.
Next, the user will utilize the OpaLayer
class in the Opart
class. Here, since the goal is to calculate pure absorption emission,
the OpartEmisPure
class will be used. (Remember that if opa
and
art
are separated, the ArtEmisPure
class would have been used
instead.)
from exojax.spec.opart import OpartEmisPure
opalayer = OpaLayer(Nnus=150000)
opart = OpartEmisPure(opalayer, pressure_top=1.0e-5, pressure_btm=1.0e1, nlayer=200, nstream=8)
opart.change_temperature_range(400.0, 1500.0)
xsmode = premodit xsmode assumes ESLOG in wavenumber space: xsmode=premodit ====================================================================== The wavenumber grid should be in ascending order. The users can specify the order of the wavelength grid by themselves. Your wavelength grid is in * descending * order ====================================================================== HITRAN exact name= (12C)(16O) radis engine = vaex
/home/kawahara/exojax/src/exojax/utils/molname.py:197: FutureWarning: e2s will be replaced to exact_molname_exomol_to_simple_molname.
warnings.warn(
/home/kawahara/exojax/src/exojax/utils/molname.py:91: FutureWarning: exojax.utils.molname.exact_molname_exomol_to_simple_molname will be replaced to radis.api.exomolapi.exact_molname_exomol_to_simple_molname.
warnings.warn(
/home/kawahara/exojax/src/exojax/utils/molname.py:91: FutureWarning: exojax.utils.molname.exact_molname_exomol_to_simple_molname will be replaced to radis.api.exomolapi.exact_molname_exomol_to_simple_molname.
warnings.warn(
Molecule: CO
Isotopologue: 12C-16O
Background atmosphere: H2
ExoMol database: None
Local folder: .database/CO/12C-16O/Li2015
Transition files:
=> File 12C-16O__Li2015.trans
Broadening code level: a0
/home/kawahara/exojax/src/radis/radis/api/exomolapi.py:685: AccuracyWarning: The default broadening parameter (alpha = 0.07 cm^-1 and n = 0.5) are used for J'' > 80 up to J'' = 152
warnings.warn(
/home/kawahara/exojax/src/exojax/spec/opacalc.py:215: UserWarning: dit_grid_resolution is not None. Ignoring broadening_parameter_resolution.
warnings.warn(
OpaPremodit: params automatically set.
default elower grid trange (degt) file version: 2
Robust range: 485.7803992045456 - 1514.171191195336 K
OpaPremodit: Tref_broadening is set to 866.0254037844389 K
# of reference width grid : 2
# of temperature exponent grid : 2
uniqidx: 0it [00:00, ?it/s]
Premodit: Twt= 1108.7151960064205 K Tref= 570.4914318566549 K
Making LSD:|####################| 100%
Here, somewhat abruptly, we define a function to update a layer. This
function simply calls update_layer
within opart
and returns its
output along with None
. You might wonder why you need to define such
a function yourself. To get a bit technical, this function is used with
jax.lax.scan
when updating layers. However, if it is defined inside
a class, XLA will recompile every time the parameters change, leading to
a performance slowdown. For this reason, in the current implementation,
users are required to define this function outside the class. This
implementation may be revisited and revised in the future.
def layer_update_function(carry_tauflux, params):
carry_tauflux = opart.update_layer(carry_tauflux, params)
return carry_tauflux, None
Now, let’s define the temperature and mixing ratio profiles (in the same
way as for art
) and calculate the flux. Define the
layer_parameter
input, which is a list of parameters for all layers.
The temperature profile must be specified as the first element (index
0). For the remaining elements, arrange them in the same order as used
in the user-defined OpaLayer
.
temperature = opart.clip_temperature(opart.powerlaw_temperature(900.0, 0.1))
mixing_ratio = opart.constant_mmr_profile(0.01)
layer_params = [temperature, opart.pressure, opart.dParr, mixing_ratio]
flux = opart(layer_params, layer_update_function)
The spectrum has now been calculated. Let’s plot it. In this example, we calculate 200,000 wavenumber grid points across 200 layers. Even if the GPU you’re using has only 8 GB of device memory, such as an RTX 2080, it should be sufficient to perform the computation.
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111)
plt.plot(opalayer.nu_grid, flux)
plt.show()
2. Optimization of opart
using forward differentiation
Next, we will perform gradient-based optimization using opart
.
First, let’s generate mock data.
import numpy as np
import matplotlib.pyplot as plt
mock_spectrum = flux + np.random.normal(0.0, 1000.0, len(opalayer.nu_grid))
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111)
plt.plot(opalayer.nu_grid, mock_spectrum, ".", alpha=0.1)
#plt.plot(opalayer.nu_grid, flux, lw=1, color="red")
plt.show()
Next, define the objective function.
import jax.numpy as jnp
def objective_fluxt(T0, alpha):
temperature = opart.clip_temperature(opart.powerlaw_temperature(T0, alpha))
mixing_ratio = opart.constant_mmr_profile(0.01)
layer_params = [temperature, opart.pressure, opart.dParr, mixing_ratio]
flux = opart(layer_params , layer_update_function)
res = flux - mock_spectrum
return jnp.dot(res,res)*1.0e-12
In this example, we will optimize two parameters of the temperature
profile (T0 and powerlaw index alpha). For gradient-based optimization,
we need to compute gradients. Typically, gradients are calculated using
jax.grad
, which employs reverse differentiation. However, this
approach consumes a significant amount of memory. Instead, we use
forward differentiation. For this purpose, we utilize jax.jvp
(Jacobian-Vector Product).
from jax import jvp
fac = 1.e4
def dfluxt_fwd(params):
T = params[0]*fac
alpha = params[1]
return jnp.array([jvp(objective_fluxt, (T,alpha), (1.0,0.0))[1], jvp(objective_fluxt, (T,alpha), (0.0,1.0))[1]])
#dfluxt_fwd([900.0/fac, 0.1])
Array([1.02126306e-06, 5.78234528e-02], dtype=float64)
Let’s plot the objective function as a function of T.
import tqdm
obj = []
derivative = []
tlist = np.linspace(800.0, 1000.0, 50)/fac
for t in tqdm.tqdm(tlist):
value = objective_fluxt(t*fac, 0.1)
df = dfluxt_fwd([t, 0.1])
obj.append(value)
derivative.append(df[0])
100%|██████████| 50/50 [02:21<00:00, 2.83s/it]
fig = plt.figure()
ax = fig.add_subplot(211)
plt.plot(tlist, obj)
plt.yscale("log")
plt.ylabel("objective function")
ax = fig.add_subplot(212)
plt.plot(tlist, derivative)
plt.axhline(0.0, color="red", linestyle="--")
plt.ylabel("dflux/dT")
plt.show()
Let’s perform optimization using the gradient (JVP) with optax
’s
AdamW optimizer (you can, of course, use Adam or other optimizers if
preferred).
params = jnp.array([800.0/fac, 0.08])
objective_fluxt(params[0], params[1])
Array(16.67315012, dtype=float64)
import optax
solver = optax.adamw(learning_rate=0.01)
params = jnp.array([800.0/fac, 0.08])
opt_state = solver.init(params)
for i in range(100):
grad = dfluxt_fwd(params)
updates, opt_state = solver.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
if np.mod(i,10)==0:
print('Objective function: {:.2E}'.format(objective_fluxt(params[0]*fac, params[1])), "T0: ", params[0]*fac, "alpha: ", params[1])
Objective function: 5.00E-01 T0: 899.9991677408548 alpha: 0.06999992001951152
Objective function: 4.65E-01 T0: 939.9491536114997 alpha: 0.11398088779982679
Objective function: 2.76E-01 T0: 920.1038684443595 alpha: 0.10279728937252419
Objective function: 2.38E-01 T0: 907.3925864908261 alpha: 0.09310951818555954
Objective function: 2.11E-01 T0: 907.6350930283961 alpha: 0.10196393969349522
Objective function: 2.01E-01 T0: 901.234486639137 alpha: 0.10026516936502246
Objective function: 2.02E-01 T0: 897.0216029459805 alpha: 0.09890739924767429
Objective function: 2.01E-01 T0: 899.6532752609976 alpha: 0.10074276776874401
Objective function: 2.01E-01 T0: 900.5988168169558 alpha: 0.09939127750890416
Objective function: 2.01E-01 T0: 899.764225314858 alpha: 0.10001049531047529
Let’s compare the model using the best-fit values with the mock data.
def fluxt(T0, alpha):
temperature = opart.clip_temperature(opart.powerlaw_temperature(T0, alpha))
mixing_ratio = opart.constant_mmr_profile(0.01)
layer_params = [temperature, opart.pressure, opart.dParr, mixing_ratio]
flux = opart(layer_params , layer_update_function)
return flux
import numpy as np
mock_spectrum = flux + np.random.normal(0.0, 1000.0, len(opalayer.nu_grid))
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(211)
plt.plot(opalayer.nu_grid, mock_spectrum, ".", alpha=0.1)
plt.plot(opalayer.nu_grid, fluxt(params[0]*fac, params[1]), lw=1, color="red")
ax = fig.add_subplot(212)
plt.plot(opalayer.nu_grid, mock_spectrum-fluxt(params[0]*fac, params[1]), ".", alpha=0.1)
plt.ylabel("Residual")
plt.show()
In this way, gradient optimization can be performed in a device memory-efficient manner using forward differentiation.
3. HMC-NUTS using forward differentiation
Forward differentiation must also be used in HMC-NUTS. In NumPyro’s
NUTS, this can be achieved by setting the option
forward_mode_differentiation=True
. Other than this, the execution
method is the same as the standard HMC-NUTS.
def fluxt(T0, alpha):
temperature = opart.clip_temperature(opart.powerlaw_temperature(T0, alpha))
mixing_ratio = opart.constant_mmr_profile(0.01)
layer_params = [temperature, opart.pressure, opart.dParr, mixing_ratio]
flux = opart(layer_params , layer_update_function)
return flux
#PPL import
from numpyro.infer import MCMC, NUTS
import numpyro
import numpyro.distributions as dist
from jax import random
def model_c(y1):
T0 = numpyro.sample('T0', dist.Uniform(800.0, 1000.0))
alpha = numpyro.sample('alpha', dist.Uniform(0.05, 0.15))
mu = fluxt(T0, alpha)
sigmain = numpyro.sample('sigmain', dist.Exponential(0.001))
numpyro.sample('y1', dist.Normal(mu, sigmain), obs=y1)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 100, 200
kernel = NUTS(model_c, forward_mode_differentiation=True)
#kernel = NUTS(model_c, forward_mode_differentiation=False)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key_, y1=mock_spectrum)
mcmc.print_summary()
sample: 100%|██████████| 300/300 [1:39:18<00:00, 19.86s/it, 3 steps of size 3.22e-03. acc. prob=0.93]
mean std median 5.0% 95.0% n_eff r_hat
T0 899.97 0.08 899.96 899.83 900.09 81.13 1.00
alpha 0.10 0.00 0.10 0.10 0.10 126.72 1.00
sigmain 1002.54 1.74 1002.52 999.56 1005.38 229.04 1.01
Number of divergences: 0