keras-hub-nightly 0.19.0.dev202412120352__py3-none-any.whl → 0.19.0.dev202412140350__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. keras_hub/api/layers/__init__.py +1 -0
  2. keras_hub/api/models/__init__.py +11 -6
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/converters.py +2 -2
  5. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  6. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  7. keras_hub/src/layers/modeling/rms_normalization.py +8 -6
  8. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  9. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  10. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  11. keras_hub/src/layers/modeling/transformer_encoder.py +3 -1
  12. keras_hub/src/metrics/bleu.py +1 -1
  13. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  14. keras_hub/src/models/bart/bart_backbone.py +4 -4
  15. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  16. keras_hub/src/models/bert/bert_presets.py +4 -2
  17. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  18. keras_hub/src/models/causal_lm.py +19 -15
  19. keras_hub/src/models/clip/clip_vision_embedding.py +1 -1
  20. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  21. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  22. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  23. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  24. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  25. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  26. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +17 -13
  27. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -3
  28. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +1 -1
  29. keras_hub/src/models/densenet/densenet_backbone.py +3 -1
  30. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -1
  31. keras_hub/src/models/densenet/densenet_presets.py +6 -6
  32. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  33. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  34. keras_hub/src/models/distil_bert/distil_bert_presets.py +2 -1
  35. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  36. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  37. keras_hub/src/models/efficientnet/cba.py +1 -1
  38. keras_hub/src/models/efficientnet/efficientnet_backbone.py +20 -8
  39. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +1 -1
  40. keras_hub/src/models/efficientnet/efficientnet_presets.py +12 -11
  41. keras_hub/src/models/efficientnet/fusedmbconv.py +3 -5
  42. keras_hub/src/models/efficientnet/mbconv.py +1 -1
  43. keras_hub/src/models/electra/electra_backbone.py +2 -2
  44. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  45. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  46. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  47. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  48. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  49. keras_hub/src/models/flux/flux_layers.py +46 -44
  50. keras_hub/src/models/flux/flux_maths.py +24 -17
  51. keras_hub/src/models/flux/flux_model.py +24 -19
  52. keras_hub/src/models/flux/flux_presets.py +2 -1
  53. keras_hub/src/models/flux/flux_text_to_image.py +7 -3
  54. keras_hub/src/models/gemma/gemma_backbone.py +27 -20
  55. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  56. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  57. keras_hub/src/models/gemma/gemma_presets.py +9 -3
  58. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  59. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  60. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  61. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  62. keras_hub/src/models/image_classifier_preprocessor.py +4 -1
  63. keras_hub/src/models/image_object_detector.py +2 -2
  64. keras_hub/src/models/image_object_detector_preprocessor.py +4 -4
  65. keras_hub/src/models/image_segmenter_preprocessor.py +2 -2
  66. keras_hub/src/models/llama/llama_backbone.py +34 -26
  67. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  68. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  69. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  70. keras_hub/src/models/mistral/mistral_causal_lm.py +3 -3
  71. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  72. keras_hub/src/models/mit/mit_backbone.py +4 -3
  73. keras_hub/src/models/mit/mit_layers.py +2 -1
  74. keras_hub/src/models/mobilenet/mobilenet_backbone.py +7 -7
  75. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  76. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +5 -3
  77. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +2 -2
  78. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  79. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  80. keras_hub/src/models/preprocessor.py +2 -2
  81. keras_hub/src/models/retinanet/feature_pyramid.py +3 -2
  82. keras_hub/src/models/retinanet/prediction_head.py +2 -2
  83. keras_hub/src/models/retinanet/retinanet_backbone.py +2 -2
  84. keras_hub/src/models/retinanet/retinanet_image_converter.py +1 -1
  85. keras_hub/src/models/retinanet/retinanet_object_detector.py +5 -6
  86. keras_hub/src/models/retinanet/retinanet_presets.py +2 -1
  87. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  88. keras_hub/src/models/roberta/roberta_presets.py +4 -2
  89. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  90. keras_hub/src/models/sam/sam_backbone.py +2 -2
  91. keras_hub/src/models/sam/sam_image_segmenter.py +6 -5
  92. keras_hub/src/models/sam/sam_layers.py +5 -3
  93. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  94. keras_hub/src/models/sam/sam_transformer.py +5 -4
  95. keras_hub/src/models/segformer/segformer_backbone.py +18 -14
  96. keras_hub/src/models/segformer/segformer_image_segmenter.py +51 -38
  97. keras_hub/src/models/segformer/segformer_presets.py +24 -12
  98. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  99. keras_hub/src/models/stable_diffusion_3/mmdit.py +20 -1
  100. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +1 -1
  101. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +13 -6
  102. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +2 -2
  103. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +7 -3
  104. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  105. keras_hub/src/models/task.py +4 -2
  106. keras_hub/src/models/text_classifier.py +2 -2
  107. keras_hub/src/models/text_to_image.py +5 -1
  108. keras_hub/src/models/vae/vae_layers.py +0 -1
  109. keras_hub/src/models/vit/__init__.py +5 -0
  110. keras_hub/src/models/vit/vit_backbone.py +152 -0
  111. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  112. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  113. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  114. keras_hub/src/models/vit/vit_layers.py +391 -0
  115. keras_hub/src/models/vit/vit_presets.py +49 -0
  116. keras_hub/src/models/vit_det/vit_det_backbone.py +4 -2
  117. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  118. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -3
  119. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  120. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  121. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  122. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  123. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  124. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  125. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  126. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  127. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  128. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  129. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  130. keras_hub/src/samplers/sampler.py +2 -1
  131. keras_hub/src/tests/test_case.py +2 -2
  132. keras_hub/src/tokenizers/byte_pair_tokenizer.py +2 -2
  133. keras_hub/src/tokenizers/byte_tokenizer.py +2 -8
  134. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  135. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +7 -12
  136. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +8 -5
  137. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +7 -3
  138. keras_hub/src/utils/preset_utils.py +25 -18
  139. keras_hub/src/utils/tensor_utils.py +4 -4
  140. keras_hub/src/utils/timm/convert_efficientnet.py +2 -4
  141. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  142. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  143. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  144. keras_hub/src/version_utils.py +1 -1
  145. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/METADATA +1 -1
  146. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/RECORD +148 -140
  147. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/WHEEL +0 -0
  148. {keras_hub_nightly-0.19.0.dev202412120352.dist-info → keras_hub_nightly-0.19.0.dev202412140350.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,7 @@ from keras import ops
3
3
 
4
4
 
5
5
  class ContentAndQueryEmbedding(keras.layers.Layer):
6
- """
7
- Content and Query Embedding.
6
+ """Content and Query Embedding.
8
7
 
9
8
  This class creates Content and Query Embeddings for XLNet model
10
9
  which is later used in XLNet Encoder.
@@ -20,9 +19,8 @@ class ContentAndQueryEmbedding(keras.layers.Layer):
20
19
  **kwargs: other keyword arguments.
21
20
 
22
21
  References:
23
- - [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
24
- (https://arxiv.org/abs/1906.08237)
25
- """
22
+ - [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
23
+ """ # noqa: E501
26
24
 
27
25
  def __init__(
28
26
  self, vocabulary_size, hidden_dim, dropout, name=None, **kwargs
@@ -11,17 +11,16 @@ def xlnet_kernel_initializer(stddev=0.02):
11
11
 
12
12
 
13
13
  class XLNetEncoder(keras.layers.Layer):
14
- """
15
- XLNet Encoder.
14
+ """XLNet Encoder.
16
15
 
17
16
  This class follows the architecture of the transformer encoder layer in the
18
17
  paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users
19
18
  can instantiate multiple instances of this class to stack up an encoder.
20
19
 
21
20
  Contrary to the single hidden state used in the paper mentioned above, this
22
- Encoder uses two hidden states, Content State and Query State. Thus calculates
23
- Two Stream Relative Attention using both of the hidden states. To know more
24
- please check the reference.
21
+ Encoder uses two hidden states, Content State and Query State. Thus
22
+ calculates Two Stream Relative Attention using both of the hidden states.
23
+ To know more please check the reference.
25
24
 
26
25
  Args:
27
26
  num_heads: int, the number of heads in the
@@ -44,9 +43,8 @@ class XLNetEncoder(keras.layers.Layer):
44
43
  **kwargs: other keyword arguments.
45
44
 
46
45
  References:
47
- - [XLNet: Generalized Autoregressive Pretraining for Language Understanding]
48
- (https://arxiv.org/abs/1906.08237)
49
- """
46
+ - [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
47
+ """ # noqa: E501
50
48
 
51
49
  def __init__(
52
50
  self,
@@ -60,7 +58,7 @@ class XLNetEncoder(keras.layers.Layer):
60
58
  kernel_initializer_range=0.02,
61
59
  bias_initializer="zeros",
62
60
  name=None,
63
- **kwargs
61
+ **kwargs,
64
62
  ):
65
63
  super().__init__(name=name, **kwargs)
66
64
  self.num_heads = num_heads
@@ -150,9 +150,8 @@ class ContrastiveSampler(Sampler):
150
150
  # The final score of each candidate token is weighted sum of
151
151
  # probability and similarity against previous tokens.
152
152
  accumulated_scores = (
153
- (1 - self.alpha) * next_token_probabilities
154
- - self.alpha * max_similarity_scores
155
- )
153
+ 1 - self.alpha
154
+ ) * next_token_probabilities - self.alpha * max_similarity_scores
156
155
  # Unflatten variables to shape [batch_size, self.k, ...] for
157
156
  # gather purpose.
158
157
  unflat_score = unflatten_beams(accumulated_scores)
@@ -95,7 +95,8 @@ class Sampler:
95
95
  def cond(prompt, cache, index):
96
96
  if stop_token_ids is None:
97
97
  return True
98
- # Stop if all sequences have produced a *new* id from stop_token_ids.
98
+ # Stop if all sequences have produced a *new* id from
99
+ # stop_token_ids.
99
100
  end_tokens = any_equal(prompt, stop_token_ids, ~mask)
100
101
  prompt_done = ops.any(end_tokens, axis=-1)
101
102
  return ops.logical_not(ops.all(prompt_done))
@@ -458,8 +458,8 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
458
458
 
459
459
  # Check variable length sequences.
460
460
  if variable_length_data is None:
461
- # If no variable length data passed, assume the second axis of all
462
- # inputs is our sequence axis and create it ourselves.
461
+ # If no variable length data passed, assume the second axis of
462
+ # all inputs is our sequence axis and create it ourselves.
463
463
  variable_length_data = [
464
464
  tree.map_structure(
465
465
  lambda x: x[:, :seq_length, ...], input_data
@@ -200,8 +200,8 @@ class BytePairTokenizer(tokenizer.Tokenizer):
200
200
  """Bype-pair encoding tokenizer layer.
201
201
 
202
202
  This BPE tokenizer provides the same functionality as the official GPT-2
203
- tokenizer. Given the same `vocabulary` which maps tokens to ids, and `merges`
204
- which describes BPE merge rules, it should provide the same output
203
+ tokenizer. Given the same `vocabulary` which maps tokens to ids, and
204
+ `merges` which describes BPE merge rules, it should provide the same output
205
205
  as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
206
206
  Different from OpenAI, this implementation is graph-compatible, so you can
207
207
  use it within a `tf.data` pipeline.
@@ -1,13 +1,5 @@
1
1
  import numpy as np
2
2
 
3
- try:
4
- import tensorflow as tf
5
- except ImportError:
6
- raise ImportError(
7
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
8
- "The TensorFlow package is required for data preprocessing with any backend."
9
- )
10
-
11
3
  from keras_hub.src.api_export import keras_hub_export
12
4
  from keras_hub.src.tokenizers import tokenizer
13
5
  from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
@@ -15,8 +7,10 @@ from keras_hub.src.utils.tensor_utils import is_int_dtype
15
7
  from keras_hub.src.utils.tensor_utils import preprocessing_function
16
8
 
17
9
  try:
10
+ import tensorflow as tf
18
11
  import tensorflow_text as tf_text
19
12
  except ImportError:
13
+ tf = None
20
14
  tf_text = None
21
15
 
22
16
 
@@ -4,14 +4,6 @@ import os
4
4
 
5
5
  import keras
6
6
 
7
- try:
8
- import tensorflow as tf
9
- except ImportError:
10
- raise ImportError(
11
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
12
- "The TensorFlow package is required for data preprocessing with any backend."
13
- )
14
-
15
7
  from keras_hub.src.api_export import keras_hub_export
16
8
  from keras_hub.src.tokenizers import tokenizer
17
9
  from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
@@ -21,11 +13,12 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
21
13
  from keras_hub.src.utils.tensor_utils import tensor_to_list
22
14
 
23
15
  try:
16
+ import tensorflow as tf
24
17
  import tensorflow_text as tf_text
25
18
  except ImportError:
19
+ tf = None
26
20
  tf_text = None
27
21
 
28
-
29
22
  VOCAB_FILENAME = "vocabulary.spm"
30
23
 
31
24
 
@@ -1,17 +1,11 @@
1
1
  import io
2
2
 
3
- try:
4
- import tensorflow as tf
5
- except ImportError:
6
- raise ImportError(
7
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
8
- "The TensorFlow package is required for data preprocessing with any backend."
9
- )
10
-
11
3
  try:
12
4
  import sentencepiece as spm
5
+ import tensorflow as tf
13
6
  except ImportError:
14
7
  spm = None
8
+ tf = None
15
9
 
16
10
  from keras_hub.src.api_export import keras_hub_export
17
11
 
@@ -52,7 +46,8 @@ def compute_sentence_piece_proto(
52
46
 
53
47
  Basic Usage (from Dataset).
54
48
  >>> inputs = tf.data.Dataset.from_tensor_slices(["Drifting Along"])
55
- >>> proto = keras_hub.tokenizers.compute_sentence_piece_proto(inputs, vocabulary_size=15)
49
+ >>> proto = keras_hub.tokenizers.compute_sentence_piece_proto(
50
+ ... inputs, vocabulary_size=15)
56
51
  >>> tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(proto=proto)
57
52
  >>> outputs = inputs.map(tokenizer)
58
53
  >>> for output in outputs:
@@ -92,7 +87,8 @@ def compute_sentence_piece_proto(
92
87
 
93
88
  if not isinstance(data, (list, tuple, tf.data.Dataset)):
94
89
  raise ValueError(
95
- "The `data` argument must be either `tf.data.Dataset` or `tuple` or `list`. "
90
+ "The `data` argument must be either `tf.data.Dataset` or "
91
+ "`tuple` or `list`. "
96
92
  f"Received: type(data)={type(data)}."
97
93
  )
98
94
 
@@ -105,8 +101,7 @@ def compute_sentence_piece_proto(
105
101
  model_writer = (
106
102
  open(proto_output_file, "wb") if proto_output_file else io.BytesIO()
107
103
  )
108
- is_dataset = isinstance(data, tf.data.Dataset)
109
- if is_dataset:
104
+ if tf is not None and isinstance(data, tf.data.Dataset):
110
105
  spm.SentencePieceTrainer.train(
111
106
  sentence_iterator=data.as_numpy_iterator(),
112
107
  model_writer=model_writer,
@@ -226,8 +226,9 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
226
226
  if normalization_form:
227
227
  if input_encoding != "UTF-8":
228
228
  raise ValueError(
229
- """Normalization Forms are Only Supported for Input Encoding
230
- UTF-8"""
229
+ "Normalization Forms are Only Supported for Input "
230
+ "Encoding UTF-8"
231
+ ""
231
232
  )
232
233
 
233
234
  super().__init__(dtype=dtype, **kwargs)
@@ -259,8 +260,9 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
259
260
  return config
260
261
 
261
262
  def vocabulary_size(self):
262
- """Get the size of the tokenizer vocabulary. None implies no vocabulary
263
- size was provided"""
263
+ """Get the size of the tokenizer vocabulary.
264
+
265
+ None implies no vocabulary size was provided"""
264
266
  return self._vocabulary_size
265
267
 
266
268
  def get_vocabulary(self):
@@ -334,6 +336,7 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer):
334
336
  id = ord(token)
335
337
  if id >= self.vocabulary_size():
336
338
  raise ValueError(
337
- f"Token {token} is not supported by `UnicodeCodepointTokenizer`."
339
+ f"Token {token} is not supported by "
340
+ "`UnicodeCodepointTokenizer`."
338
341
  )
339
342
  return id
@@ -55,7 +55,8 @@ def compute_word_piece_vocabulary(
55
55
  suffix_indicator: str. The characters prepended to a
56
56
  WordPiece to indicate that it is a suffix to another subword.
57
57
  E.g. `"##ing"`. Defaults to `"##"`.
58
- reserved_tokens: list of strings. A list of tokens that must be included in the vocabulary.
58
+ reserved_tokens: list of strings. A list of tokens that must be included
59
+ in the vocabulary.
59
60
 
60
61
  Returns:
61
62
  Returns a list of vocabulary terms.
@@ -67,7 +68,10 @@ def compute_word_piece_vocabulary(
67
68
  >>> vocab = compute_word_piece_vocabulary(inputs, 13)
68
69
  >>> vocab
69
70
  ['[PAD]', '[CLS]', '[SEP]', '[UNK]', '[MASK]', 'a', 'b', 'm', 'p', 'r', 's', 't', '##at']
70
- >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab, oov_token="[UNK]")
71
+ >>> tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
72
+ ... vocabulary=vocab,
73
+ ... oov_token="[UNK]",
74
+ ... )
71
75
  >>> outputs = inputs.map(tokenizer.tokenize)
72
76
  >>> for x in outputs:
73
77
  ... print(x)
@@ -112,7 +116,7 @@ def compute_word_piece_vocabulary(
112
116
  tokenizer = keras_hub.tokenizers.WordPieceTokenizer(vocabulary=vocab)
113
117
  inputs.map(tokenizer.tokenize)
114
118
  ```
115
- """
119
+ """ # noqa: E501
116
120
  # Read data files.
117
121
  if not isinstance(data, (list, tf.data.Dataset)):
118
122
  raise ValueError(
@@ -16,8 +16,9 @@ try:
16
16
  import tensorflow as tf
17
17
  except ImportError:
18
18
  raise ImportError(
19
- "To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
20
- "The TensorFlow package is required for data preprocessing with any backend."
19
+ "To use `keras_hub`, please install Tensorflow: "
20
+ "`pip install tensorflow`. The TensorFlow package is required for data "
21
+ "preprocessing with any backend."
21
22
  )
22
23
 
23
24
  try:
@@ -191,7 +192,8 @@ def get_file(preset, path):
191
192
  elif scheme == HF_SCHEME:
192
193
  if huggingface_hub is None:
193
194
  raise ImportError(
194
- f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. "
195
+ "`from_preset()` requires the `huggingface_hub` package to "
196
+ "load from '{preset}'. "
195
197
  "Please install with `pip install huggingface_hub`."
196
198
  )
197
199
  hf_handle = preset.removeprefix(HF_SCHEME + "://")
@@ -225,7 +227,8 @@ def get_file(preset, path):
225
227
  raise ValueError(
226
228
  "Unknown preset identifier. A preset must be a one of:\n"
227
229
  "1) a built-in preset identifier like `'bert_base_en'`\n"
228
- "2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n"
230
+ "2) a Kaggle Models handle like "
231
+ "`'kaggle://keras/bert/keras/bert_base_en'`\n"
229
232
  "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n"
230
233
  "4) a path to a local preset directory like `'./bert_base_en`\n"
231
234
  "Use `print(cls.presets.keys())` to view all built-in presets for "
@@ -342,8 +345,8 @@ def create_model_card(preset):
342
345
  markdown_content += f"* **{k}:** {v}\n"
343
346
  markdown_content += "\n"
344
347
  markdown_content += (
345
- "This model card has been generated automatically and should be completed "
346
- "by the model author. See [Model Cards documentation]"
348
+ "This model card has been generated automatically and should be "
349
+ "completed by the model author. See [Model Cards documentation]"
347
350
  "(https://huggingface.co/docs/hub/model-cards) for more information.\n"
348
351
  )
349
352
 
@@ -388,20 +391,22 @@ def upload_preset(
388
391
  if uri.startswith(KAGGLE_PREFIX):
389
392
  if kagglehub is None:
390
393
  raise ImportError(
391
- "Uploading a model to Kaggle Hub requires the `kagglehub` package. "
392
- "Please install with `pip install kagglehub`."
394
+ "Uploading a model to Kaggle Hub requires the `kagglehub` "
395
+ "package. Please install with `pip install kagglehub`."
393
396
  )
394
397
  if parse(kagglehub.__version__) < parse("0.2.4"):
395
398
  raise ImportError(
396
- "Uploading a model to Kaggle Hub requires the `kagglehub` package version `0.2.4` or higher. "
397
- "Please upgrade with `pip install --upgrade kagglehub`."
399
+ "Uploading a model to Kaggle Hub requires the `kagglehub` "
400
+ "package version `0.2.4` or higher. Please upgrade with "
401
+ "`pip install --upgrade kagglehub`."
398
402
  )
399
403
  kaggle_handle = uri.removeprefix(KAGGLE_PREFIX)
400
404
  kagglehub.model_upload(kaggle_handle, preset)
401
405
  elif uri.startswith(HF_PREFIX):
402
406
  if huggingface_hub is None:
403
407
  raise ImportError(
404
- f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. "
408
+ f"`upload_preset()` requires the `huggingface_hub` package "
409
+ f"to upload to '{uri}'. "
405
410
  "Please install with `pip install huggingface_hub`."
406
411
  )
407
412
  hf_handle = uri.removeprefix(HF_PREFIX)
@@ -413,14 +418,15 @@ def upload_preset(
413
418
  raise ValueError(
414
419
  "Unexpected Hugging Face URI. Hugging Face model handles "
415
420
  "should have the form 'hf://[{org}/]{model}'. For example, "
416
- "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly"
417
- f"upload to your user account. Received: URI={uri}."
421
+ "'hf://username/bert_base_en' or 'hf://bert_case_en' to "
422
+ f"implicitly upload to your user account. Received: URI={uri}."
418
423
  ) from e
419
424
  has_model_card = huggingface_hub.file_exists(
420
425
  repo_id=repo_url.repo_id, filename=README_FILE
421
426
  )
422
427
  if not has_model_card:
423
- # Remote repo doesn't have a model card so a basic model card is automatically generated.
428
+ # Remote repo doesn't have a model card so a basic model card is
429
+ # automatically generated.
424
430
  create_model_card(preset)
425
431
  try:
426
432
  huggingface_hub.upload_folder(
@@ -428,13 +434,14 @@ def upload_preset(
428
434
  )
429
435
  finally:
430
436
  if not has_model_card:
431
- # Clean up the preset directory in case user attempts to upload the
432
- # preset directory into Kaggle hub as well.
437
+ # Clean up the preset directory in case user attempts to upload
438
+ # the preset directory into Kaggle hub as well.
433
439
  delete_model_card(preset)
434
440
  else:
435
441
  raise ValueError(
436
442
  "Unknown URI. An URI must be a one of:\n"
437
- "1) a Kaggle Model handle like `'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
443
+ "1) a Kaggle Model handle like "
444
+ "`'kaggle://<KAGGLE_USERNAME>/<MODEL>/<FRAMEWORK>/<VARIATION>'`\n"
438
445
  "2) a Hugging Face handle like `'hf://[<HF_USERNAME>/]<MODEL>'`\n"
439
446
  f"Received: uri='{uri}'."
440
447
  )
@@ -778,7 +785,7 @@ class KerasPresetSaver:
778
785
  # E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
779
786
  # For `ResNetBackbone` we would have `ImageClassifier`.
780
787
  tasks = list_subclasses(Task)
781
- tasks = filter(lambda x: x.backbone_cls == type(layer), tasks)
788
+ tasks = filter(lambda x: x.backbone_cls is type(layer), tasks)
782
789
  tasks = [task.__base__.__name__ for task in tasks]
783
790
 
784
791
  keras_version = keras.version() if hasattr(keras, "version") else None
@@ -293,10 +293,10 @@ def any_equal(inputs, values, padding_mask):
293
293
 
294
294
  Args:
295
295
  inputs: Input tensor.
296
- values: List or iterable of tensors shaped like `inputs` or broadcastable
297
- by bit operators.
298
- padding_mask: Tensor with shape compatible with inputs that will condition
299
- output.
296
+ values: List or iterable of tensors shaped like `inputs` or
297
+ broadcastable by bit operators.
298
+ padding_mask: Tensor with shape compatible with inputs that will
299
+ condition output.
300
300
 
301
301
  Returns:
302
302
  A tensor with `inputs` shape where each position is True if it contains
@@ -198,10 +198,10 @@ def convert_weights(backbone, loader, timm_config):
198
198
  port_bias=True,
199
199
  depth_multiplier=1,
200
200
  ):
201
-
202
201
  def convert_pt_conv2d_kernel(pt_kernel):
203
202
  out_channels, in_channels_per_group, height, width = pt_kernel.shape
204
- # PT Convs are depthwise convs if and only if in_channels_per_group == 1
203
+ # PT Convs are depthwise convs if and only if
204
+ # `in_channels_per_group == 1`
205
205
  assert in_channels_per_group == 1
206
206
  pt_kernel = np.transpose(pt_kernel, (2, 3, 0, 1))
207
207
  in_channels = out_channels // depth_multiplier
@@ -248,7 +248,6 @@ def convert_weights(backbone, loader, timm_config):
248
248
  num_stacks = len(backbone.stackwise_kernel_sizes)
249
249
 
250
250
  for stack_index in range(num_stacks):
251
-
252
251
  block_type = backbone.stackwise_block_types[stack_index]
253
252
  expansion_ratio = backbone.stackwise_expansion_ratios[stack_index]
254
253
  repeats = backbone.stackwise_num_repeats[stack_index]
@@ -263,7 +262,6 @@ def convert_weights(backbone, loader, timm_config):
263
262
  ]
264
263
 
265
264
  for block_idx in range(repeats):
266
-
267
265
  conv_pw_count = 0
268
266
  bn_count = 1
269
267
 
@@ -0,0 +1,150 @@
1
+ import numpy as np
2
+
3
+ from keras_hub.src.models.vit.vit_backbone import ViTBackbone
4
+
5
+ backbone_cls = ViTBackbone
6
+
7
+
8
+ def convert_backbone_config(transformers_config):
9
+ image_size = transformers_config["image_size"]
10
+ return {
11
+ "image_shape": (image_size, image_size, 3),
12
+ "patch_size": transformers_config["patch_size"],
13
+ "num_layers": transformers_config["num_hidden_layers"],
14
+ "num_heads": transformers_config["num_attention_heads"],
15
+ "hidden_dim": transformers_config["hidden_size"],
16
+ "mlp_dim": transformers_config["intermediate_size"],
17
+ "dropout_rate": transformers_config["hidden_dropout_prob"],
18
+ "attention_dropout": transformers_config[
19
+ "attention_probs_dropout_prob"
20
+ ],
21
+ "use_mha_bias": transformers_config["qkv_bias"],
22
+ }
23
+
24
+
25
+ def convert_weights(backbone, loader, transformers_config):
26
+ def port_ln(keras_variable, weight_key):
27
+ loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
28
+ loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
29
+
30
+ def port_dense(keras_variable, weight_key):
31
+ loader.port_weight(
32
+ keras_variable.kernel,
33
+ f"{weight_key}.weight",
34
+ hook_fn=lambda x, _: x.T,
35
+ )
36
+ if keras_variable.bias is not None:
37
+ loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
38
+
39
+ def port_mha(keras_variable, weight_key, num_heads, hidden_dim):
40
+ # query
41
+ loader.port_weight(
42
+ keras_variable.query_dense.kernel,
43
+ f"{weight_key}.attention.query.weight",
44
+ hook_fn=lambda x, _: np.reshape(
45
+ x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
46
+ ),
47
+ )
48
+ loader.port_weight(
49
+ keras_variable.query_dense.bias,
50
+ f"{weight_key}.attention.query.bias",
51
+ hook_fn=lambda x, _: np.reshape(
52
+ x, (num_heads, hidden_dim // num_heads)
53
+ ),
54
+ )
55
+ # key
56
+ loader.port_weight(
57
+ keras_variable.key_dense.kernel,
58
+ f"{weight_key}.attention.key.weight",
59
+ hook_fn=lambda x, _: np.reshape(
60
+ x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
61
+ ),
62
+ )
63
+ loader.port_weight(
64
+ keras_variable.key_dense.bias,
65
+ f"{weight_key}.attention.key.bias",
66
+ hook_fn=lambda x, _: np.reshape(
67
+ x, (num_heads, hidden_dim // num_heads)
68
+ ),
69
+ )
70
+ # value
71
+ loader.port_weight(
72
+ keras_variable.value_dense.kernel,
73
+ f"{weight_key}.attention.value.weight",
74
+ hook_fn=lambda x, _: np.reshape(
75
+ x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
76
+ ),
77
+ )
78
+ loader.port_weight(
79
+ keras_variable.value_dense.bias,
80
+ f"{weight_key}.attention.value.bias",
81
+ hook_fn=lambda x, _: np.reshape(
82
+ x, (num_heads, hidden_dim // num_heads)
83
+ ),
84
+ )
85
+ # output
86
+ loader.port_weight(
87
+ keras_variable.output_dense.kernel,
88
+ f"{weight_key}.output.dense.weight",
89
+ hook_fn=lambda x, _: np.reshape(
90
+ x.T, (num_heads, hidden_dim // num_heads, hidden_dim)
91
+ ),
92
+ )
93
+ loader.port_weight(
94
+ keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias"
95
+ )
96
+
97
+ loader.port_weight(
98
+ keras_variable=backbone.layers[1].patch_embedding.kernel,
99
+ hf_weight_key="vit.embeddings.patch_embeddings.projection.weight",
100
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
101
+ )
102
+
103
+ loader.port_weight(
104
+ backbone.layers[1].patch_embedding.bias,
105
+ "vit.embeddings.patch_embeddings.projection.bias",
106
+ )
107
+
108
+ loader.port_weight(
109
+ backbone.layers[1].class_token,
110
+ "vit.embeddings.cls_token",
111
+ )
112
+
113
+ loader.port_weight(
114
+ backbone.layers[1].position_embedding.embeddings,
115
+ "vit.embeddings.position_embeddings",
116
+ hook_fn=lambda x, _: x[0],
117
+ )
118
+ encoder_layers = backbone.layers[2].encoder_layers
119
+ for i, encoder_block in enumerate(encoder_layers):
120
+ prefix = "vit.encoder.layer"
121
+ num_heads = encoder_block.num_heads
122
+ hidden_dim = encoder_block.hidden_dim
123
+
124
+ port_mha(
125
+ encoder_block.mha,
126
+ f"{prefix}.{i}.attention",
127
+ num_heads,
128
+ hidden_dim,
129
+ )
130
+ port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before")
131
+ port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after")
132
+
133
+ port_dense(
134
+ encoder_block.mlp.dense_1, f"{prefix}.{i}.intermediate.dense"
135
+ )
136
+ port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense")
137
+ port_ln(backbone.layers[2].layer_norm, "vit.layernorm")
138
+
139
+
140
+ def convert_head(task, loader, transformers_config):
141
+ prefix = "classifier."
142
+ loader.port_weight(
143
+ task.output_dense.kernel,
144
+ hf_weight_key=prefix + "weight",
145
+ hook_fn=lambda x, _: x.T,
146
+ )
147
+ loader.port_weight(
148
+ task.output_dense.bias,
149
+ hf_weight_key=prefix + "bias",
150
+ )
@@ -1,5 +1,6 @@
1
1
  """Convert huggingface models to KerasHub."""
2
2
 
3
+ from keras_hub.src.models.image_classifier import ImageClassifier
3
4
  from keras_hub.src.utils.preset_utils import PresetLoader
4
5
  from keras_hub.src.utils.preset_utils import jax_memory_cleanup
5
6
  from keras_hub.src.utils.transformers import convert_albert
@@ -11,6 +12,7 @@ from keras_hub.src.utils.transformers import convert_gpt2
11
12
  from keras_hub.src.utils.transformers import convert_llama3
12
13
  from keras_hub.src.utils.transformers import convert_mistral
13
14
  from keras_hub.src.utils.transformers import convert_pali_gemma
15
+ from keras_hub.src.utils.transformers import convert_vit
14
16
  from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
15
17
 
16
18
 
@@ -37,6 +39,8 @@ class TransformersPresetLoader(PresetLoader):
37
39
  self.converter = convert_mistral
38
40
  elif model_type == "paligemma":
39
41
  self.converter = convert_pali_gemma
42
+ elif model_type == "vit":
43
+ self.converter = convert_vit
40
44
  else:
41
45
  raise ValueError(
42
46
  "KerasHub has no converter for huggingface/transformers models "
@@ -55,6 +59,25 @@ class TransformersPresetLoader(PresetLoader):
55
59
  self.converter.convert_weights(backbone, loader, self.config)
56
60
  return backbone
57
61
 
62
+ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
63
+ architecture = self.config["architectures"][0]
64
+ if (
65
+ not load_task_weights
66
+ or not issubclass(cls, ImageClassifier)
67
+ or architecture == "ViTModel"
68
+ ):
69
+ return super().load_task(
70
+ cls, load_weights, load_task_weights, **kwargs
71
+ )
72
+ # Support loading the classification head for classifier models.
73
+ if architecture == "ViTForImageClassification":
74
+ kwargs["num_classes"] = len(self.config["id2label"])
75
+ task = super().load_task(cls, load_weights, load_task_weights, **kwargs)
76
+ if load_task_weights:
77
+ with SafetensorLoader(self.preset, prefix="") as loader:
78
+ self.converter.convert_head(task, loader, self.config)
79
+ return task
80
+
58
81
  def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs):
59
82
  return self.converter.convert_tokenizer(cls, self.preset, **kwargs)
60
83
 
@@ -42,12 +42,13 @@ class SafetensorLoader(contextlib.ExitStack):
42
42
  """
43
43
  Determine and return a prefixed key for a given hf weight key.
44
44
 
45
- This method checks if there's a common prefix for the weight keys and caches it
46
- for future use.
45
+ This method checks if there's a common prefix for the weight keys and
46
+ caches it for future use.
47
47
 
48
48
  Args:
49
49
  hf_weight_key (str): The hf weight key to check for a prefix.
50
- dict_like (object): An object to get keys of safetensor file using keys() method.
50
+ dict_like (object): An object to get keys of safetensor file using
51
+ keys() method.
51
52
 
52
53
  Returns:
53
54
  str: The full key including the prefix (if any).