generated from daniil-berg/boilerplate-py
150 lines
4.4 KiB
Python
150 lines
4.4 KiB
Python
import logging
|
|
import json
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import tensorflow as tf
|
|
|
|
from .config import CONFIG
|
|
from .ctc_layer import CTCLayer
|
|
from .keras.callbacks import EarlyStopping, History
|
|
from .keras.layers import Bidirectional, Conv2D, Dense, Dropout, Input, LSTM, MaxPooling2D, Reshape
|
|
from .keras.models import Model
|
|
from .keras.optimizers import Adam, Optimizer
|
|
from .types import PathT
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def build_model(alphabet_size: int,
|
|
img_width: int = CONFIG.DEFAULT_IMG_WIDTH,
|
|
img_height: int = CONFIG.DEFAULT_IMG_HEIGHT,
|
|
optimizer: Optimizer = Adam()) -> Model:
|
|
log.info("Building model")
|
|
# Inputs to the model
|
|
input_img = Input(
|
|
shape=(img_width, img_height, 1),
|
|
dtype='float32',
|
|
name=CONFIG.LAYER_NAME_INPUT_IMAGE
|
|
)
|
|
labels = Input(
|
|
shape=(None, ),
|
|
dtype='float32',
|
|
name=CONFIG.LAYER_NAME_INPUT_LABEL,
|
|
)
|
|
# First conv block
|
|
x = Conv2D(
|
|
filters=32,
|
|
kernel_size=(3, 3),
|
|
activation='relu',
|
|
kernel_initializer='he_normal',
|
|
padding='same',
|
|
name='conv1',
|
|
)(input_img)
|
|
x = MaxPooling2D(
|
|
pool_size=(2, 2),
|
|
name='pool1'
|
|
)(x)
|
|
# Second conv block
|
|
x = Conv2D(
|
|
filters=64,
|
|
kernel_size=(3, 3),
|
|
activation='relu',
|
|
kernel_initializer='he_normal',
|
|
padding='same',
|
|
name='conv2',
|
|
)(x)
|
|
x = MaxPooling2D(
|
|
pool_size=(2, 2),
|
|
name='pool2'
|
|
)(x)
|
|
# We have used two max. pooling layers with pool size and strides 2.
|
|
# Hence, down-sampled feature maps are 4x smaller. The number of
|
|
# filters in the last layer is 64. Reshape accordingly before
|
|
# passing the output to the RNN part of the model
|
|
down_sample_factor = 4
|
|
new_shape = (
|
|
(img_width // down_sample_factor),
|
|
(img_height // down_sample_factor) * 64
|
|
)
|
|
x = Reshape(
|
|
target_shape=new_shape,
|
|
name='reshape'
|
|
)(x)
|
|
x = Dense(
|
|
units=64,
|
|
activation='relu',
|
|
name='dense1'
|
|
)(x)
|
|
x = Dropout(rate=0.2)(x)
|
|
# RNNs
|
|
x = Bidirectional(
|
|
LSTM(
|
|
units=128,
|
|
return_sequences=True,
|
|
dropout=0.25,
|
|
)
|
|
)(x)
|
|
x = Bidirectional(
|
|
LSTM(
|
|
units=64,
|
|
return_sequences=True,
|
|
dropout=0.25,
|
|
)
|
|
)(x)
|
|
# Output layer
|
|
x = Dense(
|
|
units=alphabet_size + 1,
|
|
activation='softmax',
|
|
name=CONFIG.LAYER_NAME_OUTPUT,
|
|
)(x)
|
|
# Add CTC layer for calculating CTC loss at each step
|
|
output = CTCLayer(name='ctc_loss')(labels, x)
|
|
# Define the model
|
|
model = Model(
|
|
inputs=[input_img, labels],
|
|
outputs=output,
|
|
name=CONFIG.MODEL_NAME
|
|
)
|
|
log.debug("Compiling model")
|
|
model.compile(optimizer=optimizer)
|
|
return model
|
|
|
|
|
|
def train_model(model: Model, train_dataset: tf.data.Dataset, valid_dataset: tf.data.Dataset,
|
|
num_epochs: int = CONFIG.DEFAULT_NUM_EPOCHS,
|
|
early_stopping_patience: int = CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE) -> History:
|
|
early_stopping = EarlyStopping(
|
|
monitor='val_loss',
|
|
patience=early_stopping_patience,
|
|
restore_best_weights=True,
|
|
)
|
|
log.debug("Beginning training")
|
|
history = model.fit(
|
|
x=train_dataset,
|
|
validation_data=valid_dataset,
|
|
epochs=num_epochs,
|
|
callbacks=[early_stopping],
|
|
)
|
|
return history
|
|
|
|
|
|
def start(training_data: tf.data.Dataset, validation_data: tf.data.Dataset, vocabulary: str,
|
|
save_dir: PathT = CONFIG.DEFAULT_SAVE_DIR, num_epochs: int = CONFIG.DEFAULT_NUM_EPOCHS,
|
|
early_stopping_patience: int = CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE) -> None:
|
|
save_dir = Path(save_dir, datetime.now().strftime('%Y-%m-%d_%H-%M'))
|
|
save_dir.mkdir(parents=True)
|
|
model = build_model(len(vocabulary))
|
|
history = train_model(model, training_data, validation_data,
|
|
num_epochs=num_epochs, early_stopping_patience=early_stopping_patience)
|
|
log.debug("Saving model")
|
|
model.save(save_dir)
|
|
log.debug("Saving vocabulary")
|
|
with open(Path(save_dir, CONFIG.VOCABULARY_FILE_NAME), 'w') as f:
|
|
f.write(vocabulary)
|
|
log.debug("Saving history")
|
|
with open(Path(save_dir, CONFIG.HISTORY_FILE_NAME), 'w') as f:
|
|
json.dump(history.history, f, indent=4)
|
|
log.info("All saved!")
|