From e4e103854038532dd4798d6662a0017b97f4d0a1 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Fri, 3 Jul 2026 02:34:08 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 942036435 --- export/orbax/export/modules/tensorflow_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/export/orbax/export/modules/tensorflow_module.py b/export/orbax/export/modules/tensorflow_module.py index 99564565e..d1e59ad19 100644 --- a/export/orbax/export/modules/tensorflow_module.py +++ b/export/orbax/export/modules/tensorflow_module.py @@ -280,7 +280,7 @@ def jax_params_to_tf_variables( ) -> PyTree: """Converts `params` to tf.Variables in the same pytree structure.""" mesh = dtensor_utils.get_current_mesh() - default_cpu_device = tf.config.list_logical_devices('CPU')[0] + default_cpu_device = tf.config.list_logical_devices('CPU')[0].name if mesh is not None: if pspecs is None: raise ValueError(