"""Frame extraction and deduplication for VideoProcessorPipeline.
Extraction uses OpenCV (CPU-bound → ProcessPoolExecutor).
Deduplication supports two algorithms:
- pHash : perceptual hash via imagehash (fast, default)
- SSIM : structural similarity via scikit-image (more accurate)
"""
from __future__ import annotations
import io
import os
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from ._models import DeduplicationMethod, FrameEntry
# ---------------------------------------------------------------------------
# Lazy-import helpers
# ---------------------------------------------------------------------------
def _require_cv2(): # type: ignore[return]
try:
import cv2
return cv2
except ImportError as exc:
raise ImportError(
"opencv-python is required for frame extraction. "
"Install with: pip install ractogateway[pipelines-video]"
) from exc
def _require_imagehash(): # type: ignore[return]
try:
import imagehash
return imagehash
except ImportError as exc:
raise ImportError(
"imagehash is required for pHash deduplication. "
"Install with: pip install ractogateway[pipelines-video]"
) from exc
def _require_pil(): # type: ignore[return]
try:
from PIL import Image
return Image
except ImportError as exc:
raise ImportError(
"Pillow is required for image processing. "
"Install with: pip install ractogateway[pipelines-video]"
) from exc
def _require_skimage(): # type: ignore[return]
try:
from skimage.metrics import structural_similarity
return structural_similarity
except ImportError as exc:
raise ImportError(
"scikit-image is required for SSIM deduplication. "
"Install with: pip install ractogateway[pipelines-video]"
) from exc
def _require_numpy(): # type: ignore[return]
try:
import numpy as np
return np
except ImportError as exc:
raise ImportError(
"numpy is required for SSIM deduplication. "
"Install with: pip install ractogateway[pipelines-video]"
) from exc
# ---------------------------------------------------------------------------
# Internal raw-frame dataclass (not part of public API)
# ---------------------------------------------------------------------------
@dataclass
class _RawFrame:
frame_id: int
timestamp: float # seconds
bgr_bytes: bytes # raw OpenCV BGR array serialised via numpy tobytes()
height: int
width: int
# ---------------------------------------------------------------------------
# Frame extraction (runs in worker processes)
# ---------------------------------------------------------------------------
def _extract_worker(
video_path: str,
worker_idx: int,
n_workers: int,
target_fps: float,
max_frames: int | None,
frame_format: str,
) -> list[tuple[int, float, bytes]]:
"""Extract frames from a video file in a subprocess.
Returns list of (frame_id, timestamp_seconds, image_bytes).
Each worker handles every n_workers-th frame (stripe partition).
"""
import cv2
import numpy as np
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Cannot open video: {video_path}")
native_fps: float = cap.get(cv2.CAP_PROP_FPS) or 25.0
step = max(1, round(native_fps / target_fps))
encode_ext = ".jpg" if frame_format.upper() == "JPEG" else ".png"
encode_params: list[int] = (
[cv2.IMWRITE_JPEG_QUALITY, 85] if encode_ext == ".jpg" else []
)
results: list[tuple[int, float, bytes]] = []
frame_idx = 0
global_id = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_idx % step == 0 and frame_idx // step % n_workers == worker_idx:
timestamp = frame_idx / native_fps
ok, buf = cv2.imencode(encode_ext, frame, encode_params)
if ok:
results.append((global_id, timestamp, buf.tobytes()))
global_id += 1
frame_idx += 1
if max_frames and global_id >= max_frames:
break
cap.release()
del np # suppress unused import warning
return results
[docs]
def load_frames_from_paths(
frame_paths: list[Path],
*,
frame_format: str = "JPEG",
) -> list[tuple[int, float, bytes]]:
"""Load pre-extracted frame images from disk.
Timestamps are inferred as sequential integers (0, 1, 2 …) since
the user skipped the extraction step.
"""
pil = _require_pil()
results: list[tuple[int, float, bytes]] = []
for idx, p in enumerate(frame_paths):
img = pil.open(p).convert("RGB")
buf = io.BytesIO()
fmt = "JPEG" if frame_format.upper() == "JPEG" else "PNG"
img.save(buf, format=fmt)
results.append((idx, float(idx), buf.getvalue()))
return results
# ---------------------------------------------------------------------------
# Similarity comparison
# ---------------------------------------------------------------------------
def _phash_similarity(img_bytes_a: bytes, img_bytes_b: bytes) -> float:
"""Return perceptual hash similarity (0-100 %)."""
imagehash = _require_imagehash()
pil = _require_pil()
img_a = pil.open(io.BytesIO(img_bytes_a)).convert("RGB")
img_b = pil.open(io.BytesIO(img_bytes_b)).convert("RGB")
hash_a = imagehash.phash(img_a)
hash_b = imagehash.phash(img_b)
# pHash produces a 64-bit hash; max hamming distance = 64
distance = hash_a - hash_b
return (1.0 - distance / 64.0) * 100.0
def _phash_as_int(img_bytes: bytes) -> int:
"""Compute pHash and return as a plain int (picklable, O(1) comparison).
The hash is a 64-bit perceptual hash encoded as a hex string by imagehash;
we convert to int so Hamming-distance comparisons cost O(1) via XOR + popcount.
"""
imagehash = _require_imagehash()
pil = _require_pil()
h = imagehash.phash(pil.open(io.BytesIO(img_bytes)).convert("RGB"))
return int(str(h), 16)
def _hamming_similarity(hash_int_a: int, hash_int_b: int) -> float:
"""Hamming-based pHash similarity in % between two pre-computed hash ints.
O(1) — XOR + popcount, no PIL or imagehash required after pre-computation.
64-bit hash → max hamming distance = 64.
"""
distance = bin(hash_int_a ^ hash_int_b).count("1")
return (1.0 - distance / 64.0) * 100.0
def _hash_frame_tuple(frame_tuple: tuple[int, float, bytes]) -> int:
"""Top-level helper for :func:`deduplicate_frames_fast` (ThreadPoolExecutor map)."""
return _phash_as_int(frame_tuple[2])
def _ssim_similarity(img_bytes_a: bytes, img_bytes_b: bytes) -> float:
"""Return SSIM structural similarity (0-100 %)."""
structural_similarity = _require_skimage()
np = _require_numpy()
pil = _require_pil()
img_a = np.array(pil.open(io.BytesIO(img_bytes_a)).convert("L"))
img_b = np.array(pil.open(io.BytesIO(img_bytes_b)).convert("L"))
# Resize to same shape if needed
if img_a.shape != img_b.shape:
pil_b = pil.open(io.BytesIO(img_bytes_b)).convert("L").resize(
(img_a.shape[1], img_a.shape[0])
)
img_b = np.array(pil_b)
score: float = structural_similarity(img_a, img_b, data_range=255)
return score * 100.0
[docs]
def compute_similarity(
img_bytes_a: bytes,
img_bytes_b: bytes,
method: DeduplicationMethod,
) -> float:
"""Compute similarity % between two images using *method*."""
if method == DeduplicationMethod.PHASH:
return _phash_similarity(img_bytes_a, img_bytes_b)
return _ssim_similarity(img_bytes_a, img_bytes_b)
# ---------------------------------------------------------------------------
# Deduplication
# ---------------------------------------------------------------------------
[docs]
def deduplicate_frames(
raw_frames: list[tuple[int, float, bytes]],
*,
similarity_threshold: float,
method: DeduplicationMethod,
) -> list[FrameEntry]:
"""Deduplicate *raw_frames* using the chosen similarity method.
A frame is **discarded** when its similarity to the last *kept* frame
is >= *similarity_threshold* (e.g. 90 %).
Parameters
----------
raw_frames:
List of ``(frame_id, timestamp, image_bytes)`` as returned by
:func:`extract_frames`.
similarity_threshold:
Percentage threshold (0-100). Frames >= this are dropped.
method:
:class:`DeduplicationMethod.PHASH` or :class:`DeduplicationMethod.SSIM`.
Returns
-------
list[FrameEntry]
All frames with ``kept`` flag set appropriately.
"""
entries: list[FrameEntry] = []
last_kept_bytes: bytes | None = None
for frame_id, timestamp, img_bytes in raw_frames:
if last_kept_bytes is None:
# Always keep the first frame
entries.append(
FrameEntry(
frame_id=frame_id,
timestamp=timestamp,
similarity_to_prev=None,
kept=True,
image_data=img_bytes,
)
)
last_kept_bytes = img_bytes
continue
sim = compute_similarity(img_bytes, last_kept_bytes, method)
keep = sim < similarity_threshold
entries.append(
FrameEntry(
frame_id=frame_id,
timestamp=timestamp,
similarity_to_prev=round(sim, 2),
kept=keep,
image_data=img_bytes if keep else None,
)
)
if keep:
last_kept_bytes = img_bytes
return entries
[docs]
def deduplicate_frames_fast(
raw_frames: list[tuple[int, float, bytes]],
*,
similarity_threshold: float,
method: DeduplicationMethod,
max_hash_workers: int = 4,
) -> list[FrameEntry]:
"""DSA-optimised deduplication for the async pipeline.
**pHash fast path** — two-stage algorithm:
1. Pre-compute all perceptual hashes in **parallel** via
:class:`~concurrent.futures.ThreadPoolExecutor` (O(n/k) wall-time
with k workers instead of O(n) sequential).
2. Sequential dedup using O(1) integer XOR + popcount (Hamming distance),
bypassing PIL/imagehash entirely after the pre-computation stage.
This converts what was O(n) sequential CPU work into O(n/k) parallel work
plus O(n) trivial bit-arithmetic — a significant speedup for large frame
sets (20-200+ frames).
**SSIM path** — falls back to :func:`deduplicate_frames` because SSIM
comparison is inherently sequential (each frame compared to the last *kept*
frame, which is only known after the previous step).
"""
if not raw_frames:
return []
if method == DeduplicationMethod.SSIM:
# SSIM requires full image pixels at each step; sequential fallback.
return deduplicate_frames(
raw_frames,
similarity_threshold=similarity_threshold,
method=method,
)
# ── pHash fast path ─────────────────────────────────────────────────────
_require_imagehash() # ensure dep available before spawning threads
n_workers = min(max_hash_workers, len(raw_frames))
with ThreadPoolExecutor(max_workers=n_workers) as pool:
hashes: list[int] = list(pool.map(_hash_frame_tuple, raw_frames))
entries: list[FrameEntry] = []
last_kept_hash: int | None = None
for (frame_id, timestamp, img_bytes), h in zip(raw_frames, hashes, strict=False):
if last_kept_hash is None:
entries.append(
FrameEntry(
frame_id=frame_id,
timestamp=timestamp,
similarity_to_prev=None,
kept=True,
image_data=img_bytes,
)
)
last_kept_hash = h
continue
sim = _hamming_similarity(h, last_kept_hash)
keep = sim < similarity_threshold
entries.append(
FrameEntry(
frame_id=frame_id,
timestamp=timestamp,
similarity_to_prev=round(sim, 2),
kept=keep,
image_data=img_bytes if keep else None,
)
)
if keep:
last_kept_hash = h
return entries