ai-edge-torch-nightly 0.3.0.dev20240902__py3-none-any.whl → 0.3.0.dev20240905__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +32 -42
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +33 -15
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +29 -14
- ai_edge_torch/generative/examples/gemma/gemma.py +26 -41
- ai_edge_torch/generative/examples/gemma/gemma2.py +29 -48
- ai_edge_torch/generative/examples/phi2/phi2.py +25 -37
- ai_edge_torch/generative/examples/test_models/toy_model.py +50 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +21 -9
- ai_edge_torch/generative/test/test_model_conversion.py +0 -29
- ai_edge_torch/generative/test/test_quantize.py +23 -10
- ai_edge_torch/generative/utilities/loader.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +20 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +1 -1
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +24 -12
- ai_edge_torch/odml_torch/export.py +1 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/RECORD +21 -22
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -14
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,16 @@ import os
|
|
21
21
|
from pathlib import Path
|
22
22
|
from typing import Tuple
|
23
23
|
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
24
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
|
-
|
26
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
26
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
27
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
28
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
29
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
30
|
import numpy as np
|
31
31
|
import torch
|
32
|
-
|
32
|
+
from torch import nn
|
33
|
+
|
33
34
|
|
34
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
35
36
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -48,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
48
49
|
|
49
50
|
|
50
51
|
class Gemma(nn.Module):
|
52
|
+
"""A Gemma model built from the Edge Generative API layers."""
|
51
53
|
|
52
54
|
def __init__(self, config: cfg.ModelConfig):
|
53
55
|
super().__init__()
|
@@ -65,7 +67,7 @@ class Gemma(nn.Module):
|
|
65
67
|
# Gemma re-uses the embedding as the head projection layer.
|
66
68
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
67
69
|
self.transformer_blocks = nn.ModuleList(
|
68
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
70
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
69
71
|
)
|
70
72
|
self.final_norm = builder.build_norm(
|
71
73
|
config.embedding_dim,
|
@@ -95,9 +97,9 @@ class Gemma(nn.Module):
|
|
95
97
|
input_pos: torch.Tensor,
|
96
98
|
kv_cache: kv_utils.EKVCache,
|
97
99
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
98
|
-
|
99
|
-
assert self.config.max_seq_len >=
|
100
|
-
f"Cannot forward sequence of length {
|
100
|
+
_, seq_len = tokens.size()
|
101
|
+
assert self.config.max_seq_len >= seq_len, (
|
102
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
101
103
|
f" {self.config.max_seq_len}"
|
102
104
|
)
|
103
105
|
|
@@ -125,6 +127,15 @@ class Gemma(nn.Module):
|
|
125
127
|
|
126
128
|
|
127
129
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
130
|
+
"""Returns the model config for a Gemma 2B model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
134
|
+
is 1024.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
The model config for a Gemma 2B model.
|
138
|
+
"""
|
128
139
|
attn_config = cfg.AttentionConfig(
|
129
140
|
num_heads=8,
|
130
141
|
head_dim=256,
|
@@ -160,41 +171,18 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
160
171
|
|
161
172
|
|
162
173
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
)
|
169
|
-
ff_config = cfg.FeedForwardConfig(
|
170
|
-
type=cfg.FeedForwardType.GATED,
|
171
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
172
|
-
intermediate_size=128,
|
173
|
-
)
|
174
|
-
norm_config = cfg.NormalizationConfig(
|
175
|
-
type=cfg.NormalizationType.RMS_NORM,
|
176
|
-
epsilon=1e-6,
|
177
|
-
zero_centered=True,
|
178
|
-
)
|
179
|
-
config = cfg.ModelConfig(
|
180
|
-
vocab_size=128,
|
181
|
-
num_layers=2,
|
182
|
-
max_seq_len=2 * kv_cache_max_len,
|
183
|
-
embedding_dim=2048,
|
184
|
-
kv_cache_max_len=kv_cache_max_len,
|
185
|
-
attn_config=attn_config,
|
186
|
-
ff_config=ff_config,
|
187
|
-
pre_attention_norm_config=norm_config,
|
188
|
-
post_attention_norm_config=norm_config,
|
189
|
-
final_norm_config=norm_config,
|
190
|
-
parallel_residual=False,
|
191
|
-
lm_head_use_bias=False,
|
192
|
-
enable_hlfb=True,
|
193
|
-
)
|
174
|
+
config = get_model_config_2b(kv_cache_max_len)
|
175
|
+
config.ff_config.intermediate_size = 128
|
176
|
+
config.vocab_size = 128
|
177
|
+
config.num_layers = 2
|
178
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
194
179
|
return config
|
195
180
|
|
196
181
|
|
197
|
-
def build_2b_model(
|
182
|
+
def build_2b_model(
|
183
|
+
checkpoint_path: str, test_model: bool = False, **kwargs
|
184
|
+
) -> nn.Module:
|
185
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
198
186
|
config = (
|
199
187
|
get_fake_model_config(**kwargs)
|
200
188
|
if test_model
|
@@ -210,7 +198,9 @@ def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
210
198
|
return model
|
211
199
|
|
212
200
|
|
213
|
-
def define_and_run_2b(checkpoint_path, test_model=False) -> None:
|
201
|
+
def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
|
202
|
+
"""Instantiates and runs a Gemma 2B model."""
|
203
|
+
|
214
204
|
kv_cache_max_len = 1024
|
215
205
|
model = build_2b_model(
|
216
206
|
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
@@ -225,5 +215,5 @@ def define_and_run_2b(checkpoint_path, test_model=False) -> None:
|
|
225
215
|
|
226
216
|
|
227
217
|
if __name__ == "__main__":
|
228
|
-
|
229
|
-
define_and_run_2b(
|
218
|
+
input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
|
219
|
+
define_and_run_2b(input_checkpoint_path)
|
@@ -17,20 +17,20 @@
|
|
17
17
|
# Note: This is an experimental version of phi2 with external KV cache.
|
18
18
|
# Please use with caution.
|
19
19
|
|
20
|
-
|
21
20
|
import os
|
22
21
|
from pathlib import Path
|
23
22
|
from typing import Tuple
|
24
23
|
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
|
26
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
27
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
29
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
30
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
31
30
|
import numpy as np
|
32
31
|
import torch
|
33
|
-
|
32
|
+
from torch import nn
|
33
|
+
|
34
34
|
|
35
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
36
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -47,6 +47,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
47
47
|
|
48
48
|
|
49
49
|
class Phi2(nn.Module):
|
50
|
+
"""A Phi-2 model built from the Edge Generative API layers."""
|
50
51
|
|
51
52
|
def __init__(self, config: cfg.ModelConfig):
|
52
53
|
super().__init__()
|
@@ -60,7 +61,7 @@ class Phi2(nn.Module):
|
|
60
61
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
61
62
|
)
|
62
63
|
self.transformer_blocks = nn.ModuleList(
|
63
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
64
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
64
65
|
)
|
65
66
|
self.final_norm = builder.build_norm(
|
66
67
|
config.embedding_dim,
|
@@ -90,9 +91,9 @@ class Phi2(nn.Module):
|
|
90
91
|
input_pos: torch.Tensor,
|
91
92
|
kv_cache: kv_utils.EKVCache,
|
92
93
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
93
|
-
|
94
|
-
assert self.config.max_seq_len >=
|
95
|
-
f"Cannot forward sequence of length {
|
94
|
+
_, seq_len = tokens.size()
|
95
|
+
assert self.config.max_seq_len >= seq_len, (
|
96
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
96
97
|
f" {self.config.max_seq_len}"
|
97
98
|
)
|
98
99
|
|
@@ -118,6 +119,15 @@ class Phi2(nn.Module):
|
|
118
119
|
|
119
120
|
|
120
121
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
122
|
+
"""Returns the model config for a Phi-2 model.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
126
|
+
is 1024.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
The model config for a Phi-2 model.
|
130
|
+
"""
|
121
131
|
attn_config = cfg.AttentionConfig(
|
122
132
|
num_heads=32,
|
123
133
|
head_dim=80,
|
@@ -150,15 +160,21 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
150
160
|
return config
|
151
161
|
|
152
162
|
|
153
|
-
def
|
154
|
-
config = get_model_config(
|
163
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
164
|
+
config = get_model_config(kv_cache_max_len)
|
165
|
+
config.vocab_size = 128
|
155
166
|
config.num_layers = 2
|
167
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
168
|
+
config.ff_config.intermediate_size = 128
|
156
169
|
return config
|
157
170
|
|
158
171
|
|
159
|
-
def build_model(
|
172
|
+
def build_model(
|
173
|
+
checkpoint_path: str, test_model: bool = False, **kwargs
|
174
|
+
) -> nn.Module:
|
175
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
160
176
|
config = (
|
161
|
-
|
177
|
+
get_fake_model_config(**kwargs)
|
162
178
|
if test_model
|
163
179
|
else get_model_config(**kwargs)
|
164
180
|
)
|
@@ -170,7 +186,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
170
186
|
return model
|
171
187
|
|
172
188
|
|
173
|
-
def define_and_run(checkpoint_path, test_model=False) -> None:
|
189
|
+
def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
|
190
|
+
"""Instantiates and runs a Phi-2 model."""
|
191
|
+
|
174
192
|
kv_cache_max_len = 1024
|
175
193
|
model = build_model(
|
176
194
|
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
@@ -185,5 +203,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
|
|
185
203
|
|
186
204
|
|
187
205
|
if __name__ == "__main__":
|
188
|
-
|
189
|
-
define_and_run(
|
206
|
+
input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
|
207
|
+
define_and_run(input_checkpoint_path)
|
@@ -17,20 +17,20 @@
|
|
17
17
|
# Note: This is an experimental version of TinyLlama with external KV cache.
|
18
18
|
# Please use with caution.
|
19
19
|
|
20
|
-
|
21
20
|
import os
|
22
21
|
from pathlib import Path
|
23
22
|
from typing import Tuple
|
24
23
|
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
|
26
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
27
27
|
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
|
29
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
30
29
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
31
30
|
import numpy as np
|
32
31
|
import torch
|
33
|
-
|
32
|
+
from torch import nn
|
33
|
+
|
34
34
|
|
35
35
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
36
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -49,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
49
49
|
|
50
50
|
|
51
51
|
class TinyLLamma(nn.Module):
|
52
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
52
53
|
|
53
54
|
def __init__(self, config: cfg.ModelConfig):
|
54
55
|
super().__init__()
|
@@ -62,7 +63,7 @@ class TinyLLamma(nn.Module):
|
|
62
63
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
63
64
|
)
|
64
65
|
self.transformer_blocks = nn.ModuleList(
|
65
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
66
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
66
67
|
)
|
67
68
|
self.final_norm = builder.build_norm(
|
68
69
|
config.embedding_dim,
|
@@ -92,9 +93,9 @@ class TinyLLamma(nn.Module):
|
|
92
93
|
input_pos: torch.Tensor,
|
93
94
|
kv_cache: kv_utils.EKVCache,
|
94
95
|
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
95
|
-
|
96
|
-
assert self.config.max_seq_len >=
|
97
|
-
f"Cannot forward sequence of length {
|
96
|
+
_, seq_len = tokens.size()
|
97
|
+
assert self.config.max_seq_len >= seq_len, (
|
98
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
98
99
|
f" {self.config.max_seq_len}"
|
99
100
|
)
|
100
101
|
|
@@ -121,6 +122,15 @@ class TinyLLamma(nn.Module):
|
|
121
122
|
|
122
123
|
|
123
124
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
125
|
+
"""Returns the model config for a TinyLlama model.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
129
|
+
is 1024.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
The model config for a TinyLlama model.
|
133
|
+
"""
|
124
134
|
attn_config = cfg.AttentionConfig(
|
125
135
|
num_heads=32,
|
126
136
|
head_dim=64,
|
@@ -149,7 +159,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
149
159
|
return config
|
150
160
|
|
151
161
|
|
152
|
-
def
|
162
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
153
163
|
config = get_model_config(**kwargs)
|
154
164
|
config.vocab_size = 128
|
155
165
|
config.num_layers = 2
|
@@ -157,9 +167,12 @@ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
|
|
157
167
|
return config
|
158
168
|
|
159
169
|
|
160
|
-
def build_model(
|
170
|
+
def build_model(
|
171
|
+
checkpoint_path: str, test_model: bool = False, **kwargs
|
172
|
+
) -> nn.Module:
|
173
|
+
"""Instantiates the model instance and load checkpoint if provided."""
|
161
174
|
config = (
|
162
|
-
|
175
|
+
get_fake_model_config(**kwargs)
|
163
176
|
if test_model
|
164
177
|
else get_model_config(**kwargs)
|
165
178
|
)
|
@@ -171,7 +184,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
171
184
|
return model
|
172
185
|
|
173
186
|
|
174
|
-
def define_and_run(checkpoint_path, test_model=False) -> None:
|
187
|
+
def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
|
188
|
+
"""Instantiates and runs a TinyLlama model."""
|
189
|
+
|
175
190
|
kv_cache_max_len = 1024
|
176
191
|
model = build_model(
|
177
192
|
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
@@ -186,5 +201,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
|
|
186
201
|
|
187
202
|
|
188
203
|
if __name__ == "__main__":
|
189
|
-
|
190
|
-
define_and_run(
|
204
|
+
input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
|
205
|
+
define_and_run(input_checkpoint_path)
|
@@ -17,14 +17,14 @@
|
|
17
17
|
import os
|
18
18
|
from pathlib import Path
|
19
19
|
|
20
|
-
from ai_edge_torch.generative.layers
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
21
|
+
from ai_edge_torch.generative.layers import builder
|
21
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
|
-
|
27
|
+
from torch import nn
|
28
28
|
|
29
29
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
30
30
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
|
44
44
|
|
45
45
|
class Gemma(nn.Module):
|
46
|
+
"""A Gemma model built from the Edge Generative API layers."""
|
46
47
|
|
47
48
|
def __init__(self, config: cfg.ModelConfig):
|
48
49
|
super().__init__()
|
@@ -60,7 +61,7 @@ class Gemma(nn.Module):
|
|
60
61
|
# Gemma re-uses the embedding as the head projection layer.
|
61
62
|
self.lm_head.weight.data = self.tok_embedding.weight.data
|
62
63
|
self.transformer_blocks = nn.ModuleList(
|
63
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
64
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
64
65
|
)
|
65
66
|
self.final_norm = builder.build_norm(
|
66
67
|
config.embedding_dim,
|
@@ -88,9 +89,9 @@ class Gemma(nn.Module):
|
|
88
89
|
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
89
90
|
@torch.inference_mode
|
90
91
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
91
|
-
|
92
|
-
assert self.config.max_seq_len >=
|
93
|
-
f"Cannot forward sequence of length {
|
92
|
+
_, seq_len = idx.size()
|
93
|
+
assert self.config.max_seq_len >= seq_len, (
|
94
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
94
95
|
f" {self.config.max_seq_len}"
|
95
96
|
)
|
96
97
|
|
@@ -104,7 +105,7 @@ class Gemma(nn.Module):
|
|
104
105
|
x = self.tok_embedding(idx)
|
105
106
|
x = x * (self.config.embedding_dim**0.5)
|
106
107
|
|
107
|
-
for
|
108
|
+
for _, block in enumerate(self.transformer_blocks):
|
108
109
|
x = block(x, (cos, sin), mask, input_pos)
|
109
110
|
|
110
111
|
x = self.final_norm(x)
|
@@ -113,6 +114,15 @@ class Gemma(nn.Module):
|
|
113
114
|
|
114
115
|
|
115
116
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
117
|
+
"""Returns the model config for a Gemma 2B model.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
121
|
+
is 1024.
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
The model config for a Gemma 2B model.
|
125
|
+
"""
|
116
126
|
attn_config = cfg.AttentionConfig(
|
117
127
|
num_heads=8,
|
118
128
|
head_dim=256,
|
@@ -147,43 +157,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
147
157
|
return config
|
148
158
|
|
149
159
|
|
150
|
-
# TODO(b/363021962): Clean up this part to streamline fake model config generation.
|
151
160
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
)
|
158
|
-
ff_config = cfg.FeedForwardConfig(
|
159
|
-
type=cfg.FeedForwardType.GATED,
|
160
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
161
|
-
intermediate_size=128,
|
162
|
-
)
|
163
|
-
norm_config = cfg.NormalizationConfig(
|
164
|
-
type=cfg.NormalizationType.RMS_NORM,
|
165
|
-
epsilon=1e-6,
|
166
|
-
zero_centered=True,
|
167
|
-
)
|
168
|
-
config = cfg.ModelConfig(
|
169
|
-
vocab_size=128,
|
170
|
-
num_layers=2,
|
171
|
-
max_seq_len=2 * kv_cache_max_len,
|
172
|
-
embedding_dim=2048,
|
173
|
-
kv_cache_max_len=kv_cache_max_len,
|
174
|
-
attn_config=attn_config,
|
175
|
-
ff_config=ff_config,
|
176
|
-
pre_attention_norm_config=norm_config,
|
177
|
-
post_attention_norm_config=norm_config,
|
178
|
-
final_norm_config=norm_config,
|
179
|
-
parallel_residual=False,
|
180
|
-
lm_head_use_bias=False,
|
181
|
-
enable_hlfb=True,
|
182
|
-
)
|
161
|
+
config = get_model_config_2b(kv_cache_max_len)
|
162
|
+
config.ff_config.intermediate_size = 128
|
163
|
+
config.vocab_size = 128
|
164
|
+
config.num_layers = 2
|
165
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
183
166
|
return config
|
184
167
|
|
185
168
|
|
186
|
-
def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
169
|
+
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
187
170
|
config = get_model_config_2b(**kwargs)
|
188
171
|
model = Gemma(config)
|
189
172
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
@@ -195,6 +178,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
195
178
|
|
196
179
|
|
197
180
|
def define_and_run_2b() -> None:
|
181
|
+
"""Instantiates and runs a Gemma 2B model."""
|
182
|
+
|
198
183
|
current_dir = Path(__file__).parent.resolve()
|
199
184
|
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
200
185
|
|
@@ -18,14 +18,14 @@ import os
|
|
18
18
|
from pathlib import Path
|
19
19
|
from typing import Optional, Tuple
|
20
20
|
|
21
|
-
from ai_edge_torch.generative.layers
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
-
import ai_edge_torch.generative.layers.builder as builder
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
26
|
import numpy as np
|
27
27
|
import torch
|
28
|
-
|
28
|
+
from torch import nn
|
29
29
|
|
30
30
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
31
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -43,7 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
)
|
44
44
|
|
45
45
|
|
46
|
-
class Gemma2Block(TransformerBlock):
|
46
|
+
class Gemma2Block(attention.TransformerBlock):
|
47
47
|
|
48
48
|
def forward(
|
49
49
|
self,
|
@@ -76,6 +76,7 @@ class Gemma2Block(TransformerBlock):
|
|
76
76
|
|
77
77
|
|
78
78
|
class Gemma2(nn.Module):
|
79
|
+
"""A Gemma2 model built from the Edge Generative API layers."""
|
79
80
|
|
80
81
|
def __init__(self, config: cfg.ModelConfig):
|
81
82
|
super().__init__()
|
@@ -138,9 +139,9 @@ class Gemma2(nn.Module):
|
|
138
139
|
|
139
140
|
@torch.inference_mode
|
140
141
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
141
|
-
|
142
|
-
assert self.config.max_seq_len >=
|
143
|
-
f"Cannot forward sequence of length {
|
142
|
+
_, seq_len = idx.size()
|
143
|
+
assert self.config.max_seq_len >= seq_len, (
|
144
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
144
145
|
f" {self.config.max_seq_len}"
|
145
146
|
)
|
146
147
|
|
@@ -166,6 +167,15 @@ class Gemma2(nn.Module):
|
|
166
167
|
|
167
168
|
|
168
169
|
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
170
|
+
"""Returns the model config for a Gemma2 2B model.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
174
|
+
is 1024.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
The model config for a Gemma 2B model.
|
178
|
+
"""
|
169
179
|
attn_config = cfg.AttentionConfig(
|
170
180
|
num_heads=8,
|
171
181
|
head_dim=256,
|
@@ -210,50 +220,19 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
210
220
|
|
211
221
|
|
212
222
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
* 13,
|
223
|
-
)
|
224
|
-
|
225
|
-
norm_config = cfg.NormalizationConfig(
|
226
|
-
type=cfg.NormalizationType.RMS_NORM,
|
227
|
-
epsilon=1e-6,
|
228
|
-
zero_centered=True,
|
229
|
-
)
|
230
|
-
ff_config = cfg.FeedForwardConfig(
|
231
|
-
type=cfg.FeedForwardType.GATED,
|
232
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
233
|
-
intermediate_size=128,
|
234
|
-
pre_ff_norm_config=norm_config,
|
235
|
-
post_ff_norm_config=norm_config,
|
236
|
-
)
|
237
|
-
config = cfg.ModelConfig(
|
238
|
-
vocab_size=128,
|
239
|
-
num_layers=2,
|
240
|
-
max_seq_len=2 * kv_cache_max_len,
|
241
|
-
embedding_dim=128,
|
242
|
-
kv_cache_max_len=kv_cache_max_len,
|
243
|
-
attn_config=attn_config,
|
244
|
-
ff_config=ff_config,
|
245
|
-
pre_attention_norm_config=norm_config,
|
246
|
-
post_attention_norm_config=norm_config,
|
247
|
-
final_norm_config=norm_config,
|
248
|
-
parallel_residual=False,
|
249
|
-
lm_head_use_bias=False,
|
250
|
-
enable_hlfb=True,
|
251
|
-
final_logit_softcap=30.0,
|
252
|
-
)
|
223
|
+
config = get_model_config_2b(kv_cache_max_len)
|
224
|
+
config.attn_config.num_heads = 4
|
225
|
+
config.attn_config.head_dim = 64
|
226
|
+
config.attn_config.sliding_window_size = 64
|
227
|
+
config.ff_config.intermediate_size = 128
|
228
|
+
config.vocab_size = 128
|
229
|
+
config.num_layers = 2
|
230
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
231
|
+
config.embedding_dim = 128
|
253
232
|
return config
|
254
233
|
|
255
234
|
|
256
|
-
def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
235
|
+
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
257
236
|
config = get_model_config_2b(**kwargs)
|
258
237
|
model = Gemma2(config)
|
259
238
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
@@ -265,6 +244,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
265
244
|
|
266
245
|
|
267
246
|
def define_and_run_2b() -> None:
|
247
|
+
"""Instantiates and runs a Gemma2 2B model."""
|
248
|
+
|
268
249
|
current_dir = Path(__file__).parent.resolve()
|
269
250
|
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
270
251
|
print("Running GEMMA 2")
|
@@ -18,14 +18,14 @@
|
|
18
18
|
import os
|
19
19
|
from pathlib import Path
|
20
20
|
|
21
|
-
from ai_edge_torch.generative.layers
|
21
|
+
from ai_edge_torch.generative.layers import attention
|
22
|
+
from ai_edge_torch.generative.layers import builder
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
-
import ai_edge_torch.generative.layers.builder as builder
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
25
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
26
26
|
import numpy as np
|
27
27
|
import torch
|
28
|
-
|
28
|
+
from torch import nn
|
29
29
|
|
30
30
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
31
31
|
ff_up_proj="model.layers.{}.mlp.fc1",
|
@@ -42,6 +42,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
42
42
|
|
43
43
|
|
44
44
|
class Phi2(nn.Module):
|
45
|
+
"""A Phi-2 model built from the Edge Generative API layers."""
|
45
46
|
|
46
47
|
def __init__(self, config: cfg.ModelConfig):
|
47
48
|
super().__init__()
|
@@ -55,7 +56,7 @@ class Phi2(nn.Module):
|
|
55
56
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
56
57
|
)
|
57
58
|
self.transformer_blocks = nn.ModuleList(
|
58
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
59
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
59
60
|
)
|
60
61
|
self.final_norm = builder.build_norm(
|
61
62
|
config.embedding_dim,
|
@@ -83,9 +84,9 @@ class Phi2(nn.Module):
|
|
83
84
|
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
84
85
|
@torch.inference_mode
|
85
86
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
86
|
-
|
87
|
-
assert self.config.max_seq_len >=
|
88
|
-
f"Cannot forward sequence of length {
|
87
|
+
_, seq_len = idx.size()
|
88
|
+
assert self.config.max_seq_len >= seq_len, (
|
89
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
89
90
|
f" {self.config.max_seq_len}"
|
90
91
|
)
|
91
92
|
|
@@ -98,7 +99,7 @@ class Phi2(nn.Module):
|
|
98
99
|
# forward the model itself
|
99
100
|
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
|
100
101
|
|
101
|
-
for
|
102
|
+
for _, block in enumerate(self.transformer_blocks):
|
102
103
|
x = block(x, (cos, sin), mask, input_pos)
|
103
104
|
|
104
105
|
x = self.final_norm(x)
|
@@ -107,6 +108,15 @@ class Phi2(nn.Module):
|
|
107
108
|
|
108
109
|
|
109
110
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
111
|
+
"""Returns the model config for a Phi-2 model.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
115
|
+
is 1024.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
The model config for a Phi-2 model.
|
119
|
+
"""
|
110
120
|
attn_config = cfg.AttentionConfig(
|
111
121
|
num_heads=32,
|
112
122
|
head_dim=80,
|
@@ -140,35 +150,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
150
|
|
141
151
|
|
142
152
|
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
qkv_use_bias=True,
|
149
|
-
output_proj_use_bias=True,
|
150
|
-
)
|
151
|
-
ff_config = cfg.FeedForwardConfig(
|
152
|
-
type=cfg.FeedForwardType.SEQUENTIAL,
|
153
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
154
|
-
intermediate_size=128,
|
155
|
-
use_bias=True,
|
156
|
-
)
|
157
|
-
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
158
|
-
config = cfg.ModelConfig(
|
159
|
-
vocab_size=128,
|
160
|
-
num_layers=2,
|
161
|
-
max_seq_len=2 * kv_cache_max_len,
|
162
|
-
kv_cache_max_len=kv_cache_max_len,
|
163
|
-
embedding_dim=128,
|
164
|
-
attn_config=attn_config,
|
165
|
-
ff_config=ff_config,
|
166
|
-
pre_attention_norm_config=norm_config,
|
167
|
-
final_norm_config=norm_config,
|
168
|
-
parallel_residual=True,
|
169
|
-
lm_head_use_bias=True,
|
170
|
-
enable_hlfb=True,
|
171
|
-
)
|
153
|
+
config = get_model_config(kv_cache_max_len)
|
154
|
+
config.vocab_size = 128
|
155
|
+
config.num_layers = 2
|
156
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
157
|
+
config.ff_config.intermediate_size = 128
|
172
158
|
return config
|
173
159
|
|
174
160
|
|
@@ -181,6 +167,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
181
167
|
|
182
168
|
|
183
169
|
def define_and_run() -> None:
|
170
|
+
"""Instantiates and runs a Phi-2 model."""
|
171
|
+
|
184
172
|
current_dir = Path(__file__).parent.resolve()
|
185
173
|
phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
|
186
174
|
kv_cache_max_len = 1024
|
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
71
71
|
return self.lm_head(x)
|
72
72
|
|
73
73
|
|
74
|
+
class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
75
|
+
|
76
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
77
|
+
super().__init__()
|
78
|
+
self.lm_head = nn.Linear(
|
79
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
80
|
+
)
|
81
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
82
|
+
self.lm_head = nn.Linear(
|
83
|
+
config.embedding_dim,
|
84
|
+
config.vocab_size,
|
85
|
+
bias=config.lm_head_use_bias,
|
86
|
+
)
|
87
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
88
|
+
self.transformer_block = TransformerBlock(config)
|
89
|
+
self.final_norm = builder.build_norm(
|
90
|
+
config.embedding_dim,
|
91
|
+
config.final_norm_config,
|
92
|
+
)
|
93
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
94
|
+
size=config.max_seq_len,
|
95
|
+
dim=int(
|
96
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
97
|
+
),
|
98
|
+
base=10_000,
|
99
|
+
condense_ratio=1,
|
100
|
+
dtype=torch.float32,
|
101
|
+
device=torch.device('cpu'),
|
102
|
+
)
|
103
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
104
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
105
|
+
)
|
106
|
+
self.config = config
|
107
|
+
|
108
|
+
@torch.inference_mode
|
109
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
110
|
+
x = self.tok_embedding(idx)
|
111
|
+
cos, sin = self.rope_cache
|
112
|
+
|
113
|
+
cos = cos.index_select(0, input_pos)
|
114
|
+
sin = sin.index_select(0, input_pos)
|
115
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
116
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
117
|
+
|
118
|
+
x = self.transformer_block(x, (cos, sin), mask, input_pos)
|
119
|
+
x = self.final_norm(x)
|
120
|
+
res = self.lm_head(x)
|
121
|
+
return res
|
122
|
+
|
123
|
+
|
74
124
|
def get_model_config() -> cfg.ModelConfig:
|
75
125
|
attn_config = cfg.AttentionConfig(
|
76
126
|
num_heads=32,
|
@@ -17,14 +17,14 @@
|
|
17
17
|
import os
|
18
18
|
from pathlib import Path
|
19
19
|
|
20
|
-
from ai_edge_torch.generative.layers
|
20
|
+
from ai_edge_torch.generative.layers import attention
|
21
|
+
from ai_edge_torch.generative.layers import builder
|
21
22
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
-
import ai_edge_torch.generative.layers.builder as builder
|
23
23
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
24
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
25
25
|
import numpy as np
|
26
26
|
import torch
|
27
|
-
|
27
|
+
from torch import nn
|
28
28
|
|
29
29
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
30
30
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
|
44
44
|
|
45
45
|
class TinyLLamma(nn.Module):
|
46
|
+
"""A TinyLlama model built from the Edge Generative API layers."""
|
46
47
|
|
47
48
|
def __init__(self, config: cfg.ModelConfig):
|
48
49
|
super().__init__()
|
@@ -56,7 +57,7 @@ class TinyLLamma(nn.Module):
|
|
56
57
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
57
58
|
)
|
58
59
|
self.transformer_blocks = nn.ModuleList(
|
59
|
-
TransformerBlock(config) for _ in range(config.num_layers)
|
60
|
+
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
60
61
|
)
|
61
62
|
self.final_norm = builder.build_norm(
|
62
63
|
config.embedding_dim,
|
@@ -84,9 +85,9 @@ class TinyLLamma(nn.Module):
|
|
84
85
|
# This can be eliminated if we handle k/v cache updates inside the model itself.
|
85
86
|
@torch.inference_mode
|
86
87
|
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
87
|
-
|
88
|
-
assert self.config.max_seq_len >=
|
89
|
-
f"Cannot forward sequence of length {
|
88
|
+
_, seq_len = idx.size()
|
89
|
+
assert self.config.max_seq_len >= seq_len, (
|
90
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
90
91
|
f" {self.config.max_seq_len}"
|
91
92
|
)
|
92
93
|
|
@@ -99,7 +100,7 @@ class TinyLLamma(nn.Module):
|
|
99
100
|
# forward the model itself
|
100
101
|
x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
|
101
102
|
|
102
|
-
for
|
103
|
+
for _, block in enumerate(self.transformer_blocks):
|
103
104
|
x = block(x, (cos, sin), mask, input_pos)
|
104
105
|
|
105
106
|
x = self.final_norm(x)
|
@@ -109,6 +110,15 @@ class TinyLLamma(nn.Module):
|
|
109
110
|
|
110
111
|
|
111
112
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
113
|
+
"""Returns the model config for a TinyLlama model.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
117
|
+
is 1024.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
The model config for a TinyLlama model.
|
121
|
+
"""
|
112
122
|
attn_config = cfg.AttentionConfig(
|
113
123
|
num_heads=32,
|
114
124
|
head_dim=64,
|
@@ -145,7 +155,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
|
|
145
155
|
return config
|
146
156
|
|
147
157
|
|
148
|
-
def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
158
|
+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
149
159
|
config = get_model_config(**kwargs)
|
150
160
|
model = TinyLLamma(config)
|
151
161
|
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
@@ -154,6 +164,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
|
|
154
164
|
|
155
165
|
|
156
166
|
def define_and_run() -> None:
|
167
|
+
"""Instantiates and runs a TinyLlama model."""
|
168
|
+
|
157
169
|
current_dir = Path(__file__).parent.resolve()
|
158
170
|
tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
|
159
171
|
kv_cache_max_len = 1024
|
@@ -70,35 +70,6 @@ class TestModelConversion(googletest.TestCase):
|
|
70
70
|
)
|
71
71
|
)
|
72
72
|
|
73
|
-
@googletest.skipIf(
|
74
|
-
ai_edge_config.Config.use_torch_xla,
|
75
|
-
reason="tests with custom ops are not supported on oss",
|
76
|
-
)
|
77
|
-
def test_toy_model_with_multi_batches(self):
|
78
|
-
self.skipTest("b/362842043")
|
79
|
-
config = toy_model_with_kv_cache.get_model_config()
|
80
|
-
config.batch_size = 2
|
81
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
82
|
-
idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
|
83
|
-
[10], dtype=torch.int64
|
84
|
-
)
|
85
|
-
|
86
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
87
|
-
edge_model.set_interpreter_builder(
|
88
|
-
self._interpreter_builder(edge_model.tflite_model())
|
89
|
-
)
|
90
|
-
|
91
|
-
self.assertTrue(
|
92
|
-
model_coverage.compare_tflite_torch(
|
93
|
-
edge_model,
|
94
|
-
pytorch_model,
|
95
|
-
(idx, input_pos),
|
96
|
-
num_valid_inputs=1,
|
97
|
-
atol=1e-5,
|
98
|
-
rtol=1e-5,
|
99
|
-
)
|
100
|
-
)
|
101
|
-
|
102
73
|
@googletest.skipIf(
|
103
74
|
ai_edge_config.Config.use_torch_xla,
|
104
75
|
reason="tests with custom ops are not supported on oss",
|
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
25
25
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
26
26
|
from ai_edge_torch.quantize import quant_config
|
27
27
|
from ai_edge_torch.testing import model_coverage
|
28
|
-
from parameterized import parameterized
|
29
28
|
import torch
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from absl.testing import parameterized
|
32
32
|
|
33
33
|
|
34
|
-
class TestVerifyRecipes(
|
34
|
+
class TestVerifyRecipes(parameterized.TestCase):
|
35
35
|
"""Unit tests that check for model quantization recipes."""
|
36
36
|
|
37
|
-
@parameterized.
|
37
|
+
@parameterized.parameters([
|
38
38
|
(Dtype.FP32, Dtype.FP32),
|
39
39
|
(Dtype.INT8, Dtype.INT8),
|
40
40
|
(Dtype.INT8, Dtype.FP16),
|
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
52
52
|
with self.assertRaises(ValueError):
|
53
53
|
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
54
54
|
|
55
|
-
@parameterized.
|
55
|
+
@parameterized.parameters([
|
56
56
|
(
|
57
57
|
Dtype.FP32,
|
58
58
|
Dtype.INT8,
|
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
88
88
|
).verify()
|
89
89
|
|
90
90
|
|
91
|
-
class TestQuantizeConvert(
|
91
|
+
class TestQuantizeConvert(parameterized.TestCase):
|
92
92
|
"""Test conversion with quantization."""
|
93
93
|
|
94
94
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
105
105
|
)
|
106
106
|
)
|
107
107
|
|
108
|
-
@parameterized.
|
108
|
+
@parameterized.parameters([
|
109
109
|
(quant_recipes.full_fp16_recipe()),
|
110
110
|
(quant_recipes.full_int8_dynamic_recipe()),
|
111
111
|
(quant_recipes.full_int8_weight_only_recipe()),
|
112
112
|
(_attention_int8_dynamic_recipe()),
|
113
113
|
(_feedforward_int8_dynamic_recipe()),
|
114
114
|
])
|
115
|
-
@googletest.skipIf(
|
116
|
-
not config.Config.use_torch_xla,
|
117
|
-
reason="Not working with odml_torch at the moment.",
|
118
|
-
)
|
119
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
120
116
|
config = toy_model.get_model_config()
|
121
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
132
128
|
"Quantized model isn't smaller than F32 model.",
|
133
129
|
)
|
134
130
|
|
131
|
+
def test_quantize_convert_toy_weight_sharing(self):
|
132
|
+
config = toy_model.get_model_config()
|
133
|
+
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
+
input_pos = torch.arange(0, 100)
|
136
|
+
|
137
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
|
+
quantized_model = ai_edge_torch.convert(
|
139
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
140
|
+
)
|
141
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
142
|
+
self.assertLess(
|
143
|
+
len(quantized_model._tflite_model),
|
144
|
+
len(float_model._tflite_model),
|
145
|
+
"Quantized model isn't smaller than F32 model.",
|
146
|
+
)
|
147
|
+
|
135
148
|
def test_quantize_convert_compare_toy(self):
|
136
149
|
self.skipTest("b/338288901")
|
137
150
|
config = toy_model_with_kv_cache.get_model_config()
|
@@ -208,7 +208,7 @@ class ModelLoader:
|
|
208
208
|
if self._file_name.endswith(".safetensors"):
|
209
209
|
return load_safetensors
|
210
210
|
|
211
|
-
if self._file_name.endswith(".bin") or self._file_name.endswith("
|
211
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
|
212
212
|
return load_pytorch_statedict
|
213
213
|
|
214
214
|
raise ValueError("File format not supported.")
|
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
|
|
21
21
|
from ai_edge_torch._convert import conversion_utils
|
22
22
|
from ai_edge_torch._convert import signature as signature_module
|
23
23
|
from ai_edge_torch.lowertools import common_utils
|
24
|
+
from ai_edge_torch.lowertools import translate_recipe
|
24
25
|
from ai_edge_torch.odml_torch import export
|
25
26
|
from ai_edge_torch.odml_torch import export_utils
|
26
27
|
from ai_edge_torch.quantize import quant_config as qcfg
|
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
|
|
186
187
|
converter._experimental_enable_composite_direct_lowering = True
|
187
188
|
converter.model_origin_framework = "PYTORCH"
|
188
189
|
|
190
|
+
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
191
|
+
if (
|
192
|
+
quant_config is not None
|
193
|
+
and quant_config._quantizer_mode
|
194
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
195
|
+
):
|
196
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
197
|
+
quant_config.generative_recipe
|
198
|
+
)
|
199
|
+
|
189
200
|
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
190
201
|
|
191
202
|
tflite_model = converter.convert()
|
192
203
|
|
204
|
+
if (
|
205
|
+
quant_config is not None
|
206
|
+
and quant_config._quantizer_mode
|
207
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
208
|
+
):
|
209
|
+
tflite_model = translate_recipe.quantize_model(
|
210
|
+
tflite_model, translated_recipe
|
211
|
+
)
|
212
|
+
|
193
213
|
return tflite_model
|
194
214
|
|
195
215
|
|
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
25
25
|
from ai_edge_torch import model
|
26
26
|
from ai_edge_torch._convert import conversion_utils
|
27
27
|
from ai_edge_torch._convert import signature as signature_module
|
28
|
-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
29
28
|
from ai_edge_torch.lowertools import common_utils
|
29
|
+
from ai_edge_torch.lowertools import translate_recipe
|
30
30
|
from ai_edge_torch.quantize import quant_config as qcfg
|
31
31
|
import torch
|
32
32
|
from torch_xla import stablehlo
|
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
|
|
17
17
|
from ai_edge_torch.generative.quantize import quant_attrs
|
18
18
|
from ai_edge_torch.generative.quantize import quant_recipe
|
19
19
|
|
20
|
-
|
20
|
+
_ComputePrecision = quantizer.qtyping.ComputePrecision
|
21
|
+
_QuantGranularity = quantizer.qtyping.QuantGranularity
|
21
22
|
_OpName = quantizer.qtyping.TFLOperationName
|
22
23
|
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
23
24
|
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
|
|
50
51
|
return quantizer.qtyping.TensorDataType.INT
|
51
52
|
|
52
53
|
|
53
|
-
def
|
54
|
+
def _get_compute_precision_from_mode(
|
55
|
+
mode: quant_attrs.Mode,
|
56
|
+
) -> _ComputePrecision:
|
54
57
|
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
55
|
-
return
|
58
|
+
return _ComputePrecision.INTEGER
|
56
59
|
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
57
|
-
return
|
60
|
+
return _ComputePrecision.FLOAT
|
58
61
|
raise ValueError('Unimplemented execution mode')
|
59
62
|
|
60
63
|
|
61
|
-
def
|
64
|
+
def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
|
65
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
66
|
+
return False
|
67
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
68
|
+
return True
|
69
|
+
raise ValueError('Unimplemented execution mode')
|
70
|
+
|
71
|
+
|
72
|
+
def _get_granularity(
|
62
73
|
granularity: quant_attrs.Granularity,
|
63
74
|
) -> bool:
|
64
75
|
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
65
|
-
return
|
66
|
-
|
67
|
-
return
|
76
|
+
return _QuantGranularity.CHANNELWISE
|
77
|
+
if granularity == quant_attrs.Granularity.NONE:
|
78
|
+
return _QuantGranularity.TENSORWISE
|
68
79
|
raise ValueError('Unimplemented granularity')
|
69
80
|
|
70
81
|
|
@@ -88,12 +99,13 @@ def _set_quant_config(
|
|
88
99
|
weight_tensor_config=_TensorQuantConfig(
|
89
100
|
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
90
101
|
symmetric=True,
|
91
|
-
|
92
|
-
layer_recipe.granularity
|
93
|
-
),
|
102
|
+
granularity=_get_granularity(layer_recipe.granularity),
|
94
103
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
95
104
|
),
|
96
|
-
|
105
|
+
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
|
106
|
+
explicit_dequantize=_get_explicit_dequant_from_mode(
|
107
|
+
layer_recipe.mode
|
108
|
+
),
|
97
109
|
),
|
98
110
|
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
99
111
|
)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240905
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
-
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: ai-edge-quantizer-nightly
|
34
34
|
|
35
35
|
Library that supports converting PyTorch models into a .tflite format, which can
|
36
36
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256
|
5
|
+
ai_edge_torch/version.py,sha256=-vQGdl2EaV-VpHRty3RwZzH0UVntVt1tmjhtKOIDscw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -42,21 +42,21 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
42
42
|
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
43
|
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
44
44
|
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=
|
45
|
+
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
|
46
46
|
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
48
|
-
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=
|
48
|
+
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
|
49
49
|
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
50
|
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
|
51
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
|
52
52
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
53
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
|
54
54
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
|
55
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
|
56
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
|
57
57
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
58
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
|
59
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=
|
59
|
+
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
|
60
60
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
61
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
|
@@ -77,12 +77,12 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
|
|
77
77
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
|
78
78
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
|
79
79
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
|
81
81
|
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
|
82
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
83
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
84
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
|
85
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
85
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
|
86
86
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
87
87
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
88
88
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
|
|
106
106
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
107
107
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
108
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
|
111
109
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
112
110
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
113
111
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
114
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
112
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=KZ0uCeOdKMKyW8jBE8aOjweZmws4mvz37u8zH4XayVU,5285
|
115
113
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
|
116
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
114
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
117
115
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
118
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
116
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
119
117
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
120
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
121
119
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
|
|
128
126
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
129
127
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
130
128
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
131
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
129
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
132
130
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
133
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256
|
131
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
|
132
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
|
134
133
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
135
134
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
136
135
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
137
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
136
|
+
ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
|
138
137
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
139
138
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
140
139
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
162
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
163
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
164
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/METADATA,sha256=8yrrm7TEYgaRhKdUwgStjCqrTWs8YcnnlzoTJt2NrJg,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
File without changes
|
File without changes
|