torch-ddp 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,155 @@
1
+ require "mkmf-rice"
2
+
3
+ $CXXFLAGS += " -std=c++17 $(optflags)"
4
+
5
+ # change to 0 for Linux pre-cxx11 ABI version
6
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
7
+
8
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
9
+
10
+ if apple_clang
11
+ # silence torch warnings
12
+ $CXXFLAGS += " -Wno-deprecated-declarations"
13
+ else
14
+ # silence rice warnings
15
+ $CXXFLAGS += " -Wno-noexcept-type"
16
+
17
+ # silence torch warnings
18
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
19
+ end
20
+
21
+ paths = [
22
+ "/usr/local",
23
+ "/opt/homebrew",
24
+ "/home/linuxbrew/.linuxbrew"
25
+ ]
26
+
27
+ inc, lib = dir_config("torch")
28
+ inc ||= paths.map { |v| "#{v}/include" }.find { |v| Dir.exist?("#{v}/torch") }
29
+ lib ||= paths.map { |v| "#{v}/lib" }.find { |v| Dir["#{v}/*torch_cpu*"].any? }
30
+
31
+ unless inc && lib
32
+ abort "LibTorch not found"
33
+ end
34
+
35
+ cuda_inc, cuda_lib = dir_config("cuda")
36
+ cuda_lib ||= "/usr/local/cuda/lib64"
37
+
38
+ cudnn_inc, cudnn_lib = dir_config("cudnn")
39
+ cudnn_lib ||= "/usr/local/cuda/lib"
40
+
41
+ gloo_inc, _ = dir_config("gloo")
42
+ gloo_inc ||= "./vendor/gloo"
43
+
44
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
45
+ abort "LibTorch not found" unless have_library("torch")
46
+
47
+ have_library("mkldnn")
48
+ have_library("nnpack")
49
+
50
+ with_cuda = false
51
+ if Dir["#{lib}/*torch_cuda*"].any?
52
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
53
+ $INCFLAGS += " -I#{cuda_inc}" if cuda_inc && Dir.exist?(cuda_inc)
54
+ $LDFLAGS += " -L#{cudnn_lib}" if Dir.exist?(cudnn_lib) && cudnn_lib != cuda_lib
55
+ with_cuda = have_library("cuda") && have_library("cudnn")
56
+ end
57
+
58
+ $INCFLAGS += " -I#{inc}"
59
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
60
+
61
+ CONFIG["CC"] = CONFIG["CXX"]
62
+ $CFLAGS = $CXXFLAGS
63
+
64
+ supports_c10_cuda = with_cuda && try_compile(<<~CPP)
65
+ #include <torch/torch.h>
66
+ #include <c10/cuda/CUDAFunctions.h>
67
+
68
+ int main() {
69
+ c10::cuda::set_device(0);
70
+ return 0;
71
+ }
72
+ CPP
73
+
74
+ if supports_c10_cuda
75
+ $defs << " -DHAVE_C10_CUDA"
76
+ end
77
+
78
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
79
+ if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
80
+ $LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
81
+ end
82
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
83
+
84
+ # https://github.com/pytorch/pytorch/blob/v2.9.0/torch/utils/cpp_extension.py#L1351-L1364
85
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
86
+ if with_cuda
87
+ $LDFLAGS += " -lcuda -lnvrtc"
88
+ $LDFLAGS += " -lnvToolsExt" if File.exist?("#{cuda_lib}/libnvToolsExt.so")
89
+ $LDFLAGS += " -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
90
+ # TODO figure out why this is needed
91
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
92
+ end
93
+
94
+ supports_c10d = try_link(<<~CPP, "-DUSE_C10D")
95
+ #include <torch/torch.h>
96
+ #include <torch/csrc/distributed/c10d/FileStore.hpp>
97
+
98
+ int main() {
99
+ ::c10d::FileStore store("unused", 1);
100
+ return 0;
101
+ }
102
+ CPP
103
+
104
+ if supports_c10d
105
+ $defs << " -DUSE_C10D"
106
+ puts "Building with distributed support"
107
+
108
+ if find_header("gloo/algorithm.h", gloo_inc)
109
+ $INCFLAGS += " -I#{gloo_inc}"
110
+ else
111
+ puts "GLOO headers not found. Consider setting --with-gloo-include param"
112
+ end
113
+ else
114
+ puts "Building without distributed support"
115
+ end
116
+
117
+ supports_c10d_gloo = supports_c10d && try_link(<<~CPP, "-DUSE_C10D -DUSE_C10D_GLOO")
118
+ #include <torch/torch.h>
119
+ #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp>
120
+ #include <torch/csrc/distributed/c10d/FileStore.hpp>
121
+
122
+ int main() {
123
+ auto store = c10::make_intrusive<::c10d::FileStore>("unused", 1);
124
+ auto opts = ::c10d::ProcessGroupGloo::Options::create();
125
+ opts->devices.push_back(::c10d::ProcessGroupGloo::createDefaultDevice());
126
+ ::c10d::ProcessGroupGloo pg(store, 0, 1, opts);
127
+ return static_cast<int>(pg.getRank());
128
+ }
129
+ CPP
130
+
131
+ supports_c10d_nccl = with_cuda && supports_c10_cuda && try_link(<<~CPP, "-DUSE_C10D -DUSE_C10D_NCCL")
132
+ #include <torch/torch.h>
133
+ #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
134
+
135
+ int main() {
136
+ auto opts = c10::make_intrusive<::c10d::ProcessGroupNCCL::Options>();
137
+ opts->is_high_priority_stream = false;
138
+ return 0;
139
+ }
140
+ CPP
141
+
142
+ if supports_c10d_gloo
143
+ $defs << "-DUSE_C10D_GLOO"
144
+ puts "GLOO support detected"
145
+ end
146
+ unless supports_c10_cuda
147
+ puts "No c10 CUDA headers found. NCCL is unavailable"
148
+ end
149
+ if supports_c10d_nccl
150
+ $defs << "-DUSE_C10D_NCCL"
151
+ puts "NCCL support detected"
152
+ end
153
+
154
+ # create makefile
155
+ create_makefile("torch/ddp_ext")
@@ -0,0 +1,325 @@
1
+ require "fiddle"
2
+
3
+ module Torch
4
+ module DDP
5
+ module MonkeyPatch
6
+ WARNING_PREFIX = "[torch-ddp]".freeze
7
+
8
+ class << self
9
+ def apply_if_needed
10
+ return if defined?(@applied) && @applied
11
+
12
+ missing = missing_features
13
+ return if missing.empty?
14
+
15
+ warn("#{WARNING_PREFIX} Applying torch compatibility patch for: #{missing.join(', ')}. Please upgrade the torch gem for native support.")
16
+ patch_cuda_set_device if missing.include?(:cuda_set_device)
17
+ patch_device_helpers
18
+ patch_load if missing.include?(:load_keywords)
19
+ patch_tensor_item if missing.include?(:tensor_item_scalar)
20
+ @applied = true
21
+ end
22
+
23
+ private
24
+
25
+ def missing_features
26
+ missing = []
27
+ missing << :cuda_set_device unless Torch.const_defined?(:CUDA) && Torch::CUDA.respond_to?(:set_device)
28
+ missing << :load_keywords unless load_supports_map_location_and_weights_only?
29
+ missing << :tensor_item_scalar unless tensor_item_returns_scalar?
30
+ missing
31
+ end
32
+
33
+ def load_supports_map_location_and_weights_only?
34
+ params = Torch.method(:load).parameters
35
+ keyword_names = params.select { |kind, _| [:key, :keyreq].include?(kind) }.map(&:last)
36
+ keyword_names.include?(:map_location) && keyword_names.include?(:weights_only)
37
+ rescue NameError
38
+ false
39
+ end
40
+
41
+ def tensor_item_returns_scalar?
42
+ value = Torch.tensor([[1]]).item
43
+ value.is_a?(Numeric) || value == true || value == false
44
+ rescue StandardError
45
+ true
46
+ end
47
+
48
+ def patch_cuda_set_device
49
+ return unless Torch.const_defined?(:CUDA)
50
+
51
+ Torch::CUDA.singleton_class.class_eval do
52
+ define_method(:set_device) do |device_id|
53
+ Torch::DDP::MonkeyPatch.cuda_set_device!(device_id)
54
+ end
55
+ end
56
+ end
57
+
58
+ def cuda_set_device!(device_id)
59
+ cuda_set_device_proc.call(Integer(device_id))
60
+ end
61
+ public :cuda_set_device!
62
+
63
+ def cuda_set_device_proc
64
+ @cuda_set_device_proc ||= begin
65
+ candidates = [
66
+ ENV["LIBCUDART_PATH"],
67
+ "/usr/local/cuda/lib64/libcudart.so",
68
+ "/usr/local/cuda/lib/libcudart.so",
69
+ "/usr/local/cuda/lib/libcudart.dylib",
70
+ "libcudart.so.12",
71
+ "libcudart.so.11",
72
+ "libcudart.so",
73
+ "libcudart.dylib"
74
+ ].compact
75
+
76
+ function = nil
77
+ candidates.each do |path|
78
+ begin
79
+ handle = Fiddle.dlopen(path)
80
+ function = Fiddle::Function.new(handle["cudaSetDevice"], [Fiddle::TYPE_INT], Fiddle::TYPE_INT)
81
+ break
82
+ rescue Fiddle::DLError
83
+ next
84
+ end
85
+ end
86
+
87
+ if function
88
+ ->(device_id) do
89
+ result = function.call(device_id)
90
+ raise Torch::Error, "cudaSetDevice(#{device_id}) failed with code #{result}" unless result.zero?
91
+ nil
92
+ end
93
+ else
94
+ ->(device_id) do
95
+ raise Torch::Error, "Torch::CUDA.set_device is unavailable; ensure torch is built with CUDA or upgrade torch."
96
+ end
97
+ end
98
+ end
99
+ end
100
+
101
+ def patch_device_helpers
102
+ Torch::Device.class_eval do
103
+ define_method(:to_s) { _str }
104
+ end
105
+
106
+ unless Torch.const_defined?(:DeviceString)
107
+ Torch.const_set(
108
+ :DeviceString,
109
+ Class.new(String) do
110
+ def initialize(device)
111
+ @device = device
112
+ super(device._str)
113
+ end
114
+
115
+ def type
116
+ @device.type
117
+ end
118
+
119
+ def index
120
+ @device.index
121
+ end
122
+ end
123
+ )
124
+ end
125
+
126
+ Torch::Tensor.class_eval do
127
+ define_method(:device) { Torch::DeviceString.new(_device) }
128
+ end
129
+ end
130
+
131
+ def patch_tensor_item
132
+ Torch::Tensor.class_eval do
133
+ alias_method :_torch_ddp_original_item, :item unless method_defined?(:_torch_ddp_original_item)
134
+
135
+ def item
136
+ value = _torch_ddp_original_item
137
+ value.is_a?(Array) ? value.flatten.first : value
138
+ end
139
+ end
140
+ end
141
+
142
+ def patch_load
143
+ patch_load_helpers
144
+
145
+ Torch.singleton_class.class_eval do
146
+ alias_method :_torch_ddp_original_load, :load unless method_defined?(:_torch_ddp_original_load)
147
+
148
+ def load(filename, map_location: nil, weights_only: false)
149
+ load_device = map_location_device(map_location) if map_location
150
+ result =
151
+ if load_device && respond_to?(:_load_with_device)
152
+ Torch::DDP::MonkeyPatch.load_with_device(filename, load_device)
153
+ else
154
+ _torch_ddp_original_load(filename)
155
+ end
156
+
157
+ ensure_weights_only_contents!(result) if weights_only
158
+ result = apply_map_location(result, map_location) if map_location
159
+ result
160
+ end
161
+ end
162
+ end
163
+
164
+ def patch_load_helpers
165
+ Torch.singleton_class.class_eval do
166
+ const_set(
167
+ :WEIGHTS_ONLY_PRIMITIVE_CLASSES,
168
+ [NilClass, TrueClass, FalseClass, Integer, Float, String].freeze
169
+ ) unless const_defined?(:WEIGHTS_ONLY_PRIMITIVE_CLASSES)
170
+
171
+ unless method_defined?(:ensure_weights_only_contents!)
172
+ def ensure_weights_only_contents!(obj)
173
+ case obj
174
+ when *WEIGHTS_ONLY_PRIMITIVE_CLASSES, Tensor
175
+ obj
176
+ when Array
177
+ obj.each { |value| ensure_weights_only_contents!(value) }
178
+ when Hash
179
+ obj.each do |key, value|
180
+ ensure_weights_only_contents!(key)
181
+ ensure_weights_only_contents!(value)
182
+ end
183
+ else
184
+ raise Error, "weights_only load supports tensors, primitive Ruby types, arrays, and hashes (found #{obj.class.name})"
185
+ end
186
+ end
187
+ end
188
+
189
+ unless method_defined?(:map_location_device)
190
+ def map_location_device(map_location)
191
+ case map_location
192
+ when Device, String, Symbol
193
+ normalize_map_location_device(map_location)
194
+ when Hash
195
+ devices = map_location.values.filter_map do |value|
196
+ begin
197
+ normalize_map_location_device(value)
198
+ rescue StandardError
199
+ nil
200
+ end
201
+ end
202
+ return nil if devices.empty?
203
+ devices.uniq!
204
+ devices.one? ? devices.first : nil
205
+ else
206
+ nil
207
+ end
208
+ end
209
+ end
210
+
211
+ unless method_defined?(:apply_map_location)
212
+ def apply_map_location(obj, map_location)
213
+ case obj
214
+ when Tensor
215
+ map_tensor_location(obj, map_location)
216
+ when Array
217
+ obj.map { |value| apply_map_location(value, map_location) }
218
+ when Hash
219
+ obj.each_with_object({}) do |(key, value), memo|
220
+ memo[apply_map_location(key, map_location)] = apply_map_location(value, map_location)
221
+ end
222
+ else
223
+ obj
224
+ end
225
+ end
226
+ end
227
+
228
+ unless method_defined?(:map_tensor_location)
229
+ def map_tensor_location(tensor, map_location)
230
+ case map_location
231
+ when nil
232
+ tensor
233
+ when Hash
234
+ target = lookup_map_location_target(map_location, tensor.device)
235
+ return tensor if target.nil?
236
+ map_tensor_location(tensor, target)
237
+ else
238
+ return map_tensor_location_callable(tensor, map_location) if map_location.respond_to?(:call)
239
+ device = normalize_map_location_device(map_location)
240
+ tensor.to(device)
241
+ end
242
+ end
243
+ end
244
+
245
+ unless method_defined?(:map_tensor_location_callable)
246
+ def map_tensor_location_callable(tensor, callable)
247
+ mapped = callable.call(tensor, map_location_device_tag(tensor.device))
248
+ return tensor if mapped.nil?
249
+ unless mapped.is_a?(Tensor)
250
+ raise Error, "map_location callable must return a Tensor or nil (got #{mapped.class.name})"
251
+ end
252
+ mapped
253
+ end
254
+ end
255
+
256
+ unless method_defined?(:lookup_map_location_target)
257
+ def lookup_map_location_target(mapping, device)
258
+ key = map_location_device_tag(device)
259
+ mapping.each do |candidate, value|
260
+ candidate_key =
261
+ case candidate
262
+ when Device
263
+ map_location_device_tag(candidate)
264
+ when String, Symbol
265
+ candidate.to_s
266
+ else
267
+ candidate
268
+ end
269
+ return value if candidate_key == key
270
+ end
271
+ nil
272
+ end
273
+ end
274
+
275
+ unless method_defined?(:map_location_device_tag)
276
+ def map_location_device_tag(device)
277
+ case device
278
+ when Device
279
+ tag = device.type
280
+ tag += ":#{device.index}" unless device.index.nil?
281
+ tag
282
+ when String, Symbol
283
+ device.to_s
284
+ else
285
+ raise Error, "Unknown device reference: #{device.inspect}"
286
+ end
287
+ end
288
+ end
289
+
290
+ unless method_defined?(:normalize_map_location_device)
291
+ def normalize_map_location_device(location)
292
+ case location
293
+ when Device
294
+ location
295
+ when String, Symbol
296
+ device(location.to_s)
297
+ else
298
+ raise Error, "Unsupported map_location: #{location.inspect}"
299
+ end
300
+ end
301
+ end
302
+ end
303
+ end
304
+ end
305
+
306
+ module_function
307
+
308
+ def load_with_device(filename, device)
309
+ fallback_load =
310
+ if Torch.respond_to?(:_torch_ddp_original_load)
311
+ Torch.method(:_torch_ddp_original_load)
312
+ else
313
+ Torch.method(:load)
314
+ end
315
+
316
+ return fallback_load.call(filename) unless Torch.respond_to?(:_load_with_device)
317
+
318
+ device_str = device.respond_to?(:_str) ? device._str : device.to_s
319
+ Torch.send(:to_ruby, Torch._load_with_device(filename, device_str))
320
+ rescue StandardError
321
+ fallback_load.call(filename)
322
+ end
323
+ end
324
+ end
325
+ end
@@ -0,0 +1,5 @@
1
+ module Torch
2
+ module DDP
3
+ VERSION = "0.1.0"
4
+ end
5
+ end