generated from daniil-berg/boilerplate-py
29 lines
938 B
Python
29 lines
938 B
Python
from typing import Sequence
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
|
|
def plot_images(images: Sequence[np.ndarray], labels: Sequence[str] = None, num_columns: int = 4,
|
|
transpose: bool = True) -> None:
|
|
if transpose:
|
|
images = tf.transpose(images, perm=[0, 2, 1, 3])
|
|
images = images[:, :, :, 0] * 255
|
|
images = images.numpy().astype('uint8')
|
|
num_rows = len(images) // num_columns or 1
|
|
_, axs = plt.subplots(num_rows, num_columns, figsize=(10, 5))
|
|
for idx, image in enumerate(images):
|
|
if num_rows == 1:
|
|
if num_columns == 1:
|
|
ax = axs
|
|
else:
|
|
ax = axs[idx // num_columns]
|
|
else:
|
|
ax = axs[idx // num_columns, idx % num_columns]
|
|
ax.imshow(image, cmap='gray')
|
|
if labels is not None:
|
|
ax.set_title(labels[idx])
|
|
ax.axis('off')
|
|
plt.show()
|