textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- textpolicy/__init__.py +52 -0
- textpolicy/__main__.py +8 -0
- textpolicy/algorithms/__init__.py +54 -0
- textpolicy/algorithms/grpo.py +642 -0
- textpolicy/algorithms/gspo.py +582 -0
- textpolicy/buffer/__init__.py +23 -0
- textpolicy/buffer/buffer.py +244 -0
- textpolicy/buffer/episode.py +383 -0
- textpolicy/buffer/sampling.py +438 -0
- textpolicy/buffer/storage.py +255 -0
- textpolicy/cli.py +67 -0
- textpolicy/environment/__init__.py +79 -0
- textpolicy/environment/base.py +110 -0
- textpolicy/environment/environment.py +46 -0
- textpolicy/environment/factory.py +103 -0
- textpolicy/environment/gym.py +106 -0
- textpolicy/environment/task_suites.py +51 -0
- textpolicy/environment/text_generation.py +789 -0
- textpolicy/environment/vectorized.py +253 -0
- textpolicy/generation/__init__.py +62 -0
- textpolicy/generation/lora.py +411 -0
- textpolicy/generation/mlx_generation.py +557 -0
- textpolicy/generation/reload.py +253 -0
- textpolicy/rewards/__init__.py +137 -0
- textpolicy/rewards/adapters.py +387 -0
- textpolicy/rewards/basic.py +214 -0
- textpolicy/rewards/integrated_system.py +338 -0
- textpolicy/rewards/mlx_batch_processor.py +447 -0
- textpolicy/rewards/registry.py +293 -0
- textpolicy/rewards/rollout_rewards.py +410 -0
- textpolicy/rewards/verifiers.py +369 -0
- textpolicy/rollout/__init__.py +44 -0
- textpolicy/rollout/aggregator.py +145 -0
- textpolicy/rollout/base.py +108 -0
- textpolicy/rollout/rollout.py +142 -0
- textpolicy/rollout/runner.py +280 -0
- textpolicy/rollout/strategy.py +208 -0
- textpolicy/rollout/worker.py +194 -0
- textpolicy/training/__init__.py +14 -0
- textpolicy/training/metrics.py +242 -0
- textpolicy/training/rollout_manager.py +78 -0
- textpolicy/training/trainer.py +684 -0
- textpolicy/utils/__init__.py +40 -0
- textpolicy/utils/benchmarking.py +489 -0
- textpolicy/utils/data.py +60 -0
- textpolicy/utils/debug.py +170 -0
- textpolicy/utils/environment.py +349 -0
- textpolicy/utils/logging/__init__.py +22 -0
- textpolicy/utils/logging/base.py +48 -0
- textpolicy/utils/logging/console.py +61 -0
- textpolicy/utils/logging/factory.py +133 -0
- textpolicy/utils/logging/multi.py +83 -0
- textpolicy/utils/logging/tensorboard.py +65 -0
- textpolicy/utils/logging/wandb.py +72 -0
- textpolicy/utils/memory.py +118 -0
- textpolicy/utils/performance.py +464 -0
- textpolicy/utils/timing.py +171 -0
- textpolicy/validate.py +101 -0
- textpolicy/validation/__init__.py +13 -0
- textpolicy/validation/logprob_validation.py +315 -0
- textpolicy-0.1.0.dist-info/METADATA +99 -0
- textpolicy-0.1.0.dist-info/RECORD +66 -0
- textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
- textpolicy-0.0.1.dist-info/METADATA +0 -10
- textpolicy-0.0.1.dist-info/RECORD +0 -6
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
textpolicy/cli.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Minimal CLI entry point for TextPolicy.
|
|
3
|
+
|
|
4
|
+
Design goals:
|
|
5
|
+
- Keep it tiny and dependency-free (argparse only)
|
|
6
|
+
- Provide a single high-signal command for students: `validate`
|
|
7
|
+
- Exit non-zero on failure for CI integration
|
|
8
|
+
|
|
9
|
+
This CLI is intentionally small; a config-driven runner can be added later.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
from typing import Any, Dict
|
|
18
|
+
|
|
19
|
+
from .validate import validate_installation
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _cmd_validate(args: argparse.Namespace) -> int:
|
|
23
|
+
"""Run installation validation and print results.
|
|
24
|
+
|
|
25
|
+
Uses the programmatic validate_installation() so behavior stays consistent.
|
|
26
|
+
"""
|
|
27
|
+
report: Dict[str, Any] = validate_installation(verbose=not args.json)
|
|
28
|
+
if args.json:
|
|
29
|
+
print(json.dumps(report, indent=2))
|
|
30
|
+
# Exit code communicates status for CI/automation
|
|
31
|
+
return 0 if report.get("status") == "ok" else 1
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
35
|
+
"""Create the top-level CLI parser.
|
|
36
|
+
|
|
37
|
+
We avoid subcommand bloat; a single `validate` command covers health checks.
|
|
38
|
+
Default behavior is `validate` when no subcommand is provided for convenience.
|
|
39
|
+
"""
|
|
40
|
+
parser = argparse.ArgumentParser(prog="textpolicy", description="TextPolicy command-line tools")
|
|
41
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
42
|
+
|
|
43
|
+
p_validate = subparsers.add_parser("validate", help="Validate installation and environment")
|
|
44
|
+
p_validate.add_argument("--json", action="store_true", help="Output machine-readable JSON report")
|
|
45
|
+
p_validate.set_defaults(func=_cmd_validate)
|
|
46
|
+
|
|
47
|
+
return parser
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def main(argv: list[str] | None = None) -> int:
|
|
51
|
+
"""CLI entrypoint. Defaults to `validate` if no command is provided."""
|
|
52
|
+
parser = build_parser()
|
|
53
|
+
# If no args given, behave like `textpolicy validate` for a quick health check
|
|
54
|
+
if argv is None:
|
|
55
|
+
argv = sys.argv[1:]
|
|
56
|
+
if not argv:
|
|
57
|
+
argv = ["validate"]
|
|
58
|
+
args = parser.parse_args(argv)
|
|
59
|
+
if not hasattr(args, "func"):
|
|
60
|
+
parser.print_help()
|
|
61
|
+
return 2
|
|
62
|
+
return args.func(args)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == "__main__":
|
|
66
|
+
raise SystemExit(main())
|
|
67
|
+
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# mlx_rl/environment/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Modular environment implementation for MLX-RL.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .environment import (
|
|
7
|
+
# Base classes
|
|
8
|
+
Environment,
|
|
9
|
+
EnvironmentAdapter,
|
|
10
|
+
|
|
11
|
+
# Adapters
|
|
12
|
+
GymAdapter,
|
|
13
|
+
|
|
14
|
+
# Factory functions
|
|
15
|
+
create_environment,
|
|
16
|
+
register_environment,
|
|
17
|
+
list_registered_environments,
|
|
18
|
+
is_gymnasium_available,
|
|
19
|
+
|
|
20
|
+
# Registry
|
|
21
|
+
ENVIRONMENT_REGISTRY
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from .vectorized import (
|
|
25
|
+
# Vectorized environment classes
|
|
26
|
+
VectorizedEnvironment,
|
|
27
|
+
VectorizedCollector,
|
|
28
|
+
|
|
29
|
+
# Factory function
|
|
30
|
+
make_vectorized_env,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Task suite registry for text generation environments
|
|
34
|
+
from .task_suites import (
|
|
35
|
+
register_task_suite,
|
|
36
|
+
list_task_suites,
|
|
37
|
+
get_task_suite,
|
|
38
|
+
)
|
|
39
|
+
# Re-export text generation environments and helpers for public API access
|
|
40
|
+
from .text_generation import (
|
|
41
|
+
TextGenerationEnvironment,
|
|
42
|
+
TextGenerationEnv,
|
|
43
|
+
create_text_generation_test_env,
|
|
44
|
+
validate_learning_progress,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
# Base classes
|
|
49
|
+
"Environment",
|
|
50
|
+
"EnvironmentAdapter",
|
|
51
|
+
|
|
52
|
+
# Adapters
|
|
53
|
+
"GymAdapter",
|
|
54
|
+
|
|
55
|
+
# Vectorized environment classes
|
|
56
|
+
"VectorizedEnvironment",
|
|
57
|
+
"VectorizedCollector",
|
|
58
|
+
|
|
59
|
+
# Factory functions
|
|
60
|
+
"create_environment",
|
|
61
|
+
"register_environment",
|
|
62
|
+
"list_registered_environments",
|
|
63
|
+
"is_gymnasium_available",
|
|
64
|
+
"make_vectorized_env",
|
|
65
|
+
|
|
66
|
+
# Task suite registry
|
|
67
|
+
"register_task_suite",
|
|
68
|
+
"list_task_suites",
|
|
69
|
+
"get_task_suite",
|
|
70
|
+
|
|
71
|
+
# Registry
|
|
72
|
+
"ENVIRONMENT_REGISTRY",
|
|
73
|
+
|
|
74
|
+
# Text generation environments and helpers
|
|
75
|
+
"TextGenerationEnvironment",
|
|
76
|
+
"TextGenerationEnv",
|
|
77
|
+
"create_text_generation_test_env",
|
|
78
|
+
"validate_learning_progress",
|
|
79
|
+
]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# mlx_rl/environment/base.py
|
|
2
|
+
"""
|
|
3
|
+
Base environment interface and protocols for MLX-RL.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any, Dict, Tuple, Protocol
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Environment(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Unified environment interface for all environment types.
|
|
13
|
+
|
|
14
|
+
This abstract base class defines the contract that all environments
|
|
15
|
+
must implement to work with MLX-RL agents and trainers.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def reset(self) -> Tuple[Any, Dict[str, Any]]:
|
|
20
|
+
"""
|
|
21
|
+
Reset environment and return initial observation.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Tuple of (observation, info) following gymnasium API
|
|
25
|
+
"""
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def step(self, action: Any) -> Dict[str, Any]:
|
|
30
|
+
"""
|
|
31
|
+
Take action and return step result.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
action: Action to take in the environment
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Dict with keys: observation, reward, terminated, truncated, info
|
|
38
|
+
"""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def observation_space(self) -> Any:
|
|
44
|
+
"""
|
|
45
|
+
Observation space specification.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Space object describing valid observations
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def action_space(self) -> Any:
|
|
55
|
+
"""
|
|
56
|
+
Action space specification.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Space object describing valid actions
|
|
60
|
+
"""
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
def clone(self) -> 'Environment':
|
|
64
|
+
"""
|
|
65
|
+
Create a clone of this environment for multiprocessing.
|
|
66
|
+
|
|
67
|
+
Default implementation raises NotImplementedError.
|
|
68
|
+
Subclasses should override if they support cloning.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
New instance of the same environment
|
|
72
|
+
"""
|
|
73
|
+
# Default clone implementation: use deepcopy to support multiprocessing clones
|
|
74
|
+
import copy
|
|
75
|
+
return copy.deepcopy(self)
|
|
76
|
+
|
|
77
|
+
def render(self, mode: str = "human") -> Any:
|
|
78
|
+
"""
|
|
79
|
+
Render the environment.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
mode: Rendering mode (e.g., "human", "rgb_array")
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Rendered output (depends on mode)
|
|
86
|
+
"""
|
|
87
|
+
# Default implementation does nothing
|
|
88
|
+
pass
|
|
89
|
+
|
|
90
|
+
def close(self):
|
|
91
|
+
"""Clean up environment resources."""
|
|
92
|
+
# Default implementation does nothing
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class EnvironmentAdapter(Protocol):
|
|
97
|
+
"""
|
|
98
|
+
Protocol for environment adapters.
|
|
99
|
+
|
|
100
|
+
Adapters convert external environment APIs (gym, dm_env, etc.)
|
|
101
|
+
to the unified Environment interface.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, env_spec: Any, **kwargs):
|
|
105
|
+
"""Initialize adapter with environment specification."""
|
|
106
|
+
...
|
|
107
|
+
|
|
108
|
+
def clone(self) -> Environment:
|
|
109
|
+
"""Create a clone for multiprocessing."""
|
|
110
|
+
...
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# mlx_rl/environment/environment.py
|
|
2
|
+
"""
|
|
3
|
+
Unified environment module that coordinates all environment components.
|
|
4
|
+
|
|
5
|
+
Environment system for MLX-RL.
|
|
6
|
+
|
|
7
|
+
This module provides a unified interface for different environment types
|
|
8
|
+
(gymnasium, dm_env, custom environments) with adapter patterns for compatibility.
|
|
9
|
+
|
|
10
|
+
Design principles:
|
|
11
|
+
1. Unified Environment interface for all environment types
|
|
12
|
+
2. Adapter pattern for external environment libraries
|
|
13
|
+
3. Factory system for easy environment creation and registration
|
|
14
|
+
4. Support for multiprocessing through cloning
|
|
15
|
+
5. Extensible design for adding new environment types
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Import all components for unified API
|
|
19
|
+
from .base import Environment, EnvironmentAdapter
|
|
20
|
+
from .gym import GymAdapter
|
|
21
|
+
from .factory import (
|
|
22
|
+
create_environment,
|
|
23
|
+
register_environment,
|
|
24
|
+
list_registered_environments,
|
|
25
|
+
is_gymnasium_available,
|
|
26
|
+
ENVIRONMENT_REGISTRY
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Re-export everything for compatibility
|
|
30
|
+
__all__ = [
|
|
31
|
+
# Base classes
|
|
32
|
+
"Environment",
|
|
33
|
+
"EnvironmentAdapter",
|
|
34
|
+
|
|
35
|
+
# Adapters
|
|
36
|
+
"GymAdapter",
|
|
37
|
+
|
|
38
|
+
# Factory functions
|
|
39
|
+
"create_environment",
|
|
40
|
+
"register_environment",
|
|
41
|
+
"list_registered_environments",
|
|
42
|
+
"is_gymnasium_available",
|
|
43
|
+
|
|
44
|
+
# Registry
|
|
45
|
+
"ENVIRONMENT_REGISTRY"
|
|
46
|
+
]
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# mlx_rl/environment/factory.py
|
|
2
|
+
"""
|
|
3
|
+
Environment factory and registration system.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Dict, Callable, Union
|
|
7
|
+
from .base import Environment
|
|
8
|
+
from .gym import GymAdapter
|
|
9
|
+
|
|
10
|
+
# Environment registry for dynamic loading
|
|
11
|
+
ENVIRONMENT_REGISTRY: Dict[str, Callable] = {}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def register_environment(name: str, factory_func: Callable):
|
|
15
|
+
"""
|
|
16
|
+
Register an environment factory function.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
name: Environment name (will be lowercased)
|
|
20
|
+
factory_func: Function that returns Environment instance
|
|
21
|
+
"""
|
|
22
|
+
ENVIRONMENT_REGISTRY[name.lower()] = factory_func
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_environment(env_spec: Union[str, Environment], **kwargs) -> Environment:
|
|
26
|
+
"""
|
|
27
|
+
Create environment from specification.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
env_spec: Either environment name string or Environment instance
|
|
31
|
+
**kwargs: Additional arguments for environment creation
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Environment instance
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
# Create from string (uses GymAdapter)
|
|
38
|
+
env = create_environment("CartPole-v1")
|
|
39
|
+
|
|
40
|
+
# Create with custom parameters
|
|
41
|
+
env = create_environment("LunarLander-v2", continuous=True)
|
|
42
|
+
|
|
43
|
+
# Pass through existing environment
|
|
44
|
+
env = create_environment(my_custom_env)
|
|
45
|
+
"""
|
|
46
|
+
if isinstance(env_spec, str):
|
|
47
|
+
# String specification - look up in registry first
|
|
48
|
+
env_name = env_spec.lower()
|
|
49
|
+
|
|
50
|
+
if env_name in ENVIRONMENT_REGISTRY:
|
|
51
|
+
# Use registered factory
|
|
52
|
+
return ENVIRONMENT_REGISTRY[env_name](**kwargs)
|
|
53
|
+
else:
|
|
54
|
+
# Default to gymnasium adapter
|
|
55
|
+
return GymAdapter(env_spec, **kwargs)
|
|
56
|
+
|
|
57
|
+
elif isinstance(env_spec, Environment):
|
|
58
|
+
# Already an environment instance
|
|
59
|
+
if kwargs:
|
|
60
|
+
raise ValueError("Cannot pass kwargs when env_spec is already an Environment instance")
|
|
61
|
+
return env_spec
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
raise TypeError(f"env_spec must be str or Environment, got {type(env_spec)}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def list_registered_environments() -> list[str]:
|
|
68
|
+
"""
|
|
69
|
+
List all registered environment names.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
List of registered environment names
|
|
73
|
+
"""
|
|
74
|
+
return list(ENVIRONMENT_REGISTRY.keys())
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def is_gymnasium_available() -> bool:
|
|
78
|
+
"""
|
|
79
|
+
Check if gymnasium is available for import.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
True if gymnasium can be imported, False otherwise
|
|
83
|
+
"""
|
|
84
|
+
try:
|
|
85
|
+
return True
|
|
86
|
+
except ImportError:
|
|
87
|
+
return False
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# Register some common environments by default
|
|
91
|
+
def _register_defaults():
|
|
92
|
+
"""Register default environment factories."""
|
|
93
|
+
|
|
94
|
+
# Gymnasium environments (if available)
|
|
95
|
+
if is_gymnasium_available():
|
|
96
|
+
register_environment("cartpole", lambda **kwargs: GymAdapter("CartPole-v1", **kwargs))
|
|
97
|
+
register_environment("cartpole-v1", lambda **kwargs: GymAdapter("CartPole-v1", **kwargs))
|
|
98
|
+
register_environment("lunarlander", lambda **kwargs: GymAdapter("LunarLander-v2", **kwargs))
|
|
99
|
+
register_environment("lunarlander-v2", lambda **kwargs: GymAdapter("LunarLander-v2", **kwargs))
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Register defaults on import
|
|
103
|
+
_register_defaults()
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# mlx_rl/environment/gym.py
|
|
2
|
+
"""
|
|
3
|
+
Gymnasium environment adapter for MLX-RL.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, Tuple
|
|
7
|
+
import gymnasium as gym
|
|
8
|
+
from .base import Environment
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GymAdapter(Environment):
|
|
12
|
+
"""
|
|
13
|
+
Adapter for gymnasium environments.
|
|
14
|
+
|
|
15
|
+
Converts gymnasium environments to the unified Environment interface.
|
|
16
|
+
Handles both old and new gymnasium APIs for maximum compatibility.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, env_name: str, **kwargs):
|
|
20
|
+
"""
|
|
21
|
+
Initialize gymnasium environment.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
env_name: Name of the gymnasium environment (e.g., "CartPole-v1")
|
|
25
|
+
**kwargs: Additional arguments passed to gym.make()
|
|
26
|
+
"""
|
|
27
|
+
self.env_name = env_name
|
|
28
|
+
self.env_kwargs = kwargs
|
|
29
|
+
self.env = gym.make(env_name, **kwargs)
|
|
30
|
+
|
|
31
|
+
def clone(self) -> 'GymAdapter':
|
|
32
|
+
"""
|
|
33
|
+
Create a new instance of the same environment.
|
|
34
|
+
|
|
35
|
+
This is essential for multiprocessing where each worker
|
|
36
|
+
needs its own environment instance.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
New GymAdapter instance with same configuration
|
|
40
|
+
"""
|
|
41
|
+
return GymAdapter(self.env_name, **self.env_kwargs)
|
|
42
|
+
|
|
43
|
+
def reset(self) -> Tuple[Any, Dict[str, Any]]:
|
|
44
|
+
"""
|
|
45
|
+
Reset environment and return initial observation.
|
|
46
|
+
|
|
47
|
+
Handles both old gymnasium API (returns tuple) and
|
|
48
|
+
newer API (returns observation, info).
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (observation, info)
|
|
52
|
+
"""
|
|
53
|
+
result = self.env.reset()
|
|
54
|
+
if isinstance(result, tuple):
|
|
55
|
+
return result # New API: (obs, info)
|
|
56
|
+
else:
|
|
57
|
+
return result, {} # Old API: just obs
|
|
58
|
+
|
|
59
|
+
def step(self, action: Any) -> Dict[str, Any]:
|
|
60
|
+
"""
|
|
61
|
+
Take action and return step result in unified format.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
action: Action to take (format depends on action space)
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Dictionary with observation, reward, terminated, truncated, info
|
|
68
|
+
"""
|
|
69
|
+
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
70
|
+
return {
|
|
71
|
+
"observation": obs,
|
|
72
|
+
"reward": reward,
|
|
73
|
+
"terminated": terminated,
|
|
74
|
+
"truncated": truncated,
|
|
75
|
+
"info": info
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def observation_space(self):
|
|
80
|
+
"""Get observation space from underlying gymnasium environment."""
|
|
81
|
+
return self.env.observation_space
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def action_space(self):
|
|
85
|
+
"""Get action space from underlying gymnasium environment."""
|
|
86
|
+
return self.env.action_space
|
|
87
|
+
|
|
88
|
+
def render(self, mode: str = "human") -> Any:
|
|
89
|
+
"""
|
|
90
|
+
Render the environment.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
mode: Rendering mode ("human", "rgb_array", etc.)
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Rendered output (depends on mode and environment)
|
|
97
|
+
"""
|
|
98
|
+
return self.env.render()
|
|
99
|
+
|
|
100
|
+
def close(self):
|
|
101
|
+
"""Close the underlying gymnasium environment."""
|
|
102
|
+
self.env.close()
|
|
103
|
+
|
|
104
|
+
def __repr__(self) -> str:
|
|
105
|
+
"""String representation for debugging."""
|
|
106
|
+
return f"GymAdapter(env_name='{self.env_name}', kwargs={self.env_kwargs})"
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task suite registry for TextGenerationEnvironment.
|
|
3
|
+
|
|
4
|
+
This module provides a minimal registry that maps suite names to loader
|
|
5
|
+
functions returning lists of TextGenerationTask instances. It enables
|
|
6
|
+
customizable evaluation suites without hardcoding them in the environment.
|
|
7
|
+
|
|
8
|
+
Notes:
|
|
9
|
+
- Keep this registry lightweight and dependency-free.
|
|
10
|
+
- Default suites are registered from text_generation.py to avoid import cycles.
|
|
11
|
+
- A file-backed loader (JSON/YAML) can be added later if needed.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import Callable, Dict, List, Any, Optional
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Registry type: name -> callable that returns a List[TextGenerationTask]
|
|
20
|
+
_TASK_SUITE_REGISTRY: Dict[str, Callable[[], List[Any]]] = {}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def register_task_suite(name: str, loader: Callable[[], List[Any]]) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Register a task suite by name.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
name: Suite identifier (e.g., "basic", "challenging").
|
|
29
|
+
loader: Callable that returns a list of TextGenerationTask instances.
|
|
30
|
+
"""
|
|
31
|
+
if not callable(loader):
|
|
32
|
+
raise TypeError("loader must be callable and return a list of tasks")
|
|
33
|
+
_TASK_SUITE_REGISTRY[name] = loader
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_task_suite(name: str) -> Optional[List[Any]]:
|
|
37
|
+
"""
|
|
38
|
+
Load a registered task suite by name.
|
|
39
|
+
|
|
40
|
+
Returns None if the suite is not registered.
|
|
41
|
+
"""
|
|
42
|
+
loader = _TASK_SUITE_REGISTRY.get(name)
|
|
43
|
+
if loader is None:
|
|
44
|
+
return None
|
|
45
|
+
return loader()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def list_task_suites() -> List[str]:
|
|
49
|
+
"""List available task suite names."""
|
|
50
|
+
return sorted(_TASK_SUITE_REGISTRY.keys())
|
|
51
|
+
|