ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (28) hide show
  1. ai_edge_torch/_config.py +9 -0
  2. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
  3. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
  5. ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
  7. ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
  8. ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
  10. ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
  11. ai_edge_torch/generative/layers/experimental/attention.py +269 -0
  12. ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
  13. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
  14. ai_edge_torch/generative/layers/experimental/types.py +97 -0
  15. ai_edge_torch/generative/layers/kv_cache.py +2 -1
  16. ai_edge_torch/generative/layers/model_config.py +5 -1
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
  18. ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
  19. ai_edge_torch/generative/utilities/converter.py +18 -2
  20. ai_edge_torch/generative/utilities/model_builder.py +6 -1
  21. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
  22. ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
  26. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A listing of types describes the K and V tensors in KV caches.
16
+
17
+ import enum
18
+ from enum import Enum, auto
19
+ from typing import Tuple
20
+ from torch import nn
21
+
22
+
23
+ @enum.unique
24
+ class TensorDims(Enum):
25
+ BATCH = enum.auto()
26
+ SEQUENCE = enum.auto()
27
+ NUM_HEADS = enum.auto()
28
+ HEAD_DIM = enum.auto()
29
+ MODEL_DIM = enum.auto() # often num_heads * head_dim
30
+
31
+
32
+ DIM_TO_LETTER = {
33
+ TensorDims.BATCH: 'B',
34
+ TensorDims.SEQUENCE: 'T',
35
+ TensorDims.NUM_HEADS: 'N',
36
+ TensorDims.HEAD_DIM: 'H',
37
+ TensorDims.MODEL_DIM: 'D',
38
+ }
39
+
40
+
41
+ class TensorDimensionMeta(type):
42
+ """Metaclass to create classes representing an order of tensor dimensions."""
43
+
44
+ def __new__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
45
+ """Creates a new class with the given name and tensor dimension order.
46
+
47
+ Args:
48
+ name: Name of the new class.
49
+ bases: Base classes for the new class.
50
+ attrs: Attributes for the new class.
51
+ dimensions: A tuple of TensorDims defining the order.
52
+ """
53
+
54
+ attrs['dimensions'] = (
55
+ dimensions # Store the dimensions as a class attribute
56
+ )
57
+ return super().__new__(cls, name, bases, attrs)
58
+
59
+ def __init__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
60
+ super().__init__(name, bases, attrs)
61
+
62
+ def __repr__(cls):
63
+ return f'{cls.__name__}'
64
+
65
+
66
+ def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
67
+ """Creates a TensorDimensionMeta class with the specified dimensions.
68
+
69
+ Args:
70
+ dimensions: A tuple of TensorDims.
71
+
72
+ Returns:
73
+ A new class representing the tensor dimension order.
74
+ """
75
+ name = ''.join(DIM_TO_LETTER[d] for d in dims)
76
+ # Derive from nn.Module for torch tracing compatiblity.
77
+ return TensorDimensionMeta(name, (nn.Module,), {}, dimensions=dims)
78
+
79
+
80
+ BTNH = create_tensor_dimension_order_class((
81
+ TensorDims.BATCH,
82
+ TensorDims.SEQUENCE,
83
+ TensorDims.NUM_HEADS,
84
+ TensorDims.HEAD_DIM,
85
+ ))
86
+ BNTH = create_tensor_dimension_order_class((
87
+ TensorDims.BATCH,
88
+ TensorDims.NUM_HEADS,
89
+ TensorDims.SEQUENCE,
90
+ TensorDims.HEAD_DIM,
91
+ ))
92
+ BNHT = create_tensor_dimension_order_class((
93
+ TensorDims.BATCH,
94
+ TensorDims.NUM_HEADS,
95
+ TensorDims.HEAD_DIM,
96
+ TensorDims.SEQUENCE,
97
+ ))
@@ -81,7 +81,8 @@ class KVCache:
81
81
  """
82
82
  caches = [
83
83
  KVCacheEntry.from_model_config(
84
- config.kv_cache_max,
84
+ config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
85
+ else config.block_config(idx).kv_cache_max_len,
85
86
  config.block_config(idx).attn_config,
86
87
  dtype,
87
88
  device,
@@ -164,6 +164,9 @@ class TransformerBlockConfig:
164
164
  parallel_residual: bool = False
165
165
  # The Attention computation will include relative positional bias.
166
166
  relative_attention: bool = False
167
+ # KV Cache length for this block. Only used when attention types are different
168
+ # across blocks
169
+ kv_cache_max_len: Optional[int] = None
167
170
 
168
171
 
169
172
  @dataclasses.dataclass
@@ -200,7 +203,8 @@ class ModelConfig:
200
203
  embedding_use_bias: bool = False
201
204
  # Image embedding parameters.
202
205
  image_embedding: Optional[ImageEmbeddingConfig] = None
203
-
206
+ # Number of image tokens
207
+ num_mm_tokens_per_image: Optional[int] = None
204
208
  # Use bias term within LLM's HEAD.
205
209
  lm_head_use_bias: bool = False
206
210
  # Whether LLM's HEAD shares the weight of the embedding.
@@ -17,6 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
20
+ from ai_edge_torch.generative.examples.deepseek import deepseek
20
21
  from ai_edge_torch.generative.examples.gemma import gemma1
21
22
  from ai_edge_torch.generative.examples.gemma import gemma2
22
23
  from ai_edge_torch.generative.examples.llama import llama
@@ -150,16 +151,15 @@ class TestModelConversion(googletest.TestCase):
150
151
  ai_edge_torch.config.in_oss,
151
152
  reason="tests with custom ops are not supported in oss",
152
153
  )
153
-
154
154
  def test_smollm2(self):
155
155
  config = smollm.get_fake_model_config_v2()
156
156
  pytorch_model = smollm.SmolLM2(config).eval()
157
157
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
158
+
158
159
  @googletest.skipIf(
159
160
  ai_edge_torch.config.in_oss,
160
161
  reason="tests with custom ops are not supported in oss",
161
162
  )
162
-
163
163
  def test_openelm(self):
164
164
  config = openelm.get_fake_model_config()
165
165
  pytorch_model = openelm.OpenELM(config).eval()
@@ -174,6 +174,15 @@ class TestModelConversion(googletest.TestCase):
174
174
  pytorch_model = qwen.Qwen(config).eval()
175
175
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
176
176
 
177
+ @googletest.skipIf(
178
+ ai_edge_torch.config.in_oss,
179
+ reason="tests with custom ops are not supported in oss",
180
+ )
181
+ def test_deepseek(self):
182
+ config = deepseek.get_fake_model_config()
183
+ pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
184
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
185
+
177
186
  @googletest.skipIf(
178
187
  ai_edge_torch.config.in_oss,
179
188
  reason="tests with custom ops are not supported in oss",
@@ -0,0 +1,76 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Sequence
20
+ from ai_edge_torch.odml_torch import lowerings
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # Use torch.library.custom_op to define a new custom operator.
27
+ @torch.library.custom_op("ai_edge_torch::bmm_4d", mutates_args=())
28
+ def bmm_4d(
29
+ lhs: torch.Tensor,
30
+ rhs: torch.Tensor,
31
+ ) -> torch.Tensor:
32
+ if not (lhs.ndim == 4 and rhs.ndim == 4):
33
+ raise ValueError("bmm_4d requires LHS and RHS have rank 4.")
34
+ d0_can_bcast = lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
35
+ d1_can_bcast = lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
36
+ if not (d0_can_bcast and d1_can_bcast):
37
+ raise ValueError("bmm_4d requires that dimensions 0 and 1 can broadcast.")
38
+
39
+ if not lhs.shape[-1] == rhs.shape[-1]:
40
+ raise ValueError("bmm_4d requires LHS and RHS have same last dimension.")
41
+
42
+ return torch.einsum("abcd,abed->abce", lhs, rhs)
43
+
44
+
45
+ # Use register_fake to add a ``FakeTensor`` kernel for the operator
46
+ @bmm_4d.register_fake
47
+ def _(lhs, rhs):
48
+ return torch.einsum("abcd,abed->abce", lhs, rhs)
49
+
50
+
51
+ @lowerings.lower(torch.ops.ai_edge_torch.bmm_4d)
52
+ def _bmm_4d_lower(
53
+ lctx,
54
+ lhs: ir.Value,
55
+ rhs: ir.Value,
56
+ ):
57
+ dot_dnums = stablehlo.DotDimensionNumbers.get(
58
+ lhs_batching_dimensions=[0, 1],
59
+ rhs_batching_dimensions=[0, 1],
60
+ lhs_contracting_dimensions=(3,),
61
+ rhs_contracting_dimensions=(3,),
62
+ )
63
+ return stablehlo.dot_general(
64
+ ir.RankedTensorType.get(
65
+ (
66
+ lhs.type.shape[0],
67
+ lhs.type.shape[1],
68
+ lhs.type.shape[2],
69
+ rhs.type.shape[2],
70
+ ),
71
+ lhs.type.element_type,
72
+ ),
73
+ lhs,
74
+ rhs,
75
+ dot_dnums,
76
+ )
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Optional, Union
20
20
  from ai_edge_torch._convert import converter as converter_utils
21
21
  from ai_edge_torch.generative.layers import lora as lora_utils
22
- import ai_edge_torch.generative.layers.kv_cache as kv_utils
23
22
  import ai_edge_torch.generative.layers.model_config as cfg
24
23
  from ai_edge_torch.generative.quantize import quant_recipes
25
24
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -151,9 +150,21 @@ def _export_helper(
151
150
  else None
152
151
  )
153
152
 
153
+ if export_config.prefill_mask is None:
154
+ prefill_masks = None
155
+ elif isinstance(export_config.prefill_mask, torch.Tensor):
156
+ prefill_masks = [export_config.prefill_mask]
157
+ elif isinstance(export_config.prefill_mask, list):
158
+ prefill_masks = export_config.prefill_mask
159
+ else:
160
+ raise ValueError('Prefill masks unrecognized.')
161
+
162
+ if prefill_masks:
163
+ assert len(prefill_masks) == len(prefill_seq_lens)
164
+
154
165
  decode_token = torch.tensor([[0]], dtype=torch.int)
155
166
  decode_input_pos = torch.tensor([0], dtype=torch.int)
156
- kv = kv_utils.KVCache.from_model_config(config)
167
+ kv = export_config.kvcache_cls.from_model_config(config)
157
168
 
158
169
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
159
170
 
@@ -174,6 +185,9 @@ def _export_helper(
174
185
  'input_pos': prefill_input_pos,
175
186
  'kv_cache': kv,
176
187
  }
188
+ if prefill_masks is not None:
189
+ sample_kwargs['mask'] = prefill_masks[i]
190
+
177
191
  if lora is not None:
178
192
  prefill_signature_name += f'_lora_r{lora.get_rank()}'
179
193
  sample_kwargs['lora'] = lora
@@ -199,6 +213,8 @@ def _export_helper(
199
213
  'input_pos': decode_input_pos,
200
214
  'kv_cache': kv,
201
215
  }
216
+ if export_config.decode_mask is not None:
217
+ sample_kwargs['mask'] = export_config.decode_mask
202
218
  if lora is not None:
203
219
  sample_kwargs['lora'] = lora
204
220
 
@@ -17,7 +17,7 @@
17
17
 
18
18
  import copy
19
19
  from dataclasses import dataclass
20
- from typing import Optional, Tuple
20
+ from typing import List, Optional, Tuple
21
21
 
22
22
  from ai_edge_torch.generative.layers import attention
23
23
  from ai_edge_torch.generative.layers import builder
@@ -55,6 +55,11 @@ class ExportConfig:
55
55
  # On prefill signatures, should the model produce logit output?
56
56
  # When False, only decode signatures will produce output.
57
57
  output_logits_on_prefill: bool = False
58
+ # Attention masks given as inputs to the model.
59
+ prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
60
+ decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
61
+ # The KV Cache class for K and V buffers in attention.
62
+ kvcache_cls: type = kv_utils.KVCache
58
63
 
59
64
 
60
65
  class DecoderOnlyModel(nn.Module):
@@ -29,7 +29,7 @@ LoweringContext = context.LoweringContext
29
29
 
30
30
  @functools.cache
31
31
  def _log_usage(op):
32
- logging.warning("Use jax lowering: %s", str(op))
32
+ logging.info("Use JAX lowering: %s", str(op))
33
33
 
34
34
 
35
35
  def lower_by_jax(op, ir_input_names=None):
@@ -21,8 +21,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional
21
21
  import torch
22
22
  from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
23
23
  from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
24
- from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
25
- from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
26
24
  from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
27
25
  from torch.ao.quantization.quantizer import QuantizationAnnotation
28
26
  from torch.ao.quantization.quantizer import QuantizationSpec
@@ -47,6 +45,28 @@ __all__ = [
47
45
  "propagate_annotation",
48
46
  ]
49
47
 
48
+ # Example inputs for conv-bn1d patterns
49
+ _conv1d_bn_example_inputs = (
50
+ torch.randn(1, 1, 3), # x
51
+ torch.randn(1, 1, 1), # conv_weight
52
+ torch.randn(1), # conv_bias
53
+ torch.randn(1), # bn_weight
54
+ torch.randn(1), # bn_bias
55
+ torch.randn(1), # bn_running_mean
56
+ torch.randn(1), # bn_running_var
57
+ )
58
+
59
+ # Example inputs for conv-bn2d patterns
60
+ _conv2d_bn_example_inputs = (
61
+ torch.randn(1, 1, 3, 3), # x
62
+ torch.randn(1, 1, 1, 1), # conv_weight
63
+ torch.randn(1), # conv_bias
64
+ torch.randn(1), # bn_weight
65
+ torch.randn(1), # bn_bias
66
+ torch.randn(1), # bn_running_mean
67
+ torch.randn(1), # bn_running_var
68
+ )
69
+
50
70
 
51
71
  @dataclass(eq=True, frozen=True)
52
72
  class QuantizationConfig:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20250123"
16
+ __version__ = "0.3.0.dev20250125"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20250123
3
+ Version: 0.3.0.dev20250125
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,8 +1,8 @@
1
1
  ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
2
- ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
2
+ ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
5
+ ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -16,11 +16,11 @@ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZv
16
16
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
17
17
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
23
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
23
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
25
25
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
@@ -49,6 +49,10 @@ ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0H
49
49
  ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
50
50
  ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
51
51
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
52
+ ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
53
+ ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
54
+ ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
55
+ ai_edge_torch/generative/examples/deepseek/verify.py,sha256=sDYBhmE_CeZw5iLIQ7rJNGLjhcTyKUQGdg7_QQBh9WM,2398
52
56
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
54
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
@@ -97,7 +101,7 @@ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68
97
101
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
98
102
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
99
103
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
100
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=sB_7-PVri8PxKnFG7c8GsTGyrxGEda-oZwGyyScTL3o,5239
104
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=GtwKAByEk0ENGEWbUmC2mAAPkbLZ3M5xH1HIToyu8QE,5307
101
105
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=sQKQ-k6H9kG2brgwLsktjCMeN2h0POyfMP6iNsPNKWc,16271
102
106
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6W58LxmHHkz2ctgpknQkyoDANZAnE9Byp_svfqLpQf0,34793
103
107
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -128,12 +132,17 @@ ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8k
128
132
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
129
133
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
130
134
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
131
- ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
135
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
132
136
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
133
- ai_edge_torch/generative/layers/model_config.py,sha256=9yPEmWNw3-_2wXBmPmZ7RUKcPXHF2ZbJwksyQoXTA6M,7784
137
+ ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQteKp-UM1YvMbbWYlkNw,7999
134
138
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
135
139
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
136
140
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
141
+ ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
142
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
143
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
144
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
145
+ ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
137
146
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
138
147
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
139
148
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -151,14 +160,15 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
151
160
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
152
161
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
153
162
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
154
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF_xz0UR3kD3x74ELsmJetbQnmv7-9gyQ,12473
163
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
155
164
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
156
165
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
157
166
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
158
- ai_edge_torch/generative/utilities/converter.py,sha256=yNIZ-O6RdXYl8yuWM_sTENRxozPnKGS-TZRhiiTaraE,7515
167
+ ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
168
+ ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
159
169
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
160
170
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
161
- ai_edge_torch/generative/utilities/model_builder.py,sha256=3CQLxJ02pFIo2DlS-RCn9cT6OvR4NiIuYRH597UXLiI,6530
171
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
162
172
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
163
173
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
164
174
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -197,7 +207,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw5
197
207
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
198
208
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
199
209
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
200
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
210
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=fte81SZxgxeMcI3wWVKSTnUjIxVVilOJ6H3TybXyDmQ,11558
201
211
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
202
212
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
203
213
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
@@ -207,13 +217,13 @@ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZm
207
217
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
208
218
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
209
219
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
210
- ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=eARD1LxLi5m7Z0n_psAkeX_AtUp4fNkE--oECBfivv4,36208
220
+ ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
211
221
  ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
212
222
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
213
223
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
214
224
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
215
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
216
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
217
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
218
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
219
- ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
225
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
226
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
227
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
228
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
229
+ ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,