Skip to content

Commit 4735be1

Browse files
fchollettensorflower-gardener
authored andcommitted
Remove private Keras imports.
PiperOrigin-RevId: 563439874
1 parent 4733c85 commit 4735be1

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_transforms.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,17 @@
2929
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
3030

3131
try:
32-
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top
32+
# OSS
33+
import keras # pylint: disable=g-import-not-at-top
34+
if hasattr(keras, 'src'):
35+
# Path as seen in pip packages as of TF/Keras 2.13.
36+
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top
37+
else:
38+
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member
3339
except ImportError:
34-
# Path as seen in pip packages as of TF/Keras 2.13.
35-
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top
40+
# Internal
41+
unique_object_name = tf._keras_internal.backend.unique_object_name # pylint: disable=protected-access
42+
3643

3744
LayerNode = transforms.LayerNode
3845
LayerPattern = transforms.LayerPattern

0 commit comments

Comments
 (0)