# Transform images with multiprocessing
# Author: David Young, 2019, 2020
"""Transform large images with multiprocessing, including up/downsampling
and image transposition.
"""
from time import time
from typing import Optional, Sequence, Tuple
import numpy as np
from skimage import transform
from magmap.cv import chunking, cv_nd
from magmap.settings import config
from magmap.io import importer
from magmap.io import libmag
from magmap.plot import plot_3d
_logger = config.logger.getChild(__name__)
[docs]
class Downsampler(object):
"""Downsample (or theoretically upsample) a large image in a way
that allows multiprocessing without global variables.
Attributes:
img (:obj:`np.ndarray`): Full image array.
"""
img = None
[docs]
@classmethod
def set_data(cls, img):
"""Set the class attributes to be shared during multiprocessing.
Args:
img (:obj:`np.ndarray`): See attributes.
"""
cls.img = img
[docs]
@classmethod
def rescale_sub_roi(
cls, coord: Sequence[int],
slices: Sequence[slice],
target_size: Optional[Sequence[int]],
multichannel: bool,
sub_roi: Optional[np.ndarray] = None
) -> Tuple[Sequence[int], np.ndarray]:
"""Rescale or resize a sub-ROI.
Args:
coord: Coordinates as a tuple of (z, y, x) of the sub-ROI within the
chunked ROI.
slices: Sequence of slices within :attr:``img`` defining the
sub-ROI.
target_size: Target rescaling size for the given sub-ROI in
(z, y, x). If ``rescale`` is not None, ``target_size``
will be ignored.
multichannel: True if the final dimension is for channels.
sub_roi: Array chunk to rescale/resize;
defaults to None to extract from :attr:`img` if available.
Return:
Tuple of ``coord`` and the rescaled sub-ROI, where
``coord`` is the same as the given parameter to identify
where the sub-ROI is located during multiprocessing tasks.
"""
if sub_roi is None and cls.img is not None:
sub_roi = cls.img[slices]
rescaled = cv_nd.rescale_resize(
sub_roi, target_size, multichannel)
return coord, rescaled
[docs]
def make_modifier_plane(plane):
"""Make a string designating a plane orthogonal transformation.
Args:
plane: Plane to which the image was transposed.
Returns:
String designating the orthogonal plane transformation.
"""
return "plane{}".format(plane.upper())
[docs]
def make_modifier_scale(scale):
"""Make a string designating a scaling transformation, typically for
filenames of rescaled images.
Args:
scale (float): Scale to which the image was rescaled. Any decimal
point will be replaced with "pt" to avoid confusion with
path extensions.
Returns:
str: String designating the scaling transformation.
"""
mod = "scale{}".format(scale)
return mod.replace(".", "pt")
[docs]
def make_modifier_resized(target_size):
"""Make a string designating a resize transformation.
Note that the final image size may differ slightly from this size as
it only reflects the size targeted.
Args:
target_size: Target size of rescaling in x,y,z.
Returns:
String designating the resize transformation.
"""
return "resized({},{},{})".format(*target_size)
[docs]
def get_transposed_image_path(
img_path: str, scale: float = None, target_size: Sequence[int] = None
) -> str:
"""Get path modified for any transposition.
Args:
img_path: Unmodified image path.
scale: Scaling factor, which takes precedence over ``target_size``;
defaults to None.
target_size: Target size in ``x, y, z``, typically given by an atlas
profile; defaults to None.
Returns:
Modified path for the given transposition, or ``img_path`` unmodified
if all transposition factors are None.
"""
img_path_modified = img_path
if scale is not None or target_size is not None:
# use scaled image for pixel comparison, retrieving
# saved scaling as of v.0.6.0
if scale is not None:
# scale takes priority as command-line argument
modifier = make_modifier_scale(scale)
print("loading scaled file with {} modifier".format(modifier))
else:
# otherwise assume set target size
modifier = make_modifier_resized(target_size)
print("loading resized file with {} modifier".format(modifier))
img_path_modified = libmag.insert_before_ext(
img_path, "_" + modifier)
return img_path_modified
[docs]
def transpose_img(
filename: str, series: Optional[int], plane: Optional[str] = None,
rescale: Optional[float] = None,
target_size: Optional[Sequence[int]] = None):
"""Transpose large NumPy saved arrays, including rescaling or resizing.
Loads a saved array to tranpose its planar orientation. Supports large
arrays, with rescaling/resizing performed in multiprocessing and file saving
through memmap-based arrays to minimize RAM usage. Output filenames
are based on the ``make_modifer_[task]`` functions. Currently, transposes
all channels, ignoring :attr:``magmap.settings.config.channel`` parameter.
Args:
filename: Full file path in :attribute:cli:`filename` format.
series: Series within multi-series file.
plane: Planar orientation (see :attr:`magmap.settings.config.PLANES`).
Defaults to None, in which case no planar transformation will occur.
rescale: Rescaling factor; defaults to None. Takes precedence over
``target_size``.
target_size: Target shape in x,y,z; defaults to None,
in which case the target size will be extracted from the register
profile if available.
"""
if target_size is None:
target_size = config.atlas_profile["target_size"]
if plane is None and rescale is None and target_size is None:
print("No transposition to perform, skipping")
return
time_start = time()
# even if loaded already, reread to get image metadata
# TODO: consider saving metadata in config and retrieving from there
img5d = importer.read_file(filename, series)
info = img5d.meta
image5d = img5d.img
sizes = info["sizes"]
# make filenames based on transpositions
modifier = ""
if plane is not None:
modifier = make_modifier_plane(plane)
# either rescaling or resizing
if rescale is not None:
modifier += make_modifier_scale(rescale)
elif target_size:
# target size may differ from final output size but allows a known
# size to be used for finding the file later
modifier += make_modifier_resized(target_size)
filename_image5d_npz, filename_info_npz = importer.make_filenames(
filename, series, modifier=modifier)
# TODO: image5d should assume 4/5 dimensions
offset = 0 if image5d.ndim <= 3 else 1
multichannel = image5d.ndim >= 5
image5d_swapped = image5d
if plane is not None and plane != config.PLANE[0]:
# swap z-y to get (y, z, x) order for xz orientation
image5d_swapped = np.swapaxes(image5d_swapped, offset, offset + 1)
config.resolutions[0] = libmag.swap_elements(
config.resolutions[0], 0, 1)
if plane == config.PLANE[2]:
# swap new y-x to get (x, z, y) order for yz orientation
image5d_swapped = np.swapaxes(image5d_swapped, offset, offset + 2)
config.resolutions[0] = libmag.swap_elements(
config.resolutions[0], 0, 2)
scaling = None
if rescale is not None or target_size is not None:
# rescale based on scaling factor or target specific size
rescaled = image5d_swapped
# TODO: generalize for more than 1 preceding dimension?
if offset > 0:
rescaled = rescaled[0]
max_pixels = [100, 500, 500]
sub_roi_size = None
if target_size:
# to avoid artifacts from thin chunks, fit image into even
# number of pixels per chunk by rounding up number of chunks
# and resizing each chunk by ratio of total size to chunk num
target_size = target_size[::-1] # change to z,y,x
shape = rescaled.shape[:3]
num_chunks = np.ceil(np.divide(shape, max_pixels))
max_pixels = np.ceil(
np.divide(shape, num_chunks)).astype(int)
sub_roi_size = np.floor(
np.divide(target_size, num_chunks)).astype(int)
print("Resizing image of shape {} to target_size: {}, using "
"num_chunks: {}, max_pixels: {}, sub_roi_size: {}"
.format(rescaled.shape, target_size, num_chunks, max_pixels,
sub_roi_size))
else:
print("Rescaling image of shape {} by factor of {}"
.format(rescaled.shape, rescale))
# rescale in chunks with multiprocessing
sub_roi_slices, _ = chunking.stack_splitter(rescaled.shape, max_pixels)
is_fork = chunking.is_fork()
if is_fork:
Downsampler.set_data(rescaled)
sub_rois = np.zeros_like(sub_roi_slices)
pool = chunking.get_mp_pool()
pool_results = []
for z in range(sub_roi_slices.shape[0]):
for y in range(sub_roi_slices.shape[1]):
for x in range(sub_roi_slices.shape[2]):
coord = (z, y, x)
slices = sub_roi_slices[coord]
args = [coord, slices, rescale if rescale else sub_roi_size,
multichannel]
if not is_fork:
# pickle chunk if img not directly available
args.append(rescaled[slices])
pool_results.append(pool.apply_async(
Downsampler.rescale_sub_roi, args=args))
for result in pool_results:
coord, sub_roi = result.get()
print("replacing sub_roi at {} of {}"
.format(coord, np.add(sub_roi_slices.shape, -1)))
sub_rois[coord] = sub_roi
pool.close()
pool.join()
rescaled_shape = chunking.get_split_stack_total_shape(sub_rois)
if offset > 0:
rescaled_shape = np.concatenate(([1], rescaled_shape))
print("rescaled_shape: {}".format(rescaled_shape))
# rescale chunks directly into memmap-backed array to minimize RAM usage
image5d_transposed = np.lib.format.open_memmap(
filename_image5d_npz, mode="w+", dtype=sub_rois[0, 0, 0].dtype,
shape=tuple(rescaled_shape))
chunking.merge_split_stack2(sub_rois, None, offset, image5d_transposed)
if rescale is not None:
# scale resolutions based on single rescaling factor
config.resolutions = np.multiply(
config.resolutions, 1 / rescale)
else:
# scale resolutions based on size ratio for each dimension
config.resolutions = np.multiply(
config.resolutions,
(image5d_swapped.shape / rescaled_shape)[1:4])
sizes[0] = rescaled_shape
scaling = importer.calc_scaling(image5d_swapped, image5d_transposed)
else:
# transfer directly to memmap-backed array
image5d_transposed = np.lib.format.open_memmap(
filename_image5d_npz, mode="w+", dtype=image5d_swapped.dtype,
shape=image5d_swapped.shape)
if plane == config.PLANE[1] or plane == config.PLANE[2]:
# flip upside-down if re-orienting planes
if offset:
image5d_transposed[0, :] = np.fliplr(image5d_swapped[0, :])
else:
image5d_transposed[:] = np.fliplr(image5d_swapped[:])
else:
image5d_transposed[:] = image5d_swapped[:]
sizes[0] = image5d_swapped.shape
# save image metadata
print("detector.resolutions: {}".format(config.resolutions))
print("sizes: {}".format(sizes))
image5d.flush()
importer.save_image_info(
filename_info_npz, info["names"], sizes, config.resolutions,
info["magnification"], info["zoom"],
*importer.calc_intensity_bounds(image5d_transposed), scaling, plane)
print("saved transposed file to {} with shape {}".format(
filename_image5d_npz, image5d_transposed.shape))
print("time elapsed (s): {}".format(time() - time_start))
[docs]
def rotate_img(roi, rotate=None, order=None):
"""Rotate an ROI based on atlas profile settings.
Args:
roi (:obj:`np.ndarray`): Region of interst array (z,y,x[,c]).
rotate (dict): Dictionary of rotation settings in
:class:`magmap.settings.atlas_profile`. Defaults to None
to take the value from :attr:`config.register_settings`.
order (int): Spline interpolation order; defalts to None to use
the value from within ``rotate``. Should be 0 for labels.
Returns:
:obj:`np.ndarray`: The rotated image array.
"""
if rotate is None:
rotate = config.atlas_profile["rotate"]
if order is None:
order = rotate["order"]
roi = np.copy(roi)
for rot in rotate["rotation"]:
print("rotating by", rot)
roi = cv_nd.rotate_nd(
roi, rot[0], rot[1], order=order, resize=rotate["resize"])
return roi
[docs]
def preprocess_img(image5d, preprocs, channel, out_path):
"""Pre-process an image in 3D.
Args:
image5d (:obj:`np.ndarray`): 5D array in t,z,y,x[,c].
preprocs (Union[str, list[str]]): Pre-processing tasks that will be
converted to enums in :class:`config.PreProcessKeys` to perform
in the order given.
channel (int): Channel to preprocess, or None for all channels.
out_path (str): Output base path.
Returns:
:obj:`np.ndarray`: The pre-processed image array.
"""
if preprocs is None:
print("No preprocessing tasks to perform, skipping")
return
if not libmag.is_seq(preprocs):
preprocs = [preprocs]
roi = image5d[0]
for preproc in preprocs:
# perform global pre-processing task
task = libmag.get_enum(preproc, config.PreProcessKeys)
_logger.info("Pre-processing task: %s", task)
if task is config.PreProcessKeys.SATURATE:
roi = plot_3d.saturate_roi(roi, channel=channel)
elif task is config.PreProcessKeys.DENOISE:
roi = plot_3d.denoise_roi(roi, channel)
elif task is config.PreProcessKeys.REMAP:
roi = plot_3d.remap_intensity(roi, channel)
elif task is config.PreProcessKeys.ROTATE:
roi = rotate_img(roi)
else:
_logger.warn("No preprocessing task found for: %s", preproc)
# save to new file
image5d = importer.roi_to_image5d(roi)
importer.save_np_image(image5d, out_path)
return image5d