keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,6 @@ import keras
18
18
  from absl import logging
19
19
  from packaging.version import parse
20
20
 
21
- from keras_hub.src.utils.tensor_utils import is_tensor_type
22
-
23
21
  try:
24
22
  import tensorflow as tf
25
23
  except ImportError:
@@ -39,54 +37,6 @@ def clone_initializer(initializer):
39
37
  return initializer.__class__.from_config(config)
40
38
 
41
39
 
42
- def convert_inputs_to_list_of_tensor_segments(x):
43
- """Converts user inputs to a list of a tensor segments.
44
-
45
- For models and layers which accept lists of string tensors to pack together,
46
- this method converts user inputs to a uniform format in a way that can be
47
- considered canonical for the library.
48
-
49
- We handle the following:
50
-
51
- - A single string will be converted to a tensor and wrapped in a list.
52
- - A list of strings will be converted to a tensor and wrapped in a list.
53
- - A single tensor will be wrapped in a list.
54
- - A list of tensors will be passed through unaltered.
55
-
56
- All other inputs will result in an error. This effectively means that users
57
- who would like to pack multiple segments together should convert those
58
- segments to tensors before calling the layer. This removes any ambiguity
59
- in the input for those cases.
60
- """
61
- # Check the input type.
62
- is_string = isinstance(x, (str, bytes))
63
- is_tensor = is_tensor_type(x)
64
- is_string_list = (
65
- isinstance(x, (list, tuple)) and x and isinstance(x[0], (str, bytes))
66
- )
67
- is_tensor_list = isinstance(x, (list, tuple)) and x and is_tensor_type(x[0])
68
-
69
- if is_string or is_string_list:
70
- # Automatically convert raw strings or string lists to tensors.
71
- # Wrap this input as a single (possibly batched) segment.
72
- x = [tf.convert_to_tensor(x)]
73
- elif is_tensor:
74
- # Automatically wrap a single tensor as a single segment.
75
- x = [x]
76
- elif is_tensor_list:
77
- # Pass lists of tensors though unaltered.
78
- x = x
79
- else:
80
- # Error for all other input.
81
- raise ValueError(
82
- f"Unsupported input for `x`. `x` should be a string, a list of "
83
- "strings, or a list of tensors. If passing multiple segments "
84
- "which should packed together, please convert your inputs to a "
85
- f"list of tensors. Received `x={x}`"
86
- )
87
- return x
88
-
89
-
90
40
  def print_msg(message, line_break=True):
91
41
  """Print the message to absl logging or stdout."""
92
42
  # Copied from core Keras.
@@ -60,6 +60,8 @@ TOKENIZER_ASSET_DIR = "assets/tokenizer"
60
60
  # Config file names.
61
61
  CONFIG_FILE = "config.json"
62
62
  TOKENIZER_CONFIG_FILE = "tokenizer.json"
63
+ AUDIO_CONVERTER_CONFIG_FILE = "audio_converter.json"
64
+ IMAGE_CONVERTER_CONFIG_FILE = "image_converter.json"
63
65
  TASK_CONFIG_FILE = "task.json"
64
66
  PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
65
67
  METADATA_FILE = "metadata.json"
@@ -77,10 +79,10 @@ SAFETENSOR_FILE = "model.safetensors"
77
79
 
78
80
  # Global state for preset registry.
79
81
  BUILTIN_PRESETS = {}
80
- BUILTIN_PRESETS_FOR_CLASS = collections.defaultdict(dict)
82
+ BUILTIN_PRESETS_FOR_BACKBONE = collections.defaultdict(dict)
81
83
 
82
84
 
83
- def register_presets(presets, classes):
85
+ def register_presets(presets, backbone_cls):
84
86
  """Register built-in presets for a set of classes.
85
87
 
86
88
  Note that this is intended only for models and presets shipped in the
@@ -88,18 +90,26 @@ def register_presets(presets, classes):
88
90
  """
89
91
  for preset in presets:
90
92
  BUILTIN_PRESETS[preset] = presets[preset]
91
- for cls in classes:
92
- BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset]
93
+ BUILTIN_PRESETS_FOR_BACKBONE[backbone_cls][preset] = presets[preset]
93
94
 
94
95
 
95
- def list_presets(cls):
96
+ def builtin_presets(cls):
96
97
  """Find all registered built-in presets for a class."""
97
- return dict(BUILTIN_PRESETS_FOR_CLASS[cls])
98
+ presets = {}
99
+ if cls in BUILTIN_PRESETS_FOR_BACKBONE:
100
+ presets.update(BUILTIN_PRESETS_FOR_BACKBONE[cls])
101
+ backbone_cls = getattr(cls, "backbone_cls", None)
102
+ if backbone_cls:
103
+ presets.update(builtin_presets(backbone_cls))
104
+ for subclass in list_subclasses(cls):
105
+ presets.update(builtin_presets(subclass))
106
+ return presets
98
107
 
99
108
 
100
109
  def list_subclasses(cls):
101
110
  """Find all registered subclasses of a class."""
102
- custom_objects = keras.saving.get_custom_objects().values()
111
+ # Deduplicate the lists, since we have to register object twice for compat.
112
+ custom_objects = set(keras.saving.get_custom_objects().values())
103
113
  subclasses = []
104
114
  for x in custom_objects:
105
115
  if inspect.isclass(x) and x != cls and issubclass(x, cls):
@@ -107,6 +117,26 @@ def list_subclasses(cls):
107
117
  return subclasses
108
118
 
109
119
 
120
+ def find_subclass(preset, cls, backbone_cls):
121
+ """Find a subclass that is compatible with backbone_cls."""
122
+ subclasses = list_subclasses(cls)
123
+ subclasses = filter(lambda x: x.backbone_cls == backbone_cls, subclasses)
124
+ subclasses = list(subclasses)
125
+ if not subclasses:
126
+ raise ValueError(
127
+ f"Unable to find a subclass of {cls.__name__} that is compatible "
128
+ f"with {backbone_cls.__name__} found in preset '{preset}'."
129
+ )
130
+ # If we find multiple subclasses, try to filter to direct subclasses of
131
+ # the class we are trying to instantiate.
132
+ if len(subclasses) > 1:
133
+ directs = list(filter(lambda x: x in cls.__bases__, subclasses))
134
+ if len(directs) > 1:
135
+ subclasses = directs
136
+ # Return the subclass that was registered first (prefer built-in classes).
137
+ return subclasses[0]
138
+
139
+
110
140
  def get_file(preset, path):
111
141
  """Download a preset file in necessary and return the local path."""
112
142
  # TODO: Add tests for FileNotFound exceptions.
@@ -197,7 +227,7 @@ def get_file(preset, path):
197
227
  else:
198
228
  raise ValueError(message)
199
229
  elif os.path.exists(preset):
200
- # Assume a local filepath.
230
+ # Assume a local filepath.pyth
201
231
  local_path = os.path.join(preset, path)
202
232
  if not os.path.exists(local_path):
203
233
  raise FileNotFoundError(
@@ -272,6 +302,7 @@ def recursive_pop(config, key):
272
302
  recursive_pop(value, key)
273
303
 
274
304
 
305
+ # TODO: refactor saving routines into a PresetSaver class?
275
306
  def make_preset_dir(preset):
276
307
  os.makedirs(preset, exist_ok=True)
277
308
 
@@ -314,19 +345,9 @@ def save_metadata(layer, preset):
314
345
  metadata_file.write(json.dumps(metadata, indent=4))
315
346
 
316
347
 
317
- def _validate_tokenizer(preset, allow_incomplete=False):
348
+ def _validate_tokenizer(preset):
318
349
  if not check_file_exists(preset, TOKENIZER_CONFIG_FILE):
319
- if allow_incomplete:
320
- logging.warning(
321
- f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`."
322
- )
323
- return
324
- else:
325
- raise FileNotFoundError(
326
- f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. "
327
- "To upload the model without a tokenizer, "
328
- "set `allow_incomplete=True`."
329
- )
350
+ return
330
351
  config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
331
352
  try:
332
353
  with open(config_path, encoding="utf-8") as config_file:
@@ -377,7 +398,7 @@ def _validate_backbone(preset):
377
398
  )
378
399
 
379
400
 
380
- def get_snake_case(name):
401
+ def to_snake_case(name):
381
402
  name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
382
403
  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
383
404
 
@@ -386,7 +407,7 @@ def create_model_card(preset):
386
407
  model_card_path = os.path.join(preset, README_FILE)
387
408
  markdown_content = ""
388
409
 
389
- config = load_config(preset, CONFIG_FILE)
410
+ config = load_json(preset, CONFIG_FILE)
390
411
  model_name = (
391
412
  config["class_name"].replace("Backbone", "")
392
413
  if config["class_name"].endswith("Backbone")
@@ -395,7 +416,7 @@ def create_model_card(preset):
395
416
 
396
417
  task_type = None
397
418
  if check_file_exists(preset, TASK_CONFIG_FILE):
398
- task_config = load_config(preset, TASK_CONFIG_FILE)
419
+ task_config = load_json(preset, TASK_CONFIG_FILE)
399
420
  task_type = (
400
421
  task_config["class_name"].replace(model_name, "")
401
422
  if task_config["class_name"].startswith(model_name)
@@ -407,12 +428,12 @@ def create_model_card(preset):
407
428
  markdown_content += "library_name: keras-hub\n"
408
429
  if task_type == "CausalLM":
409
430
  markdown_content += "pipeline_tag: text-generation\n"
410
- elif task_type == "Classifier":
431
+ elif task_type == "TextClassifier":
411
432
  markdown_content += "pipeline_tag: text-classification\n"
412
433
  markdown_content += "---\n"
413
434
 
414
435
  model_link = (
415
- f"https://keras.io/api/keras_hub/models/{get_snake_case(model_name)}"
436
+ f"https://keras.io/api/keras_hub/models/{to_snake_case(model_name)}"
416
437
  )
417
438
  markdown_content += (
418
439
  f"This is a [`{model_name}` model]({model_link}) "
@@ -454,7 +475,6 @@ def delete_model_card(preset):
454
475
  def upload_preset(
455
476
  uri,
456
477
  preset,
457
- allow_incomplete=False,
458
478
  ):
459
479
  """Upload a preset directory to a model hub.
460
480
 
@@ -466,9 +486,6 @@ def upload_preset(
466
486
  `hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
467
487
  Face Hub.
468
488
  preset: The path to the local model preset directory.
469
- allow_incomplete: If True, allows the upload of presets without
470
- a tokenizer configuration. Otherwise, a tokenizer
471
- is required.
472
489
  """
473
490
 
474
491
  # Check if preset directory exists.
@@ -476,7 +493,7 @@ def upload_preset(
476
493
  raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")
477
494
 
478
495
  _validate_backbone(preset)
479
- _validate_tokenizer(preset, allow_incomplete)
496
+ _validate_tokenizer(preset)
480
497
 
481
498
  if uri.startswith(KAGGLE_PREFIX):
482
499
  if kagglehub is None:
@@ -533,42 +550,14 @@ def upload_preset(
533
550
  )
534
551
 
535
552
 
536
- def load_config(preset, config_file=CONFIG_FILE):
553
+ def load_json(preset, config_file=CONFIG_FILE):
537
554
  config_path = get_file(preset, config_file)
538
555
  with open(config_path, encoding="utf-8") as config_file:
539
556
  config = json.load(config_file)
540
557
  return config
541
558
 
542
559
 
543
- def check_format(preset):
544
- if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
545
- preset, SAFETENSOR_CONFIG_FILE
546
- ):
547
- # Determine the format by parsing the config file.
548
- config = load_config(preset, HF_CONFIG_FILE)
549
- if "hf://timm" in preset or "architecture" in config:
550
- return "timm"
551
- return "transformers"
552
-
553
- if not check_file_exists(preset, METADATA_FILE):
554
- raise FileNotFoundError(
555
- f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
556
- "or you do not have access to it. This file is required to load a Keras model "
557
- "preset. Please verify that the model you are trying to load is a Keras model."
558
- )
559
- metadata = load_config(preset, METADATA_FILE)
560
- if "keras_version" not in metadata:
561
- raise ValueError(
562
- f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
563
- "Please verify that the model you are trying to load is a Keras model."
564
- )
565
- return "keras"
566
-
567
-
568
- def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
569
- kwargs = kwargs or {}
570
- config = load_config(preset, config_file)
571
-
560
+ def load_serialized_object(config, **kwargs):
572
561
  # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
573
562
  # Ensure that `dtype` is properly configured.
574
563
  dtype = kwargs.pop("dtype", None)
@@ -578,15 +567,18 @@ def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
578
567
  return keras.saving.deserialize_keras_object(config)
579
568
 
580
569
 
581
- def check_config_class(
582
- preset,
583
- config_file=CONFIG_FILE,
584
- ):
570
+ def check_config_class(config):
585
571
  """Validate a preset is being loaded on the correct class."""
586
- config_path = get_file(preset, config_file)
587
- with open(config_path, encoding="utf-8") as config_file:
588
- config = json.load(config_file)
589
- return keras.saving.get_registered_object(config["registered_name"])
572
+ registered_name = config["registered_name"]
573
+ cls = keras.saving.get_registered_object(registered_name)
574
+ if cls is None:
575
+ raise ValueError(
576
+ f"Attempting to load class {registered_name} with "
577
+ "`from_preset()`, but there is no class registered with Keras "
578
+ f"for {registered_name}. Make sure to register any custom "
579
+ "classes with `register_keras_serializable()`."
580
+ )
581
+ return cls
590
582
 
591
583
 
592
584
  def jax_memory_cleanup(layer):
@@ -619,3 +611,173 @@ def set_dtype_in_config(config, dtype=None):
619
611
  for k in policy_map_config["policy_map"].keys():
620
612
  policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
621
613
  return config
614
+
615
+
616
+ def get_preset_loader(preset):
617
+ if not check_file_exists(preset, CONFIG_FILE):
618
+ raise ValueError(
619
+ f"Preset {preset} has no {CONFIG_FILE}. Make sure the URI or "
620
+ "directory you are trying to load is a valid KerasHub preset and "
621
+ "and that you have permissions to read/download from this location."
622
+ )
623
+ # We currently assume all formats we support have a `config.json`, this is
624
+ # true, for Keras, Transformers, and timm. We infer the on disk format by
625
+ # inspecting the `config.json` file.
626
+ config = load_json(preset, CONFIG_FILE)
627
+ if "registered_name" in config:
628
+ # If we see registered_name, we assume a serialized Keras object.
629
+ return KerasPresetLoader(preset, config)
630
+ elif "model_type" in config:
631
+ # Avoid circular import.
632
+ from keras_hub.src.utils.transformers.preset_loader import (
633
+ TransformersPresetLoader,
634
+ )
635
+
636
+ # If we see model_type, we assume a Transformers style config.
637
+ return TransformersPresetLoader(preset, config)
638
+ elif "architecture" in config:
639
+ # Avoid circular import.
640
+ from keras_hub.src.utils.timm.preset_loader import TimmPresetLoader
641
+
642
+ # If we see "architecture", we assume a timm config. We could make this
643
+ # more robust later on if we need to.
644
+ return TimmPresetLoader(preset, config)
645
+
646
+ else:
647
+ contents = json.dumps(config, indent=4)
648
+ raise ValueError(
649
+ f"Unrecognized format for {CONFIG_FILE} in {preset}. "
650
+ "Create a preset with the `save_to_preset` utility on KerasHub "
651
+ f"models. Contents of {CONFIG_FILE}:\n{contents}"
652
+ )
653
+
654
+
655
+ class PresetLoader:
656
+ def __init__(self, preset, config):
657
+ self.config = config
658
+ self.preset = preset
659
+
660
+ def check_backbone_class(self):
661
+ """Infer the backbone architecture."""
662
+ raise NotImplementedError
663
+
664
+ def load_backbone(self, cls, load_weights, **kwargs):
665
+ """Load the backbone model from the preset."""
666
+ raise NotImplementedError
667
+
668
+ def load_tokenizer(self, cls, **kwargs):
669
+ """Load a tokenizer layer from the preset."""
670
+ raise NotImplementedError
671
+
672
+ def load_audio_converter(self, cls, **kwargs):
673
+ """Load an audio converter layer from the preset."""
674
+ raise NotImplementedError
675
+
676
+ def load_image_converter(self, cls, **kwargs):
677
+ """Load an image converter layer from the preset."""
678
+ raise NotImplementedError
679
+
680
+ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
681
+ """Load a task model from the preset.
682
+
683
+ By default, we create a task from a backbone and preprocessor with
684
+ default arguments. This means
685
+ """
686
+ if "backbone" not in kwargs:
687
+ backbone_class = cls.backbone_cls
688
+ # Forward dtype to backbone.
689
+ backbone_kwargs = {"dtype": kwargs.pop("dtype", None)}
690
+ kwargs["backbone"] = self.load_backbone(
691
+ backbone_class, load_weights, **backbone_kwargs
692
+ )
693
+ if "preprocessor" not in kwargs and cls.preprocessor_cls:
694
+ kwargs["preprocessor"] = self.load_preprocessor(
695
+ cls.preprocessor_cls,
696
+ )
697
+ return cls(**kwargs)
698
+
699
+ def load_preprocessor(self, cls, **kwargs):
700
+ """Load a prepocessor layer from the preset.
701
+
702
+ By default, we create a preprocessor from a tokenizer with default
703
+ arguments. This allow us to support transformers checkpoints by
704
+ only converting the backbone and tokenizer.
705
+ """
706
+ if "tokenizer" not in kwargs and cls.tokenizer_cls:
707
+ kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls)
708
+ if "audio_converter" not in kwargs and cls.audio_converter_cls:
709
+ kwargs["audio_converter"] = self.load_audio_converter(
710
+ cls.audio_converter_cls
711
+ )
712
+ if "image_converter" not in kwargs and cls.image_converter_cls:
713
+ kwargs["image_converter"] = self.load_image_converter(
714
+ cls.image_converter_cls
715
+ )
716
+ return cls(**kwargs)
717
+
718
+
719
+ class KerasPresetLoader(PresetLoader):
720
+ def check_backbone_class(self):
721
+ return check_config_class(self.config)
722
+
723
+ def load_backbone(self, cls, load_weights, **kwargs):
724
+ backbone = load_serialized_object(self.config, **kwargs)
725
+ if load_weights:
726
+ jax_memory_cleanup(backbone)
727
+ backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
728
+ return backbone
729
+
730
+ def load_tokenizer(self, cls, **kwargs):
731
+ tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE)
732
+ tokenizer = load_serialized_object(tokenizer_config, **kwargs)
733
+ tokenizer.load_preset_assets(self.preset)
734
+ return tokenizer
735
+
736
+ def load_audio_converter(self, cls, **kwargs):
737
+ converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
738
+ return load_serialized_object(converter_config, **kwargs)
739
+
740
+ def load_image_converter(self, cls, **kwargs):
741
+ converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
742
+ return load_serialized_object(converter_config, **kwargs)
743
+
744
+ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
745
+ # If there is no `task.json` or it's for the wrong class delegate to the
746
+ # super class loader.
747
+ if not check_file_exists(self.preset, TASK_CONFIG_FILE):
748
+ return super().load_task(
749
+ cls, load_weights, load_task_weights, **kwargs
750
+ )
751
+ task_config = load_json(self.preset, TASK_CONFIG_FILE)
752
+ if not issubclass(check_config_class(task_config), cls):
753
+ return super().load_task(
754
+ cls, load_weights, load_task_weights, **kwargs
755
+ )
756
+ # We found a `task.json` with a complete config for our class.
757
+ task = load_serialized_object(task_config, **kwargs)
758
+ if task.preprocessor and task.preprocessor.tokenizer:
759
+ task.preprocessor.tokenizer.load_preset_assets(self.preset)
760
+ if load_weights:
761
+ has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
762
+ if has_task_weights and load_task_weights:
763
+ jax_memory_cleanup(task)
764
+ task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
765
+ task.load_task_weights(task_weights)
766
+ else:
767
+ jax_memory_cleanup(task.backbone)
768
+ backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
769
+ task.backbone.load_weights(backbone_weights)
770
+ return task
771
+
772
+ def load_preprocessor(self, cls, **kwargs):
773
+ # If there is no `preprocessing.json` or it's for the wrong class,
774
+ # delegate to the super class loader.
775
+ if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
776
+ return super().load_preprocessor(cls, **kwargs)
777
+ preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
778
+ if not issubclass(check_config_class(preprocessor_json), cls):
779
+ return super().load_preprocessor(cls, **kwargs)
780
+ # We found a `preprocessing.json` with a complete config for our class.
781
+ preprocessor = load_serialized_object(preprocessor_json, **kwargs)
782
+ preprocessor.tokenizer.load_preset_assets(self.preset)
783
+ return preprocessor