ai-edge-torch-nightly 0.4.0.dev20250227__py3-none-any.whl → 0.4.0.dev20250228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/layers/experimental/attention.py +0 -8
- ai_edge_torch/generative/layers/experimental/kv_cache.py +45 -31
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250227.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/top_level.txt +0 -0
@@ -52,7 +52,6 @@ class TransformerBlock(nn.Module):
|
|
52
52
|
config.pre_attention_norm_config,
|
53
53
|
)
|
54
54
|
self.atten_func = CausalSelfAttention(
|
55
|
-
model_config.batch_size,
|
56
55
|
model_config.embedding_dim,
|
57
56
|
config.attn_config,
|
58
57
|
model_config.enable_hlfb,
|
@@ -119,7 +118,6 @@ class CausalSelfAttention(nn.Module):
|
|
119
118
|
|
120
119
|
def __init__(
|
121
120
|
self,
|
122
|
-
batch_size: int,
|
123
121
|
dim: int,
|
124
122
|
config: cfg.AttentionConfig,
|
125
123
|
enable_hlfb: bool,
|
@@ -127,14 +125,12 @@ class CausalSelfAttention(nn.Module):
|
|
127
125
|
"""Initialize an instance of CausalSelfAttention.
|
128
126
|
|
129
127
|
Args:
|
130
|
-
batch_size (int): batch size of the input tensor.
|
131
128
|
dim (int): causal attention's input/output dimmension.
|
132
129
|
config (cfg.AttentionConfig): attention specific configurations.
|
133
130
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
134
131
|
"""
|
135
132
|
super().__init__()
|
136
133
|
self.kv_cache = None
|
137
|
-
self.batch_size = batch_size
|
138
134
|
qkv_shape = (
|
139
135
|
config.num_heads + 2 * config.num_query_groups
|
140
136
|
) * config.head_dim
|
@@ -180,10 +176,6 @@ class CausalSelfAttention(nn.Module):
|
|
180
176
|
"""
|
181
177
|
# Batch size, sequence length, embedding dimensionality.
|
182
178
|
B, T, E = x.size()
|
183
|
-
assert B == self.batch_size, (
|
184
|
-
"batch size of input tensor must match with the batch size specified in"
|
185
|
-
" the model configuration."
|
186
|
-
)
|
187
179
|
|
188
180
|
qkv = self.qkv_projection(x)
|
189
181
|
|
@@ -21,23 +21,19 @@ This is an experimental implementation and is subject to change at any time.
|
|
21
21
|
import dataclasses
|
22
22
|
from typing import List, Tuple
|
23
23
|
|
24
|
-
from ai_edge_torch import hlfb
|
25
24
|
from ai_edge_torch.generative.layers import model_config
|
26
|
-
from ai_edge_torch.generative.layers.experimental import types
|
27
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.layers.experimental import types
|
26
|
+
from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
|
28
27
|
import torch
|
29
|
-
import torch.nn as nn
|
30
28
|
import torch.utils._pytree as pytree
|
31
29
|
|
32
|
-
BATCH_SIZE = 1
|
33
|
-
|
34
30
|
|
35
31
|
@dataclasses.dataclass
|
36
32
|
class KVCacheEntryBase:
|
37
33
|
"""A single cache entry that includes K and V caches.
|
38
34
|
|
39
35
|
The chaches are built based on the provided config with the shape of
|
40
|
-
(batch_size
|
36
|
+
(batch_size, kv_cache_max, num_query_groups, head_dim).
|
41
37
|
"""
|
42
38
|
|
43
39
|
k_cache: torch.Tensor
|
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
|
|
46
42
|
@classmethod
|
47
43
|
def _from_model_config(
|
48
44
|
cls,
|
49
|
-
|
50
|
-
|
51
|
-
k_shape: Tuple,
|
52
|
-
v_shape: Tuple,
|
45
|
+
k_shape: Tuple[int, ...],
|
46
|
+
v_shape: Tuple[int, ...],
|
53
47
|
dtype: torch.dtype = torch.float32,
|
54
48
|
device: torch.device = None,
|
55
49
|
) -> "KVCacheEntryBase":
|
@@ -66,12 +60,11 @@ class KVCacheEntryBase:
|
|
66
60
|
config: model_config.AttentionConfig,
|
67
61
|
dtype: torch.dtype = torch.float32,
|
68
62
|
device: torch.device = None,
|
63
|
+
batch_size: int = 1,
|
69
64
|
) -> "KVCacheEntryBase":
|
70
65
|
"""Build an instance of the class based on model config."""
|
71
|
-
shape = (
|
72
|
-
return cls._from_model_config(
|
73
|
-
kv_cache_max, config, shape, shape, dtype, device
|
74
|
-
)
|
66
|
+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
67
|
+
return cls._from_model_config(shape, shape, dtype, device)
|
75
68
|
|
76
69
|
|
77
70
|
@dataclasses.dataclass
|
@@ -93,24 +86,22 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
|
|
93
86
|
config: model_config.AttentionConfig,
|
94
87
|
dtype: torch.dtype = torch.float32,
|
95
88
|
device: torch.device = None,
|
89
|
+
batch_size: int = 1,
|
96
90
|
) -> "KVCacheEntryBase":
|
97
91
|
"""Build an instance of the class based on model config."""
|
98
|
-
num_kv_heads = config.num_query_groups
|
99
92
|
k_shape = (
|
100
|
-
|
101
|
-
|
93
|
+
batch_size,
|
94
|
+
config.num_query_groups,
|
102
95
|
kv_cache_max,
|
103
96
|
config.head_dim,
|
104
|
-
) #
|
97
|
+
) # b, k, s, h
|
105
98
|
v_shape = (
|
106
|
-
|
107
|
-
|
99
|
+
batch_size,
|
100
|
+
config.num_query_groups,
|
108
101
|
config.head_dim,
|
109
102
|
kv_cache_max,
|
110
|
-
) #
|
111
|
-
return cls._from_model_config(
|
112
|
-
kv_cache_max, config, k_shape, v_shape, dtype, device
|
113
|
-
)
|
103
|
+
) # b, k, h, s
|
104
|
+
return cls._from_model_config(k_shape, v_shape, dtype, device)
|
114
105
|
|
115
106
|
|
116
107
|
@dataclasses.dataclass
|
@@ -126,6 +117,7 @@ class KVCacheBase:
|
|
126
117
|
config: model_config.ModelConfig,
|
127
118
|
dtype: torch.dtype = torch.float32,
|
128
119
|
device: torch.device = None,
|
120
|
+
batch_size: int = 1,
|
129
121
|
) -> "KVCacheBase":
|
130
122
|
caches = [
|
131
123
|
kv_entry_cls.from_model_config(
|
@@ -133,6 +125,7 @@ class KVCacheBase:
|
|
133
125
|
config.block_config(idx).attn_config,
|
134
126
|
dtype,
|
135
127
|
device,
|
128
|
+
batch_size,
|
136
129
|
)
|
137
130
|
for idx in range(config.num_layers)
|
138
131
|
]
|
@@ -145,6 +138,7 @@ class KVCacheBase:
|
|
145
138
|
config: model_config.ModelConfig,
|
146
139
|
dtype: torch.dtype = torch.float32,
|
147
140
|
device: torch.device = None,
|
141
|
+
batch_size: int = 1,
|
148
142
|
) -> "KVCacheBase":
|
149
143
|
"""Build an instance of the class based on model config.
|
150
144
|
|
@@ -154,12 +148,19 @@ class KVCacheBase:
|
|
154
148
|
Defaults to torch.float32.
|
155
149
|
device (torch.device, optional): The device placement of the cache
|
156
150
|
tensors. Defaults to None.
|
151
|
+
batch_size (int, optional): The batch size of the cache tensors.
|
152
|
+
Defaults to 1.
|
157
153
|
|
158
154
|
Returns:
|
159
155
|
KVCacheBase: The created cache object.
|
160
156
|
"""
|
157
|
+
assert batch_size == 1, "Batch size must be 1 for KV Cache."
|
161
158
|
return cls._from_model_config(
|
162
|
-
KVCacheEntryBase,
|
159
|
+
KVCacheEntryBase,
|
160
|
+
config=config,
|
161
|
+
dtype=dtype,
|
162
|
+
device=device,
|
163
|
+
batch_size=batch_size,
|
163
164
|
)
|
164
165
|
|
165
166
|
def flatten(self) -> List[torch.Tensor]:
|
@@ -177,9 +178,14 @@ class KVCacheBTNH(KVCacheBase):
|
|
177
178
|
config: model_config.ModelConfig,
|
178
179
|
dtype: torch.dtype = torch.float32,
|
179
180
|
device: torch.device = None,
|
181
|
+
batch_size: int = 1,
|
180
182
|
) -> "KVCacheBTNH":
|
181
183
|
return cls._from_model_config(
|
182
|
-
KVCacheEntryBTNH,
|
184
|
+
KVCacheEntryBTNH,
|
185
|
+
config=config,
|
186
|
+
dtype=dtype,
|
187
|
+
device=device,
|
188
|
+
batch_size=batch_size,
|
183
189
|
)
|
184
190
|
|
185
191
|
|
@@ -192,9 +198,14 @@ class KVCacheTransposed(KVCacheBase):
|
|
192
198
|
config: model_config.ModelConfig,
|
193
199
|
dtype: torch.dtype = torch.float32,
|
194
200
|
device: torch.device = None,
|
201
|
+
batch_size: int = 1,
|
195
202
|
) -> "KVCacheBTNH":
|
196
203
|
return cls._from_model_config(
|
197
|
-
KVCacheEntryTransposed,
|
204
|
+
KVCacheEntryTransposed,
|
205
|
+
config=config,
|
206
|
+
dtype=dtype,
|
207
|
+
device=device,
|
208
|
+
batch_size=batch_size,
|
198
209
|
)
|
199
210
|
|
200
211
|
|
@@ -258,7 +269,6 @@ def update(
|
|
258
269
|
input_pos: torch.Tensor,
|
259
270
|
k_slice: torch.Tensor,
|
260
271
|
v_slice: torch.Tensor,
|
261
|
-
use_dus: bool = True,
|
262
272
|
) -> KVCacheEntryBase:
|
263
273
|
"""Out of place update of Cache buffer.
|
264
274
|
|
@@ -309,6 +319,10 @@ def _update_kv_impl(
|
|
309
319
|
positions = input_pos.clone()
|
310
320
|
k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
|
311
321
|
v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
|
312
|
-
k = dynamic_update_slice(
|
313
|
-
|
322
|
+
k = dus_utils.dynamic_update_slice(
|
323
|
+
cache.k_cache, k_slice, [x for x in k_slice_indices]
|
324
|
+
)
|
325
|
+
v = dus_utils.dynamic_update_slice(
|
326
|
+
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
327
|
+
)
|
314
328
|
return KVCacheEntryTransposed(k, v)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250228
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256
|
5
|
+
ai_edge_torch/version.py,sha256=-EqWeDLQh8HxiqQxA-N-t0YXsYU9QT1iaq2h-kCDBdo,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -147,8 +147,8 @@ ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBiz
|
|
147
147
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
148
148
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
149
149
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
150
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=
|
151
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
150
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
|
151
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
|
152
152
|
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
|
153
153
|
ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
|
154
154
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
230
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
231
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
232
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
233
|
-
ai_edge_torch_nightly-0.4.0.
|
234
|
-
ai_edge_torch_nightly-0.4.0.
|
235
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
233
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/METADATA,sha256=oGVZ_Z3zOzdyxj4cJ5XTT-YzPpTa99SBgFJo5zUBqJU,1966
|
235
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/RECORD,,
|
File without changes
|
File without changes
|