generated from daniil-berg/boilerplate-py
79 lines
4.1 KiB
Python
79 lines
4.1 KiB
Python
import sys
|
|
from pathlib import Path
|
|
from typing import Iterable, Sequence
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from keras.api._v2.keras.models import Model, load_model
|
|
from keras.api._v2.keras.layers import StringLookup
|
|
from keras.api._v2.keras.backend import ctc_decode
|
|
|
|
from .config import CONFIG
|
|
from .preprocess import process_image, decode_label, find_image_files, get_lookup_table
|
|
from .types import PathT, ImgT, Array
|
|
from .visualize import plot_images
|
|
|
|
|
|
def process_predictions(predictions: tf.Tensor) -> tf.Tensor:
|
|
num_predictions = predictions.shape[0] # corresponds to the number of images passed into the model for inference
|
|
output_width = predictions.shape[1] # corresponds to the (down-sampled) width of an image
|
|
# It is worth noting that `predictions.shape[2]` corresponds to the size of the vocabulary + 1,
|
|
# i.e. one more than the number of distinct characters that can occur in a label.
|
|
|
|
# Since the `predictions` tensor is the output of a softmax activation function, we need to decode the values along
|
|
# the "width axis" from arrays of floats between 0 and 1 to single integers representing the inferred characters.
|
|
# (see CTC concepts)
|
|
|
|
# Construct 1D array, each element representing the width of a single prediction, i.e. the down-sampled image width:
|
|
seq_lengths = np.ones(num_predictions) * output_width
|
|
# Retrieve the sequences of label indices inferred by the model:
|
|
sequences, _probabilities = ctc_decode(predictions, input_length=seq_lengths, greedy=True)
|
|
# Since we use a greedy approach, only one sequence per prediction is returned, so we discard the other dimensions:
|
|
sequences = sequences[0]
|
|
# Now this is a 2D tensor, for which `sequences.shape[0]` corresponds to the number of samples/images,
|
|
# while `sequences.shape[1]` corresponds to the size of the vocabulary + 1.
|
|
# Assuming n characters were inferred, the first n elements of each array will be the label indices of those
|
|
# characters, whereas the rest of the elements will be -1, implying blank labels. Since we know the maximum length
|
|
# a string of characters in an image can have, we can discard all those labels, that must be blank.
|
|
# What we are then left with, will be an array of relevant label indices for each image passed through the model.
|
|
# Using a backward lookup table, these can later be easily decoded to the actual characters.
|
|
return sequences[:, :CONFIG.MAX_STRING_LENGTH]
|
|
|
|
|
|
def load_inference_model(model_dir: PathT) -> tuple[Model, StringLookup]:
|
|
with open(Path(model_dir, CONFIG.VOCABULARY_FILE_NAME), 'r') as vocab_file:
|
|
backward_lookup = get_lookup_table(vocab_file.read(), invert=True)
|
|
saved_model = load_model(model_dir)
|
|
inference_model = Model(
|
|
saved_model.get_layer(name=CONFIG.LAYER_NAME_INPUT_IMAGE).input,
|
|
saved_model.get_layer(name=CONFIG.LAYER_NAME_OUTPUT).output
|
|
)
|
|
return inference_model, backward_lookup
|
|
|
|
|
|
def predict_and_decode(images: Sequence[ImgT], model: Model, backward_lookup: StringLookup) -> tuple[Array, list[str]]:
|
|
dataset = np.array([process_image(img) for img in images])
|
|
encoded_labels = process_predictions(model.predict(dataset))
|
|
return dataset, [decode_label(label, backward_lookup) for label in encoded_labels]
|
|
|
|
|
|
def load_and_infer(images: Sequence[ImgT], model_dir: PathT, plot_results: bool = False) -> list[str]:
|
|
model, backward_lookup = load_inference_model(model_dir)
|
|
images, labels = predict_and_decode(images, model, backward_lookup)
|
|
if plot_results:
|
|
per_plot = 24
|
|
for i in range(0, len(images), per_plot):
|
|
plot_images(images[i:(i + per_plot)], labels=labels[i:(i + per_plot)])
|
|
return labels
|
|
|
|
|
|
def start(model_dir: PathT, image_files: Sequence[Path] = (), images_dir: PathT = None,
|
|
file_ext: Iterable[str] = CONFIG.DEFAULT_IMG_FILE_EXT, plot_results: bool = False) -> None:
|
|
if images_dir is not None:
|
|
image_files = sorted(find_image_files(images_dir, file_ext=file_ext))
|
|
if not image_files:
|
|
image_files = [sys.stdin.buffer.read()]
|
|
labels = load_and_infer(image_files, model_dir, plot_results=plot_results)
|
|
for label in labels:
|
|
print(label)
|