@@ -115,9 +115,10 @@ def __init__(self, configs):
115
115
# image processing thread worker
116
116
def process_worker (self , imgs , idx , use_pr = False ):
117
117
image_path = imgs [idx ]
118
- im = cv2 .imread (image_path , - 1 )
119
- if len (im .shape ) == 2 :
120
- im = cv2 .cvtColor (im , cv2 .COLOR_GRAY2BGR )
118
+ cv2_imread_flag = cv2 .IMREAD_COLOR
119
+ if self .config .channels == 4 :
120
+ cv2_imread_flag = cv2 .IMREAD_UNCHANGED
121
+ im = cv2 .imread (image_path , cv2_imread_flag )
121
122
channels = im .shape [2 ]
122
123
if channels != 3 and channels != 4 :
123
124
print ("Only support rgb(gray) or rgba image." )
@@ -133,8 +134,10 @@ def process_worker(self, imgs, idx, use_pr=False):
133
134
134
135
# if use models with no pre-processing/post-processing op optimizations
135
136
if not use_pr :
136
- im_mean = np .array (self .config .mean ).reshape ((3 , 1 , 1 ))
137
- im_std = np .array (self .config .std ).reshape ((3 , 1 , 1 ))
137
+ im_mean = np .array (self .config .mean ).reshape ((self .config .channels ,
138
+ 1 , 1 ))
139
+ im_std = np .array (self .config .std ).reshape ((self .config .channels , 1 ,
140
+ 1 ))
138
141
# HWC -> CHW, don't use transpose((2, 0, 1))
139
142
im = im .swapaxes (1 , 2 )
140
143
im = im .swapaxes (0 , 1 )
0 commit comments