edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +24 -14
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +28 -6
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
- edsl/agents/prompt_helpers.py +129 -0
- edsl/config.py +26 -34
- edsl/coop/coop.py +14 -4
- edsl/data_transfer_models.py +26 -73
- edsl/enums.py +2 -0
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +10 -6
- edsl/inference_services/TestService.py +34 -16
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +109 -18
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +130 -49
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
- edsl/jobs/runners/JobsRunnerStatus.py +332 -0
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +36 -38
- edsl/language_models/registry.py +13 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +74 -16
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +13 -24
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +11 -6
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +178 -34
- edsl/scenarios/Scenario.py +76 -37
- edsl/scenarios/ScenarioList.py +19 -2
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +34 -1
- edsl/surveys/RuleCollection.py +55 -5
- edsl/surveys/Survey.py +189 -10
- edsl/surveys/base.py +4 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -0,0 +1,332 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
from dataclasses import dataclass, asdict
|
5
|
+
|
6
|
+
from typing import List, DefaultDict, Optional, Type, Literal
|
7
|
+
from collections import UserDict, defaultdict
|
8
|
+
|
9
|
+
from rich.text import Text
|
10
|
+
from rich.box import SIMPLE
|
11
|
+
from rich.table import Table
|
12
|
+
from rich.live import Live
|
13
|
+
from rich.panel import Panel
|
14
|
+
from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
|
15
|
+
from rich.layout import Layout
|
16
|
+
from rich.console import Group
|
17
|
+
from rich import box
|
18
|
+
|
19
|
+
from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
|
20
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
21
|
+
from edsl.jobs.tokens.TokenUsage import TokenUsage
|
22
|
+
from edsl.enums import get_token_pricing
|
23
|
+
from edsl.jobs.tasks.task_status_enum import TaskStatus
|
24
|
+
|
25
|
+
InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
|
26
|
+
|
27
|
+
from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
|
28
|
+
from edsl.jobs.interviews.InterviewStatisticsCollection import (
|
29
|
+
InterviewStatisticsCollection,
|
30
|
+
)
|
31
|
+
from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class ModelInfo:
|
36
|
+
model_name: str
|
37
|
+
TPM_limit_k: float
|
38
|
+
RPM_limit_k: float
|
39
|
+
num_tasks_waiting: int
|
40
|
+
token_usage_info: dict
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class ModelTokenUsageStats:
|
45
|
+
token_usage_type: str
|
46
|
+
details: List[dict]
|
47
|
+
cost: str
|
48
|
+
|
49
|
+
|
50
|
+
class Stats:
|
51
|
+
def elapsed_time(self):
|
52
|
+
InterviewStatistic("elapsed_time", value=elapsed_time, digits=1, units="sec.")
|
53
|
+
|
54
|
+
|
55
|
+
class JobsRunnerStatus:
|
56
|
+
def __init__(
|
57
|
+
self, jobs_runner: "JobsRunnerAsyncio", n: int, refresh_rate: float = 0.25
|
58
|
+
):
|
59
|
+
self.jobs_runner = jobs_runner
|
60
|
+
self.start_time = time.time()
|
61
|
+
self.completed_interviews = []
|
62
|
+
self.refresh_rate = refresh_rate
|
63
|
+
self.statistics = [
|
64
|
+
"elapsed_time",
|
65
|
+
"total_interviews_requested",
|
66
|
+
"completed_interviews",
|
67
|
+
# "percent_complete",
|
68
|
+
"average_time_per_interview",
|
69
|
+
# "task_remaining",
|
70
|
+
"estimated_time_remaining",
|
71
|
+
"exceptions",
|
72
|
+
"unfixed_exceptions",
|
73
|
+
"throughput",
|
74
|
+
]
|
75
|
+
self.num_total_interviews = n * len(self.jobs_runner.interviews)
|
76
|
+
|
77
|
+
self.distinct_models = list(
|
78
|
+
set(i.model.model for i in self.jobs_runner.interviews)
|
79
|
+
)
|
80
|
+
|
81
|
+
self.completed_interview_by_model = defaultdict(list)
|
82
|
+
|
83
|
+
def add_completed_interview(self, result):
|
84
|
+
self.completed_interviews.append(result.interview_hash)
|
85
|
+
|
86
|
+
relevant_model = result.model.model
|
87
|
+
self.completed_interview_by_model[relevant_model].append(result.interview_hash)
|
88
|
+
|
89
|
+
def _compute_statistic(self, stat_name: str):
|
90
|
+
completed_tasks = self.completed_interviews
|
91
|
+
elapsed_time = time.time() - self.start_time
|
92
|
+
interviews = self.jobs_runner.total_interviews
|
93
|
+
|
94
|
+
stat_definitions = {
|
95
|
+
"elapsed_time": lambda: InterviewStatistic(
|
96
|
+
"elapsed_time", value=elapsed_time, digits=1, units="sec."
|
97
|
+
),
|
98
|
+
"total_interviews_requested": lambda: InterviewStatistic(
|
99
|
+
"total_interviews_requested", value=len(interviews), units=""
|
100
|
+
),
|
101
|
+
"completed_interviews": lambda: InterviewStatistic(
|
102
|
+
"completed_interviews", value=len(completed_tasks), units=""
|
103
|
+
),
|
104
|
+
"percent_complete": lambda: InterviewStatistic(
|
105
|
+
"percent_complete",
|
106
|
+
value=(
|
107
|
+
len(completed_tasks) / len(interviews) * 100
|
108
|
+
if len(interviews) > 0
|
109
|
+
else 0
|
110
|
+
),
|
111
|
+
digits=1,
|
112
|
+
units="%",
|
113
|
+
),
|
114
|
+
"average_time_per_interview": lambda: InterviewStatistic(
|
115
|
+
"average_time_per_interview",
|
116
|
+
value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
|
117
|
+
digits=2,
|
118
|
+
units="sec.",
|
119
|
+
),
|
120
|
+
"task_remaining": lambda: InterviewStatistic(
|
121
|
+
"task_remaining", value=len(interviews) - len(completed_tasks), units=""
|
122
|
+
),
|
123
|
+
"estimated_time_remaining": lambda: InterviewStatistic(
|
124
|
+
"estimated_time_remaining",
|
125
|
+
value=(
|
126
|
+
(len(interviews) - len(completed_tasks))
|
127
|
+
* (elapsed_time / len(completed_tasks))
|
128
|
+
if len(completed_tasks) > 0
|
129
|
+
else 0
|
130
|
+
),
|
131
|
+
digits=1,
|
132
|
+
units="sec.",
|
133
|
+
),
|
134
|
+
"exceptions": lambda: InterviewStatistic(
|
135
|
+
"exceptions",
|
136
|
+
value=sum(len(i.exceptions) for i in interviews),
|
137
|
+
units="",
|
138
|
+
),
|
139
|
+
"unfixed_exceptions": lambda: InterviewStatistic(
|
140
|
+
"unfixed_exceptions",
|
141
|
+
value=sum(i.exceptions.num_unfixed() for i in interviews),
|
142
|
+
units="",
|
143
|
+
),
|
144
|
+
"throughput": lambda: InterviewStatistic(
|
145
|
+
"throughput",
|
146
|
+
value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
|
147
|
+
digits=2,
|
148
|
+
units="interviews/sec.",
|
149
|
+
),
|
150
|
+
}
|
151
|
+
return stat_definitions[stat_name]()
|
152
|
+
|
153
|
+
def create_progress_bar(self):
|
154
|
+
return Progress(
|
155
|
+
TextColumn("[progress.description]{task.description}"),
|
156
|
+
BarColumn(),
|
157
|
+
TaskProgressColumn(),
|
158
|
+
TextColumn("{task.completed}/{task.total}"),
|
159
|
+
)
|
160
|
+
|
161
|
+
def generate_model_queues_table(self):
|
162
|
+
table = Table(show_header=False, box=box.SIMPLE)
|
163
|
+
table.add_column("Info", style="cyan")
|
164
|
+
table.add_column("Value", style="magenta")
|
165
|
+
# table.add_row("Bucket collection", str(self.jobs_runner.bucket_collection))
|
166
|
+
for model, bucket in self.jobs_runner.bucket_collection.items():
|
167
|
+
table.add_row(Text(model.model, style="bold blue"), "")
|
168
|
+
bucket_types = ["requests_bucket", "tokens_bucket"]
|
169
|
+
for bucket_type in bucket_types:
|
170
|
+
table.add_row(Text(" " + bucket_type, style="green"), "")
|
171
|
+
# table.add_row(
|
172
|
+
# f" Current level (capacity = {round(getattr(bucket, bucket_type).capacity, 3)})",
|
173
|
+
# str(round(getattr(bucket, bucket_type).tokens, 3)),
|
174
|
+
# )
|
175
|
+
num_requests = getattr(bucket, bucket_type).num_requests
|
176
|
+
num_released = getattr(bucket, bucket_type).num_released
|
177
|
+
tokens_returned = getattr(bucket, bucket_type).tokens_returned
|
178
|
+
# table.add_row(
|
179
|
+
# f" Requested",
|
180
|
+
# str(num_requests),
|
181
|
+
# )
|
182
|
+
# table.add_row(
|
183
|
+
# f" Completed",
|
184
|
+
# str(num_released),
|
185
|
+
# )
|
186
|
+
table.add_row(
|
187
|
+
" Completed vs. Requested", f"{num_released} vs. {num_requests}"
|
188
|
+
)
|
189
|
+
table.add_row(
|
190
|
+
" Added tokens (from cache)",
|
191
|
+
str(tokens_returned),
|
192
|
+
)
|
193
|
+
if bucket_type == "tokens_bucket":
|
194
|
+
rate_name = "TPM"
|
195
|
+
else:
|
196
|
+
rate_name = "RPM"
|
197
|
+
target_rate = round(getattr(bucket, bucket_type).target_rate, 1)
|
198
|
+
table.add_row(
|
199
|
+
f" Empirical {rate_name} (target = {target_rate})",
|
200
|
+
str(round(getattr(bucket, bucket_type).get_throughput(), 0)),
|
201
|
+
)
|
202
|
+
|
203
|
+
return table
|
204
|
+
|
205
|
+
def generate_layout(self):
|
206
|
+
progress = self.create_progress_bar()
|
207
|
+
task_ids = []
|
208
|
+
for model in self.distinct_models:
|
209
|
+
task_id = progress.add_task(
|
210
|
+
f"[cyan]{model}...",
|
211
|
+
total=int(self.num_total_interviews / len(self.distinct_models)),
|
212
|
+
)
|
213
|
+
task_ids.append((model, task_id))
|
214
|
+
|
215
|
+
progress_height = min(5, 2 + len(self.distinct_models))
|
216
|
+
layout = Layout()
|
217
|
+
|
218
|
+
# Create the top row with only the progress panel
|
219
|
+
layout.split_column(
|
220
|
+
Layout(
|
221
|
+
Panel(
|
222
|
+
progress,
|
223
|
+
title="Interview Progress",
|
224
|
+
border_style="cyan",
|
225
|
+
box=box.ROUNDED,
|
226
|
+
),
|
227
|
+
name="progress",
|
228
|
+
size=progress_height, # Adjusted size
|
229
|
+
),
|
230
|
+
Layout(name="bottom_row"), # Adjusted size
|
231
|
+
)
|
232
|
+
|
233
|
+
# Split the bottom row into two columns for metrics and model queues
|
234
|
+
layout["bottom_row"].split_row(
|
235
|
+
Layout(
|
236
|
+
Panel(
|
237
|
+
self.generate_metrics_table(),
|
238
|
+
title="Metrics",
|
239
|
+
border_style="magenta",
|
240
|
+
box=box.ROUNDED,
|
241
|
+
),
|
242
|
+
name="metrics",
|
243
|
+
),
|
244
|
+
Layout(
|
245
|
+
Panel(
|
246
|
+
self.generate_model_queues_table(),
|
247
|
+
title="Model Queues",
|
248
|
+
border_style="yellow",
|
249
|
+
box=box.ROUNDED,
|
250
|
+
),
|
251
|
+
name="model_queues",
|
252
|
+
),
|
253
|
+
)
|
254
|
+
|
255
|
+
return layout, progress, task_ids
|
256
|
+
|
257
|
+
def generate_metrics_table(self):
|
258
|
+
table = Table(show_header=True, header_style="bold magenta", box=box.SIMPLE)
|
259
|
+
table.add_column("Metric", style="cyan", no_wrap=True)
|
260
|
+
table.add_column("Value", justify="right")
|
261
|
+
|
262
|
+
for stat_name in self.statistics:
|
263
|
+
pretty_name, value = list(self._compute_statistic(stat_name).items())[0]
|
264
|
+
# breakpoint()
|
265
|
+
table.add_row(pretty_name, value)
|
266
|
+
return table
|
267
|
+
|
268
|
+
def update_progress(self, stop_event):
|
269
|
+
layout, progress, task_ids = self.generate_layout()
|
270
|
+
|
271
|
+
with Live(
|
272
|
+
layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
|
273
|
+
) as live:
|
274
|
+
while (
|
275
|
+
len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
|
276
|
+
and not stop_event.is_set()
|
277
|
+
):
|
278
|
+
completed_tasks = len(self.completed_interviews)
|
279
|
+
total_tasks = len(self.jobs_runner.total_interviews)
|
280
|
+
|
281
|
+
for model, task_id in task_ids:
|
282
|
+
completed_tasks = len(self.completed_interview_by_model[model])
|
283
|
+
progress.update(
|
284
|
+
task_id,
|
285
|
+
completed=completed_tasks,
|
286
|
+
description=f"[cyan]Conducting interviews for {model}...",
|
287
|
+
)
|
288
|
+
|
289
|
+
layout["metrics"].update(
|
290
|
+
Panel(
|
291
|
+
self.generate_metrics_table(),
|
292
|
+
title="Metrics",
|
293
|
+
border_style="magenta",
|
294
|
+
box=box.ROUNDED,
|
295
|
+
)
|
296
|
+
)
|
297
|
+
layout["model_queues"].update(
|
298
|
+
Panel(
|
299
|
+
self.generate_model_queues_table(),
|
300
|
+
title="Final Model Queues",
|
301
|
+
border_style="yellow",
|
302
|
+
box=box.ROUNDED,
|
303
|
+
)
|
304
|
+
)
|
305
|
+
|
306
|
+
time.sleep(self.refresh_rate)
|
307
|
+
|
308
|
+
# Final update
|
309
|
+
for model, task_id in task_ids:
|
310
|
+
completed_tasks = len(self.completed_interview_by_model[model])
|
311
|
+
progress.update(
|
312
|
+
task_id,
|
313
|
+
completed=completed_tasks,
|
314
|
+
description=f"[cyan]Conducting interviews for {model}...",
|
315
|
+
)
|
316
|
+
|
317
|
+
layout["metrics"].update(
|
318
|
+
Panel(
|
319
|
+
self.generate_metrics_table(),
|
320
|
+
title="Final Metrics",
|
321
|
+
border_style="magenta",
|
322
|
+
box=box.ROUNDED,
|
323
|
+
)
|
324
|
+
)
|
325
|
+
live.update(layout)
|
326
|
+
time.sleep(1) # Show final state for 1 second
|
327
|
+
|
328
|
+
|
329
|
+
if __name__ == "__main__":
|
330
|
+
import doctest
|
331
|
+
|
332
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
@@ -156,19 +156,6 @@ class QuestionTaskCreator(UserList):
|
|
156
156
|
self.tokens_bucket.turbo_mode_off()
|
157
157
|
self.requests_bucket.turbo_mode_off()
|
158
158
|
|
159
|
-
# breakpoint()
|
160
|
-
# _ = results.pop("cached_response", None)
|
161
|
-
|
162
|
-
# tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
163
|
-
|
164
|
-
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
165
|
-
# usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
166
|
-
# prompt_tokens = usage.get("prompt_tokens", 0)
|
167
|
-
# completion_tokens = usage.get("completion_tokens", 0)
|
168
|
-
# tracker.add_tokens(
|
169
|
-
# prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
170
|
-
# )
|
171
|
-
|
172
159
|
return results
|
173
160
|
|
174
161
|
@classmethod
|
@@ -249,6 +236,7 @@ class QuestionTaskCreator(UserList):
|
|
249
236
|
f"Required tasks failed for {self.question.question_name}"
|
250
237
|
) from e
|
251
238
|
|
239
|
+
# this only runs if all the dependencies are successful
|
252
240
|
return await self._run_focal_task()
|
253
241
|
|
254
242
|
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -50,6 +50,18 @@ class TaskHistory:
|
|
50
50
|
"""
|
51
51
|
return [i.exceptions for k, i in self._interviews.items() if i.exceptions != {}]
|
52
52
|
|
53
|
+
@property
|
54
|
+
def unfixed_exceptions(self):
|
55
|
+
"""
|
56
|
+
>>> len(TaskHistory.example().unfixed_exceptions)
|
57
|
+
4
|
58
|
+
"""
|
59
|
+
return [
|
60
|
+
i.exceptions
|
61
|
+
for k, i in self._interviews.items()
|
62
|
+
if i.exceptions.num_unfixed() > 0
|
63
|
+
]
|
64
|
+
|
53
65
|
@property
|
54
66
|
def indices(self):
|
55
67
|
return [k for k, i in self._interviews.items() if i.exceptions != {}]
|
@@ -78,6 +90,11 @@ class TaskHistory:
|
|
78
90
|
"""
|
79
91
|
return len(self.exceptions) > 0
|
80
92
|
|
93
|
+
@property
|
94
|
+
def has_unfixed_exceptions(self) -> bool:
|
95
|
+
"""Return True if there are any exceptions."""
|
96
|
+
return len(self.unfixed_exceptions) > 0
|
97
|
+
|
81
98
|
def _repr_html_(self):
|
82
99
|
"""Return an HTML representation of the TaskHistory."""
|
83
100
|
from edsl.utilities.utilities import data_to_html
|
@@ -164,20 +164,20 @@ class LanguageModel(
|
|
164
164
|
None # This should be something like ["choices", 0, "message", "content"]
|
165
165
|
)
|
166
166
|
__rate_limits = None
|
167
|
-
__default_rate_limits = {
|
168
|
-
"rpm": 10_000,
|
169
|
-
"tpm": 2_000_000,
|
170
|
-
} # TODO: Use the OpenAI Teir 1 rate limits
|
171
167
|
_safety_factor = 0.8
|
172
168
|
|
173
|
-
def __init__(
|
169
|
+
def __init__(
|
170
|
+
self, tpm=None, rpm=None, omit_system_prompt_if_empty_string=True, **kwargs
|
171
|
+
):
|
174
172
|
"""Initialize the LanguageModel."""
|
175
173
|
self.model = getattr(self, "_model_", None)
|
176
174
|
default_parameters = getattr(self, "_parameters_", None)
|
177
175
|
parameters = self._overide_default_parameters(kwargs, default_parameters)
|
178
176
|
self.parameters = parameters
|
179
177
|
self.remote = False
|
178
|
+
self.omit_system_prompt_if_empty = omit_system_prompt_if_empty_string
|
180
179
|
|
180
|
+
# self._rpm / _tpm comes from the class
|
181
181
|
if rpm is not None:
|
182
182
|
self._rpm = rpm
|
183
183
|
|
@@ -286,35 +286,40 @@ class LanguageModel(
|
|
286
286
|
>>> m.RPM
|
287
287
|
100
|
288
288
|
"""
|
289
|
-
|
289
|
+
if rpm is not None:
|
290
|
+
self._rpm = rpm
|
291
|
+
if tpm is not None:
|
292
|
+
self._tpm = tpm
|
293
|
+
return None
|
294
|
+
# self._set_rate_limits(rpm=rpm, tpm=tpm)
|
290
295
|
|
291
|
-
def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
292
|
-
|
296
|
+
# def _set_rate_limits(self, rpm=None, tpm=None) -> None:
|
297
|
+
# """Set the rate limits for the model.
|
293
298
|
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
299
|
+
# If the model does not have rate limits, use the default rate limits."""
|
300
|
+
# if rpm is not None and tpm is not None:
|
301
|
+
# self.__rate_limits = {"rpm": rpm, "tpm": tpm}
|
302
|
+
# return
|
298
303
|
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
+
# if self.__rate_limits is None:
|
305
|
+
# if hasattr(self, "get_rate_limits"):
|
306
|
+
# self.__rate_limits = self.get_rate_limits()
|
307
|
+
# else:
|
308
|
+
# self.__rate_limits = self.__default_rate_limits
|
304
309
|
|
305
310
|
@property
|
306
311
|
def RPM(self):
|
307
312
|
"""Model's requests-per-minute limit."""
|
308
313
|
# self._set_rate_limits()
|
309
314
|
# return self._safety_factor * self.__rate_limits["rpm"]
|
310
|
-
return self.
|
315
|
+
return self._rpm
|
311
316
|
|
312
317
|
@property
|
313
318
|
def TPM(self):
|
314
319
|
"""Model's tokens-per-minute limit."""
|
315
320
|
# self._set_rate_limits()
|
316
321
|
# return self._safety_factor * self.__rate_limits["tpm"]
|
317
|
-
return self.
|
322
|
+
return self._tpm
|
318
323
|
|
319
324
|
@property
|
320
325
|
def rpm(self):
|
@@ -332,17 +337,6 @@ class LanguageModel(
|
|
332
337
|
def tpm(self, value):
|
333
338
|
self._tpm = value
|
334
339
|
|
335
|
-
@property
|
336
|
-
def TPM(self):
|
337
|
-
"""Model's tokens-per-minute limit.
|
338
|
-
|
339
|
-
>>> m = LanguageModel.example()
|
340
|
-
>>> m.TPM > 0
|
341
|
-
True
|
342
|
-
"""
|
343
|
-
self._set_rate_limits()
|
344
|
-
return self._safety_factor * self.__rate_limits["tpm"]
|
345
|
-
|
346
340
|
@staticmethod
|
347
341
|
def _overide_default_parameters(passed_parameter_dict, default_parameter_dict):
|
348
342
|
"""Return a dictionary of parameters, with passed parameters taking precedence over defaults.
|
@@ -446,7 +440,7 @@ class LanguageModel(
|
|
446
440
|
system_prompt: str,
|
447
441
|
cache: "Cache",
|
448
442
|
iteration: int = 0,
|
449
|
-
|
443
|
+
files_list=None,
|
450
444
|
) -> ModelResponse:
|
451
445
|
"""Handle caching of responses.
|
452
446
|
|
@@ -468,15 +462,18 @@ class LanguageModel(
|
|
468
462
|
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
469
463
|
ModelResponse(...)"""
|
470
464
|
|
471
|
-
if
|
472
|
-
|
473
|
-
|
465
|
+
if files_list:
|
466
|
+
files_hash = "+".join([str(hash(file)) for file in files_list])
|
467
|
+
# print(f"Files hash: {files_hash}")
|
468
|
+
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
469
|
+
else:
|
470
|
+
user_prompt_with_hashes = user_prompt
|
474
471
|
|
475
472
|
cache_call_params = {
|
476
473
|
"model": str(self.model),
|
477
474
|
"parameters": self.parameters,
|
478
475
|
"system_prompt": system_prompt,
|
479
|
-
"user_prompt":
|
476
|
+
"user_prompt": user_prompt_with_hashes,
|
480
477
|
"iteration": iteration,
|
481
478
|
}
|
482
479
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
@@ -492,7 +489,8 @@ class LanguageModel(
|
|
492
489
|
params = {
|
493
490
|
"user_prompt": user_prompt,
|
494
491
|
"system_prompt": system_prompt,
|
495
|
-
|
492
|
+
"files_list": files_list
|
493
|
+
# **({"encoded_image": encoded_image} if encoded_image else {}),
|
496
494
|
}
|
497
495
|
# response = await f(**params)
|
498
496
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
@@ -536,7 +534,7 @@ class LanguageModel(
|
|
536
534
|
system_prompt: str,
|
537
535
|
cache: "Cache",
|
538
536
|
iteration: int = 1,
|
539
|
-
|
537
|
+
files_list: Optional[List["File"]] = None,
|
540
538
|
) -> dict:
|
541
539
|
"""Get response, parse, and return as string.
|
542
540
|
|
@@ -552,7 +550,7 @@ class LanguageModel(
|
|
552
550
|
"system_prompt": system_prompt,
|
553
551
|
"iteration": iteration,
|
554
552
|
"cache": cache,
|
555
|
-
|
553
|
+
"files_list": files_list,
|
556
554
|
}
|
557
555
|
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
558
556
|
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
edsl/language_models/registry.py
CHANGED
@@ -2,10 +2,10 @@ import textwrap
|
|
2
2
|
from random import random
|
3
3
|
from edsl.config import CONFIG
|
4
4
|
|
5
|
-
if "EDSL_DEFAULT_MODEL" not in CONFIG:
|
6
|
-
|
7
|
-
else:
|
8
|
-
|
5
|
+
# if "EDSL_DEFAULT_MODEL" not in CONFIG:
|
6
|
+
# default_model = "test"
|
7
|
+
# else:
|
8
|
+
# default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
|
9
9
|
|
10
10
|
|
11
11
|
def get_model_class(model_name, registry=None):
|
@@ -33,20 +33,24 @@ class Meta(type):
|
|
33
33
|
|
34
34
|
|
35
35
|
class Model(metaclass=Meta):
|
36
|
-
default_model =
|
36
|
+
default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
|
37
37
|
|
38
|
-
def __new__(
|
38
|
+
def __new__(
|
39
|
+
cls, model_name=None, registry=None, service_name=None, *args, **kwargs
|
40
|
+
):
|
39
41
|
# Map index to the respective subclass
|
40
42
|
if model_name is None:
|
41
|
-
model_name =
|
43
|
+
model_name = (
|
44
|
+
cls.default_model
|
45
|
+
) # when model_name is None, use the default model, set in the config file
|
42
46
|
from edsl.inference_services.registry import default
|
43
47
|
|
44
48
|
registry = registry or default
|
45
49
|
|
46
|
-
if isinstance(model_name, int):
|
50
|
+
if isinstance(model_name, int): # can refer to a model by index
|
47
51
|
model_name = cls.available(name_only=True)[model_name]
|
48
52
|
|
49
|
-
factory = registry.create_model_factory(model_name)
|
53
|
+
factory = registry.create_model_factory(model_name, service_name=service_name)
|
50
54
|
return factory(*args, **kwargs)
|
51
55
|
|
52
56
|
@classmethod
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Optional, List
|
3
3
|
from edsl import Survey
|
4
4
|
from edsl.config import CONFIG
|
5
5
|
from edsl.enums import InferenceServiceType
|
@@ -40,7 +40,10 @@ def create_language_model(
|
|
40
40
|
_tpm = 1000000000000
|
41
41
|
|
42
42
|
async def async_execute_model_call(
|
43
|
-
self,
|
43
|
+
self,
|
44
|
+
user_prompt: str,
|
45
|
+
system_prompt: str,
|
46
|
+
files_list: Optional[List[Any]] = None,
|
44
47
|
) -> dict[str, Any]:
|
45
48
|
question_number = int(
|
46
49
|
user_prompt.split("XX")[1]
|