tensorflow 使用 image_retraining 做图像分类 - PHP程序员学习笔记|如何学习PHP

PHP程序员学习笔记|如何学习PHP

个人总结的PHP学习方法


tensorflow 使用 image_retraining 做图像分类

2018-3-8 0phpcom keras神级网络


1.数据集下载

cd
~ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz

2.下载训练的代码

git clone https://github.com/tensorflow/tensorflow.git

3.下载训练用的模型
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz

创建一个文件夹里面放 inception-2015-12-05.tgz

/Users/sam/Desktop/inception_model

├── LICENSE

├── classify_image_graph_def.pb

├── cropped_panda.jpg

├── imagenet_2012_challenge_label_map_proto.pbtxt

├── imagenet_synset_to_human_label_map.txt

└── inception-2015-12-05.tgz


4.训练模型
python3.6 /Users/sam/Desktop/tensorflow/tensorflow/examples/image_retraining/retrain.py --bottleneck_dir bottleneck --how_many_training_steps 200 --model_dir  /Users/sam/Desktop/inception_model/ --output_graph output_graph.pb --output_labels output_labels.txt --image_dir /Users/sam/Desktop/tensorflow/tensorflow/examples/image_retraining/data/

这个时候就会得到一个 bottleneck 文件夹这个就是输出了

5.使用训练好的模型做图像识别
code:
这个时候回返回一个结果集
		

test_data/rose.png [2 4 3 1 0] roses (score = 0.53206) tulips (score = 0.29562) sunflowers (score = 0.12412) dandelion (score = 0.03425) daisy (score = 0.01394)

# coding: utf-8 import tensorflow as tf import os import numpy as np import re from PIL import Image import matplotlib.pyplot as plt lines = tf.gfile.GFile('retrain/output_labels.txt').readlines() uid_to_human = {} #一行一行读取数据 for uid,line in enumerate(lines) : #去掉换行符 line=line.strip('\n') uid_to_human[uid] = line def id_to_string(node_id): if node_id not in uid_to_human: return '' return uid_to_human[node_id] #创建一个图来存放google训练好的模型 with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') with tf.Session() as sess: softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') #遍历目录 for root,dirs,files in os.walk('retrain/images/'): for file in files: #载入图片 image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read() predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式 predictions = np.squeeze(predictions)#把结果转为1维数据 #打印图片路径及名称 image_path = os.path.join(root,file) print(image_path) #显示图片 img=Image.open(image_path) plt.imshow(img) plt.axis('off') plt.show() #排序 top_k = predictions.argsort()[::-1] print(top_k) for node_id in top_k: #获取分类名称 human_string = id_to_string(node_id) #获取该分类的置信度 score = predictions[node_id] print('%s (score = %.5f)' % (human_string, score)) print()


« 《复仇者联盟3》完整剧透版 警告!剧透慎入! | 解决XDEBUG 在外部设备访问php时无法创建断点的问题»
发表评论:









订阅Rss