训练Tensorflow图像分类器分类五种鲜花
0x00 运行环境
首先安装anaconda和tensorflow_gpu环境, 网上有教程,不多赘述。装完之后好像tensorflow_gpu会在anaconda下创建一个虚拟环境。
然后将Tensorflow图像分类器项目克隆到本地,项目地址(https://github.com/akshaypai/tfClassifier)
具体步骤,打开命令提示符,输入命令
git clone https://github.com/akshaypai/tfClassifier
注意:此时的路径就是你clone项目之后项目的保存路径,比如我的是在D盘。命令行下载可能会失败,也可以在项目主页上点击下载
数据集
从网上找了一个数据集,地址 http://www.sykv.com/m/view.php?aid=15258
有五种花,每种800张左右。
设置图像文件夹,新建一个父文件夹,名称为flowers,在该文件夹下新建5个子文件夹,分别对应五种花的名字,将五种花的图片放到对应的文件夹下,由于数据集提供者已经建好,这里只用改一下文件夹的名称即可
0x01 训练过程
打开anaconda命令行
将环境切换到tensorflow_gpu下
将路径切换到GitHub项目下载的地方,存有retrain.py文件的路径下,该文件用来重新训练模型
原始代码有部分是tensorflow 1.0时的函数,现在的tensorflow有所改动,具体为将retrain.py文件下第750行及其后一句中的tf.train.SummaryWriter改为tf.summary.FileWriter
改完之后执行代码
python retrain.py --model_dir D:\tfClassifier-master\image_classification\inception --image_dir F:\Download\flowers --output_graph D:\tfClassifier-master\image_classification\output.pb --output_labels D:\tfClassifier-master\image_classification\labels.txt --how_many_training_steps 500
这是我的路径,可以根据自己的路径修改
上图只给了部分参数,实际上还有其他参数,注意output_graph和output_labels要具体到文件,否则会报错,具体参见( https://stackoverflow.com/questions/45076911/tensorflow-failed-to-create-a-newwriteablefile-when-retraining-inception ),这个网址里也有关于预训练更多的参数
执行命令行
首先会生成label文件,由于我没有指定路径
即这里面的battleneck路径,文件生成在D盘根目录下,分类器项目我放在D盘
建议这里还是指定一下battleneck的路径
生成了tmp文件夹,其中bottleneck是五种花的标签(label)文件,下面是日志文件
生成完之后开始训练,按照之前的设置迭代500次
最终准确率84.4%
0x02 测试retrain的模型
具体命令为
这里可以参考项目上给出的链接 https://sourcedexter.com/retrain-tensorflow-inception-model/
测试准确率为94.9%,搞定!
参考
[1] https://zhuanlan.zhihu.com/p/28223165
[2] https://suiwo.xyz/2017/11/08/TensorFlow不同版本引起的错误/
[3] https://sourcedexter.com/retrain-tensorflow-inception-model/
[4] https://stackoverflow.com/questions/45076911/tensorflow-failed-to-create-a-newwriteablefile-when-retraining-inception