keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +30 -15
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +509 -29
- keras/src/backend/jax/numpy.py +59 -8
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +311 -1
- keras/src/backend/numpy/numpy.py +65 -2
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +943 -189
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +250 -50
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +80 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +2 -0
- keras/src/dtype_policies/dtype_policy.py +90 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +241 -111
- keras/src/layers/core/einsum_dense.py +316 -131
- keras/src/layers/core/embedding.py +84 -94
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +45 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +14 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +172 -34
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +258 -0
- keras/src/ops/numpy.py +569 -36
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +2 -8
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +65 -79
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +127 -61
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -2
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +5 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import math
|
|
2
3
|
import operator
|
|
3
4
|
import re
|
|
4
5
|
import warnings
|
|
@@ -96,13 +97,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
|
|
|
96
97
|
)
|
|
97
98
|
|
|
98
99
|
if torch_output_padding >= stride:
|
|
99
|
-
|
|
100
|
-
f"
|
|
101
|
-
f"output_padding
|
|
102
|
-
f"
|
|
103
|
-
|
|
104
|
-
f"padding arguments, kernel or stride, or run on another backend. "
|
|
100
|
+
warnings.warn(
|
|
101
|
+
f"Torch backend requires output_padding < stride. "
|
|
102
|
+
f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
|
|
103
|
+
f"for stride {stride}.",
|
|
104
|
+
UserWarning,
|
|
105
105
|
)
|
|
106
|
+
torch_output_padding = stride - 1
|
|
106
107
|
|
|
107
108
|
return torch_padding, torch_output_padding
|
|
108
109
|
|
|
@@ -184,6 +185,22 @@ def compute_conv_transpose_padding_args_for_torch(
|
|
|
184
185
|
torch_paddings.append(torch_padding)
|
|
185
186
|
torch_output_paddings.append(torch_output_padding)
|
|
186
187
|
|
|
188
|
+
# --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
|
|
189
|
+
corrected_output_paddings = []
|
|
190
|
+
for s, op in zip(
|
|
191
|
+
strides
|
|
192
|
+
if isinstance(strides, (list, tuple))
|
|
193
|
+
else [strides] * num_spatial_dims,
|
|
194
|
+
torch_output_paddings,
|
|
195
|
+
):
|
|
196
|
+
max_allowed = max(0, s - 1)
|
|
197
|
+
if op > max_allowed:
|
|
198
|
+
corrected_output_paddings.append(max_allowed)
|
|
199
|
+
else:
|
|
200
|
+
corrected_output_paddings.append(op)
|
|
201
|
+
|
|
202
|
+
torch_output_paddings = corrected_output_paddings
|
|
203
|
+
|
|
187
204
|
return torch_paddings, torch_output_paddings
|
|
188
205
|
|
|
189
206
|
|
|
@@ -523,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
|
|
|
523
540
|
-1 - axis
|
|
524
541
|
)
|
|
525
542
|
return x[tuple(slices)]
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
|
|
546
|
+
"""Compute small and big window sizes for adaptive pooling."""
|
|
547
|
+
small = math.ceil(input_dim / output_dim)
|
|
548
|
+
big = small + 1
|
|
549
|
+
return small, big
|
|
@@ -276,13 +276,13 @@ class Variable:
|
|
|
276
276
|
return self._maybe_autocast(self._value)
|
|
277
277
|
|
|
278
278
|
def assign(self, value):
|
|
279
|
-
value = self._convert_to_tensor(value, dtype=self.
|
|
279
|
+
value = self._convert_to_tensor(value, dtype=self._dtype)
|
|
280
280
|
if not shape_equal(value.shape, self.shape):
|
|
281
281
|
raise ValueError(
|
|
282
282
|
"The shape of the target variable and "
|
|
283
283
|
"the shape of the target value in "
|
|
284
284
|
"`variable.assign(value)` must match. "
|
|
285
|
-
f"variable.shape={self.
|
|
285
|
+
f"variable.shape={self.shape}, "
|
|
286
286
|
f"Received: value.shape={value.shape}. "
|
|
287
287
|
f"Target variable: {self}"
|
|
288
288
|
)
|
|
@@ -399,7 +399,11 @@ class Variable:
|
|
|
399
399
|
def __repr__(self):
|
|
400
400
|
value = None
|
|
401
401
|
if hasattr(self, "_value") and self._value is not None:
|
|
402
|
-
|
|
402
|
+
try:
|
|
403
|
+
value = backend.core.convert_to_numpy(self._value)
|
|
404
|
+
except:
|
|
405
|
+
# In some cases the conversion to numpy can fail.
|
|
406
|
+
pass
|
|
403
407
|
value_str = f", value={value}" if value is not None else ""
|
|
404
408
|
return (
|
|
405
409
|
f"<Variable path={self.path}, shape={self.shape}, "
|
|
@@ -595,7 +599,6 @@ def standardize_shape(shape):
|
|
|
595
599
|
# `tf.TensorShape` may contain `Dimension` objects.
|
|
596
600
|
# We need to convert the items in it to either int or `None`
|
|
597
601
|
shape = shape.as_list()
|
|
598
|
-
shape = tuple(shape)
|
|
599
602
|
|
|
600
603
|
if config.backend() == "jax":
|
|
601
604
|
# Replace `_DimExpr` (dimension expression) with None
|
|
@@ -605,25 +608,37 @@ def standardize_shape(shape):
|
|
|
605
608
|
None if jax_export.is_symbolic_dim(d) else d for d in shape
|
|
606
609
|
)
|
|
607
610
|
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
for e in shape:
|
|
614
|
-
if e is None:
|
|
611
|
+
# Handle dimensions that are not ints and not None, verify they're >= 0.
|
|
612
|
+
standardized_shape = []
|
|
613
|
+
for d in shape:
|
|
614
|
+
if d is None:
|
|
615
|
+
standardized_shape.append(d)
|
|
615
616
|
continue
|
|
616
|
-
|
|
617
|
+
|
|
618
|
+
# Reject these even if they can be cast to int successfully.
|
|
619
|
+
if isinstance(d, (str, float)):
|
|
617
620
|
raise ValueError(
|
|
618
621
|
f"Cannot convert '{shape}' to a shape. "
|
|
619
|
-
f"Found invalid
|
|
622
|
+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
|
|
620
623
|
)
|
|
621
|
-
|
|
624
|
+
|
|
625
|
+
try:
|
|
626
|
+
# Cast numpy scalars, tf constant tensors, etc.
|
|
627
|
+
d = int(d)
|
|
628
|
+
except Exception as e:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
f"Cannot convert '{shape}' to a shape. "
|
|
631
|
+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
|
|
632
|
+
) from e
|
|
633
|
+
if d < 0:
|
|
622
634
|
raise ValueError(
|
|
623
635
|
f"Cannot convert '{shape}' to a shape. "
|
|
624
636
|
"Negative dimensions are not allowed."
|
|
625
637
|
)
|
|
626
|
-
|
|
638
|
+
standardized_shape.append(d)
|
|
639
|
+
|
|
640
|
+
# This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
|
|
641
|
+
return tuple(standardized_shape)
|
|
627
642
|
|
|
628
643
|
|
|
629
644
|
def shape_equal(a_shape, b_shape):
|
keras/src/backend/jax/core.py
CHANGED
|
@@ -30,9 +30,7 @@ class JaxVariable(KerasVariable):
|
|
|
30
30
|
self._layout = layout
|
|
31
31
|
super().__init__(*args, **kwargs)
|
|
32
32
|
|
|
33
|
-
def
|
|
34
|
-
# Note that variable.shape is needed by distribution_lib
|
|
35
|
-
self._shape = self._validate_shape(value.shape)
|
|
33
|
+
def _initialize_layout(self):
|
|
36
34
|
# We can't import the keras/distribution/distribution_lib
|
|
37
35
|
# due to circular dependency.
|
|
38
36
|
distribution = global_state.get_global_attribute("distribution")
|
|
@@ -44,8 +42,28 @@ class JaxVariable(KerasVariable):
|
|
|
44
42
|
self._layout = tensor_layout.backend_layout
|
|
45
43
|
else:
|
|
46
44
|
self._layout = tensor_layout
|
|
45
|
+
|
|
46
|
+
def _initialize(self, value):
|
|
47
|
+
# Note that variable.shape is needed by distribution_lib
|
|
48
|
+
self._shape = self._validate_shape(value.shape)
|
|
49
|
+
self._initialize_layout()
|
|
47
50
|
self._direct_assign(value)
|
|
48
51
|
|
|
52
|
+
def _initialize_with_initializer(self, initializer):
|
|
53
|
+
self._initialize_layout()
|
|
54
|
+
layout = self._layout
|
|
55
|
+
shape = self._shape
|
|
56
|
+
if should_shard_at_init(layout, shape):
|
|
57
|
+
jitted_initializer = jax.jit(
|
|
58
|
+
initializer.__call__,
|
|
59
|
+
out_shardings=layout,
|
|
60
|
+
static_argnames=["shape", "dtype"],
|
|
61
|
+
)
|
|
62
|
+
value = jitted_initializer(shape=self._shape, dtype=self._dtype)
|
|
63
|
+
self._value = value
|
|
64
|
+
else:
|
|
65
|
+
super()._initialize_with_initializer(initializer)
|
|
66
|
+
|
|
49
67
|
def _direct_assign(self, value):
|
|
50
68
|
if self._layout is not None:
|
|
51
69
|
value = distribution_lib.distribute_variable(value, self._layout)
|
|
@@ -112,6 +130,12 @@ if config.is_nnx_enabled():
|
|
|
112
130
|
# The real value is now set in self._value, sync it to raw_value
|
|
113
131
|
object.__setattr__(self, "raw_value", self._value)
|
|
114
132
|
|
|
133
|
+
def _initialize_with_initializer(self, initializer):
|
|
134
|
+
value = self._convert_to_tensor(
|
|
135
|
+
initializer(self._shape, dtype=self._dtype)
|
|
136
|
+
)
|
|
137
|
+
self._initialize(value)
|
|
138
|
+
|
|
115
139
|
@property
|
|
116
140
|
def _value(self):
|
|
117
141
|
if hasattr(self, "raw_value"):
|
|
@@ -234,6 +258,71 @@ if config.is_nnx_enabled():
|
|
|
234
258
|
|
|
235
259
|
Variable = NnxVariable
|
|
236
260
|
|
|
261
|
+
def _flatten_nnx_variable(variable):
|
|
262
|
+
children = (variable.raw_value,)
|
|
263
|
+
# We copy __dict__ to avoid side effects
|
|
264
|
+
keras_state = variable.__dict__.copy()
|
|
265
|
+
# Remove elements that might be problematic or redundant if
|
|
266
|
+
# nnx.Variable's __getstate__
|
|
267
|
+
keras_state.pop("raw_value", None)
|
|
268
|
+
aux_data = (
|
|
269
|
+
variable._var_metadata,
|
|
270
|
+
getattr(variable, "_trace_state", None),
|
|
271
|
+
keras_state,
|
|
272
|
+
)
|
|
273
|
+
return children, aux_data
|
|
274
|
+
|
|
275
|
+
def _unflatten_nnx_variable(aux_data, children):
|
|
276
|
+
var_metadata, trace_state, keras_state = aux_data
|
|
277
|
+
raw_value = children[0]
|
|
278
|
+
|
|
279
|
+
# Create uninitialized instance
|
|
280
|
+
variable = NnxVariable.__new__(NnxVariable)
|
|
281
|
+
|
|
282
|
+
# Restore state
|
|
283
|
+
variable._var_metadata = var_metadata
|
|
284
|
+
if trace_state is not None:
|
|
285
|
+
variable._trace_state = trace_state
|
|
286
|
+
variable.__dict__.update(keras_state)
|
|
287
|
+
variable.raw_value = raw_value
|
|
288
|
+
|
|
289
|
+
return variable
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
jax.tree_util.register_pytree_node(
|
|
293
|
+
NnxVariable,
|
|
294
|
+
_flatten_nnx_variable,
|
|
295
|
+
_unflatten_nnx_variable,
|
|
296
|
+
)
|
|
297
|
+
except ValueError:
|
|
298
|
+
pass
|
|
299
|
+
|
|
300
|
+
def __setattr__(self, name, value):
|
|
301
|
+
# Mirror Keras attributes to _var_metadata to ensure persistence
|
|
302
|
+
# if the Pytree registration is not respected by NNX.
|
|
303
|
+
if (
|
|
304
|
+
name != "_var_metadata"
|
|
305
|
+
and name not in ("_raw_value", "_trace_state")
|
|
306
|
+
and hasattr(self, "_var_metadata")
|
|
307
|
+
):
|
|
308
|
+
self._var_metadata[name] = value
|
|
309
|
+
|
|
310
|
+
object.__setattr__(self, name, value)
|
|
311
|
+
|
|
312
|
+
NnxVariable.__setattr__ = __setattr__
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def should_shard_at_init(init_layout, shape):
|
|
316
|
+
if not isinstance(init_layout, jax.sharding.NamedSharding):
|
|
317
|
+
return False
|
|
318
|
+
|
|
319
|
+
if all(dim is None for dim in init_layout.spec):
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
size_threshold = 250 * 1024 * 1024
|
|
323
|
+
array_size = np.prod(shape) * 4
|
|
324
|
+
return array_size >= size_threshold
|
|
325
|
+
|
|
237
326
|
|
|
238
327
|
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
|
|
239
328
|
if ragged:
|
|
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
|
|
|
27
27
|
return [f"{device.platform}:{device.id}" for device in jax_devices]
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def get_device_count(device_type=None):
|
|
31
|
+
"""Returns the number of available JAX devices.
|
|
32
|
+
Args:
|
|
33
|
+
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
|
|
34
|
+
If `None`, it defaults to counting "gpu" or "tpu" devices if
|
|
35
|
+
available, otherwise it counts "cpu" devices. It does not
|
|
36
|
+
return the sum of all device types.
|
|
37
|
+
Returns:
|
|
38
|
+
int: The total number of JAX devices for the specified type.
|
|
39
|
+
"""
|
|
40
|
+
device_type = device_type.lower() if device_type else None
|
|
41
|
+
return jax.device_count(device_type)
|
|
42
|
+
|
|
43
|
+
|
|
30
44
|
def distribute_variable(value, layout):
|
|
31
45
|
"""Create a distributed variable for JAX.
|
|
32
46
|
|
|
@@ -146,13 +160,13 @@ def initialize_rng():
|
|
|
146
160
|
# Check if the global seed generator is set and ensure it has an initialized
|
|
147
161
|
# seed. Otherwise, reset the seed to the global seed.
|
|
148
162
|
global_seed_generator = global_state.get_global_attribute(
|
|
149
|
-
|
|
163
|
+
seed_generator.GLOBAL_SEED_GENERATOR
|
|
150
164
|
)
|
|
151
165
|
if global_seed_generator is not None:
|
|
152
166
|
seed = global_seed_generator.get_config()["seed"]
|
|
153
167
|
if seed is None:
|
|
154
168
|
global_state.set_global_attribute(
|
|
155
|
-
|
|
169
|
+
seed_generator.GLOBAL_SEED_GENERATOR,
|
|
156
170
|
seed_generator.SeedGenerator(
|
|
157
171
|
seed=global_seed,
|
|
158
172
|
name=global_seed_generator.name,
|
keras/src/backend/jax/linalg.py
CHANGED