ai-edge-torch-nightly 0.4.0.dev20250227__py3-none-any.whl → 0.4.0.dev20250228__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,7 +52,6 @@ class TransformerBlock(nn.Module):
52
52
  config.pre_attention_norm_config,
53
53
  )
54
54
  self.atten_func = CausalSelfAttention(
55
- model_config.batch_size,
56
55
  model_config.embedding_dim,
57
56
  config.attn_config,
58
57
  model_config.enable_hlfb,
@@ -119,7 +118,6 @@ class CausalSelfAttention(nn.Module):
119
118
 
120
119
  def __init__(
121
120
  self,
122
- batch_size: int,
123
121
  dim: int,
124
122
  config: cfg.AttentionConfig,
125
123
  enable_hlfb: bool,
@@ -127,14 +125,12 @@ class CausalSelfAttention(nn.Module):
127
125
  """Initialize an instance of CausalSelfAttention.
128
126
 
129
127
  Args:
130
- batch_size (int): batch size of the input tensor.
131
128
  dim (int): causal attention's input/output dimmension.
132
129
  config (cfg.AttentionConfig): attention specific configurations.
133
130
  enable_hlfb (bool): whether hlfb is enabled or not.
134
131
  """
135
132
  super().__init__()
136
133
  self.kv_cache = None
137
- self.batch_size = batch_size
138
134
  qkv_shape = (
139
135
  config.num_heads + 2 * config.num_query_groups
140
136
  ) * config.head_dim
@@ -180,10 +176,6 @@ class CausalSelfAttention(nn.Module):
180
176
  """
181
177
  # Batch size, sequence length, embedding dimensionality.
182
178
  B, T, E = x.size()
183
- assert B == self.batch_size, (
184
- "batch size of input tensor must match with the batch size specified in"
185
- " the model configuration."
186
- )
187
179
 
188
180
  qkv = self.qkv_projection(x)
189
181
 
@@ -21,23 +21,19 @@ This is an experimental implementation and is subject to change at any time.
21
21
  import dataclasses
22
22
  from typing import List, Tuple
23
23
 
24
- from ai_edge_torch import hlfb
25
24
  from ai_edge_torch.generative.layers import model_config
26
- from ai_edge_torch.generative.layers.experimental import types as types
27
- from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
25
+ from ai_edge_torch.generative.layers.experimental import types
26
+ from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
28
27
  import torch
29
- import torch.nn as nn
30
28
  import torch.utils._pytree as pytree
31
29
 
32
- BATCH_SIZE = 1
33
-
34
30
 
35
31
  @dataclasses.dataclass
36
32
  class KVCacheEntryBase:
37
33
  """A single cache entry that includes K and V caches.
38
34
 
39
35
  The chaches are built based on the provided config with the shape of
40
- (batch_size=1, kv_cache_max, num_query_groups, head_dim).
36
+ (batch_size, kv_cache_max, num_query_groups, head_dim).
41
37
  """
42
38
 
43
39
  k_cache: torch.Tensor
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
46
42
  @classmethod
47
43
  def _from_model_config(
48
44
  cls,
49
- kv_cache_max: int,
50
- config: model_config.AttentionConfig,
51
- k_shape: Tuple,
52
- v_shape: Tuple,
45
+ k_shape: Tuple[int, ...],
46
+ v_shape: Tuple[int, ...],
53
47
  dtype: torch.dtype = torch.float32,
54
48
  device: torch.device = None,
55
49
  ) -> "KVCacheEntryBase":
@@ -66,12 +60,11 @@ class KVCacheEntryBase:
66
60
  config: model_config.AttentionConfig,
67
61
  dtype: torch.dtype = torch.float32,
68
62
  device: torch.device = None,
63
+ batch_size: int = 1,
69
64
  ) -> "KVCacheEntryBase":
70
65
  """Build an instance of the class based on model config."""
71
- shape = (BATCH_SIZE, kv_cache_max, config.num_query_groups, config.head_dim)
72
- return cls._from_model_config(
73
- kv_cache_max, config, shape, shape, dtype, device
74
- )
66
+ shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
67
+ return cls._from_model_config(shape, shape, dtype, device)
75
68
 
76
69
 
77
70
  @dataclasses.dataclass
@@ -93,24 +86,22 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
93
86
  config: model_config.AttentionConfig,
94
87
  dtype: torch.dtype = torch.float32,
95
88
  device: torch.device = None,
89
+ batch_size: int = 1,
96
90
  ) -> "KVCacheEntryBase":
97
91
  """Build an instance of the class based on model config."""
98
- num_kv_heads = config.num_query_groups
99
92
  k_shape = (
100
- 1,
101
- BATCH_SIZE * num_kv_heads,
93
+ batch_size,
94
+ config.num_query_groups,
102
95
  kv_cache_max,
103
96
  config.head_dim,
104
- ) # 1, bk, s, h
97
+ ) # b, k, s, h
105
98
  v_shape = (
106
- 1,
107
- BATCH_SIZE * num_kv_heads,
99
+ batch_size,
100
+ config.num_query_groups,
108
101
  config.head_dim,
109
102
  kv_cache_max,
110
- ) # 1, bk, h, s
111
- return cls._from_model_config(
112
- kv_cache_max, config, k_shape, v_shape, dtype, device
113
- )
103
+ ) # b, k, h, s
104
+ return cls._from_model_config(k_shape, v_shape, dtype, device)
114
105
 
115
106
 
116
107
  @dataclasses.dataclass
@@ -126,6 +117,7 @@ class KVCacheBase:
126
117
  config: model_config.ModelConfig,
127
118
  dtype: torch.dtype = torch.float32,
128
119
  device: torch.device = None,
120
+ batch_size: int = 1,
129
121
  ) -> "KVCacheBase":
130
122
  caches = [
131
123
  kv_entry_cls.from_model_config(
@@ -133,6 +125,7 @@ class KVCacheBase:
133
125
  config.block_config(idx).attn_config,
134
126
  dtype,
135
127
  device,
128
+ batch_size,
136
129
  )
137
130
  for idx in range(config.num_layers)
138
131
  ]
@@ -145,6 +138,7 @@ class KVCacheBase:
145
138
  config: model_config.ModelConfig,
146
139
  dtype: torch.dtype = torch.float32,
147
140
  device: torch.device = None,
141
+ batch_size: int = 1,
148
142
  ) -> "KVCacheBase":
149
143
  """Build an instance of the class based on model config.
150
144
 
@@ -154,12 +148,19 @@ class KVCacheBase:
154
148
  Defaults to torch.float32.
155
149
  device (torch.device, optional): The device placement of the cache
156
150
  tensors. Defaults to None.
151
+ batch_size (int, optional): The batch size of the cache tensors.
152
+ Defaults to 1.
157
153
 
158
154
  Returns:
159
155
  KVCacheBase: The created cache object.
160
156
  """
157
+ assert batch_size == 1, "Batch size must be 1 for KV Cache."
161
158
  return cls._from_model_config(
162
- KVCacheEntryBase, config=config, dtype=dtype, device=device
159
+ KVCacheEntryBase,
160
+ config=config,
161
+ dtype=dtype,
162
+ device=device,
163
+ batch_size=batch_size,
163
164
  )
164
165
 
165
166
  def flatten(self) -> List[torch.Tensor]:
@@ -177,9 +178,14 @@ class KVCacheBTNH(KVCacheBase):
177
178
  config: model_config.ModelConfig,
178
179
  dtype: torch.dtype = torch.float32,
179
180
  device: torch.device = None,
181
+ batch_size: int = 1,
180
182
  ) -> "KVCacheBTNH":
181
183
  return cls._from_model_config(
182
- KVCacheEntryBTNH, config=config, dtype=dtype, device=device
184
+ KVCacheEntryBTNH,
185
+ config=config,
186
+ dtype=dtype,
187
+ device=device,
188
+ batch_size=batch_size,
183
189
  )
184
190
 
185
191
 
@@ -192,9 +198,14 @@ class KVCacheTransposed(KVCacheBase):
192
198
  config: model_config.ModelConfig,
193
199
  dtype: torch.dtype = torch.float32,
194
200
  device: torch.device = None,
201
+ batch_size: int = 1,
195
202
  ) -> "KVCacheBTNH":
196
203
  return cls._from_model_config(
197
- KVCacheEntryTransposed, config=config, dtype=dtype, device=device
204
+ KVCacheEntryTransposed,
205
+ config=config,
206
+ dtype=dtype,
207
+ device=device,
208
+ batch_size=batch_size,
198
209
  )
199
210
 
200
211
 
@@ -258,7 +269,6 @@ def update(
258
269
  input_pos: torch.Tensor,
259
270
  k_slice: torch.Tensor,
260
271
  v_slice: torch.Tensor,
261
- use_dus: bool = True,
262
272
  ) -> KVCacheEntryBase:
263
273
  """Out of place update of Cache buffer.
264
274
 
@@ -309,6 +319,10 @@ def _update_kv_impl(
309
319
  positions = input_pos.clone()
310
320
  k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
311
321
  v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
312
- k = dynamic_update_slice(cache.k_cache, k_slice, [x for x in k_slice_indices])
313
- v = dynamic_update_slice(cache.v_cache, v_slice, [x for x in v_slice_indices])
322
+ k = dus_utils.dynamic_update_slice(
323
+ cache.k_cache, k_slice, [x for x in k_slice_indices]
324
+ )
325
+ v = dus_utils.dynamic_update_slice(
326
+ cache.v_cache, v_slice, [x for x in v_slice_indices]
327
+ )
314
328
  return KVCacheEntryTransposed(k, v)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250227"
16
+ __version__ = "0.4.0.dev20250228"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250227
3
+ Version: 0.4.0.dev20250228
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=K2jtDrBNGi74j_uQYVUT6MJ2-aQFKkKy5ZYur9iWdVU,706
5
+ ai_edge_torch/version.py,sha256=-EqWeDLQh8HxiqQxA-N-t0YXsYU9QT1iaq2h-kCDBdo,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -147,8 +147,8 @@ ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBiz
147
147
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
148
148
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
149
149
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
150
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
151
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
150
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
151
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
152
152
  ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
153
153
  ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
154
154
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
230
230
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
231
231
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
232
232
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
233
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/METADATA,sha256=cHcz3adq1WwVddazAJ06h7SKITJm70eMpFVjoNa2Jw4,1966
235
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
- ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/RECORD,,
233
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
234
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/METADATA,sha256=oGVZ_Z3zOzdyxj4cJ5XTT-YzPpTa99SBgFJo5zUBqJU,1966
235
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
236
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
237
+ ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/RECORD,,