lm-deluge 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

@@ -0,0 +1,28 @@
1
+ def image_generation_openai():
2
+ # TODO: handle result properly
3
+ return {"type": "image_generation"}
4
+
5
+
6
+ def code_interpreter_openai(container: dict | None = None):
7
+ if container is None:
8
+ container = {"type": "auto"}
9
+ return {"type": "code_interpreter", "container": container}
10
+
11
+
12
+ def local_shell_openai():
13
+ return {"type": "local_shell"}
14
+
15
+
16
+ def web_search_openai():
17
+ return {"type": "web_search_preview"}
18
+
19
+
20
+ def computer_use_openai(
21
+ display_width: int = 1024, display_height: int = 768, environment: str = "browser"
22
+ ):
23
+ return {
24
+ "type": "computer_use_preview",
25
+ "display_width": display_width,
26
+ "display_height": display_height,
27
+ "environment": environment,
28
+ }
lm_deluge/client.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import random
2
3
  from typing import Any, Literal, Self, Sequence, overload
3
4
 
4
5
  import numpy as np
@@ -13,12 +14,12 @@ from lm_deluge.batches import (
13
14
  wait_for_batch_completion_async,
14
15
  )
15
16
  from lm_deluge.prompt import CachePattern, Conversation, prompts_to_conversations
16
- from lm_deluge.tool import Tool
17
+ from lm_deluge.tool import MCPServer, Tool
17
18
 
18
- from .api_requests import create_api_request
19
- from .api_requests.base import APIRequestBase, APIResponse, deduplicate_responses
19
+ from .api_requests.base import APIResponse
20
20
  from .config import SamplingParams
21
- from .models import registry
21
+ from .models import APIModel, registry
22
+ from .request_context import RequestContext
22
23
  from .tracker import StatusTracker
23
24
 
24
25
  # from .cache import LevelDBCache, SqliteCache
@@ -135,9 +136,7 @@ class LLMClient(BaseModel):
135
136
  print(
136
137
  "WARNING: using top_logprobs can result in very large outputs. consider limiting max_new_tokens."
137
138
  )
138
- if not all(
139
- registry[model].get("supports_logprobs") for model in self.models
140
- ):
139
+ if not all(registry[model].supports_logprobs for model in self.models):
141
140
  raise ValueError(
142
141
  "logprobs can only be enabled if all models support it."
143
142
  )
@@ -174,6 +173,110 @@ class LLMClient(BaseModel):
174
173
  model_idx = np.random.choice(range(len(self.models)), p=self.model_weights)
175
174
  return self.models[model_idx], self.sampling_params[model_idx]
176
175
 
176
+ def _select_different_model(self, current_model: str):
177
+ """Select a model different from the provided one."""
178
+ other_models = [m for m in self.models if m != current_model]
179
+ if not other_models:
180
+ # No other models available, return current
181
+ return current_model, self.sampling_params[self.models.index(current_model)]
182
+
183
+ # Get weights for other models
184
+ other_indices = [self.models.index(m) for m in other_models]
185
+ weights = [self.model_weights[idx] for idx in other_indices]
186
+ weights = [w / sum(weights) for w in weights] # type: ignore
187
+
188
+ model_idx = np.random.choice(range(len(other_models)), p=weights)
189
+ chosen_model = other_models[model_idx]
190
+ chosen_sp = self.sampling_params[self.models.index(chosen_model)]
191
+ return chosen_model, chosen_sp
192
+
193
+ async def _wait_for_capacity(self, num_tokens: int, tracker: StatusTracker):
194
+ while True:
195
+ if tracker.check_capacity(num_tokens):
196
+ tracker.set_limiting_factor(None)
197
+ return
198
+
199
+ if tracker.seconds_to_pause > 0:
200
+ await asyncio.sleep(tracker.seconds_to_pause)
201
+ else:
202
+ await asyncio.sleep(random.random())
203
+
204
+ async def _execute_request(self, context: RequestContext) -> APIResponse:
205
+ """Create and send a single API request using the provided context."""
206
+ model_obj = APIModel.from_registry(context.model_name)
207
+ request = model_obj.make_request(context)
208
+ response = await request.execute_once()
209
+ return response
210
+
211
+ async def process_single_request(
212
+ self, context: RequestContext, retry_queue: asyncio.Queue | None = None
213
+ ) -> APIResponse:
214
+ """Handle caching and single HTTP call for a request. Failed requests go to retry queue."""
215
+ # Check cache first
216
+ if self.cache:
217
+ cached = self.cache.get(context.prompt)
218
+ if cached:
219
+ cached.local_cache_hit = True
220
+ if context.status_tracker:
221
+ context.status_tracker.task_succeeded(context.task_id)
222
+ return cached
223
+
224
+ # Execute single request
225
+ assert context.status_tracker
226
+ context.status_tracker.update_pbar()
227
+ response = await self._execute_request(context)
228
+
229
+ # Handle successful response
230
+ if not response.is_error:
231
+ context.status_tracker.task_succeeded(context.task_id)
232
+ # Cache successful responses immediately
233
+ if self.cache and response.completion:
234
+ self.cache.put(context.prompt, response)
235
+ # Call callback if provided
236
+ context.maybe_callback(response, context.status_tracker)
237
+ return response
238
+
239
+ # Handle error response - add to retry queue if available
240
+ if retry_queue and context.attempts_left > 1:
241
+ # Decide whether to retry with a different model
242
+ if response.retry_with_different_model and len(self.models) > 1:
243
+ # Switch to different model for retry
244
+ new_model, new_sp = self._select_different_model(context.model_name)
245
+ retry_context = context.copy(
246
+ model_name=new_model,
247
+ sampling_params=new_sp,
248
+ attempts_left=context.attempts_left - 1,
249
+ )
250
+ else:
251
+ # Retry with same model
252
+ retry_context = context.copy(attempts_left=context.attempts_left - 1)
253
+
254
+ # Print error message for debugging
255
+ error_msg = (
256
+ f"Error task {context.task_id}. Model: {response.model_internal}"
257
+ )
258
+ if response.status_code:
259
+ error_msg += f" Code: {response.status_code},"
260
+ error_msg += f" Message: {response.error_message}. Retrying..."
261
+ print(error_msg)
262
+
263
+ # Add to retry queue for later processing
264
+ await retry_queue.put(retry_context)
265
+ return response # Return the error response for now
266
+
267
+ # No retries left or no retry queue - final failure
268
+ context.status_tracker.task_failed(context.task_id)
269
+ context.maybe_callback(response, context.status_tracker)
270
+
271
+ # Print final error message
272
+ error_msg = f"Error task {context.task_id}. Model: {response.model_internal}"
273
+ if response.status_code:
274
+ error_msg += f" Code: {response.status_code},"
275
+ error_msg += f" Message: {response.error_message}. Giving up."
276
+ print(error_msg)
277
+
278
+ return response
279
+
177
280
  @overload
178
281
  async def process_prompts_async(
179
282
  self,
@@ -181,11 +284,8 @@ class LLMClient(BaseModel):
181
284
  *,
182
285
  return_completions_only: Literal[True],
183
286
  show_progress: bool = ...,
184
- tools: list[Tool] | None = ...,
287
+ tools: list[Tool | dict | MCPServer] | None = ...,
185
288
  cache: CachePattern | None = ...,
186
- computer_use: bool = ...,
187
- display_width: int = ...,
188
- display_height: int = ...,
189
289
  use_responses_api: bool = ...,
190
290
  ) -> list[str | None]: ...
191
291
 
@@ -196,11 +296,8 @@ class LLMClient(BaseModel):
196
296
  *,
197
297
  return_completions_only: Literal[False] = ...,
198
298
  show_progress: bool = ...,
199
- tools: list[Tool] | None = ...,
299
+ tools: list[Tool | dict | MCPServer] | None = ...,
200
300
  cache: CachePattern | None = ...,
201
- computer_use: bool = ...,
202
- display_width: int = ...,
203
- display_height: int = ...,
204
301
  use_responses_api: bool = ...,
205
302
  ) -> list[APIResponse | None]: ...
206
303
 
@@ -210,147 +307,117 @@ class LLMClient(BaseModel):
210
307
  *,
211
308
  return_completions_only: bool = False,
212
309
  show_progress: bool = True,
213
- tools: list[Tool] | None = None,
310
+ tools: list[Tool | dict | MCPServer] | None = None,
214
311
  cache: CachePattern | None = None,
215
- computer_use: bool = False,
216
- display_width: int = 1024,
217
- display_height: int = 768,
218
312
  use_responses_api: bool = False,
219
313
  ) -> list[APIResponse | None] | list[str | None] | dict[str, int]:
220
- # if prompts are not Conversations, convert them.
314
+ # Convert prompts to Conversations - no upfront cache checking for dynamic caching!
221
315
  prompts = prompts_to_conversations(prompts)
222
- ids = np.arange(len(prompts))
223
-
224
- # if using cache, check for cached completions
225
- if self.cache:
226
- cached_results = [self.cache.get(prompt) for prompt in prompts]
227
- cache_hit_ids = [
228
- id for id, res in zip(ids, cached_results) if res is not None
229
- ]
230
- cache_hit_results = [res for res in cached_results if res is not None]
231
- assert len(cache_hit_ids) == len(
232
- cache_hit_results
233
- ), "Cache hit ids and results must be the same length."
234
- remaining_ids = np.array([i for i in ids if i not in cache_hit_ids])
235
- remaining_prompts = [prompts[i] for i in remaining_ids]
236
- print(
237
- f"{len(cache_hit_ids)} cache hits; {len(remaining_ids)} prompts remaining."
238
- )
239
-
240
- else:
241
- cache_hit_ids = []
242
- cache_hit_results = []
243
- remaining_prompts = prompts
244
- remaining_ids = ids
245
-
316
+ ids = list(range(len(prompts)))
246
317
  results: list[APIResponse | None] = [None for _ in range(len(prompts))]
247
- if len(remaining_prompts) > 0:
248
- # Create StatusTracker with integrated progress bar
249
- tracker = StatusTracker(
250
- max_requests_per_minute=self.max_requests_per_minute,
251
- max_tokens_per_minute=self.max_tokens_per_minute,
252
- max_concurrent_requests=self.max_concurrent_requests,
253
- use_progress_bar=show_progress,
254
- progress_bar_total=len(prompts),
255
- progress_bar_disable=not show_progress,
256
- use_rich=show_progress, # Disable Rich if progress is disabled
257
- )
258
318
 
259
- # Initialize progress bar and update with cache hits
260
- tracker.init_progress_bar()
261
- if len(cache_hit_ids) > 0:
262
- tracker.update_pbar(len(cache_hit_ids))
263
-
264
- if isinstance(ids, np.ndarray):
265
- ids = ids.tolist() # pyright: ignore
266
-
267
- # calculate dynamically so we don't throttle RPM
268
- seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
269
- next_request = None # variable to hold the next request to call
270
- prompts_not_finished = True
271
- prompts_iter = iter(zip(ids, prompts))
272
- requests: list[APIRequestBase] = []
273
- assert tracker.retry_queue, "retry queue not initialized"
274
- while True:
275
- # get next request (if one is not already waiting for capacity)
276
- retry_request = False
277
- if next_request is None:
278
- if not tracker.retry_queue.empty():
279
- next_request = tracker.retry_queue.get_nowait()
280
- retry_request = True
281
- print(f"Retrying request {next_request.task_id}.")
282
- elif prompts_not_finished:
319
+ # Create StatusTracker
320
+ tracker = StatusTracker(
321
+ max_requests_per_minute=self.max_requests_per_minute,
322
+ max_tokens_per_minute=self.max_tokens_per_minute,
323
+ max_concurrent_requests=self.max_concurrent_requests,
324
+ use_progress_bar=show_progress,
325
+ progress_bar_total=len(prompts),
326
+ progress_bar_disable=not show_progress,
327
+ use_rich=show_progress,
328
+ )
329
+
330
+ tracker.init_progress_bar()
331
+
332
+ # Create retry queue for failed requests
333
+ retry_queue: asyncio.Queue[RequestContext] = asyncio.Queue()
334
+
335
+ # Calculate sleep time for rate limiting
336
+ seconds_to_sleep_each_loop = (60.0 * 0.9) / tracker.max_requests_per_minute
337
+
338
+ # Main dispatch loop - using original pattern but with all prompts
339
+ next_context = None # Persist across iterations like original
340
+ prompts_not_finished = True
341
+ prompts_iter = iter(zip(ids, prompts))
342
+
343
+ while True:
344
+ # Get next context (retry or new) - only if we don't already have one waiting
345
+ retry_request = False
346
+ if next_context is None:
347
+ if not retry_queue.empty():
348
+ next_context = retry_queue.get_nowait()
349
+ retry_request = True
350
+ print(f"Retrying request {next_context.task_id}.")
351
+ elif prompts_not_finished:
352
+ try:
353
+ task_id, prompt = next(prompts_iter)
354
+ model, sampling_params = self._select_model()
355
+ assert isinstance(prompt, Conversation)
356
+ next_context = RequestContext(
357
+ task_id=task_id,
358
+ model_name=model,
359
+ prompt=prompt,
360
+ sampling_params=sampling_params,
361
+ attempts_left=self.max_attempts,
362
+ request_timeout=self.request_timeout,
363
+ status_tracker=tracker,
364
+ tools=tools,
365
+ cache=cache,
366
+ use_responses_api=use_responses_api,
367
+ )
368
+ except StopIteration:
369
+ prompts_not_finished = False
370
+
371
+ # Update capacity - original logic
372
+ tracker.update_capacity()
373
+
374
+ # Dispatch if capacity available - original logic
375
+ if next_context:
376
+ if tracker.check_capacity(next_context.num_tokens, retry=retry_request):
377
+ tracker.set_limiting_factor(None)
378
+
379
+ # Launch simplified request processing
380
+ async def process_and_store(ctx: RequestContext):
283
381
  try:
284
- # get new request
285
- id, prompt = next(prompts_iter)
286
- # select model
287
- model, sampling_params = self._select_model()
288
-
289
- next_request = create_api_request(
290
- task_id=id,
291
- model_name=model,
292
- prompt=prompt, # type: ignore
293
- request_timeout=self.request_timeout,
294
- attempts_left=self.max_attempts,
295
- status_tracker=tracker,
296
- results_arr=requests,
297
- sampling_params=sampling_params,
298
- all_model_names=self.models,
299
- all_sampling_params=self.sampling_params,
300
- tools=tools,
301
- cache=cache,
302
- computer_use=computer_use,
303
- display_width=display_width,
304
- display_height=display_height,
305
- use_responses_api=use_responses_api,
382
+ response = await self.process_single_request(
383
+ ctx, retry_queue
306
384
  )
307
- requests.append(next_request)
308
-
309
- except StopIteration:
310
- prompts_not_finished = False
311
- # print("API requests finished, only retries remain.")
312
-
313
- # update available capacity
314
- tracker.update_capacity()
315
-
316
- # if enough capacity available, call API
317
- if next_request:
318
- next_request_tokens = next_request.num_tokens
319
- if tracker.check_capacity(next_request_tokens, retry=retry_request):
320
- tracker.set_limiting_factor(None)
321
- # call API (attempts_left will be decremented in handle_error if it fails)
322
- asyncio.create_task(next_request.call_api())
323
- next_request = None # reset next_request to empty
324
- # update pbar status
325
- tracker.update_pbar()
326
-
327
- # if all tasks are finished, break
328
- if tracker.num_tasks_in_progress == 0:
329
- break
330
-
331
- # main loop sleeps briefly so concurrent tasks can run
332
- await asyncio.sleep(seconds_to_sleep_each_loop)
333
-
334
- # if a rate limit error was hit recently, pause to cool down
335
- if tracker.seconds_to_pause > 0:
336
- await asyncio.sleep(tracker.seconds_to_pause)
337
- print(f"Pausing {tracker.seconds_to_pause}s to cool down.")
338
-
339
- # after finishing, log final status
340
- tracker.log_final_status()
341
-
342
- # deduplicate results by id
343
- api_results = deduplicate_responses(requests)
344
- for res in api_results:
345
- results[res.id] = res
346
- # set to cache if result has a completion
347
- if self.cache and res.completion:
348
- self.cache.put(prompts[res.id], res)
385
+ results[ctx.task_id] = response
386
+ except Exception as e:
387
+ # Create an error response for validation errors and other exceptions
388
+ from .api_requests.response import APIResponse
389
+
390
+ error_response = APIResponse(
391
+ id=ctx.task_id,
392
+ model_internal=ctx.model_name,
393
+ prompt=ctx.prompt,
394
+ sampling_params=ctx.sampling_params,
395
+ status_code=None,
396
+ is_error=True,
397
+ error_message=str(e),
398
+ )
399
+ results[ctx.task_id] = error_response
400
+ # Mark task as completed so the main loop can finish
401
+ if ctx.status_tracker:
402
+ ctx.status_tracker.task_failed(ctx.task_id)
403
+
404
+ asyncio.create_task(process_and_store(next_context))
405
+ next_context = None # Reset after successful dispatch
406
+
407
+ # Update progress - original logic
408
+ tracker.update_pbar()
409
+
410
+ # Check completion - original logic
411
+ if (
412
+ tracker.num_tasks_in_progress == 0
413
+ and not prompts_not_finished
414
+ and retry_queue.empty()
415
+ ):
416
+ break
349
417
 
350
- # add cache hits back in
351
- for id, res in zip(cache_hit_ids, cache_hit_results):
352
- res.cache_hit = True
353
- results[id] = res
418
+ # Sleep - original logic
419
+ await asyncio.sleep(seconds_to_sleep_each_loop + tracker.seconds_to_pause)
420
+ tracker.log_final_status()
354
421
 
355
422
  if return_completions_only:
356
423
  return [r.completion if r is not None else None for r in results]
@@ -363,7 +430,7 @@ class LLMClient(BaseModel):
363
430
  *,
364
431
  return_completions_only: bool = False,
365
432
  show_progress=True,
366
- tools: list[Tool] | None = None,
433
+ tools: list[Tool | dict | MCPServer] | None = None,
367
434
  cache: CachePattern | None = None,
368
435
  ):
369
436
  return asyncio.run(
@@ -376,7 +443,11 @@ class LLMClient(BaseModel):
376
443
  )
377
444
  )
378
445
 
379
- async def stream(self, prompt: str | Conversation, tools: list[Tool] | None = None):
446
+ async def stream(
447
+ self,
448
+ prompt: str | Conversation,
449
+ tools: list[Tool | dict | MCPServer] | None = None,
450
+ ):
380
451
  model, sampling_params = self._select_model()
381
452
  if isinstance(prompt, str):
382
453
  prompt = Conversation.user(prompt)
@@ -387,6 +458,89 @@ class LLMClient(BaseModel):
387
458
  # final item
388
459
  return item
389
460
 
461
+ async def run_agent_loop(
462
+ self,
463
+ conversation: str | Conversation,
464
+ *,
465
+ tools: list[Tool | dict] | None = None,
466
+ max_rounds: int = 5,
467
+ show_progress: bool = False,
468
+ ) -> tuple[Conversation, APIResponse]:
469
+ """Run a simple agent loop until no more tool calls are returned.
470
+
471
+ The provided ``conversation`` will be mutated and returned alongside the
472
+ final ``APIResponse`` from the model. ``tools`` may include ``Tool``
473
+ instances or built‑in tool dictionaries.
474
+ """
475
+
476
+ if isinstance(conversation, str):
477
+ conversation = Conversation.user(conversation)
478
+
479
+ last_response: APIResponse | None = None
480
+
481
+ for _ in range(max_rounds):
482
+ responses = await self.process_prompts_async(
483
+ [conversation],
484
+ tools=tools, # type: ignore
485
+ return_completions_only=False,
486
+ show_progress=show_progress,
487
+ )
488
+
489
+ last_response = responses[0]
490
+ if last_response is None or last_response.content is None:
491
+ break
492
+
493
+ conversation.add(last_response.content)
494
+
495
+ tool_calls = last_response.content.tool_calls
496
+ if not tool_calls:
497
+ break
498
+
499
+ for call in tool_calls:
500
+ tool_obj = None
501
+ if tools:
502
+ for t in tools:
503
+ if isinstance(t, Tool) and t.name == call.name:
504
+ tool_obj = t
505
+ break
506
+
507
+ if isinstance(tool_obj, Tool) and tool_obj.run is not None:
508
+ try:
509
+ result = await tool_obj.acall(**call.arguments)
510
+ except Exception as e: # pragma: no cover - best effort
511
+ result = f"Error: {e}"
512
+ else:
513
+ result = f"Tool {call.name} not found"
514
+
515
+ if not isinstance(result, (str, dict, list)):
516
+ result = str(result)
517
+
518
+ conversation.add_tool_result(call.id, result) # type: ignore
519
+
520
+ if last_response is None:
521
+ raise RuntimeError("model did not return a response")
522
+
523
+ return conversation, last_response
524
+
525
+ def run_agent_loop_sync(
526
+ self,
527
+ conversation: str | Conversation,
528
+ *,
529
+ tools: list[Tool | dict | MCPServer] | None = None,
530
+ max_rounds: int = 5,
531
+ show_progress: bool = False,
532
+ ) -> tuple[Conversation, APIResponse]:
533
+ """Synchronous wrapper for :meth:`run_agent_loop`."""
534
+
535
+ return asyncio.run(
536
+ self.run_agent_loop(
537
+ conversation,
538
+ tools=tools, # type: ignore
539
+ max_rounds=max_rounds,
540
+ show_progress=show_progress,
541
+ )
542
+ )
543
+
390
544
  async def submit_batch_job(
391
545
  self,
392
546
  prompts: Sequence[str | list[dict] | Conversation],
@@ -409,7 +563,7 @@ class LLMClient(BaseModel):
409
563
  if len(self.models) != 1:
410
564
  raise ValueError("Batch jobs can only be submitted with a single model.")
411
565
  model = self.models[0]
412
- api_spec = registry[model].get("api_spec", None)
566
+ api_spec = registry[model].api_spec
413
567
 
414
568
  if api_spec == "openai":
415
569
  return await submit_batches_oa(model, self.sampling_params[0], prompts)
lm_deluge/image.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  from contextlib import contextmanager
3
- from functools import cached_property
4
3
  import io
5
4
  import requests
6
5
  from PIL import Image as PILImage # type: ignore
@@ -18,6 +17,8 @@ class Image:
18
17
  media_type: str | None = None # inferred if None
19
18
  detail: Literal["low", "high", "auto"] = "auto"
20
19
  type: str = field(init=False, default="image")
20
+ _fingerprint_cache: str | None = field(init=False, default=None)
21
+ _size_cache: tuple[int, int] | None = field(init=False, default=None)
21
22
 
22
23
  @classmethod
23
24
  def from_pdf(
@@ -95,12 +96,14 @@ class Image:
95
96
  if img:
96
97
  img.close()
97
98
 
98
- @cached_property
99
+ @property
99
100
  def size(self) -> tuple[int, int]:
100
- with self._image() as img:
101
- return img.size
101
+ if self._size_cache is None:
102
+ with self._image() as img:
103
+ self._size_cache = img.size
104
+ return self._size_cache
102
105
 
103
- @cached_property
106
+ @property
104
107
  def num_pixels(self) -> int:
105
108
  return self.size[0] * self.size[1]
106
109
 
@@ -143,11 +146,13 @@ class Image:
143
146
  new_width = int(new_height / height * width)
144
147
  return self._resize((new_width, new_height))
145
148
 
146
- @cached_property
149
+ @property
147
150
  def fingerprint(self) -> str:
148
151
  # return base64 of a very small version of the image
149
- small_image = self._resize_longer(max_size=48) # longer side = 48px
150
- return base64.b64encode(small_image).decode("utf-8")
152
+ if self._fingerprint_cache is None:
153
+ small_image = self._resize_longer(max_size=48) # longer side = 48px
154
+ self._fingerprint_cache = base64.b64encode(small_image).decode("utf-8")
155
+ return self._fingerprint_cache
151
156
 
152
157
  def resize(self, max_size: int) -> None:
153
158
  """
@@ -1,9 +1,12 @@
1
+ import asyncio
1
2
  import io
2
3
  import json
3
- from ..prompt import Conversation
4
- import asyncio
5
- from ..client import LLMClient
6
4
  from typing import Any
5
+
6
+ from lm_deluge.file import File
7
+
8
+ from ..client import LLMClient
9
+ from ..prompt import Conversation
7
10
  from ..util.json import load_json
8
11
 
9
12
  try:
@@ -62,6 +65,15 @@ async def extract_async(
62
65
  + 'like `{"error": "The document is not relevant to the schema."}`.'
63
66
  )
64
67
 
68
+ file_prompt = (
69
+ f"Given the attached {document_name} file, extract the {object_name}information "
70
+ + "from it according to the following JSON schema:\n\n```json\n"
71
+ + json.dumps(schema_dict, indent=2)
72
+ + "Return the extracted information as JSON, no explanation required. "
73
+ + f"If the {document_name} seems to be totally irrelevant to the schema (not just incomplete), you may return a JSON object "
74
+ + 'like `{"error": "The document is not relevant to the schema."}`.'
75
+ )
76
+
65
77
  prompts = []
66
78
  for input in inputs:
67
79
  if isinstance(input, str):
@@ -74,8 +86,15 @@ async def extract_async(
74
86
  prompts.append(
75
87
  Conversation.user(text=image_only_prompt, image=buffer.getvalue())
76
88
  )
89
+ elif isinstance(input, File):
90
+ data = input.data
91
+ if isinstance(data, io.BytesIO):
92
+ data = data.getvalue()
93
+ prompts.append(Conversation.user(text=file_prompt, file=data))
77
94
  else:
78
- raise ValueError("inputs must be a list of strings or PIL images.")
95
+ raise ValueError(
96
+ "inputs must be a list of strings or PIL images or a File object."
97
+ )
79
98
 
80
99
  if return_prompts:
81
100
  return prompts
@@ -0,0 +1 @@
1
+ # NOT IMPLEMENTED YET!