user-simulator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- user_sim/__init__.py +0 -0
- user_sim/cli/__init__.py +0 -0
- user_sim/cli/gen_user_profile.py +34 -0
- user_sim/cli/init_project.py +65 -0
- user_sim/cli/sensei_chat.py +481 -0
- user_sim/cli/sensei_check.py +103 -0
- user_sim/cli/validation_check.py +143 -0
- user_sim/core/__init__.py +0 -0
- user_sim/core/ask_about.py +665 -0
- user_sim/core/data_extraction.py +260 -0
- user_sim/core/data_gathering.py +134 -0
- user_sim/core/interaction_styles.py +147 -0
- user_sim/core/role_structure.py +608 -0
- user_sim/core/user_simulator.py +302 -0
- user_sim/handlers/__init__.py +0 -0
- user_sim/handlers/asr_module.py +128 -0
- user_sim/handlers/html_parser_module.py +202 -0
- user_sim/handlers/image_recognition_module.py +139 -0
- user_sim/handlers/pdf_parser_module.py +123 -0
- user_sim/utils/__init__.py +0 -0
- user_sim/utils/config.py +47 -0
- user_sim/utils/cost_tracker.py +153 -0
- user_sim/utils/cost_tracker_v2.py +193 -0
- user_sim/utils/errors.py +15 -0
- user_sim/utils/exceptions.py +47 -0
- user_sim/utils/languages.py +78 -0
- user_sim/utils/register_management.py +62 -0
- user_sim/utils/show_logs.py +63 -0
- user_sim/utils/token_cost_calculator.py +338 -0
- user_sim/utils/url_management.py +60 -0
- user_sim/utils/utilities.py +568 -0
- user_simulator-0.1.0.dist-info/METADATA +733 -0
- user_simulator-0.1.0.dist-info/RECORD +37 -0
- user_simulator-0.1.0.dist-info/WHEEL +5 -0
- user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
- user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
- user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,608 @@
|
|
1
|
+
import itertools
|
2
|
+
from pydantic import BaseModel, ValidationError, field_validator
|
3
|
+
from typing import List, Union, Dict, Optional
|
4
|
+
from .interaction_styles import *
|
5
|
+
from .ask_about import *
|
6
|
+
from user_sim.utils.exceptions import *
|
7
|
+
from user_sim.utils.languages import languages
|
8
|
+
from user_sim.utils import config
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from user_sim.handlers.image_recognition_module import init_vision_module
|
11
|
+
from .data_gathering import init_data_gathering_module
|
12
|
+
from .data_extraction import init_data_extraction_module
|
13
|
+
import logging
|
14
|
+
logger = logging.getLogger('Info Logger')
|
15
|
+
|
16
|
+
|
17
|
+
def replace_placeholders(phrase, variables):
|
18
|
+
def replacer(match):
|
19
|
+
key = match.group(1)
|
20
|
+
if isinstance(variables, dict):
|
21
|
+
return ', '.join(map(str, variables.get(key, [])))
|
22
|
+
else:
|
23
|
+
return ', '.join(map(str, variables))
|
24
|
+
|
25
|
+
pattern = re.compile(r'\{\{(\w+)\}\}')
|
26
|
+
return pattern.sub(replacer, phrase)
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
def list_to_str(list_of_strings):
|
31
|
+
if list_of_strings is None:
|
32
|
+
return ''
|
33
|
+
try:
|
34
|
+
single_string = ' '.join(list_of_strings)
|
35
|
+
return single_string
|
36
|
+
except Exception as e:
|
37
|
+
# logging.getLogger().verbose(f'Error: {e}')
|
38
|
+
return ''
|
39
|
+
|
40
|
+
|
41
|
+
class ConvFormat(BaseModel):
|
42
|
+
type: Optional[str] = "text"
|
43
|
+
config: Optional[str] = None
|
44
|
+
|
45
|
+
class LLM(BaseModel):
|
46
|
+
model: Optional[str] = "gpt-4o"
|
47
|
+
model_prov: Optional[str] = None
|
48
|
+
temperature: Optional[float] = 0.8
|
49
|
+
format: Optional[ConvFormat] = ConvFormat() # text, speech, hybrid
|
50
|
+
|
51
|
+
class User(BaseModel):
|
52
|
+
language: Optional[Union[str, None]] = 'English'
|
53
|
+
role: str
|
54
|
+
context: Optional[Union[List[Union[str, Dict]], Dict, None]] = ''
|
55
|
+
goals: list
|
56
|
+
|
57
|
+
|
58
|
+
class ChatbotClass(BaseModel):
|
59
|
+
is_starter: Optional[bool] = True
|
60
|
+
fallback: str
|
61
|
+
output: list
|
62
|
+
|
63
|
+
|
64
|
+
class Conversation(BaseModel):
|
65
|
+
number: Union[int, str]
|
66
|
+
max_cost: Optional[float]=10**9
|
67
|
+
goal_style: Dict
|
68
|
+
interaction_style: list
|
69
|
+
|
70
|
+
@field_validator('max_cost', mode='before')
|
71
|
+
@classmethod
|
72
|
+
def set_token_count_enabled(cls, value):
|
73
|
+
if value is not None:
|
74
|
+
config.token_count_enabled = True
|
75
|
+
return value
|
76
|
+
|
77
|
+
class RoleDataModel(BaseModel):
|
78
|
+
test_name: str
|
79
|
+
llm: Optional[LLM] = LLM()
|
80
|
+
user: User
|
81
|
+
chatbot: ChatbotClass
|
82
|
+
conversation: Conversation
|
83
|
+
|
84
|
+
@dataclass
|
85
|
+
class ValidationIssue:
|
86
|
+
field: str
|
87
|
+
error: str
|
88
|
+
error_type: str
|
89
|
+
location: str
|
90
|
+
|
91
|
+
class RoleData:
|
92
|
+
|
93
|
+
def __init__(self, yaml_file, project_folder=None, personality_file=None, validation=False):
|
94
|
+
self.yaml = yaml_file
|
95
|
+
self.validation = validation
|
96
|
+
self.personality_file = personality_file
|
97
|
+
self.project_folder = project_folder
|
98
|
+
self.errors: List[ValidationIssue] = []
|
99
|
+
|
100
|
+
# try:
|
101
|
+
# self.validated_data = RoleDataModel(**self.yaml)
|
102
|
+
# except ValidationError as e:
|
103
|
+
# for err in e.errors():
|
104
|
+
# loc_path = '.'.join(str(part) for part in err['loc'])
|
105
|
+
# issue = ValidationIssue(
|
106
|
+
# field=err['loc'][-1],
|
107
|
+
# error=err['msg'],
|
108
|
+
# error_type=err['type'],
|
109
|
+
# location=loc_path
|
110
|
+
# )
|
111
|
+
# self.errors.append(issue)
|
112
|
+
|
113
|
+
# Test Name
|
114
|
+
try:
|
115
|
+
self.test_name = self.yaml.get('test_name')
|
116
|
+
except Exception as e:
|
117
|
+
self.collect_errors(e, prefix='llm')
|
118
|
+
|
119
|
+
# LLM
|
120
|
+
self.model = self.model_provider = self.temperature = self.format_type = self.format_config = None
|
121
|
+
try:
|
122
|
+
self.llm = LLM(**self.yaml.get('llm', {}))
|
123
|
+
self.model = config.model = self.llm.model
|
124
|
+
self.model_provider = config.model_provider = self.llm.model_prov
|
125
|
+
self.temperature = self.llm.temperature
|
126
|
+
self.format_type = self.llm.format.type
|
127
|
+
self.format_config = self.llm.format.config
|
128
|
+
except Exception as e:
|
129
|
+
self.collect_errors(e, prefix='llm')
|
130
|
+
|
131
|
+
if not self.errors:
|
132
|
+
self.init_llm_modules()
|
133
|
+
|
134
|
+
# User
|
135
|
+
self.language = self.role = self.raw_context = self.context = self.ask_about = None
|
136
|
+
try:
|
137
|
+
self.user = User(**self.yaml.get('user', {}))
|
138
|
+
self.language = self.set_language(self.user.language)
|
139
|
+
self.role = self.user.role
|
140
|
+
self.raw_context = self.user.context
|
141
|
+
self.context = self.context_processor(self.raw_context)
|
142
|
+
self.ask_about = self.get_ask_about()
|
143
|
+
except Exception as e:
|
144
|
+
self.collect_errors(e, prefix='user')
|
145
|
+
|
146
|
+
# Chatbot
|
147
|
+
self.is_starter = self.fallback = self.output = None
|
148
|
+
try:
|
149
|
+
self.chatbot = ChatbotClass(**self.yaml.get('chatbot', {}))
|
150
|
+
self.is_starter = self.chatbot.is_starter
|
151
|
+
self.fallback = self.chatbot.fallback
|
152
|
+
self.output = self.chatbot.output
|
153
|
+
except Exception as e:
|
154
|
+
self.collect_errors(e, prefix='chatbot')
|
155
|
+
|
156
|
+
# Conversation
|
157
|
+
self.conversation_number = self.max_cost = self.goal_style = self.interaction_styles = None
|
158
|
+
try:
|
159
|
+
self.conversation = Conversation(**self.yaml.get('conversation', {}))
|
160
|
+
self.combinations_dict = {}
|
161
|
+
self.conversation_number = self.get_conversation_number(self.conversation.number)
|
162
|
+
self.max_cost = self.conversation.max_cost
|
163
|
+
config.limit_cost = self.max_cost
|
164
|
+
self.goal_style = self.pick_goal_style(self.conversation.goal_style)
|
165
|
+
self.interaction_styles = self.pick_interaction_style(self.conversation.interaction_style)
|
166
|
+
except Exception as e:
|
167
|
+
self.collect_errors(e, prefix='conversation')
|
168
|
+
|
169
|
+
# # Initialization of all LLM modules
|
170
|
+
# self.init_llm_modules()
|
171
|
+
|
172
|
+
def init_llm_modules(self):
|
173
|
+
|
174
|
+
init_vision_module()
|
175
|
+
init_data_gathering_module()
|
176
|
+
init_data_extraction_module()
|
177
|
+
init_any_list_module()
|
178
|
+
# init_asr_module()
|
179
|
+
|
180
|
+
|
181
|
+
def collect_errors(self, e: ValidationError, prefix=""):
|
182
|
+
|
183
|
+
if isinstance(e, ValidationError):
|
184
|
+
for err in e.errors():
|
185
|
+
loc_path = '.'.join(str(part) for part in err['loc'])
|
186
|
+
full_path = f"{prefix}.{loc_path}" if prefix else loc_path
|
187
|
+
self.errors.append(
|
188
|
+
ValidationIssue(
|
189
|
+
field=err['loc'][-1],
|
190
|
+
error=err['msg'],
|
191
|
+
error_type=err['type'],
|
192
|
+
location=full_path
|
193
|
+
)
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
self.errors.append(
|
197
|
+
ValidationIssue(
|
198
|
+
field='unknown',
|
199
|
+
error=str(e),
|
200
|
+
error_type=type(e).__name__,
|
201
|
+
location=prefix
|
202
|
+
)
|
203
|
+
)
|
204
|
+
|
205
|
+
|
206
|
+
def get_errors(self):
|
207
|
+
error_list = []
|
208
|
+
for error in self.errors:
|
209
|
+
formated_error = {
|
210
|
+
"field": error.location,
|
211
|
+
"error": error.error,
|
212
|
+
"type": error.error_type
|
213
|
+
}
|
214
|
+
error_list.append(formated_error)
|
215
|
+
logger.warning(f"\n{len(self.errors)} errors detected.\n")
|
216
|
+
|
217
|
+
return error_list, len(self.errors)
|
218
|
+
|
219
|
+
|
220
|
+
def get_ask_about(self):
|
221
|
+
if self.validation:
|
222
|
+
try:
|
223
|
+
return AskAboutClass(self.user.goals)
|
224
|
+
except Exception as e:
|
225
|
+
issue = ValidationIssue(
|
226
|
+
field="goals",
|
227
|
+
error=str(e),
|
228
|
+
error_type=type(e).__name__,
|
229
|
+
location="user.goals"
|
230
|
+
)
|
231
|
+
self.errors.append(issue)
|
232
|
+
else:
|
233
|
+
return AskAboutClass(self.user.goals)
|
234
|
+
|
235
|
+
|
236
|
+
|
237
|
+
def set_language(self, lang):
|
238
|
+
if isinstance(lang, type(None)):
|
239
|
+
logger.info("Language parameter empty. Setting language to Default (English)")
|
240
|
+
return "English"
|
241
|
+
try:
|
242
|
+
if lang in languages:
|
243
|
+
logger.info(f"Language set to {lang}")
|
244
|
+
return lang
|
245
|
+
else:
|
246
|
+
raise InvalidLanguageException(f'Invalid language input: {lang}. Setting language to default (English)')
|
247
|
+
except InvalidLanguageException as e:
|
248
|
+
issue = ValidationIssue(
|
249
|
+
field= "language",
|
250
|
+
error=str(e),
|
251
|
+
error_type=type(e).__name__,
|
252
|
+
location="user.language"
|
253
|
+
)
|
254
|
+
self.errors.append(issue)
|
255
|
+
return "English"
|
256
|
+
|
257
|
+
|
258
|
+
def reset_attributes(self):
|
259
|
+
logger.info(f"Preparing attributes for next conversation...")
|
260
|
+
self.init_llm_modules()
|
261
|
+
self.fallback = self.chatbot.fallback
|
262
|
+
# self.is_starter = self.validated_data.is_starter
|
263
|
+
self.context = self.context_processor(self.raw_context)
|
264
|
+
self.ask_about.reset() # self.picked_elements = [], self.phrases = []
|
265
|
+
|
266
|
+
self.goal_style = self.pick_goal_style(self.conversation.goal_style)
|
267
|
+
self.language = self.set_language(self.user.language)
|
268
|
+
self.interaction_styles = self.pick_interaction_style(self.conversation.interaction_style)
|
269
|
+
|
270
|
+
@staticmethod
|
271
|
+
def list_to_dict_reformat(conv):
|
272
|
+
result_dict = {k: v for d in conv for k, v in d.items()}
|
273
|
+
return result_dict
|
274
|
+
|
275
|
+
def personality_extraction(self, context):
|
276
|
+
if context["personality"]:
|
277
|
+
personality = context["personality"]
|
278
|
+
|
279
|
+
path_list = []
|
280
|
+
if os.path.exists(config.custom_personalities_folder):
|
281
|
+
custom_personalities_path = config.custom_personalities_folder
|
282
|
+
path_list.append(custom_personalities_path)
|
283
|
+
|
284
|
+
default_personalities_path = os.path.join(config.root_path, "config", "personalities")
|
285
|
+
path_list.append(default_personalities_path)
|
286
|
+
|
287
|
+
try:
|
288
|
+
for path in path_list:
|
289
|
+
for file in os.listdir(path):
|
290
|
+
file_name, ext = os.path.splitext(file)
|
291
|
+
clean_personality, _ = os.path.splitext(personality)
|
292
|
+
if file_name == clean_personality and ext in ('.yml', '.yaml'):
|
293
|
+
personality_path = os.path.join(path, file)
|
294
|
+
personality_data = read_yaml(personality_path)
|
295
|
+
|
296
|
+
try:
|
297
|
+
self.personality = personality_data["name"]
|
298
|
+
logger.info(f"Personality set to '{file_name}'")
|
299
|
+
return personality_data['context']
|
300
|
+
except KeyError:
|
301
|
+
raise InvalidFormat(f"Key 'context' not found in personality file.")
|
302
|
+
|
303
|
+
logger.error(f"Couldn't find specified personality file: '{personality}'")
|
304
|
+
return ['']
|
305
|
+
|
306
|
+
except Exception as e:
|
307
|
+
logger.error(e)
|
308
|
+
return ['']
|
309
|
+
|
310
|
+
else:
|
311
|
+
logger.error(f"Data for context is not a dictionary with context key: {context}.")
|
312
|
+
return ['']
|
313
|
+
|
314
|
+
def get_conversation_number(self, conversation):
|
315
|
+
if isinstance(conversation, int):
|
316
|
+
logger.info(f"{conversation} conversations will be generated")
|
317
|
+
return conversation
|
318
|
+
|
319
|
+
comb_pattern = r'^combinations(?:\(([^,()\s]+)(?:,\s*([^()]+))?\))?$'
|
320
|
+
match = re.match(comb_pattern, conversation.strip())
|
321
|
+
|
322
|
+
if self.validation:
|
323
|
+
generators_list = self.ask_about.var_generators
|
324
|
+
combinations_dict = []
|
325
|
+
|
326
|
+
for generator in generators_list:
|
327
|
+
if "matrix" in generator:
|
328
|
+
name = generator['name']
|
329
|
+
combination_matrix = []
|
330
|
+
combinations = 0
|
331
|
+
if generator['type'] == 'forward':
|
332
|
+
combination_matrix = [list(p) for p in itertools.product(*generator['matrix'])]
|
333
|
+
combinations = len(combination_matrix)
|
334
|
+
elif generator['type'] == 'pairwise':
|
335
|
+
combination_matrix = generator['matrix']
|
336
|
+
combinations = len(combination_matrix)
|
337
|
+
|
338
|
+
combinations_dict.append({'name':name,
|
339
|
+
'matrix':combination_matrix,
|
340
|
+
'combinations':combinations,
|
341
|
+
'type': generator['type']})
|
342
|
+
|
343
|
+
self.combinations_dict = combinations_dict
|
344
|
+
|
345
|
+
if match:
|
346
|
+
# func_name = "combinations"
|
347
|
+
sample = match.group(1)
|
348
|
+
iter_function = match.group(2)
|
349
|
+
|
350
|
+
if iter_function == "forward":
|
351
|
+
if self.ask_about.forward_combinations <= 0:
|
352
|
+
logger.error("Conversation number set to 'forward_all_combinations' but no combinations can be made.")
|
353
|
+
return 0
|
354
|
+
conv_number = self.ask_about.forward_combinations
|
355
|
+
|
356
|
+
if sample:
|
357
|
+
conv_number = round(conv_number * float(sample))
|
358
|
+
logger.info(f"{conv_number} conversations will be generated.")
|
359
|
+
return conv_number
|
360
|
+
|
361
|
+
elif iter_function == "pairwise":
|
362
|
+
if self.ask_about.pairwise_combinations <= 0:
|
363
|
+
logger.error("Conversation number set to 'pairwise_all_combinations' but no combinations can be made.")
|
364
|
+
return 0
|
365
|
+
|
366
|
+
conv_number = self.ask_about.pairwise_combinations
|
367
|
+
if sample:
|
368
|
+
conv_number = round(conv_number * float(sample))
|
369
|
+
logger.info(f"{conv_number} conversations will be generated.")
|
370
|
+
return conv_number
|
371
|
+
|
372
|
+
else:
|
373
|
+
conv_number = max(self.ask_about.forward_combinations, self.ask_about.pairwise_combinations)
|
374
|
+
if conv_number < 1:
|
375
|
+
logger.error("Conversation number set to 'combinations' but no combinations can be made.")
|
376
|
+
return 0
|
377
|
+
if sample:
|
378
|
+
conv_number = round(conv_number * float(sample))
|
379
|
+
logger.info(f"{conv_number} conversations will be generated.")
|
380
|
+
return conv_number
|
381
|
+
|
382
|
+
else:
|
383
|
+
logger.error(f"Conversation number can't be obtained due tu unrecognized value: {conversation}")
|
384
|
+
issue = ValidationIssue(
|
385
|
+
field= "language",
|
386
|
+
error=f"Conversation number can't be obtained due tu unrecognized value: {conversation}",
|
387
|
+
error_type=type(InvalidFormat).__name__,
|
388
|
+
location="conversation.number"
|
389
|
+
)
|
390
|
+
self.errors.append(issue)
|
391
|
+
return 0
|
392
|
+
|
393
|
+
def context_processor(self, context):
|
394
|
+
if isinstance(context, dict):
|
395
|
+
personality_phrases = self.personality_extraction(context)
|
396
|
+
return list_to_str(personality_phrases)
|
397
|
+
|
398
|
+
res = len(list(filter(lambda x: isinstance(x, dict), context)))
|
399
|
+
if res > 1:
|
400
|
+
# raise InvalidFormat(f)
|
401
|
+
issue = ValidationIssue(
|
402
|
+
field="context",
|
403
|
+
error=str("Too many keys in context list."),
|
404
|
+
error_type=type(InvalidFormat).__name__,
|
405
|
+
location="user.context"
|
406
|
+
)
|
407
|
+
self.errors.append(issue)
|
408
|
+
return ""
|
409
|
+
elif res <= 0 and not isinstance(context, dict):
|
410
|
+
phrases = list_to_str(context)
|
411
|
+
if self.personality_file is not None:
|
412
|
+
personality = read_yaml(self.personality_file)
|
413
|
+
personality_phrases = personality['context']
|
414
|
+
phrases = phrases + list_to_str(personality_phrases)
|
415
|
+
return phrases
|
416
|
+
else:
|
417
|
+
custom_phrases = []
|
418
|
+
personality_phrases = []
|
419
|
+
for item in context:
|
420
|
+
if isinstance(item, str):
|
421
|
+
custom_phrases.append(item)
|
422
|
+
elif isinstance(item, dict):
|
423
|
+
personality_phrases = personality_phrases + self.personality_extraction(item)
|
424
|
+
else:
|
425
|
+
issue = ValidationIssue(
|
426
|
+
field="context",
|
427
|
+
error=str(f"Invalid data type in context list: {type(item)}:{item}"),
|
428
|
+
error_type=type(InvalidDataType).__name__,
|
429
|
+
location="user.context"
|
430
|
+
)
|
431
|
+
self.errors.append(issue)
|
432
|
+
return ""
|
433
|
+
|
434
|
+
# If no personality is given, we use the one specified as input in the command line
|
435
|
+
if len(personality_phrases) == 0 and self.personality_file is not None:
|
436
|
+
personality = read_yaml(self.personality_file)
|
437
|
+
personality_phrases = personality['context']
|
438
|
+
|
439
|
+
total_phrases = personality_phrases + custom_phrases
|
440
|
+
return list_to_str(total_phrases)
|
441
|
+
|
442
|
+
def pick_goal_style(self, goal):
|
443
|
+
|
444
|
+
if goal is None:
|
445
|
+
return goal, False
|
446
|
+
|
447
|
+
if 'max_cost' in goal:
|
448
|
+
if goal['max_cost'] > 0:
|
449
|
+
config.limit_individual_cost = goal['max_cost']
|
450
|
+
config.token_count_enabled = True
|
451
|
+
else:
|
452
|
+
if self.validation:
|
453
|
+
issue = ValidationIssue(
|
454
|
+
field="goal_style",
|
455
|
+
error=str(f"Goal cost can't be lower than or equal to 0: {goal['cost']}"),
|
456
|
+
error_type=type(NoCostException).__name__,
|
457
|
+
location="conversation.goal_style"
|
458
|
+
)
|
459
|
+
self.errors.append(issue)
|
460
|
+
return ""
|
461
|
+
else:
|
462
|
+
raise NoCostException(f"Goal cost can't be lower than or equal to 0: {goal['cost']}")
|
463
|
+
else:
|
464
|
+
config.limit_individual_cost = config.limit_cost
|
465
|
+
|
466
|
+
if 'steps' in goal:
|
467
|
+
if goal['steps'] <= 20 or goal['steps'] > 0:
|
468
|
+
return list(goal.keys())[0], goal['steps']
|
469
|
+
else:
|
470
|
+
if self.validation:
|
471
|
+
issue = ValidationIssue(
|
472
|
+
field="goal_style",
|
473
|
+
error=str(f"Goal steps higher than 20 steps or lower than 0 steps: {goal['steps']}"),
|
474
|
+
error_type=type(OutOfLimitException).__name__,
|
475
|
+
location="conversation.goal_style"
|
476
|
+
)
|
477
|
+
self.errors.append(issue)
|
478
|
+
return ""
|
479
|
+
else:
|
480
|
+
raise OutOfLimitException(f"Goal steps higher than 20 steps or lower than 0 steps: {goal['steps']}")
|
481
|
+
|
482
|
+
elif 'all_answered' in goal or 'default' in goal:
|
483
|
+
if isinstance(goal, dict):
|
484
|
+
|
485
|
+
if 'export' in goal['all_answered']:
|
486
|
+
all_answered_goal = [list(goal.keys())[0], goal['all_answered']['export']]
|
487
|
+
else:
|
488
|
+
all_answered_goal = [list(goal.keys())[0], False]
|
489
|
+
|
490
|
+
if 'limit' in goal['all_answered']:
|
491
|
+
all_answered_goal.append(goal['all_answered']['limit'])
|
492
|
+
else:
|
493
|
+
all_answered_goal.append(30)
|
494
|
+
|
495
|
+
return all_answered_goal
|
496
|
+
else:
|
497
|
+
return [goal, False, 30]
|
498
|
+
|
499
|
+
elif 'random steps' in goal:
|
500
|
+
if goal['random steps'] < 20:
|
501
|
+
return list(goal.keys())[0], random.randint(1, goal['random steps'])
|
502
|
+
else:
|
503
|
+
if self.validation:
|
504
|
+
issue = ValidationIssue(
|
505
|
+
field="goal_style",
|
506
|
+
error=str(f"Goal steps higher than 20 steps: {goal['random steps']}"),
|
507
|
+
error_type=type(OutOfLimitException).__name__,
|
508
|
+
location="conversation.goal_style"
|
509
|
+
)
|
510
|
+
self.errors.append(issue)
|
511
|
+
return ""
|
512
|
+
else:
|
513
|
+
raise OutOfLimitException(f"Goal steps higher than 20 steps: {goal['random steps']}")
|
514
|
+
|
515
|
+
else:
|
516
|
+
if self.validation:
|
517
|
+
issue = ValidationIssue(
|
518
|
+
field="goal_style",
|
519
|
+
error=str(f"Invalid goal value: {goal}"),
|
520
|
+
error_type=type(InvalidGoalException).__name__,
|
521
|
+
location="conversation.goal_style"
|
522
|
+
)
|
523
|
+
self.errors.append(issue)
|
524
|
+
return ""
|
525
|
+
else:
|
526
|
+
raise InvalidGoalException(f"Invalid goal value: {goal}")
|
527
|
+
|
528
|
+
|
529
|
+
def get_interaction_metadata(self):
|
530
|
+
metadata_list = []
|
531
|
+
for inter in self.interaction_styles:
|
532
|
+
metadata_list.append(inter.get_metadata())
|
533
|
+
|
534
|
+
return metadata_list
|
535
|
+
|
536
|
+
def pick_interaction_style(self, interactions):
|
537
|
+
|
538
|
+
inter_styles = {
|
539
|
+
'long phrases': LongPhrases(),
|
540
|
+
'change your mind': ChangeYourMind(),
|
541
|
+
'change language': ChangeLanguage(self.language),
|
542
|
+
'make spelling mistakes': MakeSpellingMistakes(),
|
543
|
+
'single question': SingleQuestions(),
|
544
|
+
'all questions': AllQuestions(),
|
545
|
+
'default': Default()
|
546
|
+
}
|
547
|
+
|
548
|
+
def choice_styles(interaction_styles):
|
549
|
+
count = random.randint(1, len(interaction_styles))
|
550
|
+
random_list = random.sample(interaction_styles, count)
|
551
|
+
# logging.getLogger().verbose(f'interaction style amount: {count} style(s): {random_list}')
|
552
|
+
logger.info(f'interaction style count: {count}; style(s): {random_list}')
|
553
|
+
return random_list
|
554
|
+
|
555
|
+
def get_styles(interact):
|
556
|
+
interactions_list = []
|
557
|
+
try:
|
558
|
+
for inter in interact:
|
559
|
+
|
560
|
+
if isinstance(inter, dict):
|
561
|
+
keys = list(inter.keys())
|
562
|
+
if keys[0] == "change language":
|
563
|
+
cl_interaction = inter_styles[keys[0]]
|
564
|
+
cl_interaction.languages_options = inter.get(keys[0]).copy()
|
565
|
+
cl_interaction.change_language_flag = True
|
566
|
+
interactions_list.append(cl_interaction)
|
567
|
+
|
568
|
+
else:
|
569
|
+
if inter in inter_styles:
|
570
|
+
interaction = inter_styles[inter]
|
571
|
+
interactions_list.append(interaction)
|
572
|
+
else:
|
573
|
+
|
574
|
+
raise InvalidInteractionException(f"Invalid interaction: {inter}")
|
575
|
+
except InvalidInteractionException as e:
|
576
|
+
issue = ValidationIssue(
|
577
|
+
field="interaction_style",
|
578
|
+
error=str(e),
|
579
|
+
error_type=type(e).__name__,
|
580
|
+
location="conversation.interaction_style"
|
581
|
+
)
|
582
|
+
self.errors.append(issue)
|
583
|
+
logger.error(f"Error: {e}")
|
584
|
+
|
585
|
+
return interactions_list
|
586
|
+
|
587
|
+
# interactions_list = []
|
588
|
+
if interactions is None:
|
589
|
+
interaction_def = inter_styles['default']
|
590
|
+
return [interaction_def]
|
591
|
+
|
592
|
+
elif isinstance(interactions[0], dict) and 'random' in list(interactions[0].keys()):
|
593
|
+
# todo: add validation funct to admit random only if it's alone in the list
|
594
|
+
inter_rand = interactions[0]['random']
|
595
|
+
choice = choice_styles(inter_rand)
|
596
|
+
return get_styles(choice)
|
597
|
+
|
598
|
+
else:
|
599
|
+
return get_styles(interactions)
|
600
|
+
|
601
|
+
def get_language(self):
|
602
|
+
|
603
|
+
for instance in self.interaction_styles:
|
604
|
+
if instance.change_language_flag:
|
605
|
+
prompt = instance.get_prompt()
|
606
|
+
return prompt
|
607
|
+
|
608
|
+
return f"Please, talk in {self.language}"
|