谷歌今天又开源了,这次是Sketch-RNN
liuian 2025-04-11 01:00 17 浏览
前不久,谷歌公布了一项最新技术,可以教机器画画。今天,谷歌开源了代码。在我们研究其代码之前,首先先按要求设置Magenta环境。(
https://github.com/tensorflow/magenta/blob/master/README.md)
本文详细解释了Sketch-RNN的TensorFlow代码,即之前发布的两篇文章《Teaching Machines to Draw》和《A Neural Representation of Sketch Drawings》中描述的循环神经网络模型(RNN)。
模型概览
sketch-rnn是序列到序列的变体自动编码器。编码器RNN是双向RNN,解码器是自回归混合密度RNN。你可以使用enc_model,dec_model,enc_size,dec_size设置指定要使用的RNN单元格的类型和RNN的大小。
编码器将采用一个潜在代码z,一个维度为z_size的浮点矢量。像VAE一样,我们可以对z强制执行高斯IID分布,并使用kl_weight来控制KL发散损失项的强度。KL散度损失与重建损失之间将会有一个权衡。我们还允许潜在的代码存储信息的一些空间,而不是纯高斯IID。一旦KL损失期限低于kl_tolerance,我们将停止对该期限的优化。
对于中小型数据集,丢失(dropout)和数据扩充是避免过度拟合的非常有用的技术。我们提供了输入丢失、输出丢失、不存在内存丢失的循环丢失三个选项。实际上,我们只使用循环丢失,通常根据数据集将其设置在65%到90%之间。层次归一化和反复丢失可以一起使用,形成了一个强大的组合,用于在小型数据集上训练循环神经网络。
谷歌提供了两种数据增强技术。第一个是随机缩放训练图像大小的random_scale_factor。第二种增加技术(sketch-rnn论文中未使用)剔除线笔划中的随机点。给定一个具有超过2点的线段,我们可以随机放置线段内的点,并且仍然保持类似的矢量图像。这种类型的数据增强在小数据集上使用时非常强大,并且对矢量图是唯一的,因为难以在文本或MIDI数据中删除随机字符或音符,并且也不可能在像素图像数据中丢弃随机像素而不引起大的视觉差异。我们通常将数据增加参数设置为10%至20%。如果在与普通示例相比较的情况下,人类观众几乎没有差异,那么我们应用数据增强技术,而不考虑训练数据集的大小。
有效地使用丢弃和数据扩充,可以避免过度拟合到一个小的训练集。
训练模型
要训练模型,首先需要一个包含训练/验证/测试例子的数据集。我们提供了指向aaron_sheep数据集的链接,默认情况下,该模型将使用此轻量级数据集。
使用示例:
sketch_rnn_train --log_root=checkpoint_path --data_dir=dataset_path --hparams={"data_set"="dataset_filename.npz"}
我们建议你在模型和数据集内部创建子目录,以保存自己的数据和检查点。 TensorBoard日志将存储在checkpoint_path内,用于查看训练/验证/测试数据集中各种损失的训练曲线。
以下是模型的完整选项列表以及默认设置:
data_set='aaron_sheep.npz', # Our dataset.
num_steps=10000000, # Total number of training set. Keeplarge.
save_every=500, # Number of batches percheckpoint creation.
dec_rnn_size=512, # Size of decoder.
dec_model='lstm', # Decoder: lstm, layer_norm orhyper.
enc_rnn_size=256, # Size of encoder.
enc_model='lstm', # Encoder: lstm, layer_norm orhyper.
z_size=128, # Size of latent vector z.Recommend 32, 64 or 128.
kl_weight=0.5, # KL weight of loss equation.Recommend 0.5 or 1.0.
kl_weight_start=0.01, # KL start weight when annealing.
kl_tolerance=0.2, # Level of KL loss at which to stopoptimizing for KL.
batch_size=100, # Minibatch size. Recommendleaving at 100.
grad_clip=1.0, # Gradient clipping. Recommendleaving at 1.0.
num_mixture=20, # Number of mixtures in Gaussianmixture model.
learning_rate=0.001, # Learning rate.
decay_rate=0.9999, # Learning rate decay per minibatch.
kl_decay_rate=0.99995, # KL annealing decay rate per minibatch.
min_learning_rate=0.00001, # Minimum learning rate.
use_recurrent_dropout=True, # Recurrent Dropout without Memory Loss.Recomended.
recurrent_dropout_prob=0.90, # Probabilityof recurrent dropout keep.
use_input_dropout=False, # Input dropout. Recommend leaving False.
input_dropout_prob=0.90, # Probability of input dropout keep.
use_output_dropout=False, # Output droput. Recommend leaving False.
output_dropout_prob=0.90, # Probability of output dropout keep.
random_scale_factor=0.15, # Random scaling data augmentionproportion.
augment_stroke_prob=0.10, # Point dropping augmentation proportion.
conditional=True, # If False, use decoder-only model.
以下是一些可能需要用于在非常大的数据集上训练模型的选项,并使用HyperLSTM作为RNN单元。对于小于10K的训练样本的小数据集,具有层规范化(包括enc_model和dec_model的layer_norm)的LSTM效果最佳。
sketch_rnn_train --log_root=models/big_model --data_dir=datasets/big_dataset --hparams={"data_set"="big_dataset_filename.npz","dec_model":"hyper","dec_rnn_size":2048,"enc_model":"layer_norm","enc_rnn_size":512,"save_every":5000,"grad_clip":1.0,"use_recurrent_dropout":0}
对于Python 2.7,我们已经在TensorFlow 1.0和1.1上测试了这个模型。
数据集
由于大小限制,此报告不包含任何数据集。
我们已经准备好了许多使用Sketch-RNN开箱即用的数据集。Google QuickDraw数据集(
https://quickdraw.withgoogle.com/data)是涵盖345个类别的50M矢量草图的集合。在quickdraw数据集中,有一个名为Sketch-RNNQuickDraw Dataset的部分描述了可用于此项目的预处理数据文件。每个类别类都存储在其自己的文件中,如cat.npz,并包含70000/2500/2500示例的训练/验证/测试集大小。
从Google云(
https://console.cloud.google.com/storage/quickdraw_dataset/sketchrnn)
下载.npz数据集,以供本地使用。我们建议你创建一个名为datasets / quickdraw的子目录,并将这些.npz文件保存在此子目录中。
除了QuickDraw数据集之外,我们还在较小的数据集上测试了该模型。在sketch-rnn-datasets(
https://github.com/hardmaru/sketch-rnn-datasets)报告中,还有3个数据集:AaronKoblin Sheep Market、Kanji和Omniglot。如果你希望在本地使用它们,我们建议你为每个数据集创建一个子目录,如datasets/ aaron_sheep。如前所述,在小型数据集上训练模型以避免过度拟合时,应使用循环退出和数据增加。
创建自己的数据集
请创建你自己有趣的数据集并训练这些算法!创建新的数据集是乐趣的一部分。你很可能发现有趣的矢量线图数据集,为什么要用现有的预先打包好的数据集呢?在我们的实验中,由几千个例子组成的数据集大小足以产生一些有意义的结果。在这里,我们描述模型期望看到的数据集文件的格式。
数据集中的每个示例都存储为坐标偏移的列表:Δx,Δy用来二进制值表示笔是否从纸张提起。这种格式,我们称之为stroke-3,在论文中有描述(
https://arxiv.org/abs/1308.0850)。 请注意,论文中描述的数据格式有5个元素(stroke-5格式),此转换在DataLoader内自动完成。以下是使用以下格式的乌龟示例草图:
图:作为(Δx,Δy,二进制笔状态)序列的示例草图点和渲染形式。在渲染草图中,线条颜色对应于顺序笔画排列。
在我们的数据集中,示例列表中的每个示例都用np.int16数据类型表示为np.array。你可以将它们存储为np.int8,你可以将其存储起来以节省存储空间。如果你的数据必须是浮点格式,也可以使用np.float16。np.float32可能会浪费存储空间。在我们的数据中,Δx和Δy偏移通常用像素位置表示,它们大于神经网络模型喜欢看到的数字范围,所以在模型中内置了归一化缩放过程。当我们加载训练数据时,模型将自动转换为np.float并在训练前相应规范化。
如果要创建自己的数据集,则必须为训练/验证/测试集创建三个示例列表,以避免过度拟合到训练集。该模型将使用验证集来处理早期停止。对于aaron_sheep数据集,我们使用了7400/300/300的示例,并将每个内容放在python列表中,名为train_data,validation_data和test_data。之后,我们创建了一个名为datasets / aaron_sheep的子目录,我们使用内置的savez_compressed方法将数据集的压缩版本保存在aaron_sheep.npz文件中。在我们的所有实验中,每个数据集的大小是100的确切倍数。
filename = os.path.join('datasets/your_dataset_directory', 'your_dataset_name.npz')
我们还通过执行简单的笔画简化来预处理数据,称为Ramer-Douglas-Peucker。 在这里应用这个算法有一些易于使用的开源代码(
https://github.com/fhirschmann/rdp)。 实际上,我们可以将epsilon参数设置为0.2到3.0之间的值,具体取决于我们想要简单的线条。 在本文中,我们使用了一个2.0的epsilon参数。 我们建议你建立最大序列长度小于250的数据集。
如果你有大量简单的SVG图像,则可以使用一些可用的库(
https://pypi.python.org/pypi/svg.path)来将SVG的子集转换为线段,然后可以在将数据转换为stroke-3格式之前对线段应用RDP。
预训练模型
我们为aaron_sheep数据集提供了预先训练的模型,用于条件和无条件训练模式,使用vanilla LSTM单元以及带有层规范化的LSTM单元。这些型号将通过运行Jupyter Notebook下载。它们存储在:
/tmp/sketch_rnn/models/aaron_sheep/lstm
/tmp/sketch_rnn/models/aaron_sheep/lstm_uncond
/tmp/sketch_rnn/models/aaron_sheep/layer_norm
/tmp/sketch_rnn/models/aaron_sheep/layer_norm_uncond
此外,我们为选定的QuickDraw数据集提供了预先训练的模型:
/tmp/sketch_rnn/models/owl/lstm
/tmp/sketch_rnn/models/flamingo/lstm_uncond
/tmp/sketch_rnn/models/catbus/lstm
/tmp/sketch_rnn/models/elephantpig/lstm
使用Jupyter notebook的模型
让我们来模拟猫和公车之间的插值!
我们涵盖了一个简单的Jupyter notebook(
http://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn.ipynb),向你展示如何加载预先训练的模型并生成矢量草图。你能够在两个矢量图像之间进行编码,解码和变形,并生成新的随机图像。采样图像时,可以调整temperature参数来控制不确定度。
来源:
https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md
相关推荐
- GANs为何引爆机器学习?这篇基于TensorFlow的实例教程为你解惑!
-
「机器人圈导览」:生成对抗网络无疑是机器学习领域近三年来最火爆的研究领域,相关论文层出不求,各种领域的应用层出不穷。那么,GAN到底如何实践?本文编译自Medium,该文作者以一朵玫瑰花为例,详细阐...
- 高丽大学等机构联合发布StarGAN:可自定义表情和面部特征
-
原文来源:arXiv、GitHub作者:YunjeyChoi、MinjeChoi、MunyoungKim、Jung-WooHa、SungKim、JaegulChoo「雷克世界」编译:嗯~...
- TensorFlow和PyTorch相继发布最新版,有何变化
-
原文来源:GitHub「机器人圈」编译:嗯~阿童木呀、多啦A亮Tensorflow主要特征和改进在Tensorflow库中添加封装评估量。所添加的评估量列表如下:1.深度神经网络分类器(DNNCl...
- 「2022 年」崔庆才 Python3 爬虫教程 - 深度学习识别滑动验证码缺口
-
上一节我们使用OpenCV识别了图形验证码躯壳欧。这时候就有朋友可能会说了,现在深度学习不是对图像识别很准吗?那深度学习可以用在识别滑动验证码缺口位置吗?当然也是可以的,本节我们就来了解下使用深度...
- 20K star!搞定 LLM 微调的开源利器
-
LLM(大语言模型)微调一直都是老大难问题,不仅因为微调需要大量的计算资源,而且微调的方法也很多,要去尝试每种方法的效果,需要安装大量的第三方库和依赖,甚至要接入一些框架,可能在还没开始微调就已经因为...
- 大模型DeepSeek本地部署后如何进行自定义调整?
-
1.理解模型架构a)查看深度求索官方文档或提供的源代码文件,了解模型的结构、输入输出格式以及支持的功能。模型是否为预训练权重?如果是,可以在预训练的基础上进行微调(Fine-tuning)。是否需要...
- 因配置不当,约5000个AI模型与数据集在公网暴露
-
除了可访问机器学习模型外,暴露的数据还可能包括训练数据集、超参数,甚至是用于构建模型的原始数据。前情回顾·人工智能安全动态向ChatGPT植入恶意“长期记忆”,持续窃取用户输入数据多模态大语言模型的致...
- 基于pytorch的深度学习人员重识别
-
基于pytorch的深度学习人员重识别Torchreid是一个库。基于pytorch的深度学习人员重识别。特点:支持多GPU训练支持图像的人员重识别与视频的人员重识别端到端的训练与评估简单的re...
- DeepSeek本地部署:轻松训练你的AI模型
-
引言:为什么选择本地部署?在AI技术飞速发展的今天,越来越多的企业和个人希望将AI技术应用于实际场景中。然而,对于一些对数据隐私和计算资源有特殊需求的用户来说,云端部署可能并不是最佳选择。此时,本地部...
- 谷歌今天又开源了,这次是Sketch-RNN
-
前不久,谷歌公布了一项最新技术,可以教机器画画。今天,谷歌开源了代码。在我们研究其代码之前,首先先按要求设置Magenta环境。(https://github.com/tensorflow/magen...
- Tensorflow 使用预训练模型训练的完整流程
-
前面已经介绍了深度学习框架Tensorflow的图像的标注和训练数据的准备工作,本文介绍一下使用预训练模型完成训练并导出训练的模型。1.选择预训练模型1.1下载预训练模型首先需要在Tensorf...
- 30天大模型调优学习计划(30分钟训练大模型)
-
30天大模型调优学习计划,结合Unsloth和Lora进行大模型微调,掌握大模型基础知识和调优方法,熟练应用。第1周:基础入门目标:了解大模型基础并熟悉Unsloth等工具的基本使用。Day1:大模...
- python爬取喜马拉雅音频,json参数解析
-
一.抓包分析json,获取加密方式1.抓包获取音频界面f12打开抓包工具,播放一个(非vip)视频,点击“媒体”单击打开可以复制URL,发现就是我们要的音频。复制“CKwRIJEEXn-cABa0Tg...
- 五、JSONPath使用(Python)(json数据python)
-
1.安装方法pipinstalljsonpath2.jsonpath与Xpath下面表格是jsonpath语法与Xpath的完整概述和比较。Xpathjsonpath概述/$根节点.@当前节点...
- Python网络爬虫的时候json=就是让你少写个json.dumps()
-
大家好,我是皮皮。一、前言前几天在Python白银交流群【空翼】问了一个Python网络爬虫的问题,提问截图如下:登录请求地址是这个:二、实现过程这里【甯同学】给了一个提示,如下所示:估计很多小伙伴和...
- 一周热门
-
-
Python实现人事自动打卡,再也不会被批评
-
Psutil + Flask + Pyecharts + Bootstrap 开发动态可视化系统监控
-
一个解决支持HTML/CSS/JS网页转PDF(高质量)的终极解决方案
-
再见Swagger UI 国人开源了一款超好用的 API 文档生成框架,真香
-
【验证码逆向专栏】vaptcha 手势验证码逆向分析
-
网页转成pdf文件的经验分享 网页转成pdf文件的经验分享怎么弄
-
C++ std::vector 简介
-
python使用fitz模块提取pdf中的图片
-
《人人译客》如何规划你的移动电商网站(2)
-
Jupyterhub安装教程 jupyter怎么安装包
-
- 最近发表
- 标签列表
-
- python判断字典是否为空 (50)
- crontab每周一执行 (48)
- aes和des区别 (43)
- bash脚本和shell脚本的区别 (35)
- canvas库 (33)
- dataframe筛选满足条件的行 (35)
- gitlab日志 (33)
- lua xpcall (36)
- blob转json (33)
- python判断是否在列表中 (34)
- python html转pdf (36)
- 安装指定版本npm (37)
- idea搜索jar包内容 (33)
- css鼠标悬停出现隐藏的文字 (34)
- linux nacos启动命令 (33)
- gitlab 日志 (36)
- adb pull (37)
- table.render (33)
- uniapp textarea (33)
- python判断元素在不在列表里 (34)
- python 字典删除元素 (34)
- react-admin (33)
- vscode切换git分支 (35)
- vscode美化代码 (33)
- python bytes转16进制 (35)