chunkr-ai 0.0.45__tar.gz → 0.0.47__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {chunkr_ai-0.0.45/src/chunkr_ai.egg-info → chunkr_ai-0.0.47}/PKG-INFO +1 -1
  2. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/pyproject.toml +1 -1
  3. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/auth.py +1 -0
  4. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/chunkr.py +17 -11
  5. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/chunkr_base.py +12 -6
  6. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/configuration.py +98 -9
  7. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/decorators.py +7 -10
  8. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/misc.py +10 -6
  9. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/task_response.py +41 -18
  10. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/models.py +4 -0
  11. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47/src/chunkr_ai.egg-info}/PKG-INFO +1 -1
  12. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/tests/test_chunkr.py +163 -8
  13. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/LICENSE +0 -0
  14. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/README.md +0 -0
  15. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/setup.cfg +0 -0
  16. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/__init__.py +0 -0
  17. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/__init__.py +0 -0
  18. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai/api/protocol.py +0 -0
  19. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai.egg-info/SOURCES.txt +0 -0
  20. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai.egg-info/dependency_links.txt +0 -0
  21. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai.egg-info/requires.txt +0 -0
  22. {chunkr_ai-0.0.45 → chunkr_ai-0.0.47}/src/chunkr_ai.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chunkr-ai
3
- Version: 0.0.45
3
+ Version: 0.0.47
4
4
  Summary: Python client for Chunkr: open source document intelligence
5
5
  Author-email: Ishaan Kapoor <ishaan@lumina.sh>
6
6
  License: MIT License
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "chunkr-ai"
7
- version = "0.0.45"
7
+ version = "0.0.47"
8
8
  authors = [{"name" = "Ishaan Kapoor", "email" = "ishaan@lumina.sh"}]
9
9
  description = "Python client for Chunkr: open source document intelligence"
10
10
  readme = "README.md"
@@ -1,5 +1,6 @@
1
1
  class HeadersMixin:
2
2
  """Mixin class for handling authorization headers"""
3
+ _api_key: str = ""
3
4
 
4
5
  def get_api_key(self) -> str:
5
6
  """Get the API key"""
@@ -1,12 +1,13 @@
1
1
  from pathlib import Path
2
2
  from PIL import Image
3
- from typing import Union, BinaryIO, Optional
3
+ from typing import Union, BinaryIO, Optional, cast, Awaitable
4
4
 
5
5
  from .configuration import Configuration
6
6
  from .decorators import anywhere, ensure_client, retry_on_429
7
7
  from .misc import prepare_upload_data
8
8
  from .task_response import TaskResponse
9
9
  from .chunkr_base import ChunkrBase
10
+ from .protocol import ChunkrClientProtocol
10
11
 
11
12
  class Chunkr(ChunkrBase):
12
13
  """Chunkr API client that works in both sync and async contexts"""
@@ -16,17 +17,17 @@ class Chunkr(ChunkrBase):
16
17
  async def upload(
17
18
  self,
18
19
  file: Union[str, Path, BinaryIO, Image.Image],
19
- config: Configuration = None,
20
+ config: Optional[Configuration] = None,
20
21
  filename: Optional[str] = None,
21
22
  ) -> TaskResponse:
22
- task = await self.create_task(file, config, filename)
23
- return await task.poll()
23
+ task = await cast(Awaitable[TaskResponse], self.create_task(file, config, filename))
24
+ return await cast(Awaitable[TaskResponse], task.poll())
24
25
 
25
26
  @anywhere()
26
27
  @ensure_client()
27
28
  async def update(self, task_id: str, config: Configuration) -> TaskResponse:
28
- task = await self.update_task(task_id, config)
29
- return await task.poll()
29
+ task = await cast(Awaitable[TaskResponse], self.update_task(task_id, config))
30
+ return await cast(Awaitable[TaskResponse], task.poll())
30
31
 
31
32
  @anywhere()
32
33
  @ensure_client()
@@ -34,30 +35,32 @@ class Chunkr(ChunkrBase):
34
35
  async def create_task(
35
36
  self,
36
37
  file: Union[str, Path, BinaryIO, Image.Image],
37
- config: Configuration = None,
38
+ config: Optional[Configuration] = None,
38
39
  filename: Optional[str] = None,
39
40
  ) -> TaskResponse:
40
41
  """Create a new task with the given file and configuration."""
41
42
  data = await prepare_upload_data(file, filename, config)
43
+ assert self._client is not None
42
44
  r = await self._client.post(
43
45
  f"{self.url}/api/v1/task/parse", json=data, headers=self._headers()
44
46
  )
45
47
  r.raise_for_status()
46
- return TaskResponse(**r.json()).with_client(self, True, False)
48
+ return TaskResponse(**r.json()).with_client(cast(ChunkrClientProtocol, self), True, False)
47
49
 
48
50
  @anywhere()
49
51
  @ensure_client()
50
52
  @retry_on_429()
51
- async def update_task(self, task_id: str, config: Configuration) -> TaskResponse:
53
+ async def update_task(self, task_id: str, config: Optional[Configuration] = None) -> TaskResponse:
52
54
  """Update an existing task with new configuration."""
53
55
  data = await prepare_upload_data(None, None, config)
56
+ assert self._client is not None
54
57
  r = await self._client.patch(
55
58
  f"{self.url}/api/v1/task/{task_id}/parse",
56
59
  json=data,
57
60
  headers=self._headers(),
58
61
  )
59
62
  r.raise_for_status()
60
- return TaskResponse(**r.json()).with_client(self, True, False)
63
+ return TaskResponse(**r.json()).with_client(cast(ChunkrClientProtocol, self), True, False)
61
64
 
62
65
  @anywhere()
63
66
  @ensure_client()
@@ -66,17 +69,19 @@ class Chunkr(ChunkrBase):
66
69
  "base64_urls": str(base64_urls).lower(),
67
70
  "include_chunks": str(include_chunks).lower()
68
71
  }
72
+ assert self._client is not None
69
73
  r = await self._client.get(
70
74
  f"{self.url}/api/v1/task/{task_id}",
71
75
  params=params,
72
76
  headers=self._headers()
73
77
  )
74
78
  r.raise_for_status()
75
- return TaskResponse(**r.json()).with_client(self, include_chunks, base64_urls)
79
+ return TaskResponse(**r.json()).with_client(cast(ChunkrClientProtocol, self), include_chunks, base64_urls)
76
80
 
77
81
  @anywhere()
78
82
  @ensure_client()
79
83
  async def delete_task(self, task_id: str) -> None:
84
+ assert self._client is not None
80
85
  r = await self._client.delete(
81
86
  f"{self.url}/api/v1/task/{task_id}", headers=self._headers()
82
87
  )
@@ -85,6 +90,7 @@ class Chunkr(ChunkrBase):
85
90
  @anywhere()
86
91
  @ensure_client()
87
92
  async def cancel_task(self, task_id: str) -> None:
93
+ assert self._client is not None
88
94
  r = await self._client.get(
89
95
  f"{self.url}/api/v1/task/{task_id}/cancel", headers=self._headers()
90
96
  )
@@ -18,17 +18,23 @@ class ChunkrBase(HeadersMixin):
18
18
  raise_on_failure: Whether to raise an exception if the task fails. Defaults to False.
19
19
  """
20
20
 
21
- def __init__(self, url: str = None, api_key: str = None, raise_on_failure: bool = False):
21
+ url: str
22
+ _api_key: str
23
+ raise_on_failure: bool
24
+ _client: Optional[httpx.AsyncClient]
25
+
26
+ def __init__(self, url: Optional[str] = None, api_key: Optional[str] = None, raise_on_failure: bool = False):
22
27
  load_dotenv(override=True)
23
28
  self.url = url or os.getenv("CHUNKR_URL") or "https://api.chunkr.ai"
24
- self._api_key = api_key or os.getenv("CHUNKR_API_KEY")
29
+ _api_key = api_key or os.getenv("CHUNKR_API_KEY")
25
30
  self.raise_on_failure = raise_on_failure
26
31
 
27
- if not self._api_key:
32
+ if not _api_key:
28
33
  raise ValueError(
29
34
  "API key must be provided either directly, in .env file, or as CHUNKR_API_KEY environment variable. You can get an api key at: https://www.chunkr.ai"
30
35
  )
31
36
 
37
+ self._api_key = _api_key
32
38
  self.url = self.url.rstrip("/")
33
39
  self._client = httpx.AsyncClient()
34
40
 
@@ -36,7 +42,7 @@ class ChunkrBase(HeadersMixin):
36
42
  def upload(
37
43
  self,
38
44
  file: Union[str, Path, BinaryIO, Image.Image],
39
- config: Configuration = None,
45
+ config: Optional[Configuration] = None,
40
46
  filename: Optional[str] = None,
41
47
  ) -> TaskResponse:
42
48
  """Upload a file and wait for processing to complete.
@@ -90,7 +96,7 @@ class ChunkrBase(HeadersMixin):
90
96
  def create_task(
91
97
  self,
92
98
  file: Union[str, Path, BinaryIO, Image.Image],
93
- config: Configuration = None,
99
+ config: Optional[Configuration] = None,
94
100
  filename: Optional[str] = None,
95
101
  ) -> TaskResponse:
96
102
  """Upload a file for processing and immediately return the task response. It will not wait for processing to complete. To wait for the full processing to complete, use `task.poll()`.
@@ -127,7 +133,7 @@ class ChunkrBase(HeadersMixin):
127
133
 
128
134
  @abstractmethod
129
135
  def update_task(
130
- self, task_id: str, config: Configuration
136
+ self, task_id: str, config: Optional[Configuration] = None
131
137
  ) -> TaskResponse:
132
138
  """Update a task by its ID and immediately return the task response. It will not wait for processing to complete. To wait for the full processing to complete, use `task.poll()`.
133
139
 
@@ -1,7 +1,7 @@
1
1
  from pydantic import BaseModel, Field, ConfigDict
2
2
  from enum import Enum
3
3
  from typing import Any, List, Optional, Union
4
- from pydantic import field_validator
4
+ from pydantic import field_validator, field_serializer
5
5
 
6
6
  class GenerationStrategy(str, Enum):
7
7
  LLM = "LLM"
@@ -65,11 +65,7 @@ class TokenizerType(BaseModel):
65
65
  return f"string:{self.string_value}"
66
66
  return ""
67
67
 
68
- model_config = ConfigDict(
69
- json_encoders={
70
- 'TokenizerType': lambda v: v.model_dump()
71
- }
72
- )
68
+ model_config = ConfigDict()
73
69
 
74
70
  def model_dump(self, **kwargs):
75
71
  if self.enum_value is not None:
@@ -85,10 +81,13 @@ class ChunkProcessing(BaseModel):
85
81
 
86
82
  model_config = ConfigDict(
87
83
  arbitrary_types_allowed=True,
88
- json_encoders={
89
- TokenizerType: lambda v: v.model_dump()
90
- }
91
84
  )
85
+
86
+ @field_serializer('tokenizer')
87
+ def serialize_tokenizer(self, tokenizer: Optional[TokenizerType], _info):
88
+ if tokenizer is None:
89
+ return None
90
+ return tokenizer.model_dump()
92
91
 
93
92
  @field_validator('tokenizer', mode='before')
94
93
  def validate_tokenizer(cls, v):
@@ -130,6 +129,95 @@ class ErrorHandlingStrategy(str, Enum):
130
129
  FAIL = "Fail"
131
130
  CONTINUE = "Continue"
132
131
 
132
+ class FallbackStrategy(BaseModel):
133
+ type: str
134
+ model_id: Optional[str] = None
135
+
136
+ @classmethod
137
+ def none(cls) -> "FallbackStrategy":
138
+ return cls(type="None")
139
+
140
+ @classmethod
141
+ def default(cls) -> "FallbackStrategy":
142
+ return cls(type="Default")
143
+
144
+ @classmethod
145
+ def model(cls, model_id: str) -> "FallbackStrategy":
146
+ return cls(type="Model", model_id=model_id)
147
+
148
+ def __str__(self) -> str:
149
+ if self.type == "Model":
150
+ return f"Model({self.model_id})"
151
+ return self.type
152
+
153
+ def model_dump(self, **kwargs):
154
+ if self.type == "Model":
155
+ return {"Model": self.model_id}
156
+ return self.type
157
+
158
+ @field_validator('type')
159
+ def validate_type(cls, v):
160
+ if v not in ["None", "Default", "Model"]:
161
+ raise ValueError(f"Invalid fallback strategy: {v}")
162
+ return v
163
+
164
+ model_config = ConfigDict()
165
+
166
+ @classmethod
167
+ def model_validate(cls, obj):
168
+ # Handle string values like "None" or "Default"
169
+ if isinstance(obj, str):
170
+ if obj in ["None", "Default"]:
171
+ return cls(type=obj)
172
+ # Try to parse as Enum value if it's not a direct match
173
+ try:
174
+ return cls(type=obj)
175
+ except ValueError:
176
+ pass # Let it fall through to normal validation
177
+
178
+ # Handle dictionary format like {"Model": "model-id"}
179
+ elif isinstance(obj, dict) and len(obj) == 1:
180
+ if "Model" in obj:
181
+ return cls(type="Model", model_id=obj["Model"])
182
+
183
+ # Fall back to normal validation
184
+ return super().model_validate(obj)
185
+
186
+ class LlmProcessing(BaseModel):
187
+ model_id: Optional[str] = None
188
+ fallback_strategy: FallbackStrategy = Field(default_factory=FallbackStrategy.default)
189
+ max_completion_tokens: Optional[int] = None
190
+ temperature: float = 0.0
191
+
192
+ model_config = ConfigDict()
193
+
194
+ @field_serializer('fallback_strategy')
195
+ def serialize_fallback_strategy(self, fallback_strategy: FallbackStrategy, _info):
196
+ return fallback_strategy.model_dump()
197
+
198
+ @field_validator('fallback_strategy', mode='before')
199
+ def validate_fallback_strategy(cls, v):
200
+ if isinstance(v, str):
201
+ if v == "None":
202
+ return FallbackStrategy.none()
203
+ elif v == "Default":
204
+ return FallbackStrategy.default()
205
+ # Try to parse as a model ID if it's not None or Default
206
+ try:
207
+ return FallbackStrategy.model(v)
208
+ except ValueError:
209
+ pass # Let it fall through to normal validation
210
+ # Handle dictionary format like {"Model": "model-id"}
211
+ elif isinstance(v, dict) and len(v) == 1:
212
+ if "Model" in v:
213
+ return FallbackStrategy.model(v["Model"])
214
+ elif "None" in v or v.get("None") is None:
215
+ return FallbackStrategy.none()
216
+ elif "Default" in v or v.get("Default") is None:
217
+ return FallbackStrategy.default()
218
+
219
+ return v
220
+
133
221
  class BoundingBox(BaseModel):
134
222
  left: float
135
223
  top: float
@@ -199,6 +287,7 @@ class Configuration(BaseModel):
199
287
  segment_processing: Optional[SegmentProcessing] = None
200
288
  segmentation_strategy: Optional[SegmentationStrategy] = None
201
289
  pipeline: Optional[Pipeline] = None
290
+ llm_processing: Optional[LlmProcessing] = None
202
291
 
203
292
  class OutputConfiguration(Configuration):
204
293
  input_file_url: Optional[str] = None
@@ -13,10 +13,7 @@ P = ParamSpec('P')
13
13
 
14
14
  _sync_loop = None
15
15
 
16
- @overload
17
- def anywhere() -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Union[Awaitable[T], T]]]: ...
18
-
19
- def anywhere():
16
+ def anywhere() -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Union[Awaitable[T], T]]]:
20
17
  """Decorator that allows an async function to run anywhere - sync or async context."""
21
18
  def decorator(async_func: Callable[P, Awaitable[T]]) -> Callable[P, Union[Awaitable[T], T]]:
22
19
  @functools.wraps(async_func)
@@ -42,22 +39,22 @@ def anywhere():
42
39
  return wrapper
43
40
  return decorator
44
41
 
45
- def ensure_client() -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
42
+ def ensure_client() -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
46
43
  """Decorator that ensures a valid httpx.AsyncClient exists before executing the method"""
47
- def decorator(async_func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
44
+ def decorator(async_func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
48
45
  @functools.wraps(async_func)
49
- async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
46
+ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
50
47
  if not self._client or self._client.is_closed:
51
48
  self._client = httpx.AsyncClient()
52
49
  return await async_func(self, *args, **kwargs)
53
50
  return wrapper
54
51
  return decorator
55
52
 
56
- def require_task() -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
53
+ def require_task() -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
57
54
  """Decorator that ensures task has required attributes and valid client before execution"""
58
- def decorator(async_func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
55
+ def decorator(async_func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
59
56
  @functools.wraps(async_func)
60
- async def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
57
+ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
61
58
  if not self.task_url:
62
59
  raise ValueError("Task URL not found")
63
60
  if not self._client:
@@ -30,14 +30,18 @@ async def prepare_file(file: Union[str, Path, BinaryIO, Image.Image]) -> Tuple[O
30
30
  if isinstance(file, str):
31
31
  if file.startswith(('http://', 'https://')):
32
32
  return None, file
33
- try:
34
- base64.b64decode(file)
35
- return None, file
36
- except:
33
+ # Try to handle as a file path first
34
+ path = Path(file)
35
+ if path.exists():
36
+ # It's a valid file path, convert to Path object and continue processing
37
+ file = path
38
+ else:
39
+ # If not a valid file path, try treating as base64
37
40
  try:
38
- file = Path(file)
41
+ base64.b64decode(file)
42
+ return None, file
39
43
  except:
40
- raise ValueError("File must be a valid path, URL, or base64 string")
44
+ raise ValueError(f"File not found: {file} and it's not a valid base64 string")
41
45
 
42
46
  # Handle file paths - convert to base64
43
47
  if isinstance(file, Path):
@@ -1,5 +1,5 @@
1
1
  from datetime import datetime
2
- from typing import TypeVar, Optional, Generic
2
+ from typing import Optional, cast, Awaitable, Union
3
3
  from pydantic import BaseModel, PrivateAttr
4
4
  import asyncio
5
5
  import json
@@ -11,9 +11,7 @@ from .protocol import ChunkrClientProtocol
11
11
  from .misc import prepare_upload_data
12
12
  from .decorators import anywhere, require_task, retry_on_429
13
13
 
14
- T = TypeVar("T", bound="TaskResponse")
15
-
16
- class TaskResponse(BaseModel, Generic[T]):
14
+ class TaskResponse(BaseModel):
17
15
  configuration: OutputConfiguration
18
16
  created_at: datetime
19
17
  expires_at: Optional[datetime] = None
@@ -28,13 +26,13 @@ class TaskResponse(BaseModel, Generic[T]):
28
26
  _base64_urls: bool = False
29
27
  _client: Optional[ChunkrClientProtocol] = PrivateAttr(default=None)
30
28
 
31
- def with_client(self, client: ChunkrClientProtocol, include_chunks: bool = False, base64_urls: bool = False) -> T:
29
+ def with_client(self, client: ChunkrClientProtocol, include_chunks: bool = False, base64_urls: bool = False) -> "TaskResponse":
32
30
  self._client = client
33
31
  self.include_chunks = include_chunks
34
32
  self._base64_urls = base64_urls
35
33
  return self
36
34
 
37
- def _check_status(self) -> Optional[T]:
35
+ def _check_status(self) -> Optional["TaskResponse"]:
38
36
  """Helper method to check task status and handle completion/failure"""
39
37
  if self.status == "Failed":
40
38
  if getattr(self._client, 'raise_on_failure', True):
@@ -47,6 +45,11 @@ class TaskResponse(BaseModel, Generic[T]):
47
45
  @require_task()
48
46
  async def _poll_request(self) -> dict:
49
47
  try:
48
+ if not self._client:
49
+ raise ValueError("Chunkr client protocol is not initialized")
50
+ if not self._client._client or self._client._client.is_closed:
51
+ raise ValueError("httpx client is not open")
52
+ assert self.task_url is not None
50
53
  r = await self._client._client.get(
51
54
  self.task_url, headers=self._client._headers()
52
55
  )
@@ -64,10 +67,12 @@ class TaskResponse(BaseModel, Generic[T]):
64
67
  raise e
65
68
 
66
69
  @anywhere()
67
- async def poll(self) -> T:
70
+ async def poll(self) -> "TaskResponse":
68
71
  """Poll the task for completion."""
69
72
  while True:
70
73
  j = await self._poll_request()
74
+ if not self._client:
75
+ raise ValueError("Chunkr client protocol is not initialized")
71
76
  updated = TaskResponse(**j).with_client(self._client)
72
77
  self.__dict__.update(updated.__dict__)
73
78
  if res := self._check_status():
@@ -77,9 +82,14 @@ class TaskResponse(BaseModel, Generic[T]):
77
82
  @anywhere()
78
83
  @require_task()
79
84
  @retry_on_429()
80
- async def update(self, config: Configuration) -> T:
85
+ async def update(self, config: Configuration) -> "TaskResponse":
81
86
  """Update the task configuration."""
82
87
  data = await prepare_upload_data(None, None, config)
88
+ if not self._client:
89
+ raise ValueError("Chunkr client protocol is not initialized")
90
+ if not self._client._client or self._client._client.is_closed:
91
+ raise ValueError("httpx client is not open")
92
+ assert self.task_url is not None
83
93
  r = await self._client._client.patch(
84
94
  f"{self.task_url}/parse",
85
95
  json=data,
@@ -88,12 +98,17 @@ class TaskResponse(BaseModel, Generic[T]):
88
98
  r.raise_for_status()
89
99
  updated = TaskResponse(**r.json()).with_client(self._client)
90
100
  self.__dict__.update(updated.__dict__)
91
- return await self.poll()
101
+ return cast(TaskResponse, self.poll())
92
102
 
93
103
  @anywhere()
94
104
  @require_task()
95
- async def delete(self) -> T:
105
+ async def delete(self) -> "TaskResponse":
96
106
  """Delete the task."""
107
+ if not self._client:
108
+ raise ValueError("Chunkr client protocol is not initialized")
109
+ if not self._client._client or self._client._client.is_closed:
110
+ raise ValueError("httpx client is not open")
111
+ assert self.task_url is not None
97
112
  r = await self._client._client.delete(
98
113
  self.task_url, headers=self._client._headers()
99
114
  )
@@ -102,15 +117,20 @@ class TaskResponse(BaseModel, Generic[T]):
102
117
 
103
118
  @anywhere()
104
119
  @require_task()
105
- async def cancel(self) -> T:
120
+ async def cancel(self) -> "TaskResponse":
106
121
  """Cancel the task."""
122
+ if not self._client:
123
+ raise ValueError("Chunkr client protocol is not initialized")
124
+ if not self._client._client or self._client._client.is_closed:
125
+ raise ValueError("httpx client is not open")
126
+ assert self.task_url is not None
107
127
  r = await self._client._client.get(
108
128
  f"{self.task_url}/cancel", headers=self._client._headers()
109
129
  )
110
130
  r.raise_for_status()
111
- return await self.poll()
131
+ return cast(TaskResponse, self.poll())
112
132
 
113
- def _write_to_file(self, content: str | dict, output_file: str, is_json: bool = False) -> None:
133
+ def _write_to_file(self, content: Union[str, dict], output_file: Optional[str], is_json: bool = False) -> None:
114
134
  """Helper method to write content to a file
115
135
 
116
136
  Args:
@@ -131,9 +151,12 @@ class TaskResponse(BaseModel, Generic[T]):
131
151
  if is_json:
132
152
  json.dump(content, f, cls=DateTimeEncoder, indent=2)
133
153
  else:
134
- f.write(content)
154
+ if isinstance(content, str):
155
+ f.write(content)
156
+ else:
157
+ raise ValueError("Content is not a string")
135
158
 
136
- def html(self, output_file: str = None) -> str:
159
+ def html(self, output_file: Optional[str] = None) -> str:
137
160
  """Get the full HTML of the task
138
161
 
139
162
  Args:
@@ -143,7 +166,7 @@ class TaskResponse(BaseModel, Generic[T]):
143
166
  self._write_to_file(content, output_file)
144
167
  return content
145
168
 
146
- def markdown(self, output_file: str = None) -> str:
169
+ def markdown(self, output_file: Optional[str] = None) -> str:
147
170
  """Get the full markdown of the task
148
171
 
149
172
  Args:
@@ -153,7 +176,7 @@ class TaskResponse(BaseModel, Generic[T]):
153
176
  self._write_to_file(content, output_file)
154
177
  return content
155
178
 
156
- def content(self, output_file: str = None) -> str:
179
+ def content(self, output_file: Optional[str] = None) -> str:
157
180
  """Get the full content of the task
158
181
 
159
182
  Args:
@@ -163,7 +186,7 @@ class TaskResponse(BaseModel, Generic[T]):
163
186
  self._write_to_file(content, output_file)
164
187
  return content
165
188
 
166
- def json(self, output_file: str = None) -> dict:
189
+ def json(self, output_file: Optional[str] = None) -> dict:
167
190
  """Get the full task data as JSON
168
191
 
169
192
  Args:
@@ -6,8 +6,10 @@ from .api.configuration import (
6
6
  CroppingStrategy,
7
7
  EmbedSource,
8
8
  ErrorHandlingStrategy,
9
+ FallbackStrategy,
9
10
  GenerationStrategy,
10
11
  GenerationConfig,
12
+ LlmProcessing,
11
13
  Model,
12
14
  OCRResult,
13
15
  OcrStrategy,
@@ -31,8 +33,10 @@ __all__ = [
31
33
  "CroppingStrategy",
32
34
  "EmbedSource",
33
35
  "ErrorHandlingStrategy",
36
+ "FallbackStrategy",
34
37
  "GenerationConfig",
35
38
  "GenerationStrategy",
39
+ "LlmProcessing",
36
40
  "Model",
37
41
  "OCRResult",
38
42
  "OcrStrategy",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: chunkr-ai
3
- Version: 0.0.45
3
+ Version: 0.0.47
4
4
  Summary: Python client for Chunkr: open source document intelligence
5
5
  Author-email: Ishaan Kapoor <ishaan@lumina.sh>
6
6
  License: MIT License
@@ -18,12 +18,22 @@ from chunkr_ai.models import (
18
18
  EmbedSource,
19
19
  ErrorHandlingStrategy,
20
20
  Tokenizer,
21
+ LlmProcessing,
22
+ FallbackStrategy,
21
23
  )
22
24
 
23
25
  @pytest.fixture
24
26
  def sample_path():
25
27
  return Path("tests/files/test.pdf")
26
28
 
29
+ @pytest.fixture
30
+ def sample_absolute_path_str():
31
+ return "tests/files/test.pdf"
32
+
33
+ @pytest.fixture
34
+ def sample_relative_path_str():
35
+ return "./tests/files/test.pdf"
36
+
27
37
  @pytest.fixture
28
38
  def sample_image():
29
39
  return Image.open("tests/files/test.jpg")
@@ -41,7 +51,7 @@ def client():
41
51
  def markdown_embed_config():
42
52
  return Configuration(
43
53
  segment_processing=SegmentProcessing(
44
- page=GenerationConfig(
54
+ Page=GenerationConfig(
45
55
  html=GenerationStrategy.LLM,
46
56
  markdown=GenerationStrategy.LLM,
47
57
  embed_sources=[EmbedSource.MARKDOWN]
@@ -53,7 +63,7 @@ def markdown_embed_config():
53
63
  def html_embed_config():
54
64
  return Configuration(
55
65
  segment_processing=SegmentProcessing(
56
- page=GenerationConfig(
66
+ Page=GenerationConfig(
57
67
  html=GenerationStrategy.LLM,
58
68
  markdown=GenerationStrategy.LLM,
59
69
  embed_sources=[EmbedSource.HTML]
@@ -65,7 +75,7 @@ def html_embed_config():
65
75
  def multiple_embed_config():
66
76
  return Configuration(
67
77
  segment_processing=SegmentProcessing(
68
- page=GenerationConfig(
78
+ Page=GenerationConfig(
69
79
  html=GenerationStrategy.LLM,
70
80
  markdown=GenerationStrategy.LLM,
71
81
  llm="Generate a summary of this content",
@@ -113,7 +123,7 @@ def xlm_roberta_with_html_content_config():
113
123
  tokenizer=Tokenizer.XLM_ROBERTA_BASE
114
124
  ),
115
125
  segment_processing=SegmentProcessing(
116
- page=GenerationConfig(
126
+ Page=GenerationConfig(
117
127
  html=GenerationStrategy.LLM,
118
128
  markdown=GenerationStrategy.LLM,
119
129
  embed_sources=[EmbedSource.HTML, EmbedSource.CONTENT]
@@ -121,6 +131,39 @@ def xlm_roberta_with_html_content_config():
121
131
  ),
122
132
  )
123
133
 
134
+ @pytest.fixture
135
+ def none_fallback_config():
136
+ return Configuration(
137
+ llm_processing=LlmProcessing(
138
+ model_id="gemini-pro-2.5",
139
+ fallback_strategy=FallbackStrategy.none(),
140
+ max_completion_tokens=500,
141
+ temperature=0.2
142
+ ),
143
+ )
144
+
145
+ @pytest.fixture
146
+ def default_fallback_config():
147
+ return Configuration(
148
+ llm_processing=LlmProcessing(
149
+ model_id="gemini-pro-2.5",
150
+ fallback_strategy=FallbackStrategy.default(),
151
+ max_completion_tokens=1000,
152
+ temperature=0.5
153
+ ),
154
+ )
155
+
156
+ @pytest.fixture
157
+ def model_fallback_config():
158
+ return Configuration(
159
+ llm_processing=LlmProcessing(
160
+ model_id="gemini-pro-2.5",
161
+ fallback_strategy=FallbackStrategy.model("claude-3.7-sonnet"),
162
+ max_completion_tokens=2000,
163
+ temperature=0.7
164
+ ),
165
+ )
166
+
124
167
  @pytest.mark.asyncio
125
168
  async def test_send_file_path(client, sample_path):
126
169
  response = await client.upload(sample_path)
@@ -128,6 +171,20 @@ async def test_send_file_path(client, sample_path):
128
171
  assert response.status == "Succeeded"
129
172
  assert response.output is not None
130
173
 
174
+ @pytest.mark.asyncio
175
+ async def test_send_file_path_str(client, sample_absolute_path_str):
176
+ response = await client.upload(sample_absolute_path_str)
177
+ assert response.task_id is not None
178
+ assert response.status == "Succeeded"
179
+ assert response.output is not None
180
+
181
+ @pytest.mark.asyncio
182
+ async def test_send_file_relative_path_str(client, sample_relative_path_str):
183
+ response = await client.upload(sample_relative_path_str)
184
+ assert response.task_id is not None
185
+ assert response.status == "Succeeded"
186
+ assert response.output is not None
187
+
131
188
  @pytest.mark.asyncio
132
189
  async def test_send_file_url(client, sample_url):
133
190
  response = await client.upload(sample_url)
@@ -136,7 +193,7 @@ async def test_send_file_url(client, sample_url):
136
193
  assert response.output is not None
137
194
 
138
195
  @pytest.mark.asyncio
139
- async def test_send_file_path_str(client, sample_path):
196
+ async def test_send_file_path_as_str(client, sample_path):
140
197
  response = await client.upload(str(sample_path))
141
198
  assert response.task_id is not None
142
199
  assert response.status == "Succeeded"
@@ -205,7 +262,7 @@ async def test_page_llm_html(client, sample_path):
205
262
  Configuration(
206
263
  segmentation_strategy=SegmentationStrategy.PAGE,
207
264
  segment_processing=SegmentProcessing(
208
- page=GenerationConfig(html=GenerationStrategy.LLM)
265
+ Page=GenerationConfig(html=GenerationStrategy.LLM)
209
266
  ),
210
267
  ),
211
268
  )
@@ -218,7 +275,7 @@ async def test_page_llm(client, sample_path):
218
275
  configuration = Configuration(
219
276
  segmentation_strategy=SegmentationStrategy.PAGE,
220
277
  segment_processing=SegmentProcessing(
221
- page=GenerationConfig(
278
+ Page=GenerationConfig(
222
279
  html=GenerationStrategy.LLM, markdown=GenerationStrategy.LLM
223
280
  )
224
281
  ),
@@ -297,7 +354,7 @@ async def test_pipeline_type_azure(client, sample_path):
297
354
  assert response.output is not None
298
355
 
299
356
  @pytest.mark.asyncio
300
- async def test_pipeline_type_azure(client, sample_path):
357
+ async def test_pipeline_type_chunkr(client, sample_path):
301
358
  response = await client.upload(sample_path, Configuration(pipeline=Pipeline.CHUNKR))
302
359
  assert response.task_id is not None
303
360
  assert response.status == "Succeeded"
@@ -451,3 +508,101 @@ async def test_error_handling_continue(client, sample_path):
451
508
  assert response.task_id is not None
452
509
  assert response.status == "Succeeded"
453
510
  assert response.output is not None
511
+
512
+ @pytest.mark.asyncio
513
+ async def test_llm_processing_none_fallback(client, sample_path, none_fallback_config):
514
+ response = await client.upload(sample_path, none_fallback_config)
515
+ assert response.task_id is not None
516
+ assert response.status == "Succeeded"
517
+ assert response.output is not None
518
+ assert response.configuration.llm_processing is not None
519
+ assert response.configuration.llm_processing.model_id == "gemini-pro-2.5"
520
+ assert str(response.configuration.llm_processing.fallback_strategy) == "None"
521
+ assert response.configuration.llm_processing.max_completion_tokens == 500
522
+ assert response.configuration.llm_processing.temperature == 0.2
523
+
524
+ @pytest.mark.asyncio
525
+ async def test_llm_processing_default_fallback(client, sample_path, default_fallback_config):
526
+ response = await client.upload(sample_path, default_fallback_config)
527
+ assert response.task_id is not None
528
+ assert response.status == "Succeeded"
529
+ assert response.output is not None
530
+ assert response.configuration.llm_processing is not None
531
+ assert response.configuration.llm_processing.model_id == "gemini-pro-2.5"
532
+ # The service may resolve Default to an actual model
533
+ assert response.configuration.llm_processing.fallback_strategy is not None
534
+ assert response.configuration.llm_processing.max_completion_tokens == 1000
535
+ assert response.configuration.llm_processing.temperature == 0.5
536
+
537
+ @pytest.mark.asyncio
538
+ async def test_llm_processing_model_fallback(client, sample_path, model_fallback_config):
539
+ response = await client.upload(sample_path, model_fallback_config)
540
+ assert response.task_id is not None
541
+ assert response.status == "Succeeded"
542
+ assert response.output is not None
543
+ assert response.configuration.llm_processing is not None
544
+ assert response.configuration.llm_processing.model_id == "gemini-pro-2.5"
545
+ assert str(response.configuration.llm_processing.fallback_strategy) == "Model(claude-3.7-sonnet)"
546
+ assert response.configuration.llm_processing.max_completion_tokens == 2000
547
+ assert response.configuration.llm_processing.temperature == 0.7
548
+
549
+ @pytest.mark.asyncio
550
+ async def test_llm_custom_model(client, sample_path):
551
+ config = Configuration(
552
+ llm_processing=LlmProcessing(
553
+ model_id="claude-3.7-sonnet", # Using a model from models.yaml
554
+ fallback_strategy=FallbackStrategy.none(),
555
+ max_completion_tokens=1500,
556
+ temperature=0.3
557
+ ),
558
+ )
559
+ response = await client.upload(sample_path, config)
560
+ assert response.task_id is not None
561
+ assert response.status == "Succeeded"
562
+ assert response.output is not None
563
+ assert response.configuration.llm_processing is not None
564
+ assert response.configuration.llm_processing.model_id == "claude-3.7-sonnet"
565
+
566
+ @pytest.mark.asyncio
567
+ async def test_fallback_strategy_serialization():
568
+ # Test that FallbackStrategy objects serialize correctly
569
+ none_strategy = FallbackStrategy.none()
570
+ default_strategy = FallbackStrategy.default()
571
+ model_strategy = FallbackStrategy.model("gpt-4.1")
572
+
573
+ assert none_strategy.model_dump() == "None"
574
+ assert default_strategy.model_dump() == "Default"
575
+ assert model_strategy.model_dump() == {"Model": "gpt-4.1"}
576
+
577
+ # Test string representation
578
+ assert str(none_strategy) == "None"
579
+ assert str(default_strategy) == "Default"
580
+ assert str(model_strategy) == "Model(gpt-4.1)"
581
+
582
+ @pytest.mark.asyncio
583
+ async def test_combined_config_with_llm_and_other_settings(client, sample_path):
584
+ # Test combining LLM settings with other configuration options
585
+ config = Configuration(
586
+ llm_processing=LlmProcessing(
587
+ model_id="qwen-2.5-vl-7b-instruct",
588
+ fallback_strategy=FallbackStrategy.model("gemini-flash-2.0"),
589
+ temperature=0.4
590
+ ),
591
+ segmentation_strategy=SegmentationStrategy.PAGE,
592
+ segment_processing=SegmentProcessing(
593
+ Page=GenerationConfig(
594
+ html=GenerationStrategy.LLM,
595
+ markdown=GenerationStrategy.LLM
596
+ )
597
+ ),
598
+ chunk_processing=ChunkProcessing(target_length=1024)
599
+ )
600
+
601
+ response = await client.upload(sample_path, config)
602
+ assert response.task_id is not None
603
+ assert response.status == "Succeeded"
604
+ assert response.output is not None
605
+ assert response.configuration.llm_processing is not None
606
+ assert response.configuration.llm_processing.model_id == "qwen-2.5-vl-7b-instruct"
607
+ assert response.configuration.segmentation_strategy == SegmentationStrategy.PAGE
608
+ assert response.configuration.chunk_processing.target_length == 1024
File without changes
File without changes
File without changes