distill
【免费下载链接】amctAMCT是CANN提供的昇腾AI处理器亲和的模型压缩工具仓。项目地址: https://gitcode.com/cann/amct
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
蒸馏接口,将输入的待蒸馏的图结构按照给定的蒸馏量化配置文件进行蒸馏处理,返回修改后的torch.nn.Module蒸馏模型。
函数原型
distill_model = distill(model, compress_model, config_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=None, optimizer=None)参数说明
返回值说明
修改后的torch.nn.Module蒸馏模型。
调用示例
import amct_pytorch as amct # 建立待进行蒸馏量化的网络图结构 model = build_model() model.load_state_dict(torch.load(state_dict_path)) compress_model = compress(model) input_data = tuple([torch.randn(input_shape)]) train_loader = torch.utils.data.DataLoader(input_data) loss = torch.nn.MSELoss() optimizer = torch.optim.AdamW(compress_model.parameters(), lr=0.1) # 蒸馏 distill_model = amct.distill( model, compress_model config_json_file, train_loader, epochs=1, lr=1e-3, sample_instance=None, loss=loss, optimizer=optimizer)【免费下载链接】amctAMCT是CANN提供的昇腾AI处理器亲和的模型压缩工具仓。项目地址: https://gitcode.com/cann/amct
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考