Source code for tomobar.methodsDIR
"""Reconstruction class for direct reconstruction methods (2D/3D).
* :func:`RecToolsDIR.FORWPROJ` and :func:`RecToolsDIR.BACKPROJ` Forward/Backward 2D/3D projection (ASTRA-Toolbox)
* :func:`RecToolsDIR.FOURIER` Fourier Slice Theorem-based reconstruction in 2D only (adopted from the Tim Day's code)
* :func:`RecToolsDIR.FBP` Filtered Back Projection 2D/3D (ASTRA with the custom built filter).
"""
import numpy as np
import scipy.fftpack
from tomobar.astra_wrappers.astra_tools2d import AstraTools2D
from tomobar.astra_wrappers.astra_tools3d import AstraTools3D
from tomobar.supp.funcs import _data_dims_swapper, _parse_device_argument
def _filtersinc3D(projection3D: np.ndarray):
"""Applies a 3D filter to 3D projection data for FBP
Args:
projection3D (np.ndarray): projection data
Returns:
np.ndarray: Filtered data
"""
a = 1.1
[DetectorsLengthV, projectionsNum, DetectorsLengthH] = np.shape(projection3D)
w = np.linspace(
-np.pi,
np.pi - (2 * np.pi) / DetectorsLengthH,
DetectorsLengthH,
dtype="float32",
)
rn1 = np.abs(2.0 / a * np.sin(a * w / 2.0))
rn2 = np.sin(a * w / 2.0)
rd = (a * w) / 2.0
rd_c = np.zeros([1, DetectorsLengthH])
rd_c[0, :] = rd
r = rn1 * (np.dot(rn2, np.linalg.pinv(rd_c))) ** 2
multiplier = 1.0 / projectionsNum
f = scipy.fftpack.fftshift(r)
# making a 2d filter for projection
f_2d = np.zeros((DetectorsLengthV, DetectorsLengthH), dtype="float32")
f_2d[0::, :] = f
filtered = np.zeros(np.shape(projection3D), dtype="float32")
for i in range(0, projectionsNum):
IMG = scipy.fftpack.fft2(projection3D[:, i, :])
fimg = IMG * f_2d
filtered[:, i, :] = np.real(scipy.fftpack.ifft2(fimg))
return multiplier * filtered
def _filtersinc2D(sinogram):
# applies filters to __2D projection data__ in order to achieve FBP
a = 1.1
[projectionsNum, DetectorsLengthH] = np.shape(sinogram)
w = np.linspace(
-np.pi,
np.pi - (2 * np.pi) / DetectorsLengthH,
DetectorsLengthH,
dtype="float32",
)
rn1 = np.abs(2.0 / a * np.sin(a * w / 2.0))
rn2 = np.sin(a * w / 2.0)
rd = (a * w) / 2.0
rd_c = np.zeros([1, DetectorsLengthH])
rd_c[0, :] = rd
r = rn1 * (np.dot(rn2, np.linalg.pinv(rd_c))) ** 2
multiplier = 1.0 / projectionsNum
f = scipy.fftpack.fftshift(r)
filtered = np.zeros(np.shape(sinogram))
for i in range(0, projectionsNum):
IMG = scipy.fftpack.fft(sinogram[i, :])
fimg = IMG * f
filtered[i, :] = multiplier * np.real(scipy.fftpack.ifft(fimg))
return np.float32(filtered)
[docs]
class RecToolsDIR:
"""Reconstruction class using DIRect methods (FBP and Fourier).
Args:
DetectorsDimH (int): Horizontal detector dimension.
DetectorsDimV (int): Vertical detector dimension for 3D case, 0 or None for 2D case.
CenterRotOffset (float, ndarray): The Centre of Rotation (CoR) scalar or a vector for each angle.
AnglesVec (np.ndarray): Vector of projection angles in radians.
ObjSize (int): Reconstructed object dimensions (a scalar).
device_projector (str, int, optional): 'cpu' or 'gpu' device OR provide a GPU index (integer) of a specific GPU device.
cupyrun (bool, optional): instantiate CuPy class if True.
"""
def __init__(
self,
DetectorsDimH, # DetectorsDimH # detector dimension (horizontal)
DetectorsDimV, # DetectorsDimV # detector dimension (vertical) for 3D case only
CenterRotOffset, # Centre of Rotation (CoR) scalar or a vector
AnglesVec, # Array of angles in radians
ObjSize, # A scalar to define reconstructed object dimensions
device_projector="gpu", # Choose the device to be 'cpu' or 'gpu' OR provide a GPU index (integer) of a specific device
):
device_projector, GPUdevice_index = _parse_device_argument(device_projector)
if DetectorsDimV == 0 or DetectorsDimV is None:
self.geom = "2D"
self.Atools = AstraTools2D(
DetectorsDimH,
AnglesVec,
CenterRotOffset,
ObjSize,
device_projector,
GPUdevice_index,
)
else:
self.geom = "3D"
self.Atools = AstraTools3D(
DetectorsDimH,
DetectorsDimV,
AnglesVec,
CenterRotOffset,
ObjSize,
device_projector,
GPUdevice_index,
)
[docs]
def FORWPROJ(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Module to perform forward projection of 2d/3d data numpy array
Args:
data (np.ndarray): 2D or 3D object
Keyword Args:
data_axes_labels_order (Union[list, None], optional): The order of the axes labels for the OUTPUT data.
When "None" we assume ["angles", "detX"] for 2D and ["detY", "angles", "detX"] for 3D.
Returns:
np.ndarray: Forward projected numpy array (projection data)
"""
projected = self.Atools._forwproj(data)
for key, value in kwargs.items():
if key == "data_axes_labels_order" and value is not None:
if self.geom == "2D":
projected = _data_dims_swapper(projected, value, ["angles", "detX"])
else:
projected = _data_dims_swapper(
projected, value, ["detY", "angles", "detX"]
)
return projected
[docs]
def BACKPROJ(self, projdata: np.ndarray, **kwargs) -> np.ndarray:
"""Module to perform back-projection of 2d/3d data numpy array
Args:
projdata (np.ndarray): 2D/3D projection data
Keyword Args:
data_axes_labels_order (Union[list, None], optional): The order of the axes labels for the input data.
When "None" we assume ["angles", "detX"] for 2D and ["detY", "angles", "detX"] for 3D.
Returns:
np.ndarray: Backprojected 2D/3D object
"""
for key, value in kwargs.items():
if key == "data_axes_labels_order" and value is not None:
if self.geom == "2D":
projdata = _data_dims_swapper(projdata, value, ["angles", "detX"])
else:
projdata = _data_dims_swapper(
projdata, value, ["detY", "angles", "detX"]
)
return self.Atools._backproj(projdata)
[docs]
def FBP(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Filtered backprojection reconstruction module for 2D or 3D data.
Args:
data (np.ndarray): 2D or 3D projection data.
Keyword Args:
data_axes_labels_order (Union[list, None], optional): The order of the axes labels for the input data.
When "None" we assume ["angles", "detX"] for 2D and ["detY", "angles", "detX"] for 3D.
Returns:
np.ndarray: FBP reconstructed 2D or 3D object.
"""
for key, value in kwargs.items():
if key == "data_axes_labels_order" and value is not None:
if self.geom == "2D":
data = _data_dims_swapper(data, value, ["angles", "detX"])
else:
data = _data_dims_swapper(data, value, ["detY", "angles", "detX"])
if self.geom == "2D":
"dealing with FBP 2D not working for parallel_vec geometry and CPU"
if self.Atools.processing_arch == "gpu":
return self.Atools._fbp(data)
else:
return self.Atools._backproj(_filtersinc2D(data))
else:
return self.Atools._backproj(_filtersinc3D(data))
[docs]
def FOURIER(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""2D Reconstruction using Fourier slice theorem (scipy required)
for griddata interpolation module choose nearest, linear or cubic
Args:
data (np.ndarray): 2D sinogram data
Keyword Args:
data_axes_labels_order (Union[list, None], optional): The order of the axes labels for the input data.
When "None" we assume ["angles", "detX"] for 2D and ["detY", "angles", "detX"] for 3D.
method (str, optional): Interpolation type (nearest, linear, or cubic). Defaults to "linear".
Returns:
np.ndarray: Reconstructed object
"""
if data.ndim == 3:
raise ValueError(
"Fourier method is currently for 2D data only, use FBP if 3D reconstruction needed"
)
for key, value in kwargs.items():
if key == "data_axes_labels_order" and value is not None:
data = _data_dims_swapper(data, value, ["angles", "detX"])
if key == "method":
if value not in ["linear", "nearest", "cubic"]:
raise ValueError(
"For griddata interpolation module choose nearest, linear or cubic"
)
else:
method = value
from scipy.fft import fftshift, ifftshift, fft, ifft2
from scipy.interpolate import griddata
ObjSize = self.Atools.recon_size
# pad sinogram and move it to compensate for CoR
oversampling = 2 # 2 or larger
angles_tot, DetectorsDimH = data.shape
if (DetectorsDimH % 2) != 0:
raise ValueError(
"The horizontal detector size of the projection data (sinogram) must be even"
)
det_x_up = oversampling * DetectorsDimH
sino_up = np.zeros([angles_tot, det_x_up], dtype=np.float32)
pad_from = DetectorsDimH // 2 + int(self.Atools.centre_of_rotation)
pad_to = det_x_up - DetectorsDimH // 2 + int(self.Atools.centre_of_rotation)
sino_up[:, pad_from:pad_to] = data
# Fourier transform the rows of the sinogram, move the DC component to the row's centre
sinogram_fft_rows = fftshift(fft(ifftshift(sino_up, axes=1)), axes=1)
# Coordinates of sinogram FFT-ed rows' samples in 2D FFT space
a = -self.Atools.angles_vec
r = np.arange(det_x_up) - det_x_up / 2
r, a = np.meshgrid(r, a)
r = r.flatten()
a = a.flatten()
srcx = (det_x_up / 2) + r * np.cos(a)
srcy = (det_x_up / 2) + r * np.sin(a)
# Coordinates of regular grid in 2D FFT space
dstx, dsty = np.meshgrid(np.arange(det_x_up), np.arange(det_x_up))
dstx = dstx.flatten()
dsty = dsty.flatten()
# Interpolate the 2D Fourier space grid from the transformed sinogram rows
fft2 = griddata(
(srcy, srcx),
sinogram_fft_rows.flatten(),
(dsty, dstx),
method,
fill_value=0.0,
).reshape((det_x_up, det_x_up))
# Transform from 2D Fourier space back to a reconstruction of the target
recon = np.real(fftshift(ifft2(ifftshift(fft2))))
# Cropping the reconstruction to size of the original image
unpad_from = det_x_up // 2 - ObjSize // 2
unpad_to = det_x_up // 2 + ObjSize // 2
return recon[unpad_from:unpad_to, unpad_from:unpad_to]