Skip to content

Commit 3fea7bb

Browse files
authored
Add analysis for val dataset (#3513)
1 parent 63f95e6 commit 3fea7bb

File tree

4 files changed

+342
-2
lines changed

4 files changed

+342
-2
lines changed

paddleseg/core/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .train import train
1616
from .val import evaluate
1717
from .predict import predict
18+
from .analyse import analyse
1819
from . import infer
1920

20-
__all__ = ['train', 'evaluate', 'predict']
21+
__all__ = ['train', 'evaluate', 'predict', 'analyse']

paddleseg/core/analyse.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import json
17+
import math
18+
from unittest import result
19+
20+
import cv2
21+
import numpy as np
22+
import paddle
23+
from PIL import Image
24+
25+
from paddleseg import utils
26+
from paddleseg.core import infer
27+
from paddleseg.utils import logger, progbar, visualize, metrics
28+
29+
30+
def mkdir(path):
31+
sub_dir = os.path.dirname(path)
32+
if not os.path.exists(sub_dir):
33+
os.makedirs(sub_dir)
34+
35+
36+
def partition_list(arr, m):
37+
"""split the list 'arr' into m pieces"""
38+
n = int(math.ceil(len(arr) / float(m)))
39+
return [arr[i:i + n] for i in range(0, len(arr), n)]
40+
41+
42+
def preprocess(im_path, transforms):
43+
data = {}
44+
data['img'] = im_path
45+
data = transforms(data)
46+
data['img'] = data['img'][np.newaxis, ...]
47+
data['img'] = paddle.to_tensor(data['img'])
48+
return data
49+
50+
51+
def analyse(model,
52+
model_path,
53+
transforms,
54+
val_dataset,
55+
save_dir='output',
56+
aug_pred=False,
57+
scales=1.0,
58+
flip_horizontal=True,
59+
flip_vertical=False,
60+
is_slide=False,
61+
stride=None,
62+
crop_size=None,
63+
custom_color=None):
64+
"""
65+
predict and visualize the image_list.
66+
67+
Args:
68+
model (nn.Layer): Used to predict for input image.
69+
model_path (str): The path of pretrained model.
70+
transforms (transform.Compose): Preprocess for input image.
71+
val_dataset (paddle.io.Dataset): Used to read and process validation datasets.
72+
save_dir (str, optional): The directory to save the visualized results. Default: 'output'.
73+
aug_pred (bool, optional): Whether to use mulit-scales and flip augment for predition. Default: False.
74+
scales (list|float, optional): Scales for augment. It is valid when `aug_pred` is True. Default: 1.0.
75+
flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_pred` is True. Default: True.
76+
flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_pred` is True. Default: False.
77+
is_slide (bool, optional): Whether to predict by sliding window. Default: False.
78+
stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height.
79+
It should be provided when `is_slide` is True.
80+
crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height.
81+
It should be provided when `is_slide` is True.
82+
custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map.
83+
84+
"""
85+
utils.utils.load_entire_model(model, model_path)
86+
model.eval()
87+
file_list = val_dataset.file_list
88+
dataset_root = val_dataset.dataset_root
89+
nranks = paddle.distributed.get_world_size()
90+
local_rank = paddle.distributed.get_rank()
91+
if nranks > 1:
92+
paddle.distributed.init_parallel_env()
93+
file_list = partition_list(file_list, nranks)
94+
else:
95+
file_list = [file_list]
96+
97+
added_saved_dir = os.path.join(save_dir, 'added_prediction')
98+
pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction')
99+
results = {}
100+
101+
logger.info("Start to predict...")
102+
progbar_pred = progbar.Progbar(target=len(file_list[0]), verbose=1)
103+
color_map = visualize.get_color_map_list(256, custom_color=custom_color)
104+
with paddle.no_grad():
105+
for i, (im_path, label_path) in enumerate(file_list[local_rank]):
106+
data = preprocess(im_path, transforms)
107+
108+
if aug_pred:
109+
pred, _ = infer.aug_inference(
110+
model,
111+
data['img'],
112+
trans_info=data['trans_info'],
113+
scales=scales,
114+
flip_horizontal=flip_horizontal,
115+
flip_vertical=flip_vertical,
116+
is_slide=is_slide,
117+
stride=stride,
118+
crop_size=crop_size)
119+
else:
120+
pred, _ = infer.inference(
121+
model,
122+
data['img'],
123+
trans_info=data['trans_info'],
124+
is_slide=is_slide,
125+
stride=stride,
126+
crop_size=crop_size)
127+
pred = paddle.squeeze(pred)
128+
129+
# calculate miou for the image
130+
label = paddle.to_tensor(
131+
np.asarray(Image.open(label_path)), dtype=pred.dtype)
132+
intersect_area, pred_area, label_area = metrics.calculate_area(
133+
pred, label, val_dataset.num_classes)
134+
class_iou_per_img, miou_per_img = metrics.mean_iou(
135+
intersect_area, pred_area, label_area)
136+
results[im_path] = {
137+
'miou': miou_per_img,
138+
'class_iou': list(class_iou_per_img),
139+
'label_path': label_path
140+
}
141+
142+
pred = pred.numpy().astype('uint8')
143+
# get the saved name
144+
if dataset_root is not None:
145+
im_file = im_path.replace(dataset_root, '')
146+
else:
147+
im_file = os.path.basename(im_path)
148+
if im_file[0] == '/' or im_file[0] == '\\':
149+
im_file = im_file[1:]
150+
151+
# save added image
152+
added_image = utils.visualize.visualize(
153+
im_path, pred, color_map, weight=0.6)
154+
added_image_path = os.path.join(added_saved_dir, im_file)
155+
mkdir(added_image_path)
156+
cv2.imwrite(added_image_path, added_image)
157+
results[im_path]['added_path'] = added_image_path
158+
159+
# save pseudo color prediction
160+
pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map)
161+
pred_saved_path = os.path.join(
162+
pred_saved_dir, os.path.splitext(im_file)[0] + ".png")
163+
mkdir(pred_saved_path)
164+
pred_mask.save(pred_saved_path)
165+
results[im_path]['prediction_path'] = pred_saved_path
166+
167+
progbar_pred.update(i + 1)
168+
if nranks > 1:
169+
results_list = []
170+
paddle.distributed.all_gather_object(results_list, results)
171+
if local_rank == 0:
172+
results = {}
173+
for d in results_list:
174+
results.update(d)
175+
if local_rank == 0:
176+
with open(os.path.join(save_dir, 'analysis_results.json'), 'w') as f:
177+
json.dump(results, f, indent=4)
178+
179+
logger.info("Samples analysis finished, the results are save in {}.".format(
180+
save_dir))

paddleseg/utils/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def mean_iou(intersect_area, pred_area, label_area):
133133
class_iou = []
134134
for i in range(len(intersect_area)):
135135
if union[i] == 0:
136-
iou = 0
136+
iou = 1
137137
else:
138138
iou = intersect_area[i] / union[i]
139139
class_iou.append(float(iou))

tools/analyse.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
from paddleseg.cvlibs import manager, Config, SegBuilder
18+
from paddleseg.utils import get_sys_env, logger, utils
19+
from paddleseg.core import analyse
20+
from paddleseg.transforms import Compose
21+
22+
23+
def parse_args():
24+
parser = argparse.ArgumentParser(description='Model prediction')
25+
26+
# Common params
27+
parser.add_argument("--config", help="The path of config file.", type=str)
28+
parser.add_argument(
29+
'--model_path',
30+
help='The path of trained weights for prediction.',
31+
type=str)
32+
parser.add_argument(
33+
'--image_path',
34+
help='The image to predict, which can be a path of image, or a file list containing image paths, or a directory including images',
35+
type=str)
36+
parser.add_argument(
37+
'--dataset_path',
38+
help='The image to predict, which can be a path of image, or a file list containing image paths, or a directory including images',
39+
type=str)
40+
41+
parser.add_argument(
42+
'--save_dir',
43+
help='The directory for saving the predicted results.',
44+
type=str,
45+
default='./output/result')
46+
parser.add_argument(
47+
'--device',
48+
help='Set the device place for predicting model.',
49+
default='gpu',
50+
choices=['cpu', 'gpu', 'xpu', 'npu', 'mlu'],
51+
type=str)
52+
parser.add_argument(
53+
'--device_id',
54+
help='Set the device id for predicting model.',
55+
default=0,
56+
type=int)
57+
58+
# Data augment params
59+
parser.add_argument(
60+
'--aug_pred',
61+
help='Whether to use mulit-scales and flip augment for prediction',
62+
action='store_true')
63+
parser.add_argument(
64+
'--scales',
65+
nargs='+',
66+
help='Scales for augment, e.g., `--scales 0.75 1.0 1.25`.',
67+
type=float,
68+
default=1.0)
69+
parser.add_argument(
70+
'--flip_horizontal',
71+
help='Whether to use flip horizontally augment',
72+
action='store_true')
73+
parser.add_argument(
74+
'--flip_vertical',
75+
help='Whether to use flip vertically augment',
76+
action='store_true')
77+
78+
# Sliding window evaluation params
79+
parser.add_argument(
80+
'--is_slide',
81+
help='Whether to predict images in sliding window method',
82+
action='store_true')
83+
parser.add_argument(
84+
'--crop_size',
85+
nargs=2,
86+
help='The crop size of sliding window, the first is width and the second is height.'
87+
'For example, `--crop_size 512 512`',
88+
type=int)
89+
parser.add_argument(
90+
'--stride',
91+
nargs=2,
92+
help='The stride of sliding window, the first is width and the second is height.'
93+
'For example, `--stride 512 512`',
94+
type=int)
95+
96+
# Custom color map
97+
parser.add_argument(
98+
'--custom_color',
99+
nargs='+',
100+
help='Save images with a custom color map. Default: None, use paddleseg\'s default color map.',
101+
type=int)
102+
103+
parser.add_argument(
104+
'--opts',
105+
help='Update the key-value pairs of all options.',
106+
default=None,
107+
nargs='+')
108+
109+
return parser.parse_args()
110+
111+
112+
def merge_test_config(cfg, args):
113+
test_config = cfg.test_config
114+
if 'aug_eval' in test_config:
115+
test_config.pop('aug_eval')
116+
if args.aug_pred:
117+
test_config['aug_pred'] = args.aug_pred
118+
test_config['scales'] = args.scales
119+
test_config['flip_horizontal'] = args.flip_horizontal
120+
test_config['flip_vertical'] = args.flip_vertical
121+
if args.is_slide:
122+
test_config['is_slide'] = args.is_slide
123+
test_config['crop_size'] = args.crop_size
124+
test_config['stride'] = args.stride
125+
if args.custom_color:
126+
test_config['custom_color'] = args.custom_color
127+
return test_config
128+
129+
130+
def main(args):
131+
assert args.config is not None, \
132+
'No configuration file specified, please set --config'
133+
cfg = Config(args.config, opts=args.opts)
134+
builder = SegBuilder(cfg)
135+
test_config = merge_test_config(cfg, args)
136+
137+
utils.show_env_info()
138+
utils.show_cfg_info(cfg)
139+
utils.set_device(args.device)
140+
141+
model = builder.model
142+
transforms = Compose(builder.val_transforms)
143+
val_dataset = builder.val_dataset
144+
145+
logger.info('The number of samples to analyse: {}'.format(
146+
len(val_dataset.file_list)))
147+
148+
analyse(
149+
model,
150+
model_path=args.model_path,
151+
transforms=transforms,
152+
val_dataset=val_dataset,
153+
save_dir=args.save_dir,
154+
**test_config)
155+
156+
157+
if __name__ == '__main__':
158+
args = parse_args()
159+
main(args)

0 commit comments

Comments
 (0)