CycleGAN模型原理
前面我们了解了好几种GAN,它们大致可分为:
随机生成模型的GAN,包括GAN,DCGAN,WGAN,WGAN-GP等;
带条件生成模型的GAN,包括CGAN,InfoGAN,ACGAN等。
它们都是监督学习模型,即生成网络都有一个目标样本集。
除了监督学习模型,还有一类非监督学习模型,比如CycleGAN。
CycleGAN如下图所示,它能在油画到相片互相生成;马到斑马互相生成;夏天到冬天季节相互变化。它们都不是要把当前域的样本拟合到另一个域。
即:
1 X域所有图片的共性:油画 -> Y域所有图片的共性:真实照片,当前主要特征画内容的线条保持不变;
2 X域所有图片的共性:马的纹理 -> Y域所有图片的共性:斑马纹理,当前马的轮廓不会变;
3 X域所有图片的共性:夏天风景 -> Y域所有图片的共性:冬天风景,当前的河流,树木轮廓都不会变。
怎么才能达到上面的要求呢?以马到斑马为例:
1 在X马的样本里面,我们能找到它的公共特性就是马那黄色填充的毛,在Y斑马的样本里面,公共特性就是那条纹的着色填充。如果要把X拟合成Y,把Y拟合成X,那么很容易想到用GAN既可以解决,即用最小化“原始GAN的损失”来达到。
2 但是在步骤1完成后,我们会发现,马变成斑马后再也不像以前那匹马了。对人来说,感受物体最主要的特征就是物体的轮廓,所以我们采用最小化“Cycle一致性损失”,即对于单一的当前样本来说,X转成Y,然后再转回来X1,必须保持轮廓大致不变,这样人就会觉得物体没有变化。但对于计算机来说,它能看到的主要特征是大面积的着色填充,在X转成Y,或者Y再转回来的的过程中,着色填充确实做到了X到Y,Y到X的拟合。这样步骤1和步骤2就完成了计算机和人的双重欺骗。
3 怎么还有步骤3呢?这是论文后来加上的。在油画到真实照片的转换中,或反之,我们不希望它们的转换对自身产生作用,即X转成X1,Y转成Y1,不希望转换做任何改变,直接返回就可以了。论文中用最小化“Identity映射损失”来达到。
综上所述,就是用大面积的着色来骗计算机,用对人眼敏感的轮廓来骗人。这个理论也可以用在CGAN训练MNIST上,高斯随机噪声加数字输入,数字可以看作轮廓,噪声里面隐含了计算机能看到的着色。
损失函数组成
根据上面的分析,整个CycleGAN的目标损失函数由3部分组成:
1 原始GAN的损失:X通过生成网络G生成Y,到鉴别网络DY的输出;Y通过生成网络F生成X,到鉴别网络DX的输出。
如代码中:
a X到Y的鉴别网络loss(dB_loss_fake),生成网络loss(g_AB_loss):
img_A -> g_AB(img_A) -> fake_B -> d_B(fake_B) -> valid_B : 输入img_A,输出valid_B,鉴别网络d_B目标fake,生成网络g_AB目标valid;
b 真实Y的鉴别网络loss(dB_loss_real):
img_B-> d_B(img_B) -> valid_B : 输入img_B,输出valid_B,鉴别网络d_B目标valid;
c Y到X的鉴别网络loss(dA_loss_fake),生成网络loss(g_BA_loss):
img_B -> g_BA(img_B) -> fake_A -> d_A(fake_A) -> valid_A : 输入img_B,输出valid_A,鉴别网络d_A目标fake,生成网络g_BA目标valid;
d 真实X的鉴别网络loss(dA_loss_real):
img_A-> d_A(img_A) -> valid_A : 输入img_A,输出valid_A,鉴别网络d_A目标valid。
2 Cycle一致性损失:X通过生成网络G生成Y,然后再通过生成网络F回到X1,X到X1的损失。网络的目标是保证这种损失尽量小,即能还原X。
如代码中:
a 生成网络loss(g_AB_BA_loss):
img_A -> g_AB(img_A) -> fake_B -> g_BA(fake_B) -> reconstr_A :输入img_A,输出reconstr_A,生成网络g_AB -> g_BA目标imgs_A;
b 生成网络loss(g_BA_AB_loss):
img_B -> g_BA(img_B) -> fake_A -> g_AB(fake_A) -> reconstr_B :输入img_B,输出reconstr_B,生成网络g_BA -> g_AB目标imgs_B。
3 Identity映射损失:X通过网络F生成X1,网络需要保证X到X1几乎没有改变,即X域到X域的转变不需要修改,最大保留X的特性。
如代码中:
a 生成网络loss(g_BA_Ident_loss):
img_A -> g_BA -> img_A_id :输入img_A,输出img_A_id,目标imgs_A;
b 生成网络loss(g_AB_Ident_loss):
img_B -> g_AB -> img_B_id :输入img_B,输出img_B_id,目标img_B。
综上,鉴别网络的loss为:dA_loss_real,dA_loss_fake,dB_loss_real,dB_loss_fake;生成网络的loss为:g_AB_loss,g_BA_loss,g_AB_BA_loss,g_BA_AB_loss,g_AB_Ident_loss,g_BA_Ident_loss。