bisheng-langchain 0.3.1.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. bisheng_langchain/chains/__init__.py +4 -1
  2. bisheng_langchain/chains/qa_generation/__init__.py +0 -0
  3. bisheng_langchain/chains/qa_generation/base.py +128 -0
  4. bisheng_langchain/chains/qa_generation/base_v2.py +413 -0
  5. bisheng_langchain/chains/qa_generation/prompt.py +53 -0
  6. bisheng_langchain/chains/qa_generation/prompt_v2.py +155 -0
  7. bisheng_langchain/document_loaders/elem_unstrcutured_loader.py +36 -9
  8. bisheng_langchain/document_loaders/parsers/ellm_client.py +7 -9
  9. bisheng_langchain/document_loaders/universal_kv.py +4 -3
  10. bisheng_langchain/gpts/tools/api_tools/openapi.py +7 -7
  11. bisheng_langchain/rag/__init__.py +2 -0
  12. bisheng_langchain/rag/bisheng_rag_chain.py +164 -0
  13. bisheng_langchain/rag/bisheng_rag_pipeline_v2.py +8 -2
  14. bisheng_langchain/rag/bisheng_rag_tool.py +47 -24
  15. bisheng_langchain/rag/config/baseline_caibao_v2.yaml +1 -1
  16. bisheng_langchain/rag/config/baseline_v2.yaml +3 -2
  17. bisheng_langchain/rag/prompts/prompt.py +1 -1
  18. bisheng_langchain/rag/qa_corpus/qa_generator.py +1 -1
  19. bisheng_langchain/rag/scoring/ragas_score.py +2 -2
  20. bisheng_langchain/rag/utils.py +27 -4
  21. bisheng_langchain/sql/__init__.py +3 -0
  22. bisheng_langchain/sql/base.py +120 -0
  23. bisheng_langchain/text_splitter.py +1 -1
  24. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/METADATA +3 -1
  25. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/RECORD +27 -20
  26. bisheng_langchain/rag/bisheng_rag_pipeline_v2_cohere_raw_prompting.py +0 -376
  27. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/WHEEL +0 -0
  28. {bisheng_langchain-0.3.1.1.dist-info → bisheng_langchain-0.3.2.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,13 @@ from bisheng_langchain.chains.retrieval.retrieval_chain import RetrievalChain
5
5
  from bisheng_langchain.chains.router.multi_rule import MultiRuleChain
6
6
  from bisheng_langchain.chains.router.rule_router import RuleBasedRouter
7
7
  from bisheng_langchain.chains.transform import TransformChain
8
+ from bisheng_langchain.chains.qa_generation.base import QAGenerationChain
9
+ from bisheng_langchain.chains.qa_generation.base_v2 import QAGenerationChainV2
8
10
 
9
11
  from .loader_output import LoaderOutputChain
10
12
 
11
13
  __all__ = [
12
14
  'StuffDocumentsChain', 'LoaderOutputChain', 'AutoGenChain', 'RuleBasedRouter',
13
- 'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain', 'TransformChain'
15
+ 'MultiRuleChain', 'RetrievalChain', 'ConversationalRetrievalChain', 'TransformChain',
16
+ 'QAGenerationChain', 'QAGenerationChainV2'
14
17
  ]
File without changes
@@ -0,0 +1,128 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ import logging
6
+ from langchain.docstore.document import Document
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ from langchain_core.callbacks import CallbackManagerForChainRun
10
+ from langchain_core.language_models import BaseLanguageModel
11
+ from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
12
+ from langchain_core.pydantic_v1 import Field
13
+ from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
14
+
15
+ from langchain.chains.base import Chain
16
+ from langchain.chains.llm import LLMChain
17
+ from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR, CHAT_PROMPT, PROMPT
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_json(input_str: str) -> str:
23
+ match = re.search(r'```(json)?(.*)```', input_str, re.DOTALL)
24
+ if match is None:
25
+ out_str = input_str
26
+ else:
27
+ out_str = match.group(2)
28
+
29
+ out_str = out_str.strip()
30
+ out_str = out_str.replace('```', '')
31
+ return out_str
32
+
33
+
34
+ class QAGenerationChain(Chain):
35
+ """Base class for question-answer generation chains."""
36
+
37
+ documents: List[Document]
38
+ llm_chain: LLMChain
39
+ """LLM Chain that generates responses from user input and context."""
40
+ k: Optional[int] = None
41
+ """Number of questions to generate."""
42
+ text_splitter: TextSplitter = Field(
43
+ default=RecursiveCharacterTextSplitter(
44
+ separators=["\n\n", "\n", " ", ""],
45
+ chunk_size=1000,
46
+ chunk_overlap=100,
47
+ )
48
+ )
49
+ """Text splitter that splits the input into chunks."""
50
+ input_key: str = "begin"
51
+ """Key of the input to the chain."""
52
+ output_key: str = "questions"
53
+ """Key of the output of the chain."""
54
+
55
+ @classmethod
56
+ def from_llm(
57
+ cls,
58
+ documents: List[Document],
59
+ llm: BaseLanguageModel,
60
+ k: Optional[int] = None,
61
+ chunk_size: int = 512,
62
+ prompt: Optional[ChatPromptTemplate] = CHAT_PROMPT,
63
+ **kwargs: Any,
64
+ ) -> QAGenerationChain:
65
+ """
66
+ Create a QAGenerationChain from a language model.
67
+
68
+ Args:
69
+ llm: a language model
70
+ prompt: a prompt template
71
+ **kwargs: additional arguments
72
+
73
+ Returns:
74
+ a QAGenerationChain class
75
+ """
76
+ _prompt = PROMPT_SELECTOR.get_prompt(llm) if prompt is None else prompt
77
+ chain = LLMChain(llm=llm, prompt=_prompt)
78
+ text_splitter = RecursiveCharacterTextSplitter(
79
+ separators=["\n\n", "\n", " ", ""],
80
+ chunk_size=chunk_size,
81
+ chunk_overlap=50,
82
+ )
83
+ return cls(documents=documents, llm_chain=chain, k=k, text_splitter=text_splitter, **kwargs)
84
+
85
+ @property
86
+ def _chain_type(self) -> str:
87
+ raise NotImplementedError
88
+
89
+ @property
90
+ def input_keys(self) -> List[str]:
91
+ return [self.input_key]
92
+
93
+ @property
94
+ def output_keys(self) -> List[str]:
95
+ return [self.output_key]
96
+
97
+ def _call(
98
+ self,
99
+ inputs: Dict[str, Any],
100
+ run_manager: Optional[CallbackManagerForChainRun] = None,
101
+ ) -> Dict[str, List]:
102
+ contents = [doc.page_content for doc in self.documents]
103
+ contents = '\n\n'.join(contents)
104
+ docs = self.text_splitter.create_documents([contents])
105
+ # len(qa) = min(len(docs), self.k)
106
+ logger.info(f"Split {len(docs)} documents. Gen qa num: min({len(docs)}, {self.k}).")
107
+ qa = ''
108
+ qa_i = 0
109
+ for doc in docs:
110
+ try:
111
+ results = self.llm_chain.generate([{"text": doc.page_content}], run_manager=run_manager)
112
+ res = results.generations[0]
113
+ qa += res[0].text
114
+ qa_i += 1
115
+ except Exception as e:
116
+ logger.error(f"Failed to parse response Error: {e}")
117
+ continue
118
+ if self.k is not None and qa_i >= self.k:
119
+ break
120
+ return {self.output_key: qa}
121
+
122
+ async def _acall(
123
+ self,
124
+ inputs: Dict[str, Any],
125
+ run_manager: Optional[CallbackManagerForChainRun] = None,
126
+ ) -> Dict[str, List]:
127
+ output = self._call(inputs, run_manager)
128
+ return output
@@ -0,0 +1,413 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ import json
5
+ import logging
6
+ import typing as t
7
+ import warnings
8
+ from typing import Any, Dict, List, Optional
9
+ from collections import defaultdict, namedtuple
10
+ from dataclasses import dataclass
11
+ from langchain_core.callbacks import CallbackManagerForChainRun
12
+ from langchain_core.language_models import BaseLanguageModel
13
+
14
+ try:
15
+ from llama_index.node_parser import SimpleNodeParser
16
+ from llama_index.readers.schema import Document as LlamaindexDocument
17
+ from llama_index.schema import BaseNode
18
+ except ImportError:
19
+ raise ImportError(
20
+ "llama_index must be installed to use this function. "
21
+ "Please, install it with `pip install llama_index`."
22
+ )
23
+ import numpy as np
24
+ import numpy.testing as npt
25
+ import pandas as pd
26
+ from langchain.prompts import ChatPromptTemplate
27
+ from langchain.docstore.document import Document
28
+ # from langchain.schema.document import Document as LangchainDocument
29
+ from langchain.chains.base import Chain
30
+ from numpy.random import default_rng
31
+ from tqdm import tqdm
32
+ from .prompt_v2 import (
33
+ SEED_QUESTION_CHAT_PROMPT,
34
+ SCORE_CONTEXT_CHAT_PROMPT,
35
+ FILTER_QUESTION_CHAT_PROMPT,
36
+ ANSWER_FORMULATE,
37
+ )
38
+ from .base import parse_json
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ def load_as_score(text):
44
+ """
45
+ validate and returns given text as score
46
+ """
47
+ pattern = r"^[\d.]+$"
48
+ if not re.match(pattern, text):
49
+ warnings.warn("Invalid score")
50
+ score = 0.0
51
+ else:
52
+ score = eval(text)
53
+ return score
54
+
55
+
56
+ def load_as_json(text):
57
+ """
58
+ validate and return given text as json
59
+ """
60
+
61
+ try:
62
+ return json.loads(parse_json(text))
63
+ except ValueError as e:
64
+ warnings.warn(f"Invalid json: {e}")
65
+
66
+ return {}
67
+
68
+
69
+ DEFAULT_TRAIN_DISTRIBUTION = {
70
+ "simple": 1.0,
71
+ "reasoning": 0.0,
72
+ "multi_context": 0.0,
73
+ "conditional": 0.0,
74
+ }
75
+
76
+ DataRow = namedtuple(
77
+ "DataRow",
78
+ [
79
+ "question",
80
+ "ground_truth_context",
81
+ "ground_truth",
82
+ "question_type",
83
+ "episode_done",
84
+ ],
85
+ )
86
+
87
+
88
+ @dataclass
89
+ class TrainDataset:
90
+ """
91
+ TrainDataset class
92
+ """
93
+
94
+ train_data: t.List[DataRow]
95
+
96
+ def to_pandas(self) -> pd.DataFrame:
97
+ data_samples = []
98
+ for data in self.train_data:
99
+ data = {
100
+ "question": data.question,
101
+ "ground_truth_context": data.ground_truth_context,
102
+ "ground_truth": data.ground_truth,
103
+ "question_type": data.question_type,
104
+ "episode_done": data.episode_done,
105
+ }
106
+ data_samples.append(data)
107
+ return pd.DataFrame.from_records(data_samples)
108
+
109
+
110
+ class TrainsetGenerator:
111
+ """
112
+ Ragas Train Set Generator
113
+
114
+ Attributes
115
+ ----------
116
+ generator_llm: LangchainLLM
117
+ LLM used for all the generator operations in the TrainGeneration paradigm.
118
+ critique_llm: LangchainLLM
119
+ LLM used for all the filtering and scoring operations in TrainGeneration
120
+ paradigm.
121
+ chunk_size: int
122
+ The chunk size of nodes created from data.
123
+ train_distribution : dict
124
+ Distribution of different types of questions to be generated from given
125
+ set of documents. Defaults to {"easy":0.1, "reasoning":0.4, "conversation":0.5}
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ generator_llm: BaseLanguageModel,
131
+ critic_llm: BaseLanguageModel,
132
+ trainset_distribution: t.Optional[t.Dict[str, float]] = None,
133
+ chunk_size: int = 1024,
134
+ seed: int = 42,
135
+ prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
136
+ ) -> None:
137
+ self.generator_llm = generator_llm
138
+ self.critic_llm = critic_llm
139
+ trainset_distribution = trainset_distribution or DEFAULT_TRAIN_DISTRIBUTION
140
+ npt.assert_almost_equal(
141
+ 1,
142
+ sum(trainset_distribution.values()),
143
+ err_msg="Sum of distribution should be 1",
144
+ )
145
+
146
+ probs = np.cumsum(list(trainset_distribution.values()))
147
+ types = trainset_distribution.keys()
148
+ self.trainset_distribution = dict(zip(types, probs))
149
+ self.chunk_size = chunk_size
150
+ self.threshold = 5.0
151
+ self.rng = default_rng(seed)
152
+ self.prompt = prompt
153
+
154
+ @classmethod
155
+ def from_default(
156
+ cls,
157
+ llm: BaseLanguageModel,
158
+ chunk_size: int = 512,
159
+ trainset_distribution: dict = DEFAULT_TRAIN_DISTRIBUTION,
160
+ prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
161
+ ):
162
+ generator_llm = llm
163
+ critic_llm = llm
164
+ return cls(
165
+ generator_llm=generator_llm,
166
+ critic_llm=critic_llm,
167
+ chunk_size=chunk_size,
168
+ trainset_distribution=trainset_distribution,
169
+ prompt=prompt,
170
+ )
171
+
172
+ def _get_evolve_type(self) -> str:
173
+ """
174
+ Decides question evolution type based on probability
175
+ """
176
+ prob = self.rng.uniform(0, 1)
177
+ return next(
178
+ (
179
+ key
180
+ for key in self.trainset_distribution.keys()
181
+ if prob <= self.trainset_distribution[key]
182
+ ),
183
+ "simple",
184
+ )
185
+
186
+ def _filter_context(self, context: str) -> bool:
187
+ """
188
+ context: str
189
+ The input context
190
+
191
+ Checks if the context is has information worthy of framing a question
192
+ """
193
+ prompt = SCORE_CONTEXT_CHAT_PROMPT.format_prompt(context=context)
194
+ results = self.critic_llm(prompt.to_messages())
195
+ output = results.content
196
+ score = load_as_score(output)
197
+ print('context score:', score)
198
+ return score >= self.threshold
199
+
200
+ def _seed_question(self, context: str) -> str:
201
+ if self.prompt is None:
202
+ prompt = SEED_QUESTION_CHAT_PROMPT.format_prompt(context=context)
203
+ else:
204
+ prompt = self.prompt.format_prompt(context=context)
205
+ results = self.generator_llm(prompt.to_messages())
206
+ return results.content
207
+
208
+ def _filter_question(self, question: str) -> bool:
209
+ prompt = FILTER_QUESTION_CHAT_PROMPT.format_prompt(question=question)
210
+ results = self.critic_llm(prompt.to_messages())
211
+ results = results.content
212
+ json_results = load_as_json(results)
213
+ print('filter question:', question, json_results)
214
+ return json_results.get("verdict") != "No"
215
+
216
+ def _qc_template(self, prompt, question, context) -> str:
217
+ human_prompt = prompt.format(question=question, context=context)
218
+ prompt = ChatPromptTemplate.from_messages([human_prompt])
219
+ results = self.generator_llm(prompt.messages)
220
+ return results.content
221
+
222
+ def _generate_answer(self, question: str, context: t.List[str]) -> t.List[str]:
223
+ return [
224
+ self._qc_template(ANSWER_FORMULATE, qstn, context[i])
225
+ for i, qstn in enumerate(question.split("\n"))
226
+ ]
227
+
228
+ def _remove_nodes(
229
+ self, available_indices: t.List[BaseNode], node_idx: t.List
230
+ ) -> t.List[BaseNode]:
231
+ for idx in node_idx:
232
+ available_indices.remove(idx)
233
+ return available_indices
234
+
235
+ def _generate_doc_nodes_map(
236
+ self, document_nodes: t.List[BaseNode]
237
+ ) -> t.Dict[str, t.List[BaseNode]]:
238
+ doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list)
239
+ for node in document_nodes:
240
+ if node.ref_doc_id:
241
+ doc_nodes_map[node.ref_doc_id].append(node)
242
+
243
+ return doc_nodes_map # type: ignore
244
+
245
+ def _get_neighbour_node(
246
+ self, node: BaseNode, related_nodes: t.List[BaseNode]
247
+ ) -> t.List[BaseNode]:
248
+ if len(related_nodes) < 2:
249
+ warnings.warn("No neighbors exists")
250
+ return [node]
251
+ idx = related_nodes.index(node)
252
+ ids = [idx - 1, idx] if idx == (len(related_nodes) - 1) else [idx, idx + 1]
253
+ return [related_nodes[idx] for idx in ids]
254
+
255
+ def generate(
256
+ self,
257
+ documents: t.List[LlamaindexDocument] | t.List[Document],
258
+ train_size: int,
259
+ ) -> TrainDataset:
260
+ if not isinstance(documents[0], (LlamaindexDocument, Document)):
261
+ raise ValueError(
262
+ "Trainset Generatation only supports LlamaindexDocuments or Documents" # noqa
263
+ )
264
+
265
+ if isinstance(documents[0], Document):
266
+ # cast to LangchainDocument since its the only case here
267
+ documents = t.cast(t.List[Document], documents)
268
+ documents = [
269
+ LlamaindexDocument.from_langchain_format(doc) for doc in documents
270
+ ]
271
+ # Convert documents into nodes
272
+ node_parser = SimpleNodeParser.from_defaults(
273
+ chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True
274
+ )
275
+ documents = t.cast(t.List[LlamaindexDocument], documents)
276
+ document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents(
277
+ documents=documents
278
+ )
279
+ # # maximum 1 seed question per node
280
+ # if train_size > len(document_nodes):
281
+ # raise ValueError(
282
+ # """Maximum possible number of samples exceeded,
283
+ # reduce train_size or add more documents"""
284
+ # )
285
+
286
+ available_nodes = document_nodes
287
+ doc_nodes_map = self._generate_doc_nodes_map(document_nodes)
288
+ count_neighbours = sum(len(val) > 1 for _, val in doc_nodes_map.items())
289
+ if count_neighbours < len(documents) // 2:
290
+ warnings.warn("Most documents are too short")
291
+
292
+ count = 0
293
+ samples = []
294
+ pbar = tqdm(total=train_size)
295
+ while count < train_size and available_nodes != []:
296
+ print(count, train_size, len(available_nodes))
297
+ evolve_type = self._get_evolve_type()
298
+ curr_node = self.rng.choice(np.array(available_nodes), size=1)[0]
299
+ available_nodes = self._remove_nodes(available_nodes, [curr_node])
300
+
301
+ neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id]
302
+
303
+ # Append multiple nodes randomly to remove chunking bias
304
+ size = self.rng.integers(1, 3)
305
+ nodes = (
306
+ self._get_neighbour_node(curr_node, neighbor_nodes)
307
+ if size > 1 and evolve_type != "multi_context"
308
+ else [curr_node]
309
+ )
310
+
311
+ text_chunk = " ".join([node.get_content() for node in nodes])
312
+ score = self._filter_context(text_chunk)
313
+ if not score:
314
+ continue
315
+ seed_question = self._seed_question(text_chunk)
316
+
317
+ question = seed_question
318
+ # is_valid_question = self._filter_question(question)
319
+ is_valid_question = True
320
+ if is_valid_question:
321
+ context = [text_chunk] * len(question.split("\n"))
322
+ is_conv = len(context) > 1
323
+ answer = self._generate_answer(question, context)
324
+ for i, (qstn, ctx, ans) in enumerate(
325
+ zip(question.split("\n"), context, answer)
326
+ ):
327
+ episode_done = False if is_conv and i == 0 else True
328
+ samples.append(
329
+ DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
330
+ )
331
+ count += 1
332
+ pbar.update(1)
333
+
334
+ return TrainDataset(train_data=samples)
335
+
336
+
337
+ class QAGenerationChainV2(Chain):
338
+ """Base class for question-answer generation chains."""
339
+
340
+ documents: List[Document]
341
+ generator: TrainsetGenerator
342
+ """LLM Chain that generates responses from user input and context."""
343
+ k: Optional[int] = None
344
+ """Number of questions to generate."""
345
+ input_key: str = "begin"
346
+ """Key of the input to the chain."""
347
+ output_key: str = "questions"
348
+ """Key of the output of the chain."""
349
+
350
+ @classmethod
351
+ def from_llm(
352
+ cls,
353
+ documents: List[Document],
354
+ llm: BaseLanguageModel,
355
+ k: Optional[int] = None,
356
+ chunk_size: int = 512,
357
+ prompt: Optional[ChatPromptTemplate] = SEED_QUESTION_CHAT_PROMPT,
358
+ **kwargs: Any,
359
+ ) -> QAGenerationChainV2:
360
+ """
361
+ Create a QAGenerationChain from a language model.
362
+
363
+ Args:
364
+ llm: a language model
365
+ prompt: a prompt template
366
+ **kwargs: additional arguments
367
+
368
+ Returns:
369
+ a QAGenerationChain class
370
+ """
371
+ generator = TrainsetGenerator.from_default(llm, chunk_size=chunk_size, prompt=prompt)
372
+ return cls(documents=documents, generator=generator, k=k, **kwargs)
373
+
374
+ @property
375
+ def _chain_type(self) -> str:
376
+ raise NotImplementedError
377
+
378
+ @property
379
+ def input_keys(self) -> List[str]:
380
+ return [self.input_key]
381
+
382
+ @property
383
+ def output_keys(self) -> List[str]:
384
+ return [self.output_key]
385
+
386
+ def _call(
387
+ self,
388
+ inputs: Dict[str, Any],
389
+ run_manager: Optional[CallbackManagerForChainRun] = None,
390
+ ) -> Dict[str, List]:
391
+ for doc in self.documents:
392
+ doc.metadata = {}
393
+ if self.k is None:
394
+ self.k = 1000
395
+ dataset = self.generator.generate(documents=self.documents, train_size=self.k)
396
+ df = dataset.to_pandas()
397
+ qa_pairs = df.to_dict("records")
398
+ qa = ''
399
+ for pair in qa_pairs:
400
+ qa += json.dumps(
401
+ {
402
+ "question": pair["question"],
403
+ "answer": pair["ground_truth"][0],
404
+ }, ensure_ascii=False)
405
+ return {self.output_key: qa}
406
+
407
+ async def _acall(
408
+ self,
409
+ inputs: Dict[str, Any],
410
+ run_manager: Optional[CallbackManagerForChainRun] = None,
411
+ ) -> Dict[str, List]:
412
+ output = self._call(inputs, run_manager)
413
+ return output
@@ -0,0 +1,53 @@
1
+ # flake8: noqa
2
+ from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model
3
+ from langchain_core.prompts.chat import (
4
+ ChatPromptTemplate,
5
+ HumanMessagePromptTemplate,
6
+ SystemMessagePromptTemplate,
7
+ )
8
+ from langchain_core.prompts.prompt import PromptTemplate
9
+
10
+ templ1 = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
11
+ Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
12
+ When coming up with this question/answer pair, you must respond in the following format and in same language as the text:
13
+ ```
14
+ {{
15
+ "question": "$YOUR_QUESTION_HERE",
16
+ "answer": "$THE_ANSWER_HERE"
17
+ }}
18
+ ```
19
+
20
+ Everything between the ``` must be valid json.
21
+ """
22
+ templ2 = """Please come up with a question/answer pair, in the specified JSON format, for the following text:
23
+ ----------------
24
+ {text}"""
25
+ CHAT_PROMPT = ChatPromptTemplate.from_messages(
26
+ [
27
+ SystemMessagePromptTemplate.from_template(templ1),
28
+ HumanMessagePromptTemplate.from_template(templ2),
29
+ ]
30
+ )
31
+
32
+
33
+ templ = """You are a smart assistant designed to help high school teachers come up with reading comprehension questions.
34
+ Given a piece of text, you must come up with a question and answer pair that can be used to test a student's reading comprehension abilities.
35
+ When coming up with this question/answer pair, you must respond in the following format and in same language as the text:
36
+ ```
37
+ {{
38
+ "question": "$YOUR_QUESTION_HERE",
39
+ "answer": "$THE_ANSWER_HERE"
40
+ }}
41
+ ```
42
+
43
+ Everything between the ``` must be valid json.
44
+
45
+ Please come up with a question/answer pair, in the specified JSON format, for the following text:
46
+ ----------------
47
+ {text}"""
48
+ PROMPT = PromptTemplate.from_template(templ)
49
+
50
+
51
+ PROMPT_SELECTOR = ConditionalPromptSelector(
52
+ default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
53
+ )