ai-edge-torch-nightly 0.3.0.dev20241216__py3-none-any.whl → 0.3.0.dev20241220__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -78,7 +78,8 @@ def convert_signatures(
78
78
  *,
79
79
  strict_export: Union[Literal["auto"], bool] = True,
80
80
  quant_config: Optional[qcfg.QuantConfig] = None,
81
- _tfl_converter_flags: Optional[dict[str, Any]],
81
+ _tfl_converter_flags: Optional[dict[str, Any]] = None,
82
+ _saved_model_dir: Optional[str] = None,
82
83
  ) -> model.TfLiteModel:
83
84
  """Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
84
85
 
@@ -93,6 +94,8 @@ def convert_signatures(
93
94
  quant_config: User-defined quantization method and scheme of the model.
94
95
  _tfl_converter_flags: A nested dictionary allowing setting flags for the
95
96
  underlying tflite converter.
97
+ _saved_model_dir: Directory for the intermediate saved model. If not
98
+ specified, a random temporary directory would be used.
96
99
 
97
100
  Returns:
98
101
  The converted `model.TfLiteModel` object.
@@ -140,6 +143,7 @@ def convert_signatures(
140
143
  signatures,
141
144
  quant_config=quant_config,
142
145
  _tfl_converter_flags=_tfl_converter_flags,
146
+ _saved_model_dir=_saved_model_dir,
143
147
  )
144
148
 
145
149
  return model.TfLiteModel(tflite_model)
@@ -106,6 +106,7 @@ class Converter:
106
106
  quant_config: Optional[qcfg.QuantConfig] = None,
107
107
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
108
108
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
109
+ _saved_model_dir: Optional[str] = None,
109
110
  ) -> model.TfLiteModel:
110
111
  """Finalizes the conversion and produces an edge model.
111
112
 
@@ -139,6 +140,8 @@ class Converter:
139
140
  of this function and so needs to be treated as such. Please do not rely
140
141
  on this parameter except for local debugging as this can be removed in a
141
142
  future release.
143
+ _saved_model_dir: Directory for the intermediate saved model. If not
144
+ specified, a random temporary directory would be used.
142
145
 
143
146
  Returns:
144
147
  The converted edge model.
@@ -171,6 +174,7 @@ class Converter:
171
174
  strict_export=strict_export,
172
175
  quant_config=quant_config,
173
176
  _tfl_converter_flags=_ai_edge_converter_flags,
177
+ _saved_model_dir=_saved_model_dir,
174
178
  )
175
179
 
176
180
 
@@ -216,6 +220,7 @@ def convert(
216
220
  quant_config: Optional[qcfg.QuantConfig] = None,
217
221
  dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
218
222
  _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
223
+ _saved_model_dir: Optional[str] = None,
219
224
  ) -> model.TfLiteModel:
220
225
  """Converts a PyTorch model to an edge model with a default signature.
221
226
 
@@ -240,6 +245,8 @@ def convert(
240
245
  this function and so needs to be treated as such. Please do not rely on
241
246
  this parameter except for local debugging as this can be removed in a
242
247
  future release.
248
+ _saved_model_dir: Directory for the intermediate saved model. If not
249
+ specified, a random temporary directory would be used.
243
250
 
244
251
  Returns:
245
252
  The converted edge model.
@@ -259,4 +266,5 @@ def convert(
259
266
  quant_config=quant_config,
260
267
  dynamic_shapes=dynamic_shapes,
261
268
  _ai_edge_converter_flags=_ai_edge_converter_flags,
269
+ _saved_model_dir=_saved_model_dir,
262
270
  )
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -145,24 +141,27 @@ class Gemma2(nn.Module):
145
141
  " must be the same."
146
142
  )
147
143
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
144
+ # RoPE parameters are the same for all blocks. Use the first layer.
145
+ attn_config = self.config.block_config(0).attn_config
146
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
147
+ rope = rotary_pos_emb.build_rope(
148
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
149
+ )
151
150
 
152
151
  # token embeddings of shape (b, t, n_embd)
153
152
  x = self.tok_embedding(tokens)
154
153
  x = x * (self.config.embedding_dim**0.5)
155
154
 
156
- updated_kv_entires = []
155
+ updated_kv_entries = []
157
156
  for i, block in enumerate(self.transformer_blocks):
158
157
  mask = self.get_attention_mask(
159
158
  block.config.attn_config.attn_type, input_pos
160
159
  )
161
160
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
161
+ x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
163
162
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
163
+ updated_kv_entries.append(kv_entry)
164
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
165
 
167
166
  if export_config is not None:
168
167
  if (
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entires = []
75
+ updated_kv_entries = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entires.append(kv_entry)
80
+ updated_kv_entries.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -26,33 +26,6 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
 
29
- def _embed_rope(
30
- q: torch.Tensor,
31
- k: torch.Tensor,
32
- n_elem: int,
33
- rope: Tuple[torch.Tensor, torch.Tensor],
34
- ) -> Tuple[torch.Tensor, torch.Tensor]:
35
- """Embed rotary positional embedding for query and key.
36
-
37
- Args:
38
- q (torch.Tensor): query tensor.
39
- k (torch.Tensor): key tensor.
40
- n_elem (int): number of elements to embed rotarty positional embedding.
41
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42
- """
43
- if n_elem > 0:
44
- cos, sin = rope
45
- q_roped = rotary_pos_emb.apply_rope(
46
- q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47
- )
48
- k_roped = rotary_pos_emb.apply_rope(
49
- k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50
- )
51
- q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52
- k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53
- return q, k
54
-
55
-
56
29
  class TransformerBlock(nn.Module):
57
30
 
58
31
  def __init__(
@@ -238,7 +211,8 @@ class CausalSelfAttention(nn.Module):
238
211
  if rope is not None:
239
212
  # Compute rotary positional embedding for query and key.
240
213
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
241
- q, k = _embed_rope(q, k, n_elem, rope)
214
+ cos, sin = rope
215
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
242
216
 
243
217
  if kv_cache is not None:
244
218
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -374,7 +348,8 @@ class CrossAttention(nn.Module):
374
348
  if rope is not None:
375
349
  # Compute rotary positional embedding for query and key.
376
350
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
377
- q, k = _embed_rope(q, k, n_elem, rope)
351
+ cos, sin = rope
352
+ q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
378
353
 
379
354
  if kv_cache is not None:
380
355
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -32,57 +32,64 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
- x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
- rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
- roped = (x * cos) + (rotated * sin)
35
+ x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
+ left = x1 * cos - x2 * sin
37
+ right = x2 * cos + x1 * sin
38
+ roped = torch.cat([left, right], dim=-1)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def apply_rope_inline(
43
- q: torch.Tensor,
44
- k: torch.Tensor,
42
+ def build_rope(
45
43
  input_pos: torch.Tensor,
46
44
  n_elem: int,
45
+ head_dim: int,
47
46
  base: int = 10_000,
48
47
  ) -> Tuple[torch.Tensor, torch.Tensor]:
49
- """Computes rotary positional embedding inline for a query and key.
48
+ """Computes rotary positional embedding cosine and sine tensors.
50
49
 
51
50
  Args:
52
- q: the query tensor.
53
- k: the key tensor.
54
51
  input_pos: the sequence indices for the query and key
55
52
  n_elem: number of elements of the head dimension for RoPE computation
53
+ base: the base of the exponentiated value for RoPE.
56
54
 
57
55
  Returns:
58
- output the RoPE'd query and key.
56
+ cos, sin tensors
59
57
  """
60
58
 
61
59
  if n_elem <= 0:
62
- return q, k
60
+ return None, None
63
61
 
64
62
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
65
63
  freq_exponents = (2.0 / n_elem) * torch.arange(
66
- q.shape[-1] // 2, dtype=torch.float32
64
+ head_dim // 2, dtype=torch.float32
67
65
  )
68
66
  timescale = float(base) ** freq_exponents
69
67
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
70
68
  0
71
69
  ).unsqueeze(0)
72
- cos = torch.cos(radians).type_as(q)
73
- sin = torch.sin(radians).type_as(q)
70
+ cos = torch.cos(radians)
71
+ sin = torch.sin(radians)
72
+ return cos, sin
73
+
74
74
 
75
- def apply(x, sin, cos):
76
- x = x.transpose(1, 2)
77
- b, h, s, d = x.shape
78
- ans = torch.split(x, d // 2, dim=-1)
79
- x1, x2 = ans
80
- left = x1 * cos - x2 * sin
81
- right = x2 * cos + x1 * sin
82
- res = torch.cat([left, right], dim=-1)
83
- res = res.transpose(1, 2)
84
- return res
75
+ def apply_rope_inline(
76
+ q: torch.Tensor,
77
+ k: torch.Tensor,
78
+ cos: torch.Tensor,
79
+ sin: torch.Tensor,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
81
+ """Computes rotary positional embedding inline for a query and key.
82
+
83
+ Args:
84
+ q: the query tensor.
85
+ k: the key tensor.
86
+ cos: the cosine tensor.
87
+ sin: the sine tensor.
88
+
89
+ Returns:
90
+ output the RoPE'd query and key.
91
+ """
85
92
 
86
- q_roped = apply(q, sin, cos)
87
- k_roped = apply(k, sin, cos)
93
+ q_roped = apply_rope(q, cos, sin)
94
+ k_roped = apply_rope(k, cos, sin)
88
95
  return q_roped, k_roped
@@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
26
  import ai_edge_torch.generative.layers.model_config as cfg
27
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
27
28
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
29
  import torch
29
30
  from torch import nn
@@ -85,13 +86,6 @@ class DecoderOnlyModel(nn.Module):
85
86
  config.embedding_dim,
86
87
  config.final_norm_config,
87
88
  )
88
- # ROPE parameters for all attn_configs are the same. Take the first one.
89
- attn_config = config.block_config(0).attn_config
90
- self.rope_cache = attn_utils.build_rope_cache(
91
- size=config.kv_cache_max,
92
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93
- base=attn_config.rotary_base,
94
- )
95
89
  self.mask_cache = attn_utils.build_causal_mask_cache(
96
90
  size=config.kv_cache_max,
97
91
  )
@@ -113,11 +107,16 @@ class DecoderOnlyModel(nn.Module):
113
107
 
114
108
  # token embeddings of shape (b, t, n_embd)
115
109
  input_embeds = self.tok_embedding(tokens)
116
- cos, sin = self.rope_cache
117
- rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
118
110
  mask = self.mask_cache.index_select(2, input_pos)
119
111
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
112
 
113
+ # ROPE parameters for all attn_configs are the same. Take the first one.
114
+ attn_config = self.config.block_config(0).attn_config
115
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
116
+ rope = rotary_pos_emb.build_rope(
117
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
118
+ )
119
+
121
120
  return self.forward_with_embeds(
122
121
  input_embeds, rope, mask, input_pos, kv_cache, export_config
123
122
  )
@@ -141,13 +140,13 @@ class DecoderOnlyModel(nn.Module):
141
140
  if self.config.embedding_scale is not None:
142
141
  x = x * self.config.embedding_scale
143
142
 
144
- updated_kv_entires = []
143
+ updated_kv_entries = []
145
144
  for i, block in enumerate(self.transformer_blocks):
146
145
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
146
  x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
148
147
  if kv_entry:
149
- updated_kv_entires.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
148
+ updated_kv_entries.append(kv_entry)
149
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
151
150
 
152
151
  if export_config is not None:
153
152
  if (
@@ -50,6 +50,7 @@ def exported_programs_to_tflite(
50
50
  *,
51
51
  quant_config: Optional[qcfg.QuantConfig] = None,
52
52
  _tfl_converter_flags: Optional[dict[str, Any]] = None,
53
+ _saved_model_dir: Optional[str] = None
53
54
  ):
54
55
  """Converts a list of ExportedProgram to a TFLite model.
55
56
 
@@ -57,6 +58,8 @@ def exported_programs_to_tflite(
57
58
  exported_programs: A list of ExportedProgram.
58
59
  signatures: A list of Signature.
59
60
  quant_config: A QuantConfig.
61
+ _saved_model_dir: Directory for the intermediate saved model. If not
62
+ specified, a random temporary directory would be used.
60
63
  _tfl_converter_flags: A dict of flags for TFLiteConverter.
61
64
 
62
65
  Returns:
@@ -79,4 +82,5 @@ def exported_programs_to_tflite(
79
82
  signatures,
80
83
  quant_config=quant_config,
81
84
  _tfl_converter_flags=_tfl_converter_flags,
85
+ _saved_model_dir=_saved_model_dir,
82
86
  )
@@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model(
138
138
  *,
139
139
  quant_config: Optional[qcfg.QuantConfig] = None,
140
140
  _tfl_converter_flags: dict = {},
141
+ _saved_model_dir: Optional[str] = None,
141
142
  ):
142
143
  tf_state_dict = merged_bundle.bundles[0].state_dict
143
144
 
@@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model(
173
174
  # We need to temporarily save since TFLite's from_concrete_functions does not
174
175
  # allow providing names for each of the concrete functions.
175
176
  with tempfile.TemporaryDirectory() as temp_dir_path:
177
+ if _saved_model_dir is not None:
178
+ temp_dir_path = _saved_model_dir
179
+
176
180
  tf.saved_model.save(
177
181
  tf_module,
178
182
  temp_dir_path,
@@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model(
192
192
  *,
193
193
  quant_config: Optional[qcfg.QuantConfig] = None,
194
194
  _tfl_converter_flags: dict = {},
195
+ _saved_model_dir: Optional[str] = None,
195
196
  ) -> None:
196
197
  """Converts a StableHLOGraphModule to a tflite model.
197
198
 
@@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model(
200
201
  signatures: List of signatures from which names of the signatures is
201
202
  extracted.
202
203
  quant_config: User-defined quantization method and scheme of the model.
204
+ _saved_model_dir: Directory for the intermediate saved model. If not
205
+ specified, a random temporary directory would be used.
203
206
  _tfl_converter_flags: A nested dictionary allowing setting flags for the
204
207
  underlying tflite converter.
205
208
  """
@@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model(
246
249
  # We need to temporarily save since TFLite's from_concrete_functions does not
247
250
  # allow providing names for each of the concrete functions.
248
251
  with tempfile.TemporaryDirectory() as temp_dir_path:
252
+ if _saved_model_dir is not None:
253
+ temp_dir_path = _saved_model_dir
254
+
249
255
  tf.saved_model.save(
250
256
  tf_module,
251
257
  temp_dir_path,
@@ -304,9 +304,13 @@ def exported_program_to_mlir(
304
304
  )
305
305
 
306
306
  _convert_i64_to_i32(exported_program)
307
+
307
308
  exported_program = _torch_future.safe_run_decompositions(
308
309
  exported_program, lowerings.decompositions()
309
310
  )
311
+
312
+ # Passes below mutate the exported program to a state not executable by torch.
313
+ # Do not call run_decompositions after applying the passes.
310
314
  _convert_q_dq_per_channel_args_to_list(exported_program)
311
315
 
312
316
  with export_utils.create_ir_context() as context, ir.Location.unknown():
@@ -52,10 +52,13 @@ def _uniform_quantized_type(
52
52
  assert isinstance(scale, (list, tuple))
53
53
  assert isinstance(zero_point, (list, tuple))
54
54
 
55
+ scale = list(scale)
56
+ zero_point = list(zero_point)
57
+
55
58
  if len(scale) == 1:
56
- scale *= channel_axis_size
59
+ scale = scale * channel_axis_size
57
60
  if len(zero_point) == 1:
58
- zero_point *= channel_axis_size
61
+ zero_point = zero_point * channel_axis_size
59
62
 
60
63
  assert len(scale) == len(zero_point) == channel_axis_size
61
64
  scale_zp_strs = []
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241216"
16
+ __version__ = "0.3.0.dev20241220"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241216
3
+ Version: 0.3.0.dev20241220
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,11 +3,11 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=Ghtv10LCkTCFTvDn5IHArtqXgxeWVsVxoZa6YIVIXtA,706
6
+ ai_edge_torch/version.py,sha256=xD-MWAEa1ROHhyF3rY7MaL28xsuON0aJwaiXbJ04qfc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
8
+ ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
- ai_edge_torch/_convert/converter.py,sha256=DYbTZMZos8bvm9mLyDv3W1P8ER_iGKVohbFAmLZD4r8,9534
10
+ ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
12
12
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=roEwWVXASbk5BFj7jojjEJpHui6gCelT51l-TtN_ZaQ,9367
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -107,7 +107,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
107
107
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
108
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
109
109
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
110
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
110
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
111
111
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
112
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
113
113
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
@@ -115,14 +115,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
115
115
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
116
116
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
117
117
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
118
- ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoTLzXXIy8dNtU8xC4,13199
118
+ ai_edge_torch/generative/layers/attention.py,sha256=_OmamS3f0m_JtW73ljwGLwFPeMLL837JCLY-dJ3iRUg,12453
119
119
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
120
120
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
121
121
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
122
122
  ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
123
123
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
124
124
  ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
125
- ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
125
+ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=zbFTNgQdOT-tcKK1QaIX6fG-50syYwQX_ZbLhg2C98c,2691
126
126
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
127
127
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
128
128
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
@@ -147,7 +147,7 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
147
147
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
148
148
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
149
149
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
150
- ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
150
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=q82-1E2zYlzpbFW6Vw-MWrJivRXHKpRh8jUxpR-w0sY,6349
151
151
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -160,16 +160,16 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNgh
160
160
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
161
161
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
162
162
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
163
- ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
163
+ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
164
164
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
165
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
165
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
166
166
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
167
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
167
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
168
168
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
169
169
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
170
170
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
171
171
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
172
- ai_edge_torch/odml_torch/export.py,sha256=dgnNGBVkHBz0brlWALX2hGXpQ4YzCKdwbkF4oAfEu4I,13062
172
+ ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
173
173
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
174
174
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
175
175
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -187,7 +187,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
187
187
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
188
188
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
189
189
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
190
+ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
191
191
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
192
192
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
193
193
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241216.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241216.dist-info/METADATA,sha256=ofd2acZU87DZojjV3S5U4Uw9Pzm4NRB_2WQL-omfbFY,1966
205
- ai_edge_torch_nightly-0.3.0.dev20241216.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241216.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241216.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/METADATA,sha256=PfyYhqbf7VEibw2TEDRb8tBOIPG9dfXhT9tNNou_iZg,1966
205
+ ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/RECORD,,