查看: 5082|回复: 1

[项目] 从零开始做一个你画AI猜的小游戏

[复制链接]
  • TA的每日心情

    2020-3-6 09:52
  • 签到天数: 13 天

    连续签到: 1 天

    [LV.3]偶尔看看II

    发表于 2019-2-28 16:41:01 | 显示全部楼层 |阅读模式
    分享到:

    神经网络能学会辨识随手画的灵魂涂鸦吗?只要数据够多就可以!

    今天想带大家从零开始实现一个谷歌开发的小游戏 —— Quick, Draw! 或是叫“限时涂鸦”!

    点击打开链接 (谷歌所有,需科学上网…)


    在本文将涉及到以下内容:

    • 准备数据
    • 训练用于图片分类的神经网络(Caffe
    • Python实现
    • 移植树莓派
    • 神经网络移植于Intel Movidius 芯片的一款嵌入式智能硬件,弥补树莓派计算力不足

    1.png

    游戏内容很简单也很有趣:给你20秒的时间和一个题目,在时间限制内画出来并让AI正确识别。
    2.png

    本灵魂画手的杰作。还没画完就已经被识别过关了于是变成了这样的半成品…

    似乎大家的水平看起来都差不多。

    3.png


    好了,言归正传,首先让我们来分析一下整个游戏的逻辑,然后分成小模块一个一个来攻破。


    4_副本.jpg

    我们先从核心的部分开始:如何训练卷积神经网络来识别涂鸦。


    获得图像数据和标签

    之所以谷歌的这个小游戏可以神奇地正确识别各位在座大触们的作品,原因就是大家玩这个游戏时所提供的海量的数据。

    猜猜现在已经有多少个了?

    What do 50 million drawings look like?

    Over 15 million players have contributed millions of drawings playing Quick, Draw! These doodles are a unique data set that can help developers train new neural networks, help researchers see patterns in how people around the world draw, and help artists create things we haven’t begun to think of. That’s why we’re open-sourcing them, for anyone to play with.      

    1500万玩家提供了345类共5000万张图片!你无论怎么画,总是有那么几个人和你思路差不多,这就是大数据的力量。


    可喜可贺的是良心的谷歌开源了这个数据库,并且提供了一个相比于图像识别更高端一些的识别方式教程,如果有兴趣的同学可以->(通过笔顺用RNN在Tensorflow上实现)。



    谷歌将数据存在了他们的云服务器上(Google Cloud),同样需要科学上网和谷歌账户才能下载,下载较为繁琐(如需下载全数据库请参考他们的Github)




    于是我重新整理了一个轻量级(其实也有100万张)的数据库用于方便大家使用,包括(机翻的)中文标签。剩下的训练用网络结构、指令、模型也放在了里面。

    链接: https://pan.baidu.com/s/1C5iENo6y8QijXMxOXDFIfw 密码: fgm9


    安装OpenCV,安装Caffe,这两步可以很简单也可能让人抓狂好几天…

    首先最好电脑上有支持Cuda的nvidia的显卡,没有也行,只是CPU训练远比GPU慢的多。具体安装过程就不赘述了,网上有成吨的教程。提供一下官方Caffe的Github链接

    注意编译PyCaffe!


    好了,一切准备工作就绪,可以开始训练了。

    (注意:我使用的系统是Ubuntu 16.04,其他操作系统可能有略微不同,如果哪里有坑欢迎大家补充交流)


    训练第一步:生成LMDB数据库


    Caffe支持用txt, hdf5以及lmdb格式训练,生成难度依次从低到高,但效率也同样从低到高。因为我们的数据姑且也是百万级的,所以在这里选择用LMDB格式。好在Caffe自己提供了方便的工具可以直接生成LMDB格式文件,只需要调用编译好的二进制convert_imageset 即可。


    以下为全部指令:

    1. [Caffe路径]/convert_imageset --resize_height=28 --resize_width=28 [数据根目录] ./label_list.txt ./train_lmdb
    2. [Caffe路径]/convert_imageset --resize_height=28 --resize_width=28 [数据根目录] ./label_list_test.txt ./test_lmdb
    复制代码

    在这里需要label_list.txt来指定图像数据的位置以及它的类别标签,比如说其中一行:

    1. /images_test/0000/00046005.png 0
    复制代码

    将0000文件夹的图片赋予标签0, 注意标签要从0开始依次往下排,不然Caffe会出错。

    list我已经生成好了放在了网盘里,如果想用自己的数据训练类似问题可以参考这个结构在Python用os.walk来遍历生成。


    训练第二步:Caffe指令

    1. [Caffe路径]/build/tools/caffe train --solver=./solver.prototxt --gpu 0 2>&1 | tee ./log.log
    复制代码


    这一步将开始训练,其中solver是用来设定配置文件以及其他训练时所需的参数。例如我使用的如下:

    1. net: "./train.prototxt"
    2. test_iter: 1500
    3. test_interval: 500

    4. base_lr: 0.01
    5. lr_policy: "multistep"
    6. gamma: 0.1

    7. stepvalue: 20000
    8. stepvalue: 45000
    9. stepvalue: 65000
    10. stepvalue: 300000
    11. max_iter: 1000000

    12. display: 200
    13. momentum: 0.9
    14. weight_decay: 0.0005
    15. snapshot: 20000
    16. snapshot_prefix: "./models/sketch"

    17. solver_mode: GPU
    复制代码


    如果没有GPU可以将最后改为CPU,并且降低max_iter以及上面学习率迭代次数相关的参数。snapshot_prefix指定模型存放路径,小心两点:这个是前缀,所以会生成 sketch_iter_xxxxxxx.caffemodel 另外记得要事先生成文件夹。 其他参数的含义也可在成吨的Caffe教程里找到详细解释。


    指令中--gpu 0 为指定跑在第几个GPU,使用CPU的话可先改solver然后删掉-gpu即可。之后的部分将Console输出的日志全部写进.log文件里, Caffe提供了一个Python脚本可以筛读这个log文件画图。


    Loss 与迭代次数的关系图:

    1. python [Caffe路径]/tools/extra/plot_training_log.py 6 ./log.png ./log.log
    复制代码

    训练正常开始的话,可以留意一下training loss和test accuracy的变化。

    我的如下,其实loss稳定长期不变了就可以关了省电。

    5.png

    因为输入图片很小 24x24,所以迭代次数应该非常快(假设GPU),大概半小时到一小时左右模型就可以正常使用了(视测试准确率而定)。

    【趁训练终于可以合法看一波动画或是打一局游戏了……】


    Python实现第一步:用测试图片在Python中验证模型是否正确

    休息归来,模型训练的差不多了,终于可以在Python中试验一番。

    首先是测试PyCaffe是否可以正常工作。如出现以下问题请先将Caffe路径加入Python的搜索路径中。

    1. >>> import caffe
    2. Traceback (most recent call last):
    3.   File "<stdin>", line 1, in <module>
    4. ImportError: No module named caffe
    复制代码

    添加Caffe路径:

    1. >>> import sys
    2. >>> sys.path.append('[Caffe路径]/caffe/python')
    3. >>> import caffe
    4. >>> caffe
    5. <module 'caffe' from '/home/vlab/SSD_Proj/caffe/python/caffe/__init__.pyc'
    复制代码

    载入Caffe模型,会哗啦啦出一大片日志。

    1. >>> net = caffe.Net('./deploy.prototxt', './deploy.caffemodel', caffe.TEST)
    复制代码

    Caffe没问题的话,可以开始随便载入一张小图片来试验一下了:

    我们来随便照一张图: 6.png 看起来是个扫把……

    1. >>> import cv2, numpy as np
    2. >>> img = cv2.imread('./test.png')
    3. >>> img = (img.astype(np.float)-127.5)/127.5
    复制代码

    这里用OpenCV读了图,并且归一化图像取值,因为图片本身就是24x24所以不需要cv2.resize来调整大小。


    接下来就是转化图像矩阵变为Caffe需要的格式,此处经常会出错一定要小心。Caffe为(1, 3, 24, 24),图像为(24, 24, 3)。['softmax']为最后输出层名字。

    1. >>> img_caffe = np.array([img]).transpose(0,3,1,2)
    2. >>> out = net.forward_all(**{net.inputs[0]: img_caffe})['softmax']
    复制代码

    out 为长度345的向量,每一维代表着对于每一个类别的置信度。来看一下最大的概率是哪个

    1. >>> np.argmax(out)
    2. 73
    复制代码
    1. >>> out[0,73]
    2. 0.88805795
    复制代码

    被认为ID=73的置信度有0.888,那应该是相当的肯定了,让我们看一下73号是什么:

    1. 耙 73
    复制代码

    还是看一下英文的好了…

    1. paintbrush 73
    复制代码

    虽说不是扫把,不过看起来确实是对的!

    好,最重要的一步做完了,剩下就是实现小游戏了。


    Python实现第二步:PyGame


    游戏规则是只要画的东西被识别排在前5名就正确。我决定用PyGame来实现这个游戏,或是说用Python的话目前能想到比较适合的只有这个了… (安装在Ubuntu的话可以用pip install pygame)


    先看看效果,在这个小窗里画一个城墙。(虽说题目要求我画盆栽)

    7.png

    看起来没问题。


    整个Demo逻辑上比较简单,PyGame窗口上画线然后每一帧抽取出图像输入进神经网络识别,最后排名输出结果。

    还是用代码说话:

    1. #coding: utf-8
    2. import pygame, random
    3. import cv2, numpy as np, sys, pdb

    4. sys.path.append('./caffe/python')  # <- Caffe path
    5. import caffe

    6. size = 200
    7. rz = 28.0
    8. ratio = rz/size

    9. draw_on = False
    10. last_pos = (0, 0)
    11. color = (255, 255, 255)
    12. radius = 8

    13. caffe.set_device(0)
    14. caffe.set_mode_gpu()
    15. net = caffe.Net('./deploy.prototxt', './deploy.caffemodel', caffe.TEST)

    16. fcl2 = open('./class_list.txt','r')
    17. fcl = open('./class_list_chn.txt','r')
    18. class_list = fcl.readlines()
    19. class_list_eng = fcl2.readlines()
    20. cls = []

    21. for line in class_list:
    22.         cls.append(line.split(' ')[0])

    23. screen = pygame.display.set_mode((size,size))

    24. def roundline(srf, color, start, end, radius=1):
    25.     pygame.draw.line(srf, color, start, end, radius)

    26. try:
    27.     pts = []
    28.     stage = 0
    29.     while True:
    30.         e = pygame.event.wait()
    31.         if e.type == pygame.QUIT:
    32.             raise StopIteration
    33.         if e.type == pygame.MOUSEBUTTONDOWN:
    34.             draw_on = True
    35.         if e.type == pygame.MOUSEBUTTONUP:
    36.             draw_on = False
    37.         if e.type == pygame.MOUSEMOTION:
    38.             if draw_on:
    39.                 pts = roundline(screen, color, e.pos, last_pos,  radius)
    40.             last_pos = e.pos
    41.         if e.type == pygame.KEYDOWN:
    42.                 if e.key == ord('q'):
    43.                         screen.fill((0,0,0))
    44.         data = pygame.image.tostring(screen, 'RGB')
    45.         img = np.fromstring(data, np.uint8).reshape(size,size,3)
    46.         img = cv2.resize(img,(28,28)).astype(float)/127.5-1
    47.                
    48.         img_caffe = np.array([img]).transpose(0, 3, 1, 2)
    49.         in_ = net.inputs[0]
    50.         net.forward_all(**{in_: img_caffe})
    51.         res = net.blobs['softmax'].data[0].copy()
    52.         res_label = np.argsort(res)[::-1][:5]
    53.         print('*******************')
    54.         chn = ''.join([i for i in cls[stage][:-1] if not i.isdigit()])
    55.         print('Draw %010s %s Stage:[%d] - Press Q to clear' % (chn, class_list_eng[stage].split(' ')[0], stage+1))
    56.         print('*******************')
    57.         for label in res_label:
    58.                        chn = ''.join([i for i in cls[label][:-1] if not i.isdigit()])
    59.                 print( '%s %s - %2.2f' % (chn,class_list_eng[label].split(' ')[0],res[label]))
    60.                 if label == stage:
    61.                         print('Congratulations! Stage pass [%d]' % stage)
    62.                         stage += 1
    63.         pygame.display.flip()

    64. except StopIteration:
    65.     pass

    66. pygame.quit()
    复制代码


    Python实现第三步:摄像头 + 真·手绘识别


    相信大部分人还是比起鼠标更喜欢用铅笔画,于是在此之上又做了一个新的扩展:用摄像头来识别白纸上的手绘去识别!

    整体逻辑是这样:


    在摄像头输出画面中间画一个小框

    把需要识别的手绘放在小框里

    利用图像处理方式抽取小框中的线条,使之变为类似数据库里的图片(黑底白线)

    剩下的和之前一样

    先上效果:

    8.png

    不错,0.99比之前用鼠标画的还高。

    再上代码:

    1. #coding: utf-8
    2. import pygame, random
    3. import cv2, numpy as np, sys, pdb

    4. sys.path.append('/home/vlab/SSD_Proj/caffe/python')
    5. import caffe

    6. size = 200
    7. rz = 28.0
    8. ratio = rz/size

    9. draw_on = False
    10. last_pos = (0, 0)
    11. color = (255, 255, 255)
    12. radius = 8

    13. caffe.set_device(0)
    14. caffe.set_mode_gpu()
    15. net = caffe.Net('./deploy.prototxt', './deploy.caffemodel', caffe.TEST)

    16. fcl2 = open('./class_list.txt','r')
    17. fcl = open('./class_list_chn.txt','r')
    18. class_list = fcl.readlines()
    19. class_list_eng = fcl2.readlines()
    20. cls = []

    21. for line in class_list:
    22.         cls.append(line.split(' ')[0])

    23. cap = cv2.VideoCapture(0)
    24. p1 = 120
    25. p2 = 45
    26. ROI_ratio = 0.2
    27. stage = 0
    28. while 1:
    29.         ret_val, input_image = cap.read()
    30.         sz = input_image.shape
    31.         cx = sz[0]/2
    32.         cy = sz[1]/2
    33.         ROI = int(sz[0]*ROI_ratio)
    34.         edges = cv2.Canny(input_image,p1,p2)
    35.         edges = cv2.cvtColor(edges,cv2.COLOR_GRAY2RGB)
    36.         print(edges.shape)
    37.         cropped = edges[cx-ROI:cx+ROI,cy-ROI:cy+ROI,:]
    38.         
    39.         kernel = np.ones((4,4),np.uint8)
    40.         cropped = cv2.dilate(cropped,kernel,iterations = 1)
    41.         cropped = cv2.resize(cropped,(28,28))/127.5 - 1
    42.         
    43.         img_caffe = np.array([cropped]).transpose(0, 3, 1, 2)
    44.         in_ = net.inputs[0]
    45.         net.forward_all(**{in_: img_caffe})
    46.         res = net.blobs['softmax'].data[0].copy()
    47.         res_label = np.argsort(res)[::-1][:5]
    48.         print('*******************')
    49.         chn = ''.join([i for i in cls[stage][:-1] if not i.isdigit()])
    50.         print('Draw %010s %s Stage:[%d]' % (chn, class_list_eng[stage].split(' ')[0], stage+1))
    51.         print('*******************')
    52.         for label in res_label:
    53.                    chn = ''.join([i for i in cls[label][:-1] if not i.isdigit()])
    54.                 print( '%s %s - %2.2f' % (chn,class_list_eng[label].split(' ')[0],res[label]))
    55.                 if label == stage:
    56.                         print('Congratulations! Stage pass [%d]' % stage)
    57.                         stage += 1
    58.                
    59.         cv2.rectangle(input_image, (cy-ROI, cx-ROI), (cy+ROI, cx+ROI),(255,255,0), 5)
    60.         cv2.imshow('ret',input_image)
    61.         cv2.imshow('ret2',cropped)
    62.         key = cv2.waitKey(1)
    63.         if key == ord('w'):
    64.                 p1 += 5
    65.         elif key == ord('s'):
    66.                 p1 -= 5
    67.         elif key == ord('e'):
    68.                 p2 += 5
    69.         elif key == ord('d'):
    70.                 p2 -= 5
    71.         elif key == ord('r'):
    72.                 ROI_ratio += 0.1
    73.         elif key == ord('f'):
    74.                 ROI_ratio -= 0.1
    75.         print([p1,p2])
    复制代码

    关于部分的图像处理算法:

    • 边缘检测使用了经典的Canny edge detector, 然后转回RGB的3个通道(虽说还是黑白,但因为训练时使用的是3通道)。
    • 接着使用了大小为4x4的Dilate滤波器,用来加粗线条。
    • 最后cv2.resize变成24x24即可


    下面就可以开始移植到树莓派了。

    为什么要移植树莓派呢?嵌入式开发除了挑战自我以外还有一个很大动机,就是摆脱笨重PC让算法跟着更轻便的主控放飞自我。毕竟在高达时代来临之前应该是不太可能见到可以背着大服务器满地跑的机器人了。

    比如说一个家用助教机器人搭载了这个游戏就可以用来教小孩画画了……或是认识英文单词。

    总之,让我们开始吧。


    树莓派实现第一步...
    其实考过去装个Caffe/OpenCV的CPU版本就可以直接跑了。


    但你发现事情并没这么简单,屏幕卡顿如同集成显卡吃鸡。目测大概帧率在2-3FPS吧。


    作为搭载嵌入式Ubuntu系统的树莓派,真正难题比起兼容性更多是计算力不足。实际上就算是当代顶配CPU也跑不动大部分神经网络。

    所以我借助了一个轻便的USB神经计算硬件。英特尔官方的NCS (Neural Computing Stick)虽说同样满足要求,但NCS比较更新缓慢,似乎不太会在功能方面作出比较大的拓展,毕竟Intel的重点是开发并卖其中Movidius芯片。于是我选择一款迭代更快的同样基于Movidius芯片、国人开发的新产品,叫角蜂鸟(似乎目前只在Intel大学生竞赛里使用,还没正式开卖,买的时候已经降价到600不到了)。

    角蜂鸟目前额外搭载一个树莓派摄像头,可以直接通过USB输出结果。
    9.jpg

    上面的是树莓派,下面的是角蜂鸟

    树莓派实现第二步

    按照说明安装角蜂鸟SDK之后就可以直接在Python调用了。

    在使用之前需要做一次模型转换,将Caffe转为半精度的Graph文件。

    这里直接把3个模式都整合了。不过目前角蜂鸟没法在内置摄像头和神经网络框架之间嵌入图像处理,只能通过外部取图再送回去重新识别,据说以后会开放更多功能。

    加上了外接的神经网络计算硬件,树莓派顿时没了计算压力,基本可以实时地跑了。

    总之,上代码:
    1. #coding: utf-8

    2. import pygame, random
    3. import cv2, numpy as np, sys

    4. sys.path.append('../api/')
    5. import hsapi as hs

    6. mode = 2
    7. # Mode 0 : Webcam mode
    8. # Mode 1 : Mouse drawing mode
    9. # Mode 2 : Sungem mode

    10. devices = hs.EnumerateDevices()
    11. if len(devices) == 0:
    12.     print('No devices found')
    13.     quit()

    14. device = hs.Device(devices[0])
    15. device.OpenDevice()

    16. graph_file_name = 'graphs/graph_sg'
    17. from datetime import datetime
    18. with open(graph_file_name, mode='rb') as f:
    19.     graph_in_memory = f.read()

    20. graph = device.AllocateGraph(graph_in_memory, 0.007843, -1.0)

    21. size = 200
    22. rz = 28.0
    23. ratio = rz/size

    24. draw_on = False
    25. last_pos = (0, 0)
    26. color = (255, 255, 255)
    27. radius = 8

    28. fcl2 = open('./misc/class_list.txt','r')
    29. fcl = open('./misc/class_list_chn.txt','r')
    30. class_list = fcl.readlines()
    31. class_list_eng = fcl2.readlines()
    32. cls = []

    33. for line in class_list:
    34.         cls.append(line.split(' ')[0])
    35. # Webcam mode
    36. if mode == 0:
    37.         cap = cv2.VideoCapture(0)
    38.         p1 = 120
    39.         p2 = 45
    40.         ROI_ratio = 0.1
    41.         stage = 0
    42.         while 1:
    43.                 ret_val, input_image = cap.read()
    44.                 sz = input_image.shape
    45.                 cx = int(sz[0]/2)
    46.                 cy = int(sz[1]/2)
    47.                 ROI = int(sz[0]*ROI_ratio)
    48.                 edges = cv2.Canny(input_image,p1,p2)
    49.                 edges = cv2.cvtColor(edges,cv2.COLOR_GRAY2RGB)
    50.                 print(edges.shape)
    51.                 cropped = edges[cx-ROI:cx+ROI,cy-ROI:cy+ROI,:]

    52.                 kernel = np.ones((4,4),np.uint8)
    53.                 cropped = cv2.dilate(cropped,kernel,iterations = 1)
    54.                 cropped = cv2.resize(cropped,(28,28))/127.5 - 1

    55.                 graph.LoadTensor(cropped.astype(np.float16), None)
    56.                 output, userobj = graph.GetResult()

    57.                 output_sort = np.argsort(output)[::-1]
    58.                 output_label = output_sort[:5]
    59.                 print('*******************')
    60.                 chn = ''.join([i for i in cls[stage][:-1] if not i.isdigit()])
    61.                 print('Draw %010s %s Stage:[%d]' % (chn, class_list_eng[stage], stage+1))
    62.                 print('*******************')
    63.                 cnt = 0
    64.                 for label in output_label:
    65.                         chn = ''.join([i for i in cls[label][:-1] if not i.isdigit()])
    66.                         string = '%s %s - %2.2f' % (chn,class_list_eng[label].split(' ')[0],output[label])
    67.                         print(string)
    68.                         cnt += 1
    69.                         if label == stage and output[label] > 0.1:
    70.                                 print('Congratulations! Stage pass [%d]' % stage)
    71.                                 stage += 1
    72.                         
    73.                 cv2.rectangle(input_image, (cy-ROI, cx-ROI), (cy+ROI, cx+ROI),(255,255,0), 5)
    74.                 rank = np.where(output_sort == stage)[0]
    75.                 string = '%s - Rank: %d' % (class_list_eng[stage].split(' ')[0:-1],rank)
    76.                
    77.                 cv2.putText(input_image, string, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, int(255*(1-rank/350.0)), int(255*rank/350.0)), 3)
    78.                
    79.                 cv2.imshow('ret',input_image)
    80.                 cv2.imshow('ret2',cropped)
    81.                 key = cv2.waitKey(1)
    82.                 if key == ord('w'):
    83.                         p1 += 5
    84.                 elif key == ord('s'):
    85.                         p1 -= 5
    86.                 elif key == ord('e'):
    87.                         p2 += 5
    88.                 elif key == ord('d'):
    89.                         p2 -= 5
    90.                 elif key == ord('r'):
    91.                         ROI_ratio += 0.1
    92.                 elif key == ord('f'):
    93.                         ROI_ratio -= 0.1
    94.                 print([p1,p2])
    95.                
    96. elif mode == 1:
    97.         screen = pygame.display.set_mode((size,size))
    98.         def roundline(srf, color, start, end, radius=1):
    99.                 pygame.draw.line(srf, color, start, end, radius)

    100.         try:
    101.                 pts = []
    102.                 stage = 0
    103.                 while True:
    104.                         e = pygame.event.wait()
    105.                         if e.type == pygame.QUIT:
    106.                                 raise StopIteration
    107.                         if e.type == pygame.MOUSEBUTTONDOWN:
    108.                                 draw_on = True
    109.                         if e.type == pygame.MOUSEBUTTONUP:
    110.                                 draw_on = False
    111.                         if e.type == pygame.MOUSEMOTION:
    112.                                 if draw_on:

    113.                                         pts = roundline(screen, color, e.pos, last_pos,  radius)
    114.                                 last_pos = e.pos
    115.                         if e.type == pygame.KEYDOWN:
    116.                                 if e.key == ord('q'):
    117.                                         screen.fill((0,0,0))
    118.                         data = pygame.image.tostring(screen, 'RGB')
    119.                         img = np.fromstring(data, np.uint8).reshape(size,size,3)
    120.                         img = cv2.resize(img,(28,28)).astype(float)/127.5-1
    121.                         graph.LoadTensor(img.astype(np.float16), None)
    122.                         output, userobj = graph.GetResult()
    123.                         output_label = np.argsort(output)[::-1][:5]
    124.                         print('*******************')
    125.                         chn = ''.join([i for i in cls[stage][:-1] if not i.isdigit()])
    126.                         print('Draw %010s %s Stage:[%d] - Press Q to clear' % (chn, class_list_eng[stage].split(' ')[0], stage+1))
    127.                         print('*******************')
    128.                         for label in output_label:
    129.                                 chn = ''.join([i for i in cls[label][:-1] if not i.isdigit()])
    130.                                 print( '%s %s - %2.2f' % (chn,class_list_eng[label].split(' ')[0],output[label]))
    131.                                 if label == stage:
    132.                                         print('Congratulations! Stage pass [%d]' % stage)
    133.                                         stage += 1
    134.                         pygame.display.flip()

    135.         except StopIteration:
    136.                 pass

    137.         pygame.quit()
    138. elif mode == 2:
    139.         stage = 0
    140.         p1 = 120
    141.         p2 = 45
    142.         ROI_ratio = 0.1
    143.         while 1:
    144.                 input_image = graph.GetImage()
    145.                 sz = input_image.shape
    146.                 output, userobj = graph.GetResult() # 这里的输出目前没用
    147.                
    148.                 sz = input_image.shape
    149.                 cx = int(sz[0]/2)
    150.                 cy = int(sz[1]/2)
    151.                 ROI = int(sz[0]*ROI_ratio)
    152.                 edges = cv2.Canny(input_image,p1,p2)
    153.                 edges = cv2.cvtColor(edges,cv2.COLOR_GRAY2RGB)
    154.                 print(edges.shape)
    155.                 cropped = edges[cx-ROI:cx+ROI,cy-ROI:cy+ROI,:]

    156.                 kernel = np.ones((4,4),np.uint8)
    157.                 cropped = cv2.dilate(cropped,kernel,iterations = 1)
    158.                 cropped = cv2.resize(cropped,(28,28))/127.5 - 1

    159.                 graph.LoadTensor(cropped.astype(np.float16), None)
    160.                 output, userobj = graph.GetResult()
    161.                
    162.                 output_sort = np.argsort(output)[::-1]
    163.                 output_label = output_sort[:5]
    164.                 print('*******************')
    165.                 chn = ''.join([i for i in cls[stage][:-1] if not i.isdigit()])
    166.                 print('Draw %010s %s Stage:[%d]' % (chn, class_list_eng[stage], stage+1))
    167.                 print('*******************')
    168.                 cnt = 0
    169.                 for label in output_label:
    170.                         chn = ''.join([i for i in cls[label][:-1] if not i.isdigit()])
    171.                         string = '%s %s - %2.2f' % (chn,class_list_eng[label].split(' ')[0],output[label])
    172.                         print(string)
    173.                         cnt += 1
    174.                         if label == stage and output[label] > 0.1:
    175.                                 print('Congratulations! Stage pass [%d]' % stage)
    176.                                 stage += 1
    177.                         
    178.                 cv2.rectangle(input_image, (cy-ROI, cx-ROI), (cy+ROI, cx+ROI),(255,255,0), 5)
    179.                 rank = np.where(output_sort == stage)[0]
    180.                 string = '%s - Rank: %d' % (class_list_eng[stage].split(' ')[0:-1],rank)
    181.                
    182.                 cv2.putText(input_image, string, (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, int(255*(1-rank/350.0)), int(255*rank/350.0)), 3)
    183.                
    184.                 cv2.imshow('ret',input_image)
    185.                 cv2.imshow('ret2',cropped)
    186.                 key = cv2.waitKey(1)
    187.                 if key == ord('w'):
    188.                         p1 += 5
    189.                 elif key == ord('s'):
    190.                         p1 -= 5
    191.                 elif key == ord('e'):
    192.                         p2 += 5
    193.                 elif key == ord('d'):
    194.                         p2 -= 5
    195.                 elif key == ord('r'):
    196.                         ROI_ratio += 0.1
    197.                 elif key == ord('f'):
    198.                         ROI_ratio -= 0.1
    199.                 print([p1,p2])
    复制代码

    大概就到这里,感谢观看,欢迎批评补充!

    ---------------------
    作者:MS2308
    来源:CSDN


    回复

    使用道具 举报

  • TA的每日心情
    难过
    2021-2-27 22:16
  • 签到天数: 1568 天

    连续签到: 1 天

    [LV.Master]伴坛终老

    发表于 2019-4-8 10:59:51 | 显示全部楼层
    赞的,果然是技术高超 522.jpg
    回复 支持 反对

    使用道具 举报

    您需要登录后才可以回帖 注册/登录

    本版积分规则

    关闭

    站长推荐上一条 /3 下一条



    手机版|小黑屋|与非网

    GMT+8, 2025-1-15 17:17 , Processed in 0.141883 second(s), 20 queries , MemCache On.

    ICP经营许可证 苏B2-20140176  苏ICP备14012660号-2   苏州灵动帧格网络科技有限公司 版权所有.

    苏公网安备 32059002001037号

    Powered by Discuz! X3.4

    Copyright © 2001-2024, Tencent Cloud.