keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,218 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class TimestepEmbedding(keras.layers.Layer):
6
+ """
7
+ Creates sinusoidal timestep embeddings.
8
+
9
+ Call arguments:
10
+ t: KerasTensor of shape (N,), representing N indices, one per batch element.
11
+ These values may be fractional.
12
+ dim: int. The dimension of the output.
13
+ max_period: int, optional. Controls the minimum frequency of the embeddings. Defaults to 10000.
14
+ time_factor: float, optional. A scaling factor applied to `t`. Defaults to 1000.0.
15
+
16
+ Returns:
17
+ KerasTensor: A tensor of shape (N, D) representing the positional embeddings,
18
+ where N is the number of batch elements and D is the specified dimension `dim`.
19
+ """
20
+
21
+ def call(self, t, dim, max_period=10000, time_factor=1000.0):
22
+ t = time_factor * t
23
+ half_dim = dim // 2
24
+ freqs = ops.exp(
25
+ ops.cast(-ops.log(max_period), dtype=t.dtype)
26
+ * ops.arange(half_dim, dtype=t.dtype)
27
+ / half_dim
28
+ )
29
+ args = t[:, None] * freqs[None]
30
+ embedding = ops.concatenate([ops.cos(args), ops.sin(args)], axis=-1)
31
+
32
+ if dim % 2 != 0:
33
+ embedding = ops.concatenate(
34
+ [embedding, ops.zeros_like(embedding[:, :1])], axis=-1
35
+ )
36
+
37
+ return embedding
38
+
39
+
40
+ class RotaryPositionalEmbedding(keras.layers.Layer):
41
+ """
42
+ Applies Rotary Positional Embedding (RoPE) to the input tensor.
43
+
44
+ Call arguments:
45
+ pos: KerasTensor. The positional tensor with shape (..., n, d).
46
+ dim: int. The embedding dimension, should be even.
47
+ theta: int. The base frequency.
48
+
49
+ Returns:
50
+ KerasTensor: The tensor with applied RoPE transformation.
51
+ """
52
+
53
+ def call(self, pos, dim, theta):
54
+ scale = ops.arange(0, dim, 2, dtype="float32") / dim
55
+ omega = 1.0 / (theta**scale)
56
+ out = ops.einsum("...n,d->...nd", pos, omega)
57
+ out = ops.stack(
58
+ [ops.cos(out), -ops.sin(out), ops.sin(out), ops.cos(out)], axis=-1
59
+ )
60
+ out = ops.reshape(out, ops.shape(out)[:-1] + (2, 2))
61
+ return ops.cast(out, dtype="float32")
62
+
63
+
64
+ class ApplyRoPE(keras.layers.Layer):
65
+ """
66
+ Applies the RoPE transformation to the query and key tensors.
67
+
68
+ Call arguments:
69
+ xq: KerasTensor. The query tensor of shape (..., L, D).
70
+ xk: KerasTensor. The key tensor of shape (..., L, D).
71
+ freqs_cis: KerasTensor. The frequency complex numbers tensor with shape (..., 2).
72
+
73
+ Returns:
74
+ tuple[KerasTensor, KerasTensor]: The transformed query and key tensors.
75
+ """
76
+
77
+ def call(self, xq, xk, freqs_cis):
78
+ xq_ = ops.reshape(xq, (*ops.shape(xq)[:-1], -1, 1, 2))
79
+ xk_ = ops.reshape(xk, (*ops.shape(xk)[:-1], -1, 1, 2))
80
+
81
+ xq_out = (
82
+ freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
83
+ )
84
+ xk_out = (
85
+ freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
86
+ )
87
+
88
+ return ops.reshape(xq_out, ops.shape(xq)), ops.reshape(
89
+ xk_out, ops.shape(xk)
90
+ )
91
+
92
+
93
+ class FluxRoPEAttention(keras.layers.Layer):
94
+ """
95
+ Computes the attention mechanism with the RoPE transformation applied to the query and key tensors.
96
+
97
+ Args:
98
+ dropout_p: float, optional. Dropout probability. Defaults to 0.0.
99
+ is_causal: bool, optional. If True, applies causal masking. Defaults to False.
100
+
101
+ Call arguments:
102
+ q: KerasTensor. Query tensor of shape (..., L, D).
103
+ k: KerasTensor. Key tensor of shape (..., S, D).
104
+ v: KerasTensor. Value tensor of shape (..., S, D).
105
+ positional_encoding: KerasTensor. Positional encoding tensor.
106
+
107
+ Returns:
108
+ KerasTensor: The resulting tensor from the attention mechanism.
109
+ """
110
+
111
+ def __init__(self, dropout_p=0.0, is_causal=False):
112
+ super(FluxRoPEAttention, self).__init__()
113
+ self.dropout_p = dropout_p
114
+ self.is_causal = is_causal
115
+
116
+ def call(self, q, k, v, positional_encoding):
117
+ # Apply the RoPE transformation
118
+ q, k = ApplyRoPE()(q, k, positional_encoding)
119
+
120
+ # Scaled dot-product attention
121
+ x = scaled_dot_product_attention(
122
+ q, k, v, dropout_p=self.dropout_p, is_causal=self.is_causal
123
+ )
124
+ x = ops.transpose(x, (0, 2, 1, 3))
125
+ b, l, h, d = ops.shape(x)
126
+ return ops.reshape(x, (b, l, h * d))
127
+
128
+
129
+ # TODO: This is probably already implemented in several places, but is needed to ensure numeric equivalence to the original
130
+ # implementation. It uses torch.functional.scaled_dot_product_attention() - do we have an equivalent already in Keras?
131
+ def scaled_dot_product_attention(
132
+ query,
133
+ key,
134
+ value,
135
+ attn_mask=None,
136
+ dropout_p=0.0,
137
+ is_causal=False,
138
+ scale=None,
139
+ ):
140
+ """
141
+ Computes the scaled dot-product attention.
142
+
143
+ Args:
144
+ query: KerasTensor. Query tensor of shape (..., L, D).
145
+ key: KerasTensor. Key tensor of shape (..., S, D).
146
+ value: KerasTensor. Value tensor of shape (..., S, D).
147
+ attn_mask: KerasTensor, optional. Attention mask tensor. Defaults to None.
148
+ dropout_p: float, optional. Dropout probability. Defaults to 0.0.
149
+ is_causal: bool, optional. If True, applies causal masking. Defaults to False.
150
+ scale: float, optional. Scale factor for attention. Defaults to None.
151
+
152
+ Returns:
153
+ KerasTensor: The output tensor from the attention mechanism.
154
+ """
155
+ L, S = ops.shape(query)[-2], ops.shape(key)[-2]
156
+ scale_factor = (
157
+ 1 / ops.sqrt(ops.cast(ops.shape(query)[-1], dtype=query.dtype))
158
+ if scale is None
159
+ else scale
160
+ )
161
+ attn_bias = ops.zeros((L, S), dtype=query.dtype)
162
+
163
+ if is_causal:
164
+ assert attn_mask is None
165
+ temp_mask = ops.ones((L, S), dtype=ops.bool)
166
+ temp_mask = ops.tril(temp_mask, diagonal=0)
167
+ attn_bias = ops.where(temp_mask, attn_bias, float("-inf"))
168
+
169
+ if attn_mask is not None:
170
+ if ops.shape(attn_mask)[-1] == 1: # If the mask is 3D
171
+ attn_bias += attn_mask
172
+ else:
173
+ attn_bias = ops.where(attn_mask, attn_bias, float("-inf"))
174
+
175
+ # Compute attention weights
176
+ attn_weight = (
177
+ ops.matmul(query, ops.transpose(key, axes=[0, 1, 3, 2])) * scale_factor
178
+ )
179
+ attn_weight += attn_bias
180
+ attn_weight = keras.activations.softmax(attn_weight, axis=-1)
181
+
182
+ if dropout_p > 0.0:
183
+ attn_weight = keras.layers.Dropout(dropout_p)(
184
+ attn_weight, training=True
185
+ )
186
+
187
+ return ops.matmul(attn_weight, value)
188
+
189
+
190
+ def rearrange_symbolic_tensors(qkv, K, H):
191
+ """
192
+ Splits the qkv tensor into query (q), key (k), and value (v) components.
193
+
194
+ Mimics rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=num_heads),
195
+ for graph-mode TensorFlow support when doing functional subclassing
196
+ models.
197
+
198
+ Arguments:
199
+ qkv: np.ndarray. Input tensor of shape (B, L, K*H*D).
200
+ K: int. Number of components (q, k, v).
201
+ H: int. Number of attention heads.
202
+
203
+ Returns:
204
+ tuple: q, k, v tensors of shape (B, H, L, D).
205
+ """
206
+ # Get the shape of qkv and calculate L and D
207
+ B, L, dim = ops.shape(qkv)
208
+ D = dim // (K * H)
209
+
210
+ # Reshape and transpose the qkv tensor
211
+ qkv_reshaped = ops.reshape(qkv, (B, L, K, H, D))
212
+ qkv_transposed = ops.transpose(qkv_reshaped, (2, 0, 3, 1, 4))
213
+
214
+ # Split q, k, v along the first dimension (K)
215
+ qkv_splits = ops.split(qkv_transposed, K, axis=0)
216
+ q, k, v = [ops.squeeze(split, 0) for split in qkv_splits]
217
+
218
+ return q, k, v
@@ -0,0 +1,231 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.flux.flux_layers import DoubleStreamBlock
6
+ from keras_hub.src.models.flux.flux_layers import EmbedND
7
+ from keras_hub.src.models.flux.flux_layers import LastLayer
8
+ from keras_hub.src.models.flux.flux_layers import MLPEmbedder
9
+ from keras_hub.src.models.flux.flux_layers import SingleStreamBlock
10
+ from keras_hub.src.models.flux.flux_maths import TimestepEmbedding
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.FluxBackbone")
14
+ class FluxBackbone(Backbone):
15
+ """
16
+ Transformer model for flow matching on sequences.
17
+
18
+ The model processes image and text data with associated positional and timestep
19
+ embeddings, and optionally applies guidance embedding. Double-stream blocks
20
+ handle separate image and text streams, while single-stream blocks combine
21
+ these streams. Ported from: https://github.com/black-forest-labs/flux
22
+
23
+ Args:
24
+ input_channels: int. The number of input channels.
25
+ hidden_size: int. The hidden size of the transformer, must be divisible by `num_heads`.
26
+ mlp_ratio: float. The ratio of the MLP dimension to the hidden size.
27
+ num_heads: int. The number of attention heads.
28
+ depth: int. The number of double-stream blocks.
29
+ depth_single_blocks: int. The number of single-stream blocks.
30
+ axes_dim: list[int]. A list of dimensions for the positional embedding axes.
31
+ theta: int. The base frequency for positional embeddings.
32
+ use_bias: bool. Whether to apply bias to the query, key, and value projections.
33
+ guidance_embed: bool. If True, applies guidance embedding in the model.
34
+
35
+ Call arguments:
36
+ image: KerasTensor. Image input tensor of shape (N, L, D) where N is the batch size,
37
+ L is the sequence length, and D is the feature dimension.
38
+ image_ids: KerasTensor. Image ID input tensor of shape (N, L, D) corresponding
39
+ to the image sequences.
40
+ text: KerasTensor. Text input tensor of shape (N, L, D).
41
+ text_ids: KerasTensor. Text ID input tensor of shape (N, L, D) corresponding
42
+ to the text sequences.
43
+ timesteps: KerasTensor. Timestep tensor used to compute positional embeddings.
44
+ y: KerasTensor. Additional vector input, such as target values.
45
+ guidance: KerasTensor, optional. Guidance input tensor used
46
+ in guidance-embedded models.
47
+ Raises:
48
+ ValueError: If `hidden_size` is not divisible by `num_heads`, or if `sum(axes_dim)` is not equal to the
49
+ positional embedding dimension.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ input_channels,
55
+ hidden_size,
56
+ mlp_ratio,
57
+ num_heads,
58
+ depth,
59
+ depth_single_blocks,
60
+ axes_dim,
61
+ theta,
62
+ use_bias,
63
+ guidance_embed=False,
64
+ # These will be inferred from the CLIP/T5 encoders later
65
+ image_shape=(None, 768, 3072),
66
+ text_shape=(None, 768, 3072),
67
+ image_ids_shape=(None, 768, 3072),
68
+ text_ids_shape=(None, 768, 3072),
69
+ y_shape=(None, 128),
70
+ **kwargs,
71
+ ):
72
+
73
+ # === Layers ===
74
+ self.positional_embedder = EmbedND(theta=theta, axes_dim=axes_dim)
75
+ self.image_input_embedder = keras.layers.Dense(
76
+ hidden_size, use_bias=True
77
+ )
78
+ self.time_input_embedder = MLPEmbedder(hidden_dim=hidden_size)
79
+ self.vector_embedder = MLPEmbedder(hidden_dim=hidden_size)
80
+ self.guidance_input_embedder = (
81
+ MLPEmbedder(hidden_dim=hidden_size)
82
+ if guidance_embed
83
+ else keras.layers.Identity()
84
+ )
85
+ self.text_input_embedder = keras.layers.Dense(hidden_size)
86
+
87
+ self.double_blocks = [
88
+ DoubleStreamBlock(
89
+ hidden_size,
90
+ num_heads,
91
+ mlp_ratio=mlp_ratio,
92
+ use_bias=use_bias,
93
+ )
94
+ for _ in range(depth)
95
+ ]
96
+
97
+ self.single_blocks = [
98
+ SingleStreamBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
99
+ for _ in range(depth_single_blocks)
100
+ ]
101
+
102
+ self.final_layer = LastLayer(hidden_size, 1, input_channels)
103
+ self.timestep_embedding = TimestepEmbedding()
104
+ self.guidance_embed = guidance_embed
105
+
106
+ # === Functional Model ===
107
+ image_input = keras.Input(shape=image_shape, name="image")
108
+ image_ids = keras.Input(shape=image_ids_shape, name="image_ids")
109
+ text_input = keras.Input(shape=text_shape, name="text")
110
+ text_ids = keras.Input(shape=text_ids_shape, name="text_ids")
111
+ y = keras.Input(shape=y_shape, name="y")
112
+ timesteps_input = keras.Input(shape=(), name="timesteps")
113
+ guidance_input = keras.Input(shape=(), name="guidance")
114
+
115
+ # running on sequences image
116
+ image = self.image_input_embedder(image_input)
117
+ modulation_encoding = self.time_input_embedder(
118
+ self.timestep_embedding(timesteps_input, dim=256)
119
+ )
120
+ if self.guidance_embed:
121
+ if guidance_input is None:
122
+ raise ValueError(
123
+ "Didn't get guidance strength for guidance distilled model."
124
+ )
125
+ modulation_encoding = (
126
+ modulation_encoding
127
+ + self.guidance_input_embedder(
128
+ self.timestep_embedding(guidance_input, dim=256)
129
+ )
130
+ )
131
+
132
+ modulation_encoding = modulation_encoding + self.vector_embedder(y)
133
+ text = self.text_input_embedder(text_input)
134
+
135
+ ids = keras.ops.concatenate((text_ids, image_ids), axis=1)
136
+ positional_encoding = self.positional_embedder(ids)
137
+
138
+ for block in self.double_blocks:
139
+ image, text = block(
140
+ image=image,
141
+ text=text,
142
+ modulation_encoding=modulation_encoding,
143
+ positional_encoding=positional_encoding,
144
+ )
145
+
146
+ image = keras.ops.concatenate((text, image), axis=1)
147
+ for block in self.single_blocks:
148
+ image = block(
149
+ image,
150
+ modulation_encoding=modulation_encoding,
151
+ positional_encoding=positional_encoding,
152
+ )
153
+ image = image[:, text.shape[1] :, ...]
154
+
155
+ image = self.final_layer(
156
+ image, modulation_encoding
157
+ ) # (N, T, patch_size ** 2 * output_channels)
158
+
159
+ super().__init__(
160
+ inputs={
161
+ "image": image_input,
162
+ "image_ids": image_ids,
163
+ "text": text_input,
164
+ "text_ids": text_ids,
165
+ "y": y,
166
+ "timesteps": timesteps_input,
167
+ "guidance": guidance_input,
168
+ },
169
+ outputs=image,
170
+ **kwargs,
171
+ )
172
+
173
+ # === Config ===
174
+ self.input_channels = input_channels
175
+ self.output_channels = self.input_channels
176
+ self.hidden_size = hidden_size
177
+ self.num_heads = num_heads
178
+ self.image_shape = image_shape
179
+ self.text_shape = text_shape
180
+ self.image_ids_shape = image_ids_shape
181
+ self.text_ids_shape = text_ids_shape
182
+ self.y_shape = y_shape
183
+ self.mlp_ratio = mlp_ratio
184
+ self.depth = depth
185
+ self.depth_single_blocks = depth_single_blocks
186
+ self.axes_dim = axes_dim
187
+ self.theta = theta
188
+ self.use_bias = use_bias
189
+
190
+ def get_config(self):
191
+ config = super().get_config()
192
+ config.update(
193
+ {
194
+ "input_channels": self.input_channels,
195
+ "hidden_size": self.hidden_size,
196
+ "mlp_ratio": self.mlp_ratio,
197
+ "num_heads": self.num_heads,
198
+ "depth": self.depth,
199
+ "depth_single_blocks": self.depth_single_blocks,
200
+ "axes_dim": self.axes_dim,
201
+ "theta": self.theta,
202
+ "use_bias": self.use_bias,
203
+ "guidance_embed": self.guidance_embed,
204
+ "image_shape": self.image_shape,
205
+ "text_shape": self.text_shape,
206
+ "image_ids_shape": self.image_ids_shape,
207
+ "text_ids_shape": self.text_ids_shape,
208
+ "y_shape": self.y_shape,
209
+ }
210
+ )
211
+ return config
212
+
213
+ def encode_text_step(self, token_ids, negative_token_ids):
214
+ raise NotImplementedError("Not implemented yet")
215
+
216
+ def encode(token_ids):
217
+ raise NotImplementedError("Not implemented yet")
218
+
219
+ def encode_image_step(self, images):
220
+ raise NotImplementedError("Not implemented yet")
221
+
222
+ def add_noise_step(self, latents, noises, step, num_steps):
223
+ raise NotImplementedError("Not implemented yet")
224
+
225
+ def denoise_step(
226
+ self,
227
+ ):
228
+ raise NotImplementedError("Not implemented yet")
229
+
230
+ def decode_step(self, latents):
231
+ raise NotImplementedError("Not implemented yet")
@@ -0,0 +1,14 @@
1
+ """FLUX model preset configurations."""
2
+
3
+ presets = {
4
+ "schnell": {
5
+ "metadata": {
6
+ "description": (
7
+ "A 12 billion parameter rectified flow transformer capable of generating images from text descriptions."
8
+ ),
9
+ "params": 124439808,
10
+ "path": "flux",
11
+ },
12
+ "kaggle_handle": "TBA",
13
+ },
14
+ }
@@ -0,0 +1,142 @@
1
+ from keras import ops
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
5
+ from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
6
+ FluxTextToImagePreprocessor,
7
+ )
8
+ from keras_hub.src.models.text_to_image import TextToImage
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.FluxTextToImage")
12
+ class FluxTextToImage(TextToImage):
13
+ """An end-to-end Flux model for text-to-image generation.
14
+
15
+ This model has a `generate()` method, which generates image based on a
16
+ prompt.
17
+
18
+ Args:
19
+ backbone: A `keras_hub.models.FluxBackbone` instance.
20
+ preprocessor: A
21
+ `keras_hub.models.FluxTextToImagePreprocessor` instance.
22
+
23
+ Examples:
24
+
25
+ Use `generate()` to do image generation.
26
+ ```python
27
+ text_to_image = keras_hub.models.FluxTextToImage.from_preset(
28
+ "TBA", height=512, width=512
29
+ )
30
+ text_to_image.generate(
31
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
32
+ )
33
+
34
+ # Generate with batched prompts.
35
+ text_to_image.generate(
36
+ ["cute wallpaper art of a cat", "cute wallpaper art of a dog"]
37
+ )
38
+
39
+ # Generate with different `num_steps` and `guidance_scale`.
40
+ text_to_image.generate(
41
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
+ num_steps=50,
43
+ guidance_scale=5.0,
44
+ )
45
+
46
+ # Generate with `negative_prompts`.
47
+ text_to_image.generate(
48
+ {
49
+ "prompts": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
50
+ "negative_prompts": "green color",
51
+ }
52
+ )
53
+ ```
54
+ """
55
+
56
+ backbone_cls = FluxBackbone
57
+ preprocessor_cls = FluxTextToImagePreprocessor
58
+
59
+ def __init__(
60
+ self,
61
+ backbone,
62
+ preprocessor,
63
+ **kwargs,
64
+ ):
65
+ # === Layers ===
66
+ self.backbone = backbone
67
+ self.preprocessor = preprocessor
68
+
69
+ # === Functional Model ===
70
+ inputs = backbone.input
71
+ outputs = backbone.output
72
+ super().__init__(
73
+ inputs=inputs,
74
+ outputs=outputs,
75
+ **kwargs,
76
+ )
77
+
78
+ def fit(self, *args, **kwargs):
79
+ raise NotImplementedError(
80
+ "Currently, `fit` is not supported for " "`FluxTextToImage`."
81
+ )
82
+
83
+ def generate_step(
84
+ self,
85
+ latents,
86
+ token_ids,
87
+ num_steps,
88
+ guidance_scale,
89
+ ):
90
+ """A compilable generation function for batched of inputs.
91
+
92
+ This function represents the inner, XLA-compilable, generation function
93
+ for batched inputs.
94
+
95
+ Args:
96
+ latents: A (batch_size, height, width, channels) tensor
97
+ containing the latents to start generation from. Typically, this
98
+ tensor is sampled from the Gaussian distribution.
99
+ token_ids: A pair of (batch_size, num_tokens) tensor containing the
100
+ tokens based on the input prompts and negative prompts.
101
+ num_steps: int. The number of diffusion steps to take.
102
+ guidance_scale: float. The classifier free guidance scale defined in
103
+ [Classifier-Free Diffusion Guidance](
104
+ https://arxiv.org/abs/2207.12598). Higher scale encourages to
105
+ generate images that are closely linked to prompts, usually at
106
+ the expense of lower image quality.
107
+ """
108
+ token_ids, negative_token_ids = token_ids
109
+
110
+ # Encode prompts.
111
+ embeddings = self.backbone.encode_text_step(
112
+ token_ids, negative_token_ids
113
+ )
114
+
115
+ # Denoise.
116
+ def body_fun(step, latents):
117
+ return self.backbone.denoise_step(
118
+ latents,
119
+ embeddings,
120
+ step,
121
+ num_steps,
122
+ guidance_scale,
123
+ )
124
+
125
+ latents = ops.fori_loop(0, num_steps, body_fun, latents)
126
+
127
+ # Decode.
128
+ return self.backbone.decode_step(latents)
129
+
130
+ def generate(
131
+ self,
132
+ inputs,
133
+ num_steps=28,
134
+ guidance_scale=7.0,
135
+ seed=None,
136
+ ):
137
+ return super().generate(
138
+ inputs,
139
+ num_steps=num_steps,
140
+ guidance_scale=guidance_scale,
141
+ seed=seed,
142
+ )
@@ -0,0 +1,73 @@
1
+ import keras
2
+ from keras import layers
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.flux.flux_model import FluxBackbone
6
+ from keras_hub.src.models.preprocessor import Preprocessor
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.FluxTextToImagePreprocessor")
10
+ class FluxTextToImagePreprocessor(Preprocessor):
11
+ """Flux text-to-image model preprocessor.
12
+
13
+ This preprocessing layer is meant for use with
14
+ `keras_hub.models.FluxTextToImagePreprocessor`.
15
+
16
+ For use with generation, the layer exposes one methods
17
+ `generate_preprocess()`.
18
+
19
+ Args:
20
+ clip_l_preprocessor: A `keras_hub.models.CLIPPreprocessor` instance.
21
+ t5_preprocessor: A optional `keras_hub.models.T5Preprocessor` instance.
22
+ """
23
+
24
+ backbone_cls = FluxBackbone
25
+
26
+ def __init__(
27
+ self,
28
+ clip_l_preprocessor,
29
+ t5_preprocessor=None,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+ self.clip_l_preprocessor = clip_l_preprocessor
34
+ self.t5_preprocessor = t5_preprocessor
35
+
36
+ @property
37
+ def sequence_length(self):
38
+ """The padded length of model input sequences."""
39
+ return self.clip_l_preprocessor.sequence_length
40
+
41
+ def build(self, input_shape):
42
+ self.built = True
43
+
44
+ def generate_preprocess(self, x):
45
+ token_ids = {}
46
+ token_ids["clip_l"] = self.clip_l_preprocessor(x)["token_ids"]
47
+ if self.t5_preprocessor is not None:
48
+ token_ids["t5"] = self.t5_preprocessor(x)["token_ids"]
49
+ return token_ids
50
+
51
+ def get_config(self):
52
+ config = super().get_config()
53
+ config.update(
54
+ {
55
+ "clip_l_preprocessor": layers.serialize(
56
+ self.clip_l_preprocessor
57
+ ),
58
+ "t5_preprocessor": layers.serialize(self.t5_preprocessor),
59
+ }
60
+ )
61
+ return config
62
+
63
+ @classmethod
64
+ def from_config(cls, config):
65
+ for layer_name in (
66
+ "clip_l_preprocessor",
67
+ "t5_preprocessor",
68
+ ):
69
+ if layer_name in config and isinstance(config[layer_name], dict):
70
+ config[layer_name] = keras.layers.deserialize(
71
+ config[layer_name]
72
+ )
73
+ return cls(**config)