kiln-ai 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (40) hide show
  1. kiln_ai/adapters/__init__.py +2 -0
  2. kiln_ai/adapters/adapter_registry.py +19 -0
  3. kiln_ai/adapters/data_gen/test_data_gen_task.py +29 -21
  4. kiln_ai/adapters/fine_tune/__init__.py +14 -0
  5. kiln_ai/adapters/fine_tune/base_finetune.py +186 -0
  6. kiln_ai/adapters/fine_tune/dataset_formatter.py +187 -0
  7. kiln_ai/adapters/fine_tune/finetune_registry.py +11 -0
  8. kiln_ai/adapters/fine_tune/fireworks_finetune.py +308 -0
  9. kiln_ai/adapters/fine_tune/openai_finetune.py +205 -0
  10. kiln_ai/adapters/fine_tune/test_base_finetune.py +290 -0
  11. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +342 -0
  12. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +455 -0
  13. kiln_ai/adapters/fine_tune/test_openai_finetune.py +503 -0
  14. kiln_ai/adapters/langchain_adapters.py +103 -13
  15. kiln_ai/adapters/ml_model_list.py +218 -304
  16. kiln_ai/adapters/ollama_tools.py +114 -0
  17. kiln_ai/adapters/provider_tools.py +295 -0
  18. kiln_ai/adapters/repair/test_repair_task.py +6 -11
  19. kiln_ai/adapters/test_langchain_adapter.py +46 -18
  20. kiln_ai/adapters/test_ollama_tools.py +42 -0
  21. kiln_ai/adapters/test_prompt_adaptors.py +7 -5
  22. kiln_ai/adapters/test_provider_tools.py +312 -0
  23. kiln_ai/adapters/test_structured_output.py +22 -43
  24. kiln_ai/datamodel/__init__.py +235 -22
  25. kiln_ai/datamodel/basemodel.py +30 -0
  26. kiln_ai/datamodel/registry.py +31 -0
  27. kiln_ai/datamodel/test_basemodel.py +29 -1
  28. kiln_ai/datamodel/test_dataset_split.py +234 -0
  29. kiln_ai/datamodel/test_example_models.py +12 -0
  30. kiln_ai/datamodel/test_models.py +91 -1
  31. kiln_ai/datamodel/test_registry.py +96 -0
  32. kiln_ai/utils/config.py +9 -0
  33. kiln_ai/utils/name_generator.py +125 -0
  34. kiln_ai/utils/test_name_geneator.py +47 -0
  35. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/METADATA +4 -2
  36. kiln_ai-0.7.0.dist-info/RECORD +56 -0
  37. kiln_ai/adapters/test_ml_model_list.py +0 -181
  38. kiln_ai-0.6.1.dist-info/RECORD +0 -37
  39. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/WHEEL +0 -0
  40. {kiln_ai-0.6.1.dist-info → kiln_ai-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -6,6 +6,7 @@ from pydantic import ValidationError
6
6
  from kiln_ai.datamodel import (
7
7
  DataSource,
8
8
  DataSourceType,
9
+ Finetune,
9
10
  Project,
10
11
  Task,
11
12
  TaskOutput,
@@ -27,7 +28,7 @@ def test_project_file(tmp_path):
27
28
 
28
29
  @pytest.fixture
29
30
  def test_task_file(tmp_path):
30
- test_file_path = tmp_path / "task.json"
31
+ test_file_path = tmp_path / "task.kiln"
31
32
  data = {
32
33
  "v": 1,
33
34
  "name": "Test Task",
@@ -225,3 +226,92 @@ def test_task_run_intermediate_outputs():
225
226
  "cot": "chain of thought output",
226
227
  "draft": "draft output",
227
228
  }
229
+
230
+
231
+ def test_finetune_basic():
232
+ # Test basic initialization
233
+ finetune = Finetune(
234
+ name="test-finetune",
235
+ provider="openai",
236
+ base_model_id="gpt-3.5-turbo",
237
+ dataset_split_id="dataset-123",
238
+ train_split_name="train",
239
+ system_message="Test system message",
240
+ )
241
+ assert finetune.name == "test-finetune"
242
+ assert finetune.provider == "openai"
243
+ assert finetune.base_model_id == "gpt-3.5-turbo"
244
+ assert finetune.dataset_split_id == "dataset-123"
245
+ assert finetune.train_split_name == "train"
246
+ assert finetune.provider_id is None
247
+ assert finetune.parameters == {}
248
+ assert finetune.description is None
249
+
250
+
251
+ def test_finetune_full():
252
+ # Test with all fields populated
253
+ finetune = Finetune(
254
+ name="test-finetune",
255
+ description="Test description",
256
+ provider="openai",
257
+ base_model_id="gpt-3.5-turbo",
258
+ provider_id="ft-abc123",
259
+ dataset_split_id="dataset-123",
260
+ train_split_name="train",
261
+ system_message="Test system message",
262
+ parameters={
263
+ "epochs": 3,
264
+ "learning_rate": 0.1,
265
+ "batch_size": 4,
266
+ "use_fp16": True,
267
+ "model_suffix": "-v1",
268
+ },
269
+ )
270
+ assert finetune.description == "Test description"
271
+ assert finetune.provider_id == "ft-abc123"
272
+ assert finetune.parameters == {
273
+ "epochs": 3,
274
+ "learning_rate": 0.1,
275
+ "batch_size": 4,
276
+ "use_fp16": True,
277
+ "model_suffix": "-v1",
278
+ }
279
+ assert finetune.system_message == "Test system message"
280
+
281
+
282
+ def test_finetune_parent_task():
283
+ # Test parent_task() method
284
+ task = Task(name="Test Task", instruction="Test instruction")
285
+ finetune = Finetune(
286
+ name="test-finetune",
287
+ provider="openai",
288
+ base_model_id="gpt-3.5-turbo",
289
+ parent=task,
290
+ dataset_split_id="dataset-123",
291
+ train_split_name="train",
292
+ system_message="Test system message",
293
+ )
294
+
295
+ assert finetune.parent_task() == task
296
+
297
+ # Test with no parent
298
+ finetune_no_parent = Finetune(
299
+ name="test-finetune",
300
+ provider="openai",
301
+ base_model_id="gpt-3.5-turbo",
302
+ dataset_split_id="dataset-123",
303
+ train_split_name="train",
304
+ system_message="Test system message",
305
+ )
306
+ assert finetune_no_parent.parent_task() is None
307
+
308
+
309
+ def test_finetune_parameters_validation():
310
+ # Test that parameters only accept valid types
311
+ with pytest.raises(ValidationError):
312
+ Finetune(
313
+ name="test-finetune",
314
+ provider="openai",
315
+ base_model_id="gpt-3.5-turbo",
316
+ parameters={"invalid": [1, 2, 3]}, # Lists are not allowed
317
+ )
@@ -0,0 +1,96 @@
1
+ from unittest.mock import Mock, patch
2
+
3
+ import pytest
4
+
5
+ from kiln_ai.datamodel import Project
6
+ from kiln_ai.datamodel.registry import all_projects, project_from_id
7
+
8
+
9
+ @pytest.fixture
10
+ def mock_config():
11
+ with patch("kiln_ai.datamodel.registry.Config") as mock:
12
+ config_instance = Mock()
13
+ mock.shared.return_value = config_instance
14
+ yield config_instance
15
+
16
+
17
+ @pytest.fixture
18
+ def mock_project():
19
+ def create_mock_project(project_id: str = "test-id"):
20
+ project = Mock(spec=Project)
21
+ project.id = project_id
22
+ return project
23
+
24
+ return create_mock_project
25
+
26
+
27
+ def test_all_projects_empty(mock_config):
28
+ mock_config.projects = None
29
+ assert all_projects() == []
30
+
31
+
32
+ def test_all_projects_success(mock_config, mock_project):
33
+ mock_config.projects = ["path1", "path2"]
34
+
35
+ project1 = mock_project("project1")
36
+ project2 = mock_project("project2")
37
+
38
+ with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
39
+ mock_load.side_effect = [project1, project2]
40
+
41
+ result = all_projects()
42
+
43
+ assert len(result) == 2
44
+ assert result[0] == project1
45
+ assert result[1] == project2
46
+ mock_load.assert_any_call("path1")
47
+ mock_load.assert_any_call("path2")
48
+
49
+
50
+ def test_all_projects_with_errors(mock_config, mock_project):
51
+ mock_config.projects = ["path1", "path2", "path3"]
52
+
53
+ project1 = mock_project("project1")
54
+ project3 = mock_project("project3")
55
+
56
+ with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
57
+ mock_load.side_effect = [project1, Exception("File not found"), project3]
58
+
59
+ result = all_projects()
60
+
61
+ assert len(result) == 2
62
+ assert result[0] == project1
63
+ assert result[1] == project3
64
+
65
+
66
+ def test_project_from_id_not_found(mock_config):
67
+ mock_config.projects = None
68
+ assert project_from_id("any-id") is None
69
+
70
+
71
+ def test_project_from_id_success(mock_config, mock_project):
72
+ mock_config.projects = ["path1", "path2"]
73
+
74
+ project1 = mock_project("project1")
75
+ project2 = mock_project("project2")
76
+
77
+ with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
78
+ mock_load.side_effect = [project1, project2]
79
+
80
+ result = project_from_id("project2")
81
+
82
+ assert result == project2
83
+
84
+
85
+ def test_project_from_id_with_errors(mock_config, mock_project):
86
+ mock_config.projects = ["path1", "path2", "path3"]
87
+
88
+ project1 = mock_project("project1")
89
+ project3 = mock_project("target-id")
90
+
91
+ with patch("kiln_ai.datamodel.Project.load_from_file") as mock_load:
92
+ mock_load.side_effect = [project1, Exception("File not found"), project3]
93
+
94
+ result = project_from_id("target-id")
95
+
96
+ assert result == project3
kiln_ai/utils/config.py CHANGED
@@ -67,6 +67,15 @@ class Config:
67
67
  env_var="OPENROUTER_API_KEY",
68
68
  sensitive=True,
69
69
  ),
70
+ "fireworks_api_key": ConfigProperty(
71
+ str,
72
+ env_var="FIREWORKS_API_KEY",
73
+ sensitive=True,
74
+ ),
75
+ "fireworks_account_id": ConfigProperty(
76
+ str,
77
+ env_var="FIREWORKS_ACCOUNT_ID",
78
+ ),
70
79
  "projects": ConfigProperty(
71
80
  list,
72
81
  default_lambda=lambda: [],
@@ -0,0 +1,125 @@
1
+ from random import choice
2
+ from typing import List
3
+
4
+ ADJECTIVES: List[str] = [
5
+ "Curious",
6
+ "Playful",
7
+ "Mighty",
8
+ "Gentle",
9
+ "Clever",
10
+ "Brave",
11
+ "Cosmic",
12
+ "Dancing",
13
+ "Electric",
14
+ "Fierce",
15
+ "Glowing",
16
+ "Hidden",
17
+ "Infinite",
18
+ "Jolly",
19
+ "Magical",
20
+ "Ancient",
21
+ "Blazing",
22
+ "Celestial",
23
+ "Dazzling",
24
+ "Emerald",
25
+ "Floating",
26
+ "Graceful",
27
+ "Harmonious",
28
+ "Icy",
29
+ "Jade",
30
+ "Kinetic",
31
+ "Luminous",
32
+ "Mystic",
33
+ "Noble",
34
+ "Opal",
35
+ "Peaceful",
36
+ "Quantum",
37
+ "Radiant",
38
+ "Silent",
39
+ "Thundering",
40
+ "Untamed",
41
+ "Vibrant",
42
+ "Whispering",
43
+ "Xenial",
44
+ "Yearning",
45
+ "Zealous",
46
+ "Astral",
47
+ "Boundless",
48
+ "Crimson",
49
+ "Divine",
50
+ "Ethereal",
51
+ "Fabled",
52
+ "Golden",
53
+ "Heroic",
54
+ "Imperial",
55
+ ]
56
+
57
+ NOUNS: List[str] = [
58
+ "Penguin",
59
+ "Dragon",
60
+ "Phoenix",
61
+ "Tiger",
62
+ "Dolphin",
63
+ "Mountain",
64
+ "River",
65
+ "Forest",
66
+ "Cloud",
67
+ "Star",
68
+ "Crystal",
69
+ "Garden",
70
+ "Ocean",
71
+ "Falcon",
72
+ "Wizard",
73
+ "Aurora",
74
+ "Badger",
75
+ "Comet",
76
+ "Dryad",
77
+ "Eagle",
78
+ "Fox",
79
+ "Griffin",
80
+ "Harbor",
81
+ "Island",
82
+ "Jaguar",
83
+ "Knight",
84
+ "Lion",
85
+ "Mermaid",
86
+ "Nebula",
87
+ "Owl",
88
+ "Panther",
89
+ "Quasar",
90
+ "Raven",
91
+ "Serpent",
92
+ "Tempest",
93
+ "Unicorn",
94
+ "Valley",
95
+ "Wolf",
96
+ "Sphinx",
97
+ "Yeti",
98
+ "Zenith",
99
+ "Archer",
100
+ "Beacon",
101
+ "Cascade",
102
+ "Dreamer",
103
+ "Echo",
104
+ "Flame",
105
+ "Glacier",
106
+ "Horizon",
107
+ "Ivy",
108
+ ]
109
+
110
+
111
+ def generate_memorable_name() -> str:
112
+ """
113
+ Generates a memorable two-word name combining a random adjective and noun.
114
+
115
+ Returns:
116
+ str: A memorable name in the format "Adjective Noun"
117
+
118
+ Example:
119
+ >>> generate_memorable_name()
120
+ 'Cosmic Dragon'
121
+ """
122
+ adjective = choice(ADJECTIVES)
123
+ noun = choice(NOUNS)
124
+
125
+ return f"{adjective} {noun}"
@@ -0,0 +1,47 @@
1
+ from kiln_ai.utils.name_generator import ADJECTIVES, NOUNS, generate_memorable_name
2
+
3
+
4
+ def test_generate_memorable_name_format():
5
+ """Test that generated name follows the expected format."""
6
+ name = generate_memorable_name()
7
+
8
+ # Check that we get exactly two words
9
+ words = name.split()
10
+ assert len(words) == 2
11
+
12
+ # Check that first word is an adjective and second word is a noun
13
+ assert words[0] in ADJECTIVES
14
+ assert words[1] in NOUNS
15
+
16
+
17
+ def test_generate_memorable_name_randomness():
18
+ """Test that the function generates different names."""
19
+ names = {generate_memorable_name() for _ in range(100)}
20
+
21
+ # With 50 adjectives and 50 nouns, we should get multiple unique combinations
22
+ # in 100 tries. Using 50 as a reasonable lower bound.
23
+ assert len(names) > 50
24
+
25
+
26
+ def test_generate_memorable_name_string_type():
27
+ """Test that the generated name is a string."""
28
+ name = generate_memorable_name()
29
+ assert isinstance(name, str)
30
+
31
+
32
+ def test_word_lists_not_empty():
33
+ """Test that our word lists contain entries."""
34
+ assert len(ADJECTIVES) > 0
35
+ assert len(NOUNS) > 0
36
+
37
+
38
+ def test_word_lists_are_strings():
39
+ """Test that all entries in word lists are strings."""
40
+ assert all(isinstance(word, str) for word in ADJECTIVES)
41
+ assert all(isinstance(word, str) for word in NOUNS)
42
+
43
+
44
+ def test_word_lists_no_duplicates():
45
+ """Test that word lists don't contain duplicates."""
46
+ assert len(ADJECTIVES) == len(set(ADJECTIVES))
47
+ assert len(NOUNS) == len(set(NOUNS))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: kiln-ai
3
- Version: 0.6.1
3
+ Version: 0.7.0
4
4
  Summary: Kiln AI
5
5
  Project-URL: Homepage, https://getkiln.ai
6
6
  Project-URL: Repository, https://github.com/Kiln-AI/kiln
@@ -17,10 +17,12 @@ Requires-Python: >=3.10
17
17
  Requires-Dist: coverage>=7.6.4
18
18
  Requires-Dist: jsonschema>=4.23.0
19
19
  Requires-Dist: langchain-aws>=0.2.4
20
+ Requires-Dist: langchain-fireworks>=0.2.5
20
21
  Requires-Dist: langchain-groq>=0.2.0
21
22
  Requires-Dist: langchain-ollama>=0.2.0
22
23
  Requires-Dist: langchain-openai>=0.2.4
23
24
  Requires-Dist: langchain>=0.3.5
25
+ Requires-Dist: openai>=1.53.0
24
26
  Requires-Dist: pdoc>=15.0.0
25
27
  Requires-Dist: pydantic>=2.9.2
26
28
  Requires-Dist: pytest-cov>=6.0.0
@@ -28,7 +30,7 @@ Requires-Dist: pyyaml>=6.0.2
28
30
  Requires-Dist: typing-extensions>=4.12.2
29
31
  Description-Content-Type: text/markdown
30
32
 
31
- # kiln_ai
33
+ # Kiln AI Core Library
32
34
 
33
35
  <p align="center">
34
36
  <picture>
@@ -0,0 +1,56 @@
1
+ kiln_ai/__init__.py,sha256=Sc4z8LRVFMwJUoc_DPVUriSXTZ6PO9MaJ80PhRbKyB8,34
2
+ kiln_ai/adapters/__init__.py,sha256=8-YlnTh3gsaPeEArFVLIqGE7-tbssI42fub4OQBp_DA,970
3
+ kiln_ai/adapters/adapter_registry.py,sha256=EnB0rUIZ0KbBd2nxkNjwUqOpldwqPDyJ9LzIQoDl2GU,634
4
+ kiln_ai/adapters/base_adapter.py,sha256=E_RfXxzEhW-i066xOhZdPuTM7OPKQv70hDpfMsxfYEs,6145
5
+ kiln_ai/adapters/langchain_adapters.py,sha256=NeTZ8WbQTnVu8rtFX6AwkdjFj2ihyhe_vxNxM-_v2yE,10584
6
+ kiln_ai/adapters/ml_model_list.py,sha256=jU0Qd3B8vDLTeZcNnM2amF0-fWwLUBk7eayY27YAGDU,24196
7
+ kiln_ai/adapters/ollama_tools.py,sha256=qja3W4Ubtlw1R1E1_YgSexvtWZ7OanLldYfOdFGTMJ4,3521
8
+ kiln_ai/adapters/prompt_builders.py,sha256=Mdu-f1mC9hWIDwoF7Qwd9F99GDx6oNGvtEZN-SrOsNM,10325
9
+ kiln_ai/adapters/provider_tools.py,sha256=BVNhXNtFOG-qz1mj1Q94KG2MfZo-rLoH2pqwKt9Ldi0,10209
10
+ kiln_ai/adapters/test_langchain_adapter.py,sha256=-6ZUI94jsgDsLcuCVyhx9gC50vcC2oLGF-Wa3rGcZzI,5672
11
+ kiln_ai/adapters/test_ollama_tools.py,sha256=2KwYVaj3ySV3ld-z51TCGbJEMdb3MZj2eoEicIWz3Q4,2552
12
+ kiln_ai/adapters/test_prompt_adaptors.py,sha256=Mc0oSYgDLxfP2u3GVR_iDWaYctTQ8Ug1u6UGvWA90lM,7494
13
+ kiln_ai/adapters/test_prompt_builders.py,sha256=sU0bSBZa9Y4Q-mmkDf3HbQ0MNSWk5o9bC9sNgtnBokk,14598
14
+ kiln_ai/adapters/test_provider_tools.py,sha256=LRH7QvleIEVA6Nkvfq79R6tDGPcI8QtmOqe0LQhQmFw,9948
15
+ kiln_ai/adapters/test_saving_adapter_results.py,sha256=SYYh2xY1zmeKhFHfWAuEY4pEiLd8SitSV5ewGOTmaOI,6447
16
+ kiln_ai/adapters/test_structured_output.py,sha256=9Mgng-HOXiZ_WcJG5cpMWhtsdJt8Rn-7qIouBWvWVoU,9324
17
+ kiln_ai/adapters/data_gen/__init__.py,sha256=QTZWaf7kq5BorhPvexJfwDEKmjRmIbhwW9ei8LW2SIs,276
18
+ kiln_ai/adapters/data_gen/data_gen_prompts.py,sha256=kudjHnAz7L3q0k_NLyTlaIV7M0uRFrxXNcfcnjOE2uc,5810
19
+ kiln_ai/adapters/data_gen/data_gen_task.py,sha256=vwjC47YDrsl4GtBJpK6FWh07TGd8CalhZOX4p4YBX8w,5904
20
+ kiln_ai/adapters/data_gen/test_data_gen_task.py,sha256=TC_n1iWgfLp87q7eNE3ZunVCuk_J25vfw-ohi2qtnp0,9668
21
+ kiln_ai/adapters/fine_tune/__init__.py,sha256=DxdTR60chwgck1aEoVYWyfWi6Ed2ZkdJj0lar-SEAj4,257
22
+ kiln_ai/adapters/fine_tune/base_finetune.py,sha256=-3hyWZXImJomaZeAME6mxbjifQDAn7hwlgTm8VVkxkg,5861
23
+ kiln_ai/adapters/fine_tune/dataset_formatter.py,sha256=DzmUaCaUalTYaX2aNtnb_oucb5ZghI13RDVwtxECMUU,6340
24
+ kiln_ai/adapters/fine_tune/finetune_registry.py,sha256=H1B-opCTlIyd9JlIFTKsY_ctxUX9ziEc49_gnmg1SZg,483
25
+ kiln_ai/adapters/fine_tune/fireworks_finetune.py,sha256=B5o_-A0_Y_QYtgUXZWhKAjR1MeCXvZWz5scZZuK3pMg,13303
26
+ kiln_ai/adapters/fine_tune/openai_finetune.py,sha256=WJKczDN7CA1TJnIokzZu7hbcZiOv9JIRA1scv1zDe8o,8312
27
+ kiln_ai/adapters/fine_tune/test_base_finetune.py,sha256=YOCdQCL5Q0kpBiaU3hccafknCg0kIFRyp16lttR2Io0,9843
28
+ kiln_ai/adapters/fine_tune/test_dataset_formatter.py,sha256=7atbHb4kFtgSmHQMNrSnNpH2ZO8drpnfwKWCsx1p8mM,11127
29
+ kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py,sha256=Y6r5BxsevFeEUHJikfFLeeG6fbPvLOxQpqIMpn-SpvU,15272
30
+ kiln_ai/adapters/fine_tune/test_openai_finetune.py,sha256=EF-f0JbVaPiVXF0eBYbwTKdi5thA45s-XbVB0iUBI00,16629
31
+ kiln_ai/adapters/repair/__init__.py,sha256=dOO9MEpEhjiwzDVFg3MNfA2bKMPlax9iekDatpTkX8E,217
32
+ kiln_ai/adapters/repair/repair_task.py,sha256=VXvX1l9AYDE_GV0i3S_vPThltJoQlCFVCCHV9m-QA7k,3297
33
+ kiln_ai/adapters/repair/test_repair_task.py,sha256=JBcyqyQYWniiUo4FSle9kUEsnbTsl5JN1LTRN1SRnrE,7940
34
+ kiln_ai/datamodel/__init__.py,sha256=Fp8sdRrRLBt8alWc93jiHApv8uwTz_DHRSMcviWo2Ok,22484
35
+ kiln_ai/datamodel/basemodel.py,sha256=OedsYQ0-RJ8k4Zl1VslG1xmbJu76WxT7IWQ_wlibuKc,18657
36
+ kiln_ai/datamodel/json_schema.py,sha256=l4BIq1ItLHgcSHqsqDOchegLLHY48U4yR0SP2aMb4i0,2449
37
+ kiln_ai/datamodel/registry.py,sha256=XwGFXJFKZtOpR1Z9ven6SftggfADdZRm8TFxCEVtfUQ,957
38
+ kiln_ai/datamodel/test_basemodel.py,sha256=S2-BaqVEaEqzk7otuKdGw3i5ERT2m8MSZHlk4Dy0nyc,10466
39
+ kiln_ai/datamodel/test_dataset_split.py,sha256=aBjHVyTdt4mWXEKBkvvchEEZSj8jUwhXRZ37LbBxTi4,7265
40
+ kiln_ai/datamodel/test_datasource.py,sha256=GAiZz31qezVVPwFqnt8wHMu15WvtlV89jw8C1Ue6YNI,3165
41
+ kiln_ai/datamodel/test_example_models.py,sha256=9Jhc0bvbM4hCjJGiQNgWH5rwyIsGuneAD8h4o1P3zAY,20356
42
+ kiln_ai/datamodel/test_json_schema.py,sha256=vdLnTQxxrcmuSrf6iOmkrmpfh7JnxqIw4B4dbDAAcZ4,3199
43
+ kiln_ai/datamodel/test_models.py,sha256=iqwE4iW695BVoDnOSg-HUiXOuuzOZIfsvdaYht7qBm4,9441
44
+ kiln_ai/datamodel/test_nested_save.py,sha256=xciCddqvPyKyoyjC5Lx_3Kh1t4LJv1xYRAPazR3SRcs,5588
45
+ kiln_ai/datamodel/test_output_rating.py,sha256=iw7fVUAPORA-0-VFiikZV3NDycGFaFMHSX1a38t_aQA,2647
46
+ kiln_ai/datamodel/test_registry.py,sha256=PhS4anLi5Bf_023obuTlO5DALhtPB8WIc_bX12Yg6Po,2705
47
+ kiln_ai/utils/__init__.py,sha256=PTD0MwBCKAMIOGsTAwsFaJOusTJJoRFTfOGqRvCaU-E,142
48
+ kiln_ai/utils/config.py,sha256=voYp5NwESKm4ZdTAq6MttQCNcLYEA_1w5zB2Eth9-gg,5803
49
+ kiln_ai/utils/formatting.py,sha256=VtB9oag0lOGv17dwT7OPX_3HzBfaU9GsLH-iLete0yM,97
50
+ kiln_ai/utils/name_generator.py,sha256=v26TgpCwQbhQFcZvzgjZvURinjrOyyFhxpsI6NQrHKc,1914
51
+ kiln_ai/utils/test_config.py,sha256=pTYItz5WD15rTRdxKE7vszXF_mb-dik2qrFWzkVemEY,7671
52
+ kiln_ai/utils/test_name_geneator.py,sha256=9-hSTBshyakqlPbFnNcggwLrL7lcPTitauBYHg9jFWI,1513
53
+ kiln_ai-0.7.0.dist-info/METADATA,sha256=V3_VmqXJrKcHyWnfcU7xPoIFhXbbDLm04hVciUJn_f8,3064
54
+ kiln_ai-0.7.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
55
+ kiln_ai-0.7.0.dist-info/licenses/LICENSE.txt,sha256=_NA5pnTYgRRr4qH6lE3X-TuZJ8iRcMUi5ASoGr-lEx8,1209
56
+ kiln_ai-0.7.0.dist-info/RECORD,,
@@ -1,181 +0,0 @@
1
- import json
2
- from unittest.mock import patch
3
-
4
- import pytest
5
-
6
- from kiln_ai.adapters.ml_model_list import (
7
- ModelName,
8
- ModelProviderName,
9
- OllamaConnection,
10
- check_provider_warnings,
11
- get_model_and_provider,
12
- ollama_model_supported,
13
- parse_ollama_tags,
14
- provider_name_from_id,
15
- provider_warnings,
16
- )
17
-
18
-
19
- @pytest.fixture
20
- def mock_config():
21
- with patch("kiln_ai.adapters.ml_model_list.get_config_value") as mock:
22
- yield mock
23
-
24
-
25
- def test_check_provider_warnings_no_warning(mock_config):
26
- mock_config.return_value = "some_value"
27
-
28
- # This should not raise an exception
29
- check_provider_warnings(ModelProviderName.amazon_bedrock)
30
-
31
-
32
- def test_check_provider_warnings_missing_key(mock_config):
33
- mock_config.return_value = None
34
-
35
- with pytest.raises(ValueError) as exc_info:
36
- check_provider_warnings(ModelProviderName.amazon_bedrock)
37
-
38
- assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
39
- exc_info.value
40
- )
41
-
42
-
43
- def test_check_provider_warnings_unknown_provider():
44
- # This should not raise an exception, as no settings are required for unknown providers
45
- check_provider_warnings("unknown_provider")
46
-
47
-
48
- @pytest.mark.parametrize(
49
- "provider_name",
50
- [
51
- ModelProviderName.amazon_bedrock,
52
- ModelProviderName.openrouter,
53
- ModelProviderName.groq,
54
- ModelProviderName.openai,
55
- ],
56
- )
57
- def test_check_provider_warnings_all_providers(mock_config, provider_name):
58
- mock_config.return_value = None
59
-
60
- with pytest.raises(ValueError) as exc_info:
61
- check_provider_warnings(provider_name)
62
-
63
- assert provider_warnings[provider_name].message in str(exc_info.value)
64
-
65
-
66
- def test_check_provider_warnings_partial_keys_set(mock_config):
67
- def mock_get(key):
68
- return "value" if key == "bedrock_access_key" else None
69
-
70
- mock_config.side_effect = mock_get
71
-
72
- with pytest.raises(ValueError) as exc_info:
73
- check_provider_warnings(ModelProviderName.amazon_bedrock)
74
-
75
- assert provider_warnings[ModelProviderName.amazon_bedrock].message in str(
76
- exc_info.value
77
- )
78
-
79
-
80
- def test_provider_name_from_id_unknown_provider():
81
- assert (
82
- provider_name_from_id("unknown_provider")
83
- == "Unknown provider: unknown_provider"
84
- )
85
-
86
-
87
- def test_provider_name_from_id_case_sensitivity():
88
- assert (
89
- provider_name_from_id(ModelProviderName.amazon_bedrock.upper())
90
- == "Unknown provider: AMAZON_BEDROCK"
91
- )
92
-
93
-
94
- @pytest.mark.parametrize(
95
- "provider_id, expected_name",
96
- [
97
- (ModelProviderName.amazon_bedrock, "Amazon Bedrock"),
98
- (ModelProviderName.openrouter, "OpenRouter"),
99
- (ModelProviderName.groq, "Groq"),
100
- (ModelProviderName.ollama, "Ollama"),
101
- (ModelProviderName.openai, "OpenAI"),
102
- ],
103
- )
104
- def test_provider_name_from_id_parametrized(provider_id, expected_name):
105
- assert provider_name_from_id(provider_id) == expected_name
106
-
107
-
108
- def test_parse_ollama_tags_no_models():
109
- json_response = '{"models":[{"name":"phi3.5:latest","model":"phi3.5:latest","modified_at":"2024-10-02T12:04:35.191519822-04:00","size":2176178843,"digest":"61819fb370a3c1a9be6694869331e5f85f867a079e9271d66cb223acb81d04ba","details":{"parent_model":"","format":"gguf","family":"phi3","families":["phi3"],"parameter_size":"3.8B","quantization_level":"Q4_0"}},{"name":"gemma2:2b","model":"gemma2:2b","modified_at":"2024-09-09T16:46:38.64348929-04:00","size":1629518495,"digest":"8ccf136fdd5298f3ffe2d69862750ea7fb56555fa4d5b18c04e3fa4d82ee09d7","details":{"parent_model":"","format":"gguf","family":"gemma2","families":["gemma2"],"parameter_size":"2.6B","quantization_level":"Q4_0"}},{"name":"llama3.1:latest","model":"llama3.1:latest","modified_at":"2024-09-01T17:19:43.481523695-04:00","size":4661230720,"digest":"f66fc8dc39ea206e03ff6764fcc696b1b4dfb693f0b6ef751731dd4e6269046e","details":{"parent_model":"","format":"gguf","family":"llama","families":["llama"],"parameter_size":"8.0B","quantization_level":"Q4_0"}}]}'
110
- tags = json.loads(json_response)
111
- print(json.dumps(tags, indent=2))
112
- conn = parse_ollama_tags(tags)
113
- assert "phi3.5:latest" in conn.models
114
- assert "gemma2:2b" in conn.models
115
- assert "llama3.1:latest" in conn.models
116
-
117
-
118
- def test_ollama_model_supported():
119
- conn = OllamaConnection(
120
- models=["phi3.5:latest", "gemma2:2b", "llama3.1:latest"], message="Connected"
121
- )
122
- assert ollama_model_supported(conn, "phi3.5:latest")
123
- assert ollama_model_supported(conn, "phi3.5")
124
- assert ollama_model_supported(conn, "gemma2:2b")
125
- assert ollama_model_supported(conn, "llama3.1:latest")
126
- assert ollama_model_supported(conn, "llama3.1")
127
- assert not ollama_model_supported(conn, "unknown_model")
128
-
129
-
130
- def test_get_model_and_provider_valid():
131
- # Test with a known valid model and provider combination
132
- model, provider = get_model_and_provider(
133
- ModelName.phi_3_5, ModelProviderName.ollama
134
- )
135
-
136
- assert model is not None
137
- assert provider is not None
138
- assert model.name == ModelName.phi_3_5
139
- assert provider.name == ModelProviderName.ollama
140
- assert provider.provider_options["model"] == "phi3.5"
141
-
142
-
143
- def test_get_model_and_provider_invalid_model():
144
- # Test with an invalid model name
145
- model, provider = get_model_and_provider(
146
- "nonexistent_model", ModelProviderName.ollama
147
- )
148
-
149
- assert model is None
150
- assert provider is None
151
-
152
-
153
- def test_get_model_and_provider_invalid_provider():
154
- # Test with a valid model but invalid provider
155
- model, provider = get_model_and_provider(ModelName.phi_3_5, "nonexistent_provider")
156
-
157
- assert model is None
158
- assert provider is None
159
-
160
-
161
- def test_get_model_and_provider_valid_model_wrong_provider():
162
- # Test with a valid model but a provider that doesn't support it
163
- model, provider = get_model_and_provider(
164
- ModelName.phi_3_5, ModelProviderName.amazon_bedrock
165
- )
166
-
167
- assert model is None
168
- assert provider is None
169
-
170
-
171
- def test_get_model_and_provider_multiple_providers():
172
- # Test with a model that has multiple providers
173
- model, provider = get_model_and_provider(
174
- ModelName.llama_3_1_70b, ModelProviderName.groq
175
- )
176
-
177
- assert model is not None
178
- assert provider is not None
179
- assert model.name == ModelName.llama_3_1_70b
180
- assert provider.name == ModelProviderName.groq
181
- assert provider.provider_options["model"] == "llama-3.1-70b-versatile"