The Role of Masking for Efficient Supervised Knowledge Distillation of Vision Transformers

Pohang University of Science and Technology (POSTECH), South Korea .
ECCV 2024
Interpolate start reference image.

Our method, MaskedKD, reduces supervision cost by masking teacher ViT input based on student attention, maintaining student accuracy while saving computation.

Abstract

Knowledge distillation is an effective method for training lightweight vision models. However, acquiring teacher supervision for training samples is often costly, especially from large-scale models like vision transformers (ViTs). In this paper, we develop a simple framework to reduce the supervision cost of ViT distillation: masking out a fraction of input tokens given to the teacher. By masking input tokens, one can skip the computations associated with the masked tokens without requiring any change to teacher parameters or architecture. We find that masking patches with the lowest student attention scores is highly effective, saving up to 50% of teacher FLOPs without any drop in student accuracy, while other masking criterion leads to suboptimal efficiency gains. Through in-depth analyses, we reveal that the student-guided masking provides a good curriculum to the student, making teacher supervision easier to follow during the early stage and challenging in the later stage.

Framework

Interpolate start reference image.

MaskedKD(Masked Knowledge Distillation) is a very simple yet effective approach for reducing supervision cost by masking the teacher input tokens. Given a training image-label sample $(x, y)$, we generate a masked version $x_{mask}$ of the input image by removing some patches of the image $x$. Then, we train the student by using the following loss: \begin{equation} \ell(x,y) = \ell_{\mathtt{CE}}(f_S(x),y) + \lambda \cdot \ell_{\mathtt{KD}}(f_S(x),f_T(x_{\mathtt{mask}})),\label{eq:maskedkd} \end{equation} where $ℓ_{CE}$ denotes the cross-entropy loss, $ℓ_{KD}$ denotes the distillation loss, and $λ ≥ 0$ is a balancing hyperparameter.

Student-guided saliency score. To select the patches to be masked, we use the last layer attention scores of the student ViT as the patch saliency score: \begin{align} \mathbf{a}^{(h)} = \mathrm{Softmax}\left( \big(q_{\mathtt{cls}}^\top k_1,\:q_{\mathtt{cls}}^\top k_2,\:\cdots,\:q_{\mathtt{cls}}^\top k_N\big) / \sqrt{d}\right), \quad h \in \{1,2,\ldots,H\} \end{align} where $q_{\mathtt{cls}}$ is the query vector of the class token, $k_{i}$ is the key vector of the $i$-th image patch token, $d$ is the length of query and key vectors, and $H$ is the number of attention heads. The final patch saliency score is computed by taking the average $\bar{\mathbf{a}} = (\sum_{h=1}^H \mathbf{a}^{(h)})/H$.

Main Results

Interpolate start reference image.

MaskedKD dramatically reduces the supervision cost by 25-50% without degrading the student accuracy. We also observe that, in most cases, masking a small fraction of patches is beneficial for the performance of the trained student.

A Closer Look at MaskedKD

To answer why the proposed student-guided masking works well, we conduct an in-depth comparative analysis. Our observations can be summarized as follows:

  • Student-guided masking provides a good curriculum for distillation, enhancing the student training.
  • Masking the tokens at input preserved the supervision quality better than gradually removing tokens in the intermediate layers
  • In supervised knowledge distillation, masking the teacher model is beneficial, as masking the student significantly degrades accuracy, contrary to standard mask-based SSL practices.

Interpolate start reference image.

Compared with various masking mechanisms, student-guided masking (MaskedKD) is the sole approach that maintains or improves compared with traditional logit distillation, which acts as an implicit curriculum for distillation. During the early stage of training, masking makes the teachers less accurate, making it easier for the student to mimic their predictions.

Interpolate start reference image.

Left & Middle. Masking tokens at input (MaskedKD) is effective for preserving supervision quality while removing tokens at intermediate layers (ToMe) is better for keeping high prediction quality.

Right. Even at the low masking ratio, masking student degrades the final accuracy after distillation, while masking teacher does not. This contrasts with mask-based SSL, where masking student is essential.

Acknowledgements

This work was partly supported by the National Research Foundation of Korea (NRF) grant funded by the Korean government (MSIT) (RS2023-00213710, RS2023-00210466), and the Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korean government (MSIT) (RS-2019-II191906, Artificial Intelligence Graduate School Program (POSTECH), RS-2022-II220959, Few-Shot learning of Causal Inference in Vision and Language for Decision Making), and POSCO Creative Ideas grant (2023Q024, 2023Q032).

BibTeX

@inproceedings{son2023maskedkd,
          title={The Role of Masking for Efficient Supervised Knowledge Distillation of Vision Transformers},
          author={Son, Seungwoo and Ryu, Jegwang and Lee, Namhoon and Lee, Jaeho},
          booktitle={European Conference on Computer Vision},
          year={2024}
        }