昇思25天学习打卡营第5天|GAN图像生成

文章目录

      • 昇思MindSpore应用实践
        • 基于MindSpore的生成对抗网络图像生成
          • 1、生成对抗网络简介
            • 零和博弈 vs 极大极小博弈
            • GAN的生成对抗损失
          • 2、基于MindSpore的 Vanilla GAN
          • 3、基于MindSpore的手写数字图像生成
            • 导入数据
            • 数据可视化
            • 模型训练
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

基于MindSpore的生成对抗网络图像生成
1、生成对抗网络简介
零和博弈 vs 极大极小博弈

生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)

零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。

极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。

假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。

在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布;
  2. 判别器通过求取梯度和损失函数对网络进行优化,将接近真实数据分布的数据判定为1 D ( x ) = 1 D(x)=1 D(x)=1),将接近生成器生成数据分布数据判定为0(( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min ⁡ G max ⁡ D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) GminDmaxV(G,D)
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据;
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失

min ⁡ G max ⁡ D V ( G , D ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。

当生成对抗损失达到纳什均衡时,判别器对真假数据的判别概率都是0.5,即 D ( x ) = 1 − G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1G(z)=0.5

l o g ( D ( x ) ) = l o g ( 1 − G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1G(z))0.693

由于数据x和G(z)不仅是一张图片,再分别取两者的均值 E \mathbb{E} E,相加,就得到了生成对抗损失。

近十年来著名的GAN网络结构:
在这里插入图片描述

2、基于MindSpore的 Vanilla GAN

生成器部分:生成器 Generator 的功能是将隐码映射到数据空间。通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内,并返回一张28x28的图像作为生成结果。

from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)28x28

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,通过第一层线性变换将其映射到128维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # 通过第二层线性变换将其映射到256维
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器部分:判别器 Discriminator 是一个二分类网络模型,在训练时,判别器接收生成器的生成图像与对应的真实数据相对比,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。

 # 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')
3、基于MindSpore的手写数字图像生成
导入数据
import numpy as np
import mindspore.dataset as ds

batch_size = 128
latent_size = 100  # 潜在编码的长度

train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')

def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)
    
    # 数据增强
    mnist_ds = dataset1.map(  # 通过map方法给每张图像映射一个潜在编码
    # 将图像数据转换为 float32 类型
    # 生成一个长度为 latent_size 的服从正态分布的随机向量,并将其转换为 float32 类型
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])

    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)

    return mnist_ds

mnist_ds = data_load(train_dataset)

iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
数据可视化
import matplotlib.pyplot as plt

data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
    image = data_iter['image'][idx]
    figure.add_subplot(rows, cols, idx)
    plt.axis("off")
    plt.imshow(image.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述
潜在编码(latent code)的构造:
为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过这组固定的潜在编码(也叫隐码)所生成的图像效果来评估生成器的生成质量。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype

# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型训练

定义损失函数和优化器:

lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

训练分为两个主要部分,也就是要训练两个网络:生成与对抗网络。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1D(G(z)) 的值。

第二部分是训练生成器。如论文所述,最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z))) 来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将固定隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 24  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小

# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'

checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径

# 生成器计算损失过程
def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g

# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    # 计算判别器损失和梯度
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)

    return loss_d, loss_g

# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

# 储存生成器和判别器loss
losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 0:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")

    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))

    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)

    # 根据epoch保存模型权重文件
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'Wayn_Fan-sail')

使用cpu进行12个epoch的生成效果如下:
在这里插入图片描述

Reference

昇思官方文档-GAN图像生成
昇思大模型平台

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/777367.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

01 企业网站架构部署于优化之Web基础与HTTP协议

目录 1.1 Web基础 1.1.1 域名和DNS 1. 域名的概念 2. Hosts文件 3. DNS 4. 域名注册 1.1.2 网页与HTML 1. 网页概述 2. HTML概述 3. HTML基本标签 4. 网站和主页 5. Web1.0与Web2.0 1.1.3 静态网页与动态网页 1. 静态网页 2. 动态网页 3. 动态网页语言 1.2 HTTP协议 1…

如何快速开展每日待办工作 待办任务高效管理

每天,我们都需要处理大量的待办工作,如何高效有序地开展这些工作成为了我们必须要面对的问题。仅仅依靠个人的记忆和脑力去管理这些繁杂的事务,显然是一项艰巨的挑战。在这个时候,如果能有一款实用的待办工具来辅助我们&#xff0…

7年跨境业务资深从业者的代理IP使用经验分享

作为一个拥有7年跨境业务经验的资深从业者,今天大家分享一下在使用各种代理IP服务中的宝贵经验。无论你是新手还是老手,这篇文章都能为你带来一些新的启发和实用技巧。 在我刚开始跨境业务的那几年,最大的挑战之一就是如何跨境访问&#xff0…

ORB 特征点提取

FAST关键点 选取像素p,假设它的亮度为Ip; . 设置一个阈值T(比如Ip的20%); 以像素p为中心,选取半径为3的圆上的16个像素点; 假如选取的圆上,有连续的N个点的亮度大于IpT或小于…

CSS实现图片裁剪居中(只截取剪裁图片中间部分,图片不变形)

1.第一种方式:(直接给图片设置:object-fit:cover;) .imgbox{width: 100%;height:200px;overflow: hidden;position: relative;img{width: 100%;height: 100%; //图片要设置高度display: block;position: absolute;left: 0;right…

【Python基础篇】你了解python中运算符吗

文章目录 1. 算数运算符1.1 //整除1.2 %取模1.3 **幂 2. 赋值运算符3. 位运算符3.1 &&#xff08;按位与&#xff09;3.2 |&#xff08;按位或&#xff09;3.3 ^&#xff08;按位异或&#xff09;3.4 ~&#xff08;按位取反&#xff09;3.5 <<&#xff08;左移&#…

【JavaWeb程序设计】JSP编程II

目录 一、输入并运行下面的import_test.jsp页面 1.1 代码运行结果 1.2 修改编码之后的运行结果 二、errorPage属性和isErrorPage属性的使用 2.1 下面的hello.jsp页面执行时将抛出一个异常&#xff0c;它指定了错误处理页面为errorHandler.jsp。 2.1.2 运行截图 2.2 下面…

压测工具---Ultron

压测工具&#xff1a;Ultron 类型&#xff1a;接口级和全链路 接口级 对于接口级别的压测我们可以进行 http接口压测、thrift压测、redis压测、kafka压测、DDMQ压测、MySQL压测等&#xff0c;选对对应的业务线、选择好压测执行的时间和轮数就可以执行压测操作了 全链路 对…

Java新特性梳理——Java15

highlight: xcode theme: vuepress 概述 2020 年 9 月 15 日&#xff0c;Java 15 正式发布&#xff0c;(风平浪静的一个版本)共有 14 个 JEP&#xff0c;是时间驱动形式发布的第六个版本。相关文档&#xff1a;https://openjdk.java.net/projects/jdk/15/ 语法层面变化 密封类 …

【机器学习】基于密度的聚类算法:DBSCAN详解

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 基于密度的聚类算法&#xff1a;DBSCAN详解引言DBSCAN的基本概念点的分类聚类过…

JVM原理(十七):JVM虚拟机即时编译器详解

编译器无论在何时、在何种状态下把Class文件转换成与本地基础设施相关的二进制机器码&#xff0c;他都可以视为整个编译过程的后端。 后端编译器编译性能的好坏、代码优化质量的高低却是衡量一款商用虛拟机优秀与否的关键指标之一。 1. 即时编译器 即时编译器是一个把Java的…

19.【C语言】初识指针(重难点)

内存&#xff1a;所有程序的运行在内存中 用Cheat Engine查看任意程序的内存(16进制&#xff09;&#xff1a; 显示大量的数据 想要定位某个数字 &#xff0c;需要知道地址(类比二维坐标) 如F8的地址为00BCB90008,所以是00BCB908(偏移) ctrlG 则有 内存单元的说明&#xff1…

动态颤抖的眼睛效果404页面源码

动态颤抖的眼睛效果404页面源码&#xff0c; 源码由HTMLCSSJS组成&#xff0c;记事本打开源码文件可以进行内容文字之类的修改&#xff0c;双击html文件可以本地运行效果&#xff0c;也可以上传到服务器里面&#xff0c;重定向这个界面 动态颤抖的眼睛效果404页面源码

Portainer 是一个开源的容器管理平台-非常直观好用的Docker图形化项目

在这个容器化技术大行其道的时代&#xff0c;Docker和Kubernetes几乎成了技术圈的新宠。可是管理起容器来&#xff0c;有时候还是有点头大。命令行操作对于某些小伙伴来说&#xff0c;可能还是有点不太友好。 今天开源君分享一个叫 Portainer 的开源项目&#xff0c;一个用来简…

AI大模型时代的存储发展趋势

从2022年下半年&#xff0c;大模型和AIGC这两个词变得极其火热&#xff0c;而GPU的市场也是一卡难求。对于这种迷乱和火热&#xff0c;让我想起了当年的比特币挖矿和IPFS。似乎世界一年一个新风口&#xff0c;比特币、元宇宙、NFT、AIGC&#xff0c;金钱永不眠&#xff0c;IT炒…

【React】React18 Hooks 之 useReducer

目录 useReducer案例1&#xff1a;useReducer不带初始化函数案例2&#xff1a;useReducer带初始化函数注意事项1&#xff1a;dispatch函数不会改变正在运行的代码的状态注意事项2&#xff1a;获取dispatch函数触发后 JavaScript 变量的值注意事项3&#xff1a;触发了reducer&am…

【MotionCap】pycharm 远程在wsl2 ubuntu20.04中root的miniconda3环境

pycharm wsl2 链接到pycharmsbin 都能看到内容,/root 下内容赋予了zhangbin 所有,pycharm还是看不到/root 下内容。sudo 安装了miniconda3 引发了这些问题 由于是在 root 用户安装的miniconda3 所以安装路径在/root/miniconda3 里 这导致了环境也是root用户的,会触发告警 WA…

Xilinx原语

1. 原语介绍 原语是 Xilinx 器件底层硬件中的功能模块&#xff0c;它使用专用的资源来实现一系列的功能。相比于 IP 核&#xff0c;原语的调用方法更简单&#xff0c;但是一般只用于实现一些简单的功能。本章主要用到了 BUFG、 BUFIO、 IDDR、 ODDR、IDELAYE2 和 IDELAYCTRL。…

14-29 剑和诗人3 – 利用知识图谱增强 LLM 推理能力

知识图谱提供了一种结构化的方式来表示现实世界的事实及其关系。通过将知识图谱整合到大型语言模型中&#xff0c;我们可以增强它们的事实知识和推理能力。让我们探索如何实现这一点。 知识图谱构建 在利用知识图谱进行语言模型增强之前&#xff0c;我们需要从可靠的来源构建…

AIGC | 为机器学习工作站安装NVIDIA 4070 Ti Super显卡驱动

[ 知识是人生的灯塔&#xff0c;只有不断学习&#xff0c;才能照亮前行的道路 ] 0x00 前言简述 话接上篇《AIGC | Ubuntu24.04桌面版安装后必要配置》文章&#xff0c;作为作者进行机器学习的基础篇&#xff08;筑基期&#xff09;&#xff0c;后续将主要介绍机器学习环境之如何…