xax 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.1/xax.egg-info → xax-0.2.2}/PKG-INFO +1 -1
- {xax-0.2.1 → xax-0.2.2}/xax/__init__.py +1 -1
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/cpu_stats.py +12 -9
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/gpu_stats.py +14 -11
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/process.py +14 -8
- {xax-0.2.1 → xax-0.2.2/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.1 → xax-0.2.2}/LICENSE +0 -0
- {xax-0.2.1 → xax-0.2.2}/MANIFEST.in +0 -0
- {xax-0.2.1 → xax-0.2.2}/README.md +0 -0
- {xax-0.2.1 → xax-0.2.2}/pyproject.toml +0 -0
- {xax-0.2.1 → xax-0.2.2}/setup.cfg +0 -0
- {xax-0.2.1 → xax-0.2.2}/setup.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/core/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/core/conf.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/core/state.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/embeddings.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/equinox.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/export.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/functions.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/geom.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/losses.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/norm.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/parallel.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/nn/ssm.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/py.typed +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/requirements-dev.txt +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/requirements.txt +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/base.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/base.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/logger.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/json.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/state.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/train.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/script.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/task/task.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/data/collate.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/debugging.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/experiments.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/jax.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/logging.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/numpy.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/profile.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/pytree.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/text.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.1 → xax-0.2.2}/xax.egg-info/top_level.txt +0 -0
@@ -218,33 +218,36 @@ class CPUStatsMonitor:
|
|
218
218
|
class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
219
219
|
"""Defines a task mixin for getting CPU statistics."""
|
220
220
|
|
221
|
-
_cpu_stats_monitor: CPUStatsMonitor
|
221
|
+
_cpu_stats_monitor: CPUStatsMonitor | None
|
222
222
|
|
223
223
|
def __init__(self, config: Config) -> None:
|
224
224
|
super().__init__(config)
|
225
225
|
|
226
|
-
self.
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
)
|
226
|
+
if (ctx := self.multiprocessing_context) is not None and (mgr := self.multiprocessing_manager) is not None:
|
227
|
+
self._cpu_stats_monitor = CPUStatsMonitor(self.config.cpu_stats.ping_interval, ctx, mgr)
|
228
|
+
else:
|
229
|
+
self._cpu_stats_monitor = None
|
231
230
|
|
232
231
|
def on_training_start(self, state: State) -> State:
|
233
232
|
state = super().on_training_start(state)
|
234
233
|
|
235
|
-
self._cpu_stats_monitor
|
234
|
+
if (monitor := self._cpu_stats_monitor) is not None:
|
235
|
+
monitor.start()
|
236
236
|
return state
|
237
237
|
|
238
238
|
def on_training_end(self, state: State) -> State:
|
239
239
|
state = super().on_training_end(state)
|
240
240
|
|
241
|
-
self._cpu_stats_monitor
|
241
|
+
if (monitor := self._cpu_stats_monitor) is not None:
|
242
|
+
monitor.stop()
|
242
243
|
return state
|
243
244
|
|
244
245
|
def on_step_start(self, state: State) -> State:
|
245
246
|
state = super().on_step_start(state)
|
246
247
|
|
247
|
-
monitor
|
248
|
+
if (monitor := self._cpu_stats_monitor) is None:
|
249
|
+
return state
|
250
|
+
|
248
251
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
249
252
|
|
250
253
|
if stats is not None:
|
@@ -234,24 +234,27 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
234
234
|
def __init__(self, config: Config) -> None:
|
235
235
|
super().__init__(config)
|
236
236
|
|
237
|
-
|
238
|
-
|
239
|
-
self.
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
237
|
+
if (
|
238
|
+
shutil.which("nvidia-smi") is not None
|
239
|
+
and (ctx := self.multiprocessing_context) is not None
|
240
|
+
and (mgr := self.multiprocessing_manager) is not None
|
241
|
+
):
|
242
|
+
self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, ctx, mgr)
|
243
|
+
else:
|
244
|
+
self._gpu_stats_monitor = None
|
244
245
|
|
245
246
|
def on_training_start(self, state: State) -> State:
|
246
247
|
state = super().on_training_start(state)
|
247
|
-
|
248
|
-
|
248
|
+
|
249
|
+
if (monitor := self._gpu_stats_monitor) is not None:
|
250
|
+
monitor.start()
|
249
251
|
return state
|
250
252
|
|
251
253
|
def on_training_end(self, state: State) -> State:
|
252
254
|
state = super().on_training_end(state)
|
253
|
-
|
254
|
-
|
255
|
+
|
256
|
+
if (monitor := self._gpu_stats_monitor) is not None:
|
257
|
+
monitor.stop()
|
255
258
|
return state
|
256
259
|
|
257
260
|
def on_step_start(self, state: State) -> State:
|
@@ -20,6 +20,7 @@ logger: logging.Logger = logging.getLogger(__name__)
|
|
20
20
|
@dataclass
|
21
21
|
class ProcessConfig(BaseConfig):
|
22
22
|
multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
|
23
|
+
disable_multiprocessing: bool = field(False, help="If set, disable multiprocessing")
|
23
24
|
|
24
25
|
|
25
26
|
Config = TypeVar("Config", bound=ProcessConfig)
|
@@ -28,27 +29,32 @@ Config = TypeVar("Config", bound=ProcessConfig)
|
|
28
29
|
class ProcessMixin(BaseTask[Config], Generic[Config]):
|
29
30
|
"""Defines a base trainer mixin for handling monitoring processes."""
|
30
31
|
|
31
|
-
_mp_ctx: BaseContext
|
32
|
-
_mp_manager: SyncManager
|
32
|
+
_mp_ctx: BaseContext | None
|
33
|
+
_mp_manager: SyncManager | None
|
33
34
|
|
34
35
|
def __init__(self, config: Config) -> None:
|
35
36
|
super().__init__(config)
|
36
37
|
|
37
|
-
self.
|
38
|
-
|
38
|
+
if self.config.disable_multiprocessing:
|
39
|
+
self._mp_ctx = None
|
40
|
+
self._mp_manager = None
|
41
|
+
else:
|
42
|
+
self._mp_ctx = mp.get_context(config.multiprocessing_context)
|
43
|
+
self._mp_manager = self._mp_ctx.Manager()
|
39
44
|
|
40
45
|
@property
|
41
|
-
def multiprocessing_context(self) -> BaseContext:
|
46
|
+
def multiprocessing_context(self) -> BaseContext | None:
|
42
47
|
return self._mp_ctx
|
43
48
|
|
44
49
|
@property
|
45
|
-
def multiprocessing_manager(self) -> SyncManager:
|
50
|
+
def multiprocessing_manager(self) -> SyncManager | None:
|
46
51
|
return self._mp_manager
|
47
52
|
|
48
53
|
def on_training_end(self, state: State) -> State:
|
49
54
|
state = super().on_training_end(state)
|
50
55
|
|
51
|
-
self._mp_manager
|
52
|
-
|
56
|
+
if self._mp_manager is not None:
|
57
|
+
self._mp_manager.shutdown()
|
58
|
+
self._mp_manager.join()
|
53
59
|
|
54
60
|
return state
|
{xax-0.2.1 → xax-0.2.2}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.2.1 → xax-0.2.2}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|