88 SymLogNorm )
99from nap_plot_tools .cmap import (cat10_mod_cmap ,
1010 cat10_mod_cmap_first_transparent )
11+ from scipy .stats import binned_statistic_2d
1112
1213from biaplotter .colormap import BiaColormap
13- from scipy . stats import binned_statistic_2d
14+
1415from .artists_base import Artist
1516
1617
@@ -52,8 +53,6 @@ class Scatter(Artist):
5253 >>> plt.show()
5354 """
5455
55-
56-
5756 def __init__ (
5857 self ,
5958 ax : plt .Axes = None ,
@@ -74,16 +73,17 @@ def __init__(
7473 def _refresh (self , force_redraw : bool = True ):
7574 """Creates the scatter plot with the data and default properties."""
7675
77- if force_redraw or self ._mpl_artists [' scatter' ] is None :
76+ if force_redraw or self ._mpl_artists [" scatter" ] is None :
7877 self ._remove_artists ()
7978 # Create a new scatter plot with the updated data
80- self ._mpl_artists ['scatter' ] = self .ax .scatter (
81- self ._data [:, 0 ], self ._data [:, 1 ])
79+ self ._mpl_artists ["scatter" ] = self .ax .scatter (
80+ self ._data [:, 0 ], self ._data [:, 1 ]
81+ )
8282 self .size = 50 # Default size
8383 self .alpha = 1 # Default alpha
8484 self .color_indices = 0
8585 else :
86- self ._mpl_artists [' scatter' ].set_offsets (
86+ self ._mpl_artists [" scatter" ].set_offsets (
8787 self ._data
8888 ) # somehow resets the size and alpha
8989 self .color_indices = self ._color_indices
@@ -95,22 +95,21 @@ def _colorize(self, indices: np.ndarray):
9595 Add a color to the drawn scatter points
9696 """
9797 rgba_colors = self .color_indices_to_rgba (indices )
98- self ._mpl_artists [' scatter' ].set_facecolor (rgba_colors )
99- self ._mpl_artists [' scatter' ].set_edgecolor ("white" )
98+ self ._mpl_artists [" scatter" ].set_facecolor (rgba_colors )
99+ self ._mpl_artists [" scatter" ].set_edgecolor ("white" )
100100
101101 return rgba_colors
102102
103103 def color_indices_to_rgba (
104- self ,
105- indices : np .ndarray ,
106- is_overlay : bool = True ) -> np .ndarray :
104+ self , indices : np .ndarray , is_overlay : bool = True
105+ ) -> np .ndarray :
107106 """
108107 Convert color indices to RGBA colors using the colormap.
109108 """
110109 norm = self ._get_normalization (indices )
111110 colormap = self .overlay_colormap .cmap
112111
113- rgba = colormap (norm (self . _color_indices ))
112+ rgba = colormap (norm (indices ))
114113 return rgba
115114
116115 def _get_normalization (self , values : np .ndarray ) -> Normalize :
@@ -127,7 +126,8 @@ def _get_normalization(self, values: np.ndarray) -> Normalize:
127126 }
128127
129128 normalization_func = norm_dispatch .get (
130- self ._color_normalization_method )
129+ self ._color_normalization_method
130+ )
131131 if normalization_func is None :
132132 raise ValueError (
133133 f"Unknown color normalization method: { self ._color_normalization_method } .\n "
@@ -165,7 +165,11 @@ def overlay_visible(self) -> bool:
165165 def overlay_visible (self , value : bool ):
166166 """Sets the visibility of the overlay colormap."""
167167 self ._overlay_visible = value
168- self ._colorize (self ._color_indices )
168+ if value :
169+ self ._colorize (self ._color_indices )
170+ else :
171+ self ._colorize (np .zeros_like (self ._color_indices ))
172+ self .draw ()
169173
170174 @property
171175 def color_normalization_method (self ) -> str :
@@ -193,7 +197,7 @@ def alpha(self) -> Union[float, np.ndarray]:
193197 alpha : float
194198 alpha value of the scatter plot.
195199 """
196- return self ._mpl_artists [' scatter' ].get_alpha ()
200+ return self ._mpl_artists [" scatter" ].get_alpha ()
197201
198202 @alpha .setter
199203 def alpha (self , value : Union [float , np .ndarray ]):
@@ -202,8 +206,8 @@ def alpha(self, value: Union[float, np.ndarray]):
202206
203207 if np .isscalar (value ):
204208 value = np .ones (len (self ._data )) * value
205- if ' scatter' in self ._mpl_artists .keys ():
206- self ._mpl_artists [' scatter' ].set_alpha (value )
209+ if " scatter" in self ._mpl_artists .keys ():
210+ self ._mpl_artists [" scatter" ].set_alpha (value )
207211 self .draw ()
208212
209213 @property
@@ -223,8 +227,8 @@ def size(self) -> Union[float, np.ndarray]:
223227 def size (self , value : Union [float , np .ndarray ]):
224228 """Sets the size of the points in the scatter plot."""
225229 self ._size = value
226- if ' scatter' in self ._mpl_artists .keys ():
227- self ._mpl_artists [' scatter' ].set_sizes (
230+ if " scatter" in self ._mpl_artists .keys ():
231+ self ._mpl_artists [" scatter" ].set_sizes (
228232 np .full (len (self ._data ), value )
229233 if np .isscalar (value )
230234 else value
@@ -305,14 +309,14 @@ def _refresh(self, force_redraw: bool = True):
305309 self ._histogram_rgba = self .color_indices_to_rgba (
306310 counts .T , is_overlay = False
307311 )
308- self ._mpl_artists [' histogram_image' ] = self .ax .imshow (
312+ self ._mpl_artists [" histogram_image" ] = self .ax .imshow (
309313 self ._histogram_rgba ,
310314 extent = [x_edges [0 ], x_edges [- 1 ], y_edges [0 ], y_edges [- 1 ]],
311315 origin = "lower" ,
312316 zorder = 1 ,
313317 interpolation = self ._histogram_interpolation ,
314318 alpha = 1 ,
315- aspect = ' auto'
319+ aspect = " auto" ,
316320 )
317321
318322 if force_redraw :
@@ -328,33 +332,35 @@ def _colorize(self, indices: np.ndarray):
328332 _ , x_edges , y_edges = self ._histogram
329333 # Assign median values to the bins (fill with NaNs if no data in the bin)
330334 statistic_histogram , _ , _ , _ = binned_statistic_2d (
331- x = self ._data [:, 0 ],
332- y = self ._data [:, 1 ],
335+ x = self ._data [:, 0 ],
336+ y = self ._data [:, 1 ],
333337 values = indices ,
334338 statistic = _median_np ,
335- bins = [x_edges , y_edges ]
339+ bins = [x_edges , y_edges ],
336340 )
337341 if not np .all (np .isnan (statistic_histogram )):
338342 # Draw the overlay
339343 self .overlay_histogram_rgba = self .color_indices_to_rgba (
340344 statistic_histogram .T , is_overlay = True
341345 )
342- self ._mpl_artists [' overlay_histogram_image' ] = self .ax .imshow (
346+ self ._mpl_artists [" overlay_histogram_image" ] = self .ax .imshow (
343347 self .overlay_histogram_rgba ,
344348 extent = [x_edges [0 ], x_edges [- 1 ], y_edges [0 ], y_edges [- 1 ]],
345349 origin = "lower" ,
346350 zorder = 2 ,
347351 interpolation = self ._overlay_interpolation ,
348352 alpha = self ._overlay_opacity ,
349- aspect = ' auto'
353+ aspect = " auto" ,
350354 )
351355
352- def color_indices_to_rgba (self , indices , is_overlay : bool = True ) -> np .ndarray :
356+ def color_indices_to_rgba (
357+ self , indices , is_overlay : bool = True
358+ ) -> np .ndarray :
353359 """
354360 Convert color indices to RGBA colors using the overlay colormap.
355361 """
356362 norm = self ._get_normalization (indices , is_overlay = is_overlay )
357-
363+
358364 if is_overlay :
359365 colormap = self .overlay_colormap .cmap
360366 else :
@@ -430,7 +436,7 @@ def cmin(self) -> int:
430436 minimum count for the histogram.
431437 """
432438 return self ._cmin
433-
439+
434440 @cmin .setter
435441 def cmin (self , value : int ):
436442 """Sets the minimum count for the histogram."""
@@ -524,8 +530,8 @@ def overlay_visible(self):
524530 def overlay_visible (self , value ):
525531 """Sets the visibility of the overlay histogram."""
526532 self ._overlay_visible = value
527- if ' overlay_histogram_image' in self ._mpl_artists :
528- self ._mpl_artists [' overlay_histogram_image' ].set_visible (value )
533+ if " overlay_histogram_image" in self ._mpl_artists :
534+ self ._mpl_artists [" overlay_histogram_image" ].set_visible (value )
529535 self .draw ()
530536
531537 @property
@@ -602,9 +608,8 @@ def _is_categorical_colormap(self, colormap):
602608 return False
603609
604610 def _get_normalization (
605- self ,
606- values : np .ndarray ,
607- is_overlay : bool = True ) -> Normalize :
611+ self , values : np .ndarray , is_overlay : bool = True
612+ ) -> Normalize :
608613 """
609614 Get the normalization class for the histogram data.
610615
@@ -634,17 +639,19 @@ def _get_normalization(
634639 # norm_dispatch is to be indexed like this:
635640 # norm_dispatch[is_categorical, color_normalization_method]
636641 norm_dispatch = {
637- (True , 'linear' ): lambda : self ._linear_normalization (values , is_categorical ),
638- (False , 'linear' ): lambda : self ._linear_normalization (values ),
639- (False , 'log' ): lambda : self ._log_normalization (values ),
640- (False , 'centered' ): lambda : self ._centered_normalization (values ),
641- (False , 'symlog' ): lambda : self ._symlog_normalization (values ),
642+ (True , "linear" ): lambda : self ._linear_normalization (
643+ values , is_categorical
644+ ),
645+ (False , "linear" ): lambda : self ._linear_normalization (values ),
646+ (False , "log" ): lambda : self ._log_normalization (values ),
647+ (False , "centered" ): lambda : self ._centered_normalization (values ),
648+ (False , "symlog" ): lambda : self ._symlog_normalization (values ),
642649 }
643650
644651 return norm_dispatch .get ((is_categorical , norm_method ))()
645652
646653
647- def _median_np (arr , method = ' lower' ) -> float :
654+ def _median_np (arr , method = " lower" ) -> float :
648655 """Calculate the median of a 1D array.
649656
650657 Parameters
@@ -661,4 +668,4 @@ def _median_np(arr, method='lower') -> float:
661668 """
662669 if len (arr ) == 0 :
663670 return np .nan
664- return np .nanpercentile (arr , 50 , method = method )
671+ return np .nanpercentile (arr , 50 , method = method )
0 commit comments