kiln-ai 0.11.1__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/adapter_registry.py +12 -13
- kiln_ai/adapters/data_gen/data_gen_task.py +18 -0
- kiln_ai/adapters/eval/base_eval.py +164 -0
- kiln_ai/adapters/eval/eval_runner.py +267 -0
- kiln_ai/adapters/eval/g_eval.py +367 -0
- kiln_ai/adapters/eval/registry.py +16 -0
- kiln_ai/adapters/eval/test_base_eval.py +324 -0
- kiln_ai/adapters/eval/test_eval_runner.py +640 -0
- kiln_ai/adapters/eval/test_g_eval.py +497 -0
- kiln_ai/adapters/eval/test_g_eval_data.py +4 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +4 -1
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +1 -1
- kiln_ai/adapters/fine_tune/test_openai_finetune.py +1 -1
- kiln_ai/adapters/ml_model_list.py +141 -29
- kiln_ai/adapters/model_adapters/base_adapter.py +50 -35
- kiln_ai/adapters/model_adapters/langchain_adapters.py +27 -20
- kiln_ai/adapters/model_adapters/openai_compatible_config.py +0 -1
- kiln_ai/adapters/model_adapters/openai_model_adapter.py +93 -50
- kiln_ai/adapters/model_adapters/test_base_adapter.py +22 -13
- kiln_ai/adapters/model_adapters/test_langchain_adapter.py +7 -14
- kiln_ai/adapters/model_adapters/test_openai_model_adapter.py +55 -64
- kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +41 -19
- kiln_ai/adapters/model_adapters/test_structured_output.py +36 -30
- kiln_ai/adapters/ollama_tools.py +0 -1
- kiln_ai/adapters/prompt_builders.py +80 -42
- kiln_ai/adapters/repair/repair_task.py +9 -21
- kiln_ai/adapters/repair/test_repair_task.py +3 -3
- kiln_ai/adapters/run_output.py +3 -0
- kiln_ai/adapters/test_adapter_registry.py +10 -10
- kiln_ai/adapters/test_generate_docs.py +6 -6
- kiln_ai/adapters/test_ollama_tools.py +0 -1
- kiln_ai/adapters/test_prompt_adaptors.py +17 -14
- kiln_ai/adapters/test_prompt_builders.py +91 -31
- kiln_ai/datamodel/__init__.py +50 -952
- kiln_ai/datamodel/datamodel_enums.py +58 -0
- kiln_ai/datamodel/dataset_filters.py +114 -0
- kiln_ai/datamodel/dataset_split.py +170 -0
- kiln_ai/datamodel/eval.py +298 -0
- kiln_ai/datamodel/finetune.py +105 -0
- kiln_ai/datamodel/json_schema.py +6 -0
- kiln_ai/datamodel/project.py +23 -0
- kiln_ai/datamodel/prompt.py +37 -0
- kiln_ai/datamodel/prompt_id.py +83 -0
- kiln_ai/datamodel/strict_mode.py +24 -0
- kiln_ai/datamodel/task.py +181 -0
- kiln_ai/datamodel/task_output.py +321 -0
- kiln_ai/datamodel/task_run.py +164 -0
- kiln_ai/datamodel/test_basemodel.py +10 -11
- kiln_ai/datamodel/test_dataset_filters.py +71 -0
- kiln_ai/datamodel/test_dataset_split.py +32 -8
- kiln_ai/datamodel/test_datasource.py +3 -2
- kiln_ai/datamodel/test_eval_model.py +635 -0
- kiln_ai/datamodel/test_example_models.py +9 -13
- kiln_ai/datamodel/test_json_schema.py +23 -0
- kiln_ai/datamodel/test_models.py +2 -2
- kiln_ai/datamodel/test_prompt_id.py +129 -0
- kiln_ai/datamodel/test_task.py +159 -0
- kiln_ai/utils/config.py +6 -1
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/METADATA +37 -1
- kiln_ai-0.12.0.dist-info/RECORD +100 -0
- kiln_ai-0.11.1.dist-info/RECORD +0 -76
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.11.1.dist-info → kiln_ai-0.12.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -43,6 +43,8 @@ class ModelFamily(str, Enum):
|
|
|
43
43
|
mixtral = "mixtral"
|
|
44
44
|
qwen = "qwen"
|
|
45
45
|
deepseek = "deepseek"
|
|
46
|
+
dolphin = "dolphin"
|
|
47
|
+
grok = "grok"
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
# Where models have instruct and raw versions, instruct is default and raw is specified
|
|
@@ -71,6 +73,8 @@ class ModelName(str, Enum):
|
|
|
71
73
|
gemma_2_27b = "gemma_2_27b"
|
|
72
74
|
claude_3_5_haiku = "claude_3_5_haiku"
|
|
73
75
|
claude_3_5_sonnet = "claude_3_5_sonnet"
|
|
76
|
+
claude_3_7_sonnet = "claude_3_7_sonnet"
|
|
77
|
+
claude_3_7_sonnet_thinking = "claude_3_7_sonnet_thinking"
|
|
74
78
|
gemini_1_5_flash = "gemini_1_5_flash"
|
|
75
79
|
gemini_1_5_flash_8b = "gemini_1_5_flash_8b"
|
|
76
80
|
gemini_1_5_pro = "gemini_1_5_pro"
|
|
@@ -88,6 +92,8 @@ class ModelName(str, Enum):
|
|
|
88
92
|
deepseek_r1_distill_qwen_1p5b = "deepseek_r1_distill_qwen_1p5b"
|
|
89
93
|
deepseek_r1_distill_qwen_7b = "deepseek_r1_distill_qwen_7b"
|
|
90
94
|
deepseek_r1_distill_llama_8b = "deepseek_r1_distill_llama_8b"
|
|
95
|
+
dolphin_2_9_8x22b = "dolphin_2_9_8x22b"
|
|
96
|
+
grok_2 = "grok_2"
|
|
91
97
|
|
|
92
98
|
|
|
93
99
|
class ModelParserID(str, Enum):
|
|
@@ -123,6 +129,15 @@ class KilnModelProvider(BaseModel):
|
|
|
123
129
|
structured_output_mode: StructuredOutputMode = StructuredOutputMode.default
|
|
124
130
|
parser: ModelParserID | None = None
|
|
125
131
|
reasoning_capable: bool = False
|
|
132
|
+
supports_logprobs: bool = False
|
|
133
|
+
|
|
134
|
+
# TODO P1: Need a more generalized way to handle custom provider parameters.
|
|
135
|
+
# Making them quite declarative here for now, isolating provider specific logic
|
|
136
|
+
# to this file. Later I should be able to override anything in this file via config.
|
|
137
|
+
r1_openrouter_options: bool = False
|
|
138
|
+
require_openrouter_reasoning: bool = False
|
|
139
|
+
logprobs_openrouter_options: bool = False
|
|
140
|
+
openrouter_skip_required_parameters: bool = False
|
|
126
141
|
|
|
127
142
|
|
|
128
143
|
class KilnModel(BaseModel):
|
|
@@ -155,11 +170,14 @@ built_in_models: List[KilnModel] = [
|
|
|
155
170
|
provider_options={"model": "gpt-4o-mini"},
|
|
156
171
|
provider_finetune_id="gpt-4o-mini-2024-07-18",
|
|
157
172
|
structured_output_mode=StructuredOutputMode.json_schema,
|
|
173
|
+
supports_logprobs=True,
|
|
158
174
|
),
|
|
159
175
|
KilnModelProvider(
|
|
160
176
|
name=ModelProviderName.openrouter,
|
|
161
177
|
provider_options={"model": "openai/gpt-4o-mini"},
|
|
162
178
|
structured_output_mode=StructuredOutputMode.json_schema,
|
|
179
|
+
supports_logprobs=True,
|
|
180
|
+
logprobs_openrouter_options=True,
|
|
163
181
|
),
|
|
164
182
|
],
|
|
165
183
|
),
|
|
@@ -174,11 +192,14 @@ built_in_models: List[KilnModel] = [
|
|
|
174
192
|
provider_options={"model": "gpt-4o"},
|
|
175
193
|
provider_finetune_id="gpt-4o-2024-08-06",
|
|
176
194
|
structured_output_mode=StructuredOutputMode.json_schema,
|
|
195
|
+
supports_logprobs=True,
|
|
177
196
|
),
|
|
178
197
|
KilnModelProvider(
|
|
179
198
|
name=ModelProviderName.openrouter,
|
|
180
199
|
provider_options={"model": "openai/gpt-4o"},
|
|
181
200
|
structured_output_mode=StructuredOutputMode.json_schema,
|
|
201
|
+
supports_logprobs=True,
|
|
202
|
+
logprobs_openrouter_options=True,
|
|
182
203
|
),
|
|
183
204
|
],
|
|
184
205
|
),
|
|
@@ -190,7 +211,7 @@ built_in_models: List[KilnModel] = [
|
|
|
190
211
|
providers=[
|
|
191
212
|
KilnModelProvider(
|
|
192
213
|
name=ModelProviderName.openrouter,
|
|
193
|
-
structured_output_mode=StructuredOutputMode.
|
|
214
|
+
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
|
|
194
215
|
provider_options={"model": "anthropic/claude-3-5-haiku"},
|
|
195
216
|
),
|
|
196
217
|
],
|
|
@@ -203,51 +224,37 @@ built_in_models: List[KilnModel] = [
|
|
|
203
224
|
providers=[
|
|
204
225
|
KilnModelProvider(
|
|
205
226
|
name=ModelProviderName.openrouter,
|
|
206
|
-
structured_output_mode=StructuredOutputMode.
|
|
227
|
+
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
|
|
207
228
|
provider_options={"model": "anthropic/claude-3.5-sonnet"},
|
|
208
229
|
),
|
|
209
230
|
],
|
|
210
231
|
),
|
|
211
|
-
#
|
|
232
|
+
# Claude 3.7 Sonnet
|
|
212
233
|
KilnModel(
|
|
213
|
-
family=ModelFamily.
|
|
214
|
-
name=ModelName.
|
|
215
|
-
friendly_name="
|
|
234
|
+
family=ModelFamily.claude,
|
|
235
|
+
name=ModelName.claude_3_7_sonnet,
|
|
236
|
+
friendly_name="Claude 3.7 Sonnet",
|
|
216
237
|
providers=[
|
|
217
238
|
KilnModelProvider(
|
|
218
239
|
name=ModelProviderName.openrouter,
|
|
219
|
-
provider_options={"model": "deepseek/deepseek-chat"},
|
|
220
240
|
structured_output_mode=StructuredOutputMode.function_calling,
|
|
241
|
+
provider_options={"model": "anthropic/claude-3.7-sonnet"},
|
|
221
242
|
),
|
|
222
243
|
],
|
|
223
244
|
),
|
|
224
|
-
#
|
|
245
|
+
# Claude 3.7 Sonnet Thinking
|
|
225
246
|
KilnModel(
|
|
226
|
-
family=ModelFamily.
|
|
227
|
-
name=ModelName.
|
|
228
|
-
friendly_name="
|
|
247
|
+
family=ModelFamily.claude,
|
|
248
|
+
name=ModelName.claude_3_7_sonnet_thinking,
|
|
249
|
+
friendly_name="Claude 3.7 Sonnet Thinking",
|
|
229
250
|
providers=[
|
|
230
251
|
KilnModelProvider(
|
|
231
252
|
name=ModelProviderName.openrouter,
|
|
232
|
-
provider_options={"model": "
|
|
233
|
-
# No custom parser -- openrouter implemented it themselves
|
|
234
|
-
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
253
|
+
provider_options={"model": "anthropic/claude-3.7-sonnet:thinking"},
|
|
235
254
|
reasoning_capable=True,
|
|
236
|
-
|
|
237
|
-
KilnModelProvider(
|
|
238
|
-
name=ModelProviderName.fireworks_ai,
|
|
239
|
-
provider_options={"model": "accounts/fireworks/models/deepseek-r1"},
|
|
240
|
-
parser=ModelParserID.r1_thinking,
|
|
241
|
-
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
242
|
-
reasoning_capable=True,
|
|
243
|
-
),
|
|
244
|
-
KilnModelProvider(
|
|
245
|
-
# I want your RAM
|
|
246
|
-
name=ModelProviderName.ollama,
|
|
247
|
-
provider_options={"model": "deepseek-r1:671b"},
|
|
248
|
-
parser=ModelParserID.r1_thinking,
|
|
255
|
+
# For reasoning models, we need to use json_instructions with OpenRouter
|
|
249
256
|
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
250
|
-
|
|
257
|
+
require_openrouter_reasoning=True,
|
|
251
258
|
),
|
|
252
259
|
],
|
|
253
260
|
),
|
|
@@ -379,8 +386,11 @@ built_in_models: List[KilnModel] = [
|
|
|
379
386
|
KilnModelProvider(
|
|
380
387
|
name=ModelProviderName.openrouter,
|
|
381
388
|
supports_data_gen=False,
|
|
382
|
-
|
|
389
|
+
# Need to not pass "strict=True" to the function call to get this to work with logprobs for some reason. Openrouter issue.
|
|
390
|
+
structured_output_mode=StructuredOutputMode.function_calling_weak,
|
|
383
391
|
provider_options={"model": "meta-llama/llama-3.1-70b-instruct"},
|
|
392
|
+
supports_logprobs=True,
|
|
393
|
+
logprobs_openrouter_options=True,
|
|
384
394
|
),
|
|
385
395
|
KilnModelProvider(
|
|
386
396
|
name=ModelProviderName.ollama,
|
|
@@ -819,6 +829,58 @@ built_in_models: List[KilnModel] = [
|
|
|
819
829
|
),
|
|
820
830
|
],
|
|
821
831
|
),
|
|
832
|
+
# DeepSeek 3
|
|
833
|
+
KilnModel(
|
|
834
|
+
family=ModelFamily.deepseek,
|
|
835
|
+
name=ModelName.deepseek_3,
|
|
836
|
+
friendly_name="DeepSeek V3",
|
|
837
|
+
providers=[
|
|
838
|
+
KilnModelProvider(
|
|
839
|
+
name=ModelProviderName.openrouter,
|
|
840
|
+
provider_options={"model": "deepseek/deepseek-chat"},
|
|
841
|
+
structured_output_mode=StructuredOutputMode.function_calling,
|
|
842
|
+
),
|
|
843
|
+
KilnModelProvider(
|
|
844
|
+
name=ModelProviderName.fireworks_ai,
|
|
845
|
+
provider_options={"model": "accounts/fireworks/models/deepseek-v3"},
|
|
846
|
+
structured_output_mode=StructuredOutputMode.json_mode,
|
|
847
|
+
supports_structured_output=True,
|
|
848
|
+
supports_data_gen=False,
|
|
849
|
+
),
|
|
850
|
+
],
|
|
851
|
+
),
|
|
852
|
+
# DeepSeek R1
|
|
853
|
+
KilnModel(
|
|
854
|
+
family=ModelFamily.deepseek,
|
|
855
|
+
name=ModelName.deepseek_r1,
|
|
856
|
+
friendly_name="DeepSeek R1",
|
|
857
|
+
providers=[
|
|
858
|
+
KilnModelProvider(
|
|
859
|
+
name=ModelProviderName.openrouter,
|
|
860
|
+
provider_options={"model": "deepseek/deepseek-r1"},
|
|
861
|
+
# No custom parser -- openrouter implemented it themselves
|
|
862
|
+
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
863
|
+
reasoning_capable=True,
|
|
864
|
+
r1_openrouter_options=True,
|
|
865
|
+
require_openrouter_reasoning=True,
|
|
866
|
+
),
|
|
867
|
+
KilnModelProvider(
|
|
868
|
+
name=ModelProviderName.fireworks_ai,
|
|
869
|
+
provider_options={"model": "accounts/fireworks/models/deepseek-r1"},
|
|
870
|
+
parser=ModelParserID.r1_thinking,
|
|
871
|
+
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
872
|
+
reasoning_capable=True,
|
|
873
|
+
),
|
|
874
|
+
KilnModelProvider(
|
|
875
|
+
# I want your RAM
|
|
876
|
+
name=ModelProviderName.ollama,
|
|
877
|
+
provider_options={"model": "deepseek-r1:671b"},
|
|
878
|
+
parser=ModelParserID.r1_thinking,
|
|
879
|
+
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
880
|
+
reasoning_capable=True,
|
|
881
|
+
),
|
|
882
|
+
],
|
|
883
|
+
),
|
|
822
884
|
# DeepSeek R1 Distill Qwen 32B
|
|
823
885
|
KilnModel(
|
|
824
886
|
family=ModelFamily.deepseek,
|
|
@@ -830,6 +892,8 @@ built_in_models: List[KilnModel] = [
|
|
|
830
892
|
reasoning_capable=True,
|
|
831
893
|
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
832
894
|
provider_options={"model": "deepseek/deepseek-r1-distill-qwen-32b"},
|
|
895
|
+
r1_openrouter_options=True,
|
|
896
|
+
require_openrouter_reasoning=True,
|
|
833
897
|
),
|
|
834
898
|
KilnModelProvider(
|
|
835
899
|
name=ModelProviderName.ollama,
|
|
@@ -851,6 +915,8 @@ built_in_models: List[KilnModel] = [
|
|
|
851
915
|
reasoning_capable=True,
|
|
852
916
|
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
853
917
|
provider_options={"model": "deepseek/deepseek-r1-distill-llama-70b"},
|
|
918
|
+
r1_openrouter_options=True,
|
|
919
|
+
require_openrouter_reasoning=True,
|
|
854
920
|
),
|
|
855
921
|
KilnModelProvider(
|
|
856
922
|
name=ModelProviderName.ollama,
|
|
@@ -874,6 +940,9 @@ built_in_models: List[KilnModel] = [
|
|
|
874
940
|
reasoning_capable=True,
|
|
875
941
|
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
876
942
|
provider_options={"model": "deepseek/deepseek-r1-distill-qwen-14b"},
|
|
943
|
+
r1_openrouter_options=True,
|
|
944
|
+
require_openrouter_reasoning=True,
|
|
945
|
+
openrouter_skip_required_parameters=True,
|
|
877
946
|
),
|
|
878
947
|
KilnModelProvider(
|
|
879
948
|
name=ModelProviderName.ollama,
|
|
@@ -897,6 +966,9 @@ built_in_models: List[KilnModel] = [
|
|
|
897
966
|
reasoning_capable=True,
|
|
898
967
|
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
899
968
|
provider_options={"model": "deepseek/deepseek-r1-distill-llama-8b"},
|
|
969
|
+
r1_openrouter_options=True,
|
|
970
|
+
require_openrouter_reasoning=True,
|
|
971
|
+
openrouter_skip_required_parameters=True,
|
|
900
972
|
),
|
|
901
973
|
KilnModelProvider(
|
|
902
974
|
name=ModelProviderName.ollama,
|
|
@@ -937,6 +1009,9 @@ built_in_models: List[KilnModel] = [
|
|
|
937
1009
|
reasoning_capable=True,
|
|
938
1010
|
structured_output_mode=StructuredOutputMode.json_instructions,
|
|
939
1011
|
provider_options={"model": "deepseek/deepseek-r1-distill-qwen-1.5b"},
|
|
1012
|
+
r1_openrouter_options=True,
|
|
1013
|
+
require_openrouter_reasoning=True,
|
|
1014
|
+
openrouter_skip_required_parameters=True,
|
|
940
1015
|
),
|
|
941
1016
|
KilnModelProvider(
|
|
942
1017
|
name=ModelProviderName.ollama,
|
|
@@ -948,4 +1023,41 @@ built_in_models: List[KilnModel] = [
|
|
|
948
1023
|
),
|
|
949
1024
|
],
|
|
950
1025
|
),
|
|
1026
|
+
# Dolphin 2.9 Mixtral 8x22B
|
|
1027
|
+
KilnModel(
|
|
1028
|
+
family=ModelFamily.dolphin,
|
|
1029
|
+
name=ModelName.dolphin_2_9_8x22b,
|
|
1030
|
+
friendly_name="Dolphin 2.9 8x22B",
|
|
1031
|
+
providers=[
|
|
1032
|
+
KilnModelProvider(
|
|
1033
|
+
name=ModelProviderName.ollama,
|
|
1034
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1035
|
+
supports_data_gen=True,
|
|
1036
|
+
provider_options={"model": "dolphin-mixtral:8x22b"},
|
|
1037
|
+
),
|
|
1038
|
+
KilnModelProvider(
|
|
1039
|
+
name=ModelProviderName.openrouter,
|
|
1040
|
+
provider_options={
|
|
1041
|
+
"model": "cognitivecomputations/dolphin-mixtral-8x22b"
|
|
1042
|
+
},
|
|
1043
|
+
supports_data_gen=True,
|
|
1044
|
+
structured_output_mode=StructuredOutputMode.json_instruction_and_object,
|
|
1045
|
+
),
|
|
1046
|
+
],
|
|
1047
|
+
),
|
|
1048
|
+
# Grok 2
|
|
1049
|
+
KilnModel(
|
|
1050
|
+
family=ModelFamily.grok,
|
|
1051
|
+
name=ModelName.grok_2,
|
|
1052
|
+
friendly_name="Grok 2",
|
|
1053
|
+
providers=[
|
|
1054
|
+
KilnModelProvider(
|
|
1055
|
+
name=ModelProviderName.openrouter,
|
|
1056
|
+
provider_options={"model": "x-ai/grok-2-1212"},
|
|
1057
|
+
supports_structured_output=True,
|
|
1058
|
+
supports_data_gen=True,
|
|
1059
|
+
structured_output_mode=StructuredOutputMode.json_schema,
|
|
1060
|
+
),
|
|
1061
|
+
],
|
|
1062
|
+
),
|
|
951
1063
|
]
|
|
@@ -5,7 +5,7 @@ from typing import Dict, Literal, Tuple
|
|
|
5
5
|
|
|
6
6
|
from kiln_ai.adapters.ml_model_list import KilnModelProvider, StructuredOutputMode
|
|
7
7
|
from kiln_ai.adapters.parsers.parser_registry import model_parser_from_id
|
|
8
|
-
from kiln_ai.adapters.prompt_builders import
|
|
8
|
+
from kiln_ai.adapters.prompt_builders import prompt_builder_from_id
|
|
9
9
|
from kiln_ai.adapters.provider_tools import kiln_model_provider_from
|
|
10
10
|
from kiln_ai.adapters.run_output import RunOutput
|
|
11
11
|
from kiln_ai.datamodel import (
|
|
@@ -16,16 +16,21 @@ from kiln_ai.datamodel import (
|
|
|
16
16
|
TaskRun,
|
|
17
17
|
)
|
|
18
18
|
from kiln_ai.datamodel.json_schema import validate_schema
|
|
19
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
19
20
|
from kiln_ai.utils.config import Config
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
@dataclass
|
|
23
|
-
class
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
24
|
+
class AdapterConfig:
|
|
25
|
+
"""
|
|
26
|
+
An adapter config is config options that do NOT impact the output of the model.
|
|
27
|
+
|
|
28
|
+
For example: if it's saved, of if we request additional data like logprobs.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
allow_saving: bool = True
|
|
32
|
+
top_logprobs: int | None = None
|
|
33
|
+
default_tags: list[str] | None = None
|
|
29
34
|
|
|
30
35
|
|
|
31
36
|
COT_FINAL_ANSWER_PROMPT = "Considering the above, return a final result."
|
|
@@ -47,35 +52,36 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
47
52
|
|
|
48
53
|
def __init__(
|
|
49
54
|
self,
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
model_provider_name: str,
|
|
53
|
-
prompt_builder: BasePromptBuilder | None = None,
|
|
54
|
-
tags: list[str] | None = None,
|
|
55
|
+
run_config: RunConfig,
|
|
56
|
+
config: AdapterConfig | None = None,
|
|
55
57
|
):
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
self.default_tags = tags
|
|
61
|
-
self.model_name = model_name
|
|
62
|
-
self.model_provider_name = model_provider_name
|
|
58
|
+
self.run_config = run_config
|
|
59
|
+
self.prompt_builder = prompt_builder_from_id(
|
|
60
|
+
run_config.prompt_id, run_config.task
|
|
61
|
+
)
|
|
63
62
|
self._model_provider: KilnModelProvider | None = None
|
|
64
63
|
|
|
64
|
+
self.output_schema = self.task().output_json_schema
|
|
65
|
+
self.input_schema = self.task().input_json_schema
|
|
66
|
+
self.base_adapter_config = config or AdapterConfig()
|
|
67
|
+
|
|
68
|
+
def task(self) -> Task:
|
|
69
|
+
return self.run_config.task
|
|
70
|
+
|
|
65
71
|
def model_provider(self) -> KilnModelProvider:
|
|
66
72
|
"""
|
|
67
73
|
Lazy load the model provider for this adapter.
|
|
68
74
|
"""
|
|
69
75
|
if self._model_provider is not None:
|
|
70
76
|
return self._model_provider
|
|
71
|
-
if not self.model_name or not self.model_provider_name:
|
|
77
|
+
if not self.run_config.model_name or not self.run_config.model_provider_name:
|
|
72
78
|
raise ValueError("model_name and model_provider_name must be provided")
|
|
73
79
|
self._model_provider = kiln_model_provider_from(
|
|
74
|
-
self.model_name, self.model_provider_name
|
|
80
|
+
self.run_config.model_name, self.run_config.model_provider_name
|
|
75
81
|
)
|
|
76
82
|
if not self._model_provider:
|
|
77
83
|
raise ValueError(
|
|
78
|
-
f"model_provider_name {self.model_provider_name} not found for model {self.model_name}"
|
|
84
|
+
f"model_provider_name {self.run_config.model_provider_name} not found for model {self.run_config.model_name}"
|
|
79
85
|
)
|
|
80
86
|
return self._model_provider
|
|
81
87
|
|
|
@@ -85,7 +91,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
85
91
|
input_source: DataSource | None = None,
|
|
86
92
|
) -> Dict | str:
|
|
87
93
|
result = await self.invoke(input, input_source)
|
|
88
|
-
if self.
|
|
94
|
+
if self.task().output_json_schema is None:
|
|
89
95
|
return result.output.output
|
|
90
96
|
else:
|
|
91
97
|
return json.loads(result.output.output)
|
|
@@ -95,6 +101,14 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
95
101
|
input: Dict | str,
|
|
96
102
|
input_source: DataSource | None = None,
|
|
97
103
|
) -> TaskRun:
|
|
104
|
+
run_output, _ = await self.invoke_returning_run_output(input, input_source)
|
|
105
|
+
return run_output
|
|
106
|
+
|
|
107
|
+
async def invoke_returning_run_output(
|
|
108
|
+
self,
|
|
109
|
+
input: Dict | str,
|
|
110
|
+
input_source: DataSource | None = None,
|
|
111
|
+
) -> Tuple[TaskRun, RunOutput]:
|
|
98
112
|
# validate input
|
|
99
113
|
if self.input_schema is not None:
|
|
100
114
|
if not isinstance(input, dict):
|
|
@@ -128,19 +142,23 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
128
142
|
run = self.generate_run(input, input_source, parsed_output)
|
|
129
143
|
|
|
130
144
|
# Save the run if configured to do so, and we have a path to save to
|
|
131
|
-
if
|
|
145
|
+
if (
|
|
146
|
+
self.base_adapter_config.allow_saving
|
|
147
|
+
and Config.shared().autosave_runs
|
|
148
|
+
and self.task().path is not None
|
|
149
|
+
):
|
|
132
150
|
run.save_to_file()
|
|
133
151
|
else:
|
|
134
152
|
# Clear the ID to indicate it's not persisted
|
|
135
153
|
run.id = None
|
|
136
154
|
|
|
137
|
-
return run
|
|
155
|
+
return run, run_output
|
|
138
156
|
|
|
139
157
|
def has_structured_output(self) -> bool:
|
|
140
158
|
return self.output_schema is not None
|
|
141
159
|
|
|
142
160
|
@abstractmethod
|
|
143
|
-
def
|
|
161
|
+
def adapter_name(self) -> str:
|
|
144
162
|
pass
|
|
145
163
|
|
|
146
164
|
@abstractmethod
|
|
@@ -203,7 +221,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
203
221
|
)
|
|
204
222
|
|
|
205
223
|
new_task_run = TaskRun(
|
|
206
|
-
parent=self.
|
|
224
|
+
parent=self.task(),
|
|
207
225
|
input=input_str,
|
|
208
226
|
input_source=input_source,
|
|
209
227
|
output=TaskOutput(
|
|
@@ -215,7 +233,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
215
233
|
),
|
|
216
234
|
),
|
|
217
235
|
intermediate_outputs=run_output.intermediate_outputs,
|
|
218
|
-
tags=self.default_tags or [],
|
|
236
|
+
tags=self.base_adapter_config.default_tags or [],
|
|
219
237
|
)
|
|
220
238
|
|
|
221
239
|
return new_task_run
|
|
@@ -224,12 +242,9 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
224
242
|
props = {}
|
|
225
243
|
|
|
226
244
|
# adapter info
|
|
227
|
-
|
|
228
|
-
props["
|
|
229
|
-
props["
|
|
230
|
-
props["
|
|
231
|
-
props["prompt_builder_name"] = adapter_info.prompt_builder_name
|
|
232
|
-
if adapter_info.prompt_id is not None:
|
|
233
|
-
props["prompt_id"] = adapter_info.prompt_id
|
|
245
|
+
props["adapter_name"] = self.adapter_name()
|
|
246
|
+
props["model_name"] = self.run_config.model_name
|
|
247
|
+
props["model_provider"] = self.run_config.model_provider_name
|
|
248
|
+
props["prompt_id"] = self.run_config.prompt_id
|
|
234
249
|
|
|
235
250
|
return props
|
|
@@ -20,9 +20,8 @@ from kiln_ai.adapters.ml_model_list import (
|
|
|
20
20
|
)
|
|
21
21
|
from kiln_ai.adapters.model_adapters.base_adapter import (
|
|
22
22
|
COT_FINAL_ANSWER_PROMPT,
|
|
23
|
-
|
|
23
|
+
AdapterConfig,
|
|
24
24
|
BaseAdapter,
|
|
25
|
-
BasePromptBuilder,
|
|
26
25
|
RunOutput,
|
|
27
26
|
)
|
|
28
27
|
from kiln_ai.adapters.ollama_tools import (
|
|
@@ -30,6 +29,8 @@ from kiln_ai.adapters.ollama_tools import (
|
|
|
30
29
|
ollama_base_url,
|
|
31
30
|
ollama_model_installed,
|
|
32
31
|
)
|
|
32
|
+
from kiln_ai.datamodel import PromptId
|
|
33
|
+
from kiln_ai.datamodel.task import RunConfig
|
|
33
34
|
from kiln_ai.utils.config import Config
|
|
34
35
|
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
35
36
|
|
|
@@ -45,8 +46,8 @@ class LangchainAdapter(BaseAdapter):
|
|
|
45
46
|
custom_model: BaseChatModel | None = None,
|
|
46
47
|
model_name: str | None = None,
|
|
47
48
|
provider: str | None = None,
|
|
48
|
-
|
|
49
|
-
|
|
49
|
+
prompt_id: PromptId | None = None,
|
|
50
|
+
base_adapter_config: AdapterConfig | None = None,
|
|
50
51
|
):
|
|
51
52
|
if custom_model is not None:
|
|
52
53
|
self._model = custom_model
|
|
@@ -78,12 +79,16 @@ class LangchainAdapter(BaseAdapter):
|
|
|
78
79
|
if model_name is None:
|
|
79
80
|
raise ValueError("model_name must be provided")
|
|
80
81
|
|
|
81
|
-
|
|
82
|
-
kiln_task,
|
|
82
|
+
run_config = RunConfig(
|
|
83
|
+
task=kiln_task,
|
|
83
84
|
model_name=model_name,
|
|
84
85
|
model_provider_name=provider,
|
|
85
|
-
|
|
86
|
-
|
|
86
|
+
prompt_id=prompt_id or datamodel.PromptGenerators.SIMPLE,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
super().__init__(
|
|
90
|
+
run_config=run_config,
|
|
91
|
+
config=base_adapter_config,
|
|
87
92
|
)
|
|
88
93
|
|
|
89
94
|
async def model(self) -> LangChainModelType:
|
|
@@ -111,15 +116,15 @@ class LangchainAdapter(BaseAdapter):
|
|
|
111
116
|
f"model {self._model} does not support structured output, cannot use output_json_schema"
|
|
112
117
|
)
|
|
113
118
|
# Langchain expects title/description to be at top level, on top of json schema
|
|
114
|
-
output_schema = self.
|
|
119
|
+
output_schema = self.task().output_schema()
|
|
115
120
|
if output_schema is None:
|
|
116
121
|
raise ValueError(
|
|
117
|
-
f"output_json_schema is not valid json: {self.
|
|
122
|
+
f"output_json_schema is not valid json: {self.task().output_json_schema}"
|
|
118
123
|
)
|
|
119
124
|
output_schema["title"] = "task_response"
|
|
120
125
|
output_schema["description"] = "A response from the task"
|
|
121
126
|
with_structured_output_options = self.get_structured_output_options(
|
|
122
|
-
self.model_name, self.model_provider_name
|
|
127
|
+
self.run_config.model_name, self.run_config.model_provider_name
|
|
123
128
|
)
|
|
124
129
|
self._model = self._model.with_structured_output(
|
|
125
130
|
output_schema,
|
|
@@ -129,6 +134,11 @@ class LangchainAdapter(BaseAdapter):
|
|
|
129
134
|
return self._model
|
|
130
135
|
|
|
131
136
|
async def _run(self, input: Dict | str) -> RunOutput:
|
|
137
|
+
if self.base_adapter_config.top_logprobs is not None:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"Kiln's Langchain adapter does not support logprobs/top_logprobs. Select a model from an OpenAI compatible provider (openai, openrouter, etc) instead."
|
|
140
|
+
)
|
|
141
|
+
|
|
132
142
|
provider = self.model_provider()
|
|
133
143
|
model = await self.model()
|
|
134
144
|
chain = model
|
|
@@ -191,14 +201,8 @@ class LangchainAdapter(BaseAdapter):
|
|
|
191
201
|
intermediate_outputs=intermediate_outputs,
|
|
192
202
|
)
|
|
193
203
|
|
|
194
|
-
def
|
|
195
|
-
return
|
|
196
|
-
model_name=self.model_name,
|
|
197
|
-
model_provider=self.model_provider_name,
|
|
198
|
-
adapter_name="kiln_langchain_adapter",
|
|
199
|
-
prompt_builder_name=self.prompt_builder.__class__.prompt_builder_name(),
|
|
200
|
-
prompt_id=self.prompt_builder.prompt_id(),
|
|
201
|
-
)
|
|
204
|
+
def adapter_name(self) -> str:
|
|
205
|
+
return "kiln_langchain_adapter"
|
|
202
206
|
|
|
203
207
|
def _munge_response(self, response: Dict) -> Dict:
|
|
204
208
|
# Mistral Large tool calling format is a bit different. Convert to standard format.
|
|
@@ -220,6 +224,9 @@ class LangchainAdapter(BaseAdapter):
|
|
|
220
224
|
options = {}
|
|
221
225
|
# We may need to add some provider specific logic here if providers use different names for the same mode, but everyone is copying openai for now
|
|
222
226
|
match provider.structured_output_mode:
|
|
227
|
+
case StructuredOutputMode.function_calling_weak:
|
|
228
|
+
# Langchaing doesn't handle weak/strict separately
|
|
229
|
+
options["method"] = "function_calling"
|
|
223
230
|
case StructuredOutputMode.function_calling:
|
|
224
231
|
options["method"] = "function_calling"
|
|
225
232
|
case StructuredOutputMode.json_mode:
|
|
@@ -246,7 +253,7 @@ class LangchainAdapter(BaseAdapter):
|
|
|
246
253
|
|
|
247
254
|
async def langchain_model_from(self) -> BaseChatModel:
|
|
248
255
|
provider = self.model_provider()
|
|
249
|
-
return await langchain_model_from_provider(provider, self.model_name)
|
|
256
|
+
return await langchain_model_from_provider(provider, self.run_config.model_name)
|
|
250
257
|
|
|
251
258
|
|
|
252
259
|
async def langchain_model_from_provider(
|