h-adminsim 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. h_adminsim/__init__.py +5 -0
  2. h_adminsim/admin_staff.py +280 -0
  3. h_adminsim/assets/configs/data4primary.yaml +47 -0
  4. h_adminsim/assets/configs/data4secondary.yaml +47 -0
  5. h_adminsim/assets/configs/data4tertiary.yaml +47 -0
  6. h_adminsim/assets/country/address.json +141859 -0
  7. h_adminsim/assets/country/country_code.json +244 -0
  8. h_adminsim/assets/departments/department.json +85 -0
  9. h_adminsim/assets/departments/symptom.json +4530 -0
  10. h_adminsim/assets/fhir.schema.json +75253 -0
  11. h_adminsim/assets/names/firstname.txt +1219 -0
  12. h_adminsim/assets/names/lastname.txt +88799 -0
  13. h_adminsim/assets/prompts/cancel_patient_system.txt +38 -0
  14. h_adminsim/assets/prompts/intake_staff_task_user.txt +16 -0
  15. h_adminsim/assets/prompts/intake_supervisor_system.txt +8 -0
  16. h_adminsim/assets/prompts/intake_supervisor_user.txt +31 -0
  17. h_adminsim/assets/prompts/reschedule_patient_system.txt +38 -0
  18. h_adminsim/assets/prompts/schedule_patient_rejected_system.txt +42 -0
  19. h_adminsim/assets/prompts/schedule_patient_system.txt +36 -0
  20. h_adminsim/assets/prompts/schedule_staff_reasoning.txt +57 -0
  21. h_adminsim/assets/prompts/schedule_staff_sc_tool_calling.txt +13 -0
  22. h_adminsim/assets/prompts/schedule_staff_system.txt +10 -0
  23. h_adminsim/assets/prompts/schedule_staff_tool_calling.txt +41 -0
  24. h_adminsim/client/__init__.py +3 -0
  25. h_adminsim/client/google_client.py +209 -0
  26. h_adminsim/client/openai_client.py +199 -0
  27. h_adminsim/client/vllm_client.py +160 -0
  28. h_adminsim/environment/__init__.py +1 -0
  29. h_adminsim/environment/hospital.py +462 -0
  30. h_adminsim/environment/op_scheduling_simulation.py +1126 -0
  31. h_adminsim/pipeline/__init__.py +3 -0
  32. h_adminsim/pipeline/data_generator.py +192 -0
  33. h_adminsim/pipeline/evaluator.py +33 -0
  34. h_adminsim/pipeline/simulation.py +231 -0
  35. h_adminsim/registry/__init__.py +5 -0
  36. h_adminsim/registry/errors.py +89 -0
  37. h_adminsim/registry/models.py +126 -0
  38. h_adminsim/registry/phrases.py +10 -0
  39. h_adminsim/registry/pydantic_models.py +21 -0
  40. h_adminsim/registry/variables.py +9 -0
  41. h_adminsim/supervisor.py +182 -0
  42. h_adminsim/task/agent_task.py +900 -0
  43. h_adminsim/task/fhir_manager.py +222 -0
  44. h_adminsim/task/schedule_assign.py +151 -0
  45. h_adminsim/tools/__init__.py +5 -0
  46. h_adminsim/tools/agent_data_builder.py +124 -0
  47. h_adminsim/tools/data_converter.py +536 -0
  48. h_adminsim/tools/data_synthesizer.py +365 -0
  49. h_adminsim/tools/evaluator.py +258 -0
  50. h_adminsim/tools/sanity_checker.py +216 -0
  51. h_adminsim/tools/scheduling_rule.py +420 -0
  52. h_adminsim/utils/__init__.py +136 -0
  53. h_adminsim/utils/common_utils.py +698 -0
  54. h_adminsim/utils/fhir_utils.py +190 -0
  55. h_adminsim/utils/filesys_utils.py +135 -0
  56. h_adminsim/utils/image_preprocess_utils.py +188 -0
  57. h_adminsim/utils/random_utils.py +358 -0
  58. h_adminsim/version.txt +1 -0
  59. h_adminsim-1.0.0.dist-info/LICENSE +30 -0
  60. h_adminsim-1.0.0.dist-info/METADATA +494 -0
  61. h_adminsim-1.0.0.dist-info/RECORD +62 -0
  62. h_adminsim-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,365 @@
1
+ import os
2
+ import random
3
+ from tqdm import tqdm
4
+ from importlib import resources
5
+ from typing import Optional, Tuple
6
+ from decimal import Decimal, getcontext
7
+
8
+ from h_adminsim.task.schedule_assign import ScheduleAssigner
9
+ from h_adminsim.utils import Information, log, colorstr
10
+ from h_adminsim.utils.common_utils import *
11
+ from h_adminsim.utils.filesys_utils import json_load, txt_load, yaml_save, make_project_dir, json_save_fast
12
+ from h_adminsim.utils.random_utils import (
13
+ generate_random_prob,
14
+ generate_random_date,
15
+ generate_random_code,
16
+ generate_random_names,
17
+ generate_random_address,
18
+ generate_random_telecom,
19
+ generate_random_id_number,
20
+ generate_random_specialty,
21
+ generate_random_code_with_prob,
22
+ )
23
+
24
+
25
+
26
+ class DataSynthesizer:
27
+ def __init__(self, config):
28
+ # Initialize configuration, path and save directory
29
+ self.config = config
30
+ self._n = self.config.hospital_data.hospital_n
31
+ self._save_dir = make_project_dir(self.config)
32
+ self._data_save_dir = self._save_dir / 'data'
33
+ yaml_save(self._save_dir / 'args.yaml', self.config)
34
+ os.makedirs(self._data_save_dir, exist_ok=True)
35
+ getcontext().prec = 10
36
+
37
+
38
+ def synthesize(self,
39
+ return_obj: bool = False,
40
+ sanity_check: bool = False) -> Tuple[list[Information], list[Hospital]]:
41
+ """
42
+ Synthesize hospital data based on the configuration settings.
43
+
44
+ Args:
45
+ return_obj (bool, optional): Whether to return the hospital data object.
46
+ sanity_check (bool, optional): If you want to check whether the generated data are compatible with the `Hospital` object,
47
+ you can use this option.
48
+
49
+ Raises:
50
+ e: Exception if data synthesis fails.
51
+
52
+ Returns:
53
+ Tuple[list[Information], list[Hospital]]: A tuple containing the synthesized hospital data as an Information object and a Hospital object.
54
+ """
55
+ if sanity_check:
56
+ return_obj = True
57
+
58
+ try:
59
+ all_data, all_hospitals = list(), list()
60
+ hospitals = DataSynthesizer.hospital_list_generator(self.config.hospital_data.hospital_n)
61
+ for i, hospital in tqdm(enumerate(hospitals), desc='Synthesizing data..', total=len(hospitals)):
62
+ data = DataSynthesizer.define_hospital_info(self.config, hospital)
63
+ hospital_obj = convert_info_to_obj(data) if return_obj else None
64
+ if sanity_check:
65
+ new_data = convert_obj_to_info(hospital_obj)
66
+ assert to_dict(data) == to_dict(new_data)
67
+ json_save_fast(self._data_save_dir / f'hospital_{padded_int(i, len(str(self._n)))}.json', to_dict(data))
68
+ all_data.append(data)
69
+ all_hospitals.append(hospital_obj)
70
+ log(f"Total {len(hospitals)} data synthesizing completed. Path: `{self._data_save_dir}`", color=True)
71
+ return all_data, all_hospitals
72
+
73
+ except Exception as e:
74
+ log(f"Data synthesizing failed: {e}", level='error')
75
+ raise e
76
+
77
+
78
+ @staticmethod
79
+ def define_hospital_info(config, hospital_name: str) -> Information:
80
+ """
81
+ Define the synthetic hospital data, including its departments and doctors.
82
+
83
+ Args:
84
+ config: Configuration object containing hospital data settings.
85
+ hospital_name (str): Name of the hospital to be defined.
86
+
87
+ Returns:
88
+ Information: Synthetic data about the hospital.
89
+ """
90
+ # Define hosptial metadata
91
+ days = config.hospital_data.days
92
+ dates = generate_date_range(
93
+ generate_random_iso_date_between(
94
+ str(config.hospital_data.start_date.min),
95
+ str(config.hospital_data.start_date.max),
96
+ ),
97
+ days
98
+ )
99
+ interval_hour = float(config.hospital_data.interval_hour)
100
+ start_hour = float(random.randint(config.hospital_data.start_hour.min, config.hospital_data.start_hour.max))
101
+ end_hour = float(random.randint(config.hospital_data.end_hour.min, config.hospital_data.end_hour.max))
102
+ operation_hour_per_day = int(end_hour - start_hour)
103
+ department_n = random.randint(
104
+ config.hospital_data.department_per_hospital.min,
105
+ config.hospital_data.department_per_hospital.max
106
+ )
107
+ doctor_n_per_department = [random.randint(config.hospital_data.doctor_per_department.min, config.hospital_data.doctor_per_department.max)
108
+ for _ in range(department_n)]
109
+ doctor_n = sum(doctor_n_per_department)
110
+ doctor_capacity_per_hour_list = [c for c in range(config.hospital_data.doctor_capacity_per_hour.min, config.hospital_data.doctor_capacity_per_hour.max + 1) \
111
+ if float(Decimal(str(1))/Decimal(str(c)) % Decimal(str(interval_hour))) == 0]
112
+ hospital_time_segments = convert_time_to_segment(start_hour, end_hour, interval_hour)
113
+ metadata = Information(
114
+ hospital_name=hospital_name,
115
+ start_date=dates[0],
116
+ end_date=dates[-1],
117
+ days=days,
118
+ department_num=department_n,
119
+ doctor_num=doctor_n,
120
+ time=Information(
121
+ start_hour=start_hour,
122
+ end_hour=end_hour,
123
+ interval_hour=interval_hour
124
+ )
125
+ )
126
+
127
+ # Define ScheduleAssigner class to randomly assign schedules to each doctor
128
+ scheduler = ScheduleAssigner(start_hour, end_hour, interval_hour)
129
+
130
+ # Define detailed hospital department, doctoral, and patient information
131
+ department_info, doctor_info, patient_info = dict(), dict(), dict()
132
+ departments = DataSynthesizer.department_list_generator(department_n)
133
+ doctors = DataSynthesizer.name_list_generator(doctor_n, prefix='Dr. ') # Doctor names are unique across all departments
134
+ for department_data, doc_n in zip(departments, doctor_n_per_department):
135
+ department, dep_code = department_data
136
+
137
+ # Add department information
138
+ department_info[department] = {'code': dep_code if dep_code else 'NA', 'doctor': []}
139
+
140
+ # Add doctor information
141
+ for _ in range(doc_n):
142
+ doctor = doctors.pop()
143
+ department_info[department]['doctor'].append(doctor)
144
+ specialty, spe_code = generate_random_specialty(department)
145
+ capacity_per_hour = random.choice(doctor_capacity_per_hour_list)
146
+ working_days = random.randint(
147
+ config.hospital_data.working_days.min,
148
+ config.hospital_data.working_days.max
149
+ )
150
+ working_dates = sorted(random.sample(dates, working_days))
151
+ doctor_info[doctor] = {
152
+ 'department': department,
153
+ 'specialty': {
154
+ 'name': specialty,
155
+ 'code': spe_code,
156
+ },
157
+ 'schedule': {},
158
+ 'capacity_per_hour': int(capacity_per_hour),
159
+ 'capacity': int(capacity_per_hour * operation_hour_per_day * len(working_dates)),
160
+ 'gender': generate_random_code('gender'),
161
+ 'telecom': [{
162
+ 'system': 'phone',
163
+ 'value': generate_random_telecom(),
164
+ 'use': generate_random_code('use')
165
+ }],
166
+ 'birthDate': generate_random_date()
167
+ }
168
+ duration = int(1 / capacity_per_hour / interval_hour)
169
+
170
+ # Generate doctor schedules and apponitments based on the pre-defined days
171
+ for date in dates:
172
+ # Working day case
173
+ if date in working_dates:
174
+ schedule_segments, schedule_times = scheduler(
175
+ generate_random_prob(
176
+ config.hospital_data.doctor_has_schedule_prob,
177
+ config.hospital_data.schedule_coverage_ratio.min,
178
+ config.hospital_data.schedule_coverage_ratio.max
179
+ )
180
+ )
181
+ doctor_info[doctor]['schedule'][date] = schedule_times
182
+ # Not working day case
183
+ else:
184
+ schedule_segments, schedule_times = scheduler(1)
185
+ doctor_info[doctor]['schedule'][date] = schedule_times
186
+
187
+ # Add patient information per doctor
188
+ patient_segments = list(set(hospital_time_segments) - set(sum(schedule_segments, [])))
189
+ _, appointments = scheduler(
190
+ generate_random_prob(
191
+ 1,
192
+ config.hospital_data.appointment_coverage_ratio.min,
193
+ config.hospital_data.appointment_coverage_ratio.max
194
+ ),
195
+ True,
196
+ patient_segments,
197
+ min_chunk_size=duration,
198
+ max_chunk_size=duration
199
+ )
200
+ patients = DataSynthesizer.name_list_generator(len(appointments))
201
+ for patient, appointment in zip(patients, appointments):
202
+ preference = generate_random_code_with_prob(
203
+ config.hospital_data.preference.type,
204
+ config.hospital_data.preference.probs
205
+ )
206
+ preference_rank = DataSynthesizer.second_preference_generator(preference)
207
+ symptom_level = generate_random_code_with_prob(
208
+ config.hospital_data.symptom.type,
209
+ config.hospital_data.symptom.probs
210
+ )
211
+ birth_date = generate_random_date()
212
+ patient_info[patient] = {
213
+ 'department': department,
214
+ 'attending_physician': doctor,
215
+ 'date': date,
216
+ 'schedule': appointment,
217
+ 'preference': preference_rank,
218
+ 'symptom_level': symptom_level,
219
+ 'gender': generate_random_code('gender'),
220
+ 'telecom': [{
221
+ 'system': 'phone',
222
+ 'value': generate_random_telecom(),
223
+ 'use': generate_random_code('use')
224
+ }],
225
+ 'birthDate': birth_date,
226
+ 'identifier': [{
227
+ 'value': generate_random_id_number(birth_date=birth_date),
228
+ 'use': 'official'
229
+
230
+ }],
231
+ 'address': [{
232
+ 'type': 'postal',
233
+ 'text': generate_random_address(),
234
+ 'use': 'home'
235
+ }]
236
+ }
237
+
238
+ # Finalize data structure
239
+ data = Information(
240
+ metadata=metadata,
241
+ department=department_info,
242
+ doctor=doctor_info,
243
+ patient=patient_info,
244
+ )
245
+
246
+ # Data sanity check
247
+ if len(data.department) != metadata.department_num:
248
+ raise AssertionError(colorstr('red', 'Department number mismatch'))
249
+ if len(data.department) != len(set(doc['department'] for doc in data.doctor.values())):
250
+ raise AssertionError(colorstr('red', 'Department number mismatch'))
251
+ if len(data.doctor) != metadata.doctor_num:
252
+ raise AssertionError(colorstr('red', 'Doctor number mismatch'))
253
+ if len(data.doctor) != sum(len(dept['doctor']) for dept in data.department.values()):
254
+ raise AssertionError(colorstr('red', 'Doctor number mismatch'))
255
+
256
+ return data
257
+
258
+
259
+ @staticmethod
260
+ def hospital_list_generator(hospital_n: int,
261
+ file_path: Optional[str] = None) -> list[str]:
262
+ """
263
+ Generate a list of hospital names based on the number of hospitals.
264
+
265
+ Args:
266
+ hospital_n (int): Number of hospitals to generate.
267
+ file_path (Optional[str], optional): Path to a file containing hospital names. If provided, it will be used to load names.
268
+
269
+ Returns:
270
+ list[str]: List of hospital names in the format "Hospital 001", "Hospital 002", etc.
271
+ """
272
+ if file_path:
273
+ if registry.HOSPITALS is None:
274
+ registry.HOSPITALS = [word.capitalize() for word in txt_load(file_path).split('\n') if word.strip()]
275
+ return [f"{random.choice(registry.HOSPITALS)}" for _ in range(hospital_n)]
276
+
277
+ zfill_l = len(str(hospital_n))
278
+ return [f"hospital_{padded_int(i, zfill_l)}" for i in range(hospital_n)]
279
+
280
+
281
+ @staticmethod
282
+ def department_list_generator(department_n: int,
283
+ file_path: Optional[str] = None) -> list[Tuple[str, str]]:
284
+ """
285
+ Generate a list of department names based on the number of departments.
286
+
287
+ Args:
288
+ department_n (int): Number of departments to generate.
289
+ file_path (Optional[str], optional): Path to a file containing department names. If provided, it will be used to load names. Defaults to None.
290
+
291
+ Returns:
292
+ list[Tuple[str, str]]: List of department names and their codes.
293
+ """
294
+ if file_path == None:
295
+ file_path = str(resources.files("h_adminsim.assets.departments").joinpath("department.json"))
296
+
297
+ if file_path:
298
+ if registry.DEPARTMENTS is None:
299
+ specialty = json_load(file_path)['specialty']
300
+ registry.DEPARTMENTS = [(k2, v2['code']) for v1 in specialty.values() for k2, v2 in v1['subspecialty'].items()]
301
+
302
+ if department_n > len(registry.DEPARTMENTS):
303
+ raise ValueError(f"Requested {department_n} departments, but only {len(registry.DEPARTMENTS)} available in {file_path}.")
304
+
305
+ return random.sample(registry.DEPARTMENTS, department_n)
306
+
307
+ zfill_l = len(str(department_n))
308
+ return [(f"department_{padded_int(i, zfill_l)}", None) for i in range(department_n)]
309
+
310
+
311
+ @staticmethod
312
+ def name_list_generator(n: int,
313
+ first_name_file_path: Optional[str] = None,
314
+ last_name_file_path: Optional[str] = None,
315
+ prefix: Optional[str] = None) -> list[str]:
316
+ """
317
+ Generate a list of names.
318
+
319
+ Args:
320
+ n (int): Number of doctors to generate.
321
+ first_name_file_path (Optional[str], optional): Path to a file containing first names. Defaults to None.
322
+ last_name_file_path (Optional[str], optional): Path to a file containing last names. Defaults to None.
323
+ prefix (Optional[str], optional): Prefix for to be generated names.
324
+
325
+ Returns:
326
+ list[str]: List of names.
327
+ """
328
+ if first_name_file_path == None:
329
+ first_name_file_path = str(resources.files("h_adminsim.assets.names").joinpath("firstname.txt"))
330
+ if last_name_file_path == None:
331
+ last_name_file_path = str(resources.files("h_adminsim.assets.names").joinpath("lastname.txt"))
332
+
333
+ if prefix != None:
334
+ assert isinstance(prefix, str), log("`prefix` must be a string type", "error")
335
+ names = [f'{prefix}{name}' for name in generate_random_names(n, first_name_file_path, last_name_file_path)]
336
+ else:
337
+ names = [name for name in generate_random_names(n, first_name_file_path, last_name_file_path)]
338
+ random.shuffle(names)
339
+ return names
340
+
341
+
342
+ @staticmethod
343
+ def second_preference_generator(preference: str) -> list[str]:
344
+ """
345
+ Generate a list of preferences based on the initial preference.
346
+
347
+ Args:
348
+ preference (str): First priority of preference.
349
+
350
+ Returns:
351
+ list[str]: List of preferences including first and second priority.
352
+ """
353
+ preference_list = [preference]
354
+
355
+ if preference == 'doctor':
356
+ second_preference = random.choice(['asap', 'date'])
357
+ preference_list.append(second_preference)
358
+ elif preference == 'date':
359
+ second_preference = random.choice(['asap', 'doctor'])
360
+ preference_list.append(second_preference)
361
+ elif preference == 'asap':
362
+ second_preference = random.choice(['date', 'doctor'])
363
+ preference_list.append(second_preference)
364
+
365
+ return preference_list
@@ -0,0 +1,258 @@
1
+ import os
2
+ import numpy as np
3
+ from collections import Counter
4
+
5
+ from h_adminsim.utils import log, colorstr
6
+ from h_adminsim.utils.filesys_utils import get_files, json_load
7
+ from h_adminsim.utils.image_preprocess_utils import draw_fail_donut_subplots
8
+
9
+
10
+
11
+ class Evaluator:
12
+ def __init__(self, path, human_eval=False):
13
+ self.path = path
14
+ self.files = get_files(self.path, '_result.json')
15
+ if human_eval:
16
+ self.human_eval_files = get_files(self.path, '.txt')
17
+
18
+ try:
19
+ self.dialog_files = get_files(self.path, '_dialog.json')
20
+ except:
21
+ pass
22
+
23
+
24
+ def task_evaluation(self):
25
+ """
26
+ Perform micro-wise evaluation on the aggregated results.
27
+ """
28
+ aggregated_results = dict()
29
+ for file in self.files:
30
+ data = json_load(file)
31
+
32
+ for task, value in data.items():
33
+ if not task in aggregated_results:
34
+ aggregated_results[task] = {'status': [], 'status_code': []}
35
+
36
+ aggregated_results[task]['status'].append(value['status'])
37
+ aggregated_results[task]['status_code'].append(value['status_code'])
38
+
39
+ # Macro-wise evaluation
40
+ log('--------------Macro-wise Evaluation--------------')
41
+ for task, value in aggregated_results.items():
42
+ accuracies = [sum(x if isinstance(x, bool) else sum(x) for x in status) / sum(1 if isinstance(x, bool) else len(x) for x in status) * 100 for status in value['status']]
43
+ avg_accuracy = sum(accuracies) / len(accuracies)
44
+ stdv = round((sum((x - avg_accuracy) ** 2 for x in accuracies) / len(accuracies)) ** 0.5, 2) if len(accuracies) > 1 else 0.0
45
+ log(f'{colorstr(task):<27} | average accuracy: {colorstr("green", f"{avg_accuracy:.2f}% ± {stdv}")}, files: {len(accuracies)}')
46
+ log(f' - Individual accuracies: {", ".join([colorstr("green", f"{acc:.2f}%") for acc in accuracies])}')
47
+
48
+
49
+ # Micro-wise evaluation
50
+ log('')
51
+ log('--------------Micro-wise Evaluation--------------')
52
+ fail_data_dict = dict()
53
+ for task, value in aggregated_results.items():
54
+ status = [x for y in sum(value['status'], []) for x in (y if isinstance(y, list) or isinstance(y, tuple) else [y])]
55
+ status_code = [x for y in sum(value['status_code'], []) for x in (y if isinstance(y, list) or isinstance(y, tuple) else [y])]
56
+ accuracy = sum(status) / len(status) * 100
57
+ failed_cases = [c for s, c in zip(status, status_code) if not s and 'unexpected' not in c]
58
+ error_rate = (len(failed_cases) / len(status)) * 100
59
+ log(f'{colorstr(task):<27} | accuracy: {colorstr("green", f"{accuracy:.2f}%")}, length: {sum(status)} / {len(status)}')
60
+ log(f'{colorstr(task):<27} | Error : {colorstr("red", f"{error_rate:.2f}%")}, length: {len(failed_cases)} / {len(status)}')
61
+
62
+ if failed_cases:
63
+ fail_summary = Counter(failed_cases)
64
+ reschedule_fail_summary = Counter()
65
+
66
+ for k, v in list(fail_summary.items()):
67
+ if k.startswith("reschedule:") and 'identify' not in k and 'unexpected' not in k:
68
+ norm_key = k.replace("reschedule:", "").strip()
69
+ fail_summary[norm_key] += v
70
+ reschedule_fail_summary[norm_key] += v
71
+ fail_summary.pop(k)
72
+
73
+ for fail_type, count in fail_summary.items():
74
+ percent = (count / len(failed_cases)) * 100
75
+ reschedule_n = reschedule_fail_summary[fail_type] if fail_type in reschedule_fail_summary else 0
76
+ if reschedule_n:
77
+ log(f' - Fail type {colorstr("red", fail_type):<30}: {count} (reschedule: {reschedule_n}) cases ({percent:.2f}%)')
78
+ else:
79
+ log(f' - Fail type {colorstr("red", fail_type):<30}: {count} cases ({percent:.2f}%)')
80
+ fail_data_dict[task] = failed_cases
81
+
82
+ draw_fail_donut_subplots(fail_data_dict, os.path.join(self.path, 'fails.png'))
83
+
84
+
85
+ def ipi_evaluation(self):
86
+ """
87
+ Micro-wise IPI performance evaluation on the aggregated results.
88
+ """
89
+ aggregated_results = dict()
90
+ for file in self.files:
91
+ data = json_load(file)
92
+
93
+ if not 'intake' in aggregated_results:
94
+ aggregated_results['intake'] = {'status': [], 'status_code': []}
95
+
96
+ aggregated_results['intake']['status'].append(data['intake']['status'])
97
+ aggregated_results['intake']['status_code'].append(data['intake']['status_code'])
98
+
99
+ # Micro-wise evaluation
100
+ log('')
101
+ log('------------------IPI Evaluation-----------------')
102
+ status = sum(aggregated_results['intake']['status'], [])
103
+ status_code = sum(aggregated_results['intake']['status_code'], [])
104
+ failed_cases = [c for s, c in zip(status, status_code) if not s]
105
+
106
+ if failed_cases:
107
+ if_err_count, ipi_err_count = 0, 0
108
+ fail_summary = Counter(failed_cases)
109
+ for fail_type, count in fail_summary.items():
110
+ if fail_type in ['incorrect department and patient information', 'incorrect patient information']:
111
+ ipi_err_count += count
112
+ elif fail_type in ['incorrect format']:
113
+ if_err_count += count
114
+
115
+ if_percent = (if_err_count / len(status)) * 100
116
+ ipi_percent = (ipi_err_count / len(status)) * 100
117
+ log(f' - Fail type {colorstr("red", "incorrect format"):<38}: {if_err_count} / {len(status)} ({if_percent:.2f}%)')
118
+ log(f' - Fail type {colorstr("red", "incorrect patient information"):<38}: {ipi_err_count} / {len(status)} ({ipi_percent:.2f}%)')
119
+
120
+
121
+ def supervisor_evaluation(self):
122
+ """
123
+ Evaluate the supervisor's necessity to intervene in tasks.
124
+ """
125
+ aggregated_results = dict()
126
+ for file in self.files:
127
+ data = json_load(file)
128
+
129
+ for task, value in data.items():
130
+ if not task in aggregated_results:
131
+ aggregated_results[task] = {'status': [], 'trial': []}
132
+
133
+ aggregated_results[task]['status'].append(value['status'])
134
+ aggregated_results[task]['trial'].append(value['trial'])
135
+
136
+ log('-----Supervisor (or feedback) Evaluation----')
137
+ for task, value in aggregated_results.items():
138
+ status = sum(value['status'], [])
139
+ trial = sum(value['trial'], [])
140
+
141
+ if task == 'intake':
142
+ total_length = len(status)
143
+ supervisor_effect_cnt, correct, error, tie = 0, 0, 0, 0
144
+ for t in trial:
145
+ if 'mismatch' in t[0]:
146
+ supervisor_effect_cnt += 1
147
+ if 'better' in t[0]:
148
+ correct += 1
149
+ elif 'worse' in t[0]:
150
+ error += 1
151
+ else:
152
+ tie += 1
153
+
154
+ correct_p = correct/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
155
+ error_p = error/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
156
+ tie_p = tie/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
157
+ log(f'{colorstr(task):<27} | length: {total_length}, effected: {supervisor_effect_cnt} ({(supervisor_effect_cnt/total_length)*100:.2f}%)')
158
+ log(f' - {colorstr("green", "correct")}: {correct} ({correct_p:.2f}%), {colorstr("red", "worse")}: {error} ({error_p:.2f}%), {colorstr("yellow", "tie")}: {tie} ({tie_p:.2f}%)')
159
+
160
+ elif task == 'schedule':
161
+ feedback_n = dict()
162
+ total_length = len(status)
163
+ supervisor_effect_cnt, correct, tie = 0, 0, 0
164
+ for t in trial:
165
+ if isinstance(t, list) and len(t) > 1:
166
+ supervisor_effect_cnt += 1
167
+ if t[-1] == 'pass':
168
+ correct += 1
169
+ feedback_n[len(t)-1] = feedback_n.setdefault(len(t)-1, 0) + 1
170
+ else:
171
+ tie += 1
172
+
173
+ desc = ', '.join([f'{f}-feedback: {n}' for f, n in sorted(feedback_n.items())])
174
+ correct_p = correct/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
175
+ tie_p = tie/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
176
+ log(f'{colorstr(task):<27} | length: {total_length}, effected: {supervisor_effect_cnt} ({(supervisor_effect_cnt/total_length)*100:.2f}%)')
177
+ log(f' - {colorstr("green", "correct")}: {correct} ({correct_p:.2f}%), {colorstr("yellow", "tie")}: {tie} ({tie_p:.2f}%)')
178
+ log(f' - Feedback distribution: {desc}')
179
+
180
+
181
+ def human_evaluation(self):
182
+ """
183
+ Aggregate and evaluate human evaluation results from text files.
184
+ """
185
+ scores = {'arena': dict(), 'score': dict()}
186
+ all_lines = list()
187
+ for file in self.human_eval_files:
188
+ with open(file, 'r') as f:
189
+ lines = f.readlines()
190
+ all_lines.extend([line.strip() for line in lines if line.strip()])
191
+
192
+ for line in all_lines:
193
+ arena, score_a, score_b, model_a, model_b = line.split('\t')
194
+ scores['arena'].setdefault(model_a, 0)
195
+ scores['arena'].setdefault(model_b, 0)
196
+ scores['score'].setdefault(model_a, [])
197
+ scores['score'].setdefault(model_b, [])
198
+
199
+ if arena == 'A':
200
+ scores['arena'][model_a] += 1
201
+ else:
202
+ scores['arena'][model_b] += 1
203
+
204
+ scores['score'][model_a].append(float(score_a))
205
+ scores['score'][model_b].append(float(score_b))
206
+
207
+ log('--------------Human Evaluation--------------')
208
+ for model in scores['arena'].keys():
209
+ arena_wins = scores['arena'][model]
210
+ score_list = scores['score'][model]
211
+ avg_score = sum(score_list) / len(score_list)
212
+ stdv = round((sum((x - avg_score) ** 2 for x in score_list) / len(score_list)) ** 0.5, 2) if len(score_list) > 1 else 0.0
213
+ log(f'{colorstr(model):<15} | Arena wins: {colorstr("green", str(arena_wins))}, Average score: {colorstr("green", f"{avg_score:.2f} ± {stdv}")}')
214
+
215
+
216
+ def department_evaluation(self):
217
+ """
218
+ Evaluate solely department prediction accuracy.
219
+ """
220
+ aggregated_results = {'intake': {'gt': [], 'pred': [], 'status': []}}
221
+
222
+ for file in self.files:
223
+ data = json_load(file)
224
+ aggregated_results['intake']['gt'].extend(data['intake']['gt'])
225
+ aggregated_results['intake']['pred'].extend(data['intake']['pred'])
226
+ aggregated_results['intake']['status'].extend(data['intake']['status'])
227
+
228
+ gt = aggregated_results['intake']['gt']
229
+ pred = aggregated_results['intake']['pred']
230
+ status = aggregated_results['intake']['status']
231
+ total_n, dept_err_n = len(gt), 0
232
+ for g, p, s in zip(gt, pred, status):
233
+ if not s:
234
+ gt_depts = g['department']
235
+ pred_dept = p['department'][0]
236
+
237
+ if pred_dept not in gt_depts:
238
+ dept_err_n += 1
239
+
240
+ log('--------------Department Evaluation--------------')
241
+ log(f'Error rate: {colorstr("red", f"{(dept_err_n/total_n)*100:.2f}%")}, length: {dept_err_n} / {total_n}')
242
+
243
+
244
+ def calculate_avg_rounds(self):
245
+ """
246
+ Calculate average required intake rounds
247
+ """
248
+ counts = list()
249
+ for file in self.dialog_files:
250
+ data = json_load(file)
251
+ dialogs = list(data.values())
252
+ for dialog in dialogs:
253
+ counts.append(dialog.count('Staff: ')-1)
254
+
255
+ mean, stdv = np.mean(counts), np.std(counts)
256
+ log('-----------------Average Rounds-----------------')
257
+ log(f'Average Rounds: {mean:.2f} ± {stdv:.2f}')
258
+