Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions neat_ml/bubblesam/bubblesam.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def analyze_and_filter_masks(
if len(props_list) == 0:
continue

rp = props_list[0]
# take the region properties from the segmentation map with the greatest area
rp_areas = [x.area for x in props_list]
rp = props_list[np.argmax(rp_areas)]
area = rp.area
perimeter = rp.perimeter
if perimeter == 0:
Expand All @@ -141,19 +143,18 @@ def analyze_and_filter_masks(
major_axis = rp.major_axis_length
minor_axis = rp.minor_axis_length
h, w = seg.shape[:2]
# Using a small margin (2 pixels) to be safe
# Using a small margin (2 pixels) to be safe,
# filter any segmentations with bounding boxes close to the size of the image
# because SAM-2 can sometimes detect the image background itself.
bbox_area = (rp.bbox[2] - rp.bbox[0]) * (rp.bbox[3] - rp.bbox[1])
Comment thread
tylerjereddy marked this conversation as resolved.
max_allowed_area = (h - 2) * (w - 2)
if area >= area_threshold and circ >= circularity_threshold:
if (area >= area_threshold and circ >= circularity_threshold
and bbox_area < max_allowed_area):
binary_mask = seg.astype('uint8') * 255
contours, _ = cv2.findContours(binary_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
# reshape contours for plotting and remove any contours
# close to the size of the image because cv2.findContours
# can sometimes detect the image edge itself.
all_contours = [
c.reshape(-1, 2)[:, ::-1]
for c in contours
if cv2.contourArea(c) < max_allowed_area
]
# keep only the largest contour in each segmentation area
# and reshape for plotting
max_contour = max(contours, key=cv2.contourArea).squeeze(axis=1)
radius = np.sqrt(area / np.pi)
euler_number = rp.euler_number
# output of cucim ``rp`` stores values as objects
Expand All @@ -164,7 +165,7 @@ def analyze_and_filter_masks(
euler_number = euler_number.item()
mask_info = {
'bbox': rp.bbox,
'contour': all_contours,
'contour': max_contour,
'major_axis': major_axis,
'minor_axis': minor_axis,
'area': area,
Expand Down Expand Up @@ -202,7 +203,7 @@ def plot_filtered_masks(
for idx, row in masks_summary_df.iterrows():
contour = row['contour']
bbox = row['bbox']
ax.plot(contour[0][:, 1], contour[0][:, 0], linewidth=1, color='blue')
ax.plot(contour[:, 0], contour[:, 1], linewidth=1, color='blue')
min_row, min_col, max_row, max_col = bbox
rect = Rectangle(
(min_col, min_row),
Expand Down Expand Up @@ -271,7 +272,7 @@ def bubblesam_detection(
)

# save filtered dataframe as parquet file
# convert ``contours`` and ``bbox`` columns to list to save as parquet
# convert ``contour`` and ``bbox`` columns to list to save as parquet
save_filtered_df = filtered_df.copy()
save_filtered_df["bbox"] = save_filtered_df["bbox"].apply(list)
save_filtered_df["contour"] = save_filtered_df["contour"].apply(
Expand Down
33 changes: 33 additions & 0 deletions neat_ml/tests/test_bubblesam.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,36 @@ def test_run_bubblesam_model_cfg_error():
"""
with pytest.raises(ValueError, match="Must provide model configuration"):
run_bubblesam(pd.DataFrame(), Path("output"), detection_cfg={})

@pytest.mark.parametrize("seg_params, exp_bbox",
[
# a test case where the segmentation contains two disjoint areas
([[50, 60], [40, 45]], (50, 50, 60, 60)),
# a test case where the segmentation contains a region that touches
# the image boundary at the bottom right corner
([[90, 100]], (90, 90, 100, 100)),
]
)
def test_bubblesam_contours(seg_params, exp_bbox):
"""
test that running `analyze_and_filter_masks` generates a dataframe with
only a single contour per detection and without background areas
"""
# create two segmentation maps, one that takes up the whole image (background)
# and one containing the segmentation map generated using the test case parameters
seg = np.ones((100, 100)).astype(bool)
seg2 = np.zeros((100, 100)).astype(bool)
for seg_param in seg_params:
start = seg_param[0]
end = seg_param[1]
seg2[start:end, start:end] = True
input_df = pd.DataFrame({"segmentation": [seg, seg2]})
# call `analyze_and_filter_masks` to return filtered dataframe
# (the circularity of a perfect square is ~0.8, so lower the
# circularity threshold so that the background only gets filtered
# out by the bounding box area)
df = analyze_and_filter_masks(input_df, 25, 0.7, device="cpu")
# assert that there is only a single dataframe row after filtration
# corresponding to the appropriate segmentation map to keep from `seg2`
assert df.bbox.item() == exp_bbox
assert df.contour.item().shape == (36, 2)