monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. monai/__init__.py +1 -1
  2. monai/_version.py +3 -3
  3. monai/apps/auto3dseg/hpo_gen.py +1 -1
  4. monai/apps/detection/utils/anchor_utils.py +2 -2
  5. monai/apps/pathology/transforms/post/array.py +7 -4
  6. monai/auto3dseg/analyzer.py +1 -1
  7. monai/bundle/scripts.py +204 -22
  8. monai/bundle/utils.py +1 -0
  9. monai/data/dataset_summary.py +1 -0
  10. monai/data/meta_tensor.py +2 -2
  11. monai/data/test_time_augmentation.py +2 -0
  12. monai/data/utils.py +9 -6
  13. monai/data/wsi_reader.py +2 -2
  14. monai/engines/__init__.py +3 -1
  15. monai/engines/trainer.py +281 -2
  16. monai/engines/utils.py +76 -1
  17. monai/handlers/mlflow_handler.py +21 -4
  18. monai/inferers/__init__.py +5 -0
  19. monai/inferers/inferer.py +1279 -1
  20. monai/metrics/cumulative_average.py +2 -0
  21. monai/metrics/panoptic_quality.py +1 -1
  22. monai/metrics/rocauc.py +2 -2
  23. monai/networks/blocks/__init__.py +3 -0
  24. monai/networks/blocks/attention_utils.py +128 -0
  25. monai/networks/blocks/crossattention.py +168 -0
  26. monai/networks/blocks/rel_pos_embedding.py +56 -0
  27. monai/networks/blocks/selfattention.py +74 -5
  28. monai/networks/blocks/spade_norm.py +95 -0
  29. monai/networks/blocks/spatialattention.py +82 -0
  30. monai/networks/blocks/transformerblock.py +25 -4
  31. monai/networks/blocks/upsample.py +22 -10
  32. monai/networks/layers/__init__.py +2 -1
  33. monai/networks/layers/factories.py +12 -1
  34. monai/networks/layers/simplelayers.py +1 -1
  35. monai/networks/layers/utils.py +14 -1
  36. monai/networks/layers/vector_quantizer.py +233 -0
  37. monai/networks/nets/__init__.py +9 -0
  38. monai/networks/nets/autoencoderkl.py +702 -0
  39. monai/networks/nets/controlnet.py +465 -0
  40. monai/networks/nets/diffusion_model_unet.py +1913 -0
  41. monai/networks/nets/patchgan_discriminator.py +230 -0
  42. monai/networks/nets/quicknat.py +8 -6
  43. monai/networks/nets/resnet.py +3 -4
  44. monai/networks/nets/spade_autoencoderkl.py +480 -0
  45. monai/networks/nets/spade_diffusion_model_unet.py +934 -0
  46. monai/networks/nets/spade_network.py +435 -0
  47. monai/networks/nets/swin_unetr.py +4 -3
  48. monai/networks/nets/transformer.py +157 -0
  49. monai/networks/nets/vqvae.py +472 -0
  50. monai/networks/schedulers/__init__.py +17 -0
  51. monai/networks/schedulers/ddim.py +294 -0
  52. monai/networks/schedulers/ddpm.py +250 -0
  53. monai/networks/schedulers/pndm.py +316 -0
  54. monai/networks/schedulers/scheduler.py +205 -0
  55. monai/networks/utils.py +22 -0
  56. monai/transforms/croppad/array.py +8 -8
  57. monai/transforms/croppad/dictionary.py +4 -4
  58. monai/transforms/croppad/functional.py +1 -1
  59. monai/transforms/regularization/array.py +4 -0
  60. monai/transforms/spatial/array.py +1 -1
  61. monai/transforms/utils_create_transform_ims.py +2 -4
  62. monai/utils/__init__.py +1 -0
  63. monai/utils/misc.py +5 -4
  64. monai/utils/ordering.py +207 -0
  65. monai/visualize/class_activation_maps.py +5 -5
  66. monai/visualize/img2tensorboard.py +3 -1
  67. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
  68. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
  69. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
  70. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
  71. {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
monai/__init__.py CHANGED
@@ -93,4 +93,4 @@ except BaseException:
93
93
 
94
94
  if MONAIEnvVars.debug():
95
95
  raise
96
- __commit_id__ = "14b086b553693f5d344ff054f37d12ce6839da06"
96
+ __commit_id__ = "2e53df78e580131046dc8db7f7638063db1f5045"
monai/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-07-14T02:21:44+0000",
11
+ "date": "2024-07-28T02:19:22+0000",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "2d242d8f1b2876133bcafbe7fa5d967728a74998",
15
- "version": "1.4.dev2428"
14
+ "full-revisionid": "9dd92b4a07706d4b80edace3d39fe008dc805d5a",
15
+ "version": "1.4.dev2430"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -53,7 +53,7 @@ class HPOGen(AlgoGen):
53
53
  raise NotImplementedError
54
54
 
55
55
  @abstractmethod
56
- def set_score(self):
56
+ def set_score(self, *args, **kwargs):
57
57
  """Report the evaluated results to HPO."""
58
58
  raise NotImplementedError
59
59
 
@@ -189,7 +189,7 @@ class AnchorGenerator(nn.Module):
189
189
  w_ratios = 1 / area_scale
190
190
  h_ratios = area_scale
191
191
  # if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]
192
- elif self.spatial_dims == 3:
192
+ else:
193
193
  area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)
194
194
  w_ratios = 1 / area_scale
195
195
  h_ratios = aspect_ratios_t[:, 0] / area_scale
@@ -199,7 +199,7 @@ class AnchorGenerator(nn.Module):
199
199
  hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)
200
200
  if self.spatial_dims == 2:
201
201
  base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0
202
- elif self.spatial_dims == 3:
202
+ else: # elif self.spatial_dims == 3:
203
203
  ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)
204
204
  base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0
205
205
 
@@ -28,7 +28,7 @@ from monai.transforms import (
28
28
  SobelGradients,
29
29
  )
30
30
  from monai.transforms.transform import Transform
31
- from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
31
+ from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique, where
32
32
  from monai.utils import TransformBackends, convert_to_numpy, optional_import
33
33
  from monai.utils.misc import ensure_tuple_rep
34
34
  from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
@@ -162,7 +162,8 @@ class GenerateWatershedMask(Transform):
162
162
  pred = label(pred)[0]
163
163
  if self.remove_small_objects is not None:
164
164
  pred = self.remove_small_objects(pred)
165
- pred[pred > 0] = 1
165
+ pred_indices = np.where(pred > 0)
166
+ pred[pred_indices] = 1
166
167
 
167
168
  return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]
168
169
 
@@ -338,7 +339,8 @@ class GenerateWatershedMarkers(Transform):
338
339
  instance_border = instance_border >= self.threshold # uncertain area
339
340
 
340
341
  marker = mask - convert_to_dst_type(instance_border, mask)[0] # certain foreground
341
- marker[marker < 0] = 0
342
+ marker_indices = where(marker < 0)
343
+ marker[marker_indices] = 0 # type: ignore[index]
342
344
  marker = self.postprocess_fn(marker)
343
345
  marker = convert_to_numpy(marker)
344
346
 
@@ -379,6 +381,7 @@ class GenerateSuccinctContour(Transform):
379
381
  """
380
382
 
381
383
  p_delta = (current[0] - previous[0], current[1] - previous[1])
384
+ row, col = -1, -1
382
385
 
383
386
  if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)):
384
387
  row = int(current[0] + 0.5)
@@ -634,7 +637,7 @@ class GenerateInstanceType(Transform):
634
637
 
635
638
  seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]
636
639
 
637
- inst_type = type_map_crop[seg_map_crop]
640
+ inst_type = type_map_crop[seg_map_crop] # type: ignore[index]
638
641
  type_list, type_pixels = unique(inst_type, return_counts=True)
639
642
  type_list = list(zip(type_list, type_pixels))
640
643
  type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
@@ -470,7 +470,7 @@ class LabelStats(Analyzer):
470
470
 
471
471
  unique_label = unique(ndas_label)
472
472
  if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
473
- unique_label = unique_label.data.cpu().numpy()
473
+ unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
474
474
 
475
475
  unique_label = unique_label.astype(np.int16).tolist()
476
476
 
monai/bundle/scripts.py CHANGED
@@ -16,6 +16,7 @@ import json
16
16
  import os
17
17
  import re
18
18
  import warnings
19
+ import zipfile
19
20
  from collections.abc import Mapping, Sequence
20
21
  from pathlib import Path
21
22
  from pydoc import locate
@@ -26,7 +27,7 @@ from typing import Any, Callable
26
27
  import torch
27
28
  from torch.cuda import is_available
28
29
 
29
- from monai.apps.mmars.mmars import _get_all_ngc_models
30
+ from monai._version import get_versions
30
31
  from monai.apps.utils import _basename, download_url, extractall, get_logger
31
32
  from monai.bundle.config_item import ConfigComponent
32
33
  from monai.bundle.config_parser import ConfigParser
@@ -66,6 +67,9 @@ logger = get_logger(module_name=__name__)
66
67
  DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting")
67
68
  PPRINT_CONFIG_N = 5
68
69
 
70
+ MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
71
+ NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"
72
+
69
73
 
70
74
  def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
71
75
  """
@@ -168,12 +172,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
168
172
 
169
173
 
170
174
  def _get_ngc_bundle_url(model_name: str, version: str) -> str:
171
- return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
175
+ return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
176
+
177
+
178
+ def _get_ngc_private_base_url(repo: str) -> str:
179
+ return f"https://api.ngc.nvidia.com/v2/{repo}/models"
180
+
181
+
182
+ def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
183
+ return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"
172
184
 
173
185
 
174
186
  def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
175
- monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
176
- return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
187
+ return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
177
188
 
178
189
 
179
190
  def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
@@ -219,29 +230,168 @@ def _download_from_ngc(
219
230
  extractall(filepath=filepath, output_dir=extract_path, has_base=True)
220
231
 
221
232
 
233
+ def _download_from_ngc_private(
234
+ download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
235
+ ) -> None:
236
+ # ensure prefix is contained
237
+ filename = _add_ngc_prefix(filename)
238
+ request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
239
+ if has_requests:
240
+ headers = {} if headers is None else headers
241
+ response = requests_get(request_url, headers=headers)
242
+ response.raise_for_status()
243
+ else:
244
+ raise ValueError("NGC API requires requests package. Please install it.")
245
+
246
+ zip_path = download_path / f"{filename}_v{version}.zip"
247
+ with open(zip_path, "wb") as f:
248
+ f.write(response.content)
249
+ logger.info(f"Downloading: {zip_path}.")
250
+ if remove_prefix:
251
+ filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
252
+ extract_path = download_path / f"{filename}"
253
+ with zipfile.ZipFile(zip_path, "r") as z:
254
+ z.extractall(extract_path)
255
+ logger.info(f"Writing into directory: {extract_path}.")
256
+
257
+
258
+ def _get_ngc_token(api_key, retry=0):
259
+ """Try to connect to NGC."""
260
+ url = "https://authn.nvidia.com/token?service=ngc"
261
+ headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
262
+ if has_requests:
263
+ response = requests_get(url, headers=headers)
264
+ if not response.ok:
265
+ # retry 3 times, if failed, raise an error.
266
+ if retry < 3:
267
+ logger.info(f"Retrying {retry} time(s) to GET {url}.")
268
+ return _get_ngc_token(url, retry + 1)
269
+ raise RuntimeError("NGC API response is not ok. Failed to get token.")
270
+ else:
271
+ token = response.json()["token"]
272
+ return token
273
+
274
+
222
275
  def _get_latest_bundle_version_monaihosting(name):
223
- url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
224
- full_url = f"{url}/{name.lower()}"
276
+ full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
225
277
  requests_get, has_requests = optional_import("requests", name="get")
226
278
  if has_requests:
227
279
  resp = requests_get(full_url)
228
280
  resp.raise_for_status()
229
281
  else:
230
- raise ValueError("NGC API requires requests package. Please install it.")
282
+ raise ValueError("NGC API requires requests package. Please install it.")
231
283
  model_info = json.loads(resp.text)
232
284
  return model_info["model"]["latestVersionIdStr"]
233
285
 
234
286
 
235
- def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
287
+ def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
288
+ """Examine if the package version is compatible with the MONAI version in the metadata."""
289
+ version_dict = get_versions()
290
+ package_version = version_dict.get("version", "0+unknown")
291
+ if package_version == "0+unknown":
292
+ return False, "Package version is not available. Skipping version check."
293
+ if monai_version == "0+unknown":
294
+ return False, "MONAI version is not specified in the bundle. Skipping version check."
295
+ # treat rc versions as the same as the release version
296
+ package_version = re.sub(r"rc\d.*", "", package_version)
297
+ monai_version = re.sub(r"rc\d.*", "", monai_version)
298
+ if package_version < monai_version:
299
+ return (
300
+ False,
301
+ f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
302
+ )
303
+ return True, ""
304
+
305
+
306
+ def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
307
+ """Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version"""
308
+ metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
309
+ if not metadata_file.exists():
310
+ logger.warning(f"metadata file not found in {metadata_file}.")
311
+ return
312
+ with open(metadata_file) as f:
313
+ metadata = json.load(f)
314
+ is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown"))
315
+ if not is_compatible:
316
+ logger.warning(msg)
317
+
318
+
319
+ def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
320
+ """
321
+ Extract the latest versions from the data dictionary.
322
+
323
+ Args:
324
+ data: the data dictionary.
325
+ max_versions: the maximum number of versions to return.
326
+
327
+ Returns:
328
+ versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
329
+ """
330
+ # Check if the data is a dictionary and it has the key 'modelVersions'
331
+ if not isinstance(data, dict) or "modelVersions" not in data:
332
+ raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")
333
+
334
+ # Extract the list of model versions
335
+ model_versions = data["modelVersions"]
336
+
337
+ if (
338
+ not isinstance(model_versions, list)
339
+ or len(model_versions) == 0
340
+ or "createdDate" not in model_versions[0]
341
+ or "versionId" not in model_versions[0]
342
+ ):
343
+ raise ValueError(
344
+ "The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
345
+ )
346
+
347
+ # Sort the versions by the 'createdDate' in descending order
348
+ sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
349
+ return [v["versionId"] for v in sorted_versions[:max_versions]]
350
+
351
+
352
+ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
353
+ base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL
354
+ version_endpoint = base_url + f"/{name.lower()}/versions/"
355
+
356
+ if not has_requests:
357
+ raise ValueError("requests package is required, please install it.")
358
+
359
+ version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
360
+ if headers:
361
+ version_header.update(headers)
362
+ resp = requests_get(version_endpoint, headers=version_header)
363
+ resp.raise_for_status()
364
+ model_info = json.loads(resp.text)
365
+ latest_versions = _list_latest_versions(model_info)
366
+
367
+ for version in latest_versions:
368
+ file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
369
+ resp = requests_get(file_endpoint, headers=headers)
370
+ metadata = json.loads(resp.text)
371
+ resp.raise_for_status()
372
+ # if the package version is not available or the model is compatible with the package version
373
+ is_compatible, _ = _examine_monai_version(metadata["monai_version"])
374
+ if is_compatible:
375
+ if version != latest_versions[0]:
376
+ logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.")
377
+ return version
378
+
379
+ # if no compatible version is found, return the latest version
380
+ return latest_versions[0]
381
+
382
+
383
+ def _get_latest_bundle_version(
384
+ source: str, name: str, repo: str, **kwargs: Any
385
+ ) -> dict[str, list[str] | str] | Any | None:
236
386
  if source == "ngc":
237
387
  name = _add_ngc_prefix(name)
238
- model_dict = _get_all_ngc_models(name)
239
- for v in model_dict.values():
240
- if v["name"] == name:
241
- return v["latest"]
242
- return None
388
+ return _get_latest_bundle_version_ngc(name)
243
389
  elif source == "monaihosting":
244
390
  return _get_latest_bundle_version_monaihosting(name)
391
+ elif source == "ngc_private":
392
+ headers = kwargs.pop("headers", {})
393
+ name = _add_ngc_prefix(name)
394
+ return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
245
395
  elif source == "github":
246
396
  repo_owner, repo_name, tag_name = repo.split("/")
247
397
  return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
@@ -308,6 +458,9 @@ def download(
308
458
  # Execute this module as a CLI entry, and download bundle via URL:
309
459
  python -m monai.bundle download --name <bundle_name> --url <url>
310
460
 
461
+ # Execute this module as a CLI entry, and download bundle from ngc_private with latest version:
462
+ python -m monai.bundle download --name <bundle_name> --source "ngc_private" --bundle_dir "./" --repo "org/org_name"
463
+
311
464
  # Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
312
465
  # Other args still can override the default args at runtime.
313
466
  # The content of the JSON / YAML file is a dictionary. For example:
@@ -328,10 +481,13 @@ def download(
328
481
  Default is `bundle` subfolder under `torch.hub.get_dir()`.
329
482
  source: storage location name. This argument is used when `url` is `None`.
330
483
  In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
331
- it should be "ngc", "monaihosting", "github", or "huggingface_hub".
484
+ it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub".
485
+ If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
332
486
  repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
333
487
  If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
334
488
  If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
489
+ If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
490
+ or you can specify the environment variable NGC_ORG and NGC_TEAM.
335
491
  url: url to download the data. If not `None`, data will be downloaded directly
336
492
  and `source` will not be checked.
337
493
  If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
@@ -363,11 +519,18 @@ def download(
363
519
 
364
520
  bundle_dir_ = _process_bundle_dir(bundle_dir_)
365
521
  if repo_ is None:
366
- repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
367
- if len(repo_.split("/")) != 3 and source_ != "huggingface_hub":
368
- raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")
522
+ org_ = os.getenv("NGC_ORG", None)
523
+ team_ = os.getenv("NGC_TEAM", None)
524
+ if org_ is not None and source_ == "ngc_private":
525
+ repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}"
526
+ else:
527
+ repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
528
+ if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private":
529
+ raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.")
530
+ if len(repo_.split("/")) != 3 and source_ == "github":
531
+ raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.")
369
532
  elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
370
- raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`")
533
+ raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.")
371
534
  if url_ is not None:
372
535
  if name_ is not None:
373
536
  filepath = bundle_dir_ / f"{name_}.zip"
@@ -376,14 +539,22 @@ def download(
376
539
  download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
377
540
  extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
378
541
  else:
542
+ headers = {}
379
543
  if name_ is None:
380
544
  raise ValueError(f"To download from source: {source_}, `name` must be provided.")
545
+ if source == "ngc_private":
546
+ api_key = os.getenv("NGC_API_KEY", None)
547
+ if api_key is None:
548
+ raise ValueError("API key is required for ngc_private source.")
549
+ else:
550
+ token = _get_ngc_token(api_key)
551
+ headers = {"Authorization": f"Bearer {token}"}
552
+
381
553
  if version_ is None:
382
- version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
554
+ version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
383
555
  if source_ == "github":
384
- if version_ is not None:
385
- name_ = "_v".join([name_, version_])
386
- _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
556
+ name_ver = "_v".join([name_, version_]) if version_ is not None else name_
557
+ _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
387
558
  elif source_ == "monaihosting":
388
559
  _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
389
560
  elif source_ == "ngc":
@@ -394,6 +565,15 @@ def download(
394
565
  remove_prefix=remove_prefix_,
395
566
  progress=progress_,
396
567
  )
568
+ elif source_ == "ngc_private":
569
+ _download_from_ngc_private(
570
+ download_path=bundle_dir_,
571
+ filename=name_,
572
+ version=version_,
573
+ remove_prefix=remove_prefix_,
574
+ repo=repo_,
575
+ headers=headers,
576
+ )
397
577
  elif source_ == "huggingface_hub":
398
578
  extract_path = os.path.join(bundle_dir_, name_)
399
579
  huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
@@ -403,6 +583,8 @@ def download(
403
583
  f"got source: {source_}."
404
584
  )
405
585
 
586
+ _check_monai_version(bundle_dir_, name_)
587
+
406
588
 
407
589
  @deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
408
590
  @deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
monai/bundle/utils.py CHANGED
@@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
221
221
  raise ValueError(f"Cannot find config file '{full_cname}'")
222
222
 
223
223
  ardata = archive.read(full_cname)
224
+ cdata = {}
224
225
 
225
226
  if full_cname.lower().endswith("json"):
226
227
  cdata = json.loads(ardata, **load_kw_args)
@@ -84,6 +84,7 @@ class DatasetSummary:
84
84
  """
85
85
 
86
86
  for data in self.data_loader:
87
+ meta_dict = {}
87
88
  if isinstance(data[self.image_key], MetaTensor):
88
89
  meta_dict = data[self.image_key].meta
89
90
  elif self.meta_key in data:
monai/data/meta_tensor.py CHANGED
@@ -505,7 +505,7 @@ class MetaTensor(MetaObj, torch.Tensor):
505
505
  a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
506
506
  return 1 if a is None else int(max(1, len(a) - 1))
507
507
 
508
- def new_empty(self, size, dtype=None, device=None, requires_grad=False):
508
+ def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
509
509
  """
510
510
  must be defined for deepcopy to work
511
511
 
@@ -580,7 +580,7 @@ class MetaTensor(MetaObj, torch.Tensor):
580
580
  img.affine = MetaTensor.get_default_affine()
581
581
  return img
582
582
 
583
- def __repr__(self):
583
+ def __repr__(self): # type: ignore[override]
584
584
  """
585
585
  Prints a representation of the tensor.
586
586
  Prepends "meta" to ``torch.Tensor.__repr__``.
@@ -106,6 +106,8 @@ class TestTimeAugmentation:
106
106
  mode, mean, std, vvc = tt_aug(test_data)
107
107
  """
108
108
 
109
+ __test__ = False # indicate to pytest that this class is not intended for collection
110
+
109
111
  def __init__(
110
112
  self,
111
113
  transform: InvertibleTransform,
monai/data/utils.py CHANGED
@@ -53,10 +53,6 @@ from monai.utils import (
53
53
  pytorch_after,
54
54
  )
55
55
 
56
- if pytorch_after(1, 13):
57
- # import private code for reuse purposes, comment in case things break in the future
58
- from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
59
-
60
56
  pd, _ = optional_import("pandas")
61
57
  DataFrame, _ = optional_import("pandas", name="DataFrame")
62
58
  nib, _ = optional_import("nibabel")
@@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
454
450
  Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
455
451
  and so should not be used as a collate function directly in dataloaders.
456
452
  """
457
- collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate
458
- collated = collate_fn(batch) # type: ignore
453
+ if pytorch_after(1, 13):
454
+ from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
455
+
456
+ collated = collate_tensor_fn(batch)
457
+ else:
458
+ collated = default_collate(batch)
459
+
459
460
  meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
460
461
  common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
461
462
  if common_:
@@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence):
496
497
 
497
498
  if pytorch_after(1, 13):
498
499
  # needs to go here to avoid circular import
500
+ from torch.utils.data._utils.collate import default_collate_fn_map
501
+
499
502
  from monai.data.meta_tensor import MetaTensor
500
503
 
501
504
  default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
monai/data/wsi_reader.py CHANGED
@@ -1097,8 +1097,8 @@ class TiffFileWSIReader(BaseWSIReader):
1097
1097
  ):
1098
1098
  unit = wsi.pages[level].tags.get("ResolutionUnit")
1099
1099
  if unit is not None:
1100
- unit = str(unit.value)[8:]
1101
- else:
1100
+ unit = str(unit.value.name)
1101
+ if unit is None or len(unit) == 0:
1102
1102
  warnings.warn("The resolution unit is missing. `micrometer` will be used as default.")
1103
1103
  unit = "micrometer"
1104
1104
 
monai/engines/__init__.py CHANGED
@@ -12,12 +12,14 @@
12
12
  from __future__ import annotations
13
13
 
14
14
  from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
15
- from .trainer import GanTrainer, SupervisedTrainer, Trainer
15
+ from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
16
16
  from .utils import (
17
+ DiffusionPrepareBatch,
17
18
  IterationEvents,
18
19
  PrepareBatch,
19
20
  PrepareBatchDefault,
20
21
  PrepareBatchExtraInput,
22
+ VPredictionPrepareBatch,
21
23
  default_make_latent,
22
24
  default_metric_cmp_fn,
23
25
  default_prepare_batch,