Source code for image_processing.detector

"""Edge detector that applies a kernel stack to an image and combines responses."""

from __future__ import annotations

import torch
import torch.nn.functional as torch_f

from .combination import CombineFn, sum_of_squares
from .kernels import BaseKernel, ElongatedMaskKernel


[docs] class EdgeDetector: """Apply a stack of convolution kernels to detect edges in an image. The detector runs the following steps: 1. Moves the input image to the kernel's device and casts it to ``float32`` if needed. 2. Applies all ``N`` kernel orientations to each of the ``C`` colour channels via :func:`torch.nn.functional.conv2d` with ``padding='same'``, producing a ``(C, N, H, W)`` response tensor. 3. Passes the response tensor to ``combine_fn`` to obtain a ``(H, W)`` edge-strength map. 4. Optionally normalises the map to ``[0, 1]``. The output spatial size always matches the input spatial size. Parameters ---------- kernel : BaseKernel or None Provider of the oriented kernel stack. Defaults to :class:`~image_processing.ElongatedMaskKernel` with default parameters. combine_fn : CombineFn or None Callable ``(C, N, H, W) → (H, W)`` that aggregates the response tensor. Defaults to :func:`~image_processing.combination.sum_of_squares`. normalize : bool When ``True`` (default), divides the edge map by its maximum so the output lies in ``[0, 1]``. Has no effect when the map is all zeros. Examples -------- Minimal usage with default parameters: >>> import torch >>> from image_processing import EdgeDetector >>> detector = EdgeDetector() >>> image = torch.rand(3, 256, 256) # (C, H, W) float image >>> edges = detector.detect(image) >>> edges.shape torch.Size([256, 256]) Custom kernel and combination function: >>> from image_processing import ( ... EdgeDetector, ... ElongatedMaskKernel, ... ElongatedMaskParams, ... ) >>> from image_processing.combination import sum_of_powers >>> params = ElongatedMaskParams( ... n_angles=18, ... kernel_half_size=30, ... length_falloff=0.1, ... width_falloff=1.0, ... ) >>> detector = EdgeDetector( ... kernel=ElongatedMaskKernel(params, device="cpu"), ... combine_fn=sum_of_powers(3.0), ... ) >>> detector.detect(torch.rand(3, 128, 128)).shape torch.Size([128, 128]) """ def __init__( self, kernel: BaseKernel | None = None, combine_fn: CombineFn | None = None, normalize: bool = True, ) -> None: self.kernel: BaseKernel = ( kernel if kernel is not None else ElongatedMaskKernel() ) self.combine_fn: CombineFn = ( combine_fn if combine_fn is not None else sum_of_squares ) self.normalize = normalize
[docs] def detect(self, image: torch.Tensor) -> torch.Tensor: """Detect edges in a single image. Parameters ---------- image : torch.Tensor Input image of shape ``(C, H, W)`` or ``(H, W)``. Grayscale inputs ``(H, W)`` are expanded to ``(1, H, W)`` automatically. Non-float tensors are cast to ``float32``. Returns ------- torch.Tensor Edge map of shape ``(H, W)`` on the kernel's device. Raises ------ ValueError If ``image`` is not 2-D or 3-D. Examples -------- >>> import torch >>> from image_processing import EdgeDetector >>> EdgeDetector().detect(torch.rand(3, 64, 64)).shape torch.Size([64, 64]) """ if image.dim() == 2: image = image.unsqueeze(0) if image.dim() != 3: raise ValueError( f"Expected a 2-D (H, W) or 3-D (C, H, W) tensor, " f"got shape {tuple(image.shape)}." ) k = self.kernel.kernels image = image.to(k.device) if not image.is_floating_point(): image = image.float() n_channels, height, width = image.shape n_kernels, k_height, k_width = k.shape # Treat colour channels as independent batch items so that a single # conv2d call applies all N kernels to each channel simultaneously. x = image.view(n_channels, 1, height, width) w = k.view(n_kernels, 1, k_height, k_width) response = torch_f.conv2d(x, w, padding="same") # (C, N, H, W) edge_map = self.combine_fn(response) # (H, W) if self.normalize and edge_map.max() > 0: edge_map = edge_map / edge_map.max() return edge_map
[docs] def detect_batch(self, images: list[torch.Tensor]) -> list[torch.Tensor]: """Detect edges in a list of images. Images may have different spatial sizes; each is processed independently via :meth:`detect`. Parameters ---------- images : list of torch.Tensor Each element is a ``(C, H, W)`` or ``(H, W)`` tensor. Returns ------- list of torch.Tensor One ``(H, W)`` edge map per input image. Examples -------- >>> import torch >>> from image_processing import EdgeDetector >>> imgs = [torch.rand(3, 64, 64), torch.rand(3, 128, 96)] >>> [m.shape for m in EdgeDetector().detect_batch(imgs)] [torch.Size([64, 64]), torch.Size([128, 96])] """ return [self.detect(img) for img in images]