keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +12 -0
- keras_hub/api/models/__init__.py +32 -0
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/rms_normalization.py +34 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
- keras_hub/src/layers/preprocessing/image_converter.py +5 -0
- keras_hub/src/models/albert/albert_presets.py +0 -8
- keras_hub/src/models/bart/bart_presets.py +0 -6
- keras_hub/src/models/bert/bert_presets.py +0 -20
- keras_hub/src/models/bloom/bloom_presets.py +0 -16
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +0 -6
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
- keras_hub/src/models/efficientnet/mbconv.py +52 -21
- keras_hub/src/models/electra/electra_presets.py +0 -12
- keras_hub/src/models/f_net/f_net_presets.py +0 -4
- keras_hub/src/models/falcon/falcon_presets.py +0 -2
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +494 -0
- keras_hub/src/models/flux/flux_maths.py +218 -0
- keras_hub/src/models/flux/flux_model.py +231 -0
- keras_hub/src/models/flux/flux_presets.py +14 -0
- keras_hub/src/models/flux/flux_text_to_image.py +142 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_presets.py +0 -40
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_to_image.py +16 -10
- keras_hub/src/models/inpaint.py +20 -13
- keras_hub/src/models/llama/llama_backbone.py +1 -1
- keras_hub/src/models/llama/llama_presets.py +5 -15
- keras_hub/src/models/llama3/llama3_presets.py +0 -8
- keras_hub/src/models/mistral/mistral_presets.py +0 -6
- keras_hub/src/models/mit/mit_backbone.py +41 -27
- keras_hub/src/models/mit/mit_layers.py +9 -7
- keras_hub/src/models/mit/mit_presets.py +12 -24
- keras_hub/src/models/opt/opt_presets.py +0 -8
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
- keras_hub/src/models/phi3/phi3_presets.py +0 -4
- keras_hub/src/models/resnet/resnet_presets.py +10 -42
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
- keras_hub/src/models/roberta/roberta_presets.py +0 -4
- keras_hub/src/models/sam/sam_backbone.py +0 -1
- keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
- keras_hub/src/models/sam/sam_presets.py +0 -6
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +163 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +124 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +41 -13
- keras_hub/src/models/text_to_image.py +13 -5
- keras_hub/src/models/vgg/vgg_backbone.py +1 -1
- keras_hub/src/models/vgg/vgg_presets.py +0 -8
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
- keras_hub/src/models/whisper/whisper_presets.py +0 -20
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/utils/preset_utils.py +17 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,218 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
class TimestepEmbedding(keras.layers.Layer):
|
6
|
+
"""
|
7
|
+
Creates sinusoidal timestep embeddings.
|
8
|
+
|
9
|
+
Call arguments:
|
10
|
+
t: KerasTensor of shape (N,), representing N indices, one per batch element.
|
11
|
+
These values may be fractional.
|
12
|
+
dim: int. The dimension of the output.
|
13
|
+
max_period: int, optional. Controls the minimum frequency of the embeddings. Defaults to 10000.
|
14
|
+
time_factor: float, optional. A scaling factor applied to `t`. Defaults to 1000.0.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
KerasTensor: A tensor of shape (N, D) representing the positional embeddings,
|
18
|
+
where N is the number of batch elements and D is the specified dimension `dim`.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def call(self, t, dim, max_period=10000, time_factor=1000.0):
|
22
|
+
t = time_factor * t
|
23
|
+
half_dim = dim // 2
|
24
|
+
freqs = ops.exp(
|
25
|
+
ops.cast(-ops.log(max_period), dtype=t.dtype)
|
26
|
+
* ops.arange(half_dim, dtype=t.dtype)
|
27
|
+
/ half_dim
|
28
|
+
)
|
29
|
+
args = t[:, None] * freqs[None]
|
30
|
+
embedding = ops.concatenate([ops.cos(args), ops.sin(args)], axis=-1)
|
31
|
+
|
32
|
+
if dim % 2 != 0:
|
33
|
+
embedding = ops.concatenate(
|
34
|
+
[embedding, ops.zeros_like(embedding[:, :1])], axis=-1
|
35
|
+
)
|
36
|
+
|
37
|
+
return embedding
|
38
|
+
|
39
|
+
|
40
|
+
class RotaryPositionalEmbedding(keras.layers.Layer):
|
41
|
+
"""
|
42
|
+
Applies Rotary Positional Embedding (RoPE) to the input tensor.
|
43
|
+
|
44
|
+
Call arguments:
|
45
|
+
pos: KerasTensor. The positional tensor with shape (..., n, d).
|
46
|
+
dim: int. The embedding dimension, should be even.
|
47
|
+
theta: int. The base frequency.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
KerasTensor: The tensor with applied RoPE transformation.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def call(self, pos, dim, theta):
|
54
|
+
scale = ops.arange(0, dim, 2, dtype="float32") / dim
|
55
|
+
omega = 1.0 / (theta**scale)
|
56
|
+
out = ops.einsum("...n,d->...nd", pos, omega)
|
57
|
+
out = ops.stack(
|
58
|
+
[ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1
|
59
|
+
)
|
60
|
+
out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2))
|
61
|
+
return ops.cast(out, dtype="float32")
|
62
|
+
|
63
|
+
|
64
|
+
class ApplyRoPE(keras.layers.Layer):
|
65
|
+
"""
|
66
|
+
Applies the RoPE transformation to the query and key tensors.
|
67
|
+
|
68
|
+
Call arguments:
|
69
|
+
xq: KerasTensor. The query tensor of shape (..., L, D).
|
70
|
+
xk: KerasTensor. The key tensor of shape (..., L, D).
|
71
|
+
freqs_cis: KerasTensor. The frequency complex numbers tensor with shape (..., 2).
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
|
75
|
+
"""
|
76
|
+
|
77
|
+
def call(self, xq, xk, freqs_cis):
|
78
|
+
xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 1, 2))
|
79
|
+
xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 1, 2))
|
80
|
+
|
81
|
+
xq_out = (
|
82
|
+
freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
83
|
+
)
|
84
|
+
xk_out = (
|
85
|
+
freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
86
|
+
)
|
87
|
+
|
88
|
+
return ops.reshape(xq_out, ops.shape(xq)), ops.reshape(
|
89
|
+
xk_out, ops.shape(xk)
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
class FluxRoPEAttention(keras.layers.Layer):
|
94
|
+
"""
|
95
|
+
Computes the attention mechanism with the RoPE transformation applied to the query and key tensors.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
dropout_p: float, optional. Dropout probability. Defaults to 0.0.
|
99
|
+
is_causal: bool, optional. If True, applies causal masking. Defaults to False.
|
100
|
+
|
101
|
+
Call arguments:
|
102
|
+
q: KerasTensor. Query tensor of shape (..., L, D).
|
103
|
+
k: KerasTensor. Key tensor of shape (..., S, D).
|
104
|
+
v: KerasTensor. Value tensor of shape (..., S, D).
|
105
|
+
positional_encoding: KerasTensor. Positional encoding tensor.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
KerasTensor: The resulting tensor from the attention mechanism.
|
109
|
+
"""
|
110
|
+
|
111
|
+
def __init__(self, dropout_p=0.0, is_causal=False):
|
112
|
+
super(FluxRoPEAttention, self).__init__()
|
113
|
+
self.dropout_p = dropout_p
|
114
|
+
self.is_causal = is_causal
|
115
|
+
|
116
|
+
def call(self, q, k, v, positional_encoding):
|
117
|
+
# Apply the RoPE transformation
|
118
|
+
q, k = ApplyRoPE()(q, k, positional_encoding)
|
119
|
+
|
120
|
+
# Scaled dot-product attention
|
121
|
+
x = scaled_dot_product_attention(
|
122
|
+
q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal
|
123
|
+
)
|
124
|
+
x = ops.transpose(x, (0, 2, 1, 3))
|
125
|
+
b, l, h, d = ops.shape(x)
|
126
|
+
return ops.reshape(x, (b, l, h * d))
|
127
|
+
|
128
|
+
|
129
|
+
# TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original
|
130
|
+
# implementation. It uses torch.functional.scaled_dot_product_attention() - do we have an equivalent already in Keras?
|
131
|
+
def scaled_dot_product_attention(
|
132
|
+
query,
|
133
|
+
key,
|
134
|
+
value,
|
135
|
+
attn_mask=None,
|
136
|
+
dropout_p=0.0,
|
137
|
+
is_causal=False,
|
138
|
+
scale=None,
|
139
|
+
):
|
140
|
+
"""
|
141
|
+
Computes the scaled dot-product attention.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
query: KerasTensor. Query tensor of shape (..., L, D).
|
145
|
+
key: KerasTensor. Key tensor of shape (..., S, D).
|
146
|
+
value: KerasTensor. Value tensor of shape (..., S, D).
|
147
|
+
attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to None.
|
148
|
+
dropout_p: float, optional. Dropout probability. Defaults to 0.0.
|
149
|
+
is_causal: bool, optional. If True, applies causal masking. Defaults to False.
|
150
|
+
scale: float, optional. Scale factor for attention. Defaults to None.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
KerasTensor: The output tensor from the attention mechanism.
|
154
|
+
"""
|
155
|
+
L, S = ops.shape(query)[-2], ops.shape(key)[-2]
|
156
|
+
scale_factor = (
|
157
|
+
1 / ops.sqrt(ops.cast(ops.shape(query)[-1], dtype=query.dtype))
|
158
|
+
if scale is None
|
159
|
+
else scale
|
160
|
+
)
|
161
|
+
attn_bias = ops.zeros((L, S), dtype=query.dtype)
|
162
|
+
|
163
|
+
if is_causal:
|
164
|
+
assert attn_mask is None
|
165
|
+
temp_mask = ops.ones((L, S), dtype=ops.bool)
|
166
|
+
temp_mask = ops.tril(temp_mask, diagonal=0)
|
167
|
+
attn_bias = ops.where(temp_mask, attn_bias, float("-inf"))
|
168
|
+
|
169
|
+
if attn_mask is not None:
|
170
|
+
if ops.shape(attn_mask)[-1] == 1: # If the mask is 3D
|
171
|
+
attn_bias += attn_mask
|
172
|
+
else:
|
173
|
+
attn_bias = ops.where(attn_mask, attn_bias, float("-inf"))
|
174
|
+
|
175
|
+
# Compute attention weights
|
176
|
+
attn_weight = (
|
177
|
+
ops.matmul(query, ops.transpose(key, axes=[0, 1, 3, 2])) * scale_factor
|
178
|
+
)
|
179
|
+
attn_weight += attn_bias
|
180
|
+
attn_weight = keras.activations.softmax(attn_weight, axis=-1)
|
181
|
+
|
182
|
+
if dropout_p > 0.0:
|
183
|
+
attn_weight = keras.layers.Dropout(dropout_p)(
|
184
|
+
attn_weight, training=True
|
185
|
+
)
|
186
|
+
|
187
|
+
return ops.matmul(attn_weight, value)
|
188
|
+
|
189
|
+
|
190
|
+
def rearrange_symbolic_tensors(qkv, K, H):
|
191
|
+
"""
|
192
|
+
Splits the qkv tensor into query (q), key (k), and value (v) components.
|
193
|
+
|
194
|
+
Mimics rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=num_heads),
|
195
|
+
for graph-mode TensorFlow support when doing functional subclassing
|
196
|
+
models.
|
197
|
+
|
198
|
+
Arguments:
|
199
|
+
qkv: np.ndarray. Input tensor of shape (B, L, K*H*D).
|
200
|
+
K: int. Number of components (q, k, v).
|
201
|
+
H: int. Number of attention heads.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
tuple: q, k, v tensors of shape (B, H, L, D).
|
205
|
+
"""
|
206
|
+
# Get the shape of qkv and calculate L and D
|
207
|
+
B, L, dim = ops.shape(qkv)
|
208
|
+
D = dim // (K * H)
|
209
|
+
|
210
|
+
# Reshape and transpose the qkv tensor
|
211
|
+
qkv_reshaped = ops.reshape(qkv, (B, L, K, H, D))
|
212
|
+
qkv_transposed = ops.transpose(qkv_reshaped, (2, 0, 3, 1, 4))
|
213
|
+
|
214
|
+
# Split q, k, v along the first dimension (K)
|
215
|
+
qkv_splits = ops.split(qkv_transposed, K, axis=0)
|
216
|
+
q, k, v = [ops.squeeze(split, 0) for split in qkv_splits]
|
217
|
+
|
218
|
+
return q, k, v
|
@@ -0,0 +1,231 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
from keras_hub.src.models.flux.flux_layers import DoubleStreamBlock
|
6
|
+
from keras_hub.src.models.flux.flux_layers import EmbedND
|
7
|
+
from keras_hub.src.models.flux.flux_layers import LastLayer
|
8
|
+
from keras_hub.src.models.flux.flux_layers import MLPEmbedder
|
9
|
+
from keras_hub.src.models.flux.flux_layers import SingleStreamBlock
|
10
|
+
from keras_hub.src.models.flux.flux_maths import TimestepEmbedding
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.FluxBackbone")
|
14
|
+
class FluxBackbone(Backbone):
|
15
|
+
"""
|
16
|
+
Transformer model for flow matching on sequences.
|
17
|
+
|
18
|
+
The model processes image and text data with associated positional and timestep
|
19
|
+
embeddings, and optionally applies guidance embedding. Double-stream blocks
|
20
|
+
handle separate image and text streams, while single-stream blocks combine
|
21
|
+
these streams. Ported from: https://github.com/black-forest-labs/flux
|
22
|
+
|
23
|
+
Args:
|
24
|
+
input_channels: int. The number of input channels.
|
25
|
+
hidden_size: int. The hidden size of the transformer, must be divisible by `num_heads`.
|
26
|
+
mlp_ratio: float. The ratio of the MLP dimension to the hidden size.
|
27
|
+
num_heads: int. The number of attention heads.
|
28
|
+
depth: int. The number of double-stream blocks.
|
29
|
+
depth_single_blocks: int. The number of single-stream blocks.
|
30
|
+
axes_dim: list[int]. A list of dimensions for the positional embedding axes.
|
31
|
+
theta: int. The base frequency for positional embeddings.
|
32
|
+
use_bias: bool. Whether to apply bias to the query, key, and value projections.
|
33
|
+
guidance_embed: bool. If True, applies guidance embedding in the model.
|
34
|
+
|
35
|
+
Call arguments:
|
36
|
+
image: KerasTensor. Image input tensor of shape (N, L, D) where N is the batch size,
|
37
|
+
L is the sequence length, and D is the feature dimension.
|
38
|
+
image_ids: KerasTensor. Image ID input tensor of shape (N, L, D) corresponding
|
39
|
+
to the image sequences.
|
40
|
+
text: KerasTensor. Text input tensor of shape (N, L, D).
|
41
|
+
text_ids: KerasTensor. Text ID input tensor of shape (N, L, D) corresponding
|
42
|
+
to the text sequences.
|
43
|
+
timesteps: KerasTensor. Timestep tensor used to compute positional embeddings.
|
44
|
+
y: KerasTensor. Additional vector input, such as target values.
|
45
|
+
guidance: KerasTensor, optional. Guidance input tensor used
|
46
|
+
in guidance-embedded models.
|
47
|
+
Raises:
|
48
|
+
ValueError: If `hidden_size` is not divisible by `num_heads`, or if `sum(axes_dim)` is not equal to the
|
49
|
+
positional embedding dimension.
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
input_channels,
|
55
|
+
hidden_size,
|
56
|
+
mlp_ratio,
|
57
|
+
num_heads,
|
58
|
+
depth,
|
59
|
+
depth_single_blocks,
|
60
|
+
axes_dim,
|
61
|
+
theta,
|
62
|
+
use_bias,
|
63
|
+
guidance_embed=False,
|
64
|
+
# These will be inferred from the CLIP/T5 encoders later
|
65
|
+
image_shape=(None, 768, 3072),
|
66
|
+
text_shape=(None, 768, 3072),
|
67
|
+
image_ids_shape=(None, 768, 3072),
|
68
|
+
text_ids_shape=(None, 768, 3072),
|
69
|
+
y_shape=(None, 128),
|
70
|
+
**kwargs,
|
71
|
+
):
|
72
|
+
|
73
|
+
# === Layers ===
|
74
|
+
self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim)
|
75
|
+
self.image_input_embedder = keras.layers.Dense(
|
76
|
+
hidden_size, use_bias=True
|
77
|
+
)
|
78
|
+
self.time_input_embedder = MLPEmbedder(hidden_dim=hidden_size)
|
79
|
+
self.vector_embedder = MLPEmbedder(hidden_dim=hidden_size)
|
80
|
+
self.guidance_input_embedder = (
|
81
|
+
MLPEmbedder(hidden_dim=hidden_size)
|
82
|
+
if guidance_embed
|
83
|
+
else keras.layers.Identity()
|
84
|
+
)
|
85
|
+
self.text_input_embedder = keras.layers.Dense(hidden_size)
|
86
|
+
|
87
|
+
self.double_blocks = [
|
88
|
+
DoubleStreamBlock(
|
89
|
+
hidden_size,
|
90
|
+
num_heads,
|
91
|
+
mlp_ratio=mlp_ratio,
|
92
|
+
use_bias=use_bias,
|
93
|
+
)
|
94
|
+
for _ in range(depth)
|
95
|
+
]
|
96
|
+
|
97
|
+
self.single_blocks = [
|
98
|
+
SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
99
|
+
for _ in range(depth_single_blocks)
|
100
|
+
]
|
101
|
+
|
102
|
+
self.final_layer = LastLayer(hidden_size, 1, input_channels)
|
103
|
+
self.timestep_embedding = TimestepEmbedding()
|
104
|
+
self.guidance_embed = guidance_embed
|
105
|
+
|
106
|
+
# === Functional Model ===
|
107
|
+
image_input = keras.Input(shape=image_shape, name="image")
|
108
|
+
image_ids = keras.Input(shape=image_ids_shape, name="image_ids")
|
109
|
+
text_input = keras.Input(shape=text_shape, name="text")
|
110
|
+
text_ids = keras.Input(shape=text_ids_shape, name="text_ids")
|
111
|
+
y = keras.Input(shape=y_shape, name="y")
|
112
|
+
timesteps_input = keras.Input(shape=(), name="timesteps")
|
113
|
+
guidance_input = keras.Input(shape=(), name="guidance")
|
114
|
+
|
115
|
+
# running on sequences image
|
116
|
+
image = self.image_input_embedder(image_input)
|
117
|
+
modulation_encoding = self.time_input_embedder(
|
118
|
+
self.timestep_embedding(timesteps_input, dim=256)
|
119
|
+
)
|
120
|
+
if self.guidance_embed:
|
121
|
+
if guidance_input is None:
|
122
|
+
raise ValueError(
|
123
|
+
"Didn't get guidance strength for guidance distilled model."
|
124
|
+
)
|
125
|
+
modulation_encoding = (
|
126
|
+
modulation_encoding
|
127
|
+
+ self.guidance_input_embedder(
|
128
|
+
self.timestep_embedding(guidance_input, dim=256)
|
129
|
+
)
|
130
|
+
)
|
131
|
+
|
132
|
+
modulation_encoding = modulation_encoding + self.vector_embedder(y)
|
133
|
+
text = self.text_input_embedder(text_input)
|
134
|
+
|
135
|
+
ids = keras.ops.concatenate((text_ids, image_ids), axis=1)
|
136
|
+
positional_encoding = self.positional_embedder(ids)
|
137
|
+
|
138
|
+
for block in self.double_blocks:
|
139
|
+
image, text = block(
|
140
|
+
image=image,
|
141
|
+
text=text,
|
142
|
+
modulation_encoding=modulation_encoding,
|
143
|
+
positional_encoding=positional_encoding,
|
144
|
+
)
|
145
|
+
|
146
|
+
image = keras.ops.concatenate((text, image), axis=1)
|
147
|
+
for block in self.single_blocks:
|
148
|
+
image = block(
|
149
|
+
image,
|
150
|
+
modulation_encoding=modulation_encoding,
|
151
|
+
positional_encoding=positional_encoding,
|
152
|
+
)
|
153
|
+
image = image[:, text.shape[1] :, ...]
|
154
|
+
|
155
|
+
image = self.final_layer(
|
156
|
+
image, modulation_encoding
|
157
|
+
) # (N, T, patch_size ** 2 * output_channels)
|
158
|
+
|
159
|
+
super().__init__(
|
160
|
+
inputs={
|
161
|
+
"image": image_input,
|
162
|
+
"image_ids": image_ids,
|
163
|
+
"text": text_input,
|
164
|
+
"text_ids": text_ids,
|
165
|
+
"y": y,
|
166
|
+
"timesteps": timesteps_input,
|
167
|
+
"guidance": guidance_input,
|
168
|
+
},
|
169
|
+
outputs=image,
|
170
|
+
**kwargs,
|
171
|
+
)
|
172
|
+
|
173
|
+
# === Config ===
|
174
|
+
self.input_channels = input_channels
|
175
|
+
self.output_channels = self.input_channels
|
176
|
+
self.hidden_size = hidden_size
|
177
|
+
self.num_heads = num_heads
|
178
|
+
self.image_shape = image_shape
|
179
|
+
self.text_shape = text_shape
|
180
|
+
self.image_ids_shape = image_ids_shape
|
181
|
+
self.text_ids_shape = text_ids_shape
|
182
|
+
self.y_shape = y_shape
|
183
|
+
self.mlp_ratio = mlp_ratio
|
184
|
+
self.depth = depth
|
185
|
+
self.depth_single_blocks = depth_single_blocks
|
186
|
+
self.axes_dim = axes_dim
|
187
|
+
self.theta = theta
|
188
|
+
self.use_bias = use_bias
|
189
|
+
|
190
|
+
def get_config(self):
|
191
|
+
config = super().get_config()
|
192
|
+
config.update(
|
193
|
+
{
|
194
|
+
"input_channels": self.input_channels,
|
195
|
+
"hidden_size": self.hidden_size,
|
196
|
+
"mlp_ratio": self.mlp_ratio,
|
197
|
+
"num_heads": self.num_heads,
|
198
|
+
"depth": self.depth,
|
199
|
+
"depth_single_blocks": self.depth_single_blocks,
|
200
|
+
"axes_dim": self.axes_dim,
|
201
|
+
"theta": self.theta,
|
202
|
+
"use_bias": self.use_bias,
|
203
|
+
"guidance_embed": self.guidance_embed,
|
204
|
+
"image_shape": self.image_shape,
|
205
|
+
"text_shape": self.text_shape,
|
206
|
+
"image_ids_shape": self.image_ids_shape,
|
207
|
+
"text_ids_shape": self.text_ids_shape,
|
208
|
+
"y_shape": self.y_shape,
|
209
|
+
}
|
210
|
+
)
|
211
|
+
return config
|
212
|
+
|
213
|
+
def encode_text_step(self, token_ids, negative_token_ids):
|
214
|
+
raise NotImplementedError("Not implemented yet")
|
215
|
+
|
216
|
+
def encode(token_ids):
|
217
|
+
raise NotImplementedError("Not implemented yet")
|
218
|
+
|
219
|
+
def encode_image_step(self, images):
|
220
|
+
raise NotImplementedError("Not implemented yet")
|
221
|
+
|
222
|
+
def add_noise_step(self, latents, noises, step, num_steps):
|
223
|
+
raise NotImplementedError("Not implemented yet")
|
224
|
+
|
225
|
+
def denoise_step(
|
226
|
+
self,
|
227
|
+
):
|
228
|
+
raise NotImplementedError("Not implemented yet")
|
229
|
+
|
230
|
+
def decode_step(self, latents):
|
231
|
+
raise NotImplementedError("Not implemented yet")
|
@@ -0,0 +1,14 @@
|
|
1
|
+
"""FLUX model preset configurations."""
|
2
|
+
|
3
|
+
presets = {
|
4
|
+
"schnell": {
|
5
|
+
"metadata": {
|
6
|
+
"description": (
|
7
|
+
"A 12 billion parameter rectified flow transformer capable of generating images from text descriptions."
|
8
|
+
),
|
9
|
+
"params": 124439808,
|
10
|
+
"path": "flux",
|
11
|
+
},
|
12
|
+
"kaggle_handle": "TBA",
|
13
|
+
},
|
14
|
+
}
|
@@ -0,0 +1,142 @@
|
|
1
|
+
from keras import ops
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.flux.flux_model import FluxBackbone
|
5
|
+
from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
|
6
|
+
FluxTextToImagePreprocessor,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.text_to_image import TextToImage
|
9
|
+
|
10
|
+
|
11
|
+
@keras_hub_export("keras_hub.models.FluxTextToImage")
|
12
|
+
class FluxTextToImage(TextToImage):
|
13
|
+
"""An end-to-end Flux model for text-to-image generation.
|
14
|
+
|
15
|
+
This model has a `generate()` method, which generates image based on a
|
16
|
+
prompt.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
backbone: A `keras_hub.models.FluxBackbone` instance.
|
20
|
+
preprocessor: A
|
21
|
+
`keras_hub.models.FluxTextToImagePreprocessor` instance.
|
22
|
+
|
23
|
+
Examples:
|
24
|
+
|
25
|
+
Use `generate()` to do image generation.
|
26
|
+
```python
|
27
|
+
text_to_image = keras_hub.models.FluxTextToImage.from_preset(
|
28
|
+
"TBA", height=512, width=512
|
29
|
+
)
|
30
|
+
text_to_image.generate(
|
31
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
32
|
+
)
|
33
|
+
|
34
|
+
# Generate with batched prompts.
|
35
|
+
text_to_image.generate(
|
36
|
+
["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
|
37
|
+
)
|
38
|
+
|
39
|
+
# Generate with different `num_steps` and `guidance_scale`.
|
40
|
+
text_to_image.generate(
|
41
|
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
42
|
+
num_steps=50,
|
43
|
+
guidance_scale=5.0,
|
44
|
+
)
|
45
|
+
|
46
|
+
# Generate with `negative_prompts`.
|
47
|
+
text_to_image.generate(
|
48
|
+
{
|
49
|
+
"prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
50
|
+
"negative_prompts": "green color",
|
51
|
+
}
|
52
|
+
)
|
53
|
+
```
|
54
|
+
"""
|
55
|
+
|
56
|
+
backbone_cls = FluxBackbone
|
57
|
+
preprocessor_cls = FluxTextToImagePreprocessor
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
backbone,
|
62
|
+
preprocessor,
|
63
|
+
**kwargs,
|
64
|
+
):
|
65
|
+
# === Layers ===
|
66
|
+
self.backbone = backbone
|
67
|
+
self.preprocessor = preprocessor
|
68
|
+
|
69
|
+
# === Functional Model ===
|
70
|
+
inputs = backbone.input
|
71
|
+
outputs = backbone.output
|
72
|
+
super().__init__(
|
73
|
+
inputs=inputs,
|
74
|
+
outputs=outputs,
|
75
|
+
**kwargs,
|
76
|
+
)
|
77
|
+
|
78
|
+
def fit(self, *args, **kwargs):
|
79
|
+
raise NotImplementedError(
|
80
|
+
"Currently, `fit` is not supported for " "`FluxTextToImage`."
|
81
|
+
)
|
82
|
+
|
83
|
+
def generate_step(
|
84
|
+
self,
|
85
|
+
latents,
|
86
|
+
token_ids,
|
87
|
+
num_steps,
|
88
|
+
guidance_scale,
|
89
|
+
):
|
90
|
+
"""A compilable generation function for batched of inputs.
|
91
|
+
|
92
|
+
This function represents the inner, XLA-compilable, generation function
|
93
|
+
for batched inputs.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
latents: A (batch_size, height, width, channels) tensor
|
97
|
+
containing the latents to start generation from. Typically, this
|
98
|
+
tensor is sampled from the Gaussian distribution.
|
99
|
+
token_ids: A pair of (batch_size, num_tokens) tensor containing the
|
100
|
+
tokens based on the input prompts and negative prompts.
|
101
|
+
num_steps: int. The number of diffusion steps to take.
|
102
|
+
guidance_scale: float. The classifier free guidance scale defined in
|
103
|
+
[Classifier-Free Diffusion Guidance](
|
104
|
+
https://arxiv.org/abs/2207.12598). Higher scale encourages to
|
105
|
+
generate images that are closely linked to prompts, usually at
|
106
|
+
the expense of lower image quality.
|
107
|
+
"""
|
108
|
+
token_ids, negative_token_ids = token_ids
|
109
|
+
|
110
|
+
# Encode prompts.
|
111
|
+
embeddings = self.backbone.encode_text_step(
|
112
|
+
token_ids, negative_token_ids
|
113
|
+
)
|
114
|
+
|
115
|
+
# Denoise.
|
116
|
+
def body_fun(step, latents):
|
117
|
+
return self.backbone.denoise_step(
|
118
|
+
latents,
|
119
|
+
embeddings,
|
120
|
+
step,
|
121
|
+
num_steps,
|
122
|
+
guidance_scale,
|
123
|
+
)
|
124
|
+
|
125
|
+
latents = ops.fori_loop(0, num_steps, body_fun, latents)
|
126
|
+
|
127
|
+
# Decode.
|
128
|
+
return self.backbone.decode_step(latents)
|
129
|
+
|
130
|
+
def generate(
|
131
|
+
self,
|
132
|
+
inputs,
|
133
|
+
num_steps=28,
|
134
|
+
guidance_scale=7.0,
|
135
|
+
seed=None,
|
136
|
+
):
|
137
|
+
return super().generate(
|
138
|
+
inputs,
|
139
|
+
num_steps=num_steps,
|
140
|
+
guidance_scale=guidance_scale,
|
141
|
+
seed=seed,
|
142
|
+
)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import layers
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.models.flux.flux_model import FluxBackbone
|
6
|
+
from keras_hub.src.models.preprocessor import Preprocessor
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.FluxTextToImagePreprocessor")
|
10
|
+
class FluxTextToImagePreprocessor(Preprocessor):
|
11
|
+
"""Flux text-to-image model preprocessor.
|
12
|
+
|
13
|
+
This preprocessing layer is meant for use with
|
14
|
+
`keras_hub.models.FluxTextToImagePreprocessor`.
|
15
|
+
|
16
|
+
For use with generation, the layer exposes one methods
|
17
|
+
`generate_preprocess()`.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
|
21
|
+
t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
|
22
|
+
"""
|
23
|
+
|
24
|
+
backbone_cls = FluxBackbone
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
clip_l_preprocessor,
|
29
|
+
t5_preprocessor=None,
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
super().__init__(**kwargs)
|
33
|
+
self.clip_l_preprocessor = clip_l_preprocessor
|
34
|
+
self.t5_preprocessor = t5_preprocessor
|
35
|
+
|
36
|
+
@property
|
37
|
+
def sequence_length(self):
|
38
|
+
"""The padded length of model input sequences."""
|
39
|
+
return self.clip_l_preprocessor.sequence_length
|
40
|
+
|
41
|
+
def build(self, input_shape):
|
42
|
+
self.built = True
|
43
|
+
|
44
|
+
def generate_preprocess(self, x):
|
45
|
+
token_ids = {}
|
46
|
+
token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
|
47
|
+
if self.t5_preprocessor is not None:
|
48
|
+
token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
|
49
|
+
return token_ids
|
50
|
+
|
51
|
+
def get_config(self):
|
52
|
+
config = super().get_config()
|
53
|
+
config.update(
|
54
|
+
{
|
55
|
+
"clip_l_preprocessor": layers.serialize(
|
56
|
+
self.clip_l_preprocessor
|
57
|
+
),
|
58
|
+
"t5_preprocessor": layers.serialize(self.t5_preprocessor),
|
59
|
+
}
|
60
|
+
)
|
61
|
+
return config
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def from_config(cls, config):
|
65
|
+
for layer_name in (
|
66
|
+
"clip_l_preprocessor",
|
67
|
+
"t5_preprocessor",
|
68
|
+
):
|
69
|
+
if layer_name in config and isinstance(config[layer_name], dict):
|
70
|
+
config[layer_name] = keras.layers.deserialize(
|
71
|
+
config[layer_name]
|
72
|
+
)
|
73
|
+
return cls(**config)
|