retab 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. {uiform → retab}/_utils/ai_models.py +2 -2
  2. {uiform → retab}/_utils/benchmarking.py +15 -16
  3. {uiform → retab}/_utils/chat.py +9 -14
  4. {uiform → retab}/_utils/display.py +0 -3
  5. {uiform → retab}/_utils/json_schema.py +9 -14
  6. {uiform → retab}/_utils/mime.py +11 -14
  7. {uiform → retab}/_utils/responses.py +9 -3
  8. {uiform → retab}/_utils/stream_context_managers.py +1 -1
  9. {uiform → retab}/_utils/usage/usage.py +28 -28
  10. {uiform → retab}/client.py +32 -31
  11. {uiform → retab}/resources/consensus/client.py +17 -36
  12. {uiform → retab}/resources/consensus/completions.py +24 -47
  13. {uiform → retab}/resources/consensus/completions_stream.py +26 -38
  14. {uiform → retab}/resources/consensus/responses.py +31 -80
  15. {uiform → retab}/resources/consensus/responses_stream.py +31 -79
  16. {uiform → retab}/resources/documents/client.py +59 -45
  17. {uiform → retab}/resources/documents/extractions.py +181 -90
  18. {uiform → retab}/resources/evals.py +56 -43
  19. retab/resources/evaluations/__init__.py +3 -0
  20. retab/resources/evaluations/client.py +301 -0
  21. retab/resources/evaluations/documents.py +233 -0
  22. retab/resources/evaluations/iterations.py +452 -0
  23. {uiform → retab}/resources/files.py +2 -2
  24. {uiform → retab}/resources/jsonlUtils.py +220 -216
  25. retab/resources/models.py +73 -0
  26. retab/resources/processors/automations/client.py +244 -0
  27. {uiform → retab}/resources/processors/automations/endpoints.py +77 -118
  28. retab/resources/processors/automations/links.py +294 -0
  29. {uiform → retab}/resources/processors/automations/logs.py +30 -19
  30. {uiform → retab}/resources/processors/automations/mailboxes.py +136 -174
  31. retab/resources/processors/automations/outlook.py +337 -0
  32. {uiform → retab}/resources/processors/automations/tests.py +22 -25
  33. {uiform → retab}/resources/processors/client.py +179 -164
  34. {uiform → retab}/resources/schemas.py +78 -66
  35. {uiform → retab}/resources/secrets/external_api_keys.py +1 -5
  36. retab/resources/secrets/webhook.py +64 -0
  37. {uiform → retab}/resources/usage.py +39 -2
  38. {uiform → retab}/types/ai_models.py +13 -13
  39. {uiform → retab}/types/automations/cron.py +19 -12
  40. {uiform → retab}/types/automations/endpoints.py +7 -4
  41. {uiform → retab}/types/automations/links.py +7 -3
  42. {uiform → retab}/types/automations/mailboxes.py +9 -9
  43. {uiform → retab}/types/automations/outlook.py +15 -11
  44. retab/types/browser_canvas.py +3 -0
  45. {uiform → retab}/types/chat.py +2 -2
  46. {uiform → retab}/types/completions.py +9 -12
  47. retab/types/consensus.py +19 -0
  48. {uiform → retab}/types/db/annotations.py +3 -3
  49. {uiform → retab}/types/db/files.py +8 -6
  50. {uiform → retab}/types/documents/create_messages.py +18 -20
  51. {uiform → retab}/types/documents/extractions.py +69 -24
  52. {uiform → retab}/types/evals.py +5 -5
  53. retab/types/evaluations/__init__.py +31 -0
  54. retab/types/evaluations/documents.py +30 -0
  55. retab/types/evaluations/iterations.py +112 -0
  56. retab/types/evaluations/model.py +73 -0
  57. retab/types/events.py +79 -0
  58. {uiform → retab}/types/extractions.py +33 -10
  59. retab/types/inference_settings.py +15 -0
  60. retab/types/jobs/base.py +54 -0
  61. retab/types/jobs/batch_annotation.py +12 -0
  62. {uiform → retab}/types/jobs/evaluation.py +1 -2
  63. {uiform → retab}/types/logs.py +37 -34
  64. retab/types/metrics.py +32 -0
  65. {uiform → retab}/types/mime.py +22 -20
  66. {uiform → retab}/types/modalities.py +10 -10
  67. retab/types/predictions.py +19 -0
  68. {uiform → retab}/types/schemas/enhance.py +4 -2
  69. {uiform → retab}/types/schemas/evaluate.py +7 -4
  70. {uiform → retab}/types/schemas/generate.py +6 -3
  71. {uiform → retab}/types/schemas/layout.py +1 -1
  72. {uiform → retab}/types/schemas/object.py +13 -14
  73. {uiform → retab}/types/schemas/templates.py +1 -3
  74. {uiform → retab}/types/secrets/external_api_keys.py +0 -1
  75. {uiform → retab}/types/standards.py +18 -1
  76. {retab-0.0.35.dist-info → retab-0.0.37.dist-info}/METADATA +7 -6
  77. retab-0.0.37.dist-info/RECORD +107 -0
  78. retab-0.0.37.dist-info/top_level.txt +1 -0
  79. retab-0.0.35.dist-info/RECORD +0 -111
  80. retab-0.0.35.dist-info/top_level.txt +0 -1
  81. uiform/_utils/benchmarking copy.py +0 -588
  82. uiform/resources/deployments/__init__.py +0 -9
  83. uiform/resources/deployments/client.py +0 -78
  84. uiform/resources/deployments/endpoints.py +0 -322
  85. uiform/resources/deployments/links.py +0 -452
  86. uiform/resources/deployments/logs.py +0 -211
  87. uiform/resources/deployments/mailboxes.py +0 -496
  88. uiform/resources/deployments/outlook.py +0 -531
  89. uiform/resources/deployments/tests.py +0 -158
  90. uiform/resources/models.py +0 -45
  91. uiform/resources/processors/automations/client.py +0 -78
  92. uiform/resources/processors/automations/links.py +0 -356
  93. uiform/resources/processors/automations/outlook.py +0 -444
  94. uiform/resources/secrets/webhook.py +0 -62
  95. uiform/types/consensus.py +0 -10
  96. uiform/types/deployments/cron.py +0 -59
  97. uiform/types/deployments/endpoints.py +0 -28
  98. uiform/types/deployments/links.py +0 -36
  99. uiform/types/deployments/mailboxes.py +0 -67
  100. uiform/types/deployments/outlook.py +0 -76
  101. uiform/types/deployments/webhooks.py +0 -21
  102. uiform/types/events.py +0 -76
  103. uiform/types/jobs/base.py +0 -150
  104. uiform/types/jobs/batch_annotation.py +0 -22
  105. uiform/types/secrets/__init__.py +0 -0
  106. {uiform → retab}/__init__.py +0 -0
  107. {uiform → retab}/_resource.py +0 -0
  108. {uiform → retab}/_utils/__init__.py +0 -0
  109. {uiform → retab}/_utils/usage/__init__.py +0 -0
  110. {uiform → retab}/py.typed +0 -0
  111. {uiform → retab}/resources/__init__.py +0 -0
  112. {uiform → retab}/resources/consensus/__init__.py +0 -0
  113. {uiform → retab}/resources/documents/__init__.py +0 -0
  114. {uiform → retab}/resources/finetuning.py +0 -0
  115. {uiform → retab}/resources/openai_example.py +0 -0
  116. {uiform → retab}/resources/processors/__init__.py +0 -0
  117. {uiform → retab}/resources/processors/automations/__init__.py +0 -0
  118. {uiform → retab}/resources/prompt_optimization.py +0 -0
  119. {uiform → retab}/resources/secrets/__init__.py +0 -0
  120. {uiform → retab}/resources/secrets/client.py +0 -0
  121. {uiform → retab}/types/__init__.py +0 -0
  122. {uiform → retab}/types/automations/__init__.py +0 -0
  123. {uiform → retab}/types/automations/webhooks.py +0 -0
  124. {uiform → retab}/types/db/__init__.py +0 -0
  125. {uiform/types/deployments → retab/types/documents}/__init__.py +0 -0
  126. {uiform → retab}/types/documents/correct_orientation.py +0 -0
  127. {uiform/types/documents → retab/types/jobs}/__init__.py +0 -0
  128. {uiform → retab}/types/jobs/finetune.py +0 -0
  129. {uiform → retab}/types/jobs/prompt_optimization.py +0 -0
  130. {uiform → retab}/types/jobs/webcrawl.py +0 -0
  131. {uiform → retab}/types/pagination.py +0 -0
  132. {uiform/types/jobs → retab/types/schemas}/__init__.py +0 -0
  133. {uiform/types/schemas → retab/types/secrets}/__init__.py +0 -0
  134. {retab-0.0.35.dist-info → retab-0.0.37.dist-info}/WHEEL +0 -0
@@ -8,12 +8,13 @@ import time
8
8
  from concurrent.futures import ThreadPoolExecutor
9
9
  from io import IOBase
10
10
  from pathlib import Path
11
- from typing import IO, Any, Literal, Optional
11
+ from typing import IO, Any, Optional, TypedDict
12
12
 
13
13
  from anthropic import Anthropic
14
14
  from openai import OpenAI
15
15
  from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
16
16
  from pydantic import BaseModel
17
+ from pydantic_core import PydanticUndefined
17
18
  from tqdm import tqdm
18
19
 
19
20
  from .._resource import AsyncAPIResource, SyncAPIResource
@@ -24,6 +25,7 @@ from .._utils.json_schema import load_json_schema
24
25
  from ..types.chat import ChatCompletionUiformMessage
25
26
  from ..types.modalities import Modality
26
27
  from ..types.schemas.object import Schema
28
+ from ..types.browser_canvas import BrowserCanvas
27
29
 
28
30
 
29
31
  class FinetuningJSON(BaseModel):
@@ -31,7 +33,6 @@ class FinetuningJSON(BaseModel):
31
33
 
32
34
 
33
35
  FinetuningJSONL = list[FinetuningJSON]
34
- from typing import TypedDict
35
36
 
36
37
 
37
38
  class BatchJSONLResponseFormat(TypedDict):
@@ -106,9 +107,9 @@ class BatchJSONLResponse(BaseModel):
106
107
 
107
108
  class BaseDatasetsMixin:
108
109
  def _dump_training_set(self, training_set: list[dict[str, Any]], dataset_path: Path | str) -> None:
109
- with open(dataset_path, 'w', encoding='utf-8') as file:
110
+ with open(dataset_path, "w", encoding="utf-8") as file:
110
111
  for entry in training_set:
111
- file.write(json.dumps(entry) + '\n')
112
+ file.write(json.dumps(entry) + "\n")
112
113
 
113
114
 
114
115
  class Datasets(SyncAPIResource, BaseDatasetsMixin):
@@ -138,8 +139,8 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
138
139
  json_schema: dict[str, Any] | Path | str,
139
140
  document_annotation_pairs_paths: list[dict[str, Path | str]],
140
141
  dataset_path: Path | str,
141
- image_resolution_dpi: int | None = None,
142
- browser_canvas: Literal['A3', 'A4', 'A5'] | None = None,
142
+ image_resolution_dpi: int = PydanticUndefined, # type: ignore[assignment]
143
+ browser_canvas: BrowserCanvas = PydanticUndefined, # type: ignore[assignment]
143
144
  modality: Modality = "native",
144
145
  ) -> None:
145
146
  """Save document-annotation pairs to a JSONL training set.
@@ -153,16 +154,18 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
153
154
  json_schema = load_json_schema(json_schema)
154
155
  schema_obj = Schema(json_schema=json_schema)
155
156
 
156
- with open(dataset_path, 'w', encoding='utf-8') as file:
157
+ with open(dataset_path, "w", encoding="utf-8") as file:
157
158
  for pair_paths in tqdm(document_annotation_pairs_paths, desc="Processing pairs", position=0):
158
- document_message = self._client.documents.create_messages(document=pair_paths['document_fpath'], modality=modality, image_resolution_dpi=image_resolution_dpi, browser_canvas=browser_canvas)
159
+ document_message = self._client.documents.create_messages(
160
+ document=pair_paths["document_fpath"], modality=modality, image_resolution_dpi=image_resolution_dpi, browser_canvas=browser_canvas
161
+ )
159
162
 
160
- with open(pair_paths['annotation_fpath'], 'r') as f:
163
+ with open(pair_paths["annotation_fpath"], "r") as f:
161
164
  annotation = json.loads(f.read())
162
165
  assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
163
166
 
164
167
  entry = {"messages": schema_obj.messages + document_message.messages + [assistant_message]}
165
- file.write(json.dumps(entry) + '\n')
168
+ file.write(json.dumps(entry) + "\n")
166
169
 
167
170
  def change_schema(
168
171
  self,
@@ -193,8 +196,8 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
193
196
  assert isinstance(target_path, Path) or isinstance(target_path, str)
194
197
 
195
198
  # Use a temporary file to write the updated content
196
- with tempfile.NamedTemporaryFile('w', delete=False, encoding='utf-8') as temp_file:
197
- with open(input_dataset_path, 'r', encoding='utf-8') as infile:
199
+ with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as temp_file:
200
+ with open(input_dataset_path, "r", encoding="utf-8") as infile:
198
201
  for line in infile:
199
202
  entry = json.loads(line)
200
203
  messages = entry.get("messages", [])
@@ -208,7 +211,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
208
211
  updated_entry = {"messages": updated_messages}
209
212
 
210
213
  # Write the updated entry to the temporary file
211
- temp_file.write(json.dumps(updated_entry) + '\n')
214
+ temp_file.write(json.dumps(updated_entry) + "\n")
212
215
 
213
216
  # Replace the original file with the temporary file
214
217
  shutil.move(temp_file.name, target_path)
@@ -241,19 +244,19 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
241
244
  for pair_paths in tqdm(pairs_paths):
242
245
  document_messages: list[ChatCompletionUiformMessage] = []
243
246
 
244
- if isinstance(pair_paths['document_fpath'], str) or isinstance(pair_paths['document_fpath'], Path):
245
- document_message = self._client.documents.create_messages(document=pair_paths['document_fpath'], modality=modality)
247
+ if isinstance(pair_paths["document_fpath"], str) or isinstance(pair_paths["document_fpath"], Path):
248
+ document_message = self._client.documents.create_messages(document=pair_paths["document_fpath"], modality=modality)
246
249
  document_messages.extend(document_message.messages)
247
250
 
248
251
  else:
249
- assert isinstance(pair_paths['document_fpath'], list)
250
- for document_fpath in pair_paths['document_fpath']:
252
+ assert isinstance(pair_paths["document_fpath"], list)
253
+ for document_fpath in pair_paths["document_fpath"]:
251
254
  document_message = self._client.documents.create_messages(document=document_fpath, modality=modality)
252
255
  document_messages.extend(document_message.messages)
253
256
 
254
257
  # Use context manager to properly close the file
255
- assert isinstance(pair_paths['annotation_fpath'], Path) or isinstance(pair_paths['annotation_fpath'], str)
256
- with open(pair_paths['annotation_fpath'], 'r') as f:
258
+ assert isinstance(pair_paths["annotation_fpath"], Path) or isinstance(pair_paths["annotation_fpath"], str)
259
+ with open(pair_paths["annotation_fpath"], "r") as f:
257
260
  annotation = json.loads(f.read())
258
261
  assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
259
262
 
@@ -389,6 +392,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
389
392
  raise ValueError(f"Invalid file path: {doc_path}")
390
393
  hash_str = hashlib.md5(doc_path.as_posix().encode()).hexdigest()
391
394
  elif isinstance(doc, IO):
395
+ doc_path = Path(doc.name) or "unknown_file"
392
396
  file_bytes = doc.read()
393
397
  hash_str = hashlib.md5(file_bytes).hexdigest()
394
398
  doc.seek(0)
@@ -408,7 +412,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
408
412
  annotation_path = Path(root_dir) / f"annotations_{hash_str}.json"
409
413
  annotation_path.parent.mkdir(parents=True, exist_ok=True)
410
414
 
411
- with open(annotation_path, 'w', encoding='utf-8') as f:
415
+ with open(annotation_path, "w", encoding="utf-8") as f:
412
416
  json.dump(string_json, f, ensure_ascii=False, indent=2)
413
417
 
414
418
  return {"document_fpath": str(doc_path), "annotation_fpath": str(annotation_path)}
@@ -442,176 +446,176 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
442
446
  # Generate final training set from all results
443
447
  self.save(json_schema=json_schema, document_annotation_pairs_paths=pairs_paths, dataset_path=dataset_path)
444
448
 
445
- def eval(
446
- self,
447
- json_schema: dict[str, Any] | Path | str,
448
- dataset_path: str | Path,
449
- model: str = "gpt-4o-2024-08-06",
450
- temperature: float = 0.0,
451
- batch_size: int = 5,
452
- max_concurrent: int = 3,
453
- display: bool = True,
454
- ) -> ComparisonMetrics:
455
- """Evaluate model performance on a test dataset.
456
-
457
- Args:
458
- json_schema: JSON schema defining the expected data structure
459
- dataset_path: Path to the JSONL file containing test examples
460
- model: The model to use for benchmarking
461
- temperature: Model temperature setting (0-1)
462
- batch_size: Number of examples to process in each batch
463
- max_concurrent: Maximum number of concurrent API calls
464
- """
465
-
466
- json_schema = load_json_schema(json_schema)
467
- assert_valid_model_extraction(model)
468
- schema_obj = Schema(json_schema=json_schema)
469
-
470
- # Initialize appropriate client
471
- client, provider = self._initialize_model_client(model)
472
-
473
- # Read all lines from the JSONL file
474
- with open(dataset_path, 'r') as f:
475
- lines = [json.loads(line) for line in f]
476
-
477
- extraction_analyses: list[ExtractionAnalysis] = []
478
- total_batches = (len(lines) + batch_size - 1) // batch_size
479
-
480
- # Create main progress bar for batches
481
- batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
482
-
483
- # Track running metrics
484
- class RunningMetrics(BaseModel):
485
- model: str
486
- accuracy: float
487
- levenshtein: float
488
- jaccard: float
489
- false_positive: float
490
- mismatched: float
491
- processed: int
492
-
493
- running_metrics: RunningMetrics = RunningMetrics(
494
- model=model,
495
- accuracy=0.0,
496
- levenshtein=0.0,
497
- jaccard=0.0,
498
- false_positive=0.0,
499
- mismatched=0.0,
500
- processed=0, # number of processed examples - used in the loop to compute the running averages
501
- )
502
-
503
- def update_running_metrics(analysis: ExtractionAnalysis) -> None:
504
- comparison = normalized_comparison_metrics([analysis])
505
- running_metrics.processed += 1
506
- n = running_metrics.processed
507
- # Update running averages
508
- running_metrics.accuracy = (running_metrics.accuracy * (n - 1) + comparison.accuracy) / n
509
- running_metrics.levenshtein = (running_metrics.levenshtein * (n - 1) + comparison.levenshtein_similarity) / n
510
- running_metrics.jaccard = (running_metrics.jaccard * (n - 1) + comparison.jaccard_similarity) / n
511
- running_metrics.false_positive = (running_metrics.false_positive * (n - 1) + comparison.false_positive_rate) / n
512
- running_metrics.mismatched = (running_metrics.mismatched * (n - 1) + comparison.mismatched_value_rate) / n
513
- # Update progress bar description
514
- batch_pbar.set_description(
515
- f"Processing batches | Model: {running_metrics.model} | Acc: {running_metrics.accuracy:.2f} | "
516
- f"Lev: {running_metrics.levenshtein:.2f} | "
517
- f"IOU: {running_metrics.jaccard:.2f} | "
518
- f"FP: {running_metrics.false_positive:.2f} | "
519
- f"Mism: {running_metrics.mismatched:.2f}"
520
- )
521
-
522
- def process_example(jsonline: dict) -> ExtractionAnalysis | None:
523
- line_number = jsonline['line_number']
524
- try:
525
- messages = jsonline['messages']
526
- ground_truth = json.loads(messages[-1]['content'])
527
- inference_messages = messages[:-1]
528
-
529
- # Use _get_model_completion instead of duplicating provider-specific logic
530
- string_json = self._get_model_completion(client=client, provider=provider, model=model, temperature=temperature, messages=inference_messages, schema_obj=schema_obj)
531
-
532
- prediction = json.loads(string_json)
533
- analysis = ExtractionAnalysis(
534
- ground_truth=ground_truth,
535
- prediction=prediction,
536
- )
537
- update_running_metrics(analysis)
538
- return analysis
539
- except Exception as e:
540
- print(f"\nWarning: Failed to process line number {line_number}: {str(e)}")
541
- return None
542
-
543
- with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
544
- # Split entries into batches
545
- for batch_idx in range(0, len(lines), batch_size):
546
- batch = lines[batch_idx : batch_idx + batch_size]
547
-
548
- # Submit and process batch
549
- futures = [executor.submit(process_example, entry | {"line_number": batch_idx * batch_size + i}) for i, entry in enumerate(batch)]
550
- for future in futures:
551
- result = future.result()
552
- if result is not None:
553
- extraction_analyses.append(result)
554
-
555
- batch_pbar.update(1)
556
-
557
- batch_pbar.close()
558
-
559
- # Analyze error patterns across all examples
560
- analysis = normalized_comparison_metrics(extraction_analyses)
561
-
562
- if display:
563
- plot_comparison_metrics(analysis=analysis, top_n=10)
564
-
565
- return analysis
566
-
567
- def benchmark(
568
- self,
569
- json_schema: dict[str, Any] | Path | str,
570
- dataset_path: str | Path,
571
- models: list[str],
572
- temperature: float = 0.0,
573
- batch_size: int = 5,
574
- max_concurrent: int = 3,
575
- print: bool = True,
576
- verbose: bool = False,
577
- ) -> list[BenchmarkMetrics]:
578
- """Benchmark multiple models on a test dataset.
579
-
580
- Args:
581
- json_schema: JSON schema defining the expected data structure
582
- dataset_path: Path to the JSONL file containing test examples
583
- models: List of models to benchmark
584
- temperature: Model temperature setting (0-1)
585
- batch_size: Number of examples to process in each batch
586
- max_concurrent: Maximum number of concurrent API calls
587
- print: Whether to print the metrics
588
- verbose: Whether to print all the metrics of all the function calls
589
-
590
- Returns:
591
- Dictionary mapping model names to their evaluation metrics
592
- """
593
- results: list[BenchmarkMetrics] = []
594
-
595
- for model in models:
596
- metrics: ComparisonMetrics = self.eval(
597
- json_schema=json_schema, dataset_path=dataset_path, model=model, temperature=temperature, batch_size=batch_size, max_concurrent=max_concurrent, display=verbose
598
- )
599
- results.append(
600
- BenchmarkMetrics(
601
- ai_model=model,
602
- accuracy=metrics.accuracy,
603
- levenshtein_similarity=metrics.levenshtein_similarity,
604
- jaccard_similarity=metrics.jaccard_similarity,
605
- false_positive_rate=metrics.false_positive_rate,
606
- false_negative_rate=metrics.false_negative_rate,
607
- mismatched_value_rate=metrics.mismatched_value_rate,
608
- )
609
- )
610
-
611
- if print:
612
- display_benchmark_metrics(results)
613
-
614
- return results
449
+ # def eval(
450
+ # self,
451
+ # json_schema: dict[str, Any] | Path | str,
452
+ # dataset_path: str | Path,
453
+ # model: str = "gpt-4o-2024-08-06",
454
+ # temperature: float = 0.0,
455
+ # batch_size: int = 5,
456
+ # max_concurrent: int = 3,
457
+ # display: bool = True,
458
+ # ) -> ComparisonMetrics:
459
+ # """Evaluate model performance on a test dataset.
460
+
461
+ # Args:
462
+ # json_schema: JSON schema defining the expected data structure
463
+ # dataset_path: Path to the JSONL file containing test examples
464
+ # model: The model to use for benchmarking
465
+ # temperature: Model temperature setting (0-1)
466
+ # batch_size: Number of examples to process in each batch
467
+ # max_concurrent: Maximum number of concurrent API calls
468
+ # """
469
+
470
+ # json_schema = load_json_schema(json_schema)
471
+ # assert_valid_model_extraction(model)
472
+ # schema_obj = Schema(json_schema=json_schema)
473
+
474
+ # # Initialize appropriate client
475
+ # client, provider = self._initialize_model_client(model)
476
+
477
+ # # Read all lines from the JSONL file
478
+ # with open(dataset_path, "r") as f:
479
+ # lines = [json.loads(line) for line in f]
480
+
481
+ # extraction_analyses: list[ExtractionAnalysis] = []
482
+ # total_batches = (len(lines) + batch_size - 1) // batch_size
483
+
484
+ # # Create main progress bar for batches
485
+ # batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
486
+
487
+ # # Track running metrics
488
+ # class RunningMetrics(BaseModel):
489
+ # model: str
490
+ # accuracy: float
491
+ # levenshtein: float
492
+ # jaccard: float
493
+ # false_positive: float
494
+ # mismatched: float
495
+ # processed: int
496
+
497
+ # running_metrics: RunningMetrics = RunningMetrics(
498
+ # model=model,
499
+ # accuracy=0.0,
500
+ # levenshtein=0.0,
501
+ # jaccard=0.0,
502
+ # false_positive=0.0,
503
+ # mismatched=0.0,
504
+ # processed=0, # number of processed examples - used in the loop to compute the running averages
505
+ # )
506
+
507
+ # # def update_running_metrics(analysis: ExtractionAnalysis) -> None:
508
+ # # comparison = normalized_comparison_metrics([analysis])
509
+ # # running_metrics.processed += 1
510
+ # # n = running_metrics.processed
511
+ # # # Update running averages
512
+ # # running_metrics.accuracy = (running_metrics.accuracy * (n - 1) + comparison.accuracy) / n
513
+ # # running_metrics.levenshtein = (running_metrics.levenshtein * (n - 1) + comparison.levenshtein_similarity) / n
514
+ # # running_metrics.jaccard = (running_metrics.jaccard * (n - 1) + comparison.jaccard_similarity) / n
515
+ # # running_metrics.false_positive = (running_metrics.false_positive * (n - 1) + comparison.false_positive_rate) / n
516
+ # # running_metrics.mismatched = (running_metrics.mismatched * (n - 1) + comparison.mismatched_value_rate) / n
517
+ # # # Update progress bar description
518
+ # # batch_pbar.set_description(
519
+ # # f"Processing batches | Model: {running_metrics.model} | Acc: {running_metrics.accuracy:.2f} | "
520
+ # # f"Lev: {running_metrics.levenshtein:.2f} | "
521
+ # # f"IOU: {running_metrics.jaccard:.2f} | "
522
+ # # f"FP: {running_metrics.false_positive:.2f} | "
523
+ # # f"Mism: {running_metrics.mismatched:.2f}"
524
+ # # )
525
+
526
+ # # def process_example(jsonline: dict) -> ExtractionAnalysis | None:
527
+ # # line_number = jsonline["line_number"]
528
+ # # try:
529
+ # # messages = jsonline["messages"]
530
+ # # ground_truth = json.loads(messages[-1]["content"])
531
+ # # inference_messages = messages[:-1]
532
+
533
+ # # # Use _get_model_completion instead of duplicating provider-specific logic
534
+ # # string_json = self._get_model_completion(client=client, provider=provider, model=model, temperature=temperature, messages=inference_messages, schema_obj=schema_obj)
535
+
536
+ # # prediction = json.loads(string_json)
537
+ # # analysis = ExtractionAnalysis(
538
+ # # ground_truth=ground_truth,
539
+ # # prediction=prediction,
540
+ # # )
541
+ # # update_running_metrics(analysis)
542
+ # # return analysis
543
+ # # except Exception as e:
544
+ # # print(f"\nWarning: Failed to process line number {line_number}: {str(e)}")
545
+ # # return None
546
+
547
+ # # with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
548
+ # # # Split entries into batches
549
+ # # for batch_idx in range(0, len(lines), batch_size):
550
+ # # batch = lines[batch_idx : batch_idx + batch_size]
551
+
552
+ # # # Submit and process batch
553
+ # # futures = [executor.submit(process_example, entry | {"line_number": batch_idx * batch_size + i}) for i, entry in enumerate(batch)]
554
+ # # for future in futures:
555
+ # # result = future.result()
556
+ # # if result is not None:
557
+ # # extraction_analyses.append(result)
558
+
559
+ # # batch_pbar.update(1)
560
+
561
+ # # batch_pbar.close()
562
+
563
+ # # # Analyze error patterns across all examples
564
+ # # analysis = normalized_comparison_metrics(extraction_analyses)
565
+
566
+ # # if display:
567
+ # # plot_comparison_metrics(analysis=analysis, top_n=10)
568
+
569
+ # # return analysis
570
+
571
+ # def benchmark(
572
+ # self,
573
+ # json_schema: dict[str, Any] | Path | str,
574
+ # dataset_path: str | Path,
575
+ # models: list[str],
576
+ # temperature: float = 0.0,
577
+ # batch_size: int = 5,
578
+ # max_concurrent: int = 3,
579
+ # print: bool = True,
580
+ # verbose: bool = False,
581
+ # ) -> list[BenchmarkMetrics]:
582
+ # """Benchmark multiple models on a test dataset.
583
+
584
+ # Args:
585
+ # json_schema: JSON schema defining the expected data structure
586
+ # dataset_path: Path to the JSONL file containing test examples
587
+ # models: List of models to benchmark
588
+ # temperature: Model temperature setting (0-1)
589
+ # batch_size: Number of examples to process in each batch
590
+ # max_concurrent: Maximum number of concurrent API calls
591
+ # print: Whether to print the metrics
592
+ # verbose: Whether to print all the metrics of all the function calls
593
+
594
+ # Returns:
595
+ # Dictionary mapping model names to their evaluation metrics
596
+ # """
597
+ # results: list[BenchmarkMetrics] = []
598
+
599
+ # for model in models:
600
+ # metrics: ComparisonMetrics = self.eval(
601
+ # json_schema=json_schema, dataset_path=dataset_path, model=model, temperature=temperature, batch_size=batch_size, max_concurrent=max_concurrent, display=verbose
602
+ # )
603
+ # results.append(
604
+ # BenchmarkMetrics(
605
+ # ai_model=model,
606
+ # accuracy=metrics.accuracy,
607
+ # levenshtein_similarity=metrics.levenshtein_similarity,
608
+ # jaccard_similarity=metrics.jaccard_similarity,
609
+ # false_positive_rate=metrics.false_positive_rate,
610
+ # false_negative_rate=metrics.false_negative_rate,
611
+ # mismatched_value_rate=metrics.mismatched_value_rate,
612
+ # )
613
+ # )
614
+
615
+ # if print:
616
+ # display_benchmark_metrics(results)
617
+
618
+ # return results
615
619
 
616
620
  def update_annotations(
617
621
  self,
@@ -642,7 +646,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
642
646
  client, provider = self._initialize_model_client(model)
643
647
 
644
648
  # Read all lines from the JSONL file
645
- with open(old_dataset_path, 'r') as f:
649
+ with open(old_dataset_path, "r") as f:
646
650
  lines = [json.loads(line) for line in f]
647
651
 
648
652
  updated_entries = []
@@ -651,13 +655,13 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
651
655
  batch_pbar = tqdm(total=total_batches, desc="Processing batches", position=0)
652
656
 
653
657
  def process_entry(entry: dict) -> dict:
654
- messages = entry['messages']
658
+ messages = entry["messages"]
655
659
  system_message, user_messages, assistant_messages = separate_messages(messages)
656
660
  system_and_user_messages = messages[:-1]
657
661
 
658
662
  previous_annotation_message: ChatCompletionUiformMessage = {
659
663
  "role": "user",
660
- "content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + messages[-1]['content'],
664
+ "content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + messages[-1]["content"],
661
665
  }
662
666
 
663
667
  string_json = self._get_model_completion(
@@ -691,9 +695,9 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
691
695
 
692
696
  batch_pbar.close()
693
697
 
694
- with open(new_dataset_path, 'w') as f:
698
+ with open(new_dataset_path, "w") as f:
695
699
  for entry in updated_entries:
696
- f.write(json.dumps(entry) + '\n')
700
+ f.write(json.dumps(entry) + "\n")
697
701
 
698
702
  #########################
699
703
  ##### BATCH METHODS #####
@@ -722,7 +726,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
722
726
  schema_obj = Schema(json_schema=loaded_json_schema)
723
727
  assert_valid_model_extraction(model)
724
728
 
725
- with open(batch_requests_path, 'w', encoding='utf-8') as f:
729
+ with open(batch_requests_path, "w", encoding="utf-8") as f:
726
730
  for i, doc in tqdm(enumerate(documents)):
727
731
  # Create document messages
728
732
  doc_msg = self._client.documents.create_messages(
@@ -744,7 +748,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
744
748
  }
745
749
 
746
750
  # Write the request as a JSON line
747
- f.write(json.dumps(request) + '\n')
751
+ f.write(json.dumps(request) + "\n")
748
752
 
749
753
  def save_batch_update_annotation_requests(
750
754
  self,
@@ -768,18 +772,18 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
768
772
  schema_obj = Schema(json_schema=loaded_json_schema)
769
773
 
770
774
  # Read existing annotations
771
- with open(old_dataset_path, 'r') as f:
775
+ with open(old_dataset_path, "r") as f:
772
776
  entries = [json.loads(line) for line in f]
773
777
 
774
778
  # Create new JSONL with update requests
775
- with open(batch_requests_path, 'w', encoding='utf-8') as f:
779
+ with open(batch_requests_path, "w", encoding="utf-8") as f:
776
780
  for i, entry in enumerate(entries):
777
- existing_messages = entry['messages']
781
+ existing_messages = entry["messages"]
778
782
  system_and_user_messages = existing_messages[:-1]
779
783
 
780
784
  previous_annotation_message: ChatCompletionMessageParam = {
781
785
  "role": "user",
782
- "content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + existing_messages[-1]['content'],
786
+ "content": "Here is an old annotation using a different schema. Use it as a reference to update the annotation: " + existing_messages[-1]["content"],
783
787
  }
784
788
 
785
789
  # Construct the request object
@@ -798,7 +802,7 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
798
802
  request: BatchJSONL = {"custom_id": f"request-{i}", "method": "POST", "url": "/v1/chat/completions", "body": body}
799
803
 
800
804
  # Write the request as a JSON line
801
- f.write(json.dumps(request) + '\n')
805
+ f.write(json.dumps(request) + "\n")
802
806
 
803
807
  def build_dataset_from_batch_results(
804
808
  self,
@@ -806,27 +810,27 @@ class Datasets(SyncAPIResource, BaseDatasetsMixin):
806
810
  batch_results_path: str | Path,
807
811
  dataset_results_path: str | Path,
808
812
  ) -> None:
809
- with open(batch_requests_path, 'r') as f:
813
+ with open(batch_requests_path, "r") as f:
810
814
  input_lines: list[BatchJSONL] = [json.loads(line) for line in f]
811
- with open(batch_results_path, 'r') as f:
812
- batch_results_lines: list[BatchJSONLResponse] = [json.loads(line) for line in f]
815
+ with open(batch_results_path, "r") as f:
816
+ batch_results_lines: list[BatchJSONLResponse] = [BatchJSONLResponse.model_validate_json(line) for line in f]
813
817
 
814
818
  assert len(input_lines) == len(batch_results_lines), "Input and batch results must have the same number of lines"
815
819
 
816
820
  for input_line, batch_result in zip(input_lines, batch_results_lines):
817
- messages = input_line['body']['messages']
821
+ messages = input_line["body"]["messages"]
818
822
 
819
823
  # Filter out messages containing the old annotation reference to remove messages that come from "update annotation"
820
- if isinstance(messages[-1].get('content'), str):
821
- if re.search(r'Here is an old annotation using a different schema\. Use it as a reference to update the annotation:', str(messages[-1].get('content', ''))):
824
+ if isinstance(messages[-1].get("content"), str):
825
+ if re.search(r"Here is an old annotation using a different schema\. Use it as a reference to update the annotation:", str(messages[-1].get("content", ""))):
822
826
  print("found keyword")
823
- input_line['body']['messages'] = messages[:-1]
827
+ input_line["body"]["messages"] = messages[:-1]
824
828
 
825
- input_line['body']['messages'].append(batch_result['response']['body']['choices'][0]['message'])
829
+ input_line["body"]["messages"].append(batch_result.response.body.choices[0].message)
826
830
 
827
- with open(dataset_results_path, 'w') as f:
831
+ with open(dataset_results_path, "w") as f:
828
832
  for input_line in input_lines:
829
- f.write(json.dumps({'messages': input_line['body']['messages']}) + '\n')
833
+ f.write(json.dumps({"messages": input_line["body"]["messages"]}) + "\n")
830
834
 
831
835
  print(f"Dataset saved to {dataset_results_path}")
832
836
 
@@ -849,9 +853,9 @@ class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
849
853
  training_set = []
850
854
 
851
855
  for pair_paths in tqdm(pairs_paths):
852
- document_message = await self._client.documents.create_messages(document=pair_paths['document_fpath'], modality=modality)
856
+ document_message = await self._client.documents.create_messages(document=pair_paths["document_fpath"], modality=modality)
853
857
 
854
- with open(pair_paths['annotation_fpath'], 'r') as f:
858
+ with open(pair_paths["annotation_fpath"], "r") as f:
855
859
  annotation = json.loads(f.read())
856
860
  assistant_message = {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False, indent=2)}
857
861
 
@@ -917,7 +921,7 @@ class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
917
921
 
918
922
  annotation_path.parent.mkdir(parents=True, exist_ok=True)
919
923
 
920
- with open(annotation_path, 'w', encoding='utf-8') as f:
924
+ with open(annotation_path, "w", encoding="utf-8") as f:
921
925
  json.dump(result.choices[0].message.content, f, ensure_ascii=False, indent=2)
922
926
 
923
927
  return {"document_fpath": str(doc_path), "annotation_fpath": str(annotation_path)}
@@ -954,7 +958,7 @@ class AsyncDatasets(AsyncAPIResource, BaseDatasetsMixin):
954
958
 
955
959
  annotation_path.parent.mkdir(parents=True, exist_ok=True)
956
960
 
957
- with open(annotation_path, 'w', encoding='utf-8') as f:
961
+ with open(annotation_path, "w", encoding="utf-8") as f:
958
962
  json.dump(result.choices[0].message.content, f, ensure_ascii=False, indent=2)
959
963
 
960
964
  return {