bisheng-langchain 0.3.6.dev1__py3-none-any.whl → 0.3.7.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chains/qa_generation/base_v2.py +33 -14
- bisheng_langchain/memory/__init__.py +3 -0
- bisheng_langchain/memory/redis.py +104 -0
- {bisheng_langchain-0.3.6.dev1.dist-info → bisheng_langchain-0.3.7.dev2.dist-info}/METADATA +1 -1
- {bisheng_langchain-0.3.6.dev1.dist-info → bisheng_langchain-0.3.7.dev2.dist-info}/RECORD +7 -5
- {bisheng_langchain-0.3.6.dev1.dist-info → bisheng_langchain-0.3.7.dev2.dist-info}/WHEEL +1 -1
- {bisheng_langchain-0.3.6.dev1.dist-info → bisheng_langchain-0.3.7.dev2.dist-info}/top_level.txt +0 -0
@@ -134,6 +134,8 @@ class TrainsetGenerator:
|
|
134
134
|
chunk_size: int = 1024,
|
135
135
|
seed: int = 42,
|
136
136
|
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
|
137
|
+
filter_lowquality_context: bool = False,
|
138
|
+
filter_lowquality_question: bool = False,
|
137
139
|
answer_prompt: Optional[HumanMessagePromptTemplate] = ANSWER_FORMULATE,
|
138
140
|
) -> None:
|
139
141
|
self.generator_llm = generator_llm
|
@@ -152,6 +154,8 @@ class TrainsetGenerator:
|
|
152
154
|
self.threshold = 5.0
|
153
155
|
self.rng = default_rng(seed)
|
154
156
|
self.prompt = prompt
|
157
|
+
self.filter_lowquality_context = filter_lowquality_context
|
158
|
+
self.filter_lowquality_question = filter_lowquality_question
|
155
159
|
if answer_prompt is None:
|
156
160
|
answer_prompt = ANSWER_FORMULATE
|
157
161
|
self.answer_prompt = answer_prompt
|
@@ -163,6 +167,8 @@ class TrainsetGenerator:
|
|
163
167
|
chunk_size: int = 512,
|
164
168
|
trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
|
165
169
|
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
|
170
|
+
filter_lowquality_context: bool = False,
|
171
|
+
filter_lowquality_question: bool = False,
|
166
172
|
answer_prompt: Optional[PromptTemplate] = ANSWER_FORMULATE,
|
167
173
|
):
|
168
174
|
generator_llm = llm
|
@@ -173,6 +179,8 @@ class TrainsetGenerator:
|
|
173
179
|
chunk_size=chunk_size,
|
174
180
|
trainset_distribution=trainset_distribution,
|
175
181
|
prompt=prompt,
|
182
|
+
filter_lowquality_context=filter_lowquality_context,
|
183
|
+
filter_lowquality_question=filter_lowquality_question,
|
176
184
|
answer_prompt=answer_prompt,
|
177
185
|
)
|
178
186
|
|
@@ -316,14 +324,17 @@ class TrainsetGenerator:
|
|
316
324
|
)
|
317
325
|
|
318
326
|
text_chunk = " ".join([node.get_content() for node in nodes])
|
319
|
-
|
320
|
-
|
321
|
-
|
327
|
+
if self.filter_lowquality_context:
|
328
|
+
score = self._filter_context(text_chunk)
|
329
|
+
if not score:
|
330
|
+
continue
|
322
331
|
seed_question = self._seed_question(text_chunk)
|
323
332
|
|
324
333
|
question = seed_question
|
325
|
-
|
326
|
-
|
334
|
+
if self.filter_lowquality_question:
|
335
|
+
is_valid_question = self._filter_question(question)
|
336
|
+
else:
|
337
|
+
is_valid_question = True
|
327
338
|
if is_valid_question:
|
328
339
|
context = [text_chunk] * len(question.split("\n"))
|
329
340
|
is_conv = len(context) > 1
|
@@ -361,6 +372,8 @@ class QAGenerationChainV2(Chain):
|
|
361
372
|
llm: BaseLanguageModel,
|
362
373
|
k: Optional[int] = None,
|
363
374
|
chunk_size: int = 512,
|
375
|
+
filter_lowquality_context: bool = False,
|
376
|
+
filter_lowquality_question: bool = False,
|
364
377
|
question_prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
|
365
378
|
answer_prompt: Optional[HumanMessagePromptTemplate] = ANSWER_FORMULATE,
|
366
379
|
**kwargs: Any,
|
@@ -377,8 +390,14 @@ class QAGenerationChainV2(Chain):
|
|
377
390
|
Returns:
|
378
391
|
a QAGenerationChain class
|
379
392
|
"""
|
380
|
-
generator = TrainsetGenerator.from_default(
|
381
|
-
|
393
|
+
generator = TrainsetGenerator.from_default(
|
394
|
+
llm,
|
395
|
+
chunk_size=chunk_size,
|
396
|
+
prompt=question_prompt,
|
397
|
+
answer_prompt=answer_prompt,
|
398
|
+
filter_lowquality_context=filter_lowquality_context,
|
399
|
+
filter_lowquality_question=filter_lowquality_question
|
400
|
+
)
|
382
401
|
return cls(documents=documents, generator=generator, k=k, **kwargs)
|
383
402
|
|
384
403
|
@property
|
@@ -405,14 +424,14 @@ class QAGenerationChainV2(Chain):
|
|
405
424
|
dataset = self.generator.generate(documents=self.documents, train_size=self.k)
|
406
425
|
df = dataset.to_pandas()
|
407
426
|
qa_pairs = df.to_dict("records")
|
408
|
-
qa =
|
427
|
+
qa = []
|
409
428
|
for pair in qa_pairs:
|
410
|
-
qa
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
429
|
+
qa.append({
|
430
|
+
"question": pair["question"],
|
431
|
+
"answer": pair["ground_truth"][0],
|
432
|
+
"context": pair["ground_truth_context"][0],
|
433
|
+
})
|
434
|
+
qa = f'```json\n{json.dumps(qa, ensure_ascii=False, indent=4)}\n```'
|
416
435
|
return {self.output_key: qa}
|
417
436
|
|
418
437
|
async def _acall(
|
@@ -0,0 +1,104 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import redis
|
5
|
+
from langchain.memory.chat_memory import BaseChatMemory
|
6
|
+
from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage, get_buffer_string,
|
7
|
+
message_to_dict, messages_from_dict)
|
8
|
+
from langchain_core.pydantic_v1 import root_validator
|
9
|
+
from pydantic import Field
|
10
|
+
|
11
|
+
|
12
|
+
class ConversationRedisMemory(BaseChatMemory):
|
13
|
+
"""Using redis for storing conversation memory."""
|
14
|
+
redis_client: redis.Redis = Field(default=None, exclude=True)
|
15
|
+
human_prefix: str = 'Human'
|
16
|
+
ai_prefix: str = 'AI'
|
17
|
+
session_id: str = 'session'
|
18
|
+
memory_key: str = 'history' #: :meta private:
|
19
|
+
redis_url: str
|
20
|
+
redis_prefix: str = 'redis_buffer_'
|
21
|
+
ttl: Optional[int] = None
|
22
|
+
|
23
|
+
@root_validator()
|
24
|
+
def validate_environment(cls, values: Dict) -> Dict:
|
25
|
+
redis_url = values.get('redis_url')
|
26
|
+
if not redis_url:
|
27
|
+
raise ValueError('Redis URL must be set')
|
28
|
+
pool = redis.ConnectionPool.from_url(redis_url, max_connections=1)
|
29
|
+
values['redis_client'] = redis.StrictRedis(connection_pool=pool)
|
30
|
+
return values
|
31
|
+
|
32
|
+
@property
|
33
|
+
def buffer(self) -> Any:
|
34
|
+
"""String buffer of memory."""
|
35
|
+
return self.buffer_as_messages if self.return_messages else self.buffer_as_str
|
36
|
+
|
37
|
+
async def abuffer(self) -> Any:
|
38
|
+
"""String buffer of memory."""
|
39
|
+
return (await self.abuffer_as_messages()
|
40
|
+
if self.return_messages else await self.abuffer_as_str())
|
41
|
+
|
42
|
+
def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
|
43
|
+
return get_buffer_string(
|
44
|
+
messages,
|
45
|
+
human_prefix=self.human_prefix,
|
46
|
+
ai_prefix=self.ai_prefix,
|
47
|
+
)
|
48
|
+
|
49
|
+
@property
|
50
|
+
def buffer_as_str(self) -> str:
|
51
|
+
"""Exposes the buffer as a string in case return_messages is True."""
|
52
|
+
messages = self.buffer_as_messages
|
53
|
+
return self._buffer_as_str(messages)
|
54
|
+
|
55
|
+
# return self._buffer_as_str(self.chat_memory.messages)
|
56
|
+
|
57
|
+
async def abuffer_as_str(self) -> str:
|
58
|
+
"""Exposes the buffer as a string in case return_messages is True."""
|
59
|
+
# messages = await self.chat_memory.aget_messages()
|
60
|
+
messages = self.buffer_as_messages
|
61
|
+
return self._buffer_as_str(messages)
|
62
|
+
|
63
|
+
@property
|
64
|
+
def buffer_as_messages(self) -> List[BaseMessage]:
|
65
|
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
66
|
+
# return self.chat_memory.messages
|
67
|
+
redis_value = self.redis_client.lrange(self.redis_prefix + self.session_id, 0, -1)
|
68
|
+
items = [json.loads(m.decode('utf-8')) for m in redis_value[::-1]]
|
69
|
+
messages = messages_from_dict(items)
|
70
|
+
return messages
|
71
|
+
|
72
|
+
async def abuffer_as_messages(self) -> List[BaseMessage]:
|
73
|
+
"""Exposes the buffer as a list of messages in case return_messages is False."""
|
74
|
+
self.buffer_as_messages
|
75
|
+
|
76
|
+
@property
|
77
|
+
def memory_variables(self) -> List[str]:
|
78
|
+
"""Will always return list of memory variables.
|
79
|
+
|
80
|
+
:meta private:
|
81
|
+
"""
|
82
|
+
return [self.memory_key]
|
83
|
+
|
84
|
+
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
85
|
+
"""Return history buffer."""
|
86
|
+
return {self.memory_key: self.buffer}
|
87
|
+
|
88
|
+
async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
89
|
+
"""Return key-value pairs given the text input to the chain."""
|
90
|
+
buffer = await self.abuffer()
|
91
|
+
return {self.memory_key: buffer}
|
92
|
+
|
93
|
+
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
94
|
+
"""Save context from this conversation to buffer."""
|
95
|
+
input_str, output_str = self._get_input_output(inputs, outputs)
|
96
|
+
|
97
|
+
input_message_str = json.dumps(message_to_dict(HumanMessage(content=input_str)),
|
98
|
+
ensure_ascii=False)
|
99
|
+
output_message_str = json.dumps(message_to_dict(AIMessage(content=output_str)),
|
100
|
+
ensure_ascii=False)
|
101
|
+
self.redis_client.lpush(self.redis_prefix + self.session_id, input_message_str)
|
102
|
+
self.redis_client.lpush(self.redis_prefix + self.session_id, output_message_str)
|
103
|
+
if self.ttl:
|
104
|
+
self.redis_client.expire(self.redis_prefix + self.session_id, self.ttl)
|
@@ -23,7 +23,7 @@ bisheng_langchain/chains/conversational_retrieval/__init__.py,sha256=47DEQpj8HBS
|
|
23
23
|
bisheng_langchain/chains/conversational_retrieval/base.py,sha256=XiqBqov6No-wTVCou6qyMT5p2JQgoQI7OLQOYH8XUos,5313
|
24
24
|
bisheng_langchain/chains/qa_generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
25
25
|
bisheng_langchain/chains/qa_generation/base.py,sha256=VYGmLDB0bnlDQ6T8ivLP55wwFbMo9HOzlPEDUuRx5fU,4148
|
26
|
-
bisheng_langchain/chains/qa_generation/base_v2.py,sha256=
|
26
|
+
bisheng_langchain/chains/qa_generation/base_v2.py,sha256=2F2kGe3ermJraQu4oC-m8vm_ENBy_Zi4uHrJDcSOeJw,15460
|
27
27
|
bisheng_langchain/chains/qa_generation/prompt.py,sha256=4eJk9aDUYDN1qaaYRPy9EobCIncnwS8BbQaDFzzePtM,1944
|
28
28
|
bisheng_langchain/chains/qa_generation/prompt_v2.py,sha256=sQLanA_iOnLqrUIwzfTOTANt-1vJ44CM54HFDU8Jo1Q,8938
|
29
29
|
bisheng_langchain/chains/question_answering/__init__.py,sha256=_gOZMc-SWprK6xc-Jj64jcr9nc-G4YkZbEYwfJNq_bY,8795
|
@@ -108,6 +108,8 @@ bisheng_langchain/gpts/tools/get_current_time/tool.py,sha256=3uvk7Yu07qhZy1sBrFM
|
|
108
108
|
bisheng_langchain/input_output/__init__.py,sha256=sW_GB7MlrHYsqY1Meb_LeimQqNsMz1gH-00Tqb2BUyM,153
|
109
109
|
bisheng_langchain/input_output/input.py,sha256=I5YDmgbvvj1o2lO9wi8LE37wM0wP5jkhUREU32YrZMQ,1094
|
110
110
|
bisheng_langchain/input_output/output.py,sha256=6U-az6-Cwz665C2YmcH3SYctWVjPFjmW8s70CA_qphk,11585
|
111
|
+
bisheng_langchain/memory/__init__.py,sha256=TNqe5l5BqUv4wh3_UH28fYPWQXGLBUYn6QJHsr7vanI,82
|
112
|
+
bisheng_langchain/memory/redis.py,sha256=paz72ic5BfLXY6lj2cEbCxrTb8KVMnKMZmG9q7uh_9s,4291
|
111
113
|
bisheng_langchain/rag/__init__.py,sha256=Rm_cDxOJINt0H4bOeUo3JctPxaI6xKKXZcS-R_wkoGs,198
|
112
114
|
bisheng_langchain/rag/bisheng_rag_chain.py,sha256=2GMDUPJaW-D7tpOQ9qPt2vGZwmcXBS0UrcibO7J2S1g,5999
|
113
115
|
bisheng_langchain/rag/bisheng_rag_pipeline.py,sha256=neoBK3TtuQ07_WeuJCzYlvtsDQNepUa_68NT8VCgytw,13749
|
@@ -153,7 +155,7 @@ bisheng_langchain/vectorstores/__init__.py,sha256=zCZgDe7LyQ0iDkfcm5UJ5NxwKQSRHn
|
|
153
155
|
bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=inZarhahRaesrvLqyeRCMQvHGAASY53opEVA0_o8S14,14901
|
154
156
|
bisheng_langchain/vectorstores/milvus.py,sha256=xh7NokraKg_Xc9ofz0RVfJ_I36ftnprLJtV-1NfaeyQ,37162
|
155
157
|
bisheng_langchain/vectorstores/retriever.py,sha256=hj4nAAl352EV_ANnU2OHJn7omCH3nBK82ydo14KqMH4,4353
|
156
|
-
bisheng_langchain-0.3.
|
157
|
-
bisheng_langchain-0.3.
|
158
|
-
bisheng_langchain-0.3.
|
159
|
-
bisheng_langchain-0.3.
|
158
|
+
bisheng_langchain-0.3.7.dev2.dist-info/METADATA,sha256=rPLG8c2G8ZAOn3mjAcIP4evhXJbe-CMeUQc9gtuIdCc,2476
|
159
|
+
bisheng_langchain-0.3.7.dev2.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
160
|
+
bisheng_langchain-0.3.7.dev2.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
|
161
|
+
bisheng_langchain-0.3.7.dev2.dist-info/RECORD,,
|
{bisheng_langchain-0.3.6.dev1.dist-info → bisheng_langchain-0.3.7.dev2.dist-info}/top_level.txt
RENAMED
File without changes
|