keras-nightly 3.12.0.dev2025082103__py3-none-any.whl → 3.12.0.dev2025082203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. keras/_tf_keras/keras/quantizers/__init__.py +1 -0
  2. keras/quantizers/__init__.py +1 -0
  3. keras/src/applications/convnext.py +20 -20
  4. keras/src/applications/densenet.py +21 -21
  5. keras/src/applications/efficientnet.py +16 -16
  6. keras/src/applications/efficientnet_v2.py +28 -28
  7. keras/src/applications/inception_resnet_v2.py +7 -7
  8. keras/src/applications/inception_v3.py +5 -5
  9. keras/src/applications/mobilenet_v2.py +13 -20
  10. keras/src/applications/mobilenet_v3.py +15 -15
  11. keras/src/applications/nasnet.py +7 -8
  12. keras/src/applications/resnet.py +32 -32
  13. keras/src/applications/xception.py +10 -10
  14. keras/src/backend/common/dtypes.py +3 -3
  15. keras/src/backend/common/variables.py +3 -1
  16. keras/src/backend/jax/export.py +1 -1
  17. keras/src/backend/jax/trainer.py +1 -1
  18. keras/src/backend/openvino/numpy.py +1 -1
  19. keras/src/backend/tensorflow/trainer.py +19 -1
  20. keras/src/backend/torch/core.py +6 -9
  21. keras/src/backend/torch/trainer.py +1 -1
  22. keras/src/callbacks/backup_and_restore.py +2 -2
  23. keras/src/callbacks/csv_logger.py +1 -1
  24. keras/src/callbacks/model_checkpoint.py +1 -1
  25. keras/src/callbacks/tensorboard.py +6 -6
  26. keras/src/datasets/boston_housing.py +1 -1
  27. keras/src/datasets/california_housing.py +1 -1
  28. keras/src/datasets/cifar10.py +1 -1
  29. keras/src/datasets/cifar100.py +2 -2
  30. keras/src/datasets/imdb.py +2 -2
  31. keras/src/datasets/mnist.py +1 -1
  32. keras/src/datasets/reuters.py +2 -2
  33. keras/src/dtype_policies/dtype_policy.py +1 -1
  34. keras/src/dtype_policies/dtype_policy_map.py +1 -1
  35. keras/src/export/tf2onnx_lib.py +1 -3
  36. keras/src/layers/input_spec.py +6 -6
  37. keras/src/layers/layer.py +1 -1
  38. keras/src/layers/preprocessing/category_encoding.py +3 -3
  39. keras/src/layers/preprocessing/data_layer.py +159 -0
  40. keras/src/layers/preprocessing/discretization.py +3 -3
  41. keras/src/layers/preprocessing/feature_space.py +4 -4
  42. keras/src/layers/preprocessing/image_preprocessing/aug_mix.py +7 -4
  43. keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py +3 -0
  44. keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py +2 -2
  45. keras/src/layers/preprocessing/image_preprocessing/center_crop.py +1 -1
  46. keras/src/layers/preprocessing/image_preprocessing/cut_mix.py +6 -3
  47. keras/src/layers/preprocessing/image_preprocessing/equalization.py +1 -1
  48. keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py +3 -0
  49. keras/src/layers/preprocessing/image_preprocessing/mix_up.py +7 -4
  50. keras/src/layers/preprocessing/image_preprocessing/rand_augment.py +3 -1
  51. keras/src/layers/preprocessing/image_preprocessing/random_brightness.py +1 -1
  52. keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py +3 -0
  53. keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py +3 -0
  54. keras/src/layers/preprocessing/image_preprocessing/random_contrast.py +1 -1
  55. keras/src/layers/preprocessing/image_preprocessing/random_crop.py +1 -1
  56. keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py +3 -0
  57. keras/src/layers/preprocessing/image_preprocessing/random_erasing.py +6 -3
  58. keras/src/layers/preprocessing/image_preprocessing/random_flip.py +1 -1
  59. keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py +3 -0
  60. keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py +1 -1
  61. keras/src/layers/preprocessing/image_preprocessing/random_hue.py +3 -0
  62. keras/src/layers/preprocessing/image_preprocessing/random_invert.py +3 -0
  63. keras/src/layers/preprocessing/image_preprocessing/random_perspective.py +3 -0
  64. keras/src/layers/preprocessing/image_preprocessing/random_posterization.py +3 -0
  65. keras/src/layers/preprocessing/image_preprocessing/random_rotation.py +1 -1
  66. keras/src/layers/preprocessing/image_preprocessing/random_saturation.py +3 -0
  67. keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py +3 -0
  68. keras/src/layers/preprocessing/image_preprocessing/random_shear.py +3 -0
  69. keras/src/layers/preprocessing/image_preprocessing/random_translation.py +3 -3
  70. keras/src/layers/preprocessing/image_preprocessing/random_zoom.py +3 -3
  71. keras/src/layers/preprocessing/image_preprocessing/resizing.py +3 -3
  72. keras/src/layers/preprocessing/image_preprocessing/solarization.py +3 -0
  73. keras/src/layers/preprocessing/mel_spectrogram.py +29 -25
  74. keras/src/layers/preprocessing/normalization.py +5 -2
  75. keras/src/layers/preprocessing/rescaling.py +3 -3
  76. keras/src/layers/rnn/bidirectional.py +4 -4
  77. keras/src/legacy/backend.py +9 -23
  78. keras/src/legacy/preprocessing/image.py +11 -22
  79. keras/src/legacy/preprocessing/text.py +1 -1
  80. keras/src/models/functional.py +2 -2
  81. keras/src/models/model.py +21 -3
  82. keras/src/ops/function.py +1 -1
  83. keras/src/ops/numpy.py +5 -5
  84. keras/src/ops/operation.py +3 -2
  85. keras/src/optimizers/base_optimizer.py +3 -4
  86. keras/src/quantizers/gptq.py +350 -0
  87. keras/src/quantizers/gptq_config.py +169 -0
  88. keras/src/quantizers/gptq_core.py +335 -0
  89. keras/src/quantizers/gptq_quant.py +133 -0
  90. keras/src/saving/file_editor.py +22 -20
  91. keras/src/saving/object_registration.py +1 -1
  92. keras/src/saving/saving_lib.py +4 -4
  93. keras/src/saving/serialization_lib.py +3 -5
  94. keras/src/trainers/compile_utils.py +1 -1
  95. keras/src/trainers/data_adapters/array_data_adapter.py +9 -3
  96. keras/src/trainers/data_adapters/data_adapter_utils.py +15 -5
  97. keras/src/trainers/data_adapters/generator_data_adapter.py +2 -0
  98. keras/src/trainers/data_adapters/grain_dataset_adapter.py +8 -2
  99. keras/src/trainers/data_adapters/tf_dataset_adapter.py +4 -2
  100. keras/src/trainers/data_adapters/torch_data_loader_adapter.py +3 -1
  101. keras/src/tree/dmtree_impl.py +19 -3
  102. keras/src/tree/optree_impl.py +3 -3
  103. keras/src/tree/tree_api.py +5 -2
  104. keras/src/utils/file_utils.py +13 -5
  105. keras/src/utils/io_utils.py +1 -1
  106. keras/src/utils/model_visualization.py +1 -1
  107. keras/src/utils/progbar.py +5 -5
  108. keras/src/utils/summary_utils.py +4 -4
  109. keras/src/version.py +1 -1
  110. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/METADATA +1 -1
  111. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/RECORD +113 -109
  112. keras/src/layers/preprocessing/tf_data_layer.py +0 -78
  113. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/WHEEL +0 -0
  114. {keras_nightly-3.12.0.dev2025082103.dist-info → keras_nightly-3.12.0.dev2025082203.dist-info}/top_level.txt +0 -0
@@ -385,10 +385,10 @@ def MobileNetV3(
385
385
  model_type, "_minimalistic" if minimalistic else "", str(alpha)
386
386
  )
387
387
  if include_top:
388
- file_name = "weights_mobilenet_v3_" + model_name + ".h5"
388
+ file_name = f"weights_mobilenet_v3_{model_name}.h5"
389
389
  file_hash = WEIGHTS_HASHES[model_name][0]
390
390
  else:
391
- file_name = "weights_mobilenet_v3_" + model_name + "_no_top_v2.h5"
391
+ file_name = f"weights_mobilenet_v3_{model_name}_no_top_v2.h5"
392
392
  file_hash = WEIGHTS_HASHES[model_name][1]
393
393
  weights_path = file_utils.get_file(
394
394
  file_name,
@@ -570,23 +570,23 @@ def _depth(v, divisor=8, min_value=None):
570
570
 
571
571
  def _se_block(inputs, filters, se_ratio, prefix):
572
572
  x = layers.GlobalAveragePooling2D(
573
- keepdims=True, name=prefix + "squeeze_excite_avg_pool"
573
+ keepdims=True, name=f"{prefix}squeeze_excite_avg_pool"
574
574
  )(inputs)
575
575
  x = layers.Conv2D(
576
576
  _depth(filters * se_ratio),
577
577
  kernel_size=1,
578
578
  padding="same",
579
- name=prefix + "squeeze_excite_conv",
579
+ name=f"{prefix}squeeze_excite_conv",
580
580
  )(x)
581
- x = layers.ReLU(name=prefix + "squeeze_excite_relu")(x)
581
+ x = layers.ReLU(name=f"{prefix}squeeze_excite_relu")(x)
582
582
  x = layers.Conv2D(
583
583
  filters,
584
584
  kernel_size=1,
585
585
  padding="same",
586
- name=prefix + "squeeze_excite_conv_1",
586
+ name=f"{prefix}squeeze_excite_conv_1",
587
587
  )(x)
588
588
  x = hard_sigmoid(x)
589
- x = layers.Multiply(name=prefix + "squeeze_excite_mul")([inputs, x])
589
+ x = layers.Multiply(name=f"{prefix}squeeze_excite_mul")([inputs, x])
590
590
  return x
591
591
 
592
592
 
@@ -605,33 +605,33 @@ def _inverted_res_block(
605
605
  kernel_size=1,
606
606
  padding="same",
607
607
  use_bias=False,
608
- name=prefix + "expand",
608
+ name=f"{prefix}expand",
609
609
  )(x)
610
610
  x = layers.BatchNormalization(
611
611
  axis=channel_axis,
612
612
  epsilon=1e-3,
613
613
  momentum=0.999,
614
- name=prefix + "expand_bn",
614
+ name=f"{prefix}expand_bn",
615
615
  )(x)
616
616
  x = activation(x)
617
617
 
618
618
  if stride == 2:
619
619
  x = layers.ZeroPadding2D(
620
620
  padding=imagenet_utils.correct_pad(x, kernel_size),
621
- name=prefix + "depthwise_pad",
621
+ name=f"{prefix}depthwise_pad",
622
622
  )(x)
623
623
  x = layers.DepthwiseConv2D(
624
624
  kernel_size,
625
625
  strides=stride,
626
626
  padding="same" if stride == 1 else "valid",
627
627
  use_bias=False,
628
- name=prefix + "depthwise",
628
+ name=f"{prefix}depthwise",
629
629
  )(x)
630
630
  x = layers.BatchNormalization(
631
631
  axis=channel_axis,
632
632
  epsilon=1e-3,
633
633
  momentum=0.999,
634
- name=prefix + "depthwise_bn",
634
+ name=f"{prefix}depthwise_bn",
635
635
  )(x)
636
636
  x = activation(x)
637
637
 
@@ -643,17 +643,17 @@ def _inverted_res_block(
643
643
  kernel_size=1,
644
644
  padding="same",
645
645
  use_bias=False,
646
- name=prefix + "project",
646
+ name=f"{prefix}project",
647
647
  )(x)
648
648
  x = layers.BatchNormalization(
649
649
  axis=channel_axis,
650
650
  epsilon=1e-3,
651
651
  momentum=0.999,
652
- name=prefix + "project_bn",
652
+ name=f"{prefix}project_bn",
653
653
  )(x)
654
654
 
655
655
  if stride == 1 and infilters == filters:
656
- x = layers.Add(name=prefix + "add")([shortcut, x])
656
+ x = layers.Add(name=f"{prefix}add")([shortcut, x])
657
657
  return x
658
658
 
659
659
 
@@ -11,10 +11,10 @@ from keras.src.utils import file_utils
11
11
  BASE_WEIGHTS_PATH = (
12
12
  "https://storage.googleapis.com/tensorflow/keras-applications/nasnet/"
13
13
  )
14
- NASNET_MOBILE_WEIGHT_PATH = BASE_WEIGHTS_PATH + "NASNet-mobile.h5"
15
- NASNET_MOBILE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + "NASNet-mobile-no-top.h5"
16
- NASNET_LARGE_WEIGHT_PATH = BASE_WEIGHTS_PATH + "NASNet-large.h5"
17
- NASNET_LARGE_WEIGHT_PATH_NO_TOP = BASE_WEIGHTS_PATH + "NASNet-large-no-top.h5"
14
+ NASNET_MOBILE_WEIGHT_PATH = f"{BASE_WEIGHTS_PATH}NASNet-mobile.h5"
15
+ NASNET_MOBILE_WEIGHT_PATH_NO_TOP = f"{BASE_WEIGHTS_PATH}NASNet-mobile-no-top.h5"
16
+ NASNET_LARGE_WEIGHT_PATH = f"{BASE_WEIGHTS_PATH}NASNet-large.h5"
17
+ NASNET_LARGE_WEIGHT_PATH_NO_TOP = f"{BASE_WEIGHTS_PATH}NASNet-large-no-top.h5"
18
18
 
19
19
 
20
20
  def NASNet(
@@ -137,10 +137,9 @@ def NASNet(
137
137
  and weights == "imagenet"
138
138
  ):
139
139
  raise ValueError(
140
- "When specifying the input shape of a NASNet"
141
- " and loading `ImageNet` weights, "
142
- "the input_shape argument must be static "
143
- "(no None entries). Got: `input_shape=" + str(input_shape) + "`."
140
+ "When specifying the input shape of a NASNet and loading "
141
+ "`ImageNet` weights, the input_shape argument must be static"
142
+ f" (no None entries). Got: `input_shape={input_shape}`."
144
143
  )
145
144
 
146
145
  if default_size is None:
@@ -196,16 +196,16 @@ def ResNet(
196
196
  # Load weights.
197
197
  if (weights == "imagenet") and (weights_name in WEIGHTS_HASHES):
198
198
  if include_top:
199
- file_name = weights_name + "_weights_tf_dim_ordering_tf_kernels.h5"
199
+ file_name = f"{weights_name}_weights_tf_dim_ordering_tf_kernels.h5"
200
200
  file_hash = WEIGHTS_HASHES[weights_name][0]
201
201
  else:
202
202
  file_name = (
203
- weights_name + "_weights_tf_dim_ordering_tf_kernels_notop.h5"
203
+ f"{weights_name}_weights_tf_dim_ordering_tf_kernels_notop.h5"
204
204
  )
205
205
  file_hash = WEIGHTS_HASHES[weights_name][1]
206
206
  weights_path = file_utils.get_file(
207
207
  file_name,
208
- BASE_WEIGHTS_PATH + file_name,
208
+ f"{BASE_WEIGHTS_PATH}{file_name}",
209
209
  cache_subdir="models",
210
210
  file_hash=file_hash,
211
211
  )
@@ -241,35 +241,35 @@ def residual_block_v1(
241
241
 
242
242
  if conv_shortcut:
243
243
  shortcut = layers.Conv2D(
244
- 4 * filters, 1, strides=stride, name=name + "_0_conv"
244
+ 4 * filters, 1, strides=stride, name=f"{name}_0_conv"
245
245
  )(x)
246
246
  shortcut = layers.BatchNormalization(
247
- axis=bn_axis, epsilon=1.001e-5, name=name + "_0_bn"
247
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_0_bn"
248
248
  )(shortcut)
249
249
  else:
250
250
  shortcut = x
251
251
 
252
- x = layers.Conv2D(filters, 1, strides=stride, name=name + "_1_conv")(x)
252
+ x = layers.Conv2D(filters, 1, strides=stride, name=f"{name}_1_conv")(x)
253
253
  x = layers.BatchNormalization(
254
- axis=bn_axis, epsilon=1.001e-5, name=name + "_1_bn"
254
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_1_bn"
255
255
  )(x)
256
- x = layers.Activation("relu", name=name + "_1_relu")(x)
256
+ x = layers.Activation("relu", name=f"{name}_1_relu")(x)
257
257
 
258
258
  x = layers.Conv2D(
259
- filters, kernel_size, padding="SAME", name=name + "_2_conv"
259
+ filters, kernel_size, padding="SAME", name=f"{name}_2_conv"
260
260
  )(x)
261
261
  x = layers.BatchNormalization(
262
- axis=bn_axis, epsilon=1.001e-5, name=name + "_2_bn"
262
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_2_bn"
263
263
  )(x)
264
- x = layers.Activation("relu", name=name + "_2_relu")(x)
264
+ x = layers.Activation("relu", name=f"{name}_2_relu")(x)
265
265
 
266
- x = layers.Conv2D(4 * filters, 1, name=name + "_3_conv")(x)
266
+ x = layers.Conv2D(4 * filters, 1, name=f"{name}_3_conv")(x)
267
267
  x = layers.BatchNormalization(
268
- axis=bn_axis, epsilon=1.001e-5, name=name + "_3_bn"
268
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_3_bn"
269
269
  )(x)
270
270
 
271
- x = layers.Add(name=name + "_add")([shortcut, x])
272
- x = layers.Activation("relu", name=name + "_out")(x)
271
+ x = layers.Add(name=f"{name}_add")([shortcut, x])
272
+ x = layers.Activation("relu", name=f"{name}_out")(x)
273
273
  return x
274
274
 
275
275
 
@@ -287,10 +287,10 @@ def stack_residual_blocks_v1(x, filters, blocks, stride1=2, name=None):
287
287
  Output tensor for the stacked blocks.
288
288
  """
289
289
 
290
- x = residual_block_v1(x, filters, stride=stride1, name=name + "_block1")
290
+ x = residual_block_v1(x, filters, stride=stride1, name=f"{name}_block1")
291
291
  for i in range(2, blocks + 1):
292
292
  x = residual_block_v1(
293
- x, filters, conv_shortcut=False, name=name + "_block" + str(i)
293
+ x, filters, conv_shortcut=False, name=f"{name}_block{i}"
294
294
  )
295
295
  return x
296
296
 
@@ -319,13 +319,13 @@ def residual_block_v2(
319
319
  bn_axis = 1
320
320
 
321
321
  preact = layers.BatchNormalization(
322
- axis=bn_axis, epsilon=1.001e-5, name=name + "_preact_bn"
322
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_preact_bn"
323
323
  )(x)
324
- preact = layers.Activation("relu", name=name + "_preact_relu")(preact)
324
+ preact = layers.Activation("relu", name=f"{name}_preact_relu")(preact)
325
325
 
326
326
  if conv_shortcut:
327
327
  shortcut = layers.Conv2D(
328
- 4 * filters, 1, strides=stride, name=name + "_0_conv"
328
+ 4 * filters, 1, strides=stride, name=f"{name}_0_conv"
329
329
  )(preact)
330
330
  else:
331
331
  shortcut = (
@@ -333,28 +333,28 @@ def residual_block_v2(
333
333
  )
334
334
 
335
335
  x = layers.Conv2D(
336
- filters, 1, strides=1, use_bias=False, name=name + "_1_conv"
336
+ filters, 1, strides=1, use_bias=False, name=f"{name}_1_conv"
337
337
  )(preact)
338
338
  x = layers.BatchNormalization(
339
- axis=bn_axis, epsilon=1.001e-5, name=name + "_1_bn"
339
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_1_bn"
340
340
  )(x)
341
- x = layers.Activation("relu", name=name + "_1_relu")(x)
341
+ x = layers.Activation("relu", name=f"{name}_1_relu")(x)
342
342
 
343
- x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name + "_2_pad")(x)
343
+ x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=f"{name}_2_pad")(x)
344
344
  x = layers.Conv2D(
345
345
  filters,
346
346
  kernel_size,
347
347
  strides=stride,
348
348
  use_bias=False,
349
- name=name + "_2_conv",
349
+ name=f"{name}_2_conv",
350
350
  )(x)
351
351
  x = layers.BatchNormalization(
352
- axis=bn_axis, epsilon=1.001e-5, name=name + "_2_bn"
352
+ axis=bn_axis, epsilon=1.001e-5, name=f"{name}_2_bn"
353
353
  )(x)
354
- x = layers.Activation("relu", name=name + "_2_relu")(x)
354
+ x = layers.Activation("relu", name=f"{name}_2_relu")(x)
355
355
 
356
- x = layers.Conv2D(4 * filters, 1, name=name + "_3_conv")(x)
357
- x = layers.Add(name=name + "_out")([shortcut, x])
356
+ x = layers.Conv2D(4 * filters, 1, name=f"{name}_3_conv")(x)
357
+ x = layers.Add(name=f"{name}_out")([shortcut, x])
358
358
  return x
359
359
 
360
360
 
@@ -372,11 +372,11 @@ def stack_residual_blocks_v2(x, filters, blocks, stride1=2, name=None):
372
372
  Output tensor for the stacked blocks.
373
373
  """
374
374
 
375
- x = residual_block_v2(x, filters, conv_shortcut=True, name=name + "_block1")
375
+ x = residual_block_v2(x, filters, conv_shortcut=True, name=f"{name}_block1")
376
376
  for i in range(2, blocks):
377
- x = residual_block_v2(x, filters, name=name + "_block" + str(i))
377
+ x = residual_block_v2(x, filters, name=f"{name}_block{i}")
378
378
  x = residual_block_v2(
379
- x, filters, stride=stride1, name=name + "_block" + str(blocks)
379
+ x, filters, stride=stride1, name=f"{name}_block{str(blocks)}"
380
380
  )
381
381
  return x
382
382
 
@@ -212,40 +212,40 @@ def Xception(
212
212
 
213
213
  for i in range(8):
214
214
  residual = x
215
- prefix = "block" + str(i + 5)
215
+ prefix = f"block{i + 5}"
216
216
 
217
- x = layers.Activation("relu", name=prefix + "_sepconv1_act")(x)
217
+ x = layers.Activation("relu", name=f"{prefix}_sepconv1_act")(x)
218
218
  x = layers.SeparableConv2D(
219
219
  728,
220
220
  (3, 3),
221
221
  padding="same",
222
222
  use_bias=False,
223
- name=prefix + "_sepconv1",
223
+ name=f"{prefix}_sepconv1",
224
224
  )(x)
225
225
  x = layers.BatchNormalization(
226
- axis=channel_axis, name=prefix + "_sepconv1_bn"
226
+ axis=channel_axis, name=f"{prefix}_sepconv1_bn"
227
227
  )(x)
228
- x = layers.Activation("relu", name=prefix + "_sepconv2_act")(x)
228
+ x = layers.Activation("relu", name=f"{prefix}_sepconv2_act")(x)
229
229
  x = layers.SeparableConv2D(
230
230
  728,
231
231
  (3, 3),
232
232
  padding="same",
233
233
  use_bias=False,
234
- name=prefix + "_sepconv2",
234
+ name=f"{prefix}_sepconv2",
235
235
  )(x)
236
236
  x = layers.BatchNormalization(
237
- axis=channel_axis, name=prefix + "_sepconv2_bn"
237
+ axis=channel_axis, name=f"{prefix}_sepconv2_bn"
238
238
  )(x)
239
- x = layers.Activation("relu", name=prefix + "_sepconv3_act")(x)
239
+ x = layers.Activation("relu", name=f"{prefix}_sepconv3_act")(x)
240
240
  x = layers.SeparableConv2D(
241
241
  728,
242
242
  (3, 3),
243
243
  padding="same",
244
244
  use_bias=False,
245
- name=prefix + "_sepconv3",
245
+ name=f"{prefix}_sepconv3",
246
246
  )(x)
247
247
  x = layers.BatchNormalization(
248
- axis=channel_axis, name=prefix + "_sepconv3_bn"
248
+ axis=channel_axis, name=f"{prefix}_sepconv3_bn"
249
249
  )(x)
250
250
 
251
251
  x = layers.add([x, residual])
@@ -225,11 +225,11 @@ def _resolve_weak_type(dtype, precision="32"):
225
225
  if dtype_indicator == "b":
226
226
  return "bool"
227
227
  elif dtype_indicator == "i":
228
- return "int" + precision
228
+ return f"int{precision}"
229
229
  elif dtype_indicator == "u":
230
- return "uint" + precision
230
+ return f"uint{precision}"
231
231
  else:
232
- return "float" + precision
232
+ return f"float{precision}"
233
233
 
234
234
 
235
235
  BIT64_TO_BIT16_DTYPE = {
@@ -1,3 +1,5 @@
1
+ import os.path
2
+
1
3
  import numpy as np
2
4
 
3
5
  from keras.src import backend
@@ -142,7 +144,7 @@ class Variable:
142
144
  self._name = name
143
145
  parent_path = current_path()
144
146
  if parent_path:
145
- self._path = current_path() + "/" + name
147
+ self._path = os.path.join(current_path(), name)
146
148
  else:
147
149
  self._path = name
148
150
  self._shape = None
@@ -159,7 +159,7 @@ class JaxExportArchive:
159
159
  poly_shape.append("batch")
160
160
  else:
161
161
  poly_shape.append(next(dim_names))
162
- return "(" + ", ".join(poly_shape) + ")"
162
+ return f"({', '.join(poly_shape)})"
163
163
 
164
164
  return tree.map_structure(convert_shape, struct)
165
165
 
@@ -504,7 +504,7 @@ class JAXTrainer(base_trainer.Trainer):
504
504
  _use_cached_eval_dataset=True,
505
505
  )
506
506
  val_logs = {
507
- "val_" + name: val for name, val in val_logs.items()
507
+ f"val_{name}": val for name, val in val_logs.items()
508
508
  }
509
509
  epoch_logs.update(val_logs)
510
510
 
@@ -687,7 +687,7 @@ def diff(a, n=1, axis=-1):
687
687
  if n == 0:
688
688
  return OpenVINOKerasTensor(get_ov_output(a))
689
689
  if n < 0:
690
- raise ValueError("order must be non-negative but got " + repr(n))
690
+ raise ValueError(f"order must be non-negative but got {repr(n)}")
691
691
  a = get_ov_output(a)
692
692
  a_type = a.get_element_type()
693
693
  if isinstance(a, np.ndarray):
@@ -1,4 +1,5 @@
1
1
  import contextlib
2
+ import functools
2
3
  import warnings
3
4
 
4
5
  import numpy as np
@@ -107,6 +108,21 @@ class TensorFlowTrainer(base_trainer.Trainer):
107
108
  y_pred = self(x)
108
109
  return y_pred
109
110
 
111
+ def _autoconvert_optionals(self, step_func):
112
+ # Wrapper converting (nested) TF Optional in input data to None
113
+ @functools.wraps(step_func)
114
+ def wrapper(data):
115
+ converted_data = tree.map_structure(
116
+ lambda i: (
117
+ None if isinstance(i, tf.experimental.Optional) else i
118
+ ),
119
+ data,
120
+ )
121
+ result = step_func(converted_data)
122
+ return result
123
+
124
+ return wrapper
125
+
110
126
  def _make_function(self, step_function):
111
127
  @tf.autograph.experimental.do_not_convert
112
128
  def one_step_on_data(data):
@@ -125,6 +141,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
125
141
  reduce_retracing=True,
126
142
  jit_compile=self.jit_compile,
127
143
  )
144
+ one_step_on_data = self._autoconvert_optionals(one_step_on_data)
128
145
 
129
146
  @tf.autograph.experimental.do_not_convert
130
147
  def multi_step_on_iterator(iterator):
@@ -253,6 +270,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
253
270
  one_step_on_data = tf.function(
254
271
  one_step_on_data, reduce_retracing=True, jit_compile=True
255
272
  )
273
+ one_step_on_data = self._autoconvert_optionals(one_step_on_data)
256
274
 
257
275
  @tf.autograph.experimental.do_not_convert
258
276
  def one_step_on_data_distributed(data):
@@ -409,7 +427,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
409
427
  _use_cached_eval_dataset=True,
410
428
  )
411
429
  val_logs = {
412
- "val_" + name: val for name, val in val_logs.items()
430
+ f"val_{name}": val for name, val in val_logs.items()
413
431
  }
414
432
  epoch_logs.update(val_logs)
415
433
 
@@ -191,21 +191,18 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
191
191
  raise ValueError("`sparse=True` is not supported with torch backend")
192
192
  if ragged:
193
193
  raise ValueError("`ragged=True` is not supported with torch backend")
194
- if isinstance(x, Variable):
195
- if dtype is None:
196
- return x.value
197
- x = x.value
198
- return x.to(to_torch_dtype(dtype))
199
- if is_tensor(x):
194
+ if isinstance(x, Variable) or is_tensor(x):
195
+ if isinstance(x, Variable):
196
+ x = x.value
200
197
  device = get_device()
201
198
  if x.device != device:
202
199
  if x.is_meta:
203
200
  x = torch.empty_like(x, device=device)
204
201
  else:
205
202
  x = x.to(device)
206
- if dtype is None:
207
- return x
208
- return x.to(to_torch_dtype(dtype))
203
+ if dtype is not None:
204
+ x = x.to(to_torch_dtype(dtype))
205
+ return x
209
206
  if dtype is None:
210
207
  if isinstance(x, bool):
211
208
  return torch.as_tensor(x, dtype=torch.bool, device=get_device())
@@ -299,7 +299,7 @@ class TorchTrainer(base_trainer.Trainer):
299
299
  _use_cached_eval_dataset=True,
300
300
  )
301
301
  val_logs = {
302
- "val_" + name: val for name, val in val_logs.items()
302
+ f"val_{name}": val for name, val in val_logs.items()
303
303
  }
304
304
  epoch_logs.update(val_logs)
305
305
 
@@ -99,9 +99,9 @@ class BackupAndRestore(Callback):
99
99
  self._training_metadata_path = file_utils.join(
100
100
  backup_dir, "training_metadata.json"
101
101
  )
102
- self._prev_weights_path = self._weights_path + ".bkp"
102
+ self._prev_weights_path = f"{self._weights_path}.bkp"
103
103
  self._prev_training_metadata_path = (
104
- self._training_metadata_path + ".bkp"
104
+ f"{self._training_metadata_path}.bkp"
105
105
  )
106
106
  if save_freq != "epoch" and not isinstance(save_freq, int):
107
107
  raise ValueError(
@@ -79,7 +79,7 @@ class CSVLogger(Callback):
79
79
  val_keys_found = True
80
80
  break
81
81
  if not val_keys_found and self.keys:
82
- self.keys.extend(["val_" + k for k in self.keys])
82
+ self.keys.extend([f"val_{k}" for k in self.keys])
83
83
 
84
84
  if not self.writer:
85
85
 
@@ -372,7 +372,7 @@ class ModelCheckpoint(MonitorCallback):
372
372
  """
373
373
  dir_name = os.path.dirname(pattern)
374
374
  base_name = os.path.basename(pattern)
375
- base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$"
375
+ base_name_regex = f"^{re.sub(r'{.*}', r'.*', base_name)}$"
376
376
 
377
377
  latest_mod_time = 0
378
378
  file_path_with_latest_mod_time = None
@@ -424,7 +424,7 @@ class TensorBoard(Callback):
424
424
  with self._val_writer.as_default():
425
425
  for name, value in logs.items():
426
426
  self.summary.scalar(
427
- "evaluation_" + name + "_vs_iterations",
427
+ f"evaluation_{name}_vs_iterations",
428
428
  value,
429
429
  step=self.model.optimizer.iterations,
430
430
  )
@@ -460,7 +460,7 @@ class TensorBoard(Callback):
460
460
  if isinstance(logs, dict):
461
461
  for name, value in logs.items():
462
462
  self.summary.scalar(
463
- "batch_" + name, value, step=self._global_train_batch
463
+ f"batch_{name}", value, step=self._global_train_batch
464
464
  )
465
465
 
466
466
  if not self._should_trace:
@@ -548,12 +548,12 @@ class TensorBoard(Callback):
548
548
  if train_logs:
549
549
  with self._train_writer.as_default():
550
550
  for name, value in train_logs.items():
551
- self.summary.scalar("epoch_" + name, value, step=epoch)
551
+ self.summary.scalar(f"epoch_{name}", value, step=epoch)
552
552
  if val_logs:
553
553
  with self._val_writer.as_default():
554
554
  for name, value in val_logs.items():
555
555
  name = name[4:] # Remove 'val_' prefix.
556
- self.summary.scalar("epoch_" + name, value, step=epoch)
556
+ self.summary.scalar(f"epoch_{name}", value, step=epoch)
557
557
 
558
558
  def _log_weights(self, epoch):
559
559
  """Logs the weights of the Model to TensorBoard."""
@@ -562,14 +562,14 @@ class TensorBoard(Callback):
562
562
  for weight in layer.weights:
563
563
  weight_name = weight.name.replace(":", "_")
564
564
  # Add a suffix to prevent summary tag name collision.
565
- histogram_weight_name = weight_name + "/histogram"
565
+ histogram_weight_name = f"{weight_name}/histogram"
566
566
  self.summary.histogram(
567
567
  histogram_weight_name, weight, step=epoch
568
568
  )
569
569
  if self.write_images:
570
570
  # Add a suffix to prevent summary tag name
571
571
  # collision.
572
- image_weight_name = weight_name + "/image"
572
+ image_weight_name = f"{weight_name}/image"
573
573
  self._log_weight_as_image(
574
574
  weight, image_weight_name, epoch
575
575
  )
@@ -48,7 +48,7 @@ def load_data(path="boston_housing.npz", test_split=0.2, seed=113):
48
48
  )
49
49
  path = get_file(
50
50
  path,
51
- origin=origin_folder + "boston_housing.npz",
51
+ origin=f"{origin_folder}boston_housing.npz",
52
52
  file_hash=( # noqa: E501
53
53
  "f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5"
54
54
  ),
@@ -73,7 +73,7 @@ def load_data(
73
73
  )
74
74
  path = get_file(
75
75
  path,
76
- origin=origin_folder + "california_housing.npz",
76
+ origin=f"{origin_folder}california_housing.npz",
77
77
  file_hash=( # noqa: E501
78
78
  "1a2e3a52e0398de6463aebe6f4a8da34fb21fbb6b934cf88c3425e766f2a1a6f"
79
79
  ),
@@ -79,7 +79,7 @@ def load_data():
79
79
  # batches are within an inner folder
80
80
  path = os.path.join(path, "cifar-10-batches-py")
81
81
  for i in range(1, 6):
82
- fpath = os.path.join(path, "data_batch_" + str(i))
82
+ fpath = os.path.join(path, f"data_batch_{i}")
83
83
  (
84
84
  x_train[(i - 1) * 10000 : i * 10000, :, :, :],
85
85
  y_train[(i - 1) * 10000 : i * 10000],
@@ -71,10 +71,10 @@ def load_data(label_mode="fine"):
71
71
 
72
72
  path = os.path.join(path, "cifar-100-python")
73
73
  fpath = os.path.join(path, "train")
74
- x_train, y_train = load_batch(fpath, label_key=label_mode + "_labels")
74
+ x_train, y_train = load_batch(fpath, label_key=f"{label_mode}_labels")
75
75
 
76
76
  fpath = os.path.join(path, "test")
77
- x_test, y_test = load_batch(fpath, label_key=label_mode + "_labels")
77
+ x_test, y_test = load_batch(fpath, label_key=f"{label_mode}_labels")
78
78
 
79
79
  y_train = np.reshape(y_train, (len(y_train), 1))
80
80
  y_test = np.reshape(y_test, (len(y_test), 1))
@@ -78,7 +78,7 @@ def load_data(
78
78
  )
79
79
  path = get_file(
80
80
  fname=path,
81
- origin=origin_folder + "imdb.npz",
81
+ origin=f"{origin_folder}imdb.npz",
82
82
  file_hash=( # noqa: E501
83
83
  "69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f"
84
84
  ),
@@ -181,7 +181,7 @@ def get_word_index(path="imdb_word_index.json"):
181
181
  )
182
182
  path = get_file(
183
183
  fname=path,
184
- origin=origin_folder + "imdb_word_index.json",
184
+ origin=f"{origin_folder}imdb_word_index.json",
185
185
  file_hash="bfafd718b763782e994055a2d397834f",
186
186
  )
187
187
  with open(path) as f:
@@ -59,7 +59,7 @@ def load_data(path="mnist.npz"):
59
59
  )
60
60
  path = get_file(
61
61
  fname=path,
62
- origin=origin_folder + "mnist.npz",
62
+ origin=f"{origin_folder}mnist.npz",
63
63
  file_hash=( # noqa: E501
64
64
  "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"
65
65
  ),