首页IT科技基于深度学习的图片上色

基于深度学习的图片上色

时间2025-08-01 08:51:04分类IT科技浏览4541
导读:如果有不懂的,欢迎下方评论,你还在为毕设课设烦恼吗?注意下方图片右下角水印,解决一切问题,欢迎咨询。...

如果有不懂的               ,欢迎下方评论                      ,你还在为毕设课设烦恼吗?注意下方图片右下角水印       ,解决一切问题       ,欢迎咨询               。

1. 前言

本文基于pytorch和opencv使用生成对抗网络对灰度图像自动上色                      ,然后可以对上色后的图片手动调节亮度对比度等信息              ,最后可以保存上色后的图像       ,闲话少说                      ,先看一下效果              ,文章最后附有全部代码及数据集下载链接                      。

灰度图自动上色

b站视频地址:b站视频地址

2.图像格式(RGB,HSV                      ,Lab)

2.1 RGB

想要对灰度图片上色                     ,首先要了解图像的格式,对于一副普通的图像通常为RGB格式的               ,即红              、绿                      、蓝三个通道                     ,可以使用opencv分离图像的三个通道       ,代码如下所示:

import cv2 img=cv2.imread(pic/7.jpg) B,G,R=cv2.split(img) cv2.imshow(img,img) cv2.imshow(B,B) cv2.imshow(G,G) cv2.imshow(R,R) cv2.waitKey(0)

代码运行结果如下所示       。

2.2 hsv

hsv是图像的另一种格式               ,其中h代表图像的色调                      ,s代表饱和度       ,v代表图像亮度       ,可以通过调节h        、s              、v的值来改变图像的色调                     、饱和度        、亮度等信息       。

同样可以使用opencv将图像从RGB格式转换成hsv格式                      。然后可以分离h       、s                     、v三个通道并显示图像代码如下所示: import cv2 img=cv2.imread(pic/7.jpg) hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV) h,s,v=cv2.split(hsv) cv2.imshow(hsv,hsv) cv2.imshow(h,h) cv2.imshow(s,s) cv2.imshow(v,v) cv2.waitKey(0)

运行结果如下所示:

2.3 Lab

Lab是图像的另一种格式                      ,也是本文使用的格式              ,其中L代表灰度图像       ,a               、b代表颜色通道                      ,本文使用L通道灰度图作为输入              ,ab两个颜色通道作为输出,训练生成对抗网络                      ,将图像由RGB格式转换成Lab格式的代码如下所示:

import cv2 img=cv2.imread(pic/7.jpg) Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab) L,a,b=cv2.split(Lab) cv2.imshow(Lab,Lab) cv2.imshow(L,L) cv2.imshow(a,a) cv2.imshow(b,b) cv2.waitKey(0)

3. 生成对抗网络(GAN)

生成对抗网络主要包含两部分                     ,分别是生成网络和判别网络              。 生成网络负责生成图像,判别网络负责鉴定生成图像的好坏               ,二者相辅相成                     ,相互博弈       。 本文使用U-net作为生成网络       ,使用ResNet18作为判别网络                      。U-net网络的结构图如下所示:

3.1 生成网络(Unet)

pytorch构建unet网络的代码如下所示:

class DownsampleLayer(nn.Module): def __init__(self,in_ch,out_ch): super(DownsampleLayer, self).__init__() self.Conv_BN_ReLU_2=nn.Sequential( nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(), nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1), nn.BatchNorm2d(out_ch), nn.ReLU() ) self.downsample=nn.Sequential( nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1), nn.BatchNorm2d(out_ch), nn.ReLU() ) def forward(self,x): """ :param x: :return: out输出到深层               ,out_2输入到下一层                      , """ out=self.Conv_BN_ReLU_2(x) out_2=self.downsample(out) return out,out_2 class UpSampleLayer(nn.Module): def __init__(self,in_ch,out_ch): # 512-1024-512 # 1024-512-256 # 512-256-128 # 256-128-64 super(UpSampleLayer, self).__init__() self.Conv_BN_ReLU_2 = nn.Sequential( nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1), nn.BatchNorm2d(out_ch*2), nn.ReLU(), nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1), nn.BatchNorm2d(out_ch*2), nn.ReLU() ) self.upsample=nn.Sequential( nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1), nn.BatchNorm2d(out_ch), nn.ReLU() ) def forward(self,x,out): :param x: 输入卷积层 :param out:与上采样层进行cat :return: x_out=self.Conv_BN_ReLU_2(x) x_out=self.upsample(x_out) cat_out=torch.cat((x_out,out),dim=1) return cat_out class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024] #下采样 self.d1=DownsampleLayer(3,out_channels[0])#3-64 self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128 self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256 self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512 #上采样 self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512 self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256 self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128 self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64 #输出 self.o=nn.Sequential( nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1), nn.BatchNorm2d(out_channels[0]), nn.ReLU(), nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels[0]), nn.ReLU(), nn.Conv2d(out_channels[0],3,3,1,1), nn.Sigmoid(), # BCELoss ) def forward(self,x): out_1,out1=self.d1(x) out_2,out2=self.d2(out1) out_3,out3=self.d3(out2) out_4,out4=self.d4(out3) out5=self.u1(out4,out_4) out6=self.u2(out5,out_3) out7=self.u3(out6,out_2) out8=self.u4(out7,out_1) out=self.o(out8) return out

3.2 判别网络(resnet18)

resnet18的结构图如下所示:

在pytorch内部自带resnet18模型       ,只需一行代码即可构建resnet18模型       ,然后还需要去除网络最后的全连接层                      ,代码如下所示: from torchvision import models resnet18=models.resnet18(pretrained=False) del resnet18.fc print(resnet18)

4. 数据集

本文使用的是自然风景类的数据图片              ,在网站上爬取了大概1000多张数据图片       ,部分图片如下所示

5. 模型训练与预测流程图

5.1 训练流程图

如下图所示                      ,首先将RGB图像转换成Lab图像              ,然后将L通道作为生成网络输入,生成网络的输出为新的ab两通道                      ,然后将图像原始的ab通道                     ,与生成网络生成的ab通道输入判别网络中              。

5.2 预测流程图

下图为模型的预测过程,在预测过程中判别网络已经没有作用了               ,首先将RGB图像转换成                     ,Lab图像       ,接着将L灰度图输入生成网络可以得到新的ab通道图像               ,接着将L通道图像与生成的ab通道图像进行拼接(concate),拼接以后可以得到一张新的Lab图像                      ,然后再将其转换成RGB格式       ,此时图像即为上色以后的图像。

6. 模型预测效果

下图为模型的预测效果                      。左侧的为灰度图像       ,中间的为原始的彩色图像                      ,右侧的是模型上色以后的图像                     。整体上看              ,网络的上色效果还不错。

7. GUI界面制作

为了更加方便使用模型       ,本文使用pyqt5制作操作界面                      ,其界面如下图所示:首先可以从电脑中加载图像              ,还可以切换上一张或者下一张,可以将图像灰度化显示               。可以对其上色                      ,然后可以调整上色后图像的H       、S                     、V信息                     ,最后支持图像导出,可以将上色后的图像保存到本地中                     。

8.代码下载

链接中包含了训练代码               ,测试代码                     ,以及界面代码       。此外还包含1000多张数据集       ,直接运行main.py程序即可弹出操作界面               。

代码及数据集下载链接
声明:本站所有文章               ,如无特殊说明或标注                      ,均为本站原创发布                      。任何个人或组织       ,在未征得本站同意时       ,禁止复制               、盗用、采集                     、发布本站内容到任何网站                      、书籍等各类媒体平台       。如若本站内容侵犯了原著者的合法权益                      ,可联系我们进行处理       。

创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!

展开全文READ MORE
css动画定义(巧用 CSS 变量,实现动画函数复用,制作高级感拉满的网格动画) 拦截器 handler参数(Sa-Token v.1.31.0 新增拦截器 SaInterceptor 功能说明,以及旧代码迁移示例)