Source code for torchbp.interferometry

from .util import process_image_with_patches
from .grid import unpack_polar_grid, unpack_cartesian_grid
import torch
import torch.nn.functional as F
from torch import Tensor
from typing import TYPE_CHECKING
from .ops import subpixel_correlation_op
from scipy.optimize import minimize
import numpy as np

if TYPE_CHECKING:
    from .grid import PolarGrid, CartesianGrid

def _goldstein_patch(patches: Tensor, alpha: float, w: int=3) -> Tensor:
    fpatch = torch.fft.fft2(patches)
    fpatch = torch.fft.fftshift(fpatch, dim=(-2, -1))

    C, P, K, _ = fpatch.shape
    # Reshape to combine first two dims for pooled processing
    fpatches_reshaped = fpatch.view(C * P, 1, K, K)

    # Apply box filter (average pooling)
    filtered = F.avg_pool2d(torch.abs(fpatches_reshaped), kernel_size=w, stride=1, padding=w//2)

    filtered = filtered.view(C, P, K, K)

    filtered = fpatch * filtered**alpha
    del fpatch
    del fpatches_reshaped

    filtered = torch.fft.ifftshift(filtered, dim=(-2, -1))

    filtered = torch.fft.ifft2(filtered)

    return filtered


[docs] def goldstein_filter(igram: Tensor, patch_size: int=64, w: int=3, alpha: float=1, overlap: float=0.75) -> Tensor: """ Goldstein phase filter. [1]_ Parameters ---------- igram: Tensor Complex interferogram. patch_size : int Patch side-length. w : int Smoothing window size. Must be odd. alpha : float Smoothing exponent. overlap : float Overlap between patches as fraction of patch_size. References ---------- .. [1] R. M. Goldstein and C. L. Werner, "Radar interferogram filtering for geophysical applications," in Geophysical Research Letters, vol. 25, no. 21, pp 4035-4038, 1998 Returns ------- filtered : Tensor Filtered interferogram. """ if w % 2 == 0: raise ValueError(f"w must be odd, got {w}") orig_dim = igram.dim() if orig_dim == 2: igram = igram.unsqueeze(0) # Reflect-pad by one patch on all sides so border pixels get the same # multi-patch coverage as the interior. Cropped off again at the end. pad = min(patch_size, igram.shape[-2] - 1, igram.shape[-1] - 1) if pad > 0: igram = F.pad(igram, (pad, pad, pad, pad), mode="reflect") f_patch = lambda x : _goldstein_patch(x, alpha, w) overlap = int(overlap * patch_size) filtered = process_image_with_patches(igram, patch_size, overlap, f_patch) if pad > 0: filtered = filtered[..., pad:filtered.shape[-2] - pad, pad:filtered.shape[-1] - pad] if orig_dim == 2: filtered = filtered.squeeze(0) return filtered
[docs] def phase_to_elevation(unw: Tensor, coords: Tensor, origin1: Tensor, origin2: Tensor, fc: float) -> Tensor: """ Convert phase unwrapped interferogram to elevation. Parameters ---------- unw : Tensor Unwrapped phase tensor. Shape: [Nx, Ny]. coords : Tensor Coordinates for each position in image. Shape: [3, Nx, Ny]. origin1 : Tensor 3D antenna phase center location of the master image. origin2 : Tensor 3D antenna phase center location of the slave image. fc : float RF center frequency in Hz. Returns ------- z : Tensor Elevation tensor with the same shape as unw. """ device = unw.device c0 = 299792458 wl = c0 / fc v1 = coords - origin1[:,None,None] v2 = coords - origin2[:,None,None] r1 = torch.linalg.norm(v1, dim=0) r2 = torch.linalg.norm(v2, dim=0) # First-order height sensitivity for flat imaging plane: # dphi/dh = (4pi/lambda) · (origin2_z/r2 − origin1_z/r1) sensitivity = origin2[2] / r2 - origin1[2] / r1 z = -wl * unw / (4 * torch.pi * sensitivity) return z
[docs] def phase_to_elevation_polar(unw: Tensor, origin1: Tensor, origin2: Tensor, fc: float, grid: "PolarGrid | dict") -> Tensor: """ Convert phase unwrapped interferogram to elevation. Parameters ---------- unw : Tensor Unwrapped phase tensor. Shape: [Nx, Ny]. origin1 : Tensor 3D antenna phase center location of the master image. origin2 : Tensor 3D antenna phase center location of the slave image. fc : float RF center frequency in Hz. grid : PolarGrid or dict Image grid definition. PolarGrid object or dictionary. Returns ------- z : Tensor Elevation tensor with the same shape as unw. """ device = unw.device r0, r1, theta0, theta1, nr, ntheta, dr, dtheta = unpack_polar_grid(grid) r = r0 + dr * torch.arange(nr, device=device) theta = theta0 + dtheta * torch.arange(ntheta, device=device) coords = torch.stack([r[:,None] * torch.sqrt(1 - theta**2)[None,:], r[:,None] * theta[None,:], torch.zeros_like(unw)]) return phase_to_elevation(unw, coords, origin1, origin2, fc)
[docs] def phase_to_elevation_cart(unw: Tensor, origin1: Tensor, origin2: Tensor, fc: float, grid: "CartesianGrid | dict") -> Tensor: """ Convert phase unwrapped interferogram to elevation. Parameters ---------- unw : Tensor Unwrapped phase tensor. Shape: [Nx, Ny]. origin1 : Tensor 3D antenna phase center location of the master image. origin2 : Tensor 3D antenna phase center location of the slave image. fc : float RF center frequency in Hz. grid : CartesianGrid or dict Image grid definition. CartesianGrid object or dictionary. Returns ------- z : Tensor Elevation tensor with the same shape as unw. """ device = unw.device x0, x1, y0, y1, nx, ny, dx, dy = unpack_cartesian_grid(grid) x = x0 + dx * torch.arange(nx, device=device) y = y0 + dy * torch.arange(ny, device=device) coords = torch.stack(torch.meshgrid([x, y, torch.tensor([0], device=device, dtype=x.dtype)], indexing="ij")) coords = coords[..., 0] return phase_to_elevation(unw, coords, origin1, origin2, fc)
[docs] def flat_earth_phase_polar(origin1: Tensor, origin2: Tensor, fc: float, grid: "PolarGrid | dict") -> Tensor: """ Compute flat earth interferometric phase for a polar grid. For images formed by backprojection on a flat (z=0) grid, the interferometric phase contains baseline geometry fringes even over flat terrain. This function computes that phase so it can be removed, isolating the topographic signal. Parameters ---------- origin1 : Tensor 3D antenna phase center of the master image [x, y, z]. origin2 : Tensor 3D antenna phase center of the slave image [x, y, z]. fc : float RF center frequency in Hz. grid : PolarGrid or dict Polar grid definition. Returns ------- phase : Tensor Flat earth phase tensor. Shape: [nr, ntheta]. """ r0, r1, theta0, theta1, nr, ntheta, dr, dtheta = unpack_polar_grid(grid) device = origin1.device r = r0 + dr * torch.arange(nr, device=device) theta = theta0 + dtheta * torch.arange(ntheta, device=device) x = r[:, None] * torch.sqrt(1 - theta[None, :] ** 2) y = r[:, None] * theta[None, :] c0 = 299792458 d1 = torch.sqrt((x - origin1[0]) ** 2 + (y - origin1[1]) ** 2 + origin1[2] ** 2) d2 = torch.sqrt((x - origin2[0]) ** 2 + (y - origin2[1]) ** 2 + origin2[2] ** 2) return 4 * torch.pi * fc / c0 * (d1 - d2)
[docs] def flat_earth_phase_cart(origin1: Tensor, origin2: Tensor, fc: float, grid: "CartesianGrid | dict") -> Tensor: """ Compute flat earth interferometric phase for a Cartesian grid. Parameters ---------- origin1 : Tensor 3D antenna phase center of the master image [x, y, z]. origin2 : Tensor 3D antenna phase center of the slave image [x, y, z]. fc : float RF center frequency in Hz. grid : CartesianGrid or dict Cartesian grid definition. Returns ------- phase : Tensor Flat earth phase tensor. Shape: [nx, ny]. """ x0, x1, y0, y1, nx, ny, dx, dy = unpack_cartesian_grid(grid) device = origin1.device x = x0 + dx * torch.arange(nx, device=device) y = y0 + dy * torch.arange(ny, device=device) c0 = 299792458 d1 = torch.sqrt((x[:, None] - origin1[0]) ** 2 + (y[None, :] - origin1[1]) ** 2 + origin1[2] ** 2) d2 = torch.sqrt((x[:, None] - origin2[0]) ** 2 + (y[None, :] - origin2[1]) ** 2 + origin2[2] ** 2) return 4 * torch.pi * fc / c0 * (d1 - d2)
[docs] def elevation_to_phase_slant_polar( z: "Tensor | float", origin1: Tensor, origin2: Tensor, fc: float, grid: "PolarGrid | dict" ) -> Tensor: """ Compute traditional (slant-range) interferometric phase at given elevation. For scatterers at height z above the imaging plane:: phi = (4pi fc/c0)*(r1 − r2) where r_n = slant range in nth image Reduces to :func:`flat_earth_phase_polar` when z = 0. Parameters ---------- z : Tensor or float Elevation map [nr, ntheta] or scalar height. origin1 : Tensor Master APC [x, y, z]. origin2 : Tensor Slave APC [x, y, z]. fc : float RF center frequency (Hz). grid : PolarGrid or dict Polar grid definition. Returns ------- phase : Tensor Interferometric phase [nr, ntheta]. """ r0, r1, theta0, theta1, nr, ntheta, dr, dtheta = unpack_polar_grid(grid) device = origin1.device r = r0 + dr * torch.arange(nr, device=device) theta = theta0 + dtheta * torch.arange(ntheta, device=device) x = r[:, None] * torch.sqrt(1 - theta[None, :] ** 2) y = r[:, None] * theta[None, :] c0 = 299792458 d1 = torch.sqrt((x - origin1[0]) ** 2 + (y - origin1[1]) ** 2 + (origin1[2] - z) ** 2) d2 = torch.sqrt((x - origin2[0]) ** 2 + (y - origin2[1]) ** 2 + (origin2[2] - z) ** 2) return 4 * torch.pi * fc / c0 * (d1 - d2)
[docs] def phase_to_elevation_slant_polar( unw: Tensor, origin1: Tensor, origin2: Tensor, fc: float, grid: "PolarGrid | dict", n_iter: int = 5 ) -> Tensor: """ Convert unwrapped topographic phase to elevation via Newton iteration. Inverts the traditional interferometric relationship:: phi_topo(z) = elevation_to_phase_slant_polar(z, ...) − elevation_to_phase_slant_polar(0, ...) Starts from the linearised BP-interferometry estimate and refines with Newton's method. Parameters ---------- unw : Tensor Unwrapped topographic phase (flat-earth removed) [nr, ntheta]. origin1 : Tensor Master APC [x, y, z]. origin2 : Tensor Slave APC [x, y, z]. fc : float RF center frequency (Hz). grid : PolarGrid or dict Polar grid definition. n_iter : int Number of Newton iterations (default 5). Returns ------- z : Tensor Estimated elevation [nr, ntheta]. """ # Linearised initial estimate. # phase_to_elevation_polar uses z = -wl*unw/(4*pi*sens) which has # a negation for the BP conjugate-phase convention. Here the # phase is in the direct convention, so negate the initial guess. z = -phase_to_elevation_polar(unw, origin1, origin2, fc, grid) phi_flat = elevation_to_phase_slant_polar(0.0, origin1, origin2, fc, grid) r0, r1, theta0, theta1, nr, ntheta, dr, dtheta = unpack_polar_grid(grid) device = unw.device r = r0 + dr * torch.arange(nr, device=device) theta = theta0 + dtheta * torch.arange(ntheta, device=device) x = r[:, None] * torch.sqrt(1 - theta[None, :] ** 2) y = r[:, None] * theta[None, :] c0 = 299792458 k = 4 * torch.pi * fc / c0 a1 = (x - origin1[0]) ** 2 + (y - origin1[1]) ** 2 a2 = (x - origin2[0]) ** 2 + (y - origin2[1]) ** 2 for _ in range(n_iter): d1 = torch.sqrt(a1 + (origin1[2] - z) ** 2) d2 = torch.sqrt(a2 + (origin2[2] - z) ** 2) phi_z = k * (d1 - d2) residual = unw - (phi_z - phi_flat) # dphi/dz = k * ((h2 - z)/r2 - (h1 - z)/r1) sensitivity = k * ((origin2[2] - z) / d2 - (origin1[2] - z) / d1) z = z + residual / sensitivity return z
[docs] def subpixel_correlation(im_m: Tensor, im_s: Tensor) -> tuple[Tensor, Tensor]: """ Solve for subpixel offset that maximize coherent correlation between the two input images. [1]_ Parameters ---------- im_m : Tensor Master image. im_s : Tensor Slave image. References ---------- .. [1] D. Li and Y. Zhang, "A Fast Offset Estimation Approach for InSAR Image Subpixel Registration," in IEEE Geoscience and Remote Sensing Letters, vol. 9, no. 2, pp. 267-271, March 2012. Returns ------- offsets : Tensor Solved X and Y subpixel offsets. corrs : Tensor Correlation at the best subpixel offset. """ a, b, c = subpixel_correlation_op(im_m, im_s) a = a.cpu().numpy() b = b.cpu().numpy() c = c.cpu().numpy() corrs = torch.zeros((c.shape[0]), dtype=torch.float32, device=im_m.device) offsets = torch.zeros((c.shape[0], 2), dtype=torch.float32, device=im_m.device) for i in range(a.shape[0]): ai = a[i] bi = b[i] ci = c[i] def opt(x): dx, dy = x return -np.abs(bi[0] + bi[1]*dx + bi[2]*dy + bi[3]*dx*dy) / np.sqrt(ci * (ai[0] + ai[1]*dx + ai[2]*dy + ai[3]*dx**2 + ai[4]*dy**2 + ai[5]*dx*dy + ai[6]*dx**2*dy + ai[7]*dx*dy**2 + ai[8]*dx**2*dy**2)) sol = minimize(opt, [0, 0], method="SLSQP", bounds=[(0, 1), (0, 1)]) offsets[i, 0] = sol.x[0] offsets[i, 1] = sol.x[1] corrs[i] = -sol.fun return offsets, corrs