订阅
纠错
加入自媒体

更复杂的体系结构能保证更好的模型吗?

2022-04-19 14:46
磐创AI
关注

使用的数据集和数据预处理

我们将使用Kaggle的狗与猫数据集。它是根据知识共享许可证授权的,这意味着你可以免费使用它:

该数据集相当大——25000张图像均匀分布在不同的类中(12500张狗图像和12500张猫图像)。它应该足够大,以训练一个像样的图像分类器。

你还应该删除train/cat/666.jpg和train/dog/11702.jpg图像,这些已经损坏,你的模型将无法使用它们进行训练。

接下来,让我们看看如何使用TensorFlow加载图像。

如何使用TensorFlow加载图像数据

今天你将看到的模型将比前几篇文章中的模型具有更多的层。

为了可读性,我们将从TensorFlow中导入单个类。如果你正在跟进,请确保有一个带有GPU的系统,或者至少使用Google Colab。

让我们把库的导入放在一边:

image.png

这是很多,但模型会因此看起来格外干净。

我们现在将像往常一样加载图像数据——使用ImageDataGenerator类。

我们将把图像矩阵转换为0–1范围,使用用三个颜色通道,将所有图像调整为224x224。出于内存方面的考虑,我们将barch大小降低到32:

image.png

以下是你应该看到的输出:

让我们鼓捣第一个模型!

向TensorFlow模型中添加层会有什么不同吗?

从头开始编写卷积模型总是一项棘手的任务。网格搜索最优架构是不可行的,因为卷积模型需要很长时间来训练,而且有太多的参数需要检查。实际上,你更有可能使用迁移学习。这是我们将在不久的将来探讨的主题。

今天,这一切都是为了理解为什么在模型架构上大刀阔斧是不值得的。我们用一个简单的模型获得了75%的准确率,所以这是我们必须超越的基线。

模型1-两个卷积块

我们将宣布第一个模型在某种程度上类似于VGG体系结构——两个卷积层,后面是一个池层。滤波器设置如下,第一个块32个,第二个块64个。

至于损失和优化器,我们将坚持基本原则——分类交叉熵和Adam。数据集中的类是完全平衡的,这意味着我们只需跟踪准确率即可:

model_1 = tf.keras.Sequential([

   Conv2D(filters=32, kernel_size=(3, 3), input_shape=(224, 224, 3), activation='relu'),

   Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Flatten(),

   Dense(units=128, activation='relu'),

   Dense(units=2, activation='softmax')

])

model_1.compile(

   loss=categorical_crossentropy,

   optimizer=Adam(),

   metrics=[BinaryAccuracy(name='accuracy')]

model_1_history = model_1.fit(

   train_data,

   validation_data=valid_data,

   epochs=10

以下是经过10个epoch后的训练结果:

看起来我们的表现并没有超过基线,因为验证准确率仍然在75%左右。如果我们再加上一个卷积块会发生什么?

模型2-三个卷积块

我们将保持模型体系结构相同,唯一的区别是增加了一个包含128个滤波器的卷积块:

model_2 = Sequential([

   Conv2D(filters=32, kernel_size=(3, 3), input_shape=(224, 224, 3), activation='relu'),

   Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=128, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=128, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Flatten(),

   Dense(units=128, activation='relu'),

   Dense(units=2, activation='softmax')

])

model_2.compile(

   loss=categorical_crossentropy,

   optimizer=Adam(),

   metrics=[BinaryAccuracy(name='accuracy')]

model_2_history = model_2.fit(

   train_data,

   validation_data=valid_data,

   epochs=10

日志如下:

效果变差了。虽然你可以随意调整batch大小和学习率,但效果可能仍然不行。第一个架构在我们的数据集上工作得更好,所以让我们试着继续调整一下。

模型3-带Dropout的卷积块

第三个模型的架构与第一个模型相同,唯一的区别是增加了一个全连接层和一个Dropout层。让我们看看这是否会有所不同:

model_3 = tf.keras.Sequential([

   Conv2D(filters=32, kernel_size=(3, 3), input_shape=(224, 224, 3), activation='relu'),

   Conv2D(filters=32, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),

   MaxPool2D(pool_size=(2, 2), padding='same'),
   

   Flatten(),

   Dense(units=512, activation='relu'),

   Dropout(rate=0.3),

   Dense(units=128),

   Dense(units=2, activation='softmax')

])

model_3.compile(

   loss=categorical_crossentropy,

   optimizer=Adam(),

   metrics=[BinaryAccuracy(name='accuracy')]

model_3_history = model_3.fit(

   train_data,

   validation_data=valid_data,

   epochs=10

以下是训练日志:

太可怕了,现在还不到70%!上一篇文章中的简单架构非常好。反而是数据质量问题限制了模型的预测能力。

结论

这就证明了,更复杂的模型体系结构并不一定会产生性能更好的模型。也许你可以找到一个更适合猫狗数据集的架构,但这可能是徒劳的。

你应该将重点转移到提高数据集质量上。当然,有20K个训练图像,但我们仍然可以增加多样性。这就是数据增强的用武之地。

感谢阅读!

       原文标题 : 更复杂的体系结构能保证更好的模型吗?

声明: 本文由入驻维科号的作者撰写,观点仅代表作者本人,不代表OFweek立场。如有侵权或其他问题,请联系举报。

发表评论

0条评论,0人参与

请输入评论内容...

请输入评论/评论长度6~500个字

您提交的评论过于频繁,请输入验证码继续

暂无评论

暂无评论

    人工智能 猎头职位 更多
    扫码关注公众号
    OFweek人工智能网
    获取更多精彩内容
    文章纠错
    x
    *文字标题:
    *纠错内容:
    联系邮箱:
    *验 证 码:

    粤公网安备 44030502002758号