diff --git a/tomotwin/modules/inference/locator.py b/tomotwin/modules/inference/locator.py index ac71db8..0a41dae 100644 --- a/tomotwin/modules/inference/locator.py +++ b/tomotwin/modules/inference/locator.py @@ -76,12 +76,9 @@ def nms(boxes: pd.DataFrame, size: int, nms_threshold=0.6) -> pd.DataFrame: boxes_i_rep = ones * box_i ious = Locator._bbox_iou_vec_3d(boxes_i_rep, boxes_data[close_indicis]) - iou_mask = np.empty(len(boxes_data), dtype=int) - iou_mask_close = ious > nms_threshold - - iou_mask[close_indicis] = iou_mask_close - iou_mask[i] = 0 # ignore current - iou_mask = iou_mask == 1 + iou_mask = np.zeros(len(boxes_data), dtype=bool) + iou_mask[close_indicis] = ious > nms_threshold + iou_mask[i] = False boxes_data[iou_mask, 6] = 0