xax 0.3.1__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. {xax-0.3.1/xax.egg-info → xax-0.3.2}/PKG-INFO +1 -1
  2. {xax-0.3.1 → xax-0.3.2}/xax/__init__.py +1 -1
  3. {xax-0.3.1 → xax-0.3.2}/xax/core/state.py +1 -1
  4. {xax-0.3.1 → xax-0.3.2}/xax/task/base.py +3 -0
  5. {xax-0.3.1 → xax-0.3.2}/xax/task/launchers/single_process.py +2 -1
  6. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/artifacts.py +10 -1
  7. {xax-0.3.1 → xax-0.3.2}/xax/utils/logging.py +3 -1
  8. {xax-0.3.1 → xax-0.3.2/xax.egg-info}/PKG-INFO +1 -1
  9. {xax-0.3.1 → xax-0.3.2}/LICENSE +0 -0
  10. {xax-0.3.1 → xax-0.3.2}/MANIFEST.in +0 -0
  11. {xax-0.3.1 → xax-0.3.2}/README.md +0 -0
  12. {xax-0.3.1 → xax-0.3.2}/pyproject.toml +0 -0
  13. {xax-0.3.1 → xax-0.3.2}/setup.cfg +0 -0
  14. {xax-0.3.1 → xax-0.3.2}/setup.py +0 -0
  15. {xax-0.3.1 → xax-0.3.2}/xax/cli/__init__.py +0 -0
  16. {xax-0.3.1 → xax-0.3.2}/xax/cli/edit_config.py +0 -0
  17. {xax-0.3.1 → xax-0.3.2}/xax/core/__init__.py +0 -0
  18. {xax-0.3.1 → xax-0.3.2}/xax/core/conf.py +0 -0
  19. {xax-0.3.1 → xax-0.3.2}/xax/nn/__init__.py +0 -0
  20. {xax-0.3.1 → xax-0.3.2}/xax/nn/attention.py +0 -0
  21. {xax-0.3.1 → xax-0.3.2}/xax/nn/embeddings.py +0 -0
  22. {xax-0.3.1 → xax-0.3.2}/xax/nn/functions.py +0 -0
  23. {xax-0.3.1 → xax-0.3.2}/xax/nn/geom.py +0 -0
  24. {xax-0.3.1 → xax-0.3.2}/xax/nn/losses.py +0 -0
  25. {xax-0.3.1 → xax-0.3.2}/xax/nn/metrics.py +0 -0
  26. {xax-0.3.1 → xax-0.3.2}/xax/nn/parallel.py +0 -0
  27. {xax-0.3.1 → xax-0.3.2}/xax/nn/ssm.py +0 -0
  28. {xax-0.3.1 → xax-0.3.2}/xax/py.typed +0 -0
  29. {xax-0.3.1 → xax-0.3.2}/xax/requirements-dev.txt +0 -0
  30. {xax-0.3.1 → xax-0.3.2}/xax/requirements.txt +0 -0
  31. {xax-0.3.1 → xax-0.3.2}/xax/task/__init__.py +0 -0
  32. {xax-0.3.1 → xax-0.3.2}/xax/task/launchers/__init__.py +0 -0
  33. {xax-0.3.1 → xax-0.3.2}/xax/task/launchers/base.py +0 -0
  34. {xax-0.3.1 → xax-0.3.2}/xax/task/launchers/cli.py +0 -0
  35. {xax-0.3.1 → xax-0.3.2}/xax/task/logger.py +0 -0
  36. {xax-0.3.1 → xax-0.3.2}/xax/task/loggers/__init__.py +0 -0
  37. {xax-0.3.1 → xax-0.3.2}/xax/task/loggers/callback.py +0 -0
  38. {xax-0.3.1 → xax-0.3.2}/xax/task/loggers/json.py +0 -0
  39. {xax-0.3.1 → xax-0.3.2}/xax/task/loggers/state.py +0 -0
  40. {xax-0.3.1 → xax-0.3.2}/xax/task/loggers/stdout.py +0 -0
  41. {xax-0.3.1 → xax-0.3.2}/xax/task/loggers/tensorboard.py +0 -0
  42. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/__init__.py +0 -0
  43. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/checkpointing.py +0 -0
  44. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/compile.py +0 -0
  45. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/cpu_stats.py +0 -0
  46. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/data_loader.py +0 -0
  47. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/gpu_stats.py +0 -0
  48. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/process.py +0 -0
  50. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/runnable.py +0 -0
  51. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/step_wrapper.py +0 -0
  52. {xax-0.3.1 → xax-0.3.2}/xax/task/mixins/train.py +0 -0
  53. {xax-0.3.1 → xax-0.3.2}/xax/task/script.py +0 -0
  54. {xax-0.3.1 → xax-0.3.2}/xax/task/task.py +0 -0
  55. {xax-0.3.1 → xax-0.3.2}/xax/utils/__init__.py +0 -0
  56. {xax-0.3.1 → xax-0.3.2}/xax/utils/data/__init__.py +0 -0
  57. {xax-0.3.1 → xax-0.3.2}/xax/utils/data/collate.py +0 -0
  58. {xax-0.3.1 → xax-0.3.2}/xax/utils/debugging.py +0 -0
  59. {xax-0.3.1 → xax-0.3.2}/xax/utils/experiments.py +0 -0
  60. {xax-0.3.1 → xax-0.3.2}/xax/utils/jax.py +0 -0
  61. {xax-0.3.1 → xax-0.3.2}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.3.1 → xax-0.3.2}/xax/utils/numpy.py +0 -0
  63. {xax-0.3.1 → xax-0.3.2}/xax/utils/profile.py +0 -0
  64. {xax-0.3.1 → xax-0.3.2}/xax/utils/pytree.py +0 -0
  65. {xax-0.3.1 → xax-0.3.2}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.3.1 → xax-0.3.2}/xax/utils/text.py +0 -0
  67. {xax-0.3.1 → xax-0.3.2}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.3.1 → xax-0.3.2}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.3.1 → xax-0.3.2}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.3.1 → xax-0.3.2}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.3.1 → xax-0.3.2}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.3.1 → xax-0.3.2}/xax.egg-info/entry_points.txt +0 -0
  73. {xax-0.3.1 → xax-0.3.2}/xax.egg-info/requires.txt +0 -0
  74. {xax-0.3.1 → xax-0.3.2}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.1"
15
+ __version__ = "0.3.2"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -107,7 +107,7 @@ class State:
107
107
  @classmethod
108
108
  def from_dict(cls, **d: Unpack[StateDict]) -> "State":
109
109
  if "phase" in d:
110
- d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
110
+ d["_phase"] = _phase_to_int(d.pop("phase"))
111
111
 
112
112
  int32_arr = jnp.array(
113
113
  [
@@ -82,6 +82,9 @@ class BaseTask(Generic[Config]):
82
82
  def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
+ def add_logger_handlers(self, logger: logging.Logger) -> None:
86
+ pass
87
+
85
88
  @functools.cached_property
86
89
  def task_class_name(self) -> str:
87
90
  return self.__class__.__name__
@@ -15,8 +15,9 @@ def run_single_process_training(
15
15
  *cfgs: RawConfigType,
16
16
  use_cli: bool | list[str] = True,
17
17
  ) -> None:
18
- configure_logging()
18
+ logger = configure_logging()
19
19
  task_obj = task.get_task(*cfgs, use_cli=use_cli)
20
+ task_obj.add_logger_handlers(logger)
20
21
  task_obj.run()
21
22
 
22
23
 
@@ -14,7 +14,7 @@ from xax.core.state import State
14
14
  from xax.nn.parallel import is_master
15
15
  from xax.task.base import BaseConfig, BaseTask
16
16
  from xax.utils.experiments import stage_environment
17
- from xax.utils.logging import LOG_STATUS
17
+ from xax.utils.logging import LOG_STATUS, RankFilter
18
18
  from xax.utils.text import show_info
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
24
24
  @dataclass
25
25
  class ArtifactsConfig(BaseConfig):
26
26
  exp_dir: str | None = field(None, help="The fixed experiment directory")
27
+ log_to_file: bool = field(True, help="If set, add a file handler to the logger to write all logs to the exp dir")
27
28
 
28
29
 
29
30
  Config = TypeVar("Config", bound=ArtifactsConfig)
@@ -39,6 +40,14 @@ class ArtifactsMixin(BaseTask[Config]):
39
40
  self._exp_dir = None
40
41
  self._stage_dir = None
41
42
 
43
+ def add_logger_handlers(self, logger: logging.Logger) -> None:
44
+ super().add_logger_handlers(logger)
45
+ if is_master() and self.config.log_to_file:
46
+ file_handler = logging.FileHandler(self.exp_dir / "logs.txt")
47
+ file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
48
+ file_handler.addFilter(RankFilter(rank=0))
49
+ logger.addHandler(file_handler)
50
+
42
51
  @functools.cached_property
43
52
  def run_dir(self) -> Path:
44
53
  run_dir = get_run_dir()
@@ -146,7 +146,7 @@ def configure_logging(
146
146
  rank: int | None = None,
147
147
  world_size: int | None = None,
148
148
  debug: bool | None = None,
149
- ) -> None:
149
+ ) -> logging.Logger:
150
150
  """Instantiates logging.
151
151
 
152
152
  This captures logs and reroutes them to the Toasts module, which is
@@ -186,6 +186,8 @@ def configure_logging(
186
186
  logging.getLogger("PIL").setLevel(logging.WARNING)
187
187
  logging.getLogger("torch").setLevel(logging.WARNING)
188
188
 
189
+ return root_logger
190
+
189
191
 
190
192
  def get_unused_port(default: int | None = None) -> int:
191
193
  """Returns an unused port number on the local machine.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes