mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import hashlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, TypedDict
|
|
11
|
+
|
|
12
|
+
from mantisdk.algorithm.gepa.lib.core.adapter import EvaluationBatch, GEPAAdapter
|
|
13
|
+
|
|
14
|
+
from mantisdk.adapter import TraceAdapter
|
|
15
|
+
from mantisdk.reward import find_final_reward, get_rewards_from_span, find_reward_spans
|
|
16
|
+
from mantisdk.store.base import LightningStore
|
|
17
|
+
from mantisdk.types import NamedResources, PromptTemplate, Rollout, Span, TracingConfig
|
|
18
|
+
|
|
19
|
+
from .tracing import GEPATracingContext
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MantisdkDataInst(TypedDict):
|
|
25
|
+
"""Data instance for Mantisdk."""
|
|
26
|
+
|
|
27
|
+
input: Dict[str, Any]
|
|
28
|
+
id: Optional[str]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MantisdkTrajectory(TypedDict):
|
|
32
|
+
"""Trajectory for Mantisdk.
|
|
33
|
+
|
|
34
|
+
Captures the data needed for reflection:
|
|
35
|
+
- original_input: The task input data
|
|
36
|
+
- assistant_response: The final LLM response
|
|
37
|
+
- feedback: Evaluation feedback including score and status
|
|
38
|
+
- spans: Raw spans for detailed analysis
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
rollout_id: str
|
|
42
|
+
original_input: Dict[str, Any] # The task input
|
|
43
|
+
assistant_response: str # The final LLM response
|
|
44
|
+
feedback: str # Feedback including score and error info
|
|
45
|
+
spans: List[Span]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MantisdkRolloutOutput(TypedDict):
|
|
49
|
+
"""Rollout output for Mantisdk."""
|
|
50
|
+
|
|
51
|
+
final_reward: Optional[float]
|
|
52
|
+
status: str
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# TypedDict for GEPA's expected reflective record format
|
|
56
|
+
MantisdkReflectiveRecord = TypedDict(
|
|
57
|
+
"MantisdkReflectiveRecord",
|
|
58
|
+
{
|
|
59
|
+
"Inputs": str,
|
|
60
|
+
"Generated Outputs": str,
|
|
61
|
+
"Feedback": str,
|
|
62
|
+
},
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class MantisdkGEPAAdapter(
|
|
67
|
+
GEPAAdapter[MantisdkDataInst, MantisdkTrajectory, MantisdkRolloutOutput]
|
|
68
|
+
):
|
|
69
|
+
"""Adapter to bridge GEPA with Mantisdk.
|
|
70
|
+
|
|
71
|
+
This adapter:
|
|
72
|
+
1. Evaluates candidates by creating rollouts in the LightningStore
|
|
73
|
+
2. Waits for runners to execute the agent
|
|
74
|
+
3. Collects spans and rewards
|
|
75
|
+
4. Builds reflective datasets for GEPA's reflection mechanism
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
store: LightningStore,
|
|
81
|
+
loop: asyncio.AbstractEventLoop,
|
|
82
|
+
resource_name: str,
|
|
83
|
+
adapter: TraceAdapter,
|
|
84
|
+
llm_proxy_resource: Any = None,
|
|
85
|
+
rollout_batch_timeout: float = 600.0,
|
|
86
|
+
tracing_config: Optional[TracingConfig] = None,
|
|
87
|
+
tracing_context: Optional[GEPATracingContext] = None,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Initialize the adapter.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
store: The LightningStore instance.
|
|
93
|
+
loop: The asyncio loop where the store and other async components are running.
|
|
94
|
+
resource_name: The name of the resource to update (e.g., "prompt_template").
|
|
95
|
+
adapter: The TraceAdapter to convert spans to messages (optional, for debugging).
|
|
96
|
+
llm_proxy_resource: LLM resource from the proxy to include in resources.
|
|
97
|
+
rollout_batch_timeout: Timeout for waiting for rollouts in seconds.
|
|
98
|
+
tracing_config: Optional tracing configuration from the algorithm.
|
|
99
|
+
tracing_context: Optional GEPA tracing context for detailed execution tracking.
|
|
100
|
+
"""
|
|
101
|
+
self.store = store
|
|
102
|
+
self.loop = loop
|
|
103
|
+
self.resource_name = resource_name
|
|
104
|
+
self.adapter = adapter
|
|
105
|
+
self.llm_proxy_resource = llm_proxy_resource
|
|
106
|
+
self.rollout_batch_timeout = rollout_batch_timeout
|
|
107
|
+
self.tracing_config = tracing_config
|
|
108
|
+
self.tracing_context = tracing_context
|
|
109
|
+
|
|
110
|
+
def evaluate(
|
|
111
|
+
self,
|
|
112
|
+
batch: List[MantisdkDataInst],
|
|
113
|
+
candidate: Dict[str, str],
|
|
114
|
+
capture_traces: bool = False,
|
|
115
|
+
) -> EvaluationBatch[MantisdkTrajectory, MantisdkRolloutOutput]:
|
|
116
|
+
"""Evaluate a candidate on a batch of data instances.
|
|
117
|
+
|
|
118
|
+
This method bridges the synchronous GEPA call to the asynchronous Mantisdk execution.
|
|
119
|
+
"""
|
|
120
|
+
# Run the async evaluation in the main loop and wait for result
|
|
121
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
122
|
+
self._evaluate_async(batch, candidate, capture_traces), self.loop
|
|
123
|
+
)
|
|
124
|
+
return future.result(timeout=self.rollout_batch_timeout + 60) # Add buffer to timeout
|
|
125
|
+
|
|
126
|
+
async def _evaluate_async(
|
|
127
|
+
self,
|
|
128
|
+
batch: List[MantisdkDataInst],
|
|
129
|
+
candidate: Dict[str, str],
|
|
130
|
+
capture_traces: bool = False,
|
|
131
|
+
) -> EvaluationBatch[MantisdkTrajectory, MantisdkRolloutOutput]:
|
|
132
|
+
"""Asynchronous evaluation logic."""
|
|
133
|
+
# 1. Build resources from candidate
|
|
134
|
+
resources: NamedResources = {}
|
|
135
|
+
for key, value in candidate.items():
|
|
136
|
+
resources[key] = PromptTemplate(template=value, engine="f-string")
|
|
137
|
+
|
|
138
|
+
# Add the LLM resource if available
|
|
139
|
+
if self.llm_proxy_resource is not None:
|
|
140
|
+
resources["llm"] = self.llm_proxy_resource
|
|
141
|
+
|
|
142
|
+
# Create a unique resource version for this evaluation
|
|
143
|
+
candidate_version = f"gepa-{uuid.uuid4().hex[:8]}"
|
|
144
|
+
await self.store.update_resources(candidate_version, resources)
|
|
145
|
+
|
|
146
|
+
# 2. Enqueue rollouts for each item in the batch
|
|
147
|
+
rollout_ids = []
|
|
148
|
+
batch_inputs: Dict[str, Dict[str, Any]] = {} # Map rollout_id -> original input
|
|
149
|
+
|
|
150
|
+
# Prepare metadata from tracing config with detailed context
|
|
151
|
+
metadata = None
|
|
152
|
+
if self.tracing_config and self.tracing_context:
|
|
153
|
+
# Generate candidate hash for tracking
|
|
154
|
+
candidate_hash = hashlib.md5(str(candidate).encode()).hexdigest()[:8]
|
|
155
|
+
self.tracing_context.set_candidate(candidate_hash)
|
|
156
|
+
|
|
157
|
+
# Get batch item IDs for validation detection
|
|
158
|
+
item_ids = [item.get("id", str(i)) for i, item in enumerate(batch)]
|
|
159
|
+
|
|
160
|
+
# Detect if this is a validation batch
|
|
161
|
+
if self.tracing_context.is_validation_batch(item_ids):
|
|
162
|
+
self.tracing_context.set_phase("validation-eval")
|
|
163
|
+
else:
|
|
164
|
+
# Register training items on first batch of generation
|
|
165
|
+
if self.tracing_context.batch_count == 0:
|
|
166
|
+
self.tracing_context.register_training_items(item_ids)
|
|
167
|
+
self.tracing_context.set_phase("train-eval")
|
|
168
|
+
|
|
169
|
+
# Get batch ID and build GEPA-specific tags
|
|
170
|
+
batch_id = self.tracing_context.next_batch()
|
|
171
|
+
gepa_tags = [
|
|
172
|
+
f"gen-{self.tracing_context.generation}",
|
|
173
|
+
f"candidate-{candidate_hash}",
|
|
174
|
+
batch_id,
|
|
175
|
+
]
|
|
176
|
+
metadata = self.tracing_config.to_detailed_metadata(
|
|
177
|
+
phase=self.tracing_context.phase,
|
|
178
|
+
extra_tags=gepa_tags,
|
|
179
|
+
)
|
|
180
|
+
logger.debug(
|
|
181
|
+
f"Batch evaluation: session={self.tracing_context.session_id}, "
|
|
182
|
+
f"phase={self.tracing_context.phase}, "
|
|
183
|
+
f"gen={self.tracing_context.generation}, candidate={candidate_hash}, {batch_id}"
|
|
184
|
+
)
|
|
185
|
+
elif self.tracing_config:
|
|
186
|
+
# Fallback to simple metadata if no context
|
|
187
|
+
metadata = self.tracing_config.to_metadata("train")
|
|
188
|
+
|
|
189
|
+
for item in batch:
|
|
190
|
+
task_input = item["input"]
|
|
191
|
+
res = await self.store.enqueue_rollout(
|
|
192
|
+
input=task_input,
|
|
193
|
+
mode="train",
|
|
194
|
+
resources_id=candidate_version,
|
|
195
|
+
metadata=metadata,
|
|
196
|
+
)
|
|
197
|
+
rollout_ids.append(res.rollout_id)
|
|
198
|
+
batch_inputs[res.rollout_id] = task_input
|
|
199
|
+
|
|
200
|
+
# 3. Wait for completion
|
|
201
|
+
try:
|
|
202
|
+
completed_rollouts = await self.store.wait_for_rollouts(
|
|
203
|
+
rollout_ids=rollout_ids, timeout=self.rollout_batch_timeout
|
|
204
|
+
)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Error waiting for rollouts: {e}")
|
|
207
|
+
# Return failure results for all items
|
|
208
|
+
outputs: List[MantisdkRolloutOutput] = []
|
|
209
|
+
trajectories: List[MantisdkTrajectory] = [] if capture_traces else None
|
|
210
|
+
scores: List[float] = []
|
|
211
|
+
|
|
212
|
+
for rollout_id in rollout_ids:
|
|
213
|
+
outputs.append({"final_reward": 0.0, "status": "failed"})
|
|
214
|
+
scores.append(0.0)
|
|
215
|
+
if capture_traces:
|
|
216
|
+
trajectories.append({
|
|
217
|
+
"rollout_id": rollout_id,
|
|
218
|
+
"original_input": batch_inputs.get(rollout_id, {}),
|
|
219
|
+
"assistant_response": "",
|
|
220
|
+
"feedback": f"Rollout failed with error: {e}",
|
|
221
|
+
"spans": [],
|
|
222
|
+
})
|
|
223
|
+
|
|
224
|
+
return EvaluationBatch(
|
|
225
|
+
outputs=outputs,
|
|
226
|
+
trajectories=trajectories,
|
|
227
|
+
scores=scores,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# 4. Collect results
|
|
231
|
+
outputs: List[MantisdkRolloutOutput] = []
|
|
232
|
+
trajectories: List[MantisdkTrajectory] = [] if capture_traces else None
|
|
233
|
+
scores: List[float] = []
|
|
234
|
+
|
|
235
|
+
for rollout in completed_rollouts:
|
|
236
|
+
# Query spans for this rollout
|
|
237
|
+
spans = await self.store.query_spans(rollout.rollout_id)
|
|
238
|
+
|
|
239
|
+
# Find final reward
|
|
240
|
+
final_reward_val = find_final_reward(spans)
|
|
241
|
+
if final_reward_val is None:
|
|
242
|
+
final_reward_val = 0.0
|
|
243
|
+
|
|
244
|
+
outputs.append({
|
|
245
|
+
"final_reward": final_reward_val,
|
|
246
|
+
"status": rollout.status,
|
|
247
|
+
})
|
|
248
|
+
scores.append(final_reward_val)
|
|
249
|
+
|
|
250
|
+
if capture_traces:
|
|
251
|
+
# Extract assistant response from spans
|
|
252
|
+
assistant_response = self._extract_assistant_response(spans)
|
|
253
|
+
|
|
254
|
+
# Build feedback string
|
|
255
|
+
feedback = self._build_feedback(
|
|
256
|
+
final_reward_val,
|
|
257
|
+
rollout.status,
|
|
258
|
+
spans,
|
|
259
|
+
batch_inputs.get(rollout.rollout_id, {})
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
trajectories.append({
|
|
263
|
+
"rollout_id": rollout.rollout_id,
|
|
264
|
+
"original_input": batch_inputs.get(rollout.rollout_id, {}),
|
|
265
|
+
"assistant_response": assistant_response,
|
|
266
|
+
"feedback": feedback,
|
|
267
|
+
"spans": spans,
|
|
268
|
+
})
|
|
269
|
+
|
|
270
|
+
return EvaluationBatch(
|
|
271
|
+
outputs=outputs,
|
|
272
|
+
trajectories=trajectories,
|
|
273
|
+
scores=scores,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def _extract_assistant_response(self, spans: List[Span]) -> str:
|
|
277
|
+
"""Extract the final assistant response from spans."""
|
|
278
|
+
responses = []
|
|
279
|
+
|
|
280
|
+
for span in spans:
|
|
281
|
+
attrs = span.attributes if hasattr(span, 'attributes') else {}
|
|
282
|
+
name = span.name if hasattr(span, 'name') else ""
|
|
283
|
+
|
|
284
|
+
# Look for LLM completion spans
|
|
285
|
+
if "chat" in name.lower() or "completion" in name.lower() or "llm" in name.lower():
|
|
286
|
+
# Try different attribute keys for response content
|
|
287
|
+
response_keys = [
|
|
288
|
+
"gen_ai.completion.0.content",
|
|
289
|
+
"gen_ai.response.content",
|
|
290
|
+
"response.content",
|
|
291
|
+
"completion",
|
|
292
|
+
]
|
|
293
|
+
for key in response_keys:
|
|
294
|
+
if key in attrs and attrs[key]:
|
|
295
|
+
responses.append(str(attrs[key]))
|
|
296
|
+
break
|
|
297
|
+
|
|
298
|
+
if responses:
|
|
299
|
+
return responses[-1] # Return the last (final) response
|
|
300
|
+
return "No assistant response found in spans"
|
|
301
|
+
|
|
302
|
+
def _build_feedback(
|
|
303
|
+
self,
|
|
304
|
+
reward: float,
|
|
305
|
+
status: str,
|
|
306
|
+
spans: List[Span],
|
|
307
|
+
task_input: Dict[str, Any]
|
|
308
|
+
) -> str:
|
|
309
|
+
"""Build detailed feedback for reflection.
|
|
310
|
+
|
|
311
|
+
For evaluation tasks, extracts human_score and llm_score from reward dimensions
|
|
312
|
+
to provide specific feedback about prediction errors.
|
|
313
|
+
"""
|
|
314
|
+
parts = []
|
|
315
|
+
|
|
316
|
+
# Try to extract multi-dimensional reward data (human_score, llm_score)
|
|
317
|
+
human_score = None
|
|
318
|
+
llm_score = None
|
|
319
|
+
|
|
320
|
+
reward_spans = find_reward_spans(spans)
|
|
321
|
+
for span in reward_spans:
|
|
322
|
+
rewards = get_rewards_from_span(span)
|
|
323
|
+
for r in rewards:
|
|
324
|
+
if r.name == "human_score":
|
|
325
|
+
human_score = r.value
|
|
326
|
+
elif r.name == "llm_score":
|
|
327
|
+
llm_score = r.value
|
|
328
|
+
|
|
329
|
+
# If we have both scores, provide detailed evaluation feedback
|
|
330
|
+
if human_score is not None and llm_score is not None:
|
|
331
|
+
diff = llm_score - human_score
|
|
332
|
+
if abs(diff) < 0.1:
|
|
333
|
+
parts.append(f"Good prediction: LLM scored {llm_score:.2f}, human scored {human_score:.2f} (close match).")
|
|
334
|
+
elif diff > 0:
|
|
335
|
+
parts.append(f"OVERESTIMATED: LLM scored {llm_score:.2f} but human scored {human_score:.2f}. "
|
|
336
|
+
f"The prompt caused the LLM to rate this {diff:.2f} points too high. "
|
|
337
|
+
"Consider adding stricter criteria or examples.")
|
|
338
|
+
else:
|
|
339
|
+
parts.append(f"UNDERESTIMATED: LLM scored {llm_score:.2f} but human scored {human_score:.2f}. "
|
|
340
|
+
f"The prompt caused the LLM to rate this {abs(diff):.2f} points too low. "
|
|
341
|
+
"Consider broadening the criteria or recognizing subtle details.")
|
|
342
|
+
else:
|
|
343
|
+
# Fallback to generic feedback
|
|
344
|
+
if reward >= 0.9:
|
|
345
|
+
parts.append(f"The task was completed successfully with a high score of {reward:.2f}.")
|
|
346
|
+
elif reward >= 0.5:
|
|
347
|
+
parts.append(f"The task was partially successful with a score of {reward:.2f}. There is room for improvement.")
|
|
348
|
+
else:
|
|
349
|
+
parts.append(f"The task failed or performed poorly with a score of {reward:.2f}. The prompt needs significant improvement.")
|
|
350
|
+
|
|
351
|
+
# Add status info
|
|
352
|
+
if status != "succeeded":
|
|
353
|
+
parts.append(f"Rollout status: {status}.")
|
|
354
|
+
|
|
355
|
+
# Look for error information in spans
|
|
356
|
+
for span in spans:
|
|
357
|
+
attrs = span.attributes if hasattr(span, 'attributes') else {}
|
|
358
|
+
if "error" in str(attrs).lower():
|
|
359
|
+
parts.append(f"Error detected in execution.")
|
|
360
|
+
break
|
|
361
|
+
|
|
362
|
+
# Add task context
|
|
363
|
+
if task_input:
|
|
364
|
+
if "expected_choice" in task_input:
|
|
365
|
+
parts.append(f"Expected answer: {task_input['expected_choice']}")
|
|
366
|
+
# Include humanScore from task if available (for debugging)
|
|
367
|
+
if "humanScore" in task_input:
|
|
368
|
+
parts.append(f"Expected human score: {task_input['humanScore']:.2f}")
|
|
369
|
+
|
|
370
|
+
return " ".join(parts)
|
|
371
|
+
|
|
372
|
+
def make_reflective_dataset(
|
|
373
|
+
self,
|
|
374
|
+
candidate: Dict[str, str],
|
|
375
|
+
eval_batch: EvaluationBatch[MantisdkTrajectory, MantisdkRolloutOutput],
|
|
376
|
+
components_to_update: List[str],
|
|
377
|
+
) -> Mapping[str, Sequence[Mapping[str, Any]]]:
|
|
378
|
+
"""Create a reflective dataset for GEPA's reflection mechanism.
|
|
379
|
+
|
|
380
|
+
Returns data in GEPA's expected format:
|
|
381
|
+
{
|
|
382
|
+
"component_name": [
|
|
383
|
+
{
|
|
384
|
+
"Inputs": str, # Task input description
|
|
385
|
+
"Generated Outputs": str, # Model outputs
|
|
386
|
+
"Feedback": str # Performance feedback
|
|
387
|
+
},
|
|
388
|
+
...
|
|
389
|
+
]
|
|
390
|
+
}
|
|
391
|
+
"""
|
|
392
|
+
reflective_data: Dict[str, List[MantisdkReflectiveRecord]] = {
|
|
393
|
+
component: [] for component in components_to_update
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
trajectories = eval_batch.trajectories
|
|
397
|
+
if trajectories is None:
|
|
398
|
+
logger.warning("No trajectories available for building reflective dataset")
|
|
399
|
+
return reflective_data
|
|
400
|
+
|
|
401
|
+
for i, traj in enumerate(trajectories):
|
|
402
|
+
output = eval_batch.outputs[i]
|
|
403
|
+
|
|
404
|
+
# Format the input as a readable string
|
|
405
|
+
input_str = self._format_input(traj["original_input"])
|
|
406
|
+
|
|
407
|
+
# Get the assistant response
|
|
408
|
+
generated_output = traj["assistant_response"]
|
|
409
|
+
|
|
410
|
+
# Get the feedback
|
|
411
|
+
feedback = traj["feedback"]
|
|
412
|
+
|
|
413
|
+
record: MantisdkReflectiveRecord = {
|
|
414
|
+
"Inputs": input_str,
|
|
415
|
+
"Generated Outputs": generated_output,
|
|
416
|
+
"Feedback": feedback,
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
# Add to each component being updated
|
|
420
|
+
for component in components_to_update:
|
|
421
|
+
reflective_data[component].append(record)
|
|
422
|
+
|
|
423
|
+
# Validate we have data
|
|
424
|
+
for component in components_to_update:
|
|
425
|
+
if len(reflective_data[component]) == 0:
|
|
426
|
+
logger.warning(f"No reflective records for component {component}")
|
|
427
|
+
|
|
428
|
+
return reflective_data
|
|
429
|
+
|
|
430
|
+
def _format_input(self, task_input: Dict[str, Any]) -> str:
|
|
431
|
+
"""Format task input as a readable string for reflection."""
|
|
432
|
+
if not task_input:
|
|
433
|
+
return "No input available"
|
|
434
|
+
|
|
435
|
+
# If it's a RoomSelectionTask-like structure
|
|
436
|
+
if "task_input" in task_input:
|
|
437
|
+
inner = task_input["task_input"]
|
|
438
|
+
parts = []
|
|
439
|
+
if "date" in inner:
|
|
440
|
+
parts.append(f"Date: {inner['date']}")
|
|
441
|
+
if "time" in inner:
|
|
442
|
+
parts.append(f"Time: {inner['time']}")
|
|
443
|
+
if "duration_min" in inner:
|
|
444
|
+
parts.append(f"Duration: {inner['duration_min']} minutes")
|
|
445
|
+
if "attendees" in inner:
|
|
446
|
+
parts.append(f"Attendees: {inner['attendees']}")
|
|
447
|
+
if "needs" in inner:
|
|
448
|
+
parts.append(f"Needs: {', '.join(inner['needs']) if inner['needs'] else 'none'}")
|
|
449
|
+
if "accessible_required" in inner:
|
|
450
|
+
parts.append(f"Accessible required: {inner['accessible_required']}")
|
|
451
|
+
return "; ".join(parts) if parts else json.dumps(task_input)
|
|
452
|
+
|
|
453
|
+
# For other dict structures, try to format nicely
|
|
454
|
+
try:
|
|
455
|
+
if len(task_input) <= 5:
|
|
456
|
+
return "; ".join(f"{k}: {v}" for k, v in task_input.items())
|
|
457
|
+
return json.dumps(task_input, indent=2)
|
|
458
|
+
except Exception:
|
|
459
|
+
return str(task_input)
|