keras-nightly 3.12.0.dev2025100503__py3-none-any.whl → 3.14.0.dev2026011604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. keras/__init__.py +1 -0
  2. keras/_tf_keras/keras/__init__.py +1 -0
  3. keras/_tf_keras/keras/callbacks/__init__.py +3 -0
  4. keras/_tf_keras/keras/distillation/__init__.py +16 -0
  5. keras/_tf_keras/keras/distribution/__init__.py +3 -0
  6. keras/_tf_keras/keras/dtype_policies/__init__.py +3 -0
  7. keras/_tf_keras/keras/layers/__init__.py +21 -0
  8. keras/_tf_keras/keras/ops/__init__.py +13 -0
  9. keras/_tf_keras/keras/ops/image/__init__.py +1 -0
  10. keras/_tf_keras/keras/ops/linalg/__init__.py +1 -0
  11. keras/_tf_keras/keras/ops/nn/__init__.py +3 -0
  12. keras/_tf_keras/keras/ops/numpy/__init__.py +9 -0
  13. keras/_tf_keras/keras/quantizers/__init__.py +13 -0
  14. keras/callbacks/__init__.py +3 -0
  15. keras/distillation/__init__.py +16 -0
  16. keras/distribution/__init__.py +3 -0
  17. keras/dtype_policies/__init__.py +3 -0
  18. keras/layers/__init__.py +21 -0
  19. keras/ops/__init__.py +13 -0
  20. keras/ops/image/__init__.py +1 -0
  21. keras/ops/linalg/__init__.py +1 -0
  22. keras/ops/nn/__init__.py +3 -0
  23. keras/ops/numpy/__init__.py +9 -0
  24. keras/quantizers/__init__.py +13 -0
  25. keras/src/applications/imagenet_utils.py +4 -1
  26. keras/src/backend/common/backend_utils.py +30 -6
  27. keras/src/backend/common/name_scope.py +2 -1
  28. keras/src/backend/common/variables.py +30 -15
  29. keras/src/backend/jax/core.py +92 -3
  30. keras/src/backend/jax/distribution_lib.py +16 -2
  31. keras/src/backend/jax/linalg.py +4 -0
  32. keras/src/backend/jax/nn.py +509 -29
  33. keras/src/backend/jax/numpy.py +59 -8
  34. keras/src/backend/jax/trainer.py +14 -2
  35. keras/src/backend/numpy/linalg.py +4 -0
  36. keras/src/backend/numpy/nn.py +311 -1
  37. keras/src/backend/numpy/numpy.py +65 -2
  38. keras/src/backend/openvino/__init__.py +1 -0
  39. keras/src/backend/openvino/core.py +2 -23
  40. keras/src/backend/openvino/linalg.py +4 -0
  41. keras/src/backend/openvino/nn.py +271 -20
  42. keras/src/backend/openvino/numpy.py +943 -189
  43. keras/src/backend/tensorflow/layer.py +43 -9
  44. keras/src/backend/tensorflow/linalg.py +24 -0
  45. keras/src/backend/tensorflow/nn.py +545 -1
  46. keras/src/backend/tensorflow/numpy.py +250 -50
  47. keras/src/backend/torch/core.py +3 -1
  48. keras/src/backend/torch/linalg.py +4 -0
  49. keras/src/backend/torch/nn.py +125 -0
  50. keras/src/backend/torch/numpy.py +80 -2
  51. keras/src/callbacks/__init__.py +1 -0
  52. keras/src/callbacks/model_checkpoint.py +5 -0
  53. keras/src/callbacks/orbax_checkpoint.py +332 -0
  54. keras/src/callbacks/terminate_on_nan.py +54 -5
  55. keras/src/datasets/cifar10.py +5 -0
  56. keras/src/distillation/__init__.py +1 -0
  57. keras/src/distillation/distillation_loss.py +390 -0
  58. keras/src/distillation/distiller.py +598 -0
  59. keras/src/distribution/distribution_lib.py +14 -0
  60. keras/src/dtype_policies/__init__.py +2 -0
  61. keras/src/dtype_policies/dtype_policy.py +90 -1
  62. keras/src/export/__init__.py +2 -0
  63. keras/src/export/export_utils.py +39 -2
  64. keras/src/export/litert.py +248 -0
  65. keras/src/export/openvino.py +1 -1
  66. keras/src/export/tf2onnx_lib.py +3 -0
  67. keras/src/layers/__init__.py +13 -0
  68. keras/src/layers/activations/softmax.py +9 -4
  69. keras/src/layers/attention/multi_head_attention.py +4 -1
  70. keras/src/layers/core/dense.py +241 -111
  71. keras/src/layers/core/einsum_dense.py +316 -131
  72. keras/src/layers/core/embedding.py +84 -94
  73. keras/src/layers/core/input_layer.py +1 -0
  74. keras/src/layers/core/reversible_embedding.py +399 -0
  75. keras/src/layers/input_spec.py +17 -17
  76. keras/src/layers/layer.py +45 -15
  77. keras/src/layers/merging/dot.py +4 -1
  78. keras/src/layers/pooling/adaptive_average_pooling1d.py +65 -0
  79. keras/src/layers/pooling/adaptive_average_pooling2d.py +62 -0
  80. keras/src/layers/pooling/adaptive_average_pooling3d.py +63 -0
  81. keras/src/layers/pooling/adaptive_max_pooling1d.py +65 -0
  82. keras/src/layers/pooling/adaptive_max_pooling2d.py +62 -0
  83. keras/src/layers/pooling/adaptive_max_pooling3d.py +63 -0
  84. keras/src/layers/pooling/base_adaptive_pooling.py +63 -0
  85. keras/src/layers/preprocessing/discretization.py +6 -5
  86. keras/src/layers/preprocessing/feature_space.py +8 -4
  87. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +2 -2
  88. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +3 -3
  89. keras/src/layers/preprocessing/image_preprocessing/resizing.py +10 -0
  90. keras/src/layers/preprocessing/index_lookup.py +19 -1
  91. keras/src/layers/preprocessing/normalization.py +14 -1
  92. keras/src/layers/regularization/dropout.py +43 -1
  93. keras/src/layers/rnn/rnn.py +19 -0
  94. keras/src/losses/loss.py +1 -1
  95. keras/src/losses/losses.py +24 -0
  96. keras/src/metrics/confusion_metrics.py +7 -6
  97. keras/src/models/cloning.py +4 -0
  98. keras/src/models/functional.py +11 -3
  99. keras/src/models/model.py +172 -34
  100. keras/src/ops/image.py +257 -20
  101. keras/src/ops/linalg.py +93 -0
  102. keras/src/ops/nn.py +258 -0
  103. keras/src/ops/numpy.py +569 -36
  104. keras/src/optimizers/muon.py +65 -31
  105. keras/src/optimizers/schedules/learning_rate_schedule.py +4 -3
  106. keras/src/quantizers/__init__.py +14 -1
  107. keras/src/quantizers/awq.py +361 -0
  108. keras/src/quantizers/awq_config.py +140 -0
  109. keras/src/quantizers/awq_core.py +217 -0
  110. keras/src/quantizers/gptq.py +2 -8
  111. keras/src/quantizers/gptq_config.py +36 -1
  112. keras/src/quantizers/gptq_core.py +65 -79
  113. keras/src/quantizers/quantization_config.py +246 -0
  114. keras/src/quantizers/quantizers.py +127 -61
  115. keras/src/quantizers/utils.py +23 -0
  116. keras/src/random/seed_generator.py +6 -4
  117. keras/src/saving/file_editor.py +81 -6
  118. keras/src/saving/orbax_util.py +26 -0
  119. keras/src/saving/saving_api.py +37 -14
  120. keras/src/saving/saving_lib.py +1 -1
  121. keras/src/testing/__init__.py +1 -0
  122. keras/src/testing/test_case.py +45 -5
  123. keras/src/utils/backend_utils.py +31 -4
  124. keras/src/utils/dataset_utils.py +234 -35
  125. keras/src/utils/file_utils.py +49 -11
  126. keras/src/utils/image_utils.py +14 -2
  127. keras/src/utils/jax_layer.py +244 -55
  128. keras/src/utils/module_utils.py +29 -0
  129. keras/src/utils/progbar.py +10 -2
  130. keras/src/utils/rng_utils.py +9 -1
  131. keras/src/utils/tracking.py +5 -5
  132. keras/src/version.py +1 -1
  133. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/METADATA +16 -6
  134. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/RECORD +136 -115
  135. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/WHEEL +0 -0
  136. {keras_nightly-3.12.0.dev2025100503.dist-info → keras_nightly-3.14.0.dev2026011604.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import openvino.opset14 as ov_opset
2
2
  from openvino import Type
3
3
 
4
4
  from keras.src import backend
5
+ from keras.src.backend.openvino.core import OPENVINO_DTYPES
5
6
  from keras.src.backend.openvino.core import OpenVINOKerasTensor
6
7
  from keras.src.backend.openvino.core import get_ov_output
7
8
 
@@ -16,6 +17,23 @@ def relu6(x):
16
17
  return OpenVINOKerasTensor(ov_opset.clamp(x, 0.0, 6.0).output(0))
17
18
 
18
19
 
20
+ def celu(x, alpha=1.0):
21
+ x = get_ov_output(x)
22
+ const_zero = get_ov_output(0.0, x.get_element_type())
23
+ const_alpha = get_ov_output(alpha, x.get_element_type())
24
+ const_one = get_ov_output(1.0, x.get_element_type())
25
+ exp_x_div_alpha = ov_opset.exp(ov_opset.divide(x, const_alpha)).output(0)
26
+ negative_branch = ov_opset.multiply(
27
+ const_alpha, ov_opset.subtract(exp_x_div_alpha, const_one)
28
+ )
29
+
30
+ celu_x = ov_opset.add(
31
+ ov_opset.maximum(x, const_zero).output(0),
32
+ ov_opset.minimum(negative_branch, const_zero).output(0),
33
+ )
34
+ return OpenVINOKerasTensor(celu_x.output(0))
35
+
36
+
19
37
  def sigmoid(x):
20
38
  x = get_ov_output(x)
21
39
  return OpenVINOKerasTensor(ov_opset.sigmoid(x).output(0))
@@ -26,6 +44,39 @@ def tanh(x):
26
44
  return OpenVINOKerasTensor(ov_opset.tanh(x).output(0))
27
45
 
28
46
 
47
+ def tanh_shrink(x):
48
+ x = get_ov_output(x)
49
+ return OpenVINOKerasTensor(ov_opset.subtract(x, ov_opset.tanh(x)).output(0))
50
+
51
+
52
+ def hard_tanh(x):
53
+ x = get_ov_output(x)
54
+ return OpenVINOKerasTensor(ov_opset.clamp(x, -1.0, 1.0).output(0))
55
+
56
+
57
+ def soft_shrink(x, threshold=0.5):
58
+ x = get_ov_output(x)
59
+ et = x.get_element_type()
60
+ thr = get_ov_output(threshold, et)
61
+ zero = get_ov_output(0.0, et)
62
+ abs_x = ov_opset.abs(x)
63
+ sub = ov_opset.subtract(abs_x, thr)
64
+ shrunk = ov_opset.maximum(sub, zero)
65
+ sign = ov_opset.sign(x)
66
+ out = ov_opset.multiply(sign, shrunk)
67
+ return OpenVINOKerasTensor(out.output(0))
68
+
69
+
70
+ def hard_shrink(x, threshold=0.5):
71
+ x = get_ov_output(x)
72
+ et = x.get_element_type()
73
+ thr = get_ov_output(threshold, et)
74
+ zero = get_ov_output(0.0, et)
75
+ cond = ov_opset.greater(ov_opset.abs(x), thr)
76
+ out = ov_opset.select(cond, x, zero)
77
+ return OpenVINOKerasTensor(out.output(0))
78
+
79
+
29
80
  def softplus(x):
30
81
  x = get_ov_output(x)
31
82
  return OpenVINOKerasTensor(ov_opset.softplus(x).output(0))
@@ -38,14 +89,15 @@ def softsign(x):
38
89
 
39
90
  def silu(x):
40
91
  x = get_ov_output(x)
41
- return OpenVINOKerasTensor(
42
- ov_opset.multiply(x, ov_opset.sigmoid(x)).output(0)
43
- )
92
+ beta = get_ov_output(1.0, x.get_element_type())
93
+ return OpenVINOKerasTensor(ov_opset.swish(x, beta=beta).output(0))
44
94
 
45
95
 
46
96
  def log_sigmoid(x):
47
- raise NotImplementedError(
48
- "`log_sigmoid` is not supported with openvino backend"
97
+ x = get_ov_output(x)
98
+ neg_x = ov_opset.negative(x)
99
+ return OpenVINOKerasTensor(
100
+ ov_opset.negative(ov_opset.softplus(neg_x)).output(0)
49
101
  )
50
102
 
51
103
 
@@ -58,6 +110,17 @@ def leaky_relu(x, negative_slope=0.2):
58
110
  return OpenVINOKerasTensor(leaky_relu)
59
111
 
60
112
 
113
+ def sparse_sigmoid(x):
114
+ x = get_ov_output(x)
115
+ et = x.get_element_type()
116
+ one = get_ov_output(1.0, et)
117
+ neg_one = get_ov_output(-1.0, et)
118
+ half = get_ov_output(0.5, et)
119
+ y = ov_opset.minimum(ov_opset.maximum(x, neg_one), one)
120
+ out = ov_opset.multiply(half, ov_opset.add(y, one))
121
+ return OpenVINOKerasTensor(out.output(0))
122
+
123
+
61
124
  def hard_sigmoid(x):
62
125
  x = get_ov_output(x)
63
126
  alpha = get_ov_output(1.0 / 6.0, x.get_element_type())
@@ -121,6 +184,48 @@ def log_softmax(x, axis=-1):
121
184
  return OpenVINOKerasTensor(ov_opset.log_softmax(x, axis).output(0))
122
185
 
123
186
 
187
+ def squareplus(x, b=4):
188
+ x = get_ov_output(x)
189
+ et = x.get_element_type()
190
+ b = get_ov_output(b, et)
191
+ two = get_ov_output(2.0, et)
192
+ x_squared = ov_opset.multiply(x, x)
193
+ inside = ov_opset.add(x_squared, b)
194
+ root = ov_opset.sqrt(inside)
195
+ summed = ov_opset.add(x, root)
196
+ out = ov_opset.divide(summed, two)
197
+ return OpenVINOKerasTensor(out.output(0))
198
+
199
+
200
+ def sparse_plus(x):
201
+ x = get_ov_output(x)
202
+ et = x.get_element_type()
203
+ one = get_ov_output(1.0, et)
204
+ neg_one = get_ov_output(-1.0, et)
205
+ zero = get_ov_output(0.0, et)
206
+ quarter = get_ov_output(0.25, et)
207
+ x_plus_1 = ov_opset.add(x, one)
208
+ quad = ov_opset.multiply(quarter, ov_opset.multiply(x_plus_1, x_plus_1))
209
+ leq_than_neg_one = ov_opset.less_equal(x, neg_one)
210
+ less_than_one = ov_opset.less(x, one)
211
+ out = ov_opset.select(
212
+ leq_than_neg_one,
213
+ zero,
214
+ ov_opset.select(less_than_one, quad, x),
215
+ )
216
+ return OpenVINOKerasTensor(out.output(0))
217
+
218
+
219
+ def threshold(x, threshold, default_value):
220
+ x = get_ov_output(x)
221
+ et = x.get_element_type()
222
+ thr = get_ov_output(threshold, et)
223
+ dv = get_ov_output(default_value, et)
224
+ cond = ov_opset.greater(x, thr)
225
+ out = ov_opset.select(cond, x, dv)
226
+ return OpenVINOKerasTensor(out.output(0))
227
+
228
+
124
229
  def max_pool(
125
230
  inputs,
126
231
  pool_size,
@@ -128,8 +233,18 @@ def max_pool(
128
233
  padding="valid",
129
234
  data_format=None,
130
235
  ):
131
- raise NotImplementedError(
132
- "`max_pool` is not supported with openvino backend"
236
+ num_spatial_dims = (
237
+ get_ov_output(inputs).get_partial_shape().rank.get_length() - 2
238
+ )
239
+ kwargs = {"dilations": [1] * num_spatial_dims} # required for ov max_pool
240
+ return _pool(
241
+ inputs,
242
+ pool_size,
243
+ ov_opset.max_pool,
244
+ strides,
245
+ padding,
246
+ data_format,
247
+ **kwargs,
133
248
  )
134
249
 
135
250
 
@@ -140,11 +255,62 @@ def average_pool(
140
255
  padding="valid",
141
256
  data_format=None,
142
257
  ):
143
- raise NotImplementedError(
144
- "`average_pool` is not supported with openvino backend"
258
+ return _pool(
259
+ inputs,
260
+ pool_size,
261
+ ov_opset.avg_pool,
262
+ strides,
263
+ padding,
264
+ data_format,
265
+ exclude_pad=True,
145
266
  )
146
267
 
147
268
 
269
+ def adaptive_average_pool(inputs, output_size, data_format=None):
270
+ """Adaptive average pooling - OpenVINO backend not yet supported."""
271
+ raise NotImplementedError("Adaptive pooling not implemented for OpenVINO.")
272
+
273
+
274
+ def adaptive_max_pool(inputs, output_size, data_format=None):
275
+ """Adaptive max pooling - OpenVINO backend not yet supported."""
276
+ raise NotImplementedError("Adaptive pooling not implemented for OpenVINO.")
277
+
278
+
279
+ def _pool(
280
+ inputs,
281
+ pool_size,
282
+ pooling_func,
283
+ strides=None,
284
+ padding="valid",
285
+ data_format=None,
286
+ **kwargs,
287
+ ):
288
+ data_format = backend.standardize_data_format(data_format)
289
+ inputs = get_ov_output(inputs)
290
+
291
+ num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2
292
+ if isinstance(pool_size, int):
293
+ pool_size = [pool_size] * num_spatial_dims
294
+
295
+ if strides is None:
296
+ strides = pool_size
297
+
298
+ strides = _adjust_strides_dilation(strides, num_spatial_dims)
299
+ pad_mode, pads_begin, pads_end = _adjust_padding(padding)
300
+ inputs = _adjust_input(inputs, num_spatial_dims, data_format)
301
+ pool_kwargs = {
302
+ "kernel_shape": pool_size,
303
+ "strides": strides,
304
+ "auto_pad": pad_mode,
305
+ "pads_begin": pads_begin,
306
+ "pads_end": pads_end,
307
+ **kwargs,
308
+ }
309
+ pooled = pooling_func(inputs, **pool_kwargs).output(0)
310
+ adjusted_pooled = _adjust_outputs(pooled, num_spatial_dims, data_format)
311
+ return OpenVINOKerasTensor(adjusted_pooled)
312
+
313
+
148
314
  def _adjust_strides_dilation(
149
315
  x,
150
316
  num_spatial_dims,
@@ -374,15 +540,33 @@ def conv_transpose(
374
540
 
375
541
 
376
542
  def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
377
- raise NotImplementedError(
378
- "`one_hot` is not supported with openvino backend"
379
- )
543
+ if sparse:
544
+ raise ValueError("`sparse=True` is not supported with openvino backend")
545
+ x = get_ov_output(x)
546
+ if dtype is None:
547
+ dtype = backend.floatx()
548
+ ov_dtype = OPENVINO_DTYPES[dtype]
549
+ on_value = get_ov_output(1, ov_dtype)
550
+ off_value = get_ov_output(0, ov_dtype)
551
+ one_hot_encoded = ov_opset.one_hot(
552
+ x,
553
+ depth=num_classes,
554
+ axis=axis,
555
+ on_value=on_value,
556
+ off_value=off_value,
557
+ ).output(0)
558
+ return OpenVINOKerasTensor(one_hot_encoded)
380
559
 
381
560
 
382
561
  def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
383
- raise NotImplementedError(
384
- "`multi_hot` is not supported with openvino backend"
385
- )
562
+ reduction_axis = 1 if len(x.shape) > 1 else 0
563
+ if backend.standardize_dtype(dtype) == "bool":
564
+ outputs = one_hot(x, num_classes, axis=axis, dtype=dtype, sparse=sparse)
565
+ result = ov_opset.reduce_logical_or(outputs, reduction_axis)
566
+ else:
567
+ outputs = one_hot(x, num_classes, axis=axis, dtype=dtype)
568
+ result = ov_opset.reduce_max(outputs, reduction_axis)
569
+ return OpenVINOKerasTensor(result.output(0))
386
570
 
387
571
 
388
572
  def categorical_crossentropy(target, output, from_logits=False, axis=-1):
@@ -465,9 +649,15 @@ def batch_normalization(
465
649
 
466
650
 
467
651
  def ctc_loss(target, output, target_length, output_length, mask_index=0):
468
- raise NotImplementedError(
469
- "`ctc_loss` is not supported with openvino backend"
652
+ target = get_ov_output(target)
653
+ output = get_ov_output(output)
654
+ target_length = get_ov_output(target_length)
655
+ output_length = get_ov_output(output_length)
656
+ ctc_loss_ = ov_opset.ctc_loss(
657
+ output, output_length, target, target_length, blank_index=mask_index
470
658
  )
659
+ ctc_loss_ = ov_opset.convert(ctc_loss_, OPENVINO_DTYPES[backend.floatx()])
660
+ return OpenVINOKerasTensor(ctc_loss_.output(0))
471
661
 
472
662
 
473
663
  def ctc_decode(
@@ -485,7 +675,27 @@ def ctc_decode(
485
675
 
486
676
 
487
677
  def psnr(x1, x2, max_val):
488
- raise NotImplementedError("`psnr` is not supported with openvino backend")
678
+ from keras.src.backend.openvino.numpy import log10
679
+
680
+ x1 = get_ov_output(x1)
681
+ x2 = get_ov_output(x2)
682
+ max_val = get_ov_output(max_val, x1.get_element_type())
683
+ diff = ov_opset.subtract(x1, x2)
684
+ squared_diff = ov_opset.multiply(diff, diff)
685
+ reduction_axes = list(range(0, x1.get_partial_shape().rank.get_length()))
686
+ mse = ov_opset.reduce_mean(squared_diff, reduction_axes).output(0)
687
+ log_max_val = get_ov_output(log10(OpenVINOKerasTensor(max_val)))
688
+ log_mse = get_ov_output(log10(OpenVINOKerasTensor(mse)))
689
+
690
+ psnr = ov_opset.subtract(
691
+ ov_opset.multiply(
692
+ ov_opset.constant(20, log_max_val.get_element_type()), log_max_val
693
+ ),
694
+ ov_opset.multiply(
695
+ ov_opset.constant(10, log_mse.get_element_type()), log_mse
696
+ ),
697
+ ).output(0)
698
+ return OpenVINOKerasTensor(psnr)
489
699
 
490
700
 
491
701
  def dot_product_attention(
@@ -499,6 +709,47 @@ def dot_product_attention(
499
709
  flash_attention=None,
500
710
  attn_logits_soft_cap=None,
501
711
  ):
502
- raise NotImplementedError(
503
- "`dot_product_attention` is not supported with openvino backend"
712
+ if bias is not None:
713
+ raise NotImplementedError(
714
+ "`dot_product_attention` with `bias` is not supported "
715
+ "with openvino backend"
716
+ )
717
+ if flash_attention:
718
+ raise NotImplementedError(
719
+ "`dot_product_attention` with `flash_attention` is not supported "
720
+ "with openvino backend"
721
+ )
722
+ if attn_logits_soft_cap is not None:
723
+ raise NotImplementedError(
724
+ "`dot_product_attention` with `attn_logits_soft_cap` is not "
725
+ "supported with openvino backend"
726
+ )
727
+ query = get_ov_output(query)
728
+ key = get_ov_output(key)
729
+ value = get_ov_output(value)
730
+ if query.get_element_type() != key.get_element_type():
731
+ ov_type = OPENVINO_DTYPES[backend.floatx()]
732
+ query = ov_opset.convert(query, ov_type).output(0)
733
+ key = ov_opset.convert(key, ov_type).output(0)
734
+ if value.get_element_type() != query.get_element_type():
735
+ value = ov_opset.convert(value, query.get_element_type()).output(0)
736
+ axes_const = ov_opset.constant([0, 2, 1, 3], Type.i32).output(0)
737
+
738
+ query = ov_opset.transpose(query, axes_const)
739
+ key = ov_opset.transpose(key, axes_const)
740
+ value = ov_opset.transpose(value, axes_const)
741
+ mask = get_ov_output(mask) if mask is not None else None
742
+ scale = (
743
+ get_ov_output(scale, query.get_element_type())
744
+ if scale is not None
745
+ else None
746
+ )
747
+ dpa = ov_opset.scaled_dot_product_attention(
748
+ query, key, value, attention_mask=mask, scale=scale, causal=is_causal
504
749
  )
750
+ dpa = ov_opset.transpose(dpa, axes_const)
751
+ return OpenVINOKerasTensor(dpa.output(0))
752
+
753
+
754
+ def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
755
+ raise NotImplementedError("`unfold` is not supported with openvino backend")