keras-hub-nightly 0.19.0.dev202412170354__py3-none-any.whl → 0.19.0.dev202412190352__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +3 -0
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +3 -0
- keras_hub/src/models/vit/vit_presets.py +77 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.19.0.dev202412170354.dist-info → keras_hub_nightly-0.19.0.dev202412190352.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.19.0.dev202412170354.dist-info → keras_hub_nightly-0.19.0.dev202412190352.dist-info}/RECORD +14 -8
- {keras_hub_nightly-0.19.0.dev202412170354.dist-info → keras_hub_nightly-0.19.0.dev202412190352.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.19.0.dev202412170354.dist-info → keras_hub_nightly-0.19.0.dev202412190352.dist-info}/top_level.txt +0 -0
keras_hub/api/layers/__init__.py
CHANGED
@@ -35,6 +35,9 @@ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
|
|
35
35
|
from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
|
36
36
|
from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
|
37
37
|
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
38
|
+
from keras_hub.src.models.basnet.basnet_image_converter import (
|
39
|
+
BASNetImageConverter,
|
40
|
+
)
|
38
41
|
from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
|
39
42
|
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
|
40
43
|
DeepLabV3ImageConverter,
|
keras_hub/api/models/__init__.py
CHANGED
@@ -29,6 +29,9 @@ from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import (
|
|
29
29
|
BartSeq2SeqLMPreprocessor,
|
30
30
|
)
|
31
31
|
from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
|
32
|
+
from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter
|
33
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
34
|
+
from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
|
32
35
|
from keras_hub.src.models.bert.bert_backbone import BertBackbone
|
33
36
|
from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM
|
34
37
|
from keras_hub.src.models.bert.bert_masked_lm_preprocessor import (
|
@@ -0,0 +1,122 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
5
|
+
from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
|
6
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.BASNetImageSegmenter")
|
10
|
+
class BASNetImageSegmenter(ImageSegmenter):
|
11
|
+
"""BASNet image segmentation task.
|
12
|
+
|
13
|
+
Args:
|
14
|
+
backbone: A `keras_hub.models.BASNetBackbone` instance.
|
15
|
+
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
|
16
|
+
a `keras.Layer` instance, or a callable. If `None` no preprocessing
|
17
|
+
will be applied to the inputs.
|
18
|
+
|
19
|
+
Example:
|
20
|
+
```python
|
21
|
+
import keras_hub
|
22
|
+
|
23
|
+
images = np.ones(shape=(1, 288, 288, 3))
|
24
|
+
labels = np.zeros(shape=(1, 288, 288, 1))
|
25
|
+
|
26
|
+
image_encoder = keras_hub.models.ResNetBackbone.from_preset(
|
27
|
+
"resnet_18_imagenet",
|
28
|
+
load_weights=False
|
29
|
+
)
|
30
|
+
backbone = keras_hub.models.BASNetBackbone(
|
31
|
+
image_encoder,
|
32
|
+
num_classes=1,
|
33
|
+
image_shape=[288, 288, 3]
|
34
|
+
)
|
35
|
+
model = keras_hub.models.BASNetImageSegmenter(backbone)
|
36
|
+
|
37
|
+
# Evaluate the model
|
38
|
+
pred_labels = model(images)
|
39
|
+
|
40
|
+
# Train the model
|
41
|
+
model.compile(
|
42
|
+
optimizer="adam",
|
43
|
+
loss=keras.losses.BinaryCrossentropy(from_logits=False),
|
44
|
+
metrics=["accuracy"],
|
45
|
+
)
|
46
|
+
model.fit(images, labels, epochs=3)
|
47
|
+
```
|
48
|
+
"""
|
49
|
+
|
50
|
+
backbone_cls = BASNetBackbone
|
51
|
+
preprocessor_cls = BASNetPreprocessor
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
backbone,
|
56
|
+
preprocessor=None,
|
57
|
+
**kwargs,
|
58
|
+
):
|
59
|
+
# === Functional Model ===
|
60
|
+
x = backbone.input
|
61
|
+
outputs = backbone(x)
|
62
|
+
# only return the refinement module's output as final prediction
|
63
|
+
outputs = outputs["refine_out"]
|
64
|
+
super().__init__(inputs=x, outputs=outputs, **kwargs)
|
65
|
+
|
66
|
+
# === Config ===
|
67
|
+
self.backbone = backbone
|
68
|
+
self.preprocessor = preprocessor
|
69
|
+
|
70
|
+
def compute_loss(self, x, y, y_pred, *args, **kwargs):
|
71
|
+
# train BASNet's prediction and refinement module outputs against the
|
72
|
+
# same ground truth data
|
73
|
+
outputs = self.backbone(x)
|
74
|
+
losses = []
|
75
|
+
for output in outputs.values():
|
76
|
+
losses.append(super().compute_loss(x, y, output, *args, **kwargs))
|
77
|
+
return keras.ops.sum(losses, axis=0)
|
78
|
+
|
79
|
+
def compile(
|
80
|
+
self,
|
81
|
+
optimizer="auto",
|
82
|
+
loss="auto",
|
83
|
+
metrics="auto",
|
84
|
+
**kwargs,
|
85
|
+
):
|
86
|
+
"""Configures the `BASNet` task for training.
|
87
|
+
|
88
|
+
`BASNet` extends the default compilation signature
|
89
|
+
of `keras.Model.compile` with defaults for `optimizer` and `loss`. To
|
90
|
+
override these defaults, pass any value to these arguments during
|
91
|
+
compilation.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
|
95
|
+
instance. Defaults to `"auto"`, which uses the default
|
96
|
+
optimizer for `BASNet`. See `keras.Model.compile` and
|
97
|
+
`keras.optimizers` for more info on possible `optimizer`
|
98
|
+
values.
|
99
|
+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
|
100
|
+
Defaults to `"auto"`, in which case the default loss
|
101
|
+
computation of `BASNet` will be applied.
|
102
|
+
See `keras.Model.compile` and `keras.losses` for more info on
|
103
|
+
possible `loss` values.
|
104
|
+
metrics: `"auto"`, or a list of metrics to be evaluated by
|
105
|
+
the model during training and testing. Defaults to `"auto"`,
|
106
|
+
where a `keras.metrics.Accuracy` will be applied to track the
|
107
|
+
accuracy of the model during training.
|
108
|
+
See `keras.Model.compile` and `keras.metrics` for
|
109
|
+
more info on possible `metrics` values.
|
110
|
+
**kwargs: See `keras.Model.compile` for a full list of arguments
|
111
|
+
supported by the compile method.
|
112
|
+
"""
|
113
|
+
if loss == "auto":
|
114
|
+
loss = keras.losses.BinaryCrossentropy()
|
115
|
+
if metrics == "auto":
|
116
|
+
metrics = [keras.metrics.Accuracy()]
|
117
|
+
super().compile(
|
118
|
+
optimizer=optimizer,
|
119
|
+
loss=loss,
|
120
|
+
metrics=metrics,
|
121
|
+
**kwargs,
|
122
|
+
)
|
@@ -0,0 +1,366 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
from keras_hub.src.models.resnet.resnet_backbone import (
|
6
|
+
apply_basic_block as resnet_basic_block,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
@keras_hub_export("keras_hub.models.BASNetBackbone")
|
11
|
+
class BASNetBackbone(Backbone):
|
12
|
+
"""BASNet architecture for semantic segmentation.
|
13
|
+
|
14
|
+
A Keras model implementing the BASNet architecture described in [BASNet:
|
15
|
+
Boundary-Aware Segmentation Network for Mobile and Web Applications](
|
16
|
+
https://arxiv.org/abs/2101.04704). BASNet uses a predict-refine
|
17
|
+
architecture for highly accurate image segmentation.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
image_encoder: A `keras_hub.models.ResNetBackbone` instance. The
|
21
|
+
backbone network for the model that is used as a feature extractor
|
22
|
+
for BASNet prediction encoder. Currently supported backbones are
|
23
|
+
ResNet18 and ResNet34.
|
24
|
+
(Note: Do not specify `image_shape` within the backbone.
|
25
|
+
Please provide these while initializing the 'BASNetBackbone' model)
|
26
|
+
num_classes: int, the number of classes for the segmentation model.
|
27
|
+
image_shape: optional shape tuple, defaults to (None, None, 3).
|
28
|
+
projection_filters: int, number of filters in the convolution layer
|
29
|
+
projecting low-level features from the `backbone`.
|
30
|
+
prediction_heads: (Optional) List of `keras.layers.Layer` defining
|
31
|
+
the prediction module head for the model. If not provided, a
|
32
|
+
default head is created with a Conv2D layer followed by resizing.
|
33
|
+
refinement_head: (Optional) a `keras.layers.Layer` defining the
|
34
|
+
refinement module head for the model. If not provided, a default
|
35
|
+
head is created with a Conv2D layer.
|
36
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
37
|
+
to use for the model's computations and weights.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
image_encoder,
|
43
|
+
num_classes,
|
44
|
+
image_shape=(None, None, 3),
|
45
|
+
projection_filters=64,
|
46
|
+
prediction_heads=None,
|
47
|
+
refinement_head=None,
|
48
|
+
dtype=None,
|
49
|
+
**kwargs,
|
50
|
+
):
|
51
|
+
if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
|
52
|
+
image_encoder, keras.Model
|
53
|
+
):
|
54
|
+
raise ValueError(
|
55
|
+
"Argument `image_encoder` must be a `keras.layers.Layer`"
|
56
|
+
f" instance or `keras.Model`. Received instead"
|
57
|
+
f" image_encoder={image_encoder} (of type"
|
58
|
+
f" {type(image_encoder)})."
|
59
|
+
)
|
60
|
+
|
61
|
+
if tuple(image_encoder.image_shape) != (None, None, 3):
|
62
|
+
raise ValueError(
|
63
|
+
"Do not specify `image_shape` within the"
|
64
|
+
" `BASNetBackbone`'s image_encoder. \nPlease provide"
|
65
|
+
" `image_shape` while initializing the 'BASNetBackbone' model."
|
66
|
+
)
|
67
|
+
|
68
|
+
# === Functional Model ===
|
69
|
+
inputs = keras.layers.Input(shape=image_shape)
|
70
|
+
x = inputs
|
71
|
+
|
72
|
+
if prediction_heads is None:
|
73
|
+
prediction_heads = []
|
74
|
+
for size in (1, 2, 4, 8, 16, 32, 32):
|
75
|
+
head_layers = [
|
76
|
+
keras.layers.Conv2D(
|
77
|
+
num_classes,
|
78
|
+
kernel_size=(3, 3),
|
79
|
+
padding="same",
|
80
|
+
dtype=dtype,
|
81
|
+
)
|
82
|
+
]
|
83
|
+
if size != 1:
|
84
|
+
head_layers.append(
|
85
|
+
keras.layers.UpSampling2D(
|
86
|
+
size=size, interpolation="bilinear", dtype=dtype
|
87
|
+
)
|
88
|
+
)
|
89
|
+
prediction_heads.append(keras.Sequential(head_layers))
|
90
|
+
|
91
|
+
if refinement_head is None:
|
92
|
+
refinement_head = keras.Sequential(
|
93
|
+
[
|
94
|
+
keras.layers.Conv2D(
|
95
|
+
num_classes,
|
96
|
+
kernel_size=(3, 3),
|
97
|
+
padding="same",
|
98
|
+
dtype=dtype,
|
99
|
+
),
|
100
|
+
]
|
101
|
+
)
|
102
|
+
|
103
|
+
# Prediction model.
|
104
|
+
predict_model = basnet_predict(
|
105
|
+
x, image_encoder, projection_filters, prediction_heads, dtype=dtype
|
106
|
+
)
|
107
|
+
|
108
|
+
# Refinement model.
|
109
|
+
refine_model = basnet_rrm(
|
110
|
+
predict_model, projection_filters, refinement_head, dtype=dtype
|
111
|
+
)
|
112
|
+
|
113
|
+
outputs = refine_model.outputs # Combine outputs.
|
114
|
+
outputs.extend(predict_model.outputs)
|
115
|
+
|
116
|
+
output_names = ["refine_out"] + [
|
117
|
+
f"predict_out_{i}" for i in range(1, len(outputs))
|
118
|
+
]
|
119
|
+
|
120
|
+
outputs = {
|
121
|
+
output_name: keras.layers.Activation(
|
122
|
+
"sigmoid", name=output_name, dtype=dtype
|
123
|
+
)(output)
|
124
|
+
for output, output_name in zip(outputs, output_names)
|
125
|
+
}
|
126
|
+
|
127
|
+
super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)
|
128
|
+
|
129
|
+
# === Config ===
|
130
|
+
self.image_encoder = image_encoder
|
131
|
+
self.num_classes = num_classes
|
132
|
+
self.image_shape = image_shape
|
133
|
+
self.projection_filters = projection_filters
|
134
|
+
self.prediction_heads = prediction_heads
|
135
|
+
self.refinement_head = refinement_head
|
136
|
+
|
137
|
+
def get_config(self):
|
138
|
+
config = super().get_config()
|
139
|
+
config.update(
|
140
|
+
{
|
141
|
+
"image_encoder": keras.saving.serialize_keras_object(
|
142
|
+
self.image_encoder
|
143
|
+
),
|
144
|
+
"num_classes": self.num_classes,
|
145
|
+
"image_shape": self.image_shape,
|
146
|
+
"projection_filters": self.projection_filters,
|
147
|
+
"prediction_heads": [
|
148
|
+
keras.saving.serialize_keras_object(prediction_head)
|
149
|
+
for prediction_head in self.prediction_heads
|
150
|
+
],
|
151
|
+
"refinement_head": keras.saving.serialize_keras_object(
|
152
|
+
self.refinement_head
|
153
|
+
),
|
154
|
+
}
|
155
|
+
)
|
156
|
+
return config
|
157
|
+
|
158
|
+
@classmethod
|
159
|
+
def from_config(cls, config):
|
160
|
+
if "image_encoder" in config:
|
161
|
+
config["image_encoder"] = keras.layers.deserialize(
|
162
|
+
config["image_encoder"]
|
163
|
+
)
|
164
|
+
if "prediction_heads" in config and isinstance(
|
165
|
+
config["prediction_heads"], list
|
166
|
+
):
|
167
|
+
for i in range(len(config["prediction_heads"])):
|
168
|
+
if isinstance(config["prediction_heads"][i], dict):
|
169
|
+
config["prediction_heads"][i] = keras.layers.deserialize(
|
170
|
+
config["prediction_heads"][i]
|
171
|
+
)
|
172
|
+
|
173
|
+
if "refinement_head" in config and isinstance(
|
174
|
+
config["refinement_head"], dict
|
175
|
+
):
|
176
|
+
config["refinement_head"] = keras.layers.deserialize(
|
177
|
+
config["refinement_head"]
|
178
|
+
)
|
179
|
+
return super().from_config(config)
|
180
|
+
|
181
|
+
|
182
|
+
def convolution_block(x_input, filters, dilation=1, dtype=None):
|
183
|
+
"""Apply convolution + batch normalization + ReLU activation.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
x_input: Input keras tensor.
|
187
|
+
filters: int, number of output filters in the convolution.
|
188
|
+
dilation: int, dilation rate for the convolution operation.
|
189
|
+
Defaults to 1.
|
190
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
191
|
+
to use for the model's computations and weights.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
A tensor with convolution, batch normalization, and ReLU
|
195
|
+
activation applied.
|
196
|
+
"""
|
197
|
+
x = keras.layers.Conv2D(
|
198
|
+
filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype
|
199
|
+
)(x_input)
|
200
|
+
x = keras.layers.BatchNormalization(dtype=dtype)(x)
|
201
|
+
return keras.layers.Activation("relu", dtype=dtype)(x)
|
202
|
+
|
203
|
+
|
204
|
+
def get_resnet_block(_resnet, block_num):
|
205
|
+
"""Extract and return a specific ResNet block.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
_resnet: `keras.Model`. ResNet model instance.
|
209
|
+
block_num: int, block number to extract.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
A Keras Model representing the specified ResNet block.
|
213
|
+
"""
|
214
|
+
|
215
|
+
extractor_levels = ["P2", "P3", "P4", "P5"]
|
216
|
+
num_blocks = _resnet.stackwise_num_blocks
|
217
|
+
if block_num == 0:
|
218
|
+
x = _resnet.get_layer("pool1_pool").output
|
219
|
+
else:
|
220
|
+
x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]]
|
221
|
+
y = _resnet.get_layer(
|
222
|
+
f"stack{block_num}_block{num_blocks[block_num]-1}_add"
|
223
|
+
).output
|
224
|
+
return keras.models.Model(
|
225
|
+
inputs=x,
|
226
|
+
outputs=y,
|
227
|
+
name=f"resnet_block{block_num + 1}",
|
228
|
+
)
|
229
|
+
|
230
|
+
|
231
|
+
def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None):
|
232
|
+
"""BASNet Prediction Module.
|
233
|
+
|
234
|
+
This module outputs a coarse label map by integrating heavy
|
235
|
+
encoder, bridge, and decoder blocks.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
x_input: Input keras tensor.
|
239
|
+
backbone: `keras.Model`. The backbone network used as a feature
|
240
|
+
extractor for BASNet prediction encoder.
|
241
|
+
filters: int, the number of filters.
|
242
|
+
segmentation_heads: List of `keras.layers.Layer`, A list of Keras
|
243
|
+
layers serving as the segmentation head for prediction module.
|
244
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
245
|
+
to use for the model's computations and weights.
|
246
|
+
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
A Keras Model that integrates the encoder, bridge, and decoder
|
250
|
+
blocks for coarse label map prediction.
|
251
|
+
"""
|
252
|
+
num_stages = 6
|
253
|
+
|
254
|
+
x = x_input
|
255
|
+
|
256
|
+
# -------------Encoder--------------
|
257
|
+
x = keras.layers.Conv2D(
|
258
|
+
filters, kernel_size=(3, 3), padding="same", dtype=dtype
|
259
|
+
)(x)
|
260
|
+
|
261
|
+
encoder_blocks = []
|
262
|
+
for i in range(num_stages):
|
263
|
+
if i < 4: # First four stages are adopted from ResNet backbone.
|
264
|
+
x = get_resnet_block(backbone, i)(x)
|
265
|
+
encoder_blocks.append(x)
|
266
|
+
else: # Last 2 stages consist of three basic resnet blocks.
|
267
|
+
x = keras.layers.MaxPool2D(
|
268
|
+
pool_size=(2, 2), strides=(2, 2), dtype=dtype
|
269
|
+
)(x)
|
270
|
+
for j in range(3):
|
271
|
+
x = resnet_basic_block(
|
272
|
+
x,
|
273
|
+
filters=x.shape[3],
|
274
|
+
conv_shortcut=False,
|
275
|
+
name=f"v1_basic_block_{i + 1}_{j + 1}",
|
276
|
+
dtype=dtype,
|
277
|
+
)
|
278
|
+
encoder_blocks.append(x)
|
279
|
+
|
280
|
+
# -------------Bridge-------------
|
281
|
+
x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
|
282
|
+
x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
|
283
|
+
x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
|
284
|
+
encoder_blocks.append(x)
|
285
|
+
|
286
|
+
# -------------Decoder-------------
|
287
|
+
decoder_blocks = []
|
288
|
+
for i in reversed(range(num_stages)):
|
289
|
+
if i != (num_stages - 1): # Except first, scale other decoder stages.
|
290
|
+
x = keras.layers.UpSampling2D(
|
291
|
+
size=2, interpolation="bilinear", dtype=dtype
|
292
|
+
)(x)
|
293
|
+
|
294
|
+
x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1)
|
295
|
+
x = convolution_block(x, filters=filters * 8, dtype=dtype)
|
296
|
+
x = convolution_block(x, filters=filters * 8, dtype=dtype)
|
297
|
+
x = convolution_block(x, filters=filters * 8, dtype=dtype)
|
298
|
+
decoder_blocks.append(x)
|
299
|
+
|
300
|
+
decoder_blocks.reverse() # Change order from last to first decoder stage.
|
301
|
+
decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder.
|
302
|
+
|
303
|
+
# -------------Side Outputs--------------
|
304
|
+
decoder_blocks = [
|
305
|
+
segmentation_head(decoder_block) # Prediction segmentation head.
|
306
|
+
for segmentation_head, decoder_block in zip(
|
307
|
+
segmentation_heads, decoder_blocks
|
308
|
+
)
|
309
|
+
]
|
310
|
+
|
311
|
+
return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
|
312
|
+
|
313
|
+
|
314
|
+
def basnet_rrm(base_model, filters, segmentation_head, dtype=None):
|
315
|
+
"""BASNet Residual Refinement Module (RRM).
|
316
|
+
|
317
|
+
This module outputs a fine label map by integrating light encoder,
|
318
|
+
bridge, and decoder blocks.
|
319
|
+
|
320
|
+
Args:
|
321
|
+
base_model: Keras model used as the base or coarse label map.
|
322
|
+
filters: int, the number of filters.
|
323
|
+
segmentation_head: a `keras.layers.Layer`, A Keras layer serving
|
324
|
+
as the segmentation head for refinement module.
|
325
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
326
|
+
to use for the model's computations and weights.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
A Keras Model that constructs the Residual Refinement Module (RRM).
|
330
|
+
"""
|
331
|
+
num_stages = 4
|
332
|
+
|
333
|
+
x_input = base_model.output[0]
|
334
|
+
|
335
|
+
# -------------Encoder--------------
|
336
|
+
x = keras.layers.Conv2D(
|
337
|
+
filters, kernel_size=(3, 3), padding="same", dtype=dtype
|
338
|
+
)(x_input)
|
339
|
+
|
340
|
+
encoder_blocks = []
|
341
|
+
for _ in range(num_stages):
|
342
|
+
x = convolution_block(x, filters=filters)
|
343
|
+
encoder_blocks.append(x)
|
344
|
+
x = keras.layers.MaxPool2D(
|
345
|
+
pool_size=(2, 2), strides=(2, 2), dtype=dtype
|
346
|
+
)(x)
|
347
|
+
|
348
|
+
# -------------Bridge--------------
|
349
|
+
x = convolution_block(x, filters=filters, dtype=dtype)
|
350
|
+
|
351
|
+
# -------------Decoder--------------
|
352
|
+
for i in reversed(range(num_stages)):
|
353
|
+
x = keras.layers.UpSampling2D(
|
354
|
+
size=2, interpolation="bilinear", dtype=dtype
|
355
|
+
)(x)
|
356
|
+
x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1)
|
357
|
+
x = convolution_block(x, filters=filters)
|
358
|
+
|
359
|
+
x = segmentation_head(x) # Refinement segmentation head.
|
360
|
+
|
361
|
+
# ------------- refined = coarse + residual
|
362
|
+
x = keras.layers.Add(dtype=dtype)(
|
363
|
+
[x_input, x]
|
364
|
+
) # Add prediction + refinement output
|
365
|
+
|
366
|
+
return keras.models.Model(inputs=base_model.input, outputs=[x])
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.BASNetImageConverter")
|
7
|
+
class BASNetImageConverter(ImageConverter):
|
8
|
+
backbone_cls = BASNetBackbone
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
3
|
+
from keras_hub.src.models.basnet.basnet_image_converter import (
|
4
|
+
BASNetImageConverter,
|
5
|
+
)
|
6
|
+
from keras_hub.src.models.image_segmenter_preprocessor import (
|
7
|
+
ImageSegmenterPreprocessor,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
@keras_hub_export("keras_hub.models.BASNetPreprocessor")
|
12
|
+
class BASNetPreprocessor(ImageSegmenterPreprocessor):
|
13
|
+
backbone_cls = BASNetBackbone
|
14
|
+
image_converter_cls = BASNetImageConverter
|
@@ -46,4 +46,81 @@ backbone_presets = {
|
|
46
46
|
},
|
47
47
|
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1",
|
48
48
|
},
|
49
|
+
"vit_base_patch32_384_imagenet": {
|
50
|
+
"metadata": {
|
51
|
+
"description": (
|
52
|
+
"ViT-B32 model pre-trained on the ImageNet 1k dataset with "
|
53
|
+
"image resolution of 384x384 "
|
54
|
+
),
|
55
|
+
"params": 87528192,
|
56
|
+
"path": "vit",
|
57
|
+
},
|
58
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/1",
|
59
|
+
},
|
60
|
+
"vit_large_patch32_384_imagenet": {
|
61
|
+
"metadata": {
|
62
|
+
"description": (
|
63
|
+
"ViT-L32 model pre-trained on the ImageNet 1k dataset with "
|
64
|
+
"image resolution of 384x384 "
|
65
|
+
),
|
66
|
+
"params": 305607680,
|
67
|
+
"path": "vit",
|
68
|
+
},
|
69
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/1",
|
70
|
+
},
|
71
|
+
"vit_base_patch16_224_imagenet21k": {
|
72
|
+
"metadata": {
|
73
|
+
"description": (
|
74
|
+
"ViT-B16 backbone pre-trained on the ImageNet 21k dataset with "
|
75
|
+
"image resolution of 224x224 "
|
76
|
+
),
|
77
|
+
"params": 85798656,
|
78
|
+
"path": "vit",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/1",
|
81
|
+
},
|
82
|
+
"vit_base_patch32_224_imagenet21k": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"ViT-B32 backbone pre-trained on the ImageNet 21k dataset with "
|
86
|
+
"image resolution of 224x224 "
|
87
|
+
),
|
88
|
+
"params": 87455232,
|
89
|
+
"path": "vit",
|
90
|
+
},
|
91
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/1",
|
92
|
+
},
|
93
|
+
"vit_huge_patch14_224_imagenet21k": {
|
94
|
+
"metadata": {
|
95
|
+
"description": (
|
96
|
+
"ViT-H14 backbone pre-trained on the ImageNet 21k dataset with "
|
97
|
+
"image resolution of 224x224 "
|
98
|
+
),
|
99
|
+
"params": 630764800,
|
100
|
+
"path": "vit",
|
101
|
+
},
|
102
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/1",
|
103
|
+
},
|
104
|
+
"vit_large_patch16_224_imagenet21k": {
|
105
|
+
"metadata": {
|
106
|
+
"description": (
|
107
|
+
"ViT-L16 backbone pre-trained on the ImageNet 21k dataset with "
|
108
|
+
"image resolution of 224x224 "
|
109
|
+
),
|
110
|
+
"params": 303301632,
|
111
|
+
"path": "vit",
|
112
|
+
},
|
113
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/1",
|
114
|
+
},
|
115
|
+
"vit_large_patch32_224_imagenet21k": {
|
116
|
+
"metadata": {
|
117
|
+
"description": (
|
118
|
+
"ViT-L32 backbone pre-trained on the ImageNet 21k dataset with "
|
119
|
+
"image resolution of 224x224 "
|
120
|
+
),
|
121
|
+
"params": 305510400,
|
122
|
+
"path": "vit",
|
123
|
+
},
|
124
|
+
"kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/1",
|
125
|
+
},
|
49
126
|
}
|
keras_hub/src/version_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: keras-hub-nightly
|
3
|
-
Version: 0.19.0.
|
3
|
+
Version: 0.19.0.dev202412190352
|
4
4
|
Summary: Industry-strength Natural Language Processing extensions for Keras.
|
5
5
|
Home-page: https://github.com/keras-team/keras-hub
|
6
6
|
Author: Keras team
|
@@ -1,15 +1,15 @@
|
|
1
1
|
keras_hub/__init__.py,sha256=QGdXyHgYt6cMUAP1ebxwc6oR86dE0dkMxNy2eOCQtFo,855
|
2
2
|
keras_hub/api/__init__.py,sha256=spMxsgqzjpeuC8rY4WP-2kAZ2qwwKRSbFwddXgUjqQE,524
|
3
3
|
keras_hub/api/bounding_box/__init__.py,sha256=T8R_X7BPm0et1xaZq8565uJmid7dylsSFSj4V-rGuFQ,1097
|
4
|
-
keras_hub/api/layers/__init__.py,sha256=
|
4
|
+
keras_hub/api/layers/__init__.py,sha256=YO_YLbcxMEboFEgmFkzRf_JfQciQukX2AseOGpWEbDo,3195
|
5
5
|
keras_hub/api/metrics/__init__.py,sha256=So8Ec-lOcTzn_UUMmAdzDm8RKkPu2dbRUm2px8gpUEI,381
|
6
|
-
keras_hub/api/models/__init__.py,sha256=
|
6
|
+
keras_hub/api/models/__init__.py,sha256=suTcar7FqO5w9nNtalqmfYn7Fs6XmNEGpbojK-gaMEY,16795
|
7
7
|
keras_hub/api/samplers/__init__.py,sha256=n-_SEXxr2LNUzK2FqVFN7alsrkx1P_HOVTeLZKeGCdE,730
|
8
8
|
keras_hub/api/tokenizers/__init__.py,sha256=mtJgQy1spfQnPAkeLoeinsT_W9iCWHlJXwzcol5W1aU,2524
|
9
9
|
keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
|
10
10
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
12
|
-
keras_hub/src/version_utils.py,sha256=
|
12
|
+
keras_hub/src/version_utils.py,sha256=mkSaU8Ln1tI0_K9qOrQhUYjd2Esml96pAUrGt42ls1Q,222
|
13
13
|
keras_hub/src/bounding_box/__init__.py,sha256=7i6KnGupN4AVivR_dFjQyuuTbI0GkHy8d-aMXeqZdU8,95
|
14
14
|
keras_hub/src/bounding_box/converters.py,sha256=UUp1hwegpDZyIo8sh9TLNy1v6JjwmvwzL6wmHFMAtbk,21916
|
15
15
|
keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
|
@@ -85,6 +85,12 @@ keras_hub/src/models/bart/bart_presets.py,sha256=ppk9r_4Sm21XO6F9k3L946rkJBwWSLN
|
|
85
85
|
keras_hub/src/models/bart/bart_seq_2_seq_lm.py,sha256=0r9snJsqqmH8F1_CDQZyFgqLNMYJM8AYFkmqfxUNB1U,19262
|
86
86
|
keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py,sha256=3_e-ULIcm_3DKgt7X7cvyLZEDIEkpu9HdANgH6MjZgg,4373
|
87
87
|
keras_hub/src/models/bart/bart_tokenizer.py,sha256=Q7IXmIwXzhPSN427oQRyF9ufoExQGS184Yo_4boaOZo,2811
|
88
|
+
keras_hub/src/models/basnet/__init__.py,sha256=4N6XvIUYYJl5xtoaL3_9fawUX_qP3WmTYNEEU7tn8Gw,253
|
89
|
+
keras_hub/src/models/basnet/basnet.py,sha256=JA58Q9lmygdSOm5MUaPAlaL6B8XnmqCcRaGrk9c8P3Q,4287
|
90
|
+
keras_hub/src/models/basnet/basnet_backbone.py,sha256=t_52WW6jetONS7AnPf9YsiMLDqOjVwjNuayQEv6ZAk4,13503
|
91
|
+
keras_hub/src/models/basnet/basnet_image_converter.py,sha256=DwzAwtZeggYw_qyRQ-Abnnm885Wobv3wClxRzOTscI0,342
|
92
|
+
keras_hub/src/models/basnet/basnet_preprocessor.py,sha256=uM504utaXODSqR5zpKnopRuaV_l84zCg06RkNoNSKIs,510
|
93
|
+
keras_hub/src/models/basnet/basnet_presets.py,sha256=z6tR2q_EvYnUmGfsWIWYfmR_8gvWYPH3QmtpAu_T8f8,63
|
88
94
|
keras_hub/src/models/bert/__init__.py,sha256=K_UmCqDgOFFvXgzjXRn5oG0WWi53rAsQMOmUrsiBe1k,245
|
89
95
|
keras_hub/src/models/bert/bert_backbone.py,sha256=o8GXUpoKPXLpfFzx5u9wI_3rZJeabPfYJEYSI09Clos,8069
|
90
96
|
keras_hub/src/models/bert/bert_masked_lm.py,sha256=8gb1g8h5VFVLmKNEPfLe26z7SOlFnzf9R9okK3rp8AU,4045
|
@@ -339,7 +345,7 @@ keras_hub/src/models/vit/vit_image_classifier.py,sha256=lMVxiD1_6drx7XQ7P7YzlqnF
|
|
339
345
|
keras_hub/src/models/vit/vit_image_classifier_preprocessor.py,sha256=wu6YcBlXMWB9sKCPvmNdGBZKTLQt_HyHWS6P9nyDwsk,504
|
340
346
|
keras_hub/src/models/vit/vit_image_converter.py,sha256=5xVF04BzMcdTDc6aErAYj3_BuGmVd3zoJMcH1ho4T0g,2561
|
341
347
|
keras_hub/src/models/vit/vit_layers.py,sha256=s4j3n3qnJnv6W9AdUkNsO3Vsi_BhxEGECYkaLVCU6XY,13238
|
342
|
-
keras_hub/src/models/vit/vit_presets.py,sha256=
|
348
|
+
keras_hub/src/models/vit/vit_presets.py,sha256=1QSyagzonaK4zpJdnjW2UL70T85xGxktsmLdSxcZTjk,4479
|
343
349
|
keras_hub/src/models/vit_det/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
344
350
|
keras_hub/src/models/vit_det/vit_det_backbone.py,sha256=DOZ5J7c1t5PAZ6y0pMmBoQTMOUup7UoUrYVfCs69ltY,7697
|
345
351
|
keras_hub/src/models/vit_det/vit_layers.py,sha256=mnwu56chMc6zxmfp_hsLdR7TXYy1_YsWy1KwGX9M5Ic,19840
|
@@ -411,7 +417,7 @@ keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYum
|
|
411
417
|
keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
|
412
418
|
keras_hub/src/utils/transformers/preset_loader.py,sha256=DgGJXbTSB9Na8FIR-YWWVqQPOFxHwWrGm41EwcS_EFs,3797
|
413
419
|
keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
|
414
|
-
keras_hub_nightly-0.19.0.
|
415
|
-
keras_hub_nightly-0.19.0.
|
416
|
-
keras_hub_nightly-0.19.0.
|
417
|
-
keras_hub_nightly-0.19.0.
|
420
|
+
keras_hub_nightly-0.19.0.dev202412190352.dist-info/METADATA,sha256=4ggUncw0HlT-6YiKGo6xR7EWBavcvzwzTovgZ4hRwF8,7263
|
421
|
+
keras_hub_nightly-0.19.0.dev202412190352.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
422
|
+
keras_hub_nightly-0.19.0.dev202412190352.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
423
|
+
keras_hub_nightly-0.19.0.dev202412190352.dist-info/RECORD,,
|
File without changes
|