ai-edge-torch-nightly 0.3.0.dev20241226__py3-none-any.whl → 0.3.0.dev20250107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. ai_edge_torch/_config.py +26 -9
  2. ai_edge_torch/_convert/conversion.py +22 -18
  3. ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +1 -0
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
  5. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
  7. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  8. ai_edge_torch/generative/layers/attention.py +29 -4
  9. ai_edge_torch/generative/layers/normalization.py +2 -50
  10. ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
  11. ai_edge_torch/generative/utilities/model_builder.py +14 -14
  12. ai_edge_torch/generative/utilities/verifier.py +4 -4
  13. ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
  14. ai_edge_torch/odml_torch/export.py +1 -6
  15. ai_edge_torch/odml_torch/tf_integration.py +12 -50
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/METADATA +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/RECORD +21 -21
  19. {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/LICENSE +0 -0
  20. {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/WHEEL +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py CHANGED
@@ -22,6 +22,18 @@ import os
22
22
  __all__ = ["config"]
23
23
 
24
24
 
25
+ def _get_bool_env_var(name: str, default: bool) -> bool:
26
+ var = os.environ.get(name, "false")
27
+ var = var.lower().strip()
28
+ if var in ("y", "yes", "t", "true", "on", "1"):
29
+ return True
30
+ elif var in ("n", "no", "f", "false", "off", "0"):
31
+ return False
32
+ else:
33
+ logging.warning("Invalid %s value is ignored: %s.", name, var)
34
+ return default
35
+
36
+
25
37
  class _Config:
26
38
  """ai-edge-torch global configs."""
27
39
 
@@ -33,20 +45,25 @@ class _Config:
33
45
  To use torch_xla as the lowering backend, set environment variable
34
46
  `USE_TORCH_XLA` to "true".
35
47
  """
36
- var = os.environ.get("USE_TORCH_XLA", "false")
37
- var = var.lower().strip()
38
- if var in ("y", "yes", "t", "true", "on", "1"):
39
- return True
40
- elif var in ("n", "no", "f", "false", "off", "0"):
41
- return False
42
- else:
43
- logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
- return False
48
+ return _get_bool_env_var("USE_TORCH_XLA", default=False)
45
49
 
46
50
  @property
47
51
  def in_oss(self) -> bool:
48
52
  """True if the code is not running in google internal environment."""
49
53
  return True
50
54
 
55
+ @property
56
+ def enable_group_norm_composite(self) -> bool:
57
+ """True if lowering group norm in StableHLO composite.
58
+
59
+ Currently only supports NHWC group norm generated by
60
+ OptimizeLayoutTransposesPass.
61
+ """
62
+ return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
63
+
64
+ @enable_group_norm_composite.setter
65
+ def enable_group_norm_composite(self, value: bool):
66
+ os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
67
+
51
68
 
52
69
  config = _Config()
@@ -14,9 +14,9 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import logging
17
- import os
18
17
  from typing import Any, Literal, Optional, Union
19
18
 
19
+ import ai_edge_torch
20
20
  from ai_edge_torch import fx_pass_base
21
21
  from ai_edge_torch import lowertools
22
22
  from ai_edge_torch import model
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
26
26
  from ai_edge_torch.quantize import quant_config as qcfg
27
27
  import torch
28
28
 
29
- os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
30
-
31
29
 
32
30
  def _run_convert_passes(
33
31
  exported_program: torch.export.ExportedProgram,
@@ -35,21 +33,27 @@ def _run_convert_passes(
35
33
  exported_program = generative_fx_passes.run_generative_passes(
36
34
  exported_program
37
35
  )
38
- exported_program = fx_pass_base.run_passes(
39
- exported_program,
40
- [
41
- fx_passes.BuildInterpolateCompositePass(),
42
- fx_passes.CanonicalizePass(),
43
- fx_passes.OptimizeLayoutTransposesPass(),
44
- fx_passes.CanonicalizePass(),
45
- fx_passes.BuildAtenCompositePass(),
46
- fx_passes.CanonicalizePass(),
47
- fx_passes.RemoveNonUserOutputsPass(),
48
- fx_passes.CanonicalizePass(),
49
- fx_passes.InjectMlirDebuginfoPass(),
50
- fx_passes.CanonicalizePass(),
51
- ],
52
- )
36
+
37
+ passes = [
38
+ fx_passes.BuildInterpolateCompositePass(),
39
+ fx_passes.CanonicalizePass(),
40
+ fx_passes.OptimizeLayoutTransposesPass(),
41
+ fx_passes.CanonicalizePass(),
42
+ fx_passes.BuildAtenCompositePass(),
43
+ fx_passes.CanonicalizePass(),
44
+ fx_passes.RemoveNonUserOutputsPass(),
45
+ fx_passes.CanonicalizePass(),
46
+ ]
47
+
48
+ # Debuginfo is not injected automatically by odml_torch. Only inject
49
+ # debuginfo via fx pass when using torch_xla.
50
+ if ai_edge_torch.config.use_torch_xla:
51
+ passes += [
52
+ fx_passes.InjectMlirDebuginfoPass(),
53
+ fx_passes.CanonicalizePass(),
54
+ ]
55
+
56
+ exported_program = fx_pass_base.run_passes(exported_program, passes)
53
57
  return exported_program
54
58
 
55
59
 
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
62
62
 
63
63
 
64
64
  class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
65
+ """DEPRECATED: Debuginfo is injected automatically by odml_torch."""
65
66
 
66
67
  def call(self, graph_module: torch.fx.GraphModule):
67
68
  for node in graph_module.graph.nodes:
@@ -17,6 +17,7 @@
17
17
  import dataclasses
18
18
  import operator
19
19
 
20
+ import ai_edge_torch
20
21
  from ai_edge_torch import lowertools
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
155
156
  @layout_sensitive_inputs_getters.register(
156
157
  aten._native_batch_norm_legit_no_training
157
158
  )
159
+ @layout_sensitive_inputs_getters.register(aten.group_norm)
158
160
  @layout_sensitive_inputs_getters.register(aten.native_group_norm)
159
161
  def _first_arg_getter(node):
160
162
  return [node.args[0]]
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
188
190
  return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
189
191
 
190
192
 
193
+ @nhwcable_node_checkers.register(aten.group_norm)
194
+ def _aten_group_norm_checker(node):
195
+ val = node.meta.get("val")
196
+ if not hasattr(val, "shape"):
197
+ return NHWCable(can_be=False, must_be=False)
198
+
199
+ can_be = len(val.shape) == 4
200
+ must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
201
+ return NHWCable(can_be=can_be, must_be=must_be)
202
+
203
+
191
204
  @nhwcable_node_checkers.register(aten.native_group_norm)
192
205
  def _aten_native_group_norm_checker(node):
193
206
  val = node.meta.get("val")
@@ -16,6 +16,7 @@
16
16
 
17
17
  import operator
18
18
 
19
+ import ai_edge_torch
19
20
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -23,6 +24,7 @@ import torch
23
24
  import torch.utils._pytree as pytree
24
25
 
25
26
  aten = torch.ops.aten
27
+ StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
26
28
 
27
29
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
28
30
 
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
342
344
  node.target = batch_norm
343
345
 
344
346
 
347
+ @rewriters.register(aten.group_norm.default)
348
+ def _aten_group_norm(node):
349
+ def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
350
+ is_composite_supported = (
351
+ ai_edge_torch.config.enable_group_norm_composite
352
+ and weight is not None
353
+ and bias is not None
354
+ )
355
+
356
+ builder = None
357
+ if is_composite_supported:
358
+ builder = StableHLOCompositeBuilder(
359
+ name="odml.group_norm",
360
+ attr={
361
+ "num_groups": num_groups,
362
+ "epsilon": eps,
363
+ "reduction_axes": [3],
364
+ "channel_axis": 3,
365
+ },
366
+ )
367
+ input, weight, bias = builder.mark_inputs(input, weight, bias)
368
+
369
+ input = utils.tensor_to_nchw(input)
370
+ output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
371
+ output = utils.tensor_to_nhwc(output)
372
+
373
+ if builder is not None:
374
+ output = builder.mark_outputs(output)
375
+ return output
376
+
377
+ node.target = group_norm
378
+
379
+
345
380
  @rewriters.register(aten.native_group_norm.default)
346
381
  def _aten_native_group_norm(node):
347
382
 
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
354
389
  flattened_inner_size: int,
355
390
  num_groups: int,
356
391
  eps: float,
392
+ **kwargs,
357
393
  ):
358
394
  input_reshaped = torch.reshape(
359
395
  input,
@@ -15,14 +15,13 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import List, Optional, Tuple
18
+ from typing import Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
25
  from ai_edge_torch.generative.utilities import model_builder
27
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
28
27
  import torch
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
104
103
  config.embedding_dim,
105
104
  config.final_norm_config,
106
105
  )
107
- self.mask_cache = attn_utils.build_causal_mask_cache(
108
- size=config.kv_cache_max,
109
- )
110
106
  # Gemma2 has same hyper parameters for each layer except for attention
111
107
  # types. Use the first layer.
112
108
  attn_config = config.block_config(0).attn_config
109
+ self.rope_cache = attn_utils.build_rope_cache(
110
+ size=config.kv_cache_max,
111
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
+ base=attn_config.rotary_base,
113
+ )
114
+ self.mask_cache = attn_utils.build_causal_mask_cache(
115
+ size=config.kv_cache_max,
116
+ )
113
117
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
114
118
  size=config.kv_cache_max,
115
119
  window_size=attn_config.sliding_window_size,
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
136
140
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
137
141
  f" {self.config.max_seq_len}"
138
142
  )
139
-
140
- # token embeddings of shape (b, t, n_embd)
141
- input_embeds = self.tok_embedding(tokens)
142
- # RoPE parameters are the same for all blocks. Use the first layer.
143
- attn_config = self.config.block_config(0).attn_config
144
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
- rope = rotary_pos_emb.build_rope(
146
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
- )
148
- mask = [self.get_attention_mask(
149
- self.config.block_config(i).attn_config.attn_type, input_pos
150
- ) for i in range(self.config.num_layers)]
151
-
152
- return self._forward_with_embeds(
153
- input_embeds, rope, mask, input_pos, kv_cache, export_config
154
- )
155
-
156
- def _forward_with_embeds(
157
- self,
158
- input_embeds: torch.Tensor,
159
- rope: Tuple[torch.Tensor, torch.Tensor],
160
- mask: List[torch.Tensor],
161
- input_pos: torch.Tensor,
162
- kv_cache: kv_utils.KVCache,
163
- export_config: Optional[model_builder.ExportConfig] = None,
164
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
165
- """Forwards the model with input embeddings."""
166
143
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
167
144
  "The number of transformer blocks and the number of KV cache entries"
168
145
  " must be the same."
169
146
  )
170
147
 
171
- if self.config.embedding_scale is not None:
172
- input_embeds = input_embeds * self.config.embedding_scale
173
- x = input_embeds
174
- updated_kv_entries = []
148
+ cos, sin = self.rope_cache
149
+ cos = cos.index_select(0, input_pos)
150
+ sin = sin.index_select(0, input_pos)
151
+
152
+ # token embeddings of shape (b, t, n_embd)
153
+ x = self.tok_embedding(tokens)
154
+ x = x * (self.config.embedding_dim**0.5)
155
+
156
+ updated_kv_entires = []
175
157
  for i, block in enumerate(self.transformer_blocks):
158
+ mask = self.get_attention_mask(
159
+ block.config.attn_config.attn_type, input_pos
160
+ )
176
161
  kv_entry = kv_cache.caches[i] if kv_cache else None
177
- x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
162
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
178
163
  if kv_entry:
179
- updated_kv_entries.append(kv_entry)
180
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
164
+ updated_kv_entires.append(kv_entry)
165
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
181
166
 
182
167
  if export_config is not None:
183
168
  if (
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
243
228
  )
244
229
 
245
230
  num_layers = 26
246
- embedding_dim = 2304
247
231
  config = cfg.ModelConfig(
248
232
  vocab_size=256000,
249
233
  num_layers=num_layers,
250
234
  max_seq_len=8192,
251
- embedding_dim=embedding_dim,
252
- embedding_scale=embedding_dim**0.5,
235
+ embedding_dim=2304,
253
236
  kv_cache_max_len=kv_cache_max_len,
254
237
  block_configs=[get_block_config(i) for i in range(num_layers)],
255
238
  final_norm_config=norm_config,
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
266
249
  config.num_layers = 2
267
250
  config.max_seq_len = 2 * kv_cache_max_len
268
251
  config.embedding_dim = 128
269
- config.embedding_scale = config.embedding_dim**0.5
270
252
  config.block_configs = config.block_configs[: config.num_layers]
271
253
  for block_config in config.block_configs:
272
254
  block_config.attn_config.num_heads = 4
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
72
72
  mask = self.mask_cache.index_select(2, input_pos)
73
73
  mask = mask[:, :, :, : self.config.max_seq_len]
74
74
 
75
- updated_kv_entries = []
75
+ updated_kv_entires = []
76
76
  for i, block in enumerate(self.transformer_blocks):
77
77
  kv_entry = kv_cache.caches[i] if kv_cache else None
78
78
  x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
79
79
  if kv_entry:
80
- updated_kv_entries.append(kv_entry)
80
+ updated_kv_entires.append(kv_entry)
81
81
 
82
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
82
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
83
83
 
84
84
  if export_config is not None:
85
85
  if (
@@ -26,6 +26,33 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
 
29
+ def _embed_rope(
30
+ q: torch.Tensor,
31
+ k: torch.Tensor,
32
+ n_elem: int,
33
+ rope: Tuple[torch.Tensor, torch.Tensor],
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """Embed rotary positional embedding for query and key.
36
+
37
+ Args:
38
+ q (torch.Tensor): query tensor.
39
+ k (torch.Tensor): key tensor.
40
+ n_elem (int): number of elements to embed rotarty positional embedding.
41
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
42
+ """
43
+ if n_elem > 0:
44
+ cos, sin = rope
45
+ q_roped = rotary_pos_emb.apply_rope(
46
+ q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
47
+ )
48
+ k_roped = rotary_pos_emb.apply_rope(
49
+ k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
50
+ )
51
+ q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
52
+ k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
53
+ return q, k
54
+
55
+
29
56
  class TransformerBlock(nn.Module):
30
57
 
31
58
  def __init__(
@@ -211,8 +238,7 @@ class CausalSelfAttention(nn.Module):
211
238
  if rope is not None:
212
239
  # Compute rotary positional embedding for query and key.
213
240
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
214
- cos, sin = rope
215
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
241
+ q, k = _embed_rope(q, k, n_elem, rope)
216
242
 
217
243
  if kv_cache is not None:
218
244
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -348,8 +374,7 @@ class CrossAttention(nn.Module):
348
374
  if rope is not None:
349
375
  # Compute rotary positional embedding for query and key.
350
376
  n_elem = int(self.config.rotary_percentage * self.config.head_dim)
351
- cos, sin = rope
352
- q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
377
+ q, k = _embed_rope(q, k, n_elem, rope)
353
378
 
354
379
  if kv_cache is not None:
355
380
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
@@ -80,6 +80,7 @@ class RMSNorm(torch.nn.Module):
80
80
  output = self._norm(x.float()).type_as(x)
81
81
  return output * w
82
82
 
83
+
83
84
  class GroupNorm(torch.nn.Module):
84
85
 
85
86
  def __init__(
@@ -115,16 +116,7 @@ class GroupNorm(torch.nn.Module):
115
116
  Returns:
116
117
  torch.Tensor: output tensor after applying GroupNorm.
117
118
  """
118
- if self.enable_hlfb:
119
- return group_norm_with_hlfb(
120
- x,
121
- self.weight,
122
- self.bias,
123
- self.group_num,
124
- self.eps,
125
- )
126
- else:
127
- return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
119
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
128
120
 
129
121
 
130
122
  class LayerNorm(torch.nn.Module):
@@ -169,46 +161,6 @@ class LayerNorm(torch.nn.Module):
169
161
  )
170
162
 
171
163
 
172
- def group_norm_with_hlfb(
173
- x: torch.Tensor,
174
- w: torch.Tensor,
175
- b: torch.Tensor,
176
- num_groups: int,
177
- eps: float,
178
- ):
179
- """Group Normalization with high-level function boundary enabled.
180
-
181
- Args:
182
- x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
183
- w (torch.Tensor): The weight tensor for the normalization.
184
- b (torch.Tensor): The bias tensor for the normalization.
185
- num_groups (int): Number of groups to separate the channels into.
186
- eps (float): A small float value to ensure numerical stability.
187
-
188
- Returns:
189
- The output tensor of Group Normalization.
190
- """
191
- x = torch.permute(x, (0, 2, 3, 1))
192
-
193
- builder = StableHLOCompositeBuilder(
194
- name="odml.group_norm",
195
- attr={
196
- "num_groups": num_groups,
197
- "epsilon": eps,
198
- "reduction_axes": [3],
199
- "channel_axis": 3,
200
- },
201
- )
202
- x, w, b = builder.mark_inputs(x, w, b)
203
- x = torch.permute(x, (0, 3, 1, 2))
204
- y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
205
- y = torch.permute(y, (0, 2, 3, 1))
206
- y = builder.mark_outputs(y)
207
-
208
- y = torch.permute(y, (0, 3, 1, 2))
209
- return y
210
-
211
-
212
164
  def rms_norm_with_hlfb(
213
165
  x: torch.Tensor,
214
166
  w: torch.Tensor,
@@ -32,64 +32,57 @@ def apply_rope(
32
32
  """
33
33
  x = x.transpose(1, 2)
34
34
  head_size = x.size(-1)
35
- x1, x2 = torch.split(x, head_size // 2, dim=-1)
36
- left = x1 * cos - x2 * sin
37
- right = x2 * cos + x1 * sin
38
- roped = torch.cat([left, right], dim=-1)
35
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
36
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
37
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
38
+ roped = (x * cos) + (rotated * sin)
39
39
  return roped.transpose(1, 2).type_as(x)
40
40
 
41
41
 
42
- def build_rope(
42
+ def apply_rope_inline(
43
+ q: torch.Tensor,
44
+ k: torch.Tensor,
43
45
  input_pos: torch.Tensor,
44
46
  n_elem: int,
45
- head_dim: int,
46
47
  base: int = 10_000,
47
48
  ) -> Tuple[torch.Tensor, torch.Tensor]:
48
- """Computes rotary positional embedding cosine and sine tensors.
49
+ """Computes rotary positional embedding inline for a query and key.
49
50
 
50
51
  Args:
52
+ q: the query tensor.
53
+ k: the key tensor.
51
54
  input_pos: the sequence indices for the query and key
52
55
  n_elem: number of elements of the head dimension for RoPE computation
53
- base: the base of the exponentiated value for RoPE.
54
56
 
55
57
  Returns:
56
- cos, sin tensors
58
+ output the RoPE'd query and key.
57
59
  """
58
60
 
59
61
  if n_elem <= 0:
60
- return None, None
62
+ return q, k
61
63
 
62
64
  theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
63
65
  freq_exponents = (2.0 / n_elem) * torch.arange(
64
- head_dim // 2, dtype=torch.float32
66
+ q.shape[-1] // 2, dtype=torch.float32
65
67
  )
66
68
  timescale = float(base) ** freq_exponents
67
69
  radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
68
70
  0
69
71
  ).unsqueeze(0)
70
- cos = torch.cos(radians)
71
- sin = torch.sin(radians)
72
- return cos, sin
73
-
72
+ cos = torch.cos(radians).type_as(q)
73
+ sin = torch.sin(radians).type_as(q)
74
74
 
75
- def apply_rope_inline(
76
- q: torch.Tensor,
77
- k: torch.Tensor,
78
- cos: torch.Tensor,
79
- sin: torch.Tensor,
80
- ) -> Tuple[torch.Tensor, torch.Tensor]:
81
- """Computes rotary positional embedding inline for a query and key.
82
-
83
- Args:
84
- q: the query tensor.
85
- k: the key tensor.
86
- cos: the cosine tensor.
87
- sin: the sine tensor.
88
-
89
- Returns:
90
- output the RoPE'd query and key.
91
- """
75
+ def apply(x, sin, cos):
76
+ x = x.transpose(1, 2)
77
+ b, h, s, d = x.shape
78
+ ans = torch.split(x, d // 2, dim=-1)
79
+ x1, x2 = ans
80
+ left = x1 * cos - x2 * sin
81
+ right = x2 * cos + x1 * sin
82
+ res = torch.cat([left, right], dim=-1)
83
+ res = res.transpose(1, 2)
84
+ return res
92
85
 
93
- q_roped = apply_rope(q, cos, sin)
94
- k_roped = apply_rope(k, cos, sin)
86
+ q_roped = apply(q, sin, cos)
87
+ k_roped = apply(k, sin, cos)
95
88
  return q_roped, k_roped
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers import builder
24
24
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
26
  import ai_edge_torch.generative.layers.model_config as cfg
27
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
28
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
29
28
  import torch
30
29
  from torch import nn
@@ -86,6 +85,13 @@ class DecoderOnlyModel(nn.Module):
86
85
  config.embedding_dim,
87
86
  config.final_norm_config,
88
87
  )
88
+ # ROPE parameters for all attn_configs are the same. Take the first one.
89
+ attn_config = config.block_config(0).attn_config
90
+ self.rope_cache = attn_utils.build_rope_cache(
91
+ size=config.kv_cache_max,
92
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
93
+ base=attn_config.rotary_base,
94
+ )
89
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
90
96
  size=config.kv_cache_max,
91
97
  )
@@ -107,22 +113,16 @@ class DecoderOnlyModel(nn.Module):
107
113
 
108
114
  # token embeddings of shape (b, t, n_embd)
109
115
  input_embeds = self.tok_embedding(tokens)
110
-
111
- # ROPE parameters for all attn_configs are the same. Take the first one.
112
- attn_config = self.config.block_config(0).attn_config
113
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
114
- rope = rotary_pos_emb.build_rope(
115
- input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
116
- )
117
-
116
+ cos, sin = self.rope_cache
117
+ rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
118
118
  mask = self.mask_cache.index_select(2, input_pos)
119
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
120
120
 
121
- return self._forward_with_embeds(
121
+ return self.forward_with_embeds(
122
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
123
123
  )
124
124
 
125
- def _forward_with_embeds(
125
+ def forward_with_embeds(
126
126
  self,
127
127
  input_embeds: torch.Tensor,
128
128
  rope: Tuple[torch.Tensor, torch.Tensor],
@@ -141,13 +141,13 @@ class DecoderOnlyModel(nn.Module):
141
141
  if self.config.embedding_scale is not None:
142
142
  x = x * self.config.embedding_scale
143
143
 
144
- updated_kv_entries = []
144
+ updated_kv_entires = []
145
145
  for i, block in enumerate(self.transformer_blocks):
146
146
  kv_entry = kv_cache.caches[i] if kv_cache else None
147
147
  x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
148
148
  if kv_entry:
149
- updated_kv_entries.append(kv_entry)
150
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
149
+ updated_kv_entires.append(kv_entry)
150
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
151
151
 
152
152
  if export_config is not None:
153
153
  if (
@@ -16,7 +16,7 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import Any,List
19
+ from typing import Any, List, Optional
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  from ai_edge_torch.generative.utilities.model_builder import ExportConfig
@@ -134,7 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
134
134
  prompts: torch.Tensor,
135
135
  max_new_tokens: int,
136
136
  pixel_values: torch.Tensor = None,
137
- eos_token_id: int = 1,
137
+ eos_token_id: Optional[int] = None,
138
138
  ) -> torch.IntTensor:
139
139
  input_ids = prompts[0].int().tolist()
140
140
  tokens = torch.tensor([input_ids])
@@ -146,7 +146,7 @@ class ReauthoredModelWrapper(ModelWrapper):
146
146
  )
147
147
  generated_token = logits[0][-1].argmax().item()
148
148
  input_ids.append(generated_token)
149
- if generated_token == eos_token_id:
149
+ if eos_token_id is not None and generated_token == eos_token_id:
150
150
  break
151
151
  tokens = torch.tensor([[generated_token]])
152
152
  input_pos = torch.tensor([len(input_ids) - 1])
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
253
253
  outputs_reauthored = reauthored_model.generate(
254
254
  prompt_tokens,
255
255
  max_new_tokens,
256
- eos_token_id=tokenizer.tokenizer.eos_token_id,
256
+ eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
257
257
  )
258
258
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
259
259
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
27
27
  # https://github.com/google-ai-edge/ai-edge-torch/issues/326
28
28
  os.environ["PJRT_DEVICE"] = "CPU"
29
29
 
30
+ os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
31
+
32
+
30
33
  from ai_edge_torch import model
31
34
  from ai_edge_torch._convert import conversion_utils
32
35
  from ai_edge_torch._convert import signature as signature_module
@@ -202,7 +202,7 @@ class MlirLowered:
202
202
  target_version = stablehlo.get_minimum_version()
203
203
  else:
204
204
  target_version = stablehlo.get_version_from_compatibility_requirement(
205
- stablehlo.StablehloCompatibilityRequirement.WEEK_4
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_12
206
206
  )
207
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
208
208
  self.module_bytecode, target_version
@@ -222,11 +222,6 @@ class MlirLowered:
222
222
  # Lazy importing TF when execution is needed.
223
223
  return self.tf_function(*args)
224
224
 
225
- def to_flatbuffer(self):
226
- from . import tf_integration
227
-
228
- return tf_integration.mlir_to_flatbuffer(self)
229
-
230
225
 
231
226
  # TODO(b/331481564) Make this a ai_edge_torch FX pass.
232
227
  def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
@@ -12,10 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts."""
15
+ """APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
16
16
 
17
17
  import re
18
- import tempfile
19
18
 
20
19
  import tensorflow as tf
21
20
  import torch
@@ -104,20 +103,26 @@ def _extract_call_args(
104
103
  def _wrap_as_tf_func(lowered, tf_state_dict):
105
104
  """Build tf.function from lowered and tf_state_dict."""
106
105
 
107
- def inner(*args):
106
+ version = 6
107
+ if hasattr(tfxla, "call_module_maximum_supported_version"):
108
+ version = tfxla.call_module_maximum_supported_version()
109
+
110
+ def tf_func(*args):
108
111
  t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
109
112
  s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
110
113
  call_args = _extract_call_args(lowered, args, tf_state_dict)
111
114
  return tfxla.call_module(
112
115
  tuple(call_args),
113
- version=5,
116
+ version=version,
114
117
  Tout=t_outs, # dtype information
115
- Sout=s_outs, # Shape information
118
+ Sout=s_outs, # shape information
116
119
  function_list=[],
117
- module=lowered.module_bytecode,
120
+ module=lowered.module_bytecode_vhlo,
121
+ has_token_input_output=False,
122
+ platforms=["CPU"],
118
123
  )
119
124
 
120
- return inner
125
+ return tf_func
121
126
 
122
127
 
123
128
  def _make_input_signatures(
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
149
154
  _wrap_as_tf_func(lowered, tf_state_dict),
150
155
  input_signature=_make_input_signatures(lowered),
151
156
  )
152
-
153
-
154
- def mlir_to_flatbuffer(lowered: export.MlirLowered):
155
- """Convert the MLIR lowered to a TFLite flatbuffer binary."""
156
- tf_state_dict = _build_tf_state_dict(lowered)
157
- signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
158
- tf_signatures = [_make_input_signatures(lowered)]
159
- tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
160
-
161
- tf_module = tf.Module()
162
- tf_module.f = []
163
-
164
- for tf_sig, func in zip(tf_signatures, tf_functions):
165
- tf_module.f.append(
166
- tf.function(
167
- func,
168
- input_signature=tf_sig,
169
- )
170
- )
171
-
172
- tf_module._variables = list(tf_state_dict.values())
173
-
174
- tf_concrete_funcs = [
175
- func.get_concrete_function(*tf_sig)
176
- for func, tf_sig in zip(tf_module.f, tf_signatures)
177
- ]
178
-
179
- # We need to temporarily save since TFLite's from_concrete_functions does not
180
- # allow providing names for each of the concrete functions.
181
- with tempfile.TemporaryDirectory() as temp_dir_path:
182
- tf.saved_model.save(
183
- tf_module,
184
- temp_dir_path,
185
- signatures={
186
- sig_name: tf_concrete_funcs[idx]
187
- for idx, sig_name in enumerate(signature_names)
188
- },
189
- )
190
-
191
- converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
192
- tflite_model = converter.convert()
193
-
194
- return tflite_model
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241226"
16
+ __version__ = "0.3.0.dev20250107"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241226
3
+ Version: 0.3.0.dev20250107
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,11 +1,11 @@
1
1
  ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
2
- ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
2
+ ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=E_WDoV5Y1AG1Kq9M0_73bQYoSSnhDCJ7dxLCdKpkJJE,706
6
+ ai_edge_torch/version.py,sha256=X0ZEB5T3xcR8MsIE8VOHDAdHnCZTzJLBQQ9j2xZ4_qA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
- ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
8
+ ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
10
10
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
11
11
  ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
@@ -13,12 +13,12 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
13
13
  ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
14
14
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
15
15
  ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
16
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
16
+ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
17
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
47
47
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
48
48
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
49
49
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=VTM2nO3TqK2d1DyEb2MiHc-Tyw2lMcUXyOhvg0H5ENY,10147
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
@@ -109,7 +109,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
109
109
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
110
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
111
111
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
112
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
112
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
113
113
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
114
114
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
115
115
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
@@ -117,14 +117,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
117
117
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
118
118
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
119
119
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
120
- ai_edge_torch/generative/layers/attention.py,sha256=_OmamS3f0m_JtW73ljwGLwFPeMLL837JCLY-dJ3iRUg,12453
120
+ ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoTLzXXIy8dNtU8xC4,13199
121
121
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
122
122
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
123
123
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
124
124
  ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
125
125
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
126
- ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
127
- ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=zbFTNgQdOT-tcKK1QaIX6fG-50syYwQX_ZbLhg2C98c,2691
126
+ ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
127
+ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
128
128
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
129
129
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
130
130
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
@@ -149,12 +149,12 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
149
149
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
150
150
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
151
151
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
152
- ai_edge_torch/generative/utilities/model_builder.py,sha256=S08WNqVKCmxd2QjtMlwETd7J97UnlME_bTKdz5LMkGU,6352
152
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
153
153
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
154
154
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
155
155
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
156
156
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
157
- ai_edge_torch/generative/utilities/verifier.py,sha256=awO-sQrEpsFxIkZw72ysWZenYEmkLOLOuj62o2c7XeQ,11994
157
+ ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
158
158
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
159
159
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
160
160
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
166
166
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
167
167
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
168
168
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
169
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
169
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
170
170
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
171
171
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
172
172
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
173
173
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
174
- ai_edge_torch/odml_torch/export.py,sha256=QzOPmcNPB7R-KhhPEP0oGVbDRgGPptIxRSoz3S8py9I,13405
174
+ ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
175
175
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
176
- ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
176
+ ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
177
177
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
178
178
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
179
179
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
203
203
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
204
204
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
205
205
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
206
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/METADATA,sha256=khQQVRgopWndD2IbqOblhMzGAlOzSS6f0SbpP1oZ5xw,1966
208
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
- ai_edge_torch_nightly-0.3.0.dev20241226.dist-info/RECORD,,
206
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
207
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/METADATA,sha256=p2F-coQaq7CbpMOkQLVnpFB01cCKqftVRGZ4dCVu8Ck,1966
208
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
209
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
210
+ ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/RECORD,,