tf-nightly-cpu 2.20.0.dev20250220__cp310-cp310-win_amd64.whl → 2.20.0.dev20250222__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (128) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
  14. tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
  15. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  17. tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
  18. tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
  19. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  20. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  21. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  22. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  23. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
  24. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  25. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  26. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
  27. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  28. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  30. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  31. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  32. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  33. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  34. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  35. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  36. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  37. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  38. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  39. tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
  40. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  41. tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
  42. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  43. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  44. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  45. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  46. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  47. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  48. tensorflow/include/tensorflow/core/public/version.h +112 -127
  49. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  50. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  51. tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  52. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  53. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
  54. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  55. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  56. tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
  57. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  58. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  59. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  60. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  61. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  62. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  63. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  64. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  65. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  66. tensorflow/include/xla/service/constant_value.h +1 -0
  67. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  68. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  69. tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
  70. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  71. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  72. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  73. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  74. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  75. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  76. tensorflow/python/_pywrap_mlir.pyd +0 -0
  77. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  78. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  79. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  80. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  81. tensorflow/python/_pywrap_tfe.pyd +0 -0
  82. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  83. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  84. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  85. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  86. tensorflow/python/compat/compat.py +1 -1
  87. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  88. tensorflow/python/eager/imperative_grad.py +5 -5
  89. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  90. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  91. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  92. tensorflow/python/eager/tape.py +2 -2
  93. tensorflow/python/framework/_dtypes.pyd +0 -0
  94. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  95. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  96. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  97. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  98. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  99. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  100. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  101. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  102. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  103. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  104. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  105. tensorflow/python/ops/summary_ops_v2.py +5 -1
  106. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  107. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  108. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  109. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  110. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  111. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  112. tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
  113. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  114. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  115. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  116. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  117. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  118. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  119. tensorflow/python/util/_tf_stack.pyd +0 -0
  120. tensorflow/tools/pip_package/setup.py +2 -2
  121. tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
  122. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
  123. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +126 -121
  124. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
  125. tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
  126. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
  127. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
  128. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  #ifdef GEN_PASS_DECL
4
4
  // Generate declarations for all passes.
5
5
  #define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
6
+ #define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
6
7
  #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
7
8
  #define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
8
9
  #define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
87
88
  #undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
88
89
  #endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
89
90
 
91
+ //===----------------------------------------------------------------------===//
92
+ // StablehloAddQDQAfterConvPass
93
+ //===----------------------------------------------------------------------===//
94
+ #ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
95
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
96
+ #undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
97
+ #endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
98
+ #ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
99
+
100
+ namespace impl {
101
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
102
+ } // namespace impl
103
+ namespace impl {
104
+
105
+ template <typename DerivedT>
106
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
107
+ public:
108
+ using Base = StablehloAddQDQAfterConvPassBase;
109
+
110
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
111
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
112
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
113
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
114
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
115
+ ~StablehloAddQDQAfterConvPassBase() = default;
116
+
117
+ /// Returns the command-line argument attached to this pass.
118
+ static constexpr ::llvm::StringLiteral getArgumentName() {
119
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
120
+ }
121
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
122
+
123
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
124
+
125
+ /// Returns the derived pass name.
126
+ static constexpr ::llvm::StringLiteral getPassName() {
127
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
128
+ }
129
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
130
+
131
+ /// Support isa/dyn_cast functionality for the derived pass class.
132
+ static bool classof(const ::mlir::Pass *pass) {
133
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
134
+ }
135
+
136
+ /// A clone method to create a copy of this pass.
137
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
138
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
139
+ }
140
+
141
+ /// Return the dialect that must be loaded in the context before this pass.
142
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
143
+ registry.insert<mlir::quant::QuantDialect>();
144
+ registry.insert<stablehlo::StablehloDialect>();
145
+ }
146
+
147
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
148
+ /// instantiation because Pass classes should only be visible by the current
149
+ /// library.
150
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
151
+
152
+ protected:
153
+ private:
154
+
155
+ friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
156
+ return std::make_unique<DerivedT>();
157
+ }
158
+ };
159
+ } // namespace impl
160
+
161
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
162
+ return impl::createStablehloAddQDQAfterConvPass();
163
+ }
164
+ #undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
165
+ #endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
166
+
90
167
  //===----------------------------------------------------------------------===//
91
168
  // StablehloCanonicalizeDynamismPass
92
169
  //===----------------------------------------------------------------------===//
@@ -360,9 +437,9 @@ public:
360
437
 
361
438
  /// Returns the command-line argument attached to this pass.
362
439
  static constexpr ::llvm::StringLiteral getArgumentName() {
363
- return ::llvm::StringLiteral("legalize-quant-composite");
440
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
364
441
  }
365
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
442
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
366
443
 
367
444
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
368
445
 
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
576
653
  });
577
654
  }
578
655
 
656
+ //===----------------------------------------------------------------------===//
657
+ // StablehloAddQDQAfterConvPass Registration
658
+ //===----------------------------------------------------------------------===//
659
+
660
+ inline void registerStablehloAddQDQAfterConvPass() {
661
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
662
+ return createStablehloAddQDQAfterConvPass();
663
+ });
664
+ }
665
+
666
+ // Old registration code, kept for temporary backwards compatibility.
667
+ inline void registerStablehloAddQDQAfterConvPassPass() {
668
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
669
+ return createStablehloAddQDQAfterConvPass();
670
+ });
671
+ }
672
+
579
673
  //===----------------------------------------------------------------------===//
580
674
  // StablehloCanonicalizeDynamismPass Registration
581
675
  //===----------------------------------------------------------------------===//
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
684
778
 
685
779
  inline void registerPasses() {
686
780
  registerChloRecomposeOpsPass();
781
+ registerStablehloAddQDQAfterConvPass();
687
782
  registerStablehloCanonicalizeDynamismPass();
688
783
  registerStablehloFlattenEntryFunctionTuplesPass();
689
784
  registerStablehloFlattenTuplePass();
@@ -745,6 +840,56 @@ public:
745
840
  protected:
746
841
  };
747
842
 
843
+ template <typename DerivedT>
844
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
845
+ public:
846
+ using Base = StablehloAddQDQAfterConvPassBase;
847
+
848
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
849
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
850
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
851
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
852
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
853
+ ~StablehloAddQDQAfterConvPassBase() = default;
854
+
855
+ /// Returns the command-line argument attached to this pass.
856
+ static constexpr ::llvm::StringLiteral getArgumentName() {
857
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
858
+ }
859
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
860
+
861
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
862
+
863
+ /// Returns the derived pass name.
864
+ static constexpr ::llvm::StringLiteral getPassName() {
865
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
866
+ }
867
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
868
+
869
+ /// Support isa/dyn_cast functionality for the derived pass class.
870
+ static bool classof(const ::mlir::Pass *pass) {
871
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
872
+ }
873
+
874
+ /// A clone method to create a copy of this pass.
875
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
876
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
877
+ }
878
+
879
+ /// Register the dialects that must be loaded in the context before this pass.
880
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
881
+ registry.insert<mlir::quant::QuantDialect>();
882
+ registry.insert<stablehlo::StablehloDialect>();
883
+ }
884
+
885
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
886
+ /// instantiation because Pass classes should only be visible by the current
887
+ /// library.
888
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
889
+
890
+ protected:
891
+ };
892
+
748
893
  template <typename DerivedT>
749
894
  class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
750
895
  public:
@@ -907,9 +1052,9 @@ public:
907
1052
 
908
1053
  /// Returns the command-line argument attached to this pass.
909
1054
  static constexpr ::llvm::StringLiteral getArgumentName() {
910
- return ::llvm::StringLiteral("legalize-quant-composite");
1055
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
911
1056
  }
912
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
1057
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
913
1058
 
914
1059
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
915
1060
 
@@ -145,6 +145,11 @@ class DistributedRuntimeClient {
145
145
  std::string barrier_id, absl::Duration timeout,
146
146
  std::optional<absl::Span<const int32_t>> nodes) = 0;
147
147
 
148
+ // Returns the subset of live nodes. See CoordinationService.GetAliveTasks for
149
+ // detailed semantics.
150
+ virtual absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
151
+ absl::Span<const int32_t> nodes) = 0;
152
+
148
153
  // Returns pointer to coordination service agent, or InternalError if the
149
154
  // client does not use coordination service.
150
155
  virtual absl::StatusOr<tsl::CoordinationServiceAgent*>
@@ -37,6 +37,7 @@ limitations under the License.
37
37
  #include "xla/pjrt/distributed/key_value_store_interface.h"
38
38
  #include "xla/pjrt/gpu/gpu_topology.h"
39
39
  #include "xla/pjrt/gpu/gpu_topology.pb.h"
40
+ #include "xla/pjrt/gpu/se_gpu_topology_description.h"
40
41
  #include "xla/pjrt/local_device_state.h"
41
42
  #include "xla/pjrt/pjrt_client.h"
42
43
  #include "xla/pjrt/pjrt_compiler.h"
@@ -57,98 +58,6 @@ using DeviceTopologyPair =
57
58
  std::pair<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>,
58
59
  GpuTopologyProto>;
59
60
 
60
- class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
61
- public:
62
- StreamExecutorGpuTopologyDescription(
63
- const PjRtPlatformId platform_id, const absl::string_view platform_name,
64
- std::shared_ptr<const GpuTopology> gpu_topology,
65
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
66
- {},
67
- std::optional<stream_executor::GpuTargetConfigProto> target_config =
68
- std::nullopt)
69
- : platform_id_(platform_id),
70
- platform_name_(platform_name),
71
- gpu_topology_(std::move(gpu_topology)),
72
- attributes_(attributes),
73
- target_config_(std::move(target_config)) {}
74
-
75
- bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
76
- return this->platform_id() == other.platform_id() &&
77
- this->platform_name() == other.platform_name() &&
78
- this->platform_version() == other.platform_version() &&
79
- this->gpu_topology() == other.gpu_topology();
80
- }
81
-
82
- PjRtPlatformId platform_id() const override { return platform_id_; }
83
-
84
- absl::string_view platform_name() const override { return platform_name_; }
85
-
86
- absl::string_view platform_version() const override {
87
- return gpu_topology_->platform_version();
88
- }
89
-
90
- std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
91
- const override {
92
- std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
93
- devices.reserve(gpu_topology_->number_of_devices());
94
- for (const int device_id : gpu_topology_->device_ids()) {
95
- devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
96
- device_id, std::string(platform_version())));
97
- }
98
- return devices;
99
- }
100
-
101
- const GpuTopology& gpu_topology() const { return *gpu_topology_; }
102
- const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
103
-
104
- // No subslice is supported.
105
- bool is_subslice_topology() const override { return false; }
106
-
107
- absl::StatusOr<int> ProcessCount() const override {
108
- return gpu_topology_->number_of_hosts();
109
- }
110
-
111
- absl::StatusOr<int> CoreCountOfDefaultType() const override {
112
- return gpu_topology_->number_of_devices();
113
- }
114
-
115
- absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
116
- return gpu_topology_->number_of_devices();
117
- }
118
-
119
- absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
120
- return gpu_topology_->number_of_devices();
121
- }
122
-
123
- absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
124
- return 1;
125
- }
126
-
127
- absl::StatusOr<std::string> Serialize() const override;
128
-
129
- const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
130
- const {
131
- return target_config_;
132
- }
133
-
134
- // Returns vendor specific attributes about the topology.
135
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
136
- const override {
137
- return attributes_;
138
- }
139
-
140
- absl::StatusOr<Layout> GetDefaultLayout(
141
- PrimitiveType element_type,
142
- absl::Span<const int64_t> dims) const override;
143
-
144
- private:
145
- const PjRtPlatformId platform_id_;
146
- const std::string platform_name_;
147
- std::shared_ptr<const GpuTopology> gpu_topology_;
148
- absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
149
- std::optional<stream_executor::GpuTargetConfigProto> target_config_;
150
- };
151
-
152
61
  class StreamExecutorGpuDevice : public PjRtStreamExecutorDevice {
153
62
  public:
154
63
  StreamExecutorGpuDevice(int id,
@@ -0,0 +1,126 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+ ==============================================================================*/
12
+ #ifndef XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
13
+ #define XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
14
+
15
+ #include <cstdint>
16
+ #include <memory>
17
+ #include <optional>
18
+ #include <string>
19
+ #include <utility>
20
+ #include <vector>
21
+
22
+ #include "absl/container/flat_hash_map.h"
23
+ #include "absl/status/statusor.h"
24
+ #include "absl/strings/string_view.h"
25
+ #include "absl/types/span.h"
26
+ #include "xla/pjrt/gpu/gpu_topology.h"
27
+ #include "xla/pjrt/pjrt_compiler.h"
28
+ #include "xla/pjrt/pjrt_device_description.h"
29
+ #include "xla/pjrt/pjrt_stream_executor_device_description.h"
30
+
31
+ namespace xla {
32
+
33
+ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
34
+ public:
35
+ StreamExecutorGpuTopologyDescription(
36
+ const PjRtPlatformId platform_id, const absl::string_view platform_name,
37
+ std::shared_ptr<const GpuTopology> gpu_topology,
38
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
39
+ {},
40
+ std::optional<stream_executor::GpuTargetConfigProto> target_config =
41
+ std::nullopt)
42
+ : platform_id_(platform_id),
43
+ platform_name_(platform_name),
44
+ gpu_topology_(std::move(gpu_topology)),
45
+ attributes_(attributes),
46
+ target_config_(std::move(target_config)) {}
47
+
48
+ bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
49
+ return this->platform_id() == other.platform_id() &&
50
+ this->platform_name() == other.platform_name() &&
51
+ this->platform_version() == other.platform_version() &&
52
+ this->gpu_topology() == other.gpu_topology();
53
+ }
54
+
55
+ PjRtPlatformId platform_id() const override { return platform_id_; }
56
+
57
+ absl::string_view platform_name() const override { return platform_name_; }
58
+
59
+ absl::string_view platform_version() const override {
60
+ return gpu_topology_->platform_version();
61
+ }
62
+
63
+ std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
64
+ const override {
65
+ std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
66
+ devices.reserve(gpu_topology_->number_of_devices());
67
+ for (const int device_id : gpu_topology_->device_ids()) {
68
+ devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
69
+ device_id, std::string(platform_version())));
70
+ }
71
+ return devices;
72
+ }
73
+
74
+ const GpuTopology& gpu_topology() const { return *gpu_topology_; }
75
+ const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
76
+
77
+ // No subslice is supported.
78
+ bool is_subslice_topology() const override { return false; }
79
+
80
+ absl::StatusOr<int> ProcessCount() const override {
81
+ return gpu_topology_->number_of_hosts();
82
+ }
83
+
84
+ absl::StatusOr<int> CoreCountOfDefaultType() const override {
85
+ return gpu_topology_->number_of_devices();
86
+ }
87
+
88
+ absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
89
+ return gpu_topology_->number_of_devices();
90
+ }
91
+
92
+ absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
93
+ return gpu_topology_->number_of_devices();
94
+ }
95
+
96
+ absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
97
+ return 1;
98
+ }
99
+
100
+ absl::StatusOr<std::string> Serialize() const override;
101
+
102
+ const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
103
+ const {
104
+ return target_config_;
105
+ }
106
+
107
+ // Returns vendor specific attributes about the topology.
108
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
109
+ const override {
110
+ return attributes_;
111
+ }
112
+
113
+ absl::StatusOr<Layout> GetDefaultLayout(
114
+ PrimitiveType element_type,
115
+ absl::Span<const int64_t> dims) const override;
116
+
117
+ private:
118
+ const PjRtPlatformId platform_id_;
119
+ const std::string platform_name_;
120
+ std::shared_ptr<const GpuTopology> gpu_topology_;
121
+ absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
122
+ std::optional<stream_executor::GpuTargetConfigProto> target_config_;
123
+ };
124
+ } // namespace xla
125
+
126
+ #endif // XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
@@ -52,9 +52,9 @@ limitations under the License.
52
52
  #include "xla/pjrt/pjrt_client.h"
53
53
  #include "xla/pjrt/pjrt_common.h"
54
54
  #include "xla/pjrt/pjrt_compiler.h"
55
- #include "xla/pjrt/pjrt_device_description.h"
56
55
  #include "xla/pjrt/pjrt_executable.h"
57
56
  #include "xla/pjrt/pjrt_future.h"
57
+ #include "xla/pjrt/pjrt_stream_executor_device_description.h"
58
58
  #include "xla/pjrt/tracked_device_buffer.h"
59
59
  #include "xla/pjrt/transpose.h"
60
60
  #include "xla/pjrt/utils.h"
@@ -77,54 +77,6 @@ limitations under the License.
77
77
 
78
78
  namespace xla {
79
79
 
80
- class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
81
- public:
82
- explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
83
- int process_index = 0)
84
- : id_(id),
85
- process_index_(process_index),
86
- device_kind_(std::move(device_kind)) {}
87
-
88
- int id() const override { return id_; }
89
-
90
- int process_index() const override { return process_index_; }
91
-
92
- absl::string_view device_kind() const override { return device_kind_; }
93
-
94
- absl::string_view ToString() const override { return to_string_; }
95
-
96
- absl::string_view DebugString() const override { return debug_string_; }
97
-
98
- absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
99
-
100
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
101
- const override {
102
- return attributes_;
103
- }
104
-
105
- void SetAttributes(
106
- absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
107
- attributes_ = std::move(attributes);
108
- }
109
-
110
- void SetDebugString(std::string debug_string) {
111
- debug_string_ = std::move(debug_string);
112
- }
113
-
114
- void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
115
-
116
- void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
117
-
118
- private:
119
- const int id_;
120
- const int process_index_;
121
- const std::string device_kind_;
122
- std::string debug_string_ = "<unknown SE device>";
123
- std::string to_string_ = "<unknown SE device>";
124
- absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
125
- std::array<int, 1> coords_;
126
- };
127
-
128
80
  class PjRtStreamExecutorDevice : public PjRtDevice {
129
81
  public:
130
82
  explicit PjRtStreamExecutorDevice(
@@ -0,0 +1,75 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+ ==============================================================================*/
12
+ #ifndef XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
13
+ #define XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
14
+
15
+ #include <array>
16
+ #include <string>
17
+ #include <utility>
18
+
19
+ #include "absl/container/flat_hash_map.h"
20
+ #include "absl/strings/string_view.h"
21
+ #include "absl/types/span.h"
22
+ #include "xla/pjrt/pjrt_device_description.h"
23
+
24
+ namespace xla {
25
+
26
+ class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
27
+ public:
28
+ explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
29
+ int process_index = 0)
30
+ : id_(id),
31
+ process_index_(process_index),
32
+ device_kind_(std::move(device_kind)) {}
33
+
34
+ int id() const override { return id_; }
35
+
36
+ int process_index() const override { return process_index_; }
37
+
38
+ absl::string_view device_kind() const override { return device_kind_; }
39
+
40
+ absl::string_view ToString() const override { return to_string_; }
41
+
42
+ absl::string_view DebugString() const override { return debug_string_; }
43
+
44
+ absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
45
+
46
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
47
+ const override {
48
+ return attributes_;
49
+ }
50
+
51
+ void SetAttributes(
52
+ absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
53
+ attributes_ = std::move(attributes);
54
+ }
55
+
56
+ void SetDebugString(std::string debug_string) {
57
+ debug_string_ = std::move(debug_string);
58
+ }
59
+
60
+ void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
61
+
62
+ void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
63
+
64
+ private:
65
+ const int id_;
66
+ const int process_index_;
67
+ const std::string device_kind_;
68
+ std::string debug_string_ = "<unknown SE device>";
69
+ std::string to_string_ = "<unknown SE device>";
70
+ absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
71
+ std::array<int, 1> coords_;
72
+ };
73
+ } // namespace xla
74
+
75
+ #endif // XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
@@ -0,0 +1,57 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
17
+ #define XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
18
+
19
+ #include <optional>
20
+
21
+ #include "xla/backends/cpu/collectives/cpu_collectives.h"
22
+ #include "xla/pjrt/pjrt_executable.h"
23
+
24
+ namespace xla {
25
+
26
+ // ExecuteContext for XLA:CPU PjRtLoadedExecutable::Execute calls.
27
+ class CpuExecuteContext : public ExecuteContext {
28
+ public:
29
+ ~CpuExecuteContext() override = default;
30
+
31
+ // If specified, override the process ID specified in
32
+ // `CpuClientOptions::process_id` for a particular call of
33
+ // PjRtLoadedExecutable::Execute.
34
+ //
35
+ // TODO(hyeontaek): Look for a collectives-agnostic way and combine this
36
+ // option with `ExecuteOptions::multi_slice_config`.
37
+ std::optional<int>& process_index() { return process_index_; }
38
+ std::optional<int> process_index() const { return process_index_; }
39
+
40
+ // If specified, override CPU collectives specified in
41
+ // `CpuClientOptions::collectives` for a particular call of
42
+ // PjRtLoadedExecutable::Execute. Must remain valid until the execution
43
+ // finishes.
44
+ //
45
+ // TODO(hyeontaek): Look for a collectives-agnostic way and combine this
46
+ // option with `ExecuteOptions::multi_slice_config`.
47
+ cpu::CpuCollectives*& collectives() { return collectives_; }
48
+ cpu::CpuCollectives* collectives() const { return collectives_; }
49
+
50
+ private:
51
+ std::optional<int> process_index_;
52
+ cpu::CpuCollectives* collectives_ = nullptr;
53
+ };
54
+
55
+ } // namespace xla
56
+
57
+ #endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
@@ -69,6 +69,10 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) {
69
69
  return global_device_id.value() / kMaxCpuDevicesPerProcess;
70
70
  }
71
71
 
72
+ inline int UnpackCpuLocalDeviceId(PjRtGlobalDeviceId global_device_id) {
73
+ return global_device_id.value() % kMaxCpuDevicesPerProcess;
74
+ }
75
+
72
76
  } // namespace xla
73
77
 
74
78
  #endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_
@@ -18,6 +18,7 @@ limitations under the License.
18
18
 
19
19
  #include <string>
20
20
 
21
+ #include "absl/base/casts.h"
21
22
  #include "absl/status/statusor.h"
22
23
  #include "xla/literal.h"
23
24
  #include "xla/util.h"