tf-nightly-cpu 2.20.0.dev20250220__cp39-cp39-win_amd64.whl → 2.20.0.dev20250222__cp39-cp39-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (130) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
  14. tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
  15. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  17. tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
  18. tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
  19. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  20. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  21. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  22. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  23. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
  24. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  25. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  26. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
  27. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  28. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  30. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  31. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  32. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  33. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  34. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  35. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  36. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  37. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  38. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  39. tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
  40. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  41. tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
  42. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  43. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  44. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  45. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  46. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  47. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  48. tensorflow/include/tensorflow/core/public/version.h +112 -127
  49. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  50. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  51. tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  52. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  53. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
  54. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  55. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  56. tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
  57. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  58. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  59. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  60. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  61. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  62. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  63. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  64. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  65. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  66. tensorflow/include/xla/service/constant_value.h +1 -0
  67. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  68. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  69. tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
  70. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  71. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  72. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  73. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  74. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  75. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  76. tensorflow/python/_pywrap_mlir.pyd +0 -0
  77. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  78. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  79. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  80. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  81. tensorflow/python/_pywrap_tfe.pyd +0 -0
  82. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  83. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  84. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  85. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  86. tensorflow/python/compat/compat.py +1 -1
  87. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  88. tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
  89. tensorflow/python/eager/imperative_grad.py +5 -5
  90. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  91. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  92. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  93. tensorflow/python/eager/tape.py +2 -2
  94. tensorflow/python/framework/_dtypes.pyd +0 -0
  95. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  96. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  97. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  98. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  99. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  100. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  101. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  102. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  103. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  104. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  105. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  106. tensorflow/python/ops/summary_ops_v2.py +5 -1
  107. tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
  108. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  109. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  110. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  111. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  112. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  113. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  114. tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
  115. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  116. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  117. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  118. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  119. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  120. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  121. tensorflow/python/util/_tf_stack.pyd +0 -0
  122. tensorflow/tools/pip_package/setup.py +2 -2
  123. tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
  124. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
  125. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +128 -123
  126. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
  127. tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
  128. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
  129. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
  130. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+ ==============================================================================*/
12
+ #ifndef XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
13
+ #define XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
14
+
15
+ #include <cstdint>
16
+ #include <memory>
17
+ #include <optional>
18
+ #include <string>
19
+ #include <utility>
20
+ #include <vector>
21
+
22
+ #include "absl/container/flat_hash_map.h"
23
+ #include "absl/status/statusor.h"
24
+ #include "absl/strings/string_view.h"
25
+ #include "absl/types/span.h"
26
+ #include "xla/pjrt/gpu/gpu_topology.h"
27
+ #include "xla/pjrt/pjrt_compiler.h"
28
+ #include "xla/pjrt/pjrt_device_description.h"
29
+ #include "xla/pjrt/pjrt_stream_executor_device_description.h"
30
+
31
+ namespace xla {
32
+
33
+ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
34
+ public:
35
+ StreamExecutorGpuTopologyDescription(
36
+ const PjRtPlatformId platform_id, const absl::string_view platform_name,
37
+ std::shared_ptr<const GpuTopology> gpu_topology,
38
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
39
+ {},
40
+ std::optional<stream_executor::GpuTargetConfigProto> target_config =
41
+ std::nullopt)
42
+ : platform_id_(platform_id),
43
+ platform_name_(platform_name),
44
+ gpu_topology_(std::move(gpu_topology)),
45
+ attributes_(attributes),
46
+ target_config_(std::move(target_config)) {}
47
+
48
+ bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
49
+ return this->platform_id() == other.platform_id() &&
50
+ this->platform_name() == other.platform_name() &&
51
+ this->platform_version() == other.platform_version() &&
52
+ this->gpu_topology() == other.gpu_topology();
53
+ }
54
+
55
+ PjRtPlatformId platform_id() const override { return platform_id_; }
56
+
57
+ absl::string_view platform_name() const override { return platform_name_; }
58
+
59
+ absl::string_view platform_version() const override {
60
+ return gpu_topology_->platform_version();
61
+ }
62
+
63
+ std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
64
+ const override {
65
+ std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
66
+ devices.reserve(gpu_topology_->number_of_devices());
67
+ for (const int device_id : gpu_topology_->device_ids()) {
68
+ devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
69
+ device_id, std::string(platform_version())));
70
+ }
71
+ return devices;
72
+ }
73
+
74
+ const GpuTopology& gpu_topology() const { return *gpu_topology_; }
75
+ const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
76
+
77
+ // No subslice is supported.
78
+ bool is_subslice_topology() const override { return false; }
79
+
80
+ absl::StatusOr<int> ProcessCount() const override {
81
+ return gpu_topology_->number_of_hosts();
82
+ }
83
+
84
+ absl::StatusOr<int> CoreCountOfDefaultType() const override {
85
+ return gpu_topology_->number_of_devices();
86
+ }
87
+
88
+ absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
89
+ return gpu_topology_->number_of_devices();
90
+ }
91
+
92
+ absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
93
+ return gpu_topology_->number_of_devices();
94
+ }
95
+
96
+ absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
97
+ return 1;
98
+ }
99
+
100
+ absl::StatusOr<std::string> Serialize() const override;
101
+
102
+ const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
103
+ const {
104
+ return target_config_;
105
+ }
106
+
107
+ // Returns vendor specific attributes about the topology.
108
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
109
+ const override {
110
+ return attributes_;
111
+ }
112
+
113
+ absl::StatusOr<Layout> GetDefaultLayout(
114
+ PrimitiveType element_type,
115
+ absl::Span<const int64_t> dims) const override;
116
+
117
+ private:
118
+ const PjRtPlatformId platform_id_;
119
+ const std::string platform_name_;
120
+ std::shared_ptr<const GpuTopology> gpu_topology_;
121
+ absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
122
+ std::optional<stream_executor::GpuTargetConfigProto> target_config_;
123
+ };
124
+ } // namespace xla
125
+
126
+ #endif // XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
@@ -52,9 +52,9 @@ limitations under the License.
52
52
  #include "xla/pjrt/pjrt_client.h"
53
53
  #include "xla/pjrt/pjrt_common.h"
54
54
  #include "xla/pjrt/pjrt_compiler.h"
55
- #include "xla/pjrt/pjrt_device_description.h"
56
55
  #include "xla/pjrt/pjrt_executable.h"
57
56
  #include "xla/pjrt/pjrt_future.h"
57
+ #include "xla/pjrt/pjrt_stream_executor_device_description.h"
58
58
  #include "xla/pjrt/tracked_device_buffer.h"
59
59
  #include "xla/pjrt/transpose.h"
60
60
  #include "xla/pjrt/utils.h"
@@ -77,54 +77,6 @@ limitations under the License.
77
77
 
78
78
  namespace xla {
79
79
 
80
- class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
81
- public:
82
- explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
83
- int process_index = 0)
84
- : id_(id),
85
- process_index_(process_index),
86
- device_kind_(std::move(device_kind)) {}
87
-
88
- int id() const override { return id_; }
89
-
90
- int process_index() const override { return process_index_; }
91
-
92
- absl::string_view device_kind() const override { return device_kind_; }
93
-
94
- absl::string_view ToString() const override { return to_string_; }
95
-
96
- absl::string_view DebugString() const override { return debug_string_; }
97
-
98
- absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
99
-
100
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
101
- const override {
102
- return attributes_;
103
- }
104
-
105
- void SetAttributes(
106
- absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
107
- attributes_ = std::move(attributes);
108
- }
109
-
110
- void SetDebugString(std::string debug_string) {
111
- debug_string_ = std::move(debug_string);
112
- }
113
-
114
- void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
115
-
116
- void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
117
-
118
- private:
119
- const int id_;
120
- const int process_index_;
121
- const std::string device_kind_;
122
- std::string debug_string_ = "<unknown SE device>";
123
- std::string to_string_ = "<unknown SE device>";
124
- absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
125
- std::array<int, 1> coords_;
126
- };
127
-
128
80
  class PjRtStreamExecutorDevice : public PjRtDevice {
129
81
  public:
130
82
  explicit PjRtStreamExecutorDevice(
@@ -0,0 +1,75 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+ ==============================================================================*/
12
+ #ifndef XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
13
+ #define XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
14
+
15
+ #include <array>
16
+ #include <string>
17
+ #include <utility>
18
+
19
+ #include "absl/container/flat_hash_map.h"
20
+ #include "absl/strings/string_view.h"
21
+ #include "absl/types/span.h"
22
+ #include "xla/pjrt/pjrt_device_description.h"
23
+
24
+ namespace xla {
25
+
26
+ class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
27
+ public:
28
+ explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
29
+ int process_index = 0)
30
+ : id_(id),
31
+ process_index_(process_index),
32
+ device_kind_(std::move(device_kind)) {}
33
+
34
+ int id() const override { return id_; }
35
+
36
+ int process_index() const override { return process_index_; }
37
+
38
+ absl::string_view device_kind() const override { return device_kind_; }
39
+
40
+ absl::string_view ToString() const override { return to_string_; }
41
+
42
+ absl::string_view DebugString() const override { return debug_string_; }
43
+
44
+ absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
45
+
46
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
47
+ const override {
48
+ return attributes_;
49
+ }
50
+
51
+ void SetAttributes(
52
+ absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
53
+ attributes_ = std::move(attributes);
54
+ }
55
+
56
+ void SetDebugString(std::string debug_string) {
57
+ debug_string_ = std::move(debug_string);
58
+ }
59
+
60
+ void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
61
+
62
+ void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
63
+
64
+ private:
65
+ const int id_;
66
+ const int process_index_;
67
+ const std::string device_kind_;
68
+ std::string debug_string_ = "<unknown SE device>";
69
+ std::string to_string_ = "<unknown SE device>";
70
+ absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
71
+ std::array<int, 1> coords_;
72
+ };
73
+ } // namespace xla
74
+
75
+ #endif // XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
@@ -0,0 +1,57 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
17
+ #define XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
18
+
19
+ #include <optional>
20
+
21
+ #include "xla/backends/cpu/collectives/cpu_collectives.h"
22
+ #include "xla/pjrt/pjrt_executable.h"
23
+
24
+ namespace xla {
25
+
26
+ // ExecuteContext for XLA:CPU PjRtLoadedExecutable::Execute calls.
27
+ class CpuExecuteContext : public ExecuteContext {
28
+ public:
29
+ ~CpuExecuteContext() override = default;
30
+
31
+ // If specified, override the process ID specified in
32
+ // `CpuClientOptions::process_id` for a particular call of
33
+ // PjRtLoadedExecutable::Execute.
34
+ //
35
+ // TODO(hyeontaek): Look for a collectives-agnostic way and combine this
36
+ // option with `ExecuteOptions::multi_slice_config`.
37
+ std::optional<int>& process_index() { return process_index_; }
38
+ std::optional<int> process_index() const { return process_index_; }
39
+
40
+ // If specified, override CPU collectives specified in
41
+ // `CpuClientOptions::collectives` for a particular call of
42
+ // PjRtLoadedExecutable::Execute. Must remain valid until the execution
43
+ // finishes.
44
+ //
45
+ // TODO(hyeontaek): Look for a collectives-agnostic way and combine this
46
+ // option with `ExecuteOptions::multi_slice_config`.
47
+ cpu::CpuCollectives*& collectives() { return collectives_; }
48
+ cpu::CpuCollectives* collectives() const { return collectives_; }
49
+
50
+ private:
51
+ std::optional<int> process_index_;
52
+ cpu::CpuCollectives* collectives_ = nullptr;
53
+ };
54
+
55
+ } // namespace xla
56
+
57
+ #endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
@@ -69,6 +69,10 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) {
69
69
  return global_device_id.value() / kMaxCpuDevicesPerProcess;
70
70
  }
71
71
 
72
+ inline int UnpackCpuLocalDeviceId(PjRtGlobalDeviceId global_device_id) {
73
+ return global_device_id.value() % kMaxCpuDevicesPerProcess;
74
+ }
75
+
72
76
  } // namespace xla
73
77
 
74
78
  #endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_
@@ -18,6 +18,7 @@ limitations under the License.
18
18
 
19
19
  #include <string>
20
20
 
21
+ #include "absl/base/casts.h"
21
22
  #include "absl/status/statusor.h"
22
23
  #include "xla/literal.h"
23
24
  #include "xla/util.h"
@@ -19,16 +19,67 @@ limitations under the License.
19
19
  #include <functional>
20
20
  #include <memory>
21
21
  #include <optional>
22
+ #include <string>
22
23
 
23
- #include "absl/status/status.h"
24
+ #include "absl/log/check.h"
25
+ #include "absl/log/log.h"
24
26
  #include "absl/status/statusor.h"
27
+ #include "absl/strings/string_view.h"
25
28
  #include "absl/types/span.h"
29
+ #include "xla/hlo/ir/hlo_module.h"
30
+ #include "xla/hlo/parser/hlo_parser.h"
26
31
  #include "xla/service/compiler.h"
27
32
  #include "xla/service/hlo_module_config.h"
28
33
  #include "xla/shape.h"
34
+ #include "xla/util.h"
29
35
 
30
36
  namespace xla {
31
37
 
38
+ // Converts an HloModule from the given hlo textual IR string (in
39
+ // HloModule::ToString format).
40
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
41
+ absl::string_view hlo_string,
42
+ const DebugOptions& debug_options = DebugOptions::default_instance());
43
+
44
+ // Creates an HloModule from the given proto.
45
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
46
+ const HloModuleProto& proto,
47
+ const DebugOptions& debug_options = DebugOptions::default_instance());
48
+
49
+ // Create an HLO state from serialized representation. In addition to
50
+ // creating the proto with HloModule::CreateFromProto(...) it also
51
+ // uses HloVerifier to ensure basic invariants are held.
52
+ // The HLO module could be a pre-optimizations (default) or post-optimizations
53
+ // module, which affects how the HLO module is verified, e.g., mixed-precision
54
+ // is allowed in post-optimizations HLOs.
55
+ absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
56
+ const HloModuleProto& proto, const HloModuleConfig& module_config,
57
+ bool is_module_post_optimizations = false);
58
+
59
+ // Reads the proto file in xla.HloProto format, creates and returns the
60
+ // HloModule.
61
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
62
+ absl::string_view filename,
63
+ const DebugOptions& debug_options = DebugOptions::default_instance());
64
+
65
+ // Reads the proto file in xla.HloModule format, creates and returns the
66
+ // HloModule.
67
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
68
+ absl::string_view filename, const DebugOptions& debug_options);
69
+
70
+ // Reads the HLO text dump file in HloModule::ToString format, creates and
71
+ // returns the HloModule.
72
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
73
+ absl::string_view filename,
74
+ const DebugOptions& debug_options = DebugOptions::default_instance(),
75
+ const HloParserOptions& options = HloParserOptions());
76
+
77
+ // Reads the proto file in xla.HloProto format, creates and returns the
78
+ // HloModule.
79
+ absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
80
+ absl::string_view hlo_file,
81
+ const DebugOptions& debug_options = DebugOptions::default_instance());
82
+
32
83
  // Creates an HloModuleConfig for a given program shape and arguments.
33
84
  // If execution_options does not set num_replicas, default_num_replicas is used.
34
85
  // num_threads is optional; if not given, intra_op_parallelism_threads not set.
@@ -18,8 +18,6 @@ limitations under the License.
18
18
  #ifndef XLA_SERVICE_HLO_PROTO_UTIL_H_
19
19
  #define XLA_SERVICE_HLO_PROTO_UTIL_H_
20
20
 
21
- #include <string>
22
-
23
21
  #include "absl/status/status.h"
24
22
  #include "xla/hlo/ir/hlo_module.h"
25
23
  #include "xla/service/buffer_assignment.h"
@@ -35,16 +33,6 @@ HloProto MakeHloProto(const HloModule& module,
35
33
  // will not be included in the output.
36
34
  HloProto MakeHloProto(const HloModule& module);
37
35
 
38
- // Create an HLO state from serialized representation. In addition to
39
- // creating the proto with HloModule::CreateFromProto(...) it also
40
- // uses HloVerifier to ensure basic invariants are held.
41
- // The HLO module could be a pre-optimizations (default) or post-optimizations
42
- // module, which affects how the HLO module is verified, e.g., mixed-precision
43
- // is allowed in post-optimizations HLOs.
44
- absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
45
- const HloModuleProto& proto, const HloModuleConfig& module_config,
46
- bool is_module_post_optimizations = false);
47
-
48
36
  // Returns the shapes of the parameters of the entry computation. Shape pointers
49
37
  // refer to shapes inside of the given HloProto.
50
38
  absl::StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
@@ -35,9 +35,6 @@ limitations under the License.
35
35
  #include "xla/tsl/platform/logging.h"
36
36
 
37
37
  namespace tsl {
38
-
39
- class NotifierListNode;
40
-
41
38
  namespace internal {
42
39
 
43
40
  template <typename T>
@@ -277,6 +274,8 @@ class AsyncValue {
277
274
  protected:
278
275
  friend class IndirectAsyncValue;
279
276
 
277
+ struct WaiterListNode;
278
+
280
279
  static constexpr uint16_t kUnknownTypeId = 0;
281
280
 
282
281
  // Utility template for tag dispatching.
@@ -311,7 +310,7 @@ class AsyncValue {
311
310
 
312
311
  void NotifyAvailable(State available_state);
313
312
  void Destroy();
314
- void RunWaiters(NotifierListNode* list);
313
+ void RunWaiters(WaiterListNode* list);
315
314
 
316
315
  // IsTypeIdCompatible returns true if the type value stored in this AsyncValue
317
316
  // instance can be safely cast to `T`. This is a conservative check. I.e.
@@ -369,6 +368,16 @@ class AsyncValue {
369
368
  // This is a 16-bit value that identifies the type.
370
369
  uint16_t type_id_ = 0;
371
370
 
371
+ // This is a singly linked list of nodes waiting for notification, hanging off
372
+ // of AsyncValue. When the value becomes available or if an error occurs, the
373
+ // callbacks are informed.
374
+ struct WaiterListNode {
375
+ virtual ~WaiterListNode() = default;
376
+ virtual void operator()() = 0;
377
+
378
+ WaiterListNode* next = nullptr;
379
+ };
380
+
372
381
  // The waiter list and the state are compacted into one single atomic word as
373
382
  // accesses to them are tightly related. To change the state from unavailable
374
383
  // (i.e. kUnconstructed or kConstructed) to available
@@ -379,7 +388,7 @@ class AsyncValue {
379
388
  // Invariant: If the state is not available, then the waiter list must be
380
389
  // nullptr.
381
390
  struct WaitersAndState {
382
- // We rely on the fact that all `NotifierListNode` values are aligned at
391
+ // We rely on the fact that all `WaiterListNode` values are aligned at
383
392
  // least to 4 bytes and we can encode state in the lowest 2 bits. We use
384
393
  // the conservative estimation of the minimal alignment of pointers returned
385
394
  // from memory allocation functions.
@@ -390,7 +399,7 @@ class AsyncValue {
390
399
  static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
391
400
  static constexpr uintptr_t kPointerMask = ~kStateMask;
392
401
 
393
- WaitersAndState(NotifierListNode* ptr, State state) {
402
+ WaitersAndState(WaiterListNode* ptr, State state) {
394
403
  value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
395
404
  (state & kStateMask);
396
405
  }
@@ -399,8 +408,8 @@ class AsyncValue {
399
408
  return State(static_cast<State::StateEnum>(value & kStateMask));
400
409
  }
401
410
 
402
- NotifierListNode* waiter() const {
403
- return reinterpret_cast<NotifierListNode*>(value & kPointerMask);
411
+ WaiterListNode* waiter() const {
412
+ return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
404
413
  }
405
414
 
406
415
  uintptr_t value;
@@ -466,8 +475,26 @@ class AsyncValue {
466
475
  return (*type_info_table)[type_id_ - 1];
467
476
  }
468
477
 
469
- void EnqueueWaiter(absl::AnyInvocable<void()> waiter,
470
- WaitersAndState old_value);
478
+ // Adds a waiter list node to the waiter linked list. If the value is
479
+ // available or becomes available, this calls the waiter immediately.
480
+ // Otherwise, we add waiter to the list where it will be called when the value
481
+ // becomes available.
482
+ void EnqueueWaiterListNode(WaiterListNode* waiter,
483
+ WaitersAndState waiters_and_state);
484
+
485
+ template <typename Waiter>
486
+ void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
487
+ static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
488
+
489
+ struct Node final : public WaiterListNode {
490
+ explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
491
+ void operator()() final { waiter(); }
492
+ Waiter waiter;
493
+ };
494
+
495
+ EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
496
+ waiters_and_state);
497
+ }
471
498
 
472
499
  // This is a global counter of the number of AsyncValue instances currently
473
500
  // live in the process. This is intended to be used for debugging only, and
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
983
1010
  // Clients generally want to use AndThen without them each having to check
984
1011
  // to see if the value is present. Check for them, and immediately run the
985
1012
  // waiter if it is already here.
986
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
987
- if (old_value.state() == State::kConcrete ||
988
- old_value.state() == State::kError) {
989
- DCHECK_EQ(old_value.waiter(), nullptr);
1013
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1014
+ if (waiters_and_state.state() == State::kConcrete ||
1015
+ waiters_and_state.state() == State::kError) {
1016
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
990
1017
  waiter();
991
1018
  return;
992
1019
  }
993
- EnqueueWaiter(std::forward<Waiter>(waiter), old_value);
1020
+
1021
+ EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
994
1022
  }
995
1023
 
996
1024
  template <typename Waiter>
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
998
1026
  // Clients generally want to use AndThen without them each having to check
999
1027
  // to see if the value is present. Check for them, and immediately run the
1000
1028
  // waiter if it is already here.
1001
- auto old_value = waiters_and_state_.load(std::memory_order_acquire);
1002
- if (old_value.state() == State::kConcrete ||
1003
- old_value.state() == State::kError) {
1004
- DCHECK_EQ(old_value.waiter(), nullptr);
1029
+ auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
1030
+ if (waiters_and_state.state() == State::kConcrete ||
1031
+ waiters_and_state.state() == State::kError) {
1032
+ DCHECK_EQ(waiters_and_state.waiter(), nullptr);
1005
1033
  executor.Execute(std::forward<Waiter>(waiter));
1006
1034
  return;
1007
1035
  }
1036
+
1008
1037
  EnqueueWaiter(
1009
- [&executor, waiter = std::forward<Waiter>(waiter)]() mutable {
1038
+ [&executor, waiter = std::forward<Waiter>(waiter)] {
1010
1039
  executor.Execute(std::move(waiter));
1011
1040
  },
1012
- old_value);
1041
+ waiters_and_state);
1013
1042
  }
1014
1043
 
1015
1044
  inline void AsyncValue::Destroy() {
@@ -1604,12 +1604,12 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
1604
1604
  Index padding_left = 0, Index padding_right = 0) {
1605
1605
  typedef typename internal::traits<Input>::Index TensorIndex;
1606
1606
  typedef typename internal::traits<Input>::Scalar InputScalar;
1607
- TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608
- internal::traits<Input>::Layout, TensorIndex> >
1607
+ TensorRef<const Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1608
+ internal::traits<Input>::Layout, TensorIndex> >
1609
1609
  in(input);
1610
- TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1611
- internal::traits<Kernel>::NumDimensions,
1612
- internal::traits<Kernel>::Layout, TensorIndex> >
1610
+ TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
1611
+ internal::traits<Kernel>::NumDimensions,
1612
+ internal::traits<Kernel>::Layout, TensorIndex> >
1613
1613
  kern(kernel);
1614
1614
 
1615
1615
  EIGEN_STATIC_ASSERT(
@@ -35,8 +35,6 @@ class GetElementAtIndexOp : public AsyncOpKernel {
35
35
  OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
36
36
  }
37
37
 
38
- ~GetElementAtIndexOp() override {}
39
-
40
38
  void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
41
39
  unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
42
40
  ctx->SetStatus(DoCompute(ctx));
@@ -68,8 +68,8 @@ struct GlimpseExtractionOp {
68
68
  template <typename Input>
69
69
  DSizes<Index, 4> dimensions(const Input& input) const {
70
70
  typedef typename internal::traits<Input>::Index IndexType;
71
- typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
72
- internal::traits<Input>::Layout, IndexType> >
71
+ typedef TensorRef<const Tensor<typename internal::traits<Input>::Scalar, 4,
72
+ internal::traits<Input>::Layout, IndexType> >
73
73
  Ref;
74
74
  Ref in(input);
75
75
 
@@ -86,8 +86,8 @@ struct GlimpseExtractionOp {
86
86
  EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output,
87
87
  const Device& device) const {
88
88
  typedef typename internal::traits<Input>::Index IndexType;
89
- typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
90
- internal::traits<Input>::Layout, IndexType> >
89
+ typedef TensorRef<const Tensor<typename internal::traits<Input>::Scalar, 4,
90
+ internal::traits<Input>::Layout, IndexType> >
91
91
  Ref;
92
92
  Ref in(input);
93
93
  const Index num_channels = in.dimension(0);
@@ -394,13 +394,13 @@ CuboidConvolutionBackwardKernel(
394
394
  const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
395
395
  const DenseIndex strideCols = 1) {
396
396
  typedef typename internal::traits<Input>::Index TensorIndex;
397
- TensorRef<Tensor<typename internal::traits<Input>::Scalar,
398
- internal::traits<Input>::NumDimensions,
399
- internal::traits<Input>::Layout, TensorIndex>>
397
+ TensorRef<const Tensor<typename internal::traits<Input>::Scalar,
398
+ internal::traits<Input>::NumDimensions,
399
+ internal::traits<Input>::Layout, TensorIndex>>
400
400
  in(input);
401
- TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
402
- internal::traits<OutputBackward>::NumDimensions,
403
- internal::traits<OutputBackward>::Layout, TensorIndex>>
401
+ TensorRef<const Tensor<typename internal::traits<OutputBackward>::Scalar,
402
+ internal::traits<OutputBackward>::NumDimensions,
403
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
404
404
  out(output_backward);
405
405
 
406
406
  EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==