model-library 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. model_library/__init__.py +23 -0
  2. model_library/base.py +814 -0
  3. model_library/config/ai21labs_models.yaml +99 -0
  4. model_library/config/alibaba_models.yaml +91 -0
  5. model_library/config/all_models.json +13479 -0
  6. model_library/config/amazon_models.yaml +276 -0
  7. model_library/config/anthropic_models.yaml +370 -0
  8. model_library/config/cohere_models.yaml +177 -0
  9. model_library/config/deepseek_models.yaml +47 -0
  10. model_library/config/dummy_model.yaml +38 -0
  11. model_library/config/fireworks_models.yaml +228 -0
  12. model_library/config/google_models.yaml +516 -0
  13. model_library/config/inception_models.yaml +24 -0
  14. model_library/config/kimi_models.yaml +34 -0
  15. model_library/config/mistral_models.yaml +143 -0
  16. model_library/config/openai_models.yaml +783 -0
  17. model_library/config/perplexity_models.yaml +91 -0
  18. model_library/config/together_models.yaml +866 -0
  19. model_library/config/xai_models.yaml +266 -0
  20. model_library/config/zai_models.yaml +65 -0
  21. model_library/exceptions.py +288 -0
  22. model_library/file_utils.py +114 -0
  23. model_library/model_utils.py +26 -0
  24. model_library/providers/ai21labs.py +193 -0
  25. model_library/providers/alibaba.py +147 -0
  26. model_library/providers/amazon.py +367 -0
  27. model_library/providers/anthropic.py +419 -0
  28. model_library/providers/azure.py +43 -0
  29. model_library/providers/cohere.py +100 -0
  30. model_library/providers/deepseek.py +115 -0
  31. model_library/providers/fireworks.py +133 -0
  32. model_library/providers/google/__init__.py +4 -0
  33. model_library/providers/google/batch.py +299 -0
  34. model_library/providers/google/google.py +467 -0
  35. model_library/providers/inception.py +102 -0
  36. model_library/providers/kimi.py +102 -0
  37. model_library/providers/mistral.py +299 -0
  38. model_library/providers/openai.py +924 -0
  39. model_library/providers/perplexity.py +101 -0
  40. model_library/providers/together.py +249 -0
  41. model_library/providers/vals.py +307 -0
  42. model_library/providers/xai.py +332 -0
  43. model_library/providers/zai.py +102 -0
  44. model_library/py.typed +0 -0
  45. model_library/register_models.py +385 -0
  46. model_library/registry_utils.py +202 -0
  47. model_library/settings.py +34 -0
  48. model_library/utils.py +151 -0
  49. model_library-0.1.0.dist-info/METADATA +268 -0
  50. model_library-0.1.0.dist-info/RECORD +53 -0
  51. model_library-0.1.0.dist-info/WHEEL +5 -0
  52. model_library-0.1.0.dist-info/licenses/LICENSE +21 -0
  53. model_library-0.1.0.dist-info/top_level.txt +1 -0
model_library/base.py ADDED
@@ -0,0 +1,814 @@
1
+ import io
2
+ import logging
3
+ import time
4
+ import uuid
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Awaitable
7
+ from pprint import pformat
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Annotated,
11
+ Any,
12
+ Callable,
13
+ Literal,
14
+ Mapping,
15
+ Sequence,
16
+ TypeVar,
17
+ cast,
18
+ )
19
+
20
+ from pydantic import computed_field, field_validator, model_serializer
21
+ from pydantic.fields import Field
22
+ from pydantic.main import BaseModel
23
+ from typing_extensions import override
24
+
25
+ from model_library.exceptions import (
26
+ ImmediateRetryException,
27
+ retry_llm_call,
28
+ )
29
+ from model_library.utils import sum_optional, truncate_str
30
+
31
+ PydanticT = TypeVar("PydanticT", bound=BaseModel)
32
+
33
+ DEFAULT_MAX_TOKENS = 2048
34
+ DEFAULT_TEMPERATURE = 0.7
35
+ DEFAULT_TOP_P = 1
36
+
37
+ if TYPE_CHECKING:
38
+ from model_library.providers.openai import OpenAIModel
39
+
40
+ """
41
+ --- FILES ---
42
+ """
43
+
44
+
45
+ class FileBase(BaseModel):
46
+ type: Literal["image", "file"]
47
+ name: str
48
+ mime: str
49
+
50
+ @override
51
+ def __repr__(self):
52
+ attrs = vars(self).copy()
53
+ if "base64" in attrs:
54
+ attrs["base64"] = truncate_str(attrs["base64"])
55
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
56
+
57
+
58
+ class FileWithBase64(FileBase):
59
+ append_type: Literal["base64"] = "base64"
60
+ base64: str
61
+
62
+
63
+ class FileWithUrl(FileBase):
64
+ append_type: Literal["url"] = "url"
65
+ url: str
66
+
67
+
68
+ class FileWithId(FileBase):
69
+ append_type: Literal["file_id"] = "file_id"
70
+ file_id: str
71
+
72
+
73
+ FileInput = Annotated[
74
+ FileWithBase64 | FileWithUrl | FileWithId,
75
+ Field(discriminator="append_type"),
76
+ ]
77
+
78
+
79
+ """
80
+ --- TOOLS ---
81
+ """
82
+
83
+
84
+ class ToolBody(BaseModel):
85
+ name: str
86
+ description: str
87
+ properties: dict[str, Any]
88
+ required: list[str]
89
+ kwargs: dict[str, Any] = {}
90
+
91
+
92
+ class ToolDefinition(BaseModel):
93
+ name: str # acts as a key
94
+ body: ToolBody | Any
95
+
96
+
97
+ class ToolCall(BaseModel):
98
+ id: str
99
+ call_id: str | None = None
100
+ name: str
101
+ args: dict[str, Any] | str
102
+
103
+
104
+ """
105
+ --- INPUT ---
106
+ """
107
+
108
+ RawResponse = Any
109
+
110
+
111
+ class ToolInput(BaseModel):
112
+ tools: list[ToolDefinition] = []
113
+
114
+
115
+ class ToolResult(BaseModel):
116
+ tool_call: ToolCall
117
+ result: Any
118
+
119
+
120
+ class TextInput(BaseModel):
121
+ text: str
122
+
123
+
124
+ RawInputItem = dict[
125
+ str, Any
126
+ ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
127
+
128
+
129
+ InputItem = (
130
+ TextInput | FileInput | ToolResult | RawInputItem | RawResponse
131
+ ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
132
+
133
+
134
+ """
135
+ --- OUTPUT ---
136
+ """
137
+
138
+
139
+ class Citation(BaseModel):
140
+ type: str | None = None
141
+ title: str | None = None
142
+ url: str | None = None
143
+ start_index: int | None = None
144
+ end_index: int | None = None
145
+ file_id: str | None = None
146
+ filename: str | None = None
147
+ index: int | None = None
148
+ container_id: str | None = None
149
+
150
+
151
+ class QueryResultExtras(BaseModel):
152
+ citations: list[Citation] = Field(default_factory=list)
153
+
154
+
155
+ class QueryResultCost(BaseModel):
156
+ """
157
+ Cost information for a query
158
+ Includes total cost and a structured breakdown.
159
+ """
160
+
161
+ input: float
162
+ output: float
163
+ reasoning: float | None = None
164
+ cache_read: float | None = None
165
+ cache_write: float | None = None
166
+
167
+ @computed_field
168
+ @property
169
+ def total(self) -> float:
170
+ return sum(
171
+ filter(
172
+ None,
173
+ [
174
+ self.input,
175
+ self.output,
176
+ self.reasoning,
177
+ self.cache_read,
178
+ self.cache_write,
179
+ ],
180
+ )
181
+ )
182
+
183
+ @override
184
+ def __repr__(self):
185
+ use_cents = self.total < 1
186
+
187
+ def format_cost(value: float | None):
188
+ if value is None:
189
+ return None
190
+ return f"{value * 100:.3f} cents" if use_cents else f"${value:.2f}"
191
+
192
+ return (
193
+ f"{format_cost(self.total)} "
194
+ + f"(uncached input: {format_cost(self.input)} | output: {format_cost(self.output)} | reasoning: {format_cost(self.reasoning)} | cache_read: {format_cost(self.cache_read)} | cache_write: {format_cost(self.cache_write)})"
195
+ )
196
+
197
+
198
+ class QueryResultMetadata(BaseModel):
199
+ """
200
+ Metadata for a query: token usage and timing.
201
+
202
+ """
203
+
204
+ cost: QueryResultCost | None = None # set post query
205
+ duration_seconds: float | None = None # set post query
206
+ in_tokens: int = 0
207
+ out_tokens: int = 0
208
+ reasoning_tokens: int | None = None
209
+ cache_read_tokens: int | None = None
210
+ cache_write_tokens: int | None = None
211
+
212
+ @property
213
+ def default_duration_seconds(self) -> float:
214
+ return self.duration_seconds or 0
215
+
216
+ def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
217
+ return QueryResultMetadata(
218
+ in_tokens=self.in_tokens + other.in_tokens,
219
+ out_tokens=self.out_tokens + other.out_tokens,
220
+ reasoning_tokens=sum_optional(
221
+ self.reasoning_tokens, other.reasoning_tokens
222
+ ),
223
+ cache_read_tokens=sum_optional(
224
+ self.cache_read_tokens, other.cache_read_tokens
225
+ ),
226
+ cache_write_tokens=sum_optional(
227
+ self.cache_write_tokens, other.cache_write_tokens
228
+ ),
229
+ duration_seconds=self.default_duration_seconds
230
+ + other.default_duration_seconds,
231
+ )
232
+
233
+ @override
234
+ def __repr__(self):
235
+ attrs = vars(self).copy()
236
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
237
+
238
+
239
+ class QueryResult(BaseModel):
240
+ """
241
+ Result of a query
242
+ Contains the text, reasoning, metadata, tool calls, and history
243
+ """
244
+
245
+ output_text: str | None = None
246
+ reasoning: str | None = None
247
+ metadata: QueryResultMetadata = Field(default_factory=QueryResultMetadata)
248
+ tool_calls: list[ToolCall] = Field(default_factory=list)
249
+ history: list[InputItem] = Field(default_factory=list)
250
+ extras: QueryResultExtras = Field(default_factory=QueryResultExtras)
251
+ raw: dict[str, Any] = Field(default_factory=dict)
252
+
253
+ @property
254
+ def output_text_str(self) -> str:
255
+ return self.output_text or ""
256
+
257
+ @field_validator("reasoning", mode="before")
258
+ def default_reasoning(cls, v: str | None):
259
+ return None if not v else v # make reasoning None if empty
260
+
261
+ @property
262
+ def search_results(self) -> Any | None:
263
+ """Expose provider-supplied search metadata without additional processing."""
264
+ raw_dict = cast(dict[str, Any], getattr(self, "raw", {}))
265
+ raw_candidate = raw_dict.get("search_results")
266
+ if raw_candidate is not None:
267
+ return raw_candidate
268
+
269
+ return _get_from_history(self.history, "search_results")
270
+
271
+ @override
272
+ def __repr__(self):
273
+ attrs = vars(self).copy()
274
+ ordered_attrs = {
275
+ "output_text": truncate_str(attrs.pop("output_text", None), 400),
276
+ "reasoning": truncate_str(attrs.pop("reasoning", None), 400),
277
+ "metadata": attrs.pop("metadata", None),
278
+ }
279
+ if self.tool_calls:
280
+ ordered_attrs["tool_calls"] = self.tool_calls
281
+ return f"{self.__class__.__name__}(\n{pformat(ordered_attrs, indent=2, sort_dicts=False)}\n)"
282
+
283
+
284
+ def _get_from_history(history: Sequence[InputItem], key: str) -> Any | None:
285
+ for item in reversed(history):
286
+ value = getattr(item, key, None)
287
+ if value is not None:
288
+ return value
289
+
290
+ extra = getattr(item, "model_extra", None)
291
+ if isinstance(extra, Mapping):
292
+ value = cast(Mapping[str, Any], extra).get(key)
293
+ if value is not None:
294
+ return value
295
+
296
+ return None
297
+
298
+
299
+ class ProviderConfig(BaseModel):
300
+ """Base class for provider-specific configs. Do not use directly."""
301
+
302
+ @model_serializer(mode="plain")
303
+ def serialize_actual(self):
304
+ return self.__dict__
305
+
306
+
307
+ class LLMConfig(BaseModel):
308
+ max_tokens: int = DEFAULT_MAX_TOKENS
309
+ temperature: float | None = None
310
+ top_p: float | None = None
311
+ reasoning: bool = False
312
+ reasoning_effort: str | None = None
313
+ supports_images: bool = False
314
+ supports_files: bool = False
315
+ supports_videos: bool = False
316
+ supports_batch: bool = False
317
+ supports_temperature: bool = True
318
+ supports_tools: bool = False
319
+ native: bool = True
320
+ provider_config: ProviderConfig | None = None
321
+ registry_key: str | None = None
322
+
323
+
324
+ RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
325
+
326
+ R = TypeVar("R") # return type
327
+
328
+
329
+ class LLM(ABC):
330
+ """
331
+ Base class for all LLMs
332
+ LLM call errors should be raised as exceptions
333
+ """
334
+
335
+ def __init__(
336
+ self,
337
+ model_name: str,
338
+ provider: str,
339
+ *,
340
+ config: LLMConfig | None = None,
341
+ ):
342
+ self.instance_id = uuid.uuid4().hex[:8]
343
+
344
+ self.provider: str = provider
345
+ self.model_name: str = model_name
346
+
347
+ config = config or LLMConfig()
348
+ self._registry_key = config.registry_key
349
+
350
+ if config.provider_config:
351
+ self.provider_config = config.provider_config
352
+
353
+ self.max_tokens: int = config.max_tokens
354
+ self.temperature: float | None = config.temperature
355
+ self.top_p: float | None = config.top_p
356
+
357
+ self.reasoning: bool = config.reasoning
358
+ self.reasoning_effort: str | None = config.reasoning_effort
359
+
360
+ self.supports_files: bool = config.supports_files
361
+ self.supports_videos: bool = config.supports_videos
362
+ self.supports_images: bool = config.supports_images
363
+ self.supports_batch: bool = config.supports_batch
364
+ self.supports_temperature: bool = config.supports_temperature
365
+ self.supports_tools: bool = config.supports_tools
366
+
367
+ self.native: bool = config.native
368
+ self.delegate: "OpenAIModel | None" = None
369
+ self.batch: LLMBatchMixin | None = None
370
+
371
+ self.logger: logging.Logger = logging.getLogger(
372
+ f"llm.{provider}.{model_name}<instance={self.instance_id}>"
373
+ )
374
+ self.custom_retrier: Callable[..., RetrierType] | None = retry_llm_call
375
+
376
+ @override
377
+ def __repr__(self):
378
+ attrs = vars(self).copy()
379
+ attrs.pop("logger", None)
380
+ attrs.pop("custom_retrier", None)
381
+ attrs.pop("_key", None)
382
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
383
+
384
+ @abstractmethod
385
+ def get_client(self) -> object:
386
+ """Return the instance of the appropriate SDK client."""
387
+ ...
388
+
389
+ @staticmethod
390
+ async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
391
+ """
392
+ Time the query
393
+ """
394
+ start = time.perf_counter()
395
+ result = await func()
396
+ return result, round(time.perf_counter() - start, 4)
397
+
398
+ @staticmethod
399
+ async def immediate_retry_wrapper(
400
+ func: Callable[[], Awaitable[R]],
401
+ logger: logging.Logger,
402
+ ) -> R:
403
+ """
404
+ Retry the query immediately
405
+ """
406
+ MAX_IMMEDIATE_RETRIES = 10
407
+ retries = 0
408
+ while True:
409
+ try:
410
+ return await func()
411
+ except ImmediateRetryException as e:
412
+ if retries >= MAX_IMMEDIATE_RETRIES:
413
+ logger.error(f"Query reached max immediate retries {retries}: {e}")
414
+ raise Exception(
415
+ f"Query reached max immediate retries {retries}: {e}"
416
+ ) from e
417
+ retries += 1
418
+
419
+ logger.warning(
420
+ f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
421
+ )
422
+
423
+ @staticmethod
424
+ async def backoff_retry_wrapper(
425
+ func: Callable[..., Awaitable[R]],
426
+ backoff_retrier: RetrierType | None,
427
+ ) -> R:
428
+ """
429
+ Retry the query with backoff
430
+ """
431
+ if not backoff_retrier:
432
+ return await func()
433
+ return await backoff_retrier(func)()
434
+
435
+ async def delegate_query(
436
+ self,
437
+ input: Sequence[InputItem],
438
+ *,
439
+ tools: list[ToolDefinition] = [],
440
+ **kwargs: object,
441
+ ) -> QueryResult:
442
+ if not self.delegate:
443
+ raise Exception("Delegate not set")
444
+ return await self.delegate._query_impl(input, tools=tools, **kwargs) # pyright: ignore[reportPrivateUsage]
445
+
446
+ async def query(
447
+ self,
448
+ input: Sequence[InputItem] | str,
449
+ *,
450
+ history: Sequence[InputItem] = [],
451
+ tools: list[ToolDefinition] = [],
452
+ # for backwards compatibility
453
+ files: list[FileInput] = [],
454
+ images: list[FileInput] = [],
455
+ **kwargs: object,
456
+ ) -> QueryResult:
457
+ """
458
+ Query the model
459
+ Join input with history
460
+ Log, Time, and Retry
461
+ """
462
+ # format str input
463
+ if isinstance(input, str):
464
+ input = [TextInput(text=input)]
465
+
466
+ # prepends files and images to input
467
+ input = [*files, *images, *input]
468
+
469
+ # format input info
470
+ item_info = f"--- input ({len(input)}): {get_pretty_input_types(input)}\n"
471
+ if history:
472
+ item_info += (
473
+ f"--- history({len(history)}): {get_pretty_input_types(history)}\n"
474
+ )
475
+
476
+ # format tool info
477
+ tool_results = [t for t in input if isinstance(t, ToolResult)]
478
+ tool_names = [tool.name for tool in tools or []]
479
+
480
+ tool_info = (
481
+ f"--- tools ({len(tools)}): {tool_names}\n"
482
+ + f"--- tool results ({len(tool_results)}): "
483
+ + f"{[{tool.tool_call.name: truncate_str(str(tool.result))} for tool in tool_results]}\n"
484
+ if tools
485
+ else ""
486
+ )
487
+
488
+ short_kwargs = {k: truncate_str(repr(v)) for k, v in kwargs.items()}
489
+
490
+ # join input with history
491
+ input = [*history, *input]
492
+
493
+ # unique logger for the query
494
+ query_id = uuid.uuid4().hex[:14]
495
+ query_logger = logging.getLogger(f"{self.logger.name}<query={query_id}>")
496
+
497
+ query_logger.info(
498
+ "Query started:\n" + item_info + tool_info + f"--- kwargs: {short_kwargs}\n"
499
+ )
500
+
501
+ async def query_func() -> QueryResult:
502
+ return await self._query_impl(input, tools=tools, **kwargs)
503
+
504
+ async def timed_query() -> tuple[QueryResult, float]:
505
+ return await LLM.timer_wrapper(query_func)
506
+
507
+ async def immediate_retry() -> tuple[QueryResult, float]:
508
+ return await LLM.immediate_retry_wrapper(timed_query, query_logger)
509
+
510
+ async def backoff_retry() -> tuple[QueryResult, float]:
511
+ backoff_retrier = (
512
+ self.custom_retrier(query_logger) if self.custom_retrier else None
513
+ )
514
+ return await LLM.backoff_retry_wrapper(immediate_retry, backoff_retrier)
515
+
516
+ output, duration = await backoff_retry()
517
+ output.metadata.duration_seconds = duration
518
+ output.metadata.cost = await self._calculate_cost(output.metadata)
519
+
520
+ query_logger.info(f"Query completed: {repr(output)}")
521
+
522
+ return output
523
+
524
+ async def query_json(
525
+ self,
526
+ input: Sequence[InputItem],
527
+ pydantic_model: type[PydanticT],
528
+ **kwargs: object,
529
+ ) -> PydanticT:
530
+ """Query the model with JSON response format using Pydantic model.
531
+
532
+ This is a convenience method that is not implemented for all providers.
533
+ Only OpenAI and Google providers currently support this method.
534
+
535
+ Args:
536
+ input: Input items (text, files, etc.)
537
+ pydantic_model: Pydantic model class defining the expected response structure
538
+ **kwargs: Additional arguments passed to the query method
539
+
540
+ Returns:
541
+ Instance of the pydantic_model with the model's response
542
+
543
+ Raises:
544
+ NotImplementedError: If the provider does not support structured JSON output
545
+ """
546
+ raise NotImplementedError(
547
+ f"query_json is not implemented for {self.__class__.__name__}. "
548
+ f"Only OpenAI and Google providers currently support this method."
549
+ )
550
+
551
+ async def _calculate_cost(
552
+ self,
553
+ metadata: QueryResultMetadata,
554
+ batch: bool = False,
555
+ bill_reasoning: bool = True,
556
+ ) -> QueryResultCost | None:
557
+ """Calculate cost for a query"""
558
+ from model_library.registry_utils import get_model_cost
559
+
560
+ if not self._registry_key:
561
+ self.logger.warning("Model has no registry key, skipping cost calculation")
562
+ return None
563
+
564
+ costs = get_model_cost(self._registry_key)
565
+ if not costs:
566
+ return None
567
+
568
+ MILLION = 1_000_000
569
+
570
+ # base input and output
571
+ if costs.input is None or costs.output is None:
572
+ raise Exception("Base costs not set")
573
+ input_cost = costs.input
574
+ output_cost = costs.output
575
+
576
+ # apply fixed values or discounts/markup
577
+ # applied before other price changes
578
+ cache_read_cost, cache_write_cost = None, None
579
+ if metadata.cache_read_tokens or metadata.cache_write_tokens:
580
+ if not costs.cache:
581
+ raise Exception("Cache costs not set")
582
+ cache_read_cost, cache_write_cost = costs.cache.get_costs(
583
+ input_cost, output_cost
584
+ )
585
+
586
+ # costs for long context
587
+ total_in = metadata.in_tokens + (metadata.cache_read_tokens or 0)
588
+ if costs.context and total_in > costs.context.threshold:
589
+ input_cost, output_cost = costs.context.get_costs(
590
+ input_cost,
591
+ output_cost,
592
+ total_in,
593
+ )
594
+ if costs.context.cache:
595
+ cache_read_cost, cache_write_cost = costs.context.cache.get_costs(
596
+ input_cost, output_cost
597
+ )
598
+
599
+ # costs for batching
600
+ if batch:
601
+ if not costs.batch:
602
+ raise Exception("Batch costs not set")
603
+ input_cost, output_cost = costs.batch.get_costs(input_cost, output_cost)
604
+
605
+ return QueryResultCost(
606
+ input=input_cost * metadata.in_tokens / MILLION,
607
+ output=output_cost * metadata.out_tokens / MILLION,
608
+ reasoning=output_cost * metadata.reasoning_tokens / MILLION
609
+ if metadata.reasoning_tokens is not None and bill_reasoning
610
+ else None,
611
+ cache_read=cache_read_cost * metadata.cache_read_tokens / MILLION
612
+ if metadata.cache_read_tokens is not None and cache_read_cost
613
+ else None,
614
+ cache_write=cache_write_cost * metadata.cache_write_tokens / MILLION
615
+ if metadata.cache_write_tokens is not None and cache_write_cost
616
+ else None,
617
+ )
618
+
619
+ @abstractmethod
620
+ async def _query_impl(
621
+ self,
622
+ input: Sequence[InputItem],
623
+ *,
624
+ tools: list[ToolDefinition],
625
+ **kwargs: object, # TODO: pass in query logger
626
+ ) -> QueryResult:
627
+ """
628
+ Query the model with input
629
+ Input can consist on text, images, files, or model specific raw responses
630
+ Optionally pass in tools
631
+ Kwargs will be passed to the model call (apart from exceptions like system_prompt)
632
+ Images and files should be preprocessed according to what the model supports:
633
+ - base64
634
+ - url
635
+ - file_id
636
+ """
637
+ ...
638
+
639
+ @abstractmethod
640
+ async def parse_input(
641
+ self,
642
+ input: Sequence[InputItem],
643
+ **kwargs: Any,
644
+ ) -> Any:
645
+ """
646
+ Parses input into the appropriate format for the model
647
+ Handles prompts, images, and files
648
+ Handles history and tool call results
649
+ Calls
650
+ - parse_image
651
+ - parse_file
652
+ """
653
+ ...
654
+
655
+ @abstractmethod
656
+ async def parse_image(self, image: FileInput) -> Any:
657
+ """Parse an image into the appropriate format for the model"""
658
+ ...
659
+
660
+ @abstractmethod
661
+ async def parse_file(self, file: FileInput) -> Any:
662
+ """Parse a file into the appropriate format for the model"""
663
+ ...
664
+
665
+ @abstractmethod
666
+ async def parse_tools(self, tools: list[ToolDefinition]) -> Any:
667
+ """Parse tools into the appropriate format for the model"""
668
+ ...
669
+
670
+ @abstractmethod
671
+ async def upload_file(
672
+ self,
673
+ name: str,
674
+ mime: str,
675
+ bytes: io.BytesIO,
676
+ type: Literal["image", "file"] = "file",
677
+ ) -> FileWithId:
678
+ """Upload a file to the model provider"""
679
+ ...
680
+
681
+
682
+ class BatchResult(BaseModel):
683
+ custom_id: str
684
+ output: QueryResult
685
+ error_message: str | None = None
686
+
687
+
688
+ class LLMBatchMixin(ABC):
689
+ @abstractmethod
690
+ async def create_batch_query_request(
691
+ self,
692
+ custom_id: str,
693
+ input: Sequence[InputItem],
694
+ **kwargs: object,
695
+ ) -> dict[str, Any]:
696
+ """Return a single query request
697
+
698
+ The batch api sends out a batch of query requests to various endpoints.
699
+
700
+ For example OpenAI sends can send requests to /v1/responses or /v1/chat/completions endpoints.
701
+
702
+ This method creates a query request for methods such methods
703
+ """
704
+ ...
705
+
706
+ @abstractmethod
707
+ async def batch_query(
708
+ self,
709
+ batch_name: str,
710
+ requests: list[dict[str, Any]],
711
+ ) -> str:
712
+ """
713
+ Batch query the model
714
+ Returns:
715
+ str: batch_id
716
+ Raises:
717
+ Exception: If failed to batch query
718
+ """
719
+ ...
720
+
721
+ @abstractmethod
722
+ async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
723
+ """
724
+ Returns results for batch
725
+ Raises:
726
+ Exception: If failed to get results
727
+ """
728
+ ...
729
+
730
+ @abstractmethod
731
+ async def get_batch_progress(self, batch_id: str) -> int:
732
+ """
733
+ Returns number of completed requests for batch
734
+ Raises:
735
+ Exception: If failed to get progress
736
+ """
737
+ ...
738
+
739
+ @abstractmethod
740
+ async def cancel_batch_request(self, batch_id: str) -> None:
741
+ """
742
+ Cancels batch
743
+ Raises:
744
+ Exception: If failed to cancel
745
+ """
746
+ ...
747
+
748
+ @abstractmethod
749
+ async def get_batch_status(
750
+ self,
751
+ batch_id: str,
752
+ ) -> str:
753
+ """
754
+ Returns batch status
755
+ Raises:
756
+ Exception: If failed to get status
757
+ """
758
+ ...
759
+
760
+ @classmethod
761
+ @abstractmethod
762
+ def is_batch_status_completed(
763
+ cls,
764
+ batch_status: str,
765
+ ) -> bool:
766
+ """
767
+ Returns if batch status is completed
768
+
769
+ A completed state is any state that is final and not in-progress
770
+ Example: failed | cancelled | expired | completed
771
+
772
+ An incompleted state is any state that is not completed
773
+ Example: in_progress | pending | running
774
+ """
775
+ ...
776
+
777
+ @classmethod
778
+ @abstractmethod
779
+ def is_batch_status_failed(
780
+ cls,
781
+ batch_status: str,
782
+ ) -> bool:
783
+ """Returns if batch status is failed"""
784
+ ...
785
+
786
+ @classmethod
787
+ @abstractmethod
788
+ def is_batch_status_cancelled(
789
+ cls,
790
+ batch_status: str,
791
+ ) -> bool:
792
+ """Returns if batch status is cancelled"""
793
+ ...
794
+
795
+
796
+ def get_pretty_input_types(input: Sequence["InputItem"]) -> str:
797
+ # for logging
798
+ def process_item(item: "InputItem"):
799
+ match item:
800
+ case TextInput():
801
+ return truncate_str(repr(item))
802
+ case FileBase(): # FileInput
803
+ return repr(item)
804
+ case ToolResult():
805
+ return repr(item)
806
+ case dict():
807
+ item = cast(RawInputItem, item)
808
+ return repr(item)
809
+ case _:
810
+ # RawResponse
811
+ return repr(item)
812
+
813
+ processed_items = [f" {process_item(item)}" for item in input]
814
+ return "\n" + "\n".join(processed_items) if processed_items else ""