在我们深入了解CNN之前,让我们先补充一些背景知识。早在上世纪90年代,Yann LeCun就使用CNN做了一个手写数字识别的程序。而随着时代的发展,尤其是计算机性能和GPU的改进,研究人员有了更加丰富的想象空间。 2010年斯坦福的机器视觉实验室发布了ImageNet项目。该项目包含1400万带有描述标签的图片。这个几乎已经成为了比较CNN模型的标准。目前,最好的模型在这个数据集上能达到94%的准确率。人们不断的改善模型来提高准确率。在2014年GoogLeNet 和VGGNet成为了最好的模型,而在此之前是ZFNet。CNN应用于ImageNet的第一个可行例子是AlexNet,在此之前,研究人员试图使用传统的计算机视觉技术,但AlexNet的表现要比其他一切都高出15%。让我们一起看一下LeNet
这个图中并没有显示激活层,整个的流程是:
输入图片 →卷积层 →Relu → 最大池化→卷积层 →Relu→ 最大池化→隐藏层 →Softmax (activation)→输出层。
# -*- coding:utf-8 -*- import matplotlib.pyplot as plt import cv2 import keras from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Dropout from keras.optimizers import RMSprop from keras.layers import Activation, Dense from keras.layers.convolutional import Conv2D,MaxPooling2D import numpy as np cat = cv2.imread('cat.jpg'); plt.imshow(cat) plt.show()
cat.shape model = Sequential(); model.add(Conv2D(3,(3,3), padding='valid',input_shape=cat.shape)) # keras expects batches of images, so we have to add a dimension to trick it into being nice # keras预计会有批量的图像,所以我们必须增加一个维度来欺骗它 cat_bath = np.expand_dims(cat,axis=0) conv_cat = model.predict(cat_bath) def visualize_cat(cat_bath): cat = np.squeeze(cat_bath,axis=0) plt.imshow(cat) plt.show() # here we get rid of that added dimension and plot the image # 在这里,我们去掉了添加的尺寸和图像 visualize_cat(conv_cat)
def nice_cat_printer(model,cat): cat_bath = np.expand_dims(cat,axis=0) conv_cat2 = model.predict(cat_bath) conv_cat2 = np.squeeze(conv_cat2,axis=0) print(conv_cat2.shape) conv_cat2 = conv_cat2.reshape(conv_cat2.shape[:2]) print(conv_cat2.shape) plt.imshow(conv_cat2) plt.show() model = Sequential(); model.add(Conv2D(1,(3,3), padding='valid',input_shape=cat.shape)) nice_cat_printer(model,cat)
# 增加一个池化层 model = Sequential(); model.add(Conv2D(1,(3,3), padding='valid',input_shape=cat.shape)) model.add(MaxPooling2D(pool_size=(2,2))) nice_cat_printer(model,cat)#激活和最大池化 model = Sequential() model.add(Conv2D(1,(3,3), padding='valid',input_shape=cat.shape)) model.add(Activation('relu')) model.add(MaxPooling2D(pool_size=(5,5))) nice_cat_printer(model,cat)
参考文章:https://yq.aliyun.com/articles/73544?spm=5176.8278999.602941.2