Source code for taurex.optimizer.nestle

"""Retrieval using nestle library."""
import time
import typing as t

import nestle
import numpy as np
import numpy.typing as npt

from taurex.model import ForwardModel
from taurex.output import OutputGroup
from taurex.spectrum import BaseSpectrum
from taurex.util import quantile_corner, recursively_save_dict_contents_to_output

from .optimizer import FitParamOutput, Optimizer


[docs] class NestleStatsOutput(t.TypedDict): """Dictionary for storing nestle stats output.""" LogEvidence: float LogEvidenceError: float Peakiness: float
[docs] class NestleSolutionOutput(t.TypedDict): """Dictionary for storing nestle solution output.""" samples: npt.NDArray[np.float64] weights: npt.NDArray[np.float64] covariance: npt.NDArray[np.float64] fitparams: t.Dict[str, FitParamOutput]
[docs] class NestleOptimizer(Optimizer): """An optimizer that uses the `nestle <http://kylebarbary.com/nestle/>`_ library to perform optimization. """ def __init__( self, observed: t.Optional[BaseSpectrum] = None, model: t.Optional[ForwardModel] = None, num_live_points: t.Optional[int] = 1500, method: t.Optional[t.Literal["single", "multi", "mcmc"]] = "multi", tol: t.Optional[float] = 0.5, sigma_fraction: t.Optional[float] = 0.1, ): """Initialize and setup nestle. Parameters ---------- observed: Observed spectrum to optimize to model: Forward model to optimize num_live_points: Number of live points to use in sampling method: Nested sampling method to use. ``classic`` uses MCMC exploration, ``single`` uses a single ellipsoid and ``multi`` uses multiple ellipsoids (similar to Multinest) tol: Evidence tolerance value to stop the fit. This is based on an estimate of the remaining prior volumes contribution to the evidence. sigma_fraction: Fraction of weights to use in computing the error. (Default: 0.1) """ super().__init__("Nestle", observed, model, sigma_fraction) self._nlive = int(num_live_points) # number of live points self._method = method # use MutliNest algorithm self._tol = tol # the stopping criterion self._nestle_output: NestleSolutionOutput = None @property def tolerance(self) -> float: """Tolerance value for stopping the fit.""" return self._tol @tolerance.setter def tolerance(self, value: float) -> None: """Set the tolerance value for stopping the fit.""" self._tol = value @property def numLivePoints(self) -> int: # noqa: N802 """Number of live points to use in the fit.""" return self._nlive @numLivePoints.setter def numLivePoints(self, value: int) -> None: # noqa: N802 """Set the number of live points to use in the fit.""" self._nlive = value
[docs] def compute_fit(self) -> None: """Computes the fit using nestle.""" def nestle_uniform_prior(theta): # prior distributions called by multinest. Implements a uniform prior # converting parameters from normalised grid to uniform prior return tuple(self.prior_transform(theta)) ndim = len(self.fitting_parameters) self.warning("Beginning fit......") ndims = ndim # two parameters t0 = time.time() res = nestle.sample( self.log_likelihood, nestle_uniform_prior, ndims, method="multi", npoints=self.numLivePoints, dlogz=self.tolerance, callback=nestle.print_progress, ) res = t.cast(nestle.Result, res) t1 = time.time() timenestle = t1 - t0 print(res.summary()) self.warning("Time taken to run 'Nestle' is %s seconds", timenestle) self.warning("Fit complete.....") self._nestle_output = self.store_nestle_output(res)
[docs] def get_samples(self, solution_idx: int) -> npt.NDArray[np.float64]: """Returns the samples from the fit.""" return self._nestle_output["solution"]["samples"]
[docs] def get_weights(self, solution_idx: int) -> npt.NDArray[np.float64]: """Returns the weights of the samples.""" return self._nestle_output["solution"]["weights"]
[docs] def get_solution( self, ) -> t.Generator[ t.Tuple[ int, npt.NDArray[np.float64], npt.NDArray[np.float64], t.Tuple[ t.Tuple[t.Literal["Statistics"], float], t.Tuple[t.Literal["fit_params"], t.Dict[str, FitParamOutput]], t.Tuple[t.Literal["tracedata"], npt.NDArray[np.float64]], t.Tuple[t.Literal["weights"], npt.NDArray[np.float64]], ], ], None, None, ]: """Generator for solutions and their median and MAP values Yields ------ solution_no: Solution number (always 0) map: Map values median: Median values extra: statistics, fit_params, tracedata, weights """ names = self.fit_names opt_map = self.fit_values opt_values = self.fit_values for k, v in self._nestle_output["solution"]["fitparams"].items(): # if k.endswith('_derived'): # continue idx = names.index(k) opt_map[idx] = v["map"] opt_values[idx] = v["value"] yield 0, opt_map, opt_values, ( ("Statistics", self._nestle_output["Stats"]), ("fit_params", self._nestle_output["solution"]["fitparams"]), ("tracedata", self._nestle_output["solution"]["samples"]), ("weights", self._nestle_output["solution"]["weights"]), )
[docs] def write_optimizer(self, output: OutputGroup) -> OutputGroup: """Writes the optimizer to the output group.""" opt = super().write_optimizer(output) # number of live points opt.write_scalar("num_live_points", self._nlive) # maximum no. of iterations (0=inf) opt.write_string("method", self._method) # search for multiple modes opt.write_scalar("tol", self._tol) return opt
[docs] def write_fit(self, output: OutputGroup) -> OutputGroup: """Writes the fit to the output group.""" fit = super().write_fit(output) if self._nestle_output: recursively_save_dict_contents_to_output(output, self._nestle_output) return fit
[docs] def store_nestle_output(self, result: nestle.Result) -> NestleSolutionOutput: """This turns the output fron nestle into a dictionary Contains summary statistics and the solution. """ nestle_output = {} nestle_output["Stats"] = {} nestle_output["Stats"]["Log-Evidence"] = result.logz nestle_output["Stats"]["Log-Evidence-Error"] = result.logzerr nestle_output["Stats"]["Peakiness"] = result.h fit_param = self.fit_names samples = result.samples weights = result.weights mean, cov = nestle.mean_and_cov(samples, weights) nestle_output["solution"] = {} nestle_output["solution"]["samples"] = samples nestle_output["solution"]["weights"] = weights nestle_output["solution"]["covariance"] = cov nestle_output["solution"]["fitparams"] = {} max_weight = weights.argmax() table_data = [] for idx, param_name in enumerate(fit_param): param = {} param["mean"] = mean[idx] param["sigma"] = cov[idx] trace = samples[:, idx] q_16, q_50, q_84 = quantile_corner( trace, [0.16, 0.5, 0.84], weights=np.asarray(weights) ) param["value"] = q_50 param["sigma_m"] = q_50 - q_16 param["sigma_p"] = q_84 - q_50 param["trace"] = trace param["map"] = trace[max_weight] table_data.append((param_name, q_50, q_50 - q_16)) nestle_output["solution"]["fitparams"][param_name] = param return nestle_output
[docs] @classmethod def input_keywords(cls) -> t.Tuple[str, ...]: return ("nestle",)
BIBTEX_ENTRIES = [ """@misc{nestle, author = {Kyle Barbary}, title = {Nestle sampling library}, publisher = {GitHub}, journal = {GitHub repository}, year = 2015, howpublished = {https://github.com/kbarbary/nestle}, }""" ]