diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 4186dec68..9b676d4cc 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -2,11 +2,17 @@ hardware: 'tpu' skip_jax_distributed_system: False attention: 'flash' -a2v_attention_kernel: 'flash' +a2v_attention_kernel: 'dot_product' v2a_attention_kernel: 'dot_product' attention_sharding_uniform: True precision: 'bf16' + +# For scanning transformer layers scan_layers: True + +# For scanning diffusion loop +scan_diffusion_loop: True + names_which_can_be_saved: [] names_which_can_be_offloaded: [] remat_policy: "NONE" diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 583695fa7..6a5109fe9 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -287,11 +287,11 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) num_context_shards = mesh.shape["context"] query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) key, _ = _reshape_data_for_flash(key, heads, num_context_shards) value, _ = _reshape_data_for_flash(value, heads, num_context_shards) + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) @@ -892,7 +892,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "mlp")), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "mlp")), bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) self.act = get_activation(activation_fn) @@ -904,8 +904,8 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, - kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", None)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None,)), ) def __call__(self, hidden_states: Array) -> Array: diff --git a/src/maxdiffusion/models/ltx2/attention_ltx2.py b/src/maxdiffusion/models/ltx2/attention_ltx2.py index 398b0f473..574f1c38e 100644 --- a/src/maxdiffusion/models/ltx2/attention_ltx2.py +++ b/src/maxdiffusion/models/ltx2/attention_ltx2.py @@ -20,6 +20,7 @@ import jax.numpy as jnp from ... import common_types from ..attention_flax import NNXAttentionOp +from maxdiffusion.tpu_utils import get_tpu_type, TpuType Array = common_types.Array Mesh = common_types.Mesh @@ -349,6 +350,9 @@ def __init__( rope_type: str = "interleaved", flash_block_sizes: BlockSizes = None, flash_min_seq_length: int = 4096, + qkv_sharding_spec: Optional[tuple] = None, + out_sharding_spec: Optional[tuple] = None, + out_bias_sharding_spec: Optional[tuple] = None, ): self.heads = heads self.rope_type = rope_type @@ -356,16 +360,30 @@ def __init__( self.inner_dim = dim_head * heads self.dropout_rate = dropout + # Auto-detect hardware for sharding specs if not overridden + tpu_type = get_tpu_type() + is_ironwood = tpu_type == TpuType.TPU_7X + + # Hardware-aware sharding: Ironwood (v7x) uses 1D sharding along the heads dimension (leaving the embedding dimension replicated) + # to minimize cross-device communication, while other hardware defaults to 2D sharding along both heads and embed dimensions. + # This has currently only been tested on Trillium (v6e) and Ironwood (v7x). + if qkv_sharding_spec is None: + qkv_sharding_spec = (None, "heads") if is_ironwood else ("embed", "heads") + if out_sharding_spec is None: + out_sharding_spec = ("heads", None) if is_ironwood else ("heads", "embed") + if out_bias_sharding_spec is None: + out_bias_sharding_spec = (None,) if is_ironwood else ("embed",) + # 1. Define Partitioned Initializers (Logical Axes) # Q, K, V kernels: [in_features (embed), out_features (heads)] - qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")) + qkv_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), qkv_sharding_spec) # Q, K, V biases: [out_features (heads)] qkv_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("heads",)) # Out kernel: [in_features (heads), out_features (embed)] - out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")) + out_kernel_init = nnx.with_partitioning(nnx.initializers.lecun_normal(), out_sharding_spec) # Out bias: [out_features (embed)] - out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), ("embed",)) + out_bias_init = nnx.with_partitioning(nnx.initializers.zeros_init(), out_bias_sharding_spec) # Norm scales norm_scale_init = nnx.with_partitioning(nnx.initializers.ones_init(), ("norm",)) diff --git a/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py b/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py index 1b43457af..20436f42f 100644 --- a/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py +++ b/src/maxdiffusion/models/ltx2/latent_upsampler_ltx2.py @@ -165,12 +165,12 @@ def __init__(self, in_channels: int, mid_channels: int = 1024, scale: float = 2. in_channels, (num**2) * self.mid_channels, kernel_size=(3, 3), padding=((1, 1), (1, 1)), rngs=rngs ) self.pixel_shuffle = PixelShuffleND(dims=2, upscale_factors=(num, num)) - self.blur = BlurDownsample(dims=2, stride=den) + self.blur_down = BlurDownsample(dims=2, stride=den) def __call__(self, x: jax.Array) -> jax.Array: x = self.conv(x) x = self.pixel_shuffle(x) - x = self.blur(x) + x = self.blur_down(x) return x diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 485d500ef..a3ec9591c 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -28,6 +28,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning from transformers import AutoTokenizer, GemmaTokenizer, GemmaTokenizerFast, Gemma3ForConditionalGeneration +from maxdiffusion.tpu_utils import get_tpu_type, TpuType import qwix from ...utils import logging from ...schedulers import FlaxFlowMatchScheduler @@ -127,6 +128,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): ltx2_config["dtype"] = config.activations_dtype ltx2_config["weights_dtype"] = config.weights_dtype ltx2_config["attention_kernel"] = config.attention + ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash") + ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product") ltx2_config["precision"] = get_precision(config) ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config) ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096) @@ -826,32 +829,67 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] - if prompt_embeds is None: - prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - scale_factor=scale_factor, - dtype=dtype, - ) + tpu_type = get_tpu_type() + # Batching text encoder gives better results on Ironwood (v7x) but poor on Trillium (v6e) + use_batched_text_encoder = tpu_type == TpuType.TPU_7X - if do_classifier_free_guidance and negative_prompt_embeds is None: + if use_batched_text_encoder and prompt_embeds is None and do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." - ) + if isinstance(prompt, str): + prompt = [prompt] - negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( - prompt=negative_prompt, + combined_prompts = prompt + negative_prompt + + combined_embeds, combined_mask = self._get_gemma_prompt_embeds( + prompt=combined_prompts, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, scale_factor=scale_factor, dtype=dtype, ) + split_idx = batch_size * num_videos_per_prompt + + if isinstance(combined_embeds, list): + prompt_embeds = [state[:split_idx] for state in combined_embeds] + negative_prompt_embeds = [state[split_idx:] for state in combined_embeds] + else: + prompt_embeds = combined_embeds[:split_idx] + negative_prompt_embeds = combined_embeds[split_idx:] + + prompt_attention_mask = combined_mask[:split_idx] + negative_prompt_attention_mask = combined_mask[split_idx:] + else: + # Non-batched path (Sequential) + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + dtype=dtype, + ) + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask def check_inputs( @@ -1315,62 +1353,90 @@ def __call__( audio_embeds_sharded = jax.device_put(audio_embeds, spec) timesteps_jax = jnp.array(timesteps, dtype=jnp.float32) - for i in range(len(timesteps_jax)): - t = timesteps_jax[i] - # Isolate input sharding to scan_layers=False to avoid affecting the standard path - latents_jax_sharded = latents_jax - audio_latents_jax_sharded = audio_latents_jax + scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True) - if not self.transformer.scan_layers: - activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) - latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names) - audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names) - - noise_pred, noise_pred_audio = transformer_forward_pass( + if scan_diffusion_loop: + latents_jax, audio_latents_jax = run_diffusion_loop( graphdef, state, - latents_jax_sharded, - audio_latents_jax_sharded, - t, + scheduler_state, + timesteps_jax, + latents_jax, + audio_latents_jax, video_embeds_sharded, audio_embeds_sharded, new_attention_mask, - new_attention_mask, - guidance_scale > 1.0, guidance_scale, latent_num_frames, latent_height, latent_width, audio_num_frames, frame_rate, + batch_size, + self.transformer.scan_layers, + self.scheduler.step, + tuple(tuple(rule) if isinstance(rule, list) else rule for rule in self.config.logical_axis_rules), ) - - if guidance_scale > 1.0: - noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # Audio guidance - noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0) - noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) - - latents_step = latents_jax[batch_size:] - audio_latents_step = audio_latents_jax[batch_size:] - else: - latents_step = latents_jax - audio_latents_step = audio_latents_jax - - # Step - latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False) - audio_latents_step, _ = self.scheduler.step( - scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False - ) - - if guidance_scale > 1.0: - latents_jax = jnp.concatenate([latents_step] * 2, axis=0) - audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0) - else: - latents_jax = latents_step - audio_latents_jax = audio_latents_step + else: + # Old Python loop path + for i in range(len(timesteps_jax)): + t = timesteps_jax[i] + + # Isolate input sharding to scan_layers=False to avoid affecting the standard path + latents_jax_sharded = latents_jax + audio_latents_jax_sharded = audio_latents_jax + + if not self.transformer.scan_layers: + activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + latents_jax_sharded = jax.lax.with_sharding_constraint(latents_jax, activation_axis_names) + audio_latents_jax_sharded = jax.lax.with_sharding_constraint(audio_latents_jax, activation_axis_names) + + noise_pred, noise_pred_audio = transformer_forward_pass( + graphdef, + state, + latents_jax_sharded, + audio_latents_jax_sharded, + t, + video_embeds_sharded, + audio_embeds_sharded, + new_attention_mask, + new_attention_mask, + guidance_scale > 1.0, + guidance_scale, + latent_num_frames, + latent_height, + latent_width, + audio_num_frames, + frame_rate, + ) + + if guidance_scale > 1.0: + noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Audio guidance + noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0) + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + latents_step = latents_jax[batch_size:] + audio_latents_step = audio_latents_jax[batch_size:] + else: + latents_step = latents_jax + audio_latents_step = audio_latents_jax + + # Step + latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False) + + audio_latents_step, _ = self.scheduler.step( + scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False + ) + + if guidance_scale > 1.0: + latents_jax = jnp.concatenate([latents_step] * 2, axis=0) + audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0) + else: + latents_jax = latents_step + audio_latents_jax = audio_latents_step # 8. Decode Latents if guidance_scale > 1.0: @@ -1543,3 +1609,120 @@ def transformer_forward_pass( ) return noise_pred, noise_pred_audio + + +@partial( + jax.jit, + static_argnames=( + "guidance_scale", + "latent_num_frames", + "latent_height", + "latent_width", + "audio_num_frames", + "fps", + "batch_size", + "scan_layers", + "scheduler_step", + "logical_axis_rules", + ), +) +def run_diffusion_loop( + graphdef, + state, + scheduler_state, + timesteps_jax, + latents_jax, + audio_latents_jax, + video_embeds_sharded, + audio_embeds_sharded, + new_attention_mask, + guidance_scale, + latent_num_frames, + latent_height, + latent_width, + audio_num_frames, + fps, + batch_size, + scan_layers, + scheduler_step, + logical_axis_rules, +): + latents_jax = latents_jax.astype(jnp.float32) + audio_latents_jax = audio_latents_jax.astype(jnp.float32) + transformer = nnx.merge(graphdef, state) + + def scan_body(carry, t, model): + latents, audio_latents, s_state = carry + + with nn_partitioning.axis_rules(logical_axis_rules): + latents_sharded = latents + audio_latents_sharded = audio_latents + + if not scan_layers: + activation_axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed")) + latents_sharded = jax.lax.with_sharding_constraint(latents, activation_axis_names) + audio_latents_sharded = jax.lax.with_sharding_constraint(audio_latents, activation_axis_names) + + # Expand timestep to batch size + t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0]) + + noise_pred, noise_pred_audio = model( + hidden_states=latents_sharded, + encoder_hidden_states=video_embeds_sharded, + timestep=t_expanded, + encoder_attention_mask=new_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + audio_hidden_states=audio_latents_sharded, + audio_encoder_hidden_states=audio_embeds_sharded, + audio_encoder_attention_mask=new_attention_mask, + fps=fps, + audio_num_frames=audio_num_frames, + return_dict=False, + ) + + if guidance_scale > 1.0: + noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Audio guidance + ( + noise_pred_audio_uncond, + noise_pred_audio_text, + ) = jnp.split(noise_pred_audio, 2, axis=0) + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + latents_step = latents[batch_size:] + audio_latents_step = audio_latents[batch_size:] + else: + latents_step = latents + audio_latents_step = audio_latents + + # Step scheduler + latents_step, _ = scheduler_step(s_state, noise_pred, t, latents_step, return_dict=False) + latents_step = latents_step.astype(latents.dtype) + + audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False) + audio_latents_step = audio_latents_step.astype(audio_latents.dtype) + + if guidance_scale > 1.0: + latents_next = jnp.concatenate([latents_step] * 2, axis=0) + audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0) + else: + latents_next = latents_step + audio_latents_next = audio_latents_step + + new_carry = (latents_next, audio_latents_next, s_state) + return new_carry, None + + # Initial carry + initial_carry = (latents_jax, audio_latents_jax, scheduler_state) + + # Run scan + final_carry, _ = nnx.scan( + scan_body, + in_axes=(nnx.Carry, 0, None), + out_axes=(nnx.Carry, 0), + )(initial_carry, timesteps_jax, transformer) + + return final_carry[0], final_carry[1] diff --git a/src/maxdiffusion/tests/ltx2/test_pipeline_ltx2.py b/src/maxdiffusion/tests/ltx2/test_pipeline_ltx2.py index 62d96775b..b3f13a6da 100644 --- a/src/maxdiffusion/tests/ltx2/test_pipeline_ltx2.py +++ b/src/maxdiffusion/tests/ltx2/test_pipeline_ltx2.py @@ -118,9 +118,14 @@ def test_check_inputs(self): with self.assertRaises(ValueError): pipeline.check_inputs(prompt="test", height=64, width=63) + @patch("maxdiffusion.pipelines.ltx2.ltx2_pipeline.get_tpu_type") @patch("maxdiffusion.pipelines.ltx2.ltx2_pipeline.LTX2Pipeline._get_gemma_prompt_embeds") - def test_encode_prompt(self, list_embed_mock): + def test_encode_prompt(self, list_embed_mock, mock_get_tpu_type): """Test conditional encoding of positive and negative prompts.""" + from maxdiffusion.tpu_utils import TpuType + + mock_get_tpu_type.return_value = TpuType.TPU_7X + pipeline = LTX2Pipeline( scheduler=MagicMock(), vae=MagicMock(), @@ -132,29 +137,24 @@ def test_encode_prompt(self, list_embed_mock): vocoder=MagicMock(), ) - prompt_embeds = jnp.zeros((1, 10, 10)) - prompt_attention_mask = jnp.ones((1, 10)) - neg_prompt_embeds = jnp.zeros((1, 10, 10)) - neg_prompt_attention_mask = jnp.ones((1, 10)) + combined_embeds = jnp.zeros((2, 10, 10)) + combined_attention_mask = jnp.ones((2, 10)) - # Mock return values for positive then negative prompt encoding - list_embed_mock.side_effect = [ - (prompt_embeds, prompt_attention_mask), - (neg_prompt_embeds, neg_prompt_attention_mask), - ] + # Mock return value for combined prompt encoding + list_embed_mock.return_value = (combined_embeds, combined_attention_mask) p_e, p_a, n_e, n_a = pipeline.encode_prompt( prompt=["A cute cat"], negative_prompt=["ugly"], do_classifier_free_guidance=True ) # Check mock calls - self.assertEqual(list_embed_mock.call_count, 2) + self.assertEqual(list_embed_mock.call_count, 1) # Check returns - np.testing.assert_array_equal(p_e, prompt_embeds) - np.testing.assert_array_equal(p_a, prompt_attention_mask) - np.testing.assert_array_equal(n_e, neg_prompt_embeds) - np.testing.assert_array_equal(n_a, neg_prompt_attention_mask) + np.testing.assert_array_equal(p_e, combined_embeds[:1]) + np.testing.assert_array_equal(p_a, combined_attention_mask[:1]) + np.testing.assert_array_equal(n_e, combined_embeds[1:]) + np.testing.assert_array_equal(n_a, combined_attention_mask[1:]) @patch("maxdiffusion.pipelines.ltx2.ltx2_pipeline.LTX2Pipeline._get_gemma_prompt_embeds") def test_encode_prompt_no_cfg(self, list_embed_mock): diff --git a/src/maxdiffusion/tpu_utils.py b/src/maxdiffusion/tpu_utils.py index 5697f60cd..7f80a6478 100644 --- a/src/maxdiffusion/tpu_utils.py +++ b/src/maxdiffusion/tpu_utils.py @@ -15,6 +15,7 @@ """ import jax +from enum import Enum def print_device_memory_info(devices): @@ -42,3 +43,23 @@ def print_array_info(array, name): for device_idx in num_devices: jax.debug.print("shape on device {x} : {y}", x=device_idx, y=array.device_buffers[0].shape) jax.debug.print("size on device {x} : {y}", x=device_idx, y=array.device_buffers[device_idx].size / array.size) + + +class TpuType(Enum): + TPU_V6_LITE = "v6e" + TPU_7X = "v7x" + UNKNOWN = "unknown" + + +def get_tpu_type() -> TpuType: + """Detects the current TPU hardware generation.""" + try: + device_kind = jax.devices()[0].device_kind + if "7x" in device_kind: + return TpuType.TPU_7X + elif "v6 lite" in device_kind: + return TpuType.TPU_V6_LITE + else: + return TpuType.UNKNOWN + except Exception: + return TpuType.UNKNOWN