딥러닝
[딥러닝] 딥러닝 모델 저장하고 불러오기
eyoo
2022. 6. 14. 11:43
모델을 저장하는 방법은 크게 3가지로 분류할수있다.
- 전체 네트워크와 웨이트를 함께 저장
- 네트워크만 저장
- 웨이트만 저장
전체 네트워크와 웨이트를 함께 저장할때는 간단히 save 함수를 사용하면 된다.
model.save('fashion_mnist_model')
# 이렇게 하면 폴더로 저장된다.
# 코랩을 사용하여 구글 드라이브 마운트 후, 경로를 지정하고 저장했다.
파일 확장자를 함께 쓰면 파일로 저장된다.
model.save('fashion_mnist_model.h5')
# 모델을 저장하는 파일 확장자는 h5를 많이쓴다.
load_model 함수를 사용하여 모델을 불러올수있다.
loaded_model = tf.keras.models.load_model('fashion_mnist_model.h5')
# 기존 만들어 저장했던 모델과 똑같은 모델을 불러왔다.
네트워크만 저장할때는 to_json 함수를 활용한다.
my_network = model.to_json()
with open('my_network','w') as json_file:
json_file.write(my_network)
# open 함수에 저장한 변수와 w(write)를 사용하고 write함수에 파일명을 입력하여 저장했다.
저장된 네트워크를 불러오자.
with open('my_network','r') as json_file:
my_net2 = json_file.read()
# 이번엔 open에 불러올 파일명과 r(read)를 사용하고 read함수를 사용하여 불러왔다.
model_from_json을 사용하여 네트워크를 구성해보자.
loaded_model2 = tf.keras.models.model_from_json(my_net2)
# 웨이트는 없고 모델 구조만 불러온 것이다.
# 학습이 안된 랜덤한 웨이트로 세팅된 모델인것에 주의하자.
이제 save_weights로 웨이트만 저장해보자.
model.save_weights('fashion_mnist_weight.h5')
저장된 웨이트를 load_weights를 사용하여 위의 빈 모델에 적용하자.
loaded_model2.load_weights('fashion_mnist_weight.h5')
# 기존에 있던 네트워크와 웨이트가 적용된 모델, 즉 기존과 동일한 모델을 불러왔다.