generated from daniil-berg/boilerplate-py
160 lines
5.9 KiB
Python
160 lines
5.9 KiB
Python
from argparse import ArgumentParser
|
|
from pathlib import Path
|
|
from typing import Any, Sequence
|
|
|
|
from .config import CONFIG
|
|
|
|
|
|
CMD = 'command'
|
|
|
|
TRAIN = 'train'
|
|
|
|
DATA_DIR = 'data_dir'
|
|
SAVE_DIR = 'save_dir'
|
|
FILE_EXT = 'file_ext'
|
|
BATCH_SIZE = 'batch_size'
|
|
VALIDATION_RATIO = 'validation_ratio'
|
|
IMG_WIDTH = 'img_width'
|
|
IMG_HEIGHT = 'img_height'
|
|
NUM_EPOCHS = 'num_epochs'
|
|
EARLY_STOPPING_PATIENCE = 'early_stopping_patience'
|
|
_PREPROCESSING_KEYS = (DATA_DIR, FILE_EXT, BATCH_SIZE, VALIDATION_RATIO, IMG_WIDTH, IMG_HEIGHT)
|
|
_TRAINING_KEYS = (SAVE_DIR, NUM_EPOCHS, EARLY_STOPPING_PATIENCE)
|
|
|
|
INFER = 'infer'
|
|
MODEL_DIR = 'model_dir'
|
|
IMAGES_DIR = 'images_dir'
|
|
IMAGE_FILES = 'image_files'
|
|
PLOT_RESULTS = 'plot_results'
|
|
|
|
|
|
def ext_list(string: str) -> list[str]:
|
|
out = []
|
|
for ext in string.split(','):
|
|
ext = ext.strip()
|
|
if not ext.startswith('.'):
|
|
raise ValueError("Extensions must start with a dot")
|
|
out.append(ext)
|
|
return out
|
|
|
|
|
|
def parse_cli(args: Sequence[str] = None) -> dict[str, Any]:
|
|
parser = ArgumentParser(
|
|
prog=CONFIG.PROGRAM_NAME,
|
|
description="Character CAPTCHA Solver",
|
|
)
|
|
parser.add_argument(
|
|
'-E', f'--{FILE_EXT.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_IMG_FILE_EXT,
|
|
type=ext_list,
|
|
help=f"When used in `{TRAIN}` mode, extensions of the image files to be used for training/testing the model. "
|
|
f"When used in `{INFER}` mode, extensions of the image files to use the model on."
|
|
f"Defaults to {CONFIG.DEFAULT_IMG_FILE_EXT}."
|
|
)
|
|
subparsers = parser.add_subparsers(dest=CMD, title="Commands")
|
|
|
|
parser_train = subparsers.add_parser(TRAIN, help="trains a new model")
|
|
parser_train.add_argument(
|
|
DATA_DIR,
|
|
type=Path,
|
|
help="Directory containing the image files to be used for training/testing the model."
|
|
)
|
|
preprocessing_group = parser_train.add_argument_group("Preprocessing options")
|
|
training_group = parser_train.add_argument_group("Training options")
|
|
training_group.add_argument(
|
|
'-s', f'--{SAVE_DIR.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_SAVE_DIR,
|
|
type=Path,
|
|
help=f"Directory in which to save trained models. A subdirectory for each training session named with the "
|
|
f"current date and time will be created there and the model will be saved in that subdirectory. "
|
|
f"Defaults to '{CONFIG.DEFAULT_SAVE_DIR}'."
|
|
)
|
|
preprocessing_group.add_argument(
|
|
'-b', f'--{BATCH_SIZE.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_BATCH_SIZE,
|
|
type=int,
|
|
help=f"The dataset will be divided into batches; this determines the number of images in each batch. "
|
|
f"Defaults to {CONFIG.DEFAULT_BATCH_SIZE}."
|
|
)
|
|
preprocessing_group.add_argument(
|
|
'-r', f'--{VALIDATION_RATIO.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_VALIDATION_RATIO,
|
|
type=float,
|
|
help=f"The dataset will split into training and validation data; this argument should be a float between 0 "
|
|
f"and 1 determining the relative size of the validation dataset to the whole dataset. "
|
|
f"Defaults to {round(CONFIG.DEFAULT_VALIDATION_RATIO, 3)}."
|
|
)
|
|
preprocessing_group.add_argument(
|
|
'-W', f'--{IMG_WIDTH.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_IMG_WIDTH,
|
|
type=int,
|
|
help=f"The width of an image in pixels. Defaults to {CONFIG.DEFAULT_IMG_WIDTH}."
|
|
)
|
|
preprocessing_group.add_argument(
|
|
'-H', f'--{IMG_HEIGHT.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_IMG_HEIGHT,
|
|
type=int,
|
|
help=f"The height of an image in pixels. Defaults to {CONFIG.DEFAULT_IMG_HEIGHT}."
|
|
)
|
|
training_group.add_argument(
|
|
'-n', f'--{NUM_EPOCHS.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_NUM_EPOCHS,
|
|
type=int,
|
|
help=f"The number of training epochs. Defaults to {CONFIG.DEFAULT_NUM_EPOCHS}."
|
|
)
|
|
training_group.add_argument(
|
|
'-p', f'--{EARLY_STOPPING_PATIENCE.replace("_", "-")}',
|
|
default=CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE,
|
|
type=int,
|
|
help=f"The number of training epochs with no improvement over a previously achieved optimum to allow before "
|
|
f"stopping training early (i.e. without completing all epochs). "
|
|
f"Defaults to {CONFIG.DEFAULT_EARLY_STOPPING_PATIENCE}."
|
|
)
|
|
|
|
parser_infer = subparsers.add_parser(INFER, help="uses an existing model to make inferences")
|
|
parser_infer.add_argument(
|
|
MODEL_DIR,
|
|
type=Path,
|
|
help="Directory containing the model to use for inference."
|
|
)
|
|
data_group = parser_infer.add_mutually_exclusive_group()
|
|
data_group.add_argument(
|
|
'-f', f'--{IMAGE_FILES.replace("_", "-")}',
|
|
type=Path,
|
|
nargs='*',
|
|
metavar='PATH',
|
|
help="Paths to image files to use the model on."
|
|
)
|
|
data_group.add_argument(
|
|
'-d', f'--{IMAGES_DIR.replace("_", "-")}',
|
|
type=Path,
|
|
metavar='PATH',
|
|
help="Path to directory containing the image files to use the model on."
|
|
)
|
|
parser_infer.add_argument(
|
|
'-p', f'--{PLOT_RESULTS.replace("_", "-")}',
|
|
action='store_true',
|
|
help="If set, a plot will be displayed, showing the images with the inferred labels."
|
|
)
|
|
return vars(parser.parse_args(args))
|
|
|
|
|
|
def main() -> None:
|
|
kwargs = parse_cli()
|
|
cmd = kwargs.pop(CMD)
|
|
if cmd == TRAIN:
|
|
from .model import start
|
|
from .preprocess import load_datasets
|
|
pre_kwargs = {k: kwargs.pop(k) for k in _PREPROCESSING_KEYS}
|
|
training_data, validation_data, vocabulary = load_datasets(**pre_kwargs)
|
|
start(training_data, validation_data, vocabulary, **kwargs)
|
|
elif cmd == INFER:
|
|
from .infer import start
|
|
start(**kwargs)
|
|
else:
|
|
raise SystemExit # Should be unreachable since argument parser will throw an error earlier
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|