kiln-ai 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kiln-ai might be problematic. Click here for more details.

Files changed (158) hide show
  1. kiln_ai/adapters/__init__.py +8 -2
  2. kiln_ai/adapters/adapter_registry.py +43 -208
  3. kiln_ai/adapters/chat/chat_formatter.py +8 -12
  4. kiln_ai/adapters/chat/test_chat_formatter.py +6 -2
  5. kiln_ai/adapters/chunkers/__init__.py +13 -0
  6. kiln_ai/adapters/chunkers/base_chunker.py +42 -0
  7. kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
  8. kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
  9. kiln_ai/adapters/chunkers/helpers.py +23 -0
  10. kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
  11. kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
  12. kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
  13. kiln_ai/adapters/chunkers/test_helpers.py +75 -0
  14. kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
  15. kiln_ai/adapters/docker_model_runner_tools.py +119 -0
  16. kiln_ai/adapters/embedding/__init__.py +0 -0
  17. kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
  18. kiln_ai/adapters/embedding/embedding_registry.py +32 -0
  19. kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
  20. kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
  21. kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
  22. kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
  23. kiln_ai/adapters/eval/base_eval.py +2 -2
  24. kiln_ai/adapters/eval/eval_runner.py +9 -3
  25. kiln_ai/adapters/eval/g_eval.py +2 -2
  26. kiln_ai/adapters/eval/test_base_eval.py +2 -4
  27. kiln_ai/adapters/eval/test_g_eval.py +4 -5
  28. kiln_ai/adapters/extractors/__init__.py +18 -0
  29. kiln_ai/adapters/extractors/base_extractor.py +72 -0
  30. kiln_ai/adapters/extractors/encoding.py +20 -0
  31. kiln_ai/adapters/extractors/extractor_registry.py +44 -0
  32. kiln_ai/adapters/extractors/extractor_runner.py +112 -0
  33. kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
  34. kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
  35. kiln_ai/adapters/extractors/test_encoding.py +54 -0
  36. kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
  37. kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
  38. kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
  39. kiln_ai/adapters/fine_tune/__init__.py +1 -1
  40. kiln_ai/adapters/fine_tune/openai_finetune.py +14 -4
  41. kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
  42. kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
  43. kiln_ai/adapters/fine_tune/test_openai_finetune.py +108 -111
  44. kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
  45. kiln_ai/adapters/ml_embedding_model_list.py +192 -0
  46. kiln_ai/adapters/ml_model_list.py +761 -37
  47. kiln_ai/adapters/model_adapters/base_adapter.py +51 -21
  48. kiln_ai/adapters/model_adapters/litellm_adapter.py +380 -138
  49. kiln_ai/adapters/model_adapters/test_base_adapter.py +193 -17
  50. kiln_ai/adapters/model_adapters/test_litellm_adapter.py +407 -2
  51. kiln_ai/adapters/model_adapters/test_litellm_adapter_tools.py +1103 -0
  52. kiln_ai/adapters/model_adapters/test_saving_adapter_results.py +5 -5
  53. kiln_ai/adapters/model_adapters/test_structured_output.py +113 -5
  54. kiln_ai/adapters/ollama_tools.py +69 -12
  55. kiln_ai/adapters/parsers/__init__.py +1 -1
  56. kiln_ai/adapters/provider_tools.py +205 -47
  57. kiln_ai/adapters/rag/deduplication.py +49 -0
  58. kiln_ai/adapters/rag/progress.py +252 -0
  59. kiln_ai/adapters/rag/rag_runners.py +844 -0
  60. kiln_ai/adapters/rag/test_deduplication.py +195 -0
  61. kiln_ai/adapters/rag/test_progress.py +785 -0
  62. kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
  63. kiln_ai/adapters/remote_config.py +80 -8
  64. kiln_ai/adapters/repair/test_repair_task.py +12 -9
  65. kiln_ai/adapters/run_output.py +3 -0
  66. kiln_ai/adapters/test_adapter_registry.py +657 -85
  67. kiln_ai/adapters/test_docker_model_runner_tools.py +305 -0
  68. kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
  69. kiln_ai/adapters/test_ml_model_list.py +251 -1
  70. kiln_ai/adapters/test_ollama_tools.py +340 -1
  71. kiln_ai/adapters/test_prompt_adaptors.py +13 -6
  72. kiln_ai/adapters/test_prompt_builders.py +1 -1
  73. kiln_ai/adapters/test_provider_tools.py +254 -8
  74. kiln_ai/adapters/test_remote_config.py +651 -58
  75. kiln_ai/adapters/vector_store/__init__.py +1 -0
  76. kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
  77. kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
  78. kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
  79. kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
  80. kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
  81. kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
  82. kiln_ai/datamodel/__init__.py +39 -34
  83. kiln_ai/datamodel/basemodel.py +170 -1
  84. kiln_ai/datamodel/chunk.py +158 -0
  85. kiln_ai/datamodel/datamodel_enums.py +28 -0
  86. kiln_ai/datamodel/embedding.py +64 -0
  87. kiln_ai/datamodel/eval.py +1 -1
  88. kiln_ai/datamodel/external_tool_server.py +298 -0
  89. kiln_ai/datamodel/extraction.py +303 -0
  90. kiln_ai/datamodel/json_schema.py +25 -10
  91. kiln_ai/datamodel/project.py +40 -1
  92. kiln_ai/datamodel/rag.py +79 -0
  93. kiln_ai/datamodel/registry.py +0 -15
  94. kiln_ai/datamodel/run_config.py +62 -0
  95. kiln_ai/datamodel/task.py +2 -77
  96. kiln_ai/datamodel/task_output.py +6 -1
  97. kiln_ai/datamodel/task_run.py +41 -0
  98. kiln_ai/datamodel/test_attachment.py +649 -0
  99. kiln_ai/datamodel/test_basemodel.py +4 -4
  100. kiln_ai/datamodel/test_chunk_models.py +317 -0
  101. kiln_ai/datamodel/test_dataset_split.py +1 -1
  102. kiln_ai/datamodel/test_embedding_models.py +448 -0
  103. kiln_ai/datamodel/test_eval_model.py +6 -6
  104. kiln_ai/datamodel/test_example_models.py +175 -0
  105. kiln_ai/datamodel/test_external_tool_server.py +691 -0
  106. kiln_ai/datamodel/test_extraction_chunk.py +206 -0
  107. kiln_ai/datamodel/test_extraction_model.py +470 -0
  108. kiln_ai/datamodel/test_rag.py +641 -0
  109. kiln_ai/datamodel/test_registry.py +8 -3
  110. kiln_ai/datamodel/test_task.py +15 -47
  111. kiln_ai/datamodel/test_tool_id.py +320 -0
  112. kiln_ai/datamodel/test_vector_store.py +320 -0
  113. kiln_ai/datamodel/tool_id.py +105 -0
  114. kiln_ai/datamodel/vector_store.py +141 -0
  115. kiln_ai/tools/__init__.py +8 -0
  116. kiln_ai/tools/base_tool.py +82 -0
  117. kiln_ai/tools/built_in_tools/__init__.py +13 -0
  118. kiln_ai/tools/built_in_tools/math_tools.py +124 -0
  119. kiln_ai/tools/built_in_tools/test_math_tools.py +204 -0
  120. kiln_ai/tools/mcp_server_tool.py +95 -0
  121. kiln_ai/tools/mcp_session_manager.py +246 -0
  122. kiln_ai/tools/rag_tools.py +157 -0
  123. kiln_ai/tools/test_base_tools.py +199 -0
  124. kiln_ai/tools/test_mcp_server_tool.py +457 -0
  125. kiln_ai/tools/test_mcp_session_manager.py +1585 -0
  126. kiln_ai/tools/test_rag_tools.py +848 -0
  127. kiln_ai/tools/test_tool_registry.py +562 -0
  128. kiln_ai/tools/tool_registry.py +85 -0
  129. kiln_ai/utils/__init__.py +3 -0
  130. kiln_ai/utils/async_job_runner.py +62 -17
  131. kiln_ai/utils/config.py +24 -2
  132. kiln_ai/utils/env.py +15 -0
  133. kiln_ai/utils/filesystem.py +14 -0
  134. kiln_ai/utils/filesystem_cache.py +60 -0
  135. kiln_ai/utils/litellm.py +94 -0
  136. kiln_ai/utils/lock.py +100 -0
  137. kiln_ai/utils/mime_type.py +38 -0
  138. kiln_ai/utils/open_ai_types.py +94 -0
  139. kiln_ai/utils/pdf_utils.py +38 -0
  140. kiln_ai/utils/project_utils.py +17 -0
  141. kiln_ai/utils/test_async_job_runner.py +151 -35
  142. kiln_ai/utils/test_config.py +138 -1
  143. kiln_ai/utils/test_env.py +142 -0
  144. kiln_ai/utils/test_filesystem_cache.py +316 -0
  145. kiln_ai/utils/test_litellm.py +206 -0
  146. kiln_ai/utils/test_lock.py +185 -0
  147. kiln_ai/utils/test_mime_type.py +66 -0
  148. kiln_ai/utils/test_open_ai_types.py +131 -0
  149. kiln_ai/utils/test_pdf_utils.py +73 -0
  150. kiln_ai/utils/test_uuid.py +111 -0
  151. kiln_ai/utils/test_validation.py +524 -0
  152. kiln_ai/utils/uuid.py +9 -0
  153. kiln_ai/utils/validation.py +90 -0
  154. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +12 -5
  155. kiln_ai-0.21.0.dist-info/RECORD +211 -0
  156. kiln_ai-0.19.0.dist-info/RECORD +0 -115
  157. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
  158. {kiln_ai-0.19.0.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,316 @@
1
+ import tempfile
2
+ from pathlib import Path
3
+ from unittest.mock import AsyncMock, patch
4
+
5
+ import pytest
6
+
7
+ from kiln_ai.utils.filesystem_cache import FilesystemCache, TemporaryFilesystemCache
8
+
9
+
10
+ class TestFilesystemCache:
11
+ @pytest.fixture
12
+ def temp_dir(self):
13
+ with tempfile.TemporaryDirectory() as tmp_dir:
14
+ yield Path(tmp_dir)
15
+
16
+ @pytest.fixture
17
+ def cache(self, temp_dir):
18
+ return FilesystemCache(temp_dir)
19
+
20
+ def test_init(self, temp_dir):
21
+ cache = FilesystemCache(temp_dir)
22
+ assert cache.cache_dir_path == temp_dir
23
+ assert cache.validate_key is not None
24
+
25
+ def test_get_path_valid_key(self, cache):
26
+ key = "test_file"
27
+ expected_path = cache.cache_dir_path / key
28
+ assert cache.get_path(key) == expected_path
29
+
30
+ def test_get_path_invalid_key_empty(self, cache):
31
+ with pytest.raises(ValueError):
32
+ cache.get_path("")
33
+
34
+ def test_get_path_invalid_key_too_long(self, cache):
35
+ long_key = "x" * 121 # exceeds max_length=120
36
+ with pytest.raises(ValueError):
37
+ cache.get_path(long_key)
38
+
39
+ def test_get_path_invalid_key_special_chars(self, cache):
40
+ with pytest.raises(ValueError):
41
+ cache.get_path("invalid/key")
42
+
43
+ async def test_get_nonexistent_file(self, cache):
44
+ result = await cache.get("nonexistent")
45
+ assert result is None
46
+
47
+ async def test_get_existing_file(self, cache, temp_dir):
48
+ key = "test"
49
+ content = b"Hello, World!"
50
+ file_path = temp_dir / key
51
+ file_path.write_bytes(content)
52
+
53
+ result = await cache.get(key)
54
+ assert result == content
55
+
56
+ async def test_get_file_read_error(self, cache, temp_dir):
57
+ key = "test"
58
+ file_path = temp_dir / key
59
+ file_path.write_bytes(b"test")
60
+
61
+ with patch(
62
+ "anyio.Path.read_bytes",
63
+ new_callable=AsyncMock,
64
+ side_effect=IOError("Read error"),
65
+ ):
66
+ result = await cache.get(key)
67
+ assert result is None
68
+
69
+ async def test_set_valid_data(self, cache, temp_dir):
70
+ key = "test"
71
+ content = b"Hello, World!"
72
+
73
+ result_path = await cache.set(key, content)
74
+ expected_path = temp_dir / key
75
+
76
+ assert result_path == expected_path
77
+ assert expected_path.exists()
78
+ assert expected_path.read_bytes() == content
79
+
80
+ async def test_set_invalid_key_empty(self, cache):
81
+ with pytest.raises(ValueError):
82
+ await cache.set("", b"content")
83
+
84
+ async def test_set_invalid_key_too_long(self, cache):
85
+ long_key = "x" * 121
86
+ with pytest.raises(ValueError):
87
+ await cache.set(long_key, b"content")
88
+
89
+ async def test_set_invalid_key_special_chars(self, cache):
90
+ with pytest.raises(ValueError):
91
+ await cache.set("invalid/key", b"content")
92
+
93
+ async def test_set_overwrites_existing(self, cache, temp_dir):
94
+ key = "test"
95
+ original_content = b"Original content"
96
+ new_content = b"New content"
97
+
98
+ # Set original content
99
+ await cache.set(key, original_content)
100
+ assert (temp_dir / key).read_bytes() == original_content
101
+
102
+ # Overwrite with new content
103
+ await cache.set(key, new_content)
104
+ assert (temp_dir / key).read_bytes() == new_content
105
+
106
+ async def test_set_creates_parent_directory(self, cache, temp_dir):
107
+ # Test that set method creates parent directories if they don't exist
108
+ # Note: This test demonstrates a limitation - the current implementation
109
+ # doesn't support directory paths due to name_validator restrictions
110
+ key = "subdir_test" # Using underscore instead of slash
111
+ content = b"Test content"
112
+
113
+ result_path = await cache.set(key, content)
114
+ expected_path = temp_dir / key
115
+
116
+ assert result_path == expected_path
117
+ assert expected_path.exists()
118
+ assert expected_path.read_bytes() == content
119
+
120
+ async def test_set_with_nested_directories(self, cache, temp_dir):
121
+ # Test that set method creates deeply nested directories
122
+ # Note: This test demonstrates a limitation - the current implementation
123
+ # doesn't support directory paths due to name_validator restrictions
124
+ key = "level1_level2_level3_test" # Using underscores instead of slashes
125
+ content = b"Deeply nested content"
126
+
127
+ result_path = await cache.set(key, content)
128
+ expected_path = temp_dir / key
129
+
130
+ assert result_path == expected_path
131
+ assert expected_path.exists()
132
+ assert expected_path.read_bytes() == content
133
+
134
+ async def test_directory_paths_not_supported(self, cache):
135
+ # Test that demonstrates the current limitation - directory paths are not supported
136
+ # due to name_validator forbidding forward slashes
137
+ with pytest.raises(ValueError, match="Name is invalid"):
138
+ await cache.set("subdir/test", b"content")
139
+
140
+ with pytest.raises(ValueError, match="Name is invalid"):
141
+ await cache.get("subdir/test")
142
+
143
+ async def test_roundtrip_get_set(self, cache):
144
+ key = "roundtrip"
145
+ content = b"Roundtrip test content"
146
+
147
+ # Set content
148
+ await cache.set(key, content)
149
+
150
+ # Get content back
151
+ retrieved = await cache.get(key)
152
+ assert retrieved == content
153
+
154
+ async def test_multiple_files(self, cache, temp_dir):
155
+ files = {
156
+ "file1": b"Content 1",
157
+ "file2": b"Content 2",
158
+ "subdir_file3": b"Content 3", # Using underscore instead of slash
159
+ }
160
+
161
+ # Set all files
162
+ for key, content in files.items():
163
+ await cache.set(key, content)
164
+
165
+ # Verify all files exist and have correct content
166
+ for key, expected_content in files.items():
167
+ actual_content = await cache.get(key)
168
+ assert actual_content == expected_content
169
+
170
+ async def test_empty_bytes(self, cache):
171
+ key = "empty"
172
+ content = b""
173
+
174
+ await cache.set(key, content)
175
+ retrieved = await cache.get(key)
176
+ assert retrieved == content
177
+
178
+ async def test_large_content(self, cache):
179
+ key = "large"
180
+ content = b"x" * 10000 # 10KB of data
181
+
182
+ await cache.set(key, content)
183
+ retrieved = await cache.get(key)
184
+ assert retrieved == content
185
+
186
+ @pytest.mark.parametrize(
187
+ "unicode_text",
188
+ [
189
+ "Simple ASCII text",
190
+ "中文文本测试",
191
+ "Mixed 中文 and English text",
192
+ "Emojis: 🎉🚀💻🔥",
193
+ "Complex: 你好世界! Hello 世界! 🌍 This is 测试 with 中文, emojis 🚀, and English.",
194
+ "Special chars: ñáéíóú àèìòù çüöä",
195
+ "Math symbols: ∑∆π∫∞±≤≥≠",
196
+ "Currency: €£¥$₹₽",
197
+ ],
198
+ )
199
+ async def test_unicode_text_retrieval_integrity(self, cache, unicode_text):
200
+ # Test that Unicode text is not corrupted during storage and retrieval
201
+ key = "unicode_integrity"
202
+ content = unicode_text.encode("utf-8")
203
+
204
+ # Store the text
205
+ await cache.set(key, content)
206
+
207
+ # Retrieve and verify
208
+ retrieved = await cache.get(key)
209
+ assert retrieved == content
210
+ assert retrieved.decode("utf-8") == unicode_text
211
+
212
+ async def test_key_overwrite_behavior(self, cache):
213
+ # Test that setting at the same key overwrites whatever was there
214
+ key = "overwrite_test"
215
+
216
+ # Set initial content
217
+ initial_content = "Initial content".encode("utf-8")
218
+ await cache.set(key, initial_content)
219
+
220
+ # Verify initial content is stored
221
+ retrieved = await cache.get(key)
222
+ assert retrieved == initial_content
223
+ assert retrieved.decode("utf-8") == "Initial content"
224
+
225
+ # Overwrite with different content
226
+ new_content = "New content with 中文 and emojis 🚀".encode("utf-8")
227
+ await cache.set(key, new_content)
228
+
229
+ # Verify the content was overwritten
230
+ retrieved = await cache.get(key)
231
+ assert retrieved == new_content
232
+ assert retrieved.decode("utf-8") == "New content with 中文 and emojis 🚀"
233
+ assert retrieved != initial_content
234
+
235
+ # Overwrite again with empty content
236
+ empty_content = b""
237
+ await cache.set(key, empty_content)
238
+
239
+ # Verify empty content is stored
240
+ retrieved = await cache.get(key)
241
+ assert retrieved == empty_content
242
+ assert retrieved.decode("utf-8") == ""
243
+
244
+ @pytest.mark.parametrize(
245
+ "invalid_key",
246
+ [
247
+ "",
248
+ "x" * 121,
249
+ "invalid/key",
250
+ "invalid\\key",
251
+ "invalid:key",
252
+ "invalid*key",
253
+ "invalid?key",
254
+ "invalid<key",
255
+ "invalid>key",
256
+ "invalid|key",
257
+ ],
258
+ )
259
+ async def test_invalid_keys(self, cache, invalid_key):
260
+ with pytest.raises(ValueError):
261
+ cache.get_path(invalid_key)
262
+
263
+ with pytest.raises(ValueError):
264
+ await cache.set(invalid_key, b"content")
265
+
266
+
267
+ class TestTemporaryFilesystemCache:
268
+ def test_temporary_cache_creation(self):
269
+ """Test that TemporaryFilesystemCache creates a temporary directory."""
270
+ temp_cache = TemporaryFilesystemCache()
271
+
272
+ # Should create a temporary directory
273
+ assert temp_cache._cache_temp_dir is not None
274
+ assert Path(temp_cache._cache_temp_dir).exists()
275
+
276
+ # Check that the directory name (not full path) starts with the prefix
277
+ temp_dir_name = Path(temp_cache._cache_temp_dir).name
278
+ assert temp_dir_name.startswith("kiln_cache_")
279
+
280
+ # Should have a FilesystemCache instance
281
+ assert isinstance(temp_cache.filesystem_cache, FilesystemCache)
282
+ assert temp_cache.filesystem_cache.cache_dir_path == Path(
283
+ temp_cache._cache_temp_dir
284
+ )
285
+
286
+ async def test_multiple_instances_share_same_cache(self):
287
+ """Test that multiple calls to shared() return the same cache instance."""
288
+ # Get cache instances
289
+ cache1 = TemporaryFilesystemCache.shared()
290
+ cache2 = TemporaryFilesystemCache.shared()
291
+
292
+ # Should be the same instance
293
+ assert cache1 is cache2
294
+
295
+ # Test that they share the same cache directory
296
+ assert cache1.cache_dir_path == cache2.cache_dir_path
297
+
298
+ # Test that content set in one is available in the other
299
+ key = "shared_test"
300
+ content = b"shared content"
301
+
302
+ await cache1.set(key, content)
303
+ retrieved = await cache2.get(key)
304
+ assert retrieved == content
305
+
306
+ def test_cache_directory_naming(self):
307
+ """Test that the temporary cache directory has the correct naming pattern."""
308
+ temp_cache = TemporaryFilesystemCache()
309
+ temp_dir_name = Path(temp_cache._cache_temp_dir).name
310
+
311
+ # Should start with the expected prefix
312
+ assert temp_dir_name.startswith("kiln_cache_")
313
+
314
+ # Should be a valid directory name
315
+ assert len(temp_dir_name) > len("kiln_cache_")
316
+ assert "/" not in temp_dir_name # Should not contain path separators
@@ -0,0 +1,206 @@
1
+ import pytest
2
+
3
+ from kiln_ai.adapters.ml_embedding_model_list import KilnEmbeddingModelProvider
4
+ from kiln_ai.adapters.ml_model_list import KilnModelProvider
5
+ from kiln_ai.datamodel.datamodel_enums import ModelProviderName
6
+ from kiln_ai.utils.litellm import LitellmProviderInfo, get_litellm_provider_info
7
+
8
+
9
+ class TestGetLitellmProviderInfo:
10
+ """Test cases for get_litellm_provider_info function"""
11
+
12
+ @pytest.fixture
13
+ def sample_model_id(self):
14
+ """Sample model ID for testing"""
15
+ return "test-model-id"
16
+
17
+ @pytest.fixture
18
+ def embedding_provider(self, sample_model_id):
19
+ """Sample KilnEmbeddingModelProvider for testing"""
20
+ return KilnEmbeddingModelProvider(
21
+ name=ModelProviderName.openai,
22
+ model_id=sample_model_id,
23
+ n_dimensions=1536,
24
+ )
25
+
26
+ @pytest.fixture
27
+ def model_provider(self, sample_model_id):
28
+ """Sample KilnModelProvider for testing"""
29
+ return KilnModelProvider(
30
+ name=ModelProviderName.openai,
31
+ model_id=sample_model_id,
32
+ )
33
+
34
+ @pytest.mark.parametrize(
35
+ "model_id",
36
+ [
37
+ None,
38
+ "",
39
+ ],
40
+ )
41
+ def test_missing_model_id_raises_error(self, model_id):
42
+ """Test that missing model_id raises ValueError"""
43
+ provider = KilnModelProvider(
44
+ name=ModelProviderName.openai,
45
+ model_id=model_id,
46
+ )
47
+
48
+ with pytest.raises(
49
+ ValueError, match="Model ID is required for OpenAI compatible models"
50
+ ):
51
+ get_litellm_provider_info(provider)
52
+
53
+ @pytest.mark.parametrize(
54
+ "provider_name,expected_litellm_name,expected_is_custom",
55
+ [
56
+ (ModelProviderName.openrouter, "openrouter", False),
57
+ (ModelProviderName.openai, "openai", False),
58
+ (ModelProviderName.groq, "groq", False),
59
+ (ModelProviderName.anthropic, "anthropic", False),
60
+ (ModelProviderName.gemini_api, "gemini", False),
61
+ (ModelProviderName.fireworks_ai, "fireworks_ai", False),
62
+ (ModelProviderName.amazon_bedrock, "bedrock", False),
63
+ (ModelProviderName.azure_openai, "azure", False),
64
+ (ModelProviderName.huggingface, "huggingface", False),
65
+ (ModelProviderName.vertex, "vertex_ai", False),
66
+ (ModelProviderName.together_ai, "together_ai", False),
67
+ (ModelProviderName.ollama, "openai", True),
68
+ (ModelProviderName.openai_compatible, "openai", True),
69
+ (ModelProviderName.kiln_custom_registry, "openai", True),
70
+ (ModelProviderName.kiln_fine_tune, "openai", True),
71
+ ],
72
+ )
73
+ def test_provider_mappings_with_model_provider(
74
+ self, provider_name, expected_litellm_name, expected_is_custom, sample_model_id
75
+ ):
76
+ """Test provider name mappings for KilnModelProvider"""
77
+ provider = KilnModelProvider(
78
+ name=provider_name,
79
+ model_id=sample_model_id,
80
+ )
81
+
82
+ result = get_litellm_provider_info(provider)
83
+
84
+ assert isinstance(result, LitellmProviderInfo)
85
+ assert result.provider_name == expected_litellm_name
86
+ assert result.is_custom == expected_is_custom
87
+ assert result.litellm_model_id == f"{expected_litellm_name}/{sample_model_id}"
88
+
89
+ @pytest.mark.parametrize(
90
+ "provider_name,expected_litellm_name,expected_is_custom",
91
+ [
92
+ (ModelProviderName.openrouter, "openrouter", False),
93
+ (ModelProviderName.openai, "openai", False),
94
+ (ModelProviderName.groq, "groq", False),
95
+ (ModelProviderName.anthropic, "anthropic", False),
96
+ (ModelProviderName.gemini_api, "gemini", False),
97
+ (ModelProviderName.fireworks_ai, "fireworks_ai", False),
98
+ (ModelProviderName.amazon_bedrock, "bedrock", False),
99
+ (ModelProviderName.azure_openai, "azure", False),
100
+ (ModelProviderName.huggingface, "huggingface", False),
101
+ (ModelProviderName.vertex, "vertex_ai", False),
102
+ (ModelProviderName.together_ai, "together_ai", False),
103
+ (ModelProviderName.ollama, "openai", True),
104
+ (ModelProviderName.openai_compatible, "openai", True),
105
+ (ModelProviderName.kiln_custom_registry, "openai", True),
106
+ (ModelProviderName.kiln_fine_tune, "openai", True),
107
+ ],
108
+ )
109
+ def test_provider_mappings_with_embedding_provider(
110
+ self, provider_name, expected_litellm_name, expected_is_custom, sample_model_id
111
+ ):
112
+ """Test provider name mappings for KilnEmbeddingModelProvider"""
113
+ provider = KilnEmbeddingModelProvider(
114
+ name=provider_name,
115
+ model_id=sample_model_id,
116
+ n_dimensions=1536,
117
+ )
118
+
119
+ result = get_litellm_provider_info(provider)
120
+
121
+ assert isinstance(result, LitellmProviderInfo)
122
+ assert result.provider_name == expected_litellm_name
123
+ assert result.is_custom == expected_is_custom
124
+ assert result.litellm_model_id == f"{expected_litellm_name}/{sample_model_id}"
125
+
126
+ def test_custom_providers_use_openai_format(self, sample_model_id):
127
+ """Test that custom providers use 'openai' as the litellm provider name"""
128
+ custom_providers = [
129
+ ModelProviderName.ollama,
130
+ ModelProviderName.openai_compatible,
131
+ ModelProviderName.kiln_custom_registry,
132
+ ModelProviderName.kiln_fine_tune,
133
+ ]
134
+
135
+ for provider_name in custom_providers:
136
+ provider = KilnModelProvider(
137
+ name=provider_name,
138
+ model_id=sample_model_id,
139
+ )
140
+
141
+ result = get_litellm_provider_info(provider)
142
+
143
+ assert result.provider_name == "openai"
144
+ assert result.is_custom is True
145
+ assert result.litellm_model_id == f"openai/{sample_model_id}"
146
+
147
+ def test_non_custom_providers_use_correct_format(self, sample_model_id):
148
+ """Test that non-custom providers use their actual provider names"""
149
+ non_custom_providers = [
150
+ (ModelProviderName.openai, "openai"),
151
+ (ModelProviderName.anthropic, "anthropic"),
152
+ (ModelProviderName.groq, "groq"),
153
+ (ModelProviderName.gemini_api, "gemini"),
154
+ ]
155
+
156
+ for provider_name, expected_name in non_custom_providers:
157
+ provider = KilnModelProvider(
158
+ name=provider_name,
159
+ model_id=sample_model_id,
160
+ )
161
+
162
+ result = get_litellm_provider_info(provider)
163
+
164
+ assert result.provider_name == expected_name
165
+ assert result.is_custom is False
166
+ assert result.litellm_model_id == f"{expected_name}/{sample_model_id}"
167
+
168
+ def test_litellm_model_id_format(self, embedding_provider):
169
+ """Test that litellm_model_id follows the correct format"""
170
+ result = get_litellm_provider_info(embedding_provider)
171
+
172
+ expected_format = f"{result.provider_name}/{embedding_provider.model_id}"
173
+ assert result.litellm_model_id == expected_format
174
+
175
+ def test_return_type_structure(self, model_provider):
176
+ """Test that the return type has all expected fields"""
177
+ result = get_litellm_provider_info(model_provider)
178
+
179
+ assert hasattr(result, "provider_name")
180
+ assert hasattr(result, "is_custom")
181
+ assert hasattr(result, "litellm_model_id")
182
+
183
+ assert isinstance(result.provider_name, str)
184
+ assert isinstance(result.is_custom, bool)
185
+ assert isinstance(result.litellm_model_id, str)
186
+
187
+ def test_works_with_both_provider_types(self, sample_model_id):
188
+ """Test that function works with both KilnModelProvider and KilnEmbeddingModelProvider"""
189
+ model_provider = KilnModelProvider(
190
+ name=ModelProviderName.openai,
191
+ model_id=sample_model_id,
192
+ )
193
+
194
+ embedding_provider = KilnEmbeddingModelProvider(
195
+ name=ModelProviderName.openai,
196
+ model_id=sample_model_id,
197
+ n_dimensions=1536,
198
+ )
199
+
200
+ model_result = get_litellm_provider_info(model_provider)
201
+ embedding_result = get_litellm_provider_info(embedding_provider)
202
+
203
+ # Results should be identical for same provider name and model ID
204
+ assert model_result.provider_name == embedding_result.provider_name
205
+ assert model_result.is_custom == embedding_result.is_custom
206
+ assert model_result.litellm_model_id == embedding_result.litellm_model_id