xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/utils/experiments.py
ADDED
@@ -0,0 +1,802 @@
|
|
1
|
+
"""Functions for managing experiments."""
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import datetime
|
5
|
+
import enum
|
6
|
+
import functools
|
7
|
+
import hashlib
|
8
|
+
import inspect
|
9
|
+
import itertools
|
10
|
+
import logging
|
11
|
+
import math
|
12
|
+
import os
|
13
|
+
import random
|
14
|
+
import re
|
15
|
+
import shutil
|
16
|
+
import sys
|
17
|
+
import tempfile
|
18
|
+
import textwrap
|
19
|
+
import time
|
20
|
+
import traceback
|
21
|
+
import urllib.error
|
22
|
+
import urllib.request
|
23
|
+
import warnings
|
24
|
+
from abc import ABC, abstractmethod
|
25
|
+
from pathlib import Path
|
26
|
+
from typing import Any, Iterator, TypeVar, cast
|
27
|
+
from urllib.parse import urlparse
|
28
|
+
|
29
|
+
import git
|
30
|
+
import requests
|
31
|
+
from jaxtyping import Array
|
32
|
+
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf
|
33
|
+
|
34
|
+
from xax.core.conf import get_data_dir, get_pretrained_models_dir, load_user_config
|
35
|
+
from xax.core.state import State
|
36
|
+
from xax.utils.text import colored
|
37
|
+
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
|
40
|
+
# Date format for staging environments.
|
41
|
+
DATE_FORMAT = "%Y-%m-%d"
|
42
|
+
|
43
|
+
USER_AGENT = "xax"
|
44
|
+
|
45
|
+
T = TypeVar("T")
|
46
|
+
|
47
|
+
|
48
|
+
class CumulativeTimer:
|
49
|
+
"""Defines a simple timer to track an average value."""
|
50
|
+
|
51
|
+
def __init__(self) -> None:
|
52
|
+
self.steps = 0
|
53
|
+
self.elapsed_time = 0.0
|
54
|
+
|
55
|
+
@functools.cached_property
|
56
|
+
def start_time(self) -> float:
|
57
|
+
return time.time()
|
58
|
+
|
59
|
+
def step(self, steps: int, cur_time: float) -> None:
|
60
|
+
if steps != self.steps:
|
61
|
+
self.steps = steps
|
62
|
+
self.elapsed_time = cur_time - self.start_time
|
63
|
+
|
64
|
+
@property
|
65
|
+
def steps_per_second(self) -> float:
|
66
|
+
return 0.0 if self.elapsed_time < 1e-4 else self.steps / self.elapsed_time
|
67
|
+
|
68
|
+
@property
|
69
|
+
def steps_per_hour(self) -> float:
|
70
|
+
return self.steps_per_second * 60 * 60
|
71
|
+
|
72
|
+
@property
|
73
|
+
def seconds_per_step(self) -> float:
|
74
|
+
return 0.0 if self.steps <= 0 else self.elapsed_time / self.steps
|
75
|
+
|
76
|
+
@property
|
77
|
+
def hours_per_step(self) -> float:
|
78
|
+
return self.seconds_per_step / (60 * 60)
|
79
|
+
|
80
|
+
|
81
|
+
class IterationTimer:
|
82
|
+
"""Defines a simple timer to track consecutive values."""
|
83
|
+
|
84
|
+
def __init__(self) -> None:
|
85
|
+
self.iteration_time = 0.0
|
86
|
+
self.last_time = time.time()
|
87
|
+
|
88
|
+
def step(self, cur_time: float) -> None:
|
89
|
+
self.iteration_time = cur_time - self.last_time
|
90
|
+
self.last_time = cur_time
|
91
|
+
|
92
|
+
@property
|
93
|
+
def iter_seconds(self) -> float:
|
94
|
+
return self.iteration_time
|
95
|
+
|
96
|
+
@property
|
97
|
+
def iter_hours(self) -> float:
|
98
|
+
return self.iter_seconds / (60 * 60)
|
99
|
+
|
100
|
+
|
101
|
+
class StateTimer:
|
102
|
+
"""Defines a timer for all state information."""
|
103
|
+
|
104
|
+
def __init__(self) -> None:
|
105
|
+
self.step_timer = CumulativeTimer()
|
106
|
+
self.sample_timer = CumulativeTimer()
|
107
|
+
self.iter_timer = IterationTimer()
|
108
|
+
|
109
|
+
def step(self, state: State) -> None:
|
110
|
+
cur_time = time.time()
|
111
|
+
self.step_timer.step(state.num_steps, cur_time)
|
112
|
+
self.sample_timer.step(state.num_samples, cur_time)
|
113
|
+
self.iter_timer.step(cur_time)
|
114
|
+
|
115
|
+
def log_dict(self) -> dict[str, dict[str, int | float]]:
|
116
|
+
logs: dict[str, dict[str, int | float]] = {}
|
117
|
+
|
118
|
+
# Logs step statistics.
|
119
|
+
logs["⏰ steps"] = {
|
120
|
+
"total": self.step_timer.steps,
|
121
|
+
"per-second": self.step_timer.steps_per_second,
|
122
|
+
}
|
123
|
+
|
124
|
+
# Logs sample statistics.
|
125
|
+
logs["⏰ samples"] = {
|
126
|
+
"total": self.sample_timer.steps,
|
127
|
+
"per-second": self.sample_timer.steps_per_second,
|
128
|
+
}
|
129
|
+
|
130
|
+
# Logs full iteration statistics.
|
131
|
+
logs["🔧 dt"] = {
|
132
|
+
"iter": self.iter_timer.iter_seconds,
|
133
|
+
}
|
134
|
+
|
135
|
+
return logs
|
136
|
+
|
137
|
+
|
138
|
+
class IntervalTicker:
|
139
|
+
def __init__(self, interval: float) -> None:
|
140
|
+
self.interval = interval
|
141
|
+
self.last_tick_time: float | None = None
|
142
|
+
|
143
|
+
def tick(self, elapsed_time: float) -> bool:
|
144
|
+
if self.last_tick_time is None or elapsed_time - self.last_tick_time > self.interval:
|
145
|
+
self.last_tick_time = elapsed_time
|
146
|
+
return True
|
147
|
+
return False
|
148
|
+
|
149
|
+
|
150
|
+
def abs_path(path: str) -> str:
|
151
|
+
return str(Path(path).resolve())
|
152
|
+
|
153
|
+
|
154
|
+
OmegaConf.register_new_resolver("ml.abs_path", abs_path, replace=True)
|
155
|
+
|
156
|
+
|
157
|
+
def cpu_count(default: int) -> int:
|
158
|
+
if (cpu_count := os.cpu_count()) is not None:
|
159
|
+
return cpu_count
|
160
|
+
return default
|
161
|
+
|
162
|
+
|
163
|
+
OmegaConf.register_new_resolver("ml.cpu_count", cpu_count, replace=True)
|
164
|
+
|
165
|
+
|
166
|
+
def date_str(_: str) -> str:
|
167
|
+
return time.strftime("%Y-%m-%d")
|
168
|
+
|
169
|
+
|
170
|
+
OmegaConf.register_new_resolver("ml.date_str", date_str, replace=True)
|
171
|
+
|
172
|
+
|
173
|
+
def get_random_port(default: int = 1337) -> int:
|
174
|
+
try:
|
175
|
+
return (hash(time.time()) + random.randint(0, 100000)) % (65_535 - 10_000) + 10_000
|
176
|
+
except Exception:
|
177
|
+
return default
|
178
|
+
|
179
|
+
|
180
|
+
OmegaConf.register_new_resolver("xax.get_random_port", get_random_port, replace=True)
|
181
|
+
|
182
|
+
|
183
|
+
class NaNError(Exception):
|
184
|
+
"""Raised when NaNs are detected in the model parameters."""
|
185
|
+
|
186
|
+
|
187
|
+
class TrainingFinishedError(Exception):
|
188
|
+
"""Raised when training is finished."""
|
189
|
+
|
190
|
+
|
191
|
+
class MinGradScaleError(TrainingFinishedError):
|
192
|
+
"""Raised when the minimum gradient scale is reached.
|
193
|
+
|
194
|
+
This is a subclass of :class:`TrainingFinishedError` because it indicates
|
195
|
+
that training is finished and causes the post-training hooks to be run.
|
196
|
+
"""
|
197
|
+
|
198
|
+
|
199
|
+
def diff_configs(
|
200
|
+
first: ListConfig | DictConfig,
|
201
|
+
second: ListConfig | DictConfig,
|
202
|
+
prefix: str | None = None,
|
203
|
+
) -> tuple[list[str], list[str]]:
|
204
|
+
"""Returns the difference between two configs.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
first: The first (original) config
|
208
|
+
second: The second (new) config
|
209
|
+
prefix: The prefix to check (used for recursion, not main call)
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Two lists of lines describing the diff between the two configs
|
213
|
+
"""
|
214
|
+
|
215
|
+
def get_diff_string(prefix: str | None, val: Any) -> str: # noqa: ANN401
|
216
|
+
if isinstance(val, (str, float, int)):
|
217
|
+
return f"{prefix}={val}"
|
218
|
+
return f"{prefix}= ... ({type(val)})"
|
219
|
+
|
220
|
+
def cast_enums(k: Any) -> Any: # noqa: ANN401
|
221
|
+
return k.name if isinstance(k, enum.Enum) else k
|
222
|
+
|
223
|
+
new_first: list[str] = []
|
224
|
+
new_second: list[str] = []
|
225
|
+
|
226
|
+
any_config = (ListConfig, DictConfig)
|
227
|
+
|
228
|
+
if isinstance(first, DictConfig) and isinstance(second, DictConfig):
|
229
|
+
first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
|
230
|
+
|
231
|
+
# Gets the new keys in each config.
|
232
|
+
new_first += [f"{prefix}.{key}" for key in first_keys.difference(second_keys)]
|
233
|
+
new_second += [f"{prefix}.{key}" for key in second_keys.difference(first_keys)]
|
234
|
+
|
235
|
+
# Gets the new sub-keys in each config.
|
236
|
+
for key in first_keys.intersection(second_keys):
|
237
|
+
sub_prefix = key if prefix is None else f"{prefix}.{key}"
|
238
|
+
if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
|
239
|
+
if not OmegaConf.is_missing(first, key):
|
240
|
+
new_first += [get_diff_string(sub_prefix, first[key])]
|
241
|
+
if not OmegaConf.is_missing(second, key):
|
242
|
+
new_second += [get_diff_string(sub_prefix, second[key])]
|
243
|
+
elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
|
244
|
+
sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
|
245
|
+
new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
|
246
|
+
elif cast_enums(first[key]) != cast_enums(second[key]):
|
247
|
+
first_val, second_val = first[key], second[key]
|
248
|
+
new_first += [get_diff_string(sub_prefix, first_val)]
|
249
|
+
new_second += [get_diff_string(sub_prefix, second_val)]
|
250
|
+
|
251
|
+
elif isinstance(first, ListConfig) and isinstance(second, ListConfig):
|
252
|
+
if len(first) > len(second):
|
253
|
+
for i in range(len(second), len(first)):
|
254
|
+
new_first += [get_diff_string(prefix, first[i])]
|
255
|
+
elif len(second) > len(first):
|
256
|
+
for i in range(len(first), len(second)):
|
257
|
+
new_second += [get_diff_string(prefix, second[i])]
|
258
|
+
|
259
|
+
for i in range(min(len(first), len(second))):
|
260
|
+
sub_prefix = str(i) if prefix is None else f"{prefix}.{i}"
|
261
|
+
if isinstance(first[i], any_config) and isinstance(second[i], any_config):
|
262
|
+
sub_new_first, sub_new_second = diff_configs(first[i], second[i], prefix=sub_prefix)
|
263
|
+
new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
|
264
|
+
else:
|
265
|
+
new_first += [get_diff_string(prefix, first)]
|
266
|
+
new_second += [get_diff_string(prefix, second)]
|
267
|
+
|
268
|
+
return new_first, new_second
|
269
|
+
|
270
|
+
|
271
|
+
def get_diff_string(config_diff: tuple[list[str], list[str]]) -> str | None:
|
272
|
+
added_keys, deleted_keys = config_diff
|
273
|
+
if not added_keys and not deleted_keys:
|
274
|
+
return None
|
275
|
+
change_lines: list[str] = []
|
276
|
+
change_lines += [f" ↪ {colored('+', 'green')} {added_key}" for added_key in added_keys]
|
277
|
+
change_lines += [f" ↪ {colored('-', 'red')} {deleted_key}" for deleted_key in deleted_keys]
|
278
|
+
change_summary = "\n".join(change_lines)
|
279
|
+
return change_summary
|
280
|
+
|
281
|
+
|
282
|
+
def save_config(config_path: Path, raw_config: DictConfig) -> None:
|
283
|
+
if config_path.exists():
|
284
|
+
config_diff = diff_configs(raw_config, cast(DictConfig, OmegaConf.load(config_path)))
|
285
|
+
diff_string = get_diff_string(config_diff)
|
286
|
+
if diff_string is not None:
|
287
|
+
logger.warning("Overwriting config %s:\n%s", config_path, diff_string)
|
288
|
+
OmegaConf.save(raw_config, config_path)
|
289
|
+
else:
|
290
|
+
config_path.parent.mkdir(exist_ok=True, parents=True)
|
291
|
+
OmegaConf.save(raw_config, config_path)
|
292
|
+
logger.info("Saved config to %s", config_path)
|
293
|
+
|
294
|
+
|
295
|
+
def to_markdown_table(config: DictConfig) -> str:
|
296
|
+
"""Converts a config to a markdown table string.
|
297
|
+
|
298
|
+
Args:
|
299
|
+
config: The config to convert to a table.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
The config, formatted as a Markdown string.
|
303
|
+
"""
|
304
|
+
|
305
|
+
def format_as_string(value: Any) -> str: # noqa: ANN401
|
306
|
+
if isinstance(value, str):
|
307
|
+
return value
|
308
|
+
if isinstance(value, Array):
|
309
|
+
value = value.item()
|
310
|
+
if isinstance(value, (int, float)):
|
311
|
+
return f"{value:.4g}"
|
312
|
+
if isinstance(value, bool):
|
313
|
+
return "true" if value else "false"
|
314
|
+
if isinstance(value, datetime.datetime):
|
315
|
+
return value.isoformat()
|
316
|
+
if isinstance(value, datetime.timedelta):
|
317
|
+
return f"{value.total_seconds():.4g}s"
|
318
|
+
if value is None:
|
319
|
+
return ""
|
320
|
+
if value is MISSING:
|
321
|
+
return ""
|
322
|
+
return str(value)
|
323
|
+
|
324
|
+
def iter_flat(config: dict) -> Iterator[tuple[list[str | None], str]]:
|
325
|
+
for key, value in reversed(config.items()):
|
326
|
+
if isinstance(value, dict):
|
327
|
+
is_first = True
|
328
|
+
for sub_key_list, sub_value in iter_flat(value):
|
329
|
+
yield [format_as_string(key) if is_first else None] + sub_key_list, sub_value
|
330
|
+
is_first = False
|
331
|
+
elif isinstance(value, (list, tuple)):
|
332
|
+
is_first = True
|
333
|
+
for i, sub_value in enumerate(value):
|
334
|
+
for sub_key_list, sub_sub_value in iter_flat({f"{i}": sub_value}):
|
335
|
+
yield [format_as_string(key) if is_first else None] + sub_key_list, sub_sub_value
|
336
|
+
is_first = False
|
337
|
+
else:
|
338
|
+
yield [format_as_string(key)], format_as_string(value)
|
339
|
+
|
340
|
+
config_dict = cast(dict, OmegaConf.to_container(config, resolve=True, throw_on_missing=False, enum_to_str=True))
|
341
|
+
config_flat = list(iter_flat(config_dict))
|
342
|
+
|
343
|
+
# Gets rows of strings.
|
344
|
+
rows: list[list[str]] = []
|
345
|
+
for key_list, value in config_flat:
|
346
|
+
row = ["" if key is None else key for key in key_list] + [value]
|
347
|
+
rows.append(row)
|
348
|
+
|
349
|
+
# Pads all rows to the same length.
|
350
|
+
max_len = max(len(row) for row in rows)
|
351
|
+
rows = [row[:-1] + [""] * (max_len - len(row)) + row[-1:] for row in rows]
|
352
|
+
|
353
|
+
# Converts to a markdown table.
|
354
|
+
header_str = "| " + " | ".join([f"key_{i}" for i in range(max_len - 1)]) + " | value |"
|
355
|
+
header_sep_str = "|-" + "-|-" * (max_len - 1) + "-|"
|
356
|
+
rows_str = "\n".join(["| " + " | ".join(row) + " |" for row in rows])
|
357
|
+
return "\n".join([header_str, header_sep_str, rows_str])
|
358
|
+
|
359
|
+
|
360
|
+
def stage_environment(obj: object, root: Path) -> None:
|
361
|
+
"""Stages the current task to a staging directory.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
obj: The object with the module to stage.
|
365
|
+
root: The root directory to stage to.
|
366
|
+
"""
|
367
|
+
root.mkdir(exist_ok=True, parents=True)
|
368
|
+
|
369
|
+
# Gets the path to the root module. This is done heuristically, so it may
|
370
|
+
# not work in all cases, but it should generally work.
|
371
|
+
if (mod := inspect.getmodule(obj.__class__)) is None:
|
372
|
+
raise RuntimeError(f"Could not find module for task {obj.__class__}!")
|
373
|
+
if (spec := mod.__spec__) is None:
|
374
|
+
raise RuntimeError(f"Could not find spec for module {mod}!")
|
375
|
+
if spec.origin is None:
|
376
|
+
raise RuntimeError(f"Could not find origin for spec {spec}!")
|
377
|
+
root_mod = spec.name.split(".", 1)[0]
|
378
|
+
path_parts = Path(spec.origin).parts[:-1]
|
379
|
+
if root_mod not in path_parts:
|
380
|
+
raise RuntimeError(f"Could not find root module {root_mod} in path {path_parts}!")
|
381
|
+
root_path = Path(*path_parts[: path_parts.index(root_mod) + 1])
|
382
|
+
|
383
|
+
# Gets files to stage.
|
384
|
+
fpaths: set[tuple[Path, Path]] = set()
|
385
|
+
for module in sys.modules.values():
|
386
|
+
if (fpath_str := getattr(module, "__file__", None)) is None:
|
387
|
+
continue
|
388
|
+
fpath = Path(fpath_str).resolve()
|
389
|
+
try:
|
390
|
+
rel_fpath = fpath.relative_to(root_path)
|
391
|
+
fpaths.add((fpath, rel_fpath))
|
392
|
+
except ValueError:
|
393
|
+
pass
|
394
|
+
|
395
|
+
# Computes hash of all files and return if it matches the previous hash.
|
396
|
+
hashobj = hashlib.md5()
|
397
|
+
for fpath, _ in fpaths:
|
398
|
+
with open(fpath, "rb") as f:
|
399
|
+
while data := f.read(65536):
|
400
|
+
hashobj.update(data)
|
401
|
+
hashval = hashobj.hexdigest()
|
402
|
+
prev_hashval: str | None = None
|
403
|
+
hash_file = root / ".hash"
|
404
|
+
if hash_file.exists():
|
405
|
+
prev_hashval = hash_file.read_text().strip()
|
406
|
+
if prev_hashval == hashval:
|
407
|
+
return
|
408
|
+
hash_file.write_text(hashval)
|
409
|
+
|
410
|
+
# Copies all files to the staging directory.
|
411
|
+
if (root / root_mod).exists():
|
412
|
+
shutil.rmtree(root / root_mod, ignore_errors=True)
|
413
|
+
for fpath, rel_fpath in fpaths:
|
414
|
+
new_fpath = root / root_mod / rel_fpath
|
415
|
+
new_fpath.parent.mkdir(exist_ok=True, parents=True)
|
416
|
+
shutil.copyfile(fpath, new_fpath)
|
417
|
+
|
418
|
+
|
419
|
+
def get_git_state(obj: object) -> str:
|
420
|
+
"""Gets the state of the Git repo that an object is in as a string.
|
421
|
+
|
422
|
+
Args:
|
423
|
+
obj: The object which is in the target Git repo.
|
424
|
+
width: The width of the text blocks.
|
425
|
+
|
426
|
+
Returns:
|
427
|
+
A nicely-formatted string showing the current task's Git state.
|
428
|
+
"""
|
429
|
+
try:
|
430
|
+
task_file = inspect.getfile(type(obj))
|
431
|
+
repo = git.Repo(task_file, search_parent_directories=True)
|
432
|
+
branch = repo.active_branch
|
433
|
+
commit = repo.head.commit
|
434
|
+
status = textwrap.indent(str(repo.git.status()), " ")
|
435
|
+
diff = textwrap.indent(str(repo.git.diff(color=False)), " ")
|
436
|
+
return "\n".join(
|
437
|
+
[
|
438
|
+
f"Path: {task_file}",
|
439
|
+
f"Branch: {branch}",
|
440
|
+
f"Commit: {commit}",
|
441
|
+
"Status:",
|
442
|
+
status,
|
443
|
+
"Diff:",
|
444
|
+
diff,
|
445
|
+
]
|
446
|
+
)
|
447
|
+
|
448
|
+
except Exception:
|
449
|
+
return traceback.format_exc()
|
450
|
+
|
451
|
+
|
452
|
+
def get_training_code(obj: object) -> str:
|
453
|
+
"""Gets the text from the file containing the provided object.
|
454
|
+
|
455
|
+
Args:
|
456
|
+
obj: The object to get the file from.
|
457
|
+
|
458
|
+
Returns:
|
459
|
+
The text from the file containing the object.
|
460
|
+
"""
|
461
|
+
try:
|
462
|
+
task_file = inspect.getfile(type(obj))
|
463
|
+
with open(task_file, "r") as f:
|
464
|
+
return f.read()
|
465
|
+
except Exception:
|
466
|
+
return traceback.format_exc()
|
467
|
+
|
468
|
+
|
469
|
+
def check_md5(file_path: str | Path, hash_str: str | None, chunk_size: int = 2**16) -> bool:
|
470
|
+
"""Checks the MD5 of the downloaded file.
|
471
|
+
|
472
|
+
Args:
|
473
|
+
file_path: Path to the downloaded file.
|
474
|
+
hash_str: Expected MD5 of the file; if None, return True.
|
475
|
+
chunk_size: Size of the chunks to read from the file.
|
476
|
+
|
477
|
+
Returns:
|
478
|
+
True if the MD5 matches, False otherwise.
|
479
|
+
"""
|
480
|
+
if hash_str is None:
|
481
|
+
return True
|
482
|
+
|
483
|
+
md5 = hashlib.md5()
|
484
|
+
|
485
|
+
with open(file_path, "rb") as f:
|
486
|
+
for chunk in iter(lambda: f.read(chunk_size), b""):
|
487
|
+
md5.update(chunk)
|
488
|
+
|
489
|
+
return md5.hexdigest() == hash_str
|
490
|
+
|
491
|
+
|
492
|
+
def check_sha256(file_path: str | Path, hash_str: str | None, chunk_size: int = 2**16) -> bool:
|
493
|
+
"""Checks the SHA256 of the downloaded file.
|
494
|
+
|
495
|
+
Args:
|
496
|
+
file_path: Path to the downloaded file.
|
497
|
+
hash_str: Expected SHA256 of the file; if None, return True.
|
498
|
+
chunk_size: Size of the chunks to read from the file.
|
499
|
+
|
500
|
+
Returns:
|
501
|
+
True if the SHA256 matches, False otherwise.
|
502
|
+
"""
|
503
|
+
if hash_str is None:
|
504
|
+
return True
|
505
|
+
|
506
|
+
sha256 = hashlib.sha256()
|
507
|
+
|
508
|
+
with open(file_path, "rb") as f:
|
509
|
+
for chunk in iter(lambda: f.read(chunk_size), b""):
|
510
|
+
sha256.update(chunk)
|
511
|
+
|
512
|
+
return sha256.hexdigest() == hash_str
|
513
|
+
|
514
|
+
|
515
|
+
class BaseFileDownloader(ABC):
|
516
|
+
"""Provides a simple interface for downloading URLs.
|
517
|
+
|
518
|
+
This class is meant to be subclassed to provide different download
|
519
|
+
locations. For example, when downloading pretrained models, use the
|
520
|
+
:class:`ModelDownloader` class.
|
521
|
+
|
522
|
+
Typically, you should simply use the :func:`ensure_downloaded` function
|
523
|
+
to make sure the file is downloaded to the correct location.
|
524
|
+
|
525
|
+
This is adapted in large part from the reference implementation in the
|
526
|
+
Torchvision library.
|
527
|
+
|
528
|
+
Parameters:
|
529
|
+
url: The URL to download from.
|
530
|
+
dnames: The directory names to download to.
|
531
|
+
md5: The expected MD5 of the file.
|
532
|
+
sha256: The expected SHA256 of the file.
|
533
|
+
is_tmp: Whether to download to a temporary directory.
|
534
|
+
recheck_hash: Whether to recheck the hash after downloading.
|
535
|
+
max_redirect_hops: The maximum number of redirects to follow.
|
536
|
+
"""
|
537
|
+
|
538
|
+
def __init__(
|
539
|
+
self,
|
540
|
+
url: str,
|
541
|
+
*dnames: str,
|
542
|
+
md5: str | None = None,
|
543
|
+
sha256: str | None = None,
|
544
|
+
is_tmp: bool = False,
|
545
|
+
recheck_hash: bool = False,
|
546
|
+
max_redirect_hops: int = 3,
|
547
|
+
) -> None:
|
548
|
+
super().__init__()
|
549
|
+
|
550
|
+
assert len(dnames) >= 1, "Must provide at least 1 directory name"
|
551
|
+
filepath = Path(tempfile.mkdtemp("models")) if is_tmp else self.get_root_directory()
|
552
|
+
for dname in dnames:
|
553
|
+
filepath = filepath / dname
|
554
|
+
(root := filepath.parent).mkdir(parents=True, exist_ok=True)
|
555
|
+
|
556
|
+
self.url = url
|
557
|
+
self.filename = filepath.name
|
558
|
+
self.root = root
|
559
|
+
self.md5 = md5
|
560
|
+
self.sha256 = sha256
|
561
|
+
self.recheck_hash = recheck_hash
|
562
|
+
self.max_redirect_hops = max_redirect_hops
|
563
|
+
|
564
|
+
@abstractmethod
|
565
|
+
def get_root_directory(self) -> Path: ...
|
566
|
+
|
567
|
+
@property
|
568
|
+
def filepath(self) -> Path:
|
569
|
+
return self.root / self.filename
|
570
|
+
|
571
|
+
@property
|
572
|
+
def is_downloaded(self) -> bool:
|
573
|
+
if not self.filepath.exists():
|
574
|
+
return False
|
575
|
+
if self.recheck_hash and not self.check_hashes():
|
576
|
+
logger.warning("A file was found for %s in %s, but its hashes do not match.", self.url, self.filepath)
|
577
|
+
self.filepath.unlink()
|
578
|
+
return False
|
579
|
+
return True
|
580
|
+
|
581
|
+
def check_hashes(self) -> bool:
|
582
|
+
return check_sha256(self.filepath, self.sha256) and check_md5(self.filepath, self.md5)
|
583
|
+
|
584
|
+
def ensure_downloaded(self) -> Path:
|
585
|
+
"""Ensures the file is downloaded and returns the path to it.
|
586
|
+
|
587
|
+
By default, we only check the hash once when the file is downloaded,
|
588
|
+
and we don't bother rechecking unless ``recheck_hash`` is set to True.
|
589
|
+
|
590
|
+
Returns:
|
591
|
+
The path to the downloaded file.
|
592
|
+
"""
|
593
|
+
if not self.is_downloaded:
|
594
|
+
self.download()
|
595
|
+
if not self.check_hashes():
|
596
|
+
self.filepath.unlink()
|
597
|
+
raise RuntimeError(f"Hashes for {self.filepath} do not match. The corruped file has been deleted.")
|
598
|
+
return self.filepath
|
599
|
+
|
600
|
+
def download(self) -> None:
|
601
|
+
root = self.root.expanduser()
|
602
|
+
root.mkdir(parents=True, exist_ok=True)
|
603
|
+
|
604
|
+
# Expands the redirect chain if needed.
|
605
|
+
url = self._get_redirect_url(self.url, max_hops=self.max_redirect_hops)
|
606
|
+
|
607
|
+
# Checks if file is located on Google Drive.
|
608
|
+
file_id = self._get_google_drive_file_id(url)
|
609
|
+
if file_id is not None:
|
610
|
+
return self.download_file_from_google_drive(file_id, root, self.filename)
|
611
|
+
|
612
|
+
# Downloads the file.
|
613
|
+
try:
|
614
|
+
logger.info("Downloading %s to %s", url, self.filepath)
|
615
|
+
self._urlretrieve(url, self.filepath)
|
616
|
+
except (urllib.error.URLError, OSError) as e:
|
617
|
+
if url[:5] == "https":
|
618
|
+
url = url.replace("https:", "http:")
|
619
|
+
logger.warning("Download failed. Trying HTTP instead of HTTPS: %s to %s", url, self.filepath)
|
620
|
+
self._urlretrieve(url, self.filepath)
|
621
|
+
else:
|
622
|
+
raise e
|
623
|
+
|
624
|
+
@classmethod
|
625
|
+
def _save_response_content(cls, content: Iterator[bytes], destination: Path) -> None:
|
626
|
+
with open(destination, "wb") as fh:
|
627
|
+
for chunk in content:
|
628
|
+
if not chunk: # Filter out keep-alive new chunks.
|
629
|
+
continue
|
630
|
+
fh.write(chunk)
|
631
|
+
|
632
|
+
@classmethod
|
633
|
+
def _urlretrieve(cls, url: str, filename: Path, chunk_size: int = 1024 * 32) -> None:
|
634
|
+
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
|
635
|
+
cls._save_response_content(iter(lambda: response.read(chunk_size), b""), filename)
|
636
|
+
|
637
|
+
@classmethod
|
638
|
+
def _extract_gdrive_api_response(
|
639
|
+
cls,
|
640
|
+
response: requests.Response,
|
641
|
+
chunk_size: int = 32 * 1024,
|
642
|
+
) -> tuple[str | None, Iterator[bytes]]:
|
643
|
+
content = response.iter_content(chunk_size)
|
644
|
+
first_chunk = None
|
645
|
+
while not first_chunk: # Filter out keep-alive new chunks.
|
646
|
+
first_chunk = next(content)
|
647
|
+
content = itertools.chain([first_chunk], content)
|
648
|
+
|
649
|
+
try:
|
650
|
+
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
|
651
|
+
api_response = match["api_response"] if match is not None else None
|
652
|
+
except UnicodeDecodeError:
|
653
|
+
api_response = None
|
654
|
+
return api_response, content
|
655
|
+
|
656
|
+
@classmethod
|
657
|
+
def download_file_from_google_drive(cls, file_id: str, root: Path, filename: str | None = None) -> None:
|
658
|
+
root = root.expanduser()
|
659
|
+
if not filename:
|
660
|
+
filename = file_id
|
661
|
+
fpath = root / filename
|
662
|
+
root.mkdir(parents=True, exist_ok=True)
|
663
|
+
|
664
|
+
url = "https://drive.google.com/uc"
|
665
|
+
params = dict(id=file_id, export="download")
|
666
|
+
with requests.Session() as session:
|
667
|
+
response = session.get(url, params=params, stream=True)
|
668
|
+
|
669
|
+
token: str | None = None
|
670
|
+
for key, value in response.cookies.items():
|
671
|
+
if key.startswith("download_warning"):
|
672
|
+
token = value
|
673
|
+
break
|
674
|
+
else:
|
675
|
+
api_response, content = cls._extract_gdrive_api_response(response)
|
676
|
+
token = "t" if api_response == "Virus scan warning" else None
|
677
|
+
|
678
|
+
if token is not None:
|
679
|
+
response = session.get(url, params=dict(params, confirm=token), stream=True)
|
680
|
+
api_response, content = cls._extract_gdrive_api_response(response)
|
681
|
+
|
682
|
+
if api_response == "Quota exceeded":
|
683
|
+
raise RuntimeError(
|
684
|
+
f"The daily quota of the file {filename} is exceeded and it "
|
685
|
+
f"can't be downloaded. This is a limitation of Google Drive "
|
686
|
+
f"and can only be overcome by trying again later."
|
687
|
+
)
|
688
|
+
|
689
|
+
cls._save_response_content(content, fpath)
|
690
|
+
|
691
|
+
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB with only text.
|
692
|
+
if os.stat(fpath).st_size < 10 * 1024:
|
693
|
+
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
|
694
|
+
text = fh.read()
|
695
|
+
|
696
|
+
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
|
697
|
+
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
|
698
|
+
warnings.warn(
|
699
|
+
f"We detected some HTML elements in the downloaded file. "
|
700
|
+
f"This most likely means that the download triggered an unhandled API response by GDrive. "
|
701
|
+
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
|
702
|
+
f"the response:\n\n{text}"
|
703
|
+
)
|
704
|
+
|
705
|
+
@classmethod
|
706
|
+
def _get_google_drive_file_id(cls, url: str) -> str | None:
|
707
|
+
parts = urlparse(url)
|
708
|
+
if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
|
709
|
+
return None
|
710
|
+
match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
|
711
|
+
if match is None:
|
712
|
+
return None
|
713
|
+
return match.group("id")
|
714
|
+
|
715
|
+
@classmethod
|
716
|
+
def _get_redirect_url(cls, url: str, max_hops: int = 3) -> str:
|
717
|
+
initial_url = url
|
718
|
+
headers = {"Method": "HEAD", "User-Agent": USER_AGENT}
|
719
|
+
for _ in range(max_hops + 1):
|
720
|
+
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
|
721
|
+
if response.url == url or response.url is None:
|
722
|
+
return url
|
723
|
+
url = response.url
|
724
|
+
raise RecursionError(f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect was {url}.")
|
725
|
+
|
726
|
+
|
727
|
+
class ModelDownloader(BaseFileDownloader):
|
728
|
+
def get_root_directory(self) -> Path:
|
729
|
+
return get_pretrained_models_dir()
|
730
|
+
|
731
|
+
|
732
|
+
class DataDownloader(BaseFileDownloader):
|
733
|
+
def get_root_directory(self) -> Path:
|
734
|
+
return get_data_dir()
|
735
|
+
|
736
|
+
|
737
|
+
def get_state_dict_prefix(
|
738
|
+
ckpt: dict[str, T],
|
739
|
+
prefix: str | None = None,
|
740
|
+
suffix: str | None = None,
|
741
|
+
regexp: re.Pattern[str] | None = None,
|
742
|
+
) -> dict[str, T]:
|
743
|
+
"""Returns the parts of a checkpoint which begin with a prefix.
|
744
|
+
|
745
|
+
Args:
|
746
|
+
ckpt: The checkpoint to modify
|
747
|
+
prefix: The prefix to clip
|
748
|
+
suffix: The suffix to clip
|
749
|
+
regexp: The regexp to search for (doesn't modify any keys)
|
750
|
+
|
751
|
+
Returns:
|
752
|
+
The modified checkpoint
|
753
|
+
"""
|
754
|
+
if prefix is not None:
|
755
|
+
ckpt = {k[len(prefix) :]: v for k, v in ckpt.items() if k.startswith(prefix)}
|
756
|
+
if suffix is not None:
|
757
|
+
ckpt = {k[: -len(suffix)]: v for k, v in ckpt.items() if k.endswith(suffix)}
|
758
|
+
if regexp is not None:
|
759
|
+
ckpt = {k: v for k, v in ckpt.items() if regexp.match(k)}
|
760
|
+
return ckpt
|
761
|
+
|
762
|
+
|
763
|
+
def split_n_items_across_workers(n: int, worker_id: int, num_workers: int) -> tuple[int, int]:
|
764
|
+
"""Computes offsets for splitting N items across K workers.
|
765
|
+
|
766
|
+
This returns the start and end indices for the items to be processed by the
|
767
|
+
given worker. The end index is exclusive.
|
768
|
+
|
769
|
+
Args:
|
770
|
+
n: The number of items to process.
|
771
|
+
worker_id: The ID of the current worker.
|
772
|
+
num_workers: The total number of workers.
|
773
|
+
|
774
|
+
Returns:
|
775
|
+
The start and end index for the items in the current worker.
|
776
|
+
"""
|
777
|
+
assert n >= num_workers, f"n ({n}) must be >= num_workers ({num_workers})"
|
778
|
+
assert 0 <= worker_id < num_workers, f"worker_id ({worker_id}) must be >= 0 and < num_workers ({num_workers})"
|
779
|
+
|
780
|
+
# The number of items to process per worker.
|
781
|
+
items_per_worker = math.ceil(n / num_workers)
|
782
|
+
|
783
|
+
# The start and end indices for the items to process.
|
784
|
+
start = worker_id * items_per_worker
|
785
|
+
end = min(start + items_per_worker, n)
|
786
|
+
|
787
|
+
return start, end
|
788
|
+
|
789
|
+
|
790
|
+
def num_workers(default: int) -> int:
|
791
|
+
max_workers = load_user_config().experiment.max_workers
|
792
|
+
if hasattr(os, "sched_getaffinity"):
|
793
|
+
try:
|
794
|
+
return min(len(os.sched_getaffinity(0)), max_workers)
|
795
|
+
except Exception:
|
796
|
+
pass
|
797
|
+
if (cpu_count := os.cpu_count()) is not None:
|
798
|
+
return min(cpu_count, max_workers)
|
799
|
+
return min(default, max_workers)
|
800
|
+
|
801
|
+
|
802
|
+
OmegaConf.register_new_resolver("mlfab.num_workers", num_workers, replace=True)
|