model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. model_library/base/base.py +237 -62
  2. model_library/base/delegate_only.py +86 -9
  3. model_library/base/input.py +10 -7
  4. model_library/base/output.py +48 -0
  5. model_library/base/utils.py +56 -7
  6. model_library/config/alibaba_models.yaml +44 -57
  7. model_library/config/all_models.json +253 -126
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/openai_models.yaml +15 -23
  10. model_library/config/zai_models.yaml +24 -3
  11. model_library/exceptions.py +14 -77
  12. model_library/logging.py +6 -2
  13. model_library/providers/ai21labs.py +30 -14
  14. model_library/providers/alibaba.py +17 -8
  15. model_library/providers/amazon.py +119 -64
  16. model_library/providers/anthropic.py +184 -104
  17. model_library/providers/azure.py +22 -10
  18. model_library/providers/cohere.py +7 -7
  19. model_library/providers/deepseek.py +8 -8
  20. model_library/providers/fireworks.py +7 -8
  21. model_library/providers/google/batch.py +17 -13
  22. model_library/providers/google/google.py +130 -73
  23. model_library/providers/inception.py +7 -7
  24. model_library/providers/kimi.py +18 -8
  25. model_library/providers/minimax.py +30 -13
  26. model_library/providers/mistral.py +61 -35
  27. model_library/providers/openai.py +219 -93
  28. model_library/providers/openrouter.py +34 -0
  29. model_library/providers/perplexity.py +7 -7
  30. model_library/providers/together.py +7 -8
  31. model_library/providers/vals.py +16 -9
  32. model_library/providers/xai.py +157 -144
  33. model_library/providers/zai.py +38 -8
  34. model_library/register_models.py +4 -2
  35. model_library/registry_utils.py +39 -15
  36. model_library/retriers/__init__.py +0 -0
  37. model_library/retriers/backoff.py +73 -0
  38. model_library/retriers/base.py +225 -0
  39. model_library/retriers/token.py +427 -0
  40. model_library/retriers/utils.py +11 -0
  41. model_library/settings.py +1 -1
  42. model_library/utils.py +13 -35
  43. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
  44. model_library-0.1.8.dist-info/RECORD +70 -0
  45. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  46. model_library-0.1.6.dist-info/RECORD +0 -64
  47. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  48. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
+ import hashlib
1
2
  import io
2
3
  import logging
4
+ import threading
3
5
  import time
4
6
  import uuid
5
7
  from abc import ABC, abstractmethod
6
8
  from collections.abc import Awaitable
9
+ from math import ceil
7
10
  from pprint import pformat
8
11
  from typing import (
9
12
  Any,
@@ -13,8 +16,10 @@ from typing import (
13
16
  TypeVar,
14
17
  )
15
18
 
16
- from pydantic import model_serializer
19
+ import tiktoken
20
+ from pydantic import SecretStr, model_serializer
17
21
  from pydantic.main import BaseModel
22
+ from tiktoken.core import Encoding
18
23
  from typing_extensions import override
19
24
 
20
25
  from model_library.base.batch import (
@@ -32,14 +37,15 @@ from model_library.base.output import (
32
37
  QueryResult,
33
38
  QueryResultCost,
34
39
  QueryResultMetadata,
40
+ RateLimit,
35
41
  )
36
42
  from model_library.base.utils import (
37
43
  get_pretty_input_types,
44
+ serialize_for_tokenizing,
38
45
  )
39
- from model_library.exceptions import (
40
- ImmediateRetryException,
41
- retry_llm_call,
42
- )
46
+ from model_library.retriers.backoff import ExponentialBackoffRetrier
47
+ from model_library.retriers.base import BaseRetrier, R, RetrierType, retry_decorator
48
+ from model_library.retriers.token import TokenRetrier
43
49
  from model_library.utils import truncate_str
44
50
 
45
51
  PydanticT = TypeVar("PydanticT", bound=BaseModel)
@@ -53,11 +59,18 @@ class ProviderConfig(BaseModel):
53
59
  return self.__dict__
54
60
 
55
61
 
56
- DEFAULT_MAX_TOKENS = 2048
62
+ class TokenRetryParams(BaseModel):
63
+ input_modifier: float
64
+ output_modifier: float
65
+
66
+ use_dynamic_estimate: bool = True
67
+
68
+ limit: int
69
+ limit_refresh_seconds: Literal[60] = 60
57
70
 
58
71
 
59
72
  class LLMConfig(BaseModel):
60
- max_tokens: int = DEFAULT_MAX_TOKENS
73
+ max_tokens: int | None = None
61
74
  temperature: float | None = None
62
75
  top_p: float | None = None
63
76
  top_k: int | None = None
@@ -72,11 +85,18 @@ class LLMConfig(BaseModel):
72
85
  native: bool = True
73
86
  provider_config: ProviderConfig | None = None
74
87
  registry_key: str | None = None
88
+ custom_api_key: SecretStr | None = None
75
89
 
76
90
 
77
- RetrierType = Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]
91
+ class DelegateConfig(BaseModel):
92
+ base_url: str
93
+ api_key: SecretStr
78
94
 
79
- R = TypeVar("R") # return type
95
+
96
+ # shared across all subclasses and instances
97
+ # hash(provider + api_key) -> client
98
+ client_registry_lock = threading.Lock()
99
+ client_registry: dict[tuple[str, str], Any] = {}
80
100
 
81
101
 
82
102
  class LLM(ABC):
@@ -85,6 +105,34 @@ class LLM(ABC):
85
105
  LLM call errors should be raised as exceptions
86
106
  """
87
107
 
108
+ @abstractmethod
109
+ def get_client(self, api_key: str | None = None) -> Any:
110
+ """
111
+ Returns the cached instance of the appropriate SDK client.
112
+ Sublasses should implement this method and:
113
+ - if api_key is provided, initialize their client and call assing_client(client).
114
+ - else return super().get_client()
115
+ """
116
+ global client_registry
117
+ return client_registry[self._client_registry_key]
118
+
119
+ def assign_client(self, client: object) -> None:
120
+ """Thread-safe assignment to the client registry"""
121
+ global client_registry
122
+
123
+ if self._client_registry_key not in client_registry:
124
+ with client_registry_lock:
125
+ if self._client_registry_key not in client_registry:
126
+ client_registry[self._client_registry_key] = client
127
+
128
+ def has_client(self) -> bool:
129
+ return self._client_registry_key in client_registry
130
+
131
+ @abstractmethod
132
+ def _get_default_api_key(self) -> str:
133
+ """Return the api key from model_library.settings"""
134
+ ...
135
+
88
136
  def __init__(
89
137
  self,
90
138
  model_name: str,
@@ -100,7 +148,7 @@ class LLM(ABC):
100
148
  config = config or LLMConfig()
101
149
  self._registry_key = config.registry_key
102
150
 
103
- self.max_tokens: int = config.max_tokens
151
+ self.max_tokens: int | None = config.max_tokens
104
152
  self.temperature: float | None = config.temperature
105
153
  self.top_p: float | None = config.top_p
106
154
  self.top_k: int | None = config.top_k
@@ -128,21 +176,33 @@ class LLM(ABC):
128
176
  self.logger: logging.Logger = logging.getLogger(
129
177
  f"llm.{provider}.{model_name}<instance={self.instance_id}>"
130
178
  )
131
- self.custom_retrier: Callable[..., RetrierType] | None = retry_llm_call
179
+ self.custom_retrier: RetrierType | None = None
180
+
181
+ self.token_retry_params = None
182
+ # set _client_registry_key after initializing delegate
183
+ if not self.native:
184
+ return
185
+
186
+ if config.custom_api_key:
187
+ raw_key = config.custom_api_key.get_secret_value()
188
+ else:
189
+ raw_key = self._get_default_api_key()
190
+
191
+ key_hash = hashlib.sha256(raw_key.encode()).hexdigest()
192
+ self._client_registry_key = (self.provider, key_hash)
193
+ self._client_registry_key_model_specific = (
194
+ f"{self.provider}.{self.model_name}",
195
+ key_hash,
196
+ )
197
+ self.get_client(api_key=raw_key)
132
198
 
133
199
  @override
134
200
  def __repr__(self):
135
201
  attrs = vars(self).copy()
136
202
  attrs.pop("logger", None)
137
203
  attrs.pop("custom_retrier", None)
138
- attrs.pop("_key", None)
139
204
  return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
140
205
 
141
- @abstractmethod
142
- def get_client(self) -> object:
143
- """Return the instance of the appropriate SDK client."""
144
- ...
145
-
146
206
  @staticmethod
147
207
  async def timer_wrapper(func: Callable[[], Awaitable[R]]) -> tuple[R, float]:
148
208
  """
@@ -152,43 +212,6 @@ class LLM(ABC):
152
212
  result = await func()
153
213
  return result, round(time.perf_counter() - start, 4)
154
214
 
155
- @staticmethod
156
- async def immediate_retry_wrapper(
157
- func: Callable[[], Awaitable[R]],
158
- logger: logging.Logger,
159
- ) -> R:
160
- """
161
- Retry the query immediately
162
- """
163
- MAX_IMMEDIATE_RETRIES = 10
164
- retries = 0
165
- while True:
166
- try:
167
- return await func()
168
- except ImmediateRetryException as e:
169
- if retries >= MAX_IMMEDIATE_RETRIES:
170
- logger.error(f"Query reached max immediate retries {retries}: {e}")
171
- raise Exception(
172
- f"Query reached max immediate retries {retries}: {e}"
173
- ) from e
174
- retries += 1
175
-
176
- logger.warning(
177
- f"Query retried immediately {retries}/{MAX_IMMEDIATE_RETRIES}: {e}"
178
- )
179
-
180
- @staticmethod
181
- async def backoff_retry_wrapper(
182
- func: Callable[..., Awaitable[R]],
183
- backoff_retrier: RetrierType | None,
184
- ) -> R:
185
- """
186
- Retry the query with backoff
187
- """
188
- if not backoff_retrier:
189
- return await func()
190
- return await backoff_retrier(func)()
191
-
192
215
  async def delegate_query(
193
216
  self,
194
217
  input: Sequence[InputItem],
@@ -273,15 +296,38 @@ class LLM(ABC):
273
296
  return await LLM.timer_wrapper(query_func)
274
297
 
275
298
  async def immediate_retry() -> tuple[QueryResult, float]:
276
- return await LLM.immediate_retry_wrapper(timed_query, query_logger)
277
-
278
- async def backoff_retry() -> tuple[QueryResult, float]:
279
- backoff_retrier = (
280
- self.custom_retrier(query_logger) if self.custom_retrier else None
281
- )
282
- return await LLM.backoff_retry_wrapper(immediate_retry, backoff_retrier)
299
+ return await BaseRetrier.immediate_retry_wrapper(timed_query, query_logger)
300
+
301
+ async def default_retry() -> tuple[QueryResult, float]:
302
+ if self.token_retry_params:
303
+ (
304
+ estimate_input_tokens,
305
+ estimate_output_tokens,
306
+ ) = await self.estimate_query_tokens(
307
+ input,
308
+ tools=tools,
309
+ **kwargs,
310
+ )
311
+ retrier = TokenRetrier(
312
+ logger=query_logger,
313
+ client_registry_key=self._client_registry_key_model_specific,
314
+ estimate_input_tokens=estimate_input_tokens,
315
+ estimate_output_tokens=estimate_output_tokens,
316
+ dynamic_estimate_instance_id=self.instance_id
317
+ if self.token_retry_params.use_dynamic_estimate
318
+ else None,
319
+ )
320
+ else:
321
+ retrier = ExponentialBackoffRetrier(logger=query_logger)
322
+ return await retry_decorator(retrier)(immediate_retry)()
323
+
324
+ run_with_retry = (
325
+ default_retry
326
+ if not self.custom_retrier
327
+ else self.custom_retrier(immediate_retry)
328
+ )
283
329
 
284
- output, duration = await backoff_retry()
330
+ output, duration = await run_with_retry()
285
331
  output.metadata.duration_seconds = duration
286
332
  output.metadata.cost = await self._calculate_cost(output.metadata)
287
333
 
@@ -290,6 +336,16 @@ class LLM(ABC):
290
336
 
291
337
  return output
292
338
 
339
+ async def init_token_retry(self, token_retry_params: TokenRetryParams) -> None:
340
+ self.token_retry_params = token_retry_params
341
+ await TokenRetrier.init_remaining_tokens(
342
+ client_registry_key=self._client_registry_key_model_specific,
343
+ limit=self.token_retry_params.limit,
344
+ limit_refresh_seconds=self.token_retry_params.limit_refresh_seconds,
345
+ get_rate_limit_func=self.get_rate_limit,
346
+ logger=self.logger,
347
+ )
348
+
293
349
  async def _calculate_cost(
294
350
  self,
295
351
  metadata: QueryResultMetadata,
@@ -379,6 +435,20 @@ class LLM(ABC):
379
435
  """
380
436
  ...
381
437
 
438
+ @abstractmethod
439
+ async def build_body(
440
+ self,
441
+ input: Sequence[InputItem],
442
+ *,
443
+ tools: list[ToolDefinition],
444
+ **kwargs: Any,
445
+ ) -> dict[str, Any]:
446
+ """
447
+ Builds the body of the request to the model provider
448
+ Calls parse_input
449
+ """
450
+ ...
451
+
382
452
  @abstractmethod
383
453
  async def parse_input(
384
454
  self,
@@ -421,6 +491,111 @@ class LLM(ABC):
421
491
  """Upload a file to the model provider"""
422
492
  ...
423
493
 
494
+ async def get_rate_limit(self) -> RateLimit | None:
495
+ """Get the rate limit for the model provider"""
496
+ return None
497
+
498
+ async def estimate_query_tokens(
499
+ self,
500
+ input: Sequence[InputItem],
501
+ *,
502
+ tools: list[ToolDefinition] = [],
503
+ **kwargs: object,
504
+ ) -> tuple[int, int]:
505
+ """Pessimistically estimate the number of tokens required for a query"""
506
+ assert self.token_retry_params
507
+
508
+ # TODO: when passing in images and files, we really need to take that into account when calculating the output tokens!!
509
+
510
+ input_tokens = (
511
+ await self.count_tokens(input, history=[], tools=tools, **kwargs)
512
+ * self.token_retry_params.input_modifier
513
+ )
514
+
515
+ output_tokens = input_tokens * self.token_retry_params.output_modifier
516
+ return ceil(input_tokens), ceil(output_tokens)
517
+
518
+ async def get_encoding(self) -> Encoding:
519
+ """Get the appropriate tokenizer"""
520
+
521
+ model = self.model_name.lower()
522
+
523
+ if any(x in model for x in ["gpt-4o", "o1", "o3", "gpt-4.1", "gpt-5"]):
524
+ return tiktoken.get_encoding("o200k_base")
525
+ elif "gpt-4" in model or "gpt-3.5" in model:
526
+ try:
527
+ return tiktoken.encoding_for_model(self.model_name)
528
+ except KeyError:
529
+ return tiktoken.get_encoding("cl100k_base")
530
+ elif "claude" in model:
531
+ return tiktoken.get_encoding("cl100k_base")
532
+ elif "gemini" in model:
533
+ return tiktoken.get_encoding("o200k_base")
534
+ elif "llama" in model or "mistral" in model:
535
+ return tiktoken.get_encoding("cl100k_base")
536
+ else:
537
+ return tiktoken.get_encoding("cl100k_base")
538
+
539
+ async def stringify_input(
540
+ self,
541
+ input: Sequence[InputItem],
542
+ *,
543
+ history: Sequence[InputItem] = [],
544
+ tools: list[ToolDefinition] = [],
545
+ **kwargs: object,
546
+ ) -> str:
547
+ input = [*history, *input]
548
+
549
+ system_prompt = kwargs.pop(
550
+ "system_prompt", ""
551
+ ) # TODO: refactor along with system prompt arg change
552
+
553
+ # special case if using a delegate
554
+ # don't inherit method override by default
555
+ if self.delegate:
556
+ parsed_input = await self.delegate.parse_input(input, **kwargs)
557
+ parsed_tools = await self.delegate.parse_tools(tools)
558
+ else:
559
+ parsed_input = await self.parse_input(input, **kwargs)
560
+ parsed_tools = await self.parse_tools(tools)
561
+
562
+ serialized_input = serialize_for_tokenizing(parsed_input)
563
+ serialized_tools = serialize_for_tokenizing(parsed_tools)
564
+
565
+ combined = f"{system_prompt}\n{serialized_input}\n{serialized_tools}"
566
+
567
+ return combined
568
+
569
+ async def count_tokens(
570
+ self,
571
+ input: Sequence[InputItem],
572
+ *,
573
+ history: Sequence[InputItem] = [],
574
+ tools: list[ToolDefinition] = [],
575
+ **kwargs: object,
576
+ ) -> int:
577
+ """
578
+ Count the number of tokens for a query.
579
+ Combines parsed input and tools, then tokenizes the result.
580
+ """
581
+
582
+ if not input and not history:
583
+ return 0
584
+
585
+ if self.delegate:
586
+ encoding = await self.delegate.get_encoding()
587
+ else:
588
+ encoding = await self.get_encoding()
589
+ self.logger.debug(f"Token Count Encoding: {encoding}")
590
+
591
+ string_input = await self.stringify_input(
592
+ input, history=history, tools=tools, **kwargs
593
+ )
594
+
595
+ count = len(encoding.encode(string_input, disallowed_special=()))
596
+ self.logger.debug(f"Combined Token Count Input: {count}")
597
+ return count
598
+
424
599
  async def query_json(
425
600
  self,
426
601
  input: Sequence[InputItem],
@@ -13,6 +13,7 @@ from model_library.base import (
13
13
  QueryResult,
14
14
  ToolDefinition,
15
15
  )
16
+ from model_library.base.base import DelegateConfig
16
17
 
17
18
 
18
19
  class DelegateOnlyException(Exception):
@@ -21,17 +22,51 @@ class DelegateOnlyException(Exception):
21
22
  delegate-only model.
22
23
  """
23
24
 
24
- DEFAULT_MESSAGE: str = "This model supports only delegate-only functionality. Only the query() method should be used."
25
+ DEFAULT_MESSAGE: str = "This model is running in delegate-only mode, certain functionality is not supported."
25
26
 
26
27
  def __init__(self, message: str | None = None):
27
28
  super().__init__(message or DelegateOnlyException.DEFAULT_MESSAGE)
28
29
 
29
30
 
30
31
  class DelegateOnly(LLM):
31
- @override
32
- def get_client(self) -> None:
32
+ def _get_default_api_key(self) -> str:
33
33
  raise DelegateOnlyException()
34
34
 
35
+ @override
36
+ def get_client(self, api_key: str | None = None) -> None:
37
+ assert self.delegate
38
+ return self.delegate.get_client()
39
+
40
+ def init_delegate(
41
+ self,
42
+ config: LLMConfig | None,
43
+ delegate_config: DelegateConfig,
44
+ delegate_provider: Literal["openai", "anthropic"],
45
+ use_completions: bool = True,
46
+ ) -> None:
47
+ from model_library.providers.anthropic import AnthropicModel
48
+ from model_library.providers.openai import OpenAIModel
49
+
50
+ match delegate_provider:
51
+ case "openai":
52
+ self.delegate = OpenAIModel(
53
+ model_name=self.model_name,
54
+ provider=self.provider,
55
+ config=config,
56
+ use_completions=use_completions,
57
+ delegate_config=delegate_config,
58
+ )
59
+ case "anthropic":
60
+ self.delegate = AnthropicModel(
61
+ model_name=self.model_name,
62
+ provider=self.provider,
63
+ config=config,
64
+ delegate_config=delegate_config,
65
+ )
66
+ self._client_registry_key_model_specific = (
67
+ self.delegate._client_registry_key_model_specific
68
+ )
69
+
35
70
  def __init__(
36
71
  self,
37
72
  model_name: str,
@@ -42,6 +77,11 @@ class DelegateOnly(LLM):
42
77
  config = config or LLMConfig()
43
78
  config.native = False
44
79
  super().__init__(model_name, provider, config=config)
80
+ config.native = True
81
+
82
+ def _get_extra_body(self) -> dict[str, Any]:
83
+ """Build extra body parameters for delegate-specific features."""
84
+ return {}
45
85
 
46
86
  @override
47
87
  async def _query_impl(
@@ -53,39 +93,57 @@ class DelegateOnly(LLM):
53
93
  **kwargs: object,
54
94
  ) -> QueryResult:
55
95
  assert self.delegate
56
-
57
96
  return await self.delegate_query(
58
- input, tools=tools, query_logger=query_logger, **kwargs
97
+ input,
98
+ tools=tools,
99
+ query_logger=query_logger,
100
+ extra_body=self._get_extra_body(),
101
+ **kwargs,
59
102
  )
60
103
 
104
+ @override
105
+ async def build_body(
106
+ self,
107
+ input: Sequence[InputItem],
108
+ *,
109
+ tools: list[ToolDefinition],
110
+ **kwargs: object,
111
+ ) -> dict[str, Any]:
112
+ assert self.delegate
113
+ return await self.delegate.build_body(input, tools=tools, **kwargs)
114
+
61
115
  @override
62
116
  async def parse_input(
63
117
  self,
64
118
  input: Sequence[InputItem],
65
119
  **kwargs: Any,
66
120
  ) -> Any:
67
- raise DelegateOnlyException()
121
+ assert self.delegate
122
+ return await self.delegate.parse_input(input, **kwargs)
68
123
 
69
124
  @override
70
125
  async def parse_image(
71
126
  self,
72
127
  image: FileInput,
73
128
  ) -> Any:
74
- raise DelegateOnlyException()
129
+ assert self.delegate
130
+ return await self.delegate.parse_image(image)
75
131
 
76
132
  @override
77
133
  async def parse_file(
78
134
  self,
79
135
  file: FileInput,
80
136
  ) -> Any:
81
- raise DelegateOnlyException()
137
+ assert self.delegate
138
+ return await self.delegate.parse_file(file)
82
139
 
83
140
  @override
84
141
  async def parse_tools(
85
142
  self,
86
143
  tools: list[ToolDefinition],
87
144
  ) -> Any:
88
- raise DelegateOnlyException()
145
+ assert self.delegate
146
+ return await self.delegate.parse_tools(tools)
89
147
 
90
148
  @override
91
149
  async def upload_file(
@@ -96,3 +154,22 @@ class DelegateOnly(LLM):
96
154
  type: Literal["image", "file"] = "file",
97
155
  ) -> FileWithId:
98
156
  raise DelegateOnlyException()
157
+
158
+ @override
159
+ async def get_rate_limit(self) -> Any:
160
+ assert self.delegate
161
+ return await self.delegate.get_rate_limit()
162
+
163
+ @override
164
+ async def count_tokens(
165
+ self,
166
+ input: Sequence[InputItem],
167
+ *,
168
+ history: Sequence[InputItem] = [],
169
+ tools: list[ToolDefinition] = [],
170
+ **kwargs: object,
171
+ ) -> int:
172
+ assert self.delegate
173
+ return await self.delegate.count_tokens(
174
+ input, history=history, tools=tools, **kwargs
175
+ )
@@ -74,8 +74,6 @@ class ToolCall(BaseModel):
74
74
  --- INPUT ---
75
75
  """
76
76
 
77
- RawResponse = Any
78
-
79
77
 
80
78
  class ToolInput(BaseModel):
81
79
  tools: list[ToolDefinition] = []
@@ -90,11 +88,16 @@ class TextInput(BaseModel):
90
88
  text: str
91
89
 
92
90
 
93
- RawInputItem = dict[
94
- str, Any
95
- ] # to pass in, for example, a mock convertsation with {"role": "user", "content": "Hello"}
91
+ class RawResponse(BaseModel):
92
+ # used to store a received response
93
+ response: Any
94
+
95
+
96
+ class RawInput(BaseModel):
97
+ # used to pass in anything provider specific (e.g. a mock conversation)
98
+ input: Any
96
99
 
97
100
 
98
101
  InputItem = (
99
- TextInput | FileInput | ToolResult | RawInputItem | RawResponse
100
- ) # input item can either be a prompt, a file (image or file), a tool call result, raw input, or a previous response
102
+ TextInput | FileInput | ToolResult | RawInput | RawResponse
103
+ ) # input item can either be a prompt, a file (image or file), a tool call result, a previous response, or raw input
@@ -24,6 +24,11 @@ class Citation(BaseModel):
24
24
  index: int | None = None
25
25
  container_id: str | None = None
26
26
 
27
+ @override
28
+ def __repr__(self):
29
+ attrs = vars(self).copy()
30
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2)}\n)"
31
+
27
32
 
28
33
  class QueryResultExtras(BaseModel):
29
34
  citations: list[Citation] = Field(default_factory=list)
@@ -113,6 +118,48 @@ class QueryResultCost(BaseModel):
113
118
  )
114
119
 
115
120
 
121
+ class RateLimit(BaseModel):
122
+ """Rate limit information"""
123
+
124
+ request_limit: int | None = None
125
+ request_remaining: int | None = None
126
+
127
+ token_limit: int | None = None
128
+ token_limit_input: int | None = None
129
+ token_limit_output: int | None = None
130
+
131
+ token_remaining: int | None = None
132
+ token_remaining_input: int | None = None
133
+ token_remaining_output: int | None = None
134
+
135
+ unix_timestamp: float
136
+ raw: Any
137
+
138
+ @computed_field
139
+ @property
140
+ def token_limit_total(self) -> int:
141
+ if self.token_limit:
142
+ return self.token_limit
143
+ else:
144
+ return (self.token_limit_input or 0) + (self.token_limit_output or 0)
145
+
146
+ @computed_field
147
+ @property
148
+ def token_remaining_total(self) -> int:
149
+ if self.token_remaining:
150
+ return self.token_remaining
151
+ else:
152
+ return (self.token_remaining_input or 0) + (
153
+ self.token_remaining_output or 0
154
+ )
155
+
156
+ @override
157
+ def __repr__(self):
158
+ attrs = vars(self).copy()
159
+ attrs.pop("raw", None)
160
+ return f"{self.__class__.__name__}(\n{pformat(attrs, indent=2, sort_dicts=False)}\n)"
161
+
162
+
116
163
  class QueryResultMetadata(BaseModel):
117
164
  """
118
165
  Metadata for a query: token usage and timing.
@@ -126,6 +173,7 @@ class QueryResultMetadata(BaseModel):
126
173
  reasoning_tokens: int | None = None
127
174
  cache_read_tokens: int | None = None
128
175
  cache_write_tokens: int | None = None
176
+ extra: dict[str, Any] = {}
129
177
 
130
178
  @property
131
179
  def default_duration_seconds(self) -> float: