ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
- ai_edge_torch/convert/conversion_utils.py +0 -439
- ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
|
@@ -28,17 +28,17 @@ def convert_tiny_llama_to_tflite(
|
|
|
28
28
|
kv_cache_max_len: int = 1024,
|
|
29
29
|
quantize: bool = True,
|
|
30
30
|
):
|
|
31
|
-
"""
|
|
32
|
-
tflite model.
|
|
31
|
+
"""Converts TinyLlama model to multi-signature tflite model.
|
|
33
32
|
|
|
34
33
|
Args:
|
|
35
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
34
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
|
35
|
+
holding the checkpoint.
|
|
36
36
|
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
37
37
|
Defaults to 512.
|
|
38
38
|
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
39
39
|
including both prefill and decode. Defaults to 1024.
|
|
40
|
-
quantize (bool, optional): Whether the model should be quanized.
|
|
41
|
-
|
|
40
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
|
41
|
+
to True.
|
|
42
42
|
"""
|
|
43
43
|
pytorch_model = tiny_llama.build_model(
|
|
44
44
|
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
@@ -64,7 +64,9 @@ class TinyLLamma(nn.Module):
|
|
|
64
64
|
)
|
|
65
65
|
self.rope_cache = attn_utils.build_rope_cache(
|
|
66
66
|
size=config.kv_cache_max,
|
|
67
|
-
dim=int(
|
|
67
|
+
dim=int(
|
|
68
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
69
|
+
),
|
|
68
70
|
base=10_000,
|
|
69
71
|
condense_ratio=1,
|
|
70
72
|
dtype=torch.float32,
|
|
@@ -109,6 +111,7 @@ class TinyLLamma(nn.Module):
|
|
|
109
111
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
110
112
|
attn_config = cfg.AttentionConfig(
|
|
111
113
|
num_heads=32,
|
|
114
|
+
head_dim=64,
|
|
112
115
|
num_query_groups=4,
|
|
113
116
|
rotary_percentage=1.0,
|
|
114
117
|
)
|
|
@@ -12,8 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from ai_edge_torch.
|
|
16
|
-
from ai_edge_torch.
|
|
15
|
+
from ai_edge_torch._convert.fx_passes import CanonicalizePass
|
|
16
|
+
from ai_edge_torch._convert.fx_passes import run_passes
|
|
17
17
|
from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass # NOQA
|
|
18
18
|
import torch
|
|
19
19
|
|
|
@@ -12,8 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from ai_edge_torch
|
|
16
|
-
from ai_edge_torch.
|
|
15
|
+
from ai_edge_torch import lowertools
|
|
16
|
+
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
17
|
+
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult
|
|
17
18
|
import torch
|
|
18
19
|
|
|
19
20
|
|
|
@@ -27,7 +28,7 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
|
27
28
|
for node in graph.nodes:
|
|
28
29
|
if not (
|
|
29
30
|
node.op == "call_function"
|
|
30
|
-
and node.target ==
|
|
31
|
+
and node.target == lowertools.mark_tensor_op
|
|
31
32
|
):
|
|
32
33
|
continue
|
|
33
34
|
|
|
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_
|
|
|
24
24
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
25
25
|
import torch
|
|
26
26
|
from torch import nn
|
|
27
|
-
import torch.nn.functional as F
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
def _embed_rope(
|
|
@@ -60,8 +59,8 @@ class TransformerBlock(nn.Module):
|
|
|
60
59
|
"""Initialize an instance of the TransformerBlock.
|
|
61
60
|
|
|
62
61
|
Args:
|
|
63
|
-
config (cfg.ModelConfig): the configuration object
|
|
64
|
-
|
|
62
|
+
config (cfg.ModelConfig): the configuration object for this transformer
|
|
63
|
+
block.
|
|
65
64
|
"""
|
|
66
65
|
|
|
67
66
|
super().__init__()
|
|
@@ -131,20 +130,23 @@ class CausalSelfAttention(nn.Module):
|
|
|
131
130
|
batch_size (int): batch size of the input tensor.
|
|
132
131
|
dim (int): causal attention's input/output dimmension.
|
|
133
132
|
config (cfg.AttentionConfig): attention specific configurations.
|
|
134
|
-
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
|
133
|
+
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
|
134
|
+
enabled.
|
|
135
135
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
|
136
136
|
"""
|
|
137
137
|
super().__init__()
|
|
138
|
-
self.head_dim = dim // config.num_heads
|
|
139
|
-
shape = (config.num_heads + 2 * config.num_query_groups) * self.head_dim
|
|
140
|
-
# Key, query, value projections for all heads.
|
|
141
|
-
self.qkv_projection = nn.Linear(dim, shape, bias=config.qkv_use_bias)
|
|
142
|
-
self.output_projection = nn.Linear(
|
|
143
|
-
dim, dim, bias=config.output_proj_use_bias
|
|
144
|
-
)
|
|
145
138
|
self.config = config
|
|
146
139
|
self.kv_cache = None
|
|
147
140
|
self.batch_size = batch_size
|
|
141
|
+
qkv_shape = (
|
|
142
|
+
config.num_heads + 2 * config.num_query_groups
|
|
143
|
+
) * config.head_dim
|
|
144
|
+
output_shape = config.num_heads * config.head_dim
|
|
145
|
+
# Key, query, value projections for all heads.
|
|
146
|
+
self.qkv_projection = nn.Linear(dim, qkv_shape, bias=config.qkv_use_bias)
|
|
147
|
+
self.output_projection = nn.Linear(
|
|
148
|
+
output_shape, dim, bias=config.output_proj_use_bias
|
|
149
|
+
)
|
|
148
150
|
|
|
149
151
|
# Build a k/v cache with size (batch_size, kv_cache_max, n_heads, head_dim).
|
|
150
152
|
if config.enable_kv_cache:
|
|
@@ -152,7 +154,7 @@ class CausalSelfAttention(nn.Module):
|
|
|
152
154
|
batch_size,
|
|
153
155
|
kv_cache_max,
|
|
154
156
|
config.num_query_groups,
|
|
155
|
-
|
|
157
|
+
config.head_dim,
|
|
156
158
|
enable_hlfb,
|
|
157
159
|
)
|
|
158
160
|
|
|
@@ -169,6 +171,7 @@ class CausalSelfAttention(nn.Module):
|
|
|
169
171
|
input_pos: Optional[torch.Tensor] = None,
|
|
170
172
|
) -> torch.Tensor:
|
|
171
173
|
"""Forward function of the CausalSelfAttention layer, which can support
|
|
174
|
+
|
|
172
175
|
MQA, GQA and MHA.
|
|
173
176
|
|
|
174
177
|
Args:
|
|
@@ -193,7 +196,7 @@ class CausalSelfAttention(nn.Module):
|
|
|
193
196
|
q_per_kv = self.config.num_heads // self.config.num_query_groups
|
|
194
197
|
# Each group has >=1 queries, 1 key, and 1 value.
|
|
195
198
|
if self.config.qkv_transpose_before_split:
|
|
196
|
-
qkv = qkv.view(B, T, -1, self.head_dim)
|
|
199
|
+
qkv = qkv.view(B, T, -1, self.config.head_dim)
|
|
197
200
|
q, k, v = qkv.split(
|
|
198
201
|
(
|
|
199
202
|
q_per_kv * self.config.num_query_groups,
|
|
@@ -205,22 +208,27 @@ class CausalSelfAttention(nn.Module):
|
|
|
205
208
|
else:
|
|
206
209
|
qkv = qkv.view(B, T, self.config.num_query_groups, -1)
|
|
207
210
|
q, k, v = qkv.split(
|
|
208
|
-
(
|
|
211
|
+
(
|
|
212
|
+
q_per_kv * self.config.head_dim,
|
|
213
|
+
self.config.head_dim,
|
|
214
|
+
self.config.head_dim,
|
|
215
|
+
),
|
|
216
|
+
dim=-1,
|
|
209
217
|
)
|
|
210
218
|
|
|
211
|
-
q = q.reshape(B, T, -1, self.head_dim)
|
|
212
|
-
k = k.reshape(B, T, -1, self.head_dim)
|
|
213
|
-
v = v.reshape(B, T, -1, self.head_dim)
|
|
219
|
+
q = q.reshape(B, T, -1, self.config.head_dim)
|
|
220
|
+
k = k.reshape(B, T, -1, self.config.head_dim)
|
|
221
|
+
v = v.reshape(B, T, -1, self.config.head_dim)
|
|
214
222
|
|
|
215
223
|
# Compute rotary positional embedding for query and key.
|
|
216
|
-
n_elem = int(self.config.rotary_percentage * self.head_dim)
|
|
224
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
|
217
225
|
q, k = _embed_rope(q, k, n_elem, rope)
|
|
218
226
|
|
|
219
227
|
if self.kv_cache is not None:
|
|
220
228
|
# TODO(haoliang): Handle when execeeding max sequence length.
|
|
221
229
|
k, v = self.kv_cache.update_cache(input_pos, k, v)
|
|
222
230
|
|
|
223
|
-
y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
|
|
231
|
+
y = self.sdpa_func(q, k, v, self.config.head_dim, mask=mask)
|
|
224
232
|
y = y.reshape(B, T, E)
|
|
225
233
|
|
|
226
234
|
# Compute the output projection.
|
|
@@ -274,12 +282,12 @@ class CrossAttention(nn.Module):
|
|
|
274
282
|
query_dim (int): query tensor's dimension.
|
|
275
283
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
|
276
284
|
config (cfg.AttentionConfig): attention specific configurations.
|
|
277
|
-
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
|
285
|
+
kv_cache_max (int): determines the size of the KV Cache buffer, if
|
|
286
|
+
enabled.
|
|
278
287
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
|
279
288
|
"""
|
|
280
289
|
super().__init__()
|
|
281
290
|
self.config = config
|
|
282
|
-
self.head_dim = query_dim // config.num_heads
|
|
283
291
|
self.n_heads = config.num_heads
|
|
284
292
|
self.q_projection = nn.Linear(
|
|
285
293
|
query_dim, query_dim, bias=config.qkv_use_bias
|
|
@@ -301,7 +309,7 @@ class CrossAttention(nn.Module):
|
|
|
301
309
|
batch_size,
|
|
302
310
|
kv_cache_max,
|
|
303
311
|
config.num_query_groups,
|
|
304
|
-
self.head_dim,
|
|
312
|
+
self.config.head_dim,
|
|
305
313
|
enable_hlfb,
|
|
306
314
|
)
|
|
307
315
|
|
|
@@ -324,7 +332,8 @@ class CrossAttention(nn.Module):
|
|
|
324
332
|
x (torch.Tensor): the target tensor, with shape [B, target_seq_len, ...].
|
|
325
333
|
y (torch.Tensor): the source tensor, with shape [B, source_seq_len, ...].
|
|
326
334
|
rope (Tuple[torch.Tensor, torch.Tensor]): the optional input rope tensor.
|
|
327
|
-
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
|
335
|
+
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
|
336
|
+
[B, n_heads, target_seq_len, source_seq_len].
|
|
328
337
|
input_pos (torch.Tensor): the optional input position tensor.
|
|
329
338
|
|
|
330
339
|
Returns:
|
|
@@ -338,13 +347,13 @@ class CrossAttention(nn.Module):
|
|
|
338
347
|
k = self.k_projection(y)
|
|
339
348
|
v = self.v_projection(y)
|
|
340
349
|
|
|
341
|
-
interim_shape = (batch_size, -1, self.n_heads, self.head_dim)
|
|
350
|
+
interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
|
|
342
351
|
q = q.view(interim_shape)
|
|
343
352
|
k = k.view(interim_shape)
|
|
344
353
|
v = v.view(interim_shape)
|
|
345
354
|
|
|
346
355
|
# Compute rotary positional embedding for query and key.
|
|
347
|
-
n_elem = int(self.config.rotary_percentage * self.head_dim)
|
|
356
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
|
348
357
|
q, k = _embed_rope(q, k, n_elem, rope)
|
|
349
358
|
|
|
350
359
|
if self.kv_cache is not None:
|
|
@@ -354,7 +363,7 @@ class CrossAttention(nn.Module):
|
|
|
354
363
|
mask = torch.zeros(
|
|
355
364
|
(batch_size, 1, target_seq_len, source_seq_len), dtype=torch.float32
|
|
356
365
|
)
|
|
357
|
-
y = self.sdpa_func(q, k, v, self.head_dim, mask=mask)
|
|
366
|
+
y = self.sdpa_func(q, k, v, self.config.head_dim, mask=mask)
|
|
358
367
|
y = y.reshape(batch_size, target_seq_len, -1)
|
|
359
368
|
|
|
360
369
|
# Compute the output projection.
|
|
@@ -28,7 +28,9 @@ def build_rope_cache(
|
|
|
28
28
|
dtype: torch.dtype = torch.float32,
|
|
29
29
|
device: torch.device = None,
|
|
30
30
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
31
|
-
"""
|
|
31
|
+
"""Precomputes Rotary Positional Embeddings.
|
|
32
|
+
|
|
33
|
+
Precompute Rotary Positional Embedding Sin and Cos values for quick lookup
|
|
32
34
|
during the inference.
|
|
33
35
|
|
|
34
36
|
Args:
|
|
@@ -84,16 +86,22 @@ def relative_position_bucket(
|
|
|
84
86
|
num_buckets: int,
|
|
85
87
|
max_distance: int,
|
|
86
88
|
) -> torch.Tensor:
|
|
87
|
-
"""
|
|
88
|
-
|
|
89
|
+
"""Adapted from Mesh Tensorflow:
|
|
90
|
+
|
|
89
91
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
|
90
92
|
|
|
91
|
-
Translate relative position to a bucket number for relative attention. The
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
93
|
+
Translate relative position to a bucket number for relative attention. The
|
|
94
|
+
relative position is defined as
|
|
95
|
+
memory_position - query_position, i.e. the distance in tokens from the
|
|
96
|
+
attending position to the attended-to
|
|
97
|
+
position. If bidirectional=False, then positive relative positions are
|
|
98
|
+
invalid. We use smaller buckets for
|
|
99
|
+
small absolute relative_position and larger buckets for larger absolute
|
|
100
|
+
relative_positions. All relative
|
|
101
|
+
positions >=max_distance map to the same bucket. All relative positions
|
|
102
|
+
<=-max_distance map to the same bucket.
|
|
103
|
+
This should allow for more graceful generalization to longer sequences than
|
|
104
|
+
the model has been trained on
|
|
97
105
|
|
|
98
106
|
Args:
|
|
99
107
|
relative_position: an int32 Tensor
|
|
@@ -102,7 +110,8 @@ def relative_position_bucket(
|
|
|
102
110
|
max_distance: an integer for max distance.
|
|
103
111
|
|
|
104
112
|
Returns:
|
|
105
|
-
a Tensor with the same shape as relative_position, containing int32 values
|
|
113
|
+
a Tensor with the same shape as relative_position, containing int32 values
|
|
114
|
+
in the range [0, num_buckets)
|
|
106
115
|
"""
|
|
107
116
|
relative_buckets = 0
|
|
108
117
|
if bidirectional:
|
|
@@ -119,7 +128,8 @@ def relative_position_bucket(
|
|
|
119
128
|
max_exact = num_buckets // 2
|
|
120
129
|
is_small = relative_position < max_exact
|
|
121
130
|
|
|
122
|
-
# The other half of the buckets are for logarithmically bigger bins in
|
|
131
|
+
# The other half of the buckets are for logarithmically bigger bins in
|
|
132
|
+
# positions up to max_distance
|
|
123
133
|
relative_position_if_large = max_exact + (
|
|
124
134
|
torch.log(relative_position.float() / max_exact)
|
|
125
135
|
/ math.log(max_distance / max_exact)
|
|
@@ -148,7 +158,8 @@ def build_relative_position_buckets(
|
|
|
148
158
|
Args:
|
|
149
159
|
query_length: an integer of length of current query tensor.
|
|
150
160
|
key_length: an integer of length of current key tensor.
|
|
151
|
-
bidirectional: a boolean - whether the attention is bidirectional, default
|
|
161
|
+
bidirectional: a boolean - whether the attention is bidirectional, default
|
|
162
|
+
is True.
|
|
152
163
|
num_buckets: an integer for number of buckets, default is 32.
|
|
153
164
|
max_distance: an integer for max distance, default is 128.
|
|
154
165
|
|
|
@@ -33,11 +33,9 @@ class SequentialFeedForward(nn.Module):
|
|
|
33
33
|
):
|
|
34
34
|
"""Init function for feedforward layer.
|
|
35
35
|
|
|
36
|
-
Args:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
activation(Callable): activation function used in this block.
|
|
40
|
-
use_bias(Boolean): whether to use bias. Default is false.
|
|
36
|
+
Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
|
|
37
|
+
feedforward layer. activation(Callable): activation function used in this
|
|
38
|
+
block. use_bias(Boolean): whether to use bias. Default is false.
|
|
41
39
|
"""
|
|
42
40
|
super().__init__()
|
|
43
41
|
self.act = activation
|
|
@@ -71,11 +69,9 @@ class GatedFeedForward(nn.Module):
|
|
|
71
69
|
):
|
|
72
70
|
"""Init function for feedforward layer.
|
|
73
71
|
|
|
74
|
-
Args:
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
activation(Callable): activation function used in this block.
|
|
78
|
-
use_bias(Boolean): whether to use bias. Default is false.
|
|
72
|
+
Args: dim(int): embedding size. hidden_dim(int): hidden dim size of the
|
|
73
|
+
feedforward layer. activation(Callable): activation function used in this
|
|
74
|
+
block. use_bias(Boolean): whether to use bias. Default is false.
|
|
79
75
|
"""
|
|
80
76
|
super().__init__()
|
|
81
77
|
self.act = activation
|
|
@@ -55,9 +55,10 @@ class FeedForwardType(enum.Enum):
|
|
|
55
55
|
|
|
56
56
|
@dataclass
|
|
57
57
|
class AttentionConfig:
|
|
58
|
-
"""Attention
|
|
58
|
+
"""Attention model's parameters."""
|
|
59
59
|
|
|
60
60
|
num_heads: int
|
|
61
|
+
head_dim: int
|
|
61
62
|
# Used to determine number of groups in grouped query attention (GQA)
|
|
62
63
|
# https://arxiv.org/pdf/2305.13245.pdf
|
|
63
64
|
num_query_groups: Optional[int]
|
|
@@ -156,7 +157,3 @@ class ModelConfig:
|
|
|
156
157
|
return self.kv_cache_max_len
|
|
157
158
|
else:
|
|
158
159
|
return self.max_seq_len
|
|
159
|
-
|
|
160
|
-
@property
|
|
161
|
-
def head_dim(self) -> int:
|
|
162
|
-
return self.embedding_dim // self.attn_config.num_heads
|
|
@@ -21,12 +21,12 @@ import torch
|
|
|
21
21
|
class RMSNorm(torch.nn.Module):
|
|
22
22
|
|
|
23
23
|
def __init__(self, dim: int, eps: float = 1e-6, zero_centered_gamma=False):
|
|
24
|
-
"""
|
|
25
|
-
Initialize the RMSNorm layer.
|
|
24
|
+
"""Initialize the RMSNorm layer.
|
|
26
25
|
|
|
27
26
|
Args:
|
|
28
27
|
dim (int): dimension of the input tensor.
|
|
29
|
-
eps (float): A small float value to ensure numerical stability (default:
|
|
28
|
+
eps (float): A small float value to ensure numerical stability (default:
|
|
29
|
+
1e-6).
|
|
30
30
|
"""
|
|
31
31
|
super().__init__()
|
|
32
32
|
self.eps = eps
|
|
@@ -34,8 +34,7 @@ class RMSNorm(torch.nn.Module):
|
|
|
34
34
|
self.zero_centered_gamma = zero_centered_gamma
|
|
35
35
|
|
|
36
36
|
def _norm(self, x):
|
|
37
|
-
"""
|
|
38
|
-
Apply RMSNorm normalization.
|
|
37
|
+
"""Apply RMSNorm normalization.
|
|
39
38
|
|
|
40
39
|
Args:
|
|
41
40
|
x (torch.Tensor): input tensor.
|
|
@@ -46,8 +45,7 @@ class RMSNorm(torch.nn.Module):
|
|
|
46
45
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
47
46
|
|
|
48
47
|
def forward(self, x):
|
|
49
|
-
"""
|
|
50
|
-
Running the forward pass of RMSNorm layer.
|
|
48
|
+
"""Running the forward pass of RMSNorm layer.
|
|
51
49
|
|
|
52
50
|
Args:
|
|
53
51
|
x (torch.Tensor): input tensor.
|
|
@@ -22,9 +22,9 @@ def apply_rope(
|
|
|
22
22
|
"""Computes rotary positional embedding.
|
|
23
23
|
|
|
24
24
|
Args:
|
|
25
|
-
x
|
|
26
|
-
cos
|
|
27
|
-
sin
|
|
25
|
+
x: the input tensor.
|
|
26
|
+
cos: cosine value for the rope.
|
|
27
|
+
sin: sin value for the rope.
|
|
28
28
|
|
|
29
29
|
Returns:
|
|
30
30
|
output tensor of RoPE.
|
|
@@ -105,7 +105,6 @@ class AttentionBlock2D(nn.Module):
|
|
|
105
105
|
"""2D self attention block
|
|
106
106
|
|
|
107
107
|
x = SelfAttention(Norm(input_tensor)) + x
|
|
108
|
-
|
|
109
108
|
"""
|
|
110
109
|
|
|
111
110
|
def __init__(self, config: unet_cfg.AttentionBlock2DConfig):
|
|
@@ -161,14 +160,14 @@ class CrossAttentionBlock2D(nn.Module):
|
|
|
161
160
|
"""2D cross attention block
|
|
162
161
|
|
|
163
162
|
x = CrossAttention(Norm(input_tensor), context) + x
|
|
164
|
-
|
|
165
163
|
"""
|
|
166
164
|
|
|
167
165
|
def __init__(self, config: unet_cfg.CrossAttentionBlock2DConfig):
|
|
168
166
|
"""Initialize an instance of the AttentionBlock2D.
|
|
169
167
|
|
|
170
168
|
Args:
|
|
171
|
-
config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this
|
|
169
|
+
config (unet_cfg.CrossAttentionBlock2DConfig): the configuration of this
|
|
170
|
+
block.
|
|
172
171
|
"""
|
|
173
172
|
super().__init__()
|
|
174
173
|
self.config = config
|
|
@@ -191,7 +190,8 @@ class CrossAttentionBlock2D(nn.Module):
|
|
|
191
190
|
|
|
192
191
|
Args:
|
|
193
192
|
input_tensor (torch.Tensor): the input tensor.
|
|
194
|
-
context_tensor (torch.Tensor): the context tensor to apply cross attention
|
|
193
|
+
context_tensor (torch.Tensor): the context tensor to apply cross attention
|
|
194
|
+
on.
|
|
195
195
|
|
|
196
196
|
Returns:
|
|
197
197
|
output activation tensor after cross attention.
|
|
@@ -220,7 +220,6 @@ class FeedForwardBlock2D(nn.Module):
|
|
|
220
220
|
"""2D feed forward block
|
|
221
221
|
|
|
222
222
|
x = w2(Activation(w1(Norm(x)))) + x
|
|
223
|
-
|
|
224
223
|
"""
|
|
225
224
|
|
|
226
225
|
def __init__(
|
|
@@ -291,15 +290,14 @@ class TransformerBlock2D(nn.Module):
|
|
|
291
290
|
└─────────┬─────────┘
|
|
292
291
|
▼
|
|
293
292
|
hidden_states
|
|
294
|
-
|
|
295
|
-
|
|
296
293
|
"""
|
|
297
294
|
|
|
298
295
|
def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
|
|
299
296
|
"""Initialize an instance of the TransformerBlock2D.
|
|
300
297
|
|
|
301
298
|
Args:
|
|
302
|
-
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
|
|
299
|
+
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
|
|
300
|
+
block.
|
|
303
301
|
"""
|
|
304
302
|
super().__init__()
|
|
305
303
|
self.config = config
|
|
@@ -329,7 +327,8 @@ class TransformerBlock2D(nn.Module):
|
|
|
329
327
|
|
|
330
328
|
Args:
|
|
331
329
|
input_tensor (torch.Tensor): the input tensor.
|
|
332
|
-
context_tensor (torch.Tensor): the context tensor to apply cross attention
|
|
330
|
+
context_tensor (torch.Tensor): the context tensor to apply cross attention
|
|
331
|
+
on.
|
|
333
332
|
|
|
334
333
|
Returns:
|
|
335
334
|
output activation tensor after transformer block.
|
|
@@ -377,7 +376,8 @@ class DownEncoderBlock2D(nn.Module):
|
|
|
377
376
|
"""Initialize an instance of the DownEncoderBlock2D.
|
|
378
377
|
|
|
379
378
|
Args:
|
|
380
|
-
config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this
|
|
379
|
+
config (unet_cfg.DownEncoderBlock2DConfig): the configuration of this
|
|
380
|
+
block.
|
|
381
381
|
"""
|
|
382
382
|
super().__init__()
|
|
383
383
|
self.config = config
|
|
@@ -418,10 +418,13 @@ class DownEncoderBlock2D(nn.Module):
|
|
|
418
418
|
|
|
419
419
|
Args:
|
|
420
420
|
input_tensor (torch.Tensor): the input tensor.
|
|
421
|
-
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
422
|
-
time embedding.
|
|
423
|
-
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
424
|
-
|
|
421
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
422
|
+
configured to accept time embedding.
|
|
423
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
424
|
+
configured to use transofrmer block.
|
|
425
|
+
output_hidden_states (bool): whether to output hidden states, usually for
|
|
426
|
+
skip connections.
|
|
427
|
+
|
|
425
428
|
Returns:
|
|
426
429
|
output hidden_states tensor after DownEncoderBlock2D.
|
|
427
430
|
"""
|
|
@@ -523,9 +526,10 @@ class UpDecoderBlock2D(nn.Module):
|
|
|
523
526
|
|
|
524
527
|
Args:
|
|
525
528
|
input_tensor (torch.Tensor): the input tensor.
|
|
526
|
-
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
527
|
-
time embedding.
|
|
528
|
-
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
529
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
530
|
+
configured to accept time embedding.
|
|
531
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
532
|
+
configured to use transofrmer block.
|
|
529
533
|
|
|
530
534
|
Returns:
|
|
531
535
|
output hidden_states tensor after UpDecoderBlock2D.
|
|
@@ -576,7 +580,8 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
|
576
580
|
"""Initialize an instance of the SkipUpDecoderBlock2D.
|
|
577
581
|
|
|
578
582
|
Args:
|
|
579
|
-
config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this
|
|
583
|
+
config (unet_cfg.SkipUpDecoderBlock2DConfig): the configuration of this
|
|
584
|
+
block.
|
|
580
585
|
"""
|
|
581
586
|
super().__init__()
|
|
582
587
|
self.config = config
|
|
@@ -632,10 +637,12 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
|
632
637
|
|
|
633
638
|
Args:
|
|
634
639
|
input_tensor (torch.Tensor): the input tensor.
|
|
635
|
-
skip_connection_tensors (List[torch.Tensor]): the skip connection tensors
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
640
|
+
skip_connection_tensors (List[torch.Tensor]): the skip connection tensors
|
|
641
|
+
from encoder blocks.
|
|
642
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
643
|
+
configured to accept time embedding.
|
|
644
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
645
|
+
configured to use transofrmer block.
|
|
639
646
|
|
|
640
647
|
Returns:
|
|
641
648
|
output hidden_states tensor after SkipUpDecoderBlock2D.
|
|
@@ -738,10 +745,10 @@ class MidBlock2D(nn.Module):
|
|
|
738
745
|
|
|
739
746
|
Args:
|
|
740
747
|
input_tensor (torch.Tensor): the input tensor.
|
|
741
|
-
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
742
|
-
time embedding.
|
|
743
|
-
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
744
|
-
transofrmer block.
|
|
748
|
+
time_emb (torch.Tensor): optional time embedding tensor, if the block is
|
|
749
|
+
configured to accept time embedding.
|
|
750
|
+
context_tensor (torch.Tensor): optional context tensor, if the block if
|
|
751
|
+
configured to use transofrmer block.
|
|
745
752
|
|
|
746
753
|
Returns:
|
|
747
754
|
output hidden_states tensor after MidBlock2D.
|