edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +60 -31
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +18 -9
- edsl/agents/AgentList.py +59 -8
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/config.py +8 -0
- edsl/coop/coop.py +74 -7
- edsl/data/Cache.py +27 -2
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +88 -548
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
- edsl/jobs/runners/JobsRunnerStatus.py +0 -2
- edsl/jobs/tasks/TaskHistory.py +15 -16
- edsl/language_models/LanguageModel.py +44 -84
- edsl/language_models/ModelList.py +47 -1
- edsl/language_models/registry.py +57 -4
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +20 -16
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +18 -9
- edsl/results/Results.py +145 -51
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +61 -4
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +237 -62
- edsl/surveys/Survey.py +16 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/Instruction.py +12 -0
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -4,6 +4,7 @@ import asyncio
|
|
4
4
|
import threading
|
5
5
|
import warnings
|
6
6
|
from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
|
7
|
+
from uuid import UUID
|
7
8
|
from collections import UserList
|
8
9
|
|
9
10
|
from edsl.results.Results import Results
|
@@ -36,6 +37,8 @@ class JobsRunnerAsyncio:
|
|
36
37
|
The Jobs object is a collection of interviews that are to be run.
|
37
38
|
"""
|
38
39
|
|
40
|
+
MAX_CONCURRENT_DEFAULT = 500
|
41
|
+
|
39
42
|
def __init__(self, jobs: "Jobs"):
|
40
43
|
self.jobs = jobs
|
41
44
|
self.interviews: List["Interview"] = jobs.interviews()
|
@@ -43,6 +46,53 @@ class JobsRunnerAsyncio:
|
|
43
46
|
self.total_interviews: List["Interview"] = []
|
44
47
|
self._initialized = threading.Event()
|
45
48
|
|
49
|
+
from edsl.config import CONFIG
|
50
|
+
|
51
|
+
self.MAX_CONCURRENT = int(CONFIG.get("EDSL_MAX_CONCURRENT_TASKS"))
|
52
|
+
# print(f"MAX_CONCURRENT: {self.MAX_CONCURRENT}")
|
53
|
+
|
54
|
+
# async def run_async_generator(
|
55
|
+
# self,
|
56
|
+
# cache: Cache,
|
57
|
+
# n: int = 1,
|
58
|
+
# stop_on_exception: bool = False,
|
59
|
+
# sidecar_model: Optional[LanguageModel] = None,
|
60
|
+
# total_interviews: Optional[List["Interview"]] = None,
|
61
|
+
# raise_validation_errors: bool = False,
|
62
|
+
# ) -> AsyncGenerator["Result", None]:
|
63
|
+
# """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
|
64
|
+
|
65
|
+
# Completed tasks are yielded as they are completed.
|
66
|
+
|
67
|
+
# :param n: how many times to run each interview
|
68
|
+
# :param stop_on_exception: Whether to stop the interview if an exception is raised
|
69
|
+
# :param sidecar_model: a language model to use in addition to the interview's model
|
70
|
+
# :param total_interviews: A list of interviews to run can be provided instead.
|
71
|
+
# :param raise_validation_errors: Whether to raise validation errors
|
72
|
+
# """
|
73
|
+
# tasks = []
|
74
|
+
# if total_interviews: # was already passed in total interviews
|
75
|
+
# self.total_interviews = total_interviews
|
76
|
+
# else:
|
77
|
+
# self.total_interviews = list(
|
78
|
+
# self._populate_total_interviews(n=n)
|
79
|
+
# ) # Populate self.total_interviews before creating tasks
|
80
|
+
# self._initialized.set() # Signal that we're ready
|
81
|
+
|
82
|
+
# for interview in self.total_interviews:
|
83
|
+
# interviewing_task = self._build_interview_task(
|
84
|
+
# interview=interview,
|
85
|
+
# stop_on_exception=stop_on_exception,
|
86
|
+
# sidecar_model=sidecar_model,
|
87
|
+
# raise_validation_errors=raise_validation_errors,
|
88
|
+
# )
|
89
|
+
# tasks.append(asyncio.create_task(interviewing_task))
|
90
|
+
|
91
|
+
# for task in asyncio.as_completed(tasks):
|
92
|
+
# result = await task
|
93
|
+
# self.jobs_runner_status.add_completed_interview(result)
|
94
|
+
# yield result
|
95
|
+
|
46
96
|
async def run_async_generator(
|
47
97
|
self,
|
48
98
|
cache: Cache,
|
@@ -52,9 +102,10 @@ class JobsRunnerAsyncio:
|
|
52
102
|
total_interviews: Optional[List["Interview"]] = None,
|
53
103
|
raise_validation_errors: bool = False,
|
54
104
|
) -> AsyncGenerator["Result", None]:
|
55
|
-
"""Creates
|
105
|
+
"""Creates and processes tasks asynchronously, yielding results as they complete.
|
56
106
|
|
57
|
-
|
107
|
+
Tasks are created and processed in a streaming fashion rather than building the full list upfront.
|
108
|
+
Results are yielded as soon as they are available.
|
58
109
|
|
59
110
|
:param n: how many times to run each interview
|
60
111
|
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
@@ -62,29 +113,70 @@ class JobsRunnerAsyncio:
|
|
62
113
|
:param total_interviews: A list of interviews to run can be provided instead.
|
63
114
|
:param raise_validation_errors: Whether to raise validation errors
|
64
115
|
"""
|
65
|
-
|
66
|
-
if total_interviews:
|
116
|
+
# Initialize interviews iterator
|
117
|
+
if total_interviews:
|
118
|
+
interviews_iter = iter(total_interviews)
|
67
119
|
self.total_interviews = total_interviews
|
68
120
|
else:
|
69
|
-
|
70
|
-
|
71
|
-
|
121
|
+
interviews_iter = self._populate_total_interviews(n=n)
|
122
|
+
self.total_interviews = list(interviews_iter)
|
123
|
+
interviews_iter = iter(self.total_interviews) # Create fresh iterator
|
72
124
|
|
73
125
|
self._initialized.set() # Signal that we're ready
|
74
126
|
|
75
|
-
|
76
|
-
|
77
|
-
interview=interview,
|
78
|
-
stop_on_exception=stop_on_exception,
|
79
|
-
sidecar_model=sidecar_model,
|
80
|
-
raise_validation_errors=raise_validation_errors,
|
81
|
-
)
|
82
|
-
tasks.append(asyncio.create_task(interviewing_task))
|
127
|
+
# Keep track of active tasks
|
128
|
+
active_tasks = set()
|
83
129
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
130
|
+
try:
|
131
|
+
while True:
|
132
|
+
# Add new tasks if we're below max_concurrent and there are more interviews
|
133
|
+
while len(active_tasks) < self.MAX_CONCURRENT:
|
134
|
+
try:
|
135
|
+
interview = next(interviews_iter)
|
136
|
+
task = asyncio.create_task(
|
137
|
+
self._build_interview_task(
|
138
|
+
interview=interview,
|
139
|
+
stop_on_exception=stop_on_exception,
|
140
|
+
sidecar_model=sidecar_model,
|
141
|
+
raise_validation_errors=raise_validation_errors,
|
142
|
+
)
|
143
|
+
)
|
144
|
+
active_tasks.add(task)
|
145
|
+
# Add callback to remove task from set when done
|
146
|
+
task.add_done_callback(active_tasks.discard)
|
147
|
+
except StopIteration:
|
148
|
+
break
|
149
|
+
|
150
|
+
if not active_tasks:
|
151
|
+
break
|
152
|
+
|
153
|
+
# Wait for next completed task
|
154
|
+
done, _ = await asyncio.wait(
|
155
|
+
active_tasks, return_when=asyncio.FIRST_COMPLETED
|
156
|
+
)
|
157
|
+
|
158
|
+
# Process completed tasks
|
159
|
+
for task in done:
|
160
|
+
try:
|
161
|
+
result = await task
|
162
|
+
self.jobs_runner_status.add_completed_interview(result)
|
163
|
+
yield result
|
164
|
+
except Exception as e:
|
165
|
+
if stop_on_exception:
|
166
|
+
# Cancel remaining tasks
|
167
|
+
for t in active_tasks:
|
168
|
+
if not t.done():
|
169
|
+
t.cancel()
|
170
|
+
raise
|
171
|
+
else:
|
172
|
+
# Log error and continue
|
173
|
+
# logger.error(f"Task failed with error: {e}")
|
174
|
+
continue
|
175
|
+
finally:
|
176
|
+
# Ensure we cancel any remaining tasks if we exit early
|
177
|
+
for task in active_tasks:
|
178
|
+
if not task.done():
|
179
|
+
task.cancel()
|
88
180
|
|
89
181
|
def _populate_total_interviews(
|
90
182
|
self, n: int = 1
|
@@ -168,20 +260,20 @@ class JobsRunnerAsyncio:
|
|
168
260
|
|
169
261
|
prompt_dictionary = {}
|
170
262
|
for answer_key_name in answer_key_names:
|
171
|
-
prompt_dictionary[
|
172
|
-
|
173
|
-
|
174
|
-
prompt_dictionary[
|
175
|
-
|
176
|
-
|
263
|
+
prompt_dictionary[
|
264
|
+
answer_key_name + "_user_prompt"
|
265
|
+
] = question_name_to_prompts[answer_key_name]["user_prompt"]
|
266
|
+
prompt_dictionary[
|
267
|
+
answer_key_name + "_system_prompt"
|
268
|
+
] = question_name_to_prompts[answer_key_name]["system_prompt"]
|
177
269
|
|
178
270
|
raw_model_results_dictionary = {}
|
179
271
|
cache_used_dictionary = {}
|
180
272
|
for result in valid_results:
|
181
273
|
question_name = result.question_name
|
182
|
-
raw_model_results_dictionary[
|
183
|
-
|
184
|
-
|
274
|
+
raw_model_results_dictionary[
|
275
|
+
question_name + "_raw_model_response"
|
276
|
+
] = result.raw_model_response
|
185
277
|
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
186
278
|
one_use_buys = (
|
187
279
|
"NA"
|
@@ -245,11 +337,25 @@ class JobsRunnerAsyncio:
|
|
245
337
|
if len(results.task_history.indices) > 5:
|
246
338
|
msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
|
247
339
|
|
248
|
-
|
249
|
-
|
340
|
+
import sys
|
341
|
+
|
342
|
+
print(msg, file=sys.stderr)
|
343
|
+
from edsl.config import CONFIG
|
344
|
+
|
345
|
+
if CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "True":
|
346
|
+
open_in_browser = True
|
347
|
+
elif CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "False":
|
348
|
+
open_in_browser = False
|
349
|
+
else:
|
350
|
+
raise Exception(
|
351
|
+
"EDSL_OPEN_EXCEPTION_REPORT_URL", "must be either True or False"
|
352
|
+
)
|
353
|
+
|
354
|
+
# print("open_in_browser", open_in_browser)
|
355
|
+
|
250
356
|
filepath = results.task_history.html(
|
251
357
|
cta="Open report to see details.",
|
252
|
-
open_in_browser=
|
358
|
+
open_in_browser=open_in_browser,
|
253
359
|
return_link=True,
|
254
360
|
)
|
255
361
|
|
@@ -279,6 +385,7 @@ class JobsRunnerAsyncio:
|
|
279
385
|
progress_bar: bool = False,
|
280
386
|
sidecar_model: Optional[LanguageModel] = None,
|
281
387
|
jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
|
388
|
+
job_uuid: Optional[UUID] = None,
|
282
389
|
print_exceptions: bool = True,
|
283
390
|
raise_validation_errors: bool = False,
|
284
391
|
) -> "Coroutine":
|
@@ -297,13 +404,11 @@ class JobsRunnerAsyncio:
|
|
297
404
|
|
298
405
|
if jobs_runner_status is not None:
|
299
406
|
self.jobs_runner_status = jobs_runner_status(
|
300
|
-
self, n=n, endpoint_url=endpoint_url
|
407
|
+
self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
|
301
408
|
)
|
302
409
|
else:
|
303
410
|
self.jobs_runner_status = JobsRunnerStatus(
|
304
|
-
self,
|
305
|
-
n=n,
|
306
|
-
endpoint_url=endpoint_url,
|
411
|
+
self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
|
307
412
|
)
|
308
413
|
|
309
414
|
stop_event = threading.Event()
|
@@ -239,7 +239,6 @@ class JobsRunnerStatusBase(ABC):
|
|
239
239
|
return stat_definitions[stat_name]()
|
240
240
|
|
241
241
|
def update_progress(self, stop_event):
|
242
|
-
|
243
242
|
while not stop_event.is_set():
|
244
243
|
self.send_status_update()
|
245
244
|
time.sleep(self.refresh_rate)
|
@@ -248,7 +247,6 @@ class JobsRunnerStatusBase(ABC):
|
|
248
247
|
|
249
248
|
|
250
249
|
class JobsRunnerStatus(JobsRunnerStatusBase):
|
251
|
-
|
252
250
|
@property
|
253
251
|
def create_url(self) -> str:
|
254
252
|
return f"{self.base_url}/api/v0/local-job"
|
edsl/jobs/tasks/TaskHistory.py
CHANGED
@@ -8,7 +8,12 @@ from edsl.jobs.tasks.task_status_enum import TaskStatus
|
|
8
8
|
|
9
9
|
|
10
10
|
class TaskHistory:
|
11
|
-
def __init__(
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
interviews: List["Interview"],
|
14
|
+
include_traceback: bool = False,
|
15
|
+
max_interviews: int = 10,
|
16
|
+
):
|
12
17
|
"""
|
13
18
|
The structure of a TaskHistory exception
|
14
19
|
|
@@ -22,6 +27,7 @@ class TaskHistory:
|
|
22
27
|
self.include_traceback = include_traceback
|
23
28
|
|
24
29
|
self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
|
30
|
+
self.max_interviews = max_interviews
|
25
31
|
|
26
32
|
@classmethod
|
27
33
|
def example(cls):
|
@@ -75,13 +81,6 @@ class TaskHistory:
|
|
75
81
|
|
76
82
|
def to_dict(self, add_edsl_version=True):
|
77
83
|
"""Return the TaskHistory as a dictionary."""
|
78
|
-
# return {
|
79
|
-
# "exceptions": [
|
80
|
-
# e.to_dict(include_traceback=self.include_traceback)
|
81
|
-
# for e in self.exceptions
|
82
|
-
# ],
|
83
|
-
# "indices": self.indices,
|
84
|
-
# }
|
85
84
|
d = {
|
86
85
|
"interviews": [
|
87
86
|
i.to_dict(add_edsl_version=add_edsl_version)
|
@@ -124,10 +123,11 @@ class TaskHistory:
|
|
124
123
|
|
125
124
|
def _repr_html_(self):
|
126
125
|
"""Return an HTML representation of the TaskHistory."""
|
127
|
-
|
126
|
+
d = self.to_dict(add_edsl_version=False)
|
127
|
+
data = [[k, v] for k, v in d.items()]
|
128
|
+
from tabulate import tabulate
|
128
129
|
|
129
|
-
|
130
|
-
return data_to_html(newdata, replace_new_lines=True)
|
130
|
+
return tabulate(data, headers=["keys", "values"], tablefmt="html")
|
131
131
|
|
132
132
|
def show_exceptions(self, tracebacks=False):
|
133
133
|
"""Print the exceptions."""
|
@@ -257,8 +257,6 @@ class TaskHistory:
|
|
257
257
|
for question_name, exceptions in interview.exceptions.items():
|
258
258
|
for exception in exceptions:
|
259
259
|
exception_type = exception.exception.__class__.__name__
|
260
|
-
# exception_type = exception["exception"]
|
261
|
-
# breakpoint()
|
262
260
|
if exception_type in exceptions_by_type:
|
263
261
|
exceptions_by_type[exception_type] += 1
|
264
262
|
else:
|
@@ -345,9 +343,9 @@ class TaskHistory:
|
|
345
343
|
|
346
344
|
env = Environment(loader=TemplateLoader("edsl", "templates/error_reporting"))
|
347
345
|
|
348
|
-
#
|
346
|
+
# Get current memory usage at this point
|
347
|
+
|
349
348
|
template = env.get_template("base.html")
|
350
|
-
# rendered_template = template.render(your_data=your_data)
|
351
349
|
|
352
350
|
# Render the template with data
|
353
351
|
output = template.render(
|
@@ -361,6 +359,7 @@ class TaskHistory:
|
|
361
359
|
exceptions_by_model=self.exceptions_by_model,
|
362
360
|
exceptions_by_service=self.exceptions_by_service,
|
363
361
|
models_used=models_used,
|
362
|
+
max_interviews=self.max_interviews,
|
364
363
|
)
|
365
364
|
return output
|
366
365
|
|
@@ -370,7 +369,7 @@ class TaskHistory:
|
|
370
369
|
return_link=False,
|
371
370
|
css=None,
|
372
371
|
cta="Open Report in New Tab",
|
373
|
-
open_in_browser=
|
372
|
+
open_in_browser=False,
|
374
373
|
):
|
375
374
|
"""Return an HTML report."""
|
376
375
|
|
@@ -41,17 +41,19 @@ from edsl.data_transfer_models import (
|
|
41
41
|
AgentResponseDict,
|
42
42
|
)
|
43
43
|
|
44
|
+
if TYPE_CHECKING:
|
45
|
+
from edsl.data.Cache import Cache
|
46
|
+
from edsl.scenarios.FileStore import FileStore
|
47
|
+
from edsl.questions.QuestionBase import QuestionBase
|
44
48
|
|
45
49
|
from edsl.config import CONFIG
|
46
50
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
47
|
-
from edsl.utilities.decorators import
|
48
|
-
from edsl.language_models.repair import repair
|
49
|
-
from edsl.enums import InferenceServiceType
|
50
|
-
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
51
|
-
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
52
|
-
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
51
|
+
from edsl.utilities.decorators import remove_edsl_version
|
53
52
|
|
53
|
+
from edsl.Base import PersistenceMixin
|
54
|
+
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
54
55
|
from edsl.language_models.KeyLookup import KeyLookup
|
56
|
+
from edsl.exceptions.language_models import LanguageModelBadResponseError
|
55
57
|
|
56
58
|
TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
|
57
59
|
|
@@ -116,29 +118,11 @@ def handle_key_error(func):
|
|
116
118
|
|
117
119
|
|
118
120
|
class LanguageModel(
|
119
|
-
|
121
|
+
PersistenceMixin,
|
122
|
+
ABC,
|
123
|
+
metaclass=RegisterLanguageModelsMeta,
|
120
124
|
):
|
121
|
-
"""ABC for
|
122
|
-
|
123
|
-
TODO:
|
124
|
-
|
125
|
-
1) Need better, more descriptive names for functions
|
126
|
-
|
127
|
-
get_model_response_no_cache (currently called async_execute_model_call)
|
128
|
-
|
129
|
-
get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
|
130
|
-
Calls:
|
131
|
-
- async_execute_model_call
|
132
|
-
- _updated_model_response_with_tracking
|
133
|
-
|
134
|
-
get_answer (currently called async_get_response)
|
135
|
-
This parses out the answer block and does some error-handling.
|
136
|
-
Calls:
|
137
|
-
- async_get_raw_response
|
138
|
-
- parse_response
|
139
|
-
|
140
|
-
|
141
|
-
"""
|
125
|
+
"""ABC for Language Models."""
|
142
126
|
|
143
127
|
_model_ = None
|
144
128
|
key_sequence = (
|
@@ -196,7 +180,7 @@ class LanguageModel(
|
|
196
180
|
system_prompt = "You are a helpful agent pretending to be a human."
|
197
181
|
return self.execute_model_call(user_prompt, system_prompt)
|
198
182
|
|
199
|
-
def set_key_lookup(self, key_lookup: KeyLookup):
|
183
|
+
def set_key_lookup(self, key_lookup: KeyLookup) -> None:
|
200
184
|
del self._api_token
|
201
185
|
self.key_lookup = key_lookup
|
202
186
|
|
@@ -211,10 +195,14 @@ class LanguageModel(
|
|
211
195
|
def __getitem__(self, key):
|
212
196
|
return getattr(self, key)
|
213
197
|
|
214
|
-
def _repr_html_(self):
|
215
|
-
|
198
|
+
def _repr_html_(self) -> str:
|
199
|
+
d = {"model": self.model}
|
200
|
+
d.update(self.parameters)
|
201
|
+
data = [[k, v] for k, v in d.items()]
|
202
|
+
from tabulate import tabulate
|
216
203
|
|
217
|
-
|
204
|
+
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
205
|
+
return f"<pre>{table}</pre>"
|
218
206
|
|
219
207
|
def hello(self, verbose=False):
|
220
208
|
"""Runs a simple test to check if the model is working."""
|
@@ -235,7 +223,6 @@ class LanguageModel(
|
|
235
223
|
This method is used to check if the model has a valid API key.
|
236
224
|
"""
|
237
225
|
from edsl.enums import service_to_api_keyname
|
238
|
-
import os
|
239
226
|
|
240
227
|
if self._model_ == "test":
|
241
228
|
return True
|
@@ -248,9 +235,9 @@ class LanguageModel(
|
|
248
235
|
"""Allow the model to be used as a key in a dictionary."""
|
249
236
|
from edsl.utilities.utilities import dict_hash
|
250
237
|
|
251
|
-
return dict_hash(self.to_dict())
|
238
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
252
239
|
|
253
|
-
def __eq__(self, other):
|
240
|
+
def __eq__(self, other) -> bool:
|
254
241
|
"""Check is two models are the same.
|
255
242
|
|
256
243
|
>>> m1 = LanguageModel.example()
|
@@ -278,15 +265,11 @@ class LanguageModel(
|
|
278
265
|
@property
|
279
266
|
def RPM(self):
|
280
267
|
"""Model's requests-per-minute limit."""
|
281
|
-
# self._set_rate_limits()
|
282
|
-
# return self._safety_factor * self.__rate_limits["rpm"]
|
283
268
|
return self._rpm
|
284
269
|
|
285
270
|
@property
|
286
271
|
def TPM(self):
|
287
272
|
"""Model's tokens-per-minute limit."""
|
288
|
-
# self._set_rate_limits()
|
289
|
-
# return self._safety_factor * self.__rate_limits["tpm"]
|
290
273
|
return self._tpm
|
291
274
|
|
292
275
|
@property
|
@@ -314,8 +297,6 @@ class LanguageModel(
|
|
314
297
|
>>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
|
315
298
|
{'temperature': 0.5, 'max_tokens': 1000}
|
316
299
|
"""
|
317
|
-
# parameters = dict({})
|
318
|
-
|
319
300
|
# this is the case when data is loaded from a dict after serialization
|
320
301
|
if "parameters" in passed_parameter_dict:
|
321
302
|
passed_parameter_dict = passed_parameter_dict["parameters"]
|
@@ -429,9 +410,10 @@ class LanguageModel(
|
|
429
410
|
self,
|
430
411
|
user_prompt: str,
|
431
412
|
system_prompt: str,
|
432
|
-
cache:
|
413
|
+
cache: Cache,
|
433
414
|
iteration: int = 0,
|
434
|
-
files_list=None,
|
415
|
+
files_list: Optional[List[FileStore]] = None,
|
416
|
+
invigilator=None,
|
435
417
|
) -> ModelResponse:
|
436
418
|
"""Handle caching of responses.
|
437
419
|
|
@@ -455,7 +437,6 @@ class LanguageModel(
|
|
455
437
|
|
456
438
|
if files_list:
|
457
439
|
files_hash = "+".join([str(hash(file)) for file in files_list])
|
458
|
-
# print(f"Files hash: {files_hash}")
|
459
440
|
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
460
441
|
else:
|
461
442
|
user_prompt_with_hashes = user_prompt
|
@@ -481,9 +462,7 @@ class LanguageModel(
|
|
481
462
|
"user_prompt": user_prompt,
|
482
463
|
"system_prompt": system_prompt,
|
483
464
|
"files_list": files_list,
|
484
|
-
# **({"encoded_image": encoded_image} if encoded_image else {}),
|
485
465
|
}
|
486
|
-
# response = await f(**params)
|
487
466
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
488
467
|
new_cache_key = cache.store(
|
489
468
|
**cache_call_params, response=response
|
@@ -504,11 +483,9 @@ class LanguageModel(
|
|
504
483
|
_async_get_intended_model_call_outcome
|
505
484
|
)
|
506
485
|
|
507
|
-
# get_raw_response = sync_wrapper(async_get_raw_response)
|
508
|
-
|
509
486
|
def simple_ask(
|
510
487
|
self,
|
511
|
-
question:
|
488
|
+
question: QuestionBase,
|
512
489
|
system_prompt="You are a helpful agent pretending to be a human.",
|
513
490
|
top_logprobs=2,
|
514
491
|
):
|
@@ -523,9 +500,10 @@ class LanguageModel(
|
|
523
500
|
self,
|
524
501
|
user_prompt: str,
|
525
502
|
system_prompt: str,
|
526
|
-
cache:
|
503
|
+
cache: Cache,
|
527
504
|
iteration: int = 1,
|
528
|
-
files_list: Optional[List[
|
505
|
+
files_list: Optional[List[FileStore]] = None,
|
506
|
+
**kwargs,
|
529
507
|
) -> dict:
|
530
508
|
"""Get response, parse, and return as string.
|
531
509
|
|
@@ -543,6 +521,9 @@ class LanguageModel(
|
|
543
521
|
"cache": cache,
|
544
522
|
"files_list": files_list,
|
545
523
|
}
|
524
|
+
if "invigilator" in kwargs:
|
525
|
+
params.update({"invigilator": kwargs["invigilator"]})
|
526
|
+
|
546
527
|
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
547
528
|
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
548
529
|
edsl_dict = self.parse_response(model_outputs.response)
|
@@ -553,8 +534,6 @@ class LanguageModel(
|
|
553
534
|
)
|
554
535
|
return agent_response_dict
|
555
536
|
|
556
|
-
# return await self._async_prepare_response(model_call_outcome, cache=cache)
|
557
|
-
|
558
537
|
get_response = sync_wrapper(async_get_response)
|
559
538
|
|
560
539
|
def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
|
@@ -604,10 +583,7 @@ class LanguageModel(
|
|
604
583
|
|
605
584
|
return input_cost + output_cost
|
606
585
|
|
607
|
-
|
608
|
-
# SERIALIZATION METHODS
|
609
|
-
#######################
|
610
|
-
def to_dict(self, add_edsl_version=True) -> dict[str, Any]:
|
586
|
+
def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
|
611
587
|
"""Convert instance to a dictionary
|
612
588
|
|
613
589
|
>>> m = LanguageModel.example()
|
@@ -629,18 +605,8 @@ class LanguageModel(
|
|
629
605
|
from edsl.language_models.registry import get_model_class
|
630
606
|
|
631
607
|
model_class = get_model_class(data["model"])
|
632
|
-
# data["use_cache"] = True
|
633
608
|
return model_class(**data)
|
634
609
|
|
635
|
-
#######################
|
636
|
-
# DUNDER METHODS
|
637
|
-
#######################
|
638
|
-
def print(self):
|
639
|
-
from rich import print_json
|
640
|
-
import json
|
641
|
-
|
642
|
-
print_json(json.dumps(self.to_dict()))
|
643
|
-
|
644
610
|
def __repr__(self) -> str:
|
645
611
|
"""Return a string representation of the object."""
|
646
612
|
param_string = ", ".join(
|
@@ -654,33 +620,21 @@ class LanguageModel(
|
|
654
620
|
|
655
621
|
def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
|
656
622
|
"""Combine two models into a single model (other_model takes precedence over self)."""
|
657
|
-
|
623
|
+
import warnings
|
624
|
+
|
625
|
+
warnings.warn(
|
658
626
|
f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
|
659
627
|
by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
|
660
628
|
)
|
661
629
|
return other_model or self
|
662
630
|
|
663
|
-
def rich_print(self):
|
664
|
-
"""Display an object as a table."""
|
665
|
-
from rich.table import Table
|
666
|
-
|
667
|
-
table = Table(title="Language Model")
|
668
|
-
table.add_column("Attribute", style="bold")
|
669
|
-
table.add_column("Value")
|
670
|
-
|
671
|
-
to_display = self.__dict__.copy()
|
672
|
-
for attr_name, attr_value in to_display.items():
|
673
|
-
table.add_row(attr_name, repr(attr_value))
|
674
|
-
|
675
|
-
return table
|
676
|
-
|
677
631
|
@classmethod
|
678
632
|
def example(
|
679
633
|
cls,
|
680
634
|
test_model: bool = False,
|
681
635
|
canned_response: str = "Hello world",
|
682
636
|
throw_exception: bool = False,
|
683
|
-
):
|
637
|
+
) -> LanguageModel:
|
684
638
|
"""Return a default instance of the class.
|
685
639
|
|
686
640
|
>>> from edsl.language_models import LanguageModel
|
@@ -691,11 +645,17 @@ class LanguageModel(
|
|
691
645
|
>>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
|
692
646
|
>>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
|
693
647
|
'WOWZA!'
|
648
|
+
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
|
649
|
+
>>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
|
650
|
+
Exception report saved to ...
|
651
|
+
Also see: ...
|
694
652
|
"""
|
695
653
|
from edsl import Model
|
696
654
|
|
697
655
|
if test_model:
|
698
|
-
m = Model(
|
656
|
+
m = Model(
|
657
|
+
"test", canned_response=canned_response, throw_exception=throw_exception
|
658
|
+
)
|
699
659
|
return m
|
700
660
|
else:
|
701
661
|
return Model(skip_api_key_check=True)
|