edsl 0.1.42__py3-none-any.whl → 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +1 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +65 -19
- edsl/enums.py +1 -2
- edsl/exceptions/coop.py +4 -0
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/jobs/Jobs.py +54 -35
- edsl/jobs/JobsPrompts.py +54 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +32 -18
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Results.py +179 -1
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/Scenario.py +33 -0
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/METADATA +3 -4
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/RECORD +29 -29
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/LICENSE +0 -0
- {edsl-0.1.42.dist-info → edsl-0.1.43.dist-info}/WHEEL +0 -0
edsl/jobs/JobsPrompts.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
import time
|
2
|
+
import logging
|
1
3
|
from typing import List, TYPE_CHECKING
|
2
4
|
|
3
5
|
from edsl.results.Dataset import Dataset
|
@@ -14,6 +16,7 @@ if TYPE_CHECKING:
|
|
14
16
|
from edsl.jobs.FetchInvigilator import FetchInvigilator
|
15
17
|
from edsl.data.CacheEntry import CacheEntry
|
16
18
|
|
19
|
+
logger = logging.getLogger(__name__)
|
17
20
|
|
18
21
|
class JobsPrompts:
|
19
22
|
def __init__(self, jobs: "Jobs"):
|
@@ -22,6 +25,8 @@ class JobsPrompts:
|
|
22
25
|
self.scenarios = jobs.scenarios
|
23
26
|
self.survey = jobs.survey
|
24
27
|
self._price_lookup = None
|
28
|
+
self._agent_lookup = {agent: idx for idx, agent in enumerate(self.agents)}
|
29
|
+
self._scenario_lookup = {scenario: idx for idx, scenario in enumerate(self.scenarios)}
|
25
30
|
|
26
31
|
@property
|
27
32
|
def price_lookup(self):
|
@@ -49,25 +54,53 @@ class JobsPrompts:
|
|
49
54
|
models = []
|
50
55
|
costs = []
|
51
56
|
cache_keys = []
|
57
|
+
|
52
58
|
for interview_index, interview in enumerate(interviews):
|
59
|
+
logger.info(f"Processing interview {interview_index} of {len(interviews)}")
|
60
|
+
interview_start = time.time()
|
61
|
+
|
62
|
+
# Fetch invigilators timing
|
63
|
+
invig_start = time.time()
|
53
64
|
invigilators = [
|
54
65
|
FetchInvigilator(interview)(question)
|
55
66
|
for question in interview.survey.questions
|
56
67
|
]
|
68
|
+
invig_end = time.time()
|
69
|
+
logger.debug(f"Time taken to fetch invigilators: {invig_end - invig_start:.4f}s")
|
70
|
+
|
71
|
+
# Process prompts timing
|
72
|
+
prompts_start = time.time()
|
57
73
|
for _, invigilator in enumerate(invigilators):
|
74
|
+
# Get prompts timing
|
75
|
+
get_prompts_start = time.time()
|
58
76
|
prompts = invigilator.get_prompts()
|
77
|
+
get_prompts_end = time.time()
|
78
|
+
logger.debug(f"Time taken to get prompts: {get_prompts_end - get_prompts_start:.4f}s")
|
79
|
+
|
59
80
|
user_prompt = prompts["user_prompt"]
|
60
81
|
system_prompt = prompts["system_prompt"]
|
61
82
|
user_prompts.append(user_prompt)
|
62
83
|
system_prompts.append(system_prompt)
|
63
|
-
|
84
|
+
|
85
|
+
# Index lookups timing
|
86
|
+
index_start = time.time()
|
87
|
+
agent_index = self._agent_lookup[invigilator.agent]
|
64
88
|
agent_indices.append(agent_index)
|
65
89
|
interview_indices.append(interview_index)
|
66
|
-
scenario_index = self.
|
90
|
+
scenario_index = self._scenario_lookup[invigilator.scenario]
|
67
91
|
scenario_indices.append(scenario_index)
|
92
|
+
index_end = time.time()
|
93
|
+
logger.debug(f"Time taken for index lookups: {index_end - index_start:.4f}s")
|
94
|
+
|
95
|
+
# Model and question name assignment timing
|
96
|
+
assign_start = time.time()
|
68
97
|
models.append(invigilator.model.model)
|
69
98
|
question_names.append(invigilator.question.question_name)
|
99
|
+
assign_end = time.time()
|
100
|
+
logger.debug(f"Time taken for assignments: {assign_end - assign_start:.4f}s")
|
70
101
|
|
102
|
+
# Cost estimation timing
|
103
|
+
cost_start = time.time()
|
71
104
|
prompt_cost = self.estimate_prompt_cost(
|
72
105
|
system_prompt=system_prompt,
|
73
106
|
user_prompt=user_prompt,
|
@@ -75,16 +108,34 @@ class JobsPrompts:
|
|
75
108
|
inference_service=invigilator.model._inference_service_,
|
76
109
|
model=invigilator.model.model,
|
77
110
|
)
|
111
|
+
cost_end = time.time()
|
112
|
+
logger.debug(f"Time taken to estimate prompt cost: {cost_end - cost_start:.4f}s")
|
78
113
|
costs.append(prompt_cost["cost_usd"])
|
79
114
|
|
115
|
+
# Cache key generation timing
|
116
|
+
cache_key_gen_start = time.time()
|
80
117
|
cache_key = CacheEntry.gen_key(
|
81
118
|
model=invigilator.model.model,
|
82
119
|
parameters=invigilator.model.parameters,
|
83
120
|
system_prompt=system_prompt,
|
84
121
|
user_prompt=user_prompt,
|
85
|
-
iteration=0,
|
122
|
+
iteration=0,
|
86
123
|
)
|
124
|
+
cache_key_gen_end = time.time()
|
87
125
|
cache_keys.append(cache_key)
|
126
|
+
logger.debug(f"Time taken to generate cache key: {cache_key_gen_end - cache_key_gen_start:.4f}s")
|
127
|
+
logger.debug("-" * 50) # Separator between iterations
|
128
|
+
|
129
|
+
prompts_end = time.time()
|
130
|
+
logger.info(f"Time taken to process prompts: {prompts_end - prompts_start:.4f}s")
|
131
|
+
|
132
|
+
interview_end = time.time()
|
133
|
+
logger.info(f"Overall time taken for interview: {interview_end - interview_start:.4f}s")
|
134
|
+
logger.info("Time breakdown:")
|
135
|
+
logger.info(f" Invigilators: {invig_end - invig_start:.4f}s")
|
136
|
+
logger.info(f" Prompts processing: {prompts_end - prompts_start:.4f}s")
|
137
|
+
logger.info(f" Other overhead: {(interview_end - interview_start) - ((invig_end - invig_start) + (prompts_end - prompts_start)):.4f}s")
|
138
|
+
|
88
139
|
d = Dataset(
|
89
140
|
[
|
90
141
|
{"user_prompt": user_prompts},
|
@@ -228,6 +228,40 @@ class JobsRemoteInferenceHandler:
|
|
228
228
|
results.results_uuid = results_uuid
|
229
229
|
return results
|
230
230
|
|
231
|
+
def _attempt_fetch_job(
|
232
|
+
self,
|
233
|
+
job_info: RemoteJobInfo,
|
234
|
+
remote_job_data_fetcher: Callable,
|
235
|
+
object_fetcher: Callable,
|
236
|
+
) -> Union[None, "Results", Literal["continue"]]:
|
237
|
+
"""Makes one attempt to fetch and process a remote job's status and results."""
|
238
|
+
remote_job_data = remote_job_data_fetcher(job_info.job_uuid)
|
239
|
+
status = remote_job_data.get("status")
|
240
|
+
|
241
|
+
if status == "cancelled":
|
242
|
+
self._handle_cancelled_job(job_info)
|
243
|
+
return None
|
244
|
+
|
245
|
+
elif status == "failed" or status == "completed":
|
246
|
+
if status == "failed":
|
247
|
+
self._handle_failed_job(job_info, remote_job_data)
|
248
|
+
|
249
|
+
results_uuid = remote_job_data.get("results_uuid")
|
250
|
+
if results_uuid:
|
251
|
+
results = self._fetch_results_and_log(
|
252
|
+
job_info=job_info,
|
253
|
+
results_uuid=results_uuid,
|
254
|
+
remote_job_data=remote_job_data,
|
255
|
+
object_fetcher=object_fetcher,
|
256
|
+
)
|
257
|
+
return results
|
258
|
+
else:
|
259
|
+
return None
|
260
|
+
|
261
|
+
else:
|
262
|
+
self._sleep_for_a_bit(job_info, status)
|
263
|
+
return "continue"
|
264
|
+
|
231
265
|
def poll_remote_inference_job(
|
232
266
|
self,
|
233
267
|
job_info: RemoteJobInfo,
|
@@ -242,31 +276,13 @@ class JobsRemoteInferenceHandler:
|
|
242
276
|
|
243
277
|
job_in_queue = True
|
244
278
|
while job_in_queue:
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
elif status == "failed" or status == "completed":
|
253
|
-
if status == "failed":
|
254
|
-
self._handle_failed_job(job_info, remote_job_data)
|
255
|
-
|
256
|
-
results_uuid = remote_job_data.get("results_uuid")
|
257
|
-
if results_uuid:
|
258
|
-
results = self._fetch_results_and_log(
|
259
|
-
job_info=job_info,
|
260
|
-
results_uuid=results_uuid,
|
261
|
-
remote_job_data=remote_job_data,
|
262
|
-
object_fetcher=object_fetcher,
|
263
|
-
)
|
264
|
-
return results
|
265
|
-
else:
|
266
|
-
return None
|
267
|
-
|
268
|
-
else:
|
269
|
-
self._sleep_for_a_bit(job_info, status)
|
279
|
+
result = self._attempt_fetch_job(
|
280
|
+
job_info,
|
281
|
+
remote_job_data_fetcher,
|
282
|
+
object_fetcher
|
283
|
+
)
|
284
|
+
if result != "continue":
|
285
|
+
return result
|
270
286
|
|
271
287
|
async def create_and_poll_remote_job(
|
272
288
|
self,
|
@@ -96,6 +96,36 @@ class BucketCollection(UserDict):
|
|
96
96
|
else:
|
97
97
|
self[model] = self.services_to_buckets[self.models_to_services[model.model]]
|
98
98
|
|
99
|
+
def update_from_key_lookup(self, key_lookup: "KeyLookup") -> None:
|
100
|
+
"""Updates the bucket collection rates based on model RPM/TPM from KeyLookup"""
|
101
|
+
|
102
|
+
for model_name, service in self.models_to_services.items():
|
103
|
+
if service in key_lookup and not self.infinity_buckets:
|
104
|
+
|
105
|
+
if key_lookup[service].rpm is not None:
|
106
|
+
new_rps = key_lookup[service].rpm / 60.0
|
107
|
+
new_requests_bucket = TokenBucket(
|
108
|
+
bucket_name=service,
|
109
|
+
bucket_type="requests",
|
110
|
+
capacity=new_rps,
|
111
|
+
refill_rate=new_rps,
|
112
|
+
remote_url=self.remote_url,
|
113
|
+
)
|
114
|
+
self.services_to_buckets[service].requests_bucket = (
|
115
|
+
new_requests_bucket
|
116
|
+
)
|
117
|
+
|
118
|
+
if key_lookup[service].tpm is not None:
|
119
|
+
new_tps = key_lookup[service].tpm / 60.0
|
120
|
+
new_tokens_bucket = TokenBucket(
|
121
|
+
bucket_name=service,
|
122
|
+
bucket_type="tokens",
|
123
|
+
capacity=new_tps,
|
124
|
+
refill_rate=new_tps,
|
125
|
+
remote_url=self.remote_url,
|
126
|
+
)
|
127
|
+
self.services_to_buckets[service].tokens_bucket = new_tokens_bucket
|
128
|
+
|
99
129
|
def visualize(self) -> dict:
|
100
130
|
"""Visualize the token and request buckets for each model."""
|
101
131
|
plots = {}
|
edsl/jobs/data_structures.py
CHANGED
@@ -32,6 +32,7 @@ class RunParameters(Base):
|
|
32
32
|
remote_inference_results_visibility: Optional[VisibilityType] = "unlisted"
|
33
33
|
skip_retry: bool = False
|
34
34
|
raise_validation_errors: bool = False
|
35
|
+
background: bool = False
|
35
36
|
disable_remote_cache: bool = False
|
36
37
|
disable_remote_inference: bool = False
|
37
38
|
job_uuid: Optional[str] = None
|
@@ -61,7 +61,14 @@ class KeyLookupBuilder:
|
|
61
61
|
DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
|
62
62
|
DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))
|
63
63
|
|
64
|
-
def __init__(
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
fetch_order: Optional[tuple[str]] = None,
|
67
|
+
coop: Optional["Coop"] = None,
|
68
|
+
):
|
69
|
+
from edsl.coop import Coop
|
70
|
+
|
71
|
+
# Fetch order goes from lowest priority to highest priority
|
65
72
|
if fetch_order is None:
|
66
73
|
self.fetch_order = ("config", "env")
|
67
74
|
else:
|
@@ -70,6 +77,11 @@ class KeyLookupBuilder:
|
|
70
77
|
if not isinstance(self.fetch_order, tuple):
|
71
78
|
raise ValueError("fetch_order must be a tuple")
|
72
79
|
|
80
|
+
if coop is None:
|
81
|
+
self.coop = Coop()
|
82
|
+
else:
|
83
|
+
self.coop = coop
|
84
|
+
|
73
85
|
self.limit_data = {}
|
74
86
|
self.key_data = {}
|
75
87
|
self.id_data = {}
|
@@ -131,7 +143,8 @@ class KeyLookupBuilder:
|
|
131
143
|
service=service,
|
132
144
|
rpm=self.DEFAULT_RPM,
|
133
145
|
tpm=self.DEFAULT_TPM,
|
134
|
-
|
146
|
+
rpm_source="default",
|
147
|
+
tpm_source="default",
|
135
148
|
)
|
136
149
|
|
137
150
|
if limit_entry.rpm is None:
|
@@ -145,7 +158,8 @@ class KeyLookupBuilder:
|
|
145
158
|
tpm=int(limit_entry.tpm),
|
146
159
|
api_id=api_id,
|
147
160
|
token_source=api_key_entry.source,
|
148
|
-
|
161
|
+
rpm_source=limit_entry.rpm_source,
|
162
|
+
tpm_source=limit_entry.tpm_source,
|
149
163
|
id_source=id_source,
|
150
164
|
)
|
151
165
|
|
@@ -156,10 +170,7 @@ class KeyLookupBuilder:
|
|
156
170
|
return dict(list(os.environ.items()))
|
157
171
|
|
158
172
|
def _coop_key_value_pairs(self):
|
159
|
-
|
160
|
-
|
161
|
-
c = Coop()
|
162
|
-
return dict(list(c.fetch_rate_limit_config_vars().items()))
|
173
|
+
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
|
163
174
|
|
164
175
|
def _config_key_value_pairs(self):
|
165
176
|
from edsl.config import CONFIG
|
@@ -169,7 +180,7 @@ class KeyLookupBuilder:
|
|
169
180
|
@staticmethod
|
170
181
|
def extract_service(key: str) -> str:
|
171
182
|
"""Extract the service and limit type from the key"""
|
172
|
-
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
|
183
|
+
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
|
173
184
|
return service_raw.lower(), limit_type.lower()
|
174
185
|
|
175
186
|
def get_key_value_pairs(self) -> dict:
|
@@ -187,17 +198,17 @@ class KeyLookupBuilder:
|
|
187
198
|
d[k] = (v, source)
|
188
199
|
return d
|
189
200
|
|
190
|
-
def _entry_type(self, key
|
201
|
+
def _entry_type(self, key: str) -> str:
|
191
202
|
"""Determine the type of entry from a key.
|
192
203
|
|
193
204
|
>>> builder = KeyLookupBuilder()
|
194
|
-
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI"
|
205
|
+
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
|
195
206
|
'limit'
|
196
|
-
>>> builder._entry_type("OPENAI_API_KEY"
|
207
|
+
>>> builder._entry_type("OPENAI_API_KEY")
|
197
208
|
'api_key'
|
198
|
-
>>> builder._entry_type("AWS_ACCESS_KEY_ID"
|
209
|
+
>>> builder._entry_type("AWS_ACCESS_KEY_ID")
|
199
210
|
'api_id'
|
200
|
-
>>> builder._entry_type("UNKNOWN_KEY"
|
211
|
+
>>> builder._entry_type("UNKNOWN_KEY")
|
201
212
|
'unknown'
|
202
213
|
"""
|
203
214
|
if key.startswith("EDSL_SERVICE_"):
|
@@ -243,11 +254,13 @@ class KeyLookupBuilder:
|
|
243
254
|
service, limit_type = self.extract_service(key)
|
244
255
|
if service in self.limit_data:
|
245
256
|
setattr(self.limit_data[service], limit_type.lower(), value)
|
257
|
+
setattr(self.limit_data[service], f"{limit_type}_source", source)
|
246
258
|
else:
|
247
259
|
new_limit_entry = LimitEntry(
|
248
|
-
service=service, rpm=None, tpm=None,
|
260
|
+
service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
|
249
261
|
)
|
250
262
|
setattr(new_limit_entry, limit_type.lower(), value)
|
263
|
+
setattr(new_limit_entry, f"{limit_type}_source", source)
|
251
264
|
self.limit_data[service] = new_limit_entry
|
252
265
|
|
253
266
|
def _add_api_key(self, key: str, value: str, source: str) -> None:
|
@@ -265,13 +278,27 @@ class KeyLookupBuilder:
|
|
265
278
|
else:
|
266
279
|
self.key_data[service].append(new_entry)
|
267
280
|
|
268
|
-
def
|
269
|
-
"""
|
270
|
-
|
281
|
+
def update_from_dict(self, d: dict) -> None:
|
282
|
+
"""
|
283
|
+
Update data from a dictionary of key-value pairs.
|
284
|
+
Each key is a key name, and each value is a tuple of (value, source).
|
285
|
+
|
286
|
+
>>> builder = KeyLookupBuilder()
|
287
|
+
>>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
|
288
|
+
>>> 'sk-1234' == builder.key_data["openai"][-1].value
|
289
|
+
True
|
290
|
+
>>> 'custodial_keys' == builder.key_data["openai"][-1].source
|
291
|
+
True
|
292
|
+
"""
|
293
|
+
for key, value_pair in d.items():
|
271
294
|
value, source = value_pair
|
272
|
-
if
|
295
|
+
if self._entry_type(key) == "limit":
|
273
296
|
self._add_limit(key, value, source)
|
274
|
-
elif
|
297
|
+
elif self._entry_type(key) == "api_key":
|
275
298
|
self._add_api_key(key, value, source)
|
276
|
-
elif
|
299
|
+
elif self._entry_type(key) == "api_id":
|
277
300
|
self._add_id(key, value, source)
|
301
|
+
|
302
|
+
def process_key_value_pairs(self) -> None:
|
303
|
+
"""Process all key-value pairs from the configured sources."""
|
304
|
+
self.update_from_dict(self.get_key_value_pairs())
|
@@ -40,18 +40,23 @@ class LimitEntry:
|
|
40
40
|
60
|
41
41
|
>>> limit.tpm
|
42
42
|
100000
|
43
|
-
>>> limit.
|
43
|
+
>>> limit.rpm_source
|
44
44
|
'config'
|
45
|
+
>>> limit.tpm_source
|
46
|
+
'env'
|
45
47
|
"""
|
46
48
|
|
47
49
|
service: str
|
48
50
|
rpm: int
|
49
51
|
tpm: int
|
50
|
-
|
52
|
+
rpm_source: Optional[str] = None
|
53
|
+
tpm_source: Optional[str] = None
|
51
54
|
|
52
55
|
@classmethod
|
53
56
|
def example(cls):
|
54
|
-
return LimitEntry(
|
57
|
+
return LimitEntry(
|
58
|
+
service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
|
59
|
+
)
|
55
60
|
|
56
61
|
|
57
62
|
@dataclass
|
@@ -108,7 +113,8 @@ class LanguageModelInput:
|
|
108
113
|
tpm: int
|
109
114
|
api_id: Optional[str] = None
|
110
115
|
token_source: Optional[str] = None
|
111
|
-
|
116
|
+
rpm_source: Optional[str] = None
|
117
|
+
tpm_source: Optional[str] = None
|
112
118
|
id_source: Optional[str] = None
|
113
119
|
|
114
120
|
def to_dict(self):
|
edsl/prompts/Prompt.py
CHANGED
@@ -10,6 +10,48 @@ from edsl.Base import PersistenceMixin, RepresentationMixin
|
|
10
10
|
|
11
11
|
MAX_NESTING = 100
|
12
12
|
|
13
|
+
from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
|
14
|
+
from functools import lru_cache
|
15
|
+
|
16
|
+
class PreserveUndefined(Undefined):
|
17
|
+
def __str__(self):
|
18
|
+
return "{{ " + str(self._undefined_name) + " }}"
|
19
|
+
|
20
|
+
# Create environment once at module level
|
21
|
+
_env = Environment(undefined=PreserveUndefined)
|
22
|
+
|
23
|
+
@lru_cache(maxsize=1024)
|
24
|
+
def _compile_template(text: str):
|
25
|
+
return _env.from_string(text)
|
26
|
+
|
27
|
+
@lru_cache(maxsize=1024)
|
28
|
+
def _find_template_variables(template: str) -> list[str]:
|
29
|
+
"""Find and return the template variables."""
|
30
|
+
ast = _env.parse(template)
|
31
|
+
return list(meta.find_undeclared_variables(ast))
|
32
|
+
|
33
|
+
def _make_hashable(value):
|
34
|
+
"""Convert unhashable types to hashable ones."""
|
35
|
+
if isinstance(value, list):
|
36
|
+
return tuple(_make_hashable(item) for item in value)
|
37
|
+
if isinstance(value, dict):
|
38
|
+
return frozenset((k, _make_hashable(v)) for k, v in value.items())
|
39
|
+
return value
|
40
|
+
|
41
|
+
@lru_cache(maxsize=1024)
|
42
|
+
def _cached_render(text: str, frozen_replacements: frozenset) -> str:
|
43
|
+
"""Cached version of template rendering with frozen replacements."""
|
44
|
+
# Print cache info on every call
|
45
|
+
cache_info = _cached_render.cache_info()
|
46
|
+
print(f"\t\t\t\t\t Cache status - hits: {cache_info.hits}, misses: {cache_info.misses}, current size: {cache_info.currsize}")
|
47
|
+
|
48
|
+
# Convert back to dict with original types for rendering
|
49
|
+
replacements = {k: v for k, v in frozen_replacements}
|
50
|
+
|
51
|
+
template = _compile_template(text)
|
52
|
+
result = template.render(replacements)
|
53
|
+
|
54
|
+
return result
|
13
55
|
|
14
56
|
class Prompt(PersistenceMixin, RepresentationMixin):
|
15
57
|
"""Class for creating a prompt to be used in a survey."""
|
@@ -145,33 +187,8 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
145
187
|
return f'Prompt(text="""{self.text}""")'
|
146
188
|
|
147
189
|
def template_variables(self) -> list[str]:
|
148
|
-
"""Return the
|
149
|
-
|
150
|
-
Example:
|
151
|
-
|
152
|
-
>>> p = Prompt("Hello, {{person}}")
|
153
|
-
>>> p.template_variables()
|
154
|
-
['person']
|
155
|
-
|
156
|
-
"""
|
157
|
-
return self._template_variables(self.text)
|
158
|
-
|
159
|
-
@staticmethod
|
160
|
-
def _template_variables(template: str) -> list[str]:
|
161
|
-
"""Find and return the template variables.
|
162
|
-
|
163
|
-
:param template: The template to find the variables in.
|
164
|
-
|
165
|
-
"""
|
166
|
-
from jinja2 import Environment, meta, Undefined
|
167
|
-
|
168
|
-
class PreserveUndefined(Undefined):
|
169
|
-
def __str__(self):
|
170
|
-
return "{{ " + str(self._undefined_name) + " }}"
|
171
|
-
|
172
|
-
env = Environment(undefined=PreserveUndefined)
|
173
|
-
ast = env.parse(template)
|
174
|
-
return list(meta.find_undeclared_variables(ast))
|
190
|
+
"""Return the variables in the template."""
|
191
|
+
return _find_template_variables(self.text)
|
175
192
|
|
176
193
|
def undefined_template_variables(self, replacement_dict: dict):
|
177
194
|
"""Return the variables in the template that are not in the replacement_dict.
|
@@ -239,45 +256,39 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
239
256
|
return self
|
240
257
|
|
241
258
|
@staticmethod
|
242
|
-
def _render(
|
243
|
-
text
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
Allows for nested variable resolution up to a specified maximum nesting depth.
|
252
|
-
|
253
|
-
Example:
|
254
|
-
|
255
|
-
>>> codebook = {"age": "Age"}
|
256
|
-
>>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
|
257
|
-
>>> p.render({"name": "John", "age": 44}, codebook=codebook)
|
258
|
-
Prompt(text=\"""You are an agent named John. Age: 44\""")
|
259
|
-
"""
|
260
|
-
from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
|
261
|
-
|
262
|
-
class PreserveUndefined(Undefined):
|
263
|
-
def __str__(self):
|
264
|
-
return "{{ " + str(self._undefined_name) + " }}"
|
265
|
-
|
266
|
-
env = Environment(undefined=PreserveUndefined)
|
259
|
+
def _render(text: str, primary_replacement, **additional_replacements) -> "PromptBase":
|
260
|
+
"""Render the template text with variables replaced."""
|
261
|
+
import time
|
262
|
+
|
263
|
+
# if there are no replacements, return the text
|
264
|
+
if not primary_replacement and not additional_replacements:
|
265
|
+
return text
|
266
|
+
|
267
267
|
try:
|
268
|
+
variables = _find_template_variables(text)
|
269
|
+
|
270
|
+
if not variables: # if there are no variables, return the text
|
271
|
+
return text
|
272
|
+
|
273
|
+
# Combine all replacements
|
274
|
+
all_replacements = {**primary_replacement, **additional_replacements}
|
275
|
+
|
268
276
|
previous_text = None
|
277
|
+
current_text = text
|
278
|
+
iteration = 0
|
279
|
+
|
269
280
|
for _ in range(MAX_NESTING):
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
)
|
274
|
-
|
275
|
-
|
281
|
+
iteration += 1
|
282
|
+
|
283
|
+
template = _compile_template(current_text)
|
284
|
+
rendered_text = template.render(all_replacements)
|
285
|
+
|
286
|
+
if rendered_text == current_text:
|
276
287
|
return rendered_text
|
277
|
-
|
278
|
-
|
288
|
+
|
289
|
+
previous_text = current_text
|
290
|
+
current_text = rendered_text
|
279
291
|
|
280
|
-
# If the loop exits without returning, it indicates too much nesting
|
281
292
|
raise TemplateRenderError(
|
282
293
|
"Too much nesting - you created an infinite loop here, pal"
|
283
294
|
)
|
@@ -331,6 +342,58 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
331
342
|
"""Return an example of the prompt."""
|
332
343
|
return cls(cls.default_instructions)
|
333
344
|
|
345
|
+
def get_prompts(self) -> Dict[str, Any]:
|
346
|
+
"""Get the prompts for the question."""
|
347
|
+
start = time.time()
|
348
|
+
|
349
|
+
# Build all the components
|
350
|
+
instr_start = time.time()
|
351
|
+
agent_instructions = self.agent_instructions_prompt
|
352
|
+
instr_end = time.time()
|
353
|
+
logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
|
354
|
+
|
355
|
+
persona_start = time.time()
|
356
|
+
agent_persona = self.agent_persona_prompt
|
357
|
+
persona_end = time.time()
|
358
|
+
logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
|
359
|
+
|
360
|
+
q_instr_start = time.time()
|
361
|
+
question_instructions = self.question_instructions_prompt
|
362
|
+
q_instr_end = time.time()
|
363
|
+
logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
|
364
|
+
|
365
|
+
memory_start = time.time()
|
366
|
+
prior_question_memory = self.prior_question_memory_prompt
|
367
|
+
memory_end = time.time()
|
368
|
+
logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
|
369
|
+
|
370
|
+
# Get components dict
|
371
|
+
components = {
|
372
|
+
"agent_instructions": agent_instructions.text,
|
373
|
+
"agent_persona": agent_persona.text,
|
374
|
+
"question_instructions": question_instructions.text,
|
375
|
+
"prior_question_memory": prior_question_memory.text,
|
376
|
+
}
|
377
|
+
|
378
|
+
# Use PromptPlan's get_prompts method
|
379
|
+
plan_start = time.time()
|
380
|
+
prompts = self.prompt_plan.get_prompts(**components)
|
381
|
+
plan_end = time.time()
|
382
|
+
logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
|
383
|
+
|
384
|
+
# Handle file keys if present
|
385
|
+
if hasattr(self, 'question_file_keys') and self.question_file_keys:
|
386
|
+
files_start = time.time()
|
387
|
+
files_list = []
|
388
|
+
for key in self.question_file_keys:
|
389
|
+
files_list.append(self.scenario[key])
|
390
|
+
prompts["files_list"] = files_list
|
391
|
+
files_end = time.time()
|
392
|
+
logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
|
393
|
+
|
394
|
+
end = time.time()
|
395
|
+
logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
|
396
|
+
return prompts
|
334
397
|
|
335
398
|
if __name__ == "__main__":
|
336
399
|
print("Running doctests...")
|