model.fit() 方法是 TensorFlow Keras 中用于训练模型的核心方法。
其中里面的callbacks参数是实现模型保存、监控、以及和tensorboard联动的重要API

1 model.fit() 方法的参数及使用

必需参数

  • x: 训练数据的输入。可以是 NumPy 数组、TensorFlow tf.data.Dataset、Python 生成器或 keras.utils.Sequence 实例。
  • y: 训练数据的目标(标签)。与输入 x 相对应,应该是 NumPy 数组或 TensorFlow tf.data.Dataset。当 xtf.data.Dataset、生成器或 Sequence 实例时,y 应该不被提供,因为 x 已经包含了输入和目标。

常用可选参数

  • batch_size: 整数,指定进行梯度更新时每个批次的样本数。默认值为 32。注意,当使用 tf.data.Dataset、生成器或 Sequence 作为输入时,不应指定 batch_size,因为这些数据结构已经定义了批次大小。
  • epochs: 整数,训练模型的轮数,即整个数据集的前向和反向传播次数。
  • verbose: 整数,日志显示模式。0 = 不在标准输出流中输出日志信息,1 = 进度条(默认),2 = 每轮一行。
  • callbacks: keras.callbacks.Callback 实例的列表。一系列在训练过程中会被调用的回调函数,用于查看训练过程中内部状态和统计信息。
  • validation_split: 浮点数,0 到 1 之间,用来指定一定比例的训练数据作为验证数据的比例。模型会在这些数据上评估损失和任何模型指标,但这些数据不会用于训练。
  • validation_data: 用作验证的数据。格式可以是 (X_val, y_val) 的元组,或者是 tf.data.Dataset。如果提供此参数,则不会根据 validation_split 从训练数据中分割验证数据。
  • shuffle: 布尔值或字符串,表示是否在每轮训练前打乱数据。默认为 True。当设置为 False 时,不会打乱数据。当输入为 tf.data.Dataset、生成器或 Sequence 实例时,此参数无效,因为这些数据结构可能已经定义了自己的打乱数据的方式。
  • initial_epoch: 用于恢复之前的训练。从该轮次开始训练,之前的轮次被视为已经训练过。

高级参数

  • steps_per_epoch: 整数,当使用生成器或 Sequence 实例作为输入时定义一个 epoch 完成并开始下一个 epoch 的总步数(批次数)。通常,应该等于数据集的样本数除以批次大小。
  • validation_steps: 当 validation_data 是生成器或 Sequence 实例时,此参数指定在停止前验证集的总步数(批次数)。
  • validation_batch_size: 整数,仅当 validation_data 是 NumPy 数组时有效。指定验证批次的大小。
  • validation_freq: 指定验证的频率。可以是整数,也可以是 'epoch' 或列表。如果是整数,则表示每多少个 epoch 验证一次。如果是列表,则列表中的元素指定了需要进行验证的 epoch。

使用示例

基本用法:

model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)

使用验证数据:

model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

使用回调函数:

from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor='val_loss', patience=3)
model.fit(x_train, y_train, epochs=10, validation_split=0.2, callbacks=[early_stopping])

使用 tf.data.Dataset

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32)
model.fit(train_dataset, epochs=10, validation_data=val_dataset)

model.fit() 方法提供了灵活的方式来训练模型,通过合理设置参数,可以有效地控制训练过程和评估模型性能。

2 callbacks参数使用

callbacks 参数是 model.fit() 方法中一个重要参数,属于keras的高级用法,它允许在训练的不同阶段(如训练开始、训练结束、每个 epoch 开始/结束时等)执行特定的操作。

callbacks 是一个 tf.keras.callbacks.Callback 实例的列表,每个实例都能够访问到模型的内部状态和统计信息。TensorFlow Keras 提供了多种内置的回调函数,同时也支持自定义回调。
以下是callbacks类的全部方法类(https://keras.io/api/callbacks/):
在这里插入图片描述

  1. ModelCheckpoint: 在训练过程中保存模型或模型权重。

    • filepath: 保存模型的路径。
    • monitor: 被监视的数据。
    • verbose: 详细信息模式。
    • save_best_only: 若为 True,则只保存在验证集上性能最好的模型。
    • save_weights_only: 若为 True,则只保存模型的权重。
    • mode: {auto, min, max} 中的一个。决定监视的数据是应该最大化还是最小化。
    • save_freq: 保存模型的频率。
    checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='model.h5', save_best_only=True, monitor='val_loss', mode='min')
    
  2. EarlyStopping: 当被监视的数据不再提升,则停止训练。

    • monitor: 被监视的数据。
    • min_delta: 改进的最小变化量,小于这个量的改进将被忽略。
    • patience: 没有进步的训练轮数,在这之后训练将被停止。
    • verbose: 详细信息模式。
    • mode: {auto, min, max} 中的一个。
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
    
  3. ReduceLROnPlateau: 当学习停滞时,减少学习率。

    • monitor: 被监视的数据。
    • factor: 学习率将以这个因子减少。新的学习率 = 学习率 * 因子。
    • patience: 没有进步的训练轮数,在这之后学习率将被减少。
    • verbose: 详细信息模式。
    • mode: {auto, min, max} 中的一个。
    • min_lr: 学习率的下限。
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
    
  4. TensorBoard: 为 TensorFlow 提供的可视化工具。

    • log_dir: 用来保存日志文件的路径,TensorBoard 将读取这个路径下的日志。
    • histogram_freq: 对于模型层的激活和权重直方图的计算频率(每个 epoch)。
    • write_graph: 是否在 TensorBoard 中可视化图形。如果 write_graph 被打开,日志文件会变得非常大。
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir='./logs')
    

使用示例

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10),
    tf.keras.callbacks.ModelCheckpoint(filepath='model.h5', save_best_only=True, monitor='val_loss', mode='min'),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001),
    tf.keras.callbacks.TensorBoard(log_dir='./logs')
]

model.fit(x_train, y_train, validation_split=0.2, epochs=50, callbacks=callbacks)

自定义回调

也可以通过继承 tf.keras.callbacks.Callback 类来创建自定义回调,允许在训练的不同阶段执行自定义的逻辑。

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # 每个 epoch 结束时执行
        keys = list(logs.keys())
        print(f"结束 epoch {epoch},损失 = {logs['loss']}, 验证损失 = {logs['val_loss']}")

model.fit(x_train, y_train, validation_split=0.2, epochs=50, callbacks=[CustomCallback()])

回调提供了一种灵活的方式来嵌入训练过程,使得你可以在不改变模型代码的情况下,监控模型的训练、保存模型、调整学习率等。

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐