user-simulator 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. user_sim/__init__.py +0 -0
  2. user_sim/cli/__init__.py +0 -0
  3. user_sim/cli/gen_user_profile.py +34 -0
  4. user_sim/cli/init_project.py +65 -0
  5. user_sim/cli/sensei_chat.py +481 -0
  6. user_sim/cli/sensei_check.py +103 -0
  7. user_sim/cli/validation_check.py +143 -0
  8. user_sim/core/__init__.py +0 -0
  9. user_sim/core/ask_about.py +665 -0
  10. user_sim/core/data_extraction.py +260 -0
  11. user_sim/core/data_gathering.py +134 -0
  12. user_sim/core/interaction_styles.py +147 -0
  13. user_sim/core/role_structure.py +608 -0
  14. user_sim/core/user_simulator.py +302 -0
  15. user_sim/handlers/__init__.py +0 -0
  16. user_sim/handlers/asr_module.py +128 -0
  17. user_sim/handlers/html_parser_module.py +202 -0
  18. user_sim/handlers/image_recognition_module.py +139 -0
  19. user_sim/handlers/pdf_parser_module.py +123 -0
  20. user_sim/utils/__init__.py +0 -0
  21. user_sim/utils/config.py +47 -0
  22. user_sim/utils/cost_tracker.py +153 -0
  23. user_sim/utils/cost_tracker_v2.py +193 -0
  24. user_sim/utils/errors.py +15 -0
  25. user_sim/utils/exceptions.py +47 -0
  26. user_sim/utils/languages.py +78 -0
  27. user_sim/utils/register_management.py +62 -0
  28. user_sim/utils/show_logs.py +63 -0
  29. user_sim/utils/token_cost_calculator.py +338 -0
  30. user_sim/utils/url_management.py +60 -0
  31. user_sim/utils/utilities.py +568 -0
  32. user_simulator-0.1.0.dist-info/METADATA +733 -0
  33. user_simulator-0.1.0.dist-info/RECORD +37 -0
  34. user_simulator-0.1.0.dist-info/WHEEL +5 -0
  35. user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
  36. user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
  37. user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,608 @@
1
+ import itertools
2
+ from pydantic import BaseModel, ValidationError, field_validator
3
+ from typing import List, Union, Dict, Optional
4
+ from .interaction_styles import *
5
+ from .ask_about import *
6
+ from user_sim.utils.exceptions import *
7
+ from user_sim.utils.languages import languages
8
+ from user_sim.utils import config
9
+ from dataclasses import dataclass
10
+ from user_sim.handlers.image_recognition_module import init_vision_module
11
+ from .data_gathering import init_data_gathering_module
12
+ from .data_extraction import init_data_extraction_module
13
+ import logging
14
+ logger = logging.getLogger('Info Logger')
15
+
16
+
17
+ def replace_placeholders(phrase, variables):
18
+ def replacer(match):
19
+ key = match.group(1)
20
+ if isinstance(variables, dict):
21
+ return ', '.join(map(str, variables.get(key, [])))
22
+ else:
23
+ return ', '.join(map(str, variables))
24
+
25
+ pattern = re.compile(r'\{\{(\w+)\}\}')
26
+ return pattern.sub(replacer, phrase)
27
+
28
+
29
+
30
+ def list_to_str(list_of_strings):
31
+ if list_of_strings is None:
32
+ return ''
33
+ try:
34
+ single_string = ' '.join(list_of_strings)
35
+ return single_string
36
+ except Exception as e:
37
+ # logging.getLogger().verbose(f'Error: {e}')
38
+ return ''
39
+
40
+
41
+ class ConvFormat(BaseModel):
42
+ type: Optional[str] = "text"
43
+ config: Optional[str] = None
44
+
45
+ class LLM(BaseModel):
46
+ model: Optional[str] = "gpt-4o"
47
+ model_prov: Optional[str] = None
48
+ temperature: Optional[float] = 0.8
49
+ format: Optional[ConvFormat] = ConvFormat() # text, speech, hybrid
50
+
51
+ class User(BaseModel):
52
+ language: Optional[Union[str, None]] = 'English'
53
+ role: str
54
+ context: Optional[Union[List[Union[str, Dict]], Dict, None]] = ''
55
+ goals: list
56
+
57
+
58
+ class ChatbotClass(BaseModel):
59
+ is_starter: Optional[bool] = True
60
+ fallback: str
61
+ output: list
62
+
63
+
64
+ class Conversation(BaseModel):
65
+ number: Union[int, str]
66
+ max_cost: Optional[float]=10**9
67
+ goal_style: Dict
68
+ interaction_style: list
69
+
70
+ @field_validator('max_cost', mode='before')
71
+ @classmethod
72
+ def set_token_count_enabled(cls, value):
73
+ if value is not None:
74
+ config.token_count_enabled = True
75
+ return value
76
+
77
+ class RoleDataModel(BaseModel):
78
+ test_name: str
79
+ llm: Optional[LLM] = LLM()
80
+ user: User
81
+ chatbot: ChatbotClass
82
+ conversation: Conversation
83
+
84
+ @dataclass
85
+ class ValidationIssue:
86
+ field: str
87
+ error: str
88
+ error_type: str
89
+ location: str
90
+
91
+ class RoleData:
92
+
93
+ def __init__(self, yaml_file, project_folder=None, personality_file=None, validation=False):
94
+ self.yaml = yaml_file
95
+ self.validation = validation
96
+ self.personality_file = personality_file
97
+ self.project_folder = project_folder
98
+ self.errors: List[ValidationIssue] = []
99
+
100
+ # try:
101
+ # self.validated_data = RoleDataModel(**self.yaml)
102
+ # except ValidationError as e:
103
+ # for err in e.errors():
104
+ # loc_path = '.'.join(str(part) for part in err['loc'])
105
+ # issue = ValidationIssue(
106
+ # field=err['loc'][-1],
107
+ # error=err['msg'],
108
+ # error_type=err['type'],
109
+ # location=loc_path
110
+ # )
111
+ # self.errors.append(issue)
112
+
113
+ # Test Name
114
+ try:
115
+ self.test_name = self.yaml.get('test_name')
116
+ except Exception as e:
117
+ self.collect_errors(e, prefix='llm')
118
+
119
+ # LLM
120
+ self.model = self.model_provider = self.temperature = self.format_type = self.format_config = None
121
+ try:
122
+ self.llm = LLM(**self.yaml.get('llm', {}))
123
+ self.model = config.model = self.llm.model
124
+ self.model_provider = config.model_provider = self.llm.model_prov
125
+ self.temperature = self.llm.temperature
126
+ self.format_type = self.llm.format.type
127
+ self.format_config = self.llm.format.config
128
+ except Exception as e:
129
+ self.collect_errors(e, prefix='llm')
130
+
131
+ if not self.errors:
132
+ self.init_llm_modules()
133
+
134
+ # User
135
+ self.language = self.role = self.raw_context = self.context = self.ask_about = None
136
+ try:
137
+ self.user = User(**self.yaml.get('user', {}))
138
+ self.language = self.set_language(self.user.language)
139
+ self.role = self.user.role
140
+ self.raw_context = self.user.context
141
+ self.context = self.context_processor(self.raw_context)
142
+ self.ask_about = self.get_ask_about()
143
+ except Exception as e:
144
+ self.collect_errors(e, prefix='user')
145
+
146
+ # Chatbot
147
+ self.is_starter = self.fallback = self.output = None
148
+ try:
149
+ self.chatbot = ChatbotClass(**self.yaml.get('chatbot', {}))
150
+ self.is_starter = self.chatbot.is_starter
151
+ self.fallback = self.chatbot.fallback
152
+ self.output = self.chatbot.output
153
+ except Exception as e:
154
+ self.collect_errors(e, prefix='chatbot')
155
+
156
+ # Conversation
157
+ self.conversation_number = self.max_cost = self.goal_style = self.interaction_styles = None
158
+ try:
159
+ self.conversation = Conversation(**self.yaml.get('conversation', {}))
160
+ self.combinations_dict = {}
161
+ self.conversation_number = self.get_conversation_number(self.conversation.number)
162
+ self.max_cost = self.conversation.max_cost
163
+ config.limit_cost = self.max_cost
164
+ self.goal_style = self.pick_goal_style(self.conversation.goal_style)
165
+ self.interaction_styles = self.pick_interaction_style(self.conversation.interaction_style)
166
+ except Exception as e:
167
+ self.collect_errors(e, prefix='conversation')
168
+
169
+ # # Initialization of all LLM modules
170
+ # self.init_llm_modules()
171
+
172
+ def init_llm_modules(self):
173
+
174
+ init_vision_module()
175
+ init_data_gathering_module()
176
+ init_data_extraction_module()
177
+ init_any_list_module()
178
+ # init_asr_module()
179
+
180
+
181
+ def collect_errors(self, e: ValidationError, prefix=""):
182
+
183
+ if isinstance(e, ValidationError):
184
+ for err in e.errors():
185
+ loc_path = '.'.join(str(part) for part in err['loc'])
186
+ full_path = f"{prefix}.{loc_path}" if prefix else loc_path
187
+ self.errors.append(
188
+ ValidationIssue(
189
+ field=err['loc'][-1],
190
+ error=err['msg'],
191
+ error_type=err['type'],
192
+ location=full_path
193
+ )
194
+ )
195
+ else:
196
+ self.errors.append(
197
+ ValidationIssue(
198
+ field='unknown',
199
+ error=str(e),
200
+ error_type=type(e).__name__,
201
+ location=prefix
202
+ )
203
+ )
204
+
205
+
206
+ def get_errors(self):
207
+ error_list = []
208
+ for error in self.errors:
209
+ formated_error = {
210
+ "field": error.location,
211
+ "error": error.error,
212
+ "type": error.error_type
213
+ }
214
+ error_list.append(formated_error)
215
+ logger.warning(f"\n{len(self.errors)} errors detected.\n")
216
+
217
+ return error_list, len(self.errors)
218
+
219
+
220
+ def get_ask_about(self):
221
+ if self.validation:
222
+ try:
223
+ return AskAboutClass(self.user.goals)
224
+ except Exception as e:
225
+ issue = ValidationIssue(
226
+ field="goals",
227
+ error=str(e),
228
+ error_type=type(e).__name__,
229
+ location="user.goals"
230
+ )
231
+ self.errors.append(issue)
232
+ else:
233
+ return AskAboutClass(self.user.goals)
234
+
235
+
236
+
237
+ def set_language(self, lang):
238
+ if isinstance(lang, type(None)):
239
+ logger.info("Language parameter empty. Setting language to Default (English)")
240
+ return "English"
241
+ try:
242
+ if lang in languages:
243
+ logger.info(f"Language set to {lang}")
244
+ return lang
245
+ else:
246
+ raise InvalidLanguageException(f'Invalid language input: {lang}. Setting language to default (English)')
247
+ except InvalidLanguageException as e:
248
+ issue = ValidationIssue(
249
+ field= "language",
250
+ error=str(e),
251
+ error_type=type(e).__name__,
252
+ location="user.language"
253
+ )
254
+ self.errors.append(issue)
255
+ return "English"
256
+
257
+
258
+ def reset_attributes(self):
259
+ logger.info(f"Preparing attributes for next conversation...")
260
+ self.init_llm_modules()
261
+ self.fallback = self.chatbot.fallback
262
+ # self.is_starter = self.validated_data.is_starter
263
+ self.context = self.context_processor(self.raw_context)
264
+ self.ask_about.reset() # self.picked_elements = [], self.phrases = []
265
+
266
+ self.goal_style = self.pick_goal_style(self.conversation.goal_style)
267
+ self.language = self.set_language(self.user.language)
268
+ self.interaction_styles = self.pick_interaction_style(self.conversation.interaction_style)
269
+
270
+ @staticmethod
271
+ def list_to_dict_reformat(conv):
272
+ result_dict = {k: v for d in conv for k, v in d.items()}
273
+ return result_dict
274
+
275
+ def personality_extraction(self, context):
276
+ if context["personality"]:
277
+ personality = context["personality"]
278
+
279
+ path_list = []
280
+ if os.path.exists(config.custom_personalities_folder):
281
+ custom_personalities_path = config.custom_personalities_folder
282
+ path_list.append(custom_personalities_path)
283
+
284
+ default_personalities_path = os.path.join(config.root_path, "config", "personalities")
285
+ path_list.append(default_personalities_path)
286
+
287
+ try:
288
+ for path in path_list:
289
+ for file in os.listdir(path):
290
+ file_name, ext = os.path.splitext(file)
291
+ clean_personality, _ = os.path.splitext(personality)
292
+ if file_name == clean_personality and ext in ('.yml', '.yaml'):
293
+ personality_path = os.path.join(path, file)
294
+ personality_data = read_yaml(personality_path)
295
+
296
+ try:
297
+ self.personality = personality_data["name"]
298
+ logger.info(f"Personality set to '{file_name}'")
299
+ return personality_data['context']
300
+ except KeyError:
301
+ raise InvalidFormat(f"Key 'context' not found in personality file.")
302
+
303
+ logger.error(f"Couldn't find specified personality file: '{personality}'")
304
+ return ['']
305
+
306
+ except Exception as e:
307
+ logger.error(e)
308
+ return ['']
309
+
310
+ else:
311
+ logger.error(f"Data for context is not a dictionary with context key: {context}.")
312
+ return ['']
313
+
314
+ def get_conversation_number(self, conversation):
315
+ if isinstance(conversation, int):
316
+ logger.info(f"{conversation} conversations will be generated")
317
+ return conversation
318
+
319
+ comb_pattern = r'^combinations(?:\(([^,()\s]+)(?:,\s*([^()]+))?\))?$'
320
+ match = re.match(comb_pattern, conversation.strip())
321
+
322
+ if self.validation:
323
+ generators_list = self.ask_about.var_generators
324
+ combinations_dict = []
325
+
326
+ for generator in generators_list:
327
+ if "matrix" in generator:
328
+ name = generator['name']
329
+ combination_matrix = []
330
+ combinations = 0
331
+ if generator['type'] == 'forward':
332
+ combination_matrix = [list(p) for p in itertools.product(*generator['matrix'])]
333
+ combinations = len(combination_matrix)
334
+ elif generator['type'] == 'pairwise':
335
+ combination_matrix = generator['matrix']
336
+ combinations = len(combination_matrix)
337
+
338
+ combinations_dict.append({'name':name,
339
+ 'matrix':combination_matrix,
340
+ 'combinations':combinations,
341
+ 'type': generator['type']})
342
+
343
+ self.combinations_dict = combinations_dict
344
+
345
+ if match:
346
+ # func_name = "combinations"
347
+ sample = match.group(1)
348
+ iter_function = match.group(2)
349
+
350
+ if iter_function == "forward":
351
+ if self.ask_about.forward_combinations <= 0:
352
+ logger.error("Conversation number set to 'forward_all_combinations' but no combinations can be made.")
353
+ return 0
354
+ conv_number = self.ask_about.forward_combinations
355
+
356
+ if sample:
357
+ conv_number = round(conv_number * float(sample))
358
+ logger.info(f"{conv_number} conversations will be generated.")
359
+ return conv_number
360
+
361
+ elif iter_function == "pairwise":
362
+ if self.ask_about.pairwise_combinations <= 0:
363
+ logger.error("Conversation number set to 'pairwise_all_combinations' but no combinations can be made.")
364
+ return 0
365
+
366
+ conv_number = self.ask_about.pairwise_combinations
367
+ if sample:
368
+ conv_number = round(conv_number * float(sample))
369
+ logger.info(f"{conv_number} conversations will be generated.")
370
+ return conv_number
371
+
372
+ else:
373
+ conv_number = max(self.ask_about.forward_combinations, self.ask_about.pairwise_combinations)
374
+ if conv_number < 1:
375
+ logger.error("Conversation number set to 'combinations' but no combinations can be made.")
376
+ return 0
377
+ if sample:
378
+ conv_number = round(conv_number * float(sample))
379
+ logger.info(f"{conv_number} conversations will be generated.")
380
+ return conv_number
381
+
382
+ else:
383
+ logger.error(f"Conversation number can't be obtained due tu unrecognized value: {conversation}")
384
+ issue = ValidationIssue(
385
+ field= "language",
386
+ error=f"Conversation number can't be obtained due tu unrecognized value: {conversation}",
387
+ error_type=type(InvalidFormat).__name__,
388
+ location="conversation.number"
389
+ )
390
+ self.errors.append(issue)
391
+ return 0
392
+
393
+ def context_processor(self, context):
394
+ if isinstance(context, dict):
395
+ personality_phrases = self.personality_extraction(context)
396
+ return list_to_str(personality_phrases)
397
+
398
+ res = len(list(filter(lambda x: isinstance(x, dict), context)))
399
+ if res > 1:
400
+ # raise InvalidFormat(f)
401
+ issue = ValidationIssue(
402
+ field="context",
403
+ error=str("Too many keys in context list."),
404
+ error_type=type(InvalidFormat).__name__,
405
+ location="user.context"
406
+ )
407
+ self.errors.append(issue)
408
+ return ""
409
+ elif res <= 0 and not isinstance(context, dict):
410
+ phrases = list_to_str(context)
411
+ if self.personality_file is not None:
412
+ personality = read_yaml(self.personality_file)
413
+ personality_phrases = personality['context']
414
+ phrases = phrases + list_to_str(personality_phrases)
415
+ return phrases
416
+ else:
417
+ custom_phrases = []
418
+ personality_phrases = []
419
+ for item in context:
420
+ if isinstance(item, str):
421
+ custom_phrases.append(item)
422
+ elif isinstance(item, dict):
423
+ personality_phrases = personality_phrases + self.personality_extraction(item)
424
+ else:
425
+ issue = ValidationIssue(
426
+ field="context",
427
+ error=str(f"Invalid data type in context list: {type(item)}:{item}"),
428
+ error_type=type(InvalidDataType).__name__,
429
+ location="user.context"
430
+ )
431
+ self.errors.append(issue)
432
+ return ""
433
+
434
+ # If no personality is given, we use the one specified as input in the command line
435
+ if len(personality_phrases) == 0 and self.personality_file is not None:
436
+ personality = read_yaml(self.personality_file)
437
+ personality_phrases = personality['context']
438
+
439
+ total_phrases = personality_phrases + custom_phrases
440
+ return list_to_str(total_phrases)
441
+
442
+ def pick_goal_style(self, goal):
443
+
444
+ if goal is None:
445
+ return goal, False
446
+
447
+ if 'max_cost' in goal:
448
+ if goal['max_cost'] > 0:
449
+ config.limit_individual_cost = goal['max_cost']
450
+ config.token_count_enabled = True
451
+ else:
452
+ if self.validation:
453
+ issue = ValidationIssue(
454
+ field="goal_style",
455
+ error=str(f"Goal cost can't be lower than or equal to 0: {goal['cost']}"),
456
+ error_type=type(NoCostException).__name__,
457
+ location="conversation.goal_style"
458
+ )
459
+ self.errors.append(issue)
460
+ return ""
461
+ else:
462
+ raise NoCostException(f"Goal cost can't be lower than or equal to 0: {goal['cost']}")
463
+ else:
464
+ config.limit_individual_cost = config.limit_cost
465
+
466
+ if 'steps' in goal:
467
+ if goal['steps'] <= 20 or goal['steps'] > 0:
468
+ return list(goal.keys())[0], goal['steps']
469
+ else:
470
+ if self.validation:
471
+ issue = ValidationIssue(
472
+ field="goal_style",
473
+ error=str(f"Goal steps higher than 20 steps or lower than 0 steps: {goal['steps']}"),
474
+ error_type=type(OutOfLimitException).__name__,
475
+ location="conversation.goal_style"
476
+ )
477
+ self.errors.append(issue)
478
+ return ""
479
+ else:
480
+ raise OutOfLimitException(f"Goal steps higher than 20 steps or lower than 0 steps: {goal['steps']}")
481
+
482
+ elif 'all_answered' in goal or 'default' in goal:
483
+ if isinstance(goal, dict):
484
+
485
+ if 'export' in goal['all_answered']:
486
+ all_answered_goal = [list(goal.keys())[0], goal['all_answered']['export']]
487
+ else:
488
+ all_answered_goal = [list(goal.keys())[0], False]
489
+
490
+ if 'limit' in goal['all_answered']:
491
+ all_answered_goal.append(goal['all_answered']['limit'])
492
+ else:
493
+ all_answered_goal.append(30)
494
+
495
+ return all_answered_goal
496
+ else:
497
+ return [goal, False, 30]
498
+
499
+ elif 'random steps' in goal:
500
+ if goal['random steps'] < 20:
501
+ return list(goal.keys())[0], random.randint(1, goal['random steps'])
502
+ else:
503
+ if self.validation:
504
+ issue = ValidationIssue(
505
+ field="goal_style",
506
+ error=str(f"Goal steps higher than 20 steps: {goal['random steps']}"),
507
+ error_type=type(OutOfLimitException).__name__,
508
+ location="conversation.goal_style"
509
+ )
510
+ self.errors.append(issue)
511
+ return ""
512
+ else:
513
+ raise OutOfLimitException(f"Goal steps higher than 20 steps: {goal['random steps']}")
514
+
515
+ else:
516
+ if self.validation:
517
+ issue = ValidationIssue(
518
+ field="goal_style",
519
+ error=str(f"Invalid goal value: {goal}"),
520
+ error_type=type(InvalidGoalException).__name__,
521
+ location="conversation.goal_style"
522
+ )
523
+ self.errors.append(issue)
524
+ return ""
525
+ else:
526
+ raise InvalidGoalException(f"Invalid goal value: {goal}")
527
+
528
+
529
+ def get_interaction_metadata(self):
530
+ metadata_list = []
531
+ for inter in self.interaction_styles:
532
+ metadata_list.append(inter.get_metadata())
533
+
534
+ return metadata_list
535
+
536
+ def pick_interaction_style(self, interactions):
537
+
538
+ inter_styles = {
539
+ 'long phrases': LongPhrases(),
540
+ 'change your mind': ChangeYourMind(),
541
+ 'change language': ChangeLanguage(self.language),
542
+ 'make spelling mistakes': MakeSpellingMistakes(),
543
+ 'single question': SingleQuestions(),
544
+ 'all questions': AllQuestions(),
545
+ 'default': Default()
546
+ }
547
+
548
+ def choice_styles(interaction_styles):
549
+ count = random.randint(1, len(interaction_styles))
550
+ random_list = random.sample(interaction_styles, count)
551
+ # logging.getLogger().verbose(f'interaction style amount: {count} style(s): {random_list}')
552
+ logger.info(f'interaction style count: {count}; style(s): {random_list}')
553
+ return random_list
554
+
555
+ def get_styles(interact):
556
+ interactions_list = []
557
+ try:
558
+ for inter in interact:
559
+
560
+ if isinstance(inter, dict):
561
+ keys = list(inter.keys())
562
+ if keys[0] == "change language":
563
+ cl_interaction = inter_styles[keys[0]]
564
+ cl_interaction.languages_options = inter.get(keys[0]).copy()
565
+ cl_interaction.change_language_flag = True
566
+ interactions_list.append(cl_interaction)
567
+
568
+ else:
569
+ if inter in inter_styles:
570
+ interaction = inter_styles[inter]
571
+ interactions_list.append(interaction)
572
+ else:
573
+
574
+ raise InvalidInteractionException(f"Invalid interaction: {inter}")
575
+ except InvalidInteractionException as e:
576
+ issue = ValidationIssue(
577
+ field="interaction_style",
578
+ error=str(e),
579
+ error_type=type(e).__name__,
580
+ location="conversation.interaction_style"
581
+ )
582
+ self.errors.append(issue)
583
+ logger.error(f"Error: {e}")
584
+
585
+ return interactions_list
586
+
587
+ # interactions_list = []
588
+ if interactions is None:
589
+ interaction_def = inter_styles['default']
590
+ return [interaction_def]
591
+
592
+ elif isinstance(interactions[0], dict) and 'random' in list(interactions[0].keys()):
593
+ # todo: add validation funct to admit random only if it's alone in the list
594
+ inter_rand = interactions[0]['random']
595
+ choice = choice_styles(inter_rand)
596
+ return get_styles(choice)
597
+
598
+ else:
599
+ return get_styles(interactions)
600
+
601
+ def get_language(self):
602
+
603
+ for instance in self.interaction_styles:
604
+ if instance.change_language_flag:
605
+ prompt = instance.get_prompt()
606
+ return prompt
607
+
608
+ return f"Please, talk in {self.language}"