0%

使用Flask部署Keras模型并提供Restful接口

使用Flask部署Keras模型并提供Restful接口

前言

当我们训练好了模型后想发布出去,提供Restful风格的API供外部使用,这个时候我们可以选择任意选择一款Web框架进行发布。我之前使用过Django,但是感觉有些繁琐,所以我更倾向于使用Flask这个框架,因为这个框架的路由标注方式跟Java的Springboot非常相似。接下来我们用之前训练的识别12306验证码的模型来进行部署。

新建项目

在Pycharm里新建一个Flask项目

20200428183550

  • 新建一个controller包,在里面再新建三个子包,charset,models,utils,分别保存字符集,模型和工具类

20200428185938

  • 在charset包里新建Label.py模块,内容为训练模型时候的字符集。

20200428190142

  • 复制训练好的模型到models包底下

20200428190306

创建工具模块

  • 我们在utils包底下新建几个工具模块

20200428190427

  • Base64.py

    1
    2
    3
    4
    5
    6
    7
    import base64

    class Base64(object):

    @classmethod
    def convertToBytes(cls, imageBase64: str) -> bytes:
    return base64.b64decode(imageBase64)
  • ConvertFormat.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    from PIL import Image
    import os
    import Config
    import shutil
    import numpy as np
    from io import BytesIO


    class ConvertFormat(object):

    @classmethod
    def convertImageToArrayByChannel(cls, image, channel: int):
    '''
    input image (PIL) object convert to array by channel
    :param image: image obj
    :return: numpy array obj
    '''
    if channel == 3:
    return np.array(image.convert("RGB"))
    if channel == 1:
    imageArray = np.array(image.convert("L"))
    imageArray = np.expand_dims(imageArray, axis=2)
    return imageArray
    return np.array(image.convert("RGB"))

    @classmethod
    def convertImageFormat(cls, image, format="PNG"):
    '''
    input image obj, convert to png format
    :param image: image obj
    :return: image obj
    '''
    if image is None:
    raise TypeError("image object is none.")
    with BytesIO() as imageIO:
    image.save(imageIO, format=format)
    imageByte = imageIO.getvalue()
    return Image.open(BytesIO(imageByte))
  • HashUtils.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import hashlib


    class HashUtils(object):


    @staticmethod
    def md5(string: str):
    md5 = hashlib.md5(string.encode("utf-8"))
    return md5.hexdigest()

    @staticmethod
    def sha1(string: str):
    sha1 = hashlib.sha1(string.encode("utf-8"))
    return sha1.hexdigest()

    @staticmethod
    def sha256(string: str):
    sha256 = hashlib.sha256(string.encode("utf-8"))
    return sha256.hexdigest()
  • ResquestUtils.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    from controller.utils import ResponseUtils
    from constants import ResponseStatus
    from flask import request
    from functools import wraps
    from exception import RequestDataFormatException


    class RequestUtils(object):

    @staticmethod
    def checkPostDataIsJson(req: request):
    """check post data is json

    check post data if is json type, otherwise
    return invalid exception.

    Args:
    req: flask request object

    Returns:

    """
    if not req.json:
    raise RequestDataFormatException(
    ResponseStatus.FAILED,
    "data invalid, not json.",
    400
    )


    @staticmethod
    def deleteNoneAttributes(obj: any) -> any:
    """delete attribute is none from instance

    Args:
    obj: instance

    Returns: instance

    """
    if obj is None:
    return obj
    attributeNames = obj.__dict__.keys()
    for attributeName in list(attributeNames):
    attribute = getattr(obj, attributeName)
    if attribute is None:
    delattr(obj, attributeName)
    return obj

  • ResponseUtils.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    from flask import request, make_response

    class ResponseUtils(object):


    @staticmethod
    def responseJson(obj: object, headers: dict=None) -> any:
    """convert obj to response object

    the method encapsulation the object to flask response
    object and set content type is json.
    you still can custom headers.
    the method need input obj rewrite __str__ method,
    __str__ method return a json string

    Examples:

    {
    "status": "...",
    ...
    }

    Args:
    obj: any object

    Returns: flask response object

    """
    response = make_response(str(obj))
    response.headers["Content-Type"] = "application/json"
    if headers:
    for k, v in headers.items():
    response.headers[str(k)] = v
    return response

    创建控制层模块

  • 我们在controller包里面新建我们的控制层模块

20200428190944

  • CutCaptcha.py 用于切割原始验证码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    from io import BytesIO
    from PIL import Image
    from entities import ResultCutCaptcha
    from controller.utils import Base64
    from typing import Tuple
    from typing import Any


    class CutCaptcha(object):

    @classmethod
    def _imageToBytes(cls, image: Image) -> bytes:
    imageIO = BytesIO()
    image.save(imageIO, format="PNG")
    return imageIO.getvalue()

    @classmethod
    def _cutLabel(cls, imageByte: bytes) -> bytes:
    """return Image object"""
    label = Image.open(BytesIO(imageByte)).convert("RGB")
    x = 117
    y = 0
    w = 180
    h = 30
    label = label.crop((x, y, w, h))
    return cls._imageToBytes(label)

    @classmethod
    def _cutImage(cls, imageByte: bytes) -> Tuple[bytes, bytes, bytes, bytes, bytes, bytes, bytes, bytes]:
    """return Image object tuple"""
    image = Image.open(BytesIO(imageByte)).convert("RGB")
    space = 67 + 5
    x0, y0, w0, h0 = 0 * space + 5, 0 * space + 41, 1 * space, 0 * space + 41 + 67
    x1, y1, w1, h1 = 0 * space + 5, 1 * space + 41, 1 * space, 1 * space + 41 + 67
    x2, y2, w2, h2 = 1 * space + 5, 0 * space + 41, 2 * space, 0 * space + 41 + 67
    x3, y3, w3, h3 = 1 * space + 5, 1 * space + 41, 2 * space, 1 * space + 41 + 67
    x4, y4, w4, h4 = 2 * space + 5, 0 * space + 41, 3 * space, 0 * space + 41 + 67
    x5, y5, w5, h5 = 2 * space + 5, 1 * space + 41, 3 * space, 1 * space + 41 + 67
    x6, y6, w6, h6 = 3 * space + 5, 0 * space + 41, 4 * space, 0 * space + 41 + 67
    x7, y7, w7, h7 = 3 * space + 5, 1 * space + 41, 4 * space, 1 * space + 41 + 67
    image0 = image.crop((x0, y0, w0, h0))
    image1 = image.crop((x1, y1, w1, h1))
    image2 = image.crop((x2, y2, w2, h2))
    image3 = image.crop((x3, y3, w3, h3))
    image4 = image.crop((x4, y4, w4, h4))
    image5 = image.crop((x5, y5, w5, h5))
    image6 = image.crop((x6, y6, w6, h6))
    image7 = image.crop((x7, y7, w7, h7))
    return (cls._imageToBytes(image0),
    cls._imageToBytes(image1),
    cls._imageToBytes(image2),
    cls._imageToBytes(image3),
    cls._imageToBytes(image4),
    cls._imageToBytes(image5),
    cls._imageToBytes(image6),
    cls._imageToBytes(image7))

    @classmethod
    def cut(cls, imageByte: bytes) -> ResultCutCaptcha:
    return ResultCutCaptcha(
    label=cls._cutLabel(imageByte),
    images=cls._cutImage(imageByte)
    )
  • DrawMarkedResult.py 用于绘制标记结果

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    from PIL import Image, ImageDraw
    from io import BytesIO
    import Config
    from constants import IMAGE_POSITION_COORDINATES


    class DrawMarkedResult(object):

    @classmethod
    def _sumCenterPoint(cls, a: tuple, b: tuple) -> tuple:
    x = int((b[0] - a[0]) // 2 + a[0])
    y = int((b[1] - a[1]) // 2 + a[1])
    # +30 补偿y轴距离
    return (x, y + 30)

    @classmethod
    def _imageToBytes(cls, image: Image) -> bytes:
    imageIO = BytesIO()
    image.save(imageIO, format="JPEG")
    return imageIO.getvalue()

    @classmethod
    def checkmark(cls, imageByte: bytes, ids: list) -> bytes:
    """draw marked result at origin image"""
    # imageArray = np.asarray(Image.open(BytesIO(imageByte)), dtype="uint8")
    # image = cv2.imdecode(imageArray, cv2.IMREAD_UNCHANGED)
    image = Image.open(BytesIO(imageByte)).convert("RGB")
    draw = ImageDraw.Draw(image)
    for id in ids:
    # draw checkmark
    centerPoint = cls._sumCenterPoint(
    (IMAGE_POSITION_COORDINATES[id][0][0], IMAGE_POSITION_COORDINATES[id][0][1] - 30),
    (IMAGE_POSITION_COORDINATES[id][1][0], IMAGE_POSITION_COORDINATES[id][1][1] - 30)
    )
    # checkmark center point
    tickMidpoint = (
    centerPoint[0] + 2,
    centerPoint[1] + 5
    )
    # a and b is checkmark left and right coordinate
    a = (centerPoint[0] - 5, centerPoint[1] - 5)
    b = (centerPoint[0] + 10, centerPoint[1] - 10)

    draw.line([a, tickMidpoint], fill=Config.DRAW_CONFIG["color"], width=Config.DRAW_CONFIG["lineSize"])
    draw.line([tickMidpoint, b], fill=Config.DRAW_CONFIG["color"], width=Config.DRAW_CONFIG["lineSize"])

    return cls._imageToBytes(image)

    @classmethod
    def rectangle(cls, imageByte: bytes, ids: list) -> bytes:
    image = Image.open(BytesIO(imageByte)).convert("RGB")
    draw = ImageDraw.Draw(image)
    for id in ids:
    draw.rectangle(
    [IMAGE_POSITION_COORDINATES[id][0], IMAGE_POSITION_COORDINATES[id][1]],
    outline=Config.DRAW_CONFIG["color"],
    width=Config.DRAW_CONFIG["lineSize"]
    )
    return cls._imageToBytes(image)
  • ErrorHandler.py 全局错误处理

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    """
    Register error handler in flask framework

    """
    from flask import jsonify, Blueprint
    from exception import RequestDataFormatException
    from exception import CaptchaValidException
    from exception import GetCaptchaException


    errorHandler = Blueprint("errorHandler", __name__)


    @errorHandler.app_errorhandler(RequestDataFormatException)
    def handleRequestDataFormatException(error):
    response = jsonify(error.toDict())
    response.statusCode = error.statusCode
    return response

    @errorHandler.app_errorhandler(CaptchaValidException)
    def handleCaptchaValidException(error):
    response = jsonify(error.toDict())
    response.statusCode = error.statusCode
    return response

    @errorHandler.app_errorhandler(GetCaptchaException)
    def handleGetCaptchaException(error):
    response = jsonify(error.toDict())
    response.statusCode = error.statusCode
    return response
  • GetCaptcha.py 从12306获取验证码用于测试

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import json
    import requests
    from exception import GetCaptchaException


    class GetCaptcha(object):

    @classmethod
    def getCaptchaFrom12306(cls) -> str:
    '''
    获取12306验证码
    :return: base64 string
    '''
    url = "https://kyfw.12306.cn/passport/captcha/captcha-image64"
    headers = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.102 Safari/537.36"
    }
    req = requests.get(url=url, headers=headers)
    json_data = json.loads(req.content)
    if json_data['result_message'] == "系统维护时间":
    raise GetCaptchaException("system maintenance, get captcha failed.")
    return json_data['image']
  • MarkCaptcha.py 标记结果

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    import Config
    from keras.models import load_model
    from .PreprocessingCaptcha import PreprocessingCaptcha
    from .CutCaptcha import CutCaptcha
    from controller.utils import Base64
    from io import BytesIO


    class MarkCaptcha(object):

    def __init__(self):
    self._labelModel = load_model(Config.MODEL_PATH["label"])
    self._imageModel = load_model(Config.MODEL_PATH["image"])

    def _label(self, imageByte: bytes) -> str:
    """mark label"""
    x = PreprocessingCaptcha.loadData(
    imageByte=imageByte,
    height=Config.MARK_CONFIG["label"]["height"],
    width=Config.MARK_CONFIG["label"]["width"],
    channel=Config.MARK_CONFIG["label"]["channel"],
    )
    result = self._labelModel.predict(x)
    result = PreprocessingCaptcha.decode([result])
    return result

    def _image(self, imageByte: bytes) -> str:
    """mark image"""
    x = PreprocessingCaptcha.loadData(
    imageByte=imageByte,
    height=Config.MARK_CONFIG["image"]["height"],
    width=Config.MARK_CONFIG["image"]["width"],
    channel=Config.MARK_CONFIG["image"]["channel"],
    )
    result = self._imageModel.predict(x)
    result = PreprocessingCaptcha.decode([result])
    return result

    def mark(self, imageBase64: str) -> list:
    cutResults = CutCaptcha.cut(Base64.convertToBytes(imageBase64))
    label = cutResults.getLabel
    labelStr = self._label(label)
    images = cutResults.getImages
    markedIds = []
    for i, image in enumerate(images):
    imageStr = self._image(image)
    if labelStr.strip() == imageStr.strip():
    markedIds.append(i)
    return markedIds
  • PreCheck.py 检查前端传过来的Base64编码后的验证码是否有效

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    from PIL import Image
    from exception import CaptchaValidException
    from io import BytesIO
    import base64
    from constants import ResponseStatus


    class PreCheck(object):

    @staticmethod
    def checkImageIsValid(imageBase64: str):
    try:
    Image.open(BytesIO(base64.b64decode(imageBase64))).convert("RGB")
    except IOError:
    raise CaptchaValidException(
    status=ResponseStatus.FAILED,
    message="captcha is invalid, broken data."
    )
  • PreprocessingCaptcha.py 预处理验证码,处理成Numpy数组才能喂给模型!

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    import Config
    import numpy as np
    from PIL import Image
    from io import BytesIO
    from controller.utils import ConvertFormat


    class PreprocessingCaptcha(object):

    @classmethod
    def decode(cls, y: list) -> str:
    """decode np array predict result to string"""
    y = np.array(y)
    y = np.argmax(y, axis=2)[:, 0]
    return "".join([Config.LABEL_LIST[x] for x in y])

    @classmethod
    def loadData(cls, imageByte: bytes, height: int, width: int, channel: int):
    """load data from image byte to np array"""
    x = np.zeros((1, height, width, channel), dtype=np.float32)
    # read to image object
    image = Image.open(BytesIO(imageByte)).convert("RGB")
    # convert format
    image = ConvertFormat.convertImageFormat(image)
    # resize image
    image = image.resize((width, height), Image.ANTIALIAS)
    # convert channel
    imageArray = ConvertFormat.convertImageToArrayByChannel(image, channel)
    # normalization
    x[0] = imageArray / 255.0
    return x
  • Scheduler.py 封装成一个模块,对外只提供一个标记方法。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    from PIL import Image, ImageDraw
    import numpy as np
    # import cv2
    import Config
    import random
    from typing import List
    from controller import MarkCaptcha
    from io import BytesIO
    from controller.utils import Base64
    from .DrawMarkedResult import DrawMarkedResult
    from constants import IMAGE_POSITION_COORDINATES
    from entities import ResultMarked
    from entities import RequestMark
    import base64
    from .PreCheck import PreCheck


    class Scheduler(object):

    def __init__(self):
    self._markCaptcha = MarkCaptcha()


    def markCaptcha(self, requestMark: RequestMark) -> ResultMarked:
    # check is valid
    PreCheck.checkImageIsValid(requestMark.originCaptcha)
    # get data from entity
    originCaptcha = requestMark.originCaptcha
    ids = self._markCaptcha.mark(originCaptcha)
    results = self._sumCoordinate(ids)
    markedCaptcha = str(base64.b64encode(
    DrawMarkedResult.rectangle(
    base64.b64decode(originCaptcha),
    ids
    )
    ), encoding="utf-8")
    return ResultMarked(
    originCaptcha=originCaptcha,
    ids=ids,
    results=results,
    markedCaptcha=markedCaptcha
    )

    def _sumCoordinate(self, ids: list) -> List[tuple]:
    """

    calculate coordinates based on marked result ids

    Args:
    ids: marked result id list

    Returns: example -> [(12, 43) ...]

    """
    markedResults = []
    for id in ids:
    a = IMAGE_POSITION_COORDINATES[id][0]
    b = IMAGE_POSITION_COORDINATES[id][1]
    x = random.randint(a[0] + 10, b[0] - 10)
    y = random.randint(a[1] + 10, b[1] - 10)
    markedResults.append((x, y))
    return markedResults

创建实体包装数据

  • 新建一个entities包,在里面新建几个实体包装数据

20200428191955

  • BaseEntity.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    import json
    import inspect
    from exception import RequestDataFormatException
    from constants import ResponseStatus


    class BaseEntity(object):
    """base entity

    the class implement object convert to json string and
    check attributes is none method.

    """
    def __str__(self):
    return json.dumps(self, default=lambda obj: obj.__dict__, ensure_ascii=False, indent=4, sort_keys=True)

    def checkAttributesIsNone(self, on: bool = True):
    """check attributes is none

    Returns:

    """
    if not on:
    return
    for attributeName in list(self.__dict__.keys()):
    value = getattr(self, attributeName)
    if value is None:
    raise RequestDataFormatException(
    status=ResponseStatus.FAILED,
    message="Attribute {} not found.".format(str(attributeName)),
    statusCode=400
    )
  • RequestMark.py 请求接口

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from entities import BaseEntity


    class RequestMark(BaseEntity):

    def __init__(self, requestJson: dict):
    self._originCaptcha = requestJson.setdefault("originCaptcha", None)

    self.checkAttributesIsNone()

    @property
    def originCaptcha(self) -> str:
    return self._originCaptcha
  • ResponseGeneral.py 响应接口

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    from constants import ResponseStatus
    from enum import Enum
    import json


    class ResponseGeneral(object):

    def __init__(self, status: Enum, message: str, **kwargs):
    self.status = status.value
    self.message = message
    for k, v in kwargs.items():
    setattr(self, k, v)

    def __str__(self):
    jsonObj = dict()
    args = self.__dict__.keys()
    for arg in args:
    jsonObj.setdefault(arg, getattr(self, arg))
    return json.dumps(jsonObj)
  • ResultCutCaptcha.py 包装验证码切割的结果

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    from typing import Tuple
    from typing import Any
    from PIL import Image

    class ResultCutCaptcha(object):

    def __init__(self, label: bytes, images: Tuple[bytes, bytes, bytes, bytes, bytes, bytes, bytes, bytes]):
    self.label = label
    self.images = images

    @property
    def getLabel(self):
    return self.label

    @property
    def getImages(self):
    return self.images
  • ResultMarked.py 包装标记结果

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    import json


    class ResultMarked(object):

    def __init__(self, originCaptcha: str, ids: list, results: list, markedCaptcha: str):
    self._originCaptcha = originCaptcha
    self._ids = ids
    self._results = results
    self._markedCatpcha = markedCaptcha

    @property
    def originCaptcha(self):
    return self._originCaptcha

    @property
    def ids(self):
    return self._ids

    @property
    def results(self):
    return self._results

    @property
    def markedResult(self):
    return self._markedCatpcha

    def __str__(self):
    return json.dumps(
    {
    "originCaptcha": self._originCaptcha,
    "ids": self._ids,
    "results": self._results,
    "markedCaptcha": self._markedCatpcha
    }
    )

错误处理

  • 创建几个表达式用于自定义错误处理。

20200428192451

  • BaseException.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    from enum import Enum


    class BaseException(Exception):

    STATUS_CODE = 200

    def __init__(self, status: Enum, message: str, statusCode=None, payload=None):
    super().__init__()
    self.status = status
    self.message = message
    if statusCode is not None:
    self.statusCode = statusCode
    else:
    self.statusCode = self.STATUS_CODE
    self.payload = payload

    def toDict(self):
    rv = dict(self.payload or ())
    rv["statusCode"] = int(self.statusCode)
    rv["status"] = self.status.value
    rv["message"] = self.message
    return rv

    def __str__(self):
    return self.message

  • RequestException.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    from exception import BaseException


    class RequestDataFormatException(BaseException):
    pass

    class CaptchaValidException(BaseException):
    pass

    class GetCaptchaException(BaseException):
    pass

创建配置和初始化模块

  • Initialization.py 初始化一个全局scheduler对象

    1
    2
    3
    from controller import Scheduler

    scheduler = Scheduler()
  • Config.py 全局配置文件

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    import os
    from controller.charset import LABEL


    CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

    # general configuration
    APP_NAME = "Mark12306Captcha"
    APP_VERSION = "v1.0"

    # model and mark configuration
    MODEL_PATH = {
    "label": os.path.join(CURRENT_DIR, "controller", "models", "Label_12306_SmallCNN4.model"),
    "image": os.path.join(CURRENT_DIR, "controller", "models", "Image_12306_SmallCNN4.model")
    }

    LABEL_LIST = LABEL

    MARK_CONFIG = {
    # captcha size
    "label": {
    "height": 30,
    "width": 63,
    "channel": 1
    },
    "image": {
    "height": 67,
    "width": 67,
    "channel": 3
    }
    }

    # draw marked result
    DRAW_CONFIG = {
    # RGB
    "color": (255, 0, 0),
    "lineSize": 5
    }
  • App.py 把api导入进去,这个是程序的入口

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    from flask import Flask, render_template
    from api import mark
    from api import get
    from controller import errorHandler
    import os



    app = Flask(__name__)
    app.register_blueprint(errorHandler)
    app.register_blueprint(mark)
    app.register_blueprint(get)


    if __name__ == '__main__':
    app.run()

创建API

  • 提供两个API,一个用于获取从12306获取验证码,一个用于标记验证码。

20200428192755

  • Get.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    from flask import jsonify, Blueprint, request
    import Config
    from controller.utils import RequestUtils
    from controller.utils import ResponseUtils
    from controller import PreCheck
    from entities import RequestMark
    from entities import ResponseGeneral
    from constants import ResponseStatus
    from controller import PreCheck
    import json
    from controller import GetCaptcha


    get = Blueprint("get", __name__, url_prefix="/{}/api/{}/get".format(Config.APP_NAME, Config.APP_VERSION))


    @get.route("/captcha", methods=["GET"])
    def getCaptchaFrom12306():
    """get captcha from 12306"""
    # RequestUtils.checkPostDataIsJson(request)
    # from Initialization import scheduler
    # requetsMark = RequestMark(request.json)
    # PreCheck.checkImageIsValid(requetsMark.originCaptcha)
    # resultMarked = scheduler.markCaptcha(requetsMark)

    return ResponseUtils.responseJson(ResponseGeneral(
    status=ResponseStatus.SUCCESS,
    message="mark successfully",
    result=GetCaptcha.getCaptchaFrom12306()
    ))
  • Mark.py

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    from flask import jsonify, Blueprint, request
    import Config
    from controller.utils import RequestUtils
    from controller.utils import ResponseUtils
    from controller import PreCheck
    from entities import RequestMark
    from entities import ResponseGeneral
    from constants import ResponseStatus
    from controller import PreCheck
    import json

    mark = Blueprint("mark", __name__, url_prefix="/{}/api/{}/mark".format(Config.APP_NAME, Config.APP_VERSION))


    @mark.route("", methods=["POST"])
    def markCaptcha():
    """mark captcha"""
    RequestUtils.checkPostDataIsJson(request)
    from Initialization import scheduler
    requetsMark = RequestMark(request.json)
    # PreCheck.checkImageIsValid(requetsMark.originCaptcha)
    resultMarked = scheduler.markCaptcha(requetsMark)
    return ResponseUtils.responseJson(ResponseGeneral(
    status=ResponseStatus.SUCCESS,
    message="mark successfully",
    result=json.loads(str(resultMarked))
    ))


    @mark.route("lite", methods=["POST"])
    def markCaptchaLite():
    """mark captcha lite result"""
    RequestUtils.checkPostDataIsJson(request)
    from Initialization import scheduler
    requetsMark = RequestMark(request.json)
    # PreCheck.checkImageIsValid(requetsMark.originCaptcha)
    resultMarked = scheduler.markCaptcha(requetsMark)
    return ResponseUtils.responseJson(ResponseGeneral(
    status=ResponseStatus.SUCCESS,
    message="mark successfully",
    result=resultMarked.results
    ))

创建公共用的常量包

20200428192943

  • ImagePositionCoordinates.py 这个文件保存了验证码里8个子图片的坐标

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    IMAGE_POSITION_COORDINATES = [
    [(5, 12 + 30), (72, 79 + 30)],
    [(5, 84 + 30), (72, 151 + 30)],
    [(77, 12 + 30), (142, 79 + 30)],
    [(77, 84 + 30), (142, 151 + 30)],
    [(147, 12 + 30), (214, 79 + 30)],
    [(147, 84 + 30), (214, 151 + 30)],
    [(221, 12 + 30), (286, 79 + 30)],
    [(221, 84 + 30), (286, 151 + 30)]
    ]
  • ResponseStatus.py 接口相应状态

    1
    2
    3
    4
    5
    6
    7
    from enum import Enum


    class ResponseStatus(Enum):

    SUCCESS = "success"
    FAILED = "failed"

测试API

  • 运行项目

20200428193614

  • 测试获取验证码

20200428193842

  • 查看获取到的验证码

20200428194044

  • 我们用刚才获取到的base64字符串测试标记接口

20200428194208

  • 可以看到标记成功了

20200428194353

  • 并且给出了结果

20200428194432

  • 图标的标号为
0 2 4 6
1 3 5 7

大功告成

项目地址

搭建好的接口