xax 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.1/xax.egg-info → xax-0.2.2}/PKG-INFO +1 -1
  2. {xax-0.2.1 → xax-0.2.2}/xax/__init__.py +1 -1
  3. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/cpu_stats.py +12 -9
  4. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/gpu_stats.py +14 -11
  5. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/process.py +14 -8
  6. {xax-0.2.1 → xax-0.2.2/xax.egg-info}/PKG-INFO +1 -1
  7. {xax-0.2.1 → xax-0.2.2}/LICENSE +0 -0
  8. {xax-0.2.1 → xax-0.2.2}/MANIFEST.in +0 -0
  9. {xax-0.2.1 → xax-0.2.2}/README.md +0 -0
  10. {xax-0.2.1 → xax-0.2.2}/pyproject.toml +0 -0
  11. {xax-0.2.1 → xax-0.2.2}/setup.cfg +0 -0
  12. {xax-0.2.1 → xax-0.2.2}/setup.py +0 -0
  13. {xax-0.2.1 → xax-0.2.2}/xax/core/__init__.py +0 -0
  14. {xax-0.2.1 → xax-0.2.2}/xax/core/conf.py +0 -0
  15. {xax-0.2.1 → xax-0.2.2}/xax/core/state.py +0 -0
  16. {xax-0.2.1 → xax-0.2.2}/xax/nn/__init__.py +0 -0
  17. {xax-0.2.1 → xax-0.2.2}/xax/nn/embeddings.py +0 -0
  18. {xax-0.2.1 → xax-0.2.2}/xax/nn/equinox.py +0 -0
  19. {xax-0.2.1 → xax-0.2.2}/xax/nn/export.py +0 -0
  20. {xax-0.2.1 → xax-0.2.2}/xax/nn/functions.py +0 -0
  21. {xax-0.2.1 → xax-0.2.2}/xax/nn/geom.py +0 -0
  22. {xax-0.2.1 → xax-0.2.2}/xax/nn/losses.py +0 -0
  23. {xax-0.2.1 → xax-0.2.2}/xax/nn/norm.py +0 -0
  24. {xax-0.2.1 → xax-0.2.2}/xax/nn/parallel.py +0 -0
  25. {xax-0.2.1 → xax-0.2.2}/xax/nn/ssm.py +0 -0
  26. {xax-0.2.1 → xax-0.2.2}/xax/py.typed +0 -0
  27. {xax-0.2.1 → xax-0.2.2}/xax/requirements-dev.txt +0 -0
  28. {xax-0.2.1 → xax-0.2.2}/xax/requirements.txt +0 -0
  29. {xax-0.2.1 → xax-0.2.2}/xax/task/__init__.py +0 -0
  30. {xax-0.2.1 → xax-0.2.2}/xax/task/base.py +0 -0
  31. {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/__init__.py +0 -0
  32. {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/base.py +0 -0
  33. {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/cli.py +0 -0
  34. {xax-0.2.1 → xax-0.2.2}/xax/task/launchers/single_process.py +0 -0
  35. {xax-0.2.1 → xax-0.2.2}/xax/task/logger.py +0 -0
  36. {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/json.py +0 -0
  39. {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/state.py +0 -0
  40. {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.2.1 → xax-0.2.2}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/artifacts.py +0 -0
  44. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/checkpointing.py +0 -0
  45. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/compile.py +0 -0
  46. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/logger.py +0 -0
  48. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/runnable.py +0 -0
  49. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/step_wrapper.py +0 -0
  50. {xax-0.2.1 → xax-0.2.2}/xax/task/mixins/train.py +0 -0
  51. {xax-0.2.1 → xax-0.2.2}/xax/task/script.py +0 -0
  52. {xax-0.2.1 → xax-0.2.2}/xax/task/task.py +0 -0
  53. {xax-0.2.1 → xax-0.2.2}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.1 → xax-0.2.2}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.1 → xax-0.2.2}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.1 → xax-0.2.2}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.1 → xax-0.2.2}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.1 → xax-0.2.2}/xax/utils/jax.py +0 -0
  59. {xax-0.2.1 → xax-0.2.2}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.1 → xax-0.2.2}/xax/utils/logging.py +0 -0
  61. {xax-0.2.1 → xax-0.2.2}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.1 → xax-0.2.2}/xax/utils/profile.py +0 -0
  63. {xax-0.2.1 → xax-0.2.2}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.1 → xax-0.2.2}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.1 → xax-0.2.2}/xax/utils/text.py +0 -0
  66. {xax-0.2.1 → xax-0.2.2}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.1 → xax-0.2.2}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.1 → xax-0.2.2}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.1 → xax-0.2.2}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.1 → xax-0.2.2}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.1 → xax-0.2.2}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.1 → xax-0.2.2}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.1"
15
+ __version__ = "0.2.2"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -218,33 +218,36 @@ class CPUStatsMonitor:
218
218
  class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
219
219
  """Defines a task mixin for getting CPU statistics."""
220
220
 
221
- _cpu_stats_monitor: CPUStatsMonitor
221
+ _cpu_stats_monitor: CPUStatsMonitor | None
222
222
 
223
223
  def __init__(self, config: Config) -> None:
224
224
  super().__init__(config)
225
225
 
226
- self._cpu_stats_monitor = CPUStatsMonitor(
227
- ping_interval=self.config.cpu_stats.ping_interval,
228
- context=self._mp_ctx,
229
- manager=self._mp_manager,
230
- )
226
+ if (ctx := self.multiprocessing_context) is not None and (mgr := self.multiprocessing_manager) is not None:
227
+ self._cpu_stats_monitor = CPUStatsMonitor(self.config.cpu_stats.ping_interval, ctx, mgr)
228
+ else:
229
+ self._cpu_stats_monitor = None
231
230
 
232
231
  def on_training_start(self, state: State) -> State:
233
232
  state = super().on_training_start(state)
234
233
 
235
- self._cpu_stats_monitor.start()
234
+ if (monitor := self._cpu_stats_monitor) is not None:
235
+ monitor.start()
236
236
  return state
237
237
 
238
238
  def on_training_end(self, state: State) -> State:
239
239
  state = super().on_training_end(state)
240
240
 
241
- self._cpu_stats_monitor.stop()
241
+ if (monitor := self._cpu_stats_monitor) is not None:
242
+ monitor.stop()
242
243
  return state
243
244
 
244
245
  def on_step_start(self, state: State) -> State:
245
246
  state = super().on_step_start(state)
246
247
 
247
- monitor = self._cpu_stats_monitor
248
+ if (monitor := self._cpu_stats_monitor) is None:
249
+ return state
250
+
248
251
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
249
252
 
250
253
  if stats is not None:
@@ -234,24 +234,27 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
234
234
  def __init__(self, config: Config) -> None:
235
235
  super().__init__(config)
236
236
 
237
- self._gpu_stats_monitor = None
238
- if shutil.which("nvidia-smi") is not None:
239
- self._gpu_stats_monitor = GPUStatsMonitor(
240
- config.gpu_stats.ping_interval,
241
- self._mp_ctx,
242
- self._mp_manager,
243
- )
237
+ if (
238
+ shutil.which("nvidia-smi") is not None
239
+ and (ctx := self.multiprocessing_context) is not None
240
+ and (mgr := self.multiprocessing_manager) is not None
241
+ ):
242
+ self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, ctx, mgr)
243
+ else:
244
+ self._gpu_stats_monitor = None
244
245
 
245
246
  def on_training_start(self, state: State) -> State:
246
247
  state = super().on_training_start(state)
247
- if self._gpu_stats_monitor is not None:
248
- self._gpu_stats_monitor.start()
248
+
249
+ if (monitor := self._gpu_stats_monitor) is not None:
250
+ monitor.start()
249
251
  return state
250
252
 
251
253
  def on_training_end(self, state: State) -> State:
252
254
  state = super().on_training_end(state)
253
- if self._gpu_stats_monitor is not None:
254
- self._gpu_stats_monitor.stop()
255
+
256
+ if (monitor := self._gpu_stats_monitor) is not None:
257
+ monitor.stop()
255
258
  return state
256
259
 
257
260
  def on_step_start(self, state: State) -> State:
@@ -20,6 +20,7 @@ logger: logging.Logger = logging.getLogger(__name__)
20
20
  @dataclass
21
21
  class ProcessConfig(BaseConfig):
22
22
  multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
23
+ disable_multiprocessing: bool = field(False, help="If set, disable multiprocessing")
23
24
 
24
25
 
25
26
  Config = TypeVar("Config", bound=ProcessConfig)
@@ -28,27 +29,32 @@ Config = TypeVar("Config", bound=ProcessConfig)
28
29
  class ProcessMixin(BaseTask[Config], Generic[Config]):
29
30
  """Defines a base trainer mixin for handling monitoring processes."""
30
31
 
31
- _mp_ctx: BaseContext
32
- _mp_manager: SyncManager
32
+ _mp_ctx: BaseContext | None
33
+ _mp_manager: SyncManager | None
33
34
 
34
35
  def __init__(self, config: Config) -> None:
35
36
  super().__init__(config)
36
37
 
37
- self._mp_ctx = mp.get_context(config.multiprocessing_context)
38
- self._mp_manager = self._mp_ctx.Manager()
38
+ if self.config.disable_multiprocessing:
39
+ self._mp_ctx = None
40
+ self._mp_manager = None
41
+ else:
42
+ self._mp_ctx = mp.get_context(config.multiprocessing_context)
43
+ self._mp_manager = self._mp_ctx.Manager()
39
44
 
40
45
  @property
41
- def multiprocessing_context(self) -> BaseContext:
46
+ def multiprocessing_context(self) -> BaseContext | None:
42
47
  return self._mp_ctx
43
48
 
44
49
  @property
45
- def multiprocessing_manager(self) -> SyncManager:
50
+ def multiprocessing_manager(self) -> SyncManager | None:
46
51
  return self._mp_manager
47
52
 
48
53
  def on_training_end(self, state: State) -> State:
49
54
  state = super().on_training_end(state)
50
55
 
51
- self._mp_manager.shutdown()
52
- self._mp_manager.join()
56
+ if self._mp_manager is not None:
57
+ self._mp_manager.shutdown()
58
+ self._mp_manager.join()
53
59
 
54
60
  return state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes