Skip to content

ascota_core.mobilesam_wand

The mobilesam_wand module provides an interactive MobileSAM session API for point-prompt segmentation.

It supports:

  • lazy model loading with checkpoint path resolution,
  • per-image predictor setup (set_image), and
  • iterative multi-point mask prediction (predict) for foreground/background edits.

MobileSAM session API for interactive point-prompt segmentation.

Loads the model once; call set_image once per frame, then predict() with multi-point prompts without reloading weights or re-encoding the image.

MobileSamSession

MobileSamSession()

One predictor image at a time; thread-safe for single-worker use.

Source code in src/ascota_core/mobilesam_wand.py
81
82
83
def __init__(self) -> None:
    self._lock = threading.Lock()
    self._image_shape: Optional[Tuple[int, int]] = None

set_image

set_image(rgb)

Encode image for SAM (call once when the displayed image changes).

Source code in src/ascota_core/mobilesam_wand.py
85
86
87
88
89
90
91
92
93
94
95
96
def set_image(self, rgb: np.ndarray) -> None:
    """Encode image for SAM (call once when the displayed image changes)."""
    arr = np.asarray(rgb)
    if arr.dtype != np.uint8:
        arr = arr.astype(np.uint8)
    if arr.ndim != 3 or arr.shape[2] != 3:
        raise ValueError(f"Expected HxWx3 uint8 RGB, got shape {arr.shape}")

    predictor = _load_predictor()
    with self._lock:
        predictor.set_image(arr)
        self._image_shape = (arr.shape[0], arr.shape[1])

predict

predict(points)

Run SAM with multiple point prompts. Labels: 1=foreground, 0=background.

Source code in src/ascota_core/mobilesam_wand.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def predict(self, points: List[Point]) -> np.ndarray:
    """Run SAM with multiple point prompts. Labels: 1=foreground, 0=background."""
    if not points:
        raise ValueError("points must be non-empty")
    if self._image_shape is None:
        raise RuntimeError("Call set_image() before predict()")

    h, w = self._image_shape
    coords = np.array([[p[0], p[1]] for p in points], dtype=np.float32)
    labels = np.array([p[2] for p in points], dtype=np.int64)

    for i, (x, y, _) in enumerate(points):
        if not (0 <= x < w and 0 <= y < h):
            raise ValueError(f"Point {i} ({x}, {y}) out of bounds for image size {w}x{h}")

    predictor = _load_predictor()
    with self._lock:
        with torch.inference_mode():
            masks, _scores, _logits = predictor.predict(
                point_coords=coords,
                point_labels=labels,
                multimask_output=False,
            )

    mask = masks[0]
    return mask.astype(np.uint8)

get_mobile_sam_checkpoint_path

get_mobile_sam_checkpoint_path()

Resolve path to mobile_sam.pt (same search order as preprocess service).

Source code in src/ascota_core/mobilesam_wand.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def get_mobile_sam_checkpoint_path() -> Path:
    """Resolve path to mobile_sam.pt (same search order as preprocess service)."""
    core_dir = Path(__file__).resolve().parent
    repo_root = core_dir.parent.parent

    preprocess_weights = repo_root / "preprocess" / "backend" / "weights" / "mobile_sam.pt"
    if preprocess_weights.exists():
        return preprocess_weights

    project_weights = repo_root / "weights" / "mobile_sam.pt"
    if project_weights.exists():
        return project_weights

    try:
        import mobile_sam

        mobile_sam_path = Path(mobile_sam.__file__).resolve().parent.parent / "weights" / "mobile_sam.pt"
        if mobile_sam_path.exists():
            return mobile_sam_path
    except ImportError:
        pass

    return preprocess_weights

_load_predictor

_load_predictor()

Lazy-init global MobileSAM model and SamPredictor.

Source code in src/ascota_core/mobilesam_wand.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _load_predictor():
    """Lazy-init global MobileSAM model and SamPredictor."""
    global _mobile_sam_model, _mobile_sam_predictor, _model_loaded

    with _load_lock:
        if _model_loaded:
            return _mobile_sam_predictor

        from mobile_sam import SamPredictor, sam_model_registry

        checkpoint_path = get_mobile_sam_checkpoint_path()
        if not checkpoint_path.exists():
            raise FileNotFoundError(
                f"MobileSAM checkpoint not found at {checkpoint_path}. "
                "Download mobile_sam.pt into repo weights/ or preprocess/backend/weights/."
            )

        model_type = "vit_t"
        device = "cuda" if torch.cuda.is_available() else "cpu"

        _mobile_sam_model = sam_model_registry[model_type](checkpoint=str(checkpoint_path))
        _mobile_sam_model.to(device=device)
        _mobile_sam_model.eval()
        _mobile_sam_predictor = SamPredictor(_mobile_sam_model)
        _model_loaded = True

        return _mobile_sam_predictor