ai-edge-torch-nightly 0.4.0.dev20250225__py3-none-any.whl → 0.4.0.dev20250227__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -125,6 +125,10 @@ def convert_signatures(
125
125
  else:
126
126
  exported_program = torch.export.export(**kwargs, strict=True)
127
127
 
128
+ exported_program = fx_infra.graph_utils.reset_from_node_meta(
129
+ exported_program
130
+ )
131
+
128
132
  exported_program = fx_infra.safe_run_decompositions(
129
133
  exported_program,
130
134
  fx_infra.decomp.pre_convert_decomp(),
@@ -13,7 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """FX graph utilities."""
16
+
17
+ from packaging import version
16
18
  import torch
19
+ from torch.fx import traceback
17
20
 
18
21
 
19
22
  def remove_dangling_args(graph_module: torch.fx.GraphModule):
@@ -40,3 +43,32 @@ def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
40
43
  graph_module.graph.lint()
41
44
  graph_module.recompile()
42
45
  return graph_module
46
+
47
+
48
+ def is_torch_version_under(torch_version: str) -> bool:
49
+ """Checks if the current torch version is under the given version."""
50
+ if not torch_version:
51
+ raise ValueError("torch_version cannot be empty.")
52
+ current_version = version.parse(torch.__version__)
53
+ compared_version = version.parse(torch_version)
54
+ return current_version < compared_version
55
+
56
+
57
+ def reset_from_node_meta(ep: torch.export.ExportedProgram):
58
+ """Resets the "from_node" meta field to fx node name only for the exported program."""
59
+
60
+ for node in ep.graph.nodes:
61
+ if not hasattr(node, "meta") or "from_node" not in node.meta:
62
+ continue
63
+ if is_torch_version_under("2.6.0.dev0"):
64
+ # For torch version under 2.6.0, the history stack is a list of tuple. We
65
+ # will only keep the current node's name in the history stack.
66
+ history = [(node.name,)]
67
+ else:
68
+ # Clean up the history stack by keeping only the current node info (fx
69
+ # node name and graph id) in a list of size 1. Clear the "from_node" field
70
+ # to prevent redundant additions to the history stack.
71
+ history = [traceback.NodeSource(node)]
72
+ history[0].from_node = []
73
+ node.meta["from_node"] = history
74
+ return ep
@@ -22,7 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
- from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
+ from ai_edge_torch.generative.utilities import model_builder
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
@@ -59,6 +59,11 @@ _LORA_RANKS = flags.DEFINE_multi_integer(
59
59
  None,
60
60
  'If set, the model will be converted with the provided list of LoRA ranks.',
61
61
  )
62
+ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
63
+ 'decode_batch_size',
64
+ 1,
65
+ 'The batch size for the decode signature.',
66
+ )
62
67
 
63
68
 
64
69
  def main(_):
@@ -72,7 +77,9 @@ def main(_):
72
77
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
78
  quantize=_QUANTIZE.value,
74
79
  lora_ranks=_LORA_RANKS.value,
75
- export_config=ExportConfig(),
80
+ export_config=model_builder.ExportConfig(
81
+ decode_batch_size=_DECODE_BATCH_SIZE.value
82
+ ),
76
83
  )
77
84
 
78
85
 
@@ -22,17 +22,22 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
- from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
+ from ai_edge_torch.generative.utilities import model_builder
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,6 +54,16 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
63
+ 'decode_batch_size',
64
+ 1,
65
+ 'The batch size for the decode signature.',
66
+ )
52
67
 
53
68
 
54
69
  def main(_):
@@ -56,14 +71,16 @@ def main(_):
56
71
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
72
  )
58
73
 
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
74
  converter.convert_to_tflite(
62
75
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
76
+ output_path=_OUTPUT_PATH.value,
77
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
78
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
79
  quantize=_QUANTIZE.value,
66
- export_config=ExportConfig(),
80
+ lora_ranks=_LORA_RANKS.value,
81
+ export_config=model_builder.ExportConfig(
82
+ decode_batch_size=_DECODE_BATCH_SIZE.value
83
+ ),
67
84
  )
68
85
 
69
86
 
@@ -48,7 +48,6 @@ class TransformerBlock(nn.Module):
48
48
  config.pre_attention_norm_config,
49
49
  )
50
50
  self.atten_func = CausalSelfAttention(
51
- model_config.batch_size,
52
51
  model_config.embedding_dim,
53
52
  config.attn_config,
54
53
  model_config.enable_hlfb,
@@ -115,7 +114,6 @@ class CausalSelfAttention(nn.Module):
115
114
 
116
115
  def __init__(
117
116
  self,
118
- batch_size: int,
119
117
  dim: int,
120
118
  config: cfg.AttentionConfig,
121
119
  enable_hlfb: bool,
@@ -123,14 +121,12 @@ class CausalSelfAttention(nn.Module):
123
121
  """Initialize an instance of CausalSelfAttention.
124
122
 
125
123
  Args:
126
- batch_size (int): batch size of the input tensor.
127
124
  dim (int): causal attention's input/output dimmension.
128
125
  config (cfg.AttentionConfig): attention specific configurations.
129
126
  enable_hlfb (bool): whether hlfb is enabled or not.
130
127
  """
131
128
  super().__init__()
132
129
  self.kv_cache = None
133
- self.batch_size = batch_size
134
130
  qkv_shape = (
135
131
  config.num_heads + 2 * config.num_query_groups
136
132
  ) * config.head_dim
@@ -179,11 +175,6 @@ class CausalSelfAttention(nn.Module):
179
175
  """
180
176
  # Batch size, sequence length, embedding dimensionality.
181
177
  B, T, E = x.size()
182
- assert B == self.batch_size, (
183
- "batch size of input tensor must match with the batch size specified in"
184
- " the model configuration."
185
- )
186
-
187
178
  qkv = self.qkv_projection(x)
188
179
 
189
180
  # Assemble into a number of query groups to support MHA, MQA and GQA.
@@ -290,7 +281,6 @@ class CrossAttention(nn.Module):
290
281
 
291
282
  def __init__(
292
283
  self,
293
- batch_size: int,
294
284
  query_dim: int,
295
285
  cross_dim: int,
296
286
  hidden_dim: int,
@@ -301,7 +291,6 @@ class CrossAttention(nn.Module):
301
291
  """Initialize an instance of CrossAttention.
302
292
 
303
293
  Args:
304
- batch_size (int): batch size of the input tensor.
305
294
  query_dim (int): query tensor's dimension.
306
295
  cross_dim (int): cross attention's dimensions, for key and value tensors.
307
296
  hidden_dim (int): hidden dimension that q, k, v tensors project to.
@@ -376,7 +365,6 @@ class CrossAttention(nn.Module):
376
365
 
377
366
  if rope is not None:
378
367
  # Compute rotary positional embedding for query and key.
379
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
380
368
  cos, sin = rope
381
369
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
382
370
 
@@ -18,14 +18,11 @@
18
18
  import dataclasses
19
19
  from typing import List, Tuple
20
20
 
21
- from ai_edge_torch import hlfb
22
21
  from ai_edge_torch.generative.layers import model_config
23
22
  from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
24
23
  import torch
25
24
  import torch.utils._pytree as pytree
26
25
 
27
- BATCH_SIZE = 1
28
-
29
26
 
30
27
  @dataclasses.dataclass
31
28
  class KVCacheEntry:
@@ -45,9 +42,10 @@ class KVCacheEntry:
45
42
  config: model_config.AttentionConfig,
46
43
  dtype: torch.dtype = torch.float32,
47
44
  device: torch.device = None,
45
+ batch_size: int = 1,
48
46
  ) -> "KVCacheEntry":
49
47
  """Build an instance of the class based on model config."""
50
- shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
48
+ shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
51
49
  k = torch.zeros(shape, dtype=dtype, device=device)
52
50
  v = torch.zeros(shape, dtype=dtype, device=device)
53
51
  obj = cls(k_cache=k, v_cache=v)
@@ -66,6 +64,7 @@ class KVCache:
66
64
  config: model_config.ModelConfig,
67
65
  dtype: torch.dtype = torch.float32,
68
66
  device: torch.device = None,
67
+ batch_size: int = 1,
69
68
  ) -> "KVCache":
70
69
  """Build an instance of the class based on model config.
71
70
 
@@ -75,17 +74,21 @@ class KVCache:
75
74
  Defaults to torch.float32.
76
75
  device (torch.device, optional): The device placement of the cache
77
76
  tensors. Defaults to None.
77
+ batch_size (int, optional): The batch size of the cache tensors.
78
+ Defaults to 1.
78
79
 
79
80
  Returns:
80
81
  KVCache: The created cache object.
81
82
  """
82
83
  caches = [
83
84
  KVCacheEntry.from_model_config(
84
- config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
85
+ config.kv_cache_max
86
+ if not config.block_config(idx).kv_cache_max_len
85
87
  else config.block_config(idx).kv_cache_max_len,
86
88
  config.block_config(idx).attn_config,
87
89
  dtype,
88
90
  device,
91
+ batch_size,
89
92
  )
90
93
  for idx in range(config.num_layers)
91
94
  ]
@@ -220,9 +220,6 @@ class ModelConfig:
220
220
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
221
221
  kv_cache_max_len: int = 0
222
222
 
223
- # Default batch size of the exported model. Default value is 1.
224
- batch_size: int = 1
225
-
226
223
  # Softcap on the model output logits.
227
224
  final_logit_softcap: Optional[float] = None
228
225
 
@@ -110,6 +110,11 @@ def convert_to_tflite(
110
110
  lora_suffix = (
111
111
  '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
112
112
  )
113
+
114
+ if export_config is not None:
115
+ if export_config.decode_batch_size > 1:
116
+ output_name_prefix += f'_dbs{export_config.decode_batch_size}'
117
+
113
118
  output_filename = (
114
119
  f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
115
120
  )
@@ -162,9 +167,14 @@ def _export_helper(
162
167
  if prefill_masks:
163
168
  assert len(prefill_masks) == len(prefill_seq_lens)
164
169
 
165
- decode_token = torch.tensor([[0]], dtype=torch.int)
170
+ decode_token = torch.tensor(
171
+ [[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
172
+ )
166
173
  decode_input_pos = torch.tensor([0], dtype=torch.int)
167
- kv = export_config.kvcache_cls.from_model_config(config)
174
+ prefill_kv = export_config.kvcache_cls.from_model_config(config)
175
+ decode_kv = export_config.kvcache_cls.from_model_config(
176
+ config, batch_size=export_config.decode_batch_size
177
+ )
168
178
 
169
179
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
170
180
 
@@ -183,7 +193,7 @@ def _export_helper(
183
193
  sample_kwargs = {
184
194
  'tokens': prefill_tokens,
185
195
  'input_pos': prefill_input_pos,
186
- 'kv_cache': kv,
196
+ 'kv_cache': prefill_kv,
187
197
  }
188
198
  if prefill_masks is not None:
189
199
  sample_kwargs['mask'] = prefill_masks[i]
@@ -211,7 +221,7 @@ def _export_helper(
211
221
  sample_kwargs = {
212
222
  'tokens': decode_token,
213
223
  'input_pos': decode_input_pos,
214
- 'kv_cache': kv,
224
+ 'kv_cache': decode_kv,
215
225
  }
216
226
  if export_config.decode_mask is not None:
217
227
  sample_kwargs['mask'] = export_config.decode_mask
@@ -60,6 +60,8 @@ class ExportConfig:
60
60
  decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
61
61
  # The KV Cache class for K and V buffers in attention.
62
62
  kvcache_cls: type = kv_utils.KVCache
63
+ # The batch size of the decode signature.
64
+ decode_batch_size: int = 1
63
65
 
64
66
 
65
67
  class DecoderOnlyModel(nn.Module):
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import re
16
+ from ai_edge_torch.fx_infra import graph_utils
16
17
  import torch
17
18
 
18
19
 
@@ -48,8 +49,6 @@ def _get_canonical_filename(filename: str):
48
49
  The canonicalized filename.
49
50
  """
50
51
 
51
- # TODO(yijieyang): We should add a config option to provide a regex to strip
52
- # from the debug info. Currently absolute path is used.
53
52
  return filename
54
53
 
55
54
 
@@ -57,9 +56,13 @@ def _get_canoical_nodename(node: torch.fx.Node) -> str:
57
56
  """Get the canonical node name from the node's history."""
58
57
 
59
58
  history = node.meta.get("from_node", [])
59
+ if not history:
60
+ return None
60
61
 
61
- if len(history) > 1: # Compatible with torch version under 2.6.0
62
- return history[1][0]
62
+ # Compatible with torch version under 2.6.0. The history stack is a list of
63
+ # tuple. The first element of the first tuple is the node name.
64
+ if graph_utils.is_torch_version_under("2.6.0.dev0"):
65
+ return history[0][0]
63
66
 
64
67
  if not hasattr(history[0], "name"):
65
68
  return None
@@ -68,12 +71,10 @@ def _get_canoical_nodename(node: torch.fx.Node) -> str:
68
71
  names.append(history[0].name)
69
72
  history = history[0].from_node
70
73
 
71
- # Based on the experiment, the third to last name in the history stack
72
- # can be mapped to the original torch node name. The history stack is
73
- # generated by tracing the node's transformation history during lowering.
74
- if len(names) > 2:
75
- return names[-3]
76
- return None
74
+ # The history stack is generated by tracing the node's transformation history
75
+ # during lowering. The last name in the history stack is used to map to the
76
+ # original torch fx node name.
77
+ return names[-1]
77
78
 
78
79
 
79
80
  def build_mlir_file_debuginfo(node: torch.fx.Node):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250225"
16
+ __version__ = "0.4.0.dev20250227"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250225
3
+ Version: 0.4.0.dev20250227
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,9 +2,9 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=IlCclQHtv864xDiY226T7A6oRSPGetrRq1_B_aIyGAg,706
5
+ ai_edge_torch/version.py,sha256=K2jtDrBNGi74j_uQYVUT6MJ2-aQFKkKy5ZYur9iWdVU,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
7
+ ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
@@ -41,7 +41,7 @@ ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpx
41
41
  ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
42
42
  ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
43
43
  ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
44
- ai_edge_torch/fx_infra/graph_utils.py,sha256=3UZAOHWOUh2LCj1E2_AKQn3gRDILi9JCdqSScjyOd4M,1535
44
+ ai_edge_torch/fx_infra/graph_utils.py,sha256=nqGe-xIJ77RamSUh0UYyI2XHOsZqFDWax-vpRAtVR_E,2796
45
45
  ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
46
46
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -102,8 +102,8 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
102
102
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
103
103
  ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
104
104
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
105
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
106
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
105
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=hWko-RJB8eXNUfi4EzQ2yjW30YE4UB4zAz7rd2Q5qpg,2708
106
+ ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=sJ-o385eqQsciv0TEQRkixvS0DD6dKruAuK0zlEsDoY,2715
107
107
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
108
108
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
109
109
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -136,13 +136,13 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
136
136
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
137
137
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
138
138
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
139
- ai_edge_torch/generative/layers/attention.py,sha256=Pm8FLKh-NnOvUjqQC9oX5oghPbdivZvlPVkgOVTShoU,13703
139
+ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
140
140
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
141
141
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
142
142
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
143
- ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
143
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=jwbt0-2fd_CNWS2fp4nf0zvh6kk5citINGlFC_RtEUU,6540
144
144
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
145
- ai_edge_torch/generative/layers/model_config.py,sha256=EA1Ey5-c1IOLRNANSUnZ7gtNTA0o6OJxrz_I_mp8cjw,8244
145
+ ai_edge_torch/generative/layers/model_config.py,sha256=wNsZDzZQoimOKdZ9FWMCktPj-pQ_0D7084hgzMT5XYo,8155
146
146
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
147
147
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
148
148
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
@@ -173,10 +173,10 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
173
173
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
174
174
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
175
175
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
176
- ai_edge_torch/generative/utilities/converter.py,sha256=_PO9lYCdNNYPVsAqh8QQVMG_8TUBshKwmaR1cdT6Ang,8065
176
+ ai_edge_torch/generative/utilities/converter.py,sha256=VtG42CVz657XbvTj-FZJiCFW0Hm11OVKKC_mr2tjxhc,8413
177
177
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
178
178
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
179
- ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
179
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=eY3qAcBhupIn955YnWuzUi9hoWYvl4ntRWA6PBudzMo,6888
180
180
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
181
181
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
182
182
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -205,7 +205,7 @@ ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xK
205
205
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
206
206
  ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
207
207
  ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
208
- ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=k9Kas790kpMS5OrVcLzIr48ejAzcc2smrroKAHHM7TQ,3311
208
+ ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=6Ns2rlfOilLJEk5cUxlkRwm2uxOgEF2-0S2DMcOqr6A,3319
209
209
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
210
210
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
211
211
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
230
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
231
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
232
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
233
- ai_edge_torch_nightly-0.4.0.dev20250225.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
- ai_edge_torch_nightly-0.4.0.dev20250225.dist-info/METADATA,sha256=R_qRD9bYEGBzEuxpKkHA0pjZyNWiQ5CG8hT43wNYDeM,1966
235
- ai_edge_torch_nightly-0.4.0.dev20250225.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
- ai_edge_torch_nightly-0.4.0.dev20250225.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
- ai_edge_torch_nightly-0.4.0.dev20250225.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/METADATA,sha256=cHcz3adq1WwVddazAJ06h7SKITJm70eMpFVjoNa2Jw4,1966
235
+ ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/RECORD,,