chunkr-ai 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- chunkr_ai/__init__.py +1 -1
- chunkr_ai/api/auth.py +4 -4
- chunkr_ai/api/base.py +58 -48
- chunkr_ai/api/chunkr.py +21 -20
- chunkr_ai/api/chunkr_async.py +26 -20
- chunkr_ai/api/chunkr_base.py +34 -27
- chunkr_ai/api/config.py +41 -14
- chunkr_ai/api/misc.py +52 -44
- chunkr_ai/api/protocol.py +5 -3
- chunkr_ai/api/schema.py +66 -58
- chunkr_ai/api/task.py +13 -16
- chunkr_ai/api/task_async.py +16 -7
- chunkr_ai/api/task_base.py +4 -1
- chunkr_ai/models.py +23 -22
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/METADATA +1 -1
- chunkr_ai-0.0.16.dist-info/RECORD +21 -0
- chunkr_ai-0.0.14.dist-info/RECORD +0 -21
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/LICENSE +0 -0
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/WHEEL +0 -0
- {chunkr_ai-0.0.14.dist-info → chunkr_ai-0.0.16.dist-info}/top_level.txt +0 -0
chunkr_ai/api/config.py
CHANGED
@@ -3,28 +3,31 @@ from enum import Enum
|
|
3
3
|
from typing import Optional, List, Dict, Union, Type
|
4
4
|
from .schema import from_pydantic
|
5
5
|
|
6
|
+
|
6
7
|
class GenerationStrategy(str, Enum):
|
7
8
|
LLM = "LLM"
|
8
9
|
AUTO = "Auto"
|
9
10
|
|
11
|
+
|
10
12
|
class CroppingStrategy(str, Enum):
|
11
|
-
ALL = "All"
|
13
|
+
ALL = "All"
|
12
14
|
AUTO = "Auto"
|
13
15
|
|
16
|
+
|
14
17
|
class GenerationConfig(BaseModel):
|
15
18
|
html: Optional[GenerationStrategy] = None
|
16
19
|
llm: Optional[str] = None
|
17
20
|
markdown: Optional[GenerationStrategy] = None
|
18
21
|
crop_image: Optional[CroppingStrategy] = None
|
19
22
|
|
23
|
+
|
20
24
|
class SegmentProcessing(BaseModel):
|
21
|
-
model_config = ConfigDict(
|
22
|
-
|
23
|
-
alias_generator=str.title
|
24
|
-
)
|
25
|
-
|
25
|
+
model_config = ConfigDict(populate_by_name=True, alias_generator=str.title)
|
26
|
+
|
26
27
|
title: Optional[GenerationConfig] = Field(default=None, alias="Title")
|
27
|
-
section_header: Optional[GenerationConfig] = Field(
|
28
|
+
section_header: Optional[GenerationConfig] = Field(
|
29
|
+
default=None, alias="SectionHeader"
|
30
|
+
)
|
28
31
|
text: Optional[GenerationConfig] = Field(default=None, alias="Text")
|
29
32
|
list_item: Optional[GenerationConfig] = Field(default=None, alias="ListItem")
|
30
33
|
table: Optional[GenerationConfig] = Field(default=None, alias="Table")
|
@@ -36,38 +39,46 @@ class SegmentProcessing(BaseModel):
|
|
36
39
|
page_footer: Optional[GenerationConfig] = Field(default=None, alias="PageFooter")
|
37
40
|
page: Optional[GenerationConfig] = Field(default=None, alias="Page")
|
38
41
|
|
42
|
+
|
39
43
|
class ChunkProcessing(BaseModel):
|
40
44
|
target_length: Optional[int] = None
|
41
45
|
|
46
|
+
|
42
47
|
class Property(BaseModel):
|
43
48
|
name: str
|
44
49
|
prop_type: str
|
45
50
|
description: Optional[str] = None
|
46
51
|
default: Optional[str] = None
|
47
52
|
|
53
|
+
|
48
54
|
class JsonSchema(BaseModel):
|
49
55
|
title: str
|
50
56
|
properties: List[Property]
|
51
57
|
|
58
|
+
|
52
59
|
class OcrStrategy(str, Enum):
|
53
60
|
ALL = "All"
|
54
61
|
AUTO = "Auto"
|
55
|
-
|
62
|
+
|
63
|
+
|
56
64
|
class SegmentationStrategy(str, Enum):
|
57
65
|
LAYOUT_ANALYSIS = "LayoutAnalysis"
|
58
66
|
PAGE = "Page"
|
59
67
|
|
68
|
+
|
60
69
|
class BoundingBox(BaseModel):
|
61
70
|
left: float
|
62
71
|
top: float
|
63
72
|
width: float
|
64
73
|
height: float
|
65
74
|
|
75
|
+
|
66
76
|
class OCRResult(BaseModel):
|
67
77
|
bbox: BoundingBox
|
68
78
|
text: str
|
69
79
|
confidence: Optional[float]
|
70
80
|
|
81
|
+
|
71
82
|
class SegmentType(str, Enum):
|
72
83
|
CAPTION = "Caption"
|
73
84
|
FOOTNOTE = "Footnote"
|
@@ -82,6 +93,7 @@ class SegmentType(str, Enum):
|
|
82
93
|
TEXT = "Text"
|
83
94
|
TITLE = "Title"
|
84
95
|
|
96
|
+
|
85
97
|
class Segment(BaseModel):
|
86
98
|
bbox: BoundingBox
|
87
99
|
content: str
|
@@ -95,33 +107,43 @@ class Segment(BaseModel):
|
|
95
107
|
segment_id: str
|
96
108
|
segment_type: SegmentType
|
97
109
|
|
110
|
+
|
98
111
|
class Chunk(BaseModel):
|
99
112
|
chunk_id: str
|
100
113
|
chunk_length: int
|
101
114
|
segments: List[Segment]
|
102
115
|
|
116
|
+
|
103
117
|
class ExtractedJson(BaseModel):
|
104
118
|
data: Dict
|
105
119
|
|
120
|
+
|
106
121
|
class OutputResponse(BaseModel):
|
107
122
|
chunks: List[Chunk]
|
108
123
|
extracted_json: Optional[ExtractedJson] = Field(default=None)
|
109
124
|
|
125
|
+
|
110
126
|
class Model(str, Enum):
|
111
127
|
FAST = "Fast"
|
112
128
|
HIGH_QUALITY = "HighQuality"
|
113
129
|
|
130
|
+
class PipelineType(str, Enum):
|
131
|
+
AZURE = "Azure"
|
132
|
+
|
114
133
|
class Configuration(BaseModel):
|
115
134
|
chunk_processing: Optional[ChunkProcessing] = Field(default=None)
|
116
135
|
expires_in: Optional[int] = Field(default=None)
|
117
136
|
high_resolution: Optional[bool] = Field(default=None)
|
118
|
-
json_schema: Optional[Union[JsonSchema, Type[BaseModel], BaseModel]] = Field(
|
137
|
+
json_schema: Optional[Union[JsonSchema, Type[BaseModel], BaseModel]] = Field(
|
138
|
+
default=None
|
139
|
+
)
|
119
140
|
model: Optional[Model] = Field(default=None)
|
120
141
|
ocr_strategy: Optional[OcrStrategy] = Field(default=None)
|
121
142
|
segment_processing: Optional[SegmentProcessing] = Field(default=None)
|
122
143
|
segmentation_strategy: Optional[SegmentationStrategy] = Field(default=None)
|
144
|
+
pipeline: Optional[PipelineType] = Field(default=None)
|
123
145
|
|
124
|
-
@model_validator(mode=
|
146
|
+
@model_validator(mode="before")
|
125
147
|
def map_deprecated_fields(cls, values: Dict) -> Dict:
|
126
148
|
if isinstance(values, dict) and "target_chunk_length" in values:
|
127
149
|
target_length = values.pop("target_chunk_length")
|
@@ -130,13 +152,18 @@ class Configuration(BaseModel):
|
|
130
152
|
values["chunk_processing"]["target_length"] = target_length
|
131
153
|
return values
|
132
154
|
|
133
|
-
@model_validator(mode=
|
134
|
-
def convert_json_schema(self) ->
|
135
|
-
if self.json_schema is not None and not isinstance(
|
136
|
-
|
155
|
+
@model_validator(mode="after")
|
156
|
+
def convert_json_schema(self) -> "Configuration":
|
157
|
+
if self.json_schema is not None and not isinstance(
|
158
|
+
self.json_schema, JsonSchema
|
159
|
+
):
|
160
|
+
if isinstance(self.json_schema, (BaseModel, type)) and issubclass(
|
161
|
+
getattr(self.json_schema, "__class__", type), BaseModel
|
162
|
+
):
|
137
163
|
self.json_schema = JsonSchema(**from_pydantic(self.json_schema))
|
138
164
|
return self
|
139
165
|
|
166
|
+
|
140
167
|
class Status(str, Enum):
|
141
168
|
STARTING = "Starting"
|
142
169
|
PROCESSING = "Processing"
|
chunkr_ai/api/misc.py
CHANGED
@@ -3,71 +3,78 @@ import io
|
|
3
3
|
import json
|
4
4
|
from pathlib import Path
|
5
5
|
from PIL import Image
|
6
|
+
from pydantic import BaseModel
|
6
7
|
import requests
|
7
8
|
from typing import Union, Tuple, BinaryIO, Optional
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
) -> Tuple[str, BinaryIO]:
|
10
|
+
|
11
|
+
def prepare_file(file: Union[str, Path, BinaryIO, Image.Image]) -> Tuple[str, BinaryIO]:
|
12
12
|
"""Convert various file types into a tuple of (filename, file-like object)."""
|
13
13
|
# Handle URLs
|
14
|
-
if isinstance(file, str) and (
|
14
|
+
if isinstance(file, str) and (
|
15
|
+
file.startswith("http://") or file.startswith("https://")
|
16
|
+
):
|
15
17
|
response = requests.get(file)
|
16
18
|
response.raise_for_status()
|
17
|
-
|
19
|
+
|
18
20
|
# Try to get filename from Content-Disposition header first
|
19
21
|
filename = None
|
20
|
-
content_disposition = response.headers.get(
|
21
|
-
if content_disposition and
|
22
|
-
filename = content_disposition.split(
|
23
|
-
|
22
|
+
content_disposition = response.headers.get("Content-Disposition")
|
23
|
+
if content_disposition and "filename=" in content_disposition:
|
24
|
+
filename = content_disposition.split("filename=")[-1].strip("\"'")
|
25
|
+
|
24
26
|
# If no Content-Disposition, try to get clean filename from URL path
|
25
27
|
if not filename:
|
26
28
|
from urllib.parse import urlparse, unquote
|
29
|
+
|
27
30
|
parsed_url = urlparse(file)
|
28
31
|
path = unquote(parsed_url.path)
|
29
32
|
filename = Path(path).name if path else None
|
30
|
-
|
33
|
+
|
31
34
|
# Fallback to default name if we couldn't extract one
|
32
|
-
filename = filename or
|
33
|
-
|
35
|
+
filename = filename or "downloaded_file"
|
36
|
+
|
34
37
|
# Sanitize filename: remove invalid characters and limit length
|
35
38
|
import re
|
36
|
-
|
37
|
-
filename = re.sub(
|
38
|
-
|
39
|
-
|
40
|
-
|
39
|
+
|
40
|
+
filename = re.sub(
|
41
|
+
r'[<>:"/\\|?*%]', "_", filename
|
42
|
+
) # Replace invalid chars with underscore
|
43
|
+
filename = re.sub(r"\s+", "_", filename) # Replace whitespace with underscore
|
44
|
+
filename = filename.strip("._") # Remove leading/trailing dots and underscores
|
45
|
+
filename = filename[:255] # Limit length to 255 characters
|
46
|
+
|
41
47
|
file_obj = io.BytesIO(response.content)
|
42
48
|
return filename, file_obj
|
43
49
|
|
44
50
|
# Handle base64 strings
|
45
|
-
if isinstance(file, str) and
|
51
|
+
if isinstance(file, str) and "," in file and ";base64," in file:
|
46
52
|
try:
|
47
53
|
# Split header and data
|
48
|
-
header, base64_data = file.split(
|
54
|
+
header, base64_data = file.split(",", 1)
|
49
55
|
import base64
|
56
|
+
|
50
57
|
file_bytes = base64.b64decode(base64_data)
|
51
58
|
file_obj = io.BytesIO(file_bytes)
|
52
|
-
|
59
|
+
|
53
60
|
# Try to determine format from header
|
54
|
-
format =
|
55
|
-
mime_type = header.split(
|
56
|
-
|
61
|
+
format = "bin"
|
62
|
+
mime_type = header.split(":")[-1].split(";")[0].lower()
|
63
|
+
|
57
64
|
# Map MIME types to file extensions
|
58
65
|
mime_to_ext = {
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
66
|
+
"application/pdf": "pdf",
|
67
|
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
|
68
|
+
"application/msword": "doc",
|
69
|
+
"application/vnd.openxmlformats-officedocument.presentationml.presentation": "pptx",
|
70
|
+
"application/vnd.ms-powerpoint": "ppt",
|
71
|
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
72
|
+
"application/vnd.ms-excel": "xls",
|
73
|
+
"image/jpeg": "jpg",
|
74
|
+
"image/png": "png",
|
75
|
+
"image/jpg": "jpg",
|
69
76
|
}
|
70
|
-
|
77
|
+
|
71
78
|
if mime_type in mime_to_ext:
|
72
79
|
format = mime_to_ext[mime_type]
|
73
80
|
else:
|
@@ -82,36 +89,37 @@ def prepare_file(
|
|
82
89
|
path = Path(file).resolve()
|
83
90
|
if not path.exists():
|
84
91
|
raise FileNotFoundError(f"File not found: {file}")
|
85
|
-
return path.name, open(path,
|
92
|
+
return path.name, open(path, "rb")
|
86
93
|
|
87
94
|
# Handle PIL Images
|
88
95
|
if isinstance(file, Image.Image):
|
89
96
|
img_byte_arr = io.BytesIO()
|
90
|
-
format = file.format or
|
97
|
+
format = file.format or "PNG"
|
91
98
|
file.save(img_byte_arr, format=format)
|
92
99
|
img_byte_arr.seek(0)
|
93
100
|
return f"image.{format.lower()}", img_byte_arr
|
94
101
|
|
95
102
|
# Handle file-like objects
|
96
|
-
if hasattr(file,
|
103
|
+
if hasattr(file, "read") and hasattr(file, "seek"):
|
97
104
|
# Try to get the filename from the file object if possible
|
98
|
-
name =
|
105
|
+
name = (
|
106
|
+
getattr(file, "name", "document") if hasattr(file, "name") else "document"
|
107
|
+
)
|
99
108
|
return Path(name).name, file
|
100
109
|
|
101
110
|
raise TypeError(f"Unsupported file type: {type(file)}")
|
102
111
|
|
103
112
|
|
104
|
-
|
105
113
|
def prepare_upload_data(
|
106
114
|
file: Optional[Union[str, Path, BinaryIO, Image.Image]] = None,
|
107
|
-
config: Optional[Configuration] = None
|
115
|
+
config: Optional[Configuration] = None,
|
108
116
|
) -> dict:
|
109
117
|
"""Prepare files and data dictionaries for upload.
|
110
|
-
|
118
|
+
|
111
119
|
Args:
|
112
120
|
file: The file to upload
|
113
121
|
config: Optional configuration settings
|
114
|
-
|
122
|
+
|
115
123
|
Returns:
|
116
124
|
dict: (files dict) ready for upload
|
117
125
|
"""
|
@@ -123,6 +131,6 @@ def prepare_upload_data(
|
|
123
131
|
if config:
|
124
132
|
config_dict = config.model_dump(mode="json", exclude_none=True)
|
125
133
|
for key, value in config_dict.items():
|
126
|
-
files[key] = (None, json.dumps(value),
|
127
|
-
|
134
|
+
files[key] = (None, json.dumps(value), "application/json")
|
135
|
+
|
128
136
|
return files
|
chunkr_ai/api/protocol.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
from typing import Optional, runtime_checkable, Protocol
|
2
2
|
from requests import Session
|
3
|
-
from
|
3
|
+
from httpx import AsyncClient
|
4
|
+
|
4
5
|
|
5
6
|
@runtime_checkable
|
6
7
|
class ChunkrClientProtocol(Protocol):
|
7
8
|
"""Protocol defining the interface for Chunkr clients"""
|
9
|
+
|
8
10
|
url: str
|
9
11
|
_api_key: str
|
10
12
|
_session: Optional[Session] = None
|
11
|
-
_client: Optional[
|
13
|
+
_client: Optional[AsyncClient] = None
|
12
14
|
|
13
15
|
def get_api_key(self) -> str:
|
14
16
|
"""Get the API key"""
|
@@ -16,4 +18,4 @@ class ChunkrClientProtocol(Protocol):
|
|
16
18
|
|
17
19
|
def _headers(self) -> dict:
|
18
20
|
"""Return headers required for API requests"""
|
19
|
-
...
|
21
|
+
...
|
chunkr_ai/api/schema.py
CHANGED
@@ -2,17 +2,22 @@ from pydantic import BaseModel
|
|
2
2
|
from typing import Optional, List, Union, Type
|
3
3
|
import json
|
4
4
|
|
5
|
+
|
5
6
|
class Property(BaseModel):
|
6
7
|
name: str
|
7
8
|
prop_type: str
|
8
9
|
description: Optional[str] = None
|
9
10
|
default: Optional[str] = None
|
10
11
|
|
12
|
+
|
11
13
|
class JsonSchema(BaseModel):
|
12
14
|
title: str
|
13
15
|
properties: List[Property]
|
14
16
|
|
15
|
-
|
17
|
+
|
18
|
+
def from_pydantic(
|
19
|
+
pydantic: Union[BaseModel, Type[BaseModel]], current_depth: int = 0
|
20
|
+
) -> dict:
|
16
21
|
"""Convert a Pydantic model to a Chunk json schema."""
|
17
22
|
MAX_DEPTH = 5
|
18
23
|
model = pydantic if isinstance(pydantic, type) else pydantic.__class__
|
@@ -21,108 +26,111 @@ def from_pydantic(pydantic: Union[BaseModel, Type[BaseModel]], current_depth: in
|
|
21
26
|
|
22
27
|
def get_enum_description(details: dict) -> str:
|
23
28
|
"""Get description including enum values if they exist"""
|
24
|
-
description = details.get(
|
25
|
-
|
29
|
+
description = details.get("description", "")
|
30
|
+
|
26
31
|
# First check if this is a direct enum
|
27
|
-
if
|
28
|
-
enum_values = details[
|
29
|
-
enum_str =
|
32
|
+
if "enum" in details:
|
33
|
+
enum_values = details["enum"]
|
34
|
+
enum_str = "\nAllowed values:\n" + "\n".join(
|
35
|
+
f"- {val}" for val in enum_values
|
36
|
+
)
|
30
37
|
return f"{description}{enum_str}"
|
31
|
-
|
38
|
+
|
32
39
|
# Then check if it's a reference to an enum
|
33
|
-
if
|
34
|
-
ref_schema = resolve_ref(details[
|
35
|
-
if
|
36
|
-
enum_values = ref_schema[
|
37
|
-
enum_str =
|
40
|
+
if "$ref" in details:
|
41
|
+
ref_schema = resolve_ref(details["$ref"], schema.get("$defs", {}))
|
42
|
+
if "enum" in ref_schema:
|
43
|
+
enum_values = ref_schema["enum"]
|
44
|
+
enum_str = "\nAllowed values:\n" + "\n".join(
|
45
|
+
f"- {val}" for val in enum_values
|
46
|
+
)
|
38
47
|
return f"{description}{enum_str}"
|
39
|
-
|
48
|
+
|
40
49
|
return description
|
41
50
|
|
42
51
|
def resolve_ref(ref: str, definitions: dict) -> dict:
|
43
52
|
"""Resolve a $ref reference to its actual schema"""
|
44
|
-
if not ref.startswith(
|
53
|
+
if not ref.startswith("#/$defs/"):
|
45
54
|
return {}
|
46
|
-
ref_name = ref[len(
|
55
|
+
ref_name = ref[len("#/$defs/") :]
|
47
56
|
return definitions.get(ref_name, {})
|
48
57
|
|
49
58
|
def get_nested_schema(field_schema: dict, depth: int) -> dict:
|
50
59
|
if depth >= MAX_DEPTH:
|
51
60
|
return {}
|
52
|
-
|
61
|
+
|
53
62
|
# If there's a $ref, resolve it first
|
54
|
-
if
|
55
|
-
field_schema = resolve_ref(field_schema[
|
56
|
-
|
63
|
+
if "$ref" in field_schema:
|
64
|
+
field_schema = resolve_ref(field_schema["$ref"], schema.get("$defs", {}))
|
65
|
+
|
57
66
|
nested_props = {}
|
58
|
-
if field_schema.get(
|
59
|
-
for name, details in field_schema.get(
|
60
|
-
if details.get(
|
67
|
+
if field_schema.get("type") == "object":
|
68
|
+
for name, details in field_schema.get("properties", {}).items():
|
69
|
+
if details.get("type") == "object" or "$ref" in details:
|
61
70
|
ref_schema = details
|
62
|
-
if
|
63
|
-
ref_schema = resolve_ref(
|
71
|
+
if "$ref" in details:
|
72
|
+
ref_schema = resolve_ref(
|
73
|
+
details["$ref"], schema.get("$defs", {})
|
74
|
+
)
|
64
75
|
nested_schema = get_nested_schema(ref_schema, depth + 1)
|
65
76
|
nested_props[name] = {
|
66
|
-
|
67
|
-
|
68
|
-
|
77
|
+
"type": "object",
|
78
|
+
"description": get_enum_description(details),
|
79
|
+
"properties": nested_schema,
|
69
80
|
}
|
70
81
|
else:
|
71
82
|
nested_props[name] = {
|
72
|
-
|
73
|
-
|
83
|
+
"type": details.get("type", "string"),
|
84
|
+
"description": get_enum_description(details),
|
74
85
|
}
|
75
86
|
return nested_props
|
76
87
|
|
77
|
-
for name, details in schema.get(
|
88
|
+
for name, details in schema.get("properties", {}).items():
|
78
89
|
# Handle arrays
|
79
|
-
if details.get(
|
80
|
-
items = details.get(
|
81
|
-
if
|
82
|
-
items = resolve_ref(items[
|
83
|
-
|
90
|
+
if details.get("type") == "array":
|
91
|
+
items = details.get("items", {})
|
92
|
+
if "$ref" in items:
|
93
|
+
items = resolve_ref(items["$ref"], schema.get("$defs", {}))
|
94
|
+
|
84
95
|
# Get nested schema for array items
|
85
96
|
item_schema = get_nested_schema(items, current_depth)
|
86
97
|
description = get_enum_description(details)
|
87
|
-
|
98
|
+
|
88
99
|
if item_schema:
|
89
100
|
description = f"{description}\nList items schema:\n{json.dumps(item_schema, indent=2)}"
|
90
|
-
|
91
|
-
prop = Property(
|
92
|
-
name=name,
|
93
|
-
prop_type='list',
|
94
|
-
description=description
|
95
|
-
)
|
101
|
+
|
102
|
+
prop = Property(name=name, prop_type="list", description=description)
|
96
103
|
# Handle objects and references
|
97
|
-
elif details.get(
|
98
|
-
prop_type =
|
104
|
+
elif details.get("type") == "object" or "$ref" in details:
|
105
|
+
prop_type = "object"
|
99
106
|
ref_schema = details
|
100
|
-
if
|
101
|
-
ref_schema = resolve_ref(details[
|
102
|
-
|
107
|
+
if "$ref" in details:
|
108
|
+
ref_schema = resolve_ref(details["$ref"], schema.get("$defs", {}))
|
109
|
+
|
103
110
|
nested_schema = get_nested_schema(ref_schema, current_depth)
|
104
|
-
|
111
|
+
|
105
112
|
prop = Property(
|
106
113
|
name=name,
|
107
114
|
prop_type=prop_type,
|
108
115
|
description=get_enum_description(details),
|
109
|
-
properties=nested_schema
|
116
|
+
properties=nested_schema,
|
110
117
|
)
|
111
|
-
|
118
|
+
|
112
119
|
# Handle primitive types
|
113
120
|
else:
|
114
121
|
prop = Property(
|
115
122
|
name=name,
|
116
|
-
prop_type=details.get(
|
123
|
+
prop_type=details.get("type", "string"),
|
117
124
|
description=get_enum_description(details),
|
118
|
-
default=str(details.get(
|
125
|
+
default=str(details.get("default"))
|
126
|
+
if details.get("default") is not None
|
127
|
+
else None,
|
119
128
|
)
|
120
|
-
|
129
|
+
|
121
130
|
properties.append(prop)
|
122
|
-
|
131
|
+
|
123
132
|
json_schema = JsonSchema(
|
124
|
-
title=schema.get(
|
125
|
-
properties=properties
|
133
|
+
title=schema.get("title", model.__name__), properties=properties
|
126
134
|
)
|
127
|
-
|
128
|
-
return json_schema.model_dump(mode="json", exclude_none=True)
|
135
|
+
|
136
|
+
return json_schema.model_dump(mode="json", exclude_none=True)
|
chunkr_ai/api/task.py
CHANGED
@@ -3,24 +3,27 @@ from .misc import prepare_upload_data
|
|
3
3
|
from .task_base import TaskBase
|
4
4
|
import time
|
5
5
|
|
6
|
+
|
6
7
|
class TaskResponse(TaskBase):
|
7
8
|
def _poll_request(self) -> dict:
|
8
9
|
while True:
|
9
10
|
try:
|
10
11
|
if not self.task_url:
|
11
|
-
|
12
|
+
raise ValueError("Task URL not found in response")
|
12
13
|
if not self._client._session:
|
13
14
|
raise ValueError("Client session not found")
|
14
|
-
r = self._client._session.get(
|
15
|
+
r = self._client._session.get(
|
16
|
+
self.task_url, headers=self._client._headers()
|
17
|
+
)
|
15
18
|
r.raise_for_status()
|
16
19
|
return r.json()
|
17
20
|
except (ConnectionError, TimeoutError) as _:
|
18
21
|
print("Connection error while polling the task, retrying...")
|
19
22
|
time.sleep(0.5)
|
20
|
-
except Exception
|
23
|
+
except Exception:
|
21
24
|
raise
|
22
25
|
|
23
|
-
def poll(self) ->
|
26
|
+
def poll(self) -> "TaskResponse":
|
24
27
|
while True:
|
25
28
|
response = self._poll_request()
|
26
29
|
updated_task = TaskResponse(**response).with_client(self._client)
|
@@ -28,31 +31,28 @@ class TaskResponse(TaskBase):
|
|
28
31
|
if result := self._check_status():
|
29
32
|
return result
|
30
33
|
time.sleep(0.5)
|
31
|
-
|
32
|
-
def update(self, config: Configuration) ->
|
34
|
+
|
35
|
+
def update(self, config: Configuration) -> "TaskResponse":
|
33
36
|
if not self.task_url:
|
34
37
|
raise ValueError("Task URL not found")
|
35
38
|
if not self._client._session:
|
36
39
|
raise ValueError("Client session not found")
|
37
40
|
files = prepare_upload_data(None, config)
|
38
41
|
r = self._client._session.patch(
|
39
|
-
self.task_url,
|
40
|
-
files=files,
|
41
|
-
headers=self._client._headers()
|
42
|
+
self.task_url, files=files, headers=self._client._headers()
|
42
43
|
)
|
43
44
|
r.raise_for_status()
|
44
45
|
updated = TaskResponse(**r.json()).with_client(self._client)
|
45
46
|
self.__dict__.update(updated.__dict__)
|
46
47
|
return self.poll()
|
47
|
-
|
48
|
+
|
48
49
|
def cancel(self):
|
49
50
|
if not self.task_url:
|
50
51
|
raise ValueError("Task URL not found")
|
51
52
|
if not self._client._session:
|
52
53
|
raise ValueError("Client session not found")
|
53
54
|
r = self._client._session.get(
|
54
|
-
f"{self.task_url}/cancel",
|
55
|
-
headers=self._client._headers()
|
55
|
+
f"{self.task_url}/cancel", headers=self._client._headers()
|
56
56
|
)
|
57
57
|
r.raise_for_status()
|
58
58
|
self.poll()
|
@@ -62,8 +62,5 @@ class TaskResponse(TaskBase):
|
|
62
62
|
raise ValueError("Task URL not found")
|
63
63
|
if not self._client._session:
|
64
64
|
raise ValueError("Client session not found")
|
65
|
-
r = self._client._session.delete(
|
66
|
-
self.task_url,
|
67
|
-
headers=self._client._headers()
|
68
|
-
)
|
65
|
+
r = self._client._session.delete(self.task_url, headers=self._client._headers())
|
69
66
|
r.raise_for_status()
|