user-simulator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- user_sim/__init__.py +0 -0
- user_sim/cli/__init__.py +0 -0
- user_sim/cli/gen_user_profile.py +34 -0
- user_sim/cli/init_project.py +65 -0
- user_sim/cli/sensei_chat.py +481 -0
- user_sim/cli/sensei_check.py +103 -0
- user_sim/cli/validation_check.py +143 -0
- user_sim/core/__init__.py +0 -0
- user_sim/core/ask_about.py +665 -0
- user_sim/core/data_extraction.py +260 -0
- user_sim/core/data_gathering.py +134 -0
- user_sim/core/interaction_styles.py +147 -0
- user_sim/core/role_structure.py +608 -0
- user_sim/core/user_simulator.py +302 -0
- user_sim/handlers/__init__.py +0 -0
- user_sim/handlers/asr_module.py +128 -0
- user_sim/handlers/html_parser_module.py +202 -0
- user_sim/handlers/image_recognition_module.py +139 -0
- user_sim/handlers/pdf_parser_module.py +123 -0
- user_sim/utils/__init__.py +0 -0
- user_sim/utils/config.py +47 -0
- user_sim/utils/cost_tracker.py +153 -0
- user_sim/utils/cost_tracker_v2.py +193 -0
- user_sim/utils/errors.py +15 -0
- user_sim/utils/exceptions.py +47 -0
- user_sim/utils/languages.py +78 -0
- user_sim/utils/register_management.py +62 -0
- user_sim/utils/show_logs.py +63 -0
- user_sim/utils/token_cost_calculator.py +338 -0
- user_sim/utils/url_management.py +60 -0
- user_sim/utils/utilities.py +568 -0
- user_simulator-0.1.0.dist-info/METADATA +733 -0
- user_simulator-0.1.0.dist-info/RECORD +37 -0
- user_simulator-0.1.0.dist-info/WHEEL +5 -0
- user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
- user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
- user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,260 @@
|
|
1
|
+
import re
|
2
|
+
import logging
|
3
|
+
from dateutil import parser
|
4
|
+
from langchain_core.prompts import ChatPromptTemplate
|
5
|
+
from user_sim.utils.token_cost_calculator import calculate_cost
|
6
|
+
from user_sim.utils import config
|
7
|
+
from user_sim.utils.utilities import init_model
|
8
|
+
from datetime import date
|
9
|
+
|
10
|
+
|
11
|
+
model = ""
|
12
|
+
llm = None
|
13
|
+
logger = logging.getLogger('Info Logger')
|
14
|
+
|
15
|
+
|
16
|
+
def init_data_extraction_module():
|
17
|
+
global model
|
18
|
+
global llm
|
19
|
+
model, llm = init_model()
|
20
|
+
|
21
|
+
|
22
|
+
class DataExtraction:
|
23
|
+
|
24
|
+
def __init__(self, conversation, variable_name, dtype, description):
|
25
|
+
self.model = "gpt-4o-mini"
|
26
|
+
self.message = f"{conversation['interaction']}"
|
27
|
+
self.dtype = dtype
|
28
|
+
self.variable = variable_name
|
29
|
+
self.description = description
|
30
|
+
self.system = f"""
|
31
|
+
You're an assistant that analyzes a conversation between a user and a chatbot.
|
32
|
+
Your objective is to test the chatbot's capabilities by extracting the information only if the chatbot provides it
|
33
|
+
or verifies it. Output only the requested data, If you couldn't find it, output None.
|
34
|
+
"""
|
35
|
+
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def data_process(text, dtype):
|
39
|
+
logger.info(f'input text on data process for casting: {text}')
|
40
|
+
|
41
|
+
if text is None or text == 'null':
|
42
|
+
return text
|
43
|
+
try:
|
44
|
+
if dtype == 'int':
|
45
|
+
return int(text)
|
46
|
+
elif dtype == 'float':
|
47
|
+
return float(text)
|
48
|
+
elif dtype == 'money':
|
49
|
+
return text
|
50
|
+
elif dtype == 'str':
|
51
|
+
return str(text)
|
52
|
+
elif dtype == 'bool':
|
53
|
+
return bool(text)
|
54
|
+
elif dtype == 'time':
|
55
|
+
# time = parser.parse(text).time().strftime("%H:%M:%S")
|
56
|
+
time = str(text)
|
57
|
+
return time
|
58
|
+
elif dtype == 'date':
|
59
|
+
# date = parser.parse(text).date()
|
60
|
+
date = str(text)
|
61
|
+
return date
|
62
|
+
else:
|
63
|
+
return text
|
64
|
+
|
65
|
+
except ValueError as e:
|
66
|
+
logger.warning(f"Error in casting: {e}. Returning 'str({str(text)})'.")
|
67
|
+
return str(text)
|
68
|
+
|
69
|
+
def get_data_prompt(self, dtype):
|
70
|
+
time_format = "hh:mm:ss"
|
71
|
+
date_format = "month/day/year"
|
72
|
+
todays_date = date.today()
|
73
|
+
if "time(" in dtype:
|
74
|
+
match = re.findall(r'\((.*?)\)', dtype)
|
75
|
+
if match:
|
76
|
+
time_format = match
|
77
|
+
dtype = "time"
|
78
|
+
if "date(" in dtype:
|
79
|
+
match = re.findall(r'\((.*?)\)', dtype)
|
80
|
+
|
81
|
+
if match:
|
82
|
+
date_format = match
|
83
|
+
dtype = "date"
|
84
|
+
|
85
|
+
|
86
|
+
data_type = {'int': 'integer',
|
87
|
+
'float': 'number',
|
88
|
+
'string': 'string',
|
89
|
+
'time': 'string',
|
90
|
+
'bool': 'boolean',
|
91
|
+
'date': 'string',
|
92
|
+
'list': 'array'}
|
93
|
+
|
94
|
+
data_format = {'int': '',
|
95
|
+
'float': '',
|
96
|
+
'string': "Extract and display concisely only the requested information "
|
97
|
+
"without including additional context",
|
98
|
+
'time': f'Output just the time data (not date) following strictly this format: {time_format}',
|
99
|
+
'bool': '',
|
100
|
+
'list': 'Output only the content to list.',
|
101
|
+
'date': f'''
|
102
|
+
Output just the date data (not time) following strictly this format: {date_format}.
|
103
|
+
If you're getting a relative date, for example, "tomorrow", "yesterday", "in two days",
|
104
|
+
keep in mind that today is {todays_date}.
|
105
|
+
'''}
|
106
|
+
|
107
|
+
|
108
|
+
prompt_type = data_type.get(dtype)
|
109
|
+
d_format = data_format.get(dtype)
|
110
|
+
return prompt_type, d_format
|
111
|
+
|
112
|
+
def static_extraction(self, dtype, dformat, list_dtype):
|
113
|
+
parsed_input_message = self.system + self.message
|
114
|
+
|
115
|
+
|
116
|
+
description = f"{self.description}. {dformat}"
|
117
|
+
|
118
|
+
prompt = ChatPromptTemplate.from_messages([("system", self.system + description), ("human", "{input}")])
|
119
|
+
|
120
|
+
|
121
|
+
if dtype == "array":
|
122
|
+
answer = {
|
123
|
+
"type": [dtype, 'null'],
|
124
|
+
"items": {
|
125
|
+
"type": list_dtype
|
126
|
+
}
|
127
|
+
}
|
128
|
+
else:
|
129
|
+
answer = {
|
130
|
+
"type": [dtype, 'null'],
|
131
|
+
}
|
132
|
+
|
133
|
+
|
134
|
+
response_format = {
|
135
|
+
"title": "Data_extraction",
|
136
|
+
"description": description,
|
137
|
+
"type": "object",
|
138
|
+
"properties": {
|
139
|
+
"answer": answer
|
140
|
+
},
|
141
|
+
"required": ['answer'],
|
142
|
+
"additionalProperties": False,
|
143
|
+
}
|
144
|
+
|
145
|
+
|
146
|
+
structured_llm = llm.with_structured_output(response_format)
|
147
|
+
prompted_structured_llm = prompt | structured_llm
|
148
|
+
response = prompted_structured_llm.invoke({"input": self.message})
|
149
|
+
|
150
|
+
output_message = response["answer"]
|
151
|
+
if config.token_count_enabled:
|
152
|
+
calculate_cost(parsed_input_message, output_message, model=self.model, module="data_extraction")
|
153
|
+
|
154
|
+
return output_message
|
155
|
+
|
156
|
+
def dynamic_extraction(self, extraction, llm_output):
|
157
|
+
extraction_keys = list(extraction.keys())
|
158
|
+
field_definitions = {key: ([self.get_data_prompt(extraction[key]["type"])[0], "null"], extraction[key]["description"]) for key in extraction_keys}
|
159
|
+
|
160
|
+
if llm_output is None:
|
161
|
+
logger.warning("Couldn't get an answer from static extraction.")
|
162
|
+
llm_output = "none"
|
163
|
+
|
164
|
+
message = llm_output
|
165
|
+
|
166
|
+
parsed_input_message = self.system + message
|
167
|
+
properties = {}
|
168
|
+
required = []
|
169
|
+
|
170
|
+
for field_name, (field_type, field_description) in field_definitions.items():
|
171
|
+
properties[field_name] = {
|
172
|
+
"type": field_type,
|
173
|
+
"description": field_description
|
174
|
+
}
|
175
|
+
required.append(field_name)
|
176
|
+
|
177
|
+
response_format = {
|
178
|
+
"title": "data_extraction",
|
179
|
+
"description": "The data you want to extract",
|
180
|
+
"type": "object",
|
181
|
+
"properties": properties,
|
182
|
+
"required": required,
|
183
|
+
"additionalProperties": False,
|
184
|
+
}
|
185
|
+
|
186
|
+
prompt = ChatPromptTemplate.from_messages([("system", self.system), ("human", "{input}")])
|
187
|
+
|
188
|
+
|
189
|
+
structured_llm = llm.with_structured_output(response_format)
|
190
|
+
prompted_structured_llm = prompt | structured_llm
|
191
|
+
response = prompted_structured_llm.invoke({"input": message})
|
192
|
+
|
193
|
+
|
194
|
+
llm_output = response
|
195
|
+
parsed_output_message = str(response)
|
196
|
+
if config.token_count_enabled:
|
197
|
+
calculate_cost(parsed_input_message, parsed_output_message, model=self.model, module="data_extraction")
|
198
|
+
|
199
|
+
return llm_output
|
200
|
+
|
201
|
+
|
202
|
+
def get_data_extraction(self):
|
203
|
+
|
204
|
+
custom_types_names = list(config.types_dict.keys())
|
205
|
+
if llm is None:
|
206
|
+
logger.error("data extraction module not initialized.")
|
207
|
+
return {"output": None}
|
208
|
+
|
209
|
+
list_dtype = None
|
210
|
+
# If data type is custom
|
211
|
+
if self.dtype in custom_types_names:
|
212
|
+
type_yaml = config.types_dict.get(self.dtype, "string")
|
213
|
+
dformat = f"Data should be strictly outputted following regular expression pattern: {type_yaml['format']}"
|
214
|
+
if isinstance(type_yaml["extraction"], dict):
|
215
|
+
dtype = self.get_data_prompt("string")
|
216
|
+
static_output = self.static_extraction(dtype[0], dformat, list_dtype)
|
217
|
+
llm_output = self.dynamic_extraction(type_yaml["extraction"], static_output)
|
218
|
+
return {self.variable: llm_output}
|
219
|
+
else:
|
220
|
+
dtype = self.get_data_prompt(type_yaml["extraction"])
|
221
|
+
llm_output = self.static_extraction(dtype[0], dformat, list_dtype)
|
222
|
+
logger.info(f'LLM output for data extraction: {llm_output}')
|
223
|
+
return {self.variable: llm_output}
|
224
|
+
|
225
|
+
# If data type is predefined
|
226
|
+
else:
|
227
|
+
|
228
|
+
if "list" in self.dtype:
|
229
|
+
pattern = r'(\w+)\[(.*?)\]'
|
230
|
+
match = re.match(pattern, self.dtype)
|
231
|
+
if match:
|
232
|
+
list_name = match.group(1)
|
233
|
+
content = match.group(2)
|
234
|
+
dtype = self.get_data_prompt(list_name)[0]
|
235
|
+
list_dtype = self.get_data_prompt(content)[0]
|
236
|
+
dformat = self.get_data_prompt(list_name)[1]
|
237
|
+
else:
|
238
|
+
logger.error("Invalid structure on list for output data. Using 'string' by default.")
|
239
|
+
dtype = self.get_data_prompt('string')[0]
|
240
|
+
dformat = self.get_data_prompt('string')[1]
|
241
|
+
|
242
|
+
else:
|
243
|
+
dtype, dformat = self.get_data_prompt(self.dtype)
|
244
|
+
|
245
|
+
if dtype is None:
|
246
|
+
logger.warning(f"Data type {self.dtype} is not supported. Using 'string' by default.")
|
247
|
+
dtype = 'string'
|
248
|
+
|
249
|
+
if dformat is None:
|
250
|
+
logger.warning(f"Data format for {self.dtype} is not supported. Using default format.")
|
251
|
+
dformat = "Extract and display concisely only the requested information without including additional context"
|
252
|
+
|
253
|
+
|
254
|
+
|
255
|
+
llm_output = self.static_extraction(dtype, dformat, list_dtype)
|
256
|
+
|
257
|
+
logger.info(f'LLM output for data extraction: {llm_output}')
|
258
|
+
# text = llm_output['answer']
|
259
|
+
data = self.data_process(llm_output, self.dtype)
|
260
|
+
return {self.variable: data}
|
@@ -0,0 +1,134 @@
|
|
1
|
+
import ast
|
2
|
+
import pandas as pd
|
3
|
+
from user_sim.utils.token_cost_calculator import calculate_cost, max_output_tokens_allowed, max_input_tokens_allowed
|
4
|
+
import re
|
5
|
+
from user_sim.utils.exceptions import *
|
6
|
+
from user_sim.utils.utilities import init_model
|
7
|
+
from user_sim.utils import config
|
8
|
+
from langchain_core.prompts import ChatPromptTemplate
|
9
|
+
|
10
|
+
|
11
|
+
model = " "
|
12
|
+
llm = None
|
13
|
+
|
14
|
+
import logging
|
15
|
+
logger = logging.getLogger('Info Logger')
|
16
|
+
|
17
|
+
|
18
|
+
def init_data_gathering_module():
|
19
|
+
global model
|
20
|
+
global llm
|
21
|
+
model, llm = init_model()
|
22
|
+
|
23
|
+
def extract_dict(in_val):
|
24
|
+
reg_ex = r'\{[^{}]*\}'
|
25
|
+
coincidence = re.search(reg_ex, in_val, re.DOTALL)
|
26
|
+
|
27
|
+
if coincidence:
|
28
|
+
return coincidence.group(0)
|
29
|
+
else:
|
30
|
+
return None
|
31
|
+
|
32
|
+
|
33
|
+
def to_dict(in_val):
|
34
|
+
try:
|
35
|
+
dictionary = ast.literal_eval(extract_dict(in_val))
|
36
|
+
except (BadDictionaryGeneration, ValueError) as e:
|
37
|
+
logger.error(f"Bad dictionary generation: {e}. Setting empty dictionary value.")
|
38
|
+
dictionary = {}
|
39
|
+
return dictionary
|
40
|
+
|
41
|
+
|
42
|
+
class ChatbotAssistant:
|
43
|
+
def __init__(self, ask_about):
|
44
|
+
self.verification_description = "the following has been answered, confirmed or provided by the chatbot:"
|
45
|
+
self.data_description = """"the piece of the conversation where the following has been answered
|
46
|
+
or confirmed by the assistant. Don't consider the user's interactions:"""
|
47
|
+
self.properties = self.process_ask_about(ask_about)
|
48
|
+
self.system_message = """You are a helpful assistant that detects when a query in a conversation
|
49
|
+
has been answered, confirmed or provided by the chatbot."""
|
50
|
+
self.messages = ""
|
51
|
+
self.gathering_register = {}
|
52
|
+
|
53
|
+
def process_ask_about(self, ask_about):
|
54
|
+
properties = {
|
55
|
+
}
|
56
|
+
|
57
|
+
for ab in ask_about:
|
58
|
+
properties[ab.replace(' ', '_')] = {
|
59
|
+
"type": "object",
|
60
|
+
"properties": {
|
61
|
+
"verification": {
|
62
|
+
"type": "boolean",
|
63
|
+
"description": f"{self.verification_description} {ab}"
|
64
|
+
},
|
65
|
+
"data": {
|
66
|
+
"type": ["string", "null"],
|
67
|
+
"description": f"{self.data_description} {ab} "
|
68
|
+
}
|
69
|
+
},
|
70
|
+
"required": ["verification", "data"],
|
71
|
+
"additionalProperties": False
|
72
|
+
}
|
73
|
+
return properties
|
74
|
+
|
75
|
+
def add_message(self, history): # adds directly the chat history from user_simulator "self.conversation_history"
|
76
|
+
text = ""
|
77
|
+
for entry in history['interaction']:
|
78
|
+
for speaker, message in entry.items():
|
79
|
+
text += f"{speaker}: {message}\n"
|
80
|
+
|
81
|
+
self.messages = text
|
82
|
+
self.gathering_register = self.create_dataframe()
|
83
|
+
|
84
|
+
def get_json(self):
|
85
|
+
|
86
|
+
response_format = {
|
87
|
+
"title": "data_gathering",
|
88
|
+
"type": "object",
|
89
|
+
"description": "The information to check.",
|
90
|
+
"properties": self.properties,
|
91
|
+
"required": list(self.properties.keys()),
|
92
|
+
"additionalProperties": False
|
93
|
+
}
|
94
|
+
|
95
|
+
parsed_input_message = self.messages + self.verification_description + self.data_description
|
96
|
+
|
97
|
+
if llm is None:
|
98
|
+
logger.error("data gathering module not initialized.")
|
99
|
+
return "Empty data"
|
100
|
+
|
101
|
+
if max_input_tokens_allowed(parsed_input_message, model):
|
102
|
+
logger.error(f"Token limit was surpassed")
|
103
|
+
return None
|
104
|
+
|
105
|
+
if config.token_count_enabled:
|
106
|
+
llm.max_tokens = max_output_tokens_allowed(model)
|
107
|
+
|
108
|
+
prompt = ChatPromptTemplate.from_messages([("system", self.system_message), ("human", "{input}")])
|
109
|
+
structured_llm = llm.with_structured_output(response_format)
|
110
|
+
prompted_structured_llm = prompt | structured_llm
|
111
|
+
|
112
|
+
try:
|
113
|
+
response = prompted_structured_llm.invoke({"input": self.messages})
|
114
|
+
parsed_output_message = str(response)
|
115
|
+
|
116
|
+
except Exception as e:
|
117
|
+
logger.error(f"Truncated data in message: {e}")
|
118
|
+
response = parsed_output_message = None
|
119
|
+
if config.token_count_enabled:
|
120
|
+
calculate_cost(parsed_input_message, parsed_output_message, model=config.model, module="data_extraction")
|
121
|
+
return response
|
122
|
+
|
123
|
+
def create_dataframe(self):
|
124
|
+
data_dict = self.get_json()
|
125
|
+
if data_dict is None:
|
126
|
+
df = self.gathering_register
|
127
|
+
else:
|
128
|
+
try:
|
129
|
+
df = pd.DataFrame.from_dict(data_dict, orient='index')
|
130
|
+
except Exception as e:
|
131
|
+
logger.error(f"{e}. data_dict: {data_dict}. Retrieving data frame from gathering_register")
|
132
|
+
df = self.gathering_register
|
133
|
+
return df
|
134
|
+
|
@@ -0,0 +1,147 @@
|
|
1
|
+
import random
|
2
|
+
import logging
|
3
|
+
|
4
|
+
logger = logging.getLogger('Info Logger')
|
5
|
+
|
6
|
+
|
7
|
+
def find_instance(instances, i_class):
|
8
|
+
for instance in instances:
|
9
|
+
if isinstance(instance, i_class):
|
10
|
+
return instance
|
11
|
+
return None
|
12
|
+
|
13
|
+
|
14
|
+
def create_instance(class_list, interaction_styles):
|
15
|
+
instances = []
|
16
|
+
for class_info in class_list:
|
17
|
+
class_name = class_info['clase']
|
18
|
+
args = class_info.get('args', [])
|
19
|
+
kwargs = class_info.get('kwargs', {})
|
20
|
+
if class_name in interaction_styles:
|
21
|
+
instance = interaction_styles[class_name](*args, **kwargs)
|
22
|
+
instances.append(instance)
|
23
|
+
else:
|
24
|
+
raise ValueError(f"Couldn't find {class_name} in interaction list.")
|
25
|
+
return instances
|
26
|
+
|
27
|
+
|
28
|
+
class InteractionStyle:
|
29
|
+
|
30
|
+
def __init__(self, inter_type):
|
31
|
+
self.inter_type = inter_type
|
32
|
+
self.change_language_flag = False
|
33
|
+
self.languages_options = []
|
34
|
+
|
35
|
+
def get_prompt(self):
|
36
|
+
return
|
37
|
+
|
38
|
+
def get_metadata(self):
|
39
|
+
return
|
40
|
+
|
41
|
+
|
42
|
+
class LongPhrases(InteractionStyle):
|
43
|
+
def __init__(self):
|
44
|
+
super().__init__(inter_type='long phrases')
|
45
|
+
|
46
|
+
def get_prompt(self):
|
47
|
+
return "use very long phrases to write anything. "
|
48
|
+
|
49
|
+
def get_metadata(self):
|
50
|
+
return self.inter_type
|
51
|
+
|
52
|
+
|
53
|
+
class ChangeYourMind(InteractionStyle):
|
54
|
+
def __init__(self):
|
55
|
+
super().__init__(inter_type='change your mind')
|
56
|
+
|
57
|
+
def get_prompt(self):
|
58
|
+
return "eventually, change your mind about any information you provided. "
|
59
|
+
|
60
|
+
def get_metadata(self):
|
61
|
+
return self.inter_type
|
62
|
+
|
63
|
+
|
64
|
+
class ChangeLanguage(InteractionStyle):
|
65
|
+
# TODO: add chance variable with *args
|
66
|
+
def __init__(self, default_language):
|
67
|
+
super().__init__(inter_type='change language')
|
68
|
+
self.default_language = default_language
|
69
|
+
self.languages_list = []
|
70
|
+
self.chance = 0.3
|
71
|
+
|
72
|
+
def get_prompt(self):
|
73
|
+
|
74
|
+
lang = self.language(self.chance)
|
75
|
+
prompt = f"""Please, always talk in {lang}, even If the assistant tells you that he doesn't understand,
|
76
|
+
or you had a conversation in another language before. """
|
77
|
+
return prompt
|
78
|
+
|
79
|
+
def language(self, chance=0.3):
|
80
|
+
chance = chance*100
|
81
|
+
rand_number = random.randint(1, 100)
|
82
|
+
if rand_number <= chance:
|
83
|
+
lang = random.choice(self.languages_options)
|
84
|
+
logger.info(f'Language was set to {lang}')
|
85
|
+
self.languages_list.append(lang)
|
86
|
+
return lang
|
87
|
+
else:
|
88
|
+
self.languages_list.append(self.default_language)
|
89
|
+
logger.info(f'Language was set to default ({self.default_language})')
|
90
|
+
return self.default_language
|
91
|
+
|
92
|
+
def reset_language_list(self):
|
93
|
+
self.languages_list.clear()
|
94
|
+
|
95
|
+
def get_metadata(self):
|
96
|
+
language_list = self.languages_list.copy()
|
97
|
+
self.reset_language_list()
|
98
|
+
return {'change languages': language_list}
|
99
|
+
|
100
|
+
|
101
|
+
class MakeSpellingMistakes(InteractionStyle):
|
102
|
+
def __init__(self):
|
103
|
+
super().__init__(inter_type='make spelling mistakes')
|
104
|
+
|
105
|
+
def get_prompt(self):
|
106
|
+
prompt = """
|
107
|
+
please, make several spelling mistakes during the conversation. Minimum 5 typos per
|
108
|
+
sentence if possible.
|
109
|
+
"""
|
110
|
+
return prompt
|
111
|
+
|
112
|
+
def get_metadata(self):
|
113
|
+
return self.inter_type
|
114
|
+
|
115
|
+
|
116
|
+
class SingleQuestions(InteractionStyle):
|
117
|
+
def __init__(self):
|
118
|
+
super().__init__(inter_type='single questions')
|
119
|
+
|
120
|
+
def get_prompt(self):
|
121
|
+
return "ask only one question per interaction. "
|
122
|
+
|
123
|
+
def get_metadata(self):
|
124
|
+
return self.inter_type
|
125
|
+
|
126
|
+
|
127
|
+
class AllQuestions(InteractionStyle):
|
128
|
+
# todo: all questions should only get questions from ask_about
|
129
|
+
def __init__(self):
|
130
|
+
super().__init__(inter_type='all questions')
|
131
|
+
|
132
|
+
def get_prompt(self):
|
133
|
+
return "ask everything you have to ask in one sentence. "
|
134
|
+
|
135
|
+
def get_metadata(self):
|
136
|
+
return self.inter_type
|
137
|
+
|
138
|
+
|
139
|
+
class Default(InteractionStyle):
|
140
|
+
def __init__(self):
|
141
|
+
super().__init__(inter_type='default')
|
142
|
+
|
143
|
+
def get_prompt(self):
|
144
|
+
return "Ask about one or two things per interaction, don't ask everything you want to know in one sentence."
|
145
|
+
|
146
|
+
def get_metadata(self):
|
147
|
+
return self.inter_type
|