Running the sampler locally with multiprocessing

We’ll generate some fake radial velocity measurements of a source and run The Joker using Python’s multiprocessing package. Parallelizing the rejection sampling will generally speed up the sampling by a factor equal to the number of cores (evaluating the marginal likelihood for each of the prior samples “embarassingly parallel”).

In [1]:
from astropy.time import Time
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import schwimmbad
%matplotlib inline

from thejoker import mpl_style
plt.style.use(mpl_style)
from thejoker.celestialmechanics import SimulatedRVOrbit
from thejoker.data import RVData
from thejoker.sampler import JokerParams, TheJoker
from thejoker.plot import plot_rv_curves

rnd = np.random.RandomState(seed=123)

For data, we’ll generate simulated observations of the exoplanet GJ 876 b (with parameters taken from exoplanets.org):

In [2]:
t0 = Time(2450546.80, format='jd', scale='utc')

truth = dict()
truth['P'] = 61.1166 * u.day
truth['K'] = 214. * u.m/u.s
truth['ecc'] = 0.0324 * u.one
phi0 = 2*np.pi*t0.tcb.mjd / truth['P'].to(u.day).value
truth['phi0'] = (phi0 % (2*np.pi)) * u.radian
truth['omega'] = 50.3 * u.degree
truth['v0'] = -1.52 * u.km/u.s

orbit = SimulatedRVOrbit(**truth)

We generate the data by sampling times uniformly over 350 days relative to an arbitrary epoch in MJD:

In [3]:
n_data = 6
t = rnd.uniform(0, 350, n_data) + 55557. # arbitrary epoch
t.sort()
rv = orbit.generate_rv_curve(t)

err = np.full_like(t, 25) * u.m/u.s
rv = rv + rnd.normal(0, err.value)*err.unit

Now we create an RVData object to store the “observations”:

In [4]:
data = RVData(t=t, rv=rv, stddev=err)
ax = data.plot()
ax.set_xlabel("BMJD")
ax.set_ylabel("RV [km/s]")
Out[4]:
<matplotlib.text.Text at 0x10eda7860>
../_images/examples_multiproc-example_8_1.png

We’ll set the period range to be somewhat more restricted since (in practice, you should use a very large range of periods):

In [5]:
params = JokerParams(P_min=8*u.day, P_max=256*u.day, anomaly_tol=1E-11)

To run using multiprocessing, we have to create a schwimmbad.MultiPool instance to pass in to TheJoker. In this case, we only need the pool to do the rejection sampling, so we’ll use a context manager to make sure the worker processes are all cleaned up:

In [6]:
%%time
with schwimmbad.MultiPool() as pool:
    joker = TheJoker(params, pool=pool)
    samples = joker.rejection_sample(data, n_prior_samples=2**18)
INFO: 31 good samples after rejection sampling [thejoker.sampler.sampler]
CPU times: user 121 ms, sys: 97.8 ms, total: 218 ms
Wall time: 23.4 s

Now we’ll plot the samples in various projections of the parameters:

In [7]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

y_key = 'P'
y_unit = u.day
x_keys = ['K', 'ecc']
x_units = [u.km/u.s, u.one]

for ax,x_key,x_unit in zip(axes, x_keys, x_units):
    ax.scatter(samples[x_key].to(x_unit).value,
               samples[y_key].to(y_unit).value,
               marker='.', color='k', alpha=0.45)
    ax.set_xlabel(r"{} [{}]".format(x_key, x_unit.to_string('latex')))

    ax.axvline(truth[x_key].to(x_unit).value,
               zorder=-100, color='#31a354', alpha=0.4)
    ax.axhline(truth[y_key].to(y_unit).value,
               zorder=-100, color='#31a354', alpha=0.4)

axes[0].set_ylabel(r"{} [{}]".format(y_key, y_unit.to_string('latex')))
Out[7]:
<matplotlib.text.Text at 0x10fb72518>
../_images/examples_multiproc-example_14_1.png

And here we’ll plot RV curves for the posterior samples over the data:

In [8]:
fig, ax = plt.subplots(1, 1, figsize=(8,5))
t_grid = np.linspace(data.t.mjd.min()-10, data.t.mjd.max()+10, 1024)
fig = plot_rv_curves(samples, t_grid, rv_unit=u.km/u.s, data=data, ax=ax,
                     plot_kwargs=dict(color='#74a9cf', zorder=-100))
/Users/adrian/anaconda/envs/thejoker-dev/lib/python3.5/site-packages/thejoker-0.1.dev375-py3.5.egg/thejoker/celestialmechanics/celestialmechanics.py:103: RuntimeWarning: eccentric_anomaly_from_mean_anomaly() reached maximum number of iterations (128)
  "number of iterations ({})".format(maxiter), RuntimeWarning)
../_images/examples_multiproc-example_16_1.png