pytorch 实现 GRL Gradient Reversal Layer_pytorch grl-程序员宅基地

技术标签: 深度学习  pytorch  行人重识别  

在GRL中,要实现的目标是:在前向传导的时候,运算结果不变化,在梯度传导的时候,传递给前面的叶子节点的梯度变为原来的相反方向。举个例子最好说明了:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y
s =6* f.sum()

print(s)
s.backward()
print(x)
print(x.grad)

这个程序的运行结果是:

tensor(672., grad_fn=<MulBackward0>)
tensor([1., 2., 3.], requires_grad=True)
tensor([18., 30., 42.])

这个运算过程对于tensor中的每个维度上的运算为:

f(x)=(x^{2}+x)*6

那么对于x的导数为:

\frac{\mathrm{d} f}{\mathrm{d} x} = 12x+6

所以当输入x=[1,2,3]时,对应的梯度为:[18,30,42]

因此这个是正常的梯度求导过程,但是如何进行梯度翻转呢?很简单,看下方的代码:

import torch
from torch.autograd  import  Function

x = torch.tensor([1.,2.,3.],requires_grad=True)
y = torch.tensor([4.,5.,6.],requires_grad=True)

z = torch.pow(x,2) + torch.pow(y,2)
f = z + x + y

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input


Grl = GRL()

s =6* f.sum()
s = Grl(s)

print(s)
s.backward()
print(x)
print(x.grad)

运行结果为:

tensor(672., grad_fn=<GRL>)
tensor([1., 2., 3.], requires_grad=True)
tensor([-18., -30., -42.])

这个程序相对于上一个程序,只是差在加了一个梯度翻转层:

class GRL(Function):
    def forward(self,input):
        return input
    def backward(self,grad_output):
        grad_input = grad_output.neg()
        return grad_input

这个部分的forward没有进行任何操作,backward里面做了.neg()操作,相当于进行了梯度的翻转。在torch.autograd 中的FUnction 的backward部分,在不做任何操作的情况下,这里的grad_output的默认值是1.

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/t20134297/article/details/107870906

智能推荐

【MySQL】mysql The server time zone value “乱码” 错误_the server time zone value 乱码-程序员宅基地

文章浏览阅读7.8k次。稚语希听– 你忘了想起,我忘了忘记…mysql8以上版本时区问题:The server time zone value乱码XXXX异常类似:The server time zone value ‘�й���׼ʱ��’ is unrecognized or represents more than one time zone. You must configure either the server or JDBC driver (via the serverTimezone configuratio_the server time zone value 乱码

【WebApi】————.net WebApi开发(一)_webapi .net-程序员宅基地

文章浏览阅读8.5k次。【1】.部署环境.net4及以上版本。【2】.vs2010 开发需单独安装vs2010 sp1和mvc4mvc4:http://www.asp.net/mvc/mvc4【3】.开发1.新建项目选择ASP.net MVC 4 Web应用程序2.选择Web API 3.在新建立的项目里面有已经生成的webapi模版其中App_Start文件夹下WebApiCo..._webapi .net

几招教你阻止百度搜索自动跳转百度APP(其他网站也适用)!_百度自动跳转app怎么解决-程序员宅基地

文章浏览阅读10w+次,点赞15次,收藏33次。最近阿虚看到个消息说「百度」发布了新政策,禁止网站通过搜索引擎打开后折叠内容强迫下载APP客户端听起来似乎是百度难得良心一回?但实际上该政策仅限于手机百度APP内如果你是通过浏览器用百度搜索则与新政策完全没关系正好前不久不少粉丝来问过我这样一个问题:怎么屏蔽手机浏览器上的「跳转某某APP打开查看」提示那今天阿虚就来教一下怎么解决吧,毕竟这东西的确是有点烦人…屏蔽「跳转某某APP打开查看」这个问题我细看了下,还得分俩类:文章只能显示部分,然后提示你需要安装APP才能查看的,这种应该是大_百度自动跳转app怎么解决

PHP快速入门12-异常处理,自定义异常、抛出异常、断言异常等示例_php 抛出异常-程序员宅基地

文章浏览阅读843次。PHP的异常处理机制可以帮助我们在程序运行时遇到错误或异常情况时,及时发出警告并停止程序继续运行。下面是10个例子,分别展示了PHP异常处理的不同用法。_php 抛出异常

linux 清空docker容器日志_linux清理docker容器log-程序员宅基地

文章浏览阅读221次。【代码】linux 清空docker容器日志。_linux清理docker容器log

青岛大学开源OJ平台搭建_github oj开源-程序员宅基地

文章浏览阅读7.3k次,点赞3次,收藏15次。源码地址为:https://github.com/QingdaoU/OnlineJudge可参考的文档为:https://github.com/QingdaoU/OnlineJudgeDeploy/tree/2.0一、安装所依赖的环境sudo apt-get update && sudo apt-get install -y vim python-pip curl g..._github oj开源

随便推点

docker安装及部署mysql_docker部署mysql-程序员宅基地

文章浏览阅读1.5k次,点赞2次,收藏9次。docker安装与mysql部署_docker部署mysql

联想笔记本G510升级固态硬盘(SSD)血泪教程!!!_联想g510更换固态硬盘-程序员宅基地

文章浏览阅读8.5w次,点赞23次,收藏55次。#联想笔记本G510升级固态硬盘(SSD)血泪教程!!!用了5年的联想笔记本G510,经过了四年的游戏历程,然后四年后还老当益壮的挣扎在我工作的战斗一线,是我并肩作战多年,比兄弟还要亲的兄弟,虽然此时已经身躯残破,反应迟缓我依旧不舍得抛弃它(主要是没钱!)然后为了我个人的用户体验决定花少量的票子,让它多挣扎一会,最好是能坚持到我度过贫困期. 下面是我升级的悲催历程! - 首先为了提升运行速..._联想g510更换固态硬盘

问题记录——正则表达式匹配控制符_正则表达式匹配控制字符-程序员宅基地

文章浏览阅读910次。问题前端用xterm.js通过websocket连接docker虚拟终端,返回的字符中包括如下字符串,其中有两个控制字符,“ESC"和"BEL” ,想通过正则表达式匹配这一段字符,然后去掉这段字符:参考文档控制字符编码表转义符对照表通过上面查询得知,"ESC"和"BEL"这两个控制符的ASCII码分别为:十进制为27和7,十六进制为0x1B和0x07,转义符分别为:\e和\a代码**注意:**直接使用ASCII码匹配是不行的,一定要用转义符才行。如下测试代码中,只有regex3才能匹_正则表达式匹配控制字符

Android RIL框架分析-程序员宅基地

文章浏览阅读1.5k次。1.RIL框架 RIL,Radio Interface Layer。本层为一个协议转换层,提供Android Telephony与无线通信设备之间的抽象层。 Android RIL位于Telephony Frameworks之下,Modem之上的,根据源码,RIL可以分为两个部分:Frameworks 框架层中的java程序,简称RILJ。HAL层中C/C++程序,简称RILC,RILC具体的又包括LibRIL、Rild和Reference-RIL这三个部分。 Andr..._ril框架

Python编程基础:第六节 math包的基础使用Math Functions_ps math function-程序员宅基地

文章浏览阅读565次。第六节 math包的基础使用前言实践前言我们通常会对数值型变量进行计算,这里我们给出一些常用的函数用于辅助你的计算过程。常用的数学计算函数均在math包。实践首先我们导入math包,并定义一个浮点型变量pi将其赋值为3.14:import mathpi = 3.14如果我们需要计算浮点型变量四舍五入后的计算结果,用函数round()即可:print(round(pi))>>> 3如果我们需要向上取整,那就需要函数math.ceil():print(math.cei_ps math function

canal异常 Could not find first log file name in binary log index file_canal could not find first log file name in binary-程序员宅基地

文章浏览阅读4.4k次,点赞3次,收藏2次。Could not find first log file name in binary log index file问题解决解决过程问题最近在使用canal来监测数据库的变化,处理变动的数据。由于有一段时间没有用了,这次启动在日志文件中看到这个异常 Could not find first log file name in binary log index file,详细信息如下:2020-12-16 19:14:42.053 [destination = tradeAndRefund , addr_canal could not find first log file name in binary log index file