keras-nightly 3.14.0.dev2026012804__py3-none-any.whl → 3.14.0.dev2026013004__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/quantizers/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/quantizers/__init__.py +3 -0
- keras/src/backend/jax/core.py +12 -2
- keras/src/callbacks/orbax_checkpoint.py +41 -8
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +80 -1
- keras/src/export/tfsm_layer.py +34 -0
- keras/src/layers/core/dense.py +278 -95
- keras/src/layers/core/einsum_dense.py +350 -181
- keras/src/layers/core/embedding.py +236 -49
- keras/src/layers/core/reversible_embedding.py +177 -35
- keras/src/layers/preprocessing/discretization.py +30 -1
- keras/src/quantizers/__init__.py +6 -0
- keras/src/quantizers/quantization_config.py +98 -4
- keras/src/quantizers/quantizers.py +262 -32
- keras/src/saving/saving_api.py +66 -2
- keras/src/version.py +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/METADATA +1 -1
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/RECORD +23 -23
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/WHEEL +0 -0
- {keras_nightly-3.14.0.dev2026012804.dist-info → keras_nightly-3.14.0.dev2026013004.dist-info}/top_level.txt +0 -0
|
@@ -17,6 +17,9 @@ from keras.src.dtype_policies.dtype_policy import (
|
|
|
17
17
|
from keras.src.dtype_policies.dtype_policy import (
|
|
18
18
|
GPTQDTypePolicy as GPTQDTypePolicy,
|
|
19
19
|
)
|
|
20
|
+
from keras.src.dtype_policies.dtype_policy import (
|
|
21
|
+
Int4DTypePolicy as Int4DTypePolicy,
|
|
22
|
+
)
|
|
20
23
|
from keras.src.dtype_policies.dtype_policy import (
|
|
21
24
|
QuantizedDTypePolicy as QuantizedDTypePolicy,
|
|
22
25
|
)
|
|
@@ -24,6 +24,9 @@ from keras.src.quantizers.quantization_config import (
|
|
|
24
24
|
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
|
|
25
25
|
from keras.src.quantizers.quantizers import Quantizer as Quantizer
|
|
26
26
|
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
|
|
27
|
+
from keras.src.quantizers.quantizers import (
|
|
28
|
+
abs_max_quantize_grouped_with_zero_point as abs_max_quantize_grouped_with_zero_point,
|
|
29
|
+
)
|
|
27
30
|
from keras.src.quantizers.quantizers import (
|
|
28
31
|
compute_float8_amax_history as compute_float8_amax_history,
|
|
29
32
|
)
|
keras/dtype_policies/__init__.py
CHANGED
|
@@ -17,6 +17,9 @@ from keras.src.dtype_policies.dtype_policy import (
|
|
|
17
17
|
from keras.src.dtype_policies.dtype_policy import (
|
|
18
18
|
GPTQDTypePolicy as GPTQDTypePolicy,
|
|
19
19
|
)
|
|
20
|
+
from keras.src.dtype_policies.dtype_policy import (
|
|
21
|
+
Int4DTypePolicy as Int4DTypePolicy,
|
|
22
|
+
)
|
|
20
23
|
from keras.src.dtype_policies.dtype_policy import (
|
|
21
24
|
QuantizedDTypePolicy as QuantizedDTypePolicy,
|
|
22
25
|
)
|
keras/quantizers/__init__.py
CHANGED
|
@@ -24,6 +24,9 @@ from keras.src.quantizers.quantization_config import (
|
|
|
24
24
|
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
|
|
25
25
|
from keras.src.quantizers.quantizers import Quantizer as Quantizer
|
|
26
26
|
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize
|
|
27
|
+
from keras.src.quantizers.quantizers import (
|
|
28
|
+
abs_max_quantize_grouped_with_zero_point as abs_max_quantize_grouped_with_zero_point,
|
|
29
|
+
)
|
|
27
30
|
from keras.src.quantizers.quantizers import (
|
|
28
31
|
compute_float8_amax_history as compute_float8_amax_history,
|
|
29
32
|
)
|
keras/src/backend/jax/core.py
CHANGED
|
@@ -98,7 +98,7 @@ if config.is_nnx_enabled():
|
|
|
98
98
|
):
|
|
99
99
|
# Ensure 'mutable' is in nnx_metadata, but explicit 'mutable'
|
|
100
100
|
# param takes precedence.
|
|
101
|
-
nnx_metadata["mutable"] =
|
|
101
|
+
nnx_metadata["mutable"] = True if mutable is None else mutable
|
|
102
102
|
|
|
103
103
|
# First, initialize a basic nnx.Variable with a dummy value
|
|
104
104
|
# This sets up the NNX variable structure
|
|
@@ -603,7 +603,17 @@ def random_seed_dtype():
|
|
|
603
603
|
|
|
604
604
|
|
|
605
605
|
def custom_gradient(fun):
|
|
606
|
-
|
|
606
|
+
fun_with_custom_gradient = jax.custom_gradient(fun=fun)
|
|
607
|
+
|
|
608
|
+
# Add a wrapper to unwrap variables, otherwise custom_gradient will fail
|
|
609
|
+
def fun_with_custom_gradient_wrapper(*args, **kwargs):
|
|
610
|
+
args, kwargs = tree.map_shape_structure(
|
|
611
|
+
lambda x: x.value if isinstance(x, KerasVariable) else x,
|
|
612
|
+
(args, kwargs),
|
|
613
|
+
)
|
|
614
|
+
return fun_with_custom_gradient(*args, **kwargs)
|
|
615
|
+
|
|
616
|
+
return fun_with_custom_gradient_wrapper
|
|
607
617
|
|
|
608
618
|
|
|
609
619
|
def remat(f):
|
|
@@ -8,6 +8,7 @@ from keras.src.api_export import keras_export
|
|
|
8
8
|
from keras.src.callbacks.monitor_callback import (
|
|
9
9
|
MonitorCallback, # For metric monitoring logic
|
|
10
10
|
)
|
|
11
|
+
from keras.src.saving import saving_lib
|
|
11
12
|
from keras.src.utils.module_utils import ocp
|
|
12
13
|
|
|
13
14
|
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
|
|
@@ -117,6 +118,7 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
117
118
|
initial_value_threshold=None,
|
|
118
119
|
max_to_keep=1,
|
|
119
120
|
save_on_background=True,
|
|
121
|
+
save_weights_only=False,
|
|
120
122
|
):
|
|
121
123
|
# Ensure orbax is available
|
|
122
124
|
ocp.initialize()
|
|
@@ -131,10 +133,12 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
131
133
|
self.save_freq = save_freq
|
|
132
134
|
self.max_to_keep = max_to_keep
|
|
133
135
|
self.save_on_background = save_on_background
|
|
136
|
+
self.save_weights_only = save_weights_only
|
|
134
137
|
self._batches_seen_since_last_saving = 0
|
|
135
138
|
self._last_batch_seen = 0
|
|
136
139
|
self._current_epoch = 0 # Keep track of epoch
|
|
137
140
|
self._total_batches_seen = 0 # Global batch counter for step tracking
|
|
141
|
+
self._async_futures = [] # Track async save futures
|
|
138
142
|
|
|
139
143
|
# Multi-host support
|
|
140
144
|
self._multihost_initialized = self._is_multihost_initialized()
|
|
@@ -167,9 +171,14 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
167
171
|
|
|
168
172
|
# Create the V1 Checkpointer with direct parameter passing
|
|
169
173
|
# Orbax will handle directory creation on all processes as needed
|
|
174
|
+
# save_decision_policy is required for proper coordination of
|
|
175
|
+
# rapid async saves
|
|
170
176
|
self.checkpointer = ocp.training.Checkpointer(
|
|
171
177
|
directory=directory,
|
|
172
178
|
preservation_policy=preservation_policy,
|
|
179
|
+
save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(
|
|
180
|
+
1
|
|
181
|
+
),
|
|
173
182
|
)
|
|
174
183
|
|
|
175
184
|
def _is_multihost_initialized(self):
|
|
@@ -246,15 +255,35 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
246
255
|
|
|
247
256
|
# Save the nested state structures directly (preserving layer
|
|
248
257
|
# names and structure)
|
|
249
|
-
|
|
258
|
+
if self.save_weights_only:
|
|
259
|
+
composite_state = {
|
|
260
|
+
"trainable_variables": state_tree["trainable_variables"],
|
|
261
|
+
"non_trainable_variables": state_tree[
|
|
262
|
+
"non_trainable_variables"
|
|
263
|
+
],
|
|
264
|
+
}
|
|
265
|
+
else:
|
|
266
|
+
composite_state = state_tree
|
|
267
|
+
# Include model configuration for full model restoration
|
|
268
|
+
# Use saving_lib helper to properly handle shared objects
|
|
269
|
+
config_json, _ = saving_lib._serialize_model_as_json(self.model)
|
|
270
|
+
composite_state["model_config"] = config_json
|
|
250
271
|
|
|
251
272
|
# Use a single with statement. If context_options is empty,
|
|
252
273
|
# Context() uses defaults.
|
|
253
274
|
with ocp.Context():
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
275
|
+
# Determine sync vs async based on save_on_background setting
|
|
276
|
+
use_sync = not self.save_on_background
|
|
277
|
+
|
|
278
|
+
if use_sync:
|
|
279
|
+
# Synchronous save
|
|
257
280
|
self.checkpointer.save_pytree(step, composite_state)
|
|
281
|
+
else:
|
|
282
|
+
# Async save
|
|
283
|
+
future = self.checkpointer.save_pytree_async(
|
|
284
|
+
step, composite_state
|
|
285
|
+
)
|
|
286
|
+
self._async_futures.append(future)
|
|
258
287
|
|
|
259
288
|
def on_train_batch_end(self, batch, logs=None):
|
|
260
289
|
if self._should_save_on_batch(batch):
|
|
@@ -306,12 +335,11 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
306
335
|
|
|
307
336
|
if should_save:
|
|
308
337
|
# Use epoch number as the step for Orbax save
|
|
309
|
-
# Keras has already made the save decision - Checkpointer will
|
|
310
|
-
# save unconditionally
|
|
311
338
|
self._save_checkpoint(step=epoch, logs=logs)
|
|
312
339
|
|
|
313
340
|
def on_train_end(self, logs=None):
|
|
314
|
-
# Close the Checkpointer
|
|
341
|
+
# Close the Checkpointer - this waits for any pending async saves
|
|
342
|
+
# to complete before closing
|
|
315
343
|
try:
|
|
316
344
|
self.checkpointer.close()
|
|
317
345
|
except Exception:
|
|
@@ -325,7 +353,12 @@ class OrbaxCheckpoint(MonitorCallback):
|
|
|
325
353
|
This method blocks until all asynchronous checkpoint save operations
|
|
326
354
|
have completed across all hosts in a multi-host setup.
|
|
327
355
|
"""
|
|
328
|
-
# Wait for
|
|
356
|
+
# Wait for all tracked async futures to complete
|
|
357
|
+
for future in self._async_futures:
|
|
358
|
+
future.result() # Wait for completion
|
|
359
|
+
self._async_futures.clear() # Clear completed futures
|
|
360
|
+
|
|
361
|
+
# Wait for any remaining async operations to complete on this host
|
|
329
362
|
self.checkpointer.wait()
|
|
330
363
|
|
|
331
364
|
# Multi-host synchronization: ensure all hosts complete
|
|
@@ -6,6 +6,7 @@ from keras.src.dtype_policies.dtype_policy import AWQDTypePolicy
|
|
|
6
6
|
from keras.src.dtype_policies.dtype_policy import DTypePolicy
|
|
7
7
|
from keras.src.dtype_policies.dtype_policy import FloatDTypePolicy
|
|
8
8
|
from keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy
|
|
9
|
+
from keras.src.dtype_policies.dtype_policy import Int4DTypePolicy
|
|
9
10
|
from keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy
|
|
10
11
|
from keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy
|
|
11
12
|
from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap
|
|
@@ -18,6 +19,7 @@ ALL_OBJECTS = {
|
|
|
18
19
|
QuantizedFloat8DTypePolicy,
|
|
19
20
|
DTypePolicyMap,
|
|
20
21
|
GPTQDTypePolicy,
|
|
22
|
+
Int4DTypePolicy,
|
|
21
23
|
}
|
|
22
24
|
ALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}
|
|
23
25
|
|
|
@@ -288,6 +288,79 @@ class QuantizedFloat8DTypePolicy(QuantizedDTypePolicy):
|
|
|
288
288
|
return config
|
|
289
289
|
|
|
290
290
|
|
|
291
|
+
@keras_export("keras.dtype_policies.Int4DTypePolicy")
|
|
292
|
+
class Int4DTypePolicy(QuantizedDTypePolicy):
|
|
293
|
+
"""Quantized dtype policy for int4 quantization.
|
|
294
|
+
|
|
295
|
+
This policy helps propagate quantization settings for int4 sub-channel
|
|
296
|
+
quantization when loading a quantized model in Keras format.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
mode: The quantization mode. This should be a string in the format
|
|
300
|
+
`"int4/<block_size>"`.
|
|
301
|
+
- `"int4"`: The identifier for the quantization algorithm.
|
|
302
|
+
- `<block_size>`: The block size for sub-channel quantization.
|
|
303
|
+
Use -1 for per-channel (legacy) quantization. Any positive
|
|
304
|
+
integer enables sub-channel quantization with that block size.
|
|
305
|
+
Example: `"int4/128"` for sub-channel with 128-element groups.
|
|
306
|
+
source_name: The source dtype policy name, e.g. "float32".
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(
|
|
310
|
+
self,
|
|
311
|
+
mode,
|
|
312
|
+
source_name=None,
|
|
313
|
+
):
|
|
314
|
+
parts = mode.split("/")
|
|
315
|
+
expected_format = "'int4/<block_size>'"
|
|
316
|
+
|
|
317
|
+
# Validate format
|
|
318
|
+
if len(parts) != 2 or parts[0] != "int4":
|
|
319
|
+
raise ValueError(
|
|
320
|
+
"Invalid mode for Int4DTypePolicy. Expected format "
|
|
321
|
+
f"{expected_format}, but got '{mode}'."
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Validate and cast block_size
|
|
325
|
+
try:
|
|
326
|
+
block_size = int(parts[1])
|
|
327
|
+
except ValueError:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
"Invalid mode for Int4DTypePolicy. <block_size> must be an "
|
|
330
|
+
f"integer. Expected format {expected_format}, but got '{mode}'."
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Validate supported values
|
|
334
|
+
if block_size < -1 or block_size == 0:
|
|
335
|
+
raise ValueError(
|
|
336
|
+
"Invalid block_size in mode. Supported values are "
|
|
337
|
+
"-1 (per-channel) or a positive integer (sub-channel), "
|
|
338
|
+
f"but got {block_size} from '{mode}'."
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
base_mode = parts[0]
|
|
342
|
+
super().__init__(
|
|
343
|
+
mode=base_mode,
|
|
344
|
+
source_name=source_name,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
self._name = f"{mode}_from_{source_name}"
|
|
348
|
+
self.mode = base_mode
|
|
349
|
+
self.block_size = block_size
|
|
350
|
+
|
|
351
|
+
def __eq__(self, other):
|
|
352
|
+
if super().__eq__(other) is False:
|
|
353
|
+
return False
|
|
354
|
+
return self.block_size == other.block_size
|
|
355
|
+
|
|
356
|
+
def get_config(self):
|
|
357
|
+
config = super().get_config()
|
|
358
|
+
# Reconstruct the full mode string for serialization
|
|
359
|
+
mode = f"{self.mode}/{self.block_size}"
|
|
360
|
+
config.update({"mode": mode})
|
|
361
|
+
return config
|
|
362
|
+
|
|
363
|
+
|
|
291
364
|
@keras_export("keras.dtype_policies.GPTQDTypePolicy")
|
|
292
365
|
class GPTQDTypePolicy(QuantizedDTypePolicy):
|
|
293
366
|
"""Quantized dtype policy for GPTQ quantization.
|
|
@@ -525,8 +598,14 @@ def _get_quantized_dtype_policy_by_str(policy):
|
|
|
525
598
|
f"Received: policy={policy}"
|
|
526
599
|
)
|
|
527
600
|
mode, source_name = split_name
|
|
528
|
-
if policy.startswith("int8")
|
|
601
|
+
if policy.startswith("int8"):
|
|
529
602
|
return QuantizedDTypePolicy(mode, source_name)
|
|
603
|
+
elif policy.startswith("int4"):
|
|
604
|
+
# Check if mode has block_size component (e.g., "int4/128")
|
|
605
|
+
if "/" in mode:
|
|
606
|
+
return Int4DTypePolicy(mode, source_name)
|
|
607
|
+
else:
|
|
608
|
+
return QuantizedDTypePolicy(mode, source_name)
|
|
530
609
|
elif policy.startswith("gptq"):
|
|
531
610
|
return GPTQDTypePolicy(mode, source_name)
|
|
532
611
|
elif policy.startswith("awq"):
|
keras/src/export/tfsm_layer.py
CHANGED
|
@@ -2,6 +2,7 @@ from keras.src import backend
|
|
|
2
2
|
from keras.src import layers
|
|
3
3
|
from keras.src.api_export import keras_export
|
|
4
4
|
from keras.src.export.saved_model import _list_variables_used_by_fns
|
|
5
|
+
from keras.src.saving import serialization_lib
|
|
5
6
|
from keras.src.utils.module_utils import tensorflow as tf
|
|
6
7
|
|
|
7
8
|
|
|
@@ -146,3 +147,36 @@ class TFSMLayer(layers.Layer):
|
|
|
146
147
|
"call_training_endpoint": self.call_training_endpoint,
|
|
147
148
|
}
|
|
148
149
|
return {**base_config, **config}
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def from_config(cls, config, custom_objects=None, safe_mode=None):
|
|
153
|
+
"""Creates a TFSMLayer from its config.
|
|
154
|
+
Args:
|
|
155
|
+
config: A Python dictionary, typically the output of `get_config`.
|
|
156
|
+
custom_objects: Optional dictionary mapping names to custom objects.
|
|
157
|
+
safe_mode: Boolean, whether to disallow loading TFSMLayer.
|
|
158
|
+
When `safe_mode=True`, loading is disallowed because TFSMLayer
|
|
159
|
+
loads external SavedModels that may contain attacker-controlled
|
|
160
|
+
executable graph code. Defaults to `True`.
|
|
161
|
+
Returns:
|
|
162
|
+
A TFSMLayer instance.
|
|
163
|
+
"""
|
|
164
|
+
# Follow the same pattern as Lambda layer for safe_mode handling
|
|
165
|
+
effective_safe_mode = (
|
|
166
|
+
safe_mode
|
|
167
|
+
if safe_mode is not None
|
|
168
|
+
else serialization_lib.in_safe_mode()
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if effective_safe_mode is not False:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"Requested the deserialization of a `TFSMLayer`, which "
|
|
174
|
+
"loads an external SavedModel. This carries a potential risk "
|
|
175
|
+
"of arbitrary code execution and thus it is disallowed by "
|
|
176
|
+
"default. If you trust the source of the artifact, you can "
|
|
177
|
+
"override this error by passing `safe_mode=False` to the "
|
|
178
|
+
"loading function, or calling "
|
|
179
|
+
"`keras.config.enable_unsafe_deserialization()."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return cls(**config)
|