chunkr-ai 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chunkr_ai/__init__.py +1 -1
- chunkr_ai/api/auth.py +4 -4
- chunkr_ai/api/base.py +58 -48
- chunkr_ai/api/chunkr.py +21 -20
- chunkr_ai/api/chunkr_async.py +26 -20
- chunkr_ai/api/chunkr_base.py +34 -27
- chunkr_ai/api/config.py +41 -14
- chunkr_ai/api/misc.py +52 -44
- chunkr_ai/api/protocol.py +5 -3
- chunkr_ai/api/schema.py +66 -58
- chunkr_ai/api/task.py +13 -16
- chunkr_ai/api/task_async.py +16 -7
- chunkr_ai/api/task_base.py +4 -1
- chunkr_ai/models.py +23 -22
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/METADATA +1 -1
- chunkr_ai-0.0.16.dist-info/RECORD +21 -0
- chunkr_ai-0.0.14.dist-info/RECORD +0 -21
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/LICENSE +0 -0
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/WHEEL +0 -0
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/top_level.txt +0 -0
chunkr_ai/api/config.py
CHANGED
@@ -3,28 +3,31 @@ from enum import Enum
|
|
3
3
|
from typing import Optional, List, Dict, Union, Type
|
4
4
|
from .schema import from_pydantic
|
5
5
|
|
6
|
+
|
6
7
|
class GenerationStrategy(str, Enum):
|
7
8
|
LLM = "LLM"
|
8
9
|
AUTO = "Auto"
|
9
10
|
|
11
|
+
|
10
12
|
class CroppingStrategy(str, Enum):
|
11
|
-
ALL = "All"
|
13
|
+
ALL = "All"
|
12
14
|
AUTO = "Auto"
|
13
15
|
|
16
|
+
|
14
17
|
class GenerationConfig(BaseModel):
|
15
18
|
html: Optional[GenerationStrategy] = None
|
16
19
|
llm: Optional[str] = None
|
17
20
|
markdown: Optional[GenerationStrategy] = None
|
18
21
|
crop_image: Optional[CroppingStrategy] = None
|
19
22
|
|
23
|
+
|
20
24
|
class SegmentProcessing(BaseModel):
|
21
|
-
model_config = ConfigDict(
|
22
|
-
|
23
|
-
alias_generator=str.title
|
24
|
-
)
|
25
|
-
|
25
|
+
model_config = ConfigDict(populate_by_name=True, alias_generator=str.title)
|
26
|
+
|
26
27
|
title: Optional[GenerationConfig] = Field(default=None, alias="Title")
|
27
|
-
section_header: Optional[GenerationConfig] = Field(
|
28
|
+
section_header: Optional[GenerationConfig] = Field(
|
29
|
+
default=None, alias="SectionHeader"
|
30
|
+
)
|
28
31
|
text: Optional[GenerationConfig] = Field(default=None, alias="Text")
|
29
32
|
list_item: Optional[GenerationConfig] = Field(default=None, alias="ListItem")
|
30
33
|
table: Optional[GenerationConfig] = Field(default=None, alias="Table")
|
@@ -36,38 +39,46 @@ class SegmentProcessing(BaseModel):
|
|
36
39
|
page_footer: Optional[GenerationConfig] = Field(default=None, alias="PageFooter")
|
37
40
|
page: Optional[GenerationConfig] = Field(default=None, alias="Page")
|
38
41
|
|
42
|
+
|
39
43
|
class ChunkProcessing(BaseModel):
|
40
44
|
target_length: Optional[int] = None
|
41
45
|
|
46
|
+
|
42
47
|
class Property(BaseModel):
|
43
48
|
name: str
|
44
49
|
prop_type: str
|
45
50
|
description: Optional[str] = None
|
46
51
|
default: Optional[str] = None
|
47
52
|
|
53
|
+
|
48
54
|
class JsonSchema(BaseModel):
|
49
55
|
title: str
|
50
56
|
properties: List[Property]
|
51
57
|
|
58
|
+
|
52
59
|
class OcrStrategy(str, Enum):
|
53
60
|
ALL = "All"
|
54
61
|
AUTO = "Auto"
|
55
|
-
|
62
|
+
|
63
|
+
|
56
64
|
class SegmentationStrategy(str, Enum):
|
57
65
|
LAYOUT_ANALYSIS = "LayoutAnalysis"
|
58
66
|
PAGE = "Page"
|
59
67
|
|
68
|
+
|
60
69
|
class BoundingBox(BaseModel):
|
61
70
|
left: float
|
62
71
|
top: float
|
63
72
|
width: float
|
64
73
|
height: float
|
65
74
|
|
75
|
+
|
66
76
|
class OCRResult(BaseModel):
|
67
77
|
bbox: BoundingBox
|
68
78
|
text: str
|
69
79
|
confidence: Optional[float]
|
70
80
|
|
81
|
+
|
71
82
|
class SegmentType(str, Enum):
|
72
83
|
CAPTION = "Caption"
|
73
84
|
FOOTNOTE = "Footnote"
|
@@ -82,6 +93,7 @@ class SegmentType(str, Enum):
|
|
82
93
|
TEXT = "Text"
|
83
94
|
TITLE = "Title"
|
84
95
|
|
96
|
+
|
85
97
|
class Segment(BaseModel):
|
86
98
|
bbox: BoundingBox
|
87
99
|
content: str
|
@@ -95,33 +107,43 @@ class Segment(BaseModel):
|
|
95
107
|
segment_id: str
|
96
108
|
segment_type: SegmentType
|
97
109
|
|
110
|
+
|
98
111
|
class Chunk(BaseModel):
|
99
112
|
chunk_id: str
|
100
113
|
chunk_length: int
|
101
114
|
segments: List[Segment]
|
102
115
|
|
116
|
+
|
103
117
|
class ExtractedJson(BaseModel):
|
104
118
|
data: Dict
|
105
119
|
|
120
|
+
|
106
121
|
class OutputResponse(BaseModel):
|
107
122
|
chunks: List[Chunk]
|
108
123
|
extracted_json: Optional[ExtractedJson] = Field(default=None)
|
109
124
|
|
125
|
+
|
110
126
|
class Model(str, Enum):
|
111
127
|
FAST = "Fast"
|
112
128
|
HIGH_QUALITY = "HighQuality"
|
113
129
|
|
130
|
+
class PipelineType(str, Enum):
|
131
|
+
AZURE = "Azure"
|
132
|
+
|
114
133
|
class Configuration(BaseModel):
|
115
134
|
chunk_processing: Optional[ChunkProcessing] = Field(default=None)
|
116
135
|
expires_in: Optional[int] = Field(default=None)
|
117
136
|
high_resolution: Optional[bool] = Field(default=None)
|
118
|
-
json_schema: Optional[Union[JsonSchema, Type[BaseModel], BaseModel]] = Field(
|
137
|
+
json_schema: Optional[Union[JsonSchema, Type[BaseModel], BaseModel]] = Field(
|
138
|
+
default=None
|
139
|
+
)
|
119
140
|
model: Optional[Model] = Field(default=None)
|
120
141
|
ocr_strategy: Optional[OcrStrategy] = Field(default=None)
|
121
142
|
segment_processing: Optional[SegmentProcessing] = Field(default=None)
|
122
143
|
segmentation_strategy: Optional[SegmentationStrategy] = Field(default=None)
|
144
|
+
pipeline: Optional[PipelineType] = Field(default=None)
|
123
145
|
|
124
|
-
@model_validator(mode=
|
146
|
+
@model_validator(mode="before")
|
125
147
|
def map_deprecated_fields(cls, values: Dict) -> Dict:
|
126
148
|
if isinstance(values, dict) and "target_chunk_length" in values:
|
127
149
|
target_length = values.pop("target_chunk_length")
|
@@ -130,13 +152,18 @@ class Configuration(BaseModel):
|
|
130
152
|
values["chunk_processing"]["target_length"] = target_length
|
131
153
|
return values
|
132
154
|
|
133
|
-
@model_validator(mode=
|
134
|
-
def convert_json_schema(self) ->
|
135
|
-
if self.json_schema is not None and not isinstance(
|
136
|
-
|
155
|
+
@model_validator(mode="after")
|
156
|
+
def convert_json_schema(self) -> "Configuration":
|
157
|
+
if self.json_schema is not None and not isinstance(
|
158
|
+
self.json_schema, JsonSchema
|
159
|
+
):
|
160
|
+
if isinstance(self.json_schema, (BaseModel, type)) and issubclass(
|
161
|
+
getattr(self.json_schema, "__class__", type), BaseModel
|
162
|
+
):
|
137
163
|
self.json_schema = JsonSchema(**from_pydantic(self.json_schema))
|
138
164
|
return self
|
139
165
|
|
166
|
+
|
140
167
|
class Status(str, Enum):
|
141
168
|
STARTING = "Starting"
|
142
169
|
PROCESSING = "Processing"
|
chunkr_ai/api/misc.py
CHANGED
@@ -3,71 +3,78 @@ import io
|
|
3
3
|
import json
|
4
4
|
from pathlib import Path
|
5
5
|
from PIL import Image
|
6
|
+
from pydantic import BaseModel
|
6
7
|
import requests
|
7
8
|
from typing import Union, Tuple, BinaryIO, Optional
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
) -> Tuple[str, BinaryIO]:
|
10
|
+
|
11
|
+
def prepare_file(file: Union[str, Path, BinaryIO, Image.Image]) -> Tuple[str, BinaryIO]:
|
12
12
|
"""Convert various file types into a tuple of (filename, file-like object)."""
|
13
13
|
# Handle URLs
|
14
|
-
if isinstance(file, str) and (
|
14
|
+
if isinstance(file, str) and (
|
15
|
+
file.startswith("http://") or file.startswith("https://")
|
16
|
+
):
|
15
17
|
response = requests.get(file)
|
16
18
|
response.raise_for_status()
|
17
|
-
|
19
|
+
|
18
20
|
# Try to get filename from Content-Disposition header first
|
19
21
|
filename = None
|
20
|
-
content_disposition = response.headers.get(
|
21
|
-
if content_disposition and
|
22
|
-
filename = content_disposition.split(
|
23
|
-
|
22
|
+
content_disposition = response.headers.get("Content-Disposition")
|
23
|
+
if content_disposition and "filename=" in content_disposition:
|
24
|
+
filename = content_disposition.split("filename=")[-1].strip("\"'")
|
25
|
+
|
24
26
|
# If no Content-Disposition, try to get clean filename from URL path
|
25
27
|
if not filename:
|
26
28
|
from urllib.parse import urlparse, unquote
|
29
|
+
|
27
30
|
parsed_url = urlparse(file)
|
28
31
|
path = unquote(parsed_url.path)
|
29
32
|
filename = Path(path).name if path else None
|
30
|
-
|
33
|
+
|
31
34
|
# Fallback to default name if we couldn't extract one
|
32
|
-
filename = filename or
|
33
|
-
|
35
|
+
filename = filename or "downloaded_file"
|
36
|
+
|
34
37
|
# Sanitize filename: remove invalid characters and limit length
|
35
38
|
import re
|
36
|
-
|
37
|
-
filename = re.sub(
|
38
|
-
|
39
|
-
|
40
|
-
|
39
|
+
|
40
|
+
filename = re.sub(
|
41
|
+
r'[<>:"/\\|?*%]', "_", filename
|
42
|
+
) # Replace invalid chars with underscore
|
43
|
+
filename = re.sub(r"\s+", "_", filename) # Replace whitespace with underscore
|
44
|
+
filename = filename.strip("._") # Remove leading/trailing dots and underscores
|
45
|
+
filename = filename[:255] # Limit length to 255 characters
|
46
|
+
|
41
47
|
file_obj = io.BytesIO(response.content)
|
42
48
|
return filename, file_obj
|
43
49
|
|
44
50
|
# Handle base64 strings
|
45
|
-
if isinstance(file, str) and
|
51
|
+
if isinstance(file, str) and "," in file and ";base64," in file:
|
46
52
|
try:
|
47
53
|
# Split header and data
|
48
|
-
header, base64_data = file.split(
|
54
|
+
header, base64_data = file.split(",", 1)
|
49
55
|
import base64
|
56
|
+
|
50
57
|
file_bytes = base64.b64decode(base64_data)
|
51
58
|
file_obj = io.BytesIO(file_bytes)
|
52
|
-
|
59
|
+
|
53
60
|
# Try to determine format from header
|
54
|
-
format =
|
55
|
-
mime_type = header.split(
|
56
|
-
|
61
|
+
format = "bin"
|
62
|
+
mime_type = header.split(":")[-1].split(";")[0].lower()
|
63
|
+
|
57
64
|
# Map MIME types to file extensions
|
58
65
|
mime_to_ext = {
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
66
|
+
"application/pdf": "pdf",
|
67
|
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
|
68
|
+
"application/msword": "doc",
|
69
|
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx",
|
70
|
+
"application/vnd.ms-powerpoint": "ppt",
|
71
|
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
72
|
+
"application/vnd.ms-excel": "xls",
|
73
|
+
"image/jpeg": "jpg",
|
74
|
+
"image/png": "png",
|
75
|
+
"image/jpg": "jpg",
|
69
76
|
}
|
70
|
-
|
77
|
+
|
71
78
|
if mime_type in mime_to_ext:
|
72
79
|
format = mime_to_ext[mime_type]
|
73
80
|
else:
|
@@ -82,36 +89,37 @@ def prepare_file(
|
|
82
89
|
path = Path(file).resolve()
|
83
90
|
if not path.exists():
|
84
91
|
raise FileNotFoundError(f"File not found: {file}")
|
85
|
-
return path.name, open(path,
|
92
|
+
return path.name, open(path, "rb")
|
86
93
|
|
87
94
|
# Handle PIL Images
|
88
95
|
if isinstance(file, Image.Image):
|
89
96
|
img_byte_arr = io.BytesIO()
|
90
|
-
format = file.format or
|
97
|
+
format = file.format or "PNG"
|
91
98
|
file.save(img_byte_arr, format=format)
|
92
99
|
img_byte_arr.seek(0)
|
93
100
|
return f"image.{format.lower()}", img_byte_arr
|
94
101
|
|
95
102
|
# Handle file-like objects
|
96
|
-
if hasattr(file,
|
103
|
+
if hasattr(file, "read") and hasattr(file, "seek"):
|
97
104
|
# Try to get the filename from the file object if possible
|
98
|
-
name =
|
105
|
+
name = (
|
106
|
+
getattr(file, "name", "document") if hasattr(file, "name") else "document"
|
107
|
+
)
|
99
108
|
return Path(name).name, file
|
100
109
|
|
101
110
|
raise TypeError(f"Unsupported file type: {type(file)}")
|
102
111
|
|
103
112
|
|
104
|
-
|
105
113
|
def prepare_upload_data(
|
106
114
|
file: Optional[Union[str, Path, BinaryIO, Image.Image]] = None,
|
107
|
-
config: Optional[Configuration] = None
|
115
|
+
config: Optional[Configuration] = None,
|
108
116
|
) -> dict:
|
109
117
|
"""Prepare files and data dictionaries for upload.
|
110
|
-
|
118
|
+
|
111
119
|
Args:
|
112
120
|
file: The file to upload
|
113
121
|
config: Optional configuration settings
|
114
|
-
|
122
|
+
|
115
123
|
Returns:
|
116
124
|
dict: (files dict) ready for upload
|
117
125
|
"""
|
@@ -123,6 +131,6 @@ def prepare_upload_data(
|
|
123
131
|
if config:
|
124
132
|
config_dict = config.model_dump(mode="json", exclude_none=True)
|
125
133
|
for key, value in config_dict.items():
|
126
|
-
files[key] = (None, json.dumps(value),
|
127
|
-
|
134
|
+
files[key] = (None, json.dumps(value), "application/json")
|
135
|
+
|
128
136
|
return files
|
chunkr_ai/api/protocol.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
from typing import Optional, runtime_checkable, Protocol
|
2
2
|
from requests import Session
|
3
|
-
from
|
3
|
+
from httpx import AsyncClient
|
4
|
+
|
4
5
|
|
5
6
|
@runtime_checkable
|
6
7
|
class ChunkrClientProtocol(Protocol):
|
7
8
|
"""Protocol defining the interface for Chunkr clients"""
|
9
|
+
|
8
10
|
url: str
|
9
11
|
_api_key: str
|
10
12
|
_session: Optional[Session] = None
|
11
|
-
_client: Optional[
|
13
|
+
_client: Optional[AsyncClient] = None
|
12
14
|
|
13
15
|
def get_api_key(self) -> str:
|
14
16
|
"""Get the API key"""
|
@@ -16,4 +18,4 @@ class ChunkrClientProtocol(Protocol):
|
|
16
18
|
|
17
19
|
def _headers(self) -> dict:
|
18
20
|
"""Return headers required for API requests"""
|
19
|
-
...
|
21
|
+
...
|
chunkr_ai/api/schema.py
CHANGED
@@ -2,17 +2,22 @@ from pydantic import BaseModel
|
|
2
2
|
from typing import Optional, List, Union, Type
|
3
3
|
import json
|
4
4
|
|
5
|
+
|
5
6
|
class Property(BaseModel):
|
6
7
|
name: str
|
7
8
|
prop_type: str
|
8
9
|
description: Optional[str] = None
|
9
10
|
default: Optional[str] = None
|
10
11
|
|
12
|
+
|
11
13
|
class JsonSchema(BaseModel):
|
12
14
|
title: str
|
13
15
|
properties: List[Property]
|
14
16
|
|
15
|
-
|
17
|
+
|
18
|
+
def from_pydantic(
|
19
|
+
pydantic: Union[BaseModel, Type[BaseModel]], current_depth: int = 0
|
20
|
+
) -> dict:
|
16
21
|
"""Convert a Pydantic model to a Chunk json schema."""
|
17
22
|
MAX_DEPTH = 5
|
18
23
|
model = pydantic if isinstance(pydantic, type) else pydantic.__class__
|
@@ -21,108 +26,111 @@ def from_pydantic(pydantic: Union[BaseModel, Type[BaseModel]], current_depth: in
|
|
21
26
|
|
22
27
|
def get_enum_description(details: dict) -> str:
|
23
28
|
"""Get description including enum values if they exist"""
|
24
|
-
description = details.get(
|
25
|
-
|
29
|
+
description = details.get("description", "")
|
30
|
+
|
26
31
|
# First check if this is a direct enum
|
27
|
-
if
|
28
|
-
enum_values = details[
|
29
|
-
enum_str =
|
32
|
+
if "enum" in details:
|
33
|
+
enum_values = details["enum"]
|
34
|
+
enum_str = "\nAllowed values:\n" + "\n".join(
|
35
|
+
f"- {val}" for val in enum_values
|
36
|
+
)
|
30
37
|
return f"{description}{enum_str}"
|
31
|
-
|
38
|
+
|
32
39
|
# Then check if it's a reference to an enum
|
33
|
-
if
|
34
|
-
ref_schema = resolve_ref(details[
|
35
|
-
if
|
36
|
-
enum_values = ref_schema[
|
37
|
-
enum_str =
|
40
|
+
if "$ref" in details:
|
41
|
+
ref_schema = resolve_ref(details["$ref"], schema.get("$defs", {}))
|
42
|
+
if "enum" in ref_schema:
|
43
|
+
enum_values = ref_schema["enum"]
|
44
|
+
enum_str = "\nAllowed values:\n" + "\n".join(
|
45
|
+
f"- {val}" for val in enum_values
|
46
|
+
)
|
38
47
|
return f"{description}{enum_str}"
|
39
|
-
|
48
|
+
|
40
49
|
return description
|
41
50
|
|
42
51
|
def resolve_ref(ref: str, definitions: dict) -> dict:
|
43
52
|
"""Resolve a $ref reference to its actual schema"""
|
44
|
-
if not ref.startswith(
|
53
|
+
if not ref.startswith("#/$defs/"):
|
45
54
|
return {}
|
46
|
-
ref_name = ref[len(
|
55
|
+
ref_name = ref[len("#/$defs/") :]
|
47
56
|
return definitions.get(ref_name, {})
|
48
57
|
|
49
58
|
def get_nested_schema(field_schema: dict, depth: int) -> dict:
|
50
59
|
if depth >= MAX_DEPTH:
|
51
60
|
return {}
|
52
|
-
|
61
|
+
|
53
62
|
# If there's a $ref, resolve it first
|
54
|
-
if
|
55
|
-
field_schema = resolve_ref(field_schema[
|
56
|
-
|
63
|
+
if "$ref" in field_schema:
|
64
|
+
field_schema = resolve_ref(field_schema["$ref"], schema.get("$defs", {}))
|
65
|
+
|
57
66
|
nested_props = {}
|
58
|
-
if field_schema.get(
|
59
|
-
for name, details in field_schema.get(
|
60
|
-
if details.get(
|
67
|
+
if field_schema.get("type") == "object":
|
68
|
+
for name, details in field_schema.get("properties", {}).items():
|
69
|
+
if details.get("type") == "object" or "$ref" in details:
|
61
70
|
ref_schema = details
|
62
|
-
if
|
63
|
-
ref_schema = resolve_ref(
|
71
|
+
if "$ref" in details:
|
72
|
+
ref_schema = resolve_ref(
|
73
|
+
details["$ref"], schema.get("$defs", {})
|
74
|
+
)
|
64
75
|
nested_schema = get_nested_schema(ref_schema, depth + 1)
|
65
76
|
nested_props[name] = {
|
66
|
-
|
67
|
-
|
68
|
-
|
77
|
+
"type": "object",
|
78
|
+
"description": get_enum_description(details),
|
79
|
+
"properties": nested_schema,
|
69
80
|
}
|
70
81
|
else:
|
71
82
|
nested_props[name] = {
|
72
|
-
|
73
|
-
|
83
|
+
"type": details.get("type", "string"),
|
84
|
+
"description": get_enum_description(details),
|
74
85
|
}
|
75
86
|
return nested_props
|
76
87
|
|
77
|
-
for name, details in schema.get(
|
88
|
+
for name, details in schema.get("properties", {}).items():
|
78
89
|
# Handle arrays
|
79
|
-
if details.get(
|
80
|
-
items = details.get(
|
81
|
-
if
|
82
|
-
items = resolve_ref(items[
|
83
|
-
|
90
|
+
if details.get("type") == "array":
|
91
|
+
items = details.get("items", {})
|
92
|
+
if "$ref" in items:
|
93
|
+
items = resolve_ref(items["$ref"], schema.get("$defs", {}))
|
94
|
+
|
84
95
|
# Get nested schema for array items
|
85
96
|
item_schema = get_nested_schema(items, current_depth)
|
86
97
|
description = get_enum_description(details)
|
87
|
-
|
98
|
+
|
88
99
|
if item_schema:
|
89
100
|
description = f"{description}\nList items schema:\n{json.dumps(item_schema, indent=2)}"
|
90
|
-
|
91
|
-
prop = Property(
|
92
|
-
name=name,
|
93
|
-
prop_type='list',
|
94
|
-
description=description
|
95
|
-
)
|
101
|
+
|
102
|
+
prop = Property(name=name, prop_type="list", description=description)
|
96
103
|
# Handle objects and references
|
97
|
-
elif details.get(
|
98
|
-
prop_type =
|
104
|
+
elif details.get("type") == "object" or "$ref" in details:
|
105
|
+
prop_type = "object"
|
99
106
|
ref_schema = details
|
100
|
-
if
|
101
|
-
ref_schema = resolve_ref(details[
|
102
|
-
|
107
|
+
if "$ref" in details:
|
108
|
+
ref_schema = resolve_ref(details["$ref"], schema.get("$defs", {}))
|
109
|
+
|
103
110
|
nested_schema = get_nested_schema(ref_schema, current_depth)
|
104
|
-
|
111
|
+
|
105
112
|
prop = Property(
|
106
113
|
name=name,
|
107
114
|
prop_type=prop_type,
|
108
115
|
description=get_enum_description(details),
|
109
|
-
properties=nested_schema
|
116
|
+
properties=nested_schema,
|
110
117
|
)
|
111
|
-
|
118
|
+
|
112
119
|
# Handle primitive types
|
113
120
|
else:
|
114
121
|
prop = Property(
|
115
122
|
name=name,
|
116
|
-
prop_type=details.get(
|
123
|
+
prop_type=details.get("type", "string"),
|
117
124
|
description=get_enum_description(details),
|
118
|
-
default=str(details.get(
|
125
|
+
default=str(details.get("default"))
|
126
|
+
if details.get("default") is not None
|
127
|
+
else None,
|
119
128
|
)
|
120
|
-
|
129
|
+
|
121
130
|
properties.append(prop)
|
122
|
-
|
131
|
+
|
123
132
|
json_schema = JsonSchema(
|
124
|
-
title=schema.get(
|
125
|
-
properties=properties
|
133
|
+
title=schema.get("title", model.__name__), properties=properties
|
126
134
|
)
|
127
|
-
|
128
|
-
return json_schema.model_dump(mode="json", exclude_none=True)
|
135
|
+
|
136
|
+
return json_schema.model_dump(mode="json", exclude_none=True)
|
chunkr_ai/api/task.py
CHANGED
@@ -3,24 +3,27 @@ from .misc import prepare_upload_data
|
|
3
3
|
from .task_base import TaskBase
|
4
4
|
import time
|
5
5
|
|
6
|
+
|
6
7
|
class TaskResponse(TaskBase):
|
7
8
|
def _poll_request(self) -> dict:
|
8
9
|
while True:
|
9
10
|
try:
|
10
11
|
if not self.task_url:
|
11
|
-
|
12
|
+
raise ValueError("Task URL not found in response")
|
12
13
|
if not self._client._session:
|
13
14
|
raise ValueError("Client session not found")
|
14
|
-
r = self._client._session.get(
|
15
|
+
r = self._client._session.get(
|
16
|
+
self.task_url, headers=self._client._headers()
|
17
|
+
)
|
15
18
|
r.raise_for_status()
|
16
19
|
return r.json()
|
17
20
|
except (ConnectionError, TimeoutError) as _:
|
18
21
|
print("Connection error while polling the task, retrying...")
|
19
22
|
time.sleep(0.5)
|
20
|
-
except Exception
|
23
|
+
except Exception:
|
21
24
|
raise
|
22
25
|
|
23
|
-
def poll(self) ->
|
26
|
+
def poll(self) -> "TaskResponse":
|
24
27
|
while True:
|
25
28
|
response = self._poll_request()
|
26
29
|
updated_task = TaskResponse(**response).with_client(self._client)
|
@@ -28,31 +31,28 @@ class TaskResponse(TaskBase):
|
|
28
31
|
if result := self._check_status():
|
29
32
|
return result
|
30
33
|
time.sleep(0.5)
|
31
|
-
|
32
|
-
def update(self, config: Configuration) ->
|
34
|
+
|
35
|
+
def update(self, config: Configuration) -> "TaskResponse":
|
33
36
|
if not self.task_url:
|
34
37
|
raise ValueError("Task URL not found")
|
35
38
|
if not self._client._session:
|
36
39
|
raise ValueError("Client session not found")
|
37
40
|
files = prepare_upload_data(None, config)
|
38
41
|
r = self._client._session.patch(
|
39
|
-
self.task_url,
|
40
|
-
files=files,
|
41
|
-
headers=self._client._headers()
|
42
|
+
self.task_url, files=files, headers=self._client._headers()
|
42
43
|
)
|
43
44
|
r.raise_for_status()
|
44
45
|
updated = TaskResponse(**r.json()).with_client(self._client)
|
45
46
|
self.__dict__.update(updated.__dict__)
|
46
47
|
return self.poll()
|
47
|
-
|
48
|
+
|
48
49
|
def cancel(self):
|
49
50
|
if not self.task_url:
|
50
51
|
raise ValueError("Task URL not found")
|
51
52
|
if not self._client._session:
|
52
53
|
raise ValueError("Client session not found")
|
53
54
|
r = self._client._session.get(
|
54
|
-
f"{self.task_url}/cancel",
|
55
|
-
headers=self._client._headers()
|
55
|
+
f"{self.task_url}/cancel", headers=self._client._headers()
|
56
56
|
)
|
57
57
|
r.raise_for_status()
|
58
58
|
self.poll()
|
@@ -62,8 +62,5 @@ class TaskResponse(TaskBase):
|
|
62
62
|
raise ValueError("Task URL not found")
|
63
63
|
if not self._client._session:
|
64
64
|
raise ValueError("Client session not found")
|
65
|
-
r = self._client._session.delete(
|
66
|
-
self.task_url,
|
67
|
-
headers=self._client._headers()
|
68
|
-
)
|
65
|
+
r = self._client._session.delete(self.task_url, headers=self._client._headers())
|
69
66
|
r.raise_for_status()
|