keras-hub-nightly 0.16.1.dev202410030339__py3-none-any.whl → 0.16.1.dev202410040340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +9 -0
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
- keras_hub/src/models/task.py +20 -15
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +172 -0
- keras_hub/src/models/vae/vae_layers.py +740 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/RECORD +23 -14
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202410030339.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/top_level.txt +0 -0
keras_hub/api/layers/__init__.py
CHANGED
@@ -34,6 +34,9 @@ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
|
|
34
34
|
from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
|
35
35
|
from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
|
36
36
|
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
37
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
|
38
|
+
DeepLabV3ImageConverter,
|
39
|
+
)
|
37
40
|
from keras_hub.src.models.densenet.densenet_image_converter import (
|
38
41
|
DenseNetImageConverter,
|
39
42
|
)
|
keras_hub/api/models/__init__.py
CHANGED
@@ -85,6 +85,15 @@ from keras_hub.src.models.deberta_v3.deberta_v3_text_classifier_preprocessor imp
|
|
85
85
|
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
|
86
86
|
DebertaV3Tokenizer,
|
87
87
|
)
|
88
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
|
89
|
+
DeepLabV3Backbone,
|
90
|
+
)
|
91
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import (
|
92
|
+
DeepLabV3ImageSegmenterPreprocessor,
|
93
|
+
)
|
94
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
|
95
|
+
DeepLabV3ImageSegmenter,
|
96
|
+
)
|
88
97
|
from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
|
89
98
|
from keras_hub.src.models.densenet.densenet_image_classifier import (
|
90
99
|
DenseNetImageClassifier,
|
@@ -0,0 +1,7 @@
|
|
1
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
|
2
|
+
DeepLabV3Backbone,
|
3
|
+
)
|
4
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_presets import backbone_presets
|
5
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
6
|
+
|
7
|
+
register_presets(backbone_presets, DeepLabV3Backbone)
|
@@ -0,0 +1,196 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import (
|
6
|
+
SpatialPyramidPooling,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
@keras_hub_export("keras_hub.models.DeepLabV3Backbone")
|
11
|
+
class DeepLabV3Backbone(Backbone):
|
12
|
+
"""DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation.
|
13
|
+
|
14
|
+
This class implements a DeepLabV3 & DeepLabV3Plus architecture as described
|
15
|
+
in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](
|
16
|
+
https://arxiv.org/abs/1802.02611)(ECCV 2018)
|
17
|
+
and [Rethinking Atrous Convolution for Semantic Image Segmentation](
|
18
|
+
https://arxiv.org/abs/1706.05587)(CVPR 2017)
|
19
|
+
|
20
|
+
Args:
|
21
|
+
image_encoder: `keras.Model`. An instance that is used as a feature
|
22
|
+
extractor for the Encoder. Should either be a
|
23
|
+
`keras_hub.models.Backbone` or a `keras.Model` that implements the
|
24
|
+
`pyramid_outputs` property with keys "P2", "P3" etc as values.
|
25
|
+
A somewhat sensible backbone to use in many cases is
|
26
|
+
the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`.
|
27
|
+
projection_filters: int. Number of filters in the convolution layer
|
28
|
+
projecting low-level features from the `image_encoder`.
|
29
|
+
spatial_pyramid_pooling_key: str. A layer level to extract and perform
|
30
|
+
`spatial_pyramid_pooling`, one of the key from the `image_encoder`
|
31
|
+
`pyramid_outputs` property such as "P4", "P5" etc.
|
32
|
+
upsampling_size: int or tuple of 2 integers. The upsampling factors for
|
33
|
+
rows and columns of `spatial_pyramid_pooling` layer.
|
34
|
+
If `low_level_feature_key` is given then `spatial_pyramid_pooling`s
|
35
|
+
layer resolution should match with the `low_level_feature`s layer
|
36
|
+
resolution to concatenate both the layers for combined encoder
|
37
|
+
outputs.
|
38
|
+
dilation_rates: list. A `list` of integers for parallel dilated conv applied to
|
39
|
+
`SpatialPyramidPooling`. Usually a
|
40
|
+
sample choice of rates are `[6, 12, 18]`.
|
41
|
+
low_level_feature_key: str optional. A layer level to extract the feature
|
42
|
+
from one of the key from the `image_encoder`s `pyramid_outputs`
|
43
|
+
property such as "P2", "P3" etc which will be the Decoder block.
|
44
|
+
Required only when the DeepLabV3Plus architecture needs to be applied.
|
45
|
+
image_shape: tuple. The input shape without the batch size.
|
46
|
+
Defaults to `(None, None, 3)`.
|
47
|
+
|
48
|
+
Example:
|
49
|
+
```python
|
50
|
+
# Load a trained backbone to extract features from it's `pyramid_outputs`.
|
51
|
+
image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet")
|
52
|
+
|
53
|
+
model = keras_hub.models.DeepLabV3Backbone(
|
54
|
+
image_encoder=image_encoder,
|
55
|
+
projection_filters=48,
|
56
|
+
low_level_feature_key="P2",
|
57
|
+
spatial_pyramid_pooling_key="P5",
|
58
|
+
upsampling_size = 8,
|
59
|
+
dilation_rates = [6, 12, 18]
|
60
|
+
)
|
61
|
+
```
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
image_encoder,
|
67
|
+
spatial_pyramid_pooling_key,
|
68
|
+
upsampling_size,
|
69
|
+
dilation_rates,
|
70
|
+
low_level_feature_key=None,
|
71
|
+
projection_filters=48,
|
72
|
+
image_shape=(None, None, 3),
|
73
|
+
**kwargs,
|
74
|
+
):
|
75
|
+
if not isinstance(image_encoder, keras.Model):
|
76
|
+
raise ValueError(
|
77
|
+
"Argument `image_encoder` must be a `keras.Model` instance. Received instead "
|
78
|
+
f"{image_encoder} (of type {type(image_encoder)})."
|
79
|
+
)
|
80
|
+
data_format = keras.config.image_data_format()
|
81
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
82
|
+
|
83
|
+
# === Layers ===
|
84
|
+
inputs = keras.layers.Input(image_shape, name="inputs")
|
85
|
+
|
86
|
+
fpn_model = keras.Model(
|
87
|
+
image_encoder.inputs, image_encoder.pyramid_outputs
|
88
|
+
)
|
89
|
+
|
90
|
+
fpn_outputs = fpn_model(inputs)
|
91
|
+
|
92
|
+
spatial_pyramid_pooling = SpatialPyramidPooling(
|
93
|
+
dilation_rates=dilation_rates
|
94
|
+
)
|
95
|
+
spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key]
|
96
|
+
spp_outputs = spatial_pyramid_pooling(spatial_backbone_features)
|
97
|
+
|
98
|
+
encoder_outputs = keras.layers.UpSampling2D(
|
99
|
+
size=upsampling_size,
|
100
|
+
interpolation="bilinear",
|
101
|
+
name="encoder_output_upsampling",
|
102
|
+
data_format=data_format,
|
103
|
+
)(spp_outputs)
|
104
|
+
|
105
|
+
if low_level_feature_key:
|
106
|
+
decoder_feature = fpn_outputs[low_level_feature_key]
|
107
|
+
low_level_projected_features = apply_low_level_feature_network(
|
108
|
+
decoder_feature, projection_filters, channel_axis
|
109
|
+
)
|
110
|
+
|
111
|
+
encoder_outputs = keras.layers.Concatenate(
|
112
|
+
axis=channel_axis, name="encoder_decoder_concat"
|
113
|
+
)([encoder_outputs, low_level_projected_features])
|
114
|
+
# upsampling to the original image size
|
115
|
+
upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // (
|
116
|
+
int(upsampling_size[0])
|
117
|
+
if isinstance(upsampling_size, tuple)
|
118
|
+
else upsampling_size
|
119
|
+
)
|
120
|
+
# === Functional Model ===
|
121
|
+
x = keras.layers.Conv2D(
|
122
|
+
name="segmentation_head_conv",
|
123
|
+
filters=256,
|
124
|
+
kernel_size=1,
|
125
|
+
padding="same",
|
126
|
+
use_bias=False,
|
127
|
+
data_format=data_format,
|
128
|
+
)(encoder_outputs)
|
129
|
+
x = keras.layers.BatchNormalization(
|
130
|
+
name="segmentation_head_norm", axis=channel_axis
|
131
|
+
)(x)
|
132
|
+
x = keras.layers.ReLU(name="segmentation_head_relu")(x)
|
133
|
+
x = keras.layers.UpSampling2D(
|
134
|
+
size=upsampling,
|
135
|
+
interpolation="bilinear",
|
136
|
+
data_format=data_format,
|
137
|
+
name="backbone_output_upsampling",
|
138
|
+
)(x)
|
139
|
+
|
140
|
+
super().__init__(inputs=inputs, outputs=x, **kwargs)
|
141
|
+
|
142
|
+
# === Config ===
|
143
|
+
self.image_shape = image_shape
|
144
|
+
self.image_encoder = image_encoder
|
145
|
+
self.projection_filters = projection_filters
|
146
|
+
self.upsampling_size = upsampling_size
|
147
|
+
self.dilation_rates = dilation_rates
|
148
|
+
self.low_level_feature_key = low_level_feature_key
|
149
|
+
self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key
|
150
|
+
|
151
|
+
def get_config(self):
|
152
|
+
config = super().get_config()
|
153
|
+
config.update(
|
154
|
+
{
|
155
|
+
"image_encoder": keras.saving.serialize_keras_object(
|
156
|
+
self.image_encoder
|
157
|
+
),
|
158
|
+
"projection_filters": self.projection_filters,
|
159
|
+
"dilation_rates": self.dilation_rates,
|
160
|
+
"upsampling_size": self.upsampling_size,
|
161
|
+
"low_level_feature_key": self.low_level_feature_key,
|
162
|
+
"spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key,
|
163
|
+
"image_shape": self.image_shape,
|
164
|
+
}
|
165
|
+
)
|
166
|
+
return config
|
167
|
+
|
168
|
+
@classmethod
|
169
|
+
def from_config(cls, config):
|
170
|
+
if "image_encoder" in config and isinstance(
|
171
|
+
config["image_encoder"], dict
|
172
|
+
):
|
173
|
+
config["image_encoder"] = keras.layers.deserialize(
|
174
|
+
config["image_encoder"]
|
175
|
+
)
|
176
|
+
return super().from_config(config)
|
177
|
+
|
178
|
+
|
179
|
+
def apply_low_level_feature_network(
|
180
|
+
input_tensor, projection_filters, channel_axis
|
181
|
+
):
|
182
|
+
data_format = keras.config.image_data_format()
|
183
|
+
x = keras.layers.Conv2D(
|
184
|
+
name="decoder_conv",
|
185
|
+
filters=projection_filters,
|
186
|
+
kernel_size=1,
|
187
|
+
padding="same",
|
188
|
+
use_bias=False,
|
189
|
+
data_format=data_format,
|
190
|
+
)(input_tensor)
|
191
|
+
|
192
|
+
x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)(
|
193
|
+
x
|
194
|
+
)
|
195
|
+
x = keras.layers.ReLU(name="decoder_relu")(x)
|
196
|
+
return x
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
|
4
|
+
DeepLabV3Backbone,
|
5
|
+
)
|
6
|
+
|
7
|
+
|
8
|
+
@keras_hub_export("keras_hub.layers.DeepLabV3ImageConverter")
|
9
|
+
class DeepLabV3ImageConverter(ImageConverter):
|
10
|
+
backbone_cls = DeepLabV3Backbone
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
|
3
|
+
DeepLabV3Backbone,
|
4
|
+
)
|
5
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
|
6
|
+
DeepLabV3ImageConverter,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.image_segmenter_preprocessor import (
|
9
|
+
ImageSegmenterPreprocessor,
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenterPreprocessor")
|
14
|
+
class DeepLabV3ImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
|
15
|
+
backbone_cls = DeepLabV3Backbone
|
16
|
+
image_converter_cls = DeepLabV3ImageConverter
|
@@ -0,0 +1,215 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
class SpatialPyramidPooling(keras.layers.Layer):
|
6
|
+
"""Implements the Atrous Spatial Pyramid Pooling.
|
7
|
+
|
8
|
+
Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution
|
9
|
+
for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and
|
10
|
+
[Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
11
|
+
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
|
12
|
+
|
13
|
+
Args:
|
14
|
+
dilation_rates: list of ints. The dilation rate for parallel dilated conv.
|
15
|
+
Usually a sample choice of rates are `[6, 12, 18]`.
|
16
|
+
num_channels: int. The number of output channels, defaults to `256`.
|
17
|
+
activation: str. Activation to be used, defaults to `relu`.
|
18
|
+
dropout: float. The dropout rate of the final projection output after the
|
19
|
+
activations and batch norm, defaults to `0.0`, which means no dropout is
|
20
|
+
applied to the output.
|
21
|
+
|
22
|
+
Example:
|
23
|
+
```python
|
24
|
+
inp = keras.layers.Input((384, 384, 3))
|
25
|
+
backbone = keras.applications.EfficientNetB0(
|
26
|
+
input_tensor=inp,
|
27
|
+
include_top=False)
|
28
|
+
output = backbone(inp)
|
29
|
+
output = SpatialPyramidPooling(
|
30
|
+
dilation_rates=[6, 12, 18])(output)
|
31
|
+
```
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
dilation_rates,
|
37
|
+
num_channels=256,
|
38
|
+
activation="relu",
|
39
|
+
dropout=0.0,
|
40
|
+
**kwargs,
|
41
|
+
):
|
42
|
+
super().__init__(**kwargs)
|
43
|
+
self.dilation_rates = dilation_rates
|
44
|
+
self.num_channels = num_channels
|
45
|
+
self.activation = activation
|
46
|
+
self.dropout = dropout
|
47
|
+
self.data_format = keras.config.image_data_format()
|
48
|
+
self.channel_axis = -1 if self.data_format == "channels_last" else 1
|
49
|
+
|
50
|
+
def build(self, input_shape):
|
51
|
+
channels = input_shape[self.channel_axis]
|
52
|
+
|
53
|
+
# This is the parallel networks that process the input features with
|
54
|
+
# different dilation rates. The output from each channel will be merged
|
55
|
+
# together and feed to the output.
|
56
|
+
self.aspp_parallel_channels = []
|
57
|
+
|
58
|
+
# Channel1 with Conv2D and 1x1 kernel size.
|
59
|
+
conv_sequential = keras.Sequential(
|
60
|
+
[
|
61
|
+
keras.layers.Conv2D(
|
62
|
+
filters=self.num_channels,
|
63
|
+
kernel_size=(1, 1),
|
64
|
+
use_bias=False,
|
65
|
+
data_format=self.data_format,
|
66
|
+
name="aspp_conv_1",
|
67
|
+
),
|
68
|
+
keras.layers.BatchNormalization(
|
69
|
+
axis=self.channel_axis, name="aspp_bn_1"
|
70
|
+
),
|
71
|
+
keras.layers.Activation(
|
72
|
+
self.activation, name="aspp_activation_1"
|
73
|
+
),
|
74
|
+
]
|
75
|
+
)
|
76
|
+
conv_sequential.build(input_shape)
|
77
|
+
self.aspp_parallel_channels.append(conv_sequential)
|
78
|
+
|
79
|
+
# Channel 2 and afterwards are based on self.dilation_rates, and each of
|
80
|
+
# them will have conv2D with 3x3 kernel size.
|
81
|
+
for i, dilation_rate in enumerate(self.dilation_rates):
|
82
|
+
conv_sequential = keras.Sequential(
|
83
|
+
[
|
84
|
+
keras.layers.Conv2D(
|
85
|
+
filters=self.num_channels,
|
86
|
+
kernel_size=(3, 3),
|
87
|
+
padding="same",
|
88
|
+
dilation_rate=dilation_rate,
|
89
|
+
use_bias=False,
|
90
|
+
data_format=self.data_format,
|
91
|
+
name=f"aspp_conv_{i+2}",
|
92
|
+
),
|
93
|
+
keras.layers.BatchNormalization(
|
94
|
+
axis=self.channel_axis, name=f"aspp_bn_{i+2}"
|
95
|
+
),
|
96
|
+
keras.layers.Activation(
|
97
|
+
self.activation, name=f"aspp_activation_{i+2}"
|
98
|
+
),
|
99
|
+
]
|
100
|
+
)
|
101
|
+
conv_sequential.build(input_shape)
|
102
|
+
self.aspp_parallel_channels.append(conv_sequential)
|
103
|
+
|
104
|
+
# Last channel is the global average pooling with conv2D 1x1 kernel.
|
105
|
+
if self.channel_axis == -1:
|
106
|
+
reshape = keras.layers.Reshape((1, 1, channels), name="reshape")
|
107
|
+
else:
|
108
|
+
reshape = keras.layers.Reshape((channels, 1, 1), name="reshape")
|
109
|
+
pool_sequential = keras.Sequential(
|
110
|
+
[
|
111
|
+
keras.layers.GlobalAveragePooling2D(
|
112
|
+
data_format=self.data_format, name="average_pooling"
|
113
|
+
),
|
114
|
+
reshape,
|
115
|
+
keras.layers.Conv2D(
|
116
|
+
filters=self.num_channels,
|
117
|
+
kernel_size=(1, 1),
|
118
|
+
use_bias=False,
|
119
|
+
data_format=self.data_format,
|
120
|
+
name="conv_pooling",
|
121
|
+
),
|
122
|
+
keras.layers.BatchNormalization(
|
123
|
+
axis=self.channel_axis, name="bn_pooling"
|
124
|
+
),
|
125
|
+
keras.layers.Activation(
|
126
|
+
self.activation, name="activation_pooling"
|
127
|
+
),
|
128
|
+
]
|
129
|
+
)
|
130
|
+
pool_sequential.build(input_shape)
|
131
|
+
self.aspp_parallel_channels.append(pool_sequential)
|
132
|
+
|
133
|
+
# Final projection layers
|
134
|
+
projection = keras.Sequential(
|
135
|
+
[
|
136
|
+
keras.layers.Conv2D(
|
137
|
+
filters=self.num_channels,
|
138
|
+
kernel_size=(1, 1),
|
139
|
+
use_bias=False,
|
140
|
+
data_format=self.data_format,
|
141
|
+
name="conv_projection",
|
142
|
+
),
|
143
|
+
keras.layers.BatchNormalization(
|
144
|
+
axis=self.channel_axis, name="bn_projection"
|
145
|
+
),
|
146
|
+
keras.layers.Activation(
|
147
|
+
self.activation, name="activation_projection"
|
148
|
+
),
|
149
|
+
keras.layers.Dropout(rate=self.dropout, name="dropout"),
|
150
|
+
],
|
151
|
+
)
|
152
|
+
projection_input_channels = (
|
153
|
+
2 + len(self.dilation_rates)
|
154
|
+
) * self.num_channels
|
155
|
+
if self.data_format == "channels_first":
|
156
|
+
projection.build(
|
157
|
+
(input_shape[0],)
|
158
|
+
+ (projection_input_channels,)
|
159
|
+
+ (input_shape[2:])
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
projection.build((input_shape[:-1]) + (projection_input_channels,))
|
163
|
+
self.projection = projection
|
164
|
+
self.built = True
|
165
|
+
|
166
|
+
def call(self, inputs):
|
167
|
+
"""Calls the Atrous Spatial Pyramid Pooling layer on an input.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
inputs: A tensor of shape [batch, height, width, channels]
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
A tensor of shape [batch, height, width, num_channels]
|
174
|
+
"""
|
175
|
+
result = []
|
176
|
+
|
177
|
+
for channel in self.aspp_parallel_channels:
|
178
|
+
temp = ops.cast(channel(inputs), inputs.dtype)
|
179
|
+
result.append(temp)
|
180
|
+
|
181
|
+
image_shape = ops.shape(inputs)
|
182
|
+
if self.channel_axis == -1:
|
183
|
+
height, width = image_shape[1], image_shape[2]
|
184
|
+
else:
|
185
|
+
height, width = image_shape[2], image_shape[3]
|
186
|
+
result[-1] = keras.layers.Resizing(
|
187
|
+
height,
|
188
|
+
width,
|
189
|
+
interpolation="bilinear",
|
190
|
+
data_format=self.data_format,
|
191
|
+
name="resizing",
|
192
|
+
)(result[-1])
|
193
|
+
|
194
|
+
result = ops.concatenate(result, axis=self.channel_axis)
|
195
|
+
return self.projection(result)
|
196
|
+
|
197
|
+
def compute_output_shape(self, inputs_shape):
|
198
|
+
if self.data_format == "channels_first":
|
199
|
+
return tuple(
|
200
|
+
(inputs_shape[0],) + (self.num_channels,) + (inputs_shape[2:])
|
201
|
+
)
|
202
|
+
else:
|
203
|
+
return tuple((inputs_shape[:-1]) + (self.num_channels,))
|
204
|
+
|
205
|
+
def get_config(self):
|
206
|
+
config = super().get_config()
|
207
|
+
config.update(
|
208
|
+
{
|
209
|
+
"dilation_rates": self.dilation_rates,
|
210
|
+
"num_channels": self.num_channels,
|
211
|
+
"activation": self.activation,
|
212
|
+
"dropout": self.dropout,
|
213
|
+
}
|
214
|
+
)
|
215
|
+
return config
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
|
5
|
+
DeepLabV3Backbone,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import (
|
8
|
+
DeepLabV3ImageSegmenterPreprocessor,
|
9
|
+
)
|
10
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenter")
|
14
|
+
class DeepLabV3ImageSegmenter(ImageSegmenter):
|
15
|
+
"""DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
backbone: A `keras_hub.models.DeepLabV3` instance.
|
19
|
+
num_classes: int. The number of classes for the detection model. Note
|
20
|
+
that the `num_classes` contains the background class, and the
|
21
|
+
classes from the data should be represented by integers with range
|
22
|
+
`[0, num_classes]`.
|
23
|
+
activation: str or callable. The activation function to use on
|
24
|
+
the `Dense` layer. Set `activation=None` to return the output
|
25
|
+
logits. Defaults to `None`.
|
26
|
+
preprocessor: A `keras_hub.models.DeepLabV3ImageSegmenterPreprocessor`
|
27
|
+
or `None`. If `None`, this model will not apply preprocessing, and
|
28
|
+
inputs should be preprocessed before calling the model.
|
29
|
+
|
30
|
+
Example:
|
31
|
+
Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head.
|
32
|
+
```python
|
33
|
+
images = np.ones(shape=(1, 96, 96, 3))
|
34
|
+
labels = np.zeros(shape=(1, 96, 96, 1))
|
35
|
+
segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
|
36
|
+
"deeplabv3_resnet50_pascalvoc",
|
37
|
+
)
|
38
|
+
segmenter.predict(images)
|
39
|
+
```
|
40
|
+
|
41
|
+
Specify `num_classes` to load randomly initialized segmentation head.
|
42
|
+
```python
|
43
|
+
segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
|
44
|
+
"deeplabv3_resnet50_pascalvoc",
|
45
|
+
num_classes=2,
|
46
|
+
)
|
47
|
+
segmenter.fit(images, labels, epochs=3)
|
48
|
+
segmenter.predict(images) # Trained 2 class segmentation.
|
49
|
+
```
|
50
|
+
Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet
|
51
|
+
effective decoder module to refine the segmentation results especially
|
52
|
+
along object boundaries.
|
53
|
+
```python
|
54
|
+
segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
|
55
|
+
"deeplabv3_plus_resnet50_pascalvoc",
|
56
|
+
)
|
57
|
+
segmenter.predict(images)
|
58
|
+
```
|
59
|
+
"""
|
60
|
+
|
61
|
+
backbone_cls = DeepLabV3Backbone
|
62
|
+
preprocessor_cls = DeepLabV3ImageSegmenterPreprocessor
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
backbone,
|
67
|
+
num_classes,
|
68
|
+
activation=None,
|
69
|
+
preprocessor=None,
|
70
|
+
**kwargs,
|
71
|
+
):
|
72
|
+
data_format = keras.config.image_data_format()
|
73
|
+
# === Layers ===
|
74
|
+
self.output_conv = keras.layers.Conv2D(
|
75
|
+
name="segmentation_output",
|
76
|
+
filters=num_classes,
|
77
|
+
kernel_size=1,
|
78
|
+
use_bias=False,
|
79
|
+
padding="same",
|
80
|
+
activation=activation,
|
81
|
+
data_format=data_format,
|
82
|
+
)
|
83
|
+
|
84
|
+
# === Functional Model ===
|
85
|
+
inputs = backbone.input
|
86
|
+
x = backbone(inputs)
|
87
|
+
outputs = self.output_conv(x)
|
88
|
+
super().__init__(
|
89
|
+
inputs=inputs,
|
90
|
+
outputs=outputs,
|
91
|
+
**kwargs,
|
92
|
+
)
|
93
|
+
|
94
|
+
# === Config ===
|
95
|
+
self.backbone = backbone
|
96
|
+
self.num_classes = num_classes
|
97
|
+
self.activation = activation
|
98
|
+
self.preprocessor = preprocessor
|
99
|
+
|
100
|
+
def get_config(self):
|
101
|
+
# Backbone serialized in `super`
|
102
|
+
config = super().get_config()
|
103
|
+
config.update(
|
104
|
+
{
|
105
|
+
"num_classes": self.num_classes,
|
106
|
+
"activation": self.activation,
|
107
|
+
}
|
108
|
+
)
|
109
|
+
return config
|
@@ -19,9 +19,11 @@ class ImageSegmenterPreprocessor(Preprocessor):
|
|
19
19
|
|
20
20
|
- `x`: The first input, should always be included. It can be an image or
|
21
21
|
a batch of images.
|
22
|
-
- `y`: (Optional) Usually the segmentation mask(s),
|
23
|
-
|
22
|
+
- `y`: (Optional) Usually the segmentation mask(s), if `resize_output_mask`
|
23
|
+
is set to `True` this will be resized to input image shape else will be
|
24
|
+
passed through unaltered.
|
24
25
|
- `sample_weight`: (Optional) Will be passed through unaltered.
|
26
|
+
- `resize_output_mask` bool: If set to `True` the output mask will be resized to the same size as the input image. Defaults to `False`.
|
25
27
|
|
26
28
|
The layer will output either `x`, an `(x, y)` tuple if labels were provided,
|
27
29
|
or an `(x, y, sample_weight)` tuple if labels and sample weight were
|
@@ -29,7 +31,7 @@ class ImageSegmenterPreprocessor(Preprocessor):
|
|
29
31
|
been applied.
|
30
32
|
|
31
33
|
All `ImageSegmenterPreprocessor` tasks include a `from_preset()`
|
32
|
-
constructor which can be used to load a pre-trained config
|
34
|
+
constructor which can be used to load a pre-trained config.
|
33
35
|
You can call the `from_preset()` constructor directly on this base class, in
|
34
36
|
which case the correct class for your model will be automatically
|
35
37
|
instantiated.
|
@@ -49,7 +51,8 @@ class ImageSegmenterPreprocessor(Preprocessor):
|
|
49
51
|
x, y = preprocessor(x, y)
|
50
52
|
|
51
53
|
# Resize a batch of images and masks.
|
52
|
-
x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))],
|
54
|
+
x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))],
|
55
|
+
[np.ones((512, 512, 1)), np.zeros((512, 512, 1))]
|
53
56
|
x, y = preprocessor(x, y)
|
54
57
|
|
55
58
|
# Use a `tf.data.Dataset`.
|
@@ -61,13 +64,35 @@ class ImageSegmenterPreprocessor(Preprocessor):
|
|
61
64
|
def __init__(
|
62
65
|
self,
|
63
66
|
image_converter=None,
|
67
|
+
resize_output_mask=False,
|
64
68
|
**kwargs,
|
65
69
|
):
|
66
70
|
super().__init__(**kwargs)
|
67
71
|
self.image_converter = image_converter
|
72
|
+
self.resize_output_mask = resize_output_mask
|
68
73
|
|
69
74
|
@preprocessing_function
|
70
75
|
def call(self, x, y=None, sample_weight=None):
|
71
76
|
if self.image_converter:
|
72
77
|
x = self.image_converter(x)
|
78
|
+
|
79
|
+
if y is not None and self.image_converter and self.resize_output_mask:
|
80
|
+
|
81
|
+
y = keras.layers.Resizing(
|
82
|
+
height=(
|
83
|
+
self.image_converter.image_size[0]
|
84
|
+
if self.image_converter.image_size
|
85
|
+
else None
|
86
|
+
),
|
87
|
+
width=(
|
88
|
+
self.image_converter.image_size[1]
|
89
|
+
if self.image_converter.image_size
|
90
|
+
else None
|
91
|
+
),
|
92
|
+
crop_to_aspect_ratio=self.image_converter.crop_to_aspect_ratio,
|
93
|
+
interpolation="nearest",
|
94
|
+
data_format=self.image_converter.data_format,
|
95
|
+
dtype=self.dtype_policy,
|
96
|
+
name="mask_resizing",
|
97
|
+
)(y)
|
73
98
|
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
@@ -27,7 +27,7 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
|
|
27
27
|
https://arxiv.org/abs/2403.03206).
|
28
28
|
"""
|
29
29
|
|
30
|
-
def __init__(self, num_train_timesteps=1000, shift=
|
30
|
+
def __init__(self, num_train_timesteps=1000, shift=3.0, **kwargs):
|
31
31
|
super().__init__(**kwargs)
|
32
32
|
self.num_train_timesteps = int(num_train_timesteps)
|
33
33
|
self.shift = float(shift)
|
@@ -65,6 +65,13 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
|
|
65
65
|
timestep = self._sigma_to_timestep(sigma)
|
66
66
|
return sigma, timestep
|
67
67
|
|
68
|
+
def add_noise(self, inputs, noises, step, num_steps):
|
69
|
+
sigma, _ = self(step, num_steps)
|
70
|
+
return ops.add(
|
71
|
+
ops.multiply(sigma, noises),
|
72
|
+
ops.multiply(ops.subtract(1.0, sigma), inputs),
|
73
|
+
)
|
74
|
+
|
68
75
|
def get_config(self):
|
69
76
|
config = super().get_config()
|
70
77
|
config.update(
|