Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
530 changes: 530 additions & 0 deletions .github/workflows/export-nuitka-onnx.yml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ output
pretrained/
workspace
workspace/
build_nuitka_onnx/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions gui/TIPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Reset memory if needed.
Controls:

- Use left-click for foreground annotation and right-click for background annotation.
- With the SAM2 ONNX click backend, use shift + left-drag to draw a rectangle box prompt.
- Use number keys or the spinbox to change the object to be operated on. If it does not respond, most likely the correct number of objects was not specified during program startup.
- Use left/right arrows to move between frames, shift+arrow to move by 10 frames, and alt/option+arrow to move to the start/end.
- Use F/space and B to propagate forward and backward, respectively.
Expand Down
28 changes: 27 additions & 1 deletion gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def __init__(self, controller, cfg: DictConfig) -> None:
# callbacks to be set by the controller
self.on_mouse_motion_xy = None
self.click_fn = None
self.box_prompt_start_fn = None
self.box_prompt_update_fn = None
self.box_prompt_end_fn = None
self._dragging_box_prompt = False
self._box_prompt_start = None

self.controller = controller
self.cfg = cfg
Expand Down Expand Up @@ -441,6 +446,16 @@ def on_mouse_press(self, event):
return

ex, ey = self.get_scaled_pos(event.position().x(), event.position().y())
if (
event.button() == Qt.MouseButton.LeftButton
and event.modifiers() & Qt.KeyboardModifier.ShiftModifier
and self.box_prompt_start_fn is not None
):
self._dragging_box_prompt = True
self._box_prompt_start = (ex, ey)
self.box_prompt_start_fn(ex, ey)
return

if event.button() == Qt.MouseButton.LeftButton:
action = 'left'
elif event.button() == Qt.MouseButton.RightButton:
Expand All @@ -453,9 +468,20 @@ def on_mouse_press(self, event):
def on_mouse_motion(self, event):
ex, ey = self.get_scaled_pos(event.position().x(), event.position().y())
self.on_mouse_motion_xy(ex, ey)
if self._dragging_box_prompt and self.box_prompt_update_fn is not None:
self.box_prompt_update_fn(ex, ey)

def on_mouse_release(self, event):
pass
if not self._dragging_box_prompt:
return
if event.button() != Qt.MouseButton.LeftButton:
return

self._dragging_box_prompt = False
ex, ey = self.get_scaled_pos(event.position().x(), event.position().y())
if self.box_prompt_end_fn is not None:
self.box_prompt_end_fn(ex, ey)
self._box_prompt_start = None

def on_play_video(self):
if self.timer.isActive():
Expand Down
64 changes: 64 additions & 0 deletions gui_onnx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# ONNX GUI Backend

This folder provides an ONNX-backed VOS processor for `interactive_demo_onnx.py` while keeping the same GUI workflow.

## Required ONNX files

- `weights/cutie_image_encoder.onnx`
- `weights/cutie_memory_write.onnx`
- `weights/cutie_read_decode.onnx`
- `weights/ritm_no_brs.onnx`
- Optional for SAM2-assisted first-frame clicks:
- `weights/sam2.1_hiera_small.encoder.onnx`
- `weights/sam2.1_hiera_small.decoder.onnx`

## Export commands

```bash
python -m scripts.export_onnx --output weights/cutie_image_encoder.onnx --weights weights/cutie-base-mega.pth --height 480 --width 864 --opset 18
python -m scripts.export_onnx_pipeline --weights weights/cutie-base-mega.pth --output-dir weights --use-dynamo --height 480 --width 864 --num-objects 1 --memory-frames 1
python -m scripts.export_ritm_onnx --weights weights/coco_lvis_h18_itermask.pth --output weights/ritm_no_brs.onnx --height 480 --width 864 --opset 18
python -m samexporter.export_sam2 --checkpoint weights/sam2.1_hiera_small.pt --output_encoder weights/sam2.1_hiera_small.encoder.onnx --output_decoder weights/sam2.1_hiera_small.decoder.onnx --model_type sam2.1_hiera_small
```

## Run GUI with ONNX backend

```bash
python interactive_demo_onnx.py --images <path_to_images>
python interactive_demo_onnx.py --video <path_to_video>
python interactive_demo_onnx.py <path_to_video>

# ONNX click model (NoBRS)
python interactive_demo_onnx.py --images <path_to_images>

# SAM2 clicks for first-frame object selection, Cutie ONNX for propagation
python interactive_demo_onnx.py \
--images <path_to_images> \
--click_backend_model sam2

# In the GUI:
# - left click: positive point
# - right click: negative point
# - shift + left-drag: rectangle box prompt for SAM2
```

Optional custom model paths:

```bash
python interactive_demo_onnx.py --ritm_onnx <ritm_no_brs.onnx> \
--onnx_encoder <encoder.onnx> \
--onnx_memory_write <memory_write.onnx> \
--onnx_read_decode <read_decode.onnx> \
--images <path_to_images>
```

## Notes

- Current `gui_onnx` backend is single-object oriented and should be run with `--num_objects 1`.
- `gui_onnx` reads limits from the ONNX files at runtime (not hardcoded in code).
- With `--click_backend_model sam2`, clicks only affect first-frame mask generation; Cutie still handles all temporal tracking/memory.
- The GitHub Actions release workflow exports and bundles the SAM2.1 small encoder/decoder into `weights/`.
- If your current ONNX was exported with small capacities, re-export with larger values:
- `--num-objects` controls max object count for `memory_write/read_decode`.
- `--memory-frames` controls temporal memory length accepted by `read_decode`.
- `num_objects` in GUI must be `<=` ONNX-exported object capacity.
9 changes: 9 additions & 0 deletions gui_onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = ["MainControllerOnnxNumpy"]


def __getattr__(name):
if name == "MainControllerOnnxNumpy":
from .main_controller import MainControllerOnnxNumpy

return MainControllerOnnxNumpy
raise AttributeError(name)
125 changes: 125 additions & 0 deletions gui_onnx/click_controller_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

from typing import List, Tuple

import cv2
import numpy as np


class ClickControllerOnnxNumpy:
def __init__(
self,
onnx_path: str,
device: str = "cpu",
max_clicks: int = 8,
click_radius: int = 5,
with_flip: bool = True,
):
try:
import onnxruntime as ort
except ImportError as exc:
raise ModuleNotFoundError(
"onnxruntime is required for the ONNX click backend."
) from exc

providers = ["CPUExecutionProvider"]
if device == "cuda":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]

self.session = ort.InferenceSession(onnx_path, providers=providers)
self.input_names = [item.name for item in self.session.get_inputs()]
if len(self.input_names) < 2:
raise RuntimeError("RITM ONNX must expose image and coord_features inputs.")
self.image_input = self.input_names[0]
self.coord_input = self.input_names[1]
image_shape = self.session.get_inputs()[0].shape
self.with_prev_mask = int(image_shape[1]) == 4
self.max_clicks = int(max_clicks)
self.click_radius = int(click_radius)
self.with_flip = bool(with_flip)

self.anchored = False
self._image_np = None
self._initial_prev_mask = None
self._prev_prediction = None
self._clicks: List[Tuple[int, int, bool]] = []

def unanchor(self):
self.anchored = False
self._image_np = None
self._initial_prev_mask = None
self._prev_prediction = None
self._clicks = []

def _build_coord_features(self, h: int, w: int) -> np.ndarray:
pos = np.zeros((h, w), dtype=np.float32)
neg = np.zeros((h, w), dtype=np.float32)
for x, y, is_pos in self._clicks[-self.max_clicks :]:
if is_pos:
cv2.circle(pos, (x, y), self.click_radius, 1.0, thickness=-1)
else:
cv2.circle(neg, (x, y), self.click_radius, 1.0, thickness=-1)
return np.stack([pos, neg], axis=0)[None, ...]

def _infer_logits(self, image_np: np.ndarray, prev_np: np.ndarray, coord_np: np.ndarray) -> np.ndarray:
if self.with_prev_mask:
image_in = np.concatenate([image_np, prev_np], axis=1)
else:
image_in = image_np
return self.session.run(
None,
{
self.image_input: image_in.astype(np.float32),
self.coord_input: coord_np.astype(np.float32),
},
)[0].astype(np.float32)

def _run(self) -> np.ndarray:
prev = self._prev_prediction
if prev is None:
prev = self._initial_prev_mask
if prev is None:
prev = np.zeros((1, 1, self._image_np.shape[-2], self._image_np.shape[-1]), dtype=np.float32)

coord = self._build_coord_features(self._image_np.shape[-2], self._image_np.shape[-1])
logits = self._infer_logits(self._image_np, prev, coord)

if self.with_flip:
logits_flip = self._infer_logits(
np.flip(self._image_np, axis=-1).copy(),
np.flip(prev, axis=-1).copy(),
np.flip(coord, axis=-1).copy(),
)
logits = 0.5 * (logits + np.flip(logits_flip, axis=-1).copy())

self._prev_prediction = 1.0 / (1.0 + np.exp(-logits))
return self._prev_prediction[:, 0]

def interact(
self,
image: np.ndarray,
x: int,
y: int,
is_positive: bool,
prev_mask: np.ndarray | None,
) -> np.ndarray:
if not self.anchored:
self._image_np = image.astype(np.float32, copy=True)
self._initial_prev_mask = None if prev_mask is None else prev_mask.astype(np.float32, copy=True)
self._prev_prediction = None
self._clicks = []
self.anchored = True

self._clicks.append((int(x), int(y), bool(is_positive)))
return self._run()

def undo(self):
if len(self._clicks) == 0:
return None
self._clicks.pop()
if len(self._clicks) == 0:
self._prev_prediction = None
return None
self._prev_prediction = None
pred = self._run()
return (pred > 0.5).astype(np.float32)
70 changes: 70 additions & 0 deletions gui_onnx/interaction_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from typing import Tuple

import numpy as np

from .click_controller_numpy import ClickControllerOnnxNumpy
from .sam2_click_controller_numpy import Sam2ClickControllerOnnxNumpy
from .interactive_utils_numpy import aggregate_wbg


class ClickInteractionOnnx:
def __init__(
self,
image: np.ndarray,
prev_mask: np.ndarray,
true_size: Tuple[int, int],
controller: ClickControllerOnnxNumpy | Sam2ClickControllerOnnxNumpy,
tar_obj: int,
):
self.image = image
self.prev_mask = prev_mask
self.controller = controller
self.h, self.w = true_size
self.tar_obj = tar_obj
self.first_click = True
self.obj_mask = None
self.out_prob = self.prev_mask.copy()

def push_point(self, x: int, y: int, is_neg: bool) -> None:
if self.first_click:
last_obj_mask = self.prev_mask[self.tar_obj : self.tar_obj + 1][None, ...]
self.obj_mask = self.controller.interact(
self.image[None, ...],
x,
y,
not is_neg,
prev_mask=last_obj_mask,
)
self.first_click = False
else:
self.obj_mask = self.controller.interact(
self.image[None, ...],
x,
y,
not is_neg,
prev_mask=None,
)

def set_box(self, x0: int, y0: int, x1: int, y1: int) -> None:
if not hasattr(self.controller, "set_box"):
raise NotImplementedError("Current click controller does not support box prompts.")

last_obj_mask = self.prev_mask[self.tar_obj : self.tar_obj + 1][None, ...]
self.obj_mask = self.controller.set_box(
self.image[None, ...],
x0,
y0,
x1,
y1,
prev_mask=last_obj_mask,
)
self.first_click = False

def predict(self) -> np.ndarray:
self.out_prob = self.prev_mask.copy()
self.out_prob = np.clip(self.out_prob, a_min=None, a_max=0.9)
self.out_prob[self.tar_obj] = self.obj_mask[0]
self.out_prob = aggregate_wbg(self.out_prob[1:], keep_bg=True, hard=True)
return self.out_prob.astype(np.float32)
Loading