Джеймс Дэвис – Нейросети: создание и оптимизация будущего (страница 27)
4. Снижение сложности модели: уменьшить количество фильтров или слоев.
Пример кода для реализации обучения модели, который иллюстрирует переобучение. Здесь используется свёрточная нейронная сеть (CNN) на основе TensorFlow/Keras, обучающаяся на наборе данных CIFAR-10.
Мы намеренно создадим ситуацию переобучения, отключив регуляризацию и используя слишком большую архитектуру для небольшого набора данных.
Код:
```python
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
# Загрузка данных CIFAR-10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# Нормализация данных
x_train, x_test = x_train / 255.0, x_test / 255.0
# Кодирование меток в one-hot
y_train, y_test = to_categorical(y_train), to_categorical(y_test)
# Выбор небольшой части данных для обучения
x_train_small = x_train[:2000]
y_train_small = y_train[:2000]
x_val = x_train[2000:2500]
y_val = y_train[2000:2500]
# Определение модели
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Conv2D(128, (3, 3), activation='relu'),
Flatten(),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
# Компиляция модели
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
# Обучение модели
history = model.fit(
x_train_small, y_train_small,
epochs=15,
batch_size=32,
validation_data=(x_val, y_val),
verbose=2
)
# Визуализация результатов
plt.figure(figsize=(12, 5))
# График точности
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy', marker='o')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy', marker='o')
plt.title('Accuracy vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
# График потерь
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss', marker='o')
plt.plot(history.history['val_loss'], label='Validation Loss', marker='o')
plt.title('Loss vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)