git clone https://github.com/tensorflow/tensorflow.git
3.下载训练用的模型
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
/Users/sam/Desktop/inception_model
├── LICENSE
├── classify_image_graph_def.pb
├── cropped_panda.jpg
├── imagenet_2012_challenge_label_map_proto.pbtxt
├── imagenet_synset_to_human_label_map.txt
└── inception-2015-12-05.tgz
test_data/rose.png [2 4 3 1 0] roses (score = 0.53206) tulips (score = 0.29562) sunflowers (score = 0.12412) dandelion (score = 0.03425) daisy (score = 0.01394)
# coding: utf-8 import tensorflow as tf import os import numpy as np import re from PIL import Image import matplotlib.pyplot as plt lines = tf.gfile.GFile('retrain/output_labels.txt').readlines() uid_to_human = {} #一行一行读取数据 for uid,line in enumerate(lines) : #去掉换行符 line=line.strip('\n') uid_to_human[uid] = line def id_to_string(node_id): if node_id not in uid_to_human: return '' return uid_to_human[node_id] #创建一个图来存放google训练好的模型 with tf.gfile.FastGFile('retrain/output_graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tf.import_graph_def(graph_def, name='') with tf.Session() as sess: softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') #遍历目录 for root,dirs,files in os.walk('retrain/images/'): for file in files: #载入图片 image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read() predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式 predictions = np.squeeze(predictions)#把结果转为1维数据 #打印图片路径及名称 image_path = os.path.join(root,file) print(image_path) #显示图片 img=Image.open(image_path) plt.imshow(img) plt.axis('off') plt.show() #排序 top_k = predictions.argsort()[::-1] print(top_k) for node_id in top_k: #获取分类名称 human_string = id_to_string(node_id) #获取该分类的置信度 score = predictions[node_id] print('%s (score = %.5f)' % (human_string, score)) print()