news 2026/5/22 18:42:23

用PyTorch复现PoolFormer:一个用平均池化替换注意力的视觉Transformer,附完整代码

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
用PyTorch复现PoolFormer:一个用平均池化替换注意力的视觉Transformer,附完整代码

从零实现PoolFormer:用平均池化构建高效视觉Transformer的完整指南

在计算机视觉领域,Transformer架构正逐渐取代传统的卷积神经网络。然而,大多数视觉Transformer依赖于计算密集型的注意力机制。本文将带你用PyTorch实现一个革命性的替代方案——PoolFormer,它用简单的平均池化操作取代注意力机制,在保持高性能的同时大幅降低计算复杂度。

1. 环境准备与基础模块实现

首先确保你的开发环境已安装PyTorch 1.8+和torchvision。我们建议使用conda创建虚拟环境:

conda create -n poolformer python=3.8 conda activate poolformer pip install torch torchvision timm

PoolFormer的核心创新在于其token mixer设计。让我们先实现基础组件:

import torch import torch.nn as nn from timm.models.layers import DropPath class GroupNorm(nn.GroupNorm): """1-group GroupNorm的优化实现""" def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) class Pooling(nn.Module): """替代注意力的关键池化操作""" def __init__(self, pool_size=3): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size//2, count_include_pad=False ) def forward(self, x): return self.pool(x) - x # 残差连接设计

提示:Pooling层的减法操作形成了隐式的残差连接,这是模型稳定训练的关键

2. 构建PoolFormer核心块

完整的PoolFormerBlock包含标准化、池化token mixer和通道MLP三个主要部分:

class Mlp(nn.Module): """通道混合MLP,使用1x1卷积实现""" def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class PoolFormerBlock(nn.Module): """完整的PoolFormer基础块""" def __init__(self, dim, pool_size=3, mlp_ratio=4., act_layer=nn.GELU, norm_layer=GroupNorm, drop=0., drop_path=0., use_layer_scale=True, layer_scale_init_value=1e-5): super().__init__() self.norm1 = norm_layer(dim) self.token_mixer = Pooling(pool_size=pool_size) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.token_mixer(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x

关键实现细节:

  • LayerScale:小型可学习参数,帮助深层网络训练
  • DropPath:随机深度正则化技术
  • 1x1卷积MLP:比全连接层更适合图像数据

3. 实现Patch Embedding与下采样

视觉Transformer需要将图像分割为patch进行处理。以下是实现细节:

def to_2tuple(x): return (x, x) class PatchEmbed(nn.Module): """图像到token的嵌入层""" def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x class PatchMerging(nn.Module): """层级式下采样""" def __init__(self, dim, norm_layer=GroupNorm): super().__init__() self.dim = dim self.norm = norm_layer(4 * dim) self.reduction = nn.Conv2d(4 * dim, 2 * dim, 1, bias=False) def forward(self, x, H, W): B, N, C = x.shape x = x.view(B, H, W, C) # 空间2x2邻域合并 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) x = self.norm(x) x = self.reduction(x.transpose(1, 2)).transpose(1, 2) return x

4. 完整PoolFormer架构集成

现在我们可以组装完整的PoolFormer模型。以下是S24变体的实现:

class PoolFormer(nn.Module): """完整的PoolFormer实现""" def __init__(self, layers, embed_dims=None, mlp_ratios=None, downsamples=None, num_classes=1000, in_chans=3, norm_layer=GroupNorm, act_layer=nn.GELU, drop_rate=0., drop_path_rate=0., use_layer_scale=True, layer_scale_init_value=1e-5, **kwargs): super().__init__() self.num_classes = num_classes self.num_features = embed_dims[-1] # 构建stem层 self.stem = nn.Sequential( nn.Conv2d(in_chans, embed_dims[0], kernel_size=7, stride=4, padding=2), norm_layer(embed_dims[0]), ) # 构建层级结构 self.stages = nn.ModuleList() dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] cur = 0 for i in range(len(layers)): stage = nn.Sequential( *[PoolFormerBlock( dim=embed_dims[i], pool_size=3, mlp_ratio=mlp_ratios[i], act_layer=act_layer, norm_layer=norm_layer, drop=drop_rate, drop_path=dp_rates[cur + j], use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value) for j in range(layers[i])] ) self.stages.append(stage) cur += layers[i] if i < len(layers)-1 and downsamples[i]: self.stages.append( nn.Sequential( norm_layer(embed_dims[i]), nn.Conv2d(embed_dims[i], embed_dims[i+1], kernel_size=3, stride=2, padding=1), ) ) # 分类头 self.head = nn.Sequential( norm_layer(self.num_features), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity(), ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_features(self, x): x = self.stem(x) for stage in self.stages: x = stage(x) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def poolformer_s24(pretrained=False, **kwargs): """PoolFormer-S24模型,参数量21M""" layers = [4, 4, 12, 4] embed_dims = [64, 128, 320, 512] mlp_ratios = [4, 4, 4, 4] downsamples = [True, True, True, True] model = PoolFormer(layers, embed_dims=embed_dims, mlp_ratios=mlp_ratios, downsamples=downsamples, **kwargs) return model

模型架构特点:

  • 四阶段层级设计:类似ResNet的渐进式下采样
  • 渐进式通道增加:64 → 128 → 320 → 512
  • 深度不均衡分配:多数块集中在第三阶段(12个)

5. 训练技巧与实战应用

要让PoolFormer达到论文报告的82.1% ImageNet top-1准确率,需要注意以下训练细节:

from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR def get_optimizer(model, lr=5e-4, weight_decay=0.05): # 排除特定参数 decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if name.endswith("bias") or name.endswith("LayerNorm.weight"): no_decay_params.append(param) else: decay_params.append(param) optimizer_grouped_parameters = [ {"params": decay_params, "weight_decay": weight_decay}, {"params": no_decay_params, "weight_decay": 0.0}, ] return AdamW(optimizer_grouped_parameters, lr=lr) def get_scheduler(optimizer, epochs=300, warmup_epochs=20): return CosineAnnealingLR( optimizer, T_max=epochs - warmup_epochs, eta_min=1e-6 )

关键训练参数配置:

参数推荐值说明
批量大小1024使用梯度累积时实际可更小
基础学习率5e-4使用线性warmup
权重衰减0.05分层设置(偏置/LayerNorm除外)
数据增强RandAugment强度9, mixture 0.8
标签平滑0.1提升模型泛化能力
DropPath率0.1随机深度正则化

实际部署时,PoolFormer相比传统Transformer有明显优势:

# 计算复杂度对比 def calculate_flops(model, input_size=(1, 3, 224, 224)): from fvcore.nn import FlopCountAnalysis flops = FlopCountAnalysis(model, torch.randn(input_size)) return flops.total() vit_flops = calculate_flops(vit_base_patch16_224()) # 约17.6G poolformer_flops = calculate_flops(poolformer_s24()) # 约3.6G

在自定义数据集上微调PoolFormer:

import timm from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from torchvision.transforms import Compose, Resize, ToTensor, Normalize # 准备数据集 transform = Compose([ Resize(256), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) dataset = ImageFolder("path/to/your/data", transform=transform) loader = DataLoader(dataset, batch_size=64, shuffle=True) # 加载预训练模型 model = poolformer_s24(pretrained=True) model.head[-1] = nn.Linear(model.num_features, len(dataset.classes)) # 微调训练循环 optimizer = get_optimizer(model, lr=1e-4) scheduler = get_scheduler(optimizer, epochs=50) for epoch in range(50): model.train() for images, labels in loader: outputs = model(images) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() optimizer.step() optimizer.zero_grad() scheduler.step()

常见问题解决方案:

  1. 训练不稳定:减小学习率或增加warmup周期
  2. 过拟合:增强数据增强或增大DropPath率
  3. 显存不足:使用梯度累积或混合精度训练
  4. 收敛慢:检查学习率调度器是否正常工作
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/22 18:36:06

2026 转行网络安全全解析:薪资待遇、日常工作与行业前景

2026年转行网络安全&#xff1a;薪资详解工作安排前景分析&#xff08;新手必看&#xff09; 2026年&#xff0c;数字化转型进入深水区&#xff0c;网络威胁呈现复杂化、智能化特征&#xff0c;APT攻击、数据泄露等安全事件频发&#xff0c;叠加《网络安全法》《数据安全法》的…

作者头像 李华
网站建设 2026/5/22 18:33:16

ZenTimings终极指南:轻松监控AMD Ryzen内存时序的免费神器

ZenTimings终极指南&#xff1a;轻松监控AMD Ryzen内存时序的免费神器 【免费下载链接】ZenTimings 项目地址: https://gitcode.com/gh_mirrors/ze/ZenTimings 想要深入了解你的AMD Ryzen系统性能吗&#xff1f;ZenTimings是一款专为AMD Ryzen处理器设计的强大监控工具…

作者头像 李华
网站建设 2026/5/22 18:31:18

京东自动抢购终极指南:Python脚本轻松搞定限量秒杀

京东自动抢购终极指南&#xff1a;Python脚本轻松搞定限量秒杀 【免费下载链接】autobuy-jd 使用python语言的京东平台抢购脚本 项目地址: https://gitcode.com/gh_mirrors/au/autobuy-jd 还在为抢不到心仪商品而烦恼吗&#xff1f;Autobuy-JD 京东自动抢购工具为您提供…

作者头像 李华
网站建设 2026/5/22 18:29:33

如何快速配置Windows系统区域设置:终极区域语言模拟工具指南

如何快速配置Windows系统区域设置&#xff1a;终极区域语言模拟工具指南 【免费下载链接】Locale_Remulator System Region and Language Simulator. 项目地址: https://gitcode.com/gh_mirrors/lo/Locale_Remulator Locale Remulator是一款强大的系统区域和语言模拟工具…

作者头像 李华
网站建设 2026/5/22 18:27:59

锐捷BGP路由反射器实战:如何用一台设备搞定IBGP全互联难题?

锐捷BGP路由反射器实战&#xff1a;如何用一台设备搞定IBGP全互联难题&#xff1f; 在构建中型以上企业网络时&#xff0c;BGP协议常被用于核心层的路由交换。传统IBGP全互联架构虽然稳定可靠&#xff0c;但随着网络规模扩大&#xff0c;其配置复杂度和维护成本呈指数级增长。想…

作者头像 李华