ai-edge-torch-nightly 0.3.0.dev20240813__py3-none-any.whl → 0.3.0.dev20240817__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +2 -2
  2. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +2 -2
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +67 -0
  4. ai_edge_torch/generative/examples/gemma/gemma.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +250 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -2
  7. ai_edge_torch/generative/examples/t5/t5.py +4 -4
  8. ai_edge_torch/generative/examples/t5/t5_attention.py +3 -3
  9. ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
  11. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
  12. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  13. ai_edge_torch/generative/layers/attention.py +12 -5
  14. ai_edge_torch/generative/layers/attention_utils.py +30 -0
  15. ai_edge_torch/generative/layers/builder.py +5 -0
  16. ai_edge_torch/generative/layers/feed_forward.py +15 -3
  17. ai_edge_torch/generative/layers/model_config.py +35 -13
  18. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +25 -9
  19. ai_edge_torch/generative/test/test_model_conversion.py +29 -1
  20. ai_edge_torch/generative/utilities/loader.py +29 -7
  21. ai_edge_torch/generative/utilities/t5_loader.py +8 -8
  22. ai_edge_torch/hlfb/test/test_mark_pattern.py +32 -8
  23. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +5 -0
  24. ai_edge_torch/lowertools/__init__.py +1 -0
  25. ai_edge_torch/lowertools/odml_torch_utils.py +3 -0
  26. ai_edge_torch/lowertools/test_utils.py +60 -0
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/RECORD +32 -29
  30. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240817.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  from dataclasses import dataclass
17
17
  from dataclasses import field
18
18
  import enum
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
20
 
21
21
 
22
22
  @enum.unique
@@ -53,6 +53,11 @@ class FeedForwardType(enum.Enum):
53
53
  GATED = enum.auto()
54
54
 
55
55
 
56
+ class AttentionType(enum.Enum):
57
+ GLOBAL = enum.auto()
58
+ LOCAL_SLIDING = enum.auto()
59
+
60
+
56
61
  @dataclass
57
62
  class AttentionConfig:
58
63
  """Attention model's parameters."""
@@ -78,6 +83,12 @@ class AttentionConfig:
78
83
  enable_kv_cache: bool = True
79
84
  relative_attention_num_buckets: int = 0
80
85
  relative_attention_max_distance: int = 0
86
+ # Softcap on the output logits.
87
+ logit_softcap: Optional[float] = None
88
+ # The types of attention used in the layers of the model.
89
+ attn_types: Optional[Sequence[AttentionType]] = None
90
+ # The size of the sliding window used for local attention.
91
+ sliding_window_size: Optional[int] = None
81
92
 
82
93
 
83
94
  @dataclass
@@ -88,16 +99,6 @@ class ActivationConfig:
88
99
  dim_out: Optional[int] = None
89
100
 
90
101
 
91
- @dataclass
92
- class FeedForwardConfig:
93
- """FeedForward module's parameters."""
94
-
95
- type: FeedForwardType
96
- activation: ActivationConfig
97
- intermediate_size: int
98
- use_bias: bool = False
99
-
100
-
101
102
  @dataclass
102
103
  class NormalizationConfig:
103
104
  """Normalizater parameters."""
@@ -109,6 +110,24 @@ class NormalizationConfig:
109
110
  group_num: Optional[float] = None
110
111
 
111
112
 
113
+ @dataclass
114
+ class FeedForwardConfig:
115
+ """FeedForward module's parameters."""
116
+
117
+ type: FeedForwardType
118
+ activation: ActivationConfig
119
+ intermediate_size: int
120
+ use_bias: bool = False
121
+ # The normalization applied to feed forward's input.
122
+ pre_ff_norm_config: NormalizationConfig = field(
123
+ default_factory=NormalizationConfig
124
+ )
125
+ # The normalization applied to feed forward's output.
126
+ post_ff_norm_config: NormalizationConfig = field(
127
+ default_factory=NormalizationConfig
128
+ )
129
+
130
+
112
131
  @dataclass
113
132
  class ModelConfig:
114
133
  """Base configurations for building a transformer architecture."""
@@ -124,8 +143,8 @@ class ModelConfig:
124
143
  pre_attention_norm_config: NormalizationConfig = field(
125
144
  default_factory=NormalizationConfig
126
145
  )
127
- # The normalization applied to feed forward's input.
128
- pre_ff_norm_config: NormalizationConfig = field(
146
+ # The normalization applied to attentions's output.
147
+ post_attention_norm_config: NormalizationConfig = field(
129
148
  default_factory=NormalizationConfig
130
149
  )
131
150
  # The normalization applied before LM head.
@@ -151,6 +170,9 @@ class ModelConfig:
151
170
  # Default batch size of the exported model. Default value is 1.
152
171
  batch_size: int = 1
153
172
 
173
+ # Softcap on the model output logits.
174
+ final_logit_softcap: Optional[float] = None
175
+
154
176
  @property
155
177
  def kv_cache_max(self) -> int:
156
178
  if self.kv_cache_max_len > 0:
@@ -29,6 +29,7 @@ def scaled_dot_product_attention(
29
29
  head_size: int,
30
30
  mask: Optional[torch.Tensor] = None,
31
31
  scale: Optional[float] = None,
32
+ softcap: Optional[float] = None,
32
33
  ):
33
34
  """Scaled dot product attention.
34
35
 
@@ -53,15 +54,26 @@ def scaled_dot_product_attention(
53
54
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
54
55
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
55
56
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
56
- y = F.scaled_dot_product_attention(
57
- q,
58
- k,
59
- v,
60
- attn_mask=mask,
61
- dropout_p=0.0,
62
- is_causal=mask is None,
63
- scale=scale,
64
- )
57
+ if softcap is None:
58
+ y = F.scaled_dot_product_attention(
59
+ q,
60
+ k,
61
+ v,
62
+ attn_mask=mask,
63
+ dropout_p=0.0,
64
+ is_causal=mask is None,
65
+ scale=scale,
66
+ )
67
+ else:
68
+ q.mul_(scale)
69
+ scores = q @ k.transpose(-1, -2)
70
+ scores = scores / softcap
71
+ scores = torch.tanh(scores)
72
+ scores = scores * softcap
73
+ scores = scores + mask
74
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
75
+ y = torch.matmul(out, v)
76
+
65
77
  return y.transpose(1, 2)
66
78
 
67
79
 
@@ -72,6 +84,7 @@ def scaled_dot_product_attention_with_hlfb(
72
84
  head_size: int,
73
85
  mask: Optional[torch.Tensor] = None,
74
86
  scale: Optional[float] = None,
87
+ softcap: Optional[float] = None,
75
88
  ):
76
89
  """Scaled dot product attention with high-level function boundary enabled.
77
90
 
@@ -86,6 +99,9 @@ def scaled_dot_product_attention_with_hlfb(
86
99
  The output tensor of scaled_dot_product_attention.
87
100
  """
88
101
 
102
+ if softcap is not None:
103
+ raise NotImplementedError("SDPA with HLFB not available with softcap.")
104
+
89
105
  if scale is None:
90
106
  scale = 1.0 / math.sqrt(head_size)
91
107
 
@@ -16,7 +16,7 @@
16
16
  import copy
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch.generative.examples.gemma import gemma
19
+ from ai_edge_torch.generative.examples.gemma import gemma, gemma2
20
20
  from ai_edge_torch.generative.examples.phi2 import phi2
21
21
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
22
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
@@ -202,6 +202,34 @@ class TestModelConversion(googletest.TestCase):
202
202
  )
203
203
  )
204
204
 
205
+ def test_gemma2(self):
206
+ self.skipTest("b/338288901")
207
+ config = gemma2.get_fake_model_config_2b_for_test()
208
+ model = gemma2.Gemma2(config)
209
+ model.eval()
210
+
211
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
212
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
213
+ tokens[0, :4] = idx
214
+ input_pos = torch.arange(0, 10)
215
+
216
+ edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
217
+
218
+ # TODO: b/338288901 - re-enable test to check output tensors.
219
+ skip_output_check = True
220
+ if not skip_output_check:
221
+ # TODO(talumbau, haoliang): debug numerical diff.
222
+ self.assertTrue(
223
+ model_coverage.compare_tflite_torch(
224
+ edge_model,
225
+ model,
226
+ (tokens, input_pos),
227
+ num_valid_inputs=1,
228
+ atol=1e-2,
229
+ rtol=1e-5,
230
+ )
231
+ )
232
+
205
233
  def test_phi2(self):
206
234
  self.skipTest("b/338288901")
207
235
  config = phi2.get_fake_model_config_for_test()
@@ -107,7 +107,9 @@ class ModelLoader:
107
107
  ff_gate_proj: str = None
108
108
 
109
109
  pre_attn_norm: str = None
110
+ post_attn_norm: str = None
110
111
  pre_ff_norm: str = None
112
+ post_ff_norm: str = None
111
113
  embedding: str = None
112
114
  embedding_position: str = None
113
115
  final_norm: str = None
@@ -258,6 +260,26 @@ class ModelLoader:
258
260
  f"{ff_gate_proj_name}.bias"
259
261
  )
260
262
 
263
+ if self._names.pre_ff_norm is not None:
264
+ pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
265
+ converted_state[f"{prefix}.ff.pre_ff_norm.weight"] = state.pop(
266
+ f"{pre_ff_norm_name}.weight"
267
+ )
268
+ if f"{pre_ff_norm_name}.bias" in state:
269
+ converted_state[f"{prefix}.ff.pre_ff_norm.bias"] = state.pop(
270
+ f"{pre_ff_norm_name}.bias"
271
+ )
272
+
273
+ if self._names.post_ff_norm is not None:
274
+ post_ff_norm_name = self._names.post_ff_norm.format(idx)
275
+ converted_state[f"{prefix}.ff.post_ff_norm.weight"] = state.pop(
276
+ f"{post_ff_norm_name}.weight"
277
+ )
278
+ if f"{post_ff_norm_name}.bias" in state:
279
+ converted_state[f"{prefix}.ff.post_ff_norm.bias"] = state.pop(
280
+ f"{post_ff_norm_name}.bias"
281
+ )
282
+
261
283
  def _map_attention(
262
284
  self,
263
285
  idx: int,
@@ -325,14 +347,14 @@ class ModelLoader:
325
347
  f"{pre_attn_norm_name}.bias"
326
348
  )
327
349
 
328
- if self._names.pre_ff_norm is not None:
329
- pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
330
- converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
331
- f"{pre_ff_norm_name}.weight"
350
+ if self._names.post_attn_norm is not None:
351
+ post_attn_norm_name = self._names.post_attn_norm.format(idx)
352
+ converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
353
+ f"{post_attn_norm_name}.weight"
332
354
  )
333
- if f"{pre_ff_norm_name}.bias" in state:
334
- converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
335
- f"{pre_ff_norm_name}.bias"
355
+ if f"{post_attn_norm_name}.bias" in state:
356
+ converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
357
+ f"{post_attn_norm_name}.bias"
336
358
  )
337
359
 
338
360
  def _fuse_qkv(
@@ -113,7 +113,7 @@ class ModelLoader:
113
113
 
114
114
  pre_attn_norm: str = None
115
115
  pre_cross_attn_norm: str = None
116
- pre_ff_norm: str = None
116
+ post_attn_norm: str = None
117
117
  embedding: str = None
118
118
  final_norm: str = None
119
119
  lm_head: str = None
@@ -484,14 +484,14 @@ class ModelLoader:
484
484
  state.pop(f"{pre_cross_attn_norm_name}.bias")
485
485
  )
486
486
 
487
- if names.pre_ff_norm is not None:
488
- pre_ff_norm_name = names.pre_ff_norm.format(idx)
489
- converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
490
- f"{pre_ff_norm_name}.weight"
487
+ if names.post_attn_norm is not None:
488
+ post_attn_norm_name = names.post_attn_norm.format(idx)
489
+ converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
490
+ f"{post_attn_norm_name}.weight"
491
491
  )
492
- if f"{pre_ff_norm_name}.bias" in state:
493
- converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
494
- f"{pre_ff_norm_name}.bias"
492
+ if f"{post_attn_norm_name}.bias" in state:
493
+ converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
494
+ f"{post_attn_norm_name}.bias"
495
495
  )
496
496
 
497
497
  def _fuse_qkv(
@@ -51,7 +51,12 @@ class TestMarkPattern(googletest.TestCase):
51
51
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
52
52
  mlir = _export_stablehlo_mlir(exported_program)
53
53
 
54
- self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
54
+ lowertools.assert_string_count(
55
+ self,
56
+ mlir,
57
+ {'stablehlo.composite "test.add"': 2},
58
+ {"stablehlo.custom_call @mark_tensor": 6},
59
+ )
55
60
 
56
61
  def test_mark_pattern_with_attr_builder(self):
57
62
  class TestModel(torch.nn.Module):
@@ -72,9 +77,15 @@ class TestMarkPattern(googletest.TestCase):
72
77
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
73
78
  mlir = _export_stablehlo_mlir(exported_program)
74
79
 
75
- self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
76
- self.assertEqual(
77
- mlir.count('composite_attributes = {alias = "test.test_add"}'), 2
80
+ lowertools.assert_string_count(
81
+ self,
82
+ mlir,
83
+ {
84
+ 'stablehlo.composite "test.add"': 2,
85
+ 'composite_attributes = {alias = "test.test_add"}': 2,
86
+ },
87
+ {"stablehlo.custom_call @mark_tensor": 6},
88
+ {'{"alias": "test.test_add"}': 2},
78
89
  )
79
90
 
80
91
  def test_mark_pattern_with_scalar_attr_tracker(self):
@@ -104,9 +115,17 @@ class TestMarkPattern(googletest.TestCase):
104
115
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
105
116
  mlir = _export_stablehlo_mlir(exported_program)
106
117
 
107
- self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 5)
108
- self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 3)
109
- self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 2)
118
+ lowertools.assert_string_count(
119
+ self,
120
+ mlir,
121
+ {
122
+ 'stablehlo.composite "test.log_softmax"': 5,
123
+ "composite_attributes = {dim = 0 : i64}": 3,
124
+ "composite_attributes = {dim = 1 : i64}": 2,
125
+ },
126
+ {"stablehlo.custom_call @mark_tensor": 10},
127
+ {'{"dim": 0}': 3, '{"dim": 1}': 2},
128
+ )
110
129
 
111
130
  def test_mark_tangent_model_and_pattern_input(self):
112
131
  class TestModel(torch.nn.Module):
@@ -128,7 +147,12 @@ class TestMarkPattern(googletest.TestCase):
128
147
  mark_pattern.mark_pattern(exported_program.graph_module, pattern)
129
148
  mlir = _export_stablehlo_mlir(exported_program)
130
149
 
131
- self.assertEqual(mlir.count('stablehlo.composite "test.relu'), 1)
150
+ lowertools.assert_string_count(
151
+ self,
152
+ mlir,
153
+ {'stablehlo.composite "test.relu"': 1},
154
+ {"stablehlo.custom_call @mark_tensor": 2},
155
+ )
132
156
 
133
157
 
134
158
  if __name__ == "__main__":
@@ -16,6 +16,7 @@
16
16
 
17
17
  import math
18
18
 
19
+ from ai_edge_torch import config
19
20
  from ai_edge_torch import lowertools
20
21
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
21
22
  import torch
@@ -29,6 +30,10 @@ def _export_stablehlo_mlir(model, args):
29
30
  return lowertools.exported_program_to_mlir_text(ep)
30
31
 
31
32
 
33
+ @googletest.skipIf(
34
+ not config.Config.use_torch_xla,
35
+ reason="The odml_torch counter part is in odml_torch.",
36
+ )
32
37
  class TestStableHLOCompositeBuilder(googletest.TestCase):
33
38
 
34
39
  def test_build_composite(self):
@@ -14,3 +14,4 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from ._shim import *
17
+ from .test_utils import *
@@ -28,6 +28,7 @@ import tensorflow as tf
28
28
  import torch
29
29
 
30
30
  from tensorflow.compiler.tf2xla.python import xla as tfxla
31
+ from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
31
32
 
32
33
  MlirBundle = odml_torch.export.MlirLowered
33
34
 
@@ -162,7 +163,9 @@ def merged_bundle_to_tfl_model(
162
163
  )
163
164
 
164
165
  converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
166
+ converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
165
167
  converter._experimental_enable_composite_direct_lowering = True
168
+ converter.model_origin_framework = "PYTORCH"
166
169
 
167
170
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
168
171
 
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import re
17
+ from typing import Optional
18
+ from ai_edge_torch import config
19
+ from tensorflow.python.platform import googletest
20
+
21
+
22
+ def _extract_backend_configs(mlir):
23
+ mlir = mlir.replace("\\22", '"')
24
+ configs = []
25
+ for match in re.finditer(r"backend_config\s*=\s*\"(\{.*\})\"", mlir):
26
+ configs.append(match.group(1))
27
+ return "\n".join(configs)
28
+
29
+
30
+ def assert_string_count(
31
+ test_case: googletest.TestCase,
32
+ mlir: str,
33
+ torch_xla_pattern_counter: dict[str, int],
34
+ odml_torch_pattern_counter: dict[str, int],
35
+ odml_torch_attr_counter: Optional[dict[str, int]] = None,
36
+ ):
37
+
38
+ if odml_torch_attr_counter is None:
39
+ odml_torch_attr_counter = {}
40
+
41
+ if config.Config.use_torch_xla:
42
+ for key in torch_xla_pattern_counter:
43
+ test_case.assertEqual(
44
+ mlir.count(key),
45
+ torch_xla_pattern_counter[key],
46
+ )
47
+ else:
48
+ for key in odml_torch_pattern_counter:
49
+ test_case.assertEqual(
50
+ mlir.count(key),
51
+ odml_torch_pattern_counter[key],
52
+ )
53
+ backend_configs = _extract_backend_configs(mlir)
54
+ print("backend_configs:")
55
+ print(backend_configs)
56
+ for key in odml_torch_attr_counter:
57
+ test_case.assertEqual(
58
+ backend_configs.count(key),
59
+ odml_torch_attr_counter[key],
60
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240813"
16
+ __version__ = "0.3.0.dev20240817"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240813
3
+ Version: 0.3.0.dev20240817
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
5
- ai_edge_torch/version.py,sha256=C9Lsgh_kXnELi8xLPUgnmDTLFOKW5S5z6lXAVwLMypU,706
5
+ ai_edge_torch/version.py,sha256=C4yfJq9TbtZBH5gwhPSUtBgiIe04GkxvCq5TImNopww,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -42,22 +42,24 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
42
42
  ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
43
  ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
44
  ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=QoFbUUCTJrW1IYZg0vfb2-K-X0q1-NJFbWNGPQGwBgk,6688
45
+ ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=8313wSsddvuxZ5ZYVdaITBV2FF1k22dcCujnq0UZvKs,6699
46
46
  ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
48
  ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
49
49
  ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=GOLLd9yCBnlNXeW7xrVy1wjOltcTbRdSpiJycbMj8TA,6372
51
+ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=zQYtyk3xYdiRAnzMKN58Q_wgTQFnDujxp6L4RFQjiD4,6383
52
52
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
53
54
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
54
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=5Dn9JgJiXN-hWGQj9YqCr8Iik8mh5s0dX0VfyY8KDDo,6236
55
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
56
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=j-zxJ-JNRnQ_kDzUESmsyy_a_4IxWZ510HmIImc0LDc,8240
55
57
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
58
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
57
59
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
58
60
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
61
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
60
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=VR09iAnj1e-sr-oam2rh24Wnb_JdZZQvpJIjylfgnS8,4468
62
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
61
63
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7ra36nM5tQwSw-vi6QCFLx5IssZhT-6yVK4H3XsAc4w,5044
62
64
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
63
65
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
@@ -72,27 +74,27 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
72
74
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
73
75
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
74
76
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz70WOvNpTJ9LFkiDnlwgJiXfUZCVk,4548
75
- ai_edge_torch/generative/examples/t5/t5.py,sha256=6Rkisv7UI2w5KV8ogPPzeIiPWYwDLfFfSIncqD7Eenc,20854
76
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=gp7DV8pv4FwICQhYlUYfYZ7BE5jzDIsD_V3a_4-T4Ds,8492
77
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
77
79
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=DhxOrIKe-tilBjbh1q4MsmCmmKMc4c1BPUzhnaJDD6M,3955
79
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=bW0QB-_h9cfwAQf11AxFxOBq3HrEep_UlpBjXz3JSew,5801
80
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=CRja_CT0_eaH16rSDxwHKJS_CGUJMW0Fxd4r45Ii8Uo,4833
80
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
81
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
81
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
84
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
83
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=nu3Il8Vxe7JwM8-AnGNXoGoZ9eVXKHMYEAqVEP-gwe8,5929
85
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mXXFYJfo8yegSOFOndCR0oYxFPchYb9vTJ4ThXGIFLU,5940
84
86
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
85
87
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
86
88
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
- ai_edge_torch/generative/layers/attention.py,sha256=xq10Gw4GudK4M2eY8-H4fi3qmpmZCfE-CziAXDZvqiQ,12177
88
- ai_edge_torch/generative/layers/attention_utils.py,sha256=2hzBVZvWCqqLfI-f3RJA1hi6T8cuaIJBPt8cdjQCA5s,6420
89
- ai_edge_torch/generative/layers/builder.py,sha256=JvPmwrG8_M4-kO2MM6sDZhpS32Wx3wVVhlVO4yPJKJ0,4161
90
- ai_edge_torch/generative/layers/feed_forward.py,sha256=RukSYr9h_DehcYVZWLS_rfCTY73Uj__pTRUatjxJtv8,2788
89
+ ai_edge_torch/generative/layers/attention.py,sha256=2UujQePRJ1LK02PN-hGcuMu0ooCJC6ETfPvzEYVFyho,12284
90
+ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
91
+ ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
92
+ ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
91
93
  ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lfvKca8JbWtFx2CE,3082
92
- ai_edge_torch/generative/layers/model_config.py,sha256=CTvKFwsBR3Rc-Kf73NA7k0799m1WnEvaEBKCnnfNkyo,4961
94
+ ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
93
95
  ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
94
96
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
95
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=6WMe-A5KSSujQcZ34hIeSnnor3AXrw10cQ5FKy-30IU,3390
97
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=x2bOmrTgOISXcb06IDP7X3xgftpPpxOjBXw_OxTMVns,3874
96
98
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
99
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
98
100
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -109,23 +111,24 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha
109
111
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
112
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=T5-O2RVLJTH7v9w1_uBfp-Y7o3sdGzYq2Tj2wLRNHyI,4357
111
113
  ai_edge_torch/generative/test/test_loader.py,sha256=1ZqAq0HY5uIioumsReOVIsbGBx0WkYcl18PvttdJKrk,3381
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=4RTB1oPA2eWPyuof2-ZB1BxVKzKy5Q9vCux7psmV6zc,7615
114
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=52ciFy_Qol2Xuym6P6EqdL29oai35LSWGvsUwyEdFTo,8477
113
115
  ai_edge_torch/generative/test/test_quantize.py,sha256=3SmJm7Kq98gAneU6IGwwJrJYCVH1qwWR6oUxPfb6qiI,5346
114
116
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
115
- ai_edge_torch/generative/utilities/loader.py,sha256=XfVRvwvZyQuofctxIedLNDKQrsy9UlRr4wpScZJLWcw,11779
117
+ ai_edge_torch/generative/utilities/loader.py,sha256=bAWZ7FM4v_pPnX_AmEdGxHkDH65QdL-MjIP3PxscZmI,12649
116
118
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
117
- ai_edge_torch/generative/utilities/t5_loader.py,sha256=jz2qnDtH6oyxcqaBwEVfiiKmq_93LTDeUKNJ2cWpLwg,16856
119
+ ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
118
120
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
119
121
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
120
122
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
121
123
  ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=uiYRfzD1T8deCEAGfdAFusRbI41m14zeTt0Lz5lNT3M,9808
122
124
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
123
- ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=oYB0RPW-tHOwW9gQFH9GtHKO_Mmh1lkoiemXmTfySqc,4383
124
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=3vSX5E9ZFFhTPZZX6TMiAsGa_kaXABbN851bRbTFsC0,8297
125
- ai_edge_torch/lowertools/__init__.py,sha256=0M9TOR80sS5y6dikOsIFYx0P9IomqAdNIuYpgkP4PcI,693
125
+ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=e53YNSO2w7Sd9Y717jAr6WKjnXq34Tx_52hXRGtGs3A,4833
126
+ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=7Qbba7GJCBc-J1TUwWIvrpBK0Hwza9nift7sKpW2YVE,8449
127
+ ai_edge_torch/lowertools/__init__.py,sha256=uKGibEN7n4Tqbe0HiXOEEXWmPL9AUmh34xaYA9yx2sg,719
126
128
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
127
129
  ai_edge_torch/lowertools/common_utils.py,sha256=emClsZ_MBlbLG_0BBtyLpkdz4dMWp6SyrNioygRBylk,2973
128
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=5XVI2ovptp1wsDZcyyaZDgT4oUa1McOiE-PrKhXNhFo,6316
130
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=32cak8uiXFIVdkaYFhIW1fWG4NzLrYq-w8xK0pNkhYc,6547
131
+ ai_edge_torch/lowertools/test_utils.py,sha256=vsjaX3Ix2U1163jVUNSJgK9io2WNUtJjRvNFE9DrqF4,1932
129
132
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-g0NldtVOTCQtX3V2XEjuCQO_I52nSNQlu0r_rIS2IE,8635
130
133
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
131
134
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -134,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
134
137
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
135
138
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
136
139
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
137
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
138
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/METADATA,sha256=TMkI635DYqK0Fg6W6tZbg8ZTT54_9QkkCcd3XOxjyho,1885
139
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
140
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
141
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/RECORD,,
140
+ ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
141
+ ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/METADATA,sha256=GZgUf21m2RQYBvHxmeTujeniBxbbUTVpQQB9vjNSTaM,1885
142
+ ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
143
+ ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
144
+ ai_edge_torch_nightly-0.3.0.dev20240817.dist-info/RECORD,,