본문 바로가기
딥러닝

[딥러닝] 딥러닝 모델 저장하고 불러오기

by eyoo 2022. 6. 14.

모델을 저장하는 방법은 크게 3가지로 분류할수있다.

 

  1. 전체 네트워크와 웨이트를 함께 저장
  2. 네트워크만 저장
  3. 웨이트만 저장

전체 네트워크와 웨이트를 함께 저장할때는 간단히 save 함수를 사용하면 된다.

 

model.save('fashion_mnist_model')

# 이렇게 하면 폴더로 저장된다.

# 코랩을 사용하여 구글 드라이브 마운트 후, 경로를 지정하고 저장했다.

 

 

파일 확장자를 함께 쓰면 파일로 저장된다.

 

model.save('fashion_mnist_model.h5')

# 모델을 저장하는 파일 확장자는 h5를 많이쓴다.

 

 

load_model 함수를 사용하여 모델을 불러올수있다.

 

loaded_model = tf.keras.models.load_model('fashion_mnist_model.h5')

# 기존 만들어 저장했던 모델과 똑같은 모델을 불러왔다.

 

 

네트워크만 저장할때는 to_json 함수를 활용한다.

 

my_network = model.to_json()

with open('my_network','w') as json_file:
  json_file.write(my_network)

# open 함수에 저장한 변수와 w(write)를 사용하고 write함수에 파일명을 입력하여 저장했다.

 

 

저장된 네트워크를 불러오자.

 

with open('my_network','r') as json_file:
  my_net2 = json_file.read()

# 이번엔 open에 불러올 파일명과 r(read)를 사용하고 read함수를 사용하여 불러왔다.

 

 

model_from_json을 사용하여 네트워크를 구성해보자.

 

loaded_model2 = tf.keras.models.model_from_json(my_net2)

# 웨이트는 없고 모델 구조만 불러온 것이다.

# 학습이 안된 랜덤한 웨이트로 세팅된 모델인것에 주의하자.

 

 

이제 save_weights로 웨이트만 저장해보자.

 

model.save_weights('fashion_mnist_weight.h5')

 

 

저장된 웨이트를 load_weights를 사용하여 위의 빈 모델에 적용하자.

 

loaded_model2.load_weights('fashion_mnist_weight.h5')

# 기존에 있던 네트워크와 웨이트가 적용된 모델, 즉 기존과 동일한 모델을 불러왔다.

 

 

 

 

 

댓글