bisheng-langchain 0.3.6.dev1__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -134,6 +134,8 @@ class TrainsetGenerator:
134
134
  chunk_size: int = 1024,
135
135
  seed: int = 42,
136
136
  prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
137
+ filter_lowquality_context: bool = False,
138
+ filter_lowquality_question: bool = False,
137
139
  answer_prompt: Optional[HumanMessagePromptTemplate] = ANSWER_FORMULATE,
138
140
  ) -> None:
139
141
  self.generator_llm = generator_llm
@@ -152,6 +154,8 @@ class TrainsetGenerator:
152
154
  self.threshold = 5.0
153
155
  self.rng = default_rng(seed)
154
156
  self.prompt = prompt
157
+ self.filter_lowquality_context = filter_lowquality_context
158
+ self.filter_lowquality_question = filter_lowquality_question
155
159
  if answer_prompt is None:
156
160
  answer_prompt = ANSWER_FORMULATE
157
161
  self.answer_prompt = answer_prompt
@@ -163,6 +167,8 @@ class TrainsetGenerator:
163
167
  chunk_size: int = 512,
164
168
  trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
165
169
  prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
170
+ filter_lowquality_context: bool = False,
171
+ filter_lowquality_question: bool = False,
166
172
  answer_prompt: Optional[PromptTemplate] = ANSWER_FORMULATE,
167
173
  ):
168
174
  generator_llm = llm
@@ -173,6 +179,8 @@ class TrainsetGenerator:
173
179
  chunk_size=chunk_size,
174
180
  trainset_distribution=trainset_distribution,
175
181
  prompt=prompt,
182
+ filter_lowquality_context=filter_lowquality_context,
183
+ filter_lowquality_question=filter_lowquality_question,
176
184
  answer_prompt=answer_prompt,
177
185
  )
178
186
 
@@ -316,14 +324,17 @@ class TrainsetGenerator:
316
324
  )
317
325
 
318
326
  text_chunk = " ".join([node.get_content() for node in nodes])
319
- score = self._filter_context(text_chunk)
320
- if not score:
321
- continue
327
+ if self.filter_lowquality_context:
328
+ score = self._filter_context(text_chunk)
329
+ if not score:
330
+ continue
322
331
  seed_question = self._seed_question(text_chunk)
323
332
 
324
333
  question = seed_question
325
- # is_valid_question = self._filter_question(question)
326
- is_valid_question = True
334
+ if self.filter_lowquality_question:
335
+ is_valid_question = self._filter_question(question)
336
+ else:
337
+ is_valid_question = True
327
338
  if is_valid_question:
328
339
  context = [text_chunk] * len(question.split("\n"))
329
340
  is_conv = len(context) > 1
@@ -361,6 +372,8 @@ class QAGenerationChainV2(Chain):
361
372
  llm: BaseLanguageModel,
362
373
  k: Optional[int] = None,
363
374
  chunk_size: int = 512,
375
+ filter_lowquality_context: bool = False,
376
+ filter_lowquality_question: bool = False,
364
377
  question_prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
365
378
  answer_prompt: Optional[HumanMessagePromptTemplate] = ANSWER_FORMULATE,
366
379
  **kwargs: Any,
@@ -377,8 +390,14 @@ class QAGenerationChainV2(Chain):
377
390
  Returns:
378
391
  a QAGenerationChain class
379
392
  """
380
- generator = TrainsetGenerator.from_default(llm, chunk_size=chunk_size, prompt=question_prompt,
381
- answer_prompt=answer_prompt)
393
+ generator = TrainsetGenerator.from_default(
394
+ llm,
395
+ chunk_size=chunk_size,
396
+ prompt=question_prompt,
397
+ answer_prompt=answer_prompt,
398
+ filter_lowquality_context=filter_lowquality_context,
399
+ filter_lowquality_question=filter_lowquality_question
400
+ )
382
401
  return cls(documents=documents, generator=generator, k=k, **kwargs)
383
402
 
384
403
  @property
@@ -405,14 +424,14 @@ class QAGenerationChainV2(Chain):
405
424
  dataset = self.generator.generate(documents=self.documents, train_size=self.k)
406
425
  df = dataset.to_pandas()
407
426
  qa_pairs = df.to_dict("records")
408
- qa = ''
427
+ qa = []
409
428
  for pair in qa_pairs:
410
- qa += json.dumps(
411
- {
412
- "question": pair["question"],
413
- "answer": pair["ground_truth"][0],
414
- "context": pair["ground_truth_context"][0],
415
- }, ensure_ascii=False)
429
+ qa.append({
430
+ "question": pair["question"],
431
+ "answer": pair["ground_truth"][0],
432
+ "context": pair["ground_truth_context"][0],
433
+ })
434
+ qa = f'```json\n{json.dumps(qa, ensure_ascii=False, indent=4)}\n```'
416
435
  return {self.output_key: qa}
417
436
 
418
437
  async def _acall(
@@ -0,0 +1,3 @@
1
+ from .redis import ConversationRedisMemory
2
+
3
+ __all__ = ['ConversationRedisMemory']
@@ -0,0 +1,104 @@
1
+ import json
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import redis
5
+ from langchain.memory.chat_memory import BaseChatMemory
6
+ from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage, get_buffer_string,
7
+ message_to_dict, messages_from_dict)
8
+ from langchain_core.pydantic_v1 import root_validator
9
+ from pydantic import Field
10
+
11
+
12
+ class ConversationRedisMemory(BaseChatMemory):
13
+ """Using redis for storing conversation memory."""
14
+ redis_client: redis.Redis = Field(default=None, exclude=True)
15
+ human_prefix: str = 'Human'
16
+ ai_prefix: str = 'AI'
17
+ session_id: str = 'session'
18
+ memory_key: str = 'history' #: :meta private:
19
+ redis_url: str
20
+ redis_prefix: str = 'redis_buffer_'
21
+ ttl: Optional[int] = None
22
+
23
+ @root_validator()
24
+ def validate_environment(cls, values: Dict) -> Dict:
25
+ redis_url = values.get('redis_url')
26
+ if not redis_url:
27
+ raise ValueError('Redis URL must be set')
28
+ pool = redis.ConnectionPool.from_url(redis_url, max_connections=1)
29
+ values['redis_client'] = redis.StrictRedis(connection_pool=pool)
30
+ return values
31
+
32
+ @property
33
+ def buffer(self) -> Any:
34
+ """String buffer of memory."""
35
+ return self.buffer_as_messages if self.return_messages else self.buffer_as_str
36
+
37
+ async def abuffer(self) -> Any:
38
+ """String buffer of memory."""
39
+ return (await self.abuffer_as_messages()
40
+ if self.return_messages else await self.abuffer_as_str())
41
+
42
+ def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
43
+ return get_buffer_string(
44
+ messages,
45
+ human_prefix=self.human_prefix,
46
+ ai_prefix=self.ai_prefix,
47
+ )
48
+
49
+ @property
50
+ def buffer_as_str(self) -> str:
51
+ """Exposes the buffer as a string in case return_messages is True."""
52
+ messages = self.buffer_as_messages
53
+ return self._buffer_as_str(messages)
54
+
55
+ # return self._buffer_as_str(self.chat_memory.messages)
56
+
57
+ async def abuffer_as_str(self) -> str:
58
+ """Exposes the buffer as a string in case return_messages is True."""
59
+ # messages = await self.chat_memory.aget_messages()
60
+ messages = self.buffer_as_messages
61
+ return self._buffer_as_str(messages)
62
+
63
+ @property
64
+ def buffer_as_messages(self) -> List[BaseMessage]:
65
+ """Exposes the buffer as a list of messages in case return_messages is False."""
66
+ # return self.chat_memory.messages
67
+ redis_value = self.redis_client.lrange(self.redis_prefix + self.session_id, 0, -1)
68
+ items = [json.loads(m.decode('utf-8')) for m in redis_value[::-1]]
69
+ messages = messages_from_dict(items)
70
+ return messages
71
+
72
+ async def abuffer_as_messages(self) -> List[BaseMessage]:
73
+ """Exposes the buffer as a list of messages in case return_messages is False."""
74
+ self.buffer_as_messages
75
+
76
+ @property
77
+ def memory_variables(self) -> List[str]:
78
+ """Will always return list of memory variables.
79
+
80
+ :meta private:
81
+ """
82
+ return [self.memory_key]
83
+
84
+ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
85
+ """Return history buffer."""
86
+ return {self.memory_key: self.buffer}
87
+
88
+ async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
89
+ """Return key-value pairs given the text input to the chain."""
90
+ buffer = await self.abuffer()
91
+ return {self.memory_key: buffer}
92
+
93
+ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
94
+ """Save context from this conversation to buffer."""
95
+ input_str, output_str = self._get_input_output(inputs, outputs)
96
+
97
+ input_message_str = json.dumps(message_to_dict(HumanMessage(content=input_str)),
98
+ ensure_ascii=False)
99
+ output_message_str = json.dumps(message_to_dict(AIMessage(content=output_str)),
100
+ ensure_ascii=False)
101
+ self.redis_client.lpush(self.redis_prefix + self.session_id, input_message_str)
102
+ self.redis_client.lpush(self.redis_prefix + self.session_id, output_message_str)
103
+ if self.ttl:
104
+ self.redis_client.expire(self.redis_prefix + self.session_id, self.ttl)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bisheng-langchain
3
- Version: 0.3.6.dev1
3
+ Version: 0.3.7
4
4
  Summary: bisheng langchain modules
5
5
  Home-page: https://github.com/dataelement/bisheng
6
6
  Author: DataElem
@@ -23,7 +23,7 @@ bisheng_langchain/chains/conversational_retrieval/__init__.py,sha256=47DEQpj8HBS
23
23
  bisheng_langchain/chains/conversational_retrieval/base.py,sha256=XiqBqov6No-wTVCou6qyMT5p2JQgoQI7OLQOYH8XUos,5313
24
24
  bisheng_langchain/chains/qa_generation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  bisheng_langchain/chains/qa_generation/base.py,sha256=VYGmLDB0bnlDQ6T8ivLP55wwFbMo9HOzlPEDUuRx5fU,4148
26
- bisheng_langchain/chains/qa_generation/base_v2.py,sha256=ZtHEuNFwbE9txCGR3wx0oDAoj9V6bAxi3GXF8Z78cqQ,14580
26
+ bisheng_langchain/chains/qa_generation/base_v2.py,sha256=2F2kGe3ermJraQu4oC-m8vm_ENBy_Zi4uHrJDcSOeJw,15460
27
27
  bisheng_langchain/chains/qa_generation/prompt.py,sha256=4eJk9aDUYDN1qaaYRPy9EobCIncnwS8BbQaDFzzePtM,1944
28
28
  bisheng_langchain/chains/qa_generation/prompt_v2.py,sha256=sQLanA_iOnLqrUIwzfTOTANt-1vJ44CM54HFDU8Jo1Q,8938
29
29
  bisheng_langchain/chains/question_answering/__init__.py,sha256=_gOZMc-SWprK6xc-Jj64jcr9nc-G4YkZbEYwfJNq_bY,8795
@@ -108,6 +108,8 @@ bisheng_langchain/gpts/tools/get_current_time/tool.py,sha256=3uvk7Yu07qhZy1sBrFM
108
108
  bisheng_langchain/input_output/__init__.py,sha256=sW_GB7MlrHYsqY1Meb_LeimQqNsMz1gH-00Tqb2BUyM,153
109
109
  bisheng_langchain/input_output/input.py,sha256=I5YDmgbvvj1o2lO9wi8LE37wM0wP5jkhUREU32YrZMQ,1094
110
110
  bisheng_langchain/input_output/output.py,sha256=6U-az6-Cwz665C2YmcH3SYctWVjPFjmW8s70CA_qphk,11585
111
+ bisheng_langchain/memory/__init__.py,sha256=TNqe5l5BqUv4wh3_UH28fYPWQXGLBUYn6QJHsr7vanI,82
112
+ bisheng_langchain/memory/redis.py,sha256=paz72ic5BfLXY6lj2cEbCxrTb8KVMnKMZmG9q7uh_9s,4291
111
113
  bisheng_langchain/rag/__init__.py,sha256=Rm_cDxOJINt0H4bOeUo3JctPxaI6xKKXZcS-R_wkoGs,198
112
114
  bisheng_langchain/rag/bisheng_rag_chain.py,sha256=2GMDUPJaW-D7tpOQ9qPt2vGZwmcXBS0UrcibO7J2S1g,5999
113
115
  bisheng_langchain/rag/bisheng_rag_pipeline.py,sha256=neoBK3TtuQ07_WeuJCzYlvtsDQNepUa_68NT8VCgytw,13749
@@ -153,7 +155,7 @@ bisheng_langchain/vectorstores/__init__.py,sha256=zCZgDe7LyQ0iDkfcm5UJ5NxwKQSRHn
153
155
  bisheng_langchain/vectorstores/elastic_keywords_search.py,sha256=inZarhahRaesrvLqyeRCMQvHGAASY53opEVA0_o8S14,14901
154
156
  bisheng_langchain/vectorstores/milvus.py,sha256=xh7NokraKg_Xc9ofz0RVfJ_I36ftnprLJtV-1NfaeyQ,37162
155
157
  bisheng_langchain/vectorstores/retriever.py,sha256=hj4nAAl352EV_ANnU2OHJn7omCH3nBK82ydo14KqMH4,4353
156
- bisheng_langchain-0.3.6.dev1.dist-info/METADATA,sha256=KG32YRknnVoAxFzVKE_qMMQBjbhZen046fXQYyhXQvs,2476
157
- bisheng_langchain-0.3.6.dev1.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
158
- bisheng_langchain-0.3.6.dev1.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
159
- bisheng_langchain-0.3.6.dev1.dist-info/RECORD,,
158
+ bisheng_langchain-0.3.7.dist-info/METADATA,sha256=QmKT4P-W7klb8-YIRFq1Kqh8uHfq0454b9sOMgATjy4,2471
159
+ bisheng_langchain-0.3.7.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
160
+ bisheng_langchain-0.3.7.dist-info/top_level.txt,sha256=Z6pPNyCo4ihyr9iqGQbH8sJiC4dAUwA_mAyGRQB5_Fs,18
161
+ bisheng_langchain-0.3.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.44.0)
2
+ Generator: bdist_wheel (0.45.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5