ragbits-core 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
File without changes
@@ -0,0 +1,19 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class Embeddings(ABC):
5
+ """
6
+ Abstract client for communication with embedding models.
7
+ """
8
+
9
+ @abstractmethod
10
+ async def embed_text(self, data: list[str]) -> list[list[float]]:
11
+ """
12
+ Creates embeddings for the given strings.
13
+
14
+ Args:
15
+ data: List of strings to get embeddings for.
16
+
17
+ Returns:
18
+ List of embeddings for the given strings.
19
+ """
@@ -0,0 +1,36 @@
1
+ class EmbeddingError(Exception):
2
+ """
3
+ Base class for all exceptions raised by the EmbeddingClient.
4
+ """
5
+
6
+ def __init__(self, message: str) -> None:
7
+ super().__init__(message)
8
+ self.message = message
9
+
10
+
11
+ class EmbeddingConnectionError(EmbeddingError):
12
+ """
13
+ Raised when there is an error connecting to the embedding API.
14
+ """
15
+
16
+ def __init__(self, message: str = "Connection error.") -> None:
17
+ super().__init__(message)
18
+
19
+
20
+ class EmbeddingStatusError(EmbeddingError):
21
+ """
22
+ Raised when an API response has a status code of 4xx or 5xx.
23
+ """
24
+
25
+ def __init__(self, message: str, status_code: int) -> None:
26
+ super().__init__(message)
27
+ self.status_code = status_code
28
+
29
+
30
+ class EmbeddingResponseError(EmbeddingError):
31
+ """
32
+ Raised when an API response has an invalid schema.
33
+ """
34
+
35
+ def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
36
+ super().__init__(message)
@@ -0,0 +1,85 @@
1
+ from typing import Optional
2
+
3
+ try:
4
+ import litellm
5
+
6
+ HAS_LITELLM = True
7
+ except ImportError:
8
+ HAS_LITELLM = False
9
+
10
+ from ragbits.core.embeddings.base import Embeddings
11
+ from ragbits.core.embeddings.exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError
12
+
13
+
14
+ class LiteLLMEmbeddings(Embeddings):
15
+ """
16
+ Client for creating text embeddings using LiteLLM API.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model: str = "text-embedding-3-small",
22
+ options: Optional[dict] = None,
23
+ api_base: Optional[str] = None,
24
+ api_key: Optional[str] = None,
25
+ api_version: Optional[str] = None,
26
+ ) -> None:
27
+ """
28
+ Constructs the LiteLLMEmbeddingClient.
29
+
30
+ Args:
31
+ model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
32
+ to be used. Default is "text-embedding-3-small".
33
+ options: Additional options to pass to the LiteLLM API.
34
+ api_base: The API endpoint you want to call the model with.
35
+ api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
36
+ for more information, follow the instructions for your specific vendor in the\
37
+ [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
38
+ api_version: The API version for the call.
39
+
40
+ Raises:
41
+ ImportError: If the 'litellm' extra requirements are not installed.
42
+ """
43
+ if not HAS_LITELLM:
44
+ raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models")
45
+
46
+ super().__init__()
47
+ self.model = model
48
+ self.options = options or {}
49
+ self.api_base = api_base
50
+ self.api_key = api_key
51
+ self.api_version = api_version
52
+
53
+ async def embed_text(self, data: list[str]) -> list[list[float]]:
54
+ """
55
+ Creates embeddings for the given strings.
56
+
57
+ Args:
58
+ data: List of strings to get embeddings for.
59
+
60
+ Returns:
61
+ List of embeddings for the given strings.
62
+
63
+ Raises:
64
+ EmbeddingConnectionError: If there is a connection error with the embedding API.
65
+ EmbeddingStatusError: If the embedding API returns an error status code.
66
+ EmbeddingResponseError: If the embedding API response is invalid.
67
+ """
68
+
69
+ try:
70
+ response = await litellm.aembedding(
71
+ input=data,
72
+ model=self.model,
73
+ api_base=self.api_base,
74
+ api_key=self.api_key,
75
+ api_version=self.api_version,
76
+ **self.options,
77
+ )
78
+ except litellm.openai.APIConnectionError as exc:
79
+ raise EmbeddingConnectionError() from exc
80
+ except litellm.openai.APIStatusError as exc:
81
+ raise EmbeddingStatusError(exc.message, exc.status_code) from exc
82
+ except litellm.openai.APIResponseValidationError as exc:
83
+ raise EmbeddingResponseError() from exc
84
+
85
+ return [embedding["embedding"] for embedding in response.data]
@@ -0,0 +1,81 @@
1
+ from typing import Iterator, Optional
2
+
3
+ try:
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ HAS_LOCAL_EMBEDDINGS = True
9
+ except ImportError:
10
+ HAS_LOCAL_EMBEDDINGS = False
11
+
12
+ from ragbits.core.embeddings.base import Embeddings
13
+
14
+
15
+ class LocalEmbeddings(Embeddings):
16
+ """
17
+ Class for interaction with any encoder available in HuggingFace.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model_name: str,
23
+ api_key: Optional[str] = None,
24
+ ) -> None:
25
+ """
26
+ Constructs a new local LLM instance.
27
+
28
+ Args:
29
+ model_name: Name of the model to use.
30
+ api_key: The API key for Hugging Face authentication.
31
+
32
+ Raises:
33
+ ImportError: If the 'local' extra requirements are not installed.
34
+ """
35
+ if not HAS_LOCAL_EMBEDDINGS:
36
+ raise ImportError("You need to install the 'local' extra requirements to use local embeddings models")
37
+
38
+ super().__init__()
39
+
40
+ self.hf_api_key = api_key
41
+ self.model_name = model_name
42
+
43
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device)
45
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key)
46
+
47
+ async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]:
48
+ """
49
+ Calls the appropriate encoder endpoint with the given data and options.
50
+
51
+ Args:
52
+ data: List of strings to get embeddings for.
53
+ batch_size: Batch size.
54
+
55
+ Returns:
56
+ List of embeddings for the given strings.
57
+ """
58
+ embeddings = []
59
+ for batch in self._batch(data, batch_size):
60
+ batch_dict = self.tokenizer(
61
+ batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
62
+ ).to(self.device)
63
+ with torch.no_grad():
64
+ outputs = self.model(**batch_dict)
65
+ batch_embeddings = self._average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
66
+ batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
67
+ embeddings.extend(batch_embeddings.to("cpu").tolist())
68
+
69
+ torch.cuda.empty_cache()
70
+ return embeddings
71
+
72
+ @staticmethod
73
+ def _batch(data: list[str], batch_size: int) -> Iterator[list[str]]:
74
+ length = len(data)
75
+ for ndx in range(0, length, batch_size):
76
+ yield data[ndx : min(ndx + batch_size, length)]
77
+
78
+ @staticmethod
79
+ def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
80
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
81
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
@@ -0,0 +1,5 @@
1
+ from .base import LLM
2
+ from .litellm import LiteLLM
3
+ from .local import LocalLLM
4
+
5
+ __all__ = ["LLM", "LiteLLM", "LocalLLM"]
@@ -0,0 +1,121 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import cached_property
3
+ from typing import Generic, Optional, Type, cast, overload
4
+
5
+ from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, OutputT
6
+
7
+ from .clients.base import LLMClient, LLMClientOptions, LLMOptions
8
+
9
+
10
+ class LLM(Generic[LLMClientOptions], ABC):
11
+ """
12
+ Abstract class for interaction with Large Language Model.
13
+ """
14
+
15
+ _options_cls: Type[LLMClientOptions]
16
+
17
+ def __init__(self, model_name: str, default_options: Optional[LLMOptions] = None) -> None:
18
+ """
19
+ Constructs a new LLM instance.
20
+
21
+ Args:
22
+ model_name: Name of the model to be used.
23
+ default_options: Default options to be used.
24
+
25
+ Raises:
26
+ TypeError: If the subclass is missing the '_options_cls' attribute.
27
+ """
28
+ self.model_name = model_name
29
+ self.default_options = default_options or self._options_cls()
30
+
31
+ def __init_subclass__(cls) -> None:
32
+ if not hasattr(cls, "_options_cls"):
33
+ raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute")
34
+
35
+ @cached_property
36
+ @abstractmethod
37
+ def client(self) -> LLMClient:
38
+ """
39
+ Client for the LLM.
40
+ """
41
+
42
+ def count_tokens(self, prompt: BasePrompt) -> int:
43
+ """
44
+ Counts tokens in the prompt.
45
+
46
+ Args:
47
+ prompt: Formatted prompt template with conversation and response parsing configuration.
48
+
49
+ Returns:
50
+ Number of tokens in the prompt.
51
+ """
52
+ return sum(len(message["content"]) for message in prompt.chat)
53
+
54
+ async def generate_raw(
55
+ self,
56
+ prompt: BasePrompt,
57
+ *,
58
+ options: Optional[LLMOptions] = None,
59
+ ) -> str:
60
+ """
61
+ Prepares and sends a prompt to the LLM and returns the raw response (without parsing).
62
+
63
+ Args:
64
+ prompt: Formatted prompt template with conversation.
65
+ options: Options to use for the LLM client.
66
+
67
+ Returns:
68
+ Raw text response from LLM.
69
+ """
70
+ options = (self.default_options | options) if options else self.default_options
71
+
72
+ response = await self.client.call(
73
+ conversation=prompt.chat,
74
+ options=options,
75
+ json_mode=prompt.json_mode,
76
+ output_schema=prompt.output_schema(),
77
+ )
78
+
79
+ return response
80
+
81
+ @overload
82
+ async def generate(
83
+ self,
84
+ prompt: BasePromptWithParser[OutputT],
85
+ *,
86
+ options: Optional[LLMOptions] = None,
87
+ ) -> OutputT:
88
+ ...
89
+
90
+ @overload
91
+ async def generate(
92
+ self,
93
+ prompt: BasePrompt,
94
+ *,
95
+ options: Optional[LLMOptions] = None,
96
+ ) -> OutputT:
97
+ ...
98
+
99
+ async def generate(
100
+ self,
101
+ prompt: BasePrompt,
102
+ *,
103
+ options: Optional[LLMOptions] = None,
104
+ ) -> OutputT:
105
+ """
106
+ Prepares and sends a prompt to the LLM and returns response parsed to the
107
+ output type of the prompt (if available).
108
+
109
+ Args:
110
+ prompt: Formatted prompt template with conversation and optional response parsing configuration.
111
+ options: Options to use for the LLM client.
112
+
113
+ Returns:
114
+ Text response from LLM.
115
+ """
116
+ response = await self.generate_raw(prompt, options=options)
117
+
118
+ if isinstance(prompt, BasePromptWithParser):
119
+ return prompt.parse_response(response)
120
+
121
+ return cast(OutputT, response)
@@ -0,0 +1,12 @@
1
+ from .base import LLMClient, LLMOptions
2
+ from .litellm import LiteLLMClient, LiteLLMOptions
3
+ from .local import LocalLLMClient, LocalLLMOptions
4
+
5
+ __all__ = [
6
+ "LLMClient",
7
+ "LLMOptions",
8
+ "LiteLLMClient",
9
+ "LiteLLMOptions",
10
+ "LocalLLMClient",
11
+ "LocalLLMOptions",
12
+ ]
@@ -0,0 +1,86 @@
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import asdict, dataclass
3
+ from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from ragbits.core.prompt import ChatFormat
8
+
9
+ from ..types import NotGiven
10
+
11
+ LLMClientOptions = TypeVar("LLMClientOptions", bound="LLMOptions")
12
+
13
+
14
+ @dataclass
15
+ class LLMOptions(ABC):
16
+ """
17
+ Abstract dataclass that represents all available LLM call options.
18
+ """
19
+
20
+ _not_given: ClassVar[Any] = None
21
+
22
+ def __or__(self, other: "LLMOptions") -> "LLMOptions":
23
+ """
24
+ Merges two LLMOptions, prioritizing non-NOT_GIVEN values from the 'other' object.
25
+ """
26
+ self_dict = asdict(self)
27
+ other_dict = asdict(other)
28
+
29
+ updated_dict = {
30
+ key: other_dict.get(key, self_dict[key])
31
+ if not isinstance(other_dict.get(key), NotGiven)
32
+ else self_dict[key]
33
+ for key in self_dict
34
+ }
35
+
36
+ return self.__class__(**updated_dict)
37
+
38
+ def dict(self) -> Dict[str, Any]:
39
+ """
40
+ Creates a dictionary representation of the LLMOptions instance.
41
+ If a value is None, it will be replaced with a provider-specific not-given sentinel.
42
+
43
+ Returns:
44
+ A dictionary representation of the LLMOptions instance.
45
+ """
46
+ options = asdict(self)
47
+ return {
48
+ key: self._not_given if value is None or isinstance(value, NotGiven) else value
49
+ for key, value in options.items()
50
+ }
51
+
52
+
53
+ class LLMClient(Generic[LLMClientOptions], ABC):
54
+ """
55
+ Abstract client for a direct communication with LLM.
56
+ """
57
+
58
+ def __init__(self, model_name: str) -> None:
59
+ """
60
+ Constructs a new LLMClient instance.
61
+
62
+ Args:
63
+ model_name: Name of the model to be used.
64
+ """
65
+ self.model_name = model_name
66
+
67
+ @abstractmethod
68
+ async def call(
69
+ self,
70
+ conversation: ChatFormat,
71
+ options: LLMClientOptions,
72
+ json_mode: bool = False,
73
+ output_schema: Optional[Type[BaseModel] | Dict] = None,
74
+ ) -> str:
75
+ """
76
+ Calls LLM inference API.
77
+
78
+ Args:
79
+ conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
80
+ options: Additional settings used by LLM.
81
+ json_mode: Force the response to be in JSON format.
82
+ output_schema: Schema for structured response (either Pydantic model or a JSON schema).
83
+
84
+ Returns:
85
+ Response string from LLM.
86
+ """
@@ -0,0 +1,36 @@
1
+ class LLMError(Exception):
2
+ """
3
+ Base class for all exceptions raised by the LLMClient.
4
+ """
5
+
6
+ def __init__(self, message: str) -> None:
7
+ super().__init__(message)
8
+ self.message = message
9
+
10
+
11
+ class LLMConnectionError(LLMError):
12
+ """
13
+ Raised when there is an error connecting to the LLM API.
14
+ """
15
+
16
+ def __init__(self, message: str = "Connection error.") -> None:
17
+ super().__init__(message)
18
+
19
+
20
+ class LLMStatusError(LLMError):
21
+ """
22
+ Raised when an API response has a status code of 4xx or 5xx.
23
+ """
24
+
25
+ def __init__(self, message: str, status_code: int) -> None:
26
+ super().__init__(message)
27
+ self.status_code = status_code
28
+
29
+
30
+ class LLMResponseError(LLMError):
31
+ """
32
+ Raised when an API response has an invalid schema.
33
+ """
34
+
35
+ def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
36
+ super().__init__(message)
@@ -0,0 +1,129 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Type, Union
3
+
4
+ from pydantic import BaseModel
5
+
6
+ try:
7
+ import litellm
8
+
9
+ HAS_LITELLM = True
10
+ except ImportError:
11
+ HAS_LITELLM = False
12
+
13
+
14
+ from ragbits.core.prompt import ChatFormat
15
+
16
+ from ..types import NOT_GIVEN, NotGiven
17
+ from .base import LLMClient, LLMOptions
18
+ from .exceptions import LLMConnectionError, LLMResponseError, LLMStatusError
19
+
20
+
21
+ @dataclass
22
+ class LiteLLMOptions(LLMOptions):
23
+ """
24
+ Dataclass that represents all available LLM call options for the LiteLLM client.
25
+ Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input).
26
+ """
27
+
28
+ frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
29
+ max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN
30
+ n: Union[Optional[int], NotGiven] = NOT_GIVEN
31
+ presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
32
+ seed: Union[Optional[int], NotGiven] = NOT_GIVEN
33
+ stop: Union[Optional[Union[str, List[str]]], NotGiven] = NOT_GIVEN
34
+ temperature: Union[Optional[float], NotGiven] = NOT_GIVEN
35
+ top_p: Union[Optional[float], NotGiven] = NOT_GIVEN
36
+ mock_response: Union[Optional[str], NotGiven] = NOT_GIVEN
37
+
38
+
39
+ class LiteLLMClient(LLMClient[LiteLLMOptions]):
40
+ """
41
+ Client for the LiteLLM that supports calls to 100+ LLMs APIs, including OpenAI, Anthropic, VertexAI,
42
+ Hugging Face and others.
43
+ """
44
+
45
+ _options_cls = LiteLLMOptions
46
+
47
+ def __init__(
48
+ self,
49
+ model_name: str,
50
+ *,
51
+ base_url: Optional[str] = None,
52
+ api_key: Optional[str] = None,
53
+ api_version: Optional[str] = None,
54
+ use_structured_output: bool = False,
55
+ ) -> None:
56
+ """
57
+ Constructs a new LiteLLMClient instance.
58
+
59
+ Args:
60
+ model_name: Name of the model to use.
61
+ base_url: Base URL of the LLM API.
62
+ api_key: API key used to authenticate with the LLM API.
63
+ api_version: API version of the LLM API.
64
+ use_structured_output: Whether to request a structured output from the model. Default is False.
65
+
66
+ Raises:
67
+ ImportError: If the 'litellm' extra requirements are not installed.
68
+ """
69
+ if not HAS_LITELLM:
70
+ raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM models")
71
+
72
+ super().__init__(model_name)
73
+ self.base_url = base_url
74
+ self.api_key = api_key
75
+ self.api_version = api_version
76
+ self.use_structured_output = use_structured_output
77
+
78
+ async def call(
79
+ self,
80
+ conversation: ChatFormat,
81
+ options: LiteLLMOptions,
82
+ json_mode: bool = False,
83
+ output_schema: Optional[Type[BaseModel] | Dict] = None,
84
+ ) -> str:
85
+ """
86
+ Calls the appropriate LLM endpoint with the given prompt and options.
87
+
88
+ Args:
89
+ conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
90
+ options: Additional settings used by the LLM.
91
+ json_mode: Force the response to be in JSON format.
92
+ output_schema: Output schema for requesting a specific response format.
93
+ Only used if the client has been initialized with `use_structured_output=True`.
94
+
95
+ Returns:
96
+ Response string from LLM.
97
+
98
+ Raises:
99
+ LLMConnectionError: If there is a connection error with the LLM API.
100
+ LLMStatusError: If the LLM API returns an error status code.
101
+ LLMResponseError: If the LLM API response is invalid.
102
+ """
103
+ supported_params = litellm.get_supported_openai_params(model=self.model_name)
104
+
105
+ response_format = None
106
+ if supported_params is not None and "response_format" in supported_params:
107
+ if output_schema is not None and self.use_structured_output:
108
+ response_format = output_schema
109
+ elif json_mode:
110
+ response_format = {"type": "json_object"}
111
+
112
+ try:
113
+ response = await litellm.acompletion(
114
+ messages=conversation,
115
+ model=self.model_name,
116
+ base_url=self.base_url,
117
+ api_key=self.api_key,
118
+ api_version=self.api_version,
119
+ response_format=response_format,
120
+ **options.dict(),
121
+ )
122
+ except litellm.openai.APIConnectionError as exc:
123
+ raise LLMConnectionError() from exc
124
+ except litellm.openai.APIStatusError as exc:
125
+ raise LLMStatusError(exc.message, exc.status_code) from exc
126
+ except litellm.openai.APIResponseValidationError as exc:
127
+ raise LLMResponseError() from exc
128
+
129
+ return response.choices[0].message.content