tf-keras-nightly 2.20.0.dev2025051909__py3-none-any.whl → 2.20.0.dev2025052109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tf_keras/__init__.py CHANGED
@@ -27,4 +27,4 @@ from tf_keras.src.engine.sequential import Sequential
27
27
  from tf_keras.src.engine.training import Model
28
28
 
29
29
 
30
- __version__ = "2.20.0.dev2025051909"
30
+ __version__ = "2.20.0.dev2025052109"
@@ -308,6 +308,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
308
308
  self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
309
309
  ):
310
310
  self._instrument_layer_creation()
311
+ self._called = False
311
312
 
312
313
  # These properties should be set by the user via keyword arguments.
313
314
  # note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -326,6 +327,10 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
326
327
  # Validate optional keyword arguments.
327
328
  generic_utils.validate_kwargs(kwargs, allowed_kwargs)
328
329
 
330
+ # Track the built-in call-context arguments. These are arguments that
331
+ # are tracked and propagated across the call-stack by default.
332
+ self._call_context_args = {"training"}
333
+
329
334
  # Mutable properties
330
335
  # Indicates whether the layer's weights are updated during training
331
336
  # and whether the layer's updates are run during training.
@@ -411,6 +416,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
411
416
 
412
417
  self._init_call_fn_args()
413
418
 
419
+ # Track the built-in call-context arguments.
420
+ self._call_spec._update_call_context_arguments(self._call_context_args)
421
+
414
422
  # Whether the `call` method can be used to build a TF graph without
415
423
  # issues. This attribute has no effect if the model is created using
416
424
  # the Functional API. Instead, `model.dynamic` is determined based on
@@ -1042,6 +1050,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1042
1050
  # - input_spec compatibility is only checked against `inputs`
1043
1051
  # - mixed precision casting (autocast) is only applied to `inputs`,
1044
1052
  # not to any other argument.
1053
+ self._called = True
1045
1054
  inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs)
1046
1055
  input_list = tf.nest.flatten(inputs)
1047
1056
 
@@ -1080,17 +1089,21 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1080
1089
  if self._expects_mask_arg and mask_is_implicit:
1081
1090
  kwargs["mask"] = input_masks
1082
1091
 
1083
- # Training mode for `Layer.call` is set via (in order of priority):
1084
- # (1) The `training` argument passed to this `Layer.call`, if it is not
1085
- # None
1086
- # (2) The training mode of an outer `Layer.call`.
1087
- # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if
1088
- # set)
1089
- # (4) Any non-None default value for `training` specified in the call
1092
+ # Call-context arguments for `Layer.call` is set via (in order of
1093
+ # priority):
1094
+ # (1) The argument passed to this `Layer.call`, if it is not None
1095
+ # (2) The argument value of an outer `Layer.call`.
1096
+ # (3) (only for "training") The default mode set by
1097
+ # `tf.keras.backend.set_learning_phase` (if set)
1098
+ # (4) Any non-None default value for the argument specified in the call
1090
1099
  # signature
1091
- # (5) False (treating the layer as if it's in inference)
1092
- args, kwargs, training_mode = self._set_training_mode(
1093
- args, kwargs, call_context
1100
+ # (5) False
1101
+ (
1102
+ args,
1103
+ kwargs,
1104
+ propagated,
1105
+ ) = self._get_propagated_call_context_arguments(
1106
+ args, kwargs, call_context, self._call_context_args
1094
1107
  )
1095
1108
 
1096
1109
  # Losses are cleared for all sublayers on the outermost `Layer.call`.
@@ -1104,7 +1117,7 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1104
1117
  layer=self,
1105
1118
  inputs=inputs,
1106
1119
  build_graph=not eager,
1107
- training=training_mode,
1120
+ call_context_args=propagated,
1108
1121
  ):
1109
1122
  input_spec.assert_input_compatibility(
1110
1123
  self.input_spec, inputs, self.name
@@ -1152,6 +1165,55 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
1152
1165
 
1153
1166
  return outputs
1154
1167
 
1168
+ def _register_call_context_args(self, *argument_names):
1169
+ """Registers call-context args for this layer.
1170
+ If this layer declares a `call()` method that accepts
1171
+ one or more of the given args, those args will be
1172
+ automatically injected into the call signature of this
1173
+ layer. This layer will also propagate the args to any
1174
+ nested sublayers that are called from within this layer.
1175
+ If this layer doesn't declare a `call()` method that
1176
+ accepts one or more of the given args, these args will
1177
+ simply be propagated to any nested sublayers without
1178
+ being injected into the call signature of this layer.
1179
+ This is useful for propagating custom arguments
1180
+ from top-level layers/models to sublayers.
1181
+ Example:
1182
+ ```
1183
+ class Inner(layers.Layer):
1184
+ def __init__(self):
1185
+ super().__init__()
1186
+ # Register `foo_mode` as a call-context arg
1187
+ self._register_call_context_args("foo_mode")
1188
+ def call(self, x, foo_mode=False):
1189
+ # If foo_mode=True add 1, otherwise add 0
1190
+ add_val = ops.where(foo_mode, 1.0, 0.0)
1191
+ return x + add_val
1192
+ class Outer(layers.Layer):
1193
+ def __init__(self):
1194
+ super().__init__()
1195
+ self.inner = Inner()
1196
+ def call(self, x):
1197
+ # We don't explicitly pass foo_mode here—Base Layer.__call__
1198
+ # should inject it into `self.inner`
1199
+ return self.inner(x)
1200
+ sample_input = np.array([[1.0], [2.0]])
1201
+ # Sequential model
1202
+ seq = models.Sequential([Outer()])
1203
+ # Tell the Sequential model to propagate foo_mode down
1204
+ # the call-stack
1205
+ seq.register_call_context_args("foo_mode")
1206
+ # foo_mode=True -> input + 1
1207
+ out_true = seq(sample_input, foo_mode=True)
1208
+ """
1209
+ if self._called:
1210
+ raise RuntimeError(
1211
+ "Cannot add call-context args after the layer has been called."
1212
+ )
1213
+ self._call_context_args |= set(argument_names)
1214
+ self._call_spec._update_call_context_arguments(argument_names)
1215
+ self._call_spec._update_call_context_argument_defaults(argument_names)
1216
+
1155
1217
  def _get_unnested_name_scope(self):
1156
1218
  if _is_name_scope_on_model_declaration_enabled:
1157
1219
  with _name_scope_unnester(
@@ -2535,47 +2597,57 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
2535
2597
  kwargs["mask"] = input_masks
2536
2598
  mask_arg_passed_by_framework = True
2537
2599
 
2538
- # If `training` argument is None or not explicitly passed,
2539
- # propagate `training` value from this layer's calling layer.
2540
- training_value = None
2541
- training_arg_passed_by_framework = False
2542
- # Priority 1: `training` was explicitly passed a non-None value.
2543
- if self._call_spec.arg_was_passed("training", args, kwargs):
2544
- training_value = self._call_spec.get_arg_value(
2545
- "training", args, kwargs
2546
- )
2547
- if not self._expects_training_arg:
2548
- kwargs.pop("training")
2549
-
2550
- if training_value is None:
2551
- # Priority 2: `training` was passed to a parent layer.
2552
- if call_context.training is not None:
2553
- training_value = call_context.training
2554
- # Priority 3: `learning_phase()` has been set.
2555
- elif backend.global_learning_phase_is_set():
2556
- training_value = backend.learning_phase()
2557
- # Force the training_value to be bool type which matches to the
2558
- # contract for layer/model call args.
2559
- if tf.is_tensor(training_value):
2560
- training_value = tf.cast(training_value, tf.bool)
2600
+ propagated = dict()
2601
+ args_passed_by_framework = dict()
2602
+ for context_arg in self._call_context_args:
2603
+ # If `training` argument is None or not explicitly passed,
2604
+ # propagate `training` value from this layer's calling layer.
2605
+ value = None
2606
+ args_passed_by_framework[context_arg] = False
2607
+ # Priority 1: `training` was explicitly passed a non-None value.
2608
+ if self._call_spec.arg_was_passed(context_arg, args, kwargs):
2609
+ value = self._call_spec.get_arg_value(context_arg, args, kwargs)
2610
+ if not self._expects_context_arg(context_arg):
2611
+ kwargs.pop(context_arg)
2612
+
2613
+ if value is None:
2614
+ # Priority 2: `training` was passed to a parent layer.
2615
+ if call_context.get_call_context_arg(context_arg) is not None:
2616
+ value = call_context.get_call_context_arg(context_arg)
2617
+ # Priority 3: `learning_phase()` has been set.
2618
+ elif (
2619
+ context_arg == "training"
2620
+ and backend.global_learning_phase_is_set()
2621
+ ):
2622
+ value = backend.learning_phase()
2623
+ # Force the training_value to be bool type which matches to
2624
+ # the contract for layer/model call args.
2625
+ if tf.is_tensor(value):
2626
+ value = tf.cast(value, tf.bool)
2627
+ else:
2628
+ value = bool(value)
2629
+ # Priority 4: trace layer with the default training argument
2630
+ # specified in the `call` signature (or in inference mode if the
2631
+ # `call` signature specifies no non-None default).
2561
2632
  else:
2562
- training_value = bool(training_value)
2563
- # Priority 4: trace layer with the default training argument
2564
- # specified in the `call` signature (or in inference mode if the
2565
- # `call` signature specifies no non-None default).
2566
- else:
2567
- training_value = self._call_spec.default_training_arg
2568
- # In cases (2), (3), (4) the training argument is passed
2569
- # automatically by the framework, and will not be hard-coded into
2570
- # the model.
2571
- if self._expects_training_arg:
2572
- args, kwargs = self._call_spec.set_arg_value(
2573
- "training", training_value, args, kwargs
2574
- )
2575
- training_arg_passed_by_framework = True
2633
+ value = self._call_spec.get_context_arg_default(context_arg)
2634
+ # In cases (2), (3), (4) the training argument is passed
2635
+ # automatically by the framework, and will not be hard-coded
2636
+ # into the model.
2637
+ if self._expects_context_arg(context_arg):
2638
+ args, kwargs = self._call_spec.set_arg_value(
2639
+ context_arg, value, args, kwargs
2640
+ )
2641
+ args_passed_by_framework[context_arg] = True
2642
+
2643
+ if value is not None:
2644
+ propagated[context_arg] = value
2576
2645
 
2577
2646
  with call_context.enter(
2578
- layer=self, inputs=inputs, build_graph=True, training=training_value
2647
+ layer=self,
2648
+ inputs=inputs,
2649
+ build_graph=True,
2650
+ call_context_args=propagated,
2579
2651
  ):
2580
2652
  # Check input assumptions set after layer building, e.g. input
2581
2653
  # shape.
@@ -2601,10 +2673,13 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
2601
2673
  "Tensor or a list of Tensors, not None "
2602
2674
  "(layer: " + self.name + ")."
2603
2675
  )
2604
- if training_arg_passed_by_framework:
2605
- args, kwargs = self._call_spec.set_arg_value(
2606
- "training", None, args, kwargs, pop_kwarg_if_none=True
2607
- )
2676
+
2677
+ for context_arg, is_passed in args_passed_by_framework.items():
2678
+ if is_passed:
2679
+ args, kwargs = self._call_spec.set_arg_value(
2680
+ context_arg, None, args, kwargs, pop_kwarg_if_none=True
2681
+ )
2682
+
2608
2683
  if mask_arg_passed_by_framework:
2609
2684
  kwargs.pop("mask")
2610
2685
  # Node connectivity does not special-case the first argument.
@@ -2613,52 +2688,100 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
2613
2688
  )
2614
2689
  return outputs
2615
2690
 
2616
- def _set_training_mode(self, args, kwargs, call_context):
2617
- training_mode = None
2618
- if self._expects_training_arg:
2619
- # (1) `training` was passed to this `Layer.call`.
2620
- if self._call_spec.arg_was_passed("training", args, kwargs):
2621
- training_mode = self._call_spec.get_arg_value(
2622
- "training", args, kwargs
2623
- )
2624
- # If no `training` arg was passed, or `None` was explicitly passed,
2625
- # the framework will make a decision about the training mode is.
2626
- if training_mode is None:
2627
- call_ctx_training = call_context.training
2628
- # (2) `training` mode is inferred from an outer `Layer.call`.
2629
- if call_ctx_training is not None:
2630
- training_mode = call_ctx_training
2631
- # (3) User set `tf.keras.backend.set_learning_phase`.
2632
- elif backend.global_learning_phase_is_set():
2633
- training_mode = backend.learning_phase()
2634
- # Ensure value is a `bool` or `tf.bool`.
2635
- if isinstance(training_mode, bool):
2636
- pass
2637
- elif tf.is_tensor(training_mode):
2638
- training_mode = tf.cast(training_mode, tf.bool)
2691
+ def _get_propagated_call_context_arguments(
2692
+ self, args, kwargs, call_context, local_call_context_arguments
2693
+ ):
2694
+ """Resolves the values for propagated call context arguments for the
2695
+ current layer.
2696
+
2697
+ Args:
2698
+ args: The arguments passed to the current layer's `call` method.
2699
+ kwargs: The keyword arguments passed to the current layer's `call`
2700
+ method.
2701
+ call_context: The `CallContext` for the current call-stack.
2702
+ local_call_context_arguments: The call-context arguments registered
2703
+ with to the current layer's `Layer.call` method.
2704
+
2705
+ Returns:
2706
+ A tuple of the following:
2707
+ 1. Updated args
2708
+ 2. Updated kwargs
2709
+ 3. A dictionary of the resolved call-context arguments that should
2710
+ be propagated to the next layer in the call-stack.
2711
+ """
2712
+ propagated_context = dict()
2713
+ relevant_arguments = call_context.call_context_args.keys() | set(
2714
+ local_call_context_arguments
2715
+ )
2716
+
2717
+ for argument in relevant_arguments:
2718
+ authoritative_value = None
2719
+ was_explicitly_passed = self._call_spec.arg_was_passed(
2720
+ argument, args, kwargs
2721
+ )
2722
+ if self._expects_context_arg(argument):
2723
+ # (1) `arg_name` was passed to this `Layer.call`.
2724
+ if was_explicitly_passed:
2725
+ authoritative_value = self._call_spec.get_arg_value(
2726
+ argument, args, kwargs
2727
+ )
2728
+ # If no `arg_name` arg was passed, or `None` was explicitly
2729
+ # passed, the framework will make a decision about the training
2730
+ # mode is.
2731
+ if authoritative_value is None:
2732
+ value_from_context = call_context.get_call_context_arg(
2733
+ argument
2734
+ )
2735
+ # (2) `arg_name` mode is inferred from an outer
2736
+ # `Layer.call`.
2737
+ if value_from_context is not None:
2738
+ authoritative_value = value_from_context
2739
+ # (3) User set `tf.keras.backend.set_learning_phase`.
2740
+ elif (
2741
+ argument == "training"
2742
+ and backend.global_learning_phase_is_set()
2743
+ ):
2744
+ authoritative_value = backend.learning_phase()
2745
+ # Ensure value is a `bool` or `tf.bool`.
2746
+ if isinstance(authoritative_value, bool):
2747
+ pass
2748
+ elif tf.is_tensor(authoritative_value):
2749
+ authoritative_value = tf.cast(
2750
+ authoritative_value, tf.bool
2751
+ )
2752
+ else:
2753
+ authoritative_value = bool(authoritative_value)
2754
+ # (4) We default to using `call`'s default value for
2755
+ # `arg_name`, or treating the layer as if it is in inference
2756
+ # if no non-None default is specified in the `call`
2757
+ # signature.
2639
2758
  else:
2640
- training_mode = bool(training_mode)
2641
- # (4) We default to using `call`'s default value for `training`,
2642
- # or treating the layer as if it is in inference if no non-None
2643
- # default is specified in the `call` signature.
2644
- else:
2645
- training_mode = self._call_spec.default_training_arg
2759
+ authoritative_value = (
2760
+ self._call_spec.get_context_arg_default(argument)
2761
+ )
2646
2762
 
2647
- # For case (2), (3), (4) `training` arg is passed by framework.
2648
- args, kwargs = self._call_spec.set_arg_value(
2649
- "training", training_mode, args, kwargs
2650
- )
2651
- else:
2652
- if "training" in kwargs:
2653
- # `training` was passed to this `Layer` but is not needed for
2654
- # `Layer.call`. It will set the default mode for inner
2655
- # `Layer.call`s.
2656
- training_mode = kwargs.pop("training")
2763
+ # For case (2), (3), (4) `arg_name` arg is passed by
2764
+ # framework.
2765
+ args, kwargs = self._call_spec.set_arg_value(
2766
+ argument, authoritative_value, args, kwargs
2767
+ )
2657
2768
  else:
2658
- # Grab the current `training` mode from any outer `Layer.call`.
2659
- training_mode = call_context.training
2769
+ if argument in kwargs:
2770
+ # `arg_name` was passed to this `Layer` but is not needed
2771
+ # for `Layer.call`. It will set the default mode for inner
2772
+ # `Layer.call`s.
2773
+ authoritative_value = kwargs.pop(argument)
2774
+ else:
2775
+ # Grab the current `arg_name` mode from any outer
2776
+ # `Layer.call`.
2777
+ authoritative_value = call_context.get_call_context_arg(
2778
+ argument
2779
+ )
2660
2780
 
2661
- return args, kwargs, training_mode
2781
+ if authoritative_value is not None:
2782
+ propagated_context[argument] = authoritative_value
2783
+
2784
+ return args, kwargs, propagated_context
2662
2785
 
2663
2786
  def _autographed_call(self):
2664
2787
  # Wrapping `call` function in autograph to allow for dynamic control
@@ -3351,7 +3474,8 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
3351
3474
 
3352
3475
  def _init_call_fn_args(self, expects_training_arg=None):
3353
3476
  self._call_spec = layer_utils.CallFunctionSpec(
3354
- tf_inspect.getfullargspec(self.call)
3477
+ tf_inspect.getfullargspec(self.call),
3478
+ getattr(self, "_call_context_args", set()),
3355
3479
  )
3356
3480
  if expects_training_arg is not None:
3357
3481
  self._call_spec.expects_training_arg = expects_training_arg
@@ -3361,6 +3485,9 @@ class Layer(tf.Module, version_utils.LayerVersionSelector):
3361
3485
  """Whether the call function uses 'training' as a parameter."""
3362
3486
  return self._call_spec.expects_training_arg
3363
3487
 
3488
+ def _expects_context_arg(self, argument_name):
3489
+ return argument_name in self._call_spec.expected_context_args
3490
+
3364
3491
  @property
3365
3492
  def _expects_mask_arg(self):
3366
3493
  return self._call_spec.expects_mask_arg
@@ -480,7 +480,8 @@ class CallContext:
480
480
  layer: The `Layer` whose `call` is currently active.
481
481
  inputs: The inputs to the currently active `Layer`.
482
482
  build_graph: Whether currently inside a Graph or FuncGraph.
483
- training: Whether currently executing in training or inference mode.
483
+ call_context_args: The call-context arguments being propagated through the
484
+ the call-stack.
484
485
  saving: Whether currently saving to SavedModel.
485
486
  frozen: Whether currently executing inside a `Layer` with `trainable` set
486
487
  to `False`.
@@ -495,6 +496,7 @@ class CallContext:
495
496
  "layer": None,
496
497
  "inputs": None,
497
498
  "build_graph": False,
499
+ "call_context_args": dict(),
498
500
  "training": None,
499
501
  "saving": None,
500
502
  }
@@ -502,14 +504,17 @@ class CallContext:
502
504
  # refactor.
503
505
  self._in_keras_graph = False
504
506
 
505
- def enter(self, layer, inputs, build_graph, training, saving=None):
507
+ def enter(
508
+ self, layer, inputs, build_graph, call_context_args=dict(), saving=None
509
+ ):
506
510
  """Push a Layer and its inputs and state onto the current call context.
507
511
 
508
512
  Args:
509
513
  layer: The `Layer` whose `call` is currently active.
510
514
  inputs: The inputs to the currently active `Layer`.
511
515
  build_graph: Whether currently inside a Graph or FuncGraph.
512
- training: Whether currently executing in training or inference mode.
516
+ call_context_args: The call-context arguments being propagated through
517
+ the call-stack.
513
518
  saving: Whether currently saving to SavedModel.
514
519
 
515
520
  Returns:
@@ -519,7 +524,7 @@ class CallContext:
519
524
  "layer": layer,
520
525
  "inputs": inputs,
521
526
  "build_graph": build_graph,
522
- "training": training,
527
+ "call_context_args": call_context_args,
523
528
  "saving": saving,
524
529
  }
525
530
  return CallContextManager(self, state)
@@ -538,7 +543,14 @@ class CallContext:
538
543
 
539
544
  @property
540
545
  def training(self):
541
- return self._state["training"]
546
+ return self.call_context_args.get("training", None)
547
+
548
+ @property
549
+ def call_context_args(self):
550
+ return self._state["call_context_args"]
551
+
552
+ def get_call_context_arg(self, arg_name):
553
+ return self.call_context_args.get(arg_name, None)
542
554
 
543
555
  @property
544
556
  def saving(self):
@@ -132,6 +132,7 @@ class Layer(base_layer.Layer):
132
132
  self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
133
133
  ):
134
134
  self._instrument_layer_creation()
135
+ self._called = False
135
136
 
136
137
  # These properties should be set by the user via keyword arguments.
137
138
  # note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -165,6 +166,8 @@ class Layer(base_layer.Layer):
165
166
  self._input_spec = None
166
167
  self.supports_masking = False
167
168
 
169
+ self._call_context_args = {"training"}
170
+
168
171
  self._init_set_name(name)
169
172
  self._activity_regularizer = regularizers.get(
170
173
  kwargs.pop("activity_regularizer", None)
@@ -705,6 +708,7 @@ class Layer(base_layer.Layer):
705
708
  RuntimeError: if `super().__init__()` was not called in the
706
709
  constructor.
707
710
  """
711
+ self._called = True
708
712
  self._assert_built_as_v1()
709
713
 
710
714
  if not hasattr(self, "_thread_local"):
@@ -803,7 +807,12 @@ class Layer(base_layer.Layer):
803
807
  if build_graph and base_layer_utils.needs_keras_history(inputs):
804
808
  base_layer_utils.create_keras_history(inputs)
805
809
 
806
- with call_context.enter(self, inputs, build_graph, training_value):
810
+ with call_context.enter(
811
+ self,
812
+ inputs,
813
+ build_graph,
814
+ call_context_args={"training": training_value},
815
+ ):
807
816
  # Check input assumptions set after layer building, e.g. input
808
817
  # shape.
809
818
  if build_graph:
@@ -259,6 +259,10 @@ class TFOpLambda(Layer):
259
259
 
260
260
  self._call_spec.expects_training_arg = False
261
261
  self._call_spec.expects_mask_arg = False
262
+ # Clear the call-context arguments for the layer's call method.
263
+ # Otherwise, Keras ends up injecting context arguments into the op-call
264
+ # when the call method accepts kwargs.
265
+ self._call_spec._expected_context_args.clear()
262
266
 
263
267
  def _call_wrapper(self, *args, **kwargs):
264
268
  created_variables = []
@@ -52,9 +52,21 @@ class _RNNCellWrapper(AbstractRNNCell):
52
52
  super().__init__(*args, **kwargs)
53
53
  self.cell = cell
54
54
  cell_call_spec = tf_inspect.getfullargspec(cell.call)
55
+ accepts_kwargs = cell_call_spec.varkw is not None
56
+
55
57
  self._call_spec.expects_training_arg = (
56
58
  "training" in cell_call_spec.args
57
- ) or (cell_call_spec.varkw is not None)
59
+ ) or accepts_kwargs
60
+
61
+ # Filter _expects_context_arg. An argument is kept if:
62
+ # 1. It's an explicit argument in cell_call_spec.args OR
63
+ # 2. The cell accepts arbitrary keyword arguments (**kwargs),
64
+ # meaning it could potentially handle the context argument.
65
+ self._call_spec._expected_context_args = {
66
+ arg
67
+ for arg in self._call_spec._expected_context_args
68
+ if (arg in cell_call_spec.args) or accepts_kwargs
69
+ }
58
70
 
59
71
  def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
60
72
  """Calls the wrapped cell and performs the wrapping logic.
@@ -219,7 +219,11 @@ def wrap_layer_functions(layer, serialization_cache):
219
219
  with tracing_scope():
220
220
  call_collection.trace_with_input_signature()
221
221
  with base_layer_utils.call_context().enter(
222
- layer, inputs=None, build_graph=True, training=None, saving=True
222
+ layer,
223
+ inputs=None,
224
+ build_graph=True,
225
+ call_context_args={},
226
+ saving=True,
223
227
  ):
224
228
  for fn in fns.values():
225
229
  if fn is not None and not isinstance(fn, LayerCall):
@@ -515,19 +519,28 @@ class LayerCallCollection:
515
519
  else:
516
520
  add_trace_to_queue(fn, args, kwargs)
517
521
 
518
- def training_arg_was_passed(self, args, kwargs):
522
+ def arg_was_passed(self, arg_name, args, kwargs):
523
+ """Returns True if the argument was passed to the call function."""
519
524
  return self._call_spec.arg_was_passed(
520
- "training", args, kwargs, inputs_in_args=True
525
+ arg_name, args, kwargs, inputs_in_args=True
521
526
  )
522
527
 
523
- def get_training_arg_value(self, args, kwargs):
528
+ def training_arg_was_passed(self, args, kwargs):
529
+ """Returns True if the training arg was passed to the call function."""
530
+ return self.arg_was_passed("training", args, kwargs)
531
+
532
+ def get_arg_value(self, arg_name, args, kwargs):
533
+ """Returns the value of the given argument or None if not found."""
524
534
  try:
525
535
  return self._call_spec.get_arg_value(
526
- "training", args, kwargs, inputs_in_args=True
536
+ arg_name, args, kwargs, inputs_in_args=True
527
537
  )
528
- except KeyError: # Training is not in args or kwargs.
538
+ except KeyError: # Arg not found in args or kwargs.
529
539
  return None
530
540
 
541
+ def get_training_arg_value(self, args, kwargs):
542
+ return self.get_arg_value("training", args, kwargs)
543
+
531
544
  def get_input_arg_value(self, args, kwargs):
532
545
  return self._call_spec.get_arg_value(
533
546
  self._input_arg_name, args, kwargs, inputs_in_args=True
@@ -613,20 +626,23 @@ def layer_call_wrapper(call_collection, method, name):
613
626
  def wrapper(*args, **kwargs):
614
627
  """Calls method within call context."""
615
628
  layer = call_collection.layer
616
- training = None
629
+ propagated = {"training": None}
617
630
  inputs = _filtered_inputs([args, kwargs])
618
631
 
619
- if (args or kwargs) and call_collection.training_arg_was_passed(
620
- args, kwargs
621
- ):
622
- training = call_collection.get_training_arg_value(args, kwargs)
632
+ for context_arg in layer._call_context_args:
633
+ if (args or kwargs) and call_collection.arg_was_passed(
634
+ context_arg, args, kwargs
635
+ ):
636
+ propagated[context_arg] = call_collection.get_arg_value(
637
+ context_arg, args, kwargs
638
+ )
623
639
 
624
640
  original_losses = _reset_layer_losses(layer)
625
641
  with base_layer_utils.call_context().enter(
626
642
  layer,
627
643
  inputs=inputs,
628
644
  build_graph=False,
629
- training=training,
645
+ call_context_args=propagated,
630
646
  saving=True,
631
647
  ):
632
648
  with autocast_variable.enable_auto_cast_variables(
@@ -138,12 +138,24 @@ def trace_model_call(model, input_signature=None):
138
138
  @tf.function
139
139
  def _wrapped_model(*args, **kwargs):
140
140
  """A concrete tf.function that wraps the model's call function."""
141
+ call_context = base_layer_utils.call_context()
142
+
143
+ args, kwargs, propagated = model._get_propagated_call_context_arguments(
144
+ args, kwargs, call_context, model._call_context_args
145
+ )
146
+
141
147
  (args, kwargs,) = model._call_spec.set_arg_value(
142
148
  "training", False, args, kwargs, inputs_in_args=True
143
149
  )
144
150
 
145
- with base_layer_utils.call_context().enter(
146
- model, inputs=None, build_graph=False, training=False, saving=True
151
+ propagated["training"] = False
152
+
153
+ with call_context.enter(
154
+ model,
155
+ inputs=None,
156
+ build_graph=False,
157
+ call_context_args=propagated,
158
+ saving=True,
147
159
  ):
148
160
  outputs = model(*args, **kwargs)
149
161
 
@@ -775,11 +775,13 @@ class CallFunctionSpec:
775
775
  """Caches the spec and provides utilities for handling call function
776
776
  args."""
777
777
 
778
- def __init__(self, full_argspec):
778
+ def __init__(self, full_argspec, call_context_args=set()):
779
779
  """Initialies a `CallFunctionSpec`.
780
780
 
781
781
  Args:
782
782
  full_argspec: the FullArgSpec of a call function of a layer.
783
+ call_context_args: The set of call-context arguments registered
784
+ with to the current layer.
783
785
  """
784
786
  self._full_argspec = full_argspec
785
787
 
@@ -797,6 +799,18 @@ class CallFunctionSpec:
797
799
  "mask" in self._arg_names or call_accepts_kwargs
798
800
  )
799
801
 
802
+ # Track the set of call-context arguments that the current layer's
803
+ # `call` method accepts.
804
+ self._expected_context_args = set()
805
+ self._update_call_context_arguments(call_context_args)
806
+
807
+ self._context_arg_defaults = dict()
808
+ self._update_call_context_argument_defaults(call_context_args)
809
+
810
+ def _update_call_context_argument_defaults(self, context_args):
811
+ """Updates the set of call-context argument defaults for the current
812
+ layer's `call` method.
813
+ """
800
814
  call_fn_defaults = self._full_argspec.defaults or []
801
815
  defaults = dict()
802
816
  # The call arg defaults are an n-tuple of the last n elements of the
@@ -806,7 +820,21 @@ class CallFunctionSpec:
806
820
  # The default training arg will be any (non-None) default specified in
807
821
  # the method signature, or None if no value is specified.
808
822
  defaults.update(self._full_argspec.kwonlydefaults or {})
809
- self._default_training_arg = defaults.get("training")
823
+
824
+ for arg in context_args:
825
+ self._context_arg_defaults[arg] = defaults.get(arg)
826
+
827
+ def _update_call_context_arguments(self, context_args):
828
+ """Updates the set of call-context arguments that the current layer's
829
+ `call` method accepts.
830
+ """
831
+ call_accepts_kwargs = self._full_argspec.varkw is not None
832
+ args_to_add = {
833
+ arg
834
+ for arg in context_args
835
+ if call_accepts_kwargs or arg in self._arg_names
836
+ }
837
+ self._expected_context_args.update(args_to_add)
810
838
 
811
839
  @property
812
840
  def full_argspec(self):
@@ -843,6 +871,16 @@ class CallFunctionSpec:
843
871
  def expects_training_arg(self, value):
844
872
  self._expects_training_arg = value
845
873
 
874
+ @property
875
+ def expected_context_args(self):
876
+ """The set of call-context arguments that the current layer's
877
+ `call` method accepts."""
878
+ return self._expected_context_args
879
+
880
+ @expected_context_args.setter
881
+ def expected_context_args(self, value):
882
+ self._expected_context_args = value
883
+
846
884
  @property
847
885
  def expects_mask_arg(self):
848
886
  """Whether the call function uses `mask` as a parameter."""
@@ -855,7 +893,11 @@ class CallFunctionSpec:
855
893
  @property
856
894
  def default_training_arg(self):
857
895
  """The default value given to the "training" argument."""
858
- return self._default_training_arg
896
+ return self.get_context_arg_default("training")
897
+
898
+ def get_context_arg_default(self, arg_name):
899
+ """The default value given to the call context arguments."""
900
+ return self._context_arg_defaults.get(arg_name, None)
859
901
 
860
902
  def arg_was_passed(self, arg_name, args, kwargs, inputs_in_args=False):
861
903
  """Returns true if argument is present in `args` or `kwargs`.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tf_keras-nightly
3
- Version: 2.20.0.dev2025051909
3
+ Version: 2.20.0.dev2025052109
4
4
  Summary: Deep learning for humans.
5
5
  Home-page: https://keras.io/
6
6
  Download-URL: https://github.com/keras-team/tf-keras/tags
@@ -1,4 +1,4 @@
1
- tf_keras/__init__.py,sha256=MeElv0cYC0nUrQ5IYpomZsYZi3a11G5xFSbPyYPz0e8,911
1
+ tf_keras/__init__.py,sha256=pDkm8tudOxPPAmZy23qJxL-G-qVsy58igxL1yOF3aPM,911
2
2
  tf_keras/__internal__/__init__.py,sha256=OHQbeIC0QtRBI7dgXaJaVbH8F00x8dCI-DvEcIfyMsE,671
3
3
  tf_keras/__internal__/backend/__init__.py,sha256=LnMs2A6685gDG79fxqmdulIYlVE_3WmXlBTBo9ZWYcw,162
4
4
  tf_keras/__internal__/layers/__init__.py,sha256=F5SGMhOTPzm-PR44VrfinURHcVeQPIEdwnZlAkSTB3A,176
@@ -271,9 +271,9 @@ tf_keras/src/dtensor/lazy_variable.py,sha256=c3yylbga0se3Geflutss3fz5RzBYuY2vkU3
271
271
  tf_keras/src/dtensor/test_util.py,sha256=9QAbt44mlirdqwG2ertTsoXNKG2V4Z0bqJFxGdxy5BY,4572
272
272
  tf_keras/src/dtensor/utils.py,sha256=2TTSCEOA61Ia1FAPfQWJ2CRfiocBGUZreXH9UBFzFbk,6441
273
273
  tf_keras/src/engine/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
274
- tf_keras/src/engine/base_layer.py,sha256=H7TK3ezXORn1B79TZO3klFGE_hmXHAauS3h5k_8xJvA,156553
275
- tf_keras/src/engine/base_layer_utils.py,sha256=YMJF5sZJhFF_yzfqOtqi4YTsyUE2ZQ_cJJOIdXnuS2w,35795
276
- tf_keras/src/engine/base_layer_v1.py,sha256=cX-OCSNio3Tr2M6twr_PUgKulePLZDNW_4xLXjgYbN4,102700
274
+ tf_keras/src/engine/base_layer.py,sha256=cCSNR0ZMZBPRzpq6bByylNm5rwA8KzvIf16vKaXMpjE,161974
275
+ tf_keras/src/engine/base_layer_utils.py,sha256=AFjqwXM-WShf0dfsyIotlXYIRJlqYyjQhAf50xZgyos,36166
276
+ tf_keras/src/engine/base_layer_v1.py,sha256=V3Jyip5kiDyN-lwHAThS368iKxpznSw4tj4mwfWLobc,102896
277
277
  tf_keras/src/engine/base_preprocessing_layer.py,sha256=xne5VVtj9_IE1_cjh-kaPk-utoMY7mYwTOcgybFfY34,12650
278
278
  tf_keras/src/engine/compile_utils.py,sha256=F6KxbaXnppns5XCOJl8wzsiQ1riEp43s0G0SWsWAUE0,31757
279
279
  tf_keras/src/engine/data_adapter.py,sha256=UqYJBUDiS-vyu7euVYxQrXw0U9-piO7SwTetkGBSMwg,71654
@@ -343,7 +343,7 @@ tf_keras/src/layers/core/embedding.py,sha256=iOdkBiP1IzwOVPjsKWA54NXrlk5KgJ0DfQ8
343
343
  tf_keras/src/layers/core/identity.py,sha256=yj5cWlUTlYq_J_ZQb1iLzM0bqaM4V6TXVwM4iuBFp9U,1301
344
344
  tf_keras/src/layers/core/lambda_layer.py,sha256=QzetX-lV9ybonQKg_6QzSm8w9Vkq8CPAM4BcAke7CZk,16481
345
345
  tf_keras/src/layers/core/masking.py,sha256=19p6HYGlKdUfQnelsAoee6wf87fWx67NSGinyjagNc4,3340
346
- tf_keras/src/layers/core/tf_op_layer.py,sha256=4WDRrT8dVwnD7avcWvMCk9mnGwfHcaN3Dmhf7CBeqzQ,21066
346
+ tf_keras/src/layers/core/tf_op_layer.py,sha256=R6dFECVkPbmKi1nQVcxJy5lNxSVwiMlaWXB7j0PjI7Q,21320
347
347
  tf_keras/src/layers/experimental/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
348
348
  tf_keras/src/layers/experimental/dynamic_embedding.py,sha256=KuVIawm3avPEa5c2IDOyBH14xiU5bYbPqcm_HugfWYA,10730
349
349
  tf_keras/src/layers/experimental/dynamic_lookup.py,sha256=CMNOaxAIkB1ChPcusuymhLAYTvobEbCBli6YkuWw8RE,13720
@@ -454,7 +454,7 @@ tf_keras/src/layers/rnn/base_cudnn_rnn.py,sha256=cuPVg6r4L1pVWYTp3WFbJhikuIR2Vmg
454
454
  tf_keras/src/layers/rnn/base_rnn.py,sha256=I7mWl4KQC26gILDt9pZ9moZ81yM57lvci6hzJ9ROrxo,41968
455
455
  tf_keras/src/layers/rnn/base_wrapper.py,sha256=x4GANiXtmh9ztAFh7QtfbnQE76UVCGpaHp_XhrSs0Os,3159
456
456
  tf_keras/src/layers/rnn/bidirectional.py,sha256=JyZuBU0q2lt4augThwm8vyTvYwEJxyawsHmgNIul5vU,22670
457
- tf_keras/src/layers/rnn/cell_wrappers.py,sha256=T3FIiY9vIr0Or1N_SWNVnHR3LH6xnJ5DgNNYLk-sV6c,26874
457
+ tf_keras/src/layers/rnn/cell_wrappers.py,sha256=fMGpdFFoRWRIuKz88NcnMvAtevv8OYHzxkF86Ltmwfk,27384
458
458
  tf_keras/src/layers/rnn/conv_lstm1d.py,sha256=suShze6ipNXabGlKJTxkOia17ZP4SeEei3Mi4F8lFOQ,8761
459
459
  tf_keras/src/layers/rnn/conv_lstm2d.py,sha256=myxOioB3yNn0L_-gMh0R41sb-MwTXO993lAT05_N0Zw,8874
460
460
  tf_keras/src/layers/rnn/conv_lstm3d.py,sha256=GT4OoPFtCr5xgaaqy3ezt5DyDu8Ut-wQEihCOHFk0D4,8969
@@ -544,7 +544,7 @@ tf_keras/src/saving/legacy/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74
544
544
  tf_keras/src/saving/legacy/hdf5_format.py,sha256=IqFXHN96fuqKwu_akaqTyf9ISRPavP3Ahjydat948O4,42438
545
545
  tf_keras/src/saving/legacy/model_config.py,sha256=ZE6H_dKdmo2dlBWkr2nYO8SXcMEhshgza3sHPCpeu-k,4140
546
546
  tf_keras/src/saving/legacy/save.py,sha256=TdjiEamZ8MAsPAWsYMEtrdCRppbHcBIwJh9eVfdUS3k,23612
547
- tf_keras/src/saving/legacy/saving_utils.py,sha256=VTxnYFSWZ7m_40deANuWbtyhyXq0o0c5vLuBeCbgwi8,13745
547
+ tf_keras/src/saving/legacy/saving_utils.py,sha256=0iXchqZQNw9s5kB9_7SIj2p3Qd_21jdfHJQ3b3YVWQs,14042
548
548
  tf_keras/src/saving/legacy/serialization.py,sha256=OrmHQPolQFsR-UCMxNTxkIFTKY4DKcAgMm1jdhF7TqU,22285
549
549
  tf_keras/src/saving/legacy/saved_model/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
550
550
  tf_keras/src/saving/legacy/saved_model/base_serialization.py,sha256=dALR19_zt4c80zVw3yjCj9wfRoJufDjCrvkJyS82Dnk,5104
@@ -558,7 +558,7 @@ tf_keras/src/saving/legacy/saved_model/model_serialization.py,sha256=IxQ1TfBGagV
558
558
  tf_keras/src/saving/legacy/saved_model/network_serialization.py,sha256=ofbKN9V3syw0AQebgy2PlvaiAHi3SnBFTg-PUgclTng,1180
559
559
  tf_keras/src/saving/legacy/saved_model/order_preserving_set.py,sha256=zvNFzss8wSc0vngv74dNnQO_hxpxmEWWBBv1TTLsbPY,3250
560
560
  tf_keras/src/saving/legacy/saved_model/save.py,sha256=2-AaGFhFxzfZLkIW1qx9-rTcaZvYMFkQYP7ijfwA-ZI,6395
561
- tf_keras/src/saving/legacy/saved_model/save_impl.py,sha256=cWDBJ0uYfV1dNB5pdKFqDwYLQaaxWNeG8GSLiABAmec,29751
561
+ tf_keras/src/saving/legacy/saved_model/save_impl.py,sha256=mcdNPwJYwzOsdSisgwkBEbnoSABEZdshN7BRmTttK2c,30420
562
562
  tf_keras/src/saving/legacy/saved_model/serialized_attributes.py,sha256=nlmtIzLUBGSQU6gDKcg4-ypSRX3RbS4vPmLIhG3HSbk,15009
563
563
  tf_keras/src/saving/legacy/saved_model/utils.py,sha256=2OCwun0U8nsZvxUbv7Toq2EeC1HU32LxnLDan8cw4Dc,9953
564
564
  tf_keras/src/testing_infra/__init__.py,sha256=yrmnTOUMQ09fOgD3PD4NjpaeKz2OXCUmmoExRWhg9AY,690
@@ -585,7 +585,7 @@ tf_keras/src/utils/io_utils.py,sha256=XhCTkjwtfBc2hWSenjVdt0-2PsIc2bjJVWEP1880NU
585
585
  tf_keras/src/utils/keras_logging.py,sha256=Fv4eOMemx3Jg1hEdHIxx9GblG5YTnW1q1D1zLF3JxUE,882
586
586
  tf_keras/src/utils/kernelized_utils.py,sha256=s475SAos2zHQ1NT9AHZmbWUSahHKOhdctP6uIou0nRo,4517
587
587
  tf_keras/src/utils/kpl_test_utils.py,sha256=vnaJkySSTVhXsFEdDxNJArwXaah0yPNTK8o_3rYZvOE,7365
588
- tf_keras/src/utils/layer_utils.py,sha256=5SGCXE5Tc8QmCt7JHTbpVjYBkEMtdI7TQ5jC-9VLnCY,41764
588
+ tf_keras/src/utils/layer_utils.py,sha256=cLKqiqJ2em16zZyXaXFErsL6yja28qE6kgPs2TTcdcY,43427
589
589
  tf_keras/src/utils/losses_utils.py,sha256=oPHJSNLY8U57ieQD59vnGHNavZpMpeTZtL7VIlDwwfM,16919
590
590
  tf_keras/src/utils/metrics_utils.py,sha256=feW5GoiznbQKkxmE3Url2nlXfWgvHMpJPXdGKCdiV_U,39803
591
591
  tf_keras/src/utils/mode_keys.py,sha256=_QYq58qr_b-RhvMYBYnL47NkC0G1ng8NYcVnS_IYi-A,856
@@ -606,7 +606,7 @@ tf_keras/src/utils/legacy/__init__.py,sha256=EfMmeHYDzwvxNaktPhQbkTdcPSIGCqMhBND
606
606
  tf_keras/utils/__init__.py,sha256=b7_d-USe_EmLo02_P99Q1rUCzKBYayPCfiYFStP-0nw,2735
607
607
  tf_keras/utils/experimental/__init__.py,sha256=DzGogE2AosjxOVILQBT8PDDcqbWTc0wWnZRobCdpcec,97
608
608
  tf_keras/utils/legacy/__init__.py,sha256=7ujlDa5HeSRcth2NdqA0S1P2-VZF1kB3n68jye6Dj-8,189
609
- tf_keras_nightly-2.20.0.dev2025051909.dist-info/METADATA,sha256=K7yAGm5jLqINAdEvJPtd0TljkfiUrzCmYa3sOhl5x3M,1857
610
- tf_keras_nightly-2.20.0.dev2025051909.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
611
- tf_keras_nightly-2.20.0.dev2025051909.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
- tf_keras_nightly-2.20.0.dev2025051909.dist-info/RECORD,,
609
+ tf_keras_nightly-2.20.0.dev2025052109.dist-info/METADATA,sha256=pQdkJ0Fra-C-xPyhWHsEn_akYf6iD0Sh6kLn2q6-urw,1857
610
+ tf_keras_nightly-2.20.0.dev2025052109.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
611
+ tf_keras_nightly-2.20.0.dev2025052109.dist-info/top_level.txt,sha256=LC8FK7zHDNKxB17C6lGKvrZ_fZZGJsRiBK23SfiDegY,9
612
+ tf_keras_nightly-2.20.0.dev2025052109.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5