天天看点

写给程序员的机器学习入门 (八 补充) - 使用 GPU 训练模型

在之前的文章中我训练模型都是使用的 CPU,因为家中黄脸婆不允许我浪费钱买电脑😭。终于的,附近一个废品回收站的朋友转让给我一台破烂旧电脑,所以我现在可以体验使用 GPU 训练模型了🥳。

pytorch, tensorflow 等主流的框架的 GPU 支持都基于 CUDA 框架,而目前提供 CUDA 支持的显卡只有 nvidia,这次我捡到的破烂是 GTX 1650 4GB 所以满足最低要求了。简单描述下目前各种显卡的支持程度:

Intel 核显:死心叭

APU:没法用

Nvidia Geforce

2GB 可以用来跑一些入门例子

4GB 可以跑一些简单模型

6GB 可以跑一些中级模型

8GB 可以跑一些高级模型

10GB以上 可以跑最前沿的模型

Radeon:要折腾,试试 ROCm

如果真的要玩机器学习推荐购买 RTX 系列,因为有 tensor 核心和 16 位浮点数支持,训练速度会快很多,并且使用 16 位浮点数可以让显存占用少一半。虽然在过几个星期就可以看到 3000 系列的显卡了,可惜没钱买🤒。此外,明年如果出支持机器学习的民用国产显卡必定会大力支持😡。

Windows 的话会通过 Windows Update 自动安装, pytorch 会自动检测出显卡,不需要做任何工作。Linux 需要安装 Nvidia 官方的闭源驱动 (开源的 Nouveau 驱动不支持 CUDA),如果是 Ubuntu 那么在安装系统的时候打个勾就可以自动安装,如果没打可以参考这篇文章,其他 Linux 系统如果源没有提供可以去 Nvidia 官方下载驱动。

安装以后可以执行以下代码看看 pytorch 是否可以检测出显卡:

如果输出类似以上的结果,那么就代表没有问题了。

pytorch 默认会把 tensor 对象的数据保存在内存上,计算会由 CPU 执行,如果我们想使用 GPU,可以调用 tensor 对象的 <code>cuda</code> 方法把对象的数据复制到显存上,复制以后的 tensor 对象运算会使用 GPU。注意在内存上的 tensor 对象和在显存上的 tensor 对象之间无法进行运算。

如果你想编写同时兼容 GPU 和 CPU 的代码可以使用以下写法,如果有支持的 GPU 则会使用 GPU,如果没有则会使用 CPU:

如果你插了多张显卡,以上的写法只会使用第一张,你可以通过 "cuda:序号" 来指定不同的显卡来实现分布式计算。

这里我拿前一篇文章的代码来展示怎样实际使用 GPU 训练识别验证码的模型,以下是修改后完整的代码:

如何生成训练数据和如何使用这份代码的说明请参考前一篇文章。

使用 diff 生成相差的部分如下:

可以看到只改动了五个部分,在头部添加了 device 的定义,然后在加载模型和 tensor 对象的时候使用 <code>.to(device)</code> 即可。

简单吧☺️。

那么训练速度相差如何呢?只训练一个 batch 使用 CPU 和 GPU 消耗的时间分别如下 (单位秒):

差了整整 7 倍😱,,如果是高端的显卡估计可以看到数十倍的差距。

如果你想查看训练过程中的显存占用情况,可以使用 <code>nvidia-smi</code> 命令,命令会输出以下的信息:

如果训练过程中出现显存不足,你会看到以下的异常信息:

如果你遇到显存不足的问题,那么可以尝试以下的办法解决,按实用程度排序:

出钱买新显卡🤒

减少训练批次大小 (例如每个批次 100 条数据,减为每个批次 50 条数据)

不使用的对象早回收,例如 <code>predicted = None</code>,pytorch 会在对象声明周期结束后自动释放显存

计算单值的时候使用 <code>item()</code>,例如 <code>acc_total += acc.item()</code>,但配合 <code>backward</code> 生成运算路径的计算不能用

如果你使用桌面 Linux,试试开机的时候添加 <code>rw init=/bin/bash</code> 进入命令行界面再训练,这样可以节省个几百 MB 显存

你可能会好奇为什了 pytorch 可以及时释放显存,这是因为 python 的对象使用了引用计数 (Reference Counted),GC 基本上只负责回收循环引用的对象,对象的引用计数归 0 的时候 python 会自动调用析构函数,不需要等待 GC。而 NET 和 Java 等语言则无法做到及时回收,除非你每个 tensor 对象都及时的去调用 Dispose 方法,或者使用 tensorflow 来编译静态运算路径然后把生命周期管理工作全部交给框架。这也是使用 Python 的一大好处🥳。

这篇本来应该放在最开始,可惜等到现在才有条件写。下一篇文章预计会介绍对象识别模型,包括 RCNN,FasterRCNN 和 YOLO,看看什么时候能出来吧。