keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import math
|
|
2
3
|
import operator
|
|
3
4
|
import re
|
|
4
5
|
import warnings
|
|
@@ -96,13 +97,13 @@ def _convert_conv_transpose_padding_args_from_keras_to_torch(
|
|
|
96
97
|
)
|
|
97
98
|
|
|
98
99
|
if torch_output_padding >= stride:
|
|
99
|
-
|
|
100
|
-
f"
|
|
101
|
-
f"output_padding
|
|
102
|
-
f"
|
|
103
|
-
|
|
104
|
-
f"padding arguments, kernel or stride, or run on another backend. "
|
|
100
|
+
warnings.warn(
|
|
101
|
+
f"Torch backend requires output_padding < stride. "
|
|
102
|
+
f"Clamping output_padding {torch_output_padding} -> {stride - 1} "
|
|
103
|
+
f"for stride {stride}.",
|
|
104
|
+
UserWarning,
|
|
105
105
|
)
|
|
106
|
+
torch_output_padding = stride - 1
|
|
106
107
|
|
|
107
108
|
return torch_padding, torch_output_padding
|
|
108
109
|
|
|
@@ -184,6 +185,22 @@ def compute_conv_transpose_padding_args_for_torch(
|
|
|
184
185
|
torch_paddings.append(torch_padding)
|
|
185
186
|
torch_output_paddings.append(torch_output_padding)
|
|
186
187
|
|
|
188
|
+
# --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---
|
|
189
|
+
corrected_output_paddings = []
|
|
190
|
+
for s, op in zip(
|
|
191
|
+
strides
|
|
192
|
+
if isinstance(strides, (list, tuple))
|
|
193
|
+
else [strides] * num_spatial_dims,
|
|
194
|
+
torch_output_paddings,
|
|
195
|
+
):
|
|
196
|
+
max_allowed = max(0, s - 1)
|
|
197
|
+
if op > max_allowed:
|
|
198
|
+
corrected_output_paddings.append(max_allowed)
|
|
199
|
+
else:
|
|
200
|
+
corrected_output_paddings.append(op)
|
|
201
|
+
|
|
202
|
+
torch_output_paddings = corrected_output_paddings
|
|
203
|
+
|
|
187
204
|
return torch_paddings, torch_output_paddings
|
|
188
205
|
|
|
189
206
|
|
|
@@ -523,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
|
|
|
523
540
|
-1 - axis
|
|
524
541
|
)
|
|
525
542
|
return x[tuple(slices)]
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
|
|
546
|
+
"""Compute small and big window sizes for adaptive pooling."""
|
|
547
|
+
small = math.ceil(input_dim / output_dim)
|
|
548
|
+
big = small + 1
|
|
549
|
+
return small, big
|
|
@@ -232,18 +232,12 @@ def _resolve_weak_type(dtype, precision="32"):
|
|
|
232
232
|
return f"float{precision}"
|
|
233
233
|
|
|
234
234
|
|
|
235
|
-
BIT64_TO_BIT16_DTYPE = {
|
|
236
|
-
"int32": "int16",
|
|
237
|
-
"int64": "int16",
|
|
238
|
-
"uint32": "uint16",
|
|
239
|
-
"uint64": "uint16",
|
|
240
|
-
"float32": "float16",
|
|
241
|
-
"float64": "float16",
|
|
242
|
-
}
|
|
243
235
|
BIT64_TO_BIT32_DTYPE = {
|
|
244
|
-
|
|
236
|
+
# Since TF variables require int64 to be placed on the GPU, we exclusively
|
|
237
|
+
# enable the int64 dtype for TF.
|
|
238
|
+
"int64": "int64" if config.backend() == "tensorflow" else "int32",
|
|
245
239
|
"uint64": "uint32",
|
|
246
|
-
"float64": "float32",
|
|
240
|
+
"float64": "float64" if config.backend() == "tensorflow" else "float32",
|
|
247
241
|
"complex128": "complex64",
|
|
248
242
|
}
|
|
249
243
|
|
|
@@ -277,8 +271,8 @@ def _lattice_result_type(*args):
|
|
|
277
271
|
if out_weak_type:
|
|
278
272
|
out_dtype = _resolve_weak_type(out_dtype, precision=precision)
|
|
279
273
|
|
|
280
|
-
# Force to be 32-bit dtype when encountering 64-bit dtype.
|
|
281
|
-
#
|
|
274
|
+
# Force to be 32-bit dtype when encountering 64-bit dtype. This is to
|
|
275
|
+
# be aligned with JAX's default behavior.
|
|
282
276
|
out_dtype = BIT64_TO_BIT32_DTYPE.get(out_dtype, out_dtype)
|
|
283
277
|
return out_dtype
|
|
284
278
|
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import os.path
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
|
|
5
3
|
from keras.src import backend
|
|
@@ -144,7 +142,7 @@ class Variable:
|
|
|
144
142
|
self._name = name
|
|
145
143
|
parent_path = current_path()
|
|
146
144
|
if parent_path:
|
|
147
|
-
self._path =
|
|
145
|
+
self._path = f"{parent_path}/{name}"
|
|
148
146
|
else:
|
|
149
147
|
self._path = name
|
|
150
148
|
self._shape = None
|
|
@@ -278,13 +276,13 @@ class Variable:
|
|
|
278
276
|
return self._maybe_autocast(self._value)
|
|
279
277
|
|
|
280
278
|
def assign(self, value):
|
|
281
|
-
value = self._convert_to_tensor(value, dtype=self.
|
|
279
|
+
value = self._convert_to_tensor(value, dtype=self._dtype)
|
|
282
280
|
if not shape_equal(value.shape, self.shape):
|
|
283
281
|
raise ValueError(
|
|
284
282
|
"The shape of the target variable and "
|
|
285
283
|
"the shape of the target value in "
|
|
286
284
|
"`variable.assign(value)` must match. "
|
|
287
|
-
f"variable.shape={self.
|
|
285
|
+
f"variable.shape={self.shape}, "
|
|
288
286
|
f"Received: value.shape={value.shape}. "
|
|
289
287
|
f"Target variable: {self}"
|
|
290
288
|
)
|
|
@@ -401,7 +399,11 @@ class Variable:
|
|
|
401
399
|
def __repr__(self):
|
|
402
400
|
value = None
|
|
403
401
|
if hasattr(self, "_value") and self._value is not None:
|
|
404
|
-
|
|
402
|
+
try:
|
|
403
|
+
value = backend.core.convert_to_numpy(self._value)
|
|
404
|
+
except:
|
|
405
|
+
# In some cases the conversion to numpy can fail.
|
|
406
|
+
pass
|
|
405
407
|
value_str = f", value={value}" if value is not None else ""
|
|
406
408
|
return (
|
|
407
409
|
f"<Variable path={self.path}, shape={self.shape}, "
|
|
@@ -597,30 +599,46 @@ def standardize_shape(shape):
|
|
|
597
599
|
# `tf.TensorShape` may contain `Dimension` objects.
|
|
598
600
|
# We need to convert the items in it to either int or `None`
|
|
599
601
|
shape = shape.as_list()
|
|
600
|
-
shape = tuple(shape)
|
|
601
602
|
|
|
602
|
-
if config.backend() == "
|
|
603
|
-
# `
|
|
604
|
-
|
|
605
|
-
shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
|
|
603
|
+
if config.backend() == "jax":
|
|
604
|
+
# Replace `_DimExpr` (dimension expression) with None
|
|
605
|
+
from jax import export as jax_export
|
|
606
606
|
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
607
|
+
shape = tuple(
|
|
608
|
+
None if jax_export.is_symbolic_dim(d) else d for d in shape
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
# Handle dimensions that are not ints and not None, verify they're >= 0.
|
|
612
|
+
standardized_shape = []
|
|
613
|
+
for d in shape:
|
|
614
|
+
if d is None:
|
|
615
|
+
standardized_shape.append(d)
|
|
612
616
|
continue
|
|
613
|
-
|
|
617
|
+
|
|
618
|
+
# Reject these even if they can be cast to int successfully.
|
|
619
|
+
if isinstance(d, (str, float)):
|
|
614
620
|
raise ValueError(
|
|
615
621
|
f"Cannot convert '{shape}' to a shape. "
|
|
616
|
-
f"Found invalid
|
|
622
|
+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
|
|
617
623
|
)
|
|
618
|
-
|
|
624
|
+
|
|
625
|
+
try:
|
|
626
|
+
# Cast numpy scalars, tf constant tensors, etc.
|
|
627
|
+
d = int(d)
|
|
628
|
+
except Exception as e:
|
|
629
|
+
raise ValueError(
|
|
630
|
+
f"Cannot convert '{shape}' to a shape. "
|
|
631
|
+
f"Found invalid dimension '{d}' of type '{type(d)}'. "
|
|
632
|
+
) from e
|
|
633
|
+
if d < 0:
|
|
619
634
|
raise ValueError(
|
|
620
635
|
f"Cannot convert '{shape}' to a shape. "
|
|
621
636
|
"Negative dimensions are not allowed."
|
|
622
637
|
)
|
|
623
|
-
|
|
638
|
+
standardized_shape.append(d)
|
|
639
|
+
|
|
640
|
+
# This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
|
|
641
|
+
return tuple(standardized_shape)
|
|
624
642
|
|
|
625
643
|
|
|
626
644
|
def shape_equal(a_shape, b_shape):
|
keras/src/backend/jax/core.py
CHANGED
|
@@ -3,6 +3,7 @@ import jax.experimental.sparse as jax_sparse
|
|
|
3
3
|
import jax.numpy as jnp
|
|
4
4
|
import ml_dtypes
|
|
5
5
|
import numpy as np
|
|
6
|
+
from jax import export as jax_export
|
|
6
7
|
|
|
7
8
|
from keras.src import tree
|
|
8
9
|
from keras.src.backend import config
|
|
@@ -29,9 +30,7 @@ class JaxVariable(KerasVariable):
|
|
|
29
30
|
self._layout = layout
|
|
30
31
|
super().__init__(*args, **kwargs)
|
|
31
32
|
|
|
32
|
-
def
|
|
33
|
-
# Note that variable.shape is needed by distribution_lib
|
|
34
|
-
self._shape = self._validate_shape(value.shape)
|
|
33
|
+
def _initialize_layout(self):
|
|
35
34
|
# We can't import the keras/distribution/distribution_lib
|
|
36
35
|
# due to circular dependency.
|
|
37
36
|
distribution = global_state.get_global_attribute("distribution")
|
|
@@ -43,8 +42,28 @@ class JaxVariable(KerasVariable):
|
|
|
43
42
|
self._layout = tensor_layout.backend_layout
|
|
44
43
|
else:
|
|
45
44
|
self._layout = tensor_layout
|
|
45
|
+
|
|
46
|
+
def _initialize(self, value):
|
|
47
|
+
# Note that variable.shape is needed by distribution_lib
|
|
48
|
+
self._shape = self._validate_shape(value.shape)
|
|
49
|
+
self._initialize_layout()
|
|
46
50
|
self._direct_assign(value)
|
|
47
51
|
|
|
52
|
+
def _initialize_with_initializer(self, initializer):
|
|
53
|
+
self._initialize_layout()
|
|
54
|
+
layout = self._layout
|
|
55
|
+
shape = self._shape
|
|
56
|
+
if should_shard_at_init(layout, shape):
|
|
57
|
+
jitted_initializer = jax.jit(
|
|
58
|
+
initializer.__call__,
|
|
59
|
+
out_shardings=layout,
|
|
60
|
+
static_argnames=["shape", "dtype"],
|
|
61
|
+
)
|
|
62
|
+
value = jitted_initializer(shape=self._shape, dtype=self._dtype)
|
|
63
|
+
self._value = value
|
|
64
|
+
else:
|
|
65
|
+
super()._initialize_with_initializer(initializer)
|
|
66
|
+
|
|
48
67
|
def _direct_assign(self, value):
|
|
49
68
|
if self._layout is not None:
|
|
50
69
|
value = distribution_lib.distribute_variable(value, self._layout)
|
|
@@ -111,6 +130,12 @@ if config.is_nnx_enabled():
|
|
|
111
130
|
# The real value is now set in self._value, sync it to raw_value
|
|
112
131
|
object.__setattr__(self, "raw_value", self._value)
|
|
113
132
|
|
|
133
|
+
def _initialize_with_initializer(self, initializer):
|
|
134
|
+
value = self._convert_to_tensor(
|
|
135
|
+
initializer(self._shape, dtype=self._dtype)
|
|
136
|
+
)
|
|
137
|
+
self._initialize(value)
|
|
138
|
+
|
|
114
139
|
@property
|
|
115
140
|
def _value(self):
|
|
116
141
|
if hasattr(self, "raw_value"):
|
|
@@ -233,6 +258,71 @@ if config.is_nnx_enabled():
|
|
|
233
258
|
|
|
234
259
|
Variable = NnxVariable
|
|
235
260
|
|
|
261
|
+
def _flatten_nnx_variable(variable):
|
|
262
|
+
children = (variable.raw_value,)
|
|
263
|
+
# We copy __dict__ to avoid side effects
|
|
264
|
+
keras_state = variable.__dict__.copy()
|
|
265
|
+
# Remove elements that might be problematic or redundant if
|
|
266
|
+
# nnx.Variable's __getstate__
|
|
267
|
+
keras_state.pop("raw_value", None)
|
|
268
|
+
aux_data = (
|
|
269
|
+
variable._var_metadata,
|
|
270
|
+
getattr(variable, "_trace_state", None),
|
|
271
|
+
keras_state,
|
|
272
|
+
)
|
|
273
|
+
return children, aux_data
|
|
274
|
+
|
|
275
|
+
def _unflatten_nnx_variable(aux_data, children):
|
|
276
|
+
var_metadata, trace_state, keras_state = aux_data
|
|
277
|
+
raw_value = children[0]
|
|
278
|
+
|
|
279
|
+
# Create uninitialized instance
|
|
280
|
+
variable = NnxVariable.__new__(NnxVariable)
|
|
281
|
+
|
|
282
|
+
# Restore state
|
|
283
|
+
variable._var_metadata = var_metadata
|
|
284
|
+
if trace_state is not None:
|
|
285
|
+
variable._trace_state = trace_state
|
|
286
|
+
variable.__dict__.update(keras_state)
|
|
287
|
+
variable.raw_value = raw_value
|
|
288
|
+
|
|
289
|
+
return variable
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
jax.tree_util.register_pytree_node(
|
|
293
|
+
NnxVariable,
|
|
294
|
+
_flatten_nnx_variable,
|
|
295
|
+
_unflatten_nnx_variable,
|
|
296
|
+
)
|
|
297
|
+
except ValueError:
|
|
298
|
+
pass
|
|
299
|
+
|
|
300
|
+
def __setattr__(self, name, value):
|
|
301
|
+
# Mirror Keras attributes to _var_metadata to ensure persistence
|
|
302
|
+
# if the Pytree registration is not respected by NNX.
|
|
303
|
+
if (
|
|
304
|
+
name != "_var_metadata"
|
|
305
|
+
and name not in ("_raw_value", "_trace_state")
|
|
306
|
+
and hasattr(self, "_var_metadata")
|
|
307
|
+
):
|
|
308
|
+
self._var_metadata[name] = value
|
|
309
|
+
|
|
310
|
+
object.__setattr__(self, name, value)
|
|
311
|
+
|
|
312
|
+
NnxVariable.__setattr__ = __setattr__
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def should_shard_at_init(init_layout, shape):
|
|
316
|
+
if not isinstance(init_layout, jax.sharding.NamedSharding):
|
|
317
|
+
return False
|
|
318
|
+
|
|
319
|
+
if all(dim is None for dim in init_layout.spec):
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
size_threshold = 250 * 1024 * 1024
|
|
323
|
+
array_size = np.prod(shape) * 4
|
|
324
|
+
return array_size >= size_threshold
|
|
325
|
+
|
|
236
326
|
|
|
237
327
|
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
|
|
238
328
|
if ragged:
|
|
@@ -282,8 +372,6 @@ def is_tensor(x):
|
|
|
282
372
|
|
|
283
373
|
|
|
284
374
|
def shape(x):
|
|
285
|
-
# This will work as long as we disallow
|
|
286
|
-
# dynamic shapes in JAX.
|
|
287
375
|
return x.shape
|
|
288
376
|
|
|
289
377
|
|
|
@@ -315,31 +403,29 @@ def compute_output_spec(fn, *args, **kwargs):
|
|
|
315
403
|
else:
|
|
316
404
|
maybe_symbolic_kwargs[k] = v
|
|
317
405
|
|
|
318
|
-
#
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
406
|
+
# Create a _DimExpr instance for one dimension by creating a symbolic
|
|
407
|
+
# shape with one dimension and extracting it.
|
|
408
|
+
#
|
|
409
|
+
# We create a single dynamic dimension and reuse it instead of creating
|
|
410
|
+
# N dynamic dimensions. This is for backwards compatibility. Previously
|
|
411
|
+
# we would fill all dynamic dimensions with the same concrete value.
|
|
412
|
+
# This can handle the case where there is an implicit assumption that
|
|
413
|
+
# two dimensions are the same (e.g. square images).
|
|
414
|
+
#
|
|
415
|
+
# We add the constraint "dynamic_dimension>=2" to prevent JAX from
|
|
416
|
+
# assuming that the dimension can be broadcastable or squeezable. It
|
|
417
|
+
# removes this ambiguity.
|
|
418
|
+
dynamic_dimension = jax_export.symbolic_shape(
|
|
419
|
+
"(dynamic_dimension)",
|
|
420
|
+
constraints=["dynamic_dimension>=2"],
|
|
421
|
+
)[0]
|
|
422
|
+
|
|
423
|
+
def convert_keras_tensor_to_jax(x):
|
|
325
424
|
if isinstance(x, KerasTensor):
|
|
326
|
-
shape =
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
shape[i] = fill_value
|
|
331
|
-
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
|
|
332
|
-
return jax_tensor
|
|
333
|
-
if isinstance(x, dict):
|
|
334
|
-
return {
|
|
335
|
-
k: convert_keras_tensor_to_jax(v, fill_value=fill_value)
|
|
336
|
-
for k, v in x.items()
|
|
337
|
-
}
|
|
338
|
-
if isinstance(x, list):
|
|
339
|
-
return [
|
|
340
|
-
convert_keras_tensor_to_jax(xi, fill_value=fill_value)
|
|
341
|
-
for xi in x
|
|
342
|
-
]
|
|
425
|
+
shape = tuple(
|
|
426
|
+
[d if d is not None else dynamic_dimension for d in x.shape]
|
|
427
|
+
)
|
|
428
|
+
return jax.ShapeDtypeStruct(shape, dtype=x.dtype)
|
|
343
429
|
return x
|
|
344
430
|
|
|
345
431
|
def wrapped_fn(*args, **kwargs):
|
|
@@ -374,63 +460,25 @@ def compute_output_spec(fn, *args, **kwargs):
|
|
|
374
460
|
with StatelessScope():
|
|
375
461
|
return fn(*rec_args, **kwargs, **static_kwargs)
|
|
376
462
|
|
|
377
|
-
|
|
378
|
-
ms_args_1, ms_kwargs_1 = tree.map_structure(
|
|
379
|
-
lambda x: convert_keras_tensor_to_jax(x, fill_value=83),
|
|
380
|
-
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
|
381
|
-
)
|
|
382
|
-
_, jax_out_1 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
|
|
383
|
-
*ms_args_1, **ms_kwargs_1
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
ms_args_2, ms_kwargs_2 = tree.map_structure(
|
|
387
|
-
lambda x: convert_keras_tensor_to_jax(x, fill_value=89),
|
|
388
|
-
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
|
389
|
-
)
|
|
390
|
-
_, jax_out_2 = jax.make_jaxpr(wrapped_fn, return_shape=True)(
|
|
391
|
-
*ms_args_2, **ms_kwargs_2
|
|
392
|
-
)
|
|
393
|
-
|
|
394
|
-
def merge_shapes(shape1, shape2):
|
|
395
|
-
return tuple(
|
|
396
|
-
[d1 if d1 == d2 else None for d1, d2 in zip(shape1, shape2)]
|
|
397
|
-
)
|
|
398
|
-
|
|
399
|
-
def convert_jax_specs_to_keras_tensor(x1, x2):
|
|
400
|
-
if isinstance(x1, jax.ShapeDtypeStruct):
|
|
401
|
-
if not isinstance(x2, jax.ShapeDtypeStruct):
|
|
402
|
-
raise ValueError("Indeterministic output ordering.")
|
|
403
|
-
return KerasTensor(
|
|
404
|
-
merge_shapes(x1.shape, x2.shape), dtype=x1.dtype
|
|
405
|
-
)
|
|
406
|
-
elif isinstance(x1, jax_sparse.BCOO):
|
|
407
|
-
if not isinstance(x2, jax_sparse.BCOO):
|
|
408
|
-
raise ValueError("Indeterministic output ordering.")
|
|
409
|
-
return KerasTensor(
|
|
410
|
-
merge_shapes(x1.shape, x2.shape),
|
|
411
|
-
dtype=x1.dtype,
|
|
412
|
-
sparse=True,
|
|
413
|
-
)
|
|
414
|
-
else:
|
|
415
|
-
return x1
|
|
416
|
-
|
|
417
|
-
return tree.map_structure(
|
|
418
|
-
convert_jax_specs_to_keras_tensor, jax_out_1, jax_out_2
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
maybe_symbolic_args, maybe_symbolic_kwargs = tree.map_structure(
|
|
463
|
+
maybe_symbolic_args_jax, maybe_symbolic_kwargs_jax = tree.map_structure(
|
|
422
464
|
convert_keras_tensor_to_jax,
|
|
423
465
|
(maybe_symbolic_args, maybe_symbolic_kwargs),
|
|
424
466
|
)
|
|
425
|
-
|
|
426
|
-
*
|
|
467
|
+
jax_out = jax.eval_shape(
|
|
468
|
+
wrapped_fn, *maybe_symbolic_args_jax, **maybe_symbolic_kwargs_jax
|
|
427
469
|
)
|
|
428
470
|
|
|
429
471
|
def convert_jax_spec_to_keras_tensor(x):
|
|
430
472
|
if isinstance(x, jax.ShapeDtypeStruct):
|
|
431
|
-
|
|
473
|
+
shape = tuple(
|
|
474
|
+
d if isinstance(d, int) else None for d in x.shape
|
|
475
|
+
)
|
|
476
|
+
return KerasTensor(shape, x.dtype)
|
|
432
477
|
elif isinstance(x, jax_sparse.BCOO):
|
|
433
|
-
|
|
478
|
+
shape = tuple(
|
|
479
|
+
d if isinstance(d, int) else None for d in x.shape
|
|
480
|
+
)
|
|
481
|
+
return KerasTensor(shape, x.dtype, sparse=True)
|
|
434
482
|
return x
|
|
435
483
|
|
|
436
484
|
return tree.map_structure(convert_jax_spec_to_keras_tensor, jax_out)
|
|
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
|
|
|
27
27
|
return [f"{device.platform}:{device.id}" for device in jax_devices]
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def get_device_count(device_type=None):
|
|
31
|
+
"""Returns the number of available JAX devices.
|
|
32
|
+
Args:
|
|
33
|
+
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
|
|
34
|
+
If `None`, it defaults to counting "gpu" or "tpu" devices if
|
|
35
|
+
available, otherwise it counts "cpu" devices. It does not
|
|
36
|
+
return the sum of all device types.
|
|
37
|
+
Returns:
|
|
38
|
+
int: The total number of JAX devices for the specified type.
|
|
39
|
+
"""
|
|
40
|
+
device_type = device_type.lower() if device_type else None
|
|
41
|
+
return jax.device_count(device_type)
|
|
42
|
+
|
|
43
|
+
|
|
30
44
|
def distribute_variable(value, layout):
|
|
31
45
|
"""Create a distributed variable for JAX.
|
|
32
46
|
|
|
@@ -146,13 +160,13 @@ def initialize_rng():
|
|
|
146
160
|
# Check if the global seed generator is set and ensure it has an initialized
|
|
147
161
|
# seed. Otherwise, reset the seed to the global seed.
|
|
148
162
|
global_seed_generator = global_state.get_global_attribute(
|
|
149
|
-
|
|
163
|
+
seed_generator.GLOBAL_SEED_GENERATOR
|
|
150
164
|
)
|
|
151
165
|
if global_seed_generator is not None:
|
|
152
166
|
seed = global_seed_generator.get_config()["seed"]
|
|
153
167
|
if seed is None:
|
|
154
168
|
global_state.set_global_attribute(
|
|
155
|
-
|
|
169
|
+
seed_generator.GLOBAL_SEED_GENERATOR,
|
|
156
170
|
seed_generator.SeedGenerator(
|
|
157
171
|
seed=global_seed,
|
|
158
172
|
name=global_seed_generator.name,
|
keras/src/backend/jax/layer.py
CHANGED
|
@@ -3,7 +3,9 @@ from keras.src.backend.config import is_nnx_enabled
|
|
|
3
3
|
if is_nnx_enabled():
|
|
4
4
|
from flax import nnx
|
|
5
5
|
|
|
6
|
-
BaseLayer
|
|
6
|
+
class BaseLayer(nnx.Module):
|
|
7
|
+
def __init_subclass__(cls, **kwargs):
|
|
8
|
+
super().__init_subclass__(pytree=False, **kwargs)
|
|
7
9
|
else:
|
|
8
10
|
BaseLayer = object
|
|
9
11
|
|
keras/src/backend/jax/linalg.py
CHANGED