keras-nightly 3.12.0.dev2025092403__py3-none-any.whl → 3.14.0.dev2026010104__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/layers/__init__.py +21 -0
  7. keras/_tf_keras/keras/ops/__init__.py +13 -0
  8. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  9. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  11. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  12. keras/_tf_keras/keras/quantizers/__init__.py +12 -0
  13. keras/callbacks/__init__.py +3 -0
  14. keras/distillation/__init__.py +16 -0
  15. keras/distribution/__init__.py +3 -0
  16. keras/layers/__init__.py +21 -0
  17. keras/ops/__init__.py +13 -0
  18. keras/ops/image/__init__.py +1 -0
  19. keras/ops/linalg/__init__.py +1 -0
  20. keras/ops/nn/__init__.py +3 -0
  21. keras/ops/numpy/__init__.py +9 -0
  22. keras/quantizers/__init__.py +12 -0
  23. keras/src/applications/imagenet_utils.py +4 -1
  24. keras/src/backend/common/backend_utils.py +30 -6
  25. keras/src/backend/common/dtypes.py +1 -1
  26. keras/src/backend/common/name_scope.py +2 -1
  27. keras/src/backend/common/variables.py +33 -16
  28. keras/src/backend/jax/core.py +92 -3
  29. keras/src/backend/jax/distribution_lib.py +16 -2
  30. keras/src/backend/jax/linalg.py +4 -0
  31. keras/src/backend/jax/nn.py +485 -20
  32. keras/src/backend/jax/numpy.py +92 -23
  33. keras/src/backend/jax/optimizer.py +3 -2
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +313 -2
  37. keras/src/backend/numpy/numpy.py +76 -7
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +1030 -185
  43. keras/src/backend/openvino/random.py +7 -14
  44. keras/src/backend/tensorflow/layer.py +43 -9
  45. keras/src/backend/tensorflow/linalg.py +24 -0
  46. keras/src/backend/tensorflow/nn.py +545 -1
  47. keras/src/backend/tensorflow/numpy.py +264 -54
  48. keras/src/backend/torch/core.py +3 -1
  49. keras/src/backend/torch/linalg.py +4 -0
  50. keras/src/backend/torch/nn.py +125 -0
  51. keras/src/backend/torch/numpy.py +84 -8
  52. keras/src/callbacks/__init__.py +1 -0
  53. keras/src/callbacks/callback_list.py +45 -11
  54. keras/src/callbacks/model_checkpoint.py +5 -0
  55. keras/src/callbacks/orbax_checkpoint.py +299 -0
  56. keras/src/callbacks/terminate_on_nan.py +54 -5
  57. keras/src/datasets/cifar10.py +5 -0
  58. keras/src/distillation/__init__.py +1 -0
  59. keras/src/distillation/distillation_loss.py +390 -0
  60. keras/src/distillation/distiller.py +598 -0
  61. keras/src/distribution/distribution_lib.py +14 -0
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/attention.py +1 -1
  70. keras/src/layers/attention/multi_head_attention.py +4 -1
  71. keras/src/layers/core/dense.py +191 -172
  72. keras/src/layers/core/einsum_dense.py +235 -186
  73. keras/src/layers/core/embedding.py +83 -93
  74. keras/src/layers/core/input_layer.py +1 -0
  75. keras/src/layers/core/reversible_embedding.py +390 -0
  76. keras/src/layers/input_spec.py +17 -17
  77. keras/src/layers/layer.py +40 -15
  78. keras/src/layers/merging/dot.py +4 -1
  79. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  80. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  81. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  82. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  83. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  84. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  85. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  86. keras/src/layers/preprocessing/discretization.py +6 -5
  87. keras/src/layers/preprocessing/index_lookup.py +19 -1
  88. keras/src/layers/preprocessing/normalization.py +16 -1
  89. keras/src/layers/regularization/dropout.py +43 -1
  90. keras/src/layers/rnn/gru.py +1 -1
  91. keras/src/layers/rnn/lstm.py +2 -2
  92. keras/src/layers/rnn/rnn.py +19 -0
  93. keras/src/layers/rnn/simple_rnn.py +1 -1
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/metrics/confusion_metrics.py +7 -6
  96. keras/src/models/cloning.py +4 -0
  97. keras/src/models/functional.py +11 -3
  98. keras/src/models/model.py +156 -27
  99. keras/src/ops/image.py +184 -3
  100. keras/src/ops/linalg.py +93 -0
  101. keras/src/ops/nn.py +268 -2
  102. keras/src/ops/numpy.py +541 -43
  103. keras/src/optimizers/adafactor.py +29 -10
  104. keras/src/optimizers/base_optimizer.py +22 -3
  105. keras/src/optimizers/loss_scale_optimizer.py +51 -18
  106. keras/src/optimizers/muon.py +65 -31
  107. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  108. keras/src/quantizers/__init__.py +12 -1
  109. keras/src/quantizers/gptq.py +8 -6
  110. keras/src/quantizers/gptq_config.py +36 -1
  111. keras/src/quantizers/gptq_core.py +150 -78
  112. keras/src/quantizers/quantization_config.py +232 -0
  113. keras/src/quantizers/quantizers.py +114 -38
  114. keras/src/quantizers/utils.py +23 -0
  115. keras/src/random/seed_generator.py +4 -2
  116. keras/src/saving/file_editor.py +81 -6
  117. keras/src/saving/saving_lib.py +1 -1
  118. keras/src/testing/__init__.py +1 -0
  119. keras/src/testing/test_case.py +45 -5
  120. keras/src/trainers/compile_utils.py +14 -5
  121. keras/src/utils/backend_utils.py +31 -4
  122. keras/src/utils/dataset_utils.py +234 -35
  123. keras/src/utils/file_utils.py +49 -11
  124. keras/src/utils/image_utils.py +14 -2
  125. keras/src/utils/jax_layer.py +187 -36
  126. keras/src/utils/module_utils.py +18 -0
  127. keras/src/utils/progbar.py +10 -12
  128. keras/src/utils/rng_utils.py +9 -1
  129. keras/src/version.py +1 -1
  130. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/METADATA +16 -6
  131. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/RECORD +133 -116
  132. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/WHEEL +0 -0
  133. {keras_nightly-3.12.0.dev2025092403.dist-info → keras_nightly-3.14.0.dev2026010104.dist-info}/top_level.txt +0 -0
keras/src/ops/numpy.py CHANGED
@@ -301,33 +301,6 @@ def all(x, axis=None, keepdims=False):
301
301
  return backend.numpy.all(x, axis=axis, keepdims=keepdims)
302
302
 
303
303
 
304
- class Any(Operation):
305
- def __init__(self, axis=None, keepdims=False, *, name=None):
306
- super().__init__(name=name)
307
- if isinstance(axis, int):
308
- self.axis = [axis]
309
- else:
310
- self.axis = axis
311
- self.keepdims = keepdims
312
-
313
- def call(self, x):
314
- return backend.numpy.any(
315
- x,
316
- axis=self.axis,
317
- keepdims=self.keepdims,
318
- )
319
-
320
- def compute_output_spec(self, x):
321
- return KerasTensor(
322
- reduce_shape(
323
- x.shape,
324
- axis=self.axis,
325
- keepdims=self.keepdims,
326
- ),
327
- dtype="bool",
328
- )
329
-
330
-
331
304
  class Angle(Operation):
332
305
  def call(self, x):
333
306
  return backend.numpy.angle(x)
@@ -363,6 +336,33 @@ def angle(x):
363
336
  return backend.numpy.angle(x)
364
337
 
365
338
 
339
+ class Any(Operation):
340
+ def __init__(self, axis=None, keepdims=False, *, name=None):
341
+ super().__init__(name=name)
342
+ if isinstance(axis, int):
343
+ self.axis = [axis]
344
+ else:
345
+ self.axis = axis
346
+ self.keepdims = keepdims
347
+
348
+ def call(self, x):
349
+ return backend.numpy.any(
350
+ x,
351
+ axis=self.axis,
352
+ keepdims=self.keepdims,
353
+ )
354
+
355
+ def compute_output_spec(self, x):
356
+ return KerasTensor(
357
+ reduce_shape(
358
+ x.shape,
359
+ axis=self.axis,
360
+ keepdims=self.keepdims,
361
+ ),
362
+ dtype="bool",
363
+ )
364
+
365
+
366
366
  @keras_export(["keras.ops.any", "keras.ops.numpy.any"])
367
367
  def any(x, axis=None, keepdims=False):
368
368
  """Test whether any array element along a given axis evaluates to `True`.
@@ -595,27 +595,28 @@ class Arange(Operation):
595
595
  super().__init__(name=name)
596
596
  self.dtype = None if dtype is None else backend.standardize_dtype(dtype)
597
597
 
598
- def call(self, start, stop=None, step=1):
598
+ def call(self, start, stop=None, step=None):
599
599
  return backend.numpy.arange(start, stop, step=step, dtype=self.dtype)
600
600
 
601
- def compute_output_spec(self, start, stop=None, step=1):
601
+ def compute_output_spec(self, start, stop=None, step=None):
602
602
  if stop is None:
603
603
  start, stop = 0, start
604
+ if step is None:
605
+ step = 1
604
606
  output_shape = [int(np.ceil((stop - start) / step))]
605
607
  dtype = self.dtype
606
608
  if dtype is None:
607
- dtypes_to_resolve = [
608
- getattr(start, "dtype", type(start)),
609
- getattr(step, "dtype", type(step)),
610
- ]
609
+ dtypes_to_resolve = [getattr(start, "dtype", type(start))]
611
610
  if stop is not None:
612
611
  dtypes_to_resolve.append(getattr(stop, "dtype", type(stop)))
612
+ if step is not None:
613
+ dtypes_to_resolve.append(getattr(step, "dtype", type(step)))
613
614
  dtype = dtypes.result_type(*dtypes_to_resolve)
614
615
  return KerasTensor(output_shape, dtype=dtype)
615
616
 
616
617
 
617
618
  @keras_export(["keras.ops.arange", "keras.ops.numpy.arange"])
618
- def arange(start, stop=None, step=1, dtype=None):
619
+ def arange(start, stop=None, step=None, dtype=None):
619
620
  """Return evenly spaced values within a given interval.
620
621
 
621
622
  `arange` can be called with a varying number of positional arguments:
@@ -923,6 +924,11 @@ def arctanh(x):
923
924
 
924
925
  Returns:
925
926
  Output tensor of same shape as `x`.
927
+
928
+ Example:
929
+ >>> x = keras.ops.convert_to_tensor([0, -0.5])
930
+ >>> keras.ops.arctanh(x)
931
+ array([ 0. , -0.54930615], dtype=float32)
926
932
  """
927
933
  if any_symbolic_tensors((x,)):
928
934
  return Arctanh().symbolic_call(x)
@@ -1123,6 +1129,68 @@ def array(x, dtype=None):
1123
1129
  return backend.numpy.array(x, dtype=dtype)
1124
1130
 
1125
1131
 
1132
+ class View(Operation):
1133
+ def __init__(self, dtype=None, *, name=None):
1134
+ super().__init__(name=name)
1135
+ self.dtype = None if dtype is None else backend.standardize_dtype(dtype)
1136
+
1137
+ def call(self, x):
1138
+ return backend.numpy.view(x, dtype=self.dtype)
1139
+
1140
+ def compute_output_spec(self, x):
1141
+ old_dtype = backend.standardize_dtype(x.dtype)
1142
+ new_dtype = backend.standardize_dtype(
1143
+ self.dtype if self.dtype else x.dtype
1144
+ )
1145
+
1146
+ old_itemsize = np.dtype(old_dtype).itemsize
1147
+ new_itemsize = np.dtype(new_dtype).itemsize
1148
+
1149
+ if old_itemsize == new_itemsize:
1150
+ return KerasTensor(x.shape, dtype=new_dtype)
1151
+
1152
+ if not x.shape:
1153
+ raise ValueError(
1154
+ "Cannot view a scalar as a different dtype if item sizes "
1155
+ "are different."
1156
+ )
1157
+
1158
+ output_shape = list(x.shape)
1159
+ if output_shape[-1] is not None:
1160
+ if (output_shape[-1] * old_itemsize) % new_itemsize != 0:
1161
+ raise ValueError(
1162
+ f"Cannot view array of shape {x.shape} and dtype {x.dtype} "
1163
+ f"as dtype {new_dtype} because the total number of bytes "
1164
+ "is not divisible by the new itemsize."
1165
+ )
1166
+ output_shape[-1] = output_shape[-1] * old_itemsize // new_itemsize
1167
+ return KerasTensor(tuple(output_shape), dtype=new_dtype)
1168
+
1169
+
1170
+ @keras_export(["keras.ops.view", "keras.ops.numpy.view"])
1171
+ def view(x, dtype=None):
1172
+ """Create a new bitwise view of the same data with the specified dtype.
1173
+
1174
+ Args:
1175
+ x: Input tensor.
1176
+ dtype: Data-type descriptor of the returned view,
1177
+ e.g., float32 or int16.
1178
+
1179
+ Returns:
1180
+ View of a tensor with data type dtype.
1181
+
1182
+ Examples:
1183
+ >>> x = keras.ops.array([1, 2, 3])
1184
+ >>> x
1185
+ array([1, 2, 3], dtype=int32)
1186
+ >>> keras.ops.view(x, dtype="float32")
1187
+ array([1.0e-45, 3.0e-45, 4.0e-45], dtype=float32)
1188
+ """
1189
+ if any_symbolic_tensors((x,)):
1190
+ return View(dtype=dtype).symbolic_call(x)
1191
+ return backend.numpy.view(x, dtype=dtype)
1192
+
1193
+
1126
1194
  class Average(Operation):
1127
1195
  def __init__(self, axis=None, *, name=None):
1128
1196
  super().__init__(name=name)
@@ -3051,6 +3119,48 @@ def empty(shape, dtype=None):
3051
3119
  return backend.numpy.empty(shape, dtype=dtype)
3052
3120
 
3053
3121
 
3122
+ class EmptyLike(Operation):
3123
+ def __init__(self, dtype=None, *, name=None):
3124
+ super().__init__(name=name)
3125
+ self.dtype = None if dtype is None else backend.standardize_dtype(dtype)
3126
+
3127
+ def call(self, x):
3128
+ return backend.numpy.empty_like(x, dtype=self.dtype)
3129
+
3130
+ def compute_output_spec(self, x):
3131
+ dtype = (
3132
+ backend.standardize_dtype(x.dtype)
3133
+ if self.dtype is None
3134
+ else self.dtype
3135
+ )
3136
+ return KerasTensor(x.shape, dtype=dtype)
3137
+
3138
+
3139
+ @keras_export(["keras.ops.empty_like", "keras.ops.numpy.empty_like"])
3140
+ def empty_like(x, dtype=None):
3141
+ """Return a new uninitialized tensor with the same shape and dtype as `x`.
3142
+
3143
+ Args:
3144
+ x: Input tensor to mimic shape and dtype.
3145
+ dtype: Optional data type. If None, uses `x.dtype`.
3146
+
3147
+ Returns:
3148
+ A tensor with the same shape and dtype as `x`, with arbitrary contents.
3149
+
3150
+ Example:
3151
+ >>> from keras import ops
3152
+ >>> x = ops.ones((2, 3), dtype="float32")
3153
+ >>> y = ops.empty_like(x)
3154
+ >>> y.shape
3155
+ (2, 3)
3156
+ >>> y.dtype
3157
+ dtype('float32')
3158
+ """
3159
+ if any_symbolic_tensors((x,)):
3160
+ return EmptyLike(dtype=dtype).symbolic_call(x)
3161
+ return backend.numpy.empty_like(x, dtype=dtype)
3162
+
3163
+
3054
3164
  class Equal(Operation):
3055
3165
  def call(self, x1, x2):
3056
3166
  return backend.numpy.equal(x1, x2)
@@ -3845,6 +3955,35 @@ def isposinf(x):
3845
3955
  return backend.numpy.isposinf(x)
3846
3956
 
3847
3957
 
3958
+ class Isreal(Operation):
3959
+ def call(self, x):
3960
+ return backend.numpy.isreal(x)
3961
+
3962
+ def compute_output_spec(self, x):
3963
+ return KerasTensor(x.shape, dtype="bool")
3964
+
3965
+
3966
+ @keras_export(["keras.ops.isreal", "keras.ops.numpy.isreal"])
3967
+ def isreal(x):
3968
+ """Test element-wise for real numbers.
3969
+
3970
+ Args:
3971
+ x: Input tensor.
3972
+
3973
+ Returns:
3974
+ Output boolean tensor.
3975
+
3976
+ Example:
3977
+ >>> from keras import ops
3978
+ >>> x = ops.array([1+1j, 1+0j, 4.5, 3, 2, 2j], dtype="complex64")
3979
+ >>> ops.isreal(x)
3980
+ array([False, True, True, True, True, False])
3981
+ """
3982
+ if any_symbolic_tensors((x,)):
3983
+ return Isreal().symbolic_call(x)
3984
+ return backend.numpy.isreal(x)
3985
+
3986
+
3848
3987
  class Kron(Operation):
3849
3988
  def call(self, x1, x2):
3850
3989
  return backend.numpy.kron(x1, x2)
@@ -3925,6 +4064,46 @@ def lcm(x1, x2):
3925
4064
  return backend.numpy.lcm(x1, x2)
3926
4065
 
3927
4066
 
4067
+ class Ldexp(Operation):
4068
+ def call(self, x1, x2):
4069
+ return backend.numpy.ldexp(x1, x2)
4070
+
4071
+ def compute_output_spec(self, x1, x2):
4072
+ x1_shape = getattr(x1, "shape", [])
4073
+ x2_shape = getattr(x2, "shape", [])
4074
+ output_shape = broadcast_shapes(x1_shape, x2_shape)
4075
+
4076
+ x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1)))
4077
+ x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2)))
4078
+ dtype = dtypes.result_type(x1_type, x2_type, float)
4079
+ return KerasTensor(output_shape, dtype=dtype)
4080
+
4081
+
4082
+ @keras_export(["keras.ops.ldexp", "keras.ops.numpy.ldexp"])
4083
+ def ldexp(x1, x2):
4084
+ """Multiply `x1` by 2 raised to the power of `x2`, element-wise.
4085
+
4086
+ This function computes:
4087
+ ldexp(x1, x2) = x1 * 2**x2
4088
+
4089
+ Args:
4090
+ x1: Float input tensor.
4091
+ x2: Integer exponent tensor.
4092
+
4093
+ Returns:
4094
+ Output tensor
4095
+
4096
+ Example:
4097
+ >>> x1 = keras.ops.convert_to_tensor([0.75, 1.5])
4098
+ >>> x2 = keras.ops.convert_to_tensor([1, 2])
4099
+ >>> keras.ops.ldexp(x1, x2)
4100
+ array([1.5, 6. ], dtype=float32)
4101
+ """
4102
+ if any_symbolic_tensors((x1, x2)):
4103
+ return Ldexp().symbolic_call(x1, x2)
4104
+ return backend.numpy.ldexp(x1, x2)
4105
+
4106
+
3928
4107
  class Less(Operation):
3929
4108
  def call(self, x1, x2):
3930
4109
  return backend.numpy.less(x1, x2)
@@ -4240,6 +4419,47 @@ def logaddexp(x1, x2):
4240
4419
  return backend.numpy.logaddexp(x1, x2)
4241
4420
 
4242
4421
 
4422
+ class Logaddexp2(Operation):
4423
+ def call(self, x1, x2):
4424
+ return backend.numpy.logaddexp2(x1, x2)
4425
+
4426
+ def compute_output_spec(self, x1, x2):
4427
+ x1_shape = getattr(x1, "shape", [])
4428
+ x2_shape = getattr(x2, "shape", [])
4429
+ output_shape = broadcast_shapes(x1_shape, x2_shape)
4430
+ dtype = dtypes.result_type(
4431
+ getattr(x1, "dtype", type(x1)),
4432
+ getattr(x2, "dtype", type(x2)),
4433
+ float,
4434
+ )
4435
+ return KerasTensor(output_shape, dtype=dtype)
4436
+
4437
+
4438
+ @keras_export(["keras.ops.logaddexp2", "keras.ops.numpy.logaddexp2"])
4439
+ def logaddexp2(x1, x2):
4440
+ """Base-2 logarithm of the sum of exponentiations of the inputs.
4441
+
4442
+ Calculates `log2(2**x1 + 2**x2)`.
4443
+
4444
+ Args:
4445
+ x1: Input tensor.
4446
+ x2: Input tensor.
4447
+
4448
+ Returns:
4449
+ Output tensor, element-wise log base 2 of the sum of 2**x1 and 2**x2.
4450
+
4451
+ Example:
4452
+ >>> from keras import ops
4453
+ >>> x1 = ops.array([1, 2, 3])
4454
+ >>> x2 = ops.array([1, 2, 3])
4455
+ >>> ops.logaddexp2(x1, x2)
4456
+ array([2., 3., 4.], dtype=float32)
4457
+ """
4458
+ if any_symbolic_tensors((x1, x2)):
4459
+ return Logaddexp2().symbolic_call(x1, x2)
4460
+ return backend.numpy.logaddexp2(x1, x2)
4461
+
4462
+
4243
4463
  class LogicalAnd(Operation):
4244
4464
  def call(self, x1, x2):
4245
4465
  return backend.numpy.logical_and(x1, x2)
@@ -5376,10 +5596,10 @@ def unravel_index(indices, shape):
5376
5596
  Tuple of arrays for each dimension with unraveled indices.
5377
5597
 
5378
5598
  Example:
5379
- >>> indices = 5
5380
- >>> shape = (3, 3)
5381
- >>> unravel_index(indices, shape)
5382
- (1, 2) # 5 is at row 1, column 2 in a 3x3 array
5599
+ >>> indices = 5
5600
+ >>> shape = (3, 3)
5601
+ >>> unravel_index(indices, shape)
5602
+ (1, 2) # 5 is at row 1, column 2 in a 3x3 array
5383
5603
  """
5384
5604
  if any_symbolic_tensors((indices,)):
5385
5605
  return UnravelIndex(shape).symbolic_call(indices)
@@ -6236,6 +6456,9 @@ class Tile(Operation):
6236
6456
  repeats = self.repeats
6237
6457
  if isinstance(repeats, int):
6238
6458
  repeats = [repeats]
6459
+ else:
6460
+ repeats = list(repeats)
6461
+
6239
6462
  if len(x_shape) > len(repeats):
6240
6463
  repeats = [1] * (len(x_shape) - len(repeats)) + repeats
6241
6464
  else:
@@ -6243,10 +6466,10 @@ class Tile(Operation):
6243
6466
 
6244
6467
  output_shape = []
6245
6468
  for x_size, repeat in zip(x_shape, repeats):
6246
- if x_size is None:
6247
- output_shape.append(None)
6248
- else:
6469
+ if isinstance(x_size, int):
6249
6470
  output_shape.append(x_size * repeat)
6471
+ else:
6472
+ output_shape.append(None)
6250
6473
  return KerasTensor(output_shape, dtype=x.dtype)
6251
6474
 
6252
6475
 
@@ -6294,8 +6517,13 @@ class Trace(Operation):
6294
6517
  x_shape[self.axis2] = -1
6295
6518
  output_shape = list(filter((-1).__ne__, x_shape))
6296
6519
  output_dtype = backend.standardize_dtype(x.dtype)
6297
- if output_dtype not in ("int64", "uint32", "uint64"):
6298
- output_dtype = dtypes.result_type(output_dtype, "int32")
6520
+ if output_dtype in ("bool", "int8", "int16"):
6521
+ output_dtype = "int32"
6522
+ elif output_dtype in ("uint8", "uint16"):
6523
+ output_dtype = "uint32"
6524
+ if output_dtype == "uint32" and backend.backend() == "torch":
6525
+ # Torch backend doesn't support uint32 dtype.
6526
+ output_dtype = "int32"
6299
6527
  return KerasTensor(output_shape, dtype=output_dtype)
6300
6528
 
6301
6529
 
@@ -6876,6 +7104,49 @@ def negative(x):
6876
7104
  return backend.numpy.negative(x)
6877
7105
 
6878
7106
 
7107
+ class Nextafter(Operation):
7108
+ def call(self, x1, x2):
7109
+ return backend.numpy.nextafter(x1, x2)
7110
+
7111
+ def compute_output_spec(self, x1, x2):
7112
+ x1_shape = getattr(x1, "shape", [])
7113
+ x2_shape = getattr(x2, "shape", [])
7114
+ output_shape = broadcast_shapes(x1_shape, x2_shape)
7115
+
7116
+ x1_type = backend.standardize_dtype(getattr(x1, "dtype", type(x1)))
7117
+ x2_type = backend.standardize_dtype(getattr(x2, "dtype", type(x2)))
7118
+ dtype = dtypes.result_type(x1_type, x2_type, float)
7119
+ return KerasTensor(output_shape, dtype=dtype)
7120
+
7121
+
7122
+ @keras_export(["keras.ops.nextafter", "keras.ops.numpy.nextafter"])
7123
+ def nextafter(x1, x2):
7124
+ """
7125
+ Return the next representable floating-point value after `x1` towards `x2`.
7126
+
7127
+ This function computes the next floating-point value
7128
+ following `x1` in the direction of `x2`, element-wise.
7129
+
7130
+ Args:
7131
+ x1: Input tensor whose values will be moved to the next
7132
+ representable floating-point value.
7133
+ x2: Input tensor indicating the direction toward which
7134
+ `x1` is moved.
7135
+
7136
+ Returns:
7137
+ Output tensor
7138
+
7139
+ Example:
7140
+ >>> x1 = keras.ops.convert_to_tensor([1.0, 1.0])
7141
+ >>> x2 = keras.ops.convert_to_tensor([2.0, 0.0])
7142
+ >>> keras.ops.nextafter(x1, x2)
7143
+ array([1.0000001, 0.99999994], dtype=float32)
7144
+ """
7145
+ if any_symbolic_tensors((x1, x2)):
7146
+ return Nextafter().symbolic_call(x1, x2)
7147
+ return backend.numpy.nextafter(x1, x2)
7148
+
7149
+
6879
7150
  class Square(Operation):
6880
7151
  def call(self, x):
6881
7152
  return backend.numpy.square(x)
@@ -7012,6 +7283,48 @@ def transpose(x, axes=None):
7012
7283
  return backend.numpy.transpose(x, axes=axes)
7013
7284
 
7014
7285
 
7286
+ class Trapezoid(Operation):
7287
+ def __init__(self, x=None, dx=1.0, axis=-1, *, name=None):
7288
+ super().__init__(name=name)
7289
+ self.x = x
7290
+ self.dx = dx
7291
+ self.axis = axis
7292
+
7293
+ def call(self, y):
7294
+ return backend.numpy.trapezoid(y, x=self.x, dx=self.dx, axis=self.axis)
7295
+
7296
+ def compute_output_spec(self, y):
7297
+ out_shape = list(y.shape)
7298
+ if self.axis is not None and len(out_shape) > 0:
7299
+ out_shape.pop(self.axis % len(out_shape))
7300
+ dtype = backend.result_type(getattr(y, "dtype", type(y)), float)
7301
+ return KerasTensor(tuple(out_shape), dtype=dtype)
7302
+
7303
+
7304
+ @keras_export(["keras.ops.trapezoid", "keras.ops.numpy.trapezoid"])
7305
+ def trapezoid(y, x=None, dx=1.0, axis=-1):
7306
+ """Integrate along the given axis using the composite trapezoidal rule.
7307
+
7308
+ Args:
7309
+ y: Input tensor.
7310
+ x: Optional tensor specifying sample points corresponding to `y`.
7311
+ If `None`, spacing is assumed to be `dx`.
7312
+ dx: Spacing between sample points when `x` is `None`.
7313
+ axis: Axis along which to integrate. Default is the last axis.
7314
+
7315
+ Returns:
7316
+ The approximate integral of `y` along the given axis.
7317
+
7318
+ Example:
7319
+ >>> y = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]])
7320
+ >>> keras.ops.trapezoid(y, axis=1)
7321
+ array([ 4., 10.], dtype=float32)
7322
+ """
7323
+ if any_symbolic_tensors((y,)):
7324
+ return Trapezoid(x=x, dx=dx, axis=axis).symbolic_call(y)
7325
+ return backend.numpy.trapezoid(y, x=x, dx=dx, axis=axis)
7326
+
7327
+
7015
7328
  class Mean(Operation):
7016
7329
  def __init__(self, axis=None, keepdims=False, *, name=None):
7017
7330
  super().__init__(name=name)
@@ -7057,6 +7370,77 @@ def mean(x, axis=None, keepdims=False):
7057
7370
  return backend.numpy.mean(x, axis=axis, keepdims=keepdims)
7058
7371
 
7059
7372
 
7373
+ class Vander(Operation):
7374
+ def __init__(self, N=None, increasing=False, *, name=None):
7375
+ super().__init__(name=name)
7376
+ self.N = N
7377
+ self.increasing = increasing
7378
+
7379
+ def call(self, x):
7380
+ return backend.numpy.vander(x, self.N, self.increasing)
7381
+
7382
+ def compute_output_spec(self, x):
7383
+ if self.N is None:
7384
+ N = x.shape[0]
7385
+ else:
7386
+ N = self.N
7387
+
7388
+ out_shape = x.shape + (N,)
7389
+ return KerasTensor(tuple(out_shape), dtype=x.dtype)
7390
+
7391
+
7392
+ @keras_export(["keras.ops.vander", "keras.ops.numpy.vander"])
7393
+ def vander(x, N=None, increasing=False):
7394
+ """Generate a Vandermonde matrix.
7395
+
7396
+ Args:
7397
+ x: 1D input tensor.
7398
+ N: Number of columns. If `None`, `N` = `len(x)`.
7399
+ increasing: Order of powers. If True, powers increase left to right.
7400
+
7401
+ Returns:
7402
+ Output tensor, Vandermonde matrix of shape `(len(x), N)`.
7403
+
7404
+ Example:
7405
+ >>> import numpy as np
7406
+ >>> import keras
7407
+ >>> x = np.array([1, 2, 3, 5])
7408
+ >>> keras.ops.vander(x)
7409
+ array([[ 1, 1, 1, 1],
7410
+ [ 8, 4, 2, 1],
7411
+ [ 27, 9, 3, 1],
7412
+ [125, 25, 5, 1]])
7413
+ """
7414
+
7415
+ if len(x.shape) != 1:
7416
+ raise ValueError(
7417
+ "Input tensor must be 1-dimensional. "
7418
+ f"Received: input.shape={x.shape}"
7419
+ )
7420
+
7421
+ if N is not None:
7422
+ if not isinstance(N, int):
7423
+ raise TypeError(
7424
+ f"Argument `N` must be of type `int`. "
7425
+ f"Received: N={N} of type {type(N)}"
7426
+ )
7427
+
7428
+ if N < 0:
7429
+ raise ValueError(
7430
+ f"Argument 'N' must be nonnegative. Received: N={N}"
7431
+ )
7432
+
7433
+ if not isinstance(increasing, bool):
7434
+ raise TypeError(
7435
+ f"Argument `increasing` must be of type `bool`. "
7436
+ f"Received: increasing={increasing} of type {type(increasing)}"
7437
+ )
7438
+
7439
+ if any_symbolic_tensors((x,)):
7440
+ return Vander(N=N, increasing=increasing).symbolic_call(x)
7441
+ return backend.numpy.vander(x, N=N, increasing=increasing)
7442
+
7443
+
7060
7444
  class Var(Operation):
7061
7445
  def __init__(self, axis=None, keepdims=False, *, name=None):
7062
7446
  super().__init__(name=name)
@@ -7186,6 +7570,19 @@ def eye(N, M=None, k=0, dtype=None):
7186
7570
  Returns:
7187
7571
  Tensor with ones on the k-th diagonal and zeros elsewhere.
7188
7572
  """
7573
+
7574
+ def is_floating_type(v):
7575
+ return (
7576
+ isinstance(v, float)
7577
+ or getattr(v, "dtype", None) in dtypes.FLOAT_TYPES
7578
+ )
7579
+
7580
+ if is_floating_type(N):
7581
+ raise TypeError("Argument `N` must be an integer or an integer tensor.")
7582
+ if is_floating_type(M):
7583
+ raise TypeError(
7584
+ "Argument `M` must be an integer, an integer tensor, or `None`."
7585
+ )
7189
7586
  return backend.numpy.eye(N, M=M, k=k, dtype=dtype)
7190
7587
 
7191
7588
 
@@ -7569,3 +7966,104 @@ def histogram(x, bins=10, range=None):
7569
7966
  f"Received: input.shape={x.shape}"
7570
7967
  )
7571
7968
  return backend.numpy.histogram(x, bins=bins, range=range)
7969
+
7970
+
7971
+ class ArraySplit(Operation):
7972
+ def __init__(self, indices_or_sections, axis=0, *, name=None):
7973
+ super().__init__(name=name)
7974
+
7975
+ self.indices_or_sections = indices_or_sections
7976
+ self.axis = axis
7977
+
7978
+ def call(self, x):
7979
+ return backend.numpy.array_split(
7980
+ x,
7981
+ indices_or_sections=self.indices_or_sections,
7982
+ axis=self.axis,
7983
+ )
7984
+
7985
+ def compute_output_spec(self, x):
7986
+ num_splits = self.indices_or_sections
7987
+
7988
+ axis = self.axis
7989
+ if axis < 0:
7990
+ axis += len(x.shape)
7991
+
7992
+ total_size = x.shape[axis]
7993
+
7994
+ if total_size is None:
7995
+ output_specs = []
7996
+ base_shape = list(x.shape)
7997
+ base_shape[axis] = None
7998
+ for _ in range(num_splits):
7999
+ output_specs.append(
8000
+ KerasTensor(shape=tuple(base_shape), dtype=x.dtype)
8001
+ )
8002
+ return tuple(output_specs)
8003
+
8004
+ split_size = total_size // num_splits
8005
+ remainder = total_size % num_splits
8006
+
8007
+ output_specs = []
8008
+ base_shape = list(x.shape)
8009
+ for i in range(num_splits):
8010
+ size = split_size + (1 if i < remainder else 0)
8011
+ shape = base_shape.copy()
8012
+ shape[axis] = size
8013
+ output_specs.append(KerasTensor(shape=tuple(shape), dtype=x.dtype))
8014
+
8015
+ return list(output_specs)
8016
+
8017
+
8018
+ @keras_export(["keras.ops.array_split", "keras.ops.numpy.array_split"])
8019
+ def array_split(x, indices_or_sections, axis=0):
8020
+ """Splits an array into multiple sub-arrays (unevenly).
8021
+
8022
+ This is similar to `keras.ops.split`, but it allows for
8023
+ unequal splits. `indices_or_sections` must be an integer
8024
+ that indicates the total number of sub-arrays to create.
8025
+ If the tensor cannot be divided evenly, the first `remainder`
8026
+ splits will have size `quotient + 1`, and the rest will
8027
+ have size `quotient`.
8028
+
8029
+ Args:
8030
+ x: Input tensor.
8031
+ indices_or_sections: An integer indicating the number of
8032
+ sub-arrays to create.
8033
+ axis: The axis along which to split. Defaults to 0.
8034
+
8035
+ Returns:
8036
+ A list of sub-tensors.
8037
+
8038
+ Example:
8039
+ >>> x = keras.ops.arange(10)
8040
+ >>> keras.ops.array_split(x, 3)
8041
+ (array([0, 1, 2, 3], dtype=int32),
8042
+ array([4, 5, 6], dtype=int32),
8043
+ array([7, 8, 9], dtype=int32))
8044
+ """
8045
+ if not isinstance(indices_or_sections, int):
8046
+ raise TypeError(
8047
+ "Argument `indices_or_sections` must be of type `int`. "
8048
+ f"Received: indices_or_sections={indices_or_sections}"
8049
+ )
8050
+
8051
+ if indices_or_sections <= 0:
8052
+ raise ValueError(
8053
+ "Argument `indices_or_sections` must be a positive integer. "
8054
+ f"Received: indices_or_sections={indices_or_sections}"
8055
+ )
8056
+
8057
+ if not isinstance(axis, int):
8058
+ raise TypeError(
8059
+ f"Argument `axis` must be of type `int`. Received: {axis}"
8060
+ )
8061
+
8062
+ if any_symbolic_tensors((x,)):
8063
+ return ArraySplit(
8064
+ indices_or_sections=indices_or_sections, axis=axis
8065
+ ).symbolic_call(x)
8066
+
8067
+ return backend.numpy.array_split(
8068
+ x, indices_or_sections=indices_or_sections, axis=axis
8069
+ )