keras-nightly 3.14.0.dev2025122704__py3-none-any.whl → 3.14.0.dev2026012204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  2. keras/_tf_keras/keras/ops/__init__.py +3 -0
  3. keras/_tf_keras/keras/ops/numpy/__init__.py +3 -0
  4. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  5. keras/dtype_policies/__init__.py +3 -0
  6. keras/ops/__init__.py +3 -0
  7. keras/ops/numpy/__init__.py +3 -0
  8. keras/quantizers/__init__.py +1 -0
  9. keras/src/backend/jax/nn.py +26 -9
  10. keras/src/backend/jax/numpy.py +16 -0
  11. keras/src/backend/numpy/numpy.py +23 -0
  12. keras/src/backend/openvino/numpy.py +369 -16
  13. keras/src/backend/tensorflow/numpy.py +34 -1
  14. keras/src/backend/tensorflow/rnn.py +17 -7
  15. keras/src/backend/torch/numpy.py +36 -0
  16. keras/src/backend/torch/rnn.py +28 -11
  17. keras/src/callbacks/orbax_checkpoint.py +75 -42
  18. keras/src/dtype_policies/__init__.py +2 -0
  19. keras/src/dtype_policies/dtype_policy.py +90 -1
  20. keras/src/layers/core/dense.py +122 -6
  21. keras/src/layers/core/einsum_dense.py +151 -7
  22. keras/src/layers/core/embedding.py +1 -1
  23. keras/src/layers/core/reversible_embedding.py +10 -1
  24. keras/src/layers/layer.py +5 -0
  25. keras/src/layers/preprocessing/feature_space.py +8 -4
  26. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  27. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +13 -15
  28. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  29. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  30. keras/src/losses/losses.py +24 -0
  31. keras/src/models/model.py +18 -9
  32. keras/src/ops/image.py +109 -96
  33. keras/src/ops/numpy.py +181 -0
  34. keras/src/quantizers/__init__.py +2 -0
  35. keras/src/quantizers/awq.py +361 -0
  36. keras/src/quantizers/awq_config.py +140 -0
  37. keras/src/quantizers/awq_core.py +217 -0
  38. keras/src/quantizers/gptq.py +1 -2
  39. keras/src/quantizers/gptq_core.py +1 -1
  40. keras/src/quantizers/quantization_config.py +14 -0
  41. keras/src/quantizers/quantizers.py +61 -52
  42. keras/src/random/seed_generator.py +2 -2
  43. keras/src/saving/file_editor.py +81 -6
  44. keras/src/saving/orbax_util.py +50 -0
  45. keras/src/saving/saving_api.py +37 -14
  46. keras/src/utils/jax_layer.py +69 -31
  47. keras/src/utils/module_utils.py +11 -0
  48. keras/src/utils/tracking.py +5 -5
  49. keras/src/version.py +1 -1
  50. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/METADATA +1 -1
  51. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/RECORD +53 -49
  52. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/WHEEL +1 -1
  53. {keras_nightly-3.14.0.dev2025122704.dist-info → keras_nightly-3.14.0.dev2026012204.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,9 @@ since your modifications would be overwritten.
7
7
  from keras.src.dtype_policies import deserialize as deserialize
8
8
  from keras.src.dtype_policies import get as get
9
9
  from keras.src.dtype_policies import serialize as serialize
10
+ from keras.src.dtype_policies.dtype_policy import (
11
+ AWQDTypePolicy as AWQDTypePolicy,
12
+ )
10
13
  from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
14
  from keras.src.dtype_policies.dtype_policy import (
12
15
  FloatDTypePolicy as FloatDTypePolicy,
@@ -245,8 +245,10 @@ from keras.src.ops.numpy import mod as mod
245
245
  from keras.src.ops.numpy import moveaxis as moveaxis
246
246
  from keras.src.ops.numpy import multiply as multiply
247
247
  from keras.src.ops.numpy import nan_to_num as nan_to_num
248
+ from keras.src.ops.numpy import nansum as nansum
248
249
  from keras.src.ops.numpy import ndim as ndim
249
250
  from keras.src.ops.numpy import negative as negative
251
+ from keras.src.ops.numpy import nextafter as nextafter
250
252
  from keras.src.ops.numpy import nonzero as nonzero
251
253
  from keras.src.ops.numpy import not_equal as not_equal
252
254
  from keras.src.ops.numpy import ones as ones
@@ -255,6 +257,7 @@ from keras.src.ops.numpy import outer as outer
255
257
  from keras.src.ops.numpy import pad as pad
256
258
  from keras.src.ops.numpy import power as power
257
259
  from keras.src.ops.numpy import prod as prod
260
+ from keras.src.ops.numpy import ptp as ptp
258
261
  from keras.src.ops.numpy import quantile as quantile
259
262
  from keras.src.ops.numpy import ravel as ravel
260
263
  from keras.src.ops.numpy import real as real
@@ -129,8 +129,10 @@ from keras.src.ops.numpy import mod as mod
129
129
  from keras.src.ops.numpy import moveaxis as moveaxis
130
130
  from keras.src.ops.numpy import multiply as multiply
131
131
  from keras.src.ops.numpy import nan_to_num as nan_to_num
132
+ from keras.src.ops.numpy import nansum as nansum
132
133
  from keras.src.ops.numpy import ndim as ndim
133
134
  from keras.src.ops.numpy import negative as negative
135
+ from keras.src.ops.numpy import nextafter as nextafter
134
136
  from keras.src.ops.numpy import nonzero as nonzero
135
137
  from keras.src.ops.numpy import not_equal as not_equal
136
138
  from keras.src.ops.numpy import ones as ones
@@ -139,6 +141,7 @@ from keras.src.ops.numpy import outer as outer
139
141
  from keras.src.ops.numpy import pad as pad
140
142
  from keras.src.ops.numpy import power as power
141
143
  from keras.src.ops.numpy import prod as prod
144
+ from keras.src.ops.numpy import ptp as ptp
142
145
  from keras.src.ops.numpy import quantile as quantile
143
146
  from keras.src.ops.numpy import ravel as ravel
144
147
  from keras.src.ops.numpy import real as real
@@ -7,6 +7,7 @@ since your modifications would be overwritten.
7
7
  from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
+ from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
10
11
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11
12
  from keras.src.quantizers.quantization_config import (
12
13
  Float8QuantizationConfig as Float8QuantizationConfig,
@@ -7,6 +7,9 @@ since your modifications would be overwritten.
7
7
  from keras.src.dtype_policies import deserialize as deserialize
8
8
  from keras.src.dtype_policies import get as get
9
9
  from keras.src.dtype_policies import serialize as serialize
10
+ from keras.src.dtype_policies.dtype_policy import (
11
+ AWQDTypePolicy as AWQDTypePolicy,
12
+ )
10
13
  from keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy
11
14
  from keras.src.dtype_policies.dtype_policy import (
12
15
  FloatDTypePolicy as FloatDTypePolicy,
keras/ops/__init__.py CHANGED
@@ -245,8 +245,10 @@ from keras.src.ops.numpy import mod as mod
245
245
  from keras.src.ops.numpy import moveaxis as moveaxis
246
246
  from keras.src.ops.numpy import multiply as multiply
247
247
  from keras.src.ops.numpy import nan_to_num as nan_to_num
248
+ from keras.src.ops.numpy import nansum as nansum
248
249
  from keras.src.ops.numpy import ndim as ndim
249
250
  from keras.src.ops.numpy import negative as negative
251
+ from keras.src.ops.numpy import nextafter as nextafter
250
252
  from keras.src.ops.numpy import nonzero as nonzero
251
253
  from keras.src.ops.numpy import not_equal as not_equal
252
254
  from keras.src.ops.numpy import ones as ones
@@ -255,6 +257,7 @@ from keras.src.ops.numpy import outer as outer
255
257
  from keras.src.ops.numpy import pad as pad
256
258
  from keras.src.ops.numpy import power as power
257
259
  from keras.src.ops.numpy import prod as prod
260
+ from keras.src.ops.numpy import ptp as ptp
258
261
  from keras.src.ops.numpy import quantile as quantile
259
262
  from keras.src.ops.numpy import ravel as ravel
260
263
  from keras.src.ops.numpy import real as real
@@ -129,8 +129,10 @@ from keras.src.ops.numpy import mod as mod
129
129
  from keras.src.ops.numpy import moveaxis as moveaxis
130
130
  from keras.src.ops.numpy import multiply as multiply
131
131
  from keras.src.ops.numpy import nan_to_num as nan_to_num
132
+ from keras.src.ops.numpy import nansum as nansum
132
133
  from keras.src.ops.numpy import ndim as ndim
133
134
  from keras.src.ops.numpy import negative as negative
135
+ from keras.src.ops.numpy import nextafter as nextafter
134
136
  from keras.src.ops.numpy import nonzero as nonzero
135
137
  from keras.src.ops.numpy import not_equal as not_equal
136
138
  from keras.src.ops.numpy import ones as ones
@@ -139,6 +141,7 @@ from keras.src.ops.numpy import outer as outer
139
141
  from keras.src.ops.numpy import pad as pad
140
142
  from keras.src.ops.numpy import power as power
141
143
  from keras.src.ops.numpy import prod as prod
144
+ from keras.src.ops.numpy import ptp as ptp
142
145
  from keras.src.ops.numpy import quantile as quantile
143
146
  from keras.src.ops.numpy import ravel as ravel
144
147
  from keras.src.ops.numpy import real as real
@@ -7,6 +7,7 @@ since your modifications would be overwritten.
7
7
  from keras.src.quantizers import deserialize as deserialize
8
8
  from keras.src.quantizers import get as get
9
9
  from keras.src.quantizers import serialize as serialize
10
+ from keras.src.quantizers.awq_config import AWQConfig as AWQConfig
10
11
  from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11
12
  from keras.src.quantizers.quantization_config import (
12
13
  Float8QuantizationConfig as Float8QuantizationConfig,
@@ -1471,25 +1471,42 @@ def _can_use_flash_attention(query, key, value, bias, raise_error=False):
1471
1471
  # Only support at least Ampere
1472
1472
  if not check_compute_capability("8.0"):
1473
1473
  raise RuntimeError("Require at least Ampere arch to run")
1474
- # Check inputs layout
1474
+
1475
+ # Inspect inputs of `check_layout`
1475
1476
  check_layout_params = list(
1476
1477
  inspect.signature(check_layout).parameters.keys()
1477
1478
  )
1478
1479
  for known_param in ("query", "key", "value", "bias", "layout"):
1479
1480
  check_layout_params.remove(known_param)
1480
1481
  # Defaults to `None` when not specified.
1481
- kwargs = {key: None for key in check_layout_params}
1482
+ check_layout_kwargs = {key: None for key in check_layout_params}
1482
1483
  check_layout(
1483
- query, key, value, bias, layout=_normalize_layout("BTNH"), **kwargs
1484
- )
1485
- check_is_flash_attention(
1486
1484
  query,
1487
1485
  key,
1488
- _normalize_layout("BTNH"),
1489
- cudnn_version,
1490
- bias is not None,
1491
- is_training=False,
1486
+ value,
1487
+ bias,
1488
+ layout=_normalize_layout("BTNH"),
1489
+ **check_layout_kwargs,
1492
1490
  )
1491
+
1492
+ # Inspect inputs of `check_is_flash_attention`
1493
+ check_is_flash_attention_params = inspect.signature(
1494
+ check_is_flash_attention
1495
+ ).parameters
1496
+ check_is_flash_attention_kwargs = {
1497
+ "query": query,
1498
+ "key": key,
1499
+ "value": value,
1500
+ "layout": _normalize_layout("BTNH"),
1501
+ "cudnn_version": cudnn_version,
1502
+ "has_bias": bias is not None,
1503
+ "is_training": False,
1504
+ }
1505
+ # Remove unsupported arguments
1506
+ for param in list(check_is_flash_attention_kwargs.keys()):
1507
+ if param not in check_is_flash_attention_params:
1508
+ check_is_flash_attention_kwargs.pop(param)
1509
+ check_is_flash_attention(**check_is_flash_attention_kwargs)
1493
1510
  return True
1494
1511
  except:
1495
1512
  if raise_error:
@@ -1013,6 +1013,11 @@ def moveaxis(x, source, destination):
1013
1013
  return jnp.moveaxis(x, source=source, destination=destination)
1014
1014
 
1015
1015
 
1016
+ def nansum(x, axis=None, keepdims=False):
1017
+ x = convert_to_tensor(x)
1018
+ return jnp.nansum(x, axis=axis, keepdims=keepdims)
1019
+
1020
+
1016
1021
  def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
1017
1022
  x = convert_to_tensor(x)
1018
1023
  return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
@@ -1063,6 +1068,11 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1063
1068
  return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
1064
1069
 
1065
1070
 
1071
+ def ptp(x, axis=None, keepdims=False):
1072
+ x = convert_to_tensor(x)
1073
+ return jnp.ptp(x, axis=axis, keepdims=keepdims)
1074
+
1075
+
1066
1076
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1067
1077
  x = convert_to_tensor(x)
1068
1078
  q = convert_to_tensor(q)
@@ -1351,6 +1361,12 @@ def negative(x):
1351
1361
  return jnp.negative(x)
1352
1362
 
1353
1363
 
1364
+ def nextafter(x1, x2):
1365
+ x1 = convert_to_tensor(x1)
1366
+ x2 = convert_to_tensor(x2)
1367
+ return jnp.nextafter(x1, x2)
1368
+
1369
+
1354
1370
  @sparse.elementwise_unary(linear=False)
1355
1371
  def square(x):
1356
1372
  x = convert_to_tensor(x)
@@ -960,6 +960,17 @@ def moveaxis(x, source, destination):
960
960
  return np.moveaxis(x, source=source, destination=destination)
961
961
 
962
962
 
963
+ def nansum(x, axis=None, keepdims=False):
964
+ axis = standardize_axis_for_numpy(axis)
965
+ dtype = standardize_dtype(x.dtype)
966
+
967
+ if dtype in ("bool", "int8", "int16"):
968
+ dtype = "int32"
969
+ elif dtype in ("uint8", "uint16"):
970
+ dtype = "uint32"
971
+ return np.nansum(x, axis=axis, keepdims=keepdims).astype(dtype)
972
+
973
+
963
974
  def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
964
975
  return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
965
976
 
@@ -1018,6 +1029,10 @@ def prod(x, axis=None, keepdims=False, dtype=None):
1018
1029
  return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)
1019
1030
 
1020
1031
 
1032
+ def ptp(x, axis=None, keepdims=False):
1033
+ return np.ptp(x, axis=axis, keepdims=keepdims)
1034
+
1035
+
1021
1036
  def quantile(x, q, axis=None, method="linear", keepdims=False):
1022
1037
  axis = standardize_axis_for_numpy(axis)
1023
1038
  x = convert_to_tensor(x)
@@ -1335,6 +1350,14 @@ def negative(x):
1335
1350
  return np.negative(x)
1336
1351
 
1337
1352
 
1353
+ def nextafter(x1, x2):
1354
+ x1 = convert_to_tensor(x1)
1355
+ x2 = convert_to_tensor(x2)
1356
+ dtype = dtypes.result_type(x1.dtype, x2.dtype, float)
1357
+
1358
+ return np.nextafter(x1, x2).astype(dtype)
1359
+
1360
+
1338
1361
  def square(x):
1339
1362
  x = convert_to_tensor(x)
1340
1363
  if standardize_dtype(x.dtype) == "bool":