0%

Keras 实现识别 12306 验证码

Python+Keras 实现识别12306验证码

前言

CNN卷积神经网络在图像识别领域大放异彩使图形验证码的防护作用大大下降,我们需要寻找一种新的方式来代替传统的验证方式,识别操作者究竟是屏幕前的真人还是网络中的机器人。

本文旨在通过识别12306验证码为例来验证,来探讨图形验证码在当下深度学习技术的支持下,究竟还有多少用武之地。

环境

  • tensorflow-gpu 2.1.0
  • keras 2.3.1
  • python 3.7.4
  • flask 1.1.1

分析

分析12306的验证码,我们发现识别分为两个部分。第一步识别文字部分,然后再根据文字部分识别图片部分。

以及通过大量的观察,我们发现验证码图片类型分为80个类。

获取样本

12306 验证码接口

1
https://kyfw.12306.cn/passport/captcha/captcha-image64

返回的结果是一串base64编码后的图片

我们可以编写一个脚本批量下载验证码。
编写 GetCapcha.py 并在当前目录下新建文件夹 originCaptcha 用于保存验证码图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import requests
import base64
import json


url = "https://kyfw.12306.cn/passport/captcha/captcha-image64"
headers = {
"User-Agent":"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36"
}

def getCaptcha(savePath):
try
req = requests.get(url=url, headers=headers)
imageBase64str = json.loads(req.content)["image"]
with open(imageBase64str, "wb") as f:
f.write(base64.b64decode(imageBase64str))
except:
pass

if __name__ = "__main__":

# 下载验证码的数量
downloadCount = 10000

# 验证码保存目录
saveDir = os.path.join(os.getcwd(), "originCaptcha")
for i in range(downloadCount):
savePath = os.path.join(saveDir, str(i) + ".jpg")
getCaptcha(savePath)
print("{}/{}".format(i+1, downloadCount))

拿到的验证码样本

分割样本

分割样本,分别提取label部分和image部分

这里使用了一个Python图像处理库PIL和进度显示库tqdm,安装方式。

1
2
pip install pillow
pip install tqdm

分割Label部分

新建 CutCaptcha.py 以及 cutedCaptcha 目录进行保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from PIL import Image, ImageFile
import os
import hashlib
import numpy as np
from tqdm import tqdm

def get_img_sha1(img):
'''
计算图像的sha1值为文件名
:param img:
:return:
'''
img = np.asarray(img)
sha1obj = hashlib.sha1()
sha1obj.update(img)
hash = sha1obj.hexdigest()
return str(hash).upper()

def cut_label(path):
'''
切割标签
:param path:
:return:
'''
image = Image.open(path)
x = 117
y = 0
w = 180
h = 30
return image.crop((x, y, w, h))

def cut_image(path):
'''
切割图片
:param path:
:return:
'''
image = Image.open(path)
space = 67 + 5
x0, y0, w0, h0 = 0*space+5, 0*space+41, 1*space, 0*space+41+67
x1, y1, w1, h1 = 0*space+5, 1*space+41, 1*space, 1*space+41+67
x2, y2, w2, h2 = 1*space+5, 0*space+41, 2*space, 0*space+41+67
x3, y3, w3, h3 = 1*space+5, 1*space+41, 2*space, 1*space+41+67
x4, y4, w4, h4 = 2*space+5, 0*space+41, 3*space, 0*space+41+67
x5, y5, w5, h5 = 2*space+5, 1*space+41, 3*space, 1*space+41+67
x6, y6, w6, h6 = 3*space+5, 0*space+41, 4*space, 0*space+41+67
x7, y7, w7, h7 = 3*space+5, 1*space+41, 4*space, 1*space+41+67
image0 = image.crop((x0, y0, w0, h0))
image1 = image.crop((x1, y1, w1, h1))
image2 = image.crop((x2, y2, w2, h2))
image3 = image.crop((x3, y3, w3, h3))
image4 = image.crop((x4, y4, w4, h4))
image5 = image.crop((x5, y5, w5, h5))
image6 = image.crop((x6, y6, w6, h6))
image7 = image.crop((x7, y7, w7, h7))
return image0, image1, image2, image3, image4, image5, image6, image7

# 初始化各个参数
ImageFile.LOAD_TRUNCATED_IMAGES = True
captcha_path_list = []
captcha_input_dir = os.path.join(os.getcwd(), "originCaptcha")
captcha_output_dir = os.path.join(os.getcwd(), "cutedCaptcha")
# 遍历验证码目录获取路径列表
for root, dirs, imgs in os.walk(captcha_input_dir):
for img in imgs:
captcha_path_list.append(os.path.join(root, img))

# 分割验证码
with tqdm(total=len(captcha_path_list), desc="Cut captcha") as pbar:
for captcha_input_path in captcha_path_list:
original_captcha_name = os.path.basename(captcha_input_path).split(".")[0]
captcha_output_dir_second = os.path.join(captcha_output_dir, original_captcha_name)
image = [ n for n in range(8)]
image_name = [ str(n) for n in range(8)]
image[0], image[1], image[2], image[3], image[4], image[5], image[6], image[7] = cut_image(captcha_input_path)
for i in range(len(image)):
image_name[i] = get_img_sha1(image[i])
for i, img in enumerate(image):
captcha_output_path = os.path.join(captcha_output_dir_second, str(image_name[i]) + ".jpg")
if not os.path.exists(captcha_output_dir_second):
os.makedirs(captcha_output_dir_second)
img.save(captcha_output_path)
captcha_output_label_name = "label_" + str(original_captcha_name) + ".jpg"
captcha_output_label_path = os.path.join(captcha_output_dir_second, captcha_output_label_name)
cut_label(captcha_input_path).save(str(captcha_output_label_path))
pbar.update()

每个目录保存一个切割后的验证码的label部分和image部分。

我们先把label部分提取出来放在一个目录

新建 CopyLabel.py 和 label 目录

1
2
3
4
5
6
7
8
9
10
11
import os
from tqdm import tqdm
import shutil

for root, dir, imgs in os.walk("cutedCaptcha"):
for img in imgs:
img_path = os.path.join(root, img)
if "label" in img:
new_img_path = os.path.join(os.getcwd(), "label", img)
shutil.copy(img_path, new_img_path)
print(new_img_path)

结果如下

接下来就到了最枯燥的时候了,我们需要预先标记一些样本,标记成这样的。

分割Image部分

图像部分同理,处理成如下格式

创建字符集

新建一个 Python Package 命名为 charset,并在里面新建一个label.py文件用于保存图像分类标签

label.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
LABEL = [
"蒸笼",
"蚂蚁",
"篮球",
"苍蝇拍",
"菠萝",
"薯条",
"耳塞",
"剪纸",
"绿豆",
"蜥蜴",
"热水袋",
"红酒",
"红枣",
"漏斗",
"公交卡",
"红豆",
"金字塔",
"狮子",
"文具盒",
"棉棒",
"海鸥",
"开瓶器",
"沙拉",
"电线",
"盘子",
"网球拍",
"烛台",
"老虎",
"药片",
"订书机",
"中国结",
"双面胶",
"茶几",
"挂钟",
"档案袋",
"创可贴",
"安全帽",
"樱桃",
"调色板",
"冰箱",
"钟表",
"黑板",
"卷尺",
"印章",
"手掌印",
"话梅",
"铃铛",
"牌坊",
"蜜蜂",
"茶盅",
"路灯",
"锦旗",
"雨靴",
"刺绣",
"本子",
"拖把",
"珊瑚",
"锅铲",
"海苔",
"锣",
"排风机",
"龙舟",
"鞭炮",
"仪表盘",
"跑步机",
"日历",
"毛线",
"打字机",
"电饭煲",
"海报",
"高压锅",
"沙包",
"护腕",
"口哨",
"电子秤",
"辣椒酱",
"航母",
"蜡烛",
"啤酒",
"风铃"
]

新建配置文件

新建一个 Config.py 文件方便我们把需要用的配置拿出来统一处理

Config.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from charset import LABEL

# 设置配置文件的工作目录
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

# 项目设置
ITEM_NAME = "Label_12306_SmallCNN4"

# 图像尺寸以及放大倍率
IMAGE_MAGNIFICATION = 1
IMAGE_HEIGHT = 30 * IMAGE_MAGNIFICATION
IMAGE_WIDTH = 63 * IMAGE_MAGNIFICATION
IMAGE_CHANNEL = 1

# 字符集
CHAR_SET_LIST = LABEL
# 验证码长度
CAPTCHA_LENGTH = 1

# 训练设置
# INIT_EPOCHS 如果不是第一次训练, 可以设置初始化EPOCHS, 这样可以避免Tensorboard显示混乱
INIT_EPOCHS = 0
# INIT_EPOCHS = 31
EPOCHS = 200
BATCH_SIZE = 32
STEP = 500

# 模型保存路径
MODEL_SAVE_PATH = os.path.join(CURRENT_DIR, "model", ITEM_NAME + ".model")

# 数据集目录
DATA_SET_DIR = "这里写训练数据集的目录"
VALID_SET_DIR = "这里写验证数据集的目录"
TEST_SET_DIR = "这里写测试数据集的目录"

# tensor board 目录
TENSOR_BOARD_DIR = os.path.join(CURRENT_DIR, "tensorboard", ITEM_NAME)

# CSV 文件保存路径, 用于可视化训练过程
CSV_SAVE_PATH = os.path.join(CURRENT_DIR, "csv", ITEM_NAME + ".csv")

# 模型结构图
MODEL_STRUCTURE_IMAGE_DIR = os.path.join(CURRENT_DIR, "model_structure_image")

定义模型

新建 net_model 包并在里面新建 SmallCNN.py 用于定义一个简单的CNN模型

SmallCNN.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import keras
import numpy as np
from keras import backend as K
from keras.regularizers import l2
from keras.models import Model
from keras.layers import Dropout, Input, Lambda, Activation, Conv2D, MaxPooling2D, ZeroPadding2D, Reshape, Concatenate, Flatten, Dense, BatchNormalization, MaxPool2D
from keras.models import Sequential
from keras.optimizers import SGD, Adam
from keras import optimizers
import Config


_l2_reg = 0.0005 # L2 正则化

def _conv2d(input, filters, kernel_size, name):
'''
创建卷积层
:param input:
:param filters: 整数, 输出空间的维度, 卷积中滤波器的输出数量
:param kernel_size: 一个整数, 或者是2个整数表示的元组或列表, 指明2D卷积窗口的宽度和高度, 可以是一个整数, 为所有空间维度指定相同的值
:param name: 卷积层的名字
:return: 卷积层
'''
return Conv2D(
filters,
kernel_size,
activation="relu",
padding="same",
kernel_initializer="he_normal",
kernel_regularizer=l2(_l2_reg),
name=name
)(input)

def _pooling(input, name):
'''
池化层
:param name:
:return:
'''
return MaxPooling2D(
pool_size=(1, 2),
strides=(1, 1),
padding="same",
name=name
)(input)

def build(layres=2):
'''
构建一个简单CNN网络
:return:
'''
input_tensor = Input((Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH, Config.IMAGE_CHANNEL))
x = input_tensor
layresList = [2 for _ in range(layres)]
for i, n_cnn in enumerate(layresList):
for j in range(n_cnn):
x = Conv2D(32 * 2 ** min(i, 3), kernel_size=3, padding='same', kernel_initializer='he_uniform')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D(2)(x)
x = Dropout(rate=0.5)(x)

x = Flatten()(x)
x = [
Dense(len(Config.CHAR_SET_LIST), activation='softmax', name='c%d' % (i + 1))(x)
for i in range(Config.CAPTCHA_LENGTH)
]
model = Model(inputs=input_tensor, outputs=x, name="SmallCNN-Layers-{}".format(layres))
model.summary()
model.compile(loss='categorical_crossentropy',
optimizer=Adam(1e-3, amsgrad=True),
metrics=['accuracy'])
return model

if __name__ == '__main__':

model = build(layres=4)

新建一个工具类

新建一个工具类用于图像处理的操作

Utils.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from PIL import Image
import os
import Config
import shutil
import numpy as np
from io import BytesIO


class CheckInvalid(object):

@classmethod
def delete(cls, dstDir):
pathList = []
for root, dirs, files in os.walk(dstDir):
for file in files:
pathList.append(os.path.join(root, file))
total = len(pathList)
removed = 0
for i, imagePath in enumerate(pathList):
try:
Image.open(imagePath).convert("RGB")
except OSError:
os.remove(imagePath)
print("remove ---> ", imagePath)
removed += 1
print("{}/{}".format(i+1, total))
print("remove number ", removed)

class Preprocess(object):

@classmethod
def convertImageToArrayByChannel(cls, image):
'''
input image (PIL) object convert to array by channel
:param image: image obj
:return: numpy array obj
'''
if Config.IMAGE_CHANNEL == 3:
return np.array(image.convert("RGB"))
if Config.IMAGE_CHANNEL == 1:
imageArray = np.array(image.convert("L"))
imageArray = np.expand_dims(imageArray, axis=2)
return imageArray
return np.array(image.convert("RGB"))

@classmethod
def convertImageFormat(cls, image, format="PNG"):
'''
input image obj, convert to png format
:param image: image obj
:return: image obj
'''
if image is None:
raise TypeError("image object is none.")
with BytesIO() as imageIO:
image.save(imageIO, format=format)
imageByte = imageIO.getvalue()
return Image.open(BytesIO(imageByte))


if __name__ == '__main__':

pass




读取数据

创建一个读取数据的类

LoadData.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from keras.utils import Sequence
from keras import backend as K
import numpy as np
import Config
import math
import random
import os
import cv2
from PIL import Image
from io import BytesIO
import Utils


class LoadTrainData(Sequence):

def __init__(self,
dataSetDir,
characters=Config.CHAR_SET_LIST,
batchSize=Config.BATCH_SIZE,
steps=Config.STEP,
n_len=Config.CAPTCHA_LENGTH,
width=Config.IMAGE_WIDTH,
height=Config.IMAGE_HEIGHT,
isGRU=False):

self.dataSetDir = dataSetDir
self.characters = characters
self.batchSize = batchSize
self.steps = steps
self.n_len = n_len
self.width = width
self.height = height
self.n_class = len(characters)
# self.generator = ImageCaptcha(width=width, height=height)
self.dataSetlist = self._getDataSetList()
self.isGRU = isGRU

def _parseImageText(self, image_path):
'''
解析验证码的标签.
验证码格式为 --> 标签_XXX.png
XXX为一串随机数, 后缀任意格式, PIL 或 cv2 会自动解析
:param image_path: c:/123_XXX.jpg
:return:XXX
'''
return os.path.basename(image_path).split(".")[0].split("_")[0].strip()

def _getDataSetList(self):
'''
获取全部验证码路径用于喂数据
:return:
'''
dataSetList = []
for root, dirs, files in os.walk(self.dataSetDir):
for file in files:
imagePath = os.path.join(root, file)
dataSetList.append({
"tag": self._parseImageText(imagePath),
"image_path": imagePath
})
return dataSetList

def __len__(self):
return self.steps

def __getitem__(self, idx):
x = np.zeros((self.batchSize, self.height, self.width, Config.IMAGE_CHANNEL), dtype=np.float32)
y = [np.zeros((self.batchSize, self.n_class), dtype=np.uint8) for _ in range(self.n_len)]
for i in range(self.batchSize):
entity = random.choice(self.dataSetlist)
tag = entity["tag"]
# 加载图片
image = Image.open(entity["image_path"]).convert("RGB")
# 统一图片格式, 默认: PNG
image = Utils.Preprocess.convertImageFormat(image)
# 缩放图片到指定尺寸
image = image.resize((Config.IMAGE_WIDTH, Config.IMAGE_HEIGHT), Image.ANTIALIAS)
# 检测通道数并转换
imageArray = Utils.Preprocess.convertImageToArrayByChannel(image)
# 把色度值压缩到 0-1 区间
x[i] = imageArray / 255.0
# 编码验证码标签
for j in range(Config.CAPTCHA_LENGTH):
y[j][i, :] = 0
y[j][i, self.characters.index(tag)] = 1

return x, y

class LoadPredictData():

def __init__(self):
pass

def decode(self, y, isIndefiniteLength=False):
'''
将向量列表解码成字符
:param y:
:return:
'''
if isIndefiniteLength:
y = np.array(y)
y = np.argmax(y, axis=2)[:, 0]
text = ""
for x in y:
if x < len(Config.CHAR_SET_LIST):
text += Config.CHAR_SET_LIST[x]
return text
else:
y = np.array(y)
y = np.argmax(y, axis=2)[:, 0]
return ''.join([Config.CHAR_SET_LIST[x] for x in y])

def _parseImageText(self, imagePath):
'''
解析验证码的标签.
验证码格式为 --> XXX_标签.png
XXX为一串随机数, 后缀任意格式, PIL 或 cv2 会自动解析
:param image_path:
:return:
'''
image_name = os.path.basename(imagePath)
tag = image_name.split(".")[0].split("_")[0].strip()
return tag

# 多通道
def getPredictDataFromPath(self, imagePath):
x = np.zeros((1, Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH, Config.IMAGE_CHANNEL), dtype=np.float32)
image = Image.open(imagePath).convert("RGB")
# 统一格式, 训练和预测的格式不同会导致准确率急剧下降(我也不知道为啥)
image = Utils.Preprocess.convertImageFormat(image)
# 缩放图片到指定尺寸
image = image.resize((Config.IMAGE_WIDTH, Config.IMAGE_HEIGHT), Image.ANTIALIAS)
# 检测通道数并转换
imageArray = Utils.Preprocess.convertImageToArrayByChannel(image)
# 把色度值压缩到 0-1 区间
x[0] = imageArray / 255.0
return x

def getPredictDataFromByte(self, imageByte):
x = np.zeros((1, Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH, Config.IMAGE_CHANNEL), dtype=np.float32)
image = Image.open(BytesIO(imageByte)).convert("RGB")
image = Utils.Preprocess.convertImageFormat(image)
image = image.resize((Config.IMAGE_WIDTH, Config.IMAGE_HEIGHT), Image.ANTIALIAS)
imageArray = Utils.Preprocess.convertImageToArrayByChannel(image)
x[0] = imageArray / 255.0
return x

def getPredictDataFromIO(self, imageIO):
x = np.zeros((1, Config.IMAGE_HEIGHT, Config.IMAGE_WIDTH, Config.IMAGE_CHANNEL), dtype=np.float32)
image = Image.open(imageIO).convert("RGB")
image = Utils.Preprocess.convertImageFormat(image)
image = image.resize((Config.IMAGE_WIDTH, Config.IMAGE_HEIGHT), Image.ANTIALIAS)
imageArray = Utils.Preprocess.convertImageToArrayByChannel(image)
x[0] = imageArray / 255.0
return x

训练模型

新建 Training.py 用于训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import Config
from keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from keras.models import load_model
import os
from LoadData import LoadTrainData
import net_model.ResNet as ResNet
import net_model.SmallCNN as SmallCNN

class Training(object):

def __init__(self):
# 构建模型
# self.model = SmallCNN.build(layres=4)
# self.model = ResNet.build(numberOfLayers=32)
self.model = SmallCNN.build(layres=4)

# 读取已构建的模型继续训练
# self.model = load_model(Config.MODEL_SAVE_PATH)

# 获取训练和验证数据序列
self.trainSetSequence = LoadTrainData(Config.DATA_SET_DIR)
self.validSetSequence = LoadTrainData(Config.VALID_SET_DIR)

def train(self):
# 创建回调
callbacks = [
# 当监测值不再改善时,该回调函数将中止训练
EarlyStopping(patience=20),
# 将epoch的训练结果保存在csv文件中,支持所有可被转换为string的值,包括1D的可迭代数值如np.ndarray.
CSVLogger(Config.CSV_SAVE_PATH),
# 该回调函数将在每个epoch后保存模型到filepath
ModelCheckpoint(Config.MODEL_SAVE_PATH, save_best_only=True),
# 当评价指标不在提升时,减少学习率
ReduceLROnPlateau(factor=0.1, patience=5),
# 该回调函数是一个可视化的展示器
TensorBoard(Config.TENSOR_BOARD_DIR)
]
# 训练
self.model.fit_generator(
generator=self.trainSetSequence,
# 训练多少步, step = data set size / batch size
steps_per_epoch=Config.STEP,
# 迭代多少遍数据集
epochs=Config.EPOCHS,
validation_data=self.validSetSequence,
validation_steps=64,
callbacks=callbacks,
workers=2,
use_multiprocessing=False,
initial_epoch=Config.INIT_EPOCHS
)


def prediction(self):
pass

def evaluation(self):
pass

if __name__ == '__main__':

task = Training()
task.train()

在我的GTX980ti上训练了20个epoch准确率达到0.99

测试模型

新建 Prediction.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import Config
from keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from keras.models import load_model
import os
from LoadData import LoadPredictData
import time
import numpy as np
import random


class Prediction(object):

def __init__(self):
# 读取已构建的模型
self.model = load_model(Config.MODEL_SAVE_PATH)
# self.model = load_model("./model/itgood_resnet20.model")
self.loadData = LoadPredictData()


def predictFromImagePath(self, imagePath):
X = self.loadData.getPredictDataFromPath(imagePath)
y_pred = self.model.predict(X)
origin = os.path.basename(imagePath).split(".")[0].split("_")[0].strip()
if Config.CAPTCHA_LENGTH == 1:
predict = self.loadData.decode([y_pred])
else:
predict = self.loadData.decode(y_pred, isIndefiniteLength=True)
print("original:{0: <10} prediction:{1: <10}".format(origin, predict))
if origin.lower() == predict.lower():
return True
return False

def predictFromImageByte(self, imageByte):
X = self.loadData.getPredictDataFromByte(imageByte)
y_pred = self.model.predict(X)
origin = os.path.basename(imageByte).split(".")[0].split("_")[-1].strip()
if Config.CAPTCHA_LENGTH == 1:
predict = self.loadData.decode([y_pred])
else:
predict = self.loadData.decode(y_pred)
print("original:{} prediction:{}".format(origin, predict))
return self.loadData.decode(y_pred)

def predictFromImageDir(self, imageDir):
imageNameList = os.listdir(imageDir)
total = len(imageNameList)
error = 0
for i, imageName in enumerate(imageNameList):
imagePath = os.path.join(imageDir, imageName)
if not self.predictFromImagePath(imagePath):
error += 1
correct = total - error
print("sample total: {} accuracy: {:.5f}".format(len(imageNameList), correct / total))


if __name__ == '__main__':

t1 = time.time()
task = Prediction()

task.predictFromImageDir(Config.TEST_SET_DIR)

准确率基本为1.0

训练图像部分

同训练Label部分一样。