본문 바로가기
딥러닝

[딥러닝] ModelCheckpoint 와 CSVLogger 콜백

by eyoo 2022. 6. 16.

ModelCheckpoint 콜백을 활용하여 학습을 할때 가장 좋은 모델을 자동으로 저장할수있다.

 

또한 CSVLogger 콜백을 통해 학습한 결과를 csv로 저장할수 있다.

 

먼저 해당 라이브러리를 임포트 한다.

 

from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import CSVLogger

 

그리고 파이썬으로 디렉토리를 만드는 코드를 작성한다.

 

PROJECT_PATH= '.'  # 현재 디렉토리, os.chdir로 구글드라이브 경로로 셋팅

if not os.path.exists(PROJECT_PATH + '/checkpoints/' + model_type + '/'):
  os.makedirs(PROJECT_PATH + '/checkpoints/' + model_type + '/')
if not os.path.exists(PROJECT_PATH + '/log/' + model_type + '/'):
  os.makedirs(PROJECT_PATH + '/log/' + model_type + '/')

# checkpoints와 log 디렉토리가 없으면 해당 폴더를 만들어준다.

 

체크포인트 콜백을 변수로 저장한다.

 

cp = ModelCheckpoint(CHECKPOINT_PATH,monitor= 'val_accuracy', save_best_only=True, verbose=1)

#

 

로그 콜백을 변수로 저장한다.

 

csv_logger = CSVLogger(LOGFILE_PATH, append= True)

 

 

학습할때 callbacks 파라미터에 넣어서 사용한다.

 

epoch_history = model.fit(train_generator, epochs=40, validation_data= (X_val, y_val), callbacks=[cp,csv_logger], steps_per_epoch=10)

#

 

 

 

 

 

 

댓글