h-adminsim 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. h_adminsim/__init__.py +5 -0
  2. h_adminsim/admin_staff.py +280 -0
  3. h_adminsim/assets/configs/data4primary.yaml +47 -0
  4. h_adminsim/assets/configs/data4secondary.yaml +47 -0
  5. h_adminsim/assets/configs/data4tertiary.yaml +47 -0
  6. h_adminsim/assets/country/address.json +141859 -0
  7. h_adminsim/assets/country/country_code.json +244 -0
  8. h_adminsim/assets/departments/department.json +85 -0
  9. h_adminsim/assets/departments/symptom.json +4530 -0
  10. h_adminsim/assets/fhir.schema.json +75253 -0
  11. h_adminsim/assets/names/firstname.txt +1219 -0
  12. h_adminsim/assets/names/lastname.txt +88799 -0
  13. h_adminsim/assets/prompts/cancel_patient_system.txt +38 -0
  14. h_adminsim/assets/prompts/intake_staff_task_user.txt +16 -0
  15. h_adminsim/assets/prompts/intake_supervisor_system.txt +8 -0
  16. h_adminsim/assets/prompts/intake_supervisor_user.txt +31 -0
  17. h_adminsim/assets/prompts/reschedule_patient_system.txt +38 -0
  18. h_adminsim/assets/prompts/schedule_patient_rejected_system.txt +42 -0
  19. h_adminsim/assets/prompts/schedule_patient_system.txt +36 -0
  20. h_adminsim/assets/prompts/schedule_staff_reasoning.txt +57 -0
  21. h_adminsim/assets/prompts/schedule_staff_sc_tool_calling.txt +13 -0
  22. h_adminsim/assets/prompts/schedule_staff_system.txt +10 -0
  23. h_adminsim/assets/prompts/schedule_staff_tool_calling.txt +41 -0
  24. h_adminsim/client/__init__.py +3 -0
  25. h_adminsim/client/google_client.py +209 -0
  26. h_adminsim/client/openai_client.py +199 -0
  27. h_adminsim/client/vllm_client.py +160 -0
  28. h_adminsim/environment/__init__.py +1 -0
  29. h_adminsim/environment/hospital.py +462 -0
  30. h_adminsim/environment/op_scheduling_simulation.py +1126 -0
  31. h_adminsim/pipeline/__init__.py +3 -0
  32. h_adminsim/pipeline/data_generator.py +192 -0
  33. h_adminsim/pipeline/evaluator.py +33 -0
  34. h_adminsim/pipeline/simulation.py +231 -0
  35. h_adminsim/registry/__init__.py +5 -0
  36. h_adminsim/registry/errors.py +89 -0
  37. h_adminsim/registry/models.py +126 -0
  38. h_adminsim/registry/phrases.py +10 -0
  39. h_adminsim/registry/pydantic_models.py +21 -0
  40. h_adminsim/registry/variables.py +9 -0
  41. h_adminsim/supervisor.py +182 -0
  42. h_adminsim/task/agent_task.py +900 -0
  43. h_adminsim/task/fhir_manager.py +222 -0
  44. h_adminsim/task/schedule_assign.py +151 -0
  45. h_adminsim/tools/__init__.py +5 -0
  46. h_adminsim/tools/agent_data_builder.py +124 -0
  47. h_adminsim/tools/data_converter.py +536 -0
  48. h_adminsim/tools/data_synthesizer.py +365 -0
  49. h_adminsim/tools/evaluator.py +258 -0
  50. h_adminsim/tools/sanity_checker.py +216 -0
  51. h_adminsim/tools/scheduling_rule.py +420 -0
  52. h_adminsim/utils/__init__.py +136 -0
  53. h_adminsim/utils/common_utils.py +698 -0
  54. h_adminsim/utils/fhir_utils.py +190 -0
  55. h_adminsim/utils/filesys_utils.py +135 -0
  56. h_adminsim/utils/image_preprocess_utils.py +188 -0
  57. h_adminsim/utils/random_utils.py +358 -0
  58. h_adminsim/version.txt +1 -0
  59. h_adminsim-1.0.0.dist-info/LICENSE +30 -0
  60. h_adminsim-1.0.0.dist-info/METADATA +494 -0
  61. h_adminsim-1.0.0.dist-info/RECORD +62 -0
  62. h_adminsim-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,216 @@
1
+ from decimal import Decimal
2
+ from typing import Tuple, Union, Optional
3
+
4
+ from h_adminsim.registry import STATUS_CODES
5
+ from h_adminsim.environment.hospital import HospitalEnvironment
6
+ from h_adminsim.utils.common_utils import *
7
+
8
+
9
+
10
+ class SanityChecker:
11
+ def __init__(self,
12
+ start_hour: Optional[float] = None,
13
+ end_hour: Optional[float] = None,
14
+ time_unit: Optional[float] = None,):
15
+
16
+ # Initialization
17
+ self._START_HOUR = start_hour
18
+ self._END_HOUR = end_hour
19
+ self._TIME_UNIT = time_unit
20
+
21
+
22
+ def intake_check(self,
23
+ prediction: dict,
24
+ gt: dict,
25
+ conversations: str) -> Tuple[bool, str]:
26
+ """
27
+ Performs a sanity check on the predicted patient information and department against the ground truth.
28
+
29
+ Args:
30
+ prediction (dict): The output generated by the model. Expected to contain:
31
+ - 'patient': dict of patient demographic information (e.g., name, birth date, gender)
32
+ - 'department': str representing the predicted department
33
+ gt (dict): The ground truth data. Expected to contain:
34
+ - 'patient': dict of correct patient demographic information
35
+ - 'department': str of the correct department
36
+ conversations (str): The full conversation text between the patient and administration staff.
37
+
38
+ Returns:
39
+ Tuple[bool, str]:
40
+ - bool: True if the prediction passes all sanity checks, False otherwise
41
+ - str: Status code indicating the type of check passed or failed
42
+ """
43
+ ############################ Check the prediciton format #############################
44
+ if not isinstance(prediction['patient'], dict):
45
+ return False, STATUS_CODES['format'] # Could not be parsed as a dictionary
46
+
47
+ ############################ Incomplete simulation case #############################
48
+ if not all(v.lower() in conversations.lower() for k, v in gt['patient'].items()):
49
+ return False, STATUS_CODES['simulation']
50
+
51
+ ############################ Check with the ground truth #############################
52
+ wrong_department = prediction['department'][0] not in gt['department']
53
+ wrong_info = prediction['patient'] != gt['patient']
54
+ if wrong_department and wrong_info:
55
+ return False, STATUS_CODES['department & patient']
56
+ elif wrong_department:
57
+ return False, STATUS_CODES['department']
58
+ elif wrong_info:
59
+ return False, STATUS_CODES['patient']
60
+
61
+ return True, STATUS_CODES['correct']
62
+
63
+
64
+ def __check_is_earliest(self,
65
+ prediction: dict,
66
+ gt_patient_condition: dict,
67
+ doctor_information: dict,
68
+ environment: HospitalEnvironment) -> bool:
69
+ """
70
+ Check if the predicted schedule is the earliest possible option.
71
+
72
+ Args:
73
+ prediction (dict): Predicted schedule information including doctor, start time, end time, and date.
74
+ gt_patient_condition (dict): Ground truth patient conditions used only for sanity checks.
75
+ doctor_information (dict): Dictionary containing doctors' schedules and availability.
76
+ environment (HospitalEnvironment): Environment object containing current time and UTC offset.
77
+
78
+ Returns:
79
+ bool: True if the predicted schedule is the earliest available, False otherwise.
80
+ """
81
+ # Init grount thruth values
82
+ department = gt_patient_condition['department']
83
+ preference_type = gt_patient_condition['preference']
84
+ valid_from = gt_patient_condition['valid_from']
85
+ fixed_schedules = environment.get_doctor_schedule(doctor_information=doctor_information, department=department)['doctor']
86
+
87
+ # Get predicted results
88
+ pred_doctor_name = list(prediction['schedule'].keys())[0]
89
+ pred_start = prediction['schedule'][pred_doctor_name]['start']
90
+ pred_end = prediction['schedule'][pred_doctor_name]['end']
91
+ pred_date = prediction['schedule'][pred_doctor_name]['date']
92
+ current_time = environment.current_time
93
+ utc_offset = environment._utc_offset
94
+
95
+ # Time segments
96
+ prediction_schedule_segments = convert_time_to_segment(self._START_HOUR,
97
+ self._END_HOUR,
98
+ self._TIME_UNIT,
99
+ [pred_start, pred_end])
100
+
101
+ for k, v in fixed_schedules.items():
102
+ if preference_type == 'doctor' and k != pred_doctor_name:
103
+ continue
104
+
105
+ min_time_slot_n = int(Decimal(str(v['outpatient_duration'])) / Decimal(str(self._TIME_UNIT)))
106
+ fixed_schedule = v['schedule']
107
+ for date, schedule in fixed_schedule.items():
108
+ # date > pred_date case
109
+ if compare_iso_time(date, pred_date):
110
+ continue
111
+
112
+ # valid_from > date case (preference == 'date' case)
113
+ if valid_from and compare_iso_time(valid_from, date):
114
+ continue
115
+
116
+ fixed_schedule_segments = sum([convert_time_to_segment(self._START_HOUR,
117
+ self._END_HOUR,
118
+ self._TIME_UNIT,
119
+ fs) for fs in schedule], [])
120
+ all_time_segments = convert_time_to_segment(self._START_HOUR, self._END_HOUR, self._TIME_UNIT)
121
+ free_time = [s for s in range(len(all_time_segments)) if s not in fixed_schedule_segments]
122
+
123
+ if len(free_time):
124
+ valid_time_segments = [seg for seg in group_consecutive_segments(free_time) if len(seg) >= min_time_slot_n]
125
+ for valid_time in valid_time_segments:
126
+ if (valid_time[0] < prediction_schedule_segments[0] and pred_date == date) or (len(valid_time) and compare_iso_time(pred_date, date)):
127
+ free_max_st, _ = convert_segment_to_time(self._START_HOUR, self._END_HOUR, self._TIME_UNIT, [valid_time[0]])
128
+ free_max_st_iso = get_iso_time(free_max_st, date, utc_offset=utc_offset)
129
+ if compare_iso_time(free_max_st_iso, current_time):
130
+ return False
131
+ return True
132
+
133
+
134
+ def schedule_check(self,
135
+ prediction: Union[str, dict],
136
+ gt_patient_condition: dict,
137
+ doctor_information: dict,
138
+ environment: HospitalEnvironment) -> Tuple[bool, str]:
139
+ """
140
+ Validates a predicted schedule for a doctor by checking its structure, time validity,
141
+ duplication with existing schedules, and updates the doctor's schedule if valid.
142
+
143
+ Args:
144
+ prediction (Union[str, dict]): The predicted allocation result, either a string (if parsing failed)
145
+ or a dictionary mapping a doctor's name to a schedule with 'start' and 'end' times.
146
+ gt_patient_condition (dict): Ground truth patient conditions used only for sanity checks.
147
+ doctor_information (dict): Dictionary of doctor data including their existing schedules.
148
+ Each key is a doctor's name, and each value includes a 'schedule' field.
149
+ environment (HospitalEnvironment): Hospital environment instance to manage patient schedules.
150
+
151
+ Returns:
152
+ Tuple[bool, str]:
153
+ - A boolean indicating whether the prediction passed all sanity checks.
154
+ - A string explaining its status.
155
+ """
156
+ ############################ Check the prediciton format #############################
157
+ if not isinstance(prediction, dict):
158
+ return False, STATUS_CODES['format'] # Could not be parsed as a dictionary
159
+ elif len(prediction['schedule']) > 1:
160
+ return False, STATUS_CODES['conflict']['physician'] # Allocated more than one doctor; cannot determine target
161
+
162
+ ################## Check the predicted schedule type and validities ##################
163
+ try:
164
+ pred_doctor_name = list(prediction['schedule'].keys())[0]
165
+ start = prediction['schedule'][pred_doctor_name]['start']
166
+ end = prediction['schedule'][pred_doctor_name]['end']
167
+ date = prediction['schedule'][pred_doctor_name]['date']
168
+ fixed_schedules = doctor_information[pred_doctor_name]['schedule']
169
+ start_iso_time = get_iso_time(start, date, utc_offset=environment._utc_offset)
170
+ assert isinstance(start, float) and isinstance(end, float) and isinstance(date, str) \
171
+ and start < end and start >= self._START_HOUR and end <= self._END_HOUR \
172
+ and compare_iso_time(start_iso_time, environment.current_time) and date in fixed_schedules
173
+ assert gt_patient_condition['department'] == doctor_information[pred_doctor_name]['department']
174
+
175
+ # Duration mismatched case
176
+ if not float(Decimal(str(1)) / Decimal(str(doctor_information[pred_doctor_name]['capacity_per_hour']))) == float(Decimal(str(end)) - Decimal(str(start))):
177
+ return False, STATUS_CODES['duration']
178
+
179
+ except KeyError:
180
+ return False, STATUS_CODES['format'] # Schedule allocation missing or doctor not found
181
+ except AssertionError:
182
+ return False, STATUS_CODES['schedule'] # Invalid schedule times or department
183
+
184
+ ####################### Check the duplication of the schedules #######################
185
+ prediction_schedule_segments = convert_time_to_segment(self._START_HOUR,
186
+ self._END_HOUR,
187
+ self._TIME_UNIT,
188
+ [start, end])
189
+ fixed_schedule_segments = sum([convert_time_to_segment(self._START_HOUR,
190
+ self._END_HOUR,
191
+ self._TIME_UNIT,
192
+ fs) for fs in fixed_schedules[date]], [])
193
+
194
+ if len(set(prediction_schedule_segments) & set(fixed_schedule_segments)):
195
+ return False, STATUS_CODES['conflict']['time'] # Overlaps with an existing schedule
196
+
197
+ ####################### Check the patient's preferences #######################
198
+ if gt_patient_condition['preference'] == 'doctor':
199
+ if gt_patient_condition.get('preferred_doctor') != pred_doctor_name:
200
+ return False, STATUS_CODES['preference']['physician']
201
+
202
+ if gt_patient_condition['preference'] == 'date':
203
+ if compare_iso_time(gt_patient_condition.get('valid_from'), date):
204
+ return False, STATUS_CODES['preference']['date']
205
+
206
+ is_earliest = self.__check_is_earliest(
207
+ prediction,
208
+ gt_patient_condition,
209
+ doctor_information,
210
+ environment,
211
+ )
212
+
213
+ if not is_earliest:
214
+ return False, STATUS_CODES['preference']['asap']
215
+
216
+ return True, STATUS_CODES['correct']
@@ -0,0 +1,420 @@
1
+ from copy import deepcopy
2
+ from decimal import Decimal
3
+ from typing import Optional
4
+ from langchain.tools import tool
5
+ from langchain.agents import AgentExecutor
6
+
7
+ from .data_converter import DataConverter
8
+ from h_adminsim.registry import STATUS_CODES
9
+ from h_adminsim.utils import log
10
+ from h_adminsim.utils.fhir_utils import *
11
+ from h_adminsim.utils.common_utils import (
12
+ group_consecutive_segments,
13
+ convert_segment_to_time,
14
+ convert_time_to_segment,
15
+ init_result_dict,
16
+ compare_iso_time,
17
+ get_iso_time,
18
+ )
19
+
20
+
21
+
22
+ class SchedulingRule:
23
+ def __init__(self,
24
+ metadata: dict,
25
+ department_data: dict,
26
+ environment,
27
+ fhir_intergration: bool = False):
28
+ self.environment = environment
29
+ self._current_time = self.environment.current_time
30
+ self._utc_offset = self.environment._utc_offset
31
+ self._metadata = metadata
32
+ self._department_data = department_data
33
+ self._START_HOUR = self._metadata.get('time').get('start_hour')
34
+ self._END_HOUR = self._metadata.get('time').get('end_hour')
35
+ self._TIME_UNIT = self._metadata.get('time').get('interval_hour')
36
+ self.current_time = environment.current_time
37
+ self.fhir_integration = fhir_intergration
38
+
39
+
40
+ def physician_filter(self, filtered_doctor_information: dict, preferred_doctor: str) -> list[str]:
41
+ """
42
+ Filter schedules by preferred doctor.
43
+
44
+ Args:
45
+ filtered_doctor_information (dict): Filtered doctor information after department filtering.
46
+ preferred_doctor (str): The identifier of the preferred doctor.
47
+
48
+ Returns:
49
+ list[str]: A set of candidate schedules that match the preferred doctor.
50
+ """
51
+ candidate_schedules = set()
52
+ schedule_info = filtered_doctor_information['doctor'][preferred_doctor]
53
+ schedule_candidates = schedule_info['schedule']
54
+ min_time_slot_n = int(Decimal(str(schedule_info['outpatient_duration'])) / Decimal(str(self._TIME_UNIT)))
55
+ dates = sorted(list(schedule_candidates.keys()))
56
+ for date in dates:
57
+ schedule = schedule_candidates[date]
58
+ fixed_schedule_segments = sum([convert_time_to_segment(self._START_HOUR,
59
+ self._END_HOUR,
60
+ self._TIME_UNIT,
61
+ fs) for fs in schedule], [])
62
+ all_time_segments = convert_time_to_segment(self._START_HOUR, self._END_HOUR, self._TIME_UNIT)
63
+ free_time = [s for s in range(len(all_time_segments)) if s not in fixed_schedule_segments]
64
+
65
+ if len(free_time):
66
+ valid_time_segments = [seg for seg in group_consecutive_segments(free_time) if len(seg) >= min_time_slot_n]
67
+ for valid_time in valid_time_segments:
68
+ for i in range(len(valid_time) - min_time_slot_n + 1):
69
+ time_slot = valid_time[i:i+min_time_slot_n]
70
+ free_max_st, _ = convert_segment_to_time(self._START_HOUR, self._END_HOUR, self._TIME_UNIT, [time_slot[0]])
71
+ free_max_st_iso = get_iso_time(free_max_st, date, utc_offset=self._utc_offset)
72
+ if compare_iso_time(free_max_st_iso, self._current_time):
73
+ candidate_schedules.add(f"{preferred_doctor};;;{free_max_st_iso}")
74
+
75
+ return list(candidate_schedules)
76
+
77
+
78
+ def date_filter(self, filtered_doctor_information: dict, valid_date: str) -> list[str]:
79
+ """
80
+ Filter schedules by valid date.
81
+
82
+ Args:
83
+ filtered_doctor_information (dict): Filtered doctor information after department filtering.
84
+ valid_date (str): The valid date from which to consider schedules.
85
+
86
+ Returns:
87
+ list[str]: A set of candidate schedules that are on or after the valid date.
88
+ """
89
+ candidate_schedules = set()
90
+ schedule_infos = filtered_doctor_information['doctor']
91
+
92
+ for doctor, schedule_info in schedule_infos.items():
93
+ min_time_slot_n = int(Decimal(str(schedule_info['outpatient_duration'])) / Decimal(str(self._TIME_UNIT)))
94
+ dates = sorted(list(schedule_info['schedule'].keys()))
95
+
96
+ for date in dates:
97
+ if not compare_iso_time(valid_date, date):
98
+ schedule = schedule_info['schedule'][date]
99
+ fixed_schedule_segments = sum([convert_time_to_segment(self._START_HOUR,
100
+ self._END_HOUR,
101
+ self._TIME_UNIT,
102
+ fs) for fs in schedule], [])
103
+ all_time_segments = convert_time_to_segment(self._START_HOUR, self._END_HOUR, self._TIME_UNIT)
104
+ free_time = [s for s in range(len(all_time_segments)) if s not in fixed_schedule_segments]
105
+
106
+ if len(free_time):
107
+ valid_time_segments = [seg for seg in group_consecutive_segments(free_time) if len(seg) >= min_time_slot_n]
108
+ for valid_time in valid_time_segments:
109
+ for i in range(len(valid_time) - min_time_slot_n + 1):
110
+ time_slot = valid_time[i:i+min_time_slot_n]
111
+ free_max_st, _ = convert_segment_to_time(self._START_HOUR, self._END_HOUR, self._TIME_UNIT, [time_slot[0]])
112
+ free_max_st_iso = get_iso_time(free_max_st, date, utc_offset=self._utc_offset)
113
+ if compare_iso_time(free_max_st_iso, self._current_time):
114
+ candidate_schedules.add(f"{doctor};;;{free_max_st_iso}")
115
+
116
+ return list(candidate_schedules)
117
+
118
+
119
+ def get_all(self, filtered_doctor_information: dict) -> list[str]:
120
+ """
121
+ Get all candidate schedules without any filtering.
122
+
123
+ Args:
124
+ filtered_doctor_information (dict): Filtered doctor information after department filtering.
125
+
126
+ Returns:
127
+ list[str]: A set of all candidate schedules without any filtering.
128
+ """
129
+ candidate_schedules = set()
130
+ schedule_infos = filtered_doctor_information['doctor']
131
+
132
+ for doctor, schedule_info in schedule_infos.items():
133
+ min_time_slot_n = int(Decimal(str(schedule_info['outpatient_duration'])) / Decimal(str(self._TIME_UNIT)))
134
+ dates = sorted(list(schedule_info['schedule'].keys()))
135
+
136
+ for date in dates:
137
+ schedule = schedule_info['schedule'][date]
138
+ fixed_schedule_segments = sum([convert_time_to_segment(self._START_HOUR,
139
+ self._END_HOUR,
140
+ self._TIME_UNIT,
141
+ fs) for fs in schedule], [])
142
+ all_time_segments = convert_time_to_segment(self._START_HOUR, self._END_HOUR, self._TIME_UNIT)
143
+ free_time = [s for s in range(len(all_time_segments)) if s not in fixed_schedule_segments]
144
+
145
+ if len(free_time):
146
+ valid_time_segments = [seg for seg in group_consecutive_segments(free_time) if len(seg) >= min_time_slot_n]
147
+ for valid_time in valid_time_segments:
148
+ for i in range(len(valid_time) - min_time_slot_n + 1):
149
+ time_slot = valid_time[i:i+min_time_slot_n]
150
+ free_max_st, _ = convert_segment_to_time(self._START_HOUR, self._END_HOUR, self._TIME_UNIT, [time_slot[0]])
151
+ free_max_st_iso = get_iso_time(free_max_st, date, utc_offset=self._utc_offset)
152
+ if compare_iso_time(free_max_st_iso, self._current_time):
153
+ candidate_schedules.add(f"{doctor};;;{free_max_st_iso}")
154
+
155
+ return list(candidate_schedules)
156
+
157
+
158
+ def find_idx(self, patient_schedule_list: list[dict], patient_name: str, doctor_name: str, date: str) -> int:
159
+ """
160
+ Identify the index of the appointment corresponding to the patient's request
161
+ (e.g., cancellation or modification) from the patient's schedule list.
162
+
163
+ Args:
164
+ patient_schedule_list (list[dict]): A list of the patient's scheduled appointments.
165
+ Each item contains appointment details such as doctor name, date, and time.
166
+ patient_name (str): Name of the patient making the request.
167
+ doctor_name (str): Name of the doctor associated with the target appointment.
168
+ date (str): Date of the target appointment (YYYY-MM-DD).
169
+
170
+ Returns:
171
+ int: The index of the appointment that matches the patient's request.
172
+ """
173
+ for idx, patient_schedule in enumerate(patient_schedule_list):
174
+ if patient_schedule['status'] == 'scheduled' and \
175
+ patient_schedule['patient'].lower() == patient_name.lower() and \
176
+ patient_schedule['attending_physician'].lower() == doctor_name.lower() \
177
+ and patient_schedule['date'] == date:
178
+ return idx
179
+ return -1
180
+
181
+
182
+ def find_earliest_time(self, schedules: list[str], delimiter: str = ';;;') -> dict:
183
+ """
184
+ Find the earliest schedule from the list of schedules.
185
+
186
+ Args:
187
+ schedules (list[str]): A list of schedules in the format "doctor;;;iso_time".
188
+ delimiter (str, optional): The delimiter used to split doctor and iso_time. Defaults to ';;;'.
189
+
190
+ Returns:
191
+ dict: A dictionary containing the earliest doctor(s) and their corresponding schedule(s).
192
+ """
193
+ earliest_doctor, earliest_time = list(), list()
194
+
195
+ for schedule in schedules:
196
+ doctor, iso_time = schedule.split(delimiter)
197
+
198
+ # skip when the slot is earlier than the current time
199
+ if not compare_iso_time(iso_time, self.current_time):
200
+ continue
201
+
202
+ if not len(earliest_doctor):
203
+ earliest_doctor.append(doctor)
204
+ earliest_time.append(iso_time)
205
+ continue
206
+
207
+ # Append if the iso_time is same with the alrealdy appended one
208
+ if earliest_time[0] == iso_time:
209
+ earliest_doctor.append(doctor)
210
+ earliest_time.append(iso_time)
211
+
212
+ elif compare_iso_time(earliest_time[0], iso_time):
213
+ earliest_doctor = [doctor]
214
+ earliest_time = [iso_time]
215
+
216
+ return {'doctor': earliest_doctor, 'schedule': earliest_time}
217
+
218
+
219
+ def cancel_schedule(self,
220
+ idx: int,
221
+ doctor_info: dict,
222
+ cancelled_schedule: dict) -> dict:
223
+ """
224
+ Cancel the schedule both in doctor_info and FHIR system.
225
+
226
+ Args:
227
+ idx (int): The index of the appointment to be cancelled.
228
+ doctor_info (dict): The doctor information containing schedules.
229
+ cancelled_schedule (dict): The schedule details to be cancelled.
230
+
231
+ Returns:
232
+ dict: Updated doctor information after cancellation.
233
+ """
234
+ doctor, date, time = cancelled_schedule['attending_physician'], cancelled_schedule['date'], cancelled_schedule['schedule']
235
+ schedule_list = doctor_info[doctor]['schedule'][date]
236
+
237
+ # Remove from doctor_information
238
+ schedule_list.remove(time) # In-place logic
239
+
240
+ # Remove from FHIR
241
+ if self.fhir_integration:
242
+ fhir_appointment = DataConverter.get_fhir_appointment(data={'metadata': deepcopy(self._metadata),
243
+ 'department': deepcopy(self._department_data),
244
+ 'information': deepcopy(cancelled_schedule)})
245
+ self.environment.delete_fhir({'Appointment': fhir_appointment})
246
+
247
+ # Remove from environment patient_schedules
248
+ self.environment.schedule_cancel_event(idx, True)
249
+
250
+ return doctor_info
251
+
252
+
253
+ def create_tools(rule: SchedulingRule,
254
+ doctor_info: dict,
255
+ patient_schedule_list: Optional[list[dict]] = None,
256
+ gt_idx: Optional[int] = None,
257
+ only_schedule_tool: bool = False) -> list[tool]:
258
+ @tool
259
+ def physician_filter_tool(preferred_doctor: str) -> dict:
260
+ """
261
+ Return the earliest available schedule for a preferred doctor.
262
+
263
+ Args:
264
+ preferred_doctor: Name of the preferred doctor
265
+
266
+ Returns:
267
+ dict: The earliest physician-filtered time slot and its information.
268
+ """
269
+ log(f'[TOOL CALL] physician_filter_tool | preferred_doctor={preferred_doctor}', color=True)
270
+ prefix = 'Dr.'
271
+ if prefix not in preferred_doctor:
272
+ preferred_doctor = f'{prefix} {preferred_doctor}'
273
+ schedules = rule.physician_filter(doctor_info, preferred_doctor)
274
+ schedule = rule.find_earliest_time(schedules)
275
+ return schedule
276
+
277
+ @tool
278
+ def date_filter_tool(valid_date: str) -> dict:
279
+ """
280
+ Return the earliest available schedule after a specific date.
281
+
282
+ Args:
283
+ valid_date: Date in YYYY-MM-DD format.
284
+
285
+ Returns:
286
+ dict: The earliest date-filtered time slot and its information.
287
+ """
288
+ log(f'[TOOL CALL] date_filter_tool | valid_date={valid_date}', color=True)
289
+ schedules = rule.date_filter(doctor_info, valid_date)
290
+ schedule = rule.find_earliest_time(schedules)
291
+ return schedule
292
+
293
+ @tool
294
+ def get_all_time_tool() -> dict:
295
+ """
296
+ Return the earliest available schedule among the all available time slots.
297
+
298
+ Returns:
299
+ dict: The earliest time slot and its information.
300
+ """
301
+ log(f'[TOOL CALL] get_all_time_tool', color=True)
302
+ schedules = rule.get_all(doctor_info)
303
+ schedule = rule.find_earliest_time(schedules)
304
+ return schedule
305
+
306
+ @tool
307
+ def cancel_tool(patient_name: str, doctor_name: str, date: str) -> dict:
308
+ """
309
+ Identify the index of the appointment to be cancelled from the patient's schedule list.
310
+
311
+ Args:
312
+ patient_name (str): Name of the patient requesting the cancellation.
313
+ doctor_name (str): Name of the doctor for the appointment to be cancelled.
314
+ date (str): Date of the appointment to be cancelled (YYYY-MM-DD).
315
+
316
+ Returns:
317
+ dict: A dictionary containing the cancelled_schedule, result_dict, and updated_doctor_info.
318
+ """
319
+ log(f'[TOOL CALL] cancel_tool | patient_name={patient_name}, doctor_name={doctor_name}, date={date}', color=True)
320
+ result_dict, updated_doctor_info, cancelled_schedule = init_result_dict(), None, None
321
+ prefix = 'Dr.'
322
+ if prefix not in doctor_name:
323
+ doctor_name = f'{prefix} {doctor_name}'
324
+ index = rule.find_idx(patient_schedule_list, patient_name, doctor_name, date)
325
+
326
+ # Update result_dict
327
+ if gt_idx is None:
328
+ result_dict['gt'].append({'cancel': None})
329
+ result_dict['pred'].append({'cancel': index})
330
+ result_dict['status'].append(None)
331
+ result_dict['status_code'].append(None)
332
+ else:
333
+ status = True if index == gt_idx else False
334
+ status_code = STATUS_CODES['correct'] if index == gt_idx else STATUS_CODES['cancel']['identify']
335
+ result_dict['gt'].append({'cancel': gt_idx})
336
+ result_dict['pred'].append({'cancel': index})
337
+ result_dict['status'].append(status)
338
+ result_dict['status_code'].append(status_code)
339
+
340
+ # Update the schedule only when the cancellation is correct or there is no gt_idx
341
+ if gt_idx is None or status:
342
+ cancelled_schedule = patient_schedule_list[index]
343
+ updated_doctor_info = rule.cancel_schedule(index, doctor_info, cancelled_schedule)
344
+
345
+ return {'cancelled_schedule': cancelled_schedule, 'result_dict': result_dict, 'updated_doctor_info': updated_doctor_info}
346
+
347
+
348
+ @tool
349
+ def reschedule_tool(patient_name: str, doctor_name: str, date: str) -> dict:
350
+ """
351
+ Identify the index of the appointment to be rescheduled from the patient's schedule list.
352
+
353
+ Args:
354
+ patient_name (str): Name of the patient requesting the rescheduling.
355
+ doctor_name (str): Name of the doctor for the appointment to be rescheduled.
356
+ date (str): Date of the original appointment to be rescheduled (YYYY-MM-DD).
357
+
358
+ Returns:
359
+ dict: A dictionary containing the original_schedule and result_dict.
360
+ """
361
+ log(f'[TOOL CALL] reschedule_tool | patient_name={patient_name}, doctor_name={doctor_name}, date={date}', color=True)
362
+ result_dict, original_schedule = init_result_dict(), None
363
+ prefix = 'Dr.'
364
+ if prefix not in doctor_name:
365
+ doctor_name = f'{prefix} {doctor_name}'
366
+ index = rule.find_idx(patient_schedule_list, patient_name, doctor_name, date)
367
+
368
+ # Update result_dict
369
+ if gt_idx is None:
370
+ result_dict['gt'].append({'reschedule': None})
371
+ result_dict['pred'].append({'reschedule': index})
372
+ result_dict['status'].append(None)
373
+ result_dict['status_code'].append(None)
374
+ else:
375
+ status = True if index == gt_idx else False
376
+ status_code = STATUS_CODES['correct'] if index == gt_idx else STATUS_CODES['reschedule']['identify']
377
+ result_dict['gt'].append({'reschedule': gt_idx})
378
+ result_dict['pred'].append({'reschedule': index})
379
+ result_dict['status'].append(status)
380
+ result_dict['status_code'].append(status_code)
381
+
382
+ if gt_idx is None or status:
383
+ original_schedule = patient_schedule_list[index]
384
+
385
+ return {'original_schedule': original_schedule, 'result_dict': result_dict}
386
+
387
+ if only_schedule_tool:
388
+ return [physician_filter_tool, date_filter_tool, get_all_time_tool]
389
+ return [physician_filter_tool, date_filter_tool, get_all_time_tool, cancel_tool, reschedule_tool]
390
+
391
+
392
+
393
+ def scheduling_tool_calling(client: AgentExecutor,
394
+ user_prompt: str,
395
+ history: list = []) -> dict:
396
+ """
397
+ Make an appointment using tool-calling agent.
398
+
399
+ Args:
400
+ client (AgentExecutor): The agent executor to handle tool calls.
401
+ user_prompt (str): User prompt used for tool calling.
402
+ history (list, optional): A list of LangChain HumanMessage and AIMessage objects. Defaults to [].
403
+
404
+ Returns:
405
+ dict: A dictionary containing the scheduled doctor and their corresponding schedule.
406
+ """
407
+ inputs = {
408
+ "input": user_prompt,
409
+ "chat_history": history,
410
+ }
411
+ response = client.invoke(inputs)
412
+ steps = response.get("intermediate_steps") or []
413
+
414
+ if len(steps) > 0:
415
+ tool_output = steps[0][1]
416
+ return {"type": "tool", "result": tool_output, "raw": response}
417
+
418
+ # No tool call happened
419
+ text = response.get("output") or ""
420
+ return {"type": "text", "result": text, "raw": response}