개인 개발 프로젝트/AI 숫자 판별 앱

[AI 숫자 판별 앱] 2. 숫자 판별 함수 작성

종범2 2019. 12. 16. 22:45

딥러닝에 관한 상세한 내용을 다루는 글이 아니기 때문에 딥러닝과 Mnist에 관한 설명은 생략 하겠다. 대부분의 내용은 밑바닥부터 시작한 딥러닝이라는 책을 참고하였다. 여기에 숫자 손글씨를 이미지로 실제 숫자를 판별하는 딥러닝 모델이 존재하므로 그대로 사용하기로 하였다. 

 

파이썬과 딥러닝 모델을 사용하기 위해 아나콘다를 설치하였다. 

https://www.anaconda.com/distribution/#download-section

 

Anaconda Python/R Distribution - Free Download

Anaconda Distribution is the world's most popular Python data science platform. Download the free version to access over 1500 data science packages and manage libraries and dependencies with Conda.

www.anaconda.com

 

딥러닝 모델을 불러오는 함수와 숫자를 판별하는 함수는 그대로 이용하였고, flask를 이용하여 이미지를 input으로 받아 저장하고 이미지의 숫자를 판별한 결과값을 outuput으로 출력하는 api를 구현하였다. 코드는 다음과 같다.

 

app.py

from flask import Flask, request
from flask_cors import CORS
app = Flask (__name__)
CORS(app)
import numpy as np
from PIL import Image
import glob
import pickle as pickle
from common.functions import sigmoid, softmax

def init_network():
    with open("sample_weight.pkl",'rb') as f:
        network = pickle.load(f)
    return network

def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']
    
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2,W3) + b3
    y = softmax(a3)
    
    return y

@app.route('/')
def hello_world():
    return 'Hello, Mnist App!'

@app.route('/number', methods=['POST'])
def number():
    # get file from request
    f = request.files['file']
    # save file
    fileName = 'tempImgFile.'+f.filename.rsplit('.', 1)[1]
    f.save(fileName)
    # predict
    network = init_network()
    global result
    for image_path in glob.glob(fileName):
        img = Image.open(image_path).convert("L")
        img = np.resize(img, (1,784))
        img = 255-(img)
        y = predict(network, img)
        result = np.argmax(y)
    return str(result)

if __name__ == "__main__":
    app.run()

 

init_network 함수는 sample_weight.pkl 파일을 열어서 딥러닝 network를 정의하는데 필요한 정보들을 불러온다. 여기서는 W1, W2, W3, b1, b2, b3 값을 불러오며 모두 벡터이다.

 

predict 함수는 network를 정의하는데 필요한 정보들과 이미지정보를 input으로 받아 판별한 숫자를 output으로 반환한다. input으로 들어가는 이미지 정보는 1*784 크기의 벡터로 28*28 사이즈의 이미지를 1*784로 벡터로 변환한 값이다. 일반적으로 이미지는 각 픽셀의 값을 (r, g, b)로 가지므로 28*28 사이즈의 이미지라면 각 픽셀에 3개의 값이 존재한다. 하지만 현재 모델에서는 모든 input 이미지를 28*28 사이즈의 흑백 이미지라고 가정한다. 흑백 이미지라면 각 픽셀에 어두운 정도를 나타내는 하나의 값으로 gray scale 값만 존재한다. 만약 흑백이 아닌 이미지가 input으로 들어온다면 흑백 이미지로 변환한다 (number 함수에 해당 코드 존재함).

 

number 함수는 다음과 같은 순서로 실행된다.

1. request로 file을 받는다.

2. 받은 이미지 파일을 같은 경로에 tempImgFile이라는 이름으로 저장한다.

3. init_network 함수를 실행하여 network를 정의하는 정보를 불러와 fileName에 저장한다.

4. convert함수를 실행하여 이미지를 흑백으로 전환한다.

5. img = 255-img로 gray scale 값을 반전한다. 흑백을 반전하는 이유는 트레이닝한 이미지와 내가 input으로 넣을 이미지의 gray scale 값이 완전 반대이기 때문이다. Mnist 데이터는 대부분 검정 바탕에 흰색 글씨인데 내가 input으로 넣을 데이터는 흰 바탕에 검정 글씨이다. 예시는 다음과 같다.

 

Mnist 데이터의 3 이미지
그림판으로 그린 3 이미지

5. predict 함수를 실행하여 모델의 숫자 판별 결과를 y에 저장한다. y값은 1*10 벡터로 각 값은 이미지의 숫자가 0, 1, 2, 3, 4, 5, 6, 7, 8, 9일 확률을 나타낸다.

6. argmax 함수를 실행하여 벡터 요소들중에 가장 값이 큰 요소의 인덱스를 가져와  result에 저장한다.

7. 최종적으로 숫자를 판별한 값인 result를 문자열로 전환하여 반환한다.

 

여기까지가 숫자 판별 함수와 관련된 내용이다. 아직 flask와 관련된 코드를 설명하지 않았는데, 다음 글에서는 flask를 이용한 api 작성과 작성한 api를 Heroku에 배포하는 방법을 설명하겠다.

 

다음 글 바로가기

[AI 숫자 판별 앱] 3. 숫자 판별 API 작성 및 테스트