keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,46 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
3
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
+
5
+
6
+ @keras_hub_export(
7
+ "keras_hub.tokenizers.QwenMoeTokenizer",
8
+ )
9
+ class QwenMoeTokenizer(BytePairTokenizer):
10
+ """Tokenizer for Qwen Moe model.
11
+
12
+ This tokenizer implements byte-pair encoding (BPE) for Qwen models,
13
+ handling special tokens like BOS (beginning of sequence) and EOS (end of
14
+ sequence).
15
+
16
+ Args:
17
+ vocabulary: Dictionary mapping tokens to token IDs, or path to
18
+ vocabulary file.
19
+ merges: List of BPE merges, or path to merges file.
20
+ bos_token: Beginning of sequence token. Defaults to None.
21
+ eos_token: End of sequence token. Defaults to "<|endoftext|>".
22
+ misc_special_tokens: Set of additional special tokens. Defaults to
23
+ empty set.
24
+ """
25
+
26
+ backbone_cls = QwenMoeBackbone
27
+
28
+ def __init__(
29
+ self,
30
+ vocabulary=None,
31
+ merges=None,
32
+ **kwargs,
33
+ ):
34
+ # Add EOS token
35
+ eos_token = "<|endoftext|>"
36
+ self._add_special_token(eos_token, "end_token")
37
+
38
+ self.start_token_id = None
39
+ self.start_token = None
40
+ self.pad_token_id = 0
41
+
42
+ super().__init__(
43
+ vocabulary=vocabulary,
44
+ merges=merges,
45
+ **kwargs,
46
+ )
@@ -6,16 +6,3 @@ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
6
6
  @keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
7
7
  class RetinaNetImageConverter(ImageConverter):
8
8
  backbone_cls = RetinaNetBackbone
9
-
10
- def __init__(
11
- self,
12
- *args,
13
- **kwargs,
14
- ):
15
- # TODO: update presets and remove these old config options. They were
16
- # never needed.
17
- if "norm_mean" in kwargs:
18
- kwargs["offset"] = [-x for x in kwargs.pop("norm_mean")]
19
- if "norm_std" in kwargs:
20
- kwargs["scale"] = [1.0 / x for x in kwargs.pop("norm_std")]
21
- super().__init__(*args, **kwargs)
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 34121239,
12
12
  "path": "retinanet",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/3",
14
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/4",
15
15
  },
16
16
  "retinanet_resnet50_fpn_v2_coco": {
17
17
  "metadata": {
@@ -22,6 +22,6 @@ backbone_presets = {
22
22
  "params": 31558592,
23
23
  "path": "retinanet",
24
24
  },
25
- "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/2",
25
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/3",
26
26
  },
27
27
  }
@@ -1,5 +1,3 @@
1
- import keras
2
-
3
1
  from keras_hub.src.api_export import keras_hub_export
4
2
  from keras_hub.src.models.image_segmenter_preprocessor import (
5
3
  ImageSegmenterPreprocessor,
@@ -8,25 +6,9 @@ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
8
6
  from keras_hub.src.models.segformer.segformer_image_converter import (
9
7
  SegFormerImageConverter,
10
8
  )
11
- from keras_hub.src.utils.tensor_utils import preprocessing_function
12
-
13
- IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
14
- IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
15
9
 
16
10
 
17
11
  @keras_hub_export("keras_hub.models.SegFormerImageSegmenterPreprocessor")
18
12
  class SegFormerImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
19
13
  backbone_cls = SegFormerBackbone
20
14
  image_converter_cls = SegFormerImageConverter
21
-
22
- @preprocessing_function
23
- def call(self, x, y=None, sample_weight=None):
24
- if self.image_converter:
25
- x = self.image_converter(x)
26
- if y is not None:
27
- y = self.image_converter(y)
28
-
29
- x = x / 255
30
- x = (x - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD
31
-
32
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -10,7 +10,7 @@ presets = {
10
10
  "params": 3719027,
11
11
  "path": "segformer_b0",
12
12
  },
13
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_ade20k_512/2",
13
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_ade20k_512/3",
14
14
  },
15
15
  "segformer_b1_ade20k_512": {
16
16
  "metadata": {
@@ -21,7 +21,7 @@ presets = {
21
21
  "params": 13682643,
22
22
  "path": "segformer_b1",
23
23
  },
24
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/2",
24
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/5",
25
25
  },
26
26
  "segformer_b2_ade20k_512": {
27
27
  "metadata": {
@@ -32,7 +32,7 @@ presets = {
32
32
  "params": 24727507,
33
33
  "path": "segformer_b2",
34
34
  },
35
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_ade20k_512/2",
35
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_ade20k_512/3",
36
36
  },
37
37
  "segformer_b3_ade20k_512": {
38
38
  "metadata": {
@@ -43,7 +43,7 @@ presets = {
43
43
  "params": 44603347,
44
44
  "path": "segformer_b3",
45
45
  },
46
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_ade20k_512/2",
46
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_ade20k_512/3",
47
47
  },
48
48
  "segformer_b4_ade20k_512": {
49
49
  "metadata": {
@@ -54,7 +54,7 @@ presets = {
54
54
  "params": 61373907,
55
55
  "path": "segformer_b4",
56
56
  },
57
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_ade20k_512/2",
57
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_ade20k_512/3",
58
58
  },
59
59
  "segformer_b5_ade20k_640": {
60
60
  "metadata": {
@@ -65,7 +65,7 @@ presets = {
65
65
  "params": 81974227,
66
66
  "path": "segformer_b5",
67
67
  },
68
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_ade20k_640/2",
68
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_ade20k_640/3",
69
69
  },
70
70
  "segformer_b0_cityscapes_1024": {
71
71
  "metadata": {
@@ -76,7 +76,7 @@ presets = {
76
76
  "params": 3719027,
77
77
  "path": "segformer_b0",
78
78
  },
79
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_cityscapes_1024/2",
79
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_cityscapes_1024/3",
80
80
  },
81
81
  "segformer_b1_cityscapes_1024": {
82
82
  "metadata": {
@@ -87,7 +87,7 @@ presets = {
87
87
  "params": 13682643,
88
88
  "path": "segformer_b1",
89
89
  },
90
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/2",
90
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/1",
91
91
  },
92
92
  "segformer_b2_cityscapes_1024": {
93
93
  "metadata": {
@@ -98,7 +98,7 @@ presets = {
98
98
  "params": 24727507,
99
99
  "path": "segformer_b2",
100
100
  },
101
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_cityscapes_1024/2",
101
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_cityscapes_1024/3",
102
102
  },
103
103
  "segformer_b3_cityscapes_1024": {
104
104
  "metadata": {
@@ -109,7 +109,7 @@ presets = {
109
109
  "params": 44603347,
110
110
  "path": "segformer_b3",
111
111
  },
112
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_cityscapes_1024/2",
112
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_cityscapes_1024/3",
113
113
  },
114
114
  "segformer_b4_cityscapes_1024": {
115
115
  "metadata": {
@@ -120,7 +120,7 @@ presets = {
120
120
  "params": 61373907,
121
121
  "path": "segformer_b4",
122
122
  },
123
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_cityscapes_1024/2",
123
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_cityscapes_1024/3",
124
124
  },
125
125
  "segformer_b5_cityscapes_1024": {
126
126
  "metadata": {
@@ -131,6 +131,6 @@ presets = {
131
131
  "params": 81974227,
132
132
  "path": "segformer_b5",
133
133
  },
134
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_cityscapes_1024/2",
134
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_cityscapes_1024/3",
135
135
  },
136
136
  }
@@ -550,6 +550,12 @@ class StableDiffusion3Backbone(Backbone):
550
550
  guidance_scale=None,
551
551
  ):
552
552
  step = ops.convert_to_tensor(step)
553
+ if not keras.utils.is_keras_tensor(num_steps):
554
+ num_steps = ops.convert_to_tensor(num_steps)
555
+ if guidance_scale is not None and not keras.utils.is_keras_tensor(
556
+ guidance_scale
557
+ ):
558
+ guidance_scale = ops.convert_to_tensor(guidance_scale)
553
559
  next_step = ops.add(step, 1)
554
560
  sigma, timestep = self.scheduler(step, num_steps)
555
561
  next_sigma, _ = self.scheduler(next_step, num_steps)
@@ -236,14 +236,17 @@ class Task(PipelineModel):
236
236
  objects_to_skip=backbone_layer_ids,
237
237
  )
238
238
 
239
- def save_to_preset(self, preset_dir):
239
+ def save_to_preset(self, preset_dir, max_shard_size=10):
240
240
  """Save task to a preset directory.
241
241
 
242
242
  Args:
243
243
  preset_dir: The path to the local model preset directory.
244
+ max_shard_size: `int` or `float`. Maximum size in GB for each
245
+ sharded file. If `None`, no sharding will be done. Defaults to
246
+ `10`.
244
247
  """
245
248
  saver = get_preset_saver(preset_dir)
246
- saver.save_task(self)
249
+ saver.save_task(self, max_shard_size=max_shard_size)
247
250
 
248
251
  @property
249
252
  def layers(self):
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
2
+ from keras_hub.src.models.xception.xception_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, XceptionBackbone)
@@ -0,0 +1,188 @@
1
+ import functools
2
+
3
+ from keras import layers
4
+
5
+ from keras_hub.src.api_export import keras_hub_export
6
+ from keras_hub.src.models.backbone import Backbone
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.XceptionBackbone")
11
+ class XceptionBackbone(Backbone):
12
+ """Xception core network with hyperparameters.
13
+
14
+ This class implements a Xception backbone as described in
15
+ [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357).
16
+
17
+ Most users will want the pretrained presets available with this model. If
18
+ you are creating a custom backbone, this model provides customizability
19
+ through the `stackwise_conv_filters` and `stackwise_pooling` arguments. This
20
+ backbone assumes the same basic structure as the original Xception mode:
21
+ * Residuals and pre-activation everywhere but the first and last block.
22
+ * Conv layers for the first block only, separable conv layers elsewhere.
23
+
24
+ Args:
25
+ stackwise_conv_filters: list of list of ints. Each outermost list
26
+ entry represents a block, and each innermost list entry a conv
27
+ layer. The integer value specifies the number of filters for the
28
+ conv layer.
29
+ stackwise_pooling: list of bools. A list of booleans per block, where
30
+ each entry is true if the block should includes a max pooling layer
31
+ and false if it should not.
32
+ image_shape: tuple. The input shape without the batch size.
33
+ Defaults to `(None, None, 3)`.
34
+ data_format: `None` or str. If specified, either `"channels_last"` or
35
+ `"channels_first"`. If unspecified, the Keras default will be used.
36
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
37
+ to use for the model's computations and weights.
38
+
39
+ Examples:
40
+ ```python
41
+ input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3))
42
+
43
+ # Pretrained Xception backbone.
44
+ model = keras_hub.models.Backbone.from_preset("xception_41_imagenet")
45
+ model(input_data)
46
+
47
+ # Randomly initialized Xception backbone with a custom config.
48
+ model = keras_hub.models.XceptionBackbone(
49
+ stackwise_conv_filters=[[32, 64], [64, 128], [256, 256]],
50
+ stackwise_pooling=[True, True, False],
51
+ )
52
+ model(input_data)
53
+ ```
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ stackwise_conv_filters,
59
+ stackwise_pooling,
60
+ image_shape=(None, None, 3),
61
+ data_format=None,
62
+ dtype=None,
63
+ **kwargs,
64
+ ):
65
+ if len(stackwise_conv_filters) != len(stackwise_pooling):
66
+ raise ValueError("All stackwise args should have the same length.")
67
+
68
+ data_format = standardize_data_format(data_format)
69
+ channel_axis = -1 if data_format == "channels_last" else 1
70
+ num_blocks = len(stackwise_conv_filters)
71
+
72
+ # Layer shorcuts with common args.
73
+ norm = functools.partial(
74
+ layers.BatchNormalization,
75
+ axis=channel_axis,
76
+ dtype=dtype,
77
+ )
78
+ act = functools.partial(
79
+ layers.Activation,
80
+ activation="relu",
81
+ dtype=dtype,
82
+ )
83
+ conv = functools.partial(
84
+ layers.Conv2D,
85
+ kernel_size=(3, 3),
86
+ use_bias=False,
87
+ data_format=data_format,
88
+ dtype=dtype,
89
+ )
90
+ sep_conv = functools.partial(
91
+ layers.SeparableConv2D,
92
+ kernel_size=(3, 3),
93
+ padding="same",
94
+ use_bias=False,
95
+ data_format=data_format,
96
+ dtype=dtype,
97
+ )
98
+ point_conv = functools.partial(
99
+ layers.Conv2D,
100
+ kernel_size=(1, 1),
101
+ strides=(2, 2),
102
+ padding="same",
103
+ use_bias=False,
104
+ data_format=data_format,
105
+ dtype=dtype,
106
+ )
107
+ pool = functools.partial(
108
+ layers.MaxPool2D,
109
+ pool_size=(3, 3),
110
+ strides=(2, 2),
111
+ padding="same",
112
+ data_format=data_format,
113
+ dtype=dtype,
114
+ )
115
+
116
+ # === Functional Model ===
117
+ image_input = layers.Input(shape=image_shape)
118
+ x = image_input # Intermediate result.
119
+
120
+ # Iterate through the blocks.
121
+ for block_i in range(num_blocks):
122
+ first_block, last_block = block_i == 0, block_i == num_blocks - 1
123
+ block_filters = stackwise_conv_filters[block_i]
124
+ use_pooling = stackwise_pooling[block_i]
125
+
126
+ # Save the block input as a residual.
127
+ residual = x
128
+ for conv_i, filters in enumerate(block_filters):
129
+ # First block has post activation and strides on first conv.
130
+ if first_block:
131
+ prefix = f"block{block_i + 1}_conv{conv_i + 1}"
132
+ strides = (2, 2) if conv_i == 0 else (1, 1)
133
+ x = conv(filters, strides=strides, name=prefix)(x)
134
+ x = norm(name=f"{prefix}_bn")(x)
135
+ x = act(name=f"{prefix}_act")(x)
136
+ # Last block has post activation.
137
+ elif last_block:
138
+ prefix = f"block{block_i + 1}_sepconv{conv_i + 1}"
139
+ x = sep_conv(filters, name=prefix)(x)
140
+ x = norm(name=f"{prefix}_bn")(x)
141
+ x = act(name=f"{prefix}_act")(x)
142
+ else:
143
+ prefix = f"block{block_i + 1}_sepconv{conv_i + 1}"
144
+ # The first conv in second block has no activation.
145
+ if block_i != 1 or conv_i != 0:
146
+ x = act(name=f"{prefix}_act")(x)
147
+ x = sep_conv(filters, name=prefix)(x)
148
+ x = norm(name=f"{prefix}_bn")(x)
149
+
150
+ # Optional block pooling.
151
+ if use_pooling:
152
+ x = pool(name=f"block{block_i + 1}_pool")(x)
153
+
154
+ # Sum residual, first and last block do not have a residual.
155
+ if not first_block and not last_block:
156
+ prefix = f"block{block_i + 1}_residual"
157
+ filters = x.shape[channel_axis]
158
+ # Match filters with a pointwise conv if needed.
159
+ if filters != residual.shape[channel_axis]:
160
+ residual = point_conv(filters, name=f"{prefix}_conv")(
161
+ residual
162
+ )
163
+ residual = norm(name=f"{prefix}_bn")(residual)
164
+ x = layers.Add(name=f"{prefix}_add", dtype=dtype)([x, residual])
165
+
166
+ super().__init__(
167
+ inputs=image_input,
168
+ outputs=x,
169
+ dtype=dtype,
170
+ **kwargs,
171
+ )
172
+
173
+ # === Config ===
174
+ self.stackwise_conv_filters = stackwise_conv_filters
175
+ self.stackwise_pooling = stackwise_pooling
176
+ self.image_shape = image_shape
177
+ self.data_format = data_format
178
+
179
+ def get_config(self):
180
+ config = super().get_config()
181
+ config.update(
182
+ {
183
+ "stackwise_conv_filters": self.stackwise_conv_filters,
184
+ "stackwise_pooling": self.stackwise_pooling,
185
+ "image_shape": self.image_shape,
186
+ }
187
+ )
188
+ return config
@@ -0,0 +1,12 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier import ImageClassifier
3
+ from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
4
+ from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( # noqa: E501
5
+ XceptionImageClassifierPreprocessor,
6
+ )
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.XceptionImageClassifier")
10
+ class XceptionImageClassifier(ImageClassifier):
11
+ backbone_cls = XceptionBackbone
12
+ preprocessor_cls = XceptionImageClassifierPreprocessor
@@ -0,0 +1,14 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_classifier_preprocessor import (
3
+ ImageClassifierPreprocessor,
4
+ )
5
+ from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
6
+ from keras_hub.src.models.xception.xception_image_converter import (
7
+ XceptionImageConverter,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.XceptionImageClassifierPreprocessor")
12
+ class XceptionImageClassifierPreprocessor(ImageClassifierPreprocessor):
13
+ backbone_cls = XceptionBackbone
14
+ image_converter_cls = XceptionImageConverter
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.XceptionImageConverter")
7
+ class XceptionImageConverter(ImageConverter):
8
+ backbone_cls = XceptionBackbone
@@ -0,0 +1,14 @@
1
+ """Xception preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "xception_41_imagenet": {
5
+ "metadata": {
6
+ "description": (
7
+ "41-layer Xception model pre-trained on ImageNet 1k."
8
+ ),
9
+ "params": 20861480,
10
+ "path": "xception",
11
+ },
12
+ "kaggle_handle": "kaggle://keras/xception/keras/xception_41_imagenet/2",
13
+ },
14
+ }
@@ -0,0 +1,155 @@
1
+ import tensorflow as tf
2
+
3
+ from keras_hub.src.tokenizers.tokenizer import Tokenizer
4
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
5
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
6
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
7
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
8
+
9
+
10
+ class MockGemma3Tokenizer(Tokenizer):
11
+ def __init__(
12
+ self,
13
+ proto=None,
14
+ sequence_length=None,
15
+ dtype="int32",
16
+ add_bos=False,
17
+ add_eos=False,
18
+ **kwargs,
19
+ ):
20
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
21
+ raise ValueError(
22
+ "Output dtype must be an integer type or a string. "
23
+ f"Received: dtype={dtype}"
24
+ )
25
+
26
+ super().__init__(dtype=dtype, **kwargs)
27
+
28
+ self.vocabulary = [
29
+ "<pad>",
30
+ "<bos>",
31
+ "<eos>",
32
+ "<unk>",
33
+ "<start_of_image>",
34
+ "<end_of_image>",
35
+ "<start_of_turn>",
36
+ "<end_of_turn>",
37
+ "<img>",
38
+ "the",
39
+ "brown",
40
+ "earth",
41
+ "fox",
42
+ "is",
43
+ "quick",
44
+ "round",
45
+ "\n\n",
46
+ ]
47
+ self.string_to_id = tf.lookup.StaticHashTable(
48
+ tf.lookup.KeyValueTensorInitializer(
49
+ self.vocabulary, list(range(len(self.vocabulary)))
50
+ ),
51
+ default_value=3,
52
+ )
53
+ self.id_to_string = tf.lookup.StaticHashTable(
54
+ tf.lookup.KeyValueTensorInitializer(
55
+ list(range(len(self.vocabulary))), self.vocabulary
56
+ ),
57
+ default_value="<unk>",
58
+ )
59
+
60
+ # The usual tokens.
61
+ self._add_special_token("<bos>", "start_token")
62
+ self._add_special_token("<eos>", "end_token")
63
+ self._add_special_token("<pad>", "pad_token")
64
+
65
+ # Image placeholder token.
66
+ self._add_special_token("<img>", "image_placeholder")
67
+
68
+ # Some tokens which are used in the preprocessor. We need to keep them
69
+ # here so that the preprocessor works with `tf.data`.
70
+ self._add_special_token("<start_of_image>", "start_of_image_token")
71
+ self._add_special_token("<end_of_image>", "end_of_image_token")
72
+
73
+ # self.special_token_ids = [
74
+ # 0, 1, 2, 4, 5, 8
75
+ # ]
76
+
77
+ self.sequence_length = sequence_length
78
+ self.add_bos = add_bos
79
+ self.add_eos = add_eos
80
+
81
+ def vocabulary_size(self):
82
+ return len(self.vocabulary)
83
+
84
+ def get_vocabulary(self):
85
+ return self.vocabulary
86
+
87
+ def id_to_token(self, id):
88
+ return self.vocabulary[id]
89
+
90
+ def token_to_id(self, token):
91
+ return self.vocabulary.index(token)
92
+
93
+ @preprocessing_function
94
+ def tokenize(self, inputs):
95
+ inputs = tf.convert_to_tensor(inputs)
96
+ unbatched = inputs.shape.rank == 0
97
+ if unbatched:
98
+ inputs = tf.expand_dims(inputs, 0)
99
+
100
+ inputs = tf.strings.regex_replace(
101
+ inputs, self.start_of_image_token, f" {self.start_of_image_token} "
102
+ )
103
+ inputs = tf.strings.regex_replace(
104
+ inputs, self.end_of_image_token, f" {self.end_of_image_token} "
105
+ )
106
+ inputs = tf.strings.regex_replace(
107
+ inputs, self.image_placeholder, f" {self.image_placeholder} "
108
+ )
109
+ inputs = tf.strings.regex_replace(inputs, " ", " ")
110
+
111
+ sep_inputs = tf.strings.split(inputs, sep=" ")
112
+ tokens = self.string_to_id.lookup(sep_inputs)
113
+
114
+ if self.add_bos:
115
+ bos_tensor = tf.fill(
116
+ value=self.start_token_id,
117
+ dims=tokens.shape.as_list()[0:1] + [1],
118
+ )
119
+ tokens = tf.concat((bos_tensor, tokens), axis=-1)
120
+ if self.add_eos:
121
+ eos_tensor = tf.fill(
122
+ value=self.end_token_id, dims=tokens.shape.as_list()[0:1] + [1]
123
+ )
124
+ tokens = tf.concat((tokens, eos_tensor), axis=-1)
125
+
126
+ # Convert to a dense output if input was a scalar.
127
+ if unbatched:
128
+ tokens = tf.squeeze(tokens, 0)
129
+
130
+ return tokens
131
+
132
+ @preprocessing_function
133
+ def detokenize(self, inputs):
134
+ inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
135
+ # tf-text sentencepiece does not handle int64.
136
+ inputs = tf.cast(inputs, "int32")
137
+
138
+ outputs = self.id_to_string.lookup(inputs)
139
+ outputs = tf.strings.reduce_join(outputs, axis=-1, separator=" ")
140
+
141
+ for token in [
142
+ self.start_token,
143
+ self.end_token,
144
+ self.pad_token,
145
+ ]:
146
+ outputs = tf.strings.regex_replace(outputs, token, "")
147
+
148
+ outputs = tf.strings.strip(outputs)
149
+
150
+ if unbatched:
151
+ outputs = tf.squeeze(outputs, 0)
152
+ return outputs
153
+
154
+ def __call__(self, inputs):
155
+ return self.tokenize(inputs)
File without changes