Source code for torchbp.util

import torch
from torch import Tensor
import torch.nn.functional as F
from math import pi
import numpy as np
from scipy.signal import get_window
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from .grid import PolarGrid, CartesianGrid


[docs] def bp_polar_range_dealias( img: Tensor, origin: Tensor, fc: float, grid_polar: "PolarGrid | dict", alias_fmod: float = 0 ) -> Tensor: """ De-alias range-axis spectrum of polar SAR image processed with backprojection. [1]_ Parameters ---------- img : Tensor Complex input image. Shape should be: [Range, azimuth]. origin : Tensor Center of the platform position. fc : float RF center frequency. grid_polar : PolarGrid or dict Polar grid definition. Can be: - PolarGrid object: PolarGrid(r_range=(r0, r1), theta_range=(theta0, theta1), nr=nr, ntheta=ntheta) - dict: {"r": (r0, r1), "theta": (theta0, theta1), "nr": nr, "ntheta": ntheta} alias_fmod : float Range modulation frequency applied to input. References ---------- .. [1] T. Shi, X. Mao, A. Jakobsson and Y. Liu, "Extended PGA for Spotlight SAR-Filtered Backprojection Imagery," in IEEE Geoscience and Remote Sensing Letters, vol. 19, pp. 1-5, 2022, Art no. 4516005. Returns ------- img : Tensor SAR image without range spectrum aliasing. """ r0, r1 = grid_polar["r"] theta0, theta1 = grid_polar["theta"] ntheta = grid_polar["ntheta"] nr = grid_polar["nr"] dtheta = (theta1 - theta0) / ntheta dr = (r1 - r0) / nr er = torch.arange(nr, device=img.device) r = r0 + dr * er theta = theta0 + dtheta * torch.arange(ntheta, device=img.device) x = r[:, None] * torch.sqrt(1 - torch.square(theta))[None, :] y = r[:, None] * theta[None, :] if origin.dim() == 2: origin = origin[0] d = torch.sqrt((x - origin[0]) ** 2 + (y - origin[1]) ** 2 + origin[2] ** 2) c0 = 299792458 phase = torch.exp(-1j * 4 * pi * fc * d / c0 + 1j*alias_fmod*er[:,None]) if img.dim() == 3: phase = phase.unsqueeze(0) return phase * img
[docs] def bp_polar_range_alias( img: Tensor, origin: Tensor, fc: float, grid_polar: "PolarGrid | dict", alias_fmod: float=0 ) -> Tensor: """ Inverse of bp_polar_range_dealias. Parameters ---------- img : Tensor Complex input image. Shape should be: [Range, azimuth]. origin : Tensor Center of the platform position. fc : float RF center frequency. grid_polar : PolarGrid or dict Polar grid definition. Can be: - PolarGrid object: PolarGrid(r_range=(r0, r1), theta_range=(theta0, theta1), nr=nr, ntheta=ntheta) - dict: {"r": (r0, r1), "theta": (theta0, theta1), "nr": nr, "ntheta": ntheta} alias_fmod : float Range modulation frequency applied to output. Returns ------- img : Tensor SAR image with range spectrum aliasing. """ return bp_polar_range_dealias(img, origin, -fc, grid_polar, -alias_fmod)
[docs] def diff(x: Tensor, dim: int = -1, same_size: bool = False) -> Tensor: """ ``np.diff`` implemented in torch. Parameters ---------- x : Tensor Input tensor. dim : int Dimension. same_size : bool Pad output to same size as input. Returns ------- d : Tensor Difference tensor. """ if dim == 0: if same_size: padding = [0 for i in range(2*len(x.shape))] padding[-2] = 1 padding = tuple(padding) return torch.nn.functional.pad(x[1:] - x[:-1], padding) else: return x[1:] - x[:-1] if dim != -1: raise NotImplementedError("Only dim=0 and dim=-1 is implemented") if same_size: return torch.nn.functional.pad(x[..., 1:] - x[..., :-1], (1, 0)) else: return x[..., 1:] - x[..., :-1]
[docs] def unwrap(phi: Tensor, dim: int = -1) -> Tensor: """ ``np.unwrap`` implemented in torch. Parameters ---------- phi : Tensor Input tensor. dim : int Dimension. Returns ------- phi : Tensor Unwrapped tensor. """ if dim != -1: raise NotImplementedError("Only dim=-1 is implemented") phi_wrap = ((phi + torch.pi) % (2 * torch.pi)) - torch.pi dphi = diff(phi_wrap, same_size=True) dphi_m = ((dphi + torch.pi) % (2 * torch.pi)) - torch.pi dphi_m[(dphi_m == -torch.pi) & (dphi > 0)] = torch.pi phi_adj = dphi_m - dphi phi_adj[dphi.abs() < torch.pi] = 0 return phi_wrap + phi_adj.cumsum(dim)
[docs] def unwrap_ref(x: Tensor, y: Tensor) -> Tensor: """ Solve for integer array k such that x + k*2pi is closest to y. `k = round((y - x) / (2pi))`. Parameters ---------- x : Tensor Phase wrapped signal. y : Tensor Reference signal. Returns ------- unwrapped_x : Tensor Phase unwrapped x """ k = torch.round((y - x) / (2 * torch.pi)) unwrapped_x = x + k * 2 * torch.pi return unwrapped_x
[docs] def quad_interp(a: Tensor, v: int) -> Tensor: """ Quadractic peak interpolation. Useful for FFT peak interpolation. Parameters ---------- a : Tensor Input tensor. v : int Peak index. Returns ------- f : Tensor Estimated fractional peak index. """ a1 = a[(v - 1) % len(a)] a2 = a[v % len(a)] a3 = a[(v + 1) % len(a)] return 0.5 * (a1 - a3) / (a1 - 2 * a2 + a3)
[docs] def argmax_nd(x: Tensor) -> tuple[int, ...]: """ `torch.argmax` but returns N-dimensional index of the peak """ d = torch.argmax(x).item() res = [] for s in x.shape[::-1]: d, m = divmod(d, s) res.append(m) return tuple(res)[::-1]
[docs] def find_image_shift_1d(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: """ Find shift between images that maximizes correlation. Parameters ---------- x : Tensor Input tensor. y : Tensor Input tensor. Should have same shape as x. dim : int Dimensions to shift. Returns ------- c : Tensor Estimated shift. """ if x.shape != y.shape: raise ValueError("Input shapes should be identical") if dim < 0: dim = x.dim() + dim fx = torch.fft.fft(x, dim=dim) fy = torch.fft.fft(y, dim=dim) c = (fx * fy.conj()) / (torch.abs(fx) * torch.abs(fy)) other_dims = [i for i in range(x.dim()) if i != dim] c = torch.abs(torch.fft.ifft(c, dim=dim)) if len(other_dims) > 0: c = torch.mean(c, dim=other_dims) return torch.argmax(c)
[docs] def subset_cart( img: Tensor, grid_cart: "CartesianGrid | dict", x0: float, x1: float, y0: float, y1: float ) -> tuple[Tensor, dict]: """Cartesian image subset. Parameters ---------- img : Tensor Input image. grid_cart : CartesianGrid or dict Cartesian grid definition. Can be: - CartesianGrid object: CartesianGrid(x_range=(x0, x1), y_range=(y0, y1), nx=nx, ny=ny) - dict: {"x": (x0, x1), "y": (y0, y1), "nx": nx, "ny": ny} x0 : float Subset x0. x1 : float Subset x1. y0 : float Subset y0. y1 : float Subset y1. Returns ------- img : Tensor Subset of input image. grid : dict Grid """ gx0, gx1 = grid_cart["x"] gy0, gy1 = grid_cart["y"] nx = grid_cart["nx"] ny = grid_cart["ny"] dx = (gx1 - gx0) / nx dy = (gy1 - gy0) / ny nx0 = max(0, min(nx, int((x0 - gx0) / dx))) nx1 = max(0, min(nx, int((x1 - gx0) / dx))) ny0 = max(0, min(ny, int((y0 - gy0) / dy))) ny1 = max(0, min(ny, int((y1 - gy0) / dy))) out = img[..., nx0:nx1, ny0:ny1] grid_new = { "x": (gx0 + dx * nx0, gx0 + dx * nx1), "y": (gy0 + ny0 * dy, gy0 + ny1 * dy), "nr": out.shape[-2], "ntheta": out.shape[-1], } return out, grid_new
[docs] def subset_polar( img: Tensor, grid_polar: "PolarGrid | dict", r0: float, r1: float, theta0: float, theta1: float ) -> tuple[Tensor, dict]: """Polar image subset. Parameters ---------- img : Tensor Input image. grid_polar : PolarGrid or dict Polar grid definition. PolarGrid object or dictionary. r0 : float Subset r0. r1 : float Subset r1. theta0 : float Subset theta0. theta1 : float Subset theta1. Returns ------- img : Tensor Subset of input image. grid_new : dict Grid. """ gr0, gr1 = grid_polar["r"] gtheta0, gtheta1 = grid_polar["theta"] nr = grid_polar["nr"] ntheta = grid_polar["ntheta"] dr = (gr1 - gr0) / nr dtheta = (gtheta1 - gtheta0) / ntheta nr0 = max(0, min(nr, int((r0 - gr0) / dr))) nr1 = max(0, min(nr, int((r1 - gr0) / dr))) ntheta0 = max(0, min(ntheta, int((theta0 - gtheta0) / dtheta))) ntheta1 = max(0, min(ntheta, int((theta1 - gtheta0) / dtheta))) out = img[..., nr0:nr1, ntheta0:ntheta1] grid_new = { "r": (gr0 + dr * nr0, gr0 + dr * nr1), "theta": (gtheta0 + ntheta0 * dtheta, gtheta0 + ntheta1 * dtheta), "nr": out.shape[-2], "ntheta": out.shape[-1], } return out, grid_new
[docs] def find_image_shift_2d( x: Tensor, y: Tensor, dim: tuple = (-2, -1), interpolate=False ) -> tuple: """ Find shift between images that maximizes correlation. Parameters ---------- x : Tensor Input tensor. y : int Input tensor. Should have same shape as x. dim : tuple Dimension. Returns ------- c : tuple Estimated shift. a : float Peak of correlation. """ if x.shape != y.shape: raise ValueError("Input shapes should be identical") if dim != (-2, -1): raise NotImplentedError("dim must be (-2, -1)") d2 = [] for i in dim: if i < 0: d2.append(x.dim() + i) else: d2.append(i) dims = d2 fx = torch.fft.fft2(x, dim=dim) fy = torch.fft.fft2(y, dim=dim) c = (fx * fy.conj()) / (torch.abs(fx) * torch.abs(fy)) other_dims = [i for i in range(x.dim()) if i not in dim] c = torch.abs(torch.fft.ifft2(c, dim=dim)) idx = argmax_nd(torch.abs(c)) a = c[idx].item() if interpolate: # Apply quad_interp to each spatial dimension interp_idx = list(idx) dim_idx = dims[0] # Extract 1D slice along this dimension at the peak location slice_indices = list(idx) slice_indices[dim_idx] = slice(None) # Replace with slice for this dimension c_slice = c[tuple(slice_indices)] delta_0 = quad_interp(c_slice, idx[dim_idx]) interp_idx[dim_idx] = idx[dim_idx] + delta_0 # Interpolate along the second spatial dimension (dim[-1]) dim_idx = dims[1] slice_indices = list(idx) slice_indices[dim_idx] = slice(None) c_slice = c[tuple(slice_indices)] delta_1 = quad_interp(c_slice, idx[dim_idx]) interp_idx[dim_idx] = idx[dim_idx] + delta_1 idx = tuple([i.item() for i in interp_idx]) idx = [ idx[i] - c.shape[i] if idx[i] > c.shape[i] // 2 else idx[i] for i in range(len(idx)) ] return idx, a
[docs] def fft_peak_1d(x: Tensor, dim: int = -1, fractional: bool = True) -> Tensor: """ Find fractional peak of ``abs(fft(x))``. Parameters ---------- x : Tensor Input tensor. dim : int Dimension to calculate peak. fractional : bool Estimate peak location with fractional index accuracy. Returns ------- a : int or float Estimated peak index. """ fx = torch.abs(torch.fft.fft(x, dim=dim)) a = torch.argmax(fx) if fractional: a = a + quad_interp(fx, a) l = x.shape[dim] if a > l // 2: a = l - a return a
[docs] def detrend(x: Tensor) -> Tensor: """ Removes linear trend Parameters ---------- x : Tensor Input tensor. Should be 1 dimensional. Returns ------- x : Tensor x with linear trend removed. """ n = x.shape[0] k = torch.arange(n, device=x.device, dtype=x.dtype) / n # Solve least squares problem: k * a + b = x ones = torch.ones(n, device=x.device, dtype=x.dtype) A = torch.stack([k, ones], dim=1) params = torch.linalg.lstsq(A, x).solution a, b = params[0], params[1] # Remove linear trend return x - (a * k + b)
[docs] def entropy(x: Tensor) -> Tensor: """ Calculates entropy: ``-sum(y*log(y))`` where ``y = abs(x) / sum(abs(x))``. Parameters ---------- x : Tensor Input tensor. Returns ------- entropy : Tensor Calculated entropy of the input. """ ax = torch.abs(x) ax /= torch.sum(ax) return -torch.sum(torch.xlogy(ax, ax))
[docs] def contrast(x: Tensor, dim: int = -1) -> Tensor: """ Calculates negative contrast: ``-mean(std/mu)`` where ``mu`` is mean and ``std`` is standard deviation of ``abs(x)`` along dimension ``dim``. Parameters ---------- x : Tensor Input tensor. Returns ------- contrast: Tensor Calculated negative contrast of the input. """ std, mu = torch.std_mean(torch.abs(x), dim=dim) contrast = torch.mean(std / mu) return -contrast
[docs] def shift_spectrum(x: Tensor, dim: int = -1) -> Tensor: """ Equivalent to: ``fft(ifftshift(ifft(x, dim), dim), dim)``, but avoids calculating FFTs. Parameters ---------- x : Tensor Input tensor. Returns ------- y : Tensor Shifted tensor. """ if dim != -1: raise NotImplementedError("dim should be -1") shape = [1] * len(x.shape) shape[dim] = x.shape[dim] c = torch.ones(shape, dtype=torch.float32, device=x.device) c[..., 1::2] = -1 return x * c
[docs] def generate_fmcw_data( target_pos: Tensor, target_rcs: Tensor, pos: Tensor, fc: float, bw: float, tsweep: float, fs: float, d0: float = 0, g: Tensor | None = None, g_extent: list | None = None, att: Tensor = None, rvp: bool = True, vel: Tensor | None = None, ) -> Tensor: """ Generate FMCW radar time-domain IF signal. Parameters ---------- target_pos : Tensor [ntargets, 3] tensor of target XYZ positions. target_rcs : Tensor [ntargets, 1] tensor of target reflectivity. pos : Tensor [nsweeps, 3] tensor of platform positions. When `vel` is provided, `pos[s]` is the platform position at the midpoint of sweep `s`. fc : float RF center frequency in Hz. bw : float RF bandwidth in Hz. tsweep : float Length of one sweep in seconds. fs : float Sampling frequency in Hz. d0 : float Zero range. g : Tensor or None Square-root of two-way antenna gain in spherical coordinates, shape: [elevation, azimuth]. If TX antenna equals RX antenna, then this should be just antenna gain. (0, 0) angle is at the beam center. Isotropic antenna is assumed if g is None. g_extent : list or None List of [g_el0, g_az0, g_el1, g_az1]. g_el0, g_el1 are grx and gtx elevation axis start and end values. Units in radians. -pi/2 + +pi/2 if including data over the whole sphere. g_az0, g_az1 are grx and gtx azimuth axis start and end values. Units in radians. -pi to +pi if including data over the whole sphere. att : Tensor Euler angles of the radar antenna at each data point. Shape should be [nsweeps, 3]. [Roll, pitch, yaw]. Only roll and yaw are used at the moment. rvp : bool True to include residual video phase term. vel : Tensor or None [nsweeps, 3] tensor of platform velocities in m/s. When given, the two-way delay is evaluated per intra-sweep sample at the instantaneous platform position `pos[s] + vel[s] * (t_sample - tsweep/2)`, i.e. the stop-and-go approximation is removed. None (default) reproduces the stop-and-go model where `pos[s]` is held fixed across the chirp. Returns ------- data : Tensor [nsweeps, nsamples] measurement data. """ if pos.dim() != 2: raise ValueError("pos tensor should have 2 dimensions") if pos.shape[1] != 3: raise ValueError("positions should be 3 dimensional") if vel is not None: if vel.shape != pos.shape: raise ValueError("vel must have the same shape as pos") npos = pos.shape[0] nsamples = int(fs * tsweep) device = pos.device data = torch.zeros((npos, nsamples), dtype=torch.complex64, device=device) t = torch.arange(nsamples, dtype=torch.float32, device=device) / fs k = bw / tsweep c0 = 299792458 use_rvp = 1 if rvp else 0 antenna_gain = g is not None and att is not None if antenna_gain: if g_extent is None: raise ValueError("g_extent is None, but g is not None") if len(g_extent) != 4: raise ValueError("g_extent should be a 4 element list") g_el0, g_az0, g_el1, g_az1 = g_extent nelevation, nazimuth = g.shape # Add batch and channel dimensions to g g_batch = g.unsqueeze(0).unsqueeze(0) t = t[None, :] # Sample-time offset from sweep midpoint, used only for non-stop-and-go. t_rel = (t - 0.5 * tsweep) if vel is not None else None for e, target in enumerate(target_pos): rcs_phase = torch.angle(target_rcs[e]) rcs_abs = torch.sqrt(torch.abs(target_rcs[e])) dpos = pos - target[None, :] # [nsweeps, 3] if vel is None: d = torch.linalg.vector_norm(dpos, dim=-1)[:, None] + d0 # [nsweeps, 1] else: # |dpos + vel * t_rel|^2 = |dpos|^2 + 2<dpos,vel> t_rel + |vel|^2 t_rel^2 dp_sq = (dpos * dpos).sum(dim=-1, keepdim=True) # [nsweeps, 1] dp_v = (dpos * vel).sum(dim=-1, keepdim=True) # [nsweeps, 1] v_sq = (vel * vel).sum(dim=-1, keepdim=True) # [nsweeps, 1] d = torch.sqrt(dp_sq + 2.0 * dp_v * t_rel + v_sq * t_rel**2) + d0 # [nsweeps, nsamples] tau = 2 * d / c0 if antenna_gain: # Antenna gain evaluated at sweep midpoint (slow-varying across chirp). d_mid = torch.linalg.vector_norm(dpos, dim=-1)[:, None] + d0 look_angle = torch.asin(pos[:, 2] / d_mid[:, 0]) el_deg = -look_angle - att[:, 0] az_deg = ( torch.atan2(target[1] - pos[:, 1], target[0] - pos[:, 0]) - att[:, 2] ) az_norm = 2.0 * (az_deg - g_az0) / (g_az1 - g_az0) - 1.0 el_norm = 2.0 * (el_deg - g_el0) / (g_el1 - g_el0) - 1.0 # grid_sample maps grid[..., 0] -> width (azimuth) and # grid[..., 1] -> height (elevation) of g, which is [nel, naz]. grid = torch.stack([az_norm, el_norm], dim=-1) grid = grid.unsqueeze(0).unsqueeze(0) # [1, 1, N, 2] g_a = F.grid_sample( g_batch, grid, mode="bilinear", padding_mode="zeros", align_corners=False, ) g_a = g_a.reshape(d_mid.shape) else: g_a = 1 data += (g_a * rcs_abs / d**2) * torch.exp( 1j * 2 * pi * (-fc * tau - k * tau * t + use_rvp * 0.5 * k * tau**2) + 1j * rcs_phase ) return data
[docs] def make_polar_grid( r0: float, r1: float, nr: int, ntheta: int, theta_limit: int = 1, squint: float = 0 ) -> "PolarGrid": """ Generate PolarGrid object. Parameters ---------- r0 : float Minimum range in m. r1 : float Maximum range in m. nr : float Number of range points. ntheta : float Number of azimuth points. theta_limit : float Theta axis limits, symmetrical around zero. Units are sin of angle (0 to 1 valid range). Default is 1. squint : float Grid azimuth mean angle, radians. Returns ------- grid_polar : PolarGrid Polar grid object. """ from .grid import PolarGrid t0 = float(np.clip(np.sin(squint) - theta_limit, -1, 1)) t1 = float(np.clip(np.sin(squint) + theta_limit, -1, 1)) return PolarGrid(r_range=(r0, r1), theta_range=(t0, t1), nr=nr, ntheta=ntheta)
# Alias for backward compatibility make_polar_grid_obj = make_polar_grid
[docs] def phase_to_distance(p: Tensor, fc: float) -> Tensor: """ Convert radar reflection phase shift to distance. Parameters ---------- p : Tensor Phase shift tensor. fc : float RF center frequency. """ c0 = 299792458 return c0 * p / (4 * torch.pi * fc)
[docs] def wiener_normalize( sar: Tensor, tx_power: Tensor, eps: float | None = None, calib_quantile: float = 0.1, ) -> Tensor: """ SNR-aware radiometric normalization of a SAR image by an illumination map. Plain division ``sar / tx_power`` inverts the illumination, but where the illumination is weak (swath edges, antenna nulls) it divides receiver noise by a near-zero number and the result blows up. This applies the Wiener (MMSE) estimate instead: .. math:: \\hat{s} = \\frac{\\mathrm{sar}\\cdot\\mathrm{tx\\_power}} {\\mathrm{tx\\_power}^2 + \\varepsilon^2} which equals ``(sar / tx_power) * SNR / (1 + SNR)`` with the per-pixel power SNR ``= (tx_power / eps)**2``. Where the illumination is strong it reduces to the full normalization ``sar / tx_power``; where it is weak the gain rolls off as ``tx_power**2`` so the output goes to zero instead of amplifying noise. The SNR map itself is ``(tx_power / eps)**2``. The regularization ``eps`` is the noise-to-signal amplitude ratio :math:`\\varepsilon = \\sigma_n / \\sigma_s` **in tx_power units**, where :math:`\\sigma_n` is the additive noise amplitude in the image and :math:`\\sigma_s` is the reflectivity scale relating image to illumination (``sar = s * tx_power + n``). It is the illumination level at which the SNR equals one. Note this is *not* simply the noise level :math:`\\sigma_n`: it must be divided by the reflectivity scale (which also absorbs the leftover radiometric calibration constant of ``tx_power``). Parameters ---------- sar : Tensor Complex or magnitude SAR image, same pseudo-polar shape as ``tx_power``. tx_power : Tensor Illumination map from :func:`torchbp.ops.backprojection_polar_2d_tx_power` (square root of power returned for unit reflectivity). Non-finite entries (un-illuminated pixels) are treated as no-data and mapped to zero. eps : float or None Regularization level :math:`\\sigma_n / \\sigma_s` in ``tx_power`` units. If None it is estimated from the data using the identity :math:`E|\\mathrm{sar}|^2 = \\sigma_s^2\\,\\mathrm{tx\\_power}^2 + \\sigma_n^2`: :math:`\\sigma_s^2` from the brightest ``calib_quantile`` fraction of pixels and :math:`\\sigma_n^2` from the dimmest fraction (with the residual signal subtracted). For real data prefer passing an explicit value from a known shadow region and a calibration target. calib_quantile : float Fraction (0-0.5) of pixels at each illumination extreme used to estimate ``eps`` when it is not given. Default 0.1 (dimmest/brightest 10%). Returns ------- s_hat : Tensor Normalized image, same dtype as ``sar``. """ finite = torch.isfinite(tx_power) txp = torch.where(finite, tx_power, torch.zeros_like(tx_power)) if eps is None: x = txp[finite].flatten().float() y = (sar.abs()[finite].flatten().float()) ** 2 # |sar|^2 # torch.quantile caps at ~16M elements; subsample large images. if x.numel() > 1_000_000: sel = torch.randperm(x.numel(), device=x.device)[:1_000_000] xq, yq = x[sel], y[sel] else: xq, yq = x, y lo = torch.quantile(xq, calib_quantile) hi = torch.quantile(xq, 1.0 - calib_quantile) him = xq >= hi lom = xq <= lo # sigma_s^2 from bright (signal-dominated), sigma_n^2 from dim with the # residual signal sigma_s^2 * tx_power^2 removed. sigma_s2 = yq[him].mean() / (xq[him] ** 2).mean().clamp_min(1e-30) sigma_n2 = (yq[lom].mean() - sigma_s2 * (xq[lom] ** 2).mean()).clamp_min(0.0) eps = float((sigma_n2 / sigma_s2.clamp_min(1e-30)).sqrt()) eps2 = float(eps) ** 2 s_hat = sar * txp / (txp * txp + eps2) # Un-illuminated (non-finite tx_power) pixels carry no signal -> zero. return torch.where(finite, s_hat, torch.zeros_like(s_hat))
[docs] def next_fast_len(n: int) -> int: """CuFFT-friendly length (powers of 2,3,5,7)""" def is_fast(k: int) -> bool: for p in (2,3,5,7): while k % p == 0: k //= p return k == 1 while not is_fast(n): n += 1 return n
[docs] def fft_lowpass_filter_precalculate_window( data_length: int, window_width: int, device: str, window: str | tuple, circular_conv: bool = False, fast_len: bool = True) -> Tensor: """ Precompute window to be used with `fft_lowpass_filter_window`. Returns ------- w : Tensor Windowing Tensor. pad_size : int Amount of padding added to signal (needed for extraction). """ half_width = (window_width + 1) // 2 # Original padding for linear convolution pad_size = 2 * half_width if not circular_conv else 0 # FFT length fft_len = data_length + pad_size if fast_len: fft_len = next_fast_len(fft_len) # Window centered at DC (symmetric for zero-phase) half_window = get_window(window, 2 * half_width - 1, fftbins=True)[half_width - 1:] w = np.zeros(fft_len, dtype=np.float32) w[:half_width] = half_window w[-half_width + 1:] = np.flip(half_window[1:]) return torch.tensor(w, device=device)
[docs] def fft_lowpass_filter_window( target_data: Tensor, window: str | tuple | Tensor = "hamming", window_width: int = None, circular_conv: bool = False, fast_len: bool = True, ) -> Tensor: """ FFT low-pass filtering with a configurable window function. """ if window_width is None or window_width > target_data.shape[-1]: return target_data half_width = (window_width + 1) // 2 N = target_data.shape[-1] # Determine padding and FFT length if isinstance(window, Tensor): fft_len = window.numel() pad_size = 2 * half_width if not circular_conv else 0 else: pad_size = 2 * half_width if not circular_conv else 0 fft_len = N + pad_size if fast_len: fft_len = next_fast_len(fft_len) fdata = torch.fft.fft(target_data, dim=-1, n=fft_len) if isinstance(window, Tensor): w = window else: w, pad_size = fft_lowpass_filter_precalculate_window( N, window_width, target_data.device, window, circular_conv=circular_conv, fast_len=fast_len ) filtered_data = torch.fft.ifft(fdata * w, dim=-1) # Trim to original length filtered_data = filtered_data[..., :N] return filtered_data
[docs] def center_pos(pos: Tensor) -> tuple[Tensor, Tensor]: """ Center position to origin. Centers X and Y coordinates, but doesn't modify Z. Useful for preparing positions for polar backprojection Parameters ---------- pos : Tensor 3D positions. Shape should be [N, 3]. Returns ------- pos_local : Tensor Centered positions. origin : Tensor Position subtracted from the pos. """ origin = torch.tensor( [torch.mean(pos[:, 0]), torch.mean(pos[:, 1]), 0], device=pos.device, dtype=torch.float32, )[None, :] pos_local = pos - origin return pos_local, origin
[docs] def bounding_cart_grid( grid_polar: "PolarGrid | dict", origin: tuple, origin_angle: float, ) -> dict: """ Return the bounding Cartesian grid for polar input grid. Parameters ---------- grid_polar : PolarGrid or dict Polar grid definition. Can be: - PolarGrid object: PolarGrid(r_range=(r0, r1), theta_range=(theta0, theta1), nr=nr, ntheta=ntheta) - dict: {"r": (r0, r1), "theta": (theta0, theta1), "nr": nr, "ntheta": ntheta} where theta is sin of angle (-1, 1 for 180 degree view). origin : tuple Origin coordinates of grid_polar in the Cartesian grid. origin_angle : float Reference direction (radians) that corresponds to s = 0. Returns ------- (xmin, ymin, xmax, ymax) : tuple[float, float, float, float] Coordinates of the smallest axis‑aligned rectangle containing the grid. """ (r0, r1) = grid_polar["r"] (s0, s1) = grid_polar["theta"] # Convert the stored sine values back to angles and shift by the origin. a0 = origin_angle + np.arcsin(s0) a1 = origin_angle + np.arcsin(s1) a_min, a_max = (a0, a1) if a0 <= a1 else (a1, a0) # Quadrantal angles where x or y may reach an extremum. candidate_angles = np.linspace(a_min, a_max, 20, endpoint=True) xmin = ymin = float("inf") xmax = ymax = -float("inf") for r in (r0, r1): for a in candidate_angles: x = r * np.cos(a) + origin[0] y = r * np.sin(a) + origin[1] xmin = min(xmin, x) xmax = max(xmax, x) ymin = min(ymin, y) ymax = max(ymax, y) dr = (grid_polar["r"][1] - grid_polar["r"][0]) / grid_polar["nr"] nx = int((xmax - xmin) / dr) ny = int((ymax - ymin) / dr) grid_cart = {"x": (xmin, xmax), "y": (ymin, ymax), "nx": nx, "ny": ny} return grid_cart
[docs] def create_triangular_weights(patch_size: int, overlap: int, device: str = "cpu") -> Tensor: """ Create triangular weights for smooth blending of overlapping patches. Parameters ---------- patch_size : int Side length of patches. overlap : int Overlap between patches. device : str Pytorch device. Returns ------- weights_2d : Tensor Weight tensor of [patch_size, patch_size] with triangular weighting """ K = patch_size O = overlap if O == 0: # No overlap - use uniform weights return torch.ones(K, K, device=device) # Create 1D triangular weights weights_1d = torch.ones(K, device=device) # Apply triangular weighting at the edges fade_length = O # Left edge: linear fade-in for i in range(min(fade_length, K)): weights_1d[i] = (i + 1) / (fade_length + 1) # Right edge: linear fade-out for i in range(max(0, K - fade_length), K): weights_1d[i] = (K - i) / (fade_length + 1) # Create 2D triangular weights using outer product weights_2d = torch.outer(weights_1d, weights_1d) return weights_2d
[docs] def extract_overlapping_patches(img: Tensor, patch_size: int, overlap: int) -> tuple[Tensor, tuple[int, ...]]: """ Extract overlapping patches from a tensor. Parameters ---------- img : Tensor Input tensor of shape [C, N, M]. patch_size : int Side length of square patches (K). overlap : int Overlap between patches. Returns ------- patches : Tensor Tensor of shape [C, P, K, K] where P is the number of patches. dim : tuple Original image dimensions. """ C, N, M = img.shape K = patch_size O = overlap # Calculate stride (distance between patch centers) stride = K - O pad_height = (K - N % stride) % stride pad_width = (K - M % stride) % stride if pad_height > 0 or pad_width > 0: # Pad with reflection to avoid edge artifacts img_padded = F.pad(img, (0, pad_width, 0, pad_height), mode="reflect") else: img_padded = img # Add batch dimension img_batch = img_padded.unsqueeze(0) # Extract patches patches = F.unfold(img_batch, kernel_size=K, stride=stride) patches = patches.squeeze(0) patches = patches.view(C, K * K, -1) P = patches.shape[2] patches = patches.transpose(1, 2).contiguous().view(C, P, K, K) return patches, img_padded.shape[1:] # Return padded dimensions
[docs] def merge_patches_with_triangular_weights( patches: Tensor, original_shape: tuple[int, int], patch_size: int, overlap: int, padded_shape: tuple[int, int] | None = None ) -> Tensor: """ Merge overlapping patches back into an image using triangular weighting. Parameters ---------- patches : Tensor Tensor of shape [C, P, K, K] containing patches original_shape : tuple Original shape of the image (N, M). patch_size : int Side length of square patches (K). overlap : int Overlap between patches. padded_shape : tuple Tuple (N_pad, M_pad) of padded dimensions. Returns ------- img : Tensor Reconstructed image tensor of shape [C, N, M]. """ C, P, K, K_check = patches.shape assert K == K_check, "Patches must be square" N, M = original_shape stride = K - overlap # Determine reconstruction dimensions if padded_shape is not None: N_recon, M_recon = padded_shape else: N_recon, M_recon = N, M if overlap == 0: # No overlap case - use uniform weights and simple reconstruction weights = torch.ones(K, K, device=patches.device) weighted_patches = patches * weights.unsqueeze(0).unsqueeze(0) # Reshape patches for fold operation weighted_patches_flat = ( weighted_patches.view(C, P, K * K).transpose(1, 2).contiguous() ) weighted_patches_flat = weighted_patches_flat.view(C * K * K, P) # Add batch dimension for fold weighted_patches_batch = weighted_patches_flat.unsqueeze(0) # Reconstruct using fold reconstructed = F.fold( weighted_patches_batch, output_size=(N_recon, M_recon), kernel_size=K, stride=stride, ) # Remove batch dimension reconstructed = reconstructed.squeeze(0) # Crop back to original size if padding was used if padded_shape is not None: reconstructed = reconstructed[:, :N, :M] return reconstructed # Overlapping case - use triangular weights weights = create_triangular_weights(patch_size, overlap, device=patches.device) # Apply weights to patches weighted_patches = patches * weights.unsqueeze(0).unsqueeze(0) # Reshape patches for fold operation weighted_patches_flat = ( weighted_patches.view(C, P, K * K).transpose(1, 2).contiguous() ) weighted_patches_flat = weighted_patches_flat.view(C * K * K, P) # Also create weight patches for normalization weight_patches = weights.unsqueeze(0).expand(P, -1, -1).unsqueeze(0) weight_patches_flat = weight_patches.view(1, P, K * K).transpose(1, 2).contiguous() weight_patches_flat = weight_patches_flat.view(K * K, P) # Add batch dimension for fold weighted_patches_batch = weighted_patches_flat.unsqueeze(0) weight_patches_batch = weight_patches_flat.unsqueeze(0) # Reconstruct using fold reconstructed = F.fold( weighted_patches_batch, output_size=(N_recon, M_recon), kernel_size=K, stride=stride, ) weight_sum = F.fold( weight_patches_batch, output_size=(N_recon, M_recon), kernel_size=K, stride=stride, ) # Remove batch dimension reconstructed = reconstructed.squeeze(0) weight_sum = weight_sum.squeeze(0) # Normalize by weight sum to handle overlaps epsilon = 1e-8 reconstructed = reconstructed / (weight_sum + epsilon) if padded_shape is not None: reconstructed = reconstructed[:, :N, :M] return reconstructed
[docs] def process_image_with_patches(img: Tensor, patch_size: int, overlap: int, process_fn) -> Tensor: """ Process an image by extracting patches, applying a function, and merging back. Parameters ---------- img : Tensor Input tensor of shape [C, N, M] or [N, M]. patch_size : int Side length of square patches. overlap : int Overlap between patches. process_fn : function Function to apply to patches. Returns ------- img : Tensor Processed image with same shape as the input. """ N, M = img.shape[1], img.shape[2] # Extract patches patches, padded_shape = extract_overlapping_patches(img, patch_size, overlap) # Apply processing function processed_patches = process_fn(patches) # Merge back result = merge_patches_with_triangular_weights( processed_patches, (N, M), patch_size, overlap, padded_shape ) return result