kiln-ai 0.12.0__py3-none-any.whl → 0.13.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (49) hide show
  1. kiln_ai/adapters/__init__.py +4 -0
  2. kiln_ai/adapters/adapter_registry.py +157 -28
  3. kiln_ai/adapters/eval/__init__.py +28 -0
  4. kiln_ai/adapters/eval/eval_runner.py +4 -1
  5. kiln_ai/adapters/eval/g_eval.py +19 -3
  6. kiln_ai/adapters/eval/test_base_eval.py +1 -0
  7. kiln_ai/adapters/eval/test_eval_runner.py +1 -0
  8. kiln_ai/adapters/eval/test_g_eval.py +13 -7
  9. kiln_ai/adapters/fine_tune/base_finetune.py +16 -2
  10. kiln_ai/adapters/fine_tune/finetune_registry.py +2 -0
  11. kiln_ai/adapters/fine_tune/fireworks_finetune.py +8 -1
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +19 -0
  13. kiln_ai/adapters/fine_tune/test_together_finetune.py +533 -0
  14. kiln_ai/adapters/fine_tune/together_finetune.py +327 -0
  15. kiln_ai/adapters/ml_model_list.py +638 -155
  16. kiln_ai/adapters/model_adapters/__init__.py +2 -4
  17. kiln_ai/adapters/model_adapters/base_adapter.py +14 -11
  18. kiln_ai/adapters/model_adapters/litellm_adapter.py +391 -0
  19. kiln_ai/adapters/model_adapters/litellm_config.py +13 -0
  20. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -0
  21. kiln_ai/adapters/model_adapters/test_structured_output.py +23 -5
  22. kiln_ai/adapters/ollama_tools.py +3 -2
  23. kiln_ai/adapters/parsers/r1_parser.py +19 -14
  24. kiln_ai/adapters/parsers/test_r1_parser.py +17 -5
  25. kiln_ai/adapters/provider_tools.py +52 -60
  26. kiln_ai/adapters/repair/test_repair_task.py +3 -3
  27. kiln_ai/adapters/run_output.py +1 -1
  28. kiln_ai/adapters/test_adapter_registry.py +17 -20
  29. kiln_ai/adapters/test_generate_docs.py +2 -2
  30. kiln_ai/adapters/test_prompt_adaptors.py +30 -19
  31. kiln_ai/adapters/test_provider_tools.py +27 -82
  32. kiln_ai/datamodel/basemodel.py +2 -0
  33. kiln_ai/datamodel/datamodel_enums.py +2 -0
  34. kiln_ai/datamodel/json_schema.py +1 -1
  35. kiln_ai/datamodel/task_output.py +13 -6
  36. kiln_ai/datamodel/test_basemodel.py +9 -0
  37. kiln_ai/datamodel/test_datasource.py +19 -0
  38. kiln_ai/utils/config.py +46 -0
  39. kiln_ai/utils/dataset_import.py +232 -0
  40. kiln_ai/utils/test_dataset_import.py +596 -0
  41. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/METADATA +51 -7
  42. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/RECORD +44 -41
  43. kiln_ai/adapters/model_adapters/langchain_adapters.py +0 -309
  44. kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -10
  45. kiln_ai/adapters/model_adapters/openai_model_adapter.py +0 -289
  46. kiln_ai/adapters/model_adapters/test_langchain_adapter.py +0 -343
  47. kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +0 -216
  48. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/WHEEL +0 -0
  49. {kiln_ai-0.12.0.dist-info → kiln_ai-0.13.2.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,407 @@
1
+ import json
2
+ from unittest.mock import Mock, patch
3
+
4
+ import pytest
5
+
6
+ from kiln_ai.adapters.ml_model_list import ModelProviderName, StructuredOutputMode
7
+ from kiln_ai.adapters.model_adapters.base_adapter import AdapterConfig
8
+ from kiln_ai.adapters.model_adapters.litellm_adapter import LiteLlmAdapter
9
+ from kiln_ai.adapters.model_adapters.litellm_config import (
10
+ LiteLlmConfig,
11
+ )
12
+ from kiln_ai.datamodel import Project, Task
13
+
14
+
15
+ @pytest.fixture
16
+ def mock_task(tmp_path):
17
+ # Create a project first since Task requires a parent
18
+ project_path = tmp_path / "test_project" / "project.kiln"
19
+ project_path.parent.mkdir()
20
+
21
+ project = Project(name="Test Project", path=str(project_path))
22
+ project.save_to_file()
23
+
24
+ schema = {
25
+ "type": "object",
26
+ "properties": {"test": {"type": "string"}},
27
+ }
28
+
29
+ task = Task(
30
+ name="Test Task",
31
+ instruction="Test instruction",
32
+ parent=project,
33
+ output_json_schema=json.dumps(schema),
34
+ )
35
+ task.save_to_file()
36
+ return task
37
+
38
+
39
+ @pytest.fixture
40
+ def config():
41
+ return LiteLlmConfig(
42
+ base_url="https://api.test.com",
43
+ model_name="test-model",
44
+ provider_name="openrouter",
45
+ default_headers={"X-Test": "test"},
46
+ additional_body_options={"api_key": "test_key"},
47
+ )
48
+
49
+
50
+ def test_initialization(config, mock_task):
51
+ adapter = LiteLlmAdapter(
52
+ config=config,
53
+ kiln_task=mock_task,
54
+ prompt_id="simple_prompt_builder",
55
+ base_adapter_config=AdapterConfig(default_tags=["test-tag"]),
56
+ )
57
+
58
+ assert adapter.config == config
59
+ assert adapter.run_config.task == mock_task
60
+ assert adapter.run_config.prompt_id == "simple_prompt_builder"
61
+ assert adapter.base_adapter_config.default_tags == ["test-tag"]
62
+ assert adapter.run_config.model_name == config.model_name
63
+ assert adapter.run_config.model_provider_name == config.provider_name
64
+ assert adapter.config.additional_body_options["api_key"] == "test_key"
65
+ assert adapter._api_base == config.base_url
66
+ assert adapter._headers == config.default_headers
67
+
68
+
69
+ def test_adapter_info(config, mock_task):
70
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
71
+
72
+ assert adapter.adapter_name() == "kiln_openai_compatible_adapter"
73
+
74
+ assert adapter.run_config.model_name == config.model_name
75
+ assert adapter.run_config.model_provider_name == config.provider_name
76
+ assert adapter.run_config.prompt_id == "simple_prompt_builder"
77
+
78
+
79
+ @pytest.mark.asyncio
80
+ async def test_response_format_options_unstructured(config, mock_task):
81
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
82
+
83
+ # Mock has_structured_output to return False
84
+ with patch.object(adapter, "has_structured_output", return_value=False):
85
+ options = await adapter.response_format_options()
86
+ assert options == {}
87
+
88
+
89
+ @pytest.mark.parametrize(
90
+ "mode",
91
+ [
92
+ StructuredOutputMode.json_mode,
93
+ StructuredOutputMode.json_instruction_and_object,
94
+ ],
95
+ )
96
+ @pytest.mark.asyncio
97
+ async def test_response_format_options_json_mode(config, mock_task, mode):
98
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
99
+
100
+ with (
101
+ patch.object(adapter, "has_structured_output", return_value=True),
102
+ patch.object(adapter, "model_provider") as mock_provider,
103
+ ):
104
+ mock_provider.return_value.structured_output_mode = mode
105
+
106
+ options = await adapter.response_format_options()
107
+ assert options == {"response_format": {"type": "json_object"}}
108
+
109
+
110
+ @pytest.mark.parametrize(
111
+ "mode",
112
+ [
113
+ StructuredOutputMode.default,
114
+ StructuredOutputMode.function_calling,
115
+ ],
116
+ )
117
+ @pytest.mark.asyncio
118
+ async def test_response_format_options_function_calling(config, mock_task, mode):
119
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
120
+
121
+ with (
122
+ patch.object(adapter, "has_structured_output", return_value=True),
123
+ patch.object(adapter, "model_provider") as mock_provider,
124
+ ):
125
+ mock_provider.return_value.structured_output_mode = mode
126
+
127
+ options = await adapter.response_format_options()
128
+ assert "tools" in options
129
+ # full tool structure validated below
130
+
131
+
132
+ @pytest.mark.parametrize(
133
+ "mode",
134
+ [
135
+ StructuredOutputMode.json_custom_instructions,
136
+ StructuredOutputMode.json_instructions,
137
+ ],
138
+ )
139
+ @pytest.mark.asyncio
140
+ async def test_response_format_options_json_instructions(config, mock_task, mode):
141
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
142
+
143
+ with (
144
+ patch.object(adapter, "has_structured_output", return_value=True),
145
+ patch.object(adapter, "model_provider") as mock_provider,
146
+ ):
147
+ mock_provider.return_value.structured_output_mode = (
148
+ StructuredOutputMode.json_instructions
149
+ )
150
+ options = await adapter.response_format_options()
151
+ assert options == {}
152
+
153
+
154
+ @pytest.mark.asyncio
155
+ async def test_response_format_options_json_schema(config, mock_task):
156
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
157
+
158
+ with (
159
+ patch.object(adapter, "has_structured_output", return_value=True),
160
+ patch.object(adapter, "model_provider") as mock_provider,
161
+ ):
162
+ mock_provider.return_value.structured_output_mode = (
163
+ StructuredOutputMode.json_schema
164
+ )
165
+ options = await adapter.response_format_options()
166
+ assert options == {
167
+ "response_format": {
168
+ "type": "json_schema",
169
+ "json_schema": {
170
+ "name": "task_response",
171
+ "schema": mock_task.output_schema(),
172
+ },
173
+ }
174
+ }
175
+
176
+
177
+ def test_tool_call_params_weak(config, mock_task):
178
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
179
+
180
+ params = adapter.tool_call_params(strict=False)
181
+ expected_schema = mock_task.output_schema()
182
+ expected_schema["additionalProperties"] = False
183
+
184
+ assert params == {
185
+ "tools": [
186
+ {
187
+ "type": "function",
188
+ "function": {
189
+ "name": "task_response",
190
+ "parameters": expected_schema,
191
+ },
192
+ }
193
+ ],
194
+ "tool_choice": {
195
+ "type": "function",
196
+ "function": {"name": "task_response"},
197
+ },
198
+ }
199
+
200
+
201
+ def test_tool_call_params_strict(config, mock_task):
202
+ config.provider_name = "openai"
203
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
204
+
205
+ params = adapter.tool_call_params(strict=True)
206
+ expected_schema = mock_task.output_schema()
207
+ expected_schema["additionalProperties"] = False
208
+
209
+ assert params == {
210
+ "tools": [
211
+ {
212
+ "type": "function",
213
+ "function": {
214
+ "name": "task_response",
215
+ "parameters": expected_schema,
216
+ "strict": True,
217
+ },
218
+ }
219
+ ],
220
+ "tool_choice": {
221
+ "type": "function",
222
+ "function": {"name": "task_response"},
223
+ },
224
+ }
225
+
226
+
227
+ @pytest.mark.parametrize(
228
+ "provider_name,expected_prefix",
229
+ [
230
+ (ModelProviderName.openrouter, "openrouter"),
231
+ (ModelProviderName.openai, "openai"),
232
+ (ModelProviderName.groq, "groq"),
233
+ (ModelProviderName.anthropic, "anthropic"),
234
+ (ModelProviderName.ollama, "openai"),
235
+ (ModelProviderName.gemini_api, "gemini"),
236
+ (ModelProviderName.fireworks_ai, "fireworks_ai"),
237
+ (ModelProviderName.amazon_bedrock, "bedrock"),
238
+ (ModelProviderName.azure_openai, "azure"),
239
+ (ModelProviderName.huggingface, "huggingface"),
240
+ (ModelProviderName.vertex, "vertex_ai"),
241
+ (ModelProviderName.together_ai, "together_ai"),
242
+ ],
243
+ )
244
+ def test_litellm_model_id_standard_providers(
245
+ config, mock_task, provider_name, expected_prefix
246
+ ):
247
+ """Test litellm_model_id for standard providers"""
248
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
249
+
250
+ # Mock the model_provider method to return a provider with the specified name
251
+ mock_provider = Mock()
252
+ mock_provider.name = provider_name
253
+ mock_provider.model_id = "test-model"
254
+
255
+ with patch.object(adapter, "model_provider", return_value=mock_provider):
256
+ model_id = adapter.litellm_model_id()
257
+
258
+ assert model_id == f"{expected_prefix}/test-model"
259
+ # Verify caching works
260
+ assert adapter._litellm_model_id == model_id
261
+
262
+
263
+ @pytest.mark.parametrize(
264
+ "provider_name",
265
+ [
266
+ ModelProviderName.openai_compatible,
267
+ ModelProviderName.kiln_custom_registry,
268
+ ModelProviderName.kiln_fine_tune,
269
+ ],
270
+ )
271
+ def test_litellm_model_id_custom_providers(config, mock_task, provider_name):
272
+ """Test litellm_model_id for custom providers that require a base URL"""
273
+ config.base_url = "https://api.custom.com"
274
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
275
+
276
+ # Mock the model_provider method
277
+ mock_provider = Mock()
278
+ mock_provider.name = provider_name
279
+ mock_provider.model_id = "custom-model"
280
+
281
+ with patch.object(adapter, "model_provider", return_value=mock_provider):
282
+ model_id = adapter.litellm_model_id()
283
+
284
+ # Custom providers should use "openai" as the provider name
285
+ assert model_id == "openai/custom-model"
286
+ assert adapter._litellm_model_id == model_id
287
+
288
+
289
+ def test_litellm_model_id_custom_provider_no_base_url(config, mock_task):
290
+ """Test litellm_model_id raises error for custom providers without base URL"""
291
+ config.base_url = None
292
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
293
+
294
+ # Mock the model_provider method
295
+ mock_provider = Mock()
296
+ mock_provider.name = ModelProviderName.openai_compatible
297
+ mock_provider.model_id = "custom-model"
298
+
299
+ with patch.object(adapter, "model_provider", return_value=mock_provider):
300
+ with pytest.raises(ValueError, match="Explicit Base URL is required"):
301
+ adapter.litellm_model_id()
302
+
303
+
304
+ def test_litellm_model_id_no_model_id(config, mock_task):
305
+ """Test litellm_model_id raises error when provider has no model_id"""
306
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
307
+
308
+ # Mock the model_provider method to return a provider with no model_id
309
+ mock_provider = Mock()
310
+ mock_provider.name = ModelProviderName.openai
311
+ mock_provider.model_id = None
312
+
313
+ with patch.object(adapter, "model_provider", return_value=mock_provider):
314
+ with pytest.raises(ValueError, match="Model ID is required"):
315
+ adapter.litellm_model_id()
316
+
317
+
318
+ def test_litellm_model_id_caching(config, mock_task):
319
+ """Test that litellm_model_id caches the result"""
320
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
321
+
322
+ # Set the cached value directly
323
+ adapter._litellm_model_id = "cached-value"
324
+
325
+ # The method should return the cached value without calling model_provider
326
+ with patch.object(adapter, "model_provider") as mock_model_provider:
327
+ model_id = adapter.litellm_model_id()
328
+
329
+ assert model_id == "cached-value"
330
+ mock_model_provider.assert_not_called()
331
+
332
+
333
+ def test_litellm_model_id_unknown_provider(config, mock_task):
334
+ """Test litellm_model_id raises error for unknown provider"""
335
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
336
+
337
+ # Create a mock provider with an unknown name
338
+ mock_provider = Mock()
339
+ mock_provider.name = "unknown_provider" # Not in ModelProviderName enum
340
+ mock_provider.model_id = "test-model"
341
+
342
+ with patch.object(adapter, "model_provider", return_value=mock_provider):
343
+ with patch(
344
+ "kiln_ai.adapters.model_adapters.litellm_adapter.raise_exhaustive_enum_error"
345
+ ) as mock_raise_error:
346
+ mock_raise_error.side_effect = Exception("Test error")
347
+
348
+ with pytest.raises(Exception, match="Test error"):
349
+ adapter.litellm_model_id()
350
+
351
+
352
+ @pytest.mark.asyncio
353
+ @pytest.mark.parametrize(
354
+ "top_logprobs,response_format,extra_body",
355
+ [
356
+ (None, {}, {}), # Basic case
357
+ (5, {}, {}), # With logprobs
358
+ (
359
+ None,
360
+ {"response_format": {"type": "json_object"}},
361
+ {},
362
+ ), # With response format
363
+ (
364
+ 3,
365
+ {"tools": [{"type": "function"}]},
366
+ {"reasoning_effort": 0.8},
367
+ ), # Combined options
368
+ ],
369
+ )
370
+ async def test_build_completion_kwargs(
371
+ config, mock_task, top_logprobs, response_format, extra_body
372
+ ):
373
+ """Test build_completion_kwargs with various configurations"""
374
+ adapter = LiteLlmAdapter(config=config, kiln_task=mock_task)
375
+ mock_provider = Mock()
376
+ messages = [{"role": "user", "content": "Hello"}]
377
+
378
+ with (
379
+ patch.object(adapter, "model_provider", return_value=mock_provider),
380
+ patch.object(adapter, "litellm_model_id", return_value="openai/test-model"),
381
+ patch.object(adapter, "build_extra_body", return_value=extra_body),
382
+ patch.object(adapter, "response_format_options", return_value=response_format),
383
+ ):
384
+ kwargs = await adapter.build_completion_kwargs(
385
+ mock_provider, messages, top_logprobs
386
+ )
387
+
388
+ # Verify core functionality
389
+ assert kwargs["model"] == "openai/test-model"
390
+ assert kwargs["messages"] == messages
391
+ assert kwargs["api_base"] == config.base_url
392
+
393
+ # Verify optional parameters
394
+ if top_logprobs is not None:
395
+ assert kwargs["logprobs"] is True
396
+ assert kwargs["top_logprobs"] == top_logprobs
397
+ else:
398
+ assert "logprobs" not in kwargs
399
+ assert "top_logprobs" not in kwargs
400
+
401
+ # Verify response format is included
402
+ for key, value in response_format.items():
403
+ assert kwargs[key] == value
404
+
405
+ # Verify extra body is included
406
+ for key, value in extra_body.items():
407
+ assert kwargs[key] == value
@@ -66,7 +66,8 @@ async def test_mock_unstructred_response(tmp_path):
66
66
 
67
67
  # don't error on valid response
68
68
  adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
69
- answer = await adapter.invoke_returning_raw("You are a mock, send me the response!")
69
+ run = await adapter.invoke("You are a mock, send me the response!")
70
+ answer = json.loads(run.output.output)
70
71
  assert answer["setup"] == "asdf"
71
72
  assert answer["punchline"] == "asdf"
72
73
 
@@ -76,9 +77,12 @@ async def test_mock_unstructred_response(tmp_path):
76
77
  answer = await adapter.invoke("You are a mock, send me the response!")
77
78
 
78
79
  adapter = MockAdapter(task, response="string instead of dict")
79
- with pytest.raises(RuntimeError):
80
+ with pytest.raises(
81
+ ValueError,
82
+ match="This task requires JSON output but the model didn't return valid JSON",
83
+ ):
80
84
  # Not a structed response so should error
81
- answer = await adapter.invoke("You are a mock, send me the response!")
85
+ run = await adapter.invoke("You are a mock, send me the response!")
82
86
 
83
87
  # Should error, expecting a string, not a dict
84
88
  project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
@@ -143,7 +147,8 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
143
147
  task = build_structured_output_test_task(tmp_path)
144
148
  a = adapter_for_task(task, model_name=model_name, provider=provider)
145
149
  try:
146
- parsed = await a.invoke_returning_raw("Cows") # a joke about cows
150
+ run = await a.invoke("Cows") # a joke about cows
151
+ parsed = json.loads(run.output.output)
147
152
  except ValueError as e:
148
153
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
149
154
  pytest.skip(
@@ -162,6 +167,12 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
162
167
  assert rating >= 0
163
168
  assert rating <= 10
164
169
 
170
+ # Check reasoning models
171
+ assert a._model_provider is not None
172
+ if a._model_provider.reasoning_capable:
173
+ assert "reasoning" in run.intermediate_outputs
174
+ assert isinstance(run.intermediate_outputs["reasoning"], str)
175
+
165
176
 
166
177
  def build_structured_input_test_task(tmp_path: Path):
167
178
  project = datamodel.Project(name="test", path=tmp_path / "test.kiln")
@@ -220,7 +231,8 @@ async def run_structured_input_task(
220
231
  await a.invoke({"a": 1, "b": 2, "d": 3})
221
232
 
222
233
  try:
223
- response = await a.invoke_returning_raw({"a": 2, "b": 2, "c": 2})
234
+ run = await a.invoke({"a": 2, "b": 2, "c": 2})
235
+ response = run.output.output
224
236
  except ValueError as e:
225
237
  if str(e) == "Failed to connect to Ollama. Ensure Ollama is running.":
226
238
  pytest.skip(
@@ -241,6 +253,12 @@ async def run_structured_input_task(
241
253
  assert a.run_config.model_name == model_name
242
254
  assert a.run_config.model_provider_name == provider
243
255
 
256
+ # Check reasoning models
257
+ assert a._model_provider is not None
258
+ if a._model_provider.reasoning_capable:
259
+ assert "reasoning" in run.intermediate_outputs
260
+ assert isinstance(run.intermediate_outputs["reasoning"], str)
261
+
244
262
 
245
263
  @pytest.mark.paid
246
264
  async def test_structured_input_gpt_4o_mini(tmp_path):
@@ -38,6 +38,7 @@ async def ollama_online() -> bool:
38
38
 
39
39
  class OllamaConnection(BaseModel):
40
40
  message: str
41
+ version: str | None = None
41
42
  supported_models: List[str]
42
43
  untested_models: List[str] = Field(default_factory=list)
43
44
 
@@ -49,7 +50,7 @@ class OllamaConnection(BaseModel):
49
50
  def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
50
51
  # Build a list of models we support for Ollama from the built-in model list
51
52
  supported_ollama_models = [
52
- provider.provider_options["model"]
53
+ provider.model_id
53
54
  for model in built_in_models
54
55
  for provider in model.providers
55
56
  if provider.name == ModelProviderName.ollama
@@ -60,7 +61,7 @@ def parse_ollama_tags(tags: Any) -> OllamaConnection | None:
60
61
  alias
61
62
  for model in built_in_models
62
63
  for provider in model.providers
63
- for alias in provider.provider_options.get("model_aliases", [])
64
+ for alias in provider.ollama_model_aliases or []
64
65
  ]
65
66
  )
66
67
 
@@ -20,21 +20,33 @@ class R1ThinkingParser(BaseParser):
20
20
  Raises:
21
21
  ValueError: If response format is invalid (missing tags, multiple tags, or no content after closing tag)
22
22
  """
23
+
24
+ # The upstream providers (litellm, openrouter, fireworks) all keep changing their response formats, sometimes adding reasoning parsing where it didn't previously exist.
25
+ # If they do it already, great just return. If not we parse it ourselves. Not ideal, but better than upstream changes breaking the app.
26
+ if (
27
+ original_output.intermediate_outputs is not None
28
+ and "reasoning" in original_output.intermediate_outputs
29
+ ):
30
+ return original_output
31
+
23
32
  # This parser only works for strings
24
33
  if not isinstance(original_output.output, str):
25
34
  raise ValueError("Response must be a string for R1 parser")
26
35
 
27
36
  # Strip whitespace and validate basic structure
28
37
  cleaned_response = original_output.output.strip()
29
- if not cleaned_response.startswith(self.START_TAG):
30
- raise ValueError("Response must start with <think> tag")
31
38
 
32
39
  # Find the thinking tags
33
- think_start = cleaned_response.find(self.START_TAG)
34
40
  think_end = cleaned_response.find(self.END_TAG)
41
+ if think_end == -1:
42
+ raise ValueError("Missing </think> tag")
35
43
 
36
- if think_start == -1 or think_end == -1:
37
- raise ValueError("Missing thinking tags")
44
+ think_tag_start = cleaned_response.find(self.START_TAG)
45
+ if think_tag_start == -1:
46
+ # We allow no start <think>, thinking starts on first char. QwQ does this.
47
+ think_start = 0
48
+ else:
49
+ think_start = think_tag_start + len(self.START_TAG)
38
50
 
39
51
  # Check for multiple tags
40
52
  if (
@@ -44,9 +56,7 @@ class R1ThinkingParser(BaseParser):
44
56
  raise ValueError("Multiple thinking tags found")
45
57
 
46
58
  # Extract thinking content
47
- thinking_content = cleaned_response[
48
- think_start + len(self.START_TAG) : think_end
49
- ].strip()
59
+ thinking_content = cleaned_response[think_start:think_end].strip()
50
60
 
51
61
  # Extract result (everything after </think>)
52
62
  result = cleaned_response[think_end + len(self.END_TAG) :].strip()
@@ -54,16 +64,11 @@ class R1ThinkingParser(BaseParser):
54
64
  if not result or len(result) == 0:
55
65
  raise ValueError("No content found after </think> tag")
56
66
 
57
- # Parse JSON if needed
58
- output = result
59
- if self.structured_output:
60
- output = parse_json_string(result)
61
-
62
67
  # Add thinking content to intermediate outputs if it exists
63
68
  intermediate_outputs = original_output.intermediate_outputs or {}
64
69
  intermediate_outputs["reasoning"] = thinking_content
65
70
 
66
71
  return RunOutput(
67
- output=output,
72
+ output=result,
68
73
  intermediate_outputs=intermediate_outputs,
69
74
  )
@@ -19,6 +19,16 @@ def test_valid_response(parser):
19
19
  assert parsed.output == "This is the result"
20
20
 
21
21
 
22
+ def test_already_parsed_response(parser):
23
+ response = RunOutput(
24
+ output="This is the result",
25
+ intermediate_outputs={"reasoning": "This is thinking content"},
26
+ )
27
+ parsed = parser.parse_output(response)
28
+ assert parsed.intermediate_outputs["reasoning"] == "This is thinking content"
29
+ assert parsed.output == "This is the result"
30
+
31
+
22
32
  def test_response_with_whitespace(parser):
23
33
  response = RunOutput(
24
34
  output="""
@@ -37,14 +47,16 @@ def test_response_with_whitespace(parser):
37
47
 
38
48
 
39
49
  def test_missing_start_tag(parser):
40
- with pytest.raises(ValueError, match="Response must start with <think> tag"):
41
- parser.parse_output(
42
- RunOutput(output="Some content</think>result", intermediate_outputs=None)
43
- )
50
+ parsed = parser.parse_output(
51
+ RunOutput(output="Some content</think>result", intermediate_outputs=None)
52
+ )
53
+
54
+ assert parsed.intermediate_outputs["reasoning"] == "Some content"
55
+ assert parsed.output == "result"
44
56
 
45
57
 
46
58
  def test_missing_end_tag(parser):
47
- with pytest.raises(ValueError, match="Missing thinking tags"):
59
+ with pytest.raises(ValueError, match="Missing </think> tag"):
48
60
  parser.parse_output(
49
61
  RunOutput(output="<think>Some content", intermediate_outputs=None)
50
62
  )