keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,366 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
from keras_hub.src.models.resnet.resnet_backbone import (
|
6
|
+
apply_basic_block as resnet_basic_block,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
@keras_hub_export("keras_hub.models.BASNetBackbone")
|
11
|
+
class BASNetBackbone(Backbone):
|
12
|
+
"""BASNet architecture for semantic segmentation.
|
13
|
+
|
14
|
+
A Keras model implementing the BASNet architecture described in [BASNet:
|
15
|
+
Boundary-Aware Segmentation Network for Mobile and Web Applications](
|
16
|
+
https://arxiv.org/abs/2101.04704). BASNet uses a predict-refine
|
17
|
+
architecture for highly accurate image segmentation.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
image_encoder: A `keras_hub.models.ResNetBackbone` instance. The
|
21
|
+
backbone network for the model that is used as a feature extractor
|
22
|
+
for BASNet prediction encoder. Currently supported backbones are
|
23
|
+
ResNet18 and ResNet34.
|
24
|
+
(Note: Do not specify `image_shape` within the backbone.
|
25
|
+
Please provide these while initializing the 'BASNetBackbone' model)
|
26
|
+
num_classes: int, the number of classes for the segmentation model.
|
27
|
+
image_shape: optional shape tuple, defaults to (None, None, 3).
|
28
|
+
projection_filters: int, number of filters in the convolution layer
|
29
|
+
projecting low-level features from the `backbone`.
|
30
|
+
prediction_heads: (Optional) List of `keras.layers.Layer` defining
|
31
|
+
the prediction module head for the model. If not provided, a
|
32
|
+
default head is created with a Conv2D layer followed by resizing.
|
33
|
+
refinement_head: (Optional) a `keras.layers.Layer` defining the
|
34
|
+
refinement module head for the model. If not provided, a default
|
35
|
+
head is created with a Conv2D layer.
|
36
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
37
|
+
to use for the model's computations and weights.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
image_encoder,
|
43
|
+
num_classes,
|
44
|
+
image_shape=(None, None, 3),
|
45
|
+
projection_filters=64,
|
46
|
+
prediction_heads=None,
|
47
|
+
refinement_head=None,
|
48
|
+
dtype=None,
|
49
|
+
**kwargs,
|
50
|
+
):
|
51
|
+
if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
|
52
|
+
image_encoder, keras.Model
|
53
|
+
):
|
54
|
+
raise ValueError(
|
55
|
+
"Argument `image_encoder` must be a `keras.layers.Layer`"
|
56
|
+
f" instance or `keras.Model`. Received instead"
|
57
|
+
f" image_encoder={image_encoder} (of type"
|
58
|
+
f" {type(image_encoder)})."
|
59
|
+
)
|
60
|
+
|
61
|
+
if tuple(image_encoder.image_shape) != (None, None, 3):
|
62
|
+
raise ValueError(
|
63
|
+
"Do not specify `image_shape` within the"
|
64
|
+
" `BASNetBackbone`'s image_encoder. \nPlease provide"
|
65
|
+
" `image_shape` while initializing the 'BASNetBackbone' model."
|
66
|
+
)
|
67
|
+
|
68
|
+
# === Functional Model ===
|
69
|
+
inputs = keras.layers.Input(shape=image_shape)
|
70
|
+
x = inputs
|
71
|
+
|
72
|
+
if prediction_heads is None:
|
73
|
+
prediction_heads = []
|
74
|
+
for size in (1, 2, 4, 8, 16, 32, 32):
|
75
|
+
head_layers = [
|
76
|
+
keras.layers.Conv2D(
|
77
|
+
num_classes,
|
78
|
+
kernel_size=(3, 3),
|
79
|
+
padding="same",
|
80
|
+
dtype=dtype,
|
81
|
+
)
|
82
|
+
]
|
83
|
+
if size != 1:
|
84
|
+
head_layers.append(
|
85
|
+
keras.layers.UpSampling2D(
|
86
|
+
size=size, interpolation="bilinear", dtype=dtype
|
87
|
+
)
|
88
|
+
)
|
89
|
+
prediction_heads.append(keras.Sequential(head_layers))
|
90
|
+
|
91
|
+
if refinement_head is None:
|
92
|
+
refinement_head = keras.Sequential(
|
93
|
+
[
|
94
|
+
keras.layers.Conv2D(
|
95
|
+
num_classes,
|
96
|
+
kernel_size=(3, 3),
|
97
|
+
padding="same",
|
98
|
+
dtype=dtype,
|
99
|
+
),
|
100
|
+
]
|
101
|
+
)
|
102
|
+
|
103
|
+
# Prediction model.
|
104
|
+
predict_model = basnet_predict(
|
105
|
+
x, image_encoder, projection_filters, prediction_heads, dtype=dtype
|
106
|
+
)
|
107
|
+
|
108
|
+
# Refinement model.
|
109
|
+
refine_model = basnet_rrm(
|
110
|
+
predict_model, projection_filters, refinement_head, dtype=dtype
|
111
|
+
)
|
112
|
+
|
113
|
+
outputs = refine_model.outputs # Combine outputs.
|
114
|
+
outputs.extend(predict_model.outputs)
|
115
|
+
|
116
|
+
output_names = ["refine_out"] + [
|
117
|
+
f"predict_out_{i}" for i in range(1, len(outputs))
|
118
|
+
]
|
119
|
+
|
120
|
+
outputs = {
|
121
|
+
output_name: keras.layers.Activation(
|
122
|
+
"sigmoid", name=output_name, dtype=dtype
|
123
|
+
)(output)
|
124
|
+
for output, output_name in zip(outputs, output_names)
|
125
|
+
}
|
126
|
+
|
127
|
+
super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)
|
128
|
+
|
129
|
+
# === Config ===
|
130
|
+
self.image_encoder = image_encoder
|
131
|
+
self.num_classes = num_classes
|
132
|
+
self.image_shape = image_shape
|
133
|
+
self.projection_filters = projection_filters
|
134
|
+
self.prediction_heads = prediction_heads
|
135
|
+
self.refinement_head = refinement_head
|
136
|
+
|
137
|
+
def get_config(self):
|
138
|
+
config = super().get_config()
|
139
|
+
config.update(
|
140
|
+
{
|
141
|
+
"image_encoder": keras.saving.serialize_keras_object(
|
142
|
+
self.image_encoder
|
143
|
+
),
|
144
|
+
"num_classes": self.num_classes,
|
145
|
+
"image_shape": self.image_shape,
|
146
|
+
"projection_filters": self.projection_filters,
|
147
|
+
"prediction_heads": [
|
148
|
+
keras.saving.serialize_keras_object(prediction_head)
|
149
|
+
for prediction_head in self.prediction_heads
|
150
|
+
],
|
151
|
+
"refinement_head": keras.saving.serialize_keras_object(
|
152
|
+
self.refinement_head
|
153
|
+
),
|
154
|
+
}
|
155
|
+
)
|
156
|
+
return config
|
157
|
+
|
158
|
+
@classmethod
|
159
|
+
def from_config(cls, config):
|
160
|
+
if "image_encoder" in config:
|
161
|
+
config["image_encoder"] = keras.layers.deserialize(
|
162
|
+
config["image_encoder"]
|
163
|
+
)
|
164
|
+
if "prediction_heads" in config and isinstance(
|
165
|
+
config["prediction_heads"], list
|
166
|
+
):
|
167
|
+
for i in range(len(config["prediction_heads"])):
|
168
|
+
if isinstance(config["prediction_heads"][i], dict):
|
169
|
+
config["prediction_heads"][i] = keras.layers.deserialize(
|
170
|
+
config["prediction_heads"][i]
|
171
|
+
)
|
172
|
+
|
173
|
+
if "refinement_head" in config and isinstance(
|
174
|
+
config["refinement_head"], dict
|
175
|
+
):
|
176
|
+
config["refinement_head"] = keras.layers.deserialize(
|
177
|
+
config["refinement_head"]
|
178
|
+
)
|
179
|
+
return super().from_config(config)
|
180
|
+
|
181
|
+
|
182
|
+
def convolution_block(x_input, filters, dilation=1, dtype=None):
|
183
|
+
"""Apply convolution + batch normalization + ReLU activation.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
x_input: Input keras tensor.
|
187
|
+
filters: int, number of output filters in the convolution.
|
188
|
+
dilation: int, dilation rate for the convolution operation.
|
189
|
+
Defaults to 1.
|
190
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
191
|
+
to use for the model's computations and weights.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
A tensor with convolution, batch normalization, and ReLU
|
195
|
+
activation applied.
|
196
|
+
"""
|
197
|
+
x = keras.layers.Conv2D(
|
198
|
+
filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype
|
199
|
+
)(x_input)
|
200
|
+
x = keras.layers.BatchNormalization(dtype=dtype)(x)
|
201
|
+
return keras.layers.Activation("relu", dtype=dtype)(x)
|
202
|
+
|
203
|
+
|
204
|
+
def get_resnet_block(_resnet, block_num):
|
205
|
+
"""Extract and return a specific ResNet block.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
_resnet: `keras.Model`. ResNet model instance.
|
209
|
+
block_num: int, block number to extract.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
A Keras Model representing the specified ResNet block.
|
213
|
+
"""
|
214
|
+
|
215
|
+
extractor_levels = ["P2", "P3", "P4", "P5"]
|
216
|
+
num_blocks = _resnet.stackwise_num_blocks
|
217
|
+
if block_num == 0:
|
218
|
+
x = _resnet.get_layer("pool1_pool").output
|
219
|
+
else:
|
220
|
+
x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]]
|
221
|
+
y = _resnet.get_layer(
|
222
|
+
f"stack{block_num}_block{num_blocks[block_num] - 1}_add"
|
223
|
+
).output
|
224
|
+
return keras.models.Model(
|
225
|
+
inputs=x,
|
226
|
+
outputs=y,
|
227
|
+
name=f"resnet_block{block_num + 1}",
|
228
|
+
)
|
229
|
+
|
230
|
+
|
231
|
+
def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None):
|
232
|
+
"""BASNet Prediction Module.
|
233
|
+
|
234
|
+
This module outputs a coarse label map by integrating heavy
|
235
|
+
encoder, bridge, and decoder blocks.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
x_input: Input keras tensor.
|
239
|
+
backbone: `keras.Model`. The backbone network used as a feature
|
240
|
+
extractor for BASNet prediction encoder.
|
241
|
+
filters: int, the number of filters.
|
242
|
+
segmentation_heads: List of `keras.layers.Layer`, A list of Keras
|
243
|
+
layers serving as the segmentation head for prediction module.
|
244
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
245
|
+
to use for the model's computations and weights.
|
246
|
+
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
A Keras Model that integrates the encoder, bridge, and decoder
|
250
|
+
blocks for coarse label map prediction.
|
251
|
+
"""
|
252
|
+
num_stages = 6
|
253
|
+
|
254
|
+
x = x_input
|
255
|
+
|
256
|
+
# -------------Encoder--------------
|
257
|
+
x = keras.layers.Conv2D(
|
258
|
+
filters, kernel_size=(3, 3), padding="same", dtype=dtype
|
259
|
+
)(x)
|
260
|
+
|
261
|
+
encoder_blocks = []
|
262
|
+
for i in range(num_stages):
|
263
|
+
if i < 4: # First four stages are adopted from ResNet backbone.
|
264
|
+
x = get_resnet_block(backbone, i)(x)
|
265
|
+
encoder_blocks.append(x)
|
266
|
+
else: # Last 2 stages consist of three basic resnet blocks.
|
267
|
+
x = keras.layers.MaxPool2D(
|
268
|
+
pool_size=(2, 2), strides=(2, 2), dtype=dtype
|
269
|
+
)(x)
|
270
|
+
for j in range(3):
|
271
|
+
x = resnet_basic_block(
|
272
|
+
x,
|
273
|
+
filters=x.shape[3],
|
274
|
+
conv_shortcut=False,
|
275
|
+
name=f"v1_basic_block_{i + 1}_{j + 1}",
|
276
|
+
dtype=dtype,
|
277
|
+
)
|
278
|
+
encoder_blocks.append(x)
|
279
|
+
|
280
|
+
# -------------Bridge-------------
|
281
|
+
x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
|
282
|
+
x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
|
283
|
+
x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
|
284
|
+
encoder_blocks.append(x)
|
285
|
+
|
286
|
+
# -------------Decoder-------------
|
287
|
+
decoder_blocks = []
|
288
|
+
for i in reversed(range(num_stages)):
|
289
|
+
if i != (num_stages - 1): # Except first, scale other decoder stages.
|
290
|
+
x = keras.layers.UpSampling2D(
|
291
|
+
size=2, interpolation="bilinear", dtype=dtype
|
292
|
+
)(x)
|
293
|
+
|
294
|
+
x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1)
|
295
|
+
x = convolution_block(x, filters=filters * 8, dtype=dtype)
|
296
|
+
x = convolution_block(x, filters=filters * 8, dtype=dtype)
|
297
|
+
x = convolution_block(x, filters=filters * 8, dtype=dtype)
|
298
|
+
decoder_blocks.append(x)
|
299
|
+
|
300
|
+
decoder_blocks.reverse() # Change order from last to first decoder stage.
|
301
|
+
decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder.
|
302
|
+
|
303
|
+
# -------------Side Outputs--------------
|
304
|
+
decoder_blocks = [
|
305
|
+
segmentation_head(decoder_block) # Prediction segmentation head.
|
306
|
+
for segmentation_head, decoder_block in zip(
|
307
|
+
segmentation_heads, decoder_blocks
|
308
|
+
)
|
309
|
+
]
|
310
|
+
|
311
|
+
return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
|
312
|
+
|
313
|
+
|
314
|
+
def basnet_rrm(base_model, filters, segmentation_head, dtype=None):
|
315
|
+
"""BASNet Residual Refinement Module (RRM).
|
316
|
+
|
317
|
+
This module outputs a fine label map by integrating light encoder,
|
318
|
+
bridge, and decoder blocks.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
base_model: Keras model used as the base or coarse label map.
|
322
|
+
filters: int, the number of filters.
|
323
|
+
segmentation_head: a `keras.layers.Layer`, A Keras layer serving
|
324
|
+
as the segmentation head for refinement module.
|
325
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
326
|
+
to use for the model's computations and weights.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
A Keras Model that constructs the Residual Refinement Module (RRM).
|
330
|
+
"""
|
331
|
+
num_stages = 4
|
332
|
+
|
333
|
+
x_input = base_model.output[0]
|
334
|
+
|
335
|
+
# -------------Encoder--------------
|
336
|
+
x = keras.layers.Conv2D(
|
337
|
+
filters, kernel_size=(3, 3), padding="same", dtype=dtype
|
338
|
+
)(x_input)
|
339
|
+
|
340
|
+
encoder_blocks = []
|
341
|
+
for _ in range(num_stages):
|
342
|
+
x = convolution_block(x, filters=filters)
|
343
|
+
encoder_blocks.append(x)
|
344
|
+
x = keras.layers.MaxPool2D(
|
345
|
+
pool_size=(2, 2), strides=(2, 2), dtype=dtype
|
346
|
+
)(x)
|
347
|
+
|
348
|
+
# -------------Bridge--------------
|
349
|
+
x = convolution_block(x, filters=filters, dtype=dtype)
|
350
|
+
|
351
|
+
# -------------Decoder--------------
|
352
|
+
for i in reversed(range(num_stages)):
|
353
|
+
x = keras.layers.UpSampling2D(
|
354
|
+
size=2, interpolation="bilinear", dtype=dtype
|
355
|
+
)(x)
|
356
|
+
x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1)
|
357
|
+
x = convolution_block(x, filters=filters)
|
358
|
+
|
359
|
+
x = segmentation_head(x) # Refinement segmentation head.
|
360
|
+
|
361
|
+
# ------------- refined = coarse + residual
|
362
|
+
x = keras.layers.Add(dtype=dtype)(
|
363
|
+
[x_input, x]
|
364
|
+
) # Add prediction + refinement output
|
365
|
+
|
366
|
+
return keras.models.Model(inputs=base_model.input, outputs=[x])
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.BASNetImageConverter")
|
7
|
+
class BASNetImageConverter(ImageConverter):
|
8
|
+
backbone_cls = BASNetBackbone
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
3
|
+
from keras_hub.src.models.basnet.basnet_image_converter import (
|
4
|
+
BASNetImageConverter,
|
5
|
+
)
|
6
|
+
from keras_hub.src.models.image_segmenter_preprocessor import (
|
7
|
+
ImageSegmenterPreprocessor,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
@keras_hub_export("keras_hub.models.BASNetPreprocessor")
|
12
|
+
class BASNetPreprocessor(ImageSegmenterPreprocessor):
|
13
|
+
backbone_cls = BASNetBackbone
|
14
|
+
image_converter_cls = BASNetImageConverter
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""BASNet model preset configurations."""
|
2
|
+
|
3
|
+
basnet_presets = {
|
4
|
+
"basnet_duts": {
|
5
|
+
"metadata": {
|
6
|
+
"description": (
|
7
|
+
"BASNet model with a 34-layer ResNet backbone, pre-trained "
|
8
|
+
"on the DUTS image dataset at a 288x288 resolution. Model "
|
9
|
+
"training was performed by Hamid Ali "
|
10
|
+
"(https://github.com/hamidriasat/BASNet)."
|
11
|
+
),
|
12
|
+
"params": 108886792,
|
13
|
+
"path": "basnet",
|
14
|
+
},
|
15
|
+
"kaggle_handle": "kaggle://keras/basnet/keras/base1",
|
16
|
+
},
|
17
|
+
}
|
@@ -8,11 +8,9 @@ backbone_presets = {
|
|
8
8
|
"Trained on English Wikipedia + BooksCorpus."
|
9
9
|
),
|
10
10
|
"params": 4385920,
|
11
|
-
"official_name": "BERT",
|
12
11
|
"path": "bert",
|
13
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased/
|
13
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased/3",
|
16
14
|
},
|
17
15
|
"bert_small_en_uncased": {
|
18
16
|
"metadata": {
|
@@ -21,11 +19,9 @@ backbone_presets = {
|
|
21
19
|
"Trained on English Wikipedia + BooksCorpus."
|
22
20
|
),
|
23
21
|
"params": 28763648,
|
24
|
-
"official_name": "BERT",
|
25
22
|
"path": "bert",
|
26
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
27
23
|
},
|
28
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_small_en_uncased/
|
24
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_small_en_uncased/3",
|
29
25
|
},
|
30
26
|
"bert_medium_en_uncased": {
|
31
27
|
"metadata": {
|
@@ -34,11 +30,9 @@ backbone_presets = {
|
|
34
30
|
"Trained on English Wikipedia + BooksCorpus."
|
35
31
|
),
|
36
32
|
"params": 41373184,
|
37
|
-
"official_name": "BERT",
|
38
33
|
"path": "bert",
|
39
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
40
34
|
},
|
41
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_medium_en_uncased/
|
35
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_medium_en_uncased/3",
|
42
36
|
},
|
43
37
|
"bert_base_en_uncased": {
|
44
38
|
"metadata": {
|
@@ -47,11 +41,9 @@ backbone_presets = {
|
|
47
41
|
"Trained on English Wikipedia + BooksCorpus."
|
48
42
|
),
|
49
43
|
"params": 109482240,
|
50
|
-
"official_name": "BERT",
|
51
44
|
"path": "bert",
|
52
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
53
45
|
},
|
54
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_en_uncased/
|
46
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_en_uncased/3",
|
55
47
|
},
|
56
48
|
"bert_base_en": {
|
57
49
|
"metadata": {
|
@@ -60,11 +52,9 @@ backbone_presets = {
|
|
60
52
|
"Trained on English Wikipedia + BooksCorpus."
|
61
53
|
),
|
62
54
|
"params": 108310272,
|
63
|
-
"official_name": "BERT",
|
64
55
|
"path": "bert",
|
65
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
66
56
|
},
|
67
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_en/
|
57
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_en/3",
|
68
58
|
},
|
69
59
|
"bert_base_zh": {
|
70
60
|
"metadata": {
|
@@ -72,23 +62,20 @@ backbone_presets = {
|
|
72
62
|
"12-layer BERT model. Trained on Chinese Wikipedia."
|
73
63
|
),
|
74
64
|
"params": 102267648,
|
75
|
-
"official_name": "BERT",
|
76
65
|
"path": "bert",
|
77
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
78
66
|
},
|
79
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_zh/
|
67
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_zh/3",
|
80
68
|
},
|
81
69
|
"bert_base_multi": {
|
82
70
|
"metadata": {
|
83
71
|
"description": (
|
84
|
-
"12-layer BERT model where case is maintained. Trained on
|
72
|
+
"12-layer BERT model where case is maintained. Trained on "
|
73
|
+
"trained on Wikipedias of 104 languages"
|
85
74
|
),
|
86
75
|
"params": 177853440,
|
87
|
-
"official_name": "BERT",
|
88
76
|
"path": "bert",
|
89
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
90
77
|
},
|
91
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_multi/
|
78
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_multi/3",
|
92
79
|
},
|
93
80
|
"bert_large_en_uncased": {
|
94
81
|
"metadata": {
|
@@ -97,11 +84,9 @@ backbone_presets = {
|
|
97
84
|
"Trained on English Wikipedia + BooksCorpus."
|
98
85
|
),
|
99
86
|
"params": 335141888,
|
100
|
-
"official_name": "BERT",
|
101
87
|
"path": "bert",
|
102
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
103
88
|
},
|
104
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en_uncased/
|
89
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en_uncased/3",
|
105
90
|
},
|
106
91
|
"bert_large_en": {
|
107
92
|
"metadata": {
|
@@ -110,22 +95,19 @@ backbone_presets = {
|
|
110
95
|
"Trained on English Wikipedia + BooksCorpus."
|
111
96
|
),
|
112
97
|
"params": 333579264,
|
113
|
-
"official_name": "BERT",
|
114
98
|
"path": "bert",
|
115
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
116
99
|
},
|
117
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/
|
100
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/3",
|
118
101
|
},
|
119
102
|
"bert_tiny_en_uncased_sst2": {
|
120
103
|
"metadata": {
|
121
104
|
"description": (
|
122
|
-
"The bert_tiny_en_uncased backbone model fine-tuned on the
|
105
|
+
"The bert_tiny_en_uncased backbone model fine-tuned on the "
|
106
|
+
"SST-2 sentiment analysis dataset."
|
123
107
|
),
|
124
108
|
"params": 4385920,
|
125
|
-
"official_name": "BERT",
|
126
109
|
"path": "bert",
|
127
|
-
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
128
110
|
},
|
129
|
-
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/
|
111
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/5",
|
130
112
|
},
|
131
113
|
}
|
@@ -34,9 +34,9 @@ class BertTextClassifier(TextClassifier):
|
|
34
34
|
Args:
|
35
35
|
backbone: A `keras_hub.models.BertBackbone` instance.
|
36
36
|
num_classes: int. Number of classes to predict.
|
37
|
-
preprocessor: A `keras_hub.models.BertTextClassifierPreprocessor` or
|
38
|
-
`None`, this model will not apply preprocessing, and
|
39
|
-
be preprocessed before calling the model.
|
37
|
+
preprocessor: A `keras_hub.models.BertTextClassifierPreprocessor` or
|
38
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
39
|
+
inputs should be preprocessed before calling the model.
|
40
40
|
activation: Optional `str` or callable. The
|
41
41
|
activation function to use on the model outputs. Set
|
42
42
|
`activation="softmax"` to return output probabilities.
|
@@ -8,11 +8,9 @@ backbone_presets = {
|
|
8
8
|
"trained on 45 natural languages and 12 programming languages."
|
9
9
|
),
|
10
10
|
"params": 559214592,
|
11
|
-
"official_name": "BLOOM",
|
12
11
|
"path": "bloom",
|
13
|
-
"model_card": "https://huggingface.co/bigscience/bloom-560m",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/
|
13
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/4",
|
16
14
|
},
|
17
15
|
"bloom_1.1b_multi": {
|
18
16
|
"metadata": {
|
@@ -21,11 +19,9 @@ backbone_presets = {
|
|
21
19
|
"trained on 45 natural languages and 12 programming languages."
|
22
20
|
),
|
23
21
|
"params": 1065314304,
|
24
|
-
"official_name": "BLOOM",
|
25
22
|
"path": "bloom",
|
26
|
-
"model_card": "https://huggingface.co/bigscience/bloom-1b1",
|
27
23
|
},
|
28
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.1b_multi/
|
24
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.1b_multi/2",
|
29
25
|
},
|
30
26
|
"bloom_1.7b_multi": {
|
31
27
|
"metadata": {
|
@@ -34,11 +30,9 @@ backbone_presets = {
|
|
34
30
|
"trained on 45 natural languages and 12 programming languages."
|
35
31
|
),
|
36
32
|
"params": 1722408960,
|
37
|
-
"official_name": "BLOOM",
|
38
33
|
"path": "bloom",
|
39
|
-
"model_card": "https://huggingface.co/bigscience/bloom-1b7",
|
40
34
|
},
|
41
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.7b_multi/
|
35
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.7b_multi/2",
|
42
36
|
},
|
43
37
|
"bloom_3b_multi": {
|
44
38
|
"metadata": {
|
@@ -47,11 +41,9 @@ backbone_presets = {
|
|
47
41
|
"trained on 45 natural languages and 12 programming languages."
|
48
42
|
),
|
49
43
|
"params": 3002557440,
|
50
|
-
"official_name": "BLOOM",
|
51
44
|
"path": "bloom",
|
52
|
-
"model_card": "https://huggingface.co/bigscience/bloom-3b",
|
53
45
|
},
|
54
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_3b_multi/
|
46
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloom_3b_multi/2",
|
55
47
|
},
|
56
48
|
"bloomz_560m_multi": {
|
57
49
|
"metadata": {
|
@@ -60,11 +52,9 @@ backbone_presets = {
|
|
60
52
|
"finetuned on crosslingual task mixture (xP3) dataset."
|
61
53
|
),
|
62
54
|
"params": 559214592,
|
63
|
-
"official_name": "BLOOMZ",
|
64
55
|
"path": "bloom",
|
65
|
-
"model_card": "https://huggingface.co/bigscience/bloomz-560m",
|
66
56
|
},
|
67
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_560m_multi/
|
57
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_560m_multi/2",
|
68
58
|
},
|
69
59
|
"bloomz_1.1b_multi": {
|
70
60
|
"metadata": {
|
@@ -73,11 +63,9 @@ backbone_presets = {
|
|
73
63
|
"finetuned on crosslingual task mixture (xP3) dataset."
|
74
64
|
),
|
75
65
|
"params": 1065314304,
|
76
|
-
"official_name": "BLOOMZ",
|
77
66
|
"path": "bloom",
|
78
|
-
"model_card": "https://huggingface.co/bigscience/bloomz-1b1",
|
79
67
|
},
|
80
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.1b_multi/
|
68
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.1b_multi/2",
|
81
69
|
},
|
82
70
|
"bloomz_1.7b_multi": {
|
83
71
|
"metadata": {
|
@@ -86,11 +74,9 @@ backbone_presets = {
|
|
86
74
|
"finetuned on crosslingual task mixture (xP3) dataset."
|
87
75
|
),
|
88
76
|
"params": 1722408960,
|
89
|
-
"official_name": "BLOOMZ",
|
90
77
|
"path": "bloom",
|
91
|
-
"model_card": "https://huggingface.co/bigscience/bloomz-1b7",
|
92
78
|
},
|
93
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.7b_multi/
|
79
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.7b_multi/2",
|
94
80
|
},
|
95
81
|
"bloomz_3b_multi": {
|
96
82
|
"metadata": {
|
@@ -99,10 +85,8 @@ backbone_presets = {
|
|
99
85
|
"finetuned on crosslingual task mixture (xP3) dataset."
|
100
86
|
),
|
101
87
|
"params": 3002557440,
|
102
|
-
"official_name": "BLOOMZ",
|
103
88
|
"path": "bloom",
|
104
|
-
"model_card": "https://huggingface.co/bigscience/bloomz-3b",
|
105
89
|
},
|
106
|
-
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_3b_multi/
|
90
|
+
"kaggle_handle": "kaggle://keras/bloom/keras/bloomz_3b_multi/2",
|
107
91
|
},
|
108
92
|
}
|