딥러닝
[딥러닝] MNIST 손글씨 숫자 실습, 원핫인코딩: to_categorical, categorical_crossentropy
eyoo
2022. 6. 14. 11:49
MNist 데이터를 가져온다.
# 이미 7만장의 손글씨 이미지 데이터가 있다.
(X_train, y_train), (X_test, y_test) = mnist.load_data()
첫번째 학습 데이터를 시각화 해보자
in;
plt.imshow(X_train[0],cmap='gray')
plt.show()
out:
fromarray을 사용하여 실제 이미지를 볼수있다.
in:
Image.fromarray(X_train[0])
out:
실습 1. 데이터를 딥러닝으로 처리하기 위해서, 행렬로 만들면서, 가로세로 값을 일렬로 만든다.
X_train 형태를 확인한다.
in:
X_train.shape
out:
(60000, 28, 28)
학습할수있는 형태로 변환한다.
X_train = X_train.reshape(60000,28*28)
X_train 형태를 다시 확인한다.
in:
X_train.shape
out:
(60000, 784)
X_test의 형태도 변환한다.
X_test = X_test.reshape(10000,28*28)
실습 2. 데이터를 딥러닝에서 처리할 수 있도록 float로 바꿔준다.
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
실습 3. 이미지라서, 숫자가 0~255 이므로, 0~1 사이로 정규화 시켜주자.
X_train = X_train/255.0
X_test = X_test/255.0
실습 4. 분류의 문제이므로, y값을 확인하여, 카테고리컬 데이터를 원핫인코딩값으로 바꾼다.
to_categorical 함수를 사용하여 변환할수있다.
in:
from tensorflow.keras.utils import to_categorical
y_train = to_categorical(y_train, num_classes = 10)
y_train[0]
out:
array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)
# 원핫인코딩 되었다.
y_test 도 원핫인코딩 하자.
y_test = to_categorical(y_test, num_classes = 10)
실습 5. 모델을 만든다.
model = Sequential()
model.add(Dense(512, 'relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(10, 'softmax'))
# 사진 데이터가 28행 28열 이기에 784를 input_shape에 넣었다.
실습 6. 컴파일 한다.
model.compile('adam','categorical_crossentropy',metrics=['accuracy'])
# 원핫인코딩을한 3개이상의 카테고리컬 데이터기에 categorical_crossentropy 로스함수를 사용한다.
실습 7. 학습시킨다.
epoch_history = model.fit(X_train,y_train, epochs= 5, validation_data= (X_test,y_test))
실습 8. 모델 평가.
in:
model.evaluate(X_test,y_test)
out:
313/313 [==============================] - 1s 3ms/step - loss: 0.0735 - accuracy: 0.9812
[0.07351833581924438, 0.9811999797821045]
# 오차값 0.07, 정확도 0.98이 나왔다.
컨퓨젼 매트릭스를 히트맵으로 나타내자.
in:
from sklearn.metrics import confusion_matrix
y_pred = model.predict(X_test)
y_pred = y_pred.argmax(axis=1)
y_test = y_test.argmax(axis=1)
cm = confusion_matrix(y_test,y_pred)
sb.heatmap(cm,annot= True, fmt= '.0f' ,cmap='RdPu')
plt.show()
out:

# 원핫인코딩 된 데이터를 argmax를 이용하여 다시 변환시켰다.