retab 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {uiform → retab}/_utils/ai_models.py +2 -2
- {uiform → retab}/_utils/benchmarking.py +15 -16
- {uiform → retab}/_utils/chat.py +9 -14
- {uiform → retab}/_utils/display.py +0 -3
- {uiform → retab}/_utils/json_schema.py +9 -14
- {uiform → retab}/_utils/mime.py +11 -14
- {uiform → retab}/_utils/responses.py +9 -3
- {uiform → retab}/_utils/stream_context_managers.py +1 -1
- {uiform → retab}/_utils/usage/usage.py +28 -28
- {uiform → retab}/client.py +32 -31
- {uiform → retab}/resources/consensus/client.py +17 -36
- {uiform → retab}/resources/consensus/completions.py +24 -47
- {uiform → retab}/resources/consensus/completions_stream.py +26 -38
- {uiform → retab}/resources/consensus/responses.py +31 -80
- {uiform → retab}/resources/consensus/responses_stream.py +31 -79
- {uiform → retab}/resources/documents/client.py +59 -45
- {uiform → retab}/resources/documents/extractions.py +181 -90
- {uiform → retab}/resources/evals.py +56 -43
- retab/resources/evaluations/__init__.py +3 -0
- retab/resources/evaluations/client.py +301 -0
- retab/resources/evaluations/documents.py +233 -0
- retab/resources/evaluations/iterations.py +452 -0
- {uiform → retab}/resources/files.py +2 -2
- {uiform → retab}/resources/jsonlUtils.py +220 -216
- retab/resources/models.py +73 -0
- retab/resources/processors/automations/client.py +244 -0
- {uiform → retab}/resources/processors/automations/endpoints.py +77 -118
- retab/resources/processors/automations/links.py +294 -0
- {uiform → retab}/resources/processors/automations/logs.py +30 -19
- {uiform → retab}/resources/processors/automations/mailboxes.py +136 -174
- retab/resources/processors/automations/outlook.py +337 -0
- {uiform → retab}/resources/processors/automations/tests.py +22 -25
- {uiform → retab}/resources/processors/client.py +179 -164
- {uiform → retab}/resources/schemas.py +78 -66
- {uiform → retab}/resources/secrets/external_api_keys.py +1 -5
- retab/resources/secrets/webhook.py +64 -0
- {uiform → retab}/resources/usage.py +39 -2
- {uiform → retab}/types/ai_models.py +13 -13
- {uiform → retab}/types/automations/cron.py +19 -12
- {uiform → retab}/types/automations/endpoints.py +7 -4
- {uiform → retab}/types/automations/links.py +7 -3
- {uiform → retab}/types/automations/mailboxes.py +9 -9
- {uiform → retab}/types/automations/outlook.py +15 -11
- retab/types/browser_canvas.py +3 -0
- {uiform → retab}/types/chat.py +2 -2
- {uiform → retab}/types/completions.py +9 -12
- retab/types/consensus.py +19 -0
- {uiform → retab}/types/db/annotations.py +3 -3
- {uiform → retab}/types/db/files.py +8 -6
- {uiform → retab}/types/documents/create_messages.py +18 -20
- {uiform → retab}/types/documents/extractions.py +69 -24
- {uiform → retab}/types/evals.py +5 -5
- retab/types/evaluations/__init__.py +31 -0
- retab/types/evaluations/documents.py +30 -0
- retab/types/evaluations/iterations.py +112 -0
- retab/types/evaluations/model.py +73 -0
- retab/types/events.py +79 -0
- {uiform → retab}/types/extractions.py +33 -10
- retab/types/inference_settings.py +15 -0
- retab/types/jobs/base.py +54 -0
- retab/types/jobs/batch_annotation.py +12 -0
- {uiform → retab}/types/jobs/evaluation.py +1 -2
- {uiform → retab}/types/logs.py +37 -34
- retab/types/metrics.py +32 -0
- {uiform → retab}/types/mime.py +22 -20
- {uiform → retab}/types/modalities.py +10 -10
- retab/types/predictions.py +19 -0
- {uiform → retab}/types/schemas/enhance.py +4 -2
- {uiform → retab}/types/schemas/evaluate.py +7 -4
- {uiform → retab}/types/schemas/generate.py +6 -3
- {uiform → retab}/types/schemas/layout.py +1 -1
- {uiform → retab}/types/schemas/object.py +13 -14
- {uiform → retab}/types/schemas/templates.py +1 -3
- {uiform → retab}/types/secrets/external_api_keys.py +0 -1
- {uiform → retab}/types/standards.py +18 -1
- {retab-0.0.36.dist-info → retab-0.0.37.dist-info}/METADATA +7 -6
- retab-0.0.37.dist-info/RECORD +107 -0
- retab-0.0.37.dist-info/top_level.txt +1 -0
- retab-0.0.36.dist-info/RECORD +0 -96
- retab-0.0.36.dist-info/top_level.txt +0 -1
- uiform/_utils/benchmarking copy.py +0 -588
- uiform/resources/models.py +0 -45
- uiform/resources/processors/automations/client.py +0 -78
- uiform/resources/processors/automations/links.py +0 -356
- uiform/resources/processors/automations/outlook.py +0 -444
- uiform/resources/secrets/webhook.py +0 -62
- uiform/types/consensus.py +0 -10
- uiform/types/events.py +0 -76
- uiform/types/jobs/base.py +0 -150
- uiform/types/jobs/batch_annotation.py +0 -22
- {uiform → retab}/__init__.py +0 -0
- {uiform → retab}/_resource.py +0 -0
- {uiform → retab}/_utils/__init__.py +0 -0
- {uiform → retab}/_utils/usage/__init__.py +0 -0
- {uiform → retab}/py.typed +0 -0
- {uiform → retab}/resources/__init__.py +0 -0
- {uiform → retab}/resources/consensus/__init__.py +0 -0
- {uiform → retab}/resources/documents/__init__.py +0 -0
- {uiform → retab}/resources/finetuning.py +0 -0
- {uiform → retab}/resources/openai_example.py +0 -0
- {uiform → retab}/resources/processors/__init__.py +0 -0
- {uiform → retab}/resources/processors/automations/__init__.py +0 -0
- {uiform → retab}/resources/prompt_optimization.py +0 -0
- {uiform → retab}/resources/secrets/__init__.py +0 -0
- {uiform → retab}/resources/secrets/client.py +0 -0
- {uiform → retab}/types/__init__.py +0 -0
- {uiform → retab}/types/automations/__init__.py +0 -0
- {uiform → retab}/types/automations/webhooks.py +0 -0
- {uiform → retab}/types/db/__init__.py +0 -0
- {uiform → retab}/types/documents/__init__.py +0 -0
- {uiform → retab}/types/documents/correct_orientation.py +0 -0
- {uiform → retab}/types/jobs/__init__.py +0 -0
- {uiform → retab}/types/jobs/finetune.py +0 -0
- {uiform → retab}/types/jobs/prompt_optimization.py +0 -0
- {uiform → retab}/types/jobs/webcrawl.py +0 -0
- {uiform → retab}/types/pagination.py +0 -0
- {uiform → retab}/types/schemas/__init__.py +0 -0
- {uiform → retab}/types/secrets/__init__.py +0 -0
- {retab-0.0.36.dist-info → retab-0.0.37.dist-info}/WHEEL +0 -0
@@ -8,12 +8,13 @@ import time
|
|
8
8
|
from concurrent.futures import ThreadPoolExecutor
|
9
9
|
from io import IOBase
|
10
10
|
from pathlib import Path
|
11
|
-
from typing import IO, Any,
|
11
|
+
from typing import IO, Any, Optional, TypedDict
|
12
12
|
|
13
13
|
from anthropic import Anthropic
|
14
14
|
from openai import OpenAI
|
15
15
|
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
16
16
|
from pydantic import BaseModel
|
17
|
+
from pydantic_core import PydanticUndefined
|
17
18
|
from tqdm import tqdm
|
18
19
|
|
19
20
|
from .._resource import AsyncAPIResource, SyncAPIResource
|
@@ -24,6 +25,7 @@ from .._utils.json_schema import load_json_schema
|
|
24
25
|
from ..types.chat import ChatCompletionUiformMessage
|
25
26
|
from ..types.modalities import Modality
|
26
27
|
from ..types.schemas.object import Schema
|
28
|
+
from ..types.browser_canvas import BrowserCanvas
|
27
29
|
|
28
30
|
|
29
31
|
class FinetuningJSON(BaseModel):
|
@@ -31,7 +33,6 @@ class FinetuningJSON(BaseModel):
|
|
31
33
|
|
32
34
|
|
33
35
|
FinetuningJSONL = list[FinetuningJSON]
|
34
|
-
from typing import TypedDict
|
35
36
|
|
36
37
|
|
37
38
|
class BatchJSONLResponseFormat(TypedDict):
|
@@ -106,9 +107,9 @@ class BatchJSONLResponse(BaseModel):
|
|
106
107
|
|
107
108
|
class BaseDatasetsMixin:
|
108
109
|
def _dump_training_set(self, training_set: list[dict[str, Any]], dataset_path: Path | str) -> None:
|
109
|
-
with open(dataset_path,
|
110
|
+
with open(dataset_path, "w", encoding="utf-8") as file:
|
110
111
|
for entry in training_set:
|
111
|
-
file.write(json.dumps(entry) +
|
112
|
+
file.write(json.dumps(entry) + "\n")
|
112
113
|
|
113
114
|
|
114
115
|
class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
@@ -138,8 +139,8 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
138
139
|
json_schema: dict[str, Any] | Path | str,
|
139
140
|
document_annotation_pairs_paths: list[dict[str, Path | str]],
|
140
141
|
dataset_path: Path | str,
|
141
|
-
image_resolution_dpi: int
|
142
|
-
browser_canvas:
|
142
|
+
image_resolution_dpi: int = PydanticUndefined, # type: ignore[assignment]
|
143
|
+
browser_canvas: BrowserCanvas = PydanticUndefined, # type: ignore[assignment]
|
143
144
|
modality: Modality = "native",
|
144
145
|
) -> None:
|
145
146
|
"""Save document-annotation pairs to a JSONL training set.
|
@@ -153,16 +154,18 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
153
154
|
json_schema = load_json_schema(json_schema)
|
154
155
|
schema_obj = Schema(json_schema=json_schema)
|
155
156
|
|
156
|
-
with open(dataset_path,
|
157
|
+
with open(dataset_path, "w", encoding="utf-8") as file:
|
157
158
|
for pair_paths in tqdm(document_annotation_pairs_paths, desc="Processing pairs", position=0):
|
158
|
-
document_message = self._client.documents.create_messages(
|
159
|
+
document_message = self._client.documents.create_messages(
|
160
|
+
document=pair_paths["document_fpath"], modality=modality, image_resolution_dpi=image_resolution_dpi, browser_canvas=browser_canvas
|
161
|
+
)
|
159
162
|
|
160
|
-
with open(pair_paths[
|
163
|
+
with open(pair_paths["annotation_fpath"], "r") as f:
|
161
164
|
annotation = json.loads(f.read())
|
162
165
|
assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
|
163
166
|
|
164
167
|
entry = {"messages": schema_obj.messages + document_message.messages + [assistant_message]}
|
165
|
-
file.write(json.dumps(entry) +
|
168
|
+
file.write(json.dumps(entry) + "\n")
|
166
169
|
|
167
170
|
def change_schema(
|
168
171
|
self,
|
@@ -193,8 +196,8 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
193
196
|
assert isinstance(target_path, Path) or isinstance(target_path, str)
|
194
197
|
|
195
198
|
# Use a temporary file to write the updated content
|
196
|
-
with tempfile.NamedTemporaryFile(
|
197
|
-
with open(input_dataset_path,
|
199
|
+
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as temp_file:
|
200
|
+
with open(input_dataset_path, "r", encoding="utf-8") as infile:
|
198
201
|
for line in infile:
|
199
202
|
entry = json.loads(line)
|
200
203
|
messages = entry.get("messages", [])
|
@@ -208,7 +211,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
208
211
|
updated_entry = {"messages": updated_messages}
|
209
212
|
|
210
213
|
# Write the updated entry to the temporary file
|
211
|
-
temp_file.write(json.dumps(updated_entry) +
|
214
|
+
temp_file.write(json.dumps(updated_entry) + "\n")
|
212
215
|
|
213
216
|
# Replace the original file with the temporary file
|
214
217
|
shutil.move(temp_file.name, target_path)
|
@@ -241,19 +244,19 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
241
244
|
for pair_paths in tqdm(pairs_paths):
|
242
245
|
document_messages: list[ChatCompletionUiformMessage] = []
|
243
246
|
|
244
|
-
if isinstance(pair_paths[
|
245
|
-
document_message = self._client.documents.create_messages(document=pair_paths[
|
247
|
+
if isinstance(pair_paths["document_fpath"], str) or isinstance(pair_paths["document_fpath"], Path):
|
248
|
+
document_message = self._client.documents.create_messages(document=pair_paths["document_fpath"], modality=modality)
|
246
249
|
document_messages.extend(document_message.messages)
|
247
250
|
|
248
251
|
else:
|
249
|
-
assert isinstance(pair_paths[
|
250
|
-
for document_fpath in pair_paths[
|
252
|
+
assert isinstance(pair_paths["document_fpath"], list)
|
253
|
+
for document_fpath in pair_paths["document_fpath"]:
|
251
254
|
document_message = self._client.documents.create_messages(document=document_fpath, modality=modality)
|
252
255
|
document_messages.extend(document_message.messages)
|
253
256
|
|
254
257
|
# Use context manager to properly close the file
|
255
|
-
assert isinstance(pair_paths[
|
256
|
-
with open(pair_paths[
|
258
|
+
assert isinstance(pair_paths["annotation_fpath"], Path) or isinstance(pair_paths["annotation_fpath"], str)
|
259
|
+
with open(pair_paths["annotation_fpath"], "r") as f:
|
257
260
|
annotation = json.loads(f.read())
|
258
261
|
assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
|
259
262
|
|
@@ -389,6 +392,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
389
392
|
raise ValueError(f"Invalid file path: {doc_path}")
|
390
393
|
hash_str = hashlib.md5(doc_path.as_posix().encode()).hexdigest()
|
391
394
|
elif isinstance(doc, IO):
|
395
|
+
doc_path = Path(doc.name) or "unknown_file"
|
392
396
|
file_bytes = doc.read()
|
393
397
|
hash_str = hashlib.md5(file_bytes).hexdigest()
|
394
398
|
doc.seek(0)
|
@@ -408,7 +412,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
408
412
|
annotation_path = Path(root_dir) / f"annotations_{hash_str}.json"
|
409
413
|
annotation_path.parent.mkdir(parents=True, exist_ok=True)
|
410
414
|
|
411
|
-
with open(annotation_path,
|
415
|
+
with open(annotation_path, "w", encoding="utf-8") as f:
|
412
416
|
json.dump(string_json, f, ensure_ascii=False, indent=2)
|
413
417
|
|
414
418
|
return {"document_fpath": str(doc_path), "annotation_fpath": str(annotation_path)}
|
@@ -442,176 +446,176 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
442
446
|
# Generate final training set from all results
|
443
447
|
self.save(json_schema=json_schema, document_annotation_pairs_paths=pairs_paths, dataset_path=dataset_path)
|
444
448
|
|
445
|
-
def eval(
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
) -> ComparisonMetrics:
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
def benchmark(
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
) -> list[BenchmarkMetrics]:
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
449
|
+
# def eval(
|
450
|
+
# self,
|
451
|
+
# json_schema: dict[str, Any] | Path | str,
|
452
|
+
# dataset_path: str | Path,
|
453
|
+
# model: str = "gpt-4o-2024-08-06",
|
454
|
+
# temperature: float = 0.0,
|
455
|
+
# batch_size: int = 5,
|
456
|
+
# max_concurrent: int = 3,
|
457
|
+
# display: bool = True,
|
458
|
+
# ) -> ComparisonMetrics:
|
459
|
+
# """Evaluate model performance on a test dataset.
|
460
|
+
|
461
|
+
# Args:
|
462
|
+
# json_schema: JSON schema defining the expected data structure
|
463
|
+
# dataset_path: Path to the JSONL file containing test examples
|
464
|
+
# model: The model to use for benchmarking
|
465
|
+
# temperature: Model temperature setting (0-1)
|
466
|
+
# batch_size: Number of examples to process in each batch
|
467
|
+
# max_concurrent: Maximum number of concurrent API calls
|
468
|
+
# """
|
469
|
+
|
470
|
+
# json_schema = load_json_schema(json_schema)
|
471
|
+
# assert_valid_model_extraction(model)
|
472
|
+
# schema_obj = Schema(json_schema=json_schema)
|
473
|
+
|
474
|
+
# # Initialize appropriate client
|
475
|
+
# client, provider = self._initialize_model_client(model)
|
476
|
+
|
477
|
+
# # Read all lines from the JSONL file
|
478
|
+
# with open(dataset_path, "r") as f:
|
479
|
+
# lines = [json.loads(line) for line in f]
|
480
|
+
|
481
|
+
# extraction_analyses: list[ExtractionAnalysis] = []
|
482
|
+
# total_batches = (len(lines) + batch_size - 1) // batch_size
|
483
|
+
|
484
|
+
# # Create main progress bar for batches
|
485
|
+
# batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
|
486
|
+
|
487
|
+
# # Track running metrics
|
488
|
+
# class RunningMetrics(BaseModel):
|
489
|
+
# model: str
|
490
|
+
# accuracy: float
|
491
|
+
# levenshtein: float
|
492
|
+
# jaccard: float
|
493
|
+
# false_positive: float
|
494
|
+
# mismatched: float
|
495
|
+
# processed: int
|
496
|
+
|
497
|
+
# running_metrics: RunningMetrics = RunningMetrics(
|
498
|
+
# model=model,
|
499
|
+
# accuracy=0.0,
|
500
|
+
# levenshtein=0.0,
|
501
|
+
# jaccard=0.0,
|
502
|
+
# false_positive=0.0,
|
503
|
+
# mismatched=0.0,
|
504
|
+
# processed=0, # number of processed examples - used in the loop to compute the running averages
|
505
|
+
# )
|
506
|
+
|
507
|
+
# # def update_running_metrics(analysis: ExtractionAnalysis) -> None:
|
508
|
+
# # comparison = normalized_comparison_metrics([analysis])
|
509
|
+
# # running_metrics.processed += 1
|
510
|
+
# # n = running_metrics.processed
|
511
|
+
# # # Update running averages
|
512
|
+
# # running_metrics.accuracy = (running_metrics.accuracy * (n - 1) + comparison.accuracy) / n
|
513
|
+
# # running_metrics.levenshtein = (running_metrics.levenshtein * (n - 1) + comparison.levenshtein_similarity) / n
|
514
|
+
# # running_metrics.jaccard = (running_metrics.jaccard * (n - 1) + comparison.jaccard_similarity) / n
|
515
|
+
# # running_metrics.false_positive = (running_metrics.false_positive * (n - 1) + comparison.false_positive_rate) / n
|
516
|
+
# # running_metrics.mismatched = (running_metrics.mismatched * (n - 1) + comparison.mismatched_value_rate) / n
|
517
|
+
# # # Update progress bar description
|
518
|
+
# # batch_pbar.set_description(
|
519
|
+
# # f"Processing batches | Model: {running_metrics.model} | Acc: {running_metrics.accuracy:.2f} | "
|
520
|
+
# # f"Lev: {running_metrics.levenshtein:.2f} | "
|
521
|
+
# # f"IOU: {running_metrics.jaccard:.2f} | "
|
522
|
+
# # f"FP: {running_metrics.false_positive:.2f} | "
|
523
|
+
# # f"Mism: {running_metrics.mismatched:.2f}"
|
524
|
+
# # )
|
525
|
+
|
526
|
+
# # def process_example(jsonline: dict) -> ExtractionAnalysis | None:
|
527
|
+
# # line_number = jsonline["line_number"]
|
528
|
+
# # try:
|
529
|
+
# # messages = jsonline["messages"]
|
530
|
+
# # ground_truth = json.loads(messages[-1]["content"])
|
531
|
+
# # inference_messages = messages[:-1]
|
532
|
+
|
533
|
+
# # # Use _get_model_completion instead of duplicating provider-specific logic
|
534
|
+
# # string_json = self._get_model_completion(client=client, provider=provider, model=model, temperature=temperature, messages=inference_messages, schema_obj=schema_obj)
|
535
|
+
|
536
|
+
# # prediction = json.loads(string_json)
|
537
|
+
# # analysis = ExtractionAnalysis(
|
538
|
+
# # ground_truth=ground_truth,
|
539
|
+
# # prediction=prediction,
|
540
|
+
# # )
|
541
|
+
# # update_running_metrics(analysis)
|
542
|
+
# # return analysis
|
543
|
+
# # except Exception as e:
|
544
|
+
# # print(f"\nWarning: Failed to process line number {line_number}: {str(e)}")
|
545
|
+
# # return None
|
546
|
+
|
547
|
+
# # with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
|
548
|
+
# # # Split entries into batches
|
549
|
+
# # for batch_idx in range(0, len(lines), batch_size):
|
550
|
+
# # batch = lines[batch_idx : batch_idx + batch_size]
|
551
|
+
|
552
|
+
# # # Submit and process batch
|
553
|
+
# # futures = [executor.submit(process_example, entry | {"line_number": batch_idx * batch_size + i}) for i, entry in enumerate(batch)]
|
554
|
+
# # for future in futures:
|
555
|
+
# # result = future.result()
|
556
|
+
# # if result is not None:
|
557
|
+
# # extraction_analyses.append(result)
|
558
|
+
|
559
|
+
# # batch_pbar.update(1)
|
560
|
+
|
561
|
+
# # batch_pbar.close()
|
562
|
+
|
563
|
+
# # # Analyze error patterns across all examples
|
564
|
+
# # analysis = normalized_comparison_metrics(extraction_analyses)
|
565
|
+
|
566
|
+
# # if display:
|
567
|
+
# # plot_comparison_metrics(analysis=analysis, top_n=10)
|
568
|
+
|
569
|
+
# # return analysis
|
570
|
+
|
571
|
+
# def benchmark(
|
572
|
+
# self,
|
573
|
+
# json_schema: dict[str, Any] | Path | str,
|
574
|
+
# dataset_path: str | Path,
|
575
|
+
# models: list[str],
|
576
|
+
# temperature: float = 0.0,
|
577
|
+
# batch_size: int = 5,
|
578
|
+
# max_concurrent: int = 3,
|
579
|
+
# print: bool = True,
|
580
|
+
# verbose: bool = False,
|
581
|
+
# ) -> list[BenchmarkMetrics]:
|
582
|
+
# """Benchmark multiple models on a test dataset.
|
583
|
+
|
584
|
+
# Args:
|
585
|
+
# json_schema: JSON schema defining the expected data structure
|
586
|
+
# dataset_path: Path to the JSONL file containing test examples
|
587
|
+
# models: List of models to benchmark
|
588
|
+
# temperature: Model temperature setting (0-1)
|
589
|
+
# batch_size: Number of examples to process in each batch
|
590
|
+
# max_concurrent: Maximum number of concurrent API calls
|
591
|
+
# print: Whether to print the metrics
|
592
|
+
# verbose: Whether to print all the metrics of all the function calls
|
593
|
+
|
594
|
+
# Returns:
|
595
|
+
# Dictionary mapping model names to their evaluation metrics
|
596
|
+
# """
|
597
|
+
# results: list[BenchmarkMetrics] = []
|
598
|
+
|
599
|
+
# for model in models:
|
600
|
+
# metrics: ComparisonMetrics = self.eval(
|
601
|
+
# json_schema=json_schema, dataset_path=dataset_path, model=model, temperature=temperature, batch_size=batch_size, max_concurrent=max_concurrent, display=verbose
|
602
|
+
# )
|
603
|
+
# results.append(
|
604
|
+
# BenchmarkMetrics(
|
605
|
+
# ai_model=model,
|
606
|
+
# accuracy=metrics.accuracy,
|
607
|
+
# levenshtein_similarity=metrics.levenshtein_similarity,
|
608
|
+
# jaccard_similarity=metrics.jaccard_similarity,
|
609
|
+
# false_positive_rate=metrics.false_positive_rate,
|
610
|
+
# false_negative_rate=metrics.false_negative_rate,
|
611
|
+
# mismatched_value_rate=metrics.mismatched_value_rate,
|
612
|
+
# )
|
613
|
+
# )
|
614
|
+
|
615
|
+
# if print:
|
616
|
+
# display_benchmark_metrics(results)
|
617
|
+
|
618
|
+
# return results
|
615
619
|
|
616
620
|
def update_annotations(
|
617
621
|
self,
|
@@ -642,7 +646,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
642
646
|
client, provider = self._initialize_model_client(model)
|
643
647
|
|
644
648
|
# Read all lines from the JSONL file
|
645
|
-
with open(old_dataset_path,
|
649
|
+
with open(old_dataset_path, "r") as f:
|
646
650
|
lines = [json.loads(line) for line in f]
|
647
651
|
|
648
652
|
updated_entries = []
|
@@ -651,13 +655,13 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
651
655
|
batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
|
652
656
|
|
653
657
|
def process_entry(entry: dict) -> dict:
|
654
|
-
messages = entry[
|
658
|
+
messages = entry["messages"]
|
655
659
|
system_message, user_messages, assistant_messages = separate_messages(messages)
|
656
660
|
system_and_user_messages = messages[:-1]
|
657
661
|
|
658
662
|
previous_annotation_message: ChatCompletionUiformMessage = {
|
659
663
|
"role": "user",
|
660
|
-
"content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + messages[-1][
|
664
|
+
"content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + messages[-1]["content"],
|
661
665
|
}
|
662
666
|
|
663
667
|
string_json = self._get_model_completion(
|
@@ -691,9 +695,9 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
691
695
|
|
692
696
|
batch_pbar.close()
|
693
697
|
|
694
|
-
with open(new_dataset_path,
|
698
|
+
with open(new_dataset_path, "w") as f:
|
695
699
|
for entry in updated_entries:
|
696
|
-
f.write(json.dumps(entry) +
|
700
|
+
f.write(json.dumps(entry) + "\n")
|
697
701
|
|
698
702
|
#########################
|
699
703
|
##### BATCH METHODS #####
|
@@ -722,7 +726,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
722
726
|
schema_obj = Schema(json_schema=loaded_json_schema)
|
723
727
|
assert_valid_model_extraction(model)
|
724
728
|
|
725
|
-
with open(batch_requests_path,
|
729
|
+
with open(batch_requests_path, "w", encoding="utf-8") as f:
|
726
730
|
for i, doc in tqdm(enumerate(documents)):
|
727
731
|
# Create document messages
|
728
732
|
doc_msg = self._client.documents.create_messages(
|
@@ -744,7 +748,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
744
748
|
}
|
745
749
|
|
746
750
|
# Write the request as a JSON line
|
747
|
-
f.write(json.dumps(request) +
|
751
|
+
f.write(json.dumps(request) + "\n")
|
748
752
|
|
749
753
|
def save_batch_update_annotation_requests(
|
750
754
|
self,
|
@@ -768,18 +772,18 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
768
772
|
schema_obj = Schema(json_schema=loaded_json_schema)
|
769
773
|
|
770
774
|
# Read existing annotations
|
771
|
-
with open(old_dataset_path,
|
775
|
+
with open(old_dataset_path, "r") as f:
|
772
776
|
entries = [json.loads(line) for line in f]
|
773
777
|
|
774
778
|
# Create new JSONL with update requests
|
775
|
-
with open(batch_requests_path,
|
779
|
+
with open(batch_requests_path, "w", encoding="utf-8") as f:
|
776
780
|
for i, entry in enumerate(entries):
|
777
|
-
existing_messages = entry[
|
781
|
+
existing_messages = entry["messages"]
|
778
782
|
system_and_user_messages = existing_messages[:-1]
|
779
783
|
|
780
784
|
previous_annotation_message: ChatCompletionMessageParam = {
|
781
785
|
"role": "user",
|
782
|
-
"content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + existing_messages[-1][
|
786
|
+
"content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + existing_messages[-1]["content"],
|
783
787
|
}
|
784
788
|
|
785
789
|
# Construct the request object
|
@@ -798,7 +802,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
798
802
|
request: BatchJSONL = {"custom_id": f"request-{i}", "method": "POST", "url": "/v1/chat/completions", "body": body}
|
799
803
|
|
800
804
|
# Write the request as a JSON line
|
801
|
-
f.write(json.dumps(request) +
|
805
|
+
f.write(json.dumps(request) + "\n")
|
802
806
|
|
803
807
|
def build_dataset_from_batch_results(
|
804
808
|
self,
|
@@ -806,27 +810,27 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
|
|
806
810
|
batch_results_path: str | Path,
|
807
811
|
dataset_results_path: str | Path,
|
808
812
|
) -> None:
|
809
|
-
with open(batch_requests_path,
|
813
|
+
with open(batch_requests_path, "r") as f:
|
810
814
|
input_lines: list[BatchJSONL] = [json.loads(line) for line in f]
|
811
|
-
with open(batch_results_path,
|
812
|
-
batch_results_lines: list[BatchJSONLResponse] = [
|
815
|
+
with open(batch_results_path, "r") as f:
|
816
|
+
batch_results_lines: list[BatchJSONLResponse] = [BatchJSONLResponse.model_validate_json(line) for line in f]
|
813
817
|
|
814
818
|
assert len(input_lines) == len(batch_results_lines), "Input and batch results must have the same number of lines"
|
815
819
|
|
816
820
|
for input_line, batch_result in zip(input_lines, batch_results_lines):
|
817
|
-
messages = input_line[
|
821
|
+
messages = input_line["body"]["messages"]
|
818
822
|
|
819
823
|
# Filter out messages containing the old annotation reference to remove messages that come from "update annotation"
|
820
|
-
if isinstance(messages[-1].get(
|
821
|
-
if re.search(r
|
824
|
+
if isinstance(messages[-1].get("content"), str):
|
825
|
+
if re.search(r"Here is an old annotation using a different schema\. Use it as a reference to update the annotation:", str(messages[-1].get("content", ""))):
|
822
826
|
print("found keyword")
|
823
|
-
input_line[
|
827
|
+
input_line["body"]["messages"] = messages[:-1]
|
824
828
|
|
825
|
-
input_line[
|
829
|
+
input_line["body"]["messages"].append(batch_result.response.body.choices[0].message)
|
826
830
|
|
827
|
-
with open(dataset_results_path,
|
831
|
+
with open(dataset_results_path, "w") as f:
|
828
832
|
for input_line in input_lines:
|
829
|
-
f.write(json.dumps({
|
833
|
+
f.write(json.dumps({"messages": input_line["body"]["messages"]}) + "\n")
|
830
834
|
|
831
835
|
print(f"Dataset saved to {dataset_results_path}")
|
832
836
|
|
@@ -849,9 +853,9 @@ class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
|
|
849
853
|
training_set = []
|
850
854
|
|
851
855
|
for pair_paths in tqdm(pairs_paths):
|
852
|
-
document_message = await self._client.documents.create_messages(document=pair_paths[
|
856
|
+
document_message = await self._client.documents.create_messages(document=pair_paths["document_fpath"], modality=modality)
|
853
857
|
|
854
|
-
with open(pair_paths[
|
858
|
+
with open(pair_paths["annotation_fpath"], "r") as f:
|
855
859
|
annotation = json.loads(f.read())
|
856
860
|
assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
|
857
861
|
|
@@ -917,7 +921,7 @@ class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
|
|
917
921
|
|
918
922
|
annotation_path.parent.mkdir(parents=True, exist_ok=True)
|
919
923
|
|
920
|
-
with open(annotation_path,
|
924
|
+
with open(annotation_path, "w", encoding="utf-8") as f:
|
921
925
|
json.dump(result.choices[0].message.content, f, ensure_ascii=False, indent=2)
|
922
926
|
|
923
927
|
return {"document_fpath": str(doc_path), "annotation_fpath": str(annotation_path)}
|
@@ -954,7 +958,7 @@ class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
|
|
954
958
|
|
955
959
|
annotation_path.parent.mkdir(parents=True, exist_ok=True)
|
956
960
|
|
957
|
-
with open(annotation_path,
|
961
|
+
with open(annotation_path, "w", encoding="utf-8") as f:
|
958
962
|
json.dump(result.choices[0].message.content, f, ensure_ascii=False, indent=2)
|
959
963
|
|
960
964
|
return {
|