tf-nightly-cpu 2.20.0.dev20250220__cp310-cp310-win_amd64.whl → 2.20.0.dev20250222__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (128) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
  14. tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
  15. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  17. tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
  18. tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
  19. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  20. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  21. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  22. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  23. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
  24. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  25. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  26. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
  27. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  28. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  30. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  31. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  32. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  33. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  34. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  35. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  36. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  37. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  38. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  39. tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
  40. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  41. tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
  42. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  43. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  44. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  45. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  46. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  47. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  48. tensorflow/include/tensorflow/core/public/version.h +112 -127
  49. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  50. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  51. tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  52. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  53. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
  54. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  55. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  56. tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
  57. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  58. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  59. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  60. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  61. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  62. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  63. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  64. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  65. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  66. tensorflow/include/xla/service/constant_value.h +1 -0
  67. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  68. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  69. tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
  70. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  71. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  72. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  73. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  74. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  75. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  76. tensorflow/python/_pywrap_mlir.pyd +0 -0
  77. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  78. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  79. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  80. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  81. tensorflow/python/_pywrap_tfe.pyd +0 -0
  82. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  83. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  84. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  85. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  86. tensorflow/python/compat/compat.py +1 -1
  87. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  88. tensorflow/python/eager/imperative_grad.py +5 -5
  89. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  90. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  91. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  92. tensorflow/python/eager/tape.py +2 -2
  93. tensorflow/python/framework/_dtypes.pyd +0 -0
  94. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  95. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  96. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  97. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  98. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  99. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  100. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  101. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  102. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  103. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  104. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  105. tensorflow/python/ops/summary_ops_v2.py +5 -1
  106. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  107. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  108. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  109. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  110. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  111. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  112. tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
  113. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  114. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  115. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  116. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  117. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  118. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  119. tensorflow/python/util/_tf_stack.pyd +0 -0
  120. tensorflow/tools/pip_package/setup.py +2 -2
  121. tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
  122. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
  123. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +126 -121
  124. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
  125. tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
  126. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
  127. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
  128. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -44,28 +44,6 @@ T* Cast(HloInstruction* instr) {
44
44
  return tsl::down_cast<T*>(instr);
45
45
  }
46
46
 
47
- // Downcasts a const HloInstruction pointer or returns nullptr if argument is
48
- // nullptr. Dies if TargetClass::ClassOf() does not match.
49
- template <typename T>
50
- const T* CastOrNull(const HloInstruction* i) {
51
- if (i == nullptr) {
52
- return nullptr;
53
- }
54
- CHECK(T::ClassOf(i));
55
- return tsl::down_cast<const T*>(i);
56
- }
57
-
58
- // Downcasts a const HloInstruction pointer or returns nullptr if argument is
59
- // nullptr. Dies if TargetClass::ClassOf() does not match.
60
- template <typename T>
61
- T* CastOrNull(HloInstruction* i) {
62
- if (i == nullptr) {
63
- return nullptr;
64
- }
65
- CHECK(T::ClassOf(i));
66
- return tsl::down_cast<T*>(i);
67
- }
68
-
69
47
  // Downcasts a const HloInstruction pointer or returns nullptr if
70
48
  // TargetClass::ClassOf() does not match. Dies if argument is nullptr. Similar
71
49
  // to LLVM's dyn_cast.
@@ -84,28 +62,6 @@ T* DynCast(HloInstruction* i) {
84
62
  return !T::ClassOf(i) ? nullptr : tsl::down_cast<T*>(i);
85
63
  }
86
64
 
87
- // Downcasts a const HloInstruction pointer. Return nullptr if argument is
88
- // nullptr orTargetClass::ClassOf() does not match. Similar to LLVM's
89
- // dyn_cast_or_null.
90
- template <typename T>
91
- const T* DynCastOrNull(const HloInstruction* instruction) {
92
- if (instruction == nullptr || !T::ClassOf(instruction)) {
93
- return nullptr;
94
- }
95
- return tsl::down_cast<const T*>(instruction);
96
- }
97
-
98
- // Downcasts a non-const HloInstruction pointer. Return nullptr if argument is
99
- // nullptr orTargetClass::ClassOf() does not match. Similar to LLVM's
100
- // dyn_cast_or_null.
101
- template <typename T>
102
- T* DynCastOrNull(HloInstruction* instruction) {
103
- if (instruction == nullptr || !T::ClassOf(instruction)) {
104
- return nullptr;
105
- }
106
- return tsl::down_cast<T*>(instruction);
107
- }
108
-
109
65
  } // namespace xla
110
66
 
111
67
  #endif // XLA_HLO_IR_HLO_CASTING_UTILS_H_
@@ -1914,6 +1914,18 @@ class HloInstruction {
1914
1914
  result_accuracy().mode() != ResultAccuracy::DEFAULT);
1915
1915
  }
1916
1916
 
1917
+ bool equal_result_accuracy(const HloInstruction* other) const {
1918
+ return result_accuracy().has_tolerance() ==
1919
+ other->result_accuracy().has_tolerance() &&
1920
+ result_accuracy().tolerance().atol() ==
1921
+ other->result_accuracy().tolerance().atol() &&
1922
+ result_accuracy().tolerance().rtol() ==
1923
+ other->result_accuracy().tolerance().rtol() &&
1924
+ result_accuracy().tolerance().ulps() ==
1925
+ other->result_accuracy().tolerance().ulps() &&
1926
+ result_accuracy().mode() == other->result_accuracy().mode();
1927
+ }
1928
+
1917
1929
  void add_single_statistic(Statistic statistic) {
1918
1930
  *mutable_rare()->statistics_viz.add_statistics() = std::move(statistic);
1919
1931
  }
@@ -3,6 +3,7 @@
3
3
  #ifdef GEN_PASS_DECL
4
4
  // Generate declarations for all passes.
5
5
  #define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
6
+ #define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
6
7
  #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
7
8
  #define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
8
9
  #define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
87
88
  #undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
88
89
  #endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
89
90
 
91
+ //===----------------------------------------------------------------------===//
92
+ // StablehloAddQDQAfterConvPass
93
+ //===----------------------------------------------------------------------===//
94
+ #ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
95
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
96
+ #undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
97
+ #endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
98
+ #ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
99
+
100
+ namespace impl {
101
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
102
+ } // namespace impl
103
+ namespace impl {
104
+
105
+ template <typename DerivedT>
106
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
107
+ public:
108
+ using Base = StablehloAddQDQAfterConvPassBase;
109
+
110
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
111
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
112
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
113
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
114
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
115
+ ~StablehloAddQDQAfterConvPassBase() = default;
116
+
117
+ /// Returns the command-line argument attached to this pass.
118
+ static constexpr ::llvm::StringLiteral getArgumentName() {
119
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
120
+ }
121
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
122
+
123
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
124
+
125
+ /// Returns the derived pass name.
126
+ static constexpr ::llvm::StringLiteral getPassName() {
127
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
128
+ }
129
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
130
+
131
+ /// Support isa/dyn_cast functionality for the derived pass class.
132
+ static bool classof(const ::mlir::Pass *pass) {
133
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
134
+ }
135
+
136
+ /// A clone method to create a copy of this pass.
137
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
138
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
139
+ }
140
+
141
+ /// Return the dialect that must be loaded in the context before this pass.
142
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
143
+ registry.insert<mlir::quant::QuantDialect>();
144
+ registry.insert<stablehlo::StablehloDialect>();
145
+ }
146
+
147
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
148
+ /// instantiation because Pass classes should only be visible by the current
149
+ /// library.
150
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
151
+
152
+ protected:
153
+ private:
154
+
155
+ friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
156
+ return std::make_unique<DerivedT>();
157
+ }
158
+ };
159
+ } // namespace impl
160
+
161
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
162
+ return impl::createStablehloAddQDQAfterConvPass();
163
+ }
164
+ #undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
165
+ #endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
166
+
90
167
  //===----------------------------------------------------------------------===//
91
168
  // StablehloCanonicalizeDynamismPass
92
169
  //===----------------------------------------------------------------------===//
@@ -360,9 +437,9 @@ public:
360
437
 
361
438
  /// Returns the command-line argument attached to this pass.
362
439
  static constexpr ::llvm::StringLiteral getArgumentName() {
363
- return ::llvm::StringLiteral("legalize-quant-composite");
440
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
364
441
  }
365
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
442
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
366
443
 
367
444
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
368
445
 
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
576
653
  });
577
654
  }
578
655
 
656
+ //===----------------------------------------------------------------------===//
657
+ // StablehloAddQDQAfterConvPass Registration
658
+ //===----------------------------------------------------------------------===//
659
+
660
+ inline void registerStablehloAddQDQAfterConvPass() {
661
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
662
+ return createStablehloAddQDQAfterConvPass();
663
+ });
664
+ }
665
+
666
+ // Old registration code, kept for temporary backwards compatibility.
667
+ inline void registerStablehloAddQDQAfterConvPassPass() {
668
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
669
+ return createStablehloAddQDQAfterConvPass();
670
+ });
671
+ }
672
+
579
673
  //===----------------------------------------------------------------------===//
580
674
  // StablehloCanonicalizeDynamismPass Registration
581
675
  //===----------------------------------------------------------------------===//
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
684
778
 
685
779
  inline void registerPasses() {
686
780
  registerChloRecomposeOpsPass();
781
+ registerStablehloAddQDQAfterConvPass();
687
782
  registerStablehloCanonicalizeDynamismPass();
688
783
  registerStablehloFlattenEntryFunctionTuplesPass();
689
784
  registerStablehloFlattenTuplePass();
@@ -745,6 +840,56 @@ public:
745
840
  protected:
746
841
  };
747
842
 
843
+ template <typename DerivedT>
844
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
845
+ public:
846
+ using Base = StablehloAddQDQAfterConvPassBase;
847
+
848
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
849
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
850
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
851
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
852
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
853
+ ~StablehloAddQDQAfterConvPassBase() = default;
854
+
855
+ /// Returns the command-line argument attached to this pass.
856
+ static constexpr ::llvm::StringLiteral getArgumentName() {
857
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
858
+ }
859
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
860
+
861
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
862
+
863
+ /// Returns the derived pass name.
864
+ static constexpr ::llvm::StringLiteral getPassName() {
865
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
866
+ }
867
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
868
+
869
+ /// Support isa/dyn_cast functionality for the derived pass class.
870
+ static bool classof(const ::mlir::Pass *pass) {
871
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
872
+ }
873
+
874
+ /// A clone method to create a copy of this pass.
875
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
876
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
877
+ }
878
+
879
+ /// Register the dialects that must be loaded in the context before this pass.
880
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
881
+ registry.insert<mlir::quant::QuantDialect>();
882
+ registry.insert<stablehlo::StablehloDialect>();
883
+ }
884
+
885
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
886
+ /// instantiation because Pass classes should only be visible by the current
887
+ /// library.
888
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
889
+
890
+ protected:
891
+ };
892
+
748
893
  template <typename DerivedT>
749
894
  class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
750
895
  public:
@@ -907,9 +1052,9 @@ public:
907
1052
 
908
1053
  /// Returns the command-line argument attached to this pass.
909
1054
  static constexpr ::llvm::StringLiteral getArgumentName() {
910
- return ::llvm::StringLiteral("legalize-quant-composite");
1055
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
911
1056
  }
912
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
1057
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
913
1058
 
914
1059
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
915
1060
 
@@ -3,6 +3,7 @@
3
3
  #ifdef GEN_PASS_DECL
4
4
  // Generate declarations for all passes.
5
5
  #define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
6
+ #define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
6
7
  #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
7
8
  #define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
8
9
  #define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
87
88
  #undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
88
89
  #endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
89
90
 
91
+ //===----------------------------------------------------------------------===//
92
+ // StablehloAddQDQAfterConvPass
93
+ //===----------------------------------------------------------------------===//
94
+ #ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
95
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
96
+ #undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
97
+ #endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
98
+ #ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
99
+
100
+ namespace impl {
101
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
102
+ } // namespace impl
103
+ namespace impl {
104
+
105
+ template <typename DerivedT>
106
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
107
+ public:
108
+ using Base = StablehloAddQDQAfterConvPassBase;
109
+
110
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
111
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
112
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
113
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
114
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
115
+ ~StablehloAddQDQAfterConvPassBase() = default;
116
+
117
+ /// Returns the command-line argument attached to this pass.
118
+ static constexpr ::llvm::StringLiteral getArgumentName() {
119
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
120
+ }
121
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
122
+
123
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
124
+
125
+ /// Returns the derived pass name.
126
+ static constexpr ::llvm::StringLiteral getPassName() {
127
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
128
+ }
129
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
130
+
131
+ /// Support isa/dyn_cast functionality for the derived pass class.
132
+ static bool classof(const ::mlir::Pass *pass) {
133
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
134
+ }
135
+
136
+ /// A clone method to create a copy of this pass.
137
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
138
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
139
+ }
140
+
141
+ /// Return the dialect that must be loaded in the context before this pass.
142
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
143
+ registry.insert<mlir::quant::QuantDialect>();
144
+ registry.insert<stablehlo::StablehloDialect>();
145
+ }
146
+
147
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
148
+ /// instantiation because Pass classes should only be visible by the current
149
+ /// library.
150
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
151
+
152
+ protected:
153
+ private:
154
+
155
+ friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
156
+ return std::make_unique<DerivedT>();
157
+ }
158
+ };
159
+ } // namespace impl
160
+
161
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
162
+ return impl::createStablehloAddQDQAfterConvPass();
163
+ }
164
+ #undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
165
+ #endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
166
+
90
167
  //===----------------------------------------------------------------------===//
91
168
  // StablehloCanonicalizeDynamismPass
92
169
  //===----------------------------------------------------------------------===//
@@ -360,9 +437,9 @@ public:
360
437
 
361
438
  /// Returns the command-line argument attached to this pass.
362
439
  static constexpr ::llvm::StringLiteral getArgumentName() {
363
- return ::llvm::StringLiteral("legalize-quant-composite");
440
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
364
441
  }
365
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
442
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
366
443
 
367
444
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
368
445
 
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
576
653
  });
577
654
  }
578
655
 
656
+ //===----------------------------------------------------------------------===//
657
+ // StablehloAddQDQAfterConvPass Registration
658
+ //===----------------------------------------------------------------------===//
659
+
660
+ inline void registerStablehloAddQDQAfterConvPass() {
661
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
662
+ return createStablehloAddQDQAfterConvPass();
663
+ });
664
+ }
665
+
666
+ // Old registration code, kept for temporary backwards compatibility.
667
+ inline void registerStablehloAddQDQAfterConvPassPass() {
668
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
669
+ return createStablehloAddQDQAfterConvPass();
670
+ });
671
+ }
672
+
579
673
  //===----------------------------------------------------------------------===//
580
674
  // StablehloCanonicalizeDynamismPass Registration
581
675
  //===----------------------------------------------------------------------===//
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
684
778
 
685
779
  inline void registerPasses() {
686
780
  registerChloRecomposeOpsPass();
781
+ registerStablehloAddQDQAfterConvPass();
687
782
  registerStablehloCanonicalizeDynamismPass();
688
783
  registerStablehloFlattenEntryFunctionTuplesPass();
689
784
  registerStablehloFlattenTuplePass();
@@ -745,6 +840,56 @@ public:
745
840
  protected:
746
841
  };
747
842
 
843
+ template <typename DerivedT>
844
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
845
+ public:
846
+ using Base = StablehloAddQDQAfterConvPassBase;
847
+
848
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
849
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
850
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
851
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
852
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
853
+ ~StablehloAddQDQAfterConvPassBase() = default;
854
+
855
+ /// Returns the command-line argument attached to this pass.
856
+ static constexpr ::llvm::StringLiteral getArgumentName() {
857
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
858
+ }
859
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
860
+
861
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
862
+
863
+ /// Returns the derived pass name.
864
+ static constexpr ::llvm::StringLiteral getPassName() {
865
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
866
+ }
867
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
868
+
869
+ /// Support isa/dyn_cast functionality for the derived pass class.
870
+ static bool classof(const ::mlir::Pass *pass) {
871
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
872
+ }
873
+
874
+ /// A clone method to create a copy of this pass.
875
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
876
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
877
+ }
878
+
879
+ /// Register the dialects that must be loaded in the context before this pass.
880
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
881
+ registry.insert<mlir::quant::QuantDialect>();
882
+ registry.insert<stablehlo::StablehloDialect>();
883
+ }
884
+
885
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
886
+ /// instantiation because Pass classes should only be visible by the current
887
+ /// library.
888
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
889
+
890
+ protected:
891
+ };
892
+
748
893
  template <typename DerivedT>
749
894
  class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
750
895
  public:
@@ -907,9 +1052,9 @@ public:
907
1052
 
908
1053
  /// Returns the command-line argument attached to this pass.
909
1054
  static constexpr ::llvm::StringLiteral getArgumentName() {
910
- return ::llvm::StringLiteral("legalize-quant-composite");
1055
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
911
1056
  }
912
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
1057
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
913
1058
 
914
1059
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
915
1060
 
@@ -145,6 +145,11 @@ class DistributedRuntimeClient {
145
145
  std::string barrier_id, absl::Duration timeout,
146
146
  std::optional<absl::Span<const int32_t>> nodes) = 0;
147
147
 
148
+ // Returns the subset of live nodes. See CoordinationService.GetAliveTasks for
149
+ // detailed semantics.
150
+ virtual absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
151
+ absl::Span<const int32_t> nodes) = 0;
152
+
148
153
  // Returns pointer to coordination service agent, or InternalError if the
149
154
  // client does not use coordination service.
150
155
  virtual absl::StatusOr<tsl::CoordinationServiceAgent*>
@@ -37,6 +37,7 @@ limitations under the License.
37
37
  #include "xla/pjrt/distributed/key_value_store_interface.h"
38
38
  #include "xla/pjrt/gpu/gpu_topology.h"
39
39
  #include "xla/pjrt/gpu/gpu_topology.pb.h"
40
+ #include "xla/pjrt/gpu/se_gpu_topology_description.h"
40
41
  #include "xla/pjrt/local_device_state.h"
41
42
  #include "xla/pjrt/pjrt_client.h"
42
43
  #include "xla/pjrt/pjrt_compiler.h"
@@ -57,98 +58,6 @@ using DeviceTopologyPair =
57
58
  std::pair<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>,
58
59
  GpuTopologyProto>;
59
60
 
60
- class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
61
- public:
62
- StreamExecutorGpuTopologyDescription(
63
- const PjRtPlatformId platform_id, const absl::string_view platform_name,
64
- std::shared_ptr<const GpuTopology> gpu_topology,
65
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
66
- {},
67
- std::optional<stream_executor::GpuTargetConfigProto> target_config =
68
- std::nullopt)
69
- : platform_id_(platform_id),
70
- platform_name_(platform_name),
71
- gpu_topology_(std::move(gpu_topology)),
72
- attributes_(attributes),
73
- target_config_(std::move(target_config)) {}
74
-
75
- bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
76
- return this->platform_id() == other.platform_id() &&
77
- this->platform_name() == other.platform_name() &&
78
- this->platform_version() == other.platform_version() &&
79
- this->gpu_topology() == other.gpu_topology();
80
- }
81
-
82
- PjRtPlatformId platform_id() const override { return platform_id_; }
83
-
84
- absl::string_view platform_name() const override { return platform_name_; }
85
-
86
- absl::string_view platform_version() const override {
87
- return gpu_topology_->platform_version();
88
- }
89
-
90
- std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
91
- const override {
92
- std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
93
- devices.reserve(gpu_topology_->number_of_devices());
94
- for (const int device_id : gpu_topology_->device_ids()) {
95
- devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
96
- device_id, std::string(platform_version())));
97
- }
98
- return devices;
99
- }
100
-
101
- const GpuTopology& gpu_topology() const { return *gpu_topology_; }
102
- const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
103
-
104
- // No subslice is supported.
105
- bool is_subslice_topology() const override { return false; }
106
-
107
- absl::StatusOr<int> ProcessCount() const override {
108
- return gpu_topology_->number_of_hosts();
109
- }
110
-
111
- absl::StatusOr<int> CoreCountOfDefaultType() const override {
112
- return gpu_topology_->number_of_devices();
113
- }
114
-
115
- absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
116
- return gpu_topology_->number_of_devices();
117
- }
118
-
119
- absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
120
- return gpu_topology_->number_of_devices();
121
- }
122
-
123
- absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
124
- return 1;
125
- }
126
-
127
- absl::StatusOr<std::string> Serialize() const override;
128
-
129
- const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
130
- const {
131
- return target_config_;
132
- }
133
-
134
- // Returns vendor specific attributes about the topology.
135
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
136
- const override {
137
- return attributes_;
138
- }
139
-
140
- absl::StatusOr<Layout> GetDefaultLayout(
141
- PrimitiveType element_type,
142
- absl::Span<const int64_t> dims) const override;
143
-
144
- private:
145
- const PjRtPlatformId platform_id_;
146
- const std::string platform_name_;
147
- std::shared_ptr<const GpuTopology> gpu_topology_;
148
- absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
149
- std::optional<stream_executor::GpuTargetConfigProto> target_config_;
150
- };
151
-
152
61
  class StreamExecutorGpuDevice : public PjRtStreamExecutorDevice {
153
62
  public:
154
63
  StreamExecutorGpuDevice(int id,