tf-nightly-cpu 2.20.0.dev20250220__cp39-cp39-win_amd64.whl → 2.20.0.dev20250221__cp39-cp39-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- tensorflow/_api/v2/compat/v1/summary/__init__.py +2 -2
- tensorflow/_api/v2/compat/v1/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/compat/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/compat/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/compat/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/_api/v2/summary/__init__.py +10 -10
- tensorflow/_api/v2/summary/experimental/__init__.py +4 -4
- tensorflow/_api/v2/tpu/experimental/embedding/__init__.py +2 -2
- tensorflow/compiler/mlir/stablehlo/stablehlo_extension.pyd +0 -0
- tensorflow/compiler/tf2tensorrt/_pywrap_py_utils.pyd +0 -0
- tensorflow/compiler/tf2xla/ops/_xla_ops.so +0 -0
- tensorflow/include/external/llvm-project/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +12 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_pass_utils/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/_virtual_includes/stablehlo_passes/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/external/stablehlo/stablehlo/transforms/PassUtils.h +7 -0
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/tensorflow/compiler/xla/backends/cpu/runtime/work_queue.h +81 -18
- tensorflow/include/tensorflow/compiler/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/tensorflow/compiler/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/tensorflow/compiler/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/tensorflow/compiler/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/tensorflow/compiler/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/tensorflow/compiler/xla/service/constant_value.h +1 -0
- tensorflow/include/tensorflow/compiler/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/tensorflow/compiler/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/tensorflow/compiler/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/include/tensorflow/core/kernels/eigen_attention.h +4 -4
- tensorflow/include/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +10 -8
- tensorflow/include/tensorflow/core/kernels/eigen_cuboid_convolution.h +6 -6
- tensorflow/include/tensorflow/core/kernels/eigen_pooling.h +12 -12
- tensorflow/include/tensorflow/core/public/release_version.h +39 -0
- tensorflow/include/tensorflow/core/public/version.h +112 -127
- tensorflow/include/tensorflow/python/eager/pywrap_tfe.h +1 -1
- tensorflow/include/xla/backends/cpu/codegen/kernel_api_ir_builder.h +3 -2
- tensorflow/include/xla/backends/cpu/runtime/kernel_thunk.h +9 -3
- tensorflow/include/xla/backends/cpu/runtime/work_queue.h +81 -18
- tensorflow/include/xla/codegen/kernel_spec.h +24 -7
- tensorflow/include/xla/hlo/ir/hlo_casting_utils.h +0 -44
- tensorflow/include/xla/mlir_hlo/_virtual_includes/stablehlo_extension_pass_inc_gen/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/mlir_hlo/stablehlo_ext/transforms/passes.h.inc +149 -4
- tensorflow/include/xla/pjrt/distributed/client.h +5 -0
- tensorflow/include/xla/pjrt/gpu/se_gpu_pjrt_client.h +1 -92
- tensorflow/include/xla/pjrt/gpu/se_gpu_topology_description.h +126 -0
- tensorflow/include/xla/pjrt/pjrt_stream_executor_client.h +1 -49
- tensorflow/include/xla/pjrt/pjrt_stream_executor_device_description.h +75 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_execute_options.h +57 -0
- tensorflow/include/xla/pjrt/plugin/xla_cpu/cpu_topology.h +4 -0
- tensorflow/include/xla/service/constant_value.h +1 -0
- tensorflow/include/xla/service/hlo_module_util.h +52 -1
- tensorflow/include/xla/service/hlo_proto_util.h +0 -12
- tensorflow/include/xla/tsl/framework/convolution/eigen_spatial_convolutions-inl.h +5 -5
- tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so +0 -0
- tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.pyd +0 -0
- tensorflow/lite/python/interpreter_wrapper/_pywrap_tensorflow_interpreter_wrapper.pyd +0 -0
- tensorflow/lite/python/optimize/_pywrap_tensorflow_lite_calibration_wrapper.pyd +0 -0
- tensorflow/python/_pywrap_dtensor_device.pyd +0 -0
- tensorflow/python/_pywrap_mlir.pyd +0 -0
- tensorflow/python/_pywrap_parallel_device.pyd +0 -0
- tensorflow/python/_pywrap_quantize_training.pyd +0 -0
- tensorflow/python/_pywrap_tensorflow_internal.pyd +0 -0
- tensorflow/python/_pywrap_tfcompile.pyd +0 -0
- tensorflow/python/_pywrap_tfe.pyd +0 -0
- tensorflow/python/client/_pywrap_debug_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_device_lib.pyd +0 -0
- tensorflow/python/client/_pywrap_events_writer.pyd +0 -0
- tensorflow/python/client/_pywrap_tf_session.pyd +0 -0
- tensorflow/python/compat/compat.py +1 -1
- tensorflow/python/data/experimental/service/_pywrap_server_lib.pyd +0 -0
- tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyd +0 -0
- tensorflow/python/eager/imperative_grad.py +5 -5
- tensorflow/python/eager/polymorphic_function/atomic_function.py +1 -1
- tensorflow/python/eager/polymorphic_function/compiler_ir.py +1 -1
- tensorflow/python/eager/polymorphic_function/polymorphic_function.py +45 -41
- tensorflow/python/eager/tape.py +2 -2
- tensorflow/python/framework/_dtypes.pyd +0 -0
- tensorflow/python/framework/_op_def_library_pybind.pyd +0 -0
- tensorflow/python/framework/_op_def_registry.pyd +0 -0
- tensorflow/python/framework/_proto_comparators.pyd +0 -0
- tensorflow/python/framework/_pywrap_python_op_gen.pyd +0 -0
- tensorflow/python/framework/_test_metrics_util.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_cluster.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_item.pyd +0 -0
- tensorflow/python/grappler/_pywrap_tf_optimizer.pyd +0 -0
- tensorflow/python/lib/core/_pywrap_py_func.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_file_io.pyd +0 -0
- tensorflow/python/lib/io/_pywrap_record_io.pyd +0 -0
- tensorflow/python/ops/summary_ops_v2.py +5 -1
- tensorflow/python/platform/_pywrap_tf2.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler.pyd +0 -0
- tensorflow/python/profiler/internal/_pywrap_profiler_plugin.pyd +0 -0
- tensorflow/python/saved_model/pywrap_saved_model.pyd +0 -0
- tensorflow/python/tpu/_pywrap_sparse_core_layout.pyd +0 -0
- tensorflow/python/tpu/_pywrap_tpu_embedding.pyd +0 -0
- tensorflow/python/tpu/tpu_embedding_v3.py +14 -7
- tensorflow/python/util/_pywrap_checkpoint_reader.pyd +0 -0
- tensorflow/python/util/_pywrap_kernel_registry.pyd +0 -0
- tensorflow/python/util/_pywrap_stat_summarizer.pyd +0 -0
- tensorflow/python/util/_pywrap_tfprof.pyd +0 -0
- tensorflow/python/util/_pywrap_transform_graph.pyd +0 -0
- tensorflow/python/util/_pywrap_utils.pyd +0 -0
- tensorflow/python/util/_tf_stack.pyd +0 -0
- tensorflow/tools/pip_package/setup.py +2 -2
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/METADATA +1 -1
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/RECORD +115 -108
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/WHEEL +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/entry_points.txt +0 -0
- {tf_nightly_cpu-2.20.0.dev20250220.dist-info → tf_nightly_cpu-2.20.0.dev20250221.dist-info}/top_level.txt +0 -0
@@ -52,9 +52,9 @@ limitations under the License.
|
|
52
52
|
#include "xla/pjrt/pjrt_client.h"
|
53
53
|
#include "xla/pjrt/pjrt_common.h"
|
54
54
|
#include "xla/pjrt/pjrt_compiler.h"
|
55
|
-
#include "xla/pjrt/pjrt_device_description.h"
|
56
55
|
#include "xla/pjrt/pjrt_executable.h"
|
57
56
|
#include "xla/pjrt/pjrt_future.h"
|
57
|
+
#include "xla/pjrt/pjrt_stream_executor_device_description.h"
|
58
58
|
#include "xla/pjrt/tracked_device_buffer.h"
|
59
59
|
#include "xla/pjrt/transpose.h"
|
60
60
|
#include "xla/pjrt/utils.h"
|
@@ -77,54 +77,6 @@ limitations under the License.
|
|
77
77
|
|
78
78
|
namespace xla {
|
79
79
|
|
80
|
-
class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
|
81
|
-
public:
|
82
|
-
explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
|
83
|
-
int process_index = 0)
|
84
|
-
: id_(id),
|
85
|
-
process_index_(process_index),
|
86
|
-
device_kind_(std::move(device_kind)) {}
|
87
|
-
|
88
|
-
int id() const override { return id_; }
|
89
|
-
|
90
|
-
int process_index() const override { return process_index_; }
|
91
|
-
|
92
|
-
absl::string_view device_kind() const override { return device_kind_; }
|
93
|
-
|
94
|
-
absl::string_view ToString() const override { return to_string_; }
|
95
|
-
|
96
|
-
absl::string_view DebugString() const override { return debug_string_; }
|
97
|
-
|
98
|
-
absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
|
99
|
-
|
100
|
-
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
101
|
-
const override {
|
102
|
-
return attributes_;
|
103
|
-
}
|
104
|
-
|
105
|
-
void SetAttributes(
|
106
|
-
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
|
107
|
-
attributes_ = std::move(attributes);
|
108
|
-
}
|
109
|
-
|
110
|
-
void SetDebugString(std::string debug_string) {
|
111
|
-
debug_string_ = std::move(debug_string);
|
112
|
-
}
|
113
|
-
|
114
|
-
void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
|
115
|
-
|
116
|
-
void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
|
117
|
-
|
118
|
-
private:
|
119
|
-
const int id_;
|
120
|
-
const int process_index_;
|
121
|
-
const std::string device_kind_;
|
122
|
-
std::string debug_string_ = "<unknown SE device>";
|
123
|
-
std::string to_string_ = "<unknown SE device>";
|
124
|
-
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
|
125
|
-
std::array<int, 1> coords_;
|
126
|
-
};
|
127
|
-
|
128
80
|
class PjRtStreamExecutorDevice : public PjRtDevice {
|
129
81
|
public:
|
130
82
|
explicit PjRtStreamExecutorDevice(
|
@@ -0,0 +1,75 @@
|
|
1
|
+
/* Copyright 2025 The OpenXLA Authors.
|
2
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
you may not use this file except in compliance with the License.
|
4
|
+
You may obtain a copy of the License at
|
5
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
Unless required by applicable law or agreed to in writing, software
|
7
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
See the License for the specific language governing permissions and
|
10
|
+
limitations under the License.
|
11
|
+
==============================================================================*/
|
12
|
+
#ifndef XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
13
|
+
#define XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
14
|
+
|
15
|
+
#include <array>
|
16
|
+
#include <string>
|
17
|
+
#include <utility>
|
18
|
+
|
19
|
+
#include "absl/container/flat_hash_map.h"
|
20
|
+
#include "absl/strings/string_view.h"
|
21
|
+
#include "absl/types/span.h"
|
22
|
+
#include "xla/pjrt/pjrt_device_description.h"
|
23
|
+
|
24
|
+
namespace xla {
|
25
|
+
|
26
|
+
class PjRtStreamExecutorDeviceDescription : public PjRtDeviceDescription {
|
27
|
+
public:
|
28
|
+
explicit PjRtStreamExecutorDeviceDescription(int id, std::string device_kind,
|
29
|
+
int process_index = 0)
|
30
|
+
: id_(id),
|
31
|
+
process_index_(process_index),
|
32
|
+
device_kind_(std::move(device_kind)) {}
|
33
|
+
|
34
|
+
int id() const override { return id_; }
|
35
|
+
|
36
|
+
int process_index() const override { return process_index_; }
|
37
|
+
|
38
|
+
absl::string_view device_kind() const override { return device_kind_; }
|
39
|
+
|
40
|
+
absl::string_view ToString() const override { return to_string_; }
|
41
|
+
|
42
|
+
absl::string_view DebugString() const override { return debug_string_; }
|
43
|
+
|
44
|
+
absl::Span<int const> coords() const { return absl::MakeSpan(coords_); }
|
45
|
+
|
46
|
+
const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
|
47
|
+
const override {
|
48
|
+
return attributes_;
|
49
|
+
}
|
50
|
+
|
51
|
+
void SetAttributes(
|
52
|
+
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes) {
|
53
|
+
attributes_ = std::move(attributes);
|
54
|
+
}
|
55
|
+
|
56
|
+
void SetDebugString(std::string debug_string) {
|
57
|
+
debug_string_ = std::move(debug_string);
|
58
|
+
}
|
59
|
+
|
60
|
+
void SetToString(std::string to_string) { to_string_ = std::move(to_string); }
|
61
|
+
|
62
|
+
void SetCoords(std::array<int, 1> coords) { coords_ = coords; }
|
63
|
+
|
64
|
+
private:
|
65
|
+
const int id_;
|
66
|
+
const int process_index_;
|
67
|
+
const std::string device_kind_;
|
68
|
+
std::string debug_string_ = "<unknown SE device>";
|
69
|
+
std::string to_string_ = "<unknown SE device>";
|
70
|
+
absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_;
|
71
|
+
std::array<int, 1> coords_;
|
72
|
+
};
|
73
|
+
} // namespace xla
|
74
|
+
|
75
|
+
#endif // XLA_PJRT_PJRT_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
@@ -0,0 +1,57 @@
|
|
1
|
+
/* Copyright 2025 The OpenXLA Authors.
|
2
|
+
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
==============================================================================*/
|
15
|
+
|
16
|
+
#ifndef XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
|
17
|
+
#define XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
|
18
|
+
|
19
|
+
#include <optional>
|
20
|
+
|
21
|
+
#include "xla/backends/cpu/collectives/cpu_collectives.h"
|
22
|
+
#include "xla/pjrt/pjrt_executable.h"
|
23
|
+
|
24
|
+
namespace xla {
|
25
|
+
|
26
|
+
// ExecuteContext for XLA:CPU PjRtLoadedExecutable::Execute calls.
|
27
|
+
class CpuExecuteContext : public ExecuteContext {
|
28
|
+
public:
|
29
|
+
~CpuExecuteContext() override = default;
|
30
|
+
|
31
|
+
// If specified, override the process ID specified in
|
32
|
+
// `CpuClientOptions::process_id` for a particular call of
|
33
|
+
// PjRtLoadedExecutable::Execute.
|
34
|
+
//
|
35
|
+
// TODO(hyeontaek): Look for a collectives-agnostic way and combine this
|
36
|
+
// option with `ExecuteOptions::multi_slice_config`.
|
37
|
+
std::optional<int>& process_index() { return process_index_; }
|
38
|
+
std::optional<int> process_index() const { return process_index_; }
|
39
|
+
|
40
|
+
// If specified, override CPU collectives specified in
|
41
|
+
// `CpuClientOptions::collectives` for a particular call of
|
42
|
+
// PjRtLoadedExecutable::Execute. Must remain valid until the execution
|
43
|
+
// finishes.
|
44
|
+
//
|
45
|
+
// TODO(hyeontaek): Look for a collectives-agnostic way and combine this
|
46
|
+
// option with `ExecuteOptions::multi_slice_config`.
|
47
|
+
cpu::CpuCollectives*& collectives() { return collectives_; }
|
48
|
+
cpu::CpuCollectives* collectives() const { return collectives_; }
|
49
|
+
|
50
|
+
private:
|
51
|
+
std::optional<int> process_index_;
|
52
|
+
cpu::CpuCollectives* collectives_ = nullptr;
|
53
|
+
};
|
54
|
+
|
55
|
+
} // namespace xla
|
56
|
+
|
57
|
+
#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_EXECUTE_OPTIONS_H_
|
@@ -69,6 +69,10 @@ inline int UnpackCpuProcessIndex(PjRtGlobalDeviceId global_device_id) {
|
|
69
69
|
return global_device_id.value() / kMaxCpuDevicesPerProcess;
|
70
70
|
}
|
71
71
|
|
72
|
+
inline int UnpackCpuLocalDeviceId(PjRtGlobalDeviceId global_device_id) {
|
73
|
+
return global_device_id.value() % kMaxCpuDevicesPerProcess;
|
74
|
+
}
|
75
|
+
|
72
76
|
} // namespace xla
|
73
77
|
|
74
78
|
#endif // XLA_PJRT_PLUGIN_XLA_CPU_CPU_TOPOLOGY_H_
|
@@ -19,16 +19,67 @@ limitations under the License.
|
|
19
19
|
#include <functional>
|
20
20
|
#include <memory>
|
21
21
|
#include <optional>
|
22
|
+
#include <string>
|
22
23
|
|
23
|
-
#include "absl/
|
24
|
+
#include "absl/log/check.h"
|
25
|
+
#include "absl/log/log.h"
|
24
26
|
#include "absl/status/statusor.h"
|
27
|
+
#include "absl/strings/string_view.h"
|
25
28
|
#include "absl/types/span.h"
|
29
|
+
#include "xla/hlo/ir/hlo_module.h"
|
30
|
+
#include "xla/hlo/parser/hlo_parser.h"
|
26
31
|
#include "xla/service/compiler.h"
|
27
32
|
#include "xla/service/hlo_module_config.h"
|
28
33
|
#include "xla/shape.h"
|
34
|
+
#include "xla/util.h"
|
29
35
|
|
30
36
|
namespace xla {
|
31
37
|
|
38
|
+
// Converts an HloModule from the given hlo textual IR string (in
|
39
|
+
// HloModule::ToString format).
|
40
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString(
|
41
|
+
absl::string_view hlo_string,
|
42
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
43
|
+
|
44
|
+
// Creates an HloModule from the given proto.
|
45
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
46
|
+
const HloModuleProto& proto,
|
47
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
48
|
+
|
49
|
+
// Create an HLO state from serialized representation. In addition to
|
50
|
+
// creating the proto with HloModule::CreateFromProto(...) it also
|
51
|
+
// uses HloVerifier to ensure basic invariants are held.
|
52
|
+
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
53
|
+
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
54
|
+
// is allowed in post-optimizations HLOs.
|
55
|
+
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
56
|
+
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
57
|
+
bool is_module_post_optimizations = false);
|
58
|
+
|
59
|
+
// Reads the proto file in xla.HloProto format, creates and returns the
|
60
|
+
// HloModule.
|
61
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromBinaryProtoFile(
|
62
|
+
absl::string_view filename,
|
63
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
64
|
+
|
65
|
+
// Reads the proto file in xla.HloModule format, creates and returns the
|
66
|
+
// HloModule.
|
67
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromModuleBinaryProtofile(
|
68
|
+
absl::string_view filename, const DebugOptions& debug_options);
|
69
|
+
|
70
|
+
// Reads the HLO text dump file in HloModule::ToString format, creates and
|
71
|
+
// returns the HloModule.
|
72
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloTextFile(
|
73
|
+
absl::string_view filename,
|
74
|
+
const DebugOptions& debug_options = DebugOptions::default_instance(),
|
75
|
+
const HloParserOptions& options = HloParserOptions());
|
76
|
+
|
77
|
+
// Reads the proto file in xla.HloProto format, creates and returns the
|
78
|
+
// HloModule.
|
79
|
+
absl::StatusOr<std::unique_ptr<HloModule>> ReadModuleFromTextProtoFile(
|
80
|
+
absl::string_view hlo_file,
|
81
|
+
const DebugOptions& debug_options = DebugOptions::default_instance());
|
82
|
+
|
32
83
|
// Creates an HloModuleConfig for a given program shape and arguments.
|
33
84
|
// If execution_options does not set num_replicas, default_num_replicas is used.
|
34
85
|
// num_threads is optional; if not given, intra_op_parallelism_threads not set.
|
@@ -18,8 +18,6 @@ limitations under the License.
|
|
18
18
|
#ifndef XLA_SERVICE_HLO_PROTO_UTIL_H_
|
19
19
|
#define XLA_SERVICE_HLO_PROTO_UTIL_H_
|
20
20
|
|
21
|
-
#include <string>
|
22
|
-
|
23
21
|
#include "absl/status/status.h"
|
24
22
|
#include "xla/hlo/ir/hlo_module.h"
|
25
23
|
#include "xla/service/buffer_assignment.h"
|
@@ -35,16 +33,6 @@ HloProto MakeHloProto(const HloModule& module,
|
|
35
33
|
// will not be included in the output.
|
36
34
|
HloProto MakeHloProto(const HloModule& module);
|
37
35
|
|
38
|
-
// Create an HLO state from serialized representation. In addition to
|
39
|
-
// creating the proto with HloModule::CreateFromProto(...) it also
|
40
|
-
// uses HloVerifier to ensure basic invariants are held.
|
41
|
-
// The HLO module could be a pre-optimizations (default) or post-optimizations
|
42
|
-
// module, which affects how the HLO module is verified, e.g., mixed-precision
|
43
|
-
// is allowed in post-optimizations HLOs.
|
44
|
-
absl::StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
45
|
-
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
46
|
-
bool is_module_post_optimizations = false);
|
47
|
-
|
48
36
|
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
49
37
|
// refer to shapes inside of the given HloProto.
|
50
38
|
absl::StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
|
@@ -1604,12 +1604,12 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
|
|
1604
1604
|
Index padding_left = 0, Index padding_right = 0) {
|
1605
1605
|
typedef typename internal::traits<Input>::Index TensorIndex;
|
1606
1606
|
typedef typename internal::traits<Input>::Scalar InputScalar;
|
1607
|
-
TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
|
1608
|
-
|
1607
|
+
TensorRef<const Tensor<InputScalar, internal::traits<Input>::NumDimensions,
|
1608
|
+
internal::traits<Input>::Layout, TensorIndex> >
|
1609
1609
|
in(input);
|
1610
|
-
TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
|
1611
|
-
|
1612
|
-
|
1610
|
+
TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
|
1611
|
+
internal::traits<Kernel>::NumDimensions,
|
1612
|
+
internal::traits<Kernel>::Layout, TensorIndex> >
|
1613
1613
|
kern(kernel);
|
1614
1614
|
|
1615
1615
|
EIGEN_STATIC_ASSERT(
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -29,7 +29,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|
29
29
|
# This value changes every day with an automatic CL. It can be modified in code
|
30
30
|
# via `forward_compatibility_horizon()` or with the environment variable
|
31
31
|
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
32
|
-
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2,
|
32
|
+
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2025, 2, 20)
|
33
33
|
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
34
34
|
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
35
35
|
|
Binary file
|
Binary file
|
@@ -43,13 +43,13 @@ def imperative_grad(tape,
|
|
43
43
|
target: either a Tensor or list of Tensors to be differentiated.
|
44
44
|
sources: list of Tensors for which we want gradients
|
45
45
|
output_gradients: if not None, a list of gradient provided for each Target,
|
46
|
-
|
46
|
+
or None if we are to use the target's computed downstream gradient.
|
47
47
|
sources_raw: if not None, a list of the source python objects from which the
|
48
|
-
|
49
|
-
|
48
|
+
sources were generated. Should have the same length as sources. Only needs
|
49
|
+
to be populated if unconnected_gradients is 'zero'.
|
50
50
|
unconnected_gradients: determines the value returned if the target and
|
51
|
-
|
52
|
-
|
51
|
+
sources are unconnected. When 'none' the value returned is None whereas
|
52
|
+
when 'zero' a zero tensor in the same shape as the sources is returned.
|
53
53
|
|
54
54
|
Returns:
|
55
55
|
the gradient wrt each of the sources.
|
@@ -55,7 +55,7 @@ class CallOptions:
|
|
55
55
|
# Used by ACD to list Ops/Tensors/Callables that must be called in advance.
|
56
56
|
control_captures: List[Any] = dataclasses.field(default_factory=list)
|
57
57
|
|
58
|
-
# Determines what kind of
|
58
|
+
# Determines what kind of partitioned call is used for this function.
|
59
59
|
is_stateful: bool = False
|
60
60
|
|
61
61
|
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""
|
15
|
+
"""Implementation for defining get_compiler_ir."""
|
16
16
|
from typing import List, Optional
|
17
17
|
import warnings
|
18
18
|
|
@@ -966,7 +966,7 @@ class Function(core.PolymorphicFunction, trackable.Trackable):
|
|
966
966
|
|
967
967
|
def _check_inputs(args, kwargs):
|
968
968
|
all_inputs = list(args) + list(kwargs.values())
|
969
|
-
#
|
969
|
+
# Empty input is okay.
|
970
970
|
if not all_inputs:
|
971
971
|
return
|
972
972
|
if any(map(is_tensor_spec, all_inputs)) and any(
|
@@ -1423,7 +1423,8 @@ def function(
|
|
1423
1423
|
thought of as compile-time constants), and builds a separate `tf.Graph` for
|
1424
1424
|
each set of Python arguments that it encounters.
|
1425
1425
|
For more information, see the
|
1426
|
-
[tf.function
|
1426
|
+
[tf.function
|
1427
|
+
guide](https://www.tensorflow.org/guide/function#rules_of_tracing)
|
1427
1428
|
|
1428
1429
|
Executing a `PolymorphicFunction` will select and execute the appropriate
|
1429
1430
|
`ConcreteFunction` based on the argument types and values.
|
@@ -1440,14 +1441,17 @@ def function(
|
|
1440
1441
|
>>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
|
1441
1442
|
True
|
1442
1443
|
|
1443
|
-
`ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but
|
1444
|
-
|
1444
|
+
`ConcreteFunction`s can be executed just like `PolymorphicFunction`s, but
|
1445
|
+
their
|
1446
|
+
input is restricted to the types to which they're specialized.
|
1445
1447
|
|
1446
1448
|
## Retracing
|
1447
1449
|
|
1448
|
-
`ConcreteFunctions` are built (traced) on the fly, as the
|
1450
|
+
`ConcreteFunctions` are built (traced) on the fly, as the
|
1451
|
+
`PolymorphicFunction` is
|
1449
1452
|
called with new TensorFlow types or shapes, or with new Python values as
|
1450
|
-
arguments. When `PolymorphicFunction` builds a new trace, it is said that
|
1453
|
+
arguments. When `PolymorphicFunction` builds a new trace, it is said that
|
1454
|
+
`func`
|
1451
1455
|
is retraced. Retracing is a frequent performance concern for `tf.function` as
|
1452
1456
|
it can be considerably slower than executing a graph that's already been
|
1453
1457
|
traced. It is ideal to minimize the amount of retracing in your code.
|
@@ -1473,7 +1477,8 @@ def function(
|
|
1473
1477
|
|
1474
1478
|
## Input signatures
|
1475
1479
|
|
1476
|
-
For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction`
|
1480
|
+
For Tensor arguments, `PolymorphicFunction`creates a new `ConcreteFunction`
|
1481
|
+
for
|
1477
1482
|
every unique set of input shapes and datatypes. The example below creates two
|
1478
1483
|
separate `ConcreteFunction`s, each specialized to a different shape:
|
1479
1484
|
|
@@ -1580,59 +1585,58 @@ def function(
|
|
1580
1585
|
`func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
|
1581
1586
|
autograph: Whether autograph should be applied on `func` before tracing a
|
1582
1587
|
graph. Data-dependent Python control flow statements require
|
1583
|
-
`autograph=True`. For more information, see the
|
1584
|
-
|
1588
|
+
`autograph=True`. For more information, see the [tf.function and AutoGraph
|
1589
|
+
guide](
|
1585
1590
|
https://www.tensorflow.org/guide/function#autograph_transformations).
|
1586
1591
|
jit_compile: If `True`, compiles the function using
|
1587
1592
|
[XLA](https://tensorflow.org/xla). XLA performs compiler optimizations,
|
1588
1593
|
such as fusion, and attempts to emit more efficient code. This may
|
1589
|
-
drastically improve the performance. If set to `True`,
|
1590
|
-
|
1591
|
-
|
1592
|
-
|
1593
|
-
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
amount of retracing, for example by using more generic shapes. This
|
1602
|
-
can be controlled for user objects by customizing their associated
|
1594
|
+
drastically improve the performance. If set to `True`, the whole function
|
1595
|
+
needs to be compilable by XLA, or an `errors.InvalidArgumentError` is
|
1596
|
+
thrown. If `None` (default), compiles the function with XLA when running
|
1597
|
+
on TPU and goes through the regular function execution path when running
|
1598
|
+
on other devices. If `False`, executes the function without XLA
|
1599
|
+
compilation. Set this value to `False` when directly running a
|
1600
|
+
multi-device function on TPUs (e.g. two TPU cores, one TPU core and its
|
1601
|
+
host CPU). Not all functions are compilable, see a list of [sharp
|
1602
|
+
corners](https://tensorflow.org/xla/known_issues).
|
1603
|
+
reduce_retracing: When True, `tf.function` attempts to reduce the amount of
|
1604
|
+
retracing, for example by using more generic shapes. This can be
|
1605
|
+
controlled for user objects by customizing their associated
|
1603
1606
|
`tf.types.experimental.TraceType`.
|
1604
1607
|
experimental_implements: If provided, contains a name of a "known" function
|
1605
|
-
this implements. For example "mycompany.my_recurrent_cell".
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
for details. For an example of utilizing
|
1608
|
+
this implements. For example "mycompany.my_recurrent_cell". This is stored
|
1609
|
+
as an attribute in inference function, which can then be detected when
|
1610
|
+
processing serialized function. See [standardizing composite
|
1611
|
+
ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md)
|
1612
|
+
# pylint: disable=line-too-long for details. For an example of utilizing
|
1613
|
+
this attribute see this
|
1610
1614
|
[example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc)
|
1611
1615
|
The code above automatically detects and substitutes function that
|
1612
1616
|
implements "embedded_matmul" and allows TFLite to substitute its own
|
1613
|
-
implementations. For instance, a tensorflow user can use this
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1617
|
+
implementations. For instance, a tensorflow user can use this attribute to
|
1618
|
+
mark that their function also implements `embedded_matmul` (perhaps more
|
1619
|
+
efficiently!) by specifying it using this parameter:
|
1620
|
+
`@tf.function(experimental_implements="embedded_matmul")` This can either
|
1621
|
+
be specified as just the string name of the function or a NameAttrList
|
1622
|
+
corresponding to a list of key-value attributes associated with the
|
1623
|
+
function name. The name of the function will be in the 'name' field of the
|
1624
|
+
NameAttrList. To define a formal TF op for this function implements, try
|
1625
|
+
the experimental [composite
|
1626
|
+
TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr)
|
1623
1627
|
project.
|
1624
1628
|
experimental_autograph_options: Optional tuple of
|
1625
1629
|
`tf.autograph.experimental.Feature` values.
|
1626
1630
|
experimental_attributes: Optional dictionary of attributes to include in the
|
1627
1631
|
generated FunctionDefs.
|
1628
|
-
experimental_relax_shapes: Deprecated. Use `reduce_retracing`
|
1629
|
-
instead.
|
1632
|
+
experimental_relax_shapes: Deprecated. Use `reduce_retracing` instead.
|
1630
1633
|
experimental_compile: Deprecated alias to 'jit_compile'.
|
1631
1634
|
experimental_follow_type_hints: Deprecated. Please use input_signature or
|
1632
1635
|
reduce_retracing instead.
|
1633
1636
|
|
1634
1637
|
Returns:
|
1635
|
-
If `func` is not None, returns a
|
1638
|
+
If `func` is not None, returns a
|
1639
|
+
`tf.types.experimental.PolymorphicFunction`.
|
1636
1640
|
If `func` is None, returns a decorator that, when invoked with a single
|
1637
1641
|
`func` argument, returns a `tf.types.experimental.PolymorphicFunction`.
|
1638
1642
|
|
tensorflow/python/eager/tape.py
CHANGED
@@ -48,9 +48,9 @@ def watch(tape, tensor):
|
|
48
48
|
def default_get_variables(variable):
|
49
49
|
return [variable]
|
50
50
|
|
51
|
-
# Gets a list of changed variables. Can be
|
51
|
+
# Gets a list of changed variables. Can be overridden using
|
52
52
|
# register_variables_override. An example of overriding is for getting the
|
53
|
-
#
|
53
|
+
# variables within a distributed context.
|
54
54
|
_variables_override = default_get_variables
|
55
55
|
|
56
56
|
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -151,7 +151,11 @@ def _legacy_contrib_should_record_summaries():
|
|
151
151
|
|
152
152
|
def is_recording_summaries():
|
153
153
|
"""Returns non-Tensor boolean indicating if summaries are being recorded."""
|
154
|
-
|
154
|
+
if _summary_state.writer is None:
|
155
|
+
return False
|
156
|
+
if _summary_state.is_recording is None:
|
157
|
+
return False
|
158
|
+
return _summary_state.is_recording
|
155
159
|
|
156
160
|
|
157
161
|
@tf_export("summary.record_if", v1=[])
|
Binary file
|
Binary file
|
Binary file
|