ai-edge-torch-nightly 0.4.0.dev20250314__py3-none-any.whl → 0.4.0.dev20250315__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma2.py +17 -17
- ai_edge_torch/generative/examples/gemma3/decoder.py +20 -10
- ai_edge_torch/generative/examples/gemma3/gemma3.py +1 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250314.dist-info → ai_edge_torch_nightly-0.4.0.dev20250315.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250314.dist-info → ai_edge_torch_nightly-0.4.0.dev20250315.dist-info}/RECORD +10 -14
- ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py +0 -14
- ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py +0 -96
- ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py +0 -463
- ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py +0 -212
- /ai_edge_torch/generative/examples/gemma3/{cpu_only/image_encoder.py → image_encoder.py} +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250314.dist-info → ai_edge_torch_nightly-0.4.0.dev20250315.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250314.dist-info → ai_edge_torch_nightly-0.4.0.dev20250315.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250314.dist-info → ai_edge_torch_nightly-0.4.0.dev20250315.dist-info}/top_level.txt +0 -0
@@ -28,7 +28,7 @@ import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
28
28
|
import torch
|
29
29
|
from torch import nn
|
30
30
|
|
31
|
-
|
31
|
+
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
|
32
32
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
33
33
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
34
34
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
@@ -43,7 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
43
43
|
lm_head=None,
|
44
44
|
)
|
45
45
|
|
46
|
-
|
46
|
+
TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
|
47
47
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
48
48
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
49
49
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
@@ -59,6 +59,11 @@ ALT_TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
59
59
|
final_norm="model.norm",
|
60
60
|
)
|
61
61
|
|
62
|
+
TENSOR_NAMES_DICT = {
|
63
|
+
"safetensors": TENSOR_NAMES_SEP_QKV,
|
64
|
+
"kaggle": TENSOR_NAMES_FUSED_QKV,
|
65
|
+
}
|
66
|
+
|
62
67
|
|
63
68
|
class Gemma2Block(attention.TransformerBlock):
|
64
69
|
|
@@ -300,18 +305,13 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
300
305
|
|
301
306
|
|
302
307
|
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
checkpoint_path=checkpoint_path,
|
314
|
-
config=get_model_config_2b(**kwargs),
|
315
|
-
tensor_names=ALT_TENSOR_NAMES,
|
316
|
-
model_class=Gemma2,
|
317
|
-
)
|
308
|
+
for tensor_names in TENSOR_NAMES_DICT.values():
|
309
|
+
try:
|
310
|
+
return model_builder.build_decoder_only_model(
|
311
|
+
checkpoint_path=checkpoint_path,
|
312
|
+
config=get_model_config_2b(**kwargs),
|
313
|
+
tensor_names=tensor_names,
|
314
|
+
model_class=Gemma2,
|
315
|
+
)
|
316
|
+
except KeyError as ke:
|
317
|
+
continue
|
@@ -29,7 +29,7 @@ import torch
|
|
29
29
|
from torch import nn
|
30
30
|
|
31
31
|
|
32
|
-
|
32
|
+
TENSOR_NAMES_SEP_QKV = loading_utils.ModelLoader.TensorNames(
|
33
33
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
34
34
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
35
35
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
@@ -48,9 +48,8 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
48
48
|
lm_head=None,
|
49
49
|
)
|
50
50
|
|
51
|
-
|
52
|
-
|
53
|
-
TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
|
51
|
+
|
52
|
+
TENSOR_NAMES_FUSED_QKV = loading_utils.ModelLoader.TensorNames(
|
54
53
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
55
54
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
56
55
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
@@ -67,6 +66,11 @@ TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
|
|
67
66
|
lm_head=None,
|
68
67
|
)
|
69
68
|
|
69
|
+
TENSOR_NAMES_DICT = {
|
70
|
+
"safetensors": TENSOR_NAMES_SEP_QKV,
|
71
|
+
"kaggle": TENSOR_NAMES_FUSED_QKV,
|
72
|
+
}
|
73
|
+
|
70
74
|
|
71
75
|
class DecoderBlock(attention.TransformerBlock):
|
72
76
|
|
@@ -428,9 +432,15 @@ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
428
432
|
|
429
433
|
|
430
434
|
def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
435
|
+
# TODO(b/403644647): Better error handling for loading checkpoints with
|
436
|
+
# different tensor names.
|
437
|
+
for tensor_names in TENSOR_NAMES_DICT.values():
|
438
|
+
try:
|
439
|
+
return model_builder.build_decoder_only_model(
|
440
|
+
checkpoint_path=checkpoint_path,
|
441
|
+
config=get_decoder_config_1b(**kwargs),
|
442
|
+
tensor_names=tensor_names,
|
443
|
+
model_class=Decoder,
|
444
|
+
)
|
445
|
+
except KeyError as ke:
|
446
|
+
continue
|
@@ -20,7 +20,7 @@ from typing import List, Optional, Tuple
|
|
20
20
|
import xmlrpc
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma3 import decoder
|
23
|
-
from ai_edge_torch.generative.examples.gemma3
|
23
|
+
from ai_edge_torch.generative.examples.gemma3 import image_encoder
|
24
24
|
from ai_edge_torch.generative.layers import builder
|
25
25
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250315
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=cc0LZB58uTtggRLayoxvEd0TugJID7l0oxymtDMgoPI,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,19 +57,15 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
57
57
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
58
58
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
59
59
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
60
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=lR-T25GkjCfd_sN8mAKY_0XNA0MEkMgsj4ZBQnnytHo,11465
|
61
61
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
62
62
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
63
63
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
64
64
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
65
65
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=xAjMqhNrSv2srrBvrwCsnbLzdQXVpkZEOYImb3Mvw3w,3910
|
66
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
67
|
-
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=
|
68
|
-
ai_edge_torch/generative/examples/gemma3/
|
69
|
-
ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py,sha256=4Ym4f8pvHu7dUSkTXfSToNuX8X3fhV5kKuhgEzOcyuw,3012
|
70
|
-
ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py,sha256=fB2oYR08u7GcrWYjNbeADRZM5z1vTbE03mHXi497RRw,16140
|
71
|
-
ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py,sha256=NeMqW67uQEQl09R7nE3vSpT84KXmAHEg9oy4-7TVC5k,8104
|
72
|
-
ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
66
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=4Vf1zA94qLyNzj9iLU0jrd3kzFFZXft4uiItoIBjKyM,15632
|
67
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=NQzqZ55cmC8tGlZ1SKkDeD0Su8mZ79KiazCS8X08xUY,6473
|
68
|
+
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
73
69
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
74
70
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
|
75
71
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
|
@@ -243,8 +239,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
243
239
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
244
240
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
245
241
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
246
|
-
ai_edge_torch_nightly-0.4.0.
|
247
|
-
ai_edge_torch_nightly-0.4.0.
|
248
|
-
ai_edge_torch_nightly-0.4.0.
|
249
|
-
ai_edge_torch_nightly-0.4.0.
|
250
|
-
ai_edge_torch_nightly-0.4.0.
|
242
|
+
ai_edge_torch_nightly-0.4.0.dev20250315.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
243
|
+
ai_edge_torch_nightly-0.4.0.dev20250315.dist-info/METADATA,sha256=UU6qKt_CJAMOHhclRMR64frDLizMR27Tp3n8-eDnPss,1966
|
244
|
+
ai_edge_torch_nightly-0.4.0.dev20250315.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
245
|
+
ai_edge_torch_nightly-0.4.0.dev20250315.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
246
|
+
ai_edge_torch_nightly-0.4.0.dev20250315.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
@@ -1,96 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Example of converting a Gemma3 model to multi-signature tflite model."""
|
17
|
-
|
18
|
-
import os
|
19
|
-
import pathlib
|
20
|
-
|
21
|
-
from absl import app
|
22
|
-
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
24
|
-
from ai_edge_torch.generative.utilities import converter
|
25
|
-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
-
|
27
|
-
_MODEL_SIZE = flags.DEFINE_string(
|
28
|
-
'model_size',
|
29
|
-
'1b',
|
30
|
-
'The size of the model to convert.',
|
31
|
-
)
|
32
|
-
|
33
|
-
_CHECKPOINT_PATH = flags.DEFINE_string(
|
34
|
-
'checkpoint_path',
|
35
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
|
36
|
-
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
|
-
)
|
38
|
-
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
-
'output_path',
|
40
|
-
'/tmp/',
|
41
|
-
'The path to export the tflite model.',
|
42
|
-
)
|
43
|
-
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
-
'output_name_prefix',
|
45
|
-
'gemma3',
|
46
|
-
'The prefix of the output tflite model name.',
|
47
|
-
)
|
48
|
-
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
49
|
-
'prefill_seq_lens',
|
50
|
-
(8, 64, 128, 256, 512, 1024),
|
51
|
-
'List of the maximum sizes of prefill input tensors.',
|
52
|
-
)
|
53
|
-
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
54
|
-
'kv_cache_max_len',
|
55
|
-
1280,
|
56
|
-
'The maximum size of KV cache buffer, including both prefill and decode.',
|
57
|
-
)
|
58
|
-
_QUANTIZE = flags.DEFINE_bool(
|
59
|
-
'quantize',
|
60
|
-
True,
|
61
|
-
'Whether the model should be quantized.',
|
62
|
-
)
|
63
|
-
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
-
'lora_ranks',
|
65
|
-
None,
|
66
|
-
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
-
)
|
68
|
-
|
69
|
-
|
70
|
-
def main(_):
|
71
|
-
if _MODEL_SIZE.value == '1b':
|
72
|
-
pytorch_model = gemma3.build_model_1b(
|
73
|
-
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
74
|
-
)
|
75
|
-
config = pytorch_model.config
|
76
|
-
elif _MODEL_SIZE.value == '4b':
|
77
|
-
pytorch_model = gemma3.build_model_4b(
|
78
|
-
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
79
|
-
)
|
80
|
-
config = pytorch_model.config.decoder_config
|
81
|
-
else:
|
82
|
-
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
83
|
-
converter.convert_to_tflite(
|
84
|
-
pytorch_model,
|
85
|
-
output_path=_OUTPUT_PATH.value,
|
86
|
-
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
87
|
-
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
88
|
-
quantize=_QUANTIZE.value,
|
89
|
-
config=config,
|
90
|
-
lora_ranks=_LORA_RANKS.value,
|
91
|
-
export_config=ExportConfig(),
|
92
|
-
)
|
93
|
-
|
94
|
-
|
95
|
-
if __name__ == '__main__':
|
96
|
-
app.run(main)
|
@@ -1,463 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Example of building a Decoder for Gemma3 model."""
|
17
|
-
|
18
|
-
from typing import List, Optional, Tuple
|
19
|
-
|
20
|
-
from ai_edge_torch.generative.layers import attention
|
21
|
-
from ai_edge_torch.generative.layers import builder
|
22
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
|
-
from ai_edge_torch.generative.utilities import model_builder
|
27
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
|
-
import torch
|
29
|
-
from torch import nn
|
30
|
-
|
31
|
-
|
32
|
-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
33
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
34
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
35
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
36
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
37
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
38
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
39
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
40
|
-
attn_query_norm="model.layers.{}.self_attn.q_norm",
|
41
|
-
attn_key_norm="model.layers.{}.self_attn.k_norm",
|
42
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
43
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
44
|
-
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
45
|
-
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
46
|
-
embedding="model.embed_tokens",
|
47
|
-
final_norm="model.norm",
|
48
|
-
lm_head=None,
|
49
|
-
)
|
50
|
-
|
51
|
-
# Please don't use tensor mapping for converting checkpoints hosted on Kaggle
|
52
|
-
# or HuggingFace. Will be removed in the future.
|
53
|
-
TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
|
54
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
55
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
56
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
57
|
-
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
58
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
59
|
-
attn_query_norm="model.layers.{}.self_attn.query_norm",
|
60
|
-
attn_key_norm="model.layers.{}.self_attn.key_norm",
|
61
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
62
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
63
|
-
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
64
|
-
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
65
|
-
embedding="embedder",
|
66
|
-
final_norm="model.norm",
|
67
|
-
lm_head=None,
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
class DecoderBlock(attention.TransformerBlock):
|
72
|
-
|
73
|
-
|
74
|
-
def forward(
|
75
|
-
self,
|
76
|
-
x: torch.Tensor,
|
77
|
-
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
78
|
-
mask: Optional[torch.Tensor] = None,
|
79
|
-
input_pos: Optional[torch.Tensor] = None,
|
80
|
-
kv_cache: kv_utils.KVCacheEntry = None,
|
81
|
-
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntry]]:
|
82
|
-
"""Forward function of the Gemma3Block.
|
83
|
-
|
84
|
-
Exactly the same as TransformerBlock but we call the post-attention norm
|
85
|
-
immediately after attention and not after the residual pointwise addition.
|
86
|
-
|
87
|
-
Args:
|
88
|
-
x (torch.Tensor): the input tensor.
|
89
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
90
|
-
mask (torch.Tensor): the optional mask tensor.
|
91
|
-
input_pos (torch.Tensor): the optional input position tensor.
|
92
|
-
kv_cache (KVCacheEntry): the optional kv cache entry.
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
output activation from this transformer block, and updated kv cache (if
|
96
|
-
passed in).
|
97
|
-
"""
|
98
|
-
|
99
|
-
x_norm = self.pre_atten_norm(x)
|
100
|
-
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
101
|
-
attn_out_norm = self.post_atten_norm(attn_out)
|
102
|
-
x = x + attn_out_norm
|
103
|
-
output = x + self.ff(x)
|
104
|
-
return output, kv
|
105
|
-
|
106
|
-
|
107
|
-
class Decoder(nn.Module):
|
108
|
-
"""A Gemma3 decoder model built from the Edge Generative API layers."""
|
109
|
-
|
110
|
-
def __init__(self, config: cfg.ModelConfig):
|
111
|
-
super().__init__()
|
112
|
-
|
113
|
-
# Construct model layers.
|
114
|
-
self.tok_embedding = nn.Embedding(
|
115
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
116
|
-
)
|
117
|
-
self.lm_head = nn.Linear(
|
118
|
-
config.embedding_dim,
|
119
|
-
config.vocab_size,
|
120
|
-
bias=config.lm_head_use_bias,
|
121
|
-
)
|
122
|
-
# Gemma3 re-uses the embedding as the head projection layer.
|
123
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
124
|
-
self.transformer_blocks = nn.ModuleList(
|
125
|
-
DecoderBlock(config.block_config(idx), config)
|
126
|
-
for idx in range(config.num_layers)
|
127
|
-
)
|
128
|
-
self.final_norm = builder.build_norm(
|
129
|
-
config.embedding_dim,
|
130
|
-
config.final_norm_config,
|
131
|
-
)
|
132
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
133
|
-
size=config.kv_cache_max,
|
134
|
-
)
|
135
|
-
# Gemma3 has same hyper parameters for each layer except for attention
|
136
|
-
# types. Use the first layer.
|
137
|
-
attn_config = config.block_config(0).attn_config
|
138
|
-
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
139
|
-
size=config.kv_cache_max,
|
140
|
-
window_size=attn_config.sliding_window_size,
|
141
|
-
)
|
142
|
-
self.config = config
|
143
|
-
|
144
|
-
def get_attention_mask(
|
145
|
-
self,
|
146
|
-
attn_type: cfg.AttentionType,
|
147
|
-
input_pos: torch.Tensor,
|
148
|
-
) -> torch.Tensor:
|
149
|
-
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
150
|
-
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
151
|
-
return self.mask_cache.index_select(2, input_pos)
|
152
|
-
|
153
|
-
def compose_mask(
|
154
|
-
self, mask: torch.Tensor, pixel_mask: torch.Tensor,
|
155
|
-
attn_type: cfg.AttentionType,
|
156
|
-
) -> torch.Tensor:
|
157
|
-
mask = mask == 0
|
158
|
-
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
159
|
-
mask = torch.logical_and(mask, pixel_mask)
|
160
|
-
else:
|
161
|
-
mask = torch.logical_or(mask, pixel_mask)
|
162
|
-
mask = torch.where(mask, 0, float("-inf"))
|
163
|
-
return mask
|
164
|
-
|
165
|
-
def build_pixel_mask(self, image_indices: torch.Tensor):
|
166
|
-
pixel_mask = image_indices >= 0
|
167
|
-
max_seq_len = self.config.kv_cache_max
|
168
|
-
if pixel_mask.size(1) < max_seq_len:
|
169
|
-
pixel_mask = torch.cat(
|
170
|
-
[
|
171
|
-
pixel_mask,
|
172
|
-
torch.zeros(
|
173
|
-
(pixel_mask.size(0), max_seq_len - pixel_mask.size(1))
|
174
|
-
),
|
175
|
-
],
|
176
|
-
dim=1,
|
177
|
-
)
|
178
|
-
pixel_mask = torch.logical_and(
|
179
|
-
pixel_mask.unsqueeze(1), pixel_mask.unsqueeze(-1)
|
180
|
-
)
|
181
|
-
return pixel_mask.unsqueeze(1)
|
182
|
-
|
183
|
-
@torch.inference_mode
|
184
|
-
def forward(
|
185
|
-
self,
|
186
|
-
tokens: torch.Tensor,
|
187
|
-
input_pos: torch.Tensor,
|
188
|
-
kv_cache: kv_utils.KVCache,
|
189
|
-
input_embeds: Optional[torch.Tensor] = None,
|
190
|
-
mask: Optional[torch.Tensor] = None,
|
191
|
-
image_indices: Optional[torch.Tensor] = None,
|
192
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
193
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
194
|
-
|
195
|
-
pixel_mask = None
|
196
|
-
if input_embeds is None:
|
197
|
-
# token embeddings of shape (b, t, n_embd)
|
198
|
-
input_embeds = self.tok_embedding(tokens)
|
199
|
-
if self.config.embedding_scale is not None:
|
200
|
-
input_embeds = input_embeds * self.config.embedding_scale
|
201
|
-
if image_indices is not None:
|
202
|
-
pixel_mask = self.build_pixel_mask(image_indices)
|
203
|
-
# RoPE parameters are the same for all blocks. Use the first layer.
|
204
|
-
attn_config = self.config.block_config(0).attn_config
|
205
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
206
|
-
# Different rotary base for global and local attention
|
207
|
-
# based on attention pattern
|
208
|
-
rope = [rotary_pos_emb.build_rope(
|
209
|
-
input_pos, attn_config.head_dim,
|
210
|
-
self.config.block_config(i).attn_config.rotary_base
|
211
|
-
) for i in range(self.config.num_layers)]
|
212
|
-
mask = [self.get_attention_mask(
|
213
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
214
|
-
) for i in range(self.config.num_layers)]
|
215
|
-
|
216
|
-
return self._forward_with_embeds(
|
217
|
-
input_embeds, rope, mask, input_pos, kv_cache,
|
218
|
-
pixel_mask, export_config
|
219
|
-
)
|
220
|
-
|
221
|
-
def _forward_with_embeds(
|
222
|
-
self,
|
223
|
-
input_embeds: torch.Tensor,
|
224
|
-
rope: List[Tuple[torch.Tensor, torch.Tensor]],
|
225
|
-
mask: List[torch.Tensor],
|
226
|
-
input_pos: torch.Tensor,
|
227
|
-
kv_cache: kv_utils.KVCache,
|
228
|
-
pixel_mask: Optional[torch.Tensor] = None,
|
229
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
230
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
231
|
-
"""Forwards the model with input embeddings."""
|
232
|
-
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
233
|
-
"The number of transformer blocks and the number of KV cache entries"
|
234
|
-
" must be the same."
|
235
|
-
)
|
236
|
-
|
237
|
-
x = input_embeds
|
238
|
-
|
239
|
-
if pixel_mask is not None:
|
240
|
-
pixel_mask = pixel_mask.index_select(2, input_pos)
|
241
|
-
mask = [
|
242
|
-
self.compose_mask(
|
243
|
-
mask[i],
|
244
|
-
pixel_mask,
|
245
|
-
self.config.block_config(i).attn_config.attn_type,
|
246
|
-
)
|
247
|
-
for i in range(self.config.num_layers)
|
248
|
-
]
|
249
|
-
updated_kv_entries = []
|
250
|
-
for i, block in enumerate(self.transformer_blocks):
|
251
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
252
|
-
x, kv_entry = block(x, rope[i], mask[i], input_pos, kv_entry)
|
253
|
-
if kv_entry:
|
254
|
-
updated_kv_entries.append(kv_entry)
|
255
|
-
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
256
|
-
if export_config is not None:
|
257
|
-
if (
|
258
|
-
torch.numel(input_pos) > 1
|
259
|
-
and not export_config.output_logits_on_prefill
|
260
|
-
):
|
261
|
-
return {"kv_cache": updated_kv_cache}
|
262
|
-
|
263
|
-
x = self.final_norm(x)
|
264
|
-
res = self.lm_head(x) # (b, t, vocab_size)
|
265
|
-
|
266
|
-
return {"logits": res, "kv_cache": updated_kv_cache}
|
267
|
-
|
268
|
-
|
269
|
-
def get_decoder_config_4b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
270
|
-
"""Returns the model config for a Gemma3 4B model.
|
271
|
-
|
272
|
-
Args:
|
273
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
274
|
-
is 2048.
|
275
|
-
|
276
|
-
Returns:
|
277
|
-
The model config for a Gemma 4B model.
|
278
|
-
"""
|
279
|
-
norm_config = cfg.NormalizationConfig(
|
280
|
-
type=cfg.NormalizationType.RMS_NORM,
|
281
|
-
epsilon=1e-6,
|
282
|
-
zero_centered=True,
|
283
|
-
)
|
284
|
-
ff_config = cfg.FeedForwardConfig(
|
285
|
-
type=cfg.FeedForwardType.GATED,
|
286
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
287
|
-
intermediate_size=4*2560,
|
288
|
-
pre_ff_norm_config=norm_config,
|
289
|
-
post_ff_norm_config=norm_config,
|
290
|
-
)
|
291
|
-
|
292
|
-
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
293
|
-
attn_config = cfg.AttentionConfig(
|
294
|
-
num_heads=8,
|
295
|
-
head_dim=256,
|
296
|
-
num_query_groups=4,
|
297
|
-
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
|
298
|
-
rotary_percentage=1.0,
|
299
|
-
qkv_transpose_before_split=True,
|
300
|
-
query_norm_config=norm_config,
|
301
|
-
key_norm_config=norm_config,
|
302
|
-
logit_softcap=None,
|
303
|
-
sliding_window_size=1024,
|
304
|
-
attn_type=(
|
305
|
-
cfg.AttentionType.GLOBAL
|
306
|
-
if (idx + 1) % 6 == 0
|
307
|
-
else cfg.AttentionType.LOCAL_SLIDING
|
308
|
-
),
|
309
|
-
)
|
310
|
-
return cfg.TransformerBlockConfig(
|
311
|
-
attn_config=attn_config,
|
312
|
-
ff_config=ff_config,
|
313
|
-
pre_attention_norm_config=norm_config,
|
314
|
-
post_attention_norm_config=norm_config,
|
315
|
-
)
|
316
|
-
|
317
|
-
num_layers = 34
|
318
|
-
embedding_dim = 2560
|
319
|
-
config = cfg.ModelConfig(
|
320
|
-
vocab_size=262_144,
|
321
|
-
num_layers=num_layers,
|
322
|
-
max_seq_len=32_768,
|
323
|
-
embedding_dim=embedding_dim,
|
324
|
-
embedding_scale=embedding_dim**0.5,
|
325
|
-
kv_cache_max_len=kv_cache_max_len,
|
326
|
-
block_configs=[get_block_config(i) for i in range(num_layers)],
|
327
|
-
final_norm_config=norm_config,
|
328
|
-
lm_head_use_bias=False,
|
329
|
-
enable_hlfb=True,
|
330
|
-
final_logit_softcap=None,
|
331
|
-
)
|
332
|
-
return config
|
333
|
-
|
334
|
-
def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
335
|
-
"""Returns the model config for a Gemma3 1B model.
|
336
|
-
|
337
|
-
Args:
|
338
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
339
|
-
is 2048.
|
340
|
-
|
341
|
-
Returns:
|
342
|
-
The model config for a Gemma 1B model.
|
343
|
-
"""
|
344
|
-
norm_config = cfg.NormalizationConfig(
|
345
|
-
type=cfg.NormalizationType.RMS_NORM,
|
346
|
-
epsilon=1e-6,
|
347
|
-
zero_centered=True,
|
348
|
-
)
|
349
|
-
ff_config = cfg.FeedForwardConfig(
|
350
|
-
type=cfg.FeedForwardType.GATED,
|
351
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
352
|
-
intermediate_size=6*1152,
|
353
|
-
pre_ff_norm_config=norm_config,
|
354
|
-
post_ff_norm_config=norm_config,
|
355
|
-
)
|
356
|
-
|
357
|
-
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
358
|
-
attn_config = cfg.AttentionConfig(
|
359
|
-
num_heads=4,
|
360
|
-
head_dim=256,
|
361
|
-
num_query_groups=1,
|
362
|
-
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
|
363
|
-
rotary_percentage=1.0,
|
364
|
-
qkv_transpose_before_split=True,
|
365
|
-
query_norm_config=norm_config,
|
366
|
-
key_norm_config=norm_config,
|
367
|
-
logit_softcap=None,
|
368
|
-
sliding_window_size=512,
|
369
|
-
attn_type=(
|
370
|
-
cfg.AttentionType.GLOBAL
|
371
|
-
if (idx + 1) % 6 == 0
|
372
|
-
else cfg.AttentionType.LOCAL_SLIDING
|
373
|
-
),
|
374
|
-
)
|
375
|
-
return cfg.TransformerBlockConfig(
|
376
|
-
attn_config=attn_config,
|
377
|
-
ff_config=ff_config,
|
378
|
-
pre_attention_norm_config=norm_config,
|
379
|
-
post_attention_norm_config=norm_config,
|
380
|
-
)
|
381
|
-
|
382
|
-
num_layers = 26
|
383
|
-
embedding_dim = 1152
|
384
|
-
config = cfg.ModelConfig(
|
385
|
-
vocab_size=262_144,
|
386
|
-
num_layers=num_layers,
|
387
|
-
max_seq_len=32_768,
|
388
|
-
embedding_dim=embedding_dim,
|
389
|
-
embedding_scale=embedding_dim**0.5,
|
390
|
-
kv_cache_max_len=kv_cache_max_len,
|
391
|
-
block_configs=[get_block_config(i) for i in range(num_layers)],
|
392
|
-
final_norm_config=norm_config,
|
393
|
-
lm_head_use_bias=False,
|
394
|
-
enable_hlfb=True,
|
395
|
-
final_logit_softcap=None,
|
396
|
-
)
|
397
|
-
return config
|
398
|
-
|
399
|
-
|
400
|
-
def get_fake_decoder_config_4b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
401
|
-
"""Returns a fake model config for a Gemma3 4B model.
|
402
|
-
|
403
|
-
Args:
|
404
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
405
|
-
is 128.
|
406
|
-
|
407
|
-
Returns:
|
408
|
-
A fake model config for a Gemma 4B model.
|
409
|
-
"""
|
410
|
-
config = get_decoder_config_4b(kv_cache_max_len)
|
411
|
-
config.vocab_size = 128
|
412
|
-
config.num_layers = 2
|
413
|
-
config.max_seq_len = 2 * kv_cache_max_len
|
414
|
-
config.embedding_dim = 128
|
415
|
-
config.embedding_scale = config.embedding_dim**0.5
|
416
|
-
config.block_configs = config.block_configs[: config.num_layers]
|
417
|
-
for block_config in config.block_configs:
|
418
|
-
block_config.attn_config.num_heads = 4
|
419
|
-
block_config.attn_config.head_dim = 64
|
420
|
-
block_config.attn_config.sliding_window_size = 64
|
421
|
-
block_config.ff_config.intermediate_size = 128
|
422
|
-
return config
|
423
|
-
|
424
|
-
def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
425
|
-
"""Returns a fake model config for a Gemma3 1B model.
|
426
|
-
|
427
|
-
Args:
|
428
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
429
|
-
is 128.
|
430
|
-
|
431
|
-
Returns:
|
432
|
-
A fake model config for a Gemma 1B model.
|
433
|
-
"""
|
434
|
-
config = get_decoder_config_1b(kv_cache_max_len)
|
435
|
-
config.vocab_size = 128
|
436
|
-
config.num_layers = 2
|
437
|
-
config.max_seq_len = 2 * kv_cache_max_len
|
438
|
-
config.embedding_dim = 128
|
439
|
-
config.embedding_scale = config.embedding_dim**0.5
|
440
|
-
config.block_configs = config.block_configs[: config.num_layers]
|
441
|
-
for block_config in config.block_configs:
|
442
|
-
block_config.attn_config.num_heads = 4
|
443
|
-
block_config.attn_config.head_dim = 64
|
444
|
-
block_config.attn_config.sliding_window_size = 64
|
445
|
-
block_config.ff_config.intermediate_size = 128
|
446
|
-
return config
|
447
|
-
|
448
|
-
|
449
|
-
def build_model_4b(checkpoint_path: str, **kwargs) -> nn.Module:
|
450
|
-
return model_builder.build_decoder_only_model(
|
451
|
-
checkpoint_path=checkpoint_path,
|
452
|
-
config=get_decoder_config_4b(**kwargs),
|
453
|
-
tensor_names=TENSOR_NAMES,
|
454
|
-
model_class=Decoder,
|
455
|
-
)
|
456
|
-
|
457
|
-
def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
|
458
|
-
return model_builder.build_decoder_only_model(
|
459
|
-
checkpoint_path=checkpoint_path,
|
460
|
-
config=get_decoder_config_1b(**kwargs),
|
461
|
-
tensor_names=TENSOR_NAMES,
|
462
|
-
model_class=Decoder,
|
463
|
-
)
|
@@ -1,212 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""Example of building a Gemma3 model."""
|
17
|
-
|
18
|
-
from dataclasses import dataclass
|
19
|
-
from typing import List, Optional, Tuple
|
20
|
-
import xmlrpc
|
21
|
-
|
22
|
-
from ai_edge_torch.generative.examples.gemma3 import decoder
|
23
|
-
from ai_edge_torch.generative.examples.gemma3 import image_encoder
|
24
|
-
from ai_edge_torch.generative.layers import builder
|
25
|
-
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
-
from ai_edge_torch.generative.utilities import model_builder
|
28
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
|
-
import torch
|
30
|
-
from torch import nn
|
31
|
-
|
32
|
-
|
33
|
-
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
|
34
|
-
|
35
|
-
@dataclass
|
36
|
-
class Gemma3MMConfig:
|
37
|
-
"""Gemma3 model configurations."""
|
38
|
-
|
39
|
-
image_encoder_config: cfg.ModelConfig
|
40
|
-
decoder_config: cfg.ModelConfig
|
41
|
-
mm_norm_config: cfg.NormalizationConfig
|
42
|
-
mm_extra_tokens: int
|
43
|
-
image_token_id: int
|
44
|
-
image_projection_scale: float
|
45
|
-
image_projection_use_bias: bool = False
|
46
|
-
|
47
|
-
class Gemma3MM(nn.Module):
|
48
|
-
"""A Gemma3 multimodal model built from the Edge Generative API layers."""
|
49
|
-
|
50
|
-
def __init__(self, config: Gemma3MMConfig):
|
51
|
-
super().__init__()
|
52
|
-
|
53
|
-
self.image_encoder = image_encoder.SiglipVisionEncoderWithExit(
|
54
|
-
config.image_encoder_config)
|
55
|
-
self.decoder = decoder.Decoder(config.decoder_config)
|
56
|
-
self.mm_norm = builder.build_norm(
|
57
|
-
config.image_encoder_config.embedding_dim,
|
58
|
-
config.mm_norm_config,
|
59
|
-
)
|
60
|
-
self.extra_embedding = nn.Embedding(
|
61
|
-
config.mm_extra_tokens, config.image_encoder_config.embedding_dim)
|
62
|
-
self.image_projection = nn.Linear(
|
63
|
-
config.image_encoder_config.embedding_dim,
|
64
|
-
config.decoder_config.embedding_dim,
|
65
|
-
bias=config.image_projection_use_bias,
|
66
|
-
)
|
67
|
-
image_embedding_config = config.image_encoder_config.image_embedding
|
68
|
-
self.num_patches = (
|
69
|
-
image_embedding_config.image_size // image_embedding_config.patch_size
|
70
|
-
) ** 2
|
71
|
-
self.config = config
|
72
|
-
|
73
|
-
@torch.inference_mode
|
74
|
-
def forward(
|
75
|
-
self,
|
76
|
-
tokens: torch.Tensor,
|
77
|
-
input_pos: torch.Tensor,
|
78
|
-
kv_cache: kv_utils.KVCache,
|
79
|
-
image_indices: Optional[torch.Tensor] = None,
|
80
|
-
image_feat_indices: Optional[torch.Tensor] = None,
|
81
|
-
pixel_values: torch.Tensor = None,
|
82
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
83
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
84
|
-
_, seq_len = tokens.size()
|
85
|
-
assert self.config.decoder_config.max_seq_len >= seq_len, (
|
86
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
87
|
-
f" {self.config.decoder_config.max_seq_len}"
|
88
|
-
)
|
89
|
-
if pixel_values is None:
|
90
|
-
return self.decoder(tokens=tokens,
|
91
|
-
input_pos=input_pos,
|
92
|
-
kv_cache=kv_cache,
|
93
|
-
input_embeds=None,
|
94
|
-
export_config=export_config,
|
95
|
-
)
|
96
|
-
vocab_size = self.config.decoder_config.vocab_size
|
97
|
-
input_embeds = self.decoder.tok_embedding(torch.clip(tokens, 0,
|
98
|
-
vocab_size - 1))
|
99
|
-
if self.decoder.config.embedding_scale is not None:
|
100
|
-
input_embeds = input_embeds * self.decoder.config.embedding_scale
|
101
|
-
|
102
|
-
# TODO: Identify embedding path for hard tokens if required.
|
103
|
-
# extra_embeds = self.extra_embedding(
|
104
|
-
# torch.clip(tokens - vocab_size, 0, self.config.mm_extra_tokens - 1)
|
105
|
-
# )
|
106
|
-
# extra_embeds = self.image_projection(extra_embeds)
|
107
|
-
# input_embeds = torch.where(tokens < self.config.decoder_config.vocab_size,
|
108
|
-
# input_embeds, extra_embeds)
|
109
|
-
# alternate method of implementation
|
110
|
-
# rows, cols = torch.where(tokens >= self.config.vocab_size)
|
111
|
-
# ext_embeds = self.ext_embedding(
|
112
|
-
# tokens[rows, cols] - self.config.vocab_size
|
113
|
-
# )
|
114
|
-
# ext_embeds = self.mm_projection(extra_embeds)
|
115
|
-
# input_embeds[rows, cols, :] = extra_embeds
|
116
|
-
|
117
|
-
# Shape of pixel_values: (b, n, c, h, w)
|
118
|
-
batch_size, num_media, c, h, w = pixel_values.size()
|
119
|
-
pixel_values = pixel_values.view(-1, c, h, w)
|
120
|
-
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
121
|
-
image_encoded = self.mm_norm(image_encoded)
|
122
|
-
image_encoded = self.image_projection(image_encoded)
|
123
|
-
_, num_patches, num_channels = image_encoded.size()
|
124
|
-
image_encoded = image_encoded.view(
|
125
|
-
batch_size, num_media, num_patches, num_channels
|
126
|
-
)
|
127
|
-
|
128
|
-
# Interleave the image soft embeddings with the text embeddings
|
129
|
-
for b in range(tokens.shape[0]):
|
130
|
-
unbatched_image_encoded = image_encoded[b]
|
131
|
-
image_features = unbatched_image_encoded[
|
132
|
-
image_indices[b], image_feat_indices[b]
|
133
|
-
]
|
134
|
-
index_to_copy = torch.where(image_indices[b] >= 0)[0]
|
135
|
-
input_embeds[b] = torch.index_copy(input_embeds[b], 0, index_to_copy,
|
136
|
-
image_features[index_to_copy])
|
137
|
-
return self.decoder(
|
138
|
-
tokens=None,
|
139
|
-
input_pos=input_pos,
|
140
|
-
kv_cache=kv_cache,
|
141
|
-
input_embeds=input_embeds,
|
142
|
-
image_indices=image_indices,
|
143
|
-
export_config=export_config,
|
144
|
-
)
|
145
|
-
|
146
|
-
|
147
|
-
def get_model_config_4b(**kwargs) -> Gemma3MMConfig:
|
148
|
-
return Gemma3MMConfig(
|
149
|
-
image_encoder_config=image_encoder.get_image_encoder_config(),
|
150
|
-
decoder_config=decoder.get_decoder_config_4b(),
|
151
|
-
image_token_id=257152, # TODO: confirm
|
152
|
-
image_projection_scale=2048**0.5,
|
153
|
-
image_projection_use_bias=False,
|
154
|
-
mm_norm_config=cfg.NormalizationConfig(
|
155
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
156
|
-
epsilon=1e-6,
|
157
|
-
enable_hlfb=True,
|
158
|
-
),
|
159
|
-
mm_extra_tokens=128,
|
160
|
-
)
|
161
|
-
|
162
|
-
|
163
|
-
def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
164
|
-
return Gemma3MMConfig(
|
165
|
-
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
166
|
-
decoder_config=decoder.get_fake_decoder_config_4b(**kwargs),
|
167
|
-
image_token_id=127,
|
168
|
-
image_projection_scale=128**0.5,
|
169
|
-
image_projection_use_bias=False,
|
170
|
-
mm_norm_config=cfg.NormalizationConfig(
|
171
|
-
type=cfg.NormalizationType.LAYER_NORM,
|
172
|
-
epsilon=1e-6,
|
173
|
-
enable_hlfb=True,
|
174
|
-
),
|
175
|
-
mm_extra_tokens=32,
|
176
|
-
)
|
177
|
-
|
178
|
-
|
179
|
-
def build_model_4b(checkpoint_path: str, **kwargs) -> Gemma3MM:
|
180
|
-
|
181
|
-
decoder_tensor_names = decoder.TENSOR_NAMES
|
182
|
-
|
183
|
-
config = get_model_config_4b(**kwargs)
|
184
|
-
model = Gemma3MM(config)
|
185
|
-
# # TODO: Load the parameters of image encoder from checkpoint mapping Tensor names properly.
|
186
|
-
# loader = loading_utils.ModelLoader(
|
187
|
-
# checkpoint_path, image_encoder.TENSOR_NAMES
|
188
|
-
# )
|
189
|
-
# loader.load(model.image_encoder, strict=False)
|
190
|
-
# # Load the parameters of decoder.
|
191
|
-
loader = loading_utils.ModelLoader(checkpoint_path, decoder_tensor_names)
|
192
|
-
loader.load(model.decoder, strict=False)
|
193
|
-
|
194
|
-
# # Load the parameters of image projection.
|
195
|
-
# loader = loading_utils.ModelLoader(checkpoint_path, None)
|
196
|
-
# state = loader.get_state()
|
197
|
-
# converted_state = dict()
|
198
|
-
# converted_state["weight"] = state.pop(f"{PROJECTION_TENSOR_NAME}.weight")
|
199
|
-
# if config.image_projection_use_bias:
|
200
|
-
# converted_state["bias"] = state.pop(f"{PROJECTION_TENSOR_NAME}.bias")
|
201
|
-
# model.image_projection.load_state_dict(converted_state)
|
202
|
-
model.eval()
|
203
|
-
return model
|
204
|
-
|
205
|
-
def build_model_1b(checkpoint_path: str, **kwargs) -> decoder.Decoder:
|
206
|
-
if checkpoint_path:
|
207
|
-
model = decoder.build_model_1b(checkpoint_path, **kwargs)
|
208
|
-
else:
|
209
|
-
config = decoder.get_decoder_config_1b(**kwargs)
|
210
|
-
model = decoder.Decoder(config)
|
211
|
-
model.eval()
|
212
|
-
return model
|
File without changes
|
File without changes
|
File without changes
|