对抗生成网络GAN应用场景
有一个原始的输入,想按照我的一个要求或者按照我的一个目标尽可能完美的生成我想要的东西,比如根据夏天的图片生成冬天的图片,根据人脸图片随机生成一个人脸。
超分辨率重构
告诉网络什么样的是低分辨率,什么样的是高分辨率,设计一个损失函数,学一学两者之间的联系,就可以根据低分辨率生成高分辨率的图片了或者根据侧脸生成全脸图片。
对抗生成网络定义
- 生成 :想要什么生成什么
- 对抗:用最强的矛捅最硬的盾
罪犯印假钱,罪犯希望印出来的假币越真越好,直到能够骗过警察;而警察希望能够分辨真假。
随机的变量(Random Vector)比如100维的向量,这是噪音变量或垃圾变量,然后经过生成网络(Generator Network)进行转换得到假的数据,生成网络希望它生成的假的数据能够骗过判决器(Discriminator Network)的识别;将真实的数据和假的数据都输入到判决器网络中,让判决器能够识别真假。判决网络把生成网络生成的数据识别为假的,把真实数据识别成真的。
- 生成网络可以用传统的神经网络去做,比如输入28x28x1=748个像素点的图片,最终生成的结果也是28x28x1=748个像素点的图片
- 也可以用卷积神经网络来做(生成图像数据的话,最终生成一个特征图)
特征图(H W C)C=3就会让网络往真实图片上靠拢
即定义生成的结果是什么样的,让网络去学习怎样生成这样的。无论网络怎么定义,最终要做成一件什么事情,是由损失函数决定的,损失函数决定了整个网络的走势。
导入类库
定义损失函数
损失函数2个参数,一个是预测值,一个是标签值。
定义随机的输入数据
这些输入数据就当作是预测结果了。
把预测结果全部传入Sigmoid函数中去,映射为0-1之间的值,满足了损失函数对于预测参数的要求。
所有的预测结果必须映射到0-1的范围当中。
判决网络其实是一个二分类网络,做的好不好,就是看真的是不是判断为真的,假的判断为假的即0/1问题
上述为定义真实的标签值。有了预测值,有了标签数据,看损失值是怎么计算的
损失值计算公式
-
t[i]是概率值即预测的概率值
-
log(o[i]) 对数
概率值 x 对数 + (1-概率值)x 对数
代码实现损失函数公式
- 第一个值0/1 就是实际的标签值即真实的类别,1-0/1是错误判断的类别
- 共有9个样本(真实值),计算每个样本的损失值
- 0-1之间的对数值都是负的,前面加个-号,使得变为正数
相当于9个样本的损失值加在一起,除以9,求平均值
BCELoss
- loss函数第一个参数是预测值,第二个参数是真实值
- m是sigmoid函数,将预测值映射到0-1之间
计算结果是手动计算是一样的,
BCEWithLogitsLoss
构建GAN网络
损失函数定义好之后,构建GAN网络
-
构建生成器网络
比如输入28x28x1=784个像素点的图片,这里没有用卷积,而是用最简单最基本的全连接做的。
in_feat=100 表示输入100维的向量,out_feat=128 表示中间隐层是128个神经元,128个特征;然后加上Relu激活函数
第一个block,输入100维向量,输出128维的向量;第二个block 输入128转换成256;第三个block输入256转换成512;第四个block输入512转换成1024;要得到的特征数需跟输入一致的,输入是28x28x1=748个特征,所以需要将1024转换成748。
- 构建判决器网络
第一层是全连接层。判断一张图片的真假,输入784个像素点的图片(生成网络生成的图片和原始图片是一样的,都是784个像素点),经过几个全连接层(512、256)和激活函数,最终得到一个预测值,再把预测值传入sigmoid函数,将预测概率映射到0-1之间。
使用gpu计算,速度更快,100个epoch,几分钟就训练完了。
数据预处理
如果数据之前没有下载,就先进行下载,然后进行数据预处理这些常规操作。
定义优化器
训练的过程
针对一个batch的数据,定义真假标签
-
真实数据标签定义为1
-
假数据标签为0
实际数据
这是一个4维的数据: batch size * channel * h * w
- zero_grad 表示 梯度清零
z表示随机构建的100维的向量
- imgs.shape[0]表示batch的个数
- opt.latent_dim=100 表示100维的向量
- random.normal是随机的高斯分布 均值为0 标准差为1
有多少个batch就生成多少个随机向量
-
生成网络
通过生成网络把100维的向量转换成784维的
生成器生成的数据让判决器认为是真的
第一个参数是预测值,第二个参数是真实值(标签值)。传入的标签值都是1,即告诉判决器,我生成器生成的数据是真的(生成器希望能够骗到判决器,逃脱判决器的法眼)
- 判决器认为生成器生成的数据是假的,而生成器希望能够骗过判决器的法眼,让判决器认为它生成的数据是真的
- 判决器认为真实数据是真的
-
real_imgs
读进来的实际的数据
判决器得有能力将实际数据认为是真数据
判决器得有能力识别生成的数据是假数据
- loss平均值=(真数据损失值+假数据损失值)/2
将GAN生成的数据保存下来
这是100个epoch生成的结果
生成的效果和MNIST数据集手写数字差不多了,这还是用的最简单最基本的方法仅迭代100次、全连接网络的对抗生成网络GAN。