ai-edge-torch-nightly 0.4.0.dev20250314__py3-none-any.whl → 0.4.0.dev20250316__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ import ai_edge_torch.generative.utilities.loader as loading_utils
28
28
  import torch
29
29
  from torch import nn
30
30
 
31
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
32
32
  ff_up_proj="model.layers.{}.mlp.up_proj",
33
33
  ff_down_proj="model.layers.{}.mlp.down_proj",
34
34
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -43,7 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
  lm_head=None,
44
44
  )
45
45
 
46
- ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
46
+ TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
47
47
  ff_up_proj="model.layers.{}.mlp.up_proj",
48
48
  ff_down_proj="model.layers.{}.mlp.down_proj",
49
49
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -59,6 +59,11 @@ ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
59
59
  final_norm="model.norm",
60
60
  )
61
61
 
62
+ TENSOR_NAMES_DICT = {
63
+ "safetensors": TENSOR_NAMES_SEP_QKV,
64
+ "kaggle": TENSOR_NAMES_FUSED_QKV,
65
+ }
66
+
62
67
 
63
68
  class Gemma2Block(attention.TransformerBlock):
64
69
 
@@ -300,18 +305,13 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
300
305
 
301
306
 
302
307
  def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
303
- try:
304
- return model_builder.build_decoder_only_model(
305
- checkpoint_path=checkpoint_path,
306
- config=get_model_config_2b(**kwargs),
307
- tensor_names=TENSOR_NAMES,
308
- model_class=Gemma2,
309
- )
310
- except KeyError as ke:
311
- # Also attempt to load with an alternative naming scheme.
312
- return model_builder.build_decoder_only_model(
313
- checkpoint_path=checkpoint_path,
314
- config=get_model_config_2b(**kwargs),
315
- tensor_names=ALT_TENSOR_NAMES,
316
- model_class=Gemma2,
317
- )
308
+ for tensor_names in TENSOR_NAMES_DICT.values():
309
+ try:
310
+ return model_builder.build_decoder_only_model(
311
+ checkpoint_path=checkpoint_path,
312
+ config=get_model_config_2b(**kwargs),
313
+ tensor_names=tensor_names,
314
+ model_class=Gemma2,
315
+ )
316
+ except KeyError as ke:
317
+ continue
@@ -29,7 +29,7 @@ import torch
29
29
  from torch import nn
30
30
 
31
31
 
32
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
+ TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
33
33
  ff_up_proj="model.layers.{}.mlp.up_proj",
34
34
  ff_down_proj="model.layers.{}.mlp.down_proj",
35
35
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -48,9 +48,8 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
48
  lm_head=None,
49
49
  )
50
50
 
51
- # Please don't use tensor mapping for converting checkpoints hosted on Kaggle
52
- # or HuggingFace. Will be removed in the future.
53
- TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
51
+
52
+ TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
54
53
  ff_up_proj="model.layers.{}.mlp.up_proj",
55
54
  ff_down_proj="model.layers.{}.mlp.down_proj",
56
55
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
@@ -67,6 +66,11 @@ TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
67
66
  lm_head=None,
68
67
  )
69
68
 
69
+ TENSOR_NAMES_DICT = {
70
+ "safetensors": TENSOR_NAMES_SEP_QKV,
71
+ "kaggle": TENSOR_NAMES_FUSED_QKV,
72
+ }
73
+
70
74
 
71
75
  class DecoderBlock(attention.TransformerBlock):
72
76
 
@@ -428,9 +432,15 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
428
432
 
429
433
 
430
434
  def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
431
- return model_builder.build_decoder_only_model(
432
- checkpoint_path=checkpoint_path,
433
- config=get_decoder_config_1b(**kwargs),
434
- tensor_names=TENSOR_NAMES,
435
- model_class=Decoder,
436
- )
435
+ # TODO(b/403644647): Better error handling for loading checkpoints with
436
+ # different tensor names.
437
+ for tensor_names in TENSOR_NAMES_DICT.values():
438
+ try:
439
+ return model_builder.build_decoder_only_model(
440
+ checkpoint_path=checkpoint_path,
441
+ config=get_decoder_config_1b(**kwargs),
442
+ tensor_names=tensor_names,
443
+ model_class=Decoder,
444
+ )
445
+ except KeyError as ke:
446
+ continue
@@ -20,7 +20,7 @@ from typing import List, Optional, Tuple
20
20
  import xmlrpc
21
21
 
22
22
  from ai_edge_torch.generative.examples.gemma3 import decoder
23
- from ai_edge_torch.generative.examples.gemma3.cpu_only import image_encoder
23
+ from ai_edge_torch.generative.examples.gemma3 import image_encoder
24
24
  from ai_edge_torch.generative.layers import builder
25
25
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
26
  import ai_edge_torch.generative.layers.model_config as cfg
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250314"
16
+ __version__ = "0.4.0.dev20250316"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250314
3
+ Version: 0.4.0.dev20250316
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=PjlstuIJ-GlyKyFBMrwc7RQFRecNIkHpz5aIzvYNRKo,706
5
+ ai_edge_torch/version.py,sha256=jME-032g08KjA0-4jHbpsL3FKCJ7nOx1hgJCxDO5ePE,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,19 +57,15 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
57
57
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
58
58
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
59
59
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
60
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=CK1lHw-YQPAr26KMdrYA6icQHvKH59yHAQ4eC4X636o,11539
60
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=lR-T25GkjCfd_sN8mAKY_0XNA0MEkMgsj4ZBQnnytHo,11465
61
61
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
62
62
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
63
63
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
64
64
  ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
65
65
  ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=xAjMqhNrSv2srrBvrwCsnbLzdQXVpkZEOYImb3Mvw3w,3910
66
- ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=_7s_JrzwW4rX07f41VDuRLDZDJDshc3vqhXVY92K8q8,15423
67
- ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=n2EQVp5SrnMeb0csHrz46_gdNiHTpsApaRmcAc8xyj8,6482
68
- ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py,sha256=P11xO0F1MUbLMs8ySz6tu6qGDOOyK43q-HV_pqdsCUY,670
69
- ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py,sha256=4Ym4f8pvHu7dUSkTXfSToNuX8X3fhV5kKuhgEzOcyuw,3012
70
- ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py,sha256=fB2oYR08u7GcrWYjNbeADRZM5z1vTbE03mHXi497RRw,16140
71
- ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py,sha256=NeMqW67uQEQl09R7nE3vSpT84KXmAHEg9oy4-7TVC5k,8104
72
- ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
66
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=4Vf1zA94qLyNzj9iLU0jrd3kzFFZXft4uiItoIBjKyM,15632
67
+ ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=NQzqZ55cmC8tGlZ1SKkDeD0Su8mZ79KiazCS8X08xUY,6473
68
+ ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
73
69
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
74
70
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
75
71
  ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
@@ -243,8 +239,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
243
239
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
244
240
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
245
241
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
246
- ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
247
- ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/METADATA,sha256=n_c6T76WR-J-SCOmKKKzzuPoyM4i_2W2TO6ub8AuDw0,1966
248
- ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
249
- ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
250
- ai_edge_torch_nightly-0.4.0.dev20250314.dist-info/RECORD,,
242
+ ai_edge_torch_nightly-0.4.0.dev20250316.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
243
+ ai_edge_torch_nightly-0.4.0.dev20250316.dist-info/METADATA,sha256=Ua-f14kHLaTqaczlZePB7-9RspufHu5AMI4tbEnQCPc,1966
244
+ ai_edge_torch_nightly-0.4.0.dev20250316.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
245
+ ai_edge_torch_nightly-0.4.0.dev20250316.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
246
+ ai_edge_torch_nightly-0.4.0.dev20250316.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2025 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,96 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Example of converting a Gemma3 model to multi-signature tflite model."""
17
-
18
- import os
19
- import pathlib
20
-
21
- from absl import app
22
- from absl import flags
23
- from ai_edge_torch.generative.examples.gemma3 import gemma3
24
- from ai_edge_torch.generative.utilities import converter
25
- from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
-
27
- _MODEL_SIZE = flags.DEFINE_string(
28
- 'model_size',
29
- '1b',
30
- 'The size of the model to convert.',
31
- )
32
-
33
- _CHECKPOINT_PATH = flags.DEFINE_string(
34
- 'checkpoint_path',
35
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
36
- 'The path to the model checkpoint, or directory holding the checkpoint.',
37
- )
38
- _OUTPUT_PATH = flags.DEFINE_string(
39
- 'output_path',
40
- '/tmp/',
41
- 'The path to export the tflite model.',
42
- )
43
- _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
- 'output_name_prefix',
45
- 'gemma3',
46
- 'The prefix of the output tflite model name.',
47
- )
48
- _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
49
- 'prefill_seq_lens',
50
- (8, 64, 128, 256, 512, 1024),
51
- 'List of the maximum sizes of prefill input tensors.',
52
- )
53
- _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
54
- 'kv_cache_max_len',
55
- 1280,
56
- 'The maximum size of KV cache buffer, including both prefill and decode.',
57
- )
58
- _QUANTIZE = flags.DEFINE_bool(
59
- 'quantize',
60
- True,
61
- 'Whether the model should be quantized.',
62
- )
63
- _LORA_RANKS = flags.DEFINE_multi_integer(
64
- 'lora_ranks',
65
- None,
66
- 'If set, the model will be converted with the provided list of LoRA ranks.',
67
- )
68
-
69
-
70
- def main(_):
71
- if _MODEL_SIZE.value == '1b':
72
- pytorch_model = gemma3.build_model_1b(
73
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
74
- )
75
- config = pytorch_model.config
76
- elif _MODEL_SIZE.value == '4b':
77
- pytorch_model = gemma3.build_model_4b(
78
- _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
79
- )
80
- config = pytorch_model.config.decoder_config
81
- else:
82
- raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
83
- converter.convert_to_tflite(
84
- pytorch_model,
85
- output_path=_OUTPUT_PATH.value,
86
- output_name_prefix=_OUTPUT_NAME_PREFIX.value,
87
- prefill_seq_len=_PREFILL_SEQ_LENS.value,
88
- quantize=_QUANTIZE.value,
89
- config=config,
90
- lora_ranks=_LORA_RANKS.value,
91
- export_config=ExportConfig(),
92
- )
93
-
94
-
95
- if __name__ == '__main__':
96
- app.run(main)
@@ -1,463 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Example of building a Decoder for Gemma3 model."""
17
-
18
- from typing import List, Optional, Tuple
19
-
20
- from ai_edge_torch.generative.layers import attention
21
- from ai_edge_torch.generative.layers import builder
22
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
- import ai_edge_torch.generative.layers.model_config as cfg
25
- import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
- from ai_edge_torch.generative.utilities import model_builder
27
- import ai_edge_torch.generative.utilities.loader as loading_utils
28
- import torch
29
- from torch import nn
30
-
31
-
32
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
- ff_up_proj="model.layers.{}.mlp.up_proj",
34
- ff_down_proj="model.layers.{}.mlp.down_proj",
35
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
36
- attn_query_proj="model.layers.{}.self_attn.q_proj",
37
- attn_key_proj="model.layers.{}.self_attn.k_proj",
38
- attn_value_proj="model.layers.{}.self_attn.v_proj",
39
- attn_output_proj="model.layers.{}.self_attn.o_proj",
40
- attn_query_norm="model.layers.{}.self_attn.q_norm",
41
- attn_key_norm="model.layers.{}.self_attn.k_norm",
42
- pre_attn_norm="model.layers.{}.input_layernorm",
43
- post_attn_norm="model.layers.{}.post_attention_layernorm",
44
- pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
45
- post_ff_norm="model.layers.{}.post_feedforward_layernorm",
46
- embedding="model.embed_tokens",
47
- final_norm="model.norm",
48
- lm_head=None,
49
- )
50
-
51
- # Please don't use tensor mapping for converting checkpoints hosted on Kaggle
52
- # or HuggingFace. Will be removed in the future.
53
- TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
54
- ff_up_proj="model.layers.{}.mlp.up_proj",
55
- ff_down_proj="model.layers.{}.mlp.down_proj",
56
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
57
- attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
58
- attn_output_proj="model.layers.{}.self_attn.o_proj",
59
- attn_query_norm="model.layers.{}.self_attn.query_norm",
60
- attn_key_norm="model.layers.{}.self_attn.key_norm",
61
- pre_attn_norm="model.layers.{}.input_layernorm",
62
- post_attn_norm="model.layers.{}.post_attention_layernorm",
63
- pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
64
- post_ff_norm="model.layers.{}.post_feedforward_layernorm",
65
- embedding="embedder",
66
- final_norm="model.norm",
67
- lm_head=None,
68
- )
69
-
70
-
71
- class DecoderBlock(attention.TransformerBlock):
72
-
73
-
74
- def forward(
75
- self,
76
- x: torch.Tensor,
77
- rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
78
- mask: Optional[torch.Tensor] = None,
79
- input_pos: Optional[torch.Tensor] = None,
80
- kv_cache: kv_utils.KVCacheEntry = None,
81
- ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
82
- """Forward function of the Gemma3Block.
83
-
84
- Exactly the same as TransformerBlock but we call the post-attention norm
85
- immediately after attention and not after the residual pointwise addition.
86
-
87
- Args:
88
- x (torch.Tensor): the input tensor.
89
- rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
90
- mask (torch.Tensor): the optional mask tensor.
91
- input_pos (torch.Tensor): the optional input position tensor.
92
- kv_cache (KVCacheEntry): the optional kv cache entry.
93
-
94
- Returns:
95
- output activation from this transformer block, and updated kv cache (if
96
- passed in).
97
- """
98
-
99
- x_norm = self.pre_atten_norm(x)
100
- attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
101
- attn_out_norm = self.post_atten_norm(attn_out)
102
- x = x + attn_out_norm
103
- output = x + self.ff(x)
104
- return output, kv
105
-
106
-
107
- class Decoder(nn.Module):
108
- """A Gemma3 decoder model built from the Edge Generative API layers."""
109
-
110
- def __init__(self, config: cfg.ModelConfig):
111
- super().__init__()
112
-
113
- # Construct model layers.
114
- self.tok_embedding = nn.Embedding(
115
- config.vocab_size, config.embedding_dim, padding_idx=0
116
- )
117
- self.lm_head = nn.Linear(
118
- config.embedding_dim,
119
- config.vocab_size,
120
- bias=config.lm_head_use_bias,
121
- )
122
- # Gemma3 re-uses the embedding as the head projection layer.
123
- self.lm_head.weight.data = self.tok_embedding.weight.data
124
- self.transformer_blocks = nn.ModuleList(
125
- DecoderBlock(config.block_config(idx), config)
126
- for idx in range(config.num_layers)
127
- )
128
- self.final_norm = builder.build_norm(
129
- config.embedding_dim,
130
- config.final_norm_config,
131
- )
132
- self.mask_cache = attn_utils.build_causal_mask_cache(
133
- size=config.kv_cache_max,
134
- )
135
- # Gemma3 has same hyper parameters for each layer except for attention
136
- # types. Use the first layer.
137
- attn_config = config.block_config(0).attn_config
138
- self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
139
- size=config.kv_cache_max,
140
- window_size=attn_config.sliding_window_size,
141
- )
142
- self.config = config
143
-
144
- def get_attention_mask(
145
- self,
146
- attn_type: cfg.AttentionType,
147
- input_pos: torch.Tensor,
148
- ) -> torch.Tensor:
149
- if attn_type == cfg.AttentionType.LOCAL_SLIDING:
150
- return self.sliding_window_mask_cache.index_select(2, input_pos)
151
- return self.mask_cache.index_select(2, input_pos)
152
-
153
- def compose_mask(
154
- self, mask: torch.Tensor, pixel_mask: torch.Tensor,
155
- attn_type: cfg.AttentionType,
156
- ) -> torch.Tensor:
157
- mask = mask == 0
158
- if attn_type == cfg.AttentionType.LOCAL_SLIDING:
159
- mask = torch.logical_and(mask, pixel_mask)
160
- else:
161
- mask = torch.logical_or(mask, pixel_mask)
162
- mask = torch.where(mask, 0, float("-inf"))
163
- return mask
164
-
165
- def build_pixel_mask(self, image_indices: torch.Tensor):
166
- pixel_mask = image_indices >= 0
167
- max_seq_len = self.config.kv_cache_max
168
- if pixel_mask.size(1) < max_seq_len:
169
- pixel_mask = torch.cat(
170
- [
171
- pixel_mask,
172
- torch.zeros(
173
- (pixel_mask.size(0), max_seq_len - pixel_mask.size(1))
174
- ),
175
- ],
176
- dim=1,
177
- )
178
- pixel_mask = torch.logical_and(
179
- pixel_mask.unsqueeze(1), pixel_mask.unsqueeze(-1)
180
- )
181
- return pixel_mask.unsqueeze(1)
182
-
183
- @torch.inference_mode
184
- def forward(
185
- self,
186
- tokens: torch.Tensor,
187
- input_pos: torch.Tensor,
188
- kv_cache: kv_utils.KVCache,
189
- input_embeds: Optional[torch.Tensor] = None,
190
- mask: Optional[torch.Tensor] = None,
191
- image_indices: Optional[torch.Tensor] = None,
192
- export_config: Optional[model_builder.ExportConfig] = None,
193
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
194
-
195
- pixel_mask = None
196
- if input_embeds is None:
197
- # token embeddings of shape (b, t, n_embd)
198
- input_embeds = self.tok_embedding(tokens)
199
- if self.config.embedding_scale is not None:
200
- input_embeds = input_embeds * self.config.embedding_scale
201
- if image_indices is not None:
202
- pixel_mask = self.build_pixel_mask(image_indices)
203
- # RoPE parameters are the same for all blocks. Use the first layer.
204
- attn_config = self.config.block_config(0).attn_config
205
- n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
206
- # Different rotary base for global and local attention
207
- # based on attention pattern
208
- rope = [rotary_pos_emb.build_rope(
209
- input_pos, attn_config.head_dim,
210
- self.config.block_config(i).attn_config.rotary_base
211
- ) for i in range(self.config.num_layers)]
212
- mask = [self.get_attention_mask(
213
- self.config.block_config(i).attn_config.attn_type, input_pos
214
- ) for i in range(self.config.num_layers)]
215
-
216
- return self._forward_with_embeds(
217
- input_embeds, rope, mask, input_pos, kv_cache,
218
- pixel_mask, export_config
219
- )
220
-
221
- def _forward_with_embeds(
222
- self,
223
- input_embeds: torch.Tensor,
224
- rope: List[Tuple[torch.Tensor, torch.Tensor]],
225
- mask: List[torch.Tensor],
226
- input_pos: torch.Tensor,
227
- kv_cache: kv_utils.KVCache,
228
- pixel_mask: Optional[torch.Tensor] = None,
229
- export_config: Optional[model_builder.ExportConfig] = None,
230
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
231
- """Forwards the model with input embeddings."""
232
- assert len(self.transformer_blocks) == len(kv_cache.caches), (
233
- "The number of transformer blocks and the number of KV cache entries"
234
- " must be the same."
235
- )
236
-
237
- x = input_embeds
238
-
239
- if pixel_mask is not None:
240
- pixel_mask = pixel_mask.index_select(2, input_pos)
241
- mask = [
242
- self.compose_mask(
243
- mask[i],
244
- pixel_mask,
245
- self.config.block_config(i).attn_config.attn_type,
246
- )
247
- for i in range(self.config.num_layers)
248
- ]
249
- updated_kv_entries = []
250
- for i, block in enumerate(self.transformer_blocks):
251
- kv_entry = kv_cache.caches[i] if kv_cache else None
252
- x, kv_entry = block(x, rope[i], mask[i], input_pos, kv_entry)
253
- if kv_entry:
254
- updated_kv_entries.append(kv_entry)
255
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
256
- if export_config is not None:
257
- if (
258
- torch.numel(input_pos) > 1
259
- and not export_config.output_logits_on_prefill
260
- ):
261
- return {"kv_cache": updated_kv_cache}
262
-
263
- x = self.final_norm(x)
264
- res = self.lm_head(x) # (b, t, vocab_size)
265
-
266
- return {"logits": res, "kv_cache": updated_kv_cache}
267
-
268
-
269
- def get_decoder_config_4b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
270
- """Returns the model config for a Gemma3 4B model.
271
-
272
- Args:
273
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
274
- is 2048.
275
-
276
- Returns:
277
- The model config for a Gemma 4B model.
278
- """
279
- norm_config = cfg.NormalizationConfig(
280
- type=cfg.NormalizationType.RMS_NORM,
281
- epsilon=1e-6,
282
- zero_centered=True,
283
- )
284
- ff_config = cfg.FeedForwardConfig(
285
- type=cfg.FeedForwardType.GATED,
286
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
287
- intermediate_size=4*2560,
288
- pre_ff_norm_config=norm_config,
289
- post_ff_norm_config=norm_config,
290
- )
291
-
292
- def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
293
- attn_config = cfg.AttentionConfig(
294
- num_heads=8,
295
- head_dim=256,
296
- num_query_groups=4,
297
- rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
298
- rotary_percentage=1.0,
299
- qkv_transpose_before_split=True,
300
- query_norm_config=norm_config,
301
- key_norm_config=norm_config,
302
- logit_softcap=None,
303
- sliding_window_size=1024,
304
- attn_type=(
305
- cfg.AttentionType.GLOBAL
306
- if (idx + 1) % 6 == 0
307
- else cfg.AttentionType.LOCAL_SLIDING
308
- ),
309
- )
310
- return cfg.TransformerBlockConfig(
311
- attn_config=attn_config,
312
- ff_config=ff_config,
313
- pre_attention_norm_config=norm_config,
314
- post_attention_norm_config=norm_config,
315
- )
316
-
317
- num_layers = 34
318
- embedding_dim = 2560
319
- config = cfg.ModelConfig(
320
- vocab_size=262_144,
321
- num_layers=num_layers,
322
- max_seq_len=32_768,
323
- embedding_dim=embedding_dim,
324
- embedding_scale=embedding_dim**0.5,
325
- kv_cache_max_len=kv_cache_max_len,
326
- block_configs=[get_block_config(i) for i in range(num_layers)],
327
- final_norm_config=norm_config,
328
- lm_head_use_bias=False,
329
- enable_hlfb=True,
330
- final_logit_softcap=None,
331
- )
332
- return config
333
-
334
- def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
335
- """Returns the model config for a Gemma3 1B model.
336
-
337
- Args:
338
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
339
- is 2048.
340
-
341
- Returns:
342
- The model config for a Gemma 1B model.
343
- """
344
- norm_config = cfg.NormalizationConfig(
345
- type=cfg.NormalizationType.RMS_NORM,
346
- epsilon=1e-6,
347
- zero_centered=True,
348
- )
349
- ff_config = cfg.FeedForwardConfig(
350
- type=cfg.FeedForwardType.GATED,
351
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
352
- intermediate_size=6*1152,
353
- pre_ff_norm_config=norm_config,
354
- post_ff_norm_config=norm_config,
355
- )
356
-
357
- def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
358
- attn_config = cfg.AttentionConfig(
359
- num_heads=4,
360
- head_dim=256,
361
- num_query_groups=1,
362
- rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
363
- rotary_percentage=1.0,
364
- qkv_transpose_before_split=True,
365
- query_norm_config=norm_config,
366
- key_norm_config=norm_config,
367
- logit_softcap=None,
368
- sliding_window_size=512,
369
- attn_type=(
370
- cfg.AttentionType.GLOBAL
371
- if (idx + 1) % 6 == 0
372
- else cfg.AttentionType.LOCAL_SLIDING
373
- ),
374
- )
375
- return cfg.TransformerBlockConfig(
376
- attn_config=attn_config,
377
- ff_config=ff_config,
378
- pre_attention_norm_config=norm_config,
379
- post_attention_norm_config=norm_config,
380
- )
381
-
382
- num_layers = 26
383
- embedding_dim = 1152
384
- config = cfg.ModelConfig(
385
- vocab_size=262_144,
386
- num_layers=num_layers,
387
- max_seq_len=32_768,
388
- embedding_dim=embedding_dim,
389
- embedding_scale=embedding_dim**0.5,
390
- kv_cache_max_len=kv_cache_max_len,
391
- block_configs=[get_block_config(i) for i in range(num_layers)],
392
- final_norm_config=norm_config,
393
- lm_head_use_bias=False,
394
- enable_hlfb=True,
395
- final_logit_softcap=None,
396
- )
397
- return config
398
-
399
-
400
- def get_fake_decoder_config_4b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
401
- """Returns a fake model config for a Gemma3 4B model.
402
-
403
- Args:
404
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
405
- is 128.
406
-
407
- Returns:
408
- A fake model config for a Gemma 4B model.
409
- """
410
- config = get_decoder_config_4b(kv_cache_max_len)
411
- config.vocab_size = 128
412
- config.num_layers = 2
413
- config.max_seq_len = 2 * kv_cache_max_len
414
- config.embedding_dim = 128
415
- config.embedding_scale = config.embedding_dim**0.5
416
- config.block_configs = config.block_configs[: config.num_layers]
417
- for block_config in config.block_configs:
418
- block_config.attn_config.num_heads = 4
419
- block_config.attn_config.head_dim = 64
420
- block_config.attn_config.sliding_window_size = 64
421
- block_config.ff_config.intermediate_size = 128
422
- return config
423
-
424
- def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
425
- """Returns a fake model config for a Gemma3 1B model.
426
-
427
- Args:
428
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
429
- is 128.
430
-
431
- Returns:
432
- A fake model config for a Gemma 1B model.
433
- """
434
- config = get_decoder_config_1b(kv_cache_max_len)
435
- config.vocab_size = 128
436
- config.num_layers = 2
437
- config.max_seq_len = 2 * kv_cache_max_len
438
- config.embedding_dim = 128
439
- config.embedding_scale = config.embedding_dim**0.5
440
- config.block_configs = config.block_configs[: config.num_layers]
441
- for block_config in config.block_configs:
442
- block_config.attn_config.num_heads = 4
443
- block_config.attn_config.head_dim = 64
444
- block_config.attn_config.sliding_window_size = 64
445
- block_config.ff_config.intermediate_size = 128
446
- return config
447
-
448
-
449
- def build_model_4b(checkpoint_path: str, **kwargs) -> nn.Module:
450
- return model_builder.build_decoder_only_model(
451
- checkpoint_path=checkpoint_path,
452
- config=get_decoder_config_4b(**kwargs),
453
- tensor_names=TENSOR_NAMES,
454
- model_class=Decoder,
455
- )
456
-
457
- def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
458
- return model_builder.build_decoder_only_model(
459
- checkpoint_path=checkpoint_path,
460
- config=get_decoder_config_1b(**kwargs),
461
- tensor_names=TENSOR_NAMES,
462
- model_class=Decoder,
463
- )
@@ -1,212 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Example of building a Gemma3 model."""
17
-
18
- from dataclasses import dataclass
19
- from typing import List, Optional, Tuple
20
- import xmlrpc
21
-
22
- from ai_edge_torch.generative.examples.gemma3 import decoder
23
- from ai_edge_torch.generative.examples.gemma3 import image_encoder
24
- from ai_edge_torch.generative.layers import builder
25
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
- import ai_edge_torch.generative.layers.model_config as cfg
27
- from ai_edge_torch.generative.utilities import model_builder
28
- import ai_edge_torch.generative.utilities.loader as loading_utils
29
- import torch
30
- from torch import nn
31
-
32
-
33
- PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
34
-
35
- @dataclass
36
- class Gemma3MMConfig:
37
- """Gemma3 model configurations."""
38
-
39
- image_encoder_config: cfg.ModelConfig
40
- decoder_config: cfg.ModelConfig
41
- mm_norm_config: cfg.NormalizationConfig
42
- mm_extra_tokens: int
43
- image_token_id: int
44
- image_projection_scale: float
45
- image_projection_use_bias: bool = False
46
-
47
- class Gemma3MM(nn.Module):
48
- """A Gemma3 multimodal model built from the Edge Generative API layers."""
49
-
50
- def __init__(self, config: Gemma3MMConfig):
51
- super().__init__()
52
-
53
- self.image_encoder = image_encoder.SiglipVisionEncoderWithExit(
54
- config.image_encoder_config)
55
- self.decoder = decoder.Decoder(config.decoder_config)
56
- self.mm_norm = builder.build_norm(
57
- config.image_encoder_config.embedding_dim,
58
- config.mm_norm_config,
59
- )
60
- self.extra_embedding = nn.Embedding(
61
- config.mm_extra_tokens, config.image_encoder_config.embedding_dim)
62
- self.image_projection = nn.Linear(
63
- config.image_encoder_config.embedding_dim,
64
- config.decoder_config.embedding_dim,
65
- bias=config.image_projection_use_bias,
66
- )
67
- image_embedding_config = config.image_encoder_config.image_embedding
68
- self.num_patches = (
69
- image_embedding_config.image_size // image_embedding_config.patch_size
70
- ) ** 2
71
- self.config = config
72
-
73
- @torch.inference_mode
74
- def forward(
75
- self,
76
- tokens: torch.Tensor,
77
- input_pos: torch.Tensor,
78
- kv_cache: kv_utils.KVCache,
79
- image_indices: Optional[torch.Tensor] = None,
80
- image_feat_indices: Optional[torch.Tensor] = None,
81
- pixel_values: torch.Tensor = None,
82
- export_config: Optional[model_builder.ExportConfig] = None,
83
- ) -> dict[torch.Tensor, kv_utils.KVCache]:
84
- _, seq_len = tokens.size()
85
- assert self.config.decoder_config.max_seq_len >= seq_len, (
86
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
87
- f" {self.config.decoder_config.max_seq_len}"
88
- )
89
- if pixel_values is None:
90
- return self.decoder(tokens=tokens,
91
- input_pos=input_pos,
92
- kv_cache=kv_cache,
93
- input_embeds=None,
94
- export_config=export_config,
95
- )
96
- vocab_size = self.config.decoder_config.vocab_size
97
- input_embeds = self.decoder.tok_embedding(torch.clip(tokens, 0,
98
- vocab_size - 1))
99
- if self.decoder.config.embedding_scale is not None:
100
- input_embeds = input_embeds * self.decoder.config.embedding_scale
101
-
102
- # TODO: Identify embedding path for hard tokens if required.
103
- # extra_embeds = self.extra_embedding(
104
- # torch.clip(tokens - vocab_size, 0, self.config.mm_extra_tokens - 1)
105
- # )
106
- # extra_embeds = self.image_projection(extra_embeds)
107
- # input_embeds = torch.where(tokens < self.config.decoder_config.vocab_size,
108
- # input_embeds, extra_embeds)
109
- # alternate method of implementation
110
- # rows, cols = torch.where(tokens >= self.config.vocab_size)
111
- # ext_embeds = self.ext_embedding(
112
- # tokens[rows, cols] - self.config.vocab_size
113
- # )
114
- # ext_embeds = self.mm_projection(extra_embeds)
115
- # input_embeds[rows, cols, :] = extra_embeds
116
-
117
- # Shape of pixel_values: (b, n, c, h, w)
118
- batch_size, num_media, c, h, w = pixel_values.size()
119
- pixel_values = pixel_values.view(-1, c, h, w)
120
- image_encoded = self.image_encoder(pixel_values=pixel_values)
121
- image_encoded = self.mm_norm(image_encoded)
122
- image_encoded = self.image_projection(image_encoded)
123
- _, num_patches, num_channels = image_encoded.size()
124
- image_encoded = image_encoded.view(
125
- batch_size, num_media, num_patches, num_channels
126
- )
127
-
128
- # Interleave the image soft embeddings with the text embeddings
129
- for b in range(tokens.shape[0]):
130
- unbatched_image_encoded = image_encoded[b]
131
- image_features = unbatched_image_encoded[
132
- image_indices[b], image_feat_indices[b]
133
- ]
134
- index_to_copy = torch.where(image_indices[b] >= 0)[0]
135
- input_embeds[b] = torch.index_copy(input_embeds[b], 0, index_to_copy,
136
- image_features[index_to_copy])
137
- return self.decoder(
138
- tokens=None,
139
- input_pos=input_pos,
140
- kv_cache=kv_cache,
141
- input_embeds=input_embeds,
142
- image_indices=image_indices,
143
- export_config=export_config,
144
- )
145
-
146
-
147
- def get_model_config_4b(**kwargs) -> Gemma3MMConfig:
148
- return Gemma3MMConfig(
149
- image_encoder_config=image_encoder.get_image_encoder_config(),
150
- decoder_config=decoder.get_decoder_config_4b(),
151
- image_token_id=257152, # TODO: confirm
152
- image_projection_scale=2048**0.5,
153
- image_projection_use_bias=False,
154
- mm_norm_config=cfg.NormalizationConfig(
155
- type=cfg.NormalizationType.LAYER_NORM,
156
- epsilon=1e-6,
157
- enable_hlfb=True,
158
- ),
159
- mm_extra_tokens=128,
160
- )
161
-
162
-
163
- def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
164
- return Gemma3MMConfig(
165
- image_encoder_config=image_encoder.get_fake_image_encoder_config(),
166
- decoder_config=decoder.get_fake_decoder_config_4b(**kwargs),
167
- image_token_id=127,
168
- image_projection_scale=128**0.5,
169
- image_projection_use_bias=False,
170
- mm_norm_config=cfg.NormalizationConfig(
171
- type=cfg.NormalizationType.LAYER_NORM,
172
- epsilon=1e-6,
173
- enable_hlfb=True,
174
- ),
175
- mm_extra_tokens=32,
176
- )
177
-
178
-
179
- def build_model_4b(checkpoint_path: str, **kwargs) -> Gemma3MM:
180
-
181
- decoder_tensor_names = decoder.TENSOR_NAMES
182
-
183
- config = get_model_config_4b(**kwargs)
184
- model = Gemma3MM(config)
185
- # # TODO: Load the parameters of image encoder from checkpoint mapping Tensor names properly.
186
- # loader = loading_utils.ModelLoader(
187
- # checkpoint_path, image_encoder.TENSOR_NAMES
188
- # )
189
- # loader.load(model.image_encoder, strict=False)
190
- # # Load the parameters of decoder.
191
- loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
192
- loader.load(model.decoder, strict=False)
193
-
194
- # # Load the parameters of image projection.
195
- # loader = loading_utils.ModelLoader(checkpoint_path, None)
196
- # state = loader.get_state()
197
- # converted_state = dict()
198
- # converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
199
- # if config.image_projection_use_bias:
200
- # converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias")
201
- # model.image_projection.load_state_dict(converted_state)
202
- model.eval()
203
- return model
204
-
205
- def build_model_1b(checkpoint_path: str, **kwargs) -> decoder.Decoder:
206
- if checkpoint_path:
207
- model = decoder.build_model_1b(checkpoint_path, **kwargs)
208
- else:
209
- config = decoder.get_decoder_config_1b(**kwargs)
210
- model = decoder.Decoder(config)
211
- model.eval()
212
- return model