可微分空间缩放

1. 直觉的冲突:缩放真的不能求导吗?

如果你把缩放理解为 Photoshop 里的"图像大小调整"——比如把 256×256 硬 resize 成 128×128,像素坐标是整数跳跃——那它确实不可导。

但深度学习中使用的 F.grid_sample 走的是另一条路:它不直接搬动像素,而是让"采样坐标"随着参数平滑移动,通过双线性插值计算输出。 坐标是连续的,插值权重是连续的,输出是连续的,Loss 也是连续的——整条链可导。

2. 一维热身:先理解"坐标移动"如何产生梯度

假设输入是一维信号,只有两个采样点 x0,x1x_0, x_1。输出图像上有一个固定位置 uu(比如 u=0.5u=0.5),缩放因子为 aa

采样坐标aa 决定:

p=uap = \frac{u}{a}

grid_sample 的线性插值输出为:

y=(1p)x0+px1y = (1-p) \cdot x_0 + p \cdot x_1

Loss 对 aa 求导,链条如下:

Lossa=Lossy上游梯度yp像素差值 (x1x0)pa坐标移动速度 (u/a2)\frac{\partial \text{Loss}}{\partial a} = \underbrace{\frac{\partial \text{Loss}}{\partial y}}_{\text{上游梯度}} \cdot \underbrace{\frac{\partial y}{\partial p}}_{\text{像素差值 }(x_1-x_0)} \cdot \underbrace{\frac{\partial p}{\partial a}}_{\text{坐标移动速度 }(-u/a^2)}

关键洞察aa 的梯度不是来自"缩放"本身,而是来自"采样坐标 ppaa 移动时,插值权重发生了多少变化"。

3. 进入二维:为什么需要 theta

一维公式 p=u/ap = u/a 很简洁,但图像是二维的,有 xxyy 两个方向。我们需要一个统一的形式来描述"输出坐标如何映射到输入坐标"——这就是仿射变换矩阵 theta

3.1. 仿射变换的数学形式

二维仿射变换的标准写法:

[x采样y采样]=[θ11θ12θ13θ21θ22θ23][x输出y输出1]\begin{bmatrix} x_{\text{采样}} \\ y_{\text{采样}} \end{bmatrix} = \begin{bmatrix} \theta_{11} & \theta_{12} & \theta_{13} \\ \theta_{21} & \theta_{22} & \theta_{23} \end{bmatrix} \begin{bmatrix} x_{\text{输出}} \\ y_{\text{输出}} \\ 1 \end{bmatrix}

  • (x输出,y输出)(x_{\text{输出}}, y_{\text{输出}}):输出图像上像素的归一化坐标(范围 [1,1][-1, 1])。
  • (x采样,y采样)(x_{\text{采样}}, y_{\text{采样}}):对应到输入图像上的采样坐标
  • 矩阵前两列管线性变换(缩放、旋转、剪切),最后一列管平移。

3.2. 纯缩放时,theta 长什么样?

这段代码只做了等比例缩放,没有旋转和平移。因此:

θ=[1/s0001/s0]\theta = \begin{bmatrix} 1/s & 0 & 0 \\ 0 & 1/s & 0 \end{bmatrix}

其中 ss 就是 self.spatial_scale

代入公式:

x采样=1sx输出y采样=1sy输出\begin{aligned} x_{\text{采样}} &= \frac{1}{s} \cdot x_{\text{输出}} \\ y_{\text{采样}} &= \frac{1}{s} \cdot y_{\text{输出}} \end{aligned}

看到了吗?这就是你那个一维公式 p=u/ap = u/a 的二维版本。 theta 只是把"除以 ss"这个动作,写成了矩阵乘法的形式,以便同时处理 xxyy 两个方向。

4. affine_grid`:从"全局公式"到"逐像素坐标表"

theta 只有 6 个数,它描述的是变换规则。但 grid_sample 需要的是一个逐像素的坐标表grid),里面存着"输出图像上每个像素具体去输入图像的哪里采样"。

具体来说:

张量 形状 含义
theta B×2×3B \times 2 \times 3 全局仿射变换矩阵(“怎么变”)
grid B×H×W×2B \times H \times W \times 2 每个像素的采样坐标(“去哪采”)

affine_grid 的作用就是**“把 theta 广播到每一个像素上”**。它内部会:

  1. 生成输出图像上所有像素的归一化坐标 (xo,yo)(x_o, y_o)
  2. 对每个坐标执行矩阵乘法 θ[xo,yo,1]T\theta \cdot [x_o, y_o, 1]^T
  3. 输出完整的坐标表 grid
1
2
3
4
5
6
7
8
9
theta (2×3 矩阵)

affine_grid: "对每个像素执行矩阵乘法"

grid (H×W×2 坐标表)

grid_sample: "按坐标表去原图插值采样"

output

5. 数值例子:走一遍完整流程

假设:

  • 输入图像 x4×44 \times 4
  • spatial_scale = 2.0(放大 2 倍)
  • inv_scale = 0.5

Step 1:构造 theta

θ=[0.50000.50]\theta = \begin{bmatrix} 0.5 & 0 & 0 \\ 0 & 0.5 & 0 \end{bmatrix}

Step 2:affine_grid 生成坐标

输出图像左上角像素的归一化坐标约为 (1,1)(-1, -1)。经过变换:

[x采样y采样]=[0.50000.50][111]=[0.50.5]\begin{bmatrix} x_{\text{采样}} \\ y_{\text{采样}} \end{bmatrix} = \begin{bmatrix} 0.5 & 0 & 0 \\ 0 & 0.5 & 0 \end{bmatrix} \begin{bmatrix} -1 \\ -1 \\ 1 \end{bmatrix} = \begin{bmatrix} -0.5 \\ -0.5 \end{bmatrix}

含义:输出图像的左上角,要去输入图像的 (0.5,0.5)(-0.5, -0.5) 坐标采样。

因为 s=2s=2(放大),输出图像比输入图像"大",所以所有采样坐标都向中心收缩了 0.50.5 倍。grid 里存的就是这种收缩后的坐标。

Step 3:grid_sample 双线性插值

grid_sample 根据 grid 里的坐标,找到输入图像上相邻的 4 个像素,按小数距离加权插值:

输出像素=w1xi,j+w2xi+1,j+w3xi,j+1+w4xi+1,j+1\text{输出像素} = w_1 x_{i,j} + w_2 x_{i+1,j} + w_3 x_{i,j+1} + w_4 x_{i+1,j+1}

权重 ww 由采样坐标的小数部分决定。如果坐标发生微小偏移,权重会平滑变化,输出也平滑变化——这就是可导性的来源。

6. 梯度链条:从 Loss 一路传回 spatial_scale

现在把完整的梯度链路拼起来:

1
2
3
4
5
6
7
8
9
10
spatial_scale (s)
↓ 倒数运算
inv_scale (1/s)
↓ 广播乘法
theta = [[1/s, 0, 0], [0, 1/s, 0]]
↓ affine_grid(对 theta 可导)
grid(每个像素的采样坐标)
↓ grid_sample(对 grid 可导,双线性插值)
output
↓ Loss

反向传播时:

  1. grid_sample 阶段:算出 grid 的梯度。采样坐标往某个方向移动一点,插值结果怎么变?这取决于相邻像素的差值(也就是一维例子里的 x1x0x_1 - x_0)。

  2. affine_grid 阶段:把 grid 里所有像素的梯度汇总theta。因为每个像素的坐标都是 θ[xo,yo,1]T\theta \cdot [x_o, y_o, 1]^T 算出来的,所以 θ\theta 的每个元素都会收到来自全图像素的梯度贡献。

  3. theta 阶段theta = I \cdot \text{inv\_scale},所以 theta 的梯度会乘上单位矩阵,原封不动地传回 inv_scale

  4. inv_scale 阶段:`inv_scale = 1/s$,导数是 1/s2-1/s^2

最终 spatial_scale 的梯度 = 全图所有像素上"像素差值 × 坐标移动速度"的总和。

7. 工程陷阱(简要)

理解了原理之后,实现时有两个常见的梯度断裂点:

陷阱一:torch.tensor([...]) 会吃掉梯度

1
2
3
4
5
# 错误:inv_scale 被转成 Python float,脱离计算图
theta = torch.tensor([[inv_scale, 0, 0], [0, inv_scale, 0]], ...)

# 正确:用 Tensor 乘法保留梯度
theta = torch.tensor([[1.,0.,0.],[0.,1.,0.]], ...) * inv_scale

陷阱二:训练时的短路返回

1
2
3
4
5
6
7
# 错误:训练时如果 scale≈1 直接 return x,spatial_scale 不参与计算图
if abs(self.spatial_scale.item() - 1.0) < 1e-4:
return x

# 正确:只在推理时启用短路,训练时永远走 grid_sample
if not self.training and abs(self.spatial_scale.item() - 1.0) < 1e-4:
return x

下面是可以直接插入你博客的完整章节,承接前面的原理讲解,专门讲为什么这个参数不好优化,以及工程上怎么治

8. 优化曲面与训练陷阱:为什么 spatial_scale 不好学?

spatial_scale 很少出现传统意义上"多个谷底"的局部最优,但极易陷入平坦高原**、梯度不对称和**尺度歧义**。训练时的典型症状是:参数卡住不动、单向漂移后回不来、或者在 1.0 附近来回震荡。

8.1. 梯度大小的决定性因素

把之前推导的梯度链条再往前推一步,最终 spatial_scale(记为 ss)的梯度可以写成:

LsLy任务梯度spatialI图像空间梯度1s2坐标缩放率\frac{\partial \mathcal{L}}{\partial s} \propto \underbrace{\frac{\partial \mathcal{L}}{\partial y}}_{\text{任务梯度}} \cdot \underbrace{\nabla_{\text{spatial}} I}_{\text{图像空间梯度}} \cdot \underbrace{\frac{1}{s^2}}_{\text{坐标缩放率}}

核心洞察ss 的梯度大小正比于输入图像本身的空间梯度(相邻像素的差值)。

这意味着:

  • 平滑区域(天空、白墙、深层网络的模糊特征图):空间梯度 0\approx 0ss 的梯度 0\approx 0
  • 边缘/纹理区域(物体轮廓、细节):空间梯度大 → ss 的梯度大。

spatial_scale 的优化不是一个"均匀"的过程,它极度依赖当前 batch 里有没有足够的边缘信息。 如果连续几个 batch 都是平滑区域,ss 就像掉进了死海,几乎得不到更新。

8.2. 三种训练陷阱

8.2.1. 陷阱一:平坦高原(Flat Region)

如果输入图像本身很平滑,或者当前 ss 使得所有采样点恰好落在平滑区域:

spatialI0Ls0\nabla_{\text{spatial}} I \approx 0 \Rightarrow \frac{\partial \mathcal{L}}{\partial s} \approx 0

优化器看到梯度为 0,以为到了谷底,实际上只是一片高原。参数会在这里长期停滞,看起来像卡住了。

深层网络里尤其常见:经过十几层卷积后,特征图的空间梯度已经被 ReLU 和 Pooling 抹平了很多,此时在深层特征上学一个全局 spatial_scale,信号极其微弱。

8.2.2. 陷阱二:极端缩放的不对称性(Asymmetric Landscape)

ss 远离 1.0 时,会出现两种不对称的困境:

情况 发生了什么 梯度状况
s1s \ll 1(极度缩小) 大量像素被压缩到很小的区域,采样点密集,高频信息被平均/混叠掉 图像空间梯度被"抹平",梯度微弱,很难推回 s=1s=1
s1s \gg 1(极度放大) 采样坐标外扩,大量采样点落在原图边界外(padding_mode='zeros'),输出被黑色/零值填充 有效像素少,梯度信号稀疏,容易单向漂移向更大的 ss

这造成 Loss Landscape 是严重不对称的:从 s=1s=1 往两边走,一边可能是缓坡,另一边可能是悬崖。一旦优化器把 ss 推到极端值,它就很难自己爬回来。

8.2.3. 陷阱三:尺度歧义(Scale Ambiguity)——真正的局部最优

在某些特定输入下,确实存在多个"看起来都不错"的尺度:

  • 周期性纹理:输入是棋盘格、条纹织物。s=1s=1s=2s=2 采样后,由于混叠(aliasing),可能产生视觉上非常相似的 pattern,导致 Loss 在多个尺度附近都有低谷。
  • Padding 边界主导:如果目标图像本身有很多黑色边界,s=2s=2s=3s=3 都可能让原图内容缩成中心一小块,周围全是黑边,Loss 差异极小。

不过这种情况在自然图像和常规任务中相对少见,前两个陷阱才是工程中的主要敌人。

8.3. 实际训练中的典型症状

基于上面的分析,你在训练日志里可能会观察到:

  1. ss 在 1.0 附近震荡:因为 1.0 附近插值失真最小,梯度最"正常";一旦偏离,梯度不对称,优化器把它推回来,形成震荡。
  2. ss 单向漂移后卡住:比如一直增大到 2.0 附近,然后因为采样点大量出界、梯度消失,再也回不来了。
  3. 深层特征图上学 scale 几乎不动:深层特征空间上非常平滑,空间梯度弱,ss 的更新信号被噪声淹没。

8.4. 缓解措施(工程经验)

问题 具体解法 原理
平坦高原 / 梯度消失 不要只用 pixel-wise Loss(L1/L2/MSE)。改用 Perceptual Loss(VGG feature loss)或 对比学习损失 Pixel loss 对空间缩放极度敏感且容易平坦。Perceptual Loss 在高维特征里保留更多空间结构梯度,让 ss 有信号可学。
极端缩放漂移 加正则化。比如 Lreg=λ(s1.0)2\mathcal{L}_{\text{reg}} = \lambda \cdot (s - 1.0)^2,或者给 ss 设上下界(clamp)。 ss 锚定在 1.0 附近,防止它漂向极端值后掉进梯度死区。
初始化 ss 初始化为 1.0 这是插值最"忠实"原图的状态,也是梯度最对称、最稳定的起点。
学习率 spatial_scale 单独设一个较小的学习率(比如其他参数的 0.1 倍),或者用大 batch / gradient accumulation。 ss 是全局参数,影响所有像素,单步梯度噪声大。小学习率防止它在噪声中乱跳。
多尺度监督 不要只在一个分辨率上算 Loss。在多个尺度上同时监督。 避免某个单一尺度上的平坦区域主导训练。
参数耦合 不要单独只学一个 spatial_scale。把它放进完整的 STN(Spatial Transformer Network) 里,和旋转、平移一起学。 多个几何参数耦合时,Landscape 更复杂,但也更不容易让单一参数卡死。
Warmup 策略 训练前几轮固定 s=1.0s=1.0requires_grad=False),等网络其他部分稳定后再放开。 防止训练初期信号混乱时,ss 被随机梯度推到极端值。