Contrastive Learning - SimCLR 논문 리뷰
A Simple Framework for Contrastive Learning of Visual Representations (ICML 2020)
SimCLR은 데이터 증강(Data Augmentation)과 대조 학습(Contrastive Learning)을 통해 시각적 표현을 강화한다.
SimCLR의 framework는 다음과 같다.
1. 입력데이터 $x$로부터 두 개의 collerative view $\tilde{x}_i$와 $\tilde{x}_j$를 생성
동일한 증강 기법의 집합으로부터 얻은 두 개의 데이터 증강(Augmentation) 연산자 $t$ ~ $\tau$, $t'$ ~ $\tau$을 통해 두 개의 collerative view를 생성한다. 이 두 뷰를 양의 쌍(Positive View)로 간주한다.
(1), (2), (3)을 순차적으로 수행함.
(1) Random crop 후 원래 크기로 resize, (2) Random color distortion, (3) Random gaussian blur
Random crop만 할 시 유사한 색 분포를 공유하여 shortcut 문제가 발생할 수 있음 따라서 random color distortion을 함께 진행함.
지도 학습(Supervised Learning)과 다르게 대조 학습(Contrastive Learning)에선 강한 색상 왜곡이 성능 향상에 크게 기여.
데이터셋: ImageNet, 모델: ResNet-50


2. 인코더 $f(\cdot)$를 통해 표현 벡터(Representation Vector) $h_i, h_j$를 추출
기본인코더로 ResNet을 사용하며 $ h \in \mathbb{R}^d$ 는 평균 풀링(Average Pooling) 층 이후의 출력 벡터이다.
3. 프로젝션 헤드 $g(\cdot)$를 통해 $z_i, z_j$를 추출
Contrastive loss를 재기 위해 데이터 변환(Augmentation)에 불변성(Invariance)을 갖도록 학습됨. 즉, $g(\cdot)$가 색상, 방향 등 특정 정보를 제거할 가능성이 있음.
- 표현 벡터 $h$를 contrastive loss를 적용할 공간으로 변환.
- MLP (Multi-Layer Perceptron) 구조를 사용, $z = g(h) = W^{(2)}\sigma (W^{(1)}h)$
- $\sigma$는 ReLU 활성화 함수
- 비선형 함수를 쓰는 것이 contrastive loss를 적용할 때 제일 좋음. (선형 함수 및 안 썼을 때와 비교해봄)
4. $z_i, z_j$간의 constrative loss를 구하여 두 개의 뷰 간 유사도를 최대화하도록 학습
즉 대조 손실은 양의 쌍 $z_i, z_j$를 가깝게 만들고, 나머지 $z_k$와는 멀어지도록 학습함.

- : Positive pair의 두 벡터 (같은 이미지에서 augmentation된 두 개의 표현 벡터)
- $\text{sim}(z_i, z_j)$: Cosine Similarity로 두 벡터의 유사도를 측정
- $\tau$: Temperature parameter
- 작을수록 더 극단적인 유사도 차이를 강조
- $1_{[k \neq i]}$: Indicator function (자기 자신을 제외한 나머지 벡터만 사용)
- 핵심 아이디어
- Positive Pair: $z_i, z_j$가 까워지도록 손실을 줄임
- Negative Pair: 같은 미니배치 내의 다른 샘플들 ${z_k}$와 거리를 멀어지도록 학습
- 온도 매개변수 $\tau$를 조절하여 hard negative mining 효과를 조정
- Hard negative mining: 학습이 어려운 Negative 샘플을 추출하는 방법
def contrastive_loss_fn(z_i, z_j, temperature=0.1):
"""Compute NT-Xent contrastive loss."""
batch_size = z_i.shape[0]
labels = torch.arange(batch_size, device=z_i.device)
mask = torch.eye(batch_size, device=z_i.device)
# logits 계산:
# logits_aa: hidden1과 자기 자신(hidden1_large) 간의 유사도 (self-similarity 제거)
logits_ii = torch.mm(z_i, z_i.t()) / temperature
logits_ii = logits_ii - mask * LARGE_NUM
# logits_bb: hidden2와 자기 자신(hidden2_large) 간의 유사도 (self-similarity 제거)
logits_jj = torch.mm(z_j, z_j.t()) / temperature
logits_jj = logits_jj - mask * LARGE_NUM
# logits_ab: 첫 번째 view와 두 번째 view 간의 유사도
logits_ij = torch.mm(z_i, z_j.t()) / temperature
# logits_ba: 두 번째 view와 첫 번째 view 간의 유사도
logits_ji = torch.mm(z_j, z_i.t()) / temperature
# 두 branch에 대해 concatenation:
# 각 샘플에 대해, 정답 짝이 포함된 첫 batch_size 열이 정답
logits_a = torch.cat([logits_ij, logits_ii], dim=1) # shape: (batch_size, 2*batch_size)
logits_b = torch.cat([logits_ji, logits_jj], dim=1) # shape: (batch_size, 2*batch_size)
# cross entropy loss 계산:
loss_a = F.cross_entropy(logits_a, labels, reduction="mean")
loss_b = F.cross_entropy(logits_b, labels, reduction="mean")
loss = loss_a + loss_b
return loss
* Contrastive learning 성능 좋아지는 법
Batch Size, Model depth, Model width를 늘리자