реклама
Бургер менюБургер меню

Джеймс Дэвис – Нейросети: создание и оптимизация будущего (страница 27)

18

4. Снижение сложности модели: уменьшить количество фильтров или слоев.

Пример кода для реализации обучения модели, который иллюстрирует переобучение. Здесь используется свёрточная нейронная сеть (CNN) на основе TensorFlow/Keras, обучающаяся на наборе данных CIFAR-10.

Мы намеренно создадим ситуацию переобучения, отключив регуляризацию и используя слишком большую архитектуру для небольшого набора данных.

Код:

```python

import tensorflow as tf

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

from tensorflow.keras.datasets import cifar10

from tensorflow.keras.utils import to_categorical

import matplotlib.pyplot as plt

# Загрузка данных CIFAR-10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Нормализация данных

x_train, x_test = x_train / 255.0, x_test / 255.0

# Кодирование меток в one-hot

y_train, y_test = to_categorical(y_train), to_categorical(y_test)

# Выбор небольшой части данных для обучения

x_train_small = x_train[:2000]

y_train_small = y_train[:2000]

x_val = x_train[2000:2500]

y_val = y_train[2000:2500]

# Определение модели

model = Sequential([

Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),

MaxPooling2D((2, 2)),

Conv2D(64, (3, 3), activation='relu'),

MaxPooling2D((2, 2)),

Conv2D(128, (3, 3), activation='relu'),

Flatten(),

Dense(128, activation='relu'),

Dense(10, activation='softmax')

])

# Компиляция модели

model.compile(optimizer='adam',

loss='categorical_crossentropy',

metrics=['accuracy'])

# Обучение модели

history = model.fit(

x_train_small, y_train_small,

epochs=15,

batch_size=32,

validation_data=(x_val, y_val),

verbose=2

)

# Визуализация результатов

plt.figure(figsize=(12, 5))

# График точности

plt.subplot(1, 2, 1)

plt.plot(history.history['accuracy'], label='Train Accuracy', marker='o')

plt.plot(history.history['val_accuracy'], label='Validation Accuracy', marker='o')

plt.title('Accuracy vs Epochs')

plt.xlabel('Epochs')

plt.ylabel('Accuracy')

plt.legend()

plt.grid(True)

# График потерь

plt.subplot(1, 2, 2)

plt.plot(history.history['loss'], label='Train Loss', marker='o')

plt.plot(history.history['val_loss'], label='Validation Loss', marker='o')

plt.title('Loss vs Epochs')

plt.xlabel('Epochs')

plt.ylabel('Loss')

plt.legend()

plt.grid(True)