ccaptchas/src/ccaptchas/model.py

150 lines
4.4 KiB
Python

import logging
import json
from datetime import datetime
from pathlib import Path
import tensorflow as tf
from .config import CONFIG
from .ctc_layer import CTCLayer
from .keras.callbacks import EarlyStopping, History
from .keras.layers import Bidirectional, Conv2D, Dense, Dropout, Input, LSTM, MaxPooling2D, Reshape
from .keras.models import Model
from .keras.optimizers import Adam, Optimizer
from .types import PathT
log = logging.getLogger(__name__)
def build_model(alphabet_size: int,
img_width: int = CONFIG.DEFAULT_IMG_WIDTH,
img_height: int = CONFIG.DEFAULT_IMG_HEIGHT,
optimizer: Optimizer = Adam()) -> Model:
log.info("Building model")
# Inputs to the model
input_img = Input(
shape=(img_width, img_height, 1),
dtype='float32',
name=CONFIG.LAYER_NAME_INPUT_IMAGE
)
labels = Input(
shape=(None, ),
dtype='float32',
name=CONFIG.LAYER_NAME_INPUT_LABEL,
)
# First conv block
x = Conv2D(
filters=32,
kernel_size=(3, 3),
activation='relu',
kernel_initializer='he_normal',
padding='same',
name='conv1',
)(input_img)
x = MaxPooling2D(
pool_size=(2, 2),
name='pool1'
)(x)
# Second conv block
x = Conv2D(
filters=64,
kernel_size=(3, 3),
activation='relu',
kernel_initializer='he_normal',
padding='same',
name='conv2',
)(x)
x = MaxPooling2D(
pool_size=(2, 2),
name='pool2'
)(x)
# We have used two max. pooling layers with pool size and strides 2.
# Hence, down-sampled feature maps are 4x smaller. The number of
# filters in the last layer is 64. Reshape accordingly before
# passing the output to the RNN part of the model
down_sample_factor = 4
new_shape = (
(img_width // down_sample_factor),
(img_height // down_sample_factor) * 64
)
x = Reshape(
target_shape=new_shape,
name='reshape'
)(x)
x = Dense(
units=64,
activation='relu',
name='dense1'
)(x)
x = Dropout(rate=0.2)(x)
# RNNs
x = Bidirectional(
LSTM(
units=128,
return_sequences=True,
dropout=0.25,
)
)(x)
x = Bidirectional(
LSTM(
units=64,
return_sequences=True,
dropout=0.25,
)
)(x)
# Output layer
x = Dense(
units=alphabet_size + 1,
activation='softmax',
name=CONFIG.LAYER_NAME_OUTPUT,
)(x)
# Add CTC layer for calculating CTC loss at each step
output = CTCLayer(name='ctc_loss')(labels, x)
# Define the model
model = Model(
inputs=[input_img, labels],
outputs=output,
name=CONFIG.MODEL_NAME
)
log.debug("Compiling model")
model.compile(optimizer=optimizer)
return model
def train_model(model: Model, train_dataset: tf.data.Dataset, valid_dataset: tf.data.Dataset,
num_epochs: int = CONFIG.DEFAULT_NUM_EPOCHS,
early_stopping_patience: int = CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE) -> History:
early_stopping = EarlyStopping(
monitor='val_loss',
patience=early_stopping_patience,
restore_best_weights=True,
)
log.debug("Beginning training")
history = model.fit(
x=train_dataset,
validation_data=valid_dataset,
epochs=num_epochs,
callbacks=[early_stopping],
)
return history
def start(training_data: tf.data.Dataset, validation_data: tf.data.Dataset, vocabulary: str,
save_dir: PathT = CONFIG.DEFAULT_SAVE_DIR, num_epochs: int = CONFIG.DEFAULT_NUM_EPOCHS,
early_stopping_patience: int = CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE) -> None:
save_dir = Path(save_dir, datetime.now().strftime('%Y-%m-%d_%H-%M'))
save_dir.mkdir(parents=True)
model = build_model(len(vocabulary))
history = train_model(model, training_data, validation_data,
num_epochs=num_epochs, early_stopping_patience=early_stopping_patience)
log.debug("Saving model")
model.save(save_dir)
log.debug("Saving vocabulary")
with open(Path(save_dir, CONFIG.VOCABULARY_FILE_NAME), 'w') as f:
f.write(vocabulary)
log.debug("Saving history")
with open(Path(save_dir, CONFIG.HISTORY_FILE_NAME), 'w') as f:
json.dump(history.history, f, indent=4)
log.info("All saved!")