Computer Vision/Self-Supervised Learning

Contrastive Learning - SimCLR 논문 리뷰

BeBeom 2025. 3. 29. 09:00

A Simple Framework for Contrastive Learning of Visual Representations (ICML 2020)

SimCLR은 데이터 증강(Data Augmentation)과 대조 학습(Contrastive Learning)을 통해 시각적 표현을 강화한다.

 

SimCLR의 framework는 다음과 같다.

1. 입력데이터 $x$로부터 두 개의 collerative view $\tilde{x}_i$와 $\tilde{x}_j$를 생성

동일한 증강 기법의 집합으로부터 얻은 두 개의 데이터 증강(Augmentation) 연산자 $t$ ~ $\tau$, $t'$ ~ $\tau$을 통해 두 개의 collerative view를 생성한다. 이 두 뷰를 양의 쌍(Positive View)로 간주한다.

더보기

(1), (2), (3)을 순차적으로 수행함.

(1) Random crop 후 원래 크기로 resize, (2) Random color distortion, (3) Random gaussian blur

Random crop만 할 시 유사한 색 분포를 공유하여 shortcut 문제가 발생할 수 있음 따라서 random color distortion을 함께 진행함.

지도 학습(Supervised Learning)과 다르게 대조 학습(Contrastive Learning)에선 강한 색상 왜곡이 성능 향상에 크게 기여.

더보기

데이터셋: ImageNet, 모델: ResNet-50

첫 번째의 random crop을 하고 두 번째에 random color distortion을 하는 것이 좋다는 걸 실험을 통해 증명하였다.
Supervised learning과 다르게 constrative learning은 데이터에 강한 색상 왜곡이 성능에 향상에 좋으며 gaussian blur까지 입힌 것이 제일 좋음(Default Setting)

2. 인코더 $f(\cdot)$를 통해 표현 벡터(Representation Vector) $h_i, h_j$를 추출

기본인코더로 ResNet을 사용하며 $ h \in \mathbb{R}^d$ 는 평균 풀링(Average Pooling) 층 이후의 출력 벡터이다.

3. 프로젝션 헤드 $g(\cdot)$를 통해 $z_i, z_j$를 추출

Contrastive loss를 재기 위해 데이터 변환(Augmentation)불변성(Invariance)을 갖도록 학습됨. 즉, $g(\cdot)$가 색상, 방향 등 특정 정보를 제거할 가능성이 있음.

더보기
  • 표현 벡터 $h$를 contrastive loss를 적용할 공간으로 변환.
  • MLP (Multi-Layer Perceptron) 구조를 사용, $z = g(h) = W^{(2)}\sigma (W^{(1)}h)$
    • $\sigma$는 ReLU 활성화 함수
    • 비선형 함수를 쓰는 것이 contrastive loss를 적용할 때 제일 좋음. (선형 함수 및 안 썼을 때와 비교해봄)

4. $z_i, z_j$간의 constrative loss를 구하여 두 개의 뷰 간 유사도를 최대화하도록 학습

즉 대조 손실은 양의 쌍 $z_i, z_j$를 가깝게 만들고, 나머지 $z_k$와는 멀어지도록 학습함.

더보기
    • : Positive pair의 두 벡터 (같은 이미지에서 augmentation된 두 개의 표현 벡터)
    • $\text{sim}(z_i, z_j)$: Cosine Similarity로 두 벡터의 유사도를 측정
    • $\tau$: Temperature parameter
      • 작을수록 더 극단적인 유사도 차이를 강조
    • $1_{[k \neq i]}$: Indicator function (자기 자신을 제외한 나머지 벡터만 사용)
  • 핵심 아이디어
    • Positive Pair: $z_i, z_j$가 까워지도록 손실을 줄임
    • Negative Pair: 같은 미니배치 내의 다른 샘플들 ${z_k}$와 거리를 멀어지도록 학습
    • 온도 매개변수 $\tau$를 조절하여 hard negative mining 효과를 조정
      • Hard negative mining: 학습이 어려운 Negative 샘플을 추출하는 방법
더보기
def contrastive_loss_fn(z_i, z_j, temperature=0.1):
    """Compute NT-Xent contrastive loss."""
    batch_size = z_i.shape[0]
    labels = torch.arange(batch_size, device=z_i.device)
    mask = torch.eye(batch_size, device=z_i.device)

    # logits 계산:
    # logits_aa: hidden1과 자기 자신(hidden1_large) 간의 유사도 (self-similarity 제거)
    logits_ii = torch.mm(z_i, z_i.t()) / temperature
    logits_ii = logits_ii - mask * LARGE_NUM

    # logits_bb: hidden2와 자기 자신(hidden2_large) 간의 유사도 (self-similarity 제거)
    logits_jj = torch.mm(z_j, z_j.t()) / temperature
    logits_jj = logits_jj - mask * LARGE_NUM

    # logits_ab: 첫 번째 view와 두 번째 view 간의 유사도
    logits_ij = torch.mm(z_i, z_j.t()) / temperature

    # logits_ba: 두 번째 view와 첫 번째 view 간의 유사도
    logits_ji = torch.mm(z_j, z_i.t()) / temperature

    # 두 branch에 대해 concatenation:
    # 각 샘플에 대해, 정답 짝이 포함된 첫 batch_size 열이 정답
    logits_a = torch.cat([logits_ij, logits_ii], dim=1)  # shape: (batch_size, 2*batch_size)
    logits_b = torch.cat([logits_ji, logits_jj], dim=1)  # shape: (batch_size, 2*batch_size)

    # cross entropy loss 계산:
    loss_a = F.cross_entropy(logits_a, labels, reduction="mean")
    loss_b = F.cross_entropy(logits_b, labels, reduction="mean")
    loss = loss_a + loss_b

    return loss

* Contrastive learning 성능 좋아지는 법

Batch Size, Model depth, Model width를 늘리자