ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. ai_edge_torch/_config.py +9 -0
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
  5. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  8. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  10. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  12. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  13. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  14. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  15. ai_edge_torch/generative/layers/kv_cache.py +2 -1
  16. ai_edge_torch/generative/layers/model_config.py +5 -1
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  18. ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
  19. ai_edge_torch/generative/utilities/converter.py +18 -2
  20. ai_edge_torch/generative/utilities/model_builder.py +6 -1
  21. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
  22. ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
  26. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A listing of types describes the K and V tensors in KV caches.
16
+
17
+ import enum
18
+ from enum import Enum, auto
19
+ from typing import Tuple
20
+ from torch import nn
21
+
22
+
23
+ @enum.unique
24
+ class TensorDims(Enum):
25
+ BATCH = enum.auto()
26
+ SEQUENCE = enum.auto()
27
+ NUM_HEADS = enum.auto()
28
+ HEAD_DIM = enum.auto()
29
+ MODEL_DIM = enum.auto() # often num_heads * head_dim
30
+
31
+
32
+ DIM_TO_LETTER = {
33
+ TensorDims.BATCH: 'B',
34
+ TensorDims.SEQUENCE: 'T',
35
+ TensorDims.NUM_HEADS: 'N',
36
+ TensorDims.HEAD_DIM: 'H',
37
+ TensorDims.MODEL_DIM: 'D',
38
+ }
39
+
40
+
41
+ class TensorDimensionMeta(type):
42
+ """Metaclass to create classes representing an order of tensor dimensions."""
43
+
44
+ def __new__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
45
+ """Creates a new class with the given name and tensor dimension order.
46
+
47
+ Args:
48
+ name: Name of the new class.
49
+ bases: Base classes for the new class.
50
+ attrs: Attributes for the new class.
51
+ dimensions: A tuple of TensorDims defining the order.
52
+ """
53
+
54
+ attrs['dimensions'] = (
55
+ dimensions # Store the dimensions as a class attribute
56
+ )
57
+ return super().__new__(cls, name, bases, attrs)
58
+
59
+ def __init__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
60
+ super().__init__(name, bases, attrs)
61
+
62
+ def __repr__(cls):
63
+ return f'{cls.__name__}'
64
+
65
+
66
+ def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
67
+ """Creates a TensorDimensionMeta class with the specified dimensions.
68
+
69
+ Args:
70
+ dimensions: A tuple of TensorDims.
71
+
72
+ Returns:
73
+ A new class representing the tensor dimension order.
74
+ """
75
+ name = ''.join(DIM_TO_LETTER[d] for d in dims)
76
+ # Derive from nn.Module for torch tracing compatiblity.
77
+ return TensorDimensionMeta(name, (nn.Module,), {}, dimensions=dims)
78
+
79
+
80
+ BTNH = create_tensor_dimension_order_class((
81
+ TensorDims.BATCH,
82
+ TensorDims.SEQUENCE,
83
+ TensorDims.NUM_HEADS,
84
+ TensorDims.HEAD_DIM,
85
+ ))
86
+ BNTH = create_tensor_dimension_order_class((
87
+ TensorDims.BATCH,
88
+ TensorDims.NUM_HEADS,
89
+ TensorDims.SEQUENCE,
90
+ TensorDims.HEAD_DIM,
91
+ ))
92
+ BNHT = create_tensor_dimension_order_class((
93
+ TensorDims.BATCH,
94
+ TensorDims.NUM_HEADS,
95
+ TensorDims.HEAD_DIM,
96
+ TensorDims.SEQUENCE,
97
+ ))
@@ -81,7 +81,8 @@ class KVCache:
81
81
  """
82
82
  caches = [
83
83
  KVCacheEntry.from_model_config(
84
- config.kv_cache_max,
84
+ config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
85
+ else config.block_config(idx).kv_cache_max_len,
85
86
  config.block_config(idx).attn_config,
86
87
  dtype,
87
88
  device,
@@ -164,6 +164,9 @@ class TransformerBlockConfig:
164
164
  parallel_residual: bool = False
165
165
  # The Attention computation will include relative positional bias.
166
166
  relative_attention: bool = False
167
+ # KV Cache length for this block. Only used when attention types are different
168
+ # across blocks
169
+ kv_cache_max_len: Optional[int] = None
167
170
 
168
171
 
169
172
  @dataclasses.dataclass
@@ -200,7 +203,8 @@ class ModelConfig:
200
203
  embedding_use_bias: bool = False
201
204
  # Image embedding parameters.
202
205
  image_embedding: Optional[ImageEmbeddingConfig] = None
203
-
206
+ # Number of image tokens
207
+ num_mm_tokens_per_image: Optional[int] = None
204
208
  # Use bias term within LLM's HEAD.
205
209
  lm_head_use_bias: bool = False
206
210
  # Whether LLM's HEAD shares the weight of the embedding.
@@ -17,6 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
+ from ai_edge_torch.generative.examples.deepseek import deepseek
20
21
  from ai_edge_torch.generative.examples.gemma import gemma1
21
22
  from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  from ai_edge_torch.generative.examples.llama import llama
@@ -150,16 +151,15 @@ class TestModelConversion(googletest.TestCase):
150
151
  ai_edge_torch.config.in_oss,
151
152
  reason="tests with custom ops are not supported in oss",
152
153
  )
153
-
154
154
  def test_smollm2(self):
155
155
  config = smollm.get_fake_model_config_v2()
156
156
  pytorch_model = smollm.SmolLM2(config).eval()
157
157
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+
158
159
  @googletest.skipIf(
159
160
  ai_edge_torch.config.in_oss,
160
161
  reason="tests with custom ops are not supported in oss",
161
162
  )
162
-
163
163
  def test_openelm(self):
164
164
  config = openelm.get_fake_model_config()
165
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -174,6 +174,15 @@ class TestModelConversion(googletest.TestCase):
174
174
  pytorch_model = qwen.Qwen(config).eval()
175
175
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
176
176
 
177
+ @googletest.skipIf(
178
+ ai_edge_torch.config.in_oss,
179
+ reason="tests with custom ops are not supported in oss",
180
+ )
181
+ def test_deepseek(self):
182
+ config = deepseek.get_fake_model_config()
183
+ pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
184
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
185
+
177
186
  @googletest.skipIf(
178
187
  ai_edge_torch.config.in_oss,
179
188
  reason="tests with custom ops are not supported in oss",
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Sequence
20
+ from ai_edge_torch.odml_torch import lowerings
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # Use torch.library.custom_op to define a new custom operator.
27
+ @torch.library.custom_op("ai_edge_torch::bmm_4d", mutates_args=())
28
+ def bmm_4d(
29
+ lhs: torch.Tensor,
30
+ rhs: torch.Tensor,
31
+ ) -> torch.Tensor:
32
+ if not (lhs.ndim == 4 and rhs.ndim == 4):
33
+ raise ValueError("bmm_4d requires LHS and RHS have rank 4.")
34
+ d0_can_bcast = lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
35
+ d1_can_bcast = lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
36
+ if not (d0_can_bcast and d1_can_bcast):
37
+ raise ValueError("bmm_4d requires that dimensions 0 and 1 can broadcast.")
38
+
39
+ if not lhs.shape[-1] == rhs.shape[-1]:
40
+ raise ValueError("bmm_4d requires LHS and RHS have same last dimension.")
41
+
42
+ return torch.einsum("abcd,abed->abce", lhs, rhs)
43
+
44
+
45
+ # Use register_fake to add a ``FakeTensor`` kernel for the operator
46
+ @bmm_4d.register_fake
47
+ def _(lhs, rhs):
48
+ return torch.einsum("abcd,abed->abce", lhs, rhs)
49
+
50
+
51
+ @lowerings.lower(torch.ops.ai_edge_torch.bmm_4d)
52
+ def _bmm_4d_lower(
53
+ lctx,
54
+ lhs: ir.Value,
55
+ rhs: ir.Value,
56
+ ):
57
+ dot_dnums = stablehlo.DotDimensionNumbers.get(
58
+ lhs_batching_dimensions=[0, 1],
59
+ rhs_batching_dimensions=[0, 1],
60
+ lhs_contracting_dimensions=(3,),
61
+ rhs_contracting_dimensions=(3,),
62
+ )
63
+ return stablehlo.dot_general(
64
+ ir.RankedTensorType.get(
65
+ (
66
+ lhs.type.shape[0],
67
+ lhs.type.shape[1],
68
+ lhs.type.shape[2],
69
+ rhs.type.shape[2],
70
+ ),
71
+ lhs.type.element_type,
72
+ ),
73
+ lhs,
74
+ rhs,
75
+ dot_dnums,
76
+ )
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Optional, Union
20
20
  from ai_edge_torch._convert import converter as converter_utils
21
21
  from ai_edge_torch.generative.layers import lora as lora_utils
22
- import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
22
  import ai_edge_torch.generative.layers.model_config as cfg
24
23
  from ai_edge_torch.generative.quantize import quant_recipes
25
24
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -151,9 +150,21 @@ def _export_helper(
151
150
  else None
152
151
  )
153
152
 
153
+ if export_config.prefill_mask is None:
154
+ prefill_masks = None
155
+ elif isinstance(export_config.prefill_mask, torch.Tensor):
156
+ prefill_masks = [export_config.prefill_mask]
157
+ elif isinstance(export_config.prefill_mask, list):
158
+ prefill_masks = export_config.prefill_mask
159
+ else:
160
+ raise ValueError('Prefill masks unrecognized.')
161
+
162
+ if prefill_masks:
163
+ assert len(prefill_masks) == len(prefill_seq_lens)
164
+
154
165
  decode_token = torch.tensor([[0]], dtype=torch.int)
155
166
  decode_input_pos = torch.tensor([0], dtype=torch.int)
156
- kv = kv_utils.KVCache.from_model_config(config)
167
+ kv = export_config.kvcache_cls.from_model_config(config)
157
168
 
158
169
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
159
170
 
@@ -174,6 +185,9 @@ def _export_helper(
174
185
  'input_pos': prefill_input_pos,
175
186
  'kv_cache': kv,
176
187
  }
188
+ if prefill_masks is not None:
189
+ sample_kwargs['mask'] = prefill_masks[i]
190
+
177
191
  if lora is not None:
178
192
  prefill_signature_name += f'_lora_r{lora.get_rank()}'
179
193
  sample_kwargs['lora'] = lora
@@ -199,6 +213,8 @@ def _export_helper(
199
213
  'input_pos': decode_input_pos,
200
214
  'kv_cache': kv,
201
215
  }
216
+ if export_config.decode_mask is not None:
217
+ sample_kwargs['mask'] = export_config.decode_mask
202
218
  if lora is not None:
203
219
  sample_kwargs['lora'] = lora
204
220
 
@@ -17,7 +17,7 @@
17
17
 
18
18
  import copy
19
19
  from dataclasses import dataclass
20
- from typing import Optional, Tuple
20
+ from typing import List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
@@ -55,6 +55,11 @@ class ExportConfig:
55
55
  # On prefill signatures, should the model produce logit output?
56
56
  # When False, only decode signatures will produce output.
57
57
  output_logits_on_prefill: bool = False
58
+ # Attention masks given as inputs to the model.
59
+ prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
60
+ decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
61
+ # The KV Cache class for K and V buffers in attention.
62
+ kvcache_cls: type = kv_utils.KVCache
58
63
 
59
64
 
60
65
  class DecoderOnlyModel(nn.Module):
@@ -29,7 +29,7 @@ LoweringContext = context.LoweringContext
29
29
 
30
30
  @functools.cache
31
31
  def _log_usage(op):
32
- logging.warning("Use jax lowering: %s", str(op))
32
+ logging.info("Use JAX lowering: %s", str(op))
33
33
 
34
34
 
35
35
  def lower_by_jax(op, ir_input_names=None):
@@ -21,8 +21,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional
21
21
  import torch
22
22
  from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
23
23
  from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
24
- from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
25
- from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
26
24
  from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
27
25
  from torch.ao.quantization.quantizer import QuantizationAnnotation
28
26
  from torch.ao.quantization.quantizer import QuantizationSpec
@@ -47,6 +45,28 @@ __all__ = [
47
45
  "propagate_annotation",
48
46
  ]
49
47
 
48
+ # Example inputs for conv-bn1d patterns
49
+ _conv1d_bn_example_inputs = (
50
+ torch.randn(1, 1, 3), # x
51
+ torch.randn(1, 1, 1), # conv_weight
52
+ torch.randn(1), # conv_bias
53
+ torch.randn(1), # bn_weight
54
+ torch.randn(1), # bn_bias
55
+ torch.randn(1), # bn_running_mean
56
+ torch.randn(1), # bn_running_var
57
+ )
58
+
59
+ # Example inputs for conv-bn2d patterns
60
+ _conv2d_bn_example_inputs = (
61
+ torch.randn(1, 1, 3, 3), # x
62
+ torch.randn(1, 1, 1, 1), # conv_weight
63
+ torch.randn(1), # conv_bias
64
+ torch.randn(1), # bn_weight
65
+ torch.randn(1), # bn_bias
66
+ torch.randn(1), # bn_running_mean
67
+ torch.randn(1), # bn_running_var
68
+ )
69
+
50
70
 
51
71
  @dataclass(eq=True, frozen=True)
52
72
  class QuantizationConfig:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250123"
16
+ __version__ = "0.3.0.dev20250125"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250123
3
+ Version: 0.3.0.dev20250125
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,8 +1,8 @@
1
1
  ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
2
- ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
2
+ ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
5
+ ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -16,11 +16,11 @@ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZv
16
16
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
17
17
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
23
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
23
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
25
25
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
@@ -49,6 +49,10 @@ ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0H
49
49
  ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
50
50
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
51
51
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
52
+ ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
53
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
54
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
55
+ ai_edge_torch/generative/examples/deepseek/verify.py,sha256=sDYBhmE_CeZw5iLIQ7rJNGLjhcTyKUQGdg7_QQBh9WM,2398
52
56
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
54
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
@@ -97,7 +101,7 @@ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68
97
101
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
102
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
99
103
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
100
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=sB_7-PVri8PxKnFG7c8GsTGyrxGEda-oZwGyyScTL3o,5239
104
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=GtwKAByEk0ENGEWbUmC2mAAPkbLZ3M5xH1HIToyu8QE,5307
101
105
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=sQKQ-k6H9kG2brgwLsktjCMeN2h0POyfMP6iNsPNKWc,16271
102
106
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6W58LxmHHkz2ctgpknQkyoDANZAnE9Byp_svfqLpQf0,34793
103
107
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -128,12 +132,17 @@ ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8k
128
132
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
129
133
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
130
134
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
131
- ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
135
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
132
136
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
133
- ai_edge_torch/generative/layers/model_config.py,sha256=9yPEmWNw3-_2wXBmPmZ7RUKcPXHF2ZbJwksyQoXTA6M,7784
137
+ ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQteKp-UM1YvMbbWYlkNw,7999
134
138
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
135
139
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
136
140
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
141
+ ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
142
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
143
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
144
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
145
+ ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
137
146
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
138
147
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
139
148
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -151,14 +160,15 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
151
160
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
152
161
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
153
162
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
154
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
163
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
155
164
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
156
165
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
157
166
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
158
- ai_edge_torch/generative/utilities/converter.py,sha256=yNIZ-O6RdXYl8yuWM_sTENRxozPnKGS-TZRhiiTaraE,7515
167
+ ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
168
+ ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
159
169
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
160
170
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
161
- ai_edge_torch/generative/utilities/model_builder.py,sha256=3CQLxJ02pFIo2DlS-RCn9cT6OvR4NiIuYRH597UXLiI,6530
171
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
162
172
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
163
173
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
164
174
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -197,7 +207,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw5
197
207
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
198
208
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
199
209
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
200
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
210
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=fte81SZxgxeMcI3wWVKSTnUjIxVVilOJ6H3TybXyDmQ,11558
201
211
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
202
212
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
203
213
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -207,13 +217,13 @@ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZm
207
217
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
208
218
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
209
219
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
210
- ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=eARD1LxLi5m7Z0n_psAkeX_AtUp4fNkE--oECBfivv4,36208
220
+ ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
211
221
  ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
212
222
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
213
223
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
214
224
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
215
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
216
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
217
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
218
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
219
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
225
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
227
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,