news 2026/5/4 18:12:30

如何实现最先进的屏蔽自动编码器(MAE)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
如何实现最先进的屏蔽自动编码器(MAE)

原文:towardsdatascience.com/how-to-implement-state-of-the-art-masked-autoencoders-mae-6f454b736087

嗨,大家好!对于那些还不认识我的人来说,我叫弗朗索瓦,我是 Meta 的研究科学家。我对解释高级人工智能概念并使它们更容易理解充满热情。

今天,我很兴奋地深入探讨计算机视觉在视觉变换器之后的最重要的突破之一:屏蔽自动编码器(MAE)。本文是我之前文章的实践实现伴侣:屏蔽自动编码器(MAE)的终极指南

对于下面的教程,我们将使用这个 GitHub 仓库上的代码:

GitHub – FrancoisPorcher/awesome-ai-tutorials: 使你成为 AI 高手的最佳 AI 教程集合

这里是一个简要的提醒,说明它是如何工作的:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e847e324b22743184534273b00622cb2.png

图像来自文章 MAE are Scalable Learners

这是该方法的工作原理:

  1. 图像被分割成块。

  2. 这些块中的一小部分被随机屏蔽。

  3. 只有可见的块被输入到编码器中(这是至关重要的)。

  4. 解码器接收来自编码器的压缩表示,并尝试使用可见和屏蔽的块重建整个图像。

  5. 损失仅在屏蔽的块上计算。

让我们深入代码!

导入

  • einops:用于其“重复”功能

  • architectures.vit:标准 ViT 变换器的架构,我使用的是在 如何训练一个 ViT? 中提供的版本。

importtorchfromtorchimportnnimporttorch.nn.functionalasFfromeinopsimportrepeatfromarchitectures.vitimportTransformer

设置 MAE 类:

classMAE(nn.Module):def__init__(self,*,encoder,decoder_dim,masking_ratio=0.75,decoder_depth=1,decoder_heads=8,decoder_dim_head=64):super().__init__()# Ensure the masking ratio is validassert0<masking_ratio<1,'masking ratio must be between 0 and 1'self.masking_ratio=masking_ratio

我们定义了一个从 PyTorch 的nn.Module继承的类 MAE。

  • encoder:我们的视觉变换器模型。

  • decoder_dim:解码器嵌入空间的维度(例如 512)。

  • masking_ratio:要屏蔽的块的比例(文章发现 75%是最优的)。

  • 其他解码器配置,如depthheadsheaddimensions,这些是变换器的标准。

  • 我们断言屏蔽率在 0 和 1 之间。

块:

# Save the encoder (a Vision Transformer to be trained)self.encoder=encoder# Extract the number of patches and the encoder's dimensionality from the positional embeddingsnum_patches,encoder_dim=encoder.pos_embedding.shape[-2:]# Separate the patch embedding layers from the encoder# The first layer converts the image into patchesself.to_patch=encoder.to_patch_embedding[0]# The remaining layers embed the patchesself.patch_to_emb=nn.Sequential(*encoder.to_patch_embedding[1:])

发生了什么?

• 我们存储编码器并提取关键信息,如块的数量编码器的输出维度

我们将块嵌入过程分开:

  • self.to_patch:这一层将图像分割成更小的块。

  • self.patch_to_emb:这将每个块嵌入到向量空间中。

# Determine the dimensionality of the pixel values per patchpixel_values_per_patch=encoder.to_patch_embedding[2].weight.shape[-1]
  • 我们计算每个块中有多少像素值,我们稍后会需要这些信息。

设置解码器

  • self.enc_to_dec: 如果编码器和解码器具有不同的维度,我们将它们相应地映射。通常编码器较大,维度较高(例如 1024),而解码器可以更浅,维度更小(例如 512),但我们需要一个适配器来将编码器的维度映射到解码器的维度。

  • self.mask_token:一个可学习的标记,代表解码器中的掩码补丁。这是解码器在补丁被掩码时看到的标记。

  • 我们初始化解码器变压器和其他用于重建所需的层。

self.decoder=Transformer(dim=decoder_dim,depth=decoder_depth,heads=decoder_heads,dim_head=decoder_dim_head,mlp_dim_ratio=4)# Positional embeddings for the decoder tokensself.decoder_pos_emb=nn.Embedding(num_patches,decoder_dim)# Linear layer to reconstruct pixel values from decoder outputsself.to_pixels=nn.Linear(decoder_dim,pixel_values_per_patch)

到目前为止,你的 MAE 类应该初始化如下:

classMAE(nn.Module):def__init__(self,*,encoder,decoder_dim,masking_ratio=0.75,decoder_depth=1,decoder_heads=8,decoder_dim_head=64):super().__init__()# Ensure the masking ratio is validassert0<masking_ratio<1,'masking ratio must be between 0 and 1'self.masking_ratio=masking_ratio# Save the encoder (a Vision Transformer to be trained)self.encoder=encoder# Extract the number of patches and the encoder's dimensionality from the positional embeddingsnum_patches,encoder_dim=encoder.pos_embedding.shape[-2:]# Separate the patch embedding layers from the encoder# The first layer converts the image into patchesself.to_patch=encoder.to_patch_embedding[0]# The remaining layers embed the patchesself.patch_to_emb=nn.Sequential(*encoder.to_patch_embedding[1:])# Determine the dimensionality of the pixel values per patchpixel_values_per_patch=encoder.to_patch_embedding[2].weight.shape[-1]# Set up decoder parametersself.decoder_dim=decoder_dim# Map encoder dimensions to decoder dimensions if they differself.enc_to_dec=(nn.Linear(encoder_dim,decoder_dim)ifencoder_dim!=decoder_dimelsenn.Identity())# Learnable mask token for masked patchesself.mask_token=nn.Parameter(torch.randn(decoder_dim))# Define the decoder transformerself.decoder=Transformer(dim=decoder_dim,depth=decoder_depth,heads=decoder_heads,dim_head=decoder_dim_head,mlp_dim_ratio=4)# Positional embeddings for the decoder tokensself.decoder_pos_emb=nn.Embedding(num_patches,decoder_dim)# Linear layer to reconstruct pixel values from decoder outputsself.to_pixels=nn.Linear(decoder_dim,pixel_values_per_patch)

太好了!现在让我们看看如何在正向传递中使用这些不同的部分,这有点像拼图。

前向传递

让我们逐步分析正向函数,它定义了我们的模型如何处理输入数据。

defforward(self,img):device=img.device# Convert the input image into patchespatches=self.to_patch(img)# Shape: (batch_size, num_patches, patch_size)batch_size,num_patches,*_=patches.shape# Embed the patches using the encoder's patch embedding layerstokens=self.patch_to_emb(patches)# Shape: (batch_size, num_patches, encoder_dim)

开始部分非常标准,我们只需要分解“图像补丁化”操作与“投影到标记”操作,因为我们使用原始补丁作为计算损失的真实值。

  • 前向方法接收一个图像张量img作为输入。

  • 我们获取张量所在设备(CPU 或 GPU)。

  • 我们将图像分割成补丁。

  • 我们获取batch sizenumber of patches

  • 每个补丁被嵌入到一个向量中。

位置编码:

# Add positional embeddings to the tokensifself.encoder.pool=="cls":# If using CLS token, skip the first positional embeddingtokens+=self.encoder.pos_embedding[:,1:num_patches+1]elifself.encoder.pool=="mean":# If using mean pooling, use all positional embeddingstokens+=self.encoder.pos_embedding.to(device,dtype=tokens.dtype)
  • 我们为每个标记添加位置信息,以便模型知道每个补丁的来源。如果有额外的CLS标记,我们需要跳过它,因为它不是图像的一部分。

掩码和编码

现在我们来到了最有趣的部分,对图像进行掩码。

# Determine the number of patches to masknum_masked=int(self.masking_ratio*num_patches)# Generate random indices for maskingrand_indices=torch.rand(batch_size,num_patches,device=device).argsort(dim=-1)masked_indices=rand_indices[:,:num_masked]unmasked_indices=rand_indices[:,num_masked:]
  • 我们根据掩码比例计算我们将要掩码的补丁数量。

  • 我们为每个补丁序列生成一个随机排列。

  • 我们相应地定义masked_indicesunmasked_indices

# Select the tokens corresponding to unmasked patchesbatch_range=torch.arange(batch_size,device=device)[:,None]tokens=tokens[batch_range,unmasked_indices]# Select the original patches that are masked (for reconstruction loss)masked_patches=patches[batch_range,masked_indices]# Encode the unmasked tokens using the encoder's transformerencoded_tokens=self.encoder.transformer(tokens)
  • 我们选择具有相应masked_indicesmasked_patches

  • 我们只为未掩码的补丁保留标记进行编码。

解码

现在让我们跳到最激动人心但也是最困难的部分,解码!

# Map encoded tokens to decoder dimensions if necessarydecoder_tokens=self.enc_to_dec(encoded_tokens)# Add positional embeddings to the decoder tokens of unmasked patchesunmasked_decoder_tokens=decoder_tokens+self.decoder_pos_emb(unmasked_indices)# Create mask tokens for the masked patches and add positional embeddingsmask_tokens=repeat(self.mask_token,'d -> b n d',b=batch_size,n=num_masked)mask_tokens=mask_tokens+self.decoder_pos_emb(masked_indices)# Initialize the full sequence of decoder tokensdecoder_sequence=torch.zeros(batch_size,num_patches,self.decoder_dim,device=device)# Place unmasked decoder tokens and mask tokens in their original positionsdecoder_sequence[batch_range,unmasked_indices]=unmasked_decoder_tokens decoder_sequence[batch_range,masked_indices]=mask_tokens# Decode the full sequencedecoded_tokens=self.decoder(decoder_sequence)# Extract the decoded tokens corresponding to the masked patchesmasked_decoded_tokens=decoded_tokens[batch_range,masked_indices]
  • 我们通过self.enc_to_dec调整编码标记以匹配解码器期望的输入大小。

  • 我们为解码器标记添加位置嵌入。

  • 对于掩码位置,我们使用mask token并添加位置``嵌入

  • 我们通过将未掩码和掩码标记放回其原始位置来重建整个序列。

  • 我们将整个序列通过解码器传递。

  • 我们只提取与掩码补丁对应的解码标记。

# Reconstruct the pixel values from the masked decoded tokenspred_pixel_values=self.to_pixels(masked_decoded_tokens)# Compute the reconstruction loss (mean squared error)recon_loss=F.mse_loss(pred_pixel_values,masked_patches)returnrecon_loss
  • 我们尝试重建掩码补丁的原始像素值。

  • 我们通过比较重建的补丁与原始的掩码补丁来计算L2 损失

恭喜,你已经做到了!

感谢阅读!在你离开之前:

想要更多精彩的教程,请查看我在 Github 上的AI 教程汇编。

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

Y您应该会收到我的文章。在此订阅

如果您想获取 Medium 上的优质文章,只需每月支付 5 美元的会员费。如果您通过我的链接*注册**,您只需支付部分费用,无需额外成本即可支持我。*


如果您觉得这篇文章有见地且有益,请考虑关注我并为我点赞,以获取更多深入内容!您的支持帮助我继续创作有助于我们共同理解的内容。

参考文献

  • Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross B. Girshick.掩码自编码器是可扩展的视觉学习者.arXiv:2111.06377, 2021.arxiv.org/abs/2111.06377

  • github.com/lucidrains/vit-pytorch

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/1 6:13:36

Dify可视化开发模式对传统编码方式的颠覆

Dify可视化开发模式对传统编码方式的颠覆 在企业AI应用落地仍被高昂成本和复杂流程困扰的今天&#xff0c;一个市场部专员能否不写一行代码就上线一套智能客服系统&#xff1f;答案是肯定的——借助Dify这类新型开发平台&#xff0c;这已不再是设想。 过去构建一个基于大语言模…

作者头像 李华
网站建设 2026/5/2 5:05:15

adb卸载手机app

一、下载adb Windows版本&#xff1a;https://dl.google.com/android/repository/platform-tools-latest-windows.zip Mac版本&#xff1a;https://dl.google.com/android/repository/platform-tools-latest-windows.zip Linux版本&#xff1a;https://dl.google.com/android/r…

作者头像 李华
网站建设 2026/5/2 11:44:27

基于STM32的智能花盆系统设计与实现

基于STM32的智能花盆系统设计与实现 摘要 本文设计并实现了一种基于STM32F407VET6微控制器的智能花盆系统&#xff0c;通过多传感器融合与智能控制策略&#xff0c;实现了对植物生长环境的全方位监测与精准调控。系统集成DS18B20温度传感器、土壤湿度传感器、光敏电阻、雨滴传…

作者头像 李华
网站建设 2026/5/1 4:06:56

2025年华南师范大学计算机考研复试机试真题(附 AC 代码 + 解题思路)

2025年华南师范大学计算机考研复试机试真题 2025年华南师范大学计算机考研复试上机真题 历年华南师范大学计算机考研复试上机真题 历年华南师范大学计算机考研复试机试真题 更多学校题目开源地址&#xff1a;https://gitcode.com/verticallimit1/noobdream N 诺 DreamJudg…

作者头像 李华
网站建设 2026/5/3 12:59:41

leetcode热题岛屿数量

给你一个由 1&#xff08;陆地&#xff09;和 0&#xff08;水&#xff09;组成的的二维网格&#xff0c;请你计算网格中岛屿的数量。岛屿总是被水包围&#xff0c;并且每座岛屿只能由水平方向和/或竖直方向上相邻的陆地连接形成。此外&#xff0c;你可以假设该网格的四条边均被…

作者头像 李华