知识蒸馏的诞生背景
近年来,深度神经网络(DNN)在工业界和学术界都取得了巨大成功,尤其是在计算机视觉任务方面。深度学习的成功很大程度上归功于其具有数十亿参数的用于编码数据的可扩展性架构,其训练目标是在已有的训练数据集上建模输入和输出之间的关系,其性能高度依赖于网络的复杂程度以及标注训练数据的数量和质量。
相比于计算机视觉领域的传统算法,大多数基于 DNN 的模型都因为过参数化而具备强大的泛化能力。这种泛化能力体现在对某个问题输入的所有数据,模型都能给出较好的预测结果,无论是训练数据、测试数据,还是属于该问题的未知数据。
在当前深度学习的背景下,算法工程师为了提升业务算法的预测效果,常常会有两种方案:
01使用过参数化的更复杂的网络,这类网络学习能力非常强,但需要大量的计算资源来训练,并且推理速度较慢。
02集成模型将许多效果弱一些的模型集成起来,通常包括参数的集成和结果的集成。
这两种方案能显著提升现有算法的效果,但都提升了模型的规模,产生了较大的计算负担,需要的计算和存储资源很大。
在工作中,各种算法模型的最终目的都是要服务于某个应用。就像在买卖中我们需要控制收入和支出一样,工业应用中,除了要求模型要有好的预测以外,计算资源的使用也要严格控制,不能只考虑结果不考虑效率。在输入数据编码量高的计算机视觉领域,计算资源更显有限,控制算法的资源占用就更为重要。
通常来说,规模较大的模型预测效果更好,但训练时间长、推理速度慢的问题使得模型难以实时部署,尤其在视频监控、自动驾驶汽车和高吞吐量云端环境等计算资源有限的设备上,响应速度显然不够用。规模较小的模型虽然推理速度较快,但是因为参数量不足,推理效果和泛化性能可能就没那么好。如何权衡大规模模型和小规模模型一直是一个热门话题,当前的解决方法大多是根据部署环境的终端设备性能选择合适规模的 DNN 模型。
如果我们希望有一个规模较小的模型,能在保持较快推理速度的前提下,达到和大模型相当或接近的效果该如何做到呢?
从头训练一个小模型,从经验上看是很难达到效果的。在机器学习中,我们常常假定输入到输出有一个潜在的映射函数关系,这个函数是未知的:从头学习一个新模型就是从输入数据和对应标签中近似一个未知的映射函数。一般来说,提升一个算法的性能最有效的方式是标注更多的输入数据,也就是提供更多的监督信息,这可以让学习到的映射函数更具鲁棒性,性能更好。
举两个例子,在计算机视觉领域中,实例分割任务通过额外提供掩膜信息,可以提高目标包围框检测的效果;迁移学习任务通过提供在更大数据集上的预训练模型,显著提升新任务的预测效果。因此提供更多的监督信息,可能是缩短小规模模型和大规模模型差距的关键。
图1 带有师生学习框架的 KD 方法的插图[3]。(a)模型压缩 (b)(c)知识迁移,如半监督学习和自监督学习
按照之前的说法,想要获取更多的监督信息意味着标注更多的训练数据,这往往需要巨大的成本,那么有没有一种低成本又高效的监督信息获取方法呢?2006 年的文献[1]中指出,可以让新模型近似(approximate)原模型(模型即函数)。因为原模型的函数是已知的,新模型训练时等于天然地增加了更多的监督信息,这显然要更可行。
进一步思考,原模型带来的监督信息可能蕴含着不同维度的知识,这些与众不同的信息可能是新模型仅凭自己不能捕捉到的,在某种程度上来说,这对于新模型也是一种“跨域”的学习。
2015 年 Hinton 在论文《Distilling the Knowledge in a Neural Network》[2]中率先提出“知识蒸馏(Knowledge Distillation, KD)”的概念:可以先训练出一个大而强的模型,然后将其包含的知识转移给小的模型,就实现了在“保持小模型较快推理速度的同时,达到和大模型相当或接近的效果”的目的。
这其中先训练的大模型可以称之为教师模型,后训练的小模型可以称之为学生模型,整个训练过程可以形象地比喻为“师生学习”。随后几年,涌现了大量的知识蒸馏工作,为工业界提供了更多新的解决思路。目前,KD 已广泛应用于两个不同的领域:模型压缩和知识迁移。[3]
网易易盾作为新一代数字内容风控服务商,利用人工智能技术构筑起坚固的内容安全防线。易盾算法团队也通过知识蒸馏的应用,在项目指标上取得了显著提升,有效改进了内容安全的精细化识别能力。本文主要介绍知识蒸馏领域的两种经典算法,一篇是基于输出 logits 的知识蒸馏工作 KD [2],一篇是基于 CNN 的中间层特征的知识蒸馏工作 FitNet [4]。
1 知识蒸馏[2]Knowledge Distillation
1.1 算法简介
知识蒸馏(Knowledge Distillation)是一种基于“教师-学生网络”思想的模型压缩方法,由于简单有效,在工业界被广泛应用。其目的是将已经训练好的大模型包含的知识,蒸馏(Distill)提取到另一个小的模型中去。那怎么让大模型的知识,或者说泛化能力转移到小模型身上去呢?KD 论文把大模型对样本输出的概率向量作为软目标(soft targets)提供给小模型,让小模型的输出尽量去向这个软目标靠,而原来是和 one-hot 编码上靠,去近似学习大模型的行为。
在传统的硬标签训练过程中,所有负标签都被统一对待,但这种方式把类别间的关系割裂开了。例如识别手写数字,同样是标签为“3”的图片,可能有的比较像“8”,有的比较像“2”,硬标签区分不出来这个信息,但是一个训练良好的大模型可以给出。大模型 softmax 层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。近似学习这一行为使得每个样本给学生网络带来的信息量大于传统的训练方式。
因此,作者在训练学生网络时修改了一下损失函数,让小模型在拟合训练数据的真值(ground truth)标签的同时,也要拟合大模型输出的概率分布。这个方法叫做知识蒸馏训练(Knowledge Distillation Training, KD Training)。知识蒸馏过程所用的训练样本可以和训练大模型用的训练样本一样,或者另找一个独立的 Transfer set。
1.2 算法详解
具体来说,知识蒸馏使用的是 Teacher—Student 模型,其中 teacher 是“知识”的输出者,student 是“知识”的接受者。知识蒸馏的过程分为 2 个阶段:
①教师模型训练: 训练"Teacher 模型", 简称为 Net-T,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。对"Teacher 模型"不作任何关于模型架构、参数量、是否集成方面的限制,因为该模型不需要部署,唯一的要求就是,对于输入 X, 其都能输出 Y,其中 Y 经过 softmax 的映射,输出值对应相应类别的概率值。
②学生模型训练: 训练"Student 模型", 简称为 Net-S,它是参数量较小、模型结构相对简单的单模型。同样,对于输入 X,其都能输出 Y,Y 经过 softmax 映射后能输出相应类别的概率值。
由于使用 softmax 的网络的结果很容易走向极端,即某一类的置信度超高,其他类的置信度都很低,此时学生模型关注到的正类信息可能还是仅属于某一类。除此之外,因为不同类别的负类信息也有相对的重要性,所有负类分数都差不多也不好,达不到知识蒸馏的目的。为了解决这个问题,引入温度(Temperature)的概念,使用高温将小概率值所携带的信息蒸馏出来。具体来说,在 logits 过 softmax 函数前除以温度T。
训练时首先将教师模型学习到的知识蒸馏给小模型,具体来说对样本 x,大模型的倒数第二层先除以一个温度 T,然后通过 Softmax 预测一个软目标 Soft target,小模型也一样,倒数第二层除以同样的温度 T,再通过 Softmax 预测一个结果,接着把这个结果和软目标的交叉熵作为训练的 total loss 的一部分。下一步再将小模型正常的输出和真值标签(hard target)的交叉熵作为训练的 total loss 的另一部分。Total loss 把这两个损失加权合起来作为训练小模型的最终的 loss。
在小模型训练好后的预测环节,就不需要再有温度 T 了,直接按照常规的 Softmax 输出就可以了。
2 FitNet[4]Hints for Thin Deep Nets
2.1 算法简介
FitNet 论文在蒸馏时引入了中间层隐藏映射(intermediate-level hints)来指导学生模型的训练。使用一个宽而浅的教师模型来训练一个窄而深的学生模型。在进行 hint 引导时,提出使用一个层来匹配 hint 层和 guided 层的输出 shape,这在后人的工作里面常被称为 adaptation layer。
总的来说,相当于是在做知识蒸馏时,不仅用到了教师模型的 logit 输出,还用到了教师模型的中间层特征图作为监督信息。可以想到的是,直接让小模型在输出端模仿大模型,这个对于小模型来说太难了。一般而言,模型越深越难训练,最后一层的监督信号要向前传递增加工作量。
因此,不如在中间加一些监督信号,使得模型在训练时可以逐层接受学习更难的映射函数,而不是直接学习最难的映射函数。除此之外,hint 引导加速了学生模型的收敛,在一个非凸问题上找到更好的局部最小值,使得学生网络能更深的同时,还能提升训练速度。这感觉就好像是,我们的目的是让学生做高考题,那么就先把初中的题目给他教会了(先让小模型用前半个模型学会提取图像底层特征),再回到本来的目的:去学高考题(用 KD 调整小模型的全部参数)。
这篇文章是提出蒸馏中间特征图的始祖,提出的算法很简单,但思路具有开创性。
2.2 算法详解
FitNets 的具体做法是:
①确定教师网络,并训练成熟,将教师网络的中间特征层 hint 提取出来。
②设定学生网络,该网络一般较教师网络更窄且更深。训练学生网络使得学生网络的中间特征层与教师模型的 hint 相匹配。由于学生网络的中间特征层和与教师 hint 尺寸不同,因此需要在学生网络中间特征层后添加回归器用于特征升维,以匹配 hint 层尺寸。其中,匹配教师网络的 hint 层与回归器转化后的学生网络的中间特征层的损失函数为均方差损失函数。
实际训练的时候往往和上一节的 KD Training 联合使用,用两阶段法训练:先用 hint training 去 pretrain 小模型前半部分的参数,再用 KD Training 去训练全体参数。由于蒸馏过程中使用了更多的监督信息,基于中间特征图的蒸馏方法比基于结果 logits 的蒸馏方法效果要好,但是训练时间更久。
总结
知识蒸馏对于将知识从集成或从高度正则化的大型模型转移到较小的模型中非常有效。即使在用于训练蒸馏模型的迁移数据集中缺少任何一个或多个类的数据时,蒸馏的效果也非常好。在经典之作 KD 和 FitNet 提出之后,各种各样的蒸馏方法如雨后春笋般涌现,未来易盾人工智能技术团队也将会在模型压缩和知识迁移领域做出更进一步的探索。
引用
[1] Buciluǎ C, Caruana R, Niculescu-Mizil A. Model compression[C]//Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining. 2006: 535-541.
[2] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015, 2(7).
[3] Wang L, Yoon K J. Knowledge distillation and student-teacher learning for visual intelligence: A review and new outlooks[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2021.
[4] Romero A, Ballas N, Kahou S E, et al. Fitnets: Hints for thin deep nets[J]. arXiv preprint arXiv:1412.6550, 2014.
关于易盾
网易易盾是网易集团旗下一站式数字内容风控品牌,依托网易 20 多年的先进技术沉淀和一线实践经验,作为国内领先的数字内容风控服务商,为面向数字化业务的客户提供专业可靠的安全服务,涵盖内容安全、业务安全、移动安全三大领域,全方位保障客户业务合规、稳健和安全运营。目前,网易易盾已服务超 40 万开发者与数千家付费客户,其中不乏人民网、外交部、华泰证券、中信银行、OPPO、vivo、滴滴、知乎、B 站等知名企事业单位。