user-simulator 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- user_sim/__init__.py +0 -0
- user_sim/cli/__init__.py +0 -0
- user_sim/cli/gen_user_profile.py +34 -0
- user_sim/cli/init_project.py +65 -0
- user_sim/cli/sensei_chat.py +481 -0
- user_sim/cli/sensei_check.py +103 -0
- user_sim/cli/validation_check.py +143 -0
- user_sim/core/__init__.py +0 -0
- user_sim/core/ask_about.py +665 -0
- user_sim/core/data_extraction.py +260 -0
- user_sim/core/data_gathering.py +134 -0
- user_sim/core/interaction_styles.py +147 -0
- user_sim/core/role_structure.py +608 -0
- user_sim/core/user_simulator.py +302 -0
- user_sim/handlers/__init__.py +0 -0
- user_sim/handlers/asr_module.py +128 -0
- user_sim/handlers/html_parser_module.py +202 -0
- user_sim/handlers/image_recognition_module.py +139 -0
- user_sim/handlers/pdf_parser_module.py +123 -0
- user_sim/utils/__init__.py +0 -0
- user_sim/utils/config.py +47 -0
- user_sim/utils/cost_tracker.py +153 -0
- user_sim/utils/cost_tracker_v2.py +193 -0
- user_sim/utils/errors.py +15 -0
- user_sim/utils/exceptions.py +47 -0
- user_sim/utils/languages.py +78 -0
- user_sim/utils/register_management.py +62 -0
- user_sim/utils/show_logs.py +63 -0
- user_sim/utils/token_cost_calculator.py +338 -0
- user_sim/utils/url_management.py +60 -0
- user_sim/utils/utilities.py +568 -0
- user_simulator-0.1.0.dist-info/METADATA +733 -0
- user_simulator-0.1.0.dist-info/RECORD +37 -0
- user_simulator-0.1.0.dist-info/WHEEL +5 -0
- user_simulator-0.1.0.dist-info/entry_points.txt +6 -0
- user_simulator-0.1.0.dist-info/licenses/LICENSE.txt +21 -0
- user_simulator-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,665 @@
|
|
1
|
+
from user_sim.utils.utilities import *
|
2
|
+
from user_sim.utils.exceptions import *
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import logging
|
6
|
+
import random
|
7
|
+
|
8
|
+
from user_sim.utils import config
|
9
|
+
from allpairspy import AllPairs
|
10
|
+
from langchain_core.prompts import ChatPromptTemplate
|
11
|
+
from user_sim.utils.utilities import init_model
|
12
|
+
from user_sim.utils.token_cost_calculator import calculate_cost, max_input_tokens_allowed, max_output_tokens_allowed
|
13
|
+
|
14
|
+
model = ""
|
15
|
+
llm = None
|
16
|
+
|
17
|
+
logger = logging.getLogger('Info Logger')
|
18
|
+
|
19
|
+
|
20
|
+
def init_any_list_module():
|
21
|
+
global model
|
22
|
+
global llm
|
23
|
+
model, llm = init_model()
|
24
|
+
|
25
|
+
|
26
|
+
class VarGenerators:
|
27
|
+
|
28
|
+
def __init__(self, variable_list):
|
29
|
+
|
30
|
+
self.forward_combinations = 0
|
31
|
+
self.pairwise_combinations = 0
|
32
|
+
self.variable_list = variable_list
|
33
|
+
self.generator_list = self.create_generator_list()
|
34
|
+
|
35
|
+
class ForwardMatrixGenerator:
|
36
|
+
def __init__(self):
|
37
|
+
self.forward_function_list = []
|
38
|
+
self.dependence_tuple_list = [] # [(size, toppings), (toppings,drink), (drink, None)]
|
39
|
+
self.dependent_list = []
|
40
|
+
self.independent_list = []
|
41
|
+
self.item_matrix = []
|
42
|
+
# self.dependent_generators = []
|
43
|
+
# self.independent_generators = []
|
44
|
+
|
45
|
+
def get_matrix(self, dependent_variable_list):
|
46
|
+
self.item_matrix.clear()
|
47
|
+
for index, dependence in enumerate(dependent_variable_list):
|
48
|
+
self.item_matrix.append([])
|
49
|
+
for variable in dependence:
|
50
|
+
for forward in self.forward_function_list:
|
51
|
+
if variable == forward['name']:
|
52
|
+
self.item_matrix[index].append(forward['data'])
|
53
|
+
|
54
|
+
def add_forward(self,
|
55
|
+
forward_variable): # 'name': var_name, 'data': data_list,'function': content['function'],'dependence': dependence}
|
56
|
+
self.forward_function_list.append(forward_variable)
|
57
|
+
|
58
|
+
if forward_variable['dependence']:
|
59
|
+
master = forward_variable['dependence']
|
60
|
+
slave = forward_variable['name']
|
61
|
+
self.dependence_tuple_list.append((slave, master))
|
62
|
+
for indep_item in self.independent_list:
|
63
|
+
if indep_item == master:
|
64
|
+
self.independent_list.remove(master)
|
65
|
+
self.dependence_tuple_list.append((master, None))
|
66
|
+
|
67
|
+
else:
|
68
|
+
if self.dependence_tuple_list:
|
69
|
+
dtlc = self.dependence_tuple_list.copy()
|
70
|
+
for dependence in dtlc: # [(size, toppings), (toppings,drink), (drink, None)]
|
71
|
+
if forward_variable['name'] in dependence:
|
72
|
+
master = forward_variable['name']
|
73
|
+
self.dependence_tuple_list.append((master, None))
|
74
|
+
break
|
75
|
+
else:
|
76
|
+
master = forward_variable['name']
|
77
|
+
self.independent_list.append(master)
|
78
|
+
else:
|
79
|
+
master = forward_variable['name']
|
80
|
+
self.independent_list.append(master)
|
81
|
+
|
82
|
+
if self.dependence_tuple_list:
|
83
|
+
self.dependent_list = build_sequence(self.dependence_tuple_list)
|
84
|
+
self.get_matrix(self.dependent_list)
|
85
|
+
pass
|
86
|
+
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def combination_generator(matrix):
|
90
|
+
if not matrix:
|
91
|
+
while True:
|
92
|
+
yield []
|
93
|
+
else:
|
94
|
+
lengths = [len(lst) for lst in matrix]
|
95
|
+
indices = [0] * len(matrix)
|
96
|
+
while True:
|
97
|
+
# Yield the current combination based on indices
|
98
|
+
yield [matrix[i][indices[i]] for i in range(len(matrix))]
|
99
|
+
# Increment indices from the last position
|
100
|
+
i = len(matrix) - 1
|
101
|
+
while i >= 0:
|
102
|
+
indices[i] += 1
|
103
|
+
if indices[i] < lengths[i]:
|
104
|
+
break
|
105
|
+
else:
|
106
|
+
indices[i] = 0
|
107
|
+
i -= 1
|
108
|
+
|
109
|
+
def get_combinations(self):
|
110
|
+
if self.item_matrix:
|
111
|
+
combinations = []
|
112
|
+
for matrix in self.item_matrix:
|
113
|
+
combinations_one_matrix = 1
|
114
|
+
for sublist in matrix:
|
115
|
+
combinations_one_matrix *= len(sublist)
|
116
|
+
combinations.append(combinations_one_matrix)
|
117
|
+
return max(combinations)
|
118
|
+
else:
|
119
|
+
return 0
|
120
|
+
|
121
|
+
@staticmethod
|
122
|
+
def forward_generator(value_list):
|
123
|
+
while True:
|
124
|
+
for sample in value_list:
|
125
|
+
yield [sample]
|
126
|
+
|
127
|
+
def get_generator_list(self):
|
128
|
+
function_map = {function['name']: function['data'] for function in self.forward_function_list}
|
129
|
+
|
130
|
+
independent_generators = [
|
131
|
+
{'name': i,
|
132
|
+
'generator': self.forward_generator(function_map[i]),
|
133
|
+
'type': "forward"} for i in self.independent_list if
|
134
|
+
i in function_map
|
135
|
+
]
|
136
|
+
|
137
|
+
dependent_generators = [
|
138
|
+
{'name': val,
|
139
|
+
'generator': self.combination_generator(self.item_matrix[index]),
|
140
|
+
'type': 'forward',
|
141
|
+
'matrix': self.item_matrix[index]} for index, val in enumerate(self.dependent_list)
|
142
|
+
]
|
143
|
+
gens = independent_generators + dependent_generators
|
144
|
+
return gens
|
145
|
+
|
146
|
+
class PairwiseMatrixGenerator:
|
147
|
+
def __init__(self):
|
148
|
+
self.pairwise_function_list = []
|
149
|
+
self.pairwise_variable_list = []
|
150
|
+
self.parameters_matrix = []
|
151
|
+
self.item_matrix = []
|
152
|
+
self.combinations = 0
|
153
|
+
|
154
|
+
def add_pairwise(self, pairwise_variable): # 'name': var_name, 'data': data_list,'function': content['function'],'dependence': dependence}
|
155
|
+
self.pairwise_function_list.append(pairwise_variable)
|
156
|
+
self.pairwise_variable_list.append(pairwise_variable)
|
157
|
+
|
158
|
+
if len(self.pairwise_function_list) > 1:
|
159
|
+
self.pairwise_function_list = sorted(self.pairwise_function_list, key=lambda d: len(d['data']), reverse=True)
|
160
|
+
self.pairwise_variable_list = [function['name'] for function in self.pairwise_function_list]
|
161
|
+
self.parameters_matrix = [function['data'] for function in self.pairwise_function_list]
|
162
|
+
self.item_matrix = list(AllPairs(self.parameters_matrix))
|
163
|
+
self.combinations = len(self.item_matrix)
|
164
|
+
|
165
|
+
|
166
|
+
|
167
|
+
def pairwise_generator(self, pairwise_matrix):
|
168
|
+
"""
|
169
|
+
Given a list of parameter value-lists, generate exactly
|
170
|
+
(size of two largest lists) combinations by taking the full
|
171
|
+
Cartesian product of those two and cycling through the others.
|
172
|
+
"""
|
173
|
+
|
174
|
+
for values in pairwise_matrix:
|
175
|
+
yield values
|
176
|
+
|
177
|
+
|
178
|
+
def get_generator_list(self):
|
179
|
+
|
180
|
+
if self.pairwise_function_list:
|
181
|
+
pairwise_generators = [
|
182
|
+
{'name': self.pairwise_variable_list,
|
183
|
+
'generator': self.pairwise_generator(self.item_matrix),
|
184
|
+
'type': 'pairwise',
|
185
|
+
'matrix': self.item_matrix
|
186
|
+
}
|
187
|
+
]
|
188
|
+
else:
|
189
|
+
pairwise_generators = []
|
190
|
+
|
191
|
+
return pairwise_generators
|
192
|
+
|
193
|
+
def get_combinations(self):
|
194
|
+
# self.combinations = len(self.parameters_matrix[0]) * len(self.parameters_matrix[1])
|
195
|
+
return self.combinations
|
196
|
+
|
197
|
+
def create_generator_list(self):
|
198
|
+
generator_list = []
|
199
|
+
my_forward = self.ForwardMatrixGenerator()
|
200
|
+
my_pairwise = self.PairwiseMatrixGenerator()
|
201
|
+
for variable in self.variable_list:
|
202
|
+
name = variable['name']
|
203
|
+
data = variable['data']
|
204
|
+
pattern = r'(\w+)\((\w*)\)'
|
205
|
+
if not variable['function'] or variable['function'] == 'default()':
|
206
|
+
generator = self.default_generator(data)
|
207
|
+
generator_list.append({'name': name, 'generator': generator})
|
208
|
+
else:
|
209
|
+
match = re.search(pattern, variable['function'])
|
210
|
+
if match:
|
211
|
+
handler_name = match.group(1)
|
212
|
+
count = match.group(2) if match.group(2) else ''
|
213
|
+
if handler_name == 'random':
|
214
|
+
if count == '':
|
215
|
+
generator = self.random_choice_generator(data)
|
216
|
+
generator_list.append({'name': name, 'generator': generator})
|
217
|
+
elif count.isdigit():
|
218
|
+
count_digit = int(count)
|
219
|
+
generator = self.random_choice_count_generator(data, count_digit)
|
220
|
+
generator_list.append({'name': name, 'generator': generator})
|
221
|
+
elif count == 'rand':
|
222
|
+
generator = self.random_choice_random_count_generator(data)
|
223
|
+
generator_list.append({'name': name, 'generator': generator})
|
224
|
+
|
225
|
+
elif handler_name == 'forward':
|
226
|
+
my_forward.add_forward(variable)
|
227
|
+
|
228
|
+
elif handler_name == 'pairwise':
|
229
|
+
my_pairwise.add_pairwise(variable)
|
230
|
+
|
231
|
+
elif handler_name == 'another':
|
232
|
+
if count == '':
|
233
|
+
generator = self.another_generator(data)
|
234
|
+
generator_list.append({'name': name, 'generator': generator})
|
235
|
+
elif count.isdigit():
|
236
|
+
count_digit = int(count)
|
237
|
+
generator = self.another_count_generator(data, count_digit)
|
238
|
+
generator_list.append({'name': name, 'generator': generator})
|
239
|
+
else:
|
240
|
+
raise InvalidGenerator(f'Invalid generator function: {handler_name}')
|
241
|
+
else:
|
242
|
+
raise InvalidFormat(f"an invalid function format was used: {variable['function']}")
|
243
|
+
|
244
|
+
|
245
|
+
generators = generator_list + my_forward.get_generator_list() + my_pairwise.get_generator_list()
|
246
|
+
self.forward_combinations = my_forward.get_combinations()
|
247
|
+
self.pairwise_combinations = my_pairwise.get_combinations()
|
248
|
+
return generators
|
249
|
+
|
250
|
+
@staticmethod
|
251
|
+
def default_generator(data):
|
252
|
+
while True:
|
253
|
+
yield [data]
|
254
|
+
|
255
|
+
@staticmethod
|
256
|
+
def random_choice_generator(data):
|
257
|
+
while True:
|
258
|
+
yield [random.choice(data)]
|
259
|
+
|
260
|
+
@staticmethod
|
261
|
+
def random_choice_count_generator(data, count):
|
262
|
+
while True:
|
263
|
+
sample = random.sample(data, min(count, len(data)))
|
264
|
+
yield sample
|
265
|
+
|
266
|
+
@staticmethod
|
267
|
+
def random_choice_random_count_generator(data):
|
268
|
+
while True:
|
269
|
+
count = random.randint(1, len(data))
|
270
|
+
sample = random.sample(data, min(count, len(data)))
|
271
|
+
yield sample
|
272
|
+
|
273
|
+
@staticmethod
|
274
|
+
def another_count_generator(data, count):
|
275
|
+
while True:
|
276
|
+
copy_list = data[:]
|
277
|
+
random.shuffle(copy_list)
|
278
|
+
for i in range(0, len(copy_list), count):
|
279
|
+
yield copy_list[i:i + count]
|
280
|
+
|
281
|
+
@staticmethod
|
282
|
+
def another_generator(data):
|
283
|
+
while True:
|
284
|
+
copy_list = data[:]
|
285
|
+
random.shuffle(copy_list)
|
286
|
+
for sample in copy_list:
|
287
|
+
yield [sample]
|
288
|
+
|
289
|
+
|
290
|
+
def reorder_variables(entries):
|
291
|
+
def parse_entry(entry):
|
292
|
+
|
293
|
+
match = re.search(r'forward\((.*?)\)', entry['function'])
|
294
|
+
if match:
|
295
|
+
slave = entry['name']
|
296
|
+
master = match.group(1)
|
297
|
+
return slave, master
|
298
|
+
|
299
|
+
def reorder_list(dependencies):
|
300
|
+
tuple_list = []
|
301
|
+
none_list = []
|
302
|
+
for main_tuple in dependencies:
|
303
|
+
if main_tuple:
|
304
|
+
for comp_tuple in dependencies:
|
305
|
+
if comp_tuple:
|
306
|
+
if main_tuple[1] == comp_tuple[0]:
|
307
|
+
tuple_list.append(main_tuple)
|
308
|
+
tuple_list.append(comp_tuple)
|
309
|
+
else:
|
310
|
+
none_list.append(main_tuple)
|
311
|
+
|
312
|
+
tuple_list = list(dict.fromkeys(tuple_list))
|
313
|
+
return tuple_list
|
314
|
+
|
315
|
+
dependencies_list = []
|
316
|
+
|
317
|
+
for entry in entries:
|
318
|
+
dependencies_list.append(parse_entry(entry))
|
319
|
+
|
320
|
+
reordered_list = reorder_list(dependencies_list)
|
321
|
+
|
322
|
+
editable_entries = entries.copy()
|
323
|
+
new_entries = []
|
324
|
+
for tupl in reordered_list:
|
325
|
+
for entry in entries:
|
326
|
+
if tupl[0] == entry['name']:
|
327
|
+
new_entries.append(entry)
|
328
|
+
editable_entries.remove(entry)
|
329
|
+
reordered_entries = new_entries + editable_entries
|
330
|
+
return reordered_entries
|
331
|
+
|
332
|
+
|
333
|
+
def dependency_error_check(variable_list):
|
334
|
+
for slave in variable_list:
|
335
|
+
for master in variable_list:
|
336
|
+
if slave['dependence'] == master['name']:
|
337
|
+
pattern = r'(\w+)\((\w*)\)'
|
338
|
+
match = re.search(pattern, master['function'])
|
339
|
+
function = match.group(1)
|
340
|
+
if function != 'forward':
|
341
|
+
raise InvalidDependence(f"the following function doesn't admit dependence: {function}()")
|
342
|
+
|
343
|
+
|
344
|
+
def check_circular_dependency(items):
|
345
|
+
dependencies = {}
|
346
|
+
for item in items:
|
347
|
+
name = item['name']
|
348
|
+
dep = item['dependence']
|
349
|
+
dependencies[name] = dep
|
350
|
+
|
351
|
+
def visit(node, visited, stack):
|
352
|
+
if node in stack:
|
353
|
+
cycle = ' -> '.join(stack + [node])
|
354
|
+
raise Exception(f"Circular dependency detected: {cycle}")
|
355
|
+
if node in visited or node not in dependencies:
|
356
|
+
return
|
357
|
+
stack.append(node)
|
358
|
+
dep = dependencies[node]
|
359
|
+
if dep is not None:
|
360
|
+
visit(dep, visited, stack)
|
361
|
+
stack.pop()
|
362
|
+
visited.add(node)
|
363
|
+
|
364
|
+
visited = set()
|
365
|
+
for node in dependencies.keys():
|
366
|
+
if node not in visited:
|
367
|
+
visit(node, visited, [])
|
368
|
+
|
369
|
+
|
370
|
+
class AskAboutClass:
|
371
|
+
|
372
|
+
def __init__(self, data):
|
373
|
+
|
374
|
+
self.variable_list = self.get_variables(data)
|
375
|
+
self.str_list = self.get_phrases(data)
|
376
|
+
self.var_generators, self.forward_combinations, self.pairwise_combinations = self.variable_generator(self.variable_list)
|
377
|
+
self.phrases = self.str_list.copy()
|
378
|
+
self.picked_elements = []
|
379
|
+
|
380
|
+
|
381
|
+
@staticmethod
|
382
|
+
def validate_type_format(data_list, data_type):
|
383
|
+
type_format = config.types_dict[data_type]["format"]
|
384
|
+
t_format = normalize_regex_pattern(type_format)
|
385
|
+
regex = re.compile(t_format)
|
386
|
+
match = all(regex.fullmatch(item) for item in data_list)
|
387
|
+
return match
|
388
|
+
|
389
|
+
|
390
|
+
def get_variables(self, data):
|
391
|
+
variables = []
|
392
|
+
|
393
|
+
for item in data:
|
394
|
+
if isinstance(item, dict):
|
395
|
+
var_name = list(item.keys())[0]
|
396
|
+
content = item[var_name]
|
397
|
+
content_data = content['data'].copy()
|
398
|
+
if isinstance(content_data, dict) and 'file' in content_data: # check for personalized functions
|
399
|
+
path = content_data['file']
|
400
|
+
function = content_data['function_name']
|
401
|
+
if 'args' in content_data:
|
402
|
+
function_arguments = content_data['args']
|
403
|
+
data_list = execute_list_function(path, function, function_arguments)
|
404
|
+
else:
|
405
|
+
data_list = execute_list_function(path, function)
|
406
|
+
elif isinstance(content_data, dict) and 'date' in content_data: # check for date generator
|
407
|
+
data_list = get_date_list(content_data['date'])
|
408
|
+
else:
|
409
|
+
if content_data:
|
410
|
+
data_list = content_data
|
411
|
+
else:
|
412
|
+
raise EmptyListExcept(f'Data list is empty.')
|
413
|
+
|
414
|
+
any_list = []
|
415
|
+
item_list = []
|
416
|
+
|
417
|
+
if isinstance(content_data, list): # check for any() in data list
|
418
|
+
|
419
|
+
for index, value in enumerate(data_list):
|
420
|
+
if isinstance(value, str):
|
421
|
+
if 'any(' in value:
|
422
|
+
any_list.append(value)
|
423
|
+
else:
|
424
|
+
item_list.append(value)
|
425
|
+
else:
|
426
|
+
item_list.append(value)
|
427
|
+
|
428
|
+
if content['type'] == 'string':
|
429
|
+
for i in item_list:
|
430
|
+
if type(i) is not str:
|
431
|
+
raise InvalidDataType(f'The following item is not a string: {i}')
|
432
|
+
output_data_list = self.get_any_items(item_list, any_list, "string")
|
433
|
+
if not data_list:
|
434
|
+
raise EmptyListExcept(f'Data list is empty.')
|
435
|
+
|
436
|
+
elif content['type'] == 'int':
|
437
|
+
if isinstance(data_list, list):
|
438
|
+
for i in data_list:
|
439
|
+
if type(i) is not int:
|
440
|
+
raise InvalidDataType(f'The following item is not an integer: {i}')
|
441
|
+
if data_list:
|
442
|
+
output_data_list = data_list
|
443
|
+
else:
|
444
|
+
raise EmptyListExcept(f'Data list is empty.')
|
445
|
+
elif isinstance(data_list, dict) and 'min' in data_list:
|
446
|
+
keys = list(data_list.keys())
|
447
|
+
data = data_list
|
448
|
+
if 'step' in keys:
|
449
|
+
if isinstance(data['min'], int) and isinstance(data['max'], int) and isinstance(
|
450
|
+
data['step'], int):
|
451
|
+
output_data_list = np.arange(data['min'], data['max'], data['step'])
|
452
|
+
output_data_list = output_data_list.tolist()
|
453
|
+
output_data_list.append(data['max'])
|
454
|
+
|
455
|
+
else:
|
456
|
+
raise InvalidDataType(f'Some of the range function parameters are not integers.')
|
457
|
+
else:
|
458
|
+
if isinstance(data['min'], int) and isinstance(data['max'], int):
|
459
|
+
output_data_list = np.arange(data['min'], data['max'])
|
460
|
+
output_data_list = output_data_list.tolist()
|
461
|
+
else:
|
462
|
+
raise InvalidDataType(f'Some of the range function parameters are not integers.')
|
463
|
+
else:
|
464
|
+
raise InvalidFormat(f'Data follows an invalid format.')
|
465
|
+
elif content['type'] == 'float':
|
466
|
+
if isinstance(data_list, list):
|
467
|
+
for i in data_list:
|
468
|
+
if not isinstance(i, (int, float)):
|
469
|
+
raise InvalidDataType(f'The following item is not a number: {i}')
|
470
|
+
if data_list:
|
471
|
+
output_data_list = data_list
|
472
|
+
else:
|
473
|
+
raise EmptyListExcept(f'Data list is empty.')
|
474
|
+
elif isinstance(data_list, dict) and 'min' in data_list:
|
475
|
+
keys = list(data_list.keys())
|
476
|
+
data = content['data']
|
477
|
+
if 'step' in keys:
|
478
|
+
output_data_list = np.arange(data['min'], data['max'], data['step'])
|
479
|
+
output_data_list = output_data_list.tolist()
|
480
|
+
output_data_list.append(data['max'])
|
481
|
+
|
482
|
+
elif 'linspace' in keys:
|
483
|
+
output_data_list = np.linspace(data['min'], data['max'], data['linspace'])
|
484
|
+
output_data_list = output_data_list.tolist()
|
485
|
+
else:
|
486
|
+
raise MissingStepDefinition(
|
487
|
+
f'"step" or "lisnpace" parameter missing. A step separation must be defined.')
|
488
|
+
else:
|
489
|
+
raise InvalidFormat(f'Data follows an invalid format.')
|
490
|
+
else:
|
491
|
+
custom_types_name = list(config.types_dict.keys())
|
492
|
+
if content["type"] in custom_types_name:
|
493
|
+
output_data_list = self.get_any_items(item_list, any_list, content["type"])
|
494
|
+
if not self.validate_type_format(output_data_list, content["type"]):
|
495
|
+
raise InvalidItemType(f'Invalid data type for variable list.')
|
496
|
+
else:
|
497
|
+
raise InvalidItemType(f'Invalid data type for variable list.')
|
498
|
+
|
499
|
+
pattern = r'(\w+)\((\w*)\)'
|
500
|
+
if not content['function']:
|
501
|
+
content['function'] = 'default()'
|
502
|
+
|
503
|
+
match = re.search(pattern, content['function'])
|
504
|
+
if match:
|
505
|
+
count = match.group(2) if match.group(2) else ''
|
506
|
+
if not count == '' or count == 'rand' or count.isdigit():
|
507
|
+
dependence = count
|
508
|
+
else:
|
509
|
+
dependence = None
|
510
|
+
else:
|
511
|
+
dependence = None
|
512
|
+
|
513
|
+
logger.info(f"{var_name}: {output_data_list}")
|
514
|
+
|
515
|
+
dictionary = {'name': var_name, 'data': output_data_list,
|
516
|
+
'function': content['function'],
|
517
|
+
'dependence': dependence} # (size, [small, medium], random(), toppings)
|
518
|
+
variables.append(dictionary)
|
519
|
+
reordered_variables = reorder_variables(variables)
|
520
|
+
dependency_error_check(reordered_variables)
|
521
|
+
check_circular_dependency(reordered_variables)
|
522
|
+
return reordered_variables
|
523
|
+
|
524
|
+
@staticmethod
|
525
|
+
def get_phrases(data):
|
526
|
+
str_content = []
|
527
|
+
for item in data:
|
528
|
+
if isinstance(item, str):
|
529
|
+
str_content.append(item)
|
530
|
+
return str_content
|
531
|
+
|
532
|
+
@staticmethod
|
533
|
+
def variable_generator(variables):
|
534
|
+
generators = VarGenerators(variables)
|
535
|
+
generators_list = generators.generator_list
|
536
|
+
forward_combinations = generators.forward_combinations
|
537
|
+
pairwise_combinations = generators.pairwise_combinations
|
538
|
+
return generators_list, forward_combinations, pairwise_combinations
|
539
|
+
|
540
|
+
@staticmethod
|
541
|
+
def get_any_items(item_list, any_list, data_type):
|
542
|
+
# model = config.model
|
543
|
+
response_format = {
|
544
|
+
"title": "List_of_values",
|
545
|
+
"description": "A list of string values.",
|
546
|
+
"type": "object",
|
547
|
+
"properties": {
|
548
|
+
"answer": {
|
549
|
+
"type": "array",
|
550
|
+
"items": {"type": "string"}
|
551
|
+
}
|
552
|
+
},
|
553
|
+
"required": ["answer"],
|
554
|
+
"additionalProperties": False
|
555
|
+
}
|
556
|
+
|
557
|
+
|
558
|
+
output_list = item_list.copy()
|
559
|
+
|
560
|
+
if any_list:
|
561
|
+
for data in any_list:
|
562
|
+
content = re.findall(r'any\((.*?)\)', data)
|
563
|
+
|
564
|
+
if data_type not in ("string", "float", "int"): # modifies "content" adding custom type prompts
|
565
|
+
type_yaml = config.types_dict.get(data_type)
|
566
|
+
|
567
|
+
type_description = f"The type of data is described as follows: {type_yaml['type_description']}"
|
568
|
+
type_format = f"Data follows the following format as a regular expression: {type_yaml['format']}"
|
569
|
+
content = f"{content}.{type_description}. {type_format}"
|
570
|
+
|
571
|
+
if llm is None:
|
572
|
+
logger.error("data gathering module not initialized.")
|
573
|
+
return ""
|
574
|
+
|
575
|
+
system = "You are a helpful assistant that creates a list of whatever the user asks."
|
576
|
+
message = f"A list of any of these: {content}. Avoid putting any of these: {output_list}"
|
577
|
+
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])
|
578
|
+
# input_message = parse_content_to_text(message)
|
579
|
+
input_message = system + message
|
580
|
+
|
581
|
+
if max_input_tokens_allowed(input_message, model_used=config.model):
|
582
|
+
logger.error(f"Token limit was surpassed")
|
583
|
+
return output_list
|
584
|
+
|
585
|
+
if config.token_count_enabled:
|
586
|
+
# params["max_completion_tokens"] = max_output_tokens_allowed(model)
|
587
|
+
llm.max_tokens = max_output_tokens_allowed(model)
|
588
|
+
|
589
|
+
structured_llm = llm.with_structured_output(response_format)
|
590
|
+
prompted_structured_llm = prompt | structured_llm
|
591
|
+
response = prompted_structured_llm.invoke({"input": message})
|
592
|
+
# response = client.chat.completions.create(**params)
|
593
|
+
|
594
|
+
try:
|
595
|
+
# raw_data = json.loads(response.choices[0].message.content)
|
596
|
+
# output_data = raw_data["answer"]
|
597
|
+
output_data = response["answer"]
|
598
|
+
ls_to_str = ", ".join(response["answer"])
|
599
|
+
calculate_cost(input_message, ls_to_str, model=model, module="goals_any_list")
|
600
|
+
except Exception as e:
|
601
|
+
logger.error(f"Truncated data in message: {response.choices[0].message.content}")
|
602
|
+
output_data = [None]
|
603
|
+
|
604
|
+
output_list += output_data
|
605
|
+
|
606
|
+
return output_list
|
607
|
+
else:
|
608
|
+
return output_list
|
609
|
+
|
610
|
+
|
611
|
+
def picked_element_already_in_list(self, match, value):
|
612
|
+
element_list = [list(element.keys())[0] for element in self.picked_elements]
|
613
|
+
if match.group(1) not in element_list:
|
614
|
+
self.picked_elements.append({match.group(1): value})
|
615
|
+
|
616
|
+
def replace_variables(self, generator):
|
617
|
+
pattern = re.compile(r'\{\{(.*?)\}\}')
|
618
|
+
if isinstance(generator['name'], list) and len(generator['name']) > 1: # this is for nested forwards
|
619
|
+
|
620
|
+
values = next(generator['generator'])
|
621
|
+
keys = generator['name']
|
622
|
+
mapped_combinations = dict(zip(keys, values))
|
623
|
+
self.picked_elements.extend([{key: value} for key, value in mapped_combinations.items()])
|
624
|
+
replaced_phrases = []
|
625
|
+
for phrase in self.phrases.copy():
|
626
|
+
def replace_variable(match):
|
627
|
+
variable = match.group(1)
|
628
|
+
return str(mapped_combinations.get(variable, match.group(0)))
|
629
|
+
|
630
|
+
replaced_phrase = re.sub(r'\{\{(\w+)\}\}', replace_variable, phrase)
|
631
|
+
replaced_phrases.append(replaced_phrase)
|
632
|
+
self.phrases = replaced_phrases
|
633
|
+
|
634
|
+
else: # this is for everything else
|
635
|
+
value = next(generator['generator'])
|
636
|
+
name = generator['name']
|
637
|
+
|
638
|
+
for index, text in enumerate(self.phrases):
|
639
|
+
matches = re.finditer(pattern, text)
|
640
|
+
for match in matches:
|
641
|
+
if match.group(1) == name:
|
642
|
+
self.picked_element_already_in_list(match, value)
|
643
|
+
# self.picked_elements.append({match.group(1): value})
|
644
|
+
replacement = ', '.join([str(v) for v in value])
|
645
|
+
text = text.replace(match.group(0), replacement)
|
646
|
+
self.phrases[index] = text
|
647
|
+
break
|
648
|
+
else:
|
649
|
+
self.phrases[index] = text
|
650
|
+
|
651
|
+
|
652
|
+
|
653
|
+
|
654
|
+
def ask_about_processor(self):
|
655
|
+
for generator in self.var_generators:
|
656
|
+
self.replace_variables(generator)
|
657
|
+
return self.phrases
|
658
|
+
|
659
|
+
def prompt(self):
|
660
|
+
phrases = self.ask_about_processor()
|
661
|
+
return list_to_phrase(phrases, True)
|
662
|
+
|
663
|
+
def reset(self):
|
664
|
+
self.picked_elements = []
|
665
|
+
self.phrases = self.str_list.copy()
|