|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +import json |
| 17 | +import math |
| 18 | +from unittest import result |
| 19 | + |
| 20 | +import cv2 |
| 21 | +import numpy as np |
| 22 | +import paddle |
| 23 | +from PIL import Image |
| 24 | + |
| 25 | +from paddleseg import utils |
| 26 | +from paddleseg.core import infer |
| 27 | +from paddleseg.utils import logger, progbar, visualize, metrics |
| 28 | + |
| 29 | + |
| 30 | +def mkdir(path): |
| 31 | + sub_dir = os.path.dirname(path) |
| 32 | + if not os.path.exists(sub_dir): |
| 33 | + os.makedirs(sub_dir) |
| 34 | + |
| 35 | + |
| 36 | +def partition_list(arr, m): |
| 37 | + """split the list 'arr' into m pieces""" |
| 38 | + n = int(math.ceil(len(arr) / float(m))) |
| 39 | + return [arr[i:i + n] for i in range(0, len(arr), n)] |
| 40 | + |
| 41 | + |
| 42 | +def preprocess(im_path, transforms): |
| 43 | + data = {} |
| 44 | + data['img'] = im_path |
| 45 | + data = transforms(data) |
| 46 | + data['img'] = data['img'][np.newaxis, ...] |
| 47 | + data['img'] = paddle.to_tensor(data['img']) |
| 48 | + return data |
| 49 | + |
| 50 | + |
| 51 | +def analyse(model, |
| 52 | + model_path, |
| 53 | + transforms, |
| 54 | + val_dataset, |
| 55 | + save_dir='output', |
| 56 | + aug_pred=False, |
| 57 | + scales=1.0, |
| 58 | + flip_horizontal=True, |
| 59 | + flip_vertical=False, |
| 60 | + is_slide=False, |
| 61 | + stride=None, |
| 62 | + crop_size=None, |
| 63 | + custom_color=None): |
| 64 | + """ |
| 65 | + predict and visualize the image_list. |
| 66 | +
|
| 67 | + Args: |
| 68 | + model (nn.Layer): Used to predict for input image. |
| 69 | + model_path (str): The path of pretrained model. |
| 70 | + transforms (transform.Compose): Preprocess for input image. |
| 71 | + val_dataset (paddle.io.Dataset): Used to read and process validation datasets. |
| 72 | + save_dir (str, optional): The directory to save the visualized results. Default: 'output'. |
| 73 | + aug_pred (bool, optional): Whether to use mulit-scales and flip augment for predition. Default: False. |
| 74 | + scales (list|float, optional): Scales for augment. It is valid when `aug_pred` is True. Default: 1.0. |
| 75 | + flip_horizontal (bool, optional): Whether to use flip horizontally augment. It is valid when `aug_pred` is True. Default: True. |
| 76 | + flip_vertical (bool, optional): Whether to use flip vertically augment. It is valid when `aug_pred` is True. Default: False. |
| 77 | + is_slide (bool, optional): Whether to predict by sliding window. Default: False. |
| 78 | + stride (tuple|list, optional): The stride of sliding window, the first is width and the second is height. |
| 79 | + It should be provided when `is_slide` is True. |
| 80 | + crop_size (tuple|list, optional): The crop size of sliding window, the first is width and the second is height. |
| 81 | + It should be provided when `is_slide` is True. |
| 82 | + custom_color (list, optional): Save images with a custom color map. Default: None, use paddleseg's default color map. |
| 83 | +
|
| 84 | + """ |
| 85 | + utils.utils.load_entire_model(model, model_path) |
| 86 | + model.eval() |
| 87 | + file_list = val_dataset.file_list |
| 88 | + dataset_root = val_dataset.dataset_root |
| 89 | + nranks = paddle.distributed.get_world_size() |
| 90 | + local_rank = paddle.distributed.get_rank() |
| 91 | + if nranks > 1: |
| 92 | + paddle.distributed.init_parallel_env() |
| 93 | + file_list = partition_list(file_list, nranks) |
| 94 | + else: |
| 95 | + file_list = [file_list] |
| 96 | + |
| 97 | + added_saved_dir = os.path.join(save_dir, 'added_prediction') |
| 98 | + pred_saved_dir = os.path.join(save_dir, 'pseudo_color_prediction') |
| 99 | + results = {} |
| 100 | + |
| 101 | + logger.info("Start to predict...") |
| 102 | + progbar_pred = progbar.Progbar(target=len(file_list[0]), verbose=1) |
| 103 | + color_map = visualize.get_color_map_list(256, custom_color=custom_color) |
| 104 | + with paddle.no_grad(): |
| 105 | + for i, (im_path, label_path) in enumerate(file_list[local_rank]): |
| 106 | + data = preprocess(im_path, transforms) |
| 107 | + |
| 108 | + if aug_pred: |
| 109 | + pred, _ = infer.aug_inference( |
| 110 | + model, |
| 111 | + data['img'], |
| 112 | + trans_info=data['trans_info'], |
| 113 | + scales=scales, |
| 114 | + flip_horizontal=flip_horizontal, |
| 115 | + flip_vertical=flip_vertical, |
| 116 | + is_slide=is_slide, |
| 117 | + stride=stride, |
| 118 | + crop_size=crop_size) |
| 119 | + else: |
| 120 | + pred, _ = infer.inference( |
| 121 | + model, |
| 122 | + data['img'], |
| 123 | + trans_info=data['trans_info'], |
| 124 | + is_slide=is_slide, |
| 125 | + stride=stride, |
| 126 | + crop_size=crop_size) |
| 127 | + pred = paddle.squeeze(pred) |
| 128 | + |
| 129 | + # calculate miou for the image |
| 130 | + label = paddle.to_tensor( |
| 131 | + np.asarray(Image.open(label_path)), dtype=pred.dtype) |
| 132 | + intersect_area, pred_area, label_area = metrics.calculate_area( |
| 133 | + pred, label, val_dataset.num_classes) |
| 134 | + class_iou_per_img, miou_per_img = metrics.mean_iou( |
| 135 | + intersect_area, pred_area, label_area) |
| 136 | + results[im_path] = { |
| 137 | + 'miou': miou_per_img, |
| 138 | + 'class_iou': list(class_iou_per_img), |
| 139 | + 'label_path': label_path |
| 140 | + } |
| 141 | + |
| 142 | + pred = pred.numpy().astype('uint8') |
| 143 | + # get the saved name |
| 144 | + if dataset_root is not None: |
| 145 | + im_file = im_path.replace(dataset_root, '') |
| 146 | + else: |
| 147 | + im_file = os.path.basename(im_path) |
| 148 | + if im_file[0] == '/' or im_file[0] == '\\': |
| 149 | + im_file = im_file[1:] |
| 150 | + |
| 151 | + # save added image |
| 152 | + added_image = utils.visualize.visualize( |
| 153 | + im_path, pred, color_map, weight=0.6) |
| 154 | + added_image_path = os.path.join(added_saved_dir, im_file) |
| 155 | + mkdir(added_image_path) |
| 156 | + cv2.imwrite(added_image_path, added_image) |
| 157 | + results[im_path]['added_path'] = added_image_path |
| 158 | + |
| 159 | + # save pseudo color prediction |
| 160 | + pred_mask = utils.visualize.get_pseudo_color_map(pred, color_map) |
| 161 | + pred_saved_path = os.path.join( |
| 162 | + pred_saved_dir, os.path.splitext(im_file)[0] + ".png") |
| 163 | + mkdir(pred_saved_path) |
| 164 | + pred_mask.save(pred_saved_path) |
| 165 | + results[im_path]['prediction_path'] = pred_saved_path |
| 166 | + |
| 167 | + progbar_pred.update(i + 1) |
| 168 | + if nranks > 1: |
| 169 | + results_list = [] |
| 170 | + paddle.distributed.all_gather_object(results_list, results) |
| 171 | + if local_rank == 0: |
| 172 | + results = {} |
| 173 | + for d in results_list: |
| 174 | + results.update(d) |
| 175 | + if local_rank == 0: |
| 176 | + with open(os.path.join(save_dir, 'analysis_results.json'), 'w') as f: |
| 177 | + json.dump(results, f, indent=4) |
| 178 | + |
| 179 | + logger.info("Samples analysis finished, the results are save in {}.".format( |
| 180 | + save_dir)) |
0 commit comments