monai-weekly 1.5.dev2444__py3-none-any.whl → 1.5.dev2446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/bundle/scripts.py +2 -0
- monai/networks/blocks/__init__.py +1 -0
- monai/networks/blocks/mednext_block.py +309 -0
- monai/networks/nets/__init__.py +19 -0
- monai/networks/nets/mednext.py +354 -0
- monai/networks/nets/vista3d.py +0 -1
- monai/networks/trt_compiler.py +161 -55
- monai/networks/utils.py +11 -5
- monai/transforms/utility/array.py +2 -2
- monai/utils/__init__.py +1 -0
- monai/utils/module.py +41 -0
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/METADATA +1 -1
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/RECORD +18 -16
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/WHEEL +1 -1
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2444.dist-info → monai_weekly-1.5.dev2446.dist-info}/top_level.txt +0 -0
monai/networks/trt_compiler.py
CHANGED
@@ -18,12 +18,12 @@ import threading
|
|
18
18
|
from collections import OrderedDict
|
19
19
|
from pathlib import Path
|
20
20
|
from types import MethodType
|
21
|
-
from typing import Any, Dict, List, Union
|
21
|
+
from typing import Any, Dict, List, Tuple, Union
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from monai.apps.utils import get_logger
|
26
|
-
from monai.networks.utils import add_casts_around_norms, convert_to_onnx,
|
26
|
+
from monai.networks.utils import add_casts_around_norms, convert_to_onnx, get_profile_shapes
|
27
27
|
from monai.utils.module import optional_import
|
28
28
|
|
29
29
|
polygraphy, polygraphy_imported = optional_import("polygraphy")
|
@@ -125,6 +125,7 @@ class TRTEngine:
|
|
125
125
|
self.output_names = []
|
126
126
|
self.dtypes = []
|
127
127
|
self.cur_profile = 0
|
128
|
+
self.input_table = {}
|
128
129
|
dtype_dict = trt_to_torch_dtype_dict()
|
129
130
|
for idx in range(self.engine.num_io_tensors):
|
130
131
|
binding = self.engine[idx]
|
@@ -134,6 +135,9 @@ class TRTEngine:
|
|
134
135
|
self.output_names.append(binding)
|
135
136
|
dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]
|
136
137
|
self.dtypes.append(dtype)
|
138
|
+
self.logger.info(
|
139
|
+
f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}"
|
140
|
+
)
|
137
141
|
|
138
142
|
def allocate_buffers(self, device):
|
139
143
|
"""
|
@@ -163,7 +167,8 @@ class TRTEngine:
|
|
163
167
|
last_profile = self.cur_profile
|
164
168
|
|
165
169
|
def try_set_inputs():
|
166
|
-
for binding
|
170
|
+
for binding in self.input_names:
|
171
|
+
t = feed_dict.get(self.input_table[binding], None)
|
167
172
|
if t is not None:
|
168
173
|
t = t.contiguous()
|
169
174
|
shape = t.shape
|
@@ -180,7 +185,8 @@ class TRTEngine:
|
|
180
185
|
raise
|
181
186
|
self.cur_profile = next_profile
|
182
187
|
ctx.set_optimization_profile_async(self.cur_profile, stream)
|
183
|
-
|
188
|
+
except Exception:
|
189
|
+
raise
|
184
190
|
left = ctx.infer_shapes()
|
185
191
|
assert len(left) == 0
|
186
192
|
|
@@ -217,6 +223,74 @@ class TRTEngine:
|
|
217
223
|
return self.tensors
|
218
224
|
|
219
225
|
|
226
|
+
def make_tensor(d):
|
227
|
+
return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda()
|
228
|
+
|
229
|
+
|
230
|
+
def unroll_input(input_names, input_example):
|
231
|
+
# Simulate list/tuple unrolling during ONNX export
|
232
|
+
unrolled_input = {}
|
233
|
+
for name in input_names:
|
234
|
+
val = input_example[name]
|
235
|
+
if val is not None:
|
236
|
+
if isinstance(val, list) or isinstance(val, tuple):
|
237
|
+
for i in range(len(val)):
|
238
|
+
unrolled_input[f"{name}_{i}"] = make_tensor(val[i])
|
239
|
+
else:
|
240
|
+
unrolled_input[name] = make_tensor(val)
|
241
|
+
return unrolled_input
|
242
|
+
|
243
|
+
|
244
|
+
def parse_groups(
|
245
|
+
ret: List[torch.Tensor], output_lists: List[List[int]]
|
246
|
+
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]:
|
247
|
+
"""
|
248
|
+
Implements parsing of 'output_lists' arg of trt_compile().
|
249
|
+
|
250
|
+
Args:
|
251
|
+
ret: plain list of Tensors
|
252
|
+
|
253
|
+
output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list
|
254
|
+
of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
|
255
|
+
Format: [[group_n] | [], ...]
|
256
|
+
[] or group_n == 0 : next output from ret is a scalar
|
257
|
+
group_n > 0 : next output from ret is a list of group_n length
|
258
|
+
group_n == -1: next output is a dynamic list. This entry can be at any
|
259
|
+
position in output_lists, but can appear only once.
|
260
|
+
Returns:
|
261
|
+
Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists
|
262
|
+
|
263
|
+
"""
|
264
|
+
groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
|
265
|
+
cur = 0
|
266
|
+
for l in range(len(output_lists)):
|
267
|
+
gl = output_lists[l]
|
268
|
+
assert len(gl) == 0 or len(gl) == 1
|
269
|
+
if len(gl) == 0 or gl[0] == 0:
|
270
|
+
groups = (*groups, ret[cur])
|
271
|
+
cur = cur + 1
|
272
|
+
elif gl[0] > 0:
|
273
|
+
groups = (*groups, ret[cur : cur + gl[0]])
|
274
|
+
cur = cur + gl[0]
|
275
|
+
elif gl[0] == -1:
|
276
|
+
rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
|
277
|
+
rcur = len(ret)
|
278
|
+
for rl in range(len(output_lists) - 1, l, -1):
|
279
|
+
rgl = output_lists[rl]
|
280
|
+
assert len(rgl) == 0 or len(rgl) == 1
|
281
|
+
if len(rgl) == 0 or rgl[0] == 0:
|
282
|
+
rcur = rcur - 1
|
283
|
+
rev_groups = (*rev_groups, ret[rcur])
|
284
|
+
elif rgl[0] > 0:
|
285
|
+
rcur = rcur - rgl[0]
|
286
|
+
rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]])
|
287
|
+
else:
|
288
|
+
raise ValueError("Two -1 lists in output")
|
289
|
+
groups = (*groups, ret[cur:rcur], *rev_groups[::-1])
|
290
|
+
break
|
291
|
+
return groups
|
292
|
+
|
293
|
+
|
220
294
|
class TrtCompiler:
|
221
295
|
"""
|
222
296
|
This class implements:
|
@@ -233,6 +307,7 @@ class TrtCompiler:
|
|
233
307
|
method="onnx",
|
234
308
|
input_names=None,
|
235
309
|
output_names=None,
|
310
|
+
output_lists=None,
|
236
311
|
export_args=None,
|
237
312
|
build_args=None,
|
238
313
|
input_profiles=None,
|
@@ -240,6 +315,7 @@ class TrtCompiler:
|
|
240
315
|
use_cuda_graph=False,
|
241
316
|
timestamp=None,
|
242
317
|
fallback=False,
|
318
|
+
forward_override=None,
|
243
319
|
logger=None,
|
244
320
|
):
|
245
321
|
"""
|
@@ -255,6 +331,8 @@ class TrtCompiler:
|
|
255
331
|
'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
|
256
332
|
input_names: Optional list of input names. If None, will be read from the function signature.
|
257
333
|
output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
|
334
|
+
output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list
|
335
|
+
of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
|
258
336
|
export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
|
259
337
|
build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
|
260
338
|
input_profiles: Optional list of profiles for TRT builder and ONNX export.
|
@@ -279,6 +357,7 @@ class TrtCompiler:
|
|
279
357
|
self.method = method
|
280
358
|
self.return_dict = output_names is not None
|
281
359
|
self.output_names = output_names or []
|
360
|
+
self.output_lists = output_lists or []
|
282
361
|
self.profiles = input_profiles or []
|
283
362
|
self.dynamic_batchsize = dynamic_batchsize
|
284
363
|
self.export_args = export_args or {}
|
@@ -289,11 +368,19 @@ class TrtCompiler:
|
|
289
368
|
self.disabled = False
|
290
369
|
|
291
370
|
self.logger = logger or get_logger("monai.networks.trt_compiler")
|
371
|
+
self.argspec = inspect.getfullargspec(model.forward)
|
292
372
|
|
293
373
|
# Normally we read input_names from forward() but can be overridden
|
294
374
|
if input_names is None:
|
295
|
-
|
296
|
-
|
375
|
+
input_names = self.argspec.args[1:]
|
376
|
+
self.defaults = {}
|
377
|
+
if self.argspec.defaults is not None:
|
378
|
+
for i in range(len(self.argspec.defaults)):
|
379
|
+
d = self.argspec.defaults[-i - 1]
|
380
|
+
if d is not None:
|
381
|
+
d = make_tensor(d)
|
382
|
+
self.defaults[self.argspec.args[-i - 1]] = d
|
383
|
+
|
297
384
|
self.input_names = input_names
|
298
385
|
self.old_forward = model.forward
|
299
386
|
|
@@ -314,9 +401,18 @@ class TrtCompiler:
|
|
314
401
|
"""
|
315
402
|
try:
|
316
403
|
self.engine = TRTEngine(self.plan_path, self.logger)
|
317
|
-
|
404
|
+
# Make sure we have names correct
|
405
|
+
input_table = {}
|
406
|
+
for name in self.engine.input_names:
|
407
|
+
if name.startswith("__") and name not in self.input_names:
|
408
|
+
orig_name = name[2:]
|
409
|
+
else:
|
410
|
+
orig_name = name
|
411
|
+
input_table[name] = orig_name
|
412
|
+
self.engine.input_table = input_table
|
413
|
+
self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}")
|
318
414
|
except Exception as e:
|
319
|
-
self.logger.
|
415
|
+
self.logger.info(f"Exception while loading the engine:\n{e}")
|
320
416
|
|
321
417
|
def forward(self, model, argv, kwargs):
|
322
418
|
"""
|
@@ -329,6 +425,11 @@ class TrtCompiler:
|
|
329
425
|
Returns: Passing through wrapped module's forward() return value(s)
|
330
426
|
|
331
427
|
"""
|
428
|
+
args = self.defaults
|
429
|
+
args.update(kwargs)
|
430
|
+
if len(argv) > 0:
|
431
|
+
args.update(self._inputs_to_dict(argv))
|
432
|
+
|
332
433
|
if self.engine is None and not self.disabled:
|
333
434
|
# Restore original forward for export
|
334
435
|
new_forward = model.forward
|
@@ -336,11 +437,10 @@ class TrtCompiler:
|
|
336
437
|
try:
|
337
438
|
self._load_engine()
|
338
439
|
if self.engine is None:
|
339
|
-
build_args =
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
# This will reassign input_names from the engine
|
440
|
+
build_args = args.copy()
|
441
|
+
with torch.no_grad():
|
442
|
+
self._build_and_save(model, build_args)
|
443
|
+
# This will reassign input_names from the engine
|
344
444
|
self._load_engine()
|
345
445
|
assert self.engine is not None
|
346
446
|
except Exception as e:
|
@@ -355,19 +455,16 @@ class TrtCompiler:
|
|
355
455
|
del param
|
356
456
|
# Call empty_cache to release GPU memory
|
357
457
|
torch.cuda.empty_cache()
|
458
|
+
# restore TRT hook
|
358
459
|
model.forward = new_forward
|
359
460
|
# Run the engine
|
360
461
|
try:
|
361
|
-
if len(argv) > 0:
|
362
|
-
kwargs.update(self._inputs_to_dict(argv))
|
363
|
-
argv = ()
|
364
|
-
|
365
462
|
if self.engine is not None:
|
366
463
|
# forward_trt is not thread safe as we do not use per-thread execution contexts
|
367
464
|
with lock_sm:
|
368
465
|
device = torch.cuda.current_device()
|
369
466
|
stream = torch.cuda.Stream(device=device)
|
370
|
-
self.engine.set_inputs(
|
467
|
+
self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream)
|
371
468
|
self.engine.allocate_buffers(device=device)
|
372
469
|
# Need this to synchronize with Torch stream
|
373
470
|
stream.wait_stream(torch.cuda.current_stream())
|
@@ -375,11 +472,13 @@ class TrtCompiler:
|
|
375
472
|
# if output_names is not None, return dictionary
|
376
473
|
if not self.return_dict:
|
377
474
|
ret = list(ret.values())
|
378
|
-
if
|
475
|
+
if self.output_lists:
|
476
|
+
ret = parse_groups(ret, self.output_lists)
|
477
|
+
elif len(ret) == 1:
|
379
478
|
ret = ret[0]
|
380
479
|
return ret
|
381
480
|
except Exception as e:
|
382
|
-
if
|
481
|
+
if self.fallback:
|
383
482
|
self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...")
|
384
483
|
else:
|
385
484
|
raise e
|
@@ -391,16 +490,11 @@ class TrtCompiler:
|
|
391
490
|
"""
|
392
491
|
|
393
492
|
profiles = []
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
p = Profile()
|
400
|
-
for name, dims in input_profile.items():
|
401
|
-
assert len(dims) == 3
|
402
|
-
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
403
|
-
profiles.append(p)
|
493
|
+
for profile in self.profiles:
|
494
|
+
p = Profile()
|
495
|
+
for id, val in profile.items():
|
496
|
+
p.add(id, min=val[0], opt=val[1], max=val[2])
|
497
|
+
profiles.append(p)
|
404
498
|
|
405
499
|
build_args = self.build_args.copy()
|
406
500
|
build_args["tf32"] = self.precision != "fp32"
|
@@ -425,7 +519,7 @@ class TrtCompiler:
|
|
425
519
|
return
|
426
520
|
|
427
521
|
export_args = self.export_args
|
428
|
-
|
522
|
+
engine_bytes = None
|
429
523
|
add_casts_around_norms(model)
|
430
524
|
|
431
525
|
if self.method == "torch_trt":
|
@@ -435,7 +529,6 @@ class TrtCompiler:
|
|
435
529
|
elif self.precision == "bf16":
|
436
530
|
enabled_precisions.append(torch.bfloat16)
|
437
531
|
inputs = list(input_example.values())
|
438
|
-
ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True)
|
439
532
|
|
440
533
|
def get_torch_trt_input(input_shape, dynamic_batchsize):
|
441
534
|
min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
|
@@ -445,12 +538,7 @@ class TrtCompiler:
|
|
445
538
|
|
446
539
|
tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs]
|
447
540
|
engine_bytes = torch_tensorrt.convert_method_to_trt_engine(
|
448
|
-
|
449
|
-
"forward",
|
450
|
-
inputs=tt_inputs,
|
451
|
-
ir="torchscript",
|
452
|
-
enabled_precisions=enabled_precisions,
|
453
|
-
**export_args,
|
541
|
+
model, "forward", arg_inputs=tt_inputs, enabled_precisions=enabled_precisions, **export_args
|
454
542
|
)
|
455
543
|
else:
|
456
544
|
dbs = self.dynamic_batchsize
|
@@ -459,33 +547,47 @@ class TrtCompiler:
|
|
459
547
|
raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!")
|
460
548
|
if len(dbs) != 3:
|
461
549
|
raise ValueError("dynamic_batchsize has to have len ==3 ")
|
462
|
-
|
550
|
+
profile = {}
|
463
551
|
for id, val in input_example.items():
|
464
|
-
sh = val.shape[1:]
|
465
|
-
profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
|
466
|
-
self.profiles = [profiles]
|
467
552
|
|
468
|
-
|
469
|
-
|
553
|
+
def add_profile(id, val):
|
554
|
+
sh = val.shape
|
555
|
+
if len(sh) > 0:
|
556
|
+
sh = sh[1:]
|
557
|
+
profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
|
558
|
+
|
559
|
+
if isinstance(val, list) or isinstance(val, tuple):
|
560
|
+
for i in range(len(val)):
|
561
|
+
add_profile(f"{id}_{i}", val[i])
|
562
|
+
elif isinstance(val, torch.Tensor):
|
563
|
+
add_profile(id, val)
|
564
|
+
self.profiles = [profile]
|
565
|
+
|
566
|
+
self.dynamic_axes = get_dynamic_axes(self.profiles)
|
567
|
+
|
568
|
+
if len(self.dynamic_axes) > 0:
|
569
|
+
export_args.update({"dynamic_axes": self.dynamic_axes})
|
470
570
|
|
471
571
|
# Use temporary directory for easy cleanup in case of external weights
|
472
572
|
with tempfile.TemporaryDirectory() as tmpdir:
|
473
|
-
|
573
|
+
unrolled_input = unroll_input(self.input_names, input_example)
|
574
|
+
onnx_path = str(Path(tmpdir) / "model.onnx")
|
474
575
|
self.logger.info(
|
475
|
-
f"Exporting to {onnx_path}:\
|
576
|
+
f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n"
|
577
|
+
+ f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}"
|
476
578
|
)
|
477
579
|
convert_to_onnx(
|
478
580
|
model,
|
479
581
|
input_example,
|
480
|
-
filename=
|
481
|
-
input_names=
|
582
|
+
filename=onnx_path,
|
583
|
+
input_names=list(unrolled_input.keys()),
|
482
584
|
output_names=self.output_names,
|
483
585
|
**export_args,
|
484
586
|
)
|
485
587
|
self.logger.info("Export to ONNX successful.")
|
486
|
-
engine_bytes = self._onnx_to_trt(
|
487
|
-
|
488
|
-
|
588
|
+
engine_bytes = self._onnx_to_trt(onnx_path)
|
589
|
+
if engine_bytes:
|
590
|
+
open(self.plan_path, "wb").write(engine_bytes)
|
489
591
|
|
490
592
|
|
491
593
|
def trt_forward(self, *argv, **kwargs):
|
@@ -505,7 +607,9 @@ def trt_compile(
|
|
505
607
|
) -> torch.nn.Module:
|
506
608
|
"""
|
507
609
|
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
|
508
|
-
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x
|
610
|
+
Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
|
611
|
+
NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
|
612
|
+
Review the TensorRT Support Matrix for which GPUs are supported.
|
509
613
|
Args:
|
510
614
|
model: module to patch with TrtCompiler object.
|
511
615
|
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
|
@@ -540,9 +644,11 @@ def trt_compile(
|
|
540
644
|
args["timestamp"] = timestamp
|
541
645
|
|
542
646
|
def wrap(model, path):
|
543
|
-
|
544
|
-
|
545
|
-
|
647
|
+
if not hasattr(model, "_trt_compiler"):
|
648
|
+
model.orig_forward = model.forward
|
649
|
+
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
|
650
|
+
model._trt_compiler = wrapper
|
651
|
+
model.forward = MethodType(trt_forward, model)
|
546
652
|
|
547
653
|
def find_sub(parent, submodule):
|
548
654
|
idx = submodule.find(".")
|
monai/networks/utils.py
CHANGED
@@ -632,7 +632,6 @@ def convert_to_onnx(
|
|
632
632
|
use_trace: bool = True,
|
633
633
|
do_constant_folding: bool = True,
|
634
634
|
constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
|
635
|
-
dynamo=False,
|
636
635
|
**kwargs,
|
637
636
|
):
|
638
637
|
"""
|
@@ -673,6 +672,9 @@ def convert_to_onnx(
|
|
673
672
|
# let torch.onnx.export to trace the model.
|
674
673
|
mode_to_export = model
|
675
674
|
torch_versioned_kwargs = kwargs
|
675
|
+
if "dynamo" in kwargs and kwargs["dynamo"] and verify:
|
676
|
+
torch_versioned_kwargs["verify"] = verify
|
677
|
+
verify = False
|
676
678
|
else:
|
677
679
|
if not pytorch_after(1, 10):
|
678
680
|
if "example_outputs" not in kwargs:
|
@@ -695,13 +697,13 @@ def convert_to_onnx(
|
|
695
697
|
f = temp_file.name
|
696
698
|
else:
|
697
699
|
f = filename
|
698
|
-
|
700
|
+
print(f"torch_versioned_kwargs={torch_versioned_kwargs}")
|
699
701
|
torch.onnx.export(
|
700
702
|
mode_to_export,
|
701
703
|
onnx_inputs,
|
702
704
|
f=f,
|
703
705
|
input_names=input_names,
|
704
|
-
output_names=output_names,
|
706
|
+
output_names=output_names or None,
|
705
707
|
dynamic_axes=dynamic_axes,
|
706
708
|
opset_version=opset_version,
|
707
709
|
do_constant_folding=do_constant_folding,
|
@@ -710,11 +712,15 @@ def convert_to_onnx(
|
|
710
712
|
onnx_model = onnx.load(f)
|
711
713
|
|
712
714
|
if do_constant_folding and polygraphy_imported:
|
713
|
-
from polygraphy.backend.onnx.loader import fold_constants
|
715
|
+
from polygraphy.backend.onnx.loader import fold_constants, save_onnx
|
714
716
|
|
715
|
-
fold_constants(onnx_model, size_threshold=constant_size_threshold)
|
717
|
+
onnx_model = fold_constants(onnx_model, size_threshold=constant_size_threshold)
|
718
|
+
save_onnx(onnx_model, f)
|
716
719
|
|
717
720
|
if verify:
|
721
|
+
if isinstance(inputs, dict):
|
722
|
+
inputs = list(inputs.values())
|
723
|
+
|
718
724
|
if device is None:
|
719
725
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
720
726
|
|
@@ -1609,9 +1609,9 @@ class ImageFilter(Transform):
|
|
1609
1609
|
|
1610
1610
|
def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None:
|
1611
1611
|
if isinstance(filter, str):
|
1612
|
-
if not filter_size:
|
1612
|
+
if filter != "gauss" and not filter_size: # Gauss is the only filter that does not require `filter_size`
|
1613
1613
|
raise ValueError("`filter_size` must be specified when specifying filters by string.")
|
1614
|
-
if filter_size % 2 == 0:
|
1614
|
+
if filter_size and filter_size % 2 == 0:
|
1615
1615
|
raise ValueError("`filter_size` should be a single uneven integer.")
|
1616
1616
|
if filter not in self.supported_filters:
|
1617
1617
|
raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.")
|
monai/utils/__init__.py
CHANGED
monai/utils/module.py
CHANGED
@@ -634,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
|
|
634
634
|
if is_prerelease:
|
635
635
|
return False
|
636
636
|
return True
|
637
|
+
|
638
|
+
|
639
|
+
@functools.lru_cache(None)
|
640
|
+
def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:
|
641
|
+
"""
|
642
|
+
Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
|
643
|
+
The current system GPU CUDA compute capability is determined by the first GPU in the system.
|
644
|
+
The compared version is a string in the form of "major.minor".
|
645
|
+
|
646
|
+
Args:
|
647
|
+
major: major version number to be compared with.
|
648
|
+
minor: minor version number to be compared with. Defaults to 0.
|
649
|
+
current_ver_string: if None, the current system GPU CUDA compute capability will be used.
|
650
|
+
|
651
|
+
Returns:
|
652
|
+
True if the current system GPU CUDA compute capability is greater than the specified version.
|
653
|
+
"""
|
654
|
+
if current_ver_string is None:
|
655
|
+
cuda_available = torch.cuda.is_available()
|
656
|
+
pynvml, has_pynvml = optional_import("pynvml")
|
657
|
+
if not has_pynvml: # assuming that the user has Ampere and later GPU
|
658
|
+
return True
|
659
|
+
if not cuda_available:
|
660
|
+
return False
|
661
|
+
else:
|
662
|
+
pynvml.nvmlInit()
|
663
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU
|
664
|
+
major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
665
|
+
current_ver_string = f"{major_c}.{minor_c}"
|
666
|
+
pynvml.nvmlShutdown()
|
667
|
+
|
668
|
+
ver, has_ver = optional_import("packaging.version", name="parse")
|
669
|
+
if has_ver:
|
670
|
+
return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore
|
671
|
+
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
|
672
|
+
while len(parts) < 2:
|
673
|
+
parts += ["0"]
|
674
|
+
c_major, c_minor = parts[:2]
|
675
|
+
c_mn = int(c_major), int(c_minor)
|
676
|
+
mn = int(major), int(minor)
|
677
|
+
return c_mn >= mn
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=_63We8aEdR2cwJUl4GLqweKPjfwhJEQki2Y4S5eBTIU,4095
|
2
|
+
monai/_version.py,sha256=IWuhCmYr4-WIkzDUUBHMXFcqPiZq5p5wXJeqkcOIg98,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -114,7 +114,7 @@ monai/bundle/config_item.py,sha256=rMjXSGkjJZdi04BwSHwCcIwzIb_TflmC3xDhC3SVJRs,1
|
|
114
114
|
monai/bundle/config_parser.py,sha256=cGyEn-cqNk0rEEZ1Qiv6UydmIDvtWZcMVljyfVm5i50,23025
|
115
115
|
monai/bundle/properties.py,sha256=iN3K4FVmN9ny1Hw9p5j7_ULcCdSD8PmrR7qXxbNz49k,11582
|
116
116
|
monai/bundle/reference_resolver.py,sha256=5YTzVEoQDJSv-PF79abwYggXCZcFxaOa3veFVElme-M,16463
|
117
|
-
monai/bundle/scripts.py,sha256=
|
117
|
+
monai/bundle/scripts.py,sha256=ZxdkNI1D1LpAJpufnJSexZt13EpihSnzFzdo5DbH3NU,89316
|
118
118
|
monai/bundle/utils.py,sha256=t-22uFvLn7Yy-dr1v1U33peNOxgAmU4TJiGAbsBrUKs,10108
|
119
119
|
monai/bundle/workflows.py,sha256=a9X_yqVz_NPRj0N2ByXRDGXBWEiijzYEKv2qH14C324,24682
|
120
120
|
monai/config/__init__.py,sha256=CN28CfTdsp301gv8YXfVvkbztCfbAqrLKrJi_C8oP9s,1048
|
@@ -241,9 +241,9 @@ monai/metrics/surface_distance.py,sha256=bKDTm7ulhjfiphHLrDJoA3OKI3npwQy2Z5wY-Jk
|
|
241
241
|
monai/metrics/utils.py,sha256=eQ9QGGvuNmYFrgtVFNiA44pBhaHLCkmpyeK2FcK_2Pc,46941
|
242
242
|
monai/metrics/wrapper.py,sha256=c1zg-xcypQyZ840TEuhhLgr4sClYMWTxlv1OieJTtvE,11781
|
243
243
|
monai/networks/__init__.py,sha256=ZzU2Qo8gDXNiRBF0JapIo3xlecZHjXsJuarF0IKVKKY,1086
|
244
|
-
monai/networks/trt_compiler.py,sha256=
|
245
|
-
monai/networks/utils.py,sha256=
|
246
|
-
monai/networks/blocks/__init__.py,sha256
|
244
|
+
monai/networks/trt_compiler.py,sha256=IFfsM1qFZvmCUBbEvbHnZe6_zmMcXghkpkzmP43dZbk,27535
|
245
|
+
monai/networks/utils.py,sha256=Dio8_0Q2WQt3crgYxKlBrNvKWkRaDdbU1gvorF8v0lo,57184
|
246
|
+
monai/networks/blocks/__init__.py,sha256=xf-4SLQjL3bU7T_vCnAIbeBzz0Ys2rrtlegJM5bej-Q,2355
|
247
247
|
monai/networks/blocks/acti_norm.py,sha256=bVGXbTZ_ssRvmED5R7LOQ7jj4V6WbVFl8JMO-4iZ2Dk,4275
|
248
248
|
monai/networks/blocks/activation.py,sha256=S5k3zcP2PsHBkeIxgWgNg8ppW80tTResVP2j9ZsvTFw,5839
|
249
249
|
monai/networks/blocks/aspp.py,sha256=GGGE7NfWj77RkaWHbcLuUP4Aff-WeiDrtgtFuSoekQk,4380
|
@@ -261,6 +261,7 @@ monai/networks/blocks/fcn.py,sha256=mnCMrxhUdj2yZ0DPIj0Xf9OKVdv-qhG1BpnAg5j7q6c,
|
|
261
261
|
monai/networks/blocks/feature_pyramid_network.py,sha256=_DeAy_lNnPqjNiJLcopjqe_PHVThACctrgbXmSSB3Jw,10554
|
262
262
|
monai/networks/blocks/fft_utils_t.py,sha256=8bOvhLgP5nDLz8QwzD4XnRaxE9-tGba2-b_QDK8IWSs,8263
|
263
263
|
monai/networks/blocks/localnet_block.py,sha256=b2-ZZvkMPphHJZYTbwEZDhqA-mMBSFM5WQOoohk_6W4,11456
|
264
|
+
monai/networks/blocks/mednext_block.py,sha256=GKaFkRvmho79yxwfYyeSaJtHFtk185dY0tA4_rPnsQA,10487
|
264
265
|
monai/networks/blocks/mlp.py,sha256=qw_jgyrYwoQ5WYBM1rtSSaO4C837ZbctoRKhh_BQQFI,3341
|
265
266
|
monai/networks/blocks/patchembedding.py,sha256=tp0coxpi70LcUk03HbnygFeCxcBv5bNHJbw1crIG_Js,8956
|
266
267
|
monai/networks/blocks/pos_embed_utils.py,sha256=vFEQqxZ6UAmjcy_icFDL9EwjRHYXuIbWr1chWUJqO7g,4070
|
@@ -288,7 +289,7 @@ monai/networks/layers/spatial_transforms.py,sha256=fz2t7-ibijNLqTYpAn4ZgdXtzBSIy
|
|
288
289
|
monai/networks/layers/utils.py,sha256=k_2xVO8BTEMMVJtemUyKBWw4_5xtqd6OOTOG8qld8To,4916
|
289
290
|
monai/networks/layers/vector_quantizer.py,sha256=0PCcaH5_uaxFORHgEetQKazq74jgOVmvQJ3h4Ywat6Y,10058
|
290
291
|
monai/networks/layers/weight_init.py,sha256=ehwI5F7jm_lmDkK4qVL7ocIzCEPx5UPgLaURcsfMNwk,2253
|
291
|
-
monai/networks/nets/__init__.py,sha256=
|
292
|
+
monai/networks/nets/__init__.py,sha256=sEmOdnrwy-eCb6-HEPf9ySFMyEmF0GcdXzERLwM7szA,4152
|
292
293
|
monai/networks/nets/ahnet.py,sha256=RT-loCa5Z_3I2DWB8lmRkhxGXSsnMVBCEDpwo68-YB4,21570
|
293
294
|
monai/networks/nets/attentionunet.py,sha256=lqsrzpy0sRuuFjAtKUUJ0hT3lGF9skpepWXLG0JBo-k,9427
|
294
295
|
monai/networks/nets/autoencoder.py,sha256=QuLdDfDwhefIqA2n8XfmFyi5T8enP6O4PETdBKmFMKc,12586
|
@@ -309,6 +310,7 @@ monai/networks/nets/fullyconnectednet.py,sha256=j5uo68qnYSxgH_sEMRh7s3QGNKFaJAIx
|
|
309
310
|
monai/networks/nets/generator.py,sha256=q20EAl9N7Q56t78JiZaUEkPhYWyD02oqO0yekJCd9x0,6581
|
310
311
|
monai/networks/nets/highresnet.py,sha256=1Mx8lR5K4sRXGWjspDAHaKq0WrX9Q7qz8CcBCKZxIXk,8883
|
311
312
|
monai/networks/nets/hovernet.py,sha256=gQDeDGqCwjJACTPmQLAx9nPRBO_D65F-scx15w3Ho_Q,28645
|
313
|
+
monai/networks/nets/mednext.py,sha256=svsIk0dH7MdNI8Fr7eP2YM8j1IBJ2paF7m_2VWpLOZ4,13258
|
312
314
|
monai/networks/nets/milmodel.py,sha256=aUDgYJG0kS3p4nBW_dF7b4cWwuC31w3KIzmUzXA08HE,9813
|
313
315
|
monai/networks/nets/netadapter.py,sha256=JtcME9pcg8ud4jHKZKM9fE-8leP2PQXgUIfKBdB0wcA,6102
|
314
316
|
monai/networks/nets/patchgan_discriminator.py,sha256=yTT0on0lzlDwSu4B9McMqdxqu5xD7Ws9wCwEkxvJEu0,8620
|
@@ -329,7 +331,7 @@ monai/networks/nets/transformer.py,sha256=-nzl20Z5xdtn7xChOd_cRbbPVoPIFGVfTQw3fI
|
|
329
331
|
monai/networks/nets/unet.py,sha256=t2an-NZ8QRpWal6uh1WpxG1tbekKRDgQtpT7YeXWFvY,13543
|
330
332
|
monai/networks/nets/unetr.py,sha256=G67kjiBMz13MzP4eV8XK-GydSogMwgXaBMFDShF5sB8,8252
|
331
333
|
monai/networks/nets/varautoencoder.py,sha256=Pd9BdXW1iVjmAVCZIc2ElGtSDAWRBaLwEKxLDicyxZI,6282
|
332
|
-
monai/networks/nets/vista3d.py,sha256=
|
334
|
+
monai/networks/nets/vista3d.py,sha256=jsQfEl_EzEmj0LCo8rs9wK9oOqN8Udisn5xZXAu6mRg,43314
|
333
335
|
monai/networks/nets/vit.py,sha256=yEzFFQln5ieknnF8A1_ecB_c0SuOBBnrXPesm_kzVts,5934
|
334
336
|
monai/networks/nets/vitautoenc.py,sha256=vfQBWjTb0k7EY4uC76rmuOCIUUgeBvf_EIXBofCzVHQ,5740
|
335
337
|
monai/networks/nets/vnet.py,sha256=zaJi5kSiTLAuFHThSZfhJvHP6zKh3oBWsTWG-328O_g,10820
|
@@ -392,9 +394,9 @@ monai/transforms/spatial/array.py,sha256=5EKivdPYCP4i4qYUlkK1RpYQFzaU_baYyzgubid
|
|
392
394
|
monai/transforms/spatial/dictionary.py,sha256=t0SvEDSVNFUEw2fK66OVF20sqSzCNxil17HmvsMFBt8,133752
|
393
395
|
monai/transforms/spatial/functional.py,sha256=IwS0witCqbGkyuxzu_R4Ztp90S0pg9hY1irG7feXqig,33886
|
394
396
|
monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
395
|
-
monai/transforms/utility/array.py,sha256=
|
397
|
+
monai/transforms/utility/array.py,sha256=Ju5WvLDmujqh2yPbi9iX-qbStbWa5iqUMrbuOwy0x6w,78188
|
396
398
|
monai/transforms/utility/dictionary.py,sha256=N6E230-g2zupG63oCsAXWgkdfZmF---TZbvk7p5FQU8,78079
|
397
|
-
monai/utils/__init__.py,sha256=
|
399
|
+
monai/utils/__init__.py,sha256=2_AIpb1wqGMkmgoZ3r43muFTEsnMTCkPu3LtckipYHg,3793
|
398
400
|
monai/utils/component_store.py,sha256=Fe9jbHgwwBBAeJAw0nI02Ae13v17wlwF6N9uUue8tJg,4525
|
399
401
|
monai/utils/decorators.py,sha256=qhhdmJMjMfZIUM6x_VGUGF7kaq2cBUAam8WymAU_mhw,3156
|
400
402
|
monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
|
@@ -402,7 +404,7 @@ monai/utils/dist.py,sha256=7brB42CvdS8Jvr8Y7hfqov1uk6NNnYea9dYfgMYy0BY,8578
|
|
402
404
|
monai/utils/enums.py,sha256=orCV7SGDajYtl3DhTTjbLDbayr6WxkMSw_bZ6yeGGTY,19513
|
403
405
|
monai/utils/jupyter_utils.py,sha256=kQqfLTLAre3TLzXTt091X_XeWy5K0QKAcTuYlJ8BOag,15650
|
404
406
|
monai/utils/misc.py,sha256=R-sCS5u7SA8hX6e7x6WSc8FgLcNpqKFRRDMWxUd2wCo,31759
|
405
|
-
monai/utils/module.py,sha256=
|
407
|
+
monai/utils/module.py,sha256=2G9mgrUhytkIADHWPAH4xWKXgIhknBYzj_RCKZdYHJA,26123
|
406
408
|
monai/utils/nvtx.py,sha256=i9JBxR1uhW1ZCgLPLlTx8b907QlXkFzJyTBLMlFjhtU,6876
|
407
409
|
monai/utils/ordering.py,sha256=0nlA5b5QpVCHbtiCbTC-YsqjTmjm0bub0IeJhGFBOes,8270
|
408
410
|
monai/utils/profiling.py,sha256=V2_cSHgrcmVF48_G3nUi2-O6fnXsS89nSlb8jj58YLo,15937
|
@@ -416,8 +418,8 @@ monai/visualize/img2tensorboard.py,sha256=NnMcyfIFqX-jD7TBO3Rn02zt5uug79d_7pIIaV
|
|
416
418
|
monai/visualize/occlusion_sensitivity.py,sha256=OQHEJLyIhB8zWqQsfKaX-1kvCjWFVYtLfS4dFC0nKFI,18160
|
417
419
|
monai/visualize/utils.py,sha256=B-MhTVs7sQbIqYS3yPnpBwPw2K82rE2PBtGIfpwZtWM,9894
|
418
420
|
monai/visualize/visualizer.py,sha256=qckyaMZCbezYUwE20k5yc-Pb7UozVavMDbrmyQwfYHY,1377
|
419
|
-
monai_weekly-1.5.
|
420
|
-
monai_weekly-1.5.
|
421
|
-
monai_weekly-1.5.
|
422
|
-
monai_weekly-1.5.
|
423
|
-
monai_weekly-1.5.
|
421
|
+
monai_weekly-1.5.dev2446.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
422
|
+
monai_weekly-1.5.dev2446.dist-info/METADATA,sha256=XpCgc-ynPAgTlLREH4IqmxkhrBE9V5PYXPqR1OkqP18,11187
|
423
|
+
monai_weekly-1.5.dev2446.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
424
|
+
monai_weekly-1.5.dev2446.dist-info/top_level.txt,sha256=UaNwRzLGORdus41Ip446s3bBfViLkdkDsXDo34J2P44,6
|
425
|
+
monai_weekly-1.5.dev2446.dist-info/RECORD,,
|
File without changes
|
File without changes
|