from typing import Union
import cupy as cp
from tomobar.supp.funcs import _data_dims_swapper
[docs]
def dicts_check(
self,
_data_: dict,
_algorithm_: Union[dict, None] = None,
_regularisation_: Union[dict, None] = None,
method_run: str = "FISTA",
) -> tuple:
"""This function accepts the `_data_`, `_algorithm_`, and `_regularisation_`
dictionaries and populates parameters in them, if required. Please note that the dictionaries are
needed for iterative methods only. The most versatile methods that can accept a variety of different parameters
are FISTA and ADMM.
Args:
_data_ (dict): *Data dictionary where data-related items must be specified.*
_algorithm_ (dict, optional): *Algorithm dictionary. Defaults to None.*
_regularisation_ (dict, optional): *Regularisation dictionary. Needed only for FISTA and ADMM algorithms. Defaults to None.*
method_run (str): *The name of the method to be run. Defaults to "FISTA".*
Keyword Args:
_data_['projection_data'] (ndarray): Can be either projection data after negative log or raw data given as a 3D CuPy array.
_data_['data_axes_labels_order'] (list, None). The order of the axes labels for the input data. The default data labels are: ["detY", "angles", "detX"].
_data_['data_fidelity'] (str). Data fidelity given as 'LS' (Least Squares), 'PWLS' (Penalised Weightes LS), 'KL' (Kullback Leilbler). Defaults to 'LS'.
_algorithm_['iterations'] (int): The number of iterations for the reconstruction algorithm.
_algorithm_['nonnegativity'] (bool): Enable nonnegativity for the solution. Defaults to False.
_algorithm_['recon_mask_radius'] (float): Enables the circular mask cutoff in the reconstructed image. Defaults to 1.0.
_algorithm_['initialise'] (ndarray): Initialisation for the solution. An array of the expected output size must be provided.
_algorithm_['lipschitz_const'] (float): Lipschitz constant for the FISTA and ADMM algorithms. If not provided, it will be calculated for each method call.
_algorithm_['ADMM_rho_const'] (float): Augmented Lagrangian parameter for the ADMM algorithm.
_algorithm_['ADMM_relax_par'] (float): Over relaxation parameter for the convergence acceleration of the ADMM algorithm. Values within 1.5-1.8 range work well.
_algorithm_['tolerance'] (float): Tolerance to terminate reconstruction algorithm iterations earlier. Defaults to 0.0.
_algorithm_['verbose'] (bool): Switch on printing of iterations number and other messages. Defaults to False.
_regularisation_['method'] (str): Select the regularisation method for noise supression. The supported methods are: ROF_TV, PD_TV.
_regularisation_['regul_param'] (float): The main regularisation parameter for regularisers to control the amount of smoothing. Defaults to 0.001.
_regularisation_['iterations'] (int): The number of iterations for regularisers (INNER iterations). Defaults to 150.
_regularisation_['device_regulariser'] (int): A GPU device index to perform operation on. Defaults to 0.
_regularisation_['time_marching_step'] (float): Time step parameter for convergence of gradient-based methods: ROF_TV.
_regularisation_['PD_LipschitzConstant'] (float): The Primal-Dual (PD) penalty related parameter for convergence (PD_TV specific).
_regularisation_['methodTV'] (int): 0/1 - TV specific isotropic/anisotropic choice.
Returns:
tuple: A tuple with three populated dictionaries (_data_, _algorithm_, _regularisation_).
"""
correct_labels_order = ["detY", "angles", "detX"]
correct_labels_order2D = ["angles", "detX"]
data2dinput = False
if _data_ is None:
raise NameError("The data dictionary must be always provided")
else:
# -------- dealing with _data_ dictionary ------------
if _data_.get("projection_data") is None:
raise NameError("'projection_data' needs to be provided")
if _data_["projection_data"].ndim == 2:
data2dinput = True
if "data_axes_labels_order" not in _data_:
_data_["data_axes_labels_order"] = None
if _data_["data_axes_labels_order"] is not None:
if data2dinput:
correct_labels_order = correct_labels_order2D
_data_["projection_data"] = _data_dims_swapper(
_data_["projection_data"],
_data_["data_axes_labels_order"],
correct_labels_order,
)
# we need to reset the swap option here as the data already been modified so we don't swap it again in the method itself
_data_["data_axes_labels_order"] = None
if data2dinput:
_data_["projection_data"] = cp.expand_dims(
_data_["projection_data"], axis=0
)
if _data_.get("data_fidelity") is None:
_data_["data_fidelity"] = "LS"
if _data_["data_fidelity"] not in {"LS", "PWLS", "KL"}:
raise ValueError(
"_data_['data_fidelity'] should be provided as 'LS', 'PWLS', 'KL'."
)
else:
self.data_fidelity = _data_["data_fidelity"]
if self.OS_number > 1:
if method_run in {"SIRT", "CGLS", "Landweber"}:
raise NameError(
"There is no ordered-subsets implementation for this reconstruction method, please set OS_number=None"
)
# ---------- dealing with _algorithm_ --------------
if _algorithm_ is None:
_algorithm_ = {}
if method_run in {"SIRT", "CGLS", "power", "Landweber", "OSEM"}:
_algorithm_["lipschitz_const"] = 0 # bypass Lipshitz const calculation
if _algorithm_.get("iterations") is None:
if method_run == "SIRT":
_algorithm_["iterations"] = 200
if method_run == "CGLS":
_algorithm_["iterations"] = 30
if method_run in {"power"}:
_algorithm_["iterations"] = 15
if method_run == "Landweber":
_algorithm_["iterations"] = 1500
if _algorithm_.get("tau_step_lanweber") is None:
_algorithm_["tau_step_lanweber"] = 1e-05
if method_run == "OSEM":
if _algorithm_.get("iterations") is None:
if self.OS_number > 1:
_algorithm_["iterations"] = 15 # Ordered - Subsets
else:
_algorithm_["iterations"] = 300 # Classical
if method_run == "FISTA":
# default iterations number for FISTA reconstruction algorithm
if _algorithm_.get("iterations") is None:
if self.OS_number > 1:
_algorithm_["iterations"] = 20 # Ordered - Subsets
else:
_algorithm_["iterations"] = 400 # Classical
if method_run == "ADMM":
# ADMM -algorithm augmented Lagrangian parameter
if _algorithm_.get("iterations") is None:
if self.OS_number > 1:
_algorithm_["iterations"] = 10 # Ordered - Subsets
else:
_algorithm_["iterations"] = 400 # Classical
if "ADMM_rho_const" not in _algorithm_:
_algorithm_["ADMM_rho_const"] = 1.0
# ADMM over-relaxation parameter to accelerate convergence
if "ADMM_relax_par" not in _algorithm_:
_algorithm_["ADMM_relax_par"] = 1.6
# initialise an algorithm with an array
if "initialise" not in _algorithm_:
_algorithm_["initialise"] = None
# ENABLE or DISABLE the nonnegativity for algorithm
if "nonnegativity" not in _algorithm_:
_algorithm_["nonnegativity"] = False
if _algorithm_["nonnegativity"] not in [True, False]:
raise ValueError("_algorithm_['nonnegativity'] should be set to True or False.")
if _algorithm_["nonnegativity"]:
self.nonneg_regul = 1 # enable nonnegativity for regularisers
else:
self.nonneg_regul = 0 # disable nonnegativity for regularisers
if "recon_mask_radius" not in _algorithm_:
_algorithm_["recon_mask_radius"] = 1.0
# tolerance to stop OUTER algorithm iterations earlier
if "tolerance" not in _algorithm_:
_algorithm_["tolerance"] = 0.0
if "verbose" not in _algorithm_:
_algorithm_["verbose"] = False
# ---------- deal with _regularisation_ --------------
if _regularisation_ is None:
_regularisation_ = {}
if bool(_regularisation_) is False:
_regularisation_["method"] = None
if method_run in {"FISTA", "ADMM", "OSEM"}:
# regularisation parameter (main)
if "regul_param" not in _regularisation_:
_regularisation_["regul_param"] = 0.001
# set the number of inner (regularisation) iterations
if "iterations" not in _regularisation_:
_regularisation_["iterations"] = 150
# tolerance to stop inner regularisation iterations prematurely
if "tolerance" not in _regularisation_:
_regularisation_["tolerance"] = 0.0
# time marching step to ensure convergence for gradient based methods: ROF_TV, LLT_ROF, NDF, Diff4th
if "time_marching_step" not in _regularisation_:
_regularisation_["time_marching_step"] = 0.005
# Primal-dual parameter for convergence (TGV specific)
if "PD_LipschitzConstant" not in _regularisation_:
_regularisation_["PD_LipschitzConstant"] = 12.0
if "methodTV" not in _regularisation_:
_regularisation_["methodTV"] = 0
# choose the type of the device for the regulariser
if "device_regulariser" not in _regularisation_:
_regularisation_["device_regulariser"] = 0
return (_data_, _algorithm_, _regularisation_)