ccaptchas/src/ccaptchas/visualize.py

29 lines
938 B
Python

from typing import Sequence
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
def plot_images(images: Sequence[np.ndarray], labels: Sequence[str] = None, num_columns: int = 4,
transpose: bool = True) -> None:
if transpose:
images = tf.transpose(images, perm=[0, 2, 1, 3])
images = images[:, :, :, 0] * 255
images = images.numpy().astype('uint8')
num_rows = len(images) // num_columns or 1
_, axs = plt.subplots(num_rows, num_columns, figsize=(10, 5))
for idx, image in enumerate(images):
if num_rows == 1:
if num_columns == 1:
ax = axs
else:
ax = axs[idx // num_columns]
else:
ax = axs[idx // num_columns, idx % num_columns]
ax.imshow(image, cmap='gray')
if labels is not None:
ax.set_title(labels[idx])
ax.axis('off')
plt.show()