h-adminsim 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- h_adminsim/__init__.py +5 -0
- h_adminsim/admin_staff.py +280 -0
- h_adminsim/assets/configs/data4primary.yaml +47 -0
- h_adminsim/assets/configs/data4secondary.yaml +47 -0
- h_adminsim/assets/configs/data4tertiary.yaml +47 -0
- h_adminsim/assets/country/address.json +141859 -0
- h_adminsim/assets/country/country_code.json +244 -0
- h_adminsim/assets/departments/department.json +85 -0
- h_adminsim/assets/departments/symptom.json +4530 -0
- h_adminsim/assets/fhir.schema.json +75253 -0
- h_adminsim/assets/names/firstname.txt +1219 -0
- h_adminsim/assets/names/lastname.txt +88799 -0
- h_adminsim/assets/prompts/cancel_patient_system.txt +38 -0
- h_adminsim/assets/prompts/intake_staff_task_user.txt +16 -0
- h_adminsim/assets/prompts/intake_supervisor_system.txt +8 -0
- h_adminsim/assets/prompts/intake_supervisor_user.txt +31 -0
- h_adminsim/assets/prompts/reschedule_patient_system.txt +38 -0
- h_adminsim/assets/prompts/schedule_patient_rejected_system.txt +42 -0
- h_adminsim/assets/prompts/schedule_patient_system.txt +36 -0
- h_adminsim/assets/prompts/schedule_staff_reasoning.txt +57 -0
- h_adminsim/assets/prompts/schedule_staff_sc_tool_calling.txt +13 -0
- h_adminsim/assets/prompts/schedule_staff_system.txt +10 -0
- h_adminsim/assets/prompts/schedule_staff_tool_calling.txt +41 -0
- h_adminsim/client/__init__.py +3 -0
- h_adminsim/client/google_client.py +209 -0
- h_adminsim/client/openai_client.py +199 -0
- h_adminsim/client/vllm_client.py +160 -0
- h_adminsim/environment/__init__.py +1 -0
- h_adminsim/environment/hospital.py +462 -0
- h_adminsim/environment/op_scheduling_simulation.py +1126 -0
- h_adminsim/pipeline/__init__.py +3 -0
- h_adminsim/pipeline/data_generator.py +192 -0
- h_adminsim/pipeline/evaluator.py +33 -0
- h_adminsim/pipeline/simulation.py +231 -0
- h_adminsim/registry/__init__.py +5 -0
- h_adminsim/registry/errors.py +89 -0
- h_adminsim/registry/models.py +126 -0
- h_adminsim/registry/phrases.py +10 -0
- h_adminsim/registry/pydantic_models.py +21 -0
- h_adminsim/registry/variables.py +9 -0
- h_adminsim/supervisor.py +182 -0
- h_adminsim/task/agent_task.py +900 -0
- h_adminsim/task/fhir_manager.py +222 -0
- h_adminsim/task/schedule_assign.py +151 -0
- h_adminsim/tools/__init__.py +5 -0
- h_adminsim/tools/agent_data_builder.py +124 -0
- h_adminsim/tools/data_converter.py +536 -0
- h_adminsim/tools/data_synthesizer.py +365 -0
- h_adminsim/tools/evaluator.py +258 -0
- h_adminsim/tools/sanity_checker.py +216 -0
- h_adminsim/tools/scheduling_rule.py +420 -0
- h_adminsim/utils/__init__.py +136 -0
- h_adminsim/utils/common_utils.py +698 -0
- h_adminsim/utils/fhir_utils.py +190 -0
- h_adminsim/utils/filesys_utils.py +135 -0
- h_adminsim/utils/image_preprocess_utils.py +188 -0
- h_adminsim/utils/random_utils.py +358 -0
- h_adminsim/version.txt +1 -0
- h_adminsim-1.0.0.dist-info/LICENSE +30 -0
- h_adminsim-1.0.0.dist-info/METADATA +494 -0
- h_adminsim-1.0.0.dist-info/RECORD +62 -0
- h_adminsim-1.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
from importlib import resources
|
|
5
|
+
from typing import Optional, Tuple
|
|
6
|
+
from decimal import Decimal, getcontext
|
|
7
|
+
|
|
8
|
+
from h_adminsim.task.schedule_assign import ScheduleAssigner
|
|
9
|
+
from h_adminsim.utils import Information, log, colorstr
|
|
10
|
+
from h_adminsim.utils.common_utils import *
|
|
11
|
+
from h_adminsim.utils.filesys_utils import json_load, txt_load, yaml_save, make_project_dir, json_save_fast
|
|
12
|
+
from h_adminsim.utils.random_utils import (
|
|
13
|
+
generate_random_prob,
|
|
14
|
+
generate_random_date,
|
|
15
|
+
generate_random_code,
|
|
16
|
+
generate_random_names,
|
|
17
|
+
generate_random_address,
|
|
18
|
+
generate_random_telecom,
|
|
19
|
+
generate_random_id_number,
|
|
20
|
+
generate_random_specialty,
|
|
21
|
+
generate_random_code_with_prob,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DataSynthesizer:
|
|
27
|
+
def __init__(self, config):
|
|
28
|
+
# Initialize configuration, path and save directory
|
|
29
|
+
self.config = config
|
|
30
|
+
self._n = self.config.hospital_data.hospital_n
|
|
31
|
+
self._save_dir = make_project_dir(self.config)
|
|
32
|
+
self._data_save_dir = self._save_dir / 'data'
|
|
33
|
+
yaml_save(self._save_dir / 'args.yaml', self.config)
|
|
34
|
+
os.makedirs(self._data_save_dir, exist_ok=True)
|
|
35
|
+
getcontext().prec = 10
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def synthesize(self,
|
|
39
|
+
return_obj: bool = False,
|
|
40
|
+
sanity_check: bool = False) -> Tuple[list[Information], list[Hospital]]:
|
|
41
|
+
"""
|
|
42
|
+
Synthesize hospital data based on the configuration settings.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
return_obj (bool, optional): Whether to return the hospital data object.
|
|
46
|
+
sanity_check (bool, optional): If you want to check whether the generated data are compatible with the `Hospital` object,
|
|
47
|
+
you can use this option.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
e: Exception if data synthesis fails.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Tuple[list[Information], list[Hospital]]: A tuple containing the synthesized hospital data as an Information object and a Hospital object.
|
|
54
|
+
"""
|
|
55
|
+
if sanity_check:
|
|
56
|
+
return_obj = True
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
all_data, all_hospitals = list(), list()
|
|
60
|
+
hospitals = DataSynthesizer.hospital_list_generator(self.config.hospital_data.hospital_n)
|
|
61
|
+
for i, hospital in tqdm(enumerate(hospitals), desc='Synthesizing data..', total=len(hospitals)):
|
|
62
|
+
data = DataSynthesizer.define_hospital_info(self.config, hospital)
|
|
63
|
+
hospital_obj = convert_info_to_obj(data) if return_obj else None
|
|
64
|
+
if sanity_check:
|
|
65
|
+
new_data = convert_obj_to_info(hospital_obj)
|
|
66
|
+
assert to_dict(data) == to_dict(new_data)
|
|
67
|
+
json_save_fast(self._data_save_dir / f'hospital_{padded_int(i, len(str(self._n)))}.json', to_dict(data))
|
|
68
|
+
all_data.append(data)
|
|
69
|
+
all_hospitals.append(hospital_obj)
|
|
70
|
+
log(f"Total {len(hospitals)} data synthesizing completed. Path: `{self._data_save_dir}`", color=True)
|
|
71
|
+
return all_data, all_hospitals
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
log(f"Data synthesizing failed: {e}", level='error')
|
|
75
|
+
raise e
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def define_hospital_info(config, hospital_name: str) -> Information:
|
|
80
|
+
"""
|
|
81
|
+
Define the synthetic hospital data, including its departments and doctors.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
config: Configuration object containing hospital data settings.
|
|
85
|
+
hospital_name (str): Name of the hospital to be defined.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Information: Synthetic data about the hospital.
|
|
89
|
+
"""
|
|
90
|
+
# Define hosptial metadata
|
|
91
|
+
days = config.hospital_data.days
|
|
92
|
+
dates = generate_date_range(
|
|
93
|
+
generate_random_iso_date_between(
|
|
94
|
+
str(config.hospital_data.start_date.min),
|
|
95
|
+
str(config.hospital_data.start_date.max),
|
|
96
|
+
),
|
|
97
|
+
days
|
|
98
|
+
)
|
|
99
|
+
interval_hour = float(config.hospital_data.interval_hour)
|
|
100
|
+
start_hour = float(random.randint(config.hospital_data.start_hour.min, config.hospital_data.start_hour.max))
|
|
101
|
+
end_hour = float(random.randint(config.hospital_data.end_hour.min, config.hospital_data.end_hour.max))
|
|
102
|
+
operation_hour_per_day = int(end_hour - start_hour)
|
|
103
|
+
department_n = random.randint(
|
|
104
|
+
config.hospital_data.department_per_hospital.min,
|
|
105
|
+
config.hospital_data.department_per_hospital.max
|
|
106
|
+
)
|
|
107
|
+
doctor_n_per_department = [random.randint(config.hospital_data.doctor_per_department.min, config.hospital_data.doctor_per_department.max)
|
|
108
|
+
for _ in range(department_n)]
|
|
109
|
+
doctor_n = sum(doctor_n_per_department)
|
|
110
|
+
doctor_capacity_per_hour_list = [c for c in range(config.hospital_data.doctor_capacity_per_hour.min, config.hospital_data.doctor_capacity_per_hour.max + 1) \
|
|
111
|
+
if float(Decimal(str(1))/Decimal(str(c)) % Decimal(str(interval_hour))) == 0]
|
|
112
|
+
hospital_time_segments = convert_time_to_segment(start_hour, end_hour, interval_hour)
|
|
113
|
+
metadata = Information(
|
|
114
|
+
hospital_name=hospital_name,
|
|
115
|
+
start_date=dates[0],
|
|
116
|
+
end_date=dates[-1],
|
|
117
|
+
days=days,
|
|
118
|
+
department_num=department_n,
|
|
119
|
+
doctor_num=doctor_n,
|
|
120
|
+
time=Information(
|
|
121
|
+
start_hour=start_hour,
|
|
122
|
+
end_hour=end_hour,
|
|
123
|
+
interval_hour=interval_hour
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Define ScheduleAssigner class to randomly assign schedules to each doctor
|
|
128
|
+
scheduler = ScheduleAssigner(start_hour, end_hour, interval_hour)
|
|
129
|
+
|
|
130
|
+
# Define detailed hospital department, doctoral, and patient information
|
|
131
|
+
department_info, doctor_info, patient_info = dict(), dict(), dict()
|
|
132
|
+
departments = DataSynthesizer.department_list_generator(department_n)
|
|
133
|
+
doctors = DataSynthesizer.name_list_generator(doctor_n, prefix='Dr. ') # Doctor names are unique across all departments
|
|
134
|
+
for department_data, doc_n in zip(departments, doctor_n_per_department):
|
|
135
|
+
department, dep_code = department_data
|
|
136
|
+
|
|
137
|
+
# Add department information
|
|
138
|
+
department_info[department] = {'code': dep_code if dep_code else 'NA', 'doctor': []}
|
|
139
|
+
|
|
140
|
+
# Add doctor information
|
|
141
|
+
for _ in range(doc_n):
|
|
142
|
+
doctor = doctors.pop()
|
|
143
|
+
department_info[department]['doctor'].append(doctor)
|
|
144
|
+
specialty, spe_code = generate_random_specialty(department)
|
|
145
|
+
capacity_per_hour = random.choice(doctor_capacity_per_hour_list)
|
|
146
|
+
working_days = random.randint(
|
|
147
|
+
config.hospital_data.working_days.min,
|
|
148
|
+
config.hospital_data.working_days.max
|
|
149
|
+
)
|
|
150
|
+
working_dates = sorted(random.sample(dates, working_days))
|
|
151
|
+
doctor_info[doctor] = {
|
|
152
|
+
'department': department,
|
|
153
|
+
'specialty': {
|
|
154
|
+
'name': specialty,
|
|
155
|
+
'code': spe_code,
|
|
156
|
+
},
|
|
157
|
+
'schedule': {},
|
|
158
|
+
'capacity_per_hour': int(capacity_per_hour),
|
|
159
|
+
'capacity': int(capacity_per_hour * operation_hour_per_day * len(working_dates)),
|
|
160
|
+
'gender': generate_random_code('gender'),
|
|
161
|
+
'telecom': [{
|
|
162
|
+
'system': 'phone',
|
|
163
|
+
'value': generate_random_telecom(),
|
|
164
|
+
'use': generate_random_code('use')
|
|
165
|
+
}],
|
|
166
|
+
'birthDate': generate_random_date()
|
|
167
|
+
}
|
|
168
|
+
duration = int(1 / capacity_per_hour / interval_hour)
|
|
169
|
+
|
|
170
|
+
# Generate doctor schedules and apponitments based on the pre-defined days
|
|
171
|
+
for date in dates:
|
|
172
|
+
# Working day case
|
|
173
|
+
if date in working_dates:
|
|
174
|
+
schedule_segments, schedule_times = scheduler(
|
|
175
|
+
generate_random_prob(
|
|
176
|
+
config.hospital_data.doctor_has_schedule_prob,
|
|
177
|
+
config.hospital_data.schedule_coverage_ratio.min,
|
|
178
|
+
config.hospital_data.schedule_coverage_ratio.max
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
doctor_info[doctor]['schedule'][date] = schedule_times
|
|
182
|
+
# Not working day case
|
|
183
|
+
else:
|
|
184
|
+
schedule_segments, schedule_times = scheduler(1)
|
|
185
|
+
doctor_info[doctor]['schedule'][date] = schedule_times
|
|
186
|
+
|
|
187
|
+
# Add patient information per doctor
|
|
188
|
+
patient_segments = list(set(hospital_time_segments) - set(sum(schedule_segments, [])))
|
|
189
|
+
_, appointments = scheduler(
|
|
190
|
+
generate_random_prob(
|
|
191
|
+
1,
|
|
192
|
+
config.hospital_data.appointment_coverage_ratio.min,
|
|
193
|
+
config.hospital_data.appointment_coverage_ratio.max
|
|
194
|
+
),
|
|
195
|
+
True,
|
|
196
|
+
patient_segments,
|
|
197
|
+
min_chunk_size=duration,
|
|
198
|
+
max_chunk_size=duration
|
|
199
|
+
)
|
|
200
|
+
patients = DataSynthesizer.name_list_generator(len(appointments))
|
|
201
|
+
for patient, appointment in zip(patients, appointments):
|
|
202
|
+
preference = generate_random_code_with_prob(
|
|
203
|
+
config.hospital_data.preference.type,
|
|
204
|
+
config.hospital_data.preference.probs
|
|
205
|
+
)
|
|
206
|
+
preference_rank = DataSynthesizer.second_preference_generator(preference)
|
|
207
|
+
symptom_level = generate_random_code_with_prob(
|
|
208
|
+
config.hospital_data.symptom.type,
|
|
209
|
+
config.hospital_data.symptom.probs
|
|
210
|
+
)
|
|
211
|
+
birth_date = generate_random_date()
|
|
212
|
+
patient_info[patient] = {
|
|
213
|
+
'department': department,
|
|
214
|
+
'attending_physician': doctor,
|
|
215
|
+
'date': date,
|
|
216
|
+
'schedule': appointment,
|
|
217
|
+
'preference': preference_rank,
|
|
218
|
+
'symptom_level': symptom_level,
|
|
219
|
+
'gender': generate_random_code('gender'),
|
|
220
|
+
'telecom': [{
|
|
221
|
+
'system': 'phone',
|
|
222
|
+
'value': generate_random_telecom(),
|
|
223
|
+
'use': generate_random_code('use')
|
|
224
|
+
}],
|
|
225
|
+
'birthDate': birth_date,
|
|
226
|
+
'identifier': [{
|
|
227
|
+
'value': generate_random_id_number(birth_date=birth_date),
|
|
228
|
+
'use': 'official'
|
|
229
|
+
|
|
230
|
+
}],
|
|
231
|
+
'address': [{
|
|
232
|
+
'type': 'postal',
|
|
233
|
+
'text': generate_random_address(),
|
|
234
|
+
'use': 'home'
|
|
235
|
+
}]
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
# Finalize data structure
|
|
239
|
+
data = Information(
|
|
240
|
+
metadata=metadata,
|
|
241
|
+
department=department_info,
|
|
242
|
+
doctor=doctor_info,
|
|
243
|
+
patient=patient_info,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Data sanity check
|
|
247
|
+
if len(data.department) != metadata.department_num:
|
|
248
|
+
raise AssertionError(colorstr('red', 'Department number mismatch'))
|
|
249
|
+
if len(data.department) != len(set(doc['department'] for doc in data.doctor.values())):
|
|
250
|
+
raise AssertionError(colorstr('red', 'Department number mismatch'))
|
|
251
|
+
if len(data.doctor) != metadata.doctor_num:
|
|
252
|
+
raise AssertionError(colorstr('red', 'Doctor number mismatch'))
|
|
253
|
+
if len(data.doctor) != sum(len(dept['doctor']) for dept in data.department.values()):
|
|
254
|
+
raise AssertionError(colorstr('red', 'Doctor number mismatch'))
|
|
255
|
+
|
|
256
|
+
return data
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@staticmethod
|
|
260
|
+
def hospital_list_generator(hospital_n: int,
|
|
261
|
+
file_path: Optional[str] = None) -> list[str]:
|
|
262
|
+
"""
|
|
263
|
+
Generate a list of hospital names based on the number of hospitals.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
hospital_n (int): Number of hospitals to generate.
|
|
267
|
+
file_path (Optional[str], optional): Path to a file containing hospital names. If provided, it will be used to load names.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
list[str]: List of hospital names in the format "Hospital 001", "Hospital 002", etc.
|
|
271
|
+
"""
|
|
272
|
+
if file_path:
|
|
273
|
+
if registry.HOSPITALS is None:
|
|
274
|
+
registry.HOSPITALS = [word.capitalize() for word in txt_load(file_path).split('\n') if word.strip()]
|
|
275
|
+
return [f"{random.choice(registry.HOSPITALS)}" for _ in range(hospital_n)]
|
|
276
|
+
|
|
277
|
+
zfill_l = len(str(hospital_n))
|
|
278
|
+
return [f"hospital_{padded_int(i, zfill_l)}" for i in range(hospital_n)]
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
@staticmethod
|
|
282
|
+
def department_list_generator(department_n: int,
|
|
283
|
+
file_path: Optional[str] = None) -> list[Tuple[str, str]]:
|
|
284
|
+
"""
|
|
285
|
+
Generate a list of department names based on the number of departments.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
department_n (int): Number of departments to generate.
|
|
289
|
+
file_path (Optional[str], optional): Path to a file containing department names. If provided, it will be used to load names. Defaults to None.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
list[Tuple[str, str]]: List of department names and their codes.
|
|
293
|
+
"""
|
|
294
|
+
if file_path == None:
|
|
295
|
+
file_path = str(resources.files("h_adminsim.assets.departments").joinpath("department.json"))
|
|
296
|
+
|
|
297
|
+
if file_path:
|
|
298
|
+
if registry.DEPARTMENTS is None:
|
|
299
|
+
specialty = json_load(file_path)['specialty']
|
|
300
|
+
registry.DEPARTMENTS = [(k2, v2['code']) for v1 in specialty.values() for k2, v2 in v1['subspecialty'].items()]
|
|
301
|
+
|
|
302
|
+
if department_n > len(registry.DEPARTMENTS):
|
|
303
|
+
raise ValueError(f"Requested {department_n} departments, but only {len(registry.DEPARTMENTS)} available in {file_path}.")
|
|
304
|
+
|
|
305
|
+
return random.sample(registry.DEPARTMENTS, department_n)
|
|
306
|
+
|
|
307
|
+
zfill_l = len(str(department_n))
|
|
308
|
+
return [(f"department_{padded_int(i, zfill_l)}", None) for i in range(department_n)]
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
def name_list_generator(n: int,
|
|
313
|
+
first_name_file_path: Optional[str] = None,
|
|
314
|
+
last_name_file_path: Optional[str] = None,
|
|
315
|
+
prefix: Optional[str] = None) -> list[str]:
|
|
316
|
+
"""
|
|
317
|
+
Generate a list of names.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
n (int): Number of doctors to generate.
|
|
321
|
+
first_name_file_path (Optional[str], optional): Path to a file containing first names. Defaults to None.
|
|
322
|
+
last_name_file_path (Optional[str], optional): Path to a file containing last names. Defaults to None.
|
|
323
|
+
prefix (Optional[str], optional): Prefix for to be generated names.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
list[str]: List of names.
|
|
327
|
+
"""
|
|
328
|
+
if first_name_file_path == None:
|
|
329
|
+
first_name_file_path = str(resources.files("h_adminsim.assets.names").joinpath("firstname.txt"))
|
|
330
|
+
if last_name_file_path == None:
|
|
331
|
+
last_name_file_path = str(resources.files("h_adminsim.assets.names").joinpath("lastname.txt"))
|
|
332
|
+
|
|
333
|
+
if prefix != None:
|
|
334
|
+
assert isinstance(prefix, str), log("`prefix` must be a string type", "error")
|
|
335
|
+
names = [f'{prefix}{name}' for name in generate_random_names(n, first_name_file_path, last_name_file_path)]
|
|
336
|
+
else:
|
|
337
|
+
names = [name for name in generate_random_names(n, first_name_file_path, last_name_file_path)]
|
|
338
|
+
random.shuffle(names)
|
|
339
|
+
return names
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@staticmethod
|
|
343
|
+
def second_preference_generator(preference: str) -> list[str]:
|
|
344
|
+
"""
|
|
345
|
+
Generate a list of preferences based on the initial preference.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
preference (str): First priority of preference.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
list[str]: List of preferences including first and second priority.
|
|
352
|
+
"""
|
|
353
|
+
preference_list = [preference]
|
|
354
|
+
|
|
355
|
+
if preference == 'doctor':
|
|
356
|
+
second_preference = random.choice(['asap', 'date'])
|
|
357
|
+
preference_list.append(second_preference)
|
|
358
|
+
elif preference == 'date':
|
|
359
|
+
second_preference = random.choice(['asap', 'doctor'])
|
|
360
|
+
preference_list.append(second_preference)
|
|
361
|
+
elif preference == 'asap':
|
|
362
|
+
second_preference = random.choice(['date', 'doctor'])
|
|
363
|
+
preference_list.append(second_preference)
|
|
364
|
+
|
|
365
|
+
return preference_list
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from collections import Counter
|
|
4
|
+
|
|
5
|
+
from h_adminsim.utils import log, colorstr
|
|
6
|
+
from h_adminsim.utils.filesys_utils import get_files, json_load
|
|
7
|
+
from h_adminsim.utils.image_preprocess_utils import draw_fail_donut_subplots
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Evaluator:
|
|
12
|
+
def __init__(self, path, human_eval=False):
|
|
13
|
+
self.path = path
|
|
14
|
+
self.files = get_files(self.path, '_result.json')
|
|
15
|
+
if human_eval:
|
|
16
|
+
self.human_eval_files = get_files(self.path, '.txt')
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
self.dialog_files = get_files(self.path, '_dialog.json')
|
|
20
|
+
except:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def task_evaluation(self):
|
|
25
|
+
"""
|
|
26
|
+
Perform micro-wise evaluation on the aggregated results.
|
|
27
|
+
"""
|
|
28
|
+
aggregated_results = dict()
|
|
29
|
+
for file in self.files:
|
|
30
|
+
data = json_load(file)
|
|
31
|
+
|
|
32
|
+
for task, value in data.items():
|
|
33
|
+
if not task in aggregated_results:
|
|
34
|
+
aggregated_results[task] = {'status': [], 'status_code': []}
|
|
35
|
+
|
|
36
|
+
aggregated_results[task]['status'].append(value['status'])
|
|
37
|
+
aggregated_results[task]['status_code'].append(value['status_code'])
|
|
38
|
+
|
|
39
|
+
# Macro-wise evaluation
|
|
40
|
+
log('--------------Macro-wise Evaluation--------------')
|
|
41
|
+
for task, value in aggregated_results.items():
|
|
42
|
+
accuracies = [sum(x if isinstance(x, bool) else sum(x) for x in status) / sum(1 if isinstance(x, bool) else len(x) for x in status) * 100 for status in value['status']]
|
|
43
|
+
avg_accuracy = sum(accuracies) / len(accuracies)
|
|
44
|
+
stdv = round((sum((x - avg_accuracy) ** 2 for x in accuracies) / len(accuracies)) ** 0.5, 2) if len(accuracies) > 1 else 0.0
|
|
45
|
+
log(f'{colorstr(task):<27} | average accuracy: {colorstr("green", f"{avg_accuracy:.2f}% ± {stdv}")}, files: {len(accuracies)}')
|
|
46
|
+
log(f' - Individual accuracies: {", ".join([colorstr("green", f"{acc:.2f}%") for acc in accuracies])}')
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Micro-wise evaluation
|
|
50
|
+
log('')
|
|
51
|
+
log('--------------Micro-wise Evaluation--------------')
|
|
52
|
+
fail_data_dict = dict()
|
|
53
|
+
for task, value in aggregated_results.items():
|
|
54
|
+
status = [x for y in sum(value['status'], []) for x in (y if isinstance(y, list) or isinstance(y, tuple) else [y])]
|
|
55
|
+
status_code = [x for y in sum(value['status_code'], []) for x in (y if isinstance(y, list) or isinstance(y, tuple) else [y])]
|
|
56
|
+
accuracy = sum(status) / len(status) * 100
|
|
57
|
+
failed_cases = [c for s, c in zip(status, status_code) if not s and 'unexpected' not in c]
|
|
58
|
+
error_rate = (len(failed_cases) / len(status)) * 100
|
|
59
|
+
log(f'{colorstr(task):<27} | accuracy: {colorstr("green", f"{accuracy:.2f}%")}, length: {sum(status)} / {len(status)}')
|
|
60
|
+
log(f'{colorstr(task):<27} | Error : {colorstr("red", f"{error_rate:.2f}%")}, length: {len(failed_cases)} / {len(status)}')
|
|
61
|
+
|
|
62
|
+
if failed_cases:
|
|
63
|
+
fail_summary = Counter(failed_cases)
|
|
64
|
+
reschedule_fail_summary = Counter()
|
|
65
|
+
|
|
66
|
+
for k, v in list(fail_summary.items()):
|
|
67
|
+
if k.startswith("reschedule:") and 'identify' not in k and 'unexpected' not in k:
|
|
68
|
+
norm_key = k.replace("reschedule:", "").strip()
|
|
69
|
+
fail_summary[norm_key] += v
|
|
70
|
+
reschedule_fail_summary[norm_key] += v
|
|
71
|
+
fail_summary.pop(k)
|
|
72
|
+
|
|
73
|
+
for fail_type, count in fail_summary.items():
|
|
74
|
+
percent = (count / len(failed_cases)) * 100
|
|
75
|
+
reschedule_n = reschedule_fail_summary[fail_type] if fail_type in reschedule_fail_summary else 0
|
|
76
|
+
if reschedule_n:
|
|
77
|
+
log(f' - Fail type {colorstr("red", fail_type):<30}: {count} (reschedule: {reschedule_n}) cases ({percent:.2f}%)')
|
|
78
|
+
else:
|
|
79
|
+
log(f' - Fail type {colorstr("red", fail_type):<30}: {count} cases ({percent:.2f}%)')
|
|
80
|
+
fail_data_dict[task] = failed_cases
|
|
81
|
+
|
|
82
|
+
draw_fail_donut_subplots(fail_data_dict, os.path.join(self.path, 'fails.png'))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def ipi_evaluation(self):
|
|
86
|
+
"""
|
|
87
|
+
Micro-wise IPI performance evaluation on the aggregated results.
|
|
88
|
+
"""
|
|
89
|
+
aggregated_results = dict()
|
|
90
|
+
for file in self.files:
|
|
91
|
+
data = json_load(file)
|
|
92
|
+
|
|
93
|
+
if not 'intake' in aggregated_results:
|
|
94
|
+
aggregated_results['intake'] = {'status': [], 'status_code': []}
|
|
95
|
+
|
|
96
|
+
aggregated_results['intake']['status'].append(data['intake']['status'])
|
|
97
|
+
aggregated_results['intake']['status_code'].append(data['intake']['status_code'])
|
|
98
|
+
|
|
99
|
+
# Micro-wise evaluation
|
|
100
|
+
log('')
|
|
101
|
+
log('------------------IPI Evaluation-----------------')
|
|
102
|
+
status = sum(aggregated_results['intake']['status'], [])
|
|
103
|
+
status_code = sum(aggregated_results['intake']['status_code'], [])
|
|
104
|
+
failed_cases = [c for s, c in zip(status, status_code) if not s]
|
|
105
|
+
|
|
106
|
+
if failed_cases:
|
|
107
|
+
if_err_count, ipi_err_count = 0, 0
|
|
108
|
+
fail_summary = Counter(failed_cases)
|
|
109
|
+
for fail_type, count in fail_summary.items():
|
|
110
|
+
if fail_type in ['incorrect department and patient information', 'incorrect patient information']:
|
|
111
|
+
ipi_err_count += count
|
|
112
|
+
elif fail_type in ['incorrect format']:
|
|
113
|
+
if_err_count += count
|
|
114
|
+
|
|
115
|
+
if_percent = (if_err_count / len(status)) * 100
|
|
116
|
+
ipi_percent = (ipi_err_count / len(status)) * 100
|
|
117
|
+
log(f' - Fail type {colorstr("red", "incorrect format"):<38}: {if_err_count} / {len(status)} ({if_percent:.2f}%)')
|
|
118
|
+
log(f' - Fail type {colorstr("red", "incorrect patient information"):<38}: {ipi_err_count} / {len(status)} ({ipi_percent:.2f}%)')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def supervisor_evaluation(self):
|
|
122
|
+
"""
|
|
123
|
+
Evaluate the supervisor's necessity to intervene in tasks.
|
|
124
|
+
"""
|
|
125
|
+
aggregated_results = dict()
|
|
126
|
+
for file in self.files:
|
|
127
|
+
data = json_load(file)
|
|
128
|
+
|
|
129
|
+
for task, value in data.items():
|
|
130
|
+
if not task in aggregated_results:
|
|
131
|
+
aggregated_results[task] = {'status': [], 'trial': []}
|
|
132
|
+
|
|
133
|
+
aggregated_results[task]['status'].append(value['status'])
|
|
134
|
+
aggregated_results[task]['trial'].append(value['trial'])
|
|
135
|
+
|
|
136
|
+
log('-----Supervisor (or feedback) Evaluation----')
|
|
137
|
+
for task, value in aggregated_results.items():
|
|
138
|
+
status = sum(value['status'], [])
|
|
139
|
+
trial = sum(value['trial'], [])
|
|
140
|
+
|
|
141
|
+
if task == 'intake':
|
|
142
|
+
total_length = len(status)
|
|
143
|
+
supervisor_effect_cnt, correct, error, tie = 0, 0, 0, 0
|
|
144
|
+
for t in trial:
|
|
145
|
+
if 'mismatch' in t[0]:
|
|
146
|
+
supervisor_effect_cnt += 1
|
|
147
|
+
if 'better' in t[0]:
|
|
148
|
+
correct += 1
|
|
149
|
+
elif 'worse' in t[0]:
|
|
150
|
+
error += 1
|
|
151
|
+
else:
|
|
152
|
+
tie += 1
|
|
153
|
+
|
|
154
|
+
correct_p = correct/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
|
|
155
|
+
error_p = error/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
|
|
156
|
+
tie_p = tie/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
|
|
157
|
+
log(f'{colorstr(task):<27} | length: {total_length}, effected: {supervisor_effect_cnt} ({(supervisor_effect_cnt/total_length)*100:.2f}%)')
|
|
158
|
+
log(f' - {colorstr("green", "correct")}: {correct} ({correct_p:.2f}%), {colorstr("red", "worse")}: {error} ({error_p:.2f}%), {colorstr("yellow", "tie")}: {tie} ({tie_p:.2f}%)')
|
|
159
|
+
|
|
160
|
+
elif task == 'schedule':
|
|
161
|
+
feedback_n = dict()
|
|
162
|
+
total_length = len(status)
|
|
163
|
+
supervisor_effect_cnt, correct, tie = 0, 0, 0
|
|
164
|
+
for t in trial:
|
|
165
|
+
if isinstance(t, list) and len(t) > 1:
|
|
166
|
+
supervisor_effect_cnt += 1
|
|
167
|
+
if t[-1] == 'pass':
|
|
168
|
+
correct += 1
|
|
169
|
+
feedback_n[len(t)-1] = feedback_n.setdefault(len(t)-1, 0) + 1
|
|
170
|
+
else:
|
|
171
|
+
tie += 1
|
|
172
|
+
|
|
173
|
+
desc = ', '.join([f'{f}-feedback: {n}' for f, n in sorted(feedback_n.items())])
|
|
174
|
+
correct_p = correct/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
|
|
175
|
+
tie_p = tie/supervisor_effect_cnt*100 if supervisor_effect_cnt > 0 else 0
|
|
176
|
+
log(f'{colorstr(task):<27} | length: {total_length}, effected: {supervisor_effect_cnt} ({(supervisor_effect_cnt/total_length)*100:.2f}%)')
|
|
177
|
+
log(f' - {colorstr("green", "correct")}: {correct} ({correct_p:.2f}%), {colorstr("yellow", "tie")}: {tie} ({tie_p:.2f}%)')
|
|
178
|
+
log(f' - Feedback distribution: {desc}')
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def human_evaluation(self):
|
|
182
|
+
"""
|
|
183
|
+
Aggregate and evaluate human evaluation results from text files.
|
|
184
|
+
"""
|
|
185
|
+
scores = {'arena': dict(), 'score': dict()}
|
|
186
|
+
all_lines = list()
|
|
187
|
+
for file in self.human_eval_files:
|
|
188
|
+
with open(file, 'r') as f:
|
|
189
|
+
lines = f.readlines()
|
|
190
|
+
all_lines.extend([line.strip() for line in lines if line.strip()])
|
|
191
|
+
|
|
192
|
+
for line in all_lines:
|
|
193
|
+
arena, score_a, score_b, model_a, model_b = line.split('\t')
|
|
194
|
+
scores['arena'].setdefault(model_a, 0)
|
|
195
|
+
scores['arena'].setdefault(model_b, 0)
|
|
196
|
+
scores['score'].setdefault(model_a, [])
|
|
197
|
+
scores['score'].setdefault(model_b, [])
|
|
198
|
+
|
|
199
|
+
if arena == 'A':
|
|
200
|
+
scores['arena'][model_a] += 1
|
|
201
|
+
else:
|
|
202
|
+
scores['arena'][model_b] += 1
|
|
203
|
+
|
|
204
|
+
scores['score'][model_a].append(float(score_a))
|
|
205
|
+
scores['score'][model_b].append(float(score_b))
|
|
206
|
+
|
|
207
|
+
log('--------------Human Evaluation--------------')
|
|
208
|
+
for model in scores['arena'].keys():
|
|
209
|
+
arena_wins = scores['arena'][model]
|
|
210
|
+
score_list = scores['score'][model]
|
|
211
|
+
avg_score = sum(score_list) / len(score_list)
|
|
212
|
+
stdv = round((sum((x - avg_score) ** 2 for x in score_list) / len(score_list)) ** 0.5, 2) if len(score_list) > 1 else 0.0
|
|
213
|
+
log(f'{colorstr(model):<15} | Arena wins: {colorstr("green", str(arena_wins))}, Average score: {colorstr("green", f"{avg_score:.2f} ± {stdv}")}')
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def department_evaluation(self):
|
|
217
|
+
"""
|
|
218
|
+
Evaluate solely department prediction accuracy.
|
|
219
|
+
"""
|
|
220
|
+
aggregated_results = {'intake': {'gt': [], 'pred': [], 'status': []}}
|
|
221
|
+
|
|
222
|
+
for file in self.files:
|
|
223
|
+
data = json_load(file)
|
|
224
|
+
aggregated_results['intake']['gt'].extend(data['intake']['gt'])
|
|
225
|
+
aggregated_results['intake']['pred'].extend(data['intake']['pred'])
|
|
226
|
+
aggregated_results['intake']['status'].extend(data['intake']['status'])
|
|
227
|
+
|
|
228
|
+
gt = aggregated_results['intake']['gt']
|
|
229
|
+
pred = aggregated_results['intake']['pred']
|
|
230
|
+
status = aggregated_results['intake']['status']
|
|
231
|
+
total_n, dept_err_n = len(gt), 0
|
|
232
|
+
for g, p, s in zip(gt, pred, status):
|
|
233
|
+
if not s:
|
|
234
|
+
gt_depts = g['department']
|
|
235
|
+
pred_dept = p['department'][0]
|
|
236
|
+
|
|
237
|
+
if pred_dept not in gt_depts:
|
|
238
|
+
dept_err_n += 1
|
|
239
|
+
|
|
240
|
+
log('--------------Department Evaluation--------------')
|
|
241
|
+
log(f'Error rate: {colorstr("red", f"{(dept_err_n/total_n)*100:.2f}%")}, length: {dept_err_n} / {total_n}')
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def calculate_avg_rounds(self):
|
|
245
|
+
"""
|
|
246
|
+
Calculate average required intake rounds
|
|
247
|
+
"""
|
|
248
|
+
counts = list()
|
|
249
|
+
for file in self.dialog_files:
|
|
250
|
+
data = json_load(file)
|
|
251
|
+
dialogs = list(data.values())
|
|
252
|
+
for dialog in dialogs:
|
|
253
|
+
counts.append(dialog.count('Staff: ')-1)
|
|
254
|
+
|
|
255
|
+
mean, stdv = np.mean(counts), np.std(counts)
|
|
256
|
+
log('-----------------Average Rounds-----------------')
|
|
257
|
+
log(f'Average Rounds: {mean:.2f} ± {stdv:.2f}')
|
|
258
|
+
|