查看: 8233|回复: 11

[征集/原创]在Movidius计算棒上运行手写数字识别Keras模型(一)

  [复制链接]
  • TA的每日心情
    开心
    2018-12-28 02:39
  • 签到天数: 12 天

    连续签到: 1 天

    [LV.3]偶尔看看II

    发表于 2018-9-1 21:37:43 | 显示全部楼层 |阅读模式
    分享到:
    本帖最后由 Daniel-Wu 于 2018-9-1 22:58 编辑

    国内鲜有基于Movidius的教程,国外开源的项目也寥寥无几,本教程结合了一些国外优秀的教程和自己的一些实践经历。我想通过这个教程帮助大家尽可能快速上手使用Movidius。
    在本教程中我会向您展示培训简单的MNIST Keras模型并将其部署到树莓派+Movidius中。
    需要的材料有movidius + 树莓派或其他ubuntu电脑 + 摄像头(树莓派官方的摄像头和usb的均可)
    同时你还需要在树莓派上安装好Movidius ncsdk,有机会我会再写一个专门介绍movidius安装的教程
    有几个步骤,
    • 在Keras训练模型(TensorFlow后端)
    • 在Keras中保存模型文件和权重
    • 将Keras模型转换为TensorFlow
    • 将TensorFlow模型编译为NCS graph  (不知道graph是啥的请百度tensorflow的文档)
    • 在NCS上部署并运行graph

    -我使用的平台是up board(ubuntu)+神经计算棒+摄像头,均可以在爱板商城购买到
    IMG_2248.JPG

    一.什么是Keras?



    在keras中文wiki里面介绍了Keras是一个高层神经网络API,Keras由纯Python编写而成并基TensorflowTheano以及CNTK后端。Keras 为支持快速实验而生,能够把你的idea迅速转换为结果,如果你有如下需求,请选择Keras:

    • 简易和快速的原型设计(keras具有高度模块化,极简,和可扩充特性)
    • 支持CNN和RNN,或二者的结合
    • 无缝CPU和GPU切换
    简而言之,Keras就是简单易上手的神经网络,将TensorflowTheano以及CNTK等作为后端,将复杂的APi简单化,方便小白上手,小伙伴们可以百度keras中文文档进行学习,当然也可以你看我给大家的英文原版书籍,我会上传到附件,英文书籍正好可以练练阅读能力,学人工智能,英语阅读能力少不了的


    二,培训与输出模型


    了解了Keras的基本使用后
    再来看看如何培训Keras的模型的
    直接上培训与输出部分的代码,这部分在电脑上运行,windows和linux均可
    1. from keras import layers, models
    2. from keras.models import load_model
    3. from keras.datasets import mnist
    4. from keras.utils import to_categorical
    5. from keras import backend as K
    6. import tensorflow as tf

    7. (x_train, y_train), (x_test, y_test) = mnist.load_data()
    复制代码

    正确训练完后控制台打印如下 TIM图片20180901162216.png

    并且应该有三个文件生成
    TIM图片20180901162216.png
    三.将Keras转变为TensorFlow模型
    查看Movidius官网的WIki,由于Movidius NCSDK2仅编译TensorFlow或Caffe模型。现在tensorflow正好官方支持了树莓派平台,那就使用tensorflow来做吧。
    1. from keras.models import model_from_json
    2. from keras import backend as K
    3. import tensorflow as tf

    4. model_file = "model.json"
    5. weights_file = "weights.h5"

    6. with open(model_file, "r") as file:
    7.     config = file.read()

    8. K.set_learning_phase(0)
    9. model = model_from_json(config)
    10. model.load_weights(weights_file)

    11. saver = tf.train.Saver()
    12. sess = K.get_session()
    13. saver.save(sess, "./TF_Model/tf_model")

    14. fw = tf.summary.FileWriter('logs', sess.graph)
    15. fw.close()
    复制代码
    TIM图片20180901162216.png

    第一 我们关闭学习阶段,然后从我们先前保存的两个单独文件以标准Keras方式加载模型。
    通过 调用 K.get_session() 从使用基于TensorFlow后端的 Keras,将提供默认的TensorFlow  session 。您甚至可以调用 sess.graph.get_operations()进一步了解TensorFlow graph中的内容 它将返回一列模型中TensorFlow操作。这对于查找计算棒不支持的操作非常有用,众所周知,movidius的限制还是太多,希望国产的角峰鸟可以解决一些限制。最后,TensorFlow Saver类将模型保存到指定路径中的四个文件中。
    每个文件都有不同的用途,
    • checkpoint定义模型检查点路径,在我们的例子中是“tf_model”。
    • .meta存储图形结构,
    • .data存储图中每个变量的值
    • .index标识检查点。

    四.使用mvNCCompile编译TensorFlow模型

    在装有ncsdk的树莓派或主机上,使用mvNCCompile。mvNCCompile命令行工具自带NCSDK2工具包,可以转换来自Caffe或Tensorflow网络到graph,也就是可由Movidius计算棒使用的文件。我们将' conv2d_1_input '作为输入节点,将' dense_2 / Softmax '作为输出节点。一行命令即可在当前目录生成名为”graph“的文件。

    1. mvNCCompile TF_Model / tf_model.meta -in = conv2d_1_input -on = dense_2 / Softmax
    复制代码




    五.部署graph并进行预测

    1. from mvnc import mvncapi as mvnc
    2. # 获取计算棒的设备名
    3. devices = mvnc.enumerate_devices()
    4. dev = mvnc.Device(devices[0])
    5. # 从文件中读取已编译的网络图(为图形文件正确设置graph_filepath)
    6. with open("graph", mode='rb') as f:
    7.     graphFileBuff = f.read()

    8. graph = mvnc.Graph('graph1')

    9. # 在设备上分配图形并创建输入和输出
    10. in_fifo, out_fifo = graph.allocate_with_fifos(dev, graphFileBuff)

    11. # 将输入写入input_fifo缓冲区并在一次调用中对推理进行排队
    12. graph.queue_inference_with_fifo_elem(in_fifo, out_fifo, input_img.astype('float32'), 'user object')

    13. # 将结果读取到输出Fifo
    14. output, userobj = out_fifo.read_elem()
    15. print('Predicted:',output.argmax())
    复制代码



    六.使用Raspberry Pi上的摄像头实时图像进行预测

    训练MNIST模型以识别28×28分辨率的灰度图像的手写数字,所以必需转换摄像头捕获的图像,下面是一些预处理步骤。
    • 裁剪图像的中心区域
    • 使用边缘检测找到图像的边缘,此步骤也可以将图像转换为灰度
    • 扩大边缘,使边缘更厚,以填充两个紧密平行边缘之间的区域。
    • 将图像大小调整为28 x 28
    对于每个捕获的帧,我们将其传递给图像预处理函数,然后输入NCS graph,该graph返回最终的预测概率。从那里,我们将最终预测的结果呈现为在显示器上显示的图像上。
    图像处理过程的代码片段如下
    1. import numpy as np
    2. import cv2
    3. class ImageProcessor:
    4.     """
    5.     A singleton class for ImageProcessor
    6.     """

    7.     p1 = 90
    8.     p2 = 30
    9.     ROI_ratio = 0.2
    10.     label_text_color = (0, 120, 0)
    11.     min_score_percent = 60

    12.     def __new__(cls, min_score_percent=60):
    13.         if not hasattr(cls, 'instance'):
    14.             cls.instance = super(ImageProcessor, cls).__new__(cls)
    15.         return cls.instance

    16.     def __init__(self, min_score_percent=60):
    17.         self.min_score_percent = min_score_percent
    18.     def preprocess_image(self, input_image):
    19.         self.sz = input_image.shape
    20.         self.cx = self.sz[0]//2
    21.         self.cy = self.sz[1]//2
    22.         self.ROI = int(self.sz[0]*self.ROI_ratio)
    23.         edges = cv2.Canny(input_image,self.p1,self.p2)
    24.         cropped = edges[self.cx-self.ROI:self.cx+self.ROI,self.cy-self.ROI:self.cy+self.ROI]
    25.         kernel = np.ones((4,4),np.uint8)
    26.         cropped = cv2.dilate(cropped,kernel,iterations = 2)
    27.         cropped_input = cv2.resize(cropped,(28,28)) / 255.0
    28.         cv2.rectangle(input_image, (self.cy-self.ROI, self.cx-self.ROI), (self.cy+self.ROI, self.cx+self.ROI),(255,255,0), 5)
    29.         return cropped_input, cropped
    30.     def postprocess_image(self, input_image, percentage, label_text, cropped=None):
    31.         if cropped is not None:
    32.             cropped = np.stack((cropped,)*3, -1)
    33.             input_image[-cropped.shape[0]:, -cropped.shape[1]:] = cropped
    34.         if percentage >= self.min_score_percent:
    35.             cv2.putText(input_image, label_text, (self.cy-self.ROI - 1, self.cx-self.ROI - 1),cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
    36.         else:
    37.             cv2.putText(input_image, '?', (self.cy-self.ROI - 1, self.cx-self.ROI - 1),cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
    复制代码
    下面就是重头戏使用movidius ncs进行摄像头拍摄识别
    1. #!/usr/bin/env python3
    2. from mvnc import mvncapi as mvnc
    3. import numpy as np
    4. from ImageProcessor import ImageProcessor
    5. import cv2

    6. cv_window_name = "predict-mnist"
    7. CAMERA_INDEX = 0
    8. REQUEST_CAMERA_WIDTH = 640
    9. REQUEST_CAMERA_HEIGHT = 480

    10. def handle_keys(raw_key):
    11.     global processor
    12.     ascii_code = raw_key & 0xFF
    13.     if ((ascii_code == ord('q')) or (ascii_code == ord('Q'))):
    14.         return False
    15.     elif (ascii_code == ord('w')):
    16.         processor.p1 +=10
    17.         print('processor.p1:' + str(processor.p1))
    18.     elif (ascii_code == ord('s')):
    19.         processor.p1 -=10
    20.         print('processor.p1:' + str(processor.p1))
    21.     elif (ascii_code == ord('a')):
    22.         processor.p2 +=10
    23.         print('processor.p2:' + str(processor.p2))
    24.     elif (ascii_code == ord('d')):
    25.         processor.p2 -=10
    26.         print('processor.p1:' + str(processor.p2))
    27.     return True
    28. processor = ImageProcessor()
    29. # 检测有没有连接上计算棒
    30. devices = mvnc.enumerate_devices()
    31. if len(devices) == 0:
    32.     print('No devices found')
    33.     quit()

    34. dev = mvnc.Device(devices[0])

    35. # 尝试打开设备。 如果其他进程已经打开它,这将报异常
    36. try:
    37.     dev.open()
    38. except:
    39.     print("Error - Could not open NCS device.")
    40.     quit()

    41. # 从文件中读取已编译的graph
    42. with open("graph", mode='rb') as f:
    43.     graphFileBuff = f.read()

    44. graph = mvnc.Graph('graph1')

    45. # 在设备上分配graph并创建输入和输出Fifos
    46. in_fifo, out_fifo = graph.allocate_with_fifos(dev, graphFileBuff)
    47. cv2.namedWindow(cv_window_name)
    48. cv2.moveWindow(cv_window_name, 10,  10)

    49. cap = cv2.VideoCapture(CAMERA_INDEX)
    50. cap.set(cv2.CAP_PROP_FRAME_WIDTH, REQUEST_CAMERA_WIDTH)
    51. cap.set(cv2.CAP_PROP_FRAME_HEIGHT, REQUEST_CAMERA_HEIGHT)

    52. actual_frame_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    53. actual_frame_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    54. print ('actual video resolution: ' + str(actual_frame_width) + ' x ' + str(actual_frame_height))

    55. if ((cap == None) or (not cap.isOpened())):
    56.     print ('Could not open camera.  Make sure it is plugged in.')
    57.     print ('Also, if you installed python opencv via pip or pip3 you')
    58.     print ('need to uninstall it and install from source with -D WITH_V4L=ON')
    59.     print ('Use the provided script: install-opencv-from_source.sh')
    60.     exit_app = True
    61.     exit()
    62. exit_app = False
    63. while(True):
    64.     ret, input_image = cap.read()

    65.     if (not ret):
    66.         print("No image from from video device, exiting")
    67.         break

    68.     # 程序自检,看看检测窗口可被叉掉了
    69.     prop_val = cv2.getWindowProperty(cv_window_name, cv2.WND_PROP_ASPECT_RATIO)
    70.     if (prop_val < 0.0):
    71.         exit_app = True
    72.         break
    73.     cropped_input, cropped = processor.preprocess_image(input_image)
    74.     # 将输入写入input_fifo缓冲区并在一次调用中对推理进行排队
    75.     graph.queue_inference_with_fifo_elem(in_fifo, out_fifo, cropped_input.astype('float32'), 'user object')
    76.     # 将结果读取到输出Fifo
    77.     output, userobj = out_fifo.read_elem()
    78.     predict_label = output.argmax()
    79.     percentage = int(output[predict_label] * 100)
    80.     label_text = str(predict_label) + " (" + str(percentage) + "%)"
    81.     print('Predicted:',label_text)
    82.     processor.postprocess_image(input_image, percentage, label_text, cropped)
    83.     cv2.imshow(cv_window_name, input_image)
    84.     raw_key = cv2.waitKey(1)
    85.     if (raw_key != -1):
    86.         if (handle_keys(raw_key) == False):
    87.             exit_app = True
    88.             break
    89. cap.release()
    90. # 取消分配并清除fifo和图形,关闭设备
    91. try:
    92.     in_fifo.destroy()
    93.     out_fifo.destroy()
    94.     graph.destroy()
    95.     dev.close()
    96.     dev.destroy()
    97. except:
    98.     print("Error - could not close/destroy Graph/NCS device.")
    99.     quit()
    复制代码


    下面是实际运行的照片:
    TIM图片20180901162216.png
    Deep Learning with Keras.zip (8.71 MB, 下载次数: 12, 售价: 20 与非币)

    评分

    参与人数 1与非币 +150 收起 理由
    satoll + 150 AI征集奖励

    查看全部评分

    回复

    使用道具 举报

  • TA的每日心情
    无聊
    2024-9-4 09:09
  • 签到天数: 48 天

    连续签到: 1 天

    [LV.5]常住居民I

    发表于 2018-9-3 08:52:50 | 显示全部楼层
    大兄弟厉害了!
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    开心
    2017-12-29 15:29
  • 签到天数: 2 天

    连续签到: 1 天

    [LV.1]初来乍到

    发表于 2018-9-3 10:10:05 | 显示全部楼层
    厉害了,最近在入门深度学习,很好的学习资料,谢谢
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    开心
    2017-12-29 15:29
  • 签到天数: 2 天

    连续签到: 1 天

    [LV.1]初来乍到

    发表于 2018-9-3 10:13:27 | 显示全部楼层
    谢谢你的书
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    开心
    2019-11-4 13:48
  • 签到天数: 14 天

    连续签到: 1 天

    [LV.3]偶尔看看II

    发表于 2018-9-3 15:46:34 | 显示全部楼层
    哇,学习了,感谢楼主分享
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    无聊
    2018-8-15 10:21
  • 签到天数: 3 天

    连续签到: 1 天

    [LV.2]偶尔看看I

    发表于 2018-9-3 15:47:54 | 显示全部楼层
    太长知识了,感谢
    回复 支持 反对

    使用道具 举报

  • TA的每日心情

    2018-8-2 13:58
  • 签到天数: 1 天

    连续签到: 1 天

    [LV.1]初来乍到

    发表于 2018-9-4 10:16:48 | 显示全部楼层
    感谢楼主分享
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    无聊
    2024-9-4 09:09
  • 签到天数: 48 天

    连续签到: 1 天

    [LV.5]常住居民I

    发表于 2018-9-6 10:34:21 | 显示全部楼层
    没有(二)了吗?
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    开心
    2018-9-6 15:08
  • 签到天数: 11 天

    连续签到: 1 天

    [LV.3]偶尔看看II

    发表于 2018-9-6 15:07:49 | 显示全部楼层
    谢谢对AI小白的支持
    回复 支持 反对

    使用道具 举报

  • TA的每日心情
    开心
    2018-12-28 02:39
  • 签到天数: 12 天

    连续签到: 1 天

    [LV.3]偶尔看看II

     楼主| 发表于 2018-9-7 09:40:25 | 显示全部楼层
    satoll 发表于 2018-9-6 10:34
    没有(二)了吗?

    最近开学太忙,稍后更新
    回复 支持 反对

    使用道具 举报

    您需要登录后才可以回帖 注册/登录

    本版积分规则

    关闭

    站长推荐上一条 /4 下一条

    手机版|小黑屋|与非网

    GMT+8, 2024-11-20 00:41 , Processed in 0.220123 second(s), 36 queries , MemCache On.

    ICP经营许可证 苏B2-20140176  苏ICP备14012660号-2   苏州灵动帧格网络科技有限公司 版权所有.

    苏公网安备 32059002001037号

    Powered by Discuz! X3.4

    Copyright © 2001-2024, Tencent Cloud.