keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +15 -33
- keras_hub/layers/__init__.py +134 -0
- keras_hub/metrics/__init__.py +11 -0
- keras_hub/models/__init__.py +642 -0
- keras_hub/samplers/__init__.py +18 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
- keras_hub/src/layers/preprocessing/image_converter.py +1 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
- keras_hub/src/layers/preprocessing/random_swap.py +1 -1
- keras_hub/src/models/audio_to_text.py +66 -0
- keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
- keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma/gemma_presets.py +10 -10
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
- keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_presets.py +3 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +84 -2
- keras_hub/src/models/mistral/mistral_presets.py +3 -3
- keras_hub/src/models/mixtral/__init__.py +5 -0
- keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/moonshine/__init__.py +5 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
- keras_hub/src/models/qwen/__init__.py +4 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_backbone.py +8 -1
- keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
- keras_hub/src/models/qwen_moe/__init__.py +5 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
- keras_hub/src/models/segformer/segformer_presets.py +12 -12
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/models/xception/__init__.py +5 -0
- keras_hub/src/models/xception/xception_backbone.py +188 -0
- keras_hub/src/models/xception/xception_image_classifier.py +12 -0
- keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/xception/xception_image_converter.py +8 -0
- keras_hub/src/models/xception/xception_presets.py +14 -0
- keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
- keras_hub/src/utils/coco/__init__.py +0 -0
- keras_hub/src/utils/coco/coco_utils.py +133 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +70 -10
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/timm/convert_cspnet.py +94 -23
- keras_hub/src/utils/timm/preset_loader.py +6 -6
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/{version_utils.py → version.py} +1 -1
- keras_hub/tokenizers/__init__.py +117 -0
- keras_hub/utils/__init__.py +21 -0
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
- keras_hub/api/__init__.py +0 -15
- keras_hub/api/layers/__init__.py +0 -86
- keras_hub/api/metrics/__init__.py +0 -11
- keras_hub/api/models/__init__.py +0 -416
- keras_hub/api/samplers/__init__.py +0 -16
- keras_hub/api/tokenizers/__init__.py +0 -58
- keras_hub/api/utils/__init__.py +0 -9
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@keras_hub_export("keras_hub.utils.coco_id_to_name")
|
|
5
|
+
def coco_id_to_name(id):
|
|
6
|
+
"""Convert a single COCO class name to a class ID.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
id: An integer class id from 0 to 91.
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
The human readable image class name, e.g. "bicycle".
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
>>> keras_hub.utils.coco_id_to_name(2)
|
|
16
|
+
'bicycle'
|
|
17
|
+
"""
|
|
18
|
+
return COCO_NAMES[id]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@keras_hub_export("keras_hub.utils.coco_name_to_id")
|
|
22
|
+
def coco_name_to_id(name):
|
|
23
|
+
"""Convert a single COCO class name to a class ID.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
name: A human readable image class name, e.g. "bicycle".
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The integer class id from 0 to 999.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> keras_hub.utils.coco_name_to_id("bicycle")
|
|
33
|
+
2
|
|
34
|
+
"""
|
|
35
|
+
return COCO_IDS[name]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
COCO_NAMES = {
|
|
39
|
+
0: "unlabeled",
|
|
40
|
+
1: "person",
|
|
41
|
+
2: "bicycle",
|
|
42
|
+
3: "car",
|
|
43
|
+
4: "motorcycle",
|
|
44
|
+
5: "airplane",
|
|
45
|
+
6: "bus",
|
|
46
|
+
7: "train",
|
|
47
|
+
8: "truck",
|
|
48
|
+
9: "boat",
|
|
49
|
+
10: "traffic_light",
|
|
50
|
+
11: "fire_hydrant",
|
|
51
|
+
12: "street_sign",
|
|
52
|
+
13: "stop_sign",
|
|
53
|
+
14: "parking_meter",
|
|
54
|
+
15: "bench",
|
|
55
|
+
16: "bird",
|
|
56
|
+
17: "cat",
|
|
57
|
+
18: "dog",
|
|
58
|
+
19: "horse",
|
|
59
|
+
20: "sheep",
|
|
60
|
+
21: "cow",
|
|
61
|
+
22: "elephant",
|
|
62
|
+
23: "bear",
|
|
63
|
+
24: "zebra",
|
|
64
|
+
25: "giraffe",
|
|
65
|
+
26: "hat",
|
|
66
|
+
27: "backpack",
|
|
67
|
+
28: "umbrella",
|
|
68
|
+
29: "shoe",
|
|
69
|
+
30: "eye_glasses",
|
|
70
|
+
31: "handbag",
|
|
71
|
+
32: "tie",
|
|
72
|
+
33: "suitcase",
|
|
73
|
+
34: "frisbee",
|
|
74
|
+
35: "skis",
|
|
75
|
+
36: "snowboard",
|
|
76
|
+
37: "sports_ball",
|
|
77
|
+
38: "kite",
|
|
78
|
+
39: "baseball_bat",
|
|
79
|
+
40: "baseball_glove",
|
|
80
|
+
41: "skateboard",
|
|
81
|
+
42: "surfboard",
|
|
82
|
+
43: "tennis_racket",
|
|
83
|
+
44: "bottle",
|
|
84
|
+
45: "plate",
|
|
85
|
+
46: "wine_glass",
|
|
86
|
+
47: "cup",
|
|
87
|
+
48: "fork",
|
|
88
|
+
49: "knife",
|
|
89
|
+
50: "spoon",
|
|
90
|
+
51: "bowl",
|
|
91
|
+
52: "banana",
|
|
92
|
+
53: "apple",
|
|
93
|
+
54: "sandwich",
|
|
94
|
+
55: "orange",
|
|
95
|
+
56: "broccoli",
|
|
96
|
+
57: "carrot",
|
|
97
|
+
58: "hot_dog",
|
|
98
|
+
59: "pizza",
|
|
99
|
+
60: "donut",
|
|
100
|
+
61: "cake",
|
|
101
|
+
62: "chair",
|
|
102
|
+
63: "couch",
|
|
103
|
+
64: "potted_plant",
|
|
104
|
+
65: "bed",
|
|
105
|
+
66: "mirror",
|
|
106
|
+
67: "dining_table",
|
|
107
|
+
68: "window",
|
|
108
|
+
69: "desk",
|
|
109
|
+
70: "toilet",
|
|
110
|
+
71: "door",
|
|
111
|
+
72: "tv",
|
|
112
|
+
73: "laptop",
|
|
113
|
+
74: "mouse",
|
|
114
|
+
75: "remote",
|
|
115
|
+
76: "keyboard",
|
|
116
|
+
77: "cell_phone",
|
|
117
|
+
78: "microwave",
|
|
118
|
+
79: "oven",
|
|
119
|
+
80: "toaster",
|
|
120
|
+
81: "sink",
|
|
121
|
+
82: "refrigerator",
|
|
122
|
+
83: "blender",
|
|
123
|
+
84: "book",
|
|
124
|
+
85: "clock",
|
|
125
|
+
86: "vase",
|
|
126
|
+
87: "scissors",
|
|
127
|
+
88: "teddy_bear",
|
|
128
|
+
89: "hair_drier",
|
|
129
|
+
90: "toothbrush",
|
|
130
|
+
91: "hair_brush",
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
COCO_IDS = {v: k for k, v in COCO_NAMES.items()}
|
|
@@ -3,6 +3,40 @@ from keras import ops
|
|
|
3
3
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
@keras_hub_export("keras_hub.utils.imagenet_id_to_name")
|
|
7
|
+
def imagenet_id_to_name(id):
|
|
8
|
+
"""Convert a single ImageNet class ID to a class name.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
id: An integer class id from 0 to 999.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
The human readable image class name, e.g. "goldfish".
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> keras_hub.utils.imagenet_id_to_name(1)
|
|
18
|
+
'goldfish'
|
|
19
|
+
"""
|
|
20
|
+
return IMAGENET_NAMES[id][1]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@keras_hub_export("keras_hub.utils.imagenet_name_to_id")
|
|
24
|
+
def imagenet_name_to_id(name):
|
|
25
|
+
"""Convert a single ImageNet class name to a class ID.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
name: A human readable image class name, e.g. "goldfish".
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
The integer class id from 0 to 999.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
>>> keras_hub.utils.imagenet_name_to_id("goldfish")
|
|
35
|
+
1
|
|
36
|
+
"""
|
|
37
|
+
return IMAGENET_IDS[name]
|
|
38
|
+
|
|
39
|
+
|
|
6
40
|
@keras_hub_export("keras_hub.utils.decode_imagenet_predictions")
|
|
7
41
|
def decode_imagenet_predictions(preds, top=5, include_synset_ids=False):
|
|
8
42
|
"""Decodes the predictions for an ImageNet-1k prediction.
|
|
@@ -1052,3 +1086,5 @@ IMAGENET_NAMES = {
|
|
|
1052
1086
|
998: ("n13133613", "ear"),
|
|
1053
1087
|
999: ("n15075141", "toilet_tissue"),
|
|
1054
1088
|
}
|
|
1089
|
+
|
|
1090
|
+
IMAGENET_IDS = {v[1]: k for k, v in IMAGENET_NAMES.items()}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import sys
|
|
2
3
|
|
|
3
4
|
import keras
|
|
@@ -147,3 +148,13 @@ def get_gpu_names():
|
|
|
147
148
|
]
|
|
148
149
|
else:
|
|
149
150
|
return [""]
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def sharded_weights_available():
|
|
154
|
+
"""Whether sharded weights serialization is available.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
`True` if sharded weights are available, `False` otherwise.
|
|
158
|
+
"""
|
|
159
|
+
save_weights_signature = inspect.signature(keras.saving.save_weights)
|
|
160
|
+
return "max_shard_size" in save_weights_signature.parameters
|
|
@@ -10,6 +10,8 @@ from absl import logging
|
|
|
10
10
|
|
|
11
11
|
from keras_hub.src.api_export import keras_hub_export
|
|
12
12
|
from keras_hub.src.utils.keras_utils import print_msg
|
|
13
|
+
from keras_hub.src.utils.keras_utils import sharded_weights_available
|
|
14
|
+
from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
|
|
13
15
|
|
|
14
16
|
try:
|
|
15
17
|
import kagglehub
|
|
@@ -48,6 +50,7 @@ METADATA_FILE = "metadata.json"
|
|
|
48
50
|
# Weight file names.
|
|
49
51
|
MODEL_WEIGHTS_FILE = "model.weights.h5"
|
|
50
52
|
TASK_WEIGHTS_FILE = "task.weights.h5"
|
|
53
|
+
SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
|
|
51
54
|
|
|
52
55
|
# HuggingFace filenames.
|
|
53
56
|
README_FILE = "README.md"
|
|
@@ -647,7 +650,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
647
650
|
backbone = self._load_serialized_object(self.config, **kwargs)
|
|
648
651
|
if load_weights:
|
|
649
652
|
jax_memory_cleanup(backbone)
|
|
650
|
-
|
|
653
|
+
self._load_backbone_weights(backbone)
|
|
651
654
|
return backbone
|
|
652
655
|
|
|
653
656
|
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
|
|
@@ -697,8 +700,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
697
700
|
task.load_task_weights(task_weights)
|
|
698
701
|
else:
|
|
699
702
|
jax_memory_cleanup(task.backbone)
|
|
700
|
-
|
|
701
|
-
task.backbone.load_weights(backbone_weights)
|
|
703
|
+
self._load_backbone_weights(task.backbone)
|
|
702
704
|
return task
|
|
703
705
|
|
|
704
706
|
def load_preprocessor(
|
|
@@ -726,18 +728,64 @@ class KerasPresetLoader(PresetLoader):
|
|
|
726
728
|
config["config"] = {**config["config"], **kwargs}
|
|
727
729
|
return keras.saving.deserialize_keras_object(config)
|
|
728
730
|
|
|
731
|
+
def _get_sharded_filenames(self, config_path):
|
|
732
|
+
with open(config_path, encoding="utf-8") as config_file:
|
|
733
|
+
config = json.load(config_file)
|
|
734
|
+
weight_map = config["weight_map"]
|
|
735
|
+
return sorted(set(weight_map.values()))
|
|
736
|
+
|
|
737
|
+
def _load_backbone_weights(self, backbone):
|
|
738
|
+
# Detect if the backbone is sharded or not.
|
|
739
|
+
has_single_file_weights = check_file_exists(
|
|
740
|
+
self.preset, MODEL_WEIGHTS_FILE
|
|
741
|
+
)
|
|
742
|
+
if has_single_file_weights:
|
|
743
|
+
filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
|
|
744
|
+
else:
|
|
745
|
+
if not sharded_weights_available():
|
|
746
|
+
raise RuntimeError(
|
|
747
|
+
"Sharded weights loading is not supported in the current "
|
|
748
|
+
f"Keras version {keras.__version__}. "
|
|
749
|
+
"Please update to a newer version."
|
|
750
|
+
)
|
|
751
|
+
filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
|
|
752
|
+
sharded_filenames = self._get_sharded_filenames(filepath)
|
|
753
|
+
for sharded_filename in sharded_filenames:
|
|
754
|
+
# Download the sharded weights.
|
|
755
|
+
_ = get_file(self.preset, sharded_filename)
|
|
756
|
+
backbone.load_weights(filepath)
|
|
757
|
+
|
|
729
758
|
|
|
730
759
|
class KerasPresetSaver:
|
|
731
760
|
def __init__(self, preset_dir):
|
|
732
761
|
os.makedirs(preset_dir, exist_ok=True)
|
|
733
762
|
self.preset_dir = preset_dir
|
|
734
763
|
|
|
735
|
-
def save_backbone(self, backbone):
|
|
764
|
+
def save_backbone(self, backbone, max_shard_size=10):
|
|
736
765
|
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
|
|
737
|
-
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
|
|
738
|
-
backbone.save_weights(backbone_weight_path)
|
|
739
766
|
self._save_metadata(backbone)
|
|
740
767
|
|
|
768
|
+
# Save the weights.
|
|
769
|
+
backbone_size_in_bytes = self._get_variables_size_in_bytes(
|
|
770
|
+
backbone.variables
|
|
771
|
+
)
|
|
772
|
+
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
|
|
773
|
+
# If the size of the backbone is larger than `max_shard_size`, save
|
|
774
|
+
# sharded weights.
|
|
775
|
+
if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
|
|
776
|
+
backbone_sharded_weights_config_path = os.path.join(
|
|
777
|
+
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
|
|
778
|
+
)
|
|
779
|
+
backbone.save_weights(
|
|
780
|
+
backbone_sharded_weights_config_path,
|
|
781
|
+
max_shard_size=max_shard_size,
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
backbone_weight_path = os.path.join(
|
|
785
|
+
self.preset_dir, MODEL_WEIGHTS_FILE
|
|
786
|
+
)
|
|
787
|
+
backbone.save_weights(backbone_weight_path)
|
|
788
|
+
|
|
741
789
|
def save_tokenizer(self, tokenizer):
|
|
742
790
|
config_file = TOKENIZER_CONFIG_FILE
|
|
743
791
|
if hasattr(tokenizer, "config_file"):
|
|
@@ -755,7 +803,7 @@ class KerasPresetSaver:
|
|
|
755
803
|
def save_image_converter(self, converter):
|
|
756
804
|
self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
|
|
757
805
|
|
|
758
|
-
def save_task(self, task):
|
|
806
|
+
def save_task(self, task, max_shard_size=10):
|
|
759
807
|
# Save task specific config and weights.
|
|
760
808
|
self._save_serialized_object(task, TASK_CONFIG_FILE)
|
|
761
809
|
if task.has_task_weights():
|
|
@@ -763,10 +811,12 @@ class KerasPresetSaver:
|
|
|
763
811
|
task.save_task_weights(task_weight_path)
|
|
764
812
|
# Save backbone.
|
|
765
813
|
if hasattr(task.backbone, "save_to_preset"):
|
|
766
|
-
task.backbone.save_to_preset(
|
|
814
|
+
task.backbone.save_to_preset(
|
|
815
|
+
self.preset_dir, max_shard_size=max_shard_size
|
|
816
|
+
)
|
|
767
817
|
else:
|
|
768
818
|
# Allow saving a `keras.Model` that is not a backbone subclass.
|
|
769
|
-
self.save_backbone(task.backbone)
|
|
819
|
+
self.save_backbone(task.backbone, max_shard_size=max_shard_size)
|
|
770
820
|
# Save preprocessor.
|
|
771
821
|
if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
|
|
772
822
|
task.preprocessor.save_to_preset(self.preset_dir)
|
|
@@ -801,7 +851,7 @@ class KerasPresetSaver:
|
|
|
801
851
|
|
|
802
852
|
def _save_metadata(self, layer):
|
|
803
853
|
from keras_hub.src.models.task import Task
|
|
804
|
-
from keras_hub.src.
|
|
854
|
+
from keras_hub.src.version import __version__ as keras_hub_version
|
|
805
855
|
|
|
806
856
|
# Find all tasks that are compatible with the backbone.
|
|
807
857
|
# E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
|
|
@@ -823,3 +873,13 @@ class KerasPresetSaver:
|
|
|
823
873
|
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
|
|
824
874
|
with open(metadata_path, "w") as metadata_file:
|
|
825
875
|
metadata_file.write(json.dumps(metadata, indent=4))
|
|
876
|
+
|
|
877
|
+
def _get_variables_size_in_bytes(self, variables):
|
|
878
|
+
unique_variables = {}
|
|
879
|
+
for v in variables:
|
|
880
|
+
if id(v) not in unique_variables:
|
|
881
|
+
unique_variables[id(v)] = (v.shape, v.dtype)
|
|
882
|
+
total_memory_size = 0
|
|
883
|
+
for shape, dtype in unique_variables.values():
|
|
884
|
+
total_memory_size += get_tensor_size_in_bits(shape, dtype)
|
|
885
|
+
return total_memory_size / 8
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
import functools
|
|
3
3
|
import inspect
|
|
4
|
+
import math
|
|
5
|
+
import re
|
|
4
6
|
import threading
|
|
5
7
|
|
|
6
8
|
import keras
|
|
@@ -305,6 +307,29 @@ def is_string_dtype(dtype):
|
|
|
305
307
|
return "string" in keras.backend.standardize_dtype(dtype)
|
|
306
308
|
|
|
307
309
|
|
|
310
|
+
def get_dtype_size_in_bits(dtype):
|
|
311
|
+
"""Get the size of a given dtype in bits."""
|
|
312
|
+
dtype = keras.backend.standardize_dtype(dtype)
|
|
313
|
+
# If dtype is bool, return 1 immediately.
|
|
314
|
+
if dtype == "bool":
|
|
315
|
+
return 1
|
|
316
|
+
# Else, we extract the bit size from the string.
|
|
317
|
+
return int(re.sub(r"bfloat|float|uint|int", "", dtype))
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def get_tensor_size_in_bits(shape, dtype):
|
|
321
|
+
"""Calculate the size given dtype and shape in bits.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
dtype: The dtype of the tensor.
|
|
325
|
+
shape: List of iterables representing the shape of the tensor.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
The size of the tensor in bytes.
|
|
329
|
+
"""
|
|
330
|
+
return math.prod(shape) * get_dtype_size_in_bits(dtype)
|
|
331
|
+
|
|
332
|
+
|
|
308
333
|
def any_equal(inputs, values, padding_mask):
|
|
309
334
|
"""Return a mask that is True anywhere `inputs` has a value in `values`.
|
|
310
335
|
|
|
@@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask):
|
|
|
320
345
|
Returns:
|
|
321
346
|
A tensor with `inputs` shape where each position is True if it contains
|
|
322
347
|
a value from any `values`. Padding mask will be applied before
|
|
323
|
-
returning.
|
|
348
|
+
returning.
|
|
349
|
+
"""
|
|
324
350
|
output = ops.equal(inputs, values[0])
|
|
325
351
|
for value in values[1:]:
|
|
326
352
|
value_equality = ops.equal(inputs, value)
|
|
@@ -17,10 +17,69 @@ def convert_backbone_config(timm_config):
|
|
|
17
17
|
bottle_ratio = (0.5,) + (1.0,)
|
|
18
18
|
block_ratio = (1.0,) + (0.5,)
|
|
19
19
|
expand_ratio = (2.0,) + (1.0,)
|
|
20
|
+
stem_padding = "same"
|
|
21
|
+
stem_pooling = None
|
|
20
22
|
stage_type = "csp"
|
|
23
|
+
groups = 1
|
|
21
24
|
block_type = "dark_block"
|
|
22
25
|
down_growth = True
|
|
23
|
-
stackwise_strides = 2
|
|
26
|
+
stackwise_strides = [2, 2, 2, 2, 2]
|
|
27
|
+
avg_down = False
|
|
28
|
+
cross_linear = False
|
|
29
|
+
elif timm_architecture == "cspresnet50":
|
|
30
|
+
stem_filters = 64
|
|
31
|
+
stem_kernel_size = 7
|
|
32
|
+
stem_strides = 4
|
|
33
|
+
stackwise_depth = [3, 3, 5, 2]
|
|
34
|
+
stackwise_strides = [1, 2, 2, 2]
|
|
35
|
+
stackwise_num_filters = [128, 256, 512, 1024]
|
|
36
|
+
block_type = "bottleneck_block"
|
|
37
|
+
stage_type = "csp"
|
|
38
|
+
bottle_ratio = [0.5]
|
|
39
|
+
block_ratio = [1.0]
|
|
40
|
+
expand_ratio = [2.0]
|
|
41
|
+
stem_padding = "valid"
|
|
42
|
+
stem_pooling = "max"
|
|
43
|
+
avg_down = False
|
|
44
|
+
groups = 1
|
|
45
|
+
down_growth = False
|
|
46
|
+
cross_linear = True
|
|
47
|
+
elif timm_architecture == "cspresnext50":
|
|
48
|
+
stem_filters = 64
|
|
49
|
+
stem_kernel_size = 7
|
|
50
|
+
stem_strides = 4
|
|
51
|
+
stackwise_depth = [3, 3, 5, 2]
|
|
52
|
+
stackwise_num_filters = [256, 512, 1024, 2048]
|
|
53
|
+
bottle_ratio = [1.0]
|
|
54
|
+
block_ratio = [0.5]
|
|
55
|
+
expand_ratio = [1.0]
|
|
56
|
+
stage_type = "csp"
|
|
57
|
+
block_type = "bottleneck_block"
|
|
58
|
+
stem_pooling = "max"
|
|
59
|
+
stackwise_strides = [1, 2, 2, 2]
|
|
60
|
+
groups = 32
|
|
61
|
+
stem_padding = "valid"
|
|
62
|
+
avg_down = False
|
|
63
|
+
down_growth = False
|
|
64
|
+
cross_linear = True
|
|
65
|
+
elif timm_architecture == "darknet53":
|
|
66
|
+
stem_filters = 32
|
|
67
|
+
stem_kernel_size = 3
|
|
68
|
+
stem_strides = 1
|
|
69
|
+
stackwise_depth = [1, 2, 8, 8, 4]
|
|
70
|
+
stackwise_num_filters = [64, 128, 256, 512, 1024]
|
|
71
|
+
bottle_ratio = [0.5]
|
|
72
|
+
block_ratio = [1.0]
|
|
73
|
+
groups = 1
|
|
74
|
+
expand_ratio = [1.0]
|
|
75
|
+
stage_type = "dark"
|
|
76
|
+
block_type = "dark_block"
|
|
77
|
+
stem_pooling = None
|
|
78
|
+
stackwise_strides = [2, 2, 2, 2, 2]
|
|
79
|
+
stem_padding = "same"
|
|
80
|
+
avg_down = False
|
|
81
|
+
down_growth = False
|
|
82
|
+
cross_linear = False
|
|
24
83
|
else:
|
|
25
84
|
raise ValueError(
|
|
26
85
|
f"Currently, the architecture {timm_architecture} is not supported."
|
|
@@ -38,6 +97,11 @@ def convert_backbone_config(timm_config):
|
|
|
38
97
|
block_type=block_type,
|
|
39
98
|
stackwise_strides=stackwise_strides,
|
|
40
99
|
down_growth=down_growth,
|
|
100
|
+
stem_pooling=stem_pooling,
|
|
101
|
+
stem_padding=stem_padding,
|
|
102
|
+
avg_down=avg_down,
|
|
103
|
+
cross_linear=cross_linear,
|
|
104
|
+
groups=groups,
|
|
41
105
|
)
|
|
42
106
|
|
|
43
107
|
|
|
@@ -81,21 +145,36 @@ def convert_weights(backbone, loader, timm_config):
|
|
|
81
145
|
stackwise_depth = backbone.stackwise_depth
|
|
82
146
|
stage_type = backbone.stage_type
|
|
83
147
|
block_type = backbone.block_type
|
|
148
|
+
strides = backbone.stackwise_strides
|
|
84
149
|
|
|
85
150
|
for idx, block in enumerate(stackwise_depth):
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
151
|
+
if strides[idx] != 1 or stage_type == "dark":
|
|
152
|
+
if strides[idx] == 2 and backbone.avg_down:
|
|
153
|
+
port_conv2d(
|
|
154
|
+
f"stages.{idx}.conv_down.1.conv",
|
|
155
|
+
f"stage_{idx}_{stage_type}_conv_down_1",
|
|
156
|
+
)
|
|
157
|
+
port_batch_normalization(
|
|
158
|
+
f"stages.{idx}.conv_down.1.bn",
|
|
159
|
+
f"stage_{idx}_{stage_type}_bn_1",
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
port_conv2d(
|
|
163
|
+
f"stages.{idx}.conv_down.conv",
|
|
164
|
+
f"stage_{idx}_{stage_type}_conv_down_1",
|
|
165
|
+
)
|
|
166
|
+
port_batch_normalization(
|
|
167
|
+
f"stages.{idx}.conv_down.bn",
|
|
168
|
+
f"stage_{idx}_{stage_type}_bn_1",
|
|
169
|
+
)
|
|
170
|
+
if stage_type != "dark":
|
|
171
|
+
port_conv2d(
|
|
172
|
+
f"stages.{idx}.conv_exp.conv",
|
|
173
|
+
f"stage_{idx}_{stage_type}_conv_exp",
|
|
174
|
+
)
|
|
175
|
+
port_batch_normalization(
|
|
176
|
+
f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
|
|
177
|
+
)
|
|
99
178
|
|
|
100
179
|
for i in range(block):
|
|
101
180
|
port_conv2d(
|
|
@@ -133,16 +212,8 @@ def convert_weights(backbone, loader, timm_config):
|
|
|
133
212
|
f"stages.{idx}.conv_transition_b.bn",
|
|
134
213
|
f"stage_{idx}_{stage_type}_transition_b_bn",
|
|
135
214
|
)
|
|
136
|
-
port_conv2d(
|
|
137
|
-
f"stages.{idx}.conv_transition.conv",
|
|
138
|
-
f"stage_{idx}_{stage_type}_conv_transition",
|
|
139
|
-
)
|
|
140
|
-
port_batch_normalization(
|
|
141
|
-
f"stages.{idx}.conv_transition.bn",
|
|
142
|
-
f"stage_{idx}_{stage_type}_transition_bn",
|
|
143
|
-
)
|
|
144
215
|
|
|
145
|
-
|
|
216
|
+
if stage_type != "dark":
|
|
146
217
|
port_conv2d(
|
|
147
218
|
f"stages.{idx}.conv_transition.conv",
|
|
148
219
|
f"stage_{idx}_{stage_type}_conv_transition",
|
|
@@ -16,17 +16,17 @@ class TimmPresetLoader(PresetLoader):
|
|
|
16
16
|
def __init__(self, preset, config):
|
|
17
17
|
super().__init__(preset, config)
|
|
18
18
|
architecture = self.config["architecture"]
|
|
19
|
-
if "resnet"
|
|
19
|
+
if architecture.startswith("resnet"):
|
|
20
20
|
self.converter = convert_resnet
|
|
21
|
-
elif "csp"
|
|
21
|
+
elif architecture.startswith(("csp", "dark")):
|
|
22
22
|
self.converter = convert_cspnet
|
|
23
|
-
elif "densenet"
|
|
23
|
+
elif architecture.startswith("densenet"):
|
|
24
24
|
self.converter = convert_densenet
|
|
25
|
-
elif "mobilenet"
|
|
25
|
+
elif architecture.startswith("mobilenet"):
|
|
26
26
|
self.converter = convert_mobilenet
|
|
27
|
-
elif "vgg"
|
|
27
|
+
elif architecture.startswith("vgg"):
|
|
28
28
|
self.converter = convert_vgg
|
|
29
|
-
elif "efficientnet"
|
|
29
|
+
elif architecture.startswith("efficientnet"):
|
|
30
30
|
self.converter = convert_efficientnet
|
|
31
31
|
else:
|
|
32
32
|
raise ValueError(
|
|
@@ -7,7 +7,7 @@ backbone_cls = Llama3Backbone
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def convert_backbone_config(transformers_config):
|
|
10
|
-
|
|
10
|
+
backbone_config = {
|
|
11
11
|
"vocabulary_size": transformers_config["vocab_size"],
|
|
12
12
|
"num_layers": transformers_config["num_hidden_layers"],
|
|
13
13
|
"num_query_heads": transformers_config["num_attention_heads"],
|
|
@@ -15,8 +15,28 @@ def convert_backbone_config(transformers_config):
|
|
|
15
15
|
"intermediate_dim": transformers_config["intermediate_size"],
|
|
16
16
|
"num_key_value_heads": transformers_config["num_key_value_heads"],
|
|
17
17
|
"tie_word_embeddings": transformers_config["tie_word_embeddings"],
|
|
18
|
+
"rope_max_wavelength": transformers_config["rope_theta"],
|
|
18
19
|
}
|
|
19
20
|
|
|
21
|
+
if transformers_config.get("rope_scaling", None) is not None:
|
|
22
|
+
if transformers_config["rope_scaling"]["rope_type"] != "llama3":
|
|
23
|
+
raise ValueError("The config should be a valid llama3 config.")
|
|
24
|
+
backbone_config["rope_frequency_adjustment_factor"] = (
|
|
25
|
+
transformers_config["rope_scaling"]["factor"]
|
|
26
|
+
)
|
|
27
|
+
backbone_config["rope_low_freq_factor"] = transformers_config[
|
|
28
|
+
"rope_scaling"
|
|
29
|
+
]["low_freq_factor"]
|
|
30
|
+
backbone_config["rope_high_freq_factor"] = transformers_config[
|
|
31
|
+
"rope_scaling"
|
|
32
|
+
]["high_freq_factor"]
|
|
33
|
+
backbone_config["rope_pretraining_sequence_length"] = (
|
|
34
|
+
transformers_config["rope_scaling"][
|
|
35
|
+
"original_max_position_embeddings"
|
|
36
|
+
]
|
|
37
|
+
)
|
|
38
|
+
return backbone_config
|
|
39
|
+
|
|
20
40
|
|
|
21
41
|
def convert_weights(backbone, loader, transformers_config):
|
|
22
42
|
loader.port_weight(
|