ragbits-core 0.16.0__py3-none-any.whl → 1.4.0.dev202512021005__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ragbits/core/__init__.py +21 -2
  2. ragbits/core/audit/__init__.py +15 -157
  3. ragbits/core/audit/metrics/__init__.py +83 -0
  4. ragbits/core/audit/metrics/base.py +198 -0
  5. ragbits/core/audit/metrics/logfire.py +19 -0
  6. ragbits/core/audit/metrics/otel.py +65 -0
  7. ragbits/core/audit/traces/__init__.py +171 -0
  8. ragbits/core/audit/{base.py → traces/base.py} +9 -5
  9. ragbits/core/audit/{cli.py → traces/cli.py} +8 -4
  10. ragbits/core/audit/traces/logfire.py +18 -0
  11. ragbits/core/audit/{otel.py → traces/otel.py} +5 -8
  12. ragbits/core/config.py +15 -0
  13. ragbits/core/embeddings/__init__.py +2 -1
  14. ragbits/core/embeddings/base.py +19 -0
  15. ragbits/core/embeddings/dense/base.py +10 -1
  16. ragbits/core/embeddings/dense/fastembed.py +22 -1
  17. ragbits/core/embeddings/dense/litellm.py +37 -10
  18. ragbits/core/embeddings/dense/local.py +15 -1
  19. ragbits/core/embeddings/dense/noop.py +11 -1
  20. ragbits/core/embeddings/dense/vertex_multimodal.py +14 -1
  21. ragbits/core/embeddings/sparse/bag_of_tokens.py +47 -17
  22. ragbits/core/embeddings/sparse/base.py +10 -1
  23. ragbits/core/embeddings/sparse/fastembed.py +25 -2
  24. ragbits/core/llms/__init__.py +3 -3
  25. ragbits/core/llms/base.py +612 -88
  26. ragbits/core/llms/exceptions.py +27 -0
  27. ragbits/core/llms/litellm.py +408 -83
  28. ragbits/core/llms/local.py +180 -41
  29. ragbits/core/llms/mock.py +88 -23
  30. ragbits/core/prompt/__init__.py +2 -2
  31. ragbits/core/prompt/_cli.py +32 -19
  32. ragbits/core/prompt/base.py +105 -19
  33. ragbits/core/prompt/{discovery/prompt_discovery.py → discovery.py} +1 -1
  34. ragbits/core/prompt/exceptions.py +22 -6
  35. ragbits/core/prompt/prompt.py +180 -98
  36. ragbits/core/sources/__init__.py +2 -0
  37. ragbits/core/sources/azure.py +1 -1
  38. ragbits/core/sources/base.py +8 -1
  39. ragbits/core/sources/gcs.py +1 -1
  40. ragbits/core/sources/git.py +1 -1
  41. ragbits/core/sources/google_drive.py +595 -0
  42. ragbits/core/sources/hf.py +71 -31
  43. ragbits/core/sources/local.py +1 -1
  44. ragbits/core/sources/s3.py +1 -1
  45. ragbits/core/utils/config_handling.py +13 -2
  46. ragbits/core/utils/function_schema.py +220 -0
  47. ragbits/core/utils/helpers.py +22 -0
  48. ragbits/core/utils/lazy_litellm.py +44 -0
  49. ragbits/core/vector_stores/base.py +18 -1
  50. ragbits/core/vector_stores/chroma.py +28 -11
  51. ragbits/core/vector_stores/hybrid.py +1 -1
  52. ragbits/core/vector_stores/hybrid_strategies.py +21 -8
  53. ragbits/core/vector_stores/in_memory.py +13 -4
  54. ragbits/core/vector_stores/pgvector.py +123 -47
  55. ragbits/core/vector_stores/qdrant.py +15 -7
  56. ragbits/core/vector_stores/weaviate.py +440 -0
  57. {ragbits_core-0.16.0.dist-info → ragbits_core-1.4.0.dev202512021005.dist-info}/METADATA +22 -6
  58. ragbits_core-1.4.0.dev202512021005.dist-info/RECORD +79 -0
  59. {ragbits_core-0.16.0.dist-info → ragbits_core-1.4.0.dev202512021005.dist-info}/WHEEL +1 -1
  60. ragbits/core/prompt/discovery/__init__.py +0 -3
  61. ragbits/core/prompt/lab/__init__.py +0 -0
  62. ragbits/core/prompt/lab/app.py +0 -262
  63. ragbits_core-0.16.0.dist-info/RECORD +0 -72
ragbits/core/llms/base.py CHANGED
@@ -1,23 +1,42 @@
1
1
  import enum
2
+ import json
2
3
  from abc import ABC, abstractmethod
3
- from collections.abc import AsyncGenerator
4
- from typing import ClassVar, Generic, TypeVar, cast, overload
4
+ from collections.abc import AsyncGenerator, AsyncIterator, Callable, MutableSequence
5
+ from typing import ClassVar, Generic, Literal, TypeVar, Union, cast, overload
5
6
 
6
- from pydantic import BaseModel
7
+ from pydantic import BaseModel, Field, field_validator
8
+ from typing_extensions import deprecated
7
9
 
8
10
  from ragbits.core import llms
9
- from ragbits.core.audit import trace
11
+ from ragbits.core.audit.metrics import record_metric
12
+ from ragbits.core.audit.metrics.base import LLMMetric, MetricType
13
+ from ragbits.core.audit.traces import trace
10
14
  from ragbits.core.options import Options
11
15
  from ragbits.core.prompt.base import (
12
16
  BasePrompt,
13
17
  BasePromptWithParser,
14
18
  ChatFormat,
15
- OutputT,
19
+ PromptOutputT,
16
20
  SimplePrompt,
17
21
  )
22
+ from ragbits.core.types import NOT_GIVEN, NotGiven
18
23
  from ragbits.core.utils.config_handling import ConfigurableComponent
24
+ from ragbits.core.utils.function_schema import convert_function_to_function_schema
19
25
 
20
- LLMClientOptionsT = TypeVar("LLMClientOptionsT", bound=Options)
26
+
27
+ class LLMOptions(Options):
28
+ """
29
+ Options for the LLM.
30
+ """
31
+
32
+ max_tokens: int | None | NotGiven = NOT_GIVEN
33
+ """The maximum number of tokens the LLM can use, if None, LLM will run forever"""
34
+
35
+
36
+ LLMClientOptionsT = TypeVar("LLMClientOptionsT", bound=LLMOptions)
37
+ Tool = Callable | dict
38
+ ToolChoiceWithCallable = Literal["none", "auto", "required"] | Tool
39
+ ToolChoice = Literal["none", "auto", "required"] | dict
21
40
 
22
41
 
23
42
  class LLMType(enum.Enum):
@@ -30,13 +49,209 @@ class LLMType(enum.Enum):
30
49
  STRUCTURED_OUTPUT = "structured_output"
31
50
 
32
51
 
33
- class LLMResponseWithMetadata(BaseModel, Generic[OutputT]):
52
+ class ToolCall(BaseModel):
53
+ """
54
+ A schema of tool call data
55
+ """
56
+
57
+ id: str
58
+ type: str
59
+ name: str
60
+ arguments: dict
61
+
62
+ @field_validator("arguments", mode="before")
63
+ def parse_tool_arguments(cls, tool_arguments: str) -> dict:
64
+ """
65
+ Parser for converting tool arguments from string representation to dict
66
+ """
67
+ parsed_arguments = json.loads(tool_arguments)
68
+ return parsed_arguments
69
+
70
+
71
+ class UsageItem(BaseModel):
72
+ """
73
+ A schema of token usage data
74
+ """
75
+
76
+ model: str
77
+
78
+ prompt_tokens: int
79
+ completion_tokens: int
80
+ total_tokens: int
81
+
82
+ estimated_cost: float
83
+
84
+
85
+ class Usage(BaseModel):
86
+ """
87
+ A schema of token usage data
88
+ """
89
+
90
+ requests: list[UsageItem] = Field(default_factory=list)
91
+
92
+ @classmethod
93
+ def new(
94
+ cls,
95
+ llm: "LLM",
96
+ prompt_tokens: int,
97
+ completion_tokens: int,
98
+ total_tokens: int,
99
+ ) -> "Usage":
100
+ """
101
+ Creates a new Usage object.
102
+
103
+ Args:
104
+ llm: The LLM instance.
105
+ prompt_tokens: The number of tokens in the prompt.
106
+ completion_tokens: The number of tokens in the completion.
107
+ total_tokens: The total number of tokens.
108
+ """
109
+ return cls(
110
+ requests=[
111
+ UsageItem(
112
+ model=llm.get_model_id(),
113
+ prompt_tokens=prompt_tokens,
114
+ completion_tokens=completion_tokens,
115
+ total_tokens=total_tokens,
116
+ estimated_cost=llm.get_estimated_cost(prompt_tokens, completion_tokens),
117
+ )
118
+ ]
119
+ )
120
+
121
+ @property
122
+ def n_requests(self) -> int:
123
+ """
124
+ Returns the number of requests.
125
+ """
126
+ return len(self.requests)
127
+
128
+ @property
129
+ def estimated_cost(self) -> float:
130
+ """
131
+ Returns the estimated cost.
132
+ """
133
+ return sum(request.estimated_cost for request in self.requests)
134
+
135
+ @property
136
+ def prompt_tokens(self) -> int:
137
+ """
138
+ Returns the number of prompt tokens.
139
+ """
140
+ return sum(request.prompt_tokens for request in self.requests)
141
+
142
+ @property
143
+ def completion_tokens(self) -> int:
144
+ """
145
+ Returns the number of completion tokens.
146
+ """
147
+ return sum(request.completion_tokens for request in self.requests)
148
+
149
+ @property
150
+ def total_tokens(self) -> int:
151
+ """
152
+ Returns the total number of tokens.
153
+ """
154
+ return sum(request.total_tokens for request in self.requests)
155
+
156
+ @property
157
+ def model_breakdown(self) -> dict[str, "Usage"]:
158
+ """
159
+ Returns the model breakdown.
160
+ """
161
+ breakdown = {}
162
+ for request in self.requests:
163
+ if request.model not in breakdown:
164
+ breakdown[request.model] = Usage(requests=[request])
165
+ else:
166
+ breakdown[request.model] += request
167
+
168
+ return breakdown
169
+
170
+ def __add__(self, other: Union["Usage", "UsageItem"]) -> "Usage":
171
+ if isinstance(other, Usage):
172
+ return Usage(
173
+ requests=self.requests + other.requests,
174
+ )
175
+
176
+ if isinstance(other, UsageItem):
177
+ return Usage(requests=self.requests + [other])
178
+
179
+ return NotImplemented
180
+
181
+ def __iadd__(self, other: Union["Usage", "UsageItem"]) -> "Usage":
182
+ if isinstance(other, Usage):
183
+ self.requests += other.requests
184
+ return self
185
+
186
+ if isinstance(other, UsageItem):
187
+ self.requests.append(other)
188
+ return self
189
+
190
+ return NotImplemented
191
+
192
+ def __repr__(self) -> str:
193
+ return (
194
+ f"Usage(n_requests={self.n_requests}, "
195
+ f"prompt_tokens={self.prompt_tokens}, "
196
+ f"completion_tokens={self.completion_tokens}, "
197
+ f"total_tokens={self.total_tokens}, "
198
+ f"estimated_cost={self.estimated_cost})"
199
+ )
200
+
201
+
202
+ class Reasoning(str):
203
+ """A class for reasoning streaming"""
204
+
205
+
206
+ class LLMResponseWithMetadata(BaseModel, Generic[PromptOutputT]):
34
207
  """
35
208
  A schema of output with metadata
36
209
  """
37
210
 
38
- content: OutputT
39
- metadata: dict
211
+ content: PromptOutputT
212
+ metadata: dict = {}
213
+ reasoning: str | None = None
214
+ tool_calls: list[ToolCall] | None = None
215
+ usage: Usage | None = None
216
+
217
+
218
+ T = TypeVar("T")
219
+
220
+
221
+ class LLMResultStreaming(AsyncIterator[T]):
222
+ """
223
+ An async iterator that will collect all yielded items by LLM.generate_streaming(). This object is returned
224
+ by `run_streaming`. It can be used in an `async for` loop to process items as they arrive. After the loop completes,
225
+ metadata is available as `metadata` attribute.
226
+ """
227
+
228
+ def __init__(self, generator: AsyncGenerator[T | LLMResponseWithMetadata]):
229
+ self._generator = generator
230
+ self.usage = Usage()
231
+
232
+ def __aiter__(self) -> AsyncIterator[T]:
233
+ return self
234
+
235
+ async def __anext__(self) -> T:
236
+ try:
237
+ item = await self._generator.__anext__()
238
+ match item:
239
+ case str():
240
+ pass
241
+ case ToolCall():
242
+ pass
243
+ case LLMResponseWithMetadata():
244
+ self.metadata: LLMResponseWithMetadata = item
245
+ if item.usage:
246
+ self.usage += item.usage
247
+ raise StopAsyncIteration
248
+ case Usage():
249
+ self.usage += item
250
+ case _:
251
+ raise ValueError(f"Unexpected item: {item}")
252
+ return cast(T, item)
253
+ except StopAsyncIteration:
254
+ raise
40
255
 
41
256
 
42
257
  class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
@@ -66,6 +281,25 @@ class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
66
281
  if not hasattr(cls, "options_cls"):
67
282
  raise TypeError(f"Class {cls.__name__} is missing the 'options_cls' attribute")
68
283
 
284
+ @abstractmethod
285
+ def get_model_id(self) -> str:
286
+ """
287
+ Returns the model id.
288
+ """
289
+
290
+ @abstractmethod
291
+ def get_estimated_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
292
+ """
293
+ Returns the estimated cost of the LLM call.
294
+
295
+ Args:
296
+ prompt_tokens: The number of tokens in the prompt.
297
+ completion_tokens: The number of tokens in the completion.
298
+
299
+ Returns:
300
+ The estimated cost of the LLM call.
301
+ """
302
+
69
303
  def count_tokens(self, prompt: BasePrompt) -> int: # noqa: PLR6301
70
304
  """
71
305
  Counts tokens in the prompt.
@@ -78,6 +312,19 @@ class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
78
312
  """
79
313
  return sum(len(message["content"]) for message in prompt.chat)
80
314
 
315
+ def get_token_id(self, token: str) -> int:
316
+ """
317
+ Gets token id.
318
+
319
+ Args:
320
+ token: The token to encode.
321
+
322
+ Returns:
323
+ The id for the given token.
324
+ """
325
+ raise NotImplementedError("Token id lookup is not supported by this model")
326
+
327
+ @deprecated("Use generate_with_metadata() instead")
81
328
  async def generate_raw(
82
329
  self,
83
330
  prompt: BasePrompt | str | ChatFormat,
@@ -102,51 +349,125 @@ class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
102
349
  if isinstance(prompt, str | list):
103
350
  prompt = SimplePrompt(prompt)
104
351
 
105
- return await self._call(
106
- prompt=prompt,
107
- options=merged_options,
108
- json_mode=prompt.json_mode,
109
- output_schema=prompt.output_schema(),
110
- )
352
+ response = (
353
+ await self._call(
354
+ prompt=[prompt],
355
+ options=merged_options,
356
+ )
357
+ )[0]
358
+
359
+ returned = {
360
+ "response": response["response"],
361
+ "throughput": response["throughput"],
362
+ }
363
+ for opt in ["tool_calls", "usage"]:
364
+ if opt in response:
365
+ returned[opt] = response[opt]
366
+
367
+ return returned
111
368
 
112
369
  @overload
113
370
  async def generate(
114
371
  self,
115
- prompt: BasePromptWithParser[OutputT],
372
+ prompt: BasePromptWithParser[PromptOutputT],
116
373
  *,
374
+ tools: None = None,
117
375
  options: LLMClientOptionsT | None = None,
118
- ) -> OutputT: ...
376
+ ) -> PromptOutputT: ...
119
377
 
120
378
  @overload
121
379
  async def generate(
122
380
  self,
123
- prompt: BasePrompt,
381
+ prompt: BasePrompt | BasePromptWithParser[PromptOutputT],
124
382
  *,
383
+ tools: None = None,
384
+ tool_choice: None = None,
125
385
  options: LLMClientOptionsT | None = None,
126
- ) -> OutputT: ...
386
+ ) -> PromptOutputT: ...
127
387
 
128
388
  @overload
129
389
  async def generate(
130
390
  self,
131
- prompt: str,
391
+ prompt: MutableSequence[BasePrompt | BasePromptWithParser[PromptOutputT]],
132
392
  *,
393
+ tools: None = None,
394
+ tool_choice: None = None,
133
395
  options: LLMClientOptionsT | None = None,
134
- ) -> str: ...
396
+ ) -> list[PromptOutputT]: ...
397
+
398
+ @overload
399
+ async def generate(
400
+ self,
401
+ prompt: BasePrompt | BasePromptWithParser[PromptOutputT],
402
+ *,
403
+ tools: list[Tool],
404
+ tool_choice: ToolChoiceWithCallable | None = None,
405
+ options: LLMClientOptionsT | None = None,
406
+ ) -> PromptOutputT | list[ToolCall]: ...
135
407
 
136
408
  @overload
137
409
  async def generate(
138
410
  self,
139
- prompt: ChatFormat,
411
+ prompt: MutableSequence[BasePrompt | BasePromptWithParser[PromptOutputT]],
140
412
  *,
413
+ tools: list[Tool],
414
+ tool_choice: ToolChoiceWithCallable | None = None,
415
+ options: LLMClientOptionsT | None = None,
416
+ ) -> list[PromptOutputT | list[ToolCall]]: ...
417
+
418
+ @overload
419
+ async def generate(
420
+ self,
421
+ prompt: str | ChatFormat,
422
+ *,
423
+ tools: None = None,
424
+ tool_choice: None = None,
141
425
  options: LLMClientOptionsT | None = None,
142
426
  ) -> str: ...
143
427
 
428
+ @overload
144
429
  async def generate(
145
430
  self,
146
- prompt: BasePrompt | str | ChatFormat,
431
+ prompt: MutableSequence[str | ChatFormat],
147
432
  *,
433
+ tools: None = None,
434
+ tool_choice: None = None,
148
435
  options: LLMClientOptionsT | None = None,
149
- ) -> OutputT:
436
+ ) -> list[str]: ...
437
+
438
+ @overload
439
+ async def generate(
440
+ self,
441
+ prompt: str | ChatFormat,
442
+ *,
443
+ tools: list[Tool],
444
+ tool_choice: ToolChoiceWithCallable | None = None,
445
+ options: LLMClientOptionsT | None = None,
446
+ ) -> str | list[ToolCall]: ...
447
+
448
+ @overload
449
+ async def generate(
450
+ self,
451
+ prompt: MutableSequence[str | ChatFormat],
452
+ *,
453
+ tools: list[Tool],
454
+ tool_choice: ToolChoiceWithCallable | None = None,
455
+ options: LLMClientOptionsT | None = None,
456
+ ) -> list[str | list[ToolCall]]: ...
457
+
458
+ async def generate(
459
+ self,
460
+ prompt: str
461
+ | ChatFormat
462
+ | BasePrompt
463
+ | BasePromptWithParser[PromptOutputT]
464
+ | MutableSequence[ChatFormat | str]
465
+ | MutableSequence[BasePrompt | BasePromptWithParser[PromptOutputT]],
466
+ *,
467
+ tools: list[Tool] | None = None,
468
+ tool_choice: ToolChoiceWithCallable | None = None,
469
+ options: LLMClientOptionsT | None = None,
470
+ ) -> str | PromptOutputT | list[ToolCall] | list[list[ToolCall] | str] | list[str | PromptOutputT | list[ToolCall]]:
150
471
  """
151
472
  Prepares and sends a prompt to the LLM and returns the parsed response.
152
473
 
@@ -155,129 +476,328 @@ class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
155
476
  - BasePrompt instance: Formatted prompt template with conversation
156
477
  - str: Simple text prompt that will be sent as a user message
157
478
  - ChatFormat: List of message dictionaries in OpenAI chat format
479
+ - Iterable of any of the above (MutableSequence is only for typing purposes)
480
+ tools: Functions to be used as tools by the LLM.
481
+ tool_choice: Parameter that allows to control what tool is used. Can be one of:
482
+ - "auto": let model decide if tool call is needed
483
+ - "none": do not call tool
484
+ - "required: enforce tool usage (model decides which one)
485
+ - dict: tool dict corresponding to one of provided tools
486
+ - Callable: one of provided tools
158
487
  options: Options to use for the LLM client.
159
488
 
160
489
  Returns:
161
- Parsed response from LLM.
490
+ Parsed response(s) from LLM or list of tool calls.
162
491
  """
163
- with trace(model_name=self.model_name, prompt=prompt, options=repr(options)) as outputs:
164
- raw_response = await self.generate_raw(prompt, options=options)
165
- if isinstance(prompt, BasePromptWithParser):
166
- response = await prompt.parse_response(raw_response["response"])
167
- else:
168
- response = cast(OutputT, raw_response["response"])
169
- raw_response["response"] = response
170
- outputs.response = raw_response
171
-
172
- return response
492
+ response = await self.generate_with_metadata(prompt, tools=tools, tool_choice=tool_choice, options=options)
493
+ if isinstance(response, list):
494
+ return [r.tool_calls if tools and r.tool_calls else r.content for r in response]
495
+ else:
496
+ return response.tool_calls if tools and response.tool_calls else response.content
173
497
 
174
498
  @overload
175
499
  async def generate_with_metadata(
176
500
  self,
177
- prompt: BasePromptWithParser[OutputT],
501
+ prompt: BasePrompt | BasePromptWithParser[PromptOutputT],
178
502
  *,
503
+ tools: list[Tool] | None = None,
504
+ tool_choice: ToolChoiceWithCallable | None = None,
179
505
  options: LLMClientOptionsT | None = None,
180
- ) -> LLMResponseWithMetadata[OutputT]: ...
506
+ ) -> LLMResponseWithMetadata[PromptOutputT]: ...
181
507
 
182
508
  @overload
183
509
  async def generate_with_metadata(
184
510
  self,
185
- prompt: BasePrompt,
511
+ prompt: MutableSequence[BasePrompt | BasePromptWithParser[PromptOutputT]],
186
512
  *,
513
+ tools: list[Tool] | None = None,
514
+ tool_choice: ToolChoiceWithCallable | None = None,
187
515
  options: LLMClientOptionsT | None = None,
188
- ) -> LLMResponseWithMetadata[OutputT]: ...
516
+ ) -> list[LLMResponseWithMetadata[PromptOutputT]]: ...
189
517
 
190
518
  @overload
191
519
  async def generate_with_metadata(
192
520
  self,
193
- prompt: str,
521
+ prompt: str | ChatFormat,
194
522
  *,
523
+ tools: list[Tool] | None = None,
524
+ tool_choice: ToolChoiceWithCallable | None = None,
195
525
  options: LLMClientOptionsT | None = None,
196
- ) -> LLMResponseWithMetadata[OutputT]: ...
526
+ ) -> LLMResponseWithMetadata[str]: ...
197
527
 
198
- @overload
199
528
  @overload
200
529
  async def generate_with_metadata(
201
530
  self,
202
- prompt: ChatFormat,
531
+ prompt: MutableSequence[str | ChatFormat],
203
532
  *,
533
+ tools: list[Tool] | None = None,
534
+ tool_choice: ToolChoiceWithCallable | None = None,
204
535
  options: LLMClientOptionsT | None = None,
205
- ) -> LLMResponseWithMetadata[OutputT]: ...
536
+ ) -> list[LLMResponseWithMetadata[str]]: ...
206
537
 
207
538
  async def generate_with_metadata(
208
539
  self,
209
- prompt: BasePrompt | str | ChatFormat,
540
+ prompt: str
541
+ | ChatFormat
542
+ | MutableSequence[str | ChatFormat]
543
+ | BasePrompt
544
+ | BasePromptWithParser[PromptOutputT]
545
+ | MutableSequence[BasePrompt | BasePromptWithParser[PromptOutputT]],
210
546
  *,
547
+ tools: list[Tool] | None = None,
548
+ tool_choice: ToolChoiceWithCallable | None = None,
211
549
  options: LLMClientOptionsT | None = None,
212
- ) -> LLMResponseWithMetadata[OutputT]:
550
+ ) -> (
551
+ LLMResponseWithMetadata[str]
552
+ | list[LLMResponseWithMetadata[str]]
553
+ | LLMResponseWithMetadata[PromptOutputT]
554
+ | list[LLMResponseWithMetadata[PromptOutputT]]
555
+ ):
213
556
  """
214
557
  Prepares and sends a prompt to the LLM and returns response parsed to the
215
558
  output type of the prompt (if available).
216
559
 
217
560
  Args:
218
- prompt: Formatted prompt template with conversation and optional response parsing configuration.
561
+ prompt: Can be one of:
562
+ - BasePrompt instance: Formatted prompt template with conversation
563
+ - str: Simple text prompt that will be sent as a user message
564
+ - ChatFormat: List of message dictionaries in OpenAI chat format
565
+ - Iterable of any of the above (MutableSequence is only for typing purposes)
566
+ tools: Functions to be used as tools by the LLM.
567
+ tool_choice: Parameter that allows to control what tool is used. Can be one of:
568
+ - "auto": let model decide if tool call is needed
569
+ - "none": do not call tool
570
+ - "required: enforce tool usage (model decides which one)
571
+ - dict: tool dict corresponding to one of provided tools
572
+ - Callable: one of provided tools
219
573
  options: Options to use for the LLM client.
220
574
 
221
575
  Returns:
222
- Text response from LLM with metadata.
576
+ ResponseWithMetadata object(s) with text response, list of tool calls and metadata information.
223
577
  """
224
- with trace(model_name=self.model_name, prompt=prompt, options=repr(options)) as outputs:
225
- response = await self.generate_raw(prompt, options=options)
226
- content = response.pop("response")
227
- if isinstance(prompt, BasePromptWithParser):
228
- content = await prompt.parse_response(content)
229
- outputs.response = LLMResponseWithMetadata[type(content)]( # type: ignore
230
- content=content,
231
- metadata=response,
578
+ single_prompt = False
579
+ if isinstance(prompt, BasePrompt | str) or isinstance(prompt[0], dict):
580
+ single_prompt = True
581
+ prompt = [prompt] # type: ignore
582
+
583
+ parsed_tools = (
584
+ [convert_function_to_function_schema(tool) if callable(tool) else tool for tool in tools] if tools else None
585
+ )
586
+ parsed_tool_choice = convert_function_to_function_schema(tool_choice) if callable(tool_choice) else tool_choice
587
+
588
+ prompts: list[BasePrompt] = [SimplePrompt(p) if isinstance(p, str | list) else p for p in prompt] # type: ignore
589
+
590
+ merged_options = (self.default_options | options) if options else self.default_options
591
+
592
+ with trace(name="generate", model_name=self.model_name, prompt=prompts, options=repr(options)) as outputs:
593
+ results = await self._call(
594
+ prompt=prompts,
595
+ options=merged_options,
596
+ tools=parsed_tools,
597
+ tool_choice=parsed_tool_choice,
598
+ )
599
+
600
+ parsed_responses = []
601
+ for prompt, response in zip(prompts, results, strict=True):
602
+ tool_calls = (
603
+ [ToolCall.model_validate(tool_call) for tool_call in _tool_calls]
604
+ if (_tool_calls := response.pop("tool_calls", None)) and tools
605
+ else None
606
+ )
607
+
608
+ usage = None
609
+ if usage_data := response.pop("usage", None):
610
+ usage = Usage.new(
611
+ llm=self,
612
+ prompt_tokens=cast(int, usage_data.get("prompt_tokens")),
613
+ completion_tokens=cast(int, usage_data.get("completion_tokens")),
614
+ total_tokens=cast(int, usage_data.get("total_tokens")),
615
+ )
616
+
617
+ content = response.pop("response")
618
+ reasoning = response.pop("reasoning", None)
619
+
620
+ if isinstance(prompt, BasePromptWithParser) and content:
621
+ content = await prompt.parse_response(content)
622
+
623
+ response_with_metadata = LLMResponseWithMetadata[type(content)]( # type: ignore
624
+ content=content,
625
+ reasoning=reasoning,
626
+ tool_calls=tool_calls,
627
+ metadata=response,
628
+ usage=usage,
629
+ )
630
+ parsed_responses.append(response_with_metadata)
631
+ outputs.response = parsed_responses
632
+
633
+ prompt_tokens = sum(r.usage.prompt_tokens for r in parsed_responses if r.usage)
634
+ outputs.prompt_tokens_batch = prompt_tokens
635
+ record_metric(
636
+ metric=LLMMetric.INPUT_TOKENS,
637
+ value=prompt_tokens,
638
+ metric_type=MetricType.HISTOGRAM,
639
+ model=self.model_name,
232
640
  )
233
- return outputs.response
234
641
 
235
- async def generate_streaming(
642
+ total_throughput = sum(r["throughput"] for r in results if "throughput" in r)
643
+ outputs.throughput_batch = total_throughput
644
+ record_metric(
645
+ metric=LLMMetric.PROMPT_THROUGHPUT,
646
+ value=total_throughput,
647
+ metric_type=MetricType.HISTOGRAM,
648
+ model=self.model_name,
649
+ )
650
+
651
+ total_tokens = sum(r.usage.total_tokens for r in parsed_responses if r.usage)
652
+ outputs.total_tokens_batch = total_tokens
653
+ record_metric(
654
+ metric=LLMMetric.TOKEN_THROUGHPUT,
655
+ value=total_tokens / total_throughput,
656
+ metric_type=MetricType.HISTOGRAM,
657
+ model=self.model_name,
658
+ )
659
+
660
+ if single_prompt:
661
+ return parsed_responses[0]
662
+
663
+ return parsed_responses
664
+
665
+ @overload
666
+ def generate_streaming(
236
667
  self,
237
- prompt: BasePrompt | str | ChatFormat,
668
+ prompt: str | ChatFormat | BasePrompt,
669
+ *,
670
+ tools: None = None,
671
+ tool_choice: None = None,
672
+ options: LLMClientOptionsT | None = None,
673
+ ) -> LLMResultStreaming[str | Reasoning]: ...
674
+
675
+ @overload
676
+ def generate_streaming(
677
+ self,
678
+ prompt: str | ChatFormat | BasePrompt,
679
+ *,
680
+ tools: list[Tool],
681
+ tool_choice: ToolChoiceWithCallable | None = None,
682
+ options: LLMClientOptionsT | None = None,
683
+ ) -> LLMResultStreaming[str | Reasoning | ToolCall]: ...
684
+
685
+ def generate_streaming(
686
+ self,
687
+ prompt: str | ChatFormat | BasePrompt,
238
688
  *,
689
+ tools: list[Tool] | None = None,
690
+ tool_choice: ToolChoiceWithCallable | None = None,
239
691
  options: LLMClientOptionsT | None = None,
240
- ) -> AsyncGenerator[str, None]:
692
+ ) -> LLMResultStreaming:
241
693
  """
242
- Prepares and sends a prompt to the LLM and streams the results.
694
+ This method returns an `LLMResultStreaming` object that can be asynchronously
695
+ iterated over. After the loop completes, metadata is available as `metadata` attribute.
243
696
 
244
697
  Args:
245
698
  prompt: Formatted prompt template with conversation.
246
- options: Options to use for the LLM client.
699
+ tools: Functions to be used as tools by the LLM.
700
+ tool_choice: Parameter that allows to control what tool is used. Can be one of:
701
+ - "auto": let model decide if tool call is needed
702
+ - "none": do not call tool
703
+ - "required: enforce tool usage (model decides which one)
704
+ - dict: tool dict corresponding to one of provided tools
705
+ - Callable: one of provided tools
706
+ options: Options to use for the LLM.
247
707
 
248
708
  Returns:
249
- Response stream from LLM.
709
+ Response stream from LLM or list of tool calls.
250
710
  """
251
- merged_options = (self.default_options | options) if options else self.default_options
711
+ return LLMResultStreaming(self._stream_internal(prompt, tools=tools, tool_choice=tool_choice, options=options))
252
712
 
253
- if isinstance(prompt, str | list):
254
- prompt = SimplePrompt(prompt)
713
+ async def _stream_internal(
714
+ self,
715
+ prompt: str | ChatFormat | BasePrompt,
716
+ *,
717
+ tools: list[Tool] | None = None,
718
+ tool_choice: ToolChoiceWithCallable | None = None,
719
+ options: LLMClientOptionsT | None = None,
720
+ ) -> AsyncGenerator[str | Reasoning | ToolCall | LLMResponseWithMetadata, None]:
721
+ with trace(model_name=self.model_name, prompt=prompt, options=repr(options)) as outputs:
722
+ merged_options = (self.default_options | options) if options else self.default_options
723
+ if isinstance(prompt, str | list):
724
+ prompt = SimplePrompt(prompt)
725
+
726
+ parsed_tools = (
727
+ [convert_function_to_function_schema(tool) if callable(tool) else tool for tool in tools]
728
+ if tools
729
+ else None
730
+ )
731
+ parsed_tool_choice = (
732
+ convert_function_to_function_schema(tool_choice) if callable(tool_choice) else tool_choice
733
+ )
734
+ response = await self._call_streaming(
735
+ prompt=prompt,
736
+ options=merged_options,
737
+ tools=parsed_tools,
738
+ tool_choice=parsed_tool_choice,
739
+ )
255
740
 
256
- response = await self._call_streaming(
257
- prompt=prompt,
258
- options=merged_options,
259
- json_mode=prompt.json_mode,
260
- output_schema=prompt.output_schema(),
261
- )
262
- async for text_piece in response:
263
- yield text_piece
741
+ content = ""
742
+ reasoning = ""
743
+ tool_calls = []
744
+ usage_data = {}
745
+ async for chunk in response:
746
+ if text := chunk.get("response"):
747
+ if chunk.get("reasoning"):
748
+ reasoning += text
749
+ yield Reasoning(text)
750
+ else:
751
+ content += text
752
+ yield text
753
+
754
+ if tools and (_tool_calls := chunk.get("tool_calls")):
755
+ for tool_call in _tool_calls:
756
+ parsed_tool_call = ToolCall.model_validate(tool_call)
757
+ tool_calls.append(parsed_tool_call)
758
+ yield parsed_tool_call
759
+
760
+ if usage_chunk := chunk.get("usage"):
761
+ usage_data = usage_chunk
762
+
763
+ usage = None
764
+ if usage_data:
765
+ usage = Usage.new(
766
+ llm=self,
767
+ prompt_tokens=cast(int, usage_data.get("prompt_tokens")),
768
+ completion_tokens=cast(int, usage_data.get("completion_tokens")),
769
+ total_tokens=cast(int, usage_data.get("total_tokens")),
770
+ )
771
+
772
+ outputs.response = LLMResponseWithMetadata[type(content or None)]( # type: ignore
773
+ content=content or None,
774
+ reasoning=reasoning or None,
775
+ tool_calls=tool_calls or None,
776
+ usage=usage,
777
+ )
778
+
779
+ yield outputs.response
264
780
 
265
781
  @abstractmethod
266
782
  async def _call(
267
783
  self,
268
- prompt: BasePrompt,
784
+ prompt: MutableSequence[BasePrompt],
269
785
  options: LLMClientOptionsT,
270
- json_mode: bool = False,
271
- output_schema: type[BaseModel] | dict | None = None,
272
- ) -> dict:
786
+ tools: list[dict] | None = None,
787
+ tool_choice: ToolChoice | None = None,
788
+ ) -> list[dict]:
273
789
  """
274
790
  Calls LLM inference API.
275
791
 
276
792
  Args:
277
793
  prompt: Formatted prompt template with conversation.
278
- options: Additional settings used by LLM.
279
- json_mode: Force the response to be in JSON format.
280
- output_schema: Schema for structured response (either Pydantic model or a JSON schema).
794
+ options: Additional settings used by the LLM.
795
+ tools: Functions to be used as tools by the LLM.
796
+ tool_choice: Parameter that allows to control what tool is used. Can be one of:
797
+ - "auto": let model decide if tool call is needed
798
+ - "none": do not call tool
799
+ - "required: enforce tool usage (model decides which one)
800
+ - dict: tool dict corresponding to one of provided tools
281
801
 
282
802
  Returns:
283
803
  Response dict from LLM.
@@ -288,18 +808,22 @@ class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
288
808
  self,
289
809
  prompt: BasePrompt,
290
810
  options: LLMClientOptionsT,
291
- json_mode: bool = False,
292
- output_schema: type[BaseModel] | dict | None = None,
293
- ) -> AsyncGenerator[str, None]:
811
+ tools: list[dict] | None = None,
812
+ tool_choice: ToolChoice | None = None,
813
+ ) -> AsyncGenerator[dict, None]:
294
814
  """
295
815
  Calls LLM inference API with output streaming.
296
816
 
297
817
  Args:
298
818
  prompt: Formatted prompt template with conversation.
299
- options: Additional settings used by LLM.
300
- json_mode: Force the response to be in JSON format.
301
- output_schema: Schema for structured response (either Pydantic model or a JSON schema).
819
+ options: Additional settings used by the LLM.
820
+ tools: Functions to be used as tools by the LLM.
821
+ tool_choice: Parameter that allows to control what tool is used. Can be one of:
822
+ - "auto": let model decide if tool call is needed
823
+ - "none": do not call tool
824
+ - "required: enforce tool usage (model decides which one)
825
+ - dict: tool dict corresponding to one of provided tools
302
826
 
303
827
  Returns:
304
- Response stream from LLM.
828
+ Response dict stream from LLM.
305
829
  """