keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import contextlib
16
+ import functools
17
+ import inspect
18
+ import threading
19
+
15
20
  import keras
21
+ import numpy as np
16
22
  from keras import ops
17
23
 
18
24
  try:
@@ -23,6 +29,181 @@ except ImportError:
23
29
  tf_text = None
24
30
 
25
31
 
32
+ NO_CONVERT_COUNTER = threading.local()
33
+
34
+
35
+ @contextlib.contextmanager
36
+ def no_convert_scope():
37
+ try:
38
+ NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1
39
+ yield
40
+ finally:
41
+ NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1
42
+
43
+
44
+ def in_no_convert_scope():
45
+ return getattr(NO_CONVERT_COUNTER, "count", 0) > 0
46
+
47
+
48
+ def preprocessing_function(fn):
49
+ """Wraps a preprocessing function to handle tf tensor conversion."""
50
+ if tf is None:
51
+ return fn
52
+
53
+ params = inspect.signature(fn).parameters
54
+ accepts_labels = all(k in params for k in ("x", "y", "sample_weight"))
55
+ if not accepts_labels:
56
+
57
+ @functools.wraps(fn)
58
+ def wrapper(self, x, **kwargs):
59
+ with tf.device("cpu"):
60
+ x = convert_preprocessing_inputs(x)
61
+ with no_convert_scope():
62
+ x = fn(self, x, **kwargs)
63
+ return convert_preprocessing_outputs(x)
64
+
65
+ else:
66
+
67
+ @functools.wraps(fn)
68
+ def wrapper(self, x, y=None, sample_weight=None, **kwargs):
69
+ with tf.device("cpu"):
70
+ x, y, sample_weight = convert_preprocessing_inputs(
71
+ (x, y, sample_weight)
72
+ )
73
+ with no_convert_scope():
74
+ x = fn(self, x, y=y, sample_weight=sample_weight, **kwargs)
75
+ return convert_preprocessing_outputs(x)
76
+
77
+ return wrapper
78
+
79
+
80
+ def convert_preprocessing_inputs(x):
81
+ """Convert raw inputs for preprocessing.
82
+
83
+ This function is used to convert raw inputs (strings, lists, `np.ndarray`s,
84
+ `jax.Array`s, `torch.Tensor`s, etc) to a canonical format for
85
+ preprocessing layers. All inputs will be converted to backend tensors if
86
+ possible, except ragged inputs and string inputs which be converted to tf
87
+ tensors regardless of backend.
88
+
89
+ `tuple` and `list` elements are handled differently by this function. A
90
+ `tuple` is assumed to enumerate separate inputs, and a `list` is assumed to
91
+ enumerate elements in a single array-like input. This makes it possible to
92
+ represent ragged and string inputs in a multi-backend format, as shown in
93
+ the examples below.
94
+
95
+ Examples:
96
+ ```python
97
+ # Two ragged arrays of token ids.
98
+ x = ([[1, 2, 3], [4, 5]], [[1, 2], [3, 4, 5]])
99
+ keras_hub.utils.convert_preprocessing_inputs(x)
100
+
101
+ # A batch of three samples each with two string segments.
102
+ x = (["hi", "hello", "hey"], ["bye", "later", "so long"])
103
+ keras_hub.utils.convert_preprocessing_inputs(x)
104
+
105
+ # A batch of features in a dictionary.
106
+ x = {
107
+ "text": ["hi", "hello", "hey"],
108
+ "images": np.ones((3, 64, 64, 3)),
109
+ "labels": [1, 0, 1],
110
+ }
111
+ keras_hub.utils.convert_preprocessing_inputs(x)
112
+ ```
113
+ """
114
+ if not tf.executing_eagerly() or in_no_convert_scope():
115
+ return x
116
+
117
+ if isinstance(x, dict):
118
+ return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()}
119
+ if isinstance(x, tuple):
120
+ return tuple(convert_preprocessing_inputs(v) for v in x)
121
+ if isinstance(x, (str, bytes)):
122
+ return tf.constant(x)
123
+ if isinstance(x, list):
124
+ try:
125
+ numpy_x = np.array(x)
126
+ except ValueError as e:
127
+ # If numpy conversion failed, try converting to a ragged array.
128
+ try:
129
+ return tf.ragged.constant(x)
130
+ except ValueError:
131
+ # If ragged conversion failed return to the numpy error.
132
+ raise e
133
+ # If we have a string input, use tf.tensor.
134
+ if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_:
135
+ return tf.convert_to_tensor(x)
136
+ # Numpy will default to int64, int32 works with more ops.
137
+ if numpy_x.dtype == np.int64:
138
+ numpy_x = numpy_x.astype(np.int32)
139
+ # We have non-ragged, non-string input. Use backbend type.
140
+ x = ops.convert_to_tensor(numpy_x)
141
+ # Torch will complain about device placement for GPU tensors.
142
+ if keras.config.backend() == "torch":
143
+ x = x.cpu()
144
+ return x
145
+ if is_tensor_type(x):
146
+ # String or ragged types we keep as tf.
147
+ if isinstance(x, tf.RaggedTensor) or x.dtype == tf.string:
148
+ return x
149
+ # If we have a string input, use tf.tensor.
150
+ if isinstance(x, np.ndarray) and x.dtype.type is np.str_:
151
+ return tf.convert_to_tensor(x)
152
+ x = ops.convert_to_tensor(x)
153
+ # Torch will complain about device placement for GPU tensors.
154
+ if keras.config.backend() == "torch":
155
+ x = x.cpu()
156
+ return x
157
+ return x
158
+
159
+
160
+ def convert_preprocessing_outputs(x):
161
+ """Convert outputs after preprocessing to a backend agnostic format.
162
+
163
+ This function is used to convert `tf.Tensor` and `tf.RaggedTensor` output
164
+ from preprocessing layers to either:
165
+
166
+ - The correct tensor type for the Keras backend framework.
167
+ - Python lists, in the case of ragged and string data.
168
+
169
+ This will automatically be called when on the output of preprocessing
170
+ layers or `keras_hub.models.Task`s with preprocessing included. It could be
171
+ used directly to convert a `tf.data.Dataset` output to a backend agnostic
172
+ type.
173
+
174
+ Examples:
175
+ ```python
176
+ # Two ragged arrays of token ids.
177
+ x = tf.ragged.constant([[1, 2, 3], [4, 5]])
178
+ keras_hub.utils.convert_preprocessing_outputs(x)
179
+
180
+ # A batch of three samples each with two string segments.
181
+ x = (tf.constant["hi", "yo", "hey"]), tf.constant(["bye", "ciao", ""]))
182
+ keras_hub.utils.convert_preprocessing_outputs(x)
183
+
184
+ # A batch of features in a dictionary.
185
+ x = {
186
+ "text": tf.constant(["hi", "hello", "hey"]),
187
+ "images": tf.ones((3, 64, 64, 3)),
188
+ "labels": tf.constant([1, 0, 1]),
189
+ }
190
+ keras_hub.utils.convert_preprocessing_outputs(x)
191
+ ```
192
+ """
193
+ if not tf.executing_eagerly() or in_no_convert_scope():
194
+ return x
195
+
196
+ def convert(x):
197
+ if x is None:
198
+ return x
199
+ if isinstance(x, tf.RaggedTensor) or x.dtype == tf.string:
200
+ return tensor_to_list(x)
201
+ dtype = keras.backend.standardize_dtype(x.dtype)
202
+ return ops.convert_to_tensor(x, dtype=dtype)
203
+
204
+ return keras.tree.map_structure(convert, x)
205
+
206
+
26
207
  def _decode_strings_to_utf8(inputs):
27
208
  """Recursively decodes to list of strings with 'utf-8' encoding."""
28
209
  if isinstance(inputs, bytes):
@@ -52,75 +233,15 @@ def tensor_to_list(inputs):
52
233
  return list_outputs
53
234
 
54
235
 
55
- def convert_to_backend_tensor_or_python_list(x):
56
- """
57
- Convert a tensor to the backend friendly representation of the data.
58
-
59
- This wraps `ops.convert_to_tensor` to account for the fact that torch and
60
- jax both lack native types for ragged and string data.
61
-
62
- If we encounter one of these types in torch or jax, we will instead covert
63
- the tensor to simple pythonic types (lists of strings).
64
- """
65
- if isinstance(x, tf.RaggedTensor) or getattr(x, "dtype", None) == tf.string:
66
- return tensor_to_list(x)
67
- dtype = getattr(x, "dtype", "float32")
68
- dtype = keras.backend.standardize_dtype(dtype)
69
- return ops.convert_to_tensor(x, dtype=dtype)
70
-
71
-
72
236
  def convert_to_ragged_batch(inputs):
73
- """Convert pythonic or numpy-like input to a 2-D `tf.RaggedTensor`.
74
-
75
- This is useful for text preprocessing layers which deal with already
76
- tokenized or split text.
77
-
78
- Args:
79
- inputs: A pythonic or numpy-like input to covert. This input should
80
- represent a possibly batched list of token sequences.
81
-
82
- Returns:
83
- An `(inputs, unbatched, rectangular)` tuple, where `inputs` is a
84
- 2-D `tf.RaggedTensor`, `unbatched` is `True` if the inputs were
85
- origianlly rank 1, and `rectangular` is `True` if the inputs rows are
86
- all of equal lengths.
87
- """
88
- # `tf.keras.layers.Layer` does a weird conversion in __call__, where a list
89
- # of lists of ints will become a list of list of scalar tensors. We could
90
- # clean this up if we no longer need to care about that case.
91
- if isinstance(inputs, (list, tuple)):
92
- if isinstance(inputs[0], (list, tuple)):
93
- rectangular = len(set([len(row) for row in inputs])) == 1
94
- rows = [
95
- tf.convert_to_tensor(row, dtype_hint="int32") for row in inputs
96
- ]
97
- inputs = tf.ragged.stack(rows).with_row_splits_dtype("int64")
98
- else:
99
- inputs = tf.convert_to_tensor(inputs)
100
- rectangular = True
101
- elif isinstance(inputs, tf.Tensor):
102
- rectangular = True
103
- elif isinstance(inputs, tf.RaggedTensor):
104
- rectangular = False
105
- elif hasattr(inputs, "__array__"):
106
- inputs = tf.convert_to_tensor(ops.convert_to_numpy(inputs))
107
- rectangular = True
108
- else:
109
- raise ValueError(
110
- f"Unknown tensor type. Tensor input can be passed as "
111
- "tensors, numpy arrays, or python lists. Received: "
112
- f"`type(inputs)={type(inputs)}`"
113
- )
114
- if inputs.shape.rank < 1 or inputs.shape.rank > 2:
115
- raise ValueError(
116
- f"Tokenized tensor input should be rank 1 (unbatched) or "
117
- f"rank 2 (batched). Received: `inputs.shape={input.shape}`"
118
- )
237
+ """Ensure a tf.Tensor is a ragged rank 2 tensor."""
238
+ if not isinstance(inputs, (tf.RaggedTensor, tf.Tensor)):
239
+ inputs = tf.convert_to_tensor(inputs)
119
240
  unbatched = inputs.shape.rank == 1
120
- rectangular = rectangular or unbatched
241
+ rectangular = isinstance(inputs, tf.Tensor)
121
242
  if unbatched:
122
243
  inputs = tf.expand_dims(inputs, 0)
123
- if isinstance(inputs, tf.Tensor):
244
+ if rectangular:
124
245
  inputs = tf.RaggedTensor.from_tensor(inputs)
125
246
  return inputs, unbatched, rectangular
126
247
 
@@ -135,10 +256,7 @@ def truncate_at_token(inputs, token, mask):
135
256
 
136
257
  def strip_to_ragged(token_ids, mask, ids_to_strip):
137
258
  """Remove masked and special tokens from a sequence before detokenizing."""
138
- token_ids = ops.convert_to_numpy(token_ids)
139
- token_ids = token_ids.astype("int32")
140
- mask = ops.convert_to_numpy(mask)
141
- mask = mask.astype("bool")
259
+ mask = tf.cast(mask, "bool")
142
260
  for id in ids_to_strip:
143
261
  mask = mask & (token_ids != id)
144
262
  return tf.ragged.boolean_mask(token_ids, mask)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  import numpy as np
15
15
 
16
- from keras_hub.src.utils.preset_utils import HF_CONFIG_FILE
17
- from keras_hub.src.utils.preset_utils import jax_memory_cleanup
18
- from keras_hub.src.utils.preset_utils import load_config
19
- from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
16
+ from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
17
+
18
+ backbone_cls = ResNetBackbone
20
19
 
21
20
 
22
21
  def convert_backbone_config(timm_config):
@@ -56,6 +55,8 @@ def convert_backbone_config(timm_config):
56
55
  stackwise_num_strides=[1, 2, 2, 2],
57
56
  block_type=block_type,
58
57
  use_pre_activation=use_pre_activation,
58
+ input_conv_filters=[64],
59
+ input_conv_kernel_sizes=[7],
59
60
  )
60
61
 
61
62
 
@@ -100,10 +101,10 @@ def convert_weights(backbone, loader, timm_config):
100
101
  for stack_index in range(num_stacks):
101
102
  for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
102
103
  if version == "v1":
103
- keras_name = f"v1_stack{stack_index}_block{block_idx}"
104
+ keras_name = f"stack{stack_index}_block{block_idx}"
104
105
  hf_name = f"layer{stack_index+1}.{block_idx}"
105
106
  else:
106
- keras_name = f"v2_stack{stack_index}_block{block_idx}"
107
+ keras_name = f"stack{stack_index}_block{block_idx}"
107
108
  hf_name = f"stages.{stack_index}.blocks.{block_idx}"
108
109
 
109
110
  if version == "v1":
@@ -159,13 +160,15 @@ def convert_weights(backbone, loader, timm_config):
159
160
  normalization_layer.build(normalization_layer._build_input_shape)
160
161
 
161
162
 
162
- def load_resnet_backbone(cls, preset, load_weights, **kwargs):
163
- timm_config = load_config(preset, HF_CONFIG_FILE)
164
- keras_config = convert_backbone_config(timm_config)
165
- backbone = cls(**keras_config, **kwargs)
166
- if load_weights:
167
- jax_memory_cleanup(backbone)
168
- # Use prefix="" to avoid using `get_prefixed_key`.
169
- with SafetensorLoader(preset, prefix="") as loader:
170
- convert_weights(backbone, loader, timm_config)
171
- return backbone
163
+ def convert_head(task, loader, timm_config):
164
+ v2 = "resnetv2_" in timm_config["architecture"]
165
+ prefix = "head.fc." if v2 else "fc."
166
+ loader.port_weight(
167
+ task.output_dense.kernel,
168
+ hf_weight_key=prefix + "weight",
169
+ hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
170
+ )
171
+ loader.port_weight(
172
+ task.output_dense.bias,
173
+ hf_weight_key=prefix + "bias",
174
+ )
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Convert timm models to KerasHub."""
15
+
16
+ from keras_hub.src.models.image_classifier import ImageClassifier
17
+ from keras_hub.src.utils.preset_utils import PresetLoader
18
+ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
19
+ from keras_hub.src.utils.timm import convert_resnet
20
+ from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
21
+
22
+
23
+ class TimmPresetLoader(PresetLoader):
24
+ def __init__(self, preset, config):
25
+ super().__init__(preset, config)
26
+ architecture = self.config["architecture"]
27
+ if "resnet" in architecture:
28
+ self.converter = convert_resnet
29
+ else:
30
+ raise ValueError(
31
+ "KerasHub has no converter for timm models "
32
+ f"with architecture `'{architecture}'`."
33
+ )
34
+
35
+ def check_backbone_class(self):
36
+ return self.converter.backbone_cls
37
+
38
+ def load_backbone(self, cls, load_weights, **kwargs):
39
+ keras_config = self.converter.convert_backbone_config(self.config)
40
+ backbone = cls(**{**keras_config, **kwargs})
41
+ if load_weights:
42
+ jax_memory_cleanup(backbone)
43
+ # Use prefix="" to avoid using `get_prefixed_key`.
44
+ with SafetensorLoader(self.preset, prefix="") as loader:
45
+ self.converter.convert_weights(backbone, loader, self.config)
46
+ return backbone
47
+
48
+ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
49
+ if not load_task_weights or not issubclass(cls, ImageClassifier):
50
+ return super().load_task(
51
+ cls, load_weights, load_task_weights, **kwargs
52
+ )
53
+ # Support loading the classification head for classifier models.
54
+ kwargs["num_classes"] = self.config["num_classes"]
55
+ task = super().load_task(cls, load_weights, load_task_weights, **kwargs)
56
+ if load_task_weights:
57
+ with SafetensorLoader(self.preset, prefix="") as loader:
58
+ self.converter.convert_head(task, loader, self.config)
59
+ return task
60
+
61
+ def load_image_converter(self, cls, **kwargs):
62
+ pretrained_cfg = self.config.get("pretrained_cfg", None)
63
+ if not pretrained_cfg or "input_size" not in pretrained_cfg:
64
+ return None
65
+ input_size = pretrained_cfg["input_size"]
66
+ return cls(width=input_size[1], height=input_size[2])
@@ -0,0 +1,193 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ from keras_hub.src.models.albert.albert_backbone import AlbertBackbone
17
+ from keras_hub.src.utils.preset_utils import get_file
18
+
19
+ backbone_cls = AlbertBackbone
20
+
21
+
22
+ def convert_backbone_config(transformers_config):
23
+ return {
24
+ "vocabulary_size": transformers_config["vocab_size"],
25
+ "num_layers": transformers_config["num_hidden_layers"],
26
+ "num_heads": transformers_config["num_attention_heads"],
27
+ "embedding_dim": transformers_config["embedding_size"],
28
+ "hidden_dim": transformers_config["hidden_size"],
29
+ "intermediate_dim": transformers_config["intermediate_size"],
30
+ "num_groups": transformers_config["num_hidden_groups"],
31
+ "num_inner_repetitions": transformers_config["inner_group_num"],
32
+ "dropout": transformers_config["attention_probs_dropout_prob"],
33
+ "max_sequence_length": transformers_config["max_position_embeddings"],
34
+ "num_segments": transformers_config["type_vocab_size"],
35
+ }
36
+
37
+
38
+ def convert_weights(backbone, loader, transformers_config):
39
+ # Embeddings
40
+ loader.port_weight(
41
+ keras_variable=backbone.token_embedding.embeddings,
42
+ hf_weight_key="albert.embeddings.word_embeddings.weight",
43
+ )
44
+ loader.port_weight(
45
+ keras_variable=backbone.position_embedding.position_embeddings,
46
+ hf_weight_key="albert.embeddings.position_embeddings.weight",
47
+ )
48
+ loader.port_weight(
49
+ keras_variable=backbone.segment_embedding.embeddings,
50
+ hf_weight_key="albert.embeddings.token_type_embeddings.weight",
51
+ )
52
+
53
+ # Normalization
54
+ loader.port_weight(
55
+ keras_variable=backbone.embeddings_layer_norm.gamma,
56
+ hf_weight_key="albert.embeddings.LayerNorm.weight",
57
+ )
58
+ loader.port_weight(
59
+ keras_variable=backbone.embeddings_layer_norm.beta,
60
+ hf_weight_key="albert.embeddings.LayerNorm.bias",
61
+ )
62
+
63
+ # Encoder Embeddings
64
+ loader.port_weight(
65
+ keras_variable=backbone.embeddings_projection.kernel,
66
+ hf_weight_key="albert.encoder.embedding_hidden_mapping_in.weight",
67
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
68
+ )
69
+ loader.port_weight(
70
+ keras_variable=backbone.embeddings_projection.bias,
71
+ hf_weight_key="albert.encoder.embedding_hidden_mapping_in.bias",
72
+ )
73
+
74
+ # Encoder Group Layers
75
+ for group_idx in range(backbone.num_groups):
76
+ for inner_layer_idx in range(backbone.num_inner_repetitions):
77
+ keras_group = backbone.get_layer(
78
+ f"group_{group_idx}_inner_layer_{inner_layer_idx}"
79
+ )
80
+ hf_group_prefix = (
81
+ "albert.encoder.albert_layer_groups."
82
+ f"{group_idx}.albert_layers.{inner_layer_idx}."
83
+ )
84
+
85
+ loader.port_weight(
86
+ keras_variable=keras_group._self_attention_layer.query_dense.kernel,
87
+ hf_weight_key=f"{hf_group_prefix}attention.query.weight",
88
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
89
+ np.transpose(hf_tensor), keras_shape
90
+ ),
91
+ )
92
+ loader.port_weight(
93
+ keras_variable=keras_group._self_attention_layer.query_dense.bias,
94
+ hf_weight_key=f"{hf_group_prefix}attention.query.bias",
95
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
96
+ hf_tensor, keras_shape
97
+ ),
98
+ )
99
+ loader.port_weight(
100
+ keras_variable=keras_group._self_attention_layer.key_dense.kernel,
101
+ hf_weight_key=f"{hf_group_prefix}attention.key.weight",
102
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
103
+ np.transpose(hf_tensor), keras_shape
104
+ ),
105
+ )
106
+ loader.port_weight(
107
+ keras_variable=keras_group._self_attention_layer.key_dense.bias,
108
+ hf_weight_key=f"{hf_group_prefix}attention.key.bias",
109
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
110
+ hf_tensor, keras_shape
111
+ ),
112
+ )
113
+ loader.port_weight(
114
+ keras_variable=keras_group._self_attention_layer.value_dense.kernel,
115
+ hf_weight_key=f"{hf_group_prefix}attention.value.weight",
116
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
117
+ np.transpose(hf_tensor), keras_shape
118
+ ),
119
+ )
120
+ loader.port_weight(
121
+ keras_variable=keras_group._self_attention_layer.value_dense.bias,
122
+ hf_weight_key=f"{hf_group_prefix}attention.value.bias",
123
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
124
+ hf_tensor, keras_shape
125
+ ),
126
+ )
127
+ loader.port_weight(
128
+ keras_variable=keras_group._self_attention_layer.output_dense.kernel,
129
+ hf_weight_key=f"{hf_group_prefix}attention.dense.weight",
130
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
131
+ np.transpose(hf_tensor), keras_shape
132
+ ),
133
+ )
134
+ loader.port_weight(
135
+ keras_variable=keras_group._self_attention_layer.output_dense.bias,
136
+ hf_weight_key=f"{hf_group_prefix}attention.dense.bias",
137
+ hook_fn=lambda hf_tensor, keras_shape: np.reshape(
138
+ hf_tensor, keras_shape
139
+ ),
140
+ )
141
+ loader.port_weight(
142
+ keras_variable=keras_group._self_attention_layer_norm.gamma,
143
+ hf_weight_key=f"{hf_group_prefix}attention.LayerNorm.weight",
144
+ )
145
+ loader.port_weight(
146
+ keras_variable=keras_group._self_attention_layer_norm.beta,
147
+ hf_weight_key=f"{hf_group_prefix}attention.LayerNorm.bias",
148
+ )
149
+ loader.port_weight(
150
+ keras_variable=keras_group._feedforward_intermediate_dense.kernel,
151
+ hf_weight_key=f"{hf_group_prefix}ffn.weight",
152
+ hook_fn=lambda hf_tensor, _: np.transpose(
153
+ hf_tensor, axes=(1, 0)
154
+ ),
155
+ )
156
+ loader.port_weight(
157
+ keras_variable=keras_group._feedforward_intermediate_dense.bias,
158
+ hf_weight_key=f"{hf_group_prefix}ffn.bias",
159
+ )
160
+ loader.port_weight(
161
+ keras_variable=keras_group._feedforward_output_dense.kernel,
162
+ hf_weight_key=f"{hf_group_prefix}ffn_output.weight",
163
+ hook_fn=lambda hf_tensor, _: np.transpose(
164
+ hf_tensor, axes=(1, 0)
165
+ ),
166
+ )
167
+ loader.port_weight(
168
+ keras_variable=keras_group._feedforward_output_dense.bias,
169
+ hf_weight_key=f"{hf_group_prefix}ffn_output.bias",
170
+ )
171
+ loader.port_weight(
172
+ keras_variable=keras_group._feedforward_layer_norm.gamma,
173
+ hf_weight_key=f"{hf_group_prefix}full_layer_layer_norm.weight",
174
+ )
175
+ loader.port_weight(
176
+ keras_variable=keras_group._feedforward_layer_norm.beta,
177
+ hf_weight_key=f"{hf_group_prefix}full_layer_layer_norm.bias",
178
+ )
179
+
180
+ # Pooler
181
+ loader.port_weight(
182
+ keras_variable=backbone.pooled_dense.kernel,
183
+ hf_weight_key="albert.pooler.weight",
184
+ hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
185
+ )
186
+ loader.port_weight(
187
+ keras_variable=backbone.pooled_dense.bias,
188
+ hf_weight_key="albert.pooler.bias",
189
+ )
190
+
191
+
192
+ def convert_tokenizer(cls, preset, **kwargs):
193
+ return cls(get_file(preset, "spiece.model"), **kwargs)