keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +21 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  5. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  6. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  7. keras_hub/src/models/backbone.py +10 -15
  8. keras_hub/src/models/d_fine/__init__.py +0 -0
  9. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  10. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  11. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  12. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  13. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  14. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  15. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  16. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  17. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  18. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  19. keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
  20. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  21. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  22. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  23. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  24. keras_hub/src/models/parseq/__init__.py +0 -0
  25. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  26. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  27. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  28. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  29. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  30. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  31. keras_hub/src/tests/test_case.py +37 -1
  32. keras_hub/src/utils/preset_utils.py +49 -0
  33. keras_hub/src/utils/tensor_utils.py +23 -1
  34. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +3 -0
  37. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
  39. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,418 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.cached_multi_head_attention import (
5
+ CachedMultiHeadAttention,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ compute_causal_mask,
9
+ )
10
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
11
+ merge_padding_and_attention_mask,
12
+ )
13
+ from keras_hub.src.models.vit.vit_layers import MLP
14
+
15
+
16
+ class PARSeqDecoderBlock(keras.layers.Layer):
17
+ """A decoder block for the PARSeq model.
18
+
19
+ This block consists of self-attention, cross-attention, and a multilayer
20
+ perceptron (MLP). It also includes layer normalization and dropout layers.
21
+
22
+ Args:
23
+ hidden_dim: int. The dimension of the hidden layers.
24
+ num_heads: int. The number of attention heads.
25
+ mlp_dim: int. The dimension of the MLP hidden layer.
26
+ dropout_rate: float. The dropout rate used in the feedforward layers.
27
+ attention_dropout: float. The dropout rate for the attention weights.
28
+ layer_norm_epsilon: float. A small float added to the denominator for
29
+ numerical stability in layer normalization.
30
+ **kwargs: Additional keyword arguments passed to the base
31
+ `keras.layers.Layer` constructor.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ hidden_dim,
37
+ num_heads,
38
+ mlp_dim,
39
+ dropout_rate=0.1,
40
+ attention_dropout=0.1,
41
+ layer_norm_epsilon=1e-5,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(**kwargs)
45
+
46
+ key_dim = hidden_dim // num_heads
47
+
48
+ # === Config ===
49
+ self.hidden_dim = hidden_dim
50
+ self.num_heads = num_heads
51
+ self.mlp_dim = mlp_dim
52
+ self.key_dim = key_dim
53
+ self.dropout_rate = dropout_rate
54
+ self.attention_dropout = attention_dropout
55
+ self.layer_norm_epsilon = layer_norm_epsilon
56
+
57
+ def build(self, input_shape):
58
+ self.query_layer_norm = keras.layers.LayerNormalization(
59
+ epsilon=self.layer_norm_epsilon,
60
+ name="query_layer_norm",
61
+ dtype=self.dtype_policy,
62
+ )
63
+ self.query_layer_norm.build(input_shape)
64
+ self.content_layer_norm = keras.layers.LayerNormalization(
65
+ epsilon=self.layer_norm_epsilon,
66
+ name="content_layer_norm",
67
+ dtype=self.dtype_policy,
68
+ )
69
+ self.content_layer_norm.build(input_shape)
70
+ self.self_attention = CachedMultiHeadAttention(
71
+ num_heads=self.num_heads,
72
+ key_dim=self.key_dim,
73
+ dropout=self.attention_dropout,
74
+ name="self_attention",
75
+ dtype=self.dtype_policy,
76
+ )
77
+ self.self_attention.build(input_shape, input_shape)
78
+ self.cross_attention = CachedMultiHeadAttention(
79
+ num_heads=self.num_heads,
80
+ key_dim=self.key_dim,
81
+ dropout=self.attention_dropout,
82
+ name="cross_attention",
83
+ dtype=self.dtype_policy,
84
+ )
85
+ self.cross_attention.build(input_shape, input_shape)
86
+
87
+ self.layer_norm_1 = keras.layers.LayerNormalization(
88
+ epsilon=self.layer_norm_epsilon,
89
+ name="ln_1",
90
+ dtype=self.dtype_policy,
91
+ )
92
+ self.layer_norm_1.build((None, None, self.hidden_dim))
93
+ self.layer_norm_2 = keras.layers.LayerNormalization(
94
+ epsilon=self.layer_norm_epsilon,
95
+ name="ln_2",
96
+ dtype=self.dtype_policy,
97
+ )
98
+ self.layer_norm_2.build((None, None, self.hidden_dim))
99
+ self.mlp = MLP(
100
+ hidden_dim=self.hidden_dim,
101
+ mlp_dim=self.mlp_dim,
102
+ dropout_rate=self.dropout_rate,
103
+ name="mlp",
104
+ dtype=self.dtype_policy,
105
+ )
106
+ self.mlp.build((None, None, self.hidden_dim))
107
+ self.dropout = keras.layers.Dropout(
108
+ rate=self.dropout_rate,
109
+ dtype=self.dtype_policy,
110
+ name="decoder_block_dropout",
111
+ )
112
+
113
+ self.built = True
114
+
115
+ def forward_stream(
116
+ self,
117
+ target,
118
+ target_norm,
119
+ target_kv,
120
+ memory,
121
+ padding_mask=None,
122
+ self_attention_cache=None,
123
+ self_attention_cache_update_index=0,
124
+ train_attention_mask=None,
125
+ ):
126
+ self_attention_new_cache = None
127
+ if train_attention_mask is None:
128
+ target_attention_mask = self._compute_attention_mask(
129
+ target_norm,
130
+ padding_mask,
131
+ self_attention_cache,
132
+ self_attention_cache_update_index,
133
+ )
134
+ else:
135
+ target_attention_mask = merge_padding_and_attention_mask(
136
+ target_norm, padding_mask, attention_mask=train_attention_mask
137
+ )
138
+
139
+ if self_attention_cache is not None:
140
+ target2, self_attention_new_cache = self.self_attention(
141
+ target_norm,
142
+ target_kv,
143
+ target_kv,
144
+ attention_mask=target_attention_mask,
145
+ cache=self_attention_cache,
146
+ cache_update_index=self_attention_cache_update_index,
147
+ )
148
+ else:
149
+ target2 = self.self_attention(
150
+ target_norm,
151
+ target_kv,
152
+ target_kv,
153
+ attention_mask=target_attention_mask,
154
+ )
155
+ target = ops.add(target, self.dropout(target2))
156
+ target2 = self.cross_attention(
157
+ self.layer_norm_1(target),
158
+ memory,
159
+ memory,
160
+ )
161
+ target = ops.add(target, self.dropout(target2))
162
+
163
+ target2 = self.mlp(self.layer_norm_2(target))
164
+ target = ops.add(target, target2)
165
+
166
+ return target, self_attention_new_cache
167
+
168
+ def call(
169
+ self,
170
+ query,
171
+ content,
172
+ memory,
173
+ padding_mask=None,
174
+ update_content=True,
175
+ query_self_attention_cache=None,
176
+ query_self_attention_cache_update_index=0,
177
+ content_self_attention_cache=None,
178
+ content_self_attention_cache_update_index=0,
179
+ query_mask=None,
180
+ content_mask=None,
181
+ ):
182
+ # position + token embeddings
183
+ query_norm = self.query_layer_norm(query)
184
+ # position embeddings
185
+ content_norm = self.content_layer_norm(content)
186
+ (
187
+ query,
188
+ query_self_attention_new_cache,
189
+ ) = self.forward_stream(
190
+ query,
191
+ query_norm,
192
+ content_norm,
193
+ memory,
194
+ padding_mask=padding_mask,
195
+ train_attention_mask=query_mask,
196
+ self_attention_cache=query_self_attention_cache,
197
+ self_attention_cache_update_index=query_self_attention_cache_update_index,
198
+ )
199
+
200
+ if update_content:
201
+ (
202
+ content,
203
+ content_self_attention_new_cache,
204
+ ) = self.forward_stream(
205
+ content,
206
+ content_norm,
207
+ content_norm,
208
+ memory, # image embeddings (encoder embeddings)
209
+ padding_mask=padding_mask,
210
+ train_attention_mask=content_mask,
211
+ self_attention_cache=content_self_attention_cache,
212
+ self_attention_cache_update_index=content_self_attention_cache_update_index,
213
+ )
214
+
215
+ return_values = [query, content]
216
+
217
+ if query_self_attention_cache is not None:
218
+ return_values.append(query_self_attention_new_cache)
219
+ if update_content and content_self_attention_cache is not None:
220
+ return_values.append(content_self_attention_new_cache)
221
+ elif not update_content and content_self_attention_cache is not None:
222
+ return_values.append(content_self_attention_cache)
223
+
224
+ return tuple(return_values)
225
+
226
+ def _compute_attention_mask(
227
+ self, x, padding_mask, cache, cache_update_index
228
+ ):
229
+ decoder_mask = merge_padding_and_attention_mask(
230
+ inputs=x, padding_mask=padding_mask, attention_mask=None
231
+ )
232
+ batch_size = ops.shape(x)[0]
233
+ input_length = output_length = ops.shape(x)[1]
234
+ if cache is not None:
235
+ input_length = ops.shape(cache)[2]
236
+
237
+ causal_mask = compute_causal_mask(
238
+ batch_size=batch_size,
239
+ input_length=input_length,
240
+ output_length=output_length,
241
+ cache_index=cache_update_index,
242
+ )
243
+
244
+ return (
245
+ ops.minimum(decoder_mask, causal_mask)
246
+ if decoder_mask is not None
247
+ else causal_mask
248
+ )
249
+
250
+ def get_config(self):
251
+ config = super().get_config()
252
+ config.update(
253
+ {
254
+ "num_heads": self.num_heads,
255
+ "hidden_dim": self.hidden_dim,
256
+ "key_dim": self.key_dim,
257
+ "mlp_dim": self.mlp_dim,
258
+ "dropout_rate": self.dropout_rate,
259
+ "attention_dropout": self.attention_dropout,
260
+ "layer_norm_epsilon": self.layer_norm_epsilon,
261
+ }
262
+ )
263
+ return config
264
+
265
+
266
+ class PARSeqDecoder(keras.layers.Layer):
267
+ """The PARSeq decoder.
268
+
269
+ This decoder consists of multiple decoder blocks and a token embedding
270
+ layer. It takes token IDs and memory from the encoder as input and outputs a
271
+ sequence of hidden states.
272
+
273
+ Args:
274
+ vocabulary_size: int. The size of the vocabulary.
275
+ max_label_length: int. The maximum length of the label sequence.
276
+ num_layers: int. The number of decoder layers.
277
+ hidden_dim: int. The dimension of the hidden layers.
278
+ mlp_dim: int. The dimension of the MLP hidden layer.
279
+ num_heads: int. The number of attention heads.
280
+ dropout_rate: float. The dropout rate.
281
+ attention_dropout: float. The dropout rate for the attention weights.
282
+ layer_norm_epsilon: float. A small float added to the denominator for
283
+ numerical stability in layer normalization.
284
+ **kwargs: Additional keyword arguments passed to the base
285
+ `keras.layers.Layer` constructor.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ vocabulary_size,
291
+ max_label_length,
292
+ num_layers,
293
+ hidden_dim,
294
+ mlp_dim,
295
+ num_heads,
296
+ dropout_rate=0.1,
297
+ attention_dropout=0.1,
298
+ layer_norm_epsilon=1e-5,
299
+ **kwargs,
300
+ ):
301
+ super().__init__(**kwargs)
302
+
303
+ # === Config ===
304
+ self.vocabulary_size = vocabulary_size
305
+ self.max_label_length = max_label_length
306
+ self.hidden_dim = hidden_dim
307
+ self.mlp_dim = mlp_dim
308
+ self.num_heads = num_heads
309
+ self.dropout_rate = dropout_rate
310
+ self.attention_dropout = attention_dropout
311
+ self.layer_norm_epsilon = layer_norm_epsilon
312
+ self.num_layers = num_layers
313
+
314
+ def build(self, input_shape):
315
+ self.token_embedding = keras.layers.Embedding(
316
+ input_dim=self.vocabulary_size,
317
+ output_dim=self.hidden_dim,
318
+ dtype=self.dtype_policy,
319
+ name="token_embedding",
320
+ )
321
+ self.token_embedding.build((1, self.vocabulary_size))
322
+ self.pos_query_embeddings = self.add_weight(
323
+ shape=(1, self.max_label_length + 1, self.hidden_dim),
324
+ name="pos_query_embeddings",
325
+ dtype=self.dtype,
326
+ )
327
+ self.dropout = keras.layers.Dropout(
328
+ self.dropout_rate, dtype=self.dtype_policy, name="decoder_dropout"
329
+ )
330
+ self.decoder_layers = []
331
+ for i in range(self.num_layers):
332
+ decoder_layer = PARSeqDecoderBlock(
333
+ hidden_dim=self.hidden_dim,
334
+ num_heads=self.num_heads,
335
+ mlp_dim=self.mlp_dim,
336
+ dropout_rate=self.dropout_rate,
337
+ attention_dropout=self.attention_dropout,
338
+ layer_norm_epsilon=self.layer_norm_epsilon,
339
+ dtype=self.dtype_policy,
340
+ name=f"decoder_layer_{i}",
341
+ )
342
+ decoder_layer.build((None, None, self.hidden_dim))
343
+ self.decoder_layers.append(decoder_layer)
344
+
345
+ self.layer_norm = keras.layers.LayerNormalization(
346
+ epsilon=self.layer_norm_epsilon,
347
+ dtype=self.dtype_policy,
348
+ name="layer_norm",
349
+ )
350
+ self.layer_norm.build((None, None, self.hidden_dim))
351
+ self.built = True
352
+
353
+ def call(
354
+ self,
355
+ token_ids,
356
+ memory,
357
+ padding_mask=None,
358
+ query_mask=None,
359
+ content_mask=None,
360
+ ):
361
+ bs, tokens_length = ops.shape(token_ids)
362
+ # <bos> stands for the null context. We only supply position information
363
+ # for characters after <bos>.
364
+ null_context = self.hidden_dim**0.5 * self.token_embedding(
365
+ token_ids[:, :1]
366
+ )
367
+ if tokens_length > 1:
368
+ content = self.pos_query_embeddings[:, : tokens_length - 1, :]
369
+ content = content + self.hidden_dim**0.5 * self.token_embedding(
370
+ token_ids[:, 1:]
371
+ )
372
+ content = ops.concatenate([null_context, content], axis=1)
373
+ else:
374
+ content = null_context
375
+
376
+ content = self.dropout(content)
377
+
378
+ query = ops.multiply(
379
+ ops.ones((bs, 1, 1), dtype=self.dtype),
380
+ self.pos_query_embeddings[:, :tokens_length, :],
381
+ )
382
+ query = self.dropout(query)
383
+
384
+ for i, decoder_layer in enumerate(self.decoder_layers):
385
+ last = i == self.num_layers - 1
386
+ query, content = decoder_layer(
387
+ query=query,
388
+ content=content,
389
+ memory=memory,
390
+ padding_mask=padding_mask,
391
+ update_content=not last,
392
+ query_mask=query_mask,
393
+ content_mask=content_mask,
394
+ )
395
+
396
+ query = self.layer_norm(query)
397
+
398
+ return query
399
+
400
+ def compute_output_shape(self, input_shape):
401
+ return (None, None, self.hidden_dim)
402
+
403
+ def get_config(self):
404
+ config = super().get_config()
405
+ config.update(
406
+ {
407
+ "vocabulary_size": self.vocabulary_size,
408
+ "max_label_length": self.max_label_length,
409
+ "num_layers": self.num_layers,
410
+ "num_heads": self.num_heads,
411
+ "hidden_dim": self.hidden_dim,
412
+ "mlp_dim": self.mlp_dim,
413
+ "dropout_rate": self.dropout_rate,
414
+ "attention_dropout": self.attention_dropout,
415
+ "layer_norm_epsilon": self.layer_norm_epsilon,
416
+ }
417
+ )
418
+ return config
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.PARSeqImageConverter")
7
+ class PARSeqImageConverter(ImageConverter):
8
+ backbone_cls = PARSeqBackbone
@@ -0,0 +1,221 @@
1
+ import os
2
+ import re
3
+ from typing import Iterable
4
+
5
+ import keras
6
+
7
+ from keras_hub.src.api_export import keras_hub_export
8
+ from keras_hub.src.tokenizers import tokenizer
9
+ from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
10
+ from keras_hub.src.utils.tensor_utils import is_int_dtype
11
+ from keras_hub.src.utils.tensor_utils import is_string_dtype
12
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
13
+
14
+ try:
15
+ import tensorflow as tf
16
+ import tensorflow_text as tf_text
17
+ except ImportError:
18
+ tf = None
19
+ tf_text = None
20
+
21
+ PARSEQ_VOCAB = list(
22
+ "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"
23
+ "\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
24
+ )
25
+
26
+ VOCAB_FILENAME = "vocabulary.txt"
27
+
28
+
29
+ @keras_hub_export(
30
+ [
31
+ "keras_hub.tokenizers.PARSeqTokenizer",
32
+ "keras_hub.models.PARSeqTokenizer",
33
+ ]
34
+ )
35
+ class PARSeqTokenizer(tokenizer.Tokenizer):
36
+ """A Tokenizer for PARSeq models, designed for OCR tasks.
37
+
38
+ This tokenizer converts strings into sequences of integer IDs or string
39
+ tokens, and vice-versa. It supports various preprocessing steps such as
40
+ whitespace removal, Unicode normalization, and limiting the maximum label
41
+ length. It also provides functionality to save and load the vocabulary
42
+ from a file.
43
+
44
+ Args:
45
+ vocabulary: str. A string or iterable representing the vocabulary to
46
+ use. If a string, it's treated as the path to a vocabulary file.
47
+ If an iterable, it's treated as a list of characters forming
48
+ the vocabulary. Defaults to `PARSEQ_VOCAB`.
49
+ remove_whitespace: bool. Whether to remove whitespace characters from
50
+ the input. Defaults to `True`.
51
+ normalize_unicode: bool. Whether to normalize Unicode characters in the
52
+ input using NFKD normalization and remove non-ASCII characters.
53
+ Defaults to `True`.
54
+ max_label_length: int. The maximum length of the tokenized output.
55
+ Longer labels will be truncated. Defaults to `25`.
56
+ dtype: str. The data type of the tokenized output. Must be an integer
57
+ type (e.g., "int32") or a string type ("string").
58
+ Defaults to `"int32"`.
59
+ **kwargs: Additional keyword arguments passed to the base
60
+ `keras.layers.Layer` constructor.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ vocabulary=PARSEQ_VOCAB,
66
+ remove_whitespace=True,
67
+ normalize_unicode=True,
68
+ max_label_length=25,
69
+ dtype="int32",
70
+ **kwargs,
71
+ ):
72
+ if not is_int_dtype(dtype) and not is_string_dtype(dtype):
73
+ raise ValueError(
74
+ "Output dtype must be an integer type or a string. "
75
+ f"Received: dtype={dtype}"
76
+ )
77
+ super().__init__(dtype=dtype, **kwargs)
78
+ self.remove_whitespace = remove_whitespace
79
+ self.normalize_unicode = normalize_unicode
80
+ self.max_label_length = max_label_length
81
+ self.file_assets = [VOCAB_FILENAME]
82
+
83
+ self.set_vocabulary(vocabulary)
84
+
85
+ def save_assets(self, dir_path):
86
+ path = os.path.join(dir_path, VOCAB_FILENAME)
87
+ with open(path, "w", encoding="utf-8") as file:
88
+ for token in self.vocabulary:
89
+ file.write(f"{token}\n")
90
+
91
+ def load_assets(self, dir_path):
92
+ path = os.path.join(dir_path, VOCAB_FILENAME)
93
+ self.set_vocabulary(path)
94
+
95
+ def set_vocabulary(self, vocabulary):
96
+ """Set the tokenizer vocabulary to a file or list of strings."""
97
+ if vocabulary is None:
98
+ self.vocabulary = None
99
+ return
100
+
101
+ if isinstance(vocabulary, str):
102
+ with open(vocabulary, "r", encoding="utf-8") as file:
103
+ self.vocabulary = [line.rstrip() for line in file]
104
+ self.vocabulary = "".join(self.vocabulary)
105
+ elif isinstance(vocabulary, Iterable):
106
+ self.vocabulary = "".join(vocabulary)
107
+ else:
108
+ raise ValueError(
109
+ "Vocabulary must be an file path or list of terms. "
110
+ f"Received: vocabulary={vocabulary}"
111
+ )
112
+
113
+ self.lowercase_only = self.vocabulary == self.vocabulary.lower()
114
+ self.uppercase_only = self.vocabulary == self.vocabulary.upper()
115
+ escaped_charset = re.escape(self.vocabulary) # Escape for safe regex
116
+ self.unsupported_regex = f"[^{escaped_charset}]"
117
+ self._itos = ("[E]",) + tuple(self.vocabulary) + ("[B]", "[P]")
118
+ self._stoi = {s: i for i, s in enumerate(self._itos)}
119
+
120
+ self._add_special_token("[B]", "start_token")
121
+ self._add_special_token("[E]", "end_token")
122
+ self._add_special_token("[P]", "pad_token")
123
+ # Create lookup tables.
124
+ self.char_to_id = tf.lookup.StaticHashTable(
125
+ initializer=tf.lookup.KeyValueTensorInitializer(
126
+ keys=list(self._stoi.keys()),
127
+ values=list(self._stoi.values()),
128
+ key_dtype=tf.string,
129
+ value_dtype=tf.int32,
130
+ ),
131
+ default_value=self._stoi["[E]"],
132
+ )
133
+ self.id_to_char = tf.lookup.StaticHashTable(
134
+ initializer=tf.lookup.KeyValueTensorInitializer(
135
+ keys=list(self._stoi.values()),
136
+ values=list(self._stoi.keys()),
137
+ key_dtype=tf.int32,
138
+ value_dtype=tf.string,
139
+ ),
140
+ default_value=self.pad_token,
141
+ )
142
+
143
+ def get_vocabulary(self):
144
+ """Get the tokenizer vocabulary as a list of strings tokens."""
145
+ return list(self.vocabulary)
146
+
147
+ def id_to_token(self, id):
148
+ if id >= self.vocabulary_size() or id < 0:
149
+ raise ValueError(
150
+ f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
151
+ f"Received: {id}"
152
+ )
153
+ return self._itos[id]
154
+
155
+ def token_to_id(self, token):
156
+ return self._stoi[token]
157
+
158
+ def _preprocess(self, inputs):
159
+ """Performs preprocessing include only characters from ASCII."""
160
+ if self.remove_whitespace:
161
+ inputs = tf.strings.regex_replace(inputs, r"\s+", "")
162
+
163
+ if self.normalize_unicode:
164
+ inputs = tf_text.normalize_utf8(inputs, normalization_form="NFKD")
165
+ inputs = tf.strings.regex_replace(inputs, r"[^!-~]", "")
166
+
167
+ if self.lowercase_only:
168
+ inputs = tf.strings.lower(inputs)
169
+ elif self.uppercase_only:
170
+ inputs = tf.strings.upper(inputs)
171
+
172
+ inputs = tf.strings.regex_replace(inputs, self.unsupported_regex, "")
173
+ inputs = tf.strings.substr(inputs, 0, self.max_label_length)
174
+
175
+ return inputs
176
+
177
+ @preprocessing_function
178
+ def tokenize(self, inputs):
179
+ inputs = tf.convert_to_tensor(inputs)
180
+ unbatched = inputs.shape.rank == 0
181
+ if unbatched:
182
+ inputs = tf.expand_dims(inputs, 0)
183
+
184
+ inputs = tf.map_fn(
185
+ self._preprocess, inputs, fn_output_signature=tf.string
186
+ )
187
+
188
+ token_ids = tf.cond(
189
+ tf.size(inputs) > 0,
190
+ lambda: self.char_to_id.lookup(
191
+ tf.strings.unicode_split(inputs, "UTF-8")
192
+ ),
193
+ lambda: tf.RaggedTensor.from_row_splits(
194
+ values=tf.constant([], dtype=tf.int32),
195
+ row_splits=tf.constant([0], dtype=tf.int64),
196
+ ),
197
+ )
198
+ if unbatched:
199
+ token_ids = tf.squeeze(token_ids, 0)
200
+ tf.ensure_shape(token_ids, shape=[self.max_label_length])
201
+ return token_ids
202
+
203
+ @preprocessing_function
204
+ def detokenize(self, inputs):
205
+ inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
206
+ # tf-text sentencepiece does not handle int64.
207
+ inputs = tf.cast(inputs, "int32")
208
+ outputs = self.id_to_char.lookup(inputs)
209
+ if unbatched:
210
+ outputs = tf.squeeze(outputs, 0)
211
+ return outputs
212
+
213
+ def vocabulary_size(self):
214
+ """Get the integer size of the tokenizer vocabulary."""
215
+ return len(self.vocabulary) + 3
216
+
217
+ def compute_output_spec(self, input_spec):
218
+ return keras.KerasTensor(
219
+ input_spec.shape + (self.max_label_length,),
220
+ dtype=self.compute_dtype,
221
+ )
@@ -499,6 +499,7 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
499
499
  init_kwargs,
500
500
  input_data,
501
501
  expected_output_shape,
502
+ spatial_output_keys=None,
502
503
  expected_pyramid_output_keys=None,
503
504
  expected_pyramid_image_sizes=None,
504
505
  variable_length_data=None,
@@ -557,12 +558,47 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
557
558
  input_data = ops.transpose(input_data, axes=(2, 0, 1))
558
559
  elif len(input_data_shape) == 4:
559
560
  input_data = ops.transpose(input_data, axes=(0, 3, 1, 2))
560
- if len(expected_output_shape) == 3:
561
+ if isinstance(expected_output_shape, dict):
562
+ # Handle dictionary of shapes.
563
+ transposed_shapes = {}
564
+ for key, shape in expected_output_shape.items():
565
+ if spatial_output_keys and key not in spatial_output_keys:
566
+ transposed_shapes[key] = shape
567
+ continue
568
+ if len(shape) == 3:
569
+ transposed_shapes[key] = (shape[0], shape[2], shape[1])
570
+ elif len(shape) == 4:
571
+ transposed_shapes[key] = (
572
+ shape[0],
573
+ shape[3],
574
+ shape[1],
575
+ shape[2],
576
+ )
577
+ else:
578
+ transposed_shapes[key] = shape
579
+ expected_output_shape = transposed_shapes
580
+ elif len(expected_output_shape) == 3:
561
581
  x = expected_output_shape
562
582
  expected_output_shape = (x[0], x[2], x[1])
563
583
  elif len(expected_output_shape) == 4:
564
584
  x = expected_output_shape
565
585
  expected_output_shape = (x[0], x[3], x[1], x[2])
586
+ original_init_kwargs = init_kwargs.copy()
587
+ init_kwargs = original_init_kwargs.copy()
588
+ # Handle nested `keras.Model` instances passed within `init_kwargs`.
589
+ for k, v in init_kwargs.items():
590
+ if isinstance(v, keras.Model) and hasattr(v, "data_format"):
591
+ config = v.get_config()
592
+ config["data_format"] = "channels_first"
593
+ if (
594
+ "image_shape" in config
595
+ and config["image_shape"] is not None
596
+ and len(config["image_shape"]) == 3
597
+ ):
598
+ config["image_shape"] = tuple(
599
+ reversed(config["image_shape"])
600
+ )
601
+ init_kwargs[k] = v.__class__.from_config(config)
566
602
  if "image_shape" in init_kwargs:
567
603
  init_kwargs = init_kwargs.copy()
568
604
  init_kwargs["image_shape"] = tuple(