diff --git a/neat_ml/bubblesam/bubblesam.py b/neat_ml/bubblesam/bubblesam.py index bdc3ee2..880b0e7 100644 --- a/neat_ml/bubblesam/bubblesam.py +++ b/neat_ml/bubblesam/bubblesam.py @@ -131,7 +131,9 @@ def analyze_and_filter_masks( if len(props_list) == 0: continue - rp = props_list[0] + # take the region properties from the segmentation map with the greatest area + rp_areas = [x.area for x in props_list] + rp = props_list[np.argmax(rp_areas)] area = rp.area perimeter = rp.perimeter if perimeter == 0: @@ -141,19 +143,18 @@ def analyze_and_filter_masks( major_axis = rp.major_axis_length minor_axis = rp.minor_axis_length h, w = seg.shape[:2] - # Using a small margin (2 pixels) to be safe + # Using a small margin (2 pixels) to be safe, + # filter any segmentations with bounding boxes close to the size of the image + # because SAM-2 can sometimes detect the image background itself. + bbox_area = (rp.bbox[2] - rp.bbox[0]) * (rp.bbox[3] - rp.bbox[1]) max_allowed_area = (h - 2) * (w - 2) - if area >= area_threshold and circ >= circularity_threshold: + if (area >= area_threshold and circ >= circularity_threshold + and bbox_area < max_allowed_area): binary_mask = seg.astype('uint8') * 255 contours, _ = cv2.findContours(binary_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) - # reshape contours for plotting and remove any contours - # close to the size of the image because cv2.findContours - # can sometimes detect the image edge itself. - all_contours = [ - c.reshape(-1, 2)[:, ::-1] - for c in contours - if cv2.contourArea(c) < max_allowed_area - ] + # keep only the largest contour in each segmentation area + # and reshape for plotting + max_contour = max(contours, key=cv2.contourArea).squeeze(axis=1) radius = np.sqrt(area / np.pi) euler_number = rp.euler_number # output of cucim ``rp`` stores values as objects @@ -164,7 +165,7 @@ def analyze_and_filter_masks( euler_number = euler_number.item() mask_info = { 'bbox': rp.bbox, - 'contour': all_contours, + 'contour': max_contour, 'major_axis': major_axis, 'minor_axis': minor_axis, 'area': area, @@ -202,7 +203,7 @@ def plot_filtered_masks( for idx, row in masks_summary_df.iterrows(): contour = row['contour'] bbox = row['bbox'] - ax.plot(contour[0][:, 1], contour[0][:, 0], linewidth=1, color='blue') + ax.plot(contour[:, 0], contour[:, 1], linewidth=1, color='blue') min_row, min_col, max_row, max_col = bbox rect = Rectangle( (min_col, min_row), @@ -271,7 +272,7 @@ def bubblesam_detection( ) # save filtered dataframe as parquet file - # convert ``contours`` and ``bbox`` columns to list to save as parquet + # convert ``contour`` and ``bbox`` columns to list to save as parquet save_filtered_df = filtered_df.copy() save_filtered_df["bbox"] = save_filtered_df["bbox"].apply(list) save_filtered_df["contour"] = save_filtered_df["contour"].apply( diff --git a/neat_ml/tests/test_bubblesam.py b/neat_ml/tests/test_bubblesam.py index caf6d29..6322e65 100644 --- a/neat_ml/tests/test_bubblesam.py +++ b/neat_ml/tests/test_bubblesam.py @@ -318,3 +318,36 @@ def test_run_bubblesam_model_cfg_error(): """ with pytest.raises(ValueError, match="Must provide model configuration"): run_bubblesam(pd.DataFrame(), Path("output"), detection_cfg={}) + +@pytest.mark.parametrize("seg_params, exp_bbox", + [ + # a test case where the segmentation contains two disjoint areas + ([[50, 60], [40, 45]], (50, 50, 60, 60)), + # a test case where the segmentation contains a region that touches + # the image boundary at the bottom right corner + ([[90, 100]], (90, 90, 100, 100)), + ] +) +def test_bubblesam_contours(seg_params, exp_bbox): + """ + test that running `analyze_and_filter_masks` generates a dataframe with + only a single contour per detection and without background areas + """ + # create two segmentation maps, one that takes up the whole image (background) + # and one containing the segmentation map generated using the test case parameters + seg = np.ones((100, 100)).astype(bool) + seg2 = np.zeros((100, 100)).astype(bool) + for seg_param in seg_params: + start = seg_param[0] + end = seg_param[1] + seg2[start:end, start:end] = True + input_df = pd.DataFrame({"segmentation": [seg, seg2]}) + # call `analyze_and_filter_masks` to return filtered dataframe + # (the circularity of a perfect square is ~0.8, so lower the + # circularity threshold so that the background only gets filtered + # out by the bounding box area) + df = analyze_and_filter_masks(input_df, 25, 0.7, device="cpu") + # assert that there is only a single dataframe row after filtration + # corresponding to the appropriate segmentation map to keep from `seg2` + assert df.bbox.item() == exp_bbox + assert df.contour.item().shape == (36, 2)