synth-ai 0.1.0.dev39__py3-none-any.whl → 0.1.0.dev49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/__init__.py +3 -1
- {synth_ai-0.1.0.dev39.dist-info → synth_ai-0.1.0.dev49.dist-info}/METADATA +12 -11
- synth_ai-0.1.0.dev49.dist-info/RECORD +6 -0
- {synth_ai-0.1.0.dev39.dist-info → synth_ai-0.1.0.dev49.dist-info}/WHEEL +1 -1
- synth_ai-0.1.0.dev49.dist-info/top_level.txt +1 -0
- private_tests/try_synth_sdk.py +0 -1
- public_tests/test_agent.py +0 -538
- public_tests/test_all_structured_outputs.py +0 -196
- public_tests/test_anthropic_structured_outputs.py +0 -0
- public_tests/test_deepseek_structured_outputs.py +0 -0
- public_tests/test_deepseek_tools.py +0 -64
- public_tests/test_gemini_output.py +0 -188
- public_tests/test_gemini_structured_outputs.py +0 -106
- public_tests/test_models.py +0 -183
- public_tests/test_openai_structured_outputs.py +0 -106
- public_tests/test_reasoning_effort.py +0 -75
- public_tests/test_reasoning_models.py +0 -92
- public_tests/test_recursive_structured_outputs.py +0 -180
- public_tests/test_structured.py +0 -137
- public_tests/test_structured_outputs.py +0 -109
- public_tests/test_synth_sdk.py +0 -384
- public_tests/test_text.py +0 -160
- public_tests/test_tools.py +0 -319
- synth_ai/zyk/__init__.py +0 -3
- synth_ai/zyk/lms/__init__.py +0 -0
- synth_ai/zyk/lms/caching/__init__.py +0 -0
- synth_ai/zyk/lms/caching/constants.py +0 -1
- synth_ai/zyk/lms/caching/dbs.py +0 -0
- synth_ai/zyk/lms/caching/ephemeral.py +0 -72
- synth_ai/zyk/lms/caching/handler.py +0 -142
- synth_ai/zyk/lms/caching/initialize.py +0 -13
- synth_ai/zyk/lms/caching/persistent.py +0 -83
- synth_ai/zyk/lms/config.py +0 -8
- synth_ai/zyk/lms/core/__init__.py +0 -0
- synth_ai/zyk/lms/core/all.py +0 -47
- synth_ai/zyk/lms/core/exceptions.py +0 -9
- synth_ai/zyk/lms/core/main.py +0 -314
- synth_ai/zyk/lms/core/vendor_clients.py +0 -85
- synth_ai/zyk/lms/cost/__init__.py +0 -0
- synth_ai/zyk/lms/cost/monitor.py +0 -1
- synth_ai/zyk/lms/cost/statefulness.py +0 -1
- synth_ai/zyk/lms/structured_outputs/__init__.py +0 -0
- synth_ai/zyk/lms/structured_outputs/handler.py +0 -442
- synth_ai/zyk/lms/structured_outputs/inject.py +0 -314
- synth_ai/zyk/lms/structured_outputs/rehabilitate.py +0 -187
- synth_ai/zyk/lms/tools/base.py +0 -104
- synth_ai/zyk/lms/vendors/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/base.py +0 -31
- synth_ai/zyk/lms/vendors/constants.py +0 -22
- synth_ai/zyk/lms/vendors/core/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/core/anthropic_api.py +0 -413
- synth_ai/zyk/lms/vendors/core/gemini_api.py +0 -306
- synth_ai/zyk/lms/vendors/core/mistral_api.py +0 -327
- synth_ai/zyk/lms/vendors/core/openai_api.py +0 -185
- synth_ai/zyk/lms/vendors/local/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/local/ollama.py +0 -0
- synth_ai/zyk/lms/vendors/openai_standard.py +0 -374
- synth_ai/zyk/lms/vendors/retries.py +0 -3
- synth_ai/zyk/lms/vendors/supported/__init__.py +0 -0
- synth_ai/zyk/lms/vendors/supported/deepseek.py +0 -73
- synth_ai/zyk/lms/vendors/supported/groq.py +0 -16
- synth_ai/zyk/lms/vendors/supported/ollama.py +0 -14
- synth_ai/zyk/lms/vendors/supported/together.py +0 -11
- synth_ai-0.1.0.dev39.dist-info/RECORD +0 -67
- synth_ai-0.1.0.dev39.dist-info/top_level.txt +0 -4
- tests/test_agent.py +0 -538
- tests/test_recursive_structured_outputs.py +0 -180
- tests/test_structured_outputs.py +0 -100
- {synth_ai-0.1.0.dev39.dist-info → synth_ai-0.1.0.dev49.dist-info}/licenses/LICENSE +0 -0
@@ -1,196 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, Optional
|
2
|
-
|
3
|
-
import pytest
|
4
|
-
from pydantic import BaseModel
|
5
|
-
|
6
|
-
from synth_ai.zyk import LM, BaseLMResponse
|
7
|
-
|
8
|
-
class StateUpdate(BaseModel):
|
9
|
-
"""Response model for state updates from LLM"""
|
10
|
-
|
11
|
-
short_term_plan: Optional[str] = None
|
12
|
-
objective: Optional[str] = None
|
13
|
-
final_results: Optional[Dict[str, Any]] = None
|
14
|
-
|
15
|
-
def model_post_init(self, __context):
|
16
|
-
super().model_post_init(__context)
|
17
|
-
# Ensure no protected fields are present
|
18
|
-
protected_fields = ["message_history", "step_summaries"]
|
19
|
-
for field in protected_fields:
|
20
|
-
if hasattr(self, field):
|
21
|
-
raise ValueError(f"Cannot modify protected field: {field}")
|
22
|
-
|
23
|
-
|
24
|
-
@pytest.fixture(scope="module")
|
25
|
-
def models():
|
26
|
-
"""Initialize LMs for different vendors"""
|
27
|
-
return {
|
28
|
-
"gpt-4o-mini": LM(
|
29
|
-
model_name="gpt-4o-mini",
|
30
|
-
formatting_model_name="gpt-4o-mini",
|
31
|
-
temperature=0.1,
|
32
|
-
structured_output_mode="forced_json",
|
33
|
-
),
|
34
|
-
"o3-mini": LM(
|
35
|
-
model_name="o3-mini",
|
36
|
-
formatting_model_name="gpt-4o-mini",
|
37
|
-
temperature=0.1,
|
38
|
-
structured_output_mode="forced_json",
|
39
|
-
),
|
40
|
-
"gemini-1.5-flash": LM(
|
41
|
-
model_name="gemini-1.5-flash",
|
42
|
-
formatting_model_name="gpt-4o-mini",
|
43
|
-
temperature=0.1,
|
44
|
-
structured_output_mode="stringified_json",
|
45
|
-
),
|
46
|
-
"claude-3-haiku-20240307": LM(
|
47
|
-
model_name="claude-3-haiku-20240307",
|
48
|
-
formatting_model_name="gpt-4o-mini",
|
49
|
-
temperature=0.1,
|
50
|
-
structured_output_mode="stringified_json",
|
51
|
-
),
|
52
|
-
"deepseek-chat": LM(
|
53
|
-
model_name="deepseek-chat",
|
54
|
-
formatting_model_name="gpt-4o-mini",
|
55
|
-
temperature=0.1,
|
56
|
-
structured_output_mode="stringified_json",
|
57
|
-
),
|
58
|
-
"deepseek-reasoner": LM(
|
59
|
-
model_name="deepseek-reasoner",
|
60
|
-
formatting_model_name="gpt-4o-mini",
|
61
|
-
temperature=1,
|
62
|
-
structured_output_mode="stringified_json",
|
63
|
-
),
|
64
|
-
"llama-3.1-8b-instant": LM(
|
65
|
-
model_name="llama-3.1-8b-instant",
|
66
|
-
formatting_model_name="gpt-4o-mini",
|
67
|
-
temperature=0.1,
|
68
|
-
structured_output_mode="stringified_json",
|
69
|
-
),
|
70
|
-
"mistral-small-latest": LM(
|
71
|
-
model_name="mistral-small-latest",
|
72
|
-
formatting_model_name="gpt-4o-mini",
|
73
|
-
temperature=0.1,
|
74
|
-
structured_output_mode="stringified_json",
|
75
|
-
),
|
76
|
-
}
|
77
|
-
|
78
|
-
|
79
|
-
@pytest.fixture
|
80
|
-
def system_message():
|
81
|
-
"""System message for state updates"""
|
82
|
-
return """You are helping update the agent's state. Look at the current state and state_delta_instructions and update the state.
|
83
|
-
|
84
|
-
Available fields you can modify:
|
85
|
-
{
|
86
|
-
"short_term_plan": "str",
|
87
|
-
"objective": "str",
|
88
|
-
"final_results": "Dict[str, Any]"
|
89
|
-
}
|
90
|
-
|
91
|
-
Protected fields (do not modify):
|
92
|
-
{
|
93
|
-
"message_history": "Cannot directly edit message history - it is managed internally",
|
94
|
-
"step_summaries": "Cannot directly edit step summaries - they are generated automatically"
|
95
|
-
}
|
96
|
-
|
97
|
-
Please be brief, the state ought not be too long."""
|
98
|
-
|
99
|
-
|
100
|
-
@pytest.fixture
|
101
|
-
def current_state():
|
102
|
-
"""Initial state for testing"""
|
103
|
-
return {
|
104
|
-
"short_term_plan": "Current plan: Review code changes",
|
105
|
-
"objective": "Review pull request",
|
106
|
-
"final_results": {
|
107
|
-
"findings": [],
|
108
|
-
"recommendations": [],
|
109
|
-
"analysis": {},
|
110
|
-
"status": "IN_PROGRESS",
|
111
|
-
},
|
112
|
-
}
|
113
|
-
|
114
|
-
|
115
|
-
@pytest.mark.timeout(15)
|
116
|
-
@pytest.mark.parametrize(
|
117
|
-
"model_name",
|
118
|
-
[
|
119
|
-
"gpt-4o-mini",
|
120
|
-
"gemini-1.5-flash",
|
121
|
-
"claude-3-haiku-20240307",
|
122
|
-
"deepseek-chat",
|
123
|
-
"llama-3.1-8b-instant",
|
124
|
-
],
|
125
|
-
)
|
126
|
-
def test_state_delta_handling(
|
127
|
-
model_name: str, models: Dict[str, LM], system_message: str, current_state: Dict
|
128
|
-
):
|
129
|
-
"""Test that each model correctly handles state updates"""
|
130
|
-
|
131
|
-
state_delta_instructions = """Update the final_results to include findings about code quality issues. Add a recommendation to improve error handling."""
|
132
|
-
user_message = f"Current state: {current_state}\nState delta instructions: {state_delta_instructions}\n\nHow should the state be updated?"
|
133
|
-
|
134
|
-
#try:
|
135
|
-
result: BaseLMResponse = models[model_name].respond_sync(
|
136
|
-
system_message=system_message,
|
137
|
-
user_message=user_message,
|
138
|
-
response_model=StateUpdate,
|
139
|
-
)
|
140
|
-
print("Result", result)
|
141
|
-
# Verify response structure
|
142
|
-
assert isinstance(result, BaseLMResponse)
|
143
|
-
assert isinstance(result.structured_output, StateUpdate)
|
144
|
-
|
145
|
-
# Verify only allowed fields are present and have correct types
|
146
|
-
if result.structured_output.short_term_plan is not None:
|
147
|
-
assert isinstance(result.structured_output.short_term_plan, str)
|
148
|
-
if result.structured_output.objective is not None:
|
149
|
-
assert isinstance(result.structured_output.objective, str)
|
150
|
-
if result.structured_output.final_results is not None:
|
151
|
-
assert isinstance(result.structured_output.final_results, dict)
|
152
|
-
|
153
|
-
# except Exception as e:
|
154
|
-
# pytest.fail(f"Model {model_name} failed: {str(e)}")
|
155
|
-
|
156
|
-
|
157
|
-
@pytest.mark.timeout(15)
|
158
|
-
@pytest.mark.parametrize(
|
159
|
-
"model_name",
|
160
|
-
[
|
161
|
-
"gpt-4o-mini",
|
162
|
-
"gemini-1.5-flash",
|
163
|
-
"claude-3-haiku-20240307",
|
164
|
-
"deepseek-chat",
|
165
|
-
"llama-3.1-8b-instant",
|
166
|
-
],
|
167
|
-
)
|
168
|
-
def test_state_delta_protected_fields(
|
169
|
-
model_name: str, models: Dict[str, LM], system_message: str
|
170
|
-
):
|
171
|
-
"""Test that models respect protected fields"""
|
172
|
-
|
173
|
-
current_state = {
|
174
|
-
"short_term_plan": "Current plan: Review code changes",
|
175
|
-
"objective": "Review pull request",
|
176
|
-
"message_history": ["Previous message 1", "Previous message 2"],
|
177
|
-
"step_summaries": ["Step 1 summary", "Step 2 summary"],
|
178
|
-
"final_results": {
|
179
|
-
"findings": [],
|
180
|
-
"recommendations": [],
|
181
|
-
"analysis": {},
|
182
|
-
"status": "IN_PROGRESS",
|
183
|
-
},
|
184
|
-
}
|
185
|
-
|
186
|
-
state_delta_instructions = """Update the message history to include new findings and update step summaries with recent progress."""
|
187
|
-
user_message = f"Current state: {current_state}\nState delta instructions: {state_delta_instructions}\n\nHow should the state be updated?"
|
188
|
-
|
189
|
-
#try:
|
190
|
-
result = models[model_name].respond_sync(
|
191
|
-
system_message=system_message,
|
192
|
-
user_message=user_message,
|
193
|
-
response_model=StateUpdate,
|
194
|
-
)
|
195
|
-
# except Exception as e:
|
196
|
-
# pytest.fail(f"Model {model_name} failed: {str(e)}")
|
File without changes
|
File without changes
|
@@ -1,64 +0,0 @@
|
|
1
|
-
from pydantic import BaseModel
|
2
|
-
|
3
|
-
from synth_ai.zyk.lms.core.main import LM
|
4
|
-
from synth_ai.zyk.lms.tools.base import BaseTool
|
5
|
-
from synth_ai.zyk.lms.vendors.supported.deepseek import DeepSeekAPI
|
6
|
-
|
7
|
-
|
8
|
-
class WeatherParams(BaseModel):
|
9
|
-
location: str
|
10
|
-
|
11
|
-
|
12
|
-
weather_tool = BaseTool(
|
13
|
-
name="get_weather",
|
14
|
-
description="Get current temperature for a given location.",
|
15
|
-
arguments=WeatherParams,
|
16
|
-
)
|
17
|
-
|
18
|
-
|
19
|
-
def test_weather_tool_direct():
|
20
|
-
client = DeepSeekAPI()
|
21
|
-
|
22
|
-
response = client._hit_api_sync(
|
23
|
-
model="deepseek-chat",
|
24
|
-
messages=[
|
25
|
-
{
|
26
|
-
"role": "system",
|
27
|
-
"content": "You are a helpful assistant that uses tools when appropriate.",
|
28
|
-
},
|
29
|
-
{
|
30
|
-
"role": "user",
|
31
|
-
"content": "What's the weather in Paris? Use the tools and explain your reasoning.",
|
32
|
-
},
|
33
|
-
],
|
34
|
-
tools=[weather_tool],
|
35
|
-
lm_config={
|
36
|
-
"temperature": 0,
|
37
|
-
},
|
38
|
-
)
|
39
|
-
|
40
|
-
# Check that we got a tool call
|
41
|
-
assert response.tool_calls is not None
|
42
|
-
assert len(response.tool_calls) == 1
|
43
|
-
assert response.tool_calls[0]["function"]["name"] == "get_weather"
|
44
|
-
assert "Paris" in response.tool_calls[0]["function"]["arguments"]
|
45
|
-
|
46
|
-
|
47
|
-
def test_weather_tool_lm():
|
48
|
-
lm = LM(
|
49
|
-
model_name="deepseek-chat",
|
50
|
-
formatting_model_name="deepseek-chat",
|
51
|
-
temperature=0,
|
52
|
-
)
|
53
|
-
|
54
|
-
response = lm.respond_sync(
|
55
|
-
system_message="You are a helpful assistant that uses tools when appropriate.",
|
56
|
-
user_message="What's the weather in Paris? Use the tools and explain your reasoning.",
|
57
|
-
tools=[weather_tool],
|
58
|
-
)
|
59
|
-
|
60
|
-
# Check that we got a tool call
|
61
|
-
assert response.tool_calls is not None
|
62
|
-
assert len(response.tool_calls) == 1
|
63
|
-
assert response.tool_calls[0]["function"]["name"] == "get_weather"
|
64
|
-
assert "Paris" in response.tool_calls[0]["function"]["arguments"]
|
@@ -1,188 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
from typing import List
|
4
|
-
|
5
|
-
import pytest
|
6
|
-
from pydantic import BaseModel, Field
|
7
|
-
|
8
|
-
from synth_ai.zyk import LM
|
9
|
-
from synth_ai.zyk.lms.tools.base import BaseTool
|
10
|
-
|
11
|
-
# Set up logging
|
12
|
-
logging.basicConfig(
|
13
|
-
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
14
|
-
)
|
15
|
-
logger = logging.getLogger(__name__)
|
16
|
-
|
17
|
-
|
18
|
-
# 1. Define the Tool Input Schema using BaseModel and Field
|
19
|
-
class CraftaxToolArgs(BaseModel):
|
20
|
-
instance_id: str = Field(
|
21
|
-
description="The ID of the Craftax instance to interact with"
|
22
|
-
)
|
23
|
-
actions_list: List[str] = Field(
|
24
|
-
description="A sequence of actions to execute in the environment (e.g., ['up', 'left', 'do'])"
|
25
|
-
)
|
26
|
-
service_url: str = Field(description="The URL of the Craftax environment service")
|
27
|
-
|
28
|
-
|
29
|
-
# 2. Define the Tool class by extending BaseTool
|
30
|
-
class CraftaxTool(BaseTool):
|
31
|
-
name: str = "interact_with_craftax"
|
32
|
-
description: str = "Interacts with the Craftax environment by sending a sequence of actions to the service."
|
33
|
-
arguments = CraftaxToolArgs
|
34
|
-
|
35
|
-
async def execute(self, args: dict):
|
36
|
-
"""Mock execution function for testing"""
|
37
|
-
logger.info(
|
38
|
-
f"Would execute actions: {args['actions_list']} for instance {args['instance_id']}"
|
39
|
-
)
|
40
|
-
return {
|
41
|
-
"observation": f"Executed actions: {args['actions_list']}",
|
42
|
-
"reward": 1.0,
|
43
|
-
"done": False,
|
44
|
-
"info": {"achievements": {"collect_wood": True}},
|
45
|
-
}
|
46
|
-
|
47
|
-
|
48
|
-
# Helper function to create a simple tool dict (without RepeatedComposite)
|
49
|
-
def create_simplified_tool():
|
50
|
-
return {
|
51
|
-
"name": "interact_with_craftax",
|
52
|
-
"description": "Interacts with the Craftax environment by sending a sequence of actions to the service.",
|
53
|
-
"parameters": {
|
54
|
-
"type": "object",
|
55
|
-
"properties": {
|
56
|
-
"instance_id": {
|
57
|
-
"type": "string",
|
58
|
-
"description": "The ID of the Craftax instance to interact with",
|
59
|
-
},
|
60
|
-
"actions_list": {
|
61
|
-
"type": "array",
|
62
|
-
"items": {"type": "string"},
|
63
|
-
"description": "A sequence of actions to execute in the environment",
|
64
|
-
},
|
65
|
-
"service_url": {
|
66
|
-
"type": "string",
|
67
|
-
"description": "The URL of the Craftax environment service",
|
68
|
-
},
|
69
|
-
},
|
70
|
-
"required": ["instance_id", "actions_list", "service_url"],
|
71
|
-
},
|
72
|
-
}
|
73
|
-
|
74
|
-
|
75
|
-
# Define test constants
|
76
|
-
SYSTEM_MESSAGE = """You are an agent playing Craftax. Your goal is to collect resources.
|
77
|
-
You have access to a tool called `interact_with_craftax` to control the agent."""
|
78
|
-
|
79
|
-
USER_MESSAGE = """# Map
|
80
|
-
## Terrain_underneath_you
|
81
|
-
grass
|
82
|
-
## Surroundings
|
83
|
-
- Tree is 1 steps up
|
84
|
-
|
85
|
-
# Inventory
|
86
|
-
## Resources
|
87
|
-
- wood: 0
|
88
|
-
|
89
|
-
Instructions: Collect 1 wood.
|
90
|
-
Instance ID: test-instance-123
|
91
|
-
Service URL: http://localhost:8002
|
92
|
-
|
93
|
-
Make a tool call to execute actions. Do not explain what you're doing."""
|
94
|
-
|
95
|
-
|
96
|
-
@pytest.mark.asyncio
|
97
|
-
async def test_base_tool_to_json():
|
98
|
-
"""Test that a BaseTool can be serialized to JSON in OpenAI and Gemini formats"""
|
99
|
-
tool = CraftaxTool()
|
100
|
-
|
101
|
-
# Test that the tool can be converted to OpenAI format
|
102
|
-
openai_format = tool.to_openai_tool()
|
103
|
-
openai_json = json.dumps(openai_format, indent=2)
|
104
|
-
assert "function" in openai_json
|
105
|
-
assert "interact_with_craftax" in openai_json
|
106
|
-
|
107
|
-
# Test that the tool can be converted to Gemini format
|
108
|
-
gemini_format = tool.to_gemini_tool()
|
109
|
-
gemini_json = json.dumps(gemini_format, indent=2)
|
110
|
-
assert "parameters" in gemini_json
|
111
|
-
assert "interact_with_craftax" in gemini_json
|
112
|
-
|
113
|
-
|
114
|
-
@pytest.mark.asyncio
|
115
|
-
async def test_simplified_gemini_tool():
|
116
|
-
"""Test that a simplified Gemini tool can be serialized to JSON"""
|
117
|
-
simplified_tool = create_simplified_tool()
|
118
|
-
tool_json = json.dumps(simplified_tool, indent=2)
|
119
|
-
assert "parameters" in tool_json
|
120
|
-
assert "interact_with_craftax" in tool_json
|
121
|
-
|
122
|
-
|
123
|
-
@pytest.mark.asyncio
|
124
|
-
async def test_direct_gemini_tool_call():
|
125
|
-
"""Test that calling Gemini with a directly formatted tool works"""
|
126
|
-
lm = LM(
|
127
|
-
model_name="gemini-2-flash",
|
128
|
-
formatting_model_name="gpt-4o-mini",
|
129
|
-
temperature=0,
|
130
|
-
max_retries="Few",
|
131
|
-
synth_logging=True,
|
132
|
-
)
|
133
|
-
|
134
|
-
# Create a direct function-only tool format
|
135
|
-
direct_tool = [create_simplified_tool()]
|
136
|
-
|
137
|
-
# We're expecting this to complete without errors
|
138
|
-
response = await lm.respond_async(
|
139
|
-
system_message=SYSTEM_MESSAGE,
|
140
|
-
user_message=USER_MESSAGE,
|
141
|
-
tools=direct_tool,
|
142
|
-
)
|
143
|
-
|
144
|
-
# Just check we got a response
|
145
|
-
assert response is not None
|
146
|
-
logger.info(f"Response with direct tool format: {response.raw_response}")
|
147
|
-
|
148
|
-
# If there are tool calls, validate basic structure
|
149
|
-
if response.tool_calls:
|
150
|
-
logger.info(f"Tool calls: {response.tool_calls}")
|
151
|
-
# Verify at least one tool call has the right structure
|
152
|
-
assert any("function" in tc for tc in response.tool_calls)
|
153
|
-
|
154
|
-
|
155
|
-
@pytest.mark.asyncio
|
156
|
-
async def test_base_tool_gemini_call():
|
157
|
-
"""Test that calling Gemini with a BaseTool works"""
|
158
|
-
lm = LM(
|
159
|
-
model_name="gemini-2-flash",
|
160
|
-
formatting_model_name="gpt-4o-mini",
|
161
|
-
temperature=0,
|
162
|
-
max_retries="Few",
|
163
|
-
synth_logging=True,
|
164
|
-
)
|
165
|
-
|
166
|
-
# Use our properly defined BaseTool
|
167
|
-
tool = CraftaxTool()
|
168
|
-
|
169
|
-
# We're expecting this to complete without errors
|
170
|
-
response = await lm.respond_async(
|
171
|
-
system_message=SYSTEM_MESSAGE,
|
172
|
-
user_message=USER_MESSAGE,
|
173
|
-
tools=[tool],
|
174
|
-
)
|
175
|
-
|
176
|
-
# Just check we got a response
|
177
|
-
assert response is not None
|
178
|
-
logger.info(f"Response with BaseTool: {response.raw_response}")
|
179
|
-
|
180
|
-
# If there are tool calls, validate basic structure
|
181
|
-
if response.tool_calls:
|
182
|
-
logger.info(f"Tool calls: {response.tool_calls}")
|
183
|
-
# Verify at least one tool call has the right structure
|
184
|
-
assert any("function" in tc for tc in response.tool_calls)
|
185
|
-
|
186
|
-
|
187
|
-
if __name__ == "__main__":
|
188
|
-
pytest.main(["-xvs", __file__])
|
@@ -1,106 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import unittest
|
3
|
-
from typing import List
|
4
|
-
|
5
|
-
from pydantic import BaseModel, Field
|
6
|
-
|
7
|
-
from synth_ai.zyk.lms.core.main import LM
|
8
|
-
|
9
|
-
|
10
|
-
# Define example structured output models
|
11
|
-
class SimpleResponse(BaseModel):
|
12
|
-
message: str
|
13
|
-
confidence_between_zero_one: float = Field(
|
14
|
-
..., description="Confidence level between 0 and 1"
|
15
|
-
)
|
16
|
-
|
17
|
-
|
18
|
-
class ComplexResponse(BaseModel):
|
19
|
-
title: str
|
20
|
-
tags: List[str]
|
21
|
-
content: str
|
22
|
-
|
23
|
-
|
24
|
-
class NestedResponse(BaseModel):
|
25
|
-
main_category: str
|
26
|
-
subcategories: List[str]
|
27
|
-
details: SimpleResponse
|
28
|
-
|
29
|
-
|
30
|
-
class TestLMStructuredOutputs(unittest.TestCase):
|
31
|
-
@classmethod
|
32
|
-
def setUpClass(cls):
|
33
|
-
# Initialize LMs for both forced_json and stringified_json modes
|
34
|
-
cls.lm_forced_json = LM(
|
35
|
-
model_name="gpt-4o-mini",
|
36
|
-
formatting_model_name="gpt-4o-mini",
|
37
|
-
temperature=0.7,
|
38
|
-
max_retries="Few",
|
39
|
-
structured_output_mode="forced_json",
|
40
|
-
)
|
41
|
-
cls.lm_stringified_json = LM(
|
42
|
-
model_name="gemma3-27b-it",
|
43
|
-
formatting_model_name="gpt-4o-mini",
|
44
|
-
temperature=0.7,
|
45
|
-
max_retries="Few",
|
46
|
-
structured_output_mode="stringified_json",
|
47
|
-
)
|
48
|
-
|
49
|
-
def test_sync_simple_response(self):
|
50
|
-
for lm in [self.lm_forced_json, self.lm_stringified_json]:
|
51
|
-
with self.subTest(
|
52
|
-
mode=lm.structured_output_handler.handler.structured_output_mode
|
53
|
-
):
|
54
|
-
result = lm.respond_sync(
|
55
|
-
system_message="You are a helpful assistant.",
|
56
|
-
user_message="Give me a short greeting and your confidence level.",
|
57
|
-
response_model=SimpleResponse,
|
58
|
-
)
|
59
|
-
self.assertIsInstance(result.structured_output, SimpleResponse)
|
60
|
-
self.assertIsInstance(result.structured_output.message, str)
|
61
|
-
self.assertIsInstance(
|
62
|
-
result.structured_output.confidence_between_zero_one, float
|
63
|
-
)
|
64
|
-
self.assertGreaterEqual(
|
65
|
-
result.structured_output.confidence_between_zero_one, 0
|
66
|
-
)
|
67
|
-
self.assertLessEqual(
|
68
|
-
result.structured_output.confidence_between_zero_one, 1
|
69
|
-
)
|
70
|
-
|
71
|
-
def test_sync_complex_response(self):
|
72
|
-
for lm in [self.lm_forced_json, self.lm_stringified_json]:
|
73
|
-
with self.subTest(
|
74
|
-
mode=lm.structured_output_handler.handler.structured_output_mode
|
75
|
-
):
|
76
|
-
result = lm.respond_sync(
|
77
|
-
system_message="You are a content creator.",
|
78
|
-
user_message="Create a short blog post about AI.",
|
79
|
-
response_model=ComplexResponse,
|
80
|
-
)
|
81
|
-
self.assertIsInstance(result.structured_output, ComplexResponse)
|
82
|
-
self.assertIsInstance(result.structured_output.title, str)
|
83
|
-
self.assertIsInstance(result.structured_output.tags, list)
|
84
|
-
self.assertIsInstance(result.structured_output.content, str)
|
85
|
-
|
86
|
-
async def async_nested_response(self, lm):
|
87
|
-
result = await lm.respond_async(
|
88
|
-
system_message="You are a categorization expert.",
|
89
|
-
user_message="Categorize 'Python' and provide a brief description.",
|
90
|
-
response_model=NestedResponse,
|
91
|
-
)
|
92
|
-
self.assertIsInstance(result.structured_output, NestedResponse)
|
93
|
-
self.assertIsInstance(result.structured_output.main_category, str)
|
94
|
-
self.assertIsInstance(result.structured_output.subcategories, list)
|
95
|
-
self.assertIsInstance(result.structured_output.details, SimpleResponse)
|
96
|
-
|
97
|
-
def test_async_nested_response(self):
|
98
|
-
for lm in [self.lm_forced_json, self.lm_stringified_json]: #
|
99
|
-
with self.subTest(
|
100
|
-
mode=lm.structured_output_handler.handler.structured_output_mode
|
101
|
-
):
|
102
|
-
asyncio.run(self.async_nested_response(lm))
|
103
|
-
|
104
|
-
|
105
|
-
if __name__ == "__main__":
|
106
|
-
unittest.main()
|