인공지능/딥러닝

딥러닝 : Tensorflow의 콜백클래스를 이용해서 원하는 조건이 되면 학습을 멈추게 하기

왕현성 2022. 12. 29. 11:20
728x90

https://hyunsungstory.tistory.com/161

 

딥러닝 : Fashion mnist 10개로 분류된 패션 이미지를 tensorflow를 이용해 분류하기 / Flatten , softmax , 분

1. 이미지와 행렬 모든 이미지 사진은 픽셀당 숫자로 되어있습니다. 0~255까지 되어있고 0이 검정색, 255가 흰색입니다. 그 숫자의 데이터 타입은 Uint8(Unsigned int) 이라고 적습니다. 먼저 검정부터 회

hyunsungstory.tistory.com

이전 게시글의 데이터셋을 이용합니다.

 

epochs가 무조건 많다고 좋은 것이 아님을 이 바로 전 게시글에서 확인할 수 있었습니다. ( 오버 피팅 )

 

그래서 이제 val_accuracy가 88%가 넘으면 멈추도록 하고싶다.

 

class myCallback(tf.keras.callbacks.Callback) :
  def on_epoch_end(self,epoch,logs={}) :
    if logs['val_accuracy'] > 0.88:
      print('\n내가 정한 정확도에 도달했으니, 학습을 멈춘다')
      self.model.stop_training = True

위와 같이 원하는 조건이 되면 학습을 멈추게하는 코드를 작성해주고

my_cb=myCallback()

이를 변수로 저장합니다.

 

def build_model():
  model = Sequential()
  model.add( Flatten()  )
  model.add( Dense(128, 'relu') )
  model.add( Dense(64, 'relu') )
  model.add( Dense(10, 'softmax'))
  model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
  return model
model = build_model()
epoch_history = model.fit(X_train,y_train,epochs=30,validation_split=0.2,callbacks=[my_cb])

모델링 이후 변수 저장후 학습을 진행하게 되면

epochs를 30으로 설정 했어도 val_accuracy가 88%가 넘는 6번 째 epochs에서 학습이 끝난 것을 확인할 수 있다.