tf-nightly-cpu 2.20.0.dev20250220__cp39-cp39-win_amd64.whl → 2.20.0.dev20250221__cp39-cp39-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (115) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  14. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  15. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  17. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  18. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -18
  19. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  20. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  21. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  22. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  23. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  24. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  25. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  26. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  27. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  28. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  30. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  31. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  32. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  33. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  34. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  35. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  36. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  37. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  38. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  39. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  40. tensorflow/include/tensorflow/core/public/version.h +112 -127
  41. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  42. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  43. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  44. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -18
  45. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  46. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  47. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  48. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  49. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  50. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  51. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  52. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  53. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  54. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  55. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  56. tensorflow/include/xla/service/constant_value.h +1 -0
  57. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  58. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  59. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  60. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  61. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  62. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  63. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  64. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  65. tensorflow/python/_pywrap_mlir.pyd +0 -0
  66. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  67. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  68. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  69. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  70. tensorflow/python/_pywrap_tfe.pyd +0 -0
  71. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  72. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  73. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  74. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  75. tensorflow/python/compat/compat.py +1 -1
  76. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  77. tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
  78. tensorflow/python/eager/imperative_grad.py +5 -5
  79. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  80. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  81. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  82. tensorflow/python/eager/tape.py +2 -2
  83. tensorflow/python/framework/_dtypes.pyd +0 -0
  84. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  85. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  86. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  87. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  88. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  89. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  90. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  91. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  92. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  93. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  94. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  95. tensorflow/python/ops/summary_ops_v2.py +5 -1
  96. tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
  97. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  98. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  99. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  100. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  101. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  102. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  103. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  104. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  105. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  106. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  107. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  108. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  109. tensorflow/python/util/_tf_stack.pyd +0 -0
  110. tensorflow/tools/pip_package/setup.py +2 -2
  111. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/METADATA +1 -1
  112. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/RECORD +115 -108
  113. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/WHEEL +0 -0
  114. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/entry_points.txt +0 -0
  115. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  #ifdef GEN_PASS_DECL
4
4
  // Generate declarations for all passes.
5
5
  #define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
6
+ #define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
6
7
  #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
7
8
  #define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
8
9
  #define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
87
88
  #undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
88
89
  #endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
89
90
 
91
+ //===----------------------------------------------------------------------===//
92
+ // StablehloAddQDQAfterConvPass
93
+ //===----------------------------------------------------------------------===//
94
+ #ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
95
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
96
+ #undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
97
+ #endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
98
+ #ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
99
+
100
+ namespace impl {
101
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
102
+ } // namespace impl
103
+ namespace impl {
104
+
105
+ template <typename DerivedT>
106
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
107
+ public:
108
+ using Base = StablehloAddQDQAfterConvPassBase;
109
+
110
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
111
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
112
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
113
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
114
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
115
+ ~StablehloAddQDQAfterConvPassBase() = default;
116
+
117
+ /// Returns the command-line argument attached to this pass.
118
+ static constexpr ::llvm::StringLiteral getArgumentName() {
119
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
120
+ }
121
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
122
+
123
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
124
+
125
+ /// Returns the derived pass name.
126
+ static constexpr ::llvm::StringLiteral getPassName() {
127
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
128
+ }
129
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
130
+
131
+ /// Support isa/dyn_cast functionality for the derived pass class.
132
+ static bool classof(const ::mlir::Pass *pass) {
133
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
134
+ }
135
+
136
+ /// A clone method to create a copy of this pass.
137
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
138
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
139
+ }
140
+
141
+ /// Return the dialect that must be loaded in the context before this pass.
142
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
143
+ registry.insert<mlir::quant::QuantDialect>();
144
+ registry.insert<stablehlo::StablehloDialect>();
145
+ }
146
+
147
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
148
+ /// instantiation because Pass classes should only be visible by the current
149
+ /// library.
150
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
151
+
152
+ protected:
153
+ private:
154
+
155
+ friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
156
+ return std::make_unique<DerivedT>();
157
+ }
158
+ };
159
+ } // namespace impl
160
+
161
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
162
+ return impl::createStablehloAddQDQAfterConvPass();
163
+ }
164
+ #undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
165
+ #endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
166
+
90
167
  //===----------------------------------------------------------------------===//
91
168
  // StablehloCanonicalizeDynamismPass
92
169
  //===----------------------------------------------------------------------===//
@@ -360,9 +437,9 @@ public:
360
437
 
361
438
  /// Returns the command-line argument attached to this pass.
362
439
  static constexpr ::llvm::StringLiteral getArgumentName() {
363
- return ::llvm::StringLiteral("legalize-quant-composite");
440
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
364
441
  }
365
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
442
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
366
443
 
367
444
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
368
445
 
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
576
653
  });
577
654
  }
578
655
 
656
+ //===----------------------------------------------------------------------===//
657
+ // StablehloAddQDQAfterConvPass Registration
658
+ //===----------------------------------------------------------------------===//
659
+
660
+ inline void registerStablehloAddQDQAfterConvPass() {
661
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
662
+ return createStablehloAddQDQAfterConvPass();
663
+ });
664
+ }
665
+
666
+ // Old registration code, kept for temporary backwards compatibility.
667
+ inline void registerStablehloAddQDQAfterConvPassPass() {
668
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
669
+ return createStablehloAddQDQAfterConvPass();
670
+ });
671
+ }
672
+
579
673
  //===----------------------------------------------------------------------===//
580
674
  // StablehloCanonicalizeDynamismPass Registration
581
675
  //===----------------------------------------------------------------------===//
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
684
778
 
685
779
  inline void registerPasses() {
686
780
  registerChloRecomposeOpsPass();
781
+ registerStablehloAddQDQAfterConvPass();
687
782
  registerStablehloCanonicalizeDynamismPass();
688
783
  registerStablehloFlattenEntryFunctionTuplesPass();
689
784
  registerStablehloFlattenTuplePass();
@@ -745,6 +840,56 @@ public:
745
840
  protected:
746
841
  };
747
842
 
843
+ template <typename DerivedT>
844
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
845
+ public:
846
+ using Base = StablehloAddQDQAfterConvPassBase;
847
+
848
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
849
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
850
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
851
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
852
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
853
+ ~StablehloAddQDQAfterConvPassBase() = default;
854
+
855
+ /// Returns the command-line argument attached to this pass.
856
+ static constexpr ::llvm::StringLiteral getArgumentName() {
857
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
858
+ }
859
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
860
+
861
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
862
+
863
+ /// Returns the derived pass name.
864
+ static constexpr ::llvm::StringLiteral getPassName() {
865
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
866
+ }
867
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
868
+
869
+ /// Support isa/dyn_cast functionality for the derived pass class.
870
+ static bool classof(const ::mlir::Pass *pass) {
871
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
872
+ }
873
+
874
+ /// A clone method to create a copy of this pass.
875
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
876
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
877
+ }
878
+
879
+ /// Register the dialects that must be loaded in the context before this pass.
880
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
881
+ registry.insert<mlir::quant::QuantDialect>();
882
+ registry.insert<stablehlo::StablehloDialect>();
883
+ }
884
+
885
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
886
+ /// instantiation because Pass classes should only be visible by the current
887
+ /// library.
888
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
889
+
890
+ protected:
891
+ };
892
+
748
893
  template <typename DerivedT>
749
894
  class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
750
895
  public:
@@ -907,9 +1052,9 @@ public:
907
1052
 
908
1053
  /// Returns the command-line argument attached to this pass.
909
1054
  static constexpr ::llvm::StringLiteral getArgumentName() {
910
- return ::llvm::StringLiteral("legalize-quant-composite");
1055
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
911
1056
  }
912
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
1057
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
913
1058
 
914
1059
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
915
1060
 
@@ -3,6 +3,7 @@
3
3
  #ifdef GEN_PASS_DECL
4
4
  // Generate declarations for all passes.
5
5
  #define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
6
+ #define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
6
7
  #define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
7
8
  #define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
8
9
  #define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
87
88
  #undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
88
89
  #endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
89
90
 
91
+ //===----------------------------------------------------------------------===//
92
+ // StablehloAddQDQAfterConvPass
93
+ //===----------------------------------------------------------------------===//
94
+ #ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
95
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
96
+ #undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
97
+ #endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
98
+ #ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
99
+
100
+ namespace impl {
101
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
102
+ } // namespace impl
103
+ namespace impl {
104
+
105
+ template <typename DerivedT>
106
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
107
+ public:
108
+ using Base = StablehloAddQDQAfterConvPassBase;
109
+
110
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
111
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
112
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
113
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
114
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
115
+ ~StablehloAddQDQAfterConvPassBase() = default;
116
+
117
+ /// Returns the command-line argument attached to this pass.
118
+ static constexpr ::llvm::StringLiteral getArgumentName() {
119
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
120
+ }
121
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
122
+
123
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
124
+
125
+ /// Returns the derived pass name.
126
+ static constexpr ::llvm::StringLiteral getPassName() {
127
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
128
+ }
129
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
130
+
131
+ /// Support isa/dyn_cast functionality for the derived pass class.
132
+ static bool classof(const ::mlir::Pass *pass) {
133
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
134
+ }
135
+
136
+ /// A clone method to create a copy of this pass.
137
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
138
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
139
+ }
140
+
141
+ /// Return the dialect that must be loaded in the context before this pass.
142
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
143
+ registry.insert<mlir::quant::QuantDialect>();
144
+ registry.insert<stablehlo::StablehloDialect>();
145
+ }
146
+
147
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
148
+ /// instantiation because Pass classes should only be visible by the current
149
+ /// library.
150
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
151
+
152
+ protected:
153
+ private:
154
+
155
+ friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
156
+ return std::make_unique<DerivedT>();
157
+ }
158
+ };
159
+ } // namespace impl
160
+
161
+ std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
162
+ return impl::createStablehloAddQDQAfterConvPass();
163
+ }
164
+ #undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
165
+ #endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
166
+
90
167
  //===----------------------------------------------------------------------===//
91
168
  // StablehloCanonicalizeDynamismPass
92
169
  //===----------------------------------------------------------------------===//
@@ -360,9 +437,9 @@ public:
360
437
 
361
438
  /// Returns the command-line argument attached to this pass.
362
439
  static constexpr ::llvm::StringLiteral getArgumentName() {
363
- return ::llvm::StringLiteral("legalize-quant-composite");
440
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
364
441
  }
365
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
442
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
366
443
 
367
444
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
368
445
 
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
576
653
  });
577
654
  }
578
655
 
656
+ //===----------------------------------------------------------------------===//
657
+ // StablehloAddQDQAfterConvPass Registration
658
+ //===----------------------------------------------------------------------===//
659
+
660
+ inline void registerStablehloAddQDQAfterConvPass() {
661
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
662
+ return createStablehloAddQDQAfterConvPass();
663
+ });
664
+ }
665
+
666
+ // Old registration code, kept for temporary backwards compatibility.
667
+ inline void registerStablehloAddQDQAfterConvPassPass() {
668
+ ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
669
+ return createStablehloAddQDQAfterConvPass();
670
+ });
671
+ }
672
+
579
673
  //===----------------------------------------------------------------------===//
580
674
  // StablehloCanonicalizeDynamismPass Registration
581
675
  //===----------------------------------------------------------------------===//
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
684
778
 
685
779
  inline void registerPasses() {
686
780
  registerChloRecomposeOpsPass();
781
+ registerStablehloAddQDQAfterConvPass();
687
782
  registerStablehloCanonicalizeDynamismPass();
688
783
  registerStablehloFlattenEntryFunctionTuplesPass();
689
784
  registerStablehloFlattenTuplePass();
@@ -745,6 +840,56 @@ public:
745
840
  protected:
746
841
  };
747
842
 
843
+ template <typename DerivedT>
844
+ class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
845
+ public:
846
+ using Base = StablehloAddQDQAfterConvPassBase;
847
+
848
+ StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
849
+ StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
850
+ StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
851
+ StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
852
+ StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
853
+ ~StablehloAddQDQAfterConvPassBase() = default;
854
+
855
+ /// Returns the command-line argument attached to this pass.
856
+ static constexpr ::llvm::StringLiteral getArgumentName() {
857
+ return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
858
+ }
859
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
860
+
861
+ ::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
862
+
863
+ /// Returns the derived pass name.
864
+ static constexpr ::llvm::StringLiteral getPassName() {
865
+ return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
866
+ }
867
+ ::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
868
+
869
+ /// Support isa/dyn_cast functionality for the derived pass class.
870
+ static bool classof(const ::mlir::Pass *pass) {
871
+ return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
872
+ }
873
+
874
+ /// A clone method to create a copy of this pass.
875
+ std::unique_ptr<::mlir::Pass> clonePass() const override {
876
+ return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
877
+ }
878
+
879
+ /// Register the dialects that must be loaded in the context before this pass.
880
+ void getDependentDialects(::mlir::DialectRegistry &registry) const override {
881
+ registry.insert<mlir::quant::QuantDialect>();
882
+ registry.insert<stablehlo::StablehloDialect>();
883
+ }
884
+
885
+ /// Explicitly declare the TypeID for this class. We declare an explicit private
886
+ /// instantiation because Pass classes should only be visible by the current
887
+ /// library.
888
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
889
+
890
+ protected:
891
+ };
892
+
748
893
  template <typename DerivedT>
749
894
  class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
750
895
  public:
@@ -907,9 +1052,9 @@ public:
907
1052
 
908
1053
  /// Returns the command-line argument attached to this pass.
909
1054
  static constexpr ::llvm::StringLiteral getArgumentName() {
910
- return ::llvm::StringLiteral("legalize-quant-composite");
1055
+ return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
911
1056
  }
912
- ::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
1057
+ ::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
913
1058
 
914
1059
  ::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
915
1060
 
@@ -145,6 +145,11 @@ class DistributedRuntimeClient {
145
145
  std::string barrier_id, absl::Duration timeout,
146
146
  std::optional<absl::Span<const int32_t>> nodes) = 0;
147
147
 
148
+ // Returns the subset of live nodes. See CoordinationService.GetAliveTasks for
149
+ // detailed semantics.
150
+ virtual absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
151
+ absl::Span<const int32_t> nodes) = 0;
152
+
148
153
  // Returns pointer to coordination service agent, or InternalError if the
149
154
  // client does not use coordination service.
150
155
  virtual absl::StatusOr<tsl::CoordinationServiceAgent*>
@@ -37,6 +37,7 @@ limitations under the License.
37
37
  #include "xla/pjrt/distributed/key_value_store_interface.h"
38
38
  #include "xla/pjrt/gpu/gpu_topology.h"
39
39
  #include "xla/pjrt/gpu/gpu_topology.pb.h"
40
+ #include "xla/pjrt/gpu/se_gpu_topology_description.h"
40
41
  #include "xla/pjrt/local_device_state.h"
41
42
  #include "xla/pjrt/pjrt_client.h"
42
43
  #include "xla/pjrt/pjrt_compiler.h"
@@ -57,98 +58,6 @@ using DeviceTopologyPair =
57
58
  std::pair<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>,
58
59
  GpuTopologyProto>;
59
60
 
60
- class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
61
- public:
62
- StreamExecutorGpuTopologyDescription(
63
- const PjRtPlatformId platform_id, const absl::string_view platform_name,
64
- std::shared_ptr<const GpuTopology> gpu_topology,
65
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
66
- {},
67
- std::optional<stream_executor::GpuTargetConfigProto> target_config =
68
- std::nullopt)
69
- : platform_id_(platform_id),
70
- platform_name_(platform_name),
71
- gpu_topology_(std::move(gpu_topology)),
72
- attributes_(attributes),
73
- target_config_(std::move(target_config)) {}
74
-
75
- bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
76
- return this->platform_id() == other.platform_id() &&
77
- this->platform_name() == other.platform_name() &&
78
- this->platform_version() == other.platform_version() &&
79
- this->gpu_topology() == other.gpu_topology();
80
- }
81
-
82
- PjRtPlatformId platform_id() const override { return platform_id_; }
83
-
84
- absl::string_view platform_name() const override { return platform_name_; }
85
-
86
- absl::string_view platform_version() const override {
87
- return gpu_topology_->platform_version();
88
- }
89
-
90
- std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
91
- const override {
92
- std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
93
- devices.reserve(gpu_topology_->number_of_devices());
94
- for (const int device_id : gpu_topology_->device_ids()) {
95
- devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
96
- device_id, std::string(platform_version())));
97
- }
98
- return devices;
99
- }
100
-
101
- const GpuTopology& gpu_topology() const { return *gpu_topology_; }
102
- const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
103
-
104
- // No subslice is supported.
105
- bool is_subslice_topology() const override { return false; }
106
-
107
- absl::StatusOr<int> ProcessCount() const override {
108
- return gpu_topology_->number_of_hosts();
109
- }
110
-
111
- absl::StatusOr<int> CoreCountOfDefaultType() const override {
112
- return gpu_topology_->number_of_devices();
113
- }
114
-
115
- absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
116
- return gpu_topology_->number_of_devices();
117
- }
118
-
119
- absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
120
- return gpu_topology_->number_of_devices();
121
- }
122
-
123
- absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
124
- return 1;
125
- }
126
-
127
- absl::StatusOr<std::string> Serialize() const override;
128
-
129
- const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
130
- const {
131
- return target_config_;
132
- }
133
-
134
- // Returns vendor specific attributes about the topology.
135
- const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
136
- const override {
137
- return attributes_;
138
- }
139
-
140
- absl::StatusOr<Layout> GetDefaultLayout(
141
- PrimitiveType element_type,
142
- absl::Span<const int64_t> dims) const override;
143
-
144
- private:
145
- const PjRtPlatformId platform_id_;
146
- const std::string platform_name_;
147
- std::shared_ptr<const GpuTopology> gpu_topology_;
148
- absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
149
- std::optional<stream_executor::GpuTargetConfigProto> target_config_;
150
- };
151
-
152
61
  class StreamExecutorGpuDevice : public PjRtStreamExecutorDevice {
153
62
  public:
154
63
  StreamExecutorGpuDevice(int id,
@@ -0,0 +1,126 @@
1
+ /* Copyright 2025 The OpenXLA Authors.
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+ Unless required by applicable law or agreed to in writing, software
7
+ distributed under the License is distributed on an "AS IS" BASIS,
8
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ See the License for the specific language governing permissions and
10
+ limitations under the License.
11
+ ==============================================================================*/
12
+ #ifndef XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
13
+ #define XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
14
+
15
+ #include <cstdint>
16
+ #include <memory>
17
+ #include <optional>
18
+ #include <string>
19
+ #include <utility>
20
+ #include <vector>
21
+
22
+ #include "absl/container/flat_hash_map.h"
23
+ #include "absl/status/statusor.h"
24
+ #include "absl/strings/string_view.h"
25
+ #include "absl/types/span.h"
26
+ #include "xla/pjrt/gpu/gpu_topology.h"
27
+ #include "xla/pjrt/pjrt_compiler.h"
28
+ #include "xla/pjrt/pjrt_device_description.h"
29
+ #include "xla/pjrt/pjrt_stream_executor_device_description.h"
30
+
31
+ namespace xla {
32
+
33
+ class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
34
+ public:
35
+ StreamExecutorGpuTopologyDescription(
36
+ const PjRtPlatformId platform_id, const absl::string_view platform_name,
37
+ std::shared_ptr<const GpuTopology> gpu_topology,
38
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
39
+ {},
40
+ std::optional<stream_executor::GpuTargetConfigProto> target_config =
41
+ std::nullopt)
42
+ : platform_id_(platform_id),
43
+ platform_name_(platform_name),
44
+ gpu_topology_(std::move(gpu_topology)),
45
+ attributes_(attributes),
46
+ target_config_(std::move(target_config)) {}
47
+
48
+ bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
49
+ return this->platform_id() == other.platform_id() &&
50
+ this->platform_name() == other.platform_name() &&
51
+ this->platform_version() == other.platform_version() &&
52
+ this->gpu_topology() == other.gpu_topology();
53
+ }
54
+
55
+ PjRtPlatformId platform_id() const override { return platform_id_; }
56
+
57
+ absl::string_view platform_name() const override { return platform_name_; }
58
+
59
+ absl::string_view platform_version() const override {
60
+ return gpu_topology_->platform_version();
61
+ }
62
+
63
+ std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
64
+ const override {
65
+ std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
66
+ devices.reserve(gpu_topology_->number_of_devices());
67
+ for (const int device_id : gpu_topology_->device_ids()) {
68
+ devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
69
+ device_id, std::string(platform_version())));
70
+ }
71
+ return devices;
72
+ }
73
+
74
+ const GpuTopology& gpu_topology() const { return *gpu_topology_; }
75
+ const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
76
+
77
+ // No subslice is supported.
78
+ bool is_subslice_topology() const override { return false; }
79
+
80
+ absl::StatusOr<int> ProcessCount() const override {
81
+ return gpu_topology_->number_of_hosts();
82
+ }
83
+
84
+ absl::StatusOr<int> CoreCountOfDefaultType() const override {
85
+ return gpu_topology_->number_of_devices();
86
+ }
87
+
88
+ absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
89
+ return gpu_topology_->number_of_devices();
90
+ }
91
+
92
+ absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
93
+ return gpu_topology_->number_of_devices();
94
+ }
95
+
96
+ absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
97
+ return 1;
98
+ }
99
+
100
+ absl::StatusOr<std::string> Serialize() const override;
101
+
102
+ const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
103
+ const {
104
+ return target_config_;
105
+ }
106
+
107
+ // Returns vendor specific attributes about the topology.
108
+ const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
109
+ const override {
110
+ return attributes_;
111
+ }
112
+
113
+ absl::StatusOr<Layout> GetDefaultLayout(
114
+ PrimitiveType element_type,
115
+ absl::Span<const int64_t> dims) const override;
116
+
117
+ private:
118
+ const PjRtPlatformId platform_id_;
119
+ const std::string platform_name_;
120
+ std::shared_ptr<const GpuTopology> gpu_topology_;
121
+ absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
122
+ std::optional<stream_executor::GpuTargetConfigProto> target_config_;
123
+ };
124
+ } // namespace xla
125
+
126
+ #endif // XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_