keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026013004__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,9 @@ from keras.src.dtype_policies.dtype_policy import (
17
17
  from keras.src.dtype_policies.dtype_policy import (
18
18
  GPTQDTypePolicy as GPTQDTypePolicy,
19
19
  )
20
+ from keras.src.dtype_policies.dtype_policy import (
21
+ Int4DTypePolicy as Int4DTypePolicy,
22
+ )
20
23
  from keras.src.dtype_policies.dtype_policy import (
21
24
  QuantizedDTypePolicy as QuantizedDTypePolicy,
22
25
  )
@@ -24,6 +24,9 @@ from keras.src.quantizers.quantization_config import (
24
24
  from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
25
25
  from keras.src.quantizers.quantizers import Quantizer as Quantizer
26
26
  from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
27
+ from keras.src.quantizers.quantizers import (
28
+ abs_max_quantize_grouped_with_zero_point as abs_max_quantize_grouped_with_zero_point,
29
+ )
27
30
  from keras.src.quantizers.quantizers import (
28
31
  compute_float8_amax_history as compute_float8_amax_history,
29
32
  )
@@ -17,6 +17,9 @@ from keras.src.dtype_policies.dtype_policy import (
17
17
  from keras.src.dtype_policies.dtype_policy import (
18
18
  GPTQDTypePolicy as GPTQDTypePolicy,
19
19
  )
20
+ from keras.src.dtype_policies.dtype_policy import (
21
+ Int4DTypePolicy as Int4DTypePolicy,
22
+ )
20
23
  from keras.src.dtype_policies.dtype_policy import (
21
24
  QuantizedDTypePolicy as QuantizedDTypePolicy,
22
25
  )
@@ -24,6 +24,9 @@ from keras.src.quantizers.quantization_config import (
24
24
  from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
25
25
  from keras.src.quantizers.quantizers import Quantizer as Quantizer
26
26
  from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
27
+ from keras.src.quantizers.quantizers import (
28
+ abs_max_quantize_grouped_with_zero_point as abs_max_quantize_grouped_with_zero_point,
29
+ )
27
30
  from keras.src.quantizers.quantizers import (
28
31
  compute_float8_amax_history as compute_float8_amax_history,
29
32
  )
@@ -98,7 +98,7 @@ if config.is_nnx_enabled():
98
98
  ):
99
99
  # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable'
100
100
  # param takes precedence.
101
- nnx_metadata["mutable"] = trainable if mutable is None else mutable
101
+ nnx_metadata["mutable"] = True if mutable is None else mutable
102
102
 
103
103
  # First, initialize a basic nnx.Variable with a dummy value
104
104
  # This sets up the NNX variable structure
@@ -603,7 +603,17 @@ def random_seed_dtype():
603
603
 
604
604
 
605
605
  def custom_gradient(fun):
606
- return jax.custom_gradient(fun=fun)
606
+ fun_with_custom_gradient = jax.custom_gradient(fun=fun)
607
+
608
+ # Add a wrapper to unwrap variables, otherwise custom_gradient will fail
609
+ def fun_with_custom_gradient_wrapper(*args, **kwargs):
610
+ args, kwargs = tree.map_shape_structure(
611
+ lambda x: x.value if isinstance(x, KerasVariable) else x,
612
+ (args, kwargs),
613
+ )
614
+ return fun_with_custom_gradient(*args, **kwargs)
615
+
616
+ return fun_with_custom_gradient_wrapper
607
617
 
608
618
 
609
619
  def remat(f):
@@ -8,6 +8,7 @@ from keras.src.api_export import keras_export
8
8
  from keras.src.callbacks.monitor_callback import (
9
9
  MonitorCallback, # For metric monitoring logic
10
10
  )
11
+ from keras.src.saving import saving_lib
11
12
  from keras.src.utils.module_utils import ocp
12
13
 
13
14
  # Context and AsyncOptions are accessed through the lazy-loaded ocp module
@@ -117,6 +118,7 @@ class OrbaxCheckpoint(MonitorCallback):
117
118
  initial_value_threshold=None,
118
119
  max_to_keep=1,
119
120
  save_on_background=True,
121
+ save_weights_only=False,
120
122
  ):
121
123
  # Ensure orbax is available
122
124
  ocp.initialize()
@@ -131,10 +133,12 @@ class OrbaxCheckpoint(MonitorCallback):
131
133
  self.save_freq = save_freq
132
134
  self.max_to_keep = max_to_keep
133
135
  self.save_on_background = save_on_background
136
+ self.save_weights_only = save_weights_only
134
137
  self._batches_seen_since_last_saving = 0
135
138
  self._last_batch_seen = 0
136
139
  self._current_epoch = 0 # Keep track of epoch
137
140
  self._total_batches_seen = 0 # Global batch counter for step tracking
141
+ self._async_futures = [] # Track async save futures
138
142
 
139
143
  # Multi-host support
140
144
  self._multihost_initialized = self._is_multihost_initialized()
@@ -167,9 +171,14 @@ class OrbaxCheckpoint(MonitorCallback):
167
171
 
168
172
  # Create the V1 Checkpointer with direct parameter passing
169
173
  # Orbax will handle directory creation on all processes as needed
174
+ # save_decision_policy is required for proper coordination of
175
+ # rapid async saves
170
176
  self.checkpointer = ocp.training.Checkpointer(
171
177
  directory=directory,
172
178
  preservation_policy=preservation_policy,
179
+ save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
180
+ 1
181
+ ),
173
182
  )
174
183
 
175
184
  def _is_multihost_initialized(self):
@@ -246,15 +255,35 @@ class OrbaxCheckpoint(MonitorCallback):
246
255
 
247
256
  # Save the nested state structures directly (preserving layer
248
257
  # names and structure)
249
- composite_state = state_tree
258
+ if self.save_weights_only:
259
+ composite_state = {
260
+ "trainable_variables": state_tree["trainable_variables"],
261
+ "non_trainable_variables": state_tree[
262
+ "non_trainable_variables"
263
+ ],
264
+ }
265
+ else:
266
+ composite_state = state_tree
267
+ # Include model configuration for full model restoration
268
+ # Use saving_lib helper to properly handle shared objects
269
+ config_json, _ = saving_lib._serialize_model_as_json(self.model)
270
+ composite_state["model_config"] = config_json
250
271
 
251
272
  # Use a single with statement. If context_options is empty,
252
273
  # Context() uses defaults.
253
274
  with ocp.Context():
254
- if self.save_on_background:
255
- self.checkpointer.save_pytree_async(step, composite_state)
256
- else:
275
+ # Determine sync vs async based on save_on_background setting
276
+ use_sync = not self.save_on_background
277
+
278
+ if use_sync:
279
+ # Synchronous save
257
280
  self.checkpointer.save_pytree(step, composite_state)
281
+ else:
282
+ # Async save
283
+ future = self.checkpointer.save_pytree_async(
284
+ step, composite_state
285
+ )
286
+ self._async_futures.append(future)
258
287
 
259
288
  def on_train_batch_end(self, batch, logs=None):
260
289
  if self._should_save_on_batch(batch):
@@ -306,12 +335,11 @@ class OrbaxCheckpoint(MonitorCallback):
306
335
 
307
336
  if should_save:
308
337
  # Use epoch number as the step for Orbax save
309
- # Keras has already made the save decision - Checkpointer will
310
- # save unconditionally
311
338
  self._save_checkpoint(step=epoch, logs=logs)
312
339
 
313
340
  def on_train_end(self, logs=None):
314
- # Close the Checkpointer to ensure all pending saves complete
341
+ # Close the Checkpointer - this waits for any pending async saves
342
+ # to complete before closing
315
343
  try:
316
344
  self.checkpointer.close()
317
345
  except Exception:
@@ -325,7 +353,12 @@ class OrbaxCheckpoint(MonitorCallback):
325
353
  This method blocks until all asynchronous checkpoint save operations
326
354
  have completed across all hosts in a multi-host setup.
327
355
  """
328
- # Wait for any async operations to complete on this host
356
+ # Wait for all tracked async futures to complete
357
+ for future in self._async_futures:
358
+ future.result() # Wait for completion
359
+ self._async_futures.clear() # Clear completed futures
360
+
361
+ # Wait for any remaining async operations to complete on this host
329
362
  self.checkpointer.wait()
330
363
 
331
364
  # Multi-host synchronization: ensure all hosts complete
@@ -6,6 +6,7 @@ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
6
6
  from keras.src.dtype_policies.dtype_policy import DTypePolicy
7
7
  from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
8
8
  from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
9
+ from keras.src.dtype_policies.dtype_policy import Int4DTypePolicy
9
10
  from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
10
11
  from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
11
12
  from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
@@ -18,6 +19,7 @@ ALL_OBJECTS = {
18
19
  QuantizedFloat8DTypePolicy,
19
20
  DTypePolicyMap,
20
21
  GPTQDTypePolicy,
22
+ Int4DTypePolicy,
21
23
  }
22
24
  ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
23
25
 
@@ -288,6 +288,79 @@ class QuantizedFloat8DTypePolicy(QuantizedDTypePolicy):
288
288
  return config
289
289
 
290
290
 
291
+ @keras_export("keras.dtype_policies.Int4DTypePolicy")
292
+ class Int4DTypePolicy(QuantizedDTypePolicy):
293
+ """Quantized dtype policy for int4 quantization.
294
+
295
+ This policy helps propagate quantization settings for int4 sub-channel
296
+ quantization when loading a quantized model in Keras format.
297
+
298
+ Args:
299
+ mode: The quantization mode. This should be a string in the format
300
+ `"int4/<block_size>"`.
301
+ - `"int4"`: The identifier for the quantization algorithm.
302
+ - `<block_size>`: The block size for sub-channel quantization.
303
+ Use -1 for per-channel (legacy) quantization. Any positive
304
+ integer enables sub-channel quantization with that block size.
305
+ Example: `"int4/128"` for sub-channel with 128-element groups.
306
+ source_name: The source dtype policy name, e.g. "float32".
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ mode,
312
+ source_name=None,
313
+ ):
314
+ parts = mode.split("/")
315
+ expected_format = "'int4/<block_size>'"
316
+
317
+ # Validate format
318
+ if len(parts) != 2 or parts[0] != "int4":
319
+ raise ValueError(
320
+ "Invalid mode for Int4DTypePolicy. Expected format "
321
+ f"{expected_format}, but got '{mode}'."
322
+ )
323
+
324
+ # Validate and cast block_size
325
+ try:
326
+ block_size = int(parts[1])
327
+ except ValueError:
328
+ raise ValueError(
329
+ "Invalid mode for Int4DTypePolicy. <block_size> must be an "
330
+ f"integer. Expected format {expected_format}, but got '{mode}'."
331
+ )
332
+
333
+ # Validate supported values
334
+ if block_size < -1 or block_size == 0:
335
+ raise ValueError(
336
+ "Invalid block_size in mode. Supported values are "
337
+ "-1 (per-channel) or a positive integer (sub-channel), "
338
+ f"but got {block_size} from '{mode}'."
339
+ )
340
+
341
+ base_mode = parts[0]
342
+ super().__init__(
343
+ mode=base_mode,
344
+ source_name=source_name,
345
+ )
346
+
347
+ self._name = f"{mode}_from_{source_name}"
348
+ self.mode = base_mode
349
+ self.block_size = block_size
350
+
351
+ def __eq__(self, other):
352
+ if super().__eq__(other) is False:
353
+ return False
354
+ return self.block_size == other.block_size
355
+
356
+ def get_config(self):
357
+ config = super().get_config()
358
+ # Reconstruct the full mode string for serialization
359
+ mode = f"{self.mode}/{self.block_size}"
360
+ config.update({"mode": mode})
361
+ return config
362
+
363
+
291
364
  @keras_export("keras.dtype_policies.GPTQDTypePolicy")
292
365
  class GPTQDTypePolicy(QuantizedDTypePolicy):
293
366
  """Quantized dtype policy for GPTQ quantization.
@@ -525,8 +598,14 @@ def _get_quantized_dtype_policy_by_str(policy):
525
598
  f"Received: policy={policy}"
526
599
  )
527
600
  mode, source_name = split_name
528
- if policy.startswith("int8") or policy.startswith("int4"):
601
+ if policy.startswith("int8"):
529
602
  return QuantizedDTypePolicy(mode, source_name)
603
+ elif policy.startswith("int4"):
604
+ # Check if mode has block_size component (e.g., "int4/128")
605
+ if "/" in mode:
606
+ return Int4DTypePolicy(mode, source_name)
607
+ else:
608
+ return QuantizedDTypePolicy(mode, source_name)
530
609
  elif policy.startswith("gptq"):
531
610
  return GPTQDTypePolicy(mode, source_name)
532
611
  elif policy.startswith("awq"):
@@ -2,6 +2,7 @@ from keras.src import backend
2
2
  from keras.src import layers
3
3
  from keras.src.api_export import keras_export
4
4
  from keras.src.export.saved_model import _list_variables_used_by_fns
5
+ from keras.src.saving import serialization_lib
5
6
  from keras.src.utils.module_utils import tensorflow as tf
6
7
 
7
8
 
@@ -146,3 +147,36 @@ class TFSMLayer(layers.Layer):
146
147
  "call_training_endpoint": self.call_training_endpoint,
147
148
  }
148
149
  return {**base_config, **config}
150
+
151
+ @classmethod
152
+ def from_config(cls, config, custom_objects=None, safe_mode=None):
153
+ """Creates a TFSMLayer from its config.
154
+ Args:
155
+ config: A Python dictionary, typically the output of `get_config`.
156
+ custom_objects: Optional dictionary mapping names to custom objects.
157
+ safe_mode: Boolean, whether to disallow loading TFSMLayer.
158
+ When `safe_mode=True`, loading is disallowed because TFSMLayer
159
+ loads external SavedModels that may contain attacker-controlled
160
+ executable graph code. Defaults to `True`.
161
+ Returns:
162
+ A TFSMLayer instance.
163
+ """
164
+ # Follow the same pattern as Lambda layer for safe_mode handling
165
+ effective_safe_mode = (
166
+ safe_mode
167
+ if safe_mode is not None
168
+ else serialization_lib.in_safe_mode()
169
+ )
170
+
171
+ if effective_safe_mode is not False:
172
+ raise ValueError(
173
+ "Requested the deserialization of a `TFSMLayer`, which "
174
+ "loads an external SavedModel. This carries a potential risk "
175
+ "of arbitrary code execution and thus it is disallowed by "
176
+ "default. If you trust the source of the artifact, you can "
177
+ "override this error by passing `safe_mode=False` to the "
178
+ "loading function, or calling "
179
+ "`keras.config.enable_unsafe_deserialization()."
180
+ )
181
+
182
+ return cls(**config)