Skip to content

【Hackathon 8th No.37】Add YOLO11 for PaddleYOLO #259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions configs/yolo11/_base_/optimizer_600e.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
epoch: 600

LearningRate:
base_lr: 0.01
schedulers:
- !YOLOv5LRDecay
max_epochs: 600
min_lr_ratio: 0.01
- !ExpWarmup
epochs: 3

OptimizerBuilder:
optimizer:
type: Momentum
momentum: 0.937
use_nesterov: True
regularizer:
factor: 0.0005
type: L2
clip_grad_by_value: 10.
20 changes: 20 additions & 0 deletions configs/yolo11/_base_/optimizer_600e_high.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
epoch: 600

LearningRate:
base_lr: 0.01
schedulers:
- !YOLOv5LRDecay
max_epochs: 600
min_lr_ratio: 0.1
- !ExpWarmup
epochs: 3

OptimizerBuilder:
optimizer:
type: Momentum
momentum: 0.937
use_nesterov: True
regularizer:
factor: 0.0005
type: L2
clip_grad_by_value: 10.
35 changes: 35 additions & 0 deletions configs/yolo11/_base_/yolo11_cspdarknet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
architecture: YOLO11
norm_type: sync_bn
use_ema: True
ema_decay: 0.9999
ema_decay_type: "exponential"
act: silu
find_unused_parameters: True

depth_mult: 1.0 # default: L version
width_mult: 1.0
max_channels: 1024

YOLO11:
backbone: YOLO11CSPDarkNet
neck: YOLO11CSPPAN
yolo_head: YOLO11Head
post_process: ~

YOLO11CSPDarkNet:
return_idx: [2, 3, 4]

YOLO11Head:
fpn_strides: [8, 16, 32]
loss_weight: {class: 0.5, iou: 7.5, dfl: 1.5}
assigner:
name: TaskAlignedAssigner
topk: 10
alpha: 0.5
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 3000
keep_top_k: 300
score_threshold: 0.001
nms_threshold: 0.7
45 changes: 45 additions & 0 deletions configs/yolo11/_base_/yolo11_reader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
input_height: &input_height 640
input_width: &input_width 640
input_size: &input_size [*input_height, *input_width]
mosaic_epoch: &mosaic_epoch 490 # last 10 epochs close mosaic, totally 500 epochs as default

worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- MosaicPerspective: {mosaic_prob: 1.0, boxes_normed: False, target_size: *input_size}
- RandomHSV: {hgain: 0.015, sgain: 0.7, vgain: 0.4}
- RandomFlip: {}
batch_transforms:
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: True
drop_last: False
use_shared_memory: True
collate_batch: True
mosaic_epoch: *mosaic_epoch


EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *input_size, keep_ratio: True, interp: 1}
- Pad: {size: *input_size, fill_value: [114., 114., 114.]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 8


TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: *input_size, keep_ratio: True, interp: 1}
- Pad: {size: *input_size, fill_value: [114., 114., 114.]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
fuse_normalize: False
45 changes: 45 additions & 0 deletions configs/yolo11/_base_/yolo11_reader_high_aug.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
input_height: &input_height 640
input_width: &input_width 640
input_size: &input_size [*input_height, *input_width]
mosaic_epoch: &mosaic_epoch 490 # last 10 epochs close mosaic, totally 500 epochs as default

worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- MosaicPerspective: {mosaic_prob: 1.0, boxes_normed: False, target_size: *input_size, scale: 0.9, mixup_prob: 0.1, copy_paste_prob: 0.1}
- RandomHSV: {hgain: 0.015, sgain: 0.7, vgain: 0.4}
- RandomFlip: {}
batch_transforms:
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: True
drop_last: False
use_shared_memory: True
collate_batch: True
mosaic_epoch: *mosaic_epoch


EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *input_size, keep_ratio: True, interp: 1}
- Pad: {size: *input_size, fill_value: [114., 114., 114.]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1


TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: *input_size, keep_ratio: True, interp: 1}
- Pad: {size: *input_size, fill_value: [114., 114., 114.]}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
fuse_normalize: False
18 changes: 18 additions & 0 deletions configs/yolo11/yolo11_n_600e_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_600e.yml',
'_base_/yolo11_cspdarknet.yml',
'_base_/yolo11_reader.yml',
]

depth_mult: 0.5
width_mult: 0.25
max_channels: 1024

log_iter: 50
snapshot_epoch: 1
weights: output/yolo11_n_600e_coco/model_final

TrainReader:
batch_size: 128 # default 1 gpu, total bs = 128
4 changes: 3 additions & 1 deletion ppdet/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from . import rtmdet
from . import detr
from . import yolov10
from . import yolo11

from .meta_arch import *
from .yolo import *
Expand All @@ -36,4 +37,5 @@
from .yolov8 import *
from .rtmdet import *
from .detr import *
from .yolov10 import *
from .yolov10 import *
from .yolo11 import *
114 changes: 114 additions & 0 deletions ppdet/modeling/architectures/yolo11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register, create
from .meta_arch import BaseArch

__all__ = ["YOLO11"]


@register
class YOLO11(BaseArch):
__category__ = "architecture"
__inject__ = ["post_process"]
__shared__ = ["with_mask"]

def __init__(
self,
backbone="YOLO11CSPDarkNet",
neck="YOLO11CSPPAN",
yolo_head="YOLOv8Head",
post_process="BBoxPostProcess",
with_mask=False,
for_mot=False,
):
"""
YOLOv11

Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
yolo_head (nn.Layer): anchor_head instance
for_mot (bool): whether return other features for multi-object tracking
models, default False in pure object detection models.
"""
super().__init__()
self.backbone = backbone
self.neck = neck
self.yolo_head = yolo_head
self.post_process = post_process
self.for_mot = for_mot
self.with_mask = with_mask

@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg["backbone"])

# fpn
kwargs = {"input_shape": backbone.out_shape}
neck = create(cfg["neck"], **kwargs)

# head
kwargs = {"input_shape": neck.out_shape}
yolo_head = create(cfg["yolo_head"], **kwargs)

return {
"backbone": backbone,
"neck": neck,
"yolo_head": yolo_head,
}

def _forward(self):
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats, self.for_mot)

if self.training:
yolo_losses = self.yolo_head(neck_feats, self.inputs)
return yolo_losses
else:
yolo_head_outs = self.yolo_head(neck_feats)
post_outs = self.yolo_head.post_process(
yolo_head_outs,
im_shape=self.inputs["im_shape"],
scale_factor=self.inputs["scale_factor"],
infer_shape=self.inputs["image"].shape[2:],
)

if not isinstance(post_outs, (tuple, list)):
# if set exclude_post_process, concat([pred_bboxes, pred_scores]) not scaled to origin
# export onnx as torch yolo models
return post_outs
else:
if not self.with_mask:
# if set exclude_nms, [pred_bboxes, pred_scores] scaled to origin
bbox, bbox_num = post_outs # default for end-to-end eval/infer
output = {"bbox": bbox, "bbox_num": bbox_num}
else:
bbox, bbox_num, mask = (
post_outs # default for end-to-end eval/infer
)
output = {"bbox": bbox, "bbox_num": bbox_num, "mask": mask}
# Note: YOLOv8 Ins models don't support exclude_post_process or exclude_nms
return output

def get_loss(self):
return self._forward()

def get_pred(self):
return self._forward()
3 changes: 3 additions & 0 deletions ppdet/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from . import vit_mae
from . import hgnet_v2
from . import yolov10_csp_darknet
from . import yolo11_csp_darknet

from .resnet import *
from .darknet import *
Expand All @@ -47,3 +48,5 @@
from .vit_mae import *
from .hgnet_v2 import *
from .yolov10_csp_darknet import *
from .yolo11_csp_darknet import *

Loading