tensor-optix 1.2.2__tar.gz → 1.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/PKG-INFO +83 -2
  2. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/README.md +80 -0
  3. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/pyproject.toml +3 -2
  4. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/pytorch/torch_agent.py +14 -4
  5. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/tf_ppo.py +10 -0
  6. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/tf_ppo_continuous.py +6 -0
  7. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/torch_dqn.py +7 -0
  8. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/torch_ppo.py +13 -0
  9. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/torch_ppo_continuous.py +13 -0
  10. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/torch_sac.py +7 -0
  11. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/base_agent.py +9 -0
  12. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/loop_controller.py +5 -0
  13. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/policy_manager.py +3 -0
  14. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix.egg-info/PKG-INFO +83 -2
  15. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix.egg-info/requires.txt +2 -1
  16. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/setup.cfg +0 -0
  17. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/__init__.py +0 -0
  18. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/__init__.py +0 -0
  19. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/pytorch/__init__.py +0 -0
  20. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/pytorch/torch_evaluator.py +0 -0
  21. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/tensorflow/__init__.py +0 -0
  22. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/tensorflow/tf_agent.py +0 -0
  23. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/adapters/tensorflow/tf_evaluator.py +0 -0
  24. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/__init__.py +0 -0
  25. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/tf_dqn.py +0 -0
  26. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/algorithms/tf_sac.py +0 -0
  27. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/__init__.py +0 -0
  28. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/backoff_scheduler.py +0 -0
  29. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/base_evaluator.py +0 -0
  30. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/base_optimizer.py +0 -0
  31. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/base_pipeline.py +0 -0
  32. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/checkpoint_registry.py +0 -0
  33. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/diagnostic_controller.py +0 -0
  34. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/ensemble_agent.py +0 -0
  35. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/meta_controller.py +0 -0
  36. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/normalizers.py +0 -0
  37. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/regime_detector.py +0 -0
  38. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/replay_buffer.py +0 -0
  39. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/trajectory_buffer.py +0 -0
  40. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/core/types.py +0 -0
  41. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/exploration/__init__.py +0 -0
  42. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/exploration/rnd.py +0 -0
  43. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/optimizer.py +0 -0
  44. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/optimizers/__init__.py +0 -0
  45. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/optimizers/backoff_optimizer.py +0 -0
  46. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/optimizers/momentum_optimizer.py +0 -0
  47. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/optimizers/pbt_optimizer.py +0 -0
  48. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/optimizers/spsa_optimizer.py +0 -0
  49. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/orchestrator.py +0 -0
  50. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/pipeline/__init__.py +0 -0
  51. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/pipeline/batch_pipeline.py +0 -0
  52. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/pipeline/live_pipeline.py +0 -0
  53. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix/pipeline/vector_pipeline.py +0 -0
  54. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix.egg-info/SOURCES.txt +0 -0
  55. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix.egg-info/dependency_links.txt +0 -0
  56. {tensor_optix-1.2.2 → tensor_optix-1.2.4}/tensor_optix.egg-info/top_level.txt +0 -0
@@ -1,16 +1,17 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tensor-optix
3
- Version: 1.2.2
3
+ Version: 1.2.4
4
4
  Summary: Autonomous training loop for any sequential learning model — built-in PPO, DQN, and SAC for TensorFlow and PyTorch
5
5
  Author: sup3rus3r
6
6
  License-Expression: MIT
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  Requires-Dist: tensorflow>=2.18.0
10
- Requires-Dist: gymnasium>=1.0.0
10
+ Requires-Dist: gymnasium[box2d]>=1.0.0
11
11
  Requires-Dist: numpy>=1.24.0
12
12
  Requires-Dist: matplotlib>=3.7.0
13
13
  Requires-Dist: optuna>=3.0.0
14
+ Requires-Dist: swig>=4.4.1
14
15
  Provides-Extra: torch
15
16
  Requires-Dist: torch>=2.0.0; extra == "torch"
16
17
  Requires-Dist: torchvision; extra == "torch"
@@ -926,6 +927,86 @@ tensor_optix/
926
927
 
927
928
  ---
928
929
 
930
+ ## Common Pitfalls & Best Practices
931
+
932
+ ### Device management
933
+
934
+ Every built-in Torch agent (`TorchPPOAgent`, `TorchGaussianPPOAgent`, `TorchDQNAgent`, `TorchSACAgent`) accepts a `device` parameter and moves its networks there on construction. The default is `"auto"`, which selects CUDA if available.
935
+
936
+ ```python
937
+ agent = TorchPPOAgent(actor=actor, critic=critic, optimizer=opt,
938
+ hyperparams=hp, device="cuda") # or "cpu", "auto"
939
+ ```
940
+
941
+ The base `TorchAgent` adapter now also accepts `device="auto"` and applies it consistently in `act()` and `load_weights()`. If you subclass `TorchAgent` directly, pass `device` to `super().__init__()` — otherwise obs tensors and loaded checkpoints default to CPU even on a CUDA machine.
942
+
943
+ **Watch out:** constructing the optimizer _before_ calling `.to(device)` on the model is safe because optimizers hold references to parameter tensors, not copies. But creating the optimizer _after_ `agent.load_weights()` restores weights to the wrong device can leave parameters split between CPU and GPU, which causes a silent slowdown rather than an error.
944
+
945
+ ### Ensemble memory on GPU
946
+
947
+ Spawning agents with `PolicyManager.spawn_variant()` or `agent_factory` mode creates new networks on the target device. Calling `prune()` removes agents from the ensemble and automatically calls `agent.teardown()` on each removed agent, which moves its networks to CPU and calls `torch.cuda.empty_cache()`.
948
+
949
+ If you remove agents from the ensemble by any other means (e.g., rebuilding `_ensemble` manually), call `teardown()` yourself:
950
+
951
+ ```python
952
+ removed = pm.prune(bottom_k=2) # teardown() is called automatically
953
+
954
+ # If removing manually:
955
+ agent.teardown()
956
+ ```
957
+
958
+ For long PBT-style runs with frequent spawning, monitor GPU memory with `torch.cuda.memory_allocated()`. If memory grows despite pruning, the likely cause is optimizer state — gradient moments accumulate per parameter. Re-creating the optimizer on each spawn (as the built-in `agent_factory` pattern does) avoids this.
959
+
960
+ ### On-policy vs. off-policy rollback
961
+
962
+ `rollback_on_degradation=True` is safe for PPO but harmful for DQN and SAC. Off-policy agents accumulate experience in a replay buffer across many policies. Rolling back weights without clearing the buffer means the restored policy immediately trains on transitions it never generated — corrupted Bellman targets drag it back down.
963
+
964
+ The framework handles this automatically: any agent where `is_on_policy` returns `False` skips the weight rollback even when `rollback_on_degradation=True`. If you write a custom off-policy agent, override the property:
965
+
966
+ ```python
967
+ @property
968
+ def is_on_policy(self) -> bool:
969
+ return False
970
+ ```
971
+
972
+ ### Wiring PolicyManager early stopping
973
+
974
+ `PolicyManager.as_callback()` returns a `PolicyManagerCallback` that stops training when the spawn budget is exhausted — but only if you wire the stop function:
975
+
976
+ ```python
977
+ pm_cb = pm.as_callback(agent, agent_factory=my_factory)
978
+ rl_opt = RLOptimizer(...)
979
+ pm_cb.set_stop_fn(rl_opt.stop) # required — without this, training runs the full budget
980
+ rl_opt.add_callback(pm_cb)
981
+ rl_opt.run()
982
+ ```
983
+
984
+ Without `set_stop_fn`, the callback prints the training report when the budget runs out but cannot halt the loop. Training continues until `max_episodes` is reached.
985
+
986
+ For the factory-mode PPO path (where `agent_factory` is passed to `RLOptimizer` and `pm_cb` is created inside the factory), wire the stop function inside the factory — `rl_opt` is already bound in the enclosing scope by the time the factory is called:
987
+
988
+ ```python
989
+ def agent_factory_full(params):
990
+ agent = make_agent(params)
991
+ pm_cb = pm.as_callback(agent, agent_factory=lambda: make_agent(params))
992
+ pm_cb.set_stop_fn(rl_opt.stop) # rl_opt is bound before run() calls this factory
993
+ rl_opt.add_callback(pm_cb)
994
+ return agent
995
+
996
+ rl_opt = RLOptimizer(agent_factory=agent_factory_full, ...)
997
+ rl_opt.run()
998
+ ```
999
+
1000
+ ### Checkpoint directory hygiene
1001
+
1002
+ Each run writes checkpoints to `checkpoint_dir`. If you reuse the same directory across restarts without clearing it, `CheckpointRegistry` will load stale snapshots from a previous run and roll back to them during training. Either pass a unique directory per run (include seed and timestamp) or call `shutil.rmtree(ckpt_dir, ignore_errors=True)` at the start of each run.
1003
+
1004
+ ### State dict key mismatches during weight averaging / spawning
1005
+
1006
+ `average_weights()` and `load_weights()` use PyTorch `state_dict` keys. If the architecture passed to a spawned agent shell differs from the one that was checkpointed (different layer names, sizes, or number of layers), `load_state_dict()` will raise a `RuntimeError` with a key mismatch message. The framework does not catch this — it is user responsibility to pass a compatible shell. The safest pattern is to use the same `agent_factory` for both the primary agent and all spawned variants.
1007
+
1008
+ ---
1009
+
929
1010
  ## Math & Science Reference
930
1011
 
931
1012
  ### SPSA Gradient Estimate (`SPSAOptimizer`)
@@ -884,6 +884,86 @@ tensor_optix/
884
884
 
885
885
  ---
886
886
 
887
+ ## Common Pitfalls & Best Practices
888
+
889
+ ### Device management
890
+
891
+ Every built-in Torch agent (`TorchPPOAgent`, `TorchGaussianPPOAgent`, `TorchDQNAgent`, `TorchSACAgent`) accepts a `device` parameter and moves its networks there on construction. The default is `"auto"`, which selects CUDA if available.
892
+
893
+ ```python
894
+ agent = TorchPPOAgent(actor=actor, critic=critic, optimizer=opt,
895
+ hyperparams=hp, device="cuda") # or "cpu", "auto"
896
+ ```
897
+
898
+ The base `TorchAgent` adapter now also accepts `device="auto"` and applies it consistently in `act()` and `load_weights()`. If you subclass `TorchAgent` directly, pass `device` to `super().__init__()` — otherwise obs tensors and loaded checkpoints default to CPU even on a CUDA machine.
899
+
900
+ **Watch out:** constructing the optimizer _before_ calling `.to(device)` on the model is safe because optimizers hold references to parameter tensors, not copies. But creating the optimizer _after_ `agent.load_weights()` restores weights to the wrong device can leave parameters split between CPU and GPU, which causes a silent slowdown rather than an error.
901
+
902
+ ### Ensemble memory on GPU
903
+
904
+ Spawning agents with `PolicyManager.spawn_variant()` or `agent_factory` mode creates new networks on the target device. Calling `prune()` removes agents from the ensemble and automatically calls `agent.teardown()` on each removed agent, which moves its networks to CPU and calls `torch.cuda.empty_cache()`.
905
+
906
+ If you remove agents from the ensemble by any other means (e.g., rebuilding `_ensemble` manually), call `teardown()` yourself:
907
+
908
+ ```python
909
+ removed = pm.prune(bottom_k=2) # teardown() is called automatically
910
+
911
+ # If removing manually:
912
+ agent.teardown()
913
+ ```
914
+
915
+ For long PBT-style runs with frequent spawning, monitor GPU memory with `torch.cuda.memory_allocated()`. If memory grows despite pruning, the likely cause is optimizer state — gradient moments accumulate per parameter. Re-creating the optimizer on each spawn (as the built-in `agent_factory` pattern does) avoids this.
916
+
917
+ ### On-policy vs. off-policy rollback
918
+
919
+ `rollback_on_degradation=True` is safe for PPO but harmful for DQN and SAC. Off-policy agents accumulate experience in a replay buffer across many policies. Rolling back weights without clearing the buffer means the restored policy immediately trains on transitions it never generated — corrupted Bellman targets drag it back down.
920
+
921
+ The framework handles this automatically: any agent where `is_on_policy` returns `False` skips the weight rollback even when `rollback_on_degradation=True`. If you write a custom off-policy agent, override the property:
922
+
923
+ ```python
924
+ @property
925
+ def is_on_policy(self) -> bool:
926
+ return False
927
+ ```
928
+
929
+ ### Wiring PolicyManager early stopping
930
+
931
+ `PolicyManager.as_callback()` returns a `PolicyManagerCallback` that stops training when the spawn budget is exhausted — but only if you wire the stop function:
932
+
933
+ ```python
934
+ pm_cb = pm.as_callback(agent, agent_factory=my_factory)
935
+ rl_opt = RLOptimizer(...)
936
+ pm_cb.set_stop_fn(rl_opt.stop) # required — without this, training runs the full budget
937
+ rl_opt.add_callback(pm_cb)
938
+ rl_opt.run()
939
+ ```
940
+
941
+ Without `set_stop_fn`, the callback prints the training report when the budget runs out but cannot halt the loop. Training continues until `max_episodes` is reached.
942
+
943
+ For the factory-mode PPO path (where `agent_factory` is passed to `RLOptimizer` and `pm_cb` is created inside the factory), wire the stop function inside the factory — `rl_opt` is already bound in the enclosing scope by the time the factory is called:
944
+
945
+ ```python
946
+ def agent_factory_full(params):
947
+ agent = make_agent(params)
948
+ pm_cb = pm.as_callback(agent, agent_factory=lambda: make_agent(params))
949
+ pm_cb.set_stop_fn(rl_opt.stop) # rl_opt is bound before run() calls this factory
950
+ rl_opt.add_callback(pm_cb)
951
+ return agent
952
+
953
+ rl_opt = RLOptimizer(agent_factory=agent_factory_full, ...)
954
+ rl_opt.run()
955
+ ```
956
+
957
+ ### Checkpoint directory hygiene
958
+
959
+ Each run writes checkpoints to `checkpoint_dir`. If you reuse the same directory across restarts without clearing it, `CheckpointRegistry` will load stale snapshots from a previous run and roll back to them during training. Either pass a unique directory per run (include seed and timestamp) or call `shutil.rmtree(ckpt_dir, ignore_errors=True)` at the start of each run.
960
+
961
+ ### State dict key mismatches during weight averaging / spawning
962
+
963
+ `average_weights()` and `load_weights()` use PyTorch `state_dict` keys. If the architecture passed to a spawned agent shell differs from the one that was checkpointed (different layer names, sizes, or number of layers), `load_state_dict()` will raise a `RuntimeError` with a key mismatch message. The framework does not catch this — it is user responsibility to pass a compatible shell. The safest pattern is to use the same `agent_factory` for both the primary agent and all spawned variants.
964
+
965
+ ---
966
+
887
967
  ## Math & Science Reference
888
968
 
889
969
  ### SPSA Gradient Estimate (`SPSAOptimizer`)
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "tensor-optix"
7
- version = "1.2.2"
7
+ version = "1.2.4"
8
8
  description = "Autonomous training loop for any sequential learning model — built-in PPO, DQN, and SAC for TensorFlow and PyTorch"
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -15,10 +15,11 @@ authors = [
15
15
 
16
16
  dependencies = [
17
17
  "tensorflow>=2.18.0",
18
- "gymnasium>=1.0.0",
18
+ "gymnasium[box2d]>=1.0.0",
19
19
  "numpy>=1.24.0",
20
20
  "matplotlib>=3.7.0",
21
21
  "optuna>=3.0.0",
22
+ "swig>=4.4.1",
22
23
  ]
23
24
 
24
25
  [project.optional-dependencies]
@@ -31,13 +31,17 @@ class TorchAgent(BaseAgent):
31
31
  model=model,
32
32
  optimizer=torch.optim.Adam(model.parameters(), lr=3e-4),
33
33
  hyperparams=HyperparamSet(params={"learning_rate": 3e-4, "gamma": 0.99}, episode_id=0),
34
+ device="auto", # "cuda" if available, else "cpu"
34
35
  )
35
36
  """
36
37
 
37
- def __init__(self, model, optimizer, hyperparams: HyperparamSet, compute_loss_fn=None):
38
+ def __init__(self, model, optimizer, hyperparams: HyperparamSet, compute_loss_fn=None, device: str = "auto"):
38
39
  import torch
39
40
  self._torch = torch
40
- self.model = model
41
+ if device == "auto":
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ self._device = torch.device(device)
44
+ self.model = model.to(self._device)
41
45
  self.optimizer = optimizer
42
46
  self._hyperparams = hyperparams.copy()
43
47
  self._compute_loss_fn = compute_loss_fn
@@ -48,7 +52,7 @@ class TorchAgent(BaseAgent):
48
52
  Override for continuous actions or custom sampling strategies.
49
53
  """
50
54
  import torch
51
- obs = torch.as_tensor(np.atleast_2d(observation), dtype=torch.float32)
55
+ obs = torch.as_tensor(np.atleast_2d(observation), dtype=torch.float32).to(self._device)
52
56
  with torch.no_grad():
53
57
  logits = self.model(obs)
54
58
  action = int(torch.argmax(logits, dim=-1).item())
@@ -140,5 +144,11 @@ class TorchAgent(BaseAgent):
140
144
 
141
145
  def load_weights(self, path: str) -> None:
142
146
  import torch
143
- state = torch.load(os.path.join(path, "model.pt"), map_location="cpu")
147
+ state = torch.load(os.path.join(path, "model.pt"), map_location=self._device)
144
148
  self.model.load_state_dict(state)
149
+
150
+ def teardown(self) -> None:
151
+ """Move model to CPU and free CUDA memory."""
152
+ import torch
153
+ self.model.cpu()
154
+ torch.cuda.empty_cache()
@@ -291,6 +291,16 @@ class TFPPOAgent(BaseAgent):
291
291
  # Internal helpers
292
292
  # ------------------------------------------------------------------
293
293
 
294
+ def reset_cache(self) -> None:
295
+ """
296
+ Discard all entries in the rollout cache without learning from them.
297
+ Called by LoopController after a val-pipeline window is collected, so
298
+ that val-rollout entries never bleed into the next training learn() call.
299
+ """
300
+ self._cache_obs.clear()
301
+ self._cache_log_probs.clear()
302
+ self._cache_values.clear()
303
+
294
304
  def _clear_cache(self, T: int) -> None:
295
305
  del self._cache_obs[:T]
296
306
  del self._cache_log_probs[:T]
@@ -319,6 +319,12 @@ class TFGaussianPPOAgent(BaseAgent):
319
319
  for v in module.trainable_variables:
320
320
  v.assign(v * (1.0 + noise_scale * tf.random.normal(v.shape)))
321
321
 
322
+ def reset_cache(self) -> None:
323
+ """Discard all rollout cache entries without learning from them."""
324
+ self._cache_obs.clear()
325
+ self._cache_log_probs.clear()
326
+ self._cache_values.clear()
327
+
322
328
  @staticmethod
323
329
  def _explained_variance(values: np.ndarray, returns: np.ndarray) -> float:
324
330
  var_returns = float(np.var(returns))
@@ -254,3 +254,10 @@ class TorchDQNAgent(BaseAgent):
254
254
  for param in self._q.parameters():
255
255
  param.mul_(1.0 + noise_scale * torch.randn_like(param))
256
256
  self._q_target.load_state_dict(self._q.state_dict())
257
+
258
+ def teardown(self) -> None:
259
+ """Move networks to CPU and free CUDA memory."""
260
+ import torch
261
+ self._q.cpu()
262
+ self._q_target.cpu()
263
+ torch.cuda.empty_cache()
@@ -275,6 +275,19 @@ class TorchPPOAgent(BaseAgent):
275
275
  for param in module.parameters():
276
276
  param.mul_(1.0 + noise_scale * torch.randn_like(param))
277
277
 
278
+ def reset_cache(self) -> None:
279
+ """Discard all rollout cache entries without learning from them."""
280
+ self._cache_obs.clear()
281
+ self._cache_log_probs.clear()
282
+ self._cache_values.clear()
283
+
284
+ def teardown(self) -> None:
285
+ """Move networks to CPU and free CUDA memory."""
286
+ import torch
287
+ self._actor.cpu()
288
+ self._critic.cpu()
289
+ torch.cuda.empty_cache()
290
+
278
291
  @staticmethod
279
292
  def _explained_variance(values: np.ndarray, returns: np.ndarray) -> float:
280
293
  var_returns = float(np.var(returns))
@@ -343,6 +343,19 @@ class TorchGaussianPPOAgent(BaseAgent):
343
343
  for param in module.parameters():
344
344
  param.mul_(1.0 + noise_scale * torch.randn_like(param))
345
345
 
346
+ def reset_cache(self) -> None:
347
+ """Discard all rollout cache entries without learning from them."""
348
+ self._cache_obs.clear()
349
+ self._cache_log_probs.clear()
350
+ self._cache_values.clear()
351
+
352
+ def teardown(self) -> None:
353
+ """Move networks to CPU and free CUDA memory."""
354
+ import torch
355
+ self._actor.cpu()
356
+ self._critic.cpu()
357
+ torch.cuda.empty_cache()
358
+
346
359
  @staticmethod
347
360
  def _explained_variance(values: np.ndarray, returns: np.ndarray) -> float:
348
361
  var_returns = float(np.var(returns))
@@ -265,6 +265,13 @@ class TorchSACAgent(BaseAgent):
265
265
  # Internal helpers
266
266
  # ------------------------------------------------------------------
267
267
 
268
+ def teardown(self) -> None:
269
+ """Move all networks to CPU and free CUDA memory."""
270
+ import torch
271
+ for module in (self._actor, self._c1, self._c2, self._c1_tgt, self._c2_tgt):
272
+ module.cpu()
273
+ torch.cuda.empty_cache()
274
+
268
275
  def _sample_action(self, obs):
269
276
  import torch
270
277
  out = self._actor(obs)
@@ -114,3 +114,12 @@ class BaseAgent(ABC):
114
114
  restores the best checkpoint — so perturbation is always relative
115
115
  to the best known weights, not the current (possibly degraded) ones.
116
116
  """
117
+
118
+ def teardown(self) -> None:
119
+ """
120
+ Release resources held by this agent (GPU memory, file handles, etc.).
121
+
122
+ Called by PolicyManager.prune() when an agent is removed from the
123
+ ensemble. Override in framework-specific subclasses to move networks
124
+ to CPU and free CUDA memory. Default: no-op.
125
+ """
@@ -227,6 +227,11 @@ class LoopController:
227
227
  val_episode = next(self._val_gen)
228
228
  val_episode.episode_id = episode_id
229
229
  val_metrics = self._evaluator.score_validation(val_episode)
230
+ # Val pipeline calls agent.act() to collect rollouts, which
231
+ # populates the on-policy rollout cache. That data must never
232
+ # be consumed by the next training learn() call. Clear it now.
233
+ if hasattr(self._agent, "reset_cache"):
234
+ self._agent.reset_cache()
230
235
  eval_metrics = self._evaluator.combine(train_metrics, val_metrics)
231
236
  eval_metrics.episode_id = episode_id
232
237
  logger.debug(
@@ -215,6 +215,9 @@ class PolicyManager:
215
215
  self._score_history = new_score_history
216
216
  self._prune_count += bottom_k
217
217
 
218
+ for agent in removed_agents:
219
+ agent.teardown()
220
+
218
221
  logger.info(
219
222
  "PolicyManager.prune: removed %d agent(s), ensemble size now %d",
220
223
  len(removed_agents),
@@ -1,16 +1,17 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tensor-optix
3
- Version: 1.2.2
3
+ Version: 1.2.4
4
4
  Summary: Autonomous training loop for any sequential learning model — built-in PPO, DQN, and SAC for TensorFlow and PyTorch
5
5
  Author: sup3rus3r
6
6
  License-Expression: MIT
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  Requires-Dist: tensorflow>=2.18.0
10
- Requires-Dist: gymnasium>=1.0.0
10
+ Requires-Dist: gymnasium[box2d]>=1.0.0
11
11
  Requires-Dist: numpy>=1.24.0
12
12
  Requires-Dist: matplotlib>=3.7.0
13
13
  Requires-Dist: optuna>=3.0.0
14
+ Requires-Dist: swig>=4.4.1
14
15
  Provides-Extra: torch
15
16
  Requires-Dist: torch>=2.0.0; extra == "torch"
16
17
  Requires-Dist: torchvision; extra == "torch"
@@ -926,6 +927,86 @@ tensor_optix/
926
927
 
927
928
  ---
928
929
 
930
+ ## Common Pitfalls & Best Practices
931
+
932
+ ### Device management
933
+
934
+ Every built-in Torch agent (`TorchPPOAgent`, `TorchGaussianPPOAgent`, `TorchDQNAgent`, `TorchSACAgent`) accepts a `device` parameter and moves its networks there on construction. The default is `"auto"`, which selects CUDA if available.
935
+
936
+ ```python
937
+ agent = TorchPPOAgent(actor=actor, critic=critic, optimizer=opt,
938
+ hyperparams=hp, device="cuda") # or "cpu", "auto"
939
+ ```
940
+
941
+ The base `TorchAgent` adapter now also accepts `device="auto"` and applies it consistently in `act()` and `load_weights()`. If you subclass `TorchAgent` directly, pass `device` to `super().__init__()` — otherwise obs tensors and loaded checkpoints default to CPU even on a CUDA machine.
942
+
943
+ **Watch out:** constructing the optimizer _before_ calling `.to(device)` on the model is safe because optimizers hold references to parameter tensors, not copies. But creating the optimizer _after_ `agent.load_weights()` restores weights to the wrong device can leave parameters split between CPU and GPU, which causes a silent slowdown rather than an error.
944
+
945
+ ### Ensemble memory on GPU
946
+
947
+ Spawning agents with `PolicyManager.spawn_variant()` or `agent_factory` mode creates new networks on the target device. Calling `prune()` removes agents from the ensemble and automatically calls `agent.teardown()` on each removed agent, which moves its networks to CPU and calls `torch.cuda.empty_cache()`.
948
+
949
+ If you remove agents from the ensemble by any other means (e.g., rebuilding `_ensemble` manually), call `teardown()` yourself:
950
+
951
+ ```python
952
+ removed = pm.prune(bottom_k=2) # teardown() is called automatically
953
+
954
+ # If removing manually:
955
+ agent.teardown()
956
+ ```
957
+
958
+ For long PBT-style runs with frequent spawning, monitor GPU memory with `torch.cuda.memory_allocated()`. If memory grows despite pruning, the likely cause is optimizer state — gradient moments accumulate per parameter. Re-creating the optimizer on each spawn (as the built-in `agent_factory` pattern does) avoids this.
959
+
960
+ ### On-policy vs. off-policy rollback
961
+
962
+ `rollback_on_degradation=True` is safe for PPO but harmful for DQN and SAC. Off-policy agents accumulate experience in a replay buffer across many policies. Rolling back weights without clearing the buffer means the restored policy immediately trains on transitions it never generated — corrupted Bellman targets drag it back down.
963
+
964
+ The framework handles this automatically: any agent where `is_on_policy` returns `False` skips the weight rollback even when `rollback_on_degradation=True`. If you write a custom off-policy agent, override the property:
965
+
966
+ ```python
967
+ @property
968
+ def is_on_policy(self) -> bool:
969
+ return False
970
+ ```
971
+
972
+ ### Wiring PolicyManager early stopping
973
+
974
+ `PolicyManager.as_callback()` returns a `PolicyManagerCallback` that stops training when the spawn budget is exhausted — but only if you wire the stop function:
975
+
976
+ ```python
977
+ pm_cb = pm.as_callback(agent, agent_factory=my_factory)
978
+ rl_opt = RLOptimizer(...)
979
+ pm_cb.set_stop_fn(rl_opt.stop) # required — without this, training runs the full budget
980
+ rl_opt.add_callback(pm_cb)
981
+ rl_opt.run()
982
+ ```
983
+
984
+ Without `set_stop_fn`, the callback prints the training report when the budget runs out but cannot halt the loop. Training continues until `max_episodes` is reached.
985
+
986
+ For the factory-mode PPO path (where `agent_factory` is passed to `RLOptimizer` and `pm_cb` is created inside the factory), wire the stop function inside the factory — `rl_opt` is already bound in the enclosing scope by the time the factory is called:
987
+
988
+ ```python
989
+ def agent_factory_full(params):
990
+ agent = make_agent(params)
991
+ pm_cb = pm.as_callback(agent, agent_factory=lambda: make_agent(params))
992
+ pm_cb.set_stop_fn(rl_opt.stop) # rl_opt is bound before run() calls this factory
993
+ rl_opt.add_callback(pm_cb)
994
+ return agent
995
+
996
+ rl_opt = RLOptimizer(agent_factory=agent_factory_full, ...)
997
+ rl_opt.run()
998
+ ```
999
+
1000
+ ### Checkpoint directory hygiene
1001
+
1002
+ Each run writes checkpoints to `checkpoint_dir`. If you reuse the same directory across restarts without clearing it, `CheckpointRegistry` will load stale snapshots from a previous run and roll back to them during training. Either pass a unique directory per run (include seed and timestamp) or call `shutil.rmtree(ckpt_dir, ignore_errors=True)` at the start of each run.
1003
+
1004
+ ### State dict key mismatches during weight averaging / spawning
1005
+
1006
+ `average_weights()` and `load_weights()` use PyTorch `state_dict` keys. If the architecture passed to a spawned agent shell differs from the one that was checkpointed (different layer names, sizes, or number of layers), `load_state_dict()` will raise a `RuntimeError` with a key mismatch message. The framework does not catch this — it is user responsibility to pass a compatible shell. The safest pattern is to use the same `agent_factory` for both the primary agent and all spawned variants.
1007
+
1008
+ ---
1009
+
929
1010
  ## Math & Science Reference
930
1011
 
931
1012
  ### SPSA Gradient Estimate (`SPSAOptimizer`)
@@ -1,8 +1,9 @@
1
1
  tensorflow>=2.18.0
2
- gymnasium>=1.0.0
2
+ gymnasium[box2d]>=1.0.0
3
3
  numpy>=1.24.0
4
4
  matplotlib>=3.7.0
5
5
  optuna>=3.0.0
6
+ swig>=4.4.1
6
7
 
7
8
  [all]
8
9
  torch>=2.0.0
File without changes