02. Study/Keras
keras model 저장 및 callback
미카이
2019. 8. 7. 23:14
딥러닝 학습중 커널이 죽는 경우가 종종 발생하는데, 그럴때 항상 처음부터 모델을 학습하기에는 너무 오랜시간이 걸리고 다시 학습시 weight들의 초기값에 따라 결과가 조금씩 달라질 수 있는데, 이럴때 사용할 수 있는 방법이 epoch마다 weight 저장과 모델을 저장하는것이다. 또한, 매번 저장하면 용량이 커질수 있으니 val_loss가 낮아질때마다 저장시킬 수 있는 callback방법도 알아보자.
일반적으로 h5와 hdf5 형식으로 많이 저장하는데, 한번 특징을 알아보자
Hierachical Data Format Version 5(.h5, .hdf5) - 대용량 데이터를 저장을 위한 파일 포맷이며 NASA에서 대용량 데이터를 저장하기 위한 도구로 개발되었고 tree 구조를 가지고 있다.
ref : https://www.hdfgroup.org/
ref : https://reference.wolfram.com/language/ref/format/HDF5.html
Weight만 저장
1
2
3
4
5
6
7
8
9
10
11
12
|
import keras
import numpy
y = [1,3,5,7,9]
model = keras.models.Sequential()
model.compile('SGD', 'mse')
model.save_weights('model.h5') #학습된 weight 저장
http://colorscripter.com/info#e" target="_blank" style="color:#4f4f4ftext-decoration:none">Colored by Color Scripter
|
Model 저장
1
2
3
4
5
6
7
8
9
10
11
12
|
import keras
import numpy
y = [1,3,5,7,9]
model = keras.models.Sequential()
model.compile('SGD', 'mse')
http://colorscripter.com/info#e" target="_blank" style="color:#4f4f4ftext-decoration:none">Colored by Color Scripter
|
Weight load - weight load 시 model 구조가 동일해야됨
1
2
3
4
5
6
7
8
9
10
11
12
|
import numpy
import keras
model = keras.models.Sequential()
model.compile(optimizer='SGD', loss='mse')
model.load_weights("model.h5")
print('Predictions:', model.predict(x).flatten()) # 예측 결과
http://colorscripter.com/info#e" target="_blank" style="color:#4f4f4ftext-decoration:none">Colored by Color Scripter
|
Model load
1
2
3
4
5
6
7
8
9
10
|
import numpy
model=load_model('model.h5')
print('Predictions:', model.predict(x).flatten()) # 예측 결과
http://colorscripter.com/info#e" target="_blank" style="color:#4f4f4ftext-decoration:none">Colored by Color Scripter
|
keras 에서 제공하는 callback함수는 모델 훈련과정을 제어 할 수 있다. (checkpoint, earlystopping etc...)
ref : https://keras.io/callbacks/
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
callback_list = [
keras.callbacks.EarlyStopping( #성능 향상이 멈추면 훈련을 중지
monitor='val_acc', #모델 검증 정확도를 모니터링
patience=1 #1 에포크 보다 더 길게(즉, 2에포크 동안 정확도가 향상되지 않으면 훈련 중지
),
keras.callbacks.ModelCheckpoint( #에포크마다 현재 가중치를 저장
filepath="my_model.h5", #모델 파일 경로
monitor='val_loss', # val_loss 가 좋아지지 않으면 모델 파일을 덮어쓰지 않음.
save_best_only=True
)
]
model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) #정확도를 모니터링하므로 모델 지표에 포함되어야 함.
history = model.fit(input_train, y_train,
epochs=10,
callbacks=callback_list,
batch_size=128,
validation_split=0.2)# 콜백이 검증과 손실 정확도를 모니터링 하기 때문에 validation_Data 매개변수에 검증 데이터 전달해야 함
http://colorscripter.com/info#e" target="_blank" style="color:#4f4f4ftext-decoration:none">Colored by Color Scripter
|
전체코드
model save_load - 클릭
callback - 클릭