retab 0.0.64__py3-none-any.whl → 0.0.66__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retab/__init__.py +2 -2
- retab/client copy.py +693 -0
- retab/types/documents/extract.py +7 -4
- retab/types/inference_settings.py +1 -0
- retab/types/projects/Untitled-2.py +16671 -0
- retab/types/projects/model.py +13 -1
- retab/types/projects/v2.py +137 -0
- retab/types/schemas/__init__.py +5 -0
- retab/types/schemas/chat.py +491 -0
- retab/types/schemas/model.py +1747 -0
- {retab-0.0.64.dist-info → retab-0.0.66.dist-info}/METADATA +1 -1
- {retab-0.0.64.dist-info → retab-0.0.66.dist-info}/RECORD +14 -9
- {retab-0.0.64.dist-info → retab-0.0.66.dist-info}/WHEEL +0 -0
- {retab-0.0.64.dist-info → retab-0.0.66.dist-info}/top_level.txt +0 -0
retab/types/projects/model.py
CHANGED
|
@@ -6,6 +6,17 @@ from pydantic import BaseModel, Field, ConfigDict
|
|
|
6
6
|
|
|
7
7
|
from .documents import ProjectDocument
|
|
8
8
|
from .iterations import Iteration
|
|
9
|
+
from ..inference_settings import InferenceSettings
|
|
10
|
+
|
|
11
|
+
default_inference_settings = InferenceSettings(
|
|
12
|
+
model="auto-small",
|
|
13
|
+
temperature=0.5,
|
|
14
|
+
reasoning_effort="minimal",
|
|
15
|
+
modality="native",
|
|
16
|
+
image_resolution_dpi=192,
|
|
17
|
+
browser_canvas="A4",
|
|
18
|
+
n_consensus=1,
|
|
19
|
+
)
|
|
9
20
|
|
|
10
21
|
class SheetsIntegration(BaseModel):
|
|
11
22
|
sheet_id: str
|
|
@@ -19,7 +30,7 @@ class BaseProject(BaseModel):
|
|
|
19
30
|
updated_at: datetime.datetime = Field(default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc))
|
|
20
31
|
sheets_integration: SheetsIntegration | None = None
|
|
21
32
|
validation_flags: dict[str, Any] | None = None
|
|
22
|
-
|
|
33
|
+
inference_settings: InferenceSettings = default_inference_settings
|
|
23
34
|
|
|
24
35
|
# Actual Object stored in DB
|
|
25
36
|
class Project(BaseProject):
|
|
@@ -40,6 +51,7 @@ class PatchProjectRequest(BaseModel):
|
|
|
40
51
|
json_schema: Optional[dict[str, Any]] = Field(default=None, description="The json schema of the project")
|
|
41
52
|
sheets_integration: SheetsIntegration | None = None
|
|
42
53
|
validation_flags: Optional[dict[str, Any]] = Field(default=None, description="The validation flags of the project")
|
|
54
|
+
inference_settings: Optional[InferenceSettings] = Field(default=None, description="The inference settings of the project")
|
|
43
55
|
|
|
44
56
|
class AddIterationFromJsonlRequest(BaseModel):
|
|
45
57
|
model_config = ConfigDict(extra="ignore")
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import nanoid # type: ignore
|
|
5
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
6
|
+
|
|
7
|
+
from ..inference_settings import InferenceSettings
|
|
8
|
+
from ..mime import MIMEData
|
|
9
|
+
from .predictions import PredictionData, PredictionMetadata
|
|
10
|
+
from .documents import ProjectDocument
|
|
11
|
+
|
|
12
|
+
from typing import Self
|
|
13
|
+
from pydantic import model_validator
|
|
14
|
+
|
|
15
|
+
class DatasetDocument(BaseModel):
|
|
16
|
+
model_config = ConfigDict(extra="ignore")
|
|
17
|
+
id: str = Field(default_factory=lambda: "dataset_doc_" + nanoid.generate(), description="The ID of the document. Equal to mime_data.id but robust to the case where mime_data is a BaseMIMEData")
|
|
18
|
+
mime_data: MIMEData = Field(description="The mime data of the document. Can also be a BaseMIMEData, which is why we have this id field (to be able to identify the file, but id is equal to mime_data.id)")
|
|
19
|
+
annotation: dict[str, Any] = Field(default={}, description="The ground truth of the document")
|
|
20
|
+
annotation_metadata: Optional[PredictionMetadata] = Field(default=None, description="The metadata of the annotation when the annotation is a prediction")
|
|
21
|
+
validation_flags: dict[str, Any] | None = None
|
|
22
|
+
|
|
23
|
+
default_inference_settings = InferenceSettings(
|
|
24
|
+
model="auto-small",
|
|
25
|
+
temperature=0.5,
|
|
26
|
+
reasoning_effort="minimal",
|
|
27
|
+
modality="native",
|
|
28
|
+
image_resolution_dpi=192,
|
|
29
|
+
browser_canvas="A4",
|
|
30
|
+
n_consensus=1,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
class Dataset(BaseModel):
|
|
34
|
+
id: str = Field(default_factory=lambda: "dataset_" + nanoid.generate())
|
|
35
|
+
name: str = Field(default="", description="The name of the dataset")
|
|
36
|
+
updated_at: datetime.datetime = Field(default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc))
|
|
37
|
+
inference_settings: InferenceSettings = default_inference_settings
|
|
38
|
+
documents: list[DatasetDocument] = Field(default_factory=list)
|
|
39
|
+
iteration_ids: list[str] = Field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SchemaOverrides(BaseModel):
|
|
43
|
+
model_config = ConfigDict(extra="ignore")
|
|
44
|
+
"""Schema override for a field path. Only supports non-structural metadata.
|
|
45
|
+
|
|
46
|
+
- description: JSON Schema description string
|
|
47
|
+
- reasoning_prompt: value mapped to schema key "X-ReasoningPrompt"
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
descriptionsOverride: Optional[dict[str, str]] = None
|
|
51
|
+
reasoningPromptsOverride: Optional[dict[str, str]] = Field(default=None, description="Maps to X-ReasoningPrompt in schema")
|
|
52
|
+
|
|
53
|
+
class BaseIteration(BaseModel):
|
|
54
|
+
model_config = ConfigDict(extra="ignore")
|
|
55
|
+
id: str = Field(default_factory=lambda: "eval_iter_" + nanoid.generate())
|
|
56
|
+
parent_id: Optional[str] = Field(default=None, description="The ID of the parent iteration")
|
|
57
|
+
inference_settings: InferenceSettings
|
|
58
|
+
# Store only overrides rather than the full schema. Keys are dot-paths like "address.street" or "items.*.price".
|
|
59
|
+
schema_overrides: SchemaOverrides = Field(
|
|
60
|
+
default_factory=SchemaOverrides, description="Map of field path -> non-structural schema overrides (description, reasoning_prompt)"
|
|
61
|
+
)
|
|
62
|
+
updated_at: datetime.datetime = Field(
|
|
63
|
+
default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc),
|
|
64
|
+
description="The last update date of inference settings or schema overrides",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
class DraftIteration(BaseModel):
|
|
68
|
+
model_config = ConfigDict(extra="ignore")
|
|
69
|
+
# Store draft overrides only.
|
|
70
|
+
schema_overrides: SchemaOverrides = Field(default_factory=SchemaOverrides)
|
|
71
|
+
updated_at: datetime.datetime = Field(
|
|
72
|
+
default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc),
|
|
73
|
+
description="The last update date of draft schema overrides",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
class Iteration(BaseIteration):
|
|
77
|
+
model_config = ConfigDict(extra="ignore")
|
|
78
|
+
predictions: dict[str, PredictionData] = Field(default_factory=dict, description="The predictions of the iteration for all the documents")
|
|
79
|
+
draft: Optional[DraftIteration] = Field(default=None, description="The draft iteration of the iteration")
|
|
80
|
+
|
|
81
|
+
# if no draft is provided, set it to the current iteration
|
|
82
|
+
@model_validator(mode="after")
|
|
83
|
+
def set_draft_to_current_iteration(self) -> Self:
|
|
84
|
+
if self.draft is None:
|
|
85
|
+
self.draft = DraftIteration(
|
|
86
|
+
schema_overrides=SchemaOverrides(),
|
|
87
|
+
updated_at=datetime.datetime.now(tz=datetime.timezone.utc),
|
|
88
|
+
)
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
class Evaluation(BaseModel):
|
|
92
|
+
id: str = Field(default_factory=lambda: "eval_" + nanoid.generate())
|
|
93
|
+
updated_at: datetime.datetime = Field(default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc))
|
|
94
|
+
dataset_id: str
|
|
95
|
+
iteration_ids: list[str] = Field(default_factory=list)
|
|
96
|
+
|
|
97
|
+
class PublishedConfig(BaseModel):
|
|
98
|
+
inference_settings: InferenceSettings = default_inference_settings
|
|
99
|
+
json_schema: dict[str, Any] = Field(default_factory=dict, description="The json schema of the project")
|
|
100
|
+
human_in_the_loop_criteria: list[str] = Field(default_factory=list)
|
|
101
|
+
|
|
102
|
+
class DraftConfig(BaseModel):
|
|
103
|
+
inference_settings: InferenceSettings = default_inference_settings
|
|
104
|
+
schema_overrides: SchemaOverrides = Field(default_factory=SchemaOverrides)
|
|
105
|
+
human_in_the_loop_criteria: list[str] = Field(default_factory=list)
|
|
106
|
+
|
|
107
|
+
class Project(BaseModel):
|
|
108
|
+
model_config = ConfigDict(extra="ignore")
|
|
109
|
+
id: str = Field(default_factory=lambda: "project_" + nanoid.generate())
|
|
110
|
+
name: str = Field(default="", description="The name of the project")
|
|
111
|
+
updated_at: datetime.datetime = Field(default_factory=lambda: datetime.datetime.now(tz=datetime.timezone.utc))
|
|
112
|
+
dataset_ids: list[str] = Field(default_factory=list)
|
|
113
|
+
is_published: bool = False
|
|
114
|
+
published_config: PublishedConfig
|
|
115
|
+
draft_config: DraftConfig
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# Actual Object stored in DB
|
|
119
|
+
|
|
120
|
+
class CreateProjectRequest(BaseModel):
|
|
121
|
+
model_config = ConfigDict(extra="ignore")
|
|
122
|
+
name: str
|
|
123
|
+
json_schema: dict[str, Any]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
# This is basically the same as BaseProject, but everything is optional.
|
|
127
|
+
# Could be achieved by convert_basemodel_to_partial_basemodel(BaseProject) but we prefer explicitness
|
|
128
|
+
class PatchProjectRequest(BaseModel):
|
|
129
|
+
model_config = ConfigDict(extra="ignore")
|
|
130
|
+
name: Optional[str] = Field(default=None, description="The name of the document")
|
|
131
|
+
json_schema: Optional[dict[str, Any]] = Field(default=None, description="The json schema of the project")
|
|
132
|
+
validation_flags: Optional[dict[str, Any]] = Field(default=None, description="The validation flags of the project")
|
|
133
|
+
inference_settings: Optional[InferenceSettings] = Field(default=None, description="The inference settings of the project")
|
|
134
|
+
|
|
135
|
+
class AddIterationFromJsonlRequest(BaseModel):
|
|
136
|
+
model_config = ConfigDict(extra="ignore")
|
|
137
|
+
jsonl_gcs_path: str
|
retab/types/schemas/__init__.py
CHANGED
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Literal, Optional, Union, cast
|
|
4
|
+
import datetime
|
|
5
|
+
import json
|
|
6
|
+
|
|
7
|
+
from jiter import from_json
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
from anthropic.types.image_block_param import ImageBlockParam
|
|
11
|
+
from anthropic.types.message_param import MessageParam
|
|
12
|
+
from anthropic.types.text_block_param import TextBlockParam
|
|
13
|
+
from google.genai.types import BlobDict, ContentDict, ContentUnionDict, PartDict # type: ignore
|
|
14
|
+
from openai.types.chat.chat_completion_content_part_input_audio_param import ChatCompletionContentPartInputAudioParam
|
|
15
|
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
|
16
|
+
from openai.types.chat.chat_completion_content_part_image_param import ChatCompletionContentPartImageParam, ImageURL
|
|
17
|
+
from openai.types.chat.chat_completion_content_part_param import ChatCompletionContentPartParam
|
|
18
|
+
from openai.types.chat.chat_completion_content_part_text_param import ChatCompletionContentPartTextParam
|
|
19
|
+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletionMessage
|
|
20
|
+
from openai.types.completion_usage import CompletionTokensDetails, CompletionUsage, PromptTokensDetails
|
|
21
|
+
from openai.types.responses.easy_input_message_param import EasyInputMessageParam
|
|
22
|
+
from openai.types.responses.response import Response
|
|
23
|
+
from openai.types.responses.response_input_image_param import ResponseInputImageParam
|
|
24
|
+
from openai.types.responses.response_input_message_content_list_param import ResponseInputMessageContentListParam
|
|
25
|
+
from openai.types.responses.response_input_param import ResponseInputItemParam
|
|
26
|
+
from openai.types.responses.response_input_text_param import ResponseInputTextParam
|
|
27
|
+
|
|
28
|
+
from retab.types.chat import ChatCompletionRetabMessage
|
|
29
|
+
from retab.types.documents.extract import RetabParsedChatCompletion, RetabParsedChoice
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
MediaType = Literal["image/jpeg", "image/png", "image/gif", "image/webp"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def convert_to_google_genai_format(messages: List[ChatCompletionRetabMessage]) -> tuple[str, list[ContentUnionDict]]:
|
|
36
|
+
"""
|
|
37
|
+
Converts a list of ChatCompletionRetabMessage to a format compatible with the google.genai SDK.
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
```python
|
|
42
|
+
import google.genai as genai
|
|
43
|
+
|
|
44
|
+
# Configure the Gemini client
|
|
45
|
+
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
|
|
46
|
+
|
|
47
|
+
# Initialize the model
|
|
48
|
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
|
49
|
+
|
|
50
|
+
# Get messages in Gemini format
|
|
51
|
+
gemini_messages = document_message.gemini_messages
|
|
52
|
+
|
|
53
|
+
# Generate a response
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
messages (List[ChatCompletionRetabMessage]): List of chat messages.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
List[Union[Dict[str, str], str]]: A list of formatted inputs for the google.genai SDK.
|
|
61
|
+
"""
|
|
62
|
+
system_message: str = ""
|
|
63
|
+
formatted_content: list[ContentUnionDict] = []
|
|
64
|
+
for message in messages:
|
|
65
|
+
# -----------------------
|
|
66
|
+
# Handle system message
|
|
67
|
+
# -----------------------
|
|
68
|
+
if message["role"] in ("system", "developer"):
|
|
69
|
+
assert isinstance(message.get("content"), str), "System message content must be a string."
|
|
70
|
+
if system_message != "":
|
|
71
|
+
raise ValueError("Only one system message is allowed per chat.")
|
|
72
|
+
system_message += cast(str, message.get("content", ""))
|
|
73
|
+
continue
|
|
74
|
+
parts: list[PartDict] = []
|
|
75
|
+
|
|
76
|
+
message_content = message.get("content")
|
|
77
|
+
if isinstance(message_content, str):
|
|
78
|
+
# Direct string content is treated as the prompt for the SDK
|
|
79
|
+
parts.append(PartDict(text=message_content))
|
|
80
|
+
elif isinstance(message_content, list):
|
|
81
|
+
# Handle structured content
|
|
82
|
+
for part in message_content:
|
|
83
|
+
if part["type"] == "text":
|
|
84
|
+
parts.append(PartDict(text=part["text"]))
|
|
85
|
+
elif part["type"] == "image_url":
|
|
86
|
+
url = part["image_url"].get("url", "") # type: ignore
|
|
87
|
+
if url.startswith("data:image"):
|
|
88
|
+
# Extract base64 data and add it to the formatted inputs
|
|
89
|
+
media_type, data_content = url.split(";base64,")
|
|
90
|
+
media_type = media_type.split("data:")[-1] # => "image/jpeg"
|
|
91
|
+
base64_data = data_content
|
|
92
|
+
|
|
93
|
+
# Try to convert to PIL.Image and append it to the formatted inputs
|
|
94
|
+
try:
|
|
95
|
+
image_bytes = base64.b64decode(base64_data)
|
|
96
|
+
parts.append(PartDict(inline_data=BlobDict(data=image_bytes, mime_type=media_type)))
|
|
97
|
+
except Exception:
|
|
98
|
+
pass
|
|
99
|
+
elif part["type"] == "input_audio":
|
|
100
|
+
pass
|
|
101
|
+
elif part["type"] == "file":
|
|
102
|
+
pass
|
|
103
|
+
else:
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
formatted_content.append(ContentDict(parts=parts, role=("user" if message["role"] == "user" else "model")))
|
|
107
|
+
|
|
108
|
+
return system_message, formatted_content
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def convert_to_anthropic_format(messages: List[ChatCompletionRetabMessage]) -> tuple[str, List[MessageParam]]:
|
|
112
|
+
"""
|
|
113
|
+
Converts a list of ChatCompletionRetabMessage to a format compatible with the Anthropic SDK.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
messages (List[ChatCompletionRetabMessage]): List of chat messages.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
(system_message, formatted_messages):
|
|
120
|
+
system_message (str | NotGiven):
|
|
121
|
+
The system message if one was found, otherwise NOT_GIVEN.
|
|
122
|
+
formatted_messages (List[MessageParam]):
|
|
123
|
+
A list of formatted messages ready for Anthropic.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
formatted_messages: list[MessageParam] = []
|
|
127
|
+
system_message: str = ""
|
|
128
|
+
|
|
129
|
+
for message in messages:
|
|
130
|
+
content_blocks: list[Union[TextBlockParam, ImageBlockParam]] = []
|
|
131
|
+
|
|
132
|
+
# -----------------------
|
|
133
|
+
# Handle system message
|
|
134
|
+
# -----------------------
|
|
135
|
+
if message["role"] in ("system", "developer"):
|
|
136
|
+
assert isinstance(message.get("content"), str), "System message content must be a string."
|
|
137
|
+
if system_message != "":
|
|
138
|
+
raise ValueError("Only one system message is allowed per chat.")
|
|
139
|
+
system_message += cast(str, message.get("content", ""))
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# -----------------------
|
|
143
|
+
# Handle non-system roles
|
|
144
|
+
# -----------------------
|
|
145
|
+
if isinstance(message.get("content"), str):
|
|
146
|
+
# Direct string content is treated as a single text block
|
|
147
|
+
content_blocks.append(
|
|
148
|
+
{
|
|
149
|
+
"type": "text",
|
|
150
|
+
"text": cast(str, message.get("content", "")),
|
|
151
|
+
}
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
elif isinstance(message.get("content"), list):
|
|
155
|
+
# Handle structured content
|
|
156
|
+
for part in cast(list, message.get("content", [])):
|
|
157
|
+
if part["type"] == "text":
|
|
158
|
+
part = cast(ChatCompletionContentPartTextParam, part)
|
|
159
|
+
content_blocks.append(
|
|
160
|
+
{
|
|
161
|
+
"type": "text",
|
|
162
|
+
"text": part["text"], # type: ignore
|
|
163
|
+
}
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
elif part["type"] == "input_audio":
|
|
167
|
+
part = cast(ChatCompletionContentPartInputAudioParam, part)
|
|
168
|
+
logging.warning("Audio input is not supported yet.")
|
|
169
|
+
# No blocks appended since not supported
|
|
170
|
+
|
|
171
|
+
elif part["type"] == "image_url":
|
|
172
|
+
# Handle images that may be either base64 data-URLs or standard remote URLs
|
|
173
|
+
part = cast(ChatCompletionContentPartImageParam, part)
|
|
174
|
+
image_url = part["image_url"]["url"]
|
|
175
|
+
|
|
176
|
+
if "base64," in image_url:
|
|
177
|
+
# The string is already something like: ...
|
|
178
|
+
media_type, data_content = image_url.split(";base64,")
|
|
179
|
+
# media_type might look like: "data:image/jpeg"
|
|
180
|
+
media_type = media_type.split("data:")[-1] # => "image/jpeg"
|
|
181
|
+
base64_data = data_content
|
|
182
|
+
else:
|
|
183
|
+
# It's a remote URL, so fetch, encode, and derive media type from headers
|
|
184
|
+
try:
|
|
185
|
+
r = requests.get(image_url)
|
|
186
|
+
r.raise_for_status()
|
|
187
|
+
content_type = r.headers.get("Content-Type", "image/jpeg")
|
|
188
|
+
# fallback "image/jpeg" if no Content-Type given
|
|
189
|
+
|
|
190
|
+
# Only keep recognized image/* for anthropic
|
|
191
|
+
if content_type not in ("image/jpeg", "image/png", "image/gif", "image/webp"):
|
|
192
|
+
logging.warning(
|
|
193
|
+
"Unrecognized Content-Type '%s' - defaulting to image/jpeg",
|
|
194
|
+
content_type,
|
|
195
|
+
)
|
|
196
|
+
content_type = "image/jpeg"
|
|
197
|
+
|
|
198
|
+
media_type = content_type
|
|
199
|
+
base64_data = base64.b64encode(r.content).decode("utf-8")
|
|
200
|
+
|
|
201
|
+
except Exception:
|
|
202
|
+
logging.warning(
|
|
203
|
+
"Failed to load image from URL: %s",
|
|
204
|
+
image_url,
|
|
205
|
+
exc_info=True,
|
|
206
|
+
stack_info=True,
|
|
207
|
+
)
|
|
208
|
+
# Skip adding this block if error
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
# Finally, append to content blocks
|
|
212
|
+
content_blocks.append(
|
|
213
|
+
{
|
|
214
|
+
"type": "image",
|
|
215
|
+
"source": {
|
|
216
|
+
"type": "base64",
|
|
217
|
+
"media_type": cast(MediaType, media_type),
|
|
218
|
+
"data": base64_data,
|
|
219
|
+
},
|
|
220
|
+
}
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
formatted_messages.append(
|
|
224
|
+
MessageParam(
|
|
225
|
+
role=message["role"], # type: ignore
|
|
226
|
+
content=content_blocks,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return system_message, formatted_messages
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def convert_from_anthropic_format(messages: list[MessageParam], system_prompt: str) -> list[ChatCompletionRetabMessage]:
|
|
234
|
+
"""
|
|
235
|
+
Converts a list of Anthropic MessageParam to a list of ChatCompletionRetabMessage.
|
|
236
|
+
"""
|
|
237
|
+
formatted_messages: list[ChatCompletionRetabMessage] = [ChatCompletionRetabMessage(role="developer", content=system_prompt)]
|
|
238
|
+
|
|
239
|
+
for message in messages:
|
|
240
|
+
role = message["role"]
|
|
241
|
+
content_blocks = message["content"]
|
|
242
|
+
|
|
243
|
+
# Handle different content structures
|
|
244
|
+
if isinstance(content_blocks, list) and len(content_blocks) == 1 and isinstance(content_blocks[0], dict) and content_blocks[0].get("type") == "text":
|
|
245
|
+
# Simple text message
|
|
246
|
+
formatted_messages.append(cast(ChatCompletionRetabMessage, {"role": role, "content": content_blocks[0].get("text", "")}))
|
|
247
|
+
elif isinstance(content_blocks, list):
|
|
248
|
+
# Message with multiple content parts or non-text content
|
|
249
|
+
formatted_content: list[ChatCompletionContentPartParam] = []
|
|
250
|
+
|
|
251
|
+
for block in content_blocks:
|
|
252
|
+
if isinstance(block, dict):
|
|
253
|
+
if block.get("type") == "text":
|
|
254
|
+
formatted_content.append(cast(ChatCompletionContentPartParam, {"type": "text", "text": block.get("text", "")}))
|
|
255
|
+
elif block.get("type") == "image":
|
|
256
|
+
source = block.get("source", {})
|
|
257
|
+
if isinstance(source, dict) and source.get("type") == "base64":
|
|
258
|
+
# Convert base64 image to data URL format
|
|
259
|
+
media_type = source.get("media_type", "image/jpeg")
|
|
260
|
+
data = source.get("data", "")
|
|
261
|
+
image_url = f"data:{media_type};base64,{data}"
|
|
262
|
+
|
|
263
|
+
formatted_content.append(cast(ChatCompletionContentPartParam, {"type": "image_url", "image_url": {"url": image_url}}))
|
|
264
|
+
|
|
265
|
+
formatted_messages.append(cast(ChatCompletionRetabMessage, {"role": role, "content": formatted_content}))
|
|
266
|
+
|
|
267
|
+
return formatted_messages
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def convert_to_openai_completions_api_format(messages: List[ChatCompletionRetabMessage]) -> List[ChatCompletionMessageParam]:
|
|
271
|
+
return cast(list[ChatCompletionMessageParam], messages)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def convert_from_openai_completions_api_format(messages: list[ChatCompletionMessageParam]) -> list[ChatCompletionRetabMessage]:
|
|
275
|
+
return cast(list[ChatCompletionRetabMessage], messages)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def separate_messages(
|
|
279
|
+
messages: list[ChatCompletionRetabMessage],
|
|
280
|
+
) -> tuple[Optional[ChatCompletionRetabMessage], list[ChatCompletionRetabMessage], list[ChatCompletionRetabMessage]]:
|
|
281
|
+
"""
|
|
282
|
+
Separates messages into system, user and assistant messages.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
messages: List of chat messages containing system, user and assistant messages
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Tuple containing:
|
|
289
|
+
- The system message if present, otherwise None
|
|
290
|
+
- List of user messages
|
|
291
|
+
- List of assistant messages
|
|
292
|
+
"""
|
|
293
|
+
system_message = None
|
|
294
|
+
user_messages = []
|
|
295
|
+
assistant_messages = []
|
|
296
|
+
|
|
297
|
+
for message in messages:
|
|
298
|
+
if message["role"] in ("system", "developer"):
|
|
299
|
+
system_message = message
|
|
300
|
+
elif message["role"] == "user":
|
|
301
|
+
user_messages.append(message)
|
|
302
|
+
elif message["role"] == "assistant":
|
|
303
|
+
assistant_messages.append(message)
|
|
304
|
+
|
|
305
|
+
return system_message, user_messages, assistant_messages
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def str_messages(messages: list[ChatCompletionRetabMessage], max_length: int = 100) -> str:
|
|
309
|
+
"""
|
|
310
|
+
Converts a list of chat messages into a string representation with faithfully serialized structure.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
messages (list[ChatCompletionRetabMessage]): The list of chat messages.
|
|
314
|
+
max_length (int): Maximum length for content before truncation.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
str: A string representation of the messages with applied truncation.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def truncate(text: str, max_len: int) -> str:
|
|
321
|
+
"""Truncate text to max_len with ellipsis."""
|
|
322
|
+
return text if len(text) <= max_len else f"{text[:max_len]}..."
|
|
323
|
+
|
|
324
|
+
serialized: list[ChatCompletionRetabMessage] = []
|
|
325
|
+
for message in messages:
|
|
326
|
+
role = message["role"]
|
|
327
|
+
content = message.get("content")
|
|
328
|
+
|
|
329
|
+
if isinstance(content, str):
|
|
330
|
+
serialized.append({"role": role, "content": truncate(content, max_length)})
|
|
331
|
+
elif isinstance(content, list):
|
|
332
|
+
truncated_content: list[ChatCompletionContentPartParam] = []
|
|
333
|
+
for part in content:
|
|
334
|
+
if part["type"] == "text" and part["text"]:
|
|
335
|
+
truncated_content.append({"type": "text", "text": truncate(part["text"], max_length)})
|
|
336
|
+
elif part["type"] == "image_url" and part["image_url"]:
|
|
337
|
+
image_url = part["image_url"].get("url", "unknown image")
|
|
338
|
+
truncated_content.append({"type": "image_url", "image_url": {"url": truncate(image_url, max_length)}})
|
|
339
|
+
serialized.append({"role": role, "content": truncated_content})
|
|
340
|
+
|
|
341
|
+
return repr(serialized)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def convert_to_openai_responses_api_format(messages: list[ChatCompletionRetabMessage]) -> list[ResponseInputItemParam]:
|
|
345
|
+
"""
|
|
346
|
+
Converts a list of ChatCompletionRetabMessage to the OpenAI ResponseInputParam format.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
messages: List of chat messages in UIForm format
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Messages in OpenAI ResponseInputParam format for the Responses API
|
|
353
|
+
"""
|
|
354
|
+
formatted_messages: list[ResponseInputItemParam] = []
|
|
355
|
+
|
|
356
|
+
for message in messages:
|
|
357
|
+
role = message["role"]
|
|
358
|
+
content = message.get("content")
|
|
359
|
+
|
|
360
|
+
# Handle different content formats
|
|
361
|
+
formatted_content: ResponseInputMessageContentListParam = []
|
|
362
|
+
|
|
363
|
+
if isinstance(content, str):
|
|
364
|
+
# Simple text content - provide direct string value
|
|
365
|
+
formatted_content.append(ResponseInputTextParam(text=content, type="input_text"))
|
|
366
|
+
elif isinstance(content, list):
|
|
367
|
+
# Content is a list of parts
|
|
368
|
+
for part in content:
|
|
369
|
+
if part["type"] == "text":
|
|
370
|
+
formatted_content.append(ResponseInputTextParam(text=part["text"], type="input_text"))
|
|
371
|
+
elif part["type"] == "image_url":
|
|
372
|
+
if "detail" in part["image_url"]:
|
|
373
|
+
detail = part["image_url"]["detail"]
|
|
374
|
+
else:
|
|
375
|
+
detail = "high"
|
|
376
|
+
formatted_content.append(ResponseInputImageParam(image_url=part["image_url"]["url"], type="input_image", detail=detail))
|
|
377
|
+
else:
|
|
378
|
+
print(f"Not supported content type: {part['type']}... Skipping...")
|
|
379
|
+
|
|
380
|
+
# Create Message structure which is one of the types in ResponseInputItemParam
|
|
381
|
+
role_for_response = role if role in ("user", "assistant", "system", "developer") else "assistant"
|
|
382
|
+
formatted_message = EasyInputMessageParam(role=role_for_response, content=formatted_content, type="message")
|
|
383
|
+
|
|
384
|
+
formatted_messages.append(formatted_message)
|
|
385
|
+
|
|
386
|
+
return formatted_messages
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def convert_from_openai_responses_api_format(messages: list[ResponseInputItemParam]) -> list[ChatCompletionRetabMessage]:
|
|
390
|
+
"""
|
|
391
|
+
Converts messages from OpenAI ResponseInputParam format to ChatCompletionRetabMessage format.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
messages: Messages in OpenAI ResponseInputParam format
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
List of chat messages in UIForm format
|
|
398
|
+
"""
|
|
399
|
+
formatted_messages: list[ChatCompletionRetabMessage] = []
|
|
400
|
+
|
|
401
|
+
for message in messages:
|
|
402
|
+
if "role" not in message or "content" not in message:
|
|
403
|
+
# Mandatory fields for a message
|
|
404
|
+
if message.get("type") != "message":
|
|
405
|
+
print(f"Not supported message type: {message.get('type')}... Skipping...")
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
role = message["role"]
|
|
409
|
+
content = message["content"]
|
|
410
|
+
|
|
411
|
+
if "type" not in message:
|
|
412
|
+
# The type is required by all other sub-types of ResponseInputItemParam except for EasyInputMessageParam and Message, which are messages.
|
|
413
|
+
message["type"] = "message"
|
|
414
|
+
|
|
415
|
+
role = message["role"]
|
|
416
|
+
content = message["content"]
|
|
417
|
+
formatted_content: str | list[ChatCompletionContentPartParam]
|
|
418
|
+
|
|
419
|
+
if isinstance(content, str):
|
|
420
|
+
formatted_content = content
|
|
421
|
+
else:
|
|
422
|
+
# Handle different content formats
|
|
423
|
+
formatted_content = []
|
|
424
|
+
for part in content:
|
|
425
|
+
if part["type"] == "input_text":
|
|
426
|
+
formatted_content.append(ChatCompletionContentPartTextParam(text=part["text"], type="text"))
|
|
427
|
+
elif part["type"] == "input_image":
|
|
428
|
+
image_url = part.get("image_url") or ""
|
|
429
|
+
image_detail = part.get("detail") or "high"
|
|
430
|
+
formatted_content.append(ChatCompletionContentPartImageParam(image_url=ImageURL(url=image_url, detail=image_detail), type="image_url"))
|
|
431
|
+
else:
|
|
432
|
+
print(f"Not supported content type: {part['type']}... Skipping...")
|
|
433
|
+
|
|
434
|
+
# Create message in UIForm format
|
|
435
|
+
formatted_message = ChatCompletionRetabMessage(role=role, content=formatted_content)
|
|
436
|
+
formatted_messages.append(formatted_message)
|
|
437
|
+
|
|
438
|
+
return formatted_messages
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def parse_openai_responses_response(response: Response) -> RetabParsedChatCompletion:
|
|
442
|
+
"""
|
|
443
|
+
Convert an OpenAI Response (Responses API) to RetabParsedChatCompletion type.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
response: Response from OpenAI Responses API
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
Parsed response in RetabParsedChatCompletion format
|
|
450
|
+
"""
|
|
451
|
+
# Create the RetabParsedChatCompletion object
|
|
452
|
+
if response.usage:
|
|
453
|
+
usage = CompletionUsage(
|
|
454
|
+
prompt_tokens=response.usage.input_tokens,
|
|
455
|
+
completion_tokens=response.usage.output_tokens,
|
|
456
|
+
total_tokens=response.usage.total_tokens,
|
|
457
|
+
prompt_tokens_details=PromptTokensDetails(
|
|
458
|
+
cached_tokens=response.usage.input_tokens_details.cached_tokens,
|
|
459
|
+
),
|
|
460
|
+
completion_tokens_details=CompletionTokensDetails(
|
|
461
|
+
reasoning_tokens=response.usage.output_tokens_details.reasoning_tokens,
|
|
462
|
+
),
|
|
463
|
+
)
|
|
464
|
+
else:
|
|
465
|
+
usage = None
|
|
466
|
+
|
|
467
|
+
# Parse the ParsedChoice
|
|
468
|
+
choices = []
|
|
469
|
+
output_text = response.output_text
|
|
470
|
+
result_object = from_json(bytes(output_text, "utf-8"), partial_mode=True) # Attempt to parse the result even if EOF is reached
|
|
471
|
+
|
|
472
|
+
choices.append(
|
|
473
|
+
RetabParsedChoice(
|
|
474
|
+
index=0,
|
|
475
|
+
message=ParsedChatCompletionMessage(
|
|
476
|
+
role="assistant",
|
|
477
|
+
content=json.dumps(result_object),
|
|
478
|
+
),
|
|
479
|
+
finish_reason="stop",
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
return RetabParsedChatCompletion(
|
|
484
|
+
id=response.id,
|
|
485
|
+
choices=choices,
|
|
486
|
+
created=int(datetime.datetime.now().timestamp()),
|
|
487
|
+
model=response.model,
|
|
488
|
+
object="chat.completion",
|
|
489
|
+
likelihoods={},
|
|
490
|
+
usage=usage,
|
|
491
|
+
)
|