Skip to content

Commit 530492b

Browse files
authored
Merge pull request #49 from BiAPoL/fix_show_overlay
Fix show overlay
2 parents 233336f + bb777c0 commit 530492b

9 files changed

Lines changed: 175 additions & 84 deletions

File tree

docs/plotter_api.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1212
~CanvasWidget.active_artist
1313
~CanvasWidget.active_selector
14+
~CanvasWidget.show_color_overlay
1415
1516
.. rubric:: Methods Summary
1617
@@ -21,19 +22,20 @@
2122
~CanvasWidget.add_selector
2223
~CanvasWidget.remove_selector
2324
~CanvasWidget.on_enable_selector
24-
~CanvasWidget.hide_color_overlay
2525
2626
.. rubric:: Signals Summary
2727
2828
.. autosummary::
2929
3030
~CanvasWidget.artist_changed_signal
3131
~CanvasWidget.selector_changed_signal
32+
~CanvasWidget.show_overlay_signal
3233
3334
.. rubric:: Properties Documentation
3435
3536
.. autoattribute:: active_artist
3637
.. autoattribute:: active_selector
38+
.. autoattribute:: show_color_overlay
3739
3840
.. rubric:: Methods Documentation
3941
@@ -42,10 +44,10 @@
4244
.. automethod:: add_selector
4345
.. automethod:: remove_selector
4446
.. automethod:: on_enable_selector
45-
.. automethod:: hide_color_overlay
4647
4748
.. rubric:: Signals Documentation
4849
4950
.. autoattribute:: artist_changed_signal
5051
.. autoattribute:: selector_changed_signal
52+
.. autoattribute:: show_overlay_signal
5153
```

src/biaplotter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.3.0"
22
from .artists import Histogram2D, Scatter
33
from .colormap import BiaColormap
44
from .plotter import CanvasWidget

src/biaplotter/_tests/test_artists.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def on_color_indices_changed(color_indices):
5050
assert scatter.color_indices.shape == (size,)
5151

5252
# Test scatter colors
53-
colors = scatter._mpl_artists['scatter'].get_facecolors()
53+
colors = scatter._mpl_artists["scatter"].get_facecolors()
5454
assert np.all(colors[0] == scatter.overlay_colormap(0))
5555
assert np.all(colors[50] == scatter.overlay_colormap(2))
5656

@@ -69,27 +69,27 @@ def on_color_indices_changed(color_indices):
6969
# Test size property
7070
scatter.size = 5.0
7171
assert scatter.size == 5.0
72-
sizes = scatter._mpl_artists['scatter'].get_sizes()
72+
sizes = scatter._mpl_artists["scatter"].get_sizes()
7373
assert np.all(sizes == 5.0)
7474

7575
scatter.size = np.linspace(1, 10, size)
7676
assert np.all(scatter.size == np.linspace(1, 10, size))
77-
sizes = scatter._mpl_artists['scatter'].get_sizes()
77+
sizes = scatter._mpl_artists["scatter"].get_sizes()
7878
assert np.all(sizes == np.linspace(1, 10, size))
7979

8080
# Test size reset when new data is set
8181
scatter.data = np.random.rand(size // 2, 2)
8282
assert np.all(scatter.size == 50.0) # that's the default
83-
sizes = scatter._mpl_artists['scatter'].get_sizes()
83+
sizes = scatter._mpl_artists["scatter"].get_sizes()
8484
assert np.all(sizes == 50.0)
8585

8686
# test alpha
8787
scatter.alpha = 0.5
88-
assert np.all(scatter._mpl_artists['scatter'].get_alpha() == 0.5)
88+
assert np.all(scatter._mpl_artists["scatter"].get_alpha() == 0.5)
8989

9090
# test alpha reset when new data is set
9191
scatter.data = np.random.rand(size, 2)
92-
assert np.all(scatter._mpl_artists['scatter'].get_alpha() == 1.0)
92+
assert np.all(scatter._mpl_artists["scatter"].get_alpha() == 1.0)
9393

9494
# Test changing overlay_colormap
9595
assert scatter.overlay_colormap.name == "cat10_modified"
@@ -98,7 +98,7 @@ def on_color_indices_changed(color_indices):
9898

9999
# Test scatter color indices after continuous overlay_colormap
100100
scatter.color_indices = np.linspace(0, 1, size)
101-
colors = scatter._mpl_artists['scatter'].get_facecolors()
101+
colors = scatter._mpl_artists["scatter"].get_facecolors()
102102
assert np.all(colors[0] == plt.cm.viridis(0))
103103

104104
# Test scatter color_normalization_method
@@ -181,7 +181,9 @@ def on_color_indices_changed(color_indices):
181181
assert histogram.cmin == 0
182182

183183
# Test overlay colors
184-
overlay_array = histogram._mpl_artists['overlay_histogram_image'].get_array()
184+
overlay_array = histogram._mpl_artists[
185+
"overlay_histogram_image"
186+
].get_array()
185187
assert overlay_array.shape == (bins, bins, 4)
186188
# indices where overlay_array is not zero
187189
indices = np.where(overlay_array[..., -1] != 0)
@@ -229,7 +231,7 @@ def on_color_indices_changed(color_indices):
229231

230232
# Don't draw overlay histogram if color_indices are nan
231233
histogram.color_indices = np.nan
232-
assert 'overlay_histogram_image' not in histogram._mpl_artists.keys()
234+
assert "overlay_histogram_image" not in histogram._mpl_artists.keys()
233235

234236

235237
# Test calculate_statistic_histogram_method for different statistics

src/biaplotter/_tests/test_widget.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@ def test_disable_all_selectors(canvas_widget):
8484
assert selector._selector is None
8585

8686

87-
def test_hide_color_overlay(canvas_widget):
88-
"""Test the hide_color_overlay method."""
89-
canvas_widget.hide_color_overlay(True)
87+
def test_show_color_overlay(canvas_widget):
88+
"""Test the show_color_overlay method."""
89+
canvas_widget.show_color_overlay = False
9090
assert not canvas_widget.active_artist.overlay_visible
9191

92-
canvas_widget.hide_color_overlay(False)
92+
canvas_widget.show_color_overlay = True
9393
assert canvas_widget.active_artist.overlay_visible
9494

9595

src/biaplotter/artists.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
SymLogNorm)
99
from nap_plot_tools.cmap import (cat10_mod_cmap,
1010
cat10_mod_cmap_first_transparent)
11+
from scipy.stats import binned_statistic_2d
1112

1213
from biaplotter.colormap import BiaColormap
13-
from scipy.stats import binned_statistic_2d
14+
1415
from .artists_base import Artist
1516

1617

@@ -52,8 +53,6 @@ class Scatter(Artist):
5253
>>> plt.show()
5354
"""
5455

55-
56-
5756
def __init__(
5857
self,
5958
ax: plt.Axes = None,
@@ -74,16 +73,17 @@ def __init__(
7473
def _refresh(self, force_redraw: bool = True):
7574
"""Creates the scatter plot with the data and default properties."""
7675

77-
if force_redraw or self._mpl_artists['scatter'] is None:
76+
if force_redraw or self._mpl_artists["scatter"] is None:
7877
self._remove_artists()
7978
# Create a new scatter plot with the updated data
80-
self._mpl_artists['scatter'] = self.ax.scatter(
81-
self._data[:, 0], self._data[:, 1])
79+
self._mpl_artists["scatter"] = self.ax.scatter(
80+
self._data[:, 0], self._data[:, 1]
81+
)
8282
self.size = 50 # Default size
8383
self.alpha = 1 # Default alpha
8484
self.color_indices = 0
8585
else:
86-
self._mpl_artists['scatter'].set_offsets(
86+
self._mpl_artists["scatter"].set_offsets(
8787
self._data
8888
) # somehow resets the size and alpha
8989
self.color_indices = self._color_indices
@@ -95,22 +95,21 @@ def _colorize(self, indices: np.ndarray):
9595
Add a color to the drawn scatter points
9696
"""
9797
rgba_colors = self.color_indices_to_rgba(indices)
98-
self._mpl_artists['scatter'].set_facecolor(rgba_colors)
99-
self._mpl_artists['scatter'].set_edgecolor("white")
98+
self._mpl_artists["scatter"].set_facecolor(rgba_colors)
99+
self._mpl_artists["scatter"].set_edgecolor("white")
100100

101101
return rgba_colors
102102

103103
def color_indices_to_rgba(
104-
self,
105-
indices: np.ndarray,
106-
is_overlay: bool = True) -> np.ndarray:
104+
self, indices: np.ndarray, is_overlay: bool = True
105+
) -> np.ndarray:
107106
"""
108107
Convert color indices to RGBA colors using the colormap.
109108
"""
110109
norm = self._get_normalization(indices)
111110
colormap = self.overlay_colormap.cmap
112111

113-
rgba = colormap(norm(self._color_indices))
112+
rgba = colormap(norm(indices))
114113
return rgba
115114

116115
def _get_normalization(self, values: np.ndarray) -> Normalize:
@@ -127,7 +126,8 @@ def _get_normalization(self, values: np.ndarray) -> Normalize:
127126
}
128127

129128
normalization_func = norm_dispatch.get(
130-
self._color_normalization_method)
129+
self._color_normalization_method
130+
)
131131
if normalization_func is None:
132132
raise ValueError(
133133
f"Unknown color normalization method: {self._color_normalization_method}.\n"
@@ -165,7 +165,11 @@ def overlay_visible(self) -> bool:
165165
def overlay_visible(self, value: bool):
166166
"""Sets the visibility of the overlay colormap."""
167167
self._overlay_visible = value
168-
self._colorize(self._color_indices)
168+
if value:
169+
self._colorize(self._color_indices)
170+
else:
171+
self._colorize(np.zeros_like(self._color_indices))
172+
self.draw()
169173

170174
@property
171175
def color_normalization_method(self) -> str:
@@ -193,7 +197,7 @@ def alpha(self) -> Union[float, np.ndarray]:
193197
alpha : float
194198
alpha value of the scatter plot.
195199
"""
196-
return self._mpl_artists['scatter'].get_alpha()
200+
return self._mpl_artists["scatter"].get_alpha()
197201

198202
@alpha.setter
199203
def alpha(self, value: Union[float, np.ndarray]):
@@ -202,8 +206,8 @@ def alpha(self, value: Union[float, np.ndarray]):
202206

203207
if np.isscalar(value):
204208
value = np.ones(len(self._data)) * value
205-
if 'scatter' in self._mpl_artists.keys():
206-
self._mpl_artists['scatter'].set_alpha(value)
209+
if "scatter" in self._mpl_artists.keys():
210+
self._mpl_artists["scatter"].set_alpha(value)
207211
self.draw()
208212

209213
@property
@@ -223,8 +227,8 @@ def size(self) -> Union[float, np.ndarray]:
223227
def size(self, value: Union[float, np.ndarray]):
224228
"""Sets the size of the points in the scatter plot."""
225229
self._size = value
226-
if 'scatter' in self._mpl_artists.keys():
227-
self._mpl_artists['scatter'].set_sizes(
230+
if "scatter" in self._mpl_artists.keys():
231+
self._mpl_artists["scatter"].set_sizes(
228232
np.full(len(self._data), value)
229233
if np.isscalar(value)
230234
else value
@@ -305,14 +309,14 @@ def _refresh(self, force_redraw: bool = True):
305309
self._histogram_rgba = self.color_indices_to_rgba(
306310
counts.T, is_overlay=False
307311
)
308-
self._mpl_artists['histogram_image'] = self.ax.imshow(
312+
self._mpl_artists["histogram_image"] = self.ax.imshow(
309313
self._histogram_rgba,
310314
extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
311315
origin="lower",
312316
zorder=1,
313317
interpolation=self._histogram_interpolation,
314318
alpha=1,
315-
aspect='auto'
319+
aspect="auto",
316320
)
317321

318322
if force_redraw:
@@ -328,33 +332,35 @@ def _colorize(self, indices: np.ndarray):
328332
_, x_edges, y_edges = self._histogram
329333
# Assign median values to the bins (fill with NaNs if no data in the bin)
330334
statistic_histogram, _, _, _ = binned_statistic_2d(
331-
x = self._data[:, 0],
332-
y= self._data[:, 1],
335+
x=self._data[:, 0],
336+
y=self._data[:, 1],
333337
values=indices,
334338
statistic=_median_np,
335-
bins=[x_edges, y_edges]
339+
bins=[x_edges, y_edges],
336340
)
337341
if not np.all(np.isnan(statistic_histogram)):
338342
# Draw the overlay
339343
self.overlay_histogram_rgba = self.color_indices_to_rgba(
340344
statistic_histogram.T, is_overlay=True
341345
)
342-
self._mpl_artists['overlay_histogram_image'] = self.ax.imshow(
346+
self._mpl_artists["overlay_histogram_image"] = self.ax.imshow(
343347
self.overlay_histogram_rgba,
344348
extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
345349
origin="lower",
346350
zorder=2,
347351
interpolation=self._overlay_interpolation,
348352
alpha=self._overlay_opacity,
349-
aspect='auto'
353+
aspect="auto",
350354
)
351355

352-
def color_indices_to_rgba(self, indices, is_overlay: bool = True) -> np.ndarray:
356+
def color_indices_to_rgba(
357+
self, indices, is_overlay: bool = True
358+
) -> np.ndarray:
353359
"""
354360
Convert color indices to RGBA colors using the overlay colormap.
355361
"""
356362
norm = self._get_normalization(indices, is_overlay=is_overlay)
357-
363+
358364
if is_overlay:
359365
colormap = self.overlay_colormap.cmap
360366
else:
@@ -430,7 +436,7 @@ def cmin(self) -> int:
430436
minimum count for the histogram.
431437
"""
432438
return self._cmin
433-
439+
434440
@cmin.setter
435441
def cmin(self, value: int):
436442
"""Sets the minimum count for the histogram."""
@@ -524,8 +530,8 @@ def overlay_visible(self):
524530
def overlay_visible(self, value):
525531
"""Sets the visibility of the overlay histogram."""
526532
self._overlay_visible = value
527-
if 'overlay_histogram_image' in self._mpl_artists:
528-
self._mpl_artists['overlay_histogram_image'].set_visible(value)
533+
if "overlay_histogram_image" in self._mpl_artists:
534+
self._mpl_artists["overlay_histogram_image"].set_visible(value)
529535
self.draw()
530536

531537
@property
@@ -602,9 +608,8 @@ def _is_categorical_colormap(self, colormap):
602608
return False
603609

604610
def _get_normalization(
605-
self,
606-
values: np.ndarray,
607-
is_overlay: bool = True) -> Normalize:
611+
self, values: np.ndarray, is_overlay: bool = True
612+
) -> Normalize:
608613
"""
609614
Get the normalization class for the histogram data.
610615
@@ -634,17 +639,19 @@ def _get_normalization(
634639
# norm_dispatch is to be indexed like this:
635640
# norm_dispatch[is_categorical, color_normalization_method]
636641
norm_dispatch = {
637-
(True, 'linear'): lambda: self._linear_normalization(values, is_categorical),
638-
(False, 'linear'): lambda: self._linear_normalization(values),
639-
(False, 'log'): lambda: self._log_normalization(values),
640-
(False, 'centered'): lambda: self._centered_normalization(values),
641-
(False, 'symlog'): lambda: self._symlog_normalization(values),
642+
(True, "linear"): lambda: self._linear_normalization(
643+
values, is_categorical
644+
),
645+
(False, "linear"): lambda: self._linear_normalization(values),
646+
(False, "log"): lambda: self._log_normalization(values),
647+
(False, "centered"): lambda: self._centered_normalization(values),
648+
(False, "symlog"): lambda: self._symlog_normalization(values),
642649
}
643650

644651
return norm_dispatch.get((is_categorical, norm_method))()
645652

646653

647-
def _median_np(arr, method='lower') -> float:
654+
def _median_np(arr, method="lower") -> float:
648655
"""Calculate the median of a 1D array.
649656
650657
Parameters
@@ -661,4 +668,4 @@ def _median_np(arr, method='lower') -> float:
661668
"""
662669
if len(arr) == 0:
663670
return np.nan
664-
return np.nanpercentile(arr, 50, method=method)
671+
return np.nanpercentile(arr, 50, method=method)

0 commit comments

Comments
 (0)