Source code for tomobar.supp.dicts

import numpy as np
from tomobar.astra_wrappers.astra_tools2d import AstraTools2D
from tomobar.astra_wrappers.astra_tools3d import AstraTools3D

from typing import Union
from tomobar.supp.funcs import _data_dims_swapper

try:
    import cupy as xp

    try:
        xp.cuda.Device(0).compute_capability
        gpu_enabled = True  # CuPy is installed and GPU is available
    except xp.cuda.runtime.CUDARuntimeError:
        import numpy as xp
except ImportError:
    import numpy as xp


[docs] def dicts_check( self, _data_: dict, _algorithm_: Union[dict, None] = None, _regularisation_: Union[dict, None] = None, method_run: str = "FISTA", ) -> tuple: """This function accepts the `_data_`, `_algorithm_`, and `_regularisation_` dictionaries and populates parameters in them, if required. Please note that the dictionaries are needed for iterative methods only. The most versatile methods that can accept a variety of different parameters are FISTA and ADMM. Args: _data_ (dict): *Data dictionary where data-related items must be specified.* _algorithm_ (dict, optional): *Algorithm dictionary. Defaults to None.* _regularisation_ (dict, optional): *Regularisation dictionary. Needed only for FISTA and ADMM algorithms. Defaults to None.* method_run (str): *The name of the method to be run. Defaults to "FISTA".* Keyword Args: _data_['projection_norm_data'] (ndarray): Negative log(normalised) projection data as a 2D sinogram or as a 3D data array. _data_['projection_raw_data'] (ndarray): Raw data for PWLS and SWLS models. A FISTA-related parameter. _data_['data_axes_labels_order'] (list, None). The order of the axes labels for the input data. The default data labels are: ["detY", "angles", "detX"]. _data_['OS_number'] (int): The number of the ordered subsets. If None or 1 is used then the classical (full data) algorithm executed. Defaults to 1. _data_['huber_threshold'] (float): Parameter for the Huber data fidelity (to suppress outliers) [KAZ1_2017]_. _data_['studentst_threshold] (float): Parameter for Students't data fidelity (to suppress outliers) [KAZ1_2017]_. _data_['ringGH_lambda'] (float): Parameter for Group-Huber data model [PM2015]_ to suppress the full rings of the uniform intensity. _data_['ringGH_accelerate'] (float): Group-Huber data model acceleration factor (can lead to divergence for higher values). Defaults to 50. _data_['beta_SWLS'] (float): Regularisation parameter for stripe-weighted data model [HOA2017]_ for ring artefacts removal. Defaults to 0.1. _algorithm_['iterations'] (int): The number of iterations for the reconstruction algorithm. _algorithm_['nonnegativity'] (bool): Enable nonnegativity for the solution. Defaults to False. _algorithm_['recon_mask_radius'] (float): Enables the circular mask cutoff in the reconstructed image. Defaults to 1.0. _algorithm_['initialise'] (ndarray): Initialisation for the solution. An array of the expected output size must be provided. _algorithm_['lipschitz_const'] (float): Lipschitz constant for the FISTA algorithm. If not provided, it will be calculated for each method call. _algorithm_['ADMM_rho_const'] (float): Augmented Lagrangian parameter for the ADMM algorithm. _algorithm_['ADMM_relax_par'] (float): Over relaxation parameter for the convergence acceleration of the ADMM algorithm. _algorithm_['tolerance'] (float): Tolerance to terminate reconstruction algorithm iterations earlier. Defaults to 0.0. _algorithm_['verbose'] (bool): Switch on printing of iterations number and other messages. Defaults to False. _regularisation_['method'] (str): Select the regularisation method from the CCPi-regularisation toolkit [KAZ2019]_. The supported methods listed: ROF_TV, FGP_TV, PD_TV, SB_TV, LLT_ROF, TGV, NDF, Diff4th, NLTV. If the `pypwt` package is installed for Wavelets, then one can do WAVELET regularisation by adding WAVELETS string to any method's name above. For instance, ROF_TV_WAVELETS would enable dual regularisation with ROF_TV and wavelets. See `regul_param2` to control wavelets smoothing. _regularisation_['regul_param'] (float): The main regularisation parameter for regularisers to control the amount of smoothing. Defaults to 0.001. _regularisation_['iterations'] (int): The number of iterations for regularisers (INNER iterations). Defaults to 150. _regularisation_['device_regulariser'] (str, int): Select between 'cpu' or 'gpu' devices. One can also provide a GPU index (integer) of a specific GPU device. _regularisation_['edge_threhsold'] (float): Noise-related threshold parameter for NDF and DIFF4th (diffusion) regularisers. _regularisation_['tolerance'] (float): Tolerance to stop inner regularisation iterations prematurely. _regularisation_['time_marching_step'] (float): Time step parameter for convergence of gradient-based methods: ROF_TV, LLT_ROF, NDF,Diff4th. _regularisation_['regul_param2'] (float): The second regularisation parameter for LLT_ROF or when WAVELETS used. _regularisation_['TGV_alpha1'] (float): The TGV penalty specific parameter for the 1st order term. _regularisation_['TGV_alpha2'] (float): The TGV penalty specific parameter for the 2nd order term. _regularisation_['PD_LipschitzConstant'] (float): The Primal-Dual (PD) penalty related parameter for convergence (PD_TV and TGV specific). _regularisation_['NDF_penalty'] (str): The NDF-method specific penalty type: Huber (default), Perona, Tukey. _regularisation_['NLTV_H_i'] (ndarray): The NLTV penalty related weights, the array of i-related indices. _regularisation_['NLTV_H_j'] (ndarray): The NLTV penalty related weights, the array of j-related indices. _regularisation_['NLTV_Weights] (ndarray): The NLTV-specific penalty type, the array of Weights. _regularisation_['methodTV'] (int): 0/1 - TV specific isotropic/anisotropic choice. Returns: tuple: A tuple with three populated dictionaries (_data_, _algorithm_, _regularisation_). """ if _data_ is None: raise NameError("The data dictionary must be always provided") else: # -------- dealing with _data_ dictionary ------------ if _data_.get("projection_norm_data") is None: raise NameError("No input 'projection_norm_data' has been provided") # projection raw data for PWLS/SWLS type data models if _data_.get("projection_raw_data") is None: if self.datafidelity in ["PWLS", "SWLS"]: raise NameError( "No input 'projection_raw_data' provided for PWLS or SWLS data fidelity" ) if "data_axes_labels_order" not in _data_: _data_["data_axes_labels_order"] = None if _data_["data_axes_labels_order"] is not None: if self.geom == "2D": _data_["projection_norm_data"] = _data_dims_swapper( _data_["projection_norm_data"], _data_["data_axes_labels_order"], ["angles", "detX"], ) if self.datafidelity in ["PWLS", "SWLS"]: _data_["projection_raw_data"] = _data_dims_swapper( _data_["projection_raw_data"], _data_["data_axes_labels_order"], ["angles", "detX"], ) else: _data_["projection_norm_data"] = _data_dims_swapper( _data_["projection_norm_data"], _data_["data_axes_labels_order"], ["detY", "angles", "detX"], ) if self.datafidelity in ["PWLS", "SWLS"]: _data_["projection_raw_data"] = _data_dims_swapper( _data_["projection_raw_data"], _data_["data_axes_labels_order"], ["detY", "angles", "detX"], ) # we need to reset the swap option here as the data already been modified so we don't swap it again in the method _data_["data_axes_labels_order"] = None if _data_.get("OS_number") is None: _data_["OS_number"] = 1 # classical approach (default) self.OS_number = _data_["OS_number"] if method_run == "FISTA": if self.datafidelity == "SWLS": if _data_.get("beta_SWLS") is None: # SWLS related parameter (ring supression) _data_["beta_SWLS"] = 0.1 * np.ones(self.Atools.detectors_x) else: _data_["beta_SWLS"] = _data_["beta_SWLS"] * np.ones( self.Atools.detectors_x ) # Huber data model to supress artifacts if "huber_threshold" not in _data_: _data_["huber_threshold"] = None # Students't data model to supress artifactsand (self.datafidelity == 'SWLS'): if "studentst_threshold" not in _data_: _data_["studentst_threshold"] = None # Group-Huber data model to supress full rings of the same intensity if "ringGH_lambda" not in _data_: _data_["ringGH_lambda"] = None # Group-Huber data model acceleration factor (use carefully to avoid divergence) if "ringGH_accelerate" not in _data_: _data_["ringGH_accelerate"] = 50 # ---------- dealing with _algorithm_ -------------- if _algorithm_ is None: _algorithm_ = {} if method_run in {"SIRT", "CGLS", "power", "ADMM", "Landweber"}: _algorithm_["lipschitz_const"] = 0 # bypass Lipshitz const calculation bellow if _algorithm_.get("iterations") is None: if method_run == "SIRT": _algorithm_["iterations"] = 200 if method_run == "CGLS": _algorithm_["iterations"] = 30 if method_run in {"power", "ADMM"}: _algorithm_["iterations"] = 15 if method_run == "Landweber": _algorithm_["iterations"] = 1500 if _algorithm_.get("tau_step_lanweber") is None: _algorithm_["tau_step_lanweber"] = 1e-05 if method_run == "FISTA": # default iterations number for FISTA reconstruction algorithm if _algorithm_.get("iterations") is None: if _data_["OS_number"] > 1: _algorithm_["iterations"] = 20 # Ordered - Subsets else: _algorithm_["iterations"] = 400 # Classical if _algorithm_.get("lipschitz_const") is None: # if not provided calculate Lipschitz constant automatically _algorithm_["lipschitz_const"] = self.powermethod(_data_) if method_run == "ADMM": # ADMM -algorithm augmented Lagrangian parameter if "ADMM_rho_const" not in _algorithm_: _algorithm_["ADMM_rho_const"] = 1000.0 # ADMM over-relaxation parameter to accelerate convergence if "ADMM_relax_par" not in _algorithm_: _algorithm_["ADMM_relax_par"] = 1.0 # initialise an algorithm with an array if "initialise" not in _algorithm_: _algorithm_["initialise"] = None # ENABLE or DISABLE the nonnegativity for algorithm if "nonnegativity" not in _algorithm_: _algorithm_["nonnegativity"] = False if _algorithm_["nonnegativity"]: self.nonneg_regul = 1 # enable nonnegativity for regularisers else: self.nonneg_regul = 0 # disable nonnegativity for regularisers if "recon_mask_radius" not in _algorithm_: _algorithm_["recon_mask_radius"] = 1.0 # tolerance to stop OUTER algorithm iterations earlier if "tolerance" not in _algorithm_: _algorithm_["tolerance"] = 0.0 if "verbose" not in _algorithm_: _algorithm_["verbose"] = False # ---------- deal with _regularisation_ -------------- if _regularisation_ is None: _regularisation_ = {} if bool(_regularisation_) is False: _regularisation_["method"] = None if method_run in {"FISTA", "ADMM"}: # regularisation parameter (main) if "regul_param" not in _regularisation_: _regularisation_["regul_param"] = 0.001 # regularisation parameter second (LLT_ROF) if "regul_param2" not in _regularisation_: _regularisation_["regul_param2"] = 0.001 # set the number of inner (regularisation) iterations if "iterations" not in _regularisation_: _regularisation_["iterations"] = 150 # tolerance to stop inner regularisation iterations prematurely if "tolerance" not in _regularisation_: _regularisation_["tolerance"] = 0.0 # time marching step to ensure convergence for gradient based methods: ROF_TV, LLT_ROF, NDF, Diff4th if "time_marching_step" not in _regularisation_: _regularisation_["time_marching_step"] = 0.005 # TGV specific parameter for the 1st order term if "TGV_alpha1" not in _regularisation_: _regularisation_["TGV_alpha1"] = 1.0 # TGV specific parameter for the 2тв order term if "TGV_alpha2" not in _regularisation_: _regularisation_["TGV_alpha2"] = 2.0 # Primal-dual parameter for convergence (TGV specific) if "PD_LipschitzConstant" not in _regularisation_: _regularisation_["PD_LipschitzConstant"] = 12.0 # edge (noise) threshold parameter for NDF and DIFF4th models if "edge_threhsold" not in _regularisation_: _regularisation_["edge_threhsold"] = 0.001 # NDF specific penalty type: Huber (default), Perona, Tukey if "NDF_penalty" not in _regularisation_: _regularisation_["NDF_penalty"] = "Huber" self.NDF_method = 1 else: if _regularisation_["NDF_penalty"] == "Huber": self.NDF_method = 1 elif _regularisation_["NDF_penalty"] == "Perona": self.NDF_method = 2 elif _regularisation_["NDF_penalty"] == "Tukey": self.NDF_method = 3 else: raise NameError("For NDF_penalty choose Huber, Perona or Tukey") # NLTV penalty related weights, , the array of i-related indices if "NLTV_H_i" not in _regularisation_: _regularisation_["NLTV_H_i"] = 0 # NLTV penalty related weights, , the array of i-related indices if "NLTV_H_j" not in _regularisation_: _regularisation_["NLTV_H_j"] = 0 # NLTV-specific penalty type, the array of Weights if "NLTV_Weights" not in _regularisation_: _regularisation_["NLTV_Weights"] = 0 # 0/1 - TV specific isotropic/anisotropic choice if "methodTV" not in _regularisation_: _regularisation_["methodTV"] = 0 # choose the type of the device for the regulariser if "device_regulariser" not in _regularisation_: _regularisation_["device_regulariser"] = "gpu" return (_data_, _algorithm_, _regularisation_)
def _reinitialise_atools_OS(self, _data_: dict): """reinitialises OS geometry by overwriting the existing Atools Note: Not an ideal thing to do as it can lead to various problems, worth considering moving the subsets definition to the class init. Args: _data_ (dict): data dictionary """ if self.geom == "2D": self.Atools = AstraTools2D( self.Atools.detectors_x, self.Atools.angles_vec, self.Atools.centre_of_rotation, self.Atools.recon_size, self.Atools.processing_arch, self.Atools.device_index, _data_["OS_number"], ) # initiate 2D ASTRA class OS object else: self.Atools = AstraTools3D( self.Atools.detectors_x, self.Atools.detectors_y, self.Atools.angles_vec, self.Atools.centre_of_rotation, self.Atools.recon_size, self.Atools.processing_arch, self.Atools.device_index, _data_["OS_number"], ) # initiate 3D ASTRA class OS object return _data_