bisheng-langchain 0.3.5.dev1__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/memory/__init__.py +3 -0
- bisheng_langchain/memory/redis.py +104 -0
- bisheng_langchain/rag/init_retrievers/keyword_retriever.py +1 -1
- bisheng_langchain/vectorstores/elastic_keywords_search.py +45 -1
- {bisheng_langchain-0.3.5.dev1.dist-info → bisheng_langchain-0.3.7.dist-info}/METADATA +1 -1
- {bisheng_langchain-0.3.5.dev1.dist-info → bisheng_langchain-0.3.7.dist-info}/RECORD +8 -6
- {bisheng_langchain-0.3.5.dev1.dist-info → bisheng_langchain-0.3.7.dist-info}/WHEEL +1 -1
- {bisheng_langchain-0.3.5.dev1.dist-info → bisheng_langchain-0.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,104 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import redis
|
5
|
+
from langchain.memory.chat_memory import BaseChatMemory
|
6
|
+
from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage, get_buffer_string,
|
7
|
+
message_to_dict, messages_from_dict)
|
8
|
+
from langchain_core.pydantic_v1 import root_validator
|
9
|
+
from pydantic import Field
|
10
|
+
|
11
|
+
|
12
|
+
class ConversationRedisMemory(BaseChatMemory):
|
13
|
+
"""Using redis for storing conversation memory."""
|
14
|
+
redis_client: redis.Redis = Field(default=None, exclude=True)
|
15
|
+
human_prefix: str = 'Human'
|
16
|
+
ai_prefix: str = 'AI'
|
17
|
+
session_id: str = 'session'
|
18
|
+
memory_key: str = 'history' #: :meta private:
|
19
|
+
redis_url: str
|
20
|
+
redis_prefix: str = 'redis_buffer_'
|
21
|
+
ttl: Optional[int] = None
|
22
|
+
|
23
|
+
@root_validator()
|
24
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
25
|
+
redis_url = values.get('redis_url')
|
26
|
+
if not redis_url:
|
27
|
+
raise ValueError('Redis URL must be set')
|
28
|
+
pool = redis.ConnectionPool.from_url(redis_url, max_connections=1)
|
29
|
+
values['redis_client'] = redis.StrictRedis(connection_pool=pool)
|
30
|
+
return values
|
31
|
+
|
32
|
+
@property
|
33
|
+
def buffer(self) -> Any:
|
34
|
+
"""String buffer of memory."""
|
35
|
+
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
36
|
+
|
37
|
+
async def abuffer(self) -> Any:
|
38
|
+
"""String buffer of memory."""
|
39
|
+
return (await self.abuffer_as_messages()
|
40
|
+
if self.return_messages else await self.abuffer_as_str())
|
41
|
+
|
42
|
+
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
|
43
|
+
return get_buffer_string(
|
44
|
+
messages,
|
45
|
+
human_prefix=self.human_prefix,
|
46
|
+
ai_prefix=self.ai_prefix,
|
47
|
+
)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def buffer_as_str(self) -> str:
|
51
|
+
"""Exposes the buffer as a string in case return_messages is True."""
|
52
|
+
messages = self.buffer_as_messages
|
53
|
+
return self._buffer_as_str(messages)
|
54
|
+
|
55
|
+
# return self._buffer_as_str(self.chat_memory.messages)
|
56
|
+
|
57
|
+
async def abuffer_as_str(self) -> str:
|
58
|
+
"""Exposes the buffer as a string in case return_messages is True."""
|
59
|
+
# messages = await self.chat_memory.aget_messages()
|
60
|
+
messages = self.buffer_as_messages
|
61
|
+
return self._buffer_as_str(messages)
|
62
|
+
|
63
|
+
@property
|
64
|
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
65
|
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
66
|
+
# return self.chat_memory.messages
|
67
|
+
redis_value = self.redis_client.lrange(self.redis_prefix + self.session_id, 0, -1)
|
68
|
+
items = [json.loads(m.decode('utf-8')) for m in redis_value[::-1]]
|
69
|
+
messages = messages_from_dict(items)
|
70
|
+
return messages
|
71
|
+
|
72
|
+
async def abuffer_as_messages(self) -> List[BaseMessage]:
|
73
|
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
74
|
+
self.buffer_as_messages
|
75
|
+
|
76
|
+
@property
|
77
|
+
def memory_variables(self) -> List[str]:
|
78
|
+
"""Will always return list of memory variables.
|
79
|
+
|
80
|
+
:meta private:
|
81
|
+
"""
|
82
|
+
return [self.memory_key]
|
83
|
+
|
84
|
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
85
|
+
"""Return history buffer."""
|
86
|
+
return {self.memory_key: self.buffer}
|
87
|
+
|
88
|
+
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
89
|
+
"""Return key-value pairs given the text input to the chain."""
|
90
|
+
buffer = await self.abuffer()
|
91
|
+
return {self.memory_key: buffer}
|
92
|
+
|
93
|
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
94
|
+
"""Save context from this conversation to buffer."""
|
95
|
+
input_str, output_str = self._get_input_output(inputs, outputs)
|
96
|
+
|
97
|
+
input_message_str = json.dumps(message_to_dict(HumanMessage(content=input_str)),
|
98
|
+
ensure_ascii=False)
|
99
|
+
output_message_str = json.dumps(message_to_dict(AIMessage(content=output_str)),
|
100
|
+
ensure_ascii=False)
|
101
|
+
self.redis_client.lpush(self.redis_prefix + self.session_id, input_message_str)
|
102
|
+
self.redis_client.lpush(self.redis_prefix + self.session_id, output_message_str)
|
103
|
+
if self.ttl:
|
104
|
+
self.redis_client.expire(self.redis_prefix + self.session_id, self.ttl)
|
@@ -16,7 +16,7 @@ from langchain.text_splitter import TextSplitter
|
|
16
16
|
|
17
17
|
|
18
18
|
class KeywordRetriever(BaseRetriever):
|
19
|
-
keyword_store:
|
19
|
+
keyword_store: ElasticKeywordsSearch
|
20
20
|
text_splitter: TextSplitter
|
21
21
|
search_type: str = 'similarity'
|
22
22
|
search_kwargs: dict = Field(default_factory=dict)
|
@@ -13,6 +13,7 @@ from langchain.llms.base import BaseLLM
|
|
13
13
|
from langchain.prompts.prompt import PromptTemplate
|
14
14
|
from langchain.utils import get_from_dict_or_env
|
15
15
|
from langchain.vectorstores.base import VectorStore
|
16
|
+
from loguru import logger
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
19
|
from elasticsearch import Elasticsearch # noqa: F401
|
@@ -326,6 +327,49 @@ class ElasticKeywordsSearch(VectorStore, ABC):
|
|
326
327
|
response = client.search(index=index_name, body={'query': script_query, 'size': size})
|
327
328
|
return response
|
328
329
|
|
329
|
-
def
|
330
|
+
def delete_index(self, **kwargs: Any) -> None:
|
330
331
|
# TODO: Check if this can be done in bulk
|
331
332
|
self.client.indices.delete(index=self.index_name)
|
333
|
+
|
334
|
+
def delete(
|
335
|
+
self,
|
336
|
+
ids: Optional[List[str]] = None,
|
337
|
+
refresh_indices: Optional[bool] = True,
|
338
|
+
**kwargs: Any,
|
339
|
+
) -> Optional[bool]:
|
340
|
+
"""Delete documents from the Elasticsearch index.
|
341
|
+
|
342
|
+
Args:
|
343
|
+
ids: List of ids of documents to delete.
|
344
|
+
refresh_indices: Whether to refresh the index
|
345
|
+
after deleting documents. Defaults to True.
|
346
|
+
"""
|
347
|
+
try:
|
348
|
+
from elasticsearch.helpers import BulkIndexError, bulk
|
349
|
+
except ImportError:
|
350
|
+
raise ImportError('Could not import elasticsearch python package. '
|
351
|
+
'Please install it with `pip install elasticsearch`.')
|
352
|
+
|
353
|
+
body = []
|
354
|
+
|
355
|
+
if ids is None:
|
356
|
+
raise ValueError('ids must be provided.')
|
357
|
+
|
358
|
+
for _id in ids:
|
359
|
+
body.append({'_op_type': 'delete', '_index': self.index_name, '_id': _id})
|
360
|
+
|
361
|
+
if len(body) > 0:
|
362
|
+
try:
|
363
|
+
bulk(self.client, body, refresh=refresh_indices, ignore_status=404)
|
364
|
+
logger.debug(f'Deleted {len(body)} texts from index')
|
365
|
+
|
366
|
+
return True
|
367
|
+
except BulkIndexError as e:
|
368
|
+
logger.error(f'Error deleting texts: {e}')
|
369
|
+
firstError = e.errors[0].get('index', {}).get('error', {})
|
370
|
+
logger.error(f"First error reason: {firstError.get('reason')}")
|
371
|
+
raise e
|
372
|
+
|
373
|
+
else:
|
374
|
+
logger.debug('No texts to delete from index')
|
375
|
+
return False
|
@@ -108,6 +108,8 @@ bisheng_langchain/gpts/tools/get_current_time/tool.py,sha256=3uvk7Yu07qhZy1sBrFM
|
|
108
108
|
bisheng_langchain/input_output/__init__.py,sha256=sW_GB7MlrHYsqY1Meb_LeimQqNsMz1gH-00Tqb2BUyM,153
|
109
109
|
bisheng_langchain/input_output/input.py,sha256=I5YDmgbvvj1o2lO9wi8LE37wM0wP5jkhUREU32YrZMQ,1094
|
110
110
|
bisheng_langchain/input_output/output.py,sha256=6U-az6-Cwz665C2YmcH3SYctWVjPFjmW8s70CA_qphk,11585
|
111
|
+
bisheng_langchain/memory/__init__.py,sha256=TNqe5l5BqUv4wh3_UH28fYPWQXGLBUYn6QJHsr7vanI,82
|
112
|
+
bisheng_langchain/memory/redis.py,sha256=paz72ic5BfLXY6lj2cEbCxrTb8KVMnKMZmG9q7uh_9s,4291
|
111
113
|
bisheng_langchain/rag/__init__.py,sha256=Rm_cDxOJINt0H4bOeUo3JctPxaI6xKKXZcS-R_wkoGs,198
|
112
114
|
bisheng_langchain/rag/bisheng_rag_chain.py,sha256=2GMDUPJaW-D7tpOQ9qPt2vGZwmcXBS0UrcibO7J2S1g,5999
|
113
115
|
bisheng_langchain/rag/bisheng_rag_pipeline.py,sha256=neoBK3TtuQ07_WeuJCzYlvtsDQNepUa_68NT8VCgytw,13749
|
@@ -126,7 +128,7 @@ bisheng_langchain/rag/config/baseline_s2b_mix.yaml,sha256=rkPfzU2-mvjRrZ0zMHaQsn
|
|
126
128
|
bisheng_langchain/rag/config/baseline_v2.yaml,sha256=RP-DwIRIS_ZK8ixbXi2Z28rKqHD56pWmr2o2WWIwq3Y,2382
|
127
129
|
bisheng_langchain/rag/init_retrievers/__init__.py,sha256=qpLLAuqZPtumTlJj17Ie5AbDDmiUiDxYefg_pumqu-c,218
|
128
130
|
bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py,sha256=oRKZZpxlLQAtsubIcAXeXpf1a9h6Pt6uOtNTLeD2jps,2362
|
129
|
-
bisheng_langchain/rag/init_retrievers/keyword_retriever.py,sha256=
|
131
|
+
bisheng_langchain/rag/init_retrievers/keyword_retriever.py,sha256=NRT0fBx6HFR7j9IbRl_NBuqF7hnL-9v5GCqHpgnrfPQ,2523
|
130
132
|
bisheng_langchain/rag/init_retrievers/mix_retriever.py,sha256=Whxq4kjNPLsxnHcVo60usdFFwLTCD-1jO38q08LXkVQ,4653
|
131
133
|
bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py,sha256=RQ7QLEOOhBrkw-EimXVJqIGa96D-KkNDik2h9hzg9fU,3805
|
132
134
|
bisheng_langchain/rag/prompts/__init__.py,sha256=IUCq9gzqGQN_6IDk0D_F5t3mOUI_KbmSzYnnXoX4VKE,223
|
@@ -150,10 +152,10 @@ bisheng_langchain/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG
|
|
150
152
|
bisheng_langchain/utils/azure_dalle_image_generator.py,sha256=96-_nO4hDSwyPE4rSYop5SgJ-U9CE2un4bTdW0E5RGU,6582
|
151
153
|
bisheng_langchain/utils/requests.py,sha256=vWGKyNTxApVeaVdKxqACfIT1Q8wMy-jC3kUv2Ce9Mzc,8688
|
152
154
|
bisheng_langchain/vectorstores/__init__.py,sha256=zCZgDe7LyQ0iDkfcm5UJ5NxwKQSRHnqrsjx700Fy11M,213
|
153
|
-
bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=
|
155
|
+
bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=inZarhahRaesrvLqyeRCMQvHGAASY53opEVA0_o8S14,14901
|
154
156
|
bisheng_langchain/vectorstores/milvus.py,sha256=xh7NokraKg_Xc9ofz0RVfJ_I36ftnprLJtV-1NfaeyQ,37162
|
155
157
|
bisheng_langchain/vectorstores/retriever.py,sha256=hj4nAAl352EV_ANnU2OHJn7omCH3nBK82ydo14KqMH4,4353
|
156
|
-
bisheng_langchain-0.3.
|
157
|
-
bisheng_langchain-0.3.
|
158
|
-
bisheng_langchain-0.3.
|
159
|
-
bisheng_langchain-0.3.
|
158
|
+
bisheng_langchain-0.3.7.dist-info/METADATA,sha256=QmKT4P-W7klb8-YIRFq1Kqh8uHfq0454b9sOMgATjy4,2471
|
159
|
+
bisheng_langchain-0.3.7.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
160
|
+
bisheng_langchain-0.3.7.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
|
161
|
+
bisheng_langchain-0.3.7.dist-info/RECORD,,
|
File without changes
|