왕현성
코딩발자취
왕현성
전체 방문자
오늘
어제
  • 코딩 (277)
    • Python (71)
    • Java (16)
    • MySQL (34)
    • 인공지능 (48)
      • 머신러닝 (16)
      • 딥러닝 (32)
    • 영상처리 (4)
    • Rest API (21)
    • Android Studio (25)
    • streamlit (13)
    • DevOps (22)
      • AWS (9)
      • PuTTY (5)
      • Git (4)
      • Serverless (2)
      • Docker (2)
    • IT 기술 용어 (6)
    • 디버깅 ( 오류 해결 과정 ) (17)

블로그 메뉴

  • 홈
  • 태그
  • 방명록

공지사항

인기 글

태그

  • labelme
  • 의료이미징
  • maskimage
  • imageprocessing
  • alibidetect
  • labelme UnocodeDecodeError
  • yolov8
  • get_long_description
  • matplotlib
  • PIL
  • ComputerVision
  • PYTHON
  • pytorch
  • numpy
  • pip install labelme
  • ckpt_file
  • TensorFlow
  • 비지도학습
  • OpenCV
  • encoding='utf-8'
  • 영상처리
  • 기상탐사
  • unsupervised
  • tune()
  • 딥러닝
  • 컴퓨터비전
  • 영상기술
  • UnboundLocalError
  • 영상처리역사
  • alibi-detection

최근 댓글

최근 글

티스토리

250x250
hELLO · Designed By 정상우.
왕현성

코딩발자취

딥러닝 : 레이블링된 y값을 원핫 인코딩으로 바꾸기 tf.keras.utils.to_categorical() / Mnist 손글씨 숫자 예측
인공지능/딥러닝

딥러닝 : 레이블링된 y값을 원핫 인코딩으로 바꾸기 tf.keras.utils.to_categorical() / Mnist 손글씨 숫자 예측

2022. 12. 29. 17:19
728x90

 

다음과 같은 ANN을 만들 것이다.

 

이미지파일(28X28픽셀)이 입력으로 들어오면, 아웃풋으로는 0~9 까지의 10개 숫자로 분류하는 인공지능!

사진은 2차원 데이터이므로, 우리는 ANN의 입력에, 사진의 픽셀값을 flattening 하여 입력을 줄 것이다.

따라서 입력레이어는 784개, 히든1은 512, 히든2는 512, 아웃풋은 10개의 신경망 구축.

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
%matplotlib inline
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.optimizers import Adam

 

우선 사용할 라이브러리들을 import해줍니다.

 

MNist 데이터를 가져온다. 이미 7만장의 손글씨 이미지 데이터가 있다. 

(X_train, y_train), (X_test, y_test) = mnist.load_data()

이후 데이터 모양 확인

실제 이미지 확인

 

# 넘파이의 레이블 인코딩된 값을 원핫 인코딩으로 바꾸는 방법!
# Tensorflow가 제공한다

 

1. 데이터를 딥러닝으로 처리하기 위해서, 행렬로 만들면서, 가로세로 값을 일렬로 만든다.

X_train=X_train.reshape(60000,(28*28))
X_test = X_test.reshape(10000,(28*28))

2. 데이터를 딥러닝에서 처리할 수 있도록 float로 바꿔준다.

 

X_train = X_train.astype(float)
X_test = X_test.astype(float)

3. 이미지라서, 숫자가 0~255 이므로, 0~1 사이로 정규화 시켜주자.

X_train=X_train / 255.0
X_test = X_test / 255.0

4. 분류의 문제이므로, y값을 확인하여, 카테고리컬 데이터를 원핫인코딩값으로 바꾼다.

 

from tensorflow.keras.utils import to_categorical

우선 라이브러리를 임포트 해줍니다.

 

# 넘파이의 레이블 인코딩된 값을 원핫 인코딩으로 바꾸는 방법!
# Tensorflow가 제공한다

아래 코드와 같이 작성 해주면 원핫 인코딩으로 바꿔준다.

y_train=tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test=tf.keras.utils.to_categorical(y_test,num_classes=10)

5. 모델 만들기, 컴파일

 

def build_model():
  model = Sequential()
  model.add( Dense(512, 'relu',input_shape=(784,)) )
  model.add( Dropout(0.4))
  model.add( Dense(10, 'softmax'))
  model.compile('adam', 'categorical_crossentropy', ['accuracy'])
  return model

분류의 문제에서 y값이 원핫 인코딩 되어있는 상황에서는

loss함수를 categorical_crossentropy로 설정을 해준다

 

6. 학습과 모델 평가

 

model = build_model()
epoch_history=model.fit(X_train,y_train,epochs=5,validation_data=(X_test,y_test))
model.evaluate(X_test,y_test)

이후 confusion_matrix를 해주기 위해 

원핫 인코딩 되어있는 값을

y_test2=y_test.argmax(axis=1)
y_pred2=y_pred.argmax(axis=1)

argmax(axis=1)을 이용해

위 사진과 같이 바꿔준다.

 

cm = confusion_matrix(y_test2,y_pred2)
import seaborn as sb
sb.heatmap(cm,annot=True,fmt='.0f',cmap='coolwarm')
plt.show()

seaborn의 heatmap으로 시각화 해주면 아래 사진과 같다.

'인공지능 > 딥러닝' 카테고리의 다른 글

딥러닝 : CNN 말과 사람 분류하기 / CNN모델링 방법 , 이미지파일을 학습 데이터로 만드는 방법(ImageDataGenerator)  (0) 2022.12.30
CNN의 convolution,Stride,Padding,Pooling / feature map의 사이즈를 구하는 공식  (0) 2022.12.29
딥러닝 : Tensorflow의 모델을 저장하고 불러오는 방법  (0) 2022.12.29
딥러닝 : Flatten()라이브러리 없이 이미지를 평탄화 하는 방법과 Validation_data= 파라미터 사용법  (0) 2022.12.29
딥러닝 : Tensorflow의 콜백클래스를 이용해서 원하는 조건이 되면 학습을 멈추게 하기  (0) 2022.12.29
    '인공지능/딥러닝' 카테고리의 다른 글
    • 딥러닝 : CNN 말과 사람 분류하기 / CNN모델링 방법 , 이미지파일을 학습 데이터로 만드는 방법(ImageDataGenerator)
    • CNN의 convolution,Stride,Padding,Pooling / feature map의 사이즈를 구하는 공식
    • 딥러닝 : Tensorflow의 모델을 저장하고 불러오는 방법
    • 딥러닝 : Flatten()라이브러리 없이 이미지를 평탄화 하는 방법과 Validation_data= 파라미터 사용법
    왕현성
    왕현성
    AI 머신비전 학습일지

    티스토리툴바