왕현성
코딩발자취
왕현성
전체 방문자
오늘
어제
  • 코딩 (277)
    • Python (71)
    • Java (16)
    • MySQL (34)
    • 인공지능 (48)
      • 머신러닝 (16)
      • 딥러닝 (32)
    • 영상처리 (4)
    • Rest API (21)
    • Android Studio (25)
    • streamlit (13)
    • DevOps (22)
      • AWS (9)
      • PuTTY (5)
      • Git (4)
      • Serverless (2)
      • Docker (2)
    • IT 기술 용어 (6)
    • 디버깅 ( 오류 해결 과정 ) (17)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • unsupervised
  • 의료이미징
  • alibi-detection
  • get_long_description
  • yolov8
  • alibidetect
  • encoding='utf-8'
  • imageprocessing
  • 비지도학습
  • TensorFlow
  • pip install labelme
  • matplotlib
  • labelme UnocodeDecodeError
  • OpenCV
  • 영상처리
  • numpy
  • ComputerVision
  • 영상기술
  • 영상처리역사
  • tune()
  • 딥러닝
  • ckpt_file
  • PIL
  • UnboundLocalError
  • pytorch
  • 기상탐사
  • PYTHON
  • 컴퓨터비전
  • labelme
  • maskimage

최근 댓글

최근 글

티스토리

250x250
hELLO · Designed By 정상우.
왕현성

코딩발자취

딥러닝 : Tensorflow의 모델을 저장하고 불러오는 방법
인공지능/딥러닝

딥러닝 : Tensorflow의 모델을 저장하고 불러오는 방법

2022. 12. 29. 15:04
728x90

https://hyunsungstory.tistory.com/161

 

딥러닝 : Fashion mnist 10개로 분류된 패션 이미지를 tensorflow를 이용해 분류하기 / Flatten , softmax , 분

1. 이미지와 행렬 모든 이미지 사진은 픽셀당 숫자로 되어있습니다. 0~255까지 되어있고 0이 검정색, 255가 흰색입니다. 그 숫자의 데이터 타입은 Uint8(Unsigned int) 이라고 적습니다. 먼저 검정부터 회

hyunsungstory.tistory.com

 

위 게시글과 똑같은 데이터셋을 이용하여 

 

모델링 과정 이후에 모델을 저장하고 불러오는 방법들에 대해 설명하겠습니다.

 

1. 전체 네트워크와 웨이트를 통으로 저장하고 불러오기

 

- 폴더 구조로 저장,불러오기

# 폴더 구조로 저장.
model.save('fashion_mnist_model')
# 저장된 인공지능을 불러오는 코드.
model2 = tf.keras.models.load_model('fashion_mnist_model')
model2.evaluate(X_test,y_test)

정상 작동 확인

 

 - 모델을 파일 하나로 저장, 불러오기

model.save('fashion_mnist_model.h5')
model3=tf.keras.models.load_model('fashion_mnist_model.h5')
model3.predict(X_test)

2. 네트워크와 웨이트를 따로 저장하고 불러오기.

- 네트워크만 저장하고 불러오기

model.to_json()

# 네트워크를 json 파일로 저장하는 코드
fashion_mnist_network=model.to_json()
with open('fashion_mnist_network.json' , 'w') as file :
  file.write(fashion_mnist_network)
# 저장된 네트워크를 읽어오는 코드
with open('fashion_mnist_network.json','r') as file:
  fashion_net = file.read()
# 위의 네트워크로부터 모델을 만들고싶으면
model4=tf.keras.models.model_from_json(fashion_net)
# model4는 네트워크만 가져온 것이지,,, 학습 완료된 웨이트는 가져온 것이 아니다
# 따라서 현재 웨이트는 랜덤으로 셋팅된 웨이트다.
# 이것으로 예측 수행하면 안된다.
 
 
 

- 웨이트를 저장하고 불러오기

# 웨이트를 저장하고 불러오는 코드
model.save_weights('fashion_mnist_weight.h5')

이후 위에서 만들 model4에 웨이트를 집어넣는 방법

model4.load_weights('fashion_mnist_weight.h5')
model4.predict(X_test)

정상 작동하는 것을 확인할 수 있다.

'인공지능 > 딥러닝' 카테고리의 다른 글

CNN의 convolution,Stride,Padding,Pooling / feature map의 사이즈를 구하는 공식  (0) 2022.12.29
딥러닝 : 레이블링된 y값을 원핫 인코딩으로 바꾸기 tf.keras.utils.to_categorical() / Mnist 손글씨 숫자 예측  (0) 2022.12.29
딥러닝 : Flatten()라이브러리 없이 이미지를 평탄화 하는 방법과 Validation_data= 파라미터 사용법  (0) 2022.12.29
딥러닝 : Tensorflow의 콜백클래스를 이용해서 원하는 조건이 되면 학습을 멈추게 하기  (0) 2022.12.29
딥러닝 : epochs의 횟수를 늘렸을 때 학습데이터/밸리데이션 데이터와 OverFitting  (0) 2022.12.29
    '인공지능/딥러닝' 카테고리의 다른 글
    • CNN의 convolution,Stride,Padding,Pooling / feature map의 사이즈를 구하는 공식
    • 딥러닝 : 레이블링된 y값을 원핫 인코딩으로 바꾸기 tf.keras.utils.to_categorical() / Mnist 손글씨 숫자 예측
    • 딥러닝 : Flatten()라이브러리 없이 이미지를 평탄화 하는 방법과 Validation_data= 파라미터 사용법
    • 딥러닝 : Tensorflow의 콜백클래스를 이용해서 원하는 조건이 되면 학습을 멈추게 하기
    왕현성
    왕현성
    AI 머신비전 학습일지

    티스토리툴바