ai-edge-torch-nightly 0.4.0.dev20250226__py3-none-any.whl → 0.4.0.dev20250228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,7 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
- from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
+ from ai_edge_torch.generative.utilities import model_builder
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
@@ -59,6 +59,11 @@ _LORA_RANKS = flags.DEFINE_multi_integer(
59
59
  None,
60
60
  'If set, the model will be converted with the provided list of LoRA ranks.',
61
61
  )
62
+ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
63
+ 'decode_batch_size',
64
+ 1,
65
+ 'The batch size for the decode signature.',
66
+ )
62
67
 
63
68
 
64
69
  def main(_):
@@ -72,7 +77,9 @@ def main(_):
72
77
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
73
78
  quantize=_QUANTIZE.value,
74
79
  lora_ranks=_LORA_RANKS.value,
75
- export_config=ExportConfig(),
80
+ export_config=model_builder.ExportConfig(
81
+ decode_batch_size=_DECODE_BATCH_SIZE.value
82
+ ),
76
83
  )
77
84
 
78
85
 
@@ -22,17 +22,22 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
- from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
+ from ai_edge_torch.generative.utilities import model_builder
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,6 +54,16 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
62
+ _DECODE_BATCH_SIZE = flags.DEFINE_integer(
63
+ 'decode_batch_size',
64
+ 1,
65
+ 'The batch size for the decode signature.',
66
+ )
52
67
 
53
68
 
54
69
  def main(_):
@@ -56,14 +71,16 @@ def main(_):
56
71
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
72
  )
58
73
 
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
74
  converter.convert_to_tflite(
62
75
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
76
+ output_path=_OUTPUT_PATH.value,
77
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
78
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
79
  quantize=_QUANTIZE.value,
66
- export_config=ExportConfig(),
80
+ lora_ranks=_LORA_RANKS.value,
81
+ export_config=model_builder.ExportConfig(
82
+ decode_batch_size=_DECODE_BATCH_SIZE.value
83
+ ),
67
84
  )
68
85
 
69
86
 
@@ -48,7 +48,6 @@ class TransformerBlock(nn.Module):
48
48
  config.pre_attention_norm_config,
49
49
  )
50
50
  self.atten_func = CausalSelfAttention(
51
- model_config.batch_size,
52
51
  model_config.embedding_dim,
53
52
  config.attn_config,
54
53
  model_config.enable_hlfb,
@@ -115,7 +114,6 @@ class CausalSelfAttention(nn.Module):
115
114
 
116
115
  def __init__(
117
116
  self,
118
- batch_size: int,
119
117
  dim: int,
120
118
  config: cfg.AttentionConfig,
121
119
  enable_hlfb: bool,
@@ -123,14 +121,12 @@ class CausalSelfAttention(nn.Module):
123
121
  """Initialize an instance of CausalSelfAttention.
124
122
 
125
123
  Args:
126
- batch_size (int): batch size of the input tensor.
127
124
  dim (int): causal attention's input/output dimmension.
128
125
  config (cfg.AttentionConfig): attention specific configurations.
129
126
  enable_hlfb (bool): whether hlfb is enabled or not.
130
127
  """
131
128
  super().__init__()
132
129
  self.kv_cache = None
133
- self.batch_size = batch_size
134
130
  qkv_shape = (
135
131
  config.num_heads + 2 * config.num_query_groups
136
132
  ) * config.head_dim
@@ -179,11 +175,6 @@ class CausalSelfAttention(nn.Module):
179
175
  """
180
176
  # Batch size, sequence length, embedding dimensionality.
181
177
  B, T, E = x.size()
182
- assert B == self.batch_size, (
183
- "batch size of input tensor must match with the batch size specified in"
184
- " the model configuration."
185
- )
186
-
187
178
  qkv = self.qkv_projection(x)
188
179
 
189
180
  # Assemble into a number of query groups to support MHA, MQA and GQA.
@@ -290,7 +281,6 @@ class CrossAttention(nn.Module):
290
281
 
291
282
  def __init__(
292
283
  self,
293
- batch_size: int,
294
284
  query_dim: int,
295
285
  cross_dim: int,
296
286
  hidden_dim: int,
@@ -301,7 +291,6 @@ class CrossAttention(nn.Module):
301
291
  """Initialize an instance of CrossAttention.
302
292
 
303
293
  Args:
304
- batch_size (int): batch size of the input tensor.
305
294
  query_dim (int): query tensor's dimension.
306
295
  cross_dim (int): cross attention's dimensions, for key and value tensors.
307
296
  hidden_dim (int): hidden dimension that q, k, v tensors project to.
@@ -376,7 +365,6 @@ class CrossAttention(nn.Module):
376
365
 
377
366
  if rope is not None:
378
367
  # Compute rotary positional embedding for query and key.
379
- n_elem = int(self.config.rotary_percentage * self.config.head_dim)
380
368
  cos, sin = rope
381
369
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
382
370
 
@@ -52,7 +52,6 @@ class TransformerBlock(nn.Module):
52
52
  config.pre_attention_norm_config,
53
53
  )
54
54
  self.atten_func = CausalSelfAttention(
55
- model_config.batch_size,
56
55
  model_config.embedding_dim,
57
56
  config.attn_config,
58
57
  model_config.enable_hlfb,
@@ -119,7 +118,6 @@ class CausalSelfAttention(nn.Module):
119
118
 
120
119
  def __init__(
121
120
  self,
122
- batch_size: int,
123
121
  dim: int,
124
122
  config: cfg.AttentionConfig,
125
123
  enable_hlfb: bool,
@@ -127,14 +125,12 @@ class CausalSelfAttention(nn.Module):
127
125
  """Initialize an instance of CausalSelfAttention.
128
126
 
129
127
  Args:
130
- batch_size (int): batch size of the input tensor.
131
128
  dim (int): causal attention's input/output dimmension.
132
129
  config (cfg.AttentionConfig): attention specific configurations.
133
130
  enable_hlfb (bool): whether hlfb is enabled or not.
134
131
  """
135
132
  super().__init__()
136
133
  self.kv_cache = None
137
- self.batch_size = batch_size
138
134
  qkv_shape = (
139
135
  config.num_heads + 2 * config.num_query_groups
140
136
  ) * config.head_dim
@@ -180,10 +176,6 @@ class CausalSelfAttention(nn.Module):
180
176
  """
181
177
  # Batch size, sequence length, embedding dimensionality.
182
178
  B, T, E = x.size()
183
- assert B == self.batch_size, (
184
- "batch size of input tensor must match with the batch size specified in"
185
- " the model configuration."
186
- )
187
179
 
188
180
  qkv = self.qkv_projection(x)
189
181
 
@@ -21,23 +21,19 @@ This is an experimental implementation and is subject to change at any time.
21
21
  import dataclasses
22
22
  from typing import List, Tuple
23
23
 
24
- from ai_edge_torch import hlfb
25
24
  from ai_edge_torch.generative.layers import model_config
26
- from ai_edge_torch.generative.layers.experimental import types as types
27
- from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
25
+ from ai_edge_torch.generative.layers.experimental import types
26
+ from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
28
27
  import torch
29
- import torch.nn as nn
30
28
  import torch.utils._pytree as pytree
31
29
 
32
- BATCH_SIZE = 1
33
-
34
30
 
35
31
  @dataclasses.dataclass
36
32
  class KVCacheEntryBase:
37
33
  """A single cache entry that includes K and V caches.
38
34
 
39
35
  The chaches are built based on the provided config with the shape of
40
- (batch_size=1, kv_cache_max, num_query_groups, head_dim).
36
+ (batch_size, kv_cache_max, num_query_groups, head_dim).
41
37
  """
42
38
 
43
39
  k_cache: torch.Tensor
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
46
42
  @classmethod
47
43
  def _from_model_config(
48
44
  cls,
49
- kv_cache_max: int,
50
- config: model_config.AttentionConfig,
51
- k_shape: Tuple,
52
- v_shape: Tuple,
45
+ k_shape: Tuple[int, ...],
46
+ v_shape: Tuple[int, ...],
53
47
  dtype: torch.dtype = torch.float32,
54
48
  device: torch.device = None,
55
49
  ) -> "KVCacheEntryBase":
@@ -66,12 +60,11 @@ class KVCacheEntryBase:
66
60
  config: model_config.AttentionConfig,
67
61
  dtype: torch.dtype = torch.float32,
68
62
  device: torch.device = None,
63
+ batch_size: int = 1,
69
64
  ) -> "KVCacheEntryBase":
70
65
  """Build an instance of the class based on model config."""
71
- shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
- return cls._from_model_config(
73
- kv_cache_max, config, shape, shape, dtype, device
74
- )
66
+ shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67
+ return cls._from_model_config(shape, shape, dtype, device)
75
68
 
76
69
 
77
70
  @dataclasses.dataclass
@@ -93,24 +86,22 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
93
86
  config: model_config.AttentionConfig,
94
87
  dtype: torch.dtype = torch.float32,
95
88
  device: torch.device = None,
89
+ batch_size: int = 1,
96
90
  ) -> "KVCacheEntryBase":
97
91
  """Build an instance of the class based on model config."""
98
- num_kv_heads = config.num_query_groups
99
92
  k_shape = (
100
- 1,
101
- BATCH_SIZE * num_kv_heads,
93
+ batch_size,
94
+ config.num_query_groups,
102
95
  kv_cache_max,
103
96
  config.head_dim,
104
- ) # 1, bk, s, h
97
+ ) # b, k, s, h
105
98
  v_shape = (
106
- 1,
107
- BATCH_SIZE * num_kv_heads,
99
+ batch_size,
100
+ config.num_query_groups,
108
101
  config.head_dim,
109
102
  kv_cache_max,
110
- ) # 1, bk, h, s
111
- return cls._from_model_config(
112
- kv_cache_max, config, k_shape, v_shape, dtype, device
113
- )
103
+ ) # b, k, h, s
104
+ return cls._from_model_config(k_shape, v_shape, dtype, device)
114
105
 
115
106
 
116
107
  @dataclasses.dataclass
@@ -126,6 +117,7 @@ class KVCacheBase:
126
117
  config: model_config.ModelConfig,
127
118
  dtype: torch.dtype = torch.float32,
128
119
  device: torch.device = None,
120
+ batch_size: int = 1,
129
121
  ) -> "KVCacheBase":
130
122
  caches = [
131
123
  kv_entry_cls.from_model_config(
@@ -133,6 +125,7 @@ class KVCacheBase:
133
125
  config.block_config(idx).attn_config,
134
126
  dtype,
135
127
  device,
128
+ batch_size,
136
129
  )
137
130
  for idx in range(config.num_layers)
138
131
  ]
@@ -145,6 +138,7 @@ class KVCacheBase:
145
138
  config: model_config.ModelConfig,
146
139
  dtype: torch.dtype = torch.float32,
147
140
  device: torch.device = None,
141
+ batch_size: int = 1,
148
142
  ) -> "KVCacheBase":
149
143
  """Build an instance of the class based on model config.
150
144
 
@@ -154,12 +148,19 @@ class KVCacheBase:
154
148
  Defaults to torch.float32.
155
149
  device (torch.device, optional): The device placement of the cache
156
150
  tensors. Defaults to None.
151
+ batch_size (int, optional): The batch size of the cache tensors.
152
+ Defaults to 1.
157
153
 
158
154
  Returns:
159
155
  KVCacheBase: The created cache object.
160
156
  """
157
+ assert batch_size == 1, "Batch size must be 1 for KV Cache."
161
158
  return cls._from_model_config(
162
- KVCacheEntryBase, config=config, dtype=dtype, device=device
159
+ KVCacheEntryBase,
160
+ config=config,
161
+ dtype=dtype,
162
+ device=device,
163
+ batch_size=batch_size,
163
164
  )
164
165
 
165
166
  def flatten(self) -> List[torch.Tensor]:
@@ -177,9 +178,14 @@ class KVCacheBTNH(KVCacheBase):
177
178
  config: model_config.ModelConfig,
178
179
  dtype: torch.dtype = torch.float32,
179
180
  device: torch.device = None,
181
+ batch_size: int = 1,
180
182
  ) -> "KVCacheBTNH":
181
183
  return cls._from_model_config(
182
- KVCacheEntryBTNH, config=config, dtype=dtype, device=device
184
+ KVCacheEntryBTNH,
185
+ config=config,
186
+ dtype=dtype,
187
+ device=device,
188
+ batch_size=batch_size,
183
189
  )
184
190
 
185
191
 
@@ -192,9 +198,14 @@ class KVCacheTransposed(KVCacheBase):
192
198
  config: model_config.ModelConfig,
193
199
  dtype: torch.dtype = torch.float32,
194
200
  device: torch.device = None,
201
+ batch_size: int = 1,
195
202
  ) -> "KVCacheBTNH":
196
203
  return cls._from_model_config(
197
- KVCacheEntryTransposed, config=config, dtype=dtype, device=device
204
+ KVCacheEntryTransposed,
205
+ config=config,
206
+ dtype=dtype,
207
+ device=device,
208
+ batch_size=batch_size,
198
209
  )
199
210
 
200
211
 
@@ -258,7 +269,6 @@ def update(
258
269
  input_pos: torch.Tensor,
259
270
  k_slice: torch.Tensor,
260
271
  v_slice: torch.Tensor,
261
- use_dus: bool = True,
262
272
  ) -> KVCacheEntryBase:
263
273
  """Out of place update of Cache buffer.
264
274
 
@@ -309,6 +319,10 @@ def _update_kv_impl(
309
319
  positions = input_pos.clone()
310
320
  k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
321
  v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
- k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
- v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
322
+ k = dus_utils.dynamic_update_slice(
323
+ cache.k_cache, k_slice, [x for x in k_slice_indices]
324
+ )
325
+ v = dus_utils.dynamic_update_slice(
326
+ cache.v_cache, v_slice, [x for x in v_slice_indices]
327
+ )
314
328
  return KVCacheEntryTransposed(k, v)
@@ -18,14 +18,11 @@
18
18
  import dataclasses
19
19
  from typing import List, Tuple
20
20
 
21
- from ai_edge_torch import hlfb
22
21
  from ai_edge_torch.generative.layers import model_config
23
22
  from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
24
23
  import torch
25
24
  import torch.utils._pytree as pytree
26
25
 
27
- BATCH_SIZE = 1
28
-
29
26
 
30
27
  @dataclasses.dataclass
31
28
  class KVCacheEntry:
@@ -45,9 +42,10 @@ class KVCacheEntry:
45
42
  config: model_config.AttentionConfig,
46
43
  dtype: torch.dtype = torch.float32,
47
44
  device: torch.device = None,
45
+ batch_size: int = 1,
48
46
  ) -> "KVCacheEntry":
49
47
  """Build an instance of the class based on model config."""
50
- shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
48
+ shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
51
49
  k = torch.zeros(shape, dtype=dtype, device=device)
52
50
  v = torch.zeros(shape, dtype=dtype, device=device)
53
51
  obj = cls(k_cache=k, v_cache=v)
@@ -66,6 +64,7 @@ class KVCache:
66
64
  config: model_config.ModelConfig,
67
65
  dtype: torch.dtype = torch.float32,
68
66
  device: torch.device = None,
67
+ batch_size: int = 1,
69
68
  ) -> "KVCache":
70
69
  """Build an instance of the class based on model config.
71
70
 
@@ -75,17 +74,21 @@ class KVCache:
75
74
  Defaults to torch.float32.
76
75
  device (torch.device, optional): The device placement of the cache
77
76
  tensors. Defaults to None.
77
+ batch_size (int, optional): The batch size of the cache tensors.
78
+ Defaults to 1.
78
79
 
79
80
  Returns:
80
81
  KVCache: The created cache object.
81
82
  """
82
83
  caches = [
83
84
  KVCacheEntry.from_model_config(
84
- config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
85
+ config.kv_cache_max
86
+ if not config.block_config(idx).kv_cache_max_len
85
87
  else config.block_config(idx).kv_cache_max_len,
86
88
  config.block_config(idx).attn_config,
87
89
  dtype,
88
90
  device,
91
+ batch_size,
89
92
  )
90
93
  for idx in range(config.num_layers)
91
94
  ]
@@ -220,9 +220,6 @@ class ModelConfig:
220
220
  # The maximum sequence length of the KV cache. Should not exceed max_seq_len.
221
221
  kv_cache_max_len: int = 0
222
222
 
223
- # Default batch size of the exported model. Default value is 1.
224
- batch_size: int = 1
225
-
226
223
  # Softcap on the model output logits.
227
224
  final_logit_softcap: Optional[float] = None
228
225
 
@@ -110,6 +110,11 @@ def convert_to_tflite(
110
110
  lora_suffix = (
111
111
  '' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
112
112
  )
113
+
114
+ if export_config is not None:
115
+ if export_config.decode_batch_size > 1:
116
+ output_name_prefix += f'_dbs{export_config.decode_batch_size}'
117
+
113
118
  output_filename = (
114
119
  f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
115
120
  )
@@ -162,9 +167,14 @@ def _export_helper(
162
167
  if prefill_masks:
163
168
  assert len(prefill_masks) == len(prefill_seq_lens)
164
169
 
165
- decode_token = torch.tensor([[0]], dtype=torch.int)
170
+ decode_token = torch.tensor(
171
+ [[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
172
+ )
166
173
  decode_input_pos = torch.tensor([0], dtype=torch.int)
167
- kv = export_config.kvcache_cls.from_model_config(config)
174
+ prefill_kv = export_config.kvcache_cls.from_model_config(config)
175
+ decode_kv = export_config.kvcache_cls.from_model_config(
176
+ config, batch_size=export_config.decode_batch_size
177
+ )
168
178
 
169
179
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
170
180
 
@@ -183,7 +193,7 @@ def _export_helper(
183
193
  sample_kwargs = {
184
194
  'tokens': prefill_tokens,
185
195
  'input_pos': prefill_input_pos,
186
- 'kv_cache': kv,
196
+ 'kv_cache': prefill_kv,
187
197
  }
188
198
  if prefill_masks is not None:
189
199
  sample_kwargs['mask'] = prefill_masks[i]
@@ -211,7 +221,7 @@ def _export_helper(
211
221
  sample_kwargs = {
212
222
  'tokens': decode_token,
213
223
  'input_pos': decode_input_pos,
214
- 'kv_cache': kv,
224
+ 'kv_cache': decode_kv,
215
225
  }
216
226
  if export_config.decode_mask is not None:
217
227
  sample_kwargs['mask'] = export_config.decode_mask
@@ -60,6 +60,8 @@ class ExportConfig:
60
60
  decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
61
61
  # The KV Cache class for K and V buffers in attention.
62
62
  kvcache_cls: type = kv_utils.KVCache
63
+ # The batch size of the decode signature.
64
+ decode_batch_size: int = 1
63
65
 
64
66
 
65
67
  class DecoderOnlyModel(nn.Module):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250226"
16
+ __version__ = "0.4.0.dev20250228"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250226
3
+ Version: 0.4.0.dev20250228
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=Redqgp3EjtlXSINPZlLb-pjbUH61Ie1ejPLMQ8bl_lE,706
5
+ ai_edge_torch/version.py,sha256=-EqWeDLQh8HxiqQxA-N-t0YXsYU9QT1iaq2h-kCDBdo,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -102,8 +102,8 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
102
102
  ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
103
103
  ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
104
104
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
105
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=megskv1oiPhwHSnguoG7zV-esXp1Ns_FPeMLAYKhDb0,2522
106
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=CjY1i0iCYxFSjhCpQZwxkmVxILgeo0zu1m0oBrHqyDU,2311
105
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=hWko-RJB8eXNUfi4EzQ2yjW30YE4UB4zAz7rd2Q5qpg,2708
106
+ ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=sJ-o385eqQsciv0TEQRkixvS0DD6dKruAuK0zlEsDoY,2715
107
107
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
108
108
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
109
109
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -136,19 +136,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
136
136
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
137
137
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
138
138
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
139
- ai_edge_torch/generative/layers/attention.py,sha256=Pm8FLKh-NnOvUjqQC9oX5oghPbdivZvlPVkgOVTShoU,13703
139
+ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
140
140
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
141
141
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
142
142
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
143
- ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
143
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=jwbt0-2fd_CNWS2fp4nf0zvh6kk5citINGlFC_RtEUU,6540
144
144
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
145
- ai_edge_torch/generative/layers/model_config.py,sha256=EA1Ey5-c1IOLRNANSUnZ7gtNTA0o6OJxrz_I_mp8cjw,8244
145
+ ai_edge_torch/generative/layers/model_config.py,sha256=wNsZDzZQoimOKdZ9FWMCktPj-pQ_0D7084hgzMT5XYo,8155
146
146
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
147
147
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
148
148
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
149
149
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
150
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
151
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
150
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
151
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
152
152
  ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
153
153
  ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
154
154
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -173,10 +173,10 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
173
173
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
174
174
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
175
175
  ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
176
- ai_edge_torch/generative/utilities/converter.py,sha256=_PO9lYCdNNYPVsAqh8QQVMG_8TUBshKwmaR1cdT6Ang,8065
176
+ ai_edge_torch/generative/utilities/converter.py,sha256=VtG42CVz657XbvTj-FZJiCFW0Hm11OVKKC_mr2tjxhc,8413
177
177
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
178
178
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
179
- ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
179
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=eY3qAcBhupIn955YnWuzUi9hoWYvl4ntRWA6PBudzMo,6888
180
180
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
181
181
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
182
182
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
230
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
231
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
232
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
233
- ai_edge_torch_nightly-0.4.0.dev20250226.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
- ai_edge_torch_nightly-0.4.0.dev20250226.dist-info/METADATA,sha256=N6T5-MKa5Ztwx_XE7OJ8wiw2BC00e0dQxgngvI9S6CU,1966
235
- ai_edge_torch_nightly-0.4.0.dev20250226.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
- ai_edge_torch_nightly-0.4.0.dev20250226.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
- ai_edge_torch_nightly-0.4.0.dev20250226.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/METADATA,sha256=oGVZ_Z3zOzdyxj4cJ5XTT-YzPpTa99SBgFJo5zUBqJU,1966
235
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/RECORD,,