Source code for tomobar.methodsIR
"""Reconstruction class for regularised iterative methods (2D/3D).
* :func:`RecToolsIR.FISTA` FISTA - iterative regularised algorithm [BT2009]_, [Xu2016]_.
* :func:`RecToolsIR.ADMM` ADMM iterative regularised algorithm [Boyd2011]_.
* :func:`RecToolsIR.SIRT` and :func:`RecToolsIR.CGLS` algorithms are wrapped directly from the ASTRA package.
"""
import numpy as xp
from numpy import linalg
from typing import Union
try:
import cupy as cp
cupy_imported = True
except ImportError:
import numpy as xp
cupy_imported = False
try:
import astra
except ImportError:
print("____! Astra-toolbox package is missing, please install !____")
from tomobar.supp.dicts import dicts_check, _reinitialise_atools_OS
from tomobar.supp.suppTools import circ_mask
from tomobar.supp.funcs import _data_dims_swapper, _parse_device_argument
from tomobar.regularisers import prox_regul
from tomobar.astra_wrappers.astra_tools2d import AstraTools2D
from tomobar.astra_wrappers.astra_tools3d import AstraTools3D
[docs]
class RecToolsIR:
"""Iterative reconstruction algorithms (FISTA and ADMM) using ASTRA toolbox and CCPi-RGL toolkit.
Parameters for reconstruction algorithms should be provided in three dictionaries:
:data:`_data_`, :data:`_algorithm_`, and :data:`_regularisation_`. See :mod:`tomobar.supp.dicts`
function of ToMoBAR's :ref:`ref_api` for all parameters explained.
Args:
DetectorsDimH (int): Horizontal detector dimension.
DetectorsDimV (int): Vertical detector dimension for 3D case, 0 or None for 2D case.
CenterRotOffset (float): The Centre of Rotation (CoR) scalar or a vector for each angle.
AnglesVec (np.ndarray): Vector of projection angles in radians.
ObjSize (int): Reconstructed object dimensions (a scalar).
datafidelity (str): Data fidelity, choose from LS, KL, PWLS or SWLS.
device_projector (str, int): 'cpu' or 'gpu' device OR provide a GPU index (integer) of a specific GPU device.
cupyrun (bool, optional): instantiate CuPy modules.
"""
def __init__(
self,
DetectorsDimH, # Horizontal detector dimension
DetectorsDimV, # Vertical detector dimension (3D case), 0 or None for 2D case
CenterRotOffset, # The Centre of Rotation scalar or a vector
AnglesVec, # Array of projection angles in radians
ObjSize, # Reconstructed object dimensions (scalar)
datafidelity="LS", # Data fidelity, choose from LS, KL, PWLS, SWLS
device_projector="gpu", # Choose the device to be 'cpu' or 'gpu' OR provide a GPU index (integer) of a specific device
cupyrun=False,
):
self.datafidelity = datafidelity
self.cupyrun = cupyrun
device_projector, GPUdevice_index = _parse_device_argument(device_projector)
if DetectorsDimV == 0 or DetectorsDimV is None:
self.geom = "2D"
self.Atools = AstraTools2D(
DetectorsDimH,
AnglesVec,
CenterRotOffset,
ObjSize,
device_projector,
GPUdevice_index,
)
else:
self.geom = "3D"
self.Atools = AstraTools3D(
DetectorsDimH,
DetectorsDimV,
AnglesVec,
CenterRotOffset,
ObjSize,
device_projector,
GPUdevice_index,
)
@property
def datafidelity(self) -> int:
return self._datafidelity
@datafidelity.setter
def datafidelity(self, datafidelity_val):
if datafidelity_val not in ["LS", "PWLS", "SWLS", "KL"]:
raise ValueError("Unknown data fidelity type, select: LS, PWLS, SWLS or KL")
self._datafidelity = datafidelity_val
@property
def cupyrun(self) -> int:
return self._cupyrun
@cupyrun.setter
def cupyrun(self, cupyrun_val):
self._cupyrun = cupyrun_val
[docs]
def SIRT(self, _data_: dict, _algorithm_: Union[dict, None] = None) -> xp.ndarray:
"""Simultaneous Iterations Reconstruction Technique from ASTRA toolbox.
Args:
_data_ (dict): Data dictionary, where input data is provided.
_algorithm_ (dict, optional): Algorithm dictionary where algorithm parameters are provided.
Returns:
xp.ndarray: SIRT-reconstructed numpy array
"""
######################################################################
# parameters check and initialisation
(_data_upd_, _algorithm_upd_, _regularisation_upd_) = dicts_check(
self, _data_, _algorithm_, method_run="SIRT"
)
######################################################################
# SIRT reconstruction algorithm from ASTRA wrappers
return self.Atools._sirt(
_data_upd_["projection_norm_data"], _algorithm_upd_["iterations"]
)
[docs]
def CGLS(self, _data_: dict, _algorithm_: Union[dict, None] = None) -> xp.ndarray:
"""Conjugate Gradient Least Squares from ASTRA toolbox.
Args:
_data_ (dict): Data dictionary, where input data is provided
_algorithm_ (dict, optional): Algorithm dictionary where algorithm parameters are provided.
Returns:
xp.ndarray: CGLS-reconstructed numpy array
"""
######################################################################
# parameters check and initialisation
(_data_upd_, _algorithm_upd_, _regularisation_upd_) = dicts_check(
self, _data_, _algorithm_, method_run="CGLS"
)
######################################################################
# CGLS reconstruction algorithm from ASTRA-wrappers
return self.Atools._cgls(
_data_upd_["projection_norm_data"], _algorithm_upd_["iterations"]
)
[docs]
def powermethod(self, _data_: dict) -> float:
"""Power iteration algorithm to calculate the eigenvalue of the operator (projection matrix).
projection_raw_data is required for PWLS fidelity, otherwise will be ignored.
Args:
_data_ (dict): Data dictionary, where input data is provided.
Returns:
float: the Lipschitz constant
"""
if not self.cupyrun:
import numpy as xp
else:
if cupy_imported:
import cupy as xp
if "data_axes_labels_order" not in _data_:
_data_["data_axes_labels_order"] = None
if (
self.datafidelity in ["PWLS", "SWLS"]
and "projection_raw_data" not in _data_
):
raise ValueError("Please provide projection_raw_data for this model")
if self.datafidelity in ["PWLS", "SWLS"]:
sqweight = _data_["projection_raw_data"]
if _data_["data_axes_labels_order"] is not None:
if self.geom == "2D":
_data_["projection_norm_data"] = _data_dims_swapper(
_data_["projection_norm_data"],
_data_["data_axes_labels_order"],
["angles", "detX"],
)
if self.datafidelity in ["PWLS", "SWLS"]:
_data_["projection_raw_data"] = _data_dims_swapper(
_data_["projection_raw_data"],
_data_["data_axes_labels_order"],
["angles", "detX"],
)
sqweight = _data_["projection_raw_data"]
else:
_data_["projection_norm_data"] = _data_dims_swapper(
_data_["projection_norm_data"],
_data_["data_axes_labels_order"],
["detY", "angles", "detX"],
)
if self.datafidelity in ["PWLS", "SWLS"]:
_data_["projection_raw_data"] = _data_dims_swapper(
_data_["projection_raw_data"],
_data_["data_axes_labels_order"],
["detY", "angles", "detX"],
)
sqweight = _data_["projection_raw_data"]
# we need to reset the swap option here as the data already been modified so we don't swap it again in the method
_data_["data_axes_labels_order"] = None
if _data_.get("OS_number") is None:
_data_["OS_number"] = 1 # the classical approach (default)
else:
_data_ = _reinitialise_atools_OS(self, _data_)
power_iterations = 15
s = 1.0
proj_geom = astra.geom_size(self.Atools.vol_geom)
if cupy_imported and self.cupyrun:
x1 = cp.random.randn(*proj_geom, dtype=cp.float32)
else:
x1 = xp.float32(xp.random.randn(*proj_geom))
if _data_["OS_number"] == 1:
# non-OS approach
if cupy_imported and self.cupyrun:
y = self.Atools._forwprojCuPy(x1)
else:
y = self.Atools._forwproj(x1)
if self.datafidelity == "PWLS":
y = xp.multiply(sqweight, y)
for iterations in range(power_iterations):
if cupy_imported and self.cupyrun:
x1 = self.Atools._backprojCuPy(y)
else:
x1 = self.Atools._backproj(y)
if cupy_imported and self.cupyrun:
s = cp.linalg.norm(cp.ravel(x1), axis=0)
else:
s = xp.linalg.norm(xp.ravel(x1), axis=0)
x1 = x1 / s
if cupy_imported and self.cupyrun:
y = self.Atools._forwprojCuPy(x1)
else:
y = self.Atools._forwproj(x1)
if self.datafidelity == "PWLS":
y = xp.multiply(sqweight, y)
else:
# OS approach
if cupy_imported and self.cupyrun:
y = self.Atools._forwprojOSCuPy(x1, 0)
else:
y = self.Atools._forwprojOS(x1, 0)
if self.datafidelity == "PWLS":
if self.geom == "2D":
y = xp.multiply(sqweight[self.Atools.newInd_Vec[0, :], :], y)
else:
y = xp.multiply(sqweight[:, self.Atools.newInd_Vec[0, :], :], y)
for _ in range(power_iterations):
if cupy_imported and self.cupyrun:
x1 = self.Atools._backprojOSCuPy(y, 0)
else:
x1 = self.Atools._backprojOS(y, 0)
if cupy_imported and self.cupyrun:
s = cp.linalg.norm(cp.ravel(x1), axis=0)
else:
s = xp.linalg.norm(xp.ravel(x1), axis=0)
x1 = x1 / s
if cupy_imported and self.cupyrun:
y = self.Atools._forwprojOSCuPy(x1, 0)
else:
y = self.Atools._forwprojOS(x1, 0)
if self.datafidelity == "PWLS":
if self.geom == "2D":
y = xp.multiply(sqweight[self.Atools.newInd_Vec[0, :], :], y)
else:
y = xp.multiply(sqweight[:, self.Atools.newInd_Vec[0, :], :], y)
return s
[docs]
def FISTA(
self,
_data_: dict,
_algorithm_: Union[dict, None] = None,
_regularisation_: Union[dict, None] = None,
) -> xp.ndarray:
"""A Fast Iterative Shrinkage-Thresholding Algorithm with various types of regularisation and
data fidelity terms provided in three dictionaries.
See :mod:`tomobar.supp.dicts` for all parameters to the dictionaries bellow.
Args:
_data_ (dict): Data dictionary, where input data is provided.
_algorithm_ (dict, optional): Algorithm dictionary where algorithm parameters are provided.
_regularisation_ (dict, optional): Regularisation dictionary.
Returns:
np.ndarray: FISTA-reconstructed numpy array
"""
if not self.cupyrun:
import numpy as xp
else:
if cupy_imported:
import cupy as xp
######################################################################
# parameters check and initialisation
(_data_upd_, _algorithm_upd_, _regularisation_upd_) = dicts_check(
self, _data_, _algorithm_, _regularisation_, method_run="FISTA"
)
if _data_upd_["OS_number"] > 1:
_data_upd_ = _reinitialise_atools_OS(self, _data_upd_)
######################################################################
L_const_inv = (
1.0 / _algorithm_upd_["lipschitz_const"]
) # inverted Lipschitz constant
if self.geom == "2D":
# 2D reconstruction
# initialise the solution
if xp.size(_algorithm_upd_["initialise"]) == self.Atools.recon_size**2:
# the object has been initialised with an array
X = _algorithm_upd_["initialise"]
else:
X = xp.zeros(
(self.Atools.recon_size, self.Atools.recon_size), "float32"
) # initialise with zeros
r = xp.zeros(
(self.Atools.detectors_x, 1), "float32"
) # 1D array of sparse "ring" variables (GH)
if self.geom == "3D":
# initialise the solution
if xp.size(_algorithm_upd_["initialise"]) == self.Atools.recon_size**3:
# the object has been initialised with an array
X = _algorithm_upd_["initialise"]
else:
X = xp.zeros(
(
self.Atools.detectors_y,
self.Atools.recon_size,
self.Atools.recon_size,
),
"float32",
) # initialise with zeros
r = xp.zeros(
(self.Atools.detectors_y, self.Atools.detectors_x), "float32"
) # 2D array of sparse "ring" variables (GH)
info_vec = (0, 1)
# ****************************************************************************#
# FISTA (model-based modification) algorithm begins here:
t = 1.0
denomN = 1.0 / xp.size(X)
X_t = xp.copy(X)
r_x = r.copy()
# Outer FISTA iterations
for iter_no in range(_algorithm_upd_["iterations"]):
r_old = r
# Do GH fidelity pre-calculations using the full projections dataset for OS version
if (
(_data_upd_["OS_number"] != 1)
and (_data_upd_["ringGH_lambda"] is not None)
and (iter_no > 0)
):
if self.geom == "2D":
vec = xp.zeros((self.Atools.detectors_x))
else:
vec = xp.zeros((self.Atools.detectors_y, self.Atools.detectors_x))
for sub_ind in range(_data_upd_["OS_number"]):
# select a specific set of indeces for the subset (OS)
indVec = self.Atools.newInd_Vec[sub_ind, :]
if indVec[self.Atools.NumbProjBins - 1] == 0:
indVec = indVec[:-1] # shrink vector size
if self.geom == "2D":
res = (
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][indVec, :]
)
res[:, 0:None] = (
res[:, 0:None] + _data_upd_["ringGH_accelerate"] * r_x[:, 0]
)
vec = vec + (1.0 / (_data_upd_["OS_number"])) * res.sum(axis=0)
else:
res = (
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][:, indVec, :]
)
for ang_index in range(len(indVec)):
res[:, ang_index, :] = (
res[:, ang_index, :]
+ _data_upd_["ringGH_accelerate"] * r_x
)
vec = res.sum(axis=1)
if self.geom == "2D":
r[:, 0] = r_x[:, 0] - xp.multiply(L_const_inv, vec)
else:
r = r_x - xp.multiply(L_const_inv, vec)
# loop over subsets (OS)
for sub_ind in range(_data_upd_["OS_number"]):
X_old = X
t_old = t
if _data_upd_["OS_number"] > 1:
# select a specific set of indeces for the subset (OS)
indVec = self.Atools.newInd_Vec[sub_ind, :]
if indVec[self.Atools.NumbProjBins - 1] == 0:
indVec = indVec[:-1] # shrink vector size
# OS-reduced residuals
if self.geom == "2D":
if self.datafidelity == "LS":
# 2D Least-squares (LS) data fidelity - OS (linear)
res = (
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][indVec, :]
)
if self.datafidelity == "PWLS":
# 2D Penalised Weighted Least-squares - OS data fidelity (approximately linear)
res = xp.multiply(
_data_upd_["projection_raw_data"][indVec, :],
(
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][indVec, :]
),
)
if self.datafidelity == "SWLS":
# 2D Stripe-Weighted Least-squares - OS data fidelity (helps to minimise stripe arifacts)
res = (
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][indVec, :]
)
for det_index in range(self.Atools.detectors_x):
wk = _data_upd_["projection_raw_data"][
indVec, det_index
]
res[:, det_index] = (
xp.multiply(wk, res[:, det_index])
- 1.0
/ (xp.sum(wk) + _data_upd_["beta_SWLS"][det_index])
* (wk.dot(res[:, det_index]))
* wk
)
if self.datafidelity == "KL":
# 2D Kullback-Leibler (KL) data fidelity - OS
tmp = self.Atools._forwprojOS(X_t, sub_ind)
res = xp.divide(
tmp - _data_upd_["projection_norm_data"][indVec, :],
tmp + 1.0,
)
# ring removal part for Group-Huber (GH) fidelity (2D)
if (_data_upd_["ringGH_lambda"] is not None) and (iter_no > 0):
res[:, 0:None] = (
res[:, 0:None]
+ _data_upd_["ringGH_accelerate"] * r_x[:, 0]
)
else: # 3D
if self.datafidelity == "LS":
# 3D Least-squares (LS) data fidelity - OS (linear)
res = (
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][:, indVec, :]
)
if self.datafidelity == "PWLS":
# 3D Penalised Weighted Least-squares - OS data fidelity (approximately linear)
res = xp.multiply(
_data_upd_["projection_raw_data"][:, indVec, :],
(
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][:, indVec, :]
),
)
if self.datafidelity == "SWLS":
# 3D Stripe-Weighted Least-squares - OS data fidelity (helps to minimise stripe arifacts)
res = (
self.Atools._forwprojOS(X_t, sub_ind)
- _data_upd_["projection_norm_data"][:, indVec, :]
)
for detVert_index in range(self.Atools.detectors_y):
for detHorz_index in range(self.Atools.detectors_x):
wk = _data_upd_["projection_raw_data"][
detVert_index, indVec, detHorz_index
]
res[detVert_index, :, detHorz_index] = (
xp.multiply(
wk, res[detVert_index, :, detHorz_index]
)
- 1.0
/ (
xp.sum(wk)
+ _data_upd_["beta_SWLS"][detHorz_index]
)
* (wk.dot(res[detVert_index, :, detHorz_index]))
* wk
)
if self.datafidelity == "KL":
# 3D Kullback-Leibler (KL) data fidelity - OS
tmp = self.Atools._forwprojOS(X_t, sub_ind)
res = xp.divide(
tmp - _data_upd_["projection_norm_data"][:, indVec, :],
tmp + 1.0,
)
# GH - fidelity part (3D)
if (_data_upd_["ringGH_lambda"] is not None) and (iter_no > 0):
for ang_index in range(len(indVec)):
res[:, ang_index, :] = (
res[:, ang_index, :]
+ _data_upd_["ringGH_accelerate"] * r_x
)
else: # CLASSICAL all-data approach
if self.datafidelity == "LS":
# full residual for LS fidelity
res = (
self.Atools._forwproj(X_t)
- _data_upd_["projection_norm_data"]
)
if self.datafidelity == "PWLS":
# full gradient for the PWLS fidelity
res = xp.multiply(
_data_upd_["projection_raw_data"],
(
self.Atools._forwproj(X_t)
- _data_upd_["projection_norm_data"]
),
)
if self.datafidelity == "KL":
# Kullback-Leibler (KL) data fidelity
tmp = self.Atools._forwproj(X_t)
res = xp.divide(
tmp - _data_upd_["projection_norm_data"], tmp + 1.0
)
if (_data_upd_["ringGH_lambda"] is not None) and (iter_no > 0):
if self.geom == "2D":
res[0:None, :] = (
res[0:None, :]
+ _data_upd_["ringGH_accelerate"] * r_x[:, 0]
)
vec = res.sum(axis=0)
r[:, 0] = r_x[:, 0] - xp.multiply(L_const_inv, vec)
else: # 3D case
for ang_index in range(len(self.Atools.angles_vec)):
res[:, ang_index, :] = (
res[:, ang_index, :]
+ _data_upd_["ringGH_accelerate"] * r_x
)
vec = res.sum(axis=1)
r = r_x - xp.multiply(L_const_inv, vec)
if self.datafidelity == "SWLS":
res = (
self.Atools._forwproj(X_t)
- _data_upd_["projection_norm_data"]
)
if self.geom == "2D":
for det_index in range(self.Atools.detectors_x):
wk = _data_upd_["projection_raw_data"][:, det_index]
res[:, det_index] = (
xp.multiply(wk, res[:, det_index])
- 1.0
/ (xp.sum(wk) + _data_upd_["beta_SWLS"][det_index])
* (wk.dot(res[:, det_index]))
* wk
)
else: # 3D case
for detVert_index in range(self.Atools.detectors_y):
for detHorz_index in range(self.Atools.detectors_x):
wk = _data_upd_["projection_raw_data"][
detVert_index, :, detHorz_index
]
res[detVert_index, :, detHorz_index] = (
xp.multiply(
wk, res[detVert_index, :, detHorz_index]
)
- 1.0
/ (
xp.sum(wk)
+ _data_upd_["beta_SWLS"][detHorz_index]
)
* (wk.dot(res[detVert_index, :, detHorz_index]))
* wk
)
if _data_upd_["huber_threshold"] is not None:
# apply Huber penalty
multHuber = xp.ones(xp.shape(res))
multHuber[
(xp.where(xp.abs(res) > _data_upd_["huber_threshold"]))
] = xp.divide(
_data_upd_["huber_threshold"],
xp.abs(
res[(xp.where(xp.abs(res) > _data_upd_["huber_threshold"]))]
),
)
if _data_upd_["OS_number"] != 1:
# OS-Huber-gradient
grad_fidelity = self.Atools._backprojOS(
xp.multiply(multHuber, res), sub_ind
)
else:
# full Huber gradient
grad_fidelity = self.Atools._backproj(
xp.multiply(multHuber, res)
)
elif _data_upd_["studentst_threshold"] is not None:
# apply Students't penalty
multStudent = xp.ones(xp.shape(res))
multStudent = xp.divide(
2.0, _data_upd_["studentst_threshold"] ** 2 + res**2
)
if _data_upd_["OS_number"] != 1:
# OS-Students't-gradient
grad_fidelity = self.Atools._backprojOS(
xp.multiply(multStudent, res), sub_ind
)
else:
# full Students't gradient
grad_fidelity = self.Atools._backproj(
xp.multiply(multStudent, res)
)
else:
if _data_upd_["OS_number"] != 1:
# OS reduced gradient
grad_fidelity = self.Atools._backprojOS(res, sub_ind)
else:
# full gradient
grad_fidelity = self.Atools._backproj(res)
X = X_t - L_const_inv * grad_fidelity
if _algorithm_upd_["nonnegativity"] == "ENABLE":
X[X < 0.0] = 0.0
if _algorithm_upd_["recon_mask_radius"] is not None:
X = circ_mask(
X, _algorithm_upd_["recon_mask_radius"]
) # applying a circular mask
if _regularisation_upd_["method"] is not None:
##### The proximal operator of the chosen regulariser #####
(X, info_vec) = prox_regul(self, X, _regularisation_upd_)
###########################################################
# updating t variable
t = (1.0 + xp.sqrt(1.0 + 4.0 * t**2)) * 0.5
X_t = X + ((t_old - 1.0) / t) * (X - X_old) # updating X
if (_data_upd_["ringGH_lambda"] is not None) and (iter_no > 0):
r = xp.maximum(
(xp.abs(r) - _data_upd_["ringGH_lambda"]), 0.0
) * xp.sign(
r
) # soft-thresholding operator for ring vector
r_x = r + ((t_old - 1.0) / t) * (r - r_old) # updating r
if _algorithm_upd_["verbose"]:
if xp.mod(iter_no, (round)(_algorithm_upd_["iterations"] / 5) + 1) == 0:
print(
"FISTA iteration (",
iter_no + 1,
") using",
_regularisation_upd_["method"],
"regularisation for (",
(int)(info_vec[0]),
") iterations",
)
if iter_no == _algorithm_upd_["iterations"] - 1:
print("FISTA stopped at iteration (", iter_no + 1, ")")
# stopping criteria (checked only after a reasonable number of iterations)
if ((iter_no > 10) and (_data_upd_["OS_number"] > 1)) or (
(iter_no > 150) and (_data_upd_["OS_number"] == 1)
):
nrm = linalg.norm(X - X_old) * denomN
if nrm < _algorithm_upd_["tolerance"]:
if _algorithm_upd_["verbose"]:
print("FISTA stopped at iteration (", iter_no + 1, ")")
break
return X
# # *****************************FISTA ends here*********************************#
# **********************************ADMM***************************************#
[docs]
def ADMM(
self,
_data_: dict,
_algorithm_: Union[dict, None] = None,
_regularisation_: Union[dict, None] = None,
) -> xp.ndarray:
"""Alternating Directions Method of Multipliers with various types of regularisation and
data fidelity terms provided in three dictionaries, see :mod:`tomobar.supp.dicts`
Args:
_data_ (dict): Data dictionary, where input data is provided.
_algorithm_ (dict, optional): Algorithm dictionary where algorithm parameters are provided.
_regularisation_ (dict, optional): Regularisation dictionary.
Returns:
xp.ndarray: ADMM-reconstructed numpy array
"""
try:
import scipy.sparse.linalg
except ImportError:
print(
"____! Scipy toolbox package is missing, please install for ADMM !____"
)
if not self.cupyrun:
import numpy as xp
######################################################################
# parameters check and initialisation
(_data_upd_, _algorithm_upd_, _regularisation_upd_) = dicts_check(
self, _data_, _algorithm_, _regularisation_, method_run="ADMM"
)
######################################################################
def ADMM_Ax(x):
data_upd = self.Atools.A_optomo(x)
x_temp = self.Atools.A_optomo.transposeOpTomo(data_upd)
x_upd = x_temp + _algorithm_upd_["ADMM_rho_const"] * x
return x_upd
def ADMM_Atb(b):
b = self.Atools.A_optomo.transposeOpTomo(b)
return b
(data_dim, rec_dim) = xp.shape(self.Atools.A_optomo)
# initialise the solution and other ADMM variables
if xp.size(_algorithm_upd_["initialise"]) == rec_dim:
# the object has been initialised with an array
X = _algorithm_upd_["initialise"].ravel()
else:
X = xp.zeros(rec_dim, "float32")
info_vec = (0, 2)
denomN = 1.0 / xp.size(X)
z = xp.zeros(rec_dim, "float32")
u = xp.zeros(rec_dim, "float32")
b_to_solver_const = self.Atools.A_optomo.transposeOpTomo(
_data_upd_["projection_norm_data"].ravel()
)
# Outer ADMM iterations
for iter_no in range(_algorithm_upd_["iterations"]):
X_old = X
# solving quadratic problem using linalg solver
A_to_solver = scipy.sparse.linalg.LinearOperator(
(rec_dim, rec_dim), matvec=ADMM_Ax, rmatvec=ADMM_Atb
)
b_to_solver = b_to_solver_const + _algorithm_upd_["ADMM_rho_const"] * (
z - u
)
outputSolver = scipy.sparse.linalg.gmres(
A_to_solver, b_to_solver, atol=1e-05, maxiter=15
)
X = xp.float32(outputSolver[0]) # get gmres solution
if _algorithm_upd_["nonnegativity"] == "ENABLE":
X[X < 0.0] = 0.0
# z-update with relaxation
zold = z.copy()
x_hat = (
_algorithm_upd_["ADMM_relax_par"] * X
+ (1.0 - _algorithm_upd_["ADMM_relax_par"]) * zold
)
if self.geom == "2D":
x_prox_reg = (x_hat + u).reshape(
[self.Atools.recon_size, self.Atools.recon_size]
)
if self.geom == "3D":
x_prox_reg = (x_hat + u).reshape(
[
self.Atools.detectors_y,
self.Atools.recon_size,
self.Atools.recon_size,
]
)
# Apply regularisation using CCPi-RGL toolkit. The proximal operator of the chosen regulariser
if _regularisation_upd_["method"] is not None:
# The proximal operator of the chosen regulariser
(z, info_vec) = prox_regul(self, x_prox_reg, _regularisation_upd_)
z = z.ravel()
# update u variable
u = u + (x_hat - z)
if _algorithm_upd_["verbose"]:
if xp.mod(iter_no, (round)(_algorithm_upd_["iterations"] / 5) + 1) == 0:
print(
"ADMM iteration (",
iter_no + 1,
") using",
_regularisation_upd_["method"],
"regularisation for (",
(int)(info_vec[0]),
") iterations",
)
if iter_no == _algorithm_upd_["iterations"] - 1:
print("ADMM stopped at iteration (", iter_no + 1, ")")
# stopping criteria (checked after reasonable number of iterations)
if iter_no > 5:
nrm = xp.linalg.norm(X - X_old) * denomN
if nrm < _algorithm_upd_["tolerance"]:
print("ADMM stopped at iteration (", iter_no, ")")
break
if self.geom == "2D":
return X.reshape([self.Atools.recon_size, self.Atools.recon_size])
if self.geom == "3D":
return X.reshape(
[
self.Atools.detectors_y,
self.Atools.recon_size,
self.Atools.recon_size,
]
)
return X
# *****************************ADMM ends here*********************************#