nshtrainer 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nshtrainer/__init__.py +2 -1
- nshtrainer/callbacks/__init__.py +17 -1
- nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
- nshtrainer/callbacks/base.py +7 -5
- nshtrainer/callbacks/ema.py +1 -1
- nshtrainer/callbacks/finite_checks.py +1 -1
- nshtrainer/callbacks/gradient_skipping.py +1 -1
- nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
- nshtrainer/callbacks/model_checkpoint.py +187 -0
- nshtrainer/callbacks/norm_logging.py +1 -1
- nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
- nshtrainer/callbacks/print_table.py +1 -1
- nshtrainer/callbacks/throughput_monitor.py +1 -1
- nshtrainer/callbacks/timer.py +1 -1
- nshtrainer/callbacks/wandb_watch.py +1 -1
- nshtrainer/ll/__init__.py +0 -1
- nshtrainer/ll/actsave.py +2 -1
- nshtrainer/metrics/__init__.py +1 -0
- nshtrainer/metrics/_config.py +37 -0
- nshtrainer/model/__init__.py +11 -11
- nshtrainer/model/_environment.py +777 -0
- nshtrainer/model/base.py +5 -114
- nshtrainer/model/config.py +49 -501
- nshtrainer/model/modules/logger.py +11 -6
- nshtrainer/runner.py +3 -6
- nshtrainer/trainer/_checkpoint_metadata.py +102 -0
- nshtrainer/trainer/_checkpoint_resolver.py +319 -0
- nshtrainer/trainer/_runtime_callback.py +120 -0
- nshtrainer/trainer/checkpoint_connector.py +63 -0
- nshtrainer/trainer/signal_connector.py +12 -9
- nshtrainer/trainer/trainer.py +111 -31
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
- nshtrainer/actsave/__init__.py +0 -3
- {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,777 @@
|
|
|
1
|
+
import getpass
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import platform
|
|
6
|
+
import socket
|
|
7
|
+
import sys
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
11
|
+
|
|
12
|
+
import git
|
|
13
|
+
import nshconfig as C
|
|
14
|
+
import psutil
|
|
15
|
+
import torch
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
from ..util.slurm import parse_slurm_node_list
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from .base import LightningModuleBase
|
|
22
|
+
from .config import BaseConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class EnvironmentClassInformationConfig(C.Config):
|
|
29
|
+
"""Configuration for class information in the environment."""
|
|
30
|
+
|
|
31
|
+
name: str | None
|
|
32
|
+
"""The name of the class."""
|
|
33
|
+
|
|
34
|
+
module: str | None
|
|
35
|
+
"""The module where the class is defined."""
|
|
36
|
+
|
|
37
|
+
full_name: str | None
|
|
38
|
+
"""The fully qualified name of the class."""
|
|
39
|
+
|
|
40
|
+
file_path: Path | None
|
|
41
|
+
"""The file path where the class is defined."""
|
|
42
|
+
|
|
43
|
+
source_file_path: Path | None
|
|
44
|
+
"""The source file path of the class, if available."""
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def empty(cls):
|
|
48
|
+
return cls(
|
|
49
|
+
name=None,
|
|
50
|
+
module=None,
|
|
51
|
+
full_name=None,
|
|
52
|
+
file_path=None,
|
|
53
|
+
source_file_path=None,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def from_class(cls, cls_: type):
|
|
58
|
+
name = cls_.__name__
|
|
59
|
+
module = cls_.__module__
|
|
60
|
+
full_name = f"{cls_.__module__}.{cls_.__qualname__}"
|
|
61
|
+
|
|
62
|
+
file_path = inspect.getfile(cls_)
|
|
63
|
+
source_file_path = inspect.getsourcefile(cls_)
|
|
64
|
+
return cls(
|
|
65
|
+
name=name,
|
|
66
|
+
module=module,
|
|
67
|
+
full_name=full_name,
|
|
68
|
+
file_path=Path(file_path),
|
|
69
|
+
source_file_path=Path(source_file_path) if source_file_path else None,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_instance(cls, instance: object):
|
|
74
|
+
return cls.from_class(type(instance))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class EnvironmentSLURMInformationConfig(C.Config):
|
|
78
|
+
"""Configuration for SLURM environment information."""
|
|
79
|
+
|
|
80
|
+
hostname: str | None
|
|
81
|
+
"""The hostname of the current node."""
|
|
82
|
+
|
|
83
|
+
hostnames: list[str] | None
|
|
84
|
+
"""List of hostnames for all nodes in the job."""
|
|
85
|
+
|
|
86
|
+
job_id: str | None
|
|
87
|
+
"""The SLURM job ID."""
|
|
88
|
+
|
|
89
|
+
raw_job_id: str | None
|
|
90
|
+
"""The raw SLURM job ID."""
|
|
91
|
+
|
|
92
|
+
array_job_id: str | None
|
|
93
|
+
"""The SLURM array job ID, if applicable."""
|
|
94
|
+
|
|
95
|
+
array_task_id: str | None
|
|
96
|
+
"""The SLURM array task ID, if applicable."""
|
|
97
|
+
|
|
98
|
+
num_tasks: int | None
|
|
99
|
+
"""The number of tasks in the SLURM job."""
|
|
100
|
+
|
|
101
|
+
num_nodes: int | None
|
|
102
|
+
"""The number of nodes in the SLURM job."""
|
|
103
|
+
|
|
104
|
+
node: str | int | None
|
|
105
|
+
"""The node ID or name."""
|
|
106
|
+
|
|
107
|
+
global_rank: int | None
|
|
108
|
+
"""The global rank of the current process."""
|
|
109
|
+
|
|
110
|
+
local_rank: int | None
|
|
111
|
+
"""The local rank of the current process within its node."""
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def empty(cls):
|
|
115
|
+
return cls(
|
|
116
|
+
hostname=None,
|
|
117
|
+
hostnames=None,
|
|
118
|
+
job_id=None,
|
|
119
|
+
raw_job_id=None,
|
|
120
|
+
array_job_id=None,
|
|
121
|
+
array_task_id=None,
|
|
122
|
+
num_tasks=None,
|
|
123
|
+
num_nodes=None,
|
|
124
|
+
node=None,
|
|
125
|
+
global_rank=None,
|
|
126
|
+
local_rank=None,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def from_current_environment(cls):
|
|
131
|
+
try:
|
|
132
|
+
from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
|
|
133
|
+
|
|
134
|
+
if not SLURMEnvironment.detect():
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
hostname = socket.gethostname()
|
|
138
|
+
hostnames = [hostname]
|
|
139
|
+
if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
|
|
140
|
+
hostnames = parse_slurm_node_list(node_list)
|
|
141
|
+
|
|
142
|
+
raw_job_id = os.environ["SLURM_JOB_ID"]
|
|
143
|
+
job_id = raw_job_id
|
|
144
|
+
array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
|
|
145
|
+
array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
|
|
146
|
+
if array_job_id and array_task_id:
|
|
147
|
+
job_id = f"{array_job_id}_{array_task_id}"
|
|
148
|
+
|
|
149
|
+
num_tasks = int(os.environ["SLURM_NTASKS"])
|
|
150
|
+
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
|
|
151
|
+
|
|
152
|
+
node_id = os.environ.get("SLURM_NODEID")
|
|
153
|
+
|
|
154
|
+
global_rank = int(os.environ["SLURM_PROCID"])
|
|
155
|
+
local_rank = int(os.environ["SLURM_LOCALID"])
|
|
156
|
+
|
|
157
|
+
return cls(
|
|
158
|
+
hostname=hostname,
|
|
159
|
+
hostnames=hostnames,
|
|
160
|
+
job_id=job_id,
|
|
161
|
+
raw_job_id=raw_job_id,
|
|
162
|
+
array_job_id=array_job_id,
|
|
163
|
+
array_task_id=array_task_id,
|
|
164
|
+
num_tasks=num_tasks,
|
|
165
|
+
num_nodes=num_nodes,
|
|
166
|
+
node=node_id,
|
|
167
|
+
global_rank=global_rank,
|
|
168
|
+
local_rank=local_rank,
|
|
169
|
+
)
|
|
170
|
+
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class EnvironmentLSFInformationConfig(C.Config):
|
|
175
|
+
"""Configuration for LSF environment information."""
|
|
176
|
+
|
|
177
|
+
hostname: str | None
|
|
178
|
+
"""The hostname of the current node."""
|
|
179
|
+
|
|
180
|
+
hostnames: list[str] | None
|
|
181
|
+
"""List of hostnames for all nodes in the job."""
|
|
182
|
+
|
|
183
|
+
job_id: str | None
|
|
184
|
+
"""The LSF job ID."""
|
|
185
|
+
|
|
186
|
+
array_job_id: str | None
|
|
187
|
+
"""The LSF array job ID, if applicable."""
|
|
188
|
+
|
|
189
|
+
array_task_id: str | None
|
|
190
|
+
"""The LSF array task ID, if applicable."""
|
|
191
|
+
|
|
192
|
+
num_tasks: int | None
|
|
193
|
+
"""The number of tasks in the LSF job."""
|
|
194
|
+
|
|
195
|
+
num_nodes: int | None
|
|
196
|
+
"""The number of nodes in the LSF job."""
|
|
197
|
+
|
|
198
|
+
node: str | int | None
|
|
199
|
+
"""The node ID or name."""
|
|
200
|
+
|
|
201
|
+
global_rank: int | None
|
|
202
|
+
"""The global rank of the current process."""
|
|
203
|
+
|
|
204
|
+
local_rank: int | None
|
|
205
|
+
"""The local rank of the current process within its node."""
|
|
206
|
+
|
|
207
|
+
@classmethod
|
|
208
|
+
def empty(cls):
|
|
209
|
+
return cls(
|
|
210
|
+
hostname=None,
|
|
211
|
+
hostnames=None,
|
|
212
|
+
job_id=None,
|
|
213
|
+
array_job_id=None,
|
|
214
|
+
array_task_id=None,
|
|
215
|
+
num_tasks=None,
|
|
216
|
+
num_nodes=None,
|
|
217
|
+
node=None,
|
|
218
|
+
global_rank=None,
|
|
219
|
+
local_rank=None,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def from_current_environment(cls):
|
|
224
|
+
try:
|
|
225
|
+
import os
|
|
226
|
+
import socket
|
|
227
|
+
|
|
228
|
+
hostname = socket.gethostname()
|
|
229
|
+
hostnames = [hostname]
|
|
230
|
+
if node_list := os.environ.get("LSB_HOSTS", ""):
|
|
231
|
+
hostnames = node_list.split()
|
|
232
|
+
|
|
233
|
+
job_id = os.environ["LSB_JOBID"]
|
|
234
|
+
array_job_id = os.environ.get("LSB_JOBINDEX")
|
|
235
|
+
array_task_id = os.environ.get("LSB_JOBINDEX")
|
|
236
|
+
|
|
237
|
+
num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
|
|
238
|
+
num_nodes = len(set(hostnames))
|
|
239
|
+
|
|
240
|
+
node_id = (
|
|
241
|
+
os.environ.get("LSB_HOSTS", "").split().index(hostname)
|
|
242
|
+
if "LSB_HOSTS" in os.environ
|
|
243
|
+
else None
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# LSF doesn't have direct equivalents for global_rank and local_rank
|
|
247
|
+
# You might need to calculate these based on your specific setup
|
|
248
|
+
global_rank = int(os.environ.get("PMI_RANK", 0))
|
|
249
|
+
local_rank = int(os.environ.get("LSB_RANK", 0))
|
|
250
|
+
|
|
251
|
+
return cls(
|
|
252
|
+
hostname=hostname,
|
|
253
|
+
hostnames=hostnames,
|
|
254
|
+
job_id=job_id,
|
|
255
|
+
array_job_id=array_job_id,
|
|
256
|
+
array_task_id=array_task_id,
|
|
257
|
+
num_tasks=num_tasks,
|
|
258
|
+
num_nodes=num_nodes,
|
|
259
|
+
node=node_id,
|
|
260
|
+
global_rank=global_rank,
|
|
261
|
+
local_rank=local_rank,
|
|
262
|
+
)
|
|
263
|
+
except (ImportError, RuntimeError, ValueError, KeyError):
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class EnvironmentLinuxEnvironmentConfig(C.Config):
|
|
268
|
+
"""Configuration for Linux environment information."""
|
|
269
|
+
|
|
270
|
+
user: str | None
|
|
271
|
+
"""The current user."""
|
|
272
|
+
|
|
273
|
+
hostname: str | None
|
|
274
|
+
"""The hostname of the machine."""
|
|
275
|
+
|
|
276
|
+
system: str | None
|
|
277
|
+
"""The operating system name."""
|
|
278
|
+
|
|
279
|
+
release: str | None
|
|
280
|
+
"""The operating system release."""
|
|
281
|
+
|
|
282
|
+
version: str | None
|
|
283
|
+
"""The operating system version."""
|
|
284
|
+
|
|
285
|
+
machine: str | None
|
|
286
|
+
"""The machine type."""
|
|
287
|
+
|
|
288
|
+
processor: str | None
|
|
289
|
+
"""The processor type."""
|
|
290
|
+
|
|
291
|
+
cpu_count: int | None
|
|
292
|
+
"""The number of CPUs."""
|
|
293
|
+
|
|
294
|
+
memory: int | None
|
|
295
|
+
"""The total system memory in bytes."""
|
|
296
|
+
|
|
297
|
+
uptime: timedelta | None
|
|
298
|
+
"""The system uptime."""
|
|
299
|
+
|
|
300
|
+
boot_time: float | None
|
|
301
|
+
"""The system boot time as a timestamp."""
|
|
302
|
+
|
|
303
|
+
load_avg: tuple[float, float, float] | None
|
|
304
|
+
"""The system load average (1, 5, and 15 minutes)."""
|
|
305
|
+
|
|
306
|
+
@classmethod
|
|
307
|
+
def empty(cls):
|
|
308
|
+
return cls(
|
|
309
|
+
user=None,
|
|
310
|
+
hostname=None,
|
|
311
|
+
system=None,
|
|
312
|
+
release=None,
|
|
313
|
+
version=None,
|
|
314
|
+
machine=None,
|
|
315
|
+
processor=None,
|
|
316
|
+
cpu_count=None,
|
|
317
|
+
memory=None,
|
|
318
|
+
uptime=None,
|
|
319
|
+
boot_time=None,
|
|
320
|
+
load_avg=None,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
@classmethod
|
|
324
|
+
def from_current_environment(cls):
|
|
325
|
+
return cls(
|
|
326
|
+
user=getpass.getuser(),
|
|
327
|
+
hostname=platform.node(),
|
|
328
|
+
system=platform.system(),
|
|
329
|
+
release=platform.release(),
|
|
330
|
+
version=platform.version(),
|
|
331
|
+
machine=platform.machine(),
|
|
332
|
+
processor=platform.processor(),
|
|
333
|
+
cpu_count=os.cpu_count(),
|
|
334
|
+
memory=psutil.virtual_memory().total,
|
|
335
|
+
uptime=timedelta(seconds=psutil.boot_time()),
|
|
336
|
+
boot_time=psutil.boot_time(),
|
|
337
|
+
load_avg=os.getloadavg(),
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class EnvironmentSnapshotConfig(C.Config):
|
|
342
|
+
"""Configuration for environment snapshot information."""
|
|
343
|
+
|
|
344
|
+
snapshot_dir: Path | None
|
|
345
|
+
"""The directory where the snapshot is stored."""
|
|
346
|
+
|
|
347
|
+
modules: list[str] | None
|
|
348
|
+
"""List of modules included in the snapshot."""
|
|
349
|
+
|
|
350
|
+
@classmethod
|
|
351
|
+
def empty(cls):
|
|
352
|
+
return cls(snapshot_dir=None, modules=None)
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def from_current_environment(cls):
|
|
356
|
+
draft = cls.draft()
|
|
357
|
+
if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
|
|
358
|
+
draft.snapshot_dir = Path(snapshot_dir)
|
|
359
|
+
if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
|
|
360
|
+
draft.modules = modules.split(",")
|
|
361
|
+
return draft.finalize()
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class EnvironmentPackageConfig(C.Config):
|
|
365
|
+
"""Configuration for Python package information."""
|
|
366
|
+
|
|
367
|
+
name: str | None
|
|
368
|
+
"""The name of the package."""
|
|
369
|
+
|
|
370
|
+
version: str | None
|
|
371
|
+
"""The version of the package."""
|
|
372
|
+
|
|
373
|
+
path: Path | None
|
|
374
|
+
"""The installation path of the package."""
|
|
375
|
+
|
|
376
|
+
summary: str | None
|
|
377
|
+
"""A brief summary of the package."""
|
|
378
|
+
|
|
379
|
+
author: str | None
|
|
380
|
+
"""The author of the package."""
|
|
381
|
+
|
|
382
|
+
license: str | None
|
|
383
|
+
"""The license of the package."""
|
|
384
|
+
|
|
385
|
+
requires: list[str] | None
|
|
386
|
+
"""List of package dependencies."""
|
|
387
|
+
|
|
388
|
+
@classmethod
|
|
389
|
+
def empty(cls):
|
|
390
|
+
return cls(
|
|
391
|
+
name=None,
|
|
392
|
+
version=None,
|
|
393
|
+
path=None,
|
|
394
|
+
summary=None,
|
|
395
|
+
author=None,
|
|
396
|
+
license=None,
|
|
397
|
+
requires=None,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
@classmethod
|
|
401
|
+
def from_current_environment(cls):
|
|
402
|
+
# Add Python package information
|
|
403
|
+
python_packages: dict[str, Self] = {}
|
|
404
|
+
try:
|
|
405
|
+
import pkg_resources
|
|
406
|
+
|
|
407
|
+
for package in pkg_resources.working_set:
|
|
408
|
+
python_packages[package.key] = cls(
|
|
409
|
+
name=package.project_name,
|
|
410
|
+
version=package.version,
|
|
411
|
+
path=Path(package.location) if package.location else None,
|
|
412
|
+
summary=package.summary,
|
|
413
|
+
author=package.author,
|
|
414
|
+
license=package.license,
|
|
415
|
+
requires=[str(req) for req in package.requires()],
|
|
416
|
+
)
|
|
417
|
+
except ImportError:
|
|
418
|
+
log.warning("pkg_resources not available, skipping package information")
|
|
419
|
+
|
|
420
|
+
return python_packages
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class EnvironmentGPUConfig(C.Config):
|
|
424
|
+
"""Configuration for individual GPU information."""
|
|
425
|
+
|
|
426
|
+
name: str | None
|
|
427
|
+
"""Name of the GPU."""
|
|
428
|
+
|
|
429
|
+
total_memory: int | None
|
|
430
|
+
"""Total memory of the GPU in bytes."""
|
|
431
|
+
|
|
432
|
+
major: int | None
|
|
433
|
+
"""Major version of CUDA capability."""
|
|
434
|
+
|
|
435
|
+
minor: int | None
|
|
436
|
+
"""Minor version of CUDA capability."""
|
|
437
|
+
|
|
438
|
+
multi_processor_count: int | None
|
|
439
|
+
"""Number of multiprocessors on the GPU."""
|
|
440
|
+
|
|
441
|
+
@classmethod
|
|
442
|
+
def empty(cls):
|
|
443
|
+
return cls(
|
|
444
|
+
name=None,
|
|
445
|
+
total_memory=None,
|
|
446
|
+
major=None,
|
|
447
|
+
minor=None,
|
|
448
|
+
multi_processor_count=None,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class EnvironmentCUDAConfig(C.Config):
|
|
453
|
+
"""Configuration for CUDA environment information."""
|
|
454
|
+
|
|
455
|
+
is_available: bool | None
|
|
456
|
+
"""Whether CUDA is available."""
|
|
457
|
+
|
|
458
|
+
version: str | None
|
|
459
|
+
"""CUDA version."""
|
|
460
|
+
|
|
461
|
+
cudnn_version: int | None
|
|
462
|
+
"""cuDNN version."""
|
|
463
|
+
|
|
464
|
+
@classmethod
|
|
465
|
+
def empty(cls):
|
|
466
|
+
return cls(is_available=None, version=None, cudnn_version=None)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
class EnvironmentHardwareConfig(C.Config):
|
|
470
|
+
"""Configuration for hardware information."""
|
|
471
|
+
|
|
472
|
+
cpu_count_physical: int | None
|
|
473
|
+
"""Number of physical CPU cores."""
|
|
474
|
+
|
|
475
|
+
cpu_count_logical: int | None
|
|
476
|
+
"""Number of logical CPU cores."""
|
|
477
|
+
|
|
478
|
+
cpu_frequency_current: float | None
|
|
479
|
+
"""Current CPU frequency in MHz."""
|
|
480
|
+
|
|
481
|
+
cpu_frequency_min: float | None
|
|
482
|
+
"""Minimum CPU frequency in MHz."""
|
|
483
|
+
|
|
484
|
+
cpu_frequency_max: float | None
|
|
485
|
+
"""Maximum CPU frequency in MHz."""
|
|
486
|
+
|
|
487
|
+
ram_total: int | None
|
|
488
|
+
"""Total RAM in bytes."""
|
|
489
|
+
|
|
490
|
+
ram_available: int | None
|
|
491
|
+
"""Available RAM in bytes."""
|
|
492
|
+
|
|
493
|
+
disk_total: int | None
|
|
494
|
+
"""Total disk space in bytes."""
|
|
495
|
+
|
|
496
|
+
disk_used: int | None
|
|
497
|
+
"""Used disk space in bytes."""
|
|
498
|
+
|
|
499
|
+
disk_free: int | None
|
|
500
|
+
"""Free disk space in bytes."""
|
|
501
|
+
|
|
502
|
+
gpu_count: int | None
|
|
503
|
+
"""Number of GPUs available."""
|
|
504
|
+
|
|
505
|
+
gpus: list[EnvironmentGPUConfig] | None
|
|
506
|
+
"""List of GPU configurations."""
|
|
507
|
+
|
|
508
|
+
cuda: EnvironmentCUDAConfig | None
|
|
509
|
+
"""CUDA environment configuration."""
|
|
510
|
+
|
|
511
|
+
@classmethod
|
|
512
|
+
def empty(cls):
|
|
513
|
+
return cls(
|
|
514
|
+
cpu_count_physical=None,
|
|
515
|
+
cpu_count_logical=None,
|
|
516
|
+
cpu_frequency_current=None,
|
|
517
|
+
cpu_frequency_min=None,
|
|
518
|
+
cpu_frequency_max=None,
|
|
519
|
+
ram_total=None,
|
|
520
|
+
ram_available=None,
|
|
521
|
+
disk_total=None,
|
|
522
|
+
disk_used=None,
|
|
523
|
+
disk_free=None,
|
|
524
|
+
gpu_count=None,
|
|
525
|
+
gpus=None,
|
|
526
|
+
cuda=None,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
@classmethod
|
|
530
|
+
def from_current_environment(cls):
|
|
531
|
+
draft = cls.draft()
|
|
532
|
+
|
|
533
|
+
# CPU information
|
|
534
|
+
draft.cpu_count_physical = psutil.cpu_count(logical=False)
|
|
535
|
+
draft.cpu_count_logical = psutil.cpu_count(logical=True)
|
|
536
|
+
cpu_freq = psutil.cpu_freq()
|
|
537
|
+
if cpu_freq:
|
|
538
|
+
draft.cpu_frequency_current = cpu_freq.current
|
|
539
|
+
draft.cpu_frequency_min = cpu_freq.min
|
|
540
|
+
draft.cpu_frequency_max = cpu_freq.max
|
|
541
|
+
|
|
542
|
+
# RAM information
|
|
543
|
+
ram = psutil.virtual_memory()
|
|
544
|
+
draft.ram_total = ram.total
|
|
545
|
+
draft.ram_available = ram.available
|
|
546
|
+
|
|
547
|
+
# Disk information
|
|
548
|
+
disk = psutil.disk_usage("/")
|
|
549
|
+
draft.disk_total = disk.total
|
|
550
|
+
draft.disk_used = disk.used
|
|
551
|
+
draft.disk_free = disk.free
|
|
552
|
+
|
|
553
|
+
# GPU and CUDA information
|
|
554
|
+
draft.cuda = EnvironmentCUDAConfig(
|
|
555
|
+
is_available=torch.cuda.is_available(),
|
|
556
|
+
version=cast(Any, torch).version.cuda,
|
|
557
|
+
cudnn_version=torch.backends.cudnn.version()
|
|
558
|
+
if torch.backends.cudnn.is_available()
|
|
559
|
+
else None,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
if draft.cuda.is_available:
|
|
563
|
+
draft.gpu_count = torch.cuda.device_count()
|
|
564
|
+
draft.gpus = []
|
|
565
|
+
for i in range(draft.gpu_count):
|
|
566
|
+
gpu_props = torch.cuda.get_device_properties(i)
|
|
567
|
+
gpu_config = EnvironmentGPUConfig(
|
|
568
|
+
name=gpu_props.name,
|
|
569
|
+
total_memory=gpu_props.total_memory,
|
|
570
|
+
major=gpu_props.major,
|
|
571
|
+
minor=gpu_props.minor,
|
|
572
|
+
multi_processor_count=gpu_props.multi_processor_count,
|
|
573
|
+
)
|
|
574
|
+
draft.gpus.append(gpu_config)
|
|
575
|
+
|
|
576
|
+
return draft.finalize()
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
class GitRepositoryConfig(C.Config):
|
|
580
|
+
"""Configuration for Git repository information."""
|
|
581
|
+
|
|
582
|
+
is_git_repo: bool | None
|
|
583
|
+
"""Whether the current directory is a Git repository."""
|
|
584
|
+
|
|
585
|
+
branch: str | None
|
|
586
|
+
"""The current Git branch."""
|
|
587
|
+
|
|
588
|
+
commit_hash: str | None
|
|
589
|
+
"""The current commit hash."""
|
|
590
|
+
|
|
591
|
+
commit_message: str | None
|
|
592
|
+
"""The current commit message."""
|
|
593
|
+
|
|
594
|
+
author: str | None
|
|
595
|
+
"""The author of the current commit."""
|
|
596
|
+
|
|
597
|
+
commit_date: str | None
|
|
598
|
+
"""The date of the current commit."""
|
|
599
|
+
|
|
600
|
+
remote_url: str | None
|
|
601
|
+
"""The URL of the remote repository."""
|
|
602
|
+
|
|
603
|
+
is_dirty: bool | None
|
|
604
|
+
"""Whether there are uncommitted changes."""
|
|
605
|
+
|
|
606
|
+
@classmethod
|
|
607
|
+
def empty(cls):
|
|
608
|
+
return cls(
|
|
609
|
+
is_git_repo=None,
|
|
610
|
+
branch=None,
|
|
611
|
+
commit_hash=None,
|
|
612
|
+
commit_message=None,
|
|
613
|
+
author=None,
|
|
614
|
+
commit_date=None,
|
|
615
|
+
remote_url=None,
|
|
616
|
+
is_dirty=None,
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
@classmethod
|
|
620
|
+
def from_current_directory(cls):
|
|
621
|
+
draft = cls.draft()
|
|
622
|
+
try:
|
|
623
|
+
repo = git.Repo(os.getcwd(), search_parent_directories=True)
|
|
624
|
+
draft.is_git_repo = True
|
|
625
|
+
draft.branch = repo.active_branch.name
|
|
626
|
+
commit = repo.head.commit
|
|
627
|
+
draft.commit_hash = commit.hexsha
|
|
628
|
+
|
|
629
|
+
# Handle both str and bytes for commit message
|
|
630
|
+
if isinstance(commit.message, str):
|
|
631
|
+
draft.commit_message = commit.message.strip()
|
|
632
|
+
elif isinstance(commit.message, bytes):
|
|
633
|
+
draft.commit_message = commit.message.decode(
|
|
634
|
+
"utf-8", errors="replace"
|
|
635
|
+
).strip()
|
|
636
|
+
else:
|
|
637
|
+
draft.commit_message = str(commit.message).strip()
|
|
638
|
+
|
|
639
|
+
draft.author = f"{commit.author.name} <{commit.author.email}>"
|
|
640
|
+
draft.commit_date = commit.committed_datetime.isoformat()
|
|
641
|
+
if repo.remotes:
|
|
642
|
+
draft.remote_url = repo.remotes.origin.url
|
|
643
|
+
draft.is_dirty = repo.is_dirty()
|
|
644
|
+
except git.InvalidGitRepositoryError:
|
|
645
|
+
draft.is_git_repo = False
|
|
646
|
+
except Exception as e:
|
|
647
|
+
log.warning(f"Failed to get Git repository information: {e}")
|
|
648
|
+
draft.is_git_repo = None
|
|
649
|
+
|
|
650
|
+
return draft.finalize()
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class EnvironmentConfig(C.Config):
|
|
654
|
+
"""Configuration for the overall environment."""
|
|
655
|
+
|
|
656
|
+
cwd: Path | None
|
|
657
|
+
"""The current working directory."""
|
|
658
|
+
|
|
659
|
+
snapshot: EnvironmentSnapshotConfig | None
|
|
660
|
+
"""The environment snapshot configuration."""
|
|
661
|
+
|
|
662
|
+
python_executable: Path | None
|
|
663
|
+
"""The path to the Python executable."""
|
|
664
|
+
|
|
665
|
+
python_path: list[Path] | None
|
|
666
|
+
"""The Python path."""
|
|
667
|
+
|
|
668
|
+
python_version: str | None
|
|
669
|
+
"""The Python version."""
|
|
670
|
+
|
|
671
|
+
python_packages: dict[str, EnvironmentPackageConfig] | None
|
|
672
|
+
"""A mapping of package names to their configurations."""
|
|
673
|
+
|
|
674
|
+
config: EnvironmentClassInformationConfig | None
|
|
675
|
+
"""The configuration class information."""
|
|
676
|
+
|
|
677
|
+
model: EnvironmentClassInformationConfig | None
|
|
678
|
+
"""The Lightning module class information."""
|
|
679
|
+
|
|
680
|
+
linux: EnvironmentLinuxEnvironmentConfig | None
|
|
681
|
+
"""The Linux environment information."""
|
|
682
|
+
|
|
683
|
+
hardware: EnvironmentHardwareConfig | None
|
|
684
|
+
"""Hardware configuration information."""
|
|
685
|
+
|
|
686
|
+
slurm: EnvironmentSLURMInformationConfig | None
|
|
687
|
+
"""The SLURM environment information."""
|
|
688
|
+
|
|
689
|
+
lsf: EnvironmentLSFInformationConfig | None
|
|
690
|
+
"""The LSF environment information."""
|
|
691
|
+
|
|
692
|
+
base_dir: Path | None
|
|
693
|
+
"""The base directory for the run."""
|
|
694
|
+
|
|
695
|
+
log_dir: Path | None
|
|
696
|
+
"""The directory for logs."""
|
|
697
|
+
|
|
698
|
+
checkpoint_dir: Path | None
|
|
699
|
+
"""The directory for checkpoints."""
|
|
700
|
+
|
|
701
|
+
stdio_dir: Path | None
|
|
702
|
+
"""The directory for standard input/output files."""
|
|
703
|
+
|
|
704
|
+
seed: int | None
|
|
705
|
+
"""The global random seed."""
|
|
706
|
+
|
|
707
|
+
seed_workers: bool | None
|
|
708
|
+
"""Whether to seed workers."""
|
|
709
|
+
|
|
710
|
+
git: GitRepositoryConfig | None
|
|
711
|
+
"""Git repository information."""
|
|
712
|
+
|
|
713
|
+
@classmethod
|
|
714
|
+
def empty(cls):
|
|
715
|
+
return cls(
|
|
716
|
+
cwd=None,
|
|
717
|
+
snapshot=None,
|
|
718
|
+
python_executable=None,
|
|
719
|
+
python_path=None,
|
|
720
|
+
python_version=None,
|
|
721
|
+
python_packages=None,
|
|
722
|
+
config=None,
|
|
723
|
+
model=None,
|
|
724
|
+
linux=None,
|
|
725
|
+
hardware=None,
|
|
726
|
+
slurm=None,
|
|
727
|
+
lsf=None,
|
|
728
|
+
base_dir=None,
|
|
729
|
+
log_dir=None,
|
|
730
|
+
checkpoint_dir=None,
|
|
731
|
+
stdio_dir=None,
|
|
732
|
+
seed=None,
|
|
733
|
+
seed_workers=None,
|
|
734
|
+
git=None,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
@classmethod
|
|
738
|
+
def from_current_environment(
|
|
739
|
+
cls,
|
|
740
|
+
root_config: "BaseConfig",
|
|
741
|
+
model: "LightningModuleBase",
|
|
742
|
+
):
|
|
743
|
+
draft = cls.draft()
|
|
744
|
+
draft.cwd = Path(os.getcwd())
|
|
745
|
+
draft.python_executable = Path(sys.executable)
|
|
746
|
+
draft.python_path = [Path(path) for path in sys.path]
|
|
747
|
+
draft.python_version = sys.version
|
|
748
|
+
draft.python_packages = EnvironmentPackageConfig.from_current_environment()
|
|
749
|
+
draft.config = EnvironmentClassInformationConfig.from_instance(root_config)
|
|
750
|
+
draft.model = EnvironmentClassInformationConfig.from_instance(model)
|
|
751
|
+
draft.linux = EnvironmentLinuxEnvironmentConfig.from_current_environment()
|
|
752
|
+
draft.hardware = EnvironmentHardwareConfig.from_current_environment()
|
|
753
|
+
draft.slurm = EnvironmentSLURMInformationConfig.from_current_environment()
|
|
754
|
+
draft.lsf = EnvironmentLSFInformationConfig.from_current_environment()
|
|
755
|
+
draft.base_dir = root_config.directory.resolve_run_root_directory(
|
|
756
|
+
root_config.id
|
|
757
|
+
)
|
|
758
|
+
draft.log_dir = root_config.directory.resolve_subdirectory(
|
|
759
|
+
root_config.id, "log"
|
|
760
|
+
)
|
|
761
|
+
draft.checkpoint_dir = root_config.directory.resolve_subdirectory(
|
|
762
|
+
root_config.id, "checkpoint"
|
|
763
|
+
)
|
|
764
|
+
draft.stdio_dir = root_config.directory.resolve_subdirectory(
|
|
765
|
+
root_config.id, "stdio"
|
|
766
|
+
)
|
|
767
|
+
draft.seed = (
|
|
768
|
+
int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
|
|
769
|
+
)
|
|
770
|
+
draft.seed_workers = (
|
|
771
|
+
bool(int(seed_everything))
|
|
772
|
+
if (seed_everything := os.environ.get("PL_SEED_WORKERS"))
|
|
773
|
+
else None
|
|
774
|
+
)
|
|
775
|
+
draft.snapshot = EnvironmentSnapshotConfig.from_current_environment()
|
|
776
|
+
draft.git = GitRepositoryConfig.from_current_directory()
|
|
777
|
+
return draft.finalize()
|