ragbits-core 0.16.0__py3-none-any.whl → 1.4.0.dev202512021005__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. ragbits/core/__init__.py +21 -2
  2. ragbits/core/audit/__init__.py +15 -157
  3. ragbits/core/audit/metrics/__init__.py +83 -0
  4. ragbits/core/audit/metrics/base.py +198 -0
  5. ragbits/core/audit/metrics/logfire.py +19 -0
  6. ragbits/core/audit/metrics/otel.py +65 -0
  7. ragbits/core/audit/traces/__init__.py +171 -0
  8. ragbits/core/audit/{base.py → traces/base.py} +9 -5
  9. ragbits/core/audit/{cli.py → traces/cli.py} +8 -4
  10. ragbits/core/audit/traces/logfire.py +18 -0
  11. ragbits/core/audit/{otel.py → traces/otel.py} +5 -8
  12. ragbits/core/config.py +15 -0
  13. ragbits/core/embeddings/__init__.py +2 -1
  14. ragbits/core/embeddings/base.py +19 -0
  15. ragbits/core/embeddings/dense/base.py +10 -1
  16. ragbits/core/embeddings/dense/fastembed.py +22 -1
  17. ragbits/core/embeddings/dense/litellm.py +37 -10
  18. ragbits/core/embeddings/dense/local.py +15 -1
  19. ragbits/core/embeddings/dense/noop.py +11 -1
  20. ragbits/core/embeddings/dense/vertex_multimodal.py +14 -1
  21. ragbits/core/embeddings/sparse/bag_of_tokens.py +47 -17
  22. ragbits/core/embeddings/sparse/base.py +10 -1
  23. ragbits/core/embeddings/sparse/fastembed.py +25 -2
  24. ragbits/core/llms/__init__.py +3 -3
  25. ragbits/core/llms/base.py +612 -88
  26. ragbits/core/llms/exceptions.py +27 -0
  27. ragbits/core/llms/litellm.py +408 -83
  28. ragbits/core/llms/local.py +180 -41
  29. ragbits/core/llms/mock.py +88 -23
  30. ragbits/core/prompt/__init__.py +2 -2
  31. ragbits/core/prompt/_cli.py +32 -19
  32. ragbits/core/prompt/base.py +105 -19
  33. ragbits/core/prompt/{discovery/prompt_discovery.py → discovery.py} +1 -1
  34. ragbits/core/prompt/exceptions.py +22 -6
  35. ragbits/core/prompt/prompt.py +180 -98
  36. ragbits/core/sources/__init__.py +2 -0
  37. ragbits/core/sources/azure.py +1 -1
  38. ragbits/core/sources/base.py +8 -1
  39. ragbits/core/sources/gcs.py +1 -1
  40. ragbits/core/sources/git.py +1 -1
  41. ragbits/core/sources/google_drive.py +595 -0
  42. ragbits/core/sources/hf.py +71 -31
  43. ragbits/core/sources/local.py +1 -1
  44. ragbits/core/sources/s3.py +1 -1
  45. ragbits/core/utils/config_handling.py +13 -2
  46. ragbits/core/utils/function_schema.py +220 -0
  47. ragbits/core/utils/helpers.py +22 -0
  48. ragbits/core/utils/lazy_litellm.py +44 -0
  49. ragbits/core/vector_stores/base.py +18 -1
  50. ragbits/core/vector_stores/chroma.py +28 -11
  51. ragbits/core/vector_stores/hybrid.py +1 -1
  52. ragbits/core/vector_stores/hybrid_strategies.py +21 -8
  53. ragbits/core/vector_stores/in_memory.py +13 -4
  54. ragbits/core/vector_stores/pgvector.py +123 -47
  55. ragbits/core/vector_stores/qdrant.py +15 -7
  56. ragbits/core/vector_stores/weaviate.py +440 -0
  57. {ragbits_core-0.16.0.dist-info → ragbits_core-1.4.0.dev202512021005.dist-info}/METADATA +22 -6
  58. ragbits_core-1.4.0.dev202512021005.dist-info/RECORD +79 -0
  59. {ragbits_core-0.16.0.dist-info → ragbits_core-1.4.0.dev202512021005.dist-info}/WHEEL +1 -1
  60. ragbits/core/prompt/discovery/__init__.py +0 -3
  61. ragbits/core/prompt/lab/__init__.py +0 -0
  62. ragbits/core/prompt/lab/app.py +0 -262
  63. ragbits_core-0.16.0.dist-info/RECORD +0 -72
@@ -1,11 +1,12 @@
1
+ import json
1
2
  from abc import ABCMeta, abstractmethod
2
3
  from typing import Any, Generic
3
4
 
4
5
  from pydantic import BaseModel
5
- from typing_extensions import TypeVar
6
+ from typing_extensions import Self, TypeVar
6
7
 
7
8
  ChatFormat = list[dict[str, Any]]
8
- OutputT = TypeVar("OutputT", default=str)
9
+ PromptOutputT = TypeVar("PromptOutputT", default=str)
9
10
 
10
11
 
11
12
  class BasePrompt(metaclass=ABCMeta):
@@ -14,7 +15,6 @@ class BasePrompt(metaclass=ABCMeta):
14
15
  """
15
16
 
16
17
  @property
17
- @abstractmethod
18
18
  def chat(self) -> ChatFormat:
19
19
  """
20
20
  Returns the conversation in the standard OpenAI chat format.
@@ -22,6 +22,9 @@ class BasePrompt(metaclass=ABCMeta):
22
22
  Returns:
23
23
  ChatFormat: A list of dictionaries, each containing the role and content of a message.
24
24
  """
25
+ if not hasattr(self, "_conversation_history"):
26
+ self._conversation_history: list[dict[str, Any]] = []
27
+ return self._conversation_history
25
28
 
26
29
  @property
27
30
  def json_mode(self) -> bool:
@@ -46,15 +49,108 @@ class BasePrompt(metaclass=ABCMeta):
46
49
  """
47
50
  return []
48
51
 
52
+ def list_pdfs(self) -> list[str]: # noqa: PLR6301
53
+ """
54
+ Returns the PDFs in form of URLs or base64 encoded strings.
55
+
56
+ Returns:
57
+ list of PDFs
58
+ """
59
+ return []
60
+
61
+ def add_assistant_message(self, message: str | PromptOutputT) -> Self:
62
+ """
63
+ Add an assistant message to the conversation history.
64
+
65
+ Args:
66
+ message (str): The assistant message content.
67
+
68
+ Returns:
69
+ Prompt[PromptInputT, PromptOutputT]: The current prompt instance to allow chaining.
70
+ """
71
+ if not hasattr(self, "_conversation_history"):
72
+ self._conversation_history = []
73
+
74
+ if isinstance(message, BaseModel):
75
+ message = message.model_dump_json()
76
+ self._conversation_history.append({"role": "assistant", "content": str(message)})
77
+ return self
78
+
79
+ def add_tool_use_message(
80
+ self,
81
+ id: str,
82
+ name: str,
83
+ arguments: dict,
84
+ result: Any, # noqa: ANN401
85
+ ) -> Self:
86
+ """
87
+ Add tool call messages to the conversation history.
49
88
 
50
- class BasePromptWithParser(Generic[OutputT], BasePrompt, metaclass=ABCMeta):
89
+ Args:
90
+ id (str): The id of the tool call.
91
+ name (str): The name of the tool.
92
+ arguments (dict): The arguments of the tool.
93
+ result (any): The tool call result.
94
+
95
+ Returns:
96
+ Prompt[PromptInputT, PromptOutputT]: The current prompt instance to allow chaining.
97
+ """
98
+ if not hasattr(self, "_conversation_history"):
99
+ self._conversation_history = []
100
+
101
+ self._conversation_history.extend(
102
+ [
103
+ {
104
+ "role": "assistant",
105
+ "content": None,
106
+ "tool_calls": [
107
+ {
108
+ "id": id,
109
+ "type": "function",
110
+ "function": {
111
+ "name": name,
112
+ "arguments": json.dumps(arguments),
113
+ },
114
+ }
115
+ ],
116
+ },
117
+ {
118
+ "role": "tool",
119
+ "tool_call_id": id,
120
+ "content": str(result),
121
+ },
122
+ ]
123
+ )
124
+
125
+ return self
126
+
127
+ def add_user_message(self, message: str | dict[str, Any] | list[dict[str, Any]]) -> Self:
128
+ """
129
+ Add a user message to the conversation history.
130
+
131
+ Args:
132
+ message: The user message content. Can be:
133
+ - A string: Used directly as content
134
+ - A dictionary: With format {"type": "text", "text": "message"} or image content
135
+
136
+ Returns:
137
+ Prompt: The current prompt instance to allow chaining.
138
+ """
139
+ if not hasattr(self, "_conversation_history"):
140
+ self._conversation_history = []
141
+
142
+ self._conversation_history.append({"role": "user", "content": message})
143
+ return self
144
+
145
+
146
+ class BasePromptWithParser(Generic[PromptOutputT], BasePrompt, metaclass=ABCMeta):
51
147
  """
52
148
  Base class for prompts that know how to parse the output from the LLM to their specific
53
149
  output type.
54
150
  """
55
151
 
56
152
  @abstractmethod
57
- async def parse_response(self, response: str) -> OutputT:
153
+ async def parse_response(self, response: str) -> PromptOutputT:
58
154
  """
59
155
  Parse the response from the LLM to the desired output type.
60
156
 
@@ -62,7 +158,7 @@ class BasePromptWithParser(Generic[OutputT], BasePrompt, metaclass=ABCMeta):
62
158
  response (str): The response from the LLM.
63
159
 
64
160
  Returns:
65
- OutputT: The parsed response.
161
+ PromptOutputT_co: The parsed response.
66
162
 
67
163
  Raises:
68
164
  ResponseParsingError: If the response cannot be parsed.
@@ -75,16 +171,6 @@ class SimplePrompt(BasePrompt):
75
171
  """
76
172
 
77
173
  def __init__(self, content: str | ChatFormat) -> None:
78
- self._content = content
79
-
80
- @property
81
- def chat(self) -> ChatFormat:
82
- """
83
- Returns the conversation in the chat format.
84
-
85
- Returns:
86
- ChatFormat: A list of dictionaries, each containing the role and content of a message.
87
- """
88
- if isinstance(self._content, str):
89
- return [{"role": "user", "content": self._content}]
90
- return self._content
174
+ self._conversation_history: list[dict[str, Any]] = (
175
+ [{"role": "user", "content": content}] if isinstance(content, str) else content
176
+ )
@@ -4,7 +4,7 @@ import os
4
4
  from pathlib import Path
5
5
  from typing import Any, get_origin
6
6
 
7
- from ragbits.core.audit import trace
7
+ from ragbits.core.audit.traces import trace
8
8
  from ragbits.core.config import core_config
9
9
  from ragbits.core.prompt import Prompt
10
10
 
@@ -8,12 +8,28 @@ class PromptError(Exception):
8
8
  self.message = message
9
9
 
10
10
 
11
- class PromptWithImagesOfInvalidFormat(PromptError):
11
+ class PromptWithAttachmentOfUnknownFormat(PromptError):
12
12
  """
13
- Raised when there is an image attached to the prompt that is not in the correct format.
13
+ Raised when there is a file with an unknown format attached to the prompt.
14
14
  """
15
15
 
16
- def __init__(
17
- self, message: str = "Invalid format of image in prompt detected. Use one of supported OpenAI mime types"
18
- ) -> None:
19
- super().__init__(message)
16
+ def __init__(self) -> None:
17
+ super().__init__("Could not determine MIME type for the attachment file")
18
+
19
+
20
+ class PromptWithAttachmentOfUnsupportedFormat(PromptError):
21
+ """
22
+ Raised when there is a file with an unsupported format attached to the prompt.
23
+ """
24
+
25
+ def __init__(self, mime_type: str) -> None:
26
+ super().__init__(f"Unsupported MIME type for the attachment file: {mime_type}")
27
+
28
+
29
+ class PromptWithEmptyAttachment(PromptError):
30
+ """
31
+ Raised when there is an empty file attached to the prompt.
32
+ """
33
+
34
+ def __init__(self) -> None:
35
+ super().__init__("Attachment must have either bytes data or URL provided")
@@ -1,24 +1,38 @@
1
1
  import asyncio
2
2
  import base64
3
- import imghdr
3
+ import mimetypes
4
4
  import textwrap
5
+ import warnings
5
6
  from abc import ABCMeta
6
7
  from collections.abc import Awaitable, Callable
7
8
  from typing import Any, Generic, cast, get_args, get_origin, overload
8
9
 
10
+ import filetype
9
11
  from jinja2 import Environment, Template, meta
10
12
  from pydantic import BaseModel
11
13
  from typing_extensions import TypeVar, get_original_bases
12
14
 
13
- from ragbits.core.prompt.base import BasePromptWithParser, ChatFormat, OutputT
14
- from ragbits.core.prompt.exceptions import PromptWithImagesOfInvalidFormat
15
+ from ragbits.core.prompt.base import BasePromptWithParser, ChatFormat, PromptOutputT
16
+ from ragbits.core.prompt.exceptions import (
17
+ PromptWithAttachmentOfUnknownFormat,
18
+ PromptWithAttachmentOfUnsupportedFormat,
19
+ PromptWithEmptyAttachment,
20
+ )
15
21
  from ragbits.core.prompt.parsers import DEFAULT_PARSERS, build_pydantic_parser
16
22
 
17
- InputT = TypeVar("InputT", bound=BaseModel | None)
18
- FewShotExample = tuple[str | InputT, str | OutputT]
23
+ PromptInputT = TypeVar("PromptInputT", bound=BaseModel | None)
24
+ FewShotExample = tuple[str | PromptInputT, str | PromptOutputT]
19
25
 
20
26
 
21
- class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=ABCMeta):
27
+ class Attachment(BaseModel):
28
+ """Represents an attachment that can be passed to a LLM."""
29
+
30
+ url: str | None = None
31
+ data: bytes | None = None
32
+ mime_type: str | None = None
33
+
34
+
35
+ class Prompt(Generic[PromptInputT, PromptOutputT], BasePromptWithParser[PromptOutputT], metaclass=ABCMeta):
22
36
  """
23
37
  Generic class for prompts. It contains the system and user prompts, and additional messages.
24
38
 
@@ -31,15 +45,15 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
31
45
 
32
46
  # Additional messages to be added to the conversation after the system prompt,
33
47
  # pairs of user message and assistant response
34
- few_shots: list[FewShotExample[InputT, OutputT]] = []
48
+ few_shots: list[FewShotExample[PromptInputT, PromptOutputT]] = []
35
49
 
36
50
  # function that parses the response from the LLM to specific output type
37
51
  # if not provided, the class tries to set it automatically based on the output type
38
- response_parser: Callable[[str], OutputT | Awaitable[OutputT]]
52
+ response_parser: Callable[[str], PromptOutputT | Awaitable[PromptOutputT]]
39
53
 
40
54
  # Automatically set in __init_subclass__
41
- input_type: type[InputT] | None
42
- output_type: type[OutputT]
55
+ input_type: type[PromptInputT] | None
56
+ output_type: type[PromptOutputT]
43
57
  system_prompt_template: Template | None
44
58
  user_prompt_template: Template
45
59
  image_input_fields: list[str] | None = None
@@ -72,7 +86,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
72
86
  return Template(template)
73
87
 
74
88
  @classmethod
75
- def _render_template(cls, template: Template, input_data: InputT | None) -> str:
89
+ def _render_template(cls, template: Template, input_data: PromptInputT | None) -> str:
76
90
  # Workaround for not being able to use `input is not None`
77
91
  # because of mypy issue: https://github.com/python/mypy/issues/12622
78
92
  context = {}
@@ -81,29 +95,39 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
81
95
  return template.render(**context)
82
96
 
83
97
  @classmethod
84
- def _get_images_from_input_data(cls, input_data: InputT | None | str) -> list[bytes | str]:
85
- images: list[bytes | str] = []
98
+ def _get_attachments_from_input_data(cls, input_data: PromptInputT | None | str) -> list[Attachment]:
99
+ attachments: list[Attachment] = []
100
+
86
101
  if isinstance(input_data, BaseModel):
102
+ # to support backward compatibility with the old image_input_fields:
87
103
  image_input_fields = cls.image_input_fields or []
88
104
  for field in image_input_fields:
89
- images_for_field = getattr(input_data, field)
90
- if images_for_field:
91
- if isinstance(images_for_field, list | tuple):
92
- images.extend(images_for_field)
93
- else:
94
- images.append(images_for_field)
95
- return images
105
+ if image_for_field := getattr(input_data, field):
106
+ iter_image = [image_for_field] if isinstance(image_for_field, (str | bytes)) else image_for_field
107
+ attachments.extend(
108
+ [
109
+ Attachment(url=image) if isinstance(image, str) else Attachment(data=image)
110
+ for image in iter_image
111
+ ]
112
+ )
113
+ for value in input_data.__dict__.values():
114
+ if isinstance(value, Attachment):
115
+ attachments.append(value)
116
+ elif isinstance(value, list):
117
+ attachments.extend([item for item in value if isinstance(item, Attachment)])
118
+
119
+ return attachments
96
120
 
97
121
  @classmethod
98
122
  def _format_message(cls, message: str) -> str:
99
123
  return textwrap.dedent(message).strip()
100
124
 
101
125
  @classmethod
102
- def _detect_response_parser(cls) -> Callable[[str], OutputT | Awaitable[OutputT]]:
126
+ def _detect_response_parser(cls) -> Callable[[str], PromptOutputT | Awaitable[PromptOutputT]]:
103
127
  if hasattr(cls, "response_parser") and cls.response_parser is not None:
104
128
  return cls.response_parser
105
129
  if issubclass(cls.output_type, BaseModel):
106
- return cast(Callable[[str], OutputT], build_pydantic_parser(cls.output_type))
130
+ return cast(Callable[[str], PromptOutputT], build_pydantic_parser(cls.output_type))
107
131
  if cls.output_type in DEFAULT_PARSERS:
108
132
  return DEFAULT_PARSERS[cls.output_type]
109
133
  raise ValueError(f"Response parser not provided for output type {cls.output_type}")
@@ -123,28 +147,58 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
123
147
  return super().__init_subclass__(**kwargs)
124
148
 
125
149
  @overload
126
- def __init__(self: "Prompt[None, OutputT]") -> None: ...
150
+ def __init__(
151
+ self: "Prompt[None, PromptOutputT]", input_data: None = None, history: ChatFormat | None = None
152
+ ) -> None: ...
127
153
 
128
154
  @overload
129
- def __init__(self: "Prompt[InputT, OutputT]", input_data: InputT) -> None: ...
155
+ def __init__(
156
+ self: "Prompt[PromptInputT, PromptOutputT]", input_data: PromptInputT, history: ChatFormat | None = None
157
+ ) -> None: ...
158
+
159
+ def __init__(self, input_data: PromptInputT | None = None, history: ChatFormat | None = None) -> None:
160
+ """
161
+ Initialize the Prompt instance.
130
162
 
131
- def __init__(self, *args: Any, **kwargs: Any) -> None:
132
- input_data = args[0] if args else kwargs.get("input_data")
163
+ Args:
164
+ input_data: The input data to render the prompt templates with. Must be a Pydantic model
165
+ instance if the prompt has an input type defined. If None and input_type is defined,
166
+ a ValueError will be raised.
167
+ history: Optional conversation history to initialize the prompt with. If provided,
168
+ should be in the standard OpenAI chat format.
169
+
170
+ Raises:
171
+ ValueError: If input_data is None when input_type is defined, or if input_data
172
+ is a string instead of a Pydantic model.
173
+ """
133
174
  if self.input_type and input_data is None:
134
175
  raise ValueError("Input data must be provided")
135
176
 
177
+ if isinstance(input_data, str):
178
+ raise ValueError("Input data must be of pydantic model type")
179
+
180
+ if self.image_input_fields:
181
+ warnings.warn(
182
+ message="The 'image_input_fields' attribute is deprecated. "
183
+ "Use 'Attachment' objects in the prompt input instead.",
184
+ category=UserWarning,
185
+ stacklevel=2,
186
+ )
187
+
136
188
  self.rendered_system_prompt = (
137
189
  self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None
138
190
  )
139
- self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data)
140
- self.images = self._get_images_from_input_data(input_data)
191
+ self.attachments = self._get_attachments_from_input_data(input_data)
141
192
 
142
193
  # Additional few shot examples that can be added dynamically using methods
143
194
  # (in opposite to the static `few_shots` attribute which is defined in the class)
144
- self._instance_few_shots: list[FewShotExample[InputT, OutputT]] = []
195
+ self._instance_few_shots: list[FewShotExample[PromptInputT, PromptOutputT]] = []
145
196
 
146
197
  # Additional conversation history that can be added dynamically using methods
147
- self._conversation_history: list[dict[str, Any]] = []
198
+ self._conversation_history: list[dict[str, Any]] = history or []
199
+
200
+ self.add_user_message(input_data or self._render_template(self.user_prompt_template, input_data))
201
+ self.rendered_user_prompt = self.chat[-1]["content"]
148
202
  super().__init__()
149
203
 
150
204
  @property
@@ -155,12 +209,6 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
155
209
  Returns:
156
210
  ChatFormat: A list of dictionaries, each containing the role and content of a message.
157
211
  """
158
- user_content = (
159
- [{"type": "text", "text": self.rendered_user_prompt}]
160
- + [self._create_message_with_image(image) for image in self.images]
161
- if self.images
162
- else self.rendered_user_prompt
163
- )
164
212
  chat = [
165
213
  *(
166
214
  [{"role": "system", "content": self.rendered_system_prompt}]
@@ -168,23 +216,24 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
168
216
  else []
169
217
  ),
170
218
  *self.list_few_shots(),
171
- {"role": "user", "content": user_content},
172
219
  *self._conversation_history,
173
220
  ]
174
221
  return chat
175
222
 
176
- def add_few_shot(self, user_message: str | InputT, assistant_message: str | OutputT) -> "Prompt[InputT, OutputT]":
223
+ def add_few_shot(
224
+ self, user_message: str | PromptInputT, assistant_message: str | PromptOutputT
225
+ ) -> "Prompt[PromptInputT, PromptOutputT]":
177
226
  """
178
227
  Add a few-shot example to the conversation.
179
228
 
180
229
  Args:
181
- user_message (str | InputT): The raw user message or input data that will be rendered using the
230
+ user_message (str | PromptInputT): The raw user message or input data that will be rendered using the
182
231
  user prompt template.
183
- assistant_message (str | OutputT): The raw assistant response or output data that will be cast to a string
184
- or in case of a Pydantic model, to JSON.
232
+ assistant_message (str | PromptOutputT): The raw assistant response or output data that will be cast to a
233
+ string or in case of a Pydantic model, to JSON.
185
234
 
186
235
  Returns:
187
- Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
236
+ Prompt[PromptInputT, PromptOutputT]: The current prompt instance in order to allow chaining.
188
237
  """
189
238
  self._instance_few_shots.append((user_message, assistant_message))
190
239
  return self
@@ -201,13 +250,14 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
201
250
  for user_message, assistant_message in self.few_shots + self._instance_few_shots:
202
251
  if not isinstance(user_message, str):
203
252
  rendered_text_message = self._render_template(self.user_prompt_template, user_message)
204
- images_in_input_data = self._get_images_from_input_data(user_message)
205
- if images_in_input_data:
206
- user_content = [{"type": "text", "text": rendered_text_message}] + [
207
- self._create_message_with_image(image) for image in images_in_input_data
208
- ]
209
- else:
210
- user_content = rendered_text_message
253
+ input_attachments = self._get_attachments_from_input_data(user_message)
254
+
255
+ user_parts: list[dict[str, Any]] = [{"type": "text", "text": rendered_text_message}]
256
+ for attachment in input_attachments:
257
+ user_parts.append(self.create_message_with_attachment(attachment))
258
+
259
+ user_content = user_parts if len(user_parts) > 1 else rendered_text_message
260
+
211
261
  else:
212
262
  user_content = user_message
213
263
 
@@ -220,56 +270,38 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
220
270
  result.append({"role": "assistant", "content": assistant_content})
221
271
  return result
222
272
 
223
- def add_user_message(self, message: str | dict[str, Any] | InputT) -> "Prompt[InputT, OutputT]":
273
+ def add_user_message(self, message: str | dict[str, Any] | PromptInputT) -> "Prompt[PromptInputT, PromptOutputT]": # type: ignore
224
274
  """
225
275
  Add a user message to the conversation history.
226
276
 
227
277
  Args:
228
- message (str | dict[str, Any] | InputT): The user message content. Can be:
278
+ message (str | dict[str, Any] | PromptInputT): The user message content. Can be:
229
279
  - A string: Used directly as content
230
280
  - A dictionary: With format {"type": "text", "text": "message"} or image content
231
- - An InputT model: Will be rendered using the user prompt template
281
+ - An PromptInputT model: Will be rendered using the user prompt template
232
282
 
233
283
  Returns:
234
- Prompt[InputT, OutputT]: The current prompt instance to allow chaining.
284
+ Prompt[PromptInputT, PromptOutputT]: The current prompt instance to allow chaining.
235
285
  """
236
- content: str | list[dict[str, Any]] | dict[str, Any] | InputT
286
+ content: str | list[dict[str, Any]] | dict[str, Any]
237
287
 
238
288
  if isinstance(message, BaseModel):
239
- # Type checking to ensure we're passing InputT to the methods
240
- input_model: InputT = cast(InputT, message)
289
+ # Type checking to ensure we're passing PromptInputT to the methods
290
+ input_model: PromptInputT = cast(PromptInputT, message)
241
291
 
242
292
  # Render the message using the template if it's an input model
243
293
  rendered_text = self._render_template(self.user_prompt_template, input_model)
244
- images_in_input = self._get_images_from_input_data(input_model)
294
+ input_attachments = self._get_attachments_from_input_data(input_model)
245
295
 
246
- if images_in_input:
247
- content = [{"type": "text", "text": rendered_text}] + [
248
- self._create_message_with_image(image) for image in images_in_input
249
- ]
250
- else:
251
- content = rendered_text
252
- else:
253
- # Use the message directly if it's a string or dict
254
- content = message
255
-
256
- self._conversation_history.append({"role": "user", "content": content})
257
- return self
258
-
259
- def add_assistant_message(self, message: str | OutputT) -> "Prompt[InputT, OutputT]":
260
- """
261
- Add an assistant message to the conversation history.
296
+ content_list: list[dict[str, Any]] = [{"type": "text", "text": rendered_text}]
297
+ for attachment in input_attachments:
298
+ content_list.append(self.create_message_with_attachment(attachment))
262
299
 
263
- Args:
264
- message (str): The assistant message content.
300
+ content = content_list if len(content_list) > 1 else rendered_text
301
+ else:
302
+ content = cast(str | dict[str, Any], message)
265
303
 
266
- Returns:
267
- Prompt[InputT, OutputT]: The current prompt instance to allow chaining.
268
- """
269
- if isinstance(message, BaseModel):
270
- message = message.model_dump_json()
271
- self._conversation_history.append({"role": "assistant", "content": str(message)})
272
- return self
304
+ return super().add_user_message(content)
273
305
 
274
306
  def list_images(self) -> list[str]:
275
307
  """
@@ -281,25 +313,75 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
281
313
  return [
282
314
  content["image_url"]["url"]
283
315
  for message in self.chat
316
+ if message["content"]
284
317
  for content in message["content"]
285
318
  if isinstance(message["content"], list) and content["type"] == "image_url"
286
319
  ]
287
320
 
321
+ def list_pdfs(self) -> list[str]: # noqa: PLR6301
322
+ """
323
+ Returns the PDFs in form of URLs or base64 encoded strings.
324
+
325
+ Returns:
326
+ list of PDFs
327
+ """
328
+ return [
329
+ content["file"].get("file_id") or content["file"]["file_data"]
330
+ for message in self.chat
331
+ if message["content"]
332
+ for content in message["content"]
333
+ if isinstance(message["content"], list) and content["type"] == "file"
334
+ ]
335
+
288
336
  @staticmethod
289
- def _create_message_with_image(image: str | bytes) -> dict:
290
- if isinstance(image, bytes):
291
- image_type = imghdr.what(None, image)
292
- if not image_type:
293
- raise PromptWithImagesOfInvalidFormat()
294
- image_url = f"data:image/{image_type};base64,{base64.b64encode(image).decode('utf-8')}"
295
- else:
296
- image_url = image
297
- return {
298
- "type": "image_url",
299
- "image_url": {
300
- "url": image_url,
301
- },
302
- }
337
+ def create_message_with_attachment(attachment: Attachment) -> dict[str, Any]:
338
+ """
339
+ Create a message with an attachment in the OpenAI chat format.
340
+
341
+ Args:
342
+ attachment (Attachment): The attachment to include in the message.
343
+
344
+ Returns:
345
+ dict[str, Any]: A dictionary representing the message with the attachment.
346
+ """
347
+ if not (attachment.data or attachment.url):
348
+ raise PromptWithEmptyAttachment()
349
+
350
+ def get_mime_type() -> str:
351
+ if attachment.mime_type:
352
+ return attachment.mime_type
353
+ if attachment.data:
354
+ detected = filetype.guess(attachment.data)
355
+ if detected:
356
+ return detected.mime
357
+ if attachment.url:
358
+ guessed_type, _ = mimetypes.guess_type(attachment.url)
359
+ if guessed_type:
360
+ return guessed_type
361
+ raise PromptWithAttachmentOfUnknownFormat()
362
+
363
+ def encode_data_url(data: bytes, mime: str) -> str:
364
+ return f"data:{mime};base64,{base64.b64encode(data).decode('utf-8')}"
365
+
366
+ mime_type = get_mime_type()
367
+
368
+ if mime_type.startswith("image/"):
369
+ return {
370
+ "type": "image_url",
371
+ "image_url": {
372
+ "url": attachment.url or encode_data_url(attachment.data, mime_type) # type: ignore[arg-type]
373
+ },
374
+ }
375
+
376
+ if mime_type == "application/pdf":
377
+ return {
378
+ "type": "file",
379
+ "file": {"file_id": attachment.url}
380
+ if attachment.url
381
+ else {"file_data": encode_data_url(attachment.data, mime_type)}, # type: ignore[arg-type]
382
+ }
383
+
384
+ raise PromptWithAttachmentOfUnsupportedFormat(mime_type)
303
385
 
304
386
  def output_schema(self) -> dict | type[BaseModel] | None:
305
387
  """
@@ -321,7 +403,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
321
403
  """
322
404
  return issubclass(self.output_type, BaseModel)
323
405
 
324
- async def parse_response(self, response: str) -> OutputT:
406
+ async def parse_response(self, response: str) -> PromptOutputT:
325
407
  """
326
408
  Parse the response from the LLM to the desired output type.
327
409
 
@@ -329,7 +411,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
329
411
  response (str): The response from the LLM.
330
412
 
331
413
  Returns:
332
- OutputT: The parsed response.
414
+ PromptOutputT: The parsed response.
333
415
 
334
416
  Raises:
335
417
  ResponseParsingError: If the response cannot be parsed.
@@ -6,11 +6,13 @@ from ragbits.core.sources.hf import HuggingFaceSource
6
6
  from ragbits.core.sources.local import LocalFileSource
7
7
  from ragbits.core.sources.s3 import S3Source
8
8
  from ragbits.core.sources.web import WebSource
9
+ from ragbits.core.sources.google_drive import GoogleDriveSource
9
10
 
10
11
  __all__ = [
11
12
  "AzureBlobStorageSource",
12
13
  "GCSSource",
13
14
  "GitSource",
15
+ "GoogleDriveSource",
14
16
  "HuggingFaceSource",
15
17
  "LocalFileSource",
16
18
  "S3Source",