nshtrainer 0.10.4__tar.gz → 0.10.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/PKG-INFO +1 -1
  2. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/pyproject.toml +1 -1
  3. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/_environment.py +95 -95
  4. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/config.py +3 -0
  5. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/trainer/trainer.py +4 -0
  6. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/README.md +0 -0
  7. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/__init__.py +0 -0
  8. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/_checkpoint/loader.py +0 -0
  9. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/_checkpoint/metadata.py +0 -0
  10. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/_experimental/__init__.py +0 -0
  11. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/_experimental/flops/__init__.py +0 -0
  12. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/_experimental/flops/flop_counter.py +0 -0
  13. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/_experimental/flops/module_tracker.py +0 -0
  14. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/__init__.py +0 -0
  15. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/_throughput_monitor_callback.py +0 -0
  16. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/actsave.py +0 -0
  17. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/base.py +0 -0
  18. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/early_stopping.py +0 -0
  19. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/ema.py +0 -0
  20. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/finite_checks.py +0 -0
  21. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/gradient_skipping.py +0 -0
  22. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/interval.py +0 -0
  23. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/latest_epoch_checkpoint.py +0 -0
  24. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/log_epoch.py +0 -0
  25. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/model_checkpoint.py +0 -0
  26. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/norm_logging.py +0 -0
  27. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/on_exception_checkpoint.py +0 -0
  28. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/print_table.py +0 -0
  29. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/throughput_monitor.py +0 -0
  30. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/timer.py +0 -0
  31. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/callbacks/wandb_watch.py +0 -0
  32. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/data/__init__.py +0 -0
  33. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/data/balanced_batch_sampler.py +0 -0
  34. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/data/transform.py +0 -0
  35. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/__init__.py +0 -0
  36. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/_experimental.py +0 -0
  37. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/actsave.py +0 -0
  38. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/callbacks.py +0 -0
  39. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/config.py +0 -0
  40. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/data.py +0 -0
  41. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/log.py +0 -0
  42. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/lr_scheduler.py +0 -0
  43. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/model.py +0 -0
  44. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/nn.py +0 -0
  45. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/optimizer.py +0 -0
  46. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/runner.py +0 -0
  47. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/snapshot.py +0 -0
  48. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/snoop.py +0 -0
  49. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/trainer.py +0 -0
  50. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/typecheck.py +0 -0
  51. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/ll/util.py +0 -0
  52. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/lr_scheduler/__init__.py +0 -0
  53. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/lr_scheduler/_base.py +0 -0
  54. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/lr_scheduler/linear_warmup_cosine.py +0 -0
  55. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +0 -0
  56. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/metrics/__init__.py +0 -0
  57. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/metrics/_config.py +0 -0
  58. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/__init__.py +0 -0
  59. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/base.py +0 -0
  60. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/callback.py +0 -0
  61. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/debug.py +0 -0
  62. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/distributed.py +0 -0
  63. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/logger.py +0 -0
  64. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/profiler.py +0 -0
  65. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/rlp_sanity_checks.py +0 -0
  66. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/model/modules/shared_parameters.py +0 -0
  67. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/nn/__init__.py +0 -0
  68. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/nn/mlp.py +0 -0
  69. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/nn/module_dict.py +0 -0
  70. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/nn/module_list.py +0 -0
  71. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/nn/nonlinearity.py +0 -0
  72. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/optimizer.py +0 -0
  73. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/runner.py +0 -0
  74. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/scripts/check_env.py +0 -0
  75. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/scripts/find_packages.py +0 -0
  76. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/trainer/__init__.py +0 -0
  77. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/trainer/_runtime_callback.py +0 -0
  78. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/trainer/checkpoint_connector.py +0 -0
  79. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/trainer/signal_connector.py +0 -0
  80. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/util/environment.py +0 -0
  81. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/util/seed.py +0 -0
  82. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/util/slurm.py +0 -0
  83. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/util/typed.py +0 -0
  84. {nshtrainer-0.10.4 → nshtrainer-0.10.6}/src/nshtrainer/util/typing_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nshtrainer
3
- Version: 0.10.4
3
+ Version: 0.10.6
4
4
  Summary:
5
5
  Author: Nima Shoghi
6
6
  Author-email: nimashoghi@gmail.com
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "nshtrainer"
3
- version = "0.10.4"
3
+ version = "0.10.6"
4
4
  description = ""
5
5
  authors = ["Nima Shoghi <nimashoghi@gmail.com>"]
6
6
  readme = "README.md"
@@ -28,19 +28,19 @@ log = logging.getLogger(__name__)
28
28
  class EnvironmentClassInformationConfig(C.Config):
29
29
  """Configuration for class information in the environment."""
30
30
 
31
- name: str | None
31
+ name: str | None = None
32
32
  """The name of the class."""
33
33
 
34
- module: str | None
34
+ module: str | None = None
35
35
  """The module where the class is defined."""
36
36
 
37
- full_name: str | None
37
+ full_name: str | None = None
38
38
  """The fully qualified name of the class."""
39
39
 
40
- file_path: Path | None
40
+ file_path: Path | None = None
41
41
  """The file path where the class is defined."""
42
42
 
43
- source_file_path: Path | None
43
+ source_file_path: Path | None = None
44
44
  """The source file path of the class, if available."""
45
45
 
46
46
  @classmethod
@@ -77,37 +77,37 @@ class EnvironmentClassInformationConfig(C.Config):
77
77
  class EnvironmentSLURMInformationConfig(C.Config):
78
78
  """Configuration for SLURM environment information."""
79
79
 
80
- hostname: str | None
80
+ hostname: str | None = None
81
81
  """The hostname of the current node."""
82
82
 
83
- hostnames: list[str] | None
83
+ hostnames: list[str] | None = None
84
84
  """List of hostnames for all nodes in the job."""
85
85
 
86
- job_id: str | None
86
+ job_id: str | None = None
87
87
  """The SLURM job ID."""
88
88
 
89
- raw_job_id: str | None
89
+ raw_job_id: str | None = None
90
90
  """The raw SLURM job ID."""
91
91
 
92
- array_job_id: str | None
92
+ array_job_id: str | None = None
93
93
  """The SLURM array job ID, if applicable."""
94
94
 
95
- array_task_id: str | None
95
+ array_task_id: str | None = None
96
96
  """The SLURM array task ID, if applicable."""
97
97
 
98
- num_tasks: int | None
98
+ num_tasks: int | None = None
99
99
  """The number of tasks in the SLURM job."""
100
100
 
101
- num_nodes: int | None
101
+ num_nodes: int | None = None
102
102
  """The number of nodes in the SLURM job."""
103
103
 
104
- node: str | int | None
104
+ node: str | int | None = None
105
105
  """The node ID or name."""
106
106
 
107
- global_rank: int | None
107
+ global_rank: int | None = None
108
108
  """The global rank of the current process."""
109
109
 
110
- local_rank: int | None
110
+ local_rank: int | None = None
111
111
  """The local rank of the current process within its node."""
112
112
 
113
113
  @classmethod
@@ -174,34 +174,34 @@ class EnvironmentSLURMInformationConfig(C.Config):
174
174
  class EnvironmentLSFInformationConfig(C.Config):
175
175
  """Configuration for LSF environment information."""
176
176
 
177
- hostname: str | None
177
+ hostname: str | None = None
178
178
  """The hostname of the current node."""
179
179
 
180
- hostnames: list[str] | None
180
+ hostnames: list[str] | None = None
181
181
  """List of hostnames for all nodes in the job."""
182
182
 
183
- job_id: str | None
183
+ job_id: str | None = None
184
184
  """The LSF job ID."""
185
185
 
186
- array_job_id: str | None
186
+ array_job_id: str | None = None
187
187
  """The LSF array job ID, if applicable."""
188
188
 
189
- array_task_id: str | None
189
+ array_task_id: str | None = None
190
190
  """The LSF array task ID, if applicable."""
191
191
 
192
- num_tasks: int | None
192
+ num_tasks: int | None = None
193
193
  """The number of tasks in the LSF job."""
194
194
 
195
- num_nodes: int | None
195
+ num_nodes: int | None = None
196
196
  """The number of nodes in the LSF job."""
197
197
 
198
- node: str | int | None
198
+ node: str | int | None = None
199
199
  """The node ID or name."""
200
200
 
201
- global_rank: int | None
201
+ global_rank: int | None = None
202
202
  """The global rank of the current process."""
203
203
 
204
- local_rank: int | None
204
+ local_rank: int | None = None
205
205
  """The local rank of the current process within its node."""
206
206
 
207
207
  @classmethod
@@ -267,40 +267,40 @@ class EnvironmentLSFInformationConfig(C.Config):
267
267
  class EnvironmentLinuxEnvironmentConfig(C.Config):
268
268
  """Configuration for Linux environment information."""
269
269
 
270
- user: str | None
270
+ user: str | None = None
271
271
  """The current user."""
272
272
 
273
- hostname: str | None
273
+ hostname: str | None = None
274
274
  """The hostname of the machine."""
275
275
 
276
- system: str | None
276
+ system: str | None = None
277
277
  """The operating system name."""
278
278
 
279
- release: str | None
279
+ release: str | None = None
280
280
  """The operating system release."""
281
281
 
282
- version: str | None
282
+ version: str | None = None
283
283
  """The operating system version."""
284
284
 
285
- machine: str | None
285
+ machine: str | None = None
286
286
  """The machine type."""
287
287
 
288
- processor: str | None
288
+ processor: str | None = None
289
289
  """The processor type."""
290
290
 
291
- cpu_count: int | None
291
+ cpu_count: int | None = None
292
292
  """The number of CPUs."""
293
293
 
294
- memory: int | None
294
+ memory: int | None = None
295
295
  """The total system memory in bytes."""
296
296
 
297
- uptime: timedelta | None
297
+ uptime: timedelta | None = None
298
298
  """The system uptime."""
299
299
 
300
- boot_time: float | None
300
+ boot_time: float | None = None
301
301
  """The system boot time as a timestamp."""
302
302
 
303
- load_avg: tuple[float, float, float] | None
303
+ load_avg: tuple[float, float, float] | None = None
304
304
  """The system load average (1, 5, and 15 minutes)."""
305
305
 
306
306
  @classmethod
@@ -341,10 +341,10 @@ class EnvironmentLinuxEnvironmentConfig(C.Config):
341
341
  class EnvironmentSnapshotConfig(C.Config):
342
342
  """Configuration for environment snapshot information."""
343
343
 
344
- snapshot_dir: Path | None
344
+ snapshot_dir: Path | None = None
345
345
  """The directory where the snapshot is stored."""
346
346
 
347
- modules: list[str] | None
347
+ modules: list[str] | None = None
348
348
  """List of modules included in the snapshot."""
349
349
 
350
350
  @classmethod
@@ -364,25 +364,25 @@ class EnvironmentSnapshotConfig(C.Config):
364
364
  class EnvironmentPackageConfig(C.Config):
365
365
  """Configuration for Python package information."""
366
366
 
367
- name: str | None
367
+ name: str | None = None
368
368
  """The name of the package."""
369
369
 
370
- version: str | None
370
+ version: str | None = None
371
371
  """The version of the package."""
372
372
 
373
- path: Path | None
373
+ path: Path | None = None
374
374
  """The installation path of the package."""
375
375
 
376
- summary: str | None
376
+ summary: str | None = None
377
377
  """A brief summary of the package."""
378
378
 
379
- author: str | None
379
+ author: str | None = None
380
380
  """The author of the package."""
381
381
 
382
- license: str | None
382
+ license: str | None = None
383
383
  """The license of the package."""
384
384
 
385
- requires: list[str] | None
385
+ requires: list[str] | None = None
386
386
  """List of package dependencies."""
387
387
 
388
388
  @classmethod
@@ -423,19 +423,19 @@ class EnvironmentPackageConfig(C.Config):
423
423
  class EnvironmentGPUConfig(C.Config):
424
424
  """Configuration for individual GPU information."""
425
425
 
426
- name: str | None
426
+ name: str | None = None
427
427
  """Name of the GPU."""
428
428
 
429
- total_memory: int | None
429
+ total_memory: int | None = None
430
430
  """Total memory of the GPU in bytes."""
431
431
 
432
- major: int | None
432
+ major: int | None = None
433
433
  """Major version of CUDA capability."""
434
434
 
435
- minor: int | None
435
+ minor: int | None = None
436
436
  """Minor version of CUDA capability."""
437
437
 
438
- multi_processor_count: int | None
438
+ multi_processor_count: int | None = None
439
439
  """Number of multiprocessors on the GPU."""
440
440
 
441
441
  @classmethod
@@ -452,13 +452,13 @@ class EnvironmentGPUConfig(C.Config):
452
452
  class EnvironmentCUDAConfig(C.Config):
453
453
  """Configuration for CUDA environment information."""
454
454
 
455
- is_available: bool | None
455
+ is_available: bool | None = None
456
456
  """Whether CUDA is available."""
457
457
 
458
- version: str | None
458
+ version: str | None = None
459
459
  """CUDA version."""
460
460
 
461
- cudnn_version: int | None
461
+ cudnn_version: int | None = None
462
462
  """cuDNN version."""
463
463
 
464
464
  @classmethod
@@ -469,43 +469,43 @@ class EnvironmentCUDAConfig(C.Config):
469
469
  class EnvironmentHardwareConfig(C.Config):
470
470
  """Configuration for hardware information."""
471
471
 
472
- cpu_count_physical: int | None
472
+ cpu_count_physical: int | None = None
473
473
  """Number of physical CPU cores."""
474
474
 
475
- cpu_count_logical: int | None
475
+ cpu_count_logical: int | None = None
476
476
  """Number of logical CPU cores."""
477
477
 
478
- cpu_frequency_current: float | None
478
+ cpu_frequency_current: float | None = None
479
479
  """Current CPU frequency in MHz."""
480
480
 
481
- cpu_frequency_min: float | None
481
+ cpu_frequency_min: float | None = None
482
482
  """Minimum CPU frequency in MHz."""
483
483
 
484
- cpu_frequency_max: float | None
484
+ cpu_frequency_max: float | None = None
485
485
  """Maximum CPU frequency in MHz."""
486
486
 
487
- ram_total: int | None
487
+ ram_total: int | None = None
488
488
  """Total RAM in bytes."""
489
489
 
490
- ram_available: int | None
490
+ ram_available: int | None = None
491
491
  """Available RAM in bytes."""
492
492
 
493
- disk_total: int | None
493
+ disk_total: int | None = None
494
494
  """Total disk space in bytes."""
495
495
 
496
- disk_used: int | None
496
+ disk_used: int | None = None
497
497
  """Used disk space in bytes."""
498
498
 
499
- disk_free: int | None
499
+ disk_free: int | None = None
500
500
  """Free disk space in bytes."""
501
501
 
502
- gpu_count: int | None
502
+ gpu_count: int | None = None
503
503
  """Number of GPUs available."""
504
504
 
505
- gpus: list[EnvironmentGPUConfig] | None
505
+ gpus: list[EnvironmentGPUConfig] | None = None
506
506
  """List of GPU configurations."""
507
507
 
508
- cuda: EnvironmentCUDAConfig | None
508
+ cuda: EnvironmentCUDAConfig | None = None
509
509
  """CUDA environment configuration."""
510
510
 
511
511
  @classmethod
@@ -579,28 +579,28 @@ class EnvironmentHardwareConfig(C.Config):
579
579
  class GitRepositoryConfig(C.Config):
580
580
  """Configuration for Git repository information."""
581
581
 
582
- is_git_repo: bool | None
582
+ is_git_repo: bool | None = None
583
583
  """Whether the current directory is a Git repository."""
584
584
 
585
- branch: str | None
585
+ branch: str | None = None
586
586
  """The current Git branch."""
587
587
 
588
- commit_hash: str | None
588
+ commit_hash: str | None = None
589
589
  """The current commit hash."""
590
590
 
591
- commit_message: str | None
591
+ commit_message: str | None = None
592
592
  """The current commit message."""
593
593
 
594
- author: str | None
594
+ author: str | None = None
595
595
  """The author of the current commit."""
596
596
 
597
- commit_date: str | None
597
+ commit_date: str | None = None
598
598
  """The date of the current commit."""
599
599
 
600
- remote_url: str | None
600
+ remote_url: str | None = None
601
601
  """The URL of the remote repository."""
602
602
 
603
- is_dirty: bool | None
603
+ is_dirty: bool | None = None
604
604
  """Whether there are uncommitted changes."""
605
605
 
606
606
  @classmethod
@@ -653,61 +653,61 @@ class GitRepositoryConfig(C.Config):
653
653
  class EnvironmentConfig(C.Config):
654
654
  """Configuration for the overall environment."""
655
655
 
656
- cwd: Path | None
656
+ cwd: Path | None = None
657
657
  """The current working directory."""
658
658
 
659
- snapshot: EnvironmentSnapshotConfig | None
659
+ snapshot: EnvironmentSnapshotConfig | None = None
660
660
  """The environment snapshot configuration."""
661
661
 
662
- python_executable: Path | None
662
+ python_executable: Path | None = None
663
663
  """The path to the Python executable."""
664
664
 
665
- python_path: list[Path] | None
665
+ python_path: list[Path] | None = None
666
666
  """The Python path."""
667
667
 
668
- python_version: str | None
668
+ python_version: str | None = None
669
669
  """The Python version."""
670
670
 
671
- python_packages: dict[str, EnvironmentPackageConfig] | None
671
+ python_packages: dict[str, EnvironmentPackageConfig] | None = None
672
672
  """A mapping of package names to their configurations."""
673
673
 
674
- config: EnvironmentClassInformationConfig | None
674
+ config: EnvironmentClassInformationConfig | None = None
675
675
  """The configuration class information."""
676
676
 
677
- model: EnvironmentClassInformationConfig | None
677
+ model: EnvironmentClassInformationConfig | None = None
678
678
  """The Lightning module class information."""
679
679
 
680
- linux: EnvironmentLinuxEnvironmentConfig | None
680
+ linux: EnvironmentLinuxEnvironmentConfig | None = None
681
681
  """The Linux environment information."""
682
682
 
683
- hardware: EnvironmentHardwareConfig | None
683
+ hardware: EnvironmentHardwareConfig | None = None
684
684
  """Hardware configuration information."""
685
685
 
686
- slurm: EnvironmentSLURMInformationConfig | None
686
+ slurm: EnvironmentSLURMInformationConfig | None = None
687
687
  """The SLURM environment information."""
688
688
 
689
- lsf: EnvironmentLSFInformationConfig | None
689
+ lsf: EnvironmentLSFInformationConfig | None = None
690
690
  """The LSF environment information."""
691
691
 
692
- base_dir: Path | None
692
+ base_dir: Path | None = None
693
693
  """The base directory for the run."""
694
694
 
695
- log_dir: Path | None
695
+ log_dir: Path | None = None
696
696
  """The directory for logs."""
697
697
 
698
- checkpoint_dir: Path | None
698
+ checkpoint_dir: Path | None = None
699
699
  """The directory for checkpoints."""
700
700
 
701
- stdio_dir: Path | None
701
+ stdio_dir: Path | None = None
702
702
  """The directory for standard input/output files."""
703
703
 
704
- seed: int | None
704
+ seed: int | None = None
705
705
  """The global random seed."""
706
706
 
707
- seed_workers: bool | None
707
+ seed_workers: bool | None = None
708
708
  """Whether to seed workers."""
709
709
 
710
- git: GitRepositoryConfig | None
710
+ git: GitRepositoryConfig | None = None
711
711
  """Git repository information."""
712
712
 
713
713
  @classmethod
@@ -1121,6 +1121,9 @@ class SanityCheckingConfig(C.Config):
1121
1121
 
1122
1122
 
1123
1123
  class TrainerConfig(C.Config):
1124
+ ckpt_path: str | Path | None = None
1125
+ """Path to a checkpoint to load and resume training from."""
1126
+
1124
1127
  checkpoint_loading: CheckpointLoadingConfig | Literal["auto"] = "auto"
1125
1128
  """Checkpoint loading configuration options."""
1126
1129
 
@@ -304,6 +304,10 @@ class Trainer(LightningTrainer):
304
304
  log_dir = str(Path(log_dir).resolve())
305
305
  log.critical(f"LightningTrainer log directory: {self.log_dir}.")
306
306
 
307
+ # Set the checkpoint
308
+ if (ckpt_path := config.trainer.ckpt_path) is not None:
309
+ self.ckpt_path = str(Path(ckpt_path).resolve().absolute())
310
+
307
311
  def __runtime_tracker(self):
308
312
  return next(
309
313
  (
File without changes