Source code for image_processing.kernels

"""Convolution kernel classes for edge detection."""

from __future__ import annotations

import abc
from typing import cast

import numpy as np
import PIL.Image
import PIL.ImageOps
import torch
import torchvision.transforms

from .params import BaseKernelParams, ElongatedMaskParams


def _resolve_device(device: torch.device | str | None) -> torch.device:
    if device is None:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return torch.device(device)


[docs] class BaseKernel(abc.ABC): """Abstract base class for edge-detection convolution kernels. Concrete subclasses implement :meth:`build` to construct a ``(N, kH, kW)`` kernel stack on ``self.device``. The result is cached after the first access and can be discarded with :meth:`reset`. Parameters ---------- params : BaseKernelParams Kernel configuration. device : torch.device or str or None Computation device. Defaults to CUDA when available, otherwise CPU. Examples -------- Minimal subclass skeleton:: class MyKernel(BaseKernel): def build(self) -> torch.Tensor: # return tensor of shape (n_kernels, kH, kW) on self.device ... """ def __init__( self, params: BaseKernelParams, device: torch.device | str | None = None, ) -> None: self.params = params self.device = _resolve_device(device) self._cache: torch.Tensor | None = None
[docs] @abc.abstractmethod def build(self) -> torch.Tensor: """Build and return the kernel stack. Returns ------- torch.Tensor Shape ``(n_kernels, kH, kW)`` on ``self.device``. """
@property def kernels(self) -> torch.Tensor: """Return the lazily built and cached kernel tensor. Returns ------- torch.Tensor Shape ``(n_kernels, kH, kW)``. """ if self._cache is None: self._cache = self.build() return self._cache
[docs] def reset(self) -> None: """Invalidate the cached kernels. The next access to :attr:`kernels` triggers a fresh call to :meth:`build`. """ self._cache = None
[docs] class ElongatedMaskKernel(BaseKernel): """Rotated elongated-stripe convolution kernels for edge detection. Implements the method from *Contour-texture separation: part 2* (Antal, 2024). The construction proceeds as follows: 1. A horizontal stripe of decaying values is placed on a blank square canvas. Weights fall off along the stripe length as ``1 / (length_falloff * |j| + 1)`` and optionally across the width as ``1 / (width_falloff * i + 1)``. 2. The canvas is converted to a PIL image and rotated to ``n_angles`` orientations evenly distributed in ``[0°, 180°)``. 3. Each rotated mask is anti-symmetrised by subtracting its 180° rotation (mirror flip). The resulting kernels detect signed intensity gradients perpendicular to the stripe. With ``padding='same'`` the convolution output preserves the spatial size of the input image. Parameters ---------- params : ElongatedMaskParams or None Kernel configuration. Uses default :class:`~image_processing.ElongatedMaskParams` when ``None``. device : torch.device or str or None Computation device. Defaults to CUDA when available, otherwise CPU. Examples -------- Default configuration (10 angles, 41 x 41 kernel): >>> from image_processing import ElongatedMaskKernel >>> kernel = ElongatedMaskKernel() >>> kernel.kernels.shape torch.Size([10, 41, 41]) GPU-tuned configuration matching the notebook's GPU example: >>> from image_processing import ElongatedMaskKernel, ElongatedMaskParams >>> params = ElongatedMaskParams( ... n_angles=18, ... kernel_half_size=30, ... stripe_half_width=5, ... stripe_half_length=30, ... length_falloff=0.1, ... width_falloff=1.0, ... ) >>> kernel = ElongatedMaskKernel(params, device="cpu") >>> kernel.kernels.shape torch.Size([18, 61, 61]) """ def __init__( self, params: ElongatedMaskParams | None = None, device: torch.device | str | None = None, ) -> None: super().__init__( params if params is not None else ElongatedMaskParams(), device, )
[docs] def build(self) -> torch.Tensor: """Build the anti-symmetrised rotated stripe kernel stack. Returns ------- torch.Tensor Shape ``(n_angles, 2 * kernel_half_size + 1, 2 * kernel_half_size + 1)`` on ``self.device``. Raises ------ ValueError If the stripe dimensions exceed ``kernel_half_size``. """ p = cast(ElongatedMaskParams, self.params) half = p.kernel_half_size size = 2 * half + 1 # odd size → clear centre pixel, no asymmetric-pad warning errors: list[str] = [] if p.stripe_half_length > half: errors.append( f"stripe_half_length ({p.stripe_half_length}) must be " f"<= kernel_half_size ({half})" ) if p.stripe_half_width > half: errors.append( f"stripe_half_width ({p.stripe_half_width}) must be " f"<= kernel_half_size ({half})" ) if errors: raise ValueError(". ".join(errors) + ".") canvas = self._make_canvas(p, size, half) canvas_pil = torchvision.transforms.ToPILImage()(canvas) angles = np.arange(0.0, 180.0, 180.0 / p.n_angles) masks: list[torch.Tensor] = [] for phi in angles: rotated = canvas_pil.rotate( float(phi), PIL.Image.Resampling.NEAREST, expand=False ) fwd = torchvision.transforms.ToTensor()(rotated) rev = torchvision.transforms.ToTensor()( PIL.ImageOps.mirror(PIL.ImageOps.flip(rotated)) ) masks.append(fwd - rev) stacked = torch.cat(masks).view(p.n_angles, size, size) return stacked.to(self.device)
@staticmethod def _make_canvas(p: ElongatedMaskParams, size: int, half: int) -> torch.Tensor: """Create the base unrotated stripe canvas as a ``(size, size)`` tensor.""" canvas = torch.zeros((size, size)) j_offsets = torch.arange(-p.stripe_half_length, p.stripe_half_length) length_weights = 1.0 / (p.length_falloff * j_offsets.abs().float() + 1.0) col_start = half - p.stripe_half_length col_end = half + p.stripe_half_length for i in range(p.stripe_half_width): width_weight = ( 1.0 / (p.width_falloff * float(i) + 1.0) if p.width_falloff > 0.0 else 1.0 ) canvas[half + i, col_start:col_end] = length_weights * width_weight return canvas