kiln-ai 0.20.1__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kiln-ai might be problematic. Click here for more details.
- kiln_ai/adapters/__init__.py +6 -0
- kiln_ai/adapters/adapter_registry.py +43 -226
- kiln_ai/adapters/chunkers/__init__.py +13 -0
- kiln_ai/adapters/chunkers/base_chunker.py +42 -0
- kiln_ai/adapters/chunkers/chunker_registry.py +16 -0
- kiln_ai/adapters/chunkers/fixed_window_chunker.py +39 -0
- kiln_ai/adapters/chunkers/helpers.py +23 -0
- kiln_ai/adapters/chunkers/test_base_chunker.py +63 -0
- kiln_ai/adapters/chunkers/test_chunker_registry.py +28 -0
- kiln_ai/adapters/chunkers/test_fixed_window_chunker.py +346 -0
- kiln_ai/adapters/chunkers/test_helpers.py +75 -0
- kiln_ai/adapters/data_gen/test_data_gen_task.py +9 -3
- kiln_ai/adapters/embedding/__init__.py +0 -0
- kiln_ai/adapters/embedding/base_embedding_adapter.py +44 -0
- kiln_ai/adapters/embedding/embedding_registry.py +32 -0
- kiln_ai/adapters/embedding/litellm_embedding_adapter.py +199 -0
- kiln_ai/adapters/embedding/test_base_embedding_adapter.py +283 -0
- kiln_ai/adapters/embedding/test_embedding_registry.py +166 -0
- kiln_ai/adapters/embedding/test_litellm_embedding_adapter.py +1149 -0
- kiln_ai/adapters/eval/eval_runner.py +6 -2
- kiln_ai/adapters/eval/test_base_eval.py +1 -3
- kiln_ai/adapters/eval/test_g_eval.py +1 -1
- kiln_ai/adapters/extractors/__init__.py +18 -0
- kiln_ai/adapters/extractors/base_extractor.py +72 -0
- kiln_ai/adapters/extractors/encoding.py +20 -0
- kiln_ai/adapters/extractors/extractor_registry.py +44 -0
- kiln_ai/adapters/extractors/extractor_runner.py +112 -0
- kiln_ai/adapters/extractors/litellm_extractor.py +386 -0
- kiln_ai/adapters/extractors/test_base_extractor.py +244 -0
- kiln_ai/adapters/extractors/test_encoding.py +54 -0
- kiln_ai/adapters/extractors/test_extractor_registry.py +181 -0
- kiln_ai/adapters/extractors/test_extractor_runner.py +181 -0
- kiln_ai/adapters/extractors/test_litellm_extractor.py +1192 -0
- kiln_ai/adapters/fine_tune/test_dataset_formatter.py +2 -2
- kiln_ai/adapters/fine_tune/test_fireworks_tinetune.py +2 -6
- kiln_ai/adapters/fine_tune/test_together_finetune.py +2 -6
- kiln_ai/adapters/ml_embedding_model_list.py +192 -0
- kiln_ai/adapters/ml_model_list.py +382 -4
- kiln_ai/adapters/model_adapters/litellm_adapter.py +7 -69
- kiln_ai/adapters/model_adapters/test_litellm_adapter.py +1 -1
- kiln_ai/adapters/model_adapters/test_structured_output.py +3 -1
- kiln_ai/adapters/ollama_tools.py +69 -12
- kiln_ai/adapters/provider_tools.py +190 -46
- kiln_ai/adapters/rag/deduplication.py +49 -0
- kiln_ai/adapters/rag/progress.py +252 -0
- kiln_ai/adapters/rag/rag_runners.py +844 -0
- kiln_ai/adapters/rag/test_deduplication.py +195 -0
- kiln_ai/adapters/rag/test_progress.py +785 -0
- kiln_ai/adapters/rag/test_rag_runners.py +2376 -0
- kiln_ai/adapters/remote_config.py +80 -8
- kiln_ai/adapters/test_adapter_registry.py +579 -86
- kiln_ai/adapters/test_ml_embedding_model_list.py +429 -0
- kiln_ai/adapters/test_ml_model_list.py +212 -0
- kiln_ai/adapters/test_ollama_tools.py +340 -1
- kiln_ai/adapters/test_prompt_builders.py +1 -1
- kiln_ai/adapters/test_provider_tools.py +199 -8
- kiln_ai/adapters/test_remote_config.py +551 -56
- kiln_ai/adapters/vector_store/__init__.py +1 -0
- kiln_ai/adapters/vector_store/base_vector_store_adapter.py +83 -0
- kiln_ai/adapters/vector_store/lancedb_adapter.py +389 -0
- kiln_ai/adapters/vector_store/test_base_vector_store.py +160 -0
- kiln_ai/adapters/vector_store/test_lancedb_adapter.py +1841 -0
- kiln_ai/adapters/vector_store/test_vector_store_registry.py +199 -0
- kiln_ai/adapters/vector_store/vector_store_registry.py +33 -0
- kiln_ai/datamodel/__init__.py +16 -13
- kiln_ai/datamodel/basemodel.py +170 -1
- kiln_ai/datamodel/chunk.py +158 -0
- kiln_ai/datamodel/datamodel_enums.py +27 -0
- kiln_ai/datamodel/embedding.py +64 -0
- kiln_ai/datamodel/extraction.py +303 -0
- kiln_ai/datamodel/project.py +33 -1
- kiln_ai/datamodel/rag.py +79 -0
- kiln_ai/datamodel/test_attachment.py +649 -0
- kiln_ai/datamodel/test_basemodel.py +1 -1
- kiln_ai/datamodel/test_chunk_models.py +317 -0
- kiln_ai/datamodel/test_dataset_split.py +1 -1
- kiln_ai/datamodel/test_embedding_models.py +448 -0
- kiln_ai/datamodel/test_eval_model.py +6 -6
- kiln_ai/datamodel/test_extraction_chunk.py +206 -0
- kiln_ai/datamodel/test_extraction_model.py +470 -0
- kiln_ai/datamodel/test_rag.py +641 -0
- kiln_ai/datamodel/test_tool_id.py +81 -0
- kiln_ai/datamodel/test_vector_store.py +320 -0
- kiln_ai/datamodel/tool_id.py +22 -0
- kiln_ai/datamodel/vector_store.py +141 -0
- kiln_ai/tools/mcp_session_manager.py +4 -1
- kiln_ai/tools/rag_tools.py +157 -0
- kiln_ai/tools/test_mcp_session_manager.py +1 -1
- kiln_ai/tools/test_rag_tools.py +848 -0
- kiln_ai/tools/test_tool_registry.py +91 -2
- kiln_ai/tools/tool_registry.py +21 -0
- kiln_ai/utils/__init__.py +3 -0
- kiln_ai/utils/async_job_runner.py +62 -17
- kiln_ai/utils/config.py +2 -2
- kiln_ai/utils/env.py +15 -0
- kiln_ai/utils/filesystem.py +14 -0
- kiln_ai/utils/filesystem_cache.py +60 -0
- kiln_ai/utils/litellm.py +94 -0
- kiln_ai/utils/lock.py +100 -0
- kiln_ai/utils/mime_type.py +38 -0
- kiln_ai/utils/pdf_utils.py +38 -0
- kiln_ai/utils/test_async_job_runner.py +151 -35
- kiln_ai/utils/test_env.py +142 -0
- kiln_ai/utils/test_filesystem_cache.py +316 -0
- kiln_ai/utils/test_litellm.py +206 -0
- kiln_ai/utils/test_lock.py +185 -0
- kiln_ai/utils/test_mime_type.py +66 -0
- kiln_ai/utils/test_pdf_utils.py +73 -0
- kiln_ai/utils/test_uuid.py +111 -0
- kiln_ai/utils/test_validation.py +524 -0
- kiln_ai/utils/uuid.py +9 -0
- kiln_ai/utils/validation.py +90 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/METADATA +7 -1
- kiln_ai-0.21.0.dist-info/RECORD +211 -0
- kiln_ai-0.20.1.dist-info/RECORD +0 -138
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/WHEEL +0 -0
- {kiln_ai-0.20.1.dist-info → kiln_ai-0.21.0.dist-info}/licenses/LICENSE.txt +0 -0
|
@@ -160,8 +160,12 @@ class EvalRunner:
|
|
|
160
160
|
"""
|
|
161
161
|
jobs = self.collect_tasks()
|
|
162
162
|
|
|
163
|
-
runner = AsyncJobRunner(
|
|
164
|
-
|
|
163
|
+
runner = AsyncJobRunner(
|
|
164
|
+
concurrency=concurrency,
|
|
165
|
+
jobs=jobs,
|
|
166
|
+
run_job_fn=self.run_job,
|
|
167
|
+
)
|
|
168
|
+
async for progress in runner.run():
|
|
165
169
|
yield progress
|
|
166
170
|
|
|
167
171
|
async def run_job(self, job: EvalJob) -> bool:
|
|
@@ -307,9 +307,7 @@ async def test_run_method():
|
|
|
307
307
|
evaluator = EvalTester(eval_config, run_config.run_config())
|
|
308
308
|
|
|
309
309
|
# Run the evaluation
|
|
310
|
-
task_run, eval_scores,
|
|
311
|
-
"test input"
|
|
312
|
-
)
|
|
310
|
+
task_run, eval_scores, _ = await evaluator.run_task_and_eval("test input")
|
|
313
311
|
|
|
314
312
|
# Verify task run was created
|
|
315
313
|
assert task_run.input == "test input"
|
|
@@ -188,7 +188,7 @@ async def test_run_g_eval_e2e(
|
|
|
188
188
|
g_eval = GEval(test_eval_config, test_run_config)
|
|
189
189
|
|
|
190
190
|
# Run the evaluation
|
|
191
|
-
|
|
191
|
+
_, scores, intermediate_outputs = await g_eval.run_task_and_eval("chickens")
|
|
192
192
|
|
|
193
193
|
# Verify the evaluation results
|
|
194
194
|
assert isinstance(scores, dict)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File extractors for processing different document types.
|
|
3
|
+
|
|
4
|
+
This package provides a framework for extracting content from files
|
|
5
|
+
using different extraction methods.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from . import base_extractor, extractor_registry, extractor_runner, litellm_extractor
|
|
9
|
+
from .base_extractor import ExtractionInput, ExtractionOutput
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ExtractionInput",
|
|
13
|
+
"ExtractionOutput",
|
|
14
|
+
"base_extractor",
|
|
15
|
+
"extractor_registry",
|
|
16
|
+
"extractor_runner",
|
|
17
|
+
"litellm_extractor",
|
|
18
|
+
]
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, OutputFormat
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExtractionInput(BaseModel):
|
|
13
|
+
path: Path | str = Field(description="The absolute path to the file to extract.")
|
|
14
|
+
mime_type: str = Field(description="The mime type of the file.")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ExtractionOutput(BaseModel):
|
|
18
|
+
"""
|
|
19
|
+
The output of an extraction. This is the data that will be saved to the data store.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
is_passthrough: bool = Field(
|
|
23
|
+
default=False, description="Whether the extractor returned the file as is."
|
|
24
|
+
)
|
|
25
|
+
content_format: OutputFormat = Field(
|
|
26
|
+
description="The format of the extracted data."
|
|
27
|
+
)
|
|
28
|
+
content: str = Field(description="The extracted data.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BaseExtractor(ABC):
|
|
32
|
+
"""
|
|
33
|
+
Base class for all extractors.
|
|
34
|
+
|
|
35
|
+
Should be subclassed by each extractor.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, extractor_config: ExtractorConfig):
|
|
39
|
+
self.extractor_config = extractor_config
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
async def extract(
|
|
46
|
+
self,
|
|
47
|
+
extraction_input: ExtractionInput,
|
|
48
|
+
) -> ExtractionOutput:
|
|
49
|
+
"""
|
|
50
|
+
Extracts content from a file by delegating to the concrete extractor implementation.
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
if self._should_passthrough(extraction_input.mime_type):
|
|
54
|
+
return ExtractionOutput(
|
|
55
|
+
is_passthrough=True,
|
|
56
|
+
content=Path(extraction_input.path).read_text(encoding="utf-8"),
|
|
57
|
+
content_format=self.extractor_config.output_format,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return await self._extract(
|
|
61
|
+
extraction_input,
|
|
62
|
+
)
|
|
63
|
+
except Exception as e:
|
|
64
|
+
raise ValueError(f"Error extracting {extraction_input.path}: {e}") from e
|
|
65
|
+
|
|
66
|
+
def _should_passthrough(self, mime_type: str) -> bool:
|
|
67
|
+
return mime_type.lower() in {
|
|
68
|
+
mt.lower() for mt in self.extractor_config.passthrough_mimetypes
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def output_format(self) -> OutputFormat:
|
|
72
|
+
return self.extractor_config.output_format
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def to_base64_url(mime_type: str, bytes: bytes) -> str:
|
|
5
|
+
base64_url = f"data:{mime_type};base64,{base64.b64encode(bytes).decode('utf-8')}"
|
|
6
|
+
return base64_url
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def from_base64_url(base64_url: str) -> bytes:
|
|
10
|
+
if not base64_url.startswith("data:") or "," not in base64_url:
|
|
11
|
+
raise ValueError("Invalid base64 URL format")
|
|
12
|
+
|
|
13
|
+
parts = base64_url.split(",")
|
|
14
|
+
if len(parts) != 2:
|
|
15
|
+
raise ValueError("Invalid base64 URL format")
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
return base64.b64decode(parts[1])
|
|
19
|
+
except Exception as e:
|
|
20
|
+
raise ValueError(f"Failed to decode base64 data: {e}")
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from kiln_ai.adapters.extractors.base_extractor import BaseExtractor
|
|
2
|
+
from kiln_ai.adapters.extractors.litellm_extractor import LitellmExtractor
|
|
3
|
+
from kiln_ai.adapters.ml_model_list import ModelProviderName
|
|
4
|
+
from kiln_ai.adapters.provider_tools import (
|
|
5
|
+
core_provider,
|
|
6
|
+
lite_llm_core_config_for_provider,
|
|
7
|
+
)
|
|
8
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType
|
|
9
|
+
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
|
|
10
|
+
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def extractor_adapter_from_type(
|
|
14
|
+
extractor_type: ExtractorType,
|
|
15
|
+
extractor_config: ExtractorConfig,
|
|
16
|
+
filesystem_cache: FilesystemCache | None = None,
|
|
17
|
+
) -> BaseExtractor:
|
|
18
|
+
match extractor_type:
|
|
19
|
+
case ExtractorType.LITELLM:
|
|
20
|
+
try:
|
|
21
|
+
provider_enum = ModelProviderName(extractor_config.model_provider_name)
|
|
22
|
+
except ValueError:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
f"Unsupported model provider name: {extractor_config.model_provider_name}. "
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
core_provider_name = core_provider(
|
|
28
|
+
extractor_config.model_name, provider_enum
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
provider_config = lite_llm_core_config_for_provider(core_provider_name)
|
|
32
|
+
if provider_config is None:
|
|
33
|
+
raise ValueError(
|
|
34
|
+
f"No configuration found for core provider: {core_provider_name.value}. "
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
return LitellmExtractor(
|
|
38
|
+
extractor_config,
|
|
39
|
+
provider_config,
|
|
40
|
+
filesystem_cache,
|
|
41
|
+
)
|
|
42
|
+
case _:
|
|
43
|
+
# type checking will catch missing cases
|
|
44
|
+
raise_exhaustive_enum_error(extractor_type)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import AsyncGenerator, Dict, List, Set
|
|
6
|
+
|
|
7
|
+
from kiln_ai.adapters.extractors.base_extractor import BaseExtractor, ExtractionInput
|
|
8
|
+
from kiln_ai.adapters.extractors.extractor_registry import extractor_adapter_from_type
|
|
9
|
+
from kiln_ai.datamodel.basemodel import ID_TYPE, KilnAttachmentModel
|
|
10
|
+
from kiln_ai.datamodel.extraction import (
|
|
11
|
+
Document,
|
|
12
|
+
Extraction,
|
|
13
|
+
ExtractionSource,
|
|
14
|
+
ExtractorConfig,
|
|
15
|
+
)
|
|
16
|
+
from kiln_ai.utils.async_job_runner import AsyncJobRunner, Progress
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ExtractorJob:
|
|
23
|
+
doc: Document
|
|
24
|
+
extractor_config: ExtractorConfig
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ExtractorRunner:
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
documents: List[Document],
|
|
31
|
+
extractor_configs: List[ExtractorConfig],
|
|
32
|
+
):
|
|
33
|
+
if len(extractor_configs) == 0:
|
|
34
|
+
raise ValueError("Extractor runner requires at least one extractor config")
|
|
35
|
+
|
|
36
|
+
self.documents = documents
|
|
37
|
+
self.extractor_configs = extractor_configs
|
|
38
|
+
|
|
39
|
+
def collect_jobs(self) -> List[ExtractorJob]:
|
|
40
|
+
jobs = []
|
|
41
|
+
|
|
42
|
+
# we want to avoid re-running the same document for the same extractor config
|
|
43
|
+
already_extracted: Dict[ID_TYPE, Set[ID_TYPE]] = defaultdict(set)
|
|
44
|
+
for document in self.documents:
|
|
45
|
+
for extraction in document.extractions():
|
|
46
|
+
already_extracted[extraction.extractor_config_id].add(document.id)
|
|
47
|
+
|
|
48
|
+
for extractor_config in self.extractor_configs:
|
|
49
|
+
for document in self.documents:
|
|
50
|
+
if document.id not in already_extracted[extractor_config.id]:
|
|
51
|
+
jobs.append(
|
|
52
|
+
ExtractorJob(
|
|
53
|
+
doc=document,
|
|
54
|
+
extractor_config=extractor_config,
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return jobs
|
|
59
|
+
|
|
60
|
+
async def run(self, concurrency: int = 25) -> AsyncGenerator[Progress, None]:
|
|
61
|
+
jobs = self.collect_jobs()
|
|
62
|
+
|
|
63
|
+
runner = AsyncJobRunner(
|
|
64
|
+
concurrency=concurrency,
|
|
65
|
+
jobs=jobs,
|
|
66
|
+
run_job_fn=self.run_job,
|
|
67
|
+
)
|
|
68
|
+
async for progress in runner.run():
|
|
69
|
+
yield progress
|
|
70
|
+
|
|
71
|
+
async def run_job(self, job: ExtractorJob) -> bool:
|
|
72
|
+
try:
|
|
73
|
+
extractor = extractor_adapter_from_type(
|
|
74
|
+
job.extractor_config.extractor_type,
|
|
75
|
+
job.extractor_config,
|
|
76
|
+
)
|
|
77
|
+
if not isinstance(extractor, BaseExtractor):
|
|
78
|
+
raise ValueError("Not able to create extractor from extractor config")
|
|
79
|
+
|
|
80
|
+
if job.doc.path is None:
|
|
81
|
+
raise ValueError("Document path is not set")
|
|
82
|
+
|
|
83
|
+
output = await extractor.extract(
|
|
84
|
+
extraction_input=ExtractionInput(
|
|
85
|
+
path=Path(
|
|
86
|
+
job.doc.original_file.attachment.resolve_path(
|
|
87
|
+
job.doc.path.parent
|
|
88
|
+
)
|
|
89
|
+
),
|
|
90
|
+
mime_type=job.doc.original_file.mime_type,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
extraction = Extraction(
|
|
95
|
+
parent=job.doc,
|
|
96
|
+
extractor_config_id=job.extractor_config.id,
|
|
97
|
+
output=KilnAttachmentModel.from_data(
|
|
98
|
+
data=output.content,
|
|
99
|
+
mime_type=output.content_format,
|
|
100
|
+
),
|
|
101
|
+
source=ExtractionSource.PASSTHROUGH
|
|
102
|
+
if output.is_passthrough
|
|
103
|
+
else ExtractionSource.PROCESSED,
|
|
104
|
+
)
|
|
105
|
+
extraction.save_to_file()
|
|
106
|
+
|
|
107
|
+
return True
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(
|
|
110
|
+
f"Error running extraction job for dataset item {job.doc.id}: {e}"
|
|
111
|
+
)
|
|
112
|
+
return False
|
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, List
|
|
6
|
+
|
|
7
|
+
import litellm
|
|
8
|
+
from litellm.types.utils import Choices, ModelResponse
|
|
9
|
+
|
|
10
|
+
from kiln_ai.adapters.extractors.base_extractor import (
|
|
11
|
+
BaseExtractor,
|
|
12
|
+
ExtractionInput,
|
|
13
|
+
ExtractionOutput,
|
|
14
|
+
)
|
|
15
|
+
from kiln_ai.adapters.extractors.encoding import to_base64_url
|
|
16
|
+
from kiln_ai.adapters.ml_model_list import built_in_models_from_provider
|
|
17
|
+
from kiln_ai.adapters.provider_tools import LiteLlmCoreConfig
|
|
18
|
+
from kiln_ai.datamodel.datamodel_enums import ModelProviderName
|
|
19
|
+
from kiln_ai.datamodel.extraction import ExtractorConfig, ExtractorType, Kind
|
|
20
|
+
from kiln_ai.utils.filesystem_cache import FilesystemCache
|
|
21
|
+
from kiln_ai.utils.litellm import get_litellm_provider_info
|
|
22
|
+
from kiln_ai.utils.pdf_utils import split_pdf_into_pages
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def max_pdf_page_concurrency_for_model(model_name: str) -> int:
|
|
26
|
+
# we assume each batch takes ~5s to complete (likely more in practice)
|
|
27
|
+
# lowest rate limit is 150 RPM for Tier 1 accounts for gemini-2.5-pro
|
|
28
|
+
if model_name == "gemini/gemini-2.5-pro":
|
|
29
|
+
return 2
|
|
30
|
+
# other models support at least 500 RPM for lowest tier accounts
|
|
31
|
+
return 5
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
MIME_TYPES_SUPPORTED = {
|
|
37
|
+
Kind.DOCUMENT: [
|
|
38
|
+
"application/pdf",
|
|
39
|
+
"text/plain",
|
|
40
|
+
"text/markdown", # not officially listed, but works
|
|
41
|
+
"text/html",
|
|
42
|
+
"text/md",
|
|
43
|
+
"text/csv",
|
|
44
|
+
],
|
|
45
|
+
Kind.IMAGE: [
|
|
46
|
+
"image/png",
|
|
47
|
+
"image/jpeg",
|
|
48
|
+
"image/jpg",
|
|
49
|
+
],
|
|
50
|
+
Kind.VIDEO: [
|
|
51
|
+
"video/mp4",
|
|
52
|
+
"video/mov", # the correct type is video/quicktime, but Google lists it as video/mov
|
|
53
|
+
"video/quicktime",
|
|
54
|
+
],
|
|
55
|
+
Kind.AUDIO: [
|
|
56
|
+
"audio/wav",
|
|
57
|
+
"audio/mpeg", # this is the official MP3 mimetype, audio/mp3 is often used but not correct
|
|
58
|
+
"audio/ogg",
|
|
59
|
+
],
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def encode_file_litellm_format(path: Path, mime_type: str) -> dict[str, Any]:
|
|
64
|
+
# There are different formats that LiteLLM supports, the docs are scattered
|
|
65
|
+
# and incomplete:
|
|
66
|
+
# - https://docs.litellm.ai/docs/completion/document_understanding#base64
|
|
67
|
+
# - https://docs.litellm.ai/docs/completion/vision#explicitly-specify-image-type
|
|
68
|
+
|
|
69
|
+
# this is the most generic format that seems to work for all / most mime types
|
|
70
|
+
if mime_type in [
|
|
71
|
+
"application/pdf",
|
|
72
|
+
"text/csv",
|
|
73
|
+
"text/html",
|
|
74
|
+
"text/markdown",
|
|
75
|
+
"text/plain",
|
|
76
|
+
] or any(mime_type.startswith(m) for m in ["video/", "audio/"]):
|
|
77
|
+
pdf_bytes = path.read_bytes()
|
|
78
|
+
return {
|
|
79
|
+
"type": "file",
|
|
80
|
+
"file": {
|
|
81
|
+
"file_data": to_base64_url(mime_type, pdf_bytes),
|
|
82
|
+
},
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
# image has its own format (but also appears to work with the file format)
|
|
86
|
+
if mime_type.startswith("image/"):
|
|
87
|
+
image_bytes = path.read_bytes()
|
|
88
|
+
return {
|
|
89
|
+
"type": "image_url",
|
|
90
|
+
"image_url": {
|
|
91
|
+
"url": to_base64_url(mime_type, image_bytes),
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
raise ValueError(f"Unsupported MIME type: {mime_type} for {path}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class LitellmExtractor(BaseExtractor):
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
extractor_config: ExtractorConfig,
|
|
102
|
+
litellm_core_config: LiteLlmCoreConfig,
|
|
103
|
+
filesystem_cache: FilesystemCache | None = None,
|
|
104
|
+
):
|
|
105
|
+
if extractor_config.extractor_type != ExtractorType.LITELLM:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"LitellmExtractor must be initialized with a litellm extractor_type config. Got {extractor_config.extractor_type}"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
prompt_document = extractor_config.prompt_document()
|
|
111
|
+
if prompt_document is None or prompt_document == "":
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"properties.prompt_document is required for LitellmExtractor"
|
|
114
|
+
)
|
|
115
|
+
prompt_video = extractor_config.prompt_video()
|
|
116
|
+
if prompt_video is None or prompt_video == "":
|
|
117
|
+
raise ValueError("properties.prompt_video is required for LitellmExtractor")
|
|
118
|
+
prompt_audio = extractor_config.prompt_audio()
|
|
119
|
+
if prompt_audio is None or prompt_audio == "":
|
|
120
|
+
raise ValueError("properties.prompt_audio is required for LitellmExtractor")
|
|
121
|
+
prompt_image = extractor_config.prompt_image()
|
|
122
|
+
if prompt_image is None or prompt_image == "":
|
|
123
|
+
raise ValueError("properties.prompt_image is required for LitellmExtractor")
|
|
124
|
+
|
|
125
|
+
self.filesystem_cache = filesystem_cache
|
|
126
|
+
|
|
127
|
+
super().__init__(extractor_config)
|
|
128
|
+
self.prompt_for_kind = {
|
|
129
|
+
Kind.DOCUMENT: prompt_document,
|
|
130
|
+
Kind.VIDEO: prompt_video,
|
|
131
|
+
Kind.AUDIO: prompt_audio,
|
|
132
|
+
Kind.IMAGE: prompt_image,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
self.litellm_core_config = litellm_core_config
|
|
136
|
+
|
|
137
|
+
def pdf_page_cache_key(self, pdf_path: Path, page_number: int) -> str:
|
|
138
|
+
"""
|
|
139
|
+
Generate a cache key for a page of a PDF. The PDF path must be the full path to the PDF file,
|
|
140
|
+
not the path to the page - since page path is temporary and changes on each run.
|
|
141
|
+
"""
|
|
142
|
+
if self.extractor_config.id is None:
|
|
143
|
+
raise ValueError("Extractor config ID is required for PDF page cache key")
|
|
144
|
+
|
|
145
|
+
raw_key = f"{pdf_path.resolve()}::{page_number}"
|
|
146
|
+
digest = hashlib.md5(raw_key.encode("utf-8")).hexdigest()
|
|
147
|
+
return f"{self.extractor_config.id}_{digest}"
|
|
148
|
+
|
|
149
|
+
async def get_page_content_from_cache(
|
|
150
|
+
self, pdf_path: Path, page_number: int
|
|
151
|
+
) -> str | None:
|
|
152
|
+
if self.filesystem_cache is None:
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
page_bytes = await self.filesystem_cache.get(
|
|
156
|
+
self.pdf_page_cache_key(pdf_path, page_number)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if page_bytes is not None:
|
|
160
|
+
logger.debug(f"Cache hit for page {page_number} of {pdf_path}")
|
|
161
|
+
try:
|
|
162
|
+
return page_bytes.decode("utf-8")
|
|
163
|
+
except UnicodeDecodeError:
|
|
164
|
+
logger.warning(
|
|
165
|
+
"Cached bytes for page %s of %s are not valid UTF-8; treating as miss.",
|
|
166
|
+
page_number,
|
|
167
|
+
pdf_path,
|
|
168
|
+
exc_info=True,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
logger.debug(f"Cache miss for page {page_number} of {pdf_path}")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
async def _extract_single_pdf_page(
|
|
175
|
+
self, pdf_path: Path, page_path: Path, prompt: str, page_number: int
|
|
176
|
+
) -> str:
|
|
177
|
+
try:
|
|
178
|
+
page_input = ExtractionInput(
|
|
179
|
+
path=str(page_path), mime_type="application/pdf"
|
|
180
|
+
)
|
|
181
|
+
completion_kwargs = self._build_completion_kwargs(prompt, page_input)
|
|
182
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
raise RuntimeError(
|
|
185
|
+
f"Error extracting page {page_number} in file {page_path}: {e}"
|
|
186
|
+
) from e
|
|
187
|
+
|
|
188
|
+
if (
|
|
189
|
+
not isinstance(response, ModelResponse)
|
|
190
|
+
or not response.choices
|
|
191
|
+
or len(response.choices) == 0
|
|
192
|
+
or not isinstance(response.choices[0], Choices)
|
|
193
|
+
):
|
|
194
|
+
raise RuntimeError(
|
|
195
|
+
f"Expected ModelResponse with Choices for page {page_number}, got {type(response)}."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if response.choices[0].message.content is None:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"No text returned from LiteLLM when extracting page {page_number}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
content = response.choices[0].message.content
|
|
204
|
+
if not content:
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"No text returned from extraction model when extracting page {page_number} for {page_path}"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if self.filesystem_cache is not None:
|
|
210
|
+
# we don't want to fail the whole extraction just because cache write fails
|
|
211
|
+
# as that would block the whole flow
|
|
212
|
+
try:
|
|
213
|
+
logger.debug(f"Caching page {page_number} of {page_path} in cache")
|
|
214
|
+
await self.filesystem_cache.set(
|
|
215
|
+
self.pdf_page_cache_key(pdf_path, page_number),
|
|
216
|
+
content.encode("utf-8"),
|
|
217
|
+
)
|
|
218
|
+
except Exception:
|
|
219
|
+
logger.warning(
|
|
220
|
+
"Failed to cache page %s of %s; continuing without cache.",
|
|
221
|
+
page_number,
|
|
222
|
+
page_path,
|
|
223
|
+
exc_info=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
return content
|
|
227
|
+
|
|
228
|
+
async def _extract_pdf_page_by_page(self, pdf_path: Path, prompt: str) -> str:
|
|
229
|
+
async with split_pdf_into_pages(pdf_path) as page_paths:
|
|
230
|
+
page_outcomes: List[str | Exception | None] = [None] * len(page_paths)
|
|
231
|
+
|
|
232
|
+
extract_page_jobs: list = []
|
|
233
|
+
page_indices_for_jobs: list = [] # Track which page index each job corresponds to
|
|
234
|
+
|
|
235
|
+
# we extract from each page individually and then combine the results
|
|
236
|
+
# this ensures the model stays focused on the current page and does not
|
|
237
|
+
# start summarizing the later pages
|
|
238
|
+
for i, page_path in enumerate(page_paths):
|
|
239
|
+
page_content = await self.get_page_content_from_cache(pdf_path, i)
|
|
240
|
+
if page_content is not None:
|
|
241
|
+
page_outcomes[i] = page_content
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
extract_page_jobs.append(
|
|
245
|
+
self._extract_single_pdf_page(pdf_path, page_path, prompt, i)
|
|
246
|
+
)
|
|
247
|
+
page_indices_for_jobs.append(i)
|
|
248
|
+
|
|
249
|
+
if (
|
|
250
|
+
len(extract_page_jobs)
|
|
251
|
+
>= max_pdf_page_concurrency_for_model(self.litellm_model_slug())
|
|
252
|
+
or i == len(page_paths) - 1
|
|
253
|
+
):
|
|
254
|
+
extraction_results = await asyncio.gather(
|
|
255
|
+
*extract_page_jobs, return_exceptions=True
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
for batch_i, extraction_result in enumerate(extraction_results):
|
|
259
|
+
page_index = page_indices_for_jobs[batch_i]
|
|
260
|
+
# we let it continue even if there is an error - the success results will be cached
|
|
261
|
+
# and can be reused on the next run
|
|
262
|
+
if isinstance(extraction_result, Exception):
|
|
263
|
+
page_outcomes[page_index] = extraction_result
|
|
264
|
+
elif isinstance(extraction_result, str):
|
|
265
|
+
page_outcomes[page_index] = extraction_result
|
|
266
|
+
else:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"Unexpected type {type(extraction_result)} for page {page_index}"
|
|
269
|
+
)
|
|
270
|
+
extract_page_jobs.clear()
|
|
271
|
+
page_indices_for_jobs.clear()
|
|
272
|
+
|
|
273
|
+
exceptions: list[tuple[int, Exception]] = [
|
|
274
|
+
(page_index, result)
|
|
275
|
+
for page_index, result in enumerate(page_outcomes)
|
|
276
|
+
if isinstance(result, Exception)
|
|
277
|
+
]
|
|
278
|
+
if len(exceptions) > 0:
|
|
279
|
+
msg = f"Error extracting PDF {pdf_path}: "
|
|
280
|
+
for page_index, exception in exceptions:
|
|
281
|
+
msg += f"Page {page_index}: {exception}\n"
|
|
282
|
+
raise RuntimeError(msg)
|
|
283
|
+
|
|
284
|
+
return "\n\n".join(
|
|
285
|
+
[outcome for outcome in page_outcomes if isinstance(outcome, str)]
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _get_kind_from_mime_type(self, mime_type: str) -> Kind | None:
|
|
289
|
+
for kind, mime_types in MIME_TYPES_SUPPORTED.items():
|
|
290
|
+
if mime_type in mime_types:
|
|
291
|
+
return kind
|
|
292
|
+
return None
|
|
293
|
+
|
|
294
|
+
def _build_completion_kwargs(
|
|
295
|
+
self, prompt: str, extraction_input: ExtractionInput
|
|
296
|
+
) -> dict[str, Any]:
|
|
297
|
+
completion_kwargs = {
|
|
298
|
+
"model": self.litellm_model_slug(),
|
|
299
|
+
"messages": [
|
|
300
|
+
{
|
|
301
|
+
"role": "user",
|
|
302
|
+
"content": [
|
|
303
|
+
{"type": "text", "text": prompt},
|
|
304
|
+
encode_file_litellm_format(
|
|
305
|
+
Path(extraction_input.path), extraction_input.mime_type
|
|
306
|
+
),
|
|
307
|
+
],
|
|
308
|
+
}
|
|
309
|
+
],
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
if self.litellm_core_config.base_url:
|
|
313
|
+
completion_kwargs["base_url"] = self.litellm_core_config.base_url
|
|
314
|
+
|
|
315
|
+
if self.litellm_core_config.default_headers:
|
|
316
|
+
completion_kwargs["default_headers"] = (
|
|
317
|
+
self.litellm_core_config.default_headers
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if self.litellm_core_config.additional_body_options:
|
|
321
|
+
completion_kwargs.update(self.litellm_core_config.additional_body_options)
|
|
322
|
+
|
|
323
|
+
return completion_kwargs
|
|
324
|
+
|
|
325
|
+
async def _extract(self, extraction_input: ExtractionInput) -> ExtractionOutput:
|
|
326
|
+
kind = self._get_kind_from_mime_type(extraction_input.mime_type)
|
|
327
|
+
if kind is None:
|
|
328
|
+
raise ValueError(
|
|
329
|
+
f"Unsupported MIME type: {extraction_input.mime_type} for {extraction_input.path}"
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
prompt = self.prompt_for_kind.get(kind)
|
|
333
|
+
if prompt is None:
|
|
334
|
+
raise ValueError(f"No prompt found for kind: {kind}")
|
|
335
|
+
|
|
336
|
+
# special handling for PDFs - process each page individually
|
|
337
|
+
if extraction_input.mime_type == "application/pdf":
|
|
338
|
+
content = await self._extract_pdf_page_by_page(
|
|
339
|
+
Path(extraction_input.path), prompt
|
|
340
|
+
)
|
|
341
|
+
return ExtractionOutput(
|
|
342
|
+
is_passthrough=False,
|
|
343
|
+
content=content,
|
|
344
|
+
content_format=self.extractor_config.output_format,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
completion_kwargs = self._build_completion_kwargs(prompt, extraction_input)
|
|
348
|
+
|
|
349
|
+
response = await litellm.acompletion(**completion_kwargs)
|
|
350
|
+
|
|
351
|
+
if (
|
|
352
|
+
not isinstance(response, ModelResponse)
|
|
353
|
+
or not response.choices
|
|
354
|
+
or len(response.choices) == 0
|
|
355
|
+
or not isinstance(response.choices[0], Choices)
|
|
356
|
+
):
|
|
357
|
+
raise RuntimeError(
|
|
358
|
+
f"Expected ModelResponse with Choices, got {type(response)}."
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if response.choices[0].message.content is None:
|
|
362
|
+
raise ValueError("No text returned from LiteLLM when extracting document")
|
|
363
|
+
|
|
364
|
+
return ExtractionOutput(
|
|
365
|
+
is_passthrough=False,
|
|
366
|
+
content=response.choices[0].message.content,
|
|
367
|
+
content_format=self.extractor_config.output_format,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
def litellm_model_slug(self) -> str:
|
|
371
|
+
kiln_model_provider = built_in_models_from_provider(
|
|
372
|
+
ModelProviderName(self.extractor_config.model_provider_name),
|
|
373
|
+
self.extractor_config.model_name,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if kiln_model_provider is None:
|
|
377
|
+
raise ValueError(
|
|
378
|
+
f"Model provider {self.extractor_config.model_provider_name} not found in the list of built-in models"
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# need to translate into LiteLLM model slug
|
|
382
|
+
litellm_provider_name = get_litellm_provider_info(
|
|
383
|
+
kiln_model_provider,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
return litellm_provider_name.litellm_model_id
|