标准化流(Normalizing Flow)_Jie Qiao的博客-程序员ITS301_标准化流

技术标签: 人工智能  

Normalizing Flow

flow的核心思想就是这个分布变换的公式,如果 y = f ( x ) \displaystyle y=f( x) y=f(x),且 f \displaystyle f f是可逆的,则

p x ( x ) = p y ( f ( x ) ) ∗ ∣ det ⁡ J f ( x ) ∣ p y ( y ) = p x ( f − 1 ( y ) ) ∗ ∣ det ⁡ J f − 1 ( y ) ∣ p_{x} (x)=p_{y} (f(x))*|\det Jf(x)|\\ p_{y} (y)=p_{x} (f^{-1} (y))*|\det Jf^{-1} (y)| px(x)=py(f(x))detJf(x)py(y)=px(f1(y))detJf1(y)

那如果有很多个 f \displaystyle f f,那么取log之后我们只需简单地加起来就好了:

z K = f K ∘ … ∘ f 2 ∘ f 1 ( z 0 ) log ⁡ q K ( z K ) = log ⁡ q 0 ( z 0 ) − ∑ k = 1 K log ⁡ det ⁡ ∣ ∂ f k ∂ z k ∣ \begin{aligned} \mathbf{z}_{K} & =f_{K} \circ \dotsc \circ f_{2} \circ f_{1}(\mathbf{z}_{0})\\ \log q_{K}(\mathbf{z}_{K}) & =\log q_{0}(\mathbf{z}_{0}) -\sum ^{K}_{k=1}\log\operatorname{det}\left| \frac{\partial f_{k}}{\partial \mathbf{z}_{k}}\right| \end{aligned} zKlogqK(zK)=fKf2f1(z0)=logq0(z0)k=1Klogdetzkfk

想要对这个分布变换了解更多的可以看我前一篇文章:
理解Jacobian矩阵与分布变换

NICE: Additive coupling layer

从分布变换的公式可以看出,想要处理好这种变换,首先要保证f是可逆的,其次,f的Jacobian的行列式也要好求才行,而最好求的行列式自然就是三角行列式,等于对角线的。

那么现在介绍最基本的做法就是NICE的做法:首先对于D维数据的x,将其随意地划分为两部分, x 1 , x 2 \displaystyle x_{1} ,x_{2} x1,x2,并做变换:

y 1 = x 1 y 2 = x 2 + m ( x 1 ) \begin{aligned} & \mathbf{y}_{1} =\boldsymbol{x}_{1}\\ & \mathbf{y}_{2} =\boldsymbol{x}_{2} +\boldsymbol{m} (\boldsymbol{x}_{1} ) \end{aligned} y1=x1y2=x2+m(x1)

其中m是一个任意的MLP函数,通过这样的变换,我们发现这个函数是可逆的,即:

x 1 = y 1 x 2 = y 2 − m ( x 1 ) \begin{aligned} & \boldsymbol{x}_{1} =\mathbf{y}_{1}\\ & \boldsymbol{x}_{2} =\mathbf{y}_{2} -\boldsymbol{m} (\boldsymbol{x}_{1} ) \end{aligned} x1=y1x2=y2m(x1)

事实上随后的改进都只是进一步地将这个可逆函数变得更加地“复杂”,这样的加性确实是简单了点。

它的基本原理很简单,就是将x和y划分成两块,其中 x 1 = x 1 : d \displaystyle x_{1} =x_{1:d} x1=x1:d前d个元素, x 2 = x d + 1 : D \displaystyle x_{2} =x_{d+1:D} x2=xd+1:D后面的元素。(所以这个似乎只能处理x和y都是高维的情况)。注意观察,y2和x2的关系其实是线性,而且利用了x1和y1的等价关系,所以用x2表示y2的时候,直接把x1换成y1就可以了,因此这个函数非常容易求逆。其实我觉得这个东西的本质其实应该是利用了一个Z的共享变量,来构造的可逆函数

        Z     /         \   X   −   Y \begin{array}{l} \ \ \ \ \ \ \ Z\\ \ \ \ /\ \ \ \ \ \ \ \backslash \\ \ X\ -\ Y \end{array}        Z   /       \ X  Y

于是

y 2 = x 2 + m ( z ) y_{2} =x_{2} +m(z) y2=x2+m(z)

而这个函数求导也很简单

∂ y ∂ x = [ ∂ y 1 ∂ x 1 ∂ y 1 ∂ x 2 ∂ y 2 ∂ x 1 ∂ y 2 ∂ x 2 ] = [ I 1 : d 0 ∂ m ( x 1 ) ∂ x 1 I d : D ] \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =\left[\begin{array}{ c c } \frac{\partial y_{1}}{\partial x_{1}} & \frac{\partial y_{1}}{\partial x_{2}}\\ \frac{\partial y_{2}}{\partial x_{1}} & \frac{\partial y_{2}}{\partial x_{2}} \end{array}\right] =\left[\begin{array}{ c c } \mathbb{I}_{1:d} & 0\\ \frac{\partial m( x_{1})}{\partial x_{1}} & \mathbb{I}_{d:D} \end{array}\right] xy=[x1y1x1y2x2y1x2y2]=[I1:dx1m(x1)0Id:D]

神奇的事情出现了,这个可逆函数的Jacobian矩阵居然只是一个三角矩阵,他的行列式就等于对角线的乘积为1,而其对数为0,

Real NVP: Affine Coupling layers

只是加性就太简单了,所以我们变得复杂一点,这是Real NVP中提出的做法:

y 1 = x 1 y 2 = s ( x 1 ) ⊙ x 2 + t ( x 1 ) \begin{aligned} & \mathbf{y}_{1} =\boldsymbol{x}_{1}\\ & \mathbf{y}_{2} =\mathbf{s}(\mathbf{x}_{1})\boldsymbol{\odot x}_{2} +t(\boldsymbol{x}_{1} ) \end{aligned} y1=x1y2=s(x1)x2+t(x1)

我乘一个非线性函数 s ( x 1 ) \displaystyle \mathbf{s}(\mathbf{x}_{1}) s(x1)上去,这里 ⊙ \displaystyle \odot 是点乘,其实他们本质上还是一个线性变换而已(所以才叫affine),s就是斜率,t是截距。既然线性变换可以,肯定也有多项式变换等等变种,事实上已经有类似的工作,比如Neural Spline Flows这种就是一个多项式的可逆函数,这里先不说这个。我们先看看这个Affine Coupling layer的jacobian是长什么样:

∂ y ∂ x = [ ∂ y 1 ∂ x 1 ∂ y 1 ∂ x 2 ∂ y 2 ∂ x 1 ∂ y 2 ∂ x 2 ] = [ I d 0 ∂ s ∂ x 1 ⊗ x 2 + ∂ t ∂ x 1 diag ⁡ ( s ) ] \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =\left[\begin{array}{ c c } \frac{\partial y_{1}}{\partial x_{1}} & \frac{\partial y_{1}}{\partial x_{2}}\\ \frac{\partial y_{2}}{\partial x_{1}} & \frac{\partial y_{2}}{\partial x_{2}} \end{array}\right] =\left[\begin{array}{ c c } \mathbb{I}_{d} & 0\\ \frac{\partial \mathbf{s}}{\partial \mathbf{x}_{1}} \otimes \mathbf{x}_{2} +\frac{\partial t}{\partial \mathbf{x}_{1}} & \operatorname{diag}( s) \end{array}\right] xy=[x1y1x1y2x2y1x2y2]=[Idx1sx2+x1t0diag(s)]

其中之所以是对角矩阵是因为

( x 1 ⋮ x n ) ⊙ ( y 1 ⋮ y n ) = ( x 1 ⋯ 0 ⋮ ⋱ ⋮ 0 ⋯ x n ) ( y 1 ⋮ y n ) \left(\begin{array}{ c } x_{1}\\ \vdots \\ x_{n} \end{array}\right) \odot \left(\begin{array}{ c } y_{1}\\ \vdots \\ y_{n} \end{array}\right) =\left(\begin{array}{ c c c } x_{1} & \cdots & 0\\ \vdots & \ddots & \vdots \\ 0 & \cdots & x_{n} \end{array}\right)\left(\begin{array}{ c } y_{1}\\ \vdots \\ y_{n} \end{array}\right) x1xny1yn=x100xny1yn

所以

∂ y 2 ∂ x 2 = ∂diag ⁡ ( s ) x 2 ∂ x 2 = diag ⁡ ( s ) \frac{\partial y_{2}}{\partial x_{2}} =\frac{\operatorname{\partial diag}(\mathbf{s}) x_{2}}{\partial x_{2}} =\operatorname{diag}(\mathbf{s}) x2y2=x2diag(s)x2=diag(s)

这里一般会约束s大于0,所以一般神经网络输出的是log(s),然后取指数变回来。

Glow 一种可逆的1x1卷积

我们刚才随机划分的方法感觉处理图片的时候怪怪的,而且就算是随机交换channel也感觉不太对,有没有更优雅的方法?Glow给出了解决的方法,我们可以引入1x1可逆卷积核来代替这个划分的操作。其实1x1卷积核本身就有随机置换的味道在里面,只不过这篇文章的贡献在于用了一个trick保证了他的可逆性。

可以看个小例子(引用来自苏剑林的博客:https://kexue.fm/archives/5807),随机置换的操作其实就是一个简单的线性变换:

( 2 1 4 3 ) = ( 0 1 0 0 1 0 0 0 0 0 0 1 0 0 1 0 ) ( 1 2 3 4 ) \begin{pmatrix} 2\\ 1\\ 4\\ 3 \end{pmatrix} =\begin{pmatrix} 0 & 1 & 0 & 0\\ 1 & 0 & 0 & 0\\ 0 & 0 & 0 & 1\\ 0 & 0 & 1 & 0 \end{pmatrix}\begin{pmatrix} 1\\ 2\\ 3\\ 4 \end{pmatrix} 2143=01001000000100101234

我们知道一个卷积核其实是可以表达成一个矩阵乘积的,只要我们把所有channel展开成一条向量,就可以写出一个矩阵的计算公式:

Y = X W Y=XW Y=XW

比如, x \displaystyle x x h ∗ w ∗ c \displaystyle h*w*c hwc的张量,c表示channel,对于1x1卷积来说,W就是一个 c ∗ c \displaystyle c*c cc的矩阵,就是他们的乘积,实际上可以将x想象成 h ∗ w \displaystyle h*w hw c \displaystyle c c列的矩阵,然后乘以W。

所以接下来的问题只有一个,如何保证这个W是可逆的。为了构造一个一定可逆的矩阵W,我们可以利用LU分解。因为任意矩阵都可以表达成

W = P L U W=PLU W=PLU

其中P是置换矩阵, L是下三角矩阵,对角线元素全为1,U是上三角矩阵,所以为了保证矩阵W可逆,那么只要保证P, L,U满秩就可以了,又因为P,L一定是满秩的,所以只要保证U满秩即可。那么一个方便的方法就是:

W = P L ( U + d i a g ( s ) ) W=PL( U+diag( s)) W=PL(U+diag(s))

其中,U是严格上三角矩阵,其对角线为0,我们只需保证这个s不为0即可。于是最终我们的可逆卷积核其导数行列式的求解为:

log ⁡ ∣ det ⁡ ( d conv ⁡ 2 D ( h ; W ) d h ) ∣ = h ⋅ w ⋅ log ⁡ ∣ det ⁡ ( W ) ∣ = h ⋅ w ⋅ log ⁡ ∣ d i a g ( s ) ∣ = h ⋅ w ⋅ s u m ( l o g ( ∣ s ∣ ) ) \log\left| \operatorname{det}\left(\frac{d\operatorname{conv} 2\mathrm{D} (\mathbf{h} ;\mathbf{W} )}{d\mathbf{h}}\right)\right| =h\cdot w\cdot \log |\operatorname{det} (\mathbf{W} )|\\ =h\cdot w\cdot \log |diag( s) |=h\cdot w\cdot sum( log( |s|)) logdet(dhdconv2D(h;W))=hwlogdet(W)=hwlogdiag(s)=hwsum(log(s))

这里有个绝对值是因为这个分布变换jacobian的行列式一定是大于0的。实际做的时候,P固定这,只要更新L,U和s的参数就好了。

此外,从实际的角度,如果加了BN后,其实相比加性耦合,仿射耦合效果的提升并不高,所以要训练大型的模型,为了节省资源,一般都只用加性耦合,比如Glow训练256x256的高清人脸生成模型,就只用到了加性耦合。

参考资料

Dinh L, Krueger D, Bengio Y. NICE: Non-linear Independent Components Estimation[J]. 2014, 1(2): 1–13.

Dinh L, Sohl-Dickstein J, Bengio S. Density estimation using Real NVP[J]. 2016.

Kingma D P, Dhariwal P, Francisco S. Glow: Generative Flow with Invertible 1×1 Convolutions[J]. : 1–15.

https://kexue.fm/archives/5807

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/a358463121/article/details/104193164

智能推荐

转-关于图片或者文件在数据库的存储方式归纳_weixin_30323631的博客-程序员ITS301

商品图片,用户上传的头像,其他方面的图片。目前业界存储图片有两种做法:1、 把图片直接以二进制形式存储在数据库中一般数据库提供一个二进制字段来存储二进制数据。比如mysql中有个blob字段。oracle数据库中是blob或bfile类型2、 图片存储在磁盘上,数据库字段中保存的是图片的路径。一、图片以二进制形式直接存储在数据库中第一种存储实现(php语言...

struts1.3 入门级例子_stevenbill的博客-程序员ITS301

[code="java"]//web.xml struts action org.apache.struts.action.ActionServlet config /WEB-INF/classes/struts-config.xml debug 3...

Android 11遇到的小坑,调用系统拍照后,图片回调显示不出来_ygz0915的博客-程序员ITS301

调用系统摄像头,一开始用了 Uri.fromFile(file);Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);uri = Uri.fromFile(file);intent.putExtra(MediaStore.EXTRA_OUTPUT, uri);startActivityForResult(intent, Activity.DEFAULT_KEYS_DIALER);百度发现需要用这个uri = FileProvid

系统启动流程 - 理解modules加载流程_Hacker_Albert的博客-程序员ITS301_modules.dep

linux 启动流程1.启动过程分为三个部分BIOS 上电自检(POST)引导装载程序 (GRUB2)内核初始化启动 systemd,其是所有进程之父。1.1.BIOS 上电自检(POST)  BIOS stands for Basic Input/Output System. In simple terms, the BIOS loads and executes the Master Boot Record (MBR) boot loader.  When you first .

python 重复元素判定_Python 计算不重复元素的个数_man4acs的博客-程序员ITS301

情景:计算日志文件中,独立IP的个数,也就是unique visitor。计算量:每小时大概有70万左右的记录,每天24小时,大概1400-1500万条记录一开始,想到用一个list来保存客户端IP,从日志里边取出一个IP,判断是否已经存在,如果存在,就忽略,否则添加到这个list中去addrs = []for line in f.readlines() :addr = get_addr(line...

matlab利用已有激光雷达数据寻找地平面和车辆周围的障碍物仿真实验_爱打瞌睡的CV君的博客-程序员ITS301

matlab利用已有激光雷达数据寻找地平面和车辆周围的障碍物仿真实验

随便推点

性能优化专题六--进程保活十种方式_沙漠一只雕得儿得儿的博客-程序员ITS301_进程保活的几种解决方案

LMK(Low Memory Killer)进程被杀死无非就是由于系统内存过低,并且进程的优先级比较低,所以才会被系统kill掉,想要进程保活必须提高进程的优先级。为什么引入LMK?进程的启动分冷启动和热启动,当用户退出某一个进程的时候,并不会真正的将进程退出,而是将这个进程放到后台,以便下次启动的时候可以马上启动起来,这个过程名为热启动,这也是Android的设计理念之一。这个机制会带来一个问题,每个进程都有自己独立的内存地址空间,随着应用打开数量的增多,系统已使用的内存越来越大,就很有可能导

数据的结构特征(结构化数据)与存储系统类型---分布式存储系统的分类_diaoju3333的博客-程序员ITS301

数据的结构特征非结构化数据:包括所有格式的办公文档、文本、图片、图像、音频、视频信息等。结构化数据:一般会存储在关系型数据库中,可用二位关系的表结构来对数据进行描述,数据的模式需要预先进行定义。半结构化数据:介于结构化数据和半结构化数据直接,HTML文档就属于半结构化数据。它一般是自描述的,与结构化数据的最大区别之处在于,半结构化的数据模式和内容混在一起,没有明显的界限和区分。根...

NIO原理剖析与Netty初步----浅谈高性能服务器开发(一)_止戈(Frank)的博客-程序员ITS301_nio的复杂性

在博主不长的工作经历中,NIO用的并不多,由于使用原生的Java NIO编程的复杂性,大多数时候我们会选择Netty,mina等开源框架,但理解NIO的原理就不重要了吗?恰恰相反,理解NIO底层机制是理解这一切的基础,由此我总结一下当初学习NIO时的笔记,以便后续复习。     以下是我理解的Java原生NIO开发大致流程:           上图大致描述的是服务端的NIO操作。...

java ArrayList数组_路没有尽头的博客-程序员ITS301

@ljsArrayList#ArrayList 该类也是实现了List的接口,实现了可变大小的数组,随机访问和遍历元素时,提供更好的性能。该类也是非同步的,在多线程的情况下不要使用。ArrayList 增长当前长度的50%,插入删除效率低。package first;import java.util.*;public class TestArrayList { public st...

RabbitMQ可视化界面登录不了,报错:Login failed_rabbitmq登录失败_zhoupenghui168的博客-程序员ITS301

​使用docker_lnmp安装了php环境,以及对应的rabbitmq扩展,登录时却登录失败,错误信息:{"error":"not_authorised","reason":"Login failed"}

【AKOJ】1024-A+B Problem_dearvee的博客-程序员ITS301

A+B ProblemTime Limit:1000MS  Memory Limit:65536KTotal Submit:308 Accepted:181原题链接DescriptionA+B Problem 水题水题, 没有最水,只有更水。Input连续输入两个整数 a、 b,当a=0且b=0时结束输入Output输出a+b的值,并换行。Sam

推荐文章

热门文章

相关标签