Skip to content

Commit cf37a4b

Browse files
committed
WIP
WIP WIP WIP Can test version Can test version modify for dump onnx ready version ready version ready version ready version ready version ready version
1 parent 36eec70 commit cf37a4b

37 files changed

+3379
-47
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
import torch.utils.data
3+
import torchvision
4+
5+
from .coco import build as build_coco
6+
7+
8+
def get_coco_api_from_dataset(dataset):
9+
for _ in range(10):
10+
# if isinstance(dataset, torchvision.datasets.CocoDetection):
11+
# break
12+
if isinstance(dataset, torch.utils.data.Subset):
13+
dataset = dataset.dataset
14+
if isinstance(dataset, torchvision.datasets.CocoDetection):
15+
return dataset.coco
16+
17+
18+
def build_dataset(image_set, args):
19+
if args.dataset_file == 'coco':
20+
return build_coco(image_set, args)
21+
if args.dataset_file == 'coco_panoptic':
22+
# to avoid making panopticapi required for coco
23+
from .coco_panoptic import build as build_coco_panoptic
24+
return build_coco_panoptic(image_set, args)
25+
raise ValueError(f'dataset {args.dataset_file} not supported')

examples/DETR_ptq/datasets/coco.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2+
"""
3+
COCO dataset which returns image_id for evaluation.
4+
5+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
6+
"""
7+
from pathlib import Path
8+
9+
import torch
10+
import torch.utils.data
11+
import torchvision
12+
from pycocotools import mask as coco_mask
13+
14+
import datasets.transforms as T
15+
16+
17+
class CocoDetection(torchvision.datasets.CocoDetection):
18+
def __init__(self, img_folder, ann_file, transforms, return_masks):
19+
super(CocoDetection, self).__init__(img_folder, ann_file)
20+
self._transforms = transforms
21+
self.prepare = ConvertCocoPolysToMask(return_masks)
22+
23+
def __getitem__(self, idx):
24+
img, target = super(CocoDetection, self).__getitem__(idx)
25+
image_id = self.ids[idx]
26+
target = {'image_id': image_id, 'annotations': target}
27+
img, target = self.prepare(img, target)
28+
if self._transforms is not None:
29+
img, target = self._transforms(img, target)
30+
return img, target
31+
32+
33+
def convert_coco_poly_to_mask(segmentations, height, width):
34+
masks = []
35+
for polygons in segmentations:
36+
rles = coco_mask.frPyObjects(polygons, height, width)
37+
mask = coco_mask.decode(rles)
38+
if len(mask.shape) < 3:
39+
mask = mask[..., None]
40+
mask = torch.as_tensor(mask, dtype=torch.uint8)
41+
mask = mask.any(dim=2)
42+
masks.append(mask)
43+
if masks:
44+
masks = torch.stack(masks, dim=0)
45+
else:
46+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
47+
return masks
48+
49+
50+
class ConvertCocoPolysToMask(object):
51+
def __init__(self, return_masks=False):
52+
self.return_masks = return_masks
53+
54+
def __call__(self, image, target):
55+
w, h = image.size
56+
57+
image_id = target["image_id"]
58+
image_id = torch.tensor([image_id])
59+
60+
anno = target["annotations"]
61+
62+
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
63+
64+
boxes = [obj["bbox"] for obj in anno]
65+
# guard against no boxes via resizing
66+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
67+
boxes[:, 2:] += boxes[:, :2]
68+
boxes[:, 0::2].clamp_(min=0, max=w)
69+
boxes[:, 1::2].clamp_(min=0, max=h)
70+
71+
classes = [obj["category_id"] for obj in anno]
72+
classes = torch.tensor(classes, dtype=torch.int64)
73+
74+
if self.return_masks:
75+
segmentations = [obj["segmentation"] for obj in anno]
76+
masks = convert_coco_poly_to_mask(segmentations, h, w)
77+
78+
keypoints = None
79+
if anno and "keypoints" in anno[0]:
80+
keypoints = [obj["keypoints"] for obj in anno]
81+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
82+
num_keypoints = keypoints.shape[0]
83+
if num_keypoints:
84+
keypoints = keypoints.view(num_keypoints, -1, 3)
85+
86+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
87+
boxes = boxes[keep]
88+
classes = classes[keep]
89+
if self.return_masks:
90+
masks = masks[keep]
91+
if keypoints is not None:
92+
keypoints = keypoints[keep]
93+
94+
target = {}
95+
target["boxes"] = boxes
96+
target["labels"] = classes
97+
if self.return_masks:
98+
target["masks"] = masks
99+
target["image_id"] = image_id
100+
if keypoints is not None:
101+
target["keypoints"] = keypoints
102+
103+
# for conversion to coco api
104+
area = torch.tensor([obj["area"] for obj in anno])
105+
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
106+
target["area"] = area[keep]
107+
target["iscrowd"] = iscrowd[keep]
108+
109+
target["orig_size"] = torch.as_tensor([int(h), int(w)])
110+
target["size"] = torch.as_tensor([int(h), int(w)])
111+
112+
return image, target
113+
114+
115+
def make_coco_transforms(image_set):
116+
117+
normalize = T.Compose([
118+
T.ToTensor(),
119+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
120+
])
121+
122+
scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
123+
124+
if image_set == 'train':
125+
return T.Compose([
126+
T.RandomHorizontalFlip(),
127+
T.RandomSelect(
128+
T.RandomResize(scales, max_size=1333),
129+
T.Compose([
130+
T.RandomResize([400, 500, 600]),
131+
T.RandomSizeCrop(384, 600),
132+
T.RandomResize(scales, max_size=1333),
133+
])
134+
),
135+
normalize,
136+
])
137+
138+
if image_set == 'val':
139+
return T.Compose([
140+
T.RandomResize([800], max_size=1333),
141+
normalize,
142+
])
143+
144+
raise ValueError(f'unknown {image_set}')
145+
146+
147+
def build(image_set, args):
148+
root = Path(args.coco_path)
149+
assert root.exists(), f'provided COCO path {root} does not exist'
150+
mode = 'instances'
151+
PATHS = {
152+
"train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
153+
"val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
154+
}
155+
156+
img_folder, ann_file = PATHS[image_set]
157+
dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms("val"), return_masks=args.masks)
158+
return dataset

0 commit comments

Comments
 (0)