retab 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retab-0.0.35.dist-info/METADATA +417 -0
- retab-0.0.35.dist-info/RECORD +111 -0
- retab-0.0.35.dist-info/WHEEL +5 -0
- retab-0.0.35.dist-info/top_level.txt +1 -0
- uiform/__init__.py +4 -0
- uiform/_resource.py +28 -0
- uiform/_utils/__init__.py +0 -0
- uiform/_utils/ai_models.py +100 -0
- uiform/_utils/benchmarking copy.py +588 -0
- uiform/_utils/benchmarking.py +485 -0
- uiform/_utils/chat.py +332 -0
- uiform/_utils/display.py +443 -0
- uiform/_utils/json_schema.py +2161 -0
- uiform/_utils/mime.py +168 -0
- uiform/_utils/responses.py +163 -0
- uiform/_utils/stream_context_managers.py +52 -0
- uiform/_utils/usage/__init__.py +0 -0
- uiform/_utils/usage/usage.py +300 -0
- uiform/client.py +701 -0
- uiform/py.typed +0 -0
- uiform/resources/__init__.py +0 -0
- uiform/resources/consensus/__init__.py +3 -0
- uiform/resources/consensus/client.py +114 -0
- uiform/resources/consensus/completions.py +252 -0
- uiform/resources/consensus/completions_stream.py +278 -0
- uiform/resources/consensus/responses.py +325 -0
- uiform/resources/consensus/responses_stream.py +373 -0
- uiform/resources/deployments/__init__.py +9 -0
- uiform/resources/deployments/client.py +78 -0
- uiform/resources/deployments/endpoints.py +322 -0
- uiform/resources/deployments/links.py +452 -0
- uiform/resources/deployments/logs.py +211 -0
- uiform/resources/deployments/mailboxes.py +496 -0
- uiform/resources/deployments/outlook.py +531 -0
- uiform/resources/deployments/tests.py +158 -0
- uiform/resources/documents/__init__.py +3 -0
- uiform/resources/documents/client.py +255 -0
- uiform/resources/documents/extractions.py +441 -0
- uiform/resources/evals.py +812 -0
- uiform/resources/files.py +24 -0
- uiform/resources/finetuning.py +62 -0
- uiform/resources/jsonlUtils.py +1046 -0
- uiform/resources/models.py +45 -0
- uiform/resources/openai_example.py +22 -0
- uiform/resources/processors/__init__.py +3 -0
- uiform/resources/processors/automations/__init__.py +9 -0
- uiform/resources/processors/automations/client.py +78 -0
- uiform/resources/processors/automations/endpoints.py +317 -0
- uiform/resources/processors/automations/links.py +356 -0
- uiform/resources/processors/automations/logs.py +211 -0
- uiform/resources/processors/automations/mailboxes.py +435 -0
- uiform/resources/processors/automations/outlook.py +444 -0
- uiform/resources/processors/automations/tests.py +158 -0
- uiform/resources/processors/client.py +474 -0
- uiform/resources/prompt_optimization.py +76 -0
- uiform/resources/schemas.py +369 -0
- uiform/resources/secrets/__init__.py +9 -0
- uiform/resources/secrets/client.py +20 -0
- uiform/resources/secrets/external_api_keys.py +109 -0
- uiform/resources/secrets/webhook.py +62 -0
- uiform/resources/usage.py +271 -0
- uiform/types/__init__.py +0 -0
- uiform/types/ai_models.py +645 -0
- uiform/types/automations/__init__.py +0 -0
- uiform/types/automations/cron.py +58 -0
- uiform/types/automations/endpoints.py +21 -0
- uiform/types/automations/links.py +28 -0
- uiform/types/automations/mailboxes.py +60 -0
- uiform/types/automations/outlook.py +68 -0
- uiform/types/automations/webhooks.py +21 -0
- uiform/types/chat.py +8 -0
- uiform/types/completions.py +93 -0
- uiform/types/consensus.py +10 -0
- uiform/types/db/__init__.py +0 -0
- uiform/types/db/annotations.py +24 -0
- uiform/types/db/files.py +36 -0
- uiform/types/deployments/__init__.py +0 -0
- uiform/types/deployments/cron.py +59 -0
- uiform/types/deployments/endpoints.py +28 -0
- uiform/types/deployments/links.py +36 -0
- uiform/types/deployments/mailboxes.py +67 -0
- uiform/types/deployments/outlook.py +76 -0
- uiform/types/deployments/webhooks.py +21 -0
- uiform/types/documents/__init__.py +0 -0
- uiform/types/documents/correct_orientation.py +13 -0
- uiform/types/documents/create_messages.py +226 -0
- uiform/types/documents/extractions.py +297 -0
- uiform/types/evals.py +207 -0
- uiform/types/events.py +76 -0
- uiform/types/extractions.py +85 -0
- uiform/types/jobs/__init__.py +0 -0
- uiform/types/jobs/base.py +150 -0
- uiform/types/jobs/batch_annotation.py +22 -0
- uiform/types/jobs/evaluation.py +133 -0
- uiform/types/jobs/finetune.py +6 -0
- uiform/types/jobs/prompt_optimization.py +41 -0
- uiform/types/jobs/webcrawl.py +6 -0
- uiform/types/logs.py +231 -0
- uiform/types/mime.py +257 -0
- uiform/types/modalities.py +68 -0
- uiform/types/pagination.py +6 -0
- uiform/types/schemas/__init__.py +0 -0
- uiform/types/schemas/enhance.py +53 -0
- uiform/types/schemas/evaluate.py +55 -0
- uiform/types/schemas/generate.py +32 -0
- uiform/types/schemas/layout.py +58 -0
- uiform/types/schemas/object.py +631 -0
- uiform/types/schemas/templates.py +107 -0
- uiform/types/secrets/__init__.py +0 -0
- uiform/types/secrets/external_api_keys.py +22 -0
- uiform/types/standards.py +39 -0
@@ -0,0 +1,1046 @@
|
|
1
|
+
import asyncio
|
2
|
+
import hashlib
|
3
|
+
import json
|
4
|
+
import re
|
5
|
+
import shutil
|
6
|
+
import tempfile
|
7
|
+
import time
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from io import IOBase
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import IO, Any, Literal, Optional
|
12
|
+
|
13
|
+
from anthropic import Anthropic
|
14
|
+
from openai import OpenAI
|
15
|
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
16
|
+
from pydantic import BaseModel
|
17
|
+
from tqdm import tqdm
|
18
|
+
|
19
|
+
from .._resource import AsyncAPIResource, SyncAPIResource
|
20
|
+
from .._utils.ai_models import assert_valid_model_extraction, find_provider_from_model
|
21
|
+
from .._utils.chat import convert_to_anthropic_format, convert_to_openai_format, separate_messages
|
22
|
+
from .._utils.display import Metrics, display_metrics, process_dataset_and_compute_metrics
|
23
|
+
from .._utils.json_schema import load_json_schema
|
24
|
+
from ..types.chat import ChatCompletionUiformMessage
|
25
|
+
from ..types.modalities import Modality
|
26
|
+
from ..types.schemas.object import Schema
|
27
|
+
|
28
|
+
|
29
|
+
class FinetuningJSON(BaseModel):
|
30
|
+
messages: list[ChatCompletionUiformMessage]
|
31
|
+
|
32
|
+
|
33
|
+
FinetuningJSONL = list[FinetuningJSON]
|
34
|
+
from typing import TypedDict
|
35
|
+
|
36
|
+
|
37
|
+
class BatchJSONLResponseFormat(TypedDict):
|
38
|
+
type: str
|
39
|
+
json_schema: dict[str, Any]
|
40
|
+
|
41
|
+
|
42
|
+
class BatchJSONLBody(TypedDict):
|
43
|
+
model: str
|
44
|
+
messages: list[ChatCompletionMessageParam]
|
45
|
+
temperature: float
|
46
|
+
response_format: BatchJSONLResponseFormat
|
47
|
+
|
48
|
+
|
49
|
+
class BatchJSONL(TypedDict):
|
50
|
+
custom_id: str
|
51
|
+
method: str
|
52
|
+
url: str
|
53
|
+
body: BatchJSONLBody
|
54
|
+
|
55
|
+
|
56
|
+
class BatchJSONLResponseUsageTokenDetails(BaseModel):
|
57
|
+
cached_tokens: int
|
58
|
+
audio_tokens: int
|
59
|
+
|
60
|
+
|
61
|
+
class BatchJSONLResponseUsageCompletionDetails(BaseModel):
|
62
|
+
reasoning_tokens: int
|
63
|
+
audio_tokens: int
|
64
|
+
accepted_prediction_tokens: int
|
65
|
+
rejected_prediction_tokens: int
|
66
|
+
|
67
|
+
|
68
|
+
class BatchJSONLResponseUsage(BaseModel):
|
69
|
+
prompt_tokens: int
|
70
|
+
completion_tokens: int
|
71
|
+
total_tokens: int
|
72
|
+
prompt_tokens_details: BatchJSONLResponseUsageTokenDetails
|
73
|
+
completion_tokens_details: BatchJSONLResponseUsageCompletionDetails
|
74
|
+
|
75
|
+
|
76
|
+
class BatchJSONLResponseChoice(BaseModel):
|
77
|
+
index: int
|
78
|
+
message: ChatCompletionMessageParam
|
79
|
+
logprobs: None | Any
|
80
|
+
finish_reason: str
|
81
|
+
|
82
|
+
|
83
|
+
class BatchJSONLResponseBody(BaseModel):
|
84
|
+
id: str
|
85
|
+
object: str
|
86
|
+
created: int
|
87
|
+
model: str
|
88
|
+
choices: list[BatchJSONLResponseChoice]
|
89
|
+
usage: BatchJSONLResponseUsage
|
90
|
+
service_tier: str
|
91
|
+
system_fingerprint: str
|
92
|
+
|
93
|
+
|
94
|
+
class BatchJSONLResponseInner(BaseModel):
|
95
|
+
status_code: int
|
96
|
+
request_id: str
|
97
|
+
body: BatchJSONLResponseBody
|
98
|
+
|
99
|
+
|
100
|
+
class BatchJSONLResponse(BaseModel):
|
101
|
+
id: str
|
102
|
+
custom_id: str
|
103
|
+
response: BatchJSONLResponseInner
|
104
|
+
error: None | str
|
105
|
+
|
106
|
+
|
107
|
+
class BaseDatasetsMixin:
|
108
|
+
def _dump_training_set(self, training_set: list[dict[str, Any]], dataset_path: Path | str) -> None:
|
109
|
+
with open(dataset_path, 'w', encoding='utf-8') as file:
|
110
|
+
for entry in training_set:
|
111
|
+
file.write(json.dumps(entry) + '\n')
|
112
|
+
|
113
|
+
|
114
|
+
class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
115
|
+
"""Datasets API wrapper"""
|
116
|
+
|
117
|
+
# TODO : Maybe at some point we could add some visualization methods... but the multimodality makes it hard... # client.datasets.plot.tsne()... # client.datasets.plot.umap()...
|
118
|
+
def pprint(self, dataset_path: Path | str, input_token_price: Optional[float] = None, output_token_price: Optional[float] = None) -> Metrics:
|
119
|
+
"""Print a summary of the contents and statistics of a JSONL file.
|
120
|
+
|
121
|
+
This method analyzes the JSONL file and displays various metrics and statistics
|
122
|
+
about the dataset contents.
|
123
|
+
|
124
|
+
Inspired from : https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
|
125
|
+
|
126
|
+
Args:
|
127
|
+
dataset_path: Path to the JSONL file to analyze
|
128
|
+
completion_token_price: Price per completion token
|
129
|
+
input_token_price: Price per input token
|
130
|
+
"""
|
131
|
+
|
132
|
+
computed_metrics = process_dataset_and_compute_metrics(dataset_path)
|
133
|
+
display_metrics(computed_metrics, input_token_price=input_token_price, output_token_price=output_token_price)
|
134
|
+
return computed_metrics
|
135
|
+
|
136
|
+
def save(
|
137
|
+
self,
|
138
|
+
json_schema: dict[str, Any] | Path | str,
|
139
|
+
document_annotation_pairs_paths: list[dict[str, Path | str]],
|
140
|
+
dataset_path: Path | str,
|
141
|
+
image_resolution_dpi: int | None = None,
|
142
|
+
browser_canvas: Literal['A3', 'A4', 'A5'] | None = None,
|
143
|
+
modality: Modality = "native",
|
144
|
+
) -> None:
|
145
|
+
"""Save document-annotation pairs to a JSONL training set.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
json_schema: The JSON schema for validation, can be a dict, Path, or string
|
149
|
+
document_annotation_pairs_paths: {document_fpath: Path | str, annotation_fpath: Path | str} List of dictionaries containing document and annotation file paths
|
150
|
+
jsonl_path: Output path for the JSONL training file
|
151
|
+
modality: The modality to use for document processing ("native" by default)
|
152
|
+
"""
|
153
|
+
json_schema = load_json_schema(json_schema)
|
154
|
+
schema_obj = Schema(json_schema=json_schema)
|
155
|
+
|
156
|
+
with open(dataset_path, 'w', encoding='utf-8') as file:
|
157
|
+
for pair_paths in tqdm(document_annotation_pairs_paths, desc="Processing pairs", position=0):
|
158
|
+
document_message = self._client.documents.create_messages(document=pair_paths['document_fpath'], modality=modality, image_resolution_dpi=image_resolution_dpi, browser_canvas=browser_canvas)
|
159
|
+
|
160
|
+
with open(pair_paths['annotation_fpath'], 'r') as f:
|
161
|
+
annotation = json.loads(f.read())
|
162
|
+
assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
|
163
|
+
|
164
|
+
entry = {"messages": schema_obj.messages + document_message.messages + [assistant_message]}
|
165
|
+
file.write(json.dumps(entry) + '\n')
|
166
|
+
|
167
|
+
def change_schema(
|
168
|
+
self,
|
169
|
+
input_dataset_path: Path | str,
|
170
|
+
json_schema: dict[str, Any] | Path | str,
|
171
|
+
output_dataset_path: None | Path | str = None,
|
172
|
+
inplace: bool = False,
|
173
|
+
) -> None:
|
174
|
+
"""Change the system prompt in a dataset to match a new schema.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
input_dataset_path: Path to the input JSONL dataset file
|
178
|
+
output_dataset_path: Path to the output JSONL dataset file
|
179
|
+
json_schema: The new JSON schema for validation, can be a dict, Path, or string
|
180
|
+
inplace: If True, overwrite the input dataset with the changes
|
181
|
+
"""
|
182
|
+
json_schema = load_json_schema(json_schema)
|
183
|
+
schema_obj = Schema(json_schema=json_schema)
|
184
|
+
|
185
|
+
# Determine the path to write to
|
186
|
+
if inplace:
|
187
|
+
assert output_dataset_path is None, "Cannot specify inplace=True and output_dataset_path not None"
|
188
|
+
target_path = input_dataset_path
|
189
|
+
else:
|
190
|
+
assert output_dataset_path, "Cannot save the file if inplace=False and output_dataset_path is None"
|
191
|
+
target_path = output_dataset_path
|
192
|
+
|
193
|
+
assert isinstance(target_path, Path) or isinstance(target_path, str)
|
194
|
+
|
195
|
+
# Use a temporary file to write the updated content
|
196
|
+
with tempfile.NamedTemporaryFile('w', delete=False, encoding='utf-8') as temp_file:
|
197
|
+
with open(input_dataset_path, 'r', encoding='utf-8') as infile:
|
198
|
+
for line in infile:
|
199
|
+
entry = json.loads(line)
|
200
|
+
messages = entry.get("messages", [])
|
201
|
+
|
202
|
+
# Remove existing system prompt if it exists
|
203
|
+
if messages and messages[0].get("role") in ("developer", "system"):
|
204
|
+
messages = messages[1:]
|
205
|
+
|
206
|
+
# Add the new system prompt from schema_obj.messages
|
207
|
+
updated_messages = schema_obj.messages + messages
|
208
|
+
updated_entry = {"messages": updated_messages}
|
209
|
+
|
210
|
+
# Write the updated entry to the temporary file
|
211
|
+
temp_file.write(json.dumps(updated_entry) + '\n')
|
212
|
+
|
213
|
+
# Replace the original file with the temporary file
|
214
|
+
shutil.move(temp_file.name, target_path)
|
215
|
+
|
216
|
+
def stich_and_save(
|
217
|
+
self,
|
218
|
+
json_schema: dict[str, Any] | Path | str,
|
219
|
+
pairs_paths: list[dict[str, Path | str | list[Path | str] | list[str] | list[Path]]],
|
220
|
+
dataset_path: Path | str,
|
221
|
+
modality: Modality = "native",
|
222
|
+
) -> None:
|
223
|
+
"""Stitch multiple documents and their annotations into a JSONL training set.
|
224
|
+
|
225
|
+
This method processes and combines multiple documents into messages, creating document-annotation
|
226
|
+
pairs that are saved to a JSONL file. Each document is processed according to the specified
|
227
|
+
modality and combined with its corresponding annotation.
|
228
|
+
|
229
|
+
|
230
|
+
Args:
|
231
|
+
json_schema: The JSON schema for validation, can be a dict, Path, or string
|
232
|
+
pairs_paths: List of dictionaries containing document and annotation file paths
|
233
|
+
jsonl_path: Output path for the JSONL training file
|
234
|
+
modality: The modality to use for document processing ("native" by default)
|
235
|
+
"""
|
236
|
+
|
237
|
+
json_schema = load_json_schema(json_schema)
|
238
|
+
schema_obj = Schema(json_schema=json_schema)
|
239
|
+
training_set = []
|
240
|
+
|
241
|
+
for pair_paths in tqdm(pairs_paths):
|
242
|
+
document_messages: list[ChatCompletionUiformMessage] = []
|
243
|
+
|
244
|
+
if isinstance(pair_paths['document_fpath'], str) or isinstance(pair_paths['document_fpath'], Path):
|
245
|
+
document_message = self._client.documents.create_messages(document=pair_paths['document_fpath'], modality=modality)
|
246
|
+
document_messages.extend(document_message.messages)
|
247
|
+
|
248
|
+
else:
|
249
|
+
assert isinstance(pair_paths['document_fpath'], list)
|
250
|
+
for document_fpath in pair_paths['document_fpath']:
|
251
|
+
document_message = self._client.documents.create_messages(document=document_fpath, modality=modality)
|
252
|
+
document_messages.extend(document_message.messages)
|
253
|
+
|
254
|
+
# Use context manager to properly close the file
|
255
|
+
assert isinstance(pair_paths['annotation_fpath'], Path) or isinstance(pair_paths['annotation_fpath'], str)
|
256
|
+
with open(pair_paths['annotation_fpath'], 'r') as f:
|
257
|
+
annotation = json.loads(f.read())
|
258
|
+
assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
|
259
|
+
|
260
|
+
# Add the complete message set as an entry
|
261
|
+
training_set.append({"messages": schema_obj.messages + document_messages + [assistant_message]})
|
262
|
+
|
263
|
+
self._dump_training_set(training_set, dataset_path)
|
264
|
+
|
265
|
+
#########################################
|
266
|
+
##### ENDPOINTS THAT MAKE LLM CALLS #####
|
267
|
+
#########################################
|
268
|
+
|
269
|
+
def _initialize_model_client(self, model: str) -> tuple[OpenAI | Anthropic, str]:
|
270
|
+
"""Initialize the appropriate client based on the model provider.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
model: The model identifier string
|
274
|
+
|
275
|
+
Returns:
|
276
|
+
A tuple of (client instance, provider type string)
|
277
|
+
"""
|
278
|
+
provider = find_provider_from_model(model)
|
279
|
+
|
280
|
+
if provider == "OpenAI":
|
281
|
+
return OpenAI(api_key=self._client.headers["OpenAI-Api-Key"]), provider
|
282
|
+
elif provider == "xAI":
|
283
|
+
return OpenAI(api_key=self._client.headers["XAI-Api-Key"], base_url="https://api.x.ai/v1"), provider
|
284
|
+
elif provider == "Gemini":
|
285
|
+
return OpenAI(
|
286
|
+
api_key=self._client.headers["Gemini-Api-Key"],
|
287
|
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
288
|
+
), provider
|
289
|
+
else:
|
290
|
+
assert provider == "Anthropic", f"Unsupported model: {model}"
|
291
|
+
return Anthropic(api_key=self._client.headers["Anthropic-Api-Key"]), provider
|
292
|
+
|
293
|
+
def _get_model_completion(
|
294
|
+
self,
|
295
|
+
client: OpenAI | Anthropic,
|
296
|
+
provider: str,
|
297
|
+
model: str,
|
298
|
+
temperature: float,
|
299
|
+
messages: list[ChatCompletionUiformMessage],
|
300
|
+
schema_obj: Schema,
|
301
|
+
) -> str:
|
302
|
+
"""Get completion from the appropriate model provider.
|
303
|
+
|
304
|
+
Args:
|
305
|
+
client: The initialized client instance
|
306
|
+
provider: The provider type string
|
307
|
+
model: The model identifier
|
308
|
+
temperature: Temperature setting for generation
|
309
|
+
messages: The messages to send to the model
|
310
|
+
schema_obj: The schema object containing format information
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
The completion string in JSON format
|
314
|
+
"""
|
315
|
+
if provider in ["OpenAI", "xAI"]:
|
316
|
+
assert isinstance(client, OpenAI)
|
317
|
+
completion = client.chat.completions.create(
|
318
|
+
model=model,
|
319
|
+
temperature=temperature,
|
320
|
+
messages=convert_to_openai_format(messages),
|
321
|
+
response_format={"type": "json_schema", "json_schema": {"name": schema_obj.id, "schema": schema_obj.inference_json_schema, "strict": True}},
|
322
|
+
)
|
323
|
+
assert completion.choices[0].message.content is not None
|
324
|
+
return completion.choices[0].message.content
|
325
|
+
|
326
|
+
elif provider == "Gemini":
|
327
|
+
assert isinstance(client, OpenAI)
|
328
|
+
gemini_completion = client.chat.completions.create(
|
329
|
+
model=model,
|
330
|
+
temperature=temperature,
|
331
|
+
messages=convert_to_openai_format(messages),
|
332
|
+
response_format={"type": "json_schema", "json_schema": {"name": schema_obj.id, "schema": schema_obj.inference_gemini_json_schema, "strict": True}},
|
333
|
+
)
|
334
|
+
assert gemini_completion.choices[0].message.content is not None
|
335
|
+
return gemini_completion.choices[0].message.content
|
336
|
+
|
337
|
+
else: # Anthropic
|
338
|
+
assert isinstance(client, Anthropic)
|
339
|
+
system_messages, other_messages = convert_to_anthropic_format(messages)
|
340
|
+
anthropic_completion = client.messages.create(
|
341
|
+
model="claude-3-5-sonnet-20241022", max_tokens=8192, temperature=temperature, system=system_messages, messages=other_messages
|
342
|
+
)
|
343
|
+
from anthropic.types.text_block import TextBlock
|
344
|
+
|
345
|
+
assert isinstance(anthropic_completion.content[0], TextBlock)
|
346
|
+
return anthropic_completion.content[0].text
|
347
|
+
|
348
|
+
def annotate(
|
349
|
+
self,
|
350
|
+
json_schema: dict[str, Any] | Path | str,
|
351
|
+
documents: list[Path | str | IOBase],
|
352
|
+
dataset_path: Path,
|
353
|
+
model: str = "gpt-4o-2024-08-06",
|
354
|
+
temperature: float = 0.0,
|
355
|
+
batch_size: int = 5,
|
356
|
+
max_concurrent: int = 3,
|
357
|
+
root_dir: Path = Path("annotations"),
|
358
|
+
modality: Modality = "native",
|
359
|
+
) -> None:
|
360
|
+
json_schema = load_json_schema(json_schema)
|
361
|
+
assert_valid_model_extraction(model)
|
362
|
+
|
363
|
+
client, provider = self._initialize_model_client(model)
|
364
|
+
schema_obj = Schema(json_schema=json_schema)
|
365
|
+
|
366
|
+
"""
|
367
|
+
Generate annotations from document files or in-memory documents
|
368
|
+
and create a JSONL training set in one go.
|
369
|
+
|
370
|
+
Args:
|
371
|
+
json_schema: The JSON schema for validation
|
372
|
+
documents: list of documents, each can be a Path/str or an IOBase object
|
373
|
+
dataset_path: Output path for the JSONL training file
|
374
|
+
model: The model to use for processing
|
375
|
+
temperature: Model temperature (0-1)
|
376
|
+
batch_size: Number of examples to process in each batch
|
377
|
+
max_concurrent: Maximum number of concurrent API calls
|
378
|
+
root_dir: Where to store the per-document JSON annotations
|
379
|
+
"""
|
380
|
+
|
381
|
+
def process_example(doc: Path | str | IOBase) -> dict[str, Any]:
|
382
|
+
"""
|
383
|
+
Process a single document (either a file path or an in-memory file-like object).
|
384
|
+
Returns a dict with pointers to the original doc and the stored annotation JSON.
|
385
|
+
"""
|
386
|
+
if isinstance(doc, (str, Path)):
|
387
|
+
doc_path = Path(doc)
|
388
|
+
if not doc_path.is_file():
|
389
|
+
raise ValueError(f"Invalid file path: {doc_path}")
|
390
|
+
hash_str = hashlib.md5(doc_path.as_posix().encode()).hexdigest()
|
391
|
+
elif isinstance(doc, IO):
|
392
|
+
file_bytes = doc.read()
|
393
|
+
hash_str = hashlib.md5(file_bytes).hexdigest()
|
394
|
+
doc.seek(0)
|
395
|
+
else:
|
396
|
+
raise ValueError(f"Unsupported document type: {type(doc)}")
|
397
|
+
|
398
|
+
doc_msg = self._client.documents.create_messages(
|
399
|
+
document=doc,
|
400
|
+
modality=modality,
|
401
|
+
)
|
402
|
+
|
403
|
+
# Use _get_model_completion instead of duplicating provider-specific logic
|
404
|
+
string_json = self._get_model_completion(
|
405
|
+
client=client, provider=provider, model=model, temperature=temperature, messages=schema_obj.messages + doc_msg.messages, schema_obj=schema_obj
|
406
|
+
)
|
407
|
+
|
408
|
+
annotation_path = Path(root_dir) / f"annotations_{hash_str}.json"
|
409
|
+
annotation_path.parent.mkdir(parents=True, exist_ok=True)
|
410
|
+
|
411
|
+
with open(annotation_path, 'w', encoding='utf-8') as f:
|
412
|
+
json.dump(string_json, f, ensure_ascii=False, indent=2)
|
413
|
+
|
414
|
+
return {"document_fpath": str(doc_path), "annotation_fpath": str(annotation_path)}
|
415
|
+
|
416
|
+
# Make sure output directory exists
|
417
|
+
Path(root_dir).mkdir(parents=True, exist_ok=True)
|
418
|
+
|
419
|
+
pairs_paths: list[dict[str, Path | str]] = []
|
420
|
+
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
|
421
|
+
futures = []
|
422
|
+
# Split documents into batches
|
423
|
+
for batch in tqdm([documents[i : i + batch_size] for i in range(0, len(documents), batch_size)], desc="Processing batches"):
|
424
|
+
# Submit batch of tasks
|
425
|
+
batch_futures = []
|
426
|
+
for doc in batch:
|
427
|
+
try:
|
428
|
+
future = executor.submit(process_example, doc)
|
429
|
+
batch_futures.append(future)
|
430
|
+
except Exception as e:
|
431
|
+
print(f"Error submitting document for processing: {e}")
|
432
|
+
futures.extend(batch_futures)
|
433
|
+
|
434
|
+
# Wait for batch to finish (rate limit)
|
435
|
+
for future in batch_futures:
|
436
|
+
try:
|
437
|
+
pair = future.result()
|
438
|
+
pairs_paths.append(pair)
|
439
|
+
except Exception as e:
|
440
|
+
print(f"Error processing example: {e}")
|
441
|
+
|
442
|
+
# Generate final training set from all results
|
443
|
+
self.save(json_schema=json_schema, document_annotation_pairs_paths=pairs_paths, dataset_path=dataset_path)
|
444
|
+
|
445
|
+
def eval(
|
446
|
+
self,
|
447
|
+
json_schema: dict[str, Any] | Path | str,
|
448
|
+
dataset_path: str | Path,
|
449
|
+
model: str = "gpt-4o-2024-08-06",
|
450
|
+
temperature: float = 0.0,
|
451
|
+
batch_size: int = 5,
|
452
|
+
max_concurrent: int = 3,
|
453
|
+
display: bool = True,
|
454
|
+
) -> ComparisonMetrics:
|
455
|
+
"""Evaluate model performance on a test dataset.
|
456
|
+
|
457
|
+
Args:
|
458
|
+
json_schema: JSON schema defining the expected data structure
|
459
|
+
dataset_path: Path to the JSONL file containing test examples
|
460
|
+
model: The model to use for benchmarking
|
461
|
+
temperature: Model temperature setting (0-1)
|
462
|
+
batch_size: Number of examples to process in each batch
|
463
|
+
max_concurrent: Maximum number of concurrent API calls
|
464
|
+
"""
|
465
|
+
|
466
|
+
json_schema = load_json_schema(json_schema)
|
467
|
+
assert_valid_model_extraction(model)
|
468
|
+
schema_obj = Schema(json_schema=json_schema)
|
469
|
+
|
470
|
+
# Initialize appropriate client
|
471
|
+
client, provider = self._initialize_model_client(model)
|
472
|
+
|
473
|
+
# Read all lines from the JSONL file
|
474
|
+
with open(dataset_path, 'r') as f:
|
475
|
+
lines = [json.loads(line) for line in f]
|
476
|
+
|
477
|
+
extraction_analyses: list[ExtractionAnalysis] = []
|
478
|
+
total_batches = (len(lines) + batch_size - 1) // batch_size
|
479
|
+
|
480
|
+
# Create main progress bar for batches
|
481
|
+
batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
|
482
|
+
|
483
|
+
# Track running metrics
|
484
|
+
class RunningMetrics(BaseModel):
|
485
|
+
model: str
|
486
|
+
accuracy: float
|
487
|
+
levenshtein: float
|
488
|
+
jaccard: float
|
489
|
+
false_positive: float
|
490
|
+
mismatched: float
|
491
|
+
processed: int
|
492
|
+
|
493
|
+
running_metrics: RunningMetrics = RunningMetrics(
|
494
|
+
model=model,
|
495
|
+
accuracy=0.0,
|
496
|
+
levenshtein=0.0,
|
497
|
+
jaccard=0.0,
|
498
|
+
false_positive=0.0,
|
499
|
+
mismatched=0.0,
|
500
|
+
processed=0, # number of processed examples - used in the loop to compute the running averages
|
501
|
+
)
|
502
|
+
|
503
|
+
def update_running_metrics(analysis: ExtractionAnalysis) -> None:
|
504
|
+
comparison = normalized_comparison_metrics([analysis])
|
505
|
+
running_metrics.processed += 1
|
506
|
+
n = running_metrics.processed
|
507
|
+
# Update running averages
|
508
|
+
running_metrics.accuracy = (running_metrics.accuracy * (n - 1) + comparison.accuracy) / n
|
509
|
+
running_metrics.levenshtein = (running_metrics.levenshtein * (n - 1) + comparison.levenshtein_similarity) / n
|
510
|
+
running_metrics.jaccard = (running_metrics.jaccard * (n - 1) + comparison.jaccard_similarity) / n
|
511
|
+
running_metrics.false_positive = (running_metrics.false_positive * (n - 1) + comparison.false_positive_rate) / n
|
512
|
+
running_metrics.mismatched = (running_metrics.mismatched * (n - 1) + comparison.mismatched_value_rate) / n
|
513
|
+
# Update progress bar description
|
514
|
+
batch_pbar.set_description(
|
515
|
+
f"Processing batches | Model: {running_metrics.model} | Acc: {running_metrics.accuracy:.2f} | "
|
516
|
+
f"Lev: {running_metrics.levenshtein:.2f} | "
|
517
|
+
f"IOU: {running_metrics.jaccard:.2f} | "
|
518
|
+
f"FP: {running_metrics.false_positive:.2f} | "
|
519
|
+
f"Mism: {running_metrics.mismatched:.2f}"
|
520
|
+
)
|
521
|
+
|
522
|
+
def process_example(jsonline: dict) -> ExtractionAnalysis | None:
|
523
|
+
line_number = jsonline['line_number']
|
524
|
+
try:
|
525
|
+
messages = jsonline['messages']
|
526
|
+
ground_truth = json.loads(messages[-1]['content'])
|
527
|
+
inference_messages = messages[:-1]
|
528
|
+
|
529
|
+
# Use _get_model_completion instead of duplicating provider-specific logic
|
530
|
+
string_json = self._get_model_completion(client=client, provider=provider, model=model, temperature=temperature, messages=inference_messages, schema_obj=schema_obj)
|
531
|
+
|
532
|
+
prediction = json.loads(string_json)
|
533
|
+
analysis = ExtractionAnalysis(
|
534
|
+
ground_truth=ground_truth,
|
535
|
+
prediction=prediction,
|
536
|
+
)
|
537
|
+
update_running_metrics(analysis)
|
538
|
+
return analysis
|
539
|
+
except Exception as e:
|
540
|
+
print(f"\nWarning: Failed to process line number {line_number}: {str(e)}")
|
541
|
+
return None
|
542
|
+
|
543
|
+
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
|
544
|
+
# Split entries into batches
|
545
|
+
for batch_idx in range(0, len(lines), batch_size):
|
546
|
+
batch = lines[batch_idx : batch_idx + batch_size]
|
547
|
+
|
548
|
+
# Submit and process batch
|
549
|
+
futures = [executor.submit(process_example, entry | {"line_number": batch_idx * batch_size + i}) for i, entry in enumerate(batch)]
|
550
|
+
for future in futures:
|
551
|
+
result = future.result()
|
552
|
+
if result is not None:
|
553
|
+
extraction_analyses.append(result)
|
554
|
+
|
555
|
+
batch_pbar.update(1)
|
556
|
+
|
557
|
+
batch_pbar.close()
|
558
|
+
|
559
|
+
# Analyze error patterns across all examples
|
560
|
+
analysis = normalized_comparison_metrics(extraction_analyses)
|
561
|
+
|
562
|
+
if display:
|
563
|
+
plot_comparison_metrics(analysis=analysis, top_n=10)
|
564
|
+
|
565
|
+
return analysis
|
566
|
+
|
567
|
+
def benchmark(
|
568
|
+
self,
|
569
|
+
json_schema: dict[str, Any] | Path | str,
|
570
|
+
dataset_path: str | Path,
|
571
|
+
models: list[str],
|
572
|
+
temperature: float = 0.0,
|
573
|
+
batch_size: int = 5,
|
574
|
+
max_concurrent: int = 3,
|
575
|
+
print: bool = True,
|
576
|
+
verbose: bool = False,
|
577
|
+
) -> list[BenchmarkMetrics]:
|
578
|
+
"""Benchmark multiple models on a test dataset.
|
579
|
+
|
580
|
+
Args:
|
581
|
+
json_schema: JSON schema defining the expected data structure
|
582
|
+
dataset_path: Path to the JSONL file containing test examples
|
583
|
+
models: List of models to benchmark
|
584
|
+
temperature: Model temperature setting (0-1)
|
585
|
+
batch_size: Number of examples to process in each batch
|
586
|
+
max_concurrent: Maximum number of concurrent API calls
|
587
|
+
print: Whether to print the metrics
|
588
|
+
verbose: Whether to print all the metrics of all the function calls
|
589
|
+
|
590
|
+
Returns:
|
591
|
+
Dictionary mapping model names to their evaluation metrics
|
592
|
+
"""
|
593
|
+
results: list[BenchmarkMetrics] = []
|
594
|
+
|
595
|
+
for model in models:
|
596
|
+
metrics: ComparisonMetrics = self.eval(
|
597
|
+
json_schema=json_schema, dataset_path=dataset_path, model=model, temperature=temperature, batch_size=batch_size, max_concurrent=max_concurrent, display=verbose
|
598
|
+
)
|
599
|
+
results.append(
|
600
|
+
BenchmarkMetrics(
|
601
|
+
ai_model=model,
|
602
|
+
accuracy=metrics.accuracy,
|
603
|
+
levenshtein_similarity=metrics.levenshtein_similarity,
|
604
|
+
jaccard_similarity=metrics.jaccard_similarity,
|
605
|
+
false_positive_rate=metrics.false_positive_rate,
|
606
|
+
false_negative_rate=metrics.false_negative_rate,
|
607
|
+
mismatched_value_rate=metrics.mismatched_value_rate,
|
608
|
+
)
|
609
|
+
)
|
610
|
+
|
611
|
+
if print:
|
612
|
+
display_benchmark_metrics(results)
|
613
|
+
|
614
|
+
return results
|
615
|
+
|
616
|
+
def update_annotations(
|
617
|
+
self,
|
618
|
+
json_schema: dict[str, Any] | Path | str,
|
619
|
+
old_dataset_path: str | Path,
|
620
|
+
new_dataset_path: str | Path,
|
621
|
+
model: str = "gpt-4o-2024-08-06",
|
622
|
+
temperature: float = 0.0,
|
623
|
+
batch_size: int = 5,
|
624
|
+
max_concurrent: int = 3,
|
625
|
+
) -> None:
|
626
|
+
"""Update annotations in a JSONL file using a new model.
|
627
|
+
|
628
|
+
Args:
|
629
|
+
json_schema: The JSON schema for validation
|
630
|
+
old_dataset_path: Path to the JSONL file to update
|
631
|
+
new_dataset_path: Path for saving updated annotations
|
632
|
+
model: The model to use for new annotations
|
633
|
+
temperature: Model temperature (0-1)
|
634
|
+
batch_size: Number of examples to process in each batch
|
635
|
+
max_concurrent: Maximum number of concurrent API calls
|
636
|
+
"""
|
637
|
+
json_schema = load_json_schema(json_schema)
|
638
|
+
assert_valid_model_extraction(model)
|
639
|
+
schema_obj = Schema(json_schema=json_schema)
|
640
|
+
|
641
|
+
# Initialize appropriate client
|
642
|
+
client, provider = self._initialize_model_client(model)
|
643
|
+
|
644
|
+
# Read all lines from the JSONL file
|
645
|
+
with open(old_dataset_path, 'r') as f:
|
646
|
+
lines = [json.loads(line) for line in f]
|
647
|
+
|
648
|
+
updated_entries = []
|
649
|
+
total_batches = (len(lines) + batch_size - 1) // batch_size
|
650
|
+
|
651
|
+
batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
|
652
|
+
|
653
|
+
def process_entry(entry: dict) -> dict:
|
654
|
+
messages = entry['messages']
|
655
|
+
system_message, user_messages, assistant_messages = separate_messages(messages)
|
656
|
+
system_and_user_messages = messages[:-1]
|
657
|
+
|
658
|
+
previous_annotation_message: ChatCompletionUiformMessage = {
|
659
|
+
"role": "user",
|
660
|
+
"content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + messages[-1]['content'],
|
661
|
+
}
|
662
|
+
|
663
|
+
string_json = self._get_model_completion(
|
664
|
+
client=client,
|
665
|
+
provider=provider,
|
666
|
+
model=model,
|
667
|
+
temperature=temperature,
|
668
|
+
messages=schema_obj.messages + user_messages + [previous_annotation_message],
|
669
|
+
schema_obj=schema_obj,
|
670
|
+
)
|
671
|
+
|
672
|
+
return {"messages": system_and_user_messages + [{"role": "assistant", "content": string_json}]}
|
673
|
+
|
674
|
+
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
|
675
|
+
futures = []
|
676
|
+
for batch_idx in range(0, len(lines), batch_size):
|
677
|
+
batch = lines[batch_idx : batch_idx + batch_size]
|
678
|
+
|
679
|
+
batch_futures = [executor.submit(process_entry, entry | {"line_number": batch_idx * batch_size + i}) for i, entry in enumerate(batch)]
|
680
|
+
futures.extend(batch_futures)
|
681
|
+
|
682
|
+
for future in batch_futures:
|
683
|
+
try:
|
684
|
+
result = future.result()
|
685
|
+
updated_entries.append(result)
|
686
|
+
except Exception as e:
|
687
|
+
print(f"Error processing example: {e}")
|
688
|
+
|
689
|
+
batch_pbar.update(1)
|
690
|
+
time.sleep(1)
|
691
|
+
|
692
|
+
batch_pbar.close()
|
693
|
+
|
694
|
+
with open(new_dataset_path, 'w') as f:
|
695
|
+
for entry in updated_entries:
|
696
|
+
f.write(json.dumps(entry) + '\n')
|
697
|
+
|
698
|
+
#########################
|
699
|
+
##### BATCH METHODS #####
|
700
|
+
#########################
|
701
|
+
|
702
|
+
def save_batch_annotate_requests(
|
703
|
+
self,
|
704
|
+
json_schema: dict[str, Any] | Path | str,
|
705
|
+
documents: list[Path | str | IOBase],
|
706
|
+
batch_requests_path: Path,
|
707
|
+
model: str = "gpt-4o-mini",
|
708
|
+
temperature: float = 0.0,
|
709
|
+
modality: Modality = "native",
|
710
|
+
) -> None:
|
711
|
+
"""Create a JSONL file containing requests for OpenAI batch processing API.
|
712
|
+
|
713
|
+
Args:
|
714
|
+
json_schema: The JSON schema for validation
|
715
|
+
documents: List of documents to process
|
716
|
+
batch_requests_path: Output path for the JSONL requests file
|
717
|
+
model: The model to use for processing
|
718
|
+
temperature: Model temperature (0-1)
|
719
|
+
modality: The modality to use for document processing
|
720
|
+
"""
|
721
|
+
loaded_json_schema = load_json_schema(json_schema)
|
722
|
+
schema_obj = Schema(json_schema=loaded_json_schema)
|
723
|
+
assert_valid_model_extraction(model)
|
724
|
+
|
725
|
+
with open(batch_requests_path, 'w', encoding='utf-8') as f:
|
726
|
+
for i, doc in tqdm(enumerate(documents)):
|
727
|
+
# Create document messages
|
728
|
+
doc_msg = self._client.documents.create_messages(
|
729
|
+
document=doc,
|
730
|
+
modality=modality,
|
731
|
+
)
|
732
|
+
|
733
|
+
# Construct the request object
|
734
|
+
request: BatchJSONL = {
|
735
|
+
"custom_id": f"request-{i}",
|
736
|
+
"method": "POST",
|
737
|
+
"url": "/v1/chat/completions",
|
738
|
+
"body": {
|
739
|
+
"model": model,
|
740
|
+
"messages": schema_obj.openai_messages + doc_msg.openai_messages,
|
741
|
+
"temperature": temperature,
|
742
|
+
"response_format": {"type": "json_schema", "json_schema": {"name": schema_obj.id, "schema": schema_obj.inference_json_schema, "strict": True}},
|
743
|
+
},
|
744
|
+
}
|
745
|
+
|
746
|
+
# Write the request as a JSON line
|
747
|
+
f.write(json.dumps(request) + '\n')
|
748
|
+
|
749
|
+
def save_batch_update_annotation_requests(
|
750
|
+
self,
|
751
|
+
json_schema: dict[str, Any] | Path | str,
|
752
|
+
old_dataset_path: str | Path,
|
753
|
+
batch_requests_path: str | Path,
|
754
|
+
model: str = "gpt-4o-mini",
|
755
|
+
temperature: float = 0.0,
|
756
|
+
) -> None:
|
757
|
+
"""Create a JSONL file containing requests for OpenAI batch processing API to update annotations.
|
758
|
+
|
759
|
+
Args:
|
760
|
+
json_schema: The JSON schema for validation
|
761
|
+
old_dataset_path: Path to the JSONL file to update
|
762
|
+
batch_requests_path: Output path for the updated JSONL file
|
763
|
+
model: The model to use for processing
|
764
|
+
temperature: Model temperature (0-1)
|
765
|
+
modality: The modality to use for document processing
|
766
|
+
"""
|
767
|
+
loaded_json_schema = load_json_schema(json_schema)
|
768
|
+
schema_obj = Schema(json_schema=loaded_json_schema)
|
769
|
+
|
770
|
+
# Read existing annotations
|
771
|
+
with open(old_dataset_path, 'r') as f:
|
772
|
+
entries = [json.loads(line) for line in f]
|
773
|
+
|
774
|
+
# Create new JSONL with update requests
|
775
|
+
with open(batch_requests_path, 'w', encoding='utf-8') as f:
|
776
|
+
for i, entry in enumerate(entries):
|
777
|
+
existing_messages = entry['messages']
|
778
|
+
system_and_user_messages = existing_messages[:-1]
|
779
|
+
|
780
|
+
previous_annotation_message: ChatCompletionMessageParam = {
|
781
|
+
"role": "user",
|
782
|
+
"content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + existing_messages[-1]['content'],
|
783
|
+
}
|
784
|
+
|
785
|
+
# Construct the request object
|
786
|
+
response_format: BatchJSONLResponseFormat = {
|
787
|
+
"type": "json_schema",
|
788
|
+
"json_schema": {"name": schema_obj.id, "schema": schema_obj.inference_json_schema, "strict": True},
|
789
|
+
}
|
790
|
+
|
791
|
+
body: BatchJSONLBody = {
|
792
|
+
"model": model,
|
793
|
+
"messages": schema_obj.openai_messages + system_and_user_messages + [previous_annotation_message],
|
794
|
+
"temperature": temperature,
|
795
|
+
"response_format": response_format,
|
796
|
+
}
|
797
|
+
|
798
|
+
request: BatchJSONL = {"custom_id": f"request-{i}", "method": "POST", "url": "/v1/chat/completions", "body": body}
|
799
|
+
|
800
|
+
# Write the request as a JSON line
|
801
|
+
f.write(json.dumps(request) + '\n')
|
802
|
+
|
803
|
+
def build_dataset_from_batch_results(
|
804
|
+
self,
|
805
|
+
batch_requests_path: str | Path,
|
806
|
+
batch_results_path: str | Path,
|
807
|
+
dataset_results_path: str | Path,
|
808
|
+
) -> None:
|
809
|
+
with open(batch_requests_path, 'r') as f:
|
810
|
+
input_lines: list[BatchJSONL] = [json.loads(line) for line in f]
|
811
|
+
with open(batch_results_path, 'r') as f:
|
812
|
+
batch_results_lines: list[BatchJSONLResponse] = [json.loads(line) for line in f]
|
813
|
+
|
814
|
+
assert len(input_lines) == len(batch_results_lines), "Input and batch results must have the same number of lines"
|
815
|
+
|
816
|
+
for input_line, batch_result in zip(input_lines, batch_results_lines):
|
817
|
+
messages = input_line['body']['messages']
|
818
|
+
|
819
|
+
# Filter out messages containing the old annotation reference to remove messages that come from "update annotation"
|
820
|
+
if isinstance(messages[-1].get('content'), str):
|
821
|
+
if re.search(r'Here is an old annotation using a different schema\. Use it as a reference to update the annotation:', str(messages[-1].get('content', ''))):
|
822
|
+
print("found keyword")
|
823
|
+
input_line['body']['messages'] = messages[:-1]
|
824
|
+
|
825
|
+
input_line['body']['messages'].append(batch_result['response']['body']['choices'][0]['message'])
|
826
|
+
|
827
|
+
with open(dataset_results_path, 'w') as f:
|
828
|
+
for input_line in input_lines:
|
829
|
+
f.write(json.dumps({'messages': input_line['body']['messages']}) + '\n')
|
830
|
+
|
831
|
+
print(f"Dataset saved to {dataset_results_path}")
|
832
|
+
|
833
|
+
#############################
|
834
|
+
##### END BATCH METHODS #####
|
835
|
+
#############################
|
836
|
+
|
837
|
+
|
838
|
+
class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
|
839
|
+
"""Asynchronous wrapper for Datasets using thread execution."""
|
840
|
+
|
841
|
+
async def save(
|
842
|
+
self,
|
843
|
+
json_schema: dict[str, Any] | Path | str,
|
844
|
+
pairs_paths: list[dict[str, Path | str]],
|
845
|
+
jsonl_path: Path | str,
|
846
|
+
modality: Modality = "native",
|
847
|
+
) -> None:
|
848
|
+
json_schema = load_json_schema(json_schema)
|
849
|
+
training_set = []
|
850
|
+
|
851
|
+
for pair_paths in tqdm(pairs_paths):
|
852
|
+
document_message = await self._client.documents.create_messages(document=pair_paths['document_fpath'], modality=modality)
|
853
|
+
|
854
|
+
with open(pair_paths['annotation_fpath'], 'r') as f:
|
855
|
+
annotation = json.loads(f.read())
|
856
|
+
assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
|
857
|
+
|
858
|
+
training_set.append({"messages": document_message.messages + [assistant_message]})
|
859
|
+
|
860
|
+
self._dump_training_set(training_set, jsonl_path)
|
861
|
+
|
862
|
+
async def annotate(
|
863
|
+
self,
|
864
|
+
json_schema: dict[str, Any] | Path | str,
|
865
|
+
documents: list[Path | str | IOBase],
|
866
|
+
jsonl_path: str | Path,
|
867
|
+
model: str = "gpt-4o-2024-08-06",
|
868
|
+
temperature: float = 0.0,
|
869
|
+
batch_size: int = 5,
|
870
|
+
max_concurrent: int = 3,
|
871
|
+
root_dir: Path = Path("annotations"),
|
872
|
+
modality: Modality = "native",
|
873
|
+
) -> None:
|
874
|
+
json_schema = load_json_schema(json_schema)
|
875
|
+
assert_valid_model_extraction(model)
|
876
|
+
"""
|
877
|
+
Generate annotations from document files or in-memory documents
|
878
|
+
and create a JSONL training set in one go.
|
879
|
+
|
880
|
+
Args:
|
881
|
+
json_schema: The JSON schema for validation
|
882
|
+
documents: list of documents, each can be a Path/str or an IOBase object
|
883
|
+
jsonl_path: Output path for the JSONL training file
|
884
|
+
model: The model to use for processing
|
885
|
+
temperature: Model temperature (0-1)
|
886
|
+
batch_size: Number of examples to process in each batch
|
887
|
+
max_concurrent: Maximum number of concurrent API calls
|
888
|
+
root_dir: Where to store the per-document JSON annotations
|
889
|
+
"""
|
890
|
+
|
891
|
+
async def process_example(doc: Path | str | IOBase, semaphore: asyncio.Semaphore) -> dict[str, Any]:
|
892
|
+
"""
|
893
|
+
Process a single document (either a file path or an in-memory file-like object).
|
894
|
+
Returns a dict with pointers to the original doc and the stored annotation JSON.
|
895
|
+
"""
|
896
|
+
if isinstance(doc, (str, Path)):
|
897
|
+
# Handle path or string
|
898
|
+
doc_path = Path(doc)
|
899
|
+
if not doc_path.is_file():
|
900
|
+
raise ValueError(f"Invalid file path: {doc_path}")
|
901
|
+
|
902
|
+
# Extract results
|
903
|
+
async with semaphore:
|
904
|
+
result = await self._client.documents.extractions.parse(
|
905
|
+
json_schema=json_schema,
|
906
|
+
document=doc_path, # pass the actual Path to .extract
|
907
|
+
model=model,
|
908
|
+
temperature=temperature,
|
909
|
+
modality=modality,
|
910
|
+
)
|
911
|
+
if result.choices[0].message.content is None:
|
912
|
+
print(f"Failed to extract content from {doc_path}")
|
913
|
+
return {"document_fpath": str(doc_path), "annotation_fpath": None}
|
914
|
+
# Generate a unique filename for the annotation
|
915
|
+
hash_str = hashlib.md5(doc_path.as_posix().encode()).hexdigest()
|
916
|
+
annotation_path = Path(root_dir) / f"annotations_{hash_str}.json"
|
917
|
+
|
918
|
+
annotation_path.parent.mkdir(parents=True, exist_ok=True)
|
919
|
+
|
920
|
+
with open(annotation_path, 'w', encoding='utf-8') as f:
|
921
|
+
json.dump(result.choices[0].message.content, f, ensure_ascii=False, indent=2)
|
922
|
+
|
923
|
+
return {"document_fpath": str(doc_path), "annotation_fpath": str(annotation_path)}
|
924
|
+
|
925
|
+
elif isinstance(doc, IO):
|
926
|
+
# Handle in-memory file-like object
|
927
|
+
# 1) Read file content (but be careful with read pointer!)
|
928
|
+
file_bytes = doc.read()
|
929
|
+
|
930
|
+
# 2) Attempt to get a name; default to "uploaded_file" if none
|
931
|
+
doc_name = getattr(doc, "name", "uploaded_file")
|
932
|
+
|
933
|
+
# 3) Reset the file pointer if you plan to reuse `doc`
|
934
|
+
# (optional, depending on how you're using it)
|
935
|
+
doc.seek(0)
|
936
|
+
|
937
|
+
# 4) Call extract with the same doc object
|
938
|
+
async with semaphore:
|
939
|
+
result = await self._client.documents.extractions.parse(
|
940
|
+
json_schema=json_schema,
|
941
|
+
document=doc, # pass the IO object directly
|
942
|
+
model=model,
|
943
|
+
temperature=temperature,
|
944
|
+
modality=modality,
|
945
|
+
)
|
946
|
+
|
947
|
+
if result.choices[0].message.content is None:
|
948
|
+
print(f"Failed to extract content from {doc_name}")
|
949
|
+
return {"document_fpath": doc_name, "annotation_fpath": None}
|
950
|
+
|
951
|
+
# 5) Create a unique hash from the content
|
952
|
+
hash_str = hashlib.md5(file_bytes).hexdigest()
|
953
|
+
annotation_path = Path(root_dir) / f"annotations_{hash_str}.json"
|
954
|
+
|
955
|
+
annotation_path.parent.mkdir(parents=True, exist_ok=True)
|
956
|
+
|
957
|
+
with open(annotation_path, 'w', encoding='utf-8') as f:
|
958
|
+
json.dump(result.choices[0].message.content, f, ensure_ascii=False, indent=2)
|
959
|
+
|
960
|
+
return {
|
961
|
+
"document_fpath": doc_name, # or "in_memory_file"
|
962
|
+
"annotation_fpath": str(annotation_path),
|
963
|
+
}
|
964
|
+
|
965
|
+
else:
|
966
|
+
raise ValueError(f"Unsupported document type: {type(doc)}")
|
967
|
+
|
968
|
+
# Make sure output directory exists
|
969
|
+
Path(root_dir).mkdir(parents=True, exist_ok=True)
|
970
|
+
|
971
|
+
pairs_paths: list[dict[str, Path | str]] = []
|
972
|
+
futures = []
|
973
|
+
semaphore = asyncio.Semaphore(max_concurrent)
|
974
|
+
# Split documents into batches
|
975
|
+
for batch in tqdm([documents[i : i + batch_size] for i in range(0, len(documents), batch_size)], desc="Processing batches"):
|
976
|
+
# Submit batch of tasks
|
977
|
+
for doc in batch:
|
978
|
+
futures.append(process_example(doc, semaphore))
|
979
|
+
pairs_paths = await asyncio.gather(*futures)
|
980
|
+
|
981
|
+
# Generate final training set from all results
|
982
|
+
await self.save(json_schema=json_schema, pairs_paths=pairs_paths, jsonl_path=jsonl_path)
|
983
|
+
|
984
|
+
async def benchmark(self, **kwargs: Any) -> None:
|
985
|
+
"""Benchmark model performance on a test dataset.
|
986
|
+
|
987
|
+
Args:
|
988
|
+
json_schema: JSON schema defining the expected data structure
|
989
|
+
jsonl_path: Path to the JSONL file containing test examples
|
990
|
+
model: The AI model to use for benchmarking
|
991
|
+
temperature: Model temperature setting (0-1)
|
992
|
+
|
993
|
+
Raises:
|
994
|
+
NotImplementedError: This method is not implemented yet
|
995
|
+
"""
|
996
|
+
|
997
|
+
# TODO
|
998
|
+
|
999
|
+
raise NotImplementedError("Benchmarking is not implemented yet")
|
1000
|
+
|
1001
|
+
async def filter(self, **kwargs: Any) -> None:
|
1002
|
+
"""Filter examples from a JSONL file based on specified parameters.
|
1003
|
+
|
1004
|
+
Args:
|
1005
|
+
json_schema: JSON schema defining the data structure
|
1006
|
+
jsonl_path: Path to the JSONL file to filter
|
1007
|
+
output_path: Optional path for the filtered output
|
1008
|
+
inplace: Whether to modify the file in place
|
1009
|
+
filter_parameters: Parameters to filter examples by (e.g., {"confidence": 0.8})
|
1010
|
+
|
1011
|
+
Note:
|
1012
|
+
Filter parameters can include:
|
1013
|
+
- Number of tokens
|
1014
|
+
- Modality
|
1015
|
+
- Other custom parameters
|
1016
|
+
"""
|
1017
|
+
raise NotImplementedError("Filtering is not implemented yet")
|
1018
|
+
|
1019
|
+
async def print(self, jsonl_path: Path, output_path: Path = Path("annotations")) -> None:
|
1020
|
+
"""Print a summary of the contents and statistics of a JSONL file.
|
1021
|
+
|
1022
|
+
This method analyzes the JSONL file and displays various metrics and statistics
|
1023
|
+
about the dataset contents.
|
1024
|
+
|
1025
|
+
Args:
|
1026
|
+
jsonl_path: Path to the JSONL file to analyze
|
1027
|
+
output_path: Directory where to save any generated reports
|
1028
|
+
"""
|
1029
|
+
raise NotImplementedError("Printing is not implemented yet")
|
1030
|
+
|
1031
|
+
async def stitch(self, **kwargs: Any) -> None:
|
1032
|
+
"""Stitch annotations from a list of MIMEData objects into a single MIMEData object.
|
1033
|
+
|
1034
|
+
This method combines multiple MIMEData annotations into a single object to avoid
|
1035
|
+
nested list structures (list[list[MIMEData]]) and maintain a simpler list[MIMEData] structure.
|
1036
|
+
|
1037
|
+
Args:
|
1038
|
+
json_schema: The JSON schema for validation
|
1039
|
+
jsonl_path: Path to the JSONL file
|
1040
|
+
output_path: Optional path for the output file
|
1041
|
+
inplace: Whether to modify the file in place
|
1042
|
+
filter_parameters: Optional parameters for filtering
|
1043
|
+
modality: The modality to use for processing
|
1044
|
+
"""
|
1045
|
+
|
1046
|
+
raise NotImplementedError("Stitching is not implemented yet")
|