Source code for taurex.util.util

"""General utility functions."""
import re
import typing as t

import numpy as np
import numpy.typing as npt
from astropy import units as u

from taurex.output.output import OutputGroup
from taurex.types import AnyValType, ScalarType

mass = {
    "H": 1.00794,
    "He": 4.002602,
    "Li": 6.941,
    "Be": 9.012182,
    "B": 10.811,
    "C": 12.011,
    "N": 14.00674,
    "O": 15.9994,
    "F": 18.9984032,
    "Ne": 20.1797,
    "Na": 22.989768,
    "Mg": 24.3050,
    "Al": 26.981539,
    "Si": 28.0855,
    "P": 30.973762,
    "S": 32.066,
    "Cl": 35.4527,
    "Ar": 39.948,
    "K": 39.0983,
    "Ca": 40.078,
    "Sc": 44.955910,
    "Ti": 47.88,
    "V": 50.9415,
    "Cr": 51.9961,
    "Mn": 54.93805,
    "Fe": 55.847,
    "Co": 58.93320,
    "Ni": 58.6934,
    "Cu": 63.546,
    "Zn": 65.39,
    "Ga": 69.723,
    "Ge": 72.61,
    "As": 74.92159,
    "Se": 78.96,
    "Br": 79.904,
    "Kr": 83.80,
    "Rb": 85.4678,
    "Sr": 87.62,
    "Y": 88.90585,
    "Zr": 91.224,
    "Nb": 92.90638,
    "Mo": 95.94,
    "Tc": 98,
    "Ru": 101.07,
    "Rh": 102.90550,
    "Pd": 106.42,
    "Ag": 107.8682,
    "Cd": 112.411,
    "In": 114.82,
    "Sn": 118.710,
    "Sb": 121.757,
    "Te": 127.60,
    "I": 126.90447,
    "Xe": 131.29,
    "Cs": 132.90543,
    "Ba": 137.327,
    "La": 138.9055,
    "Ce": 140.115,
    "Pr": 140.90765,
    "Nd": 144.24,
    "Pm": 145,
    "Sm": 150.36,
    "Eu": 151.965,
    "Gd": 157.25,
    "Tb": 158.92534,
    "Dy": 162.50,
    "Ho": 164.93032,
    "Er": 167.26,
    "Tm": 168.93421,
    "Yb": 173.04,
    "Lu": 174.967,
    "Hf": 178.49,
    "Ta": 180.9479,
    "W": 183.85,
    "Re": 186.207,
    "Os": 190.2,
    "Ir": 192.22,
    "Pt": 195.08,
    "Au": 196.96654,
    "Hg": 200.59,
    "Tl": 204.3833,
    "Pb": 207.2,
    "Bi": 208.98037,
    "Po": 209,
    "At": 210,
    "Rn": 222,
    "Fr": 223,
    "Ra": 226.0254,
    "Ac": 227,
    "Th": 232.0381,
    "Pa": 213.0359,
    "U": 238.0289,
    "Np": 237.0482,
    "Pu": 244,
    "Am": 243,
    "Cm": 247,
    "Bk": 247,
    "Cf": 251,
    "Es": 252,
    "Fm": 257,
    "Md": 258,
    "No": 259,
    "Lr": 260,
    "Rf": 261,
    "Db": 262,
    "Sg": 263,
    "Bh": 262,
    "Hs": 265,
    "Mt": 266,
    "e-": 5.4857990907e-4,
}

ElementType = t.Dict[str, int]


[docs] def calculate_weight(chem: str) -> float: """Compute the molecular weight of a molecule in amu. Parameters ---------- chem : str Molecule name e.g. H2O, CO2, CH4, etc. Returns ------- float Molecular weight in amu """ s = split_molecule_elements(chem) compoundweight = 0.0 for element, count in s.items(): compoundweight += mass[element] * count return compoundweight
# def split_molecule_elements(chem): # s = re.findall('([A-Z][a-z]?)([0-9]*)', chem) # return s
[docs] def tokenize_molecule(molecule: str) -> list[str]: """Tokenize a molecule string into its elements and numbers.""" import re return re.findall(r"[A-Z][a-z]?|\d+|.", molecule)
[docs] def merge_elements( elem1: ElementType, elem2: ElementType, factor: t.Optional[int] = 1 ) -> ElementType: """Merge two element dictionaries.""" return { elem: elem1.get(elem, 0) + elem2.get(elem, 0) * factor for elem in set(elem1) | set(elem2) }
[docs] def split_molecule_elements( # noqa: C901 molecule: t.Optional[str] = None, tokens: t.Optional[t.List[str]] = None ) -> ElementType: """Split a molecule string into its elements and numbers. For example when run with H2O: >>> split_molecule_elements('H2O') {'H': 2, 'O': 1} Parameters ---------- molecule : str, optional Molecule string to split tokens : list[str], optional List of presplit tokens. Returns ------- dict[str, int] Dictionary of elements and their counts """ from taurex.util import mass elems = {} if molecule: tokens = tokenize_molecule(molecule) length = 0 while length < len(tokens): token = tokens[length] if token in mass: if token not in elems: elems[token] = 0 try: peek = int(tokens[length + 1]) length += 1 except IndexError: peek = 1 except ValueError: peek = 1 elems[token] += peek elif token in "{([": # noqa: S105 length += 1 sub_elems, moved = split_molecule_elements(tokens=tokens[length:]) length += moved try: peek = int(tokens[length + 1]) length += 1 except IndexError: peek = 1 except ValueError: peek = 1 elems = merge_elements(elems, sub_elems, peek) elif token in "}])": # noqa: S105 return elems, length length += 1 return elems
[docs] def sanitize_molecule_string(molecule: str) -> str: """Cleans a molecule string to match up with molecule naming in TauREx3. For example: >>> sanitize_molecule_string('H2O') 'H2O' >>> sanitize_molecule_string('H2-16O') 'H2O' Parameters ---------- molecule: str Molecule to sanitize Returns ------- str: Sanitized name """ return "".join(["".join(s) for s in re.findall("([A-Z][a-z]?)([0-9]*)", molecule)])
_mol_latex = { "HE": "He", "H2": "H$_2$", "N2": "N$_2$", "O2": "O$_2$", "CO2": "CO$_2$", "CH4": "CH$_4$", "CO": "CO", "NH3": "NH$_3$", "H2O": "H$_2$O", "C2H2": "C$_2$H$_2$", "HCN": "HCN", "H2S": "H$_2$S", "SIO2": "SiO$_2$", "SO2": "SO$_2$", } """Latex versions of molecule names"""
[docs] def get_molecular_weight(gasname: str) -> float: """For a given molecule return the molecular weight in kg Parameters ---------- gasname : str Name of molecule Returns ------- float : molecular weight in amu or 0 if not found """ from taurex.constants import AMU mu = calculate_weight(gasname) return mu * AMU
# TODO: Generalize this to any molecule.
[docs] def molecule_texlabel(gasname: str) -> str: """For a given molecule return its latex form Parameters ---------- gasname : str Name of molecule Returns ------- str : Latex form of the molecule or just the passed name if not found """ gasname = gasname try: return _mol_latex[gasname] except KeyError: return gasname
[docs] def bindown( original_bin: npt.NDArray, original_data: npt.NDArray, new_bin: npt.NDArray ) -> npt.NDArray: """This method quickly bins down by taking the mean. The numpy histogram function is exploited to do this quickly. This is prone to nans if no data is present in the bin. Parameters ---------- original_bin: :obj:`numpy.array` The original bins for the that we want to bin down original_data: :obj:`numpy.array` The associated data that will be averaged along the new bins new_bin: :obj:`numpy.array` The new binnings we want to use (must have less points than the original) Returns ------- :obj:`array` Binned mean of ``original_data`` """ import numpy as np # print(original_bin.shape,original_data.shape) # if last_point is None: # last_point = new_bin[-1]*2 # calc_bin = np.append(new_bin,last_point) # return(np.histogram(original_bin, calc_bin, weights=original_data)[0] / # np.histogram(original_bin,calc_bin)[0]) filter_lhs = np.zeros(new_bin.shape[0] + 1) filter_lhs[0] = new_bin[0] filter_lhs[0] -= (new_bin[1] - new_bin[0]) / 2 filter_lhs[-1] = new_bin[-1] filter_lhs[-1] += (new_bin[-1] - new_bin[-2]) / 2 filter_lhs[1:-1] = (new_bin[1:] + new_bin[:-1]) / 2 axis = len(original_data.shape) - 1 if axis: digitized = np.digitize(original_bin, filter_lhs, right=True) axis = len(original_data.shape) - 1 bin_means = [ original_data[..., digitized == i].mean(axis=axis) for i in range(1, len(filter_lhs)) ] return np.column_stack(bin_means) return ( np.histogram(original_bin, filter_lhs, weights=original_data)[0] / np.histogram(original_bin, filter_lhs)[0] )
[docs] def movingaverage(a: npt.NDArray, n: t.Optional[int] = 3) -> npt.NDArray: """Computes moving average given an array and window size. Parameters ---------- a : :obj:`array` Array to compute average n : int Averaging window Returns ------- :obj:`array` Resultant array """ import numpy as np ret = np.cumsum(a) ret[n:] = ret[n:] - ret[:-n] return ret[n - 1 :] / n
[docs] def quantile_corner( x: npt.NDArray, q: t.Union[npt.NDArray, float], weights: t.Optional[t.Union[float, npt.NDArray]] = None, ) -> npt.NDArray: """Compute quantiles from an array with weighting. * Taken from corner.py __author__ = "Dan Foreman-Mackey (danfm@nyu.edu)" __copyright__ = "Copyright 2013-2015 Daniel Foreman-Mackey" Like numpy.percentile, but: * Values of q are quantiles [0., 1.] rather than percentiles [0., 100.] * scalar q not supported (q must be iterable) * optional weights on x Parameters ---------- x : :obj:`array` Input array or object that can be converted to an array. q : :obj:`array` or float Percentile or sequence of percentiles to compute, which must be between 0 and 1 inclusive. weights : :obj:`array` or float , optional Weights on x Returns ------- percentile : scalar or ndarray """ import numpy as np if weights is None: return np.percentile(x, [100.0 * qi for qi in q]) else: idx = np.argsort(x) xsorted = x[idx] cdf = np.add.accumulate(weights[idx]) cdf /= cdf[-1] return np.interp(q, cdf, xsorted).tolist()
[docs] def loadtxt2d(intext: str) -> npt.NDArray: """Wraps loadtext. Returns a 2d array or 1d array depending on the input text. Parameters ---------- intext : str Input text Returns ------- :obj:`array` 2d array or 1d array """ try: return np.loadtxt(intext, ndmin=2) except Exception: return np.loadtxt(intext)
[docs] def read_error_line(line: str) -> t.Tuple[str, float, float]: """Reads line from multinest""" print("_read_error_line -> line>", line) name, values = line.split(" ", 1) print("_read_error_line -> name>", name) print("_read_error_line -> values>", values) name = name.strip(": ").strip() values = values.strip(": ").strip() v, error = values.split(" +/- ") return name, float(v), float(error)
[docs] def read_error_into_dict(line: str, d: t.Dict[str, float]) -> None: """Reads multinest error into dict.""" name, v, error = read_error_line(line) d[name.lower()] = v d["%s error" % name.lower()] = error
[docs] def read_table( txt: str, d: t.Optional[t.Dict[str, float]] = None, title: t.Optional[str] = None ): """Reads a table into a dictionary from multinest outputs.""" from io import StringIO import numpy as np if title is None: title, table = txt.split("\n", 1) else: table = txt header, table = table.split("\n", 1) data = loadtxt2d(StringIO(table)) if d is not None: d[title.strip().lower()] = data if len(data.shape) == 1: data = np.reshape(data, (1, -1)) return data
[docs] def decode_string_array(f): """Helper to decode strings from hdf5.""" sl = list(f) return [s[0].decode("utf-8") for s in sl]
OutputItem = t.Union[ float, int, np.int64, np.float64, np.ndarray, str, t.List, t.Tuple, dict ]
[docs] def recursively_save_dict_contents_to_output( output: OutputGroup, dic: t.Dict[str, OutputItem] ): """Will recursive write a dictionary into output. Parameters ---------- output: Group (or root) in output file to write to dic: dict Dictionary we want to write Raises ------ ValueError If item is not a supported type """ for key, item in dic.items(): try: store_thing(output, key, item) except TypeError as e: raise ValueError("Cannot save %s type" % type(item)) from e
[docs] def store_thing(output: OutputGroup, key: str, item: OutputItem) -> None: # noqa: C901 """Stores a single item into output. Parameters ---------- output: Group (or root) in output file to write to key: str Name of item item: Item to store Raises ------ TypeError If item is not a supported type """ if isinstance( item, ( float, int, np.int64, np.float64, ), ): output.write_scalar(key, item) elif isinstance(item, (np.ndarray,)): output.write_array(key, item) elif isinstance(item, (str,)): output.write_string(key, item) elif isinstance( item, ( list, tuple, ), ): if isinstance(item, tuple): item = list(item) if True in [isinstance(x, str) for x in item]: output.write_string_array(key, item) else: try: output.write_array(key, np.array(item)) except TypeError: for idx, val in enumerate(item): new_key = f"{key}{idx}" store_thing(output, new_key, val) elif isinstance(item, dict): group = output.create_group(key) recursively_save_dict_contents_to_output(group, item) else: raise TypeError
[docs] def weighted_avg_and_std( values: npt.ArrayLike, weights: npt.ArrayLike, axis: t.Optional[int] = None ) -> t.Tuple[AnyValType, AnyValType]: """Computes weight average and standard deviation. Parameters ---------- values : :obj:`array` Input array weights : :obj:`array` Must be same shape as ``values`` axis : int , optional axis to perform weighting """ import numpy as np average = np.average(values, weights=weights, axis=axis) variance = np.average( (values - average) ** 2, weights=weights, axis=axis ) # Fast and numerically precise return (average, np.sqrt(variance))
[docs] def random_int_iter(total: int, fraction: t.Optional[float] = 1.0) -> t.Iterator[int]: """Iterator to randomly sample integers up to a total number. Fraction is the fraction of total to sample. For example if total = 100 and fraction = 0.1 then 10 random integers will be sampled between 0 and 99. Parameters ---------- total : int Maximum number fraction : float Fraction of total to sample Yields ------ int Random integer """ import random n_points = int(total * fraction) samples = random.sample(range(total), n_points) yield from samples
[docs] def compute_bin_edges(spectral_grid: npt.NDArray) -> t.Tuple[npt.NDArray, npt.NDArray]: """Computes bin edges from a spectral grid. Parameters ---------- spectral_grid : :obj:`array` Spectral grid Returns ------- :obj:`array` Bin edges :obj:`array` Bin widths """ import numpy as np diff = np.diff(spectral_grid) / 2 edges = np.concatenate( [ [spectral_grid[0] - (spectral_grid[1] - spectral_grid[0]) / 2], spectral_grid[:-1] + diff, [(spectral_grid[-1] - spectral_grid[-2]) / 2 + spectral_grid[-1]], ] ) return edges, np.abs(np.diff(edges))
[docs] def clip_native_to_wngrid( native_grid: npt.NDArray, spectral: npt.NDArray ) -> npt.NDArray: """Clips native grid values to a different spectral grid. Parameters ---------- native_grid : :obj:`array` Native spectral grid spectral : :obj:`array` spectral grid Returns ------- :obj:`array` Clipped native spectral grid """ min_spectral = spectral.min() max_spectral = spectral.max() # Compute the maximum width wnwidths = compute_bin_edges(spectral)[-1] wn_min = min_spectral - wnwidths.max() wn_max = max_spectral + wnwidths.max() native_filter = (native_grid >= wn_min) & (native_grid <= wn_max) return native_grid[native_filter]
[docs] def wnwidth_to_wlwidth(wngrid: npt.NDArray, wnwidth: npt.NDArray) -> npt.NDArray: """Converts a wavenumber width to wavelength width and vice versa. Given a spectral grid and its associated spectral bin widths, this function will convert the wavenumber widths to wavelength widths and vice versa. The formula used is: .. math:: \\Delta \\lambda = \\frac{10000 \\Delta \\nu}{\\nu^2} Parameters ---------- wngrid : :obj:`array` Wavenumber grid in :math:`cm^{-1}` wnwidth : :obj:`array` Wavenumber width in :math:`cm^{-1}` Returns ------- :obj:`array` Wavelength width in :math:`\\mu m` """ return 10000 * wnwidth / (wngrid**2)
[docs] def class_from_keyword(keyword, class_filter=None): from ..parameter.classfactory import ClassFactory cf = ClassFactory() combined_classes = [] if class_filter is None: combined_classes = ( list(cf.temperatureKlasses) + list(cf.pressureKlasses) + list(cf.chemistryKlasses) + list(cf.gasKlasses) + list(cf.planetKlasses) + list(cf.starKlasses) + list(cf.modelKlasses) + list(cf.contributionKlasses) ) else: if hasattr(class_filter, "__len__"): for x in class_filter: combined_classes += list(cf.list_from_base(x)) else: combined_classes = list(cf.list_from_base(class_filter)) for x in combined_classes: try: if keyword in x.input_keywords(): return x except NotImplementedError: continue return None
[docs] def class_for_name(class_name: str): """Converts a string to a class. Searches TauREx3 registry of classes (including plugins) for name. Parameters ---------- class_name : str Name of class """ from ..parameter.classfactory import ClassFactory cf = ClassFactory() combined_classes = ( list(cf.temperatureKlasses) + list(cf.pressureKlasses) + list(cf.chemistryKlasses) + list(cf.gasKlasses) + list(cf.planetKlasses) + list(cf.starKlasses) + list(cf.modelKlasses) + list(cf.contributionKlasses) ) try: class_name = class_name.decode() except (UnicodeDecodeError, AttributeError): pass combined_classes_name = [c.__name__ for c in combined_classes] if class_name in combined_classes_name: return combined_classes[combined_classes_name.index(class_name)] else: raise Exception(f"Class of name {class_name} does not exist")
[docs] def create_grid_res( resolution: ScalarType, spectral_min: ScalarType, spectral_max: ScalarType ) -> npt.NDArray: """Creates a grid with a given resolution. Resolution is defined as :math:`R = \\frac{\\lambda}{\\Delta \\lambda}` Parameters ---------- resolution : float Resolution to use spectral_min : float Minimum wavelength spectral_max : float Maximum wavelength Returns ------- :obj:`array` Grid with resolution and spectral bin widths """ # # R = l/Dl # l = (l-1)+Dl/2 + (Dl-1)/2 # # --> (R - 1/2)*Dl = (l-1) + (Dl-1)/2 # # spectral_list = [] width_list = [] wave = spectral_min width = wave / resolution while wave < spectral_max: width = wave / (resolution - 0.5) + width / 2 / (resolution - 0.5) wave = resolution * width width_list.append(width) spectral_list.append(wave) return np.array((spectral_list, width_list)).T
[docs] def conversion_factor(from_unit: str, to_unit: str) -> float: """Determine conversion from one unit to another. Parameters ---------- from_unit : :class:`~astropy.units.Unit` Unit to convert from to_unit : :class:`~astropy.units.Unit` Unit to convert to Returns ------- float Conversion factor. """ try: from_conv = u.Unit(from_unit) except ValueError: from_conv = u.Unit(from_unit, format="cds") try: to_conv = u.Unit(to_unit) except ValueError: to_conv = u.Unit(to_unit, format="cds") return from_conv.to(to_conv)
[docs] def compute_dz(altitude: npt.NDArray) -> npt.NDArray: dz = np.zeros_like(altitude) dz[:-1] = np.diff(altitude) dz[-1] = altitude[-1] - altitude[-2] return dz
[docs] def has_duplicates(arr: npt.ArrayLike) -> bool: """Checks if an array has duplicates.""" return len(arr) != len(set(arr))
[docs] def find_closest_pair(arr, value) -> (int, int): """Will find the indices that lie to the left and right of the value. `arr[left] <= value <= arr[right]` If the value is less than the array minimum then it will always return left=0 and right=1 If the value is above the maximum Parameters ---------- arr: :obj:`array` Array to search, must be sorted value: float Value to find in array Returns ------- left: int right: int """ right = arr.searchsorted(value) right = max(min(arr.shape[0] - 1, right), 1) left = right - 1 left = max(0, left) return left, right
[docs] def ensure_string_utf8(val: str) -> str: """Ensures a string is utf8 encoded.""" output = val try: output = val.decode() except ( UnicodeDecodeError, AttributeError, ): pass return output