tf-nightly-cpu 2.20.0.dev20250220__cp310-cp310-win_amd64.whl → 2.20.0.dev20250221__cp310-cp310-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
- tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
- tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
- tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -18
- tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
- tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
- tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
- tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
- tensorflow/include/tensorflow/core/public/release_version.h +39 -0
- tensorflow/include/tensorflow/core/public/version.h +112 -127
- tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
- tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -18
- tensorflow/include/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/xla/service/constant_value.h +1 -0
- tensorflow/include/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
- tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
- tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
- tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
- tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
- tensorflow/python/_pywrap_mlir.pyd +0 -0
- tensorflow/python/_pywrap_parallel_device.pyd +0 -0
- tensorflow/python/_pywrap_quantize_training.pyd +0 -0
- tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
- tensorflow/python/_pywrap_tfcompile.pyd +0 -0
- tensorflow/python/_pywrap_tfe.pyd +0 -0
- tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
- tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
- tensorflow/python/compat/compat.py +1 -1
- tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
- tensorflow/python/eager/imperative_grad.py +5 -5
- tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
- tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
- tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
- tensorflow/python/eager/tape.py +2 -2
- tensorflow/python/framework/_dtypes.pyd +0 -0
- tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
- tensorflow/python/framework/_op_def_registry.pyd +0 -0
- tensorflow/python/framework/_proto_comparators.pyd +0 -0
- tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
- tensorflow/python/framework/_test_metrics_util.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
- tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
- tensorflow/python/ops/summary_ops_v2.py +5 -1
- tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
- tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
- tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
- tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
- tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
- tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
- tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
- tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
- tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
- tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
- tensorflow/python/util/_pywrap_utils.pyd +0 -0
- tensorflow/python/util/_tf_stack.pyd +0 -0
- tensorflow/tools/pip_package/setup.py +2 -2
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/METADATA +1 -1
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/RECORD +113 -106
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/WHEEL +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/entry_points.txt +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,8 @@
|
|
5
5
|
|
6
6
|
import sys as _sys
|
7
7
|
|
8
|
-
from tensorflow.python.ops.summary_ops_v2 import all_v2_summary_ops # line:
|
9
|
-
from tensorflow.python.ops.summary_ops_v2 import initialize # line:
|
8
|
+
from tensorflow.python.ops.summary_ops_v2 import all_v2_summary_ops # line: 665
|
9
|
+
from tensorflow.python.ops.summary_ops_v2 import initialize # line: 477
|
10
10
|
from tensorflow.python.proto_exports import Event # line: 28
|
11
11
|
from tensorflow.python.proto_exports import SessionLog # line: 47
|
12
12
|
from tensorflow.python.proto_exports import Summary # line: 50
|
@@ -19,8 +19,8 @@ from tensorflow.python.tpu.tpu_embedding_v2_utils import QuantizationConfig # li
|
|
19
19
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import RowIdInitializer # line: 1347
|
20
20
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import SGD # line: 363
|
21
21
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import TableConfig # line: 1161
|
22
|
-
from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line:
|
23
|
-
from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line:
|
22
|
+
from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 78
|
23
|
+
from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 482
|
24
24
|
|
25
25
|
from tensorflow.python.util import module_wrapper as _module_wrapper
|
26
26
|
|
@@ -6,17 +6,17 @@
|
|
6
6
|
import sys as _sys
|
7
7
|
|
8
8
|
from tensorflow._api.v2.compat.v2.summary import experimental
|
9
|
-
from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line:
|
10
|
-
from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line:
|
11
|
-
from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line:
|
12
|
-
from tensorflow.python.ops.summary_ops_v2 import flush # line:
|
13
|
-
from tensorflow.python.ops.summary_ops_v2 import graph # line:
|
14
|
-
from tensorflow.python.ops.summary_ops_v2 import record_if # line:
|
9
|
+
from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line: 248
|
10
|
+
from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line: 520
|
11
|
+
from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line: 645
|
12
|
+
from tensorflow.python.ops.summary_ops_v2 import flush # line: 1145
|
13
|
+
from tensorflow.python.ops.summary_ops_v2 import graph # line: 1057
|
14
|
+
from tensorflow.python.ops.summary_ops_v2 import record_if # line: 161
|
15
15
|
from tensorflow.python.ops.summary_ops_v2 import should_record_summaries # line: 133
|
16
|
-
from tensorflow.python.ops.summary_ops_v2 import trace_export # line:
|
17
|
-
from tensorflow.python.ops.summary_ops_v2 import trace_off # line:
|
18
|
-
from tensorflow.python.ops.summary_ops_v2 import trace_on # line:
|
19
|
-
from tensorflow.python.ops.summary_ops_v2 import write # line:
|
16
|
+
from tensorflow.python.ops.summary_ops_v2 import trace_export # line: 1394
|
17
|
+
from tensorflow.python.ops.summary_ops_v2 import trace_off # line: 1447
|
18
|
+
from tensorflow.python.ops.summary_ops_v2 import trace_on # line: 1338
|
19
|
+
from tensorflow.python.ops.summary_ops_v2 import write # line: 741
|
20
20
|
from tensorflow.python.summary.tb_summary import audio # line: 32
|
21
21
|
from tensorflow.python.summary.tb_summary import histogram # line: 89
|
22
22
|
from tensorflow.python.summary.tb_summary import image # line: 165
|
@@ -5,7 +5,7 @@
|
|
5
5
|
|
6
6
|
import sys as _sys
|
7
7
|
|
8
|
-
from tensorflow.python.ops.summary_ops_v2 import get_step # line:
|
9
|
-
from tensorflow.python.ops.summary_ops_v2 import set_step # line:
|
10
|
-
from tensorflow.python.ops.summary_ops_v2 import summary_scope # line:
|
11
|
-
from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line:
|
8
|
+
from tensorflow.python.ops.summary_ops_v2 import get_step # line: 218
|
9
|
+
from tensorflow.python.ops.summary_ops_v2 import set_step # line: 229
|
10
|
+
from tensorflow.python.ops.summary_ops_v2 import summary_scope # line: 700
|
11
|
+
from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line: 818
|
@@ -19,5 +19,5 @@ from tensorflow.python.tpu.tpu_embedding_v2_utils import QuantizationConfig # li
|
|
19
19
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import RowIdInitializer # line: 1347
|
20
20
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import SGD # line: 363
|
21
21
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import TableConfig # line: 1161
|
22
|
-
from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line:
|
23
|
-
from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line:
|
22
|
+
from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 78
|
23
|
+
from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 482
|
@@ -6,17 +6,17 @@
|
|
6
6
|
import sys as _sys
|
7
7
|
|
8
8
|
from tensorflow._api.v2.summary import experimental
|
9
|
-
from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line:
|
10
|
-
from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line:
|
11
|
-
from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line:
|
12
|
-
from tensorflow.python.ops.summary_ops_v2 import flush # line:
|
13
|
-
from tensorflow.python.ops.summary_ops_v2 import graph # line:
|
14
|
-
from tensorflow.python.ops.summary_ops_v2 import record_if # line:
|
9
|
+
from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line: 248
|
10
|
+
from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line: 520
|
11
|
+
from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line: 645
|
12
|
+
from tensorflow.python.ops.summary_ops_v2 import flush # line: 1145
|
13
|
+
from tensorflow.python.ops.summary_ops_v2 import graph # line: 1057
|
14
|
+
from tensorflow.python.ops.summary_ops_v2 import record_if # line: 161
|
15
15
|
from tensorflow.python.ops.summary_ops_v2 import should_record_summaries # line: 133
|
16
|
-
from tensorflow.python.ops.summary_ops_v2 import trace_export # line:
|
17
|
-
from tensorflow.python.ops.summary_ops_v2 import trace_off # line:
|
18
|
-
from tensorflow.python.ops.summary_ops_v2 import trace_on # line:
|
19
|
-
from tensorflow.python.ops.summary_ops_v2 import write # line:
|
16
|
+
from tensorflow.python.ops.summary_ops_v2 import trace_export # line: 1394
|
17
|
+
from tensorflow.python.ops.summary_ops_v2 import trace_off # line: 1447
|
18
|
+
from tensorflow.python.ops.summary_ops_v2 import trace_on # line: 1338
|
19
|
+
from tensorflow.python.ops.summary_ops_v2 import write # line: 741
|
20
20
|
from tensorflow.python.summary.tb_summary import audio # line: 32
|
21
21
|
from tensorflow.python.summary.tb_summary import histogram # line: 89
|
22
22
|
from tensorflow.python.summary.tb_summary import image # line: 165
|
@@ -5,7 +5,7 @@
|
|
5
5
|
|
6
6
|
import sys as _sys
|
7
7
|
|
8
|
-
from tensorflow.python.ops.summary_ops_v2 import get_step # line:
|
9
|
-
from tensorflow.python.ops.summary_ops_v2 import set_step # line:
|
10
|
-
from tensorflow.python.ops.summary_ops_v2 import summary_scope # line:
|
11
|
-
from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line:
|
8
|
+
from tensorflow.python.ops.summary_ops_v2 import get_step # line: 218
|
9
|
+
from tensorflow.python.ops.summary_ops_v2 import set_step # line: 229
|
10
|
+
from tensorflow.python.ops.summary_ops_v2 import summary_scope # line: 700
|
11
|
+
from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line: 818
|
@@ -19,5 +19,5 @@ from tensorflow.python.tpu.tpu_embedding_v2_utils import QuantizationConfig # li
|
|
19
19
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import RowIdInitializer # line: 1347
|
20
20
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import SGD # line: 363
|
21
21
|
from tensorflow.python.tpu.tpu_embedding_v2_utils import TableConfig # line: 1161
|
22
|
-
from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line:
|
23
|
-
from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line:
|
22
|
+
from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 78
|
23
|
+
from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 482
|
Binary file
|
Binary file
|
Binary file
|
tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
CHANGED
@@ -71,6 +71,18 @@ public:
|
|
71
71
|
unsigned firstIndex) override;
|
72
72
|
};
|
73
73
|
|
74
|
+
/// Succeeds if an op can be converted to its unsigned equivalent without
|
75
|
+
/// changing its semantics. This is the case when none of its openands or
|
76
|
+
/// results can be below 0 when analyzed from a signed perspective.
|
77
|
+
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op);
|
78
|
+
|
79
|
+
/// Succeeds when a value is statically non-negative in that it has a lower
|
80
|
+
/// bound on its value (if it is treated as signed) and that bound is
|
81
|
+
/// non-negative.
|
82
|
+
/// Note, the results of this query may not be accurate for `index` if you plan
|
83
|
+
/// to use a non-64-bit index.
|
84
|
+
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v);
|
85
|
+
|
74
86
|
} // end namespace dataflow
|
75
87
|
} // end namespace mlir
|
76
88
|
|
@@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
|
69
69
|
// Check if any of the given types are mlir::quant::QuantizedType.
|
70
70
|
bool isAnyQuantizedTypes(TypeRange types);
|
71
71
|
|
72
|
+
// Creates a quantized element type based on the given parameters.
|
73
|
+
Type getQuantizedElementType(Location loc, Type storageType, Type expressedType,
|
74
|
+
ArrayRef<double> scales,
|
75
|
+
ArrayRef<int64_t> zeroPoints,
|
76
|
+
int32_t quantizedDimension, int64_t storageTypeMin,
|
77
|
+
int64_t storageTypeMax);
|
78
|
+
|
72
79
|
} // namespace stablehlo
|
73
80
|
} // namespace mlir
|
74
81
|
|
@@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
|
69
69
|
// Check if any of the given types are mlir::quant::QuantizedType.
|
70
70
|
bool isAnyQuantizedTypes(TypeRange types);
|
71
71
|
|
72
|
+
// Creates a quantized element type based on the given parameters.
|
73
|
+
Type getQuantizedElementType(Location loc, Type storageType, Type expressedType,
|
74
|
+
ArrayRef<double> scales,
|
75
|
+
ArrayRef<int64_t> zeroPoints,
|
76
|
+
int32_t quantizedDimension, int64_t storageTypeMin,
|
77
|
+
int64_t storageTypeMax);
|
78
|
+
|
72
79
|
} // namespace stablehlo
|
73
80
|
} // namespace mlir
|
74
81
|
|
@@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
|
69
69
|
// Check if any of the given types are mlir::quant::QuantizedType.
|
70
70
|
bool isAnyQuantizedTypes(TypeRange types);
|
71
71
|
|
72
|
+
// Creates a quantized element type based on the given parameters.
|
73
|
+
Type getQuantizedElementType(Location loc, Type storageType, Type expressedType,
|
74
|
+
ArrayRef<double> scales,
|
75
|
+
ArrayRef<int64_t> zeroPoints,
|
76
|
+
int32_t quantizedDimension, int64_t storageTypeMin,
|
77
|
+
int64_t storageTypeMax);
|
78
|
+
|
72
79
|
} // namespace stablehlo
|
73
80
|
} // namespace mlir
|
74
81
|
|
@@ -89,9 +89,10 @@ class KernelApiIrBuilder {
|
|
89
89
|
// read-only if it is not aliased with any result.
|
90
90
|
absl::flat_hash_set<int64_t> invariant_arguments;
|
91
91
|
|
92
|
-
//
|
92
|
+
// The set of buffers used by this kernel, can be empty if buffer assignment
|
93
93
|
// was not provided.
|
94
|
-
absl::InlinedVector<
|
94
|
+
absl::InlinedVector<BufferAllocation::Slice, 8> argument_buffers;
|
95
|
+
absl::InlinedVector<BufferAllocation::Slice, 8> result_buffers;
|
95
96
|
};
|
96
97
|
|
97
98
|
KernelApiIrBuilder(llvm::LLVMContext& context, Options options);
|
@@ -63,6 +63,8 @@ class KernelThunkBase : public Thunk {
|
|
63
63
|
const = 0;
|
64
64
|
|
65
65
|
virtual absl::Span<const BufferAllocation::Slice> results_buffers() const = 0;
|
66
|
+
|
67
|
+
virtual const absl::flat_hash_set<int64_t>& invariant_arguments() const = 0;
|
66
68
|
};
|
67
69
|
|
68
70
|
namespace internal {
|
@@ -95,6 +97,10 @@ class KernelThunk : public KernelThunkBase {
|
|
95
97
|
return absl::MakeSpan(results_buffers_);
|
96
98
|
}
|
97
99
|
|
100
|
+
const absl::flat_hash_set<int64_t>& invariant_arguments() const final {
|
101
|
+
return invariant_arguments_;
|
102
|
+
}
|
103
|
+
|
98
104
|
protected:
|
99
105
|
tsl::AsyncValueRef<ExecuteEvent> ExecuteInternal(const ExecuteParams& params);
|
100
106
|
|
@@ -129,7 +135,7 @@ class KernelThunk : public KernelThunkBase {
|
|
129
135
|
KernelThunk(Info info,
|
130
136
|
absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
131
137
|
absl::Span<const BufferAllocation::Slice> results_buffers,
|
132
|
-
|
138
|
+
absl::flat_hash_set<int64_t> invariant_arguments,
|
133
139
|
std::string kernel_name, se::ThreadDim thread_dim,
|
134
140
|
std::optional<uint64_t> min_alignment);
|
135
141
|
|
@@ -139,7 +145,7 @@ class KernelThunk : public KernelThunkBase {
|
|
139
145
|
ResultsBuffers results_buffers_;
|
140
146
|
|
141
147
|
// A set of invariant arguments (their indices).
|
142
|
-
|
148
|
+
absl::flat_hash_set<int64_t> invariant_arguments_;
|
143
149
|
|
144
150
|
size_t num_kernel_args_;
|
145
151
|
|
@@ -189,7 +195,7 @@ class KernelThunk final : public internal::KernelThunk<> {
|
|
189
195
|
absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
190
196
|
absl::Span<const BufferAllocation::Slice> results_buffers,
|
191
197
|
std::string kernel_name, se::ThreadDim thread_dim,
|
192
|
-
|
198
|
+
absl::flat_hash_set<int64_t> invariant_arguments,
|
193
199
|
std::optional<uint64_t> min_alignment = std::nullopt);
|
194
200
|
|
195
201
|
static absl::StatusOr<std::unique_ptr<Thunk>> Create(
|
@@ -44,15 +44,6 @@ namespace xla::cpu {
|
|
44
44
|
// A work queue that partitions `num_tasks` tasks into `num_partitions`
|
45
45
|
// partitions processed by parallel workers.
|
46
46
|
class WorkQueue {
|
47
|
-
// Align all atomic counters to a cache line boundary to avoid false
|
48
|
-
// sharing between multiple worker threads.
|
49
|
-
static constexpr size_t kAtomicAlignment =
|
50
|
-
#if defined(__cpp_lib_hardware_interference_size)
|
51
|
-
std::hardware_destructive_interference_size;
|
52
|
-
#else
|
53
|
-
64;
|
54
|
-
#endif
|
55
|
-
|
56
47
|
public:
|
57
48
|
WorkQueue(size_t num_tasks, size_t num_partitions);
|
58
49
|
|
@@ -60,13 +51,23 @@ class WorkQueue {
|
|
60
51
|
// if the partition is complete.
|
61
52
|
std::optional<size_t> Pop(size_t partition_index);
|
62
53
|
|
63
|
-
|
54
|
+
// Return the partition [begin, end) task range.
|
55
|
+
std::pair<size_t, size_t> partition_range(size_t partition_index) const;
|
64
56
|
|
65
|
-
|
57
|
+
size_t num_partitions() const { return partitions_.size(); }
|
66
58
|
|
67
59
|
private:
|
68
60
|
friend class Worker;
|
69
61
|
|
62
|
+
// Align all atomic counters to a cache line boundary to avoid false
|
63
|
+
// sharing between multiple worker threads.
|
64
|
+
static constexpr size_t kAtomicAlignment =
|
65
|
+
#if defined(__cpp_lib_hardware_interference_size)
|
66
|
+
std::hardware_destructive_interference_size;
|
67
|
+
#else
|
68
|
+
64;
|
69
|
+
#endif
|
70
|
+
|
70
71
|
struct Partition {
|
71
72
|
void Initialize(size_t begin, size_t end);
|
72
73
|
|
@@ -76,8 +77,21 @@ class WorkQueue {
|
|
76
77
|
size_t end;
|
77
78
|
};
|
78
79
|
|
80
|
+
// An empty work queue flag to stop worker threads from looping through all
|
81
|
+
// partitions looking for work.
|
82
|
+
bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
|
83
|
+
void SetEmpty() { empty_.store(true, std::memory_order_relaxed); }
|
84
|
+
|
85
|
+
// Notify that one of the workers switched to the work stealing mode.
|
86
|
+
void NotifyWorkStealingWorker();
|
87
|
+
|
88
|
+
// Decrements the number of work stealing workers by at most `max_workers` and
|
89
|
+
// returns the number of decremented work stealing workers.
|
90
|
+
size_t DecrementWorkStealingWorkers(size_t max_workers);
|
91
|
+
|
79
92
|
absl::FixedArray<Partition, 32> partitions_;
|
80
93
|
alignas(kAtomicAlignment) std::atomic<bool> empty_;
|
94
|
+
alignas(kAtomicAlignment) std::atomic<size_t> num_work_stealing_workers_;
|
81
95
|
};
|
82
96
|
|
83
97
|
// Worker processes tasks from the work queue starting from the assigned
|
@@ -130,10 +144,14 @@ inline void WorkQueue::Partition::Initialize(size_t begin, size_t end) {
|
|
130
144
|
}
|
131
145
|
|
132
146
|
inline WorkQueue::WorkQueue(size_t num_tasks, size_t num_partitions)
|
133
|
-
: partitions_(num_partitions),
|
134
|
-
|
135
|
-
|
136
|
-
|
147
|
+
: partitions_(num_partitions),
|
148
|
+
empty_(num_tasks == 0),
|
149
|
+
num_work_stealing_workers_(0) {
|
150
|
+
size_t partition_size =
|
151
|
+
tsl::MathUtil::FloorOfRatio(num_tasks, num_partitions);
|
152
|
+
size_t rem_tasks = num_tasks % num_partitions;
|
153
|
+
for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
|
154
|
+
end = begin + partition_size + ((i < rem_tasks) ? 1 : 0);
|
137
155
|
partitions_[i].Initialize(begin, end);
|
138
156
|
}
|
139
157
|
}
|
@@ -154,6 +172,29 @@ inline std::optional<size_t> WorkQueue::Pop(size_t partition_index) {
|
|
154
172
|
: std::make_optional(index);
|
155
173
|
}
|
156
174
|
|
175
|
+
inline std::pair<size_t, size_t> WorkQueue::partition_range(
|
176
|
+
size_t partition_index) const {
|
177
|
+
DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
|
178
|
+
return {partitions_[partition_index].begin, partitions_[partition_index].end};
|
179
|
+
}
|
180
|
+
|
181
|
+
inline void WorkQueue::NotifyWorkStealingWorker() {
|
182
|
+
num_work_stealing_workers_.fetch_add(1, std::memory_order_relaxed);
|
183
|
+
}
|
184
|
+
|
185
|
+
inline size_t WorkQueue::DecrementWorkStealingWorkers(size_t max_workers) {
|
186
|
+
size_t n = num_work_stealing_workers_.load(std::memory_order_relaxed);
|
187
|
+
|
188
|
+
size_t decrement = std::min(n, max_workers);
|
189
|
+
while (decrement && !num_work_stealing_workers_.compare_exchange_weak(
|
190
|
+
n, n - decrement, std::memory_order_relaxed,
|
191
|
+
std::memory_order_relaxed)) {
|
192
|
+
decrement = std::min(n, max_workers);
|
193
|
+
}
|
194
|
+
|
195
|
+
return decrement;
|
196
|
+
}
|
197
|
+
|
157
198
|
inline Worker::Worker(size_t worker_index, WorkQueue* queue)
|
158
199
|
: worker_index_(worker_index),
|
159
200
|
partition_index_(worker_index),
|
@@ -163,7 +204,13 @@ inline std::optional<size_t> Worker::Pop() {
|
|
163
204
|
std::optional<size_t> task = queue_->Pop(partition_index_);
|
164
205
|
if (ABSL_PREDICT_TRUE(task)) return task;
|
165
206
|
|
166
|
-
|
207
|
+
// If we didn't find a task in the initially assigned partition, notify the
|
208
|
+
// work queue that we are switching to work stealing mode.
|
209
|
+
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
|
210
|
+
queue_->NotifyWorkStealingWorker();
|
211
|
+
}
|
212
|
+
|
213
|
+
while (!task.has_value() && !queue_->IsEmpty()) {
|
167
214
|
// Wrap around to the first partition.
|
168
215
|
if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
|
169
216
|
partition_index_ = 0;
|
@@ -171,7 +218,7 @@ inline std::optional<size_t> Worker::Pop() {
|
|
171
218
|
|
172
219
|
// We checked all partitions and got back to the partition we started from.
|
173
220
|
if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
|
174
|
-
queue_->
|
221
|
+
queue_->SetEmpty();
|
175
222
|
break;
|
176
223
|
}
|
177
224
|
|
@@ -205,6 +252,7 @@ Worker::ParallelizeContext<ParallelTask>::ParallelizeContext(
|
|
205
252
|
parallel_task(std::forward<ParallelTask>(parallel_task)) {}
|
206
253
|
|
207
254
|
template <typename ParallelTask>
|
255
|
+
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
|
208
256
|
void Worker::ParallelizeWithContext(ParallelizeContext<ParallelTask>* ctx,
|
209
257
|
uint16_t start_index, uint16_t end_index) {
|
210
258
|
DCHECK_LT(start_index, end_index) << "Invalid worker index range";
|
@@ -223,11 +271,26 @@ void Worker::ParallelizeWithContext(ParallelizeContext<ParallelTask>* ctx,
|
|
223
271
|
while (end_index - start_index > 1) {
|
224
272
|
// If work queue is empty, we don't need to keep enqueuing more workers and
|
225
273
|
// can simply count down for the remaining workers.
|
226
|
-
if (ABSL_PREDICT_FALSE(ctx->work_queue.
|
274
|
+
if (ABSL_PREDICT_FALSE(ctx->work_queue.IsEmpty())) {
|
227
275
|
count_down(end_index - start_index, absl::OkStatus());
|
228
276
|
return;
|
229
277
|
}
|
230
278
|
|
279
|
+
// If we have workers in the work stealing mode, we can skip enqueuing
|
280
|
+
// more tasks as existing workers will process remaining partitions. By
|
281
|
+
// doing this optimization we avoid unnecessary thread pool overheads.
|
282
|
+
size_t skip_workers =
|
283
|
+
ctx->work_queue.DecrementWorkStealingWorkers(end_index - start_index);
|
284
|
+
if (ABSL_PREDICT_FALSE(skip_workers > 0)) {
|
285
|
+
DCHECK_LE(skip_workers, end_index - start_index);
|
286
|
+
count_down(skip_workers, absl::OkStatus());
|
287
|
+
|
288
|
+
end_index -= skip_workers;
|
289
|
+
if (start_index == end_index) return;
|
290
|
+
if (end_index - start_index == 1) break;
|
291
|
+
}
|
292
|
+
|
293
|
+
DCHECK_GE(end_index - start_index, 1);
|
231
294
|
uint16_t mid_index = (start_index + end_index) / 2;
|
232
295
|
ctx->device->enqueueNoNotification([ctx, mid_index, end_index] {
|
233
296
|
ParallelizeWithContext(ctx, mid_index, end_index);
|
@@ -17,12 +17,14 @@ limitations under the License.
|
|
17
17
|
#define XLA_CODEGEN_KERNEL_SPEC_H_
|
18
18
|
|
19
19
|
#include <cstddef>
|
20
|
+
#include <cstdint>
|
20
21
|
#include <optional>
|
21
22
|
#include <string>
|
22
23
|
|
24
|
+
#include "absl/container/flat_hash_set.h"
|
23
25
|
#include "absl/container/inlined_vector.h"
|
24
26
|
#include "absl/strings/string_view.h"
|
25
|
-
#include "xla/
|
27
|
+
#include "xla/service/buffer_assignment.h"
|
26
28
|
#include "xla/stream_executor/launch_dim.h"
|
27
29
|
|
28
30
|
namespace xla {
|
@@ -33,15 +35,17 @@ namespace xla {
|
|
33
35
|
// will load kernel PTX on device and instantiate a KernelThunk.
|
34
36
|
class KernelSpec {
|
35
37
|
public:
|
36
|
-
using
|
38
|
+
using Buffers = absl::InlinedVector<BufferAllocation::Slice, 8>;
|
37
39
|
|
38
40
|
KernelSpec(absl::string_view name, se::ThreadDim thread_dim,
|
39
|
-
|
41
|
+
Buffers argument_buffers, Buffers result_buffers,
|
42
|
+
absl::flat_hash_set<int64_t> invariant_arguments,
|
40
43
|
std::optional<size_t> scratch_bytes = std::nullopt);
|
41
44
|
|
42
45
|
KernelSpec(absl::string_view name, se::ClusterDim cluster_dim,
|
43
46
|
se::BlockDim block_dim, se::ThreadDim thread_dim,
|
44
|
-
|
47
|
+
Buffers argument_buffers, Buffers result_buffers,
|
48
|
+
absl::flat_hash_set<int64_t> invariant_arguments,
|
45
49
|
std::optional<size_t> scratch_bytes = std::nullopt);
|
46
50
|
|
47
51
|
// Get the backend specific name of the kernel.
|
@@ -67,15 +71,28 @@ class KernelSpec {
|
|
67
71
|
// managed buffer that is likely to be in L1/L2 cache).
|
68
72
|
std::optional<size_t> scratch_bytes() const { return scratch_bytes_; }
|
69
73
|
|
70
|
-
//
|
71
|
-
const
|
74
|
+
// Argument buffers read by the kernel.
|
75
|
+
const Buffers& argument_buffers() const { return argument_buffers_; }
|
76
|
+
// Result buffers written to by the kernel.
|
77
|
+
const Buffers& result_buffers() const { return result_buffers_; }
|
78
|
+
|
79
|
+
// Returns a set of invariant arguments (corresponding to the indices in the
|
80
|
+
// argument buffers list).
|
81
|
+
const absl::flat_hash_set<int64_t>& invariant_arguments() const {
|
82
|
+
return invariant_arguments_;
|
83
|
+
}
|
72
84
|
|
73
85
|
private:
|
74
86
|
std::string name_;
|
75
87
|
se::ClusterDim cluster_dim_;
|
76
88
|
se::BlockDim block_dim_;
|
77
89
|
se::ThreadDim thread_dim_;
|
78
|
-
|
90
|
+
|
91
|
+
Buffers argument_buffers_;
|
92
|
+
Buffers result_buffers_;
|
93
|
+
|
94
|
+
absl::flat_hash_set<int64_t> invariant_arguments_;
|
95
|
+
|
79
96
|
std::optional<size_t> scratch_bytes_;
|
80
97
|
};
|
81
98
|
|
@@ -44,28 +44,6 @@ T* Cast(HloInstruction* instr) {
|
|
44
44
|
return tsl::down_cast<T*>(instr);
|
45
45
|
}
|
46
46
|
|
47
|
-
// Downcasts a const HloInstruction pointer or returns nullptr if argument is
|
48
|
-
// nullptr. Dies if TargetClass::ClassOf() does not match.
|
49
|
-
template <typename T>
|
50
|
-
const T* CastOrNull(const HloInstruction* i) {
|
51
|
-
if (i == nullptr) {
|
52
|
-
return nullptr;
|
53
|
-
}
|
54
|
-
CHECK(T::ClassOf(i));
|
55
|
-
return tsl::down_cast<const T*>(i);
|
56
|
-
}
|
57
|
-
|
58
|
-
// Downcasts a const HloInstruction pointer or returns nullptr if argument is
|
59
|
-
// nullptr. Dies if TargetClass::ClassOf() does not match.
|
60
|
-
template <typename T>
|
61
|
-
T* CastOrNull(HloInstruction* i) {
|
62
|
-
if (i == nullptr) {
|
63
|
-
return nullptr;
|
64
|
-
}
|
65
|
-
CHECK(T::ClassOf(i));
|
66
|
-
return tsl::down_cast<T*>(i);
|
67
|
-
}
|
68
|
-
|
69
47
|
// Downcasts a const HloInstruction pointer or returns nullptr if
|
70
48
|
// TargetClass::ClassOf() does not match. Dies if argument is nullptr. Similar
|
71
49
|
// to LLVM's dyn_cast.
|
@@ -84,28 +62,6 @@ T* DynCast(HloInstruction* i) {
|
|
84
62
|
return !T::ClassOf(i) ? nullptr : tsl::down_cast<T*>(i);
|
85
63
|
}
|
86
64
|
|
87
|
-
// Downcasts a const HloInstruction pointer. Return nullptr if argument is
|
88
|
-
// nullptr orTargetClass::ClassOf() does not match. Similar to LLVM's
|
89
|
-
// dyn_cast_or_null.
|
90
|
-
template <typename T>
|
91
|
-
const T* DynCastOrNull(const HloInstruction* instruction) {
|
92
|
-
if (instruction == nullptr || !T::ClassOf(instruction)) {
|
93
|
-
return nullptr;
|
94
|
-
}
|
95
|
-
return tsl::down_cast<const T*>(instruction);
|
96
|
-
}
|
97
|
-
|
98
|
-
// Downcasts a non-const HloInstruction pointer. Return nullptr if argument is
|
99
|
-
// nullptr orTargetClass::ClassOf() does not match. Similar to LLVM's
|
100
|
-
// dyn_cast_or_null.
|
101
|
-
template <typename T>
|
102
|
-
T* DynCastOrNull(HloInstruction* instruction) {
|
103
|
-
if (instruction == nullptr || !T::ClassOf(instruction)) {
|
104
|
-
return nullptr;
|
105
|
-
}
|
106
|
-
return tsl::down_cast<T*>(instruction);
|
107
|
-
}
|
108
|
-
|
109
65
|
} // namespace xla
|
110
66
|
|
111
67
|
#endif // XLA_HLO_IR_HLO_CASTING_UTILS_H_
|