casual-llm 0.1.0__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {casual_llm-0.1.0/src/casual_llm.egg-info → casual_llm-0.2.0}/PKG-INFO +1 -1
- {casual_llm-0.1.0 → casual_llm-0.2.0}/pyproject.toml +1 -1
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/__init__.py +1 -1
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/providers/base.py +19 -2
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/providers/ollama.py +26 -3
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/providers/openai.py +32 -2
- {casual_llm-0.1.0 → casual_llm-0.2.0/src/casual_llm.egg-info}/PKG-INFO +1 -1
- {casual_llm-0.1.0 → casual_llm-0.2.0}/tests/test_providers.py +245 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/LICENSE +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/README.md +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/setup.cfg +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/config.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/message_converters/__init__.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/message_converters/ollama.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/message_converters/openai.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/messages.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/providers/__init__.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/py.typed +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/tool_converters/__init__.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/tool_converters/ollama.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/tool_converters/openai.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/tools.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm/usage.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm.egg-info/SOURCES.txt +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm.egg-info/dependency_links.txt +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm.egg-info/requires.txt +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/src/casual_llm.egg-info/top_level.txt +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/tests/test_messages.py +0 -0
- {casual_llm-0.1.0 → casual_llm-0.2.0}/tests/test_tools.py +0 -0
|
@@ -7,7 +7,7 @@ A simple, protocol-based library for working with different LLM providers
|
|
|
7
7
|
Part of the casual-* ecosystem of lightweight AI tools.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
__version__ = "0.
|
|
10
|
+
__version__ = "0.2.0"
|
|
11
11
|
|
|
12
12
|
# Model configuration
|
|
13
13
|
from casual_llm.config import ModelConfig, Provider
|
|
@@ -9,6 +9,8 @@ from __future__ import annotations
|
|
|
9
9
|
|
|
10
10
|
from typing import Protocol, Literal
|
|
11
11
|
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
12
14
|
from casual_llm.messages import ChatMessage, AssistantMessage
|
|
13
15
|
from casual_llm.tools import Tool
|
|
14
16
|
from casual_llm.usage import Usage
|
|
@@ -37,7 +39,7 @@ class LLMProvider(Protocol):
|
|
|
37
39
|
async def chat(
|
|
38
40
|
self,
|
|
39
41
|
messages: list[ChatMessage],
|
|
40
|
-
response_format: Literal["json", "text"] = "text",
|
|
42
|
+
response_format: Literal["json", "text"] | type[BaseModel] = "text",
|
|
41
43
|
max_tokens: int | None = None,
|
|
42
44
|
tools: list[Tool] | None = None,
|
|
43
45
|
temperature: float | None = None,
|
|
@@ -47,7 +49,9 @@ class LLMProvider(Protocol):
|
|
|
47
49
|
|
|
48
50
|
Args:
|
|
49
51
|
messages: List of ChatMessage (UserMessage, AssistantMessage, SystemMessage, etc.)
|
|
50
|
-
response_format: Expected response format
|
|
52
|
+
response_format: Expected response format. Can be "json", "text", or a Pydantic
|
|
53
|
+
BaseModel class for JSON Schema-based structured output. When a Pydantic model
|
|
54
|
+
is provided, the LLM will be instructed to return JSON matching the schema.
|
|
51
55
|
max_tokens: Maximum tokens to generate (optional)
|
|
52
56
|
tools: List of tools available for the LLM to call (optional)
|
|
53
57
|
temperature: Temperature for this request (optional, overrides instance temperature)
|
|
@@ -57,6 +61,19 @@ class LLMProvider(Protocol):
|
|
|
57
61
|
|
|
58
62
|
Raises:
|
|
59
63
|
Provider-specific exceptions (httpx.HTTPError, openai.OpenAIError, etc.)
|
|
64
|
+
|
|
65
|
+
Examples:
|
|
66
|
+
>>> from pydantic import BaseModel
|
|
67
|
+
>>>
|
|
68
|
+
>>> class PersonInfo(BaseModel):
|
|
69
|
+
... name: str
|
|
70
|
+
... age: int
|
|
71
|
+
>>>
|
|
72
|
+
>>> # Pass Pydantic model for structured output
|
|
73
|
+
>>> response = await provider.chat(
|
|
74
|
+
... messages=[UserMessage(content="Tell me about a person")],
|
|
75
|
+
... response_format=PersonInfo # Pass the class, not an instance
|
|
76
|
+
... )
|
|
60
77
|
"""
|
|
61
78
|
...
|
|
62
79
|
|
|
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|
|
7
7
|
import logging
|
|
8
8
|
from typing import Any, Literal
|
|
9
9
|
from ollama import AsyncClient
|
|
10
|
+
from pydantic import BaseModel
|
|
10
11
|
|
|
11
12
|
from casual_llm.messages import ChatMessage, AssistantMessage
|
|
12
13
|
from casual_llm.tools import Tool
|
|
@@ -69,7 +70,7 @@ class OllamaProvider:
|
|
|
69
70
|
async def chat(
|
|
70
71
|
self,
|
|
71
72
|
messages: list[ChatMessage],
|
|
72
|
-
response_format: Literal["json", "text"] = "text",
|
|
73
|
+
response_format: Literal["json", "text"] | type[BaseModel] = "text",
|
|
73
74
|
max_tokens: int | None = None,
|
|
74
75
|
tools: list[Tool] | None = None,
|
|
75
76
|
temperature: float | None = None,
|
|
@@ -79,7 +80,10 @@ class OllamaProvider:
|
|
|
79
80
|
|
|
80
81
|
Args:
|
|
81
82
|
messages: Conversation messages (ChatMessage format)
|
|
82
|
-
response_format: "json" for
|
|
83
|
+
response_format: "json" for JSON output, "text" for plain text, or a Pydantic
|
|
84
|
+
BaseModel class for JSON Schema-based structured output. When a Pydantic
|
|
85
|
+
model is provided, the LLM will be instructed to return JSON matching the
|
|
86
|
+
schema.
|
|
83
87
|
max_tokens: Maximum tokens to generate (optional)
|
|
84
88
|
tools: List of tools available for the LLM to call (optional)
|
|
85
89
|
temperature: Temperature for this request (optional, overrides instance temperature)
|
|
@@ -90,6 +94,19 @@ class OllamaProvider:
|
|
|
90
94
|
Raises:
|
|
91
95
|
ResponseError: If the request could not be fulfilled
|
|
92
96
|
RequestError: If the request was invalid
|
|
97
|
+
|
|
98
|
+
Examples:
|
|
99
|
+
>>> from pydantic import BaseModel
|
|
100
|
+
>>>
|
|
101
|
+
>>> class PersonInfo(BaseModel):
|
|
102
|
+
... name: str
|
|
103
|
+
... age: int
|
|
104
|
+
>>>
|
|
105
|
+
>>> # Pass Pydantic model for structured output
|
|
106
|
+
>>> response = await provider.chat(
|
|
107
|
+
... messages=[UserMessage(content="Tell me about a person")],
|
|
108
|
+
... response_format=PersonInfo # Pass the class, not an instance
|
|
109
|
+
... )
|
|
93
110
|
"""
|
|
94
111
|
# Convert messages to Ollama format using converter
|
|
95
112
|
chat_messages = convert_messages_to_ollama(messages)
|
|
@@ -113,9 +130,15 @@ class OllamaProvider:
|
|
|
113
130
|
"options": options,
|
|
114
131
|
}
|
|
115
132
|
|
|
116
|
-
#
|
|
133
|
+
# Handle response_format: "json", "text", or Pydantic model class
|
|
117
134
|
if response_format == "json":
|
|
118
135
|
request_kwargs["format"] = "json"
|
|
136
|
+
elif isinstance(response_format, type) and issubclass(response_format, BaseModel):
|
|
137
|
+
# Extract JSON Schema from Pydantic model and pass directly to format
|
|
138
|
+
schema = response_format.model_json_schema()
|
|
139
|
+
request_kwargs["format"] = schema
|
|
140
|
+
logger.debug(f"Using JSON Schema from Pydantic model: {response_format.__name__}")
|
|
141
|
+
# "text" is the default - no format parameter needed
|
|
119
142
|
|
|
120
143
|
# Add tools if provided
|
|
121
144
|
if tools:
|
|
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|
|
7
7
|
import logging
|
|
8
8
|
from typing import Literal, Any
|
|
9
9
|
from openai import AsyncOpenAI
|
|
10
|
+
from pydantic import BaseModel
|
|
10
11
|
|
|
11
12
|
from casual_llm.messages import ChatMessage, AssistantMessage
|
|
12
13
|
from casual_llm.tools import Tool
|
|
@@ -82,7 +83,7 @@ class OpenAIProvider:
|
|
|
82
83
|
async def chat(
|
|
83
84
|
self,
|
|
84
85
|
messages: list[ChatMessage],
|
|
85
|
-
response_format: Literal["json", "text"] = "text",
|
|
86
|
+
response_format: Literal["json", "text"] | type[BaseModel] = "text",
|
|
86
87
|
max_tokens: int | None = None,
|
|
87
88
|
tools: list[Tool] | None = None,
|
|
88
89
|
temperature: float | None = None,
|
|
@@ -92,7 +93,10 @@ class OpenAIProvider:
|
|
|
92
93
|
|
|
93
94
|
Args:
|
|
94
95
|
messages: Conversation messages (ChatMessage format)
|
|
95
|
-
response_format: "json" for
|
|
96
|
+
response_format: "json" for JSON output, "text" for plain text, or a Pydantic
|
|
97
|
+
BaseModel class for JSON Schema-based structured output. When a Pydantic
|
|
98
|
+
model is provided, the LLM will be instructed to return JSON matching the
|
|
99
|
+
schema.
|
|
96
100
|
max_tokens: Maximum tokens to generate (optional)
|
|
97
101
|
tools: List of tools available for the LLM to call (optional)
|
|
98
102
|
temperature: Temperature for this request (optional, overrides instance temperature)
|
|
@@ -102,6 +106,19 @@ class OpenAIProvider:
|
|
|
102
106
|
|
|
103
107
|
Raises:
|
|
104
108
|
openai.OpenAIError: If request fails
|
|
109
|
+
|
|
110
|
+
Examples:
|
|
111
|
+
>>> from pydantic import BaseModel
|
|
112
|
+
>>>
|
|
113
|
+
>>> class PersonInfo(BaseModel):
|
|
114
|
+
... name: str
|
|
115
|
+
... age: int
|
|
116
|
+
>>>
|
|
117
|
+
>>> # Pass Pydantic model for structured output
|
|
118
|
+
>>> response = await provider.chat(
|
|
119
|
+
... messages=[UserMessage(content="Tell me about a person")],
|
|
120
|
+
... response_format=PersonInfo # Pass the class, not an instance
|
|
121
|
+
... )
|
|
105
122
|
"""
|
|
106
123
|
# Convert messages to OpenAI format using converter
|
|
107
124
|
chat_messages = convert_messages_to_openai(messages)
|
|
@@ -120,8 +137,21 @@ class OpenAIProvider:
|
|
|
120
137
|
if temp is not None:
|
|
121
138
|
request_kwargs["temperature"] = temp
|
|
122
139
|
|
|
140
|
+
# Handle response_format: "json", "text", or Pydantic model class
|
|
123
141
|
if response_format == "json":
|
|
124
142
|
request_kwargs["response_format"] = {"type": "json_object"}
|
|
143
|
+
elif isinstance(response_format, type) and issubclass(response_format, BaseModel):
|
|
144
|
+
# Extract JSON Schema from Pydantic model
|
|
145
|
+
schema = response_format.model_json_schema()
|
|
146
|
+
request_kwargs["response_format"] = {
|
|
147
|
+
"type": "json_schema",
|
|
148
|
+
"json_schema": {
|
|
149
|
+
"name": response_format.__name__,
|
|
150
|
+
"schema": schema,
|
|
151
|
+
},
|
|
152
|
+
}
|
|
153
|
+
logger.debug(f"Using JSON Schema from Pydantic model: {response_format.__name__}")
|
|
154
|
+
# "text" is the default - no response_format needed
|
|
125
155
|
|
|
126
156
|
if max_tokens:
|
|
127
157
|
request_kwargs["max_tokens"] = max_tokens
|
|
@@ -3,12 +3,38 @@ Tests for LLM provider implementations.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
+
from pydantic import BaseModel
|
|
6
7
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
7
8
|
from casual_llm.config import ModelConfig, Provider
|
|
8
9
|
from casual_llm.providers import OllamaProvider, create_provider
|
|
9
10
|
from casual_llm.messages import UserMessage, AssistantMessage, SystemMessage
|
|
10
11
|
from casual_llm.usage import Usage
|
|
11
12
|
|
|
13
|
+
|
|
14
|
+
# Test Pydantic models for JSON Schema tests
|
|
15
|
+
class PersonInfo(BaseModel):
|
|
16
|
+
"""Simple Pydantic model for testing"""
|
|
17
|
+
|
|
18
|
+
name: str
|
|
19
|
+
age: int
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Address(BaseModel):
|
|
23
|
+
"""Nested model for testing complex schemas"""
|
|
24
|
+
|
|
25
|
+
street: str
|
|
26
|
+
city: str
|
|
27
|
+
zip_code: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class PersonWithAddress(BaseModel):
|
|
31
|
+
"""Pydantic model with nested structure for testing"""
|
|
32
|
+
|
|
33
|
+
name: str
|
|
34
|
+
age: int
|
|
35
|
+
address: Address
|
|
36
|
+
|
|
37
|
+
|
|
12
38
|
# Try to import OpenAI provider - may not be available
|
|
13
39
|
try:
|
|
14
40
|
from casual_llm.providers import OpenAIProvider
|
|
@@ -187,6 +213,110 @@ class TestOllamaProvider:
|
|
|
187
213
|
assert usage.completion_tokens == 20
|
|
188
214
|
assert usage.total_tokens == 30
|
|
189
215
|
|
|
216
|
+
@pytest.mark.asyncio
|
|
217
|
+
async def test_json_schema_response_format(self, provider):
|
|
218
|
+
"""Test that Pydantic model is correctly converted to JSON Schema for Ollama"""
|
|
219
|
+
mock_response = MagicMock()
|
|
220
|
+
mock_response.message.content = '{"name": "Alice", "age": 30}'
|
|
221
|
+
mock_response.message.tool_calls = None
|
|
222
|
+
|
|
223
|
+
mock_chat = AsyncMock(return_value=mock_response)
|
|
224
|
+
|
|
225
|
+
with patch("ollama.AsyncClient.chat", new=mock_chat):
|
|
226
|
+
messages = [UserMessage(content="Give me person info")]
|
|
227
|
+
|
|
228
|
+
result = await provider.chat(messages, response_format=PersonInfo)
|
|
229
|
+
|
|
230
|
+
assert isinstance(result, AssistantMessage)
|
|
231
|
+
assert '{"name": "Alice", "age": 30}' in result.content
|
|
232
|
+
|
|
233
|
+
# Verify the format parameter contains the JSON Schema
|
|
234
|
+
call_kwargs = mock_chat.call_args.kwargs
|
|
235
|
+
assert "format" in call_kwargs
|
|
236
|
+
schema = call_kwargs["format"]
|
|
237
|
+
|
|
238
|
+
# Verify it's a dict (JSON Schema), not a string
|
|
239
|
+
assert isinstance(schema, dict)
|
|
240
|
+
|
|
241
|
+
# Verify schema has expected properties
|
|
242
|
+
assert "properties" in schema
|
|
243
|
+
assert "name" in schema["properties"]
|
|
244
|
+
assert "age" in schema["properties"]
|
|
245
|
+
assert schema["properties"]["name"]["type"] == "string"
|
|
246
|
+
assert schema["properties"]["age"]["type"] == "integer"
|
|
247
|
+
|
|
248
|
+
@pytest.mark.asyncio
|
|
249
|
+
async def test_json_schema_nested_pydantic_model(self, provider):
|
|
250
|
+
"""Test that complex nested Pydantic models work correctly"""
|
|
251
|
+
mock_response = MagicMock()
|
|
252
|
+
mock_response.message.content = '{"name": "Bob", "age": 25, "address": {"street": "123 Main St", "city": "NYC", "zip_code": "10001"}}'
|
|
253
|
+
mock_response.message.tool_calls = None
|
|
254
|
+
|
|
255
|
+
mock_chat = AsyncMock(return_value=mock_response)
|
|
256
|
+
|
|
257
|
+
with patch("ollama.AsyncClient.chat", new=mock_chat):
|
|
258
|
+
messages = [UserMessage(content="Give me person with address")]
|
|
259
|
+
|
|
260
|
+
result = await provider.chat(messages, response_format=PersonWithAddress)
|
|
261
|
+
|
|
262
|
+
assert isinstance(result, AssistantMessage)
|
|
263
|
+
|
|
264
|
+
# Verify the format parameter contains the nested JSON Schema
|
|
265
|
+
call_kwargs = mock_chat.call_args.kwargs
|
|
266
|
+
assert "format" in call_kwargs
|
|
267
|
+
schema = call_kwargs["format"]
|
|
268
|
+
|
|
269
|
+
# Verify it's a dict with properties
|
|
270
|
+
assert isinstance(schema, dict)
|
|
271
|
+
assert "properties" in schema
|
|
272
|
+
|
|
273
|
+
# Verify nested structure is present (either through $defs or inline)
|
|
274
|
+
# Pydantic v2 uses $defs for nested models
|
|
275
|
+
if "$defs" in schema:
|
|
276
|
+
assert "Address" in schema["$defs"]
|
|
277
|
+
|
|
278
|
+
@pytest.mark.asyncio
|
|
279
|
+
async def test_backward_compat_json_format(self, provider):
|
|
280
|
+
"""Test that existing 'json' format still works (backward compatibility)"""
|
|
281
|
+
mock_response = MagicMock()
|
|
282
|
+
mock_response.message.content = '{"status": "ok"}'
|
|
283
|
+
mock_response.message.tool_calls = None
|
|
284
|
+
|
|
285
|
+
mock_chat = AsyncMock(return_value=mock_response)
|
|
286
|
+
|
|
287
|
+
with patch("ollama.AsyncClient.chat", new=mock_chat):
|
|
288
|
+
messages = [UserMessage(content="Give me JSON")]
|
|
289
|
+
|
|
290
|
+
result = await provider.chat(messages, response_format="json")
|
|
291
|
+
|
|
292
|
+
assert isinstance(result, AssistantMessage)
|
|
293
|
+
assert '{"status": "ok"}' in result.content
|
|
294
|
+
|
|
295
|
+
# Verify format is set to "json" string (not a schema dict)
|
|
296
|
+
call_kwargs = mock_chat.call_args.kwargs
|
|
297
|
+
assert call_kwargs["format"] == "json"
|
|
298
|
+
|
|
299
|
+
@pytest.mark.asyncio
|
|
300
|
+
async def test_backward_compat_text_format(self, provider):
|
|
301
|
+
"""Test that existing 'text' format still works (backward compatibility)"""
|
|
302
|
+
mock_response = MagicMock()
|
|
303
|
+
mock_response.message.content = "Plain text response"
|
|
304
|
+
mock_response.message.tool_calls = None
|
|
305
|
+
|
|
306
|
+
mock_chat = AsyncMock(return_value=mock_response)
|
|
307
|
+
|
|
308
|
+
with patch("ollama.AsyncClient.chat", new=mock_chat):
|
|
309
|
+
messages = [UserMessage(content="Give me text")]
|
|
310
|
+
|
|
311
|
+
result = await provider.chat(messages, response_format="text")
|
|
312
|
+
|
|
313
|
+
assert isinstance(result, AssistantMessage)
|
|
314
|
+
assert result.content == "Plain text response"
|
|
315
|
+
|
|
316
|
+
# Verify no format parameter is set for text
|
|
317
|
+
call_kwargs = mock_chat.call_args.kwargs
|
|
318
|
+
assert "format" not in call_kwargs
|
|
319
|
+
|
|
190
320
|
|
|
191
321
|
@pytest.mark.skipif(not OPENAI_AVAILABLE, reason="OpenAI provider not installed")
|
|
192
322
|
class TestOpenAIProvider:
|
|
@@ -362,6 +492,121 @@ class TestOpenAIProvider:
|
|
|
362
492
|
assert usage.completion_tokens == 25
|
|
363
493
|
assert usage.total_tokens == 40
|
|
364
494
|
|
|
495
|
+
@pytest.mark.asyncio
|
|
496
|
+
async def test_json_schema_response_format(self, provider):
|
|
497
|
+
"""Test that Pydantic model is correctly converted to JSON Schema for OpenAI"""
|
|
498
|
+
mock_completion = MagicMock()
|
|
499
|
+
mock_completion.choices = [
|
|
500
|
+
MagicMock(message=MagicMock(content='{"name": "Alice", "age": 30}'))
|
|
501
|
+
]
|
|
502
|
+
|
|
503
|
+
mock_create = AsyncMock(return_value=mock_completion)
|
|
504
|
+
|
|
505
|
+
with patch.object(provider.client.chat.completions, "create", new=mock_create):
|
|
506
|
+
messages = [UserMessage(content="Give me person info")]
|
|
507
|
+
|
|
508
|
+
result = await provider.chat(messages, response_format=PersonInfo)
|
|
509
|
+
|
|
510
|
+
assert isinstance(result, AssistantMessage)
|
|
511
|
+
assert '{"name": "Alice", "age": 30}' in result.content
|
|
512
|
+
|
|
513
|
+
# Verify the response_format parameter contains the JSON Schema structure
|
|
514
|
+
call_kwargs = mock_create.call_args.kwargs
|
|
515
|
+
assert "response_format" in call_kwargs
|
|
516
|
+
response_format = call_kwargs["response_format"]
|
|
517
|
+
|
|
518
|
+
# Verify OpenAI json_schema format structure
|
|
519
|
+
assert response_format["type"] == "json_schema"
|
|
520
|
+
assert "json_schema" in response_format
|
|
521
|
+
assert response_format["json_schema"]["name"] == "PersonInfo"
|
|
522
|
+
assert "schema" in response_format["json_schema"]
|
|
523
|
+
|
|
524
|
+
# Verify schema has expected properties
|
|
525
|
+
schema = response_format["json_schema"]["schema"]
|
|
526
|
+
assert "properties" in schema
|
|
527
|
+
assert "name" in schema["properties"]
|
|
528
|
+
assert "age" in schema["properties"]
|
|
529
|
+
assert schema["properties"]["name"]["type"] == "string"
|
|
530
|
+
assert schema["properties"]["age"]["type"] == "integer"
|
|
531
|
+
|
|
532
|
+
@pytest.mark.asyncio
|
|
533
|
+
async def test_json_schema_nested_pydantic_model(self, provider):
|
|
534
|
+
"""Test that complex nested Pydantic models work correctly"""
|
|
535
|
+
mock_completion = MagicMock()
|
|
536
|
+
mock_completion.choices = [
|
|
537
|
+
MagicMock(
|
|
538
|
+
message=MagicMock(
|
|
539
|
+
content='{"name": "Bob", "age": 25, "address": {"street": "123 Main St", "city": "NYC", "zip_code": "10001"}}'
|
|
540
|
+
)
|
|
541
|
+
)
|
|
542
|
+
]
|
|
543
|
+
|
|
544
|
+
mock_create = AsyncMock(return_value=mock_completion)
|
|
545
|
+
|
|
546
|
+
with patch.object(provider.client.chat.completions, "create", new=mock_create):
|
|
547
|
+
messages = [UserMessage(content="Give me person with address")]
|
|
548
|
+
|
|
549
|
+
result = await provider.chat(messages, response_format=PersonWithAddress)
|
|
550
|
+
|
|
551
|
+
assert isinstance(result, AssistantMessage)
|
|
552
|
+
|
|
553
|
+
# Verify the response_format parameter contains the nested JSON Schema
|
|
554
|
+
call_kwargs = mock_create.call_args.kwargs
|
|
555
|
+
assert "response_format" in call_kwargs
|
|
556
|
+
response_format = call_kwargs["response_format"]
|
|
557
|
+
|
|
558
|
+
# Verify OpenAI json_schema format structure
|
|
559
|
+
assert response_format["type"] == "json_schema"
|
|
560
|
+
assert response_format["json_schema"]["name"] == "PersonWithAddress"
|
|
561
|
+
|
|
562
|
+
schema = response_format["json_schema"]["schema"]
|
|
563
|
+
assert "properties" in schema
|
|
564
|
+
|
|
565
|
+
# Verify nested structure is present (either through $defs or inline)
|
|
566
|
+
# Pydantic v2 uses $defs for nested models
|
|
567
|
+
if "$defs" in schema:
|
|
568
|
+
assert "Address" in schema["$defs"]
|
|
569
|
+
|
|
570
|
+
@pytest.mark.asyncio
|
|
571
|
+
async def test_backward_compat_json_format(self, provider):
|
|
572
|
+
"""Test that existing 'json' format still works (backward compatibility)"""
|
|
573
|
+
mock_completion = MagicMock()
|
|
574
|
+
mock_completion.choices = [MagicMock(message=MagicMock(content='{"status": "ok"}'))]
|
|
575
|
+
|
|
576
|
+
mock_create = AsyncMock(return_value=mock_completion)
|
|
577
|
+
|
|
578
|
+
with patch.object(provider.client.chat.completions, "create", new=mock_create):
|
|
579
|
+
messages = [UserMessage(content="Give me JSON")]
|
|
580
|
+
|
|
581
|
+
result = await provider.chat(messages, response_format="json")
|
|
582
|
+
|
|
583
|
+
assert isinstance(result, AssistantMessage)
|
|
584
|
+
assert '{"status": "ok"}' in result.content
|
|
585
|
+
|
|
586
|
+
# Verify response_format is set to json_object (not json_schema)
|
|
587
|
+
call_kwargs = mock_create.call_args.kwargs
|
|
588
|
+
assert call_kwargs["response_format"] == {"type": "json_object"}
|
|
589
|
+
|
|
590
|
+
@pytest.mark.asyncio
|
|
591
|
+
async def test_backward_compat_text_format(self, provider):
|
|
592
|
+
"""Test that existing 'text' format still works (backward compatibility)"""
|
|
593
|
+
mock_completion = MagicMock()
|
|
594
|
+
mock_completion.choices = [MagicMock(message=MagicMock(content="Plain text response"))]
|
|
595
|
+
|
|
596
|
+
mock_create = AsyncMock(return_value=mock_completion)
|
|
597
|
+
|
|
598
|
+
with patch.object(provider.client.chat.completions, "create", new=mock_create):
|
|
599
|
+
messages = [UserMessage(content="Give me text")]
|
|
600
|
+
|
|
601
|
+
result = await provider.chat(messages, response_format="text")
|
|
602
|
+
|
|
603
|
+
assert isinstance(result, AssistantMessage)
|
|
604
|
+
assert result.content == "Plain text response"
|
|
605
|
+
|
|
606
|
+
# Verify no response_format parameter is set for text
|
|
607
|
+
call_kwargs = mock_create.call_args.kwargs
|
|
608
|
+
assert "response_format" not in call_kwargs
|
|
609
|
+
|
|
365
610
|
|
|
366
611
|
class TestCreateProviderFactory:
|
|
367
612
|
"""Tests for create_provider() factory function"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|