keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -40,15 +40,15 @@ class SegFormerBackbone(Backbone):
|
|
|
40
40
|
import keras_hub
|
|
41
41
|
|
|
42
42
|
backbone = keras_hub.models.MiTBackbone(
|
|
43
|
-
depths=[2, 2, 2, 2],
|
|
44
43
|
image_shape=(224, 224, 3),
|
|
45
|
-
hidden_dims=[32, 64, 160, 256],
|
|
46
44
|
num_layers=4,
|
|
47
|
-
|
|
48
|
-
|
|
45
|
+
hidden_dims=[32, 64, 160, 256],
|
|
46
|
+
layerwise_depths=[2, 2, 2, 2],
|
|
47
|
+
layerwise_num_heads=[1, 2, 5, 8],
|
|
48
|
+
layerwise_sr_ratios=[8, 4, 2, 1],
|
|
49
|
+
layerwise_patch_sizes=[7, 3, 3, 3],
|
|
50
|
+
layerwise_strides=[4, 2, 2, 2],
|
|
49
51
|
max_drop_path_rate=0.1,
|
|
50
|
-
patch_sizes=[7, 3, 3, 3],
|
|
51
|
-
strides=[4, 2, 2, 2],
|
|
52
52
|
)
|
|
53
53
|
|
|
54
54
|
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
|
@@ -3,10 +3,8 @@ import math
|
|
|
3
3
|
from keras import initializers
|
|
4
4
|
from keras import layers
|
|
5
5
|
from keras import ops
|
|
6
|
+
from keras.layers import ReversibleEmbedding
|
|
6
7
|
|
|
7
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
8
|
-
ReversibleEmbedding,
|
|
9
|
-
)
|
|
10
8
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
11
9
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
12
10
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
|
|
9
7
|
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
4
|
-
ReversibleEmbedding,
|
|
5
|
-
)
|
|
6
4
|
from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
|
|
7
5
|
from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
|
|
8
6
|
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
|
|
9
7
|
from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
|
9
7
|
from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer
|
keras_hub/src/models/task.py
CHANGED
|
@@ -361,7 +361,7 @@ class Task(PipelineModel):
|
|
|
361
361
|
|
|
362
362
|
# Output captured summary for non-interactive logging.
|
|
363
363
|
if print_fn:
|
|
364
|
-
print_fn(console.end_capture()
|
|
364
|
+
print_fn(console.end_capture().rstrip("\n"))
|
|
365
365
|
|
|
366
366
|
super().summary(
|
|
367
367
|
line_length=line_length,
|
keras_hub/src/tests/test_case.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
|
+
import gc
|
|
1
2
|
import json
|
|
2
3
|
import os
|
|
3
4
|
import pathlib
|
|
4
5
|
import re
|
|
6
|
+
import tempfile
|
|
5
7
|
|
|
6
8
|
import keras
|
|
7
9
|
import numpy as np
|
|
10
|
+
import packaging.version
|
|
8
11
|
import tensorflow as tf
|
|
9
12
|
from absl.testing import parameterized
|
|
10
13
|
from keras import ops
|
|
11
14
|
from keras import tree
|
|
15
|
+
from keras.layers import ReversibleEmbedding
|
|
12
16
|
|
|
13
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
14
|
-
ReversibleEmbedding,
|
|
15
|
-
)
|
|
16
17
|
from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
|
|
17
18
|
from keras_hub.src.tokenizers.tokenizer import Tokenizer
|
|
18
19
|
from keras_hub.src.utils.tensor_utils import is_float_dtype
|
|
@@ -433,6 +434,396 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
|
433
434
|
restored_output = restored_model(input_data)
|
|
434
435
|
self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol)
|
|
435
436
|
|
|
437
|
+
def _verify_litert_outputs(
|
|
438
|
+
self,
|
|
439
|
+
keras_output,
|
|
440
|
+
litert_output,
|
|
441
|
+
sig_outputs,
|
|
442
|
+
expected_output_shape=None,
|
|
443
|
+
verify_numerics=True,
|
|
444
|
+
comparison_mode="strict",
|
|
445
|
+
output_thresholds=None,
|
|
446
|
+
):
|
|
447
|
+
"""Verify LiteRT outputs against expected shape and Keras outputs.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
keras_output: Keras model output (can be None if not verifying
|
|
451
|
+
numerics)
|
|
452
|
+
litert_output: LiteRT interpreter output
|
|
453
|
+
sig_outputs: Output names from SignatureDef
|
|
454
|
+
expected_output_shape: Expected output shape (optional)
|
|
455
|
+
verify_numerics: Whether to verify numerical correctness
|
|
456
|
+
comparison_mode: "strict" or "statistical"
|
|
457
|
+
output_thresholds: Thresholds for statistical comparison
|
|
458
|
+
"""
|
|
459
|
+
# Handle single output case: if Keras has single output but LiteRT
|
|
460
|
+
# returns dict
|
|
461
|
+
if (
|
|
462
|
+
not isinstance(keras_output, dict)
|
|
463
|
+
and isinstance(litert_output, dict)
|
|
464
|
+
and len(litert_output) == 1
|
|
465
|
+
):
|
|
466
|
+
litert_output = list(litert_output.values())[0]
|
|
467
|
+
|
|
468
|
+
# Verify output shape if specified
|
|
469
|
+
if expected_output_shape is not None:
|
|
470
|
+
self.assertEqual(litert_output.shape, expected_output_shape)
|
|
471
|
+
|
|
472
|
+
# Verify numerical correctness if requested
|
|
473
|
+
if verify_numerics:
|
|
474
|
+
self._verify_litert_numerics(
|
|
475
|
+
keras_output,
|
|
476
|
+
litert_output,
|
|
477
|
+
sig_outputs,
|
|
478
|
+
output_thresholds,
|
|
479
|
+
comparison_mode,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def _verify_litert_numerics(
|
|
483
|
+
self,
|
|
484
|
+
keras_output,
|
|
485
|
+
litert_output,
|
|
486
|
+
sig_outputs,
|
|
487
|
+
output_thresholds,
|
|
488
|
+
comparison_mode,
|
|
489
|
+
):
|
|
490
|
+
"""Verify numerical accuracy between Keras and LiteRT outputs.
|
|
491
|
+
|
|
492
|
+
This method compares outputs using the SignatureDef output names to
|
|
493
|
+
match Keras outputs with LiteRT outputs properly.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
keras_output: Keras model output (tensor or dict)
|
|
497
|
+
litert_output: LiteRT interpreter output (tensor or dict)
|
|
498
|
+
sig_outputs: List of output names from SignatureDef
|
|
499
|
+
output_thresholds: Dict of thresholds for comparison
|
|
500
|
+
comparison_mode: "strict" or "statistical"
|
|
501
|
+
"""
|
|
502
|
+
if isinstance(keras_output, dict) and isinstance(litert_output, dict):
|
|
503
|
+
# Both outputs are dicts - compare using SignatureDef output names
|
|
504
|
+
for output_name in sig_outputs:
|
|
505
|
+
if output_name not in keras_output:
|
|
506
|
+
self.fail(
|
|
507
|
+
f"SignatureDef output '{output_name}' not found in "
|
|
508
|
+
f"Keras outputs.\n"
|
|
509
|
+
f"Keras keys: {list(keras_output.keys())}"
|
|
510
|
+
)
|
|
511
|
+
if output_name not in litert_output:
|
|
512
|
+
self.fail(
|
|
513
|
+
f"SignatureDef output '{output_name}' not found in "
|
|
514
|
+
f"LiteRT outputs.\n"
|
|
515
|
+
f"LiteRT keys: {list(litert_output.keys())}"
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
keras_val_np = ops.convert_to_numpy(keras_output[output_name])
|
|
519
|
+
litert_val = litert_output[output_name]
|
|
520
|
+
output_threshold = output_thresholds.get(
|
|
521
|
+
output_name,
|
|
522
|
+
output_thresholds.get("*", {"max": 10.0, "mean": 0.1}),
|
|
523
|
+
)
|
|
524
|
+
self._compare_outputs(
|
|
525
|
+
keras_val_np,
|
|
526
|
+
litert_val,
|
|
527
|
+
comparison_mode,
|
|
528
|
+
output_name,
|
|
529
|
+
output_threshold["max"],
|
|
530
|
+
output_threshold["mean"],
|
|
531
|
+
)
|
|
532
|
+
elif not isinstance(keras_output, dict) and not isinstance(
|
|
533
|
+
litert_output, dict
|
|
534
|
+
):
|
|
535
|
+
# Both outputs are single tensors - direct comparison
|
|
536
|
+
keras_output_np = ops.convert_to_numpy(keras_output)
|
|
537
|
+
output_threshold = output_thresholds.get(
|
|
538
|
+
"*", {"max": 1e-2, "mean": 1e-3}
|
|
539
|
+
)
|
|
540
|
+
self._compare_outputs(
|
|
541
|
+
keras_output_np,
|
|
542
|
+
litert_output,
|
|
543
|
+
comparison_mode,
|
|
544
|
+
key=None,
|
|
545
|
+
max_threshold=output_threshold["max"],
|
|
546
|
+
mean_threshold=output_threshold["mean"],
|
|
547
|
+
)
|
|
548
|
+
else:
|
|
549
|
+
keras_type = type(keras_output).__name__
|
|
550
|
+
litert_type = type(litert_output).__name__
|
|
551
|
+
self.fail(
|
|
552
|
+
f"Output structure mismatch: Keras returns "
|
|
553
|
+
f"{keras_type}, LiteRT returns {litert_type}"
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def run_litert_export_test(
|
|
557
|
+
self,
|
|
558
|
+
cls=None,
|
|
559
|
+
init_kwargs=None,
|
|
560
|
+
input_data=None,
|
|
561
|
+
expected_output_shape=None,
|
|
562
|
+
model=None,
|
|
563
|
+
verify_numerics=True,
|
|
564
|
+
# No LiteRT output in model saving test; remove undefined return
|
|
565
|
+
output_thresholds=None,
|
|
566
|
+
**export_kwargs,
|
|
567
|
+
):
|
|
568
|
+
"""Export model to LiteRT format and verify outputs.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
cls: Model class to test (optional if model is provided)
|
|
572
|
+
init_kwargs: Initialization arguments for the model (optional
|
|
573
|
+
if model is provided)
|
|
574
|
+
input_data: Input data to test with (dict or tensor)
|
|
575
|
+
expected_output_shape: Expected output shape from LiteRT inference
|
|
576
|
+
model: Pre-created model instance (optional, if provided cls and
|
|
577
|
+
init_kwargs are ignored)
|
|
578
|
+
verify_numerics: Whether to verify numerical correctness
|
|
579
|
+
between Keras and LiteRT outputs. Set to False for preset
|
|
580
|
+
models with load_weights=False where outputs are random.
|
|
581
|
+
comparison_mode: "strict" (default) or "statistical".
|
|
582
|
+
- "strict": All elements must be within default tolerances
|
|
583
|
+
(1e-6)
|
|
584
|
+
- "statistical": Check mean/max absolute differences against
|
|
585
|
+
provided thresholds
|
|
586
|
+
output_thresholds: Dict mapping output names to threshold dicts
|
|
587
|
+
with "max" and "mean" keys. Use "*" as wildcard for defaults.
|
|
588
|
+
Example: {"output1": {"max": 1e-4, "mean": 1e-5},
|
|
589
|
+
"*": {"max": 1e-3, "mean": 1e-4}}
|
|
590
|
+
**export_kwargs: Additional keyword arguments to pass to
|
|
591
|
+
model.export(), such as allow_custom_ops=True or
|
|
592
|
+
enable_select_tf_ops=True.
|
|
593
|
+
"""
|
|
594
|
+
# Skip test if Keras version is less than 3.13
|
|
595
|
+
if packaging.version.Version(
|
|
596
|
+
keras.__version__
|
|
597
|
+
) < packaging.version.Version("3.13.0"):
|
|
598
|
+
self.skipTest("LiteRT export requires Keras >= 3.13")
|
|
599
|
+
|
|
600
|
+
self.skipTest(
|
|
601
|
+
"#TODO: [#2572] Re-enable LiteRT tests after a new tf release. "
|
|
602
|
+
"Can't test with tf 2.20 due to tf.lite module deprecation."
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Extract comparison_mode from export_kwargs if provided
|
|
606
|
+
comparison_mode = export_kwargs.pop("comparison_mode", "strict")
|
|
607
|
+
if keras.backend.backend() != "tensorflow":
|
|
608
|
+
self.skipTest("LiteRT export only supports TensorFlow backend")
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
from ai_edge_litert.interpreter import Interpreter
|
|
612
|
+
except ImportError:
|
|
613
|
+
Interpreter = tf.lite.Interpreter
|
|
614
|
+
|
|
615
|
+
if output_thresholds is None:
|
|
616
|
+
output_thresholds = {"*": {"max": 10.0, "mean": 0.1}}
|
|
617
|
+
|
|
618
|
+
if model is None:
|
|
619
|
+
if cls is None or init_kwargs is None:
|
|
620
|
+
raise ValueError(
|
|
621
|
+
"Either 'model' or 'cls' and 'init_kwargs' must be provided"
|
|
622
|
+
)
|
|
623
|
+
model = cls(**init_kwargs)
|
|
624
|
+
_ = model(input_data)
|
|
625
|
+
|
|
626
|
+
interpreter = None
|
|
627
|
+
try:
|
|
628
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
629
|
+
export_path = os.path.join(temp_dir, "model.tflite")
|
|
630
|
+
|
|
631
|
+
# Step 1: Export model and get Keras output
|
|
632
|
+
model.export(export_path, format="litert", **export_kwargs)
|
|
633
|
+
self.assertTrue(os.path.exists(export_path))
|
|
634
|
+
self.assertGreater(os.path.getsize(export_path), 0)
|
|
635
|
+
|
|
636
|
+
keras_output = model(input_data) if verify_numerics else None
|
|
637
|
+
|
|
638
|
+
# Step 2: Load interpreter and verify SignatureDef
|
|
639
|
+
interpreter = Interpreter(model_path=export_path)
|
|
640
|
+
signature_defs = interpreter.get_signature_list()
|
|
641
|
+
self.assertIn(
|
|
642
|
+
"serving_default",
|
|
643
|
+
signature_defs,
|
|
644
|
+
"Missing serving_default signature",
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
serving_sig = signature_defs["serving_default"]
|
|
648
|
+
sig_inputs = serving_sig.get("inputs", [])
|
|
649
|
+
sig_outputs = serving_sig.get("outputs", [])
|
|
650
|
+
|
|
651
|
+
self.assertGreater(
|
|
652
|
+
len(sig_inputs),
|
|
653
|
+
0,
|
|
654
|
+
"Should have at least one input in SignatureDef",
|
|
655
|
+
)
|
|
656
|
+
self.assertGreater(
|
|
657
|
+
len(sig_outputs),
|
|
658
|
+
0,
|
|
659
|
+
"Should have at least one output in SignatureDef",
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Verify input signature
|
|
663
|
+
if isinstance(input_data, dict):
|
|
664
|
+
expected_inputs = set(input_data.keys())
|
|
665
|
+
actual_inputs = set(sig_inputs)
|
|
666
|
+
# Check that all expected inputs are in the signature
|
|
667
|
+
# (allow signature to have additional optional inputs)
|
|
668
|
+
missing_inputs = expected_inputs - actual_inputs
|
|
669
|
+
if missing_inputs:
|
|
670
|
+
self.fail(
|
|
671
|
+
f"Missing inputs in SignatureDef: "
|
|
672
|
+
f"{sorted(missing_inputs)}. "
|
|
673
|
+
f"Expected: {sorted(expected_inputs)}, "
|
|
674
|
+
f"SignatureDef has: {sorted(actual_inputs)}"
|
|
675
|
+
)
|
|
676
|
+
else:
|
|
677
|
+
# For numpy arrays, just verify we have exactly one input
|
|
678
|
+
# (since we're passing a single tensor)
|
|
679
|
+
if len(sig_inputs) != 1:
|
|
680
|
+
self.fail(
|
|
681
|
+
"Expected 1 input for numpy array input_data, "
|
|
682
|
+
f"but SignatureDef has {len(sig_inputs)}: "
|
|
683
|
+
f"{sig_inputs}"
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
# Verify output signature
|
|
687
|
+
if verify_numerics and isinstance(keras_output, dict):
|
|
688
|
+
expected_outputs = set(keras_output.keys())
|
|
689
|
+
actual_outputs = set(sig_outputs)
|
|
690
|
+
if expected_outputs != actual_outputs:
|
|
691
|
+
self.fail(
|
|
692
|
+
f"Output name mismatch: Expected "
|
|
693
|
+
f"{sorted(expected_outputs)}, "
|
|
694
|
+
f"but SignatureDef has {sorted(actual_outputs)}"
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# Step 3: Run LiteRT inference
|
|
698
|
+
os.remove(export_path)
|
|
699
|
+
# Simple inference implementation
|
|
700
|
+
runner = interpreter.get_signature_runner("serving_default")
|
|
701
|
+
|
|
702
|
+
# Convert input data dtypes to match TFLite expectations
|
|
703
|
+
def convert_for_tflite(x):
|
|
704
|
+
"""Convert tensor/array to TFLite-compatible dtypes."""
|
|
705
|
+
if hasattr(x, "dtype"):
|
|
706
|
+
if isinstance(x, np.ndarray):
|
|
707
|
+
if x.dtype == bool:
|
|
708
|
+
return x.astype(np.int32)
|
|
709
|
+
elif x.dtype == np.float64:
|
|
710
|
+
return x.astype(np.float32)
|
|
711
|
+
elif x.dtype == np.int64:
|
|
712
|
+
return x.astype(np.int32)
|
|
713
|
+
else: # TensorFlow tensor
|
|
714
|
+
if x.dtype == tf.bool:
|
|
715
|
+
return ops.cast(x, "int32").numpy()
|
|
716
|
+
elif x.dtype == tf.float64:
|
|
717
|
+
return ops.cast(x, "float32").numpy()
|
|
718
|
+
elif x.dtype == tf.int64:
|
|
719
|
+
return ops.cast(x, "int32").numpy()
|
|
720
|
+
else:
|
|
721
|
+
return x.numpy() if hasattr(x, "numpy") else x
|
|
722
|
+
elif hasattr(x, "numpy"):
|
|
723
|
+
return x.numpy()
|
|
724
|
+
return x
|
|
725
|
+
|
|
726
|
+
if isinstance(input_data, dict):
|
|
727
|
+
converted_input_data = tree.map_structure(
|
|
728
|
+
convert_for_tflite, input_data
|
|
729
|
+
)
|
|
730
|
+
litert_output = runner(**converted_input_data)
|
|
731
|
+
else:
|
|
732
|
+
# For single tensor inputs, get the input name
|
|
733
|
+
sig_inputs = serving_sig.get("inputs", [])
|
|
734
|
+
input_name = sig_inputs[
|
|
735
|
+
0
|
|
736
|
+
] # We verified len(sig_inputs) == 1 above
|
|
737
|
+
converted_input = convert_for_tflite(input_data)
|
|
738
|
+
litert_output = runner(**{input_name: converted_input})
|
|
739
|
+
|
|
740
|
+
# Step 4: Verify outputs
|
|
741
|
+
self._verify_litert_outputs(
|
|
742
|
+
keras_output,
|
|
743
|
+
litert_output,
|
|
744
|
+
sig_outputs,
|
|
745
|
+
expected_output_shape=expected_output_shape,
|
|
746
|
+
verify_numerics=verify_numerics,
|
|
747
|
+
comparison_mode=comparison_mode,
|
|
748
|
+
output_thresholds=output_thresholds,
|
|
749
|
+
)
|
|
750
|
+
finally:
|
|
751
|
+
if interpreter is not None:
|
|
752
|
+
del interpreter
|
|
753
|
+
if model is not None and cls is not None:
|
|
754
|
+
del model
|
|
755
|
+
gc.collect()
|
|
756
|
+
|
|
757
|
+
def _compare_outputs(
|
|
758
|
+
self,
|
|
759
|
+
keras_val,
|
|
760
|
+
litert_val,
|
|
761
|
+
comparison_mode,
|
|
762
|
+
key=None,
|
|
763
|
+
max_threshold=10.0,
|
|
764
|
+
mean_threshold=0.1,
|
|
765
|
+
):
|
|
766
|
+
"""Compare Keras and LiteRT outputs using specified comparison mode.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
keras_val: Keras model output (numpy array)
|
|
770
|
+
litert_val: LiteRT model output (numpy array)
|
|
771
|
+
comparison_mode: "strict" or "statistical"
|
|
772
|
+
key: Output key name for error messages (optional)
|
|
773
|
+
max_threshold: Maximum absolute difference threshold for statistical
|
|
774
|
+
mode
|
|
775
|
+
mean_threshold: Mean absolute difference threshold for statistical
|
|
776
|
+
mode
|
|
777
|
+
"""
|
|
778
|
+
key_msg = f" for output key '{key}'" if key else ""
|
|
779
|
+
|
|
780
|
+
# Check if shapes are compatible for comparison
|
|
781
|
+
self.assertEqual(
|
|
782
|
+
keras_val.shape,
|
|
783
|
+
litert_val.shape,
|
|
784
|
+
f"Shape mismatch{key_msg}: Keras shape "
|
|
785
|
+
f"{keras_val.shape}, LiteRT shape {litert_val.shape}. "
|
|
786
|
+
"Numerical comparison cannot proceed due to incompatible shapes.",
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
if comparison_mode == "strict":
|
|
790
|
+
# Original strict element-wise comparison with default tolerances
|
|
791
|
+
self.assertAllClose(
|
|
792
|
+
keras_val,
|
|
793
|
+
litert_val,
|
|
794
|
+
atol=1e-6,
|
|
795
|
+
rtol=1e-6,
|
|
796
|
+
msg=f"Mismatch{key_msg}",
|
|
797
|
+
)
|
|
798
|
+
elif comparison_mode == "statistical":
|
|
799
|
+
# Statistical comparison
|
|
800
|
+
|
|
801
|
+
# Calculate element-wise absolute differences
|
|
802
|
+
abs_diff = np.abs(keras_val - litert_val)
|
|
803
|
+
|
|
804
|
+
# Element-wise statistics
|
|
805
|
+
mean_abs_diff = np.mean(abs_diff)
|
|
806
|
+
max_abs_diff = np.max(abs_diff)
|
|
807
|
+
|
|
808
|
+
# Assert reasonable bounds on statistical differences
|
|
809
|
+
self.assertLessEqual(
|
|
810
|
+
mean_abs_diff,
|
|
811
|
+
mean_threshold,
|
|
812
|
+
f"Mean absolute difference too high: {mean_abs_diff:.6e}"
|
|
813
|
+
f"{key_msg} (threshold: {mean_threshold})",
|
|
814
|
+
)
|
|
815
|
+
self.assertLessEqual(
|
|
816
|
+
max_abs_diff,
|
|
817
|
+
max_threshold,
|
|
818
|
+
f"Max absolute difference too high: {max_abs_diff:.6e}"
|
|
819
|
+
f"{key_msg} (threshold: {max_threshold})",
|
|
820
|
+
)
|
|
821
|
+
else:
|
|
822
|
+
raise ValueError(
|
|
823
|
+
f"Unknown comparison_mode: {comparison_mode}. Must be "
|
|
824
|
+
"'strict' or 'statistical'"
|
|
825
|
+
)
|
|
826
|
+
|
|
436
827
|
def run_backbone_test(
|
|
437
828
|
self,
|
|
438
829
|
cls,
|
|
@@ -11,6 +11,7 @@ from typing import Iterable
|
|
|
11
11
|
|
|
12
12
|
import keras
|
|
13
13
|
import regex as re
|
|
14
|
+
from keras.src.saving import serialization_lib
|
|
14
15
|
|
|
15
16
|
from keras_hub.src.api_export import keras_hub_export
|
|
16
17
|
from keras_hub.src.tokenizers import tokenizer
|
|
@@ -21,9 +22,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
|
21
22
|
|
|
22
23
|
try:
|
|
23
24
|
import tensorflow as tf
|
|
24
|
-
import tensorflow_text as tf_text
|
|
25
25
|
except ImportError:
|
|
26
26
|
tf = None
|
|
27
|
+
try:
|
|
28
|
+
import tensorflow_text as tf_text
|
|
29
|
+
except ImportError:
|
|
27
30
|
tf_text = None
|
|
28
31
|
|
|
29
32
|
VOCAB_FILENAME = "vocabulary.json"
|
|
@@ -135,7 +138,13 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None):
|
|
|
135
138
|
return remove_strings_from_inputs(raw_tokens, "६")
|
|
136
139
|
|
|
137
140
|
|
|
138
|
-
|
|
141
|
+
try:
|
|
142
|
+
_base_class = tf.Module
|
|
143
|
+
except (AttributeError, TypeError):
|
|
144
|
+
_base_class = object
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class BytePairTokenizerCache(_base_class):
|
|
139
148
|
"""Cache that stores the encoded result of seen tokens.
|
|
140
149
|
|
|
141
150
|
The cache key is string tensor or python strings, and the value is split
|
|
@@ -331,6 +340,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
|
|
|
331
340
|
return
|
|
332
341
|
|
|
333
342
|
if isinstance(vocabulary, str):
|
|
343
|
+
if serialization_lib.in_safe_mode():
|
|
344
|
+
raise ValueError(
|
|
345
|
+
"Requested the loading of a vocabulary file outside of the "
|
|
346
|
+
"model archive. This carries a potential risk of loading "
|
|
347
|
+
"arbitrary and sensitive files and thus it is disallowed "
|
|
348
|
+
"by default. If you trust the source of the artifact, you "
|
|
349
|
+
"can override this error by passing `safe_mode=False` to "
|
|
350
|
+
"the loading function, or calling "
|
|
351
|
+
"`keras.config.enable_unsafe_deserialization()`. "
|
|
352
|
+
f"Vocabulary file: '{vocabulary}'"
|
|
353
|
+
)
|
|
334
354
|
with open(vocabulary, "r", encoding="utf-8") as f:
|
|
335
355
|
self.vocabulary = json.load(f)
|
|
336
356
|
elif isinstance(vocabulary, dict):
|
|
@@ -342,6 +362,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
|
|
|
342
362
|
f"`type(vocabulary)={type(vocabulary)}`."
|
|
343
363
|
)
|
|
344
364
|
if isinstance(merges, str):
|
|
365
|
+
if serialization_lib.in_safe_mode():
|
|
366
|
+
raise ValueError(
|
|
367
|
+
"Requested the loading of a merges file outside of the "
|
|
368
|
+
"model archive. This carries a potential risk of loading "
|
|
369
|
+
"arbitrary and sensitive files and thus it is disallowed "
|
|
370
|
+
"by default. If you trust the source of the artifact, you "
|
|
371
|
+
"can override this error by passing `safe_mode=False` to "
|
|
372
|
+
"the loading function, or calling "
|
|
373
|
+
"`keras.config.enable_unsafe_deserialization()`. "
|
|
374
|
+
f"Merges file: '{merges}'"
|
|
375
|
+
)
|
|
345
376
|
with open(merges, encoding="utf-8") as f:
|
|
346
377
|
self.merges = [bp.rstrip() for bp in f]
|
|
347
378
|
elif isinstance(merges, Iterable):
|
|
@@ -8,9 +8,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
|
8
8
|
|
|
9
9
|
try:
|
|
10
10
|
import tensorflow as tf
|
|
11
|
-
import tensorflow_text as tf_text
|
|
12
11
|
except ImportError:
|
|
13
12
|
tf = None
|
|
13
|
+
try:
|
|
14
|
+
import tensorflow_text as tf_text
|
|
15
|
+
except ImportError:
|
|
14
16
|
tf_text = None
|
|
15
17
|
|
|
16
18
|
|
|
@@ -3,6 +3,7 @@ import binascii
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
5
|
import keras
|
|
6
|
+
from keras.src.saving import serialization_lib
|
|
6
7
|
|
|
7
8
|
from keras_hub.src.api_export import keras_hub_export
|
|
8
9
|
from keras_hub.src.tokenizers import tokenizer
|
|
@@ -14,9 +15,11 @@ from keras_hub.src.utils.tensor_utils import tensor_to_list
|
|
|
14
15
|
|
|
15
16
|
try:
|
|
16
17
|
import tensorflow as tf
|
|
17
|
-
import tensorflow_text as tf_text
|
|
18
18
|
except ImportError:
|
|
19
19
|
tf = None
|
|
20
|
+
try:
|
|
21
|
+
import tensorflow_text as tf_text
|
|
22
|
+
except ImportError:
|
|
20
23
|
tf_text = None
|
|
21
24
|
|
|
22
25
|
VOCAB_FILENAME = "vocabulary.spm"
|
|
@@ -145,6 +148,17 @@ class SentencePieceTokenizer(tokenizer.Tokenizer):
|
|
|
145
148
|
except binascii.Error:
|
|
146
149
|
pass
|
|
147
150
|
if not is_base64:
|
|
151
|
+
if serialization_lib.in_safe_mode():
|
|
152
|
+
raise ValueError(
|
|
153
|
+
"Requested the loading of a proto file outside of "
|
|
154
|
+
"the model archive. This carries a potential risk of "
|
|
155
|
+
"loading arbitrary and sensitive files and thus it is "
|
|
156
|
+
"disallowed by default. If you trust the source of the "
|
|
157
|
+
"artifact, you can override this error by passing "
|
|
158
|
+
"`safe_mode=False` to the loading function, or calling "
|
|
159
|
+
"`keras.config.enable_unsafe_deserialization()`. "
|
|
160
|
+
f"Proto file: '{proto}'"
|
|
161
|
+
)
|
|
148
162
|
proto_bytes = open(proto, "rb").read()
|
|
149
163
|
elif isinstance(proto, bytes):
|
|
150
164
|
proto_bytes = proto
|
|
@@ -6,9 +6,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
|
6
6
|
|
|
7
7
|
try:
|
|
8
8
|
import tensorflow as tf
|
|
9
|
-
import tensorflow_text as tf_text
|
|
10
9
|
except ImportError:
|
|
11
10
|
tf = None
|
|
11
|
+
try:
|
|
12
|
+
import tensorflow_text as tf_text
|
|
13
|
+
except ImportError:
|
|
12
14
|
tf_text = None
|
|
13
15
|
|
|
14
16
|
|