过拟合问题实战
1.构建数据集
我们使用的数据集样本特性向量长度为 2,标签为 0 或 1,分别代表了 2 种类别。借助于 scikit-learn 库中提供的 make_moons 工具我们可以生成任意多数据的训练集。
import matplotlib.pyplot as plt # 导入数据集生成工具 import numpy as np import seaborn as sns from sklearn.datasets import make_moons from sklearn.model_selection import train_test_split from tensorflow.keras import layers, Sequential, regularizers from mpl_toolkits.mplot3d import Axes3D
为了演示过拟合现象,我们只采样了 1000 个样本数据,同时添加标准差为 0.25 的高斯噪声数据:
def load_dataset(): # 采样点数 N_SAMPLES = 1000 # 测试数量比率 TEST_SIZE = None # 从 moon 分布中随机采样 1000 个点,并切分为训练集-测试集 X, y = make_moons(n_samples=N_SAMPLES, noise=0.25, random_state=100) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=42) return X, y, X_train, X_test, y_train, y_test
make_plot 函数可以方便地根据样本的坐标 X 和样本的标签 y 绘制出数据的分布图:
def make_plot(X, y, plot_name, file_name, XX=None, YY=None, preds=None, dark=False, output_dir=OUTPUT_DIR): # 绘制数据集的分布, X 为 2D 坐标, y 为数据点的标签 if dark: plt.style.use('dark_background') else: sns.set_style("whitegrid") axes = plt.gca() axes.set_xlim([-2, 3]) axes.set_ylim([-1.5, 2]) axes.set(xlabel="$x_1$", ylabel="$x_2$") plt.title(plot_name, fontsize=20, fontproperties='SimHei') plt.subplots_adjust(left=0.20) plt.subplots_adjust(right=0.80) if XX is not None and YY is not None and preds is not None: plt.contourf(XX, YY, preds.reshape(XX.shape), 25, alpha=0.08, cmap=plt.cm.Spectral) plt.contour(XX, YY, preds.reshape(XX.shape), levels=[.5], cmap="Greys", vmin=0, vmax=.6) # 绘制散点图,根据标签区分颜色m=markers markers = ['o' if i == 1 else 's' for i in y.ravel()] mscatter(X[:, 0], X[:, 1], c=y.ravel(), s=20, cmap=plt.cm.Spectral, edgecolors='none', m=markers, ax=axes) # 保存矢量图 plt.savefig(output_dir + '/' + file_name) plt.close()
def mscatter(x, y, ax=None, m=None, **kw): import matplotlib.markers as mmarkers if not ax: ax = plt.gca() sc = ax.scatter(x, y, **kw) if (m is not None) and (len(m) == len(x)): paths = [] for marker in m: if isinstance(marker, mmarkers.MarkerStyle): marker_obj = marker else: marker_obj = mmarkers.MarkerStyle(marker) path = marker_obj.get_path().transformed( marker_obj.get_transform()) paths.append(path) sc.set_paths(paths) return sc
X, y, X_train, X_test, y_train, y_test = load_dataset() make_plot(X,y,"haha",'月牙形状二分类数据集分布.svg')
2.网络层数的影响
为了探讨不同的网络深度下的过拟合程度,我们共进行了 5 次训练实验。在