# -*- coding: utf-8 -*-
"""
Created on Sat Mar 19 14:49:49 2022
@author: hofer
"""
import time
from abc import ABC
from typing import Iterable, Optional, Union
import numpy as np
import uncertainties
from scipy.optimize import curve_fit
from atomcloud.analysis import calc_chi_squared
from atomcloud.utils import fit_utils
try:
import jaxfit
JAX = True
except ImportError:
JAX = False
# __all__ = ["MultiFunctionFit"]
[docs]class MultiFunctionFit(ABC):
def __init__(
self,
function_names: list[str],
multi_func: object,
func_registry: dict,
fit_label: str,
max_nfev_scalar: int = 50,
constraints: Optional[list[str]] = None,
scipy_length: int = 1e3,
fixed_length: Optional[str] = None,
) -> None:
"""
This class is used to create a multi-function fit object using either
numpy or JAX. The resulting object is then used to do a multi-function
fit to the data using either scipy or JAXFit.
Args:
function_names: names of the functions to be used in the fit
dimensions: number of dimensions of the data
constraints: list of constraints to be used in the multi-function
scipy_length: length of data to use scipy over jax
fixed_length: fixed length in JAXFit
"""
self.function_names = function_names
self.constraints = constraints
self.scipy_length = scipy_length
self.fixed_length = fixed_length
self.func_registry = func_registry
self.fit_label = fit_label
# make rename func to func and make cloud_func obj _call__ function()
self.func = multi_func(function_names, constraints, use_jax=False)
self.fit_object_init(multi_func)
if isinstance(max_nfev_scalar, int):
self.max_nfev = max_nfev_scalar * self.func.num_args
else:
raise TypeError(
f"max_nfev_scalar must be an integer, not {type(max_nfev_scalar)}"
)
self.info_keys = [
"fit_type",
"equations",
"constraints",
"params",
"fit_metrics",
"data_sum",
]
[docs] def fit_object_init(self, multi_func: object) -> None:
"""Initialize the JAXFit object and function. This is only done if
JAX is installed and SciPy length is not None.
Args:
multi_func: multi-function object to be used
"""
if self.scipy_length is not None and JAX is True:
self.jax = True
self.jax_func = multi_func(
self.function_names, self.constraints, use_jax=True
)
self.jcf = jaxfit.CurveFit(flength=self.fixed_length)
else:
self.jax = False
self.scipy_length = 0
# self.jax_func = None
[docs] def get_fit_obj(self, flat_data):
"""Returns the correct fit package to use based on the length of the
data and whether JAX is installed.
Args:
flat_data: flattened data to be fit
Returns:
curvefit: curvefit function to use
fit_func: function to be fit
print_label: label to print to console if Verbose is enabled
"""
if self.jax and len(flat_data) > self.scipy_length:
kwargs = {"return_eval": True}
return self.jcf.curve_fit, self.jax_func, "JAXFit", True, kwargs
else:
return curve_fit, self.func, "SciPy", False, {}
[docs] def get_fit(
self,
coords: Union[np.ndarray, Iterable[np.ndarray]],
data: np.ndarray,
seed: Optional[list[list[float]]] = None,
bounds: Optional[list[list[float]]] = None,
sigma: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
uncertainty: bool = False,
verbose: bool = False,
) -> tuple[list[float], dict]:
"""Fit the data to the multi-function
Args:
coords: coordinates of the data
data: data to be fit (must be same shape as coords)
seed: initial seed for the fit in terms of the multi-functions
individual functions (ie. list of lists where the top level list
corresponds to the function and the second level list corresponds
to the parameters of that function).
bounds: tuple of min and max values for each parameter, but the
min and max values are each formatted as a list of lists (see
seed for formatting)
sigma: standard deviation of the data (must be same shape as data)
or a covariance matrix with each axis the same length as
the data.
mask: mask to be applied to the data which is the same shape as the
data
uncertainty: whether to return the uncertainties of the fit
plot_it: whether to plot the fit
verbose: whether to print the fit information to the console.
Returns:
params: list of the fit parameters
info: dictionary of the fit information
"""
flat_coords, flat_data = self.flatten_fit_data(coords, data)
if mask is not None:
mask = mask.flatten()
flat_coords, flat_data = fit_utils.get_masked_data(
flat_coords, flat_data, mask
)
data_sum = np.sum(flat_data)
fit_settings = self.get_fit_obj(flat_data)
curvefit, fit_func, print_label, jaxfit, kwargs = fit_settings
if seed is None:
seed = self.get_default_seed(flat_coords, flat_data)
if bounds is None:
bounds = self.get_default_bounds()
# convert the seed and bounds to a single list matching the fit func
seed = self.func.params_to_args(seed)
bounds = [self.func.params_to_args(bound) for bound in bounds]
st = time.time()
fit_results = curvefit(
fit_func.fit_function,
flat_coords,
flat_data,
p0=seed,
bounds=bounds,
sigma=sigma,
max_nfev=self.max_nfev,
**kwargs,
)
if jaxfit:
popt, pcov, func_eval = fit_results
else:
popt, pcov = fit_results
func_eval = self.func.fit_function(flat_coords, *popt)
if verbose:
print(print_label, time.time() - st)
params = self.func.args_to_params(popt)
save_params = self.handle_uncertainty(popt, pcov, params, uncertainty)
fit_metrics = self.get_fit_metrics(params, func_eval, flat_data, sigma)
fit_dict = self.get_info_dict(save_params, fit_metrics, data_sum)
return params, fit_dict
[docs] def flatten_fit_data(
self, coords: Union[np.ndarray, Iterable[np.ndarray]], data: np.ndarray
):
"""Flatten the data and coordinates to be fit.
Args:
coords: coordinates of the data
data: data to be fit (must be same shape as coords)
Returns:
flat_coords: flattened coordinates of the data
flat_data: flattened data to be fit
"""
flat_data = data.flatten()
if type(coords) is tuple or type(coords) is list:
flat_coords = [coord_array.flatten() for coord_array in coords]
else:
flat_coords = coords.flatten()
return flat_coords, flat_data
[docs] def handle_uncertainty(
self,
popt: np.ndarray,
pcov: np.ndarray,
func_params: list[list[float]],
uncertainty: bool,
) -> list[list[float]]:
"""Handle the uncertainty of the fit. If uncertainty is True, then
the covariance matrix and the fit parameters are used to create the
uncertainty fit parameters. The packages is designed to handle
these uncertainty parameters throughout, but these are more difficult
to work with due to limits on the accepted operations and thus the
user might wish to neglect using them for their own custom functions.
Args:
popt: fit parameters
pcov: covariance matrix for the fit parameters
func_params: fit parameters in the format list of individual
functions fit parameters
uncertainty: whether to return the uncertainty parameters
Returns:
save_params: fit parameters to be saved in the info dictionary
"""
if uncertainty:
upopt = uncertainties.correlated_values(popt, pcov)
ufunc_params = self.func.args_to_params(upopt)
return ufunc_params
else:
return func_params
[docs] def get_fit_metrics(
self,
params: list[list[float]],
func_eval: np.ndarray,
flat_data: Union[np.ndarray, Iterable[np.ndarray]],
sigma: Optional[np.ndarray] = None,
) -> dict:
"""Get the fit metrics for the fit. Currently only chi squared and
reduced chi squared are calculated.
Args:
params: fit parameters
flat_coords: flattened coordinates of the data
flat_data: flattened data that was fit
sigma: standard deviation of the data (must be same shape as data)
or a covariance matrix with each axis the same length as
the data.
Returns:
fit_metrics: dictionary of the fit metrics
"""
# TODO: change function evaluation to use JAX rather than numpy
# current issues is that coords are in numpy and need to be reconverted
fit_metric_dict = {}
chi_values = calc_chi_squared(len(params), flat_data, func_eval, sigma)
chi_dict = {"chi_squared": chi_values[0], "chi_squared_red": chi_values[1]}
fit_metric_dict.update(chi_dict)
return fit_metric_dict
[docs] def get_info_dict(
self, fit_parameters: list[list[float]], fit_metrics: dict, data_sum: float
) -> dict:
"""Get the fit information dictionary this will be saved and used
throughout the package to do things like plotting, integrating, etc.
Args:
fit_parameters: fit parameters
fit_metrics: dictionary of the fit metrics
data_sum: sum of the data that was fit
Returns:
fit_dict: dictionary of the fit information
"""
fit_info = [
self.fit_label,
self.function_names,
self.constraints,
fit_parameters,
fit_metrics,
data_sum,
]
fit_dict = dict(zip(self.info_keys, fit_info))
return fit_dict
[docs] def get_default_bounds(self) -> tuple[list[list[float]], ...]:
"""Get the default bounds for the fit. This is a tuple of two lists
one for the lower bounds and one for the upper bounds. The lower and
upper bounds are each returned as a list of lists where
each sublist is the bounds for a single function.
Returns:
bounds: tuple of the lower and upper bounds
"""
bounds = []
for function_name in self.function_names:
function_object = self.func_registry.get(function_name)
bounds.append(function_object.default_bounds())
min_bounds = [bound[0] for bound in bounds]
max_bounds = [bound[1] for bound in bounds]
return min_bounds, max_bounds
[docs] def get_default_seed(
self, coords: Union[np.ndarray, Iterable[np.ndarray]], data: np.ndarray
) -> list[list[float]]:
"""Get the default seed for the fit. This is a list of lists where
each sublist is the seed for a single function.
Args:
coords: coordinates of the data
data: data to be fit
Returns:
seed: list of the seed parameters
"""
seed = []
for function_name in self.function_names:
function_object = self.func_registry.get(function_name)
seed.append(function_object.initial_seed(coords, data))
return seed
[docs] def set_cutoff_length(self, scipy_length: int) -> None:
"""Allows the user to change the length of data that will
trigger a JAX vs. SciPy fit (see init docstring for more info).
Args:
scipy_length: length of data to use scipy over jax
"""
self.scipy_length = scipy_length
[docs] def set_fixed_length(self, fixed_length: int) -> None:
"""Set a fixed length, but jaxfit needs to be changed to allow
this to work."""
# self.fixed_length = fixed_length
pass