textpolicy 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +53 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +797 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.1.dist-info/METADATA +109 -0
  62. textpolicy-0.1.1.dist-info/RECORD +66 -0
  63. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/WHEEL +1 -1
  64. textpolicy-0.1.1.dist-info/entry_points.txt +2 -0
  65. textpolicy-0.0.1.dist-info/METADATA +0 -10
  66. textpolicy-0.0.1.dist-info/RECORD +0 -6
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.1.dist-info}/top_level.txt +0 -0
textpolicy/cli.py ADDED
@@ -0,0 +1,67 @@
1
+ """
2
+ Minimal CLI entry point for TextPolicy.
3
+
4
+ Design goals:
5
+ - Keep it tiny and dependency-free (argparse only)
6
+ - Provide a single high-signal command for students: `validate`
7
+ - Exit non-zero on failure for CI integration
8
+
9
+ This CLI is intentionally small; a config-driven runner can be added later.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import sys
17
+ from typing import Any, Dict
18
+
19
+ from .validate import validate_installation
20
+
21
+
22
+ def _cmd_validate(args: argparse.Namespace) -> int:
23
+ """Run installation validation and print results.
24
+
25
+ Uses the programmatic validate_installation() so behavior stays consistent.
26
+ """
27
+ report: Dict[str, Any] = validate_installation(verbose=not args.json)
28
+ if args.json:
29
+ print(json.dumps(report, indent=2))
30
+ # Exit code communicates status for CI/automation
31
+ return 0 if report.get("status") == "ok" else 1
32
+
33
+
34
+ def build_parser() -> argparse.ArgumentParser:
35
+ """Create the top-level CLI parser.
36
+
37
+ We avoid subcommand bloat; a single `validate` command covers health checks.
38
+ Default behavior is `validate` when no subcommand is provided for convenience.
39
+ """
40
+ parser = argparse.ArgumentParser(prog="textpolicy", description="TextPolicy command-line tools")
41
+ subparsers = parser.add_subparsers(dest="command")
42
+
43
+ p_validate = subparsers.add_parser("validate", help="Validate installation and environment")
44
+ p_validate.add_argument("--json", action="store_true", help="Output machine-readable JSON report")
45
+ p_validate.set_defaults(func=_cmd_validate)
46
+
47
+ return parser
48
+
49
+
50
+ def main(argv: list[str] | None = None) -> int:
51
+ """CLI entrypoint. Defaults to `validate` if no command is provided."""
52
+ parser = build_parser()
53
+ # If no args given, behave like `textpolicy validate` for a quick health check
54
+ if argv is None:
55
+ argv = sys.argv[1:]
56
+ if not argv:
57
+ argv = ["validate"]
58
+ args = parser.parse_args(argv)
59
+ if not hasattr(args, "func"):
60
+ parser.print_help()
61
+ return 2
62
+ return args.func(args)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ raise SystemExit(main())
67
+
@@ -0,0 +1,79 @@
1
+ # mlx_rl/environment/__init__.py
2
+ """
3
+ Modular environment implementation for MLX-RL.
4
+ """
5
+
6
+ from .environment import (
7
+ # Base classes
8
+ Environment,
9
+ EnvironmentAdapter,
10
+
11
+ # Adapters
12
+ GymAdapter,
13
+
14
+ # Factory functions
15
+ create_environment,
16
+ register_environment,
17
+ list_registered_environments,
18
+ is_gymnasium_available,
19
+
20
+ # Registry
21
+ ENVIRONMENT_REGISTRY
22
+ )
23
+
24
+ from .vectorized import (
25
+ # Vectorized environment classes
26
+ VectorizedEnvironment,
27
+ VectorizedCollector,
28
+
29
+ # Factory function
30
+ make_vectorized_env,
31
+ )
32
+
33
+ # Task suite registry for text generation environments
34
+ from .task_suites import (
35
+ register_task_suite,
36
+ list_task_suites,
37
+ get_task_suite,
38
+ )
39
+ # Re-export text generation environments and helpers for public API access
40
+ from .text_generation import (
41
+ TextGenerationEnvironment,
42
+ TextGenerationEnv,
43
+ create_text_generation_test_env,
44
+ validate_learning_progress,
45
+ )
46
+
47
+ __all__ = [
48
+ # Base classes
49
+ "Environment",
50
+ "EnvironmentAdapter",
51
+
52
+ # Adapters
53
+ "GymAdapter",
54
+
55
+ # Vectorized environment classes
56
+ "VectorizedEnvironment",
57
+ "VectorizedCollector",
58
+
59
+ # Factory functions
60
+ "create_environment",
61
+ "register_environment",
62
+ "list_registered_environments",
63
+ "is_gymnasium_available",
64
+ "make_vectorized_env",
65
+
66
+ # Task suite registry
67
+ "register_task_suite",
68
+ "list_task_suites",
69
+ "get_task_suite",
70
+
71
+ # Registry
72
+ "ENVIRONMENT_REGISTRY",
73
+
74
+ # Text generation environments and helpers
75
+ "TextGenerationEnvironment",
76
+ "TextGenerationEnv",
77
+ "create_text_generation_test_env",
78
+ "validate_learning_progress",
79
+ ]
@@ -0,0 +1,110 @@
1
+ # mlx_rl/environment/base.py
2
+ """
3
+ Base environment interface and protocols for MLX-RL.
4
+ """
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any, Dict, Tuple, Protocol
8
+
9
+
10
+ class Environment(ABC):
11
+ """
12
+ Unified environment interface for all environment types.
13
+
14
+ This abstract base class defines the contract that all environments
15
+ must implement to work with MLX-RL agents and trainers.
16
+ """
17
+
18
+ @abstractmethod
19
+ def reset(self) -> Tuple[Any, Dict[str, Any]]:
20
+ """
21
+ Reset environment and return initial observation.
22
+
23
+ Returns:
24
+ Tuple of (observation, info) following gymnasium API
25
+ """
26
+ pass
27
+
28
+ @abstractmethod
29
+ def step(self, action: Any) -> Dict[str, Any]:
30
+ """
31
+ Take action and return step result.
32
+
33
+ Args:
34
+ action: Action to take in the environment
35
+
36
+ Returns:
37
+ Dict with keys: observation, reward, terminated, truncated, info
38
+ """
39
+ pass
40
+
41
+ @property
42
+ @abstractmethod
43
+ def observation_space(self) -> Any:
44
+ """
45
+ Observation space specification.
46
+
47
+ Returns:
48
+ Space object describing valid observations
49
+ """
50
+ pass
51
+
52
+ @property
53
+ @abstractmethod
54
+ def action_space(self) -> Any:
55
+ """
56
+ Action space specification.
57
+
58
+ Returns:
59
+ Space object describing valid actions
60
+ """
61
+ pass
62
+
63
+ def clone(self) -> 'Environment':
64
+ """
65
+ Create a clone of this environment for multiprocessing.
66
+
67
+ Default implementation raises NotImplementedError.
68
+ Subclasses should override if they support cloning.
69
+
70
+ Returns:
71
+ New instance of the same environment
72
+ """
73
+ # Default clone implementation: use deepcopy to support multiprocessing clones
74
+ import copy
75
+ return copy.deepcopy(self)
76
+
77
+ def render(self, mode: str = "human") -> Any:
78
+ """
79
+ Render the environment.
80
+
81
+ Args:
82
+ mode: Rendering mode (e.g., "human", "rgb_array")
83
+
84
+ Returns:
85
+ Rendered output (depends on mode)
86
+ """
87
+ # Default implementation does nothing
88
+ pass
89
+
90
+ def close(self):
91
+ """Clean up environment resources."""
92
+ # Default implementation does nothing
93
+ pass
94
+
95
+
96
+ class EnvironmentAdapter(Protocol):
97
+ """
98
+ Protocol for environment adapters.
99
+
100
+ Adapters convert external environment APIs (gym, dm_env, etc.)
101
+ to the unified Environment interface.
102
+ """
103
+
104
+ def __init__(self, env_spec: Any, **kwargs):
105
+ """Initialize adapter with environment specification."""
106
+ ...
107
+
108
+ def clone(self) -> Environment:
109
+ """Create a clone for multiprocessing."""
110
+ ...
@@ -0,0 +1,46 @@
1
+ # mlx_rl/environment/environment.py
2
+ """
3
+ Unified environment module that coordinates all environment components.
4
+
5
+ Environment system for MLX-RL.
6
+
7
+ This module provides a unified interface for different environment types
8
+ (gymnasium, dm_env, custom environments) with adapter patterns for compatibility.
9
+
10
+ Design principles:
11
+ 1. Unified Environment interface for all environment types
12
+ 2. Adapter pattern for external environment libraries
13
+ 3. Factory system for easy environment creation and registration
14
+ 4. Support for multiprocessing through cloning
15
+ 5. Extensible design for adding new environment types
16
+ """
17
+
18
+ # Import all components for unified API
19
+ from .base import Environment, EnvironmentAdapter
20
+ from .gym import GymAdapter
21
+ from .factory import (
22
+ create_environment,
23
+ register_environment,
24
+ list_registered_environments,
25
+ is_gymnasium_available,
26
+ ENVIRONMENT_REGISTRY
27
+ )
28
+
29
+ # Re-export everything for compatibility
30
+ __all__ = [
31
+ # Base classes
32
+ "Environment",
33
+ "EnvironmentAdapter",
34
+
35
+ # Adapters
36
+ "GymAdapter",
37
+
38
+ # Factory functions
39
+ "create_environment",
40
+ "register_environment",
41
+ "list_registered_environments",
42
+ "is_gymnasium_available",
43
+
44
+ # Registry
45
+ "ENVIRONMENT_REGISTRY"
46
+ ]
@@ -0,0 +1,103 @@
1
+ # mlx_rl/environment/factory.py
2
+ """
3
+ Environment factory and registration system.
4
+ """
5
+
6
+ from typing import Dict, Callable, Union
7
+ from .base import Environment
8
+ from .gym import GymAdapter
9
+
10
+ # Environment registry for dynamic loading
11
+ ENVIRONMENT_REGISTRY: Dict[str, Callable] = {}
12
+
13
+
14
+ def register_environment(name: str, factory_func: Callable):
15
+ """
16
+ Register an environment factory function.
17
+
18
+ Args:
19
+ name: Environment name (will be lowercased)
20
+ factory_func: Function that returns Environment instance
21
+ """
22
+ ENVIRONMENT_REGISTRY[name.lower()] = factory_func
23
+
24
+
25
+ def create_environment(env_spec: Union[str, Environment], **kwargs) -> Environment:
26
+ """
27
+ Create environment from specification.
28
+
29
+ Args:
30
+ env_spec: Either environment name string or Environment instance
31
+ **kwargs: Additional arguments for environment creation
32
+
33
+ Returns:
34
+ Environment instance
35
+
36
+ Examples:
37
+ # Create from string (uses GymAdapter)
38
+ env = create_environment("CartPole-v1")
39
+
40
+ # Create with custom parameters
41
+ env = create_environment("LunarLander-v2", continuous=True)
42
+
43
+ # Pass through existing environment
44
+ env = create_environment(my_custom_env)
45
+ """
46
+ if isinstance(env_spec, str):
47
+ # String specification - look up in registry first
48
+ env_name = env_spec.lower()
49
+
50
+ if env_name in ENVIRONMENT_REGISTRY:
51
+ # Use registered factory
52
+ return ENVIRONMENT_REGISTRY[env_name](**kwargs)
53
+ else:
54
+ # Default to gymnasium adapter
55
+ return GymAdapter(env_spec, **kwargs)
56
+
57
+ elif isinstance(env_spec, Environment):
58
+ # Already an environment instance
59
+ if kwargs:
60
+ raise ValueError("Cannot pass kwargs when env_spec is already an Environment instance")
61
+ return env_spec
62
+
63
+ else:
64
+ raise TypeError(f"env_spec must be str or Environment, got {type(env_spec)}")
65
+
66
+
67
+ def list_registered_environments() -> list[str]:
68
+ """
69
+ List all registered environment names.
70
+
71
+ Returns:
72
+ List of registered environment names
73
+ """
74
+ return list(ENVIRONMENT_REGISTRY.keys())
75
+
76
+
77
+ def is_gymnasium_available() -> bool:
78
+ """
79
+ Check if gymnasium is available for import.
80
+
81
+ Returns:
82
+ True if gymnasium can be imported, False otherwise
83
+ """
84
+ try:
85
+ return True
86
+ except ImportError:
87
+ return False
88
+
89
+
90
+ # Register some common environments by default
91
+ def _register_defaults():
92
+ """Register default environment factories."""
93
+
94
+ # Gymnasium environments (if available)
95
+ if is_gymnasium_available():
96
+ register_environment("cartpole", lambda **kwargs: GymAdapter("CartPole-v1", **kwargs))
97
+ register_environment("cartpole-v1", lambda **kwargs: GymAdapter("CartPole-v1", **kwargs))
98
+ register_environment("lunarlander", lambda **kwargs: GymAdapter("LunarLander-v2", **kwargs))
99
+ register_environment("lunarlander-v2", lambda **kwargs: GymAdapter("LunarLander-v2", **kwargs))
100
+
101
+
102
+ # Register defaults on import
103
+ _register_defaults()
@@ -0,0 +1,106 @@
1
+ # mlx_rl/environment/gym.py
2
+ """
3
+ Gymnasium environment adapter for MLX-RL.
4
+ """
5
+
6
+ from typing import Any, Dict, Tuple
7
+ import gymnasium as gym
8
+ from .base import Environment
9
+
10
+
11
+ class GymAdapter(Environment):
12
+ """
13
+ Adapter for gymnasium environments.
14
+
15
+ Converts gymnasium environments to the unified Environment interface.
16
+ Handles both old and new gymnasium APIs for maximum compatibility.
17
+ """
18
+
19
+ def __init__(self, env_name: str, **kwargs):
20
+ """
21
+ Initialize gymnasium environment.
22
+
23
+ Args:
24
+ env_name: Name of the gymnasium environment (e.g., "CartPole-v1")
25
+ **kwargs: Additional arguments passed to gym.make()
26
+ """
27
+ self.env_name = env_name
28
+ self.env_kwargs = kwargs
29
+ self.env = gym.make(env_name, **kwargs)
30
+
31
+ def clone(self) -> 'GymAdapter':
32
+ """
33
+ Create a new instance of the same environment.
34
+
35
+ This is essential for multiprocessing where each worker
36
+ needs its own environment instance.
37
+
38
+ Returns:
39
+ New GymAdapter instance with same configuration
40
+ """
41
+ return GymAdapter(self.env_name, **self.env_kwargs)
42
+
43
+ def reset(self) -> Tuple[Any, Dict[str, Any]]:
44
+ """
45
+ Reset environment and return initial observation.
46
+
47
+ Handles both old gymnasium API (returns tuple) and
48
+ newer API (returns observation, info).
49
+
50
+ Returns:
51
+ Tuple of (observation, info)
52
+ """
53
+ result = self.env.reset()
54
+ if isinstance(result, tuple):
55
+ return result # New API: (obs, info)
56
+ else:
57
+ return result, {} # Old API: just obs
58
+
59
+ def step(self, action: Any) -> Dict[str, Any]:
60
+ """
61
+ Take action and return step result in unified format.
62
+
63
+ Args:
64
+ action: Action to take (format depends on action space)
65
+
66
+ Returns:
67
+ Dictionary with observation, reward, terminated, truncated, info
68
+ """
69
+ obs, reward, terminated, truncated, info = self.env.step(action)
70
+ return {
71
+ "observation": obs,
72
+ "reward": reward,
73
+ "terminated": terminated,
74
+ "truncated": truncated,
75
+ "info": info
76
+ }
77
+
78
+ @property
79
+ def observation_space(self):
80
+ """Get observation space from underlying gymnasium environment."""
81
+ return self.env.observation_space
82
+
83
+ @property
84
+ def action_space(self):
85
+ """Get action space from underlying gymnasium environment."""
86
+ return self.env.action_space
87
+
88
+ def render(self, mode: str = "human") -> Any:
89
+ """
90
+ Render the environment.
91
+
92
+ Args:
93
+ mode: Rendering mode ("human", "rgb_array", etc.)
94
+
95
+ Returns:
96
+ Rendered output (depends on mode and environment)
97
+ """
98
+ return self.env.render()
99
+
100
+ def close(self):
101
+ """Close the underlying gymnasium environment."""
102
+ self.env.close()
103
+
104
+ def __repr__(self) -> str:
105
+ """String representation for debugging."""
106
+ return f"GymAdapter(env_name='{self.env_name}', kwargs={self.env_kwargs})"
@@ -0,0 +1,51 @@
1
+ """
2
+ Task suite registry for TextGenerationEnvironment.
3
+
4
+ This module provides a minimal registry that maps suite names to loader
5
+ functions returning lists of TextGenerationTask instances. It enables
6
+ customizable evaluation suites without hardcoding them in the environment.
7
+
8
+ Notes:
9
+ - Keep this registry lightweight and dependency-free.
10
+ - Default suites are registered from text_generation.py to avoid import cycles.
11
+ - A file-backed loader (JSON/YAML) can be added later if needed.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from typing import Callable, Dict, List, Any, Optional
17
+
18
+
19
+ # Registry type: name -> callable that returns a List[TextGenerationTask]
20
+ _TASK_SUITE_REGISTRY: Dict[str, Callable[[], List[Any]]] = {}
21
+
22
+
23
+ def register_task_suite(name: str, loader: Callable[[], List[Any]]) -> None:
24
+ """
25
+ Register a task suite by name.
26
+
27
+ Args:
28
+ name: Suite identifier (e.g., "basic", "challenging").
29
+ loader: Callable that returns a list of TextGenerationTask instances.
30
+ """
31
+ if not callable(loader):
32
+ raise TypeError("loader must be callable and return a list of tasks")
33
+ _TASK_SUITE_REGISTRY[name] = loader
34
+
35
+
36
+ def get_task_suite(name: str) -> Optional[List[Any]]:
37
+ """
38
+ Load a registered task suite by name.
39
+
40
+ Returns None if the suite is not registered.
41
+ """
42
+ loader = _TASK_SUITE_REGISTRY.get(name)
43
+ if loader is None:
44
+ return None
45
+ return loader()
46
+
47
+
48
+ def list_task_suites() -> List[str]:
49
+ """List available task suite names."""
50
+ return sorted(_TASK_SUITE_REGISTRY.keys())
51
+