ai-edge-torch-nightly 0.6.0.dev20250620__py3-none-any.whl → 0.6.0.dev20250816__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +2 -0
- ai_edge_torch/_convert/test/test_convert.py +18 -0
- ai_edge_torch/generative/examples/gemma/gemma1.py +43 -8
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +8 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +76 -0
- ai_edge_torch/generative/examples/gemma3/gemma3.py +19 -0
- ai_edge_torch/generative/layers/einsum.py +8 -2
- ai_edge_torch/generative/quantize/quant_recipes.py +0 -1
- ai_edge_torch/lowertools/translate_recipe.py +2 -2
- ai_edge_torch/odml_torch/export.py +3 -2
- ai_edge_torch/odml_torch/lowerings/_basic.py +12 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +143 -8
- ai_edge_torch/odml_torch/optimization_barrier.py +71 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250620.dist-info → ai_edge_torch_nightly-0.6.0.dev20250816.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.6.0.dev20250620.dist-info → ai_edge_torch_nightly-0.6.0.dev20250816.dist-info}/RECORD +19 -18
- {ai_edge_torch_nightly-0.6.0.dev20250620.dist-info → ai_edge_torch_nightly-0.6.0.dev20250816.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250620.dist-info → ai_edge_torch_nightly-0.6.0.dev20250816.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250620.dist-info → ai_edge_torch_nightly-0.6.0.dev20250816.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
from typing import Any, Callable
|
17
17
|
from ai_edge_torch import fx_infra
|
18
18
|
from ai_edge_torch import lowertools
|
19
|
+
from ai_edge_torch.odml_torch import optimization_barrier as optimization_barrier_lib
|
19
20
|
import torch
|
20
21
|
import torch.utils._pytree as pytree
|
21
22
|
|
@@ -276,6 +277,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
276
277
|
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
277
278
|
# outside of the HLFB.
|
278
279
|
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
|
280
|
+
output, _ = optimization_barrier_lib.optimization_barrier(output, idx)
|
279
281
|
return output
|
280
282
|
|
281
283
|
node.target = embedding
|
@@ -576,6 +576,24 @@ class TestConvert(googletest.TestCase):
|
|
576
576
|
self.fail(f"Conversion failed with bloat16 inputs: {err}")
|
577
577
|
# pylint: enable=broad-except
|
578
578
|
|
579
|
+
def test_convert_model_with_torch_linspace_operation(self):
|
580
|
+
"""Test converting a simple model with torch.linspace operation."""
|
581
|
+
|
582
|
+
class SampleModel(nn.Module):
|
583
|
+
|
584
|
+
def forward(self, x: torch.Tensor):
|
585
|
+
return torch.linspace(0.5, 10.5, steps=x.shape[0], dtype=torch.float64)
|
586
|
+
|
587
|
+
model = SampleModel().eval()
|
588
|
+
args = (torch.randint(0, 100, (10, 10), dtype=torch.int64),)
|
589
|
+
|
590
|
+
try:
|
591
|
+
# Expect this to fix the error during conversion
|
592
|
+
ai_edge_torch.convert(model, args)
|
593
|
+
except Exception as err:
|
594
|
+
self.fail(f"Conversion failed with int64 inputs: {err}")
|
595
|
+
# pylint: enable=broad-except
|
596
|
+
|
579
597
|
def test_compile_model(self):
|
580
598
|
"""Tests AOT compilation of a simple Add module."""
|
581
599
|
|
@@ -23,7 +23,7 @@ import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
23
23
|
import torch
|
24
24
|
from torch import nn
|
25
25
|
|
26
|
-
|
26
|
+
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
|
27
27
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
28
28
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
29
29
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
@@ -36,6 +36,24 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
36
36
|
lm_head=None,
|
37
37
|
)
|
38
38
|
|
39
|
+
TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
|
40
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
41
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
42
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
43
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
44
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
45
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
46
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
47
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
48
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
49
|
+
embedding="model.embed_tokens",
|
50
|
+
final_norm="model.norm",
|
51
|
+
)
|
52
|
+
|
53
|
+
TENSOR_NAMES_DICT = {
|
54
|
+
"safetensors": TENSOR_NAMES_SEP_QKV,
|
55
|
+
"kaggle": TENSOR_NAMES_FUSED_QKV,
|
56
|
+
}
|
39
57
|
|
40
58
|
class Gemma1(model_builder.DecoderOnlyModel):
|
41
59
|
"""A Gemma1 model built from the Edge Generative API layers."""
|
@@ -94,11 +112,28 @@ def build_2b_model(
|
|
94
112
|
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
95
113
|
mask_cache_size: int = 0,
|
96
114
|
) -> nn.Module:
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
115
|
+
|
116
|
+
# A list to store the reasons for each failure
|
117
|
+
key_errors = []
|
118
|
+
|
119
|
+
for tensor_names in TENSOR_NAMES_DICT.values():
|
120
|
+
try:
|
121
|
+
return model_builder.build_decoder_only_model(
|
122
|
+
checkpoint_path=checkpoint_path,
|
123
|
+
config=get_model_config_2b(),
|
124
|
+
tensor_names=tensor_names,
|
125
|
+
model_class=Gemma1,
|
126
|
+
custom_loader=custom_loader,
|
127
|
+
mask_cache_size=mask_cache_size,
|
128
|
+
)
|
129
|
+
except KeyError as ke:
|
130
|
+
# Store the specific key that was missing for later
|
131
|
+
key_errors.append(f"Missing key: {ke}")
|
132
|
+
continue
|
133
|
+
|
134
|
+
# If the loop finishes, raise an error with all the collected details
|
135
|
+
error_details = "\n".join(key_errors)
|
136
|
+
raise RuntimeError(
|
137
|
+
"Failed to build model after trying all configurations. "
|
138
|
+
f"Encountered the following errors:\n{error_details}"
|
104
139
|
)
|
@@ -42,6 +42,14 @@ def main(_):
|
|
42
42
|
),
|
43
43
|
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
44
44
|
)
|
45
|
+
elif _MODEL_SIZE.value == '270m':
|
46
|
+
pytorch_model = gemma3.build_model_270m(
|
47
|
+
checkpoint_path,
|
48
|
+
custom_loader=loader.maybe_get_custom_loader(
|
49
|
+
checkpoint_path, flags.FLAGS.custom_checkpoint_loader
|
50
|
+
),
|
51
|
+
mask_cache_size=converter.get_mask_cache_size_from_flags(),
|
52
|
+
)
|
45
53
|
else:
|
46
54
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
47
55
|
|
@@ -391,6 +391,60 @@ def get_decoder_config_1b() -> cfg.ModelConfig:
|
|
391
391
|
return config
|
392
392
|
|
393
393
|
|
394
|
+
def get_decoder_config_270m() -> cfg.ModelConfig:
|
395
|
+
"""Returns the model config for a Gemma3 270M model."""
|
396
|
+
norm_config = cfg.NormalizationConfig(
|
397
|
+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, zero_centered=True,
|
398
|
+
)
|
399
|
+
ff_config = cfg.FeedForwardConfig(
|
400
|
+
type=cfg.FeedForwardType.GATED,
|
401
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
402
|
+
intermediate_size=2048,
|
403
|
+
pre_ff_norm_config=norm_config,
|
404
|
+
post_ff_norm_config=norm_config,
|
405
|
+
)
|
406
|
+
|
407
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
408
|
+
attn_config = cfg.AttentionConfig(
|
409
|
+
num_heads=4,
|
410
|
+
head_dim=256,
|
411
|
+
num_query_groups=1,
|
412
|
+
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
|
413
|
+
rotary_percentage=1.0,
|
414
|
+
qkv_transpose_before_split=True,
|
415
|
+
query_norm_config=norm_config,
|
416
|
+
key_norm_config=norm_config,
|
417
|
+
logit_softcap=None,
|
418
|
+
sliding_window_size=512,
|
419
|
+
attn_type=(
|
420
|
+
cfg.AttentionType.GLOBAL
|
421
|
+
if (idx + 1) % 6 == 0
|
422
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
423
|
+
),
|
424
|
+
)
|
425
|
+
return cfg.TransformerBlockConfig(
|
426
|
+
attn_config=attn_config,
|
427
|
+
ff_config=ff_config,
|
428
|
+
pre_attention_norm_config=norm_config,
|
429
|
+
post_attention_norm_config=norm_config,
|
430
|
+
)
|
431
|
+
|
432
|
+
num_layers = 18
|
433
|
+
embedding_dim = 640
|
434
|
+
config = cfg.ModelConfig(
|
435
|
+
vocab_size=262_144,
|
436
|
+
num_layers=num_layers,
|
437
|
+
max_seq_len=32_768,
|
438
|
+
embedding_dim=embedding_dim,
|
439
|
+
embedding_scale=embedding_dim**0.5,
|
440
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
441
|
+
final_norm_config=norm_config,
|
442
|
+
lm_head_use_bias=False,
|
443
|
+
final_logit_softcap=None,
|
444
|
+
)
|
445
|
+
return config
|
446
|
+
|
447
|
+
|
394
448
|
def get_fake_decoder_config_1b() -> cfg.ModelConfig:
|
395
449
|
"""Returns a fake model config for a Gemma3 1B model."""
|
396
450
|
config = get_decoder_config_1b()
|
@@ -427,3 +481,25 @@ def build_model_1b(
|
|
427
481
|
)
|
428
482
|
except KeyError as ke:
|
429
483
|
continue
|
484
|
+
|
485
|
+
|
486
|
+
def build_model_270m(
|
487
|
+
checkpoint_path: str,
|
488
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
489
|
+
mask_cache_size: int = 0,
|
490
|
+
) -> nn.Module:
|
491
|
+
"""Builds a Gemma3 270M model."""
|
492
|
+
# TODO(b/403644647): Better error handling for loading checkpoints with
|
493
|
+
# different tensor names.
|
494
|
+
for tensor_names in TENSOR_NAMES_DICT.values():
|
495
|
+
try:
|
496
|
+
return model_builder.build_decoder_only_model(
|
497
|
+
checkpoint_path=checkpoint_path,
|
498
|
+
config=get_decoder_config_270m(),
|
499
|
+
tensor_names=tensor_names,
|
500
|
+
model_class=Decoder,
|
501
|
+
custom_loader=custom_loader,
|
502
|
+
mask_cache_size=mask_cache_size,
|
503
|
+
)
|
504
|
+
except KeyError as _:
|
505
|
+
continue
|
@@ -179,3 +179,22 @@ def build_model_1b(
|
|
179
179
|
# TODO: Load the parameters of decoder from checkpoint.
|
180
180
|
model.eval()
|
181
181
|
return model
|
182
|
+
|
183
|
+
|
184
|
+
def build_model_270m(
|
185
|
+
checkpoint_path: str,
|
186
|
+
custom_loader: Callable[[str], Dict[str, torch.Tensor]] = None,
|
187
|
+
mask_cache_size: int = 0,
|
188
|
+
) -> decoder.Decoder:
|
189
|
+
"""Builds a Gemma3 270M model."""
|
190
|
+
if checkpoint_path:
|
191
|
+
model = decoder.build_model_270m(
|
192
|
+
checkpoint_path, custom_loader, mask_cache_size
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
config = decoder.get_decoder_config_270m()
|
196
|
+
model = decoder.Decoder(config, mask_cache_size)
|
197
|
+
# TODO: Load the parameters of decoder from checkpoint.
|
198
|
+
model.eval()
|
199
|
+
return model
|
200
|
+
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Einsum layer implementation.
|
16
16
|
|
17
|
-
from typing import Sequence
|
17
|
+
from typing import Callable, Sequence
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
|
@@ -22,7 +22,12 @@ from torch import nn
|
|
22
22
|
class Einsum(nn.Module):
|
23
23
|
"""Einsum layer wrapping over torch.einsum."""
|
24
24
|
|
25
|
-
def __init__(
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
shape: Sequence[int],
|
28
|
+
einsum_str: str,
|
29
|
+
init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
|
30
|
+
):
|
26
31
|
super().__init__()
|
27
32
|
self.shape = shape
|
28
33
|
self.einsum_str = einsum_str
|
@@ -30,6 +35,7 @@ class Einsum(nn.Module):
|
|
30
35
|
torch.empty(shape, dtype=torch.float32),
|
31
36
|
requires_grad=False,
|
32
37
|
)
|
38
|
+
init_fn(self.w)
|
33
39
|
self.einsum_fn = lambda x: torch.einsum(einsum_str, x, self.w)
|
34
40
|
|
35
41
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -25,13 +25,13 @@ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
|
25
25
|
|
26
26
|
_DEFAULT_REGEX_STR = '.*'
|
27
27
|
_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR = 'transformer_block'
|
28
|
-
_IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
|
28
|
+
_IDX_TRANSFORMER_BLOCKS_REGEX_STR = r'transformer_blocks\[{}\]'
|
29
29
|
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
|
30
30
|
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
|
31
31
|
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
|
32
32
|
# TODO: b/415833584 - Improve the regex for pre-softmax layer.
|
33
33
|
_DECODE_LOGITS_REGEX_STR = 'StatefulPartitionedCall'
|
34
|
-
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'
|
34
|
+
_ANY_TWO_DIGITS_REGEX_STR = r'\d{1,2}'
|
35
35
|
|
36
36
|
|
37
37
|
def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
|
@@ -21,7 +21,6 @@ import operator
|
|
21
21
|
from typing import Any, Callable, Optional
|
22
22
|
|
23
23
|
from ai_edge_torch import fx_infra
|
24
|
-
from jax.lib import xla_extension
|
25
24
|
from jax._src.lib.mlir import ir
|
26
25
|
from jax._src.lib.mlir.dialects import func
|
27
26
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
@@ -35,6 +34,8 @@ from . import lowerings
|
|
35
34
|
|
36
35
|
LoweringContext = lowerings.context.LoweringContext
|
37
36
|
|
37
|
+
from jaxlib._jax.mlir import serialize_portable_artifact
|
38
|
+
|
38
39
|
|
39
40
|
def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
|
40
41
|
"""Build flattened inputs and metadata from exported program's signature."""
|
@@ -233,7 +234,7 @@ class MlirLowered:
|
|
233
234
|
target_version = stablehlo.get_version_from_compatibility_requirement(
|
234
235
|
stablehlo.StablehloCompatibilityRequirement.WEEK_12
|
235
236
|
)
|
236
|
-
module_bytecode =
|
237
|
+
module_bytecode = serialize_portable_artifact(
|
237
238
|
self.module_bytecode, target_version
|
238
239
|
)
|
239
240
|
return module_bytecode
|
@@ -331,6 +331,18 @@ def _aten_sym_size_int(lctx, x: ir.Value, dim: int):
|
|
331
331
|
return stablehlo.get_dimension_size(x, dim)
|
332
332
|
|
333
333
|
|
334
|
+
# Lowering for the addition operator (`+`).
|
335
|
+
# Handles cases where one operand is an integer (scalar) and the other is a
|
336
|
+
# tensor, broadcasting the scalar to the tensor's shape before addition.
|
337
|
+
@lower(operator.add)
|
338
|
+
def _operator_add(lctx, self: int | ir.Value, other: int | ir.Value):
|
339
|
+
if isinstance(self, int) and isinstance(other, ir.Value):
|
340
|
+
self = utils.splat(self, other.type.element_type, other.type.shape)
|
341
|
+
if isinstance(other, int) and isinstance(self, ir.Value):
|
342
|
+
other = utils.splat(other, self.type.element_type, self.type.shape)
|
343
|
+
return stablehlo.add(self, other)
|
344
|
+
|
345
|
+
|
334
346
|
# Lowering for the subtraction operator (`-`).
|
335
347
|
# Handles cases where one operand is an integer (scalar) and the other is a
|
336
348
|
# tensor, broadcasting the scalar to the tensor's shape before subtraction.
|
@@ -78,8 +78,6 @@ lower_by_torch_xla2(torch.ops.aten._unsafe_index)
|
|
78
78
|
lower_by_torch_xla2(torch.ops.aten._unsafe_view)
|
79
79
|
lower_by_torch_xla2(torch.ops.aten.acos)
|
80
80
|
lower_by_torch_xla2(torch.ops.aten.acosh)
|
81
|
-
lower_by_torch_xla2(torch.ops.aten.add.Scalar)
|
82
|
-
lower_by_torch_xla2(torch.ops.aten.add.Tensor)
|
83
81
|
lower_by_torch_xla2(torch.ops.aten.addbmm.default)
|
84
82
|
lower_by_torch_xla2(torch.ops.aten.addmm)
|
85
83
|
lower_by_torch_xla2(torch.ops.aten.addmv)
|
@@ -159,7 +157,6 @@ lower_by_torch_xla2(torch.ops.aten.logical_and)
|
|
159
157
|
lower_by_torch_xla2(torch.ops.aten.logical_not)
|
160
158
|
lower_by_torch_xla2(torch.ops.aten.logical_or)
|
161
159
|
lower_by_torch_xla2(torch.ops.aten.logical_xor)
|
162
|
-
lower_by_torch_xla2(torch.ops.aten.lt)
|
163
160
|
lower_by_torch_xla2(torch.ops.aten.max)
|
164
161
|
lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices)
|
165
162
|
lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
|
@@ -170,8 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mean)
|
|
170
167
|
lower_by_torch_xla2(torch.ops.aten.min)
|
171
168
|
lower_by_torch_xla2(torch.ops.aten.minimum)
|
172
169
|
lower_by_torch_xla2(torch.ops.aten.mm)
|
173
|
-
lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
174
|
-
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
175
170
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
176
171
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
177
172
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
@@ -215,8 +210,6 @@ lower_by_torch_xla2(torch.ops.aten.sqrt)
|
|
215
210
|
lower_by_torch_xla2(torch.ops.aten.squeeze)
|
216
211
|
lower_by_torch_xla2(torch.ops.aten.squeeze_copy)
|
217
212
|
lower_by_torch_xla2(torch.ops.aten.stack)
|
218
|
-
lower_by_torch_xla2(torch.ops.aten.sub.Scalar)
|
219
|
-
lower_by_torch_xla2(torch.ops.aten.sub.Tensor)
|
220
213
|
lower_by_torch_xla2(torch.ops.aten.sum)
|
221
214
|
lower_by_torch_xla2(torch.ops.aten.t)
|
222
215
|
lower_by_torch_xla2(torch.ops.aten.tan)
|
@@ -244,7 +237,6 @@ lower_by_torch_xla2(torch.ops.aten.view_as_real)
|
|
244
237
|
lower_by_torch_xla2(torch.ops.aten.view_copy)
|
245
238
|
lower_by_torch_xla2(torch.ops.aten.where.ScalarOther)
|
246
239
|
lower_by_torch_xla2(torch.ops.aten.where.ScalarSelf)
|
247
|
-
lower_by_torch_xla2(torch.ops.aten.where.self)
|
248
240
|
lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim)
|
249
241
|
lower_by_torch_xla2(torch.ops.prims.var)
|
250
242
|
|
@@ -259,6 +251,149 @@ def _aten_copy(self, src, **kwargs):
|
|
259
251
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
|
260
252
|
|
261
253
|
|
254
|
+
@registry.lower(torch.ops.aten.add.Scalar)
|
255
|
+
def _aten_add_scalar(lctx: LoweringContext, self, other):
|
256
|
+
_log_usage(torch.ops.aten.add.Scalar)
|
257
|
+
|
258
|
+
@jax_bridge.wrap
|
259
|
+
def jax_lowering(self, other):
|
260
|
+
other_dtype = jnp.result_type(other)
|
261
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
262
|
+
if promoted_type == jnp.float64:
|
263
|
+
promoted_type = jnp.float32
|
264
|
+
return jnp.add(
|
265
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
266
|
+
)
|
267
|
+
|
268
|
+
return jax_lowering(lctx, self, other)
|
269
|
+
|
270
|
+
|
271
|
+
@registry.lower(torch.ops.aten.add.Tensor)
|
272
|
+
def _aten_add_tensor(lctx: LoweringContext, self, other):
|
273
|
+
_log_usage(torch.ops.aten.add.Tensor)
|
274
|
+
|
275
|
+
@jax_bridge.wrap
|
276
|
+
def jax_lowering(self, other):
|
277
|
+
promoted_type = jnp.promote_types(self.dtype, other.dtype)
|
278
|
+
if promoted_type == jnp.float64:
|
279
|
+
promoted_type = jnp.float32
|
280
|
+
return jnp.add(self.astype(promoted_type), other.astype(promoted_type))
|
281
|
+
|
282
|
+
return jax_lowering(lctx, self, other)
|
283
|
+
|
284
|
+
|
285
|
+
@registry.lower(torch.ops.aten.sub.Scalar)
|
286
|
+
def _aten_sub_scalar(lctx: LoweringContext, self, other):
|
287
|
+
_log_usage(torch.ops.aten.sub.Scalar)
|
288
|
+
|
289
|
+
@jax_bridge.wrap
|
290
|
+
def jax_lowering(self, other):
|
291
|
+
other_dtype = jnp.result_type(other)
|
292
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
293
|
+
if promoted_type == jnp.float64:
|
294
|
+
promoted_type = jnp.float32
|
295
|
+
return jnp.subtract(
|
296
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
297
|
+
)
|
298
|
+
|
299
|
+
return jax_lowering(lctx, self, other)
|
300
|
+
|
301
|
+
|
302
|
+
@registry.lower(torch.ops.aten.sub.Tensor)
|
303
|
+
def _aten_sub_tensor(lctx: LoweringContext, self, other):
|
304
|
+
_log_usage(torch.ops.aten.sub.Tensor)
|
305
|
+
|
306
|
+
@jax_bridge.wrap
|
307
|
+
def jax_lowering(self, other):
|
308
|
+
promoted_type = jnp.promote_types(self.dtype, other.dtype)
|
309
|
+
if promoted_type == jnp.float64:
|
310
|
+
promoted_type = jnp.float32
|
311
|
+
return jnp.subtract(self.astype(promoted_type), other.astype(promoted_type))
|
312
|
+
|
313
|
+
return jax_lowering(lctx, self, other)
|
314
|
+
|
315
|
+
|
316
|
+
@registry.lower(torch.ops.aten.lt.Scalar)
|
317
|
+
def _aten_lt_scalar(lctx: LoweringContext, self, other):
|
318
|
+
_log_usage(torch.ops.aten.lt.Scalar)
|
319
|
+
|
320
|
+
@jax_bridge.wrap
|
321
|
+
def jax_lowering(self, other):
|
322
|
+
other_dtype = jnp.result_type(other)
|
323
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
324
|
+
if promoted_type == jnp.float64:
|
325
|
+
promoted_type = jnp.float32
|
326
|
+
return jnp.less(
|
327
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
328
|
+
)
|
329
|
+
|
330
|
+
return jax_lowering(lctx, self, other)
|
331
|
+
|
332
|
+
|
333
|
+
@registry.lower(torch.ops.aten.lt.Tensor)
|
334
|
+
def _aten_lt_tensor(lctx: LoweringContext, self, other):
|
335
|
+
_log_usage(torch.ops.aten.lt.Tensor)
|
336
|
+
|
337
|
+
@jax_bridge.wrap
|
338
|
+
def jax_lowering(self, other):
|
339
|
+
promoted_type = jnp.promote_types(self.dtype, other.dtype)
|
340
|
+
return jnp.less(self.astype(promoted_type), other.astype(promoted_type))
|
341
|
+
|
342
|
+
return jax_lowering(lctx, self, other)
|
343
|
+
|
344
|
+
|
345
|
+
@registry.lower(torch.ops.aten.mul.Scalar)
|
346
|
+
def _aten_mul_scalar(lctx: LoweringContext, self, other):
|
347
|
+
_log_usage(torch.ops.aten.mul.Scalar)
|
348
|
+
|
349
|
+
@jax_bridge.wrap
|
350
|
+
def jax_lowering(self, other):
|
351
|
+
other_dtype = jnp.result_type(other)
|
352
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
353
|
+
if promoted_type == jnp.float64:
|
354
|
+
promoted_type = jnp.float32
|
355
|
+
return jnp.multiply(
|
356
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
357
|
+
)
|
358
|
+
|
359
|
+
return jax_lowering(lctx, self, other)
|
360
|
+
|
361
|
+
|
362
|
+
@registry.lower(torch.ops.aten.mul.Tensor)
|
363
|
+
def _aten_mul_tensor(lctx: LoweringContext, self, other):
|
364
|
+
_log_usage(torch.ops.aten.mul.Tensor)
|
365
|
+
|
366
|
+
@jax_bridge.wrap
|
367
|
+
def jax_lowering(self, other):
|
368
|
+
other_dtype = jnp.result_type(other)
|
369
|
+
promoted_type = jnp.promote_types(self.dtype, other_dtype)
|
370
|
+
if promoted_type == jnp.float64:
|
371
|
+
promoted_type = jnp.float32
|
372
|
+
return jnp.multiply(
|
373
|
+
self.astype(promoted_type), jnp.array(other, dtype=promoted_type)
|
374
|
+
)
|
375
|
+
|
376
|
+
return jax_lowering(lctx, self, other)
|
377
|
+
|
378
|
+
|
379
|
+
@registry.lower(torch.ops.aten.where.self)
|
380
|
+
def _aten_where_self(lctx: LoweringContext, condition, self, other):
|
381
|
+
_log_usage(torch.ops.aten.where.self)
|
382
|
+
|
383
|
+
@jax_bridge.wrap
|
384
|
+
def jax_lowering(condition, self, other):
|
385
|
+
promoted_type = jnp.promote_types(self.dtype, other.dtype)
|
386
|
+
if promoted_type == jnp.float64:
|
387
|
+
promoted_type = jnp.float32
|
388
|
+
return jnp.where(
|
389
|
+
condition,
|
390
|
+
self.astype(promoted_type),
|
391
|
+
other.astype(promoted_type),
|
392
|
+
)
|
393
|
+
|
394
|
+
return jax_lowering(lctx, condition, self, other)
|
395
|
+
|
396
|
+
|
262
397
|
# Schema:
|
263
398
|
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
|
264
399
|
# -> Tensor
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Optimization barrier op definition and lowering."""
|
16
|
+
|
17
|
+
from ai_edge_torch.odml_torch import _torch_library
|
18
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
19
|
+
from jax._src.lib.mlir import ir
|
20
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
21
|
+
import torch
|
22
|
+
import torch.utils._pytree as pytree
|
23
|
+
|
24
|
+
_torch_library.ODML_TORCH_LIB.define(
|
25
|
+
"optimization_barrier(Tensor[] inputs) -> Tensor[]"
|
26
|
+
)
|
27
|
+
|
28
|
+
optimization_barrier_op = torch.ops.odml_torch.optimization_barrier.default
|
29
|
+
|
30
|
+
|
31
|
+
def optimization_barrier(*inputs: pytree.PyTree):
|
32
|
+
"""Apply optimization barrier to the tensors nested within arbitrary pytrees.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
*inputs: A list of tensors or tensor pytrees.
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
The tensors after optimization barrier in the same pytrees structures.
|
39
|
+
"""
|
40
|
+
if len(inputs) == 1:
|
41
|
+
inputs = inputs[0]
|
42
|
+
tensors, spec = pytree.tree_flatten(inputs)
|
43
|
+
tensors = optimization_barrier_op(tuple(tensors))
|
44
|
+
outputs = pytree.tree_unflatten(tensors, spec)
|
45
|
+
return outputs
|
46
|
+
|
47
|
+
|
48
|
+
@torch.library.impl(
|
49
|
+
_torch_library.ODML_TORCH_LIB,
|
50
|
+
"optimization_barrier",
|
51
|
+
"CompositeExplicitAutograd",
|
52
|
+
)
|
53
|
+
def _optimization_barrier_impl(inputs: tuple[torch.Tensor, ...]):
|
54
|
+
return tuple(inputs)
|
55
|
+
|
56
|
+
|
57
|
+
@torch.library.impl(
|
58
|
+
_torch_library.ODML_TORCH_LIB,
|
59
|
+
"optimization_barrier",
|
60
|
+
"Meta",
|
61
|
+
)
|
62
|
+
def _optimization_barrier_fake(inputs: tuple[torch.Tensor, ...]):
|
63
|
+
return tuple([torch.empty_like(x) for x in inputs])
|
64
|
+
|
65
|
+
|
66
|
+
@registry.lower(torch.ops.odml_torch.optimization_barrier.default)
|
67
|
+
def _optimization_barrier_lowering(
|
68
|
+
lctx, inputs: tuple[ir.Value, ...]
|
69
|
+
) -> ir.Value:
|
70
|
+
del lctx
|
71
|
+
return stablehlo.optimization_barrier(inputs)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250816
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -31,7 +31,7 @@ Requires-Dist: transformers
|
|
31
31
|
Requires-Dist: kagglehub
|
32
32
|
Requires-Dist: tabulate
|
33
33
|
Requires-Dist: torch>=2.4.0
|
34
|
-
Requires-Dist: tf-nightly
|
34
|
+
Requires-Dist: tf-nightly<=2.20.0.dev20250615
|
35
35
|
Requires-Dist: ai-edge-litert-nightly
|
36
36
|
Requires-Dist: ai-edge-quantizer-nightly
|
37
37
|
Requires-Dist: jax
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=l7eZe2YNZxV10qw-t4qOqrQ97K8OtzpEfzORpOBAi88,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -10,7 +10,7 @@ ai_edge_torch/_convert/converter.py,sha256=6MLKELzAwFoiXv-b7KRYi7gc7Z57XOeowcz9A
|
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
12
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=uHek7k9KIW3kaEM_lcygbukJ69JLjm-xnYUWzAEIZII,1345
|
13
|
-
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=
|
13
|
+
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=D_7RQp3F7XzJZ3d0Sgay7hf_oz3IhKZCZ-F-sDH5LEY,11589
|
14
14
|
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
15
15
|
ai_edge_torch/_convert/fx_passes/eliminate_dead_code_pass.py,sha256=jMl9YHIbx08KQHbp9UgDnxviUUWiN-FSsiUgR2HCT5s,1576
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
28
28
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=oXbr9G5Jc21xd1dr2CDrp774I4crs0_kkN490K5fNn0,7312
|
29
29
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
30
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=ioeEpeS07NAs14-nHiPI-6lLTtnALxl8uNtTKvoHdgE,19316
|
31
31
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
32
32
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
33
33
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -65,15 +65,15 @@ ai_edge_torch/generative/examples/deepseek/verify_util.py,sha256=xf2-jhv-ORwr_-R
|
|
65
65
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
66
66
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=m5N3MxNX4Yu8-0vXRszkMgfVuFKN6gskVoUIGqonJFk,1814
|
67
67
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=fR4869w1UZIuUVGLDdPof0IDEMDhqn2ej5kl7QW9vNA,1882
|
68
|
-
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=8oLB0PaSrHMy0tDT--qPIEjTm0qNT16kqRBCHqAALSc,4674
|
69
69
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=E6jotWYYIx6SXUWqurKWjiZpbfj_M2jJrBc2rQ90z1s,11782
|
70
70
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=USyist332qZkhCBof2tJwQSqtnKjTQsKAK_jCE_CO2U,1853
|
71
71
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=eAM7EVVMW-QCqjeZEss7TOkVKArgUs1La51LAC-5a9A,1962
|
72
72
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=b12naCj4zZxOjkIKrd08qovtajYuX-Ba3fbrv6kkDZs,8410
|
73
73
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
74
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
75
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
76
|
-
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=b2l8R3-3xhurmXn3LRZ-C12J11C8tx77NXKze6jFYbA,2425
|
75
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=DEkOnfHZJPaxbjwegarRjGpBa7gr6IuWMf8xgZP2mtU,17951
|
76
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=PAtuqvBkq6TeNFImoJfCaLX5pZCBOcaQjFvwur_XQJ4,7141
|
77
77
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=OCMIAQfNmPR4uQUAtlYL6j4xkG0dw2Ays4-lnThcWqQ,5110
|
78
78
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
79
79
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=5OmUwz38kVHYLA-v8U8evvDN9da2WioZtGo-XK6yq1o,10067
|
@@ -174,7 +174,7 @@ ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAe
|
|
174
174
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
|
175
175
|
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
|
176
176
|
ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
|
177
|
-
ai_edge_torch/generative/layers/einsum.py,sha256=
|
177
|
+
ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
|
178
178
|
ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
|
179
179
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
|
180
180
|
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
|
@@ -196,7 +196,7 @@ ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FT
|
|
196
196
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=plMsd7JBi98r2NHsAdMdvS6TPTXAoRFLCwOXu8H3-24,2004
|
197
197
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=CEW-ewHxwb59x_GISx4jr7WMihvn-jKWVcBonllzDS4,5724
|
198
198
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=h3k_na6rbR08Ip79-2JbkeH8RDk_rrnEGiytuzFDhqc,2678
|
199
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
199
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=mmPsgQ2vloMWflqJ6ALmD1lANacTkWotEhurb5foK30,2771
|
200
200
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=TwR2FpQuBEORy6FshEyHNBMKARWlA2MVtTfX9tXV5aE,1488
|
201
201
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
202
202
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
@@ -231,12 +231,13 @@ ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5z
|
|
231
231
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
|
232
232
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
233
233
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
234
|
-
ai_edge_torch/lowertools/translate_recipe.py,sha256=
|
234
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=Oavs0ENKVnIeB-WidXvokTPqNlFfOuP0GMV0d5RK2Rg,6285
|
235
235
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
236
236
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
237
237
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
238
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
238
|
+
ai_edge_torch/odml_torch/export.py,sha256=FDseiAOOcgN8HOuwv6HT9sALZgXGj7vZi6kwZPWPF14,14935
|
239
239
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
240
|
+
ai_edge_torch/odml_torch/optimization_barrier.py,sha256=2lmSiu5iXWLFWpupZHvsVeNYNzG5AVGSK3K_CNhS5Sk,2290
|
240
241
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
241
242
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
242
243
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
@@ -248,11 +249,11 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
|
|
248
249
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
249
250
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
250
251
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
251
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
252
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=sC4N5-7RS9yKecs97kM9J56enGvsZj1CJo7y79cuzRg,12784
|
252
253
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
253
254
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
254
255
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=ybOdoFE5HIJTkyiYcc73zpyUyUpioVnAca6k0wyJPs4,2572
|
255
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
256
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=xUuQjoR0NJhuwG36GuycpKHo9jg783bDSHj9wE4F1Sg,15439
|
256
257
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
257
258
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
258
259
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
@@ -268,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
268
269
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
269
270
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
270
271
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
271
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
-
ai_edge_torch_nightly-0.6.0.
|
273
|
-
ai_edge_torch_nightly-0.6.0.
|
274
|
-
ai_edge_torch_nightly-0.6.0.
|
275
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250816.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250816.dist-info/METADATA,sha256=x3HCuvXK0BSgJhf34jwigF6tJIjfq55lFwrPZIrPjVA,2074
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250816.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
275
|
+
ai_edge_torch_nightly-0.6.0.dev20250816.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
276
|
+
ai_edge_torch_nightly-0.6.0.dev20250816.dist-info/RECORD,,
|
File without changes
|
File without changes
|