SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

🏷️ Model Name

SimCLR – A Simple Framework for Contrastive Learning of Visual Representations

🧠 Core Idea

Learn visual representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space, effectively pulling positive pairs (two augmented views of the same image) close together while pushing them away from all other images in the batch (negatives)

SimCLR architecture

πŸ–ΌοΈ Architecture

            β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
            β”‚          Raw image x        β”‚
            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                            β”‚
         β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
         β”‚                              β”‚
 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
 β”‚ Data Augmentation1 aug1β”‚   β”‚ Data Augmentation2 aug2β”‚
 β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         β”‚                              β”‚
         β–Ό                              β–Ό
 Encoder (ResNet) f()            Encoder (ResNet) f()
         β”‚                              β”‚
         β–Ό                              β–Ό
      Feature h₁                    Feature hβ‚‚
         β”‚                              β”‚
         β–Ό                              β–Ό
Projection Head (MLP) g()      Projection Head (MLP) g()
         β”‚                              β”‚
         β–Ό                              β–Ό
      Vector z₁                        Vector zβ‚‚
         β”‚                              β”‚
         β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                          β–Ό
                  Contrastive Loss NT-Xent

And the pseudocode of SimCLR’s main learning algorithm:

input: batch size N, constant Ο„, structure of f, g, T.
for sampled minibatch {x_k}^N_{k=1} do
    for all k∈{1,...,N} do
        draw two augmentation functions t∼T, tβ€²βˆΌT
        # the first augmentation
        x_{2kβˆ’1} = t(x_k)
        h_{2kβˆ’1} = f(x_{2kβˆ’1}) # representation
        z_{2kβˆ’1} = g(h_{2kβˆ’1}) # projection
        # the second augmentation
        x_{2k} = tβ€²(x_k)
        h_{2k} = f(x_{2k}) # representation
        z_{2k} = g(h_{2k}) # projection
    end for
    for all i∈{1,...,2N} and j ∈{1,...,2N} do
        s_{i,j} = sim(z_i, z_j) # pairwise similarity
    end for
    define β„“(i,j) as NT-Xent # the normalized temperature-scaled cross entropy loss)
    L=1/(2N)*sum[β„“(2kβˆ’1,2k) + β„“(2k,2kβˆ’1) for k from 1 to N]
    update networks f and g to minimize L
end for
return encoder network f(Β·), and throw away g(Β·)

1️⃣ Stochastic Data Augmentation ($x \rightarrow (\tilde{x}_i, \tilde{x}_j)$)

Randomly transforms any given input data example ($x$) to generate two correlated views, denoted $\tilde{x}_i$ and $\tilde{x}_j$, which are considered a positive pair.

The standard augmentation policy sequentially applies three simple augmentations:

  • Random cropping followed by resizing back to the original size (which often includes random horizontal flipping).
  • Random color distortions (including color jittering and color dropping). The combination of random crop and color distortion is crucial for good performance.
  • Random Gaussian blur.

2️⃣ Neural Network Base Encoder $f(\cdot)$

The encoder network extracts the representation vectors ($h$) from the augmented data examples. The SimCLR framework allows various network architectures without constraints. The authors typically adopt the commonly used ResNet architecture. The output $h_i = f(\tilde{x}_i)$ is obtained after the average pooling layer.

3️⃣ Small Neural Network Projection Head $g(\cdot)$

This component maps the representation $h$ to the space $z$ where the contrastive loss is applied. The use of a learnable nonlinear transformation here substantially improves the quality of the learned representations

  • The authors use a Multi-Layer Perceptron (MLP) with one hidden layer and a ReLU non-linearity: $z_i = g(h_i) = W^{(2)}\sigma(W^{(1)}h_i)$

After training is completed, the projection head $g(\cdot)$ is discarded, and the encoder $f(\cdot)$ and representation $h$ are used for downstream tasks.

4️⃣ Contrastive Loss Function

The contrastive prediction task aims to identify $\tilde{x}_j$ (the positive counterpart) within a set of augmented examples ${\tilde{x}_k}$, given $\tilde{x}_i$

  • SimCLR utilizes a minibatch of $N$ examples, generating $2N$ augmented data points. The other $2(N-1)$ augmented examples within the minibatch are treated as negative examples for a given positive pair.
  • The specific loss function used is the Normalized Temperature-scaled Cross Entropy loss (NT-Xent). $$ L_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{2N} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k) / \tau)} $$
    • $sim(Β·,Β·)$: cosine similarity, which is the dot product between $\ell_2$ normalized vectors $u$ and $v$ (i.e., $\text{sim}(u, v) = u^T v / |u| |v|$)
    • $\tau$: temperature parameter. Using $\ell_2$ normalization along with an appropriate temperature parameter effectively weights different examples and helps the model learn from hard negatives

🎯 Downstream Tasks

After the self-supervised pretraining phase is complete, the projection head $g(\cdot)$ is discarded, and the encoder network $f(\cdot)$ and the representation $h$ (the layer before the projection head) are used as the feature extractor for the downstream tasks.

  • Linear Evaluation Protocol
  • Semi-Supervised Learning
  • Transfer Learning

πŸ’‘ Strengths

  • Simplicity and Architecture Agnosticism: SimCLR is a simple framework for contrastive learning. It does not require specialized architectures or an explicit memory bank. The framework allows for various choices of network architecture without constraints, typically adopting the standard ResNet.
  • State-of-the-Art Performance: By combining its core components, SimCLR considerably outperforms previous methods in self-supervised, semi-supervised, and transfer learning on ImageNet.
    • Its representations achieve 76.5% top-1 accuracy on linear evaluation, matching the performance of a supervised ResNet-50.
    • When fine-tuned with only 1% of ImageNet labels, it achieves 85.8% top-5 accuracy.
  • Effective Feature Extraction for High-Level Tasks: SimCLR, as a Joint Embedding method, produces highly semantic features, which are great for classification tasks. The representations achieve competitive results in linear probing.
  • Improved Representation Quality: The introduction of a learnable nonlinear projection head ($g(\cdot)$) substantially improves the quality of the learned representations in the layer before the head ($h$)

⚠️ Limitations

  • Reliance on Large Batch Sizes: Contrastive learning with SimCLR benefits from significantly larger batch sizes (e.g., up to 8192 examples). This requires a very large batch size to accumulate sufficient negative examples.
  • High Computational and Training Demand:
    • The framework benefits from longer training compared to its supervised counterpart.
    • The model benefits more from bigger models (increased depth and width) than supervised learning.
    • Training with large batch sizes may be unstable using standard optimizers, requiring specialized methods like the LARS optimizer.
  • Need for Architectural/Optimization Specifics:
    • To prevent the model from exploiting local information shortcuts, SimCLR requires specialized techniques like Global Batch Normalization (BN) during distributed training.
    • Special care is required to handle negative samples/collapse. The use of in-batch negatives can lead to false negatives.
  • Sensitivity to Augmentation Policy: The system requires tuning the augmentations. Specifically, the composition of multiple data augmentation operations is crucial for defining effective predictive tasks and learning good representations.
  • Suboptimal for Low-Level Tasks: SimCLR, being a Joint Embedding model, is not fit for low-level tasks such as denoising or superresolution (compared to Masked Image Modeling approaches).

πŸ“š Reference


See also