keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ import inspect
2
+
1
3
  import keras
2
4
  from keras import ops
3
5
 
@@ -184,31 +186,33 @@ class ReversibleEmbedding(keras.layers.Embedding):
184
186
  else:
185
187
  self._quantization_mode_error(self.quantization_mode)
186
188
 
187
- def _int8_build(
188
- self,
189
- embeddings_initializer="zeros",
190
- embeddings_scale_initializer="ones",
191
- reverse_embeddings_initializer="zeros",
192
- reverse_embeddings_scale_initializer="ones",
193
- ):
194
- super()._int8_build(
195
- embeddings_initializer, embeddings_scale_initializer
196
- )
189
+ def _int8_build(self, embeddings_shape=None):
190
+ if (
191
+ "embeddings_shape"
192
+ in inspect.signature(super()._int8_build).parameters
193
+ ):
194
+ if embeddings_shape is None:
195
+ embeddings_shape = (self.input_dim, self.output_dim)
196
+ super()._int8_build(embeddings_shape=embeddings_shape)
197
+ else:
198
+ # Backward compatibility for older versions of Keras.
199
+ super()._int8_build()
197
200
  self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
198
201
  if not self.tie_weights:
199
202
  self.reverse_embeddings = self.add_weight(
200
203
  name="reverse_embeddings",
201
204
  shape=(self.output_dim, self.input_dim),
202
- initializer=reverse_embeddings_initializer,
205
+ initializer="zeros",
203
206
  dtype="int8",
204
207
  trainable=False,
205
208
  )
206
209
  self.reverse_embeddings_scale = self.add_weight(
207
210
  name="reverse_embeddings_scale",
208
211
  shape=(self.input_dim,),
209
- initializer=reverse_embeddings_scale_initializer,
212
+ initializer="ones",
210
213
  trainable=False,
211
214
  )
215
+ self._is_quantized = True
212
216
 
213
217
  def _int8_call(self, inputs, reverse=False):
214
218
  if reverse:
@@ -232,27 +236,20 @@ class ReversibleEmbedding(keras.layers.Embedding):
232
236
  return super()._int8_call(inputs)
233
237
 
234
238
  def quantize(self, mode, type_check=True):
235
- import gc
236
-
237
239
  if type_check and type(self) is not ReversibleEmbedding:
238
- raise NotImplementedError(
239
- f"Layer {self.__class__.__name__} does not have a `quantize()` "
240
- "method implemented."
241
- )
242
- self._check_quantize_args(mode, self.compute_dtype)
240
+ raise self._not_implemented_error(self.quantize)
243
241
 
244
242
  def abs_max_quantize(inputs, axis):
245
243
  return keras.quantizers.abs_max_quantize(
246
244
  inputs, axis=axis, to_numpy=True
247
245
  )
248
246
 
249
- self._tracker.unlock()
247
+ embeddings_shape = (self.input_dim, self.output_dim)
250
248
  if mode == "int8":
251
249
  embeddings, embeddings_scale = abs_max_quantize(
252
250
  self._embeddings, axis=-1
253
251
  )
254
252
  embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
255
- self._untrack_variable(self._embeddings)
256
253
  del self._embeddings
257
254
  if not self.tie_weights:
258
255
  reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
@@ -261,24 +258,17 @@ class ReversibleEmbedding(keras.layers.Embedding):
261
258
  reverse_embeddings_scale = ops.squeeze(
262
259
  reverse_embeddings_scale, axis=0
263
260
  )
264
- self._untrack_variable(self.reverse_embeddings)
265
261
  del self.reverse_embeddings
266
- else:
267
- reverse_embeddings = None
268
- reverse_embeddings_scale = None
269
- self._int8_build(
270
- lambda shape, dtype: embeddings,
271
- lambda shape, dtype: embeddings_scale,
272
- lambda shape, dtype: reverse_embeddings,
273
- lambda shape, dtype: reverse_embeddings_scale,
274
- )
275
- else:
276
- raise self._quantization_mode_error(mode)
277
- self._tracker.lock()
262
+ self.quantized_build(embeddings_shape, mode)
263
+ if mode == "int8":
264
+ self._embeddings.assign(embeddings)
265
+ self.embeddings_scale.assign(embeddings_scale)
266
+ if not self.tie_weights:
267
+ self.reverse_embeddings.assign(reverse_embeddings)
268
+ self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
278
269
 
279
270
  if self.dtype_policy.quantization_mode is None:
280
271
  policy = keras.dtype_policies.get(
281
272
  f"{mode}_from_{self.dtype_policy.name}"
282
273
  )
283
274
  self.dtype_policy = policy
284
- gc.collect()
@@ -246,6 +246,7 @@ class ImageConverter(PreprocessingLayer):
246
246
  self.antialias = antialias
247
247
  self.bounding_box_format = bounding_box_format
248
248
  self.data_format = standardize_data_format(data_format)
249
+ self.built = True
249
250
 
250
251
  @property
251
252
  def image_size(self):
@@ -125,7 +125,7 @@ class RandomDeletion(PreprocessingLayer):
125
125
 
126
126
  self.rate = rate
127
127
  self.max_deletions = max_deletions
128
- self.seed = random.randint(1, 1e9) if seed is None else seed
128
+ self.seed = random.randint(1, int(1e9)) if seed is None else seed
129
129
  self._generator = tf.random.Generator.from_seed(self.seed)
130
130
  self.skip_list = skip_list
131
131
  self.skip_fn = skip_fn
@@ -127,7 +127,7 @@ class RandomSwap(PreprocessingLayer):
127
127
 
128
128
  self.rate = rate
129
129
  self.max_swaps = max_swaps
130
- self.seed = random.randint(1, 1e9) if seed is None else seed
130
+ self.seed = random.randint(1, int(1e9)) if seed is None else seed
131
131
  self._generator = tf.random.Generator.from_seed(self.seed)
132
132
  self.skip_list = skip_list
133
133
  self.skip_fn = skip_fn
@@ -0,0 +1,66 @@
1
+ from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
2
+
3
+
4
+ class AudioToText(Seq2SeqLM):
5
+ """Base class for audio-to-text models.
6
+
7
+ `AudioToText` tasks wrap a `keras_hub.models.Backbone` (capable of
8
+ processing audio and text features) and a
9
+ `keras_hub.models.AudioToTextPreprocessor` to create a model for
10
+ audio-to-text tasks like speech recognition or audio transcription.
11
+
12
+ These models typically consist of an encoder that processes audio input
13
+ and a decoder that generates a textual representation.
14
+
15
+ `AudioToText` tasks provide a high-level `generate()` method for
16
+ auto-regressively generating text from audio input. An optional text
17
+ prompt can also be provided to the decoder to guide generation. The
18
+ sampling strategy for generation (e.g., greedy, top-k, top-p) can be
19
+ controlled via the `sampler` argument in the `compile()` method.
20
+
21
+ When calling `fit()`, inputs should consist of audio data and corresponding
22
+ target text transcriptions. The model is trained to predict the target text
23
+ token-by-token.
24
+
25
+ All `AudioToText` tasks include a `from_preset()` constructor which
26
+ can be used to load pre-trained configurations and weights for specific
27
+ audio-to-text models.
28
+ This constructor can also be called on the base `AudioToText` class,
29
+ which will automatically select the correct subclass based on the preset.
30
+
31
+ Examples:
32
+ ```python
33
+ # Load a Moonshine backbone with pre-trained weights.
34
+ # AudioToText is a base class. You will typically work with a specific
35
+ # implementation, such as `keras_hub.models.MoonshineAudioToText`.
36
+ # The following examples demonstrate common usage patterns.
37
+
38
+ # Initialize a model from a preset using the specific subclass.
39
+ audio_to_text = keras_hub.models.MoonshineAudioToText.from_preset(
40
+ "moonshine_base_en"
41
+ )
42
+
43
+ # Initialize a model from a preset using the base class.
44
+ audio_to_text_model_base = keras_hub.models.AudioToText.from_preset(
45
+ "moonshine_base_en"
46
+ )
47
+
48
+ # Generate text from an audio input.
49
+ audio_input_tensor = keras.random.normal((1, 16000, 1))
50
+ generated_output = audio_to_text_model.generate(
51
+ {"audio": audio_input_tensor}
52
+ )
53
+
54
+ # Generate conditioned on the `"The quick brown fox."` as an input sequence.
55
+ prompted_output = audio_to_text_model.generate(
56
+ {"audio": audio_input_tensor, "text": "The quick brown fox."}
57
+ )
58
+
59
+ # Use a different sampling strategy for generation.
60
+ audio_to_text_model.compile(sampler="greedy")
61
+ greedy_output = audio_to_text_model.generate(
62
+ {"audio": audio_input_tensor}
63
+ )
64
+ """
65
+
66
+ # TODO: Fill in once audio to text task model requirements are clearer.
@@ -0,0 +1,80 @@
1
+ from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
2
+
3
+
4
+ class AudioToTextPreprocessor(Seq2SeqLMPreprocessor):
5
+ """Base class for audio-to-text preprocessing layers.
6
+
7
+ `AudioToTextPreprocessor` layers wrap an audio feature extractor (specific
8
+ to the subclass) and a `keras_hub.tokenizer.Tokenizer` to create a
9
+ preprocessing layer for audio-to-text tasks. It is intended to be
10
+ paired with a `keras_hub.models.AudioToText` task.
11
+
12
+ Subclasses are expected to handle the conversion of raw audio data into
13
+ numerical features suitable for an encoder, and raw text data into token IDs
14
+ for a decoder.
15
+
16
+ All `AudioToTextPreprocessor` layers take a dictionary as input,
17
+ typically with keys like `"audio"` (for audio data) and `"text"` (for
18
+ target transcriptions or decoder prompts).
19
+
20
+ This layer will always output a `(x, y, sample_weight)` tuple, where `x`
21
+ is a dictionary containing processed audio features for the encoder and
22
+ tokenized text inputs for the decoder. `y` contains the target token IDs
23
+ (decoder input tokens shifted by one position), and `sample_weight`
24
+ indicates padding in `y`. The exact keys and structure of features within
25
+ `x` will depend on the specific subclass and the paired `AudioToText` model.
26
+
27
+ An `AudioToTextPreprocessor` includes `generate_preprocess` and
28
+ `generate_postprocess` methods for use during inference with an
29
+ `AudioToText` model's `generate()` method.
30
+
31
+ All `AudioToTextPreprocessor` tasks include a `from_preset()` constructor
32
+ which can be used to load a pre-trained configuration, including tokenizer
33
+ vocabularies and audio feature extraction settings. Calling `from_preset()`
34
+ on this base class can instantiate the correct subclass registered for the
35
+ given preset.
36
+
37
+ Examples:
38
+ ```python
39
+ preprocessor = keras_hub.models.AudioToTextPreprocessor.from_preset(
40
+ "moonshine_base_en",
41
+ decoder_sequence_length=10
42
+ )
43
+
44
+ # Process a single audio-text pair.
45
+ x = {
46
+ "audio": keras.random.normal((1, 16000, 1)),
47
+ "text": ["the quick brown fox"]
48
+ }
49
+ x, y, sample_weight = preprocessor(x)
50
+
51
+ # Process a batch of audio-text pairs.
52
+ x = {
53
+ "audio": keras.random.normal((2, 16000, 1)),
54
+ "text": ["first sentence", "second sentence"]
55
+ }
56
+ x, y, sample_weight = preprocessor(x)
57
+
58
+ # With a `tf.data.Dataset`.
59
+ audio_tf = keras.ops.convert_to_tensor(batch_input["audio"])
60
+ text_tf = batch_input["text"] # List of strings
61
+ x = {"audio": audio_tf, "text": text_tf}
62
+ ds = tf.data.Dataset.from_tensor_slices(x)
63
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
64
+ ds = ds.batch(2) # Batching after map
65
+
66
+ # Generate preprocess and postprocess.
67
+ x = preprocessor.generate_preprocess({
68
+ "audio": keras.random.normal((1, 16000, 1)),
69
+ "text": ["optional prompt text"]
70
+ })
71
+ x = preprocessor.generate_postprocess({
72
+ "decoder_token_ids": keras.ops.array([[10, 20, 30, 2, 0]]),
73
+ "decoder_padding_mask": keras.ops.array([
74
+ [True, True, True, True, False]
75
+ ])
76
+ })
77
+ ```
78
+ """
79
+
80
+ # TODO: Fill in once audio to text task model requirements are clearer.
@@ -177,14 +177,17 @@ class Backbone(keras.Model):
177
177
  )
178
178
  return loader.load_backbone(backbone_cls, load_weights, **kwargs)
179
179
 
180
- def save_to_preset(self, preset_dir):
180
+ def save_to_preset(self, preset_dir, max_shard_size=10):
181
181
  """Save backbone to a preset directory.
182
182
 
183
183
  Args:
184
184
  preset_dir: The path to the local model preset directory.
185
+ max_shard_size: `int` or `float`. Maximum size in GB for each
186
+ sharded file. If `None`, no sharding will be done. Defaults to
187
+ `10`.
185
188
  """
186
189
  saver = get_preset_saver(preset_dir)
187
- saver.save_backbone(self)
190
+ saver.save_backbone(self, max_shard_size=max_shard_size)
188
191
 
189
192
  def get_lora_target_names(self):
190
193
  """Returns list of layer names which are to be LoRA-fied.
@@ -81,7 +81,7 @@ class CSPNetBackbone(FeaturePyramidBackbone):
81
81
 
82
82
  # Pretrained backbone
83
83
  model = keras_hub.models.CSPNetBackbone.from_preset(
84
- "cspdarknet53_ra_imagenet"
84
+ "csp_darknet_53_ra_imagenet"
85
85
  )
86
86
  model(input_data)
87
87
 
@@ -357,18 +357,6 @@ def bottleneck_block(
357
357
  dtype=dtype,
358
358
  name=f"{name}_bottleneck_block_bn_3",
359
359
  )(x)
360
- if activation == "leaky_relu":
361
- x = layers.LeakyReLU(
362
- negative_slope=0.01,
363
- dtype=dtype,
364
- name=f"{name}_bottleneck_block_activation_3",
365
- )(x)
366
- else:
367
- x = layers.Activation(
368
- activation,
369
- dtype=dtype,
370
- name=f"{name}_bottleneck_block_activation_3",
371
- )(x)
372
360
 
373
361
  x = layers.add(
374
362
  [x, shortcut], dtype=dtype, name=f"{name}_bottleneck_block_add"
@@ -673,6 +661,13 @@ def cross_stage(
673
661
  name=f"{name}_csp_activation_1",
674
662
  )(x)
675
663
  else:
664
+ if strides > 1:
665
+ x = layers.ZeroPadding2D(
666
+ 1,
667
+ data_format=data_format,
668
+ dtype=dtype,
669
+ name=f"{name}_csp_conv_pad_1",
670
+ )(x)
676
671
  x = layers.Conv2D(
677
672
  filters=down_chs,
678
673
  kernel_size=3,
@@ -882,6 +877,13 @@ def cross_stage3(
882
877
  name=f"{name}_cs3_activation_1",
883
878
  )(x)
884
879
  else:
880
+ if strides > 1:
881
+ x = layers.ZeroPadding2D(
882
+ 1,
883
+ data_format=data_format,
884
+ dtype=dtype,
885
+ name=f"{name}_cs3_conv_pad_1",
886
+ )(x)
885
887
  x = layers.Conv2D(
886
888
  filters=down_chs,
887
889
  kernel_size=3,
@@ -1062,6 +1064,13 @@ def dark_stage(
1062
1064
  name=f"{name}_dark_activation_1",
1063
1065
  )(x)
1064
1066
  else:
1067
+ if strides > 1:
1068
+ x = layers.ZeroPadding2D(
1069
+ 1,
1070
+ data_format=data_format,
1071
+ dtype=dtype,
1072
+ name=f"{name}_dark_conv_pad_1",
1073
+ )(x)
1065
1074
  x = layers.Conv2D(
1066
1075
  filters=filters,
1067
1076
  kernel_size=3,
@@ -1091,18 +1100,18 @@ def dark_stage(
1091
1100
  dtype=dtype,
1092
1101
  name=f"{name}_dark_activation_1",
1093
1102
  )(x)
1094
- for i in range(depth):
1095
- x = block_fn(
1096
- filters=block_channels,
1097
- dilation=dilation,
1098
- bottle_ratio=bottle_ratio,
1099
- groups=groups,
1100
- activation=activation,
1101
- data_format=data_format,
1102
- channel_axis=channel_axis,
1103
- dtype=dtype,
1104
- name=f"{name}_block_{i}",
1105
- )(x)
1103
+ for i in range(depth):
1104
+ x = block_fn(
1105
+ filters=block_channels,
1106
+ dilation=dilation,
1107
+ bottle_ratio=bottle_ratio,
1108
+ groups=groups,
1109
+ activation=activation,
1110
+ data_format=data_format,
1111
+ channel_axis=channel_axis,
1112
+ dtype=dtype,
1113
+ name=f"{name}_block_{i}",
1114
+ )(x)
1106
1115
  return x
1107
1116
 
1108
1117
  return apply
@@ -1135,6 +1144,13 @@ def create_csp_stem(
1135
1144
  or (i == last_idx and strides > 2 and not pooling)
1136
1145
  else 1
1137
1146
  )
1147
+ if conv_strides > 1:
1148
+ x = layers.ZeroPadding2D(
1149
+ (kernel_size - 1) // 2,
1150
+ data_format=data_format,
1151
+ dtype=dtype,
1152
+ name=f"csp_stem_pad_{i}",
1153
+ )(x)
1138
1154
  x = layers.Conv2D(
1139
1155
  filters=chs,
1140
1156
  kernel_size=kernel_size,
@@ -1167,10 +1183,19 @@ def create_csp_stem(
1167
1183
 
1168
1184
  if pooling == "max":
1169
1185
  assert strides > 2
1186
+ # Use manual padding to handle edge case scenario to ignore zero's
1187
+ # as max value instead consider negative values from Leaky Relu type
1188
+ # of activations.
1189
+ pad_width = [[1, 1], [1, 1]]
1190
+ if data_format == "channels_last":
1191
+ pad_width += [[0, 0]]
1192
+ else:
1193
+ pad_width = [[0, 0]] + pad_width
1194
+ pad_width = [[0, 0]] + pad_width
1195
+ x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf"))
1170
1196
  x = layers.MaxPooling2D(
1171
1197
  pool_size=3,
1172
1198
  strides=2,
1173
- padding="same",
1174
1199
  data_format=data_format,
1175
1200
  dtype=dtype,
1176
1201
  name="csp_stem_pool",
@@ -6,11 +6,46 @@ backbone_presets = {
6
6
  "description": (
7
7
  "A CSP-DarkNet (Cross-Stage-Partial) image classification model"
8
8
  " pre-trained on the Randomly Augmented ImageNet 1k dataset at "
9
- "a 224x224 resolution."
9
+ "a 256x256 resolution."
10
10
  ),
11
- "params": 26652512,
11
+ "params": 27642184,
12
12
  "path": "cspnet",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/1",
14
+ "kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/2",
15
+ },
16
+ "csp_resnext_50_ra_imagenet": {
17
+ "metadata": {
18
+ "description": (
19
+ "A CSP-ResNeXt (Cross-Stage-Partial) image classification model"
20
+ " pre-trained on the Randomly Augmented ImageNet 1k dataset at "
21
+ "a 256x256 resolution."
22
+ ),
23
+ "params": 20569896,
24
+ "path": "cspnet",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnext_50_ra_imagenet/1",
27
+ },
28
+ "csp_resnet_50_ra_imagenet": {
29
+ "metadata": {
30
+ "description": (
31
+ "A CSP-ResNet (Cross-Stage-Partial) image classification model"
32
+ " pre-trained on the Randomly Augmented ImageNet 1k dataset at "
33
+ "a 256x256 resolution."
34
+ ),
35
+ "params": 21616168,
36
+ "path": "cspnet",
37
+ },
38
+ "kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnet_50_ra_imagenet/1",
39
+ },
40
+ "darknet_53_imagenet": {
41
+ "metadata": {
42
+ "description": (
43
+ "A DarkNet image classification model pre-trained on the"
44
+ "ImageNet 1k dataset at a 256x256 resolution."
45
+ ),
46
+ "params": 41609928,
47
+ "path": "cspnet",
48
+ },
49
+ "kaggle_handle": "kaggle://keras/cspdarknet/keras/darknet_53_imagenet/1",
15
50
  },
16
51
  }
@@ -29,7 +29,7 @@ class FalconBackbone(Backbone):
29
29
  layer_norm_epsilon: float. Epsilon for the layer normalization layers in
30
30
  the transformer decoder.
31
31
  attention_dropout_rate: float. Dropout probability for the attention.
32
- feedforward_dropout_rate: flaot. Dropout probability for the
32
+ feedforward_dropout_rate: float. Dropout probability for the
33
33
  feedforward.
34
34
  dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
35
35
  for model computations and weights. Note that some computations,
@@ -61,7 +61,7 @@ backbone_presets = {
61
61
  "params": 8537680896,
62
62
  "path": "gemma",
63
63
  },
64
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/3",
64
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/4",
65
65
  },
66
66
  "gemma_instruct_7b_en": {
67
67
  "metadata": {
@@ -71,7 +71,7 @@ backbone_presets = {
71
71
  "params": 8537680896,
72
72
  "path": "gemma",
73
73
  },
74
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/3",
74
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/4",
75
75
  },
76
76
  "gemma_1.1_instruct_7b_en": {
77
77
  "metadata": {
@@ -82,7 +82,7 @@ backbone_presets = {
82
82
  "params": 8537680896,
83
83
  "path": "gemma",
84
84
  },
85
- "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/4",
85
+ "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/5",
86
86
  },
87
87
  "code_gemma_7b_en": {
88
88
  "metadata": {
@@ -94,7 +94,7 @@ backbone_presets = {
94
94
  "params": 8537680896,
95
95
  "path": "gemma",
96
96
  },
97
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/2",
97
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/3",
98
98
  },
99
99
  "code_gemma_instruct_7b_en": {
100
100
  "metadata": {
@@ -106,7 +106,7 @@ backbone_presets = {
106
106
  "params": 8537680896,
107
107
  "path": "gemma",
108
108
  },
109
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/2",
109
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/3",
110
110
  },
111
111
  "code_gemma_1.1_instruct_7b_en": {
112
112
  "metadata": {
@@ -118,7 +118,7 @@ backbone_presets = {
118
118
  "params": 8537680896,
119
119
  "path": "gemma",
120
120
  },
121
- "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/2",
121
+ "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/3",
122
122
  },
123
123
  "gemma2_2b_en": {
124
124
  "metadata": {
@@ -144,7 +144,7 @@ backbone_presets = {
144
144
  "params": 9241705984,
145
145
  "path": "gemma",
146
146
  },
147
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/3",
147
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/4",
148
148
  },
149
149
  "gemma2_instruct_9b_en": {
150
150
  "metadata": {
@@ -154,7 +154,7 @@ backbone_presets = {
154
154
  "params": 9241705984,
155
155
  "path": "gemma",
156
156
  },
157
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/3",
157
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/4",
158
158
  },
159
159
  "gemma2_27b_en": {
160
160
  "metadata": {
@@ -162,7 +162,7 @@ backbone_presets = {
162
162
  "params": 27227128320,
163
163
  "path": "gemma",
164
164
  },
165
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/2",
165
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/3",
166
166
  },
167
167
  "gemma2_instruct_27b_en": {
168
168
  "metadata": {
@@ -172,7 +172,7 @@ backbone_presets = {
172
172
  "params": 27227128320,
173
173
  "path": "gemma",
174
174
  },
175
- "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/2",
175
+ "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/3",
176
176
  },
177
177
  "shieldgemma_2b_en": {
178
178
  "metadata": {
@@ -40,13 +40,13 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
40
40
 
41
41
  For use with generation, the layer also exposes two methods
42
42
  `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
43
- is attached to a `keras_hub.models.GemmaCausalLM` instance, these methods
43
+ is attached to a `keras_hub.models.Gemma3CausalLM` instance, these methods
44
44
  will be called implicitly in `generate()`. They can also be called
45
45
  standalone (e.g. to precompute preprocessing inputs for generation in a
46
46
  separate process).
47
47
 
48
48
  Args:
49
- tokenizer: A `keras_hub.models.GemmaTokenizer` instance.
49
+ tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance.
50
50
  image_converter: A `keras_hub.layers.ImageConverter` instance. Defaults
51
51
  to `None`.
52
52
  sequence_length: The length of the packed inputs. Defaults to 1024.
@@ -512,6 +512,7 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
512
512
 
513
513
  # Extract text part of the input.
514
514
  prompts, responses = x["prompts"], x["responses"]
515
+ tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
515
516
 
516
517
  # Find out if the input is batched/not batched. Uprank if not batched.
517
518
  # In other preprocessors, we don't have to do this, but here, all