L2 regularization using GradientDescent ------------------------------------------- .. code:: python import numpy as np import jax.numpy as jnp from jaxsot.core.weight import comp_weight, comp_omega from jaxsot.core.lc import gen_lightcurve from jaxsot.io.earth import binarymap import jaxopt import healpy as hp import matplotlib.pyplot as plt .. code:: python nside=8 mmap=binarymap(nside=nside,show=False) inc=0.0 Thetaeq=np.pi zeta=np.pi/3.0 Pspin=23.9344699/24.0 wspin=2*np.pi/Pspin Porb=40.0 worb=2.*np.pi/Porb N=1024 obst=np.linspace(0.0,Porb,N) Thetav=worb*obst Phiv=np.mod(wspin*obst,2*np.pi) omega=comp_omega(nside) WI,WV=comp_weight(nside,zeta,inc,Thetaeq,Thetav,Phiv,omega) W=jnp.array(WI*WV) lc=gen_lightcurve(W,mmap,0.1) .. code:: python def objective(params,lam): residuals=lc - jnp.dot(W,params) return jnp.mean(residuals ** 2) + lam*lam * jnp.dot(params,params) gd = jaxopt.GradientDescent(fun=objective, maxiter=500) res = gd.run(init_params=np.random.normal(0.0,1.0,len(mmap)), lam=0.1) params, state = res hp.mollview(params) plt.show()