SwAV: Swapping Assignments between Views

🏷️ Model Name

SwAV – Swapping Assignments between Views

🧠 Core Idea

Simultaneously cluster the data and learn visual representations by enforcing consistency between cluster assignments, or “codes,” generated from different augmented views of the same image.

SwAV architecture

πŸ–ΌοΈ Architecture

Input Image x
 β”œβ”€β”€> Augmentation 1 (x₁)
 β”‚     └──> Encoder f_ΞΈ
 β”‚            └──> Projection g_ΞΈ
 β”‚                   └──> Feature z₁
 β”‚                          └──> Compute assignment q₁ to prototypes C
 β”‚
 └──> Augmentation 2 (xβ‚‚)
       └──> Encoder f_ΞΈ
              └──> Projection g_ΞΈ
                     └──> Feature zβ‚‚
                            └──> Compute assignment qβ‚‚ to prototypes C

Then

Predict qβ‚‚ from z₁, and predict q₁ from zβ‚‚
Loss = cross_entropy(pred₁→qβ‚‚) + cross_entropy(predβ‚‚β†’q₁)

The SwAV architecture consists of two main components that are trained jointly:

  • Image Encoder ($f_{\theta}$): This is the convolutional network (e.g., ResNet) used to extract features from the input images. The features are then passed through a projection head, typically a Multi-Layer Perceptron (MLP), to produce a feature vector $z$.
  • Prototypes ($C$): A set of $K$ trainable prototype vectors, $C = {c_1, \dots, c_K}$, which act as cluster centers. The matrix $C$ contains these vectors as columns.

1️⃣ Multi-Crop Data Augmentation

The input image ($x$) is transformed into multiple augmented views ($V+2$ views). This strategy is called multi-crop:

  • Global Views: Two standard resolution crops (e.g., $2 \times 224 \times 224$ pixels) are sampled.
  • Small Views: $V$ additional low-resolution crops (e.g., $V \times 96 \times 96$ pixels) are sampled, which cover only small parts of the image.
  • The resulting views are $x^{t_1}, x^{t_2}, \dots, x^{t_{V+2}}$

2️⃣ Feature Extraction ($f_{\theta}$) and Normalization ($\ell_2$)

All $V+2$ augmented views are passed through the image encoder $f_{\theta}$:

  1. For any augmented view $x^{n t}$, the encoder produces a feature representation $f_{\theta}(x^{n t})$.
  2. This feature is then $\ell_2$-normalized (projected to the unit sphere) to yield the final feature vector $z^{n t}$.

3️⃣ Online Code (Cluster Assignment) Computation

SwAV computes image “codes” or soft cluster assignments ($Q$) for the features in the current batch ($B$) by mapping them to the prototypes ($C$).

  1. Similarity Calculation: The features of the batch ($Z = [z_1, \dots, z_B]$) are compared to the prototypes ($C$).
  2. Optimal Transport: The assignment process maximizes the similarity between features and prototypes while enforcing an equipartition constraint.
    • This constraint ensures that, on average, each prototype is selected equally often in the batch, preventing trivial solutions where all samples receive the same code.
  3. Sinkhorn-Knopp Algorithm: The soft assignment matrix $Q^*$ (the codes) is obtained by solving this optimal transport problem using a few iterations (e.g., 3 iterations) of the iterative Sinkhorn-Knopp algorithm.
  4. Code Input Restriction: Importantly, the codes ($q$) are typically computed using only the full-resolution crops (the two global views), not the small crops, to avoid degrading assignment quality due to partial information.

4️⃣ Swapped Prediction Loss

The core of SwAV is the “swapped” prediction problem, where the code generated by one augmented view is predicted by the feature representation of another augmented view of the same image.

  1. Swapped Loss Components: For any pair of features ($z^t, z^s$) and their corresponding codes ($q^t, q^s$), the loss minimizes the cross-entropy between:
    • The prediction derived from feature $z^t$ and the target code $q^s$.
    • The prediction derived from feature $z^s$ and the target code $q^t$. $$L(z^t, z^s) = \mathcal{L}(z^t, q^s) + \mathcal{L}(z^s, q^t)$$
  2. Generalization to Multi-Crop: The total loss generalizes this concept across all $V+2$ views. For the views $z^{t_1}, \dots, z^{t_{V+2}}$, the loss ensures that every view $z^{t_v}$ can predict the code $q^{t_i}$ of the full-resolution crops (where $i \in {1, 2}$).

5️⃣ Joint Parameter Optimization

The loss function is jointly minimized using stochastic optimization (e.g., SGD with LARS optimizer):

  1. Encoder Update ($\theta$): The parameters of the image encoder ($f_{\theta}$) are updated via backpropagation to minimize the swapped prediction loss.
  2. Prototype Update ($C$): The prototype vectors ($C$) are also learned by backpropagation and updated jointly with the ConvNet parameters.
  3. Prototype Normalization: After the update, the prototype vectors $C$ are normalized.

🎯 Downstream Tasks

  • ImageNet Evaluation Protocols
  • Transfer Learning (Linear Classification on Frozen Features)
  • Object Detection and Instance Segmentation (Finetuning)

πŸ’‘ Strengths

  • Architectural and Efficiency Advantages
    • Avoids Pairwise Comparisons: SwAV takes advantage of contrastive learning concepts but does not require the computation of explicit pairwise feature comparisons. This simplifies the objective and improves tractability compared to classical contrastive methods.
    • Memory Efficiency: The method is more memory efficient because it does not require a large memory bank (like NPID or MoCo).
    • No Momentum Encoder Required: SwAV does not require a special momentum network (like MoCo).
    • Scalability: SwAV is an online algorithm that allows features and codes to be learned online, enabling the method to scale to unlimited amounts of data.
    • Batch Size Flexibility: SwAV can be trained effectively with both large and small batches. When trained with a small batch size (256), it only needs to store a small queue of features (around 3,840 vectors), compared to 65,536 features required by MoCov2 for good performance.
  • Performance and Training Speed
    • State-of-the-Art Performance: SwAV achieves 75.3% top-1 accuracy on ImageNet with a standard ResNet-50 under the linear evaluation protocol, outperforming the prior state of the art by +4.2%.
    • Outperforms Supervised Pretraining: SwAV is the first self-supervised method reported to surpass supervised ImageNet pretraining on all considered transfer tasks (including linear classification tasks like Places205 and object detection tasks like VOC07+12 and COCO).
    • Faster Training Convergence: SwAV learns much faster than contrastive methods, reaching higher performance in four times fewer epochs than MoCov2 in the small batch setting. SwAV achieves strong performance (72.1% top-1 accuracy) after just 100 epochs (approx. 6 hours 15 minutes).
    • Scales with Architecture Capacity: The performance of SwAV consistently increases with the width and capacity of the model, shrinking the gap with supervised training to 0.6% for large architectures.
  • Multi-Crop Augmentation
    • Effective Data Augmentation Strategy: SwAV proposes the multi-crop strategy which uses a mix of views with different resolutions (e.g., two full-resolution crops and several low-resolution crops).
    • Consistent Performance Boost: Multi-crop consistently improves the performance of SwAV and other self-supervised methods (like SimCLR, DeepCluster, and SeLa) by a significant margin of 2% to 4% top-1 accuracy on ImageNet without increasing memory or computational requirements.

⚠️ Limitations

  • Computational and Speed Constraints
    • Slower Wall Clock Time Per Epoch: Although SwAV converges faster in terms of epochs, one epoch of SwAV is generally slower in wall clock time (1.2Γ— to 1.4Γ— slower) than an epoch of SimCLR or MoCov2, due to the additional back-propagation step and the Sinkhorn algorithm calculation.
    • Increased Computation with More Prototypes: While the number of prototypes has little influence on final accuracy (as long as there are “enough”), using more prototypes increases the computational time needed for both the Sinkhorn algorithm and back-propagation.
  • Dependence on Assignment Details
    • Soft vs. Hard Codes: Using the default soft assignments (continuous codes) performs better than using hard (discrete) assignments; hard assignments lead to a faster but worse solution.
    • Trivial Solution Risk: Like other clustering methods, SwAV must actively prevent the trivial solution where every image has the same code. This is managed by enforcing an equipartition constraint via the Sinkhorn-Knopp algorithm.
    • Sensitivity to Sinkhorn Iterations: If too few iterations are used in the Sinkhorn-Knopp algorithm (e.g., 1 iteration), the loss fails to converge.
    • Sensitivity to Regularization Parameter: A strong entropy regularization (high $\epsilon$) generally leads to a trivial solution where all samples collapse into an unique representation, requiring the parameter $\epsilon$ to be kept low in practice.
    • Code Quality Degradation with Small Crops: When using the multi-crop strategy, codes must be computed only using the full-resolution crops. Computing codes using the low-resolution crops (partial information) degrades the assignment quality and consequently alters the transfer performance of the resulting network.

πŸ“š Reference


See also