# 2D overlaid plot editor
# Author: David Young, 2018, 2020
"""Editor for 2D plot with overlaid planes.
Integrates with :class:``atlas_editor.AtlasEditor`` for synchronized 3D
view of orthogonal planes.
"""
import textwrap
import time
from typing import Callable, List, Optional, Sequence, TYPE_CHECKING, Union
from matplotlib import patches
import numpy as np
from skimage import draw
from magmap.gui import image_viewer, pixel_display
from magmap.io import libmag, np_io
from magmap.settings import config
from magmap.atlas import ontology
from magmap.plot import plot_support
if TYPE_CHECKING:
from matplotlib import axes, colors, backend_bases, image
_logger = config.logger.getChild(__name__)
[docs]
class PlotAxImg:
"""Axes image storage class.
Tracks settings that may differ between the currently displayed image and
input values, such as ``None`` to specify auto-intensity and values for
brightness and contrast that are not stored in the image itself.
Attributes:
ax_img: Displayed Matplotlib image.
vmin: Specified vmin; defaults to None for auto-scaling. See
``ax_img.norm.vmin`` for the output vmin.
vmax: Specified vmax; defaults to None for auto-scaling. See
``ax_img.norm.vmax`` for the output vmax.
brightness: Brightness addend; defaults to 0.0.
contrast: Contrast factor; defaults to 1.0.
alpha: Opacity level; defaults to None.
alpha_blend: Opacity level of the first level in the area of blending
between two images; defaults to None.
"""
def __init__(
self, ax_img: "image.AxesImage", vmin: Optional[float] = None,
vmax: Optional[float] = None, img: Optional[np.ndarray] = None):
# set from arguments
self.ax_img = ax_img
self.vmin = vmin
self.vmax = vmax
# additional image attributes
self.brightness: float = 0.0
self.contrast: float = 1.0
self.alpha: Optional[float] = None
self.alpha_blend: Optional[float] = None
#: True if the image is displayed as RGB(A); defaults to False.
self.rgb: bool = False
#: Original array of the displayed Matplotlib image. If None (default),
#: the displayed image's array will be copied to allow adjusting the
#: array in ``ax_img`` while retaining the original data.
self.img: np.ndarray = np.copy(
self.ax_img.get_array()) if img is None else img
#: Original input image data; defaults to None.
self.input_img: Optional[np.ndarray] = None
[docs]
class PlotEditor:
"""Show a scrollable, editable plot of sequential planes in a 3D image.
Attributes:
intensity (int): Chosen intensity value of ``img3d_labels``.
intensity_spec (int): Intensity value specified directly
rather than chosen from the labels image.
intensity_shown (int): Displayed intensity in the
:obj:`matplotlib.AxesImage` corresponding to ``img3d_labels``.
coord (List[int]): Coordinates in ``z, y, x``.
scale_bar (bool): True to add a scale bar; defaults to False.
enable_painting (bool): True to enable label painting; defaults to True.
max_intens_proj (int): Number of planes to include in a maximum
intensity projection; defaults to 0 for no projection. The
planes are taken starting from the given z-value in ``coord``,
limited by the number of planes available. Applied to the first
intensity image.
"""
ALPHA_DEFAULT = 0.5
_KEY_MODIFIERS = ("shift", "alt", "control")
def __init__(self, overlayer, img3d,
img3d_labels=None, cmap_labels=None, plane=None,
fn_update_coords=None, fn_refresh_images=None,
scaling=None, plane_slider=None, img3d_borders=None,
cmap_borders=None, fn_show_label_3d=None, interp_planes=None,
fn_update_intensity=None, max_size=None, fn_status_bar=None,
img3d_extras=None):
"""Initialize the plot editor.
Args:
fn_refresh_images (function): Callback when refreshing the image.
Typically takes two arguments, this ``PlotEditor`` object
and a boolean where True will update synchronized
``AtlasEditor``s.
scaling (List[float]): Scaling/spacing in z,y,x.
plane_slider (:obj:`matplotlib.widgets.Slider`): Slider for
scrolling through planes.
img3d_labels (:obj:`np.ndarray`): Borders 3D image; defaults
to None.
cmap_borders (:obj:`matplotlib.colors.ListedColormap`): Borders
colormap, generally a :obj:`colormaps.DiscreteColormap`;
defaults to None.
fn_show_label_3d (function): Callback to show a label at the
current 3D coordinates; defaults to None.
interp_planes (:obj:`atlas_editor.InterpolatePlanes`): Plane
interpolation object; defaults to None.
fn_update_intensity (function): Callback when updating the
intensity value; defaults to None.
max_size (int): Maximum size of either side of the 2D plane shown;
defaults to None.
fn_status_bar (func): Function to call during status bar updates
in :class:`pixel_display.PixelDisplay`; defaults to None.
img3d_extras (List[:obj:`np.ndarray`]): Sequence of additional
intensity images to display; defaults to None.
"""
#: Manager for plotting overlaid images.
self.overlayer: "plot_support.ImageOverlayer" = overlayer
self.axes: "axes.Axes" = self.overlayer.ax
#: Main 3D image.
self.img3d: np.ndarray = img3d
#: Labels 3D image; defaults to None.
self.img3d_labels: Optional[np.ndarray] = img3d_labels
#: Labels colormap, generally of
#: :class:`magmap.plot.colormaps.DiscreteColormap`. Defaults to None.
self.cmap_labels: Optional["colors.ListedColormap"] = cmap_labels
#: One of :attr:`magmap.settings.config.PLANE` specifying the
#: orthogonal plane to view.
self.plane: str = plane if plane else config.PLANE[0]
#: Callback when updating coordinates, typically mouse click events
#: in x,y; takes two arguments, the updated coordinates and ``plane``
#: to indicate the coordinates' orientation.
self.fn_update_coords: Callable[
[Sequence[int], str], None] = fn_update_coords
self.fn_refresh_images = fn_refresh_images
self.scaling = config.labels_scaling if scaling is None else scaling
self.plane_slider = plane_slider
if self.plane_slider:
self.plane_slider.on_changed(self.update_plane_slider)
self.img3d_borders = img3d_borders
self.cmap_borders = cmap_borders
self.fn_show_label_3d = fn_show_label_3d
self.interp_planes = interp_planes
self.fn_update_intensity = fn_update_intensity
self.fn_status_bar = fn_status_bar
self.img3d_extras = img3d_extras
#: Main image opacity as a sequence of values from 0-1 for each
#: channel. Defaults to a list of the first element in
#: :attr:`config.alphas`.
self.alpha_img3d: Sequence[float] = [config.alphas[0]]
#: Labels opacity from 0-1; defaults to :const:`ALPHA_DEFAULT`.
self.alpha: float = self.ALPHA_DEFAULT
self.intensity = None # picked intensity of underlying img3d_label
self.intensity_spec = None # specified intensity
self.intensity_shown = None # shown intensity in AxesImage
self.cidpress = None
self.cidrelease = None
self.cidmotion = None
self.cidenter = None
self.cidleave = None
self.cidkeypress = None
self.radius = 5
self.circle = None
self.background = None
self.last_loc = None
self.last_loc_data = None
self.press_loc_data = None
self.connected = False
self.hline = None
self.vline = None
self.coord = None
self.xlim = None
self.ylim = None
self.edited = False # True if labels image was edited
#: Atlas labels edit mode status; defaults to False.
self._edit_mode: bool = False
self.region_label = None
self.scale_bar = False
self.max_intens_proj = 0
self.enable_painting = True
#: Ontology level at which to show region names.
self.labels_level: Optional[int] = None
#: Blit manager.
self.blitter: Optional["image_viewer.Blitter"] = None
#: Threshold for triggering label motion events. Use 0 (default) to
#: respond to all events except those within the same pixel. Higher
#: values filter more movements, especially fast and short motions.
#: Use `np.inf` to ignore motion and respond to left-click instead.
self.label_motion_thresh: float = 0.
#: Threshold for triggering navigation (eg pan/zoom) events. See
#: :attr:`label_motion_thresh`, except `np.inf` simply turns off
#: navigation.
self.nav_motion_thresh: float = 0.
self._plot_ax_imgs = None
self._ax_img_labels = None # displayed labels image
self._channels = None # displayed channels list
# track label editing during mouse click/movement for plane interp
self._editing = False
self._show_labels = True # show atlas labels on mouseover
self._show_crosslines = False # show crosslines to orthogonal views
self._colorbar = None # colorbar for first main image
#: Region label text offset in axes units.
self._region_label_offset: float = 0.05
#: Last time of mouse motion. Defaults to initialization time.
self._last_time: float = time.perf_counter()
# ROI offset and size in z,y,x
self._roi_offset = None
self._roi_size = None
self._roi_patch_preview = None
# pre-compute image shapes, scales, and downsampling factors for
# each type of 3D image
img3ds = [self.img3d, self.img3d_labels, self.img3d_borders]
if self.img3d_extras is not None:
img3ds.extend(self.img3d_extras)
num_img3ds = len(img3ds)
self._img3d_shapes = [None] * num_img3ds
self._img3d_scales = [None] * num_img3ds
self._downsample = [1] * num_img3ds
for i, img in enumerate(img3ds):
if img is None: continue
self._img3d_shapes[i] = img.shape
if i > 0:
lbls_scale = np.divide(img.shape[:3], self.img3d.shape[:3])
if not all(lbls_scale == 1):
# replace with scale if not all 1
self._img3d_scales[i] = lbls_scale
if max_size:
downsample = max(img.shape[1:3]) // max_size
if downsample > 1:
# only downsample if factor is over 1
self._downsample[i] = downsample
if config.verbose and not np.all(np.array(self._downsample) == 1):
_logger.debug(
"Plane %s downsampling factors by image: %s", self.plane,
self._downsample)
@property
def edit_mode(self) -> bool:
"""Atlas labels edit mode status.
Return:
True if in edit mode, False otherwise.
"""
return self._edit_mode
@edit_mode.setter
def edit_mode(self, val: bool):
"""Set atlas labels edit mode status.
Adds the atlas labels axes image to the blitter in edit mode and
removes the image when leaving edit mode.
Args:
val: True to start editing atlas labels.
"""
self._edit_mode = val
if self.blitter:
if val:
# add labels axes images to blitter
self.blitter.add_artist(self._ax_img_labels)
else:
# remove from blitter
artists = self.blitter.artists
if self._ax_img_labels in artists:
artists.remove(self._ax_img_labels)
@property
def plot_ax_imgs(self) -> Optional[List["PlotAxImg"]]:
return self._plot_ax_imgs
[docs]
def connect(self):
"""Connect events to functions.
"""
canvas = self.axes.figure.canvas
self.cidpress = canvas.mpl_connect("button_press_event", self.on_press)
self.cidrelease = canvas.mpl_connect(
"button_release_event", self.on_release)
self.cidmotion = canvas.mpl_connect(
"motion_notify_event", self.on_motion)
self.cidkeypress = canvas.mpl_connect(
"key_press_event", self.on_key_press)
self.connected = True
[docs]
def disconnect(self):
"""Disconnect event listeners.
"""
self.circle = None
listeners = [
self.cidpress, self.cidrelease, self.cidmotion, self.cidenter,
self.cidleave, self.cidkeypress]
canvas = self.axes.figure.canvas
for listener in listeners:
if listener:
canvas.mpl_disconnect(listener)
self.connected = False
[docs]
def update_coord(self, coord=None):
"""Update the displayed image for the given coordinates.
Scroll to the given z-plane if changed and draw crosshairs to
indicated the corresponding ``x,y`` values.
Args:
coord (List[int]): Coordinates in `z,y,x`, assumed to be transposed
so the z-plane is show in this editor; defaults to None
to use :attr:`coord`.
"""
update_overview = True
if coord is not None:
# no need to refresh z-plane if same coordinate
update_overview = self.coord is None or coord[0] != self.coord[0]
self.coord = coord
if update_overview:
self.show_overview()
if self._show_crosslines:
self.draw_crosslines()
[docs]
def translate_coord(
self, coord: Sequence[int], up: bool = False,
coord_slice: Optional[Union[slice, Sequence[slice]]] = None
) -> Sequence[int]:
"""Translate coordinate based on downsampling factor of the main image.
Coordinates sent to and received from the Atlas Editor are assumed to
be in the original image space. All overlaid images are assumed to be
resized to the shape of the main image.
Args:
coord: Coordinates in z,y,x.
up: True to upsample; defaults to False, which adjusts
coordinates for downsampled images.
coord_slice: Slice of each set of coordinates to
transpose. Defaults to None, which gives a slice starting
at 1 so that the z-value will not be adjusted on the
assumption that downsampling is only performed in ``x,y``.
Returns:
List[int]: The translated coordinates.
"""
if coord_slice is None:
coord_slice = slice(1, None)
coord_tr = np.copy(coord)
if up:
coord_tr[coord_slice] = np.multiply(
coord_tr[coord_slice], self._downsample[0])
else:
coord_tr[coord_slice] = np.divide(
coord_tr[coord_slice], self._downsample[0])
coord_tr = list(coord_tr.astype(int))
# print("translated from {} to {}".format(coord, coord_tr))
return coord_tr
[docs]
def draw_crosslines(self, show=None):
"""Draw crosshairs depicting the x and y values in orthogonal viewers.
Args:
show (bool): True to show crosslines, False to make them invisible;
defaults to None to use :attr:`show_crosslines`.
"""
if show is None:
show = self._show_crosslines
# translate coordinate down for any downsampling
coord = self.translate_coord(self.coord)
if self.hline is None:
# draw new crosshairs
self.hline = self.axes.axhline(coord[1], linestyle=":")
self.vline = self.axes.axvline(coord[2], linestyle=":")
else:
# toggle visibility
self.hline.set_visible(show)
self.vline.set_visible(show)
if show:
# update positions of current crosshairs
self.hline.set_ydata([coord[1]])
self.vline.set_xdata([coord[2]])
[docs]
def set_show_label(self, val):
"""Set whether to show labels on hover.
Args:
val (bool): True to show labels, False otherwise.
"""
if not val:
# reset text to trigger a figure refresh
self.region_label.set_text("")
self._show_labels = val
[docs]
def show_labels(self, show: bool = True, **kwargs):
"""Show or remove labels for all regions.
Args:
show: True (default) to show all labels; False to remove them.
kwargs: Arguments passed to
:meth:`magmap.plot_support.ImageOverlayer.annotate_labels`.
"""
if show:
if (len(self._plot_ax_imgs) > 1 and
self._plot_ax_imgs[1] is not None and config.labels_ref
and config.labels_ref.ref_lookup):
# add label annotations
self.overlayer.annotate_labels(
self._plot_ax_imgs[1][0].img, config.labels_ref.ref_lookup,
self.labels_level, **kwargs)
else:
# remove all labels
self.overlayer.remove_labels()
# refresh view
self.axes.figure.canvas.draw_idle()
def _get_img2d(
self, i: int, img: np.ndarray, max_intens: int = 0) -> np.ndarray:
"""Get the 2D image from the given 3D image from the current coordinate.
Scales and downsamples as necessary.
Args:
i: Index of 3D image in sequence of 3D images, assuming
order of ``(main_image, labels_img, borders_img)``.
img: 3D image from which to extract a 2D plane.
max_intens: Number of planes to incorporate for maximum
intensity projection; defaults to 0 to not perform this
projection.
Returns:
2D plane, downsampled if necessary.
"""
z = self.coord[0]
z_scale = 1
if self._img3d_scales[i] is not None:
# rescale z-coordinate based on image scaling to the main image
z_scale = self._img3d_scales[i][0]
z = int(z * z_scale)
num_z = len(img)
if z >= num_z:
# keep index within image planes
z = num_z - 1
# downsample to reduce access time; use same factor for both x and y
# to retain aspect ratio
downsample = self._downsample[i]
img = img[:, ::downsample, ::downsample]
if max_intens:
# max intensity projection (MIP) across the given number of
# planes available
z_stop = z + int(max_intens * z_scale)
if z_stop > num_z:
z_stop = num_z
z_range = np.arange(z, z_stop)
img2d = plot_support.extract_planes(
img[None], z_range, max_intens_proj=True)[0]
else:
# get single plane
img2d = img[z]
return img2d
def _get_ax_imgs(self) -> List["image.AxesImage"]:
"""Flatten ax image data stores and extract axes images."""
ax_imgs = [p.ax_img for p in libmag.flatten(self._plot_ax_imgs)]
return ax_imgs
[docs]
def show_overview(self):
"""Show the main 2D plane, taken as a z-plane."""
self.axes.clear()
self.hline = None
self.vline = None
if self.blitter:
# remove existing artists in this editor from blitter
artists = self.blitter.artists
animated = [self.region_label, self.circle]
if self._plot_ax_imgs:
animated.extend(self._get_ax_imgs())
for artist in animated:
if artist in artists:
artists.remove(artist)
# prep 2D image from main image, assumed to be an intensity image,
# with settings for each channel within this main image
imgs2d = [self._get_img2d(0, self.img3d, self.max_intens_proj)]
self._channels = [config.channel]
cmaps = [config.cmaps] # for all channels in image, even if not shown
# rest of settings are only for channels in config.channel
alphas = list(self.alpha_img3d)
alpha_is_default = True
alpha_blends = [None]
shapes = [self._img3d_shapes[0][1:3]]
vmaxs = [None]
vmins = [None]
brightnesses = [None]
contrasts = [None]
if self._plot_ax_imgs:
# use vmin/vmax from norm values in previously displayed images
# if available; None specifies auto-scaling
vmaxs[0] = [p.vmax for p in self._plot_ax_imgs[0]]
vmins[0] = [p.vmin for p in self._plot_ax_imgs[0]]
# use opacity, brightness, anc contrast from prior images
alphas[0] = [p.alpha for p in self._plot_ax_imgs[0]]
alpha_is_default = False
alpha_blends[0] = [p.alpha_blend for p in self._plot_ax_imgs[0]]
brightnesses[0] = [p.brightness for p in self._plot_ax_imgs[0]]
contrasts[0] = [p.contrast for p in self._plot_ax_imgs[0]]
img2d_lbl = None
if self.img3d_labels is not None:
# prep labels with discrete colormap and prior alpha if available
img2d_lbl = self._get_img2d(1, self.img3d_labels)
imgs2d.append(img2d_lbl)
self._channels.append([0])
cmaps.append(self.cmap_labels)
alphas.append(
self._ax_img_labels.get_alpha() if self._ax_img_labels
else self.alpha)
alpha_blends.append(None)
shapes.append(self._img3d_shapes[1][1:3])
vmaxs.append(None)
vmins.append(None)
if self.img3d_borders is not None:
# prep borders image, which may have an extra channels
# dimension for multiple sets of borders
img2d = self._get_img2d(2, self.img3d_borders)
channels = img2d.ndim if img2d.ndim >= 3 else 1
for i, channel in enumerate(range(channels - 1, -1, -1)):
# show first (original) borders image last so that its
# colormap values take precedence to highlight original bounds
img_add = img2d[..., channel] if channels > 1 else img2d
imgs2d.append(img_add)
self._channels.append([0])
cmaps.append(self.cmap_borders[channel])
# get alpha for last corresponding borders plane if available
ax_img = libmag.get_if_within(self._plot_ax_imgs, 2 + i, None)
alpha = (ax_img[i].alpha if ax_img else
libmag.get_if_within(config.alphas, 2 + i, 1))
alphas.append(alpha)
alpha_blends.append(None)
shapes.append(self._img3d_shapes[2][1:3])
vmaxs.append(None)
vmins.append(None)
if self.img3d_extras is not None:
for i, img in enumerate(self.img3d_extras):
# prep additional intensity image
imgi = 3 + i
imgs2d.append(self._get_img2d(imgi, img))
self._channels.append([0])
cmaps.append(("Greys",))
alphas.append(0.4)
alpha_blends.append(None)
shapes.append(self._img3d_shapes[imgi][1:3])
vmaxs.append(None)
vmins.append(None)
def make_abs_chls(vals):
# `overlay_images` requires channel-specific settings to be given
# for all channels of first (intensity) image in absolute rather
# than relative position
if libmag.is_seq(vals[0]):
# initialize for all channels and slot in channel-specific vals
vals_chls = [None] * nchls
for chl, v in zip(config.channel, vals[0]):
vals_chls[chl] = v
vals = list(vals)
vals[0] = vals_chls
return vals
# overlay all images and set labels for footer value on mouseover;
# if first time showing image, need to check for images with single
# value since they fail to update on subsequent updates for unclear
# reasons
nchls = np_io.get_num_channels(self.img3d, True)
ax_imgs = self.overlayer.overlay_images(
imgs2d, self._channels, cmaps,
# expand for all intensity channels
make_abs_chls(alphas), make_abs_chls(vmins), make_abs_chls(vmaxs),
check_single=(self._ax_img_labels is None),
alpha_blends=alpha_blends)
# add or update colorbar
colobar_prof = config.roi_profile["colorbar"]
if self._colorbar:
self._colorbar.update_normal(ax_imgs[0][0])
elif colobar_prof:
# store colorbar since it's tied to the artist, which will be
# replaced with the next display and cannot be further accessed
self._colorbar = self.axes.figure.colorbar(
ax_imgs[0][0], ax=self.axes, **colobar_prof)
# display coordinates and label values for each image
self.axes.format_coord = pixel_display.PixelDisplay(
imgs2d, ax_imgs, shapes, cmap_labels=self.cmap_labels)
# trigger actual display through slider or call update directly
if self.plane_slider:
self.plane_slider.set_val(self.coord[0])
else:
self._update_overview(self.coord[0])
self.show_roi()
if self.scale_bar:
plot_support.add_scale_bar(self.axes, self._downsample[0])
if len(ax_imgs) > 1:
# store labels axes image separately for frequent access
self._ax_img_labels = ax_imgs[1][0]
# store displayed images in the PlotAxImg container class and update
# displayed brightness/contrast
self._plot_ax_imgs = []
for i, imgs in enumerate(ax_imgs):
plot_ax_imgs = []
for j, img in enumerate(imgs):
# for 2D label images, use the original 2D labels, without
# cmap index conversion
img_orig = img2d_lbl if i == 1 else None
# initialized the plotted image storage instance
plot_ax_img = PlotAxImg(img, img=img_orig)
if i == 0 and imgs2d:
# store the original intensity image plane, which may
# differ from the displayed axes image array
plot_ax_img.input_img = libmag.get_if_within(imgs2d[0], j)
if i == 0:
# specified vmin/vmax, in contrast to the AxesImages's
# norm, which holds the values used for the displayed image
plot_ax_img.vmin = libmag.get_if_within(vmins[i], j)
plot_ax_img.vmax = libmag.get_if_within(vmaxs[i], j)
# set brightness/contrast
self.change_brightness_contrast(
plot_ax_img, libmag.get_if_within(brightnesses[i], j),
libmag.get_if_within(contrasts[i], j))
# get alpha from image if default since image overlayer scales
# the default alpha for each channel
plot_ax_img.alpha = (
plot_ax_img.ax_img.get_alpha() if alpha_is_default
else libmag.get_if_within(alphas[i], j))
# store rest of settings
plot_ax_img.alpha_blend = libmag.get_if_within(
alpha_blends[i], j)
plot_ax_img.rgb = self.overlayer.rgb
plot_ax_imgs.append(plot_ax_img)
self._plot_ax_imgs.append(plot_ax_imgs)
if self.xlim is not None and self.ylim is not None:
# restore pan/zoom view
self.axes.set_xlim(self.xlim)
self.axes.set_ylim(self.ylim)
if not self.connected:
# connect once get AxesImage
self.connect()
# text label with color for visibility on axes plus fig background
self.region_label = self.axes.text(
0, 0, "", color="k", bbox=dict(facecolor="xkcd:silver", alpha=0.5))
if self.blitter:
self.blitter.add_artist(self.region_label)
self.circle = None
if self.overlayer.labels_annots:
# regenerate label annotations if previously shown
self.show_labels()
def _update_overview(self, z_overview_new: int):
"""Update overview plot to the given plane.
Ignores updates if the update function is missing or the plane is
unchanged. In these cases, call :meth:`show_overview` directly.
Args:
z_overview_new: Z-plane index to show.
"""
if self.fn_update_coords and z_overview_new != self.coord[0]:
# move only if step registered and changing position
coord = list(self.coord)
coord[0] = z_overview_new
self.fn_update_coords(coord, self.plane)
[docs]
def update_plane_slider(self, val):
self._update_overview(int(val))
[docs]
def view_subimg(self, offset: Sequence[int], size: Sequence[int]):
"""View a sub-image.
Args:
offset: Sub-image offset in ``y, x``.
size: Sub-image size in ``y, x``.
"""
coord_slice = slice(0, None)
off_trans = self.translate_coord(offset, coord_slice=coord_slice)
size_trans = self.translate_coord(size, coord_slice=coord_slice)
# print("view subimg offset", offset, "translated offset", off_trans,
# "size", size, "translated size", size_trans)
self.axes.set_xlim(off_trans[1], off_trans[1] + size_trans[1] - 1)
ylim = (off_trans[0], off_trans[0] + size_trans[0] - 1)
ylim_orig = self.axes.get_ylim()
if ylim_orig[1] - ylim_orig[0] < 0:
# set "bottom" first (higher y-values) if originally flipped
ylim = ylim[::-1]
self.axes.set_ylim(*ylim)
self.xlim = self.axes.get_xlim()
self.ylim = self.axes.get_ylim()
[docs]
def show_roi(self, offset=None, size=None, preview=False):
"""Show an ROI as an empty rectangular patch.
If ``offset`` and ``size`` cannot be retrieved, no ROI will be shown.
Args:
offset (List[int]): ROI offset in ``y, x``. Defaults to None
to use the saved ROI offset if available.
size (List[int]): ROI size in ``y, x``. Defaults to None
to use the saved ROI size if available.
preview (bool): True if the ROI should be displayed as a preview,
which is lighter and transient, removed when another preview
ROI is displayed. Defaults to False. If False, ROI parameters
for a displayed ROI will be shown.
"""
# translate coordinates if given; otherwise, use stored coordinates
coord_slice = slice(0, None)
if offset is None:
offset = self._roi_offset
else:
offset = self.translate_coord(offset, coord_slice=coord_slice)
if size is None:
size = self._roi_size
else:
size = self.translate_coord(size, coord_slice=coord_slice)
if offset is None or size is None: return
# generate and display ROI patch as empty rectangle, with lighter
# border for preview ROI
linewidth = 1 if preview else 2
linestyle = "--" if preview else "-"
alpha = 0.5 if preview else 1
patch = patches.Rectangle(
offset[::-1], *size[::-1], fill=False, edgecolor="yellow",
linewidth=linewidth, linestyle=linestyle, alpha=alpha)
if preview:
# remove prior preview and store new preview
if self._roi_patch_preview:
self._roi_patch_preview.remove()
self._roi_patch_preview = patch
else:
# store coordinates
self._roi_offset = offset
self._roi_size = size
self.axes.add_patch(patch)
[docs]
def refresh_img3d_labels(self):
"""Replace the displayed labels image with underlying plane's data.
"""
if self._ax_img_labels is not None:
# underlying plane may need to be translated to a linear
# scale for the colormap before display
self._ax_img_labels.set_data(self.cmap_labels.convert_img_labels(
self.img3d_labels[self.coord[0]]))
[docs]
@staticmethod
def get_plot_ax_img(
plot_ax_imgs: Sequence[Sequence["PlotAxImg"]], imgi: int,
channels: Optional[Sequence[Sequence[int]]] = None, chl:
Optional[int] = None) -> Optional["PlotAxImg"]:
"""Get a plotted image based on image group and channel.
Args:
plot_ax_imgs: Plotted image objects, organized as
``[[img0_chl0, img0_chl1, ...], [img1_chl0, ...], ....]``.
imgi: Index of image group in ``plot_ax_imgs``.
channels: List of channel lists corresponding
to ``plot_ax_imgs``; defalts to None to use ``chl`` directly.
chl: Index of channel within the selected image group;
defaults to None to use the first channel.
Returns:
The selected plotted image, or None if image corresponding to
``imgi`` and ``chl`` is not found.
"""
plot_ax_img = None
if plot_ax_imgs and imgi < len(plot_ax_imgs):
if chl is None:
chl = 0
if channels:
# translate channel to index of displayed channels
if chl not in channels[imgi]:
return None
chl = channels[imgi].index(chl)
if plot_ax_imgs[imgi] and chl < len(plot_ax_imgs[imgi]):
# get image channel within given set of images
plot_ax_img = plot_ax_imgs[imgi][chl]
return plot_ax_img
[docs]
def get_displayed_img(self, imgi, chl=None):
"""Get display settings for the given image.
Args:
imgi (int): Index of image group.
chl (int): Index of channel within the group; defaults to None.
Returns:
:obj:`PlotAxImg`: The currently displayed image.
"""
return self.get_plot_ax_img(
self._plot_ax_imgs, imgi, self._channels, chl)
[docs]
@staticmethod
def change_brightness_contrast(
plot_ax_img: PlotAxImg, brightness: Optional[float],
contrast: Optional[float]):
"""Change image brightness and contrast.
All changes are made relative to the original image, so both
brightness and contrast should be given together. For example, if
brightness is changed, and later contrast is changed, the contrast
change will wipe out the brightness change unless the brightness
value is given again.
Args:
plot_ax_img: Axes image storage instance.
brightness: Brightness value, centered on 0. Can be None to
ignore brightness changes.
contrast: Contrast value, centered on 1. Can be None to
ignore contrast changes.
"""
# get displayed image array, which adjusts dynamically to array changes
data = plot_ax_img.ax_img.get_array()
# get info range from data type, or assume 0-1 for RGB images
info = libmag.get_dtype_info(data)
info_range = (0, 1) if plot_ax_img.rgb else (info.min, info.max)
img = plot_ax_img.img
if brightness is not None:
# shift original image array by brightness
img = np.clip(img + brightness, *info_range)
data[:] = img
plot_ax_img.brightness = brightness
if contrast is not None:
# stretch adjusted image array by contrast
img = np.clip(img * contrast, info.min, info.max)
data[:] = img
plot_ax_img.contrast = contrast
[docs]
@staticmethod
def update_plot_ax_img_display(
plot_ax_img: PlotAxImg, minimum: float = np.nan,
maximum: float = np.nan, brightness: Optional[float] = None,
contrast: Optional[float] = None, alpha: Optional[float] = None,
**kwargs) -> Optional[PlotAxImg]:
"""Update plotted image display settings.
Args:
plot_ax_img: Plotted image.
minimum: Vmin; can be None for auto setting; defaults
to ``np.nan`` to ignore.
maximum: Vmax; can be None for auto setting; defaults
to ``np.nan`` to ignore.
brightness: Brightness addend; defaults to None.
contrast: Contrast multiplier; defaults to None.
alpha: Opacity value; defalts to None.
**kwargs: Extra arguments, currently ignored.
Returns:
The updated axes image plot.
"""
if not plot_ax_img:
return None
if minimum is not np.nan or maximum is not np.nan:
# store the specified intensity limits
if minimum is not np.nan:
plot_ax_img.vmin = minimum
if maximum is not np.nan:
plot_ax_img.vmax = maximum
# set vmin and vmax; use norm rather than get_clim since norm
# holds the actual limits
clim = [minimum, maximum]
norm = plot_ax_img.ax_img.norm
for j, (lim, ax_lim) in enumerate(zip(
clim, (norm.vmin, norm.vmax))):
if lim is np.nan:
# default to using current value
clim[j] = ax_lim
if None not in clim and clim[0] > clim[1]:
# ensure min is <= max, setting to the currently adjusted val
if minimum is np.nan:
clim[0] = clim[1]
else:
clim[1] = clim[0]
# directly update norm rather than clim so that None vals are
# not skipped for auto-scaling
norm.vmin, norm.vmax = clim
plot_ax_img.ax_img.autoscale_None()
if norm.vmin > norm.vmax:
# auto-scaling may cause vmin to exceed vmax
if minimum is np.nan:
norm.vmin = norm.vmax
else:
norm.vmax = norm.vmin
if brightness is not None or contrast is not None:
# adjust brightness and contrast together
PlotEditor.change_brightness_contrast(
plot_ax_img, brightness, contrast)
if alpha is not None:
# adjust and store opacity
try:
plot_ax_img.ax_img.set_alpha(alpha)
except ValueError:
# WORKAROUND: matplotlib v3.8 give an error when stored alpha
# is an array and new alpha is a scalar
plot_ax_img.ax_img._alpha = None
plot_ax_img.ax_img.set_alpha(alpha)
plot_ax_img.alpha = alpha
return plot_ax_img
[docs]
def update_alpha_blend(
self, imgi: int, alpha_blend: float) -> List[PlotAxImg]:
"""Update alpha blending between two images.
Args:
imgi: Index of image group. The first two images in this group
will be blended.
alpha_blend: Alpha opacity level for the blending.
Returns:
List of the updated axes image storage instances.
"""
# get the first two channels in the image group
plot_ax_imgs = [self.get_displayed_img(imgi, c) for c in (0, 1)]
if None in plot_ax_imgs:
return plot_ax_imgs
# blend the images
alpha1, alpha2 = plot_support.alpha_blend_intersection(
plot_ax_imgs[0].img, plot_ax_imgs[1].img, alpha_blend)
plot_ax_imgs[0].ax_img.set_alpha(alpha1)
plot_ax_imgs[1].ax_img.set_alpha(alpha2)
# store the alpha blend level while retaining the original stored alpha
for p in plot_ax_imgs: p.alpha_blend = alpha_blend
return plot_ax_imgs
[docs]
def update_img_display(
self, imgi: int, chl: Optional[int] = None, minimum: float = np.nan,
maximum: float = np.nan, brightness: Optional[float] = None,
contrast: Optional[float] = None, alpha: Optional[float] = None,
alpha_blend: Optional[float] = None) -> PlotAxImg:
"""Update displayed image settings.
Args:
imgi: Index of image group.
chl: Index of channel within the group; defaults to None.
minimum: Vmin; can be None for auto setting; defaults
to ``np.nan`` to ignore.
maximum: Vmax; can be None for auto setting; defaults
to ``np.nan`` to ignore.
brightness: Brightness addend; defaults to None.
contrast: Contrast multiplier; defaults to None.
alpha: Opacity value; defaults to None.
alpha_blend: Opacity blending value; defaults to None. False
turns off alpha blending, resetting the images to their
stored alpha values.
Returns:
The updated axes image plot.
"""
if alpha_blend is False:
# turn off alpha blending, assumed to be applied to first 2
# images, and reset them to stored alpha values
for c in (0, 1):
plot_ax_img = self.get_displayed_img(imgi, c)
if plot_ax_img is not None:
plot_ax_img.alpha_blend = None
plot_ax_img = self.update_img_display(
imgi, c, minimum, maximum, brightness, contrast,
plot_ax_img.alpha)
elif alpha_blend:
# alpha blend the images
plot_ax_img = self.update_alpha_blend(imgi, alpha_blend)[0]
else:
# update the selected image
plot_ax_img = self.get_displayed_img(imgi, chl)
plot_ax_img = self.update_plot_ax_img_display(
plot_ax_img, minimum, maximum, brightness, contrast, alpha)
self.axes.figure.canvas.draw_idle()
return plot_ax_img
[docs]
def alpha_updater(self, alpha):
"""Update labels image opacity level.
Params:
alpha (float): Alpha level.
"""
self.alpha = alpha
if self._ax_img_labels is not None:
self._ax_img_labels.set_alpha(self.alpha)
#print("set image alpha to {}".format(self.alpha))
@staticmethod
def _is_pan(event: "backend_bases.MouseEvent") -> bool:
"""Check if a mouse event is for panning navigation."""
return not event.dblclick and (event.button == 2 or (
event.button == 1 and event.key == "shift"))
@staticmethod
def _is_zoom(event: "backend_bases.MouseEvent") -> bool:
"""Check if a mouse event is for zooming navigation."""
return not event.dblclick and (event.button == 3 or (
event.button == 1 and event.key == "control"))
[docs]
def on_press(self, event):
"""Respond to mouse press events."""
if event.inaxes != self.axes: return
x = int(event.xdata)
y = int(event.ydata)
self.press_loc_data = (x, y)
self.last_loc_data = tuple(self.press_loc_data)
self.last_loc = (int(event.x), int(event.y))
# re-translate downsampled coordinates back up
coord = self.translate_coord([self.coord[0], y, x], up=True)
if (self._is_pan(event) or self._is_zoom(event)) and self.blitter:
# add axes images to blitter for navigation
for ax_img in self._get_ax_imgs():
artists = self.blitter.artists
if ax_img in artists:
artists.remove(ax_img)
self.blitter.add_artist(ax_img)
if event.button == 1:
if self.edit_mode and self.img3d_labels is not None:
# label painting in edit mode
if event.key is not None and "alt" in event.key:
print("using previously picked intensity instead,",
self.intensity)
elif self.intensity_spec is None:
# click while in editing mode to initialize intensity value
# for painting, using values at current position for
# the underlying and displayed images
self.intensity = self.img3d_labels[tuple(coord)]
self.intensity_shown = self._ax_img_labels.get_array()[y, x]
print("got intensity {} at x,y,z = {},{},{}"
.format(self.intensity, *coord[::-1]))
if self.fn_update_intensity:
# trigger text box update
self.fn_update_intensity(self.intensity)
else:
# apply user-specified intensity value; assume it will
# be reset on mouse release
print("using specified intensity of", self.intensity_spec)
self.intensity = self.intensity_spec
self.intensity_shown = self.cmap_labels.convert_img_labels(
self.intensity)
elif event.key not in self._KEY_MODIFIERS:
# click without modifiers to update crosshairs and
# corresponding planes
self.coord = coord
if self.fn_update_coords:
# trigger provided handler
self.fn_update_coords(self.coord, self.plane)
if self.label_motion_thresh is np.inf:
# update region label on click if set to ignore motion
self._update_region_label(x, y, coord)
self._redraw_animated()
if event.key == "3" and self.fn_show_label_3d is not None:
if self.img3d_labels is not None:
# extract label ID and display in 3D viewer
self.fn_show_label_3d(self.img3d_labels[tuple(coord)])
[docs]
def on_axes_exit(self, event):
"""Remove any mouse circle and region label."""
if event.inaxes != self.axes: return
if self.circle and not self.blitter:
# remove circle unless managed by blitter
self.circle.remove()
self.circle = None
if self.region_label:
# make region label empty
self.region_label.set_text("")
def _redraw_animated(self):
"""Redraw by blitter if available, falling back to full redraw."""
if self.blitter:
# blit updates if available for more efficient rendering
self.blitter.update()
else:
# need explicit draw call for figs embedded in TraitsUI
self.axes.figure.canvas.draw_idle()
def _update_region_label(self, x: int, y: int, coord: Sequence[int]):
"""Update region label at the given location.
Args:
x: Mouse x-value.
y: Mouse y-value
coord: Coordinates translated to region space.
"""
name = ""
if self._show_labels:
# get name from labels reference corresponding to
# labels image value under mouse pointer
atlas_label = ontology.get_label(
coord, self.img3d_labels,
config.labels_ref.ref_lookup, self.scaling,
self.labels_level)
if atlas_label is not None:
# extract name and ID from label dict
name = "{} ({})".format(
ontology.get_label_name(atlas_label),
ontology.get_label_item(
atlas_label, config.ABAKeys.ABA_ID.value))
_logger.debug("Found label: %s", name)
# minimize chance of text overflowing out of axes by
# word-wrapping and switching sides at plot midlines; shift in
# axes coords for consistent distance across zooming
name = "\n".join(textwrap.wrap(name, 30))
ax_coords = self.axes.transLimits.transform((x, y))
# shift horizontally
if ax_coords[0] > 0.5:
alignment_x = "right"
ax_coords[0] -= self._region_label_offset
else:
alignment_x = "left"
ax_coords[0] += self._region_label_offset
# shift vertically
if ax_coords[1] > 0.5:
alignment_y = "top"
ax_coords[1] -= self._region_label_offset
else:
alignment_y = "bottom"
ax_coords[1] += self._region_label_offset
# set alignment and convert back to data coordinates
self.region_label.set_horizontalalignment(alignment_x)
self.region_label.set_verticalalignment(alignment_y)
inv = self.axes.transLimits.inverted()
data_coords = inv.transform(ax_coords)
self.region_label.set_position(data_coords)
# update label
self.region_label.set_text(name)
[docs]
def on_motion(self, event):
"""Handle motion events, including navigation and label editing."""
if event.inaxes != self.axes: return
# get mouse position and return if no change from last pixel coord
x = int(event.xdata)
y = int(event.ydata)
x_fig = int(event.x)
y_fig = int(event.y)
# current state
loc = (x_fig, y_fig)
loc_data = (x, y)
curr_time = time.perf_counter()
# get action type
pan = self._is_pan(event)
zoom = self._is_zoom(event)
# copy last state since it may change from other threads
last_loc = self.last_loc
last_loc_data = self.last_loc_data
last_time = self._last_time
if last_loc is not None:
# skip movements that are fast and short; all movements within the
# same px will be skipped if threshold is non-neg
time_diff = curr_time - last_time
dist = np.hypot(*np.subtract(last_loc, loc))
movt = dist * time_diff
if pan or zoom:
# navigation threshold
if movt <= self.nav_motion_thresh: return
else:
# region label threshold
if movt <= self.label_motion_thresh: return
# update instance state
self.last_loc = loc
self.last_loc_data = loc_data
self._last_time = curr_time
if pan:
# pan by middle-click or shift+left-click during mouseover
# use data coordinates so same part of image stays under mouse
if last_loc_data is None:
last_loc_data = loc_data
dx = x - last_loc_data[0]
dy = y - last_loc_data[1]
# data itself moved, so update saved location for this movement
loc_data = (x - dx, y - dy)
self.last_loc_data = loc_data
# update axes view for the movement
xlim = self.axes.get_xlim()
self.axes.set_xlim(xlim[0] - dx, xlim[1] - dx)
ylim = self.axes.get_ylim()
self.axes.set_ylim(ylim[0] - dy, ylim[1] - dy)
self._redraw_animated()
self.xlim = self.axes.get_xlim()
self.ylim = self.axes.get_ylim()
elif zoom:
# zooming by right-click or ctrl+click (which coverts button event
# to 3 on Mac at least) while moving mouse up/down in y
# use figure coordinates since data pixels will scale during zoom
zoom_speed = (y_fig - last_loc[1]) * 0.01
# update axes view for the zoom
xlim = self.axes.get_xlim()
xlim_update = (
xlim[0] + (self.press_loc_data[0] - xlim[0]) * zoom_speed,
xlim[1] + (self.press_loc_data[0] - xlim[1]) * zoom_speed)
ylim = self.axes.get_ylim()
ylim_update = (
ylim[0] + (self.press_loc_data[1] - ylim[0]) * zoom_speed,
ylim[1] + (self.press_loc_data[1] - ylim[1]) * zoom_speed)
# avoid flip by checking that relationship between high and low
# values in updated limits is in the same order as in the current
# limits, which might otherwise flip if zoom speed is high
if ((xlim_update[1] - xlim_update[0]) * (xlim[1] - xlim[0]) > 0 and
(ylim_update[1] - ylim_update[0]) *
(ylim[1] - ylim[0]) > 0):
self.axes.set_xlim(xlim_update)
self.axes.set_ylim(ylim_update)
self._redraw_animated()
self.xlim = self.axes.get_xlim()
self.ylim = self.axes.get_ylim()
else:
# hover movements over image
if 0 <= x < self.img3d.shape[2] and 0 <= y < self.img3d.shape[1]:
if self.enable_painting and self.edit_mode:
if self.circle:
# update pen circle position
self.circle.center = x, y
# does not appear necessary since text update already
# triggers redraw, but would also trigger if no update
self.circle.stale = True
else:
# generate new circle if not yet present
self.circle = patches.Circle(
(x, y), radius=self.radius, linestyle=":",
fill=False, edgecolor="w")
self.axes.add_patch(self.circle)
if self.blitter:
self.blitter.add_artist(self.circle)
# re-translate downsampled coordinates to original space
coord = self.translate_coord([self.coord[0], y, x], up=True)
if event.button == 1:
if (self.enable_painting and self.edit_mode
and self.intensity is not None):
# click in editing mode to overwrite images with pen
# of the current radius using chosen intensity for the
# underlying and displayed images
if self._ax_img_labels is not None:
# edit displayed image
img = self._ax_img_labels.get_array()
rr, cc = draw.disk(
(y, x), self.radius, shape=img.shape)
img[rr, cc] = self.intensity_shown
# edit underlying labels image
rr, cc = draw.disk(
(coord[1], coord[2]),
self.radius * self._downsample[0],
shape=self.img3d_labels[self.coord[0]].shape)
self.img3d_labels[
self.coord[0], rr, cc] = self.intensity
_logger.debug(
"Changed intensity at x,y,z = %s to %s",
coord[::-1], self.intensity)
if self.fn_refresh_images is None:
self.refresh_img3d_labels()
else:
self.fn_refresh_images(self, True)
self.edited = True
self._editing = True
else:
# mouse left-click drag (separate from the initial
# press) moves crosshairs
self.coord = coord
if self.fn_update_coords:
self.fn_update_coords(self.coord, self.plane)
if (self.img3d_labels is not None and
config.labels_ref is not None and
config.labels_ref.ref_lookup and
self.label_motion_thresh is not np.inf):
# update region label unless motion thresh is set to ignore
self._update_region_label(x, y, coord)
self._redraw_animated()
if self.fn_status_bar:
self.fn_status_bar(self.axes.format_coord.get_msg(event))
[docs]
def on_release(self, event):
"""Respond to mouse button release events.
If labels were edited during the current mouse press, update
plane interpolation values. Also reset any specified intensity value.
Args:
event: Key press event.
"""
if self._editing:
if self.interp_planes is not None:
self.interp_planes.update_plane(
self.plane, self.coord[0], self.intensity)
self._editing = False
if (self._is_pan(event) or self._is_zoom(event)) and self.blitter:
# remove axes images except labels image in editing mode
for ax_img in self._get_ax_imgs():
artists = self.blitter.artists
if ax_img in artists and not (
self.edit_mode and ax_img is self._ax_img_labels):
artists.remove(ax_img)
# redraw to refresh any last event
self.axes.figure.canvas.draw_idle()
if self.intensity_spec is not None:
# reset all plot editors' specified intensity to allow updating
# with clicked intensities
self.intensity_spec = None
if self.fn_update_intensity:
self.fn_update_intensity(None)
[docs]
def on_key_press(self, event):
"""Change pen radius with bracket ([/]) buttons.
The "ctrl" modifier will have the increment.
Args:
event: Key press event.
"""
rad_orig = self.radius
increment = 0.5 if "ctrl" in event.key else 1
if "[" in event.key and self.radius > 1:
self.radius -= increment
elif "]" in event.key:
self.radius += increment
#print("radius: {}".format(self.radius))
if rad_orig != self.radius and self.circle:
self.circle.radius = self.radius