즐거운프로그래밍

[딥러닝] 사용자 데이터로 CNN 학습하기- 5. 인공신경망 학습시키기

수수께끼 고양이 2023. 11. 6. 10:40
728x90
반응형

 

5. 인공신경망 학습시키기

 

_04_cnn_training_4.py

from _04_cnn_training_3 import *

import tensorflow as tf

model=tf.keras.Sequential([ # donkey car CNN
    tf.keras.layers.Conv2D(24,(5,5), strides=(2,2), padding='same', activation='relu',
                           input_shape=x_train.shape[1:]),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Conv2D(32,(5,5), strides=(2,2), padding='same', activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Conv2D(64,(5,5), strides=(2,2), padding='same', activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Conv2D(64,(3,3), padding='same', activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(100,activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(50,activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(4,activation='softmax')    
])
model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer='adam', metrics=['accuracy'])

history=model.fit(x_train, y_train, epochs=50,
                  validation_data=(x_valid, y_vaild))

loss=history.history['loss']

epochs=range(1,len(loss)+1)

import matplotlib.pyplot as plt

plt.plot(epochs, loss, 'g', label='Training loss')
plt.title('Traning loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend()
plt.show()

model.save('model.h5')

 

 

 

 

 

 

 

728x90
반응형