keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,7 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "2 billion parameter, 18-layer, base Gemma model.",
8
8
  "params": 2506172416,
9
- "official_name": "Gemma",
10
9
  "path": "gemma",
11
- "model_card": "https://www.kaggle.com/models/google/gemma",
12
10
  },
13
11
  "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/2",
14
12
  },
@@ -18,9 +16,7 @@ backbone_presets = {
18
16
  "2 billion parameter, 18-layer, instruction tuned Gemma model."
19
17
  ),
20
18
  "params": 2506172416,
21
- "official_name": "Gemma",
22
19
  "path": "gemma",
23
- "model_card": "https://www.kaggle.com/models/google/gemma",
24
20
  },
25
21
  "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/2",
26
22
  },
@@ -31,9 +27,7 @@ backbone_presets = {
31
27
  "The 1.1 update improves model quality."
32
28
  ),
33
29
  "params": 2506172416,
34
- "official_name": "Gemma",
35
30
  "path": "gemma",
36
- "model_card": "https://www.kaggle.com/models/google/gemma",
37
31
  },
38
32
  "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_2b_en/3",
39
33
  },
@@ -45,9 +39,7 @@ backbone_presets = {
45
39
  "completion. The 1.1 update improves model quality."
46
40
  ),
47
41
  "params": 2506172416,
48
- "official_name": "Gemma",
49
42
  "path": "gemma",
50
- "model_card": "https://www.kaggle.com/models/google/gemma",
51
43
  },
52
44
  "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_2b_en/1",
53
45
  },
@@ -59,9 +51,7 @@ backbone_presets = {
59
51
  "completion."
60
52
  ),
61
53
  "params": 2506172416,
62
- "official_name": "Gemma",
63
54
  "path": "gemma",
64
- "model_card": "https://www.kaggle.com/models/google/gemma",
65
55
  },
66
56
  "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_2b_en/1",
67
57
  },
@@ -69,9 +59,7 @@ backbone_presets = {
69
59
  "metadata": {
70
60
  "description": "7 billion parameter, 28-layer, base Gemma model.",
71
61
  "params": 8537680896,
72
- "official_name": "Gemma",
73
62
  "path": "gemma",
74
- "model_card": "https://www.kaggle.com/models/google/gemma",
75
63
  },
76
64
  "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/2",
77
65
  },
@@ -81,9 +69,7 @@ backbone_presets = {
81
69
  "7 billion parameter, 28-layer, instruction tuned Gemma model."
82
70
  ),
83
71
  "params": 8537680896,
84
- "official_name": "Gemma",
85
72
  "path": "gemma",
86
- "model_card": "https://www.kaggle.com/models/google/gemma",
87
73
  },
88
74
  "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/2",
89
75
  },
@@ -94,9 +80,7 @@ backbone_presets = {
94
80
  "The 1.1 update improves model quality."
95
81
  ),
96
82
  "params": 8537680896,
97
- "official_name": "Gemma",
98
83
  "path": "gemma",
99
- "model_card": "https://www.kaggle.com/models/google/gemma",
100
84
  },
101
85
  "kaggle_handle": "kaggle://keras/gemma/keras/gemma_1.1_instruct_7b_en/3",
102
86
  },
@@ -108,9 +92,7 @@ backbone_presets = {
108
92
  "completion."
109
93
  ),
110
94
  "params": 8537680896,
111
- "official_name": "Gemma",
112
95
  "path": "gemma",
113
- "model_card": "https://www.kaggle.com/models/google/gemma",
114
96
  },
115
97
  "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_7b_en/1",
116
98
  },
@@ -122,9 +104,7 @@ backbone_presets = {
122
104
  "to code."
123
105
  ),
124
106
  "params": 8537680896,
125
- "official_name": "Gemma",
126
107
  "path": "gemma",
127
- "model_card": "https://www.kaggle.com/models/google/gemma",
128
108
  },
129
109
  "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_instruct_7b_en/1",
130
110
  },
@@ -136,9 +116,7 @@ backbone_presets = {
136
116
  "to code. The 1.1 update improves model quality."
137
117
  ),
138
118
  "params": 8537680896,
139
- "official_name": "Gemma",
140
119
  "path": "gemma",
141
- "model_card": "https://www.kaggle.com/models/google/gemma",
142
120
  },
143
121
  "kaggle_handle": "kaggle://keras/codegemma/keras/code_gemma_1.1_instruct_7b_en/1",
144
122
  },
@@ -146,9 +124,7 @@ backbone_presets = {
146
124
  "metadata": {
147
125
  "description": "2 billion parameter, 26-layer, base Gemma model.",
148
126
  "params": 2614341888,
149
- "official_name": "Gemma",
150
127
  "path": "gemma",
151
- "model_card": "https://www.kaggle.com/models/google/gemma",
152
128
  },
153
129
  "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_2b_en/1",
154
130
  },
@@ -156,9 +132,7 @@ backbone_presets = {
156
132
  "metadata": {
157
133
  "description": "2 billion parameter, 26-layer, instruction tuned Gemma model.",
158
134
  "params": 2614341888,
159
- "official_name": "Gemma",
160
135
  "path": "gemma",
161
- "model_card": "https://www.kaggle.com/models/google/gemma",
162
136
  },
163
137
  "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_2b_en/1",
164
138
  },
@@ -166,9 +140,7 @@ backbone_presets = {
166
140
  "metadata": {
167
141
  "description": "9 billion parameter, 42-layer, base Gemma model.",
168
142
  "params": 9241705984,
169
- "official_name": "Gemma",
170
143
  "path": "gemma",
171
- "model_card": "https://www.kaggle.com/models/google/gemma",
172
144
  },
173
145
  "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_9b_en/2",
174
146
  },
@@ -176,9 +148,7 @@ backbone_presets = {
176
148
  "metadata": {
177
149
  "description": "9 billion parameter, 42-layer, instruction tuned Gemma model.",
178
150
  "params": 9241705984,
179
- "official_name": "Gemma",
180
151
  "path": "gemma",
181
- "model_card": "https://www.kaggle.com/models/google/gemma",
182
152
  },
183
153
  "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_9b_en/2",
184
154
  },
@@ -186,9 +156,7 @@ backbone_presets = {
186
156
  "metadata": {
187
157
  "description": "27 billion parameter, 42-layer, base Gemma model.",
188
158
  "params": 27227128320,
189
- "official_name": "Gemma",
190
159
  "path": "gemma",
191
- "model_card": "https://www.kaggle.com/models/google/gemma",
192
160
  },
193
161
  "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_27b_en/1",
194
162
  },
@@ -196,9 +164,7 @@ backbone_presets = {
196
164
  "metadata": {
197
165
  "description": "27 billion parameter, 42-layer, instruction tuned Gemma model.",
198
166
  "params": 27227128320,
199
- "official_name": "Gemma",
200
167
  "path": "gemma",
201
- "model_card": "https://www.kaggle.com/models/google/gemma",
202
168
  },
203
169
  "kaggle_handle": "kaggle://keras/gemma2/keras/gemma2_instruct_27b_en/1",
204
170
  },
@@ -206,9 +172,7 @@ backbone_presets = {
206
172
  "metadata": {
207
173
  "description": "2 billion parameter, 26-layer, ShieldGemma model.",
208
174
  "params": 2614341888,
209
- "official_name": "Gemma",
210
175
  "path": "gemma",
211
- "model_card": "https://www.kaggle.com/models/google/gemma",
212
176
  },
213
177
  "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_2b_en/1",
214
178
  },
@@ -216,9 +180,7 @@ backbone_presets = {
216
180
  "metadata": {
217
181
  "description": "9 billion parameter, 42-layer, ShieldGemma model.",
218
182
  "params": 9241705984,
219
- "official_name": "Gemma",
220
183
  "path": "gemma",
221
- "model_card": "https://www.kaggle.com/models/google/gemma",
222
184
  },
223
185
  "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_9b_en/1",
224
186
  },
@@ -226,9 +188,7 @@ backbone_presets = {
226
188
  "metadata": {
227
189
  "description": "27 billion parameter, 42-layer, ShieldGemma model.",
228
190
  "params": 27227128320,
229
- "official_name": "Gemma",
230
191
  "path": "gemma",
231
- "model_card": "https://www.kaggle.com/models/google/gemma",
232
192
  },
233
193
  "kaggle_handle": "kaggle://google/shieldgemma/keras/shieldgemma_27b_en/1",
234
194
  },
@@ -9,9 +9,7 @@ backbone_presets = {
9
9
  "Trained on WebText."
10
10
  ),
11
11
  "params": 124439808,
12
- "official_name": "GPT-2",
13
12
  "path": "gpt2",
14
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
15
13
  },
16
14
  "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en/2",
17
15
  },
@@ -22,9 +20,7 @@ backbone_presets = {
22
20
  "Trained on WebText."
23
21
  ),
24
22
  "params": 354823168,
25
- "official_name": "GPT-2",
26
23
  "path": "gpt2",
27
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
28
24
  },
29
25
  "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_medium_en/2",
30
26
  },
@@ -35,9 +31,7 @@ backbone_presets = {
35
31
  "Trained on WebText."
36
32
  ),
37
33
  "params": 774030080,
38
- "official_name": "GPT-2",
39
34
  "path": "gpt2",
40
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
41
35
  },
42
36
  "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_large_en/2",
43
37
  },
@@ -48,9 +42,7 @@ backbone_presets = {
48
42
  "Trained on WebText."
49
43
  ),
50
44
  "params": 1557611200,
51
- "official_name": "GPT-2",
52
45
  "path": "gpt2",
53
- "model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
54
46
  },
55
47
  "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_extra_large_en/2",
56
48
  },
@@ -61,7 +53,6 @@ backbone_presets = {
61
53
  "Finetuned on the CNN/DailyMail summarization dataset."
62
54
  ),
63
55
  "params": 124439808,
64
- "official_name": "GPT-2",
65
56
  "path": "gpt2",
66
57
  },
67
58
  "kaggle_handle": "kaggle://keras/gpt2/keras/gpt2_base_en_cnn_dailymail/2",
@@ -0,0 +1,87 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.task import Task
5
+
6
+
7
+ @keras_hub_export("keras_hub.models.ImageObjectDetector")
8
+ class ImageObjectDetector(Task):
9
+ """Base class for all image object detection tasks.
10
+
11
+ The `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and
12
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
13
+ object detection. `ImageObjectDetector` tasks take an additional
14
+ `num_classes` argument, controlling the number of predicted output classes.
15
+
16
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
17
+ labels where `x` is a string and `y` is dictionary with `boxes` and
18
+ `classes`.
19
+
20
+ All `ImageObjectDetector` tasks include a `from_preset()` constructor which
21
+ can be used to load a pre-trained config and weights.
22
+ """
23
+
24
+ def compile(
25
+ self,
26
+ optimizer="auto",
27
+ box_loss="auto",
28
+ classification_loss="auto",
29
+ metrics=None,
30
+ **kwargs,
31
+ ):
32
+ """Configures the `ImageObjectDetector` task for training.
33
+
34
+ The `ImageObjectDetector` task extends the default compilation signature of
35
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
36
+ `metrics`. To override these defaults, pass any value
37
+ to these arguments during compilation.
38
+
39
+ Args:
40
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
41
+ instance. Defaults to `"auto"`, which uses the default optimizer
42
+ for the given model and task. See `keras.Model.compile` and
43
+ `keras.optimizers` for more info on possible `optimizer` values.
44
+ box_loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
45
+ Defaults to `"auto"`, where a
46
+ `keras.losses.Huber` loss will be
47
+ applied for the object detector task. See
48
+ `keras.Model.compile` and `keras.losses` for more info on
49
+ possible `loss` values.
50
+ classification_loss: `"auto"`, a loss name, or a `keras.losses.Loss`
51
+ instance. Defaults to `"auto"`, where a
52
+ `keras.losses.BinaryFocalCrossentropy` loss will be
53
+ applied for the object detector task. See
54
+ `keras.Model.compile` and `keras.losses` for more info on
55
+ possible `loss` values.
56
+ metrics: `a list of metrics to be evaluated by
57
+ the model during training and testing. Defaults to `None`.
58
+ See `keras.Model.compile` and `keras.metrics` for
59
+ more info on possible `metrics` values.
60
+ **kwargs: See `keras.Model.compile` for a full list of arguments
61
+ supported by the compile method.
62
+ """
63
+ if optimizer == "auto":
64
+ optimizer = keras.optimizers.Adam(5e-5)
65
+ if box_loss == "auto":
66
+ box_loss = keras.losses.Huber(reduction="sum")
67
+ if classification_loss == "auto":
68
+ activation = getattr(self, "activation", None)
69
+ activation = keras.activations.get(activation)
70
+ from_logits = activation != keras.activations.sigmoid
71
+ classification_loss = keras.losses.BinaryFocalCrossentropy(
72
+ from_logits=from_logits, reduction="sum"
73
+ )
74
+ if metrics is not None:
75
+ raise ValueError("User metrics not yet supported")
76
+
77
+ losses = {
78
+ "bbox_regression": box_loss,
79
+ "cls_logits": classification_loss,
80
+ }
81
+
82
+ super().compile(
83
+ optimizer=optimizer,
84
+ loss=losses,
85
+ metrics=metrics,
86
+ **kwargs,
87
+ )
@@ -0,0 +1,57 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.preprocessor import Preprocessor
5
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
6
+
7
+
8
+ @keras_hub_export("keras_hub.models.ImageObjectDetectorPreprocessor")
9
+ class ImageObjectDetectorPreprocessor(Preprocessor):
10
+ """Base class for object detector preprocessing layers.
11
+
12
+ `ImageObjectDetectorPreprocessor` tasks wraps a
13
+ `keras_hub.layers.Preprocessor` to create a preprocessing layer for
14
+ object detection tasks. It is intended to be paired with a
15
+ `keras_hub.models.ImageObjectDetector` task.
16
+
17
+ All `ImageObjectDetectorPreprocessor` take three inputs, `x`, `y`, and
18
+ `sample_weight`. `x`, the first input, should always be included. It can
19
+ be a image or batch of images. See examples below. `y` and `sample_weight`
20
+ are optional inputs that will be passed through unaltered. Usually, `y` will
21
+ be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4),
22
+ "classes": (batch_size, num_boxes)}.
23
+
24
+ The layer will returns either `x`, an `(x, y)` tuple if labels were provided,
25
+ or an `(x, y, sample_weight)` tuple if labels and sample weight were
26
+ provided. `x` will be the input images after all model preprocessing has
27
+ been applied.
28
+
29
+ All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()`
30
+ constructor which can be used to load a pre-trained config and vocabularies.
31
+ You can call the `from_preset()` constructor directly on this base class, in
32
+ which case the correct class for your model will be automatically
33
+ instantiated.
34
+
35
+ Args:
36
+ image_converter: Preprocessing pipeline for images.
37
+
38
+ Examples.
39
+ ```python
40
+ preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset(
41
+ "retinanet_resnet50",
42
+ )
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ image_converter=None,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(**kwargs)
51
+ self.image_converter = image_converter
52
+
53
+ @preprocessing_function
54
+ def call(self, x, y=None, sample_weight=None):
55
+ if self.image_converter:
56
+ x = self.image_converter(x)
57
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -234,7 +234,7 @@ class ImageToImage(Task):
234
234
  input_is_scalar = True
235
235
  x = ops.image.resize(
236
236
  x,
237
- (self.backbone.height, self.backbone.width),
237
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
238
238
  interpolation="nearest",
239
239
  data_format=data_format,
240
240
  )
@@ -284,8 +284,8 @@ class ImageToImage(Task):
284
284
  self,
285
285
  inputs,
286
286
  num_steps,
287
- guidance_scale,
288
287
  strength,
288
+ guidance_scale=None,
289
289
  seed=None,
290
290
  ):
291
291
  """Generate image based on the provided `inputs`.
@@ -313,30 +313,36 @@ class ImageToImage(Task):
313
313
  - A `tf.data.Dataset` with `"images"`, `"prompts"` and/or
314
314
  `"negative_prompts"` keys.
315
315
  num_steps: int. The number of diffusion steps to take.
316
- guidance_scale: float. The classifier free guidance scale defined in
317
- [Classifier-Free Diffusion Guidance](
318
- https://arxiv.org/abs/2207.12598). A higher scale encourages
319
- generating images more closely related to the prompts, typically
320
- at the cost of lower image quality.
321
316
  strength: float. Indicates the extent to which the reference
322
317
  `images` are transformed. Must be between `0.0` and `1.0`. When
323
318
  `strength=1.0`, `images` is essentially ignore and added noise
324
319
  is maximum and the denoising process runs for the full number of
325
320
  iterations specified in `num_steps`.
321
+ guidance_scale: Optional float. The classifier free guidance scale
322
+ defined in [Classifier-Free Diffusion Guidance](
323
+ https://arxiv.org/abs/2207.12598). A higher scale encourages
324
+ generating images more closely related to the prompts, typically
325
+ at the cost of lower image quality. Note that some models don't
326
+ utilize classifier-free guidance.
326
327
  seed: optional int. Used as a random seed.
327
328
  """
328
329
  num_steps = int(num_steps)
329
- guidance_scale = float(guidance_scale)
330
330
  strength = float(strength)
331
+ guidance_scale = (
332
+ float(guidance_scale) if guidance_scale is not None else None
333
+ )
331
334
  if strength < 0.0 or strength > 1.0:
332
335
  raise ValueError(
333
336
  "`strength` must be between `0.0` and `1.0`. "
334
337
  f"Received strength={strength}."
335
338
  )
339
+ if guidance_scale is not None and guidance_scale > 1.0:
340
+ guidance_scale = ops.convert_to_tensor(float(guidance_scale))
341
+ else:
342
+ guidance_scale = None
336
343
  starting_step = int(num_steps * (1.0 - strength))
337
344
  starting_step = ops.convert_to_tensor(starting_step, "int32")
338
- num_steps = ops.convert_to_tensor(num_steps, "int32")
339
- guidance_scale = ops.convert_to_tensor(guidance_scale)
345
+ num_steps = ops.convert_to_tensor(int(num_steps), "int32")
340
346
 
341
347
  # Check `inputs` format.
342
348
  required_keys = ["images", "prompts"]
@@ -202,7 +202,7 @@ class Inpaint(Task):
202
202
  input_is_scalar = True
203
203
  x = ops.image.resize(
204
204
  x,
205
- (self.backbone.height, self.backbone.width),
205
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
206
206
  interpolation="nearest",
207
207
  data_format=data_format,
208
208
  )
@@ -240,7 +240,7 @@ class Inpaint(Task):
240
240
  x = ops.cast(x, "float32")
241
241
  x = ops.image.resize(
242
242
  x,
243
- (self.backbone.height, self.backbone.width),
243
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
244
244
  interpolation="nearest",
245
245
  data_format=data_format,
246
246
  )
@@ -303,7 +303,7 @@ class Inpaint(Task):
303
303
  input_is_scalar = True
304
304
  x = ops.image.resize(
305
305
  x,
306
- (self.backbone.height, self.backbone.width),
306
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
307
307
  interpolation="nearest",
308
308
  data_format=data_format,
309
309
  )
@@ -323,7 +323,7 @@ class Inpaint(Task):
323
323
  x = ops.cast(x, "float32")
324
324
  x = ops.image.resize(
325
325
  x,
326
- (self.backbone.height, self.backbone.width),
326
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
327
327
  interpolation="nearest",
328
328
  data_format=data_format,
329
329
  )
@@ -376,16 +376,16 @@ class Inpaint(Task):
376
376
  self,
377
377
  inputs,
378
378
  num_steps,
379
- guidance_scale,
380
379
  strength,
380
+ guidance_scale=None,
381
381
  seed=None,
382
382
  ):
383
383
  """Generate image based on the provided `inputs`.
384
384
 
385
385
  Typically, `inputs` is a dict with `"images"` `"masks"` and `"prompts"`
386
386
  keys. `"images"` are reference images within a value range of
387
- `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and
388
- `self.backbone.width`, then encoded into latent space by the VAE
387
+ `[-1.0, 1.0]`, which will be resized to height and width from
388
+ `self.backbone.image_shape`, then encoded into latent space by the VAE
389
389
  encoder. `"masks"` are mask images with a boolean dtype, where white
390
390
  pixels are repainted while black pixels are preserved. `"prompts"` are
391
391
  strings that will be tokenized and encoded by the text encoder.
@@ -406,26 +406,33 @@ class Inpaint(Task):
406
406
  - A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"`
407
407
  and/or `"negative_prompts"` keys.
408
408
  num_steps: int. The number of diffusion steps to take.
409
- guidance_scale: float. The classifier free guidance scale defined in
410
- [Classifier-Free Diffusion Guidance](
411
- https://arxiv.org/abs/2207.12598). A higher scale encourages
412
- generating images more closely related to the prompts, typically
413
- at the cost of lower image quality.
414
409
  strength: float. Indicates the extent to which the reference
415
410
  `images` are transformed. Must be between `0.0` and `1.0`. When
416
411
  `strength=1.0`, `images` is essentially ignore and added noise
417
412
  is maximum and the denoising process runs for the full number of
418
413
  iterations specified in `num_steps`.
414
+ guidance_scale: Optional float. The classifier free guidance scale
415
+ defined in [Classifier-Free Diffusion Guidance](
416
+ https://arxiv.org/abs/2207.12598). A higher scale encourages
417
+ generating images more closely related to the prompts, typically
418
+ at the cost of lower image quality. Note that some models don't
419
+ utilize classifier-free guidance.
419
420
  seed: optional int. Used as a random seed.
420
421
  """
421
422
  num_steps = int(num_steps)
422
- guidance_scale = float(guidance_scale)
423
423
  strength = float(strength)
424
+ guidance_scale = (
425
+ float(guidance_scale) if guidance_scale is not None else None
426
+ )
424
427
  if strength < 0.0 or strength > 1.0:
425
428
  raise ValueError(
426
429
  "`strength` must be between `0.0` and `1.0`. "
427
430
  f"Received strength={strength}."
428
431
  )
432
+ if guidance_scale is not None and guidance_scale > 1.0:
433
+ guidance_scale = ops.convert_to_tensor(guidance_scale)
434
+ else:
435
+ guidance_scale = None
429
436
  starting_step = int(num_steps * (1.0 - strength))
430
437
  starting_step = ops.convert_to_tensor(starting_step, "int32")
431
438
  num_steps = ops.convert_to_tensor(num_steps, "int32")
@@ -59,7 +59,7 @@ class LlamaBackbone(Backbone):
59
59
  }
60
60
 
61
61
  # Pretrained Llama decoder.
62
- model = keras_hub.models.LlamaBackbone.from_preset("llama7b_base_en")
62
+ model = keras_hub.models.LlamaBackbone.from_preset("llama2_7b_en")
63
63
  model(input_data)
64
64
 
65
65
  # Randomly initialized Llama decoder with custom config.
@@ -6,9 +6,7 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "7 billion parameter, 32-layer, base LLaMA 2 model.",
8
8
  "params": 6738415616,
9
- "official_name": "LLaMA 2",
10
- "path": "llama2",
11
- "model_card": "https://github.com/meta-llama/llama",
9
+ "path": "llama",
12
10
  },
13
11
  "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1",
14
12
  },
@@ -19,9 +17,7 @@ backbone_presets = {
19
17
  "activation and weights quantized to int8."
20
18
  ),
21
19
  "params": 6739839488,
22
- "official_name": "LLaMA 2",
23
- "path": "llama2",
24
- "model_card": "https://github.com/meta-llama/llama",
20
+ "path": "llama",
25
21
  },
26
22
  "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/1",
27
23
  },
@@ -32,9 +28,7 @@ backbone_presets = {
32
28
  "model."
33
29
  ),
34
30
  "params": 6738415616,
35
- "official_name": "LLaMA 2",
36
- "path": "llama2",
37
- "model_card": "https://github.com/meta-llama/llama",
31
+ "path": "llama",
38
32
  },
39
33
  "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1",
40
34
  },
@@ -45,9 +39,7 @@ backbone_presets = {
45
39
  "model with activation and weights quantized to int8."
46
40
  ),
47
41
  "params": 6739839488,
48
- "official_name": "LLaMA 2",
49
- "path": "llama2",
50
- "model_card": "https://github.com/meta-llama/llama",
42
+ "path": "llama",
51
43
  },
52
44
  "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1",
53
45
  },
@@ -58,9 +50,7 @@ backbone_presets = {
58
50
  "model."
59
51
  ),
60
52
  "params": 6738415616,
61
- "official_name": "Vicuna",
62
- "path": "vicuna",
63
- "model_card": "https://github.com/lm-sys/FastChat",
53
+ "path": "llama",
64
54
  },
65
55
  "kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/1",
66
56
  },
@@ -6,9 +6,7 @@ backbone_presets = {
6
6
  "metadata": {
7
7
  "description": "8 billion parameter, 32-layer, base LLaMA 3 model.",
8
8
  "params": 8030261248,
9
- "official_name": "LLaMA 3",
10
9
  "path": "llama3",
11
- "model_card": "https://github.com/meta-llama/llama3",
12
10
  },
13
11
  "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/3",
14
12
  },
@@ -19,9 +17,7 @@ backbone_presets = {
19
17
  "activation and weights quantized to int8."
20
18
  ),
21
19
  "params": 8031894016,
22
- "official_name": "LLaMA 3",
23
20
  "path": "llama3",
24
- "model_card": "https://github.com/meta-llama/llama3",
25
21
  },
26
22
  "kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/1",
27
23
  },
@@ -32,9 +28,7 @@ backbone_presets = {
32
28
  "model."
33
29
  ),
34
30
  "params": 8030261248,
35
- "official_name": "LLaMA 3",
36
31
  "path": "llama3",
37
- "model_card": "https://github.com/meta-llama/llama3",
38
32
  },
39
33
  "kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/3",
40
34
  },
@@ -45,9 +39,7 @@ backbone_presets = {
45
39
  "model with activation and weights quantized to int8."
46
40
  ),
47
41
  "params": 8031894016,
48
- "official_name": "LLaMA 3",
49
42
  "path": "llama3",
50
- "model_card": "https://github.com/meta-llama/llama3",
51
43
  },
52
44
  "kaggle_handle": (
53
45
  "kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/1"