keras-hub-nightly 0.16.1.dev202409240339__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. keras_hub/api/layers/__init__.py +5 -0
  2. keras_hub/api/models/__init__.py +19 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/image_segmenter.py +86 -0
  15. keras_hub/src/models/sam/__init__.py +13 -0
  16. keras_hub/src/models/sam/sam_backbone.py +153 -0
  17. keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
  18. keras_hub/src/models/sam/sam_layers.py +402 -0
  19. keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
  20. keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
  21. keras_hub/src/models/sam/sam_transformer.py +159 -0
  22. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  23. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  24. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  25. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  26. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  27. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  28. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  29. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  30. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  31. keras_hub/src/models/text_to_image.py +295 -0
  32. keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
  33. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  34. keras_hub/src/utils/timm/preset_loader.py +3 -0
  35. keras_hub/src/version_utils.py +1 -1
  36. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  37. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +40 -24
  38. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  39. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  40. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  41. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  42. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  43. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  44. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -92,11 +92,14 @@ class DenseNetBackbone(FeaturePyramidBackbone):
92
92
  channel_axis,
93
93
  stackwise_num_repeats[stack_index],
94
94
  growth_rate,
95
- name=f"conv{index}",
95
+ name=f"stack{stack_index+1}",
96
96
  )
97
97
  pyramid_outputs[f"P{index}"] = x
98
98
  x = apply_transition_block(
99
- x, channel_axis, compression_ratio, name=f"pool{index}"
99
+ x,
100
+ channel_axis,
101
+ compression_ratio,
102
+ name=f"transition{stack_index+1}",
100
103
  )
101
104
 
102
105
  x = apply_dense_block(
@@ -104,7 +107,7 @@ class DenseNetBackbone(FeaturePyramidBackbone):
104
107
  channel_axis,
105
108
  stackwise_num_repeats[-1],
106
109
  growth_rate,
107
- name=f"conv{len(stackwise_num_repeats) + 1}",
110
+ name=f"stack{len(stackwise_num_repeats)}",
108
111
  )
109
112
  pyramid_outputs[f"P{len(stackwise_num_repeats) + 1}"] = x
110
113
  x = keras.layers.BatchNormalization(
@@ -148,7 +151,7 @@ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None):
148
151
 
149
152
  for i in range(num_repeats):
150
153
  x = apply_conv_block(
151
- x, channel_axis, growth_rate, name=f"{name}_block_{i}"
154
+ x, channel_axis, growth_rate, name=f"{name}_block{i+1}"
152
155
  )
153
156
  return x
154
157
 
@@ -196,9 +199,9 @@ def apply_conv_block(x, channel_axis, growth_rate, name=None):
196
199
 
197
200
  shortcut = x
198
201
  x = keras.layers.BatchNormalization(
199
- axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_0_bn"
202
+ axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_1_bn"
200
203
  )(x)
201
- x = keras.layers.Activation("relu", name=f"{name}_0_relu")(x)
204
+ x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x)
202
205
  x = keras.layers.Conv2D(
203
206
  4 * growth_rate,
204
207
  1,
@@ -207,9 +210,9 @@ def apply_conv_block(x, channel_axis, growth_rate, name=None):
207
210
  name=f"{name}_1_conv",
208
211
  )(x)
209
212
  x = keras.layers.BatchNormalization(
210
- axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_1_bn"
213
+ axis=channel_axis, epsilon=BN_EPSILON, name=f"{name}_2_bn"
211
214
  )(x)
212
- x = keras.layers.Activation("relu", name=f"{name}_1_relu")(x)
215
+ x = keras.layers.Activation("relu", name=f"{name}_2_relu")(x)
213
216
  x = keras.layers.Conv2D(
214
217
  growth_rate,
215
218
  3,
@@ -15,6 +15,9 @@ import keras
15
15
 
16
16
  from keras_hub.src.api_export import keras_hub_export
17
17
  from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
18
+ from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
19
+ DenseNetImageClassifierPreprocessor,
20
+ )
18
21
  from keras_hub.src.models.image_classifier import ImageClassifier
19
22
 
20
23
 
@@ -32,7 +35,13 @@ class DenseNetImageClassifier(ImageClassifier):
32
35
  num_classes: int. The number of classes to predict.
33
36
  activation: `None`, str or callable. The activation function to use on
34
37
  the `Dense` layer. Set `activation=None` to return the output
35
- logits. Defaults to `"softmax"`.
38
+ logits. Defaults to `None`.
39
+ pooling: A pooling layer to use before the final classification layer,
40
+ must be one of "avg" or "max". Use "avg" for
41
+ `GlobalAveragePooling2D` and "max" for "GlobalMaxPooling2D.
42
+ preprocessor: A `keras_hub.models.DenseNetImageClassifierPreprocessor`
43
+ or `None`. If `None`, this model will not apply preprocessing, and
44
+ inputs should be preprocessed before calling the model.
36
45
 
37
46
  Examples:
38
47
 
@@ -86,18 +95,29 @@ class DenseNetImageClassifier(ImageClassifier):
86
95
  """
87
96
 
88
97
  backbone_cls = DenseNetBackbone
98
+ preprocessor_cls = DenseNetImageClassifierPreprocessor
89
99
 
90
100
  def __init__(
91
101
  self,
92
102
  backbone,
93
103
  num_classes,
94
- activation="softmax",
95
- preprocessor=None, # adding this dummy arg for saved model test
96
- # TODO: once preprocessor flow is figured out, this needs to be updated
104
+ activation=None,
105
+ pooling="avg",
106
+ preprocessor=None,
97
107
  **kwargs,
98
108
  ):
99
109
  # === Layers ===
100
110
  self.backbone = backbone
111
+ self.preprocessor = preprocessor
112
+ if pooling == "avg":
113
+ self.pooler = keras.layers.GlobalAveragePooling2D()
114
+ elif pooling == "max":
115
+ self.pooler = keras.layers.GlobalMaxPooling2D()
116
+ else:
117
+ raise ValueError(
118
+ "Unknown `pooling` type. Polling should be either `'avg'` or "
119
+ f"`'max'`. Received: pooling={pooling}."
120
+ )
101
121
  self.output_dense = keras.layers.Dense(
102
122
  num_classes,
103
123
  activation=activation,
@@ -107,6 +127,7 @@ class DenseNetImageClassifier(ImageClassifier):
107
127
  # === Functional Model ===
108
128
  inputs = self.backbone.input
109
129
  x = self.backbone(inputs)
130
+ x = self.pooler(x)
110
131
  outputs = self.output_dense(x)
111
132
  super().__init__(
112
133
  inputs=inputs,
@@ -117,6 +138,7 @@ class DenseNetImageClassifier(ImageClassifier):
117
138
  # === Config ===
118
139
  self.num_classes = num_classes
119
140
  self.activation = activation
141
+ self.pooling = pooling
120
142
 
121
143
  def get_config(self):
122
144
  # Backbone serialized in `super`
@@ -125,6 +147,7 @@ class DenseNetImageClassifier(ImageClassifier):
125
147
  {
126
148
  "num_classes": self.num_classes,
127
149
  "activation": self.activation,
150
+ "pooling": self.pooling,
128
151
  }
129
152
  )
130
153
  return config
@@ -0,0 +1,27 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.api_export import keras_hub_export
15
+ from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
16
+ from keras_hub.src.models.densenet.densenet_image_converter import (
17
+ DenseNetImageConverter,
18
+ )
19
+ from keras_hub.src.models.image_classifier_preprocessor import (
20
+ ImageClassifierPreprocessor,
21
+ )
22
+
23
+
24
+ @keras_hub_export("keras_hub.models.DenseNetImageClassifierPreprocessor")
25
+ class DenseNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
26
+ backbone_cls = DenseNetBackbone
27
+ image_converter_cls = DenseNetImageConverter
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.api_export import keras_hub_export
15
+ from keras_hub.src.layers.preprocessing.resizing_image_converter import (
16
+ ResizingImageConverter,
17
+ )
18
+ from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.DenseNetImageConverter")
22
+ class DenseNetImageConverter(ResizingImageConverter):
23
+ backbone_cls = DenseNetBackbone
@@ -0,0 +1,56 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """DenseNet preset configurations."""
15
+
16
+ backbone_presets = {
17
+ "densenet_121_imagenet": {
18
+ "metadata": {
19
+ "description": (
20
+ "121-layer DenseNet model pre-trained on the ImageNet 1k dataset "
21
+ "at a 224x224 resolution."
22
+ ),
23
+ "params": 7037504,
24
+ "official_name": "DenseNet",
25
+ "path": "densenet",
26
+ "model_card": "https://arxiv.org/abs/1608.06993",
27
+ },
28
+ "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_121_imagenet",
29
+ },
30
+ "densenet_169_imagenet": {
31
+ "metadata": {
32
+ "description": (
33
+ "169-layer DenseNet model pre-trained on the ImageNet 1k dataset "
34
+ "at a 224x224 resolution."
35
+ ),
36
+ "params": 12642880,
37
+ "official_name": "DenseNet",
38
+ "path": "densenet",
39
+ "model_card": "https://arxiv.org/abs/1608.06993",
40
+ },
41
+ "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_169_imagenet",
42
+ },
43
+ "densenet_201_imagenet": {
44
+ "metadata": {
45
+ "description": (
46
+ "201-layer DenseNet model pre-trained on the ImageNet 1k dataset "
47
+ "at a 224x224 resolution."
48
+ ),
49
+ "params": 18321984,
50
+ "official_name": "DenseNet",
51
+ "path": "densenet",
52
+ "model_card": "https://arxiv.org/abs/1608.06993",
53
+ },
54
+ "kaggle_handle": "kaggle://kerashub/densenet/keras/densenet_201_imagenet",
55
+ },
56
+ }
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.task import Task
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.ImageSegmenter")
21
+ class ImageSegmenter(Task):
22
+ """Base class for all image segmentation tasks.
23
+
24
+ `ImageSegmenter` tasks wrap a `keras_hub.models.Task` and
25
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
26
+ image segmentation.
27
+
28
+ All `ImageSegmenter` tasks include a `from_preset()` constructor which can
29
+ be used to load a pre-trained config and weights.
30
+ """
31
+
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ # Default compilation.
35
+ self.compile()
36
+
37
+ def compile(
38
+ self,
39
+ optimizer="auto",
40
+ loss="auto",
41
+ *,
42
+ metrics="auto",
43
+ **kwargs,
44
+ ):
45
+ """Configures the `ImageSegmenter` task for training.
46
+
47
+ The `ImageSegmenter` task extends the default compilation signature of
48
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
49
+ `metrics`. To override these defaults, pass any value
50
+ to these arguments during compilation.
51
+
52
+ Args:
53
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
54
+ instance. Defaults to `"auto"`, which uses the default optimizer
55
+ for the given model and task. See `keras.Model.compile` and
56
+ `keras.optimizers` for more info on possible `optimizer` values.
57
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
58
+ Defaults to `"auto"`, where a
59
+ `keras.losses.SparseCategoricalCrossentropy` loss will be
60
+ applied for the classification task. See
61
+ `keras.Model.compile` and `keras.losses` for more info on
62
+ possible `loss` values.
63
+ metrics: `"auto"`, or a list of metrics to be evaluated by
64
+ the model during training and testing. Defaults to `"auto"`,
65
+ where a `keras.metrics.SparseCategoricalAccuracy` will be
66
+ applied to track the accuracy of the model during training.
67
+ See `keras.Model.compile` and `keras.metrics` for
68
+ more info on possible `metrics` values.
69
+ **kwargs: See `keras.Model.compile` for a full list of arguments
70
+ supported by the compile method.
71
+ """
72
+ if optimizer == "auto":
73
+ optimizer = keras.optimizers.Adam(5e-5)
74
+ if loss == "auto":
75
+ activation = getattr(self, "activation", None)
76
+ activation = keras.activations.get(activation)
77
+ from_logits = activation != keras.activations.softmax
78
+ loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits)
79
+ if metrics == "auto":
80
+ metrics = [keras.metrics.CategoricalAccuracy()]
81
+ super().compile(
82
+ optimizer=optimizer,
83
+ loss=loss,
84
+ metrics=metrics,
85
+ **kwargs,
86
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,153 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.models.backbone import Backbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.SAMBackbone")
22
+ class SAMBackbone(Backbone):
23
+ """A backbone for the Segment Anything Model (SAM).
24
+
25
+ Args:
26
+ image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for
27
+ the input images.
28
+ prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
29
+ compute embeddings for points, box, and mask prompt.
30
+ mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
31
+ generate segmentation masks given the embeddings generated by the
32
+ backbone and the prompt encoder.
33
+ dtype: The dtype of the layer weights.
34
+
35
+ Example:
36
+ ```python
37
+ image_size=128
38
+ batch_size=2
39
+ input_data = {
40
+ "images": np.ones(
41
+ (batch_size, image_size, image_size, 3),
42
+ dtype="float32",
43
+ ),
44
+ "points": np.ones((batch_size, 1, 2), dtype="float32"),
45
+ "labels": np.ones((batch_size, 1), dtype="float32"),
46
+ "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
47
+ "masks": np.zeros(
48
+ (batch_size, 0, image_size, image_size, 1)
49
+ ),
50
+ }
51
+ image_encoder = keras_hub.models.ViTDetBackbone(
52
+ hidden_size=16,
53
+ num_layers=16,
54
+ intermediate_dim=16 * 4,
55
+ num_heads=16,
56
+ global_attention_layer_indices=[2, 5, 8, 11],
57
+ patch_size=16,
58
+ num_output_channels=8,
59
+ window_size=2,
60
+ image_shape=(image_size, image_size, 3),
61
+ )
62
+ prompt_encoder = keras_hub.layers.SAMPromptEncoder(
63
+ hidden_size=8,
64
+ image_embedding_size=(8, 8),
65
+ input_image_size=(
66
+ image_size,
67
+ image_size,
68
+ ),
69
+ mask_in_channels=16,
70
+ )
71
+ mask_decoder = keras_hub.layers.SAMMaskDecoder(
72
+ num_layers=2,
73
+ hidden_size=8,
74
+ intermediate_dim=32,
75
+ num_heads=8,
76
+ embedding_dim=8,
77
+ num_multimask_outputs=3,
78
+ iou_head_depth=3,
79
+ iou_head_hidden_dim=8,
80
+ )
81
+ backbone = keras_hub.models.SAMBackbone(
82
+ image_encoder=image_encoder,
83
+ prompt_encoder=prompt_encoder,
84
+ mask_decoder=mask_decoder,
85
+ image_shape=(image_size, image_size, 3),
86
+ )
87
+ backbone(input_data)
88
+ ```
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ image_encoder,
94
+ prompt_encoder,
95
+ mask_decoder,
96
+ dtype=None,
97
+ **kwargs,
98
+ ):
99
+ # === Layers ===
100
+ self.image_encoder = image_encoder
101
+ self.prompt_encoder = prompt_encoder
102
+ self.mask_decoder = mask_decoder
103
+ # === Functional model
104
+ image_input = self.image_encoder.input
105
+
106
+ inputs = {
107
+ "images": image_input,
108
+ "points": keras.Input(shape=[None, 2], name="points"),
109
+ "labels": keras.Input(shape=[None], name="labels"),
110
+ "boxes": keras.Input(shape=[None, 2, 2], name="boxes"),
111
+ "masks": keras.Input(shape=[None, None, None, 1], name="masks"),
112
+ }
113
+ image_embeddings = self.image_encoder.output
114
+ prompt_embeddings = self.prompt_encoder(**inputs)
115
+ outputs = {
116
+ "image_embeddings": image_embeddings,
117
+ }
118
+ outputs.update(prompt_embeddings)
119
+ super().__init__(
120
+ inputs=inputs,
121
+ outputs=outputs,
122
+ dtype=dtype,
123
+ **kwargs,
124
+ )
125
+
126
+ def get_config(self):
127
+ config = super().get_config()
128
+ config.update(
129
+ {
130
+ "image_encoder": keras.layers.serialize(self.image_encoder),
131
+ "prompt_encoder": keras.layers.serialize(self.prompt_encoder),
132
+ "mask_decoder": keras.layers.serialize(self.mask_decoder),
133
+ }
134
+ )
135
+ return config
136
+
137
+ @classmethod
138
+ def from_config(cls, config):
139
+ config.update(
140
+ {
141
+ "image_encoder": keras.layers.deserialize(
142
+ config["image_encoder"]
143
+ ),
144
+ "prompt_encoder": keras.layers.deserialize(
145
+ config["prompt_encoder"]
146
+ ),
147
+ "mask_decoder": keras.layers.deserialize(
148
+ config["mask_decoder"]
149
+ ),
150
+ }
151
+ )
152
+
153
+ return super().from_config(config)