tf-nightly-cpu 2.20.0.dev20250220__cp39-cp39-win_amd64.whl → 2.20.0.dev20250222__cp39-cp39-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
- tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
- tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
- tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
- tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
- tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
- tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
- tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
- tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
- tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
- tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
- tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
- tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
- tensorflow/include/tensorflow/core/public/release_version.h +39 -0
- tensorflow/include/tensorflow/core/public/version.h +112 -127
- tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
- tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
- tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
- tensorflow/include/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
- tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/xla/service/constant_value.h +1 -0
- tensorflow/include/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
- tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
- tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
- tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
- tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
- tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
- tensorflow/python/_pywrap_mlir.pyd +0 -0
- tensorflow/python/_pywrap_parallel_device.pyd +0 -0
- tensorflow/python/_pywrap_quantize_training.pyd +0 -0
- tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
- tensorflow/python/_pywrap_tfcompile.pyd +0 -0
- tensorflow/python/_pywrap_tfe.pyd +0 -0
- tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
- tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
- tensorflow/python/compat/compat.py +1 -1
- tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
- tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
- tensorflow/python/eager/imperative_grad.py +5 -5
- tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
- tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
- tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
- tensorflow/python/eager/tape.py +2 -2
- tensorflow/python/framework/_dtypes.pyd +0 -0
- tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
- tensorflow/python/framework/_op_def_registry.pyd +0 -0
- tensorflow/python/framework/_proto_comparators.pyd +0 -0
- tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
- tensorflow/python/framework/_test_metrics_util.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
- tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
- tensorflow/python/ops/summary_ops_v2.py +5 -1
- tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
- tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
- tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
- tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
- tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
- tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
- tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
- tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
- tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
- tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
- tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
- tensorflow/python/util/_pywrap_utils.pyd +0 -0
- tensorflow/python/util/_tf_stack.pyd +0 -0
- tensorflow/tools/pip_package/setup.py +2 -2
- tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +128 -123
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
- tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
|
|
1
|
+
/* Copyright 2025 The OpenXLA Authors.
|
2
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
you may not use this file except in compliance with the License.
|
4
|
+
You may obtain a copy of the License at
|
5
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
Unless required by applicable law or agreed to in writing, software
|
7
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
See the License for the specific language governing permissions and
|
10
|
+
limitations under the License.
|
11
|
+
==============================================================================*/
|
12
|
+
#ifndef XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
|
13
|
+
#define XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
|
14
|
+
|
15
|
+
#include <cstdint>
|
16
|
+
#include <memory>
|
17
|
+
#include <optional>
|
18
|
+
#include <string>
|
19
|
+
#include <utility>
|
20
|
+
#include <vector>
|
21
|
+
|
22
|
+
#include "absl/container/flat_hash_map.h"
|
23
|
+
#include "absl/status/statusor.h"
|
24
|
+
#include "absl/strings/string_view.h"
|
25
|
+
#include "absl/types/span.h"
|
26
|
+
#include "xla/pjrt/gpu/gpu_topology.h"
|
27
|
+
#include "xla/pjrt/pjrt_compiler.h"
|
28
|
+
#include "xla/pjrt/pjrt_device_description.h"
|
29
|
+
#include "xla/pjrt/pjrt_stream_executor_device_description.h"
|
30
|
+
|
31
|
+
namespace xla {
|
32
|
+
|
33
|
+
class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
|
34
|
+
public:
|
35
|
+
StreamExecutorGpuTopologyDescription(
|
36
|
+
const PjRtPlatformId platform_id, const absl::string_view platform_name,
|
37
|
+
std::shared_ptr<const GpuTopology> gpu_topology,
|
38
|
+
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
|
39
|
+
{},
|
40
|
+
std::optional<stream_executor::GpuTargetConfigProto> target_config =
|
41
|
+
std::nullopt)
|
42
|
+
: platform_id_(platform_id),
|
43
|
+
platform_name_(platform_name),
|
44
|
+
gpu_topology_(std::move(gpu_topology)),
|
45
|
+
attributes_(attributes),
|
46
|
+
target_config_(std::move(target_config)) {}
|
47
|
+
|
48
|
+
bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
|
49
|
+
return this->platform_id() == other.platform_id() &&
|
50
|
+
this->platform_name() == other.platform_name() &&
|
51
|
+
this->platform_version() == other.platform_version() &&
|
52
|
+
this->gpu_topology() == other.gpu_topology();
|
53
|
+
}
|
54
|
+
|
55
|
+
PjRtPlatformId platform_id() const override { return platform_id_; }
|
56
|
+
|
57
|
+
absl::string_view platform_name() const override { return platform_name_; }
|
58
|
+
|
59
|
+
absl::string_view platform_version() const override {
|
60
|
+
return gpu_topology_->platform_version();
|
61
|
+
}
|
62
|
+
|
63
|
+
std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
|
64
|
+
const override {
|
65
|
+
std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
|
66
|
+
devices.reserve(gpu_topology_->number_of_devices());
|
67
|
+
for (const int device_id : gpu_topology_->device_ids()) {
|
68
|
+
devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
|
69
|
+
device_id, std::string(platform_version())));
|
70
|
+
}
|
71
|
+
return devices;
|
72
|
+
}
|
73
|
+
|
74
|
+
const GpuTopology& gpu_topology() const { return *gpu_topology_; }
|
75
|
+
const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
|
76
|
+
|
77
|
+
// No subslice is supported.
|
78
|
+
bool is_subslice_topology() const override { return false; }
|
79
|
+
|
80
|
+
absl::StatusOr<int> ProcessCount() const override {
|
81
|
+
return gpu_topology_->number_of_hosts();
|
82
|
+
}
|
83
|
+
|
84
|
+
absl::StatusOr<int> CoreCountOfDefaultType() const override {
|
85
|
+
return gpu_topology_->number_of_devices();
|
86
|
+
}
|
87
|
+
|
88
|
+
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
|
89
|
+
return gpu_topology_->number_of_devices();
|
90
|
+
}
|
91
|
+
|
92
|
+
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
|
93
|
+
return gpu_topology_->number_of_devices();
|
94
|
+
}
|
95
|
+
|
96
|
+
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
|
97
|
+
return 1;
|
98
|
+
}
|
99
|
+
|
100
|
+
absl::StatusOr<std::string> Serialize() const override;
|
101
|
+
|
102
|
+
const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
|
103
|
+
const {
|
104
|
+
return target_config_;
|
105
|
+
}
|
106
|
+
|
107
|
+
// Returns vendor specific attributes about the topology.
|
108
|
+
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
109
|
+
const override {
|
110
|
+
return attributes_;
|
111
|
+
}
|
112
|
+
|
113
|
+
absl::StatusOr<Layout> GetDefaultLayout(
|
114
|
+
PrimitiveType element_type,
|
115
|
+
absl::Span<const int64_t> dims) const override;
|
116
|
+
|
117
|
+
private:
|
118
|
+
const PjRtPlatformId platform_id_;
|
119
|
+
const std::string platform_name_;
|
120
|
+
std::shared_ptr<const GpuTopology> gpu_topology_;
|
121
|
+
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
|
122
|
+
std::optional<stream_executor::GpuTargetConfigProto> target_config_;
|
123
|
+
};
|
124
|
+
} // namespace xla
|
125
|
+
|
126
|
+
#endif // XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
|
@@ -52,9 +52,9 @@ limitations under the License.
|
|
52
52
|
#include "xla/pjrt/pjrt_client.h"
|
53
53
|
#include "xla/pjrt/pjrt_common.h"
|
54
54
|
#include "xla/pjrt/pjrt_compiler.h"
|
55
|
-
#include "xla/pjrt/pjrt_device_description.h"
|
56
55
|
#include "xla/pjrt/pjrt_executable.h"
|
57
56
|
#include "xla/pjrt/pjrt_future.h"
|
57
|
+
#include "xla/pjrt/pjrt_stream_executor_device_description.h"
|
58
58
|
#include "xla/pjrt/tracked_device_buffer.h"
|
59
59
|
#include "xla/pjrt/transpose.h"
|
60
60
|
#include "xla/pjrt/utils.h"
|
@@ -77,54 +77,6 @@ limitations under the License.
|
|
77
77
|
|
78
78
|
namespace xla {
|
79
79
|
|
80
|
-
class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
|
81
|
-
public:
|
82
|
-
explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
|
83
|
-
int process_index = 0)
|
84
|
-
: id_(id),
|
85
|
-
process_index_(process_index),
|
86
|
-
device_kind_(std::move(device_kind)) {}
|
87
|
-
|
88
|
-
int id() const override { return id_; }
|
89
|
-
|
90
|
-
int process_index() const override { return process_index_; }
|
91
|
-
|
92
|
-
absl::string_view device_kind() const override { return device_kind_; }
|
93
|
-
|
94
|
-
absl::string_view ToString() const override { return to_string_; }
|
95
|
-
|
96
|
-
absl::string_view DebugString() const override { return debug_string_; }
|
97
|
-
|
98
|
-
absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
|
99
|
-
|
100
|
-
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
101
|
-
const override {
|
102
|
-
return attributes_;
|
103
|
-
}
|
104
|
-
|
105
|
-
void SetAttributes(
|
106
|
-
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
|
107
|
-
attributes_ = std::move(attributes);
|
108
|
-
}
|
109
|
-
|
110
|
-
void SetDebugString(std::string debug_string) {
|
111
|
-
debug_string_ = std::move(debug_string);
|
112
|
-
}
|
113
|
-
|
114
|
-
void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
|
115
|
-
|
116
|
-
void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
|
117
|
-
|
118
|
-
private:
|
119
|
-
const int id_;
|
120
|
-
const int process_index_;
|
121
|
-
const std::string device_kind_;
|
122
|
-
std::string debug_string_ = "<unknown SE device>";
|
123
|
-
std::string to_string_ = "<unknown SE device>";
|
124
|
-
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
|
125
|
-
std::array<int, 1> coords_;
|
126
|
-
};
|
127
|
-
|
128
80
|
class PjRtStreamExecutorDevice : public PjRtDevice {
|
129
81
|
public:
|
130
82
|
explicit PjRtStreamExecutorDevice(
|
@@ -0,0 +1,75 @@
|
|
1
|
+
/* Copyright 2025 The OpenXLA Authors.
|
2
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
you may not use this file except in compliance with the License.
|
4
|
+
You may obtain a copy of the License at
|
5
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
Unless required by applicable law or agreed to in writing, software
|
7
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
See the License for the specific language governing permissions and
|
10
|
+
limitations under the License.
|
11
|
+
==============================================================================*/
|
12
|
+
#ifndef XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
13
|
+
#define XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
14
|
+
|
15
|
+
#include <array>
|
16
|
+
#include <string>
|
17
|
+
#include <utility>
|
18
|
+
|
19
|
+
#include "absl/container/flat_hash_map.h"
|
20
|
+
#include "absl/strings/string_view.h"
|
21
|
+
#include "absl/types/span.h"
|
22
|
+
#include "xla/pjrt/pjrt_device_description.h"
|
23
|
+
|
24
|
+
namespace xla {
|
25
|
+
|
26
|
+
class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
|
27
|
+
public:
|
28
|
+
explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
|
29
|
+
int process_index = 0)
|
30
|
+
: id_(id),
|
31
|
+
process_index_(process_index),
|
32
|
+
device_kind_(std::move(device_kind)) {}
|
33
|
+
|
34
|
+
int id() const override { return id_; }
|
35
|
+
|
36
|
+
int process_index() const override { return process_index_; }
|
37
|
+
|
38
|
+
absl::string_view device_kind() const override { return device_kind_; }
|
39
|
+
|
40
|
+
absl::string_view ToString() const override { return to_string_; }
|
41
|
+
|
42
|
+
absl::string_view DebugString() const override { return debug_string_; }
|
43
|
+
|
44
|
+
absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
|
45
|
+
|
46
|
+
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
47
|
+
const override {
|
48
|
+
return attributes_;
|
49
|
+
}
|
50
|
+
|
51
|
+
void SetAttributes(
|
52
|
+
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
|
53
|
+
attributes_ = std::move(attributes);
|
54
|
+
}
|
55
|
+
|
56
|
+
void SetDebugString(std::string debug_string) {
|
57
|
+
debug_string_ = std::move(debug_string);
|
58
|
+
}
|
59
|
+
|
60
|
+
void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
|
61
|
+
|
62
|
+
void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
|
63
|
+
|
64
|
+
private:
|
65
|
+
const int id_;
|
66
|
+
const int process_index_;
|
67
|
+
const std::string device_kind_;
|
68
|
+
std::string debug_string_ = "<unknown SE device>";
|
69
|
+
std::string to_string_ = "<unknown SE device>";
|
70
|
+
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
|
71
|
+
std::array<int, 1> coords_;
|
72
|
+
};
|
73
|
+
} // namespace xla
|
74
|
+
|
75
|
+
#endif // XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
@@ -0,0 +1,57 @@
|
|
1
|
+
/* Copyright 2025 The OpenXLA Authors.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
|
17
|
+
#define XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
|
18
|
+
|
19
|
+
#include <optional>
|
20
|
+
|
21
|
+
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
22
|
+
#include "xla/pjrt/pjrt_executable.h"
|
23
|
+
|
24
|
+
namespace xla {
|
25
|
+
|
26
|
+
// ExecuteContext for XLA:CPU PjRtLoadedExecutable::Execute calls.
|
27
|
+
class CpuExecuteContext : public ExecuteContext {
|
28
|
+
public:
|
29
|
+
~CpuExecuteContext() override = default;
|
30
|
+
|
31
|
+
// If specified, override the process ID specified in
|
32
|
+
// `CpuClientOptions::process_id` for a particular call of
|
33
|
+
// PjRtLoadedExecutable::Execute.
|
34
|
+
//
|
35
|
+
// TODO(hyeontaek): Look for a collectives-agnostic way and combine this
|
36
|
+
// option with `ExecuteOptions::multi_slice_config`.
|
37
|
+
std::optional<int>& process_index() { return process_index_; }
|
38
|
+
std::optional<int> process_index() const { return process_index_; }
|
39
|
+
|
40
|
+
// If specified, override CPU collectives specified in
|
41
|
+
// `CpuClientOptions::collectives` for a particular call of
|
42
|
+
// PjRtLoadedExecutable::Execute. Must remain valid until the execution
|
43
|
+
// finishes.
|
44
|
+
//
|
45
|
+
// TODO(hyeontaek): Look for a collectives-agnostic way and combine this
|
46
|
+
// option with `ExecuteOptions::multi_slice_config`.
|
47
|
+
cpu::CpuCollectives*& collectives() { return collectives_; }
|
48
|
+
cpu::CpuCollectives* collectives() const { return collectives_; }
|
49
|
+
|
50
|
+
private:
|
51
|
+
std::optional<int> process_index_;
|
52
|
+
cpu::CpuCollectives* collectives_ = nullptr;
|
53
|
+
};
|
54
|
+
|
55
|
+
} // namespace xla
|
56
|
+
|
57
|
+
#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
|
@@ -69,6 +69,10 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) {
|
|
69
69
|
return global_device_id.value() / kMaxCpuDevicesPerProcess;
|
70
70
|
}
|
71
71
|
|
72
|
+
inline int UnpackCpuLocalDeviceId(PjRtGlobalDeviceId global_device_id) {
|
73
|
+
return global_device_id.value() % kMaxCpuDevicesPerProcess;
|
74
|
+
}
|
75
|
+
|
72
76
|
} // namespace xla
|
73
77
|
|
74
78
|
#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_
|
@@ -19,16 +19,67 @@ limitations under the License.
|
|
19
19
|
#include <functional>
|
20
20
|
#include <memory>
|
21
21
|
#include <optional>
|
22
|
+
#include <string>
|
22
23
|
|
23
|
-
#include "absl/
|
24
|
+
#include "absl/log/check.h"
|
25
|
+
#include "absl/log/log.h"
|
24
26
|
#include "absl/status/statusor.h"
|
27
|
+
#include "absl/strings/string_view.h"
|
25
28
|
#include "absl/types/span.h"
|
29
|
+
#include "xla/hlo/ir/hlo_module.h"
|
30
|
+
#include "xla/hlo/parser/hlo_parser.h"
|
26
31
|
#include "xla/service/compiler.h"
|
27
32
|
#include "xla/service/hlo_module_config.h"
|
28
33
|
#include "xla/shape.h"
|
34
|
+
#include "xla/util.h"
|
29
35
|
|
30
36
|
namespace xla {
|
31
37
|
|
38
|
+
// Converts an HloModule from the given hlo textual IR string (in
|
39
|
+
// HloModule::ToString format).
|
40
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
41
|
+
absl::string_view hlo_string,
|
42
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
43
|
+
|
44
|
+
// Creates an HloModule from the given proto.
|
45
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
46
|
+
const HloModuleProto& proto,
|
47
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
48
|
+
|
49
|
+
// Create an HLO state from serialized representation. In addition to
|
50
|
+
// creating the proto with HloModule::CreateFromProto(...) it also
|
51
|
+
// uses HloVerifier to ensure basic invariants are held.
|
52
|
+
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
53
|
+
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
54
|
+
// is allowed in post-optimizations HLOs.
|
55
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
56
|
+
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
57
|
+
bool is_module_post_optimizations = false);
|
58
|
+
|
59
|
+
// Reads the proto file in xla.HloProto format, creates and returns the
|
60
|
+
// HloModule.
|
61
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
|
62
|
+
absl::string_view filename,
|
63
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
64
|
+
|
65
|
+
// Reads the proto file in xla.HloModule format, creates and returns the
|
66
|
+
// HloModule.
|
67
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
|
68
|
+
absl::string_view filename, const DebugOptions& debug_options);
|
69
|
+
|
70
|
+
// Reads the HLO text dump file in HloModule::ToString format, creates and
|
71
|
+
// returns the HloModule.
|
72
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
|
73
|
+
absl::string_view filename,
|
74
|
+
const DebugOptions& debug_options = DebugOptions::default_instance(),
|
75
|
+
const HloParserOptions& options = HloParserOptions());
|
76
|
+
|
77
|
+
// Reads the proto file in xla.HloProto format, creates and returns the
|
78
|
+
// HloModule.
|
79
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
|
80
|
+
absl::string_view hlo_file,
|
81
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
82
|
+
|
32
83
|
// Creates an HloModuleConfig for a given program shape and arguments.
|
33
84
|
// If execution_options does not set num_replicas, default_num_replicas is used.
|
34
85
|
// num_threads is optional; if not given, intra_op_parallelism_threads not set.
|
@@ -18,8 +18,6 @@ limitations under the License.
|
|
18
18
|
#ifndef XLA_SERVICE_HLO_PROTO_UTIL_H_
|
19
19
|
#define XLA_SERVICE_HLO_PROTO_UTIL_H_
|
20
20
|
|
21
|
-
#include <string>
|
22
|
-
|
23
21
|
#include "absl/status/status.h"
|
24
22
|
#include "xla/hlo/ir/hlo_module.h"
|
25
23
|
#include "xla/service/buffer_assignment.h"
|
@@ -35,16 +33,6 @@ HloProto MakeHloProto(const HloModule& module,
|
|
35
33
|
// will not be included in the output.
|
36
34
|
HloProto MakeHloProto(const HloModule& module);
|
37
35
|
|
38
|
-
// Create an HLO state from serialized representation. In addition to
|
39
|
-
// creating the proto with HloModule::CreateFromProto(...) it also
|
40
|
-
// uses HloVerifier to ensure basic invariants are held.
|
41
|
-
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
42
|
-
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
43
|
-
// is allowed in post-optimizations HLOs.
|
44
|
-
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
45
|
-
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
46
|
-
bool is_module_post_optimizations = false);
|
47
|
-
|
48
36
|
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
49
37
|
// refer to shapes inside of the given HloProto.
|
50
38
|
absl::StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
|
@@ -35,9 +35,6 @@ limitations under the License.
|
|
35
35
|
#include "xla/tsl/platform/logging.h"
|
36
36
|
|
37
37
|
namespace tsl {
|
38
|
-
|
39
|
-
class NotifierListNode;
|
40
|
-
|
41
38
|
namespace internal {
|
42
39
|
|
43
40
|
template <typename T>
|
@@ -277,6 +274,8 @@ class AsyncValue {
|
|
277
274
|
protected:
|
278
275
|
friend class IndirectAsyncValue;
|
279
276
|
|
277
|
+
struct WaiterListNode;
|
278
|
+
|
280
279
|
static constexpr uint16_t kUnknownTypeId = 0;
|
281
280
|
|
282
281
|
// Utility template for tag dispatching.
|
@@ -311,7 +310,7 @@ class AsyncValue {
|
|
311
310
|
|
312
311
|
void NotifyAvailable(State available_state);
|
313
312
|
void Destroy();
|
314
|
-
void RunWaiters(
|
313
|
+
void RunWaiters(WaiterListNode* list);
|
315
314
|
|
316
315
|
// IsTypeIdCompatible returns true if the type value stored in this AsyncValue
|
317
316
|
// instance can be safely cast to `T`. This is a conservative check. I.e.
|
@@ -369,6 +368,16 @@ class AsyncValue {
|
|
369
368
|
// This is a 16-bit value that identifies the type.
|
370
369
|
uint16_t type_id_ = 0;
|
371
370
|
|
371
|
+
// This is a singly linked list of nodes waiting for notification, hanging off
|
372
|
+
// of AsyncValue. When the value becomes available or if an error occurs, the
|
373
|
+
// callbacks are informed.
|
374
|
+
struct WaiterListNode {
|
375
|
+
virtual ~WaiterListNode() = default;
|
376
|
+
virtual void operator()() = 0;
|
377
|
+
|
378
|
+
WaiterListNode* next = nullptr;
|
379
|
+
};
|
380
|
+
|
372
381
|
// The waiter list and the state are compacted into one single atomic word as
|
373
382
|
// accesses to them are tightly related. To change the state from unavailable
|
374
383
|
// (i.e. kUnconstructed or kConstructed) to available
|
@@ -379,7 +388,7 @@ class AsyncValue {
|
|
379
388
|
// Invariant: If the state is not available, then the waiter list must be
|
380
389
|
// nullptr.
|
381
390
|
struct WaitersAndState {
|
382
|
-
// We rely on the fact that all `
|
391
|
+
// We rely on the fact that all `WaiterListNode` values are aligned at
|
383
392
|
// least to 4 bytes and we can encode state in the lowest 2 bits. We use
|
384
393
|
// the conservative estimation of the minimal alignment of pointers returned
|
385
394
|
// from memory allocation functions.
|
@@ -390,7 +399,7 @@ class AsyncValue {
|
|
390
399
|
static constexpr uintptr_t kStateMask = (1ull << 2) - 1;
|
391
400
|
static constexpr uintptr_t kPointerMask = ~kStateMask;
|
392
401
|
|
393
|
-
WaitersAndState(
|
402
|
+
WaitersAndState(WaiterListNode* ptr, State state) {
|
394
403
|
value = (reinterpret_cast<uintptr_t>(ptr) & kPointerMask) |
|
395
404
|
(state & kStateMask);
|
396
405
|
}
|
@@ -399,8 +408,8 @@ class AsyncValue {
|
|
399
408
|
return State(static_cast<State::StateEnum>(value & kStateMask));
|
400
409
|
}
|
401
410
|
|
402
|
-
|
403
|
-
return reinterpret_cast<
|
411
|
+
WaiterListNode* waiter() const {
|
412
|
+
return reinterpret_cast<WaiterListNode*>(value & kPointerMask);
|
404
413
|
}
|
405
414
|
|
406
415
|
uintptr_t value;
|
@@ -466,8 +475,26 @@ class AsyncValue {
|
|
466
475
|
return (*type_info_table)[type_id_ - 1];
|
467
476
|
}
|
468
477
|
|
469
|
-
|
470
|
-
|
478
|
+
// Adds a waiter list node to the waiter linked list. If the value is
|
479
|
+
// available or becomes available, this calls the waiter immediately.
|
480
|
+
// Otherwise, we add waiter to the list where it will be called when the value
|
481
|
+
// becomes available.
|
482
|
+
void EnqueueWaiterListNode(WaiterListNode* waiter,
|
483
|
+
WaitersAndState waiters_and_state);
|
484
|
+
|
485
|
+
template <typename Waiter>
|
486
|
+
void EnqueueWaiter(Waiter&& waiter, WaitersAndState waiters_and_state) {
|
487
|
+
static_assert(std::is_invocable_v<Waiter>, "Waiter must be invocable");
|
488
|
+
|
489
|
+
struct Node final : public WaiterListNode {
|
490
|
+
explicit Node(Waiter waiter) : waiter(std::move(waiter)) {}
|
491
|
+
void operator()() final { waiter(); }
|
492
|
+
Waiter waiter;
|
493
|
+
};
|
494
|
+
|
495
|
+
EnqueueWaiterListNode(new Node{std::forward<Waiter>(waiter)},
|
496
|
+
waiters_and_state);
|
497
|
+
}
|
471
498
|
|
472
499
|
// This is a global counter of the number of AsyncValue instances currently
|
473
500
|
// live in the process. This is intended to be used for debugging only, and
|
@@ -983,14 +1010,15 @@ void AsyncValue::AndThen(Waiter&& waiter) {
|
|
983
1010
|
// Clients generally want to use AndThen without them each having to check
|
984
1011
|
// to see if the value is present. Check for them, and immediately run the
|
985
1012
|
// waiter if it is already here.
|
986
|
-
auto
|
987
|
-
if (
|
988
|
-
|
989
|
-
DCHECK_EQ(
|
1013
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1014
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1015
|
+
waiters_and_state.state() == State::kError) {
|
1016
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
990
1017
|
waiter();
|
991
1018
|
return;
|
992
1019
|
}
|
993
|
-
|
1020
|
+
|
1021
|
+
EnqueueWaiter(std::forward<Waiter>(waiter), waiters_and_state);
|
994
1022
|
}
|
995
1023
|
|
996
1024
|
template <typename Waiter>
|
@@ -998,18 +1026,19 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) {
|
|
998
1026
|
// Clients generally want to use AndThen without them each having to check
|
999
1027
|
// to see if the value is present. Check for them, and immediately run the
|
1000
1028
|
// waiter if it is already here.
|
1001
|
-
auto
|
1002
|
-
if (
|
1003
|
-
|
1004
|
-
DCHECK_EQ(
|
1029
|
+
auto waiters_and_state = waiters_and_state_.load(std::memory_order_acquire);
|
1030
|
+
if (waiters_and_state.state() == State::kConcrete ||
|
1031
|
+
waiters_and_state.state() == State::kError) {
|
1032
|
+
DCHECK_EQ(waiters_and_state.waiter(), nullptr);
|
1005
1033
|
executor.Execute(std::forward<Waiter>(waiter));
|
1006
1034
|
return;
|
1007
1035
|
}
|
1036
|
+
|
1008
1037
|
EnqueueWaiter(
|
1009
|
-
[&executor, waiter = std::forward<Waiter>(waiter)]
|
1038
|
+
[&executor, waiter = std::forward<Waiter>(waiter)] {
|
1010
1039
|
executor.Execute(std::move(waiter));
|
1011
1040
|
},
|
1012
|
-
|
1041
|
+
waiters_and_state);
|
1013
1042
|
}
|
1014
1043
|
|
1015
1044
|
inline void AsyncValue::Destroy() {
|
@@ -1604,12 +1604,12 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
|
|
1604
1604
|
Index padding_left = 0, Index padding_right = 0) {
|
1605
1605
|
typedef typename internal::traits<Input>::Index TensorIndex;
|
1606
1606
|
typedef typename internal::traits<Input>::Scalar InputScalar;
|
1607
|
-
TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
|
1608
|
-
|
1607
|
+
TensorRef<const Tensor<InputScalar, internal::traits<Input>::NumDimensions,
|
1608
|
+
internal::traits<Input>::Layout, TensorIndex> >
|
1609
1609
|
in(input);
|
1610
|
-
TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
|
1611
|
-
|
1612
|
-
|
1610
|
+
TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
|
1611
|
+
internal::traits<Kernel>::NumDimensions,
|
1612
|
+
internal::traits<Kernel>::Layout, TensorIndex> >
|
1613
1613
|
kern(kernel);
|
1614
1614
|
|
1615
1615
|
EIGEN_STATIC_ASSERT(
|
@@ -35,8 +35,6 @@ class GetElementAtIndexOp : public AsyncOpKernel {
|
|
35
35
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
36
36
|
}
|
37
37
|
|
38
|
-
~GetElementAtIndexOp() override {}
|
39
|
-
|
40
38
|
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
41
39
|
unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() {
|
42
40
|
ctx->SetStatus(DoCompute(ctx));
|
@@ -68,8 +68,8 @@ struct GlimpseExtractionOp {
|
|
68
68
|
template <typename Input>
|
69
69
|
DSizes<Index, 4> dimensions(const Input& input) const {
|
70
70
|
typedef typename internal::traits<Input>::Index IndexType;
|
71
|
-
typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
|
72
|
-
|
71
|
+
typedef TensorRef<const Tensor<typename internal::traits<Input>::Scalar, 4,
|
72
|
+
internal::traits<Input>::Layout, IndexType> >
|
73
73
|
Ref;
|
74
74
|
Ref in(input);
|
75
75
|
|
@@ -86,8 +86,8 @@ struct GlimpseExtractionOp {
|
|
86
86
|
EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output,
|
87
87
|
const Device& device) const {
|
88
88
|
typedef typename internal::traits<Input>::Index IndexType;
|
89
|
-
typedef TensorRef<Tensor<typename internal::traits<Input>::Scalar, 4,
|
90
|
-
|
89
|
+
typedef TensorRef<const Tensor<typename internal::traits<Input>::Scalar, 4,
|
90
|
+
internal::traits<Input>::Layout, IndexType> >
|
91
91
|
Ref;
|
92
92
|
Ref in(input);
|
93
93
|
const Index num_channels = in.dimension(0);
|
@@ -394,13 +394,13 @@ CuboidConvolutionBackwardKernel(
|
|
394
394
|
const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
|
395
395
|
const DenseIndex strideCols = 1) {
|
396
396
|
typedef typename internal::traits<Input>::Index TensorIndex;
|
397
|
-
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
|
398
|
-
|
399
|
-
|
397
|
+
TensorRef<const Tensor<typename internal::traits<Input>::Scalar,
|
398
|
+
internal::traits<Input>::NumDimensions,
|
399
|
+
internal::traits<Input>::Layout, TensorIndex>>
|
400
400
|
in(input);
|
401
|
-
TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
|
402
|
-
|
403
|
-
|
401
|
+
TensorRef<const Tensor<typename internal::traits<OutputBackward>::Scalar,
|
402
|
+
internal::traits<OutputBackward>::NumDimensions,
|
403
|
+
internal::traits<OutputBackward>::Layout, TensorIndex>>
|
404
404
|
out(output_backward);
|
405
405
|
|
406
406
|
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
|