xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,802 @@
1
+ """Functions for managing experiments."""
2
+
3
+ import contextlib
4
+ import datetime
5
+ import enum
6
+ import functools
7
+ import hashlib
8
+ import inspect
9
+ import itertools
10
+ import logging
11
+ import math
12
+ import os
13
+ import random
14
+ import re
15
+ import shutil
16
+ import sys
17
+ import tempfile
18
+ import textwrap
19
+ import time
20
+ import traceback
21
+ import urllib.error
22
+ import urllib.request
23
+ import warnings
24
+ from abc import ABC, abstractmethod
25
+ from pathlib import Path
26
+ from typing import Any, Iterator, TypeVar, cast
27
+ from urllib.parse import urlparse
28
+
29
+ import git
30
+ import requests
31
+ from jaxtyping import Array
32
+ from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
33
+
34
+ from xax.core.conf import get_data_dir, get_pretrained_models_dir, load_user_config
35
+ from xax.core.state import State
36
+ from xax.utils.text import colored
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ # Date format for staging environments.
41
+ DATE_FORMAT = "%Y-%m-%d"
42
+
43
+ USER_AGENT = "xax"
44
+
45
+ T = TypeVar("T")
46
+
47
+
48
+ class CumulativeTimer:
49
+ """Defines a simple timer to track an average value."""
50
+
51
+ def __init__(self) -> None:
52
+ self.steps = 0
53
+ self.elapsed_time = 0.0
54
+
55
+ @functools.cached_property
56
+ def start_time(self) -> float:
57
+ return time.time()
58
+
59
+ def step(self, steps: int, cur_time: float) -> None:
60
+ if steps != self.steps:
61
+ self.steps = steps
62
+ self.elapsed_time = cur_time - self.start_time
63
+
64
+ @property
65
+ def steps_per_second(self) -> float:
66
+ return 0.0 if self.elapsed_time < 1e-4 else self.steps / self.elapsed_time
67
+
68
+ @property
69
+ def steps_per_hour(self) -> float:
70
+ return self.steps_per_second * 60 * 60
71
+
72
+ @property
73
+ def seconds_per_step(self) -> float:
74
+ return 0.0 if self.steps <= 0 else self.elapsed_time / self.steps
75
+
76
+ @property
77
+ def hours_per_step(self) -> float:
78
+ return self.seconds_per_step / (60 * 60)
79
+
80
+
81
+ class IterationTimer:
82
+ """Defines a simple timer to track consecutive values."""
83
+
84
+ def __init__(self) -> None:
85
+ self.iteration_time = 0.0
86
+ self.last_time = time.time()
87
+
88
+ def step(self, cur_time: float) -> None:
89
+ self.iteration_time = cur_time - self.last_time
90
+ self.last_time = cur_time
91
+
92
+ @property
93
+ def iter_seconds(self) -> float:
94
+ return self.iteration_time
95
+
96
+ @property
97
+ def iter_hours(self) -> float:
98
+ return self.iter_seconds / (60 * 60)
99
+
100
+
101
+ class StateTimer:
102
+ """Defines a timer for all state information."""
103
+
104
+ def __init__(self) -> None:
105
+ self.step_timer = CumulativeTimer()
106
+ self.sample_timer = CumulativeTimer()
107
+ self.iter_timer = IterationTimer()
108
+
109
+ def step(self, state: State) -> None:
110
+ cur_time = time.time()
111
+ self.step_timer.step(state.num_steps, cur_time)
112
+ self.sample_timer.step(state.num_samples, cur_time)
113
+ self.iter_timer.step(cur_time)
114
+
115
+ def log_dict(self) -> dict[str, dict[str, int | float]]:
116
+ logs: dict[str, dict[str, int | float]] = {}
117
+
118
+ # Logs step statistics.
119
+ logs["⏰ steps"] = {
120
+ "total": self.step_timer.steps,
121
+ "per-second": self.step_timer.steps_per_second,
122
+ }
123
+
124
+ # Logs sample statistics.
125
+ logs["⏰ samples"] = {
126
+ "total": self.sample_timer.steps,
127
+ "per-second": self.sample_timer.steps_per_second,
128
+ }
129
+
130
+ # Logs full iteration statistics.
131
+ logs["🔧 dt"] = {
132
+ "iter": self.iter_timer.iter_seconds,
133
+ }
134
+
135
+ return logs
136
+
137
+
138
+ class IntervalTicker:
139
+ def __init__(self, interval: float) -> None:
140
+ self.interval = interval
141
+ self.last_tick_time: float | None = None
142
+
143
+ def tick(self, elapsed_time: float) -> bool:
144
+ if self.last_tick_time is None or elapsed_time - self.last_tick_time > self.interval:
145
+ self.last_tick_time = elapsed_time
146
+ return True
147
+ return False
148
+
149
+
150
+ def abs_path(path: str) -> str:
151
+ return str(Path(path).resolve())
152
+
153
+
154
+ OmegaConf.register_new_resolver("ml.abs_path", abs_path, replace=True)
155
+
156
+
157
+ def cpu_count(default: int) -> int:
158
+ if (cpu_count := os.cpu_count()) is not None:
159
+ return cpu_count
160
+ return default
161
+
162
+
163
+ OmegaConf.register_new_resolver("ml.cpu_count", cpu_count, replace=True)
164
+
165
+
166
+ def date_str(_: str) -> str:
167
+ return time.strftime("%Y-%m-%d")
168
+
169
+
170
+ OmegaConf.register_new_resolver("ml.date_str", date_str, replace=True)
171
+
172
+
173
+ def get_random_port(default: int = 1337) -> int:
174
+ try:
175
+ return (hash(time.time()) + random.randint(0, 100000)) % (65_535 - 10_000) + 10_000
176
+ except Exception:
177
+ return default
178
+
179
+
180
+ OmegaConf.register_new_resolver("xax.get_random_port", get_random_port, replace=True)
181
+
182
+
183
+ class NaNError(Exception):
184
+ """Raised when NaNs are detected in the model parameters."""
185
+
186
+
187
+ class TrainingFinishedError(Exception):
188
+ """Raised when training is finished."""
189
+
190
+
191
+ class MinGradScaleError(TrainingFinishedError):
192
+ """Raised when the minimum gradient scale is reached.
193
+
194
+ This is a subclass of :class:`TrainingFinishedError` because it indicates
195
+ that training is finished and causes the post-training hooks to be run.
196
+ """
197
+
198
+
199
+ def diff_configs(
200
+ first: ListConfig | DictConfig,
201
+ second: ListConfig | DictConfig,
202
+ prefix: str | None = None,
203
+ ) -> tuple[list[str], list[str]]:
204
+ """Returns the difference between two configs.
205
+
206
+ Args:
207
+ first: The first (original) config
208
+ second: The second (new) config
209
+ prefix: The prefix to check (used for recursion, not main call)
210
+
211
+ Returns:
212
+ Two lists of lines describing the diff between the two configs
213
+ """
214
+
215
+ def get_diff_string(prefix: str | None, val: Any) -> str: # noqa: ANN401
216
+ if isinstance(val, (str, float, int)):
217
+ return f"{prefix}={val}"
218
+ return f"{prefix}= ... ({type(val)})"
219
+
220
+ def cast_enums(k: Any) -> Any: # noqa: ANN401
221
+ return k.name if isinstance(k, enum.Enum) else k
222
+
223
+ new_first: list[str] = []
224
+ new_second: list[str] = []
225
+
226
+ any_config = (ListConfig, DictConfig)
227
+
228
+ if isinstance(first, DictConfig) and isinstance(second, DictConfig):
229
+ first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
230
+
231
+ # Gets the new keys in each config.
232
+ new_first += [f"{prefix}.{key}" for key in first_keys.difference(second_keys)]
233
+ new_second += [f"{prefix}.{key}" for key in second_keys.difference(first_keys)]
234
+
235
+ # Gets the new sub-keys in each config.
236
+ for key in first_keys.intersection(second_keys):
237
+ sub_prefix = key if prefix is None else f"{prefix}.{key}"
238
+ if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
239
+ if not OmegaConf.is_missing(first, key):
240
+ new_first += [get_diff_string(sub_prefix, first[key])]
241
+ if not OmegaConf.is_missing(second, key):
242
+ new_second += [get_diff_string(sub_prefix, second[key])]
243
+ elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
244
+ sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
245
+ new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
246
+ elif cast_enums(first[key]) != cast_enums(second[key]):
247
+ first_val, second_val = first[key], second[key]
248
+ new_first += [get_diff_string(sub_prefix, first_val)]
249
+ new_second += [get_diff_string(sub_prefix, second_val)]
250
+
251
+ elif isinstance(first, ListConfig) and isinstance(second, ListConfig):
252
+ if len(first) > len(second):
253
+ for i in range(len(second), len(first)):
254
+ new_first += [get_diff_string(prefix, first[i])]
255
+ elif len(second) > len(first):
256
+ for i in range(len(first), len(second)):
257
+ new_second += [get_diff_string(prefix, second[i])]
258
+
259
+ for i in range(min(len(first), len(second))):
260
+ sub_prefix = str(i) if prefix is None else f"{prefix}.{i}"
261
+ if isinstance(first[i], any_config) and isinstance(second[i], any_config):
262
+ sub_new_first, sub_new_second = diff_configs(first[i], second[i], prefix=sub_prefix)
263
+ new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
264
+ else:
265
+ new_first += [get_diff_string(prefix, first)]
266
+ new_second += [get_diff_string(prefix, second)]
267
+
268
+ return new_first, new_second
269
+
270
+
271
+ def get_diff_string(config_diff: tuple[list[str], list[str]]) -> str | None:
272
+ added_keys, deleted_keys = config_diff
273
+ if not added_keys and not deleted_keys:
274
+ return None
275
+ change_lines: list[str] = []
276
+ change_lines += [f" ↪ {colored('+', 'green')} {added_key}" for added_key in added_keys]
277
+ change_lines += [f" ↪ {colored('-', 'red')} {deleted_key}" for deleted_key in deleted_keys]
278
+ change_summary = "\n".join(change_lines)
279
+ return change_summary
280
+
281
+
282
+ def save_config(config_path: Path, raw_config: DictConfig) -> None:
283
+ if config_path.exists():
284
+ config_diff = diff_configs(raw_config, cast(DictConfig, OmegaConf.load(config_path)))
285
+ diff_string = get_diff_string(config_diff)
286
+ if diff_string is not None:
287
+ logger.warning("Overwriting config %s:\n%s", config_path, diff_string)
288
+ OmegaConf.save(raw_config, config_path)
289
+ else:
290
+ config_path.parent.mkdir(exist_ok=True, parents=True)
291
+ OmegaConf.save(raw_config, config_path)
292
+ logger.info("Saved config to %s", config_path)
293
+
294
+
295
+ def to_markdown_table(config: DictConfig) -> str:
296
+ """Converts a config to a markdown table string.
297
+
298
+ Args:
299
+ config: The config to convert to a table.
300
+
301
+ Returns:
302
+ The config, formatted as a Markdown string.
303
+ """
304
+
305
+ def format_as_string(value: Any) -> str: # noqa: ANN401
306
+ if isinstance(value, str):
307
+ return value
308
+ if isinstance(value, Array):
309
+ value = value.item()
310
+ if isinstance(value, (int, float)):
311
+ return f"{value:.4g}"
312
+ if isinstance(value, bool):
313
+ return "true" if value else "false"
314
+ if isinstance(value, datetime.datetime):
315
+ return value.isoformat()
316
+ if isinstance(value, datetime.timedelta):
317
+ return f"{value.total_seconds():.4g}s"
318
+ if value is None:
319
+ return ""
320
+ if value is MISSING:
321
+ return ""
322
+ return str(value)
323
+
324
+ def iter_flat(config: dict) -> Iterator[tuple[list[str | None], str]]:
325
+ for key, value in reversed(config.items()):
326
+ if isinstance(value, dict):
327
+ is_first = True
328
+ for sub_key_list, sub_value in iter_flat(value):
329
+ yield [format_as_string(key) if is_first else None] + sub_key_list, sub_value
330
+ is_first = False
331
+ elif isinstance(value, (list, tuple)):
332
+ is_first = True
333
+ for i, sub_value in enumerate(value):
334
+ for sub_key_list, sub_sub_value in iter_flat({f"{i}": sub_value}):
335
+ yield [format_as_string(key) if is_first else None] + sub_key_list, sub_sub_value
336
+ is_first = False
337
+ else:
338
+ yield [format_as_string(key)], format_as_string(value)
339
+
340
+ config_dict = cast(dict, OmegaConf.to_container(config, resolve=True, throw_on_missing=False, enum_to_str=True))
341
+ config_flat = list(iter_flat(config_dict))
342
+
343
+ # Gets rows of strings.
344
+ rows: list[list[str]] = []
345
+ for key_list, value in config_flat:
346
+ row = ["" if key is None else key for key in key_list] + [value]
347
+ rows.append(row)
348
+
349
+ # Pads all rows to the same length.
350
+ max_len = max(len(row) for row in rows)
351
+ rows = [row[:-1] + [""] * (max_len - len(row)) + row[-1:] for row in rows]
352
+
353
+ # Converts to a markdown table.
354
+ header_str = "| " + " | ".join([f"key_{i}" for i in range(max_len - 1)]) + " | value |"
355
+ header_sep_str = "|-" + "-|-" * (max_len - 1) + "-|"
356
+ rows_str = "\n".join(["| " + " | ".join(row) + " |" for row in rows])
357
+ return "\n".join([header_str, header_sep_str, rows_str])
358
+
359
+
360
+ def stage_environment(obj: object, root: Path) -> None:
361
+ """Stages the current task to a staging directory.
362
+
363
+ Args:
364
+ obj: The object with the module to stage.
365
+ root: The root directory to stage to.
366
+ """
367
+ root.mkdir(exist_ok=True, parents=True)
368
+
369
+ # Gets the path to the root module. This is done heuristically, so it may
370
+ # not work in all cases, but it should generally work.
371
+ if (mod := inspect.getmodule(obj.__class__)) is None:
372
+ raise RuntimeError(f"Could not find module for task {obj.__class__}!")
373
+ if (spec := mod.__spec__) is None:
374
+ raise RuntimeError(f"Could not find spec for module {mod}!")
375
+ if spec.origin is None:
376
+ raise RuntimeError(f"Could not find origin for spec {spec}!")
377
+ root_mod = spec.name.split(".", 1)[0]
378
+ path_parts = Path(spec.origin).parts[:-1]
379
+ if root_mod not in path_parts:
380
+ raise RuntimeError(f"Could not find root module {root_mod} in path {path_parts}!")
381
+ root_path = Path(*path_parts[: path_parts.index(root_mod) + 1])
382
+
383
+ # Gets files to stage.
384
+ fpaths: set[tuple[Path, Path]] = set()
385
+ for module in sys.modules.values():
386
+ if (fpath_str := getattr(module, "__file__", None)) is None:
387
+ continue
388
+ fpath = Path(fpath_str).resolve()
389
+ try:
390
+ rel_fpath = fpath.relative_to(root_path)
391
+ fpaths.add((fpath, rel_fpath))
392
+ except ValueError:
393
+ pass
394
+
395
+ # Computes hash of all files and return if it matches the previous hash.
396
+ hashobj = hashlib.md5()
397
+ for fpath, _ in fpaths:
398
+ with open(fpath, "rb") as f:
399
+ while data := f.read(65536):
400
+ hashobj.update(data)
401
+ hashval = hashobj.hexdigest()
402
+ prev_hashval: str | None = None
403
+ hash_file = root / ".hash"
404
+ if hash_file.exists():
405
+ prev_hashval = hash_file.read_text().strip()
406
+ if prev_hashval == hashval:
407
+ return
408
+ hash_file.write_text(hashval)
409
+
410
+ # Copies all files to the staging directory.
411
+ if (root / root_mod).exists():
412
+ shutil.rmtree(root / root_mod, ignore_errors=True)
413
+ for fpath, rel_fpath in fpaths:
414
+ new_fpath = root / root_mod / rel_fpath
415
+ new_fpath.parent.mkdir(exist_ok=True, parents=True)
416
+ shutil.copyfile(fpath, new_fpath)
417
+
418
+
419
+ def get_git_state(obj: object) -> str:
420
+ """Gets the state of the Git repo that an object is in as a string.
421
+
422
+ Args:
423
+ obj: The object which is in the target Git repo.
424
+ width: The width of the text blocks.
425
+
426
+ Returns:
427
+ A nicely-formatted string showing the current task's Git state.
428
+ """
429
+ try:
430
+ task_file = inspect.getfile(type(obj))
431
+ repo = git.Repo(task_file, search_parent_directories=True)
432
+ branch = repo.active_branch
433
+ commit = repo.head.commit
434
+ status = textwrap.indent(str(repo.git.status()), " ")
435
+ diff = textwrap.indent(str(repo.git.diff(color=False)), " ")
436
+ return "\n".join(
437
+ [
438
+ f"Path: {task_file}",
439
+ f"Branch: {branch}",
440
+ f"Commit: {commit}",
441
+ "Status:",
442
+ status,
443
+ "Diff:",
444
+ diff,
445
+ ]
446
+ )
447
+
448
+ except Exception:
449
+ return traceback.format_exc()
450
+
451
+
452
+ def get_training_code(obj: object) -> str:
453
+ """Gets the text from the file containing the provided object.
454
+
455
+ Args:
456
+ obj: The object to get the file from.
457
+
458
+ Returns:
459
+ The text from the file containing the object.
460
+ """
461
+ try:
462
+ task_file = inspect.getfile(type(obj))
463
+ with open(task_file, "r") as f:
464
+ return f.read()
465
+ except Exception:
466
+ return traceback.format_exc()
467
+
468
+
469
+ def check_md5(file_path: str | Path, hash_str: str | None, chunk_size: int = 2**16) -> bool:
470
+ """Checks the MD5 of the downloaded file.
471
+
472
+ Args:
473
+ file_path: Path to the downloaded file.
474
+ hash_str: Expected MD5 of the file; if None, return True.
475
+ chunk_size: Size of the chunks to read from the file.
476
+
477
+ Returns:
478
+ True if the MD5 matches, False otherwise.
479
+ """
480
+ if hash_str is None:
481
+ return True
482
+
483
+ md5 = hashlib.md5()
484
+
485
+ with open(file_path, "rb") as f:
486
+ for chunk in iter(lambda: f.read(chunk_size), b""):
487
+ md5.update(chunk)
488
+
489
+ return md5.hexdigest() == hash_str
490
+
491
+
492
+ def check_sha256(file_path: str | Path, hash_str: str | None, chunk_size: int = 2**16) -> bool:
493
+ """Checks the SHA256 of the downloaded file.
494
+
495
+ Args:
496
+ file_path: Path to the downloaded file.
497
+ hash_str: Expected SHA256 of the file; if None, return True.
498
+ chunk_size: Size of the chunks to read from the file.
499
+
500
+ Returns:
501
+ True if the SHA256 matches, False otherwise.
502
+ """
503
+ if hash_str is None:
504
+ return True
505
+
506
+ sha256 = hashlib.sha256()
507
+
508
+ with open(file_path, "rb") as f:
509
+ for chunk in iter(lambda: f.read(chunk_size), b""):
510
+ sha256.update(chunk)
511
+
512
+ return sha256.hexdigest() == hash_str
513
+
514
+
515
+ class BaseFileDownloader(ABC):
516
+ """Provides a simple interface for downloading URLs.
517
+
518
+ This class is meant to be subclassed to provide different download
519
+ locations. For example, when downloading pretrained models, use the
520
+ :class:`ModelDownloader` class.
521
+
522
+ Typically, you should simply use the :func:`ensure_downloaded` function
523
+ to make sure the file is downloaded to the correct location.
524
+
525
+ This is adapted in large part from the reference implementation in the
526
+ Torchvision library.
527
+
528
+ Parameters:
529
+ url: The URL to download from.
530
+ dnames: The directory names to download to.
531
+ md5: The expected MD5 of the file.
532
+ sha256: The expected SHA256 of the file.
533
+ is_tmp: Whether to download to a temporary directory.
534
+ recheck_hash: Whether to recheck the hash after downloading.
535
+ max_redirect_hops: The maximum number of redirects to follow.
536
+ """
537
+
538
+ def __init__(
539
+ self,
540
+ url: str,
541
+ *dnames: str,
542
+ md5: str | None = None,
543
+ sha256: str | None = None,
544
+ is_tmp: bool = False,
545
+ recheck_hash: bool = False,
546
+ max_redirect_hops: int = 3,
547
+ ) -> None:
548
+ super().__init__()
549
+
550
+ assert len(dnames) >= 1, "Must provide at least 1 directory name"
551
+ filepath = Path(tempfile.mkdtemp("models")) if is_tmp else self.get_root_directory()
552
+ for dname in dnames:
553
+ filepath = filepath / dname
554
+ (root := filepath.parent).mkdir(parents=True, exist_ok=True)
555
+
556
+ self.url = url
557
+ self.filename = filepath.name
558
+ self.root = root
559
+ self.md5 = md5
560
+ self.sha256 = sha256
561
+ self.recheck_hash = recheck_hash
562
+ self.max_redirect_hops = max_redirect_hops
563
+
564
+ @abstractmethod
565
+ def get_root_directory(self) -> Path: ...
566
+
567
+ @property
568
+ def filepath(self) -> Path:
569
+ return self.root / self.filename
570
+
571
+ @property
572
+ def is_downloaded(self) -> bool:
573
+ if not self.filepath.exists():
574
+ return False
575
+ if self.recheck_hash and not self.check_hashes():
576
+ logger.warning("A file was found for %s in %s, but its hashes do not match.", self.url, self.filepath)
577
+ self.filepath.unlink()
578
+ return False
579
+ return True
580
+
581
+ def check_hashes(self) -> bool:
582
+ return check_sha256(self.filepath, self.sha256) and check_md5(self.filepath, self.md5)
583
+
584
+ def ensure_downloaded(self) -> Path:
585
+ """Ensures the file is downloaded and returns the path to it.
586
+
587
+ By default, we only check the hash once when the file is downloaded,
588
+ and we don't bother rechecking unless ``recheck_hash`` is set to True.
589
+
590
+ Returns:
591
+ The path to the downloaded file.
592
+ """
593
+ if not self.is_downloaded:
594
+ self.download()
595
+ if not self.check_hashes():
596
+ self.filepath.unlink()
597
+ raise RuntimeError(f"Hashes for {self.filepath} do not match. The corruped file has been deleted.")
598
+ return self.filepath
599
+
600
+ def download(self) -> None:
601
+ root = self.root.expanduser()
602
+ root.mkdir(parents=True, exist_ok=True)
603
+
604
+ # Expands the redirect chain if needed.
605
+ url = self._get_redirect_url(self.url, max_hops=self.max_redirect_hops)
606
+
607
+ # Checks if file is located on Google Drive.
608
+ file_id = self._get_google_drive_file_id(url)
609
+ if file_id is not None:
610
+ return self.download_file_from_google_drive(file_id, root, self.filename)
611
+
612
+ # Downloads the file.
613
+ try:
614
+ logger.info("Downloading %s to %s", url, self.filepath)
615
+ self._urlretrieve(url, self.filepath)
616
+ except (urllib.error.URLError, OSError) as e:
617
+ if url[:5] == "https":
618
+ url = url.replace("https:", "http:")
619
+ logger.warning("Download failed. Trying HTTP instead of HTTPS: %s to %s", url, self.filepath)
620
+ self._urlretrieve(url, self.filepath)
621
+ else:
622
+ raise e
623
+
624
+ @classmethod
625
+ def _save_response_content(cls, content: Iterator[bytes], destination: Path) -> None:
626
+ with open(destination, "wb") as fh:
627
+ for chunk in content:
628
+ if not chunk: # Filter out keep-alive new chunks.
629
+ continue
630
+ fh.write(chunk)
631
+
632
+ @classmethod
633
+ def _urlretrieve(cls, url: str, filename: Path, chunk_size: int = 1024 * 32) -> None:
634
+ with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
635
+ cls._save_response_content(iter(lambda: response.read(chunk_size), b""), filename)
636
+
637
+ @classmethod
638
+ def _extract_gdrive_api_response(
639
+ cls,
640
+ response: requests.Response,
641
+ chunk_size: int = 32 * 1024,
642
+ ) -> tuple[str | None, Iterator[bytes]]:
643
+ content = response.iter_content(chunk_size)
644
+ first_chunk = None
645
+ while not first_chunk: # Filter out keep-alive new chunks.
646
+ first_chunk = next(content)
647
+ content = itertools.chain([first_chunk], content)
648
+
649
+ try:
650
+ match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
651
+ api_response = match["api_response"] if match is not None else None
652
+ except UnicodeDecodeError:
653
+ api_response = None
654
+ return api_response, content
655
+
656
+ @classmethod
657
+ def download_file_from_google_drive(cls, file_id: str, root: Path, filename: str | None = None) -> None:
658
+ root = root.expanduser()
659
+ if not filename:
660
+ filename = file_id
661
+ fpath = root / filename
662
+ root.mkdir(parents=True, exist_ok=True)
663
+
664
+ url = "https://drive.google.com/uc"
665
+ params = dict(id=file_id, export="download")
666
+ with requests.Session() as session:
667
+ response = session.get(url, params=params, stream=True)
668
+
669
+ token: str | None = None
670
+ for key, value in response.cookies.items():
671
+ if key.startswith("download_warning"):
672
+ token = value
673
+ break
674
+ else:
675
+ api_response, content = cls._extract_gdrive_api_response(response)
676
+ token = "t" if api_response == "Virus scan warning" else None
677
+
678
+ if token is not None:
679
+ response = session.get(url, params=dict(params, confirm=token), stream=True)
680
+ api_response, content = cls._extract_gdrive_api_response(response)
681
+
682
+ if api_response == "Quota exceeded":
683
+ raise RuntimeError(
684
+ f"The daily quota of the file {filename} is exceeded and it "
685
+ f"can't be downloaded. This is a limitation of Google Drive "
686
+ f"and can only be overcome by trying again later."
687
+ )
688
+
689
+ cls._save_response_content(content, fpath)
690
+
691
+ # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB with only text.
692
+ if os.stat(fpath).st_size < 10 * 1024:
693
+ with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
694
+ text = fh.read()
695
+
696
+ # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
697
+ if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
698
+ warnings.warn(
699
+ f"We detected some HTML elements in the downloaded file. "
700
+ f"This most likely means that the download triggered an unhandled API response by GDrive. "
701
+ f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
702
+ f"the response:\n\n{text}"
703
+ )
704
+
705
+ @classmethod
706
+ def _get_google_drive_file_id(cls, url: str) -> str | None:
707
+ parts = urlparse(url)
708
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
709
+ return None
710
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
711
+ if match is None:
712
+ return None
713
+ return match.group("id")
714
+
715
+ @classmethod
716
+ def _get_redirect_url(cls, url: str, max_hops: int = 3) -> str:
717
+ initial_url = url
718
+ headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
719
+ for _ in range(max_hops + 1):
720
+ with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
721
+ if response.url == url or response.url is None:
722
+ return url
723
+ url = response.url
724
+ raise RecursionError(f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect was {url}.")
725
+
726
+
727
+ class ModelDownloader(BaseFileDownloader):
728
+ def get_root_directory(self) -> Path:
729
+ return get_pretrained_models_dir()
730
+
731
+
732
+ class DataDownloader(BaseFileDownloader):
733
+ def get_root_directory(self) -> Path:
734
+ return get_data_dir()
735
+
736
+
737
+ def get_state_dict_prefix(
738
+ ckpt: dict[str, T],
739
+ prefix: str | None = None,
740
+ suffix: str | None = None,
741
+ regexp: re.Pattern[str] | None = None,
742
+ ) -> dict[str, T]:
743
+ """Returns the parts of a checkpoint which begin with a prefix.
744
+
745
+ Args:
746
+ ckpt: The checkpoint to modify
747
+ prefix: The prefix to clip
748
+ suffix: The suffix to clip
749
+ regexp: The regexp to search for (doesn't modify any keys)
750
+
751
+ Returns:
752
+ The modified checkpoint
753
+ """
754
+ if prefix is not None:
755
+ ckpt = {k[len(prefix) :]: v for k, v in ckpt.items() if k.startswith(prefix)}
756
+ if suffix is not None:
757
+ ckpt = {k[: -len(suffix)]: v for k, v in ckpt.items() if k.endswith(suffix)}
758
+ if regexp is not None:
759
+ ckpt = {k: v for k, v in ckpt.items() if regexp.match(k)}
760
+ return ckpt
761
+
762
+
763
+ def split_n_items_across_workers(n: int, worker_id: int, num_workers: int) -> tuple[int, int]:
764
+ """Computes offsets for splitting N items across K workers.
765
+
766
+ This returns the start and end indices for the items to be processed by the
767
+ given worker. The end index is exclusive.
768
+
769
+ Args:
770
+ n: The number of items to process.
771
+ worker_id: The ID of the current worker.
772
+ num_workers: The total number of workers.
773
+
774
+ Returns:
775
+ The start and end index for the items in the current worker.
776
+ """
777
+ assert n >= num_workers, f"n ({n}) must be >= num_workers ({num_workers})"
778
+ assert 0 <= worker_id < num_workers, f"worker_id ({worker_id}) must be >= 0 and < num_workers ({num_workers})"
779
+
780
+ # The number of items to process per worker.
781
+ items_per_worker = math.ceil(n / num_workers)
782
+
783
+ # The start and end indices for the items to process.
784
+ start = worker_id * items_per_worker
785
+ end = min(start + items_per_worker, n)
786
+
787
+ return start, end
788
+
789
+
790
+ def num_workers(default: int) -> int:
791
+ max_workers = load_user_config().experiment.max_workers
792
+ if hasattr(os, "sched_getaffinity"):
793
+ try:
794
+ return min(len(os.sched_getaffinity(0)), max_workers)
795
+ except Exception:
796
+ pass
797
+ if (cpu_count := os.cpu_count()) is not None:
798
+ return min(cpu_count, max_workers)
799
+ return min(default, max_workers)
800
+
801
+
802
+ OmegaConf.register_new_resolver("mlfab.num_workers", num_workers, replace=True)