3D Gaussian Splatting(3DGS)现在几乎成了3D视觉领域的标配技术。NVIDIA把它整合进COSMOS,Meta的新款AR眼镜可以直接在设备端跑3DGS做实时环境捕获和渲染。这技术已经不只是停留在论文阶段了,产品落地速度是相当快的。
所以这篇文章我们用PyTorch从头实现最初那篇3DGS论文,代码量控制在几百行以内。虽然实现很简洁但效果能达到SOTA水平。
需要说明的是,这里主要讲实现细节不会设计每个公式背后的数学推导。
场景如何表示
3DGS把场景表示成一堆各向异性的3D高斯分布。这跟NeRF那种把场景隐式编码在神经网络里的做法不太一样,3DGS的表示是显式的、可微的。
我们这次直接加载已经训练好的场景,只跑forward pass,不涉及训练和反向传播。
pos = torch.load('trained_gaussians/kitchen/pos_7000.pt').cuda()
opacity_raw = torch.load('trained_gaussians/kitchen/opacity_raw_7000.pt').cuda()
f_dc = torch.load('trained_gaussians/kitchen/f_dc_7000.pt').cuda()
f_rest = torch.load('trained_gaussians/kitchen/f_rest_7000.pt').cuda()
scale_raw = torch.load('trained_gaussians/kitchen/scale_raw_7000.pt').cuda()
q_raw = torch.load('trained_gaussians/kitchen/q_rot_7000.pt').cuda()
每个3D高斯分布用均值(pos)和协方差矩阵来描述,这俩参数定义了高斯分布在3D空间里的位置和形状——是被拉长的还是接近球形。协方差矩阵用尺度(scale_raw)和旋转(q_raw)来参数化,对应论文里的公式(6)。
另外每个高斯还存了不透明度(opacity_raw)和颜色信息。颜色是view-dependent的,跟NeRF类似,所以不能存成固定的RGB值。这里用球谐函数(Spherical Harmonics)来表示颜色随观察方向的变化,系数就是f_dc和f_rest。
渲染pipeline
3DGS的渲染可以分成两个主要阶段:
第一阶段是预处理。把3D高斯投影到图像平面并按深度排序然后组织成tile。这样做是为了并行处理更高效,避免重复计算。
第二阶段是针对每个tile做体积渲染。遍历跟tile重叠的那些高斯,用体积渲染方程累加它们的贡献。最后通过alpha compositing算出像素颜色,原理上跟NeRF的渲染类似,但这里是在屏幕空间显式实现的。
值得注意的是,3DGS在像素(tile)上并行化渲染,其实也可以在高斯上并行化。像素并行化的好处是相邻像素能共享空间信息,可以提前剔除不相关的高斯,渲染速度基本不受高斯数量影响。这两种策略篇论文都有讨论。
代码结构基本就是这两个阶段:
@torch.no_grad()
def render(pos, color, opacity_raw, sigma, c2w, H, W, fx, fy, cx, cy,
near=2e-3, far=100, pix_guard=64, T=16, min_conis=1e-6,
chi_square_clip=9.21, alpha_max=0.99, alpha_cutoff=1/255.):
# Block 1: Project the Gaussians onto the image plane
# Block2: Global depth sorting
# Block3: Tiling
# Block 4: Compute the inverse covariance matrix
final_image = torch.zeros((H * W, 3), device=pos.device, dtype=pos.dtype)
# Iterate over tiles
for tile_id, s0, s1 in zip(unique_tile_ids.tolist(), start.tolist(), end.tolist()):
current_gaussian_ids = gaussian_ids[s0:s1]
# Block 5: Compute pixel coordinates for this tile
# Block 6: Apply the volumetric rendering equation
return final_image.reshape((H, W, 3)).clamp(0, 1)
1、投影高斯分布
用相机内参(fx, fy, cx, cy)和相机到世界的变换矩阵(c2w)把3D的均值和协方差投影出来。论文里公式5定义了协方差矩阵Σ怎么从3D世界空间转到2D相机空间。
uv, x, y, z = project_points(pos, c2w, fx, fy, cx, cy)
in_guard = (uv[:, 0] > -pix_guard) & (uv[:, 0] < W + pix_guard) & (
uv[:, 1] > -pix_guard) & (uv[:, 1] < H + pix_guard) & (z > near) & (z < far)
uv = uv[in_guard]
pos = pos[in_guard]
color = color[in_guard]
opacity = torch.sigmoid(opacity_raw[in_guard]).clamp(0, 0.999)
z = z[in_guard]
x = x[in_guard]
y = y[in_guard]
sigma = sigma[in_guard]
idx = torch.nonzero(in_guard, as_tuple=False).squeeze(1)
# Project the covariance
Rcw = c2w[:3, :3]
Rwc = Rcw.t()
invz = 1 / z.clamp_min(1e-6)
invz2 = invz * invz
J = torch.zeros((pos.shape[0], 2, 3), device=pos.device, dtype=pos.dtype)
J[:, 0, 0] = fx * invz
J[:, 1, 1] = fy * invz
J[:, 0, 2] = -fx * x * invz2
J[:, 1, 2] = -fy * y * invz2
tmp = Rwc.unsqueeze(0) @ sigma @ Rwc.t().unsqueeze(0) # Eq. 5
sigma_camera = J @ tmp @ J.transpose(1, 2)
# Enforce symmetry
sigma_camera = 0.5 * (sigma_camera + sigma_camera.transpose(1, 2))
# Ensure positive definiteness
evals, evecs = torch.linalg.eigh(sigma_camera)
evals = torch.clamp(evals, min=1e-6, max=1e4)
sigma_camera = evecs @ torch.diag_embed(evals) @ evecs.transpose(1, 2)
2、全局深度排序
为了保证透明度混合的正确性,要把高斯按深度从远到近排序,这个思路跟NeRF的front-to-back累积是一样的。
# Global depth sorting
order = torch.argsort(z, descending=False)
uv = uv[order]
u = uv[:, 0]
v = uv[:, 1]
color = color[order]
opacity = opacity[order]
sigma_camera = sigma_camera[order]
evals = evals[order]
idx = idx[order]
3、Tiling划分
屏幕会被切成一个个tile(比如16×16像素),然后在tile内部的像素上并行渲染。这个设计跟Pulsar论文里提到的基于tile、像素并行的策略一致。
# Tiling
major_variance = evals[:, 1].clamp_min(1e-12).clamp_max(1e4) # [N]
radius = torch.ceil(3.0 * torch.sqrt(major_variance)).to(torch.int64)
umin = torch.floor(u - radius).to(torch.int64)
umax = torch.floor(u + radius).to(torch.int64)
vmin = torch.floor(v - radius).to(torch.int64)
vmax = torch.floor(v + radius).to(torch.int64)
on_screen = (umax >= 0) & (umin < W) & (vmax >= 0) & (vmin < H)
if not on_screen.any():
raise Exception("All projected points are off-screen")
u, v = u[on_screen], v[on_screen]
color = color[on_screen]
opacity = opacity[on_screen]
sigma_camera = sigma_camera[on_screen]
umin, umax = umin[on_screen], umax[on_screen]
vmin, vmax = vmin[on_screen], vmax[on_screen]
idx = idx[on_screen]
umin = umin.clamp(0, W - 1)
umax = umax.clamp(0, W - 1)
vmin = vmin.clamp(0, H - 1)
vmax = vmax.clamp(0, H - 1)
# Tile index for each AABB
umin_tile = (umin // T).to(torch.int64) # [N]
umax_tile = (umax // T).to(torch.int64) # [N]
vmin_tile = (vmin // T).to(torch.int64) # [N]
vmax_tile = (vmax // T).to(torch.int64) # [N]
# Number of tiles each gaussian intersects
n_u = umax_tile - umin_tile + 1 # [N]
n_v = vmax_tile - vmin_tile + 1 # [N]
# Max number of tiles
max_u = int(n_u.max().item())
max_v = int(n_v.max().item())
nb_gaussians = umin_tile.shape[0]
span_indices_u = torch.arange(max_u, device=pos.device, dtype=torch.int64) # [max_u]
span_indices_v = torch.arange(max_v, device=pos.device, dtype=torch.int64) # [max_v]
tile_u = (umin_tile[:, None, None] + span_indices_u[None, :, None]
).expand(nb_gaussians, max_u, max_v) # [N, max_u, max_v]
tile_v = (vmin_tile[:, None, None] + span_indices_v[None, None, :]
).expand(nb_gaussians, max_u, max_v) # [N, max_u, max_v]
mask = (span_indices_u[None, :, None] < n_u[:, None, None]
) & (span_indices_v[None, None, :] < n_v[:, None, None]) # [N, max_u, max_v]
flat_tile_u = tile_u[mask] # [0, 0, 1, 1, 2, ...]
flat_tile_v = tile_v[mask] # [0, 1, 0, 1, 2]
nb_tiles_per_gaussian = n_u * n_v # [N]
gaussian_ids = torch.repeat_interleave(
torch.arange(nb_gaussians, device=pos.device, dtype=torch.int64),
nb_tiles_per_gaussian) # [0, 0, 0, 0, 1 ...]
nb_tiles_u = (W + T - 1) // T
flat_tile_id = flat_tile_v * nb_tiles_u + flat_tile_u # [0, 0, 0, 0, 1 ...]
idx_z_order = torch.arange(nb_gaussians, device=pos.device, dtype=torch.int64)
M = nb_gaussians + 1
comp = flat_tile_id * M + idx_z_order[gaussian_ids]
comp_sorted, perm = torch.sort(comp)
gaussian_ids = gaussian_ids[perm]
tile_ids_1d = torch.div(comp_sorted, M, rounding_mode='floor')
# tile_ids_1d [0, 0, 0, 1, 1, 2, 2, 2, 2]
# nb_gaussian_per_tile [3, 2, 4]
# start [0, 3, 5]
# end [3, 5, 9]
unique_tile_ids, nb_gaussian_per_tile = torch.unique_consecutive(tile_ids_1d, return_counts=True)
start = torch.zeros_like(unique_tile_ids)
start[1:] = torch.cumsum(nb_gaussian_per_tile[:-1], dim=0)
end = start + nb_gaussian_per_tile
4、逆协方差矩阵
要算每个高斯的不透明度贡献,先要计算出它的高斯概率密度函数(PDF),也就是论文里的公式(4)。所以就需要逆协方差矩阵,这是直接从相机坐标系下的协方差矩阵算出来。
inverse_covariance = inv2x2(sigma_camera)
inverse_covariance[:, 0, 0] = torch.clamp(
inverse_covariance[:, 0, 0], min=min_conis)
inverse_covariance[:, 1, 1] = torch.clamp(
inverse_covariance[:, 1, 1], min=min_conis)
5、像素坐标计算
公式(4)还依赖每个像素到高斯中心的距离。在每个tile内部先算出屏幕空间的像素坐标,下一步会用这些坐标高效地计算所有高斯的公式(4)。
txi = tile_id % nb_tiles_u
tyi = tile_id // nb_tiles_u
x0, y0 = txi * T, tyi * T
x1, y1 = min((txi + 1) * T, W), min((tyi + 1) * T, H)
if x0 >= x1 or y0 >= y1:
continue
xs = torch.arange(x0, x1, device=pos.device, dtype=pos.dtype)
ys = torch.arange(y0, y1, device=pos.device, dtype=pos.dtype)
pu, pv = torch.meshgrid(xs, ys, indexing='xy')
px_u = pu.reshape(-1) # [T * T]
px_v = pv.reshape(-1)
pixel_idx_1d = (px_v * W + px_u).to(torch.int64)
6、体积渲染方程
最后用标准的alpha compositing累加颜色。每个高斯根据自己的不透明度和沿视线方向累积的透射率来贡献颜色。这步对应NeRF里的核心体积渲染原理。
gaussian_i_u = u[current_gaussian_ids] # [N]
gaussian_i_v = v[current_gaussian_ids] # [N]
gaussian_i_color = color[current_gaussian_ids] # [N, 3]
gaussian_i_opacity = opacity[current_gaussian_ids] # [N]
gaussian_i_inverse_covariance = inverse_covariance[current_gaussian_ids] # [N, 2, 2]
du = px_u.unsqueeze(0) - gaussian_i_u.unsqueeze(-1) # [N, T * T]
dv = px_v.unsqueeze(0) - gaussian_i_v.unsqueeze(-1) # [N, T * T]
A11 = gaussian_i_inverse_covariance[:, 0, 0].unsqueeze(-1) # [N, 1]
A12 = gaussian_i_inverse_covariance[:, 0, 1].unsqueeze(-1)
A22 = gaussian_i_inverse_covariance[:, 1, 1].unsqueeze(-1)
q = A11 * du * du + 2 * A12 * du * dv + A22 * dv * dv # [N, T * T]
inside = q <= chi_square_clip
g = torch.exp(-0.5 * torch.clamp(q, max=chi_square_clip)) # [N, T * T]
g = torch.where(inside, g, torch.zeros_like(g))
alpha_i = (gaussian_i_opacity.unsqueeze(-1) * g).clamp_max(alpha_max) # [N, T * T]
alpha_i = torch.where(alpha_i >= alpha_cutoff, alpha_i, torch.zeros_like(alpha_i))
one_minus_alpha_i = 1 - alpha_i # [N, T * T]
T_i = torch.cumprod(one_minus_alpha_i, dim=0)
T_i = torch.concatenate([
torch.ones((1, alpha_i.shape[-1]), device=pos.device, dtype=pos.dtype),
T_i[:-1]], dim=0)
alive = (T_i > 1e-4).float()
w = alpha_i * T_i * alive # [N, T * T]
final_image[pixel_idx_1d] = (w.unsqueeze(-1) * gaussian_i_color.unsqueeze(1)).sum(dim=0)
球谐函数表示
没有给每个高斯分配单一的RGB值,而是用球谐函数(SH)把颜色表示成观察方向的平滑函数。
简单说,球谐函数就是球面上的傅里叶变换。可以把定义在所有方向上的函数(比如依赖视角的颜色)分解成一系列基函数的加权和。每个高斯学习一组SH系数,编码了颜色随观察方向的变化规律,能捕捉镜面高光或者表面相关的光照变化这类效果。
下面代码计算每个高斯的SH颜色:
SH_C0 = 0.28209479177387814
SH_C1_x = 0.4886025119029199
SH_C1_y = 0.4886025119029199
SH_C1_z = 0.4886025119029199
SH_C2_xy = 1.0925484305920792
SH_C2_xz = 1.0925484305920792
SH_C2_yz = 1.0925484305920792
SH_C2_zz = 0.31539156525252005
SH_C2_xx_yy = 0.5462742152960396
SH_C3_yxx_yyy = 0.5900435899266435
SH_C3_xyz = 2.890611442640554
SH_C3_yzz_yxx_yyy = 0.4570457994644658
SH_C3_zzz_zxx_zyy = 0.3731763325901154
SH_C3_xzz_xxx_xyy = 0.4570457994644658
SH_C3_zxx_zyy = 1.445305721320277
SH_C3_xxx_xyy = 0.5900435899266435
def evaluate_sh(f_dc, f_rest, points, c2w):
sh = torch.empty((points.shape[0], 16, 3),
device=points.device, dtype=points.dtype)
sh[:, 0] = f_dc
sh[:, 1:, 0] = f_rest[:, :15] # R
sh[:, 1:, 1] = f_rest[:, 15:30] # G
sh[:, 1:, 2] = f_rest[:, 30:45] # B
view_dir = points - c2w[:3, 3].unsqueeze(0) # [N, 3]
view_dir = view_dir / (view_dir.norm(dim=-1, keepdim=True) + 1e-8)
x, y, z = view_dir[:, 0], view_dir[:, 1], view_dir[:, 2]
xx, yy, zz = x * x, y * y, z * z
xy, xz, yz = x * y, x * z, y * z
Y0 = torch.full_like(x, SH_C0) # [N]
Y1 = - SH_C1_y * y
Y2 = SH_C1_z * z
Y3 = - SH_C1_x * x
Y4 = SH_C2_xy * xy
Y5 = SH_C2_yz * yz
Y6 = SH_C2_zz * (3 * zz - 1)
Y7 = SH_C2_xz * xz
Y8 = SH_C2_xx_yy * (xx - yy)
Y9 = SH_C3_yxx_yyy * y * (3 * xx - yy)
Y10 = SH_C3_xyz * x * y * z
Y11 = SH_C3_yzz_yxx_yyy * y * (4 * zz - xx - yy)
Y12 = SH_C3_zzz_zxx_zyy * z * (2 * zz - 3 * xx - 3 * yy)
Y13 = SH_C3_xzz_xxx_xyy * x * (4 * zz - xx - yy)
Y14 = SH_C3_zxx_zyy * z * (xx - yy)
Y15 = SH_C3_xxx_xyy * x * (xx - 3 * yy)
Y = torch.stack([Y0, Y1, Y2, Y3, Y4, Y5, Y6, Y7, Y8, Y9, Y10, Y11, Y12, Y13, Y14, Y15],
dim=1) # [N, 16]
return torch.sigmoid((sh * Y.unsqueeze(2)).sum(dim=1))
这种表示方式很紧凑,表达能力又强能建模复杂的view-dependent效果,还不用在渲染时跑神经网络。实际应用中,3DGS一般用3阶球谐函数,对应每个颜色通道16个系数。这个配置在视觉真实感和内存效率之间平衡得不错。
有个常问的面试题:100万个高斯的3DGS表示需要多少存储空间?
每个高斯存储的内容包括:球谐函数48个系数(16×3),位置3个float,不透明度1个float,尺度3个float,旋转四元数4个float。总共59个float,按每个float 4字节算,大概236字节,100万个高斯就是225 MB左右。
另外一个进阶问题:如果只用2阶球谐函数呢?
那就是(9×3)+3=30个系数而不是48个,总数从59降到44个float,内存能省25%左右。
辅助函数实现
前面用到的几个辅助函数现在实现一下。首先是计算2×2矩阵的逆,这个比较简单:
def inv2x2(M, eps=1e-12):
a = M[:, 0, 0]
b = M[:, 0, 1]
c = M[:, 1, 0]
d = M[:, 1, 1]
det = a * d - b * c
safe_det = torch.clamp(det, min=eps)
inv = torch.empty_like(M)
inv[:, 0, 0] = d / safe_det
inv[:, 0, 1] = -b / safe_det
inv[:, 1, 0] = -c / safe_det
inv[:, 1, 1] = a / safe_det
return inv
然后是透视相机的光栅化,把3D点从世界空间投影到2D图像平面。这步用相机内参和外参确定每个高斯在屏幕空间的位置。
从学习到的参数构建协方差矩阵。每个高斯由尺度和旋转四元数定义,它们决定了3D空间里的形状和朝向。先把四元数转成旋转矩阵,再跟对角尺度矩阵组合成完整的3D协方差,渲染时这个协方差会投影到相机空间。
def project_points(pc, c2w, fx, fy, cx, cy):
w2c = torch.eye(4, device=pc.device)
R = c2w[:3, :3]
t = c2w[:3, 3]
w2c[:3, :3] = R.t()
w2c[:3, 3] = -R.t() @ t
PC = ((w2c @ torch.concatenate(
[pc, torch.ones_like(pc[:, :1])], dim=1).t()).t())[:, :3]
x, y, z = PC[:, 0], PC[:, 1], PC[:, 2] # Camera space
uv = torch.stack([fx * x / z + cx, fy * y / z + cy], dim=-1)
return uv, x, y, z
完整流程整合
最后把所有模块组合起来
if __name__ == "__main__":
pos = torch.load('trained_gaussians/kitchen/pos_7000.pt').cuda()
opacity_raw = torch.load('trained_gaussians/kitchen/opacity_raw_7000.pt').cuda()
f_dc = torch.load('trained_gaussians/kitchen/f_dc_7000.pt').cuda()
f_rest = torch.load('trained_gaussians/kitchen/f_rest_7000.pt').cuda()
scale_raw = torch.load('trained_gaussians/kitchen/scale_raw_7000.pt').cuda()
q_raw = torch.load('trained_gaussians/kitchen/q_rot_7000.pt').cuda()
cam_parameters = np.load('out_colmap/kitchen/cam_meta.npy',
allow_pickle=True).item()
orbit_c2ws = torch.load('camera_trajectories/kitchen_orbit.pt').cuda()
sigma = build_sigma_from_params(scale_raw, q_raw)
with torch.no_grad():
for i, c2w_i in tqdm(enumerate(orbit_c2ws)):
c2w = c2w_i
H = cam_parameters['height'] // 2
W = cam_parameters['width'] // 2
H_src = cam_parameters['height']
W_src = cam_parameters['width']
fx, fy = cam_parameters['fx'], cam_parameters['fy']
cx, cy = W_src / 2, H_src / 2
fx, fy, cx, cy = scale_intrinsics(H, W, H_src, W_src, fx, fy, cx, cy)
color = evaluate_sh(f_dc, f_rest, pos, c2w)
img = render(pos, color, opacity_raw, sigma, c2w, H, W, fx, fy, cx, cy)
Image.fromarray((img.cpu().detach().numpy() * 255).astype(np.uint8)
).save(f'novel_views/frame_{i:04d}.png')
总结
这篇文章我们用纯PyTorch实现了3D Gaussian Splatting的完整渲染pipeline,代码量控制在几百行以内。整个实现围绕两个核心阶段展开:预处理阶段完成3D高斯到2D图像平面的投影、深度排序和tile划分;渲染阶段则通过体积渲染方程完成alpha compositing。