edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. edsl/Base.py +60 -31
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Agent.py +18 -9
  4. edsl/agents/AgentList.py +59 -8
  5. edsl/agents/Invigilator.py +18 -7
  6. edsl/agents/InvigilatorBase.py +0 -19
  7. edsl/agents/PromptConstructor.py +5 -4
  8. edsl/config.py +8 -0
  9. edsl/coop/coop.py +74 -7
  10. edsl/data/Cache.py +27 -2
  11. edsl/data/CacheEntry.py +8 -3
  12. edsl/data/RemoteCacheSync.py +0 -19
  13. edsl/enums.py +2 -0
  14. edsl/inference_services/GoogleService.py +7 -15
  15. edsl/inference_services/PerplexityService.py +163 -0
  16. edsl/inference_services/registry.py +2 -0
  17. edsl/jobs/Jobs.py +88 -548
  18. edsl/jobs/JobsChecks.py +147 -0
  19. edsl/jobs/JobsPrompts.py +268 -0
  20. edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
  21. edsl/jobs/interviews/Interview.py +11 -11
  22. edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
  23. edsl/jobs/runners/JobsRunnerStatus.py +0 -2
  24. edsl/jobs/tasks/TaskHistory.py +15 -16
  25. edsl/language_models/LanguageModel.py +44 -84
  26. edsl/language_models/ModelList.py +47 -1
  27. edsl/language_models/registry.py +57 -4
  28. edsl/prompts/Prompt.py +8 -3
  29. edsl/questions/QuestionBase.py +20 -16
  30. edsl/questions/QuestionExtract.py +3 -4
  31. edsl/questions/question_registry.py +36 -6
  32. edsl/results/CSSParameterizer.py +108 -0
  33. edsl/results/Dataset.py +146 -15
  34. edsl/results/DatasetExportMixin.py +231 -217
  35. edsl/results/DatasetTree.py +134 -4
  36. edsl/results/Result.py +18 -9
  37. edsl/results/Results.py +145 -51
  38. edsl/results/TableDisplay.py +198 -0
  39. edsl/results/table_display.css +78 -0
  40. edsl/scenarios/FileStore.py +187 -13
  41. edsl/scenarios/Scenario.py +61 -4
  42. edsl/scenarios/ScenarioJoin.py +127 -0
  43. edsl/scenarios/ScenarioList.py +237 -62
  44. edsl/surveys/Survey.py +16 -2
  45. edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
  46. edsl/surveys/instructions/Instruction.py +12 -0
  47. edsl/templates/error_reporting/interview_details.html +3 -3
  48. edsl/templates/error_reporting/interviews.html +18 -9
  49. edsl/utilities/utilities.py +15 -0
  50. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
  51. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
  52. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
  53. {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
@@ -4,6 +4,7 @@ import asyncio
4
4
  import threading
5
5
  import warnings
6
6
  from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator, Type
7
+ from uuid import UUID
7
8
  from collections import UserList
8
9
 
9
10
  from edsl.results.Results import Results
@@ -36,6 +37,8 @@ class JobsRunnerAsyncio:
36
37
  The Jobs object is a collection of interviews that are to be run.
37
38
  """
38
39
 
40
+ MAX_CONCURRENT_DEFAULT = 500
41
+
39
42
  def __init__(self, jobs: "Jobs"):
40
43
  self.jobs = jobs
41
44
  self.interviews: List["Interview"] = jobs.interviews()
@@ -43,6 +46,53 @@ class JobsRunnerAsyncio:
43
46
  self.total_interviews: List["Interview"] = []
44
47
  self._initialized = threading.Event()
45
48
 
49
+ from edsl.config import CONFIG
50
+
51
+ self.MAX_CONCURRENT = int(CONFIG.get("EDSL_MAX_CONCURRENT_TASKS"))
52
+ # print(f"MAX_CONCURRENT: {self.MAX_CONCURRENT}")
53
+
54
+ # async def run_async_generator(
55
+ # self,
56
+ # cache: Cache,
57
+ # n: int = 1,
58
+ # stop_on_exception: bool = False,
59
+ # sidecar_model: Optional[LanguageModel] = None,
60
+ # total_interviews: Optional[List["Interview"]] = None,
61
+ # raise_validation_errors: bool = False,
62
+ # ) -> AsyncGenerator["Result", None]:
63
+ # """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
64
+
65
+ # Completed tasks are yielded as they are completed.
66
+
67
+ # :param n: how many times to run each interview
68
+ # :param stop_on_exception: Whether to stop the interview if an exception is raised
69
+ # :param sidecar_model: a language model to use in addition to the interview's model
70
+ # :param total_interviews: A list of interviews to run can be provided instead.
71
+ # :param raise_validation_errors: Whether to raise validation errors
72
+ # """
73
+ # tasks = []
74
+ # if total_interviews: # was already passed in total interviews
75
+ # self.total_interviews = total_interviews
76
+ # else:
77
+ # self.total_interviews = list(
78
+ # self._populate_total_interviews(n=n)
79
+ # ) # Populate self.total_interviews before creating tasks
80
+ # self._initialized.set() # Signal that we're ready
81
+
82
+ # for interview in self.total_interviews:
83
+ # interviewing_task = self._build_interview_task(
84
+ # interview=interview,
85
+ # stop_on_exception=stop_on_exception,
86
+ # sidecar_model=sidecar_model,
87
+ # raise_validation_errors=raise_validation_errors,
88
+ # )
89
+ # tasks.append(asyncio.create_task(interviewing_task))
90
+
91
+ # for task in asyncio.as_completed(tasks):
92
+ # result = await task
93
+ # self.jobs_runner_status.add_completed_interview(result)
94
+ # yield result
95
+
46
96
  async def run_async_generator(
47
97
  self,
48
98
  cache: Cache,
@@ -52,9 +102,10 @@ class JobsRunnerAsyncio:
52
102
  total_interviews: Optional[List["Interview"]] = None,
53
103
  raise_validation_errors: bool = False,
54
104
  ) -> AsyncGenerator["Result", None]:
55
- """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
105
+ """Creates and processes tasks asynchronously, yielding results as they complete.
56
106
 
57
- Completed tasks are yielded as they are completed.
107
+ Tasks are created and processed in a streaming fashion rather than building the full list upfront.
108
+ Results are yielded as soon as they are available.
58
109
 
59
110
  :param n: how many times to run each interview
60
111
  :param stop_on_exception: Whether to stop the interview if an exception is raised
@@ -62,29 +113,70 @@ class JobsRunnerAsyncio:
62
113
  :param total_interviews: A list of interviews to run can be provided instead.
63
114
  :param raise_validation_errors: Whether to raise validation errors
64
115
  """
65
- tasks = []
66
- if total_interviews: # was already passed in total interviews
116
+ # Initialize interviews iterator
117
+ if total_interviews:
118
+ interviews_iter = iter(total_interviews)
67
119
  self.total_interviews = total_interviews
68
120
  else:
69
- self.total_interviews = list(
70
- self._populate_total_interviews(n=n)
71
- ) # Populate self.total_interviews before creating tasks
121
+ interviews_iter = self._populate_total_interviews(n=n)
122
+ self.total_interviews = list(interviews_iter)
123
+ interviews_iter = iter(self.total_interviews) # Create fresh iterator
72
124
 
73
125
  self._initialized.set() # Signal that we're ready
74
126
 
75
- for interview in self.total_interviews:
76
- interviewing_task = self._build_interview_task(
77
- interview=interview,
78
- stop_on_exception=stop_on_exception,
79
- sidecar_model=sidecar_model,
80
- raise_validation_errors=raise_validation_errors,
81
- )
82
- tasks.append(asyncio.create_task(interviewing_task))
127
+ # Keep track of active tasks
128
+ active_tasks = set()
83
129
 
84
- for task in asyncio.as_completed(tasks):
85
- result = await task
86
- self.jobs_runner_status.add_completed_interview(result)
87
- yield result
130
+ try:
131
+ while True:
132
+ # Add new tasks if we're below max_concurrent and there are more interviews
133
+ while len(active_tasks) < self.MAX_CONCURRENT:
134
+ try:
135
+ interview = next(interviews_iter)
136
+ task = asyncio.create_task(
137
+ self._build_interview_task(
138
+ interview=interview,
139
+ stop_on_exception=stop_on_exception,
140
+ sidecar_model=sidecar_model,
141
+ raise_validation_errors=raise_validation_errors,
142
+ )
143
+ )
144
+ active_tasks.add(task)
145
+ # Add callback to remove task from set when done
146
+ task.add_done_callback(active_tasks.discard)
147
+ except StopIteration:
148
+ break
149
+
150
+ if not active_tasks:
151
+ break
152
+
153
+ # Wait for next completed task
154
+ done, _ = await asyncio.wait(
155
+ active_tasks, return_when=asyncio.FIRST_COMPLETED
156
+ )
157
+
158
+ # Process completed tasks
159
+ for task in done:
160
+ try:
161
+ result = await task
162
+ self.jobs_runner_status.add_completed_interview(result)
163
+ yield result
164
+ except Exception as e:
165
+ if stop_on_exception:
166
+ # Cancel remaining tasks
167
+ for t in active_tasks:
168
+ if not t.done():
169
+ t.cancel()
170
+ raise
171
+ else:
172
+ # Log error and continue
173
+ # logger.error(f"Task failed with error: {e}")
174
+ continue
175
+ finally:
176
+ # Ensure we cancel any remaining tasks if we exit early
177
+ for task in active_tasks:
178
+ if not task.done():
179
+ task.cancel()
88
180
 
89
181
  def _populate_total_interviews(
90
182
  self, n: int = 1
@@ -168,20 +260,20 @@ class JobsRunnerAsyncio:
168
260
 
169
261
  prompt_dictionary = {}
170
262
  for answer_key_name in answer_key_names:
171
- prompt_dictionary[answer_key_name + "_user_prompt"] = (
172
- question_name_to_prompts[answer_key_name]["user_prompt"]
173
- )
174
- prompt_dictionary[answer_key_name + "_system_prompt"] = (
175
- question_name_to_prompts[answer_key_name]["system_prompt"]
176
- )
263
+ prompt_dictionary[
264
+ answer_key_name + "_user_prompt"
265
+ ] = question_name_to_prompts[answer_key_name]["user_prompt"]
266
+ prompt_dictionary[
267
+ answer_key_name + "_system_prompt"
268
+ ] = question_name_to_prompts[answer_key_name]["system_prompt"]
177
269
 
178
270
  raw_model_results_dictionary = {}
179
271
  cache_used_dictionary = {}
180
272
  for result in valid_results:
181
273
  question_name = result.question_name
182
- raw_model_results_dictionary[question_name + "_raw_model_response"] = (
183
- result.raw_model_response
184
- )
274
+ raw_model_results_dictionary[
275
+ question_name + "_raw_model_response"
276
+ ] = result.raw_model_response
185
277
  raw_model_results_dictionary[question_name + "_cost"] = result.cost
186
278
  one_use_buys = (
187
279
  "NA"
@@ -245,11 +337,25 @@ class JobsRunnerAsyncio:
245
337
  if len(results.task_history.indices) > 5:
246
338
  msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
247
339
 
248
- print(msg)
249
- # this is where exceptions are opening up
340
+ import sys
341
+
342
+ print(msg, file=sys.stderr)
343
+ from edsl.config import CONFIG
344
+
345
+ if CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "True":
346
+ open_in_browser = True
347
+ elif CONFIG.get("EDSL_OPEN_EXCEPTION_REPORT_URL") == "False":
348
+ open_in_browser = False
349
+ else:
350
+ raise Exception(
351
+ "EDSL_OPEN_EXCEPTION_REPORT_URL", "must be either True or False"
352
+ )
353
+
354
+ # print("open_in_browser", open_in_browser)
355
+
250
356
  filepath = results.task_history.html(
251
357
  cta="Open report to see details.",
252
- open_in_browser=True,
358
+ open_in_browser=open_in_browser,
253
359
  return_link=True,
254
360
  )
255
361
 
@@ -279,6 +385,7 @@ class JobsRunnerAsyncio:
279
385
  progress_bar: bool = False,
280
386
  sidecar_model: Optional[LanguageModel] = None,
281
387
  jobs_runner_status: Optional[Type[JobsRunnerStatusBase]] = None,
388
+ job_uuid: Optional[UUID] = None,
282
389
  print_exceptions: bool = True,
283
390
  raise_validation_errors: bool = False,
284
391
  ) -> "Coroutine":
@@ -297,13 +404,11 @@ class JobsRunnerAsyncio:
297
404
 
298
405
  if jobs_runner_status is not None:
299
406
  self.jobs_runner_status = jobs_runner_status(
300
- self, n=n, endpoint_url=endpoint_url
407
+ self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
301
408
  )
302
409
  else:
303
410
  self.jobs_runner_status = JobsRunnerStatus(
304
- self,
305
- n=n,
306
- endpoint_url=endpoint_url,
411
+ self, n=n, endpoint_url=endpoint_url, job_uuid=job_uuid
307
412
  )
308
413
 
309
414
  stop_event = threading.Event()
@@ -239,7 +239,6 @@ class JobsRunnerStatusBase(ABC):
239
239
  return stat_definitions[stat_name]()
240
240
 
241
241
  def update_progress(self, stop_event):
242
-
243
242
  while not stop_event.is_set():
244
243
  self.send_status_update()
245
244
  time.sleep(self.refresh_rate)
@@ -248,7 +247,6 @@ class JobsRunnerStatusBase(ABC):
248
247
 
249
248
 
250
249
  class JobsRunnerStatus(JobsRunnerStatusBase):
251
-
252
250
  @property
253
251
  def create_url(self) -> str:
254
252
  return f"{self.base_url}/api/v0/local-job"
@@ -8,7 +8,12 @@ from edsl.jobs.tasks.task_status_enum import TaskStatus
8
8
 
9
9
 
10
10
  class TaskHistory:
11
- def __init__(self, interviews: List["Interview"], include_traceback: bool = False):
11
+ def __init__(
12
+ self,
13
+ interviews: List["Interview"],
14
+ include_traceback: bool = False,
15
+ max_interviews: int = 10,
16
+ ):
12
17
  """
13
18
  The structure of a TaskHistory exception
14
19
 
@@ -22,6 +27,7 @@ class TaskHistory:
22
27
  self.include_traceback = include_traceback
23
28
 
24
29
  self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
30
+ self.max_interviews = max_interviews
25
31
 
26
32
  @classmethod
27
33
  def example(cls):
@@ -75,13 +81,6 @@ class TaskHistory:
75
81
 
76
82
  def to_dict(self, add_edsl_version=True):
77
83
  """Return the TaskHistory as a dictionary."""
78
- # return {
79
- # "exceptions": [
80
- # e.to_dict(include_traceback=self.include_traceback)
81
- # for e in self.exceptions
82
- # ],
83
- # "indices": self.indices,
84
- # }
85
84
  d = {
86
85
  "interviews": [
87
86
  i.to_dict(add_edsl_version=add_edsl_version)
@@ -124,10 +123,11 @@ class TaskHistory:
124
123
 
125
124
  def _repr_html_(self):
126
125
  """Return an HTML representation of the TaskHistory."""
127
- from edsl.utilities.utilities import data_to_html
126
+ d = self.to_dict(add_edsl_version=False)
127
+ data = [[k, v] for k, v in d.items()]
128
+ from tabulate import tabulate
128
129
 
129
- newdata = self.to_dict()["exceptions"]
130
- return data_to_html(newdata, replace_new_lines=True)
130
+ return tabulate(data, headers=["keys", "values"], tablefmt="html")
131
131
 
132
132
  def show_exceptions(self, tracebacks=False):
133
133
  """Print the exceptions."""
@@ -257,8 +257,6 @@ class TaskHistory:
257
257
  for question_name, exceptions in interview.exceptions.items():
258
258
  for exception in exceptions:
259
259
  exception_type = exception.exception.__class__.__name__
260
- # exception_type = exception["exception"]
261
- # breakpoint()
262
260
  if exception_type in exceptions_by_type:
263
261
  exceptions_by_type[exception_type] += 1
264
262
  else:
@@ -345,9 +343,9 @@ class TaskHistory:
345
343
 
346
344
  env = Environment(loader=TemplateLoader("edsl", "templates/error_reporting"))
347
345
 
348
- # Load and render a template
346
+ # Get current memory usage at this point
347
+
349
348
  template = env.get_template("base.html")
350
- # rendered_template = template.render(your_data=your_data)
351
349
 
352
350
  # Render the template with data
353
351
  output = template.render(
@@ -361,6 +359,7 @@ class TaskHistory:
361
359
  exceptions_by_model=self.exceptions_by_model,
362
360
  exceptions_by_service=self.exceptions_by_service,
363
361
  models_used=models_used,
362
+ max_interviews=self.max_interviews,
364
363
  )
365
364
  return output
366
365
 
@@ -370,7 +369,7 @@ class TaskHistory:
370
369
  return_link=False,
371
370
  css=None,
372
371
  cta="Open Report in New Tab",
373
- open_in_browser=True,
372
+ open_in_browser=False,
374
373
  ):
375
374
  """Return an HTML report."""
376
375
 
@@ -41,17 +41,19 @@ from edsl.data_transfer_models import (
41
41
  AgentResponseDict,
42
42
  )
43
43
 
44
+ if TYPE_CHECKING:
45
+ from edsl.data.Cache import Cache
46
+ from edsl.scenarios.FileStore import FileStore
47
+ from edsl.questions.QuestionBase import QuestionBase
44
48
 
45
49
  from edsl.config import CONFIG
46
50
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
47
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
48
- from edsl.language_models.repair import repair
49
- from edsl.enums import InferenceServiceType
50
- from edsl.Base import RichPrintingMixin, PersistenceMixin
51
- from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
52
- from edsl.exceptions.language_models import LanguageModelBadResponseError
51
+ from edsl.utilities.decorators import remove_edsl_version
53
52
 
53
+ from edsl.Base import PersistenceMixin
54
+ from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
54
55
  from edsl.language_models.KeyLookup import KeyLookup
56
+ from edsl.exceptions.language_models import LanguageModelBadResponseError
55
57
 
56
58
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
57
59
 
@@ -116,29 +118,11 @@ def handle_key_error(func):
116
118
 
117
119
 
118
120
  class LanguageModel(
119
- RichPrintingMixin, PersistenceMixin, ABC, metaclass=RegisterLanguageModelsMeta
121
+ PersistenceMixin,
122
+ ABC,
123
+ metaclass=RegisterLanguageModelsMeta,
120
124
  ):
121
- """ABC for LLM subclasses.
122
-
123
- TODO:
124
-
125
- 1) Need better, more descriptive names for functions
126
-
127
- get_model_response_no_cache (currently called async_execute_model_call)
128
-
129
- get_model_response (currently called async_get_raw_response; uses cache & adds tracking info)
130
- Calls:
131
- - async_execute_model_call
132
- - _updated_model_response_with_tracking
133
-
134
- get_answer (currently called async_get_response)
135
- This parses out the answer block and does some error-handling.
136
- Calls:
137
- - async_get_raw_response
138
- - parse_response
139
-
140
-
141
- """
125
+ """ABC for Language Models."""
142
126
 
143
127
  _model_ = None
144
128
  key_sequence = (
@@ -196,7 +180,7 @@ class LanguageModel(
196
180
  system_prompt = "You are a helpful agent pretending to be a human."
197
181
  return self.execute_model_call(user_prompt, system_prompt)
198
182
 
199
- def set_key_lookup(self, key_lookup: KeyLookup):
183
+ def set_key_lookup(self, key_lookup: KeyLookup) -> None:
200
184
  del self._api_token
201
185
  self.key_lookup = key_lookup
202
186
 
@@ -211,10 +195,14 @@ class LanguageModel(
211
195
  def __getitem__(self, key):
212
196
  return getattr(self, key)
213
197
 
214
- def _repr_html_(self):
215
- from edsl.utilities.utilities import data_to_html
198
+ def _repr_html_(self) -> str:
199
+ d = {"model": self.model}
200
+ d.update(self.parameters)
201
+ data = [[k, v] for k, v in d.items()]
202
+ from tabulate import tabulate
216
203
 
217
- return data_to_html(self.to_dict())
204
+ table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
205
+ return f"<pre>{table}</pre>"
218
206
 
219
207
  def hello(self, verbose=False):
220
208
  """Runs a simple test to check if the model is working."""
@@ -235,7 +223,6 @@ class LanguageModel(
235
223
  This method is used to check if the model has a valid API key.
236
224
  """
237
225
  from edsl.enums import service_to_api_keyname
238
- import os
239
226
 
240
227
  if self._model_ == "test":
241
228
  return True
@@ -248,9 +235,9 @@ class LanguageModel(
248
235
  """Allow the model to be used as a key in a dictionary."""
249
236
  from edsl.utilities.utilities import dict_hash
250
237
 
251
- return dict_hash(self.to_dict())
238
+ return dict_hash(self.to_dict(add_edsl_version=False))
252
239
 
253
- def __eq__(self, other):
240
+ def __eq__(self, other) -> bool:
254
241
  """Check is two models are the same.
255
242
 
256
243
  >>> m1 = LanguageModel.example()
@@ -278,15 +265,11 @@ class LanguageModel(
278
265
  @property
279
266
  def RPM(self):
280
267
  """Model's requests-per-minute limit."""
281
- # self._set_rate_limits()
282
- # return self._safety_factor * self.__rate_limits["rpm"]
283
268
  return self._rpm
284
269
 
285
270
  @property
286
271
  def TPM(self):
287
272
  """Model's tokens-per-minute limit."""
288
- # self._set_rate_limits()
289
- # return self._safety_factor * self.__rate_limits["tpm"]
290
273
  return self._tpm
291
274
 
292
275
  @property
@@ -314,8 +297,6 @@ class LanguageModel(
314
297
  >>> LanguageModel._overide_default_parameters(passed_parameter_dict={"temperature": 0.5}, default_parameter_dict={"temperature":0.9, "max_tokens": 1000})
315
298
  {'temperature': 0.5, 'max_tokens': 1000}
316
299
  """
317
- # parameters = dict({})
318
-
319
300
  # this is the case when data is loaded from a dict after serialization
320
301
  if "parameters" in passed_parameter_dict:
321
302
  passed_parameter_dict = passed_parameter_dict["parameters"]
@@ -429,9 +410,10 @@ class LanguageModel(
429
410
  self,
430
411
  user_prompt: str,
431
412
  system_prompt: str,
432
- cache: "Cache",
413
+ cache: Cache,
433
414
  iteration: int = 0,
434
- files_list=None,
415
+ files_list: Optional[List[FileStore]] = None,
416
+ invigilator=None,
435
417
  ) -> ModelResponse:
436
418
  """Handle caching of responses.
437
419
 
@@ -455,7 +437,6 @@ class LanguageModel(
455
437
 
456
438
  if files_list:
457
439
  files_hash = "+".join([str(hash(file)) for file in files_list])
458
- # print(f"Files hash: {files_hash}")
459
440
  user_prompt_with_hashes = user_prompt + f" {files_hash}"
460
441
  else:
461
442
  user_prompt_with_hashes = user_prompt
@@ -481,9 +462,7 @@ class LanguageModel(
481
462
  "user_prompt": user_prompt,
482
463
  "system_prompt": system_prompt,
483
464
  "files_list": files_list,
484
- # **({"encoded_image": encoded_image} if encoded_image else {}),
485
465
  }
486
- # response = await f(**params)
487
466
  response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
488
467
  new_cache_key = cache.store(
489
468
  **cache_call_params, response=response
@@ -504,11 +483,9 @@ class LanguageModel(
504
483
  _async_get_intended_model_call_outcome
505
484
  )
506
485
 
507
- # get_raw_response = sync_wrapper(async_get_raw_response)
508
-
509
486
  def simple_ask(
510
487
  self,
511
- question: "QuestionBase",
488
+ question: QuestionBase,
512
489
  system_prompt="You are a helpful agent pretending to be a human.",
513
490
  top_logprobs=2,
514
491
  ):
@@ -523,9 +500,10 @@ class LanguageModel(
523
500
  self,
524
501
  user_prompt: str,
525
502
  system_prompt: str,
526
- cache: "Cache",
503
+ cache: Cache,
527
504
  iteration: int = 1,
528
- files_list: Optional[List["File"]] = None,
505
+ files_list: Optional[List[FileStore]] = None,
506
+ **kwargs,
529
507
  ) -> dict:
530
508
  """Get response, parse, and return as string.
531
509
 
@@ -543,6 +521,9 @@ class LanguageModel(
543
521
  "cache": cache,
544
522
  "files_list": files_list,
545
523
  }
524
+ if "invigilator" in kwargs:
525
+ params.update({"invigilator": kwargs["invigilator"]})
526
+
546
527
  model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
547
528
  model_outputs = await self._async_get_intended_model_call_outcome(**params)
548
529
  edsl_dict = self.parse_response(model_outputs.response)
@@ -553,8 +534,6 @@ class LanguageModel(
553
534
  )
554
535
  return agent_response_dict
555
536
 
556
- # return await self._async_prepare_response(model_call_outcome, cache=cache)
557
-
558
537
  get_response = sync_wrapper(async_get_response)
559
538
 
560
539
  def cost(self, raw_response: dict[str, Any]) -> Union[float, str]:
@@ -604,10 +583,7 @@ class LanguageModel(
604
583
 
605
584
  return input_cost + output_cost
606
585
 
607
- #######################
608
- # SERIALIZATION METHODS
609
- #######################
610
- def to_dict(self, add_edsl_version=True) -> dict[str, Any]:
586
+ def to_dict(self, add_edsl_version: bool = True) -> dict[str, Any]:
611
587
  """Convert instance to a dictionary
612
588
 
613
589
  >>> m = LanguageModel.example()
@@ -629,18 +605,8 @@ class LanguageModel(
629
605
  from edsl.language_models.registry import get_model_class
630
606
 
631
607
  model_class = get_model_class(data["model"])
632
- # data["use_cache"] = True
633
608
  return model_class(**data)
634
609
 
635
- #######################
636
- # DUNDER METHODS
637
- #######################
638
- def print(self):
639
- from rich import print_json
640
- import json
641
-
642
- print_json(json.dumps(self.to_dict()))
643
-
644
610
  def __repr__(self) -> str:
645
611
  """Return a string representation of the object."""
646
612
  param_string = ", ".join(
@@ -654,33 +620,21 @@ class LanguageModel(
654
620
 
655
621
  def __add__(self, other_model: Type[LanguageModel]) -> Type[LanguageModel]:
656
622
  """Combine two models into a single model (other_model takes precedence over self)."""
657
- print(
623
+ import warnings
624
+
625
+ warnings.warn(
658
626
  f"""Warning: one model is replacing another. If you want to run both models, use a single `by` e.g.,
659
627
  by(m1, m2, m3) not by(m1).by(m2).by(m3)."""
660
628
  )
661
629
  return other_model or self
662
630
 
663
- def rich_print(self):
664
- """Display an object as a table."""
665
- from rich.table import Table
666
-
667
- table = Table(title="Language Model")
668
- table.add_column("Attribute", style="bold")
669
- table.add_column("Value")
670
-
671
- to_display = self.__dict__.copy()
672
- for attr_name, attr_value in to_display.items():
673
- table.add_row(attr_name, repr(attr_value))
674
-
675
- return table
676
-
677
631
  @classmethod
678
632
  def example(
679
633
  cls,
680
634
  test_model: bool = False,
681
635
  canned_response: str = "Hello world",
682
636
  throw_exception: bool = False,
683
- ):
637
+ ) -> LanguageModel:
684
638
  """Return a default instance of the class.
685
639
 
686
640
  >>> from edsl.language_models import LanguageModel
@@ -691,11 +645,17 @@ class LanguageModel(
691
645
  >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
692
646
  >>> q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True).select('example').first()
693
647
  'WOWZA!'
648
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!", throw_exception = True)
649
+ >>> r = q.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True, print_exceptions = True)
650
+ Exception report saved to ...
651
+ Also see: ...
694
652
  """
695
653
  from edsl import Model
696
654
 
697
655
  if test_model:
698
- m = Model("test", canned_response=canned_response)
656
+ m = Model(
657
+ "test", canned_response=canned_response, throw_exception=throw_exception
658
+ )
699
659
  return m
700
660
  else:
701
661
  return Model(skip_api_key_check=True)