edsl 0.1.31.dev4__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +3 -4
  3. edsl/agents/PromptConstructionMixin.py +35 -15
  4. edsl/config.py +11 -1
  5. edsl/conjure/Conjure.py +6 -0
  6. edsl/data/CacheHandler.py +3 -4
  7. edsl/enums.py +4 -0
  8. edsl/exceptions/general.py +10 -8
  9. edsl/inference_services/AwsBedrock.py +110 -0
  10. edsl/inference_services/AzureAI.py +197 -0
  11. edsl/inference_services/DeepInfraService.py +4 -3
  12. edsl/inference_services/GroqService.py +3 -4
  13. edsl/inference_services/InferenceServicesCollection.py +13 -8
  14. edsl/inference_services/OllamaService.py +18 -0
  15. edsl/inference_services/OpenAIService.py +23 -18
  16. edsl/inference_services/models_available_cache.py +31 -0
  17. edsl/inference_services/registry.py +13 -1
  18. edsl/jobs/Jobs.py +100 -19
  19. edsl/jobs/buckets/TokenBucket.py +12 -4
  20. edsl/jobs/interviews/Interview.py +31 -9
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
  22. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -34
  23. edsl/jobs/interviews/interview_exception_tracking.py +68 -10
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +36 -15
  25. edsl/jobs/runners/JobsRunnerStatusMixin.py +81 -51
  26. edsl/jobs/tasks/TaskCreators.py +1 -1
  27. edsl/jobs/tasks/TaskHistory.py +145 -1
  28. edsl/language_models/LanguageModel.py +58 -43
  29. edsl/language_models/registry.py +2 -2
  30. edsl/questions/QuestionBudget.py +0 -1
  31. edsl/questions/QuestionCheckBox.py +0 -1
  32. edsl/questions/QuestionExtract.py +0 -1
  33. edsl/questions/QuestionFreeText.py +2 -9
  34. edsl/questions/QuestionList.py +0 -1
  35. edsl/questions/QuestionMultipleChoice.py +1 -2
  36. edsl/questions/QuestionNumerical.py +0 -1
  37. edsl/questions/QuestionRank.py +0 -1
  38. edsl/results/DatasetExportMixin.py +33 -3
  39. edsl/scenarios/Scenario.py +14 -0
  40. edsl/scenarios/ScenarioList.py +216 -13
  41. edsl/scenarios/ScenarioListExportMixin.py +15 -4
  42. edsl/scenarios/ScenarioListPdfMixin.py +3 -0
  43. edsl/surveys/Rule.py +5 -2
  44. edsl/surveys/Survey.py +84 -1
  45. edsl/surveys/SurveyQualtricsImport.py +213 -0
  46. edsl/utilities/utilities.py +31 -0
  47. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/METADATA +4 -1
  48. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/RECORD +50 -45
  49. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
  50. {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -1,12 +1,14 @@
1
1
  from typing import Any, List
2
2
  import re
3
3
  import os
4
- #from openai import AsyncOpenAI
4
+
5
+ # from openai import AsyncOpenAI
5
6
  import openai
6
7
 
7
8
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
8
9
  from edsl.language_models import LanguageModel
9
10
  from edsl.inference_services.rate_limits_cache import rate_limits
11
+ from edsl.utilities.utilities import fix_partial_correct_response
10
12
 
11
13
 
12
14
  class OpenAIService(InferenceServiceABC):
@@ -18,18 +20,18 @@ class OpenAIService(InferenceServiceABC):
18
20
 
19
21
  _sync_client_ = openai.OpenAI
20
22
  _async_client_ = openai.AsyncOpenAI
21
-
23
+
22
24
  @classmethod
23
25
  def sync_client(cls):
24
26
  return cls._sync_client_(
25
- api_key = os.getenv(cls._env_key_name_),
26
- base_url = cls._base_url_)
27
-
27
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
28
+ )
29
+
28
30
  @classmethod
29
31
  def async_client(cls):
30
32
  return cls._async_client_(
31
- api_key = os.getenv(cls._env_key_name_),
32
- base_url = cls._base_url_)
33
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
34
+ )
33
35
 
34
36
  # TODO: Make this a coop call
35
37
  model_exclude_list = [
@@ -59,14 +61,14 @@ class OpenAIService(InferenceServiceABC):
59
61
 
60
62
  @classmethod
61
63
  def available(cls) -> List[str]:
62
- #from openai import OpenAI
64
+ # from openai import OpenAI
63
65
 
64
66
  if not cls._models_list_cache:
65
67
  try:
66
- #client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
68
+ # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
67
69
  cls._models_list_cache = [
68
70
  m.id
69
- for m in cls.get_model_list()
71
+ for m in cls.get_model_list()
70
72
  if m.id not in cls.model_exclude_list
71
73
  ]
72
74
  except Exception as e:
@@ -106,21 +108,21 @@ class OpenAIService(InferenceServiceABC):
106
108
 
107
109
  def sync_client(self):
108
110
  return cls.sync_client()
109
-
111
+
110
112
  def async_client(self):
111
113
  return cls.async_client()
112
114
 
113
115
  @classmethod
114
116
  def available(cls) -> list[str]:
115
- #import openai
116
- #client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
117
- #return client.models.list()
117
+ # import openai
118
+ # client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
119
+ # return client.models.list()
118
120
  return cls.sync_client().models.list()
119
-
121
+
120
122
  def get_headers(self) -> dict[str, Any]:
121
- #from openai import OpenAI
123
+ # from openai import OpenAI
122
124
 
123
- #client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
125
+ # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
124
126
  client = self.sync_client()
125
127
  response = client.chat.completions.with_raw_response.create(
126
128
  messages=[
@@ -172,7 +174,7 @@ class OpenAIService(InferenceServiceABC):
172
174
  else:
173
175
  content = user_prompt
174
176
  # self.client = AsyncOpenAI(
175
- # api_key = os.getenv(cls._env_key_name_),
177
+ # api_key = os.getenv(cls._env_key_name_),
176
178
  # base_url = cls._base_url_
177
179
  # )
178
180
  client = self.async_client()
@@ -206,6 +208,9 @@ class OpenAIService(InferenceServiceABC):
206
208
  if match:
207
209
  return match.group(1)
208
210
  else:
211
+ out = fix_partial_correct_response(response)
212
+ if "error" not in out:
213
+ response = out["extracted_json"]
209
214
  return response
210
215
 
211
216
  LLM.__name__ = "LanguageModel"
@@ -66,4 +66,35 @@ models_available = {
66
66
  "openchat/openchat_3.5",
67
67
  ],
68
68
  "google": ["gemini-pro"],
69
+ "bedrock": [
70
+ "amazon.titan-tg1-large",
71
+ "amazon.titan-text-lite-v1",
72
+ "amazon.titan-text-express-v1",
73
+ "ai21.j2-grande-instruct",
74
+ "ai21.j2-jumbo-instruct",
75
+ "ai21.j2-mid",
76
+ "ai21.j2-mid-v1",
77
+ "ai21.j2-ultra",
78
+ "ai21.j2-ultra-v1",
79
+ "anthropic.claude-instant-v1",
80
+ "anthropic.claude-v2:1",
81
+ "anthropic.claude-v2",
82
+ "anthropic.claude-3-sonnet-20240229-v1:0",
83
+ "anthropic.claude-3-haiku-20240307-v1:0",
84
+ "anthropic.claude-3-opus-20240229-v1:0",
85
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
86
+ "cohere.command-text-v14",
87
+ "cohere.command-r-v1:0",
88
+ "cohere.command-r-plus-v1:0",
89
+ "cohere.command-light-text-v14",
90
+ "meta.llama3-8b-instruct-v1:0",
91
+ "meta.llama3-70b-instruct-v1:0",
92
+ "meta.llama3-1-8b-instruct-v1:0",
93
+ "meta.llama3-1-70b-instruct-v1:0",
94
+ "meta.llama3-1-405b-instruct-v1:0",
95
+ "mistral.mistral-7b-instruct-v0:2",
96
+ "mistral.mixtral-8x7b-instruct-v0:1",
97
+ "mistral.mistral-large-2402-v1:0",
98
+ "mistral.mistral-large-2407-v1:0",
99
+ ],
69
100
  }
@@ -7,7 +7,19 @@ from edsl.inference_services.AnthropicService import AnthropicService
7
7
  from edsl.inference_services.DeepInfraService import DeepInfraService
8
8
  from edsl.inference_services.GoogleService import GoogleService
9
9
  from edsl.inference_services.GroqService import GroqService
10
+ from edsl.inference_services.AwsBedrock import AwsBedrockService
11
+ from edsl.inference_services.AzureAI import AzureAIService
12
+ from edsl.inference_services.OllamaService import OllamaService
10
13
 
11
14
  default = InferenceServicesCollection(
12
- [OpenAIService, AnthropicService, DeepInfraService, GoogleService, GroqService]
15
+ [
16
+ OpenAIService,
17
+ AnthropicService,
18
+ DeepInfraService,
19
+ GoogleService,
20
+ GroqService,
21
+ AwsBedrockService,
22
+ AzureAIService,
23
+ OllamaService,
24
+ ]
13
25
  )
edsl/jobs/Jobs.py CHANGED
@@ -39,6 +39,8 @@ class Jobs(Base):
39
39
 
40
40
  self.__bucket_collection = None
41
41
 
42
+ # these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
43
+
42
44
  @property
43
45
  def models(self):
44
46
  return self._models
@@ -119,7 +121,9 @@ class Jobs(Base):
119
121
  - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
120
122
  - models: new models overwrite old models.
121
123
  """
122
- passed_objects = self._turn_args_to_list(args)
124
+ passed_objects = self._turn_args_to_list(
125
+ args
126
+ ) # objects can also be passed comma-separated
123
127
 
124
128
  current_objects, objects_key = self._get_current_objects_of_this_type(
125
129
  passed_objects[0]
@@ -176,17 +180,27 @@ class Jobs(Base):
176
180
  from edsl.agents.Agent import Agent
177
181
  from edsl.scenarios.Scenario import Scenario
178
182
  from edsl.scenarios.ScenarioList import ScenarioList
183
+ from edsl.language_models.ModelList import ModelList
179
184
 
180
185
  if isinstance(object, Agent):
181
186
  return AgentList
182
187
  elif isinstance(object, Scenario):
183
188
  return ScenarioList
189
+ elif isinstance(object, ModelList):
190
+ return ModelList
184
191
  else:
185
192
  return list
186
193
 
187
194
  @staticmethod
188
195
  def _turn_args_to_list(args):
189
- """Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments."""
196
+ """Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
197
+
198
+ Example:
199
+
200
+ >>> Jobs._turn_args_to_list([1,2,3])
201
+ [1, 2, 3]
202
+
203
+ """
190
204
 
191
205
  def did_user_pass_a_sequence(args):
192
206
  """Return True if the user passed a sequence, False otherwise.
@@ -209,7 +223,7 @@ class Jobs(Base):
209
223
  return container_class(args)
210
224
 
211
225
  def _get_current_objects_of_this_type(
212
- self, object: Union[Agent, Scenario, LanguageModel]
226
+ self, object: Union["Agent", "Scenario", "LanguageModel"]
213
227
  ) -> tuple[list, str]:
214
228
  from edsl.agents.Agent import Agent
215
229
  from edsl.scenarios.Scenario import Scenario
@@ -292,7 +306,11 @@ class Jobs(Base):
292
306
 
293
307
  @classmethod
294
308
  def from_interviews(cls, interview_list):
295
- """Return a Jobs instance from a list of interviews."""
309
+ """Return a Jobs instance from a list of interviews.
310
+
311
+ This is useful when you have, say, a list of failed interviews and you want to create
312
+ a new job with only those interviews.
313
+ """
296
314
  survey = interview_list[0].survey
297
315
  # get all the models
298
316
  models = list(set([interview.model for interview in interview_list]))
@@ -308,6 +326,8 @@ class Jobs(Base):
308
326
  Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
309
327
  This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
310
328
  with us filling in defaults.
329
+
330
+
311
331
  """
312
332
  # if no agents, models, or scenarios are set, set them to defaults
313
333
  from edsl.agents.Agent import Agent
@@ -319,7 +339,11 @@ class Jobs(Base):
319
339
  self.scenarios = self.scenarios or [Scenario()]
320
340
  for agent, scenario, model in product(self.agents, self.scenarios, self.models):
321
341
  yield Interview(
322
- survey=self.survey, agent=agent, scenario=scenario, model=model
342
+ survey=self.survey,
343
+ agent=agent,
344
+ scenario=scenario,
345
+ model=model,
346
+ skip_retry=self.skip_retry,
323
347
  )
324
348
 
325
349
  def create_bucket_collection(self) -> BucketCollection:
@@ -359,10 +383,16 @@ class Jobs(Base):
359
383
  return links
360
384
 
361
385
  def __hash__(self):
362
- """Allow the model to be used as a key in a dictionary."""
386
+ """Allow the model to be used as a key in a dictionary.
387
+
388
+ >>> from edsl.jobs import Jobs
389
+ >>> hash(Jobs.example())
390
+ 846655441787442972
391
+
392
+ """
363
393
  from edsl.utilities.utilities import dict_hash
364
394
 
365
- return dict_hash(self.to_dict())
395
+ return dict_hash(self._to_dict())
366
396
 
367
397
  def _output(self, message) -> None:
368
398
  """Check if a Job is verbose. If so, print the message."""
@@ -390,11 +420,27 @@ class Jobs(Base):
390
420
  Traceback (most recent call last):
391
421
  ...
392
422
  ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
423
+
424
+ >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
425
+ >>> s = Scenario({'ugly_question': "B"})
426
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
427
+ >>> j._check_parameters()
428
+ Traceback (most recent call last):
429
+ ...
430
+ ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
393
431
  """
394
432
  survey_parameters: set = self.survey.parameters
395
433
  scenario_parameters: set = self.scenarios.parameters
396
434
 
397
- msg1, msg2 = None, None
435
+ msg0, msg1, msg2 = None, None, None
436
+
437
+ # look for key issues
438
+ if intersection := set(self.scenarios.parameters) & set(
439
+ self.survey.question_names
440
+ ):
441
+ msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
442
+
443
+ raise ValueError(msg0)
398
444
 
399
445
  if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
400
446
  msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
@@ -409,6 +455,12 @@ class Jobs(Base):
409
455
  if warn:
410
456
  warnings.warn(message)
411
457
 
458
+ @property
459
+ def skip_retry(self):
460
+ if not hasattr(self, "_skip_retry"):
461
+ return False
462
+ return self._skip_retry
463
+
412
464
  def run(
413
465
  self,
414
466
  n: int = 1,
@@ -423,6 +475,7 @@ class Jobs(Base):
423
475
  print_exceptions=True,
424
476
  remote_cache_description: Optional[str] = None,
425
477
  remote_inference_description: Optional[str] = None,
478
+ skip_retry: bool = False,
426
479
  ) -> Results:
427
480
  """
428
481
  Runs the Job: conducts Interviews and returns their results.
@@ -441,6 +494,7 @@ class Jobs(Base):
441
494
  from edsl.coop.coop import Coop
442
495
 
443
496
  self._check_parameters()
497
+ self._skip_retry = skip_retry
444
498
 
445
499
  if batch_mode is not None:
446
500
  raise NotImplementedError(
@@ -631,12 +685,16 @@ class Jobs(Base):
631
685
  return results
632
686
 
633
687
  async def run_async(self, cache=None, n=1, **kwargs):
634
- """Run the job asynchronously."""
688
+ """Run asynchronously."""
635
689
  results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
636
690
  return results
637
691
 
638
692
  def all_question_parameters(self):
639
- """Return all the fields in the questions in the survey."""
693
+ """Return all the fields in the questions in the survey.
694
+ >>> from edsl.jobs import Jobs
695
+ >>> Jobs.example().all_question_parameters()
696
+ {'period'}
697
+ """
640
698
  return set.union(*[question.parameters for question in self.survey.questions])
641
699
 
642
700
  #######################
@@ -677,15 +735,19 @@ class Jobs(Base):
677
735
  #######################
678
736
  # Serialization methods
679
737
  #######################
738
+
739
+ def _to_dict(self):
740
+ return {
741
+ "survey": self.survey._to_dict(),
742
+ "agents": [agent._to_dict() for agent in self.agents],
743
+ "models": [model._to_dict() for model in self.models],
744
+ "scenarios": [scenario._to_dict() for scenario in self.scenarios],
745
+ }
746
+
680
747
  @add_edsl_version
681
748
  def to_dict(self) -> dict:
682
749
  """Convert the Jobs instance to a dictionary."""
683
- return {
684
- "survey": self.survey.to_dict(),
685
- "agents": [agent.to_dict() for agent in self.agents],
686
- "models": [model.to_dict() for model in self.models],
687
- "scenarios": [scenario.to_dict() for scenario in self.scenarios],
688
- }
750
+ return self._to_dict()
689
751
 
690
752
  @classmethod
691
753
  @remove_edsl_version
@@ -704,7 +766,13 @@ class Jobs(Base):
704
766
  )
705
767
 
706
768
  def __eq__(self, other: Jobs) -> bool:
707
- """Return True if the Jobs instance is equal to another Jobs instance."""
769
+ """Return True if the Jobs instance is equal to another Jobs instance.
770
+
771
+ >>> from edsl.jobs import Jobs
772
+ >>> Jobs.example() == Jobs.example()
773
+ True
774
+
775
+ """
708
776
  return self.to_dict() == other.to_dict()
709
777
 
710
778
  #######################
@@ -712,11 +780,16 @@ class Jobs(Base):
712
780
  #######################
713
781
  @classmethod
714
782
  def example(
715
- cls, throw_exception_probability: int = 0, randomize: bool = False
783
+ cls,
784
+ throw_exception_probability: float = 0.0,
785
+ randomize: bool = False,
786
+ test_model=False,
716
787
  ) -> Jobs:
717
788
  """Return an example Jobs instance.
718
789
 
719
790
  :param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
791
+ :param randomize: whether to randomize the job by adding a random string to the period
792
+ :param test_model: whether to use a test model
720
793
 
721
794
  >>> Jobs.example()
722
795
  Jobs(...)
@@ -730,6 +803,11 @@ class Jobs(Base):
730
803
 
731
804
  addition = "" if not randomize else str(uuid4())
732
805
 
806
+ if test_model:
807
+ from edsl.language_models import LanguageModel
808
+
809
+ m = LanguageModel.example(test_model=True)
810
+
733
811
  # (status, question, period)
734
812
  agent_answers = {
735
813
  ("Joyful", "how_feeling", "morning"): "OK",
@@ -777,7 +855,10 @@ class Jobs(Base):
777
855
  Scenario({"period": "afternoon"}),
778
856
  ]
779
857
  )
780
- job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
858
+ if test_model:
859
+ job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
860
+ else:
861
+ job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
781
862
 
782
863
  return job
783
864
 
@@ -100,7 +100,9 @@ class TokenBucket:
100
100
  available_tokens = min(self.capacity, self.tokens + refill_amount)
101
101
  return max(0, requested_tokens - available_tokens) / self.refill_rate
102
102
 
103
- async def get_tokens(self, amount: Union[int, float] = 1) -> None:
103
+ async def get_tokens(
104
+ self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
105
+ ) -> None:
104
106
  """Wait for the specified number of tokens to become available.
105
107
 
106
108
 
@@ -116,14 +118,20 @@ class TokenBucket:
116
118
  True
117
119
 
118
120
  >>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
119
- >>> asyncio.run(bucket.get_tokens(11))
121
+ >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=False))
120
122
  Traceback (most recent call last):
121
123
  ...
122
124
  ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
125
+ >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
123
126
  """
124
127
  if amount > self.capacity:
125
- msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
126
- raise ValueError(msg)
128
+ if not cheat_bucket_capacity:
129
+ msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
130
+ raise ValueError(msg)
131
+ else:
132
+ self.tokens = 0 # clear the bucket but let it go through
133
+ return
134
+
127
135
  while self.tokens < amount:
128
136
  self.refill()
129
137
  await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
@@ -14,8 +14,8 @@ from edsl.jobs.tasks.TaskCreators import TaskCreators
14
14
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
15
15
  from edsl.jobs.interviews.interview_exception_tracking import (
16
16
  InterviewExceptionCollection,
17
- InterviewExceptionEntry,
18
17
  )
18
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
19
19
  from edsl.jobs.interviews.retry_management import retry_strategy
20
20
  from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
21
21
  from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
@@ -44,6 +44,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
44
44
  iteration: int = 0,
45
45
  cache: Optional["Cache"] = None,
46
46
  sidecar_model: Optional["LanguageModel"] = None,
47
+ skip_retry=False,
47
48
  ):
48
49
  """Initialize the Interview instance.
49
50
 
@@ -87,6 +88,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
87
88
  self.task_creators = TaskCreators() # tracks the task creators
88
89
  self.exceptions = InterviewExceptionCollection()
89
90
  self._task_status_log_dict = InterviewStatusLog()
91
+ self.skip_retry = skip_retry
90
92
 
91
93
  # dictionary mapping question names to their index in the survey.
92
94
  self.to_index = {
@@ -94,6 +96,30 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
94
96
  for index, question_name in enumerate(self.survey.question_names)
95
97
  }
96
98
 
99
+ def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
100
+ """Return a dictionary representation of the Interview instance.
101
+ This is just for hashing purposes.
102
+
103
+ >>> i = Interview.example()
104
+ >>> hash(i)
105
+ 1646262796627658719
106
+ """
107
+ d = {
108
+ "agent": self.agent._to_dict(),
109
+ "survey": self.survey._to_dict(),
110
+ "scenario": self.scenario._to_dict(),
111
+ "model": self.model._to_dict(),
112
+ "iteration": self.iteration,
113
+ "exceptions": {},
114
+ }
115
+ if include_exceptions:
116
+ d["exceptions"] = self.exceptions.to_dict()
117
+
118
+ def __hash__(self) -> int:
119
+ from edsl.utilities.utilities import dict_hash
120
+
121
+ return dict_hash(self._to_dict())
122
+
97
123
  async def async_conduct_interview(
98
124
  self,
99
125
  *,
@@ -134,8 +160,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
134
160
  <BLANKLINE>
135
161
 
136
162
  >>> i.exceptions
137
- {'q0': [{'exception': "Exception('This is a test error')", 'time': ..., 'traceback': ...
138
-
163
+ {'q0': ...
139
164
  >>> i = Interview.example()
140
165
  >>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
141
166
  Traceback (most recent call last):
@@ -204,13 +229,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
204
229
  {}
205
230
  >>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
206
231
  >>> i.exceptions
207
- {'q0': [{'exception': "Exception('An exception occurred.')", 'time': ..., 'traceback': 'NoneType: None\\n'}]}
232
+ {'q0': ...
208
233
  """
209
- exception_entry = InterviewExceptionEntry(
210
- exception=repr(exception),
211
- time=time.time(),
212
- traceback=traceback.format_exc(),
213
- )
234
+ exception_entry = InterviewExceptionEntry(exception)
214
235
  self.exceptions.add(task.get_name(), exception_entry)
215
236
 
216
237
  @property
@@ -251,6 +272,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
251
272
  model=self.model,
252
273
  iteration=iteration,
253
274
  cache=cache,
275
+ skip_retry=self.skip_retry,
254
276
  )
255
277
 
256
278
  @classmethod
@@ -0,0 +1,101 @@
1
+ import traceback
2
+ import datetime
3
+ import time
4
+ from collections import UserDict
5
+
6
+ # traceback=traceback.format_exc(),
7
+ # traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
8
+ # traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
9
+
10
+
11
+ class InterviewExceptionEntry:
12
+ """Class to record an exception that occurred during the interview.
13
+
14
+ >>> entry = InterviewExceptionEntry.example()
15
+ >>> entry.to_dict()['exception']
16
+ "ValueError('An error occurred.')"
17
+ """
18
+
19
+ def __init__(self, exception: Exception, traceback_format="html"):
20
+ self.time = datetime.datetime.now().isoformat()
21
+ self.exception = exception
22
+ self.traceback_format = traceback_format
23
+
24
+ def __getitem__(self, key):
25
+ # Support dict-like access obj['a']
26
+ return str(getattr(self, key))
27
+
28
+ @classmethod
29
+ def example(cls):
30
+ try:
31
+ raise ValueError("An error occurred.")
32
+ except Exception as e:
33
+ entry = InterviewExceptionEntry(e)
34
+ return entry
35
+
36
+ @property
37
+ def traceback(self):
38
+ """Return the exception as HTML."""
39
+ if self.traceback_format == "html":
40
+ return self.html_traceback
41
+ else:
42
+ return self.text_traceback
43
+
44
+ @property
45
+ def text_traceback(self):
46
+ """
47
+ >>> entry = InterviewExceptionEntry.example()
48
+ >>> entry.text_traceback
49
+ 'Traceback (most recent call last):...'
50
+ """
51
+ e = self.exception
52
+ tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
53
+ return tb_str
54
+
55
+ @property
56
+ def html_traceback(self):
57
+ from rich.console import Console
58
+ from rich.table import Table
59
+ from rich.traceback import Traceback
60
+
61
+ from io import StringIO
62
+
63
+ html_output = StringIO()
64
+
65
+ console = Console(file=html_output, record=True)
66
+
67
+ tb = Traceback.from_exception(
68
+ type(self.exception),
69
+ self.exception,
70
+ self.exception.__traceback__,
71
+ show_locals=True,
72
+ )
73
+ console.print(tb)
74
+ return html_output.getvalue()
75
+
76
+ def to_dict(self) -> dict:
77
+ """Return the exception as a dictionary.
78
+
79
+ >>> entry = InterviewExceptionEntry.example()
80
+ >>> entry.to_dict()['exception']
81
+ "ValueError('An error occurred.')"
82
+
83
+ """
84
+ return {
85
+ "exception": repr(self.exception),
86
+ "time": self.time,
87
+ "traceback": self.traceback,
88
+ }
89
+
90
+ def push(self):
91
+ from edsl import Coop
92
+
93
+ coop = Coop()
94
+ results = coop.error_create(self.to_dict())
95
+ return results
96
+
97
+
98
+ if __name__ == "__main__":
99
+ import doctest
100
+
101
+ doctest.testmod(optionflags=doctest.ELLIPSIS)