秀站网,秀执着,秀梦想,一个爱秀的地方!

自媒体资讯网

热门关键词:  as  i hone  阿萨德  网站
在Keras上实现GAN:构建消除图片模糊的应用
来源:
作者:
时间:2018-04-19
浏览热度:
#评论#
[ 导读 ] 2014 年,Ian Goodfellow 提出了生成对抗网络(GAN),今天,GAN 已经成为深度学习最热门的方向之一。本文将重点介绍如何利用 Keras 将 GAN 应用于图像去模糊(image deblurring)任务当中。
2014 年,Ian Goodfellow 提出了生成对抗网络(GAN),今天,GAN 已经成为深度学习最热门的方向之一。本文将重点介绍如何利用 Keras 将 GAN 应用于图像去模糊(image deblurring)任务当中。

Keras 代码地址:https://github.com/RaphaelMeudec/deblur-gan

此外,请查阅 DeblurGAN 的原始论文(https://arxiv.org/pdf/1711.07064.pdf)及其 Pytorch 版本实现:https://github.com/KupynOrest/DeblurGAN/。

生成对抗网络简介

在生成对抗网络中,有两个网络互相进行训练。生成器通过生成逼真的虚假输入来误导判别器,而判别器会分辨输入是真实的还是人造的。

 

在Keras上实现GAN:构建消除图片模糊的应用

 

GAN 训练流程

训练过程中有三个关键步骤:

使用生成器根据噪声创造虚假输入;
利用真实输入和虚假输入训练判别器;
训练整个模型:该模型是判别器和生成器连接所构建的。

请注意,判别器的权重在第三步中被冻结。

对两个网络进行连接的原因是不存在单独对生成器输出的反馈。我们唯一的衡量标准是判别器是否能接受生成的样本。

以上,我们简要介绍了 GAN 的架构。如果你觉得不够详尽,可以参考这篇优秀的介绍:生成对抗网络初学入门:一文读懂 GAN 的基本原理(附资源)。

数据

Ian Goodfellow 首先应用 GAN 模型生成 MNIST 数据。而在本教程中,我们将生成对抗网络应用于图像去模糊。因此,生成器的输入不是噪声,而是模糊的图像。

我们采用的数据集是 GOPRO 数据集。该数据集包含来自多个街景的人工模糊图像。根据场景的不同,该数据集在不同子文件夹中分类。

你可以下载简单版:https://drive.google.com/file/d/1H0PIXvJH4c40pk7ou6nAwoxuR4Qh_Sa2/view

或完整版:https://drive.google.com/file/d/1SlURvdQsokgsoyTosAaELc4zRjQz9T2U/view

我们首先将图像分配到两个文件夹 A(模糊)B(清晰)中。这种 A&B 的架构对应于原始的 pix2pix 论文。为此我创建了一个自定义的脚本在 github 中执行这个任务,请按照 README 的说明去使用它:

https://github.com/RaphaelMeudec/deblur-gan/blob/master/organize_gopro_dataset.py

模型

训练过程保持不变。首先,让我们看看神经网络的架构吧!

生成器

该生成器旨在重现清晰的图像。该网络基于 ResNet 模块,它不断地追踪关于原始模糊图像的演变。本文同样使用了一个基于 UNet 的版本,但我还没有实现这个版本。这两种模块应该都适合图像去模糊。

 

在Keras上实现GAN:构建消除图片模糊的应用

 

DeblurGAN 生成器网络架构,源论文《DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks》。

其核心是应用于原始图像上采样的 9 个 ResNet 模块。让我们来看看 Keras 上的代码实现!

from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout

def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
"""
Instanciate a Keras Resnet Block using sequential API.
:param input: Input tensor
:param filters: Number of filters to use
:param kernel_size: Shape of the kernel for the convolution
:param strides: Shape of the strides for the convolution
:param use_dropout: Boolean value to determine the use of dropout
:return: Keras Model
"""
x = ReflectionPadding2D((1,1))(input)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

if use_dropout:
x = Dropout(0.5)(x)

x = ReflectionPadding2D((1,1))(x)
x = Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,)(x)
x = BatchNormalization()(x)

# Two convolution layers followed by a direct connection between input and output
merged = Add()([input, x])
return merged

该 ResNet 层基本是卷积层,其输入和输出都被添加以形成最终的输出。

from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model

from layer_utils import ReflectionPadding2D, res_block

ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9


def generator_model():
"""Build generator architecture."""
# Current version : ResNet block
inputs = Input(shape=image_shape)

x = ReflectionPadding2D((3, 3))(inputs)
x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

# Increase filter number
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

# Apply 9 ResNet blocks
mult = 2**n_downsampling
for i in range(n_blocks_gen):
x = res_block(x, ngf*mult, use_dropout=True)

相关文档:

客户是否会点击你的广告?机器学..

如何应对机器学习模型的一致性风..

全方位对比深度学习和经典机器学..