bisheng-langchain 0.3.1.2__py3-none-any.whl → 0.3.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bisheng_langchain/chains/__init__.py +4 -1
- bisheng_langchain/chains/qa_generation/__init__.py +0 -0
- bisheng_langchain/chains/qa_generation/base.py +128 -0
- bisheng_langchain/chains/qa_generation/base_v2.py +413 -0
- bisheng_langchain/chains/qa_generation/prompt.py +53 -0
- bisheng_langchain/chains/qa_generation/prompt_v2.py +155 -0
- bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +36 -9
- bisheng_langchain/document_loaders/parsers/ellm_client.py +7 -9
- bisheng_langchain/document_loaders/universal_kv.py +4 -3
- bisheng_langchain/gpts/tools/api_tools/openapi.py +7 -7
- bisheng_langchain/rag/__init__.py +2 -0
- bisheng_langchain/rag/bisheng_rag_chain.py +164 -0
- bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +8 -2
- bisheng_langchain/rag/bisheng_rag_tool.py +47 -24
- bisheng_langchain/rag/config/baseline_caibao_v2.yaml +1 -1
- bisheng_langchain/rag/config/baseline_v2.yaml +3 -2
- bisheng_langchain/rag/prompts/prompt.py +1 -1
- bisheng_langchain/rag/qa_corpus/qa_generator.py +1 -1
- bisheng_langchain/rag/scoring/ragas_score.py +2 -2
- bisheng_langchain/rag/utils.py +27 -4
- bisheng_langchain/sql/__init__.py +3 -0
- bisheng_langchain/sql/base.py +120 -0
- bisheng_langchain/text_splitter.py +1 -1
- {bisheng_langchain-0.3.1.2.dist-info → bisheng_langchain-0.3.2.1.dist-info}/METADATA +3 -1
- {bisheng_langchain-0.3.1.2.dist-info → bisheng_langchain-0.3.2.1.dist-info}/RECORD +27 -20
- bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +0 -376
- {bisheng_langchain-0.3.1.2.dist-info → bisheng_langchain-0.3.2.1.dist-info}/WHEEL +0 -0
- {bisheng_langchain-0.3.1.2.dist-info → bisheng_langchain-0.3.2.1.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,13 @@ from bisheng_langchain.chains.retrieval.retrieval_chain import RetrievalChain
|
|
5
5
|
from bisheng_langchain.chains.router.multi_rule import MultiRuleChain
|
6
6
|
from bisheng_langchain.chains.router.rule_router import RuleBasedRouter
|
7
7
|
from bisheng_langchain.chains.transform import TransformChain
|
8
|
+
from bisheng_langchain.chains.qa_generation.base import QAGenerationChain
|
9
|
+
from bisheng_langchain.chains.qa_generation.base_v2 import QAGenerationChainV2
|
8
10
|
|
9
11
|
from .loader_output import LoaderOutputChain
|
10
12
|
|
11
13
|
__all__ = [
|
12
14
|
'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter',
|
13
|
-
'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain', 'TransformChain'
|
15
|
+
'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain', 'TransformChain',
|
16
|
+
'QAGenerationChain', 'QAGenerationChainV2'
|
14
17
|
]
|
File without changes
|
@@ -0,0 +1,128 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import re
|
5
|
+
import logging
|
6
|
+
from langchain.docstore.document import Document
|
7
|
+
from typing import Any, Dict, List, Optional
|
8
|
+
|
9
|
+
from langchain_core.callbacks import CallbackManagerForChainRun
|
10
|
+
from langchain_core.language_models import BaseLanguageModel
|
11
|
+
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
|
12
|
+
from langchain_core.pydantic_v1 import Field
|
13
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
14
|
+
|
15
|
+
from langchain.chains.base import Chain
|
16
|
+
from langchain.chains.llm import LLMChain
|
17
|
+
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR, CHAT_PROMPT, PROMPT
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
def parse_json(input_str: str) -> str:
|
23
|
+
match = re.search(r'```(json)?(.*)```', input_str, re.DOTALL)
|
24
|
+
if match is None:
|
25
|
+
out_str = input_str
|
26
|
+
else:
|
27
|
+
out_str = match.group(2)
|
28
|
+
|
29
|
+
out_str = out_str.strip()
|
30
|
+
out_str = out_str.replace('```', '')
|
31
|
+
return out_str
|
32
|
+
|
33
|
+
|
34
|
+
class QAGenerationChain(Chain):
|
35
|
+
"""Base class for question-answer generation chains."""
|
36
|
+
|
37
|
+
documents: List[Document]
|
38
|
+
llm_chain: LLMChain
|
39
|
+
"""LLM Chain that generates responses from user input and context."""
|
40
|
+
k: Optional[int] = None
|
41
|
+
"""Number of questions to generate."""
|
42
|
+
text_splitter: TextSplitter = Field(
|
43
|
+
default=RecursiveCharacterTextSplitter(
|
44
|
+
separators=["\n\n", "\n", " ", ""],
|
45
|
+
chunk_size=1000,
|
46
|
+
chunk_overlap=100,
|
47
|
+
)
|
48
|
+
)
|
49
|
+
"""Text splitter that splits the input into chunks."""
|
50
|
+
input_key: str = "begin"
|
51
|
+
"""Key of the input to the chain."""
|
52
|
+
output_key: str = "questions"
|
53
|
+
"""Key of the output of the chain."""
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def from_llm(
|
57
|
+
cls,
|
58
|
+
documents: List[Document],
|
59
|
+
llm: BaseLanguageModel,
|
60
|
+
k: Optional[int] = None,
|
61
|
+
chunk_size: int = 512,
|
62
|
+
prompt: Optional[ChatPromptTemplate] = CHAT_PROMPT,
|
63
|
+
**kwargs: Any,
|
64
|
+
) -> QAGenerationChain:
|
65
|
+
"""
|
66
|
+
Create a QAGenerationChain from a language model.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
llm: a language model
|
70
|
+
prompt: a prompt template
|
71
|
+
**kwargs: additional arguments
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
a QAGenerationChain class
|
75
|
+
"""
|
76
|
+
_prompt = PROMPT_SELECTOR.get_prompt(llm) if prompt is None else prompt
|
77
|
+
chain = LLMChain(llm=llm, prompt=_prompt)
|
78
|
+
text_splitter = RecursiveCharacterTextSplitter(
|
79
|
+
separators=["\n\n", "\n", " ", ""],
|
80
|
+
chunk_size=chunk_size,
|
81
|
+
chunk_overlap=50,
|
82
|
+
)
|
83
|
+
return cls(documents=documents, llm_chain=chain, k=k, text_splitter=text_splitter, **kwargs)
|
84
|
+
|
85
|
+
@property
|
86
|
+
def _chain_type(self) -> str:
|
87
|
+
raise NotImplementedError
|
88
|
+
|
89
|
+
@property
|
90
|
+
def input_keys(self) -> List[str]:
|
91
|
+
return [self.input_key]
|
92
|
+
|
93
|
+
@property
|
94
|
+
def output_keys(self) -> List[str]:
|
95
|
+
return [self.output_key]
|
96
|
+
|
97
|
+
def _call(
|
98
|
+
self,
|
99
|
+
inputs: Dict[str, Any],
|
100
|
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
101
|
+
) -> Dict[str, List]:
|
102
|
+
contents = [doc.page_content for doc in self.documents]
|
103
|
+
contents = '\n\n'.join(contents)
|
104
|
+
docs = self.text_splitter.create_documents([contents])
|
105
|
+
# len(qa) = min(len(docs), self.k)
|
106
|
+
logger.info(f"Split {len(docs)} documents. Gen qa num: min({len(docs)}, {self.k}).")
|
107
|
+
qa = ''
|
108
|
+
qa_i = 0
|
109
|
+
for doc in docs:
|
110
|
+
try:
|
111
|
+
results = self.llm_chain.generate([{"text": doc.page_content}], run_manager=run_manager)
|
112
|
+
res = results.generations[0]
|
113
|
+
qa += res[0].text
|
114
|
+
qa_i += 1
|
115
|
+
except Exception as e:
|
116
|
+
logger.error(f"Failed to parse response Error: {e}")
|
117
|
+
continue
|
118
|
+
if self.k is not None and qa_i >= self.k:
|
119
|
+
break
|
120
|
+
return {self.output_key: qa}
|
121
|
+
|
122
|
+
async def _acall(
|
123
|
+
self,
|
124
|
+
inputs: Dict[str, Any],
|
125
|
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
126
|
+
) -> Dict[str, List]:
|
127
|
+
output = self._call(inputs, run_manager)
|
128
|
+
return output
|
@@ -0,0 +1,413 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import re
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import typing as t
|
7
|
+
import warnings
|
8
|
+
from typing import Any, Dict, List, Optional
|
9
|
+
from collections import defaultdict, namedtuple
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from langchain_core.callbacks import CallbackManagerForChainRun
|
12
|
+
from langchain_core.language_models import BaseLanguageModel
|
13
|
+
|
14
|
+
try:
|
15
|
+
from llama_index.node_parser import SimpleNodeParser
|
16
|
+
from llama_index.readers.schema import Document as LlamaindexDocument
|
17
|
+
from llama_index.schema import BaseNode
|
18
|
+
except ImportError:
|
19
|
+
raise ImportError(
|
20
|
+
"llama_index must be installed to use this function. "
|
21
|
+
"Please, install it with `pip install llama_index`."
|
22
|
+
)
|
23
|
+
import numpy as np
|
24
|
+
import numpy.testing as npt
|
25
|
+
import pandas as pd
|
26
|
+
from langchain.prompts import ChatPromptTemplate
|
27
|
+
from langchain.docstore.document import Document
|
28
|
+
# from langchain.schema.document import Document as LangchainDocument
|
29
|
+
from langchain.chains.base import Chain
|
30
|
+
from numpy.random import default_rng
|
31
|
+
from tqdm import tqdm
|
32
|
+
from .prompt_v2 import (
|
33
|
+
SEED_QUESTION_CHAT_PROMPT,
|
34
|
+
SCORE_CONTEXT_CHAT_PROMPT,
|
35
|
+
FILTER_QUESTION_CHAT_PROMPT,
|
36
|
+
ANSWER_FORMULATE,
|
37
|
+
)
|
38
|
+
from .base import parse_json
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
def load_as_score(text):
|
44
|
+
"""
|
45
|
+
validate and returns given text as score
|
46
|
+
"""
|
47
|
+
pattern = r"^[\d.]+$"
|
48
|
+
if not re.match(pattern, text):
|
49
|
+
warnings.warn("Invalid score")
|
50
|
+
score = 0.0
|
51
|
+
else:
|
52
|
+
score = eval(text)
|
53
|
+
return score
|
54
|
+
|
55
|
+
|
56
|
+
def load_as_json(text):
|
57
|
+
"""
|
58
|
+
validate and return given text as json
|
59
|
+
"""
|
60
|
+
|
61
|
+
try:
|
62
|
+
return json.loads(parse_json(text))
|
63
|
+
except ValueError as e:
|
64
|
+
warnings.warn(f"Invalid json: {e}")
|
65
|
+
|
66
|
+
return {}
|
67
|
+
|
68
|
+
|
69
|
+
DEFAULT_TRAIN_DISTRIBUTION = {
|
70
|
+
"simple": 1.0,
|
71
|
+
"reasoning": 0.0,
|
72
|
+
"multi_context": 0.0,
|
73
|
+
"conditional": 0.0,
|
74
|
+
}
|
75
|
+
|
76
|
+
DataRow = namedtuple(
|
77
|
+
"DataRow",
|
78
|
+
[
|
79
|
+
"question",
|
80
|
+
"ground_truth_context",
|
81
|
+
"ground_truth",
|
82
|
+
"question_type",
|
83
|
+
"episode_done",
|
84
|
+
],
|
85
|
+
)
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass
|
89
|
+
class TrainDataset:
|
90
|
+
"""
|
91
|
+
TrainDataset class
|
92
|
+
"""
|
93
|
+
|
94
|
+
train_data: t.List[DataRow]
|
95
|
+
|
96
|
+
def to_pandas(self) -> pd.DataFrame:
|
97
|
+
data_samples = []
|
98
|
+
for data in self.train_data:
|
99
|
+
data = {
|
100
|
+
"question": data.question,
|
101
|
+
"ground_truth_context": data.ground_truth_context,
|
102
|
+
"ground_truth": data.ground_truth,
|
103
|
+
"question_type": data.question_type,
|
104
|
+
"episode_done": data.episode_done,
|
105
|
+
}
|
106
|
+
data_samples.append(data)
|
107
|
+
return pd.DataFrame.from_records(data_samples)
|
108
|
+
|
109
|
+
|
110
|
+
class TrainsetGenerator:
|
111
|
+
"""
|
112
|
+
Ragas Train Set Generator
|
113
|
+
|
114
|
+
Attributes
|
115
|
+
----------
|
116
|
+
generator_llm: LangchainLLM
|
117
|
+
LLM used for all the generator operations in the TrainGeneration paradigm.
|
118
|
+
critique_llm: LangchainLLM
|
119
|
+
LLM used for all the filtering and scoring operations in TrainGeneration
|
120
|
+
paradigm.
|
121
|
+
chunk_size: int
|
122
|
+
The chunk size of nodes created from data.
|
123
|
+
train_distribution : dict
|
124
|
+
Distribution of different types of questions to be generated from given
|
125
|
+
set of documents. Defaults to {"easy":0.1, "reasoning":0.4, "conversation":0.5}
|
126
|
+
"""
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
generator_llm: BaseLanguageModel,
|
131
|
+
critic_llm: BaseLanguageModel,
|
132
|
+
trainset_distribution: t.Optional[t.Dict[str, float]] = None,
|
133
|
+
chunk_size: int = 1024,
|
134
|
+
seed: int = 42,
|
135
|
+
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
|
136
|
+
) -> None:
|
137
|
+
self.generator_llm = generator_llm
|
138
|
+
self.critic_llm = critic_llm
|
139
|
+
trainset_distribution = trainset_distribution or DEFAULT_TRAIN_DISTRIBUTION
|
140
|
+
npt.assert_almost_equal(
|
141
|
+
1,
|
142
|
+
sum(trainset_distribution.values()),
|
143
|
+
err_msg="Sum of distribution should be 1",
|
144
|
+
)
|
145
|
+
|
146
|
+
probs = np.cumsum(list(trainset_distribution.values()))
|
147
|
+
types = trainset_distribution.keys()
|
148
|
+
self.trainset_distribution = dict(zip(types, probs))
|
149
|
+
self.chunk_size = chunk_size
|
150
|
+
self.threshold = 5.0
|
151
|
+
self.rng = default_rng(seed)
|
152
|
+
self.prompt = prompt
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def from_default(
|
156
|
+
cls,
|
157
|
+
llm: BaseLanguageModel,
|
158
|
+
chunk_size: int = 512,
|
159
|
+
trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
|
160
|
+
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
|
161
|
+
):
|
162
|
+
generator_llm = llm
|
163
|
+
critic_llm = llm
|
164
|
+
return cls(
|
165
|
+
generator_llm=generator_llm,
|
166
|
+
critic_llm=critic_llm,
|
167
|
+
chunk_size=chunk_size,
|
168
|
+
trainset_distribution=trainset_distribution,
|
169
|
+
prompt=prompt,
|
170
|
+
)
|
171
|
+
|
172
|
+
def _get_evolve_type(self) -> str:
|
173
|
+
"""
|
174
|
+
Decides question evolution type based on probability
|
175
|
+
"""
|
176
|
+
prob = self.rng.uniform(0, 1)
|
177
|
+
return next(
|
178
|
+
(
|
179
|
+
key
|
180
|
+
for key in self.trainset_distribution.keys()
|
181
|
+
if prob <= self.trainset_distribution[key]
|
182
|
+
),
|
183
|
+
"simple",
|
184
|
+
)
|
185
|
+
|
186
|
+
def _filter_context(self, context: str) -> bool:
|
187
|
+
"""
|
188
|
+
context: str
|
189
|
+
The input context
|
190
|
+
|
191
|
+
Checks if the context is has information worthy of framing a question
|
192
|
+
"""
|
193
|
+
prompt = SCORE_CONTEXT_CHAT_PROMPT.format_prompt(context=context)
|
194
|
+
results = self.critic_llm(prompt.to_messages())
|
195
|
+
output = results.content
|
196
|
+
score = load_as_score(output)
|
197
|
+
print('context score:', score)
|
198
|
+
return score >= self.threshold
|
199
|
+
|
200
|
+
def _seed_question(self, context: str) -> str:
|
201
|
+
if self.prompt is None:
|
202
|
+
prompt = SEED_QUESTION_CHAT_PROMPT.format_prompt(context=context)
|
203
|
+
else:
|
204
|
+
prompt = self.prompt.format_prompt(context=context)
|
205
|
+
results = self.generator_llm(prompt.to_messages())
|
206
|
+
return results.content
|
207
|
+
|
208
|
+
def _filter_question(self, question: str) -> bool:
|
209
|
+
prompt = FILTER_QUESTION_CHAT_PROMPT.format_prompt(question=question)
|
210
|
+
results = self.critic_llm(prompt.to_messages())
|
211
|
+
results = results.content
|
212
|
+
json_results = load_as_json(results)
|
213
|
+
print('filter question:', question, json_results)
|
214
|
+
return json_results.get("verdict") != "No"
|
215
|
+
|
216
|
+
def _qc_template(self, prompt, question, context) -> str:
|
217
|
+
human_prompt = prompt.format(question=question, context=context)
|
218
|
+
prompt = ChatPromptTemplate.from_messages([human_prompt])
|
219
|
+
results = self.generator_llm(prompt.messages)
|
220
|
+
return results.content
|
221
|
+
|
222
|
+
def _generate_answer(self, question: str, context: t.List[str]) -> t.List[str]:
|
223
|
+
return [
|
224
|
+
self._qc_template(ANSWER_FORMULATE, qstn, context[i])
|
225
|
+
for i, qstn in enumerate(question.split("\n"))
|
226
|
+
]
|
227
|
+
|
228
|
+
def _remove_nodes(
|
229
|
+
self, available_indices: t.List[BaseNode], node_idx: t.List
|
230
|
+
) -> t.List[BaseNode]:
|
231
|
+
for idx in node_idx:
|
232
|
+
available_indices.remove(idx)
|
233
|
+
return available_indices
|
234
|
+
|
235
|
+
def _generate_doc_nodes_map(
|
236
|
+
self, document_nodes: t.List[BaseNode]
|
237
|
+
) -> t.Dict[str, t.List[BaseNode]]:
|
238
|
+
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list)
|
239
|
+
for node in document_nodes:
|
240
|
+
if node.ref_doc_id:
|
241
|
+
doc_nodes_map[node.ref_doc_id].append(node)
|
242
|
+
|
243
|
+
return doc_nodes_map # type: ignore
|
244
|
+
|
245
|
+
def _get_neighbour_node(
|
246
|
+
self, node: BaseNode, related_nodes: t.List[BaseNode]
|
247
|
+
) -> t.List[BaseNode]:
|
248
|
+
if len(related_nodes) < 2:
|
249
|
+
warnings.warn("No neighbors exists")
|
250
|
+
return [node]
|
251
|
+
idx = related_nodes.index(node)
|
252
|
+
ids = [idx - 1, idx] if idx == (len(related_nodes) - 1) else [idx, idx + 1]
|
253
|
+
return [related_nodes[idx] for idx in ids]
|
254
|
+
|
255
|
+
def generate(
|
256
|
+
self,
|
257
|
+
documents: t.List[LlamaindexDocument] | t.List[Document],
|
258
|
+
train_size: int,
|
259
|
+
) -> TrainDataset:
|
260
|
+
if not isinstance(documents[0], (LlamaindexDocument, Document)):
|
261
|
+
raise ValueError(
|
262
|
+
"Trainset Generatation only supports LlamaindexDocuments or Documents" # noqa
|
263
|
+
)
|
264
|
+
|
265
|
+
if isinstance(documents[0], Document):
|
266
|
+
# cast to LangchainDocument since its the only case here
|
267
|
+
documents = t.cast(t.List[Document], documents)
|
268
|
+
documents = [
|
269
|
+
LlamaindexDocument.from_langchain_format(doc) for doc in documents
|
270
|
+
]
|
271
|
+
# Convert documents into nodes
|
272
|
+
node_parser = SimpleNodeParser.from_defaults(
|
273
|
+
chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True
|
274
|
+
)
|
275
|
+
documents = t.cast(t.List[LlamaindexDocument], documents)
|
276
|
+
document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents(
|
277
|
+
documents=documents
|
278
|
+
)
|
279
|
+
# # maximum 1 seed question per node
|
280
|
+
# if train_size > len(document_nodes):
|
281
|
+
# raise ValueError(
|
282
|
+
# """Maximum possible number of samples exceeded,
|
283
|
+
# reduce train_size or add more documents"""
|
284
|
+
# )
|
285
|
+
|
286
|
+
available_nodes = document_nodes
|
287
|
+
doc_nodes_map = self._generate_doc_nodes_map(document_nodes)
|
288
|
+
count_neighbours = sum(len(val) > 1 for _, val in doc_nodes_map.items())
|
289
|
+
if count_neighbours < len(documents) // 2:
|
290
|
+
warnings.warn("Most documents are too short")
|
291
|
+
|
292
|
+
count = 0
|
293
|
+
samples = []
|
294
|
+
pbar = tqdm(total=train_size)
|
295
|
+
while count < train_size and available_nodes != []:
|
296
|
+
print(count, train_size, len(available_nodes))
|
297
|
+
evolve_type = self._get_evolve_type()
|
298
|
+
curr_node = self.rng.choice(np.array(available_nodes), size=1)[0]
|
299
|
+
available_nodes = self._remove_nodes(available_nodes, [curr_node])
|
300
|
+
|
301
|
+
neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id]
|
302
|
+
|
303
|
+
# Append multiple nodes randomly to remove chunking bias
|
304
|
+
size = self.rng.integers(1, 3)
|
305
|
+
nodes = (
|
306
|
+
self._get_neighbour_node(curr_node, neighbor_nodes)
|
307
|
+
if size > 1 and evolve_type != "multi_context"
|
308
|
+
else [curr_node]
|
309
|
+
)
|
310
|
+
|
311
|
+
text_chunk = " ".join([node.get_content() for node in nodes])
|
312
|
+
score = self._filter_context(text_chunk)
|
313
|
+
if not score:
|
314
|
+
continue
|
315
|
+
seed_question = self._seed_question(text_chunk)
|
316
|
+
|
317
|
+
question = seed_question
|
318
|
+
# is_valid_question = self._filter_question(question)
|
319
|
+
is_valid_question = True
|
320
|
+
if is_valid_question:
|
321
|
+
context = [text_chunk] * len(question.split("\n"))
|
322
|
+
is_conv = len(context) > 1
|
323
|
+
answer = self._generate_answer(question, context)
|
324
|
+
for i, (qstn, ctx, ans) in enumerate(
|
325
|
+
zip(question.split("\n"), context, answer)
|
326
|
+
):
|
327
|
+
episode_done = False if is_conv and i == 0 else True
|
328
|
+
samples.append(
|
329
|
+
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
|
330
|
+
)
|
331
|
+
count += 1
|
332
|
+
pbar.update(1)
|
333
|
+
|
334
|
+
return TrainDataset(train_data=samples)
|
335
|
+
|
336
|
+
|
337
|
+
class QAGenerationChainV2(Chain):
|
338
|
+
"""Base class for question-answer generation chains."""
|
339
|
+
|
340
|
+
documents: List[Document]
|
341
|
+
generator: TrainsetGenerator
|
342
|
+
"""LLM Chain that generates responses from user input and context."""
|
343
|
+
k: Optional[int] = None
|
344
|
+
"""Number of questions to generate."""
|
345
|
+
input_key: str = "begin"
|
346
|
+
"""Key of the input to the chain."""
|
347
|
+
output_key: str = "questions"
|
348
|
+
"""Key of the output of the chain."""
|
349
|
+
|
350
|
+
@classmethod
|
351
|
+
def from_llm(
|
352
|
+
cls,
|
353
|
+
documents: List[Document],
|
354
|
+
llm: BaseLanguageModel,
|
355
|
+
k: Optional[int] = None,
|
356
|
+
chunk_size: int = 512,
|
357
|
+
prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
|
358
|
+
**kwargs: Any,
|
359
|
+
) -> QAGenerationChainV2:
|
360
|
+
"""
|
361
|
+
Create a QAGenerationChain from a language model.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
llm: a language model
|
365
|
+
prompt: a prompt template
|
366
|
+
**kwargs: additional arguments
|
367
|
+
|
368
|
+
Returns:
|
369
|
+
a QAGenerationChain class
|
370
|
+
"""
|
371
|
+
generator = TrainsetGenerator.from_default(llm, chunk_size=chunk_size, prompt=prompt)
|
372
|
+
return cls(documents=documents, generator=generator, k=k, **kwargs)
|
373
|
+
|
374
|
+
@property
|
375
|
+
def _chain_type(self) -> str:
|
376
|
+
raise NotImplementedError
|
377
|
+
|
378
|
+
@property
|
379
|
+
def input_keys(self) -> List[str]:
|
380
|
+
return [self.input_key]
|
381
|
+
|
382
|
+
@property
|
383
|
+
def output_keys(self) -> List[str]:
|
384
|
+
return [self.output_key]
|
385
|
+
|
386
|
+
def _call(
|
387
|
+
self,
|
388
|
+
inputs: Dict[str, Any],
|
389
|
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
390
|
+
) -> Dict[str, List]:
|
391
|
+
for doc in self.documents:
|
392
|
+
doc.metadata = {}
|
393
|
+
if self.k is None:
|
394
|
+
self.k = 1000
|
395
|
+
dataset = self.generator.generate(documents=self.documents, train_size=self.k)
|
396
|
+
df = dataset.to_pandas()
|
397
|
+
qa_pairs = df.to_dict("records")
|
398
|
+
qa = ''
|
399
|
+
for pair in qa_pairs:
|
400
|
+
qa += json.dumps(
|
401
|
+
{
|
402
|
+
"question": pair["question"],
|
403
|
+
"answer": pair["ground_truth"][0],
|
404
|
+
}, ensure_ascii=False)
|
405
|
+
return {self.output_key: qa}
|
406
|
+
|
407
|
+
async def _acall(
|
408
|
+
self,
|
409
|
+
inputs: Dict[str, Any],
|
410
|
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
411
|
+
) -> Dict[str, List]:
|
412
|
+
output = self._call(inputs, run_manager)
|
413
|
+
return output
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# flake8: noqa
|
2
|
+
from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
|
3
|
+
from langchain_core.prompts.chat import (
|
4
|
+
ChatPromptTemplate,
|
5
|
+
HumanMessagePromptTemplate,
|
6
|
+
SystemMessagePromptTemplate,
|
7
|
+
)
|
8
|
+
from langchain_core.prompts.prompt import PromptTemplate
|
9
|
+
|
10
|
+
templ1 = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
|
11
|
+
Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
|
12
|
+
When coming up with this question/answer pair, you must respond in the following format and in same language as the text:
|
13
|
+
```
|
14
|
+
{{
|
15
|
+
"question": "$YOUR_QUESTION_HERE",
|
16
|
+
"answer": "$THE_ANSWER_HERE"
|
17
|
+
}}
|
18
|
+
```
|
19
|
+
|
20
|
+
Everything between the ``` must be valid json.
|
21
|
+
"""
|
22
|
+
templ2 = """Please come up with a question/answer pair, in the specified JSON format, for the following text:
|
23
|
+
----------------
|
24
|
+
{text}"""
|
25
|
+
CHAT_PROMPT = ChatPromptTemplate.from_messages(
|
26
|
+
[
|
27
|
+
SystemMessagePromptTemplate.from_template(templ1),
|
28
|
+
HumanMessagePromptTemplate.from_template(templ2),
|
29
|
+
]
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
templ = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
|
34
|
+
Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
|
35
|
+
When coming up with this question/answer pair, you must respond in the following format and in same language as the text:
|
36
|
+
```
|
37
|
+
{{
|
38
|
+
"question": "$YOUR_QUESTION_HERE",
|
39
|
+
"answer": "$THE_ANSWER_HERE"
|
40
|
+
}}
|
41
|
+
```
|
42
|
+
|
43
|
+
Everything between the ``` must be valid json.
|
44
|
+
|
45
|
+
Please come up with a question/answer pair, in the specified JSON format, for the following text:
|
46
|
+
----------------
|
47
|
+
{text}"""
|
48
|
+
PROMPT = PromptTemplate.from_template(templ)
|
49
|
+
|
50
|
+
|
51
|
+
PROMPT_SELECTOR = ConditionalPromptSelector(
|
52
|
+
default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
|
53
|
+
)
|