kiln-ai 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiln_ai/adapters/__init__.py +2 -0
- kiln_ai/adapters/adapter_registry.py +22 -44
- kiln_ai/adapters/chat/__init__.py +8 -0
- kiln_ai/adapters/chat/chat_formatter.py +233 -0
- kiln_ai/adapters/chat/test_chat_formatter.py +131 -0
- kiln_ai/adapters/data_gen/data_gen_prompts.py +121 -36
- kiln_ai/adapters/data_gen/data_gen_task.py +49 -36
- kiln_ai/adapters/data_gen/test_data_gen_task.py +330 -40
- kiln_ai/adapters/eval/base_eval.py +7 -6
- kiln_ai/adapters/eval/eval_runner.py +9 -2
- kiln_ai/adapters/eval/g_eval.py +40 -17
- kiln_ai/adapters/eval/test_base_eval.py +174 -17
- kiln_ai/adapters/eval/test_eval_runner.py +3 -0
- kiln_ai/adapters/eval/test_g_eval.py +116 -5
- kiln_ai/adapters/fine_tune/base_finetune.py +3 -8
- kiln_ai/adapters/fine_tune/dataset_formatter.py +135 -273
- kiln_ai/adapters/fine_tune/test_base_finetune.py +10 -10
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +287 -353
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +3 -3
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +6 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +1 -0
- kiln_ai/adapters/fine_tune/test_vertex_finetune.py +6 -11
- kiln_ai/adapters/fine_tune/together_finetune.py +13 -2
- kiln_ai/adapters/ml_model_list.py +370 -84
- kiln_ai/adapters/model_adapters/base_adapter.py +73 -26
- kiln_ai/adapters/model_adapters/litellm_adapter.py +88 -97
- kiln_ai/adapters/model_adapters/litellm_config.py +3 -2
- kiln_ai/adapters/model_adapters/test_base_adapter.py +235 -61
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +104 -21
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -0
- kiln_ai/adapters/model_adapters/test_structured_output.py +44 -12
- kiln_ai/adapters/parsers/parser_registry.py +0 -2
- kiln_ai/adapters/parsers/r1_parser.py +0 -1
- kiln_ai/adapters/prompt_builders.py +0 -16
- kiln_ai/adapters/provider_tools.py +27 -9
- kiln_ai/adapters/remote_config.py +66 -0
- kiln_ai/adapters/repair/repair_task.py +1 -6
- kiln_ai/adapters/repair/test_repair_task.py +24 -3
- kiln_ai/adapters/test_adapter_registry.py +88 -28
- kiln_ai/adapters/test_ml_model_list.py +176 -0
- kiln_ai/adapters/test_prompt_adaptors.py +17 -7
- kiln_ai/adapters/test_prompt_builders.py +3 -16
- kiln_ai/adapters/test_provider_tools.py +69 -20
- kiln_ai/adapters/test_remote_config.py +100 -0
- kiln_ai/datamodel/__init__.py +0 -2
- kiln_ai/datamodel/datamodel_enums.py +38 -13
- kiln_ai/datamodel/eval.py +32 -0
- kiln_ai/datamodel/finetune.py +12 -8
- kiln_ai/datamodel/task.py +68 -7
- kiln_ai/datamodel/task_output.py +0 -2
- kiln_ai/datamodel/task_run.py +0 -2
- kiln_ai/datamodel/test_basemodel.py +2 -1
- kiln_ai/datamodel/test_dataset_split.py +0 -8
- kiln_ai/datamodel/test_eval_model.py +146 -4
- kiln_ai/datamodel/test_models.py +33 -10
- kiln_ai/datamodel/test_task.py +168 -2
- kiln_ai/utils/config.py +3 -2
- kiln_ai/utils/dataset_import.py +1 -1
- kiln_ai/utils/logging.py +166 -0
- kiln_ai/utils/test_config.py +23 -0
- kiln_ai/utils/test_dataset_import.py +30 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/METADATA +2 -2
- kiln_ai-0.18.0.dist-info/RECORD +115 -0
- kiln_ai-0.16.0.dist-info/RECORD +0 -108
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.16.0.dist-info → kiln_ai-0.18.0.dist-info}/licenses/LICENSE.txt +0 -0
kiln_ai/adapters/__init__.py
CHANGED
|
@@ -17,6 +17,7 @@ The eval submodule contains the code for evaluating the performance of a model.
|
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
19
|
from . import (
|
|
20
|
+
chat,
|
|
20
21
|
data_gen,
|
|
21
22
|
eval,
|
|
22
23
|
fine_tune,
|
|
@@ -28,6 +29,7 @@ from . import (
|
|
|
28
29
|
|
|
29
30
|
__all__ = [
|
|
30
31
|
"model_adapters",
|
|
32
|
+
"chat",
|
|
31
33
|
"data_gen",
|
|
32
34
|
"fine_tune",
|
|
33
35
|
"ml_model_list",
|
|
@@ -7,31 +7,33 @@ from kiln_ai.adapters.model_adapters.litellm_adapter import (
|
|
|
7
7
|
LiteLlmAdapter,
|
|
8
8
|
LiteLlmConfig,
|
|
9
9
|
)
|
|
10
|
-
from kiln_ai.adapters.provider_tools import
|
|
11
|
-
|
|
10
|
+
from kiln_ai.adapters.provider_tools import (
|
|
11
|
+
core_provider,
|
|
12
|
+
lite_llm_config_for_openai_compatible,
|
|
13
|
+
)
|
|
14
|
+
from kiln_ai.datamodel.task import RunConfigProperties
|
|
12
15
|
from kiln_ai.utils.config import Config
|
|
13
16
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def adapter_for_task(
|
|
17
20
|
kiln_task: datamodel.Task,
|
|
18
|
-
|
|
19
|
-
provider: ModelProviderName,
|
|
20
|
-
prompt_id: PromptId | None = None,
|
|
21
|
+
run_config_properties: RunConfigProperties,
|
|
21
22
|
base_adapter_config: AdapterConfig | None = None,
|
|
22
23
|
) -> BaseAdapter:
|
|
23
24
|
# Get the provider to run. For things like the fine-tune provider, we want to run the underlying provider
|
|
24
|
-
core_provider_name = core_provider(
|
|
25
|
+
core_provider_name = core_provider(
|
|
26
|
+
run_config_properties.model_name, run_config_properties.model_provider_name
|
|
27
|
+
)
|
|
25
28
|
|
|
26
29
|
match core_provider_name:
|
|
27
30
|
case ModelProviderName.openrouter:
|
|
28
31
|
return LiteLlmAdapter(
|
|
29
32
|
kiln_task=kiln_task,
|
|
30
33
|
config=LiteLlmConfig(
|
|
31
|
-
|
|
34
|
+
run_config_properties=run_config_properties,
|
|
32
35
|
base_url=getenv("OPENROUTER_BASE_URL")
|
|
33
36
|
or "https://openrouter.ai/api/v1",
|
|
34
|
-
provider_name=provider,
|
|
35
37
|
default_headers={
|
|
36
38
|
"HTTP-Referer": "https://getkiln.ai/openrouter",
|
|
37
39
|
"X-Title": "KilnAI",
|
|
@@ -40,38 +42,32 @@ def adapter_for_task(
|
|
|
40
42
|
"api_key": Config.shared().open_router_api_key,
|
|
41
43
|
},
|
|
42
44
|
),
|
|
43
|
-
prompt_id=prompt_id,
|
|
44
45
|
base_adapter_config=base_adapter_config,
|
|
45
46
|
)
|
|
46
47
|
case ModelProviderName.openai:
|
|
47
48
|
return LiteLlmAdapter(
|
|
48
49
|
kiln_task=kiln_task,
|
|
49
50
|
config=LiteLlmConfig(
|
|
50
|
-
|
|
51
|
-
provider_name=provider,
|
|
51
|
+
run_config_properties=run_config_properties,
|
|
52
52
|
additional_body_options={
|
|
53
53
|
"api_key": Config.shared().open_ai_api_key,
|
|
54
54
|
},
|
|
55
55
|
),
|
|
56
|
-
prompt_id=prompt_id,
|
|
57
56
|
base_adapter_config=base_adapter_config,
|
|
58
57
|
)
|
|
59
58
|
case ModelProviderName.openai_compatible:
|
|
60
|
-
config =
|
|
59
|
+
config = lite_llm_config_for_openai_compatible(run_config_properties)
|
|
61
60
|
return LiteLlmAdapter(
|
|
62
61
|
kiln_task=kiln_task,
|
|
63
62
|
config=config,
|
|
64
|
-
prompt_id=prompt_id,
|
|
65
63
|
base_adapter_config=base_adapter_config,
|
|
66
64
|
)
|
|
67
65
|
case ModelProviderName.groq:
|
|
68
66
|
return LiteLlmAdapter(
|
|
69
67
|
kiln_task=kiln_task,
|
|
70
|
-
prompt_id=prompt_id,
|
|
71
68
|
base_adapter_config=base_adapter_config,
|
|
72
69
|
config=LiteLlmConfig(
|
|
73
|
-
|
|
74
|
-
provider_name=provider,
|
|
70
|
+
run_config_properties=run_config_properties,
|
|
75
71
|
additional_body_options={
|
|
76
72
|
"api_key": Config.shared().groq_api_key,
|
|
77
73
|
},
|
|
@@ -80,11 +76,9 @@ def adapter_for_task(
|
|
|
80
76
|
case ModelProviderName.amazon_bedrock:
|
|
81
77
|
return LiteLlmAdapter(
|
|
82
78
|
kiln_task=kiln_task,
|
|
83
|
-
prompt_id=prompt_id,
|
|
84
79
|
base_adapter_config=base_adapter_config,
|
|
85
80
|
config=LiteLlmConfig(
|
|
86
|
-
|
|
87
|
-
provider_name=provider,
|
|
81
|
+
run_config_properties=run_config_properties,
|
|
88
82
|
additional_body_options={
|
|
89
83
|
"aws_access_key_id": Config.shared().bedrock_access_key,
|
|
90
84
|
"aws_secret_access_key": Config.shared().bedrock_secret_key,
|
|
@@ -99,11 +93,9 @@ def adapter_for_task(
|
|
|
99
93
|
)
|
|
100
94
|
return LiteLlmAdapter(
|
|
101
95
|
kiln_task=kiln_task,
|
|
102
|
-
prompt_id=prompt_id,
|
|
103
96
|
base_adapter_config=base_adapter_config,
|
|
104
97
|
config=LiteLlmConfig(
|
|
105
|
-
|
|
106
|
-
provider_name=provider,
|
|
98
|
+
run_config_properties=run_config_properties,
|
|
107
99
|
# Set the Ollama base URL for 2 reasons:
|
|
108
100
|
# 1. To use the correct base URL
|
|
109
101
|
# 2. We use Ollama's OpenAI compatible API (/v1), and don't just let litellm use the Ollama API. We use more advanced features like json_schema.
|
|
@@ -117,11 +109,9 @@ def adapter_for_task(
|
|
|
117
109
|
case ModelProviderName.fireworks_ai:
|
|
118
110
|
return LiteLlmAdapter(
|
|
119
111
|
kiln_task=kiln_task,
|
|
120
|
-
prompt_id=prompt_id,
|
|
121
112
|
base_adapter_config=base_adapter_config,
|
|
122
113
|
config=LiteLlmConfig(
|
|
123
|
-
|
|
124
|
-
provider_name=provider,
|
|
114
|
+
run_config_properties=run_config_properties,
|
|
125
115
|
additional_body_options={
|
|
126
116
|
"api_key": Config.shared().fireworks_api_key,
|
|
127
117
|
},
|
|
@@ -130,11 +120,9 @@ def adapter_for_task(
|
|
|
130
120
|
case ModelProviderName.anthropic:
|
|
131
121
|
return LiteLlmAdapter(
|
|
132
122
|
kiln_task=kiln_task,
|
|
133
|
-
prompt_id=prompt_id,
|
|
134
123
|
base_adapter_config=base_adapter_config,
|
|
135
124
|
config=LiteLlmConfig(
|
|
136
|
-
|
|
137
|
-
provider_name=provider,
|
|
125
|
+
run_config_properties=run_config_properties,
|
|
138
126
|
additional_body_options={
|
|
139
127
|
"api_key": Config.shared().anthropic_api_key,
|
|
140
128
|
},
|
|
@@ -143,11 +131,9 @@ def adapter_for_task(
|
|
|
143
131
|
case ModelProviderName.gemini_api:
|
|
144
132
|
return LiteLlmAdapter(
|
|
145
133
|
kiln_task=kiln_task,
|
|
146
|
-
prompt_id=prompt_id,
|
|
147
134
|
base_adapter_config=base_adapter_config,
|
|
148
135
|
config=LiteLlmConfig(
|
|
149
|
-
|
|
150
|
-
provider_name=provider,
|
|
136
|
+
run_config_properties=run_config_properties,
|
|
151
137
|
additional_body_options={
|
|
152
138
|
"api_key": Config.shared().gemini_api_key,
|
|
153
139
|
},
|
|
@@ -156,11 +142,9 @@ def adapter_for_task(
|
|
|
156
142
|
case ModelProviderName.vertex:
|
|
157
143
|
return LiteLlmAdapter(
|
|
158
144
|
kiln_task=kiln_task,
|
|
159
|
-
prompt_id=prompt_id,
|
|
160
145
|
base_adapter_config=base_adapter_config,
|
|
161
146
|
config=LiteLlmConfig(
|
|
162
|
-
|
|
163
|
-
provider_name=provider,
|
|
147
|
+
run_config_properties=run_config_properties,
|
|
164
148
|
additional_body_options={
|
|
165
149
|
"vertex_project": Config.shared().vertex_project_id,
|
|
166
150
|
"vertex_location": Config.shared().vertex_location,
|
|
@@ -170,11 +154,9 @@ def adapter_for_task(
|
|
|
170
154
|
case ModelProviderName.together_ai:
|
|
171
155
|
return LiteLlmAdapter(
|
|
172
156
|
kiln_task=kiln_task,
|
|
173
|
-
prompt_id=prompt_id,
|
|
174
157
|
base_adapter_config=base_adapter_config,
|
|
175
158
|
config=LiteLlmConfig(
|
|
176
|
-
|
|
177
|
-
provider_name=provider,
|
|
159
|
+
run_config_properties=run_config_properties,
|
|
178
160
|
additional_body_options={
|
|
179
161
|
"api_key": Config.shared().together_api_key,
|
|
180
162
|
},
|
|
@@ -183,12 +165,10 @@ def adapter_for_task(
|
|
|
183
165
|
case ModelProviderName.azure_openai:
|
|
184
166
|
return LiteLlmAdapter(
|
|
185
167
|
kiln_task=kiln_task,
|
|
186
|
-
prompt_id=prompt_id,
|
|
187
168
|
base_adapter_config=base_adapter_config,
|
|
188
169
|
config=LiteLlmConfig(
|
|
189
170
|
base_url=Config.shared().azure_openai_endpoint,
|
|
190
|
-
|
|
191
|
-
provider_name=provider,
|
|
171
|
+
run_config_properties=run_config_properties,
|
|
192
172
|
additional_body_options={
|
|
193
173
|
"api_key": Config.shared().azure_openai_api_key,
|
|
194
174
|
"api_version": "2025-02-01-preview",
|
|
@@ -198,11 +178,9 @@ def adapter_for_task(
|
|
|
198
178
|
case ModelProviderName.huggingface:
|
|
199
179
|
return LiteLlmAdapter(
|
|
200
180
|
kiln_task=kiln_task,
|
|
201
|
-
prompt_id=prompt_id,
|
|
202
181
|
base_adapter_config=base_adapter_config,
|
|
203
182
|
config=LiteLlmConfig(
|
|
204
|
-
|
|
205
|
-
provider_name=provider,
|
|
183
|
+
run_config_properties=run_config_properties,
|
|
206
184
|
additional_body_options={
|
|
207
185
|
"api_key": Config.shared().huggingface_api_key,
|
|
208
186
|
},
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Dict, List, Literal, Optional
|
|
7
|
+
|
|
8
|
+
from kiln_ai.datamodel.datamodel_enums import ChatStrategy
|
|
9
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
10
|
+
|
|
11
|
+
COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class ChatMessage:
|
|
16
|
+
role: Literal["system", "assistant", "user"]
|
|
17
|
+
content: Optional[str]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ChatTurn:
|
|
22
|
+
"""
|
|
23
|
+
All data needed to send a chat turn to the model.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
messages: List[ChatMessage]
|
|
27
|
+
final_call: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ChatFormatter(ABC):
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
system_message: str,
|
|
34
|
+
user_input: str | Dict,
|
|
35
|
+
thinking_instructions: str | None = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
self.system_message = system_message
|
|
38
|
+
self.user_input = user_input
|
|
39
|
+
self.thinking_instructions = thinking_instructions
|
|
40
|
+
self._messages: List[ChatMessage] = []
|
|
41
|
+
self._state = "start"
|
|
42
|
+
self._intermediate_outputs: Dict[str, str] = {}
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def messages(self) -> List[ChatMessage]:
|
|
46
|
+
return list(self._messages)
|
|
47
|
+
|
|
48
|
+
def message_dicts(self) -> List[dict[str, str | None]]:
|
|
49
|
+
return [{"role": m.role, "content": m.content} for m in self._messages]
|
|
50
|
+
|
|
51
|
+
def intermediate_outputs(self) -> Dict[str, str]:
|
|
52
|
+
"""Get the intermediate outputs from the chat formatter."""
|
|
53
|
+
return self._intermediate_outputs
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
|
|
57
|
+
"""Advance the conversation and return the next messages if any."""
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SingleTurnFormatter(ChatFormatter):
|
|
62
|
+
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
|
|
63
|
+
if self._state == "start":
|
|
64
|
+
msgs = [
|
|
65
|
+
ChatMessage("system", self.system_message),
|
|
66
|
+
ChatMessage("user", format_user_message(self.user_input)),
|
|
67
|
+
]
|
|
68
|
+
self._state = "awaiting_final"
|
|
69
|
+
self._messages.extend(msgs)
|
|
70
|
+
return ChatTurn(messages=msgs, final_call=True)
|
|
71
|
+
|
|
72
|
+
if self._state == "awaiting_final":
|
|
73
|
+
if previous_output is None:
|
|
74
|
+
raise ValueError("previous_output required for final step")
|
|
75
|
+
self._messages.append(ChatMessage("assistant", previous_output))
|
|
76
|
+
self._state = "done"
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class TwoMessageCotLegacyFormatter(ChatFormatter):
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
system_message: str,
|
|
86
|
+
user_input: str | Dict,
|
|
87
|
+
thinking_instructions: str | None,
|
|
88
|
+
) -> None:
|
|
89
|
+
super().__init__(system_message, user_input, thinking_instructions)
|
|
90
|
+
if self.thinking_instructions is None:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"thinking_instructions are required when strategy is final_and_intermediate"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
|
|
96
|
+
if self._state == "start":
|
|
97
|
+
msgs = [
|
|
98
|
+
ChatMessage("system", self.system_message),
|
|
99
|
+
ChatMessage("user", format_user_message(self.user_input)),
|
|
100
|
+
ChatMessage("system", self.thinking_instructions),
|
|
101
|
+
]
|
|
102
|
+
self._state = "awaiting_thinking"
|
|
103
|
+
self._messages.extend(msgs)
|
|
104
|
+
return ChatTurn(messages=msgs, final_call=False)
|
|
105
|
+
|
|
106
|
+
if self._state == "awaiting_thinking":
|
|
107
|
+
if previous_output is None:
|
|
108
|
+
raise ValueError("previous_output required for thinking step")
|
|
109
|
+
msgs = [
|
|
110
|
+
ChatMessage("assistant", previous_output),
|
|
111
|
+
ChatMessage("user", COT_FINAL_ANSWER_PROMPT),
|
|
112
|
+
]
|
|
113
|
+
self._intermediate_outputs["chain_of_thought"] = previous_output
|
|
114
|
+
self._state = "awaiting_final"
|
|
115
|
+
self._messages.extend(msgs)
|
|
116
|
+
return ChatTurn(messages=msgs, final_call=True)
|
|
117
|
+
|
|
118
|
+
if self._state == "awaiting_final":
|
|
119
|
+
if previous_output is None:
|
|
120
|
+
raise ValueError("previous_output required for final step")
|
|
121
|
+
self._messages.append(ChatMessage("assistant", previous_output))
|
|
122
|
+
self._state = "done"
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TwoMessageCotFormatter(ChatFormatter):
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
system_message: str,
|
|
132
|
+
user_input: str | Dict,
|
|
133
|
+
thinking_instructions: str | None,
|
|
134
|
+
) -> None:
|
|
135
|
+
super().__init__(system_message, user_input, thinking_instructions)
|
|
136
|
+
if self.thinking_instructions is None:
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"thinking_instructions are required when strategy is final_and_intermediate"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
|
|
142
|
+
if self._state == "start":
|
|
143
|
+
# User message combines the input and the thinking instructions
|
|
144
|
+
formatted_user_message = format_user_message(self.user_input)
|
|
145
|
+
user_message = f"The input is:\n<user_input>\n{formatted_user_message}\n</user_input>\n\n{self.thinking_instructions}"
|
|
146
|
+
|
|
147
|
+
msgs = [
|
|
148
|
+
ChatMessage("system", self.system_message),
|
|
149
|
+
ChatMessage("user", user_message),
|
|
150
|
+
]
|
|
151
|
+
self._state = "awaiting_thinking"
|
|
152
|
+
self._messages.extend(msgs)
|
|
153
|
+
return ChatTurn(messages=msgs, final_call=False)
|
|
154
|
+
|
|
155
|
+
if self._state == "awaiting_thinking":
|
|
156
|
+
if previous_output is None:
|
|
157
|
+
raise ValueError("previous_output required for thinking step")
|
|
158
|
+
msgs = [
|
|
159
|
+
ChatMessage("assistant", previous_output),
|
|
160
|
+
ChatMessage("user", COT_FINAL_ANSWER_PROMPT),
|
|
161
|
+
]
|
|
162
|
+
self._intermediate_outputs["chain_of_thought"] = previous_output
|
|
163
|
+
self._state = "awaiting_final"
|
|
164
|
+
self._messages.extend(msgs)
|
|
165
|
+
return ChatTurn(messages=msgs, final_call=True)
|
|
166
|
+
|
|
167
|
+
if self._state == "awaiting_final":
|
|
168
|
+
if previous_output is None:
|
|
169
|
+
raise ValueError("previous_output required for final step")
|
|
170
|
+
self._messages.append(ChatMessage("assistant", previous_output))
|
|
171
|
+
self._state = "done"
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class SingleTurnR1ThinkingFormatter(ChatFormatter):
|
|
178
|
+
def next_turn(self, previous_output: str | None = None) -> Optional[ChatTurn]:
|
|
179
|
+
if self._state == "start":
|
|
180
|
+
msgs = [
|
|
181
|
+
ChatMessage("system", self.system_message),
|
|
182
|
+
ChatMessage("user", format_user_message(self.user_input)),
|
|
183
|
+
]
|
|
184
|
+
self._state = "awaiting_final"
|
|
185
|
+
self._messages.extend(msgs)
|
|
186
|
+
return ChatTurn(messages=msgs, final_call=True)
|
|
187
|
+
|
|
188
|
+
if self._state == "awaiting_final":
|
|
189
|
+
if previous_output is None:
|
|
190
|
+
raise ValueError("previous_output required for final step")
|
|
191
|
+
self._messages.append(ChatMessage("assistant", previous_output))
|
|
192
|
+
self._state = "done"
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def get_chat_formatter(
|
|
199
|
+
strategy: ChatStrategy,
|
|
200
|
+
system_message: str,
|
|
201
|
+
user_input: str | Dict,
|
|
202
|
+
thinking_instructions: str | None = None,
|
|
203
|
+
) -> ChatFormatter:
|
|
204
|
+
match strategy:
|
|
205
|
+
case ChatStrategy.single_turn:
|
|
206
|
+
return SingleTurnFormatter(system_message, user_input)
|
|
207
|
+
case ChatStrategy.two_message_cot_legacy:
|
|
208
|
+
return TwoMessageCotLegacyFormatter(
|
|
209
|
+
system_message, user_input, thinking_instructions
|
|
210
|
+
)
|
|
211
|
+
case ChatStrategy.two_message_cot:
|
|
212
|
+
return TwoMessageCotFormatter(
|
|
213
|
+
system_message, user_input, thinking_instructions
|
|
214
|
+
)
|
|
215
|
+
case ChatStrategy.single_turn_r1_thinking:
|
|
216
|
+
return SingleTurnR1ThinkingFormatter(system_message, user_input)
|
|
217
|
+
case _:
|
|
218
|
+
raise_exhaustive_enum_error(strategy)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def format_user_message(input: Dict | str) -> str:
|
|
222
|
+
"""Build a user message from the input.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
input (Union[Dict, str]): The input to format into a message.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
str: The formatted user message.
|
|
229
|
+
"""
|
|
230
|
+
if isinstance(input, dict):
|
|
231
|
+
return json.dumps(input, ensure_ascii=False)
|
|
232
|
+
|
|
233
|
+
return input
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from kiln_ai.adapters.chat import ChatStrategy, get_chat_formatter
|
|
2
|
+
from kiln_ai.adapters.chat.chat_formatter import (
|
|
3
|
+
COT_FINAL_ANSWER_PROMPT,
|
|
4
|
+
format_user_message,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_chat_formatter_final_only():
|
|
9
|
+
expected = [
|
|
10
|
+
{"role": "system", "content": "system message"},
|
|
11
|
+
{"role": "user", "content": "test input"},
|
|
12
|
+
{"role": "assistant", "content": "test output"},
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
formatter = get_chat_formatter(
|
|
16
|
+
strategy=ChatStrategy.single_turn,
|
|
17
|
+
system_message="system message",
|
|
18
|
+
user_input="test input",
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
first = formatter.next_turn()
|
|
22
|
+
assert [m.__dict__ for m in first.messages] == expected[:2]
|
|
23
|
+
assert first.final_call
|
|
24
|
+
assert formatter.intermediate_outputs() == {}
|
|
25
|
+
|
|
26
|
+
assert formatter.next_turn("test output") is None
|
|
27
|
+
assert formatter.message_dicts() == expected
|
|
28
|
+
assert formatter.intermediate_outputs() == {}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def test_chat_formatter_final_and_intermediate():
|
|
32
|
+
expected = [
|
|
33
|
+
{"role": "system", "content": "system message"},
|
|
34
|
+
{"role": "user", "content": "test input"},
|
|
35
|
+
{"role": "system", "content": "thinking instructions"},
|
|
36
|
+
{"role": "assistant", "content": "thinking output"},
|
|
37
|
+
{"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
|
|
38
|
+
{"role": "assistant", "content": "test output"},
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
formatter = get_chat_formatter(
|
|
42
|
+
strategy=ChatStrategy.two_message_cot_legacy,
|
|
43
|
+
system_message="system message",
|
|
44
|
+
user_input="test input",
|
|
45
|
+
thinking_instructions="thinking instructions",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
first = formatter.next_turn()
|
|
49
|
+
assert [m.__dict__ for m in first.messages] == expected[:3]
|
|
50
|
+
assert not first.final_call
|
|
51
|
+
assert formatter.intermediate_outputs() == {}
|
|
52
|
+
|
|
53
|
+
second = formatter.next_turn("thinking output")
|
|
54
|
+
assert [m.__dict__ for m in second.messages] == expected[3:5]
|
|
55
|
+
assert second.final_call
|
|
56
|
+
assert formatter.intermediate_outputs() == {"chain_of_thought": "thinking output"}
|
|
57
|
+
|
|
58
|
+
assert formatter.next_turn("test output") is None
|
|
59
|
+
assert formatter.message_dicts() == expected
|
|
60
|
+
assert formatter.intermediate_outputs() == {"chain_of_thought": "thinking output"}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_chat_formatter_two_message_cot():
|
|
64
|
+
user_message = "The input is:\n<user_input>\ntest input\n</user_input>\n\nthinking instructions"
|
|
65
|
+
expected = [
|
|
66
|
+
{"role": "system", "content": "system message"},
|
|
67
|
+
{"role": "user", "content": user_message},
|
|
68
|
+
{"role": "assistant", "content": "thinking output"},
|
|
69
|
+
{"role": "user", "content": COT_FINAL_ANSWER_PROMPT},
|
|
70
|
+
{"role": "assistant", "content": "test output"},
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
formatter = get_chat_formatter(
|
|
74
|
+
strategy=ChatStrategy.two_message_cot,
|
|
75
|
+
system_message="system message",
|
|
76
|
+
user_input="test input",
|
|
77
|
+
thinking_instructions="thinking instructions",
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
first = formatter.next_turn()
|
|
81
|
+
assert [m.__dict__ for m in first.messages] == expected[:2]
|
|
82
|
+
assert not first.final_call
|
|
83
|
+
assert formatter.intermediate_outputs() == {}
|
|
84
|
+
|
|
85
|
+
second = formatter.next_turn("thinking output")
|
|
86
|
+
assert [m.__dict__ for m in second.messages] == expected[2:4]
|
|
87
|
+
assert second.final_call
|
|
88
|
+
assert formatter.intermediate_outputs() == {"chain_of_thought": "thinking output"}
|
|
89
|
+
|
|
90
|
+
assert formatter.next_turn("test output") is None
|
|
91
|
+
assert formatter.message_dicts() == expected
|
|
92
|
+
assert formatter.intermediate_outputs() == {"chain_of_thought": "thinking output"}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def test_chat_formatter_r1_style():
|
|
96
|
+
thinking_output = "<think>thinking</think> answer"
|
|
97
|
+
expected = [
|
|
98
|
+
{"role": "system", "content": "system message"},
|
|
99
|
+
{"role": "user", "content": "test input"},
|
|
100
|
+
{"role": "assistant", "content": thinking_output},
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
formatter = get_chat_formatter(
|
|
104
|
+
strategy=ChatStrategy.single_turn_r1_thinking,
|
|
105
|
+
system_message="system message",
|
|
106
|
+
user_input="test input",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
first = formatter.next_turn()
|
|
110
|
+
assert [m.__dict__ for m in first.messages] == expected[:2]
|
|
111
|
+
assert first.final_call
|
|
112
|
+
|
|
113
|
+
assert formatter.next_turn(thinking_output) is None
|
|
114
|
+
assert formatter.message_dicts() == expected
|
|
115
|
+
assert formatter.intermediate_outputs() == {}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_format_user_message():
|
|
119
|
+
# String
|
|
120
|
+
assert format_user_message("test input") == "test input"
|
|
121
|
+
# JSON, preserving order
|
|
122
|
+
assert (
|
|
123
|
+
format_user_message({"test": "input", "a": "b"})
|
|
124
|
+
== '{"test": "input", "a": "b"}'
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def test_simple_prompt_builder_structured_input_non_ascii():
|
|
129
|
+
input = {"key": "你好👋"}
|
|
130
|
+
user_msg = format_user_message(input)
|
|
131
|
+
assert "你好👋" in user_msg
|