keras-hub-nightly 0.22.0.dev202507150421__py3-none-any.whl → 0.22.0.dev202507170424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/layers/__init__.py +3 -0
  2. keras_hub/models/__init__.py +3 -0
  3. keras_hub/src/models/clip/clip_backbone.py +3 -102
  4. keras_hub/src/models/clip/clip_layers.py +295 -0
  5. keras_hub/src/models/clip/clip_preprocessor.py +57 -48
  6. keras_hub/src/models/clip/clip_text_encoder.py +2 -2
  7. keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
  8. keras_hub/src/models/dinov2/__init__.py +5 -0
  9. keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
  10. keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
  11. keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
  12. keras_hub/src/models/dinov2/dinov2_presets.py +4 -0
  13. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
  14. keras_hub/src/models/hgnetv2/__init__.py +5 -0
  15. keras_hub/src/models/hgnetv2/hgnetv2_presets.py +5 -5
  16. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
  17. keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +23 -32
  19. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
  20. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
  21. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
  22. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
  23. keras_hub/src/utils/preset_utils.py +4 -1
  24. keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
  25. keras_hub/src/utils/transformers/export/gemma.py +89 -0
  26. keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
  27. keras_hub/src/utils/transformers/preset_loader.py +4 -1
  28. keras_hub/src/version.py +1 -1
  29. {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/METADATA +1 -1
  30. {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/RECORD +32 -25
  31. keras_hub/src/models/clip/clip_encoder_block.py +0 -111
  32. keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
  33. {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/WHEEL +0 -0
  34. {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,4 @@
1
+ """DINOV2 model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {}
@@ -43,9 +43,13 @@ class FluxTextToImagePreprocessor(Preprocessor):
43
43
 
44
44
  def generate_preprocess(self, x):
45
45
  token_ids = {}
46
- token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
46
+ token_ids["clip_l"] = self.clip_l_preprocessor(
47
+ {"prompts": x, "images": None}
48
+ )["token_ids"]
47
49
  if self.t5_preprocessor is not None:
48
- token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
50
+ token_ids["t5"] = self.t5_preprocessor(
51
+ {"prompts": x, "images": None}
52
+ )["token_ids"]
49
53
  return token_ids
50
54
 
51
55
  def get_config(self):
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.hgnetv2.hgnetv2_backbone import HGNetV2Backbone
2
+ from keras_hub.src.models.hgnetv2.hgnetv2_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, HGNetV2Backbone)
@@ -9,7 +9,7 @@ backbone_presets = {
9
9
  "params": 13599072,
10
10
  "path": "hgnetv2",
11
11
  },
12
- "kaggle_handle": "",
12
+ "kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b4_ssld_stage2_ft_in1k/1",
13
13
  },
14
14
  "hgnetv2_b5_ssld_stage1_in22k_in1k": {
15
15
  "metadata": {
@@ -20,7 +20,7 @@ backbone_presets = {
20
20
  "params": 33419680,
21
21
  "path": "hgnetv2",
22
22
  },
23
- "kaggle_handle": "",
23
+ "kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b5_ssld_stage1_in22k_in1k/1",
24
24
  },
25
25
  "hgnetv2_b5_ssld_stage2_ft_in1k": {
26
26
  "metadata": {
@@ -31,7 +31,7 @@ backbone_presets = {
31
31
  "params": 33419680,
32
32
  "path": "hgnetv2",
33
33
  },
34
- "kaggle_handle": "",
34
+ "kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b5_ssld_stage2_ft_in1k/1",
35
35
  },
36
36
  "hgnetv2_b6_ssld_stage1_in22k_in1k": {
37
37
  "metadata": {
@@ -42,7 +42,7 @@ backbone_presets = {
42
42
  "params": 69179888,
43
43
  "path": "hgnetv2",
44
44
  },
45
- "kaggle_handle": "",
45
+ "kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b6_ssld_stage1_in22k_in1k/1",
46
46
  },
47
47
  "hgnetv2_b6_ssld_stage2_ft_in1k": {
48
48
  "metadata": {
@@ -53,6 +53,6 @@ backbone_presets = {
53
53
  "params": 69179888,
54
54
  "path": "hgnetv2",
55
55
  },
56
- "kaggle_handle": "",
56
+ "kaggle_handle": "kaggle://keras/hgnetv2/keras/hgnetv2_b6_ssld_stage2_ft_in1k/1",
57
57
  },
58
58
  }
@@ -38,7 +38,6 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
38
38
  timesteps = ops.flip(timesteps, axis=0)
39
39
  sigmas = self._timestep_to_sigma(timesteps)
40
40
 
41
- self.timesteps = ops.multiply(sigmas, num_train_timesteps)
42
41
  self.sigma_min = sigmas[-1]
43
42
  self.sigma_max = sigmas[0]
44
43
 
@@ -54,14 +53,24 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
54
53
  )
55
54
  return sigma
56
55
 
56
+ def set_sigmas(self, num_steps):
57
+ timesteps = ops.linspace(
58
+ self._sigma_to_timestep(self.sigma_max),
59
+ self._sigma_to_timestep(self.sigma_min),
60
+ num_steps,
61
+ )
62
+ sigmas = self._timestep_to_sigma(timesteps)
63
+ sigmas = ops.concatenate([sigmas, ops.zeros((1,), dtype=sigmas.dtype)])
64
+ self.sigmas = sigmas
65
+
57
66
  def call(self, inputs, num_steps):
58
- start = self._sigma_to_timestep(self.sigma_max)
59
- end = self._sigma_to_timestep(self.sigma_min)
60
- step_size = ops.divide(
61
- ops.subtract(end, start), ops.subtract(num_steps, 1)
67
+ if not hasattr(self, "sigmas"):
68
+ self.set_sigmas(num_steps)
69
+
70
+ step = ops.expand_dims(
71
+ ops.convert_to_tensor(inputs, dtype="int32"), axis=0
62
72
  )
63
- timestep = ops.add(start, ops.multiply(inputs, step_size))
64
- sigma = ops.maximum(self._timestep_to_sigma(timestep), 0.0)
73
+ sigma = ops.take(self.sigmas, step)
65
74
  timestep = self._sigma_to_timestep(sigma)
66
75
  return sigma, timestep
67
76
 
@@ -10,6 +10,63 @@ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
10
  from keras_hub.src.utils.keras_utils import gelu_approximate
11
11
  from keras_hub.src.utils.keras_utils import standardize_data_format
12
12
 
13
+ # TODO: Deprecate this in favor of
14
+ # `keras.layers.RMSNormalization` once we require Keras 3.9 or later.
15
+ if hasattr(layers, "RMSNormalization"):
16
+ RMSNormalization = layers.RMSNormalization
17
+ else:
18
+
19
+ class RMSNormalization(layers.Layer):
20
+ """A normalization layer for MMDiT that implements RMS normalization."""
21
+
22
+ def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.axis = axis
25
+ self.epsilon = epsilon
26
+
27
+ def build(self, input_shape):
28
+ if isinstance(self.axis, list):
29
+ shape = tuple([input_shape[dim] for dim in self.axis])
30
+ else:
31
+ shape = (input_shape[self.axis],)
32
+ self.axis = [self.axis]
33
+
34
+ self.scale = self.add_weight(
35
+ name="scale", shape=shape, initializer="ones"
36
+ )
37
+
38
+ self.built = True
39
+
40
+ def call(self, x):
41
+ x = ops.cast(
42
+ x, keras.backend.result_type(self.compute_dtype, "float32")
43
+ )
44
+ rrms = ops.rsqrt(
45
+ ops.mean(ops.square(x), axis=self.axis, keepdims=True)
46
+ + self.epsilon
47
+ )
48
+ return (x * rrms) * ops.cast(self.scale, x.dtype)
49
+
50
+ def compute_output_shape(self, input_shape):
51
+ if isinstance(self.axis, int):
52
+ axes = [self.axis]
53
+ else:
54
+ axes = self.axis
55
+
56
+ for axis in axes:
57
+ if axis >= len(input_shape) or axis < -len(input_shape):
58
+ raise ValueError(
59
+ f"Axis {axis} is out of bounds for "
60
+ f"input shape {input_shape}. "
61
+ f"Received: axis={self.axis}"
62
+ )
63
+ return input_shape
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update({"axis": self.axis, "epsilon": self.epsilon})
68
+ return config
69
+
13
70
 
14
71
  class AdaptiveLayerNormalization(layers.Layer):
15
72
  """Adaptive layer normalization.
@@ -402,11 +459,11 @@ def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
402
459
  if qk_norm is None:
403
460
  pass
404
461
  elif qk_norm == "rms_norm":
405
- q_norm = layers.LayerNormalization(
406
- epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
462
+ q_norm = RMSNormalization(
463
+ axis=-1, epsilon=1e-6, dtype="float32", name=q_norm_name
407
464
  )
408
- k_norm = layers.LayerNormalization(
409
- epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
465
+ k_norm = RMSNormalization(
466
+ axis=-1, epsilon=1e-6, dtype="float32", name=k_norm_name
410
467
  )
411
468
  else:
412
469
  raise NotImplementedError(
@@ -96,26 +96,10 @@ class LatentRescaling(layers.Rescaling):
96
96
  return (self.backend.cast(inputs, dtype) / scale) + offset
97
97
 
98
98
 
99
- class ClassifierFreeGuidanceConcatenate(layers.Layer):
100
- def call(
101
- self,
102
- latents,
103
- positive_contexts,
104
- negative_contexts,
105
- positive_pooled_projections,
106
- negative_pooled_projections,
107
- timestep,
108
- ):
99
+ class TimestepBroadcastTo(layers.Layer):
100
+ def call(self, latents, timestep):
109
101
  timestep = ops.broadcast_to(timestep, ops.shape(latents)[:1])
110
- latents = ops.concatenate([latents, latents], axis=0)
111
- contexts = ops.concatenate(
112
- [positive_contexts, negative_contexts], axis=0
113
- )
114
- pooled_projections = ops.concatenate(
115
- [positive_pooled_projections, negative_pooled_projections], axis=0
116
- )
117
- timesteps = ops.concatenate([timestep, timestep], axis=0)
118
- return latents, contexts, pooled_projections, timesteps
102
+ return timestep
119
103
 
120
104
 
121
105
  class ClassifierFreeGuidance(layers.Layer):
@@ -330,8 +314,8 @@ class StableDiffusion3Backbone(Backbone):
330
314
  name="diffuser",
331
315
  )
332
316
  self.vae = vae
333
- self.cfg_concat = ClassifierFreeGuidanceConcatenate(
334
- dtype=dtype, name="classifier_free_guidance_concat"
317
+ self.timestep_broadcast_to = TimestepBroadcastTo(
318
+ dtype=dtype, name="timestep_broadcast_to"
335
319
  )
336
320
  self.cfg = ClassifierFreeGuidance(
337
321
  dtype=dtype, name="classifier_free_guidance"
@@ -538,6 +522,9 @@ class StableDiffusion3Backbone(Backbone):
538
522
  latents = self.vae.encode(images)
539
523
  return self.image_rescaling(latents)
540
524
 
525
+ def configure_scheduler(self, num_steps):
526
+ self.scheduler.set_sigmas(num_steps)
527
+
541
528
  def add_noise_step(self, latents, noises, step, num_steps):
542
529
  return self.scheduler.add_noise(latents, noises, step, num_steps)
543
530
 
@@ -562,11 +549,15 @@ class StableDiffusion3Backbone(Backbone):
562
549
 
563
550
  # Concatenation for classifier-free guidance.
564
551
  if guidance_scale is not None:
565
- concated_latents, contexts, pooled_projs, timesteps = (
566
- self.cfg_concat(latents, *embeddings, timestep)
552
+ timestep = self.timestep_broadcast_to(latents, timestep)
553
+ timesteps = ops.concatenate([timestep, timestep], axis=0)
554
+ concated_latents = ops.concatenate([latents, latents], axis=0)
555
+ contexts = ops.concatenate([embeddings[0], embeddings[1]], axis=0)
556
+ pooled_projs = ops.concatenate(
557
+ [embeddings[2], embeddings[3]], axis=0
567
558
  )
568
559
  else:
569
- timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1])
560
+ timesteps = self.timestep_broadcast_to(latents, timestep)
570
561
  concated_latents = latents
571
562
  contexts = embeddings[0]
572
563
  pooled_projs = embeddings[2]
@@ -623,20 +614,20 @@ class StableDiffusion3Backbone(Backbone):
623
614
  def from_config(cls, config, custom_objects=None):
624
615
  config = config.copy()
625
616
 
626
- # Propagate `dtype` to text encoders if needed.
617
+ # Propagate `dtype` to the VAE if needed.
627
618
  if "dtype" in config and config["dtype"] is not None:
628
619
  dtype_config = config["dtype"]
629
620
  if "dtype" not in config["vae"]["config"]:
630
621
  config["vae"]["config"]["dtype"] = dtype_config
631
- if "dtype" not in config["clip_l"]["config"]:
632
- config["clip_l"]["config"]["dtype"] = dtype_config
633
- if "dtype" not in config["clip_g"]["config"]:
634
- config["clip_g"]["config"]["dtype"] = dtype_config
622
+
623
+ # Text encoders default to float16 dtype if not specified.
624
+ for text_encoder in ("clip_l", "clip_g", "t5"):
635
625
  if (
636
- config["t5"] is not None
637
- and "dtype" not in config["t5"]["config"]
626
+ text_encoder in config
627
+ and config[text_encoder] is not None
628
+ and "dtype" not in config[text_encoder]["config"]
638
629
  ):
639
- config["t5"]["config"]["dtype"] = dtype_config
630
+ config[text_encoder]["config"]["dtype"] = "float16"
640
631
 
641
632
  # We expect `vae`, `clip_l`, `clip_g` and/or `t5` to be instantiated.
642
633
  config["vae"] = layers.deserialize(
@@ -169,6 +169,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
169
169
  guidance_scale=7.0,
170
170
  seed=None,
171
171
  ):
172
+ self.backbone.configure_scheduler(num_steps)
172
173
  return super().generate(
173
174
  inputs,
174
175
  num_steps=num_steps,
@@ -184,6 +184,7 @@ class StableDiffusion3Inpaint(Inpaint):
184
184
  guidance_scale=7.0,
185
185
  seed=None,
186
186
  ):
187
+ self.backbone.configure_scheduler(num_steps)
187
188
  return super().generate(
188
189
  inputs,
189
190
  num_steps=num_steps,
@@ -141,6 +141,7 @@ class StableDiffusion3TextToImage(TextToImage):
141
141
  guidance_scale=7.0,
142
142
  seed=None,
143
143
  ):
144
+ self.backbone.configure_scheduler(num_steps)
144
145
  return super().generate(
145
146
  inputs,
146
147
  num_steps=num_steps,
@@ -50,8 +50,12 @@ class StableDiffusion3TextToImagePreprocessor(TextToImagePreprocessor):
50
50
 
51
51
  def generate_preprocess(self, x):
52
52
  token_ids = {}
53
- token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
54
- token_ids["clip_g"] = self.clip_g_preprocessor(x)["token_ids"]
53
+ token_ids["clip_l"] = self.clip_l_preprocessor(
54
+ {"prompts": x, "images": None}
55
+ )["token_ids"]
56
+ token_ids["clip_g"] = self.clip_g_preprocessor(
57
+ {"prompts": x, "images": None}
58
+ )["token_ids"]
55
59
  if self.t5_preprocessor is not None:
56
60
  token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
57
61
  return token_ids
@@ -649,7 +649,10 @@ class KerasPresetLoader(PresetLoader):
649
649
  return check_config_class(self.config)
650
650
 
651
651
  def load_backbone(self, cls, load_weights, **kwargs):
652
- backbone = self._load_serialized_object(self.config, **kwargs)
652
+ config = self.config.copy()
653
+ backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
654
+ config["config"] = {**config["config"], **backbone_kwargs}
655
+ backbone = self._load_serialized_object(config, **kwargs)
653
656
  if load_weights:
654
657
  jax_memory_cleanup(backbone)
655
658
  self._load_backbone_weights(backbone)
@@ -0,0 +1,180 @@
1
+ import numpy as np
2
+
3
+ from keras_hub.src.models.dinov2.dinov2_backbone import DINOV2Backbone
4
+
5
+ backbone_cls = DINOV2Backbone
6
+
7
+
8
+ def convert_backbone_config(transformers_config):
9
+ model_type = transformers_config["model_type"]
10
+ antialias_in_interpolation = False if model_type == "dinov2" else True
11
+ image_size = transformers_config["image_size"]
12
+ intermediate_dim = int(
13
+ transformers_config["hidden_size"] * transformers_config["mlp_ratio"]
14
+ )
15
+ return {
16
+ "patch_size": transformers_config["patch_size"],
17
+ "num_layers": transformers_config["num_hidden_layers"],
18
+ "hidden_dim": transformers_config["hidden_size"],
19
+ "num_heads": transformers_config["num_attention_heads"],
20
+ "intermediate_dim": intermediate_dim,
21
+ "layer_scale_init_value": transformers_config["layerscale_value"],
22
+ "num_register_tokens": transformers_config.get(
23
+ "num_register_tokens", 0
24
+ ),
25
+ "use_mask_token": transformers_config.get("use_mask_token", True),
26
+ "use_swiglu_ffn": transformers_config["use_swiglu_ffn"],
27
+ "dropout_rate": transformers_config["hidden_dropout_prob"],
28
+ "drop_path_rate": transformers_config["drop_path_rate"],
29
+ "image_shape": (image_size, image_size, 3),
30
+ "position_embedding_shape": (image_size, image_size),
31
+ "antialias_in_interpolation": antialias_in_interpolation,
32
+ }
33
+
34
+
35
+ def convert_weights(backbone, loader, transformers_config):
36
+ if not isinstance(backbone, DINOV2Backbone):
37
+ raise ValueError(
38
+ "The provided backbone must be an instance of DINOV2Backbone. "
39
+ f"Received: {type(backbone)}"
40
+ )
41
+
42
+ def port_ln(keras_variable, weight_key):
43
+ loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
44
+ loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
45
+
46
+ def port_dense(keras_variable, weight_key):
47
+ loader.port_weight(
48
+ keras_variable.kernel,
49
+ f"{weight_key}.weight",
50
+ hook_fn=lambda x, _: x.T,
51
+ )
52
+ if keras_variable.bias is not None:
53
+ loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
54
+
55
+ def port_mha(keras_variable, weight_key, num_heads, hidden_dim):
56
+ # query
57
+ loader.port_weight(
58
+ keras_variable.query_dense.kernel,
59
+ f"{weight_key}.attention.query.weight",
60
+ hook_fn=lambda x, _: np.reshape(
61
+ x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
62
+ ),
63
+ )
64
+ loader.port_weight(
65
+ keras_variable.query_dense.bias,
66
+ f"{weight_key}.attention.query.bias",
67
+ hook_fn=lambda x, _: np.reshape(
68
+ x, (num_heads, hidden_dim // num_heads)
69
+ ),
70
+ )
71
+ # key
72
+ loader.port_weight(
73
+ keras_variable.key_dense.kernel,
74
+ f"{weight_key}.attention.key.weight",
75
+ hook_fn=lambda x, _: np.reshape(
76
+ x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
77
+ ),
78
+ )
79
+ loader.port_weight(
80
+ keras_variable.key_dense.bias,
81
+ f"{weight_key}.attention.key.bias",
82
+ hook_fn=lambda x, _: np.reshape(
83
+ x, (num_heads, hidden_dim // num_heads)
84
+ ),
85
+ )
86
+ # value
87
+ loader.port_weight(
88
+ keras_variable.value_dense.kernel,
89
+ f"{weight_key}.attention.value.weight",
90
+ hook_fn=lambda x, _: np.reshape(
91
+ x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
92
+ ),
93
+ )
94
+ loader.port_weight(
95
+ keras_variable.value_dense.bias,
96
+ f"{weight_key}.attention.value.bias",
97
+ hook_fn=lambda x, _: np.reshape(
98
+ x, (num_heads, hidden_dim // num_heads)
99
+ ),
100
+ )
101
+ # output
102
+ loader.port_weight(
103
+ keras_variable.output_dense.kernel,
104
+ f"{weight_key}.output.dense.weight",
105
+ hook_fn=lambda x, _: np.reshape(
106
+ x.T, (num_heads, hidden_dim // num_heads, hidden_dim)
107
+ ),
108
+ )
109
+ loader.port_weight(
110
+ keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias"
111
+ )
112
+
113
+ # Embedding.
114
+ loader.port_weight(
115
+ keras_variable=backbone.embeddings.cls_token,
116
+ hf_weight_key="embeddings.cls_token",
117
+ )
118
+ if backbone.use_mask_token:
119
+ loader.port_weight(
120
+ keras_variable=backbone.embeddings.mask_token,
121
+ hf_weight_key="embeddings.mask_token",
122
+ )
123
+ if backbone.num_register_tokens > 0:
124
+ loader.port_weight(
125
+ keras_variable=backbone.embeddings.register_tokens,
126
+ hf_weight_key="embeddings.register_tokens",
127
+ )
128
+ loader.port_weight(
129
+ keras_variable=backbone.embeddings.position_embeddings,
130
+ hf_weight_key="embeddings.position_embeddings",
131
+ )
132
+ # Interpolate position embeddings to match the image shape.
133
+ backbone.embeddings.interpolated_position_embeddings.assign(
134
+ backbone.embeddings._interpolate_position_embeddings(
135
+ backbone.embeddings.position_embeddings,
136
+ patch_size=backbone.patch_size,
137
+ source_shape=backbone.embeddings.position_embedding_shape,
138
+ target_shape=backbone.image_shape,
139
+ antialias=backbone.embeddings.antialias_in_interpolation,
140
+ )
141
+ )
142
+ loader.port_weight(
143
+ keras_variable=backbone.embeddings.patch_embeddings.projection.kernel,
144
+ hf_weight_key="embeddings.patch_embeddings.projection.weight",
145
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
146
+ )
147
+ loader.port_weight(
148
+ keras_variable=backbone.embeddings.patch_embeddings.projection.bias,
149
+ hf_weight_key="embeddings.patch_embeddings.projection.bias",
150
+ )
151
+
152
+ # Encoder.
153
+ hidden_dim = backbone.hidden_dim
154
+ num_heads = backbone.num_heads
155
+ for i, layer in enumerate(backbone.encoder.layers):
156
+ prefix = f"encoder.layer.{i}"
157
+ port_ln(layer.norm1, f"{prefix}.norm1")
158
+ port_mha(
159
+ layer.attention.attention,
160
+ f"{prefix}.attention",
161
+ num_heads,
162
+ hidden_dim,
163
+ )
164
+ loader.port_weight(
165
+ keras_variable=layer.layer_scale1.lambda1,
166
+ hf_weight_key=f"{prefix}.layer_scale1.lambda1",
167
+ )
168
+ port_ln(layer.norm2, f"{prefix}.norm2")
169
+ if backbone.use_swiglu_ffn:
170
+ port_dense(layer.mlp.weights_in, f"{prefix}.mlp.weights_in")
171
+ port_dense(layer.mlp.weights_out, f"{prefix}.mlp.weights_out")
172
+ else:
173
+ port_dense(layer.mlp.fc1, f"{prefix}.mlp.fc1")
174
+ port_dense(layer.mlp.fc2, f"{prefix}.mlp.fc2")
175
+ loader.port_weight(
176
+ keras_variable=layer.layer_scale2.lambda1,
177
+ hf_weight_key=f"{prefix}.layer_scale2.lambda1",
178
+ )
179
+
180
+ port_ln(backbone.layernorm, "layernorm")
@@ -0,0 +1,89 @@
1
+ import keras.ops as ops
2
+
3
+
4
+ def get_gemma_config(backbone):
5
+ hf_config = {
6
+ "vocab_size": backbone.vocabulary_size,
7
+ "num_hidden_layers": backbone.num_layers,
8
+ "num_attention_heads": backbone.num_query_heads,
9
+ "num_key_value_heads": backbone.num_key_value_heads,
10
+ "hidden_size": backbone.hidden_dim,
11
+ "intermediate_size": backbone.intermediate_dim // 2,
12
+ "head_dim": backbone.head_dim,
13
+ "max_position_embeddings": 8192,
14
+ }
15
+ return hf_config
16
+
17
+
18
+ def get_gemma_weights_map(backbone):
19
+ weights_dict = {}
20
+
21
+ # Map token embedding
22
+ token_embedding_layer = backbone.get_layer("token_embedding")
23
+ weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0]
24
+
25
+ for i in range(backbone.num_layers):
26
+ decoder_layer = backbone.get_layer(f"decoder_block_{i}")
27
+
28
+ # Pre-attention normalization
29
+ weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
30
+ decoder_layer.pre_attention_norm.weights[0]
31
+ )
32
+
33
+ # Attention query projection
34
+ query_kernel = decoder_layer.attention.query_dense.weights[0]
35
+ query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2))
36
+ query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim))
37
+ query_kernel = ops.transpose(query_kernel)
38
+ weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel
39
+
40
+ # Attention key projection
41
+ key_kernel = decoder_layer.attention.key_dense.weights[0][0]
42
+ weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = (
43
+ ops.transpose(key_kernel)
44
+ )
45
+
46
+ # Attention value projection
47
+ value_kernel = decoder_layer.attention.value_dense.weights[0][0]
48
+ weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = (
49
+ ops.transpose(value_kernel)
50
+ )
51
+
52
+ # Attention output projection
53
+ out_kernel = decoder_layer.attention.output_dense.weights[0]
54
+ out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1))
55
+ out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1))
56
+ weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel
57
+
58
+ # Post-attention normalization
59
+ weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
60
+ decoder_layer.pre_ffw_norm.weights[0]
61
+ )
62
+
63
+ # MLP gate projection
64
+ gate_kernel = decoder_layer.gating_ffw.weights[0]
65
+ weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
66
+ gate_kernel
67
+ )
68
+
69
+ # MLP up projection
70
+ up_kernel = decoder_layer.gating_ffw_2.weights[0]
71
+ weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
72
+ up_kernel
73
+ )
74
+
75
+ # MLP down projection
76
+ down_kernel = decoder_layer.ffw_linear.weights[0]
77
+ weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
78
+ down_kernel
79
+ )
80
+
81
+ # Map final normalization
82
+ weights_dict["model.norm.weight"] = backbone.get_layer(
83
+ "final_normalization"
84
+ ).weights[0]
85
+
86
+ # Tie weights, but clone to avoid sharing memory issues
87
+ weights_dict["lm_head.weight"] = ops.copy(token_embedding_layer.weights[0])
88
+
89
+ return weights_dict