计算机工程与应用 ›› 2022, Vol. 58 ›› Issue (5): 193-199.DOI: 10.3778/j.issn.1002-8331.2010-0353

• 模式识别与人工智能 • 上一篇    下一篇

小样本下基于Wasserstein距离的半监督学习算法

马幪浩,王喆   

  1. 华东理工大学 信息科学与工程学院,上海 200237
  • 出版日期:2022-03-01 发布日期:2022-03-01

Semi-supervised Learning Method via Wasserstein Distance Under Small Sample Condition

MA Menghao, WANG Zhe   

  1. School of Information Science and Engineering, East China University of Science and Technology, Shanghai 200237, China
  • Online:2022-03-01 Published:2022-03-01

摘要: 近年来,基于大规模标记数据集的深度神经网络模型在图像领域展现出优秀的性能,但是大量标记数据昂贵且难以收集。为了更好地利用无标记数据,提出了一种半监督学习方法Wasserstein consistency training(WCT), 通过引入Jensen-Shannon散度来模拟协同训练并组织大量未标记数据来提高协同训练效率,通过快速梯度符号攻击施加的对抗攻击来生成对抗样本以鼓励视图的差异,将Wasserstein距离作为网络差异约束的度量,以防止深度神经网络崩溃,使网络在低维流形空间上平滑输出。实验结果表明,所提方法在MNIST分类错误率为0.85%,在仅使用4?000个标记数据的CIFAR-10数据集上错误率达到11.96%,证明了所提方法在小样本条件下的半监督图像分类中具有较好的性能。

关键词: 小样本, 半监督学习, 对抗样本, 深度神经网络

Abstract: In recent years, the deep neural network model based on large-scale labeled data sets has shown advanced performance in the image field, but a large number of labeled data are expensive and difficult to collect. To make better use of unlabeled data, in this paper, a semi-supervised learning method, Wasserstein consistency training(WCT), is proposed in which Jensen-Shannon divergence is introduced to simulate consistency training and organize massive unlabeled data to improve the efficiency of consistency training. Adversarial samples are generated to encourage the difference of views through the adversarial attack imposed by the fast gradient sign method, and Wasserstein distance is used as the measure of network difference constraint to prevent the deep neural network from collapsing and make the network output smoothly on the low dimensional manifold.Experimental results show that the error rate of the proposed method is 0.85% in MNIST and 11.96% in CIFAR-10 with only 4?000 labeled data, which proves that the proposed method has better performance in semi-supervised image classification under the small samplecondition.

Key words: small sample, semi-supervised learning, adversarial samples, deep neural network