tf-nightly-cpu 2.20.0.dev20250220__cp310-cp310-win_amd64.whl → 2.20.0.dev20250221__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (113) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  14. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  15. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  17. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  18. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -18
  19. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  20. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  21. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  22. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  23. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  24. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  25. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  26. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  27. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  28. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  30. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  31. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  32. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  33. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  34. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  35. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  36. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  37. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  38. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  39. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  40. tensorflow/include/tensorflow/core/public/version.h +112 -127
  41. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  42. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  43. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  44. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -18
  45. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  46. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  47. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  48. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  49. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  50. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  51. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  52. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  53. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  54. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  55. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  56. tensorflow/include/xla/service/constant_value.h +1 -0
  57. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  58. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  59. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  60. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  61. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  62. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  63. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  64. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  65. tensorflow/python/_pywrap_mlir.pyd +0 -0
  66. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  67. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  68. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  69. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  70. tensorflow/python/_pywrap_tfe.pyd +0 -0
  71. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  72. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  73. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  74. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  75. tensorflow/python/compat/compat.py +1 -1
  76. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  77. tensorflow/python/eager/imperative_grad.py +5 -5
  78. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  79. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  80. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  81. tensorflow/python/eager/tape.py +2 -2
  82. tensorflow/python/framework/_dtypes.pyd +0 -0
  83. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  84. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  85. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  86. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  87. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  88. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  89. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  90. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  91. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  92. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  93. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  94. tensorflow/python/ops/summary_ops_v2.py +5 -1
  95. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  96. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  97. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  98. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  99. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  100. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  101. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  102. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  103. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  104. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  105. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  106. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  107. tensorflow/python/util/_tf_stack.pyd +0 -0
  108. tensorflow/tools/pip_package/setup.py +2 -2
  109. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/METADATA +1 -1
  110. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/RECORD +113 -106
  111. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/WHEEL +0 -0
  112. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/entry_points.txt +0 -0
  113. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -1,127 +1,112 @@
1
- /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
-
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- ==============================================================================*/
15
-
16
- #ifndef TENSORFLOW_CORE_PUBLIC_VERSION_H_
17
- #define TENSORFLOW_CORE_PUBLIC_VERSION_H_
18
-
19
- // TensorFlow uses semantic versioning, see http://semver.org/.
20
-
21
- // Also update tensorflow/tensorflow.bzl and
22
- // tensorflow/tools/pip_package/setup.py
23
- #define TF_MAJOR_VERSION 2
24
- #define TF_MINOR_VERSION 20
25
- #define TF_PATCH_VERSION 0
26
-
27
- // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
28
- // "-beta", "-rc", "-rc.1")
29
- #define TF_VERSION_SUFFIX "-dev20250220"
30
-
31
- #define TF_STR_HELPER(x) #x
32
- #define TF_STR(x) TF_STR_HELPER(x)
33
-
34
- // e.g. "0.5.0" or "0.6.0-alpha".
35
- #define TF_VERSION_STRING \
36
- (TF_STR(TF_MAJOR_VERSION) "." TF_STR(TF_MINOR_VERSION) "." TF_STR( \
37
- TF_PATCH_VERSION) TF_VERSION_SUFFIX)
38
-
39
- // GraphDef compatibility versions (the versions field in graph.proto).
40
- //
41
- // Each graph has producer and min_consumer versions, and each
42
- // consumer has its own version and a min_producer. In addition, graphs can
43
- // mark specific consumer versions as bad (to prevent bugs from executing).
44
- // A consumer will execute a graph if the consumer's version is at least the
45
- // graph's min_consumer, the graph's producer version is at least the consumer's
46
- // min_producer, and the consumer version isn't specifically disallowed by the
47
- // graph.
48
- //
49
- // By default, newly created graphs have producer version TF_GRAPH_DEF_VERSION
50
- // min_consumer TF_GRAPH_DEF_MIN_CONSUMER, and no other bad consumer versions.
51
- //
52
- // Version history:
53
- //
54
- // 0. Graphs created before GraphDef versioning
55
- // 1. First real version (2dec2015)
56
- // 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015)
57
- // 3. Remove TileGrad, since it was equivalent to reduce_sum (30dec2015)
58
- // 4. When support for this version is removed, we can safely make AttrValue
59
- // parsing more strict with respect to empty list values (see
60
- // 111635679, 7jan2016).
61
- // 5. Graphs are wholly-validated during Session::Create() (7jan2016).
62
- // 6. TensorFlow is scalar strict within Google (27jan2016).
63
- // 7. Remove TopK in favor of TopKV2 (5feb2016).
64
- // 8. Replace RandomCrop from C++ with pure Python (5feb2016).
65
- // 9. Deprecate batch_norm_with_global_normalization (16feb2016).
66
- // 10. Deprecate conv3d_backprop_{filter,input} (10jun2016).
67
- // 11. Deprecate {batch}_self_adjoint_eig (3aug2016).
68
- // 12. Graph consumers understand the node_def field of FunctionDef (22aug2016).
69
- // 13. Deprecate multiple batch linear algebra ops (9sep2016).
70
- // 14. Deprecate batch_matrix_* ops. (10sep2016).
71
- // 15. Deprecate batch_fft_* ops. (14sep2016).
72
- // 16. Deprecate tensor_array (v1) ops in favor of v2 (10nov2016).
73
- // 17. Deprecate inv (11nov2016).
74
- // 17. Expose reverse_v2 (10nov2016)
75
- // 18. Add VariableV2 (30nov2016)
76
- // 19. Deprecated ops created by models moved out of core SkipGram, NegTrain.
77
- // (08dec2016)
78
- // 20. Catch all version 1.0 changes to Python API generation. SplitV is now
79
- // used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
80
- // now used by tf.concat. Graphs use flooring
81
- // division and mod semantics. TensorArrayV3. (12dec2016)
82
- // Also considered the version for when it is required for reduction
83
- // ops' indices to be scalar or vector, and not higher rank.
84
- // Some earlier graph def versions allowed this.
85
- // 21. Dropped FunctionDef.Node support, switched to node_def introduced
86
- // in version 12. (11jan2017)
87
- // 22. Placeholder now can specify and enforce scalar and partial
88
- // shapes, particularly when restoring a graph from GraphDef
89
- // produced at version 22 or later. (04/10/2016)
90
- // 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
91
- // 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
92
- // 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
93
- // 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25).
94
- // 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating
95
- // whether default-valued attrs have been stripped from the nodes in the
96
- // GraphDef. (7dec2017)
97
- // 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops
98
- // deprecated in favor of V2 ops. (2018/01/23)
99
- // 28. Deprecate MatrixExponential op in favor of Python implementation.
100
- // (2018/08/21).
101
- // (2019/02/15). Added `control_ret` field to FunctionDef proto, and
102
- // `control_output` field to OpDef proto.
103
- // 29. Deprecate StatefulStandardNormal op in favor of StatefulStandardNormalV2.
104
- // (2019/03/25).
105
- // (2019/04/17). Added `arg_attr` field to FunctionDefProto.
106
- // 30. (2019/05/09) First date based GraphDef version. GraphDef
107
- // versions advance by 1 each day after this point.
108
-
109
- #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
110
- #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
111
- #define TF_GRAPH_DEF_VERSION 2143 // Updated: 2025/2/19
112
-
113
- // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
114
- //
115
- // The checkpoint versions have the same semantics as GraphDef versions, but the
116
- // numbering scheme is separate. We have no plans to ever deprecate checkpoint
117
- // versions, but it's good to have this in place in case we ever need to.
118
- //
119
- // Version history:
120
- //
121
- // 0. Checkpoints saved before checkpoint versioning.
122
- // 1. First real version (10feb2015).
123
- #define TF_CHECKPOINT_VERSION_MIN_PRODUCER 0
124
- #define TF_CHECKPOINT_VERSION_MIN_CONSUMER 0
125
- #define TF_CHECKPOINT_VERSION 1
126
-
127
- #endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_
1
+ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef TENSORFLOW_CORE_PUBLIC_VERSION_H_
17
+ #define TENSORFLOW_CORE_PUBLIC_VERSION_H_
18
+
19
+ // TensorFlow uses semantic versioning, see http://semver.org/.
20
+
21
+ #define TF_STR_HELPER(x) #x
22
+ #define TF_STR(x) TF_STR_HELPER(x)
23
+
24
+ // GraphDef compatibility versions (the versions field in graph.proto).
25
+ //
26
+ // Each graph has producer and min_consumer versions, and each
27
+ // consumer has its own version and a min_producer. In addition, graphs can
28
+ // mark specific consumer versions as bad (to prevent bugs from executing).
29
+ // A consumer will execute a graph if the consumer's version is at least the
30
+ // graph's min_consumer, the graph's producer version is at least the consumer's
31
+ // min_producer, and the consumer version isn't specifically disallowed by the
32
+ // graph.
33
+ //
34
+ // By default, newly created graphs have producer version TF_GRAPH_DEF_VERSION
35
+ // min_consumer TF_GRAPH_DEF_MIN_CONSUMER, and no other bad consumer versions.
36
+ //
37
+ // Version history:
38
+ //
39
+ // 0. Graphs created before GraphDef versioning
40
+ // 1. First real version (2dec2015)
41
+ // 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015)
42
+ // 3. Remove TileGrad, since it was equivalent to reduce_sum (30dec2015)
43
+ // 4. When support for this version is removed, we can safely make AttrValue
44
+ // parsing more strict with respect to empty list values (see
45
+ // 111635679, 7jan2016).
46
+ // 5. Graphs are wholly-validated during Session::Create() (7jan2016).
47
+ // 6. TensorFlow is scalar strict within Google (27jan2016).
48
+ // 7. Remove TopK in favor of TopKV2 (5feb2016).
49
+ // 8. Replace RandomCrop from C++ with pure Python (5feb2016).
50
+ // 9. Deprecate batch_norm_with_global_normalization (16feb2016).
51
+ // 10. Deprecate conv3d_backprop_{filter,input} (10jun2016).
52
+ // 11. Deprecate {batch}_self_adjoint_eig (3aug2016).
53
+ // 12. Graph consumers understand the node_def field of FunctionDef (22aug2016).
54
+ // 13. Deprecate multiple batch linear algebra ops (9sep2016).
55
+ // 14. Deprecate batch_matrix_* ops. (10sep2016).
56
+ // 15. Deprecate batch_fft_* ops. (14sep2016).
57
+ // 16. Deprecate tensor_array (v1) ops in favor of v2 (10nov2016).
58
+ // 17. Deprecate inv (11nov2016).
59
+ // 17. Expose reverse_v2 (10nov2016)
60
+ // 18. Add VariableV2 (30nov2016)
61
+ // 19. Deprecated ops created by models moved out of core SkipGram, NegTrain.
62
+ // (08dec2016)
63
+ // 20. Catch all version 1.0 changes to Python API generation. SplitV is now
64
+ // used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is
65
+ // now used by tf.concat. Graphs use flooring
66
+ // division and mod semantics. TensorArrayV3. (12dec2016)
67
+ // Also considered the version for when it is required for reduction
68
+ // ops' indices to be scalar or vector, and not higher rank.
69
+ // Some earlier graph def versions allowed this.
70
+ // 21. Dropped FunctionDef.Node support, switched to node_def introduced
71
+ // in version 12. (11jan2017)
72
+ // 22. Placeholder now can specify and enforce scalar and partial
73
+ // shapes, particularly when restoring a graph from GraphDef
74
+ // produced at version 22 or later. (04/10/2016)
75
+ // 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
76
+ // 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
77
+ // 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
78
+ // 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25).
79
+ // 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating
80
+ // whether default-valued attrs have been stripped from the nodes in the
81
+ // GraphDef. (7dec2017)
82
+ // 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops
83
+ // deprecated in favor of V2 ops. (2018/01/23)
84
+ // 28. Deprecate MatrixExponential op in favor of Python implementation.
85
+ // (2018/08/21).
86
+ // (2019/02/15). Added `control_ret` field to FunctionDef proto, and
87
+ // `control_output` field to OpDef proto.
88
+ // 29. Deprecate StatefulStandardNormal op in favor of StatefulStandardNormalV2.
89
+ // (2019/03/25).
90
+ // (2019/04/17). Added `arg_attr` field to FunctionDefProto.
91
+ // 30. (2019/05/09) First date based GraphDef version. GraphDef
92
+ // versions advance by 1 each day after this point.
93
+
94
+ #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
95
+ #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
96
+ #define TF_GRAPH_DEF_VERSION 2144 // Updated: 2025/2/20
97
+
98
+ // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
99
+ //
100
+ // The checkpoint versions have the same semantics as GraphDef versions, but the
101
+ // numbering scheme is separate. We have no plans to ever deprecate checkpoint
102
+ // versions, but it's good to have this in place in case we ever need to.
103
+ //
104
+ // Version history:
105
+ //
106
+ // 0. Checkpoints saved before checkpoint versioning.
107
+ // 1. First real version (10feb2015).
108
+ #define TF_CHECKPOINT_VERSION_MIN_PRODUCER 0
109
+ #define TF_CHECKPOINT_VERSION_MIN_CONSUMER 0
110
+ #define TF_CHECKPOINT_VERSION 1
111
+
112
+ #endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_
@@ -443,7 +443,7 @@ EagerContextThreadLocalData* GetEagerContextThreadLocalData(
443
443
  // wish to destroy thread-local state associated with a single py_eager_context
444
444
  // for multiple threads, then you must call this method from each thread.
445
445
  //
446
- // Thread-local state assocaited with eager contexts is also automatically
446
+ // Thread-local state associated with eager contexts is also automatically
447
447
  // cleaned up when the thread is destroyed.
448
448
  //
449
449
  // This function assumes that the Python GIL is held (and does not perform its
@@ -89,9 +89,10 @@ class KernelApiIrBuilder {
89
89
  // read-only if it is not aliased with any result.
90
90
  absl::flat_hash_set<int64_t> invariant_arguments;
91
91
 
92
- // the set of buffer uses for this kernel, can be empty if buffer
92
+ // The set of buffers used by this kernel, can be empty if buffer assignment
93
93
  // was not provided.
94
- absl::InlinedVector<BufferUse, 8> buffer_uses;
94
+ absl::InlinedVector<BufferAllocation::Slice, 8> argument_buffers;
95
+ absl::InlinedVector<BufferAllocation::Slice, 8> result_buffers;
95
96
  };
96
97
 
97
98
  KernelApiIrBuilder(llvm::LLVMContext& context, Options options);
@@ -63,6 +63,8 @@ class KernelThunkBase : public Thunk {
63
63
  const = 0;
64
64
 
65
65
  virtual absl::Span<const BufferAllocation::Slice> results_buffers() const = 0;
66
+
67
+ virtual const absl::flat_hash_set<int64_t>& invariant_arguments() const = 0;
66
68
  };
67
69
 
68
70
  namespace internal {
@@ -95,6 +97,10 @@ class KernelThunk : public KernelThunkBase {
95
97
  return absl::MakeSpan(results_buffers_);
96
98
  }
97
99
 
100
+ const absl::flat_hash_set<int64_t>& invariant_arguments() const final {
101
+ return invariant_arguments_;
102
+ }
103
+
98
104
  protected:
99
105
  tsl::AsyncValueRef<ExecuteEvent> ExecuteInternal(const ExecuteParams& params);
100
106
 
@@ -129,7 +135,7 @@ class KernelThunk : public KernelThunkBase {
129
135
  KernelThunk(Info info,
130
136
  absl::Span<const BufferAllocation::Slice> arguments_buffers,
131
137
  absl::Span<const BufferAllocation::Slice> results_buffers,
132
- std::optional<absl::flat_hash_set<int64_t>> invariant_arguments,
138
+ absl::flat_hash_set<int64_t> invariant_arguments,
133
139
  std::string kernel_name, se::ThreadDim thread_dim,
134
140
  std::optional<uint64_t> min_alignment);
135
141
 
@@ -139,7 +145,7 @@ class KernelThunk : public KernelThunkBase {
139
145
  ResultsBuffers results_buffers_;
140
146
 
141
147
  // A set of invariant arguments (their indices).
142
- std::optional<absl::flat_hash_set<int64_t>> invariant_arguments_;
148
+ absl::flat_hash_set<int64_t> invariant_arguments_;
143
149
 
144
150
  size_t num_kernel_args_;
145
151
 
@@ -189,7 +195,7 @@ class KernelThunk final : public internal::KernelThunk<> {
189
195
  absl::Span<const BufferAllocation::Slice> arguments_buffers,
190
196
  absl::Span<const BufferAllocation::Slice> results_buffers,
191
197
  std::string kernel_name, se::ThreadDim thread_dim,
192
- std::optional<absl::flat_hash_set<int64_t>> invariant_arguments,
198
+ absl::flat_hash_set<int64_t> invariant_arguments,
193
199
  std::optional<uint64_t> min_alignment = std::nullopt);
194
200
 
195
201
  static absl::StatusOr<std::unique_ptr<Thunk>> Create(
@@ -44,15 +44,6 @@ namespace xla::cpu {
44
44
  // A work queue that partitions `num_tasks` tasks into `num_partitions`
45
45
  // partitions processed by parallel workers.
46
46
  class WorkQueue {
47
- // Align all atomic counters to a cache line boundary to avoid false
48
- // sharing between multiple worker threads.
49
- static constexpr size_t kAtomicAlignment =
50
- #if defined(__cpp_lib_hardware_interference_size)
51
- std::hardware_destructive_interference_size;
52
- #else
53
- 64;
54
- #endif
55
-
56
47
  public:
57
48
  WorkQueue(size_t num_tasks, size_t num_partitions);
58
49
 
@@ -60,13 +51,23 @@ class WorkQueue {
60
51
  // if the partition is complete.
61
52
  std::optional<size_t> Pop(size_t partition_index);
62
53
 
63
- size_t num_partitions() const { return partitions_.size(); }
54
+ // Return the partition [begin, end) task range.
55
+ std::pair<size_t, size_t> partition_range(size_t partition_index) const;
64
56
 
65
- bool empty() const { return empty_.load(std::memory_order_relaxed); }
57
+ size_t num_partitions() const { return partitions_.size(); }
66
58
 
67
59
  private:
68
60
  friend class Worker;
69
61
 
62
+ // Align all atomic counters to a cache line boundary to avoid false
63
+ // sharing between multiple worker threads.
64
+ static constexpr size_t kAtomicAlignment =
65
+ #if defined(__cpp_lib_hardware_interference_size)
66
+ std::hardware_destructive_interference_size;
67
+ #else
68
+ 64;
69
+ #endif
70
+
70
71
  struct Partition {
71
72
  void Initialize(size_t begin, size_t end);
72
73
 
@@ -76,8 +77,21 @@ class WorkQueue {
76
77
  size_t end;
77
78
  };
78
79
 
80
+ // An empty work queue flag to stop worker threads from looping through all
81
+ // partitions looking for work.
82
+ bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
83
+ void SetEmpty() { empty_.store(true, std::memory_order_relaxed); }
84
+
85
+ // Notify that one of the workers switched to the work stealing mode.
86
+ void NotifyWorkStealingWorker();
87
+
88
+ // Decrements the number of work stealing workers by at most `max_workers` and
89
+ // returns the number of decremented work stealing workers.
90
+ size_t DecrementWorkStealingWorkers(size_t max_workers);
91
+
79
92
  absl::FixedArray<Partition, 32> partitions_;
80
93
  alignas(kAtomicAlignment) std::atomic<bool> empty_;
94
+ alignas(kAtomicAlignment) std::atomic<size_t> num_work_stealing_workers_;
81
95
  };
82
96
 
83
97
  // Worker processes tasks from the work queue starting from the assigned
@@ -130,10 +144,14 @@ inline void WorkQueue::Partition::Initialize(size_t begin, size_t end) {
130
144
  }
131
145
 
132
146
  inline WorkQueue::WorkQueue(size_t num_tasks, size_t num_partitions)
133
- : partitions_(num_partitions), empty_(num_tasks == 0) {
134
- size_t partition_size = tsl::MathUtil::CeilOfRatio(num_tasks, num_partitions);
135
- for (size_t i = 0, begin = 0, end = partition_size; i < num_partitions;
136
- ++i, begin = end, end = std::min(num_tasks, end + partition_size)) {
147
+ : partitions_(num_partitions),
148
+ empty_(num_tasks == 0),
149
+ num_work_stealing_workers_(0) {
150
+ size_t partition_size =
151
+ tsl::MathUtil::FloorOfRatio(num_tasks, num_partitions);
152
+ size_t rem_tasks = num_tasks % num_partitions;
153
+ for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
154
+ end = begin + partition_size + ((i < rem_tasks) ? 1 : 0);
137
155
  partitions_[i].Initialize(begin, end);
138
156
  }
139
157
  }
@@ -154,6 +172,29 @@ inline std::optional<size_t> WorkQueue::Pop(size_t partition_index) {
154
172
  : std::make_optional(index);
155
173
  }
156
174
 
175
+ inline std::pair<size_t, size_t> WorkQueue::partition_range(
176
+ size_t partition_index) const {
177
+ DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
178
+ return {partitions_[partition_index].begin, partitions_[partition_index].end};
179
+ }
180
+
181
+ inline void WorkQueue::NotifyWorkStealingWorker() {
182
+ num_work_stealing_workers_.fetch_add(1, std::memory_order_relaxed);
183
+ }
184
+
185
+ inline size_t WorkQueue::DecrementWorkStealingWorkers(size_t max_workers) {
186
+ size_t n = num_work_stealing_workers_.load(std::memory_order_relaxed);
187
+
188
+ size_t decrement = std::min(n, max_workers);
189
+ while (decrement && !num_work_stealing_workers_.compare_exchange_weak(
190
+ n, n - decrement, std::memory_order_relaxed,
191
+ std::memory_order_relaxed)) {
192
+ decrement = std::min(n, max_workers);
193
+ }
194
+
195
+ return decrement;
196
+ }
197
+
157
198
  inline Worker::Worker(size_t worker_index, WorkQueue* queue)
158
199
  : worker_index_(worker_index),
159
200
  partition_index_(worker_index),
@@ -163,7 +204,13 @@ inline std::optional<size_t> Worker::Pop() {
163
204
  std::optional<size_t> task = queue_->Pop(partition_index_);
164
205
  if (ABSL_PREDICT_TRUE(task)) return task;
165
206
 
166
- while (!task.has_value() && !queue_->empty()) {
207
+ // If we didn't find a task in the initially assigned partition, notify the
208
+ // work queue that we are switching to work stealing mode.
209
+ if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
210
+ queue_->NotifyWorkStealingWorker();
211
+ }
212
+
213
+ while (!task.has_value() && !queue_->IsEmpty()) {
167
214
  // Wrap around to the first partition.
168
215
  if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
169
216
  partition_index_ = 0;
@@ -171,7 +218,7 @@ inline std::optional<size_t> Worker::Pop() {
171
218
 
172
219
  // We checked all partitions and got back to the partition we started from.
173
220
  if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
174
- queue_->empty_.store(true, std::memory_order_relaxed);
221
+ queue_->SetEmpty();
175
222
  break;
176
223
  }
177
224
 
@@ -205,6 +252,7 @@ Worker::ParallelizeContext<ParallelTask>::ParallelizeContext(
205
252
  parallel_task(std::forward<ParallelTask>(parallel_task)) {}
206
253
 
207
254
  template <typename ParallelTask>
255
+ // NOLINTNEXTLINE(readability-function-cognitive-complexity)
208
256
  void Worker::ParallelizeWithContext(ParallelizeContext<ParallelTask>* ctx,
209
257
  uint16_t start_index, uint16_t end_index) {
210
258
  DCHECK_LT(start_index, end_index) << "Invalid worker index range";
@@ -223,11 +271,26 @@ void Worker::ParallelizeWithContext(ParallelizeContext<ParallelTask>* ctx,
223
271
  while (end_index - start_index > 1) {
224
272
  // If work queue is empty, we don't need to keep enqueuing more workers and
225
273
  // can simply count down for the remaining workers.
226
- if (ABSL_PREDICT_FALSE(ctx->work_queue.empty())) {
274
+ if (ABSL_PREDICT_FALSE(ctx->work_queue.IsEmpty())) {
227
275
  count_down(end_index - start_index, absl::OkStatus());
228
276
  return;
229
277
  }
230
278
 
279
+ // If we have workers in the work stealing mode, we can skip enqueuing
280
+ // more tasks as existing workers will process remaining partitions. By
281
+ // doing this optimization we avoid unnecessary thread pool overheads.
282
+ size_t skip_workers =
283
+ ctx->work_queue.DecrementWorkStealingWorkers(end_index - start_index);
284
+ if (ABSL_PREDICT_FALSE(skip_workers > 0)) {
285
+ DCHECK_LE(skip_workers, end_index - start_index);
286
+ count_down(skip_workers, absl::OkStatus());
287
+
288
+ end_index -= skip_workers;
289
+ if (start_index == end_index) return;
290
+ if (end_index - start_index == 1) break;
291
+ }
292
+
293
+ DCHECK_GE(end_index - start_index, 1);
231
294
  uint16_t mid_index = (start_index + end_index) / 2;
232
295
  ctx->device->enqueueNoNotification([ctx, mid_index, end_index] {
233
296
  ParallelizeWithContext(ctx, mid_index, end_index);
@@ -17,12 +17,14 @@ limitations under the License.
17
17
  #define XLA_CODEGEN_KERNEL_SPEC_H_
18
18
 
19
19
  #include <cstddef>
20
+ #include <cstdint>
20
21
  #include <optional>
21
22
  #include <string>
22
23
 
24
+ #include "absl/container/flat_hash_set.h"
23
25
  #include "absl/container/inlined_vector.h"
24
26
  #include "absl/strings/string_view.h"
25
- #include "xla/runtime/buffer_use.h"
27
+ #include "xla/service/buffer_assignment.h"
26
28
  #include "xla/stream_executor/launch_dim.h"
27
29
 
28
30
  namespace xla {
@@ -33,15 +35,17 @@ namespace xla {
33
35
  // will load kernel PTX on device and instantiate a KernelThunk.
34
36
  class KernelSpec {
35
37
  public:
36
- using BufferUses = absl::InlinedVector<BufferUse, 8>;
38
+ using Buffers = absl::InlinedVector<BufferAllocation::Slice, 8>;
37
39
 
38
40
  KernelSpec(absl::string_view name, se::ThreadDim thread_dim,
39
- BufferUses buffer_uses,
41
+ Buffers argument_buffers, Buffers result_buffers,
42
+ absl::flat_hash_set<int64_t> invariant_arguments,
40
43
  std::optional<size_t> scratch_bytes = std::nullopt);
41
44
 
42
45
  KernelSpec(absl::string_view name, se::ClusterDim cluster_dim,
43
46
  se::BlockDim block_dim, se::ThreadDim thread_dim,
44
- BufferUses buffer_uses,
47
+ Buffers argument_buffers, Buffers result_buffers,
48
+ absl::flat_hash_set<int64_t> invariant_arguments,
45
49
  std::optional<size_t> scratch_bytes = std::nullopt);
46
50
 
47
51
  // Get the backend specific name of the kernel.
@@ -67,15 +71,28 @@ class KernelSpec {
67
71
  // managed buffer that is likely to be in L1/L2 cache).
68
72
  std::optional<size_t> scratch_bytes() const { return scratch_bytes_; }
69
73
 
70
- // Buffers (buffer allocation slices) used by the kernel.
71
- const BufferUses& buffer_uses() const { return buffer_uses_; }
74
+ // Argument buffers read by the kernel.
75
+ const Buffers& argument_buffers() const { return argument_buffers_; }
76
+ // Result buffers written to by the kernel.
77
+ const Buffers& result_buffers() const { return result_buffers_; }
78
+
79
+ // Returns a set of invariant arguments (corresponding to the indices in the
80
+ // argument buffers list).
81
+ const absl::flat_hash_set<int64_t>& invariant_arguments() const {
82
+ return invariant_arguments_;
83
+ }
72
84
 
73
85
  private:
74
86
  std::string name_;
75
87
  se::ClusterDim cluster_dim_;
76
88
  se::BlockDim block_dim_;
77
89
  se::ThreadDim thread_dim_;
78
- BufferUses buffer_uses_;
90
+
91
+ Buffers argument_buffers_;
92
+ Buffers result_buffers_;
93
+
94
+ absl::flat_hash_set<int64_t> invariant_arguments_;
95
+
79
96
  std::optional<size_t> scratch_bytes_;
80
97
  };
81
98
 
@@ -44,28 +44,6 @@ T* Cast(HloInstruction* instr) {
44
44
  return tsl::down_cast<T*>(instr);
45
45
  }
46
46
 
47
- // Downcasts a const HloInstruction pointer or returns nullptr if argument is
48
- // nullptr. Dies if TargetClass::ClassOf() does not match.
49
- template <typename T>
50
- const T* CastOrNull(const HloInstruction* i) {
51
- if (i == nullptr) {
52
- return nullptr;
53
- }
54
- CHECK(T::ClassOf(i));
55
- return tsl::down_cast<const T*>(i);
56
- }
57
-
58
- // Downcasts a const HloInstruction pointer or returns nullptr if argument is
59
- // nullptr. Dies if TargetClass::ClassOf() does not match.
60
- template <typename T>
61
- T* CastOrNull(HloInstruction* i) {
62
- if (i == nullptr) {
63
- return nullptr;
64
- }
65
- CHECK(T::ClassOf(i));
66
- return tsl::down_cast<T*>(i);
67
- }
68
-
69
47
  // Downcasts a const HloInstruction pointer or returns nullptr if
70
48
  // TargetClass::ClassOf() does not match. Dies if argument is nullptr. Similar
71
49
  // to LLVM's dyn_cast.
@@ -84,28 +62,6 @@ T* DynCast(HloInstruction* i) {
84
62
  return !T::ClassOf(i) ? nullptr : tsl::down_cast<T*>(i);
85
63
  }
86
64
 
87
- // Downcasts a const HloInstruction pointer. Return nullptr if argument is
88
- // nullptr orTargetClass::ClassOf() does not match. Similar to LLVM's
89
- // dyn_cast_or_null.
90
- template <typename T>
91
- const T* DynCastOrNull(const HloInstruction* instruction) {
92
- if (instruction == nullptr || !T::ClassOf(instruction)) {
93
- return nullptr;
94
- }
95
- return tsl::down_cast<const T*>(instruction);
96
- }
97
-
98
- // Downcasts a non-const HloInstruction pointer. Return nullptr if argument is
99
- // nullptr orTargetClass::ClassOf() does not match. Similar to LLVM's
100
- // dyn_cast_or_null.
101
- template <typename T>
102
- T* DynCastOrNull(HloInstruction* instruction) {
103
- if (instruction == nullptr || !T::ClassOf(instruction)) {
104
- return nullptr;
105
- }
106
- return tsl::down_cast<T*>(instruction);
107
- }
108
-
109
65
  } // namespace xla
110
66
 
111
67
  #endif // XLA_HLO_IR_HLO_CASTING_UTILS_H_