PointNet: 포인트 집합에 대한 심층학습을 통한 3D 분류 및 분할
문제 정의 — 포인트 클라우드의 3가지 근본 어려움
3D 데이터는 복셀(Voxel), 메시(Mesh), 포인트 클라우드(Point Cloud) 세 가지로 표현할 수 있습니다. 복셀은 3D CNN을 적용하기 쉽지만 메모리 사용량이 O(n³)으로 폭발적으로 증가합니다. 포인트 클라우드는 LiDAR나 깊이 센서의 원본 출력이어서 빠르고 직관적이지만, 딥러닝에 직접 적용하기 어려운 3가지 근본적인 문제가 있습니다:
- 순서 불변성(Unordered): 같은 형상의 1000개 포인트를 1000! 가지 서로 다른 순서로 배열해도 동일한 3D 형상입니다. 입력 포인트의 순서와 무관하게 동일한 출력이 나와야 합니다.
- 불균일한 밀도(Irregular): 카메라 근처의 영역은 포인트가 밀집하고 먼 영역은 희박합니다. 픽셀처럼 균일하게 분포한다고 가정할 수 없습니다.
- 변환 불변성(Transformation Variance): 물체를 회전하거나 이동해도 동일한 물체로 인식해야 합니다.
핵심 아이디어 — 대칭 함수(Symmetric Function) + T-Net
해결 방법: 입력 순서에 무관한(대칭적) 집계 함수를 사용합니다. f({x₁,...,xₙ}) ≈ g(h(x₁),...,h(xₙ)) 여기서 g는 대칭 함수입니다. 여러 후보(최댓값, 합, 평균) 중에서 최댓값 풀링(Max Pooling)이 가장 강력한 표현을 학습합니다. N개의 각 포인트에 공유된 MLP h를 적용하고 채널별로 최댓값을 취하면, 포인트 개수에 무관한 1024차원의 전역 특징벡터가 생성됩니다.
회전 및 이동에 대한 불변성은 T-Net(공간 변환 네트워크, STN)으로 해결합니다. 입력 공간에서는 3×3 아핀 변환을, 특징 공간에서는 64×64 아핀 변환을 예측·적용해 정규화된 위치로 정렬합니다. T-Net의 정규화 손실함수는 L_reg = ‖I − AA^T‖²_F 로 직교성(orthogonality)을 강제합니다.
아키텍처
분할 가지(Segmentation Branch): 전역 특징(1024차원)과 포인트별 지역 특징(64차원)을 연결(concat) → 각 포인트마다 클래스 점수를 예측합니다. 전역적 맥락과 지역 정보를 동시에 활용합니다.
💻 코드 — PointNet 분류기 (PyTorch)
import torch, torch.nn as nn
class TNet(nn.Module):
def __init__(self, k=3):
super().__init__(); self.k = k
self.mlp = nn.Sequential(
nn.Conv1d(k,64,1), nn.BatchNorm1d(64), nn.ReLU(),
nn.Conv1d(64,128,1), nn.BatchNorm1d(128), nn.ReLU(),
nn.Conv1d(128,1024,1), nn.BatchNorm1d(1024), nn.ReLU())
self.fc = nn.Sequential(
nn.Linear(1024,512), nn.BatchNorm1d(512), nn.ReLU(),
nn.Linear(512,256), nn.BatchNorm1d(256), nn.ReLU(),
nn.Linear(256, k*k))
nn.init.zeros_(self.fc[-1].weight)
nn.init.eye_(self.fc[-1].bias.view(k,k))
def forward(self, x): # x: (B, k, N)
feat = self.mlp(x).max(-1)[0] # (B, 1024)
return self.fc(feat).view(-1, self.k, self.k)
# 직교성 손실함수 계산
def ortho_loss(A):
B,k,_ = A.shape
I = torch.eye(k, device=A.device).unsqueeze(0).expand(B,-1,-1)
return ((I - A @ A.transpose(1,2))**2).sum((1,2)).mean()
class PointNetCls(nn.Module):
def __init__(self, n_cls=40):
super().__init__()
self.tnet3 = TNet(3); self.tnet64 = TNet(64)
self.mlp1 = nn.Sequential(nn.Conv1d(3,64,1), nn.BatchNorm1d(64), nn.ReLU(),
nn.Conv1d(64,64,1), nn.BatchNorm1d(64), nn.ReLU())
self.mlp2 = nn.Sequential(nn.Conv1d(64,128,1), nn.BatchNorm1d(128), nn.ReLU(),
nn.Conv1d(128,1024,1), nn.BatchNorm1d(1024), nn.ReLU())
self.head = nn.Sequential(nn.Linear(1024,512), nn.BatchNorm1d(512), nn.ReLU(),
nn.Dropout(0.3), nn.Linear(512,256), nn.BatchNorm1d(256),
nn.ReLU(), nn.Dropout(0.3), nn.Linear(256,n_cls))
def forward(self, x): # x: (B,3,N)
x = (self.tnet3(x) @ x.transpose(1,2)).transpose(1,2)
x = self.mlp1(x)
x = (self.tnet64(x) @ x.transpose(1,2)).transpose(1,2)
feat = x
g = self.mlp2(x).max(-1)[0] # 전역 특징 (B,1024)
return self.head(g), ortho_loss(self.tnet64(feat))
# 샘플 사용
pts = torch.randn(4,3,1024)
logits, reg = PointNetCls()(pts)
print(logits.shape, reg.item()) # (4,40) 스칼라
결과
| 작업 | 데이터셋 | 평가지표 | 성능 |
|---|---|---|---|
| 3D 분류 | ModelNet40 | 정확도 | 89.2% |
| 부분 분할 | ShapeNet Part | mIoU | 83.7% |
| 장면 분할 | S3DIS | 정확도 | 78.6% |
💬 평가
PointNet은 포인트 클라우드를 딥러닝으로 직접 처리한 최초의 실용적 방법으로 3D 비전의 패러다임을 바꿨습니다. "대칭 함수로 순열 불변성(permutation invariance)을 해결한다"는 통찰력은 지금도 유효합니다.
그러나 명확한 한계가 있습니다: 각 포인트를 독립적으로 처리하기 때문에 인접한 포인트들 간의 관계(지역 기하학적 구조)를 무시합니다. 불균일한 밀도에도 취약합니다. 이 두 가지 한계를 계층적 지역 학습으로 극복한 것이 PointNet++입니다. 그럼에도 불구하고 PointNet은 오늘날 PointPillar, PointRCNN 등 실시간 3D 객체 감지 시스템의 기본 모듈로서 여전히 활발히 사용되고 있습니다.