@@ -296,24 +296,31 @@ def decode_example(self, example):
296
296
297
297
def decode_example_np (self , example : bytes ) -> np .ndarray :
298
298
"""Reconstruct the image with PIL from bytes."""
299
+ channels = self ._shape [- 1 ]
299
300
if self ._dtype == np .uint16 :
300
301
# PIL does not handle multi-channel 16-bit images, so we use OpenCV.
301
- return self .decode_example_np_with_opencv (example )
302
+ return self .decode_example_np_with_opencv (example , channels )
302
303
else :
303
- return self .decode_example_np_with_pil (example )
304
+ return self .decode_example_np_with_pil (example , channels )
304
305
305
- def decode_example_np_with_opencv (self , example : bytes ) -> np .ndarray :
306
+ def decode_example_np_with_opencv (
307
+ self , example : bytes , channels : int
308
+ ) -> np .ndarray :
306
309
try :
307
310
cv2 = lazy_imports_lib .lazy_imports .cv2
308
311
except ImportError as e :
309
312
raise Exception (
310
313
'Decoding 16-bit images with NumPy requires OpenCV.'
311
314
) from e
312
- buffer = np .frombuffer (example , dtype = np .uint8 )
313
- image_cv2 = cv2 .imdecode (buffer , cv2 .IMREAD_UNCHANGED )
314
- return image_cv2 [:, :, _opencv_to_tfds_channels (self ._shape )]
315
-
316
- def decode_example_np_with_pil (self , example : bytes ) -> np .ndarray :
315
+ example = np .frombuffer (example , dtype = np .uint8 )
316
+ example = cv2 .imdecode (example , cv2 .IMREAD_UNCHANGED )
317
+ if example .ndim == 2 :
318
+ return _reshape_grayscale_image (example , channels )
319
+ return example [:, :, _opencv_to_tfds_channels (self ._shape )]
320
+
321
+ def decode_example_np_with_pil (
322
+ self , example : bytes , channels : int
323
+ ) -> np .ndarray :
317
324
try :
318
325
PIL_Image = lazy_imports_lib .lazy_imports .PIL_Image # pylint: disable=invalid-name
319
326
except ImportError as e :
@@ -322,9 +329,7 @@ def decode_example_np_with_pil(self, example: bytes) -> np.ndarray:
322
329
with PIL_Image .open (bytes_io ) as image :
323
330
dtype = self .np_dtype if self .np_dtype != np .float32 else np .uint8
324
331
np_array = np .asarray (image , dtype = dtype )
325
- # Reshape the array if needed
326
- if np_array .ndim == 2 : # (h, w)
327
- np_array = np_array [..., None ] # (h, w, 1)
332
+ np_array = _reshape_grayscale_image (np_array , channels )
328
333
if self .np_dtype == np .uint8 :
329
334
return np_array
330
335
# Bitcast 4 channels uint8 -> 1 channel float32.
@@ -536,14 +541,33 @@ def _validate_np_array(
536
541
537
542
538
543
@py_utils .memoize ()
539
- def _opencv_to_tfds_channels (shape : utils .Shape ) -> Union [int , List [int ]]:
544
+ def _opencv_to_tfds_channels (shape : utils .Shape ) -> Union [int , None , List [int ]]:
540
545
"""Restore channel in the expected order: OpenCV uses BGR rather than RGB."""
541
546
num_channels = shape [- 1 ]
542
- if num_channels == 1 :
543
- return 0
544
- elif num_channels == 3 :
547
+ if num_channels == 3 :
545
548
return [2 , 1 , 0 ] # BGR -> RGB
546
549
elif num_channels == 4 :
547
550
return [2 , 1 , 0 , 3 ] # BGRa -> RGBa
548
551
else :
549
552
raise ValueError (f'Unsupported number of channels: { num_channels } ' )
553
+
554
+
555
+ def _reshape_grayscale_image (
556
+ image : np .ndarray , num_channels : int
557
+ ) -> np .ndarray :
558
+ """Reshape grayscale images: (h, w) or (h, w, 1) -> (h, w, num_channels).
559
+
560
+ This reproduces the transformation TensorFlow applies to grayscale images.
561
+
562
+ Args:
563
+ image: An image as an np.ndarray.
564
+ num_channels: The number of channels in the image feature.
565
+
566
+ Returns:
567
+ The reshaped image.
568
+ """
569
+ if image .ndim == 2 : # (h, w)
570
+ return np .repeat (image [..., None ], num_channels , axis = - 1 )
571
+ if image .ndim == 3 and image .shape [2 ] == 1 : # (h, w, 1)
572
+ return np .repeat (image , num_channels , axis = - 1 )
573
+ return image
0 commit comments