Skip to content

Commit 4b125f4

Browse files
marcenacpThe TensorFlow Datasets Authors
authored andcommitted
Fix decoding of grayscale images in TFless path
If the feature is Image(shape=(None, None, 3)), grayscale images should be decoded to (h, w, 3) (not (h, w, 1) as we used to). PiperOrigin-RevId: 512554994
1 parent 412449f commit 4b125f4

File tree

4 files changed

+73
-15
lines changed

4 files changed

+73
-15
lines changed

tensorflow_datasets/core/features/image_feature.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -296,24 +296,31 @@ def decode_example(self, example):
296296

297297
def decode_example_np(self, example: bytes) -> np.ndarray:
298298
"""Reconstruct the image with PIL from bytes."""
299+
channels = self._shape[-1]
299300
if self._dtype == np.uint16:
300301
# PIL does not handle multi-channel 16-bit images, so we use OpenCV.
301-
return self.decode_example_np_with_opencv(example)
302+
return self.decode_example_np_with_opencv(example, channels)
302303
else:
303-
return self.decode_example_np_with_pil(example)
304+
return self.decode_example_np_with_pil(example, channels)
304305

305-
def decode_example_np_with_opencv(self, example: bytes) -> np.ndarray:
306+
def decode_example_np_with_opencv(
307+
self, example: bytes, channels: int
308+
) -> np.ndarray:
306309
try:
307310
cv2 = lazy_imports_lib.lazy_imports.cv2
308311
except ImportError as e:
309312
raise Exception(
310313
'Decoding 16-bit images with NumPy requires OpenCV.'
311314
) from e
312-
buffer = np.frombuffer(example, dtype=np.uint8)
313-
image_cv2 = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)
314-
return image_cv2[:, :, _opencv_to_tfds_channels(self._shape)]
315-
316-
def decode_example_np_with_pil(self, example: bytes) -> np.ndarray:
315+
example = np.frombuffer(example, dtype=np.uint8)
316+
example = cv2.imdecode(example, cv2.IMREAD_UNCHANGED)
317+
if example.ndim == 2:
318+
return _reshape_grayscale_image(example, channels)
319+
return example[:, :, _opencv_to_tfds_channels(self._shape)]
320+
321+
def decode_example_np_with_pil(
322+
self, example: bytes, channels: int
323+
) -> np.ndarray:
317324
try:
318325
PIL_Image = lazy_imports_lib.lazy_imports.PIL_Image # pylint: disable=invalid-name
319326
except ImportError as e:
@@ -322,9 +329,7 @@ def decode_example_np_with_pil(self, example: bytes) -> np.ndarray:
322329
with PIL_Image.open(bytes_io) as image:
323330
dtype = self.np_dtype if self.np_dtype != np.float32 else np.uint8
324331
np_array = np.asarray(image, dtype=dtype)
325-
# Reshape the array if needed
326-
if np_array.ndim == 2: # (h, w)
327-
np_array = np_array[..., None] # (h, w, 1)
332+
np_array = _reshape_grayscale_image(np_array, channels)
328333
if self.np_dtype == np.uint8:
329334
return np_array
330335
# Bitcast 4 channels uint8 -> 1 channel float32.
@@ -536,14 +541,33 @@ def _validate_np_array(
536541

537542

538543
@py_utils.memoize()
539-
def _opencv_to_tfds_channels(shape: utils.Shape) -> Union[int, List[int]]:
544+
def _opencv_to_tfds_channels(shape: utils.Shape) -> Union[int, None, List[int]]:
540545
"""Restore channel in the expected order: OpenCV uses BGR rather than RGB."""
541546
num_channels = shape[-1]
542-
if num_channels == 1:
543-
return 0
544-
elif num_channels == 3:
547+
if num_channels == 3:
545548
return [2, 1, 0] # BGR -> RGB
546549
elif num_channels == 4:
547550
return [2, 1, 0, 3] # BGRa -> RGBa
548551
else:
549552
raise ValueError(f'Unsupported number of channels: {num_channels}')
553+
554+
555+
def _reshape_grayscale_image(
556+
image: np.ndarray, num_channels: int
557+
) -> np.ndarray:
558+
"""Reshape grayscale images: (h, w) or (h, w, 1) -> (h, w, num_channels).
559+
560+
This reproduces the transformation TensorFlow applies to grayscale images.
561+
562+
Args:
563+
image: An image as an np.ndarray.
564+
num_channels: The number of channels in the image feature.
565+
566+
Returns:
567+
The reshaped image.
568+
"""
569+
if image.ndim == 2: # (h, w)
570+
return np.repeat(image[..., None], num_channels, axis=-1)
571+
if image.ndim == 3 and image.shape[2] == 1: # (h, w, 1)
572+
return np.repeat(image, num_channels, axis=-1)
573+
return image

tensorflow_datasets/core/features/image_feature_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ class ImageFeatureTest(
3333
):
3434

3535
@parameterized.parameters(
36+
# Grayscale images
37+
(np.uint8, np.uint8, 1),
38+
(tf.uint8, np.uint8, 1),
39+
(np.uint16, np.uint16, 1),
40+
(tf.uint16, np.uint16, 1),
41+
# 3 channels
3642
(np.uint8, np.uint8, 3),
3743
(tf.uint8, np.uint8, 3),
3844
(np.uint16, np.uint16, 3),
3945
(tf.uint16, np.uint16, 3),
46+
# 4 channels
4047
(np.uint8, np.uint8, 4),
4148
(tf.uint8, np.uint8, 4),
4249
(np.uint16, np.uint16, 4),
@@ -48,10 +55,12 @@ def test_images(self, dtype, np_dtype, channels):
4855

4956
filename = {
5057
np.uint8: {
58+
1: '6pixels_grayscale.png',
5159
3: '6pixels.png',
5260
4: '6pixels_4chan.png',
5361
},
5462
np.uint16: {
63+
1: '6pixels_grayscale_16bit.png',
5564
3: '6pixels_16bit.png',
5665
4: '6pixels_16bit_4chan.png',
5766
},
@@ -116,6 +125,31 @@ def test_images(self, dtype, np_dtype, channels):
116125
raise_cls_np=ValueError,
117126
raise_msg='dtype should be',
118127
),
128+
],
129+
test_attributes=dict(
130+
_encoding_format=None,
131+
_use_colormap=False,
132+
),
133+
)
134+
135+
@parameterized.parameters(
136+
# 3 channels
137+
(np.uint8, np.uint8, 3),
138+
(tf.uint8, np.uint8, 3),
139+
(np.uint16, np.uint16, 3),
140+
(tf.uint16, np.uint16, 3),
141+
# 4 channels
142+
(np.uint8, np.uint8, 4),
143+
(tf.uint8, np.uint8, 4),
144+
(np.uint16, np.uint16, 4),
145+
(tf.uint16, np.uint16, 4),
146+
)
147+
def test_images_with_invalid_shape(self, dtype, np_dtype, channels):
148+
self.assertFeature(
149+
feature=features_lib.Image(shape=(None, None, channels), dtype=dtype),
150+
shape=(None, None, channels),
151+
dtype=dtype,
152+
tests=[
119153
# Invalid number of dimensions
120154
testing.FeatureExpectationItem(
121155
value=randint(256, size=(128, 128), dtype=np_dtype),
Loading
Loading

0 commit comments

Comments
 (0)