ragbits-core 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits_core-0.0.1/.gitignore +97 -0
- ragbits_core-0.0.1/PKG-INFO +30 -0
- ragbits_core-0.0.1/README.md +1 -0
- ragbits_core-0.0.1/examples/llm_example.py +62 -0
- ragbits_core-0.0.1/examples/prompt_example.py +52 -0
- ragbits_core-0.0.1/pyproject.toml +68 -0
- ragbits_core-0.0.1/src/py.typed +0 -0
- ragbits_core-0.0.1/src/ragbits/core/__init__.py +0 -0
- ragbits_core-0.0.1/src/ragbits/core/embeddings/__init__.py +0 -0
- ragbits_core-0.0.1/src/ragbits/core/embeddings/base.py +19 -0
- ragbits_core-0.0.1/src/ragbits/core/embeddings/exceptions.py +36 -0
- ragbits_core-0.0.1/src/ragbits/core/embeddings/litellm.py +85 -0
- ragbits_core-0.0.1/src/ragbits/core/embeddings/local.py +81 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/__init__.py +5 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/base.py +121 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/clients/__init__.py +12 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/clients/base.py +86 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/clients/exceptions.py +36 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/clients/litellm.py +129 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/clients/local.py +103 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/litellm.py +85 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/local.py +71 -0
- ragbits_core-0.0.1/src/ragbits/core/llms/types.py +31 -0
- ragbits_core-0.0.1/src/ragbits/core/prompt/__init__.py +3 -0
- ragbits_core-0.0.1/src/ragbits/core/prompt/base.py +60 -0
- ragbits_core-0.0.1/src/ragbits/core/prompt/parsers.py +130 -0
- ragbits_core-0.0.1/src/ragbits/core/prompt/prompt.py +192 -0
- ragbits_core-0.0.1/src/ragbits/py.typed +0 -0
- ragbits_core-0.0.1/tests/unit/__init__.py +0 -0
- ragbits_core-0.0.1/tests/unit/llms/__init__.py +0 -0
- ragbits_core-0.0.1/tests/unit/llms/test_litellm.py +142 -0
- ragbits_core-0.0.1/tests/unit/prompts/__init__.py +0 -0
- ragbits_core-0.0.1/tests/unit/prompts/test_parsers.py +156 -0
- ragbits_core-0.0.1/tests/unit/prompts/test_prompt.py +200 -0
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Directories
|
|
2
|
+
.vscode/
|
|
3
|
+
.idea/
|
|
4
|
+
.neptune/
|
|
5
|
+
.pytest_cache/
|
|
6
|
+
.mypy_cache/
|
|
7
|
+
venv/
|
|
8
|
+
__pycache__/
|
|
9
|
+
**.egg-info/
|
|
10
|
+
|
|
11
|
+
# Byte-compiled / optimized / DLL files
|
|
12
|
+
__pycache__/
|
|
13
|
+
*.py[cod]
|
|
14
|
+
*$py.class
|
|
15
|
+
|
|
16
|
+
# C extensions
|
|
17
|
+
*.so
|
|
18
|
+
|
|
19
|
+
# Distribution / packaging
|
|
20
|
+
.Python
|
|
21
|
+
env/
|
|
22
|
+
build/
|
|
23
|
+
develop-eggs/
|
|
24
|
+
dist/
|
|
25
|
+
downloads/
|
|
26
|
+
eggs/
|
|
27
|
+
.eggs/
|
|
28
|
+
lib/
|
|
29
|
+
lib64/
|
|
30
|
+
parts/
|
|
31
|
+
sdist/
|
|
32
|
+
var/
|
|
33
|
+
*.egg-info/
|
|
34
|
+
.installed.cfg
|
|
35
|
+
*.egg
|
|
36
|
+
|
|
37
|
+
# Sphinx documentation
|
|
38
|
+
docs/_build/
|
|
39
|
+
public/
|
|
40
|
+
# autogenerated package license table
|
|
41
|
+
docs/licenses_table.rst
|
|
42
|
+
|
|
43
|
+
# license dump file
|
|
44
|
+
licenses.txt
|
|
45
|
+
|
|
46
|
+
# File formats
|
|
47
|
+
*.onnx
|
|
48
|
+
*.pyc
|
|
49
|
+
*.pt
|
|
50
|
+
*.pth
|
|
51
|
+
*.pkl
|
|
52
|
+
*.mar
|
|
53
|
+
*.torchscript
|
|
54
|
+
**/.ipynb_checkpoints
|
|
55
|
+
**/dist/
|
|
56
|
+
**/checkpoints/
|
|
57
|
+
**/outputs/
|
|
58
|
+
**/multirun/
|
|
59
|
+
|
|
60
|
+
# Other env files
|
|
61
|
+
.python-version
|
|
62
|
+
pyvenv.cfg
|
|
63
|
+
pip-selfcheck.json
|
|
64
|
+
|
|
65
|
+
# Unit test / coverage reports
|
|
66
|
+
htmlcov/
|
|
67
|
+
.tox/
|
|
68
|
+
.coverage
|
|
69
|
+
.coverage.*
|
|
70
|
+
.cache
|
|
71
|
+
nosetests.xml
|
|
72
|
+
coverage.xml
|
|
73
|
+
*,cover
|
|
74
|
+
.hypothesis/
|
|
75
|
+
|
|
76
|
+
# dotenv
|
|
77
|
+
.env
|
|
78
|
+
|
|
79
|
+
# coverage and pytest reports
|
|
80
|
+
coverage.xml
|
|
81
|
+
report.xml
|
|
82
|
+
|
|
83
|
+
# CMake
|
|
84
|
+
cmake-build-*/
|
|
85
|
+
|
|
86
|
+
# Terraform
|
|
87
|
+
**/.terraform.lock.hcl
|
|
88
|
+
**/.terraform
|
|
89
|
+
|
|
90
|
+
# benchmarks
|
|
91
|
+
benchmarks/sql/data/
|
|
92
|
+
|
|
93
|
+
# mkdocs generated files
|
|
94
|
+
site/
|
|
95
|
+
|
|
96
|
+
# build artifacts
|
|
97
|
+
dist/
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: ragbits-core
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Building blocks for rapid development of GenAI applications
|
|
5
|
+
Author-email: "deepsense.ai" <contact@deepsense.ai>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Keywords: GenAI,Generative AI,LLMs,Large Language Models,Prompt Management,RAG,Retrieval Augmented Generation
|
|
8
|
+
Classifier: Development Status :: 1 - Planning
|
|
9
|
+
Classifier: Environment :: Console
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Natural Language :: English
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Requires-Dist: jinja2>=3.1.4
|
|
21
|
+
Requires-Dist: pydantic>=2.9.1
|
|
22
|
+
Provides-Extra: litellm
|
|
23
|
+
Requires-Dist: litellm~=1.46.0; extra == 'litellm'
|
|
24
|
+
Provides-Extra: local
|
|
25
|
+
Requires-Dist: numpy~=1.24.0; extra == 'local'
|
|
26
|
+
Requires-Dist: torch~=2.2.1; extra == 'local'
|
|
27
|
+
Requires-Dist: transformers~=4.44.2; extra == 'local'
|
|
28
|
+
Description-Content-Type: text/markdown
|
|
29
|
+
|
|
30
|
+
# Ragbits Core
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Ragbits Core
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# requires-python = ">=3.10"
|
|
3
|
+
# dependencies = [
|
|
4
|
+
# "ragbits[litellm]",
|
|
5
|
+
# ]
|
|
6
|
+
# ///
|
|
7
|
+
import asyncio
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from ragbits.core.llms.litellm import LiteLLM
|
|
12
|
+
from ragbits.core.prompt import Prompt
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LoremPromptInput(BaseModel):
|
|
16
|
+
"""
|
|
17
|
+
Input format for the LoremPrompt.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
theme: str
|
|
21
|
+
pun_allowed: bool = False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LoremPromptOutput(BaseModel):
|
|
25
|
+
"""
|
|
26
|
+
Output format for the LoremPrompt.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
joke: str
|
|
30
|
+
joke_category: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class JokePrompt(Prompt[LoremPromptInput, LoremPromptOutput]):
|
|
34
|
+
"""
|
|
35
|
+
A prompt that generates jokes.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
system_prompt = """
|
|
39
|
+
You are a joke generator. The jokes you generate should be funny and not offensive. {% if not pun_allowed %}Also, make sure
|
|
40
|
+
that the jokes do not contain any puns.{% else %}You can use any type of joke, even if it contains puns.{% endif %}
|
|
41
|
+
|
|
42
|
+
Respond as json with two fields: joke and joke_category.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
user_prompt = """
|
|
46
|
+
theme: {{ theme }}
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def main():
|
|
51
|
+
"""
|
|
52
|
+
Example of using the LiteLLM client with a Prompt class. Requires the OPENAI_API_KEY environment variable to be set.
|
|
53
|
+
"""
|
|
54
|
+
llm = LiteLLM("gpt-4o-2024-08-06", use_structured_output=True)
|
|
55
|
+
prompt = JokePrompt(LoremPromptInput(theme="software developers", pun_allowed=True))
|
|
56
|
+
response = await llm.generate(prompt)
|
|
57
|
+
print(f"The LLM generated a (hopefully) funny {response.joke_category} joke:")
|
|
58
|
+
print(response.joke)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
if __name__ == "__main__":
|
|
62
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# /// script
|
|
2
|
+
# requires-python = ">=3.10"
|
|
3
|
+
# dependencies = [
|
|
4
|
+
# "ragbits",
|
|
5
|
+
# ]
|
|
6
|
+
# ///
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from ragbits.core.prompt import Prompt
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LoremPromptInput(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Input format for the LoremPrompt.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
theme: str
|
|
18
|
+
nsfw_allowed: bool = False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LoremPromptOutput(BaseModel):
|
|
22
|
+
"""
|
|
23
|
+
Output format for the LoremPrompt.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
text: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LoremPrompt(Prompt[LoremPromptInput, LoremPromptOutput]):
|
|
30
|
+
"""
|
|
31
|
+
A prompt that generates Lorem Ipsum text.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
system_prompt = """
|
|
35
|
+
You are a helpful Lorem Ipsum generator. The kind of vocablurary that you use besides "Lorem Ipsum" depends
|
|
36
|
+
on the theme provided by the user. Make sure it is latin and not too long. {% if not nsfw_allowed %}Also, make sure
|
|
37
|
+
that the text is safe for work.{% else %}You can use any text, even if it is not safe for work.{% endif %}
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
user_prompt = """
|
|
41
|
+
theme: {{ theme }}
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
if __name__ == "__main__":
|
|
46
|
+
lorem_prompt = LoremPrompt(LoremPromptInput(theme="business"))
|
|
47
|
+
lorem_prompt.add_assistant_message("Lorem Ipsum biznessum dolor copy machinum yearly reportum")
|
|
48
|
+
print("CHAT:")
|
|
49
|
+
print(lorem_prompt.chat)
|
|
50
|
+
print()
|
|
51
|
+
print("OUTPUT MODEL:")
|
|
52
|
+
print(lorem_prompt.output_schema())
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "ragbits-core"
|
|
3
|
+
version = "0.0.1"
|
|
4
|
+
description = "Building blocks for rapid development of GenAI applications"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10"
|
|
7
|
+
license = "MIT"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name = "deepsense.ai", email = "contact@deepsense.ai"}
|
|
10
|
+
]
|
|
11
|
+
keywords = [
|
|
12
|
+
"Retrieval Augmented Generation",
|
|
13
|
+
"RAG",
|
|
14
|
+
"Large Language Models",
|
|
15
|
+
"LLMs",
|
|
16
|
+
"Generative AI",
|
|
17
|
+
"GenAI",
|
|
18
|
+
"Prompt Management"
|
|
19
|
+
]
|
|
20
|
+
classifiers = [
|
|
21
|
+
"Development Status :: 1 - Planning",
|
|
22
|
+
"Environment :: Console",
|
|
23
|
+
"Intended Audience :: Science/Research",
|
|
24
|
+
"License :: OSI Approved :: MIT License",
|
|
25
|
+
"Natural Language :: English",
|
|
26
|
+
"Operating System :: OS Independent",
|
|
27
|
+
"Programming Language :: Python :: 3.10",
|
|
28
|
+
"Programming Language :: Python :: 3.11",
|
|
29
|
+
"Programming Language :: Python :: 3.12",
|
|
30
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
31
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
32
|
+
]
|
|
33
|
+
dependencies = [
|
|
34
|
+
"jinja2>=3.1.4",
|
|
35
|
+
"pydantic>=2.9.1"
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
[project.optional-dependencies]
|
|
39
|
+
litellm = [
|
|
40
|
+
"litellm~=1.46.0",
|
|
41
|
+
]
|
|
42
|
+
local = [
|
|
43
|
+
"torch~=2.2.1",
|
|
44
|
+
"transformers~=4.44.2",
|
|
45
|
+
"numpy~=1.24.0"
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[tool.uv]
|
|
49
|
+
dev-dependencies = [
|
|
50
|
+
"pre-commit~=3.8.0",
|
|
51
|
+
"pytest~=8.3.3",
|
|
52
|
+
"pytest-cov~=5.0.0",
|
|
53
|
+
"pytest-asyncio~=0.24.0",
|
|
54
|
+
"pip-licenses>=4.0.0,<5.0.0"
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
[build-system]
|
|
58
|
+
requires = ["hatchling"]
|
|
59
|
+
build-backend = "hatchling.build"
|
|
60
|
+
|
|
61
|
+
[tool.hatch.metadata]
|
|
62
|
+
allow-direct-references = true
|
|
63
|
+
|
|
64
|
+
[tool.hatch.build.targets.wheel]
|
|
65
|
+
packages = ["src/ragbits"]
|
|
66
|
+
|
|
67
|
+
[tool.pytest.ini_options]
|
|
68
|
+
asyncio_mode = "auto"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Embeddings(ABC):
|
|
5
|
+
"""
|
|
6
|
+
Abstract client for communication with embedding models.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
async def embed_text(self, data: list[str]) -> list[list[float]]:
|
|
11
|
+
"""
|
|
12
|
+
Creates embeddings for the given strings.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
data: List of strings to get embeddings for.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
List of embeddings for the given strings.
|
|
19
|
+
"""
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
class EmbeddingError(Exception):
|
|
2
|
+
"""
|
|
3
|
+
Base class for all exceptions raised by the EmbeddingClient.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
def __init__(self, message: str) -> None:
|
|
7
|
+
super().__init__(message)
|
|
8
|
+
self.message = message
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EmbeddingConnectionError(EmbeddingError):
|
|
12
|
+
"""
|
|
13
|
+
Raised when there is an error connecting to the embedding API.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, message: str = "Connection error.") -> None:
|
|
17
|
+
super().__init__(message)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EmbeddingStatusError(EmbeddingError):
|
|
21
|
+
"""
|
|
22
|
+
Raised when an API response has a status code of 4xx or 5xx.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, message: str, status_code: int) -> None:
|
|
26
|
+
super().__init__(message)
|
|
27
|
+
self.status_code = status_code
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EmbeddingResponseError(EmbeddingError):
|
|
31
|
+
"""
|
|
32
|
+
Raised when an API response has an invalid schema.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
|
|
36
|
+
super().__init__(message)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import litellm
|
|
5
|
+
|
|
6
|
+
HAS_LITELLM = True
|
|
7
|
+
except ImportError:
|
|
8
|
+
HAS_LITELLM = False
|
|
9
|
+
|
|
10
|
+
from ragbits.core.embeddings.base import Embeddings
|
|
11
|
+
from ragbits.core.embeddings.exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LiteLLMEmbeddings(Embeddings):
|
|
15
|
+
"""
|
|
16
|
+
Client for creating text embeddings using LiteLLM API.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: str = "text-embedding-3-small",
|
|
22
|
+
options: Optional[dict] = None,
|
|
23
|
+
api_base: Optional[str] = None,
|
|
24
|
+
api_key: Optional[str] = None,
|
|
25
|
+
api_version: Optional[str] = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""
|
|
28
|
+
Constructs the LiteLLMEmbeddingClient.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
|
|
32
|
+
to be used. Default is "text-embedding-3-small".
|
|
33
|
+
options: Additional options to pass to the LiteLLM API.
|
|
34
|
+
api_base: The API endpoint you want to call the model with.
|
|
35
|
+
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
|
|
36
|
+
for more information, follow the instructions for your specific vendor in the\
|
|
37
|
+
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
|
|
38
|
+
api_version: The API version for the call.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ImportError: If the 'litellm' extra requirements are not installed.
|
|
42
|
+
"""
|
|
43
|
+
if not HAS_LITELLM:
|
|
44
|
+
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models")
|
|
45
|
+
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.model = model
|
|
48
|
+
self.options = options or {}
|
|
49
|
+
self.api_base = api_base
|
|
50
|
+
self.api_key = api_key
|
|
51
|
+
self.api_version = api_version
|
|
52
|
+
|
|
53
|
+
async def embed_text(self, data: list[str]) -> list[list[float]]:
|
|
54
|
+
"""
|
|
55
|
+
Creates embeddings for the given strings.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
data: List of strings to get embeddings for.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of embeddings for the given strings.
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
EmbeddingConnectionError: If there is a connection error with the embedding API.
|
|
65
|
+
EmbeddingStatusError: If the embedding API returns an error status code.
|
|
66
|
+
EmbeddingResponseError: If the embedding API response is invalid.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
response = await litellm.aembedding(
|
|
71
|
+
input=data,
|
|
72
|
+
model=self.model,
|
|
73
|
+
api_base=self.api_base,
|
|
74
|
+
api_key=self.api_key,
|
|
75
|
+
api_version=self.api_version,
|
|
76
|
+
**self.options,
|
|
77
|
+
)
|
|
78
|
+
except litellm.openai.APIConnectionError as exc:
|
|
79
|
+
raise EmbeddingConnectionError() from exc
|
|
80
|
+
except litellm.openai.APIStatusError as exc:
|
|
81
|
+
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
|
|
82
|
+
except litellm.openai.APIResponseValidationError as exc:
|
|
83
|
+
raise EmbeddingResponseError() from exc
|
|
84
|
+
|
|
85
|
+
return [embedding["embedding"] for embedding in response.data]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
from typing import Iterator, Optional
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from transformers import AutoModel, AutoTokenizer
|
|
7
|
+
|
|
8
|
+
HAS_LOCAL_EMBEDDINGS = True
|
|
9
|
+
except ImportError:
|
|
10
|
+
HAS_LOCAL_EMBEDDINGS = False
|
|
11
|
+
|
|
12
|
+
from ragbits.core.embeddings.base import Embeddings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LocalEmbeddings(Embeddings):
|
|
16
|
+
"""
|
|
17
|
+
Class for interaction with any encoder available in HuggingFace.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model_name: str,
|
|
23
|
+
api_key: Optional[str] = None,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Constructs a new local LLM instance.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_name: Name of the model to use.
|
|
30
|
+
api_key: The API key for Hugging Face authentication.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ImportError: If the 'local' extra requirements are not installed.
|
|
34
|
+
"""
|
|
35
|
+
if not HAS_LOCAL_EMBEDDINGS:
|
|
36
|
+
raise ImportError("You need to install the 'local' extra requirements to use local embeddings models")
|
|
37
|
+
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.hf_api_key = api_key
|
|
41
|
+
self.model_name = model_name
|
|
42
|
+
|
|
43
|
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
44
|
+
self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device)
|
|
45
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key)
|
|
46
|
+
|
|
47
|
+
async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]:
|
|
48
|
+
"""
|
|
49
|
+
Calls the appropriate encoder endpoint with the given data and options.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
data: List of strings to get embeddings for.
|
|
53
|
+
batch_size: Batch size.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
List of embeddings for the given strings.
|
|
57
|
+
"""
|
|
58
|
+
embeddings = []
|
|
59
|
+
for batch in self._batch(data, batch_size):
|
|
60
|
+
batch_dict = self.tokenizer(
|
|
61
|
+
batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
|
|
62
|
+
).to(self.device)
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
outputs = self.model(**batch_dict)
|
|
65
|
+
batch_embeddings = self._average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
|
|
66
|
+
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
|
|
67
|
+
embeddings.extend(batch_embeddings.to("cpu").tolist())
|
|
68
|
+
|
|
69
|
+
torch.cuda.empty_cache()
|
|
70
|
+
return embeddings
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def _batch(data: list[str], batch_size: int) -> Iterator[list[str]]:
|
|
74
|
+
length = len(data)
|
|
75
|
+
for ndx in range(0, length, batch_size):
|
|
76
|
+
yield data[ndx : min(ndx + batch_size, length)]
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
|
81
|
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from functools import cached_property
|
|
3
|
+
from typing import Generic, Optional, Type, cast, overload
|
|
4
|
+
|
|
5
|
+
from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, OutputT
|
|
6
|
+
|
|
7
|
+
from .clients.base import LLMClient, LLMClientOptions, LLMOptions
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LLM(Generic[LLMClientOptions], ABC):
|
|
11
|
+
"""
|
|
12
|
+
Abstract class for interaction with Large Language Model.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
_options_cls: Type[LLMClientOptions]
|
|
16
|
+
|
|
17
|
+
def __init__(self, model_name: str, default_options: Optional[LLMOptions] = None) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Constructs a new LLM instance.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_name: Name of the model to be used.
|
|
23
|
+
default_options: Default options to be used.
|
|
24
|
+
|
|
25
|
+
Raises:
|
|
26
|
+
TypeError: If the subclass is missing the '_options_cls' attribute.
|
|
27
|
+
"""
|
|
28
|
+
self.model_name = model_name
|
|
29
|
+
self.default_options = default_options or self._options_cls()
|
|
30
|
+
|
|
31
|
+
def __init_subclass__(cls) -> None:
|
|
32
|
+
if not hasattr(cls, "_options_cls"):
|
|
33
|
+
raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute")
|
|
34
|
+
|
|
35
|
+
@cached_property
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def client(self) -> LLMClient:
|
|
38
|
+
"""
|
|
39
|
+
Client for the LLM.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def count_tokens(self, prompt: BasePrompt) -> int:
|
|
43
|
+
"""
|
|
44
|
+
Counts tokens in the prompt.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
prompt: Formatted prompt template with conversation and response parsing configuration.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Number of tokens in the prompt.
|
|
51
|
+
"""
|
|
52
|
+
return sum(len(message["content"]) for message in prompt.chat)
|
|
53
|
+
|
|
54
|
+
async def generate_raw(
|
|
55
|
+
self,
|
|
56
|
+
prompt: BasePrompt,
|
|
57
|
+
*,
|
|
58
|
+
options: Optional[LLMOptions] = None,
|
|
59
|
+
) -> str:
|
|
60
|
+
"""
|
|
61
|
+
Prepares and sends a prompt to the LLM and returns the raw response (without parsing).
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
prompt: Formatted prompt template with conversation.
|
|
65
|
+
options: Options to use for the LLM client.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Raw text response from LLM.
|
|
69
|
+
"""
|
|
70
|
+
options = (self.default_options | options) if options else self.default_options
|
|
71
|
+
|
|
72
|
+
response = await self.client.call(
|
|
73
|
+
conversation=prompt.chat,
|
|
74
|
+
options=options,
|
|
75
|
+
json_mode=prompt.json_mode,
|
|
76
|
+
output_schema=prompt.output_schema(),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return response
|
|
80
|
+
|
|
81
|
+
@overload
|
|
82
|
+
async def generate(
|
|
83
|
+
self,
|
|
84
|
+
prompt: BasePromptWithParser[OutputT],
|
|
85
|
+
*,
|
|
86
|
+
options: Optional[LLMOptions] = None,
|
|
87
|
+
) -> OutputT:
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
async def generate(
|
|
92
|
+
self,
|
|
93
|
+
prompt: BasePrompt,
|
|
94
|
+
*,
|
|
95
|
+
options: Optional[LLMOptions] = None,
|
|
96
|
+
) -> OutputT:
|
|
97
|
+
...
|
|
98
|
+
|
|
99
|
+
async def generate(
|
|
100
|
+
self,
|
|
101
|
+
prompt: BasePrompt,
|
|
102
|
+
*,
|
|
103
|
+
options: Optional[LLMOptions] = None,
|
|
104
|
+
) -> OutputT:
|
|
105
|
+
"""
|
|
106
|
+
Prepares and sends a prompt to the LLM and returns response parsed to the
|
|
107
|
+
output type of the prompt (if available).
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
prompt: Formatted prompt template with conversation and optional response parsing configuration.
|
|
111
|
+
options: Options to use for the LLM client.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Text response from LLM.
|
|
115
|
+
"""
|
|
116
|
+
response = await self.generate_raw(prompt, options=options)
|
|
117
|
+
|
|
118
|
+
if isinstance(prompt, BasePromptWithParser):
|
|
119
|
+
return prompt.parse_response(response)
|
|
120
|
+
|
|
121
|
+
return cast(OutputT, response)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .base import LLMClient, LLMOptions
|
|
2
|
+
from .litellm import LiteLLMClient, LiteLLMOptions
|
|
3
|
+
from .local import LocalLLMClient, LocalLLMOptions
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"LLMClient",
|
|
7
|
+
"LLMOptions",
|
|
8
|
+
"LiteLLMClient",
|
|
9
|
+
"LiteLLMOptions",
|
|
10
|
+
"LocalLLMClient",
|
|
11
|
+
"LocalLLMOptions",
|
|
12
|
+
]
|