keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -40,15 +40,15 @@ class SegFormerBackbone(Backbone):
40
40
  import keras_hub
41
41
 
42
42
  backbone = keras_hub.models.MiTBackbone(
43
- depths=[2, 2, 2, 2],
44
43
  image_shape=(224, 224, 3),
45
- hidden_dims=[32, 64, 160, 256],
46
44
  num_layers=4,
47
- blockwise_num_heads=[1, 2, 5, 8],
48
- blockwise_sr_ratios=[8, 4, 2, 1],
45
+ hidden_dims=[32, 64, 160, 256],
46
+ layerwise_depths=[2, 2, 2, 2],
47
+ layerwise_num_heads=[1, 2, 5, 8],
48
+ layerwise_sr_ratios=[8, 4, 2, 1],
49
+ layerwise_patch_sizes=[7, 3, 3, 3],
50
+ layerwise_strides=[4, 2, 2, 2],
49
51
  max_drop_path_rate=0.1,
50
- patch_sizes=[7, 3, 3, 3],
51
- strides=[4, 2, 2, 2],
52
52
  )
53
53
 
54
54
  segformer_backbone = keras_hub.models.SegFormerBackbone(
@@ -3,10 +3,8 @@ import math
3
3
  from keras import initializers
4
4
  from keras import layers
5
5
  from keras import ops
6
+ from keras.layers import ReversibleEmbedding
6
7
 
7
- from keras_hub.src.layers.modeling.reversible_embedding import (
8
- ReversibleEmbedding,
9
- )
10
8
  from keras_hub.src.utils.keras_utils import clone_initializer
11
9
  from keras_hub.src.utils.keras_utils import gelu_approximate
12
10
  from keras_hub.src.utils.keras_utils import standardize_data_format
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
9
7
 
@@ -1,8 +1,6 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
- from keras_hub.src.layers.modeling.reversible_embedding import (
4
- ReversibleEmbedding,
5
- )
6
4
  from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
7
5
  from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
8
6
 
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.t5.t5_layer_norm import T5LayerNorm
9
7
  from keras_hub.src.models.t5.t5_transformer_layer import T5TransformerLayer
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9
7
  from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer
@@ -361,7 +361,7 @@ class Task(PipelineModel):
361
361
 
362
362
  # Output captured summary for non-interactive logging.
363
363
  if print_fn:
364
- print_fn(console.end_capture(), line_break=False)
364
+ print_fn(console.end_capture().rstrip("\n"))
365
365
 
366
366
  super().summary(
367
367
  line_length=line_length,
@@ -1,18 +1,19 @@
1
+ import gc
1
2
  import json
2
3
  import os
3
4
  import pathlib
4
5
  import re
6
+ import tempfile
5
7
 
6
8
  import keras
7
9
  import numpy as np
10
+ import packaging.version
8
11
  import tensorflow as tf
9
12
  from absl.testing import parameterized
10
13
  from keras import ops
11
14
  from keras import tree
15
+ from keras.layers import ReversibleEmbedding
12
16
 
13
- from keras_hub.src.layers.modeling.reversible_embedding import (
14
- ReversibleEmbedding,
15
- )
16
17
  from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
17
18
  from keras_hub.src.tokenizers.tokenizer import Tokenizer
18
19
  from keras_hub.src.utils.tensor_utils import is_float_dtype
@@ -433,6 +434,396 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
433
434
  restored_output = restored_model(input_data)
434
435
  self.assertAllClose(model_output, restored_output, atol=atol, rtol=rtol)
435
436
 
437
+ def _verify_litert_outputs(
438
+ self,
439
+ keras_output,
440
+ litert_output,
441
+ sig_outputs,
442
+ expected_output_shape=None,
443
+ verify_numerics=True,
444
+ comparison_mode="strict",
445
+ output_thresholds=None,
446
+ ):
447
+ """Verify LiteRT outputs against expected shape and Keras outputs.
448
+
449
+ Args:
450
+ keras_output: Keras model output (can be None if not verifying
451
+ numerics)
452
+ litert_output: LiteRT interpreter output
453
+ sig_outputs: Output names from SignatureDef
454
+ expected_output_shape: Expected output shape (optional)
455
+ verify_numerics: Whether to verify numerical correctness
456
+ comparison_mode: "strict" or "statistical"
457
+ output_thresholds: Thresholds for statistical comparison
458
+ """
459
+ # Handle single output case: if Keras has single output but LiteRT
460
+ # returns dict
461
+ if (
462
+ not isinstance(keras_output, dict)
463
+ and isinstance(litert_output, dict)
464
+ and len(litert_output) == 1
465
+ ):
466
+ litert_output = list(litert_output.values())[0]
467
+
468
+ # Verify output shape if specified
469
+ if expected_output_shape is not None:
470
+ self.assertEqual(litert_output.shape, expected_output_shape)
471
+
472
+ # Verify numerical correctness if requested
473
+ if verify_numerics:
474
+ self._verify_litert_numerics(
475
+ keras_output,
476
+ litert_output,
477
+ sig_outputs,
478
+ output_thresholds,
479
+ comparison_mode,
480
+ )
481
+
482
+ def _verify_litert_numerics(
483
+ self,
484
+ keras_output,
485
+ litert_output,
486
+ sig_outputs,
487
+ output_thresholds,
488
+ comparison_mode,
489
+ ):
490
+ """Verify numerical accuracy between Keras and LiteRT outputs.
491
+
492
+ This method compares outputs using the SignatureDef output names to
493
+ match Keras outputs with LiteRT outputs properly.
494
+
495
+ Args:
496
+ keras_output: Keras model output (tensor or dict)
497
+ litert_output: LiteRT interpreter output (tensor or dict)
498
+ sig_outputs: List of output names from SignatureDef
499
+ output_thresholds: Dict of thresholds for comparison
500
+ comparison_mode: "strict" or "statistical"
501
+ """
502
+ if isinstance(keras_output, dict) and isinstance(litert_output, dict):
503
+ # Both outputs are dicts - compare using SignatureDef output names
504
+ for output_name in sig_outputs:
505
+ if output_name not in keras_output:
506
+ self.fail(
507
+ f"SignatureDef output '{output_name}' not found in "
508
+ f"Keras outputs.\n"
509
+ f"Keras keys: {list(keras_output.keys())}"
510
+ )
511
+ if output_name not in litert_output:
512
+ self.fail(
513
+ f"SignatureDef output '{output_name}' not found in "
514
+ f"LiteRT outputs.\n"
515
+ f"LiteRT keys: {list(litert_output.keys())}"
516
+ )
517
+
518
+ keras_val_np = ops.convert_to_numpy(keras_output[output_name])
519
+ litert_val = litert_output[output_name]
520
+ output_threshold = output_thresholds.get(
521
+ output_name,
522
+ output_thresholds.get("*", {"max": 10.0, "mean": 0.1}),
523
+ )
524
+ self._compare_outputs(
525
+ keras_val_np,
526
+ litert_val,
527
+ comparison_mode,
528
+ output_name,
529
+ output_threshold["max"],
530
+ output_threshold["mean"],
531
+ )
532
+ elif not isinstance(keras_output, dict) and not isinstance(
533
+ litert_output, dict
534
+ ):
535
+ # Both outputs are single tensors - direct comparison
536
+ keras_output_np = ops.convert_to_numpy(keras_output)
537
+ output_threshold = output_thresholds.get(
538
+ "*", {"max": 1e-2, "mean": 1e-3}
539
+ )
540
+ self._compare_outputs(
541
+ keras_output_np,
542
+ litert_output,
543
+ comparison_mode,
544
+ key=None,
545
+ max_threshold=output_threshold["max"],
546
+ mean_threshold=output_threshold["mean"],
547
+ )
548
+ else:
549
+ keras_type = type(keras_output).__name__
550
+ litert_type = type(litert_output).__name__
551
+ self.fail(
552
+ f"Output structure mismatch: Keras returns "
553
+ f"{keras_type}, LiteRT returns {litert_type}"
554
+ )
555
+
556
+ def run_litert_export_test(
557
+ self,
558
+ cls=None,
559
+ init_kwargs=None,
560
+ input_data=None,
561
+ expected_output_shape=None,
562
+ model=None,
563
+ verify_numerics=True,
564
+ # No LiteRT output in model saving test; remove undefined return
565
+ output_thresholds=None,
566
+ **export_kwargs,
567
+ ):
568
+ """Export model to LiteRT format and verify outputs.
569
+
570
+ Args:
571
+ cls: Model class to test (optional if model is provided)
572
+ init_kwargs: Initialization arguments for the model (optional
573
+ if model is provided)
574
+ input_data: Input data to test with (dict or tensor)
575
+ expected_output_shape: Expected output shape from LiteRT inference
576
+ model: Pre-created model instance (optional, if provided cls and
577
+ init_kwargs are ignored)
578
+ verify_numerics: Whether to verify numerical correctness
579
+ between Keras and LiteRT outputs. Set to False for preset
580
+ models with load_weights=False where outputs are random.
581
+ comparison_mode: "strict" (default) or "statistical".
582
+ - "strict": All elements must be within default tolerances
583
+ (1e-6)
584
+ - "statistical": Check mean/max absolute differences against
585
+ provided thresholds
586
+ output_thresholds: Dict mapping output names to threshold dicts
587
+ with "max" and "mean" keys. Use "*" as wildcard for defaults.
588
+ Example: {"output1": {"max": 1e-4, "mean": 1e-5},
589
+ "*": {"max": 1e-3, "mean": 1e-4}}
590
+ **export_kwargs: Additional keyword arguments to pass to
591
+ model.export(), such as allow_custom_ops=True or
592
+ enable_select_tf_ops=True.
593
+ """
594
+ # Skip test if Keras version is less than 3.13
595
+ if packaging.version.Version(
596
+ keras.__version__
597
+ ) < packaging.version.Version("3.13.0"):
598
+ self.skipTest("LiteRT export requires Keras >= 3.13")
599
+
600
+ self.skipTest(
601
+ "#TODO: [#2572] Re-enable LiteRT tests after a new tf release. "
602
+ "Can't test with tf 2.20 due to tf.lite module deprecation."
603
+ )
604
+
605
+ # Extract comparison_mode from export_kwargs if provided
606
+ comparison_mode = export_kwargs.pop("comparison_mode", "strict")
607
+ if keras.backend.backend() != "tensorflow":
608
+ self.skipTest("LiteRT export only supports TensorFlow backend")
609
+
610
+ try:
611
+ from ai_edge_litert.interpreter import Interpreter
612
+ except ImportError:
613
+ Interpreter = tf.lite.Interpreter
614
+
615
+ if output_thresholds is None:
616
+ output_thresholds = {"*": {"max": 10.0, "mean": 0.1}}
617
+
618
+ if model is None:
619
+ if cls is None or init_kwargs is None:
620
+ raise ValueError(
621
+ "Either 'model' or 'cls' and 'init_kwargs' must be provided"
622
+ )
623
+ model = cls(**init_kwargs)
624
+ _ = model(input_data)
625
+
626
+ interpreter = None
627
+ try:
628
+ with tempfile.TemporaryDirectory() as temp_dir:
629
+ export_path = os.path.join(temp_dir, "model.tflite")
630
+
631
+ # Step 1: Export model and get Keras output
632
+ model.export(export_path, format="litert", **export_kwargs)
633
+ self.assertTrue(os.path.exists(export_path))
634
+ self.assertGreater(os.path.getsize(export_path), 0)
635
+
636
+ keras_output = model(input_data) if verify_numerics else None
637
+
638
+ # Step 2: Load interpreter and verify SignatureDef
639
+ interpreter = Interpreter(model_path=export_path)
640
+ signature_defs = interpreter.get_signature_list()
641
+ self.assertIn(
642
+ "serving_default",
643
+ signature_defs,
644
+ "Missing serving_default signature",
645
+ )
646
+
647
+ serving_sig = signature_defs["serving_default"]
648
+ sig_inputs = serving_sig.get("inputs", [])
649
+ sig_outputs = serving_sig.get("outputs", [])
650
+
651
+ self.assertGreater(
652
+ len(sig_inputs),
653
+ 0,
654
+ "Should have at least one input in SignatureDef",
655
+ )
656
+ self.assertGreater(
657
+ len(sig_outputs),
658
+ 0,
659
+ "Should have at least one output in SignatureDef",
660
+ )
661
+
662
+ # Verify input signature
663
+ if isinstance(input_data, dict):
664
+ expected_inputs = set(input_data.keys())
665
+ actual_inputs = set(sig_inputs)
666
+ # Check that all expected inputs are in the signature
667
+ # (allow signature to have additional optional inputs)
668
+ missing_inputs = expected_inputs - actual_inputs
669
+ if missing_inputs:
670
+ self.fail(
671
+ f"Missing inputs in SignatureDef: "
672
+ f"{sorted(missing_inputs)}. "
673
+ f"Expected: {sorted(expected_inputs)}, "
674
+ f"SignatureDef has: {sorted(actual_inputs)}"
675
+ )
676
+ else:
677
+ # For numpy arrays, just verify we have exactly one input
678
+ # (since we're passing a single tensor)
679
+ if len(sig_inputs) != 1:
680
+ self.fail(
681
+ "Expected 1 input for numpy array input_data, "
682
+ f"but SignatureDef has {len(sig_inputs)}: "
683
+ f"{sig_inputs}"
684
+ )
685
+
686
+ # Verify output signature
687
+ if verify_numerics and isinstance(keras_output, dict):
688
+ expected_outputs = set(keras_output.keys())
689
+ actual_outputs = set(sig_outputs)
690
+ if expected_outputs != actual_outputs:
691
+ self.fail(
692
+ f"Output name mismatch: Expected "
693
+ f"{sorted(expected_outputs)}, "
694
+ f"but SignatureDef has {sorted(actual_outputs)}"
695
+ )
696
+
697
+ # Step 3: Run LiteRT inference
698
+ os.remove(export_path)
699
+ # Simple inference implementation
700
+ runner = interpreter.get_signature_runner("serving_default")
701
+
702
+ # Convert input data dtypes to match TFLite expectations
703
+ def convert_for_tflite(x):
704
+ """Convert tensor/array to TFLite-compatible dtypes."""
705
+ if hasattr(x, "dtype"):
706
+ if isinstance(x, np.ndarray):
707
+ if x.dtype == bool:
708
+ return x.astype(np.int32)
709
+ elif x.dtype == np.float64:
710
+ return x.astype(np.float32)
711
+ elif x.dtype == np.int64:
712
+ return x.astype(np.int32)
713
+ else: # TensorFlow tensor
714
+ if x.dtype == tf.bool:
715
+ return ops.cast(x, "int32").numpy()
716
+ elif x.dtype == tf.float64:
717
+ return ops.cast(x, "float32").numpy()
718
+ elif x.dtype == tf.int64:
719
+ return ops.cast(x, "int32").numpy()
720
+ else:
721
+ return x.numpy() if hasattr(x, "numpy") else x
722
+ elif hasattr(x, "numpy"):
723
+ return x.numpy()
724
+ return x
725
+
726
+ if isinstance(input_data, dict):
727
+ converted_input_data = tree.map_structure(
728
+ convert_for_tflite, input_data
729
+ )
730
+ litert_output = runner(**converted_input_data)
731
+ else:
732
+ # For single tensor inputs, get the input name
733
+ sig_inputs = serving_sig.get("inputs", [])
734
+ input_name = sig_inputs[
735
+ 0
736
+ ] # We verified len(sig_inputs) == 1 above
737
+ converted_input = convert_for_tflite(input_data)
738
+ litert_output = runner(**{input_name: converted_input})
739
+
740
+ # Step 4: Verify outputs
741
+ self._verify_litert_outputs(
742
+ keras_output,
743
+ litert_output,
744
+ sig_outputs,
745
+ expected_output_shape=expected_output_shape,
746
+ verify_numerics=verify_numerics,
747
+ comparison_mode=comparison_mode,
748
+ output_thresholds=output_thresholds,
749
+ )
750
+ finally:
751
+ if interpreter is not None:
752
+ del interpreter
753
+ if model is not None and cls is not None:
754
+ del model
755
+ gc.collect()
756
+
757
+ def _compare_outputs(
758
+ self,
759
+ keras_val,
760
+ litert_val,
761
+ comparison_mode,
762
+ key=None,
763
+ max_threshold=10.0,
764
+ mean_threshold=0.1,
765
+ ):
766
+ """Compare Keras and LiteRT outputs using specified comparison mode.
767
+
768
+ Args:
769
+ keras_val: Keras model output (numpy array)
770
+ litert_val: LiteRT model output (numpy array)
771
+ comparison_mode: "strict" or "statistical"
772
+ key: Output key name for error messages (optional)
773
+ max_threshold: Maximum absolute difference threshold for statistical
774
+ mode
775
+ mean_threshold: Mean absolute difference threshold for statistical
776
+ mode
777
+ """
778
+ key_msg = f" for output key '{key}'" if key else ""
779
+
780
+ # Check if shapes are compatible for comparison
781
+ self.assertEqual(
782
+ keras_val.shape,
783
+ litert_val.shape,
784
+ f"Shape mismatch{key_msg}: Keras shape "
785
+ f"{keras_val.shape}, LiteRT shape {litert_val.shape}. "
786
+ "Numerical comparison cannot proceed due to incompatible shapes.",
787
+ )
788
+
789
+ if comparison_mode == "strict":
790
+ # Original strict element-wise comparison with default tolerances
791
+ self.assertAllClose(
792
+ keras_val,
793
+ litert_val,
794
+ atol=1e-6,
795
+ rtol=1e-6,
796
+ msg=f"Mismatch{key_msg}",
797
+ )
798
+ elif comparison_mode == "statistical":
799
+ # Statistical comparison
800
+
801
+ # Calculate element-wise absolute differences
802
+ abs_diff = np.abs(keras_val - litert_val)
803
+
804
+ # Element-wise statistics
805
+ mean_abs_diff = np.mean(abs_diff)
806
+ max_abs_diff = np.max(abs_diff)
807
+
808
+ # Assert reasonable bounds on statistical differences
809
+ self.assertLessEqual(
810
+ mean_abs_diff,
811
+ mean_threshold,
812
+ f"Mean absolute difference too high: {mean_abs_diff:.6e}"
813
+ f"{key_msg} (threshold: {mean_threshold})",
814
+ )
815
+ self.assertLessEqual(
816
+ max_abs_diff,
817
+ max_threshold,
818
+ f"Max absolute difference too high: {max_abs_diff:.6e}"
819
+ f"{key_msg} (threshold: {max_threshold})",
820
+ )
821
+ else:
822
+ raise ValueError(
823
+ f"Unknown comparison_mode: {comparison_mode}. Must be "
824
+ "'strict' or 'statistical'"
825
+ )
826
+
436
827
  def run_backbone_test(
437
828
  self,
438
829
  cls,
@@ -11,6 +11,7 @@ from typing import Iterable
11
11
 
12
12
  import keras
13
13
  import regex as re
14
+ from keras.src.saving import serialization_lib
14
15
 
15
16
  from keras_hub.src.api_export import keras_hub_export
16
17
  from keras_hub.src.tokenizers import tokenizer
@@ -21,9 +22,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
21
22
 
22
23
  try:
23
24
  import tensorflow as tf
24
- import tensorflow_text as tf_text
25
25
  except ImportError:
26
26
  tf = None
27
+ try:
28
+ import tensorflow_text as tf_text
29
+ except ImportError:
27
30
  tf_text = None
28
31
 
29
32
  VOCAB_FILENAME = "vocabulary.json"
@@ -135,7 +138,13 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None):
135
138
  return remove_strings_from_inputs(raw_tokens, "६")
136
139
 
137
140
 
138
- class BytePairTokenizerCache(tf.Module if tf is not None else object):
141
+ try:
142
+ _base_class = tf.Module
143
+ except (AttributeError, TypeError):
144
+ _base_class = object
145
+
146
+
147
+ class BytePairTokenizerCache(_base_class):
139
148
  """Cache that stores the encoded result of seen tokens.
140
149
 
141
150
  The cache key is string tensor or python strings, and the value is split
@@ -331,6 +340,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
331
340
  return
332
341
 
333
342
  if isinstance(vocabulary, str):
343
+ if serialization_lib.in_safe_mode():
344
+ raise ValueError(
345
+ "Requested the loading of a vocabulary file outside of the "
346
+ "model archive. This carries a potential risk of loading "
347
+ "arbitrary and sensitive files and thus it is disallowed "
348
+ "by default. If you trust the source of the artifact, you "
349
+ "can override this error by passing `safe_mode=False` to "
350
+ "the loading function, or calling "
351
+ "`keras.config.enable_unsafe_deserialization()`. "
352
+ f"Vocabulary file: '{vocabulary}'"
353
+ )
334
354
  with open(vocabulary, "r", encoding="utf-8") as f:
335
355
  self.vocabulary = json.load(f)
336
356
  elif isinstance(vocabulary, dict):
@@ -342,6 +362,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
342
362
  f"`type(vocabulary)={type(vocabulary)}`."
343
363
  )
344
364
  if isinstance(merges, str):
365
+ if serialization_lib.in_safe_mode():
366
+ raise ValueError(
367
+ "Requested the loading of a merges file outside of the "
368
+ "model archive. This carries a potential risk of loading "
369
+ "arbitrary and sensitive files and thus it is disallowed "
370
+ "by default. If you trust the source of the artifact, you "
371
+ "can override this error by passing `safe_mode=False` to "
372
+ "the loading function, or calling "
373
+ "`keras.config.enable_unsafe_deserialization()`. "
374
+ f"Merges file: '{merges}'"
375
+ )
345
376
  with open(merges, encoding="utf-8") as f:
346
377
  self.merges = [bp.rstrip() for bp in f]
347
378
  elif isinstance(merges, Iterable):
@@ -8,9 +8,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
8
8
 
9
9
  try:
10
10
  import tensorflow as tf
11
- import tensorflow_text as tf_text
12
11
  except ImportError:
13
12
  tf = None
13
+ try:
14
+ import tensorflow_text as tf_text
15
+ except ImportError:
14
16
  tf_text = None
15
17
 
16
18
 
@@ -3,6 +3,7 @@ import binascii
3
3
  import os
4
4
 
5
5
  import keras
6
+ from keras.src.saving import serialization_lib
6
7
 
7
8
  from keras_hub.src.api_export import keras_hub_export
8
9
  from keras_hub.src.tokenizers import tokenizer
@@ -14,9 +15,11 @@ from keras_hub.src.utils.tensor_utils import tensor_to_list
14
15
 
15
16
  try:
16
17
  import tensorflow as tf
17
- import tensorflow_text as tf_text
18
18
  except ImportError:
19
19
  tf = None
20
+ try:
21
+ import tensorflow_text as tf_text
22
+ except ImportError:
20
23
  tf_text = None
21
24
 
22
25
  VOCAB_FILENAME = "vocabulary.spm"
@@ -145,6 +148,17 @@ class SentencePieceTokenizer(tokenizer.Tokenizer):
145
148
  except binascii.Error:
146
149
  pass
147
150
  if not is_base64:
151
+ if serialization_lib.in_safe_mode():
152
+ raise ValueError(
153
+ "Requested the loading of a proto file outside of "
154
+ "the model archive. This carries a potential risk of "
155
+ "loading arbitrary and sensitive files and thus it is "
156
+ "disallowed by default. If you trust the source of the "
157
+ "artifact, you can override this error by passing "
158
+ "`safe_mode=False` to the loading function, or calling "
159
+ "`keras.config.enable_unsafe_deserialization()`. "
160
+ f"Proto file: '{proto}'"
161
+ )
148
162
  proto_bytes = open(proto, "rb").read()
149
163
  elif isinstance(proto, bytes):
150
164
  proto_bytes = proto
@@ -6,9 +6,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function
6
6
 
7
7
  try:
8
8
  import tensorflow as tf
9
- import tensorflow_text as tf_text
10
9
  except ImportError:
11
10
  tf = None
11
+ try:
12
+ import tensorflow_text as tf_text
13
+ except ImportError:
12
14
  tf_text = None
13
15
 
14
16