lightrag-hku 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lightrag_hku-0.0.1/LICENSE +21 -0
- lightrag_hku-0.0.1/PKG-INFO +39 -0
- lightrag_hku-0.0.1/README.md +18 -0
- lightrag_hku-0.0.1/lightrag/__init__.py +5 -0
- lightrag_hku-0.0.1/lightrag/base.py +116 -0
- lightrag_hku-0.0.1/lightrag/lightrag.py +289 -0
- lightrag_hku-0.0.1/lightrag/llm.py +88 -0
- lightrag_hku-0.0.1/lightrag/operate.py +944 -0
- lightrag_hku-0.0.1/lightrag/prompt.py +256 -0
- lightrag_hku-0.0.1/lightrag/storage.py +246 -0
- lightrag_hku-0.0.1/lightrag/utils.py +165 -0
- lightrag_hku-0.0.1/lightrag_hku.egg-info/PKG-INFO +39 -0
- lightrag_hku-0.0.1/lightrag_hku.egg-info/SOURCES.txt +16 -0
- lightrag_hku-0.0.1/lightrag_hku.egg-info/dependency_links.txt +1 -0
- lightrag_hku-0.0.1/lightrag_hku.egg-info/requires.txt +8 -0
- lightrag_hku-0.0.1/lightrag_hku.egg-info/top_level.txt +1 -0
- lightrag_hku-0.0.1/setup.cfg +4 -0
- lightrag_hku-0.0.1/setup.py +39 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Gustavo Ye
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: lightrag-hku
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: LightRAG: Simple and Fast Retrieval-Augmented Generation
|
|
5
|
+
Home-page: https://github.com/HKUDS/GraphEdit
|
|
6
|
+
Author: ZiruiGuo
|
|
7
|
+
Classifier: Programming Language :: Python :: 3
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Operating System :: OS Independent
|
|
10
|
+
Requires-Python: >=3.9
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
License-File: LICENSE
|
|
13
|
+
Requires-Dist: openai
|
|
14
|
+
Requires-Dist: tiktoken
|
|
15
|
+
Requires-Dist: networkx
|
|
16
|
+
Requires-Dist: graspologic
|
|
17
|
+
Requires-Dist: nano-vectordb
|
|
18
|
+
Requires-Dist: hnswlib
|
|
19
|
+
Requires-Dist: xxhash
|
|
20
|
+
Requires-Dist: tenacity
|
|
21
|
+
|
|
22
|
+
# LightRAG: Simple and Fast Retrieval-Augmented Generation
|
|
23
|
+
|
|
24
|
+
## Citation
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
@article{guo2024lightrag,
|
|
28
|
+
title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
|
|
29
|
+
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
|
|
30
|
+
year={2024},
|
|
31
|
+
eprint={},
|
|
32
|
+
archivePrefix={arXiv},
|
|
33
|
+
primaryClass={cs.IR}
|
|
34
|
+
}
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Acknowledgement
|
|
38
|
+
|
|
39
|
+
The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# LightRAG: Simple and Fast Retrieval-Augmented Generation
|
|
2
|
+
|
|
3
|
+
## Citation
|
|
4
|
+
|
|
5
|
+
```
|
|
6
|
+
@article{guo2024lightrag,
|
|
7
|
+
title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
|
|
8
|
+
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
|
|
9
|
+
year={2024},
|
|
10
|
+
eprint={},
|
|
11
|
+
archivePrefix={arXiv},
|
|
12
|
+
primaryClass={cs.IR}
|
|
13
|
+
}
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
## Acknowledgement
|
|
17
|
+
|
|
18
|
+
The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TypedDict, Union, Literal, Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from .utils import EmbeddingFunc
|
|
7
|
+
|
|
8
|
+
TextChunkSchema = TypedDict(
|
|
9
|
+
"TextChunkSchema",
|
|
10
|
+
{"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T")
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class QueryParam:
|
|
17
|
+
mode: Literal["local", "global", "hybird", "naive"] = "global"
|
|
18
|
+
only_need_context: bool = False
|
|
19
|
+
response_type: str = "Multiple Paragraphs"
|
|
20
|
+
top_k: int = 60
|
|
21
|
+
max_token_for_text_unit: int = 4000
|
|
22
|
+
max_token_for_global_context: int = 4000
|
|
23
|
+
max_token_for_local_context: int = 4000
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class StorageNameSpace:
|
|
28
|
+
namespace: str
|
|
29
|
+
global_config: dict
|
|
30
|
+
|
|
31
|
+
async def index_done_callback(self):
|
|
32
|
+
"""commit the storage operations after indexing"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
async def query_done_callback(self):
|
|
36
|
+
"""commit the storage operations after querying"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class BaseVectorStorage(StorageNameSpace):
|
|
41
|
+
embedding_func: EmbeddingFunc
|
|
42
|
+
meta_fields: set = field(default_factory=set)
|
|
43
|
+
|
|
44
|
+
async def query(self, query: str, top_k: int) -> list[dict]:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
async def upsert(self, data: dict[str, dict]):
|
|
48
|
+
"""Use 'content' field from value for embedding, use key as id.
|
|
49
|
+
If embedding_func is None, use 'embedding' field from value
|
|
50
|
+
"""
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
55
|
+
async def all_keys(self) -> list[str]:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
async def get_by_id(self, id: str) -> Union[T, None]:
|
|
59
|
+
raise NotImplementedError
|
|
60
|
+
|
|
61
|
+
async def get_by_ids(
|
|
62
|
+
self, ids: list[str], fields: Union[set[str], None] = None
|
|
63
|
+
) -> list[Union[T, None]]:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
async def filter_keys(self, data: list[str]) -> set[str]:
|
|
67
|
+
"""return un-exist keys"""
|
|
68
|
+
raise NotImplementedError
|
|
69
|
+
|
|
70
|
+
async def upsert(self, data: dict[str, T]):
|
|
71
|
+
raise NotImplementedError
|
|
72
|
+
|
|
73
|
+
async def drop(self):
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class BaseGraphStorage(StorageNameSpace):
|
|
79
|
+
async def has_node(self, node_id: str) -> bool:
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
async def node_degree(self, node_id: str) -> int:
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
async def get_node(self, node_id: str) -> Union[dict, None]:
|
|
92
|
+
raise NotImplementedError
|
|
93
|
+
|
|
94
|
+
async def get_edge(
|
|
95
|
+
self, source_node_id: str, target_node_id: str
|
|
96
|
+
) -> Union[dict, None]:
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
async def get_node_edges(
|
|
100
|
+
self, source_node_id: str
|
|
101
|
+
) -> Union[list[tuple[str, str]], None]:
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
async def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
|
105
|
+
raise NotImplementedError
|
|
106
|
+
|
|
107
|
+
async def upsert_edge(
|
|
108
|
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
|
109
|
+
):
|
|
110
|
+
raise NotImplementedError
|
|
111
|
+
|
|
112
|
+
async def clustering(self, algorithm: str):
|
|
113
|
+
raise NotImplementedError
|
|
114
|
+
|
|
115
|
+
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
|
116
|
+
raise NotImplementedError("Node embedding is not used in lightrag.")
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from functools import partial
|
|
6
|
+
from typing import Type, cast
|
|
7
|
+
|
|
8
|
+
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
|
|
9
|
+
from .operate import (
|
|
10
|
+
chunking_by_token_size,
|
|
11
|
+
extract_entities,
|
|
12
|
+
local_query,
|
|
13
|
+
global_query,
|
|
14
|
+
hybird_query,
|
|
15
|
+
naive_query,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from .storage import (
|
|
19
|
+
JsonKVStorage,
|
|
20
|
+
NanoVectorDBStorage,
|
|
21
|
+
NetworkXStorage,
|
|
22
|
+
)
|
|
23
|
+
from .utils import (
|
|
24
|
+
EmbeddingFunc,
|
|
25
|
+
compute_mdhash_id,
|
|
26
|
+
limit_async_func_call,
|
|
27
|
+
convert_response_to_json,
|
|
28
|
+
logger,
|
|
29
|
+
set_logger,
|
|
30
|
+
)
|
|
31
|
+
from .base import (
|
|
32
|
+
BaseGraphStorage,
|
|
33
|
+
BaseKVStorage,
|
|
34
|
+
BaseVectorStorage,
|
|
35
|
+
StorageNameSpace,
|
|
36
|
+
QueryParam,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
40
|
+
try:
|
|
41
|
+
# If there is already an event loop, use it.
|
|
42
|
+
loop = asyncio.get_event_loop()
|
|
43
|
+
except RuntimeError:
|
|
44
|
+
# If in a sub-thread, create a new event loop.
|
|
45
|
+
logger.info("Creating a new event loop in a sub-thread.")
|
|
46
|
+
loop = asyncio.new_event_loop()
|
|
47
|
+
asyncio.set_event_loop(loop)
|
|
48
|
+
return loop
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class LightRAG:
|
|
52
|
+
working_dir: str = field(
|
|
53
|
+
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# text chunking
|
|
57
|
+
chunk_token_size: int = 1200
|
|
58
|
+
chunk_overlap_token_size: int = 100
|
|
59
|
+
tiktoken_model_name: str = "gpt-4o-mini"
|
|
60
|
+
|
|
61
|
+
# entity extraction
|
|
62
|
+
entity_extract_max_gleaning: int = 1
|
|
63
|
+
entity_summary_to_max_tokens: int = 500
|
|
64
|
+
|
|
65
|
+
# node embedding
|
|
66
|
+
node_embedding_algorithm: str = "node2vec"
|
|
67
|
+
node2vec_params: dict = field(
|
|
68
|
+
default_factory=lambda: {
|
|
69
|
+
"dimensions": 1536,
|
|
70
|
+
"num_walks": 10,
|
|
71
|
+
"walk_length": 40,
|
|
72
|
+
"num_walks": 10,
|
|
73
|
+
"window_size": 2,
|
|
74
|
+
"iterations": 3,
|
|
75
|
+
"random_seed": 3,
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# text embedding
|
|
80
|
+
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
|
81
|
+
embedding_batch_num: int = 32
|
|
82
|
+
embedding_func_max_async: int = 16
|
|
83
|
+
|
|
84
|
+
# LLM
|
|
85
|
+
llm_model_func: callable = gpt_4o_mini_complete
|
|
86
|
+
llm_model_max_token_size: int = 32768
|
|
87
|
+
llm_model_max_async: int = 16
|
|
88
|
+
|
|
89
|
+
# storage
|
|
90
|
+
key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
|
|
91
|
+
vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
|
|
92
|
+
vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
|
|
93
|
+
graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
|
|
94
|
+
enable_llm_cache: bool = True
|
|
95
|
+
|
|
96
|
+
# extension
|
|
97
|
+
addon_params: dict = field(default_factory=dict)
|
|
98
|
+
convert_response_to_json_func: callable = convert_response_to_json
|
|
99
|
+
|
|
100
|
+
def __post_init__(self):
|
|
101
|
+
log_file = os.path.join(self.working_dir, "lightrag.log")
|
|
102
|
+
set_logger(log_file)
|
|
103
|
+
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
|
104
|
+
|
|
105
|
+
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
|
106
|
+
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
|
107
|
+
|
|
108
|
+
if not os.path.exists(self.working_dir):
|
|
109
|
+
logger.info(f"Creating working directory {self.working_dir}")
|
|
110
|
+
os.makedirs(self.working_dir)
|
|
111
|
+
|
|
112
|
+
self.full_docs = self.key_string_value_json_storage_cls(
|
|
113
|
+
namespace="full_docs", global_config=asdict(self)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
self.text_chunks = self.key_string_value_json_storage_cls(
|
|
117
|
+
namespace="text_chunks", global_config=asdict(self)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.llm_response_cache = (
|
|
121
|
+
self.key_string_value_json_storage_cls(
|
|
122
|
+
namespace="llm_response_cache", global_config=asdict(self)
|
|
123
|
+
)
|
|
124
|
+
if self.enable_llm_cache
|
|
125
|
+
else None
|
|
126
|
+
)
|
|
127
|
+
self.chunk_entity_relation_graph = self.graph_storage_cls(
|
|
128
|
+
namespace="chunk_entity_relation", global_config=asdict(self)
|
|
129
|
+
)
|
|
130
|
+
self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
|
|
131
|
+
self.embedding_func
|
|
132
|
+
)
|
|
133
|
+
self.entities_vdb = (
|
|
134
|
+
self.vector_db_storage_cls(
|
|
135
|
+
namespace="entities",
|
|
136
|
+
global_config=asdict(self),
|
|
137
|
+
embedding_func=self.embedding_func,
|
|
138
|
+
meta_fields={"entity_name"}
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
self.relationships_vdb = (
|
|
142
|
+
self.vector_db_storage_cls(
|
|
143
|
+
namespace="relationships",
|
|
144
|
+
global_config=asdict(self),
|
|
145
|
+
embedding_func=self.embedding_func,
|
|
146
|
+
meta_fields={"src_id", "tgt_id"}
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
self.chunks_vdb = (
|
|
150
|
+
self.vector_db_storage_cls(
|
|
151
|
+
namespace="chunks",
|
|
152
|
+
global_config=asdict(self),
|
|
153
|
+
embedding_func=self.embedding_func,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
|
158
|
+
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def insert(self, string_or_strings):
|
|
162
|
+
loop = always_get_an_event_loop()
|
|
163
|
+
return loop.run_until_complete(self.ainsert(string_or_strings))
|
|
164
|
+
|
|
165
|
+
async def ainsert(self, string_or_strings):
|
|
166
|
+
try:
|
|
167
|
+
if isinstance(string_or_strings, str):
|
|
168
|
+
string_or_strings = [string_or_strings]
|
|
169
|
+
# ---------- new docs
|
|
170
|
+
new_docs = {
|
|
171
|
+
compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
|
|
172
|
+
for c in string_or_strings
|
|
173
|
+
}
|
|
174
|
+
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
|
175
|
+
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
176
|
+
if not len(new_docs):
|
|
177
|
+
logger.warning(f"All docs are already in the storage")
|
|
178
|
+
return
|
|
179
|
+
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
|
180
|
+
|
|
181
|
+
# ---------- chunking
|
|
182
|
+
inserting_chunks = {}
|
|
183
|
+
for doc_key, doc in new_docs.items():
|
|
184
|
+
chunks = {
|
|
185
|
+
compute_mdhash_id(dp["content"], prefix="chunk-"): {
|
|
186
|
+
**dp,
|
|
187
|
+
"full_doc_id": doc_key,
|
|
188
|
+
}
|
|
189
|
+
for dp in chunking_by_token_size(
|
|
190
|
+
doc["content"],
|
|
191
|
+
overlap_token_size=self.chunk_overlap_token_size,
|
|
192
|
+
max_token_size=self.chunk_token_size,
|
|
193
|
+
tiktoken_model=self.tiktoken_model_name,
|
|
194
|
+
)
|
|
195
|
+
}
|
|
196
|
+
inserting_chunks.update(chunks)
|
|
197
|
+
_add_chunk_keys = await self.text_chunks.filter_keys(
|
|
198
|
+
list(inserting_chunks.keys())
|
|
199
|
+
)
|
|
200
|
+
inserting_chunks = {
|
|
201
|
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
|
202
|
+
}
|
|
203
|
+
if not len(inserting_chunks):
|
|
204
|
+
logger.warning(f"All chunks are already in the storage")
|
|
205
|
+
return
|
|
206
|
+
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
|
207
|
+
|
|
208
|
+
await self.chunks_vdb.upsert(inserting_chunks)
|
|
209
|
+
|
|
210
|
+
# ---------- commit upsertings and indexing
|
|
211
|
+
await self.full_docs.upsert(new_docs)
|
|
212
|
+
await self.text_chunks.upsert(inserting_chunks)
|
|
213
|
+
finally:
|
|
214
|
+
await self._insert_done()
|
|
215
|
+
|
|
216
|
+
async def _insert_done(self):
|
|
217
|
+
tasks = []
|
|
218
|
+
for storage_inst in [
|
|
219
|
+
self.full_docs,
|
|
220
|
+
self.text_chunks,
|
|
221
|
+
self.llm_response_cache,
|
|
222
|
+
self.entities_vdb,
|
|
223
|
+
self.relationships_vdb,
|
|
224
|
+
self.chunks_vdb,
|
|
225
|
+
self.chunk_entity_relation_graph,
|
|
226
|
+
]:
|
|
227
|
+
if storage_inst is None:
|
|
228
|
+
continue
|
|
229
|
+
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
230
|
+
await asyncio.gather(*tasks)
|
|
231
|
+
|
|
232
|
+
def query(self, query: str, param: QueryParam = QueryParam()):
|
|
233
|
+
loop = always_get_an_event_loop()
|
|
234
|
+
return loop.run_until_complete(self.aquery(query, param))
|
|
235
|
+
|
|
236
|
+
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
|
237
|
+
if param.mode == "local":
|
|
238
|
+
response = await local_query(
|
|
239
|
+
query,
|
|
240
|
+
self.chunk_entity_relation_graph,
|
|
241
|
+
self.entities_vdb,
|
|
242
|
+
self.relationships_vdb,
|
|
243
|
+
self.text_chunks,
|
|
244
|
+
param,
|
|
245
|
+
asdict(self),
|
|
246
|
+
)
|
|
247
|
+
elif param.mode == "global":
|
|
248
|
+
response = await global_query(
|
|
249
|
+
query,
|
|
250
|
+
self.chunk_entity_relation_graph,
|
|
251
|
+
self.entities_vdb,
|
|
252
|
+
self.relationships_vdb,
|
|
253
|
+
self.text_chunks,
|
|
254
|
+
param,
|
|
255
|
+
asdict(self),
|
|
256
|
+
)
|
|
257
|
+
elif param.mode == "hybird":
|
|
258
|
+
response = await hybird_query(
|
|
259
|
+
query,
|
|
260
|
+
self.chunk_entity_relation_graph,
|
|
261
|
+
self.entities_vdb,
|
|
262
|
+
self.relationships_vdb,
|
|
263
|
+
self.text_chunks,
|
|
264
|
+
param,
|
|
265
|
+
asdict(self),
|
|
266
|
+
)
|
|
267
|
+
elif param.mode == "naive":
|
|
268
|
+
response = await naive_query(
|
|
269
|
+
query,
|
|
270
|
+
self.chunks_vdb,
|
|
271
|
+
self.text_chunks,
|
|
272
|
+
param,
|
|
273
|
+
asdict(self),
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(f"Unknown mode {param.mode}")
|
|
277
|
+
await self._query_done()
|
|
278
|
+
return response
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
async def _query_done(self):
|
|
282
|
+
tasks = []
|
|
283
|
+
for storage_inst in [self.llm_response_cache]:
|
|
284
|
+
if storage_inst is None:
|
|
285
|
+
continue
|
|
286
|
+
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
|
287
|
+
await asyncio.gather(*tasks)
|
|
288
|
+
|
|
289
|
+
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
|
4
|
+
from tenacity import (
|
|
5
|
+
retry,
|
|
6
|
+
stop_after_attempt,
|
|
7
|
+
wait_exponential,
|
|
8
|
+
retry_if_exception_type,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from .base import BaseKVStorage
|
|
12
|
+
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
|
13
|
+
|
|
14
|
+
@retry(
|
|
15
|
+
stop=stop_after_attempt(3),
|
|
16
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
17
|
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
|
18
|
+
)
|
|
19
|
+
async def openai_complete_if_cache(
|
|
20
|
+
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
|
21
|
+
) -> str:
|
|
22
|
+
openai_async_client = AsyncOpenAI()
|
|
23
|
+
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
|
24
|
+
messages = []
|
|
25
|
+
if system_prompt:
|
|
26
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
27
|
+
messages.extend(history_messages)
|
|
28
|
+
messages.append({"role": "user", "content": prompt})
|
|
29
|
+
if hashing_kv is not None:
|
|
30
|
+
args_hash = compute_args_hash(model, messages)
|
|
31
|
+
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
|
32
|
+
if if_cache_return is not None:
|
|
33
|
+
return if_cache_return["return"]
|
|
34
|
+
|
|
35
|
+
response = await openai_async_client.chat.completions.create(
|
|
36
|
+
model=model, messages=messages, **kwargs
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if hashing_kv is not None:
|
|
40
|
+
await hashing_kv.upsert(
|
|
41
|
+
{args_hash: {"return": response.choices[0].message.content, "model": model}}
|
|
42
|
+
)
|
|
43
|
+
return response.choices[0].message.content
|
|
44
|
+
|
|
45
|
+
async def gpt_4o_complete(
|
|
46
|
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
|
47
|
+
) -> str:
|
|
48
|
+
return await openai_complete_if_cache(
|
|
49
|
+
"gpt-4o",
|
|
50
|
+
prompt,
|
|
51
|
+
system_prompt=system_prompt,
|
|
52
|
+
history_messages=history_messages,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
async def gpt_4o_mini_complete(
|
|
58
|
+
prompt, system_prompt=None, history_messages=[], **kwargs
|
|
59
|
+
) -> str:
|
|
60
|
+
return await openai_complete_if_cache(
|
|
61
|
+
"gpt-4o-mini",
|
|
62
|
+
prompt,
|
|
63
|
+
system_prompt=system_prompt,
|
|
64
|
+
history_messages=history_messages,
|
|
65
|
+
**kwargs,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
|
69
|
+
@retry(
|
|
70
|
+
stop=stop_after_attempt(3),
|
|
71
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
72
|
+
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
|
73
|
+
)
|
|
74
|
+
async def openai_embedding(texts: list[str]) -> np.ndarray:
|
|
75
|
+
openai_async_client = AsyncOpenAI()
|
|
76
|
+
response = await openai_async_client.embeddings.create(
|
|
77
|
+
model="text-embedding-3-small", input=texts, encoding_format="float"
|
|
78
|
+
)
|
|
79
|
+
return np.array([dp.embedding for dp in response.data])
|
|
80
|
+
|
|
81
|
+
if __name__ == "__main__":
|
|
82
|
+
import asyncio
|
|
83
|
+
|
|
84
|
+
async def main():
|
|
85
|
+
result = await gpt_4o_mini_complete('How are you?')
|
|
86
|
+
print(result)
|
|
87
|
+
|
|
88
|
+
asyncio.run(main())
|