nshtrainer 0.9.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. nshtrainer/__init__.py +2 -1
  2. nshtrainer/callbacks/__init__.py +17 -1
  3. nshtrainer/{actsave/_callback.py → callbacks/actsave.py} +68 -10
  4. nshtrainer/callbacks/base.py +7 -5
  5. nshtrainer/callbacks/ema.py +1 -1
  6. nshtrainer/callbacks/finite_checks.py +1 -1
  7. nshtrainer/callbacks/gradient_skipping.py +1 -1
  8. nshtrainer/callbacks/latest_epoch_checkpoint.py +50 -14
  9. nshtrainer/callbacks/model_checkpoint.py +187 -0
  10. nshtrainer/callbacks/norm_logging.py +1 -1
  11. nshtrainer/callbacks/on_exception_checkpoint.py +76 -22
  12. nshtrainer/callbacks/print_table.py +1 -1
  13. nshtrainer/callbacks/throughput_monitor.py +1 -1
  14. nshtrainer/callbacks/timer.py +1 -1
  15. nshtrainer/callbacks/wandb_watch.py +1 -1
  16. nshtrainer/ll/__init__.py +0 -1
  17. nshtrainer/ll/actsave.py +2 -1
  18. nshtrainer/metrics/__init__.py +1 -0
  19. nshtrainer/metrics/_config.py +37 -0
  20. nshtrainer/model/__init__.py +11 -11
  21. nshtrainer/model/_environment.py +777 -0
  22. nshtrainer/model/base.py +5 -114
  23. nshtrainer/model/config.py +49 -501
  24. nshtrainer/model/modules/logger.py +11 -6
  25. nshtrainer/runner.py +3 -6
  26. nshtrainer/trainer/_checkpoint_metadata.py +102 -0
  27. nshtrainer/trainer/_checkpoint_resolver.py +319 -0
  28. nshtrainer/trainer/_runtime_callback.py +120 -0
  29. nshtrainer/trainer/checkpoint_connector.py +63 -0
  30. nshtrainer/trainer/signal_connector.py +12 -9
  31. nshtrainer/trainer/trainer.py +111 -31
  32. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/METADATA +3 -1
  33. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/RECORD +34 -27
  34. nshtrainer/actsave/__init__.py +0 -3
  35. {nshtrainer-0.9.1.dist-info → nshtrainer-0.10.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,777 @@
1
+ import getpass
2
+ import inspect
3
+ import logging
4
+ import os
5
+ import platform
6
+ import socket
7
+ import sys
8
+ from datetime import timedelta
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any, cast
11
+
12
+ import git
13
+ import nshconfig as C
14
+ import psutil
15
+ import torch
16
+ from typing_extensions import Self
17
+
18
+ from ..util.slurm import parse_slurm_node_list
19
+
20
+ if TYPE_CHECKING:
21
+ from .base import LightningModuleBase
22
+ from .config import BaseConfig
23
+
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+
28
+ class EnvironmentClassInformationConfig(C.Config):
29
+ """Configuration for class information in the environment."""
30
+
31
+ name: str | None
32
+ """The name of the class."""
33
+
34
+ module: str | None
35
+ """The module where the class is defined."""
36
+
37
+ full_name: str | None
38
+ """The fully qualified name of the class."""
39
+
40
+ file_path: Path | None
41
+ """The file path where the class is defined."""
42
+
43
+ source_file_path: Path | None
44
+ """The source file path of the class, if available."""
45
+
46
+ @classmethod
47
+ def empty(cls):
48
+ return cls(
49
+ name=None,
50
+ module=None,
51
+ full_name=None,
52
+ file_path=None,
53
+ source_file_path=None,
54
+ )
55
+
56
+ @classmethod
57
+ def from_class(cls, cls_: type):
58
+ name = cls_.__name__
59
+ module = cls_.__module__
60
+ full_name = f"{cls_.__module__}.{cls_.__qualname__}"
61
+
62
+ file_path = inspect.getfile(cls_)
63
+ source_file_path = inspect.getsourcefile(cls_)
64
+ return cls(
65
+ name=name,
66
+ module=module,
67
+ full_name=full_name,
68
+ file_path=Path(file_path),
69
+ source_file_path=Path(source_file_path) if source_file_path else None,
70
+ )
71
+
72
+ @classmethod
73
+ def from_instance(cls, instance: object):
74
+ return cls.from_class(type(instance))
75
+
76
+
77
+ class EnvironmentSLURMInformationConfig(C.Config):
78
+ """Configuration for SLURM environment information."""
79
+
80
+ hostname: str | None
81
+ """The hostname of the current node."""
82
+
83
+ hostnames: list[str] | None
84
+ """List of hostnames for all nodes in the job."""
85
+
86
+ job_id: str | None
87
+ """The SLURM job ID."""
88
+
89
+ raw_job_id: str | None
90
+ """The raw SLURM job ID."""
91
+
92
+ array_job_id: str | None
93
+ """The SLURM array job ID, if applicable."""
94
+
95
+ array_task_id: str | None
96
+ """The SLURM array task ID, if applicable."""
97
+
98
+ num_tasks: int | None
99
+ """The number of tasks in the SLURM job."""
100
+
101
+ num_nodes: int | None
102
+ """The number of nodes in the SLURM job."""
103
+
104
+ node: str | int | None
105
+ """The node ID or name."""
106
+
107
+ global_rank: int | None
108
+ """The global rank of the current process."""
109
+
110
+ local_rank: int | None
111
+ """The local rank of the current process within its node."""
112
+
113
+ @classmethod
114
+ def empty(cls):
115
+ return cls(
116
+ hostname=None,
117
+ hostnames=None,
118
+ job_id=None,
119
+ raw_job_id=None,
120
+ array_job_id=None,
121
+ array_task_id=None,
122
+ num_tasks=None,
123
+ num_nodes=None,
124
+ node=None,
125
+ global_rank=None,
126
+ local_rank=None,
127
+ )
128
+
129
+ @classmethod
130
+ def from_current_environment(cls):
131
+ try:
132
+ from lightning.fabric.plugins.environments.slurm import SLURMEnvironment
133
+
134
+ if not SLURMEnvironment.detect():
135
+ return None
136
+
137
+ hostname = socket.gethostname()
138
+ hostnames = [hostname]
139
+ if node_list := os.environ.get("SLURM_JOB_NODELIST", ""):
140
+ hostnames = parse_slurm_node_list(node_list)
141
+
142
+ raw_job_id = os.environ["SLURM_JOB_ID"]
143
+ job_id = raw_job_id
144
+ array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID")
145
+ array_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
146
+ if array_job_id and array_task_id:
147
+ job_id = f"{array_job_id}_{array_task_id}"
148
+
149
+ num_tasks = int(os.environ["SLURM_NTASKS"])
150
+ num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
151
+
152
+ node_id = os.environ.get("SLURM_NODEID")
153
+
154
+ global_rank = int(os.environ["SLURM_PROCID"])
155
+ local_rank = int(os.environ["SLURM_LOCALID"])
156
+
157
+ return cls(
158
+ hostname=hostname,
159
+ hostnames=hostnames,
160
+ job_id=job_id,
161
+ raw_job_id=raw_job_id,
162
+ array_job_id=array_job_id,
163
+ array_task_id=array_task_id,
164
+ num_tasks=num_tasks,
165
+ num_nodes=num_nodes,
166
+ node=node_id,
167
+ global_rank=global_rank,
168
+ local_rank=local_rank,
169
+ )
170
+ except (ImportError, RuntimeError, ValueError, KeyError):
171
+ return None
172
+
173
+
174
+ class EnvironmentLSFInformationConfig(C.Config):
175
+ """Configuration for LSF environment information."""
176
+
177
+ hostname: str | None
178
+ """The hostname of the current node."""
179
+
180
+ hostnames: list[str] | None
181
+ """List of hostnames for all nodes in the job."""
182
+
183
+ job_id: str | None
184
+ """The LSF job ID."""
185
+
186
+ array_job_id: str | None
187
+ """The LSF array job ID, if applicable."""
188
+
189
+ array_task_id: str | None
190
+ """The LSF array task ID, if applicable."""
191
+
192
+ num_tasks: int | None
193
+ """The number of tasks in the LSF job."""
194
+
195
+ num_nodes: int | None
196
+ """The number of nodes in the LSF job."""
197
+
198
+ node: str | int | None
199
+ """The node ID or name."""
200
+
201
+ global_rank: int | None
202
+ """The global rank of the current process."""
203
+
204
+ local_rank: int | None
205
+ """The local rank of the current process within its node."""
206
+
207
+ @classmethod
208
+ def empty(cls):
209
+ return cls(
210
+ hostname=None,
211
+ hostnames=None,
212
+ job_id=None,
213
+ array_job_id=None,
214
+ array_task_id=None,
215
+ num_tasks=None,
216
+ num_nodes=None,
217
+ node=None,
218
+ global_rank=None,
219
+ local_rank=None,
220
+ )
221
+
222
+ @classmethod
223
+ def from_current_environment(cls):
224
+ try:
225
+ import os
226
+ import socket
227
+
228
+ hostname = socket.gethostname()
229
+ hostnames = [hostname]
230
+ if node_list := os.environ.get("LSB_HOSTS", ""):
231
+ hostnames = node_list.split()
232
+
233
+ job_id = os.environ["LSB_JOBID"]
234
+ array_job_id = os.environ.get("LSB_JOBINDEX")
235
+ array_task_id = os.environ.get("LSB_JOBINDEX")
236
+
237
+ num_tasks = int(os.environ.get("LSB_DJOB_NUMPROC", 1))
238
+ num_nodes = len(set(hostnames))
239
+
240
+ node_id = (
241
+ os.environ.get("LSB_HOSTS", "").split().index(hostname)
242
+ if "LSB_HOSTS" in os.environ
243
+ else None
244
+ )
245
+
246
+ # LSF doesn't have direct equivalents for global_rank and local_rank
247
+ # You might need to calculate these based on your specific setup
248
+ global_rank = int(os.environ.get("PMI_RANK", 0))
249
+ local_rank = int(os.environ.get("LSB_RANK", 0))
250
+
251
+ return cls(
252
+ hostname=hostname,
253
+ hostnames=hostnames,
254
+ job_id=job_id,
255
+ array_job_id=array_job_id,
256
+ array_task_id=array_task_id,
257
+ num_tasks=num_tasks,
258
+ num_nodes=num_nodes,
259
+ node=node_id,
260
+ global_rank=global_rank,
261
+ local_rank=local_rank,
262
+ )
263
+ except (ImportError, RuntimeError, ValueError, KeyError):
264
+ return None
265
+
266
+
267
+ class EnvironmentLinuxEnvironmentConfig(C.Config):
268
+ """Configuration for Linux environment information."""
269
+
270
+ user: str | None
271
+ """The current user."""
272
+
273
+ hostname: str | None
274
+ """The hostname of the machine."""
275
+
276
+ system: str | None
277
+ """The operating system name."""
278
+
279
+ release: str | None
280
+ """The operating system release."""
281
+
282
+ version: str | None
283
+ """The operating system version."""
284
+
285
+ machine: str | None
286
+ """The machine type."""
287
+
288
+ processor: str | None
289
+ """The processor type."""
290
+
291
+ cpu_count: int | None
292
+ """The number of CPUs."""
293
+
294
+ memory: int | None
295
+ """The total system memory in bytes."""
296
+
297
+ uptime: timedelta | None
298
+ """The system uptime."""
299
+
300
+ boot_time: float | None
301
+ """The system boot time as a timestamp."""
302
+
303
+ load_avg: tuple[float, float, float] | None
304
+ """The system load average (1, 5, and 15 minutes)."""
305
+
306
+ @classmethod
307
+ def empty(cls):
308
+ return cls(
309
+ user=None,
310
+ hostname=None,
311
+ system=None,
312
+ release=None,
313
+ version=None,
314
+ machine=None,
315
+ processor=None,
316
+ cpu_count=None,
317
+ memory=None,
318
+ uptime=None,
319
+ boot_time=None,
320
+ load_avg=None,
321
+ )
322
+
323
+ @classmethod
324
+ def from_current_environment(cls):
325
+ return cls(
326
+ user=getpass.getuser(),
327
+ hostname=platform.node(),
328
+ system=platform.system(),
329
+ release=platform.release(),
330
+ version=platform.version(),
331
+ machine=platform.machine(),
332
+ processor=platform.processor(),
333
+ cpu_count=os.cpu_count(),
334
+ memory=psutil.virtual_memory().total,
335
+ uptime=timedelta(seconds=psutil.boot_time()),
336
+ boot_time=psutil.boot_time(),
337
+ load_avg=os.getloadavg(),
338
+ )
339
+
340
+
341
+ class EnvironmentSnapshotConfig(C.Config):
342
+ """Configuration for environment snapshot information."""
343
+
344
+ snapshot_dir: Path | None
345
+ """The directory where the snapshot is stored."""
346
+
347
+ modules: list[str] | None
348
+ """List of modules included in the snapshot."""
349
+
350
+ @classmethod
351
+ def empty(cls):
352
+ return cls(snapshot_dir=None, modules=None)
353
+
354
+ @classmethod
355
+ def from_current_environment(cls):
356
+ draft = cls.draft()
357
+ if snapshot_dir := os.environ.get("NSHRUNNER_SNAPSHOT_DIR"):
358
+ draft.snapshot_dir = Path(snapshot_dir)
359
+ if modules := os.environ.get("NSHRUNNER_SNAPSHOT_MODULES"):
360
+ draft.modules = modules.split(",")
361
+ return draft.finalize()
362
+
363
+
364
+ class EnvironmentPackageConfig(C.Config):
365
+ """Configuration for Python package information."""
366
+
367
+ name: str | None
368
+ """The name of the package."""
369
+
370
+ version: str | None
371
+ """The version of the package."""
372
+
373
+ path: Path | None
374
+ """The installation path of the package."""
375
+
376
+ summary: str | None
377
+ """A brief summary of the package."""
378
+
379
+ author: str | None
380
+ """The author of the package."""
381
+
382
+ license: str | None
383
+ """The license of the package."""
384
+
385
+ requires: list[str] | None
386
+ """List of package dependencies."""
387
+
388
+ @classmethod
389
+ def empty(cls):
390
+ return cls(
391
+ name=None,
392
+ version=None,
393
+ path=None,
394
+ summary=None,
395
+ author=None,
396
+ license=None,
397
+ requires=None,
398
+ )
399
+
400
+ @classmethod
401
+ def from_current_environment(cls):
402
+ # Add Python package information
403
+ python_packages: dict[str, Self] = {}
404
+ try:
405
+ import pkg_resources
406
+
407
+ for package in pkg_resources.working_set:
408
+ python_packages[package.key] = cls(
409
+ name=package.project_name,
410
+ version=package.version,
411
+ path=Path(package.location) if package.location else None,
412
+ summary=package.summary,
413
+ author=package.author,
414
+ license=package.license,
415
+ requires=[str(req) for req in package.requires()],
416
+ )
417
+ except ImportError:
418
+ log.warning("pkg_resources not available, skipping package information")
419
+
420
+ return python_packages
421
+
422
+
423
+ class EnvironmentGPUConfig(C.Config):
424
+ """Configuration for individual GPU information."""
425
+
426
+ name: str | None
427
+ """Name of the GPU."""
428
+
429
+ total_memory: int | None
430
+ """Total memory of the GPU in bytes."""
431
+
432
+ major: int | None
433
+ """Major version of CUDA capability."""
434
+
435
+ minor: int | None
436
+ """Minor version of CUDA capability."""
437
+
438
+ multi_processor_count: int | None
439
+ """Number of multiprocessors on the GPU."""
440
+
441
+ @classmethod
442
+ def empty(cls):
443
+ return cls(
444
+ name=None,
445
+ total_memory=None,
446
+ major=None,
447
+ minor=None,
448
+ multi_processor_count=None,
449
+ )
450
+
451
+
452
+ class EnvironmentCUDAConfig(C.Config):
453
+ """Configuration for CUDA environment information."""
454
+
455
+ is_available: bool | None
456
+ """Whether CUDA is available."""
457
+
458
+ version: str | None
459
+ """CUDA version."""
460
+
461
+ cudnn_version: int | None
462
+ """cuDNN version."""
463
+
464
+ @classmethod
465
+ def empty(cls):
466
+ return cls(is_available=None, version=None, cudnn_version=None)
467
+
468
+
469
+ class EnvironmentHardwareConfig(C.Config):
470
+ """Configuration for hardware information."""
471
+
472
+ cpu_count_physical: int | None
473
+ """Number of physical CPU cores."""
474
+
475
+ cpu_count_logical: int | None
476
+ """Number of logical CPU cores."""
477
+
478
+ cpu_frequency_current: float | None
479
+ """Current CPU frequency in MHz."""
480
+
481
+ cpu_frequency_min: float | None
482
+ """Minimum CPU frequency in MHz."""
483
+
484
+ cpu_frequency_max: float | None
485
+ """Maximum CPU frequency in MHz."""
486
+
487
+ ram_total: int | None
488
+ """Total RAM in bytes."""
489
+
490
+ ram_available: int | None
491
+ """Available RAM in bytes."""
492
+
493
+ disk_total: int | None
494
+ """Total disk space in bytes."""
495
+
496
+ disk_used: int | None
497
+ """Used disk space in bytes."""
498
+
499
+ disk_free: int | None
500
+ """Free disk space in bytes."""
501
+
502
+ gpu_count: int | None
503
+ """Number of GPUs available."""
504
+
505
+ gpus: list[EnvironmentGPUConfig] | None
506
+ """List of GPU configurations."""
507
+
508
+ cuda: EnvironmentCUDAConfig | None
509
+ """CUDA environment configuration."""
510
+
511
+ @classmethod
512
+ def empty(cls):
513
+ return cls(
514
+ cpu_count_physical=None,
515
+ cpu_count_logical=None,
516
+ cpu_frequency_current=None,
517
+ cpu_frequency_min=None,
518
+ cpu_frequency_max=None,
519
+ ram_total=None,
520
+ ram_available=None,
521
+ disk_total=None,
522
+ disk_used=None,
523
+ disk_free=None,
524
+ gpu_count=None,
525
+ gpus=None,
526
+ cuda=None,
527
+ )
528
+
529
+ @classmethod
530
+ def from_current_environment(cls):
531
+ draft = cls.draft()
532
+
533
+ # CPU information
534
+ draft.cpu_count_physical = psutil.cpu_count(logical=False)
535
+ draft.cpu_count_logical = psutil.cpu_count(logical=True)
536
+ cpu_freq = psutil.cpu_freq()
537
+ if cpu_freq:
538
+ draft.cpu_frequency_current = cpu_freq.current
539
+ draft.cpu_frequency_min = cpu_freq.min
540
+ draft.cpu_frequency_max = cpu_freq.max
541
+
542
+ # RAM information
543
+ ram = psutil.virtual_memory()
544
+ draft.ram_total = ram.total
545
+ draft.ram_available = ram.available
546
+
547
+ # Disk information
548
+ disk = psutil.disk_usage("/")
549
+ draft.disk_total = disk.total
550
+ draft.disk_used = disk.used
551
+ draft.disk_free = disk.free
552
+
553
+ # GPU and CUDA information
554
+ draft.cuda = EnvironmentCUDAConfig(
555
+ is_available=torch.cuda.is_available(),
556
+ version=cast(Any, torch).version.cuda,
557
+ cudnn_version=torch.backends.cudnn.version()
558
+ if torch.backends.cudnn.is_available()
559
+ else None,
560
+ )
561
+
562
+ if draft.cuda.is_available:
563
+ draft.gpu_count = torch.cuda.device_count()
564
+ draft.gpus = []
565
+ for i in range(draft.gpu_count):
566
+ gpu_props = torch.cuda.get_device_properties(i)
567
+ gpu_config = EnvironmentGPUConfig(
568
+ name=gpu_props.name,
569
+ total_memory=gpu_props.total_memory,
570
+ major=gpu_props.major,
571
+ minor=gpu_props.minor,
572
+ multi_processor_count=gpu_props.multi_processor_count,
573
+ )
574
+ draft.gpus.append(gpu_config)
575
+
576
+ return draft.finalize()
577
+
578
+
579
+ class GitRepositoryConfig(C.Config):
580
+ """Configuration for Git repository information."""
581
+
582
+ is_git_repo: bool | None
583
+ """Whether the current directory is a Git repository."""
584
+
585
+ branch: str | None
586
+ """The current Git branch."""
587
+
588
+ commit_hash: str | None
589
+ """The current commit hash."""
590
+
591
+ commit_message: str | None
592
+ """The current commit message."""
593
+
594
+ author: str | None
595
+ """The author of the current commit."""
596
+
597
+ commit_date: str | None
598
+ """The date of the current commit."""
599
+
600
+ remote_url: str | None
601
+ """The URL of the remote repository."""
602
+
603
+ is_dirty: bool | None
604
+ """Whether there are uncommitted changes."""
605
+
606
+ @classmethod
607
+ def empty(cls):
608
+ return cls(
609
+ is_git_repo=None,
610
+ branch=None,
611
+ commit_hash=None,
612
+ commit_message=None,
613
+ author=None,
614
+ commit_date=None,
615
+ remote_url=None,
616
+ is_dirty=None,
617
+ )
618
+
619
+ @classmethod
620
+ def from_current_directory(cls):
621
+ draft = cls.draft()
622
+ try:
623
+ repo = git.Repo(os.getcwd(), search_parent_directories=True)
624
+ draft.is_git_repo = True
625
+ draft.branch = repo.active_branch.name
626
+ commit = repo.head.commit
627
+ draft.commit_hash = commit.hexsha
628
+
629
+ # Handle both str and bytes for commit message
630
+ if isinstance(commit.message, str):
631
+ draft.commit_message = commit.message.strip()
632
+ elif isinstance(commit.message, bytes):
633
+ draft.commit_message = commit.message.decode(
634
+ "utf-8", errors="replace"
635
+ ).strip()
636
+ else:
637
+ draft.commit_message = str(commit.message).strip()
638
+
639
+ draft.author = f"{commit.author.name} <{commit.author.email}>"
640
+ draft.commit_date = commit.committed_datetime.isoformat()
641
+ if repo.remotes:
642
+ draft.remote_url = repo.remotes.origin.url
643
+ draft.is_dirty = repo.is_dirty()
644
+ except git.InvalidGitRepositoryError:
645
+ draft.is_git_repo = False
646
+ except Exception as e:
647
+ log.warning(f"Failed to get Git repository information: {e}")
648
+ draft.is_git_repo = None
649
+
650
+ return draft.finalize()
651
+
652
+
653
+ class EnvironmentConfig(C.Config):
654
+ """Configuration for the overall environment."""
655
+
656
+ cwd: Path | None
657
+ """The current working directory."""
658
+
659
+ snapshot: EnvironmentSnapshotConfig | None
660
+ """The environment snapshot configuration."""
661
+
662
+ python_executable: Path | None
663
+ """The path to the Python executable."""
664
+
665
+ python_path: list[Path] | None
666
+ """The Python path."""
667
+
668
+ python_version: str | None
669
+ """The Python version."""
670
+
671
+ python_packages: dict[str, EnvironmentPackageConfig] | None
672
+ """A mapping of package names to their configurations."""
673
+
674
+ config: EnvironmentClassInformationConfig | None
675
+ """The configuration class information."""
676
+
677
+ model: EnvironmentClassInformationConfig | None
678
+ """The Lightning module class information."""
679
+
680
+ linux: EnvironmentLinuxEnvironmentConfig | None
681
+ """The Linux environment information."""
682
+
683
+ hardware: EnvironmentHardwareConfig | None
684
+ """Hardware configuration information."""
685
+
686
+ slurm: EnvironmentSLURMInformationConfig | None
687
+ """The SLURM environment information."""
688
+
689
+ lsf: EnvironmentLSFInformationConfig | None
690
+ """The LSF environment information."""
691
+
692
+ base_dir: Path | None
693
+ """The base directory for the run."""
694
+
695
+ log_dir: Path | None
696
+ """The directory for logs."""
697
+
698
+ checkpoint_dir: Path | None
699
+ """The directory for checkpoints."""
700
+
701
+ stdio_dir: Path | None
702
+ """The directory for standard input/output files."""
703
+
704
+ seed: int | None
705
+ """The global random seed."""
706
+
707
+ seed_workers: bool | None
708
+ """Whether to seed workers."""
709
+
710
+ git: GitRepositoryConfig | None
711
+ """Git repository information."""
712
+
713
+ @classmethod
714
+ def empty(cls):
715
+ return cls(
716
+ cwd=None,
717
+ snapshot=None,
718
+ python_executable=None,
719
+ python_path=None,
720
+ python_version=None,
721
+ python_packages=None,
722
+ config=None,
723
+ model=None,
724
+ linux=None,
725
+ hardware=None,
726
+ slurm=None,
727
+ lsf=None,
728
+ base_dir=None,
729
+ log_dir=None,
730
+ checkpoint_dir=None,
731
+ stdio_dir=None,
732
+ seed=None,
733
+ seed_workers=None,
734
+ git=None,
735
+ )
736
+
737
+ @classmethod
738
+ def from_current_environment(
739
+ cls,
740
+ root_config: "BaseConfig",
741
+ model: "LightningModuleBase",
742
+ ):
743
+ draft = cls.draft()
744
+ draft.cwd = Path(os.getcwd())
745
+ draft.python_executable = Path(sys.executable)
746
+ draft.python_path = [Path(path) for path in sys.path]
747
+ draft.python_version = sys.version
748
+ draft.python_packages = EnvironmentPackageConfig.from_current_environment()
749
+ draft.config = EnvironmentClassInformationConfig.from_instance(root_config)
750
+ draft.model = EnvironmentClassInformationConfig.from_instance(model)
751
+ draft.linux = EnvironmentLinuxEnvironmentConfig.from_current_environment()
752
+ draft.hardware = EnvironmentHardwareConfig.from_current_environment()
753
+ draft.slurm = EnvironmentSLURMInformationConfig.from_current_environment()
754
+ draft.lsf = EnvironmentLSFInformationConfig.from_current_environment()
755
+ draft.base_dir = root_config.directory.resolve_run_root_directory(
756
+ root_config.id
757
+ )
758
+ draft.log_dir = root_config.directory.resolve_subdirectory(
759
+ root_config.id, "log"
760
+ )
761
+ draft.checkpoint_dir = root_config.directory.resolve_subdirectory(
762
+ root_config.id, "checkpoint"
763
+ )
764
+ draft.stdio_dir = root_config.directory.resolve_subdirectory(
765
+ root_config.id, "stdio"
766
+ )
767
+ draft.seed = (
768
+ int(seed_str) if (seed_str := os.environ.get("PL_GLOBAL_SEED")) else None
769
+ )
770
+ draft.seed_workers = (
771
+ bool(int(seed_everything))
772
+ if (seed_everything := os.environ.get("PL_SEED_WORKERS"))
773
+ else None
774
+ )
775
+ draft.snapshot = EnvironmentSnapshotConfig.from_current_environment()
776
+ draft.git = GitRepositoryConfig.from_current_directory()
777
+ return draft.finalize()