인공지능/딥러닝
딥러닝 : Tensorflow의 모델을 저장하고 불러오는 방법
왕현성
2022. 12. 29. 15:04
728x90
https://hyunsungstory.tistory.com/161
딥러닝 : Fashion mnist 10개로 분류된 패션 이미지를 tensorflow를 이용해 분류하기 / Flatten , softmax , 분
1. 이미지와 행렬 모든 이미지 사진은 픽셀당 숫자로 되어있습니다. 0~255까지 되어있고 0이 검정색, 255가 흰색입니다. 그 숫자의 데이터 타입은 Uint8(Unsigned int) 이라고 적습니다. 먼저 검정부터 회
hyunsungstory.tistory.com
위 게시글과 똑같은 데이터셋을 이용하여
모델링 과정 이후에 모델을 저장하고 불러오는 방법들에 대해 설명하겠습니다.
1. 전체 네트워크와 웨이트를 통으로 저장하고 불러오기
- 폴더 구조로 저장,불러오기
# 폴더 구조로 저장.
model.save('fashion_mnist_model')
# 저장된 인공지능을 불러오는 코드.
model2 = tf.keras.models.load_model('fashion_mnist_model')
model2.evaluate(X_test,y_test)
정상 작동 확인
- 모델을 파일 하나로 저장, 불러오기
model.save('fashion_mnist_model.h5')
model3=tf.keras.models.load_model('fashion_mnist_model.h5')
model3.predict(X_test)
2. 네트워크와 웨이트를 따로 저장하고 불러오기.
- 네트워크만 저장하고 불러오기
model.to_json()
# 네트워크를 json 파일로 저장하는 코드
fashion_mnist_network=model.to_json()
with open('fashion_mnist_network.json' , 'w') as file :
file.write(fashion_mnist_network)
# 저장된 네트워크를 읽어오는 코드
with open('fashion_mnist_network.json','r') as file:
fashion_net = file.read()
# 위의 네트워크로부터 모델을 만들고싶으면
model4=tf.keras.models.model_from_json(fashion_net)
# model4는 네트워크만 가져온 것이지,,, 학습 완료된 웨이트는 가져온 것이 아니다
# 따라서 현재 웨이트는 랜덤으로 셋팅된 웨이트다.
# 이것으로 예측 수행하면 안된다.
- 웨이트를 저장하고 불러오기
# 웨이트를 저장하고 불러오는 코드
model.save_weights('fashion_mnist_weight.h5')
이후 위에서 만들 model4에 웨이트를 집어넣는 방법
model4.load_weights('fashion_mnist_weight.h5')
model4.predict(X_test)
정상 작동하는 것을 확인할 수 있다.