edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +6 -3
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
- edsl/config.py +26 -34
- edsl/coop/coop.py +11 -2
- edsl/data_transfer_models.py +27 -73
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +1 -1
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/OpenAIService.py +7 -4
- edsl/inference_services/TestService.py +24 -15
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +18 -8
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +115 -47
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +26 -31
- edsl/language_models/registry.py +13 -9
- edsl/questions/QuestionBase.py +64 -16
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +11 -26
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +6 -5
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +86 -21
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +13 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +19 -3
- edsl/surveys/Survey.py +7 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
from dataclasses import dataclass, asdict
|
5
|
+
|
6
|
+
from typing import List, DefaultDict, Optional, Type, Literal
|
7
|
+
from collections import UserDict, defaultdict
|
8
|
+
|
9
|
+
from rich.text import Text
|
10
|
+
from rich.box import SIMPLE
|
11
|
+
from rich.table import Table
|
12
|
+
from rich.live import Live
|
13
|
+
from rich.panel import Panel
|
14
|
+
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
|
15
|
+
from rich.layout import Layout
|
16
|
+
from rich.console import Group
|
17
|
+
from rich import box
|
18
|
+
|
19
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
20
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
21
|
+
from edsl.jobs.tokens.TokenUsage import TokenUsage
|
22
|
+
from edsl.enums import get_token_pricing
|
23
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
24
|
+
|
25
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
26
|
+
|
27
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
28
|
+
from edsl.jobs.interviews.InterviewStatisticsCollection import (
|
29
|
+
InterviewStatisticsCollection,
|
30
|
+
)
|
31
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class ModelInfo:
|
36
|
+
model_name: str
|
37
|
+
TPM_limit_k: float
|
38
|
+
RPM_limit_k: float
|
39
|
+
num_tasks_waiting: int
|
40
|
+
token_usage_info: dict
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class ModelTokenUsageStats:
|
45
|
+
token_usage_type: str
|
46
|
+
details: List[dict]
|
47
|
+
cost: str
|
48
|
+
|
49
|
+
|
50
|
+
class Stats:
|
51
|
+
def elapsed_time(self):
|
52
|
+
InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
|
53
|
+
|
54
|
+
|
55
|
+
class JobsRunnerStatus:
|
56
|
+
def __init__(
|
57
|
+
self, jobs_runner: "JobsRunnerAsyncio", n: int, refresh_rate: float = 0.25
|
58
|
+
):
|
59
|
+
self.jobs_runner = jobs_runner
|
60
|
+
self.start_time = time.time()
|
61
|
+
self.completed_interviews = []
|
62
|
+
self.refresh_rate = refresh_rate
|
63
|
+
self.statistics = [
|
64
|
+
"elapsed_time",
|
65
|
+
"total_interviews_requested",
|
66
|
+
"completed_interviews",
|
67
|
+
# "percent_complete",
|
68
|
+
"average_time_per_interview",
|
69
|
+
# "task_remaining",
|
70
|
+
"estimated_time_remaining",
|
71
|
+
"exceptions",
|
72
|
+
"unfixed_exceptions",
|
73
|
+
"throughput",
|
74
|
+
]
|
75
|
+
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
76
|
+
|
77
|
+
self.distinct_models = list(
|
78
|
+
set(i.model.model for i in self.jobs_runner.interviews)
|
79
|
+
)
|
80
|
+
|
81
|
+
self.completed_interview_by_model = defaultdict(list)
|
82
|
+
|
83
|
+
def add_completed_interview(self, result):
|
84
|
+
self.completed_interviews.append(result.interview_hash)
|
85
|
+
|
86
|
+
relevant_model = result.model.model
|
87
|
+
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
88
|
+
|
89
|
+
def _compute_statistic(self, stat_name: str):
|
90
|
+
completed_tasks = self.completed_interviews
|
91
|
+
elapsed_time = time.time() - self.start_time
|
92
|
+
interviews = self.jobs_runner.total_interviews
|
93
|
+
|
94
|
+
stat_definitions = {
|
95
|
+
"elapsed_time": lambda: InterviewStatistic(
|
96
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
97
|
+
),
|
98
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
99
|
+
"total_interviews_requested", value=len(interviews), units=""
|
100
|
+
),
|
101
|
+
"completed_interviews": lambda: InterviewStatistic(
|
102
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
103
|
+
),
|
104
|
+
"percent_complete": lambda: InterviewStatistic(
|
105
|
+
"percent_complete",
|
106
|
+
value=(
|
107
|
+
len(completed_tasks) / len(interviews) * 100
|
108
|
+
if len(interviews) > 0
|
109
|
+
else 0
|
110
|
+
),
|
111
|
+
digits=1,
|
112
|
+
units="%",
|
113
|
+
),
|
114
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
115
|
+
"average_time_per_interview",
|
116
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
117
|
+
digits=2,
|
118
|
+
units="sec.",
|
119
|
+
),
|
120
|
+
"task_remaining": lambda: InterviewStatistic(
|
121
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
122
|
+
),
|
123
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
124
|
+
"estimated_time_remaining",
|
125
|
+
value=(
|
126
|
+
(len(interviews) - len(completed_tasks))
|
127
|
+
* (elapsed_time / len(completed_tasks))
|
128
|
+
if len(completed_tasks) > 0
|
129
|
+
else 0
|
130
|
+
),
|
131
|
+
digits=1,
|
132
|
+
units="sec.",
|
133
|
+
),
|
134
|
+
"exceptions": lambda: InterviewStatistic(
|
135
|
+
"exceptions",
|
136
|
+
value=sum(len(i.exceptions) for i in interviews),
|
137
|
+
units="",
|
138
|
+
),
|
139
|
+
"unfixed_exceptions": lambda: InterviewStatistic(
|
140
|
+
"unfixed_exceptions",
|
141
|
+
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
142
|
+
units="",
|
143
|
+
),
|
144
|
+
"throughput": lambda: InterviewStatistic(
|
145
|
+
"throughput",
|
146
|
+
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
147
|
+
digits=2,
|
148
|
+
units="interviews/sec.",
|
149
|
+
),
|
150
|
+
}
|
151
|
+
return stat_definitions[stat_name]()
|
152
|
+
|
153
|
+
def create_progress_bar(self):
|
154
|
+
return Progress(
|
155
|
+
TextColumn("[progress.description]{task.description}"),
|
156
|
+
BarColumn(),
|
157
|
+
TaskProgressColumn(),
|
158
|
+
TextColumn("{task.completed}/{task.total}"),
|
159
|
+
)
|
160
|
+
|
161
|
+
def generate_model_queues_table(self):
|
162
|
+
table = Table(show_header=False, box=box.SIMPLE)
|
163
|
+
table.add_column("Info", style="cyan")
|
164
|
+
table.add_column("Value", style="magenta")
|
165
|
+
# table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
|
166
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
167
|
+
table.add_row(Text(model.model, style="bold blue"), "")
|
168
|
+
bucket_types = ["requests_bucket", "tokens_bucket"]
|
169
|
+
for bucket_type in bucket_types:
|
170
|
+
table.add_row(Text(" " + bucket_type, style="green"), "")
|
171
|
+
# table.add_row(
|
172
|
+
# f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
|
173
|
+
# str(round(getattr(bucket, bucket_type).tokens, 3)),
|
174
|
+
# )
|
175
|
+
num_requests = getattr(bucket, bucket_type).num_requests
|
176
|
+
num_released = getattr(bucket, bucket_type).num_released
|
177
|
+
tokens_returned = getattr(bucket, bucket_type).tokens_returned
|
178
|
+
# table.add_row(
|
179
|
+
# f" Requested",
|
180
|
+
# str(num_requests),
|
181
|
+
# )
|
182
|
+
# table.add_row(
|
183
|
+
# f" Completed",
|
184
|
+
# str(num_released),
|
185
|
+
# )
|
186
|
+
table.add_row(
|
187
|
+
" Completed vs. Requested", f"{num_released} vs. {num_requests}"
|
188
|
+
)
|
189
|
+
table.add_row(
|
190
|
+
" Added tokens (from cache)",
|
191
|
+
str(tokens_returned),
|
192
|
+
)
|
193
|
+
if bucket_type == "tokens_bucket":
|
194
|
+
rate_name = "TPM"
|
195
|
+
else:
|
196
|
+
rate_name = "RPM"
|
197
|
+
target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
|
198
|
+
table.add_row(
|
199
|
+
f" Empirical {rate_name} (target = {target_rate})",
|
200
|
+
str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
|
201
|
+
)
|
202
|
+
|
203
|
+
return table
|
204
|
+
|
205
|
+
def generate_layout(self):
|
206
|
+
progress = self.create_progress_bar()
|
207
|
+
task_ids = []
|
208
|
+
for model in self.distinct_models:
|
209
|
+
task_id = progress.add_task(
|
210
|
+
f"[cyan]{model}...",
|
211
|
+
total=int(self.num_total_interviews / len(self.distinct_models)),
|
212
|
+
)
|
213
|
+
task_ids.append((model, task_id))
|
214
|
+
|
215
|
+
progress_height = min(5, 2 + len(self.distinct_models))
|
216
|
+
layout = Layout()
|
217
|
+
|
218
|
+
# Create the top row with only the progress panel
|
219
|
+
layout.split_column(
|
220
|
+
Layout(
|
221
|
+
Panel(
|
222
|
+
progress,
|
223
|
+
title="Interview Progress",
|
224
|
+
border_style="cyan",
|
225
|
+
box=box.ROUNDED,
|
226
|
+
),
|
227
|
+
name="progress",
|
228
|
+
size=progress_height, # Adjusted size
|
229
|
+
),
|
230
|
+
Layout(name="bottom_row"), # Adjusted size
|
231
|
+
)
|
232
|
+
|
233
|
+
# Split the bottom row into two columns for metrics and model queues
|
234
|
+
layout["bottom_row"].split_row(
|
235
|
+
Layout(
|
236
|
+
Panel(
|
237
|
+
self.generate_metrics_table(),
|
238
|
+
title="Metrics",
|
239
|
+
border_style="magenta",
|
240
|
+
box=box.ROUNDED,
|
241
|
+
),
|
242
|
+
name="metrics",
|
243
|
+
),
|
244
|
+
Layout(
|
245
|
+
Panel(
|
246
|
+
self.generate_model_queues_table(),
|
247
|
+
title="Model Queues",
|
248
|
+
border_style="yellow",
|
249
|
+
box=box.ROUNDED,
|
250
|
+
),
|
251
|
+
name="model_queues",
|
252
|
+
),
|
253
|
+
)
|
254
|
+
|
255
|
+
return layout, progress, task_ids
|
256
|
+
|
257
|
+
def generate_metrics_table(self):
|
258
|
+
table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
|
259
|
+
table.add_column("Metric", style="cyan", no_wrap=True)
|
260
|
+
table.add_column("Value", justify="right")
|
261
|
+
|
262
|
+
for stat_name in self.statistics:
|
263
|
+
pretty_name, value = list(self._compute_statistic(stat_name).items())[0]
|
264
|
+
# breakpoint()
|
265
|
+
table.add_row(pretty_name, value)
|
266
|
+
return table
|
267
|
+
|
268
|
+
def update_progress(self):
|
269
|
+
layout, progress, task_ids = self.generate_layout()
|
270
|
+
|
271
|
+
with Live(
|
272
|
+
layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
|
273
|
+
) as live:
|
274
|
+
while len(self.completed_interviews) < len(
|
275
|
+
self.jobs_runner.total_interviews
|
276
|
+
):
|
277
|
+
completed_tasks = len(self.completed_interviews)
|
278
|
+
total_tasks = len(self.jobs_runner.total_interviews)
|
279
|
+
|
280
|
+
for model, task_id in task_ids:
|
281
|
+
completed_tasks = len(self.completed_interview_by_model[model])
|
282
|
+
progress.update(
|
283
|
+
task_id,
|
284
|
+
completed=completed_tasks,
|
285
|
+
description=f"[cyan]Conducting interviews for {model}...",
|
286
|
+
)
|
287
|
+
|
288
|
+
layout["metrics"].update(
|
289
|
+
Panel(
|
290
|
+
self.generate_metrics_table(),
|
291
|
+
title="Metrics",
|
292
|
+
border_style="magenta",
|
293
|
+
box=box.ROUNDED,
|
294
|
+
)
|
295
|
+
)
|
296
|
+
layout["model_queues"].update(
|
297
|
+
Panel(
|
298
|
+
self.generate_model_queues_table(),
|
299
|
+
title="Final Model Queues",
|
300
|
+
border_style="yellow",
|
301
|
+
box=box.ROUNDED,
|
302
|
+
)
|
303
|
+
)
|
304
|
+
|
305
|
+
time.sleep(self.refresh_rate)
|
306
|
+
|
307
|
+
# Final update
|
308
|
+
for model, task_id in task_ids:
|
309
|
+
completed_tasks = len(self.completed_interview_by_model[model])
|
310
|
+
progress.update(
|
311
|
+
task_id,
|
312
|
+
completed=completed_tasks,
|
313
|
+
description=f"[cyan]Conducting interviews for {model}...",
|
314
|
+
)
|
315
|
+
|
316
|
+
layout["metrics"].update(
|
317
|
+
Panel(
|
318
|
+
self.generate_metrics_table(),
|
319
|
+
title="Final Metrics",
|
320
|
+
border_style="magenta",
|
321
|
+
box=box.ROUNDED,
|
322
|
+
)
|
323
|
+
)
|
324
|
+
live.update(layout)
|
325
|
+
time.sleep(1) # Show final state for 1 second
|
326
|
+
|
327
|
+
|
328
|
+
if __name__ == "__main__":
|
329
|
+
import doctest
|
330
|
+
|
331
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -50,6 +50,18 @@ class TaskHistory:
|
|
50
50
|
"""
|
51
51
|
return [i.exceptions for k, i in self._interviews.items() if i.exceptions != {}]
|
52
52
|
|
53
|
+
@property
|
54
|
+
def unfixed_exceptions(self):
|
55
|
+
"""
|
56
|
+
>>> len(TaskHistory.example().unfixed_exceptions)
|
57
|
+
4
|
58
|
+
"""
|
59
|
+
return [
|
60
|
+
i.exceptions
|
61
|
+
for k, i in self._interviews.items()
|
62
|
+
if i.exceptions.num_unfixed() > 0
|
63
|
+
]
|
64
|
+
|
53
65
|
@property
|
54
66
|
def indices(self):
|
55
67
|
return [k for k, i in self._interviews.items() if i.exceptions != {}]
|
@@ -78,6 +90,11 @@ class TaskHistory:
|
|
78
90
|
"""
|
79
91
|
return len(self.exceptions) > 0
|
80
92
|
|
93
|
+
@property
|
94
|
+
def has_unfixed_exceptions(self) -> bool:
|
95
|
+
"""Return True if there are any exceptions."""
|
96
|
+
return len(self.unfixed_exceptions) > 0
|
97
|
+
|
81
98
|
def _repr_html_(self):
|
82
99
|
"""Return an HTML representation of the TaskHistory."""
|
83
100
|
from edsl.utilities.utilities import data_to_html
|
@@ -164,20 +164,20 @@ class LanguageModel(
|
|
164
164
|
None # This should be something like ["choices", 0, "message", "content"]
|
165
165
|
)
|
166
166
|
__rate_limits = None
|
167
|
-
__default_rate_limits = {
|
168
|
-
"rpm": 10_000,
|
169
|
-
"tpm": 2_000_000,
|
170
|
-
} # TODO: Use the OpenAI Teir 1 rate limits
|
171
167
|
_safety_factor = 0.8
|
172
168
|
|
173
|
-
def __init__(
|
169
|
+
def __init__(
|
170
|
+
self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
|
171
|
+
):
|
174
172
|
"""Initialize the LanguageModel."""
|
175
173
|
self.model = getattr(self, "_model_", None)
|
176
174
|
default_parameters = getattr(self, "_parameters_", None)
|
177
175
|
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
178
176
|
self.parameters = parameters
|
179
177
|
self.remote = False
|
178
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
180
179
|
|
180
|
+
# self._rpm / _tpm comes from the class
|
181
181
|
if rpm is not None:
|
182
182
|
self._rpm = rpm
|
183
183
|
|
@@ -286,35 +286,40 @@ class LanguageModel(
|
|
286
286
|
>>> m.RPM
|
287
287
|
100
|
288
288
|
"""
|
289
|
-
|
289
|
+
if rpm is not None:
|
290
|
+
self._rpm = rpm
|
291
|
+
if tpm is not None:
|
292
|
+
self._tpm = tpm
|
293
|
+
return None
|
294
|
+
# self._set_rate_limits(rpm=rpm, tpm=tpm)
|
290
295
|
|
291
|
-
def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
292
|
-
|
296
|
+
# def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
297
|
+
# """Set the rate limits for the model.
|
293
298
|
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
299
|
+
# If the model does not have rate limits, use the default rate limits."""
|
300
|
+
# if rpm is not None and tpm is not None:
|
301
|
+
# self.__rate_limits = {"rpm": rpm, "tpm": tpm}
|
302
|
+
# return
|
298
303
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
+
# if self.__rate_limits is None:
|
305
|
+
# if hasattr(self, "get_rate_limits"):
|
306
|
+
# self.__rate_limits = self.get_rate_limits()
|
307
|
+
# else:
|
308
|
+
# self.__rate_limits = self.__default_rate_limits
|
304
309
|
|
305
310
|
@property
|
306
311
|
def RPM(self):
|
307
312
|
"""Model's requests-per-minute limit."""
|
308
313
|
# self._set_rate_limits()
|
309
314
|
# return self._safety_factor * self.__rate_limits["rpm"]
|
310
|
-
return self.
|
315
|
+
return self._rpm
|
311
316
|
|
312
317
|
@property
|
313
318
|
def TPM(self):
|
314
319
|
"""Model's tokens-per-minute limit."""
|
315
320
|
# self._set_rate_limits()
|
316
321
|
# return self._safety_factor * self.__rate_limits["tpm"]
|
317
|
-
return self.
|
322
|
+
return self._tpm
|
318
323
|
|
319
324
|
@property
|
320
325
|
def rpm(self):
|
@@ -332,17 +337,6 @@ class LanguageModel(
|
|
332
337
|
def tpm(self, value):
|
333
338
|
self._tpm = value
|
334
339
|
|
335
|
-
@property
|
336
|
-
def TPM(self):
|
337
|
-
"""Model's tokens-per-minute limit.
|
338
|
-
|
339
|
-
>>> m = LanguageModel.example()
|
340
|
-
>>> m.TPM > 0
|
341
|
-
True
|
342
|
-
"""
|
343
|
-
self._set_rate_limits()
|
344
|
-
return self._safety_factor * self.__rate_limits["tpm"]
|
345
|
-
|
346
340
|
@staticmethod
|
347
341
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
348
342
|
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
|
@@ -471,12 +465,13 @@ class LanguageModel(
|
|
471
465
|
if encoded_image:
|
472
466
|
# the image has is appended to the user_prompt for hash-lookup purposes
|
473
467
|
image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
|
468
|
+
user_prompt += f" {image_hash}"
|
474
469
|
|
475
470
|
cache_call_params = {
|
476
471
|
"model": str(self.model),
|
477
472
|
"parameters": self.parameters,
|
478
473
|
"system_prompt": system_prompt,
|
479
|
-
"user_prompt": user_prompt
|
474
|
+
"user_prompt": user_prompt,
|
480
475
|
"iteration": iteration,
|
481
476
|
}
|
482
477
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
edsl/language_models/registry.py
CHANGED
@@ -2,10 +2,10 @@ import textwrap
|
|
2
2
|
from random import random
|
3
3
|
from edsl.config import CONFIG
|
4
4
|
|
5
|
-
if "EDSL_DEFAULT_MODEL" not in CONFIG:
|
6
|
-
|
7
|
-
else:
|
8
|
-
|
5
|
+
# if "EDSL_DEFAULT_MODEL" not in CONFIG:
|
6
|
+
# default_model = "test"
|
7
|
+
# else:
|
8
|
+
# default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
|
9
9
|
|
10
10
|
|
11
11
|
def get_model_class(model_name, registry=None):
|
@@ -33,20 +33,24 @@ class Meta(type):
|
|
33
33
|
|
34
34
|
|
35
35
|
class Model(metaclass=Meta):
|
36
|
-
default_model =
|
36
|
+
default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
|
37
37
|
|
38
|
-
def __new__(
|
38
|
+
def __new__(
|
39
|
+
cls, model_name=None, registry=None, service_name=None, *args, **kwargs
|
40
|
+
):
|
39
41
|
# Map index to the respective subclass
|
40
42
|
if model_name is None:
|
41
|
-
model_name =
|
43
|
+
model_name = (
|
44
|
+
cls.default_model
|
45
|
+
) # when model_name is None, use the default model, set in the config file
|
42
46
|
from edsl.inference_services.registry import default
|
43
47
|
|
44
48
|
registry = registry or default
|
45
49
|
|
46
|
-
if isinstance(model_name, int):
|
50
|
+
if isinstance(model_name, int): # can refer to a model by index
|
47
51
|
model_name = cls.available(name_only=True)[model_name]
|
48
52
|
|
49
|
-
factory = registry.create_model_factory(model_name)
|
53
|
+
factory = registry.create_model_factory(model_name, service_name=service_name)
|
50
54
|
return factory(*args, **kwargs)
|
51
55
|
|
52
56
|
@classmethod
|
edsl/questions/QuestionBase.py
CHANGED
@@ -75,8 +75,7 @@ class QuestionBase(
|
|
75
75
|
if not hasattr(self, "_fake_data_factory"):
|
76
76
|
from polyfactory.factories.pydantic_factory import ModelFactory
|
77
77
|
|
78
|
-
class FakeData(ModelFactory[self.response_model]):
|
79
|
-
...
|
78
|
+
class FakeData(ModelFactory[self.response_model]): ...
|
80
79
|
|
81
80
|
self._fake_data_factory = FakeData
|
82
81
|
return self._fake_data_factory
|
@@ -103,13 +102,8 @@ class QuestionBase(
|
|
103
102
|
"""Validate the answer.
|
104
103
|
>>> from edsl.exceptions import QuestionAnswerValidationError
|
105
104
|
>>> from edsl import QuestionFreeText as Q
|
106
|
-
>>> Q.example()._validate_answer({'answer': 'Hello'})
|
107
|
-
{'answer': 'Hello', 'generated_tokens':
|
108
|
-
>>> Q.example()._validate_answer({'shmanswer': 1})
|
109
|
-
Traceback (most recent call last):
|
110
|
-
...
|
111
|
-
edsl.exceptions.questions.QuestionAnswerValidationError:...
|
112
|
-
...
|
105
|
+
>>> Q.example()._validate_answer({'answer': 'Hello', 'generated_tokens': 'Hello'})
|
106
|
+
{'answer': 'Hello', 'generated_tokens': 'Hello'}
|
113
107
|
"""
|
114
108
|
|
115
109
|
return self.response_validator.validate(answer)
|
@@ -471,6 +465,7 @@ class QuestionBase(
|
|
471
465
|
self,
|
472
466
|
scenario: Optional[dict] = None,
|
473
467
|
agent: Optional[dict] = {},
|
468
|
+
answers: Optional[dict] = None,
|
474
469
|
include_question_name: bool = False,
|
475
470
|
height: Optional[int] = None,
|
476
471
|
width: Optional[int] = None,
|
@@ -482,6 +477,17 @@ class QuestionBase(
|
|
482
477
|
if scenario is None:
|
483
478
|
scenario = {}
|
484
479
|
|
480
|
+
prior_answers_dict = {}
|
481
|
+
|
482
|
+
if isinstance(answers, dict):
|
483
|
+
for key, value in answers.items():
|
484
|
+
if not key.endswith("_comment") and not key.endswith(
|
485
|
+
"_generated_tokens"
|
486
|
+
):
|
487
|
+
prior_answers_dict[key] = {"answer": value}
|
488
|
+
|
489
|
+
# breakpoint()
|
490
|
+
|
485
491
|
base_template = """
|
486
492
|
<div id="{{ question_name }}" class="survey_question" data-type="{{ question_type }}">
|
487
493
|
{% if include_question_name %}
|
@@ -501,13 +507,40 @@ class QuestionBase(
|
|
501
507
|
|
502
508
|
base_template = Template(base_template)
|
503
509
|
|
504
|
-
|
505
|
-
"
|
506
|
-
"
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
510
|
+
context = {
|
511
|
+
"scenario": scenario,
|
512
|
+
"agent": agent,
|
513
|
+
} | prior_answers_dict
|
514
|
+
|
515
|
+
# Render the question text
|
516
|
+
try:
|
517
|
+
question_text = Template(self.question_text).render(context)
|
518
|
+
except Exception as e:
|
519
|
+
print(
|
520
|
+
f"Error rendering question: question_text = {self.question_text}, error = {e}"
|
521
|
+
)
|
522
|
+
question_text = self.question_text
|
523
|
+
|
524
|
+
try:
|
525
|
+
question_content = Template(question_content).render(context)
|
526
|
+
except Exception as e:
|
527
|
+
print(
|
528
|
+
f"Error rendering question: question_content = {question_content}, error = {e}"
|
529
|
+
)
|
530
|
+
question_content = question_content
|
531
|
+
|
532
|
+
try:
|
533
|
+
params = {
|
534
|
+
"question_name": self.question_name,
|
535
|
+
"question_text": question_text,
|
536
|
+
"question_type": self.question_type,
|
537
|
+
"question_content": question_content,
|
538
|
+
"include_question_name": include_question_name,
|
539
|
+
}
|
540
|
+
except Exception as e:
|
541
|
+
raise ValueError(
|
542
|
+
f"Error rendering question: params = {params}, error = {e}"
|
543
|
+
)
|
511
544
|
rendered_html = base_template.render(**params)
|
512
545
|
|
513
546
|
if iframe:
|
@@ -526,6 +559,21 @@ class QuestionBase(
|
|
526
559
|
|
527
560
|
return rendered_html
|
528
561
|
|
562
|
+
@classmethod
|
563
|
+
def example_model(cls):
|
564
|
+
from edsl import Model
|
565
|
+
|
566
|
+
q = cls.example()
|
567
|
+
m = Model("test", canned_response=cls._simulate_answer(q)["answer"])
|
568
|
+
|
569
|
+
return m
|
570
|
+
|
571
|
+
@classmethod
|
572
|
+
def example_results(cls):
|
573
|
+
m = cls.example_model()
|
574
|
+
q = cls.example()
|
575
|
+
return q.by(m).run(cache=False)
|
576
|
+
|
529
577
|
def rich_print(self):
|
530
578
|
"""Print the question in a rich format."""
|
531
579
|
from rich.table import Table
|