本文共 12942 字,大约阅读时间需要 43 分钟。
本文介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。文中主要探讨了两种蒸馏方式,并详细讲解了如何实现第二种方法,即通过卷积神经网络蒸馏DEiT模型。
# Copyright (c) 2015-present, Facebook, Inc. All rights reserved.import torchimport torch.nn as nnfrom functools import partialfrom timm.models.vision_transformer import VisionTransformer, _cfgfrom timm.models.registry import register_modelfrom timm.models.layers import trunc_normal# 注册模型@register_modeldef deit_tiny_patch16_224(pretrained=False, **kwargs): model = VisionTransformer( patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs ) model.default_cfg = _cfg() if pretrained: checkpoint = torch.hub.load_state_dict_from_url( url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", map_location="cpu", check_hash=True ) model.load_state_dict(checkpoint["model"]) return model
# Copyright (c) 2015-present, Facebook, Inc. All rights reserved."""Implements the knowledge distillation loss"""import torchfrom torch.nn import functional as Fclass DistillationLoss(torch.nn.Module): """This module wraps a standard criterion and adds an extra knowledge distillation loss by taking a teacher model prediction and using it as additional supervision.""" def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, distillation_type: str, alpha: float, tau: float): super().__init__() self.base_criterion = base_criterion self.teacher_model = teacher_model assert distillation_type in ['none', 'soft', 'hard'] self.distillation_type = distillation_type self.alpha = alpha self.tau = tau def forward(self, inputs, outputs, labels): """Args: inputs: The original inputs that are feed to the teacher model outputs: the outputs of the model to be trained. It is expected to be either a Tensor, or a Tuple[Tensor, Tensor], with the original output in the first position and the distillation predictions as the second output labels: the labels for the base criterion """ outputs_kd = None if not isinstance(outputs, torch.Tensor): outputs, outputs_kd = outputs base_loss = self.base_criterion(outputs, labels) if self.distillation_type == 'none': return base_loss if outputs_kd is None: raise ValueError("When knowledge distillation is enabled, the model is " "expected to return a Tuple[Tensor, Tensor] with the output of the " "class_token and the dist_token") # don't backprop through the teacher with torch.no_grad(): teacher_outputs = self.teacher_model(inputs) if self.distillation_type == 'soft': T = self.tau distillation_loss = F.kl_div( F.log_softmax(outputs_kd / T, dim=1), F.log_softmax(teacher_outputs / T, dim=1), reduction='sum', log_target=True ) * (T * T) / outputs_kd.numel() elif self.distillation_type == 'hard': distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha return loss
# 导入必要的库import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom torchvision import datasetsfrom torch.autograd import Variablefrom timm.models import regnetx_160import jsonimport os# 定义训练过程def seed_everything(seed=42): os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True# 训练函数def train(model, device, train_loader, optimizer, epoch): model.train() sum_loss = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) out = model(data) loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item())) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程@torch.no_grad()def val(model, device, test_loader): global Best_ACC model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = Variable(data).to(device), Variable(target).to(device) out = model(data) loss = criterion(out, target) _, pred = torch.max(out.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) if acc > Best_ACC: torch.save(model, file_dir + '/' + 'best.pth') Best_ACC = acc print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) return acc
if __name__ == '__main__': # 创建保存模型的文件夹 file_dir = 'TeacherModel' if os.path.exists(file_dir): print('true') os.makedirs(file_dir, exist_ok=True) else: os.makedirs(file_dir) # 设置全局参数 modellr = 1e-4 BATCH_SIZE = 16 EPOCHS = 100 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') SEED = 42 seed_everything(SEED)
# 导入必要的库import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom torchvision import datasetsfrom models.models import deit_tiny_distilled_patch16_224import jsonimport os# 定义训练过程def seed_everything(seed=42): os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True# 定义训练过程def train(model, device, train_loader, optimizer, epoch): model.train() sum_loss = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = Variable(data).to(device), Variable(target).to(device) out = model(data)[0] loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item())) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程@torch.no_grad()def val(model, device, test_loader): global Best_ACC model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = Variable(data).to(device), Variable(target).to(device) out = model(data) loss = criterion(out, target) _, pred = torch.max(out.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) if acc > Best_ACC: torch.save(model, file_dir + '/' + 'best.pth') Best_ACC = acc print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) return acc
if __name__ == '__main__': # 创建保存模型的文件夹 file_dir = 'StudentModel' if os.path.exists(file_dir): print('true') os.makedirs(file_dir, exist_ok=True) else: os.makedirs(file_dir) # 设置全局参数 modellr = 1e-4 BATCH_SIZE = 16 EPOCHS = 100 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') SEED = 42 seed_everything(SEED)
# 导入必要的库import torch.optim as optimimport torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.utils.data.distributedimport torchvision.transforms as transformsfrom timm.loss import LabelSmoothingCrossEntropyfrom torchvision import datasetsfrom models.models import deit_tiny_distilled_patch16_224import jsonimport osfrom losses import DistillationLoss# 设置随机因子def seed_everything(seed=42): os.environ['PYTHONHASHSEED'] = str(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True# 定义训练过程def train(s_net, t_net, device, criterionKD, train_loader, optimizer, epoch): s_net.train() sum_loss = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() out_s = s_net(data) loss = criterionKD(data, out_s, target) loss.backward() optimizer.step() print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item())) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程@torch.no_grad()def val(model, device, criterionCls, test_loader): global Best_ACC model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) out_s = model(data) loss = criterionCls(out_s, target) _, pred = torch.max(out_s.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) if acc > Best_ACC: torch.save(model, file_dir + '/' + 'best.pth') Best_ACC = acc print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) return acc
if __name__ == '__main__': # 创建保存模型的文件夹 file_dir = 'KDModel' if os.path.exists(file_dir): print('true') os.makedirs(file_dir, exist_ok=True) else: os.makedirs(file_dir) # 设置全局参数 modellr = 1e-4 BATCH_SIZE = 4 EPOCHS = 100 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') SEED = 42 seed_everything(SEED) distillation_type = 'hard' distillation_alpha = 0.5 distillation_tau = 1.0
# 导入必要的库import numpy as npimport matplotlib.pyplot as plt# 定义文件路径teacher_file = 'result.json'student_file = 'result_student.json'student_kd_file = 'result_kd.json'# 定义读取函数def read_json(file): with open(file, 'r', encoding='utf8') as fp: json_data = json.load(fp) print(json_data) return json_data# 读取数据teacher_data = read_json.teacher_file()student_data = read_json.student_file()student_kd_data = read_json.student_kd_file()# 绘制图表x = list(range(len(teacher_data)))plt.plot(x, list(teacher_data.values()), label='teacher')plt.plot(x, list(student_data.values()), label='student without IRG')plt.plot(x, list(student_kd_data.values()), label='student with IRG')plt.title('Test accuracy')plt.legend()plt.show()
本文详细介绍了如何通过外部蒸馏算法对DeiT模型进行知识蒸馏。通过实验结果可以看出,采用硬蒸馏方式能够有效提升学生网络的性能。代码和数据集详见链接:链接
转载地址:http://sjdbz.baihongyu.com/