tf-nightly-cpu 2.20.0.dev20250220__cp310-cp310-win_amd64.whl → 2.20.0.dev20250222__cp310-cp310-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (128) hide show
  1. tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
  2. tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
  3. tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
  4. tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
  5. tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
  6. tensorflow/_api/v2/summary/__init__.py +10 -10
  7. tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
  8. tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
  9. tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
  10. tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
  11. tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
  12. tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
  13. tensorflow/include/external/llvm-project/mlir/include/mlir/Dialect/Math/IR/MathOps.h.inc +4 -0
  14. tensorflow/include/external/shardy/shardy/dialect/sdy/transforms/propagation/aggressive_factor_propagation.h +9 -0
  15. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
  16. tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
  17. tensorflow/include/external/stablehlo/_virtual_includes/version/stablehlo/dialect/Version.h +1 -1
  18. tensorflow/include/external/stablehlo/stablehlo/dialect/Version.h +1 -1
  19. tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
  20. tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  21. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  22. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  23. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -19
  24. tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
  25. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
  26. tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_instruction.h +12 -0
  27. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  28. tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  29. tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
  30. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  31. tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  32. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  33. tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  34. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  35. tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  36. tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
  37. tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
  38. tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
  39. tensorflow/include/tensorflow/compiler/xla/tsl/concurrency/async_value.h +50 -21
  40. tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  41. tensorflow/include/tensorflow/core/kernels/data/experimental/random_access_ops.h +0 -2
  42. tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
  43. tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
  44. tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
  45. tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
  46. tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
  47. tensorflow/include/tensorflow/core/public/release_version.h +39 -0
  48. tensorflow/include/tensorflow/core/public/version.h +112 -127
  49. tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
  50. tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
  51. tensorflow/include/xla/backends/cpu/runtime/convolution_thunk_internal.h +8 -10
  52. tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
  53. tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -19
  54. tensorflow/include/xla/codegen/kernel_spec.h +24 -7
  55. tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
  56. tensorflow/include/xla/hlo/ir/hlo_instruction.h +12 -0
  57. tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
  58. tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
  59. tensorflow/include/xla/pjrt/distributed/client.h +5 -0
  60. tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
  61. tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
  62. tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
  63. tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
  64. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
  65. tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
  66. tensorflow/include/xla/service/constant_value.h +1 -0
  67. tensorflow/include/xla/service/hlo_module_util.h +52 -1
  68. tensorflow/include/xla/service/hlo_proto_util.h +0 -12
  69. tensorflow/include/xla/tsl/concurrency/async_value.h +50 -21
  70. tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
  71. tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
  72. tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
  73. tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
  74. tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
  75. tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
  76. tensorflow/python/_pywrap_mlir.pyd +0 -0
  77. tensorflow/python/_pywrap_parallel_device.pyd +0 -0
  78. tensorflow/python/_pywrap_quantize_training.pyd +0 -0
  79. tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
  80. tensorflow/python/_pywrap_tfcompile.pyd +0 -0
  81. tensorflow/python/_pywrap_tfe.pyd +0 -0
  82. tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
  83. tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
  84. tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
  85. tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
  86. tensorflow/python/compat/compat.py +1 -1
  87. tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
  88. tensorflow/python/eager/imperative_grad.py +5 -5
  89. tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
  90. tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
  91. tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
  92. tensorflow/python/eager/tape.py +2 -2
  93. tensorflow/python/framework/_dtypes.pyd +0 -0
  94. tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
  95. tensorflow/python/framework/_op_def_registry.pyd +0 -0
  96. tensorflow/python/framework/_proto_comparators.pyd +0 -0
  97. tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
  98. tensorflow/python/framework/_test_metrics_util.pyd +0 -0
  99. tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
  100. tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
  101. tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
  102. tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
  103. tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
  104. tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
  105. tensorflow/python/ops/summary_ops_v2.py +5 -1
  106. tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
  107. tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
  108. tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
  109. tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
  110. tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
  111. tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
  112. tensorflow/python/tpu/tpu_embedding_v3_checkpoint_adapter.py +10 -1
  113. tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
  114. tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
  115. tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
  116. tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
  117. tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
  118. tensorflow/python/util/_pywrap_utils.pyd +0 -0
  119. tensorflow/python/util/_tf_stack.pyd +0 -0
  120. tensorflow/tools/pip_package/setup.py +2 -2
  121. tensorflow/xla_aot_runtime_src/xla/tsl/concurrency/async_value.cc +26 -51
  122. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/METADATA +1 -1
  123. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/RECORD +126 -121
  124. tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/concurrency.h +0 -77
  125. tensorflow/include/xla/backends/cpu/runtime/concurrency.h +0 -77
  126. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/WHEEL +0 -0
  127. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/entry_points.txt +0 -0
  128. {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,8 @@
5
5
 
6
6
  import sys as _sys
7
7
 
8
- from tensorflow.python.ops.summary_ops_v2 import all_v2_summary_ops # line: 661
9
- from tensorflow.python.ops.summary_ops_v2 import initialize # line: 473
8
+ from tensorflow.python.ops.summary_ops_v2 import all_v2_summary_ops # line: 665
9
+ from tensorflow.python.ops.summary_ops_v2 import initialize # line: 477
10
10
  from tensorflow.python.proto_exports import Event # line: 28
11
11
  from tensorflow.python.proto_exports import SessionLog # line: 47
12
12
  from tensorflow.python.proto_exports import Summary # line: 50
@@ -19,8 +19,8 @@ from tensorflow.python.tpu.tpu_embedding_v2_utils import QuantizationConfig # li
19
19
  from tensorflow.python.tpu.tpu_embedding_v2_utils import RowIdInitializer # line: 1347
20
20
  from tensorflow.python.tpu.tpu_embedding_v2_utils import SGD # line: 363
21
21
  from tensorflow.python.tpu.tpu_embedding_v2_utils import TableConfig # line: 1161
22
- from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 77
23
- from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 475
22
+ from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 78
23
+ from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 482
24
24
 
25
25
  from tensorflow.python.util import module_wrapper as _module_wrapper
26
26
 
@@ -6,17 +6,17 @@
6
6
  import sys as _sys
7
7
 
8
8
  from tensorflow._api.v2.compat.v2.summary import experimental
9
- from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line: 244
10
- from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line: 516
11
- from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line: 641
12
- from tensorflow.python.ops.summary_ops_v2 import flush # line: 1141
13
- from tensorflow.python.ops.summary_ops_v2 import graph # line: 1053
14
- from tensorflow.python.ops.summary_ops_v2 import record_if # line: 157
9
+ from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line: 248
10
+ from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line: 520
11
+ from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line: 645
12
+ from tensorflow.python.ops.summary_ops_v2 import flush # line: 1145
13
+ from tensorflow.python.ops.summary_ops_v2 import graph # line: 1057
14
+ from tensorflow.python.ops.summary_ops_v2 import record_if # line: 161
15
15
  from tensorflow.python.ops.summary_ops_v2 import should_record_summaries # line: 133
16
- from tensorflow.python.ops.summary_ops_v2 import trace_export # line: 1390
17
- from tensorflow.python.ops.summary_ops_v2 import trace_off # line: 1443
18
- from tensorflow.python.ops.summary_ops_v2 import trace_on # line: 1334
19
- from tensorflow.python.ops.summary_ops_v2 import write # line: 737
16
+ from tensorflow.python.ops.summary_ops_v2 import trace_export # line: 1394
17
+ from tensorflow.python.ops.summary_ops_v2 import trace_off # line: 1447
18
+ from tensorflow.python.ops.summary_ops_v2 import trace_on # line: 1338
19
+ from tensorflow.python.ops.summary_ops_v2 import write # line: 741
20
20
  from tensorflow.python.summary.tb_summary import audio # line: 32
21
21
  from tensorflow.python.summary.tb_summary import histogram # line: 89
22
22
  from tensorflow.python.summary.tb_summary import image # line: 165
@@ -5,7 +5,7 @@
5
5
 
6
6
  import sys as _sys
7
7
 
8
- from tensorflow.python.ops.summary_ops_v2 import get_step # line: 214
9
- from tensorflow.python.ops.summary_ops_v2 import set_step # line: 225
10
- from tensorflow.python.ops.summary_ops_v2 import summary_scope # line: 696
11
- from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line: 814
8
+ from tensorflow.python.ops.summary_ops_v2 import get_step # line: 218
9
+ from tensorflow.python.ops.summary_ops_v2 import set_step # line: 229
10
+ from tensorflow.python.ops.summary_ops_v2 import summary_scope # line: 700
11
+ from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line: 818
@@ -19,5 +19,5 @@ from tensorflow.python.tpu.tpu_embedding_v2_utils import QuantizationConfig # li
19
19
  from tensorflow.python.tpu.tpu_embedding_v2_utils import RowIdInitializer # line: 1347
20
20
  from tensorflow.python.tpu.tpu_embedding_v2_utils import SGD # line: 363
21
21
  from tensorflow.python.tpu.tpu_embedding_v2_utils import TableConfig # line: 1161
22
- from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 77
23
- from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 475
22
+ from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 78
23
+ from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 482
@@ -6,17 +6,17 @@
6
6
  import sys as _sys
7
7
 
8
8
  from tensorflow._api.v2.summary import experimental
9
- from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line: 244
10
- from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line: 516
11
- from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line: 641
12
- from tensorflow.python.ops.summary_ops_v2 import flush # line: 1141
13
- from tensorflow.python.ops.summary_ops_v2 import graph # line: 1053
14
- from tensorflow.python.ops.summary_ops_v2 import record_if # line: 157
9
+ from tensorflow.python.ops.summary_ops_v2 import SummaryWriter # line: 248
10
+ from tensorflow.python.ops.summary_ops_v2 import create_file_writer_v2 as create_file_writer # line: 520
11
+ from tensorflow.python.ops.summary_ops_v2 import create_noop_writer # line: 645
12
+ from tensorflow.python.ops.summary_ops_v2 import flush # line: 1145
13
+ from tensorflow.python.ops.summary_ops_v2 import graph # line: 1057
14
+ from tensorflow.python.ops.summary_ops_v2 import record_if # line: 161
15
15
  from tensorflow.python.ops.summary_ops_v2 import should_record_summaries # line: 133
16
- from tensorflow.python.ops.summary_ops_v2 import trace_export # line: 1390
17
- from tensorflow.python.ops.summary_ops_v2 import trace_off # line: 1443
18
- from tensorflow.python.ops.summary_ops_v2 import trace_on # line: 1334
19
- from tensorflow.python.ops.summary_ops_v2 import write # line: 737
16
+ from tensorflow.python.ops.summary_ops_v2 import trace_export # line: 1394
17
+ from tensorflow.python.ops.summary_ops_v2 import trace_off # line: 1447
18
+ from tensorflow.python.ops.summary_ops_v2 import trace_on # line: 1338
19
+ from tensorflow.python.ops.summary_ops_v2 import write # line: 741
20
20
  from tensorflow.python.summary.tb_summary import audio # line: 32
21
21
  from tensorflow.python.summary.tb_summary import histogram # line: 89
22
22
  from tensorflow.python.summary.tb_summary import image # line: 165
@@ -5,7 +5,7 @@
5
5
 
6
6
  import sys as _sys
7
7
 
8
- from tensorflow.python.ops.summary_ops_v2 import get_step # line: 214
9
- from tensorflow.python.ops.summary_ops_v2 import set_step # line: 225
10
- from tensorflow.python.ops.summary_ops_v2 import summary_scope # line: 696
11
- from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line: 814
8
+ from tensorflow.python.ops.summary_ops_v2 import get_step # line: 218
9
+ from tensorflow.python.ops.summary_ops_v2 import set_step # line: 229
10
+ from tensorflow.python.ops.summary_ops_v2 import summary_scope # line: 700
11
+ from tensorflow.python.ops.summary_ops_v2 import write_raw_pb # line: 818
@@ -19,5 +19,5 @@ from tensorflow.python.tpu.tpu_embedding_v2_utils import QuantizationConfig # li
19
19
  from tensorflow.python.tpu.tpu_embedding_v2_utils import RowIdInitializer # line: 1347
20
20
  from tensorflow.python.tpu.tpu_embedding_v2_utils import SGD # line: 363
21
21
  from tensorflow.python.tpu.tpu_embedding_v2_utils import TableConfig # line: 1161
22
- from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 77
23
- from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 475
22
+ from tensorflow.python.tpu.tpu_embedding_v3 import SparseCoreEmbeddingConfig # line: 78
23
+ from tensorflow.python.tpu.tpu_embedding_v3 import TPUEmbeddingV2 # line: 482
Binary file
@@ -71,6 +71,18 @@ public:
71
71
  unsigned firstIndex) override;
72
72
  };
73
73
 
74
+ /// Succeeds if an op can be converted to its unsigned equivalent without
75
+ /// changing its semantics. This is the case when none of its openands or
76
+ /// results can be below 0 when analyzed from a signed perspective.
77
+ LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op);
78
+
79
+ /// Succeeds when a value is statically non-negative in that it has a lower
80
+ /// bound on its value (if it is treated as signed) and that bound is
81
+ /// non-negative.
82
+ /// Note, the results of this query may not be accurate for `index` if you plan
83
+ /// to use a non-64-bit index.
84
+ LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v);
85
+
74
86
  } // end namespace dataflow
75
87
  } // end namespace mlir
76
88
 
@@ -5711,6 +5711,7 @@ public:
5711
5711
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
5712
5712
  ::llvm::LogicalResult verifyInvariantsImpl();
5713
5713
  ::llvm::LogicalResult verifyInvariants();
5714
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
5714
5715
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
5715
5716
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
5716
5717
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -5925,6 +5926,7 @@ public:
5925
5926
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
5926
5927
  ::llvm::LogicalResult verifyInvariantsImpl();
5927
5928
  ::llvm::LogicalResult verifyInvariants();
5929
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
5928
5930
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
5929
5931
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
5930
5932
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -6139,6 +6141,7 @@ public:
6139
6141
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
6140
6142
  ::llvm::LogicalResult verifyInvariantsImpl();
6141
6143
  ::llvm::LogicalResult verifyInvariants();
6144
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
6142
6145
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
6143
6146
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
6144
6147
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -6353,6 +6356,7 @@ public:
6353
6356
  static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties);
6354
6357
  ::llvm::LogicalResult verifyInvariantsImpl();
6355
6358
  ::llvm::LogicalResult verifyInvariants();
6359
+ ::mlir::OpFoldResult fold(FoldAdaptor adaptor);
6356
6360
  static ::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes);
6357
6361
  static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
6358
6362
  void print(::mlir::OpAsmPrinter &_odsPrinter);
@@ -81,6 +81,15 @@ class AggressiveFactorPropagation : public BasicFactorPropagation {
81
81
  PropagationDirectionAlongFactor directionAlongFactor,
82
82
  ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
83
83
  bool conservativePropagation) const override;
84
+
85
+ private:
86
+ // Returns the axes to propagate to an individual factor in the given
87
+ // `tensorFactorShardings` of a tensor.
88
+ SmallVector<AxisRefAttr> getPropagatedFactorSharding(
89
+ int64_t factorIndex, const TensorFactorShardings& tensorFactorShardings,
90
+ const FactorIndexToSharding& factorIndexToSharding,
91
+ AxesPerFactorRef axesPerFactor, MeshAttr mesh,
92
+ bool conservativePropagation, ArrayRef<int64_t> factorSizes) const;
84
93
  };
85
94
 
86
95
  } // namespace sdy
@@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
69
69
  // Check if any of the given types are mlir::quant::QuantizedType.
70
70
  bool isAnyQuantizedTypes(TypeRange types);
71
71
 
72
+ // Creates a quantized element type based on the given parameters.
73
+ Type getQuantizedElementType(Location loc, Type storageType, Type expressedType,
74
+ ArrayRef<double> scales,
75
+ ArrayRef<int64_t> zeroPoints,
76
+ int32_t quantizedDimension, int64_t storageTypeMin,
77
+ int64_t storageTypeMax);
78
+
72
79
  } // namespace stablehlo
73
80
  } // namespace mlir
74
81
 
@@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
69
69
  // Check if any of the given types are mlir::quant::QuantizedType.
70
70
  bool isAnyQuantizedTypes(TypeRange types);
71
71
 
72
+ // Creates a quantized element type based on the given parameters.
73
+ Type getQuantizedElementType(Location loc, Type storageType, Type expressedType,
74
+ ArrayRef<double> scales,
75
+ ArrayRef<int64_t> zeroPoints,
76
+ int32_t quantizedDimension, int64_t storageTypeMin,
77
+ int64_t storageTypeMax);
78
+
72
79
  } // namespace stablehlo
73
80
  } // namespace mlir
74
81
 
@@ -38,7 +38,7 @@ class Version {
38
38
  static FailureOr<Version> fromString(llvm::StringRef versionRef);
39
39
 
40
40
  /// Return a Version representing the current VHLO dialect version.
41
- static Version getCurrentVersion() { return Version(1, 9, 2); }
41
+ static Version getCurrentVersion() { return Version(1, 9, 3); }
42
42
 
43
43
  /// Return a Version representing the minimum supported VHLO dialect version.
44
44
  static Version getMinimumVersion() { return Version(0, 9, 0); }
@@ -38,7 +38,7 @@ class Version {
38
38
  static FailureOr<Version> fromString(llvm::StringRef versionRef);
39
39
 
40
40
  /// Return a Version representing the current VHLO dialect version.
41
- static Version getCurrentVersion() { return Version(1, 9, 2); }
41
+ static Version getCurrentVersion() { return Version(1, 9, 3); }
42
42
 
43
43
  /// Return a Version representing the minimum supported VHLO dialect version.
44
44
  static Version getMinimumVersion() { return Version(0, 9, 0); }
@@ -69,6 +69,13 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
69
69
  // Check if any of the given types are mlir::quant::QuantizedType.
70
70
  bool isAnyQuantizedTypes(TypeRange types);
71
71
 
72
+ // Creates a quantized element type based on the given parameters.
73
+ Type getQuantizedElementType(Location loc, Type storageType, Type expressedType,
74
+ ArrayRef<double> scales,
75
+ ArrayRef<int64_t> zeroPoints,
76
+ int32_t quantizedDimension, int64_t storageTypeMin,
77
+ int64_t storageTypeMax);
78
+
72
79
  } // namespace stablehlo
73
80
  } // namespace mlir
74
81
 
@@ -89,9 +89,10 @@ class KernelApiIrBuilder {
89
89
  // read-only if it is not aliased with any result.
90
90
  absl::flat_hash_set<int64_t> invariant_arguments;
91
91
 
92
- // the set of buffer uses for this kernel, can be empty if buffer
92
+ // The set of buffers used by this kernel, can be empty if buffer assignment
93
93
  // was not provided.
94
- absl::InlinedVector<BufferUse, 8> buffer_uses;
94
+ absl::InlinedVector<BufferAllocation::Slice, 8> argument_buffers;
95
+ absl::InlinedVector<BufferAllocation::Slice, 8> result_buffers;
95
96
  };
96
97
 
97
98
  KernelApiIrBuilder(llvm::LLVMContext& context, Options options);
@@ -22,7 +22,7 @@ limitations under the License.
22
22
  #include <memory>
23
23
  #include <utility>
24
24
 
25
- #include "xla/backends/cpu/runtime/concurrency.h"
25
+ #include "xla/backends/cpu/runtime/work_queue.h"
26
26
  #include "xla/tsl/concurrency/async_value_ref.h"
27
27
  #include "xla/tsl/concurrency/chain.h"
28
28
  #include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep
@@ -30,7 +30,6 @@ limitations under the License.
30
30
 
31
31
  #define EIGEN_USE_THREADS
32
32
  #include "Eigen/Core"
33
- #include "Eigen/ThreadPool"
34
33
  #include "unsupported/Eigen/CXX11/Tensor"
35
34
 
36
35
  namespace xla::cpu::internal {
@@ -384,8 +383,9 @@ void EigenGenericConv2D(
384
383
  auto num_tasks = Eigen::numext::div_ceil(feature_group_count, task_size);
385
384
 
386
385
  if (use_thunk_runtime) {
387
- ScheduleAll(
388
- &device, num_tasks, [=, &device](Eigen::Index task_index) mutable {
386
+ Worker::Parallelize(
387
+ &device, /*num_workers=*/num_tasks, num_tasks,
388
+ [=, &device](Eigen::Index task_index) mutable {
389
389
  Eigen::Index start = task_index * task_size;
390
390
  Eigen::Index end = std::min(start + task_size, feature_group_count);
391
391
  for (Eigen::Index i = start; i < end; ++i) {
@@ -395,18 +395,16 @@ void EigenGenericConv2D(
395
395
  }
396
396
  });
397
397
  } else {
398
- Eigen::Barrier barrier(num_tasks);
399
- ScheduleAll(
400
- &device, num_tasks, [=, &device, &barrier](Eigen::Index task_index) {
398
+ tsl::BlockUntilReady(Worker::Parallelize(
399
+ &device, /*num_workers=*/num_tasks, num_tasks,
400
+ [=, &device](Eigen::Index task_index) {
401
401
  Eigen::Index start = task_index * task_size;
402
402
  Eigen::Index end = std::min(start + task_size, feature_group_count);
403
403
  for (Eigen::Index i = start; i < end; ++i) {
404
404
  auto [output, convolved] = convolve_group(i);
405
405
  output.device(device) = convolved;
406
406
  }
407
- barrier.Notify();
408
- });
409
- barrier.Wait();
407
+ }));
410
408
  }
411
409
 
412
410
  } else {
@@ -63,6 +63,8 @@ class KernelThunkBase : public Thunk {
63
63
  const = 0;
64
64
 
65
65
  virtual absl::Span<const BufferAllocation::Slice> results_buffers() const = 0;
66
+
67
+ virtual const absl::flat_hash_set<int64_t>& invariant_arguments() const = 0;
66
68
  };
67
69
 
68
70
  namespace internal {
@@ -95,6 +97,10 @@ class KernelThunk : public KernelThunkBase {
95
97
  return absl::MakeSpan(results_buffers_);
96
98
  }
97
99
 
100
+ const absl::flat_hash_set<int64_t>& invariant_arguments() const final {
101
+ return invariant_arguments_;
102
+ }
103
+
98
104
  protected:
99
105
  tsl::AsyncValueRef<ExecuteEvent> ExecuteInternal(const ExecuteParams& params);
100
106
 
@@ -129,7 +135,7 @@ class KernelThunk : public KernelThunkBase {
129
135
  KernelThunk(Info info,
130
136
  absl::Span<const BufferAllocation::Slice> arguments_buffers,
131
137
  absl::Span<const BufferAllocation::Slice> results_buffers,
132
- std::optional<absl::flat_hash_set<int64_t>> invariant_arguments,
138
+ absl::flat_hash_set<int64_t> invariant_arguments,
133
139
  std::string kernel_name, se::ThreadDim thread_dim,
134
140
  std::optional<uint64_t> min_alignment);
135
141
 
@@ -139,7 +145,7 @@ class KernelThunk : public KernelThunkBase {
139
145
  ResultsBuffers results_buffers_;
140
146
 
141
147
  // A set of invariant arguments (their indices).
142
- std::optional<absl::flat_hash_set<int64_t>> invariant_arguments_;
148
+ absl::flat_hash_set<int64_t> invariant_arguments_;
143
149
 
144
150
  size_t num_kernel_args_;
145
151
 
@@ -189,7 +195,7 @@ class KernelThunk final : public internal::KernelThunk<> {
189
195
  absl::Span<const BufferAllocation::Slice> arguments_buffers,
190
196
  absl::Span<const BufferAllocation::Slice> results_buffers,
191
197
  std::string kernel_name, se::ThreadDim thread_dim,
192
- std::optional<absl::flat_hash_set<int64_t>> invariant_arguments,
198
+ absl::flat_hash_set<int64_t> invariant_arguments,
193
199
  std::optional<uint64_t> min_alignment = std::nullopt);
194
200
 
195
201
  static absl::StatusOr<std::unique_ptr<Thunk>> Create(
@@ -29,7 +29,6 @@ limitations under the License.
29
29
  #include "absl/base/attributes.h"
30
30
  #include "absl/base/optimization.h"
31
31
  #include "absl/container/fixed_array.h"
32
- #include "absl/log/check.h"
33
32
  #include "absl/status/status.h"
34
33
  #include "xla/tsl/concurrency/async_value_ref.h"
35
34
  #include "xla/tsl/concurrency/chain.h"
@@ -44,15 +43,6 @@ namespace xla::cpu {
44
43
  // A work queue that partitions `num_tasks` tasks into `num_partitions`
45
44
  // partitions processed by parallel workers.
46
45
  class WorkQueue {
47
- // Align all atomic counters to a cache line boundary to avoid false
48
- // sharing between multiple worker threads.
49
- static constexpr size_t kAtomicAlignment =
50
- #if defined(__cpp_lib_hardware_interference_size)
51
- std::hardware_destructive_interference_size;
52
- #else
53
- 64;
54
- #endif
55
-
56
46
  public:
57
47
  WorkQueue(size_t num_tasks, size_t num_partitions);
58
48
 
@@ -60,13 +50,23 @@ class WorkQueue {
60
50
  // if the partition is complete.
61
51
  std::optional<size_t> Pop(size_t partition_index);
62
52
 
63
- size_t num_partitions() const { return partitions_.size(); }
53
+ // Return the partition [begin, end) task range.
54
+ std::pair<size_t, size_t> partition_range(size_t partition_index) const;
64
55
 
65
- bool empty() const { return empty_.load(std::memory_order_relaxed); }
56
+ size_t num_partitions() const { return partitions_.size(); }
66
57
 
67
58
  private:
68
59
  friend class Worker;
69
60
 
61
+ // Align all atomic counters to a cache line boundary to avoid false
62
+ // sharing between multiple worker threads.
63
+ static constexpr size_t kAtomicAlignment =
64
+ #if defined(__cpp_lib_hardware_interference_size)
65
+ std::hardware_destructive_interference_size;
66
+ #else
67
+ 64;
68
+ #endif
69
+
70
70
  struct Partition {
71
71
  void Initialize(size_t begin, size_t end);
72
72
 
@@ -76,8 +76,21 @@ class WorkQueue {
76
76
  size_t end;
77
77
  };
78
78
 
79
+ // An empty work queue flag to stop worker threads from looping through all
80
+ // partitions looking for work.
81
+ bool IsEmpty() const { return empty_.load(std::memory_order_relaxed); }
82
+ void SetEmpty() { empty_.store(true, std::memory_order_relaxed); }
83
+
84
+ // Notify that one of the workers switched to the work stealing mode.
85
+ void NotifyWorkStealingWorker();
86
+
87
+ // Decrements the number of work stealing workers by at most `max_workers` and
88
+ // returns the number of decremented work stealing workers.
89
+ size_t DecrementWorkStealingWorkers(size_t max_workers);
90
+
79
91
  absl::FixedArray<Partition, 32> partitions_;
80
92
  alignas(kAtomicAlignment) std::atomic<bool> empty_;
93
+ alignas(kAtomicAlignment) std::atomic<size_t> num_work_stealing_workers_;
81
94
  };
82
95
 
83
96
  // Worker processes tasks from the work queue starting from the assigned
@@ -130,10 +143,14 @@ inline void WorkQueue::Partition::Initialize(size_t begin, size_t end) {
130
143
  }
131
144
 
132
145
  inline WorkQueue::WorkQueue(size_t num_tasks, size_t num_partitions)
133
- : partitions_(num_partitions), empty_(num_tasks == 0) {
134
- size_t partition_size = tsl::MathUtil::CeilOfRatio(num_tasks, num_partitions);
135
- for (size_t i = 0, begin = 0, end = partition_size; i < num_partitions;
136
- ++i, begin = end, end = std::min(num_tasks, end + partition_size)) {
146
+ : partitions_(num_partitions),
147
+ empty_(num_tasks == 0),
148
+ num_work_stealing_workers_(0) {
149
+ size_t partition_size =
150
+ tsl::MathUtil::FloorOfRatio(num_tasks, num_partitions);
151
+ size_t rem_tasks = num_tasks % num_partitions;
152
+ for (size_t i = 0, begin = 0, end = 0; i < num_partitions; ++i, begin = end) {
153
+ end = begin + partition_size + ((i < rem_tasks) ? 1 : 0);
137
154
  partitions_[i].Initialize(begin, end);
138
155
  }
139
156
  }
@@ -154,6 +171,29 @@ inline std::optional<size_t> WorkQueue::Pop(size_t partition_index) {
154
171
  : std::make_optional(index);
155
172
  }
156
173
 
174
+ inline std::pair<size_t, size_t> WorkQueue::partition_range(
175
+ size_t partition_index) const {
176
+ DCHECK(partition_index < partitions_.size()) << "Invalid partition index";
177
+ return {partitions_[partition_index].begin, partitions_[partition_index].end};
178
+ }
179
+
180
+ inline void WorkQueue::NotifyWorkStealingWorker() {
181
+ num_work_stealing_workers_.fetch_add(1, std::memory_order_relaxed);
182
+ }
183
+
184
+ inline size_t WorkQueue::DecrementWorkStealingWorkers(size_t max_workers) {
185
+ size_t n = num_work_stealing_workers_.load(std::memory_order_relaxed);
186
+
187
+ size_t decrement = std::min(n, max_workers);
188
+ while (decrement && !num_work_stealing_workers_.compare_exchange_weak(
189
+ n, n - decrement, std::memory_order_relaxed,
190
+ std::memory_order_relaxed)) {
191
+ decrement = std::min(n, max_workers);
192
+ }
193
+
194
+ return decrement;
195
+ }
196
+
157
197
  inline Worker::Worker(size_t worker_index, WorkQueue* queue)
158
198
  : worker_index_(worker_index),
159
199
  partition_index_(worker_index),
@@ -163,7 +203,13 @@ inline std::optional<size_t> Worker::Pop() {
163
203
  std::optional<size_t> task = queue_->Pop(partition_index_);
164
204
  if (ABSL_PREDICT_TRUE(task)) return task;
165
205
 
166
- while (!task.has_value() && !queue_->empty()) {
206
+ // If we didn't find a task in the initially assigned partition, notify the
207
+ // work queue that we are switching to work stealing mode.
208
+ if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
209
+ queue_->NotifyWorkStealingWorker();
210
+ }
211
+
212
+ while (!task.has_value() && !queue_->IsEmpty()) {
167
213
  // Wrap around to the first partition.
168
214
  if (ABSL_PREDICT_FALSE(++partition_index_ >= queue_->num_partitions())) {
169
215
  partition_index_ = 0;
@@ -171,7 +217,7 @@ inline std::optional<size_t> Worker::Pop() {
171
217
 
172
218
  // We checked all partitions and got back to the partition we started from.
173
219
  if (ABSL_PREDICT_FALSE(partition_index_ == worker_index_)) {
174
- queue_->empty_.store(true, std::memory_order_relaxed);
220
+ queue_->SetEmpty();
175
221
  break;
176
222
  }
177
223
 
@@ -205,6 +251,7 @@ Worker::ParallelizeContext<ParallelTask>::ParallelizeContext(
205
251
  parallel_task(std::forward<ParallelTask>(parallel_task)) {}
206
252
 
207
253
  template <typename ParallelTask>
254
+ // NOLINTNEXTLINE(readability-function-cognitive-complexity)
208
255
  void Worker::ParallelizeWithContext(ParallelizeContext<ParallelTask>* ctx,
209
256
  uint16_t start_index, uint16_t end_index) {
210
257
  DCHECK_LT(start_index, end_index) << "Invalid worker index range";
@@ -223,11 +270,26 @@ void Worker::ParallelizeWithContext(ParallelizeContext<ParallelTask>* ctx,
223
270
  while (end_index - start_index > 1) {
224
271
  // If work queue is empty, we don't need to keep enqueuing more workers and
225
272
  // can simply count down for the remaining workers.
226
- if (ABSL_PREDICT_FALSE(ctx->work_queue.empty())) {
273
+ if (ABSL_PREDICT_FALSE(ctx->work_queue.IsEmpty())) {
227
274
  count_down(end_index - start_index, absl::OkStatus());
228
275
  return;
229
276
  }
230
277
 
278
+ // If we have workers in the work stealing mode, we can skip enqueuing
279
+ // more tasks as existing workers will process remaining partitions. By
280
+ // doing this optimization we avoid unnecessary thread pool overheads.
281
+ size_t skip_workers =
282
+ ctx->work_queue.DecrementWorkStealingWorkers(end_index - start_index);
283
+ if (ABSL_PREDICT_FALSE(skip_workers > 0)) {
284
+ DCHECK_LE(skip_workers, end_index - start_index);
285
+ count_down(skip_workers, absl::OkStatus());
286
+
287
+ end_index -= skip_workers;
288
+ if (start_index == end_index) return;
289
+ if (end_index - start_index == 1) break;
290
+ }
291
+
292
+ DCHECK_GE(end_index - start_index, 1);
231
293
  uint16_t mid_index = (start_index + end_index) / 2;
232
294
  ctx->device->enqueueNoNotification([ctx, mid_index, end_index] {
233
295
  ParallelizeWithContext(ctx, mid_index, end_index);
@@ -17,12 +17,14 @@ limitations under the License.
17
17
  #define XLA_CODEGEN_KERNEL_SPEC_H_
18
18
 
19
19
  #include <cstddef>
20
+ #include <cstdint>
20
21
  #include <optional>
21
22
  #include <string>
22
23
 
24
+ #include "absl/container/flat_hash_set.h"
23
25
  #include "absl/container/inlined_vector.h"
24
26
  #include "absl/strings/string_view.h"
25
- #include "xla/runtime/buffer_use.h"
27
+ #include "xla/service/buffer_assignment.h"
26
28
  #include "xla/stream_executor/launch_dim.h"
27
29
 
28
30
  namespace xla {
@@ -33,15 +35,17 @@ namespace xla {
33
35
  // will load kernel PTX on device and instantiate a KernelThunk.
34
36
  class KernelSpec {
35
37
  public:
36
- using BufferUses = absl::InlinedVector<BufferUse, 8>;
38
+ using Buffers = absl::InlinedVector<BufferAllocation::Slice, 8>;
37
39
 
38
40
  KernelSpec(absl::string_view name, se::ThreadDim thread_dim,
39
- BufferUses buffer_uses,
41
+ Buffers argument_buffers, Buffers result_buffers,
42
+ absl::flat_hash_set<int64_t> invariant_arguments,
40
43
  std::optional<size_t> scratch_bytes = std::nullopt);
41
44
 
42
45
  KernelSpec(absl::string_view name, se::ClusterDim cluster_dim,
43
46
  se::BlockDim block_dim, se::ThreadDim thread_dim,
44
- BufferUses buffer_uses,
47
+ Buffers argument_buffers, Buffers result_buffers,
48
+ absl::flat_hash_set<int64_t> invariant_arguments,
45
49
  std::optional<size_t> scratch_bytes = std::nullopt);
46
50
 
47
51
  // Get the backend specific name of the kernel.
@@ -67,15 +71,28 @@ class KernelSpec {
67
71
  // managed buffer that is likely to be in L1/L2 cache).
68
72
  std::optional<size_t> scratch_bytes() const { return scratch_bytes_; }
69
73
 
70
- // Buffers (buffer allocation slices) used by the kernel.
71
- const BufferUses& buffer_uses() const { return buffer_uses_; }
74
+ // Argument buffers read by the kernel.
75
+ const Buffers& argument_buffers() const { return argument_buffers_; }
76
+ // Result buffers written to by the kernel.
77
+ const Buffers& result_buffers() const { return result_buffers_; }
78
+
79
+ // Returns a set of invariant arguments (corresponding to the indices in the
80
+ // argument buffers list).
81
+ const absl::flat_hash_set<int64_t>& invariant_arguments() const {
82
+ return invariant_arguments_;
83
+ }
72
84
 
73
85
  private:
74
86
  std::string name_;
75
87
  se::ClusterDim cluster_dim_;
76
88
  se::BlockDim block_dim_;
77
89
  se::ThreadDim thread_dim_;
78
- BufferUses buffer_uses_;
90
+
91
+ Buffers argument_buffers_;
92
+ Buffers result_buffers_;
93
+
94
+ absl::flat_hash_set<int64_t> invariant_arguments_;
95
+
79
96
  std::optional<size_t> scratch_bytes_;
80
97
  };
81
98