h-adminsim 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- h_adminsim/__init__.py +5 -0
- h_adminsim/admin_staff.py +280 -0
- h_adminsim/assets/configs/data4primary.yaml +47 -0
- h_adminsim/assets/configs/data4secondary.yaml +47 -0
- h_adminsim/assets/configs/data4tertiary.yaml +47 -0
- h_adminsim/assets/country/address.json +141859 -0
- h_adminsim/assets/country/country_code.json +244 -0
- h_adminsim/assets/departments/department.json +85 -0
- h_adminsim/assets/departments/symptom.json +4530 -0
- h_adminsim/assets/fhir.schema.json +75253 -0
- h_adminsim/assets/names/firstname.txt +1219 -0
- h_adminsim/assets/names/lastname.txt +88799 -0
- h_adminsim/assets/prompts/cancel_patient_system.txt +38 -0
- h_adminsim/assets/prompts/intake_staff_task_user.txt +16 -0
- h_adminsim/assets/prompts/intake_supervisor_system.txt +8 -0
- h_adminsim/assets/prompts/intake_supervisor_user.txt +31 -0
- h_adminsim/assets/prompts/reschedule_patient_system.txt +38 -0
- h_adminsim/assets/prompts/schedule_patient_rejected_system.txt +42 -0
- h_adminsim/assets/prompts/schedule_patient_system.txt +36 -0
- h_adminsim/assets/prompts/schedule_staff_reasoning.txt +57 -0
- h_adminsim/assets/prompts/schedule_staff_sc_tool_calling.txt +13 -0
- h_adminsim/assets/prompts/schedule_staff_system.txt +10 -0
- h_adminsim/assets/prompts/schedule_staff_tool_calling.txt +41 -0
- h_adminsim/client/__init__.py +3 -0
- h_adminsim/client/google_client.py +209 -0
- h_adminsim/client/openai_client.py +199 -0
- h_adminsim/client/vllm_client.py +160 -0
- h_adminsim/environment/__init__.py +1 -0
- h_adminsim/environment/hospital.py +462 -0
- h_adminsim/environment/op_scheduling_simulation.py +1126 -0
- h_adminsim/pipeline/__init__.py +3 -0
- h_adminsim/pipeline/data_generator.py +192 -0
- h_adminsim/pipeline/evaluator.py +33 -0
- h_adminsim/pipeline/simulation.py +231 -0
- h_adminsim/registry/__init__.py +5 -0
- h_adminsim/registry/errors.py +89 -0
- h_adminsim/registry/models.py +126 -0
- h_adminsim/registry/phrases.py +10 -0
- h_adminsim/registry/pydantic_models.py +21 -0
- h_adminsim/registry/variables.py +9 -0
- h_adminsim/supervisor.py +182 -0
- h_adminsim/task/agent_task.py +900 -0
- h_adminsim/task/fhir_manager.py +222 -0
- h_adminsim/task/schedule_assign.py +151 -0
- h_adminsim/tools/__init__.py +5 -0
- h_adminsim/tools/agent_data_builder.py +124 -0
- h_adminsim/tools/data_converter.py +536 -0
- h_adminsim/tools/data_synthesizer.py +365 -0
- h_adminsim/tools/evaluator.py +258 -0
- h_adminsim/tools/sanity_checker.py +216 -0
- h_adminsim/tools/scheduling_rule.py +420 -0
- h_adminsim/utils/__init__.py +136 -0
- h_adminsim/utils/common_utils.py +698 -0
- h_adminsim/utils/fhir_utils.py +190 -0
- h_adminsim/utils/filesys_utils.py +135 -0
- h_adminsim/utils/image_preprocess_utils.py +188 -0
- h_adminsim/utils/random_utils.py +358 -0
- h_adminsim/version.txt +1 -0
- h_adminsim-1.0.0.dist-info/LICENSE +30 -0
- h_adminsim-1.0.0.dist-info/METADATA +494 -0
- h_adminsim-1.0.0.dist-info/RECORD +62 -0
- h_adminsim-1.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,900 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import random
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from decimal import getcontext
|
|
6
|
+
from importlib import resources
|
|
7
|
+
from typing import Tuple, Union, Optional
|
|
8
|
+
from dotenv import load_dotenv, find_dotenv
|
|
9
|
+
|
|
10
|
+
from patientsim import PatientAgent
|
|
11
|
+
from patientsim import AdminStaffAgent as IntakeAdminStaffAgent
|
|
12
|
+
from patientsim.environment import OPSimulation as OPFVIntakeSimulation
|
|
13
|
+
|
|
14
|
+
from h_adminsim import SupervisorAgent
|
|
15
|
+
from h_adminsim import AdminStaffAgent as SchedulingAdminStaffAgent
|
|
16
|
+
from h_adminsim.environment.hospital import HospitalEnvironment
|
|
17
|
+
from h_adminsim.environment import OPScehdulingSimulation as OPFVScheduleSimulation
|
|
18
|
+
from h_adminsim.tools.sanity_checker import SanityChecker
|
|
19
|
+
from h_adminsim.tools import DataConverter, SchedulingRule
|
|
20
|
+
from h_adminsim.registry import STATUS_CODES, PREFERENCE_PHRASE_PATIENT
|
|
21
|
+
from h_adminsim.utils import colorstr, log
|
|
22
|
+
from h_adminsim.utils.fhir_utils import *
|
|
23
|
+
from h_adminsim.utils.common_utils import *
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FirstVisitOutpatientTask:
|
|
28
|
+
def __init__(self):
|
|
29
|
+
self.token_stats = {
|
|
30
|
+
'patient_token': {'input':[], 'output': [], 'reasoning': []},
|
|
31
|
+
'admin_staff_token': {'input': [], 'output': [], 'reasoning': []},
|
|
32
|
+
'supervisor_token': {'input':[], 'output': [], 'reasoning': []}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def save_token_data(self,
|
|
37
|
+
patient_token: Optional[dict] = None,
|
|
38
|
+
admin_staff_token: Optional[dict] = None,
|
|
39
|
+
supervisor_token: Optional[dict] = None):
|
|
40
|
+
"""
|
|
41
|
+
Save the API token usage data
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
patient_token (Optional[dict], optional): Patient token information. Defaults to None.
|
|
45
|
+
admin_staff_token (Optional[dict], optional): Administration staff token information. Defaults to None.
|
|
46
|
+
supervisor_token (Optional[dict], optional): Supervisor token information. Defaults to None.
|
|
47
|
+
"""
|
|
48
|
+
if patient_token:
|
|
49
|
+
self.token_stats['patient_token']['input'].extend(patient_token['prompt_tokens'])
|
|
50
|
+
self.token_stats['patient_token']['output'].extend(patient_token['completion_tokens'])
|
|
51
|
+
if 'reasoning_tokens' in patient_token:
|
|
52
|
+
self.token_stats['patient_token']['reasoning'].extend(patient_token['reasoning_tokens'])
|
|
53
|
+
|
|
54
|
+
if admin_staff_token:
|
|
55
|
+
self.token_stats['admin_staff_token']['input'].extend(admin_staff_token['prompt_tokens'])
|
|
56
|
+
self.token_stats['admin_staff_token']['output'].extend(admin_staff_token['completion_tokens'])
|
|
57
|
+
if 'reasoning_tokens' in admin_staff_token:
|
|
58
|
+
self.token_stats['admin_staff_token']['reasoning'].extend(admin_staff_token['reasoning_tokens'])
|
|
59
|
+
|
|
60
|
+
if supervisor_token:
|
|
61
|
+
self.token_stats['supervisor_token']['input'].extend(supervisor_token['prompt_tokens'])
|
|
62
|
+
self.token_stats['supervisor_token']['output'].extend(supervisor_token['completion_tokens'])
|
|
63
|
+
if 'reasoning_tokens' in supervisor_token:
|
|
64
|
+
self.token_stats['supervisor_token']['reasoning'].extend(supervisor_token['reasoning_tokens'])
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _init_task_models(self, model: str, vllm_endpoint: Optional[str] = None) -> Tuple[str, str, bool]:
|
|
68
|
+
"""
|
|
69
|
+
Initialize the model for the task.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model (str): The model name.
|
|
73
|
+
vllm_endpoint (Optional[str], optional): The VLLM endpoint URL. Defaults to None.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple[str, str, bool]: The model name, VLLM endpoint URL, vllm usage flag.
|
|
77
|
+
"""
|
|
78
|
+
if any(keyword in model.lower() for keyword in ['gemini', 'gpt']):
|
|
79
|
+
return model, None, False
|
|
80
|
+
else:
|
|
81
|
+
assert vllm_endpoint is not None, log('VLLM endpoint must be provided for non-Gemini/GPT models.', 'error')
|
|
82
|
+
return model, vllm_endpoint, True
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class OutpatientFirstIntake(FirstVisitOutpatientTask):
|
|
87
|
+
def __init__(self,
|
|
88
|
+
patient_model: str,
|
|
89
|
+
admin_staff_model: str,
|
|
90
|
+
supervisor_agent: Optional[SupervisorAgent] = None,
|
|
91
|
+
intake_max_inference: int = 5,
|
|
92
|
+
max_retries: int = 8,
|
|
93
|
+
admin_staff_last_task_user_prompt_path: Optional[str] = None,
|
|
94
|
+
patient_vllm_endpoint: Optional[str] = None,
|
|
95
|
+
admin_staff_vllm_endpoint: Optional[str] = None):
|
|
96
|
+
super().__init__()
|
|
97
|
+
|
|
98
|
+
# Initialize variables
|
|
99
|
+
self.name = 'intake'
|
|
100
|
+
self.patient_model, self.patient_vllm_endpoint, self.patient_use_vllm \
|
|
101
|
+
= self._init_task_models(patient_model, patient_vllm_endpoint)
|
|
102
|
+
self.admin_staff_model, self.admin_staff_vllm_endpoint, self.admin_staff_use_vllm \
|
|
103
|
+
= self._init_task_models(admin_staff_model, admin_staff_vllm_endpoint)
|
|
104
|
+
self.use_supervisor = True if isinstance(supervisor_agent, SupervisorAgent) else False
|
|
105
|
+
self.supervisor_client = supervisor_agent if self.use_supervisor else None
|
|
106
|
+
task_mechanism = 'Staff + Supervisor' if self.use_supervisor else 'Staff'
|
|
107
|
+
self.max_inferences = intake_max_inference
|
|
108
|
+
self.max_retries = max_retries
|
|
109
|
+
self._init_last_task_prompt(admin_staff_last_task_user_prompt_path)
|
|
110
|
+
self.patient_reasoning_kwargs = {'reasoning_effort': 'low'} if 'gpt-5' in self.patient_model.lower() else {}
|
|
111
|
+
self.staff_reasoning_kwargs = {'reasoning_effort': 'low'} if 'gpt-5' in self.admin_staff_model.lower() else {}
|
|
112
|
+
log(f'Patient intake tasks are conducted by {colorstr(task_mechanism)}')
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _init_last_task_prompt(self, admin_staff_last_task_user_prompt_path: Optional[str] = None) -> str:
|
|
116
|
+
"""
|
|
117
|
+
Initialize the user prompt for the admnistration staff agent's last task.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
admin_staff_last_task_user_prompt_path (Optional[str], optional): Path to a custom user prompt file.
|
|
121
|
+
If not provided, the default user prompt will be used. Defaults to None.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
FileNotFoundError: If the specified user prompt file does not exist.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
str: The user prompt.
|
|
128
|
+
"""
|
|
129
|
+
if not self.use_supervisor:
|
|
130
|
+
if not admin_staff_last_task_user_prompt_path:
|
|
131
|
+
prompt_file_name = 'intake_staff_task_user.txt'
|
|
132
|
+
file_path = resources.files("h_adminsim.assets.prompts").joinpath(prompt_file_name)
|
|
133
|
+
self.last_task_user_prompt = file_path.read_text()
|
|
134
|
+
else:
|
|
135
|
+
if not os.path.exists(admin_staff_last_task_user_prompt_path):
|
|
136
|
+
raise FileNotFoundError(colorstr("red", f"User prompt file not found: {admin_staff_last_task_user_prompt_path}"))
|
|
137
|
+
with open(admin_staff_last_task_user_prompt_path, 'r') as f:
|
|
138
|
+
self.last_task_user_prompt = f.read()
|
|
139
|
+
else:
|
|
140
|
+
if admin_staff_last_task_user_prompt_path:
|
|
141
|
+
log('The admin_staff_last_task_user_prompt_path setting is ignored when using supervisor model.', 'warning')
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def postprocessing_department(text: str) -> str:
|
|
146
|
+
"""
|
|
147
|
+
Post-processing method of text output, especially for the department decision.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
text (str): Text input.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
str: Post-processed text output.
|
|
154
|
+
"""
|
|
155
|
+
try:
|
|
156
|
+
pattern = re.compile(r'Answer:\s*\d+\.\s*(.+)')
|
|
157
|
+
text = pattern.search(text).group(1)
|
|
158
|
+
except:
|
|
159
|
+
text = 'wrong'
|
|
160
|
+
return text
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def postprocessing_information(text: str) -> Union[str, dict]:
|
|
165
|
+
"""
|
|
166
|
+
Post-processing method of text output, especially for the patient information extraction.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
text (str): Text input.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Union[str, dict]: A dictionary if the text is valid JSON, otherwise the original string.
|
|
173
|
+
"""
|
|
174
|
+
try:
|
|
175
|
+
if isinstance(text, str):
|
|
176
|
+
match = re.search(r'```json\s*(\{.*?\})\s*```', text, re.DOTALL)
|
|
177
|
+
if match:
|
|
178
|
+
json_str = match.group(1)
|
|
179
|
+
text_dict = json.loads(json_str)
|
|
180
|
+
else:
|
|
181
|
+
try:
|
|
182
|
+
text_dict = json.loads(text)
|
|
183
|
+
except:
|
|
184
|
+
return text
|
|
185
|
+
else:
|
|
186
|
+
text_dict = text
|
|
187
|
+
|
|
188
|
+
assert len(text_dict) == 6 and all(k in text_dict for k in ['name', 'gender', 'phone_number', 'personal_id', 'address', 'department']) # Basic sanity check
|
|
189
|
+
return text_dict
|
|
190
|
+
except:
|
|
191
|
+
return str(text)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _department_decision(self, prediction_department: str, prediction_supervison: Union[str, dict], gt_department: str) -> Tuple[str, list[str]]:
|
|
195
|
+
"""
|
|
196
|
+
Determine the final department decision by considering both
|
|
197
|
+
the interaction agent result and the supervisor agent result.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
prediction_department (str): The department predicted by the interaction agent.
|
|
201
|
+
prediction_supervison (Union[str, dict]): The supervisor agent's result.
|
|
202
|
+
If this is a dictionary, it should contain a 'department' field.
|
|
203
|
+
gt_department (str): The ground truth department for the patient.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
str: The final department decision.
|
|
207
|
+
"""
|
|
208
|
+
try:
|
|
209
|
+
sup_department = prediction_supervison.pop('department')
|
|
210
|
+
if prediction_department == sup_department:
|
|
211
|
+
trial = ['match']
|
|
212
|
+
else:
|
|
213
|
+
if prediction_department in gt_department and sup_department not in gt_department:
|
|
214
|
+
trial = ['mismatch - worse']
|
|
215
|
+
elif prediction_department not in gt_department and sup_department in gt_department:
|
|
216
|
+
trial = ['mismatch - better']
|
|
217
|
+
else:
|
|
218
|
+
trial = ['mismatch - both wrong']
|
|
219
|
+
return sup_department, trial
|
|
220
|
+
except:
|
|
221
|
+
return prediction_department, ['supervisor error']
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def __call__(self, data_pair: Tuple[dict, dict], agent_test_data: dict, agent_results: dict, environment, verbose: bool = False) -> dict:
|
|
225
|
+
"""
|
|
226
|
+
Estimates the most appropriate medical department for each patient using an LLM agent.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
data_pair (Tuple[dict, dict]): A pair of ground truth and patient data for agent simulation.
|
|
230
|
+
agent_test_data (dict): A dictionary containing test data for a single hospital.
|
|
231
|
+
Expected to include:
|
|
232
|
+
- 'department': Dictionary of available departments.
|
|
233
|
+
agent_results (dict): Placeholder for compatibility; not used in this method.
|
|
234
|
+
environment (HospitalEnvironment): Hospital environment instance to manage patient schedules.
|
|
235
|
+
verbose (bool): Whether logging the each result or not.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
dict: A dictionary with:
|
|
239
|
+
- 'gt': List of ground-truth departments.
|
|
240
|
+
- 'pred': List of predicted departments from the LLM agent.
|
|
241
|
+
- 'status': List of booleans indicating whether each prediction correct.
|
|
242
|
+
- 'status_code': List of status codes explaining each status.
|
|
243
|
+
"""
|
|
244
|
+
gt, test_data = data_pair
|
|
245
|
+
departments = list(agent_test_data['department'].keys())
|
|
246
|
+
results = init_result_dict()
|
|
247
|
+
sanity_checker = SanityChecker()
|
|
248
|
+
|
|
249
|
+
# Append a ground truth
|
|
250
|
+
name, gender, birth_date, telecom, personal_id, address = \
|
|
251
|
+
gt['patient'], gt['gender'], gt['birthDate'], gt['telecom'][0]['value'], gt['identifier'][0]['value'], gt['address'][0]['text']
|
|
252
|
+
gt_data = {
|
|
253
|
+
'patient': {
|
|
254
|
+
'name': name,
|
|
255
|
+
'gender': gender,
|
|
256
|
+
'phone_number': telecom,
|
|
257
|
+
'personal_id': personal_id,
|
|
258
|
+
'address': address,
|
|
259
|
+
},
|
|
260
|
+
'department': gt['department']
|
|
261
|
+
}
|
|
262
|
+
results['gt'].append(gt_data)
|
|
263
|
+
|
|
264
|
+
# LLM call: Conversation and department decision
|
|
265
|
+
department_candidates = test_data['constraint']['symptom']['department']
|
|
266
|
+
if test_data['constraint']['symptom_level'] == 'simple':
|
|
267
|
+
medical_history = "None. This is the patient's first visit."
|
|
268
|
+
diagnosis = "Unknown for now, as this is the patient's first visit to the hospital."
|
|
269
|
+
elif test_data['constraint']['symptom_level'] == 'with_history':
|
|
270
|
+
medical_history = f"Diagnosed with {test_data['constraint']['symptom']['disease']} at a primary or secondary hospital."
|
|
271
|
+
diagnosis = test_data['constraint']['symptom']['disease']
|
|
272
|
+
else:
|
|
273
|
+
log("Patient's symptom level must be either 'simple' or 'with_history'.", "error")
|
|
274
|
+
|
|
275
|
+
# Simulation patient intake
|
|
276
|
+
patient_agent = PatientAgent(
|
|
277
|
+
self.patient_model,
|
|
278
|
+
'outpatient',
|
|
279
|
+
lang_proficiency_level='B',
|
|
280
|
+
recall_level='no_history' if test_data['constraint']['symptom_level'] == 'simple' else 'high',
|
|
281
|
+
use_vllm=self.patient_use_vllm,
|
|
282
|
+
vllm_endpoint=self.patient_vllm_endpoint,
|
|
283
|
+
department=department_candidates,
|
|
284
|
+
name=name,
|
|
285
|
+
birth_date=birth_date,
|
|
286
|
+
gender=gender,
|
|
287
|
+
telecom=telecom,
|
|
288
|
+
personal_id=personal_id,
|
|
289
|
+
address=address,
|
|
290
|
+
medical_history=medical_history,
|
|
291
|
+
diagnosis=diagnosis,
|
|
292
|
+
chiefcomplaint=test_data['constraint']['symptom']['symptom'],
|
|
293
|
+
temperature=0 if not 'gpt-5' in self.patient_model.lower() else 1
|
|
294
|
+
)
|
|
295
|
+
admin_staff_agent = IntakeAdminStaffAgent(
|
|
296
|
+
self.admin_staff_model,
|
|
297
|
+
departments,
|
|
298
|
+
max_inferences=self.max_inferences,
|
|
299
|
+
use_vllm=self.admin_staff_use_vllm,
|
|
300
|
+
vllm_endpoint=self.admin_staff_vllm_endpoint,
|
|
301
|
+
temperature=0 if not 'gpt-5' in self.admin_staff_model.lower() else 1
|
|
302
|
+
)
|
|
303
|
+
sim_environment = OPFVIntakeSimulation(patient_agent, admin_staff_agent, max_inferences=self.max_inferences)
|
|
304
|
+
output = run_with_retry(
|
|
305
|
+
sim_environment.simulate,
|
|
306
|
+
verbose=False,
|
|
307
|
+
patient_kwargs=self.patient_reasoning_kwargs,
|
|
308
|
+
staff_kwargs=self.staff_reasoning_kwargs,
|
|
309
|
+
max_retries=self.max_retries,
|
|
310
|
+
)
|
|
311
|
+
dialogs, patient_token, admin_staff_token = output['dialog_history'], output.get('patient_token_usage'), output.get('admin_staff_token_usage')
|
|
312
|
+
prediction_department = OutpatientFirstIntake.postprocessing_department(dialogs[-1]['content'])
|
|
313
|
+
|
|
314
|
+
# LLM call: Agent which should extract demographic information of the patient and evaluation the department decision result
|
|
315
|
+
dialogs = preprocess_dialog(dialogs)
|
|
316
|
+
|
|
317
|
+
if self.use_supervisor:
|
|
318
|
+
user_prompt = self.supervisor_client.user_prompt_template.format(
|
|
319
|
+
CONVERSATION=dialogs,
|
|
320
|
+
DEPARTMENTS=''.join([f'{i+1}. {department}\n' for i, department in enumerate(departments)])
|
|
321
|
+
)
|
|
322
|
+
prediction_supervision = run_with_retry(
|
|
323
|
+
self.supervisor_client,
|
|
324
|
+
user_prompt,
|
|
325
|
+
using_multi_turn=False,
|
|
326
|
+
verbose=False,
|
|
327
|
+
max_retries=self.max_retries,
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
prediction_supervision = run_with_retry(
|
|
331
|
+
admin_staff_agent,
|
|
332
|
+
self.last_task_user_prompt,
|
|
333
|
+
verbose=False,
|
|
334
|
+
max_retries=self.max_retries,
|
|
335
|
+
**self.staff_reasoning_kwargs,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
prediction_supervision = OutpatientFirstIntake.postprocessing_information(prediction_supervision)
|
|
339
|
+
|
|
340
|
+
# Append token data
|
|
341
|
+
self.save_token_data(
|
|
342
|
+
patient_token,
|
|
343
|
+
admin_staff_token,
|
|
344
|
+
supervisor_token=self.supervisor_client.client.token_usages if self.use_supervisor else {}
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Sanity check
|
|
348
|
+
department, trial = self._department_decision(prediction_department, prediction_supervision, gt['department'])
|
|
349
|
+
prediction = {'patient': prediction_supervision, 'department': [department]}
|
|
350
|
+
status, status_code = sanity_checker.intake_check(
|
|
351
|
+
prediction=prediction,
|
|
352
|
+
gt=gt_data,
|
|
353
|
+
conversations=dialogs
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
if verbose:
|
|
357
|
+
log(f'GT : {gt_data}')
|
|
358
|
+
log(f'Pred : {prediction}')
|
|
359
|
+
log(f'Status: {status_code}\n\n\n')
|
|
360
|
+
|
|
361
|
+
# Append results
|
|
362
|
+
results['pred'].append(prediction)
|
|
363
|
+
results['status'].append(status)
|
|
364
|
+
results['status_code'].append(status_code)
|
|
365
|
+
results['trial'].append(trial)
|
|
366
|
+
results['dialog'].append(dialogs)
|
|
367
|
+
|
|
368
|
+
return results
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class OutpatientFirstScheduling(FirstVisitOutpatientTask):
|
|
373
|
+
def __init__(self,
|
|
374
|
+
patient_model: str,
|
|
375
|
+
admin_staff_model: str,
|
|
376
|
+
schedule_cancellation_prob: float = 0.05,
|
|
377
|
+
request_early_schedule_prob: float = 0.1,
|
|
378
|
+
preference_rejection_prob: float = 0.3,
|
|
379
|
+
preference_rejection_prob_decay: float = 0.5,
|
|
380
|
+
fhir_integration: bool = False,
|
|
381
|
+
scheduling_max_inference: int = 5,
|
|
382
|
+
scheduling_strategy: str = 'tool_calling',
|
|
383
|
+
max_retries: int = 8,
|
|
384
|
+
patient_vllm_endpoint: Optional[str] = None,
|
|
385
|
+
admin_staff_vllm_endpoint: Optional[str] = None):
|
|
386
|
+
super().__init__()
|
|
387
|
+
|
|
388
|
+
# Initialize variables
|
|
389
|
+
getcontext().prec = 10
|
|
390
|
+
dotenv_path = find_dotenv(usecwd=True)
|
|
391
|
+
load_dotenv(dotenv_path, override=True)
|
|
392
|
+
self.name = 'schedule'
|
|
393
|
+
self.patient_model, self.patient_vllm_endpoint, self.patient_use_vllm \
|
|
394
|
+
= self._init_task_models(patient_model, patient_vllm_endpoint)
|
|
395
|
+
self.admin_staff_model, self.admin_staff_vllm_endpoint, self.admin_staff_use_vllm \
|
|
396
|
+
= self._init_task_models(admin_staff_model, admin_staff_vllm_endpoint)
|
|
397
|
+
|
|
398
|
+
# Initialize scheduling methods and a staff agent
|
|
399
|
+
self.admin_staff_agent = SchedulingAdminStaffAgent(
|
|
400
|
+
target_task='first_outpatient_scheduling',
|
|
401
|
+
model=self.admin_staff_model,
|
|
402
|
+
use_vllm=self.admin_staff_use_vllm,
|
|
403
|
+
vllm_endpoint=self.admin_staff_vllm_endpoint,
|
|
404
|
+
temperature=0 if not 'gpt-5' in self.admin_staff_model.lower() else 1
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Scheduling parameters
|
|
408
|
+
self.schedule_cancellation_prob = schedule_cancellation_prob
|
|
409
|
+
self.request_early_schedule_prob = request_early_schedule_prob
|
|
410
|
+
self.preference_rejection_prob = preference_rejection_prob
|
|
411
|
+
self.preference_rejection_prob_decay = preference_rejection_prob_decay
|
|
412
|
+
|
|
413
|
+
# Others
|
|
414
|
+
self.fhir_integration = fhir_integration
|
|
415
|
+
self.max_retries = max_retries
|
|
416
|
+
self.max_inferences = scheduling_max_inference
|
|
417
|
+
self.scheduling_strategy = scheduling_strategy
|
|
418
|
+
assert self.scheduling_strategy in ['reasoning', 'tool_calling'], \
|
|
419
|
+
log('Scheduling strategy must be either `reasoning` or `tool_calling`.', 'error')
|
|
420
|
+
self.schedule_patient_system_prompt_path = str(resources.files("h_adminsim.assets.prompts").joinpath('schedule_patient_system.txt'))
|
|
421
|
+
self.cancel_patient_system_prompt_path = str(resources.files("h_adminsim.assets.prompts").joinpath('cancel_patient_system.txt'))
|
|
422
|
+
self.reschedule_patient_system_prompt_path = str(resources.files("h_adminsim.assets.prompts").joinpath('reschedule_patient_system.txt'))
|
|
423
|
+
self.patient_reasoning_kwargs = {'reasoning_effort': 'low'} if 'gpt-5' in self.patient_model.lower() else {}
|
|
424
|
+
self.staff_reasoning_kwargs = {'reasoning_effort': 'low'} if 'gpt-5' in self.admin_staff_model.lower() else {}
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _init_simulation(self,
|
|
428
|
+
system_prompt_path: str,
|
|
429
|
+
environment: HospitalEnvironment,
|
|
430
|
+
additional_patient_conditions: dict = {}) -> OPFVScheduleSimulation:
|
|
431
|
+
"""
|
|
432
|
+
Initialize an outpatient first-visit intake and scheduling simulation.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
system_prompt_path (str): Path to the system prompt used to initialize the patient agent.
|
|
436
|
+
environment (HospitalEnvironment): Hospital environment configuration for the simulation.
|
|
437
|
+
additional_patient_conditions (dict, optional): Additional patient-specific conditions for simulation control.
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
OPFVIntakeSimulation: Configured outpatient intake and scheduling simulation instance.
|
|
441
|
+
"""
|
|
442
|
+
patient_agent = PatientAgent(
|
|
443
|
+
self.patient_model,
|
|
444
|
+
'outpatient',
|
|
445
|
+
use_vllm=self.patient_use_vllm,
|
|
446
|
+
vllm_endpoint=self.patient_vllm_endpoint,
|
|
447
|
+
system_prompt_path=system_prompt_path,
|
|
448
|
+
log_verbose=False,
|
|
449
|
+
additional_patient_conditions=additional_patient_conditions,
|
|
450
|
+
temperature=0 if not 'gpt-5' in self.patient_model.lower() else 1
|
|
451
|
+
)
|
|
452
|
+
sim_environment = OPFVScheduleSimulation(
|
|
453
|
+
patient_agent=patient_agent,
|
|
454
|
+
admin_staff_agent=self.admin_staff_agent,
|
|
455
|
+
metadata=self._metadata,
|
|
456
|
+
department_data=self._department_data,
|
|
457
|
+
environment=environment,
|
|
458
|
+
scheduling_strategy=self.scheduling_strategy,
|
|
459
|
+
preference_rejection_prob=self.preference_rejection_prob,
|
|
460
|
+
preference_rejection_prob_decay=self.preference_rejection_prob_decay,
|
|
461
|
+
fhir_integration=self.fhir_integration,
|
|
462
|
+
sanity_checker=self.sanity_checker,
|
|
463
|
+
)
|
|
464
|
+
return sim_environment
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
def get_intake_information(self, gt: dict, agent_results: dict, doctor_information: dict) -> Tuple[dict, str, bool]:
|
|
468
|
+
"""
|
|
469
|
+
Extracts the patient name and predicted department from agent results.
|
|
470
|
+
If predictions are not available, falls back to using ground truth labels.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
gt (dict): Ground truth data of a patient.
|
|
474
|
+
agent_results (dict): A dictionary that may contain predicted department results under the key 'department'.
|
|
475
|
+
doctor_information (dict): Dictionary of doctor data including their existing schedules.
|
|
476
|
+
Each key is a doctor's name, and each value includes a 'schedule' field.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
Tuple[dict, str, bool]: Patient information, determined department, either predicted or ground truth and its sanity status.
|
|
480
|
+
"""
|
|
481
|
+
# Prediction results are existing case
|
|
482
|
+
try:
|
|
483
|
+
for i, intake_gt in enumerate(agent_results['intake']['gt']):
|
|
484
|
+
if gt['patient'] == intake_gt['patient']['name']:
|
|
485
|
+
break
|
|
486
|
+
|
|
487
|
+
patient_info = agent_results['intake']['pred'][i]['patient']
|
|
488
|
+
department = agent_results['intake']['pred'][i]['department'][0]
|
|
489
|
+
sanity = agent_results['intake']['status'][i]
|
|
490
|
+
|
|
491
|
+
assert gt['patient'] == agent_results['intake']['gt'][i]['patient']['name']
|
|
492
|
+
|
|
493
|
+
# Loading from the ground truth
|
|
494
|
+
except:
|
|
495
|
+
log('The predicted department is not given. The ground truth value will be used.', 'warning')
|
|
496
|
+
patient_info = {
|
|
497
|
+
'name': gt['patient'],
|
|
498
|
+
'gender': gt['gender'],
|
|
499
|
+
'phone_number': gt['telecom'][0]['value'],
|
|
500
|
+
'personal_id': gt['identifier'][0]['value'],
|
|
501
|
+
'address': gt['address'][0]['text'],
|
|
502
|
+
}
|
|
503
|
+
department = doctor_information[gt['attending_physician']]['department']
|
|
504
|
+
assert department in gt['department']
|
|
505
|
+
sanity = True
|
|
506
|
+
|
|
507
|
+
return patient_info, department, sanity
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def cancellation_request(self,
|
|
511
|
+
doctor_information: dict,
|
|
512
|
+
environment: HospitalEnvironment,
|
|
513
|
+
idx: Optional[int] = None,
|
|
514
|
+
verbose: bool = False) -> Tuple[dict, Optional[dict]]:
|
|
515
|
+
"""
|
|
516
|
+
Cancel a doctor's scheduled appointment.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
doctor_information (dict): A dictionary containing information about the doctor(s) involved,
|
|
520
|
+
including availability and other relevant details.
|
|
521
|
+
environment (HospitalEnvironment): Hospital environment.
|
|
522
|
+
idx (int, optional): Specific patient schedule index.
|
|
523
|
+
verbose (bool, optional): Whether logging the each result or not. Defaults to False.
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
Tuple[dict, Optional[dict]]: Updated doctor information and a result dictionary after cancellation.
|
|
527
|
+
"""
|
|
528
|
+
if idx is None:
|
|
529
|
+
candidate_idx = [i for i, schedule in enumerate(environment.patient_schedules) if schedule['status'] == 'scheduled']
|
|
530
|
+
idx = random.choice(candidate_idx) if len(candidate_idx) else -1
|
|
531
|
+
|
|
532
|
+
if idx >= 0:
|
|
533
|
+
# Ground-truth cancelled schedule
|
|
534
|
+
cancelled_schedule = environment.patient_schedules[idx]
|
|
535
|
+
patient = cancelled_schedule['patient']
|
|
536
|
+
doctor, date, time = cancelled_schedule['attending_physician'], cancelled_schedule['date'], cancelled_schedule['schedule']
|
|
537
|
+
|
|
538
|
+
# Initialize simulation environment for cancellation
|
|
539
|
+
sim_environment = self._init_simulation(
|
|
540
|
+
system_prompt_path=self.cancel_patient_system_prompt_path,
|
|
541
|
+
environment=environment,
|
|
542
|
+
additional_patient_conditions={
|
|
543
|
+
'patient_name': patient,
|
|
544
|
+
'doctor_name': doctor,
|
|
545
|
+
'date': date,
|
|
546
|
+
'start_time': hour_to_hhmmss(time[0])
|
|
547
|
+
}
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
# Schedule cancellation simulation
|
|
551
|
+
doctor_information, result_dict = run_with_retry(
|
|
552
|
+
sim_environment.canceling_simulate,
|
|
553
|
+
gt_idx=idx,
|
|
554
|
+
doctor_information=doctor_information,
|
|
555
|
+
patient_schedules=environment.patient_schedules,
|
|
556
|
+
verbose=verbose,
|
|
557
|
+
max_inferences=self.max_inferences,
|
|
558
|
+
patient_kwargs=self.patient_reasoning_kwargs,
|
|
559
|
+
staff_kwargs=self.staff_reasoning_kwargs,
|
|
560
|
+
max_retries=self.max_retries,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
# Successfully canceled
|
|
564
|
+
if result_dict['status'][0] is not False: # No GT and correct case
|
|
565
|
+
# Update waiting list due to cancellation
|
|
566
|
+
doctor_information, rs_result_dict = self.automatic_waiting_list_update(
|
|
567
|
+
sim_environment=sim_environment,
|
|
568
|
+
environment=environment,
|
|
569
|
+
doctor_information=doctor_information,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Update result dictionary
|
|
573
|
+
for key in result_dict.keys():
|
|
574
|
+
if len(rs_result_dict[key]):
|
|
575
|
+
result_dict[key].append(tuple(rs_result_dict[key]))
|
|
576
|
+
|
|
577
|
+
return doctor_information, result_dict
|
|
578
|
+
|
|
579
|
+
return doctor_information, None
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def rescheduling_request(self,
|
|
583
|
+
doctor_information: dict,
|
|
584
|
+
environment: HospitalEnvironment,
|
|
585
|
+
idx: Optional[int] = None,
|
|
586
|
+
verbose: bool = False) -> Tuple[dict, Optional[dict]]:
|
|
587
|
+
"""
|
|
588
|
+
Add a patient schedule to the waiting list in the given environment.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
doctor_information (dict): A dictionary containing information about the doctor(s) involved,
|
|
592
|
+
including availability and other relevant details.
|
|
593
|
+
environment (HospitalEnvironment): Hospital environment.
|
|
594
|
+
idx (int, optional): Specific patient schedule index.
|
|
595
|
+
verbose (bool, optional): Whether logging the each result or not. Defaults to False.
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
Tuple[dict, Optional[dict]]: Updated doctor information and a result dictionary after cancellation.
|
|
599
|
+
"""
|
|
600
|
+
result_dict = init_result_dict()
|
|
601
|
+
if idx is None:
|
|
602
|
+
candidate_idx = [i for i, schedule in enumerate(environment.patient_schedules) if schedule['status'] == 'scheduled']
|
|
603
|
+
idx = random.choice(candidate_idx) if len(candidate_idx) else -1
|
|
604
|
+
|
|
605
|
+
if idx >= 0:
|
|
606
|
+
requested_schedule = environment.patient_schedules[idx]
|
|
607
|
+
if all(requested_schedule != s[1] for s in environment.waiting_list):
|
|
608
|
+
# Ground-truth rescheduling requested schedule
|
|
609
|
+
patient = requested_schedule['patient']
|
|
610
|
+
doctor, date, time = requested_schedule['attending_physician'], requested_schedule['date'], requested_schedule['schedule']
|
|
611
|
+
|
|
612
|
+
# Initialize simulation environment for rescheduling request
|
|
613
|
+
sim_environment = self._init_simulation(
|
|
614
|
+
system_prompt_path=self.reschedule_patient_system_prompt_path,
|
|
615
|
+
environment=environment,
|
|
616
|
+
additional_patient_conditions={
|
|
617
|
+
'patient_name': patient,
|
|
618
|
+
'doctor_name': doctor,
|
|
619
|
+
'date': date,
|
|
620
|
+
'start_time': hour_to_hhmmss(time[0])
|
|
621
|
+
}
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# Rescheduling request simulation
|
|
625
|
+
doctor_information, result_dict = run_with_retry(
|
|
626
|
+
sim_environment.rescheduling_simulate,
|
|
627
|
+
gt_idx=idx,
|
|
628
|
+
doctor_information=doctor_information,
|
|
629
|
+
patient_schedules=environment.patient_schedules,
|
|
630
|
+
verbose=verbose,
|
|
631
|
+
max_inferences=self.max_inferences,
|
|
632
|
+
patient_kwargs=self.patient_reasoning_kwargs,
|
|
633
|
+
staff_kwargs=self.staff_reasoning_kwargs,
|
|
634
|
+
max_retries=self.max_retries,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if result_dict['status'][0] is not False: # No GT and correct case
|
|
638
|
+
if 'patient' in result_dict['pred'][0]:
|
|
639
|
+
new_schedule = result_dict['pred'][0]
|
|
640
|
+
doctor_information[new_schedule['attending_physician']]['schedule'][new_schedule['date']].append(new_schedule['schedule'])
|
|
641
|
+
doctor_information[new_schedule['attending_physician']]['schedule'][new_schedule['date']].sort()
|
|
642
|
+
self.update_env(
|
|
643
|
+
status=True,
|
|
644
|
+
prediction=new_schedule,
|
|
645
|
+
environment=environment,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
return doctor_information, result_dict
|
|
649
|
+
|
|
650
|
+
return doctor_information, None
|
|
651
|
+
|
|
652
|
+
return doctor_information, None
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
def automatic_waiting_list_update(self,
|
|
656
|
+
sim_environment: OPFVScheduleSimulation,
|
|
657
|
+
environment: HospitalEnvironment,
|
|
658
|
+
doctor_information: Optional[dict] = None) -> Tuple[dict, dict]:
|
|
659
|
+
"""
|
|
660
|
+
Automatically update the waiting list by attempting to reschedule patients.
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
sim_environment (OPFVScheduleSimulation): The simulation environment used for scheduling.
|
|
664
|
+
environment (HospitalEnvironment): Hospital environment.
|
|
665
|
+
doctor_information (Optional[dict], optional): A dictionary containing information about the doctor(s) involved,
|
|
666
|
+
including availability and other relevant details. Defaults to None.
|
|
667
|
+
|
|
668
|
+
Returns:
|
|
669
|
+
Tuple[dict, dict]: Updated doctor information and a result dictionary.
|
|
670
|
+
"""
|
|
671
|
+
all_result_dict = init_result_dict()
|
|
672
|
+
for result in sim_environment.automatic_waiting_list_update(
|
|
673
|
+
doctor_information=doctor_information,
|
|
674
|
+
**self.staff_reasoning_kwargs,
|
|
675
|
+
):
|
|
676
|
+
doctor_information, result_dict = result['doctor_information'], result['result_dict']
|
|
677
|
+
|
|
678
|
+
if result_dict['status'][0]:
|
|
679
|
+
new_schedule, original = result_dict['pred'][0], result['original']
|
|
680
|
+
doctor_information[new_schedule['attending_physician']]['schedule'][new_schedule['date']].append(new_schedule['schedule'])
|
|
681
|
+
doctor_information[new_schedule['attending_physician']]['schedule'][new_schedule['date']].sort()
|
|
682
|
+
self.update_env(
|
|
683
|
+
status=True,
|
|
684
|
+
prediction=new_schedule,
|
|
685
|
+
environment=environment,
|
|
686
|
+
)
|
|
687
|
+
log(f'{colorstr("[RESCHEDULED]")}: {original} is rescheduled to {new_schedule}')
|
|
688
|
+
|
|
689
|
+
all_result_dict['gt'].extend(result_dict['gt'])
|
|
690
|
+
all_result_dict['pred'].extend(result_dict['pred'])
|
|
691
|
+
all_result_dict['status'].extend(result_dict['status'])
|
|
692
|
+
all_result_dict['status_code'].extend(result_dict['status_code'])
|
|
693
|
+
all_result_dict['dialog'].extend(result_dict['dialog'])
|
|
694
|
+
|
|
695
|
+
return doctor_information, all_result_dict
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def update_env(self,
|
|
699
|
+
status: bool,
|
|
700
|
+
prediction: Union[dict, str],
|
|
701
|
+
environment: HospitalEnvironment,
|
|
702
|
+
patient_information: Optional[dict] = None):
|
|
703
|
+
"""
|
|
704
|
+
Update the simulation environment with scheduling results and optionally synchronize FHIR resources.
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
status (bool): Whether the scheduling task was successful. If True, FHIR resources may be updated.
|
|
708
|
+
prediction (Union[dict, str]): The predicted scheduling result (e.g., patient schedule information).
|
|
709
|
+
environment (HospitalEnvironment): The environment instance to be updated (must implement `update_env`).
|
|
710
|
+
patient_information (Optional[dict], optional): Patient-related predicted (or GT) information to generate FHIR Patient resources. Defaults to None.
|
|
711
|
+
|
|
712
|
+
"""
|
|
713
|
+
# POST/PUT to FHIR
|
|
714
|
+
fhir_patient, fhir_appointment = None, None
|
|
715
|
+
if status and self.fhir_integration:
|
|
716
|
+
if patient_information is not None:
|
|
717
|
+
fhir_patient = DataConverter.data_to_patient(
|
|
718
|
+
{
|
|
719
|
+
'metadata': deepcopy(self._metadata),
|
|
720
|
+
'department': deepcopy(self._department_data),
|
|
721
|
+
'patient': {
|
|
722
|
+
prediction['patient']: {
|
|
723
|
+
'department': prediction['department'],
|
|
724
|
+
'gender': patient_information['gender'],
|
|
725
|
+
'telecom': [{'system': 'phone', 'value': patient_information['phone_number'], 'use': 'mobile'}],
|
|
726
|
+
'birthDate': personal_id_to_birth_date(patient_information['personal_id']),
|
|
727
|
+
'identifier': [{'value': patient_information['personal_id'], 'use': 'official'}],
|
|
728
|
+
'address': [{'type': 'postal', 'text': patient_information['address'], 'use': 'home'}],
|
|
729
|
+
}
|
|
730
|
+
}
|
|
731
|
+
}
|
|
732
|
+
)[0]
|
|
733
|
+
fhir_appointment = DataConverter.get_fhir_appointment(data={'metadata': deepcopy(self._metadata),
|
|
734
|
+
'department': deepcopy(self._department_data),
|
|
735
|
+
'information': deepcopy(prediction)})
|
|
736
|
+
|
|
737
|
+
environment.update_env(
|
|
738
|
+
status=status,
|
|
739
|
+
patient_schedule=prediction,
|
|
740
|
+
fhir_resources={'Patient': fhir_patient, 'Appointment': fhir_appointment}
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def __call__(self, data_pair: Tuple[dict, dict], agent_test_data: dict, agent_results: dict, environment, verbose: bool = False) -> dict:
|
|
745
|
+
"""
|
|
746
|
+
This method uses agent test data to prompt an LLM for scheduling decisions, post-processes
|
|
747
|
+
the output, runs sanity checks on predicted schedules, and collects the results for evaluation.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
data_pair (Tuple[dict, dict]): A pair of ground truth and patient data for agent simulation.
|
|
751
|
+
agent_test_data (dict): Dictionary containing test data and metadata for a single hospital.
|
|
752
|
+
Expected keys include:
|
|
753
|
+
- 'metadata': A dict containing start_hour, end_hour, and interval_hour under 'time'.
|
|
754
|
+
- 'agent_data': A list of (ground_truth, test_data) pairs.
|
|
755
|
+
- 'doctor': A dictionary of doctor profiles with department and schedule info.
|
|
756
|
+
agent_results (dict): Optional dictionary containing prior department predictions.
|
|
757
|
+
Used to extract department-level guidance per patient. Can be empty.
|
|
758
|
+
environment (HospitalEnvironment): Hospital environment instance to manage patient schedules.
|
|
759
|
+
verbose (bool, option): Whether logging the each result or not.
|
|
760
|
+
|
|
761
|
+
Returns:
|
|
762
|
+
dict: A dictionary with three keys:
|
|
763
|
+
- 'gt': List of ground truth results, each including patient info, attending physician, department, and schedule.
|
|
764
|
+
- 'pred': List of predicted results (either valid dict or fallback string).
|
|
765
|
+
- 'status': List of booleans indicating whether each prediction passed sanity checks.
|
|
766
|
+
- 'status_code': List of status codes explaining each status.
|
|
767
|
+
"""
|
|
768
|
+
gt, test_data = data_pair
|
|
769
|
+
self._metadata = agent_test_data.get('metadata')
|
|
770
|
+
self._department_data = agent_test_data.get('department')
|
|
771
|
+
self._START_HOUR = self._metadata.get('time').get('start_hour')
|
|
772
|
+
self._END_HOUR = self._metadata.get('time').get('end_hour')
|
|
773
|
+
self._TIME_UNIT = self._metadata.get('time').get('interval_hour')
|
|
774
|
+
self.sanity_checker = SanityChecker(self._START_HOUR, self._END_HOUR, self._TIME_UNIT)
|
|
775
|
+
doctor_information = environment.get_general_doctor_info_from_fhir() if self.fhir_integration else agent_test_data.get('doctor')
|
|
776
|
+
patient_info, department, sanity = self.get_intake_information(gt, agent_results, doctor_information)
|
|
777
|
+
self.rules = SchedulingRule(self._metadata, self._department_data, environment, self.fhir_integration)
|
|
778
|
+
results = init_result_dict()
|
|
779
|
+
|
|
780
|
+
# Make scheduling GT list
|
|
781
|
+
gt_data = [
|
|
782
|
+
{
|
|
783
|
+
'patient': patient_info['name'] if sanity else gt.get('patient'),
|
|
784
|
+
'department': department if sanity else gt.get('department'),
|
|
785
|
+
'preference': preference,
|
|
786
|
+
'preferred_doctor': gt.get('attending_physician') if preference == 'doctor' else "Doesn't matter",
|
|
787
|
+
'valid_from': gt.get('valid_from') if preference == 'date' else None,
|
|
788
|
+
} for preference in gt.get('preference')
|
|
789
|
+
]
|
|
790
|
+
staff_known_data = [
|
|
791
|
+
{
|
|
792
|
+
'patient': patient_info['name'],
|
|
793
|
+
'department': department,
|
|
794
|
+
'patient_intention': None,
|
|
795
|
+
} for _ in range(len(gt_data))
|
|
796
|
+
]
|
|
797
|
+
|
|
798
|
+
# If the precedent department data is wrong, continue
|
|
799
|
+
if not sanity:
|
|
800
|
+
results['gt'].append(gt_data)
|
|
801
|
+
results['pred'].append({})
|
|
802
|
+
results['status'].append(False)
|
|
803
|
+
results['status_code'].append(STATUS_CODES['preceding'])
|
|
804
|
+
return results
|
|
805
|
+
|
|
806
|
+
#################################################### Regular Scheudling Simulation ####################################################
|
|
807
|
+
# Initialize the simulation environment using the first preference data
|
|
808
|
+
preference = gt_data[0].get('preference')
|
|
809
|
+
preference_desc = PREFERENCE_PHRASE_PATIENT[preference] if preference != 'date' \
|
|
810
|
+
else PREFERENCE_PHRASE_PATIENT[preference].format(date=gt_data[0].get('valid_from'))
|
|
811
|
+
sim_environment = self._init_simulation(
|
|
812
|
+
system_prompt_path=self.schedule_patient_system_prompt_path,
|
|
813
|
+
environment=environment,
|
|
814
|
+
additional_patient_conditions={
|
|
815
|
+
'preference': preference,
|
|
816
|
+
'preference_desc': preference_desc,
|
|
817
|
+
'preferred_doctor': gt_data[0]['preferred_doctor'],
|
|
818
|
+
}
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
# Simulate the main scheduling task
|
|
822
|
+
doctor_information, result_dict = run_with_retry(
|
|
823
|
+
sim_environment.scheduling_simulate,
|
|
824
|
+
gt_data=gt_data,
|
|
825
|
+
staff_known_data=staff_known_data,
|
|
826
|
+
doctor_information=doctor_information,
|
|
827
|
+
verbose=verbose,
|
|
828
|
+
patient_kwargs=self.patient_reasoning_kwargs,
|
|
829
|
+
staff_kwargs=self.staff_reasoning_kwargs,
|
|
830
|
+
max_retries=self.max_retries,
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
prediction, status, status_code = \
|
|
834
|
+
result_dict['pred'][0], result_dict['status'][0], result_dict['status_code'][0]
|
|
835
|
+
|
|
836
|
+
if verbose:
|
|
837
|
+
log(f'Pred : {prediction}')
|
|
838
|
+
log(f'Status: {status_code}')
|
|
839
|
+
log(f'Final Status: {status_code}\n\n\n')
|
|
840
|
+
|
|
841
|
+
# Update the simulation environment and the doctor information in the agent test data
|
|
842
|
+
if status:
|
|
843
|
+
doctor_information[prediction['attending_physician']]['schedule'][prediction['date']].append(prediction['schedule'])
|
|
844
|
+
doctor_information[prediction['attending_physician']]['schedule'][prediction['date']].sort()
|
|
845
|
+
|
|
846
|
+
self.update_env(
|
|
847
|
+
status=status,
|
|
848
|
+
prediction=prediction,
|
|
849
|
+
environment=environment,
|
|
850
|
+
patient_information=patient_info,
|
|
851
|
+
)
|
|
852
|
+
agent_test_data['doctor'] = doctor_information
|
|
853
|
+
|
|
854
|
+
# Append results
|
|
855
|
+
for key in result_dict.keys():
|
|
856
|
+
results[key] += result_dict[key]
|
|
857
|
+
#######################################################################################################################################
|
|
858
|
+
|
|
859
|
+
# Other events
|
|
860
|
+
## Simulate the schedule cancellation requests
|
|
861
|
+
if random.random() < self.schedule_cancellation_prob:
|
|
862
|
+
doctor_information, result_dict = self.cancellation_request(
|
|
863
|
+
doctor_information=doctor_information,
|
|
864
|
+
environment=environment,
|
|
865
|
+
verbose=verbose,
|
|
866
|
+
)
|
|
867
|
+
if result_dict is not None:
|
|
868
|
+
agent_test_data['doctor'] = doctor_information
|
|
869
|
+
results['gt'].extend(result_dict['gt'])
|
|
870
|
+
results['pred'].extend(result_dict['pred'])
|
|
871
|
+
results['status'].extend(result_dict['status'])
|
|
872
|
+
results['status_code'].extend(result_dict['status_code'])
|
|
873
|
+
results['dialog'].extend(result_dict['dialog'])
|
|
874
|
+
|
|
875
|
+
if verbose:
|
|
876
|
+
log(f'Pred : {result_dict["pred"]}')
|
|
877
|
+
log(f'Status: {result_dict["status_code"]}')
|
|
878
|
+
log(f'Final Status: {result_dict["status_code"]}\n\n\n')
|
|
879
|
+
|
|
880
|
+
## Simulate the resecheduling requests
|
|
881
|
+
if random.random() < self.request_early_schedule_prob:
|
|
882
|
+
doctor_information, result_dict = self.rescheduling_request(
|
|
883
|
+
doctor_information=doctor_information,
|
|
884
|
+
environment=environment,
|
|
885
|
+
verbose=verbose
|
|
886
|
+
)
|
|
887
|
+
if result_dict is not None:
|
|
888
|
+
agent_test_data['doctor'] = doctor_information
|
|
889
|
+
results['gt'].extend(result_dict['gt'])
|
|
890
|
+
results['pred'].extend(result_dict['pred'])
|
|
891
|
+
results['status'].extend(result_dict['status'])
|
|
892
|
+
results['status_code'].extend(result_dict['status_code'])
|
|
893
|
+
results['dialog'].extend(result_dict['dialog'])
|
|
894
|
+
|
|
895
|
+
if verbose:
|
|
896
|
+
log(f'Pred : {result_dict["pred"]}')
|
|
897
|
+
log(f'Status: {result_dict["status_code"]}')
|
|
898
|
+
log(f'Final Status: {result_dict["status_code"]}\n\n\n')
|
|
899
|
+
|
|
900
|
+
return results
|