ai-edge-torch-nightly 0.6.0.dev20250528__py3-none-any.whl → 0.6.0.dev20250529__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -268,6 +268,12 @@ class ModelConfig:
268
268
  # export.
269
269
  use_mask_cache: bool = True
270
270
 
271
+ # An interleaved sequence of the attention types used in the model.
272
+ # E.g. [AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING,
273
+ # AttentionType.GLOBAL] means that the model has an attention pattern of 2
274
+ # local attentions followed by a global attention in a repeated pattern.
275
+ attention_patterns: Optional[Sequence[AttentionType]] = None
276
+
271
277
  @property
272
278
  def kv_cache_max(self) -> int:
273
279
  if self.kv_cache_max_len > 0:
@@ -286,3 +292,19 @@ class ModelConfig:
286
292
  @property
287
293
  def causal_mask_value(self) -> float:
288
294
  return self.block_config(0).attn_config.causal_mask_value
295
+
296
+ def check_if_global_attention_layer(self, layer_idx: int) -> bool:
297
+ """Returns True if the layer is a global attention layer."""
298
+ if self.attention_patterns is None:
299
+ # If attention_patterns is not set, we assume the model has global
300
+ # attention.
301
+ return True
302
+ assert layer_idx >= 0 and layer_idx < self.num_layers, (
303
+ "Layer index {layer_idx} is out of range for num_layers:"
304
+ f" {self.num_layers}"
305
+ )
306
+
307
+ return (
308
+ self.block_config(layer_idx).attn_config.attn_type
309
+ == AttentionType.GLOBAL
310
+ )
@@ -0,0 +1,29 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Test utils for generative layers."""
16
+ from typing import Sequence
17
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
18
+ import torch
19
+
20
+
21
+ def initialize_kv_cache_all_zeros(
22
+ kv_shape: Sequence[int],
23
+ layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT,
24
+ ) -> kv_utils.KVCacheEntry:
25
+ return kv_utils.KVCacheEntry(
26
+ k_cache=torch.zeros(kv_shape, dtype=torch.float32),
27
+ v_cache=torch.zeros(kv_shape, dtype=torch.float32),
28
+ kv_layout=layout,
29
+ )
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250528"
18
+ __version__ = "0.6.0.dev20250529"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250528
3
+ Version: 0.6.0.dev20250529
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=MuW0AEZVV7KlferCv485Nb_a1fonWf_MSQEeft5h9yU,806
5
+ ai_edge_torch/version.py,sha256=j84SLcYNp64OJ8Smn3CjFhF-ojE6L2pi7LJltVN_0Hg,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -180,7 +180,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7L
180
180
  ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
181
181
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
182
182
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
183
- ai_edge_torch/generative/layers/model_config.py,sha256=0FH3UJPVnEhgBO4eUlNaHuQBDo_OKH17ChG5-Ybj2T4,9895
183
+ ai_edge_torch/generative/layers/model_config.py,sha256=UH7vQPUCLZi6NRALiA28tYhoG3O2XW_P_QtjTj2r0Ts,10808
184
184
  ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
185
185
  ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
186
186
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
@@ -215,6 +215,7 @@ ai_edge_torch/generative/utilities/model_builder.py,sha256=tBfOcsI_NcneggHqkCSyd
215
215
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
216
216
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
217
217
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
218
+ ai_edge_torch/generative/utilities/test_utils.py,sha256=fhUMCMxoeMzxYbOCjNeX5wbQmF6Y88Hi52FtRiZYJAk,1147
218
219
  ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=l54bmmhj613eB2oCoONIAKEHhf8TQOhC9Gwjp6lxHAE,1659
219
220
  ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
220
221
  ai_edge_torch/generative/utilities/verifier.py,sha256=ETO2ShU5KXG7MLP8eVOWuzuRLCUtapafYHcZ6TZHIkw,13061
@@ -267,8 +268,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
267
268
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
268
269
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
269
270
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
270
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
271
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/METADATA,sha256=4Zdcdi5qzXkx7bVEuHMUdtsdN0UljrDl-8ud8-Q1hQQ,2074
272
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
273
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
274
- ai_edge_torch_nightly-0.6.0.dev20250528.dist-info/RECORD,,
271
+ ai_edge_torch_nightly-0.6.0.dev20250529.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
272
+ ai_edge_torch_nightly-0.6.0.dev20250529.dist-info/METADATA,sha256=uBCPe1_F0-0_ZyaQRalg_9Qs_iTHnpWr8p4jr-DCLMk,2074
273
+ ai_edge_torch_nightly-0.6.0.dev20250529.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
274
+ ai_edge_torch_nightly-0.6.0.dev20250529.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
275
+ ai_edge_torch_nightly-0.6.0.dev20250529.dist-info/RECORD,,