ai-edge-torch-nightly 0.3.0.dev20240812__py3-none-any.whl → 0.3.0.dev20240814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +2 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +67 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +3 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +250 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -2
- ai_edge_torch/generative/examples/t5/t5.py +4 -4
- ai_edge_torch/generative/examples/t5/t5_attention.py +3 -3
- ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/layers/attention.py +12 -5
- ai_edge_torch/generative/layers/attention_utils.py +30 -0
- ai_edge_torch/generative/layers/builder.py +5 -0
- ai_edge_torch/generative/layers/feed_forward.py +15 -3
- ai_edge_torch/generative/layers/model_config.py +35 -13
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +25 -9
- ai_edge_torch/generative/test/test_model_conversion.py +29 -1
- ai_edge_torch/generative/utilities/loader.py +29 -7
- ai_edge_torch/generative/utilities/t5_loader.py +8 -8
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240812.dist-info → ai_edge_torch_nightly-0.3.0.dev20240814.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240812.dist-info → ai_edge_torch_nightly-0.3.0.dev20240814.dist-info}/RECORD +27 -25
- {ai_edge_torch_nightly-0.3.0.dev20240812.dist-info → ai_edge_torch_nightly-0.3.0.dev20240814.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240812.dist-info → ai_edge_torch_nightly-0.3.0.dev20240814.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240812.dist-info → ai_edge_torch_nightly-0.3.0.dev20240814.dist-info}/top_level.txt +0 -0
|
@@ -40,7 +40,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
|
40
40
|
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
41
41
|
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
42
42
|
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
43
|
-
|
|
43
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
|
44
44
|
embedding="model.embed_tokens",
|
|
45
45
|
final_norm="model.norm",
|
|
46
46
|
lm_head=None,
|
|
@@ -150,7 +150,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
150
150
|
attn_config=attn_config,
|
|
151
151
|
ff_config=ff_config,
|
|
152
152
|
pre_attention_norm_config=norm_config,
|
|
153
|
-
|
|
153
|
+
post_attention_norm_config=norm_config,
|
|
154
154
|
final_norm_config=norm_config,
|
|
155
155
|
parallel_residual=False,
|
|
156
156
|
lm_head_use_bias=False,
|
|
@@ -41,7 +41,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
|
41
41
|
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
42
42
|
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
43
43
|
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
44
|
-
|
|
44
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
|
45
45
|
embedding="model.embed_tokens",
|
|
46
46
|
final_norm="model.norm",
|
|
47
47
|
lm_head="lm_head",
|
|
@@ -142,7 +142,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
142
142
|
attn_config=attn_config,
|
|
143
143
|
ff_config=ff_config,
|
|
144
144
|
pre_attention_norm_config=norm_config,
|
|
145
|
-
|
|
145
|
+
post_attention_norm_config=norm_config,
|
|
146
146
|
final_norm_config=norm_config,
|
|
147
147
|
enable_hlfb=True,
|
|
148
148
|
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import ai_edge_torch
|
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
|
21
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def convert_gemma_to_tflite(
|
|
26
|
+
checkpoint_path: str,
|
|
27
|
+
prefill_seq_len: int = 512,
|
|
28
|
+
kv_cache_max_len: int = 1024,
|
|
29
|
+
quantize: bool = True,
|
|
30
|
+
):
|
|
31
|
+
"""Converting a Gemma 2 2B model to multi-signature
|
|
32
|
+
tflite model.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
|
|
36
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
|
37
|
+
Defaults to 512.
|
|
38
|
+
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
|
39
|
+
including both prefill and decode. Defaults to 1024.
|
|
40
|
+
quantize (bool, optional): Whether the model should be quanized.
|
|
41
|
+
Defaults to True.
|
|
42
|
+
"""
|
|
43
|
+
pytorch_model = gemma2.build_2b_model(
|
|
44
|
+
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
|
45
|
+
)
|
|
46
|
+
# Tensors used to trace the model graph during conversion.
|
|
47
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
|
48
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
|
49
|
+
decode_token = torch.tensor([[0]], dtype=torch.long)
|
|
50
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
|
51
|
+
|
|
52
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
|
53
|
+
edge_model = (
|
|
54
|
+
ai_edge_torch.signature(
|
|
55
|
+
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
|
|
56
|
+
)
|
|
57
|
+
.signature('decode', pytorch_model, (decode_token, decode_input_pos))
|
|
58
|
+
.convert(quant_config=quant_config)
|
|
59
|
+
)
|
|
60
|
+
edge_model.export(
|
|
61
|
+
f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
|
|
67
|
+
convert_gemma_to_tflite(checkpoint_path)
|
|
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
|
35
35
|
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
36
36
|
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
37
37
|
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
38
|
-
|
|
38
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
|
39
39
|
embedding="model.embed_tokens",
|
|
40
40
|
final_norm="model.norm",
|
|
41
41
|
lm_head=None,
|
|
@@ -138,7 +138,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
138
138
|
attn_config=attn_config,
|
|
139
139
|
ff_config=ff_config,
|
|
140
140
|
pre_attention_norm_config=norm_config,
|
|
141
|
-
|
|
141
|
+
post_attention_norm_config=norm_config,
|
|
142
142
|
final_norm_config=norm_config,
|
|
143
143
|
parallel_residual=False,
|
|
144
144
|
lm_head_use_bias=False,
|
|
@@ -160,6 +160,7 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
|
160
160
|
# since embedding and lm-head use the same weight, we need to set strict
|
|
161
161
|
# to False.
|
|
162
162
|
loader.load(model, strict=False)
|
|
163
|
+
model.eval()
|
|
163
164
|
return model
|
|
164
165
|
|
|
165
166
|
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Example of building the Gemma2 2B model.
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Optional, Tuple
|
|
20
|
+
|
|
21
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
22
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
|
23
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
25
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
|
|
30
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
31
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
|
32
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
|
33
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
|
34
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
|
35
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
36
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
37
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
|
38
|
+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
|
39
|
+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
|
40
|
+
embedding="embedder",
|
|
41
|
+
final_norm="model.norm",
|
|
42
|
+
lm_head=None,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Gemma2Block(TransformerBlock):
|
|
47
|
+
|
|
48
|
+
def forward(
|
|
49
|
+
self,
|
|
50
|
+
x: torch.Tensor,
|
|
51
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
52
|
+
mask: Optional[torch.Tensor] = None,
|
|
53
|
+
input_pos: Optional[torch.Tensor] = None,
|
|
54
|
+
) -> torch.Tensor:
|
|
55
|
+
"""Forward function of the Gemma2Block.
|
|
56
|
+
|
|
57
|
+
Exactly the same as TransformerBlock but we call the post-attention norm
|
|
58
|
+
immediately after attention and not after the residual pointwise addition.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
x (torch.Tensor): the input tensor.
|
|
62
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
|
63
|
+
mask (torch.Tensor): the optional mask tensor.
|
|
64
|
+
input_pos (torch.Tensor): the optional input position tensor.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
output activation from this transformer block.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
x_norm = self.pre_atten_norm(x)
|
|
71
|
+
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
|
72
|
+
attn_out_norm = self.post_atten_norm(attn_out)
|
|
73
|
+
x = x + attn_out_norm
|
|
74
|
+
output = x + self.ff(x)
|
|
75
|
+
return output
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Gemma2(nn.Module):
|
|
79
|
+
|
|
80
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
81
|
+
super().__init__()
|
|
82
|
+
|
|
83
|
+
self.config = config
|
|
84
|
+
# Construct model layers.
|
|
85
|
+
self.tok_embedding = nn.Embedding(
|
|
86
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
87
|
+
)
|
|
88
|
+
self.lm_head = nn.Linear(
|
|
89
|
+
config.embedding_dim,
|
|
90
|
+
config.vocab_size,
|
|
91
|
+
bias=config.lm_head_use_bias,
|
|
92
|
+
)
|
|
93
|
+
# Gemma re-uses the embedding as the head projection layer.
|
|
94
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
|
95
|
+
self.transformer_blocks = nn.ModuleList(
|
|
96
|
+
Gemma2Block(config) for _ in range(config.num_layers)
|
|
97
|
+
)
|
|
98
|
+
self.final_norm = builder.build_norm(
|
|
99
|
+
config.embedding_dim,
|
|
100
|
+
config.final_norm_config,
|
|
101
|
+
)
|
|
102
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
|
103
|
+
size=config.kv_cache_max,
|
|
104
|
+
dim=int(
|
|
105
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
|
106
|
+
),
|
|
107
|
+
base=10_000,
|
|
108
|
+
condense_ratio=1,
|
|
109
|
+
dtype=torch.float32,
|
|
110
|
+
device=torch.device("cpu"),
|
|
111
|
+
)
|
|
112
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
|
113
|
+
size=config.kv_cache_max,
|
|
114
|
+
dtype=torch.float32,
|
|
115
|
+
device=torch.device("cpu"),
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
|
119
|
+
size=config.kv_cache_max,
|
|
120
|
+
window_size=self.config.attn_config.sliding_window_size,
|
|
121
|
+
dtype=torch.float32,
|
|
122
|
+
device=torch.device("cpu"),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.config = config
|
|
126
|
+
|
|
127
|
+
def get_attention_mask(
|
|
128
|
+
self, idx: int, input_pos: torch.Tensor
|
|
129
|
+
) -> torch.Tensor:
|
|
130
|
+
if self.config.attn_config.attn_types:
|
|
131
|
+
if (
|
|
132
|
+
self.config.attn_config.attn_types[idx]
|
|
133
|
+
== cfg.AttentionType.LOCAL_SLIDING
|
|
134
|
+
):
|
|
135
|
+
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
|
136
|
+
|
|
137
|
+
return self.mask_cache.index_select(2, input_pos)
|
|
138
|
+
|
|
139
|
+
@torch.inference_mode
|
|
140
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
|
141
|
+
B, T = idx.size()
|
|
142
|
+
assert self.config.max_seq_len >= T, (
|
|
143
|
+
f"Cannot forward sequence of length {T}, max seq length is only"
|
|
144
|
+
f" {self.config.max_seq_len}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
cos, sin = self.rope_cache
|
|
148
|
+
cos = cos.index_select(0, input_pos)
|
|
149
|
+
sin = sin.index_select(0, input_pos)
|
|
150
|
+
|
|
151
|
+
# token embeddings of shape (b, t, n_embd)
|
|
152
|
+
x = self.tok_embedding(idx)
|
|
153
|
+
x = x * (self.config.embedding_dim**0.5)
|
|
154
|
+
|
|
155
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
156
|
+
mask = self.get_attention_mask(i, input_pos)
|
|
157
|
+
x = block(x, (cos, sin), mask, input_pos)
|
|
158
|
+
|
|
159
|
+
x = self.final_norm(x)
|
|
160
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
|
161
|
+
if self.config.final_logit_softcap is not None:
|
|
162
|
+
res = res / self.config.final_logit_softcap
|
|
163
|
+
res = torch.tanh(res)
|
|
164
|
+
res = res * self.config.final_logit_softcap
|
|
165
|
+
return res
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
169
|
+
attn_config = cfg.AttentionConfig(
|
|
170
|
+
num_heads=8,
|
|
171
|
+
head_dim=256,
|
|
172
|
+
num_query_groups=4,
|
|
173
|
+
rotary_percentage=1.0,
|
|
174
|
+
qkv_transpose_before_split=True,
|
|
175
|
+
logit_softcap=50.0,
|
|
176
|
+
sliding_window_size=4096,
|
|
177
|
+
attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
|
|
178
|
+
* 13,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
norm_config = cfg.NormalizationConfig(
|
|
182
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
183
|
+
epsilon=1e-6,
|
|
184
|
+
zero_centered=True,
|
|
185
|
+
)
|
|
186
|
+
ff_config = cfg.FeedForwardConfig(
|
|
187
|
+
type=cfg.FeedForwardType.GATED,
|
|
188
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
189
|
+
intermediate_size=9216,
|
|
190
|
+
pre_ff_norm_config=norm_config,
|
|
191
|
+
post_ff_norm_config=norm_config,
|
|
192
|
+
)
|
|
193
|
+
config = cfg.ModelConfig(
|
|
194
|
+
vocab_size=256000,
|
|
195
|
+
num_layers=26,
|
|
196
|
+
max_seq_len=8192,
|
|
197
|
+
embedding_dim=2304,
|
|
198
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
199
|
+
attn_config=attn_config,
|
|
200
|
+
ff_config=ff_config,
|
|
201
|
+
pre_attention_norm_config=norm_config,
|
|
202
|
+
post_attention_norm_config=norm_config,
|
|
203
|
+
final_norm_config=norm_config,
|
|
204
|
+
parallel_residual=False,
|
|
205
|
+
lm_head_use_bias=False,
|
|
206
|
+
enable_hlfb=False,
|
|
207
|
+
final_logit_softcap=30.0,
|
|
208
|
+
)
|
|
209
|
+
return config
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
|
|
213
|
+
config = get_model_config_2b()
|
|
214
|
+
config.num_layers = 2
|
|
215
|
+
return config
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
219
|
+
config = get_model_config_2b(**kwargs)
|
|
220
|
+
model = Gemma2(config)
|
|
221
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
|
222
|
+
# since embedding and lm-head use the same weight, we need to set strict
|
|
223
|
+
# to False.
|
|
224
|
+
loader.load(model, strict=False)
|
|
225
|
+
model.eval()
|
|
226
|
+
return model
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def define_and_run_2b() -> None:
|
|
230
|
+
current_dir = Path(__file__).parent.resolve()
|
|
231
|
+
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
|
232
|
+
print("Running GEMMA 2")
|
|
233
|
+
kv_cache_max_len = 1024
|
|
234
|
+
checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
|
|
235
|
+
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
|
236
|
+
toks = torch.from_numpy(
|
|
237
|
+
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
|
238
|
+
)
|
|
239
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
|
240
|
+
tokens[0, :9] = toks
|
|
241
|
+
input_pos = torch.arange(0, kv_cache_max_len)
|
|
242
|
+
out = model.forward(tokens, input_pos)
|
|
243
|
+
out_final = out[0, 8, :]
|
|
244
|
+
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
|
245
|
+
print(out)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
if __name__ == "__main__":
|
|
249
|
+
torch.set_printoptions(sci_mode=True)
|
|
250
|
+
define_and_run_2b()
|
|
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
|
35
35
|
pre_attn_norm=(
|
|
36
36
|
"cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1"
|
|
37
37
|
),
|
|
38
|
-
|
|
38
|
+
post_attn_norm=(
|
|
39
39
|
"cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2"
|
|
40
40
|
),
|
|
41
41
|
embedding=(
|
|
@@ -120,7 +120,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
120
120
|
attn_config=attn_config,
|
|
121
121
|
ff_config=ff_config,
|
|
122
122
|
pre_attention_norm_config=norm_config,
|
|
123
|
-
|
|
123
|
+
post_attention_norm_config=norm_config,
|
|
124
124
|
final_norm_config=norm_config,
|
|
125
125
|
enable_hlfb=True,
|
|
126
126
|
)
|
|
@@ -38,7 +38,7 @@ ENCDEC_TENSOR_NAMES = {
|
|
|
38
38
|
"{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
|
|
39
39
|
),
|
|
40
40
|
"pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
|
|
41
|
-
"
|
|
41
|
+
"post_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
42
42
|
"final_norm": "{prefix}.final_layer_norm",
|
|
43
43
|
}
|
|
44
44
|
|
|
@@ -396,7 +396,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
|
|
|
396
396
|
relative_attention=True,
|
|
397
397
|
ff_config=ff_config,
|
|
398
398
|
pre_attention_norm_config=norm_config,
|
|
399
|
-
|
|
399
|
+
post_attention_norm_config=norm_config,
|
|
400
400
|
final_norm_config=norm_config,
|
|
401
401
|
parallel_residual=False,
|
|
402
402
|
lm_head_use_bias=False,
|
|
@@ -419,7 +419,7 @@ def build_t5_model(checkpoint_path: str) -> nn.Module:
|
|
|
419
419
|
"cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
|
|
420
420
|
"cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
|
|
421
421
|
# In the decoder, the FF is layer 2 in the Transformer block
|
|
422
|
-
"
|
|
422
|
+
"post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
|
|
423
423
|
# In the decoder, the cross attention is layer 1 in the Transformer block
|
|
424
424
|
"pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
425
425
|
}
|
|
@@ -475,7 +475,7 @@ def build_t5_decoder_model(
|
|
|
475
475
|
"cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
|
|
476
476
|
"cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
|
|
477
477
|
# In the decoder, the FF is layer 2 in the Transformer block
|
|
478
|
-
"
|
|
478
|
+
"post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
|
|
479
479
|
# In the decoder, the cross attention is layer 1 in the Transformer block
|
|
480
480
|
"pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
|
|
481
481
|
}
|
|
@@ -68,8 +68,8 @@ class EncoderDecoderBlock(nn.Module):
|
|
|
68
68
|
else:
|
|
69
69
|
self.cross_atten_func = None
|
|
70
70
|
|
|
71
|
-
self.
|
|
72
|
-
config.embedding_dim, config.
|
|
71
|
+
self.post_atten_norm = builder.build_norm(
|
|
72
|
+
config.embedding_dim, config.post_attention_norm_config
|
|
73
73
|
)
|
|
74
74
|
self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
|
|
75
75
|
self.config = config
|
|
@@ -118,7 +118,7 @@ class EncoderDecoderBlock(nn.Module):
|
|
|
118
118
|
)
|
|
119
119
|
attn_out = hidden_states + attn_out
|
|
120
120
|
|
|
121
|
-
forwarded = self.
|
|
121
|
+
forwarded = self.post_atten_norm(attn_out)
|
|
122
122
|
forwarded = self.ff(forwarded)
|
|
123
123
|
hidden_states = attn_out + forwarded
|
|
124
124
|
|
|
@@ -93,7 +93,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
93
93
|
attn_config=attn_config,
|
|
94
94
|
ff_config=ff_config,
|
|
95
95
|
pre_attention_norm_config=norm_config,
|
|
96
|
-
|
|
96
|
+
post_attention_norm_config=norm_config,
|
|
97
97
|
final_norm_config=norm_config,
|
|
98
98
|
)
|
|
99
99
|
return config
|
|
@@ -107,7 +107,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
107
107
|
attn_config=attn_config,
|
|
108
108
|
ff_config=ff_config,
|
|
109
109
|
pre_attention_norm_config=norm_config,
|
|
110
|
-
|
|
110
|
+
post_attention_norm_config=norm_config,
|
|
111
111
|
final_norm_config=norm_config,
|
|
112
112
|
enable_hlfb=True,
|
|
113
113
|
)
|
|
@@ -94,7 +94,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
94
94
|
attn_config=attn_config,
|
|
95
95
|
ff_config=ff_config,
|
|
96
96
|
pre_attention_norm_config=norm_config,
|
|
97
|
-
|
|
97
|
+
post_attention_norm_config=norm_config,
|
|
98
98
|
final_norm_config=norm_config,
|
|
99
99
|
enable_hlfb=True,
|
|
100
100
|
)
|
|
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
|
35
35
|
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
|
36
36
|
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
|
37
37
|
pre_attn_norm="model.layers.{}.input_layernorm",
|
|
38
|
-
|
|
38
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
|
39
39
|
embedding="model.embed_tokens",
|
|
40
40
|
final_norm="model.norm",
|
|
41
41
|
lm_head="lm_head",
|
|
@@ -130,7 +130,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
130
130
|
attn_config=attn_config,
|
|
131
131
|
ff_config=ff_config,
|
|
132
132
|
pre_attention_norm_config=norm_config,
|
|
133
|
-
|
|
133
|
+
post_attention_norm_config=norm_config,
|
|
134
134
|
final_norm_config=norm_config,
|
|
135
135
|
enable_hlfb=True,
|
|
136
136
|
)
|
|
@@ -74,8 +74,8 @@ class TransformerBlock(nn.Module):
|
|
|
74
74
|
config.kv_cache_max,
|
|
75
75
|
config.enable_hlfb,
|
|
76
76
|
)
|
|
77
|
-
self.
|
|
78
|
-
config.embedding_dim, config.
|
|
77
|
+
self.post_atten_norm = builder.build_norm(
|
|
78
|
+
config.embedding_dim, config.post_attention_norm_config
|
|
79
79
|
)
|
|
80
80
|
self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
|
|
81
81
|
self.config = config
|
|
@@ -108,7 +108,7 @@ class TransformerBlock(nn.Module):
|
|
|
108
108
|
x_norm = self.pre_atten_norm(x)
|
|
109
109
|
attn_out = self.atten_func(x_norm, rope, mask, input_pos)
|
|
110
110
|
x = x + attn_out
|
|
111
|
-
x_norm = self.
|
|
111
|
+
x_norm = self.post_atten_norm(x)
|
|
112
112
|
output = x + self.ff(x_norm)
|
|
113
113
|
|
|
114
114
|
return output
|
|
@@ -228,8 +228,15 @@ class CausalSelfAttention(nn.Module):
|
|
|
228
228
|
# TODO(haoliang): Handle when execeeding max sequence length.
|
|
229
229
|
k, v = self.kv_cache.update_cache(input_pos, k, v)
|
|
230
230
|
|
|
231
|
-
y = self.sdpa_func(
|
|
232
|
-
|
|
231
|
+
y = self.sdpa_func(
|
|
232
|
+
q,
|
|
233
|
+
k,
|
|
234
|
+
v,
|
|
235
|
+
self.config.head_dim,
|
|
236
|
+
mask=mask,
|
|
237
|
+
softcap=self.config.logit_softcap,
|
|
238
|
+
)
|
|
239
|
+
y = y.reshape(B, T, -1)
|
|
233
240
|
|
|
234
241
|
# Compute the output projection.
|
|
235
242
|
y = self.output_projection(y)
|
|
@@ -74,12 +74,42 @@ def build_causal_mask_cache(
|
|
|
74
74
|
Returns:
|
|
75
75
|
torch.Tensor: Causal attention mask.
|
|
76
76
|
"""
|
|
77
|
+
|
|
77
78
|
if device is None:
|
|
78
79
|
device = torch.device('cpu')
|
|
79
80
|
mask = torch.full((size, size), float('-inf'), dtype=dtype, device=device)
|
|
80
81
|
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
|
81
82
|
|
|
82
83
|
|
|
84
|
+
def build_sliding_window_mask_cache(
|
|
85
|
+
size: int,
|
|
86
|
+
window_size: int,
|
|
87
|
+
dtype: torch.dtype = torch.float32,
|
|
88
|
+
device: torch.device = None,
|
|
89
|
+
) -> torch.Tensor:
|
|
90
|
+
"""Build a cache for a sliding window mask.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
size (int): The size of the built mask cache.
|
|
94
|
+
window_size (int): The window size that is "seen" by a token.
|
|
95
|
+
dtype (torch.dtype, optional): Output tensor's data type. Defaults to
|
|
96
|
+
torch.float32.
|
|
97
|
+
device (torch.device, optional): Output tensor's data type. Defaults to
|
|
98
|
+
None in which case "cpu" is used.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
torch.Tensor: Causal attention mask.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
mask = build_causal_mask_cache(size, dtype, device)
|
|
105
|
+
all_ones = torch.ones_like(mask)
|
|
106
|
+
window_size = min(size, window_size)
|
|
107
|
+
sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
|
|
108
|
+
all_ones, window_size - 1
|
|
109
|
+
)
|
|
110
|
+
return torch.where(sliding_mask == 1, mask, -2.3819763e38)
|
|
111
|
+
|
|
112
|
+
|
|
83
113
|
def relative_position_bucket(
|
|
84
114
|
relative_position: torch.Tensor,
|
|
85
115
|
bidirectional: bool,
|
|
@@ -89,11 +89,16 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
|
|
|
89
89
|
|
|
90
90
|
activation = get_activation(config.activation)
|
|
91
91
|
|
|
92
|
+
pre_ff_norm = build_norm(dim, config.pre_ff_norm_config)
|
|
93
|
+
post_ff_norm = build_norm(dim, config.post_ff_norm_config)
|
|
94
|
+
|
|
92
95
|
return ff_module(
|
|
93
96
|
dim=dim,
|
|
94
97
|
hidden_dim=config.intermediate_size,
|
|
95
98
|
activation=activation,
|
|
96
99
|
use_bias=config.use_bias,
|
|
100
|
+
pre_ff_norm=pre_ff_norm,
|
|
101
|
+
post_ff_norm=post_ff_norm,
|
|
97
102
|
)
|
|
98
103
|
|
|
99
104
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
# Common building blocks for FeedForward layers.
|
|
16
16
|
|
|
17
|
-
from typing import Callable
|
|
17
|
+
from typing import Callable, Optional
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
20
|
from torch import nn
|
|
@@ -30,6 +30,8 @@ class SequentialFeedForward(nn.Module):
|
|
|
30
30
|
hidden_dim: int,
|
|
31
31
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
|
32
32
|
use_bias=False,
|
|
33
|
+
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
34
|
+
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
33
35
|
):
|
|
34
36
|
"""Init function for feedforward layer.
|
|
35
37
|
|
|
@@ -41,6 +43,8 @@ class SequentialFeedForward(nn.Module):
|
|
|
41
43
|
self.act = activation
|
|
42
44
|
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
|
43
45
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
|
46
|
+
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
|
47
|
+
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
|
|
44
48
|
|
|
45
49
|
def forward(self, x):
|
|
46
50
|
"""Forward pass for Feedforward layer.
|
|
@@ -51,7 +55,9 @@ class SequentialFeedForward(nn.Module):
|
|
|
51
55
|
Returns:
|
|
52
56
|
torch.Tensor: output tensor after feedforward.
|
|
53
57
|
"""
|
|
54
|
-
|
|
58
|
+
x_norm = self.pre_ff_norm(x)
|
|
59
|
+
out = self.w2(self.act(self.w1(x_norm)))
|
|
60
|
+
return self.post_ff_norm(out)
|
|
55
61
|
|
|
56
62
|
|
|
57
63
|
class GatedFeedForward(nn.Module):
|
|
@@ -66,6 +72,8 @@ class GatedFeedForward(nn.Module):
|
|
|
66
72
|
hidden_dim: int,
|
|
67
73
|
activation: Callable[[torch.Tensor], torch.Tensor],
|
|
68
74
|
use_bias=False,
|
|
75
|
+
pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
76
|
+
post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
|
|
69
77
|
):
|
|
70
78
|
"""Init function for feedforward layer.
|
|
71
79
|
|
|
@@ -78,6 +86,8 @@ class GatedFeedForward(nn.Module):
|
|
|
78
86
|
self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
|
79
87
|
self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
|
|
80
88
|
self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
|
|
89
|
+
self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
|
|
90
|
+
self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
|
|
81
91
|
|
|
82
92
|
def forward(self, x):
|
|
83
93
|
"""Forward pass for Feedforward layer.
|
|
@@ -88,4 +98,6 @@ class GatedFeedForward(nn.Module):
|
|
|
88
98
|
Returns:
|
|
89
99
|
torch.Tensor: output tensor after feedforward.
|
|
90
100
|
"""
|
|
91
|
-
|
|
101
|
+
x_norm = self.pre_ff_norm(x)
|
|
102
|
+
out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
|
|
103
|
+
return self.post_ff_norm(out)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
from dataclasses import dataclass
|
|
17
17
|
from dataclasses import field
|
|
18
18
|
import enum
|
|
19
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional, Sequence
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@enum.unique
|
|
@@ -53,6 +53,11 @@ class FeedForwardType(enum.Enum):
|
|
|
53
53
|
GATED = enum.auto()
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
class AttentionType(enum.Enum):
|
|
57
|
+
GLOBAL = enum.auto()
|
|
58
|
+
LOCAL_SLIDING = enum.auto()
|
|
59
|
+
|
|
60
|
+
|
|
56
61
|
@dataclass
|
|
57
62
|
class AttentionConfig:
|
|
58
63
|
"""Attention model's parameters."""
|
|
@@ -78,6 +83,12 @@ class AttentionConfig:
|
|
|
78
83
|
enable_kv_cache: bool = True
|
|
79
84
|
relative_attention_num_buckets: int = 0
|
|
80
85
|
relative_attention_max_distance: int = 0
|
|
86
|
+
# Softcap on the output logits.
|
|
87
|
+
logit_softcap: Optional[float] = None
|
|
88
|
+
# The types of attention used in the layers of the model.
|
|
89
|
+
attn_types: Optional[Sequence[AttentionType]] = None
|
|
90
|
+
# The size of the sliding window used for local attention.
|
|
91
|
+
sliding_window_size: Optional[int] = None
|
|
81
92
|
|
|
82
93
|
|
|
83
94
|
@dataclass
|
|
@@ -88,16 +99,6 @@ class ActivationConfig:
|
|
|
88
99
|
dim_out: Optional[int] = None
|
|
89
100
|
|
|
90
101
|
|
|
91
|
-
@dataclass
|
|
92
|
-
class FeedForwardConfig:
|
|
93
|
-
"""FeedForward module's parameters."""
|
|
94
|
-
|
|
95
|
-
type: FeedForwardType
|
|
96
|
-
activation: ActivationConfig
|
|
97
|
-
intermediate_size: int
|
|
98
|
-
use_bias: bool = False
|
|
99
|
-
|
|
100
|
-
|
|
101
102
|
@dataclass
|
|
102
103
|
class NormalizationConfig:
|
|
103
104
|
"""Normalizater parameters."""
|
|
@@ -109,6 +110,24 @@ class NormalizationConfig:
|
|
|
109
110
|
group_num: Optional[float] = None
|
|
110
111
|
|
|
111
112
|
|
|
113
|
+
@dataclass
|
|
114
|
+
class FeedForwardConfig:
|
|
115
|
+
"""FeedForward module's parameters."""
|
|
116
|
+
|
|
117
|
+
type: FeedForwardType
|
|
118
|
+
activation: ActivationConfig
|
|
119
|
+
intermediate_size: int
|
|
120
|
+
use_bias: bool = False
|
|
121
|
+
# The normalization applied to feed forward's input.
|
|
122
|
+
pre_ff_norm_config: NormalizationConfig = field(
|
|
123
|
+
default_factory=NormalizationConfig
|
|
124
|
+
)
|
|
125
|
+
# The normalization applied to feed forward's output.
|
|
126
|
+
post_ff_norm_config: NormalizationConfig = field(
|
|
127
|
+
default_factory=NormalizationConfig
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
|
|
112
131
|
@dataclass
|
|
113
132
|
class ModelConfig:
|
|
114
133
|
"""Base configurations for building a transformer architecture."""
|
|
@@ -124,8 +143,8 @@ class ModelConfig:
|
|
|
124
143
|
pre_attention_norm_config: NormalizationConfig = field(
|
|
125
144
|
default_factory=NormalizationConfig
|
|
126
145
|
)
|
|
127
|
-
# The normalization applied to
|
|
128
|
-
|
|
146
|
+
# The normalization applied to attentions's output.
|
|
147
|
+
post_attention_norm_config: NormalizationConfig = field(
|
|
129
148
|
default_factory=NormalizationConfig
|
|
130
149
|
)
|
|
131
150
|
# The normalization applied before LM head.
|
|
@@ -151,6 +170,9 @@ class ModelConfig:
|
|
|
151
170
|
# Default batch size of the exported model. Default value is 1.
|
|
152
171
|
batch_size: int = 1
|
|
153
172
|
|
|
173
|
+
# Softcap on the model output logits.
|
|
174
|
+
final_logit_softcap: Optional[float] = None
|
|
175
|
+
|
|
154
176
|
@property
|
|
155
177
|
def kv_cache_max(self) -> int:
|
|
156
178
|
if self.kv_cache_max_len > 0:
|
|
@@ -29,6 +29,7 @@ def scaled_dot_product_attention(
|
|
|
29
29
|
head_size: int,
|
|
30
30
|
mask: Optional[torch.Tensor] = None,
|
|
31
31
|
scale: Optional[float] = None,
|
|
32
|
+
softcap: Optional[float] = None,
|
|
32
33
|
):
|
|
33
34
|
"""Scaled dot product attention.
|
|
34
35
|
|
|
@@ -53,15 +54,26 @@ def scaled_dot_product_attention(
|
|
|
53
54
|
# Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
|
|
54
55
|
k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
|
|
55
56
|
v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
57
|
+
if softcap is None:
|
|
58
|
+
y = F.scaled_dot_product_attention(
|
|
59
|
+
q,
|
|
60
|
+
k,
|
|
61
|
+
v,
|
|
62
|
+
attn_mask=mask,
|
|
63
|
+
dropout_p=0.0,
|
|
64
|
+
is_causal=mask is None,
|
|
65
|
+
scale=scale,
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
q.mul_(scale)
|
|
69
|
+
scores = q @ k.transpose(-1, -2)
|
|
70
|
+
scores = scores / softcap
|
|
71
|
+
scores = torch.tanh(scores)
|
|
72
|
+
scores = scores * softcap
|
|
73
|
+
scores = scores + mask
|
|
74
|
+
out = F.softmax(scores.float(), dim=-1).type_as(q)
|
|
75
|
+
y = torch.matmul(out, v)
|
|
76
|
+
|
|
65
77
|
return y.transpose(1, 2)
|
|
66
78
|
|
|
67
79
|
|
|
@@ -72,6 +84,7 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
72
84
|
head_size: int,
|
|
73
85
|
mask: Optional[torch.Tensor] = None,
|
|
74
86
|
scale: Optional[float] = None,
|
|
87
|
+
softcap: Optional[float] = None,
|
|
75
88
|
):
|
|
76
89
|
"""Scaled dot product attention with high-level function boundary enabled.
|
|
77
90
|
|
|
@@ -86,6 +99,9 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
86
99
|
The output tensor of scaled_dot_product_attention.
|
|
87
100
|
"""
|
|
88
101
|
|
|
102
|
+
if softcap is not None:
|
|
103
|
+
raise NotImplementedError("SDPA with HLFB not available with softcap.")
|
|
104
|
+
|
|
89
105
|
if scale is None:
|
|
90
106
|
scale = 1.0 / math.sqrt(head_size)
|
|
91
107
|
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
import copy
|
|
17
17
|
|
|
18
18
|
import ai_edge_torch
|
|
19
|
-
from ai_edge_torch.generative.examples.gemma import gemma
|
|
19
|
+
from ai_edge_torch.generative.examples.gemma import gemma, gemma2
|
|
20
20
|
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
21
21
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
22
22
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
@@ -202,6 +202,34 @@ class TestModelConversion(googletest.TestCase):
|
|
|
202
202
|
)
|
|
203
203
|
)
|
|
204
204
|
|
|
205
|
+
def test_gemma2(self):
|
|
206
|
+
self.skipTest("b/338288901")
|
|
207
|
+
config = gemma2.get_fake_model_config_2b_for_test()
|
|
208
|
+
model = gemma2.Gemma2(config)
|
|
209
|
+
model.eval()
|
|
210
|
+
|
|
211
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
212
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
213
|
+
tokens[0, :4] = idx
|
|
214
|
+
input_pos = torch.arange(0, 10)
|
|
215
|
+
|
|
216
|
+
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
217
|
+
|
|
218
|
+
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
219
|
+
skip_output_check = True
|
|
220
|
+
if not skip_output_check:
|
|
221
|
+
# TODO(talumbau, haoliang): debug numerical diff.
|
|
222
|
+
self.assertTrue(
|
|
223
|
+
model_coverage.compare_tflite_torch(
|
|
224
|
+
edge_model,
|
|
225
|
+
model,
|
|
226
|
+
(tokens, input_pos),
|
|
227
|
+
num_valid_inputs=1,
|
|
228
|
+
atol=1e-2,
|
|
229
|
+
rtol=1e-5,
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
205
233
|
def test_phi2(self):
|
|
206
234
|
self.skipTest("b/338288901")
|
|
207
235
|
config = phi2.get_fake_model_config_for_test()
|
|
@@ -107,7 +107,9 @@ class ModelLoader:
|
|
|
107
107
|
ff_gate_proj: str = None
|
|
108
108
|
|
|
109
109
|
pre_attn_norm: str = None
|
|
110
|
+
post_attn_norm: str = None
|
|
110
111
|
pre_ff_norm: str = None
|
|
112
|
+
post_ff_norm: str = None
|
|
111
113
|
embedding: str = None
|
|
112
114
|
embedding_position: str = None
|
|
113
115
|
final_norm: str = None
|
|
@@ -258,6 +260,26 @@ class ModelLoader:
|
|
|
258
260
|
f"{ff_gate_proj_name}.bias"
|
|
259
261
|
)
|
|
260
262
|
|
|
263
|
+
if self._names.pre_ff_norm is not None:
|
|
264
|
+
pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
|
|
265
|
+
converted_state[f"{prefix}.ff.pre_ff_norm.weight"] = state.pop(
|
|
266
|
+
f"{pre_ff_norm_name}.weight"
|
|
267
|
+
)
|
|
268
|
+
if f"{pre_ff_norm_name}.bias" in state:
|
|
269
|
+
converted_state[f"{prefix}.ff.pre_ff_norm.bias"] = state.pop(
|
|
270
|
+
f"{pre_ff_norm_name}.bias"
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if self._names.post_ff_norm is not None:
|
|
274
|
+
post_ff_norm_name = self._names.post_ff_norm.format(idx)
|
|
275
|
+
converted_state[f"{prefix}.ff.post_ff_norm.weight"] = state.pop(
|
|
276
|
+
f"{post_ff_norm_name}.weight"
|
|
277
|
+
)
|
|
278
|
+
if f"{post_ff_norm_name}.bias" in state:
|
|
279
|
+
converted_state[f"{prefix}.ff.post_ff_norm.bias"] = state.pop(
|
|
280
|
+
f"{post_ff_norm_name}.bias"
|
|
281
|
+
)
|
|
282
|
+
|
|
261
283
|
def _map_attention(
|
|
262
284
|
self,
|
|
263
285
|
idx: int,
|
|
@@ -325,14 +347,14 @@ class ModelLoader:
|
|
|
325
347
|
f"{pre_attn_norm_name}.bias"
|
|
326
348
|
)
|
|
327
349
|
|
|
328
|
-
if self._names.
|
|
329
|
-
|
|
330
|
-
converted_state[f"{prefix}.
|
|
331
|
-
f"{
|
|
350
|
+
if self._names.post_attn_norm is not None:
|
|
351
|
+
post_attn_norm_name = self._names.post_attn_norm.format(idx)
|
|
352
|
+
converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
|
|
353
|
+
f"{post_attn_norm_name}.weight"
|
|
332
354
|
)
|
|
333
|
-
if f"{
|
|
334
|
-
converted_state[f"{prefix}.
|
|
335
|
-
f"{
|
|
355
|
+
if f"{post_attn_norm_name}.bias" in state:
|
|
356
|
+
converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
|
|
357
|
+
f"{post_attn_norm_name}.bias"
|
|
336
358
|
)
|
|
337
359
|
|
|
338
360
|
def _fuse_qkv(
|
|
@@ -113,7 +113,7 @@ class ModelLoader:
|
|
|
113
113
|
|
|
114
114
|
pre_attn_norm: str = None
|
|
115
115
|
pre_cross_attn_norm: str = None
|
|
116
|
-
|
|
116
|
+
post_attn_norm: str = None
|
|
117
117
|
embedding: str = None
|
|
118
118
|
final_norm: str = None
|
|
119
119
|
lm_head: str = None
|
|
@@ -484,14 +484,14 @@ class ModelLoader:
|
|
|
484
484
|
state.pop(f"{pre_cross_attn_norm_name}.bias")
|
|
485
485
|
)
|
|
486
486
|
|
|
487
|
-
if names.
|
|
488
|
-
|
|
489
|
-
converted_state[f"{prefix}.
|
|
490
|
-
f"{
|
|
487
|
+
if names.post_attn_norm is not None:
|
|
488
|
+
post_attn_norm_name = names.post_attn_norm.format(idx)
|
|
489
|
+
converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
|
|
490
|
+
f"{post_attn_norm_name}.weight"
|
|
491
491
|
)
|
|
492
|
-
if f"{
|
|
493
|
-
converted_state[f"{prefix}.
|
|
494
|
-
f"{
|
|
492
|
+
if f"{post_attn_norm_name}.bias" in state:
|
|
493
|
+
converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
|
|
494
|
+
f"{post_attn_norm_name}.bias"
|
|
495
495
|
)
|
|
496
496
|
|
|
497
497
|
def _fuse_qkv(
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.3.0.
|
|
3
|
+
Version: 0.3.0.dev20240814
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
5
|
+
ai_edge_torch/version.py,sha256=BlH3JqkXwVHXFYAd5rF04dUvLCthvKVqnfgO3abgh14,706
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -42,22 +42,24 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
|
42
42
|
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
43
43
|
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
44
44
|
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=
|
|
45
|
+
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=8313wSsddvuxZ5ZYVdaITBV2FF1k22dcCujnq0UZvKs,6699
|
|
46
46
|
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
47
47
|
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
|
48
48
|
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
|
|
49
49
|
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
50
50
|
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
|
|
51
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=
|
|
51
|
+
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=zQYtyk3xYdiRAnzMKN58Q_wgTQFnDujxp6L4RFQjiD4,6383
|
|
52
52
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
53
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
|
|
53
54
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
|
|
54
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
|
55
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
|
|
56
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=j-zxJ-JNRnQ_kDzUESmsyy_a_4IxWZ510HmIImc0LDc,8240
|
|
55
57
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
56
58
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
|
|
57
59
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
|
|
58
60
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
59
61
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
|
62
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
|
|
61
63
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7ra36nM5tQwSw-vi6QCFLx5IssZhT-6yVK4H3XsAc4w,5044
|
|
62
64
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
|
63
65
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
|
|
@@ -72,27 +74,27 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
|
|
|
72
74
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
|
|
73
75
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
74
76
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz70WOvNpTJ9LFkiDnlwgJiXfUZCVk,4548
|
|
75
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
|
76
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
|
77
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
|
|
78
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
|
|
77
79
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
78
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
|
79
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=
|
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
|
80
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
|
|
81
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
|
|
82
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
|
81
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
82
84
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
|
|
83
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
|
85
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mXXFYJfo8yegSOFOndCR0oYxFPchYb9vTJ4ThXGIFLU,5940
|
|
84
86
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
|
85
87
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
|
86
88
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
87
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
|
88
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=
|
|
89
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
|
90
|
-
ai_edge_torch/generative/layers/feed_forward.py,sha256=
|
|
89
|
+
ai_edge_torch/generative/layers/attention.py,sha256=2UujQePRJ1LK02PN-hGcuMu0ooCJC6ETfPvzEYVFyho,12284
|
|
90
|
+
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
|
91
|
+
ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
|
|
92
|
+
ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
|
|
91
93
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lfvKca8JbWtFx2CE,3082
|
|
92
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
94
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
|
|
93
95
|
ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
|
|
94
96
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
|
95
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
|
97
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=x2bOmrTgOISXcb06IDP7X3xgftpPpxOjBXw_OxTMVns,3874
|
|
96
98
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
97
99
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
|
|
98
100
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
|
@@ -109,12 +111,12 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha
|
|
|
109
111
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
110
112
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=T5-O2RVLJTH7v9w1_uBfp-Y7o3sdGzYq2Tj2wLRNHyI,4357
|
|
111
113
|
ai_edge_torch/generative/test/test_loader.py,sha256=1ZqAq0HY5uIioumsReOVIsbGBx0WkYcl18PvttdJKrk,3381
|
|
112
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
114
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=52ciFy_Qol2Xuym6P6EqdL29oai35LSWGvsUwyEdFTo,8477
|
|
113
115
|
ai_edge_torch/generative/test/test_quantize.py,sha256=3SmJm7Kq98gAneU6IGwwJrJYCVH1qwWR6oUxPfb6qiI,5346
|
|
114
116
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
115
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
117
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=bAWZ7FM4v_pPnX_AmEdGxHkDH65QdL-MjIP3PxscZmI,12649
|
|
116
118
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
|
117
|
-
ai_edge_torch/generative/utilities/t5_loader.py,sha256=
|
|
119
|
+
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
|
118
120
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
|
119
121
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
|
120
122
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
|
@@ -134,8 +136,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
|
134
136
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
135
137
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
136
138
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
137
|
-
ai_edge_torch_nightly-0.3.0.
|
|
138
|
-
ai_edge_torch_nightly-0.3.0.
|
|
139
|
-
ai_edge_torch_nightly-0.3.0.
|
|
140
|
-
ai_edge_torch_nightly-0.3.0.
|
|
141
|
-
ai_edge_torch_nightly-0.3.0.
|
|
139
|
+
ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
140
|
+
ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/METADATA,sha256=eYXq0PpFouGnXKu9vXIzyaXj8XsLDxlDn903GJFR3ak,1885
|
|
141
|
+
ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
142
|
+
ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
143
|
+
ai_edge_torch_nightly-0.3.0.dev20240814.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|