keras-hub-nightly 0.22.0.dev202507150421__py3-none-any.whl → 0.22.0.dev202507170424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +3 -0
- keras_hub/src/models/clip/clip_backbone.py +3 -102
- keras_hub/src/models/clip/clip_layers.py +295 -0
- keras_hub/src/models/clip/clip_preprocessor.py +57 -48
- keras_hub/src/models/clip/clip_text_encoder.py +2 -2
- keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
- keras_hub/src/models/dinov2/__init__.py +5 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
- keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
- keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
- keras_hub/src/models/dinov2/dinov2_presets.py +4 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/hgnetv2/__init__.py +5 -0
- keras_hub/src/models/hgnetv2/hgnetv2_presets.py +5 -5
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
- keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +23 -32
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
- keras_hub/src/utils/preset_utils.py +4 -1
- keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
- keras_hub/src/utils/transformers/export/gemma.py +89 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
- keras_hub/src/utils/transformers/preset_loader.py +4 -1
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/RECORD +32 -25
- keras_hub/src/models/clip/clip_encoder_block.py +0 -111
- keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/top_level.txt +0 -0
@@ -43,9 +43,13 @@ class FluxTextToImagePreprocessor(Preprocessor):
|
|
43
43
|
|
44
44
|
def generate_preprocess(self, x):
|
45
45
|
token_ids = {}
|
46
|
-
token_ids["clip_l"] = self.clip_l_preprocessor(
|
46
|
+
token_ids["clip_l"] = self.clip_l_preprocessor(
|
47
|
+
{"prompts": x, "images": None}
|
48
|
+
)["token_ids"]
|
47
49
|
if self.t5_preprocessor is not None:
|
48
|
-
token_ids["t5"] = self.t5_preprocessor(
|
50
|
+
token_ids["t5"] = self.t5_preprocessor(
|
51
|
+
{"prompts": x, "images": None}
|
52
|
+
)["token_ids"]
|
49
53
|
return token_ids
|
50
54
|
|
51
55
|
def get_config(self):
|
@@ -9,7 +9,7 @@ backbone_presets = {
|
|
9
9
|
"params": 13599072,
|
10
10
|
"path": "hgnetv2",
|
11
11
|
},
|
12
|
-
"kaggle_handle": "",
|
12
|
+
"kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b4_ssld_stage2_ft_in1k/1",
|
13
13
|
},
|
14
14
|
"hgnetv2_b5_ssld_stage1_in22k_in1k": {
|
15
15
|
"metadata": {
|
@@ -20,7 +20,7 @@ backbone_presets = {
|
|
20
20
|
"params": 33419680,
|
21
21
|
"path": "hgnetv2",
|
22
22
|
},
|
23
|
-
"kaggle_handle": "",
|
23
|
+
"kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b5_ssld_stage1_in22k_in1k/1",
|
24
24
|
},
|
25
25
|
"hgnetv2_b5_ssld_stage2_ft_in1k": {
|
26
26
|
"metadata": {
|
@@ -31,7 +31,7 @@ backbone_presets = {
|
|
31
31
|
"params": 33419680,
|
32
32
|
"path": "hgnetv2",
|
33
33
|
},
|
34
|
-
"kaggle_handle": "",
|
34
|
+
"kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b5_ssld_stage2_ft_in1k/1",
|
35
35
|
},
|
36
36
|
"hgnetv2_b6_ssld_stage1_in22k_in1k": {
|
37
37
|
"metadata": {
|
@@ -42,7 +42,7 @@ backbone_presets = {
|
|
42
42
|
"params": 69179888,
|
43
43
|
"path": "hgnetv2",
|
44
44
|
},
|
45
|
-
"kaggle_handle": "",
|
45
|
+
"kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b6_ssld_stage1_in22k_in1k/1",
|
46
46
|
},
|
47
47
|
"hgnetv2_b6_ssld_stage2_ft_in1k": {
|
48
48
|
"metadata": {
|
@@ -53,6 +53,6 @@ backbone_presets = {
|
|
53
53
|
"params": 69179888,
|
54
54
|
"path": "hgnetv2",
|
55
55
|
},
|
56
|
-
"kaggle_handle": "",
|
56
|
+
"kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b6_ssld_stage2_ft_in1k/1",
|
57
57
|
},
|
58
58
|
}
|
@@ -38,7 +38,6 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
|
|
38
38
|
timesteps = ops.flip(timesteps, axis=0)
|
39
39
|
sigmas = self._timestep_to_sigma(timesteps)
|
40
40
|
|
41
|
-
self.timesteps = ops.multiply(sigmas, num_train_timesteps)
|
42
41
|
self.sigma_min = sigmas[-1]
|
43
42
|
self.sigma_max = sigmas[0]
|
44
43
|
|
@@ -54,14 +53,24 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
|
|
54
53
|
)
|
55
54
|
return sigma
|
56
55
|
|
56
|
+
def set_sigmas(self, num_steps):
|
57
|
+
timesteps = ops.linspace(
|
58
|
+
self._sigma_to_timestep(self.sigma_max),
|
59
|
+
self._sigma_to_timestep(self.sigma_min),
|
60
|
+
num_steps,
|
61
|
+
)
|
62
|
+
sigmas = self._timestep_to_sigma(timesteps)
|
63
|
+
sigmas = ops.concatenate([sigmas, ops.zeros((1,), dtype=sigmas.dtype)])
|
64
|
+
self.sigmas = sigmas
|
65
|
+
|
57
66
|
def call(self, inputs, num_steps):
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
67
|
+
if not hasattr(self, "sigmas"):
|
68
|
+
self.set_sigmas(num_steps)
|
69
|
+
|
70
|
+
step = ops.expand_dims(
|
71
|
+
ops.convert_to_tensor(inputs, dtype="int32"), axis=0
|
62
72
|
)
|
63
|
-
|
64
|
-
sigma = ops.maximum(self._timestep_to_sigma(timestep), 0.0)
|
73
|
+
sigma = ops.take(self.sigmas, step)
|
65
74
|
timestep = self._sigma_to_timestep(sigma)
|
66
75
|
return sigma, timestep
|
67
76
|
|
@@ -10,6 +10,63 @@ from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
|
10
10
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
11
11
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
12
12
|
|
13
|
+
# TODO: Deprecate this in favor of
|
14
|
+
# `keras.layers.RMSNormalization` once we require Keras 3.9 or later.
|
15
|
+
if hasattr(layers, "RMSNormalization"):
|
16
|
+
RMSNormalization = layers.RMSNormalization
|
17
|
+
else:
|
18
|
+
|
19
|
+
class RMSNormalization(layers.Layer):
|
20
|
+
"""A normalization layer for MMDiT that implements RMS normalization."""
|
21
|
+
|
22
|
+
def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
|
23
|
+
super().__init__(**kwargs)
|
24
|
+
self.axis = axis
|
25
|
+
self.epsilon = epsilon
|
26
|
+
|
27
|
+
def build(self, input_shape):
|
28
|
+
if isinstance(self.axis, list):
|
29
|
+
shape = tuple([input_shape[dim] for dim in self.axis])
|
30
|
+
else:
|
31
|
+
shape = (input_shape[self.axis],)
|
32
|
+
self.axis = [self.axis]
|
33
|
+
|
34
|
+
self.scale = self.add_weight(
|
35
|
+
name="scale", shape=shape, initializer="ones"
|
36
|
+
)
|
37
|
+
|
38
|
+
self.built = True
|
39
|
+
|
40
|
+
def call(self, x):
|
41
|
+
x = ops.cast(
|
42
|
+
x, keras.backend.result_type(self.compute_dtype, "float32")
|
43
|
+
)
|
44
|
+
rrms = ops.rsqrt(
|
45
|
+
ops.mean(ops.square(x), axis=self.axis, keepdims=True)
|
46
|
+
+ self.epsilon
|
47
|
+
)
|
48
|
+
return (x * rrms) * ops.cast(self.scale, x.dtype)
|
49
|
+
|
50
|
+
def compute_output_shape(self, input_shape):
|
51
|
+
if isinstance(self.axis, int):
|
52
|
+
axes = [self.axis]
|
53
|
+
else:
|
54
|
+
axes = self.axis
|
55
|
+
|
56
|
+
for axis in axes:
|
57
|
+
if axis >= len(input_shape) or axis < -len(input_shape):
|
58
|
+
raise ValueError(
|
59
|
+
f"Axis {axis} is out of bounds for "
|
60
|
+
f"input shape {input_shape}. "
|
61
|
+
f"Received: axis={self.axis}"
|
62
|
+
)
|
63
|
+
return input_shape
|
64
|
+
|
65
|
+
def get_config(self):
|
66
|
+
config = super().get_config()
|
67
|
+
config.update({"axis": self.axis, "epsilon": self.epsilon})
|
68
|
+
return config
|
69
|
+
|
13
70
|
|
14
71
|
class AdaptiveLayerNormalization(layers.Layer):
|
15
72
|
"""Adaptive layer normalization.
|
@@ -402,11 +459,11 @@ def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
|
|
402
459
|
if qk_norm is None:
|
403
460
|
pass
|
404
461
|
elif qk_norm == "rms_norm":
|
405
|
-
q_norm =
|
406
|
-
epsilon=1e-6,
|
462
|
+
q_norm = RMSNormalization(
|
463
|
+
axis=-1, epsilon=1e-6, dtype="float32", name=q_norm_name
|
407
464
|
)
|
408
|
-
k_norm =
|
409
|
-
epsilon=1e-6,
|
465
|
+
k_norm = RMSNormalization(
|
466
|
+
axis=-1, epsilon=1e-6, dtype="float32", name=k_norm_name
|
410
467
|
)
|
411
468
|
else:
|
412
469
|
raise NotImplementedError(
|
@@ -96,26 +96,10 @@ class LatentRescaling(layers.Rescaling):
|
|
96
96
|
return (self.backend.cast(inputs, dtype) / scale) + offset
|
97
97
|
|
98
98
|
|
99
|
-
class
|
100
|
-
def call(
|
101
|
-
self,
|
102
|
-
latents,
|
103
|
-
positive_contexts,
|
104
|
-
negative_contexts,
|
105
|
-
positive_pooled_projections,
|
106
|
-
negative_pooled_projections,
|
107
|
-
timestep,
|
108
|
-
):
|
99
|
+
class TimestepBroadcastTo(layers.Layer):
|
100
|
+
def call(self, latents, timestep):
|
109
101
|
timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
|
110
|
-
|
111
|
-
contexts = ops.concatenate(
|
112
|
-
[positive_contexts, negative_contexts], axis=0
|
113
|
-
)
|
114
|
-
pooled_projections = ops.concatenate(
|
115
|
-
[positive_pooled_projections, negative_pooled_projections], axis=0
|
116
|
-
)
|
117
|
-
timesteps = ops.concatenate([timestep, timestep], axis=0)
|
118
|
-
return latents, contexts, pooled_projections, timesteps
|
102
|
+
return timestep
|
119
103
|
|
120
104
|
|
121
105
|
class ClassifierFreeGuidance(layers.Layer):
|
@@ -330,8 +314,8 @@ class StableDiffusion3Backbone(Backbone):
|
|
330
314
|
name="diffuser",
|
331
315
|
)
|
332
316
|
self.vae = vae
|
333
|
-
self.
|
334
|
-
dtype=dtype, name="
|
317
|
+
self.timestep_broadcast_to = TimestepBroadcastTo(
|
318
|
+
dtype=dtype, name="timestep_broadcast_to"
|
335
319
|
)
|
336
320
|
self.cfg = ClassifierFreeGuidance(
|
337
321
|
dtype=dtype, name="classifier_free_guidance"
|
@@ -538,6 +522,9 @@ class StableDiffusion3Backbone(Backbone):
|
|
538
522
|
latents = self.vae.encode(images)
|
539
523
|
return self.image_rescaling(latents)
|
540
524
|
|
525
|
+
def configure_scheduler(self, num_steps):
|
526
|
+
self.scheduler.set_sigmas(num_steps)
|
527
|
+
|
541
528
|
def add_noise_step(self, latents, noises, step, num_steps):
|
542
529
|
return self.scheduler.add_noise(latents, noises, step, num_steps)
|
543
530
|
|
@@ -562,11 +549,15 @@ class StableDiffusion3Backbone(Backbone):
|
|
562
549
|
|
563
550
|
# Concatenation for classifier-free guidance.
|
564
551
|
if guidance_scale is not None:
|
565
|
-
|
566
|
-
|
552
|
+
timestep = self.timestep_broadcast_to(latents, timestep)
|
553
|
+
timesteps = ops.concatenate([timestep, timestep], axis=0)
|
554
|
+
concated_latents = ops.concatenate([latents, latents], axis=0)
|
555
|
+
contexts = ops.concatenate([embeddings[0], embeddings[1]], axis=0)
|
556
|
+
pooled_projs = ops.concatenate(
|
557
|
+
[embeddings[2], embeddings[3]], axis=0
|
567
558
|
)
|
568
559
|
else:
|
569
|
-
timesteps =
|
560
|
+
timesteps = self.timestep_broadcast_to(latents, timestep)
|
570
561
|
concated_latents = latents
|
571
562
|
contexts = embeddings[0]
|
572
563
|
pooled_projs = embeddings[2]
|
@@ -623,20 +614,20 @@ class StableDiffusion3Backbone(Backbone):
|
|
623
614
|
def from_config(cls, config, custom_objects=None):
|
624
615
|
config = config.copy()
|
625
616
|
|
626
|
-
# Propagate `dtype` to
|
617
|
+
# Propagate `dtype` to the VAE if needed.
|
627
618
|
if "dtype" in config and config["dtype"] is not None:
|
628
619
|
dtype_config = config["dtype"]
|
629
620
|
if "dtype" not in config["vae"]["config"]:
|
630
621
|
config["vae"]["config"]["dtype"] = dtype_config
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
config["clip_g"]["config"]["dtype"] = dtype_config
|
622
|
+
|
623
|
+
# Text encoders default to float16 dtype if not specified.
|
624
|
+
for text_encoder in ("clip_l", "clip_g", "t5"):
|
635
625
|
if (
|
636
|
-
|
637
|
-
and
|
626
|
+
text_encoder in config
|
627
|
+
and config[text_encoder] is not None
|
628
|
+
and "dtype" not in config[text_encoder]["config"]
|
638
629
|
):
|
639
|
-
config[
|
630
|
+
config[text_encoder]["config"]["dtype"] = "float16"
|
640
631
|
|
641
632
|
# We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
|
642
633
|
config["vae"] = layers.deserialize(
|
@@ -50,8 +50,12 @@ class StableDiffusion3TextToImagePreprocessor(TextToImagePreprocessor):
|
|
50
50
|
|
51
51
|
def generate_preprocess(self, x):
|
52
52
|
token_ids = {}
|
53
|
-
token_ids["clip_l"] = self.clip_l_preprocessor(
|
54
|
-
|
53
|
+
token_ids["clip_l"] = self.clip_l_preprocessor(
|
54
|
+
{"prompts": x, "images": None}
|
55
|
+
)["token_ids"]
|
56
|
+
token_ids["clip_g"] = self.clip_g_preprocessor(
|
57
|
+
{"prompts": x, "images": None}
|
58
|
+
)["token_ids"]
|
55
59
|
if self.t5_preprocessor is not None:
|
56
60
|
token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
|
57
61
|
return token_ids
|
@@ -649,7 +649,10 @@ class KerasPresetLoader(PresetLoader):
|
|
649
649
|
return check_config_class(self.config)
|
650
650
|
|
651
651
|
def load_backbone(self, cls, load_weights, **kwargs):
|
652
|
-
|
652
|
+
config = self.config.copy()
|
653
|
+
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
|
654
|
+
config["config"] = {**config["config"], **backbone_kwargs}
|
655
|
+
backbone = self._load_serialized_object(config, **kwargs)
|
653
656
|
if load_weights:
|
654
657
|
jax_memory_cleanup(backbone)
|
655
658
|
self._load_backbone_weights(backbone)
|
@@ -0,0 +1,180 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from keras_hub.src.models.dinov2.dinov2_backbone import DINOV2Backbone
|
4
|
+
|
5
|
+
backbone_cls = DINOV2Backbone
|
6
|
+
|
7
|
+
|
8
|
+
def convert_backbone_config(transformers_config):
|
9
|
+
model_type = transformers_config["model_type"]
|
10
|
+
antialias_in_interpolation = False if model_type == "dinov2" else True
|
11
|
+
image_size = transformers_config["image_size"]
|
12
|
+
intermediate_dim = int(
|
13
|
+
transformers_config["hidden_size"] * transformers_config["mlp_ratio"]
|
14
|
+
)
|
15
|
+
return {
|
16
|
+
"patch_size": transformers_config["patch_size"],
|
17
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
18
|
+
"hidden_dim": transformers_config["hidden_size"],
|
19
|
+
"num_heads": transformers_config["num_attention_heads"],
|
20
|
+
"intermediate_dim": intermediate_dim,
|
21
|
+
"layer_scale_init_value": transformers_config["layerscale_value"],
|
22
|
+
"num_register_tokens": transformers_config.get(
|
23
|
+
"num_register_tokens", 0
|
24
|
+
),
|
25
|
+
"use_mask_token": transformers_config.get("use_mask_token", True),
|
26
|
+
"use_swiglu_ffn": transformers_config["use_swiglu_ffn"],
|
27
|
+
"dropout_rate": transformers_config["hidden_dropout_prob"],
|
28
|
+
"drop_path_rate": transformers_config["drop_path_rate"],
|
29
|
+
"image_shape": (image_size, image_size, 3),
|
30
|
+
"position_embedding_shape": (image_size, image_size),
|
31
|
+
"antialias_in_interpolation": antialias_in_interpolation,
|
32
|
+
}
|
33
|
+
|
34
|
+
|
35
|
+
def convert_weights(backbone, loader, transformers_config):
|
36
|
+
if not isinstance(backbone, DINOV2Backbone):
|
37
|
+
raise ValueError(
|
38
|
+
"The provided backbone must be an instance of DINOV2Backbone. "
|
39
|
+
f"Received: {type(backbone)}"
|
40
|
+
)
|
41
|
+
|
42
|
+
def port_ln(keras_variable, weight_key):
|
43
|
+
loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
|
44
|
+
loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
|
45
|
+
|
46
|
+
def port_dense(keras_variable, weight_key):
|
47
|
+
loader.port_weight(
|
48
|
+
keras_variable.kernel,
|
49
|
+
f"{weight_key}.weight",
|
50
|
+
hook_fn=lambda x, _: x.T,
|
51
|
+
)
|
52
|
+
if keras_variable.bias is not None:
|
53
|
+
loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
|
54
|
+
|
55
|
+
def port_mha(keras_variable, weight_key, num_heads, hidden_dim):
|
56
|
+
# query
|
57
|
+
loader.port_weight(
|
58
|
+
keras_variable.query_dense.kernel,
|
59
|
+
f"{weight_key}.attention.query.weight",
|
60
|
+
hook_fn=lambda x, _: np.reshape(
|
61
|
+
x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
|
62
|
+
),
|
63
|
+
)
|
64
|
+
loader.port_weight(
|
65
|
+
keras_variable.query_dense.bias,
|
66
|
+
f"{weight_key}.attention.query.bias",
|
67
|
+
hook_fn=lambda x, _: np.reshape(
|
68
|
+
x, (num_heads, hidden_dim // num_heads)
|
69
|
+
),
|
70
|
+
)
|
71
|
+
# key
|
72
|
+
loader.port_weight(
|
73
|
+
keras_variable.key_dense.kernel,
|
74
|
+
f"{weight_key}.attention.key.weight",
|
75
|
+
hook_fn=lambda x, _: np.reshape(
|
76
|
+
x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
|
77
|
+
),
|
78
|
+
)
|
79
|
+
loader.port_weight(
|
80
|
+
keras_variable.key_dense.bias,
|
81
|
+
f"{weight_key}.attention.key.bias",
|
82
|
+
hook_fn=lambda x, _: np.reshape(
|
83
|
+
x, (num_heads, hidden_dim // num_heads)
|
84
|
+
),
|
85
|
+
)
|
86
|
+
# value
|
87
|
+
loader.port_weight(
|
88
|
+
keras_variable.value_dense.kernel,
|
89
|
+
f"{weight_key}.attention.value.weight",
|
90
|
+
hook_fn=lambda x, _: np.reshape(
|
91
|
+
x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
|
92
|
+
),
|
93
|
+
)
|
94
|
+
loader.port_weight(
|
95
|
+
keras_variable.value_dense.bias,
|
96
|
+
f"{weight_key}.attention.value.bias",
|
97
|
+
hook_fn=lambda x, _: np.reshape(
|
98
|
+
x, (num_heads, hidden_dim // num_heads)
|
99
|
+
),
|
100
|
+
)
|
101
|
+
# output
|
102
|
+
loader.port_weight(
|
103
|
+
keras_variable.output_dense.kernel,
|
104
|
+
f"{weight_key}.output.dense.weight",
|
105
|
+
hook_fn=lambda x, _: np.reshape(
|
106
|
+
x.T, (num_heads, hidden_dim // num_heads, hidden_dim)
|
107
|
+
),
|
108
|
+
)
|
109
|
+
loader.port_weight(
|
110
|
+
keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias"
|
111
|
+
)
|
112
|
+
|
113
|
+
# Embedding.
|
114
|
+
loader.port_weight(
|
115
|
+
keras_variable=backbone.embeddings.cls_token,
|
116
|
+
hf_weight_key="embeddings.cls_token",
|
117
|
+
)
|
118
|
+
if backbone.use_mask_token:
|
119
|
+
loader.port_weight(
|
120
|
+
keras_variable=backbone.embeddings.mask_token,
|
121
|
+
hf_weight_key="embeddings.mask_token",
|
122
|
+
)
|
123
|
+
if backbone.num_register_tokens > 0:
|
124
|
+
loader.port_weight(
|
125
|
+
keras_variable=backbone.embeddings.register_tokens,
|
126
|
+
hf_weight_key="embeddings.register_tokens",
|
127
|
+
)
|
128
|
+
loader.port_weight(
|
129
|
+
keras_variable=backbone.embeddings.position_embeddings,
|
130
|
+
hf_weight_key="embeddings.position_embeddings",
|
131
|
+
)
|
132
|
+
# Interpolate position embeddings to match the image shape.
|
133
|
+
backbone.embeddings.interpolated_position_embeddings.assign(
|
134
|
+
backbone.embeddings._interpolate_position_embeddings(
|
135
|
+
backbone.embeddings.position_embeddings,
|
136
|
+
patch_size=backbone.patch_size,
|
137
|
+
source_shape=backbone.embeddings.position_embedding_shape,
|
138
|
+
target_shape=backbone.image_shape,
|
139
|
+
antialias=backbone.embeddings.antialias_in_interpolation,
|
140
|
+
)
|
141
|
+
)
|
142
|
+
loader.port_weight(
|
143
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.kernel,
|
144
|
+
hf_weight_key="embeddings.patch_embeddings.projection.weight",
|
145
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
146
|
+
)
|
147
|
+
loader.port_weight(
|
148
|
+
keras_variable=backbone.embeddings.patch_embeddings.projection.bias,
|
149
|
+
hf_weight_key="embeddings.patch_embeddings.projection.bias",
|
150
|
+
)
|
151
|
+
|
152
|
+
# Encoder.
|
153
|
+
hidden_dim = backbone.hidden_dim
|
154
|
+
num_heads = backbone.num_heads
|
155
|
+
for i, layer in enumerate(backbone.encoder.layers):
|
156
|
+
prefix = f"encoder.layer.{i}"
|
157
|
+
port_ln(layer.norm1, f"{prefix}.norm1")
|
158
|
+
port_mha(
|
159
|
+
layer.attention.attention,
|
160
|
+
f"{prefix}.attention",
|
161
|
+
num_heads,
|
162
|
+
hidden_dim,
|
163
|
+
)
|
164
|
+
loader.port_weight(
|
165
|
+
keras_variable=layer.layer_scale1.lambda1,
|
166
|
+
hf_weight_key=f"{prefix}.layer_scale1.lambda1",
|
167
|
+
)
|
168
|
+
port_ln(layer.norm2, f"{prefix}.norm2")
|
169
|
+
if backbone.use_swiglu_ffn:
|
170
|
+
port_dense(layer.mlp.weights_in, f"{prefix}.mlp.weights_in")
|
171
|
+
port_dense(layer.mlp.weights_out, f"{prefix}.mlp.weights_out")
|
172
|
+
else:
|
173
|
+
port_dense(layer.mlp.fc1, f"{prefix}.mlp.fc1")
|
174
|
+
port_dense(layer.mlp.fc2, f"{prefix}.mlp.fc2")
|
175
|
+
loader.port_weight(
|
176
|
+
keras_variable=layer.layer_scale2.lambda1,
|
177
|
+
hf_weight_key=f"{prefix}.layer_scale2.lambda1",
|
178
|
+
)
|
179
|
+
|
180
|
+
port_ln(backbone.layernorm, "layernorm")
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import keras.ops as ops
|
2
|
+
|
3
|
+
|
4
|
+
def get_gemma_config(backbone):
|
5
|
+
hf_config = {
|
6
|
+
"vocab_size": backbone.vocabulary_size,
|
7
|
+
"num_hidden_layers": backbone.num_layers,
|
8
|
+
"num_attention_heads": backbone.num_query_heads,
|
9
|
+
"num_key_value_heads": backbone.num_key_value_heads,
|
10
|
+
"hidden_size": backbone.hidden_dim,
|
11
|
+
"intermediate_size": backbone.intermediate_dim // 2,
|
12
|
+
"head_dim": backbone.head_dim,
|
13
|
+
"max_position_embeddings": 8192,
|
14
|
+
}
|
15
|
+
return hf_config
|
16
|
+
|
17
|
+
|
18
|
+
def get_gemma_weights_map(backbone):
|
19
|
+
weights_dict = {}
|
20
|
+
|
21
|
+
# Map token embedding
|
22
|
+
token_embedding_layer = backbone.get_layer("token_embedding")
|
23
|
+
weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0]
|
24
|
+
|
25
|
+
for i in range(backbone.num_layers):
|
26
|
+
decoder_layer = backbone.get_layer(f"decoder_block_{i}")
|
27
|
+
|
28
|
+
# Pre-attention normalization
|
29
|
+
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
|
30
|
+
decoder_layer.pre_attention_norm.weights[0]
|
31
|
+
)
|
32
|
+
|
33
|
+
# Attention query projection
|
34
|
+
query_kernel = decoder_layer.attention.query_dense.weights[0]
|
35
|
+
query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2))
|
36
|
+
query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim))
|
37
|
+
query_kernel = ops.transpose(query_kernel)
|
38
|
+
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel
|
39
|
+
|
40
|
+
# Attention key projection
|
41
|
+
key_kernel = decoder_layer.attention.key_dense.weights[0][0]
|
42
|
+
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = (
|
43
|
+
ops.transpose(key_kernel)
|
44
|
+
)
|
45
|
+
|
46
|
+
# Attention value projection
|
47
|
+
value_kernel = decoder_layer.attention.value_dense.weights[0][0]
|
48
|
+
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = (
|
49
|
+
ops.transpose(value_kernel)
|
50
|
+
)
|
51
|
+
|
52
|
+
# Attention output projection
|
53
|
+
out_kernel = decoder_layer.attention.output_dense.weights[0]
|
54
|
+
out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1))
|
55
|
+
out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1))
|
56
|
+
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel
|
57
|
+
|
58
|
+
# Post-attention normalization
|
59
|
+
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
|
60
|
+
decoder_layer.pre_ffw_norm.weights[0]
|
61
|
+
)
|
62
|
+
|
63
|
+
# MLP gate projection
|
64
|
+
gate_kernel = decoder_layer.gating_ffw.weights[0]
|
65
|
+
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
|
66
|
+
gate_kernel
|
67
|
+
)
|
68
|
+
|
69
|
+
# MLP up projection
|
70
|
+
up_kernel = decoder_layer.gating_ffw_2.weights[0]
|
71
|
+
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
|
72
|
+
up_kernel
|
73
|
+
)
|
74
|
+
|
75
|
+
# MLP down projection
|
76
|
+
down_kernel = decoder_layer.ffw_linear.weights[0]
|
77
|
+
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
|
78
|
+
down_kernel
|
79
|
+
)
|
80
|
+
|
81
|
+
# Map final normalization
|
82
|
+
weights_dict["model.norm.weight"] = backbone.get_layer(
|
83
|
+
"final_normalization"
|
84
|
+
).weights[0]
|
85
|
+
|
86
|
+
# Tie weights, but clone to avoid sharing memory issues
|
87
|
+
weights_dict["lm_head.weight"] = ops.copy(token_embedding_layer.weights[0])
|
88
|
+
|
89
|
+
return weights_dict
|