tf-nightly-cpu 2.20.0.dev20250220__cp312-cp312-win_amd64.whl → 2.20.0.dev20250221__cp312-cp312-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
- tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
- tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
- tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -18
- tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
- tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
- tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
- tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
- tensorflow/include/tensorflow/core/public/release_version.h +39 -0
- tensorflow/include/tensorflow/core/public/version.h +112 -127
- tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
- tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -18
- tensorflow/include/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/xla/service/constant_value.h +1 -0
- tensorflow/include/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
- tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
- tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
- tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
- tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
- tensorflow/python/_pywrap_mlir.pyd +0 -0
- tensorflow/python/_pywrap_parallel_device.pyd +0 -0
- tensorflow/python/_pywrap_quantize_training.pyd +0 -0
- tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
- tensorflow/python/_pywrap_tfcompile.pyd +0 -0
- tensorflow/python/_pywrap_tfe.pyd +0 -0
- tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
- tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
- tensorflow/python/compat/compat.py +1 -1
- tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
- tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
- tensorflow/python/eager/imperative_grad.py +5 -5
- tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
- tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
- tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
- tensorflow/python/eager/tape.py +2 -2
- tensorflow/python/framework/_dtypes.pyd +0 -0
- tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
- tensorflow/python/framework/_op_def_registry.pyd +0 -0
- tensorflow/python/framework/_proto_comparators.pyd +0 -0
- tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
- tensorflow/python/framework/_test_metrics_util.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
- tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
- tensorflow/python/ops/summary_ops_v2.py +5 -1
- tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
- tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
- tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
- tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
- tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
- tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
- tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
- tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
- tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
- tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
- tensorflow/python/util/_pywrap_utils.pyd +0 -0
- tensorflow/python/util/_tf_stack.pyd +0 -0
- tensorflow/tools/pip_package/setup.py +2 -2
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/METADATA +1 -1
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/RECORD +115 -108
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/WHEEL +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/entry_points.txt +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
|
|
3
3
|
#ifdef GEN_PASS_DECL
|
4
4
|
// Generate declarations for all passes.
|
5
5
|
#define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
|
6
|
+
#define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
6
7
|
#define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
|
7
8
|
#define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
|
8
9
|
#define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
|
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
|
|
87
88
|
#undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
|
88
89
|
#endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
|
89
90
|
|
91
|
+
//===----------------------------------------------------------------------===//
|
92
|
+
// StablehloAddQDQAfterConvPass
|
93
|
+
//===----------------------------------------------------------------------===//
|
94
|
+
#ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
95
|
+
std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
|
96
|
+
#undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
97
|
+
#endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
98
|
+
#ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
|
99
|
+
|
100
|
+
namespace impl {
|
101
|
+
std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
|
102
|
+
} // namespace impl
|
103
|
+
namespace impl {
|
104
|
+
|
105
|
+
template <typename DerivedT>
|
106
|
+
class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
|
107
|
+
public:
|
108
|
+
using Base = StablehloAddQDQAfterConvPassBase;
|
109
|
+
|
110
|
+
StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
|
111
|
+
StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
|
112
|
+
StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
|
113
|
+
StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
|
114
|
+
StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
|
115
|
+
~StablehloAddQDQAfterConvPassBase() = default;
|
116
|
+
|
117
|
+
/// Returns the command-line argument attached to this pass.
|
118
|
+
static constexpr ::llvm::StringLiteral getArgumentName() {
|
119
|
+
return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
|
120
|
+
}
|
121
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
|
122
|
+
|
123
|
+
::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
|
124
|
+
|
125
|
+
/// Returns the derived pass name.
|
126
|
+
static constexpr ::llvm::StringLiteral getPassName() {
|
127
|
+
return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
|
128
|
+
}
|
129
|
+
::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
|
130
|
+
|
131
|
+
/// Support isa/dyn_cast functionality for the derived pass class.
|
132
|
+
static bool classof(const ::mlir::Pass *pass) {
|
133
|
+
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
134
|
+
}
|
135
|
+
|
136
|
+
/// A clone method to create a copy of this pass.
|
137
|
+
std::unique_ptr<::mlir::Pass> clonePass() const override {
|
138
|
+
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
139
|
+
}
|
140
|
+
|
141
|
+
/// Return the dialect that must be loaded in the context before this pass.
|
142
|
+
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
143
|
+
registry.insert<mlir::quant::QuantDialect>();
|
144
|
+
registry.insert<stablehlo::StablehloDialect>();
|
145
|
+
}
|
146
|
+
|
147
|
+
/// Explicitly declare the TypeID for this class. We declare an explicit private
|
148
|
+
/// instantiation because Pass classes should only be visible by the current
|
149
|
+
/// library.
|
150
|
+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
|
151
|
+
|
152
|
+
protected:
|
153
|
+
private:
|
154
|
+
|
155
|
+
friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
|
156
|
+
return std::make_unique<DerivedT>();
|
157
|
+
}
|
158
|
+
};
|
159
|
+
} // namespace impl
|
160
|
+
|
161
|
+
std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
|
162
|
+
return impl::createStablehloAddQDQAfterConvPass();
|
163
|
+
}
|
164
|
+
#undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
|
165
|
+
#endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
|
166
|
+
|
90
167
|
//===----------------------------------------------------------------------===//
|
91
168
|
// StablehloCanonicalizeDynamismPass
|
92
169
|
//===----------------------------------------------------------------------===//
|
@@ -360,9 +437,9 @@ public:
|
|
360
437
|
|
361
438
|
/// Returns the command-line argument attached to this pass.
|
362
439
|
static constexpr ::llvm::StringLiteral getArgumentName() {
|
363
|
-
return ::llvm::StringLiteral("legalize-quant-composite");
|
440
|
+
return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
|
364
441
|
}
|
365
|
-
::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
|
442
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
|
366
443
|
|
367
444
|
::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
|
368
445
|
|
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
|
|
576
653
|
});
|
577
654
|
}
|
578
655
|
|
656
|
+
//===----------------------------------------------------------------------===//
|
657
|
+
// StablehloAddQDQAfterConvPass Registration
|
658
|
+
//===----------------------------------------------------------------------===//
|
659
|
+
|
660
|
+
inline void registerStablehloAddQDQAfterConvPass() {
|
661
|
+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
662
|
+
return createStablehloAddQDQAfterConvPass();
|
663
|
+
});
|
664
|
+
}
|
665
|
+
|
666
|
+
// Old registration code, kept for temporary backwards compatibility.
|
667
|
+
inline void registerStablehloAddQDQAfterConvPassPass() {
|
668
|
+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
669
|
+
return createStablehloAddQDQAfterConvPass();
|
670
|
+
});
|
671
|
+
}
|
672
|
+
|
579
673
|
//===----------------------------------------------------------------------===//
|
580
674
|
// StablehloCanonicalizeDynamismPass Registration
|
581
675
|
//===----------------------------------------------------------------------===//
|
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
|
|
684
778
|
|
685
779
|
inline void registerPasses() {
|
686
780
|
registerChloRecomposeOpsPass();
|
781
|
+
registerStablehloAddQDQAfterConvPass();
|
687
782
|
registerStablehloCanonicalizeDynamismPass();
|
688
783
|
registerStablehloFlattenEntryFunctionTuplesPass();
|
689
784
|
registerStablehloFlattenTuplePass();
|
@@ -745,6 +840,56 @@ public:
|
|
745
840
|
protected:
|
746
841
|
};
|
747
842
|
|
843
|
+
template <typename DerivedT>
|
844
|
+
class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
|
845
|
+
public:
|
846
|
+
using Base = StablehloAddQDQAfterConvPassBase;
|
847
|
+
|
848
|
+
StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
|
849
|
+
StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
|
850
|
+
StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
|
851
|
+
StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
|
852
|
+
StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
|
853
|
+
~StablehloAddQDQAfterConvPassBase() = default;
|
854
|
+
|
855
|
+
/// Returns the command-line argument attached to this pass.
|
856
|
+
static constexpr ::llvm::StringLiteral getArgumentName() {
|
857
|
+
return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
|
858
|
+
}
|
859
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
|
860
|
+
|
861
|
+
::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
|
862
|
+
|
863
|
+
/// Returns the derived pass name.
|
864
|
+
static constexpr ::llvm::StringLiteral getPassName() {
|
865
|
+
return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
|
866
|
+
}
|
867
|
+
::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
|
868
|
+
|
869
|
+
/// Support isa/dyn_cast functionality for the derived pass class.
|
870
|
+
static bool classof(const ::mlir::Pass *pass) {
|
871
|
+
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
872
|
+
}
|
873
|
+
|
874
|
+
/// A clone method to create a copy of this pass.
|
875
|
+
std::unique_ptr<::mlir::Pass> clonePass() const override {
|
876
|
+
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
877
|
+
}
|
878
|
+
|
879
|
+
/// Register the dialects that must be loaded in the context before this pass.
|
880
|
+
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
881
|
+
registry.insert<mlir::quant::QuantDialect>();
|
882
|
+
registry.insert<stablehlo::StablehloDialect>();
|
883
|
+
}
|
884
|
+
|
885
|
+
/// Explicitly declare the TypeID for this class. We declare an explicit private
|
886
|
+
/// instantiation because Pass classes should only be visible by the current
|
887
|
+
/// library.
|
888
|
+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
|
889
|
+
|
890
|
+
protected:
|
891
|
+
};
|
892
|
+
|
748
893
|
template <typename DerivedT>
|
749
894
|
class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
|
750
895
|
public:
|
@@ -907,9 +1052,9 @@ public:
|
|
907
1052
|
|
908
1053
|
/// Returns the command-line argument attached to this pass.
|
909
1054
|
static constexpr ::llvm::StringLiteral getArgumentName() {
|
910
|
-
return ::llvm::StringLiteral("legalize-quant-composite");
|
1055
|
+
return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
|
911
1056
|
}
|
912
|
-
::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
|
1057
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
|
913
1058
|
|
914
1059
|
::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
|
915
1060
|
|
@@ -3,6 +3,7 @@
|
|
3
3
|
#ifdef GEN_PASS_DECL
|
4
4
|
// Generate declarations for all passes.
|
5
5
|
#define GEN_PASS_DECL_CHLORECOMPOSEOPSPASS
|
6
|
+
#define GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
6
7
|
#define GEN_PASS_DECL_STABLEHLOCANONICALIZEDYNAMISMPASS
|
7
8
|
#define GEN_PASS_DECL_STABLEHLOFLATTENENTRYFUNCTIONTUPLESPASS
|
8
9
|
#define GEN_PASS_DECL_STABLEHLOFLATTENTUPLEPASS
|
@@ -87,6 +88,82 @@ std::unique_ptr<::mlir::Pass> createChloRecomposeOpsPass() {
|
|
87
88
|
#undef GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
|
88
89
|
#endif // GEN_PASS_DEF_CHLORECOMPOSEOPSPASS
|
89
90
|
|
91
|
+
//===----------------------------------------------------------------------===//
|
92
|
+
// StablehloAddQDQAfterConvPass
|
93
|
+
//===----------------------------------------------------------------------===//
|
94
|
+
#ifdef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
95
|
+
std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
|
96
|
+
#undef GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
97
|
+
#endif // GEN_PASS_DECL_STABLEHLOADDQDQAFTERCONVPASS
|
98
|
+
#ifdef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
|
99
|
+
|
100
|
+
namespace impl {
|
101
|
+
std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass();
|
102
|
+
} // namespace impl
|
103
|
+
namespace impl {
|
104
|
+
|
105
|
+
template <typename DerivedT>
|
106
|
+
class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
|
107
|
+
public:
|
108
|
+
using Base = StablehloAddQDQAfterConvPassBase;
|
109
|
+
|
110
|
+
StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
|
111
|
+
StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
|
112
|
+
StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
|
113
|
+
StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
|
114
|
+
StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
|
115
|
+
~StablehloAddQDQAfterConvPassBase() = default;
|
116
|
+
|
117
|
+
/// Returns the command-line argument attached to this pass.
|
118
|
+
static constexpr ::llvm::StringLiteral getArgumentName() {
|
119
|
+
return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
|
120
|
+
}
|
121
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
|
122
|
+
|
123
|
+
::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
|
124
|
+
|
125
|
+
/// Returns the derived pass name.
|
126
|
+
static constexpr ::llvm::StringLiteral getPassName() {
|
127
|
+
return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
|
128
|
+
}
|
129
|
+
::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
|
130
|
+
|
131
|
+
/// Support isa/dyn_cast functionality for the derived pass class.
|
132
|
+
static bool classof(const ::mlir::Pass *pass) {
|
133
|
+
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
134
|
+
}
|
135
|
+
|
136
|
+
/// A clone method to create a copy of this pass.
|
137
|
+
std::unique_ptr<::mlir::Pass> clonePass() const override {
|
138
|
+
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
139
|
+
}
|
140
|
+
|
141
|
+
/// Return the dialect that must be loaded in the context before this pass.
|
142
|
+
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
143
|
+
registry.insert<mlir::quant::QuantDialect>();
|
144
|
+
registry.insert<stablehlo::StablehloDialect>();
|
145
|
+
}
|
146
|
+
|
147
|
+
/// Explicitly declare the TypeID for this class. We declare an explicit private
|
148
|
+
/// instantiation because Pass classes should only be visible by the current
|
149
|
+
/// library.
|
150
|
+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
|
151
|
+
|
152
|
+
protected:
|
153
|
+
private:
|
154
|
+
|
155
|
+
friend std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
|
156
|
+
return std::make_unique<DerivedT>();
|
157
|
+
}
|
158
|
+
};
|
159
|
+
} // namespace impl
|
160
|
+
|
161
|
+
std::unique_ptr<::mlir::Pass> createStablehloAddQDQAfterConvPass() {
|
162
|
+
return impl::createStablehloAddQDQAfterConvPass();
|
163
|
+
}
|
164
|
+
#undef GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
|
165
|
+
#endif // GEN_PASS_DEF_STABLEHLOADDQDQAFTERCONVPASS
|
166
|
+
|
90
167
|
//===----------------------------------------------------------------------===//
|
91
168
|
// StablehloCanonicalizeDynamismPass
|
92
169
|
//===----------------------------------------------------------------------===//
|
@@ -360,9 +437,9 @@ public:
|
|
360
437
|
|
361
438
|
/// Returns the command-line argument attached to this pass.
|
362
439
|
static constexpr ::llvm::StringLiteral getArgumentName() {
|
363
|
-
return ::llvm::StringLiteral("legalize-quant-composite");
|
440
|
+
return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
|
364
441
|
}
|
365
|
-
::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
|
442
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
|
366
443
|
|
367
444
|
::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
|
368
445
|
|
@@ -576,6 +653,23 @@ inline void registerChloRecomposeOpsPassPass() {
|
|
576
653
|
});
|
577
654
|
}
|
578
655
|
|
656
|
+
//===----------------------------------------------------------------------===//
|
657
|
+
// StablehloAddQDQAfterConvPass Registration
|
658
|
+
//===----------------------------------------------------------------------===//
|
659
|
+
|
660
|
+
inline void registerStablehloAddQDQAfterConvPass() {
|
661
|
+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
662
|
+
return createStablehloAddQDQAfterConvPass();
|
663
|
+
});
|
664
|
+
}
|
665
|
+
|
666
|
+
// Old registration code, kept for temporary backwards compatibility.
|
667
|
+
inline void registerStablehloAddQDQAfterConvPassPass() {
|
668
|
+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
|
669
|
+
return createStablehloAddQDQAfterConvPass();
|
670
|
+
});
|
671
|
+
}
|
672
|
+
|
579
673
|
//===----------------------------------------------------------------------===//
|
580
674
|
// StablehloCanonicalizeDynamismPass Registration
|
581
675
|
//===----------------------------------------------------------------------===//
|
@@ -684,6 +778,7 @@ inline void registerStablehloRefineShapesPassPass() {
|
|
684
778
|
|
685
779
|
inline void registerPasses() {
|
686
780
|
registerChloRecomposeOpsPass();
|
781
|
+
registerStablehloAddQDQAfterConvPass();
|
687
782
|
registerStablehloCanonicalizeDynamismPass();
|
688
783
|
registerStablehloFlattenEntryFunctionTuplesPass();
|
689
784
|
registerStablehloFlattenTuplePass();
|
@@ -745,6 +840,56 @@ public:
|
|
745
840
|
protected:
|
746
841
|
};
|
747
842
|
|
843
|
+
template <typename DerivedT>
|
844
|
+
class StablehloAddQDQAfterConvPassBase : public ::mlir::OperationPass<ModuleOp> {
|
845
|
+
public:
|
846
|
+
using Base = StablehloAddQDQAfterConvPassBase;
|
847
|
+
|
848
|
+
StablehloAddQDQAfterConvPassBase() : ::mlir::OperationPass<ModuleOp>(::mlir::TypeID::get<DerivedT>()) {}
|
849
|
+
StablehloAddQDQAfterConvPassBase(const StablehloAddQDQAfterConvPassBase &other) : ::mlir::OperationPass<ModuleOp>(other) {}
|
850
|
+
StablehloAddQDQAfterConvPassBase& operator=(const StablehloAddQDQAfterConvPassBase &) = delete;
|
851
|
+
StablehloAddQDQAfterConvPassBase(StablehloAddQDQAfterConvPassBase &&) = delete;
|
852
|
+
StablehloAddQDQAfterConvPassBase& operator=(StablehloAddQDQAfterConvPassBase &&) = delete;
|
853
|
+
~StablehloAddQDQAfterConvPassBase() = default;
|
854
|
+
|
855
|
+
/// Returns the command-line argument attached to this pass.
|
856
|
+
static constexpr ::llvm::StringLiteral getArgumentName() {
|
857
|
+
return ::llvm::StringLiteral("stablehlo-ext-add-qdq-after-conv");
|
858
|
+
}
|
859
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-add-qdq-after-conv"; }
|
860
|
+
|
861
|
+
::llvm::StringRef getDescription() const override { return "Add quant and dequant ops after convolution op."; }
|
862
|
+
|
863
|
+
/// Returns the derived pass name.
|
864
|
+
static constexpr ::llvm::StringLiteral getPassName() {
|
865
|
+
return ::llvm::StringLiteral("StablehloAddQDQAfterConvPass");
|
866
|
+
}
|
867
|
+
::llvm::StringRef getName() const override { return "StablehloAddQDQAfterConvPass"; }
|
868
|
+
|
869
|
+
/// Support isa/dyn_cast functionality for the derived pass class.
|
870
|
+
static bool classof(const ::mlir::Pass *pass) {
|
871
|
+
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
|
872
|
+
}
|
873
|
+
|
874
|
+
/// A clone method to create a copy of this pass.
|
875
|
+
std::unique_ptr<::mlir::Pass> clonePass() const override {
|
876
|
+
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
|
877
|
+
}
|
878
|
+
|
879
|
+
/// Register the dialects that must be loaded in the context before this pass.
|
880
|
+
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
|
881
|
+
registry.insert<mlir::quant::QuantDialect>();
|
882
|
+
registry.insert<stablehlo::StablehloDialect>();
|
883
|
+
}
|
884
|
+
|
885
|
+
/// Explicitly declare the TypeID for this class. We declare an explicit private
|
886
|
+
/// instantiation because Pass classes should only be visible by the current
|
887
|
+
/// library.
|
888
|
+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StablehloAddQDQAfterConvPassBase<DerivedT>)
|
889
|
+
|
890
|
+
protected:
|
891
|
+
};
|
892
|
+
|
748
893
|
template <typename DerivedT>
|
749
894
|
class StablehloCanonicalizeDynamismPassBase : public ::mlir::OperationPass<func::FuncOp> {
|
750
895
|
public:
|
@@ -907,9 +1052,9 @@ public:
|
|
907
1052
|
|
908
1053
|
/// Returns the command-line argument attached to this pass.
|
909
1054
|
static constexpr ::llvm::StringLiteral getArgumentName() {
|
910
|
-
return ::llvm::StringLiteral("legalize-quant-composite");
|
1055
|
+
return ::llvm::StringLiteral("stablehlo-ext-legalize-quant-composite");
|
911
1056
|
}
|
912
|
-
::llvm::StringRef getArgument() const override { return "legalize-quant-composite"; }
|
1057
|
+
::llvm::StringRef getArgument() const override { return "stablehlo-ext-legalize-quant-composite"; }
|
913
1058
|
|
914
1059
|
::llvm::StringRef getDescription() const override { return "Lowers the quantization related composites op to native quantized ops."; }
|
915
1060
|
|
@@ -145,6 +145,11 @@ class DistributedRuntimeClient {
|
|
145
145
|
std::string barrier_id, absl::Duration timeout,
|
146
146
|
std::optional<absl::Span<const int32_t>> nodes) = 0;
|
147
147
|
|
148
|
+
// Returns the subset of live nodes. See CoordinationService.GetAliveTasks for
|
149
|
+
// detailed semantics.
|
150
|
+
virtual absl::StatusOr<std::vector<int32_t>> GetLiveNodes(
|
151
|
+
absl::Span<const int32_t> nodes) = 0;
|
152
|
+
|
148
153
|
// Returns pointer to coordination service agent, or InternalError if the
|
149
154
|
// client does not use coordination service.
|
150
155
|
virtual absl::StatusOr<tsl::CoordinationServiceAgent*>
|
@@ -37,6 +37,7 @@ limitations under the License.
|
|
37
37
|
#include "xla/pjrt/distributed/key_value_store_interface.h"
|
38
38
|
#include "xla/pjrt/gpu/gpu_topology.h"
|
39
39
|
#include "xla/pjrt/gpu/gpu_topology.pb.h"
|
40
|
+
#include "xla/pjrt/gpu/se_gpu_topology_description.h"
|
40
41
|
#include "xla/pjrt/local_device_state.h"
|
41
42
|
#include "xla/pjrt/pjrt_client.h"
|
42
43
|
#include "xla/pjrt/pjrt_compiler.h"
|
@@ -57,98 +58,6 @@ using DeviceTopologyPair =
|
|
57
58
|
std::pair<std::vector<std::unique_ptr<PjRtStreamExecutorDevice>>,
|
58
59
|
GpuTopologyProto>;
|
59
60
|
|
60
|
-
class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
|
61
|
-
public:
|
62
|
-
StreamExecutorGpuTopologyDescription(
|
63
|
-
const PjRtPlatformId platform_id, const absl::string_view platform_name,
|
64
|
-
std::shared_ptr<const GpuTopology> gpu_topology,
|
65
|
-
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
|
66
|
-
{},
|
67
|
-
std::optional<stream_executor::GpuTargetConfigProto> target_config =
|
68
|
-
std::nullopt)
|
69
|
-
: platform_id_(platform_id),
|
70
|
-
platform_name_(platform_name),
|
71
|
-
gpu_topology_(std::move(gpu_topology)),
|
72
|
-
attributes_(attributes),
|
73
|
-
target_config_(std::move(target_config)) {}
|
74
|
-
|
75
|
-
bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
|
76
|
-
return this->platform_id() == other.platform_id() &&
|
77
|
-
this->platform_name() == other.platform_name() &&
|
78
|
-
this->platform_version() == other.platform_version() &&
|
79
|
-
this->gpu_topology() == other.gpu_topology();
|
80
|
-
}
|
81
|
-
|
82
|
-
PjRtPlatformId platform_id() const override { return platform_id_; }
|
83
|
-
|
84
|
-
absl::string_view platform_name() const override { return platform_name_; }
|
85
|
-
|
86
|
-
absl::string_view platform_version() const override {
|
87
|
-
return gpu_topology_->platform_version();
|
88
|
-
}
|
89
|
-
|
90
|
-
std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
|
91
|
-
const override {
|
92
|
-
std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
|
93
|
-
devices.reserve(gpu_topology_->number_of_devices());
|
94
|
-
for (const int device_id : gpu_topology_->device_ids()) {
|
95
|
-
devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
|
96
|
-
device_id, std::string(platform_version())));
|
97
|
-
}
|
98
|
-
return devices;
|
99
|
-
}
|
100
|
-
|
101
|
-
const GpuTopology& gpu_topology() const { return *gpu_topology_; }
|
102
|
-
const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
|
103
|
-
|
104
|
-
// No subslice is supported.
|
105
|
-
bool is_subslice_topology() const override { return false; }
|
106
|
-
|
107
|
-
absl::StatusOr<int> ProcessCount() const override {
|
108
|
-
return gpu_topology_->number_of_hosts();
|
109
|
-
}
|
110
|
-
|
111
|
-
absl::StatusOr<int> CoreCountOfDefaultType() const override {
|
112
|
-
return gpu_topology_->number_of_devices();
|
113
|
-
}
|
114
|
-
|
115
|
-
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
|
116
|
-
return gpu_topology_->number_of_devices();
|
117
|
-
}
|
118
|
-
|
119
|
-
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
|
120
|
-
return gpu_topology_->number_of_devices();
|
121
|
-
}
|
122
|
-
|
123
|
-
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
|
124
|
-
return 1;
|
125
|
-
}
|
126
|
-
|
127
|
-
absl::StatusOr<std::string> Serialize() const override;
|
128
|
-
|
129
|
-
const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
|
130
|
-
const {
|
131
|
-
return target_config_;
|
132
|
-
}
|
133
|
-
|
134
|
-
// Returns vendor specific attributes about the topology.
|
135
|
-
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
136
|
-
const override {
|
137
|
-
return attributes_;
|
138
|
-
}
|
139
|
-
|
140
|
-
absl::StatusOr<Layout> GetDefaultLayout(
|
141
|
-
PrimitiveType element_type,
|
142
|
-
absl::Span<const int64_t> dims) const override;
|
143
|
-
|
144
|
-
private:
|
145
|
-
const PjRtPlatformId platform_id_;
|
146
|
-
const std::string platform_name_;
|
147
|
-
std::shared_ptr<const GpuTopology> gpu_topology_;
|
148
|
-
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
|
149
|
-
std::optional<stream_executor::GpuTargetConfigProto> target_config_;
|
150
|
-
};
|
151
|
-
|
152
61
|
class StreamExecutorGpuDevice : public PjRtStreamExecutorDevice {
|
153
62
|
public:
|
154
63
|
StreamExecutorGpuDevice(int id,
|
@@ -0,0 +1,126 @@
|
|
1
|
+
/* Copyright 2025 The OpenXLA Authors.
|
2
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
you may not use this file except in compliance with the License.
|
4
|
+
You may obtain a copy of the License at
|
5
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
Unless required by applicable law or agreed to in writing, software
|
7
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
See the License for the specific language governing permissions and
|
10
|
+
limitations under the License.
|
11
|
+
==============================================================================*/
|
12
|
+
#ifndef XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
|
13
|
+
#define XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
|
14
|
+
|
15
|
+
#include <cstdint>
|
16
|
+
#include <memory>
|
17
|
+
#include <optional>
|
18
|
+
#include <string>
|
19
|
+
#include <utility>
|
20
|
+
#include <vector>
|
21
|
+
|
22
|
+
#include "absl/container/flat_hash_map.h"
|
23
|
+
#include "absl/status/statusor.h"
|
24
|
+
#include "absl/strings/string_view.h"
|
25
|
+
#include "absl/types/span.h"
|
26
|
+
#include "xla/pjrt/gpu/gpu_topology.h"
|
27
|
+
#include "xla/pjrt/pjrt_compiler.h"
|
28
|
+
#include "xla/pjrt/pjrt_device_description.h"
|
29
|
+
#include "xla/pjrt/pjrt_stream_executor_device_description.h"
|
30
|
+
|
31
|
+
namespace xla {
|
32
|
+
|
33
|
+
class StreamExecutorGpuTopologyDescription : public PjRtTopologyDescription {
|
34
|
+
public:
|
35
|
+
StreamExecutorGpuTopologyDescription(
|
36
|
+
const PjRtPlatformId platform_id, const absl::string_view platform_name,
|
37
|
+
std::shared_ptr<const GpuTopology> gpu_topology,
|
38
|
+
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& attributes =
|
39
|
+
{},
|
40
|
+
std::optional<stream_executor::GpuTargetConfigProto> target_config =
|
41
|
+
std::nullopt)
|
42
|
+
: platform_id_(platform_id),
|
43
|
+
platform_name_(platform_name),
|
44
|
+
gpu_topology_(std::move(gpu_topology)),
|
45
|
+
attributes_(attributes),
|
46
|
+
target_config_(std::move(target_config)) {}
|
47
|
+
|
48
|
+
bool operator==(const StreamExecutorGpuTopologyDescription& other) const {
|
49
|
+
return this->platform_id() == other.platform_id() &&
|
50
|
+
this->platform_name() == other.platform_name() &&
|
51
|
+
this->platform_version() == other.platform_version() &&
|
52
|
+
this->gpu_topology() == other.gpu_topology();
|
53
|
+
}
|
54
|
+
|
55
|
+
PjRtPlatformId platform_id() const override { return platform_id_; }
|
56
|
+
|
57
|
+
absl::string_view platform_name() const override { return platform_name_; }
|
58
|
+
|
59
|
+
absl::string_view platform_version() const override {
|
60
|
+
return gpu_topology_->platform_version();
|
61
|
+
}
|
62
|
+
|
63
|
+
std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
|
64
|
+
const override {
|
65
|
+
std::vector<std::unique_ptr<const PjRtDeviceDescription>> devices;
|
66
|
+
devices.reserve(gpu_topology_->number_of_devices());
|
67
|
+
for (const int device_id : gpu_topology_->device_ids()) {
|
68
|
+
devices.push_back(std::make_unique<PjRtStreamExecutorDeviceDescription>(
|
69
|
+
device_id, std::string(platform_version())));
|
70
|
+
}
|
71
|
+
return devices;
|
72
|
+
}
|
73
|
+
|
74
|
+
const GpuTopology& gpu_topology() const { return *gpu_topology_; }
|
75
|
+
const GpuTopology* gpu_topology_ptr() const { return gpu_topology_.get(); }
|
76
|
+
|
77
|
+
// No subslice is supported.
|
78
|
+
bool is_subslice_topology() const override { return false; }
|
79
|
+
|
80
|
+
absl::StatusOr<int> ProcessCount() const override {
|
81
|
+
return gpu_topology_->number_of_hosts();
|
82
|
+
}
|
83
|
+
|
84
|
+
absl::StatusOr<int> CoreCountOfDefaultType() const override {
|
85
|
+
return gpu_topology_->number_of_devices();
|
86
|
+
}
|
87
|
+
|
88
|
+
absl::StatusOr<int> LogicalDeviceCountOfDefaultType() const override {
|
89
|
+
return gpu_topology_->number_of_devices();
|
90
|
+
}
|
91
|
+
|
92
|
+
absl::StatusOr<int> CoreCountOfDefaultTypePerProcess() const override {
|
93
|
+
return gpu_topology_->number_of_devices();
|
94
|
+
}
|
95
|
+
|
96
|
+
absl::StatusOr<int> CoreCountOfDefaultTypePerChip() const override {
|
97
|
+
return 1;
|
98
|
+
}
|
99
|
+
|
100
|
+
absl::StatusOr<std::string> Serialize() const override;
|
101
|
+
|
102
|
+
const std::optional<stream_executor::GpuTargetConfigProto>& target_config()
|
103
|
+
const {
|
104
|
+
return target_config_;
|
105
|
+
}
|
106
|
+
|
107
|
+
// Returns vendor specific attributes about the topology.
|
108
|
+
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
109
|
+
const override {
|
110
|
+
return attributes_;
|
111
|
+
}
|
112
|
+
|
113
|
+
absl::StatusOr<Layout> GetDefaultLayout(
|
114
|
+
PrimitiveType element_type,
|
115
|
+
absl::Span<const int64_t> dims) const override;
|
116
|
+
|
117
|
+
private:
|
118
|
+
const PjRtPlatformId platform_id_;
|
119
|
+
const std::string platform_name_;
|
120
|
+
std::shared_ptr<const GpuTopology> gpu_topology_;
|
121
|
+
absl::flat_hash_map<std::string, xla::PjRtDeviceAttribute> attributes_;
|
122
|
+
std::optional<stream_executor::GpuTargetConfigProto> target_config_;
|
123
|
+
};
|
124
|
+
} // namespace xla
|
125
|
+
|
126
|
+
#endif // XLA_PJRT_GPU_SE_GPU_TOPOLOGY_DESCRIPTION_H_
|