trainer.py
ultralytics\engine\trainer.py
目录
trainer.py
1.所需的库和模块
2.class BaseTrainer:
1.所需的库和模块
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Train a model on a dataset.
Usage:
$ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
"""
import gc
import math
import os
import subprocess
import time
import warnings
from copy import copy, deepcopy
from datetime import datetime, timedelta
from pathlib import Path
import numpy as np
import torch
from torch import distributed as dist
from torch import nn, optim
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (
DEFAULT_CFG,
LOCAL_RANK,
LOGGER,
RANK,
TQDM,
__version__,
callbacks,
clean_url,
colorstr,
emojis,
yaml_save,
)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (
TORCH_2_4,
EarlyStopping,
ModelEMA,
autocast,
convert_optimizer_state_dict_to_fp16,
init_seeds,
one_cycle,
select_device,
strip_optimizer,
torch_distributed_zero_first,
)
2.class BaseTrainer:
# 这段代码定义了一个名为 BaseTrainer 的类,它是用于创建训练器的基础类,封装了训练深度学习模型所需的各种功能和属性。
# 定义了一个名为 BaseTrainer 的类,作为训练器的基础类,用于封装训练模型所需的各种功能。
class BaseTrainer:
# 用于创建训练器的基类。
"""
A base class for creating trainers.
Attributes:
args (SimpleNamespace): Configuration for the trainer.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
save_dir (Path): Directory to save results.
wdir (Path): Directory to save weights.
last (Path): Path to the last checkpoint.
best (Path): Path to the best checkpoint.
save_period (int): Save checkpoint every x epochs (disabled if < 1).
batch_size (int): Batch size for training.
epochs (int): Number of epochs to train for.
start_epoch (int): Starting epoch for training.
device (torch.device): Device to use for training.
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
scaler (amp.GradScaler): Gradient scaler for AMP.
data (str): Path to data.
trainset (torch.utils.data.Dataset): Training dataset.
testset (torch.utils.data.Dataset): Testing dataset.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
resume (bool): Resume training from a checkpoint.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
fitness (float): Current fitness value.
loss (float): Current loss value.
tloss (float): Total loss value.
loss_names (list): List of loss names.
csv (Path): Path to results CSV file.
"""
# 这段代码是 BaseTrainer 类的初始化方法 __init__ ,用于设置训练器的基本属性和初始化训练所需的资源。
# 定义了 BaseTrainer 类的初始化方法,接受以下参数 :
# 1.cfg :配置文件路径,默认为 DEFAULT_CFG 。
# 2.overrides :配置覆盖项,用于覆盖默认配置。
# 3._callbacks :回调函数字典,用于自定义训练过程中的行为。
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
# 初始化 BaseTrainer 类。
"""
Initializes the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
# 通过 get_cfg 函数加载配置文件,并应用覆盖项,将结果存储在 self.args 中。 self.args 是一个 SimpleNamespace 对象,用于存储训练器的配置参数。
# def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
# -> 用于处理和验证配置信息,最终返回一个配置对象。将最终的配置字典 cfg 转换为 IterableSimpleNamespace 对象并返回。 IterableSimpleNamespace 是一个可迭代的命名空间对象,支持通过点符号访问属性(如 cfg.name ),同时也支持字典操作(如 cfg["name"] )。
# -> return IterableSimpleNamespace(**cfg)
self.args = get_cfg(cfg, overrides)
# 调用 check_resume 方法 检查是否需要从检查点恢复训练 。如果 overrides 中包含恢复训练的配置,则更新 self.args 。
self.check_resume(overrides)
# 根据配置 选择训练设备 (如 GPU 或 CPU),并考虑批量大小对设备选择的影响。 select_device 函数会根据设备类型和批量大小选择最优的设备。
# def select_device(device="", batch=0, newline=False, verbose=True):
# -> 用于根据用户指定的设备字符串(如 "cpu" 、 "cuda" 、 "mps" 等)选择合适的计算设备(CPU、GPU 或 MPS),并进行一系列的检查和配置,以确保设备选择的正确性和兼容性。返回一个 torch.device 对象,表示 最终选择的设备 。
# -> return torch.device(arg)
self.device = select_device(self.args.device, self.args.batch)
# 初始化 验证器 为 None ,验证器用于 在训练过程中评估模型性能 。
self.validator = None
# 初始化 训练过程中的指标 为 None 。
self.metrics = None
# 初始化一个字典 用于存储训练过程中的可视化数据 。
self.plots = {}
# 初始化随机种子 ,确保训练过程的可重复性。 RANK 是分布式训练中的进程编号, deterministic 控制是否启用确定性算法。
# def init_seeds(seed=0, deterministic=False): -> 用于初始化随机种子,确保代码的可重复性(reproducibility)。它同时支持单 GPU 和多 GPU 环境,并提供了确定性(deterministic)和非确定性(non-deterministic)模式的选择。
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
# 通过 get_save_dir 函数获取保存训练结果的目录。
# def get_save_dir(args, name=None): -> 用于根据输入参数 args 和可选参数 name 生成一个保存目录路径( save_dir )。它的主要功能是根据用户提供的参数动态生成保存目录,并确保目录路径的唯一性。返回最终的保存目录路径,确保其类型为 Path 对象。 -> return Path(save_dir)
self.save_dir = get_save_dir(self.args)
# 更新配置中的名称 ,以便日志记录器使用。
self.args.name = self.save_dir.name # update name for loggers
# 定义 保存权重的目录 。
self.wdir = self.save_dir / "weights" # weights dir
# 如果当前进程是主进程( RANK 为 -1 或 0 )。 则创建权重目录并保存配置文件。 yaml_save 函数将配置参数保存为 YAML 文件。
if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
# 定义 保存最新检查点 和 最佳检查点的路径 。
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
# 设置 保存检查点的时间间隔 。
self.save_period = self.args.save_period
# 设置训练的 批量大小 。
self.batch_size = self.args.batch
# 设置训练的 总轮数 ,默认为 100 轮。如果用户未指定轮数,则默认为 100 轮。
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
# 初始化 起始轮数 为 0。
self.start_epoch = 0
# 如果当前是单进程训练,则打印配置参数。
if RANK == -1:
# def print_args(args: Optional[dict] = None, show_file=True, show_func=False): -> 用于打印函数的参数及其值。它支持可选的参数字典,并可以选择性地显示文件名和函数名。
print_args(vars(self.args))
# Device
# 如果设备是 CPU 或 MPS,则 将数据加载器的工作进程数 设置为 0,以加快 CPU 训练速度。
if self.device.type in {"cpu", "mps"}:
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
# 加载模型文件,并根据需要添加文件后缀(如 .pt )。 check_model_file_from_stem 函数会检查模型文件是否存在,并返回完整的路径。
# def check_model_file_from_stem(model="yolo11n"):
# -> 用于根据模型的名称(或“stem”)返回完整的模型文件名。如果条件满足,则使用 Path(model).with_suffix(".pt") 为模型名称添加 .pt 扩展名,并返回完整的文件路径。例如,输入 "yolo11n" 会返回 Path("yolo11n.pt") 。如果条件不满足,则直接返回原始的 model 输入。
# -> return Path(model).with_suffix(".pt") / return model
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt
# 在分布式训练中, 确保数据集只被下载一次 。 torch_distributed_zero_first 是一个上下文管理器,用于在分布式训练中同步进程。
# def torch_distributed_zero_first(local_rank: int): -> 用于在分布式训练中确保所有进程等待本地主进程(rank 0)完成某个任务后再继续执行。这种机制常用于在分布式训练中同步操作,例如在主进程完成数据加载或模型初始化后再让其他进程继续执行。
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
self.trainset, self.testset = self.get_dataset()
# 初始化 指数移动平均 (EMA)模型为 None 。
self.ema = None
# Optimization utils init
# 初始化 学习率调度器的调度函数 为 None 。
self.lf = None
# 初始化 学习率调度器 为 None 。
self.scheduler = None
# Epoch level metrics
# 初始化 best_fitness 属性为 None 。 best_fitness 用于记录训练过程中模型达到的 最佳性能指标 。性能指标可以是准确率、mAP(平均精度)、损失值等,具体取决于任务类型。在训练过程中, best_fitness 会被更新为模型在验证集上表现最好的那个指标值。
self.best_fitness = None
# 初始化 fitness 属性为 None 。 fitness 用于记录当前轮次(epoch)模型的 性能指标 。与 best_fitness 不同, fitness 是动态变化的,反映了模型在当前轮次的表现。它通常用于判断模型是否在训练过程中有所改进。
self.fitness = None
# 初始化 loss 属性为 None 。 loss 用于记录 当前轮次的损失值 。损失值是模型训练过程中的关键指标,反映了模型预测值与真实值之间的差异。较低的损失值通常意味着模型表现更好。
self.loss = None
# 初始化 tloss 属性为 None 。 tloss 代表 总损失值 ,可能是一个累积值或平均值,用于记录整个训练过程中的损失变化。它通常用于绘制损失曲线,帮助分析模型的收敛情况。
self.tloss = None
# 初始化 loss_names 属性为一个列表,包含一个默认的损失名称 "Loss" 。 loss_names 用于 记录不同损失项的名称 ,例如在多任务学习或复杂模型中,可能会有多个损失项(如分类损失、回归损失等)。这个列表在后续的训练过程中可能会被更新,以反映实际的损失项名称。
self.loss_names = ["Loss"]
# 初始化 csv 属性为 保存训练结果的 CSV 文件路径 。该路径是通过 self.save_dir (保存训练结果的目录)和文件名 "results.csv" 拼接而成的。训练过程中的关键指标(如损失值、性能指标等)会被记录到这个 CSV 文件中,便于后续分析和可视化。
self.csv = self.save_dir / "results.csv"
# 初始化 plot_idx 属性为一个列表 [0, 1, 2] 。 plot_idx 用于 记录在训练过程中需要绘制的特定批次(batch)的索引 。这些索引通常用于可视化训练样本的预测结果,帮助开发者直观地了解模型的行为。例如,在某些特定的批次(如第一个、中间一个和最后一个)绘制预测结果,以便观察模型在不同阶段的表现。
self.plot_idx = [0, 1, 2]
# HUB
# 初始化 HUB 会话为 None 。
self.hub_session = None
# Callbacks
# 初始化 回调函数字典 ,如果未提供,则使用默认回调函数。
self.callbacks = _callbacks or callbacks.get_default_callbacks()
# 如果当前是主进程,则添加集成回调函数。
if RANK in {-1, 0}:
callbacks.add_integration_callbacks(self)
# 这段代码的初始化方法 __init__ 是 BaseTrainer 类的核心,负责设置训练器的基本属性和初始化训练所需的资源。它包括以下功能。加载配置文件:通过 get_cfg 函数加载配置文件,并应用覆盖项。选择设备:根据配置选择训练设备(如 GPU 或 CPU)。初始化目录和文件:创建保存训练结果的目录,并保存配置文件。加载模型和数据集:加载模型文件,并获取训练和测试数据集。初始化优化器和调度器:初始化学习率调度器和优化器。设置回调函数:初始化回调函数字典,并添加集成回调函数。设置随机种子:确保训练过程的可重复性。这些初始化步骤为后续的训练过程(如 _do_train 方法)提供了必要的基础。
# 这段代码定义了一个名为 add_callback 的方法,用于向指定事件添加回调函数。
# 定义了 add_callback 方法,它接受两个参数。
# 1.event (类型为 str ) :表示事件的名称,例如 "on_train_start" 、 "on_epoch_end" 等。
# 2.callback :一个可调用的函数,会在指定事件发生时被触发。
def add_callback(self, event: str, callback):
# 附加给定的回调。
"""Appends the given callback."""
# self.callbacks 是一个字典,键为 事件名称 ,值为 回调函数列表 。 通过 self.callbacks[event] 访问指定事件的回调函数列表,并使用 append 方法将新的回调函数 callback 添加到列表中。
self.callbacks[event].append(callback)
# add_callback 方法的主要功能是。动态添加回调函数:允许用户在运行时为特定事件添加自定义的回调函数。支持多个回调函数:同一个事件可以绑定多个回调函数,这些函数会被依次触发。灵活性和扩展性:通过回调机制,用户可以在不修改核心代码的情况下,扩展训练器的功能,例如:在训练开始时记录日志。在每个轮次结束时保存模型。在训练过程中动态调整学习率。
# 使用场景示例 :
# 假设正在使用 BaseTrainer 类进行模型训练,并希望在每个轮次结束时打印当前的损失值。可以定义一个回调函数并将其添加到 "on_epoch_end" 事件中 :
# def print_loss(trainer):
# print(f"Epoch {trainer.epoch}: Loss = {trainer.loss}")
# # 添加回调函数
# trainer.add_callback("on_epoch_end", print_loss)
# 在训练过程中,每当 "on_epoch_end" 事件被触发时, print_loss 函数就会被调用,并打印当前轮次的损失值。这种设计使得训练器更加灵活,用户可以根据自己的需求添加各种自定义行为,而无需修改训练器的核心代码。
# 这段代码定义了一个名为 set_callback 的方法,用于为指定事件设置(或覆盖)回调函数。与 add_callback 方法不同, set_callback 会覆盖指定事件的所有现有回调函数,而不是将新的回调函数添加到现有列表中。
# 定义了 set_callback 方法,它接受两个参数。
# 1.event (类型为 str ) :表示事件的名称,例如 "on_train_start" 、 "on_epoch_end" 等。
# 2.callback :一个可调用的函数,会在指定事件发生时被触发。
def set_callback(self, event: str, callback):
# 用给定的回调覆盖现有的回调。
"""Overrides the existing callbacks with the given callback."""
# self.callbacks 是一个字典,键为事件名称,值为回调函数列表。 通过 self.callbacks[event] 访问指定事件的回调函数列表,并将其设置为只包含一个回调函数 [callback] 。 这一步会覆盖掉之前为该事件设置的所有回调函数。
self.callbacks[event] = [callback]
# set_callback 方法的主要功能是。设置新的回调函数:为指定事件设置一个新的回调函数。覆盖现有回调函数:如果指定事件已经存在回调函数, set_callback 会覆盖它们,只保留新设置的回调函数。简化回调管理:当用户只想为某个事件设置一个回调函数时, set_callback 提供了一种更直接的方式,避免了手动清空现有回调函数的复杂性。通过结合使用 add_callback 和 set_callback ,用户可以灵活地管理训练过程中的回调函数,满足不同的需求。
# 这段代码定义了一个名为 run_callbacks 的方法,用于在指定事件发生时触发所有与该事件关联的回调函数。
# 定义了 run_callbacks 方法,它接受一个参数。
# 1.event (类型为 str ) :表示事件的名称,例如 "on_train_start" 、 "on_epoch_end" 等。
def run_callbacks(self, event: str):
# 运行与特定事件相关的所有现有回调。
"""Run all existing callbacks associated with a particular event."""
# self.callbacks 是一个字典,键为事件名称,值为回调函数列表。 使用 self.callbacks.get(event, []) 获取与指定事件关联的回调函数列表。如果没有找到该事件,则返回一个空列表 [] 。 遍历回调函数列表中的每一个回调函数 callback 。
for callback in self.callbacks.get(event, []):
# 调用回调函数,并将 self (即当前的训练器实例)作为参数传递给回调函数。这使得回调函数可以访问训练器的属性和方法,例如当前的轮次、损失值、模型等。
callback(self)
# run_callbacks 方法的主要功能是。触发回调函数:在指定事件发生时,触发所有与该事件关联的回调函数。提供灵活性:通过回调机制,用户可以在不修改核心代码的情况下,为训练过程添加自定义行为。确保安全调用:即使某个事件没有关联任何回调函数, run_callbacks 方法也不会报错,而是安全地跳过。
# 使用场景示例假 :
# 设已经为某个事件(如 "on_epoch_end" )添加了多个回调函数,例如保存模型、打印日志和调整学习率。在训练过程中,每当轮次结束时, run_callbacks 方法会被调用,触发所有这些回调函数 :
# # 定义回调函数
# def save_model(trainer):
# print(f"Saving model at epoch {trainer.epoch}")
# trainer.save_model()
# def print_loss(trainer):
# print(f"Epoch {trainer.epoch}: Loss = {trainer.loss}")
# # 添加回调函数
# trainer.add_callback("on_epoch_end", save_model)
# trainer.add_callback("on_epoch_end", print_loss)
# # 在训练过程中触发回调
# trainer.run_callbacks("on_epoch_end")
# 输出示例 :
# Saving model at epoch 1
# Epoch 1: Loss = 0.5
# 与 add_callback 和 set_callback 的关系 :
# add_callback :用于向指定事件添加新的回调函数。
# set_callback :用于覆盖指定事件的所有回调函数,只保留一个回调函数。
# run_callbacks :在指定事件发生时,触发所有与该事件关联的回调函数。
# 这三个方法共同构成了一个完整的回调机制,使得训练器能够在不同阶段执行用户定义的自定义逻辑。
# 这段代码定义了 train 方法,它是 BaseTrainer 类的核心方法,用于启动训练过程。它根据配置自动选择单机训练或多GPU分布式训练(DDP),并处理相关的逻辑。
# 定义了 train 方法,用于启动训练过程。
def train(self):
# 允许多 GPU 系统上的 device=''、device=None 默认为 device=0。
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
# 这段代码的作用是根据 self.args.device 的值来确定 训练时使用的设备数量 ( world_size )。 world_size 是分布式训练中的一个重要参数,它表示参与训练的设备(通常是 GPU)的数量。
# 如果 self.args.device 是一个非空字符串(例如 '0' 或 '0,1,2,3' ),表示设备编号以逗号分隔。
# 使用 split(",") 将字符串按逗号分割成一个列表,然后计算列表的长度,得到设备数量( world_size )。
# 示例 :如果 self.args.device = '0,1,2' ,则 world_size = 3 。
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(","))
# 如果 self.args.device 是一个元组或列表(例如 [0, 1, 2, 3] ),表示设备编号以列表形式给出。
# 直接计算列表或元组的长度,得到设备数量( world_size )。
# 示例 :如果 self.args.device = [0, 1, 2] ,则 world_size = 3 。
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
# 如果 self.args.device 是 'cpu' 或 'mps' ,表示使用 CPU 或 Apple Metal Performance Shaders(MPS)进行训练。 在这种情况下, world_size 被设置为 0,表示不使用分布式训练。
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
world_size = 0
# 如果 self.args.device 未指定或为空(例如 device=None 或 device='' ),并且 CUDA 可用。 默认使用单个 GPU( world_size = 1 ),并假设使用设备 0。 这种情况下通常用于单 GPU 训练。
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
world_size = 1 # default to device 0
# 如果 self.args.device 未指定或为空,且 CUDA 不可用。 设置 world_size = 0 ,表示不使用 GPU,通常用于 CPU 训练。
else: # i.e. device=None or device=''
world_size = 0
# 这段代码通过检查 self.args.device 的值,确定训练时使用的设备数量( world_size )。它支持以下几种设备配置。多 GPU 分布式训练:通过字符串(如 '0,1,2,3' )或列表(如 [0, 1, 2, 3] )指定多个 GPU。world_size 为 GPU 数量。单 GPU 训练:如果未指定设备,但 CUDA 可用,默认使用单个 GPU( world_size = 1 )。CPU 或 MPS 训练:如果使用 'cpu' 或 'mps' , world_size 设置为 0,表示不使用分布式训练。这种设计使得代码能够灵活地处理不同的设备配置,并为后续的训练逻辑(如单机训练或分布式训练)提供必要的信息。
# 这段代码的作用是检查是否需要启动分布式数据并行(DDP)训练,并在启动之前对某些不兼容的参数进行调整。
# Run subprocess if DDP training, else train normally
# 检查是否需要启动分布式数据并行(DDP)训练。
# world_size > 1 :表示有多个设备(GPU)参与训练。
# "LOCAL_RANK" not in os.environ :表示当前环境尚未启动 DDP 训练( LOCAL_RANK 是 DDP 环境中用于标识本地进程的环境变量)。
# 如果同时满足这两个条件,则需要启动 DDP 训练。
if world_size > 1 and "LOCAL_RANK" not in os.environ:
# Argument checks
# 检查 self.args.rect 参数。
if self.args.rect:
# 如果 self.args.rect 为 True ,表示启用了矩形训练( rect=True ),但矩形训练与多 GPU 训练不兼容。 发出警告,并将 self.args.rect 设置为 False ,以确保训练可以正常进行。
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") # 警告⚠️'rect=True' 与多 GPU 训练不兼容,请设置'rect=False'。
self.args.rect = False
# 检查 self.args.batch 参数。
if self.args.batch < 1.0:
# 如果 self.args.batch 小于 1.0,表示启用了自动批量大小调整(AutoBatch),但此功能与多 GPU 训练不兼容。 发出警告,并将批量大小设置为默认值 16。
LOGGER.warning(
"WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting " # 警告⚠️ AutoBatch 的“batch<1”与多 GPU 训练不兼容,设置默认“batch=16”。
"default 'batch=16'"
)
self.args.batch = 16
# 这段代码的主要功能是。检测是否需要启动 DDP 训练:如果 world_size > 1 且 LOCAL_RANK 环境变量不存在,则需要启动 DDP 训练。调整不兼容的参数:如果启用了矩形训练( rect=True ),但多 GPU 训练不支持此功能,将 rect 设置为 False 。如果启用了自动批量大小调整( batch<1 ),但多 GPU 训练不支持此功能,将批量大小设置为默认值 16。确保训练的兼容性:在启动 DDP 训练之前,通过调整参数确保训练过程不会因不兼容的设置而失败。
#
# 为什么需要这些检查?
# 在分布式训练中,某些功能可能与多 GPU 训练不兼容。例如 :
# 矩形训练( rect=True ) :通常用于单 GPU 训练,通过调整图像尺寸以减少填充,从而提高训练效率。但在多 GPU 训练中,这种调整可能会导致数据不一致。
# 自动批量大小调整( batch<1 ) :通常用于单 GPU 训练,通过动态调整批量大小以适应内存限制。但在多 GPU 训练中,批量大小通常需要手动设置,以确保所有 GPU 的负载均衡。
# 通过在启动 DDP 训练之前调整这些参数,可以避免潜在的错误,并确保训练过程的顺利进行。
# 这段代码是 train 方法的一部分,用于处理分布式数据并行(DDP)训练的启动逻辑。如果需要进行 DDP 训练,则通过子进程启动训练;否则,直接调用 _do_train 方法进行单机训练。
# Command
# 调用 generate_ddp_command 函数,生成用于启动 DDP 训练的命令 ( cmd )和 临时文件路径 ( file )。 world_size 是参与训练的设备数量(GPU 数量)。 self 是当前训练器实例,包含训练所需的配置和状态。
# def generate_ddp_command(world_size, trainer): -> 用于生成分布式训练所需的命令,并返回该命令及其对应的临时文件路径。返回生成的命令列表 cmd 和临时文件路径 file 。 -> return cmd, file
cmd, file = generate_ddp_command(world_size, self)
# 在尝试运行 DDP 训练之前,记录调试信息,打印生成的命令。 colorstr('DDP:') 用于在日志中添加颜色标记,突出显示 DDP 相关的日志。 {' '.join(cmd)} 将命令列表拼接成一个字符串,便于查看。
try:
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}") # {colorstr('DDP:')} 调试命令 {' '.join(cmd)} 。
# 使用 subprocess.run 执行生成的 DDP 命令。 check=True 表示如果命令执行失败(返回非零退出码),会抛出异常。
subprocess.run(cmd, check=True)
# 捕获并重新抛出异常,确保错误能够被上层捕获和处理。 这一步确保了在 DDP 训练启动失败时,能够明确地报告错误原因。
except Exception as e:
raise e
# 无论训练是否成功,都会执行 ddp_cleanup 函数,清理临时文件和其他资源。 str(file) 是 临时文件的路径 ,用于在训练完成后进行清理。
finally:
# def ddp_cleanup(trainer, file): -> 用于在分布式数据并行(DDP)训练完成后清理临时文件。
ddp_cleanup(self, str(file))
# 如果不需要启动 DDP 训练( world_size <= 1 或 LOCAL_RANK 已设置),直接调用 _do_train 方法进行单机训练。 _do_train 方法是实际执行训练逻辑的核心函数。
else:
self._do_train(world_size)
# 这段代码的主要功能是。生成 DDP 训练命令:使用 generate_ddp_command 函数生成启动 DDP 训练所需的命令和临时文件路径。启动 DDP 训练:使用 subprocess.run 在子进程中执行生成的命令。如果命令执行失败,捕获并重新抛出异常,确保错误能够被明确报告。清理资源:在训练完成后,无论成功与否,都会调用 ddp_cleanup 函数清理临时文件和其他资源。单机训练逻辑:如果不需要 DDP 训练,则直接调用 _do_train 方法进行单机训练。
# 为什么需要通过子进程启动 DDP 训练?
# 在分布式训练中,每个进程需要独立运行,并且可能需要在不同的 GPU 上执行相同的代码。通过子进程启动训练,可以确保每个进程都正确初始化 PyTorch 的分布式环境(如 torch.distributed.init_process_group ),并正确分配到指定的 GPU 上。这种方法可以避免在主进程中直接启动分布式训练时可能出现的资源冲突或初始化问题。
# train 方法的主要功能是。检测训练环境:根据 self.args.device 的值,判断是单机训练还是多GPU分布式训练。检测 CUDA 是否可用,并设置设备数量( world_size )。处理分布式训练的特殊情况:如果启用多GPU训练,检查是否需要调整某些不兼容的参数(如 rect 和 batch )。生成分布式训练的命令,并通过子进程启动训练。执行训练逻辑:如果是分布式训练,通过子进程启动训练。如果是单机训练或分布式训练的子进程,调用 _do_train 方法执行实际的训练逻辑。异常处理和清理:捕获分布式训练过程中可能发生的异常。在训练完成后执行清理操作,确保资源被正确释放。这种设计使得 train 方法能够灵活地支持单机和分布式训练场景,同时确保训练过程的健壮性和可扩展性。
# 这段代码定义了 _setup_scheduler 方法,用于初始化学习率调度器(Learning Rate Scheduler)。学习率调度器的作用是在训练过程中动态调整学习率,以优化训练效果和收敛速度。
# 定义了 _setup_scheduler 方法,用于初始化学习率调度器。
def _setup_scheduler(self):
# 初始化训练学习率调度程序。
"""Initialize training learning rate scheduler."""
# 如果 self.args.cos_lr 为 True ,表示使用 余弦退火学习率调度器 (Cosine Annealing)。
# 调用 one_cycle 函数生成 学习率调度函数 self.lf 。
# 1 表示初始学习率的倍率(通常为 1)。
# self.args.lrf 表示最终学习率的倍率( lrf 是学习率衰减因子)。
# self.epochs 表示总训练轮数。
# 余弦退火调度器会根据余弦函数动态调整学习率,从初始值逐渐降低到目标值。
if self.args.cos_lr:
# def one_cycle(y1=0.0, y2=1.0, steps=100):
# -> 用于生成一个周期性的调度函数,该函数在给定的步数内从 y1 增加到 y2 ,然后在剩余的步数内从 y2 减少回 y1 。这种调度常用于调整学习率或其他超参数。返回一个 lambda 函数,该函数接受一个参数 x ,表示当前的步数。
# -> return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
# 如果不使用余弦退火调度器,则使用 线性退火学习率调度器 。
# 定义一个匿名函数 lambda 作为学习率调度函数 self.lf 。
# x 表示当前轮次。
# max(1 - x / self.epochs, 0) :计算线性衰减的倍率,确保值不小于 0。
# (1.0 - self.args.lrf) :初始学习率的调整范围。
# + self.args.lrf :最终学习率的下限。
# 线性退火调度器会线性地降低学习率,从初始值逐渐降低到目标值。
else:
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
# 使用 PyTorch 的 LambdaLR 学习率调度器,将定义好的 学习率调度函数 self.lf 应用到 优化器 self.optimizer 上。 lr_lambda 参数接收一个函数,该函数根据当前轮次动态调整学习率。
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
# _setup_scheduler 方法的主要功能是。选择学习率调度策略:如果 self.args.cos_lr 为 True ,使用余弦退火调度器。否则,使用线性退火调度器。定义学习率调度函数:余弦退火调度器使用 one_cycle 函数生成调度函数。线性退火调度器使用 lambda 函数定义调度逻辑。初始化学习率调度器:使用 PyTorch 的 LambdaLR 调度器,将调度函数应用到优化器上。
# 示例 :
# 假设 :
# self.args.cos_lr = True (使用余弦退火调度器)。
# self.args.lrf = 0.1 (最终学习率倍率为 0.1)。
# self.epochs = 100 (总训练轮数为 100)。
# 则 :
# self.lf 会被设置为一个余弦退火调度函数。
# 在训练过程中,学习率会从初始值逐渐降低到初始值的 10%。
# 如果 :
# self.args.cos_lr = False (使用线性退火调度器)。则 :
# self.lf 会被设置为一个线性退火调度函数。
# 在训练过程中,学习率会线性地从初始值降低到初始值的 10%。
# 为什么需要学习率调度器?
# 学习率是训练过程中的一个重要超参数,它决定了模型参数更新的步长。合适的学习率可以加速训练过程并提高模型性能。然而,固定的学习率可能在训练初期表现良好,但在训练后期可能导致收敛缓慢或过拟合。通过动态调整学习率,可以在训练初期快速收敛,并在训练后期精细调整模型参数,从而提高训练效果和模型性能。
# 这段代码定义了 _setup_ddp 方法,用于初始化 PyTorch 的分布式数据并行(DDP)环境。DDP 是一种用于多 GPU 训练的技术,允许模型在多个 GPU 上并行计算,同时保持模型参数同步。
# 定义了 _setup_ddp 方法,它接受一个参数。
# 1.world_size :表示参与训练的设备(GPU)总数。
def _setup_ddp(self, world_size):
# 初始化并设置用于训练的 DistributedDataParallel 参数。
"""Initializes and sets the DistributedDataParallel parameters for training."""
# 设置当前进程使用的 GPU 设备。
# RANK 是当前进程的全局排名(从 0 开始),表示当前进程使用的 GPU 编号。
# torch.cuda.set_device 确保当前进程只使用指定的 GPU,避免多进程之间的 GPU 冲突。
torch.cuda.set_device(RANK)
# 创建一个 PyTorch 设备对象,表示当前进程使用的 GPU。 torch.device("cuda", RANK) 指定使用编号为 RANK 的 GPU。
self.device = torch.device("cuda", RANK)
# 打印 DDP 相关信息,包括 当前进程的排名 ( RANK )、 设备总数 ( world_size )和 当前设备 ( self.device )。 这行代码被注释掉了,但可以在调试时取消注释以获取更多信息。
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
# 设置环境变量 TORCH_NCCL_BLOCKING_WAIT 为 "1" 。 这会启用 NCCL 的阻塞等待模式,确保在通信操作中强制执行超时机制。 这对于分布式训练中的同步操作非常重要,可以防止某些进程卡住或导致死锁。
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
# 初始化 PyTorch 的 分布式进程组 ( torch.distributed )。
dist.init_process_group(
# 指定 后端通信方式 。如果 NCCL 可用,则使用 "nccl" (适用于 GPU 通信);否则使用 "gloo" 。
backend="nccl" if dist.is_nccl_available() else "gloo",
# 设置 通信超时时间 ,这里设置为 3 小时( timedelta(seconds=10800) )。
timeout=timedelta(seconds=10800), # 3 hours
# 当前进程的 全局排名 ( RANK )。
rank=RANK,
# 参与训练的 设备总数 ( world_size )。
world_size=world_size,
)
# _setup_ddp 方法的主要功能是。设置当前进程的 GPU 设备:使用 torch.cuda.set_device(RANK) 确保当前进程只使用指定的 GPU。创建一个设备对象 self.device ,表示当前进程使用的 GPU。启用 NCCL 的阻塞等待模式:设置环境变量 TORCH_NCCL_BLOCKING_WAIT 为 "1" ,以确保通信操作中的超时机制。初始化分布式进程组:使用 torch.distributed.init_process_group 初始化分布式环境。配置通信后端( nccl 或 gloo )、超时时间、当前进程排名和设备总数。
# 在分布式训练中,多个进程需要协同工作,每个进程负责一部分计算任务。为了确保这些进程能够正确同步和通信,需要进行以下初始化 :
# 设置设备 :确保每个进程只使用一个 GPU,避免资源冲突。
# 启用超时机制 :防止通信操作中的死锁或卡住,确保训练过程的稳定性。
# 初始化进程组 :配置分布式环境,使进程能够正确同步和通信。
# 这段代码定义了 _setup_train 方法,用于在训练开始前完成一系列初始化工作,包括模型设置、层冻结、AMP(自动混合精度)配置、数据加载器创建、优化器和学习率调度器初始化等。
# 定义了 _setup_train 方法,它接受一个参数。
# 1.world_size :表示参与训练的设备(GPU)总数。
def _setup_train(self, world_size):
# 在正确的排名过程中构建数据加载器和优化器。
"""Builds dataloaders and optimizer on correct rank process."""
# Model
# 在训练开始前运行回调函数,允许用户在模型初始化之前执行自定义逻辑。
self.run_callbacks("on_pretrain_routine_start")
# 调用 setup_model 方法加载或创建模型。 ckpt 是加载的检查点(如果有)。
ckpt = self.setup_model()
# 将模型移动到指定的设备(如 GPU 或 CPU)。
self.model = self.model.to(self.device)
# 调用 set_model_attributes 方法设置模型的属性,例如类别名称、模型步幅等。
self.set_model_attributes()
# 这段代码的功能是冻结模型中指定的层,以减少训练时的计算量或保留某些预训练权重。
# Freeze layers
# 根据 self.args.freeze 的类型,确定需要冻结的层索引列表。
# 如果 self.args.freeze 是一个列表,直接使用该列表。
# 如果是一个整数,生成一个从 0 到该整数的范围( range(self.args.freeze) ),表示冻结前 n 层。
# 如果既不是列表也不是整数,默认为空列表(表示不冻结任何层)。
freeze_list = (
self.args.freeze
if isinstance(self.args.freeze, list)
else range(self.args.freeze)
if isinstance(self.args.freeze, int)
else []
)
# 定义一些始终需要冻结的层名称(例如 .dfl )。 这些层在任何情况下都不会参与训练 。
always_freeze_names = [".dfl"] # always freeze these layers
# 生成 需要冻结的层名称列表 。
# 对于 freeze_list 中的每个索引 x ,生成 对应的层名称 (如 model.0. 、 model.1. 等)。 将这些层名称与 always_freeze_names 合并,形成 完整的冻结层名称列表 。
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
# 遍历模型的所有参数( named_parameters ),其中 k 是 参数名称 , v 是 参数张量 。
for k, v in self.model.named_parameters():
# 这行代码被注释掉了,它的作用是为每个参数注册一个钩子函数,将 NaN 值替换为 0。但由于可能导致训练结果不稳定,因此被注释掉了。
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
# 检查当前参数名称 k 是否包含在冻结层名称列表中。
if any(x in k for x in freeze_layer_names):
# 如果包含,则冻结该层(通过设置 v.requires_grad = False )。 同时记录一条日志,说明该层已被冻结。
LOGGER.info(f"Freezing layer '{k}'") # 冻结层‘{k}’。
v.requires_grad = False
# 如果当前参数已经被冻结( v.requires_grad = False ),但需要计算梯度(例如在某些自定义情况下)。
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
# 发出警告,说明正在重新启用该层的梯度计算。
LOGGER.info(
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " # 警告⚠️为冻结层“{k}”设置“requires_grad=True”。
"See ultralytics.engine.trainer for customization of frozen layers." # 请参阅 ultralytics.engine.trainer 以了解冻结层的定制。
)
# 重新设置 v.requires_grad = True ,允许该层参与训练。
v.requires_grad = True
# 这段代码的主要功能是。生成冻结层列表:根据 self.args.freeze 的值生成需要冻结的层列表。包括用户指定的层和始终需要冻结的层(如 .dfl )。遍历模型参数:遍历模型的所有参数,检查参数名称是否在冻结层列表中。冻结特定层:如果参数名称匹配冻结层列表,将其 requires_grad 属性设置为 False ,避免在训练过程中更新这些参数。处理特殊情况:如果参数已经被冻结,但需要计算梯度(例如在某些自定义情况下),发出警告并重新启用梯度计算。
# 冻结层在训练过程中有以下用途 :
# 减少计算量 :冻结某些层可以减少训练过程中的计算量,加快训练速度。
# 固定预训练权重 :在微调预训练模型时,冻结某些层可以保留预训练的特征提取能力,只训练模型的顶部层(如分类头)。
# 避免过拟合 :冻结某些层可以减少模型的可训练参数数量,从而降低过拟合的风险。
# 这段代码定义了检查和配置自动混合精度(AMP)的逻辑,并根据训练环境(单 GPU 或分布式数据并行,DDP)进行相应的初始化。
# Check AMP
# 将 self.args.amp (表示是否启用 AMP 的布尔值)转换为一个张量,并移动到指定设备(如 GPU 或 CPU)。这一步确保 self.amp 是一个设备上的布尔张量。
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
# 如果启用了 AMP,并且当前是单 GPU 训练或分布式训练的主进程( RANK 为 -1 或 0 )。 备份默认的回调函数,因为 check_amp 函数可能会重置这些回调。
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
# 调用 check_amp 函数检查模型是否支持 AMP,并将结果(布尔值)转换为一个设备上的张量。 check_amp 函数通常会检查模型的兼容性,并返回是否可以启用 AMP。
# def check_amp(model):
# -> 用于检查当前环境是否支持自动混合精度(AMP)训练。它通过一系列检查确保在使用 AMP 时不会出现数值不稳定(如 NaN 损失)或性能下降(如零 mAP 结果)的问题。返回 False ,表示不支持 AMP。如果所有检查通过,则返回 True ,表示支持 AMP。
# -> return False / return True
self.amp = torch.tensor(check_amp(self.model), device=self.device)
# 恢复之前备份的默认回调函数。
callbacks.default_callbacks = callbacks_backup # restore callbacks
# 如果当前是分布式训练( RANK > -1 且 world_size > 1 )。
if RANK > -1 and world_size > 1: # DDP
# 使用 dist.broadcast 将主进程( rank=0 )的 self.amp 张量广播到所有其他进程。这确保所有进程的 AMP 配置一致。
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
# 将 self.amp 转换为布尔值,以便后续逻辑使用。
self.amp = bool(self.amp) # as boolean
# 根据 PyTorch 的版本,初始化 梯度缩放器 ( GradScaler )。
self.scaler = (
# 如果 PyTorch 版本为 2.4 或更高,使用 torch.amp.GradScaler 。 否则,使用 torch.cuda.amp.GradScaler 。 enabled=self.amp 确保只有在启用 AMP 时才激活梯度缩放器。
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
)
# 如果当前是分布式训练( world_size > 1 )。
if world_size > 1:
# 将模型包装为 DistributedDataParallel (DDP),以便在多个 GPU 上并行训练。 device_ids=[RANK] 指定当前进程使用的 GPU。 find_unused_parameters=True 确保 DDP 能够检测到未使用的参数,这对于某些动态模型(如多任务模型)很重要。
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
# 这段代码的主要功能是。检查 AMP 配置:根据 self.args.amp 的值确定是否启用 AMP。如果启用 AMP,检查模型的兼容性,并确保所有进程的 AMP 配置一致。初始化梯度缩放器:根据 PyTorch 版本初始化 GradScaler ,用于在 AMP 模式下稳定训练过程。分布式训练支持:如果是分布式训练,将模型包装为 DistributedDataParallel ,并确保所有进程的 AMP 配置一致。
# AMP 支持 :自动混合精度(AMP)可以显著加速训练过程,同时减少内存占用。通过检查模型的兼容性,可以确保 AMP 在训练中安全使用。
# 分布式训练 :在分布式训练中,所有进程需要共享相同的 AMP 配置。通过广播机制,可以确保所有进程的配置一致。
# 梯度缩放器 :在 AMP 模式下,梯度缩放器( GradScaler )用于防止梯度下溢(underflow),从而稳定训练过程。
# 这段代码的功能是检查和调整图像尺寸( imgsz )以及批量大小( batch size ),以确保它们与模型的步幅( stride )和训练环境兼容。
# Check imgsz
# 检查模型是否有一个名为 stride 的属性,该属性通常表示模型的最大步幅(例如在目标检测模型中,步幅是特征图的下采样率)。
# 如果模型有 stride 属性,取其最大值并转换为整数。
# 如果模型没有 stride 属性,则默认步幅为 32。
# 使用 max(..., 32) 确保步幅至少为 32,以避免过小的步幅导致问题。
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
# 调用 check_imgsz 函数,确保图像尺寸( imgsz )与模型的步幅( gs )兼容。
# check_imgsz 函数会调整图像尺寸, 使其能够被步幅整除 ,从而 避免在特征提取过程中出现尺寸不匹配的问题 。
# stride=gs :指定模型的最大步幅。
# floor=gs :指定图像尺寸的最小值,确保图像尺寸不会小于步幅。
# max_dim=1 :限制图像的最大维度(例如,避免图像尺寸过大导致内存不足)。
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
# 将 模型的最大步幅 ( gs )存储在 self.stride 中,以便在多尺度训练中使用。
self.stride = gs # for multiscale training
# Batch size
# 如果批量大小( self.batch_size )小于 1,且当前是单 GPU 训练( RANK == -1 )。
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
# 调用 auto_batch 方法自动估计最佳批量大小。 更新 self.args.batch 和 self.batch_size 为自动估计的值。
self.args.batch = self.batch_size = self.auto_batch()
# 这段代码的主要功能是。检查和调整图像尺寸:确保图像尺寸与模型的步幅兼容,避免在特征提取过程中出现尺寸不匹配的问题。使用 check_imgsz 函数调整图像尺寸,使其能够被步幅整除。存储模型步幅:将模型的最大步幅存储在 self.stride 中,以便在多尺度训练中使用。自动估计批量大小:如果批量大小未指定(小于 1),且当前是单 GPU 训练,自动估计最佳批量大小。
# 图像尺寸调整 :在许多深度学习模型中,输入图像的尺寸需要与模型的步幅兼容。例如,在目标检测模型中,特征图的尺寸需要能够被步幅整除,以避免边界效应和尺寸不匹配的问题。
# 多尺度训练 :存储模型的最大步幅( self.stride )可以用于多尺度训练,即在训练过程中动态调整图像尺寸,以提高模型的泛化能力。
# 批量大小估计 :在单 GPU 训练中,自动估计最佳批量大小可以充分利用 GPU 的计算资源,同时避免内存不足的问题。
# 这段代码的功能是初始化训练和验证数据加载器( DataLoader ),并设置与验证相关的组件(如验证器、指标和 EMA)。
# Dataloaders
# 计算每个设备(如 GPU)的 批量大小 。 self.batch_size 是全局批量大小。 world_size 是参与训练的设备数量。 如果 world_size > 1 (分布式训练),则将全局批量大小平均分配到每个设备上。 如果 world_size <= 1 (单机训练),则使用全局批量大小。
batch_size = self.batch_size // max(world_size, 1)
# 调用 self.get_dataloader 方法创建 训练数据加载器 。 self.trainset 是训练数据集。 batch_size 是每个设备的批量大小。 rank=LOCAL_RANK 指定当前进程的本地排名(用于分布式训练)。 mode="train" 表示这是训练数据加载器。
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train")
# 如果当前是主进程( RANK == -1 或 RANK == 0 ),执行以下操作。主进程通常负责验证、日志记录和保存模型。
if RANK in {-1, 0}:
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
# 创建 验证数据加载器 。
self.test_loader = self.get_dataloader(
# self.testset 是验证数据集。
# 如果任务是定向边界框检测( self.args.task == "obb" ),则使用与训练相同的批量大小。 否则,将批量大小加倍( batch_size * 2 ),以加快验证速度。
# rank=-1 表示这是主进程的数据加载器。
# mode="val" 表示这是验证数据加载器。
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
)
# 调用 self.get_validator 方法获取验证器实例,用于在验证阶段评估模型性能。
self.validator = self.get_validator()
# 定义 验证阶段的指标键 。
# self.validator.metrics.keys 是 验证器提供的指标键 。
# self.label_loss_items(prefix="val") 是 验证阶段的损失项键 。
# 将两者合并,生成 完整的指标键列表 。
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
# 初始化 验证指标字典 ,将所有指标值设置为 0。
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
# 创建模型的指数移动平均(EMA)实例,用于在训练过程中 维护模型权重的平滑版本 。
self.ema = ModelEMA(self.model)
# 如果启用了绘图( self.args.plots ),调用 self.plot_training_labels 方法绘制训练数据的标签分布,用于可视化训练数据的类别分布。
if self.args.plots:
self.plot_training_labels()
# 这段代码的主要功能是。初始化数据加载器:创建训练数据加载器,根据设备数量分配批量大小。创建验证数据加载器,根据任务类型调整批量大小。设置验证组件:获取验证器实例,用于评估模型性能。初始化验证指标字典,记录验证阶段的指标值。创建模型的 EMA 实例,用于维护模型权重的平滑版本。可视化训练数据:如果启用了绘图,绘制训练数据的标签分布。数据加载器:训练和验证数据加载器是训练过程中必不可少的组件,用于高效地加载和预处理数据。在分布式训练中,数据加载器需要根据设备数量分配批量大小,以充分利用多 GPU 资源。验证组件:验证器用于在训练过程中定期评估模型性能,确保模型不会过拟合。EMA 可以平滑模型权重的更新,提高模型的泛化能力。可视化:绘制训练数据的标签分布可以帮助开发者了解数据集的类别分布情况,从而更好地调整训练策略。
# 这段代码的功能是初始化优化器和学习率调度器,并设置早停机制、恢复训练状态以及运行预训练阶段的回调函数。
# Optimizer
# 计算 梯度累积步数 ( self.accumulate ),用于在小批量训练时累积梯度以模拟大批次训练的效果。 self.args.nbs 是 目标批量大小 (通常是一个较大的值,如 64)。 self.batch_size 是当前设备的 实际批量大小 。 round(self.args.nbs / self.batch_size) 计算 需要累积的步数 ,确保 累积后的批量大小接近目标批量大小 。 使用 max(..., 1) 确保至少累积一次。
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
# 根据 实际批量大小 和 梯度累积步数 调整 权重衰减 ( weight_decay )。 权重衰减需要根据实际的训练批量大小进行缩放,以保持与目标批量大小一致的效果。
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
# 计算 总迭代次数 ( iterations )。
# len(self.train_loader.dataset) 是训练数据集的 总样本数 。
# max(self.batch_size, self.args.nbs) 确保使用较大的批量大小( 实际批量大小 或 目标批量大小 )。
# math.ceil(...) 确保向上取整,避免遗漏最后一个不完整的批次。
# 乘以 self.epochs 得到 总迭代次数 。
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
# 调用 self.build_optimizer 方法构建优化器。
self.optimizer = self.build_optimizer(
# 指定优化的模型。
model=self.model,
# 指定优化器类型(如 SGD、Adam 等)。
name=self.args.optimizer,
# 初始学习率。
lr=self.args.lr0,
# 动量参数(对于 SGD 等优化器)。
momentum=self.args.momentum,
# 调整后的权重衰减。
decay=weight_decay,
# 总迭代次数,用于某些优化器(如 AdamW)的调整。
iterations=iterations,
)
# Scheduler
# 调用 _setup_scheduler 方法初始化学习率调度器。这通常包括设置学习率调度策略(如余弦退火或线性退火)。
self._setup_scheduler()
# 初始化 早停机制 ( EarlyStopping ),设置 耐心值 ( patience )为 self.args.patience 。 self.stop 是一个标志,用于 在训练过程中指示是否需要提前停止训练 。
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
# 如果存在检查点( ckpt ),调用 self.resume_training 方法 恢复训练状态 ,包括模型权重、优化器状态和训练轮次。
self.resume_training(ckpt)
# 设置学习率调度器的最后一个轮次为 self.start_epoch - 1 。这确保学习率调度器从正确的轮次开始。
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
# 调用 self.run_callbacks 方法,运行预训练阶段结束时的回调函数。这允许用户在训练开始前插入自定义逻辑,例如日志记录或模型检查。
self.run_callbacks("on_pretrain_routine_end")
# 这段代码的主要功能是。优化器初始化:计算梯度累积步数,调整权重衰减,并根据总迭代次数构建优化器。学习率调度器初始化:设置学习率调度策略,确保学习率在训练过程中动态调整。早停机制初始化:设置早停机制的耐心值,用于在验证指标不再提升时提前终止训练。恢复训练状态:如果存在检查点,恢复模型权重、优化器状态和训练轮次。运行回调函数:在预训练阶段结束时运行回调函数,允许用户插入自定义逻辑。优化器和学习率调度器:优化器和学习率调度器是训练过程中的关键组件,它们决定了模型的收敛速度和最终性能。早停机制:早停机制可以防止过拟合,节省训练时间和计算资源。恢复训练状态:支持从检查点恢复训练,确保训练过程的连续性,即使在意外中断后也能继续训练。回调函数:回调函数提供了灵活性,允许用户在训练的不同阶段插入自定义逻辑,增强训练器的功能。
# _setup_train 方法是一个综合性的初始化函数,它在训练开始前完成了一系列关键的准备工作,确保训练过程能够高效、稳定地进行。具体来说,它首先根据用户配置冻结特定的模型层,以减少计算量或固定预训练权重;接着检查并配置自动混合精度(AMP),以加速训练并减少内存占用。此外,该方法还根据模型的步幅调整图像尺寸,确保输入数据与模型兼容,并根据训练环境(如单 GPU 或分布式训练)初始化数据加载器、优化器和学习率调度器。同时,它还设置了早停机制以避免过拟合,并在存在检查点时恢复训练状态。最后,通过运行回调函数,提供了扩展性和灵活性,允许用户在训练初始化阶段插入自定义逻辑。这些步骤共同为训练过程奠定了坚实的基础,使得模型能够在各种训练场景下顺利运行。
# 这段代码定义了 _do_train 方法,它是训练过程的核心逻辑,负责执行模型的训练、验证、保存和早停机制。
# 定义了 _do_train 方法,它接受一个参数。
# 1.world_size :表示参与训练的设备(GPU)数量。默认值为 1,表示单 GPU 训练。
def _do_train(self, world_size=1):
# 训练完成,如果参数指定则进行评估和绘图。
"""Train completed, evaluate and plot if specified by arguments."""
# 如果 world_size > 1 ,表示有多于一个设备参与训练,调用 _setup_ddp 方法 初始化分布式数据并行(DDP)环境 。
if world_size > 1:
self._setup_ddp(world_size)
# 调用 _setup_train 方法, 完成训练前的准备工作 ,包括模型初始化、数据加载器创建、优化器和学习率调度器设置等。
self._setup_train(world_size)
# 这段代码初始化了训练过程中的关键变量,并记录了训练开始时的日志信息。
# 获取训练数据加载器中的 批次数量 ( nb ),用于后续计算预热迭代次数和训练进度。
nb = len(self.train_loader) # number of batches
# 计算预热(warmup)迭代次数。 self.args.warmup_epochs 是预热轮数。 nb 是每个轮次的批次数量。 如果预热轮数大于 0,则计算 预热迭代次数 ( self.args.warmup_epochs * nb ),并向上取整。 使用 max(..., 100) 确保预热迭代次数至少为 100 次。 如果 self.args.warmup_epochs <= 0 ,则设置为 -1 ,表示不进行预热。
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
# 初始化 最后一次优化步骤的索引 为 -1 ,用于 后续的梯度累积逻辑 。
last_opt_step = -1
# 初始化训练时间相关的变量。
# 当前轮次的训练时间 。
self.epoch_time = None
# 当前轮次的开始时间 。
self.epoch_time_start = time.time()
# 整个训练过程的开始时间。
self.train_time_start = time.time()
# 调用 self.run_callbacks 方法,运行训练开始时的回调函数。这允许用户在训练开始前插入自定义逻辑,例如初始化日志记录器或检查训练配置。
self.run_callbacks("on_train_start")
# 记录训练开始时的日志信息,包括。
LOGGER.info(
# 训练和验证的 图像尺寸 ( self.args.imgsz )。
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" # 图像大小 {self.args.imgsz} train, {self.args.imgsz} val 。
# 数据加载器的 工作进程数量 ( self.train_loader.num_workers * (world_size or 1) )。
f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n" # 使用 {self.train_loader.num_workers * (world_size 或 1)} 个数据加载器工作者。
# 日志保存路径 ( self.save_dir )。
f"Logging results to {colorstr('bold', self.save_dir)}\n" # 将结果记录到 {colorstr('bold', self.save_dir)}。
# 训练时长(小时或轮次)。 如果 self.args.time 指定了训练时间(小时),则显示 预计训练时间 。 否则,显示 总训练轮次 ( self.epochs )。
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") # 开始训练“ + (f"{self.args.time} 小时..." if self.args.time else f"{self.epochs} 时期...
)
# 这段代码的主要功能是。初始化关键变量:计算训练数据加载器的批次数量( nb )。计算预热迭代次数( nw ),并根据配置决定是否进行预热。初始化训练时间相关的变量,记录训练开始时间和当前轮次的开始时间。运行回调函数:调用 self.run_callbacks("on_train_start") ,允许用户在训练开始前插入自定义逻辑。记录日志信息:使用 LOGGER.info 记录训练开始时的关键信息,包括图像尺寸、数据加载器工作进程数量、日志保存路径和训练时长。这些步骤为训练过程提供了透明度和可扩展性,确保用户可以清楚地了解训练配置,并在训练开始前进行必要的初始化操作。
# 这段代码的功能是初始化训练过程中的马赛克增强关闭逻辑,并开始训练循环。
# 如果启用了关闭马赛克增强( self.args.close_mosaic ),计算 关闭马赛克增强的起始批次索引 ( base_idx )。
if self.args.close_mosaic:
# self.epochs :总训练轮数。
# self.args.close_mosaic :从倒数第几轮开始关闭马赛克增强。
# nb :每个轮次的批次数量。
base_idx = (self.epochs - self.args.close_mosaic) * nb
# 将关闭马赛克增强的批次索引及其后两个批次索引添加到 self.plot_idx 中,用于 后续的可视化 。
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
# 初始化当前轮次为 self.start_epoch ,这通常是从 0 开始,或者从检查点恢复时的轮次。
epoch = self.start_epoch
# 清零优化器的梯度 ,确保训练开始时的稳定性。这一步尤其重要,因为在恢复训练时,可能存在未清零的梯度。
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
# 开始训练循环,循环会一直运行,直到满足停止条件(如达到最大轮次或训练时间超限)。
while True:
# 设置 当前轮次 。
self.epoch = epoch
# 调用 self.run_callbacks 方法,运行轮次开始时的回调函数。这允许用户在每个轮次开始时插入自定义逻辑,例如日志记录或模型检查。
self.run_callbacks("on_train_epoch_start")
# 使用 warnings.catch_warnings() 和 warnings.simplefilter("ignore") 忽略可能的警告信息,特别是关于学习率调度器在优化器步骤之前被调用的警告。
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
# 调用 self.scheduler.step() , 根据当前轮次调整学习率 。
self.scheduler.step()
# 这段代码的主要功能是。初始化马赛克增强关闭逻辑:如果启用了关闭马赛克增强,计算关闭马赛克增强的起始批次索引,并记录相关批次索引用于可视化。开始训练循环:初始化当前轮次,并清零优化器的梯度,确保训练开始时的稳定性。运行轮次开始时的回调函数:调用 self.run_callbacks("on_train_epoch_start") ,允许用户在每个轮次开始时插入自定义逻辑。调整学习率:调用学习率调度器的 step 方法,根据当前轮次调整学习率,并忽略可能的警告信息。这些步骤为训练循环的开始提供了必要的初始化和配置,确保训练过程能够顺利进行。
# 这段代码的功能是将模型设置为训练模式,并根据当前轮次更新数据加载器的属性(例如关闭马赛克增强)。
# 将模型设置为 训练模式 。这一步确保模型的某些层(如 Dropout 和 BatchNorm )在训练时的行为与推理时不同。
self.model.train()
# 如果 当前是分布式训练 ( RANK != -1 ),调用 self.train_loader.sampler.set_epoch(epoch) 。 这一步确保在每个轮次开始时,数据采样器会重新洗牌,从而保证每个进程在每个轮次中加载不同的数据子集。
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
# 创建一个枚举器 pbar ,用于 遍历训练数据加载器中的每个批次 。 enumerate 会返回批次的索引和数据。
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
# 检查当前轮次是否是关闭马赛克增强的轮次。 self.args.close_mosaic 表示 从倒数第几轮开始关闭马赛克增强 。 如果当前轮次满足条件,则执行以下操作。
if epoch == (self.epochs - self.args.close_mosaic):
# 调用 _close_dataloader_mosaic 方法,关闭数据加载器中的马赛克增强。这通常是为了 在训练的最后阶段减少数据增强的复杂性 ,从而提高模型的泛化能力。
self._close_dataloader_mosaic()
# 调用 self.train_loader.reset() , 重置数据加载器 。这一步确保在关闭马赛克增强后,数据加载器能够正确加载数据。
self.train_loader.reset()
# 这段代码的主要功能是。设置模型为训练模式:确保模型在训练时的行为正确。更新数据采样器:在分布式训练中,确保每个进程在每个轮次中加载不同的数据子集。关闭马赛克增强:在指定的轮次关闭数据加载器中的马赛克增强,以减少数据增强的复杂性,提高模型的泛化能力。重置数据加载器:在关闭马赛克增强后,重置数据加载器,确保数据加载正确。这些步骤确保了训练过程中的数据加载和模型状态的正确性,特别是在分布式训练和数据增强策略调整时。
# 这段代码的功能是在主进程( RANK in {-1, 0} )中记录训练进度,并在每个批次开始时执行预热(warmup)逻辑,动态调整学习率和动量。
# 检查当前进程是否是主进程( RANK == -1 或 RANK == 0 )。 主进程 负责 日志记录 和 进度条显示 。
if RANK in {-1, 0}:
# 调用 self.progress_string() 方法,生成并记录 当前训练进度的信息 。这通常包括当前轮次、损失值、学习率等。
LOGGER.info(self.progress_string())
# 使用 TQDM 创建一个进度条,用于显示训练过程中的批次进度。 enumerate(self.train_loader) 生成批次的索引和数据。 total=nb 设置进度条的总批次数量。
pbar = TQDM(enumerate(self.train_loader), total=nb)
# 初始化 总损失 为 None ,用于在训练过程中累积损失值。
self.tloss = None
# 遍历训练数据加载器中的每个批次, i 是 批次索引 , batch 是 当前批次的数据 。
for i, batch in pbar:
# 调用 self.run_callbacks 方法,运行批次开始时的回调函数。这允许用户在每个批次开始时插入自定义逻辑,例如日志记录或数据预处理。
self.run_callbacks("on_train_batch_start")
# Warmup
# 计算 当前迭代的全局索引 ni ,用于预热逻辑。 i 是当前批次索引, nb 是每个轮次的批次数量, epoch 是当前轮次。
ni = i + nb * epoch
# 检查当前迭代是否在预热阶段。 nw 是预热迭代的总次数。
if ni <= nw:
# 定义插值的 x 轴范围,用于计算 学习率 和 动量 的 插值 。
xi = [0, nw] # x interp
# 根据当前迭代 ni ,使用线性插值计算 梯度累积步数 self.accumulate 。 self.args.nbs 是目标批量大小, self.batch_size 是当前设备的实际批量大小。 np.interp 用于在 [1, self.args.nbs / self.batch_size] 范围内插值,确保累积步数逐渐增加。
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
# 遍历 优化器的参数组 ,每个参数组可以有不同的学习率和动量。
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
# 使用 线性插值 动态调整学习率。
# 如果是第一个参数组(通常是偏置项),学习率从 self.args.warmup_bias_lr 线性增加到 x["initial_lr"] * self.lf(epoch) 。
# 其他参数组的学习率从 0 线性增加到 x["initial_lr"] * self.lf(epoch) 。
# self.lf(epoch) 是学习率调度函数, 根据当前轮次调整学习率 。
x["lr"] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
)
# 如果参数组中包含动量( momentum ),使用线性插值动态调整动量。 动量从 self.args.warmup_momentum 线性增加到 self.args.momentum 。
if "momentum" in x:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# 这段代码的主要功能是。记录训练进度:在主进程中记录当前训练进度,并显示进度条。初始化总损失为 None ,用于后续累积损失值。动态调整学习率和动量:在预热阶段( ni <= nw ),根据当前迭代动态调整学习率和动量。使用线性插值确保学习率和动量从初始值平滑过渡到目标值。运行批次开始时的回调函数:调用 self.run_callbacks("on_train_batch_start") ,允许用户在每个批次开始时插入自定义逻辑。这些步骤确保了训练过程中的学习率和动量能够平滑过渡,特别是在预热阶段,从而提高训练的稳定性和收敛速度。
# 这段代码的功能是执行模型的前向传播( forward )和反向传播( backward ),并处理自动混合精度(AMP)和梯度缩放。
# Forward
# 使用 autocast 上下文管理器,根据 self.amp 的值启用自动混合精度(AMP)。 如果 self.amp 为 True ,则启用 AMP,自动选择操作的精度以提高性能并减少内存占用。 如果 self.amp 为 False ,则禁用 AMP,所有操作使用默认精度(通常是 float32 )。
# def autocast(enabled: bool, device: str = "cuda"):
# -> 用于根据 PyTorch 的版本动态选择混合精度训练的上下文管理器。如果当前 PyTorch 版本为 1.13 或更高版本,返回 torch.amp.autocast 上下文管理器。返回 torch.cuda.amp.autocast ,并传递 enabled 参数,用于控制是否启用混合精度训练。
# -> return torch.amp.autocast(device, enabled=enabled) / return torch.cuda.amp.autocast(enabled)
with autocast(self.amp):
# 调用 self.preprocess_batch 方法对当前批次的数据进行预处理。这可能包括数据增强、归一化等操作。
batch = self.preprocess_batch(batch)
# 将预处理后的批次数据传递给模型,执行前向传播。 模型返回 损失值 ( self.loss )和 损失项的详细信息 ( self.loss_items )。
self.loss, self.loss_items = self.model(batch)
# 如果当前是分布式训练( RANK != -1 ),将损失值乘以设备数量( world_size )。 这一步是为了在分布式训练中同步损失值,确保所有设备的损失值一致。
if RANK != -1:
self.loss *= world_size
# 更新 总损失 ( self.tloss )。
# 如果 self.tloss 已经初始化(不为 None ),则使用移动平均的方式更新总损失。
# 如果 self.tloss 为 None ,则直接将其设置为当前批次的损失项( self.loss_items )。
# 这一步用于在训练过程中累积损失值,以便后续计算平均损失。
self.tloss = (
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
)
# Backward
# 使用梯度缩放器( GradScaler )对损失值进行缩放,然后执行反向传播。 这一步是为了在使用 AMP 时避免梯度下溢(underflow),确保梯度计算的稳定性。 self.scaler.scale(self.loss) 对损失值进行缩放, backward() 计算梯度。
self.scaler.scale(self.loss).backward()
# 这段代码的主要功能是。前向传播:使用 autocast 上下文管理器启用自动混合精度(AMP)。对批次数据进行预处理,并传递给模型执行前向传播。计算损失值并更新总损失。分布式训练同步:在分布式训练中,将损失值乘以设备数量,确保所有设备的损失值一致。反向传播:使用梯度缩放器对损失值进行缩放,然后执行反向传播,计算梯度。这一步确保在使用 AMP 时梯度计算的稳定性,避免梯度下溢。这些步骤确保了训练过程中的前向和反向传播能够高效、稳定地进行,特别是在使用 AMP 和分布式训练时。
# 这段代码的功能是在训练过程中执行优化步骤,并根据训练时间限制决定是否提前停止训练。
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
# 检查当前迭代 ni 是否达到了梯度累积步数 self.accumulate 。 last_opt_step 是上次执行优化步骤的迭代索引。 如果 ni - last_opt_step >= self.accumulate ,表示已经累积了足够的梯度,可以 执行优化步骤 。
if ni - last_opt_step >= self.accumulate:
# 调用 self.optimizer_step 方法,执行优化步骤。这通常包括 :使用梯度缩放器( GradScaler )对梯度进行缩放。 调用优化器的 step 方法,更新模型参数。 清零优化器的梯度。
self.optimizer_step()
# 更新 last_opt_step 为 当前迭代索引 ni ,记录 最后一次执行优化步骤的位置 。
last_opt_step = ni
# Timed stopping
# 检查是否设置了训练时间限制( self.args.time )。如果设置了训练时间(以小时为单位),则进入时间检查逻辑。
if self.args.time:
# 计算当前时间与训练开始时间的差值(以秒为单位),并与训练时间限制( self.args.time * 3600 秒)进行比较。 如果当前训练时间超过了设置的时间限制,将 self.stop 设置为 True ,表示需要停止训练。
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
# 检查当前是否处于分布式训练模式( RANK != -1 )。 RANK 是当前进程的全局排名, -1 表示单机训练,非 -1 表示分布式训练。
if RANK != -1: # if DDP training
# 创建一个列表 broadcast_list ,用于存储停止信号( self.stop )。 如果当前是主进程( RANK == 0 ),将 self.stop 放入列表。 如果是其他进程,列表中放入 None 。 这一步是为了确保只有主进程的信号被广播,其他进程接收信号。
broadcast_list = [self.stop if RANK == 0 else None]
# 使用 torch.distributed.broadcast_object_list 方法将主进程的停止信号广播到所有其他进程。 broadcast_list 是包含停止信号的列表。 0 表示从主进程( RANK == 0 )广播信号。 这一步确保所有进程都能接收到主进程的停止信号,从而同步停止训练。
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
# 更新当前进程的 self.stop 为广播后的值。 这一步确保所有进程的 self.stop 保持一致,避免部分进程继续运行。
self.stop = broadcast_list[0]
# 如果 self.stop 为 True ,表示训练时间已超过限制,退出当前批次的训练循环。
if self.stop: # training time exceeded
break
# 这段代码的主要功能是。执行优化步骤:在累积了足够的梯度后,调用 self.optimizer_step 方法执行优化步骤,更新模型参数。更新 last_opt_step 为当前迭代索引,记录最后一次执行优化步骤的位置。时间限制检查:如果设置了训练时间限制,检查当前训练时间是否超过了限制。如果时间超过限制,将 self.stop 设置为 True ,并广播停止信号到所有进程。提前停止训练:如果 self.stop 为 True ,退出当前批次的训练循环,提前停止训练。梯度累积:在小批量训练中,通过累积梯度可以模拟大批次训练的效果,提高训练的稳定性和性能。时间限制:在某些情况下,用户可能希望在有限的时间内完成训练。通过设置训练时间限制,可以确保训练不会超过预定的时间。分布式训练同步:在分布式训练中,所有进程需要同步停止训练,以避免部分进程继续运行导致的不一致问题。
# 这段代码的功能是在主进程( RANK in {-1, 0} )中记录训练过程中的日志信息,并在每个批次结束时执行相关的回调函数。
# Log
# 检查当前进程是否是主进程( RANK == -1 或 RANK == 0 )。主进程负责日志记录和进度条更新。
if RANK in {-1, 0}:
# 计算 总损失 self.tloss 的长度。 如果 self.tloss 是一个多维张量(例如包含多个损失项),则取其第一维的长度。 如果 self.tloss 是一个标量(无形状),则设置长度为 1。
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
# 更新进度条的描述信息,显示以下内容。
pbar.set_description(
("%11s" * 2 + "%11.4g" * (2 + loss_length))
% (
# 当前轮次和总轮次( f"{epoch + 1}/{self.epochs}" )。
f"{epoch + 1}/{self.epochs}",
# 当前 GPU 内存使用量( f"{self._get_memory():.3g}G" )。
f"{self._get_memory():.3g}G", # (GB) GPU memory util
# 损失值( self.tloss ),如果损失值是多维的,则展开显示每个损失项。
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
# 当前批次的类别数量( batch["cls"].shape[0] )。
batch["cls"].shape[0], # batch size, i.e. 8
# 当前批次的图像尺寸( batch["img"].shape[-1] )。
batch["img"].shape[-1], # imgsz, i.e 640
)
)
# 调用 self.run_callbacks 方法,运行批次结束时的回调函数。这允许用户在每个批次结束时插入自定义逻辑,例如日志记录或性能监控。
self.run_callbacks("on_batch_end")
# 如果启用了绘图( self.args.plots ),并且当前迭代索引 ni 在绘图索引列表 self.plot_idx 中。
if self.args.plots and ni in self.plot_idx:
# 调用 self.plot_training_samples 方法,绘制当前批次的训练样本。这通常用于可视化训练过程中的数据和模型输出。
self.plot_training_samples(batch, ni)
# 调用 self.run_callbacks 方法,运行训练批次结束时的回调函数。这允许用户在每个训练批次结束时插入自定义逻辑,例如日志记录或性能监控。
self.run_callbacks("on_train_batch_end")
# 这段代码的主要功能是。日志记录:在主进程中更新进度条的描述信息,显示当前轮次、GPU 内存使用量、损失值、批次类别数量和图像尺寸。这些信息帮助用户实时监控训练过程的状态。回调函数执行:在每个批次结束时运行回调函数,允许用户插入自定义逻辑,例如日志记录、性能监控或数据可视化。特别地,如果启用了绘图且当前迭代索引在绘图索引列表中,绘制当前批次的训练样本。训练批次结束时的回调:在每个训练批次结束时运行回调函数,进一步扩展训练过程的灵活性。这些步骤确保了训练过程的透明性和可扩展性,帮助用户更好地监控和调整训练过程。
# 这段代码的功能是在每个训练轮次结束时执行一系列操作,包括记录学习率、运行轮次结束时的回调函数、执行验证、保存模型等。
# 记录 每个优化器参数组的学习率 ,以便后续的日志记录或监控。 self.optimizer.param_groups 是优化器的参数组列表,每个参数组可以有不同的学习率。 使用字典推导式生成一个字典,键为 lr/pg{ir} (其中 {ir} 是参数组的索引),值为对应的学习率。
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
# 调用 self.run_callbacks 方法,运行轮次结束时的回调函数。这允许用户在每个轮次结束时插入自定义逻辑,例如日志记录或性能监控。
self.run_callbacks("on_train_epoch_end")
# 检查当前进程是否是主进程( RANK == -1 或 RANK == 0 )。主进程负责验证、保存模型和日志记录。
if RANK in {-1, 0}:
# 检查当前轮次是否是最后一个轮次。如果是最后一个轮次, final_epoch 为 True 。
final_epoch = epoch + 1 >= self.epochs
# 更新指数移动平均(EMA)模型的属性,确保 EMA 模型与当前模型的某些属性保持一致。 include 参数指定了需要更新的属性列表。
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation
# 检查是 否需要执行验证 。 如果启用了验证( self.args.val )。 或者是最后一个轮次( final_epoch )。 或者早停机制可能触发( self.stopper.possible_stop )。 或者已经触发了停止信号( self.stop )。
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
# 调用 self.validate 方法执行验证,返回 验证指标 ( self.metrics )和 模型的性能指标 ( self.fitness )。
self.metrics, self.fitness = self.validate()
# 保存训练和验证的指标。
# self.label_loss_items(self.tloss) :训练损失项。
# self.metrics :验证指标。
# self.lr :学习率。
# 这些指标将被保存到日志文件或监控工具中。
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
# 更新停止信号 self.stop 。 如果早停机制触发( self.stopper(epoch + 1, self.fitness) 返回 True )。 或者是最后一个轮次( final_epoch )。 这一步确保在满足早停条件或到达最后一个轮次时停止训练。
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
# 如果设置了训练时间限制( self.args.time ),检查当前训练时间是否超过限制。
if self.args.time:
# 如果超过限制,将 self.stop 设置为 True ,停止训练。
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
# Save model
# 检查是否需要保存模型。 如果启用了模型保存( self.args.save )。 或者是最后一个轮次( final_epoch )。
if self.args.save or final_epoch:
# 调用 self.save_model 方法保存当前模型。
self.save_model()
# 调用 self.run_callbacks 方法,运行模型保存时的回调函数。这允许用户在保存模型时插入自定义逻辑,例如日志记录或模型备份。
self.run_callbacks("on_model_save")
# 这段代码的主要功能是。记录学习率:记录每个优化器参数组的学习率,以便后续的日志记录或监控。运行轮次结束时的回调函数:允许用户在每个轮次结束时插入自定义逻辑。执行验证:在需要时执行验证,计算验证指标和模型性能。保存指标:保存训练和验证的指标,包括损失值、验证指标和学习率。更新停止信号:根据早停机制或训练时间限制更新停止信号。保存模型:在需要时保存当前模型,并运行模型保存时的回调函数。这些步骤确保了训练过程的完整性和灵活性,帮助用户监控训练进度、评估模型性能,并在适当的时候保存模型。
# 这段代码的功能是在每个轮次结束后更新学习率调度器,并根据训练时间限制动态调整总训练轮次。
# Scheduler
# 获取当前时间戳,用于计算时间相关的信息。
t = time.time()
# 计算 当前轮次的训练时间 ( self.epoch_time ),即当前时间戳减去轮次开始时间( self.epoch_time_start )。
self.epoch_time = t - self.epoch_time_start
# 更新轮次开始时间为当前时间戳,为下一个轮次的计时做准备。
self.epoch_time_start = t
# 检查是否设置了训练时间限制( self.args.time ),以小时为单位。
if self.args.time:
# 计算 平均每个轮次的训练时间 ( mean_epoch_time )。 从训练开始时间( self.train_time_start )到当前时间的总时间除以已完成的轮次数( epoch - self.start_epoch + 1 )。
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
# 根据 训练时间限制 动态调整总训练轮次 ( self.epochs )。
# 将设置的训练时间( self.args.time ,单位为小时)转换为秒( self.args.time * 3600 )。
# 除以平均每个轮次的训练时间( mean_epoch_time ),得到 理论上的最大轮次数 。
# 使用 math.ceil 向上取整,确保训练时间不超过限制。
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
# 调用 _setup_scheduler 方法重新初始化学习率调度器,以适应新的总轮次数( self.epochs )。
self._setup_scheduler()
# 更新学习率调度器的最后一个轮次为当前轮次( self.epoch )。这一步确保学习率调度器从正确的轮次开始调整学习率。
self.scheduler.last_epoch = self.epoch # do not move
# 更新停止信号( self.stop )。 如果当前轮次( epoch )已经达到或超过动态调整后的总轮次( self.epochs ),将 self.stop 设置为 True ,停止训练。
self.stop |= epoch >= self.epochs # stop if exceeded epochs
# 调用 self.run_callbacks 方法,运行轮次结束时的回调函数。这允许用户在每个轮次结束时插入自定义逻辑,例如日志记录或性能监控。
self.run_callbacks("on_fit_epoch_end")
# 调用 _clear_memory 方法,清理当前轮次占用的内存,释放资源。
self._clear_memory()
# 这段代码的主要功能是。更新轮次时间:计算当前轮次的训练时间,并更新轮次开始时间。动态调整总轮次:根据训练时间限制动态调整总训练轮次,确保训练不会超过预定时间。重新初始化学习率调度器:根据新的总轮次数重新初始化学习率调度器,确保学习率调整策略正确。更新停止信号:如果当前轮次达到或超过动态调整后的总轮次,设置停止信号,停止训练。运行回调函数:在轮次结束时运行回调函数,扩展训练过程的灵活性。清理内存:清理当前轮次占用的内存,释放资源。这些步骤确保了训练过程的灵活性和效率,特别是在设置了训练时间限制的情况下,能够动态调整训练策略,确保训练按时完成。
# 这段代码的功能是在分布式训练(DDP)中实现早停机制(Early Stopping),并确保所有进程同步停止训练。
# Early Stopping
# 检查当前是否处于分布式训练模式( RANK != -1 )。 RANK 是当前进程的全局排名, -1 表示单机训练,非 -1 表示分布式训练。
if RANK != -1: # if DDP training
# 创建一个列表 broadcast_list ,用于存储停止信号( self.stop )。 如果当前是主进程( RANK == 0 ),将 self.stop 放入列表。 如果是其他进程,列表中放入 None 。 这一步是为了确保只有主进程的信号被广播,其他进程接收信号。
broadcast_list = [self.stop if RANK == 0 else None]
# 使用 torch.distributed.broadcast_object_list 方法将主进程的停止信号广播到所有其他进程。 broadcast_list 是包含停止信号的列表。 0 表示从主进程( RANK == 0 )广播信号。 这一步确保所有进程都能接收到主进程的停止信号,从而同步停止训练。
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
# 更新当前进程的 self.stop 为广播后的值。 这一步确保所有进程的 self.stop 保持一致,避免部分进程继续运行。
self.stop = broadcast_list[0]
# 如果 self.stop 为 True ,表示需要停止训练,退出训练循环。 在分布式训练中,所有进程必须同步停止,因此需要在所有进程中执行 break 。
if self.stop:
break # must break all DDP ranks
# 如果没有触发停止信号,将当前轮次 epoch 加 1,继续下一个轮次的训练。
epoch += 1
# 这段代码的主要功能是。分布式训练同步:在分布式训练中,将主进程的停止信号广播到所有其他进程,确保所有进程同步停止训练。这一步避免了部分进程继续运行导致的不一致问题。早停机制:如果触发了停止信号( self.stop 为 True ),退出训练循环,停止训练。轮次递增:如果没有触发停止信号,继续下一个轮次的训练。在分布式训练中,多个进程并行运行,每个进程负责一部分计算任务。为了确保训练过程的同步性,需要在某些情况下(如早停机制触发或训练时间超过限制)让所有进程同时停止。通过广播机制,主进程可以将停止信号发送给所有其他进程,确保所有进程同步停止训练。
# 这段代码的功能是在训练完成后执行一系列收尾操作,包括最终验证、绘制训练指标、运行训练结束时的回调函数以及清理内存。
# 检查当前进程是否是主进程( RANK == -1 或 RANK == 0 )。主进程负责执行最终验证、绘制指标和日志记录。
if RANK in {-1, 0}:
# Do final val with best.pt
# 计算训练过程的总耗时(以秒为单位),从训练开始时间( self.train_time_start )到当前时间。
seconds = time.time() - self.train_time_start
# 记录训练完成的日志信息,显示完成的轮次数和总耗时(以小时为单位)。
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.") # {epoch - self.start_epoch + 1} 个纪元在 {seconds / 3600:.3f} 小时内完成。
# 调用 self.final_eval 方法,使用最佳模型( best.pt )进行最终验证。这一步通常用于评估模型在验证集上的性能,并记录 最终的验证指标 。
self.final_eval()
# 如果启用了绘图( self.args.plots ),调用 self.plot_metrics 方法绘制训练过程中的关键指标(如损失值、准确率等)。这一步用于可视化训练过程的性能变化。
if self.args.plots:
self.plot_metrics()
# 调用 self.run_callbacks 方法,运行训练结束时的回调函数。这允许用户在训练结束后插入自定义逻辑,例如保存最终结果、清理资源或发送通知。
self.run_callbacks("on_train_end")
# 调用 _clear_memory 方法,清理训练过程中占用的内存资源。这一步有助于释放 GPU 内存,避免内存泄漏。
self._clear_memory()
# 调用 self.run_callbacks 方法,运行清理阶段的回调函数。这一步用于执行训练结束后的清理操作,例如关闭日志记录器、释放资源等。
self.run_callbacks("teardown")
# 这段代码的主要功能是。记录训练完成信息:在主进程中记录训练完成的日志信息,显示完成的轮次数和总耗时。执行最终验证:使用最佳模型( best.pt )进行最终验证,评估模型在验证集上的性能。绘制训练指标:如果启用了绘图,绘制训练过程中的关键指标,帮助用户可视化训练效果。运行训练结束时的回调函数:允许用户在训练结束后插入自定义逻辑,例如保存最终结果或清理资源。清理内存:清理训练过程中占用的内存资源,释放 GPU 内存。执行清理阶段的回调函数:执行训练结束后的清理操作,确保资源被正确释放。这些步骤确保了训练过程的完整性和灵活性,帮助用户在训练结束后进行必要的评估和清理操作。
# _do_train 方法是训练过程的核心逻辑,负责执行模型的训练、验证、保存和早停机制。它首先根据训练环境(单 GPU 或分布式训练)初始化必要的组件,如数据加载器、优化器和学习率调度器。随后,它进入训练循环,逐轮执行以下操作:调整学习率、执行前向和反向传播、累积梯度并执行优化步骤、记录训练进度和损失值、动态调整训练时间限制下的总轮次、执行验证并保存模型。在训练过程中,它还支持早停机制以避免过拟合,并在每个轮次结束时运行回调函数以实现自定义逻辑。最终,在训练完成后,它执行最终验证、绘制训练指标、清理内存,并运行训练结束时的回调函数。这些步骤确保了训练过程的高效性、灵活性和可扩展性,同时提供了丰富的功能以满足不同训练场景的需求。
# 这段代码定义了 auto_batch 方法,用于自动估算最佳批量大小,以确保训练过程中不会因显存不足而导致训练失败。
# 定义了 auto_batch 方法,它接受一个可选参数。
# 1.max_num_obj :表示单张图像中对象的最大数量。默认值为 0,表示不考虑对象数量对批量大小的影响。
def auto_batch(self, max_num_obj=0):
# 通过计算模型的内存占用来获取批次大小。
"""Get batch size by calculating memory occupation of model."""
# 调用 check_train_batch_size 函数,传入以下参数。
# 该函数会根据模型、图像尺寸、AMP 状态和显存限制,自动 估算一个合适的批量大小 ,以确保训练过程中不会因显存不足而导致训练失败。
# def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1):
# -> 用于检查和自动调整训练时的最佳批量大小。它通过调用 autobatch 函数来实现这一功能,并支持自动混合精度(AMP)训练。autobatch 函数的作用是自动调整批量大小,以确保在训练过程中不会因显存不足而失败。它会根据显存占用情况动态调整批量大小,找到一个合适的值。
# -> return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj)
return check_train_batch_size(
# 当前训练的模型。
model=self.model,
# 训练图像的尺寸。
imgsz=self.args.imgsz,
# 是否启用自动混合精度(AMP)。
amp=self.amp,
# 当前的批量大小。
batch=self.batch_size,
# 单张图像中对象的最大数量。
max_num_obj=max_num_obj,
) # returns batch size
# auto_batch 方法的主要功能是。自动估算批量大小:根据模型、图像尺寸、AMP 状态和显存限制,动态估算一个合适的批量大小。避免显存不足:确保估算的批量大小不会导致显存不足,从而提高训练的稳定性。考虑对象数量:通过 max_num_obj 参数,可以根据单张图像中对象的最大数量进一步调整批量大小,以适应特定任务的需求。这种方法特别适用于目标检测任务,其中单张图像中的对象数量可能会影响显存占用。通过自动估算批量大小,可以优化训练过程,充分利用 GPU 资源,同时避免因显存不足导致的训练中断。
# 这段代码定义了 _get_memory 方法,用于获取当前设备(如 GPU 或 CPU)的显存或内存使用量。
# 定义了 _get_memory 方法,用于获取当前设备的显存或内存使用量。
def _get_memory(self):
# 获取加速器内存利用率(以 GB 为单位)。
"""Get accelerator memory utilization in GB."""
# 如果当前设备是 Apple Metal Performance Shaders(MPS),调用 torch.mps.driver_allocated_memory() 获取 显存使用量 。 torch.mps.driver_allocated_memory() 返回当前 MPS 设备分配的 显存总量 (以字节为单位)。
if self.device.type == "mps":
memory = torch.mps.driver_allocated_memory()
# 如果当前设备是 CPU,显存使用量为 0,因为 CPU 不涉及显存分配。
elif self.device.type == "cpu":
memory = 0
# 如果当前设备是 GPU(默认情况下为 CUDA 设备),调用 torch.cuda.memory_reserved() 获取 显存使用量 。 torch.cuda.memory_reserved() 返回当前 GPU 分配的 显存总量 (以字节为单位)。
else:
memory = torch.cuda.memory_reserved()
# 将显存或内存使用量从字节转换为 GB(1 GB = 1e9 字节)。 返回显存或内存使用量(以 GB 为单位)。
return memory / 1e9
# _get_memory 方法的主要功能是。获取显存或内存使用量:根据当前设备类型(MPS、CPU 或 GPU),获取显存或内存的使用量。对于 MPS 设备,使用 torch.mps.driver_allocated_memory() 。对于 CPU 设备,显存使用量为 0。对于 GPU 设备,使用 torch.cuda.memory_reserved() 。单位转换:将显存或内存使用量从字节转换为 GB,以便更直观地表示。通过这些步骤, _get_memory 方法可以灵活地获取不同设备的显存或内存使用量,为训练过程中的资源监控提供支持。
# 这段代码定义了 _clear_memory 方法,用于清理当前设备(如 GPU 或 CPU)的缓存,释放未使用的内存资源。
# 定义了 _clear_memory 方法,用于清理当前设备的缓存,释放未使用的内存资源。
def _clear_memory(self):
# 清除不同平台上的加速器内存。
"""Clear accelerator memory on different platforms."""
# 调用 Python 的垃圾回收器( gc.collect() ),清理未引用的对象,释放内存。 这一步确保 Python 级别的内存管理不会占用过多资源。
gc.collect()
# 如果当前设备是 Apple Metal Performance Shaders(MPS),调用 torch.mps.empty_cache() 清理 MPS 设备的显存。 torch.mps.empty_cache() 释放未使用的显存,减少显存占用。
if self.device.type == "mps":
torch.mps.empty_cache()
# 如果当前设备是 CPU,直接返回,因为 CPU 不涉及显存清理。
elif self.device.type == "cpu":
return
# 如果当前设备是 GPU(默认情况下为 CUDA 设备),调用 torch.cuda.empty_cache() 清理 GPU 的显存。 torch.cuda.empty_cache() 释放未使用的显存,减少显存占用。
else:
torch.cuda.empty_cache()
# _clear_memory 方法的主要功能是。清理 Python 垃圾回收:调用 gc.collect() ,清理未被引用的对象,释放内存。清理设备显存:根据当前设备类型,清理显存:对于 MPS 设备,调用 torch.mps.empty_cache() 。对于 GPU 设备,调用 torch.cuda.empty_cache() 。对于 CPU 设备,不执行任何操作,因为 CPU 不涉及显存清理。通过这些步骤, _clear_memory 方法可以有效地清理内存和显存,减少训练过程中的资源占用,避免内存泄漏和显存不足的问题。
# 这段代码定义了 read_results_csv 方法,用于读取训练结果的 CSV 文件,并将其内容转换为字典格式。
# 定义了 read_results_csv 方法,用于读取训练结果的 CSV 文件。
def read_results_csv(self):
# 使用 pandas 将 results.csv 读入字典。
"""Read results.csv into a dict using pandas."""
# 导入 pandas 库,并将其别名为 pd 。 注释说明 :将 pandas 的导入范围限制在该方法内,以加快 import ultralytics 的速度。这可能是为了避免在模块级别导入不必要的库,从而减少初始化时间。
import pandas as pd # scope for faster 'import ultralytics'
# 使用 pandas 的 read_csv 方法读取 CSV 文件,文件路径由 self.csv 指定。 将读取的 CSV 数据转换为字典格式,其中 每一列的数据被转换为一个列表 ( orient="list" )。 返回 转换后的字典 。
return pd.read_csv(self.csv).to_dict(orient="list")
# ead_results_csv 方法的主要功能是。读取 CSV 文件:使用 pandas 库读取训练结果的 CSV 文件。文件路径由 self.csv 属性指定,通常是一个包含训练指标(如损失值、准确率等)的文件。转换为字典格式:将 CSV 数据转换为字典格式,其中每一列的数据被转换为一个列表。这种格式便于后续处理和分析训练结果。返回结果:返回转换后的字典,方便其他方法或模块使用这些数据。这种格式使得训练结果可以方便地用于绘图、分析或其他后续处理。
# 这段代码定义了 save_model 方法,用于保存训练过程中的模型检查点(checkpoint)。
# 定义了 save_model 方法,用于保存训练过程中的模型检查点。
def save_model(self):
# 使用附加元数据保存模型训练检查点。
"""Save model training checkpoints with additional metadata."""
# 导入 io 模块,用于 创建字节缓冲区 ,以便 高效地序列化模型检查点 。
import io
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
# 创建一个字节缓冲区 buffer ,用于存储序列化的检查点数据。 将检查点数据序列化到字节缓冲区中,可以避免多次调用 torch.save() ,从而提高保存效率。
buffer = io.BytesIO()
# 使用 torch.save 将模型检查点数据序列化到字节缓冲区 buffer 中。
torch.save(
{
"epoch": self.epoch,
"model": None, # resume and final checkpoints derive from EMA
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": self.read_results_csv(),
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
},
buffer,
)
# 从字节缓冲区中获取序列化的检查点数据,准备保存到文件中。
serialized_ckpt = buffer.getvalue() # get the serialized content to save
# Save checkpoints
# 将序列化的检查点数据保存到 self.last 指定的路径(通常是 last.pt ),表示 最新的检查点 。
self.last.write_bytes(serialized_ckpt) # save last.pt
# 如果当前轮次的性能指标( self.fitness )等于最佳性能指标( self.best_fitness ),将检查点数据保存到 self.best 指定的路径(通常是 best.pt ),表示 最佳检查点 。
if self.best_fitness == self.fitness:
self.best.write_bytes(serialized_ckpt) # save best.pt
# 如果设置了保存周期( self.save_period > 0 ),并且当前轮次是保存周期的倍数,将检查点数据保存到指定的路径(如 epoch3.pt ),表示 按周期保存的检查点 。
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
# 如果启用了关闭马赛克增强,并且当前轮次是关闭马赛克增强的前一个轮次,将检查点数据保存到 last_mosaic.pt ,表示关闭马赛克增强前的检查点。 这段代码被注释掉了,可能在某些版本中不再使用。
# if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1):
# (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint
# save_model 方法的主要功能是。序列化检查点数据。使用 torch.save 将模型检查点数据序列化到字节缓冲区中,包括当前轮次、最佳性能指标、EMA 权重、优化器状态、训练参数、训练指标和训练结果。保存检查点文件:将序列化的检查点数据保存到 last.pt ,表示最新的检查点。如果当前性能指标等于最佳性能指标,保存到 best.pt ,表示最佳检查点。如果设置了保存周期,按周期保存检查点(如 epoch3.pt )。如果启用了关闭马赛克增强,保存关闭马赛克增强前的检查点( last_mosaic.pt )。通过这些步骤, save_model 方法可以高效地保存训练过程中的关键信息,便于后续恢复训练或评估模型性能。
# 这段代码定义了 get_dataset 方法,用于根据任务类型加载和验证数据集。它支持分类任务( classify )和检测/分割/姿态估计/定向边界框( detect , segment , pose , obb )任务。
# 定义了 get_dataset 方法,该函数不接受任何参数。
def get_dataset(self):
# 如果数据字典中存在,则获取 train、val 路径。
# 如果无法识别数据格式,则返回 None。
"""
Get train, val path from data dict if it exists.
Returns None if data format is not recognized.
"""
# 使用 try 块来捕获可能发生的异常,确保在数据集加载失败时能够给出明确的错误提示。
try:
# 如果任务类型是 分类 ( classify )。
if self.args.task == "classify":
# 调用 check_cls_dataset 函数,传入 数据集路径 ( self.args.data ),并返回 验证后的数据集信息 ( data )。 check_cls_dataset 函数通常会检查数据集的格式、路径是否有效,并返回 包含训练集 和 验证集路径 的字典。
# def check_cls_dataset(dataset, split=""):
# -> 用于检查和准备分类数据集。它支持自动下载数据集、解析数据集结构,并验证数据集的完整性和一致性。返回一个字典,包含 数据集路径 、 类别数量 和 类别名称 。
# -> return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names}
data = check_cls_dataset(self.args.data)
# 如果任务类型是 检测 、 分割 、 姿态估计 或 定向边界框 ( detect , segment , pose , obb ),或者数据集文件以 .yaml 或 .yml 结尾。
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
"detect",
"segment",
"pose",
"obb",
}:
# 调用 check_det_dataset 函数,传入数据集路径( self.args.data ),并返回 验证后的数据集信息 ( data )。 check_det_dataset 函数通常会解析 YAML 文件,检查数据集的格式是否正确,并返回 包含训练集 和 验证集路径 的字典。
# def check_det_dataset(dataset, autodownload=True): -> 用于检查和准备目标检测数据集。它支持从本地路径或远程地址加载数据集,自动下载缺失的文件,并验证数据集的 YAML 配置文件是否符合要求。返回更新后的 YAML 数据(字典格式),供后续使用。 -> return data # dictionary
data = check_det_dataset(self.args.data)
# 如果返回的 data 字典中包含 yaml_file 键。
if "yaml_file" in data:
# 更新 self.args.data 为 data["yaml_file"] 的值。 这一步是为了支持从 URL 下载数据集时,验证 YAML 文件路径是否正确。
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
# 如果在加载数据集过程中发生异常。
except Exception as e:
# 捕获异常并抛出一个 RuntimeError ,提示数据集加载失败,并显示具体的错误信息。 clean_url 函数用于清理 URL 中可能包含的敏感信息, emojis 函数用于在错误信息中添加表情符号以增强可读性。
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e # 数据集‘{clean_url(self.args.data)}’错误 ❌ {e} 。
# 将 加载的数据集信息 存储在 self.data 中,供后续使用。
self.data = data
# 返回 数据集的 训练集路径 ( data["train"] )和 验证集路径 ( data.get("val") 或 data.get("test") )。 如果数据集中没有验证集路径,但有测试集路径,则返回测试集路径作为验证集路径。
return data["train"], data.get("val") or data.get("test")
# get_dataset 方法的主要功能是。根据任务类型加载数据集:对于分类任务,调用 check_cls_dataset 函数。对于检测、分割、姿态估计或定向边界框任务,调用 check_det_dataset 函数。验证数据集路径和格式:检查数据集文件是否存在,并解析 YAML 文件(如果适用)。确保数据集路径有效,并支持从 URL 下载数据集时的验证。异常处理:如果数据集加载失败,抛出明确的错误提示,帮助用户快速定位问题。返回数据集路径:返回训练集和验证集(或测试集)的路径,供后续训练和验证使用。这个方法是训练器初始化过程中的重要组成部分,确保训练器能够正确加载和验证数据集,为训练过程提供数据支持。
# 这段代码定义了 setup_model 方法,用于初始化或加载模型。它支持从预训练权重加载模型,并根据配置文件( cfg )和权重( weights )构建模型。
# 定义了 setup_model 方法,用于初始化或加载模型。
def setup_model(self):
# 为任何任务加载/创建/下载模型。
"""Load/create/download model for any task."""
# 如果 self.model 已经是一个 torch.nn.Module 实例(即模型已经加载),则直接返回,无需进一步初始化。
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
# 初始化 cfg 为 self.model ,表示 模型的配置文件路径 。 初始化 weights 为 None ,表示 模型的权重路径 。
cfg, weights = self.model, None
# 初始化 ckpt 为 None ,表示 加载的检查点 (如果有)。
ckpt = None
# 如果 self.model 是一个以 .pt 结尾的路径,表示模型权重文件。
if str(self.model).endswith(".pt"):
# 调用 attempt_load_one_weight 函数加载权重文件,返回 权重路径 ( weights )和 检查点 ( ckpt )。
# def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): -> 用于加载单个模型权重文件,并对模型进行一系列的初始化和兼容性处理。返回 处理后的模型 和 检查点数据 。 -> return model, ckpt
weights, ckpt = attempt_load_one_weight(self.model)
# 从权重文件中提取 配置文件路径 ( weights.yaml ),并将其赋值给 cfg 。
cfg = weights.yaml
# 如果 self.args.pretrained 是一个 字符串 或 路径 ,表示 预训练权重路径 。
elif isinstance(self.args.pretrained, (str, Path)):
# 调用 attempt_load_one_weight 函数加载预训练权重,返回 权重路径 ( weights )。
weights, _ = attempt_load_one_weight(self.args.pretrained)
# 调用 self.get_model 方法,根据 配置文件 ( cfg )和 权重 ( weights )构建模型。
# cfg :模型的配置文件路径。
# weights :模型的权重路径。
# verbose :是否打印详细信息,仅在主进程( RANK == -1 )中打印。
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
# 返回 加载的检查点 ( ckpt ),如果有的话。
return ckpt
# setup_model 方法的主要功能是。检查模型是否已加载:如果 self.model 已经是一个 torch.nn.Module 实例,则直接返回,无需进一步初始化。加载模型权重:如果 self.model 是一个以 .pt 结尾的路径,加载权重文件并提取配置文件路径。如果 self.args.pretrained 是一个路径,加载预训练权重。构建模型:根据配置文件( cfg )和权重( weights )构建模型。调用 self.get_model 方法,支持从配置文件和权重路径加载模型。返回检查点:返回加载的检查点( ckpt ),如果有的话。通过这些步骤, setup_model 方法可以灵活地加载模型权重,并根据配置文件构建模型,支持预训练权重的加载。
# 这段代码定义了 optimizer_step 方法,用于执行优化器的一步更新操作,包括梯度缩放、梯度裁剪、优化器更新以及指数移动平均(EMA)的更新。
# 定义了 optimizer_step 方法,用于执行优化器的一步更新操作。
def optimizer_step(self):
# 使用梯度剪辑和 EMA 更新执行训练优化器的单步。
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
# 使用梯度缩放器( GradScaler )的 unscale_ 方法对 优化器的梯度进行反缩放 。 这一步确保梯度在裁剪之前恢复到原始比例, 避免裁剪操作受到缩放的影响 。
self.scaler.unscale_(self.optimizer) # unscale gradients
# 使用 torch.nn.utils.clip_grad_norm_ 对模型参数的梯度进行裁剪 ,限制梯度的全局范数不超过 max_norm=10.0 。 梯度裁剪可以防止梯度爆炸,提高训练的稳定性。
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
# 调用梯度缩放器的 step 方法,执行 优化器的更新步骤 。 这一步会 根据缩放后的梯度更新模型参数 。
self.scaler.step(self.optimizer)
# 调用梯度缩放器的 update 方法, 更新缩放因子 。 这一步根据当前梯度的大小调整缩放因子,确保后续的梯度缩放更加合理。
self.scaler.update()
# 调用优化器的 zero_grad 方法, 清零梯度 。 这一步确保在下一个批次的训练中,梯度不会累积。
self.optimizer.zero_grad()
# 如果启用了指数移动平均(EMA),调用 self.ema.update 方法更新 EMA 模型。 EMA 模型 通过平滑模型参数的变化 ,有助于 提高模型的泛化能力 。
if self.ema:
self.ema.update(self.model)
# optimizer_step 方法的主要功能是。梯度反缩放:使用梯度缩放器对优化器的梯度进行反缩放,确保梯度在裁剪之前恢复到原始比例。梯度裁剪:对模型参数的梯度进行裁剪,限制梯度的全局范数不超过 max_norm=10.0 ,防止梯度爆炸。优化器更新:调用梯度缩放器的 step 方法,执行优化器的更新步骤,更新模型参数。更新缩放因子:调用梯度缩放器的 update 方法,根据当前梯度的大小调整缩放因子。清零梯度:调用优化器的 zero_grad 方法,清零梯度,为下一个批次的训练做准备。更新 EMA 模型:如果启用了 EMA,更新 EMA 模型,通过平滑模型参数的变化,提高模型的泛化能力。通过这些步骤, optimizer_step 方法可以高效地完成优化器的一步更新操作,同时确保训练过程的稳定性和模型的泛化能力。
# 这段代码定义了 preprocess_batch 方法,用于对训练批次数据进行预处理。然而,当前实现是一个空实现,直接返回输入的批次数据,没有进行任何处理。
# 定义了 preprocess_batch 方法,它接受一个参数。
# 1.batch :表示当前批次的数据。
def preprocess_batch(self, batch):
# 允许根据任务类型自定义预处理模型输入和基本事实。
"""Allows custom preprocessing model inputs and ground truths depending on task type."""
# 直接返回输入的批次数据 batch ,没有进行任何预处理。
return batch
# 当前实现的 preprocess_batch 方法是一个空实现,它直接返回输入的批次数据,没有进行任何预处理。这通常是一个占位符方法,可以在后续实现中根据需要添加具体的预处理逻辑,例如数据增强、归一化、调整图像尺寸等。
# 为什么需要预处理?
# 在深度学习中,数据预处理是一个非常重要的步骤,它可以显著影响模型的性能和训练效果。常见的预处理操作包括 :
# 数据增强 :通过随机变换(如旋转、裁剪、翻转等)增加数据的多样性,提高模型的泛化能力。
# 归一化 :将图像像素值归一化到 [0, 1] 或 [-1, 1] 范围,避免数值不稳定。
# 调整尺寸 :将图像调整到模型所需的输入尺寸。
# 标签处理 :将标签转换为适合模型输入的格式。
# 这段代码定义了 validate 方法,用于在验证集上评估模型的性能,并更新最佳性能指标( best_fitness )。
# 定义了 validate 方法,用于在验证集上评估模型的性能。
def validate(self):
# 使用 self.validator 对测试集运行验证。
# 返回的字典应包含“fitness”键。
"""
Runs validation on test set using self.validator.
The returned dict is expected to contain "fitness" key.
"""
# 调用 self.validator 对象的验证方法,传入当前训练器实例( self )。 self.validator 是一个 验证器对象 ,负责在验证集上评估模型的性能,并返回一个包含验证指标的字典( metrics )。
metrics = self.validator(self)
# 从 metrics 字典中提取 fitness 指标(如果存在)。
# 如果 metrics 中没有 fitness 指标,则使用当前损失值( self.loss )作为性能指标。损失值取负值并转换为 NumPy 数组,以便与性能指标保持一致的符号( 性能指标通常越高越好 ,而 损失值越低越好 )。
# self.loss.detach().cpu().numpy() 确保损失值从 GPU 转移到 CPU,并转换为 NumPy 数组。
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
# 如果当前性能指标( fitness )高于之前记录的最佳性能指标( self.best_fitness ),更新 self.best_fitness 。 如果 self.best_fitness 为 None (即尚未记录最佳性能),则直接将当前性能指标设置为最佳性能。
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
# 返回 验证指标 ( metrics )和 当前性能指标 ( fitness )。
return metrics, fitness
# validate 方法的主要功能是。执行验证:调用验证器( self.validator )在验证集上评估模型的性能,返回验证指标( metrics )。提取性能指标:从验证指标中提取 fitness 指标,如果不存在,则使用当前损失值作为性能指标。更新最佳性能指标:如果当前性能指标高于之前记录的最佳性能指标,更新 self.best_fitness 。返回结果:返回验证指标( metrics )和当前性能指标( fitness )。通过这些步骤, validate 方法可以有效地评估模型的性能,并记录最佳性能指标,以便后续的早停机制或模型保存。
# 这段代码定义了 get_model 方法,它是一个占位符方法,用于根据配置文件( cfg )和权重( weights )加载或构建模型。然而,当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持通过配置文件加载模型。
# 定义了 get_model 方法,它接受以下参数 :
# 1.cfg :模型的配置文件路径(可选,默认为 None )。
# 2.weights :模型的权重路径(可选,默认为 None )。
# 3.verbose :是否打印详细信息(默认为 True )。
def get_model(self, cfg=None, weights=None, verbose=True):
# 获取模型并引发 NotImplementedError 以加载 cfg 文件。
"""Get model and raise NotImplementedError for loading cfg files."""
# 抛出一个 NotImplementedError 异常,表明当前任务的训练器不支持通过配置文件加载模型。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器不支持从配置文件加载模型。
raise NotImplementedError("This task trainer doesn't support loading cfg files") # 此任务训练器不支持加载 cfg 文件。
# get_model 方法的主要功能是。加载或构建模型:根据提供的配置文件( cfg )和权重( weights )加载或构建模型。这是一个占位符方法,当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持通过配置文件加载模型。打印详细信息:如果 verbose 为 True ,打印加载模型的详细信息(当前实现中未实现)。
# 这段代码定义了 get_validator 方法,它是一个占位符方法,用于获取验证器(validator)实例。然而,当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持获取验证器。
# 定义了 get_validator 方法,该方法不接受任何参数。
def get_validator(self):
# 调用 get_validator 函数时返回 NotImplementedError。
"""Returns a NotImplementedError when the get_validator function is called."""
# 抛出一个 NotImplementedError 异常,表明当前任务的训练器不支持获取验证器实例。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现验证功能。
raise NotImplementedError("get_validator function not implemented in trainer") # get_validator 函数未在训练器中实现。
# get_validator 方法的主要功能是。获取验证器实例:返回一个验证器对象,用于在验证集上评估模型的性能。当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持获取验证器。
# 这段代码定义了 get_dataloader 方法,它是一个占位符方法,用于根据数据集路径、批量大小、进程排名和模式(训练或验证)创建数据加载器( DataLoader )。然而,当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持自定义数据加载器的创建。
# 定义了 get_dataloader 方法,它接受以下参数 :
# 1.dataset_path :数据集的路径。
# 2.batch_size :每个批次的样本数量,默认值为 16。
# 3.rank :当前进程的排名,用于分布式训练,默认值为 0。
# 4.mode :数据加载器的模式,可以是 "train" 或 "val" ,默认值为 "train" 。
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
# 返回从 torch.data.Dataloader 派生的数据加载器。
"""Returns dataloader derived from torch.data.Dataloader."""
# 抛出一个 NotImplementedError 异常,表明当前任务的训练器不支持自定义数据加载器的创建。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现数据加载器的创建逻辑。
raise NotImplementedError("get_dataloader function not implemented in trainer") # get_dataloader 函数未在训练器中实现。
# get_dataloader 方法的主要功能是。创建数据加载器:根据提供的数据集路径、批量大小、进程排名和模式,创建一个 PyTorch 数据加载器( DataLoader )。当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持自定义数据加载器的创建。
# 这段代码定义了 build_dataset 方法,它是一个占位符方法,用于根据图像路径和模式(训练或验证)构建数据集。当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持自定义数据集的构建。
# 定义了 build_dataset 方法,它接受以下参数 :
# 1.img_path :图像路径,指向数据集的目录或文件。
# 2.mode :数据集的模式,可以是 "train" 或 "val" ,默认值为 "train" 。
# 3.batch :可选参数,表示批量大小,用于某些数据集的特定处理,默认为 None 。
def build_dataset(self, img_path, mode="train", batch=None):
# 构建数据集。
"""Build dataset."""
# 抛出一个 NotImplementedError 异常,表明当前任务的训练器不支持自定义数据集的构建。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现数据集的构建逻辑。
raise NotImplementedError("build_dataset function not implemented in trainer") # build_dataset 函数未在训练器中实现。
# build_dataset 方法的主要功能是。构建数据集:根据提供的图像路径和模式,构建一个 PyTorch 数据集( Dataset )。当前实现抛出了一个 NotImplementedError ,表明该方法尚未实现,或者当前任务的训练器不支持自定义数据集的构建。
# 这段代码定义了 label_loss_items 方法,用于处理和标记损失项( loss_items ),以便在训练或验证过程中记录和显示损失值。
# 定义了 label_loss_items 方法,它接受以下参数 :
# 1.loss_items :损失项的值或列表,可以是 None 。
# 2.prefix :前缀,用于标记损失项的来源,例如 "train" 或 "val" ,默认值为 "train" 。
def label_loss_items(self, loss_items=None, prefix="train"):
# 返回带有标记训练损失项张量的损失字典。
# 注意:
# 这对于分类来说不是必需的,但对于分割和检测来说是必需的。
"""
Returns a loss dict with labelled training loss items tensor.
Note:
This is not needed for classification but necessary for segmentation & detection
"""
# 如果 loss_items 不为 None 。 返回一个字典,键为 "loss" ,值为 loss_items 。这表示将损失项标记为 "loss" ,并返回其值。
# 如果 loss_items 为 None 。 返回一个列表 ["loss"] ,表示损失项的名称。
return {"loss": loss_items} if loss_items is not None else ["loss"]
# abel_loss_items 方法的主要功能是。处理损失项:如果提供了具体的损失项值( loss_items ),将其标记为 "loss" 并返回。如果没有提供损失项值,则返回损失项的名称列表。支持训练和验证模式:通过 prefix 参数,可以区分训练模式( "train" )和验证模式( "val" )。 label_loss_items 方法在训练和验证过程中非常有用,特别是在记录和显示损失值时。它可以根据提供的损失项值返回标记后的损失字典,或者返回损失项的名称列表,以便在日志记录或可视化工具中使用。
# 这段代码定义了 set_model_attributes 方法,用于设置模型的某些属性。当前实现中,它将模型的 names 属性设置为数据集中的类别名称。
# 定义了 set_model_attributes 方法,用于设置模型的某些属性。
def set_model_attributes(self):
# 在训练之前设置或更新模型参数。
"""To set or update model parameters before training."""
# 将模型的 names 属性设置为数据集中的类别名称。 self.data["names"] 是一个列表,包含数据集中的所有类别名称。 这一步确保模型能够访问和使用这些类别名称,例如在日志记录、可视化或模型输出时。
self.model.names = self.data["names"]
# set_model_attributes 方法的主要功能是。设置模型的 names 属性:将模型的 names 属性设置为数据集中的类别名称。这一步确保模型能够访问和使用这些类别名称,以便在训练和验证过程中正确标识和处理不同类别。通过这些步骤,模型可以正确地访问和使用类别名称,例如在日志记录或模型输出时。
# 这段代码定义了 build_targets 方法,它是一个占位符方法,用于根据模型的预测( preds )和目标( targets )构建目标张量。当前实现中,该方法没有执行任何操作( pass ),表明它尚未实现具体的功能。
# 定义了 build_targets 方法,它接受以下参数 :
# 1.preds :模型的预测输出。
# 2.targets :目标值(通常是标签或边界框等)。
def build_targets(self, preds, targets):
# 构建用于训练 YOLO 模型的目标张量。
"""Builds target tensors for training YOLO model."""
# pass 是一个占位符语句,表示该方法没有执行任何操作。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现目标构建逻辑。
pass
# build_targets 方法的主要功能是。构建目标张量:根据模型的预测( preds )和目标( targets ),构建目标张量,用于计算损失函数。当前实现中,该方法没有执行任何操作,表明它尚未实现具体的功能。
# 这段代码定义了 progress_string 方法,它是一个占位符方法,用于生成训练进度的字符串。当前实现中,该方法直接返回一个空字符串,表明它尚未实现具体的功能。
# 定义了 progress_string 方法,用于生成训练进度的字符串。
def progress_string(self):
# 返回描述训练进度的字符串。
"""Returns a string describing training progress."""
# 直接返回一个空字符串,表示当前实现中没有生成任何进度信息。
return ""
# progress_string 方法的主要功能是。生成训练进度字符串:返回一个描述当前训练进度的字符串,通常包括当前轮次、损失值、学习率等信息。当前实现中,该方法返回一个空字符串,表明它尚未实现具体的功能。
# TODO: may need to put these following functions into callback TODO:可能需要将以下函数放入回调中。
# 这段代码定义了 plot_training_samples 方法,它是一个占位符方法,用于在训练过程中绘制训练样本的可视化图像。当前实现中,该方法没有执行任何操作( pass ),表明它尚未实现具体的功能。
# 定义了 plot_training_samples 方法,它接受以下参数 :
# 1.batch :当前批次的数据,通常是一个字典,包含图像和标签。
# 2.ni :当前迭代的全局索引( ni = i + nb * epoch ),用于标识当前批次在训练过程中的位置。
def plot_training_samples(self, batch, ni):
# 在 YOLO 训练期间绘制训练样本。
"""Plots training samples during YOLO training."""
# pass 是一个占位符语句,表示该方法没有执行任何操作。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现样本可视化的逻辑。
pass
# plot_training_samples 方法的主要功能是。绘制训练样本的可视化图像:根据当前批次的数据( batch )和全局索引( ni ),生成训练样本的可视化图像。 当前实现中,该方法没有执行任何操作,表明它尚未实现具体的功能。
# 这段代码定义了 plot_training_labels 方法,它是一个占位符方法,用于绘制训练数据集中标签的分布情况。当前实现中,该方法没有执行任何操作( pass ),表明它尚未实现具体的功能。
# 定义了 plot_training_labels 方法,用于绘制训练数据集中标签的分布情况。
def plot_training_labels(self):
# 为 YOLO 模型绘制训练标签。
"""Plots training labels for YOLO model."""
# pass 是一个占位符语句,表示该方法没有执行任何操作。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现标签分布可视化的逻辑。
pass
# plot_training_labels 方法的主要功能是。绘制标签分布:根据训练数据集中的标签信息,生成标签分布的可视化图像。当前实现中,该方法没有执行任何操作,表明它尚未实现具体的功能。
# 这段代码定义了 save_metrics 方法,用于将训练过程中的指标( metrics )保存到一个 CSV 文件中。
# 定义了 save_metrics 方法,它接受一个参数。
# 1.metrics :这是一个字典,包含当前轮次的训练指标。
def save_metrics(self, metrics):
# 将训练指标保存到 CSV 文件。
"""Saves training metrics to a CSV file."""
# 将 metrics 字典的键( keys )和值( vals )分别提取为列表。 keys 是 指标的名称 , vals 是 对应的值 。
keys, vals = list(metrics.keys()), list(metrics.values())
# 计算 CSV 文件中列的数量。 len(metrics) 是 指标的数量 ,加上 2 表示额外的两列: epoch 和 time 。
n = len(metrics) + 2 # number of cols
# 如果 CSV 文件已经存在( self.csv.exists() ),则 s 为空字符串。
# 如果 CSV 文件不存在,则生成一个表头( header ),包含列名 : epoch 和 time 是固定的两列。
# keys 是指标的名称。
# 使用 ("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n" 生成表头,并确保每列之间用逗号分隔。
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
# 计算 当前轮次的训练时间 ( t ),从训练开始时间( self.train_time_start )到当前时间的差值。
t = time.time() - self.train_time_start
# 打开 CSV 文件( self.csv ),以追加模式( "a" )写入数据。
with open(self.csv, "a") as f:
# 如果文件不存在,先写入表头( s )。
# 写入当前轮次的数据 :
# self.epoch + 1 是当前轮次。
# t 是当前轮次的训练时间。
# vals 是当前轮次的指标值。
# 使用 ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n" 生成一行数据,并确保每列之间用逗号分隔。
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
# ave_metrics 方法的主要功能是。保存训练指标:将当前轮次的训练指标( metrics )保存到一个 CSV 文件中。每行数据包括当前轮次、训练时间以及指标值。生成表头:如果 CSV 文件不存在,生成表头,包含列名( epoch 、 time 和指标名称)。通过这些步骤, save_metrics 方法可以灵活地保存训练过程中的指标,便于后续分析和可视化。
# 这段代码定义了 plot_metrics 方法,它是一个占位符方法,用于绘制训练过程中的关键指标(如损失值、准确率等)。当前实现中,该方法没有执行任何操作( pass ),表明它尚未实现具体的功能。
# 定义了 plot_metrics 方法,用于绘制训练过程中的关键指标。
def plot_metrics(self):
# 以可视化方式绘制和显示指标。
"""Plot and display metrics visually."""
# pass 是一个占位符语句,表示该方法没有执行任何操作。 这通常意味着该方法需要根据具体任务进行实现,或者当前任务的训练器尚未实现指标可视化的逻辑。
pass
# plot_metrics 方法的主要功能是。绘制训练指标:根据保存的训练指标(如损失值、准确率等),生成可视化图表。当前实现中,该方法没有执行任何操作,表明它尚未实现具体的功能。
# 这段代码定义了 on_plot 方法,用于记录和存储绘图数据及其时间戳。该方法可以用于在训练过程中动态记录绘图所需的参数和数据。
# 定义了 on_plot 方法,它接受以下参数 :
# 1.name :绘图的名称或路径,用于标识绘图。
# 2.data :可选参数,表示绘图所需的数据,默认为 None 。
def on_plot(self, name, data=None):
# 注册图表(例如在回调中使用)。
"""Registers plots (e.g. to be consumed in callbacks)."""
# 使用 Path 类(来自 pathlib 模块)将 name 转换为路径对象。这使得路径操作更加灵活和安全。
path = Path(name)
# 将 绘图数据 及 其时间戳 存储在 self.plots 字典中。
# 键为 路径对象 ( path ),表示绘图的唯一标识。
# 值为一个字典,包含两个键。
# "data" :存储绘图所需的数据。
# "timestamp" :记录当前时间的时间戳( time.time() ),用于后续的时间管理或排序。
self.plots[path] = {"data": data, "timestamp": time.time()}
# on_plot 方法的主要功能是。记录绘图数据:将绘图的名称或路径( name )转换为路径对象( path ),并将其作为键存储在 self.plots 字典中。将绘图所需的数据( data )和当前时间的时间戳( timestamp )存储在字典中。支持动态绘图:通过记录绘图数据及其时间戳,可以在训练过程中动态生成和更新绘图。通过这些步骤, on_plot 方法可以灵活地记录绘图数据及其时间戳,支持在训练过程中动态生成和更新绘图。
# 这段代码定义了 final_eval 方法,用于在训练完成后对模型进行最终验证。它会加载最新的检查点( last.pt )和最佳检查点( best.pt ),并更新最佳检查点的训练指标。
# 定义了 final_eval 方法,用于在训练完成后对模型进行最终验证。
def final_eval(self):
# 对对象检测 YOLO 模型进行最终评估和验证。
"""Performs final evaluation and validation for object detection YOLO model."""
# 初始化一个空字典 ckpt ,用于 存储检查点信息 。
ckpt = {}
# 遍历两个检查点文件。 self.last (最新的检查点)和 self.best (最佳检查点)。
for f in self.last, self.best:
# 检查当前检查点文件是否存在。
if f.exists():
# 如果当前文件是 self.last (最新的检查点),调用 strip_optimizer 函数加载检查点,并移除优化器状态。 将加载的检查点信息存储在 ckpt 中。
if f is self.last:
# def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
# -> 用于清理和优化 PyTorch 模型检查点文件(checkpoint)。它的主要功能是从检查点中移除不必要的信息(如优化器状态、EMA 状态等),并将模型转换为 FP16 格式,同时保留模型的元数据和训练参数。返回合并后的字典 combined ,包含清理后的检查点内容和元数据。 这使得函数的调用者可以获取清理后的检查点内容,方便后续使用。
# -> return combined
ckpt = strip_optimizer(f)
# 如果当前文件是 self.best (最佳检查点),更新其训练指标。
elif f is self.best:
# 指定要更新的键。
k = "train_results" # update best.pt train_metrics from last.pt
# 如果 ckpt 中存在 k ,将 ckpt[k] 的值更新到最佳检查点中。 调用 strip_optimizer 函数加载最佳检查点,并应用更新。
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
# 记录日志,提示正在验证当前检查点文件。
LOGGER.info(f"\nValidating {f}...") # 正在验证 {f}...
# 将验 证器的绘图参数 设置为与 训练器相同的值 。
self.validator.args.plots = self.args.plots
# 调用验证器对当前模型进行验证,并获取 验证指标 。
self.metrics = self.validator(model=f)
# 从验证指标中移除 fitness 键(如果存在),因为 fitness 通常用于早停机制,而不是最终验证。
self.metrics.pop("fitness", None)
# 运行训练结束时的回调函数,允许用户在验证完成后执行自定义逻辑。
self.run_callbacks("on_fit_epoch_end")
# final_eval 方法的主要功能是。加载检查点:加载最新的检查点( last.pt )和最佳检查点( best.pt )。从最新的检查点中提取训练指标,并更新最佳检查点的训练指标。验证模型:对每个检查点进行验证,获取验证指标。记录验证过程的日志信息。运行回调函数:在验证完成后运行训练结束时的回调函数,允许用户执行自定义逻辑,例如保存验证结果或清理资源。通过这些步骤, final_eval 方法可以有效地完成最终验证,并确保训练过程的完整性和灵活性。
# 这段代码定义了 check_resume 方法,用于检查是否需要从一个检查点(checkpoint)恢复训练,并更新相关配置。
# 定义了 check_resume 方法,它接受一个参数。
# 1.overrides :这是一个字典,用于覆盖默认的配置参数。
def check_resume(self, overrides):
# 检查恢复检查点是否存在并相应地更新参数。
"""Check if resume checkpoint exists and update arguments accordingly."""
# 从 self.args 中获取 resume 属性的值。 resume 是一个 布尔值 或 路径 ,表示 是否需要从检查点恢复训练 。
resume = self.args.resume
# 如果 resume 为 True 或是一个有效的路径,则进入恢复训练的逻辑。
if resume:
# 使用 try 块来捕获可能发生的异常,确保在恢复失败时能够给出明确的错误提示。
try:
# 检查 resume 是否是一个 字符串 或 Path 对象,并且对应的文件是否存在。如果存在, exists 为 True 。
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
# 如果 exists 为 True ,调用 check_file 函数验证并规范化路径。
# 如果 exists 为 False ,调用 get_latest_run 函数获取最近一次运行的检查点路径。
# 将结果存储在 last 中, last 是一个 Path 对象,表示 恢复训练的检查点路径 。
# def check_file(file, suffix="", download=True, download_dir=".", hard=True):
# -> 用于检查文件是否存在,如果不存在则尝试下载文件,并返回文件的路径。直接返回 file 。返回下载后的文件路径,将其转换为字符串形式。如果找到文件,返回第一个匹配的文件路径。 如果未找到文件,返回空列表 [] 。
# -> return file / return str(file) / return files[0] if len(files) else [] # return file
# def get_latest_run(search_dir="."): -> 用于在指定目录中查找最新的 last.pt 文件,通常用于恢复训练模型。 如果 last_list 不为空(即找到了匹配的文件),使用 max 函数找到其中“最新”的文件。如果 last_list 为空(即没有找到匹配的文件),返回空字符串 "" 。 -> return max(last_list, key=os.path.getctime) if last_list else ""
last = Path(check_file(resume) if exists else get_latest_run())
# Check that resume data YAML exists, otherwise strip to force re-download of dataset 检查恢复数据 YAML 是否存在,否则删除以强制重新下载数据集。
# 调用 attempt_load_weights 函数加载检查点文件,并从中提取配置参数( args )。这些参数存储在 ckpt_args 中,用于 后续更新配置 。
# def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# -> 用于加载模型权重,支持加载单个模型或多个模型组成的集成模型(ensemble)。如果 ensemble 中只有一个模型,则直接返回该模型。返回最终的模型集合 ensemble 。
# -> return ensemble[-1] / return ensemble
ckpt_args = attempt_load_weights(last).args
# 检查 检查点中记录的 数据集路径 是否存在。
if not Path(ckpt_args["data"]).exists():
# 如果不存在,使用 当前配置中的数据集路径 替换检查点中的路径。这确保了即使检查点中的数据集路径无效,训练也能正常进行。
ckpt_args["data"] = self.args.data
# 将 resume 设置为 True ,表示 确实需要恢复训练 。
resume = True
# 使用 get_cfg 函数 将检查点中的配置参数 更新到 self.args 中,确保训练器的配置与检查点一致。
# def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None):
# -> 用于处理和验证配置信息,最终返回一个配置对象。将最终的配置字典 cfg 转换为 IterableSimpleNamespace 对象并返回。 IterableSimpleNamespace 是一个可迭代的命名空间对象,支持通过点符号访问属性(如 cfg.name ),同时也支持字典操作(如 cfg["name"] )。
# -> return IterableSimpleNamespace(**cfg)
self.args = get_cfg(ckpt_args)
# 将 self.args.model 和 self.args.resume 更新为 检查点路径 ( last 的字符串形式)。这确保了模型路径和恢复路径一致。
self.args.model = self.args.resume = str(last) # reinstate model
# 允许通过 overrides 参数覆盖某些配置项。
# 支持覆盖的配置项包括 imgsz (图像大小)、 batch (批量大小)、 device (设备)、 close_mosaic (是否关闭马赛克增强)。
for k in (
"imgsz",
"batch",
"device",
"close_mosaic",
): # allow arg updates to reduce memory or update device on resume
# 如果 overrides 中包含这些键,则使用 setattr 更新 self.args 中的对应值。
if k in overrides:
setattr(self.args, k, overrides[k])
# 如果在恢复过程中发生异常(如检查点文件不存在),捕获异常并抛出一个 FileNotFoundError ,提示用户提供有效的检查点路径。
except Exception as e:
raise FileNotFoundError(
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
"i.e. 'yolo train resume model=path/to/last.pt'" # 未找到恢复检查点。请传递有效的检查点以从中恢复,即“yolo train resume model=path/to/last.pt”。
) from e
# 将 self.resume 设置为最终的恢复标志( True 或 False ),表示 是否需要恢复训练 。
self.resume = resume
# check_resume 方法的主要功能是。检查恢复标志:根据 self.args.resume 判断是否需要从检查点恢复训练。验证检查点路径:确保提供的检查点路径有效,或者自动获取最近一次运行的检查点。更新配置:从检查点中加载配置参数,并根据需要更新当前配置。支持覆盖:允许通过 overrides 参数覆盖某些配置项,例如批量大小或设备。异常处理:如果恢复失败(如检查点不存在),抛出明确的错误提示。这个方法在训练器初始化时被调用,确保训练器能够正确地从检查点恢复训练,同时支持灵活的配置覆盖。
# 这段代码定义了 resume_training 方法,用于从检查点( ckpt )恢复训练。它加载检查点中的模型权重、优化器状态、EMA 状态,并设置训练的起始轮次。
# 定义了 resume_training 方法,它接受一个参数。
# 1.ckpt :表示检查点文件的内容。
def resume_training(self, ckpt):
# 从给定的时期和最佳适应度恢复 YOLO 训练。
"""Resume YOLO training from given epoch and best fitness."""
# 如果检查点 ckpt 为 None 或者 self.resume 为 False ,直接返回,不执行恢复操作。
if ckpt is None or not self.resume:
return
# 初始化 best_fitness 为 0.0,表示 最佳性能指标 。
best_fitness = 0.0
# 从检查点中获取 当前轮次 ( epoch ),默认值为 -1 。 将起始轮次设置为当前轮次加 1,即 从下一个轮次开始训练 。
start_epoch = ckpt.get("epoch", -1) + 1
# 如果检查点中 包含优化器状态 ( optimizer ),加载优化器状态。
if ckpt.get("optimizer", None) is not None:
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
# 更新 最佳性能指标 ( best_fitness )。
best_fitness = ckpt["best_fitness"]
# 如果启用了 EMA(指数移动平均)并且检查点中包含 EMA 状态( ema )。
if self.ema and ckpt.get("ema"):
# 加载 EMA 模型的权重。
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
# 更新 EMA 的更新次数( updates )。
self.ema.updates = ckpt["updates"]
# 断言起始轮次大于 0,确保训练尚未完成。 如果训练已经完成( start_epoch <= 0 ),抛出错误并提示用户重新开始训练。
assert start_epoch > 0, (
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" # {self.args.model} 训练 {self.epochs} 个时期已完成,无需恢复。
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" # 开始新的训练而不恢复,即'yolo train model={self.args.model}'。
)
# 记录日志,提示用户正在从指定的轮次恢复训练。
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") # 恢复训练 {self.args.model},从第 {start_epoch + 1} 个周期到第 {self.epochs} 个周期。
# 如果总轮次( self.epochs )小于起始轮次( start_epoch ),说明模型已经训练了一部分轮次。
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." # {self.model} 已训练了 {ckpt['epoch']} 个时期。 正在对 {self.epochs} 个时期进行微调。
)
# 更新总轮次,将额外的微调轮次加到总轮次上。
self.epochs += ckpt["epoch"] # finetune additional epochs
# 更新 最佳性能指标 ( best_fitness )和 起始轮次 ( start_epoch )。
self.best_fitness = best_fitness
self.start_epoch = start_epoch
# 如果起始轮次大于 关闭马赛克增强的轮次 ( self.epochs - self.args.close_mosaic ),调用 _close_dataloader_mosaic 方法关闭马赛克增强。
if start_epoch > (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()
# resume_training 方法的主要功能是。加载检查点:从检查点中加载模型权重、优化器状态和 EMA 状态。更新最佳性能指标( best_fitness )和起始轮次( start_epoch )。恢复训练:确保训练尚未完成,并从指定的轮次恢复训练。如果模型已经训练了一部分轮次,更新总轮次以进行微调。关闭马赛克增强:如果起始轮次大于关闭马赛克增强的轮次,关闭马赛克增强。
# 示例 :
# 假设 :检查点文件包含以下内容 :
# ckpt = {
# "epoch": 10,
# "best_fitness": 0.85,
# "optimizer": optimizer_state_dict,
# "ema": ema_state_dict,
# "updates": 100
# }
# 当前总轮次为 50,关闭马赛克增强的轮次为 20。调用 resume_training 方法后 :
# 加载优化器状态和 EMA 状态。
# 更新最佳性能指标为 0.85。
# 设置起始轮次为 11。
# 如果总轮次小于起始轮次,更新总轮次为 60(额外微调 10 轮)。
# 如果起始轮次大于 30(50 - 20),关闭马赛克增强。
# 通过这些步骤, resume_training 方法可以有效地从检查点恢复训练,确保训练过程的连续性和灵活性。
# 这段代码定义了 _close_dataloader_mosaic 方法,用于关闭训练数据加载器中的马赛克增强(Mosaic Augmentation)。马赛克增强是一种数据增强技术,通过将多个图像拼接在一起形成一个训练样本,从而增加数据的多样性。关闭马赛克增强通常在训练的后期阶段进行,以减少数据增强的复杂性并提高模型的泛化能力。
# 定义了 _close_dataloader_mosaic 方法,用于关闭训练数据加载器中的马赛克增强。
def _close_dataloader_mosaic(self):
# 更新数据加载器以停止使用马赛克增强。
"""Update dataloaders to stop using mosaic augmentation."""
# 检查训练数据集( self.train_loader.dataset )是否有一个名为 mosaic 的属性。 如果存在该属性,将其设置为 False ,关闭马赛克增强。
if hasattr(self.train_loader.dataset, "mosaic"):
self.train_loader.dataset.mosaic = False
# 检查训练数据集是否有一个名为 close_mosaic 的方法。
if hasattr(self.train_loader.dataset, "close_mosaic"):
LOGGER.info("Closing dataloader mosaic") # 关闭数据加载器马赛克。
# 如果存在该方法,调用 close_mosaic 方法关闭马赛克增强。传递 hyp=copy(self.args) 作为参数, hyp 是 超参数字典的副本 ,用于 在关闭马赛克增强时传递必要的配置 。
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
# _close_dataloader_mosaic 方法的主要功能是。关闭马赛克增强:检查训练数据集是否支持马赛克增强,并将其关闭。如果数据集提供了 close_mosaic 方法,调用该方法以更优雅的方式关闭马赛克增强。记录日志:记录关闭马赛克增强的日志信息,以便用户了解训练过程中的变化。通过这些步骤, _close_dataloader_mosaic 方法可以有效地关闭马赛克增强,减少训练后期的数据增强复杂性,从而提高模型的泛化能力。
# 这段代码定义了 build_optimizer 方法,用于根据模型和配置参数构建优化器。它支持多种优化器类型,并根据模型的参数分组来设置不同的权重衰减策略。
# 定义了 build_optimizer 方法,它接受以下参数 :
# 1.model :模型实例。
# 2.name :优化器的名称,默认为 "auto" ,表示自动选择优化器。
# 3.lr :初始学习率,默认为 0.001 。
# 4.momentum :动量参数,默认为 0.9 。
# 5.decay :权重衰减参数,默认为 1e-5 。
# 6.iterations :总迭代次数,默认为 1e5 。
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
# 根据指定的优化器名称、学习率、动量、权重衰减和迭代次数,为给定模型构建优化器。
"""
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
weight decay, and number of iterations.
Args:
model (torch.nn.Module): The model for which to build an optimizer.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
based on the number of iterations. Default: 'auto'.
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
iterations (float, optional): The number of iterations, which determines the optimizer if
name is 'auto'. Default: 1e5.
Returns:
(torch.optim.Optimizer): The constructed optimizer.
"""
# 这段代码定义了优化器参数分组和自动选择优化器的逻辑。
# 初始化三个参数组 g ,用于存储不同类型的参数。
# g[0] :权重参数(应用权重衰减)。
# g[1] :归一化层的权重参数(不应用权重衰减)。
# g[2] :偏置参数(不应用权重衰减)。
g = [], [], [] # optimizer parameter groups
# 遍历 torch.nn 模块的所有类,筛选出名称中包含 "Norm" 的类(如 BatchNorm2d 、 LayerNorm 等)。 将这些归一化层的类存储在元组 bn 中,用于 后续判断参数是否属于归一化层 。
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
# 检查优化器名称是否为 "auto" 。如果是,则自动选择优化器类型、学习率和动量。
if name == "auto":
# 记录日志,提示用户优化器名称为 "auto" ,将自动选择优化器类型、学习率和动量。 忽略用户指定的初始学习率( self.args.lr0 )和动量( self.args.momentum )。
LOGGER.info(
f"{colorstr('optimizer:')} 'optimizer=auto' found, " # {colorstr('optimizer:')} 发现'optimizer=auto',
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " # 忽略'lr0={self.args.lr0}'和'momentum={self.args.momentum}'
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " # 并自动确定最佳'optimizer'、'lr0'和'momentum'...
)
# 从数据集配置中获取 类别数量 ( nc ),默认值为 10。 这个值用于 后续计算自动选择的学习率 。
nc = self.data.get("nc", 10) # number of classes
# 根据 类别数量 ( nc )计算 自动选择的学习率 ( lr_fit )。 使用公式 0.002 * 5 / (4 + nc) 计算学习率。 将结果四舍五入到小数点后 6 位。 这个公式是一个经验公式,用于 根据类别数量调整学习率 。
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
# 根据 总迭代次数 ( iterations )选择优化器类型、学习率和动量。
# 如果迭代次数大于 10000,使用 SGD 优化器,学习率为 0.01 ,动量为 0.9 。
# 否则,使用 AdamW 优化器,学习率为 lr_fit ,动量为 0.9 。
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
# 设置偏置参数的预热学习率为 0.0 ,确保在使用 Adam 优化器时,偏置参数的学习率不会过高(不超过 0.01 )。
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
# 这段代码的主要功能是。初始化参数组:初始化三个参数组 g ,用于存储不同类型的参数(权重、归一化层权重、偏置)。筛选归一化层:筛选出所有归一化层的类(如 BatchNorm2d ),用于后续判断参数是否属于归一化层。自动选择优化器:如果优化器名称为 "auto" ,根据类别数量和迭代次数自动选择优化器类型、学习率和动量。使用经验公式计算学习率,并根据迭代次数选择合适的优化器。设置偏置参数的预热学习率:确保在使用 Adam 优化器时,偏置参数的学习率不会过高。通过这些步骤, build_optimizer 方法可以灵活地构建优化器,并根据模型的参数类型设置不同的权重衰减策略,从而提高训练的效率和稳定性。
# 这段代码的功能是遍历模型的所有模块和参数,并根据参数的类型将它们分配到不同的参数组中。这些参数组将用于优化器的配置,以便对不同类型的参数应用不同的权重衰减策略。
# 遍历 模型的所有模块 (包括子模块), model.named_modules() 返回一个生成器,包含模块的名称和模块实例。
for module_name, module in model.named_modules():
# 遍历 当前模块的所有参数 。 module.named_parameters(recurse=False) 返回模块自身的参数,不递归遍历子模块。
for param_name, param in module.named_parameters(recurse=False):
# 构造参数的完整名称。如果模块有名称,则完整名称为 module_name.param_name ;否则,直接使用 param_name 。
fullname = f"{module_name}.{param_name}" if module_name else param_name
# 如果参数名称中包含 "bias" ,表示这是一个偏置参数,将其添加到参数组 g[2] 中。偏置参数通常不应用权重衰减。
if "bias" in fullname: # bias (no decay)
g[2].append(param)
# 如果当前模块是归一化层(如 BatchNorm2d ),将其权重参数添加到参数组 g[1] 中。归一化层的权重参数通常也不应用权重衰减。
elif isinstance(module, bn): # weight (no decay)
g[1].append(param)
# 如果参数不属于上述两种情况,则认为它是普通权重参数,将其添加到参数组 g[0] 中。普通权重参数通常应用权重衰减。
else: # weight (with decay)
g[0].append(param)
# 这段代码的主要功能是。遍历模型的所有模块和参数:使用 model.named_modules() 和 module.named_parameters(recurse=False) 遍历模型的所有模块和参数。根据参数类型分组:将偏置参数( bias )添加到参数组 g[2] 中,这些参数不应用权重衰减。将归一化层的权重参数添加到参数组 g[1] 中,这些参数也不应用权重衰减。将普通权重参数添加到参数组 g[0] 中,这些参数应用权重衰减。通过这些步骤,可以将模型的参数分为三组,以便在优化器中对不同类型的参数应用不同的权重衰减策略,从而提高训练的效率和稳定性。
# 示例 :
# 假设模型包含以下模块和参数 :
# Conv2d 层,包含权重参数 weight 和偏置参数 bias 。
# BatchNorm2d 层,包含权重参数 weight 和偏置参数 bias 。则 :
# Conv2d.weight 将被添加到 g[0] (应用权重衰减)。
# Conv2d.bias 和 BatchNorm2d.bias 将被添加到 g[2] (不应用权重衰减)。
# BatchNorm2d.weight 将被添加到 g[1] (不应用权重衰减)。
# 通过这种分组方式,优化器可以对不同类型的参数应用不同的权重衰减策略,从而提高训练的效率和稳定性。
# 这段代码的功能是根据指定的优化器名称创建优化器实例,并设置其参数。
# 定义了一个集合 optimizers ,包含 支持的优化器名称 。这些优化器包括常见的优化算法,如 Adam、Adamax、AdamW、NAdam、RAdam、RMSProp 和 SGD,以及一个特殊值 "auto" ,用于自动选择优化器。
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
# 创建一个字典,将所有支持的优化器名称转换为小写,并映射回原始名称。 使用 name.lower() 将输入的优化器名称转换为小写,然后从字典中获取对应的原始名称。 这一步确保优化器名称的大小写不敏感。
name = {x.lower(): x for x in optimizers}.get(name.lower())
# 检查优化器名称是否属于 Adam 系列优化器(包括 Adam、Adamax、AdamW、NAdam 和 RAdam)。
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
# 使用 getattr 动态获取优化器类。如果指定的优化器名称无效,则默认使用 optim.Adam 。
# 创建优化器实例,初始化时使用 偏置参数 ( g[2] )。
# 设置学习率( lr )、动量( momentum )和权重衰减( weight_decay=0.0 )。
# Adam 系列优化器通常使用两个动量参数( betas=(momentum, 0.999) ),其中 0.999 是第二个动量参数的默认值。
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
# 检查优化器名称是否为 "RMSProp" 。
elif name == "RMSProp":
# 创建 RMSProp 优化器实例,初始化时使用 偏置参数 ( g[2] )。 设置学习率( lr )和动量( momentum )。
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
# 检查优化器名称是否为 "SGD" 。
elif name == "SGD":
# 创建 SGD 优化器实例,初始化时使用 偏置参数 ( g[2] )。 设置学习率( lr )、动量( momentum )并启用 Nesterov 动量( nesterov=True )。
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
# 如果优化器名称不属于上述任何一种。
else:
# 抛出 NotImplementedError 异常,提示用户指定的优化器名称无效。 建议用户在 Ultralytics GitHub 请求支持更多优化器。
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. " # 可用优化器 {optimizers} 列表中未找到优化器“{name}”。
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics." # 在 https://github.com/ultralytics/ultralytics 请求对附加优化器的支持。
)
# 这段代码的主要功能是。支持多种优化器:定义了一个集合 optimizers ,包含支持的优化器名称。支持常见的优化器,如 Adam、Adamax、AdamW、NAdam、RAdam、RMSProp 和 SGD。动态创建优化器实例:根据指定的优化器名称动态创建优化器实例。设置优化器的参数,如学习率、动量和权重衰减。处理无效优化器名称:如果指定的优化器名称无效,抛出 NotImplementedError 异常,并提示用户请求支持更多优化器。通过这些步骤,代码可以灵活地创建不同类型的优化器实例,并根据用户的需求进行配置。
# 这段代码的功能是将不同类型的参数组添加到优化器中,并记录优化器的配置信息。
# 将 权重参数 ( g[0] )添加到优化器中,并设 置权重衰减 ( weight_decay )。 这些参数通常 需要应用权重衰减以防止过拟合 。
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
# 将 归一化层的权重参数 ( g[1] )添加到优化器中,并设置权重衰减为 0.0 。 归一化层的权重通常不应用权重衰减,因为它们的更新方式与普通权重不同。
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
# 记录优化器的配置信息,包括优化器类型、学习率、动量以及每个参数组的大小和权重衰减设置。 使用 colorstr 函数为日志信息添加颜色,使其在控制台中更易于识别。
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)" # {colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, motivation={momentum}) 带参数组 {len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)。
)
# 返回 配置好的优化器实例 。
return optimizer
# 这段代码的主要功能是。添加参数组:将不同类型的参数(权重、归一化层权重、偏置)添加到优化器中。为权重参数设置权重衰减( weight_decay ),而归一化层的权重和偏置参数不应用权重衰减。记录优化器配置:记录优化器的类型、学习率、动量以及每个参数组的大小和权重衰减设置。这些信息有助于调试和验证优化器的配置是否符合预期。通过这些步骤,代码可以灵活地配置优化器,并确保不同类型的参数得到适当的处理。这对于提高训练的效率和稳定性至关重要。
# build_optimizer 方法的主要功能是。自动选择优化器:如果优化器名称为 "auto" ,根据模型的类别数量和迭代次数自动选择优化器类型、学习率和动量。参数分组:将模型的参数分为三组:偏置参数(不应用权重衰减)。归一化层的权重参数(不应用权重衰减)。其他权重参数(应用权重衰减)。创建优化器:根据指定的优化器类型创建优化器实例,并设置学习率、动量和权重衰减。记录配置信息:记录优化器的配置信息,包括优化器类型、学习率、动量和参数组的数量。通过这些步骤, build_optimizer 方法可以灵活地构建优化器,并根据模型的参数类型设置不同的权重衰减策略,从而提高训练的效率和稳定性。
# 总轮次(epoch)和总迭代次数(iterations)是深度学习训练过程中的两个重要概念,它们描述了训练的不同方面 :
# 总轮次(epoch) :
# 总轮次指的是在整个训练过程中,模型将完整地遍历整个训练数据集的次数。
# 每个轮次中,模型会看到训练数据集中的所有样本一次。
# 总轮次通常由用户根据训练需求和模型收敛情况设定,例如 self.epochs 。
# 总迭代次数(iterations) :
# 总迭代次数指的是在整个训练过程中,模型参数更新的总次数。
# 每个迭代中,模型会看到训练数据集中的一个批次(batch)的样本,并根据该批次的损失值更新参数。
# 总迭代次数可以通过以下公式计算 :
# 总迭代次数 = 总轮次 × 每个轮次的迭代次数
# 其中,每个轮次的迭代次数等于训练数据集的样本总数除以每个批次的样本数(即 len(self.train_loader) )。
# 在 BaseTrainer 类中,总轮次(epoch)和总迭代次数(iterations)的计算和使用如下 :
# 总轮次(epoch) :通过 self.epochs 属性设定,表示模型将完整地遍历整个训练数据集的次数。
# 总迭代次数(iterations) :通过以下公式计算 : 总迭代次数 = 总轮次 × 每个轮次的迭代次数 。
# 其中,每个轮次的迭代次数等于训练数据集的样本总数除以每个批次的样本数(即 len(self.train_loader) )。
# 在某些情况下,总迭代次数可能根据训练时间限制动态调整,例如在 self.args.time 指定了训练时间限制时。
# 通过这些计算, BaseTrainer 类可以灵活地管理训练过程中的轮次和迭代,确保训练的高效性和可扩展性。
# BaseTrainer 类是一个基础训练器类,封装了深度学习模型训练的通用逻辑和功能。它提供了从数据加载、模型初始化、优化器和学习率调度器配置,到训练循环、验证、早停机制以及模型保存的完整流程。通过灵活的回调机制和模块化设计, BaseTrainer 支持用户自定义训练过程中的各个阶段,同时提供了自动混合精度(AMP)、分布式训练(DDP)和动态批量大小调整等高级功能。此外,它还支持训练过程中的日志记录、指标可视化和资源管理,确保训练过程的高效性和可扩展性。