ai-edge-torch-nightly 0.6.0.dev20250528__py3-none-any.whl → 0.6.0.dev20250530__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -134,8 +134,6 @@ class AttentionConfig:
134
134
  value_norm_config: NormalizationConfig = dataclasses.field(
135
135
  default_factory=NormalizationConfig
136
136
  )
137
- # Whether the KV cache is shared with the previous attention block.
138
- kv_shared: bool = False
139
137
  relative_attention_num_buckets: int = 0
140
138
  relative_attention_max_distance: int = 0
141
139
  # Softcap on the output logits.
@@ -268,6 +266,12 @@ class ModelConfig:
268
266
  # export.
269
267
  use_mask_cache: bool = True
270
268
 
269
+ # An interleaved sequence of the attention types used in the model.
270
+ # E.g. [AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING,
271
+ # AttentionType.GLOBAL] means that the model has an attention pattern of 2
272
+ # local attentions followed by a global attention in a repeated pattern.
273
+ attention_patterns: Optional[Sequence[AttentionType]] = None
274
+
271
275
  @property
272
276
  def kv_cache_max(self) -> int:
273
277
  if self.kv_cache_max_len > 0:
@@ -286,3 +290,19 @@ class ModelConfig:
286
290
  @property
287
291
  def causal_mask_value(self) -> float:
288
292
  return self.block_config(0).attn_config.causal_mask_value
293
+
294
+ def check_if_global_attention_layer(self, layer_idx: int) -> bool:
295
+ """Returns True if the layer is a global attention layer."""
296
+ if self.attention_patterns is None:
297
+ # If attention_patterns is not set, we assume the model has global
298
+ # attention.
299
+ return True
300
+ assert layer_idx >= 0 and layer_idx < self.num_layers, (
301
+ "Layer index {layer_idx} is out of range for num_layers:"
302
+ f" {self.num_layers}"
303
+ )
304
+
305
+ return (
306
+ self.block_config(layer_idx).attn_config.attn_type
307
+ == AttentionType.GLOBAL
308
+ )
@@ -0,0 +1,29 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Test utils for generative layers."""
16
+ from typing import Sequence
17
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
18
+ import torch
19
+
20
+
21
+ def initialize_kv_cache_all_zeros(
22
+ kv_shape: Sequence[int],
23
+ layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
24
+ ) -> kv_utils.KVCacheEntry:
25
+ return kv_utils.KVCacheEntry(
26
+ k_cache=torch.zeros(kv_shape, dtype=torch.float32),
27
+ v_cache=torch.zeros(kv_shape, dtype=torch.float32),
28
+ kv_layout=layout,
29
+ )
ai_edge_torch/model.py CHANGED
@@ -155,7 +155,8 @@ class TfLiteModel(Model):
155
155
  Args:
156
156
  path: The path to file to which the model is serialized.
157
157
  """
158
- os.makedirs(os.path.dirname(path), exist_ok=True)
158
+ if os.path.dirname(path):
159
+ os.makedirs(os.path.dirname(path), exist_ok=True)
159
160
  with open(path, 'wb') as file_handle:
160
161
  file_handle.write(self._tflite_model)
161
162
 
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250528"
18
+ __version__ = "0.6.0.dev20250530"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250528
3
+ Version: 0.6.0.dev20250530
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,8 +1,8 @@
1
1
  ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,1290
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=MuW0AEZVV7KlferCv485Nb_a1fonWf_MSQEeft5h9yU,806
4
+ ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
+ ai_edge_torch/version.py,sha256=B-sOsG_3lPrDKxH_MJPNpivWVftaRufBHKPbBig2z3E,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -180,7 +180,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7L
180
180
  ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
181
181
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
182
182
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
183
- ai_edge_torch/generative/layers/model_config.py,sha256=0FH3UJPVnEhgBO4eUlNaHuQBDo_OKH17ChG5-Ybj2T4,9895
183
+ ai_edge_torch/generative/layers/model_config.py,sha256=gGr6SZzgFVRXemfXwX__TZ0OXgAFnU70M6U0eql06TE,10712
184
184
  ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
185
185
  ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
186
186
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
@@ -215,6 +215,7 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSyd
215
215
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
216
216
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
217
217
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
218
+ ai_edge_torch/generative/utilities/test_utils.py,sha256=fhUMCMxoeMzxYbOCjNeX5wbQmF6Y88Hi52FtRiZYJAk,1147
218
219
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=l54bmmhj613eB2oCoONIAKEHhf8TQOhC9Gwjp6lxHAE,1659
219
220
  ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
220
221
  ai_edge_torch/generative/utilities/verifier.py,sha256=ETO2ShU5KXG7MLP8eVOWuzuRLCUtapafYHcZ6TZHIkw,13061
@@ -267,8 +268,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
267
268
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
268
269
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
269
270
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
270
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
271
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/METADATA,sha256=4Zdcdi5qzXkx7bVEuHMUdtsdN0UljrDl-8ud8-Q1hQQ,2074
272
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
273
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
274
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/RECORD,,
271
+ ai_edge_torch_nightly-0.6.0.dev20250530.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
272
+ ai_edge_torch_nightly-0.6.0.dev20250530.dist-info/METADATA,sha256=7oeZ6wSsBUuvNXH20tOHtlWkw_Rfmmr0EADK7hSt6AQ,2074
273
+ ai_edge_torch_nightly-0.6.0.dev20250530.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
274
+ ai_edge_torch_nightly-0.6.0.dev20250530.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
275
+ ai_edge_torch_nightly-0.6.0.dev20250530.dist-info/RECORD,,