chunkr-ai 0.0.14__py3-none-any.whl → 0.0.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chunkr_ai/__init__.py +1 -1
- chunkr_ai/api/auth.py +4 -4
- chunkr_ai/api/base.py +58 -48
- chunkr_ai/api/chunkr.py +21 -20
- chunkr_ai/api/chunkr_async.py +26 -20
- chunkr_ai/api/chunkr_base.py +34 -27
- chunkr_ai/api/config.py +38 -14
- chunkr_ai/api/misc.py +51 -44
- chunkr_ai/api/protocol.py +5 -3
- chunkr_ai/api/schema.py +66 -58
- chunkr_ai/api/task.py +13 -16
- chunkr_ai/api/task_async.py +16 -7
- chunkr_ai/api/task_base.py +2 -1
- chunkr_ai/models.py +21 -22
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.15.dist-info}/METADATA +1 -1
- chunkr_ai-0.0.15.dist-info/RECORD +21 -0
- chunkr_ai-0.0.14.dist-info/RECORD +0 -21
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.15.dist-info}/LICENSE +0 -0
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.15.dist-info}/WHEEL +0 -0
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.15.dist-info}/top_level.txt +0 -0
chunkr_ai/api/config.py
CHANGED
@@ -3,28 +3,31 @@ from enum import Enum
|
|
3
3
|
from typing import Optional, List, Dict, Union, Type
|
4
4
|
from .schema import from_pydantic
|
5
5
|
|
6
|
+
|
6
7
|
class GenerationStrategy(str, Enum):
|
7
8
|
LLM = "LLM"
|
8
9
|
AUTO = "Auto"
|
9
10
|
|
11
|
+
|
10
12
|
class CroppingStrategy(str, Enum):
|
11
|
-
ALL = "All"
|
13
|
+
ALL = "All"
|
12
14
|
AUTO = "Auto"
|
13
15
|
|
16
|
+
|
14
17
|
class GenerationConfig(BaseModel):
|
15
18
|
html: Optional[GenerationStrategy] = None
|
16
19
|
llm: Optional[str] = None
|
17
20
|
markdown: Optional[GenerationStrategy] = None
|
18
21
|
crop_image: Optional[CroppingStrategy] = None
|
19
22
|
|
23
|
+
|
20
24
|
class SegmentProcessing(BaseModel):
|
21
|
-
model_config = ConfigDict(
|
22
|
-
|
23
|
-
alias_generator=str.title
|
24
|
-
)
|
25
|
-
|
25
|
+
model_config = ConfigDict(populate_by_name=True, alias_generator=str.title)
|
26
|
+
|
26
27
|
title: Optional[GenerationConfig] = Field(default=None, alias="Title")
|
27
|
-
section_header: Optional[GenerationConfig] = Field(
|
28
|
+
section_header: Optional[GenerationConfig] = Field(
|
29
|
+
default=None, alias="SectionHeader"
|
30
|
+
)
|
28
31
|
text: Optional[GenerationConfig] = Field(default=None, alias="Text")
|
29
32
|
list_item: Optional[GenerationConfig] = Field(default=None, alias="ListItem")
|
30
33
|
table: Optional[GenerationConfig] = Field(default=None, alias="Table")
|
@@ -36,38 +39,46 @@ class SegmentProcessing(BaseModel):
|
|
36
39
|
page_footer: Optional[GenerationConfig] = Field(default=None, alias="PageFooter")
|
37
40
|
page: Optional[GenerationConfig] = Field(default=None, alias="Page")
|
38
41
|
|
42
|
+
|
39
43
|
class ChunkProcessing(BaseModel):
|
40
44
|
target_length: Optional[int] = None
|
41
45
|
|
46
|
+
|
42
47
|
class Property(BaseModel):
|
43
48
|
name: str
|
44
49
|
prop_type: str
|
45
50
|
description: Optional[str] = None
|
46
51
|
default: Optional[str] = None
|
47
52
|
|
53
|
+
|
48
54
|
class JsonSchema(BaseModel):
|
49
55
|
title: str
|
50
56
|
properties: List[Property]
|
51
57
|
|
58
|
+
|
52
59
|
class OcrStrategy(str, Enum):
|
53
60
|
ALL = "All"
|
54
61
|
AUTO = "Auto"
|
55
|
-
|
62
|
+
|
63
|
+
|
56
64
|
class SegmentationStrategy(str, Enum):
|
57
65
|
LAYOUT_ANALYSIS = "LayoutAnalysis"
|
58
66
|
PAGE = "Page"
|
59
67
|
|
68
|
+
|
60
69
|
class BoundingBox(BaseModel):
|
61
70
|
left: float
|
62
71
|
top: float
|
63
72
|
width: float
|
64
73
|
height: float
|
65
74
|
|
75
|
+
|
66
76
|
class OCRResult(BaseModel):
|
67
77
|
bbox: BoundingBox
|
68
78
|
text: str
|
69
79
|
confidence: Optional[float]
|
70
80
|
|
81
|
+
|
71
82
|
class SegmentType(str, Enum):
|
72
83
|
CAPTION = "Caption"
|
73
84
|
FOOTNOTE = "Footnote"
|
@@ -82,6 +93,7 @@ class SegmentType(str, Enum):
|
|
82
93
|
TEXT = "Text"
|
83
94
|
TITLE = "Title"
|
84
95
|
|
96
|
+
|
85
97
|
class Segment(BaseModel):
|
86
98
|
bbox: BoundingBox
|
87
99
|
content: str
|
@@ -95,33 +107,40 @@ class Segment(BaseModel):
|
|
95
107
|
segment_id: str
|
96
108
|
segment_type: SegmentType
|
97
109
|
|
110
|
+
|
98
111
|
class Chunk(BaseModel):
|
99
112
|
chunk_id: str
|
100
113
|
chunk_length: int
|
101
114
|
segments: List[Segment]
|
102
115
|
|
116
|
+
|
103
117
|
class ExtractedJson(BaseModel):
|
104
118
|
data: Dict
|
105
119
|
|
120
|
+
|
106
121
|
class OutputResponse(BaseModel):
|
107
122
|
chunks: List[Chunk]
|
108
123
|
extracted_json: Optional[ExtractedJson] = Field(default=None)
|
109
124
|
|
125
|
+
|
110
126
|
class Model(str, Enum):
|
111
127
|
FAST = "Fast"
|
112
128
|
HIGH_QUALITY = "HighQuality"
|
113
129
|
|
130
|
+
|
114
131
|
class Configuration(BaseModel):
|
115
132
|
chunk_processing: Optional[ChunkProcessing] = Field(default=None)
|
116
133
|
expires_in: Optional[int] = Field(default=None)
|
117
134
|
high_resolution: Optional[bool] = Field(default=None)
|
118
|
-
json_schema: Optional[Union[JsonSchema, Type[BaseModel], BaseModel]] = Field(
|
135
|
+
json_schema: Optional[Union[JsonSchema, Type[BaseModel], BaseModel]] = Field(
|
136
|
+
default=None
|
137
|
+
)
|
119
138
|
model: Optional[Model] = Field(default=None)
|
120
139
|
ocr_strategy: Optional[OcrStrategy] = Field(default=None)
|
121
140
|
segment_processing: Optional[SegmentProcessing] = Field(default=None)
|
122
141
|
segmentation_strategy: Optional[SegmentationStrategy] = Field(default=None)
|
123
142
|
|
124
|
-
@model_validator(mode=
|
143
|
+
@model_validator(mode="before")
|
125
144
|
def map_deprecated_fields(cls, values: Dict) -> Dict:
|
126
145
|
if isinstance(values, dict) and "target_chunk_length" in values:
|
127
146
|
target_length = values.pop("target_chunk_length")
|
@@ -130,13 +149,18 @@ class Configuration(BaseModel):
|
|
130
149
|
values["chunk_processing"]["target_length"] = target_length
|
131
150
|
return values
|
132
151
|
|
133
|
-
@model_validator(mode=
|
134
|
-
def convert_json_schema(self) ->
|
135
|
-
if self.json_schema is not None and not isinstance(
|
136
|
-
|
152
|
+
@model_validator(mode="after")
|
153
|
+
def convert_json_schema(self) -> "Configuration":
|
154
|
+
if self.json_schema is not None and not isinstance(
|
155
|
+
self.json_schema, JsonSchema
|
156
|
+
):
|
157
|
+
if isinstance(self.json_schema, (BaseModel, type)) and issubclass(
|
158
|
+
getattr(self.json_schema, "__class__", type), BaseModel
|
159
|
+
):
|
137
160
|
self.json_schema = JsonSchema(**from_pydantic(self.json_schema))
|
138
161
|
return self
|
139
162
|
|
163
|
+
|
140
164
|
class Status(str, Enum):
|
141
165
|
STARTING = "Starting"
|
142
166
|
PROCESSING = "Processing"
|
chunkr_ai/api/misc.py
CHANGED
@@ -6,68 +6,74 @@ from PIL import Image
|
|
6
6
|
import requests
|
7
7
|
from typing import Union, Tuple, BinaryIO, Optional
|
8
8
|
|
9
|
-
|
10
|
-
|
11
|
-
) -> Tuple[str, BinaryIO]:
|
9
|
+
|
10
|
+
def prepare_file(file: Union[str, Path, BinaryIO, Image.Image]) -> Tuple[str, BinaryIO]:
|
12
11
|
"""Convert various file types into a tuple of (filename, file-like object)."""
|
13
12
|
# Handle URLs
|
14
|
-
if isinstance(file, str) and (
|
13
|
+
if isinstance(file, str) and (
|
14
|
+
file.startswith("http://") or file.startswith("https://")
|
15
|
+
):
|
15
16
|
response = requests.get(file)
|
16
17
|
response.raise_for_status()
|
17
|
-
|
18
|
+
|
18
19
|
# Try to get filename from Content-Disposition header first
|
19
20
|
filename = None
|
20
|
-
content_disposition = response.headers.get(
|
21
|
-
if content_disposition and
|
22
|
-
filename = content_disposition.split(
|
23
|
-
|
21
|
+
content_disposition = response.headers.get("Content-Disposition")
|
22
|
+
if content_disposition and "filename=" in content_disposition:
|
23
|
+
filename = content_disposition.split("filename=")[-1].strip("\"'")
|
24
|
+
|
24
25
|
# If no Content-Disposition, try to get clean filename from URL path
|
25
26
|
if not filename:
|
26
27
|
from urllib.parse import urlparse, unquote
|
28
|
+
|
27
29
|
parsed_url = urlparse(file)
|
28
30
|
path = unquote(parsed_url.path)
|
29
31
|
filename = Path(path).name if path else None
|
30
|
-
|
32
|
+
|
31
33
|
# Fallback to default name if we couldn't extract one
|
32
|
-
filename = filename or
|
33
|
-
|
34
|
+
filename = filename or "downloaded_file"
|
35
|
+
|
34
36
|
# Sanitize filename: remove invalid characters and limit length
|
35
37
|
import re
|
36
|
-
|
37
|
-
filename = re.sub(
|
38
|
-
|
39
|
-
|
40
|
-
|
38
|
+
|
39
|
+
filename = re.sub(
|
40
|
+
r'[<>:"/\\|?*%]', "_", filename
|
41
|
+
) # Replace invalid chars with underscore
|
42
|
+
filename = re.sub(r"\s+", "_", filename) # Replace whitespace with underscore
|
43
|
+
filename = filename.strip("._") # Remove leading/trailing dots and underscores
|
44
|
+
filename = filename[:255] # Limit length to 255 characters
|
45
|
+
|
41
46
|
file_obj = io.BytesIO(response.content)
|
42
47
|
return filename, file_obj
|
43
48
|
|
44
49
|
# Handle base64 strings
|
45
|
-
if isinstance(file, str) and
|
50
|
+
if isinstance(file, str) and "," in file and ";base64," in file:
|
46
51
|
try:
|
47
52
|
# Split header and data
|
48
|
-
header, base64_data = file.split(
|
53
|
+
header, base64_data = file.split(",", 1)
|
49
54
|
import base64
|
55
|
+
|
50
56
|
file_bytes = base64.b64decode(base64_data)
|
51
57
|
file_obj = io.BytesIO(file_bytes)
|
52
|
-
|
58
|
+
|
53
59
|
# Try to determine format from header
|
54
|
-
format =
|
55
|
-
mime_type = header.split(
|
56
|
-
|
60
|
+
format = "bin"
|
61
|
+
mime_type = header.split(":")[-1].split(";")[0].lower()
|
62
|
+
|
57
63
|
# Map MIME types to file extensions
|
58
64
|
mime_to_ext = {
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
65
|
+
"application/pdf": "pdf",
|
66
|
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
|
67
|
+
"application/msword": "doc",
|
68
|
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx",
|
69
|
+
"application/vnd.ms-powerpoint": "ppt",
|
70
|
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
71
|
+
"application/vnd.ms-excel": "xls",
|
72
|
+
"image/jpeg": "jpg",
|
73
|
+
"image/png": "png",
|
74
|
+
"image/jpg": "jpg",
|
69
75
|
}
|
70
|
-
|
76
|
+
|
71
77
|
if mime_type in mime_to_ext:
|
72
78
|
format = mime_to_ext[mime_type]
|
73
79
|
else:
|
@@ -82,36 +88,37 @@ def prepare_file(
|
|
82
88
|
path = Path(file).resolve()
|
83
89
|
if not path.exists():
|
84
90
|
raise FileNotFoundError(f"File not found: {file}")
|
85
|
-
return path.name, open(path,
|
91
|
+
return path.name, open(path, "rb")
|
86
92
|
|
87
93
|
# Handle PIL Images
|
88
94
|
if isinstance(file, Image.Image):
|
89
95
|
img_byte_arr = io.BytesIO()
|
90
|
-
format = file.format or
|
96
|
+
format = file.format or "PNG"
|
91
97
|
file.save(img_byte_arr, format=format)
|
92
98
|
img_byte_arr.seek(0)
|
93
99
|
return f"image.{format.lower()}", img_byte_arr
|
94
100
|
|
95
101
|
# Handle file-like objects
|
96
|
-
if hasattr(file,
|
102
|
+
if hasattr(file, "read") and hasattr(file, "seek"):
|
97
103
|
# Try to get the filename from the file object if possible
|
98
|
-
name =
|
104
|
+
name = (
|
105
|
+
getattr(file, "name", "document") if hasattr(file, "name") else "document"
|
106
|
+
)
|
99
107
|
return Path(name).name, file
|
100
108
|
|
101
109
|
raise TypeError(f"Unsupported file type: {type(file)}")
|
102
110
|
|
103
111
|
|
104
|
-
|
105
112
|
def prepare_upload_data(
|
106
113
|
file: Optional[Union[str, Path, BinaryIO, Image.Image]] = None,
|
107
|
-
config: Optional[Configuration] = None
|
114
|
+
config: Optional[Configuration] = None,
|
108
115
|
) -> dict:
|
109
116
|
"""Prepare files and data dictionaries for upload.
|
110
|
-
|
117
|
+
|
111
118
|
Args:
|
112
119
|
file: The file to upload
|
113
120
|
config: Optional configuration settings
|
114
|
-
|
121
|
+
|
115
122
|
Returns:
|
116
123
|
dict: (files dict) ready for upload
|
117
124
|
"""
|
@@ -123,6 +130,6 @@ def prepare_upload_data(
|
|
123
130
|
if config:
|
124
131
|
config_dict = config.model_dump(mode="json", exclude_none=True)
|
125
132
|
for key, value in config_dict.items():
|
126
|
-
files[key] = (None, json.dumps(value),
|
127
|
-
|
133
|
+
files[key] = (None, json.dumps(value), "application/json")
|
134
|
+
|
128
135
|
return files
|
chunkr_ai/api/protocol.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
from typing import Optional, runtime_checkable, Protocol
|
2
2
|
from requests import Session
|
3
|
-
from
|
3
|
+
from httpx import AsyncClient
|
4
|
+
|
4
5
|
|
5
6
|
@runtime_checkable
|
6
7
|
class ChunkrClientProtocol(Protocol):
|
7
8
|
"""Protocol defining the interface for Chunkr clients"""
|
9
|
+
|
8
10
|
url: str
|
9
11
|
_api_key: str
|
10
12
|
_session: Optional[Session] = None
|
11
|
-
_client: Optional[
|
13
|
+
_client: Optional[AsyncClient] = None
|
12
14
|
|
13
15
|
def get_api_key(self) -> str:
|
14
16
|
"""Get the API key"""
|
@@ -16,4 +18,4 @@ class ChunkrClientProtocol(Protocol):
|
|
16
18
|
|
17
19
|
def _headers(self) -> dict:
|
18
20
|
"""Return headers required for API requests"""
|
19
|
-
...
|
21
|
+
...
|
chunkr_ai/api/schema.py
CHANGED
@@ -2,17 +2,22 @@ from pydantic import BaseModel
|
|
2
2
|
from typing import Optional, List, Union, Type
|
3
3
|
import json
|
4
4
|
|
5
|
+
|
5
6
|
class Property(BaseModel):
|
6
7
|
name: str
|
7
8
|
prop_type: str
|
8
9
|
description: Optional[str] = None
|
9
10
|
default: Optional[str] = None
|
10
11
|
|
12
|
+
|
11
13
|
class JsonSchema(BaseModel):
|
12
14
|
title: str
|
13
15
|
properties: List[Property]
|
14
16
|
|
15
|
-
|
17
|
+
|
18
|
+
def from_pydantic(
|
19
|
+
pydantic: Union[BaseModel, Type[BaseModel]], current_depth: int = 0
|
20
|
+
) -> dict:
|
16
21
|
"""Convert a Pydantic model to a Chunk json schema."""
|
17
22
|
MAX_DEPTH = 5
|
18
23
|
model = pydantic if isinstance(pydantic, type) else pydantic.__class__
|
@@ -21,108 +26,111 @@ def from_pydantic(pydantic: Union[BaseModel, Type[BaseModel]], current_depth: in
|
|
21
26
|
|
22
27
|
def get_enum_description(details: dict) -> str:
|
23
28
|
"""Get description including enum values if they exist"""
|
24
|
-
description = details.get(
|
25
|
-
|
29
|
+
description = details.get("description", "")
|
30
|
+
|
26
31
|
# First check if this is a direct enum
|
27
|
-
if
|
28
|
-
enum_values = details[
|
29
|
-
enum_str =
|
32
|
+
if "enum" in details:
|
33
|
+
enum_values = details["enum"]
|
34
|
+
enum_str = "\nAllowed values:\n" + "\n".join(
|
35
|
+
f"- {val}" for val in enum_values
|
36
|
+
)
|
30
37
|
return f"{description}{enum_str}"
|
31
|
-
|
38
|
+
|
32
39
|
# Then check if it's a reference to an enum
|
33
|
-
if
|
34
|
-
ref_schema = resolve_ref(details[
|
35
|
-
if
|
36
|
-
enum_values = ref_schema[
|
37
|
-
enum_str =
|
40
|
+
if "$ref" in details:
|
41
|
+
ref_schema = resolve_ref(details["$ref"], schema.get("$defs", {}))
|
42
|
+
if "enum" in ref_schema:
|
43
|
+
enum_values = ref_schema["enum"]
|
44
|
+
enum_str = "\nAllowed values:\n" + "\n".join(
|
45
|
+
f"- {val}" for val in enum_values
|
46
|
+
)
|
38
47
|
return f"{description}{enum_str}"
|
39
|
-
|
48
|
+
|
40
49
|
return description
|
41
50
|
|
42
51
|
def resolve_ref(ref: str, definitions: dict) -> dict:
|
43
52
|
"""Resolve a $ref reference to its actual schema"""
|
44
|
-
if not ref.startswith(
|
53
|
+
if not ref.startswith("#/$defs/"):
|
45
54
|
return {}
|
46
|
-
ref_name = ref[len(
|
55
|
+
ref_name = ref[len("#/$defs/") :]
|
47
56
|
return definitions.get(ref_name, {})
|
48
57
|
|
49
58
|
def get_nested_schema(field_schema: dict, depth: int) -> dict:
|
50
59
|
if depth >= MAX_DEPTH:
|
51
60
|
return {}
|
52
|
-
|
61
|
+
|
53
62
|
# If there's a $ref, resolve it first
|
54
|
-
if
|
55
|
-
field_schema = resolve_ref(field_schema[
|
56
|
-
|
63
|
+
if "$ref" in field_schema:
|
64
|
+
field_schema = resolve_ref(field_schema["$ref"], schema.get("$defs", {}))
|
65
|
+
|
57
66
|
nested_props = {}
|
58
|
-
if field_schema.get(
|
59
|
-
for name, details in field_schema.get(
|
60
|
-
if details.get(
|
67
|
+
if field_schema.get("type") == "object":
|
68
|
+
for name, details in field_schema.get("properties", {}).items():
|
69
|
+
if details.get("type") == "object" or "$ref" in details:
|
61
70
|
ref_schema = details
|
62
|
-
if
|
63
|
-
ref_schema = resolve_ref(
|
71
|
+
if "$ref" in details:
|
72
|
+
ref_schema = resolve_ref(
|
73
|
+
details["$ref"], schema.get("$defs", {})
|
74
|
+
)
|
64
75
|
nested_schema = get_nested_schema(ref_schema, depth + 1)
|
65
76
|
nested_props[name] = {
|
66
|
-
|
67
|
-
|
68
|
-
|
77
|
+
"type": "object",
|
78
|
+
"description": get_enum_description(details),
|
79
|
+
"properties": nested_schema,
|
69
80
|
}
|
70
81
|
else:
|
71
82
|
nested_props[name] = {
|
72
|
-
|
73
|
-
|
83
|
+
"type": details.get("type", "string"),
|
84
|
+
"description": get_enum_description(details),
|
74
85
|
}
|
75
86
|
return nested_props
|
76
87
|
|
77
|
-
for name, details in schema.get(
|
88
|
+
for name, details in schema.get("properties", {}).items():
|
78
89
|
# Handle arrays
|
79
|
-
if details.get(
|
80
|
-
items = details.get(
|
81
|
-
if
|
82
|
-
items = resolve_ref(items[
|
83
|
-
|
90
|
+
if details.get("type") == "array":
|
91
|
+
items = details.get("items", {})
|
92
|
+
if "$ref" in items:
|
93
|
+
items = resolve_ref(items["$ref"], schema.get("$defs", {}))
|
94
|
+
|
84
95
|
# Get nested schema for array items
|
85
96
|
item_schema = get_nested_schema(items, current_depth)
|
86
97
|
description = get_enum_description(details)
|
87
|
-
|
98
|
+
|
88
99
|
if item_schema:
|
89
100
|
description = f"{description}\nList items schema:\n{json.dumps(item_schema, indent=2)}"
|
90
|
-
|
91
|
-
prop = Property(
|
92
|
-
name=name,
|
93
|
-
prop_type='list',
|
94
|
-
description=description
|
95
|
-
)
|
101
|
+
|
102
|
+
prop = Property(name=name, prop_type="list", description=description)
|
96
103
|
# Handle objects and references
|
97
|
-
elif details.get(
|
98
|
-
prop_type =
|
104
|
+
elif details.get("type") == "object" or "$ref" in details:
|
105
|
+
prop_type = "object"
|
99
106
|
ref_schema = details
|
100
|
-
if
|
101
|
-
ref_schema = resolve_ref(details[
|
102
|
-
|
107
|
+
if "$ref" in details:
|
108
|
+
ref_schema = resolve_ref(details["$ref"], schema.get("$defs", {}))
|
109
|
+
|
103
110
|
nested_schema = get_nested_schema(ref_schema, current_depth)
|
104
|
-
|
111
|
+
|
105
112
|
prop = Property(
|
106
113
|
name=name,
|
107
114
|
prop_type=prop_type,
|
108
115
|
description=get_enum_description(details),
|
109
|
-
properties=nested_schema
|
116
|
+
properties=nested_schema,
|
110
117
|
)
|
111
|
-
|
118
|
+
|
112
119
|
# Handle primitive types
|
113
120
|
else:
|
114
121
|
prop = Property(
|
115
122
|
name=name,
|
116
|
-
prop_type=details.get(
|
123
|
+
prop_type=details.get("type", "string"),
|
117
124
|
description=get_enum_description(details),
|
118
|
-
default=str(details.get(
|
125
|
+
default=str(details.get("default"))
|
126
|
+
if details.get("default") is not None
|
127
|
+
else None,
|
119
128
|
)
|
120
|
-
|
129
|
+
|
121
130
|
properties.append(prop)
|
122
|
-
|
131
|
+
|
123
132
|
json_schema = JsonSchema(
|
124
|
-
title=schema.get(
|
125
|
-
properties=properties
|
133
|
+
title=schema.get("title", model.__name__), properties=properties
|
126
134
|
)
|
127
|
-
|
128
|
-
return json_schema.model_dump(mode="json", exclude_none=True)
|
135
|
+
|
136
|
+
return json_schema.model_dump(mode="json", exclude_none=True)
|
chunkr_ai/api/task.py
CHANGED
@@ -3,24 +3,27 @@ from .misc import prepare_upload_data
|
|
3
3
|
from .task_base import TaskBase
|
4
4
|
import time
|
5
5
|
|
6
|
+
|
6
7
|
class TaskResponse(TaskBase):
|
7
8
|
def _poll_request(self) -> dict:
|
8
9
|
while True:
|
9
10
|
try:
|
10
11
|
if not self.task_url:
|
11
|
-
|
12
|
+
raise ValueError("Task URL not found in response")
|
12
13
|
if not self._client._session:
|
13
14
|
raise ValueError("Client session not found")
|
14
|
-
r = self._client._session.get(
|
15
|
+
r = self._client._session.get(
|
16
|
+
self.task_url, headers=self._client._headers()
|
17
|
+
)
|
15
18
|
r.raise_for_status()
|
16
19
|
return r.json()
|
17
20
|
except (ConnectionError, TimeoutError) as _:
|
18
21
|
print("Connection error while polling the task, retrying...")
|
19
22
|
time.sleep(0.5)
|
20
|
-
except Exception
|
23
|
+
except Exception:
|
21
24
|
raise
|
22
25
|
|
23
|
-
def poll(self) ->
|
26
|
+
def poll(self) -> "TaskResponse":
|
24
27
|
while True:
|
25
28
|
response = self._poll_request()
|
26
29
|
updated_task = TaskResponse(**response).with_client(self._client)
|
@@ -28,31 +31,28 @@ class TaskResponse(TaskBase):
|
|
28
31
|
if result := self._check_status():
|
29
32
|
return result
|
30
33
|
time.sleep(0.5)
|
31
|
-
|
32
|
-
def update(self, config: Configuration) ->
|
34
|
+
|
35
|
+
def update(self, config: Configuration) -> "TaskResponse":
|
33
36
|
if not self.task_url:
|
34
37
|
raise ValueError("Task URL not found")
|
35
38
|
if not self._client._session:
|
36
39
|
raise ValueError("Client session not found")
|
37
40
|
files = prepare_upload_data(None, config)
|
38
41
|
r = self._client._session.patch(
|
39
|
-
self.task_url,
|
40
|
-
files=files,
|
41
|
-
headers=self._client._headers()
|
42
|
+
self.task_url, files=files, headers=self._client._headers()
|
42
43
|
)
|
43
44
|
r.raise_for_status()
|
44
45
|
updated = TaskResponse(**r.json()).with_client(self._client)
|
45
46
|
self.__dict__.update(updated.__dict__)
|
46
47
|
return self.poll()
|
47
|
-
|
48
|
+
|
48
49
|
def cancel(self):
|
49
50
|
if not self.task_url:
|
50
51
|
raise ValueError("Task URL not found")
|
51
52
|
if not self._client._session:
|
52
53
|
raise ValueError("Client session not found")
|
53
54
|
r = self._client._session.get(
|
54
|
-
f"{self.task_url}/cancel",
|
55
|
-
headers=self._client._headers()
|
55
|
+
f"{self.task_url}/cancel", headers=self._client._headers()
|
56
56
|
)
|
57
57
|
r.raise_for_status()
|
58
58
|
self.poll()
|
@@ -62,8 +62,5 @@ class TaskResponse(TaskBase):
|
|
62
62
|
raise ValueError("Task URL not found")
|
63
63
|
if not self._client._session:
|
64
64
|
raise ValueError("Client session not found")
|
65
|
-
r = self._client._session.delete(
|
66
|
-
self.task_url,
|
67
|
-
headers=self._client._headers()
|
68
|
-
)
|
65
|
+
r = self._client._session.delete(self.task_url, headers=self._client._headers())
|
69
66
|
r.raise_for_status()
|