keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +13 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
- keras/_tf_keras/keras/quantizers/__init__.py +12 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +13 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +9 -0
- keras/quantizers/__init__.py +12 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +1 -1
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +33 -16
- keras/src/backend/jax/core.py +92 -3
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +485 -20
- keras/src/backend/jax/numpy.py +92 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +14 -2
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +76 -7
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +2 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1030 -185
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +264 -54
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +84 -8
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +299 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +191 -172
- keras/src/layers/core/einsum_dense.py +235 -186
- keras/src/layers/core/embedding.py +83 -93
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +390 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +40 -15
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/losses/loss.py +1 -1
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +156 -27
- keras/src/ops/image.py +184 -3
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +541 -43
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +12 -1
- keras/src/quantizers/gptq.py +8 -6
- keras/src/quantizers/gptq_config.py +36 -1
- keras/src/quantizers/gptq_core.py +150 -78
- keras/src/quantizers/quantization_config.py +232 -0
- keras/src/quantizers/quantizers.py +114 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +4 -2
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +14 -5
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +187 -36
- keras/src/utils/module_utils.py +18 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/rng_utils.py +9 -1
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
keras/src/utils/jax_layer.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import inspect
|
|
3
|
+
import itertools
|
|
4
|
+
import string
|
|
2
5
|
|
|
3
6
|
import numpy as np
|
|
4
7
|
|
|
@@ -12,6 +15,22 @@ from keras.src.saving import serialization_lib
|
|
|
12
15
|
from keras.src.utils import jax_utils
|
|
13
16
|
from keras.src.utils import tracking
|
|
14
17
|
from keras.src.utils.module_utils import jax
|
|
18
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
19
|
+
|
|
20
|
+
if backend.backend() == "tensorflow":
|
|
21
|
+
tf_no_automatic_dependency_tracking = (
|
|
22
|
+
tf.__internal__.tracking.no_automatic_dependency_tracking
|
|
23
|
+
)
|
|
24
|
+
else:
|
|
25
|
+
|
|
26
|
+
def tf_no_automatic_dependency_tracking(fn):
|
|
27
|
+
return fn
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _convert_to_jax_key(tensor):
|
|
31
|
+
if backend.backend() == "tensorflow":
|
|
32
|
+
return tf.bitcast(tensor, tf.uint32)[0]
|
|
33
|
+
return tensor
|
|
15
34
|
|
|
16
35
|
|
|
17
36
|
@keras_export("keras.layers.JaxLayer")
|
|
@@ -219,10 +238,10 @@ class JaxLayer(Layer):
|
|
|
219
238
|
seed=None,
|
|
220
239
|
**kwargs,
|
|
221
240
|
):
|
|
222
|
-
if backend.backend()
|
|
241
|
+
if backend.backend() not in ["jax", "tensorflow"]:
|
|
223
242
|
raise ValueError(
|
|
224
|
-
"
|
|
225
|
-
f"backend: {backend.backend()}"
|
|
243
|
+
f"{self.__class__.__name__} is only supported with the JAX or"
|
|
244
|
+
f" Tensorflow backend. Current backend: {backend.backend()}"
|
|
226
245
|
)
|
|
227
246
|
|
|
228
247
|
if init_fn is None and params is None and state is None:
|
|
@@ -252,6 +271,10 @@ class JaxLayer(Layer):
|
|
|
252
271
|
init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
|
|
253
272
|
)
|
|
254
273
|
|
|
274
|
+
# Attributes for jax2tf functions
|
|
275
|
+
self.jax2tf_training_false_fn = None
|
|
276
|
+
self.jax2tf_training_true_fn = None
|
|
277
|
+
|
|
255
278
|
def _validate_signature(self, fn, fn_name, allowed, required):
|
|
256
279
|
fn_parameters = inspect.signature(fn).parameters
|
|
257
280
|
for parameter_name in required:
|
|
@@ -272,7 +295,81 @@ class JaxLayer(Layer):
|
|
|
272
295
|
|
|
273
296
|
return parameter_names
|
|
274
297
|
|
|
298
|
+
def _get_jax2tf_input_shape(self, input_shape):
|
|
299
|
+
"""Convert input shape in a format suitable for `jax2tf`.
|
|
300
|
+
|
|
301
|
+
`jax2tf` expects a letter for each unknown dimension, which allows
|
|
302
|
+
correlated dimensions. Since correlated dimensions are not supported by
|
|
303
|
+
Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
|
|
304
|
+
however use 'batch' for dimension 0 if not defined to correlate the
|
|
305
|
+
batch size across inputs.
|
|
306
|
+
|
|
307
|
+
Example (spaces added for readability):
|
|
308
|
+
```
|
|
309
|
+
input_shape: (None , 4 , None, None, 5 )
|
|
310
|
+
result: "(batch, 4 , a , b , 5 )"
|
|
311
|
+
```
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
input_shape: a single shape or a structure of shapes for the inputs.
|
|
315
|
+
Returns:
|
|
316
|
+
the shape or shapes structure in the `jax2tf` format as strings.
|
|
317
|
+
"""
|
|
318
|
+
dim_names = itertools.chain(
|
|
319
|
+
string.ascii_lowercase, # a, b, ... z
|
|
320
|
+
itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
|
|
321
|
+
lambda a, b: a + b,
|
|
322
|
+
itertools.product(string.ascii_lowercase, repeat=2),
|
|
323
|
+
),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
def get_single_jax2tf_shape(shape):
|
|
327
|
+
jax2tf_shape = []
|
|
328
|
+
|
|
329
|
+
for index, dim in enumerate(shape):
|
|
330
|
+
if dim is not None:
|
|
331
|
+
jax2tf_shape.append(str(dim))
|
|
332
|
+
elif index == 0:
|
|
333
|
+
jax2tf_shape.append("batch")
|
|
334
|
+
else:
|
|
335
|
+
jax2tf_shape.append(next(dim_names))
|
|
336
|
+
|
|
337
|
+
return "(" + ", ".join(jax2tf_shape) + ")"
|
|
338
|
+
|
|
339
|
+
res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
|
|
340
|
+
return res
|
|
341
|
+
|
|
342
|
+
def _jax2tf_convert(self, fn, polymorphic_shapes):
|
|
343
|
+
from jax.experimental import jax2tf
|
|
344
|
+
|
|
345
|
+
converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
|
|
346
|
+
# Autograph won't work with the output of jax2tf.
|
|
347
|
+
converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
|
|
348
|
+
return converted_fn
|
|
349
|
+
|
|
350
|
+
def _partial_with_positional(self, fn, index, value):
|
|
351
|
+
"""Return a new partial with one positional argument set to a value.
|
|
352
|
+
|
|
353
|
+
This is needed because `jax2tf` only supports positional arguments and
|
|
354
|
+
`functools.partial` only supports setting positional arguments starting
|
|
355
|
+
from the left. Our use case is the `training` argument which is
|
|
356
|
+
typically the righmost argument.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
fn: the function to wrap.
|
|
360
|
+
index: the index of the positional argument to set to `value`.
|
|
361
|
+
value: the value for the positional argument at `index`.
|
|
362
|
+
"""
|
|
363
|
+
|
|
364
|
+
@functools.wraps(fn)
|
|
365
|
+
def wrapper(*args):
|
|
366
|
+
args = args[0:index] + (value,) + args[index:]
|
|
367
|
+
return fn(*args)
|
|
368
|
+
|
|
369
|
+
return wrapper
|
|
370
|
+
|
|
275
371
|
@tracking.no_automatic_dependency_tracking
|
|
372
|
+
@tf_no_automatic_dependency_tracking
|
|
276
373
|
def _create_variables(self, values, trainable):
|
|
277
374
|
"""Create a structure of variables from a structure of JAX arrays.
|
|
278
375
|
|
|
@@ -296,14 +393,14 @@ class JaxLayer(Layer):
|
|
|
296
393
|
|
|
297
394
|
def create_variable(value):
|
|
298
395
|
if backend.is_tensor(value) or isinstance(
|
|
299
|
-
value, (np.ndarray, np.generic)
|
|
396
|
+
value, (np.ndarray, np.generic, jax.Array)
|
|
300
397
|
):
|
|
301
398
|
dtype = value.dtype
|
|
302
399
|
if is_float_dtype(dtype):
|
|
303
400
|
dtype = None # Use the layer dtype policy
|
|
304
401
|
return self.add_weight(
|
|
305
402
|
value.shape,
|
|
306
|
-
initializer=value,
|
|
403
|
+
initializer=backend.convert_to_tensor(value),
|
|
307
404
|
dtype=dtype,
|
|
308
405
|
trainable=trainable,
|
|
309
406
|
)
|
|
@@ -333,44 +430,46 @@ class JaxLayer(Layer):
|
|
|
333
430
|
|
|
334
431
|
def _get_init_rng(self):
|
|
335
432
|
"""
|
|
336
|
-
Returns a
|
|
433
|
+
Returns a key in form of the backend array of size 2 dtype uint32
|
|
434
|
+
to pass to `init_fn`.
|
|
337
435
|
|
|
338
|
-
By default, this returns a
|
|
436
|
+
By default, this returns a Jax or TF array of size 2 by calling
|
|
339
437
|
`self.seed_generator.next()`. Override this to return a different
|
|
340
438
|
structure.
|
|
341
439
|
|
|
342
440
|
Returns:
|
|
343
|
-
a
|
|
344
|
-
the `rng` argument of `init_fn`.
|
|
441
|
+
a key as an Jax or TF array of size 2 dtype uint32 will be passed
|
|
442
|
+
as the `rng` argument of `init_fn`.
|
|
345
443
|
"""
|
|
346
444
|
return self.seed_generator.next()
|
|
347
445
|
|
|
348
446
|
def _get_call_rng(self, training):
|
|
349
447
|
"""
|
|
350
|
-
Returns a
|
|
448
|
+
Returns a key in form of the backend array of size 2 dtype uint32
|
|
449
|
+
to pass to `call_fn`.
|
|
351
450
|
|
|
352
|
-
By default, this returns a
|
|
451
|
+
By default, this returns a Jax or TF array of size 2 by calling
|
|
353
452
|
`self.seed_generator.next()` when `training` is `True`, and `None` when
|
|
354
453
|
`training` is `False`. Override this to return a different structure or
|
|
355
454
|
to pass RNGs in inference mode too.
|
|
356
455
|
|
|
357
456
|
Returns:
|
|
358
|
-
a
|
|
359
|
-
the `rng` argument of `call_fn`.
|
|
457
|
+
a key as an Jax or TF array of size 2 dtype uint32 will be passed
|
|
458
|
+
as the `rng` argument of `call_fn`.
|
|
360
459
|
"""
|
|
361
460
|
if training:
|
|
362
461
|
return self.seed_generator.next()
|
|
363
462
|
else:
|
|
364
463
|
return None
|
|
365
464
|
|
|
366
|
-
def
|
|
367
|
-
if
|
|
368
|
-
return
|
|
369
|
-
|
|
370
|
-
if jax_utils.is_in_jax_tracing_scope():
|
|
465
|
+
def _initialize_weights(self, input_shape):
|
|
466
|
+
if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
|
|
371
467
|
# This exception is not actually shown, it is caught and a detailed
|
|
372
468
|
# warning about calling 'build' is printed.
|
|
373
|
-
raise ValueError(
|
|
469
|
+
raise ValueError(
|
|
470
|
+
"'JaxLayer' cannot be built in tracing scope"
|
|
471
|
+
"or inside tf function"
|
|
472
|
+
)
|
|
374
473
|
|
|
375
474
|
# Initialize `params` and `state` if needed by calling `init_fn`.
|
|
376
475
|
def create_input(shape):
|
|
@@ -381,7 +480,12 @@ class JaxLayer(Layer):
|
|
|
381
480
|
init_args = []
|
|
382
481
|
for argument_name in self.init_fn_arguments:
|
|
383
482
|
if argument_name == "rng":
|
|
384
|
-
init_args.append(
|
|
483
|
+
init_args.append(
|
|
484
|
+
jax.tree_util.tree_map(
|
|
485
|
+
lambda x: jax.numpy.array(_convert_to_jax_key(x)),
|
|
486
|
+
self._get_init_rng(),
|
|
487
|
+
)
|
|
488
|
+
)
|
|
385
489
|
elif argument_name == "inputs":
|
|
386
490
|
init_args.append(init_inputs)
|
|
387
491
|
elif argument_name == "training":
|
|
@@ -398,6 +502,45 @@ class JaxLayer(Layer):
|
|
|
398
502
|
)
|
|
399
503
|
self.tracked_state = self._create_variables(init_state, trainable=False)
|
|
400
504
|
|
|
505
|
+
def build(self, input_shape):
|
|
506
|
+
if self.params is None and self.state is None:
|
|
507
|
+
self._initialize_weights(input_shape)
|
|
508
|
+
|
|
509
|
+
if backend.backend() == "tensorflow":
|
|
510
|
+
polymorphic_shapes = []
|
|
511
|
+
for argument in self.call_fn_arguments:
|
|
512
|
+
if argument == "inputs":
|
|
513
|
+
polymorphic_shapes.append(
|
|
514
|
+
self._get_jax2tf_input_shape(input_shape)
|
|
515
|
+
)
|
|
516
|
+
elif argument != "training":
|
|
517
|
+
# params, state, rng
|
|
518
|
+
polymorphic_shapes.append("...")
|
|
519
|
+
|
|
520
|
+
if "training" in self.call_fn_arguments:
|
|
521
|
+
training_argument_index = self.call_fn_arguments.index(
|
|
522
|
+
"training"
|
|
523
|
+
)
|
|
524
|
+
self.jax2tf_training_false_fn = self._jax2tf_convert(
|
|
525
|
+
self._partial_with_positional(
|
|
526
|
+
self.call_fn, training_argument_index, False
|
|
527
|
+
),
|
|
528
|
+
polymorphic_shapes,
|
|
529
|
+
)
|
|
530
|
+
self.jax2tf_training_true_fn = self._jax2tf_convert(
|
|
531
|
+
self._partial_with_positional(
|
|
532
|
+
self.call_fn, training_argument_index, True
|
|
533
|
+
),
|
|
534
|
+
polymorphic_shapes,
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
self.jax2tf_training_false_fn = self._jax2tf_convert(
|
|
538
|
+
self.call_fn,
|
|
539
|
+
polymorphic_shapes,
|
|
540
|
+
)
|
|
541
|
+
self.jax2tf_training_true_fn = None
|
|
542
|
+
super().build(input_shape)
|
|
543
|
+
|
|
401
544
|
def call(self, inputs, training=False):
|
|
402
545
|
def unwrap_variable(variable):
|
|
403
546
|
return None if variable is None else variable.value
|
|
@@ -413,11 +556,16 @@ class JaxLayer(Layer):
|
|
|
413
556
|
jax.tree_util.tree_map(unwrap_variable, self.state)
|
|
414
557
|
)
|
|
415
558
|
elif argument_name == "rng":
|
|
416
|
-
call_args.append(
|
|
559
|
+
call_args.append(
|
|
560
|
+
jax.tree_util.tree_map(
|
|
561
|
+
_convert_to_jax_key, self._get_call_rng(training)
|
|
562
|
+
)
|
|
563
|
+
)
|
|
417
564
|
elif argument_name == "inputs":
|
|
418
565
|
call_args.append(inputs)
|
|
419
566
|
elif argument_name == "training":
|
|
420
|
-
|
|
567
|
+
if backend.backend() == "jax":
|
|
568
|
+
call_args.append(training)
|
|
421
569
|
|
|
422
570
|
def assign_state_to_variable(value, variable):
|
|
423
571
|
# This exists only to make debugging this error case easier.
|
|
@@ -429,14 +577,23 @@ class JaxLayer(Layer):
|
|
|
429
577
|
)
|
|
430
578
|
variable.assign(value)
|
|
431
579
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
580
|
+
def call_with_fn(fn):
|
|
581
|
+
if self.has_state:
|
|
582
|
+
predictions, new_state = fn(*call_args)
|
|
583
|
+
jax.tree_util.tree_map(
|
|
584
|
+
assign_state_to_variable, new_state, self.state
|
|
585
|
+
)
|
|
586
|
+
return predictions
|
|
587
|
+
else:
|
|
588
|
+
return fn(*call_args)
|
|
589
|
+
|
|
590
|
+
if backend.backend() == "jax":
|
|
591
|
+
return call_with_fn(self.call_fn)
|
|
592
|
+
elif backend.backend() == "tensorflow":
|
|
593
|
+
if training and self.jax2tf_training_true_fn is not None:
|
|
594
|
+
return call_with_fn(self.jax2tf_training_true_fn)
|
|
595
|
+
else:
|
|
596
|
+
return call_with_fn(self.jax2tf_training_false_fn)
|
|
440
597
|
|
|
441
598
|
def get_config(self):
|
|
442
599
|
config = {
|
|
@@ -556,12 +713,6 @@ class FlaxLayer(JaxLayer):
|
|
|
556
713
|
# Late import to only require Flax when this is used.
|
|
557
714
|
from flax.core import scope as flax_scope
|
|
558
715
|
|
|
559
|
-
if backend.backend() != "jax":
|
|
560
|
-
raise ValueError(
|
|
561
|
-
"FlaxLayer is only supported with the JAX backend. Current "
|
|
562
|
-
f"backend: {backend.backend()}"
|
|
563
|
-
)
|
|
564
|
-
|
|
565
716
|
self.module = module
|
|
566
717
|
self.method = method
|
|
567
718
|
|
keras/src/utils/module_utils.py
CHANGED
|
@@ -39,6 +39,15 @@ class LazyModule:
|
|
|
39
39
|
return f"LazyModule({self.name})"
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
class OrbaxLazyModule(LazyModule):
|
|
43
|
+
def initialize(self):
|
|
44
|
+
try:
|
|
45
|
+
parent_module = importlib.import_module("orbax.checkpoint")
|
|
46
|
+
self.module = parent_module.v1
|
|
47
|
+
except ImportError:
|
|
48
|
+
raise ImportError(self.import_error_msg)
|
|
49
|
+
|
|
50
|
+
|
|
42
51
|
tensorflow = LazyModule("tensorflow")
|
|
43
52
|
gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
|
|
44
53
|
tensorflow_io = LazyModule("tensorflow_io")
|
|
@@ -59,3 +68,12 @@ optree = LazyModule("optree")
|
|
|
59
68
|
dmtree = LazyModule("tree")
|
|
60
69
|
tf2onnx = LazyModule("tf2onnx")
|
|
61
70
|
grain = LazyModule("grain")
|
|
71
|
+
litert = LazyModule("ai_edge_litert")
|
|
72
|
+
ocp = OrbaxLazyModule(
|
|
73
|
+
"orbax.checkpoint.v1",
|
|
74
|
+
pip_name="orbax-checkpoint",
|
|
75
|
+
import_error_msg=(
|
|
76
|
+
"OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
|
|
77
|
+
"You can install it via pip install orbax-checkpoint"
|
|
78
|
+
),
|
|
79
|
+
)
|
keras/src/utils/progbar.py
CHANGED
|
@@ -3,7 +3,8 @@ import os
|
|
|
3
3
|
import sys
|
|
4
4
|
import time
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
7
8
|
from keras.src.api_export import keras_export
|
|
8
9
|
from keras.src.utils import io_utils
|
|
9
10
|
|
|
@@ -162,12 +163,10 @@ class Progbar:
|
|
|
162
163
|
for k in self._values_order:
|
|
163
164
|
info += f" - {k}:"
|
|
164
165
|
if isinstance(self._values[k], list):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
)
|
|
170
|
-
avg = float(avg)
|
|
166
|
+
values, count = self._values[k]
|
|
167
|
+
if not isinstance(values, float):
|
|
168
|
+
values = np.mean(values)
|
|
169
|
+
avg = values / max(1, count)
|
|
171
170
|
if abs(avg) > 1e-3:
|
|
172
171
|
info += f" {avg:.4f}"
|
|
173
172
|
else:
|
|
@@ -194,11 +193,10 @@ class Progbar:
|
|
|
194
193
|
info += f" -{self._format_time(time_per_unit, self.unit_name)}"
|
|
195
194
|
for k in self._values_order:
|
|
196
195
|
info += f" - {k}:"
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
)
|
|
196
|
+
values, count = self._values[k]
|
|
197
|
+
if not isinstance(values, float):
|
|
198
|
+
values = np.mean(values)
|
|
199
|
+
avg = values / max(1, count)
|
|
202
200
|
if avg > 1e-3:
|
|
203
201
|
info += f" {avg:.4f}"
|
|
204
202
|
else:
|
keras/src/utils/rng_utils.py
CHANGED
|
@@ -5,6 +5,7 @@ import numpy as np
|
|
|
5
5
|
from keras.src import backend
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.backend.common import global_state
|
|
8
|
+
from keras.src.random import seed_generator
|
|
8
9
|
from keras.src.utils.module_utils import tensorflow as tf
|
|
9
10
|
|
|
10
11
|
GLOBAL_RANDOM_SEED = "global_random_seed"
|
|
@@ -20,7 +21,7 @@ def set_random_seed(seed):
|
|
|
20
21
|
sources of randomness, or when certain non-deterministic cuDNN ops are
|
|
21
22
|
involved.
|
|
22
23
|
|
|
23
|
-
Calling this utility
|
|
24
|
+
Calling this utility does the following:
|
|
24
25
|
|
|
25
26
|
```python
|
|
26
27
|
import random
|
|
@@ -36,6 +37,9 @@ def set_random_seed(seed):
|
|
|
36
37
|
torch.manual_seed(seed)
|
|
37
38
|
```
|
|
38
39
|
|
|
40
|
+
Additionally, it resets the global Keras `SeedGenerator`, which is used by
|
|
41
|
+
`keras.random` functions when the `seed` is not provided.
|
|
42
|
+
|
|
39
43
|
Note that the TensorFlow seed is set even if you're not using TensorFlow
|
|
40
44
|
as your backend framework, since many workflows leverage `tf.data`
|
|
41
45
|
pipelines (which feature random shuffling). Likewise many workflows
|
|
@@ -52,6 +56,10 @@ def set_random_seed(seed):
|
|
|
52
56
|
|
|
53
57
|
# Store seed in global state so we can query it if set.
|
|
54
58
|
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
|
|
59
|
+
# Remove global SeedGenerator, it will be recreated from the seed.
|
|
60
|
+
global_state.set_global_attribute(
|
|
61
|
+
seed_generator.GLOBAL_SEED_GENERATOR, None
|
|
62
|
+
)
|
|
55
63
|
random.seed(seed)
|
|
56
64
|
np.random.seed(seed)
|
|
57
65
|
if tf.available:
|
keras/src/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: keras-nightly
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.14.0.dev2026010104
|
|
4
4
|
Summary: Multi-backend Keras
|
|
5
5
|
Author-email: Keras team <keras-users@googlegroups.com>
|
|
6
6
|
License: Apache License 2.0
|
|
@@ -8,15 +8,15 @@ Project-URL: Home, https://keras.io/
|
|
|
8
8
|
Project-URL: Repository, https://github.com/keras-team/keras
|
|
9
9
|
Classifier: Development Status :: 4 - Beta
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
13
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
14
14
|
Classifier: Operating System :: Unix
|
|
15
15
|
Classifier: Operating System :: MacOS
|
|
16
16
|
Classifier: Intended Audience :: Science/Research
|
|
17
17
|
Classifier: Topic :: Scientific/Engineering
|
|
18
18
|
Classifier: Topic :: Software Development
|
|
19
|
-
Requires-Python: >=3.
|
|
19
|
+
Requires-Python: >=3.11
|
|
20
20
|
Description-Content-Type: text/markdown
|
|
21
21
|
Requires-Dist: absl-py
|
|
22
22
|
Requires-Dist: numpy
|
|
@@ -56,9 +56,8 @@ pip install keras --upgrade
|
|
|
56
56
|
|
|
57
57
|
2. Install backend package(s).
|
|
58
58
|
|
|
59
|
-
To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`.
|
|
60
|
-
|
|
61
|
-
as well as `tf.data` pipelines.
|
|
59
|
+
To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. Additionally,
|
|
60
|
+
The `openvino` backend is available with support for model inference only.
|
|
62
61
|
|
|
63
62
|
### Local installation
|
|
64
63
|
|
|
@@ -85,6 +84,17 @@ python pip_build.py --install
|
|
|
85
84
|
./shell/api_gen.sh
|
|
86
85
|
```
|
|
87
86
|
|
|
87
|
+
## Backend Compatibility Table
|
|
88
|
+
|
|
89
|
+
The following table lists the minimum supported versions of each backend for the latest stable release of Keras (v3.x):
|
|
90
|
+
|
|
91
|
+
| Backend | Minimum Supported Version |
|
|
92
|
+
|------------|---------------------------|
|
|
93
|
+
| TensorFlow | 2.16.1 |
|
|
94
|
+
| JAX | 0.4.20 |
|
|
95
|
+
| PyTorch | 2.1.0 |
|
|
96
|
+
| OpenVINO | 2025.3.0 |
|
|
97
|
+
|
|
88
98
|
#### Adding GPU support
|
|
89
99
|
|
|
90
100
|
The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also
|