keras-nightly 3.12.0.dev2025083103__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras/__init__.py +1 -0
- keras/_tf_keras/keras/__init__.py +1 -0
- keras/_tf_keras/keras/callbacks/__init__.py +3 -0
- keras/_tf_keras/keras/distillation/__init__.py +16 -0
- keras/_tf_keras/keras/distribution/__init__.py +3 -0
- keras/_tf_keras/keras/dtype_policies/__init__.py +6 -0
- keras/_tf_keras/keras/layers/__init__.py +21 -0
- keras/_tf_keras/keras/ops/__init__.py +16 -0
- keras/_tf_keras/keras/ops/image/__init__.py +1 -0
- keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
- keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
- keras/_tf_keras/keras/ops/numpy/__init__.py +12 -0
- keras/_tf_keras/keras/quantizers/__init__.py +13 -0
- keras/callbacks/__init__.py +3 -0
- keras/distillation/__init__.py +16 -0
- keras/distribution/__init__.py +3 -0
- keras/dtype_policies/__init__.py +6 -0
- keras/layers/__init__.py +21 -0
- keras/ops/__init__.py +16 -0
- keras/ops/image/__init__.py +1 -0
- keras/ops/linalg/__init__.py +1 -0
- keras/ops/nn/__init__.py +3 -0
- keras/ops/numpy/__init__.py +12 -0
- keras/quantizers/__init__.py +13 -0
- keras/src/applications/imagenet_utils.py +4 -1
- keras/src/backend/common/backend_utils.py +30 -6
- keras/src/backend/common/dtypes.py +6 -12
- keras/src/backend/common/name_scope.py +2 -1
- keras/src/backend/common/variables.py +38 -20
- keras/src/backend/jax/core.py +126 -78
- keras/src/backend/jax/distribution_lib.py +16 -2
- keras/src/backend/jax/layer.py +3 -1
- keras/src/backend/jax/linalg.py +4 -0
- keras/src/backend/jax/nn.py +511 -29
- keras/src/backend/jax/numpy.py +109 -23
- keras/src/backend/jax/optimizer.py +3 -2
- keras/src/backend/jax/trainer.py +18 -3
- keras/src/backend/numpy/linalg.py +4 -0
- keras/src/backend/numpy/nn.py +313 -2
- keras/src/backend/numpy/numpy.py +97 -8
- keras/src/backend/openvino/__init__.py +1 -0
- keras/src/backend/openvino/core.py +6 -23
- keras/src/backend/openvino/linalg.py +4 -0
- keras/src/backend/openvino/nn.py +271 -20
- keras/src/backend/openvino/numpy.py +1369 -195
- keras/src/backend/openvino/random.py +7 -14
- keras/src/backend/tensorflow/layer.py +43 -9
- keras/src/backend/tensorflow/linalg.py +24 -0
- keras/src/backend/tensorflow/nn.py +545 -1
- keras/src/backend/tensorflow/numpy.py +351 -56
- keras/src/backend/tensorflow/trainer.py +6 -2
- keras/src/backend/torch/core.py +3 -1
- keras/src/backend/torch/linalg.py +4 -0
- keras/src/backend/torch/nn.py +125 -0
- keras/src/backend/torch/numpy.py +109 -9
- keras/src/backend/torch/trainer.py +8 -2
- keras/src/callbacks/__init__.py +1 -0
- keras/src/callbacks/callback_list.py +45 -11
- keras/src/callbacks/model_checkpoint.py +5 -0
- keras/src/callbacks/orbax_checkpoint.py +332 -0
- keras/src/callbacks/terminate_on_nan.py +54 -5
- keras/src/datasets/cifar10.py +5 -0
- keras/src/distillation/__init__.py +1 -0
- keras/src/distillation/distillation_loss.py +390 -0
- keras/src/distillation/distiller.py +598 -0
- keras/src/distribution/distribution_lib.py +14 -0
- keras/src/dtype_policies/__init__.py +4 -0
- keras/src/dtype_policies/dtype_policy.py +180 -1
- keras/src/export/__init__.py +2 -0
- keras/src/export/export_utils.py +39 -2
- keras/src/export/litert.py +248 -0
- keras/src/export/onnx.py +6 -0
- keras/src/export/openvino.py +1 -1
- keras/src/export/tf2onnx_lib.py +3 -0
- keras/src/layers/__init__.py +13 -0
- keras/src/layers/activations/softmax.py +9 -4
- keras/src/layers/attention/attention.py +1 -1
- keras/src/layers/attention/multi_head_attention.py +4 -1
- keras/src/layers/core/dense.py +406 -102
- keras/src/layers/core/einsum_dense.py +521 -116
- keras/src/layers/core/embedding.py +257 -99
- keras/src/layers/core/input_layer.py +1 -0
- keras/src/layers/core/reversible_embedding.py +399 -0
- keras/src/layers/input_spec.py +17 -17
- keras/src/layers/layer.py +50 -15
- keras/src/layers/merging/concatenate.py +6 -5
- keras/src/layers/merging/dot.py +4 -1
- keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
- keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
- keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
- keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
- keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
- keras/src/layers/preprocessing/discretization.py +6 -5
- keras/src/layers/preprocessing/feature_space.py +8 -4
- keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
- keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py +5 -5
- keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
- keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
- keras/src/layers/preprocessing/index_lookup.py +19 -1
- keras/src/layers/preprocessing/normalization.py +16 -1
- keras/src/layers/preprocessing/string_lookup.py +26 -28
- keras/src/layers/regularization/dropout.py +43 -1
- keras/src/layers/rnn/gru.py +1 -1
- keras/src/layers/rnn/lstm.py +2 -2
- keras/src/layers/rnn/rnn.py +19 -0
- keras/src/layers/rnn/simple_rnn.py +1 -1
- keras/src/legacy/preprocessing/image.py +4 -1
- keras/src/legacy/preprocessing/sequence.py +20 -12
- keras/src/losses/loss.py +1 -1
- keras/src/losses/losses.py +24 -0
- keras/src/metrics/confusion_metrics.py +7 -6
- keras/src/models/cloning.py +4 -0
- keras/src/models/functional.py +11 -3
- keras/src/models/model.py +195 -44
- keras/src/ops/image.py +257 -20
- keras/src/ops/linalg.py +93 -0
- keras/src/ops/nn.py +268 -2
- keras/src/ops/numpy.py +701 -44
- keras/src/ops/operation.py +90 -29
- keras/src/ops/operation_utils.py +2 -0
- keras/src/optimizers/adafactor.py +29 -10
- keras/src/optimizers/base_optimizer.py +22 -3
- keras/src/optimizers/loss_scale_optimizer.py +51 -18
- keras/src/optimizers/muon.py +65 -31
- keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
- keras/src/quantizers/__init__.py +14 -1
- keras/src/quantizers/awq.py +361 -0
- keras/src/quantizers/awq_config.py +140 -0
- keras/src/quantizers/awq_core.py +217 -0
- keras/src/quantizers/gptq.py +346 -207
- keras/src/quantizers/gptq_config.py +63 -13
- keras/src/quantizers/gptq_core.py +328 -215
- keras/src/quantizers/quantization_config.py +246 -0
- keras/src/quantizers/quantizers.py +407 -38
- keras/src/quantizers/utils.py +23 -0
- keras/src/random/seed_generator.py +6 -4
- keras/src/saving/file_editor.py +81 -6
- keras/src/saving/orbax_util.py +26 -0
- keras/src/saving/saving_api.py +37 -14
- keras/src/saving/saving_lib.py +1 -1
- keras/src/testing/__init__.py +1 -0
- keras/src/testing/test_case.py +45 -5
- keras/src/trainers/compile_utils.py +38 -17
- keras/src/trainers/data_adapters/grain_dataset_adapter.py +1 -5
- keras/src/tree/torchtree_impl.py +215 -0
- keras/src/tree/tree_api.py +6 -1
- keras/src/utils/backend_utils.py +31 -4
- keras/src/utils/dataset_utils.py +234 -35
- keras/src/utils/file_utils.py +49 -11
- keras/src/utils/image_utils.py +14 -2
- keras/src/utils/jax_layer.py +244 -55
- keras/src/utils/module_utils.py +29 -0
- keras/src/utils/progbar.py +10 -12
- keras/src/utils/python_utils.py +5 -0
- keras/src/utils/rng_utils.py +9 -1
- keras/src/utils/tracking.py +70 -5
- keras/src/version.py +1 -1
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +163 -142
- keras/src/quantizers/gptq_quant.py +0 -133
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
- {keras_nightly-3.12.0.dev2025083103.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
keras/src/utils/jax_layer.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import inspect
|
|
3
|
+
import itertools
|
|
4
|
+
import string
|
|
2
5
|
|
|
3
6
|
import numpy as np
|
|
4
7
|
|
|
@@ -8,10 +11,27 @@ from keras.src.api_export import keras_export
|
|
|
8
11
|
from keras.src.backend.common.variables import is_float_dtype
|
|
9
12
|
from keras.src.backend.common.variables import standardize_dtype
|
|
10
13
|
from keras.src.layers.layer import Layer
|
|
14
|
+
from keras.src.random.seed_generator import draw_seed
|
|
11
15
|
from keras.src.saving import serialization_lib
|
|
12
16
|
from keras.src.utils import jax_utils
|
|
13
17
|
from keras.src.utils import tracking
|
|
14
18
|
from keras.src.utils.module_utils import jax
|
|
19
|
+
from keras.src.utils.module_utils import tensorflow as tf
|
|
20
|
+
|
|
21
|
+
if backend.backend() == "tensorflow":
|
|
22
|
+
tf_no_automatic_dependency_tracking = (
|
|
23
|
+
tf.__internal__.tracking.no_automatic_dependency_tracking
|
|
24
|
+
)
|
|
25
|
+
else:
|
|
26
|
+
|
|
27
|
+
def tf_no_automatic_dependency_tracking(fn):
|
|
28
|
+
return fn
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _convert_to_jax_key(tensor):
|
|
32
|
+
if backend.backend() == "tensorflow":
|
|
33
|
+
return tf.bitcast(tensor, tf.uint32)[0]
|
|
34
|
+
return tensor
|
|
15
35
|
|
|
16
36
|
|
|
17
37
|
@keras_export("keras.layers.JaxLayer")
|
|
@@ -219,21 +239,15 @@ class JaxLayer(Layer):
|
|
|
219
239
|
seed=None,
|
|
220
240
|
**kwargs,
|
|
221
241
|
):
|
|
222
|
-
if backend.backend()
|
|
242
|
+
if backend.backend() not in ["jax", "tensorflow"]:
|
|
223
243
|
raise ValueError(
|
|
224
|
-
"
|
|
225
|
-
f"backend: {backend.backend()}"
|
|
226
|
-
)
|
|
227
|
-
|
|
228
|
-
if init_fn is None and params is None and state is None:
|
|
229
|
-
raise ValueError(
|
|
230
|
-
"`init_fn`, `params` and `state` cannot all be `None`."
|
|
244
|
+
f"{self.__class__.__name__} is only supported with the JAX or"
|
|
245
|
+
f" Tensorflow backend. Current backend: {backend.backend()}"
|
|
231
246
|
)
|
|
232
247
|
|
|
233
248
|
super().__init__(**kwargs)
|
|
234
249
|
self.call_fn = call_fn
|
|
235
250
|
self.init_fn = init_fn
|
|
236
|
-
self.seed_generator = backend.random.SeedGenerator(seed)
|
|
237
251
|
self.tracked_params = self._create_variables(params, trainable=True)
|
|
238
252
|
self.tracked_state = self._create_variables(state, trainable=False)
|
|
239
253
|
if self.params is not None or self.state is not None:
|
|
@@ -245,13 +259,35 @@ class JaxLayer(Layer):
|
|
|
245
259
|
{"params", "state", "rng", "inputs", "training"},
|
|
246
260
|
{"inputs"},
|
|
247
261
|
)
|
|
248
|
-
self.
|
|
262
|
+
self.call_fn_has_params = "params" in self.call_fn_arguments
|
|
263
|
+
self.call_fn_has_state = "state" in self.call_fn_arguments
|
|
264
|
+
call_fn_has_rng = "rng" in self.call_fn_arguments
|
|
265
|
+
|
|
266
|
+
if call_fn_has_rng:
|
|
267
|
+
self.seed_generator = backend.random.SeedGenerator(seed)
|
|
268
|
+
else:
|
|
269
|
+
self.seed_generator = None
|
|
270
|
+
|
|
271
|
+
if (
|
|
272
|
+
init_fn is None
|
|
273
|
+
and params is None
|
|
274
|
+
and state is None
|
|
275
|
+
and (self.call_fn_has_params or self.call_fn_has_state)
|
|
276
|
+
):
|
|
277
|
+
raise ValueError(
|
|
278
|
+
"`init_fn`, `params` and `state` cannot all be `None` when "
|
|
279
|
+
"`call_fn` takes a `params` or a `state` argument."
|
|
280
|
+
)
|
|
249
281
|
|
|
250
282
|
if init_fn:
|
|
251
283
|
self.init_fn_arguments = self._validate_signature(
|
|
252
284
|
init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
|
|
253
285
|
)
|
|
254
286
|
|
|
287
|
+
# Attributes for jax2tf functions
|
|
288
|
+
self.jax2tf_training_false_fn = None
|
|
289
|
+
self.jax2tf_training_true_fn = None
|
|
290
|
+
|
|
255
291
|
def _validate_signature(self, fn, fn_name, allowed, required):
|
|
256
292
|
fn_parameters = inspect.signature(fn).parameters
|
|
257
293
|
for parameter_name in required:
|
|
@@ -272,7 +308,81 @@ class JaxLayer(Layer):
|
|
|
272
308
|
|
|
273
309
|
return parameter_names
|
|
274
310
|
|
|
311
|
+
def _get_jax2tf_input_shape(self, input_shape):
|
|
312
|
+
"""Convert input shape in a format suitable for `jax2tf`.
|
|
313
|
+
|
|
314
|
+
`jax2tf` expects a letter for each unknown dimension, which allows
|
|
315
|
+
correlated dimensions. Since correlated dimensions are not supported by
|
|
316
|
+
Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We
|
|
317
|
+
however use 'batch' for dimension 0 if not defined to correlate the
|
|
318
|
+
batch size across inputs.
|
|
319
|
+
|
|
320
|
+
Example (spaces added for readability):
|
|
321
|
+
```
|
|
322
|
+
input_shape: (None , 4 , None, None, 5 )
|
|
323
|
+
result: "(batch, 4 , a , b , 5 )"
|
|
324
|
+
```
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
input_shape: a single shape or a structure of shapes for the inputs.
|
|
328
|
+
Returns:
|
|
329
|
+
the shape or shapes structure in the `jax2tf` format as strings.
|
|
330
|
+
"""
|
|
331
|
+
dim_names = itertools.chain(
|
|
332
|
+
string.ascii_lowercase, # a, b, ... z
|
|
333
|
+
itertools.starmap( # aa, ab, ... az, ba, bb, ... zz
|
|
334
|
+
lambda a, b: a + b,
|
|
335
|
+
itertools.product(string.ascii_lowercase, repeat=2),
|
|
336
|
+
),
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
def get_single_jax2tf_shape(shape):
|
|
340
|
+
jax2tf_shape = []
|
|
341
|
+
|
|
342
|
+
for index, dim in enumerate(shape):
|
|
343
|
+
if dim is not None:
|
|
344
|
+
jax2tf_shape.append(str(dim))
|
|
345
|
+
elif index == 0:
|
|
346
|
+
jax2tf_shape.append("batch")
|
|
347
|
+
else:
|
|
348
|
+
jax2tf_shape.append(next(dim_names))
|
|
349
|
+
|
|
350
|
+
return "(" + ", ".join(jax2tf_shape) + ")"
|
|
351
|
+
|
|
352
|
+
res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)
|
|
353
|
+
return res
|
|
354
|
+
|
|
355
|
+
def _jax2tf_convert(self, fn, polymorphic_shapes):
|
|
356
|
+
from jax.experimental import jax2tf
|
|
357
|
+
|
|
358
|
+
converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)
|
|
359
|
+
# Autograph won't work with the output of jax2tf.
|
|
360
|
+
converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)
|
|
361
|
+
return converted_fn
|
|
362
|
+
|
|
363
|
+
def _partial_with_positional(self, fn, index, value):
|
|
364
|
+
"""Return a new partial with one positional argument set to a value.
|
|
365
|
+
|
|
366
|
+
This is needed because `jax2tf` only supports positional arguments and
|
|
367
|
+
`functools.partial` only supports setting positional arguments starting
|
|
368
|
+
from the left. Our use case is the `training` argument which is
|
|
369
|
+
typically the righmost argument.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
fn: the function to wrap.
|
|
373
|
+
index: the index of the positional argument to set to `value`.
|
|
374
|
+
value: the value for the positional argument at `index`.
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
@functools.wraps(fn)
|
|
378
|
+
def wrapper(*args):
|
|
379
|
+
args = args[0:index] + (value,) + args[index:]
|
|
380
|
+
return fn(*args)
|
|
381
|
+
|
|
382
|
+
return wrapper
|
|
383
|
+
|
|
275
384
|
@tracking.no_automatic_dependency_tracking
|
|
385
|
+
@tf_no_automatic_dependency_tracking
|
|
276
386
|
def _create_variables(self, values, trainable):
|
|
277
387
|
"""Create a structure of variables from a structure of JAX arrays.
|
|
278
388
|
|
|
@@ -296,14 +406,14 @@ class JaxLayer(Layer):
|
|
|
296
406
|
|
|
297
407
|
def create_variable(value):
|
|
298
408
|
if backend.is_tensor(value) or isinstance(
|
|
299
|
-
value, (np.ndarray, np.generic)
|
|
409
|
+
value, (np.ndarray, np.generic, jax.Array)
|
|
300
410
|
):
|
|
301
411
|
dtype = value.dtype
|
|
302
412
|
if is_float_dtype(dtype):
|
|
303
413
|
dtype = None # Use the layer dtype policy
|
|
304
414
|
return self.add_weight(
|
|
305
415
|
value.shape,
|
|
306
|
-
initializer=value,
|
|
416
|
+
initializer=backend.convert_to_tensor(value),
|
|
307
417
|
dtype=dtype,
|
|
308
418
|
trainable=trainable,
|
|
309
419
|
)
|
|
@@ -331,46 +441,69 @@ class JaxLayer(Layer):
|
|
|
331
441
|
flat_variables, _ = jax.tree_util.tree_flatten(variables)
|
|
332
442
|
return flat_variables
|
|
333
443
|
|
|
444
|
+
def _get_init_seed(self):
|
|
445
|
+
"""
|
|
446
|
+
Returns a single seed as a tensor of shape [2].
|
|
447
|
+
|
|
448
|
+
Call this within `_get_init_rng()` to obtain a new seed.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A native tensor of shape [2] and the backend dtype for seeds.
|
|
452
|
+
"""
|
|
453
|
+
# Use the global SeedGenerator.
|
|
454
|
+
return draw_seed(None)
|
|
455
|
+
|
|
334
456
|
def _get_init_rng(self):
|
|
335
457
|
"""
|
|
336
|
-
Returns a
|
|
458
|
+
Returns a seed or seeds to pass as the `rng` argument of `init_fn`.
|
|
337
459
|
|
|
338
|
-
By default, this returns a single
|
|
339
|
-
`self.
|
|
340
|
-
|
|
460
|
+
By default, this returns a single seed. Override this to return a
|
|
461
|
+
different structure. Overrides should use `self._get_init_seed()` to
|
|
462
|
+
obtain new seeds.
|
|
341
463
|
|
|
342
464
|
Returns:
|
|
343
|
-
|
|
344
|
-
|
|
465
|
+
RNG key or structure of keys as tensors of shape [2] and the backend
|
|
466
|
+
dtype for seeds.
|
|
467
|
+
"""
|
|
468
|
+
return self._get_init_seed()
|
|
469
|
+
|
|
470
|
+
def _get_call_seed(self):
|
|
471
|
+
"""
|
|
472
|
+
Returns a single seed as a tensor of shape [2].
|
|
473
|
+
|
|
474
|
+
Call this within `_get_call_rng()` to obtain a new seed.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
A native tensor of shape [2] and the backend dtype for seeds.
|
|
345
478
|
"""
|
|
346
479
|
return self.seed_generator.next()
|
|
347
480
|
|
|
348
481
|
def _get_call_rng(self, training):
|
|
349
482
|
"""
|
|
350
|
-
Returns a
|
|
483
|
+
Returns a seed or seeds to pass as the `rng` argument of `call_fn`.
|
|
351
484
|
|
|
352
|
-
By default, this returns a
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
to
|
|
485
|
+
By default, this returns a seed when `training` is `True`, and `None`
|
|
486
|
+
when `training` is `False`. Override this to return a different
|
|
487
|
+
structure or to pass seeds in inference mode too. Overrides should use
|
|
488
|
+
`self._get_call_seed()` to obtain seeds.
|
|
356
489
|
|
|
357
490
|
Returns:
|
|
358
|
-
|
|
359
|
-
|
|
491
|
+
RNG key or structure of keys as tensors of shape [2] and the backend
|
|
492
|
+
dtype for seeds.
|
|
360
493
|
"""
|
|
361
494
|
if training:
|
|
362
|
-
return self.
|
|
495
|
+
return self._get_call_seed()
|
|
363
496
|
else:
|
|
364
497
|
return None
|
|
365
498
|
|
|
366
|
-
def
|
|
367
|
-
if
|
|
368
|
-
return
|
|
369
|
-
|
|
370
|
-
if jax_utils.is_in_jax_tracing_scope():
|
|
499
|
+
def _initialize_weights(self, input_shape):
|
|
500
|
+
if jax_utils.is_in_jax_tracing_scope() or tf.inside_function():
|
|
371
501
|
# This exception is not actually shown, it is caught and a detailed
|
|
372
502
|
# warning about calling 'build' is printed.
|
|
373
|
-
raise ValueError(
|
|
503
|
+
raise ValueError(
|
|
504
|
+
"'JaxLayer' cannot be built in tracing scope"
|
|
505
|
+
"or inside tf function"
|
|
506
|
+
)
|
|
374
507
|
|
|
375
508
|
# Initialize `params` and `state` if needed by calling `init_fn`.
|
|
376
509
|
def create_input(shape):
|
|
@@ -381,14 +514,19 @@ class JaxLayer(Layer):
|
|
|
381
514
|
init_args = []
|
|
382
515
|
for argument_name in self.init_fn_arguments:
|
|
383
516
|
if argument_name == "rng":
|
|
384
|
-
init_args.append(
|
|
517
|
+
init_args.append(
|
|
518
|
+
jax.tree_util.tree_map(
|
|
519
|
+
lambda x: jax.numpy.array(_convert_to_jax_key(x)),
|
|
520
|
+
self._get_init_rng(),
|
|
521
|
+
)
|
|
522
|
+
)
|
|
385
523
|
elif argument_name == "inputs":
|
|
386
524
|
init_args.append(init_inputs)
|
|
387
525
|
elif argument_name == "training":
|
|
388
526
|
init_args.append(True)
|
|
389
527
|
|
|
390
528
|
init_result = self.init_fn(*init_args)
|
|
391
|
-
if self.
|
|
529
|
+
if self.call_fn_has_state:
|
|
392
530
|
init_params, init_state = init_result
|
|
393
531
|
else:
|
|
394
532
|
init_params, init_state = init_result, None
|
|
@@ -398,6 +536,49 @@ class JaxLayer(Layer):
|
|
|
398
536
|
)
|
|
399
537
|
self.tracked_state = self._create_variables(init_state, trainable=False)
|
|
400
538
|
|
|
539
|
+
def build(self, input_shape):
|
|
540
|
+
if (
|
|
541
|
+
self.params is None
|
|
542
|
+
and self.state is None
|
|
543
|
+
and (self.call_fn_has_params or self.call_fn_has_state)
|
|
544
|
+
):
|
|
545
|
+
self._initialize_weights(input_shape)
|
|
546
|
+
|
|
547
|
+
if backend.backend() == "tensorflow":
|
|
548
|
+
polymorphic_shapes = []
|
|
549
|
+
for argument in self.call_fn_arguments:
|
|
550
|
+
if argument == "inputs":
|
|
551
|
+
polymorphic_shapes.append(
|
|
552
|
+
self._get_jax2tf_input_shape(input_shape)
|
|
553
|
+
)
|
|
554
|
+
elif argument != "training":
|
|
555
|
+
# params, state, rng
|
|
556
|
+
polymorphic_shapes.append("...")
|
|
557
|
+
|
|
558
|
+
if "training" in self.call_fn_arguments:
|
|
559
|
+
training_argument_index = self.call_fn_arguments.index(
|
|
560
|
+
"training"
|
|
561
|
+
)
|
|
562
|
+
self.jax2tf_training_false_fn = self._jax2tf_convert(
|
|
563
|
+
self._partial_with_positional(
|
|
564
|
+
self.call_fn, training_argument_index, False
|
|
565
|
+
),
|
|
566
|
+
polymorphic_shapes,
|
|
567
|
+
)
|
|
568
|
+
self.jax2tf_training_true_fn = self._jax2tf_convert(
|
|
569
|
+
self._partial_with_positional(
|
|
570
|
+
self.call_fn, training_argument_index, True
|
|
571
|
+
),
|
|
572
|
+
polymorphic_shapes,
|
|
573
|
+
)
|
|
574
|
+
else:
|
|
575
|
+
self.jax2tf_training_false_fn = self._jax2tf_convert(
|
|
576
|
+
self.call_fn,
|
|
577
|
+
polymorphic_shapes,
|
|
578
|
+
)
|
|
579
|
+
self.jax2tf_training_true_fn = None
|
|
580
|
+
super().build(input_shape)
|
|
581
|
+
|
|
401
582
|
def call(self, inputs, training=False):
|
|
402
583
|
def unwrap_variable(variable):
|
|
403
584
|
return None if variable is None else variable.value
|
|
@@ -413,11 +594,16 @@ class JaxLayer(Layer):
|
|
|
413
594
|
jax.tree_util.tree_map(unwrap_variable, self.state)
|
|
414
595
|
)
|
|
415
596
|
elif argument_name == "rng":
|
|
416
|
-
call_args.append(
|
|
597
|
+
call_args.append(
|
|
598
|
+
jax.tree_util.tree_map(
|
|
599
|
+
_convert_to_jax_key, self._get_call_rng(training)
|
|
600
|
+
)
|
|
601
|
+
)
|
|
417
602
|
elif argument_name == "inputs":
|
|
418
603
|
call_args.append(inputs)
|
|
419
604
|
elif argument_name == "training":
|
|
420
|
-
|
|
605
|
+
if backend.backend() == "jax":
|
|
606
|
+
call_args.append(training)
|
|
421
607
|
|
|
422
608
|
def assign_state_to_variable(value, variable):
|
|
423
609
|
# This exists only to make debugging this error case easier.
|
|
@@ -429,14 +615,23 @@ class JaxLayer(Layer):
|
|
|
429
615
|
)
|
|
430
616
|
variable.assign(value)
|
|
431
617
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
618
|
+
def call_with_fn(fn):
|
|
619
|
+
if self.call_fn_has_state:
|
|
620
|
+
predictions, new_state = fn(*call_args)
|
|
621
|
+
jax.tree_util.tree_map(
|
|
622
|
+
assign_state_to_variable, new_state, self.state
|
|
623
|
+
)
|
|
624
|
+
return predictions
|
|
625
|
+
else:
|
|
626
|
+
return fn(*call_args)
|
|
627
|
+
|
|
628
|
+
if backend.backend() == "jax":
|
|
629
|
+
return call_with_fn(self.call_fn)
|
|
630
|
+
elif backend.backend() == "tensorflow":
|
|
631
|
+
if training and self.jax2tf_training_true_fn is not None:
|
|
632
|
+
return call_with_fn(self.jax2tf_training_true_fn)
|
|
633
|
+
else:
|
|
634
|
+
return call_with_fn(self.jax2tf_training_false_fn)
|
|
440
635
|
|
|
441
636
|
def get_config(self):
|
|
442
637
|
config = {
|
|
@@ -554,18 +749,12 @@ class FlaxLayer(JaxLayer):
|
|
|
554
749
|
**kwargs,
|
|
555
750
|
):
|
|
556
751
|
# Late import to only require Flax when this is used.
|
|
557
|
-
from flax.
|
|
558
|
-
|
|
559
|
-
if backend.backend() != "jax":
|
|
560
|
-
raise ValueError(
|
|
561
|
-
"FlaxLayer is only supported with the JAX backend. Current "
|
|
562
|
-
f"backend: {backend.backend()}"
|
|
563
|
-
)
|
|
752
|
+
from flax.linen import DenyList
|
|
564
753
|
|
|
565
754
|
self.module = module
|
|
566
755
|
self.method = method
|
|
567
756
|
|
|
568
|
-
apply_mutable =
|
|
757
|
+
apply_mutable = DenyList(["params"])
|
|
569
758
|
|
|
570
759
|
def apply_with_training(params, state, rng, inputs, training):
|
|
571
760
|
return self.module.apply(
|
|
@@ -650,13 +839,13 @@ class FlaxLayer(JaxLayer):
|
|
|
650
839
|
|
|
651
840
|
def _get_init_rng(self):
|
|
652
841
|
return {
|
|
653
|
-
"params": self.
|
|
654
|
-
"dropout": self.
|
|
842
|
+
"params": self._get_init_seed(),
|
|
843
|
+
"dropout": self._get_init_seed(),
|
|
655
844
|
}
|
|
656
845
|
|
|
657
846
|
def _get_call_rng(self, training):
|
|
658
847
|
if training:
|
|
659
|
-
return {"dropout": self.
|
|
848
|
+
return {"dropout": self._get_call_seed()}
|
|
660
849
|
else:
|
|
661
850
|
return {}
|
|
662
851
|
|
keras/src/utils/module_utils.py
CHANGED
|
@@ -39,11 +39,31 @@ class LazyModule:
|
|
|
39
39
|
return f"LazyModule({self.name})"
|
|
40
40
|
|
|
41
41
|
|
|
42
|
+
class OrbaxLazyModule(LazyModule):
|
|
43
|
+
def initialize(self):
|
|
44
|
+
try:
|
|
45
|
+
parent_module = importlib.import_module("orbax.checkpoint")
|
|
46
|
+
self.module = parent_module.v1
|
|
47
|
+
self.parent_module = parent_module
|
|
48
|
+
except ImportError:
|
|
49
|
+
raise ImportError(self.import_error_msg)
|
|
50
|
+
|
|
51
|
+
def __getattr__(self, name):
|
|
52
|
+
if name == "_api_export_path":
|
|
53
|
+
raise AttributeError
|
|
54
|
+
if self.module is None:
|
|
55
|
+
self.initialize()
|
|
56
|
+
if name == "multihost":
|
|
57
|
+
return self.parent_module.multihost
|
|
58
|
+
return getattr(self.module, name)
|
|
59
|
+
|
|
60
|
+
|
|
42
61
|
tensorflow = LazyModule("tensorflow")
|
|
43
62
|
gfile = LazyModule("tensorflow.io.gfile", pip_name="tensorflow")
|
|
44
63
|
tensorflow_io = LazyModule("tensorflow_io")
|
|
45
64
|
scipy = LazyModule("scipy")
|
|
46
65
|
jax = LazyModule("jax")
|
|
66
|
+
h5py = LazyModule("h5py")
|
|
47
67
|
torch_xla = LazyModule(
|
|
48
68
|
"torch_xla",
|
|
49
69
|
import_error_msg=(
|
|
@@ -59,3 +79,12 @@ optree = LazyModule("optree")
|
|
|
59
79
|
dmtree = LazyModule("tree")
|
|
60
80
|
tf2onnx = LazyModule("tf2onnx")
|
|
61
81
|
grain = LazyModule("grain")
|
|
82
|
+
litert = LazyModule("ai_edge_litert")
|
|
83
|
+
ocp = OrbaxLazyModule(
|
|
84
|
+
"orbax.checkpoint.v1",
|
|
85
|
+
pip_name="orbax-checkpoint",
|
|
86
|
+
import_error_msg=(
|
|
87
|
+
"OrbaxCheckpoint requires the 'orbax-checkpoint' package. "
|
|
88
|
+
"You can install it via pip install orbax-checkpoint"
|
|
89
|
+
),
|
|
90
|
+
)
|
keras/src/utils/progbar.py
CHANGED
|
@@ -3,7 +3,8 @@ import os
|
|
|
3
3
|
import sys
|
|
4
4
|
import time
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
7
8
|
from keras.src.api_export import keras_export
|
|
8
9
|
from keras.src.utils import io_utils
|
|
9
10
|
|
|
@@ -162,12 +163,10 @@ class Progbar:
|
|
|
162
163
|
for k in self._values_order:
|
|
163
164
|
info += f" - {k}:"
|
|
164
165
|
if isinstance(self._values[k], list):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
)
|
|
170
|
-
avg = float(avg)
|
|
166
|
+
values, count = self._values[k]
|
|
167
|
+
if not isinstance(values, float):
|
|
168
|
+
values = np.mean(values)
|
|
169
|
+
avg = values / max(1, count)
|
|
171
170
|
if abs(avg) > 1e-3:
|
|
172
171
|
info += f" {avg:.4f}"
|
|
173
172
|
else:
|
|
@@ -194,11 +193,10 @@ class Progbar:
|
|
|
194
193
|
info += f" -{self._format_time(time_per_unit, self.unit_name)}"
|
|
195
194
|
for k in self._values_order:
|
|
196
195
|
info += f" - {k}:"
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
)
|
|
196
|
+
values, count = self._values[k]
|
|
197
|
+
if not isinstance(values, float):
|
|
198
|
+
values = np.mean(values)
|
|
199
|
+
avg = values / max(1, count)
|
|
202
200
|
if avg > 1e-3:
|
|
203
201
|
info += f" {avg:.4f}"
|
|
204
202
|
else:
|
keras/src/utils/python_utils.py
CHANGED
|
@@ -181,6 +181,8 @@ def pythonify_logs(logs):
|
|
|
181
181
|
A flattened dict with values converted to Python-native types if
|
|
182
182
|
possible.
|
|
183
183
|
"""
|
|
184
|
+
from keras.src import backend
|
|
185
|
+
|
|
184
186
|
logs = logs or {}
|
|
185
187
|
result = {}
|
|
186
188
|
for key, value in sorted(logs.items()):
|
|
@@ -188,6 +190,9 @@ def pythonify_logs(logs):
|
|
|
188
190
|
result.update(pythonify_logs(value))
|
|
189
191
|
else:
|
|
190
192
|
try:
|
|
193
|
+
# Prevent torch compiler from breaking the graph.
|
|
194
|
+
if backend.is_tensor(value):
|
|
195
|
+
value = backend.convert_to_numpy(value)
|
|
191
196
|
value = float(value)
|
|
192
197
|
except:
|
|
193
198
|
pass
|
keras/src/utils/rng_utils.py
CHANGED
|
@@ -5,6 +5,7 @@ import numpy as np
|
|
|
5
5
|
from keras.src import backend
|
|
6
6
|
from keras.src.api_export import keras_export
|
|
7
7
|
from keras.src.backend.common import global_state
|
|
8
|
+
from keras.src.random import seed_generator
|
|
8
9
|
from keras.src.utils.module_utils import tensorflow as tf
|
|
9
10
|
|
|
10
11
|
GLOBAL_RANDOM_SEED = "global_random_seed"
|
|
@@ -20,7 +21,7 @@ def set_random_seed(seed):
|
|
|
20
21
|
sources of randomness, or when certain non-deterministic cuDNN ops are
|
|
21
22
|
involved.
|
|
22
23
|
|
|
23
|
-
Calling this utility
|
|
24
|
+
Calling this utility does the following:
|
|
24
25
|
|
|
25
26
|
```python
|
|
26
27
|
import random
|
|
@@ -36,6 +37,9 @@ def set_random_seed(seed):
|
|
|
36
37
|
torch.manual_seed(seed)
|
|
37
38
|
```
|
|
38
39
|
|
|
40
|
+
Additionally, it resets the global Keras `SeedGenerator`, which is used by
|
|
41
|
+
`keras.random` functions when the `seed` is not provided.
|
|
42
|
+
|
|
39
43
|
Note that the TensorFlow seed is set even if you're not using TensorFlow
|
|
40
44
|
as your backend framework, since many workflows leverage `tf.data`
|
|
41
45
|
pipelines (which feature random shuffling). Likewise many workflows
|
|
@@ -52,6 +56,10 @@ def set_random_seed(seed):
|
|
|
52
56
|
|
|
53
57
|
# Store seed in global state so we can query it if set.
|
|
54
58
|
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
|
|
59
|
+
# Remove global SeedGenerator, it will be recreated from the seed.
|
|
60
|
+
global_state.set_global_attribute(
|
|
61
|
+
seed_generator.GLOBAL_SEED_GENERATOR, None
|
|
62
|
+
)
|
|
55
63
|
random.seed(seed)
|
|
56
64
|
np.random.seed(seed)
|
|
57
65
|
if tf.available:
|