cs231n

cs231n Assignment 3: Q4-1 (Self-supervised Learning, SimCLR 설명)

츄츄츄츄츄츄츄 2023. 3. 9. 12:25

이번 과제에서 구현할 것은 SimCLR (A Simple Framework for Contrastive Learning of Visual Representations (arxiv.org))이다. 과제를 시작하기에 앞서 이번 포스팅에서는 SimCLR에 대한 이해를 한 후, 다음 포스팅에서 과제를 시작해 보겠다.

 

우리가 Assignment 1, 2에서 사용한 딥러닝 방식은 train data에 target label이 붙어 있는 supervised learning 방법을 사용했었다. 그러나 supervised learning의 단점은 위와 같이 모든 train data에 target label을 붙여주어야 하는데, 이는 보통 사람이 하게 된다. 예시로 CIFAR-10데이터에서 사진을 보고 (개, 고양이, 자동차, 트럭) 등등의 class 레이블을 붙이는 과정은 당연히 사람이 할 수 밖에 없다. 문제는 데이터가 매우 커지면, 또는 COCO datasets처럼 target label이 한 단어가 아니라 사진을 설명하는 문장이라면 몇백만개 이상의 사진과 단어를 사람들이 일일이 검토해야 하고 이는 수많은 노동력이 필요하고 노동력은 곧 비용을 뜻한다. self supervised learning은 위와 같은 labeling 과정 없이 사진 분류를 할 수 있게 도와준다.

 

Self supervised learning중  constrastive learning은 사람의 붙인 label을 기준으로 분류하는 것을 학습하지 않고, 사진을 서로 비교하면서 이미지의 주요 features를 살펴보며 비슷한 사진들끼리 비슷한 representation을 가질 수 있도록 하고, 다른 사진들 끼리는 다른 representation을 가질 수 있도록 하는 것을 학습한다. 예시로 두개의 사과 시진이 있으면 두 사진은 비슷한 representation 벡터를 만들 수 있도록 학습하고, 사과 바나나 사진이 있으면 두 사진은 다른 representation 벡터를 만들 수 있도록 학습한다.

 

SimCLR Diagram

SimCLR는 constrastive learning을 위해 각 사진 데이터 x에서 x의 데이터를 변형한 2개의 augmented data(x~i, x~j)를 만든다. 사진 data를 augment 하는 방법에는 여러 방법이 있다. 아래의 사진들이 data를 augment하는 방법들이다. 아래의 방법을 한개 또는 여러개 중첩하여 사용해서 augmented data (x~i, x~j) 를 만든다. x~i, x~j를 augment하는 방법은 반드시 같지 않아도 될 것이다. x~i에는 crop, color distort 등을 사용하고 x~j에는 Gaussian Blur등을 사용해 다른 방법으로 augment 해도 될 것이다. 물론 이런 assymetric 한 augmentation은 정확도가 더 떨어진다고 한다.

그래서 어떤 data augmentation을 사용할지 각 augmentation 방법 당 정확도를 측정해 보았다고 한다. ImageNet을 사용해 정확도를 측정했는데, ImageNet의 이미지 사이즈가 모두 다르기에 augmentation에 모두 crop/resize를 포함하고, 공정성을 위해 한 branch는 더이상 augment하지 않고 한 branch는 아래와 같은 순서로 (1st, 2nd) transformation을 취해 (겹치는 지점은 transformation 1개) 최대 정확도를 비교해 보았다고 한다. 그랬더니 1개의 transformation만을 취한 결과가 거의 항상 최악이었고, 2개를 취한 것 중 제일 좋은 것은 color distortion, cropping이 가장 효과가 좋았다고 한다. 

나아가 제일 좋은 color distort, random cropping 두개를 color distort를 취하지 않은것과 취한것을 비교해보기도 해보았다고 한다. color distort를 취하지 않은 경우 문제점이 있는데 아래의 without color distortion 히스토그램처럼 augmented data 들이 모두 비슷한 pixel intensities를 가지고 있었다고 한다. 비슷한 pixel intensity 만으로도 뒤의 Neural Network가 이미지 분류하는 데에는 어려움이 없어서 이렇게 좋지 않은 representation을 내보내는 결과가 보여졌다고 해 color distortion을 사용해야 좀더 general 한 features를 얻을 수 있을 것이라고 한다.

아무튼 이렇게 augmentation을 사용해 두 개의 augmented data를 얻어낸 후, 기존의 ResNet (사진에서는 f(.)) 과 같은 신경망을 통과시켜 ResNet에서의 features를 얻어 h를 얻어낸다. 다음으로 h를 projection layer (사진에서 g(.), MLP사용) 를 통과시켜 z를 얻는다. 여기서 projection layer의 MLP구조를 달리 해 가며도 정확도를 비교 해 보았는데, Non-linear layer에서 결과가 제일 좋고, dimension은 그렇게 큰 차이를 보여주지 않았다고 한다.

만약 사진이 총 N개 있다면, 각 사진당 2개의 z가 나오므로 z는 총 2*N 개가 생길 것이고, 2*N개중 각각 2개는 positive pair일 것이고 그 2개에 대해서 2*(N-1)개와는 negative pair를 이룰 것이다. SimCLR에서는 이 positive pair와 negative pair간의 NT-Xent loss를 사용해 loss 함수를 구한다. 물론 이 NT-Xent loss도 실험을 통해 구한 최적의 loss 함수이다.

XT-Xent loss function

NT-Xent에서는 temperature (tau) 와 l2 normalization을 사용할지 말지 구분 할 수 있는데 이도 각 케이스별로 실험해 보았다고 한다. l2norm을 사용하지 않으면 contrastive acc는 높았지만 결국 top1 test accuracy는 낮았고 l2norm을 사용하고 tau를 일정 값으로 조정 했을시 top1 test accuracy가 높게 나왔다고 한다.

그리고 학습 할 때 batch size를 달리 해 갔는데, batch size가 클수록 정확도가 올라갔다고 한다. batch size가 많아질 수록, x의 사진 하나 당 비교할 수 있는 negative pair가 많아지므로 각 사진들을 더 다양하게 비교해 볼 수 있고, 당연히 더 잘 분류 할 수 있게 될 것이다.

 

학습하는 과정에서 또한 positive pair와 negative pair의 전체를 기준으로 normalize 하는 Global Normalization, gradient update로는 batch size가 클때 사용하는 LARS Optimizer를 사용했다고 한다.

 

그러면 gradient update하는 parameter들은 어디에 위치해 있느것일까? 라고 생각할 수 있는데, 먼저 x 를 augment 하는 layer는 파라미터가 있을리가 없다. x를 augment하는 방법에는 여러 방법이 있지만 이는 우리가 조작하는 hyperparameter가 될 것이다.

 

다음으로 f(.)는 이미지에서 features를 가져오는 ResNet등의 layer이다. 여기서의 f(.) layer는 기존의 pretrained 된 ResNet을 사용하면 안된다. 기존의 pretrained weight는 이후 features를 통해 바로 직접적인 classification을 하는 것에 train 되어있으므로, ResNet의 모델 구조만 사용하고 weights들은 새로 학습시켜야 할 parameter 일 것이다.

 

그리고 남은 g(.)에서 아까 우리가 MLP를 사용한다 했는데, 이 MLP들의 weights들도 얻은 features를 통해 positive pair끼리의 좋은 representation을 만들기 위해 추론하는 역할을 할 것이다. 그러므로 이 weights들도 학습시키는 parameter 일 것이다.