lightrag-hku 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Gustavo Ye
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,39 @@
1
+ Metadata-Version: 2.1
2
+ Name: lightrag-hku
3
+ Version: 0.0.1
4
+ Summary: LightRAG: Simple and Fast Retrieval-Augmented Generation
5
+ Home-page: https://github.com/HKUDS/GraphEdit
6
+ Author: ZiruiGuo
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: MIT License
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Python: >=3.9
11
+ Description-Content-Type: text/markdown
12
+ License-File: LICENSE
13
+ Requires-Dist: openai
14
+ Requires-Dist: tiktoken
15
+ Requires-Dist: networkx
16
+ Requires-Dist: graspologic
17
+ Requires-Dist: nano-vectordb
18
+ Requires-Dist: hnswlib
19
+ Requires-Dist: xxhash
20
+ Requires-Dist: tenacity
21
+
22
+ # LightRAG: Simple and Fast Retrieval-Augmented Generation
23
+
24
+ ## Citation
25
+
26
+ ```
27
+ @article{guo2024lightrag,
28
+ title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
29
+ author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
30
+ year={2024},
31
+ eprint={},
32
+ archivePrefix={arXiv},
33
+ primaryClass={cs.IR}
34
+ }
35
+ ```
36
+
37
+ ## Acknowledgement
38
+
39
+ The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
@@ -0,0 +1,18 @@
1
+ # LightRAG: Simple and Fast Retrieval-Augmented Generation
2
+
3
+ ## Citation
4
+
5
+ ```
6
+ @article{guo2024lightrag,
7
+ title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
8
+ author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
9
+ year={2024},
10
+ eprint={},
11
+ archivePrefix={arXiv},
12
+ primaryClass={cs.IR}
13
+ }
14
+ ```
15
+
16
+ ## Acknowledgement
17
+
18
+ The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
@@ -0,0 +1,5 @@
1
+ from .lightrag import LightRAG, QueryParam
2
+
3
+ __version__ = "0.0.1"
4
+ __author__ = "Zirui Guo"
5
+ __url__ = "https://github.com/HKUDS/GraphEdit"
@@ -0,0 +1,116 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import TypedDict, Union, Literal, Generic, TypeVar
3
+
4
+ import numpy as np
5
+
6
+ from .utils import EmbeddingFunc
7
+
8
+ TextChunkSchema = TypedDict(
9
+ "TextChunkSchema",
10
+ {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
11
+ )
12
+
13
+ T = TypeVar("T")
14
+
15
+ @dataclass
16
+ class QueryParam:
17
+ mode: Literal["local", "global", "hybird", "naive"] = "global"
18
+ only_need_context: bool = False
19
+ response_type: str = "Multiple Paragraphs"
20
+ top_k: int = 60
21
+ max_token_for_text_unit: int = 4000
22
+ max_token_for_global_context: int = 4000
23
+ max_token_for_local_context: int = 4000
24
+
25
+
26
+ @dataclass
27
+ class StorageNameSpace:
28
+ namespace: str
29
+ global_config: dict
30
+
31
+ async def index_done_callback(self):
32
+ """commit the storage operations after indexing"""
33
+ pass
34
+
35
+ async def query_done_callback(self):
36
+ """commit the storage operations after querying"""
37
+ pass
38
+
39
+ @dataclass
40
+ class BaseVectorStorage(StorageNameSpace):
41
+ embedding_func: EmbeddingFunc
42
+ meta_fields: set = field(default_factory=set)
43
+
44
+ async def query(self, query: str, top_k: int) -> list[dict]:
45
+ raise NotImplementedError
46
+
47
+ async def upsert(self, data: dict[str, dict]):
48
+ """Use 'content' field from value for embedding, use key as id.
49
+ If embedding_func is None, use 'embedding' field from value
50
+ """
51
+ raise NotImplementedError
52
+
53
+ @dataclass
54
+ class BaseKVStorage(Generic[T], StorageNameSpace):
55
+ async def all_keys(self) -> list[str]:
56
+ raise NotImplementedError
57
+
58
+ async def get_by_id(self, id: str) -> Union[T, None]:
59
+ raise NotImplementedError
60
+
61
+ async def get_by_ids(
62
+ self, ids: list[str], fields: Union[set[str], None] = None
63
+ ) -> list[Union[T, None]]:
64
+ raise NotImplementedError
65
+
66
+ async def filter_keys(self, data: list[str]) -> set[str]:
67
+ """return un-exist keys"""
68
+ raise NotImplementedError
69
+
70
+ async def upsert(self, data: dict[str, T]):
71
+ raise NotImplementedError
72
+
73
+ async def drop(self):
74
+ raise NotImplementedError
75
+
76
+
77
+ @dataclass
78
+ class BaseGraphStorage(StorageNameSpace):
79
+ async def has_node(self, node_id: str) -> bool:
80
+ raise NotImplementedError
81
+
82
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
83
+ raise NotImplementedError
84
+
85
+ async def node_degree(self, node_id: str) -> int:
86
+ raise NotImplementedError
87
+
88
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
89
+ raise NotImplementedError
90
+
91
+ async def get_node(self, node_id: str) -> Union[dict, None]:
92
+ raise NotImplementedError
93
+
94
+ async def get_edge(
95
+ self, source_node_id: str, target_node_id: str
96
+ ) -> Union[dict, None]:
97
+ raise NotImplementedError
98
+
99
+ async def get_node_edges(
100
+ self, source_node_id: str
101
+ ) -> Union[list[tuple[str, str]], None]:
102
+ raise NotImplementedError
103
+
104
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
105
+ raise NotImplementedError
106
+
107
+ async def upsert_edge(
108
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
109
+ ):
110
+ raise NotImplementedError
111
+
112
+ async def clustering(self, algorithm: str):
113
+ raise NotImplementedError
114
+
115
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
116
+ raise NotImplementedError("Node embedding is not used in lightrag.")
@@ -0,0 +1,289 @@
1
+ import asyncio
2
+ import os
3
+ from dataclasses import asdict, dataclass, field
4
+ from datetime import datetime
5
+ from functools import partial
6
+ from typing import Type, cast
7
+
8
+ from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
9
+ from .operate import (
10
+ chunking_by_token_size,
11
+ extract_entities,
12
+ local_query,
13
+ global_query,
14
+ hybird_query,
15
+ naive_query,
16
+ )
17
+
18
+ from .storage import (
19
+ JsonKVStorage,
20
+ NanoVectorDBStorage,
21
+ NetworkXStorage,
22
+ )
23
+ from .utils import (
24
+ EmbeddingFunc,
25
+ compute_mdhash_id,
26
+ limit_async_func_call,
27
+ convert_response_to_json,
28
+ logger,
29
+ set_logger,
30
+ )
31
+ from .base import (
32
+ BaseGraphStorage,
33
+ BaseKVStorage,
34
+ BaseVectorStorage,
35
+ StorageNameSpace,
36
+ QueryParam,
37
+ )
38
+
39
+ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
40
+ try:
41
+ # If there is already an event loop, use it.
42
+ loop = asyncio.get_event_loop()
43
+ except RuntimeError:
44
+ # If in a sub-thread, create a new event loop.
45
+ logger.info("Creating a new event loop in a sub-thread.")
46
+ loop = asyncio.new_event_loop()
47
+ asyncio.set_event_loop(loop)
48
+ return loop
49
+
50
+ @dataclass
51
+ class LightRAG:
52
+ working_dir: str = field(
53
+ default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
54
+ )
55
+
56
+ # text chunking
57
+ chunk_token_size: int = 1200
58
+ chunk_overlap_token_size: int = 100
59
+ tiktoken_model_name: str = "gpt-4o-mini"
60
+
61
+ # entity extraction
62
+ entity_extract_max_gleaning: int = 1
63
+ entity_summary_to_max_tokens: int = 500
64
+
65
+ # node embedding
66
+ node_embedding_algorithm: str = "node2vec"
67
+ node2vec_params: dict = field(
68
+ default_factory=lambda: {
69
+ "dimensions": 1536,
70
+ "num_walks": 10,
71
+ "walk_length": 40,
72
+ "num_walks": 10,
73
+ "window_size": 2,
74
+ "iterations": 3,
75
+ "random_seed": 3,
76
+ }
77
+ )
78
+
79
+ # text embedding
80
+ embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
81
+ embedding_batch_num: int = 32
82
+ embedding_func_max_async: int = 16
83
+
84
+ # LLM
85
+ llm_model_func: callable = gpt_4o_mini_complete
86
+ llm_model_max_token_size: int = 32768
87
+ llm_model_max_async: int = 16
88
+
89
+ # storage
90
+ key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
91
+ vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
92
+ vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
93
+ graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
94
+ enable_llm_cache: bool = True
95
+
96
+ # extension
97
+ addon_params: dict = field(default_factory=dict)
98
+ convert_response_to_json_func: callable = convert_response_to_json
99
+
100
+ def __post_init__(self):
101
+ log_file = os.path.join(self.working_dir, "lightrag.log")
102
+ set_logger(log_file)
103
+ logger.info(f"Logger initialized for working directory: {self.working_dir}")
104
+
105
+ _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
106
+ logger.debug(f"LightRAG init with param:\n {_print_config}\n")
107
+
108
+ if not os.path.exists(self.working_dir):
109
+ logger.info(f"Creating working directory {self.working_dir}")
110
+ os.makedirs(self.working_dir)
111
+
112
+ self.full_docs = self.key_string_value_json_storage_cls(
113
+ namespace="full_docs", global_config=asdict(self)
114
+ )
115
+
116
+ self.text_chunks = self.key_string_value_json_storage_cls(
117
+ namespace="text_chunks", global_config=asdict(self)
118
+ )
119
+
120
+ self.llm_response_cache = (
121
+ self.key_string_value_json_storage_cls(
122
+ namespace="llm_response_cache", global_config=asdict(self)
123
+ )
124
+ if self.enable_llm_cache
125
+ else None
126
+ )
127
+ self.chunk_entity_relation_graph = self.graph_storage_cls(
128
+ namespace="chunk_entity_relation", global_config=asdict(self)
129
+ )
130
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
131
+ self.embedding_func
132
+ )
133
+ self.entities_vdb = (
134
+ self.vector_db_storage_cls(
135
+ namespace="entities",
136
+ global_config=asdict(self),
137
+ embedding_func=self.embedding_func,
138
+ meta_fields={"entity_name"}
139
+ )
140
+ )
141
+ self.relationships_vdb = (
142
+ self.vector_db_storage_cls(
143
+ namespace="relationships",
144
+ global_config=asdict(self),
145
+ embedding_func=self.embedding_func,
146
+ meta_fields={"src_id", "tgt_id"}
147
+ )
148
+ )
149
+ self.chunks_vdb = (
150
+ self.vector_db_storage_cls(
151
+ namespace="chunks",
152
+ global_config=asdict(self),
153
+ embedding_func=self.embedding_func,
154
+ )
155
+ )
156
+
157
+ self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
158
+ partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
159
+ )
160
+
161
+ def insert(self, string_or_strings):
162
+ loop = always_get_an_event_loop()
163
+ return loop.run_until_complete(self.ainsert(string_or_strings))
164
+
165
+ async def ainsert(self, string_or_strings):
166
+ try:
167
+ if isinstance(string_or_strings, str):
168
+ string_or_strings = [string_or_strings]
169
+ # ---------- new docs
170
+ new_docs = {
171
+ compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
172
+ for c in string_or_strings
173
+ }
174
+ _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
175
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
176
+ if not len(new_docs):
177
+ logger.warning(f"All docs are already in the storage")
178
+ return
179
+ logger.info(f"[New Docs] inserting {len(new_docs)} docs")
180
+
181
+ # ---------- chunking
182
+ inserting_chunks = {}
183
+ for doc_key, doc in new_docs.items():
184
+ chunks = {
185
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
186
+ **dp,
187
+ "full_doc_id": doc_key,
188
+ }
189
+ for dp in chunking_by_token_size(
190
+ doc["content"],
191
+ overlap_token_size=self.chunk_overlap_token_size,
192
+ max_token_size=self.chunk_token_size,
193
+ tiktoken_model=self.tiktoken_model_name,
194
+ )
195
+ }
196
+ inserting_chunks.update(chunks)
197
+ _add_chunk_keys = await self.text_chunks.filter_keys(
198
+ list(inserting_chunks.keys())
199
+ )
200
+ inserting_chunks = {
201
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
202
+ }
203
+ if not len(inserting_chunks):
204
+ logger.warning(f"All chunks are already in the storage")
205
+ return
206
+ logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
207
+
208
+ await self.chunks_vdb.upsert(inserting_chunks)
209
+
210
+ # ---------- commit upsertings and indexing
211
+ await self.full_docs.upsert(new_docs)
212
+ await self.text_chunks.upsert(inserting_chunks)
213
+ finally:
214
+ await self._insert_done()
215
+
216
+ async def _insert_done(self):
217
+ tasks = []
218
+ for storage_inst in [
219
+ self.full_docs,
220
+ self.text_chunks,
221
+ self.llm_response_cache,
222
+ self.entities_vdb,
223
+ self.relationships_vdb,
224
+ self.chunks_vdb,
225
+ self.chunk_entity_relation_graph,
226
+ ]:
227
+ if storage_inst is None:
228
+ continue
229
+ tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
230
+ await asyncio.gather(*tasks)
231
+
232
+ def query(self, query: str, param: QueryParam = QueryParam()):
233
+ loop = always_get_an_event_loop()
234
+ return loop.run_until_complete(self.aquery(query, param))
235
+
236
+ async def aquery(self, query: str, param: QueryParam = QueryParam()):
237
+ if param.mode == "local":
238
+ response = await local_query(
239
+ query,
240
+ self.chunk_entity_relation_graph,
241
+ self.entities_vdb,
242
+ self.relationships_vdb,
243
+ self.text_chunks,
244
+ param,
245
+ asdict(self),
246
+ )
247
+ elif param.mode == "global":
248
+ response = await global_query(
249
+ query,
250
+ self.chunk_entity_relation_graph,
251
+ self.entities_vdb,
252
+ self.relationships_vdb,
253
+ self.text_chunks,
254
+ param,
255
+ asdict(self),
256
+ )
257
+ elif param.mode == "hybird":
258
+ response = await hybird_query(
259
+ query,
260
+ self.chunk_entity_relation_graph,
261
+ self.entities_vdb,
262
+ self.relationships_vdb,
263
+ self.text_chunks,
264
+ param,
265
+ asdict(self),
266
+ )
267
+ elif param.mode == "naive":
268
+ response = await naive_query(
269
+ query,
270
+ self.chunks_vdb,
271
+ self.text_chunks,
272
+ param,
273
+ asdict(self),
274
+ )
275
+ else:
276
+ raise ValueError(f"Unknown mode {param.mode}")
277
+ await self._query_done()
278
+ return response
279
+
280
+
281
+ async def _query_done(self):
282
+ tasks = []
283
+ for storage_inst in [self.llm_response_cache]:
284
+ if storage_inst is None:
285
+ continue
286
+ tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
287
+ await asyncio.gather(*tasks)
288
+
289
+
@@ -0,0 +1,88 @@
1
+ import os
2
+ import numpy as np
3
+ from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
4
+ from tenacity import (
5
+ retry,
6
+ stop_after_attempt,
7
+ wait_exponential,
8
+ retry_if_exception_type,
9
+ )
10
+
11
+ from .base import BaseKVStorage
12
+ from .utils import compute_args_hash, wrap_embedding_func_with_attrs
13
+
14
+ @retry(
15
+ stop=stop_after_attempt(3),
16
+ wait=wait_exponential(multiplier=1, min=4, max=10),
17
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
18
+ )
19
+ async def openai_complete_if_cache(
20
+ model, prompt, system_prompt=None, history_messages=[], **kwargs
21
+ ) -> str:
22
+ openai_async_client = AsyncOpenAI()
23
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
24
+ messages = []
25
+ if system_prompt:
26
+ messages.append({"role": "system", "content": system_prompt})
27
+ messages.extend(history_messages)
28
+ messages.append({"role": "user", "content": prompt})
29
+ if hashing_kv is not None:
30
+ args_hash = compute_args_hash(model, messages)
31
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
32
+ if if_cache_return is not None:
33
+ return if_cache_return["return"]
34
+
35
+ response = await openai_async_client.chat.completions.create(
36
+ model=model, messages=messages, **kwargs
37
+ )
38
+
39
+ if hashing_kv is not None:
40
+ await hashing_kv.upsert(
41
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
42
+ )
43
+ return response.choices[0].message.content
44
+
45
+ async def gpt_4o_complete(
46
+ prompt, system_prompt=None, history_messages=[], **kwargs
47
+ ) -> str:
48
+ return await openai_complete_if_cache(
49
+ "gpt-4o",
50
+ prompt,
51
+ system_prompt=system_prompt,
52
+ history_messages=history_messages,
53
+ **kwargs,
54
+ )
55
+
56
+
57
+ async def gpt_4o_mini_complete(
58
+ prompt, system_prompt=None, history_messages=[], **kwargs
59
+ ) -> str:
60
+ return await openai_complete_if_cache(
61
+ "gpt-4o-mini",
62
+ prompt,
63
+ system_prompt=system_prompt,
64
+ history_messages=history_messages,
65
+ **kwargs,
66
+ )
67
+
68
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
69
+ @retry(
70
+ stop=stop_after_attempt(3),
71
+ wait=wait_exponential(multiplier=1, min=4, max=10),
72
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
73
+ )
74
+ async def openai_embedding(texts: list[str]) -> np.ndarray:
75
+ openai_async_client = AsyncOpenAI()
76
+ response = await openai_async_client.embeddings.create(
77
+ model="text-embedding-3-small", input=texts, encoding_format="float"
78
+ )
79
+ return np.array([dp.embedding for dp in response.data])
80
+
81
+ if __name__ == "__main__":
82
+ import asyncio
83
+
84
+ async def main():
85
+ result = await gpt_4o_mini_complete('How are you?')
86
+ print(result)
87
+
88
+ asyncio.run(main())