textpolicy 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. textpolicy/__init__.py +52 -0
  2. textpolicy/__main__.py +8 -0
  3. textpolicy/algorithms/__init__.py +54 -0
  4. textpolicy/algorithms/grpo.py +642 -0
  5. textpolicy/algorithms/gspo.py +582 -0
  6. textpolicy/buffer/__init__.py +23 -0
  7. textpolicy/buffer/buffer.py +244 -0
  8. textpolicy/buffer/episode.py +383 -0
  9. textpolicy/buffer/sampling.py +438 -0
  10. textpolicy/buffer/storage.py +255 -0
  11. textpolicy/cli.py +67 -0
  12. textpolicy/environment/__init__.py +79 -0
  13. textpolicy/environment/base.py +110 -0
  14. textpolicy/environment/environment.py +46 -0
  15. textpolicy/environment/factory.py +103 -0
  16. textpolicy/environment/gym.py +106 -0
  17. textpolicy/environment/task_suites.py +51 -0
  18. textpolicy/environment/text_generation.py +789 -0
  19. textpolicy/environment/vectorized.py +253 -0
  20. textpolicy/generation/__init__.py +62 -0
  21. textpolicy/generation/lora.py +411 -0
  22. textpolicy/generation/mlx_generation.py +557 -0
  23. textpolicy/generation/reload.py +253 -0
  24. textpolicy/rewards/__init__.py +137 -0
  25. textpolicy/rewards/adapters.py +387 -0
  26. textpolicy/rewards/basic.py +214 -0
  27. textpolicy/rewards/integrated_system.py +338 -0
  28. textpolicy/rewards/mlx_batch_processor.py +447 -0
  29. textpolicy/rewards/registry.py +293 -0
  30. textpolicy/rewards/rollout_rewards.py +410 -0
  31. textpolicy/rewards/verifiers.py +369 -0
  32. textpolicy/rollout/__init__.py +44 -0
  33. textpolicy/rollout/aggregator.py +145 -0
  34. textpolicy/rollout/base.py +108 -0
  35. textpolicy/rollout/rollout.py +142 -0
  36. textpolicy/rollout/runner.py +280 -0
  37. textpolicy/rollout/strategy.py +208 -0
  38. textpolicy/rollout/worker.py +194 -0
  39. textpolicy/training/__init__.py +14 -0
  40. textpolicy/training/metrics.py +242 -0
  41. textpolicy/training/rollout_manager.py +78 -0
  42. textpolicy/training/trainer.py +684 -0
  43. textpolicy/utils/__init__.py +40 -0
  44. textpolicy/utils/benchmarking.py +489 -0
  45. textpolicy/utils/data.py +60 -0
  46. textpolicy/utils/debug.py +170 -0
  47. textpolicy/utils/environment.py +349 -0
  48. textpolicy/utils/logging/__init__.py +22 -0
  49. textpolicy/utils/logging/base.py +48 -0
  50. textpolicy/utils/logging/console.py +61 -0
  51. textpolicy/utils/logging/factory.py +133 -0
  52. textpolicy/utils/logging/multi.py +83 -0
  53. textpolicy/utils/logging/tensorboard.py +65 -0
  54. textpolicy/utils/logging/wandb.py +72 -0
  55. textpolicy/utils/memory.py +118 -0
  56. textpolicy/utils/performance.py +464 -0
  57. textpolicy/utils/timing.py +171 -0
  58. textpolicy/validate.py +101 -0
  59. textpolicy/validation/__init__.py +13 -0
  60. textpolicy/validation/logprob_validation.py +315 -0
  61. textpolicy-0.1.0.dist-info/METADATA +99 -0
  62. textpolicy-0.1.0.dist-info/RECORD +66 -0
  63. textpolicy-0.1.0.dist-info/entry_points.txt +2 -0
  64. textpolicy-0.0.1.dist-info/METADATA +0 -10
  65. textpolicy-0.0.1.dist-info/RECORD +0 -6
  66. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/WHEEL +0 -0
  67. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/licenses/LICENSE +0 -0
  68. {textpolicy-0.0.1.dist-info → textpolicy-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ # textpolicy/utils/logging/factory.py
2
+ """
3
+ Factory functions for creating logger instances.
4
+ """
5
+
6
+ from typing import Optional, List
7
+ from .base import Logger
8
+ from .wandb import WandbLogger
9
+ from .tensorboard import TensorboardLogger
10
+ from .console import ConsoleLogger
11
+ from .multi import MultiLogger
12
+
13
+
14
+ def create_logger(
15
+ logger_type: str = "console",
16
+ **kwargs
17
+ ) -> Logger:
18
+ """
19
+ Factory function to create logger instances.
20
+
21
+ Args:
22
+ logger_type: Type of logger ("wandb", "tensorboard", "console")
23
+ **kwargs: Logger-specific parameters
24
+
25
+ Returns:
26
+ Logger instance
27
+
28
+ Raises:
29
+ ValueError: If logger_type is unknown
30
+
31
+ Examples:
32
+ # Console logger
33
+ logger = create_logger("console", verbose=True)
34
+
35
+ # Wandb logger
36
+ logger = create_logger("wandb", project_name="my-project", run_name="test")
37
+
38
+ # TensorBoard logger
39
+ logger = create_logger("tensorboard", log_dir="./logs")
40
+ """
41
+ if logger_type == "wandb":
42
+ if "project_name" not in kwargs:
43
+ raise ValueError("project_name is required for wandb logger")
44
+ project_name = kwargs.pop("project_name")
45
+ return WandbLogger(project_name, **kwargs)
46
+ elif logger_type == "tensorboard":
47
+ if "log_dir" not in kwargs:
48
+ raise ValueError("log_dir is required for tensorboard logger")
49
+ log_dir = kwargs.pop("log_dir")
50
+ return TensorboardLogger(log_dir)
51
+ elif logger_type == "console":
52
+ return ConsoleLogger(**kwargs)
53
+ else:
54
+ available_types = ["wandb", "tensorboard", "console"]
55
+ raise ValueError(f"Unknown logger type: {logger_type}. Available: {available_types}")
56
+
57
+
58
+ def create_multi_logger(
59
+ configs: List[dict]
60
+ ) -> MultiLogger:
61
+ """
62
+ Create a MultiLogger from a list of logger configurations.
63
+
64
+ Args:
65
+ configs: List of dictionaries with "type" and other parameters
66
+
67
+ Returns:
68
+ MultiLogger instance
69
+
70
+ Example:
71
+ logger = create_multi_logger([
72
+ {"type": "console", "verbose": True},
73
+ {"type": "wandb", "project_name": "my-project"}
74
+ ])
75
+ """
76
+ loggers = []
77
+ for config in configs:
78
+ logger_type = config.pop("type")
79
+ logger = create_logger(logger_type, **config)
80
+ loggers.append(logger)
81
+
82
+ return MultiLogger(loggers)
83
+
84
+
85
+ def create_auto_logger(
86
+ project_name: Optional[str] = None,
87
+ log_dir: Optional[str] = None,
88
+ console: bool = True
89
+ ) -> Logger:
90
+ """
91
+ Automatically create appropriate logger based on available dependencies.
92
+
93
+ Priority order:
94
+ 1. Wandb (if project_name provided and wandb available)
95
+ 2. TensorBoard (if log_dir provided and tensorboard available)
96
+ 3. Console (always available)
97
+
98
+ Args:
99
+ project_name: Wandb project name (enables wandb if available)
100
+ log_dir: TensorBoard log directory (enables tensorboard if available)
101
+ console: Whether to include console logging
102
+
103
+ Returns:
104
+ Logger instance (MultiLogger if multiple backends, single Logger otherwise)
105
+ """
106
+ loggers = []
107
+
108
+ # Try wandb first
109
+ if project_name:
110
+ try:
111
+ loggers.append(WandbLogger(project_name))
112
+ except ImportError:
113
+ print("Warning: wandb not available, skipping")
114
+
115
+ # Try tensorboard
116
+ if log_dir:
117
+ try:
118
+ loggers.append(TensorboardLogger(log_dir))
119
+ except ImportError:
120
+ print("Warning: tensorboard not available, skipping")
121
+
122
+ # Always add console if requested
123
+ if console:
124
+ loggers.append(ConsoleLogger(verbose=True))
125
+
126
+ # Return appropriate logger type
127
+ if len(loggers) == 0:
128
+ # Fallback to console
129
+ return ConsoleLogger(verbose=True)
130
+ elif len(loggers) == 1:
131
+ return loggers[0]
132
+ else:
133
+ return MultiLogger(loggers)
@@ -0,0 +1,83 @@
1
+ # textpolicy/utils/logging/multi.py
2
+ """
3
+ Multi-logger for combining multiple logging backends.
4
+ """
5
+
6
+ from typing import Dict, List
7
+ from .base import Logger
8
+
9
+
10
+ class MultiLogger(Logger):
11
+ """
12
+ Combine multiple logging backends into a single interface.
13
+
14
+ Features:
15
+ - Log to multiple backends simultaneously
16
+ - Graceful error handling (one logger failure doesn't stop others)
17
+ - Unified interface for complex logging setups
18
+
19
+ Example:
20
+ # Log to both wandb and console
21
+ logger = MultiLogger([
22
+ WandbLogger("my-project"),
23
+ ConsoleLogger(verbose=True)
24
+ ])
25
+ """
26
+
27
+ def __init__(self, loggers: List[Logger]):
28
+ """
29
+ Initialize multi-logger with list of backends.
30
+
31
+ Args:
32
+ loggers: List of Logger instances to combine
33
+
34
+ Raises:
35
+ ValueError: If no loggers provided
36
+ """
37
+ if not loggers:
38
+ raise ValueError("At least one logger must be provided")
39
+ self.loggers = loggers
40
+
41
+ def log_metrics(self, metrics: Dict[str, float], step: int):
42
+ """
43
+ Log training metrics to all backends.
44
+
45
+ Continues logging to other backends even if one fails.
46
+
47
+ Args:
48
+ metrics: Training metrics dictionary
49
+ step: Training step number
50
+ """
51
+ for logger in self.loggers:
52
+ try:
53
+ logger.log_metrics(metrics, step)
54
+ except Exception as e:
55
+ print(f"Warning: Logger {type(logger).__name__} failed: {e}")
56
+
57
+ def log_evaluation(self, metrics: Dict[str, float], step: int):
58
+ """
59
+ Log evaluation metrics to all backends.
60
+
61
+ Continues logging to other backends even if one fails.
62
+
63
+ Args:
64
+ metrics: Evaluation metrics dictionary
65
+ step: Training step when evaluation was performed
66
+ """
67
+ for logger in self.loggers:
68
+ try:
69
+ logger.log_evaluation(metrics, step)
70
+ except Exception as e:
71
+ print(f"Warning: Logger {type(logger).__name__} failed: {e}")
72
+
73
+ def finish(self):
74
+ """
75
+ Finish all loggers.
76
+
77
+ Attempts to finish all loggers even if some fail.
78
+ """
79
+ for logger in self.loggers:
80
+ try:
81
+ logger.finish()
82
+ except Exception as e:
83
+ print(f"Warning: Logger {type(logger).__name__} finish failed: {e}")
@@ -0,0 +1,65 @@
1
+ # textpolicy/utils/logging/tensorboard.py
2
+ """
3
+ TensorBoard logging integration.
4
+ """
5
+
6
+ from typing import Dict
7
+ from .base import Logger
8
+
9
+
10
+ class TensorboardLogger(Logger):
11
+ """
12
+ TensorBoard integration for local experiment visualization.
13
+
14
+ Features:
15
+ - Local scalar metric visualization
16
+ - Histogram and distribution tracking
17
+ - Image and model graph visualization
18
+ - No external service dependency
19
+
20
+ Requires: pip install tensorboard
21
+ """
22
+
23
+ def __init__(self, log_dir: str):
24
+ """
25
+ Initialize TensorBoard logging.
26
+
27
+ Args:
28
+ log_dir: Directory to store TensorBoard log files
29
+
30
+ Raises:
31
+ ImportError: If tensorboard is not installed
32
+ """
33
+ try:
34
+ from torch.utils.tensorboard import SummaryWriter # type: ignore
35
+ self.writer = SummaryWriter(log_dir)
36
+ except ImportError:
37
+ raise ImportError(
38
+ "tensorboard not installed. Install with: pip install tensorboard"
39
+ )
40
+
41
+ def log_metrics(self, metrics: Dict[str, float], step: int):
42
+ """
43
+ Log training metrics to TensorBoard with 'train/' prefix.
44
+
45
+ Args:
46
+ metrics: Training metrics dictionary
47
+ step: Training step number
48
+ """
49
+ for key, value in metrics.items():
50
+ self.writer.add_scalar(f"train/{key}", value, step)
51
+
52
+ def log_evaluation(self, metrics: Dict[str, float], step: int):
53
+ """
54
+ Log evaluation metrics to TensorBoard with 'eval/' prefix.
55
+
56
+ Args:
57
+ metrics: Evaluation metrics dictionary
58
+ step: Training step when evaluation was performed
59
+ """
60
+ for key, value in metrics.items():
61
+ self.writer.add_scalar(f"eval/{key}", value, step)
62
+
63
+ def finish(self):
64
+ """Close TensorBoard writer and flush remaining data."""
65
+ self.writer.close()
@@ -0,0 +1,72 @@
1
+ # textpolicy/utils/logging/wandb.py
2
+ """
3
+ Weights & Biases (wandb) logging integration.
4
+ """
5
+
6
+ from typing import Dict, Optional
7
+ from .base import Logger
8
+
9
+
10
+ class WandbLogger(Logger):
11
+ """
12
+ Weights & Biases integration for experiment tracking.
13
+
14
+ Features:
15
+ - Automatic experiment organization with projects
16
+ - Real-time metric visualization
17
+ - Hyperparameter tracking
18
+ - Model artifact management
19
+
20
+ Requires: pip install wandb
21
+ """
22
+
23
+ def __init__(self, project_name: str, run_name: Optional[str] = None, **kwargs):
24
+ """
25
+ Initialize wandb logging.
26
+
27
+ Args:
28
+ project_name: Wandb project name for organization
29
+ run_name: Optional run name (auto-generated if None)
30
+ **kwargs: Additional wandb.init() parameters (tags, config, etc.)
31
+
32
+ Raises:
33
+ ImportError: If wandb is not installed
34
+ """
35
+ try:
36
+ import wandb # type: ignore
37
+ self.wandb = wandb
38
+ self.run = wandb.init(
39
+ project=project_name,
40
+ name=run_name,
41
+ **kwargs
42
+ )
43
+ except ImportError:
44
+ raise ImportError(
45
+ "wandb not installed. Install with: pip install wandb"
46
+ )
47
+
48
+ def log_metrics(self, metrics: Dict[str, float], step: int):
49
+ """
50
+ Log training metrics to wandb with 'train/' prefix.
51
+
52
+ Args:
53
+ metrics: Training metrics dictionary
54
+ step: Training step number
55
+ """
56
+ prefixed_metrics = {"train/" + k: v for k, v in metrics.items()}
57
+ self.wandb.log(prefixed_metrics, step=step)
58
+
59
+ def log_evaluation(self, metrics: Dict[str, float], step: int):
60
+ """
61
+ Log evaluation metrics to wandb with 'eval/' prefix.
62
+
63
+ Args:
64
+ metrics: Evaluation metrics dictionary
65
+ step: Training step when evaluation was performed
66
+ """
67
+ prefixed_metrics = {"eval/" + k: v for k, v in metrics.items()}
68
+ self.wandb.log(prefixed_metrics, step=step)
69
+
70
+ def finish(self):
71
+ """Finish wandb run and upload final data."""
72
+ self.wandb.finish()
@@ -0,0 +1,118 @@
1
+ # textpolicy/utils/memory.py
2
+ """
3
+ Memory monitoring utilities for TextPolicy.
4
+ """
5
+
6
+ import gc
7
+ from typing import Dict, Optional
8
+ try:
9
+ import mlx.core as mx # type: ignore
10
+ MLX_AVAILABLE = True
11
+ except ImportError:
12
+ MLX_AVAILABLE = False
13
+
14
+
15
+ def get_memory_stats() -> Dict[str, float]:
16
+ """
17
+ Get current memory usage statistics.
18
+
19
+ Returns:
20
+ Dictionary with memory statistics in MB
21
+ """
22
+ stats = {}
23
+
24
+ # MLX memory usage (Apple Silicon GPU/ANE)
25
+ if MLX_AVAILABLE:
26
+ try:
27
+ # MLX memory information
28
+ stats["mlx_memory_mb"] = mx.metal.get_active_memory() / 1024 / 1024
29
+ stats["mlx_peak_mb"] = mx.metal.get_peak_memory() / 1024 / 1024
30
+ except Exception as e:
31
+ print(f"Error getting MLX memory stats: {e}")
32
+ stats["mlx_memory_mb"] = 0.0
33
+ stats["mlx_peak_mb"] = 0.0
34
+
35
+ # Python memory usage
36
+ try:
37
+ import psutil # type: ignore
38
+ process = psutil.Process()
39
+ stats["python_memory_mb"] = process.memory_info().rss / 1024 / 1024
40
+ stats["python_virtual_mb"] = process.memory_info().vms / 1024 / 1024
41
+ except ImportError:
42
+ stats["python_memory_mb"] = 0.0
43
+ stats["python_virtual_mb"] = 0.0
44
+
45
+ return stats
46
+
47
+
48
+ def clear_memory():
49
+ """
50
+ Clear memory caches and run garbage collection.
51
+
52
+ Useful for freeing memory between training runs or evaluations.
53
+ """
54
+ # Python garbage collection
55
+ gc.collect()
56
+
57
+ # MLX memory cleanup
58
+ if MLX_AVAILABLE:
59
+ try:
60
+ mx.metal.clear_cache()
61
+ except Exception as e:
62
+ print(f"Error clearing MLX memory: {e}")
63
+ pass
64
+
65
+
66
+ class MemoryMonitor:
67
+ """
68
+ Monitor memory usage during training.
69
+
70
+ Features:
71
+ - Track peak memory usage
72
+ - Automatic memory alerts
73
+ - Integration with logging systems
74
+ """
75
+
76
+ def __init__(self, alert_threshold_mb: float = 8000):
77
+ """
78
+ Initialize memory monitor.
79
+
80
+ Args:
81
+ alert_threshold_mb: Memory usage threshold for alerts (default 8GB)
82
+ """
83
+ self.alert_threshold = alert_threshold_mb
84
+ self.peak_stats = {}
85
+
86
+ def check_memory(self, step: Optional[int] = None) -> Dict[str, float]:
87
+ """
88
+ Check current memory usage and update peaks.
89
+
90
+ Args:
91
+ step: Optional training step for logging
92
+
93
+ Returns:
94
+ Current memory statistics
95
+ """
96
+ current = get_memory_stats()
97
+
98
+ # Update peaks
99
+ for key, value in current.items():
100
+ if key not in self.peak_stats or value > self.peak_stats[key]:
101
+ self.peak_stats[key] = value
102
+
103
+ # Check for alerts
104
+ total_memory = current.get("mlx_memory_mb", 0) + current.get("python_memory_mb", 0)
105
+ if total_memory > self.alert_threshold:
106
+ print(f"Memory alert: {total_memory:.1f}MB (threshold: {self.alert_threshold:.1f}MB)")
107
+ if step is not None:
108
+ print(f" At training step: {step}")
109
+
110
+ return current
111
+
112
+ def get_peak_stats(self) -> Dict[str, float]:
113
+ """Get peak memory usage statistics."""
114
+ return self.peak_stats.copy()
115
+
116
+ def reset_peaks(self):
117
+ """Reset peak memory tracking."""
118
+ self.peak_stats.clear()