synth-ai 0.4.1__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (153) hide show
  1. synth_ai/__init__.py +13 -13
  2. synth_ai/cli/__init__.py +6 -15
  3. synth_ai/cli/commands/eval/__init__.py +6 -15
  4. synth_ai/cli/commands/eval/config.py +338 -0
  5. synth_ai/cli/commands/eval/core.py +236 -1091
  6. synth_ai/cli/commands/eval/runner.py +704 -0
  7. synth_ai/cli/commands/eval/validation.py +44 -117
  8. synth_ai/cli/commands/filter/core.py +7 -7
  9. synth_ai/cli/commands/filter/validation.py +2 -2
  10. synth_ai/cli/commands/smoke/core.py +7 -17
  11. synth_ai/cli/commands/status/__init__.py +1 -64
  12. synth_ai/cli/commands/status/client.py +50 -151
  13. synth_ai/cli/commands/status/config.py +3 -83
  14. synth_ai/cli/commands/status/errors.py +4 -13
  15. synth_ai/cli/commands/status/subcommands/__init__.py +2 -8
  16. synth_ai/cli/commands/status/subcommands/config.py +13 -0
  17. synth_ai/cli/commands/status/subcommands/files.py +18 -63
  18. synth_ai/cli/commands/status/subcommands/jobs.py +28 -311
  19. synth_ai/cli/commands/status/subcommands/models.py +18 -62
  20. synth_ai/cli/commands/status/subcommands/runs.py +16 -63
  21. synth_ai/cli/commands/status/subcommands/session.py +67 -172
  22. synth_ai/cli/commands/status/subcommands/summary.py +24 -32
  23. synth_ai/cli/commands/status/subcommands/utils.py +41 -0
  24. synth_ai/cli/commands/status/utils.py +16 -107
  25. synth_ai/cli/commands/train/__init__.py +18 -20
  26. synth_ai/cli/commands/train/errors.py +3 -3
  27. synth_ai/cli/commands/train/prompt_learning_validation.py +15 -16
  28. synth_ai/cli/commands/train/validation.py +7 -7
  29. synth_ai/cli/commands/train/{judge_schemas.py → verifier_schemas.py} +33 -34
  30. synth_ai/cli/commands/train/verifier_validation.py +235 -0
  31. synth_ai/cli/demo_apps/demo_task_apps/math/config.toml +0 -1
  32. synth_ai/cli/demo_apps/demo_task_apps/math/modal_task_app.py +2 -6
  33. synth_ai/cli/demo_apps/math/config.toml +0 -1
  34. synth_ai/cli/demo_apps/math/modal_task_app.py +2 -6
  35. synth_ai/cli/demo_apps/mipro/task_app.py +25 -47
  36. synth_ai/cli/lib/apps/task_app.py +12 -13
  37. synth_ai/cli/lib/task_app_discovery.py +6 -6
  38. synth_ai/cli/lib/train_cfgs.py +10 -10
  39. synth_ai/cli/task_apps/__init__.py +11 -0
  40. synth_ai/cli/task_apps/commands.py +7 -15
  41. synth_ai/core/env.py +12 -1
  42. synth_ai/core/errors.py +1 -2
  43. synth_ai/core/integrations/cloudflare.py +209 -33
  44. synth_ai/core/tracing_v3/abstractions.py +46 -0
  45. synth_ai/data/__init__.py +3 -30
  46. synth_ai/data/enums.py +1 -20
  47. synth_ai/data/rewards.py +100 -3
  48. synth_ai/products/graph_evolve/__init__.py +1 -2
  49. synth_ai/products/graph_evolve/config.py +16 -16
  50. synth_ai/products/graph_evolve/converters/__init__.py +3 -3
  51. synth_ai/products/graph_evolve/converters/openai_sft.py +7 -7
  52. synth_ai/products/graph_evolve/examples/hotpotqa/config.toml +1 -1
  53. synth_ai/products/graph_gepa/__init__.py +23 -0
  54. synth_ai/products/graph_gepa/converters/__init__.py +19 -0
  55. synth_ai/products/graph_gepa/converters/openai_sft.py +29 -0
  56. synth_ai/sdk/__init__.py +45 -35
  57. synth_ai/sdk/api/eval/__init__.py +33 -0
  58. synth_ai/sdk/api/eval/job.py +732 -0
  59. synth_ai/sdk/api/research_agent/__init__.py +276 -66
  60. synth_ai/sdk/api/train/builders.py +181 -0
  61. synth_ai/sdk/api/train/cli.py +41 -33
  62. synth_ai/sdk/api/train/configs/__init__.py +6 -4
  63. synth_ai/sdk/api/train/configs/prompt_learning.py +127 -33
  64. synth_ai/sdk/api/train/configs/rl.py +264 -16
  65. synth_ai/sdk/api/train/configs/sft.py +165 -1
  66. synth_ai/sdk/api/train/graph_validators.py +12 -12
  67. synth_ai/sdk/api/train/graphgen.py +169 -51
  68. synth_ai/sdk/api/train/graphgen_models.py +95 -45
  69. synth_ai/sdk/api/train/local_api.py +10 -0
  70. synth_ai/sdk/api/train/pollers.py +36 -0
  71. synth_ai/sdk/api/train/prompt_learning.py +390 -60
  72. synth_ai/sdk/api/train/rl.py +41 -5
  73. synth_ai/sdk/api/train/sft.py +2 -0
  74. synth_ai/sdk/api/train/task_app.py +20 -0
  75. synth_ai/sdk/api/train/validators.py +17 -17
  76. synth_ai/sdk/graphs/completions.py +239 -33
  77. synth_ai/sdk/{judging/schemas.py → graphs/verifier_schemas.py} +23 -23
  78. synth_ai/sdk/learning/__init__.py +35 -5
  79. synth_ai/sdk/learning/context_learning_client.py +531 -0
  80. synth_ai/sdk/learning/context_learning_types.py +294 -0
  81. synth_ai/sdk/learning/prompt_learning_client.py +1 -1
  82. synth_ai/sdk/learning/prompt_learning_types.py +2 -1
  83. synth_ai/sdk/learning/rl/__init__.py +0 -4
  84. synth_ai/sdk/learning/rl/contracts.py +0 -4
  85. synth_ai/sdk/localapi/__init__.py +40 -0
  86. synth_ai/sdk/localapi/apps/__init__.py +28 -0
  87. synth_ai/sdk/localapi/client.py +10 -0
  88. synth_ai/sdk/localapi/contracts.py +10 -0
  89. synth_ai/sdk/localapi/helpers.py +519 -0
  90. synth_ai/sdk/localapi/rollouts.py +93 -0
  91. synth_ai/sdk/localapi/server.py +29 -0
  92. synth_ai/sdk/localapi/template.py +49 -0
  93. synth_ai/sdk/streaming/handlers.py +6 -6
  94. synth_ai/sdk/streaming/streamer.py +10 -6
  95. synth_ai/sdk/task/__init__.py +18 -5
  96. synth_ai/sdk/task/apps/__init__.py +37 -1
  97. synth_ai/sdk/task/client.py +9 -1
  98. synth_ai/sdk/task/config.py +6 -11
  99. synth_ai/sdk/task/contracts.py +137 -95
  100. synth_ai/sdk/task/in_process.py +32 -22
  101. synth_ai/sdk/task/in_process_runner.py +9 -4
  102. synth_ai/sdk/task/rubrics/__init__.py +2 -3
  103. synth_ai/sdk/task/rubrics/loaders.py +4 -4
  104. synth_ai/sdk/task/rubrics/strict.py +3 -4
  105. synth_ai/sdk/task/server.py +76 -16
  106. synth_ai/sdk/task/trace_correlation_helpers.py +190 -139
  107. synth_ai/sdk/task/validators.py +34 -49
  108. synth_ai/sdk/training/__init__.py +7 -16
  109. synth_ai/sdk/tunnels/__init__.py +118 -0
  110. synth_ai/sdk/tunnels/cleanup.py +83 -0
  111. synth_ai/sdk/tunnels/ports.py +120 -0
  112. synth_ai/sdk/tunnels/tunneled_api.py +363 -0
  113. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/METADATA +71 -4
  114. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/RECORD +118 -128
  115. synth_ai/cli/commands/baseline/__init__.py +0 -12
  116. synth_ai/cli/commands/baseline/core.py +0 -636
  117. synth_ai/cli/commands/baseline/list.py +0 -94
  118. synth_ai/cli/commands/eval/errors.py +0 -81
  119. synth_ai/cli/commands/status/formatters.py +0 -164
  120. synth_ai/cli/commands/status/subcommands/pricing.py +0 -23
  121. synth_ai/cli/commands/status/subcommands/usage.py +0 -203
  122. synth_ai/cli/commands/train/judge_validation.py +0 -305
  123. synth_ai/cli/usage.py +0 -159
  124. synth_ai/data/specs.py +0 -36
  125. synth_ai/sdk/api/research_agent/cli.py +0 -428
  126. synth_ai/sdk/api/research_agent/config.py +0 -357
  127. synth_ai/sdk/api/research_agent/job.py +0 -717
  128. synth_ai/sdk/baseline/__init__.py +0 -25
  129. synth_ai/sdk/baseline/config.py +0 -209
  130. synth_ai/sdk/baseline/discovery.py +0 -216
  131. synth_ai/sdk/baseline/execution.py +0 -154
  132. synth_ai/sdk/judging/__init__.py +0 -15
  133. synth_ai/sdk/judging/base.py +0 -24
  134. synth_ai/sdk/judging/client.py +0 -191
  135. synth_ai/sdk/judging/types.py +0 -42
  136. synth_ai/sdk/research_agent/__init__.py +0 -34
  137. synth_ai/sdk/research_agent/container_builder.py +0 -328
  138. synth_ai/sdk/research_agent/container_spec.py +0 -198
  139. synth_ai/sdk/research_agent/defaults.py +0 -34
  140. synth_ai/sdk/research_agent/results_collector.py +0 -69
  141. synth_ai/sdk/specs/__init__.py +0 -46
  142. synth_ai/sdk/specs/dataclasses.py +0 -149
  143. synth_ai/sdk/specs/loader.py +0 -144
  144. synth_ai/sdk/specs/serializer.py +0 -199
  145. synth_ai/sdk/specs/validation.py +0 -250
  146. synth_ai/sdk/tracing/__init__.py +0 -39
  147. synth_ai/sdk/usage/__init__.py +0 -37
  148. synth_ai/sdk/usage/client.py +0 -171
  149. synth_ai/sdk/usage/models.py +0 -261
  150. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/WHEEL +0 -0
  151. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/entry_points.txt +0 -0
  152. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/licenses/LICENSE +0 -0
  153. {synth_ai-0.4.1.dist-info → synth_ai-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,9 @@
1
1
  """First-class SDK API for prompt learning (MIPRO and GEPA).
2
2
 
3
+ **Status:** Alpha
4
+
5
+ Note: MIPRO is Experimental, GEPA is Alpha.
6
+
3
7
  This module provides high-level abstractions for running prompt optimization jobs
4
8
  both via CLI (`uvx synth-ai train`) and programmatically in Python scripts.
5
9
 
@@ -8,13 +12,17 @@ Example CLI usage:
8
12
 
9
13
  Example SDK usage:
10
14
  from synth_ai.sdk.api.train.prompt_learning import PromptLearningJob
11
-
12
- job = PromptLearningJob.from_config("my_config.toml")
15
+
16
+ job = PromptLearningJob.from_dict(config_dict, api_key="sk_live_...")
13
17
  job.submit()
14
- result = job.poll_until_complete()
15
- print(f"Best score: {result['best_score']}")
18
+ result = job.poll_until_complete(progress=True) # Built-in progress printing
19
+
20
+ if result.succeeded:
21
+ print(f"Best score: {result.best_score}")
22
+ else:
23
+ print(f"Failed: {result.error}")
16
24
 
17
- For domain-specific judging, you can use **Verifier Graphs**. See `PromptLearningJudgeConfig`
25
+ For domain-specific verification, you can use **Verifier Graphs**. See `PromptLearningVerifierConfig`
18
26
  in `synth_ai.sdk.api.train.configs.prompt_learning` for configuration details.
19
27
  """
20
28
 
@@ -22,38 +30,186 @@ from __future__ import annotations
22
30
 
23
31
  import asyncio
24
32
  import os
25
- from dataclasses import dataclass
33
+ import time
34
+ from dataclasses import dataclass, field
35
+ from enum import Enum
26
36
  from pathlib import Path
27
- from typing import Any, Callable, Dict, Optional
37
+ from typing import Any, Callable, Dict, List, Optional
28
38
 
29
39
  from synth_ai.core.telemetry import log_info
30
40
 
31
- from .builders import PromptLearningBuildResult, build_prompt_learning_payload
41
+
42
+ class JobStatus(str, Enum):
43
+ """Status of a prompt learning job."""
44
+
45
+ PENDING = "pending"
46
+ QUEUED = "queued"
47
+ RUNNING = "running"
48
+ SUCCEEDED = "succeeded"
49
+ FAILED = "failed"
50
+ CANCELLED = "cancelled"
51
+
52
+ @classmethod
53
+ def from_string(cls, status: str) -> "JobStatus":
54
+ """Convert string to JobStatus, defaulting to PENDING for unknown values."""
55
+ try:
56
+ return cls(status.lower())
57
+ except ValueError:
58
+ return cls.PENDING
59
+
60
+ @property
61
+ def is_terminal(self) -> bool:
62
+ """Whether this status is terminal (job won't change further)."""
63
+ return self in (JobStatus.SUCCEEDED, JobStatus.FAILED, JobStatus.CANCELLED)
64
+
65
+ @property
66
+ def is_success(self) -> bool:
67
+ """Whether this status indicates success."""
68
+ return self == JobStatus.SUCCEEDED
69
+
70
+
71
+ @dataclass
72
+ class PromptLearningResult:
73
+ """Typed result from a prompt learning job.
74
+
75
+ Provides clean accessors for common fields instead of raw dict access.
76
+
77
+ Example:
78
+ >>> result = job.poll_until_complete()
79
+ >>> if result.succeeded:
80
+ ... print(f"Best score: {result.best_score}")
81
+ ... print(f"Best prompt: {result.best_prompt[:100]}...")
82
+ >>> else:
83
+ ... print(f"Failed: {result.error}")
84
+ """
85
+
86
+ job_id: str
87
+ status: JobStatus
88
+ best_score: Optional[float] = None
89
+ best_prompt: Optional[str] = None
90
+ error: Optional[str] = None
91
+ raw: Dict[str, Any] = field(default_factory=dict)
92
+
93
+ @classmethod
94
+ def from_response(cls, job_id: str, data: Dict[str, Any]) -> "PromptLearningResult":
95
+ """Create result from API response dict."""
96
+ status_str = data.get("status", "pending")
97
+ status = JobStatus.from_string(status_str)
98
+
99
+ # Extract best score from various field names (backward compat)
100
+ best_score = (
101
+ data.get("best_score")
102
+ or data.get("best_reward")
103
+ or data.get("best_train_score")
104
+ or data.get("best_train_reward")
105
+ )
106
+
107
+ return cls(
108
+ job_id=job_id,
109
+ status=status,
110
+ best_score=best_score,
111
+ best_prompt=data.get("best_prompt"),
112
+ error=data.get("error"),
113
+ raw=data,
114
+ )
115
+
116
+ @property
117
+ def succeeded(self) -> bool:
118
+ """Whether the job succeeded."""
119
+ return self.status.is_success
120
+
121
+ @property
122
+ def failed(self) -> bool:
123
+ """Whether the job failed."""
124
+ return self.status == JobStatus.FAILED
125
+
126
+ @property
127
+ def is_terminal(self) -> bool:
128
+ """Whether the job has reached a terminal state."""
129
+ return self.status.is_terminal
130
+
131
+ from .builders import (
132
+ PromptLearningBuildResult,
133
+ build_prompt_learning_payload,
134
+ build_prompt_learning_payload_from_mapping,
135
+ )
32
136
  from .pollers import JobPoller, PollOutcome
33
- from .task_app import check_task_app_health
34
- from .utils import ensure_api_base, http_post
137
+ from .local_api import check_local_api_health
138
+ from .utils import ensure_api_base, http_get, http_post
35
139
 
36
140
 
37
141
  @dataclass
38
142
  class PromptLearningJobConfig:
39
- """Configuration for a prompt learning job."""
40
-
41
- config_path: Path
143
+ """Configuration for a prompt learning job.
144
+
145
+ This dataclass holds all the configuration needed to submit and run
146
+ a prompt learning job (MIPRO or GEPA optimization).
147
+
148
+ Supports two modes:
149
+ 1. **File-based**: Provide `config_path` pointing to a TOML file
150
+ 2. **Programmatic**: Provide `config_dict` with the configuration directly
151
+
152
+ Both modes go through the same `PromptLearningConfig` Pydantic validation.
153
+
154
+ Attributes:
155
+ config_path: Path to the TOML configuration file. Mutually exclusive with config_dict.
156
+ config_dict: Dictionary with prompt learning configuration. Mutually exclusive with config_path.
157
+ Should have the same structure as the TOML file (with 'prompt_learning' section).
158
+ backend_url: Base URL of the Synth API backend (e.g., "https://api.usesynth.ai").
159
+ api_key: Synth API key for authentication.
160
+ task_app_api_key: API key for authenticating with the Local API.
161
+ allow_experimental: If True, allows use of experimental models.
162
+ overrides: Dictionary of config overrides.
163
+
164
+ Example (file-based):
165
+ >>> config = PromptLearningJobConfig(
166
+ ... config_path=Path("my_config.toml"),
167
+ ... backend_url="https://api.usesynth.ai",
168
+ ... api_key="sk_live_...",
169
+ ... )
170
+
171
+ Example (programmatic):
172
+ >>> config = PromptLearningJobConfig(
173
+ ... config_dict={
174
+ ... "prompt_learning": {
175
+ ... "algorithm": "gepa",
176
+ ... "task_app_url": "https://tunnel.example.com",
177
+ ... "policy": {"model": "gpt-4o-mini", "provider": "openai"},
178
+ ... "gepa": {...},
179
+ ... }
180
+ ... },
181
+ ... backend_url="https://api.usesynth.ai",
182
+ ... api_key="sk_live_...",
183
+ ... )
184
+ """
185
+
42
186
  backend_url: str
43
187
  api_key: str
188
+ config_path: Optional[Path] = None
189
+ config_dict: Optional[Dict[str, Any]] = None
44
190
  task_app_api_key: Optional[str] = None
45
191
  allow_experimental: Optional[bool] = None
46
192
  overrides: Optional[Dict[str, Any]] = None
47
-
193
+
48
194
  def __post_init__(self) -> None:
49
195
  """Validate configuration."""
50
- if not self.config_path.exists():
196
+ # Must provide exactly one of config_path or config_dict
197
+ has_path = self.config_path is not None
198
+ has_dict = self.config_dict is not None
199
+
200
+ if has_path and has_dict:
201
+ raise ValueError("Provide either config_path OR config_dict, not both")
202
+ if not has_path and not has_dict:
203
+ raise ValueError("Either config_path or config_dict is required")
204
+
205
+ if has_path and not self.config_path.exists():
51
206
  raise FileNotFoundError(f"Config file not found: {self.config_path}")
207
+
52
208
  if not self.backend_url:
53
209
  raise ValueError("backend_url is required")
54
210
  if not self.api_key:
55
211
  raise ValueError("api_key is required")
56
-
212
+
57
213
  # Get task_app_api_key from environment if not provided
58
214
  if not self.task_app_api_key:
59
215
  self.task_app_api_key = os.environ.get("ENVIRONMENT_API_KEY")
@@ -184,9 +340,100 @@ class PromptLearningJob:
184
340
  allow_experimental=allow_experimental,
185
341
  overrides=overrides or {},
186
342
  )
187
-
343
+
188
344
  return cls(config)
189
-
345
+
346
+ @classmethod
347
+ def from_dict(
348
+ cls,
349
+ config_dict: Dict[str, Any],
350
+ backend_url: Optional[str] = None,
351
+ api_key: Optional[str] = None,
352
+ task_app_api_key: Optional[str] = None,
353
+ allow_experimental: Optional[bool] = None,
354
+ overrides: Optional[Dict[str, Any]] = None,
355
+ skip_health_check: bool = False,
356
+ ) -> PromptLearningJob:
357
+ """Create a job from a configuration dictionary (programmatic use).
358
+
359
+ This allows creating prompt learning jobs without a TOML file, enabling
360
+ programmatic use in notebooks, scripts, and applications.
361
+
362
+ The config_dict should have the same structure as a TOML file:
363
+ ```python
364
+ {
365
+ "prompt_learning": {
366
+ "algorithm": "gepa",
367
+ "task_app_url": "https://...",
368
+ "policy": {"model": "gpt-4o-mini", "provider": "openai"},
369
+ "gepa": {...},
370
+ }
371
+ }
372
+ ```
373
+
374
+ Args:
375
+ config_dict: Configuration dictionary with 'prompt_learning' section
376
+ backend_url: Backend API URL (defaults to env or production)
377
+ api_key: API key (defaults to SYNTH_API_KEY env var)
378
+ task_app_api_key: Task app API key (defaults to ENVIRONMENT_API_KEY env var)
379
+ allow_experimental: Allow experimental models
380
+ overrides: Config overrides
381
+ skip_health_check: If True, skip task app health check before submission
382
+
383
+ Returns:
384
+ PromptLearningJob instance
385
+
386
+ Raises:
387
+ ValueError: If required config is missing or invalid
388
+
389
+ Example:
390
+ >>> job = PromptLearningJob.from_dict(
391
+ ... config_dict={
392
+ ... "prompt_learning": {
393
+ ... "algorithm": "gepa",
394
+ ... "task_app_url": "https://tunnel.example.com",
395
+ ... "policy": {"model": "gpt-4o-mini", "provider": "openai"},
396
+ ... "gepa": {
397
+ ... "rollout": {"budget": 50, "max_concurrent": 5},
398
+ ... "evaluation": {"train_seeds": [1, 2, 3], "val_seeds": [4, 5]},
399
+ ... "population": {"num_generations": 2, "children_per_generation": 2},
400
+ ... },
401
+ ... }
402
+ ... },
403
+ ... api_key="sk_live_...",
404
+ ... )
405
+ >>> job_id = job.submit()
406
+ """
407
+ import os
408
+
409
+ from synth_ai.core.env import get_backend_from_env
410
+
411
+ # Resolve backend URL
412
+ if not backend_url:
413
+ backend_url = os.environ.get("BACKEND_BASE_URL", "").strip()
414
+ if not backend_url:
415
+ base, _ = get_backend_from_env()
416
+ backend_url = f"{base}/api" if not base.endswith("/api") else base
417
+
418
+ # Resolve API key
419
+ if not api_key:
420
+ api_key = os.environ.get("SYNTH_API_KEY")
421
+ if not api_key:
422
+ raise ValueError(
423
+ "api_key is required (provide explicitly or set SYNTH_API_KEY env var)"
424
+ )
425
+
426
+ config = PromptLearningJobConfig(
427
+ config_dict=config_dict,
428
+ backend_url=backend_url,
429
+ api_key=api_key,
430
+ task_app_api_key=task_app_api_key,
431
+ allow_experimental=allow_experimental,
432
+ overrides=overrides or {},
433
+ )
434
+
435
+ return cls(config, skip_health_check=skip_health_check)
436
+
190
437
  @classmethod
191
438
  def from_job_id(
192
439
  cls,
@@ -223,33 +470,59 @@ class PromptLearningJob:
223
470
  "api_key is required (provide explicitly or set SYNTH_API_KEY env var)"
224
471
  )
225
472
 
226
- # Create minimal config (we don't need the config file for resuming)
473
+ # Create minimal config (we don't need the config for resuming - use empty dict)
474
+ # The config_dict is never used when resuming since we have the job_id
227
475
  config = PromptLearningJobConfig(
228
- config_path=Path("/dev/null"), # Dummy path
476
+ config_dict={"prompt_learning": {"_resumed": True}}, # Placeholder for resume mode
229
477
  backend_url=backend_url,
230
478
  api_key=api_key,
231
479
  )
232
-
480
+
233
481
  return cls(config, job_id=job_id)
234
482
 
235
483
  def _build_payload(self) -> PromptLearningBuildResult:
236
- """Build the job payload from config."""
484
+ """Build the job payload from config.
485
+
486
+ Supports both file-based (config_path) and programmatic (config_dict) modes.
487
+ Both modes route through the same PromptLearningConfig Pydantic validation.
488
+ """
237
489
  if self._build_result is None:
238
- if not self.config.config_path.exists() or self.config.config_path.name == "/dev/null":
239
- raise RuntimeError(
240
- "Cannot build payload: config_path is required for new jobs. "
241
- "Use from_job_id() to resume an existing job."
242
- )
243
-
244
490
  overrides = self.config.overrides or {}
245
491
  overrides["backend"] = self.config.backend_url
246
-
247
- self._build_result = build_prompt_learning_payload(
248
- config_path=self.config.config_path,
249
- task_url=None, # Force using TOML only
250
- overrides=overrides,
251
- allow_experimental=self.config.allow_experimental,
252
- )
492
+ # Pass task_app_api_key to builder via overrides
493
+ if self.config.task_app_api_key:
494
+ overrides["task_app_api_key"] = self.config.task_app_api_key
495
+
496
+ # Route to appropriate builder based on config mode
497
+ if self.config.config_dict is not None:
498
+ # Programmatic mode: use dict-based builder
499
+ self._build_result = build_prompt_learning_payload_from_mapping(
500
+ raw_config=self.config.config_dict,
501
+ task_url=None,
502
+ overrides=overrides,
503
+ allow_experimental=self.config.allow_experimental,
504
+ source_label="PromptLearningJob.from_dict",
505
+ )
506
+ elif self.config.config_path is not None:
507
+ # File-based mode: use path-based builder
508
+ if not self.config.config_path.exists():
509
+ raise RuntimeError(
510
+ f"Config file not found: {self.config.config_path}. "
511
+ "Use from_dict() for programmatic config or from_job_id() to resume."
512
+ )
513
+
514
+ self._build_result = build_prompt_learning_payload(
515
+ config_path=self.config.config_path,
516
+ task_url=None,
517
+ overrides=overrides,
518
+ allow_experimental=self.config.allow_experimental,
519
+ )
520
+ else:
521
+ raise RuntimeError(
522
+ "Cannot build payload: either config_path or config_dict is required. "
523
+ "Use from_config() for file-based config, from_dict() for programmatic config, "
524
+ "or from_job_id() to resume an existing job."
525
+ )
253
526
  return self._build_result
254
527
 
255
528
  def submit(self) -> str:
@@ -262,7 +535,11 @@ class PromptLearningJob:
262
535
  RuntimeError: If job submission fails
263
536
  ValueError: If task app health check fails
264
537
  """
265
- ctx: Dict[str, Any] = {"config_path": str(self.config.config_path)}
538
+ # Log context based on config mode
539
+ if self.config.config_path is not None:
540
+ ctx: Dict[str, Any] = {"config_path": str(self.config.config_path)}
541
+ else:
542
+ ctx = {"config_mode": "programmatic"}
266
543
  log_info("PromptLearningJob.submit invoked", ctx=ctx)
267
544
  if self._job_id:
268
545
  raise RuntimeError(f"Job already submitted: {self._job_id}")
@@ -271,7 +548,7 @@ class PromptLearningJob:
271
548
 
272
549
  # Health check (skip if _skip_health_check is set - useful for tunnels with DNS delay)
273
550
  if not self._skip_health_check:
274
- health = check_task_app_health(build.task_url, self.config.task_app_api_key or "")
551
+ health = check_local_api_health(build.task_url, self.config.task_app_api_key or "")
275
552
  if not health.ok:
276
553
  raise ValueError(f"Task app health check failed: {health.detail}")
277
554
 
@@ -351,40 +628,92 @@ class PromptLearningJob:
351
628
  *,
352
629
  timeout: float = 3600.0,
353
630
  interval: float = 5.0,
631
+ progress: bool = False,
354
632
  on_status: Optional[Callable[[Dict[str, Any]], None]] = None,
355
- ) -> Dict[str, Any]:
633
+ ) -> PromptLearningResult:
356
634
  """Poll job until it reaches a terminal state.
357
-
635
+
358
636
  Args:
359
637
  timeout: Maximum seconds to wait
360
638
  interval: Seconds between poll attempts
361
- on_status: Optional callback called on each status update
362
-
639
+ progress: If True, print status updates during polling (useful for notebooks)
640
+ on_status: Optional callback called on each status update (for custom progress handling)
641
+
363
642
  Returns:
364
- Final job status dictionary
365
-
643
+ PromptLearningResult with typed status, best_score, etc.
644
+
366
645
  Raises:
367
646
  RuntimeError: If job hasn't been submitted yet
368
647
  TimeoutError: If timeout is exceeded
648
+
649
+ Example:
650
+ >>> result = job.poll_until_complete(progress=True)
651
+ [00:15] running | score: 0.72
652
+ [00:30] running | score: 0.78
653
+ [00:45] succeeded | score: 0.85
654
+ >>> result.succeeded
655
+ True
656
+ >>> result.best_score
657
+ 0.85
369
658
  """
370
659
  if not self._job_id:
371
660
  raise RuntimeError("Job not yet submitted. Call submit() first.")
372
-
373
- poller = PromptLearningJobPoller(
374
- base_url=self.config.backend_url,
375
- api_key=self.config.api_key,
376
- interval=interval,
377
- timeout=timeout,
378
- )
379
-
380
- outcome = poller.poll_job(self._job_id) # type: ignore[arg-type] # We check None above
381
-
382
- payload = dict(outcome.payload) if isinstance(outcome.payload, dict) else {}
383
-
384
- if on_status:
385
- on_status(payload)
386
-
387
- return payload
661
+
662
+ job_id = self._job_id
663
+ base_url = ensure_api_base(self.config.backend_url)
664
+ headers = {
665
+ "Authorization": f"Bearer {self.config.api_key}",
666
+ "Content-Type": "application/json",
667
+ }
668
+
669
+ start_time = time.time()
670
+ elapsed = 0.0
671
+ last_data: Dict[str, Any] = {}
672
+
673
+ while elapsed <= timeout:
674
+ try:
675
+ # Fetch job status
676
+ url = f"{base_url}/prompt-learning/online/jobs/{job_id}"
677
+ resp = http_get(url, headers=headers)
678
+ data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
679
+ last_data = dict(data) if isinstance(data, dict) else {}
680
+
681
+ status = JobStatus.from_string(last_data.get("status", "pending"))
682
+ best_score = (
683
+ last_data.get("best_score")
684
+ or last_data.get("best_reward")
685
+ or last_data.get("best_train_score")
686
+ or last_data.get("best_train_reward")
687
+ )
688
+
689
+ # Progress output
690
+ if progress:
691
+ mins, secs = divmod(int(elapsed), 60)
692
+ score_str = f"score: {best_score:.2f}" if best_score is not None else "score: --"
693
+ print(f"[{mins:02d}:{secs:02d}] {status.value} | {score_str}")
694
+
695
+ # Callback for custom handling
696
+ if on_status:
697
+ on_status(last_data)
698
+
699
+ # Check terminal state
700
+ if status.is_terminal:
701
+ return PromptLearningResult.from_response(job_id, last_data)
702
+
703
+ except Exception as exc:
704
+ if progress:
705
+ print(f"[poll] error: {exc}")
706
+ log_info("poll request failed", ctx={"error": str(exc), "job_id": job_id})
707
+
708
+ time.sleep(interval)
709
+ elapsed = time.time() - start_time
710
+
711
+ # Timeout reached
712
+ if progress:
713
+ print(f"[poll] timeout after {timeout:.0f}s")
714
+
715
+ # Return with whatever data we have, status will indicate not complete
716
+ return PromptLearningResult.from_response(job_id, last_data)
388
717
 
389
718
  def get_results(self) -> Dict[str, Any]:
390
719
  """Get job results (prompts, scores, etc.).
@@ -463,8 +792,9 @@ class PromptLearningJob:
463
792
 
464
793
 
465
794
  __all__ = [
795
+ "JobStatus",
466
796
  "PromptLearningJob",
467
797
  "PromptLearningJobConfig",
468
798
  "PromptLearningJobPoller",
799
+ "PromptLearningResult",
469
800
  ]
470
-
@@ -1,5 +1,7 @@
1
1
  """First-class SDK API for reinforcement learning (RL/GSPO).
2
2
 
3
+ **Status:** Experimental
4
+
3
5
  This module provides high-level abstractions for running RL training jobs
4
6
  both via CLI (`uvx synth-ai train --type rl`) and programmatically in Python scripts.
5
7
 
@@ -32,14 +34,49 @@ from synth_ai.core.telemetry import log_info
32
34
 
33
35
  from .builders import RLBuildResult, build_rl_payload
34
36
  from .pollers import RLJobPoller
35
- from .task_app import check_task_app_health
37
+ from .local_api import check_local_api_health
36
38
  from .utils import ensure_api_base, http_post
37
39
 
38
40
 
39
41
  @dataclass
40
42
  class RLJobConfig:
41
- """Configuration for an RL training job."""
42
-
43
+ """Configuration for an RL training job.
44
+
45
+ This dataclass holds all the configuration needed to submit and run
46
+ a reinforcement learning training job (GSPO, GRPO, PPO, etc.).
47
+
48
+ Attributes:
49
+ config_path: Path to the TOML configuration file that defines the
50
+ RL training task, including model settings, training hyperparameters,
51
+ reward configuration, and Local API URL.
52
+ backend_url: Base URL of the Synth API backend (e.g.,
53
+ "https://api.usesynth.ai"). Can also be set via BACKEND_BASE_URL
54
+ environment variable.
55
+ api_key: Synth API key for authentication. Can also be set via
56
+ SYNTH_API_KEY environment variable.
57
+ task_app_url: URL of the Local API that serves rollout environments.
58
+ Can be set via TASK_APP_URL env var if not provided.
59
+ (Alias: also known as "task app URL" in older documentation)
60
+ task_app_api_key: API key for authenticating with the Local API.
61
+ Defaults to ENVIRONMENT_API_KEY env var if not provided.
62
+ (Alias: also known as "task app API key" in older documentation)
63
+ allow_experimental: If True, allows use of experimental models and
64
+ features. Defaults to None (uses config file setting).
65
+ overrides: Dictionary of config overrides that take precedence over
66
+ values in the TOML file. Useful for programmatic customization.
67
+ idempotency_key: Optional key for idempotent job submission. If provided,
68
+ submitting the same key twice will return the existing job instead
69
+ of creating a new one.
70
+
71
+ Example:
72
+ >>> config = RLJobConfig(
73
+ ... config_path=Path("rl_config.toml"),
74
+ ... backend_url="https://api.usesynth.ai",
75
+ ... api_key="sk_live_...",
76
+ ... task_app_url="https://my-task-app.example.com",
77
+ ... )
78
+ """
79
+
43
80
  config_path: Path
44
81
  backend_url: str
45
82
  api_key: str
@@ -282,7 +319,7 @@ class RLJob:
282
319
  # Health check (skip if _skip_health_check is set - useful for tunnels with DNS delay)
283
320
  if not self._skip_health_check:
284
321
  task_app_key = self.config.task_app_api_key or ""
285
- health = check_task_app_health(build.task_url, task_app_key)
322
+ health = check_local_api_health(build.task_url, task_app_key)
286
323
  if not health.ok:
287
324
  raise ValueError(f"Task app health check failed: {health.detail}")
288
325
 
@@ -439,4 +476,3 @@ __all__ = [
439
476
  "RLJob",
440
477
  "RLJobConfig",
441
478
  ]
442
-
@@ -1,5 +1,7 @@
1
1
  """First-class SDK API for SFT (Supervised Fine-Tuning).
2
2
 
3
+ **Status:** Experimental
4
+
3
5
  This module provides high-level abstractions for running SFT jobs
4
6
  both via CLI (`uvx synth-ai train`) and programmatically in Python scripts.
5
7
 
@@ -21,6 +21,11 @@ class TaskAppHealth:
21
21
  detail: str | None = None
22
22
 
23
23
 
24
+ @dataclass(slots=True)
25
+ class LocalAPIHealth(TaskAppHealth):
26
+ """Alias for TaskAppHealth with LocalAPI naming."""
27
+
28
+
24
29
  def _resolve_hostname_with_explicit_resolvers(hostname: str) -> str:
25
30
  """
26
31
  Resolve hostname using explicit resolvers (1.1.1.1, 8.8.8.8) first,
@@ -245,6 +250,19 @@ def check_task_app_health(base_url: str, api_key: str, *, timeout: float = 10.0,
245
250
  )
246
251
 
247
252
 
253
+ def check_local_api_health(
254
+ base_url: str, api_key: str, *, timeout: float = 10.0, max_retries: int = 5
255
+ ) -> LocalAPIHealth:
256
+ """Alias for check_task_app_health with LocalAPI naming."""
257
+ health = check_task_app_health(base_url, api_key, timeout=timeout, max_retries=max_retries)
258
+ return LocalAPIHealth(
259
+ ok=health.ok,
260
+ health_status=health.health_status,
261
+ task_info_status=health.task_info_status,
262
+ detail=health.detail,
263
+ )
264
+
265
+
248
266
  @dataclass(slots=True)
249
267
  class ModalSecret:
250
268
  name: str
@@ -323,9 +341,11 @@ __all__ = [
323
341
  "ModalApp",
324
342
  "ModalSecret",
325
343
  "check_task_app_health",
344
+ "check_local_api_health",
326
345
  "format_modal_apps",
327
346
  "format_modal_secrets",
328
347
  "get_modal_secret_value",
329
348
  "list_modal_apps",
330
349
  "list_modal_secrets",
350
+ "LocalAPIHealth",
331
351
  ]