Skip to content

Commit 8fd12f2

Browse files
fchollettensorflower-gardener
authored andcommitted
Remove private Keras imports.
PiperOrigin-RevId: 564240456
1 parent 4735be1 commit 8fd12f2

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_transforms.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@
2828
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry
2929
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
3030

31-
3231
try:
33-
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top
32+
import keras # pylint: disable=g-import-not-at-top
33+
if hasattr(keras, 'src'):
34+
# Path as seen in pip packages as of TF/Keras 2.13.
35+
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member
36+
else:
37+
from keras.backend import unique_object_name # pylint: disable=g-import-not-at-top,g-importing-member
3438
except ImportError:
35-
# Path as seen in pip packages as of TF/Keras 2.13.
36-
from keras.src.backend import unique_object_name # pylint: disable=g-import-not-at-top
39+
unique_object_name = tf._keras_internal.backend.unique_object_name # pylint: disable=protected-access
3740

3841
LayerNode = transforms.LayerNode
3942
LayerPattern = transforms.LayerPattern

0 commit comments

Comments
 (0)