ccaptchas/src/ccaptchas/__main__.py

160 lines
5.9 KiB
Python

from argparse import ArgumentParser
from pathlib import Path
from typing import Any, Sequence
from .config import CONFIG
CMD = 'command'
TRAIN = 'train'
DATA_DIR = 'data_dir'
SAVE_DIR = 'save_dir'
FILE_EXT = 'file_ext'
BATCH_SIZE = 'batch_size'
VALIDATION_RATIO = 'validation_ratio'
IMG_WIDTH = 'img_width'
IMG_HEIGHT = 'img_height'
NUM_EPOCHS = 'num_epochs'
EARLY_STOPPING_PATIENCE = 'early_stopping_patience'
_PREPROCESSING_KEYS = (DATA_DIR, FILE_EXT, BATCH_SIZE, VALIDATION_RATIO, IMG_WIDTH, IMG_HEIGHT)
_TRAINING_KEYS = (SAVE_DIR, NUM_EPOCHS, EARLY_STOPPING_PATIENCE)
INFER = 'infer'
MODEL_DIR = 'model_dir'
IMAGES_DIR = 'images_dir'
IMAGE_FILES = 'image_files'
PLOT_RESULTS = 'plot_results'
def ext_list(string: str) -> list[str]:
out = []
for ext in string.split(','):
ext = ext.strip()
if not ext.startswith('.'):
raise ValueError("Extensions must start with a dot")
out.append(ext)
return out
def parse_cli(args: Sequence[str] = None) -> dict[str, Any]:
parser = ArgumentParser(
prog=CONFIG.PROGRAM_NAME,
description="Character CAPTCHA Solver",
)
parser.add_argument(
'-E', f'--{FILE_EXT.replace("_", "-")}',
default=CONFIG.DEFAULT_IMG_FILE_EXT,
type=ext_list,
help=f"When used in `{TRAIN}` mode, extensions of the image files to be used for training/testing the model. "
f"When used in `{INFER}` mode, extensions of the image files to use the model on."
f"Defaults to {CONFIG.DEFAULT_IMG_FILE_EXT}."
)
subparsers = parser.add_subparsers(dest=CMD, title="Commands")
parser_train = subparsers.add_parser(TRAIN, help="trains a new model")
parser_train.add_argument(
DATA_DIR,
type=Path,
help="Directory containing the image files to be used for training/testing the model."
)
preprocessing_group = parser_train.add_argument_group("Preprocessing options")
training_group = parser_train.add_argument_group("Training options")
training_group.add_argument(
'-s', f'--{SAVE_DIR.replace("_", "-")}',
default=CONFIG.DEFAULT_SAVE_DIR,
type=Path,
help=f"Directory in which to save trained models. A subdirectory for each training session named with the "
f"current date and time will be created there and the model will be saved in that subdirectory. "
f"Defaults to '{CONFIG.DEFAULT_SAVE_DIR}'."
)
preprocessing_group.add_argument(
'-b', f'--{BATCH_SIZE.replace("_", "-")}',
default=CONFIG.DEFAULT_BATCH_SIZE,
type=int,
help=f"The dataset will be divided into batches; this determines the number of images in each batch. "
f"Defaults to {CONFIG.DEFAULT_BATCH_SIZE}."
)
preprocessing_group.add_argument(
'-r', f'--{VALIDATION_RATIO.replace("_", "-")}',
default=CONFIG.DEFAULT_VALIDATION_RATIO,
type=float,
help=f"The dataset will split into training and validation data; this argument should be a float between 0 "
f"and 1 determining the relative size of the validation dataset to the whole dataset. "
f"Defaults to {round(CONFIG.DEFAULT_VALIDATION_RATIO, 3)}."
)
preprocessing_group.add_argument(
'-W', f'--{IMG_WIDTH.replace("_", "-")}',
default=CONFIG.DEFAULT_IMG_WIDTH,
type=int,
help=f"The width of an image in pixels. Defaults to {CONFIG.DEFAULT_IMG_WIDTH}."
)
preprocessing_group.add_argument(
'-H', f'--{IMG_HEIGHT.replace("_", "-")}',
default=CONFIG.DEFAULT_IMG_HEIGHT,
type=int,
help=f"The height of an image in pixels. Defaults to {CONFIG.DEFAULT_IMG_HEIGHT}."
)
training_group.add_argument(
'-n', f'--{NUM_EPOCHS.replace("_", "-")}',
default=CONFIG.DEFAULT_NUM_EPOCHS,
type=int,
help=f"The number of training epochs. Defaults to {CONFIG.DEFAULT_NUM_EPOCHS}."
)
training_group.add_argument(
'-p', f'--{EARLY_STOPPING_PATIENCE.replace("_", "-")}',
default=CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE,
type=int,
help=f"The number of training epochs with no improvement over a previously achieved optimum to allow before "
f"stopping training early (i.e. without completing all epochs). "
f"Defaults to {CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE}."
)
parser_infer = subparsers.add_parser(INFER, help="uses an existing model to make inferences")
parser_infer.add_argument(
MODEL_DIR,
type=Path,
help="Directory containing the model to use for inference."
)
data_group = parser_infer.add_mutually_exclusive_group()
data_group.add_argument(
'-f', f'--{IMAGE_FILES.replace("_", "-")}',
type=Path,
nargs='*',
metavar='PATH',
help="Paths to image files to use the model on."
)
data_group.add_argument(
'-d', f'--{IMAGES_DIR.replace("_", "-")}',
type=Path,
metavar='PATH',
help="Path to directory containing the image files to use the model on."
)
parser_infer.add_argument(
'-p', f'--{PLOT_RESULTS.replace("_", "-")}',
action='store_true',
help="If set, a plot will be displayed, showing the images with the inferred labels."
)
return vars(parser.parse_args(args))
def main() -> None:
kwargs = parse_cli()
cmd = kwargs.pop(CMD)
if cmd == TRAIN:
from .model import start
from .preprocess import load_datasets
pre_kwargs = {k: kwargs.pop(k) for k in _PREPROCESSING_KEYS}
training_data, validation_data, vocabulary = load_datasets(**pre_kwargs)
start(training_data, validation_data, vocabulary, **kwargs)
elif cmd == INFER:
from .infer import start
start(**kwargs)
else:
raise SystemExit # Should be unreachable since argument parser will throw an error earlier
if __name__ == '__main__':
main()