Source code for atomcloud.plots.plot_sumfit

# -*- coding: utf-8 -*-
"""
Created on Tue Mar 22 23:03:21 2022

@author: hofer
"""
from pathlib import Path
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from atomcloud.functions import MultiFunction1D, MultiFunction2D
from atomcloud.plots.plot_2d import Plot2DBase
from atomcloud.utils import img_utils
from atomcloud.utils.uncertain_utils import nominal_list


# __all__ = ["Plot2DSumFit"]


[docs]class Plot2DSumFit(Plot2DBase): """Class for plotting 2D sum fits. Inherits from the base 2D plotting class."""
[docs] def plot_fit( self, fit_dicts: dict, XY_tuple: tuple[np.ndarray, np.ndarray], data: np.ndarray, mask: Optional[np.ndarray] = None, title: str = "", savepath: Optional[Union[str, Path]] = None, verbose: bool = True, *args, **kwargs ) -> None: """Plot the fit results from the 2D sum fit. Args: fit_dicts (dict): Dictionary containing the fit results for a 2D multi-function fit. XY_tuple (tuple[np.ndarray, np.ndarray]): The x and y coordinates of the data. data (np.ndarray): The original data which was fit using the multi-function to determine the parameters (params). mask (np.ndarray, optional): The mask to be applied to the data. title (str, optional): The title of the plot. Defaults to ''. savepath (str, optional): The path to save the plot. Defaults to None. verbose (bool, optional): If True, the data for each function will be plotted. Defaults to True. """ data = self.mask_data(data, mask) fit_dict_params = self.unpack_fit_dicts(fit_dicts) func_strs, constraints, params2d, xparams, yparams = fit_dict_params if len(func_strs) == 1: verbose = False multi_func1d = MultiFunction1D(func_strs) multi_func2d = MultiFunction2D(func_strs) x, y, xsum, ysum = img_utils.img_data_to_sums(XY_tuple, data) pdata_dict = self.plot_data( multi_func2d, XY_tuple, data, params2d, func_strs, verbose ) xdata_dict = self.plot_data(multi_func1d, x, xsum, xparams, func_strs, verbose) ydata_dict = self.plot_data(multi_func1d, y, ysum, yparams, func_strs, verbose) fig, axes_2d, axes_1d, cax = self.fig_axes(func_strs, verbose) self.plot_2d(XY_tuple, pdata_dict, axes_2d, cax) self.plot_1d(x, y, xdata_dict, ydata_dict, axes_1d, "sum") title_base, save_name = self.handle_title( title, base_title="2D Sum Fit", base_save_name="sumfit" ) title_str = self.title_string(constraints, title_base) fig.suptitle(title_str) if savepath is not None: self.save_plot(savepath, save_name) plt.show()
[docs] def unpack_fit_dicts( self, fit_dicts: dict ) -> tuple[ list[str], list[str], list[list[float]], list[list[float]], list[list[float]] ]: """Unpack the sumfit fit dictionaries into the necessary components for plotting as they are nested in the fit_dicts and comprised of the 1d and 2d fit dictionaries. Args: fit_dicts: The fit dictionaries for the 2D sum fit. Returns: func_strs: The function strings for the 2D fit. constraints: The constraints for the 2D fit. params2d: The seed parameters for the 2D fit. xparams: The parameters for the x-axis 1D fit. yparams: The parameters for the y-axis 1D fit. """ func_strs = fit_dicts["2d"]["equations"] constraints = fit_dicts["2d"]["constraints"] params2d = nominal_list(fit_dicts["2d"]["params"]) xparams = nominal_list(fit_dicts["xsum"]["params"]) yparams = nominal_list(fit_dicts["ysum"]["params"]) return func_strs, constraints, params2d, xparams, yparams