kiln-ai 0.7.1__tar.gz → 0.8.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/PKG-INFO +1 -1
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/adapter_registry.py +2 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/base_adapter.py +6 -1
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/langchain_adapters.py +5 -1
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/ml_model_list.py +9 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/provider_tools.py +48 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_provider_tools.py +95 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/__init__.py +113 -14
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/basemodel.py +3 -9
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_dataset_split.py +1 -1
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_models.py +49 -0
- kiln_ai-0.8.1/kiln_ai/datamodel/test_output_rating.py +456 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/utils/config.py +28 -9
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/utils/test_config.py +48 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/pyproject.toml +1 -1
- kiln_ai-0.7.1/kiln_ai/datamodel/test_output_rating.py +0 -89
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/.gitignore +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/.python-version +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/LICENSE.txt +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/README.md +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/index.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters/base_adapter.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters/langchain_adapters.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters/ml_model_list.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters/prompt_builders.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters/repair/repair_task.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters/repair.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/adapters.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/datamodel/basemodel.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/datamodel/json_schema.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/datamodel.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/utils/config.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/utils/formatting.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai/utils.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/kiln_ai.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/core_library_docs/search.js +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/index.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/base_adapter.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/data_gen/data_gen_task.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/data_gen.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/fine_tune/base_finetune.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/fine_tune/dataset_formatter.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/fine_tune/dataset_split.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/fine_tune/finetune_registry.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/fine_tune/openai_finetune.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/fine_tune.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/langchain_adapters.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/ml_model_list.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/prompt_builders.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/repair/repair_task.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters/repair.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/adapters.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/datamodel/basemodel.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/datamodel/json_schema.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/datamodel.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/utils/config.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/utils/formatting.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai/utils.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/kiln_ai.html +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/docs/kiln_core_docs/search.js +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/__init__.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/__init__.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/data_gen/__init__.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/data_gen/data_gen_prompts.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/data_gen/data_gen_task.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/data_gen/test_data_gen_task.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/__init__.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/base_finetune.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/dataset_formatter.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/finetune_registry.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/fireworks_finetune.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/openai_finetune.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/test_base_finetune.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/test_dataset_formatter.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/fine_tune/test_openai_finetune.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/ollama_tools.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/prompt_builders.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/repair/__init__.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/repair/repair_task.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/repair/test_repair_task.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_langchain_adapter.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_ollama_tools.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_prompt_adaptors.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_prompt_builders.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_saving_adapter_results.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/adapters/test_structured_output.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/json_schema.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/model_cache.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/registry.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_basemodel.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_datasource.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_example_models.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_json_schema.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_model_cache.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_nested_save.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/datamodel/test_registry.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/utils/__init__.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/utils/formatting.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/utils/name_generator.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/kiln_ai/utils/test_name_geneator.py +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/setup.cfg +0 -0
- {kiln_ai-0.7.1 → kiln_ai-0.8.1}/uv.lock +0 -0
|
@@ -9,6 +9,7 @@ def adapter_for_task(
|
|
|
9
9
|
model_name: str | None = None,
|
|
10
10
|
provider: str | None = None,
|
|
11
11
|
prompt_builder: BasePromptBuilder | None = None,
|
|
12
|
+
tags: list[str] | None = None,
|
|
12
13
|
) -> BaseAdapter:
|
|
13
14
|
# We use langchain for everything right now, but can add any others here
|
|
14
15
|
return LangchainAdapter(
|
|
@@ -16,4 +17,5 @@ def adapter_for_task(
|
|
|
16
17
|
model_name=model_name,
|
|
17
18
|
provider=provider,
|
|
18
19
|
prompt_builder=prompt_builder,
|
|
20
|
+
tags=tags,
|
|
19
21
|
)
|
|
@@ -45,12 +45,16 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
47
|
def __init__(
|
|
48
|
-
self,
|
|
48
|
+
self,
|
|
49
|
+
kiln_task: Task,
|
|
50
|
+
prompt_builder: BasePromptBuilder | None = None,
|
|
51
|
+
tags: list[str] | None = None,
|
|
49
52
|
):
|
|
50
53
|
self.prompt_builder = prompt_builder or SimplePromptBuilder(kiln_task)
|
|
51
54
|
self.kiln_task = kiln_task
|
|
52
55
|
self.output_schema = self.kiln_task.output_json_schema
|
|
53
56
|
self.input_schema = self.kiln_task.input_json_schema
|
|
57
|
+
self.default_tags = tags
|
|
54
58
|
|
|
55
59
|
async def invoke_returning_raw(
|
|
56
60
|
self,
|
|
@@ -148,6 +152,7 @@ class BaseAdapter(metaclass=ABCMeta):
|
|
|
148
152
|
),
|
|
149
153
|
),
|
|
150
154
|
intermediate_outputs=run_output.intermediate_outputs,
|
|
155
|
+
tags=self.default_tags or [],
|
|
151
156
|
)
|
|
152
157
|
|
|
153
158
|
exclude_fields = {
|
|
@@ -39,8 +39,9 @@ class LangchainAdapter(BaseAdapter):
|
|
|
39
39
|
model_name: str | None = None,
|
|
40
40
|
provider: str | None = None,
|
|
41
41
|
prompt_builder: BasePromptBuilder | None = None,
|
|
42
|
+
tags: list[str] | None = None,
|
|
42
43
|
):
|
|
43
|
-
super().__init__(kiln_task, prompt_builder=prompt_builder)
|
|
44
|
+
super().__init__(kiln_task, prompt_builder=prompt_builder, tags=tags)
|
|
44
45
|
if custom_model is not None:
|
|
45
46
|
self._model = custom_model
|
|
46
47
|
|
|
@@ -198,6 +199,9 @@ async def langchain_model_from_provider(
|
|
|
198
199
|
if provider.name == ModelProviderName.openai:
|
|
199
200
|
api_key = Config.shared().open_ai_api_key
|
|
200
201
|
return ChatOpenAI(**provider.provider_options, openai_api_key=api_key) # type: ignore[arg-type]
|
|
202
|
+
elif provider.name == ModelProviderName.openai_compatible:
|
|
203
|
+
# See provider_tools.py for how base_url, key and other parameters are set
|
|
204
|
+
return ChatOpenAI(**provider.provider_options) # type: ignore[arg-type]
|
|
201
205
|
elif provider.name == ModelProviderName.groq:
|
|
202
206
|
api_key = Config.shared().groq_api_key
|
|
203
207
|
if api_key is None:
|
|
@@ -23,6 +23,7 @@ class ModelProviderName(str, Enum):
|
|
|
23
23
|
fireworks_ai = "fireworks_ai"
|
|
24
24
|
kiln_fine_tune = "kiln_fine_tune"
|
|
25
25
|
kiln_custom_registry = "kiln_custom_registry"
|
|
26
|
+
openai_compatible = "openai_compatible"
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class ModelFamily(str, Enum):
|
|
@@ -522,6 +523,12 @@ built_in_models: List[KilnModel] = [
|
|
|
522
523
|
}
|
|
523
524
|
},
|
|
524
525
|
),
|
|
526
|
+
KilnModelProvider(
|
|
527
|
+
name=ModelProviderName.groq,
|
|
528
|
+
supports_structured_output=True,
|
|
529
|
+
supports_data_gen=True,
|
|
530
|
+
provider_options={"model": "llama-3.3-70b-versatile"},
|
|
531
|
+
),
|
|
525
532
|
KilnModelProvider(
|
|
526
533
|
name=ModelProviderName.ollama,
|
|
527
534
|
provider_options={"model": "llama3.3"},
|
|
@@ -530,6 +537,8 @@ built_in_models: List[KilnModel] = [
|
|
|
530
537
|
name=ModelProviderName.fireworks_ai,
|
|
531
538
|
# Finetuning not live yet
|
|
532
539
|
# provider_finetune_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
|
|
540
|
+
supports_structured_output=True,
|
|
541
|
+
supports_data_gen=True,
|
|
533
542
|
provider_options={
|
|
534
543
|
"model": "accounts/fireworks/models/llama-v3p3-70b-instruct"
|
|
535
544
|
},
|
|
@@ -108,6 +108,9 @@ async def kiln_model_provider_from(
|
|
|
108
108
|
if provider_name == ModelProviderName.kiln_fine_tune:
|
|
109
109
|
return finetune_provider_model(name)
|
|
110
110
|
|
|
111
|
+
if provider_name == ModelProviderName.openai_compatible:
|
|
112
|
+
return openai_compatible_provider_model(name)
|
|
113
|
+
|
|
111
114
|
built_in_model = await builtin_model_from(name, provider_name)
|
|
112
115
|
if built_in_model:
|
|
113
116
|
return built_in_model
|
|
@@ -136,6 +139,45 @@ async def kiln_model_provider_from(
|
|
|
136
139
|
finetune_cache: dict[str, KilnModelProvider] = {}
|
|
137
140
|
|
|
138
141
|
|
|
142
|
+
def openai_compatible_provider_model(
|
|
143
|
+
model_id: str,
|
|
144
|
+
) -> KilnModelProvider:
|
|
145
|
+
try:
|
|
146
|
+
openai_provider_name, model_id = model_id.split("::")
|
|
147
|
+
except Exception:
|
|
148
|
+
raise ValueError(f"Invalid openai compatible model ID: {model_id}")
|
|
149
|
+
|
|
150
|
+
openai_compatible_providers = Config.shared().openai_compatible_providers or []
|
|
151
|
+
provider = next(
|
|
152
|
+
filter(
|
|
153
|
+
lambda p: p.get("name") == openai_provider_name, openai_compatible_providers
|
|
154
|
+
),
|
|
155
|
+
None,
|
|
156
|
+
)
|
|
157
|
+
if provider is None:
|
|
158
|
+
raise ValueError(f"OpenAI compatible provider {openai_provider_name} not found")
|
|
159
|
+
|
|
160
|
+
# API key optional some providers don't use it
|
|
161
|
+
api_key = provider.get("api_key")
|
|
162
|
+
base_url = provider.get("base_url")
|
|
163
|
+
if base_url is None:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
f"OpenAI compatible provider {openai_provider_name} has no base URL"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
return KilnModelProvider(
|
|
169
|
+
name=ModelProviderName.openai_compatible,
|
|
170
|
+
provider_options={
|
|
171
|
+
"model": model_id,
|
|
172
|
+
"api_key": api_key,
|
|
173
|
+
"openai_api_base": base_url,
|
|
174
|
+
},
|
|
175
|
+
supports_structured_output=False,
|
|
176
|
+
supports_data_gen=False,
|
|
177
|
+
untested_model=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
139
181
|
def finetune_provider_model(
|
|
140
182
|
model_id: str,
|
|
141
183
|
) -> KilnModelProvider:
|
|
@@ -228,6 +270,8 @@ def provider_name_from_id(id: str) -> str:
|
|
|
228
270
|
return "Fireworks AI"
|
|
229
271
|
case ModelProviderName.kiln_custom_registry:
|
|
230
272
|
return "Custom Models"
|
|
273
|
+
case ModelProviderName.openai_compatible:
|
|
274
|
+
return "OpenAI Compatible"
|
|
231
275
|
case _:
|
|
232
276
|
# triggers pyright warning if I miss a case
|
|
233
277
|
raise_exhaustive_error(enum_id)
|
|
@@ -266,6 +310,10 @@ def provider_options_for_custom_model(
|
|
|
266
310
|
raise ValueError(
|
|
267
311
|
"Fine tuned models should populate provider options via another path"
|
|
268
312
|
)
|
|
313
|
+
case ModelProviderName.openai_compatible:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"OpenAI compatible models should populate provider options via another path"
|
|
316
|
+
)
|
|
269
317
|
case _:
|
|
270
318
|
# triggers pyright warning if I miss a case
|
|
271
319
|
raise_exhaustive_error(enum_id)
|
|
@@ -15,6 +15,7 @@ from kiln_ai.adapters.provider_tools import (
|
|
|
15
15
|
finetune_provider_model,
|
|
16
16
|
get_model_and_provider,
|
|
17
17
|
kiln_model_provider_from,
|
|
18
|
+
openai_compatible_provider_model,
|
|
18
19
|
provider_enabled,
|
|
19
20
|
provider_name_from_id,
|
|
20
21
|
provider_options_for_custom_model,
|
|
@@ -64,6 +65,25 @@ def mock_finetune():
|
|
|
64
65
|
yield mock
|
|
65
66
|
|
|
66
67
|
|
|
68
|
+
@pytest.fixture
|
|
69
|
+
def mock_shared_config():
|
|
70
|
+
with patch("kiln_ai.adapters.provider_tools.Config.shared") as mock:
|
|
71
|
+
config = Mock()
|
|
72
|
+
config.openai_compatible_providers = [
|
|
73
|
+
{
|
|
74
|
+
"name": "test_provider",
|
|
75
|
+
"base_url": "https://api.test.com",
|
|
76
|
+
"api_key": "test-key",
|
|
77
|
+
},
|
|
78
|
+
{
|
|
79
|
+
"name": "no_key_provider",
|
|
80
|
+
"base_url": "https://api.nokey.com",
|
|
81
|
+
},
|
|
82
|
+
]
|
|
83
|
+
mock.return_value = config
|
|
84
|
+
yield mock
|
|
85
|
+
|
|
86
|
+
|
|
67
87
|
def test_check_provider_warnings_no_warning(mock_config):
|
|
68
88
|
mock_config.return_value = "some_value"
|
|
69
89
|
|
|
@@ -529,3 +549,78 @@ def test_finetune_provider_model_fireworks_provider(
|
|
|
529
549
|
assert provider.adapter_options == {
|
|
530
550
|
"langchain": {"with_structured_output_options": {"method": "json_mode"}}
|
|
531
551
|
}
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
def test_openai_compatible_provider_model_success(mock_shared_config):
|
|
555
|
+
"""Test successful creation of an OpenAI compatible provider"""
|
|
556
|
+
model_id = "test_provider::gpt-4"
|
|
557
|
+
|
|
558
|
+
provider = openai_compatible_provider_model(model_id)
|
|
559
|
+
|
|
560
|
+
assert provider.name == ModelProviderName.openai_compatible
|
|
561
|
+
assert provider.provider_options == {
|
|
562
|
+
"model": "gpt-4",
|
|
563
|
+
"api_key": "test-key",
|
|
564
|
+
"openai_api_base": "https://api.test.com",
|
|
565
|
+
}
|
|
566
|
+
assert provider.supports_structured_output is False
|
|
567
|
+
assert provider.supports_data_gen is False
|
|
568
|
+
assert provider.untested_model is True
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def test_openai_compatible_provider_model_no_api_key(mock_shared_config):
|
|
572
|
+
"""Test provider creation without API key (should work as some providers don't require it)"""
|
|
573
|
+
model_id = "no_key_provider::gpt-4"
|
|
574
|
+
|
|
575
|
+
provider = openai_compatible_provider_model(model_id)
|
|
576
|
+
|
|
577
|
+
assert provider.name == ModelProviderName.openai_compatible
|
|
578
|
+
assert provider.provider_options == {
|
|
579
|
+
"model": "gpt-4",
|
|
580
|
+
"api_key": None,
|
|
581
|
+
"openai_api_base": "https://api.nokey.com",
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def test_openai_compatible_provider_model_invalid_id():
|
|
586
|
+
"""Test handling of invalid model ID format"""
|
|
587
|
+
with pytest.raises(ValueError) as exc_info:
|
|
588
|
+
openai_compatible_provider_model("invalid-id-format")
|
|
589
|
+
assert (
|
|
590
|
+
str(exc_info.value) == "Invalid openai compatible model ID: invalid-id-format"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def test_openai_compatible_provider_model_no_providers(mock_shared_config):
|
|
595
|
+
"""Test handling when no providers are configured"""
|
|
596
|
+
mock_shared_config.return_value.openai_compatible_providers = None
|
|
597
|
+
|
|
598
|
+
with pytest.raises(ValueError) as exc_info:
|
|
599
|
+
openai_compatible_provider_model("test_provider::gpt-4")
|
|
600
|
+
assert str(exc_info.value) == "OpenAI compatible provider test_provider not found"
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def test_openai_compatible_provider_model_provider_not_found(mock_shared_config):
|
|
604
|
+
"""Test handling of non-existent provider"""
|
|
605
|
+
with pytest.raises(ValueError) as exc_info:
|
|
606
|
+
openai_compatible_provider_model("unknown_provider::gpt-4")
|
|
607
|
+
assert (
|
|
608
|
+
str(exc_info.value) == "OpenAI compatible provider unknown_provider not found"
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def test_openai_compatible_provider_model_no_base_url(mock_shared_config):
|
|
613
|
+
"""Test handling of provider without base URL"""
|
|
614
|
+
mock_shared_config.return_value.openai_compatible_providers = [
|
|
615
|
+
{
|
|
616
|
+
"name": "test_provider",
|
|
617
|
+
"api_key": "test-key",
|
|
618
|
+
}
|
|
619
|
+
]
|
|
620
|
+
|
|
621
|
+
with pytest.raises(ValueError) as exc_info:
|
|
622
|
+
openai_compatible_provider_model("test_provider::gpt-4")
|
|
623
|
+
assert (
|
|
624
|
+
str(exc_info.value)
|
|
625
|
+
== "OpenAI compatible provider test_provider has no base URL"
|
|
626
|
+
)
|
|
@@ -49,9 +49,15 @@ __all__ = [
|
|
|
49
49
|
"DataSource",
|
|
50
50
|
"DataSourceType",
|
|
51
51
|
"DataSourceProperty",
|
|
52
|
+
"Finetune",
|
|
53
|
+
"FineTuneStatusType",
|
|
52
54
|
"TaskOutputRatingType",
|
|
53
55
|
"TaskRequirement",
|
|
54
56
|
"TaskDeterminism",
|
|
57
|
+
"DatasetSplitDefinition",
|
|
58
|
+
"DatasetSplit",
|
|
59
|
+
"RequirementRating",
|
|
60
|
+
"TaskRequirement",
|
|
55
61
|
"strict_mode",
|
|
56
62
|
"set_strict_mode",
|
|
57
63
|
]
|
|
@@ -85,30 +91,71 @@ class TaskOutputRatingType(str, Enum):
|
|
|
85
91
|
"""Defines the types of rating systems available for task outputs."""
|
|
86
92
|
|
|
87
93
|
five_star = "five_star"
|
|
94
|
+
pass_fail = "pass_fail"
|
|
95
|
+
pass_fail_critical = "pass_fail_critical"
|
|
88
96
|
custom = "custom"
|
|
89
97
|
|
|
90
98
|
|
|
99
|
+
class RequirementRating(BaseModel):
|
|
100
|
+
"""Rating for a specific requirement within a task output."""
|
|
101
|
+
|
|
102
|
+
value: float = Field(
|
|
103
|
+
description="The rating value. Interpretation depends on rating type"
|
|
104
|
+
)
|
|
105
|
+
type: TaskOutputRatingType = Field(description="The type of rating")
|
|
106
|
+
|
|
107
|
+
|
|
91
108
|
class TaskOutputRating(KilnBaseModel):
|
|
92
109
|
"""
|
|
93
110
|
A rating for a task output, including an overall rating and ratings for each requirement.
|
|
94
111
|
|
|
95
|
-
|
|
112
|
+
Supports:
|
|
113
|
+
- five_star: 1-5 star ratings
|
|
114
|
+
- pass_fail: boolean pass/fail (1.0 = pass, 0.0 = fail)
|
|
115
|
+
- pass_fail_critical: tri-state (1.0 = pass, 0.0 = fail, -1.0 = critical fail)
|
|
96
116
|
"""
|
|
97
117
|
|
|
98
118
|
type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
|
|
99
119
|
value: float | None = Field(
|
|
100
|
-
description="The
|
|
120
|
+
description="The rating value. Interpretation depends on rating type:\n- five_star: 1-5 stars\n- pass_fail: 1.0 (pass) or 0.0 (fail)\n- pass_fail_critical: 1.0 (pass), 0.0 (fail), or -1.0 (critical fail)",
|
|
101
121
|
default=None,
|
|
102
122
|
)
|
|
103
|
-
requirement_ratings: Dict[ID_TYPE,
|
|
123
|
+
requirement_ratings: Dict[ID_TYPE, RequirementRating] = Field(
|
|
104
124
|
default={},
|
|
105
|
-
description="The ratings of the requirements of the task.
|
|
125
|
+
description="The ratings of the requirements of the task.",
|
|
106
126
|
)
|
|
107
127
|
|
|
128
|
+
# Previously we stored rating values as a dict of floats, but now we store them as RequirementRating objects.
|
|
129
|
+
@model_validator(mode="before")
|
|
130
|
+
def upgrade_old_format(cls, data: dict) -> dict:
|
|
131
|
+
if not isinstance(data, dict):
|
|
132
|
+
return data
|
|
133
|
+
|
|
134
|
+
# Check if we have the old format (dict of floats)
|
|
135
|
+
req_ratings = data.get("requirement_ratings", {})
|
|
136
|
+
if req_ratings and all(
|
|
137
|
+
isinstance(v, (int, float)) for v in req_ratings.values()
|
|
138
|
+
):
|
|
139
|
+
# Convert each float to a RequirementRating object
|
|
140
|
+
# all ratings are five star at the point we used this format
|
|
141
|
+
data["requirement_ratings"] = {
|
|
142
|
+
k: {"value": v, "type": TaskOutputRatingType.five_star}
|
|
143
|
+
for k, v in req_ratings.items()
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
return data
|
|
147
|
+
|
|
108
148
|
# Used to select high quality outputs for example selection (MultiShotPromptBuilder, etc)
|
|
109
149
|
def is_high_quality(self) -> bool:
|
|
150
|
+
if self.value is None:
|
|
151
|
+
return False
|
|
152
|
+
|
|
110
153
|
if self.type == TaskOutputRatingType.five_star:
|
|
111
|
-
return self.value
|
|
154
|
+
return self.value >= 4
|
|
155
|
+
elif self.type == TaskOutputRatingType.pass_fail:
|
|
156
|
+
return self.value == 1.0
|
|
157
|
+
elif self.type == TaskOutputRatingType.pass_fail_critical:
|
|
158
|
+
return self.value == 1.0
|
|
112
159
|
return False
|
|
113
160
|
|
|
114
161
|
@model_validator(mode="after")
|
|
@@ -116,24 +163,61 @@ class TaskOutputRating(KilnBaseModel):
|
|
|
116
163
|
if self.type not in TaskOutputRatingType:
|
|
117
164
|
raise ValueError(f"Invalid rating type: {self.type}")
|
|
118
165
|
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
166
|
+
# Overall rating is optional
|
|
167
|
+
if self.value is not None:
|
|
168
|
+
self._validate_rating(self.type, self.value, "overall rating")
|
|
169
|
+
|
|
170
|
+
for req_id, req_rating in self.requirement_ratings.items():
|
|
171
|
+
self._validate_rating(
|
|
172
|
+
req_rating.type,
|
|
173
|
+
req_rating.value,
|
|
174
|
+
f"requirement rating for req ID: {req_id}",
|
|
175
|
+
)
|
|
124
176
|
|
|
125
177
|
return self
|
|
126
178
|
|
|
127
|
-
def
|
|
128
|
-
|
|
179
|
+
def _validate_rating(
|
|
180
|
+
self, type: TaskOutputRatingType, rating: float | None, rating_name: str
|
|
181
|
+
) -> None:
|
|
182
|
+
if type == TaskOutputRatingType.five_star:
|
|
183
|
+
self._validate_five_star(rating, rating_name)
|
|
184
|
+
elif type == TaskOutputRatingType.pass_fail:
|
|
185
|
+
self._validate_pass_fail(rating, rating_name)
|
|
186
|
+
elif type == TaskOutputRatingType.pass_fail_critical:
|
|
187
|
+
self._validate_pass_fail_critical(rating, rating_name)
|
|
188
|
+
|
|
189
|
+
def _validate_five_star(self, rating: float | None, rating_name: str) -> None:
|
|
190
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
129
191
|
raise ValueError(
|
|
130
|
-
f"{rating_name.capitalize()} of type five_star must be an integer value (1
|
|
192
|
+
f"{rating_name.capitalize()} of type five_star must be an integer value (1-5)"
|
|
131
193
|
)
|
|
132
194
|
if rating < 1 or rating > 5:
|
|
133
195
|
raise ValueError(
|
|
134
196
|
f"{rating_name.capitalize()} of type five_star must be between 1 and 5 stars"
|
|
135
197
|
)
|
|
136
198
|
|
|
199
|
+
def _validate_pass_fail(self, rating: float | None, rating_name: str) -> None:
|
|
200
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
201
|
+
raise ValueError(
|
|
202
|
+
f"{rating_name.capitalize()} of type pass_fail must be an integer value (0 or 1)"
|
|
203
|
+
)
|
|
204
|
+
if rating not in [0, 1]:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"{rating_name.capitalize()} of type pass_fail must be 0 (fail) or 1 (pass)"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _validate_pass_fail_critical(
|
|
210
|
+
self, rating: float | None, rating_name: str
|
|
211
|
+
) -> None:
|
|
212
|
+
if rating is None or not isinstance(rating, float) or not rating.is_integer():
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f"{rating_name.capitalize()} of type pass_fail_critical must be an integer value (-1, 0, or 1)"
|
|
215
|
+
)
|
|
216
|
+
if rating not in [-1, 0, 1]:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"{rating_name.capitalize()} of type pass_fail_critical must be -1 (critical fail), 0 (fail), or 1 (pass)"
|
|
219
|
+
)
|
|
220
|
+
|
|
137
221
|
|
|
138
222
|
class TaskOutput(KilnBaseModel):
|
|
139
223
|
"""
|
|
@@ -381,6 +465,10 @@ class TaskRun(KilnParentedModel):
|
|
|
381
465
|
default=None,
|
|
382
466
|
description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.",
|
|
383
467
|
)
|
|
468
|
+
tags: List[str] = Field(
|
|
469
|
+
default=[],
|
|
470
|
+
description="Tags for the task run. Tags are used to categorize task runs for filtering and reporting.",
|
|
471
|
+
)
|
|
384
472
|
|
|
385
473
|
def parent_task(self) -> Task | None:
|
|
386
474
|
if not isinstance(self.parent, Task):
|
|
@@ -442,6 +530,16 @@ class TaskRun(KilnParentedModel):
|
|
|
442
530
|
raise ValueError("input_source is required when strict mode is enabled")
|
|
443
531
|
return self
|
|
444
532
|
|
|
533
|
+
@model_validator(mode="after")
|
|
534
|
+
def validate_tags(self) -> Self:
|
|
535
|
+
for tag in self.tags:
|
|
536
|
+
if not tag:
|
|
537
|
+
raise ValueError("Tags cannot be empty strings")
|
|
538
|
+
if " " in tag:
|
|
539
|
+
raise ValueError("Tags cannot contain spaces. Try underscores.")
|
|
540
|
+
|
|
541
|
+
return self
|
|
542
|
+
|
|
445
543
|
|
|
446
544
|
# Define the type alias for clarity
|
|
447
545
|
DatasetFilter = Callable[[TaskRun], bool]
|
|
@@ -602,7 +700,7 @@ class TaskRequirement(BaseModel):
|
|
|
602
700
|
Defines a specific requirement that should be met by task outputs.
|
|
603
701
|
|
|
604
702
|
Includes an identifier, name, description, instruction for meeting the requirement,
|
|
605
|
-
and
|
|
703
|
+
priority level, and rating type (five_star, pass_fail, pass_fail_critical, custom).
|
|
606
704
|
"""
|
|
607
705
|
|
|
608
706
|
id: ID_TYPE = ID_FIELD
|
|
@@ -610,6 +708,7 @@ class TaskRequirement(BaseModel):
|
|
|
610
708
|
description: str | None = Field(default=None)
|
|
611
709
|
instruction: str = Field(min_length=1)
|
|
612
710
|
priority: Priority = Field(default=Priority.p2)
|
|
711
|
+
type: TaskOutputRatingType = Field(default=TaskOutputRatingType.five_star)
|
|
613
712
|
|
|
614
713
|
|
|
615
714
|
class TaskDeterminism(str, Enum):
|
|
@@ -142,14 +142,8 @@ class KilnBaseModel(BaseModel):
|
|
|
142
142
|
# modified time of file for cache invalidation. From file descriptor so it's atomic w read.
|
|
143
143
|
mtime_ns = os.fstat(file.fileno()).st_mtime_ns
|
|
144
144
|
file_data = file.read()
|
|
145
|
-
# TODO P2 perf: parsing the JSON twice here.
|
|
146
|
-
# Once for model_type, once for model. Can't call model_validate with parsed json because enum types break; they get strings instead of enums.
|
|
147
145
|
parsed_json = json.loads(file_data)
|
|
148
|
-
m = cls.
|
|
149
|
-
file_data,
|
|
150
|
-
strict=True,
|
|
151
|
-
context={"loading_from_file": True},
|
|
152
|
-
)
|
|
146
|
+
m = cls.model_validate(parsed_json, context={"loading_from_file": True})
|
|
153
147
|
if not isinstance(m, cls):
|
|
154
148
|
raise ValueError(f"Loaded model is not of type {cls.__name__}")
|
|
155
149
|
m._loaded_from_file = True
|
|
@@ -471,7 +465,7 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
471
465
|
validation_errors = []
|
|
472
466
|
|
|
473
467
|
try:
|
|
474
|
-
instance = cls.model_validate(data
|
|
468
|
+
instance = cls.model_validate(data)
|
|
475
469
|
if path is not None:
|
|
476
470
|
instance.path = path
|
|
477
471
|
if parent is not None and isinstance(instance, KilnParentedModel):
|
|
@@ -499,7 +493,7 @@ class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
|
|
|
499
493
|
parent_type._validate_nested(**kwargs)
|
|
500
494
|
elif issubclass(parent_type, KilnParentedModel):
|
|
501
495
|
# Root node
|
|
502
|
-
subinstance = parent_type.model_validate(value
|
|
496
|
+
subinstance = parent_type.model_validate(value)
|
|
503
497
|
if instance is not None:
|
|
504
498
|
subinstance.parent = instance
|
|
505
499
|
if save:
|
|
@@ -439,3 +439,52 @@ def test_task_output_source_validation(tmp_path):
|
|
|
439
439
|
assert os.path.exists(task_missing_output_source)
|
|
440
440
|
task_run = TaskRun.load_from_file(task_missing_output_source)
|
|
441
441
|
assert task_run.output.source is None
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def test_task_run_tags_validation():
|
|
445
|
+
# Setup basic output for TaskRun creation
|
|
446
|
+
output = TaskOutput(
|
|
447
|
+
output="test output",
|
|
448
|
+
source=DataSource(
|
|
449
|
+
type=DataSourceType.synthetic,
|
|
450
|
+
properties={
|
|
451
|
+
"model_name": "test-model",
|
|
452
|
+
"model_provider": "test-provider",
|
|
453
|
+
"adapter_name": "test-adapter",
|
|
454
|
+
},
|
|
455
|
+
),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Test 1: Valid tags should work
|
|
459
|
+
task_run = TaskRun(
|
|
460
|
+
input="test input",
|
|
461
|
+
output=output,
|
|
462
|
+
tags=["test_tag", "another_tag", "tag123"],
|
|
463
|
+
)
|
|
464
|
+
assert task_run.tags == ["test_tag", "another_tag", "tag123"]
|
|
465
|
+
|
|
466
|
+
# Test 2: Empty list of tags should work
|
|
467
|
+
task_run = TaskRun(
|
|
468
|
+
input="test input",
|
|
469
|
+
output=output,
|
|
470
|
+
tags=[],
|
|
471
|
+
)
|
|
472
|
+
assert task_run.tags == []
|
|
473
|
+
|
|
474
|
+
# Test 3: Empty string tag should fail
|
|
475
|
+
with pytest.raises(ValueError) as exc_info:
|
|
476
|
+
TaskRun(
|
|
477
|
+
input="test input",
|
|
478
|
+
output=output,
|
|
479
|
+
tags=["valid_tag", ""],
|
|
480
|
+
)
|
|
481
|
+
assert "Tags cannot be empty strings" in str(exc_info.value)
|
|
482
|
+
|
|
483
|
+
# Test 4: Tag with spaces should fail
|
|
484
|
+
with pytest.raises(ValueError) as exc_info:
|
|
485
|
+
TaskRun(
|
|
486
|
+
input="test input",
|
|
487
|
+
output=output,
|
|
488
|
+
tags=["valid_tag", "invalid tag"],
|
|
489
|
+
)
|
|
490
|
+
assert "Tags cannot contain spaces. Try underscores." in str(exc_info.value)
|