Siamese Network & Triplet NetWork-程序员宅基地

技术标签: 孪生网络  深度学习  

Siamese Network(孪生网络)

简单来说,孪生网络就是共享参数的两个神经网络

在孪生网络中,我们把一张图片 X 1 X_1 X1作为输入,得到该图片的编码 G W ( X 1 ) G_W(X_1) GW(X1)。然后,我们在不对网络参数进行任何更新的情况下,输入另一张图片 X 2 X_2 X2,并得到改图片的编码 G W ( X 2 ) G_W(X_2) GW(X2)。由于相似的图片应该具有相似的特征(编码),利用这一点,我们就可以比较并判断两张图片的相似性

孪生网络的损失函数

传统的Siamese Network使用Contrastive Loss(对比损失函数)
L = ( 1 − Y ) 1 2 ( D W ) 2 + ( Y ) 1 2 { m a x ( 0 , m − D W ) } 2 \mathcal{L} = (1-Y)\frac{1}{2}(D_W)^2+(Y)\frac{1}{2}\{max(0, m-D_W)\}^2 L=(1Y)21(DW)2+(Y)21{ max(0,mDW)}2
其中 D W D_W DW被定义为孪生网络两个输入之间的欧氏距离,即
D W = { G W ( X 1 ) − G W ( X 2 ) } 2 D_W = \sqrt{\{G_W(X_1)-G_W(X_2)\}^2} DW={ GW(X1)GW(X2)}2

  • Y Y Y值为0或1,如果 X 1 , X 2 X_1,X_2 X1,X2这对样本属于同一类,则 Y = 0 Y=0 Y=0,反之 Y = 1 Y=1 Y=1
  • m m m是边际价值(margin value),即当 Y = 1 Y=1 Y=1,如果 X 1 X_1 X1 X 2 X_2 X2之间距离大于 m m m,则不做优化(省时省力);如果 X 1 X_1 X1 X 2 X_2 X2之间的距离小于 m m m,则调整参数使其距离增大到 m m m
Contrastive Loss代码
import torch
import numpy as np
import torch.nn.functional as F

class ContrastiveLoss(torch.nn.Module):
    "Contrastive loss function"
    def __init__(self, m=2.0):
        super(ContrastiveLoss, self).__init__()
        self.m = m
            
    def forward(self, output1, output2, label):
        d_w = F.pairwise_distance(output1, output2)
        contrastive_loss = torch.mean((1-label) * 0.5 * torch.pow(d_w, 2) +
                                      (label) * 0.5 * torch.pow(torch.clamp(self.m - d_w, min=0.0), 2))

        return contrastive_loss

其中,F.pairwise_distance(x1, x2, p=2)函数公式如下
( ∑ i = 1 n ( ∣ x 1 − x 2 ∣ p ) ) 1 p x 1 , x 2 ∈ R b × n (\sum_{i=1}^n(|x_1-x_2|^p))^{\frac{1}{p}}\\ x_1,x_2 \in \mathbb{R}^{b\times n} (i=1n(x1x2p))p1x1,x2Rb×n

pairwise_distance(x1, x2, p) Computes the batchwise pairwise distance between vectors x 1 x_1 x1, x 2 x_2 x2 using the p-norm

孪生网络的用途

简单来说,孪生网络的直接用途就是衡量两个输入的差异程度(或者说相似程度)。将两个输入分别送入两个神经网络,得到其在新空间的representation,然后通过Loss Function来计算它们的差异程度(或相似程度)

  • 词汇语义相似度分析,QA中question和answer的匹配
  • 手写体识别也可以用Siamese Network
  • Kaggle上Quora的Question Pair比赛,即判断两个提问是否为同一个问题
Pseudo-Siamese Network(伪孪生网络)

对于伪孪生网络来说,两边可以是不同的神经网络(如一个是lstm,一个是cnn),并且如果是相同的神经网络,是不共享参数

孪生网络和伪孪生网络分别适用的场景
  • 孪生网络适用于处理两个输入比较类似的情况
  • 伪孪生网络适用于处理两个输入有一定差别的情况

例如,计算两个句子或者词汇的语义相似度,使用Siamese Network比较合适;验证标题与正文的描述是否一致(标题和正文长度差别很大),或者文字是否描述了一幅图片(一个是图片,一个是文字)就应该使用Pseudo-Siamese Network

Triplet Network(三胞胎网络)

如果说Siamese Network是双胞胎,那Triplet Network就是三胞胎。它的输入是三个:一个正例+两个负例,或一个负例+两个正例。训练的目标仍然是让相同类别间的距离尽可能小,不同类别间的距离尽可能大。Triplet Network在CIFAR,MNIST数据集上效果均超过了Siamese Network

损失函数定义如下:
L = m a x ( d ( a , p ) − d ( a , n ) + m a r g i n , 0 ) \mathcal{L}=max(d(a,p)-d(a,n)+margin, 0) L=max(d(a,p)d(a,n)+margin,0)

  • a a a表示anchor图像
  • p p p表示positive图像
  • n n n表示negative图像

我们希望 a a a p p p的距离应该小于 a a a n n n的距离。 m a r g i n margin margin是个超参数,它表示 d ( a , p ) d(a,p) d(a,p) d ( a , n ) d(a,n) d(a,n)之间应该相差多少,例如,假设 m a r g i n = 0.2 margin=0.2 margin=0.2,并且 d ( a , p ) = 0.5 d(a,p)=0.5 d(a,p)=0.5,那么 d ( a , n ) d(a,n) d(a,n)应该大于等于 0.7 0.7 0.7

Reference
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_37236745/article/details/109057383

智能推荐

【Unity3D游戏开发实战】Unity3D实现休闲类游戏《2048》——算法、源代码_unity小游戏2048源码-程序员宅基地

文章浏览阅读1w次,点赞29次,收藏133次。推荐阅读CSDN主页GitHub开源地址Unity3D插件分享简书地址我的个人博客QQ群:1040082875大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。一、前言最近有粉丝要参加游戏创作大赛,问我需要准备学习什么知识,以及参加比赛的注意事项:参加这类比赛是非常有好处的,不仅提高了技术,也增长了见识。因为是兴趣驱动学习,在完善自己心爱游戏的过程中,要不断的去学习,不断的提高自己。更能在这个过程中找到志同道合的好朋友。那今天就._unity小游戏2048源码

【python】Open3D,Write PLY failed解决方法_write pcd failed: unable to generate header.-程序员宅基地

文章浏览阅读6.5k次。写了一个简单的函数,将三维点云(ndarray)保存为.ply文件:def save_points_as_ply(points, ply_path): """ 将点云保存为.ply文件,保存成功会打印'ply_path 已保存' :param points: ndarray, (-1,3) :param ply_path: str,'xxx/xxxx.ply' """ pcd = o3d.geometry.PointCloud() pcd.point_write pcd failed: unable to generate header.

详解 Android Views 元素的 layout_weight 属性-程序员宅基地

文章浏览阅读75次。所有View(视图)元素中都有一个XML属性android:layout_weight,其值为0,1,2,3...等整数值。使用了之后,其对应界面中的元素比例就会发生变化,变大或者变小。layout_weight属性其实就是一个元素重要度的属性,用于在线性布局中为不同的view元素设置不同的重要度。  所有的视图都有一个layout_weight值,其默认值为0,表示视图多大就占据..._android view获取当前的layout_with 的值

hosts文件修改后无法保存问题_linux hosts文件无法保存-程序员宅基地

文章浏览阅读8.6k次,点赞11次,收藏14次。hosts文件在windows目录下的位置(我的是win10系统,其他系统大同小异)C:\Windows\system32\drivers\etc\hostslinux系统hosts位置/etc/hostsLinux系统一般来说linux系统出现无法修改的情况是比较少的,基本没有,只要你处于root权限下是都可以修改的,因为root默认是有rwx权限的如果不能修改,r..._linux hosts文件无法保存

Java中自定义异常的两个小例子_public int getlength(){return length}-程序员宅基地

文章浏览阅读962次。Java 异常处理异常是程序中的一些错误,但并不是所有的错误都是异常,并且错误有时候是可以避免的。比如说,你的代码少了一个分号,那么运行出来结果是提示是错误 java.lang.Error;如果你用System.out.println(11/0),那么你是因为你用0做了除数,会抛出 java.lang.ArithmeticException 的异常。异常发生的原因有很多,通常包含以下几..._public int getlength(){return length}

解决:Ubuntu18环境Docker安装成功,但启动报错Unit docker.service is not loaded properly: Bad message_loaded: error (reason: bad message)-程序员宅基地

文章浏览阅读6.2k次,点赞2次,收藏4次。解决:Ubuntu18环境Docker安装成功,但启动报错Unit docker.service is not loaded properly: Bad message.文章目录解决:Ubuntu18环境Docker安装成功,但启动报错Unit docker.service is not loaded properly: Bad message.前言:微信交流群:分析解决方案使用存储库安装 Docker-ce**设置存储库****安装 Docker-ce****测试 Docker-ce****升级 Doc_loaded: error (reason: bad message)

随便推点

操作系统安全---实验三:Windows7操作系统安全_操作系统的基本安全设置实验总结-程序员宅基地

文章浏览阅读3.8k次,点赞3次,收藏39次。目录一、实验目的及要求二、实验原理三、实验环境四、实验步骤及内容4.1账户与口令4.2审核与日志4.3安全模板五、实验总结六、分析与思考一、实验目的及要求了解Windows账户与密码的安全策略设置,掌握用户和用户组的权限管理、审核,以及日志的启用,并学会使用安全模版来分析配置计算机。二、实验原理Windows系列是目前世界上使用用户最多的桌面操作系统。由于历史原因,Windows的很多用户都直接以管理员权限运行系统,对计算机安全构成很大隐患。从Wi..._操作系统的基本安全设置实验总结

Python + Selenium自动化测试 -- 自定义Log类_selenium python 创建log类-程序员宅基地

文章浏览阅读3.2k次,点赞2次,收藏15次。本文用日志来记录我们测试脚本做的事情,其实最好的办法是写事件监听(对于小白的我,暂时不会,先从日志学起)。 下面写一个日之类,用来输出不同级别的日志信息到本地文件夹下的日志文件里。 目标输出效果: 解决思路: 1. 封装Log类,类名为Logger; 2. 在Logger类中创建记录器logger; 3. 创建一个handler,用于写入日志文件,写到磁盘;再创建一个handler,_selenium python 创建log类

canal 整合 springboot_canalboot-程序员宅基地

文章浏览阅读632次。mysql 开启bin_logvi /etc/my.cnf末尾增加如下配置log_bin=mysql-bin binlog-format=ROW #选择row模式server-id = 1expire_logs_days=5 #日志过期时间为5天 重启mysql [5.7]service mysqld restart 修改canal 配置vi canal/conf/canal.properties#唯一标识 新增canal.id =123 _canalboot

零基础HTML教程(14)--hr:黄昏的地平线_html水平线-程序员宅基地

文章浏览阅读1w次,点赞11次,收藏13次。本文目录1. 水平线的概念2. 水平线的用法3. 小结1. 水平线的概念HTML中有一个比较特别的标签,叫做水平线,写作<hr>。该标签可以在网页上显示一条横线,一般用来分隔不同的网页内容。2. 水平线的用法使用方法很简单,在需要分割的地方,添加一个<hr>标签即可。例如:<!DOCTYPE html><html><head> <title>水平线实例</title> <meta c_html水平线

python查看已安装包的版本_python 如何查看networks的版本-程序员宅基地

文章浏览阅读4k次,点赞7次,收藏9次。pip freeze就不要说了,当你安装1000个包的时候就不会用这种蠢办法。第一种办法,打开终端/CMDpip freeze | findstr numpy这是windows下的,numpy只是个例子,要查什么自己改pip freeze | grep numpylinux下的第二种方法,python里去看。打开python命令行界面。一般来说包的版本都会用一个.__versio..._python 如何查看networks的版本

element ui 点击表格某一行改变行背景颜色_element ui 表格第一行变色-程序员宅基地

文章浏览阅读7.5k次,点赞5次,收藏10次。template<el-table :data="data" :row-class-name="tableRowClassName" //设置类 :row-style="selectedstyle" //设置行的样式 @row-click="rowClick" //点击></el-table>scriptdata() { return { data:[], getIndex:"", }},met_element ui 表格第一行变色

推荐文章

热门文章

相关标签