edsl 0.1.42__py3-none-any.whl → 0.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/jobs/JobsPrompts.py CHANGED
@@ -1,3 +1,5 @@
1
+ import time
2
+ import logging
1
3
  from typing import List, TYPE_CHECKING
2
4
 
3
5
  from edsl.results.Dataset import Dataset
@@ -14,6 +16,7 @@ if TYPE_CHECKING:
14
16
  from edsl.jobs.FetchInvigilator import FetchInvigilator
15
17
  from edsl.data.CacheEntry import CacheEntry
16
18
 
19
+ logger = logging.getLogger(__name__)
17
20
 
18
21
  class JobsPrompts:
19
22
  def __init__(self, jobs: "Jobs"):
@@ -22,6 +25,8 @@ class JobsPrompts:
22
25
  self.scenarios = jobs.scenarios
23
26
  self.survey = jobs.survey
24
27
  self._price_lookup = None
28
+ self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
29
+ self._scenario_lookup = {scenario: idx for idx, scenario in enumerate(self.scenarios)}
25
30
 
26
31
  @property
27
32
  def price_lookup(self):
@@ -49,25 +54,53 @@ class JobsPrompts:
49
54
  models = []
50
55
  costs = []
51
56
  cache_keys = []
57
+
52
58
  for interview_index, interview in enumerate(interviews):
59
+ logger.info(f"Processing interview {interview_index} of {len(interviews)}")
60
+ interview_start = time.time()
61
+
62
+ # Fetch invigilators timing
63
+ invig_start = time.time()
53
64
  invigilators = [
54
65
  FetchInvigilator(interview)(question)
55
66
  for question in interview.survey.questions
56
67
  ]
68
+ invig_end = time.time()
69
+ logger.debug(f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s")
70
+
71
+ # Process prompts timing
72
+ prompts_start = time.time()
57
73
  for _, invigilator in enumerate(invigilators):
74
+ # Get prompts timing
75
+ get_prompts_start = time.time()
58
76
  prompts = invigilator.get_prompts()
77
+ get_prompts_end = time.time()
78
+ logger.debug(f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s")
79
+
59
80
  user_prompt = prompts["user_prompt"]
60
81
  system_prompt = prompts["system_prompt"]
61
82
  user_prompts.append(user_prompt)
62
83
  system_prompts.append(system_prompt)
63
- agent_index = self.agents.index(invigilator.agent)
84
+
85
+ # Index lookups timing
86
+ index_start = time.time()
87
+ agent_index = self._agent_lookup[invigilator.agent]
64
88
  agent_indices.append(agent_index)
65
89
  interview_indices.append(interview_index)
66
- scenario_index = self.scenarios.index(invigilator.scenario)
90
+ scenario_index = self._scenario_lookup[invigilator.scenario]
67
91
  scenario_indices.append(scenario_index)
92
+ index_end = time.time()
93
+ logger.debug(f"Time taken for index lookups: {index_end - index_start:.4f}s")
94
+
95
+ # Model and question name assignment timing
96
+ assign_start = time.time()
68
97
  models.append(invigilator.model.model)
69
98
  question_names.append(invigilator.question.question_name)
99
+ assign_end = time.time()
100
+ logger.debug(f"Time taken for assignments: {assign_end - assign_start:.4f}s")
70
101
 
102
+ # Cost estimation timing
103
+ cost_start = time.time()
71
104
  prompt_cost = self.estimate_prompt_cost(
72
105
  system_prompt=system_prompt,
73
106
  user_prompt=user_prompt,
@@ -75,16 +108,34 @@ class JobsPrompts:
75
108
  inference_service=invigilator.model._inference_service_,
76
109
  model=invigilator.model.model,
77
110
  )
111
+ cost_end = time.time()
112
+ logger.debug(f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s")
78
113
  costs.append(prompt_cost["cost_usd"])
79
114
 
115
+ # Cache key generation timing
116
+ cache_key_gen_start = time.time()
80
117
  cache_key = CacheEntry.gen_key(
81
118
  model=invigilator.model.model,
82
119
  parameters=invigilator.model.parameters,
83
120
  system_prompt=system_prompt,
84
121
  user_prompt=user_prompt,
85
- iteration=0, # TODO how to handle when there are multiple iterations?
122
+ iteration=0,
86
123
  )
124
+ cache_key_gen_end = time.time()
87
125
  cache_keys.append(cache_key)
126
+ logger.debug(f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s")
127
+ logger.debug("-" * 50) # Separator between iterations
128
+
129
+ prompts_end = time.time()
130
+ logger.info(f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s")
131
+
132
+ interview_end = time.time()
133
+ logger.info(f"Overall time taken for interview: {interview_end - interview_start:.4f}s")
134
+ logger.info("Time breakdown:")
135
+ logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
136
+ logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
137
+ logger.info(f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s")
138
+
88
139
  d = Dataset(
89
140
  [
90
141
  {"user_prompt": user_prompts},
@@ -228,6 +228,40 @@ class JobsRemoteInferenceHandler:
228
228
  results.results_uuid = results_uuid
229
229
  return results
230
230
 
231
+ def _attempt_fetch_job(
232
+ self,
233
+ job_info: RemoteJobInfo,
234
+ remote_job_data_fetcher: Callable,
235
+ object_fetcher: Callable,
236
+ ) -> Union[None, "Results", Literal["continue"]]:
237
+ """Makes one attempt to fetch and process a remote job's status and results."""
238
+ remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
239
+ status = remote_job_data.get("status")
240
+
241
+ if status == "cancelled":
242
+ self._handle_cancelled_job(job_info)
243
+ return None
244
+
245
+ elif status == "failed" or status == "completed":
246
+ if status == "failed":
247
+ self._handle_failed_job(job_info, remote_job_data)
248
+
249
+ results_uuid = remote_job_data.get("results_uuid")
250
+ if results_uuid:
251
+ results = self._fetch_results_and_log(
252
+ job_info=job_info,
253
+ results_uuid=results_uuid,
254
+ remote_job_data=remote_job_data,
255
+ object_fetcher=object_fetcher,
256
+ )
257
+ return results
258
+ else:
259
+ return None
260
+
261
+ else:
262
+ self._sleep_for_a_bit(job_info, status)
263
+ return "continue"
264
+
231
265
  def poll_remote_inference_job(
232
266
  self,
233
267
  job_info: RemoteJobInfo,
@@ -242,31 +276,13 @@ class JobsRemoteInferenceHandler:
242
276
 
243
277
  job_in_queue = True
244
278
  while job_in_queue:
245
- remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
246
- status = remote_job_data.get("status")
247
-
248
- if status == "cancelled":
249
- self._handle_cancelled_job(job_info)
250
- return None
251
-
252
- elif status == "failed" or status == "completed":
253
- if status == "failed":
254
- self._handle_failed_job(job_info, remote_job_data)
255
-
256
- results_uuid = remote_job_data.get("results_uuid")
257
- if results_uuid:
258
- results = self._fetch_results_and_log(
259
- job_info=job_info,
260
- results_uuid=results_uuid,
261
- remote_job_data=remote_job_data,
262
- object_fetcher=object_fetcher,
263
- )
264
- return results
265
- else:
266
- return None
267
-
268
- else:
269
- self._sleep_for_a_bit(job_info, status)
279
+ result = self._attempt_fetch_job(
280
+ job_info,
281
+ remote_job_data_fetcher,
282
+ object_fetcher
283
+ )
284
+ if result != "continue":
285
+ return result
270
286
 
271
287
  async def create_and_poll_remote_job(
272
288
  self,
@@ -96,6 +96,36 @@ class BucketCollection(UserDict):
96
96
  else:
97
97
  self[model] = self.services_to_buckets[self.models_to_services[model.model]]
98
98
 
99
+ def update_from_key_lookup(self, key_lookup: "KeyLookup") -> None:
100
+ """Updates the bucket collection rates based on model RPM/TPM from KeyLookup"""
101
+
102
+ for model_name, service in self.models_to_services.items():
103
+ if service in key_lookup and not self.infinity_buckets:
104
+
105
+ if key_lookup[service].rpm is not None:
106
+ new_rps = key_lookup[service].rpm / 60.0
107
+ new_requests_bucket = TokenBucket(
108
+ bucket_name=service,
109
+ bucket_type="requests",
110
+ capacity=new_rps,
111
+ refill_rate=new_rps,
112
+ remote_url=self.remote_url,
113
+ )
114
+ self.services_to_buckets[service].requests_bucket = (
115
+ new_requests_bucket
116
+ )
117
+
118
+ if key_lookup[service].tpm is not None:
119
+ new_tps = key_lookup[service].tpm / 60.0
120
+ new_tokens_bucket = TokenBucket(
121
+ bucket_name=service,
122
+ bucket_type="tokens",
123
+ capacity=new_tps,
124
+ refill_rate=new_tps,
125
+ remote_url=self.remote_url,
126
+ )
127
+ self.services_to_buckets[service].tokens_bucket = new_tokens_bucket
128
+
99
129
  def visualize(self) -> dict:
100
130
  """Visualize the token and request buckets for each model."""
101
131
  plots = {}
@@ -32,6 +32,7 @@ class RunParameters(Base):
32
32
  remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
33
33
  skip_retry: bool = False
34
34
  raise_validation_errors: bool = False
35
+ background: bool = False
35
36
  disable_remote_cache: bool = False
36
37
  disable_remote_inference: bool = False
37
38
  job_uuid: Optional[str] = None
@@ -61,7 +61,14 @@ class KeyLookupBuilder:
61
61
  DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
62
62
  DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))
63
63
 
64
- def __init__(self, fetch_order: Optional[tuple[str]] = None):
64
+ def __init__(
65
+ self,
66
+ fetch_order: Optional[tuple[str]] = None,
67
+ coop: Optional["Coop"] = None,
68
+ ):
69
+ from edsl.coop import Coop
70
+
71
+ # Fetch order goes from lowest priority to highest priority
65
72
  if fetch_order is None:
66
73
  self.fetch_order = ("config", "env")
67
74
  else:
@@ -70,6 +77,11 @@ class KeyLookupBuilder:
70
77
  if not isinstance(self.fetch_order, tuple):
71
78
  raise ValueError("fetch_order must be a tuple")
72
79
 
80
+ if coop is None:
81
+ self.coop = Coop()
82
+ else:
83
+ self.coop = coop
84
+
73
85
  self.limit_data = {}
74
86
  self.key_data = {}
75
87
  self.id_data = {}
@@ -131,7 +143,8 @@ class KeyLookupBuilder:
131
143
  service=service,
132
144
  rpm=self.DEFAULT_RPM,
133
145
  tpm=self.DEFAULT_TPM,
134
- source="default",
146
+ rpm_source="default",
147
+ tpm_source="default",
135
148
  )
136
149
 
137
150
  if limit_entry.rpm is None:
@@ -145,7 +158,8 @@ class KeyLookupBuilder:
145
158
  tpm=int(limit_entry.tpm),
146
159
  api_id=api_id,
147
160
  token_source=api_key_entry.source,
148
- limit_source=limit_entry.source,
161
+ rpm_source=limit_entry.rpm_source,
162
+ tpm_source=limit_entry.tpm_source,
149
163
  id_source=id_source,
150
164
  )
151
165
 
@@ -156,10 +170,7 @@ class KeyLookupBuilder:
156
170
  return dict(list(os.environ.items()))
157
171
 
158
172
  def _coop_key_value_pairs(self):
159
- from edsl.coop import Coop
160
-
161
- c = Coop()
162
- return dict(list(c.fetch_rate_limit_config_vars().items()))
173
+ return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
163
174
 
164
175
  def _config_key_value_pairs(self):
165
176
  from edsl.config import CONFIG
@@ -169,7 +180,7 @@ class KeyLookupBuilder:
169
180
  @staticmethod
170
181
  def extract_service(key: str) -> str:
171
182
  """Extract the service and limit type from the key"""
172
- limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
183
+ limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
173
184
  return service_raw.lower(), limit_type.lower()
174
185
 
175
186
  def get_key_value_pairs(self) -> dict:
@@ -187,17 +198,17 @@ class KeyLookupBuilder:
187
198
  d[k] = (v, source)
188
199
  return d
189
200
 
190
- def _entry_type(self, key, value) -> str:
201
+ def _entry_type(self, key: str) -> str:
191
202
  """Determine the type of entry from a key.
192
203
 
193
204
  >>> builder = KeyLookupBuilder()
194
- >>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI", "60")
205
+ >>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
195
206
  'limit'
196
- >>> builder._entry_type("OPENAI_API_KEY", "sk-1234")
207
+ >>> builder._entry_type("OPENAI_API_KEY")
197
208
  'api_key'
198
- >>> builder._entry_type("AWS_ACCESS_KEY_ID", "AKIA1234")
209
+ >>> builder._entry_type("AWS_ACCESS_KEY_ID")
199
210
  'api_id'
200
- >>> builder._entry_type("UNKNOWN_KEY", "value")
211
+ >>> builder._entry_type("UNKNOWN_KEY")
201
212
  'unknown'
202
213
  """
203
214
  if key.startswith("EDSL_SERVICE_"):
@@ -243,11 +254,13 @@ class KeyLookupBuilder:
243
254
  service, limit_type = self.extract_service(key)
244
255
  if service in self.limit_data:
245
256
  setattr(self.limit_data[service], limit_type.lower(), value)
257
+ setattr(self.limit_data[service], f"{limit_type}_source", source)
246
258
  else:
247
259
  new_limit_entry = LimitEntry(
248
- service=service, rpm=None, tpm=None, source=source
260
+ service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
249
261
  )
250
262
  setattr(new_limit_entry, limit_type.lower(), value)
263
+ setattr(new_limit_entry, f"{limit_type}_source", source)
251
264
  self.limit_data[service] = new_limit_entry
252
265
 
253
266
  def _add_api_key(self, key: str, value: str, source: str) -> None:
@@ -265,13 +278,27 @@ class KeyLookupBuilder:
265
278
  else:
266
279
  self.key_data[service].append(new_entry)
267
280
 
268
- def process_key_value_pairs(self) -> None:
269
- """Process all key-value pairs from the configured sources."""
270
- for key, value_pair in self.get_key_value_pairs().items():
281
+ def update_from_dict(self, d: dict) -> None:
282
+ """
283
+ Update data from a dictionary of key-value pairs.
284
+ Each key is a key name, and each value is a tuple of (value, source).
285
+
286
+ >>> builder = KeyLookupBuilder()
287
+ >>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
288
+ >>> 'sk-1234' == builder.key_data["openai"][-1].value
289
+ True
290
+ >>> 'custodial_keys' == builder.key_data["openai"][-1].source
291
+ True
292
+ """
293
+ for key, value_pair in d.items():
271
294
  value, source = value_pair
272
- if (entry_type := self._entry_type(key, value)) == "limit":
295
+ if self._entry_type(key) == "limit":
273
296
  self._add_limit(key, value, source)
274
- elif entry_type == "api_key":
297
+ elif self._entry_type(key) == "api_key":
275
298
  self._add_api_key(key, value, source)
276
- elif entry_type == "api_id":
299
+ elif self._entry_type(key) == "api_id":
277
300
  self._add_id(key, value, source)
301
+
302
+ def process_key_value_pairs(self) -> None:
303
+ """Process all key-value pairs from the configured sources."""
304
+ self.update_from_dict(self.get_key_value_pairs())
@@ -40,18 +40,23 @@ class LimitEntry:
40
40
  60
41
41
  >>> limit.tpm
42
42
  100000
43
- >>> limit.source
43
+ >>> limit.rpm_source
44
44
  'config'
45
+ >>> limit.tpm_source
46
+ 'env'
45
47
  """
46
48
 
47
49
  service: str
48
50
  rpm: int
49
51
  tpm: int
50
- source: Optional[str] = None
52
+ rpm_source: Optional[str] = None
53
+ tpm_source: Optional[str] = None
51
54
 
52
55
  @classmethod
53
56
  def example(cls):
54
- return LimitEntry(service="openai", rpm=60, tpm=100000, source="config")
57
+ return LimitEntry(
58
+ service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
59
+ )
55
60
 
56
61
 
57
62
  @dataclass
@@ -108,7 +113,8 @@ class LanguageModelInput:
108
113
  tpm: int
109
114
  api_id: Optional[str] = None
110
115
  token_source: Optional[str] = None
111
- limit_source: Optional[str] = None
116
+ rpm_source: Optional[str] = None
117
+ tpm_source: Optional[str] = None
112
118
  id_source: Optional[str] = None
113
119
 
114
120
  def to_dict(self):
edsl/prompts/Prompt.py CHANGED
@@ -10,6 +10,48 @@ from edsl.Base import PersistenceMixin, RepresentationMixin
10
10
 
11
11
  MAX_NESTING = 100
12
12
 
13
+ from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
14
+ from functools import lru_cache
15
+
16
+ class PreserveUndefined(Undefined):
17
+ def __str__(self):
18
+ return "{{ " + str(self._undefined_name) + " }}"
19
+
20
+ # Create environment once at module level
21
+ _env = Environment(undefined=PreserveUndefined)
22
+
23
+ @lru_cache(maxsize=1024)
24
+ def _compile_template(text: str):
25
+ return _env.from_string(text)
26
+
27
+ @lru_cache(maxsize=1024)
28
+ def _find_template_variables(template: str) -> list[str]:
29
+ """Find and return the template variables."""
30
+ ast = _env.parse(template)
31
+ return list(meta.find_undeclared_variables(ast))
32
+
33
+ def _make_hashable(value):
34
+ """Convert unhashable types to hashable ones."""
35
+ if isinstance(value, list):
36
+ return tuple(_make_hashable(item) for item in value)
37
+ if isinstance(value, dict):
38
+ return frozenset((k, _make_hashable(v)) for k, v in value.items())
39
+ return value
40
+
41
+ @lru_cache(maxsize=1024)
42
+ def _cached_render(text: str, frozen_replacements: frozenset) -> str:
43
+ """Cached version of template rendering with frozen replacements."""
44
+ # Print cache info on every call
45
+ cache_info = _cached_render.cache_info()
46
+ print(f"\t\t\t\t\t Cache status - hits: {cache_info.hits}, misses: {cache_info.misses}, current size: {cache_info.currsize}")
47
+
48
+ # Convert back to dict with original types for rendering
49
+ replacements = {k: v for k, v in frozen_replacements}
50
+
51
+ template = _compile_template(text)
52
+ result = template.render(replacements)
53
+
54
+ return result
13
55
 
14
56
  class Prompt(PersistenceMixin, RepresentationMixin):
15
57
  """Class for creating a prompt to be used in a survey."""
@@ -145,33 +187,8 @@ class Prompt(PersistenceMixin, RepresentationMixin):
145
187
  return f'Prompt(text="""{self.text}""")'
146
188
 
147
189
  def template_variables(self) -> list[str]:
148
- """Return the the variables in the template.
149
-
150
- Example:
151
-
152
- >>> p = Prompt("Hello, {{person}}")
153
- >>> p.template_variables()
154
- ['person']
155
-
156
- """
157
- return self._template_variables(self.text)
158
-
159
- @staticmethod
160
- def _template_variables(template: str) -> list[str]:
161
- """Find and return the template variables.
162
-
163
- :param template: The template to find the variables in.
164
-
165
- """
166
- from jinja2 import Environment, meta, Undefined
167
-
168
- class PreserveUndefined(Undefined):
169
- def __str__(self):
170
- return "{{ " + str(self._undefined_name) + " }}"
171
-
172
- env = Environment(undefined=PreserveUndefined)
173
- ast = env.parse(template)
174
- return list(meta.find_undeclared_variables(ast))
190
+ """Return the variables in the template."""
191
+ return _find_template_variables(self.text)
175
192
 
176
193
  def undefined_template_variables(self, replacement_dict: dict):
177
194
  """Return the variables in the template that are not in the replacement_dict.
@@ -239,45 +256,39 @@ class Prompt(PersistenceMixin, RepresentationMixin):
239
256
  return self
240
257
 
241
258
  @staticmethod
242
- def _render(
243
- text: str, primary_replacement, **additional_replacements
244
- ) -> "PromptBase":
245
- """Render the template text with variables replaced from the provided named dictionaries.
246
-
247
- :param text: The text to render.
248
- :param primary_replacement: The primary replacement dictionary.
249
- :param additional_replacements: Additional replacement dictionaries.
250
-
251
- Allows for nested variable resolution up to a specified maximum nesting depth.
252
-
253
- Example:
254
-
255
- >>> codebook = {"age": "Age"}
256
- >>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
257
- >>> p.render({"name": "John", "age": 44}, codebook=codebook)
258
- Prompt(text=\"""You are an agent named John. Age: 44\""")
259
- """
260
- from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
261
-
262
- class PreserveUndefined(Undefined):
263
- def __str__(self):
264
- return "{{ " + str(self._undefined_name) + " }}"
265
-
266
- env = Environment(undefined=PreserveUndefined)
259
+ def _render(text: str, primary_replacement, **additional_replacements) -> "PromptBase":
260
+ """Render the template text with variables replaced."""
261
+ import time
262
+
263
+ # if there are no replacements, return the text
264
+ if not primary_replacement and not additional_replacements:
265
+ return text
266
+
267
267
  try:
268
+ variables = _find_template_variables(text)
269
+
270
+ if not variables: # if there are no variables, return the text
271
+ return text
272
+
273
+ # Combine all replacements
274
+ all_replacements = {**primary_replacement, **additional_replacements}
275
+
268
276
  previous_text = None
277
+ current_text = text
278
+ iteration = 0
279
+
269
280
  for _ in range(MAX_NESTING):
270
- # breakpoint()
271
- rendered_text = env.from_string(text).render(
272
- primary_replacement, **additional_replacements
273
- )
274
- if rendered_text == previous_text:
275
- # No more changes, so return the rendered text
281
+ iteration += 1
282
+
283
+ template = _compile_template(current_text)
284
+ rendered_text = template.render(all_replacements)
285
+
286
+ if rendered_text == current_text:
276
287
  return rendered_text
277
- previous_text = text
278
- text = rendered_text
288
+
289
+ previous_text = current_text
290
+ current_text = rendered_text
279
291
 
280
- # If the loop exits without returning, it indicates too much nesting
281
292
  raise TemplateRenderError(
282
293
  "Too much nesting - you created an infinite loop here, pal"
283
294
  )
@@ -331,6 +342,58 @@ class Prompt(PersistenceMixin, RepresentationMixin):
331
342
  """Return an example of the prompt."""
332
343
  return cls(cls.default_instructions)
333
344
 
345
+ def get_prompts(self) -> Dict[str, Any]:
346
+ """Get the prompts for the question."""
347
+ start = time.time()
348
+
349
+ # Build all the components
350
+ instr_start = time.time()
351
+ agent_instructions = self.agent_instructions_prompt
352
+ instr_end = time.time()
353
+ logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
354
+
355
+ persona_start = time.time()
356
+ agent_persona = self.agent_persona_prompt
357
+ persona_end = time.time()
358
+ logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
359
+
360
+ q_instr_start = time.time()
361
+ question_instructions = self.question_instructions_prompt
362
+ q_instr_end = time.time()
363
+ logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
364
+
365
+ memory_start = time.time()
366
+ prior_question_memory = self.prior_question_memory_prompt
367
+ memory_end = time.time()
368
+ logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
369
+
370
+ # Get components dict
371
+ components = {
372
+ "agent_instructions": agent_instructions.text,
373
+ "agent_persona": agent_persona.text,
374
+ "question_instructions": question_instructions.text,
375
+ "prior_question_memory": prior_question_memory.text,
376
+ }
377
+
378
+ # Use PromptPlan's get_prompts method
379
+ plan_start = time.time()
380
+ prompts = self.prompt_plan.get_prompts(**components)
381
+ plan_end = time.time()
382
+ logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
383
+
384
+ # Handle file keys if present
385
+ if hasattr(self, 'question_file_keys') and self.question_file_keys:
386
+ files_start = time.time()
387
+ files_list = []
388
+ for key in self.question_file_keys:
389
+ files_list.append(self.scenario[key])
390
+ prompts["files_list"] = files_list
391
+ files_end = time.time()
392
+ logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
393
+
394
+ end = time.time()
395
+ logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
396
+ return prompts
334
397
 
335
398
  if __name__ == "__main__":
336
399
  print("Running doctests...")