synth-ai 0.2.4.dev5__py3-none-any.whl → 0.2.4.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +18 -9
- synth_ai/cli/__init__.py +10 -5
- synth_ai/cli/balance.py +22 -17
- synth_ai/cli/calc.py +2 -3
- synth_ai/cli/demo.py +3 -5
- synth_ai/cli/legacy_root_backup.py +58 -32
- synth_ai/cli/man.py +22 -19
- synth_ai/cli/recent.py +9 -8
- synth_ai/cli/root.py +58 -13
- synth_ai/cli/status.py +13 -6
- synth_ai/cli/traces.py +45 -21
- synth_ai/cli/watch.py +40 -37
- synth_ai/config/base_url.py +1 -3
- synth_ai/core/experiment.py +1 -2
- synth_ai/environments/__init__.py +2 -6
- synth_ai/environments/environment/artifacts/base.py +3 -1
- synth_ai/environments/environment/db/sqlite.py +1 -1
- synth_ai/environments/environment/registry.py +19 -20
- synth_ai/environments/environment/resources/sqlite.py +2 -3
- synth_ai/environments/environment/rewards/core.py +3 -2
- synth_ai/environments/environment/tools/__init__.py +6 -4
- synth_ai/environments/examples/crafter_classic/__init__.py +1 -1
- synth_ai/environments/examples/crafter_classic/engine.py +21 -17
- synth_ai/environments/examples/crafter_classic/engine_deterministic_patch.py +1 -0
- synth_ai/environments/examples/crafter_classic/engine_helpers/action_map.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_helpers/serialization.py +2 -1
- synth_ai/environments/examples/crafter_classic/engine_serialization_patch_v3.py +3 -2
- synth_ai/environments/examples/crafter_classic/environment.py +16 -15
- synth_ai/environments/examples/crafter_classic/taskset.py +2 -2
- synth_ai/environments/examples/crafter_classic/trace_hooks_v3.py +2 -3
- synth_ai/environments/examples/crafter_classic/world_config_patch_simple.py +2 -1
- synth_ai/environments/examples/crafter_custom/crafter/__init__.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/config.py +2 -2
- synth_ai/environments/examples/crafter_custom/crafter/env.py +1 -5
- synth_ai/environments/examples/crafter_custom/crafter/objects.py +1 -2
- synth_ai/environments/examples/crafter_custom/crafter/worldgen.py +1 -2
- synth_ai/environments/examples/crafter_custom/dataset_builder.py +5 -5
- synth_ai/environments/examples/crafter_custom/environment.py +13 -13
- synth_ai/environments/examples/crafter_custom/run_dataset.py +5 -5
- synth_ai/environments/examples/enron/art_helpers/email_search_tools.py +2 -2
- synth_ai/environments/examples/enron/art_helpers/local_email_db.py +5 -4
- synth_ai/environments/examples/enron/art_helpers/types_enron.py +2 -1
- synth_ai/environments/examples/enron/engine.py +18 -14
- synth_ai/environments/examples/enron/environment.py +12 -11
- synth_ai/environments/examples/enron/taskset.py +7 -7
- synth_ai/environments/examples/minigrid/__init__.py +6 -6
- synth_ai/environments/examples/minigrid/engine.py +6 -6
- synth_ai/environments/examples/minigrid/environment.py +6 -6
- synth_ai/environments/examples/minigrid/puzzle_loader.py +3 -2
- synth_ai/environments/examples/minigrid/taskset.py +13 -13
- synth_ai/environments/examples/nethack/achievements.py +1 -1
- synth_ai/environments/examples/nethack/engine.py +8 -7
- synth_ai/environments/examples/nethack/environment.py +10 -9
- synth_ai/environments/examples/nethack/helpers/__init__.py +8 -9
- synth_ai/environments/examples/nethack/helpers/action_mapping.py +1 -1
- synth_ai/environments/examples/nethack/helpers/nle_wrapper.py +2 -1
- synth_ai/environments/examples/nethack/helpers/observation_utils.py +1 -1
- synth_ai/environments/examples/nethack/helpers/recording_wrapper.py +3 -4
- synth_ai/environments/examples/nethack/helpers/trajectory_recorder.py +6 -5
- synth_ai/environments/examples/nethack/helpers/visualization/replay_viewer.py +5 -5
- synth_ai/environments/examples/nethack/helpers/visualization/visualizer.py +7 -6
- synth_ai/environments/examples/nethack/taskset.py +5 -5
- synth_ai/environments/examples/red/engine.py +9 -8
- synth_ai/environments/examples/red/engine_helpers/reward_components.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/__init__.py +7 -7
- synth_ai/environments/examples/red/engine_helpers/reward_library/adaptive_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/battle_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/composite_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/economy_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/efficiency_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/exploration_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/novelty_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pallet_town_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/pokemon_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/social_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/reward_library/story_rewards.py +2 -1
- synth_ai/environments/examples/red/engine_helpers/screen_analysis.py +3 -2
- synth_ai/environments/examples/red/engine_helpers/state_extraction.py +2 -1
- synth_ai/environments/examples/red/environment.py +18 -15
- synth_ai/environments/examples/red/taskset.py +5 -3
- synth_ai/environments/examples/sokoban/engine.py +16 -13
- synth_ai/environments/examples/sokoban/engine_helpers/room_utils.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/__init__.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/__init__.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/boxoban_env.py +7 -5
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/render_utils.py +1 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/room_utils.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_fixed_targets.py +3 -2
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_pull.py +2 -1
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_two_player.py +5 -4
- synth_ai/environments/examples/sokoban/engine_helpers/vendored/envs/sokoban_env_variations.py +1 -1
- synth_ai/environments/examples/sokoban/environment.py +15 -14
- synth_ai/environments/examples/sokoban/generate_verified_puzzles.py +5 -3
- synth_ai/environments/examples/sokoban/puzzle_loader.py +3 -2
- synth_ai/environments/examples/sokoban/taskset.py +13 -10
- synth_ai/environments/examples/tictactoe/engine.py +6 -6
- synth_ai/environments/examples/tictactoe/environment.py +8 -7
- synth_ai/environments/examples/tictactoe/taskset.py +6 -5
- synth_ai/environments/examples/verilog/engine.py +4 -3
- synth_ai/environments/examples/verilog/environment.py +11 -10
- synth_ai/environments/examples/verilog/taskset.py +14 -12
- synth_ai/environments/examples/wordle/__init__.py +29 -0
- synth_ai/environments/examples/wordle/engine.py +398 -0
- synth_ai/environments/examples/wordle/environment.py +159 -0
- synth_ai/environments/examples/wordle/helpers/generate_instances_wordfreq.py +75 -0
- synth_ai/environments/examples/wordle/taskset.py +230 -0
- synth_ai/environments/reproducibility/core.py +1 -1
- synth_ai/environments/reproducibility/tree.py +21 -21
- synth_ai/environments/service/app.py +11 -2
- synth_ai/environments/service/core_routes.py +137 -105
- synth_ai/environments/service/external_registry.py +1 -2
- synth_ai/environments/service/registry.py +1 -1
- synth_ai/environments/stateful/core.py +1 -2
- synth_ai/environments/stateful/engine.py +1 -1
- synth_ai/environments/tasks/api.py +4 -4
- synth_ai/environments/tasks/core.py +14 -12
- synth_ai/environments/tasks/filters.py +6 -4
- synth_ai/environments/tasks/utils.py +13 -11
- synth_ai/evals/base.py +2 -3
- synth_ai/experimental/synth_oss.py +4 -4
- synth_ai/learning/gateway.py +1 -3
- synth_ai/learning/prompts/banking77_injection_eval.py +168 -0
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +213 -0
- synth_ai/learning/prompts/mipro.py +282 -1
- synth_ai/learning/prompts/random_search.py +246 -0
- synth_ai/learning/prompts/run_mipro_banking77.py +172 -0
- synth_ai/learning/prompts/run_random_search_banking77.py +324 -0
- synth_ai/lm/__init__.py +5 -5
- synth_ai/lm/caching/ephemeral.py +9 -9
- synth_ai/lm/caching/handler.py +20 -20
- synth_ai/lm/caching/persistent.py +10 -10
- synth_ai/lm/config.py +3 -3
- synth_ai/lm/constants.py +7 -7
- synth_ai/lm/core/all.py +17 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +26 -41
- synth_ai/lm/core/main_v3.py +20 -10
- synth_ai/lm/core/vendor_clients.py +18 -17
- synth_ai/lm/injection.py +80 -0
- synth_ai/lm/overrides.py +206 -0
- synth_ai/lm/provider_support/__init__.py +1 -1
- synth_ai/lm/provider_support/anthropic.py +51 -24
- synth_ai/lm/provider_support/openai.py +51 -22
- synth_ai/lm/structured_outputs/handler.py +34 -32
- synth_ai/lm/structured_outputs/inject.py +24 -27
- synth_ai/lm/structured_outputs/rehabilitate.py +19 -15
- synth_ai/lm/tools/base.py +17 -16
- synth_ai/lm/unified_interface.py +17 -18
- synth_ai/lm/vendors/base.py +20 -18
- synth_ai/lm/vendors/core/anthropic_api.py +50 -25
- synth_ai/lm/vendors/core/gemini_api.py +31 -36
- synth_ai/lm/vendors/core/mistral_api.py +19 -19
- synth_ai/lm/vendors/core/openai_api.py +11 -10
- synth_ai/lm/vendors/openai_standard.py +144 -88
- synth_ai/lm/vendors/openai_standard_responses.py +74 -61
- synth_ai/lm/vendors/retries.py +9 -1
- synth_ai/lm/vendors/supported/custom_endpoint.py +26 -26
- synth_ai/lm/vendors/supported/deepseek.py +10 -10
- synth_ai/lm/vendors/supported/grok.py +8 -8
- synth_ai/lm/vendors/supported/ollama.py +2 -1
- synth_ai/lm/vendors/supported/openrouter.py +11 -9
- synth_ai/lm/vendors/synth_client.py +69 -63
- synth_ai/lm/warmup.py +8 -7
- synth_ai/tracing/__init__.py +22 -10
- synth_ai/tracing_v1/__init__.py +22 -20
- synth_ai/tracing_v3/__init__.py +7 -7
- synth_ai/tracing_v3/abstractions.py +56 -52
- synth_ai/tracing_v3/config.py +4 -2
- synth_ai/tracing_v3/db_config.py +6 -8
- synth_ai/tracing_v3/decorators.py +29 -30
- synth_ai/tracing_v3/examples/basic_usage.py +12 -12
- synth_ai/tracing_v3/hooks.py +21 -21
- synth_ai/tracing_v3/llm_call_record_helpers.py +85 -98
- synth_ai/tracing_v3/lm_call_record_abstractions.py +2 -4
- synth_ai/tracing_v3/migration_helper.py +3 -5
- synth_ai/tracing_v3/replica_sync.py +30 -32
- synth_ai/tracing_v3/session_tracer.py +35 -29
- synth_ai/tracing_v3/storage/__init__.py +1 -1
- synth_ai/tracing_v3/storage/base.py +8 -7
- synth_ai/tracing_v3/storage/config.py +4 -4
- synth_ai/tracing_v3/storage/factory.py +4 -4
- synth_ai/tracing_v3/storage/utils.py +9 -9
- synth_ai/tracing_v3/turso/__init__.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +9 -9
- synth_ai/tracing_v3/turso/manager.py +60 -48
- synth_ai/tracing_v3/turso/models.py +24 -19
- synth_ai/tracing_v3/utils.py +5 -5
- synth_ai/tui/__main__.py +1 -1
- synth_ai/tui/cli/query_experiments.py +2 -3
- synth_ai/tui/cli/query_experiments_v3.py +2 -3
- synth_ai/tui/dashboard.py +97 -86
- synth_ai/v0/tracing/abstractions.py +28 -28
- synth_ai/v0/tracing/base_client.py +9 -9
- synth_ai/v0/tracing/client_manager.py +7 -7
- synth_ai/v0/tracing/config.py +7 -7
- synth_ai/v0/tracing/context.py +6 -6
- synth_ai/v0/tracing/decorators.py +6 -5
- synth_ai/v0/tracing/events/manage.py +1 -1
- synth_ai/v0/tracing/events/store.py +5 -4
- synth_ai/v0/tracing/immediate_client.py +4 -5
- synth_ai/v0/tracing/local.py +3 -3
- synth_ai/v0/tracing/log_client_base.py +4 -5
- synth_ai/v0/tracing/retry_queue.py +5 -6
- synth_ai/v0/tracing/trackers.py +25 -25
- synth_ai/v0/tracing/upload.py +6 -0
- synth_ai/v0/tracing_v1/__init__.py +1 -1
- synth_ai/v0/tracing_v1/abstractions.py +28 -28
- synth_ai/v0/tracing_v1/base_client.py +9 -9
- synth_ai/v0/tracing_v1/client_manager.py +7 -7
- synth_ai/v0/tracing_v1/config.py +7 -7
- synth_ai/v0/tracing_v1/context.py +6 -6
- synth_ai/v0/tracing_v1/decorators.py +7 -6
- synth_ai/v0/tracing_v1/events/manage.py +1 -1
- synth_ai/v0/tracing_v1/events/store.py +5 -4
- synth_ai/v0/tracing_v1/immediate_client.py +4 -5
- synth_ai/v0/tracing_v1/local.py +3 -3
- synth_ai/v0/tracing_v1/log_client_base.py +4 -5
- synth_ai/v0/tracing_v1/retry_queue.py +5 -6
- synth_ai/v0/tracing_v1/trackers.py +25 -25
- synth_ai/v0/tracing_v1/upload.py +25 -24
- synth_ai/zyk/__init__.py +1 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/METADATA +2 -11
- synth_ai-0.2.4.dev7.dist-info/RECORD +299 -0
- synth_ai-0.2.4.dev5.dist-info/RECORD +0 -287
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.4.dev5.dist-info → synth_ai-0.2.4.dev7.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Collection
|
2
2
|
from dataclasses import dataclass
|
3
|
-
from
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from synth_ai.environments.tasks.core import TaskInstance, TaskInstanceMetadataFilter
|
4
6
|
|
5
7
|
|
6
8
|
@dataclass
|
@@ -18,8 +20,8 @@ class ValueFilter(TaskInstanceMetadataFilter):
|
|
18
20
|
@dataclass
|
19
21
|
class RangeFilter(TaskInstanceMetadataFilter):
|
20
22
|
key: str
|
21
|
-
min_val:
|
22
|
-
max_val:
|
23
|
+
min_val: float | None = None
|
24
|
+
max_val: float | None = None
|
23
25
|
|
24
26
|
def __call__(self, instance: TaskInstance) -> bool:
|
25
27
|
instance_value = getattr(instance.metadata, self.key, None)
|
@@ -2,17 +2,19 @@
|
|
2
2
|
Utility functions and generic filters for taskset creation.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from
|
5
|
+
from collections.abc import Collection
|
6
|
+
from typing import Any
|
6
7
|
from uuid import UUID, uuid4
|
8
|
+
|
7
9
|
from synth_ai.environments.tasks.core import (
|
8
|
-
TaskInstanceMetadataFilter,
|
9
|
-
TaskInstanceSet,
|
10
10
|
SplitInfo,
|
11
11
|
TaskInstance,
|
12
|
+
TaskInstanceMetadataFilter,
|
13
|
+
TaskInstanceSet,
|
12
14
|
)
|
13
15
|
|
14
16
|
|
15
|
-
def parse_or_new_uuid(raw_id:
|
17
|
+
def parse_or_new_uuid(raw_id: str | None) -> UUID:
|
16
18
|
"""
|
17
19
|
Parse a raw ID string into a UUID, or generate a new one if invalid or missing.
|
18
20
|
"""
|
@@ -43,8 +45,8 @@ class RangeFilter(TaskInstanceMetadataFilter):
|
|
43
45
|
def __init__(
|
44
46
|
self,
|
45
47
|
key: str,
|
46
|
-
min_value:
|
47
|
-
max_value:
|
48
|
+
min_value: float | None = None,
|
49
|
+
max_value: float | None = None,
|
48
50
|
):
|
49
51
|
self.key = key
|
50
52
|
self.min_value = min_value
|
@@ -62,15 +64,15 @@ class RangeFilter(TaskInstanceMetadataFilter):
|
|
62
64
|
def make_taskset(
|
63
65
|
name: str,
|
64
66
|
description: str,
|
65
|
-
instances:
|
66
|
-
val_filter:
|
67
|
-
test_filter:
|
67
|
+
instances: list[TaskInstance],
|
68
|
+
val_filter: TaskInstanceMetadataFilter | None = None,
|
69
|
+
test_filter: TaskInstanceMetadataFilter | None = None,
|
68
70
|
) -> TaskInstanceSet:
|
69
71
|
"""
|
70
72
|
Assemble a TaskInstanceSet by applying optional validation and test filters.
|
71
73
|
"""
|
72
|
-
val_ids:
|
73
|
-
test_ids:
|
74
|
+
val_ids: set[Any] = set()
|
75
|
+
test_ids: set[Any] = set()
|
74
76
|
if val_filter:
|
75
77
|
val_ids = {inst.id for inst in instances if val_filter(inst)}
|
76
78
|
if test_filter:
|
synth_ai/evals/base.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
|
-
from typing import List
|
2
1
|
|
3
2
|
|
4
3
|
class Judgement:
|
5
4
|
def __init__(
|
6
|
-
self, criteria: str, score: float, reasoning: str = "", evidence:
|
5
|
+
self, criteria: str, score: float, reasoning: str = "", evidence: list[str] = None
|
7
6
|
):
|
8
7
|
self.criteria = criteria
|
9
8
|
self.score = score
|
@@ -12,5 +11,5 @@ class Judgement:
|
|
12
11
|
|
13
12
|
|
14
13
|
class BaseEval:
|
15
|
-
async def run(self, data: any) ->
|
14
|
+
async def run(self, data: any) -> list[Judgement]:
|
16
15
|
pass
|
@@ -1,5 +1,5 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
# ruff: noqa
|
2
|
+
'''
|
3
3
|
Synth OSS Integration Module
|
4
4
|
|
5
5
|
This module provides integration with Synth's open-source inference and training APIs
|
@@ -336,7 +336,7 @@ Implementation sketch (backend == "synth")
|
|
336
336
|
The method is a *no-op* for the default (OpenAI) backend so existing code keeps
|
337
337
|
working.
|
338
338
|
|
339
|
-
|
339
|
+
'''
|
340
340
|
|
341
341
|
|
342
342
|
"""
|
@@ -443,4 +443,4 @@ async def warmup(
|
|
443
443
|
So: **the existing endpoint does not yet support GPU selection; we need to add
|
444
444
|
the small change above on the `learning_v2` side and then LM.warmup can request
|
445
445
|
specific GPUs.**
|
446
|
-
"""
|
446
|
+
"""
|
synth_ai/learning/gateway.py
CHANGED
@@ -0,0 +1,168 @@
|
|
1
|
+
"""
|
2
|
+
Banking77 in-context injection evals (async, not tests)
|
3
|
+
|
4
|
+
Samples a handful of Banking77 prompts and evaluates multiple override
|
5
|
+
contexts in parallel, printing simple accuracy for each.
|
6
|
+
|
7
|
+
Usage
|
8
|
+
- Keys in .env (GROQ_API_KEY, etc.)
|
9
|
+
- Run: uv run -q python -m synth_ai.learning.prompts.banking77_injection_eval
|
10
|
+
Optional env:
|
11
|
+
- N_SAMPLES=20 (default)
|
12
|
+
- MODEL=openai/gpt-oss-20b (default)
|
13
|
+
- VENDOR=groq (default)
|
14
|
+
"""
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import asyncio
|
19
|
+
import os
|
20
|
+
import random
|
21
|
+
from typing import Any
|
22
|
+
|
23
|
+
from datasets import load_dataset
|
24
|
+
from dotenv import load_dotenv
|
25
|
+
from synth_ai.lm.core.main_v3 import LM, build_messages
|
26
|
+
from synth_ai.lm.overrides import LMOverridesContext
|
27
|
+
|
28
|
+
|
29
|
+
async def classify_one(lm: LM, text: str, label_names: list[str]) -> str:
|
30
|
+
labels_joined = ", ".join(label_names)
|
31
|
+
system_message = (
|
32
|
+
"You are an intent classifier for the Banking77 dataset. "
|
33
|
+
"Given a customer message, respond with exactly one label from the list. "
|
34
|
+
"Return only the label text with no extra words.\n\n"
|
35
|
+
f"Valid labels: {labels_joined}"
|
36
|
+
)
|
37
|
+
user_message = f"Message: {text}\nLabel:"
|
38
|
+
messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
|
39
|
+
resp = await lm.respond_async(messages=messages)
|
40
|
+
return (resp.raw_response or "").strip()
|
41
|
+
|
42
|
+
|
43
|
+
def choose_label(pred: str, label_names: list[str]) -> str:
|
44
|
+
norm_pred = pred.strip().lower()
|
45
|
+
label_lookup = {ln.lower(): ln for ln in label_names}
|
46
|
+
mapped = label_lookup.get(norm_pred)
|
47
|
+
if mapped is not None:
|
48
|
+
return mapped
|
49
|
+
|
50
|
+
# Fallback: choose the label with the highest naive token overlap
|
51
|
+
def score(cand: str) -> int:
|
52
|
+
c = cand.lower()
|
53
|
+
return sum(1 for w in c.split() if w in norm_pred)
|
54
|
+
|
55
|
+
return max(label_names, key=score)
|
56
|
+
|
57
|
+
|
58
|
+
async def eval_context(
|
59
|
+
lm: LM,
|
60
|
+
items: list[tuple[str, str]],
|
61
|
+
label_names: list[str],
|
62
|
+
ctx_name: str,
|
63
|
+
specs: list[dict[str, Any]],
|
64
|
+
) -> tuple[str, int, int]:
|
65
|
+
correct = 0
|
66
|
+
with LMOverridesContext(specs):
|
67
|
+
tasks = [classify_one(lm, text, label_names) for text, _ in items]
|
68
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
69
|
+
for (text, gold), pred in zip(items, results, strict=False):
|
70
|
+
if isinstance(pred, Exception):
|
71
|
+
# Treat exceptions as incorrect
|
72
|
+
continue
|
73
|
+
mapped = choose_label(pred, label_names)
|
74
|
+
correct += int(mapped == gold)
|
75
|
+
return (ctx_name, correct, len(items))
|
76
|
+
|
77
|
+
|
78
|
+
async def main() -> None:
|
79
|
+
load_dotenv()
|
80
|
+
|
81
|
+
n = int(os.getenv("N_SAMPLES", "20"))
|
82
|
+
model = os.getenv("MODEL", "openai/gpt-oss-20b")
|
83
|
+
vendor = os.getenv("VENDOR", "groq")
|
84
|
+
|
85
|
+
lm = LM(model=model, vendor=vendor, temperature=0.0)
|
86
|
+
|
87
|
+
print("Loading Banking77 dataset (split='test')...")
|
88
|
+
ds = load_dataset("banking77", split="test")
|
89
|
+
label_names: list[str] = ds.features["label"].names # type: ignore
|
90
|
+
|
91
|
+
idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
|
92
|
+
items = [
|
93
|
+
(ds[i]["text"], label_names[int(ds[i]["label"])]) # (text, gold_label)
|
94
|
+
for i in idxs
|
95
|
+
]
|
96
|
+
|
97
|
+
# Define a few override contexts to compare
|
98
|
+
contexts: list[dict[str, Any]] = [
|
99
|
+
{
|
100
|
+
"name": "baseline (no overrides)",
|
101
|
+
"overrides": [],
|
102
|
+
},
|
103
|
+
{
|
104
|
+
"name": "nonsense prompt injection (expected worse)",
|
105
|
+
"overrides": [
|
106
|
+
{
|
107
|
+
"match": {"contains": "", "role": "user"},
|
108
|
+
"injection_rules": [
|
109
|
+
# Heavily corrupt user text by replacing vowels
|
110
|
+
{"find": "a", "replace": "x"},
|
111
|
+
{"find": "e", "replace": "x"},
|
112
|
+
{"find": "i", "replace": "x"},
|
113
|
+
{"find": "o", "replace": "x"},
|
114
|
+
{"find": "u", "replace": "x"},
|
115
|
+
{"find": "A", "replace": "X"},
|
116
|
+
{"find": "E", "replace": "X"},
|
117
|
+
{"find": "I", "replace": "X"},
|
118
|
+
{"find": "O", "replace": "X"},
|
119
|
+
{"find": "U", "replace": "X"},
|
120
|
+
],
|
121
|
+
}
|
122
|
+
],
|
123
|
+
},
|
124
|
+
{
|
125
|
+
"name": "injection: atm->ATM, txn->transaction",
|
126
|
+
"overrides": [
|
127
|
+
{
|
128
|
+
"match": {"contains": "atm", "role": "user"},
|
129
|
+
"injection_rules": [
|
130
|
+
{"find": "atm", "replace": "ATM"},
|
131
|
+
{"find": "txn", "replace": "transaction"},
|
132
|
+
],
|
133
|
+
}
|
134
|
+
],
|
135
|
+
},
|
136
|
+
{
|
137
|
+
"name": "params: temperature=0.0",
|
138
|
+
"overrides": [
|
139
|
+
{"match": {"contains": ""}, "params": {"temperature": 0.0}},
|
140
|
+
],
|
141
|
+
},
|
142
|
+
{
|
143
|
+
"name": "model override: 20b->120b",
|
144
|
+
"overrides": [
|
145
|
+
{"match": {"contains": ""}, "params": {"model": "openai/gpt-oss-120b"}},
|
146
|
+
],
|
147
|
+
},
|
148
|
+
]
|
149
|
+
|
150
|
+
print(f"\nEvaluating {len(contexts)} contexts on {len(items)} Banking77 samples (async)...")
|
151
|
+
|
152
|
+
# Evaluate each context sequentially but batched (each context classifies in parallel)
|
153
|
+
results: list[tuple[str, int, int]] = []
|
154
|
+
for ctx in contexts:
|
155
|
+
name = ctx["name"]
|
156
|
+
specs = ctx["overrides"]
|
157
|
+
print(f"Evaluating: {name} ...")
|
158
|
+
res = await eval_context(lm, items, label_names, name, specs)
|
159
|
+
results.append(res)
|
160
|
+
|
161
|
+
print("\nResults:")
|
162
|
+
for name, correct, total in results:
|
163
|
+
acc = correct / total if total else 0.0
|
164
|
+
print(f"- {name}: {correct}/{total} correct ({acc:.2%})")
|
165
|
+
|
166
|
+
|
167
|
+
if __name__ == "__main__":
|
168
|
+
asyncio.run(main())
|
@@ -0,0 +1,213 @@
|
|
1
|
+
"""
|
2
|
+
Hello World: Banking77 intent classification with in-context injection
|
3
|
+
|
4
|
+
This script shows a minimal text-classification pipeline over the
|
5
|
+
Hugging Face Banking77 dataset using the Synth LM interface. It also
|
6
|
+
demonstrates a simple pre-send prompt-injection step as outlined in
|
7
|
+
`synth_ai/learning/prompts/injection_plan.txt`.
|
8
|
+
|
9
|
+
Notes
|
10
|
+
- Network access is required to download the dataset and call the model.
|
11
|
+
- Defaults to Groq with model `openai/gpt-oss-20b`.
|
12
|
+
- Export your key: `export GROQ_API_KEY=...`
|
13
|
+
- Override if needed: `export MODEL=openai/gpt-oss-20b VENDOR=groq`
|
14
|
+
|
15
|
+
Run
|
16
|
+
- `python -m synth_ai.learning.prompts.hello_world_in_context_injection_ex`
|
17
|
+
|
18
|
+
What "in-context injection" means here
|
19
|
+
- The script applies ordered substring replacements to the outgoing
|
20
|
+
`messages` array before calling the model. This mirrors the algorithm
|
21
|
+
described in `injection_plan.txt` without importing any non-existent
|
22
|
+
helper yet. You can adapt `INJECTION_RULES` to your needs.
|
23
|
+
"""
|
24
|
+
|
25
|
+
from __future__ import annotations
|
26
|
+
|
27
|
+
import asyncio
|
28
|
+
import os
|
29
|
+
import random
|
30
|
+
|
31
|
+
from datasets import load_dataset
|
32
|
+
|
33
|
+
# Use the v3 LM class present in this repo
|
34
|
+
from synth_ai.lm.core.main_v3 import LM, build_messages
|
35
|
+
|
36
|
+
# Use Overrides context to demonstrate matching by content
|
37
|
+
from synth_ai.lm.overrides import LMOverridesContext
|
38
|
+
from synth_ai.tracing_v3.abstractions import LMCAISEvent
|
39
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
40
|
+
|
41
|
+
INJECTION_RULES = [
|
42
|
+
{"find": "accnt", "replace": "account"},
|
43
|
+
{"find": "atm", "replace": "ATM"},
|
44
|
+
{"find": "txn", "replace": "transaction"},
|
45
|
+
]
|
46
|
+
|
47
|
+
|
48
|
+
async def classify_sample(lm: LM, text: str, label_names: list[str]) -> str:
|
49
|
+
"""Classify one Banking77 utterance and return the predicted label name."""
|
50
|
+
labels_joined = ", ".join(label_names)
|
51
|
+
system_message = (
|
52
|
+
"You are an intent classifier for the Banking77 dataset. "
|
53
|
+
"Given a customer message, respond with exactly one label from the list. "
|
54
|
+
"Return only the label text with no extra words.\n\n"
|
55
|
+
f"Valid labels: {labels_joined}"
|
56
|
+
)
|
57
|
+
user_message = f"Message: {text}\nLabel:"
|
58
|
+
|
59
|
+
# Build canonical messages; injection will be applied inside the vendor via context
|
60
|
+
messages = build_messages(system_message, user_message, images_bytes=None, model_name=lm.model)
|
61
|
+
resp = await lm.respond_async(messages=messages)
|
62
|
+
raw = (resp.raw_response or "").strip()
|
63
|
+
return raw
|
64
|
+
|
65
|
+
|
66
|
+
async def main() -> None:
|
67
|
+
# Configurable model/provider via env, with sensible defaults
|
68
|
+
# Default to Groq hosting `openai/gpt-oss-20b`
|
69
|
+
model = os.getenv("MODEL", "openai/gpt-oss-20b")
|
70
|
+
vendor = os.getenv("VENDOR", "groq")
|
71
|
+
|
72
|
+
# Construct LM
|
73
|
+
lm = LM(model=model, vendor=vendor, temperature=0.0)
|
74
|
+
|
75
|
+
# Load Banking77 dataset
|
76
|
+
# Columns: {"text": str, "label": int}; label names at ds.features["label"].names
|
77
|
+
print("Loading Banking77 dataset (split='test')...")
|
78
|
+
ds = load_dataset("banking77", split="test")
|
79
|
+
label_names: list[str] = ds.features["label"].names # type: ignore
|
80
|
+
|
81
|
+
# Sample a few items for a quick demo
|
82
|
+
n = int(os.getenv("N_SAMPLES", "8"))
|
83
|
+
idxs = random.sample(range(len(ds)), k=min(n, len(ds)))
|
84
|
+
|
85
|
+
correct = 0
|
86
|
+
# Apply overrides for all calls in this block (match by content)
|
87
|
+
overrides = [
|
88
|
+
{"match": {"contains": "atm", "role": "user"}, "injection_rules": INJECTION_RULES},
|
89
|
+
{"match": {"contains": "refund"}, "params": {"temperature": 0.0}},
|
90
|
+
]
|
91
|
+
with LMOverridesContext(overrides):
|
92
|
+
for i, idx in enumerate(idxs, start=1):
|
93
|
+
text: str = ds[idx]["text"] # type: ignore
|
94
|
+
gold_label_idx: int = int(ds[idx]["label"]) # type: ignore
|
95
|
+
gold_label = label_names[gold_label_idx]
|
96
|
+
|
97
|
+
try:
|
98
|
+
pred = await classify_sample(lm, text, label_names)
|
99
|
+
except Exception as e:
|
100
|
+
print(f"[{i}] Error calling model: {e}")
|
101
|
+
break
|
102
|
+
|
103
|
+
# Normalize and check exact match; if not exact, attempt a loose fallback
|
104
|
+
norm_pred = pred.strip().lower()
|
105
|
+
label_lookup = {ln.lower(): ln for ln in label_names}
|
106
|
+
pred_label = label_lookup.get(norm_pred)
|
107
|
+
if pred_label is None:
|
108
|
+
# Fallback: pick the label with highest substring overlap (very naive)
|
109
|
+
# This avoids extra deps; feel free to replace with a better matcher.
|
110
|
+
def score(cand: str) -> int:
|
111
|
+
c = cand.lower()
|
112
|
+
return sum(1 for w in c.split() if w in norm_pred)
|
113
|
+
|
114
|
+
pred_label = max(label_names, key=score)
|
115
|
+
|
116
|
+
is_correct = pred_label == gold_label
|
117
|
+
correct += int(is_correct)
|
118
|
+
print(
|
119
|
+
f"[{i}] text={text!r}\n gold={gold_label}\n pred={pred} -> mapped={pred_label} {'✅' if is_correct else '❌'}"
|
120
|
+
)
|
121
|
+
|
122
|
+
if idxs:
|
123
|
+
acc = correct / len(idxs)
|
124
|
+
print(f"\nSamples: {len(idxs)} | Correct: {correct} | Accuracy: {acc:.2%}")
|
125
|
+
|
126
|
+
# ------------------------------
|
127
|
+
# Integration tests (three paths)
|
128
|
+
# ------------------------------
|
129
|
+
print("\nRunning integration tests with in-context injection...")
|
130
|
+
test_text = "I used the atm to withdraw cash."
|
131
|
+
|
132
|
+
# 1) LM path with v3 tracing: verify substitution in traced messages
|
133
|
+
tracer = SessionTracer()
|
134
|
+
await tracer.start_session(metadata={"test": "lm_injection"})
|
135
|
+
await tracer.start_timestep(step_id="lm_test")
|
136
|
+
# Use a tracer-bound LM instance
|
137
|
+
lm_traced = LM(model=model, vendor=vendor, temperature=0.0, session_tracer=tracer)
|
138
|
+
with LMOverridesContext([{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]):
|
139
|
+
_ = await classify_sample(lm_traced, test_text, label_names)
|
140
|
+
# inspect trace
|
141
|
+
events = [
|
142
|
+
e
|
143
|
+
for e in (tracer.current_session.event_history if tracer.current_session else [])
|
144
|
+
if isinstance(e, LMCAISEvent)
|
145
|
+
]
|
146
|
+
assert events, "No LMCAISEvent recorded by SessionTracer"
|
147
|
+
cr = events[-1].call_records[0]
|
148
|
+
traced_user = ""
|
149
|
+
for m in cr.input_messages:
|
150
|
+
if m.role == "user":
|
151
|
+
for part in m.parts:
|
152
|
+
if getattr(part, "type", None) == "text":
|
153
|
+
traced_user += part.text or ""
|
154
|
+
assert "ATM" in traced_user, f"Expected substitution in traced prompt; got: {traced_user!r}"
|
155
|
+
print("LM path trace verified: substitution present in traced prompt.")
|
156
|
+
await tracer.end_timestep()
|
157
|
+
await tracer.end_session()
|
158
|
+
|
159
|
+
# 2) OpenAI wrapper path (AsyncOpenAI to Groq): ensure apply_injection is active
|
160
|
+
try:
|
161
|
+
import synth_ai.lm.provider_support.openai as _synth_openai_patch # noqa: F401
|
162
|
+
from openai import AsyncOpenAI
|
163
|
+
|
164
|
+
base_url = os.getenv("OPENAI_BASE_URL", "https://api.groq.com/openai/v1")
|
165
|
+
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("GROQ_API_KEY") or ""
|
166
|
+
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
|
167
|
+
messages = [
|
168
|
+
{"role": "system", "content": "Echo user label."},
|
169
|
+
{"role": "user", "content": f"Please classify: {test_text}"},
|
170
|
+
]
|
171
|
+
with LMOverridesContext(
|
172
|
+
[{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]
|
173
|
+
):
|
174
|
+
_ = await client.chat.completions.create(
|
175
|
+
model=model, messages=messages, temperature=0
|
176
|
+
)
|
177
|
+
# Not all models echo input; instead, verify that our injected expectation matches
|
178
|
+
expected_user = messages[1]["content"].replace("atm", "ATM")
|
179
|
+
if messages[1]["content"] == expected_user:
|
180
|
+
print("OpenAI wrapper: input already normalized; skipping assertion.")
|
181
|
+
else:
|
182
|
+
print("OpenAI wrapper: sent message contains substitution expectation:", expected_user)
|
183
|
+
except Exception as e:
|
184
|
+
print("OpenAI wrapper test skipped due to error:", e)
|
185
|
+
|
186
|
+
# 3) Anthropic wrapper path (AsyncClient): ensure apply_injection is active
|
187
|
+
try:
|
188
|
+
import anthropic
|
189
|
+
import synth_ai.lm.provider_support.anthropic as _synth_anthropic_patch # noqa: F401
|
190
|
+
|
191
|
+
a_model = os.getenv("ANTHROPIC_MODEL", "claude-3-5-haiku-20241022")
|
192
|
+
a_key = os.getenv("ANTHROPIC_API_KEY")
|
193
|
+
if a_key:
|
194
|
+
a_client = anthropic.AsyncClient(api_key=a_key)
|
195
|
+
with LMOverridesContext(
|
196
|
+
[{"match": {"contains": "atm"}, "injection_rules": INJECTION_RULES}]
|
197
|
+
):
|
198
|
+
_ = await a_client.messages.create(
|
199
|
+
model=a_model,
|
200
|
+
system="Echo user label.",
|
201
|
+
max_tokens=64,
|
202
|
+
temperature=0,
|
203
|
+
messages=[{"role": "user", "content": [{"type": "text", "text": test_text}]}],
|
204
|
+
)
|
205
|
+
print("Anthropic wrapper call completed (cannot reliably assert echo).")
|
206
|
+
else:
|
207
|
+
print("Anthropic wrapper test skipped: ANTHROPIC_API_KEY not set.")
|
208
|
+
except Exception as e:
|
209
|
+
print("Anthropic wrapper test skipped due to error:", e)
|
210
|
+
|
211
|
+
|
212
|
+
if __name__ == "__main__":
|
213
|
+
asyncio.run(main())
|