ccaptchas/src/ccaptchas/infer.py

79 lines
4.1 KiB
Python

import sys
from pathlib import Path
from typing import Iterable, Sequence
import numpy as np
import tensorflow as tf
from keras.api._v2.keras.models import Model, load_model
from keras.api._v2.keras.layers import StringLookup
from keras.api._v2.keras.backend import ctc_decode
from .config import CONFIG
from .preprocess import process_image, decode_label, find_image_files, get_lookup_table
from .types import PathT, ImgT, Array
from .visualize import plot_images
def process_predictions(predictions: tf.Tensor) -> tf.Tensor:
num_predictions = predictions.shape[0] # corresponds to the number of images passed into the model for inference
output_width = predictions.shape[1] # corresponds to the (down-sampled) width of an image
# It is worth noting that `predictions.shape[2]` corresponds to the size of the vocabulary + 1,
# i.e. one more than the number of distinct characters that can occur in a label.
# Since the `predictions` tensor is the output of a softmax activation function, we need to decode the values along
# the "width axis" from arrays of floats between 0 and 1 to single integers representing the inferred characters.
# (see CTC concepts)
# Construct 1D array, each element representing the width of a single prediction, i.e. the down-sampled image width:
seq_lengths = np.ones(num_predictions) * output_width
# Retrieve the sequences of label indices inferred by the model:
sequences, _probabilities = ctc_decode(predictions, input_length=seq_lengths, greedy=True)
# Since we use a greedy approach, only one sequence per prediction is returned, so we discard the other dimensions:
sequences = sequences[0]
# Now this is a 2D tensor, for which `sequences.shape[0]` corresponds to the number of samples/images,
# while `sequences.shape[1]` corresponds to the size of the vocabulary + 1.
# Assuming n characters were inferred, the first n elements of each array will be the label indices of those
# characters, whereas the rest of the elements will be -1, implying blank labels. Since we know the maximum length
# a string of characters in an image can have, we can discard all those labels, that must be blank.
# What we are then left with, will be an array of relevant label indices for each image passed through the model.
# Using a backward lookup table, these can later be easily decoded to the actual characters.
return sequences[:, :CONFIG.MAX_STRING_LENGTH]
def load_inference_model(model_dir: PathT) -> tuple[Model, StringLookup]:
with open(Path(model_dir, CONFIG.VOCABULARY_FILE_NAME), 'r') as vocab_file:
backward_lookup = get_lookup_table(vocab_file.read(), invert=True)
saved_model = load_model(model_dir)
inference_model = Model(
saved_model.get_layer(name=CONFIG.LAYER_NAME_INPUT_IMAGE).input,
saved_model.get_layer(name=CONFIG.LAYER_NAME_OUTPUT).output
)
return inference_model, backward_lookup
def predict_and_decode(images: Sequence[ImgT], model: Model, backward_lookup: StringLookup) -> tuple[Array, list[str]]:
dataset = np.array([process_image(img) for img in images])
encoded_labels = process_predictions(model.predict(dataset))
return dataset, [decode_label(label, backward_lookup) for label in encoded_labels]
def load_and_infer(images: Sequence[ImgT], model_dir: PathT, plot_results: bool = False) -> list[str]:
model, backward_lookup = load_inference_model(model_dir)
images, labels = predict_and_decode(images, model, backward_lookup)
if plot_results:
per_plot = 24
for i in range(0, len(images), per_plot):
plot_images(images[i:(i + per_plot)], labels=labels[i:(i + per_plot)])
return labels
def start(model_dir: PathT, image_files: Sequence[Path] = (), images_dir: PathT = None,
file_ext: Iterable[str] = CONFIG.DEFAULT_IMG_FILE_EXT, plot_results: bool = False) -> None:
if images_dir is not None:
image_files = sorted(find_image_files(images_dir, file_ext=file_ext))
if not image_files:
image_files = [sys.stdin.buffer.read()]
labels = load_and_infer(image_files, model_dir, plot_results=plot_results)
for label in labels:
print(label)