Skip to content

Commit d88e087

Browse files
pierrot0The TensorFlow Datasets Authors
authored andcommitted
more robust way of inferring Builder name, support builder pkg located anywhere.
PiperOrigin-RevId: 502523472
1 parent d1701a6 commit d88e087

File tree

4 files changed

+117
-8
lines changed

4 files changed

+117
-8
lines changed

tensorflow_datasets/core/registered.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
# Keep track of Dict[str (module name), List[DatasetCollectionBuilder]]
5555
_MODULE_TO_DATASET_COLLECTIONS = collections.defaultdict(list)
5656

57+
# eg for dataset "foo": "tensorflow_datasets.datasets.foo.foo_dataset_builder".
58+
_BUILDER_MODULE_SUFFIX = '_dataset_builder'
59+
5760

5861
class DatasetNotFoundError(ValueError):
5962
"""Exception raised when the dataset cannot be found."""
@@ -169,12 +172,19 @@ def __init_subclass__(cls, skip_registration=False, **kwargs): # pylint: disabl
169172
if not cls.__dict__.get('name'):
170173
if cls.__name__ == 'Builder':
171174
# Config-based builders should be defined with a class named "Builder".
172-
# In such a case, the builder name is extracted from the package name:
173-
# 1 package = 1 dataset!.
174-
module_path_components = cls.__module__.rsplit('.', 2)
175-
if len(module_path_components) != 3:
176-
raise AssertionError(f'module path is too short: {cls.__module__}')
177-
cls.name = module_path_components[1]
175+
# In such a case, the builder name is extracted from the module if it
176+
# follows conventions:
177+
module_name = cls.__module__.rsplit('.', 1)[-1]
178+
if module_name.endswith(_BUILDER_MODULE_SUFFIX):
179+
cls.name = module_name[: -len(_BUILDER_MODULE_SUFFIX)]
180+
elif '.' in cls.__module__: # Extract dataset name from package name.
181+
cls.name = cls.__module__.rsplit('.', 2)[-2]
182+
else:
183+
raise AssertionError(
184+
'When using `Builder` as class name, the dataset builder name is '
185+
'inferred from module name if named "*_dataset_builder" or from '
186+
f'package name, but there is no package in "{cls.__module__}".'
187+
)
178188
else: # Legacy builders.
179189
cls.name = naming.camelcase_to_snakecase(cls.__name__)
180190

@@ -261,7 +271,9 @@ def _get_existing_dataset_packages(
261271
]
262272
if child.name not in exceptions:
263273
pkg_path = epath.Path(datasets_dir_path) / child.name
264-
builder_module = f'{ds_dir_pkg}.{child.name}.{child.name}_dataset_builder'
274+
builder_module = (
275+
f'{ds_dir_pkg}.{child.name}.{child.name}{_BUILDER_MODULE_SUFFIX}'
276+
)
265277
datasets[child.name] = (pkg_path, builder_module)
266278
return datasets
267279

tensorflow_datasets/core/registered_test.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
"""Tests for tensorflow_datasets.core.registered."""
1717

1818
import abc
19+
import re
20+
1921
from unittest import mock
2022
import pytest
2123

@@ -26,6 +28,7 @@
2628
from tensorflow_datasets.core import splits
2729
from tensorflow_datasets.core import utils
2830
from tensorflow_datasets.core.utils import py_utils
31+
import tensorflow_datasets.public_api as tfds
2932
from tensorflow_datasets.testing.dummy_config_based_datasets.dummy_ds_1 import dummy_ds_1_dataset_builder
3033

3134

@@ -278,13 +281,57 @@ def test_name_inferred_from_pkg():
278281
assert ds_builder.name == "dummy_ds_1"
279282

280283

284+
def test_name_inferred_from_pkg_level0_fails():
285+
pkg_path = (
286+
tfds.core.tfds_path() / "testing/dummy_config_based_datasets/dummy_ds_2"
287+
)
288+
expected_msg = (
289+
"When using `Builder` as class name, the dataset builder name is "
290+
'inferred from module name if named "*_dataset_builder" or from '
291+
'package name, but there is no package in "dummy_builder".'
292+
)
293+
with tfds.core.utils.add_sys_path(pkg_path):
294+
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
295+
tfds.core.community.builder_cls_from_module("dummy_builder")
296+
297+
298+
@mock.patch.dict(registered._DATASET_REGISTRY, {})
299+
def test_name_inferred_from_pkg_level1():
300+
pkg_path = tfds.core.tfds_path() / "testing/dummy_config_based_datasets"
301+
with tfds.core.utils.add_sys_path(pkg_path):
302+
ds_builder = tfds.core.community.builder_cls_from_module(
303+
"dummy_ds_2.dummy_builder"
304+
)
305+
assert ds_builder.name == "dummy_ds_2"
306+
307+
308+
@mock.patch.dict(registered._DATASET_REGISTRY, {})
309+
def test_name_inferred_from_pkg_level2():
310+
pkg_path = tfds.core.tfds_path() / "testing"
311+
with tfds.core.utils.add_sys_path(pkg_path):
312+
ds_builder = tfds.core.community.builder_cls_from_module(
313+
"dummy_config_based_datasets.dummy_ds_2.dummy_builder"
314+
)
315+
assert ds_builder.name == "dummy_ds_2"
316+
317+
318+
@mock.patch.dict(registered._DATASET_REGISTRY, {})
319+
def test_name_inferred_from_pkg_level3():
320+
pkg_path = tfds.core.tfds_path()
321+
with tfds.core.utils.add_sys_path(pkg_path):
322+
ds_builder = tfds.core.community.builder_cls_from_module(
323+
"testing.dummy_config_based_datasets.dummy_ds_2.dummy_builder"
324+
)
325+
assert ds_builder.name == "dummy_ds_2"
326+
327+
281328
class ConfigBasedBuildersTest(testing.TestCase):
282329

283330
def test__get_existing_dataset_packages(self):
284331
ds_packages = registered._get_existing_dataset_packages(
285332
"testing/dummy_config_based_datasets"
286333
)
287-
self.assertEqual(list(ds_packages.keys()), ["dummy_ds_1"])
334+
self.assertEqual(list(ds_packages.keys()), ["dummy_ds_1", "dummy_ds_2"])
288335
pkg_path, builder_module = ds_packages["dummy_ds_1"]
289336
self.assertEndsWith(
290337
str(pkg_path),
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Similar to `dummy_ds_1`, but the builder module does not follow
2+
the recommended naming.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# coding=utf-8
2+
# Copyright 2022 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Dummy config-based dataset self-contained in a directory.
17+
18+
The builder module intentionnaly does not follow the naming conventions.
19+
This is used in registered_test.py to check the logic to infer dataset
20+
name.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import numpy as np
26+
import tensorflow_datasets.public_api as tfds
27+
28+
29+
class Builder(tfds.core.GeneratorBasedBuilder):
30+
"""Dummy dataset."""
31+
32+
VERSION = tfds.core.Version('1.0.0')
33+
34+
def _info(self):
35+
return self.dataset_info_from_configs(
36+
features=tfds.features.FeaturesDict({'x': np.int64}),
37+
)
38+
39+
def _split_generators(self, dl_manager):
40+
return [
41+
tfds.core.SplitGenerator(
42+
name=tfds.Split.TRAIN,
43+
),
44+
]
45+
46+
def _generate_examples(self):
47+
for i in range(10):
48+
yield i, {'x': i}

0 commit comments

Comments
 (0)