keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev20240915160609__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. keras_hub/api/__init__.py +1 -0
  2. keras_hub/api/bounding_box/__init__.py +36 -0
  3. keras_hub/api/layers/__init__.py +14 -0
  4. keras_hub/api/models/__init__.py +97 -48
  5. keras_hub/api/tokenizers/__init__.py +30 -0
  6. keras_hub/src/bounding_box/__init__.py +13 -0
  7. keras_hub/src/bounding_box/converters.py +529 -0
  8. keras_hub/src/bounding_box/formats.py +162 -0
  9. keras_hub/src/bounding_box/iou.py +263 -0
  10. keras_hub/src/bounding_box/to_dense.py +95 -0
  11. keras_hub/src/bounding_box/to_ragged.py +99 -0
  12. keras_hub/src/bounding_box/utils.py +194 -0
  13. keras_hub/src/bounding_box/validate_format.py +99 -0
  14. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  15. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  16. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  17. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  18. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  19. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  20. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  21. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  22. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  23. keras_hub/src/models/albert/__init__.py +1 -2
  24. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  25. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  26. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  27. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  28. keras_hub/src/models/backbone.py +12 -34
  29. keras_hub/src/models/bart/__init__.py +1 -2
  30. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  31. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  32. keras_hub/src/models/bert/__init__.py +1 -5
  33. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  34. keras_hub/src/models/bert/bert_presets.py +1 -4
  35. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  36. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  37. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  38. keras_hub/src/models/bloom/__init__.py +1 -2
  39. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  40. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  41. keras_hub/src/models/causal_lm.py +10 -29
  42. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  43. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  44. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  45. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  46. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  47. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  48. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  49. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  50. keras_hub/src/models/distil_bert/__init__.py +1 -4
  51. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  52. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  53. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  54. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  55. keras_hub/src/models/efficientnet/__init__.py +13 -0
  56. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  57. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  58. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  59. keras_hub/src/models/electra/__init__.py +1 -2
  60. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  61. keras_hub/src/models/f_net/__init__.py +1 -2
  62. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  63. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  64. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  65. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  66. keras_hub/src/models/falcon/__init__.py +1 -2
  67. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  68. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  69. keras_hub/src/models/gemma/__init__.py +1 -2
  70. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  71. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  72. keras_hub/src/models/gpt2/__init__.py +1 -2
  73. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  74. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  75. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  76. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  77. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  78. keras_hub/src/models/image_classifier.py +0 -5
  79. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  80. keras_hub/src/models/llama/__init__.py +1 -2
  81. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  82. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  83. keras_hub/src/models/llama3/__init__.py +1 -2
  84. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  85. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  86. keras_hub/src/models/masked_lm.py +0 -2
  87. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  88. keras_hub/src/models/mistral/__init__.py +1 -2
  89. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  90. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  91. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  92. keras_hub/src/models/mobilenet/__init__.py +13 -0
  93. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  94. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  95. keras_hub/src/models/opt/__init__.py +1 -2
  96. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  97. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  98. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  99. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  100. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  101. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  102. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  103. keras_hub/src/models/phi3/__init__.py +1 -2
  104. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  105. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  106. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  107. keras_hub/src/models/preprocessor.py +72 -83
  108. keras_hub/src/models/resnet/__init__.py +6 -0
  109. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  110. keras_hub/src/models/resnet/resnet_image_classifier.py +24 -3
  111. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  112. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  113. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  114. keras_hub/src/models/roberta/__init__.py +1 -2
  115. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  116. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  117. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  118. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  119. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  120. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  121. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  122. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  123. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  124. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  125. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  126. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  127. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  128. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  129. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  130. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  131. keras_hub/src/models/t5/__init__.py +1 -2
  132. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  133. keras_hub/src/models/task.py +71 -116
  134. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  135. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  136. keras_hub/src/models/whisper/__init__.py +1 -2
  137. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  138. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  139. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  140. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  141. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  142. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  143. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  144. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  145. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  146. keras_hub/src/tests/test_case.py +38 -0
  147. keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
  148. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  149. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
  150. keras_hub/src/tokenizers/tokenizer.py +67 -32
  151. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  152. keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
  153. keras_hub/src/utils/keras_utils.py +0 -50
  154. keras_hub/src/utils/preset_utils.py +220 -67
  155. keras_hub/src/utils/tensor_utils.py +187 -69
  156. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  157. keras_hub/src/utils/timm/preset_loader.py +66 -0
  158. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  159. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  160. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  161. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  162. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  163. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  164. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  165. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  166. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  167. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  168. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  169. keras_hub/src/version_utils.py +1 -1
  170. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/METADATA +1 -2
  171. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/RECORD +173 -143
  172. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/WHEEL +1 -1
  173. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  174. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  175. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  176. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  177. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  178. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  179. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  180. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  181. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  182. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  183. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  184. keras_hub/src/utils/timm/convert.py +0 -37
  185. keras_hub/src/utils/transformers/convert.py +0 -101
  186. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev20240915160609.dist-info}/top_level.txt +0 -0
@@ -60,6 +60,8 @@ TOKENIZER_ASSET_DIR = "assets/tokenizer"
60
60
  # Config file names.
61
61
  CONFIG_FILE = "config.json"
62
62
  TOKENIZER_CONFIG_FILE = "tokenizer.json"
63
+ AUDIO_CONVERTER_CONFIG_FILE = "audio_converter.json"
64
+ IMAGE_CONVERTER_CONFIG_FILE = "image_converter.json"
63
65
  TASK_CONFIG_FILE = "task.json"
64
66
  PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
65
67
  METADATA_FILE = "metadata.json"
@@ -77,10 +79,10 @@ SAFETENSOR_FILE = "model.safetensors"
77
79
 
78
80
  # Global state for preset registry.
79
81
  BUILTIN_PRESETS = {}
80
- BUILTIN_PRESETS_FOR_CLASS = collections.defaultdict(dict)
82
+ BUILTIN_PRESETS_FOR_BACKBONE = collections.defaultdict(dict)
81
83
 
82
84
 
83
- def register_presets(presets, classes):
85
+ def register_presets(presets, backbone_cls):
84
86
  """Register built-in presets for a set of classes.
85
87
 
86
88
  Note that this is intended only for models and presets shipped in the
@@ -88,18 +90,26 @@ def register_presets(presets, classes):
88
90
  """
89
91
  for preset in presets:
90
92
  BUILTIN_PRESETS[preset] = presets[preset]
91
- for cls in classes:
92
- BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset]
93
+ BUILTIN_PRESETS_FOR_BACKBONE[backbone_cls][preset] = presets[preset]
93
94
 
94
95
 
95
- def list_presets(cls):
96
+ def builtin_presets(cls):
96
97
  """Find all registered built-in presets for a class."""
97
- return dict(BUILTIN_PRESETS_FOR_CLASS[cls])
98
+ presets = {}
99
+ if cls in BUILTIN_PRESETS_FOR_BACKBONE:
100
+ presets.update(BUILTIN_PRESETS_FOR_BACKBONE[cls])
101
+ backbone_cls = getattr(cls, "backbone_cls", None)
102
+ if backbone_cls:
103
+ presets.update(builtin_presets(backbone_cls))
104
+ for subclass in list_subclasses(cls):
105
+ presets.update(builtin_presets(subclass))
106
+ return presets
98
107
 
99
108
 
100
109
  def list_subclasses(cls):
101
110
  """Find all registered subclasses of a class."""
102
- custom_objects = keras.saving.get_custom_objects().values()
111
+ # Deduplicate the lists, since we have to register object twice for compat.
112
+ custom_objects = set(keras.saving.get_custom_objects().values())
103
113
  subclasses = []
104
114
  for x in custom_objects:
105
115
  if inspect.isclass(x) and x != cls and issubclass(x, cls):
@@ -107,6 +117,26 @@ def list_subclasses(cls):
107
117
  return subclasses
108
118
 
109
119
 
120
+ def find_subclass(preset, cls, backbone_cls):
121
+ """Find a subclass that is compatible with backbone_cls."""
122
+ subclasses = list_subclasses(cls)
123
+ subclasses = filter(lambda x: x.backbone_cls == backbone_cls, subclasses)
124
+ subclasses = list(subclasses)
125
+ if not subclasses:
126
+ raise ValueError(
127
+ f"Unable to find a subclass of {cls.__name__} that is compatible "
128
+ f"with {backbone_cls.__name__} found in preset '{preset}'."
129
+ )
130
+ # If we find multiple subclasses, try to filter to direct subclasses of
131
+ # the class we are trying to instantiate.
132
+ if len(subclasses) > 1:
133
+ directs = list(filter(lambda x: x in cls.__bases__, subclasses))
134
+ if len(directs) > 1:
135
+ subclasses = directs
136
+ # Return the subclass that was registered first (prefer built-in classes).
137
+ return subclasses[0]
138
+
139
+
110
140
  def get_file(preset, path):
111
141
  """Download a preset file in necessary and return the local path."""
112
142
  # TODO: Add tests for FileNotFound exceptions.
@@ -197,7 +227,7 @@ def get_file(preset, path):
197
227
  else:
198
228
  raise ValueError(message)
199
229
  elif os.path.exists(preset):
200
- # Assume a local filepath.
230
+ # Assume a local filepath.pyth
201
231
  local_path = os.path.join(preset, path)
202
232
  if not os.path.exists(local_path):
203
233
  raise FileNotFoundError(
@@ -272,6 +302,7 @@ def recursive_pop(config, key):
272
302
  recursive_pop(value, key)
273
303
 
274
304
 
305
+ # TODO: refactor saving routines into a PresetSaver class?
275
306
  def make_preset_dir(preset):
276
307
  os.makedirs(preset, exist_ok=True)
277
308
 
@@ -314,19 +345,9 @@ def save_metadata(layer, preset):
314
345
  metadata_file.write(json.dumps(metadata, indent=4))
315
346
 
316
347
 
317
- def _validate_tokenizer(preset, allow_incomplete=False):
348
+ def _validate_tokenizer(preset):
318
349
  if not check_file_exists(preset, TOKENIZER_CONFIG_FILE):
319
- if allow_incomplete:
320
- logging.warning(
321
- f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`."
322
- )
323
- return
324
- else:
325
- raise FileNotFoundError(
326
- f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. "
327
- "To upload the model without a tokenizer, "
328
- "set `allow_incomplete=True`."
329
- )
350
+ return
330
351
  config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
331
352
  try:
332
353
  with open(config_path, encoding="utf-8") as config_file:
@@ -377,7 +398,7 @@ def _validate_backbone(preset):
377
398
  )
378
399
 
379
400
 
380
- def get_snake_case(name):
401
+ def to_snake_case(name):
381
402
  name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
382
403
  return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
383
404
 
@@ -386,7 +407,7 @@ def create_model_card(preset):
386
407
  model_card_path = os.path.join(preset, README_FILE)
387
408
  markdown_content = ""
388
409
 
389
- config = load_config(preset, CONFIG_FILE)
410
+ config = load_json(preset, CONFIG_FILE)
390
411
  model_name = (
391
412
  config["class_name"].replace("Backbone", "")
392
413
  if config["class_name"].endswith("Backbone")
@@ -395,7 +416,7 @@ def create_model_card(preset):
395
416
 
396
417
  task_type = None
397
418
  if check_file_exists(preset, TASK_CONFIG_FILE):
398
- task_config = load_config(preset, TASK_CONFIG_FILE)
419
+ task_config = load_json(preset, TASK_CONFIG_FILE)
399
420
  task_type = (
400
421
  task_config["class_name"].replace(model_name, "")
401
422
  if task_config["class_name"].startswith(model_name)
@@ -407,12 +428,12 @@ def create_model_card(preset):
407
428
  markdown_content += "library_name: keras-hub\n"
408
429
  if task_type == "CausalLM":
409
430
  markdown_content += "pipeline_tag: text-generation\n"
410
- elif task_type == "Classifier":
431
+ elif task_type == "TextClassifier":
411
432
  markdown_content += "pipeline_tag: text-classification\n"
412
433
  markdown_content += "---\n"
413
434
 
414
435
  model_link = (
415
- f"https://keras.io/api/keras_hub/models/{get_snake_case(model_name)}"
436
+ f"https://keras.io/api/keras_hub/models/{to_snake_case(model_name)}"
416
437
  )
417
438
  markdown_content += (
418
439
  f"This is a [`{model_name}` model]({model_link}) "
@@ -454,7 +475,6 @@ def delete_model_card(preset):
454
475
  def upload_preset(
455
476
  uri,
456
477
  preset,
457
- allow_incomplete=False,
458
478
  ):
459
479
  """Upload a preset directory to a model hub.
460
480
 
@@ -466,9 +486,6 @@ def upload_preset(
466
486
  `hf://[<HF_USERNAME>/]<MODEL>` will be uploaded to the Hugging
467
487
  Face Hub.
468
488
  preset: The path to the local model preset directory.
469
- allow_incomplete: If True, allows the upload of presets without
470
- a tokenizer configuration. Otherwise, a tokenizer
471
- is required.
472
489
  """
473
490
 
474
491
  # Check if preset directory exists.
@@ -476,7 +493,7 @@ def upload_preset(
476
493
  raise FileNotFoundError(f"The preset directory {preset} doesn't exist.")
477
494
 
478
495
  _validate_backbone(preset)
479
- _validate_tokenizer(preset, allow_incomplete)
496
+ _validate_tokenizer(preset)
480
497
 
481
498
  if uri.startswith(KAGGLE_PREFIX):
482
499
  if kagglehub is None:
@@ -533,42 +550,14 @@ def upload_preset(
533
550
  )
534
551
 
535
552
 
536
- def load_config(preset, config_file=CONFIG_FILE):
553
+ def load_json(preset, config_file=CONFIG_FILE):
537
554
  config_path = get_file(preset, config_file)
538
555
  with open(config_path, encoding="utf-8") as config_file:
539
556
  config = json.load(config_file)
540
557
  return config
541
558
 
542
559
 
543
- def check_format(preset):
544
- if check_file_exists(preset, SAFETENSOR_FILE) or check_file_exists(
545
- preset, SAFETENSOR_CONFIG_FILE
546
- ):
547
- # Determine the format by parsing the config file.
548
- config = load_config(preset, HF_CONFIG_FILE)
549
- if "hf://timm" in preset or "architecture" in config:
550
- return "timm"
551
- return "transformers"
552
-
553
- if not check_file_exists(preset, METADATA_FILE):
554
- raise FileNotFoundError(
555
- f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`, "
556
- "or you do not have access to it. This file is required to load a Keras model "
557
- "preset. Please verify that the model you are trying to load is a Keras model."
558
- )
559
- metadata = load_config(preset, METADATA_FILE)
560
- if "keras_version" not in metadata:
561
- raise ValueError(
562
- f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
563
- "Please verify that the model you are trying to load is a Keras model."
564
- )
565
- return "keras"
566
-
567
-
568
- def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
569
- kwargs = kwargs or {}
570
- config = load_config(preset, config_file)
571
-
560
+ def load_serialized_object(config, **kwargs):
572
561
  # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
573
562
  # Ensure that `dtype` is properly configured.
574
563
  dtype = kwargs.pop("dtype", None)
@@ -578,14 +567,8 @@ def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
578
567
  return keras.saving.deserialize_keras_object(config)
579
568
 
580
569
 
581
- def check_config_class(
582
- preset,
583
- config_file=CONFIG_FILE,
584
- ):
570
+ def check_config_class(config):
585
571
  """Validate a preset is being loaded on the correct class."""
586
- config_path = get_file(preset, config_file)
587
- with open(config_path, encoding="utf-8") as config_file:
588
- config = json.load(config_file)
589
572
  return keras.saving.get_registered_object(config["registered_name"])
590
573
 
591
574
 
@@ -619,3 +602,173 @@ def set_dtype_in_config(config, dtype=None):
619
602
  for k in policy_map_config["policy_map"].keys():
620
603
  policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
621
604
  return config
605
+
606
+
607
+ def get_preset_loader(preset):
608
+ if not check_file_exists(preset, CONFIG_FILE):
609
+ raise ValueError(
610
+ f"Preset {preset} has no {CONFIG_FILE}. Make sure the URI or "
611
+ "directory you are trying to load is a valid KerasHub preset and "
612
+ "and that you have permissions to read/download from this location."
613
+ )
614
+ # We currently assume all formats we support have a `config.json`, this is
615
+ # true, for Keras, Transformers, and timm. We infer the on disk format by
616
+ # inspecting the `config.json` file.
617
+ config = load_json(preset, CONFIG_FILE)
618
+ if "registered_name" in config:
619
+ # If we see registered_name, we assume a serialized Keras object.
620
+ return KerasPresetLoader(preset, config)
621
+ elif "model_type" in config:
622
+ # Avoid circular import.
623
+ from keras_hub.src.utils.transformers.preset_loader import (
624
+ TransformersPresetLoader,
625
+ )
626
+
627
+ # If we see model_type, we assume a Transformers style config.
628
+ return TransformersPresetLoader(preset, config)
629
+ elif "architecture" in config:
630
+ # Avoid circular import.
631
+ from keras_hub.src.utils.timm.preset_loader import TimmPresetLoader
632
+
633
+ # If we see "architecture", we assume a timm config. We could make this
634
+ # more robust later on if we need to.
635
+ return TimmPresetLoader(preset, config)
636
+
637
+ else:
638
+ contents = json.dumps(config, indent=4)
639
+ raise ValueError(
640
+ f"Unrecognized format for {CONFIG_FILE} in {preset}. "
641
+ "Create a preset with the `save_to_preset` utility on KerasHub "
642
+ f"models. Contents of {CONFIG_FILE}:\n{contents}"
643
+ )
644
+
645
+
646
+ class PresetLoader:
647
+ def __init__(self, preset, config):
648
+ self.config = config
649
+ self.preset = preset
650
+
651
+ def check_backbone_class(self):
652
+ """Infer the backbone architecture."""
653
+ raise NotImplementedError
654
+
655
+ def load_backbone(self, cls, load_weights, **kwargs):
656
+ """Load the backbone model from the preset."""
657
+ raise NotImplementedError
658
+
659
+ def load_tokenizer(self, cls, **kwargs):
660
+ """Load a tokenizer layer from the preset."""
661
+ raise NotImplementedError
662
+
663
+ def load_audio_converter(self, cls, **kwargs):
664
+ """Load an audio converter layer from the preset."""
665
+ raise NotImplementedError
666
+
667
+ def load_image_converter(self, cls, **kwargs):
668
+ """Load an image converter layer from the preset."""
669
+ raise NotImplementedError
670
+
671
+ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
672
+ """Load a task model from the preset.
673
+
674
+ By default, we create a task from a backbone and preprocessor with
675
+ default arguments. This means
676
+ """
677
+ if "backbone" not in kwargs:
678
+ backbone_class = cls.backbone_cls
679
+ # Forward dtype to backbone.
680
+ backbone_kwargs = {"dtype": kwargs.pop("dtype", None)}
681
+ kwargs["backbone"] = self.load_backbone(
682
+ backbone_class, load_weights, **backbone_kwargs
683
+ )
684
+ if "preprocessor" not in kwargs and cls.preprocessor_cls:
685
+ kwargs["preprocessor"] = self.load_preprocessor(
686
+ cls.preprocessor_cls,
687
+ )
688
+ return cls(**kwargs)
689
+
690
+ def load_preprocessor(self, cls, **kwargs):
691
+ """Load a prepocessor layer from the preset.
692
+
693
+ By default, we create a preprocessor from a tokenizer with default
694
+ arguments. This allow us to support transformers checkpoints by
695
+ only converting the backbone and tokenizer.
696
+ """
697
+ if "tokenizer" not in kwargs and cls.tokenizer_cls:
698
+ kwargs["tokenizer"] = self.load_tokenizer(cls.tokenizer_cls)
699
+ if "audio_converter" not in kwargs and cls.audio_converter_cls:
700
+ kwargs["audio_converter"] = self.load_audio_converter(
701
+ cls.audio_converter_cls
702
+ )
703
+ if "image_converter" not in kwargs and cls.image_converter_cls:
704
+ kwargs["image_converter"] = self.load_image_converter(
705
+ cls.image_converter_cls
706
+ )
707
+ return cls(**kwargs)
708
+
709
+
710
+ class KerasPresetLoader(PresetLoader):
711
+ def check_backbone_class(self):
712
+ return check_config_class(self.config)
713
+
714
+ def load_backbone(self, cls, load_weights, **kwargs):
715
+ backbone = load_serialized_object(self.config, **kwargs)
716
+ if load_weights:
717
+ jax_memory_cleanup(backbone)
718
+ backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
719
+ return backbone
720
+
721
+ def load_tokenizer(self, cls, **kwargs):
722
+ tokenizer_config = load_json(self.preset, TOKENIZER_CONFIG_FILE)
723
+ tokenizer = load_serialized_object(tokenizer_config, **kwargs)
724
+ tokenizer.load_preset_assets(self.preset)
725
+ return tokenizer
726
+
727
+ def load_audio_converter(self, cls, **kwargs):
728
+ converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
729
+ return load_serialized_object(converter_config, **kwargs)
730
+
731
+ def load_image_converter(self, cls, **kwargs):
732
+ converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
733
+ return load_serialized_object(converter_config, **kwargs)
734
+
735
+ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
736
+ # If there is no `task.json` or it's for the wrong class delegate to the
737
+ # super class loader.
738
+ if not check_file_exists(self.preset, TASK_CONFIG_FILE):
739
+ return super().load_task(
740
+ cls, load_weights, load_task_weights, **kwargs
741
+ )
742
+ task_config = load_json(self.preset, TASK_CONFIG_FILE)
743
+ if not issubclass(check_config_class(task_config), cls):
744
+ return super().load_task(
745
+ cls, load_weights, load_task_weights, **kwargs
746
+ )
747
+ # We found a `task.json` with a complete config for our class.
748
+ task = load_serialized_object(task_config, **kwargs)
749
+ if task.preprocessor and task.preprocessor.tokenizer:
750
+ task.preprocessor.tokenizer.load_preset_assets(self.preset)
751
+ if load_weights:
752
+ has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
753
+ if has_task_weights and load_task_weights:
754
+ jax_memory_cleanup(task)
755
+ task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
756
+ task.load_task_weights(task_weights)
757
+ else:
758
+ jax_memory_cleanup(task.backbone)
759
+ backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
760
+ task.backbone.load_weights(backbone_weights)
761
+ return task
762
+
763
+ def load_preprocessor(self, cls, **kwargs):
764
+ # If there is no `preprocessing.json` or it's for the wrong class,
765
+ # delegate to the super class loader.
766
+ if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
767
+ return super().load_preprocessor(cls, **kwargs)
768
+ preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
769
+ if not issubclass(check_config_class(preprocessor_json), cls):
770
+ return super().load_preprocessor(cls, **kwargs)
771
+ # We found a `preprocessing.json` with a complete config for our class.
772
+ preprocessor = load_serialized_object(preprocessor_json, **kwargs)
773
+ preprocessor.tokenizer.load_preset_assets(self.preset)
774
+ return preprocessor