ai-edge-torch-nightly 0.5.0.dev20250424__py3-none-any.whl → 0.5.0.dev20250426__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ai_edge_torch/_convert/conversion.py +1 -3
  2. ai_edge_torch/_convert/fx_passes/__init__.py +0 -1
  3. ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +63 -2
  4. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
  5. ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +3 -3
  6. ai_edge_torch/generative/examples/deepseek/deepseek.py +1 -0
  7. ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -38
  8. ai_edge_torch/generative/examples/hammer/__init__.py +14 -0
  9. ai_edge_torch/generative/examples/hammer/convert_to_tflite.py +92 -0
  10. ai_edge_torch/generative/examples/hammer/hammer.py +107 -0
  11. ai_edge_torch/generative/examples/hammer/verify.py +86 -0
  12. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +1 -3
  13. ai_edge_torch/generative/examples/llama/llama.py +3 -1
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +1 -2
  15. ai_edge_torch/generative/examples/phi/convert_phi4_to_tflite.py +1 -2
  16. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +1 -2
  17. ai_edge_torch/generative/examples/phi/phi2.py +1 -1
  18. ai_edge_torch/generative/examples/phi/phi3.py +3 -1
  19. ai_edge_torch/generative/examples/phi/phi4.py +3 -1
  20. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -3
  21. ai_edge_torch/generative/examples/qwen/qwen.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +5 -3
  23. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +4 -3
  24. ai_edge_torch/generative/examples/smollm/smollm.py +3 -1
  25. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +1 -2
  26. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +3 -1
  27. ai_edge_torch/generative/layers/kv_cache.py +2 -4
  28. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +51 -0
  29. ai_edge_torch/generative/layers/sdpa_with_kv_update.py +4 -6
  30. ai_edge_torch/generative/test/test_model_conversion.py +3 -33
  31. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -75
  32. ai_edge_torch/generative/utilities/converter.py +11 -1
  33. ai_edge_torch/generative/utilities/export_config.py +30 -0
  34. ai_edge_torch/model.py +2 -0
  35. ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +2 -0
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/RECORD +41 -39
  39. ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +0 -129
  40. ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +0 -93
  41. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/LICENSE +0 -0
  42. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/WHEEL +0 -0
  43. {ai_edge_torch_nightly-0.5.0.dev20250424.dist-info → ai_edge_torch_nightly-0.5.0.dev20250426.dist-info}/top_level.txt +0 -0
@@ -35,13 +35,11 @@ def _run_convert_passes(
35
35
  )
36
36
 
37
37
  passes = [
38
- fx_passes.CastInputsBf16ToF32Pass(),
39
- fx_passes.BuildInterpolateCompositePass(),
40
- fx_passes.CanonicalizePass(),
41
38
  fx_passes.OptimizeLayoutTransposesPass(),
42
39
  fx_passes.CanonicalizePass(),
43
40
  fx_passes.BuildAtenCompositePass(),
44
41
  fx_passes.RemoveNonUserOutputsPass(),
42
+ fx_passes.CastInputsBf16ToF32Pass(),
45
43
  ]
46
44
 
47
45
  # Debuginfo is not injected automatically by odml_torch. Only inject
@@ -16,7 +16,6 @@
16
16
  from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
- from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
19
  from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
21
20
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
22
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
@@ -20,7 +20,8 @@ import torch
20
20
  import torch.utils._pytree as pytree
21
21
 
22
22
  _composite_builders: dict[
23
- Callable, Callable[[torch.fx.GraphModule, torch.fx.Node], None]
23
+ Callable[[Any, ...], Any],
24
+ Callable[[torch.fx.GraphModule, torch.fx.Node], None],
24
25
  ] = {}
25
26
 
26
27
 
@@ -272,13 +273,73 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
272
273
  output = op(**full_kwargs)
273
274
  output = builder.mark_outputs(output)
274
275
 
275
- # Explicitly reshape back to the original shape. This places the ReshapeOp outside of the HLFB.
276
+ # Explicitly reshape back to the original shape. This places the ReshapeOp
277
+ # outside of the HLFB.
276
278
  output = torch.reshape(output, (*(original_idx_shape), embedding_dim))
277
279
  return output
278
280
 
279
281
  node.target = embedding
280
282
 
281
283
 
284
+ @_register_composite_builder(torch.ops.aten.upsample_bilinear2d.vec)
285
+ def _aten_upsample_bilinear2d_vec(_, node: torch.fx.Node):
286
+ """Build a composite for aten.upsample_bilinear2d.vec."""
287
+ op = node.target
288
+ args_mapper = TorchOpArgumentsMapper(op)
289
+ # Assumes later FX passes does not change the args/kwargs of the op.
290
+ # Which is a valid assumption for, given that composite/mark_tensor wrapper
291
+ # should semantically prevents any future mutations on the op.
292
+ output_h, output_w = node.meta["val"].shape[-2:]
293
+
294
+ def upsample_bilinear2d_vec(*args, **kwargs):
295
+ nonlocal op, args_mapper
296
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
297
+
298
+ builder = lowertools.StableHLOCompositeBuilder(
299
+ name="odml.upsample_bilinear2d",
300
+ attr={
301
+ "size": (int(output_h), int(output_w)),
302
+ "align_corners": full_kwargs["align_corners"],
303
+ "is_nchw_op": True,
304
+ },
305
+ )
306
+ full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
307
+ output = op(**full_kwargs)
308
+ output = builder.mark_outputs(output)
309
+ return output
310
+
311
+ node.target = upsample_bilinear2d_vec
312
+
313
+
314
+ @_register_composite_builder(torch.ops.aten.upsample_nearest2d.vec)
315
+ def _aten_upsample_nearest2d_vec(_, node: torch.fx.Node):
316
+ """Build a composite for aten.upsample_nearest2d.vec."""
317
+ op = node.target
318
+ args_mapper = TorchOpArgumentsMapper(op)
319
+ # Assumes later FX passes does not change the args/kwargs of the op.
320
+ # Which is a valid assumption for, given that composite/mark_tensor wrapper
321
+ # should semantically prevents any future mutations on the op.
322
+ output_h, output_w = node.meta["val"].shape[-2:]
323
+
324
+ def upsample_nearest2d_vec(*args, **kwargs):
325
+ nonlocal op, args_mapper
326
+ full_kwargs = args_mapper.get_full_kwargs(args, kwargs)
327
+
328
+ builder = lowertools.StableHLOCompositeBuilder(
329
+ name="tfl.resize_nearest_neighbor",
330
+ attr={
331
+ "size": (int(output_h), int(output_w)),
332
+ "is_nchw_op": True,
333
+ },
334
+ )
335
+ full_kwargs["input"] = builder.mark_inputs(full_kwargs["input"])
336
+ output = op(**full_kwargs)
337
+ output = builder.mark_outputs(output)
338
+ return output
339
+
340
+ node.target = upsample_nearest2d_vec
341
+
342
+
282
343
  class BuildAtenCompositePass(fx_infra.PassBase):
283
344
 
284
345
  def call(self, graph_module: torch.fx.GraphModule):
@@ -17,6 +17,7 @@
17
17
  import operator
18
18
 
19
19
  import ai_edge_torch
20
+ from ai_edge_torch import lowertools
20
21
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
22
23
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
@@ -24,7 +25,7 @@ import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
26
27
  aten = torch.ops.aten
27
- StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
28
+ StableHLOCompositeBuilder = lowertools.StableHLOCompositeBuilder
28
29
 
29
30
  __all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
30
31
 
@@ -17,11 +17,11 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.deepseek import deepseek
20
+ from ai_edge_torch.generative.layers import kv_cache
20
21
  from ai_edge_torch.generative.utilities import converter
21
22
  from ai_edge_torch.generative.utilities import export_config
22
23
 
23
- flags = converter.define_conversion_flags("deepseek")
24
- ExportConfig = export_config.ExportConfig
24
+ flags = converter.define_conversion_flags('deepseek')
25
25
 
26
26
  def main(_):
27
27
  pytorch_model = deepseek.build_model(
@@ -34,7 +34,7 @@ def main(_):
34
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
35
35
  quantize=flags.FLAGS.quantize,
36
36
  lora_ranks=flags.FLAGS.lora_ranks,
37
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
38
38
  )
39
39
 
40
40
 
@@ -53,6 +53,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
53
53
  norm_config = cfg.NormalizationConfig(
54
54
  type=cfg.NormalizationType.RMS_NORM,
55
55
  epsilon=1e-06,
56
+ enable_hlfb=True,
56
57
  )
57
58
  block_config = cfg.TransformerBlockConfig(
58
59
  attn_config=attn_config,
@@ -17,14 +17,10 @@
17
17
 
18
18
  from absl import app
19
19
  from ai_edge_torch.generative.examples.gemma3 import gemma3
20
- from ai_edge_torch.generative.layers import kv_cache
21
20
  from ai_edge_torch.generative.utilities import converter
22
21
  from ai_edge_torch.generative.utilities import export_config
23
- import torch
24
22
 
25
23
  flags = converter.define_conversion_flags('gemma3-1b')
26
- ExportConfig = export_config.ExportConfig
27
-
28
24
 
29
25
  _MODEL_SIZE = flags.DEFINE_string(
30
26
  'model_size',
@@ -33,55 +29,23 @@ _MODEL_SIZE = flags.DEFINE_string(
33
29
  )
34
30
 
35
31
 
36
- def _create_mask(mask_len, kv_cache_max_len):
37
- mask = torch.full(
38
- (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
39
- )
40
- mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
41
- return mask
42
-
43
-
44
- def _create_export_config(
45
- prefill_seq_lens: list[int], kv_cache_max_len: int
46
- ) -> ExportConfig:
47
- """Creates the export config for the model."""
48
- export_config = ExportConfig()
49
- if isinstance(prefill_seq_lens, list):
50
- prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
51
- else:
52
- prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
53
-
54
- export_config.prefill_mask = prefill_mask
55
-
56
- decode_mask = torch.full(
57
- (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
58
- )
59
- decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
60
- export_config.decode_mask = decode_mask
61
- export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
62
- return export_config
63
-
64
-
65
32
  def main(_):
66
33
  if _MODEL_SIZE.value == '1b':
67
34
  pytorch_model = gemma3.build_model_1b(
68
35
  flags.FLAGS.checkpoint_path,
69
36
  kv_cache_max_len=flags.FLAGS.kv_cache_max_len,
70
37
  )
71
- config = pytorch_model.config
72
38
  else:
73
39
  raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
40
+
74
41
  converter.convert_to_tflite(
75
42
  pytorch_model,
76
43
  output_path=flags.FLAGS.output_path,
77
44
  output_name_prefix=flags.FLAGS.output_name_prefix,
78
45
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
79
46
  quantize=flags.FLAGS.quantize,
80
- config=config,
81
47
  lora_ranks=flags.FLAGS.lora_ranks,
82
- export_config=_create_export_config(
83
- flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
84
- ),
48
+ export_config=export_config.get_from_flags(),
85
49
  )
86
50
 
87
51
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,92 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting hammer 2.1 models to multi-signature tflite model."""
17
+
18
+ from absl import app
19
+ from ai_edge_torch.generative.examples.hammer import hammer
20
+ from ai_edge_torch.generative.layers import kv_cache
21
+ from ai_edge_torch.generative.utilities import converter
22
+ from ai_edge_torch.generative.utilities import export_config as export_cfg
23
+ import torch
24
+
25
+
26
+ flags = converter.define_conversion_flags('hammer')
27
+ ExportConfig = export_cfg.ExportConfig
28
+
29
+
30
+ _MODEL_SIZE = flags.DEFINE_enum(
31
+ 'model_size',
32
+ '1.5b',
33
+ ['0.5b', '1.5b'],
34
+ 'The size of the model to convert.',
35
+ )
36
+
37
+ _BUILDER = {
38
+ '0.5b': hammer.build_0_5b_model,
39
+ '1.5b': hammer.build_1_5b_model,
40
+ }
41
+
42
+
43
+ def _create_mask(mask_len, kv_cache_max_len):
44
+ mask = torch.full(
45
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
46
+ )
47
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
48
+ return mask
49
+
50
+
51
+ def _create_export_config(
52
+ prefill_seq_lens: list[int], kv_cache_max_len: int
53
+ ) -> ExportConfig:
54
+ """Creates the export config for the model."""
55
+ export_config = ExportConfig()
56
+ if isinstance(prefill_seq_lens, list):
57
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
58
+ else:
59
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
60
+
61
+ export_config.prefill_mask = prefill_mask
62
+
63
+ decode_mask = torch.full(
64
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
65
+ )
66
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
67
+ export_config.decode_mask = decode_mask
68
+ export_config.kvcache_layout = kv_cache.KV_LAYOUT_TRANSPOSED
69
+ return export_config
70
+
71
+
72
+ def main(_):
73
+ pytorch_model = _BUILDER[_MODEL_SIZE.value](
74
+ flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
75
+ )
76
+ converter.convert_to_tflite(
77
+ pytorch_model,
78
+ output_path=flags.FLAGS.output_path,
79
+ output_name_prefix=flags.FLAGS.output_name_prefix,
80
+ prefill_seq_len=flags.FLAGS.prefill_seq_lens,
81
+ quantize=flags.FLAGS.quantize,
82
+ lora_ranks=flags.FLAGS.lora_ranks,
83
+ export_config=_create_export_config(
84
+ flags.FLAGS.prefill_seq_lens, flags.FLAGS.kv_cache_max_len
85
+ )
86
+ if flags.FLAGS.transpose_kv_cache
87
+ else ExportConfig(),
88
+ )
89
+
90
+
91
+ if __name__ == '__main__':
92
+ app.run(main)
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building Hammer 2.1 models."""
17
+
18
+ import ai_edge_torch.generative.layers.model_config as cfg
19
+ from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
21
+
22
+ TENSOR_NAMES = model_builder.TENSOR_NAMES
23
+
24
+
25
+ class Hammer(model_builder.DecoderOnlyModel):
26
+ """A Hammer model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
30
+ def get_1_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
31
+ """Returns the model config for a Hammer 2.1 1.5B model."""
32
+ attn_config = cfg.AttentionConfig(
33
+ num_heads=12,
34
+ head_dim=128,
35
+ num_query_groups=2,
36
+ rotary_base=1000000,
37
+ rotary_percentage=1.0,
38
+ qkv_use_bias=True,
39
+ )
40
+ ff_config = cfg.FeedForwardConfig(
41
+ type=cfg.FeedForwardType.GATED,
42
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
43
+ intermediate_size=8960,
44
+ )
45
+ norm_config = cfg.NormalizationConfig(
46
+ type=cfg.NormalizationType.RMS_NORM,
47
+ epsilon=1e-06,
48
+ enable_hlfb=True,
49
+ )
50
+ block_config = cfg.TransformerBlockConfig(
51
+ attn_config=attn_config,
52
+ ff_config=ff_config,
53
+ pre_attention_norm_config=norm_config,
54
+ post_attention_norm_config=norm_config,
55
+ )
56
+ config = cfg.ModelConfig(
57
+ vocab_size=151665,
58
+ num_layers=28,
59
+ max_seq_len=32768,
60
+ embedding_dim=1536,
61
+ kv_cache_max_len=kv_cache_max_len,
62
+ block_configs=block_config,
63
+ final_norm_config=norm_config,
64
+ enable_hlfb=True,
65
+ )
66
+ return config
67
+
68
+
69
+ def get_0_5b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
70
+ """Returns the model config for a Hammer 2.1 0.5B model."""
71
+ config = get_1_5b_model_config(kv_cache_max_len)
72
+ # Hammer has only one block config.
73
+ block_config = config.block_config(0)
74
+ block_config.attn_config.num_heads = 14
75
+ block_config.attn_config.head_dim = 64
76
+ block_config.ff_config.intermediate_size = 4864
77
+ config.num_layers = 24
78
+ config.embedding_dim = 896
79
+ return config
80
+
81
+
82
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
83
+ config = get_1_5b_model_config(**kwargs)
84
+ config.vocab_size = 128
85
+ config.num_layers = 2
86
+ config.embedding_dim = 16
87
+ # Hammer has only one block config.
88
+ config.block_config(0).ff_config.intermediate_size = 64
89
+ return config
90
+
91
+
92
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
93
+ return model_builder.build_decoder_only_model(
94
+ checkpoint_path=checkpoint_path,
95
+ config=get_1_5b_model_config(**kwargs),
96
+ tensor_names=TENSOR_NAMES,
97
+ model_class=Hammer,
98
+ )
99
+
100
+
101
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
102
+ return model_builder.build_decoder_only_model(
103
+ checkpoint_path=checkpoint_path,
104
+ config=get_0_5b_model_config(**kwargs),
105
+ tensor_names=TENSOR_NAMES,
106
+ model_class=Hammer,
107
+ )
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Hammer 2.1 0.5B and 1.5B models."""
17
+
18
+ import logging
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.hammer import hammer
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
25
+ from ai_edge_torch.generative.utilities import verifier
26
+ import transformers
27
+
28
+
29
+ _MODEL_SIZE = flags.DEFINE_enum(
30
+ "model_size",
31
+ "0.5b",
32
+ ["0.5b", "1.5b"],
33
+ "The size of the model to verify.",
34
+ )
35
+ _PROMPTS = flags.DEFINE_multi_string(
36
+ "prompts",
37
+ "What is the meaning of life?",
38
+ "The input prompts to generate answers.",
39
+ )
40
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
41
+ "max_new_tokens",
42
+ 30,
43
+ "The maximum size of the generated tokens.",
44
+ )
45
+
46
+ _CHECKPOINT = {
47
+ "0.5b": "MadeAgents/Hammer2.1-0.5b",
48
+ "1.5b": "MadeAgents/Hammer2.1-1.5b",
49
+ }
50
+
51
+ _BUILDER = {
52
+ "0.5b": hammer.build_0_5b_model,
53
+ "1.5b": hammer.build_1_5b_model,
54
+ }
55
+
56
+
57
+ def main(_):
58
+ checkpoint = _CHECKPOINT[_MODEL_SIZE.value]
59
+ logging.info("Loading the original model from: %s", checkpoint)
60
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
61
+
62
+ # Locate the cached dir.
63
+ cached_config_file = transformers.utils.cached_file(
64
+ checkpoint, transformers.utils.CONFIG_NAME
65
+ )
66
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
67
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
68
+ reauthored_model = _BUILDER[_MODEL_SIZE.value](reauthored_checkpoint)
69
+
70
+ logging.info("Loading the tokenizer from: %s", checkpoint)
71
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
72
+
73
+ verifier.verify_reauthored_model(
74
+ original_model=transformers_verifier.TransformersModelWrapper(
75
+ original_model
76
+ ),
77
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
78
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
79
+ generate_prompts=_PROMPTS.value,
80
+ max_new_tokens=_MAX_NEW_TOKENS.value,
81
+ atol=1e-04,
82
+ )
83
+
84
+
85
+ if __name__ == "__main__":
86
+ app.run(main)
@@ -22,8 +22,6 @@ from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
 
24
24
  flags = converter.define_conversion_flags('llama')
25
- ExportConfig = export_config.ExportConfig
26
-
27
25
 
28
26
  _MODEL_SIZE = flags.DEFINE_enum(
29
27
  'model_size',
@@ -49,7 +47,7 @@ def main(_):
49
47
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
50
48
  quantize=flags.FLAGS.quantize,
51
49
  lora_ranks=flags.FLAGS.lora_ranks,
52
- export_config=ExportConfig(),
50
+ export_config=export_config.get_from_flags(),
53
51
  )
54
52
 
55
53
 
@@ -121,7 +121,9 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
121
121
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
122
122
  intermediate_size=8192,
123
123
  )
124
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
124
+ norm_config = cfg.NormalizationConfig(
125
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
126
+ )
125
127
  block_config = cfg.TransformerBlockConfig(
126
128
  attn_config=attn_config,
127
129
  ff_config=ff_config,
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("phi3")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -21,7 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags("phi4")
24
- ExportConfig = export_config.ExportConfig
25
24
 
26
25
 
27
26
  def main(_):
@@ -35,7 +34,7 @@ def main(_):
35
34
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
36
35
  quantize=flags.FLAGS.quantize,
37
36
  lora_ranks=flags.FLAGS.lora_ranks,
38
- export_config=ExportConfig(),
37
+ export_config=export_config.get_from_flags(),
39
38
  )
40
39
 
41
40
 
@@ -22,7 +22,6 @@ from ai_edge_torch.generative.utilities import converter
22
22
  from ai_edge_torch.generative.utilities import export_config
23
23
 
24
24
  flags = converter.define_conversion_flags("phi2")
25
- ExportConfig = export_config.ExportConfig
26
25
 
27
26
 
28
27
  def main(_):
@@ -36,7 +35,7 @@ def main(_):
36
35
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
37
36
  quantize=flags.FLAGS.quantize,
38
37
  lora_ranks=flags.FLAGS.lora_ranks,
39
- export_config=ExportConfig(),
38
+ export_config=export_config.get_from_flags(),
40
39
  )
41
40
 
42
41
 
@@ -65,7 +65,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
65
65
  use_bias=True,
66
66
  )
67
67
  norm_config = cfg.NormalizationConfig(
68
- type=cfg.NormalizationType.LAYER_NORM,
68
+ type=cfg.NormalizationType.LAYER_NORM, enable_hlfb=True
69
69
  )
70
70
  block_config = cfg.TransformerBlockConfig(
71
71
  attn_config=attn_config,
@@ -162,7 +162,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
162
162
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
163
163
  intermediate_size=8192,
164
164
  )
165
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
165
+ norm_config = cfg.NormalizationConfig(
166
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True,
167
+ )
166
168
  block_config = cfg.TransformerBlockConfig(
167
169
  attn_config=attn_config,
168
170
  ff_config=ff_config,
@@ -112,7 +112,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
113
113
  intermediate_size=8192,
114
114
  )
115
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
115
+ norm_config = cfg.NormalizationConfig(
116
+ type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
117
+ )
116
118
  block_config = cfg.TransformerBlockConfig(
117
119
  attn_config=attn_config,
118
120
  ff_config=ff_config,
@@ -21,8 +21,6 @@ from ai_edge_torch.generative.utilities import converter
21
21
  from ai_edge_torch.generative.utilities import export_config
22
22
 
23
23
  flags = converter.define_conversion_flags('qwen')
24
- ExportConfig = export_config.ExportConfig
25
-
26
24
 
27
25
  _MODEL_SIZE = flags.DEFINE_enum(
28
26
  'model_size',
@@ -37,6 +35,7 @@ _BUILDER = {
37
35
  '3b': qwen.build_3b_model,
38
36
  }
39
37
 
38
+
40
39
  def main(_):
41
40
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
42
41
  flags.FLAGS.checkpoint_path, kv_cache_max_len=flags.FLAGS.kv_cache_max_len
@@ -48,7 +47,7 @@ def main(_):
48
47
  prefill_seq_len=flags.FLAGS.prefill_seq_lens,
49
48
  quantize=flags.FLAGS.quantize,
50
49
  lora_ranks=flags.FLAGS.lora_ranks,
51
- export_config=ExportConfig(),
50
+ export_config=export_config.get_from_flags(),
52
51
  )
53
52
 
54
53