retab 0.0.68__py3-none-any.whl → 0.0.70__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retab/client.py +3 -1
- retab/resources/documents/client.py +44 -138
- retab/resources/extractions/__init__.py +3 -0
- retab/resources/extractions/client.py +288 -0
- retab/resources/projects/client.py +7 -1
- retab/resources/schemas.py +0 -8
- retab/types/documents/create_messages.py +10 -12
- retab/types/documents/extract.py +16 -81
- retab/types/documents/parse.py +0 -2
- retab/types/extractions/__init__.py +0 -0
- retab/types/extractions/types.py +3 -0
- retab/types/inference_settings.py +6 -4
- retab/types/mime.py +4 -38
- retab/types/pagination.py +8 -0
- retab/types/projects/model.py +49 -36
- retab/types/schemas/generate.py +0 -4
- {retab-0.0.68.dist-info → retab-0.0.70.dist-info}/METADATA +1 -1
- {retab-0.0.68.dist-info → retab-0.0.70.dist-info}/RECORD +20 -18
- retab/client copy.py +0 -693
- retab/types/browser_canvas.py +0 -3
- {retab-0.0.68.dist-info → retab-0.0.70.dist-info}/WHEEL +0 -0
- {retab-0.0.68.dist-info → retab-0.0.70.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Dict, List, Literal
|
|
4
|
+
|
|
5
|
+
from ..._resource import AsyncAPIResource, SyncAPIResource
|
|
6
|
+
from ...types.standards import PreparedRequest
|
|
7
|
+
from ...types.pagination import PaginatedList, PaginationOrder
|
|
8
|
+
from ...types.extractions.types import HumanReviewStatus
|
|
9
|
+
|
|
10
|
+
class ExtractionsMixin:
|
|
11
|
+
def prepare_list(
|
|
12
|
+
self,
|
|
13
|
+
before: str | None = None,
|
|
14
|
+
after: str | None = None,
|
|
15
|
+
limit: int = 10,
|
|
16
|
+
order: PaginationOrder = "desc",
|
|
17
|
+
origin_dot_type: str | None = None,
|
|
18
|
+
origin_dot_id: str | None = None,
|
|
19
|
+
from_date: datetime | None = None,
|
|
20
|
+
to_date: datetime | None = None,
|
|
21
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
22
|
+
metadata: Dict[str, str] | None = None,
|
|
23
|
+
**extra_params: Any,
|
|
24
|
+
) -> PreparedRequest:
|
|
25
|
+
"""Prepare a request to list extractions with pagination and filtering."""
|
|
26
|
+
params = {
|
|
27
|
+
"before": before,
|
|
28
|
+
"after": after,
|
|
29
|
+
"limit": limit,
|
|
30
|
+
"order": order,
|
|
31
|
+
"origin_dot_type": origin_dot_type,
|
|
32
|
+
"origin_dot_id": origin_dot_id,
|
|
33
|
+
"from_date": from_date.isoformat() if from_date else None,
|
|
34
|
+
"to_date": to_date.isoformat() if to_date else None,
|
|
35
|
+
"human_review_status": human_review_status,
|
|
36
|
+
# Note: metadata must be JSON-serialized as the backend expects a JSON string
|
|
37
|
+
"metadata": json.dumps(metadata) if metadata else None,
|
|
38
|
+
}
|
|
39
|
+
if extra_params:
|
|
40
|
+
params.update(extra_params)
|
|
41
|
+
# Remove None values
|
|
42
|
+
params = {k: v for k, v in params.items() if v is not None}
|
|
43
|
+
return PreparedRequest(method="GET", url="/v1/extractions", params=params)
|
|
44
|
+
|
|
45
|
+
def prepare_download(
|
|
46
|
+
self,
|
|
47
|
+
order: Literal["asc", "desc"] = "desc",
|
|
48
|
+
origin_dot_id: str | None = None,
|
|
49
|
+
from_date: datetime | None = None,
|
|
50
|
+
to_date: datetime | None = None,
|
|
51
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
52
|
+
metadata: Dict[str, str] | None = None,
|
|
53
|
+
format: Literal["jsonl", "csv", "xlsx"] = "jsonl",
|
|
54
|
+
**extra_params: Any,
|
|
55
|
+
) -> PreparedRequest:
|
|
56
|
+
"""Prepare a request to download extractions in various formats."""
|
|
57
|
+
params = {
|
|
58
|
+
"order": order,
|
|
59
|
+
"origin_dot_id": origin_dot_id,
|
|
60
|
+
"from_date": from_date.isoformat() if from_date else None,
|
|
61
|
+
"to_date": to_date.isoformat() if to_date else None,
|
|
62
|
+
"human_review_status": human_review_status,
|
|
63
|
+
# Note: metadata must be JSON-serialized as the backend expects a JSON string
|
|
64
|
+
"metadata": json.dumps(metadata) if metadata else None,
|
|
65
|
+
"format": format,
|
|
66
|
+
}
|
|
67
|
+
if extra_params:
|
|
68
|
+
params.update(extra_params)
|
|
69
|
+
params = {k: v for k, v in params.items() if v is not None}
|
|
70
|
+
return PreparedRequest(method="GET", url="/v1/extractions/download", params=params)
|
|
71
|
+
|
|
72
|
+
def prepare_update(
|
|
73
|
+
self,
|
|
74
|
+
extraction_id: str,
|
|
75
|
+
predictions: dict[str, Any] | None = None,
|
|
76
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
77
|
+
json_schema: dict[str, Any] | None = None,
|
|
78
|
+
inference_settings: dict[str, Any] | None = None,
|
|
79
|
+
**extra_body: Any,
|
|
80
|
+
) -> PreparedRequest:
|
|
81
|
+
"""Prepare a request to update an extraction."""
|
|
82
|
+
data: dict[str, Any] = {}
|
|
83
|
+
if predictions is not None:
|
|
84
|
+
data["predictions"] = predictions
|
|
85
|
+
if human_review_status is not None:
|
|
86
|
+
data["human_review_status"] = human_review_status
|
|
87
|
+
if json_schema is not None:
|
|
88
|
+
data["json_schema"] = json_schema
|
|
89
|
+
if inference_settings is not None:
|
|
90
|
+
data["inference_settings"] = inference_settings
|
|
91
|
+
if extra_body:
|
|
92
|
+
data.update(extra_body)
|
|
93
|
+
return PreparedRequest(method="PATCH", url=f"/v1/extractions/{extraction_id}", data=data)
|
|
94
|
+
|
|
95
|
+
def prepare_get(self, extraction_id: str) -> PreparedRequest:
|
|
96
|
+
"""Prepare a request to get an extraction by ID."""
|
|
97
|
+
return PreparedRequest(method="GET", url=f"/v1/extractions/{extraction_id}")
|
|
98
|
+
|
|
99
|
+
def prepare_delete(self, extraction_id: str) -> PreparedRequest:
|
|
100
|
+
"""Prepare a request to delete an extraction by ID."""
|
|
101
|
+
return PreparedRequest(method="DELETE", url=f"/v1/extractions/{extraction_id}")
|
|
102
|
+
|
|
103
|
+
class Extractions(SyncAPIResource, ExtractionsMixin):
|
|
104
|
+
"""Extractions API wrapper"""
|
|
105
|
+
|
|
106
|
+
def __init__(self, *args, **kwargs):
|
|
107
|
+
super().__init__(*args, **kwargs)
|
|
108
|
+
|
|
109
|
+
def list(
|
|
110
|
+
self,
|
|
111
|
+
before: str | None = None,
|
|
112
|
+
after: str | None = None,
|
|
113
|
+
limit: int = 10,
|
|
114
|
+
order: PaginationOrder = "desc",
|
|
115
|
+
origin_dot_type: str | None = None,
|
|
116
|
+
origin_dot_id: str | None = None,
|
|
117
|
+
from_date: datetime | None = None,
|
|
118
|
+
to_date: datetime | None = None,
|
|
119
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
120
|
+
metadata: Dict[str, str] | None = None,
|
|
121
|
+
**extra_params: Any,
|
|
122
|
+
) -> PaginatedList:
|
|
123
|
+
"""List extractions with pagination and filtering."""
|
|
124
|
+
request = self.prepare_list(
|
|
125
|
+
before=before,
|
|
126
|
+
after=after,
|
|
127
|
+
limit=limit,
|
|
128
|
+
order=order,
|
|
129
|
+
origin_dot_type=origin_dot_type,
|
|
130
|
+
origin_dot_id=origin_dot_id,
|
|
131
|
+
from_date=from_date,
|
|
132
|
+
to_date=to_date,
|
|
133
|
+
human_review_status=human_review_status,
|
|
134
|
+
metadata=metadata,
|
|
135
|
+
**extra_params,
|
|
136
|
+
)
|
|
137
|
+
response = self._client._prepared_request(request)
|
|
138
|
+
return PaginatedList(**response)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def download(
|
|
142
|
+
self,
|
|
143
|
+
order: Literal["asc", "desc"] = "desc",
|
|
144
|
+
origin_dot_id: str | None = None,
|
|
145
|
+
from_date: datetime | None = None,
|
|
146
|
+
to_date: datetime | None = None,
|
|
147
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
148
|
+
metadata: Dict[str, str] | None = None,
|
|
149
|
+
format: Literal["jsonl", "csv", "xlsx"] = "jsonl",
|
|
150
|
+
**extra_params: Any,
|
|
151
|
+
) -> dict[str, Any]:
|
|
152
|
+
"""Download extractions in various formats. Returns download_url, filename, and expires_at."""
|
|
153
|
+
request = self.prepare_download(
|
|
154
|
+
order=order,
|
|
155
|
+
origin_dot_id=origin_dot_id,
|
|
156
|
+
from_date=from_date,
|
|
157
|
+
to_date=to_date,
|
|
158
|
+
human_review_status=human_review_status,
|
|
159
|
+
metadata=metadata,
|
|
160
|
+
format=format,
|
|
161
|
+
**extra_params,
|
|
162
|
+
)
|
|
163
|
+
return self._client._prepared_request(request)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def update(
|
|
167
|
+
self,
|
|
168
|
+
extraction_id: str,
|
|
169
|
+
predictions: dict[str, Any] | None = None,
|
|
170
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
171
|
+
json_schema: dict[str, Any] | None = None,
|
|
172
|
+
inference_settings: dict[str, Any] | None = None,
|
|
173
|
+
**extra_body: Any,
|
|
174
|
+
) -> dict[str, Any]:
|
|
175
|
+
"""Update an extraction."""
|
|
176
|
+
request = self.prepare_update(
|
|
177
|
+
extraction_id=extraction_id,
|
|
178
|
+
predictions=predictions,
|
|
179
|
+
human_review_status=human_review_status,
|
|
180
|
+
json_schema=json_schema,
|
|
181
|
+
inference_settings=inference_settings,
|
|
182
|
+
**extra_body,
|
|
183
|
+
)
|
|
184
|
+
response = self._client._prepared_request(request)
|
|
185
|
+
return response
|
|
186
|
+
|
|
187
|
+
def get(self, extraction_id: str) -> dict[str, Any]:
|
|
188
|
+
"""Get an extraction by ID."""
|
|
189
|
+
request = self.prepare_get(extraction_id)
|
|
190
|
+
return self._client._prepared_request(request)
|
|
191
|
+
|
|
192
|
+
def delete(self, extraction_id: str) -> None:
|
|
193
|
+
"""Delete an extraction by ID."""
|
|
194
|
+
request = self.prepare_delete(extraction_id)
|
|
195
|
+
self._client._prepared_request(request)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class AsyncExtractions(AsyncAPIResource, ExtractionsMixin):
|
|
199
|
+
"""Async Extractions API wrapper"""
|
|
200
|
+
|
|
201
|
+
def __init__(self, *args, **kwargs):
|
|
202
|
+
super().__init__(*args, **kwargs)
|
|
203
|
+
|
|
204
|
+
async def list(
|
|
205
|
+
self,
|
|
206
|
+
before: str | None = None,
|
|
207
|
+
after: str | None = None,
|
|
208
|
+
limit: int = 10,
|
|
209
|
+
order: PaginationOrder = "desc",
|
|
210
|
+
origin_dot_type: str | None = None,
|
|
211
|
+
origin_dot_id: str | None = None,
|
|
212
|
+
from_date: datetime | None = None,
|
|
213
|
+
to_date: datetime | None = None,
|
|
214
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
215
|
+
metadata: Dict[str, str] | None = None,
|
|
216
|
+
**extra_params: Any,
|
|
217
|
+
) -> PaginatedList:
|
|
218
|
+
"""List extractions with pagination and filtering."""
|
|
219
|
+
request = self.prepare_list(
|
|
220
|
+
before=before,
|
|
221
|
+
after=after,
|
|
222
|
+
limit=limit,
|
|
223
|
+
order=order,
|
|
224
|
+
origin_dot_type=origin_dot_type,
|
|
225
|
+
origin_dot_id=origin_dot_id,
|
|
226
|
+
from_date=from_date,
|
|
227
|
+
to_date=to_date,
|
|
228
|
+
human_review_status=human_review_status,
|
|
229
|
+
metadata=metadata,
|
|
230
|
+
**extra_params,
|
|
231
|
+
)
|
|
232
|
+
response = await self._client._prepared_request(request)
|
|
233
|
+
return PaginatedList(**response)
|
|
234
|
+
|
|
235
|
+
async def download(
|
|
236
|
+
self,
|
|
237
|
+
order: Literal["asc", "desc"] = "desc",
|
|
238
|
+
origin_dot_id: str | None = None,
|
|
239
|
+
from_date: datetime | None = None,
|
|
240
|
+
to_date: datetime | None = None,
|
|
241
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
242
|
+
metadata: Dict[str, str] | None = None,
|
|
243
|
+
format: Literal["jsonl", "csv", "xlsx"] = "jsonl",
|
|
244
|
+
**extra_params: Any,
|
|
245
|
+
) -> dict[str, Any]:
|
|
246
|
+
"""Download extractions in various formats. Returns download_url, filename, and expires_at."""
|
|
247
|
+
request = self.prepare_download(
|
|
248
|
+
order=order,
|
|
249
|
+
origin_dot_id=origin_dot_id,
|
|
250
|
+
from_date=from_date,
|
|
251
|
+
to_date=to_date,
|
|
252
|
+
human_review_status=human_review_status,
|
|
253
|
+
metadata=metadata,
|
|
254
|
+
format=format,
|
|
255
|
+
**extra_params,
|
|
256
|
+
)
|
|
257
|
+
return await self._client._prepared_request(request)
|
|
258
|
+
|
|
259
|
+
async def update(
|
|
260
|
+
self,
|
|
261
|
+
extraction_id: str,
|
|
262
|
+
predictions: dict[str, Any] | None = None,
|
|
263
|
+
human_review_status: HumanReviewStatus | None = None,
|
|
264
|
+
json_schema: dict[str, Any] | None = None,
|
|
265
|
+
inference_settings: dict[str, Any] | None = None,
|
|
266
|
+
**extra_body: Any,
|
|
267
|
+
) -> dict[str, Any]:
|
|
268
|
+
"""Update an extraction."""
|
|
269
|
+
request = self.prepare_update(
|
|
270
|
+
extraction_id=extraction_id,
|
|
271
|
+
predictions=predictions,
|
|
272
|
+
human_review_status=human_review_status,
|
|
273
|
+
json_schema=json_schema,
|
|
274
|
+
inference_settings=inference_settings,
|
|
275
|
+
**extra_body,
|
|
276
|
+
)
|
|
277
|
+
response = await self._client._prepared_request(request)
|
|
278
|
+
return response
|
|
279
|
+
|
|
280
|
+
async def get(self, extraction_id: str) -> dict[str, Any]:
|
|
281
|
+
"""Get an extraction by ID."""
|
|
282
|
+
request = self.prepare_get(extraction_id)
|
|
283
|
+
return await self._client._prepared_request(request)
|
|
284
|
+
|
|
285
|
+
async def delete(self, extraction_id: str) -> None:
|
|
286
|
+
"""Delete an extraction by ID."""
|
|
287
|
+
request = self.prepare_delete(extraction_id)
|
|
288
|
+
await self._client._prepared_request(request)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import json
|
|
2
3
|
from io import IOBase
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any, Dict, List, Optional, Sequence
|
|
@@ -89,6 +90,8 @@ class ProjectsMixin:
|
|
|
89
90
|
n_consensus: int | None = None,
|
|
90
91
|
seed: int | None = None,
|
|
91
92
|
store: bool = True,
|
|
93
|
+
metadata: Dict[str, str] | None = None,
|
|
94
|
+
extraction_id: str | None = None,
|
|
92
95
|
**extra_form: Any,
|
|
93
96
|
) -> PreparedRequest:
|
|
94
97
|
"""Prepare a request to extract documents from a project.
|
|
@@ -104,7 +107,7 @@ class ProjectsMixin:
|
|
|
104
107
|
n_consensus: Optional number of consensus extractions
|
|
105
108
|
store: Whether to store the results
|
|
106
109
|
seed: Optional seed for reproducibility
|
|
107
|
-
|
|
110
|
+
metadata: User-defined metadata for the extraction
|
|
108
111
|
|
|
109
112
|
Returns:
|
|
110
113
|
PreparedRequest: The prepared request
|
|
@@ -117,6 +120,7 @@ class ProjectsMixin:
|
|
|
117
120
|
raise ValueError("Provide either 'document' (single) or 'documents' (multiple), not both")
|
|
118
121
|
|
|
119
122
|
# Prepare form data parameters
|
|
123
|
+
# Note: metadata must be JSON-serialized since httpx multipart forms only accept primitive types
|
|
120
124
|
form_data = {
|
|
121
125
|
"model": model,
|
|
122
126
|
"temperature": temperature,
|
|
@@ -124,6 +128,8 @@ class ProjectsMixin:
|
|
|
124
128
|
"n_consensus": n_consensus,
|
|
125
129
|
"seed": seed,
|
|
126
130
|
"store": store,
|
|
131
|
+
"metadata": json.dumps(metadata) if metadata else None,
|
|
132
|
+
"extraction_id": extraction_id,
|
|
127
133
|
}
|
|
128
134
|
if extra_form:
|
|
129
135
|
form_data.update(extra_form)
|
retab/resources/schemas.py
CHANGED
|
@@ -9,7 +9,6 @@ from .._resource import AsyncAPIResource, SyncAPIResource
|
|
|
9
9
|
from ..utils.mime import prepare_mime_document_list
|
|
10
10
|
from ..types.mime import MIMEData
|
|
11
11
|
from ..types.schemas.generate import GenerateSchemaRequest
|
|
12
|
-
from ..types.browser_canvas import BrowserCanvas
|
|
13
12
|
from ..types.standards import PreparedRequest, FieldUnset
|
|
14
13
|
|
|
15
14
|
|
|
@@ -22,7 +21,6 @@ class SchemasMixin:
|
|
|
22
21
|
temperature: float = FieldUnset,
|
|
23
22
|
reasoning_effort: ChatCompletionReasoningEffort = FieldUnset,
|
|
24
23
|
image_resolution_dpi: int = FieldUnset,
|
|
25
|
-
browser_canvas: BrowserCanvas = FieldUnset,
|
|
26
24
|
**extra_body: Any,
|
|
27
25
|
) -> PreparedRequest:
|
|
28
26
|
mime_documents = prepare_mime_document_list(documents)
|
|
@@ -40,8 +38,6 @@ class SchemasMixin:
|
|
|
40
38
|
body["reasoning_effort"] = reasoning_effort
|
|
41
39
|
if image_resolution_dpi is not FieldUnset:
|
|
42
40
|
body["image_resolution_dpi"] = image_resolution_dpi
|
|
43
|
-
if browser_canvas is not FieldUnset:
|
|
44
|
-
body["browser_canvas"] = browser_canvas
|
|
45
41
|
if extra_body:
|
|
46
42
|
body.update(extra_body)
|
|
47
43
|
|
|
@@ -59,7 +55,6 @@ class Schemas(SyncAPIResource, SchemasMixin):
|
|
|
59
55
|
temperature: float = FieldUnset,
|
|
60
56
|
reasoning_effort: ChatCompletionReasoningEffort = FieldUnset,
|
|
61
57
|
image_resolution_dpi: int = FieldUnset,
|
|
62
|
-
browser_canvas: BrowserCanvas = FieldUnset,
|
|
63
58
|
**extra_body: Any,
|
|
64
59
|
) -> dict[str, Any]:
|
|
65
60
|
"""
|
|
@@ -84,7 +79,6 @@ class Schemas(SyncAPIResource, SchemasMixin):
|
|
|
84
79
|
temperature=temperature,
|
|
85
80
|
reasoning_effort=reasoning_effort,
|
|
86
81
|
image_resolution_dpi=image_resolution_dpi,
|
|
87
|
-
browser_canvas=browser_canvas,
|
|
88
82
|
**extra_body,
|
|
89
83
|
)
|
|
90
84
|
response = self._client._prepared_request(prepared_request)
|
|
@@ -100,7 +94,6 @@ class AsyncSchemas(AsyncAPIResource, SchemasMixin):
|
|
|
100
94
|
temperature: float = FieldUnset,
|
|
101
95
|
reasoning_effort: ChatCompletionReasoningEffort = FieldUnset,
|
|
102
96
|
image_resolution_dpi: int = FieldUnset,
|
|
103
|
-
browser_canvas: BrowserCanvas = FieldUnset,
|
|
104
97
|
**extra_body: Any,
|
|
105
98
|
) -> dict[str, Any]:
|
|
106
99
|
"""
|
|
@@ -125,7 +118,6 @@ class AsyncSchemas(AsyncAPIResource, SchemasMixin):
|
|
|
125
118
|
temperature=temperature,
|
|
126
119
|
reasoning_effort=reasoning_effort,
|
|
127
120
|
image_resolution_dpi=image_resolution_dpi,
|
|
128
|
-
browser_canvas=browser_canvas,
|
|
129
121
|
**extra_body,
|
|
130
122
|
)
|
|
131
123
|
response = await self._client._prepared_request(prepared_request)
|
|
@@ -9,7 +9,6 @@ from pydantic import BaseModel, ConfigDict, Field, computed_field
|
|
|
9
9
|
from ...utils.display import count_image_tokens, count_text_tokens
|
|
10
10
|
from ..chat import ChatCompletionRetabMessage
|
|
11
11
|
from ..mime import MIMEData
|
|
12
|
-
from ..browser_canvas import BrowserCanvas
|
|
13
12
|
MediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
|
|
14
13
|
|
|
15
14
|
|
|
@@ -23,9 +22,6 @@ class DocumentCreateMessageRequest(BaseModel):
|
|
|
23
22
|
model_config = ConfigDict(extra="ignore")
|
|
24
23
|
document: MIMEData = Field(description="The document to load.")
|
|
25
24
|
image_resolution_dpi: int = Field(default=192, description="Resolution of the image sent to the LLM")
|
|
26
|
-
browser_canvas: BrowserCanvas = Field(
|
|
27
|
-
default="A4", description="Sets the size of the browser canvas for rendering documents in browser-based processing. Choose a size that matches the document type."
|
|
28
|
-
)
|
|
29
25
|
model: str = Field(default="gemini-2.5-flash", description="The model to use for the document.")
|
|
30
26
|
|
|
31
27
|
class DocumentCreateInputRequest(DocumentCreateMessageRequest):
|
|
@@ -55,11 +51,12 @@ class DocumentMessage(BaseModel):
|
|
|
55
51
|
for msg in self.messages:
|
|
56
52
|
role = msg.get("role", "user")
|
|
57
53
|
msg_tokens = 0
|
|
54
|
+
content = msg.get("content")
|
|
58
55
|
|
|
59
|
-
if isinstance(
|
|
60
|
-
msg_tokens = count_text_tokens(
|
|
61
|
-
elif isinstance(
|
|
62
|
-
for content_item in
|
|
56
|
+
if isinstance(content, str):
|
|
57
|
+
msg_tokens = count_text_tokens(content)
|
|
58
|
+
elif isinstance(content, list):
|
|
59
|
+
for content_item in content:
|
|
63
60
|
if isinstance(content_item, str):
|
|
64
61
|
msg_tokens += count_text_tokens(content_item)
|
|
65
62
|
elif isinstance(content_item, dict):
|
|
@@ -104,11 +101,12 @@ class DocumentMessage(BaseModel):
|
|
|
104
101
|
results: list[str | PIL.Image.Image] = []
|
|
105
102
|
|
|
106
103
|
for msg in self.messages:
|
|
107
|
-
|
|
108
|
-
|
|
104
|
+
content = msg.get("content")
|
|
105
|
+
if isinstance(content, str):
|
|
106
|
+
results.append(content)
|
|
109
107
|
continue
|
|
110
|
-
assert isinstance(
|
|
111
|
-
for content_item in
|
|
108
|
+
assert isinstance(content, list), "content must be a list or a string"
|
|
109
|
+
for content_item in content:
|
|
112
110
|
if isinstance(content_item, str):
|
|
113
111
|
results.append(content_item)
|
|
114
112
|
else:
|
retab/types/documents/extract.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
1
1
|
import base64
|
|
2
|
-
import datetime
|
|
3
2
|
import json
|
|
4
3
|
from typing import Any, Literal, Optional
|
|
5
|
-
|
|
4
|
+
import datetime
|
|
6
5
|
|
|
7
6
|
from openai.types.chat import ChatCompletionMessageParam
|
|
8
7
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
@@ -17,14 +16,13 @@ from openai.types.chat.parsed_chat_completion import ParsedChatCompletionMessage
|
|
|
17
16
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
|
|
18
17
|
from ..chat import ChatCompletionRetabMessage
|
|
19
18
|
from ..mime import MIMEData
|
|
20
|
-
from ..standards import
|
|
19
|
+
from ..standards import StreamingBaseModel
|
|
21
20
|
from ...utils.json_schema import filter_auxiliary_fields_json, convert_basemodel_to_partial_basemodel, convert_json_schema_to_basemodel, unflatten_dict
|
|
22
21
|
from ..modality import Modality
|
|
23
22
|
|
|
24
23
|
class DocumentExtractRequest(BaseModel):
|
|
25
24
|
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
|
26
|
-
document: MIMEData = Field(
|
|
27
|
-
documents: list[MIMEData] = Field(default=[], description="Documents to be analyzed (preferred over document)")
|
|
25
|
+
document: MIMEData = Field(..., description="Document to be analyzed")
|
|
28
26
|
image_resolution_dpi: int = Field(default=192, description="Resolution of the image sent to the LLM", ge=96, le=300)
|
|
29
27
|
model: str = Field(..., description="Model used for chat completion")
|
|
30
28
|
json_schema: dict[str, Any] = Field(..., description="JSON schema format used to validate the output data.")
|
|
@@ -37,9 +35,10 @@ class DocumentExtractRequest(BaseModel):
|
|
|
37
35
|
stream: bool = Field(default=False, description="If true, the extraction will be streamed to the user using the active WebSocket connection")
|
|
38
36
|
seed: int | None = Field(default=None, description="Seed for the random number generator. If not provided, a random seed will be generated.", examples=[None])
|
|
39
37
|
store: bool = Field(default=True, description="If true, the extraction will be stored in the database")
|
|
40
|
-
need_validation: bool = Field(default=False, description="If true, the extraction will be validated against the schema")
|
|
41
38
|
modality: Modality = Field(default="native", description="The modality of the document to be analyzed")
|
|
42
39
|
parallel_ocr_keys: Optional[dict[str, str]] = Field(default=None, description="If set, keys to be used for the extraction of long lists of data using Parallel OCR", examples=[{"properties": "ID", "products": "identity.id"}])
|
|
40
|
+
metadata: dict[str, str] = Field(default_factory=dict, description="User-defined metadata to associate with this extraction")
|
|
41
|
+
extraction_id: Optional[str] = Field(default=None, description="Extraction ID to use for this extraction. If not provided, a new ID will be generated.")
|
|
43
42
|
|
|
44
43
|
# Add a model validator that rejects n_consensus > 1 if temperature is 0
|
|
45
44
|
@field_validator("n_consensus")
|
|
@@ -48,28 +47,6 @@ class DocumentExtractRequest(BaseModel):
|
|
|
48
47
|
raise ValueError("n_consensus greater than 1 but temperature is 0")
|
|
49
48
|
return v
|
|
50
49
|
|
|
51
|
-
@model_validator(mode="before")
|
|
52
|
-
def validate_document_or_documents(cls, data: Any) -> Any:
|
|
53
|
-
# Handle both dict and model instance cases
|
|
54
|
-
if isinstance(data, dict):
|
|
55
|
-
if data.get("documents"): # If documents is set, it has higher priority than document
|
|
56
|
-
data["document"] = data["documents"][0]
|
|
57
|
-
elif data.get("document"):
|
|
58
|
-
data["documents"] = [data["document"]]
|
|
59
|
-
else:
|
|
60
|
-
raise ValueError("document or documents must be provided")
|
|
61
|
-
else:
|
|
62
|
-
# Handle model instance case
|
|
63
|
-
document = getattr(data, "document", None)
|
|
64
|
-
documents = getattr(data, "documents", None)
|
|
65
|
-
if documents:
|
|
66
|
-
setattr(data, "document", documents[0])
|
|
67
|
-
elif document:
|
|
68
|
-
setattr(data, "documents", [document])
|
|
69
|
-
else:
|
|
70
|
-
raise ValueError("document or documents must be provided")
|
|
71
|
-
return data
|
|
72
|
-
|
|
73
50
|
|
|
74
51
|
class ConsensusModel(BaseModel):
|
|
75
52
|
model: str = Field(description="Model name")
|
|
@@ -79,31 +56,16 @@ class ConsensusModel(BaseModel):
|
|
|
79
56
|
)
|
|
80
57
|
|
|
81
58
|
|
|
82
|
-
# For location of fields in the document (OCR)
|
|
83
|
-
class FieldLocation(BaseModel):
|
|
84
|
-
label: str = Field(..., description="The label of the field")
|
|
85
|
-
value: str = Field(..., description="The extracted value of the field")
|
|
86
|
-
quote: str = Field(..., description="The quote of the field (verbatim from the document)")
|
|
87
|
-
file_id: str | None = Field(default=None, description="The ID of the file")
|
|
88
|
-
page: int | None = Field(default=None, description="The page number of the field (1-indexed)")
|
|
89
|
-
bbox_normalized: tuple[float, float, float, float] | None = Field(default=None, description="The normalized bounding box of the field")
|
|
90
|
-
score: float | None = Field(default=None, description="The score of the field")
|
|
91
|
-
match_level: Literal["token", "line", "block", "token-windows"] | None = Field(default=None, description="The level of the match (token, line, block, token-windows)")
|
|
92
|
-
|
|
93
|
-
|
|
94
59
|
class RetabParsedChoice(ParsedChoice):
|
|
95
60
|
# Adaptable ParsedChoice that allows None for the finish_reason
|
|
96
61
|
finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call"] | None = None # type: ignore
|
|
97
|
-
field_locations: dict[str, FieldLocation] | None = Field(default=None, description="The locations of the fields in the document, if available")
|
|
98
62
|
key_mapping: dict[str, Optional[str]] | None = Field(default=None, description="Mapping of consensus keys to original model keys")
|
|
99
63
|
|
|
100
64
|
|
|
101
65
|
LikelihoodsSource = Literal["consensus", "log_probs"]
|
|
102
66
|
|
|
103
|
-
|
|
104
67
|
class RetabParsedChatCompletion(ParsedChatCompletion):
|
|
105
68
|
model_config = ConfigDict(arbitrary_types_allowed=True, extra="ignore")
|
|
106
|
-
|
|
107
69
|
extraction_id: str | None = None
|
|
108
70
|
choices: list[RetabParsedChoice] # type: ignore
|
|
109
71
|
# Additional metadata fields
|
|
@@ -111,24 +73,8 @@ class RetabParsedChatCompletion(ParsedChatCompletion):
|
|
|
111
73
|
default=None, description="Object defining the uncertainties of the fields extracted when using consensus. Follows the same structure as the extraction object."
|
|
112
74
|
)
|
|
113
75
|
|
|
114
|
-
requires_human_review: bool = Field(default=False, description="
|
|
115
|
-
schema_validation_error: ErrorDetail | None = None
|
|
116
|
-
# Timestamps
|
|
117
|
-
request_at: datetime.datetime | None = Field(default=None, description="Timestamp of the request")
|
|
118
|
-
first_token_at: datetime.datetime | None = Field(default=None, description="Timestamp of the first token of the document. If non-streaming, set to last_token_at")
|
|
119
|
-
last_token_at: datetime.datetime | None = Field(default=None, description="Timestamp of the last token of the document")
|
|
120
|
-
|
|
121
|
-
|
|
76
|
+
requires_human_review: bool = Field(default=False, description="Flag indicating if the extraction requires human review")
|
|
122
77
|
|
|
123
|
-
|
|
124
|
-
class UiResponse(Response):
|
|
125
|
-
extraction_id: str | None = None
|
|
126
|
-
# Additional metadata fields (UIForm)
|
|
127
|
-
likelihoods: Optional[dict[str, Any]] = Field(
|
|
128
|
-
default=None, description="Object defining the uncertainties of the fields extracted when using consensus. Follows the same structure as the extraction object."
|
|
129
|
-
)
|
|
130
|
-
schema_validation_error: ErrorDetail | None = None
|
|
131
|
-
# Timestamps
|
|
132
78
|
request_at: datetime.datetime | None = Field(default=None, description="Timestamp of the request")
|
|
133
79
|
first_token_at: datetime.datetime | None = Field(default=None, description="Timestamp of the first token of the document. If non-streaming, set to last_token_at")
|
|
134
80
|
last_token_at: datetime.datetime | None = Field(default=None, description="Timestamp of the last token of the document")
|
|
@@ -187,7 +133,6 @@ class LogExtractionRequest(BaseModel):
|
|
|
187
133
|
|
|
188
134
|
|
|
189
135
|
class LogExtractionResponse(BaseModel):
|
|
190
|
-
extraction_id: str | None = None # None only in case of error
|
|
191
136
|
status: Literal["success", "error"]
|
|
192
137
|
error_message: str | None = None
|
|
193
138
|
|
|
@@ -208,7 +153,6 @@ class RetabParsedChoiceDeltaChunk(ChoiceDeltaChunk):
|
|
|
208
153
|
flat_likelihoods: dict[str, float] = {}
|
|
209
154
|
flat_parsed: dict[str, Any] = {}
|
|
210
155
|
flat_deleted_keys: list[str] = []
|
|
211
|
-
field_locations: dict[str, list[FieldLocation]] | None = Field(default=None, description="The locations of the fields in the document, if available")
|
|
212
156
|
is_valid_json: bool = False
|
|
213
157
|
key_mapping: dict[str, Optional[str]] | None = Field(default=None, description="Mapping of consensus keys to original model keys")
|
|
214
158
|
|
|
@@ -218,16 +162,13 @@ class RetabParsedChoiceChunk(ChoiceChunk):
|
|
|
218
162
|
|
|
219
163
|
|
|
220
164
|
class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
221
|
-
extraction_id: str | None = None
|
|
222
165
|
choices: list[RetabParsedChoiceChunk] # type: ignore
|
|
223
|
-
|
|
224
|
-
|
|
166
|
+
|
|
167
|
+
extraction_id: str | None = None
|
|
225
168
|
request_at: datetime.datetime | None = Field(default=None, description="Timestamp of the request")
|
|
226
169
|
first_token_at: datetime.datetime | None = Field(default=None, description="Timestamp of the first token of the document. If non-streaming, set to last_token_at")
|
|
227
170
|
last_token_at: datetime.datetime | None = Field(default=None, description="Timestamp of the last token of the document")
|
|
228
171
|
|
|
229
|
-
|
|
230
|
-
|
|
231
172
|
def chunk_accumulator(self, previous_cumulated_chunk: "RetabParsedChatCompletionChunk | None" = None) -> "RetabParsedChatCompletionChunk":
|
|
232
173
|
"""
|
|
233
174
|
Accumulate the chunk into the state, returning a new RetabParsedChatCompletionChunk with the accumulated content that could be yielded alone to generate the same state.
|
|
@@ -249,7 +190,6 @@ class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
|
249
190
|
# Get the current chunk missing content, flat_deleted_keys and is_valid_json
|
|
250
191
|
acc_flat_deleted_keys = [safe_get_delta(self, i).flat_deleted_keys for i in range(max_choices)]
|
|
251
192
|
acc_is_valid_json = [safe_get_delta(self, i).is_valid_json for i in range(max_choices)]
|
|
252
|
-
acc_field_locations = [safe_get_delta(self, i).field_locations for i in range(max_choices)] # This is only present in the last chunk.
|
|
253
193
|
# Delete from previous_cumulated_chunk.choices[i].delta.flat_parsed the keys that are in safe_get_delta(self, i).flat_deleted_keys
|
|
254
194
|
for i in range(max_choices):
|
|
255
195
|
previous_delta = safe_get_delta(previous_cumulated_chunk, i)
|
|
@@ -263,12 +203,8 @@ class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
|
263
203
|
acc_key_mapping = [safe_get_delta(previous_cumulated_chunk, i).key_mapping or safe_get_delta(self, i).key_mapping for i in range(max_choices)]
|
|
264
204
|
|
|
265
205
|
acc_content = [(safe_get_delta(previous_cumulated_chunk, i).content or "") + (safe_get_delta(self, i).content or "") for i in range(max_choices)]
|
|
266
|
-
first_token_at = self.first_token_at
|
|
267
|
-
last_token_at = self.last_token_at
|
|
268
|
-
request_at = self.request_at
|
|
269
206
|
|
|
270
207
|
return RetabParsedChatCompletionChunk(
|
|
271
|
-
extraction_id=self.extraction_id,
|
|
272
208
|
id=self.id,
|
|
273
209
|
created=self.created,
|
|
274
210
|
model=self.model,
|
|
@@ -281,7 +217,6 @@ class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
|
281
217
|
flat_parsed=acc_flat_parsed[i],
|
|
282
218
|
flat_likelihoods=acc_flat_likelihoods[i],
|
|
283
219
|
flat_deleted_keys=acc_flat_deleted_keys[i],
|
|
284
|
-
field_locations=acc_field_locations[i],
|
|
285
220
|
is_valid_json=acc_is_valid_json[i],
|
|
286
221
|
key_mapping=acc_key_mapping[i],
|
|
287
222
|
),
|
|
@@ -289,10 +224,10 @@ class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
|
289
224
|
)
|
|
290
225
|
for i in range(max_choices)
|
|
291
226
|
],
|
|
292
|
-
|
|
293
|
-
request_at=request_at,
|
|
294
|
-
first_token_at=first_token_at,
|
|
295
|
-
last_token_at=last_token_at,
|
|
227
|
+
extraction_id=self.extraction_id,
|
|
228
|
+
request_at=self.request_at,
|
|
229
|
+
first_token_at=self.first_token_at,
|
|
230
|
+
last_token_at=self.last_token_at,
|
|
296
231
|
)
|
|
297
232
|
|
|
298
233
|
def to_completion(
|
|
@@ -313,8 +248,11 @@ class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
|
313
248
|
final_likelihoods = unflatten_dict(override_final_flat_likelihoods)
|
|
314
249
|
|
|
315
250
|
return RetabParsedChatCompletion(
|
|
316
|
-
extraction_id=self.extraction_id,
|
|
317
251
|
id=self.id,
|
|
252
|
+
extraction_id=self.extraction_id,
|
|
253
|
+
request_at=self.request_at,
|
|
254
|
+
first_token_at=self.first_token_at,
|
|
255
|
+
last_token_at=self.last_token_at,
|
|
318
256
|
created=self.created,
|
|
319
257
|
model=self.model,
|
|
320
258
|
object="chat.completion",
|
|
@@ -334,9 +272,6 @@ class RetabParsedChatCompletionChunk(StreamingBaseModel, ChatCompletionChunk):
|
|
|
334
272
|
],
|
|
335
273
|
likelihoods=final_likelihoods,
|
|
336
274
|
usage=self.usage,
|
|
337
|
-
request_at=self.request_at,
|
|
338
|
-
first_token_at=self.first_token_at,
|
|
339
|
-
last_token_at=self.last_token_at,
|
|
340
275
|
)
|
|
341
276
|
|
|
342
277
|
|
retab/types/documents/parse.py
CHANGED
|
@@ -2,7 +2,6 @@ from typing import Literal
|
|
|
2
2
|
from pydantic import BaseModel, ConfigDict, Field
|
|
3
3
|
|
|
4
4
|
from ..mime import MIMEData, BaseMIMEData
|
|
5
|
-
from ..browser_canvas import BrowserCanvas
|
|
6
5
|
|
|
7
6
|
TableParsingFormat = Literal["markdown", "yaml", "html", "json"]
|
|
8
7
|
|
|
@@ -22,7 +21,6 @@ class ParseRequest(BaseModel):
|
|
|
22
21
|
model: str = Field(default="gemini-2.5-flash", description="Model to use for parsing")
|
|
23
22
|
table_parsing_format: TableParsingFormat = Field(default="html", description="Format for parsing tables")
|
|
24
23
|
image_resolution_dpi: int = Field(default=192, description="DPI for image processing", ge=96, le=300)
|
|
25
|
-
browser_canvas: BrowserCanvas = Field(default="A4", description="Canvas size for document rendering")
|
|
26
24
|
|
|
27
25
|
|
|
28
26
|
class ParseResult(BaseModel):
|
|
File without changes
|