Skip to content

Commit e37b7c6

Browse files
authored
Add Panoptic Segmentation.
1 parent a79cb15 commit e37b7c6

32 files changed

+3755
-4
lines changed

contrib/PanopticDeepLab/README.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
2+
# Panoptic DeepLab
3+
4+
基于PaddlePaddle实现[Panoptic Deeplab](https://arxiv.org/abs/1911.10194)全景分割算法。
5+
6+
Panoptic DeepLab首次证实了bottem-up算法能够达到state-of-the-art的效果。Panoptic DeepLab预测三个输出:Semantic Segmentation, Center Prediction 和 Center Regression。实例类别像素根据最近距离原则聚集到实例中心点得到实例分割结果。最后按照majority-vote规则融合语义分割结果和实例分割结果,得到最终的全景分割结果。
7+
其通过将每一个像素赋值给每一个类别或实例达到分割的效果。
8+
![](./docs/panoptic_deeplab.jpg)
9+
10+
## Model Baselines
11+
12+
### Cityscapes
13+
| Backbone | Batch Size |Resolution | Training Iters | PQ | SQ | RQ | AP | mIoU | Links |
14+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
15+
|ResNet50_OS32| 8 | 2049x1025|90000|58.35%|80.03%|71.52%|25.80%|79.18%|[model](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005/train.log)|
16+
|ResNet50_OS32| 64 | 1025x513|90000|60.32%|80.56%|73.56%|26.77%|79.67%|[model](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pnoptic_segmentation/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005/train.log)|
17+
18+
## 环境准备
19+
20+
1. 系统环境
21+
* PaddlePaddle >= 2.0.0
22+
* Python >= 3.6+
23+
推荐使用GPU版本的PaddlePaddle版本。详细安装教程请参考官方网站[PaddlePaddle](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/windows-pip.html)
24+
25+
2. 下载PaddleSeg repo
26+
```shell
27+
git clone https://github.com/PaddlePaddle/PaddleSeg
28+
```
29+
30+
3. 安装paddleseg
31+
```shell
32+
cd PaddleSeg
33+
pip install -e .
34+
```
35+
36+
4. 进入PaddleSeg/contrib/PanopticDeepLab目录
37+
```shell
38+
cd contrib/PanopticDeepLab
39+
```
40+
41+
## 数据集准备
42+
43+
将数据集放置于`data`目录下。
44+
45+
### Cityscapes
46+
47+
前往[CityScapes官网](https://www.cityscapes-dataset.com/)下载数据集并整理成如下结构:
48+
49+
```
50+
cityscapes/
51+
|--gtFine/
52+
| |--train/
53+
| | |--aachen/
54+
| | | |--*_color.png, *_instanceIds.png, *_labelIds.png, *_polygons.json,
55+
| | | |--*_labelTrainIds.png
56+
| | | |--...
57+
| |--val/
58+
| |--test/
59+
| |--cityscapes_panoptic_train_trainId.json
60+
| |--cityscapes_panoptic_train_trainId/
61+
| | |-- *_panoptic.png
62+
| |--cityscapes_panoptic_val_trainId.json
63+
| |--cityscapes_panoptic_val_trainId/
64+
| | |-- *_panoptic.png
65+
|--leftImg8bit/
66+
| |--train/
67+
| |--val/
68+
| |--test/
69+
70+
```
71+
72+
安装CityscapesScripts
73+
```shell
74+
pip install git+https://github.com/mcordts/cityscapesScripts.git
75+
```
76+
77+
`*_panoptic.png` 生成命令(需找到`createPanopticImgs.py`文件):
78+
```shell
79+
python /path/to/cityscapesscripts/preparation/createPanopticImgs.py \
80+
--dataset-folder data/cityscapes/gtFine/ \
81+
--output-folder data/cityscapes/gtFine/ \
82+
--use-train-id
83+
```
84+
85+
## 训练
86+
```shell
87+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据实际情况进行显卡数量的设置
88+
python -m paddle.distributed.launch train.py \
89+
--config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml \
90+
--do_eval \
91+
--use_vdl \
92+
--save_interval 5000 \
93+
--save_dir output
94+
```
95+
96+
**note:** 使用--do_eval会影响训练速度及增加显存消耗,根据选择进行开闭。
97+
98+
更多参数信息请运行如下命令进行查看:
99+
```shell
100+
python train.py --help
101+
```
102+
103+
## 评估
104+
```shell
105+
python val.py \
106+
--config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_bs8_90k_lr00005.yml \
107+
--model_path output/iter_90000/model.pdparams
108+
```
109+
你可以直接下载我们提供的模型进行评估。
110+
111+
更多参数信息请运行如下命令进行查看:
112+
```shell
113+
python val.py --help
114+
```
115+
116+
## 预测及可视化结果保存
117+
```shell
118+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # 根据实际情况进行显卡数量的设置
119+
python -m paddle.distributed.launch predict.py \
120+
--config configs/panoptic_deeplab/panoptic_deeplab_resnet50_os32_cityscapes_1025x513_120k.yml \
121+
--model_path output/iter_90000/model.pdparams \
122+
--image_path data/cityscapes/leftImg8bit/val/ \
123+
--save_dir ./output/result
124+
```
125+
你可以直接下载我们提供的模型进行预测。
126+
127+
更多参数信息请运行如下命令进行查看:
128+
```shell
129+
python predict.py --help
130+
```
131+
全景分割结果:
132+
<center>
133+
<img src="docs/visualization_panoptic.png">
134+
</center>
135+
136+
语义分割结果:
137+
<center>
138+
<img src="docs/visualization_semantic.png">
139+
</center>
140+
141+
实例分割结果:
142+
<center>
143+
<img src="docs/visualization_instance.png">
144+
</center>
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
train_dataset:
2+
type: CityscapesPanoptic
3+
dataset_root: data/cityscapes
4+
transforms:
5+
- type: ResizeStepScaling
6+
min_scale_factor: 0.5
7+
max_scale_factor: 2.0
8+
scale_step_size: 0.25
9+
- type: RandomPaddingCrop
10+
crop_size: [2049, 1025]
11+
label_padding_value: [0, 0, 0]
12+
- type: RandomHorizontalFlip
13+
- type: RandomDistort
14+
brightness_range: 0.4
15+
contrast_range: 0.4
16+
saturation_range: 0.4
17+
- type: Normalize
18+
mode: train
19+
ignore_stuff_in_offset: True
20+
small_instance_area: 4096
21+
small_instance_weight: 3
22+
23+
val_dataset:
24+
type: CityscapesPanoptic
25+
dataset_root: data/cityscapes
26+
transforms:
27+
- type: Padding
28+
target_size: [2049, 1025]
29+
label_padding_value: [0, 0, 0]
30+
- type: Normalize
31+
mode: val
32+
ignore_stuff_in_offset: True
33+
small_instance_area: 4096
34+
small_instance_weight: 3
35+
36+
37+
optimizer:
38+
type: adam
39+
40+
learning_rate:
41+
value: 0.00005
42+
decay:
43+
type: poly
44+
power: 0.9
45+
end_lr: 0.0
46+
47+
loss:
48+
types:
49+
- type: CrossEntropyLoss
50+
top_k_percent_pixels: 0.2
51+
- type: MSELoss
52+
reduction: "none"
53+
- type: L1Loss
54+
reduction: "none"
55+
coef: [1, 200, 0.001]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_base_: ./panoptic_deeplab_resnet50_os32_cityscapes_2049x1025_bs1_90k_lr00005.yml
2+
3+
batch_size: 8
4+
5+
train_dataset:
6+
transforms:
7+
- type: ResizeStepScaling
8+
min_scale_factor: 0.5
9+
max_scale_factor: 2.0
10+
scale_step_size: 0.25
11+
- type: RandomPaddingCrop
12+
crop_size: [1025, 513]
13+
label_padding_value: [0, 0, 0]
14+
- type: RandomHorizontalFlip
15+
- type: RandomDistort
16+
brightness_range: 0.4
17+
contrast_range: 0.4
18+
saturation_range: 0.4
19+
- type: Normalize
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
_base_: ../_base_/cityscapes_panoptic.yml
2+
3+
batch_size: 1
4+
iters: 90000
5+
6+
model:
7+
type: PanopticDeepLab
8+
backbone:
9+
type: ResNet50_vd
10+
output_stride: 32
11+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
12+
backbone_indices: [2,1,0,3]
13+
aspp_ratios: [1, 3, 6, 9]
14+
aspp_out_channels: 256
15+
decoder_channels: 256
16+
low_level_channels_projects: [128, 64, 32]
17+
align_corners: True
18+
instance_aspp_out_channels: 256
19+
instance_decoder_channels: 128
20+
instance_low_level_channels_projects: [64, 32, 16]
21+
instance_num_classes: [1, 2]
22+
instance_head_channels: 32
23+
instance_class_key: ["center", "offset"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .train import train
16+
from .val import evaluate
17+
from .predict import predict
18+
from . import infer
19+
20+
__all__ = ['train', 'evaluate', 'predict']

0 commit comments

Comments
 (0)