Fitting to the RV dataΒΆ

Update: April 4/2022, Hajime Kawahara

Exojax has some function for planet dynamics. Here, we show an example of the RV fitting using HMC/NUTS.

../_images/resultsrv.png
import pandas as pd
import numpy as np
import jax.numpy as jnp
import tqdm
from jax import random
from jax import vmap, jit
import matplotlib.pyplot as plt
from exojax.dynamics.rvfunc import rvf
from exojax.dynamics.getE import getE

We first generate a mock RV curve.

N=100
t=np.random.rand(N)*100
T0=0.0
P=10.0
e=0.3
omegaA=0.5
Ksini=10.0
Vsys=5.0
model=rvf(t,T0,P,e,omegaA,Ksini,Vsys)
sigma=3.0
np.random.seed(1)
noise=np.random.normal(0.0,sigma,N)
rv=model+noise
err=sigma*np.ones(N)/2.0

The following is the fitting part.

import numpyro.distributions as dist
import numpyro
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.diagnostics import hpdi

def model_c(t1,y1,e1):
    P=numpyro.sample("P", dist.Uniform(8.0,12.0))
    Ksini=numpyro.sample('Ksini', dist.Exponential(0.1)) #should be modified Jeffery later
    T0 = numpyro.sample('T0', dist.Uniform(-6.0,6.0))
    sesinw = numpyro.sample('sesinw', dist.Uniform(-1.0,1.0))
    secosw = numpyro.sample('secosw', dist.Uniform(-1.0,1.0))
    etmp=sesinw**2+secosw**2
    e=jnp.where(etmp>1.0,1.0,etmp)
    omegaA=jnp.arctan2(sesinw,secosw) #
#    sigmajit=numpyro.sample('sigmajit', dist.Uniform(0.1,100.0))
    sigmajit=numpyro.sample('sigmajit', dist.Exponential(1.0))
    Vsys = numpyro.sample('Vsys', dist.Uniform(-10,10.0))
    mu=rvf(t1,T0,P,e,omegaA,Ksini,Vsys)
    errall=jnp.sqrt(e1**2+sigmajit**2)
    numpyro.sample("y1", dist.Normal(mu, errall), obs=y1) #-


#Running a HMC-NUTS
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 2000, 4000
kernel = NUTS(model_c)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key_, t1=t, y1=rv, e1=err)
mcmc.print_summary()

Save the posterior and plot the prediction.

#Post-processing
posterior_sample = mcmc.get_samples()
np.savez("savepos.npz",[posterior_sample])

fig=plt.figure(figsize=(10,7))
ax=fig.add_subplot(111)
ax.errorbar(t,rv,yerr=err,ls="none")
ax.plot(t,rv,"o")

sesinw=posterior_sample["sesinw"]
secosw=posterior_sample["secosw"]
eps=sesinw**2+secosw**2
omegaAps=jnp.arctan2(sesinw,secosw) #

tpre=jnp.linspace(np.min(t),np.max(t),3600)
for i in tqdm.tqdm(range(0,len(posterior_sample["P"][::10]))):
    e=eps[i]
    T0=posterior_sample["T0"][i]
    P=posterior_sample["P"][i]
    omegaA=omegaAps[i]
    Ksini=posterior_sample["Ksini"][i]
    Vsys=posterior_sample["Vsys"][i]
    model=rvf(tpre,T0,P,e,omegaA,Ksini,Vsys)
    ax.plot(tpre,model,alpha=0.05,color="gray")

plt.savefig("npz/results.png", bbox_inches="tight", pad_inches=0.0)
plt.show()