ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -3
- ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
- ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
- ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
- ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
- ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
- ai_edge_torch/generative/examples/hammer/verify.py +86 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/llama/llama.py +3 -1
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/phi/phi2.py +1 -1
- ai_edge_torch/generative/examples/phi/phi3.py +3 -1
- ai_edge_torch/generative/examples/phi/phi4.py +3 -1
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -3
- ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
- ai_edge_torch/generative/layers/kv_cache.py +2 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
- ai_edge_torch/generative/test/test_model_conversion.py +3 -33
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -75
- ai_edge_torch/generative/utilities/converter.py +11 -1
- ai_edge_torch/generative/utilities/export_config.py +30 -0
- ai_edge_torch/model.py +2 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +41 -39
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -35,13 +35,11 @@ def _run_convert_passes(
|
|
35
35
|
)
|
36
36
|
|
37
37
|
passes = [
|
38
|
-
fx_passes.CastInputsBf16ToF32Pass(),
|
39
|
-
fx_passes.BuildInterpolateCompositePass(),
|
40
|
-
fx_passes.CanonicalizePass(),
|
41
38
|
fx_passes.OptimizeLayoutTransposesPass(),
|
42
39
|
fx_passes.CanonicalizePass(),
|
43
40
|
fx_passes.BuildAtenCompositePass(),
|
44
41
|
fx_passes.RemoveNonUserOutputsPass(),
|
42
|
+
fx_passes.CastInputsBf16ToF32Pass(),
|
45
43
|
]
|
46
44
|
|
47
45
|
# Debuginfo is not injected automatically by odml_torch. Only inject
|
@@ -16,7 +16,6 @@
|
|
16
16
|
from typing import Sequence, Union
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
|
-
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
19
|
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
21
20
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
22
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
@@ -20,7 +20,8 @@ import torch
|
|
20
20
|
import torch.utils._pytree as pytree
|
21
21
|
|
22
22
|
_composite_builders: dict[
|
23
|
-
Callable
|
23
|
+
Callable[[Any, ...], Any],
|
24
|
+
Callable[[torch.fx.GraphModule, torch.fx.Node], None],
|
24
25
|
] = {}
|
25
26
|
|
26
27
|
|
@@ -272,13 +273,73 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
272
273
|
output = op(**full_kwargs)
|
273
274
|
output = builder.mark_outputs(output)
|
274
275
|
|
275
|
-
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
276
|
+
# Explicitly reshape back to the original shape. This places the ReshapeOp
|
277
|
+
# outside of the HLFB.
|
276
278
|
output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
|
277
279
|
return output
|
278
280
|
|
279
281
|
node.target = embedding
|
280
282
|
|
281
283
|
|
284
|
+
@_register_composite_builder(torch.ops.aten.upsample_bilinear2d.vec)
|
285
|
+
def _aten_upsample_bilinear2d_vec(_, node: torch.fx.Node):
|
286
|
+
"""Build a composite for aten.upsample_bilinear2d.vec."""
|
287
|
+
op = node.target
|
288
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
289
|
+
# Assumes later FX passes does not change the args/kwargs of the op.
|
290
|
+
# Which is a valid assumption for, given that composite/mark_tensor wrapper
|
291
|
+
# should semantically prevents any future mutations on the op.
|
292
|
+
output_h, output_w = node.meta["val"].shape[-2:]
|
293
|
+
|
294
|
+
def upsample_bilinear2d_vec(*args, **kwargs):
|
295
|
+
nonlocal op, args_mapper
|
296
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
297
|
+
|
298
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
299
|
+
name="odml.upsample_bilinear2d",
|
300
|
+
attr={
|
301
|
+
"size": (int(output_h), int(output_w)),
|
302
|
+
"align_corners": full_kwargs["align_corners"],
|
303
|
+
"is_nchw_op": True,
|
304
|
+
},
|
305
|
+
)
|
306
|
+
full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
|
307
|
+
output = op(**full_kwargs)
|
308
|
+
output = builder.mark_outputs(output)
|
309
|
+
return output
|
310
|
+
|
311
|
+
node.target = upsample_bilinear2d_vec
|
312
|
+
|
313
|
+
|
314
|
+
@_register_composite_builder(torch.ops.aten.upsample_nearest2d.vec)
|
315
|
+
def _aten_upsample_nearest2d_vec(_, node: torch.fx.Node):
|
316
|
+
"""Build a composite for aten.upsample_nearest2d.vec."""
|
317
|
+
op = node.target
|
318
|
+
args_mapper = TorchOpArgumentsMapper(op)
|
319
|
+
# Assumes later FX passes does not change the args/kwargs of the op.
|
320
|
+
# Which is a valid assumption for, given that composite/mark_tensor wrapper
|
321
|
+
# should semantically prevents any future mutations on the op.
|
322
|
+
output_h, output_w = node.meta["val"].shape[-2:]
|
323
|
+
|
324
|
+
def upsample_nearest2d_vec(*args, **kwargs):
|
325
|
+
nonlocal op, args_mapper
|
326
|
+
full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
|
327
|
+
|
328
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
329
|
+
name="tfl.resize_nearest_neighbor",
|
330
|
+
attr={
|
331
|
+
"size": (int(output_h), int(output_w)),
|
332
|
+
"is_nchw_op": True,
|
333
|
+
},
|
334
|
+
)
|
335
|
+
full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
|
336
|
+
output = op(**full_kwargs)
|
337
|
+
output = builder.mark_outputs(output)
|
338
|
+
return output
|
339
|
+
|
340
|
+
node.target = upsample_nearest2d_vec
|
341
|
+
|
342
|
+
|
282
343
|
class BuildAtenCompositePass(fx_infra.PassBase):
|
283
344
|
|
284
345
|
def call(self, graph_module: torch.fx.GraphModule):
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import operator
|
18
18
|
|
19
19
|
import ai_edge_torch
|
20
|
+
from ai_edge_torch import lowertools
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -24,7 +25,7 @@ import torch
|
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
26
27
|
aten = torch.ops.aten
|
27
|
-
StableHLOCompositeBuilder =
|
28
|
+
StableHLOCompositeBuilder = lowertools.StableHLOCompositeBuilder
|
28
29
|
|
29
30
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
30
31
|
|
@@ -17,11 +17,11 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
20
21
|
from ai_edge_torch.generative.utilities import converter
|
21
22
|
from ai_edge_torch.generative.utilities import export_config
|
22
23
|
|
23
|
-
flags = converter.define_conversion_flags(
|
24
|
-
ExportConfig = export_config.ExportConfig
|
24
|
+
flags = converter.define_conversion_flags('deepseek')
|
25
25
|
|
26
26
|
def main(_):
|
27
27
|
pytorch_model = deepseek.build_model(
|
@@ -34,7 +34,7 @@ def main(_):
|
|
34
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
35
35
|
quantize=flags.FLAGS.quantize,
|
36
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
37
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
38
38
|
)
|
39
39
|
|
40
40
|
|
@@ -53,6 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
53
53
|
norm_config = cfg.NormalizationConfig(
|
54
54
|
type=cfg.NormalizationType.RMS_NORM,
|
55
55
|
epsilon=1e-06,
|
56
|
+
enable_hlfb=True,
|
56
57
|
)
|
57
58
|
block_config = cfg.TransformerBlockConfig(
|
58
59
|
attn_config=attn_config,
|
@@ -17,14 +17,10 @@
|
|
17
17
|
|
18
18
|
from absl import app
|
19
19
|
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
20
|
-
from ai_edge_torch.generative.layers import kv_cache
|
21
20
|
from ai_edge_torch.generative.utilities import converter
|
22
21
|
from ai_edge_torch.generative.utilities import export_config
|
23
|
-
import torch
|
24
22
|
|
25
23
|
flags = converter.define_conversion_flags('gemma3-1b')
|
26
|
-
ExportConfig = export_config.ExportConfig
|
27
|
-
|
28
24
|
|
29
25
|
_MODEL_SIZE = flags.DEFINE_string(
|
30
26
|
'model_size',
|
@@ -33,55 +29,23 @@ _MODEL_SIZE = flags.DEFINE_string(
|
|
33
29
|
)
|
34
30
|
|
35
31
|
|
36
|
-
def _create_mask(mask_len, kv_cache_max_len):
|
37
|
-
mask = torch.full(
|
38
|
-
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
39
|
-
)
|
40
|
-
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
41
|
-
return mask
|
42
|
-
|
43
|
-
|
44
|
-
def _create_export_config(
|
45
|
-
prefill_seq_lens: list[int], kv_cache_max_len: int
|
46
|
-
) -> ExportConfig:
|
47
|
-
"""Creates the export config for the model."""
|
48
|
-
export_config = ExportConfig()
|
49
|
-
if isinstance(prefill_seq_lens, list):
|
50
|
-
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
51
|
-
else:
|
52
|
-
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
53
|
-
|
54
|
-
export_config.prefill_mask = prefill_mask
|
55
|
-
|
56
|
-
decode_mask = torch.full(
|
57
|
-
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
58
|
-
)
|
59
|
-
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
60
|
-
export_config.decode_mask = decode_mask
|
61
|
-
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
62
|
-
return export_config
|
63
|
-
|
64
|
-
|
65
32
|
def main(_):
|
66
33
|
if _MODEL_SIZE.value == '1b':
|
67
34
|
pytorch_model = gemma3.build_model_1b(
|
68
35
|
flags.FLAGS.checkpoint_path,
|
69
36
|
kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
|
70
37
|
)
|
71
|
-
config = pytorch_model.config
|
72
38
|
else:
|
73
39
|
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
40
|
+
|
74
41
|
converter.convert_to_tflite(
|
75
42
|
pytorch_model,
|
76
43
|
output_path=flags.FLAGS.output_path,
|
77
44
|
output_name_prefix=flags.FLAGS.output_name_prefix,
|
78
45
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
79
46
|
quantize=flags.FLAGS.quantize,
|
80
|
-
config=config,
|
81
47
|
lora_ranks=flags.FLAGS.lora_ranks,
|
82
|
-
export_config=
|
83
|
-
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
84
|
-
),
|
48
|
+
export_config=export_config.get_from_flags(),
|
85
49
|
)
|
86
50
|
|
87
51
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting hammer 2.1 models to multi-signature tflite model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
20
|
+
from ai_edge_torch.generative.layers import kv_cache
|
21
|
+
from ai_edge_torch.generative.utilities import converter
|
22
|
+
from ai_edge_torch.generative.utilities import export_config as export_cfg
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
flags = converter.define_conversion_flags('hammer')
|
27
|
+
ExportConfig = export_cfg.ExportConfig
|
28
|
+
|
29
|
+
|
30
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
31
|
+
'model_size',
|
32
|
+
'1.5b',
|
33
|
+
['0.5b', '1.5b'],
|
34
|
+
'The size of the model to convert.',
|
35
|
+
)
|
36
|
+
|
37
|
+
_BUILDER = {
|
38
|
+
'0.5b': hammer.build_0_5b_model,
|
39
|
+
'1.5b': hammer.build_1_5b_model,
|
40
|
+
}
|
41
|
+
|
42
|
+
|
43
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
44
|
+
mask = torch.full(
|
45
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
46
|
+
)
|
47
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
48
|
+
return mask
|
49
|
+
|
50
|
+
|
51
|
+
def _create_export_config(
|
52
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
53
|
+
) -> ExportConfig:
|
54
|
+
"""Creates the export config for the model."""
|
55
|
+
export_config = ExportConfig()
|
56
|
+
if isinstance(prefill_seq_lens, list):
|
57
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
58
|
+
else:
|
59
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
60
|
+
|
61
|
+
export_config.prefill_mask = prefill_mask
|
62
|
+
|
63
|
+
decode_mask = torch.full(
|
64
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
65
|
+
)
|
66
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
67
|
+
export_config.decode_mask = decode_mask
|
68
|
+
export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
|
69
|
+
return export_config
|
70
|
+
|
71
|
+
|
72
|
+
def main(_):
|
73
|
+
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
74
|
+
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
75
|
+
)
|
76
|
+
converter.convert_to_tflite(
|
77
|
+
pytorch_model,
|
78
|
+
output_path=flags.FLAGS.output_path,
|
79
|
+
output_name_prefix=flags.FLAGS.output_name_prefix,
|
80
|
+
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
81
|
+
quantize=flags.FLAGS.quantize,
|
82
|
+
lora_ranks=flags.FLAGS.lora_ranks,
|
83
|
+
export_config=_create_export_config(
|
84
|
+
flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
|
85
|
+
)
|
86
|
+
if flags.FLAGS.transpose_kv_cache
|
87
|
+
else ExportConfig(),
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
if __name__ == '__main__':
|
92
|
+
app.run(main)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building Hammer 2.1 models."""
|
17
|
+
|
18
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
19
|
+
from ai_edge_torch.generative.utilities import model_builder
|
20
|
+
from torch import nn
|
21
|
+
|
22
|
+
TENSOR_NAMES = model_builder.TENSOR_NAMES
|
23
|
+
|
24
|
+
|
25
|
+
class Hammer(model_builder.DecoderOnlyModel):
|
26
|
+
"""A Hammer model built from the Edge Generative API layers."""
|
27
|
+
pass
|
28
|
+
|
29
|
+
|
30
|
+
def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
31
|
+
"""Returns the model config for a Hammer 2.1 1.5B model."""
|
32
|
+
attn_config = cfg.AttentionConfig(
|
33
|
+
num_heads=12,
|
34
|
+
head_dim=128,
|
35
|
+
num_query_groups=2,
|
36
|
+
rotary_base=1000000,
|
37
|
+
rotary_percentage=1.0,
|
38
|
+
qkv_use_bias=True,
|
39
|
+
)
|
40
|
+
ff_config = cfg.FeedForwardConfig(
|
41
|
+
type=cfg.FeedForwardType.GATED,
|
42
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
43
|
+
intermediate_size=8960,
|
44
|
+
)
|
45
|
+
norm_config = cfg.NormalizationConfig(
|
46
|
+
type=cfg.NormalizationType.RMS_NORM,
|
47
|
+
epsilon=1e-06,
|
48
|
+
enable_hlfb=True,
|
49
|
+
)
|
50
|
+
block_config = cfg.TransformerBlockConfig(
|
51
|
+
attn_config=attn_config,
|
52
|
+
ff_config=ff_config,
|
53
|
+
pre_attention_norm_config=norm_config,
|
54
|
+
post_attention_norm_config=norm_config,
|
55
|
+
)
|
56
|
+
config = cfg.ModelConfig(
|
57
|
+
vocab_size=151665,
|
58
|
+
num_layers=28,
|
59
|
+
max_seq_len=32768,
|
60
|
+
embedding_dim=1536,
|
61
|
+
kv_cache_max_len=kv_cache_max_len,
|
62
|
+
block_configs=block_config,
|
63
|
+
final_norm_config=norm_config,
|
64
|
+
enable_hlfb=True,
|
65
|
+
)
|
66
|
+
return config
|
67
|
+
|
68
|
+
|
69
|
+
def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
70
|
+
"""Returns the model config for a Hammer 2.1 0.5B model."""
|
71
|
+
config = get_1_5b_model_config(kv_cache_max_len)
|
72
|
+
# Hammer has only one block config.
|
73
|
+
block_config = config.block_config(0)
|
74
|
+
block_config.attn_config.num_heads = 14
|
75
|
+
block_config.attn_config.head_dim = 64
|
76
|
+
block_config.ff_config.intermediate_size = 4864
|
77
|
+
config.num_layers = 24
|
78
|
+
config.embedding_dim = 896
|
79
|
+
return config
|
80
|
+
|
81
|
+
|
82
|
+
def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
|
83
|
+
config = get_1_5b_model_config(**kwargs)
|
84
|
+
config.vocab_size = 128
|
85
|
+
config.num_layers = 2
|
86
|
+
config.embedding_dim = 16
|
87
|
+
# Hammer has only one block config.
|
88
|
+
config.block_config(0).ff_config.intermediate_size = 64
|
89
|
+
return config
|
90
|
+
|
91
|
+
|
92
|
+
def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
93
|
+
return model_builder.build_decoder_only_model(
|
94
|
+
checkpoint_path=checkpoint_path,
|
95
|
+
config=get_1_5b_model_config(**kwargs),
|
96
|
+
tensor_names=TENSOR_NAMES,
|
97
|
+
model_class=Hammer,
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
102
|
+
return model_builder.build_decoder_only_model(
|
103
|
+
checkpoint_path=checkpoint_path,
|
104
|
+
config=get_0_5b_model_config(**kwargs),
|
105
|
+
tensor_names=TENSOR_NAMES,
|
106
|
+
model_class=Hammer,
|
107
|
+
)
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.hammer import hammer
|
24
|
+
from ai_edge_torch.generative.utilities import transformers_verifier
|
25
|
+
from ai_edge_torch.generative.utilities import verifier
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
|
29
|
+
_MODEL_SIZE = flags.DEFINE_enum(
|
30
|
+
"model_size",
|
31
|
+
"0.5b",
|
32
|
+
["0.5b", "1.5b"],
|
33
|
+
"The size of the model to verify.",
|
34
|
+
)
|
35
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
36
|
+
"prompts",
|
37
|
+
"What is the meaning of life?",
|
38
|
+
"The input prompts to generate answers.",
|
39
|
+
)
|
40
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
41
|
+
"max_new_tokens",
|
42
|
+
30,
|
43
|
+
"The maximum size of the generated tokens.",
|
44
|
+
)
|
45
|
+
|
46
|
+
_CHECKPOINT = {
|
47
|
+
"0.5b": "MadeAgents/Hammer2.1-0.5b",
|
48
|
+
"1.5b": "MadeAgents/Hammer2.1-1.5b",
|
49
|
+
}
|
50
|
+
|
51
|
+
_BUILDER = {
|
52
|
+
"0.5b": hammer.build_0_5b_model,
|
53
|
+
"1.5b": hammer.build_1_5b_model,
|
54
|
+
}
|
55
|
+
|
56
|
+
|
57
|
+
def main(_):
|
58
|
+
checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
|
59
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
60
|
+
original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
|
61
|
+
|
62
|
+
# Locate the cached dir.
|
63
|
+
cached_config_file = transformers.utils.cached_file(
|
64
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
65
|
+
)
|
66
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
67
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
68
|
+
reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
|
69
|
+
|
70
|
+
logging.info("Loading the tokenizer from: %s", checkpoint)
|
71
|
+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
72
|
+
|
73
|
+
verifier.verify_reauthored_model(
|
74
|
+
original_model=transformers_verifier.TransformersModelWrapper(
|
75
|
+
original_model
|
76
|
+
),
|
77
|
+
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
78
|
+
tokenizer=verifier.TokenizerWrapper(tokenizer),
|
79
|
+
generate_prompts=_PROMPTS.value,
|
80
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
81
|
+
atol=1e-04,
|
82
|
+
)
|
83
|
+
|
84
|
+
|
85
|
+
if __name__ == "__main__":
|
86
|
+
app.run(main)
|
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
|
|
22
22
|
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags('llama')
|
25
|
-
ExportConfig = export_config.ExportConfig
|
26
|
-
|
27
25
|
|
28
26
|
_MODEL_SIZE = flags.DEFINE_enum(
|
29
27
|
'model_size',
|
@@ -49,7 +47,7 @@ def main(_):
|
|
49
47
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
50
48
|
quantize=flags.FLAGS.quantize,
|
51
49
|
lora_ranks=flags.FLAGS.lora_ranks,
|
52
|
-
export_config=
|
50
|
+
export_config=export_config.get_from_flags(),
|
53
51
|
)
|
54
52
|
|
55
53
|
|
@@ -121,7 +121,9 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
121
121
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
|
122
122
|
intermediate_size=8192,
|
123
123
|
)
|
124
|
-
norm_config = cfg.NormalizationConfig(
|
124
|
+
norm_config = cfg.NormalizationConfig(
|
125
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
126
|
+
)
|
125
127
|
block_config = cfg.TransformerBlockConfig(
|
126
128
|
attn_config=attn_config,
|
127
129
|
ff_config=ff_config,
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("phi3")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags("phi4")
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
24
|
|
26
25
|
|
27
26
|
def main(_):
|
@@ -35,7 +34,7 @@ def main(_):
|
|
35
34
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
36
35
|
quantize=flags.FLAGS.quantize,
|
37
36
|
lora_ranks=flags.FLAGS.lora_ranks,
|
38
|
-
export_config=
|
37
|
+
export_config=export_config.get_from_flags(),
|
39
38
|
)
|
40
39
|
|
41
40
|
|
@@ -22,7 +22,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
22
22
|
from ai_edge_torch.generative.utilities import export_config
|
23
23
|
|
24
24
|
flags = converter.define_conversion_flags("phi2")
|
25
|
-
ExportConfig = export_config.ExportConfig
|
26
25
|
|
27
26
|
|
28
27
|
def main(_):
|
@@ -36,7 +35,7 @@ def main(_):
|
|
36
35
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
37
36
|
quantize=flags.FLAGS.quantize,
|
38
37
|
lora_ranks=flags.FLAGS.lora_ranks,
|
39
|
-
export_config=
|
38
|
+
export_config=export_config.get_from_flags(),
|
40
39
|
)
|
41
40
|
|
42
41
|
|
@@ -65,7 +65,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
65
65
|
use_bias=True,
|
66
66
|
)
|
67
67
|
norm_config = cfg.NormalizationConfig(
|
68
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
68
|
+
type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
|
69
69
|
)
|
70
70
|
block_config = cfg.TransformerBlockConfig(
|
71
71
|
attn_config=attn_config,
|
@@ -162,7 +162,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
162
162
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
163
163
|
intermediate_size=8192,
|
164
164
|
)
|
165
|
-
norm_config = cfg.NormalizationConfig(
|
165
|
+
norm_config = cfg.NormalizationConfig(
|
166
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
|
167
|
+
)
|
166
168
|
block_config = cfg.TransformerBlockConfig(
|
167
169
|
attn_config=attn_config,
|
168
170
|
ff_config=ff_config,
|
@@ -112,7 +112,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
112
112
|
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
113
113
|
intermediate_size=8192,
|
114
114
|
)
|
115
|
-
norm_config = cfg.NormalizationConfig(
|
115
|
+
norm_config = cfg.NormalizationConfig(
|
116
|
+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
|
117
|
+
)
|
116
118
|
block_config = cfg.TransformerBlockConfig(
|
117
119
|
attn_config=attn_config,
|
118
120
|
ff_config=ff_config,
|
@@ -21,8 +21,6 @@ from ai_edge_torch.generative.utilities import converter
|
|
21
21
|
from ai_edge_torch.generative.utilities import export_config
|
22
22
|
|
23
23
|
flags = converter.define_conversion_flags('qwen')
|
24
|
-
ExportConfig = export_config.ExportConfig
|
25
|
-
|
26
24
|
|
27
25
|
_MODEL_SIZE = flags.DEFINE_enum(
|
28
26
|
'model_size',
|
@@ -37,6 +35,7 @@ _BUILDER = {
|
|
37
35
|
'3b': qwen.build_3b_model,
|
38
36
|
}
|
39
37
|
|
38
|
+
|
40
39
|
def main(_):
|
41
40
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
42
41
|
flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
|
@@ -48,7 +47,7 @@ def main(_):
|
|
48
47
|
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
|
49
48
|
quantize=flags.FLAGS.quantize,
|
50
49
|
lora_ranks=flags.FLAGS.lora_ranks,
|
51
|
-
export_config=
|
50
|
+
export_config=export_config.get_from_flags(),
|
52
51
|
)
|
53
52
|
|
54
53
|
|