5. 파이썬
83536 fashion_checker.py
패스트코드블로그
2020. 6. 2. 22:27
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import tensorflow as tf from tensorflow import keras import matplotlib.pyplot as plt import numpy as np from mnist_test.number_checker import NumberChecker from mnist_test.fashion_checker import FashionChecker class FashionChecker: def __init__(self): self.class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] def create_model(self) -> []: fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() # plt.figure() # plt.imshow(train_images[10]) # plt.colorbar() # plt.grid(False) # plt.show() # modeling model = keras.Sequential([ keras.layers.Flatten(input_shape=(28, 28)), keras.layers.Dense(128, activation='relu'), keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss = 'sparse_categorical_crossentropy', metrics=['accuracy']) # learning model.fit(train_images, train_labels, epochs=5) # test test_loss, test_acc = model.evaluate(test_images, test_labels) print('테스트 정확도: {}'.format(test_acc)) if __name__ == '__main__': # t = NumberChecker() t = FashionChecker() t.create_model() | cs |