model-library 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +11 -6
- model_library/base/output.py +54 -0
- model_library/base/utils.py +3 -2
- model_library/config/ai21labs_models.yaml +1 -0
- model_library/config/all_models.json +300 -37
- model_library/config/anthropic_models.yaml +26 -3
- model_library/config/google_models.yaml +49 -0
- model_library/config/openai_models.yaml +0 -9
- model_library/config/together_models.yaml +1 -0
- model_library/config/xai_models.yaml +63 -3
- model_library/exceptions.py +6 -2
- model_library/file_utils.py +1 -1
- model_library/providers/anthropic.py +2 -6
- model_library/providers/google/google.py +35 -29
- model_library/providers/openai.py +8 -2
- model_library/providers/together.py +18 -211
- model_library/register_models.py +0 -2
- {model_library-0.1.2.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
- {model_library-0.1.2.dist-info → model_library-0.1.3.dist-info}/RECORD +22 -22
- {model_library-0.1.2.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
- {model_library-0.1.2.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.2.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
model_library/base/base.py
CHANGED
|
@@ -218,6 +218,10 @@ class LLM(ABC):
|
|
|
218
218
|
Join input with history
|
|
219
219
|
Log, Time, and Retry
|
|
220
220
|
"""
|
|
221
|
+
|
|
222
|
+
# verbose on debug
|
|
223
|
+
verbose = self.logger.isEnabledFor(logging.DEBUG)
|
|
224
|
+
|
|
221
225
|
# format str input
|
|
222
226
|
if isinstance(input, str):
|
|
223
227
|
input = [TextInput(text=input)]
|
|
@@ -226,11 +230,11 @@ class LLM(ABC):
|
|
|
226
230
|
input = [*files, *images, *input]
|
|
227
231
|
|
|
228
232
|
# format input info
|
|
229
|
-
item_info =
|
|
233
|
+
item_info = (
|
|
234
|
+
f"--- input ({len(input)}): {get_pretty_input_types(input, verbose)}\n"
|
|
235
|
+
)
|
|
230
236
|
if history:
|
|
231
|
-
item_info += (
|
|
232
|
-
f"--- history({len(history)}): {get_pretty_input_types(history)}\n"
|
|
233
|
-
)
|
|
237
|
+
item_info += f"--- history({len(history)}): {get_pretty_input_types(history, verbose)}\n"
|
|
234
238
|
|
|
235
239
|
# format tool info
|
|
236
240
|
tool_results = [t for t in input if isinstance(t, ToolResult)]
|
|
@@ -251,7 +255,7 @@ class LLM(ABC):
|
|
|
251
255
|
|
|
252
256
|
# unique logger for the query
|
|
253
257
|
query_id = uuid.uuid4().hex[:14]
|
|
254
|
-
query_logger =
|
|
258
|
+
query_logger = self.logger.getChild(f"query={query_id}")
|
|
255
259
|
|
|
256
260
|
query_logger.info(
|
|
257
261
|
"Query started:\n" + item_info + tool_info + f"--- kwargs: {short_kwargs}\n"
|
|
@@ -277,6 +281,7 @@ class LLM(ABC):
|
|
|
277
281
|
output.metadata.cost = await self._calculate_cost(output.metadata)
|
|
278
282
|
|
|
279
283
|
query_logger.info(f"Query completed: {repr(output)}")
|
|
284
|
+
query_logger.debug(output.model_dump(exclude={"history", "raw"}))
|
|
280
285
|
|
|
281
286
|
return output
|
|
282
287
|
|
|
@@ -316,7 +321,7 @@ class LLM(ABC):
|
|
|
316
321
|
)
|
|
317
322
|
|
|
318
323
|
# costs for long context
|
|
319
|
-
total_in = metadata.
|
|
324
|
+
total_in = metadata.total_input_tokens
|
|
320
325
|
if costs.context and total_in > costs.context.threshold:
|
|
321
326
|
input_cost, output_cost = costs.context.get_costs(
|
|
322
327
|
input_cost,
|
model_library/base/output.py
CHANGED
|
@@ -59,6 +59,33 @@ class QueryResultCost(BaseModel):
|
|
|
59
59
|
)
|
|
60
60
|
)
|
|
61
61
|
|
|
62
|
+
@computed_field
|
|
63
|
+
@property
|
|
64
|
+
def total_input(self) -> float:
|
|
65
|
+
return sum(
|
|
66
|
+
filter(
|
|
67
|
+
None,
|
|
68
|
+
[
|
|
69
|
+
self.input,
|
|
70
|
+
self.cache_read,
|
|
71
|
+
self.cache_write,
|
|
72
|
+
],
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@computed_field
|
|
77
|
+
@property
|
|
78
|
+
def total_output(self) -> float:
|
|
79
|
+
return sum(
|
|
80
|
+
filter(
|
|
81
|
+
None,
|
|
82
|
+
[
|
|
83
|
+
self.output,
|
|
84
|
+
self.reasoning,
|
|
85
|
+
],
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
|
|
62
89
|
@override
|
|
63
90
|
def __repr__(self):
|
|
64
91
|
use_cents = self.total < 1
|
|
@@ -92,6 +119,33 @@ class QueryResultMetadata(BaseModel):
|
|
|
92
119
|
def default_duration_seconds(self) -> float:
|
|
93
120
|
return self.duration_seconds or 0
|
|
94
121
|
|
|
122
|
+
@computed_field
|
|
123
|
+
@property
|
|
124
|
+
def total_input_tokens(self) -> int:
|
|
125
|
+
return sum(
|
|
126
|
+
filter(
|
|
127
|
+
None,
|
|
128
|
+
[
|
|
129
|
+
self.in_tokens,
|
|
130
|
+
self.cache_read_tokens,
|
|
131
|
+
self.cache_write_tokens,
|
|
132
|
+
],
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@computed_field
|
|
137
|
+
@property
|
|
138
|
+
def total_output_tokens(self) -> int:
|
|
139
|
+
return sum(
|
|
140
|
+
filter(
|
|
141
|
+
None,
|
|
142
|
+
[
|
|
143
|
+
self.out_tokens,
|
|
144
|
+
self.reasoning_tokens,
|
|
145
|
+
],
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
95
149
|
def __add__(self, other: "QueryResultMetadata") -> "QueryResultMetadata":
|
|
96
150
|
return QueryResultMetadata(
|
|
97
151
|
in_tokens=self.in_tokens + other.in_tokens,
|
model_library/base/utils.py
CHANGED
|
@@ -21,12 +21,13 @@ def sum_optional(a: int | None, b: int | None) -> int | None:
|
|
|
21
21
|
return (a or 0) + (b or 0)
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def get_pretty_input_types(input: Sequence["InputItem"]) -> str:
|
|
24
|
+
def get_pretty_input_types(input: Sequence["InputItem"], verbose: bool = False) -> str:
|
|
25
25
|
# for logging
|
|
26
26
|
def process_item(item: "InputItem"):
|
|
27
27
|
match item:
|
|
28
28
|
case TextInput():
|
|
29
|
-
|
|
29
|
+
item_str = repr(item)
|
|
30
|
+
return item_str if verbose else truncate_str(item_str)
|
|
30
31
|
case FileBase(): # FileInput
|
|
31
32
|
return repr(item)
|
|
32
33
|
case ToolResult():
|