ai-edge-torch-nightly 0.3.0.dev20240902__py3-none-any.whl → 0.3.0.dev20240905__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +32 -42
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +33 -15
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +29 -14
- ai_edge_torch/generative/examples/gemma/gemma.py +26 -41
- ai_edge_torch/generative/examples/gemma/gemma2.py +29 -48
- ai_edge_torch/generative/examples/phi2/phi2.py +25 -37
- ai_edge_torch/generative/examples/test_models/toy_model.py +50 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +21 -9
- ai_edge_torch/generative/test/test_model_conversion.py +0 -29
- ai_edge_torch/generative/test/test_quantize.py +23 -10
- ai_edge_torch/generative/utilities/loader.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +20 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +1 -1
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +24 -12
- ai_edge_torch/odml_torch/export.py +1 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/RECORD +21 -22
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -14
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,16 @@ import os
|
|
21
21
|
from pathlib import Path
|
22
22
|
from typing import Tuple
|
23
23
|
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
24
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
-
|
26
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
26
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
27
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
28
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
30
|
import numpy as np
|
31
31
|
import torch
|
32
|
-
|
32
|
+
from torch import nn
|
33
|
+
|
33
34
|
|
34
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
35
36
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -48,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
48
49
|
|
49
50
|
|
50
51
|
class Gemma(nn.Module):
|
52
|
+
"""A Gemma model built from the Edge Generative API layers."""
|
51
53
|
|
52
54
|
def __init__(self, config: cfg.ModelConfig):
|
53
55
|
super().__init__()
|
@@ -65,7 +67,7 @@ class Gemma(nn.Module):
|
|
65
67
|
# Gemma re-uses the embedding as the head projection layer.
|
66
68
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
67
69
|
self.transformer_blocks = nn.ModuleList(
|
68
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
70
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
69
71
|
)
|
70
72
|
self.final_norm = builder.build_norm(
|
71
73
|
config.embedding_dim,
|
@@ -95,9 +97,9 @@ class Gemma(nn.Module):
|
|
95
97
|
input_pos: torch.Tensor,
|
96
98
|
kv_cache: kv_utils.EKVCache,
|
97
99
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
98
|
-
|
99
|
-
assert self.config.max_seq_len >=
|
100
|
-
f"Cannot forward sequence of length {
|
100
|
+
_, seq_len = tokens.size()
|
101
|
+
assert self.config.max_seq_len >= seq_len, (
|
102
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
101
103
|
f" {self.config.max_seq_len}"
|
102
104
|
)
|
103
105
|
|
@@ -125,6 +127,15 @@ class Gemma(nn.Module):
|
|
125
127
|
|
126
128
|
|
127
129
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
130
|
+
"""Returns the model config for a Gemma 2B model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
134
|
+
is 1024.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
The model config for a Gemma 2B model.
|
138
|
+
"""
|
128
139
|
attn_config = cfg.AttentionConfig(
|
129
140
|
num_heads=8,
|
130
141
|
head_dim=256,
|
@@ -160,41 +171,18 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
160
171
|
|
161
172
|
|
162
173
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
)
|
169
|
-
ff_config = cfg.FeedForwardConfig(
|
170
|
-
type=cfg.FeedForwardType.GATED,
|
171
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
172
|
-
intermediate_size=128,
|
173
|
-
)
|
174
|
-
norm_config = cfg.NormalizationConfig(
|
175
|
-
type=cfg.NormalizationType.RMS_NORM,
|
176
|
-
epsilon=1e-6,
|
177
|
-
zero_centered=True,
|
178
|
-
)
|
179
|
-
config = cfg.ModelConfig(
|
180
|
-
vocab_size=128,
|
181
|
-
num_layers=2,
|
182
|
-
max_seq_len=2 * kv_cache_max_len,
|
183
|
-
embedding_dim=2048,
|
184
|
-
kv_cache_max_len=kv_cache_max_len,
|
185
|
-
attn_config=attn_config,
|
186
|
-
ff_config=ff_config,
|
187
|
-
pre_attention_norm_config=norm_config,
|
188
|
-
post_attention_norm_config=norm_config,
|
189
|
-
final_norm_config=norm_config,
|
190
|
-
parallel_residual=False,
|
191
|
-
lm_head_use_bias=False,
|
192
|
-
enable_hlfb=True,
|
193
|
-
)
|
174
|
+
config = get_model_config_2b(kv_cache_max_len)
|
175
|
+
config.ff_config.intermediate_size = 128
|
176
|
+
config.vocab_size = 128
|
177
|
+
config.num_layers = 2
|
178
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
194
179
|
return config
|
195
180
|
|
196
181
|
|
197
|
-
def build_2b_model(
|
182
|
+
def build_2b_model(
|
183
|
+
checkpoint_path: str, test_model: bool = False, **kwargs
|
184
|
+
) -> nn.Module:
|
185
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
198
186
|
config = (
|
199
187
|
get_fake_model_config(**kwargs)
|
200
188
|
if test_model
|
@@ -210,7 +198,9 @@ def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
210
198
|
return model
|
211
199
|
|
212
200
|
|
213
|
-
def define_and_run_2b(checkpoint_path, test_model=False) -> None:
|
201
|
+
def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
|
202
|
+
"""Instantiates and runs a Gemma 2B model."""
|
203
|
+
|
214
204
|
kv_cache_max_len = 1024
|
215
205
|
model = build_2b_model(
|
216
206
|
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
@@ -225,5 +215,5 @@ def define_and_run_2b(checkpoint_path, test_model=False) -> None:
|
|
225
215
|
|
226
216
|
|
227
217
|
if __name__ == "__main__":
|
228
|
-
|
229
|
-
define_and_run_2b(
|
218
|
+
input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
|
219
|
+
define_and_run_2b(input_checkpoint_path)
|
@@ -17,20 +17,20 @@
|
|
17
17
|
# Note: This is an experimental version of phi2 with external KV cache.
|
18
18
|
# Please use with caution.
|
19
19
|
|
20
|
-
|
21
20
|
import os
|
22
21
|
from pathlib import Path
|
23
22
|
from typing import Tuple
|
24
23
|
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
|
26
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
27
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
29
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
30
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
31
30
|
import numpy as np
|
32
31
|
import torch
|
33
|
-
|
32
|
+
from torch import nn
|
33
|
+
|
34
34
|
|
35
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
36
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -47,6 +47,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
47
47
|
|
48
48
|
|
49
49
|
class Phi2(nn.Module):
|
50
|
+
"""A Phi-2 model built from the Edge Generative API layers."""
|
50
51
|
|
51
52
|
def __init__(self, config: cfg.ModelConfig):
|
52
53
|
super().__init__()
|
@@ -60,7 +61,7 @@ class Phi2(nn.Module):
|
|
60
61
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
61
62
|
)
|
62
63
|
self.transformer_blocks = nn.ModuleList(
|
63
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
64
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
64
65
|
)
|
65
66
|
self.final_norm = builder.build_norm(
|
66
67
|
config.embedding_dim,
|
@@ -90,9 +91,9 @@ class Phi2(nn.Module):
|
|
90
91
|
input_pos: torch.Tensor,
|
91
92
|
kv_cache: kv_utils.EKVCache,
|
92
93
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
93
|
-
|
94
|
-
assert self.config.max_seq_len >=
|
95
|
-
f"Cannot forward sequence of length {
|
94
|
+
_, seq_len = tokens.size()
|
95
|
+
assert self.config.max_seq_len >= seq_len, (
|
96
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
96
97
|
f" {self.config.max_seq_len}"
|
97
98
|
)
|
98
99
|
|
@@ -118,6 +119,15 @@ class Phi2(nn.Module):
|
|
118
119
|
|
119
120
|
|
120
121
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
122
|
+
"""Returns the model config for a Phi-2 model.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
126
|
+
is 1024.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
The model config for a Phi-2 model.
|
130
|
+
"""
|
121
131
|
attn_config = cfg.AttentionConfig(
|
122
132
|
num_heads=32,
|
123
133
|
head_dim=80,
|
@@ -150,15 +160,21 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
150
160
|
return config
|
151
161
|
|
152
162
|
|
153
|
-
def
|
154
|
-
config = get_model_config(
|
163
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
164
|
+
config = get_model_config(kv_cache_max_len)
|
165
|
+
config.vocab_size = 128
|
155
166
|
config.num_layers = 2
|
167
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
168
|
+
config.ff_config.intermediate_size = 128
|
156
169
|
return config
|
157
170
|
|
158
171
|
|
159
|
-
def build_model(
|
172
|
+
def build_model(
|
173
|
+
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
+
) -> nn.Module:
|
175
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
160
176
|
config = (
|
161
|
-
|
177
|
+
get_fake_model_config(**kwargs)
|
162
178
|
if test_model
|
163
179
|
else get_model_config(**kwargs)
|
164
180
|
)
|
@@ -170,7 +186,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
170
186
|
return model
|
171
187
|
|
172
188
|
|
173
|
-
def define_and_run(checkpoint_path, test_model=False) -> None:
|
189
|
+
def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
|
190
|
+
"""Instantiates and runs a Phi-2 model."""
|
191
|
+
|
174
192
|
kv_cache_max_len = 1024
|
175
193
|
model = build_model(
|
176
194
|
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
@@ -185,5 +203,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
|
|
185
203
|
|
186
204
|
|
187
205
|
if __name__ == "__main__":
|
188
|
-
|
189
|
-
define_and_run(
|
206
|
+
input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
|
207
|
+
define_and_run(input_checkpoint_path)
|
@@ -17,20 +17,20 @@
|
|
17
17
|
# Note: This is an experimental version of TinyLlama with external KV cache.
|
18
18
|
# Please use with caution.
|
19
19
|
|
20
|
-
|
21
20
|
import os
|
22
21
|
from pathlib import Path
|
23
22
|
from typing import Tuple
|
24
23
|
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
|
26
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
27
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
29
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
30
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
31
30
|
import numpy as np
|
32
31
|
import torch
|
33
|
-
|
32
|
+
from torch import nn
|
33
|
+
|
34
34
|
|
35
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
36
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -49,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
49
49
|
|
50
50
|
|
51
51
|
class TinyLLamma(nn.Module):
|
52
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
52
53
|
|
53
54
|
def __init__(self, config: cfg.ModelConfig):
|
54
55
|
super().__init__()
|
@@ -62,7 +63,7 @@ class TinyLLamma(nn.Module):
|
|
62
63
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
63
64
|
)
|
64
65
|
self.transformer_blocks = nn.ModuleList(
|
65
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
66
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
66
67
|
)
|
67
68
|
self.final_norm = builder.build_norm(
|
68
69
|
config.embedding_dim,
|
@@ -92,9 +93,9 @@ class TinyLLamma(nn.Module):
|
|
92
93
|
input_pos: torch.Tensor,
|
93
94
|
kv_cache: kv_utils.EKVCache,
|
94
95
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
95
|
-
|
96
|
-
assert self.config.max_seq_len >=
|
97
|
-
f"Cannot forward sequence of length {
|
96
|
+
_, seq_len = tokens.size()
|
97
|
+
assert self.config.max_seq_len >= seq_len, (
|
98
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
98
99
|
f" {self.config.max_seq_len}"
|
99
100
|
)
|
100
101
|
|
@@ -121,6 +122,15 @@ class TinyLLamma(nn.Module):
|
|
121
122
|
|
122
123
|
|
123
124
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
125
|
+
"""Returns the model config for a TinyLlama model.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
129
|
+
is 1024.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
The model config for a TinyLlama model.
|
133
|
+
"""
|
124
134
|
attn_config = cfg.AttentionConfig(
|
125
135
|
num_heads=32,
|
126
136
|
head_dim=64,
|
@@ -149,7 +159,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
149
159
|
return config
|
150
160
|
|
151
161
|
|
152
|
-
def
|
162
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
153
163
|
config = get_model_config(**kwargs)
|
154
164
|
config.vocab_size = 128
|
155
165
|
config.num_layers = 2
|
@@ -157,9 +167,12 @@ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
|
|
157
167
|
return config
|
158
168
|
|
159
169
|
|
160
|
-
def build_model(
|
170
|
+
def build_model(
|
171
|
+
checkpoint_path: str, test_model: bool = False, **kwargs
|
172
|
+
) -> nn.Module:
|
173
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
161
174
|
config = (
|
162
|
-
|
175
|
+
get_fake_model_config(**kwargs)
|
163
176
|
if test_model
|
164
177
|
else get_model_config(**kwargs)
|
165
178
|
)
|
@@ -171,7 +184,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
171
184
|
return model
|
172
185
|
|
173
186
|
|
174
|
-
def define_and_run(checkpoint_path, test_model=False) -> None:
|
187
|
+
def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
|
188
|
+
"""Instantiates and runs a TinyLlama model."""
|
189
|
+
|
175
190
|
kv_cache_max_len = 1024
|
176
191
|
model = build_model(
|
177
192
|
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
@@ -186,5 +201,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
|
|
186
201
|
|
187
202
|
|
188
203
|
if __name__ == "__main__":
|
189
|
-
|
190
|
-
define_and_run(
|
204
|
+
input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
|
205
|
+
define_and_run(input_checkpoint_path)
|
@@ -17,14 +17,14 @@
|
|
17
17
|
import os
|
18
18
|
from pathlib import Path
|
19
19
|
|
20
|
-
from ai_edge_torch.generative.layers
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
21
|
+
from ai_edge_torch.generative.layers import builder
|
21
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
|
-
|
27
|
+
from torch import nn
|
28
28
|
|
29
29
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
30
30
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
|
44
44
|
|
45
45
|
class Gemma(nn.Module):
|
46
|
+
"""A Gemma model built from the Edge Generative API layers."""
|
46
47
|
|
47
48
|
def __init__(self, config: cfg.ModelConfig):
|
48
49
|
super().__init__()
|
@@ -60,7 +61,7 @@ class Gemma(nn.Module):
|
|
60
61
|
# Gemma re-uses the embedding as the head projection layer.
|
61
62
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
62
63
|
self.transformer_blocks = nn.ModuleList(
|
63
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
64
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
64
65
|
)
|
65
66
|
self.final_norm = builder.build_norm(
|
66
67
|
config.embedding_dim,
|
@@ -88,9 +89,9 @@ class Gemma(nn.Module):
|
|
88
89
|
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
89
90
|
@torch.inference_mode
|
90
91
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
91
|
-
|
92
|
-
assert self.config.max_seq_len >=
|
93
|
-
f"Cannot forward sequence of length {
|
92
|
+
_, seq_len = idx.size()
|
93
|
+
assert self.config.max_seq_len >= seq_len, (
|
94
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
94
95
|
f" {self.config.max_seq_len}"
|
95
96
|
)
|
96
97
|
|
@@ -104,7 +105,7 @@ class Gemma(nn.Module):
|
|
104
105
|
x = self.tok_embedding(idx)
|
105
106
|
x = x * (self.config.embedding_dim**0.5)
|
106
107
|
|
107
|
-
for
|
108
|
+
for _, block in enumerate(self.transformer_blocks):
|
108
109
|
x = block(x, (cos, sin), mask, input_pos)
|
109
110
|
|
110
111
|
x = self.final_norm(x)
|
@@ -113,6 +114,15 @@ class Gemma(nn.Module):
|
|
113
114
|
|
114
115
|
|
115
116
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
117
|
+
"""Returns the model config for a Gemma 2B model.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
121
|
+
is 1024.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
The model config for a Gemma 2B model.
|
125
|
+
"""
|
116
126
|
attn_config = cfg.AttentionConfig(
|
117
127
|
num_heads=8,
|
118
128
|
head_dim=256,
|
@@ -147,43 +157,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
147
157
|
return config
|
148
158
|
|
149
159
|
|
150
|
-
# TODO(b/363021962): Clean up this part to streamline fake model config generation.
|
151
160
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
)
|
158
|
-
ff_config = cfg.FeedForwardConfig(
|
159
|
-
type=cfg.FeedForwardType.GATED,
|
160
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
161
|
-
intermediate_size=128,
|
162
|
-
)
|
163
|
-
norm_config = cfg.NormalizationConfig(
|
164
|
-
type=cfg.NormalizationType.RMS_NORM,
|
165
|
-
epsilon=1e-6,
|
166
|
-
zero_centered=True,
|
167
|
-
)
|
168
|
-
config = cfg.ModelConfig(
|
169
|
-
vocab_size=128,
|
170
|
-
num_layers=2,
|
171
|
-
max_seq_len=2 * kv_cache_max_len,
|
172
|
-
embedding_dim=2048,
|
173
|
-
kv_cache_max_len=kv_cache_max_len,
|
174
|
-
attn_config=attn_config,
|
175
|
-
ff_config=ff_config,
|
176
|
-
pre_attention_norm_config=norm_config,
|
177
|
-
post_attention_norm_config=norm_config,
|
178
|
-
final_norm_config=norm_config,
|
179
|
-
parallel_residual=False,
|
180
|
-
lm_head_use_bias=False,
|
181
|
-
enable_hlfb=True,
|
182
|
-
)
|
161
|
+
config = get_model_config_2b(kv_cache_max_len)
|
162
|
+
config.ff_config.intermediate_size = 128
|
163
|
+
config.vocab_size = 128
|
164
|
+
config.num_layers = 2
|
165
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
183
166
|
return config
|
184
167
|
|
185
168
|
|
186
|
-
def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
169
|
+
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
187
170
|
config = get_model_config_2b(**kwargs)
|
188
171
|
model = Gemma(config)
|
189
172
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
@@ -195,6 +178,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
195
178
|
|
196
179
|
|
197
180
|
def define_and_run_2b() -> None:
|
181
|
+
"""Instantiates and runs a Gemma 2B model."""
|
182
|
+
|
198
183
|
current_dir = Path(__file__).parent.resolve()
|
199
184
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
200
185
|
|
@@ -18,14 +18,14 @@ import os
|
|
18
18
|
from pathlib import Path
|
19
19
|
from typing import Optional, Tuple
|
20
20
|
|
21
|
-
from ai_edge_torch.generative.layers
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
-
import ai_edge_torch.generative.layers.builder as builder
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
26
|
import numpy as np
|
27
27
|
import torch
|
28
|
-
|
28
|
+
from torch import nn
|
29
29
|
|
30
30
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
31
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -43,7 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
)
|
44
44
|
|
45
45
|
|
46
|
-
class Gemma2Block(TransformerBlock):
|
46
|
+
class Gemma2Block(attention.TransformerBlock):
|
47
47
|
|
48
48
|
def forward(
|
49
49
|
self,
|
@@ -76,6 +76,7 @@ class Gemma2Block(TransformerBlock):
|
|
76
76
|
|
77
77
|
|
78
78
|
class Gemma2(nn.Module):
|
79
|
+
"""A Gemma2 model built from the Edge Generative API layers."""
|
79
80
|
|
80
81
|
def __init__(self, config: cfg.ModelConfig):
|
81
82
|
super().__init__()
|
@@ -138,9 +139,9 @@ class Gemma2(nn.Module):
|
|
138
139
|
|
139
140
|
@torch.inference_mode
|
140
141
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
141
|
-
|
142
|
-
assert self.config.max_seq_len >=
|
143
|
-
f"Cannot forward sequence of length {
|
142
|
+
_, seq_len = idx.size()
|
143
|
+
assert self.config.max_seq_len >= seq_len, (
|
144
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
144
145
|
f" {self.config.max_seq_len}"
|
145
146
|
)
|
146
147
|
|
@@ -166,6 +167,15 @@ class Gemma2(nn.Module):
|
|
166
167
|
|
167
168
|
|
168
169
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
170
|
+
"""Returns the model config for a Gemma2 2B model.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
174
|
+
is 1024.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
The model config for a Gemma 2B model.
|
178
|
+
"""
|
169
179
|
attn_config = cfg.AttentionConfig(
|
170
180
|
num_heads=8,
|
171
181
|
head_dim=256,
|
@@ -210,50 +220,19 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
210
220
|
|
211
221
|
|
212
222
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
* 13,
|
223
|
-
)
|
224
|
-
|
225
|
-
norm_config = cfg.NormalizationConfig(
|
226
|
-
type=cfg.NormalizationType.RMS_NORM,
|
227
|
-
epsilon=1e-6,
|
228
|
-
zero_centered=True,
|
229
|
-
)
|
230
|
-
ff_config = cfg.FeedForwardConfig(
|
231
|
-
type=cfg.FeedForwardType.GATED,
|
232
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
233
|
-
intermediate_size=128,
|
234
|
-
pre_ff_norm_config=norm_config,
|
235
|
-
post_ff_norm_config=norm_config,
|
236
|
-
)
|
237
|
-
config = cfg.ModelConfig(
|
238
|
-
vocab_size=128,
|
239
|
-
num_layers=2,
|
240
|
-
max_seq_len=2 * kv_cache_max_len,
|
241
|
-
embedding_dim=128,
|
242
|
-
kv_cache_max_len=kv_cache_max_len,
|
243
|
-
attn_config=attn_config,
|
244
|
-
ff_config=ff_config,
|
245
|
-
pre_attention_norm_config=norm_config,
|
246
|
-
post_attention_norm_config=norm_config,
|
247
|
-
final_norm_config=norm_config,
|
248
|
-
parallel_residual=False,
|
249
|
-
lm_head_use_bias=False,
|
250
|
-
enable_hlfb=True,
|
251
|
-
final_logit_softcap=30.0,
|
252
|
-
)
|
223
|
+
config = get_model_config_2b(kv_cache_max_len)
|
224
|
+
config.attn_config.num_heads = 4
|
225
|
+
config.attn_config.head_dim = 64
|
226
|
+
config.attn_config.sliding_window_size = 64
|
227
|
+
config.ff_config.intermediate_size = 128
|
228
|
+
config.vocab_size = 128
|
229
|
+
config.num_layers = 2
|
230
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
231
|
+
config.embedding_dim = 128
|
253
232
|
return config
|
254
233
|
|
255
234
|
|
256
|
-
def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
235
|
+
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
257
236
|
config = get_model_config_2b(**kwargs)
|
258
237
|
model = Gemma2(config)
|
259
238
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
@@ -265,6 +244,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
265
244
|
|
266
245
|
|
267
246
|
def define_and_run_2b() -> None:
|
247
|
+
"""Instantiates and runs a Gemma2 2B model."""
|
248
|
+
|
268
249
|
current_dir = Path(__file__).parent.resolve()
|
269
250
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
270
251
|
print("Running GEMMA 2")
|
@@ -18,14 +18,14 @@
|
|
18
18
|
import os
|
19
19
|
from pathlib import Path
|
20
20
|
|
21
|
-
from ai_edge_torch.generative.layers
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
-
import ai_edge_torch.generative.layers.builder as builder
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
26
|
import numpy as np
|
27
27
|
import torch
|
28
|
-
|
28
|
+
from torch import nn
|
29
29
|
|
30
30
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
31
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -42,6 +42,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
42
42
|
|
43
43
|
|
44
44
|
class Phi2(nn.Module):
|
45
|
+
"""A Phi-2 model built from the Edge Generative API layers."""
|
45
46
|
|
46
47
|
def __init__(self, config: cfg.ModelConfig):
|
47
48
|
super().__init__()
|
@@ -55,7 +56,7 @@ class Phi2(nn.Module):
|
|
55
56
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
56
57
|
)
|
57
58
|
self.transformer_blocks = nn.ModuleList(
|
58
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
59
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
59
60
|
)
|
60
61
|
self.final_norm = builder.build_norm(
|
61
62
|
config.embedding_dim,
|
@@ -83,9 +84,9 @@ class Phi2(nn.Module):
|
|
83
84
|
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
84
85
|
@torch.inference_mode
|
85
86
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
86
|
-
|
87
|
-
assert self.config.max_seq_len >=
|
88
|
-
f"Cannot forward sequence of length {
|
87
|
+
_, seq_len = idx.size()
|
88
|
+
assert self.config.max_seq_len >= seq_len, (
|
89
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
89
90
|
f" {self.config.max_seq_len}"
|
90
91
|
)
|
91
92
|
|
@@ -98,7 +99,7 @@ class Phi2(nn.Module):
|
|
98
99
|
# forward the model itself
|
99
100
|
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
|
100
101
|
|
101
|
-
for
|
102
|
+
for _, block in enumerate(self.transformer_blocks):
|
102
103
|
x = block(x, (cos, sin), mask, input_pos)
|
103
104
|
|
104
105
|
x = self.final_norm(x)
|
@@ -107,6 +108,15 @@ class Phi2(nn.Module):
|
|
107
108
|
|
108
109
|
|
109
110
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
111
|
+
"""Returns the model config for a Phi-2 model.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
115
|
+
is 1024.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
The model config for a Phi-2 model.
|
119
|
+
"""
|
110
120
|
attn_config = cfg.AttentionConfig(
|
111
121
|
num_heads=32,
|
112
122
|
head_dim=80,
|
@@ -140,35 +150,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
150
|
|
141
151
|
|
142
152
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
qkv_use_bias=True,
|
149
|
-
output_proj_use_bias=True,
|
150
|
-
)
|
151
|
-
ff_config = cfg.FeedForwardConfig(
|
152
|
-
type=cfg.FeedForwardType.SEQUENTIAL,
|
153
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
154
|
-
intermediate_size=128,
|
155
|
-
use_bias=True,
|
156
|
-
)
|
157
|
-
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
158
|
-
config = cfg.ModelConfig(
|
159
|
-
vocab_size=128,
|
160
|
-
num_layers=2,
|
161
|
-
max_seq_len=2 * kv_cache_max_len,
|
162
|
-
kv_cache_max_len=kv_cache_max_len,
|
163
|
-
embedding_dim=128,
|
164
|
-
attn_config=attn_config,
|
165
|
-
ff_config=ff_config,
|
166
|
-
pre_attention_norm_config=norm_config,
|
167
|
-
final_norm_config=norm_config,
|
168
|
-
parallel_residual=True,
|
169
|
-
lm_head_use_bias=True,
|
170
|
-
enable_hlfb=True,
|
171
|
-
)
|
153
|
+
config = get_model_config(kv_cache_max_len)
|
154
|
+
config.vocab_size = 128
|
155
|
+
config.num_layers = 2
|
156
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
157
|
+
config.ff_config.intermediate_size = 128
|
172
158
|
return config
|
173
159
|
|
174
160
|
|
@@ -181,6 +167,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
181
167
|
|
182
168
|
|
183
169
|
def define_and_run() -> None:
|
170
|
+
"""Instantiates and runs a Phi-2 model."""
|
171
|
+
|
184
172
|
current_dir = Path(__file__).parent.resolve()
|
185
173
|
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
186
174
|
kv_cache_max_len = 1024
|
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
71
71
|
return self.lm_head(x)
|
72
72
|
|
73
73
|
|
74
|
+
class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
75
|
+
|
76
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
77
|
+
super().__init__()
|
78
|
+
self.lm_head = nn.Linear(
|
79
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
80
|
+
)
|
81
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
82
|
+
self.lm_head = nn.Linear(
|
83
|
+
config.embedding_dim,
|
84
|
+
config.vocab_size,
|
85
|
+
bias=config.lm_head_use_bias,
|
86
|
+
)
|
87
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
88
|
+
self.transformer_block = TransformerBlock(config)
|
89
|
+
self.final_norm = builder.build_norm(
|
90
|
+
config.embedding_dim,
|
91
|
+
config.final_norm_config,
|
92
|
+
)
|
93
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
94
|
+
size=config.max_seq_len,
|
95
|
+
dim=int(
|
96
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
97
|
+
),
|
98
|
+
base=10_000,
|
99
|
+
condense_ratio=1,
|
100
|
+
dtype=torch.float32,
|
101
|
+
device=torch.device('cpu'),
|
102
|
+
)
|
103
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
104
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
105
|
+
)
|
106
|
+
self.config = config
|
107
|
+
|
108
|
+
@torch.inference_mode
|
109
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
110
|
+
x = self.tok_embedding(idx)
|
111
|
+
cos, sin = self.rope_cache
|
112
|
+
|
113
|
+
cos = cos.index_select(0, input_pos)
|
114
|
+
sin = sin.index_select(0, input_pos)
|
115
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
116
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
117
|
+
|
118
|
+
x = self.transformer_block(x, (cos, sin), mask, input_pos)
|
119
|
+
x = self.final_norm(x)
|
120
|
+
res = self.lm_head(x)
|
121
|
+
return res
|
122
|
+
|
123
|
+
|
74
124
|
def get_model_config() -> cfg.ModelConfig:
|
75
125
|
attn_config = cfg.AttentionConfig(
|
76
126
|
num_heads=32,
|
@@ -17,14 +17,14 @@
|
|
17
17
|
import os
|
18
18
|
from pathlib import Path
|
19
19
|
|
20
|
-
from ai_edge_torch.generative.layers
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
21
|
+
from ai_edge_torch.generative.layers import builder
|
21
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
|
-
|
27
|
+
from torch import nn
|
28
28
|
|
29
29
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
30
30
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
|
44
44
|
|
45
45
|
class TinyLLamma(nn.Module):
|
46
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
46
47
|
|
47
48
|
def __init__(self, config: cfg.ModelConfig):
|
48
49
|
super().__init__()
|
@@ -56,7 +57,7 @@ class TinyLLamma(nn.Module):
|
|
56
57
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
57
58
|
)
|
58
59
|
self.transformer_blocks = nn.ModuleList(
|
59
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
60
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
60
61
|
)
|
61
62
|
self.final_norm = builder.build_norm(
|
62
63
|
config.embedding_dim,
|
@@ -84,9 +85,9 @@ class TinyLLamma(nn.Module):
|
|
84
85
|
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
85
86
|
@torch.inference_mode
|
86
87
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
87
|
-
|
88
|
-
assert self.config.max_seq_len >=
|
89
|
-
f"Cannot forward sequence of length {
|
88
|
+
_, seq_len = idx.size()
|
89
|
+
assert self.config.max_seq_len >= seq_len, (
|
90
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
90
91
|
f" {self.config.max_seq_len}"
|
91
92
|
)
|
92
93
|
|
@@ -99,7 +100,7 @@ class TinyLLamma(nn.Module):
|
|
99
100
|
# forward the model itself
|
100
101
|
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
|
101
102
|
|
102
|
-
for
|
103
|
+
for _, block in enumerate(self.transformer_blocks):
|
103
104
|
x = block(x, (cos, sin), mask, input_pos)
|
104
105
|
|
105
106
|
x = self.final_norm(x)
|
@@ -109,6 +110,15 @@ class TinyLLamma(nn.Module):
|
|
109
110
|
|
110
111
|
|
111
112
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
113
|
+
"""Returns the model config for a TinyLlama model.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
117
|
+
is 1024.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The model config for a TinyLlama model.
|
121
|
+
"""
|
112
122
|
attn_config = cfg.AttentionConfig(
|
113
123
|
num_heads=32,
|
114
124
|
head_dim=64,
|
@@ -145,7 +155,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
145
155
|
return config
|
146
156
|
|
147
157
|
|
148
|
-
def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
158
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
149
159
|
config = get_model_config(**kwargs)
|
150
160
|
model = TinyLLamma(config)
|
151
161
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
@@ -154,6 +164,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
154
164
|
|
155
165
|
|
156
166
|
def define_and_run() -> None:
|
167
|
+
"""Instantiates and runs a TinyLlama model."""
|
168
|
+
|
157
169
|
current_dir = Path(__file__).parent.resolve()
|
158
170
|
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
159
171
|
kv_cache_max_len = 1024
|
@@ -70,35 +70,6 @@ class TestModelConversion(googletest.TestCase):
|
|
70
70
|
)
|
71
71
|
)
|
72
72
|
|
73
|
-
@googletest.skipIf(
|
74
|
-
ai_edge_config.Config.use_torch_xla,
|
75
|
-
reason="tests with custom ops are not supported on oss",
|
76
|
-
)
|
77
|
-
def test_toy_model_with_multi_batches(self):
|
78
|
-
self.skipTest("b/362842043")
|
79
|
-
config = toy_model_with_kv_cache.get_model_config()
|
80
|
-
config.batch_size = 2
|
81
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
82
|
-
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
|
83
|
-
[10], dtype=torch.int64
|
84
|
-
)
|
85
|
-
|
86
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
87
|
-
edge_model.set_interpreter_builder(
|
88
|
-
self._interpreter_builder(edge_model.tflite_model())
|
89
|
-
)
|
90
|
-
|
91
|
-
self.assertTrue(
|
92
|
-
model_coverage.compare_tflite_torch(
|
93
|
-
edge_model,
|
94
|
-
pytorch_model,
|
95
|
-
(idx, input_pos),
|
96
|
-
num_valid_inputs=1,
|
97
|
-
atol=1e-5,
|
98
|
-
rtol=1e-5,
|
99
|
-
)
|
100
|
-
)
|
101
|
-
|
102
73
|
@googletest.skipIf(
|
103
74
|
ai_edge_config.Config.use_torch_xla,
|
104
75
|
reason="tests with custom ops are not supported on oss",
|
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
25
25
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
26
26
|
from ai_edge_torch.quantize import quant_config
|
27
27
|
from ai_edge_torch.testing import model_coverage
|
28
|
-
from parameterized import parameterized
|
29
28
|
import torch
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from absl.testing import parameterized
|
32
32
|
|
33
33
|
|
34
|
-
class TestVerifyRecipes(
|
34
|
+
class TestVerifyRecipes(parameterized.TestCase):
|
35
35
|
"""Unit tests that check for model quantization recipes."""
|
36
36
|
|
37
|
-
@parameterized.
|
37
|
+
@parameterized.parameters([
|
38
38
|
(Dtype.FP32, Dtype.FP32),
|
39
39
|
(Dtype.INT8, Dtype.INT8),
|
40
40
|
(Dtype.INT8, Dtype.FP16),
|
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
52
52
|
with self.assertRaises(ValueError):
|
53
53
|
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
54
54
|
|
55
|
-
@parameterized.
|
55
|
+
@parameterized.parameters([
|
56
56
|
(
|
57
57
|
Dtype.FP32,
|
58
58
|
Dtype.INT8,
|
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
88
88
|
).verify()
|
89
89
|
|
90
90
|
|
91
|
-
class TestQuantizeConvert(
|
91
|
+
class TestQuantizeConvert(parameterized.TestCase):
|
92
92
|
"""Test conversion with quantization."""
|
93
93
|
|
94
94
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
105
105
|
)
|
106
106
|
)
|
107
107
|
|
108
|
-
@parameterized.
|
108
|
+
@parameterized.parameters([
|
109
109
|
(quant_recipes.full_fp16_recipe()),
|
110
110
|
(quant_recipes.full_int8_dynamic_recipe()),
|
111
111
|
(quant_recipes.full_int8_weight_only_recipe()),
|
112
112
|
(_attention_int8_dynamic_recipe()),
|
113
113
|
(_feedforward_int8_dynamic_recipe()),
|
114
114
|
])
|
115
|
-
@googletest.skipIf(
|
116
|
-
not config.Config.use_torch_xla,
|
117
|
-
reason="Not working with odml_torch at the moment.",
|
118
|
-
)
|
119
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
120
116
|
config = toy_model.get_model_config()
|
121
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
132
128
|
"Quantized model isn't smaller than F32 model.",
|
133
129
|
)
|
134
130
|
|
131
|
+
def test_quantize_convert_toy_weight_sharing(self):
|
132
|
+
config = toy_model.get_model_config()
|
133
|
+
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
+
input_pos = torch.arange(0, 100)
|
136
|
+
|
137
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
|
+
quantized_model = ai_edge_torch.convert(
|
139
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
140
|
+
)
|
141
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
142
|
+
self.assertLess(
|
143
|
+
len(quantized_model._tflite_model),
|
144
|
+
len(float_model._tflite_model),
|
145
|
+
"Quantized model isn't smaller than F32 model.",
|
146
|
+
)
|
147
|
+
|
135
148
|
def test_quantize_convert_compare_toy(self):
|
136
149
|
self.skipTest("b/338288901")
|
137
150
|
config = toy_model_with_kv_cache.get_model_config()
|
@@ -208,7 +208,7 @@ class ModelLoader:
|
|
208
208
|
if self._file_name.endswith(".safetensors"):
|
209
209
|
return load_safetensors
|
210
210
|
|
211
|
-
if self._file_name.endswith(".bin") or self._file_name.endswith("
|
211
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
|
212
212
|
return load_pytorch_statedict
|
213
213
|
|
214
214
|
raise ValueError("File format not supported.")
|
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
|
|
21
21
|
from ai_edge_torch._convert import conversion_utils
|
22
22
|
from ai_edge_torch._convert import signature as signature_module
|
23
23
|
from ai_edge_torch.lowertools import common_utils
|
24
|
+
from ai_edge_torch.lowertools import translate_recipe
|
24
25
|
from ai_edge_torch.odml_torch import export
|
25
26
|
from ai_edge_torch.odml_torch import export_utils
|
26
27
|
from ai_edge_torch.quantize import quant_config as qcfg
|
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
|
|
186
187
|
converter._experimental_enable_composite_direct_lowering = True
|
187
188
|
converter.model_origin_framework = "PYTORCH"
|
188
189
|
|
190
|
+
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
191
|
+
if (
|
192
|
+
quant_config is not None
|
193
|
+
and quant_config._quantizer_mode
|
194
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
195
|
+
):
|
196
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
197
|
+
quant_config.generative_recipe
|
198
|
+
)
|
199
|
+
|
189
200
|
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
190
201
|
|
191
202
|
tflite_model = converter.convert()
|
192
203
|
|
204
|
+
if (
|
205
|
+
quant_config is not None
|
206
|
+
and quant_config._quantizer_mode
|
207
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
208
|
+
):
|
209
|
+
tflite_model = translate_recipe.quantize_model(
|
210
|
+
tflite_model, translated_recipe
|
211
|
+
)
|
212
|
+
|
193
213
|
return tflite_model
|
194
214
|
|
195
215
|
|
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
25
25
|
from ai_edge_torch import model
|
26
26
|
from ai_edge_torch._convert import conversion_utils
|
27
27
|
from ai_edge_torch._convert import signature as signature_module
|
28
|
-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
29
28
|
from ai_edge_torch.lowertools import common_utils
|
29
|
+
from ai_edge_torch.lowertools import translate_recipe
|
30
30
|
from ai_edge_torch.quantize import quant_config as qcfg
|
31
31
|
import torch
|
32
32
|
from torch_xla import stablehlo
|
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
|
|
17
17
|
from ai_edge_torch.generative.quantize import quant_attrs
|
18
18
|
from ai_edge_torch.generative.quantize import quant_recipe
|
19
19
|
|
20
|
-
|
20
|
+
_ComputePrecision = quantizer.qtyping.ComputePrecision
|
21
|
+
_QuantGranularity = quantizer.qtyping.QuantGranularity
|
21
22
|
_OpName = quantizer.qtyping.TFLOperationName
|
22
23
|
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
23
24
|
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
|
|
50
51
|
return quantizer.qtyping.TensorDataType.INT
|
51
52
|
|
52
53
|
|
53
|
-
def
|
54
|
+
def _get_compute_precision_from_mode(
|
55
|
+
mode: quant_attrs.Mode,
|
56
|
+
) -> _ComputePrecision:
|
54
57
|
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
55
|
-
return
|
58
|
+
return _ComputePrecision.INTEGER
|
56
59
|
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
57
|
-
return
|
60
|
+
return _ComputePrecision.FLOAT
|
58
61
|
raise ValueError('Unimplemented execution mode')
|
59
62
|
|
60
63
|
|
61
|
-
def
|
64
|
+
def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
|
65
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
66
|
+
return False
|
67
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
68
|
+
return True
|
69
|
+
raise ValueError('Unimplemented execution mode')
|
70
|
+
|
71
|
+
|
72
|
+
def _get_granularity(
|
62
73
|
granularity: quant_attrs.Granularity,
|
63
74
|
) -> bool:
|
64
75
|
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
65
|
-
return
|
66
|
-
|
67
|
-
return
|
76
|
+
return _QuantGranularity.CHANNELWISE
|
77
|
+
if granularity == quant_attrs.Granularity.NONE:
|
78
|
+
return _QuantGranularity.TENSORWISE
|
68
79
|
raise ValueError('Unimplemented granularity')
|
69
80
|
|
70
81
|
|
@@ -88,12 +99,13 @@ def _set_quant_config(
|
|
88
99
|
weight_tensor_config=_TensorQuantConfig(
|
89
100
|
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
90
101
|
symmetric=True,
|
91
|
-
|
92
|
-
layer_recipe.granularity
|
93
|
-
),
|
102
|
+
granularity=_get_granularity(layer_recipe.granularity),
|
94
103
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
95
104
|
),
|
96
|
-
|
105
|
+
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
|
106
|
+
explicit_dequantize=_get_explicit_dequant_from_mode(
|
107
|
+
layer_recipe.mode
|
108
|
+
),
|
97
109
|
),
|
98
110
|
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
99
111
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240905
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
-
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: ai-edge-quantizer-nightly
|
34
34
|
|
35
35
|
Library that supports converting PyTorch models into a .tflite format, which can
|
36
36
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256
|
5
|
+
ai_edge_torch/version.py,sha256=-vQGdl2EaV-VpHRty3RwZzH0UVntVt1tmjhtKOIDscw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -42,21 +42,21 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
42
42
|
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
43
|
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
44
44
|
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=
|
45
|
+
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
|
46
46
|
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
48
|
-
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=
|
48
|
+
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
|
49
49
|
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
50
|
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
|
51
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
|
52
52
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
53
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
|
54
54
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
|
55
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
|
56
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
|
57
57
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
58
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
|
59
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=
|
59
|
+
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
|
60
60
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
61
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
|
@@ -77,12 +77,12 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
|
|
77
77
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
|
78
78
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
|
79
79
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
|
81
81
|
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
|
82
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
83
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
84
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
|
85
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
85
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
|
86
86
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
87
87
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
88
88
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
|
|
106
106
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
107
107
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
108
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
|
111
109
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
112
110
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
113
111
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
114
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
112
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=KZ0uCeOdKMKyW8jBE8aOjweZmws4mvz37u8zH4XayVU,5285
|
115
113
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
|
116
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
114
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
117
115
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
118
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
116
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
119
117
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
120
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
121
119
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
|
|
128
126
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
129
127
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
130
128
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
131
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
129
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
132
130
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
133
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256
|
131
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
|
132
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
|
134
133
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
135
134
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
136
135
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
137
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
136
|
+
ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
|
138
137
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
139
138
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
140
139
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
162
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
163
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
164
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/METADATA,sha256=8yrrm7TEYgaRhKdUwgStjCqrTWs8YcnnlzoTJt2NrJg,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
File without changes
|
File without changes
|