keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +15 -33
- keras_hub/layers/__init__.py +134 -0
- keras_hub/metrics/__init__.py +11 -0
- keras_hub/models/__init__.py +642 -0
- keras_hub/samplers/__init__.py +18 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
- keras_hub/src/layers/preprocessing/image_converter.py +1 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
- keras_hub/src/layers/preprocessing/random_swap.py +1 -1
- keras_hub/src/models/audio_to_text.py +66 -0
- keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
- keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma/gemma_presets.py +10 -10
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
- keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_presets.py +3 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +84 -2
- keras_hub/src/models/mistral/mistral_presets.py +3 -3
- keras_hub/src/models/mixtral/__init__.py +5 -0
- keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/moonshine/__init__.py +5 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
- keras_hub/src/models/qwen/__init__.py +4 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_backbone.py +8 -1
- keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
- keras_hub/src/models/qwen_moe/__init__.py +5 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
- keras_hub/src/models/segformer/segformer_presets.py +12 -12
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/models/xception/__init__.py +5 -0
- keras_hub/src/models/xception/xception_backbone.py +188 -0
- keras_hub/src/models/xception/xception_image_classifier.py +12 -0
- keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/xception/xception_image_converter.py +8 -0
- keras_hub/src/models/xception/xception_presets.py +14 -0
- keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
- keras_hub/src/utils/coco/__init__.py +0 -0
- keras_hub/src/utils/coco/coco_utils.py +133 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +70 -10
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/timm/convert_cspnet.py +94 -23
- keras_hub/src/utils/timm/preset_loader.py +6 -6
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/{version_utils.py → version.py} +1 -1
- keras_hub/tokenizers/__init__.py +117 -0
- keras_hub/utils/__init__.py +21 -0
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
- keras_hub/api/__init__.py +0 -15
- keras_hub/api/layers/__init__.py +0 -86
- keras_hub/api/metrics/__init__.py +0 -11
- keras_hub/api/models/__init__.py +0 -416
- keras_hub/api/samplers/__init__.py +0 -16
- keras_hub/api/tokenizers/__init__.py +0 -58
- keras_hub/api/utils/__init__.py +0 -9
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
1
3
|
import keras
|
|
2
4
|
from keras import ops
|
|
3
5
|
|
|
@@ -184,31 +186,33 @@ class ReversibleEmbedding(keras.layers.Embedding):
|
|
|
184
186
|
else:
|
|
185
187
|
self._quantization_mode_error(self.quantization_mode)
|
|
186
188
|
|
|
187
|
-
def _int8_build(
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
189
|
+
def _int8_build(self, embeddings_shape=None):
|
|
190
|
+
if (
|
|
191
|
+
"embeddings_shape"
|
|
192
|
+
in inspect.signature(super()._int8_build).parameters
|
|
193
|
+
):
|
|
194
|
+
if embeddings_shape is None:
|
|
195
|
+
embeddings_shape = (self.input_dim, self.output_dim)
|
|
196
|
+
super()._int8_build(embeddings_shape=embeddings_shape)
|
|
197
|
+
else:
|
|
198
|
+
# Backward compatibility for older versions of Keras.
|
|
199
|
+
super()._int8_build()
|
|
197
200
|
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
|
|
198
201
|
if not self.tie_weights:
|
|
199
202
|
self.reverse_embeddings = self.add_weight(
|
|
200
203
|
name="reverse_embeddings",
|
|
201
204
|
shape=(self.output_dim, self.input_dim),
|
|
202
|
-
initializer=
|
|
205
|
+
initializer="zeros",
|
|
203
206
|
dtype="int8",
|
|
204
207
|
trainable=False,
|
|
205
208
|
)
|
|
206
209
|
self.reverse_embeddings_scale = self.add_weight(
|
|
207
210
|
name="reverse_embeddings_scale",
|
|
208
211
|
shape=(self.input_dim,),
|
|
209
|
-
initializer=
|
|
212
|
+
initializer="ones",
|
|
210
213
|
trainable=False,
|
|
211
214
|
)
|
|
215
|
+
self._is_quantized = True
|
|
212
216
|
|
|
213
217
|
def _int8_call(self, inputs, reverse=False):
|
|
214
218
|
if reverse:
|
|
@@ -232,27 +236,20 @@ class ReversibleEmbedding(keras.layers.Embedding):
|
|
|
232
236
|
return super()._int8_call(inputs)
|
|
233
237
|
|
|
234
238
|
def quantize(self, mode, type_check=True):
|
|
235
|
-
import gc
|
|
236
|
-
|
|
237
239
|
if type_check and type(self) is not ReversibleEmbedding:
|
|
238
|
-
raise
|
|
239
|
-
f"Layer {self.__class__.__name__} does not have a `quantize()` "
|
|
240
|
-
"method implemented."
|
|
241
|
-
)
|
|
242
|
-
self._check_quantize_args(mode, self.compute_dtype)
|
|
240
|
+
raise self._not_implemented_error(self.quantize)
|
|
243
241
|
|
|
244
242
|
def abs_max_quantize(inputs, axis):
|
|
245
243
|
return keras.quantizers.abs_max_quantize(
|
|
246
244
|
inputs, axis=axis, to_numpy=True
|
|
247
245
|
)
|
|
248
246
|
|
|
249
|
-
self.
|
|
247
|
+
embeddings_shape = (self.input_dim, self.output_dim)
|
|
250
248
|
if mode == "int8":
|
|
251
249
|
embeddings, embeddings_scale = abs_max_quantize(
|
|
252
250
|
self._embeddings, axis=-1
|
|
253
251
|
)
|
|
254
252
|
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
|
255
|
-
self._untrack_variable(self._embeddings)
|
|
256
253
|
del self._embeddings
|
|
257
254
|
if not self.tie_weights:
|
|
258
255
|
reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
|
|
@@ -261,24 +258,17 @@ class ReversibleEmbedding(keras.layers.Embedding):
|
|
|
261
258
|
reverse_embeddings_scale = ops.squeeze(
|
|
262
259
|
reverse_embeddings_scale, axis=0
|
|
263
260
|
)
|
|
264
|
-
self._untrack_variable(self.reverse_embeddings)
|
|
265
261
|
del self.reverse_embeddings
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
self.
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
lambda shape, dtype: reverse_embeddings_scale,
|
|
274
|
-
)
|
|
275
|
-
else:
|
|
276
|
-
raise self._quantization_mode_error(mode)
|
|
277
|
-
self._tracker.lock()
|
|
262
|
+
self.quantized_build(embeddings_shape, mode)
|
|
263
|
+
if mode == "int8":
|
|
264
|
+
self._embeddings.assign(embeddings)
|
|
265
|
+
self.embeddings_scale.assign(embeddings_scale)
|
|
266
|
+
if not self.tie_weights:
|
|
267
|
+
self.reverse_embeddings.assign(reverse_embeddings)
|
|
268
|
+
self.reverse_embeddings_scale.assign(reverse_embeddings_scale)
|
|
278
269
|
|
|
279
270
|
if self.dtype_policy.quantization_mode is None:
|
|
280
271
|
policy = keras.dtype_policies.get(
|
|
281
272
|
f"{mode}_from_{self.dtype_policy.name}"
|
|
282
273
|
)
|
|
283
274
|
self.dtype_policy = policy
|
|
284
|
-
gc.collect()
|
|
@@ -125,7 +125,7 @@ class RandomDeletion(PreprocessingLayer):
|
|
|
125
125
|
|
|
126
126
|
self.rate = rate
|
|
127
127
|
self.max_deletions = max_deletions
|
|
128
|
-
self.seed = random.randint(1, 1e9) if seed is None else seed
|
|
128
|
+
self.seed = random.randint(1, int(1e9)) if seed is None else seed
|
|
129
129
|
self._generator = tf.random.Generator.from_seed(self.seed)
|
|
130
130
|
self.skip_list = skip_list
|
|
131
131
|
self.skip_fn = skip_fn
|
|
@@ -127,7 +127,7 @@ class RandomSwap(PreprocessingLayer):
|
|
|
127
127
|
|
|
128
128
|
self.rate = rate
|
|
129
129
|
self.max_swaps = max_swaps
|
|
130
|
-
self.seed = random.randint(1, 1e9) if seed is None else seed
|
|
130
|
+
self.seed = random.randint(1, int(1e9)) if seed is None else seed
|
|
131
131
|
self._generator = tf.random.Generator.from_seed(self.seed)
|
|
132
132
|
self.skip_list = skip_list
|
|
133
133
|
self.skip_fn = skip_fn
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AudioToText(Seq2SeqLM):
|
|
5
|
+
"""Base class for audio-to-text models.
|
|
6
|
+
|
|
7
|
+
`AudioToText` tasks wrap a `keras_hub.models.Backbone` (capable of
|
|
8
|
+
processing audio and text features) and a
|
|
9
|
+
`keras_hub.models.AudioToTextPreprocessor` to create a model for
|
|
10
|
+
audio-to-text tasks like speech recognition or audio transcription.
|
|
11
|
+
|
|
12
|
+
These models typically consist of an encoder that processes audio input
|
|
13
|
+
and a decoder that generates a textual representation.
|
|
14
|
+
|
|
15
|
+
`AudioToText` tasks provide a high-level `generate()` method for
|
|
16
|
+
auto-regressively generating text from audio input. An optional text
|
|
17
|
+
prompt can also be provided to the decoder to guide generation. The
|
|
18
|
+
sampling strategy for generation (e.g., greedy, top-k, top-p) can be
|
|
19
|
+
controlled via the `sampler` argument in the `compile()` method.
|
|
20
|
+
|
|
21
|
+
When calling `fit()`, inputs should consist of audio data and corresponding
|
|
22
|
+
target text transcriptions. The model is trained to predict the target text
|
|
23
|
+
token-by-token.
|
|
24
|
+
|
|
25
|
+
All `AudioToText` tasks include a `from_preset()` constructor which
|
|
26
|
+
can be used to load pre-trained configurations and weights for specific
|
|
27
|
+
audio-to-text models.
|
|
28
|
+
This constructor can also be called on the base `AudioToText` class,
|
|
29
|
+
which will automatically select the correct subclass based on the preset.
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
```python
|
|
33
|
+
# Load a Moonshine backbone with pre-trained weights.
|
|
34
|
+
# AudioToText is a base class. You will typically work with a specific
|
|
35
|
+
# implementation, such as `keras_hub.models.MoonshineAudioToText`.
|
|
36
|
+
# The following examples demonstrate common usage patterns.
|
|
37
|
+
|
|
38
|
+
# Initialize a model from a preset using the specific subclass.
|
|
39
|
+
audio_to_text = keras_hub.models.MoonshineAudioToText.from_preset(
|
|
40
|
+
"moonshine_base_en"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Initialize a model from a preset using the base class.
|
|
44
|
+
audio_to_text_model_base = keras_hub.models.AudioToText.from_preset(
|
|
45
|
+
"moonshine_base_en"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Generate text from an audio input.
|
|
49
|
+
audio_input_tensor = keras.random.normal((1, 16000, 1))
|
|
50
|
+
generated_output = audio_to_text_model.generate(
|
|
51
|
+
{"audio": audio_input_tensor}
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Generate conditioned on the `"The quick brown fox."` as an input sequence.
|
|
55
|
+
prompted_output = audio_to_text_model.generate(
|
|
56
|
+
{"audio": audio_input_tensor, "text": "The quick brown fox."}
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Use a different sampling strategy for generation.
|
|
60
|
+
audio_to_text_model.compile(sampler="greedy")
|
|
61
|
+
greedy_output = audio_to_text_model.generate(
|
|
62
|
+
{"audio": audio_input_tensor}
|
|
63
|
+
)
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
# TODO: Fill in once audio to text task model requirements are clearer.
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AudioToTextPreprocessor(Seq2SeqLMPreprocessor):
|
|
5
|
+
"""Base class for audio-to-text preprocessing layers.
|
|
6
|
+
|
|
7
|
+
`AudioToTextPreprocessor` layers wrap an audio feature extractor (specific
|
|
8
|
+
to the subclass) and a `keras_hub.tokenizer.Tokenizer` to create a
|
|
9
|
+
preprocessing layer for audio-to-text tasks. It is intended to be
|
|
10
|
+
paired with a `keras_hub.models.AudioToText` task.
|
|
11
|
+
|
|
12
|
+
Subclasses are expected to handle the conversion of raw audio data into
|
|
13
|
+
numerical features suitable for an encoder, and raw text data into token IDs
|
|
14
|
+
for a decoder.
|
|
15
|
+
|
|
16
|
+
All `AudioToTextPreprocessor` layers take a dictionary as input,
|
|
17
|
+
typically with keys like `"audio"` (for audio data) and `"text"` (for
|
|
18
|
+
target transcriptions or decoder prompts).
|
|
19
|
+
|
|
20
|
+
This layer will always output a `(x, y, sample_weight)` tuple, where `x`
|
|
21
|
+
is a dictionary containing processed audio features for the encoder and
|
|
22
|
+
tokenized text inputs for the decoder. `y` contains the target token IDs
|
|
23
|
+
(decoder input tokens shifted by one position), and `sample_weight`
|
|
24
|
+
indicates padding in `y`. The exact keys and structure of features within
|
|
25
|
+
`x` will depend on the specific subclass and the paired `AudioToText` model.
|
|
26
|
+
|
|
27
|
+
An `AudioToTextPreprocessor` includes `generate_preprocess` and
|
|
28
|
+
`generate_postprocess` methods for use during inference with an
|
|
29
|
+
`AudioToText` model's `generate()` method.
|
|
30
|
+
|
|
31
|
+
All `AudioToTextPreprocessor` tasks include a `from_preset()` constructor
|
|
32
|
+
which can be used to load a pre-trained configuration, including tokenizer
|
|
33
|
+
vocabularies and audio feature extraction settings. Calling `from_preset()`
|
|
34
|
+
on this base class can instantiate the correct subclass registered for the
|
|
35
|
+
given preset.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
```python
|
|
39
|
+
preprocessor = keras_hub.models.AudioToTextPreprocessor.from_preset(
|
|
40
|
+
"moonshine_base_en",
|
|
41
|
+
decoder_sequence_length=10
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Process a single audio-text pair.
|
|
45
|
+
x = {
|
|
46
|
+
"audio": keras.random.normal((1, 16000, 1)),
|
|
47
|
+
"text": ["the quick brown fox"]
|
|
48
|
+
}
|
|
49
|
+
x, y, sample_weight = preprocessor(x)
|
|
50
|
+
|
|
51
|
+
# Process a batch of audio-text pairs.
|
|
52
|
+
x = {
|
|
53
|
+
"audio": keras.random.normal((2, 16000, 1)),
|
|
54
|
+
"text": ["first sentence", "second sentence"]
|
|
55
|
+
}
|
|
56
|
+
x, y, sample_weight = preprocessor(x)
|
|
57
|
+
|
|
58
|
+
# With a `tf.data.Dataset`.
|
|
59
|
+
audio_tf = keras.ops.convert_to_tensor(batch_input["audio"])
|
|
60
|
+
text_tf = batch_input["text"] # List of strings
|
|
61
|
+
x = {"audio": audio_tf, "text": text_tf}
|
|
62
|
+
ds = tf.data.Dataset.from_tensor_slices(x)
|
|
63
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
|
64
|
+
ds = ds.batch(2) # Batching after map
|
|
65
|
+
|
|
66
|
+
# Generate preprocess and postprocess.
|
|
67
|
+
x = preprocessor.generate_preprocess({
|
|
68
|
+
"audio": keras.random.normal((1, 16000, 1)),
|
|
69
|
+
"text": ["optional prompt text"]
|
|
70
|
+
})
|
|
71
|
+
x = preprocessor.generate_postprocess({
|
|
72
|
+
"decoder_token_ids": keras.ops.array([[10, 20, 30, 2, 0]]),
|
|
73
|
+
"decoder_padding_mask": keras.ops.array([
|
|
74
|
+
[True, True, True, True, False]
|
|
75
|
+
])
|
|
76
|
+
})
|
|
77
|
+
```
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# TODO: Fill in once audio to text task model requirements are clearer.
|
keras_hub/src/models/backbone.py
CHANGED
|
@@ -177,14 +177,17 @@ class Backbone(keras.Model):
|
|
|
177
177
|
)
|
|
178
178
|
return loader.load_backbone(backbone_cls, load_weights, **kwargs)
|
|
179
179
|
|
|
180
|
-
def save_to_preset(self, preset_dir):
|
|
180
|
+
def save_to_preset(self, preset_dir, max_shard_size=10):
|
|
181
181
|
"""Save backbone to a preset directory.
|
|
182
182
|
|
|
183
183
|
Args:
|
|
184
184
|
preset_dir: The path to the local model preset directory.
|
|
185
|
+
max_shard_size: `int` or `float`. Maximum size in GB for each
|
|
186
|
+
sharded file. If `None`, no sharding will be done. Defaults to
|
|
187
|
+
`10`.
|
|
185
188
|
"""
|
|
186
189
|
saver = get_preset_saver(preset_dir)
|
|
187
|
-
saver.save_backbone(self)
|
|
190
|
+
saver.save_backbone(self, max_shard_size=max_shard_size)
|
|
188
191
|
|
|
189
192
|
def get_lora_target_names(self):
|
|
190
193
|
"""Returns list of layer names which are to be LoRA-fied.
|
|
@@ -81,7 +81,7 @@ class CSPNetBackbone(FeaturePyramidBackbone):
|
|
|
81
81
|
|
|
82
82
|
# Pretrained backbone
|
|
83
83
|
model = keras_hub.models.CSPNetBackbone.from_preset(
|
|
84
|
-
"
|
|
84
|
+
"csp_darknet_53_ra_imagenet"
|
|
85
85
|
)
|
|
86
86
|
model(input_data)
|
|
87
87
|
|
|
@@ -357,18 +357,6 @@ def bottleneck_block(
|
|
|
357
357
|
dtype=dtype,
|
|
358
358
|
name=f"{name}_bottleneck_block_bn_3",
|
|
359
359
|
)(x)
|
|
360
|
-
if activation == "leaky_relu":
|
|
361
|
-
x = layers.LeakyReLU(
|
|
362
|
-
negative_slope=0.01,
|
|
363
|
-
dtype=dtype,
|
|
364
|
-
name=f"{name}_bottleneck_block_activation_3",
|
|
365
|
-
)(x)
|
|
366
|
-
else:
|
|
367
|
-
x = layers.Activation(
|
|
368
|
-
activation,
|
|
369
|
-
dtype=dtype,
|
|
370
|
-
name=f"{name}_bottleneck_block_activation_3",
|
|
371
|
-
)(x)
|
|
372
360
|
|
|
373
361
|
x = layers.add(
|
|
374
362
|
[x, shortcut], dtype=dtype, name=f"{name}_bottleneck_block_add"
|
|
@@ -673,6 +661,13 @@ def cross_stage(
|
|
|
673
661
|
name=f"{name}_csp_activation_1",
|
|
674
662
|
)(x)
|
|
675
663
|
else:
|
|
664
|
+
if strides > 1:
|
|
665
|
+
x = layers.ZeroPadding2D(
|
|
666
|
+
1,
|
|
667
|
+
data_format=data_format,
|
|
668
|
+
dtype=dtype,
|
|
669
|
+
name=f"{name}_csp_conv_pad_1",
|
|
670
|
+
)(x)
|
|
676
671
|
x = layers.Conv2D(
|
|
677
672
|
filters=down_chs,
|
|
678
673
|
kernel_size=3,
|
|
@@ -882,6 +877,13 @@ def cross_stage3(
|
|
|
882
877
|
name=f"{name}_cs3_activation_1",
|
|
883
878
|
)(x)
|
|
884
879
|
else:
|
|
880
|
+
if strides > 1:
|
|
881
|
+
x = layers.ZeroPadding2D(
|
|
882
|
+
1,
|
|
883
|
+
data_format=data_format,
|
|
884
|
+
dtype=dtype,
|
|
885
|
+
name=f"{name}_cs3_conv_pad_1",
|
|
886
|
+
)(x)
|
|
885
887
|
x = layers.Conv2D(
|
|
886
888
|
filters=down_chs,
|
|
887
889
|
kernel_size=3,
|
|
@@ -1062,6 +1064,13 @@ def dark_stage(
|
|
|
1062
1064
|
name=f"{name}_dark_activation_1",
|
|
1063
1065
|
)(x)
|
|
1064
1066
|
else:
|
|
1067
|
+
if strides > 1:
|
|
1068
|
+
x = layers.ZeroPadding2D(
|
|
1069
|
+
1,
|
|
1070
|
+
data_format=data_format,
|
|
1071
|
+
dtype=dtype,
|
|
1072
|
+
name=f"{name}_dark_conv_pad_1",
|
|
1073
|
+
)(x)
|
|
1065
1074
|
x = layers.Conv2D(
|
|
1066
1075
|
filters=filters,
|
|
1067
1076
|
kernel_size=3,
|
|
@@ -1091,18 +1100,18 @@ def dark_stage(
|
|
|
1091
1100
|
dtype=dtype,
|
|
1092
1101
|
name=f"{name}_dark_activation_1",
|
|
1093
1102
|
)(x)
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1103
|
+
for i in range(depth):
|
|
1104
|
+
x = block_fn(
|
|
1105
|
+
filters=block_channels,
|
|
1106
|
+
dilation=dilation,
|
|
1107
|
+
bottle_ratio=bottle_ratio,
|
|
1108
|
+
groups=groups,
|
|
1109
|
+
activation=activation,
|
|
1110
|
+
data_format=data_format,
|
|
1111
|
+
channel_axis=channel_axis,
|
|
1112
|
+
dtype=dtype,
|
|
1113
|
+
name=f"{name}_block_{i}",
|
|
1114
|
+
)(x)
|
|
1106
1115
|
return x
|
|
1107
1116
|
|
|
1108
1117
|
return apply
|
|
@@ -1135,6 +1144,13 @@ def create_csp_stem(
|
|
|
1135
1144
|
or (i == last_idx and strides > 2 and not pooling)
|
|
1136
1145
|
else 1
|
|
1137
1146
|
)
|
|
1147
|
+
if conv_strides > 1:
|
|
1148
|
+
x = layers.ZeroPadding2D(
|
|
1149
|
+
(kernel_size - 1) // 2,
|
|
1150
|
+
data_format=data_format,
|
|
1151
|
+
dtype=dtype,
|
|
1152
|
+
name=f"csp_stem_pad_{i}",
|
|
1153
|
+
)(x)
|
|
1138
1154
|
x = layers.Conv2D(
|
|
1139
1155
|
filters=chs,
|
|
1140
1156
|
kernel_size=kernel_size,
|
|
@@ -1167,10 +1183,19 @@ def create_csp_stem(
|
|
|
1167
1183
|
|
|
1168
1184
|
if pooling == "max":
|
|
1169
1185
|
assert strides > 2
|
|
1186
|
+
# Use manual padding to handle edge case scenario to ignore zero's
|
|
1187
|
+
# as max value instead consider negative values from Leaky Relu type
|
|
1188
|
+
# of activations.
|
|
1189
|
+
pad_width = [[1, 1], [1, 1]]
|
|
1190
|
+
if data_format == "channels_last":
|
|
1191
|
+
pad_width += [[0, 0]]
|
|
1192
|
+
else:
|
|
1193
|
+
pad_width = [[0, 0]] + pad_width
|
|
1194
|
+
pad_width = [[0, 0]] + pad_width
|
|
1195
|
+
x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf"))
|
|
1170
1196
|
x = layers.MaxPooling2D(
|
|
1171
1197
|
pool_size=3,
|
|
1172
1198
|
strides=2,
|
|
1173
|
-
padding="same",
|
|
1174
1199
|
data_format=data_format,
|
|
1175
1200
|
dtype=dtype,
|
|
1176
1201
|
name="csp_stem_pool",
|
|
@@ -6,11 +6,46 @@ backbone_presets = {
|
|
|
6
6
|
"description": (
|
|
7
7
|
"A CSP-DarkNet (Cross-Stage-Partial) image classification model"
|
|
8
8
|
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
|
|
9
|
-
"a
|
|
9
|
+
"a 256x256 resolution."
|
|
10
10
|
),
|
|
11
|
-
"params":
|
|
11
|
+
"params": 27642184,
|
|
12
12
|
"path": "cspnet",
|
|
13
13
|
},
|
|
14
|
-
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/
|
|
14
|
+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_darknet_53_ra_imagenet/2",
|
|
15
|
+
},
|
|
16
|
+
"csp_resnext_50_ra_imagenet": {
|
|
17
|
+
"metadata": {
|
|
18
|
+
"description": (
|
|
19
|
+
"A CSP-ResNeXt (Cross-Stage-Partial) image classification model"
|
|
20
|
+
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
|
|
21
|
+
"a 256x256 resolution."
|
|
22
|
+
),
|
|
23
|
+
"params": 20569896,
|
|
24
|
+
"path": "cspnet",
|
|
25
|
+
},
|
|
26
|
+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnext_50_ra_imagenet/1",
|
|
27
|
+
},
|
|
28
|
+
"csp_resnet_50_ra_imagenet": {
|
|
29
|
+
"metadata": {
|
|
30
|
+
"description": (
|
|
31
|
+
"A CSP-ResNet (Cross-Stage-Partial) image classification model"
|
|
32
|
+
" pre-trained on the Randomly Augmented ImageNet 1k dataset at "
|
|
33
|
+
"a 256x256 resolution."
|
|
34
|
+
),
|
|
35
|
+
"params": 21616168,
|
|
36
|
+
"path": "cspnet",
|
|
37
|
+
},
|
|
38
|
+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/csp_resnet_50_ra_imagenet/1",
|
|
39
|
+
},
|
|
40
|
+
"darknet_53_imagenet": {
|
|
41
|
+
"metadata": {
|
|
42
|
+
"description": (
|
|
43
|
+
"A DarkNet image classification model pre-trained on the"
|
|
44
|
+
"ImageNet 1k dataset at a 256x256 resolution."
|
|
45
|
+
),
|
|
46
|
+
"params": 41609928,
|
|
47
|
+
"path": "cspnet",
|
|
48
|
+
},
|
|
49
|
+
"kaggle_handle": "kaggle://keras/cspdarknet/keras/darknet_53_imagenet/1",
|
|
15
50
|
},
|
|
16
51
|
}
|
|
@@ -29,7 +29,7 @@ class FalconBackbone(Backbone):
|
|
|
29
29
|
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
|
|
30
30
|
the transformer decoder.
|
|
31
31
|
attention_dropout_rate: float. Dropout probability for the attention.
|
|
32
|
-
feedforward_dropout_rate:
|
|
32
|
+
feedforward_dropout_rate: float. Dropout probability for the
|
|
33
33
|
feedforward.
|
|
34
34
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
35
35
|
for model computations and weights. Note that some computations,
|
|
@@ -61,7 +61,7 @@ backbone_presets = {
|
|
|
61
61
|
"params": 8537680896,
|
|
62
62
|
"path": "gemma",
|
|
63
63
|
},
|
|
64
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/
|
|
64
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/4",
|
|
65
65
|
},
|
|
66
66
|
"gemma_instruct_7b_en": {
|
|
67
67
|
"metadata": {
|
|
@@ -71,7 +71,7 @@ backbone_presets = {
|
|
|
71
71
|
"params": 8537680896,
|
|
72
72
|
"path": "gemma",
|
|
73
73
|
},
|
|
74
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/
|
|
74
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/4",
|
|
75
75
|
},
|
|
76
76
|
"gemma_1.1_instruct_7b_en": {
|
|
77
77
|
"metadata": {
|
|
@@ -82,7 +82,7 @@ backbone_presets = {
|
|
|
82
82
|
"params": 8537680896,
|
|
83
83
|
"path": "gemma",
|
|
84
84
|
},
|
|
85
|
-
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/
|
|
85
|
+
"kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/5",
|
|
86
86
|
},
|
|
87
87
|
"code_gemma_7b_en": {
|
|
88
88
|
"metadata": {
|
|
@@ -94,7 +94,7 @@ backbone_presets = {
|
|
|
94
94
|
"params": 8537680896,
|
|
95
95
|
"path": "gemma",
|
|
96
96
|
},
|
|
97
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/
|
|
97
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/3",
|
|
98
98
|
},
|
|
99
99
|
"code_gemma_instruct_7b_en": {
|
|
100
100
|
"metadata": {
|
|
@@ -106,7 +106,7 @@ backbone_presets = {
|
|
|
106
106
|
"params": 8537680896,
|
|
107
107
|
"path": "gemma",
|
|
108
108
|
},
|
|
109
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/
|
|
109
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/3",
|
|
110
110
|
},
|
|
111
111
|
"code_gemma_1.1_instruct_7b_en": {
|
|
112
112
|
"metadata": {
|
|
@@ -118,7 +118,7 @@ backbone_presets = {
|
|
|
118
118
|
"params": 8537680896,
|
|
119
119
|
"path": "gemma",
|
|
120
120
|
},
|
|
121
|
-
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/
|
|
121
|
+
"kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/3",
|
|
122
122
|
},
|
|
123
123
|
"gemma2_2b_en": {
|
|
124
124
|
"metadata": {
|
|
@@ -144,7 +144,7 @@ backbone_presets = {
|
|
|
144
144
|
"params": 9241705984,
|
|
145
145
|
"path": "gemma",
|
|
146
146
|
},
|
|
147
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/
|
|
147
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/4",
|
|
148
148
|
},
|
|
149
149
|
"gemma2_instruct_9b_en": {
|
|
150
150
|
"metadata": {
|
|
@@ -154,7 +154,7 @@ backbone_presets = {
|
|
|
154
154
|
"params": 9241705984,
|
|
155
155
|
"path": "gemma",
|
|
156
156
|
},
|
|
157
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/
|
|
157
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/4",
|
|
158
158
|
},
|
|
159
159
|
"gemma2_27b_en": {
|
|
160
160
|
"metadata": {
|
|
@@ -162,7 +162,7 @@ backbone_presets = {
|
|
|
162
162
|
"params": 27227128320,
|
|
163
163
|
"path": "gemma",
|
|
164
164
|
},
|
|
165
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/
|
|
165
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/3",
|
|
166
166
|
},
|
|
167
167
|
"gemma2_instruct_27b_en": {
|
|
168
168
|
"metadata": {
|
|
@@ -172,7 +172,7 @@ backbone_presets = {
|
|
|
172
172
|
"params": 27227128320,
|
|
173
173
|
"path": "gemma",
|
|
174
174
|
},
|
|
175
|
-
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/
|
|
175
|
+
"kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/3",
|
|
176
176
|
},
|
|
177
177
|
"shieldgemma_2b_en": {
|
|
178
178
|
"metadata": {
|
|
@@ -40,13 +40,13 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
|
40
40
|
|
|
41
41
|
For use with generation, the layer also exposes two methods
|
|
42
42
|
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
|
|
43
|
-
is attached to a `keras_hub.models.
|
|
43
|
+
is attached to a `keras_hub.models.Gemma3CausalLM` instance, these methods
|
|
44
44
|
will be called implicitly in `generate()`. They can also be called
|
|
45
45
|
standalone (e.g. to precompute preprocessing inputs for generation in a
|
|
46
46
|
separate process).
|
|
47
47
|
|
|
48
48
|
Args:
|
|
49
|
-
tokenizer: A `keras_hub.models.
|
|
49
|
+
tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance.
|
|
50
50
|
image_converter: A `keras_hub.layers.ImageConverter` instance. Defaults
|
|
51
51
|
to `None`.
|
|
52
52
|
sequence_length: The length of the packed inputs. Defaults to 1024.
|
|
@@ -512,6 +512,7 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
|
512
512
|
|
|
513
513
|
# Extract text part of the input.
|
|
514
514
|
prompts, responses = x["prompts"], x["responses"]
|
|
515
|
+
tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))])
|
|
515
516
|
|
|
516
517
|
# Find out if the input is batched/not batched. Uprank if not batched.
|
|
517
518
|
# In other preprocessors, we don't have to do this, but here, all
|