ai-edge-torch-nightly 0.6.0.dev20250619__py3-none-any.whl → 0.6.0.dev20250815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +2 -0
  2. ai_edge_torch/_convert/test/test_convert.py +18 -0
  3. ai_edge_torch/generative/examples/gemma/gemma1.py +43 -8
  4. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +8 -0
  5. ai_edge_torch/generative/examples/gemma3/decoder.py +76 -0
  6. ai_edge_torch/generative/examples/gemma3/gemma3.py +19 -0
  7. ai_edge_torch/generative/examples/gemma3/image_encoder.py +11 -10
  8. ai_edge_torch/generative/layers/attention_utils.py +9 -3
  9. ai_edge_torch/generative/layers/einsum.py +8 -2
  10. ai_edge_torch/generative/quantize/quant_recipes.py +0 -1
  11. ai_edge_torch/lowertools/translate_recipe.py +2 -2
  12. ai_edge_torch/odml_torch/export.py +3 -2
  13. ai_edge_torch/odml_torch/lowerings/_basic.py +12 -0
  14. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +143 -8
  15. ai_edge_torch/odml_torch/optimization_barrier.py +71 -0
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250815.dist-info}/METADATA +2 -2
  18. {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250815.dist-info}/RECORD +21 -20
  19. {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250815.dist-info}/LICENSE +0 -0
  20. {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250815.dist-info}/WHEEL +0 -0
  21. {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250815.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from typing import Any, Callable
17
17
  from ai_edge_torch import fx_infra
18
18
  from ai_edge_torch import lowertools
19
+ from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
19
20
  import torch
20
21
  import torch.utils._pytree as pytree
21
22
 
@@ -276,6 +277,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
276
277
  # Explicitly reshape back to the original shape. This places the ReshapeOp
277
278
  # outside of the HLFB.
278
279
  output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
280
+ output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
279
281
  return output
280
282
 
281
283
  node.target = embedding
@@ -576,6 +576,24 @@ class TestConvert(googletest.TestCase):
576
576
  self.fail(f"Conversion failed with bloat16 inputs: {err}")
577
577
  # pylint: enable=broad-except
578
578
 
579
+ def test_convert_model_with_torch_linspace_operation(self):
580
+ """Test converting a simple model with torch.linspace operation."""
581
+
582
+ class SampleModel(nn.Module):
583
+
584
+ def forward(self, x: torch.Tensor):
585
+ return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64)
586
+
587
+ model = SampleModel().eval()
588
+ args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),)
589
+
590
+ try:
591
+ # Expect this to fix the error during conversion
592
+ ai_edge_torch.convert(model, args)
593
+ except Exception as err:
594
+ self.fail(f"Conversion failed with int64 inputs: {err}")
595
+ # pylint: enable=broad-except
596
+
579
597
  def test_compile_model(self):
580
598
  """Tests AOT compilation of a simple Add module."""
581
599
 
@@ -23,7 +23,7 @@ import ai_edge_torch.generative.utilities.loader as loading_utils
23
23
  import torch
24
24
  from torch import nn
25
25
 
26
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
+ TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
27
27
  ff_up_proj="model.layers.{}.mlp.up_proj",
28
28
  ff_down_proj="model.layers.{}.mlp.down_proj",
29
29
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -36,6 +36,24 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
36
  lm_head=None,
37
37
  )
38
38
 
39
+ TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
40
+ ff_up_proj="model.layers.{}.mlp.up_proj",
41
+ ff_down_proj="model.layers.{}.mlp.down_proj",
42
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
43
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
44
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
45
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
46
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
47
+ pre_attn_norm="model.layers.{}.input_layernorm",
48
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
49
+ embedding="model.embed_tokens",
50
+ final_norm="model.norm",
51
+ )
52
+
53
+ TENSOR_NAMES_DICT = {
54
+ "safetensors": TENSOR_NAMES_SEP_QKV,
55
+ "kaggle": TENSOR_NAMES_FUSED_QKV,
56
+ }
39
57
 
40
58
  class Gemma1(model_builder.DecoderOnlyModel):
41
59
  """A Gemma1 model built from the Edge Generative API layers."""
@@ -94,11 +112,28 @@ def build_2b_model(
94
112
  custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
95
113
  mask_cache_size: int = 0,
96
114
  ) -> nn.Module:
97
- return model_builder.build_decoder_only_model(
98
- checkpoint_path=checkpoint_path,
99
- config=get_model_config_2b(),
100
- tensor_names=TENSOR_NAMES,
101
- model_class=Gemma1,
102
- custom_loader=custom_loader,
103
- mask_cache_size=mask_cache_size,
115
+
116
+ # A list to store the reasons for each failure
117
+ key_errors = []
118
+
119
+ for tensor_names in TENSOR_NAMES_DICT.values():
120
+ try:
121
+ return model_builder.build_decoder_only_model(
122
+ checkpoint_path=checkpoint_path,
123
+ config=get_model_config_2b(),
124
+ tensor_names=tensor_names,
125
+ model_class=Gemma1,
126
+ custom_loader=custom_loader,
127
+ mask_cache_size=mask_cache_size,
128
+ )
129
+ except KeyError as ke:
130
+ # Store the specific key that was missing for later
131
+ key_errors.append(f"Missing key: {ke}")
132
+ continue
133
+
134
+ # If the loop finishes, raise an error with all the collected details
135
+ error_details = "\n".join(key_errors)
136
+ raise RuntimeError(
137
+ "Failed to build model after trying all configurations. "
138
+ f"Encountered the following errors:\n{error_details}"
104
139
  )
@@ -42,6 +42,14 @@ def main(_):
42
42
  ),
43
43
  mask_cache_size=converter.get_mask_cache_size_from_flags(),
44
44
  )
45
+ elif _MODEL_SIZE.value == '270m':
46
+ pytorch_model = gemma3.build_model_270m(
47
+ checkpoint_path,
48
+ custom_loader=loader.maybe_get_custom_loader(
49
+ checkpoint_path, flags.FLAGS.custom_checkpoint_loader
50
+ ),
51
+ mask_cache_size=converter.get_mask_cache_size_from_flags(),
52
+ )
45
53
  else:
46
54
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
47
55
 
@@ -391,6 +391,60 @@ def get_decoder_config_1b() -> cfg.ModelConfig:
391
391
  return config
392
392
 
393
393
 
394
+ def get_decoder_config_270m() -> cfg.ModelConfig:
395
+ """Returns the model config for a Gemma3 270M model."""
396
+ norm_config = cfg.NormalizationConfig(
397
+ type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
398
+ )
399
+ ff_config = cfg.FeedForwardConfig(
400
+ type=cfg.FeedForwardType.GATED,
401
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
402
+ intermediate_size=2048,
403
+ pre_ff_norm_config=norm_config,
404
+ post_ff_norm_config=norm_config,
405
+ )
406
+
407
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
408
+ attn_config = cfg.AttentionConfig(
409
+ num_heads=4,
410
+ head_dim=256,
411
+ num_query_groups=1,
412
+ rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
413
+ rotary_percentage=1.0,
414
+ qkv_transpose_before_split=True,
415
+ query_norm_config=norm_config,
416
+ key_norm_config=norm_config,
417
+ logit_softcap=None,
418
+ sliding_window_size=512,
419
+ attn_type=(
420
+ cfg.AttentionType.GLOBAL
421
+ if (idx + 1) % 6 == 0
422
+ else cfg.AttentionType.LOCAL_SLIDING
423
+ ),
424
+ )
425
+ return cfg.TransformerBlockConfig(
426
+ attn_config=attn_config,
427
+ ff_config=ff_config,
428
+ pre_attention_norm_config=norm_config,
429
+ post_attention_norm_config=norm_config,
430
+ )
431
+
432
+ num_layers = 18
433
+ embedding_dim = 640
434
+ config = cfg.ModelConfig(
435
+ vocab_size=262_144,
436
+ num_layers=num_layers,
437
+ max_seq_len=32_768,
438
+ embedding_dim=embedding_dim,
439
+ embedding_scale=embedding_dim**0.5,
440
+ block_configs=[get_block_config(i) for i in range(num_layers)],
441
+ final_norm_config=norm_config,
442
+ lm_head_use_bias=False,
443
+ final_logit_softcap=None,
444
+ )
445
+ return config
446
+
447
+
394
448
  def get_fake_decoder_config_1b() -> cfg.ModelConfig:
395
449
  """Returns a fake model config for a Gemma3 1B model."""
396
450
  config = get_decoder_config_1b()
@@ -427,3 +481,25 @@ def build_model_1b(
427
481
  )
428
482
  except KeyError as ke:
429
483
  continue
484
+
485
+
486
+ def build_model_270m(
487
+ checkpoint_path: str,
488
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
489
+ mask_cache_size: int = 0,
490
+ ) -> nn.Module:
491
+ """Builds a Gemma3 270M model."""
492
+ # TODO(b/403644647): Better error handling for loading checkpoints with
493
+ # different tensor names.
494
+ for tensor_names in TENSOR_NAMES_DICT.values():
495
+ try:
496
+ return model_builder.build_decoder_only_model(
497
+ checkpoint_path=checkpoint_path,
498
+ config=get_decoder_config_270m(),
499
+ tensor_names=tensor_names,
500
+ model_class=Decoder,
501
+ custom_loader=custom_loader,
502
+ mask_cache_size=mask_cache_size,
503
+ )
504
+ except KeyError as _:
505
+ continue
@@ -179,3 +179,22 @@ def build_model_1b(
179
179
  # TODO: Load the parameters of decoder from checkpoint.
180
180
  model.eval()
181
181
  return model
182
+
183
+
184
+ def build_model_270m(
185
+ checkpoint_path: str,
186
+ custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
187
+ mask_cache_size: int = 0,
188
+ ) -> decoder.Decoder:
189
+ """Builds a Gemma3 270M model."""
190
+ if checkpoint_path:
191
+ model = decoder.build_model_270m(
192
+ checkpoint_path, custom_loader, mask_cache_size
193
+ )
194
+ else:
195
+ config = decoder.get_decoder_config_270m()
196
+ model = decoder.Decoder(config, mask_cache_size)
197
+ # TODO: Load the parameters of decoder from checkpoint.
198
+ model.eval()
199
+ return model
200
+
@@ -24,26 +24,27 @@ import torch.nn.functional as F
24
24
 
25
25
 
26
26
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
27
- ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
28
- ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
27
+ ff_up_proj="siglip_vision_model.encoder_blocks.{}.mlp.fc1",
28
+ ff_down_proj="siglip_vision_model.encoder_blocks.{}.mlp.fc2",
29
29
  attn_query_proj=(
30
- "vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
30
+ "siglip_vision_model.encoder_blocks.{}.self_attn.q_proj"
31
31
  ),
32
32
  attn_key_proj=(
33
- "vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
33
+ "siglip_vision_model.encoder_blocks.{}.self_attn.k_proj"
34
34
  ),
35
35
  attn_value_proj=(
36
- "vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
36
+ "siglip_vision_model.encoder_blocks.{}.self_attn.v_proj"
37
37
  ),
38
38
  attn_output_proj=(
39
- "vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
39
+ "siglip_vision_model.encoder_blocks.{}.self_attn.o_proj"
40
40
  ),
41
- pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
42
- embedding="vision_tower.vision_model.embeddings.patch_embedding",
41
+ pre_attn_norm="siglip_vision_model.encoder_blocks.{}.layer_norm1",
42
+ pre_ff_norm="siglip_vision_model.encoder_blocks.{}.layer_norm2",
43
+ embedding="siglip_vision_model.patch_embedding",
43
44
  embedding_position=(
44
- "vision_tower.vision_model.embeddings.position_embedding.weight"
45
+ "siglip_vision_model.position_embedding.weight"
45
46
  ),
46
- final_norm="vision_tower.vision_model.post_layernorm",
47
+ final_norm="siglip_vision_model.final_norm",
47
48
  )
48
49
 
49
50
 
@@ -61,6 +61,7 @@ def build_causal_mask_cache(
61
61
  size: int,
62
62
  dtype: torch.dtype = torch.float32,
63
63
  device: torch.device = None,
64
+ mask_value: float = float('-inf'),
64
65
  ) -> torch.Tensor:
65
66
  """Build a cache for causal attention mask.
66
67
 
@@ -70,6 +71,8 @@ def build_causal_mask_cache(
70
71
  torch.float32.
71
72
  device (torch.device, optional): Output tensor's data type. Defaults to
72
73
  None in which case "cpu" is used.
74
+ mask_value (float, optional): The value to set the mask to. Defaults to
75
+ float('-inf').
73
76
 
74
77
  Returns:
75
78
  torch.Tensor: Causal attention mask.
@@ -77,7 +80,7 @@ def build_causal_mask_cache(
77
80
 
78
81
  if device is None:
79
82
  device = torch.device('cpu')
80
- mask = torch.full((size, size), float('-inf'), dtype=dtype, device=device)
83
+ mask = torch.full((size, size), mask_value, dtype=dtype, device=device)
81
84
  return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
82
85
 
83
86
 
@@ -86,6 +89,7 @@ def build_sliding_window_mask_cache(
86
89
  window_size: int,
87
90
  dtype: torch.dtype = torch.float32,
88
91
  device: torch.device = None,
92
+ mask_value: float = float('-inf'),
89
93
  ) -> torch.Tensor:
90
94
  """Build a cache for a sliding window mask.
91
95
 
@@ -96,18 +100,20 @@ def build_sliding_window_mask_cache(
96
100
  torch.float32.
97
101
  device (torch.device, optional): Output tensor's data type. Defaults to
98
102
  None in which case "cpu" is used.
103
+ mask_value (float, optional): The value to set the mask to. Defaults to
104
+ float('-inf').
99
105
 
100
106
  Returns:
101
107
  torch.Tensor: Causal attention mask.
102
108
  """
103
109
 
104
- mask = build_causal_mask_cache(size, dtype, device)
110
+ mask = build_causal_mask_cache(size, dtype, device, mask_value)
105
111
  all_ones = torch.ones_like(mask)
106
112
  window_size = min(size, window_size)
107
113
  sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
108
114
  all_ones, window_size - 1
109
115
  )
110
- return torch.where(sliding_mask == 1, mask, float('-inf'))
116
+ return torch.where(sliding_mask == 1, mask, mask_value)
111
117
 
112
118
 
113
119
  def relative_position_bucket(
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  # Einsum layer implementation.
16
16
 
17
- from typing import Sequence
17
+ from typing import Callable, Sequence
18
18
  import torch
19
19
  from torch import nn
20
20
 
@@ -22,7 +22,12 @@ from torch import nn
22
22
  class Einsum(nn.Module):
23
23
  """Einsum layer wrapping over torch.einsum."""
24
24
 
25
- def __init__(self, shape: Sequence[int], einsum_str: str):
25
+ def __init__(
26
+ self,
27
+ shape: Sequence[int],
28
+ einsum_str: str,
29
+ init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
30
+ ):
26
31
  super().__init__()
27
32
  self.shape = shape
28
33
  self.einsum_str = einsum_str
@@ -30,6 +35,7 @@ class Einsum(nn.Module):
30
35
  torch.empty(shape, dtype=torch.float32),
31
36
  requires_grad=False,
32
37
  )
38
+ init_fn(self.w)
33
39
  self.einsum_fn = lambda x: torch.einsum(einsum_str, x, self.w)
34
40
 
35
41
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -76,7 +76,6 @@ def all_supported_int4_dynamic_block_recipe(
76
76
  default=quant_recipe_utils.create_layer_quant_int4_dynamic_block(
77
77
  block_size
78
78
  ),
79
- embedding=quant_recipe_utils.create_layer_quant_int8_dynamic(),
80
79
  _model_config=mcfg,
81
80
  )
82
81
  )
@@ -25,13 +25,13 @@ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
25
25
 
26
26
  _DEFAULT_REGEX_STR = '.*'
27
27
  _SINGULAR_TRANSFORMER_BLOCK_REGEX_STR = 'transformer_block'
28
- _IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
28
+ _IDX_TRANSFORMER_BLOCKS_REGEX_STR = r'transformer_blocks\[{}\]'
29
29
  _ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
30
30
  _FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
31
31
  _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
32
32
  # TODO: b/415833584 - Improve the regex for pre-softmax layer.
33
33
  _DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
34
- _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
34
+ _ANY_TWO_DIGITS_REGEX_STR = r'\d{1,2}'
35
35
 
36
36
 
37
37
  def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
@@ -21,7 +21,6 @@ import operator
21
21
  from typing import Any, Callable, Optional
22
22
 
23
23
  from ai_edge_torch import fx_infra
24
- from jax.lib import xla_extension
25
24
  from jax._src.lib.mlir import ir
26
25
  from jax._src.lib.mlir.dialects import func
27
26
  from jax._src.lib.mlir.dialects import hlo as stablehlo
@@ -35,6 +34,8 @@ from . import lowerings
35
34
 
36
35
  LoweringContext = lowerings.context.LoweringContext
37
36
 
37
+ from jaxlib._jax.mlir import serialize_portable_artifact
38
+
38
39
 
39
40
  def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
40
41
  """Build flattened inputs and metadata from exported program's signature."""
@@ -233,7 +234,7 @@ class MlirLowered:
233
234
  target_version = stablehlo.get_version_from_compatibility_requirement(
234
235
  stablehlo.StablehloCompatibilityRequirement.WEEK_12
235
236
  )
236
- module_bytecode = xla_extension.mlir.serialize_portable_artifact(
237
+ module_bytecode = serialize_portable_artifact(
237
238
  self.module_bytecode, target_version
238
239
  )
239
240
  return module_bytecode
@@ -331,6 +331,18 @@ def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
331
331
  return stablehlo.get_dimension_size(x, dim)
332
332
 
333
333
 
334
+ # Lowering for the addition operator (`+`).
335
+ # Handles cases where one operand is an integer (scalar) and the other is a
336
+ # tensor, broadcasting the scalar to the tensor's shape before addition.
337
+ @lower(operator.add)
338
+ def _operator_add(lctx, self: int | ir.Value, other: int | ir.Value):
339
+ if isinstance(self, int) and isinstance(other, ir.Value):
340
+ self = utils.splat(self, other.type.element_type, other.type.shape)
341
+ if isinstance(other, int) and isinstance(self, ir.Value):
342
+ other = utils.splat(other, self.type.element_type, self.type.shape)
343
+ return stablehlo.add(self, other)
344
+
345
+
334
346
  # Lowering for the subtraction operator (`-`).
335
347
  # Handles cases where one operand is an integer (scalar) and the other is a
336
348
  # tensor, broadcasting the scalar to the tensor's shape before subtraction.
@@ -78,8 +78,6 @@ lower_by_torch_xla2(torch.ops.aten._unsafe_index)
78
78
  lower_by_torch_xla2(torch.ops.aten._unsafe_view)
79
79
  lower_by_torch_xla2(torch.ops.aten.acos)
80
80
  lower_by_torch_xla2(torch.ops.aten.acosh)
81
- lower_by_torch_xla2(torch.ops.aten.add.Scalar)
82
- lower_by_torch_xla2(torch.ops.aten.add.Tensor)
83
81
  lower_by_torch_xla2(torch.ops.aten.addbmm.default)
84
82
  lower_by_torch_xla2(torch.ops.aten.addmm)
85
83
  lower_by_torch_xla2(torch.ops.aten.addmv)
@@ -159,7 +157,6 @@ lower_by_torch_xla2(torch.ops.aten.logical_and)
159
157
  lower_by_torch_xla2(torch.ops.aten.logical_not)
160
158
  lower_by_torch_xla2(torch.ops.aten.logical_or)
161
159
  lower_by_torch_xla2(torch.ops.aten.logical_xor)
162
- lower_by_torch_xla2(torch.ops.aten.lt)
163
160
  lower_by_torch_xla2(torch.ops.aten.max)
164
161
  lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices)
165
162
  lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
@@ -170,8 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mean)
170
167
  lower_by_torch_xla2(torch.ops.aten.min)
171
168
  lower_by_torch_xla2(torch.ops.aten.minimum)
172
169
  lower_by_torch_xla2(torch.ops.aten.mm)
173
- lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
174
- lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
175
170
  lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
176
171
  lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
177
172
  lower_by_torch_xla2(torch.ops.aten.ne)
@@ -215,8 +210,6 @@ lower_by_torch_xla2(torch.ops.aten.sqrt)
215
210
  lower_by_torch_xla2(torch.ops.aten.squeeze)
216
211
  lower_by_torch_xla2(torch.ops.aten.squeeze_copy)
217
212
  lower_by_torch_xla2(torch.ops.aten.stack)
218
- lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
219
- lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
220
213
  lower_by_torch_xla2(torch.ops.aten.sum)
221
214
  lower_by_torch_xla2(torch.ops.aten.t)
222
215
  lower_by_torch_xla2(torch.ops.aten.tan)
@@ -244,7 +237,6 @@ lower_by_torch_xla2(torch.ops.aten.view_as_real)
244
237
  lower_by_torch_xla2(torch.ops.aten.view_copy)
245
238
  lower_by_torch_xla2(torch.ops.aten.where.ScalarOther)
246
239
  lower_by_torch_xla2(torch.ops.aten.where.ScalarSelf)
247
- lower_by_torch_xla2(torch.ops.aten.where.self)
248
240
  lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim)
249
241
  lower_by_torch_xla2(torch.ops.prims.var)
250
242
 
@@ -259,6 +251,149 @@ def _aten_copy(self, src, **kwargs):
259
251
  return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
260
252
 
261
253
 
254
+ @registry.lower(torch.ops.aten.add.Scalar)
255
+ def _aten_add_scalar(lctx: LoweringContext, self, other):
256
+ _log_usage(torch.ops.aten.add.Scalar)
257
+
258
+ @jax_bridge.wrap
259
+ def jax_lowering(self, other):
260
+ other_dtype = jnp.result_type(other)
261
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
262
+ if promoted_type == jnp.float64:
263
+ promoted_type = jnp.float32
264
+ return jnp.add(
265
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
266
+ )
267
+
268
+ return jax_lowering(lctx, self, other)
269
+
270
+
271
+ @registry.lower(torch.ops.aten.add.Tensor)
272
+ def _aten_add_tensor(lctx: LoweringContext, self, other):
273
+ _log_usage(torch.ops.aten.add.Tensor)
274
+
275
+ @jax_bridge.wrap
276
+ def jax_lowering(self, other):
277
+ promoted_type = jnp.promote_types(self.dtype, other.dtype)
278
+ if promoted_type == jnp.float64:
279
+ promoted_type = jnp.float32
280
+ return jnp.add(self.astype(promoted_type), other.astype(promoted_type))
281
+
282
+ return jax_lowering(lctx, self, other)
283
+
284
+
285
+ @registry.lower(torch.ops.aten.sub.Scalar)
286
+ def _aten_sub_scalar(lctx: LoweringContext, self, other):
287
+ _log_usage(torch.ops.aten.sub.Scalar)
288
+
289
+ @jax_bridge.wrap
290
+ def jax_lowering(self, other):
291
+ other_dtype = jnp.result_type(other)
292
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
293
+ if promoted_type == jnp.float64:
294
+ promoted_type = jnp.float32
295
+ return jnp.subtract(
296
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
297
+ )
298
+
299
+ return jax_lowering(lctx, self, other)
300
+
301
+
302
+ @registry.lower(torch.ops.aten.sub.Tensor)
303
+ def _aten_sub_tensor(lctx: LoweringContext, self, other):
304
+ _log_usage(torch.ops.aten.sub.Tensor)
305
+
306
+ @jax_bridge.wrap
307
+ def jax_lowering(self, other):
308
+ promoted_type = jnp.promote_types(self.dtype, other.dtype)
309
+ if promoted_type == jnp.float64:
310
+ promoted_type = jnp.float32
311
+ return jnp.subtract(self.astype(promoted_type), other.astype(promoted_type))
312
+
313
+ return jax_lowering(lctx, self, other)
314
+
315
+
316
+ @registry.lower(torch.ops.aten.lt.Scalar)
317
+ def _aten_lt_scalar(lctx: LoweringContext, self, other):
318
+ _log_usage(torch.ops.aten.lt.Scalar)
319
+
320
+ @jax_bridge.wrap
321
+ def jax_lowering(self, other):
322
+ other_dtype = jnp.result_type(other)
323
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
324
+ if promoted_type == jnp.float64:
325
+ promoted_type = jnp.float32
326
+ return jnp.less(
327
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
328
+ )
329
+
330
+ return jax_lowering(lctx, self, other)
331
+
332
+
333
+ @registry.lower(torch.ops.aten.lt.Tensor)
334
+ def _aten_lt_tensor(lctx: LoweringContext, self, other):
335
+ _log_usage(torch.ops.aten.lt.Tensor)
336
+
337
+ @jax_bridge.wrap
338
+ def jax_lowering(self, other):
339
+ promoted_type = jnp.promote_types(self.dtype, other.dtype)
340
+ return jnp.less(self.astype(promoted_type), other.astype(promoted_type))
341
+
342
+ return jax_lowering(lctx, self, other)
343
+
344
+
345
+ @registry.lower(torch.ops.aten.mul.Scalar)
346
+ def _aten_mul_scalar(lctx: LoweringContext, self, other):
347
+ _log_usage(torch.ops.aten.mul.Scalar)
348
+
349
+ @jax_bridge.wrap
350
+ def jax_lowering(self, other):
351
+ other_dtype = jnp.result_type(other)
352
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
353
+ if promoted_type == jnp.float64:
354
+ promoted_type = jnp.float32
355
+ return jnp.multiply(
356
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
357
+ )
358
+
359
+ return jax_lowering(lctx, self, other)
360
+
361
+
362
+ @registry.lower(torch.ops.aten.mul.Tensor)
363
+ def _aten_mul_tensor(lctx: LoweringContext, self, other):
364
+ _log_usage(torch.ops.aten.mul.Tensor)
365
+
366
+ @jax_bridge.wrap
367
+ def jax_lowering(self, other):
368
+ other_dtype = jnp.result_type(other)
369
+ promoted_type = jnp.promote_types(self.dtype, other_dtype)
370
+ if promoted_type == jnp.float64:
371
+ promoted_type = jnp.float32
372
+ return jnp.multiply(
373
+ self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
374
+ )
375
+
376
+ return jax_lowering(lctx, self, other)
377
+
378
+
379
+ @registry.lower(torch.ops.aten.where.self)
380
+ def _aten_where_self(lctx: LoweringContext, condition, self, other):
381
+ _log_usage(torch.ops.aten.where.self)
382
+
383
+ @jax_bridge.wrap
384
+ def jax_lowering(condition, self, other):
385
+ promoted_type = jnp.promote_types(self.dtype, other.dtype)
386
+ if promoted_type == jnp.float64:
387
+ promoted_type = jnp.float32
388
+ return jnp.where(
389
+ condition,
390
+ self.astype(promoted_type),
391
+ other.astype(promoted_type),
392
+ )
393
+
394
+ return jax_lowering(lctx, condition, self, other)
395
+
396
+
262
397
  # Schema:
263
398
  # - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
264
399
  # -> Tensor
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Optimization barrier op definition and lowering."""
16
+
17
+ from ai_edge_torch.odml_torch import _torch_library
18
+ from ai_edge_torch.odml_torch.lowerings import registry
19
+ from jax._src.lib.mlir import ir
20
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
21
+ import torch
22
+ import torch.utils._pytree as pytree
23
+
24
+ _torch_library.ODML_TORCH_LIB.define(
25
+ "optimization_barrier(Tensor[] inputs) -> Tensor[]"
26
+ )
27
+
28
+ optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default
29
+
30
+
31
+ def optimization_barrier(*inputs: pytree.PyTree):
32
+ """Apply optimization barrier to the tensors nested within arbitrary pytrees.
33
+
34
+ Args:
35
+ *inputs: A list of tensors or tensor pytrees.
36
+
37
+ Returns:
38
+ The tensors after optimization barrier in the same pytrees structures.
39
+ """
40
+ if len(inputs) == 1:
41
+ inputs = inputs[0]
42
+ tensors, spec = pytree.tree_flatten(inputs)
43
+ tensors = optimization_barrier_op(tuple(tensors))
44
+ outputs = pytree.tree_unflatten(tensors, spec)
45
+ return outputs
46
+
47
+
48
+ @torch.library.impl(
49
+ _torch_library.ODML_TORCH_LIB,
50
+ "optimization_barrier",
51
+ "CompositeExplicitAutograd",
52
+ )
53
+ def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
54
+ return tuple(inputs)
55
+
56
+
57
+ @torch.library.impl(
58
+ _torch_library.ODML_TORCH_LIB,
59
+ "optimization_barrier",
60
+ "Meta",
61
+ )
62
+ def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
63
+ return tuple([torch.empty_like(x) for x in inputs])
64
+
65
+
66
+ @registry.lower(torch.ops.odml_torch.optimization_barrier.default)
67
+ def _optimization_barrier_lowering(
68
+ lctx, inputs: tuple[ir.Value, ...]
69
+ ) -> ir.Value:
70
+ del lctx
71
+ return stablehlo.optimization_barrier(inputs)
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250619"
18
+ __version__ = "0.6.0.dev20250815"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250619
3
+ Version: 0.6.0.dev20250815
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -31,7 +31,7 @@ Requires-Dist: transformers
31
31
  Requires-Dist: kagglehub
32
32
  Requires-Dist: tabulate
33
33
  Requires-Dist: torch>=2.4.0
34
- Requires-Dist: tf-nightly>=2.19.0.dev20250101
34
+ Requires-Dist: tf-nightly<=2.20.0.dev20250615
35
35
  Requires-Dist: ai-edge-litert-nightly
36
36
  Requires-Dist: ai-edge-quantizer-nightly
37
37
  Requires-Dist: jax
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=aCO6sA_1IPQGd5f8Ya-ce4ZKJE1EEt2BkypXJLQ3qvI,806
5
+ ai_edge_torch/version.py,sha256=IVKlDTOouEuxQaKYFDSypWA39jM5U2e7-X4oph76QyM,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -10,7 +10,7 @@ ai_edge_torch/_convert/converter.py,sha256=6MLKELzAwFoiXv-b7KRYi7gc7Z57XOeowcz9A
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
11
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
12
12
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=uHek7k9KIW3kaEM_lcygbukJ69JLjm-xnYUWzAEIZII,1345
13
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=dgUO-lI9Id9hIOHP5XmegVlu5Fl79GR4_b-lDUehzoo,11428
13
+ ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=D_7RQp3F7XzJZ3d0Sgay7hf_oz3IhKZCZ-F-sDH5LEY,11589
14
14
  ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
15
15
  ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py,sha256=jMl9YHIbx08KQHbp9UgDnxviUUWiN-FSsiUgR2HCT5s,1576
16
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
28
28
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=oXbr9G5Jc21xd1dr2CDrp774I4crs0_kkN490K5fNn0,7312
29
29
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
30
- ai_edge_torch/_convert/test/test_convert.py,sha256=yQC0WZk_gzReguTOfgWWodK71jnfMjYoRF29_Kafnuw,18692
30
+ ai_edge_torch/_convert/test/test_convert.py,sha256=ioeEpeS07NAs14-nHiPI-6lLTtnALxl8uNtTKvoHdgE,19316
31
31
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
32
32
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
33
33
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -65,16 +65,16 @@ ai_edge_torch/generative/examples/deepseek/verify_util.py,sha256=xf2-jhv-ORwr_-R
65
65
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=m5N3MxNX4Yu8-0vXRszkMgfVuFKN6gskVoUIGqonJFk,1814
67
67
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=fR4869w1UZIuUVGLDdPof0IDEMDhqn2ej5kl7QW9vNA,1882
68
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=TH9XQAp5p4S829XbaWbJQZBwB18WizDRIQMsUkKqj38,3377
68
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=8oLB0PaSrHMy0tDT--qPIEjTm0qNT16kqRBCHqAALSc,4674
69
69
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=E6jotWYYIx6SXUWqurKWjiZpbfj_M2jJrBc2rQ90z1s,11782
70
70
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=USyist332qZkhCBof2tJwQSqtnKjTQsKAK_jCE_CO2U,1853
71
71
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=eAM7EVVMW-QCqjeZEss7TOkVKArgUs1La51LAC-5a9A,1962
72
72
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=b12naCj4zZxOjkIKrd08qovtajYuX-Ba3fbrv6kkDZs,8410
73
73
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
74
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=UEDNN3JmI31WfE2pvacxeJpqumKK86L2dEus3yTURaY,2114
75
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=1UVv9SFFg5degX3wf-Fefx7nor1AzJj2NWBVuo8bRnM,15540
76
- ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=fFMyIS8si3GdwW8EsdhYk1OKyg_27xDv1HTQ2Gv4N8E,6616
77
- ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=tUOI99kdny33qcDM7-z0R6F-1aU1lZ24kG5zeLVdwow,5129
74
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=b2l8R3-3xhurmXn3LRZ-C12J11C8tx77NXKze6jFYbA,2425
75
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=DEkOnfHZJPaxbjwegarRjGpBa7gr6IuWMf8xgZP2mtU,17951
76
+ ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=PAtuqvBkq6TeNFImoJfCaLX5pZCBOcaQjFvwur_XQJ4,7141
77
+ ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=OCMIAQfNmPR4uQUAtlYL6j4xkG0dw2Ays4-lnThcWqQ,5110
78
78
  ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
79
79
  ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=5OmUwz38kVHYLA-v8U8evvDN9da2WioZtGo-XK6yq1o,10067
80
80
  ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -171,10 +171,10 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A
171
171
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
172
172
  ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6imHQ-aCXtmPXCh_o,13911
173
173
  ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
174
- ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
174
+ ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
175
175
  ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
176
176
  ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
177
- ai_edge_torch/generative/layers/einsum.py,sha256=EsZSWNVWUs0-1plp4TBnhP4ZhaRDBa2VlDO6hWpUAqU,1288
177
+ ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
178
178
  ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
179
179
  ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
180
180
  ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
@@ -196,7 +196,7 @@ ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FT
196
196
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
197
197
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
198
198
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
199
- ai_edge_torch/generative/quantize/quant_recipes.py,sha256=45DJfcQXZ1FA1qI4LgYoYE4UD4yvfIYoY9LgYTeKFVw,2845
199
+ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=mmPsgQ2vloMWflqJ6ALmD1lANacTkWotEhurb5foK30,2771
200
200
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
201
201
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
202
202
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
@@ -231,12 +231,13 @@ ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5z
231
231
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
232
232
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
233
233
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
234
- ai_edge_torch/lowertools/translate_recipe.py,sha256=JNsRc1Jmpj5W6PBww8KRMkbtxcv7ssl8Rr1R3x5_7to,6283
234
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=Oavs0ENKVnIeB-WidXvokTPqNlFfOuP0GMV0d5RK2Rg,6285
235
235
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
236
236
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
237
237
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
238
- ai_edge_torch/odml_torch/export.py,sha256=lbLpdGa8MDE8oWNA7aSV3tOCQ9P9I2Ox95dSPEssn-g,14930
238
+ ai_edge_torch/odml_torch/export.py,sha256=FDseiAOOcgN8HOuwv6HT9sALZgXGj7vZi6kwZPWPF14,14935
239
239
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
240
+ ai_edge_torch/odml_torch/optimization_barrier.py,sha256=2lmSiu5iXWLFWpupZHvsVeNYNzG5AVGSK3K_CNhS5Sk,2290
240
241
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
241
242
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
242
243
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
@@ -248,11 +249,11 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
248
249
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
249
250
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
250
251
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
251
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VWb5HEeVljnuXi1eecKp1ieOIcBrSLlu7YIZnxnrozU,12198
252
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvsZj1CJo7y79cuzRg,12784
252
253
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
253
254
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
254
255
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
255
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=tkaDo232HjuZvJHyua0n6tdHecifUuVzclJAGq7PPYs,11428
256
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=xUuQjoR0NJhuwG36GuycpKHo9jg783bDSHj9wE4F1Sg,15439
256
257
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
257
258
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
258
259
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -268,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
268
269
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
269
270
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
270
271
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
271
- ai_edge_torch_nightly-0.6.0.dev20250619.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
272
- ai_edge_torch_nightly-0.6.0.dev20250619.dist-info/METADATA,sha256=_nUnboHwt2qZ0ejpyggXZf5q5iuNSNUUboXq6e8uQGw,2074
273
- ai_edge_torch_nightly-0.6.0.dev20250619.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
274
- ai_edge_torch_nightly-0.6.0.dev20250619.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
275
- ai_edge_torch_nightly-0.6.0.dev20250619.dist-info/RECORD,,
272
+ ai_edge_torch_nightly-0.6.0.dev20250815.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
+ ai_edge_torch_nightly-0.6.0.dev20250815.dist-info/METADATA,sha256=ImIHtR76k9wPyaSXpEQLjWUrDq21cfiSR0pNw9RH4cU,2074
274
+ ai_edge_torch_nightly-0.6.0.dev20250815.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
+ ai_edge_torch_nightly-0.6.0.dev20250815.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
+ ai_edge_torch_nightly-0.6.0.dev20250815.dist-info/RECORD,,