edsl 0.1.31__py3-none-any.whl → 0.1.31.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +2 -7
  3. edsl/agents/PromptConstructionMixin.py +4 -9
  4. edsl/config.py +0 -4
  5. edsl/conjure/Conjure.py +0 -6
  6. edsl/coop/coop.py +0 -4
  7. edsl/data/CacheHandler.py +4 -3
  8. edsl/enums.py +0 -2
  9. edsl/inference_services/DeepInfraService.py +91 -6
  10. edsl/inference_services/InferenceServicesCollection.py +8 -13
  11. edsl/inference_services/OpenAIService.py +21 -64
  12. edsl/inference_services/registry.py +1 -2
  13. edsl/jobs/Jobs.py +5 -29
  14. edsl/jobs/buckets/TokenBucket.py +4 -12
  15. edsl/jobs/interviews/Interview.py +9 -31
  16. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +33 -49
  17. edsl/jobs/interviews/interview_exception_tracking.py +10 -68
  18. edsl/jobs/runners/JobsRunnerAsyncio.py +81 -112
  19. edsl/jobs/runners/JobsRunnerStatusData.py +237 -0
  20. edsl/jobs/runners/JobsRunnerStatusMixin.py +35 -291
  21. edsl/jobs/tasks/TaskCreators.py +2 -8
  22. edsl/jobs/tasks/TaskHistory.py +1 -145
  23. edsl/language_models/LanguageModel.py +32 -49
  24. edsl/language_models/registry.py +0 -4
  25. edsl/questions/QuestionMultipleChoice.py +1 -1
  26. edsl/questions/QuestionNumerical.py +1 -0
  27. edsl/results/DatasetExportMixin.py +3 -12
  28. edsl/scenarios/Scenario.py +0 -14
  29. edsl/scenarios/ScenarioList.py +2 -15
  30. edsl/scenarios/ScenarioListExportMixin.py +4 -15
  31. edsl/scenarios/ScenarioListPdfMixin.py +0 -3
  32. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/METADATA +1 -2
  33. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/RECORD +35 -37
  34. edsl/inference_services/GroqService.py +0 -18
  35. edsl/jobs/interviews/InterviewExceptionEntry.py +0 -101
  36. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/LICENSE +0 -0
  37. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/WHEEL +0 -0
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.31"
1
+ __version__ = "0.1.31.dev2"
@@ -18,12 +18,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
18
18
  """An invigilator that uses an AI model to answer questions."""
19
19
 
20
20
  async def async_answer_question(self) -> AgentResponseDict:
21
- """Answer a question using the AI model.
22
-
23
- >>> i = InvigilatorAI.example()
24
- >>> i.answer_question()
25
- {'message': '{"answer": "SPAM!"}'}
26
- """
21
+ """Answer a question using the AI model."""
27
22
  params = self.get_prompts() | {"iteration": self.iteration}
28
23
  raw_response = await self.async_get_response(**params)
29
24
  data = {
@@ -34,7 +29,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
34
29
  "raw_model_response": raw_response["raw_model_response"],
35
30
  }
36
31
  response = self._format_raw_response(**data)
37
- # breakpoint()
38
32
  return AgentResponseDict(**response)
39
33
 
40
34
  async def async_get_response(
@@ -103,6 +97,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
103
97
  answer = question._translate_answer_code_to_answer(
104
98
  response["answer"], combined_dict
105
99
  )
100
+ # breakpoint()
106
101
  data = {
107
102
  "answer": answer,
108
103
  "comment": response.get(
@@ -281,17 +281,12 @@ class PromptConstructorMixin:
281
281
  if "question_options" in question_data:
282
282
  if isinstance(self.question.data["question_options"], str):
283
283
  from jinja2 import Environment, meta
284
-
285
284
  env = Environment()
286
- parsed_content = env.parse(self.question.data["question_options"])
287
- question_option_key = list(
288
- meta.find_undeclared_variables(parsed_content)
289
- )[0]
290
- question_data["question_options"] = self.scenario.get(
291
- question_option_key
292
- )
285
+ parsed_content = env.parse(self.question.data['question_options'])
286
+ question_option_key = list(meta.find_undeclared_variables(parsed_content))[0]
287
+ question_data["question_options"] = self.scenario.get(question_option_key)
293
288
 
294
- # breakpoint()
289
+ #breakpoint()
295
290
  rendered_instructions = question_prompt.render(
296
291
  question_data | self.scenario | d | {"agent": self.agent}
297
292
  )
edsl/config.py CHANGED
@@ -65,10 +65,6 @@ CONFIG_MAP = {
65
65
  # "default": None,
66
66
  # "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
67
67
  # },
68
- # "GROQ_API_KEY": {
69
- # "default": None,
70
- # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
- # },
72
68
  }
73
69
 
74
70
 
edsl/conjure/Conjure.py CHANGED
@@ -35,12 +35,6 @@ class Conjure:
35
35
  # The __init__ method in Conjure won't be called because __new__ returns a different class instance.
36
36
  pass
37
37
 
38
- @classmethod
39
- def example(cls):
40
- from edsl.conjure.InputData import InputDataABC
41
-
42
- return InputDataABC.example()
43
-
44
38
 
45
39
  if __name__ == "__main__":
46
40
  pass
edsl/coop/coop.py CHANGED
@@ -465,7 +465,6 @@ class Coop:
465
465
  description: Optional[str] = None,
466
466
  status: RemoteJobStatus = "queued",
467
467
  visibility: Optional[VisibilityType] = "unlisted",
468
- iterations: Optional[int] = 1,
469
468
  ) -> dict:
470
469
  """
471
470
  Send a remote inference job to the server.
@@ -474,7 +473,6 @@ class Coop:
474
473
  :param optional description: A description for this entry in the remote cache.
475
474
  :param status: The status of the job. Should be 'queued', unless you are debugging.
476
475
  :param visibility: The visibility of the cache entry.
477
- :param iterations: The number of times to run each interview.
478
476
 
479
477
  >>> job = Jobs.example()
480
478
  >>> coop.remote_inference_create(job=job, description="My job")
@@ -490,7 +488,6 @@ class Coop:
490
488
  ),
491
489
  "description": description,
492
490
  "status": status,
493
- "iterations": iterations,
494
491
  "visibility": visibility,
495
492
  "version": self._edsl_version,
496
493
  },
@@ -501,7 +498,6 @@ class Coop:
501
498
  "uuid": response_json.get("jobs_uuid"),
502
499
  "description": response_json.get("description"),
503
500
  "status": response_json.get("status"),
504
- "iterations": response_json.get("iterations"),
505
501
  "visibility": response_json.get("visibility"),
506
502
  "version": self._edsl_version,
507
503
  }
edsl/data/CacheHandler.py CHANGED
@@ -41,7 +41,7 @@ class CacheHandler:
41
41
  old_data = self.from_old_sqlite_cache()
42
42
  self.cache.add_from_dict(old_data)
43
43
 
44
- def create_cache_directory(self, notify = False) -> None:
44
+ def create_cache_directory(self) -> None:
45
45
  """
46
46
  Create the cache directory if one is required and it does not exist.
47
47
  """
@@ -49,8 +49,9 @@ class CacheHandler:
49
49
  dir_path = os.path.dirname(path)
50
50
  if dir_path and not os.path.exists(dir_path):
51
51
  os.makedirs(dir_path)
52
- if notify:
53
- print(f"Created cache directory: {dir_path}")
52
+ import warnings
53
+
54
+ warnings.warn(f"Created cache directory: {dir_path}")
54
55
 
55
56
  def gen_cache(self) -> Cache:
56
57
  """
edsl/enums.py CHANGED
@@ -59,7 +59,6 @@ class InferenceServiceType(EnumWithChecks):
59
59
  GOOGLE = "google"
60
60
  TEST = "test"
61
61
  ANTHROPIC = "anthropic"
62
- GROQ = "groq"
63
62
 
64
63
 
65
64
  service_to_api_keyname = {
@@ -70,7 +69,6 @@ service_to_api_keyname = {
70
69
  InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
71
70
  InferenceServiceType.TEST.value: "TBD",
72
71
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
73
- InferenceServiceType.GROQ.value: "GROQ_API_KEY",
74
72
  }
75
73
 
76
74
 
@@ -2,17 +2,102 @@ import aiohttp
2
2
  import json
3
3
  import requests
4
4
  from typing import Any, List
5
-
6
- # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
5
+ from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
6
  from edsl.language_models import LanguageModel
8
7
 
9
- from edsl.inference_services.OpenAIService import OpenAIService
10
-
11
8
 
12
- class DeepInfraService(OpenAIService):
9
+ class DeepInfraService(InferenceServiceABC):
13
10
  """DeepInfra service class."""
14
11
 
15
12
  _inference_service_ = "deep_infra"
16
13
  _env_key_name_ = "DEEP_INFRA_API_KEY"
17
- _base_url_ = "https://api.deepinfra.com/v1/openai"
14
+
18
15
  _models_list_cache: List[str] = []
16
+
17
+ @classmethod
18
+ def available(cls):
19
+ text_models = cls.full_details_available()
20
+ return [m["model_name"] for m in text_models]
21
+
22
+ @classmethod
23
+ def full_details_available(cls, verbose=False):
24
+ if not cls._models_list_cache:
25
+ url = "https://api.deepinfra.com/models/list"
26
+ response = requests.get(url)
27
+ if response.status_code == 200:
28
+ text_generation_models = [
29
+ r for r in response.json() if r["type"] == "text-generation"
30
+ ]
31
+ cls._models_list_cache = text_generation_models
32
+
33
+ from rich import print_json
34
+ import json
35
+
36
+ if verbose:
37
+ print_json(json.dumps(text_generation_models))
38
+ return text_generation_models
39
+ else:
40
+ return f"Failed to fetch data: Status code {response.status_code}"
41
+ else:
42
+ return cls._models_list_cache
43
+
44
+ @classmethod
45
+ def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
46
+ base_url = "https://api.deepinfra.com/v1/inference/"
47
+ if model_class_name is None:
48
+ model_class_name = cls.to_class_name(model_name)
49
+ url = f"{base_url}{model_name}"
50
+
51
+ class LLM(LanguageModel):
52
+ _inference_service_ = cls._inference_service_
53
+ _model_ = model_name
54
+ _parameters_ = {
55
+ "temperature": 0.7,
56
+ "top_p": 0.2,
57
+ "top_k": 0.1,
58
+ "max_new_tokens": 512,
59
+ "stopSequences": [],
60
+ }
61
+
62
+ async def async_execute_model_call(
63
+ self, user_prompt: str, system_prompt: str = ""
64
+ ) -> dict[str, Any]:
65
+ self.url = url
66
+ headers = {
67
+ "Content-Type": "application/json",
68
+ "Authorization": f"bearer {self.api_token}",
69
+ }
70
+ # don't mess w/ the newlines
71
+ data = {
72
+ "input": f"""
73
+ [INST]<<SYS>>
74
+ {system_prompt}
75
+ <<SYS>>{user_prompt}[/INST]
76
+ """,
77
+ "stream": False,
78
+ "temperature": self.temperature,
79
+ "top_p": self.top_p,
80
+ "top_k": self.top_k,
81
+ "max_new_tokens": self.max_new_tokens,
82
+ }
83
+ async with aiohttp.ClientSession() as session:
84
+ async with session.post(
85
+ self.url, headers=headers, data=json.dumps(data)
86
+ ) as response:
87
+ raw_response_text = await response.text()
88
+ return json.loads(raw_response_text)
89
+
90
+ def parse_response(self, raw_response: dict[str, Any]) -> str:
91
+ if "results" not in raw_response:
92
+ raise Exception(
93
+ f"Deep Infra response does not contain 'results' key: {raw_response}"
94
+ )
95
+ if "generated_text" not in raw_response["results"][0]:
96
+ raise Exception(
97
+ f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
98
+ )
99
+ return raw_response["results"][0]["generated_text"]
100
+
101
+ LLM.__name__ = model_class_name
102
+
103
+ return LLM
@@ -15,19 +15,18 @@ class InferenceServicesCollection:
15
15
  cls.added_models[service_name].append(model_name)
16
16
 
17
17
  @staticmethod
18
- def _get_service_available(service, warn: bool = False) -> list[str]:
18
+ def _get_service_available(service) -> list[str]:
19
19
  from_api = True
20
20
  try:
21
21
  service_models = service.available()
22
22
  except Exception as e:
23
- if warn:
24
- warnings.warn(
25
- f"""Error getting models for {service._inference_service_}.
26
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
27
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
28
- Relying on cache.""",
29
- UserWarning,
30
- )
23
+ warnings.warn(
24
+ f"""Error getting models for {service._inference_service_}.
25
+ Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
+ See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
+ Relying on cache.""",
28
+ UserWarning,
29
+ )
31
30
  from edsl.inference_services.models_available_cache import models_available
32
31
 
33
32
  service_models = models_available.get(service._inference_service_, [])
@@ -61,8 +60,4 @@ class InferenceServicesCollection:
61
60
  if service_name is None or service_name == service._inference_service_:
62
61
  return service.create_model(model_name)
63
62
 
64
- # if model_name == "test":
65
- # from edsl.language_models import LanguageModel
66
- # return LanguageModel(test = True)
67
-
68
63
  raise Exception(f"Model {model_name} not found in any of the services")
@@ -1,9 +1,6 @@
1
1
  from typing import Any, List
2
2
  import re
3
- import os
4
-
5
- # from openai import AsyncOpenAI
6
- import openai
3
+ from openai import AsyncOpenAI
7
4
 
8
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
9
6
  from edsl.language_models import LanguageModel
@@ -15,22 +12,6 @@ class OpenAIService(InferenceServiceABC):
15
12
 
16
13
  _inference_service_ = "openai"
17
14
  _env_key_name_ = "OPENAI_API_KEY"
18
- _base_url_ = None
19
-
20
- _sync_client_ = openai.OpenAI
21
- _async_client_ = openai.AsyncOpenAI
22
-
23
- @classmethod
24
- def sync_client(cls):
25
- return cls._sync_client_(
26
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
27
- )
28
-
29
- @classmethod
30
- def async_client(cls):
31
- return cls._async_client_(
32
- api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
33
- )
34
15
 
35
16
  # TODO: Make this a coop call
36
17
  model_exclude_list = [
@@ -50,24 +31,16 @@ class OpenAIService(InferenceServiceABC):
50
31
  ]
51
32
  _models_list_cache: List[str] = []
52
33
 
53
- @classmethod
54
- def get_model_list(cls):
55
- raw_list = cls.sync_client().models.list()
56
- if hasattr(raw_list, "data"):
57
- return raw_list.data
58
- else:
59
- return raw_list
60
-
61
34
  @classmethod
62
35
  def available(cls) -> List[str]:
63
- # from openai import OpenAI
36
+ from openai import OpenAI
64
37
 
65
38
  if not cls._models_list_cache:
66
39
  try:
67
- # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
40
+ client = OpenAI()
68
41
  cls._models_list_cache = [
69
42
  m.id
70
- for m in cls.get_model_list()
43
+ for m in client.models.list()
71
44
  if m.id not in cls.model_exclude_list
72
45
  ]
73
46
  except Exception as e:
@@ -105,24 +78,15 @@ class OpenAIService(InferenceServiceABC):
105
78
  "top_logprobs": 3,
106
79
  }
107
80
 
108
- def sync_client(self):
109
- return cls.sync_client()
110
-
111
- def async_client(self):
112
- return cls.async_client()
113
-
114
81
  @classmethod
115
82
  def available(cls) -> list[str]:
116
- # import openai
117
- # client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
118
- # return client.models.list()
119
- return cls.sync_client().models.list()
83
+ client = openai.OpenAI()
84
+ return client.models.list()
120
85
 
121
86
  def get_headers(self) -> dict[str, Any]:
122
- # from openai import OpenAI
87
+ from openai import OpenAI
123
88
 
124
- # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
125
- client = self.sync_client()
89
+ client = OpenAI()
126
90
  response = client.chat.completions.with_raw_response.create(
127
91
  messages=[
128
92
  {
@@ -160,8 +124,8 @@ class OpenAIService(InferenceServiceABC):
160
124
  encoded_image=None,
161
125
  ) -> dict[str, Any]:
162
126
  """Calls the OpenAI API and returns the API response."""
127
+ content = [{"type": "text", "text": user_prompt}]
163
128
  if encoded_image:
164
- content = [{"type": "text", "text": user_prompt}]
165
129
  content.append(
166
130
  {
167
131
  "type": "image_url",
@@ -170,28 +134,21 @@ class OpenAIService(InferenceServiceABC):
170
134
  },
171
135
  }
172
136
  )
173
- else:
174
- content = user_prompt
175
- # self.client = AsyncOpenAI(
176
- # api_key = os.getenv(cls._env_key_name_),
177
- # base_url = cls._base_url_
178
- # )
179
- client = self.async_client()
180
- params = {
181
- "model": self.model,
182
- "messages": [
137
+ self.client = AsyncOpenAI()
138
+ response = await self.client.chat.completions.create(
139
+ model=self.model,
140
+ messages=[
183
141
  {"role": "system", "content": system_prompt},
184
142
  {"role": "user", "content": content},
185
143
  ],
186
- "temperature": self.temperature,
187
- "max_tokens": self.max_tokens,
188
- "top_p": self.top_p,
189
- "frequency_penalty": self.frequency_penalty,
190
- "presence_penalty": self.presence_penalty,
191
- "logprobs": self.logprobs,
192
- "top_logprobs": self.top_logprobs if self.logprobs else None,
193
- }
194
- response = await client.chat.completions.create(**params)
144
+ temperature=self.temperature,
145
+ max_tokens=self.max_tokens,
146
+ top_p=self.top_p,
147
+ frequency_penalty=self.frequency_penalty,
148
+ presence_penalty=self.presence_penalty,
149
+ logprobs=self.logprobs,
150
+ top_logprobs=self.top_logprobs if self.logprobs else None,
151
+ )
195
152
  return response.model_dump()
196
153
 
197
154
  @staticmethod
@@ -6,8 +6,7 @@ from edsl.inference_services.OpenAIService import OpenAIService
6
6
  from edsl.inference_services.AnthropicService import AnthropicService
7
7
  from edsl.inference_services.DeepInfraService import DeepInfraService
8
8
  from edsl.inference_services.GoogleService import GoogleService
9
- from edsl.inference_services.GroqService import GroqService
10
9
 
11
10
  default = InferenceServicesCollection(
12
- [OpenAIService, AnthropicService, DeepInfraService, GoogleService, GroqService]
11
+ [OpenAIService, AnthropicService, DeepInfraService, GoogleService]
13
12
  )
edsl/jobs/Jobs.py CHANGED
@@ -319,11 +319,7 @@ class Jobs(Base):
319
319
  self.scenarios = self.scenarios or [Scenario()]
320
320
  for agent, scenario, model in product(self.agents, self.scenarios, self.models):
321
321
  yield Interview(
322
- survey=self.survey,
323
- agent=agent,
324
- scenario=scenario,
325
- model=model,
326
- skip_retry=self.skip_retry,
322
+ survey=self.survey, agent=agent, scenario=scenario, model=model
327
323
  )
328
324
 
329
325
  def create_bucket_collection(self) -> BucketCollection:
@@ -413,12 +409,6 @@ class Jobs(Base):
413
409
  if warn:
414
410
  warnings.warn(message)
415
411
 
416
- @property
417
- def skip_retry(self):
418
- if not hasattr(self, "_skip_retry"):
419
- return False
420
- return self._skip_retry
421
-
422
412
  def run(
423
413
  self,
424
414
  n: int = 1,
@@ -433,7 +423,6 @@ class Jobs(Base):
433
423
  print_exceptions=True,
434
424
  remote_cache_description: Optional[str] = None,
435
425
  remote_inference_description: Optional[str] = None,
436
- skip_retry: bool = False,
437
426
  ) -> Results:
438
427
  """
439
428
  Runs the Job: conducts Interviews and returns their results.
@@ -452,7 +441,6 @@ class Jobs(Base):
452
441
  from edsl.coop.coop import Coop
453
442
 
454
443
  self._check_parameters()
455
- self._skip_retry = skip_retry
456
444
 
457
445
  if batch_mode is not None:
458
446
  raise NotImplementedError(
@@ -487,7 +475,6 @@ class Jobs(Base):
487
475
  self,
488
476
  description=remote_inference_description,
489
477
  status="queued",
490
- iterations=n,
491
478
  )
492
479
  time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
493
480
  job_uuid = remote_job_creation_data.get("uuid")
@@ -642,9 +629,9 @@ class Jobs(Base):
642
629
  results = JobsRunnerAsyncio(self).run(*args, **kwargs)
643
630
  return results
644
631
 
645
- async def run_async(self, cache=None, n=1, **kwargs):
632
+ async def run_async(self, cache=None, **kwargs):
646
633
  """Run the job asynchronously."""
647
- results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
634
+ results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
648
635
  return results
649
636
 
650
637
  def all_question_parameters(self):
@@ -724,10 +711,7 @@ class Jobs(Base):
724
711
  #######################
725
712
  @classmethod
726
713
  def example(
727
- cls,
728
- throw_exception_probability: int = 0,
729
- randomize: bool = False,
730
- test_model=False,
714
+ cls, throw_exception_probability: int = 0, randomize: bool = False
731
715
  ) -> Jobs:
732
716
  """Return an example Jobs instance.
733
717
 
@@ -745,11 +729,6 @@ class Jobs(Base):
745
729
 
746
730
  addition = "" if not randomize else str(uuid4())
747
731
 
748
- if test_model:
749
- from edsl.language_models import LanguageModel
750
-
751
- m = LanguageModel.example(test_model=True)
752
-
753
732
  # (status, question, period)
754
733
  agent_answers = {
755
734
  ("Joyful", "how_feeling", "morning"): "OK",
@@ -797,10 +776,7 @@ class Jobs(Base):
797
776
  Scenario({"period": "afternoon"}),
798
777
  ]
799
778
  )
800
- if test_model:
801
- job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
802
- else:
803
- job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
779
+ job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
804
780
 
805
781
  return job
806
782
 
@@ -100,9 +100,7 @@ class TokenBucket:
100
100
  available_tokens = min(self.capacity, self.tokens + refill_amount)
101
101
  return max(0, requested_tokens - available_tokens) / self.refill_rate
102
102
 
103
- async def get_tokens(
104
- self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
105
- ) -> None:
103
+ async def get_tokens(self, amount: Union[int, float] = 1) -> None:
106
104
  """Wait for the specified number of tokens to become available.
107
105
 
108
106
 
@@ -118,20 +116,14 @@ class TokenBucket:
118
116
  True
119
117
 
120
118
  >>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
121
- >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=False))
119
+ >>> asyncio.run(bucket.get_tokens(11))
122
120
  Traceback (most recent call last):
123
121
  ...
124
122
  ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
125
- >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
126
123
  """
127
124
  if amount > self.capacity:
128
- if not cheat_bucket_capacity:
129
- msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
130
- raise ValueError(msg)
131
- else:
132
- self.tokens = 0 # clear the bucket but let it go through
133
- return
134
-
125
+ msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
126
+ raise ValueError(msg)
135
127
  while self.tokens < amount:
136
128
  self.refill()
137
129
  await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
@@ -14,8 +14,8 @@ from edsl.jobs.tasks.TaskCreators import TaskCreators
14
14
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
15
15
  from edsl.jobs.interviews.interview_exception_tracking import (
16
16
  InterviewExceptionCollection,
17
+ InterviewExceptionEntry,
17
18
  )
18
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
19
19
  from edsl.jobs.interviews.retry_management import retry_strategy
20
20
  from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
21
21
  from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
@@ -44,7 +44,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
44
44
  iteration: int = 0,
45
45
  cache: Optional["Cache"] = None,
46
46
  sidecar_model: Optional["LanguageModel"] = None,
47
- skip_retry=False,
48
47
  ):
49
48
  """Initialize the Interview instance.
50
49
 
@@ -88,7 +87,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
88
87
  self.task_creators = TaskCreators() # tracks the task creators
89
88
  self.exceptions = InterviewExceptionCollection()
90
89
  self._task_status_log_dict = InterviewStatusLog()
91
- self.skip_retry = skip_retry
92
90
 
93
91
  # dictionary mapping question names to their index in the survey.
94
92
  self.to_index = {
@@ -96,30 +94,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
96
94
  for index, question_name in enumerate(self.survey.question_names)
97
95
  }
98
96
 
99
- def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
100
- """Return a dictionary representation of the Interview instance.
101
- This is just for hashing purposes.
102
-
103
- >>> i = Interview.example()
104
- >>> hash(i)
105
- 1646262796627658719
106
- """
107
- d = {
108
- "agent": self.agent._to_dict(),
109
- "survey": self.survey._to_dict(),
110
- "scenario": self.scenario._to_dict(),
111
- "model": self.model._to_dict(),
112
- "iteration": self.iteration,
113
- "exceptions": {},
114
- }
115
- if include_exceptions:
116
- d["exceptions"] = self.exceptions.to_dict()
117
-
118
- def __hash__(self) -> int:
119
- from edsl.utilities.utilities import dict_hash
120
-
121
- return dict_hash(self._to_dict())
122
-
123
97
  async def async_conduct_interview(
124
98
  self,
125
99
  *,
@@ -160,7 +134,8 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
160
134
  <BLANKLINE>
161
135
 
162
136
  >>> i.exceptions
163
- {'q0': ...
137
+ {'q0': [{'exception': "Exception('This is a test error')", 'time': ..., 'traceback': ...
138
+
164
139
  >>> i = Interview.example()
165
140
  >>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
166
141
  Traceback (most recent call last):
@@ -229,9 +204,13 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
229
204
  {}
230
205
  >>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
231
206
  >>> i.exceptions
232
- {'q0': ...
207
+ {'q0': [{'exception': "Exception('An exception occurred.')", 'time': ..., 'traceback': 'NoneType: None\\n'}]}
233
208
  """
234
- exception_entry = InterviewExceptionEntry(exception)
209
+ exception_entry = InterviewExceptionEntry(
210
+ exception=repr(exception),
211
+ time=time.time(),
212
+ traceback=traceback.format_exc(),
213
+ )
235
214
  self.exceptions.add(task.get_name(), exception_entry)
236
215
 
237
216
  @property
@@ -272,7 +251,6 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
272
251
  model=self.model,
273
252
  iteration=iteration,
274
253
  cache=cache,
275
- skip_retry=self.skip_retry,
276
254
  )
277
255
 
278
256
  @classmethod