model-library 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. model_library/__init__.py +7 -3
  2. model_library/base/__init__.py +7 -0
  3. model_library/{base.py → base/base.py} +47 -423
  4. model_library/base/batch.py +121 -0
  5. model_library/base/delegate_only.py +94 -0
  6. model_library/base/input.py +100 -0
  7. model_library/base/output.py +175 -0
  8. model_library/base/utils.py +42 -0
  9. model_library/config/all_models.json +164 -2
  10. model_library/config/anthropic_models.yaml +4 -0
  11. model_library/config/deepseek_models.yaml +3 -1
  12. model_library/config/openai_models.yaml +48 -0
  13. model_library/exceptions.py +2 -0
  14. model_library/logging.py +30 -0
  15. model_library/providers/__init__.py +0 -0
  16. model_library/providers/ai21labs.py +2 -0
  17. model_library/providers/alibaba.py +16 -78
  18. model_library/providers/amazon.py +3 -0
  19. model_library/providers/anthropic.py +213 -2
  20. model_library/providers/azure.py +2 -0
  21. model_library/providers/cohere.py +14 -80
  22. model_library/providers/deepseek.py +14 -90
  23. model_library/providers/fireworks.py +17 -81
  24. model_library/providers/google/google.py +22 -20
  25. model_library/providers/inception.py +15 -83
  26. model_library/providers/kimi.py +15 -83
  27. model_library/providers/mistral.py +2 -0
  28. model_library/providers/openai.py +2 -0
  29. model_library/providers/perplexity.py +12 -79
  30. model_library/providers/together.py +2 -0
  31. model_library/providers/vals.py +2 -0
  32. model_library/providers/xai.py +2 -0
  33. model_library/providers/zai.py +15 -83
  34. model_library/register_models.py +75 -55
  35. model_library/registry_utils.py +5 -5
  36. model_library/utils.py +3 -28
  37. {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/METADATA +36 -7
  38. model_library-0.1.2.dist-info/RECORD +61 -0
  39. model_library-0.1.0.dist-info/RECORD +0 -53
  40. {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/WHEEL +0 -0
  41. {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/licenses/LICENSE +0 -0
  42. {model_library-0.1.0.dist-info → model_library-0.1.2.dist-info}/top_level.txt +0 -0
model_library/__init__.py CHANGED
@@ -1,8 +1,11 @@
1
- from model_library.settings import ModelLibrarySettings
2
1
  from model_library.base import LLM, LLMConfig
2
+ from model_library.logging import set_logging
3
+ from model_library.settings import ModelLibrarySettings
3
4
 
4
5
  model_library_settings: ModelLibrarySettings = ModelLibrarySettings()
5
6
 
7
+ set_logging()
8
+
6
9
 
7
10
  def model(model_str: str, override_config: LLMConfig | None = None) -> LLM:
8
11
  from model_library.registry_utils import get_registry_model
@@ -10,14 +13,15 @@ def model(model_str: str, override_config: LLMConfig | None = None) -> LLM:
10
13
  return get_registry_model(model_str, override_config)
11
14
 
12
15
 
13
- def raw_model(model_str: str, override_config: LLMConfig | None = None) -> LLM:
16
+ def raw_model(model_str: str, config: LLMConfig | None = None) -> LLM:
14
17
  from model_library.registry_utils import get_raw_model
15
18
 
16
- return get_raw_model(model_str, config=override_config)
19
+ return get_raw_model(model_str, config=config)
17
20
 
18
21
 
19
22
  __all__ = [
20
23
  "model_library_settings",
21
24
  "model",
22
25
  "raw_model",
26
+ "set_logging",
23
27
  ]
@@ -0,0 +1,7 @@
1
+ # ruff: noqa: F403,F401
2
+ from model_library.base.base import *
3
+ from model_library.base.batch import *
4
+ from model_library.base.delegate_only import *
5
+ from model_library.base.input import *
6
+ from model_library.base.output import *
7
+ from model_library.base.utils import *
@@ -7,293 +7,46 @@ from collections.abc import Awaitable
7
7
  from pprint import pformat
8
8
  from typing import (
9
9
  TYPE_CHECKING,
10
- Annotated,
11
10
  Any,
12
11
  Callable,
13
12
  Literal,
14
- Mapping,
15
13
  Sequence,
16
14
  TypeVar,
17
- cast,
18
15
  )
19
16
 
20
- from pydantic import computed_field, field_validator, model_serializer
21
- from pydantic.fields import Field
17
+ from pydantic import model_serializer
22
18
  from pydantic.main import BaseModel
23
19
  from typing_extensions import override
24
20
 
21
+ from model_library.base.batch import (
22
+ LLMBatchMixin,
23
+ )
24
+ from model_library.base.input import (
25
+ FileInput,
26
+ FileWithId,
27
+ InputItem,
28
+ TextInput,
29
+ ToolDefinition,
30
+ ToolResult,
31
+ )
32
+ from model_library.base.output import (
33
+ QueryResult,
34
+ QueryResultCost,
35
+ QueryResultMetadata,
36
+ )
37
+ from model_library.base.utils import (
38
+ get_pretty_input_types,
39
+ )
25
40
  from model_library.exceptions import (
26
41
  ImmediateRetryException,
27
42
  retry_llm_call,
28
43
  )
29
- from model_library.utils import sum_optional, truncate_str
30
-
31
- PydanticT = TypeVar("PydanticT", bound=BaseModel)
32
-
33
- DEFAULT_MAX_TOKENS = 2048
34
- DEFAULT_TEMPERATURE = 0.7
35
- DEFAULT_TOP_P = 1
44
+ from model_library.utils import truncate_str
36
45
 
37
46
  if TYPE_CHECKING:
38
47
  from model_library.providers.openai import OpenAIModel
39
48
 
40
- """
41
- --- FILES ---
42
- """
43
-
44
-
45
- class FileBase(BaseModel):
46
- type: Literal["image", "file"]
47
- name: str
48
- mime: str
49
-
50
- @override
51
- def __repr__(self):
52
- attrs = vars(self).copy()
53
- if "base64" in attrs:
54
- attrs["base64"] = truncate_str(attrs["base64"])
55
- return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
56
-
57
-
58
- class FileWithBase64(FileBase):
59
- append_type: Literal["base64"] = "base64"
60
- base64: str
61
-
62
-
63
- class FileWithUrl(FileBase):
64
- append_type: Literal["url"] = "url"
65
- url: str
66
-
67
-
68
- class FileWithId(FileBase):
69
- append_type: Literal["file_id"] = "file_id"
70
- file_id: str
71
-
72
-
73
- FileInput = Annotated[
74
- FileWithBase64 | FileWithUrl | FileWithId,
75
- Field(discriminator="append_type"),
76
- ]
77
-
78
-
79
- """
80
- --- TOOLS ---
81
- """
82
-
83
-
84
- class ToolBody(BaseModel):
85
- name: str
86
- description: str
87
- properties: dict[str, Any]
88
- required: list[str]
89
- kwargs: dict[str, Any] = {}
90
-
91
-
92
- class ToolDefinition(BaseModel):
93
- name: str # acts as a key
94
- body: ToolBody | Any
95
-
96
-
97
- class ToolCall(BaseModel):
98
- id: str
99
- call_id: str | None = None
100
- name: str
101
- args: dict[str, Any] | str
102
-
103
-
104
- """
105
- --- INPUT ---
106
- """
107
-
108
- RawResponse = Any
109
-
110
-
111
- class ToolInput(BaseModel):
112
- tools: list[ToolDefinition] = []
113
-
114
-
115
- class ToolResult(BaseModel):
116
- tool_call: ToolCall
117
- result: Any
118
-
119
-
120
- class TextInput(BaseModel):
121
- text: str
122
-
123
-
124
- RawInputItem = dict[
125
- str, Any
126
- ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
127
-
128
-
129
- InputItem = (
130
- TextInput | FileInput | ToolResult | RawInputItem | RawResponse
131
- ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
132
-
133
-
134
- """
135
- --- OUTPUT ---
136
- """
137
-
138
-
139
- class Citation(BaseModel):
140
- type: str | None = None
141
- title: str | None = None
142
- url: str | None = None
143
- start_index: int | None = None
144
- end_index: int | None = None
145
- file_id: str | None = None
146
- filename: str | None = None
147
- index: int | None = None
148
- container_id: str | None = None
149
-
150
-
151
- class QueryResultExtras(BaseModel):
152
- citations: list[Citation] = Field(default_factory=list)
153
-
154
-
155
- class QueryResultCost(BaseModel):
156
- """
157
- Cost information for a query
158
- Includes total cost and a structured breakdown.
159
- """
160
-
161
- input: float
162
- output: float
163
- reasoning: float | None = None
164
- cache_read: float | None = None
165
- cache_write: float | None = None
166
-
167
- @computed_field
168
- @property
169
- def total(self) -> float:
170
- return sum(
171
- filter(
172
- None,
173
- [
174
- self.input,
175
- self.output,
176
- self.reasoning,
177
- self.cache_read,
178
- self.cache_write,
179
- ],
180
- )
181
- )
182
-
183
- @override
184
- def __repr__(self):
185
- use_cents = self.total < 1
186
-
187
- def format_cost(value: float | None):
188
- if value is None:
189
- return None
190
- return f"{value * 100:.3f} cents" if use_cents else f"${value:.2f}"
191
-
192
- return (
193
- f"{format_cost(self.total)} "
194
- + f"(uncached input: {format_cost(self.input)} | output: {format_cost(self.output)} | reasoning: {format_cost(self.reasoning)} | cache_read: {format_cost(self.cache_read)} | cache_write: {format_cost(self.cache_write)})"
195
- )
196
-
197
-
198
- class QueryResultMetadata(BaseModel):
199
- """
200
- Metadata for a query: token usage and timing.
201
-
202
- """
203
-
204
- cost: QueryResultCost | None = None # set post query
205
- duration_seconds: float | None = None # set post query
206
- in_tokens: int = 0
207
- out_tokens: int = 0
208
- reasoning_tokens: int | None = None
209
- cache_read_tokens: int | None = None
210
- cache_write_tokens: int | None = None
211
-
212
- @property
213
- def default_duration_seconds(self) -> float:
214
- return self.duration_seconds or 0
215
-
216
- def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
217
- return QueryResultMetadata(
218
- in_tokens=self.in_tokens + other.in_tokens,
219
- out_tokens=self.out_tokens + other.out_tokens,
220
- reasoning_tokens=sum_optional(
221
- self.reasoning_tokens, other.reasoning_tokens
222
- ),
223
- cache_read_tokens=sum_optional(
224
- self.cache_read_tokens, other.cache_read_tokens
225
- ),
226
- cache_write_tokens=sum_optional(
227
- self.cache_write_tokens, other.cache_write_tokens
228
- ),
229
- duration_seconds=self.default_duration_seconds
230
- + other.default_duration_seconds,
231
- )
232
-
233
- @override
234
- def __repr__(self):
235
- attrs = vars(self).copy()
236
- return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
237
-
238
-
239
- class QueryResult(BaseModel):
240
- """
241
- Result of a query
242
- Contains the text, reasoning, metadata, tool calls, and history
243
- """
244
-
245
- output_text: str | None = None
246
- reasoning: str | None = None
247
- metadata: QueryResultMetadata = Field(default_factory=QueryResultMetadata)
248
- tool_calls: list[ToolCall] = Field(default_factory=list)
249
- history: list[InputItem] = Field(default_factory=list)
250
- extras: QueryResultExtras = Field(default_factory=QueryResultExtras)
251
- raw: dict[str, Any] = Field(default_factory=dict)
252
-
253
- @property
254
- def output_text_str(self) -> str:
255
- return self.output_text or ""
256
-
257
- @field_validator("reasoning", mode="before")
258
- def default_reasoning(cls, v: str | None):
259
- return None if not v else v # make reasoning None if empty
260
-
261
- @property
262
- def search_results(self) -> Any | None:
263
- """Expose provider-supplied search metadata without additional processing."""
264
- raw_dict = cast(dict[str, Any], getattr(self, "raw", {}))
265
- raw_candidate = raw_dict.get("search_results")
266
- if raw_candidate is not None:
267
- return raw_candidate
268
-
269
- return _get_from_history(self.history, "search_results")
270
-
271
- @override
272
- def __repr__(self):
273
- attrs = vars(self).copy()
274
- ordered_attrs = {
275
- "output_text": truncate_str(attrs.pop("output_text", None), 400),
276
- "reasoning": truncate_str(attrs.pop("reasoning", None), 400),
277
- "metadata": attrs.pop("metadata", None),
278
- }
279
- if self.tool_calls:
280
- ordered_attrs["tool_calls"] = self.tool_calls
281
- return f"{self.__class__.__name__}(\n{pformat(ordered_attrs, indent=2, sort_dicts=False)}\n)"
282
-
283
-
284
- def _get_from_history(history: Sequence[InputItem], key: str) -> Any | None:
285
- for item in reversed(history):
286
- value = getattr(item, key, None)
287
- if value is not None:
288
- return value
289
-
290
- extra = getattr(item, "model_extra", None)
291
- if isinstance(extra, Mapping):
292
- value = cast(Mapping[str, Any], extra).get(key)
293
- if value is not None:
294
- return value
295
-
296
- return None
49
+ PydanticT = TypeVar("PydanticT", bound=BaseModel)
297
50
 
298
51
 
299
52
  class ProviderConfig(BaseModel):
@@ -304,6 +57,9 @@ class ProviderConfig(BaseModel):
304
57
  return self.__dict__
305
58
 
306
59
 
60
+ DEFAULT_MAX_TOKENS = 2048
61
+
62
+
307
63
  class LLMConfig(BaseModel):
308
64
  max_tokens: int = DEFAULT_MAX_TOKENS
309
65
  temperature: float | None = None
@@ -347,9 +103,6 @@ class LLM(ABC):
347
103
  config = config or LLMConfig()
348
104
  self._registry_key = config.registry_key
349
105
 
350
- if config.provider_config:
351
- self.provider_config = config.provider_config
352
-
353
106
  self.max_tokens: int = config.max_tokens
354
107
  self.temperature: float | None = config.temperature
355
108
  self.top_p: float | None = config.top_p
@@ -368,6 +121,12 @@ class LLM(ABC):
368
121
  self.delegate: "OpenAIModel | None" = None
369
122
  self.batch: LLMBatchMixin | None = None
370
123
 
124
+ if config.provider_config:
125
+ if isinstance(
126
+ config.provider_config, type(getattr(self, "provider_config"))
127
+ ):
128
+ self.provider_config = config.provider_config
129
+
371
130
  self.logger: logging.Logger = logging.getLogger(
372
131
  f"llm.{provider}.{model_name}<instance={self.instance_id}>"
373
132
  )
@@ -521,33 +280,6 @@ class LLM(ABC):
521
280
 
522
281
  return output
523
282
 
524
- async def query_json(
525
- self,
526
- input: Sequence[InputItem],
527
- pydantic_model: type[PydanticT],
528
- **kwargs: object,
529
- ) -> PydanticT:
530
- """Query the model with JSON response format using Pydantic model.
531
-
532
- This is a convenience method that is not implemented for all providers.
533
- Only OpenAI and Google providers currently support this method.
534
-
535
- Args:
536
- input: Input items (text, files, etc.)
537
- pydantic_model: Pydantic model class defining the expected response structure
538
- **kwargs: Additional arguments passed to the query method
539
-
540
- Returns:
541
- Instance of the pydantic_model with the model's response
542
-
543
- Raises:
544
- NotImplementedError: If the provider does not support structured JSON output
545
- """
546
- raise NotImplementedError(
547
- f"query_json is not implemented for {self.__class__.__name__}. "
548
- f"Only OpenAI and Google providers currently support this method."
549
- )
550
-
551
283
  async def _calculate_cost(
552
284
  self,
553
285
  metadata: QueryResultMetadata,
@@ -678,137 +410,29 @@ class LLM(ABC):
678
410
  """Upload a file to the model provider"""
679
411
  ...
680
412
 
681
-
682
- class BatchResult(BaseModel):
683
- custom_id: str
684
- output: QueryResult
685
- error_message: str | None = None
686
-
687
-
688
- class LLMBatchMixin(ABC):
689
- @abstractmethod
690
- async def create_batch_query_request(
413
+ async def query_json(
691
414
  self,
692
- custom_id: str,
693
415
  input: Sequence[InputItem],
416
+ pydantic_model: type[PydanticT],
694
417
  **kwargs: object,
695
- ) -> dict[str, Any]:
696
- """Return a single query request
697
-
698
- The batch api sends out a batch of query requests to various endpoints.
418
+ ) -> PydanticT:
419
+ """Query the model with JSON response format using Pydantic model.
699
420
 
700
- For example OpenAI sends can send requests to /v1/responses or /v1/chat/completions endpoints.
421
+ This is a convenience method that is not implemented for all providers.
422
+ Only OpenAI and Google providers currently support this method.
701
423
 
702
- This method creates a query request for methods such methods
703
- """
704
- ...
424
+ Args:
425
+ input: Input items (text, files, etc.)
426
+ pydantic_model: Pydantic model class defining the expected response structure
427
+ **kwargs: Additional arguments passed to the query method
705
428
 
706
- @abstractmethod
707
- async def batch_query(
708
- self,
709
- batch_name: str,
710
- requests: list[dict[str, Any]],
711
- ) -> str:
712
- """
713
- Batch query the model
714
429
  Returns:
715
- str: batch_id
716
- Raises:
717
- Exception: If failed to batch query
718
- """
719
- ...
720
-
721
- @abstractmethod
722
- async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
723
- """
724
- Returns results for batch
725
- Raises:
726
- Exception: If failed to get results
727
- """
728
- ...
729
-
730
- @abstractmethod
731
- async def get_batch_progress(self, batch_id: str) -> int:
732
- """
733
- Returns number of completed requests for batch
734
- Raises:
735
- Exception: If failed to get progress
736
- """
737
- ...
738
-
739
- @abstractmethod
740
- async def cancel_batch_request(self, batch_id: str) -> None:
741
- """
742
- Cancels batch
743
- Raises:
744
- Exception: If failed to cancel
745
- """
746
- ...
430
+ Instance of the pydantic_model with the model's response
747
431
 
748
- @abstractmethod
749
- async def get_batch_status(
750
- self,
751
- batch_id: str,
752
- ) -> str:
753
- """
754
- Returns batch status
755
432
  Raises:
756
- Exception: If failed to get status
757
- """
758
- ...
759
-
760
- @classmethod
761
- @abstractmethod
762
- def is_batch_status_completed(
763
- cls,
764
- batch_status: str,
765
- ) -> bool:
766
- """
767
- Returns if batch status is completed
768
-
769
- A completed state is any state that is final and not in-progress
770
- Example: failed | cancelled | expired | completed
771
-
772
- An incompleted state is any state that is not completed
773
- Example: in_progress | pending | running
433
+ NotImplementedError: If the provider does not support structured JSON output
774
434
  """
775
- ...
776
-
777
- @classmethod
778
- @abstractmethod
779
- def is_batch_status_failed(
780
- cls,
781
- batch_status: str,
782
- ) -> bool:
783
- """Returns if batch status is failed"""
784
- ...
785
-
786
- @classmethod
787
- @abstractmethod
788
- def is_batch_status_cancelled(
789
- cls,
790
- batch_status: str,
791
- ) -> bool:
792
- """Returns if batch status is cancelled"""
793
- ...
794
-
795
-
796
- def get_pretty_input_types(input: Sequence["InputItem"]) -> str:
797
- # for logging
798
- def process_item(item: "InputItem"):
799
- match item:
800
- case TextInput():
801
- return truncate_str(repr(item))
802
- case FileBase(): # FileInput
803
- return repr(item)
804
- case ToolResult():
805
- return repr(item)
806
- case dict():
807
- item = cast(RawInputItem, item)
808
- return repr(item)
809
- case _:
810
- # RawResponse
811
- return repr(item)
812
-
813
- processed_items = [f" {process_item(item)}" for item in input]
814
- return "\n" + "\n".join(processed_items) if processed_items else ""
435
+ raise NotImplementedError(
436
+ f"query_json is not implemented for {self.__class__.__name__}. "
437
+ f"Only OpenAI and Google providers currently support this method."
438
+ )
@@ -0,0 +1,121 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Sequence
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from model_library.base.input import InputItem
7
+ from model_library.base.output import QueryResult
8
+
9
+
10
+ class BatchResult(BaseModel):
11
+ custom_id: str
12
+ output: QueryResult
13
+ error_message: str | None = None
14
+
15
+
16
+ class LLMBatchMixin(ABC):
17
+ @abstractmethod
18
+ async def create_batch_query_request(
19
+ self,
20
+ custom_id: str,
21
+ input: Sequence[InputItem],
22
+ **kwargs: object,
23
+ ) -> dict[str, Any]:
24
+ """Return a single query request
25
+
26
+ The batch api sends out a batch of query requests to various endpoints.
27
+
28
+ For example OpenAI sends can send requests to /v1/responses or /v1/chat/completions endpoints.
29
+
30
+ This method creates a query request for methods such methods
31
+ """
32
+ ...
33
+
34
+ @abstractmethod
35
+ async def batch_query(
36
+ self,
37
+ batch_name: str,
38
+ requests: list[dict[str, Any]],
39
+ ) -> str:
40
+ """
41
+ Batch query the model
42
+ Returns:
43
+ str: batch_id
44
+ Raises:
45
+ Exception: If failed to batch query
46
+ """
47
+ ...
48
+
49
+ @abstractmethod
50
+ async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
51
+ """
52
+ Returns results for batch
53
+ Raises:
54
+ Exception: If failed to get results
55
+ """
56
+ ...
57
+
58
+ @abstractmethod
59
+ async def get_batch_progress(self, batch_id: str) -> int:
60
+ """
61
+ Returns number of completed requests for batch
62
+ Raises:
63
+ Exception: If failed to get progress
64
+ """
65
+ ...
66
+
67
+ @abstractmethod
68
+ async def cancel_batch_request(self, batch_id: str) -> None:
69
+ """
70
+ Cancels batch
71
+ Raises:
72
+ Exception: If failed to cancel
73
+ """
74
+ ...
75
+
76
+ @abstractmethod
77
+ async def get_batch_status(
78
+ self,
79
+ batch_id: str,
80
+ ) -> str:
81
+ """
82
+ Returns batch status
83
+ Raises:
84
+ Exception: If failed to get status
85
+ """
86
+ ...
87
+
88
+ @classmethod
89
+ @abstractmethod
90
+ def is_batch_status_completed(
91
+ cls,
92
+ batch_status: str,
93
+ ) -> bool:
94
+ """
95
+ Returns if batch status is completed
96
+
97
+ A completed state is any state that is final and not in-progress
98
+ Example: failed | cancelled | expired | completed
99
+
100
+ An incompleted state is any state that is not completed
101
+ Example: in_progress | pending | running
102
+ """
103
+ ...
104
+
105
+ @classmethod
106
+ @abstractmethod
107
+ def is_batch_status_failed(
108
+ cls,
109
+ batch_status: str,
110
+ ) -> bool:
111
+ """Returns if batch status is failed"""
112
+ ...
113
+
114
+ @classmethod
115
+ @abstractmethod
116
+ def is_batch_status_cancelled(
117
+ cls,
118
+ batch_status: str,
119
+ ) -> bool:
120
+ """Returns if batch status is cancelled"""
121
+ ...