synth-ai 0.2.3__py3-none-any.whl → 0.2.4.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/compound/cais.py +0 -0
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +115 -1
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth.py +3 -3
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/test_crafter_react_agent_lm_synth_v2_backup.py +3 -3
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/run_rollouts_for_models_and_compare_v3.py +4 -4
- synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/test_crafter_react_agent_openai_v2_backup.py +3 -3
- synth_ai/environments/examples/crafter_classic/agent_demos/example_v3_usage.py +1 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/crafter_custom/environment.py +1 -1
- synth_ai/environments/service/core_routes.py +1 -1
- synth_ai/learning/prompts/mipro.py +8 -0
- synth_ai/lm/core/main_v3.py +219 -158
- synth_ai/tracing_v3/__init__.py +2 -2
- synth_ai/tracing_v3/abstractions.py +62 -17
- synth_ai/tracing_v3/hooks.py +1 -1
- synth_ai/tracing_v3/llm_call_record_helpers.py +350 -0
- synth_ai/tracing_v3/lm_call_record_abstractions.py +257 -0
- synth_ai/tracing_v3/session_tracer.py +5 -5
- synth_ai/tracing_v3/tests/test_concurrent_operations.py +1 -1
- synth_ai/tracing_v3/tests/test_llm_call_records.py +672 -0
- synth_ai/tracing_v3/tests/test_session_tracer.py +43 -9
- synth_ai/tracing_v3/tests/test_turso_manager.py +1 -1
- synth_ai/tracing_v3/turso/manager.py +10 -3
- synth_ai/tracing_v3/turso/models.py +1 -0
- {synth_ai-0.2.3.dist-info → synth_ai-0.2.4.dev2.dist-info}/METADATA +3 -2
- {synth_ai-0.2.3.dist-info → synth_ai-0.2.4.dev2.dist-info}/RECORD +30 -26
- {synth_ai-0.2.3.dist-info → synth_ai-0.2.4.dev2.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.3.dist-info → synth_ai-0.2.4.dev2.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.3.dist-info → synth_ai-0.2.4.dev2.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.3.dist-info → synth_ai-0.2.4.dev2.dist-info}/top_level.txt +0 -0
File without changes
|
@@ -251,6 +251,111 @@ class FinetuningDataExtractorV3:
|
|
251
251
|
|
252
252
|
return qualifying_sessions
|
253
253
|
|
254
|
+
async def extract_openai_format_from_call_records(self, session_ids: List[str], min_reward: float = 0.0) -> List[Dict[str, Any]]:
|
255
|
+
"""Extract training data in OpenAI format from call_records in LMCAISEvents.
|
256
|
+
|
257
|
+
This is the new method that uses the detailed LLM interaction data stored
|
258
|
+
in call_records instead of relying on separate message records.
|
259
|
+
"""
|
260
|
+
training_data = []
|
261
|
+
|
262
|
+
for session_id in session_ids:
|
263
|
+
# Get LM CAIS events with call_records from the proper column
|
264
|
+
events_query = """
|
265
|
+
SELECT e.call_records, st.turn_number
|
266
|
+
FROM events e
|
267
|
+
LEFT JOIN session_timesteps st ON e.timestep_id = st.id
|
268
|
+
WHERE e.session_id = :session_id
|
269
|
+
AND e.event_type = 'cais'
|
270
|
+
AND e.call_records IS NOT NULL
|
271
|
+
ORDER BY COALESCE(st.turn_number, e.message_time), e.id
|
272
|
+
"""
|
273
|
+
|
274
|
+
events_df = await self.db_manager.query_traces(events_query, {"session_id": session_id})
|
275
|
+
|
276
|
+
if len(events_df) == 0:
|
277
|
+
# Fall back to old method if no call_records
|
278
|
+
continue
|
279
|
+
|
280
|
+
# Extract messages from call_records
|
281
|
+
all_messages = []
|
282
|
+
|
283
|
+
for _, row in events_df.iterrows():
|
284
|
+
call_records_json = row['call_records']
|
285
|
+
if not call_records_json:
|
286
|
+
continue
|
287
|
+
|
288
|
+
# Parse the call_records JSON directly from the column
|
289
|
+
try:
|
290
|
+
import json
|
291
|
+
if isinstance(call_records_json, str):
|
292
|
+
call_records = json.loads(call_records_json)
|
293
|
+
else:
|
294
|
+
call_records = call_records_json
|
295
|
+
|
296
|
+
# Process each call record
|
297
|
+
for record in call_records:
|
298
|
+
# Extract input messages
|
299
|
+
for msg in record.get('input_messages', []):
|
300
|
+
role = msg.get('role', 'user')
|
301
|
+
parts = msg.get('parts', [])
|
302
|
+
|
303
|
+
# Combine text parts
|
304
|
+
text_content = []
|
305
|
+
for part in parts:
|
306
|
+
if part.get('type') == 'text' and part.get('text'):
|
307
|
+
text_content.append(part['text'])
|
308
|
+
|
309
|
+
if text_content:
|
310
|
+
content = ' '.join(text_content)
|
311
|
+
if role == 'system' and not any(m['role'] == 'system' for m in all_messages):
|
312
|
+
all_messages.insert(0, {"role": "system", "content": content})
|
313
|
+
elif role != 'system':
|
314
|
+
all_messages.append({"role": role, "content": content})
|
315
|
+
|
316
|
+
# Extract output messages
|
317
|
+
for msg in record.get('output_messages', []):
|
318
|
+
role = msg.get('role', 'assistant')
|
319
|
+
parts = msg.get('parts', [])
|
320
|
+
|
321
|
+
# Combine text parts
|
322
|
+
text_content = []
|
323
|
+
for part in parts:
|
324
|
+
if part.get('type') == 'text' and part.get('text'):
|
325
|
+
text_content.append(part['text'])
|
326
|
+
|
327
|
+
if text_content:
|
328
|
+
content = ' '.join(text_content)
|
329
|
+
all_messages.append({"role": role, "content": content})
|
330
|
+
|
331
|
+
except Exception as e:
|
332
|
+
print(f"Error parsing call_records for session {session_id}: {e}")
|
333
|
+
continue
|
334
|
+
|
335
|
+
# Only include if we have a complete conversation
|
336
|
+
if len(all_messages) > 1:
|
337
|
+
# Get total reward for this session
|
338
|
+
reward_query = """
|
339
|
+
SELECT COALESCE(SUM(reward), 0) as total_reward
|
340
|
+
FROM events
|
341
|
+
WHERE session_id = :session_id
|
342
|
+
AND event_type = 'environment'
|
343
|
+
AND reward IS NOT NULL
|
344
|
+
"""
|
345
|
+
reward_df = await self.db_manager.query_traces(reward_query, {"session_id": session_id})
|
346
|
+
total_reward = reward_df.iloc[0]['total_reward'] if len(reward_df) > 0 else 0
|
347
|
+
|
348
|
+
training_data.append({
|
349
|
+
"messages": all_messages,
|
350
|
+
"metadata": {
|
351
|
+
"session_id": session_id,
|
352
|
+
"total_reward": float(total_reward),
|
353
|
+
"source": "call_records" # Mark that this came from call_records
|
354
|
+
}
|
355
|
+
})
|
356
|
+
|
357
|
+
return training_data
|
358
|
+
|
254
359
|
async def extract_openai_format(self, session_ids: List[str], min_reward: float = 0.0) -> List[Dict[str, Any]]:
|
255
360
|
"""Extract training data in OpenAI format from filtered sessions."""
|
256
361
|
training_data = []
|
@@ -440,10 +545,19 @@ async def filter_traces_from_turso(
|
|
440
545
|
|
441
546
|
# Extract training data
|
442
547
|
if mode == "trajectory":
|
443
|
-
|
548
|
+
# Try new method first (using call_records)
|
549
|
+
training_data = await extractor.extract_openai_format_from_call_records(
|
444
550
|
session_ids=filtered_sessions,
|
445
551
|
min_reward=min_reward
|
446
552
|
)
|
553
|
+
|
554
|
+
# If no data from call_records, fall back to old method
|
555
|
+
if not training_data:
|
556
|
+
print("No call_records found, falling back to message-based extraction...")
|
557
|
+
training_data = await extractor.extract_openai_format(
|
558
|
+
session_ids=filtered_sessions,
|
559
|
+
min_reward=min_reward
|
560
|
+
)
|
447
561
|
else: # window mode
|
448
562
|
# For window mode, we need to implement window extraction
|
449
563
|
# For now, use trajectory mode
|
@@ -70,7 +70,7 @@ from synth_ai.lm.config import SynthConfig
|
|
70
70
|
# Import session tracer for v3 tracing
|
71
71
|
from synth_ai.tracing_v3 import SessionTracer
|
72
72
|
from synth_ai.tracing_v3.abstractions import (
|
73
|
-
|
73
|
+
SessionEventMarkovBlanketMessage, TimeRecord,
|
74
74
|
RuntimeEvent, EnvironmentEvent, LMCAISEvent
|
75
75
|
)
|
76
76
|
# create_experiment_context will be defined as a helper function below
|
@@ -255,7 +255,7 @@ async def retry_http_request(client: AsyncClient, method: str, url: str, **kwarg
|
|
255
255
|
raise last_exception
|
256
256
|
|
257
257
|
|
258
|
-
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) ->
|
258
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) -> SessionEventMarkovBlanketMessage:
|
259
259
|
"""Create a message with origin system ID embedded in content."""
|
260
260
|
# Map custom message types to valid v3 message types
|
261
261
|
type_mapping = {
|
@@ -267,7 +267,7 @@ def create_message(content: Any, message_type: str, origin_system_id: Any, turn:
|
|
267
267
|
"tool_result": "tool_result"
|
268
268
|
}
|
269
269
|
|
270
|
-
return
|
270
|
+
return SessionEventMarkovBlanketMessage(
|
271
271
|
content=json.dumps({
|
272
272
|
"origin_system_id": str(origin_system_id),
|
273
273
|
"payload": content
|
@@ -67,7 +67,7 @@ from synth_ai.lm.config import SynthConfig
|
|
67
67
|
|
68
68
|
# Import session tracer for v2 tracing
|
69
69
|
from synth_ai.tracing_v2.session_tracer import (
|
70
|
-
SessionTracer,
|
70
|
+
SessionTracer, SessionEventMarkovBlanketMessage, TimeRecord,
|
71
71
|
RuntimeEvent, EnvironmentEvent, LMCAISEvent
|
72
72
|
)
|
73
73
|
from synth_ai.tracing_v2.utils import create_experiment_context
|
@@ -175,9 +175,9 @@ async def retry_http_request(client: AsyncClient, method: str, url: str, **kwarg
|
|
175
175
|
raise last_exception
|
176
176
|
|
177
177
|
|
178
|
-
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) ->
|
178
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) -> SessionEventMarkovBlanketMessage:
|
179
179
|
"""Create a message with origin system ID embedded in content."""
|
180
|
-
return
|
180
|
+
return SessionEventMarkovBlanketMessage(
|
181
181
|
content={
|
182
182
|
"origin_system_id": str(origin_system_id),
|
183
183
|
"payload": content
|
@@ -52,7 +52,7 @@ from synth_ai.lm.core.main_v3 import LM
|
|
52
52
|
# Import session tracer for v3 tracing
|
53
53
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
54
54
|
from synth_ai.tracing_v3.abstractions import (
|
55
|
-
|
55
|
+
SessionEventMarkovBlanketMessage, TimeRecord,
|
56
56
|
RuntimeEvent, EnvironmentEvent
|
57
57
|
)
|
58
58
|
# from synth_ai.tracing_v3.utils import create_experiment_context # Not needed
|
@@ -176,9 +176,9 @@ def compress_observation_for_trace(obs: dict[str, Any]) -> str:
|
|
176
176
|
return f"{{\"error\": \"{str(e)}\"}}"
|
177
177
|
|
178
178
|
|
179
|
-
def create_message(content: str, message_type: str, system_id: str, turn: int) ->
|
180
|
-
"""Create a
|
181
|
-
return
|
179
|
+
def create_message(content: str, message_type: str, system_id: str, turn: int) -> SessionEventMarkovBlanketMessage:
|
180
|
+
"""Create a SessionEventMarkovBlanketMessage with metadata."""
|
181
|
+
return SessionEventMarkovBlanketMessage(
|
182
182
|
content=content,
|
183
183
|
message_type=message_type,
|
184
184
|
metadata={"system_id": system_id, "turn": turn},
|
@@ -71,7 +71,7 @@ import numpy as np
|
|
71
71
|
|
72
72
|
# Import session tracer for CAIS event capture
|
73
73
|
from synth_ai.tracing_v2.session_tracer import (
|
74
|
-
SessionTracer,
|
74
|
+
SessionTracer, SessionEventMarkovBlanketMessage, TimeRecord,
|
75
75
|
RuntimeEvent, EnvironmentEvent
|
76
76
|
)
|
77
77
|
from synth_ai.tracing_v2.abstractions import CAISEvent
|
@@ -150,9 +150,9 @@ except ImportError:
|
|
150
150
|
|
151
151
|
|
152
152
|
# Create a proper message structure with origin_system_id
|
153
|
-
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) ->
|
153
|
+
def create_message(content: Any, message_type: str, origin_system_id: Any, turn: int) -> SessionEventMarkovBlanketMessage:
|
154
154
|
"""Create a message with origin system ID embedded in content."""
|
155
|
-
return
|
155
|
+
return SessionEventMarkovBlanketMessage(
|
156
156
|
content={
|
157
157
|
"origin_system_id": str(origin_system_id),
|
158
158
|
"payload": content
|
@@ -18,7 +18,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent.parent)
|
|
18
18
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
19
19
|
from synth_ai.tracing_v3.abstractions import (
|
20
20
|
RuntimeEvent, EnvironmentEvent, LMCAISEvent,
|
21
|
-
TimeRecord,
|
21
|
+
TimeRecord, SessionEventMarkovBlanketMessage
|
22
22
|
)
|
23
23
|
from synth_ai.tracing_v3.turso.manager import AsyncSQLTraceManager
|
24
24
|
from synth_ai.tracing_v3.decorators import set_session_id, set_turn_number
|
@@ -13,7 +13,7 @@ from synth_ai.environments.examples.crafter_classic.config_logging import safe_c
|
|
13
13
|
# Import tracing abstractions
|
14
14
|
from synth_ai.tracing_v3.abstractions import (
|
15
15
|
RuntimeEvent,
|
16
|
-
|
16
|
+
SessionEventMarkovBlanketMessage,
|
17
17
|
TimeRecord,
|
18
18
|
)
|
19
19
|
|