monai-weekly 1.4.dev2428__py3-none-any.whl → 1.4.dev2430__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/apps/auto3dseg/hpo_gen.py +1 -1
- monai/apps/detection/utils/anchor_utils.py +2 -2
- monai/apps/pathology/transforms/post/array.py +7 -4
- monai/auto3dseg/analyzer.py +1 -1
- monai/bundle/scripts.py +204 -22
- monai/bundle/utils.py +1 -0
- monai/data/dataset_summary.py +1 -0
- monai/data/meta_tensor.py +2 -2
- monai/data/test_time_augmentation.py +2 -0
- monai/data/utils.py +9 -6
- monai/data/wsi_reader.py +2 -2
- monai/engines/__init__.py +3 -1
- monai/engines/trainer.py +281 -2
- monai/engines/utils.py +76 -1
- monai/handlers/mlflow_handler.py +21 -4
- monai/inferers/__init__.py +5 -0
- monai/inferers/inferer.py +1279 -1
- monai/metrics/cumulative_average.py +2 -0
- monai/metrics/panoptic_quality.py +1 -1
- monai/metrics/rocauc.py +2 -2
- monai/networks/blocks/__init__.py +3 -0
- monai/networks/blocks/attention_utils.py +128 -0
- monai/networks/blocks/crossattention.py +168 -0
- monai/networks/blocks/rel_pos_embedding.py +56 -0
- monai/networks/blocks/selfattention.py +74 -5
- monai/networks/blocks/spade_norm.py +95 -0
- monai/networks/blocks/spatialattention.py +82 -0
- monai/networks/blocks/transformerblock.py +25 -4
- monai/networks/blocks/upsample.py +22 -10
- monai/networks/layers/__init__.py +2 -1
- monai/networks/layers/factories.py +12 -1
- monai/networks/layers/simplelayers.py +1 -1
- monai/networks/layers/utils.py +14 -1
- monai/networks/layers/vector_quantizer.py +233 -0
- monai/networks/nets/__init__.py +9 -0
- monai/networks/nets/autoencoderkl.py +702 -0
- monai/networks/nets/controlnet.py +465 -0
- monai/networks/nets/diffusion_model_unet.py +1913 -0
- monai/networks/nets/patchgan_discriminator.py +230 -0
- monai/networks/nets/quicknat.py +8 -6
- monai/networks/nets/resnet.py +3 -4
- monai/networks/nets/spade_autoencoderkl.py +480 -0
- monai/networks/nets/spade_diffusion_model_unet.py +934 -0
- monai/networks/nets/spade_network.py +435 -0
- monai/networks/nets/swin_unetr.py +4 -3
- monai/networks/nets/transformer.py +157 -0
- monai/networks/nets/vqvae.py +472 -0
- monai/networks/schedulers/__init__.py +17 -0
- monai/networks/schedulers/ddim.py +294 -0
- monai/networks/schedulers/ddpm.py +250 -0
- monai/networks/schedulers/pndm.py +316 -0
- monai/networks/schedulers/scheduler.py +205 -0
- monai/networks/utils.py +22 -0
- monai/transforms/croppad/array.py +8 -8
- monai/transforms/croppad/dictionary.py +4 -4
- monai/transforms/croppad/functional.py +1 -1
- monai/transforms/regularization/array.py +4 -0
- monai/transforms/spatial/array.py +1 -1
- monai/transforms/utils_create_transform_ims.py +2 -4
- monai/utils/__init__.py +1 -0
- monai/utils/misc.py +5 -4
- monai/utils/ordering.py +207 -0
- monai/visualize/class_activation_maps.py +5 -5
- monai/visualize/img2tensorboard.py +3 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/METADATA +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/RECORD +71 -50
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/WHEEL +1 -1
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/LICENSE +0 -0
- {monai_weekly-1.4.dev2428.dist-info → monai_weekly-1.4.dev2430.dist-info}/top_level.txt +0 -0
monai/__init__.py
CHANGED
monai/_version.py
CHANGED
@@ -8,11 +8,11 @@ import json
|
|
8
8
|
|
9
9
|
version_json = '''
|
10
10
|
{
|
11
|
-
"date": "2024-07-
|
11
|
+
"date": "2024-07-28T02:19:22+0000",
|
12
12
|
"dirty": false,
|
13
13
|
"error": null,
|
14
|
-
"full-revisionid": "
|
15
|
-
"version": "1.4.
|
14
|
+
"full-revisionid": "9dd92b4a07706d4b80edace3d39fe008dc805d5a",
|
15
|
+
"version": "1.4.dev2430"
|
16
16
|
}
|
17
17
|
''' # END VERSION_JSON
|
18
18
|
|
monai/apps/auto3dseg/hpo_gen.py
CHANGED
@@ -189,7 +189,7 @@ class AnchorGenerator(nn.Module):
|
|
189
189
|
w_ratios = 1 / area_scale
|
190
190
|
h_ratios = area_scale
|
191
191
|
# if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]
|
192
|
-
|
192
|
+
else:
|
193
193
|
area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)
|
194
194
|
w_ratios = 1 / area_scale
|
195
195
|
h_ratios = aspect_ratios_t[:, 0] / area_scale
|
@@ -199,7 +199,7 @@ class AnchorGenerator(nn.Module):
|
|
199
199
|
hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)
|
200
200
|
if self.spatial_dims == 2:
|
201
201
|
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0
|
202
|
-
elif self.spatial_dims == 3:
|
202
|
+
else: # elif self.spatial_dims == 3:
|
203
203
|
ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)
|
204
204
|
base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0
|
205
205
|
|
@@ -28,7 +28,7 @@ from monai.transforms import (
|
|
28
28
|
SobelGradients,
|
29
29
|
)
|
30
30
|
from monai.transforms.transform import Transform
|
31
|
-
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
|
31
|
+
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique, where
|
32
32
|
from monai.utils import TransformBackends, convert_to_numpy, optional_import
|
33
33
|
from monai.utils.misc import ensure_tuple_rep
|
34
34
|
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
|
@@ -162,7 +162,8 @@ class GenerateWatershedMask(Transform):
|
|
162
162
|
pred = label(pred)[0]
|
163
163
|
if self.remove_small_objects is not None:
|
164
164
|
pred = self.remove_small_objects(pred)
|
165
|
-
pred
|
165
|
+
pred_indices = np.where(pred > 0)
|
166
|
+
pred[pred_indices] = 1
|
166
167
|
|
167
168
|
return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]
|
168
169
|
|
@@ -338,7 +339,8 @@ class GenerateWatershedMarkers(Transform):
|
|
338
339
|
instance_border = instance_border >= self.threshold # uncertain area
|
339
340
|
|
340
341
|
marker = mask - convert_to_dst_type(instance_border, mask)[0] # certain foreground
|
341
|
-
marker
|
342
|
+
marker_indices = where(marker < 0)
|
343
|
+
marker[marker_indices] = 0 # type: ignore[index]
|
342
344
|
marker = self.postprocess_fn(marker)
|
343
345
|
marker = convert_to_numpy(marker)
|
344
346
|
|
@@ -379,6 +381,7 @@ class GenerateSuccinctContour(Transform):
|
|
379
381
|
"""
|
380
382
|
|
381
383
|
p_delta = (current[0] - previous[0], current[1] - previous[1])
|
384
|
+
row, col = -1, -1
|
382
385
|
|
383
386
|
if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)):
|
384
387
|
row = int(current[0] + 0.5)
|
@@ -634,7 +637,7 @@ class GenerateInstanceType(Transform):
|
|
634
637
|
|
635
638
|
seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]
|
636
639
|
|
637
|
-
inst_type = type_map_crop[seg_map_crop]
|
640
|
+
inst_type = type_map_crop[seg_map_crop] # type: ignore[index]
|
638
641
|
type_list, type_pixels = unique(inst_type, return_counts=True)
|
639
642
|
type_list = list(zip(type_list, type_pixels))
|
640
643
|
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
|
monai/auto3dseg/analyzer.py
CHANGED
@@ -470,7 +470,7 @@ class LabelStats(Analyzer):
|
|
470
470
|
|
471
471
|
unique_label = unique(ndas_label)
|
472
472
|
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
|
473
|
-
unique_label = unique_label.data.cpu().numpy()
|
473
|
+
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
|
474
474
|
|
475
475
|
unique_label = unique_label.astype(np.int16).tolist()
|
476
476
|
|
monai/bundle/scripts.py
CHANGED
@@ -16,6 +16,7 @@ import json
|
|
16
16
|
import os
|
17
17
|
import re
|
18
18
|
import warnings
|
19
|
+
import zipfile
|
19
20
|
from collections.abc import Mapping, Sequence
|
20
21
|
from pathlib import Path
|
21
22
|
from pydoc import locate
|
@@ -26,7 +27,7 @@ from typing import Any, Callable
|
|
26
27
|
import torch
|
27
28
|
from torch.cuda import is_available
|
28
29
|
|
29
|
-
from monai.
|
30
|
+
from monai._version import get_versions
|
30
31
|
from monai.apps.utils import _basename, download_url, extractall, get_logger
|
31
32
|
from monai.bundle.config_item import ConfigComponent
|
32
33
|
from monai.bundle.config_parser import ConfigParser
|
@@ -66,6 +67,9 @@ logger = get_logger(module_name=__name__)
|
|
66
67
|
DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting")
|
67
68
|
PPRINT_CONFIG_N = 5
|
68
69
|
|
70
|
+
MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
|
71
|
+
NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"
|
72
|
+
|
69
73
|
|
70
74
|
def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
|
71
75
|
"""
|
@@ -168,12 +172,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
|
|
168
172
|
|
169
173
|
|
170
174
|
def _get_ngc_bundle_url(model_name: str, version: str) -> str:
|
171
|
-
return f"
|
175
|
+
return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
|
176
|
+
|
177
|
+
|
178
|
+
def _get_ngc_private_base_url(repo: str) -> str:
|
179
|
+
return f"https://api.ngc.nvidia.com/v2/{repo}/models"
|
180
|
+
|
181
|
+
|
182
|
+
def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
|
183
|
+
return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"
|
172
184
|
|
173
185
|
|
174
186
|
def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
|
175
|
-
|
176
|
-
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
|
187
|
+
return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
|
177
188
|
|
178
189
|
|
179
190
|
def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
|
@@ -219,29 +230,168 @@ def _download_from_ngc(
|
|
219
230
|
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
|
220
231
|
|
221
232
|
|
233
|
+
def _download_from_ngc_private(
|
234
|
+
download_path: Path, filename: str, version: str, remove_prefix: str | None, repo: str, headers: dict | None = None
|
235
|
+
) -> None:
|
236
|
+
# ensure prefix is contained
|
237
|
+
filename = _add_ngc_prefix(filename)
|
238
|
+
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
|
239
|
+
if has_requests:
|
240
|
+
headers = {} if headers is None else headers
|
241
|
+
response = requests_get(request_url, headers=headers)
|
242
|
+
response.raise_for_status()
|
243
|
+
else:
|
244
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
245
|
+
|
246
|
+
zip_path = download_path / f"{filename}_v{version}.zip"
|
247
|
+
with open(zip_path, "wb") as f:
|
248
|
+
f.write(response.content)
|
249
|
+
logger.info(f"Downloading: {zip_path}.")
|
250
|
+
if remove_prefix:
|
251
|
+
filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
|
252
|
+
extract_path = download_path / f"{filename}"
|
253
|
+
with zipfile.ZipFile(zip_path, "r") as z:
|
254
|
+
z.extractall(extract_path)
|
255
|
+
logger.info(f"Writing into directory: {extract_path}.")
|
256
|
+
|
257
|
+
|
258
|
+
def _get_ngc_token(api_key, retry=0):
|
259
|
+
"""Try to connect to NGC."""
|
260
|
+
url = "https://authn.nvidia.com/token?service=ngc"
|
261
|
+
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
|
262
|
+
if has_requests:
|
263
|
+
response = requests_get(url, headers=headers)
|
264
|
+
if not response.ok:
|
265
|
+
# retry 3 times, if failed, raise an error.
|
266
|
+
if retry < 3:
|
267
|
+
logger.info(f"Retrying {retry} time(s) to GET {url}.")
|
268
|
+
return _get_ngc_token(url, retry + 1)
|
269
|
+
raise RuntimeError("NGC API response is not ok. Failed to get token.")
|
270
|
+
else:
|
271
|
+
token = response.json()["token"]
|
272
|
+
return token
|
273
|
+
|
274
|
+
|
222
275
|
def _get_latest_bundle_version_monaihosting(name):
|
223
|
-
|
224
|
-
full_url = f"{url}/{name.lower()}"
|
276
|
+
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
|
225
277
|
requests_get, has_requests = optional_import("requests", name="get")
|
226
278
|
if has_requests:
|
227
279
|
resp = requests_get(full_url)
|
228
280
|
resp.raise_for_status()
|
229
281
|
else:
|
230
|
-
raise ValueError("NGC API requires requests package.
|
282
|
+
raise ValueError("NGC API requires requests package. Please install it.")
|
231
283
|
model_info = json.loads(resp.text)
|
232
284
|
return model_info["model"]["latestVersionIdStr"]
|
233
285
|
|
234
286
|
|
235
|
-
def
|
287
|
+
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
|
288
|
+
"""Examine if the package version is compatible with the MONAI version in the metadata."""
|
289
|
+
version_dict = get_versions()
|
290
|
+
package_version = version_dict.get("version", "0+unknown")
|
291
|
+
if package_version == "0+unknown":
|
292
|
+
return False, "Package version is not available. Skipping version check."
|
293
|
+
if monai_version == "0+unknown":
|
294
|
+
return False, "MONAI version is not specified in the bundle. Skipping version check."
|
295
|
+
# treat rc versions as the same as the release version
|
296
|
+
package_version = re.sub(r"rc\d.*", "", package_version)
|
297
|
+
monai_version = re.sub(r"rc\d.*", "", monai_version)
|
298
|
+
if package_version < monai_version:
|
299
|
+
return (
|
300
|
+
False,
|
301
|
+
f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
|
302
|
+
)
|
303
|
+
return True, ""
|
304
|
+
|
305
|
+
|
306
|
+
def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
|
307
|
+
"""Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version"""
|
308
|
+
metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
|
309
|
+
if not metadata_file.exists():
|
310
|
+
logger.warning(f"metadata file not found in {metadata_file}.")
|
311
|
+
return
|
312
|
+
with open(metadata_file) as f:
|
313
|
+
metadata = json.load(f)
|
314
|
+
is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown"))
|
315
|
+
if not is_compatible:
|
316
|
+
logger.warning(msg)
|
317
|
+
|
318
|
+
|
319
|
+
def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
|
320
|
+
"""
|
321
|
+
Extract the latest versions from the data dictionary.
|
322
|
+
|
323
|
+
Args:
|
324
|
+
data: the data dictionary.
|
325
|
+
max_versions: the maximum number of versions to return.
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
|
329
|
+
"""
|
330
|
+
# Check if the data is a dictionary and it has the key 'modelVersions'
|
331
|
+
if not isinstance(data, dict) or "modelVersions" not in data:
|
332
|
+
raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")
|
333
|
+
|
334
|
+
# Extract the list of model versions
|
335
|
+
model_versions = data["modelVersions"]
|
336
|
+
|
337
|
+
if (
|
338
|
+
not isinstance(model_versions, list)
|
339
|
+
or len(model_versions) == 0
|
340
|
+
or "createdDate" not in model_versions[0]
|
341
|
+
or "versionId" not in model_versions[0]
|
342
|
+
):
|
343
|
+
raise ValueError(
|
344
|
+
"The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
|
345
|
+
)
|
346
|
+
|
347
|
+
# Sort the versions by the 'createdDate' in descending order
|
348
|
+
sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
|
349
|
+
return [v["versionId"] for v in sorted_versions[:max_versions]]
|
350
|
+
|
351
|
+
|
352
|
+
def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
|
353
|
+
base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL
|
354
|
+
version_endpoint = base_url + f"/{name.lower()}/versions/"
|
355
|
+
|
356
|
+
if not has_requests:
|
357
|
+
raise ValueError("requests package is required, please install it.")
|
358
|
+
|
359
|
+
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
|
360
|
+
if headers:
|
361
|
+
version_header.update(headers)
|
362
|
+
resp = requests_get(version_endpoint, headers=version_header)
|
363
|
+
resp.raise_for_status()
|
364
|
+
model_info = json.loads(resp.text)
|
365
|
+
latest_versions = _list_latest_versions(model_info)
|
366
|
+
|
367
|
+
for version in latest_versions:
|
368
|
+
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
|
369
|
+
resp = requests_get(file_endpoint, headers=headers)
|
370
|
+
metadata = json.loads(resp.text)
|
371
|
+
resp.raise_for_status()
|
372
|
+
# if the package version is not available or the model is compatible with the package version
|
373
|
+
is_compatible, _ = _examine_monai_version(metadata["monai_version"])
|
374
|
+
if is_compatible:
|
375
|
+
if version != latest_versions[0]:
|
376
|
+
logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.")
|
377
|
+
return version
|
378
|
+
|
379
|
+
# if no compatible version is found, return the latest version
|
380
|
+
return latest_versions[0]
|
381
|
+
|
382
|
+
|
383
|
+
def _get_latest_bundle_version(
|
384
|
+
source: str, name: str, repo: str, **kwargs: Any
|
385
|
+
) -> dict[str, list[str] | str] | Any | None:
|
236
386
|
if source == "ngc":
|
237
387
|
name = _add_ngc_prefix(name)
|
238
|
-
|
239
|
-
for v in model_dict.values():
|
240
|
-
if v["name"] == name:
|
241
|
-
return v["latest"]
|
242
|
-
return None
|
388
|
+
return _get_latest_bundle_version_ngc(name)
|
243
389
|
elif source == "monaihosting":
|
244
390
|
return _get_latest_bundle_version_monaihosting(name)
|
391
|
+
elif source == "ngc_private":
|
392
|
+
headers = kwargs.pop("headers", {})
|
393
|
+
name = _add_ngc_prefix(name)
|
394
|
+
return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
|
245
395
|
elif source == "github":
|
246
396
|
repo_owner, repo_name, tag_name = repo.split("/")
|
247
397
|
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
|
@@ -308,6 +458,9 @@ def download(
|
|
308
458
|
# Execute this module as a CLI entry, and download bundle via URL:
|
309
459
|
python -m monai.bundle download --name <bundle_name> --url <url>
|
310
460
|
|
461
|
+
# Execute this module as a CLI entry, and download bundle from ngc_private with latest version:
|
462
|
+
python -m monai.bundle download --name <bundle_name> --source "ngc_private" --bundle_dir "./" --repo "org/org_name"
|
463
|
+
|
311
464
|
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
|
312
465
|
# Other args still can override the default args at runtime.
|
313
466
|
# The content of the JSON / YAML file is a dictionary. For example:
|
@@ -328,10 +481,13 @@ def download(
|
|
328
481
|
Default is `bundle` subfolder under `torch.hub.get_dir()`.
|
329
482
|
source: storage location name. This argument is used when `url` is `None`.
|
330
483
|
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
|
331
|
-
it should be "ngc", "monaihosting", "github", or "huggingface_hub".
|
484
|
+
it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub".
|
485
|
+
If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
|
332
486
|
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
|
333
487
|
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
|
334
488
|
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
|
489
|
+
If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
|
490
|
+
or you can specify the environment variable NGC_ORG and NGC_TEAM.
|
335
491
|
url: url to download the data. If not `None`, data will be downloaded directly
|
336
492
|
and `source` will not be checked.
|
337
493
|
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
|
@@ -363,11 +519,18 @@ def download(
|
|
363
519
|
|
364
520
|
bundle_dir_ = _process_bundle_dir(bundle_dir_)
|
365
521
|
if repo_ is None:
|
366
|
-
|
367
|
-
|
368
|
-
|
522
|
+
org_ = os.getenv("NGC_ORG", None)
|
523
|
+
team_ = os.getenv("NGC_TEAM", None)
|
524
|
+
if org_ is not None and source_ == "ngc_private":
|
525
|
+
repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}"
|
526
|
+
else:
|
527
|
+
repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
|
528
|
+
if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private":
|
529
|
+
raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.")
|
530
|
+
if len(repo_.split("/")) != 3 and source_ == "github":
|
531
|
+
raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.")
|
369
532
|
elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
|
370
|
-
raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name
|
533
|
+
raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.")
|
371
534
|
if url_ is not None:
|
372
535
|
if name_ is not None:
|
373
536
|
filepath = bundle_dir_ / f"{name_}.zip"
|
@@ -376,14 +539,22 @@ def download(
|
|
376
539
|
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
|
377
540
|
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
|
378
541
|
else:
|
542
|
+
headers = {}
|
379
543
|
if name_ is None:
|
380
544
|
raise ValueError(f"To download from source: {source_}, `name` must be provided.")
|
545
|
+
if source == "ngc_private":
|
546
|
+
api_key = os.getenv("NGC_API_KEY", None)
|
547
|
+
if api_key is None:
|
548
|
+
raise ValueError("API key is required for ngc_private source.")
|
549
|
+
else:
|
550
|
+
token = _get_ngc_token(api_key)
|
551
|
+
headers = {"Authorization": f"Bearer {token}"}
|
552
|
+
|
381
553
|
if version_ is None:
|
382
|
-
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
|
554
|
+
version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
|
383
555
|
if source_ == "github":
|
384
|
-
if version_ is not None
|
385
|
-
|
386
|
-
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
|
556
|
+
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
|
557
|
+
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
|
387
558
|
elif source_ == "monaihosting":
|
388
559
|
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
|
389
560
|
elif source_ == "ngc":
|
@@ -394,6 +565,15 @@ def download(
|
|
394
565
|
remove_prefix=remove_prefix_,
|
395
566
|
progress=progress_,
|
396
567
|
)
|
568
|
+
elif source_ == "ngc_private":
|
569
|
+
_download_from_ngc_private(
|
570
|
+
download_path=bundle_dir_,
|
571
|
+
filename=name_,
|
572
|
+
version=version_,
|
573
|
+
remove_prefix=remove_prefix_,
|
574
|
+
repo=repo_,
|
575
|
+
headers=headers,
|
576
|
+
)
|
397
577
|
elif source_ == "huggingface_hub":
|
398
578
|
extract_path = os.path.join(bundle_dir_, name_)
|
399
579
|
huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
|
@@ -403,6 +583,8 @@ def download(
|
|
403
583
|
f"got source: {source_}."
|
404
584
|
)
|
405
585
|
|
586
|
+
_check_monai_version(bundle_dir_, name_)
|
587
|
+
|
406
588
|
|
407
589
|
@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
|
408
590
|
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
|
monai/bundle/utils.py
CHANGED
@@ -221,6 +221,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
|
|
221
221
|
raise ValueError(f"Cannot find config file '{full_cname}'")
|
222
222
|
|
223
223
|
ardata = archive.read(full_cname)
|
224
|
+
cdata = {}
|
224
225
|
|
225
226
|
if full_cname.lower().endswith("json"):
|
226
227
|
cdata = json.loads(ardata, **load_kw_args)
|
monai/data/dataset_summary.py
CHANGED
monai/data/meta_tensor.py
CHANGED
@@ -505,7 +505,7 @@ class MetaTensor(MetaObj, torch.Tensor):
|
|
505
505
|
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
|
506
506
|
return 1 if a is None else int(max(1, len(a) - 1))
|
507
507
|
|
508
|
-
def new_empty(self, size, dtype=None, device=None, requires_grad=False):
|
508
|
+
def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
|
509
509
|
"""
|
510
510
|
must be defined for deepcopy to work
|
511
511
|
|
@@ -580,7 +580,7 @@ class MetaTensor(MetaObj, torch.Tensor):
|
|
580
580
|
img.affine = MetaTensor.get_default_affine()
|
581
581
|
return img
|
582
582
|
|
583
|
-
def __repr__(self):
|
583
|
+
def __repr__(self): # type: ignore[override]
|
584
584
|
"""
|
585
585
|
Prints a representation of the tensor.
|
586
586
|
Prepends "meta" to ``torch.Tensor.__repr__``.
|
monai/data/utils.py
CHANGED
@@ -53,10 +53,6 @@ from monai.utils import (
|
|
53
53
|
pytorch_after,
|
54
54
|
)
|
55
55
|
|
56
|
-
if pytorch_after(1, 13):
|
57
|
-
# import private code for reuse purposes, comment in case things break in the future
|
58
|
-
from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
|
59
|
-
|
60
56
|
pd, _ = optional_import("pandas")
|
61
57
|
DataFrame, _ = optional_import("pandas", name="DataFrame")
|
62
58
|
nib, _ = optional_import("nibabel")
|
@@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
|
|
454
450
|
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
|
455
451
|
and so should not be used as a collate function directly in dataloaders.
|
456
452
|
"""
|
457
|
-
|
458
|
-
|
453
|
+
if pytorch_after(1, 13):
|
454
|
+
from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
|
455
|
+
|
456
|
+
collated = collate_tensor_fn(batch)
|
457
|
+
else:
|
458
|
+
collated = default_collate(batch)
|
459
|
+
|
459
460
|
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
|
460
461
|
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
|
461
462
|
if common_:
|
@@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence):
|
|
496
497
|
|
497
498
|
if pytorch_after(1, 13):
|
498
499
|
# needs to go here to avoid circular import
|
500
|
+
from torch.utils.data._utils.collate import default_collate_fn_map
|
501
|
+
|
499
502
|
from monai.data.meta_tensor import MetaTensor
|
500
503
|
|
501
504
|
default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
|
monai/data/wsi_reader.py
CHANGED
@@ -1097,8 +1097,8 @@ class TiffFileWSIReader(BaseWSIReader):
|
|
1097
1097
|
):
|
1098
1098
|
unit = wsi.pages[level].tags.get("ResolutionUnit")
|
1099
1099
|
if unit is not None:
|
1100
|
-
unit = str(unit.value)
|
1101
|
-
|
1100
|
+
unit = str(unit.value.name)
|
1101
|
+
if unit is None or len(unit) == 0:
|
1102
1102
|
warnings.warn("The resolution unit is missing. `micrometer` will be used as default.")
|
1103
1103
|
unit = "micrometer"
|
1104
1104
|
|
monai/engines/__init__.py
CHANGED
@@ -12,12 +12,14 @@
|
|
12
12
|
from __future__ import annotations
|
13
13
|
|
14
14
|
from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
|
15
|
-
from .trainer import GanTrainer, SupervisedTrainer, Trainer
|
15
|
+
from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
|
16
16
|
from .utils import (
|
17
|
+
DiffusionPrepareBatch,
|
17
18
|
IterationEvents,
|
18
19
|
PrepareBatch,
|
19
20
|
PrepareBatchDefault,
|
20
21
|
PrepareBatchExtraInput,
|
22
|
+
VPredictionPrepareBatch,
|
21
23
|
default_make_latent,
|
22
24
|
default_metric_cmp_fn,
|
23
25
|
default_prepare_batch,
|