edsl 0.1.31.dev4__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +3 -4
- edsl/agents/PromptConstructionMixin.py +35 -15
- edsl/config.py +11 -1
- edsl/conjure/Conjure.py +6 -0
- edsl/data/CacheHandler.py +3 -4
- edsl/enums.py +4 -0
- edsl/exceptions/general.py +10 -8
- edsl/inference_services/AwsBedrock.py +110 -0
- edsl/inference_services/AzureAI.py +197 -0
- edsl/inference_services/DeepInfraService.py +4 -3
- edsl/inference_services/GroqService.py +3 -4
- edsl/inference_services/InferenceServicesCollection.py +13 -8
- edsl/inference_services/OllamaService.py +18 -0
- edsl/inference_services/OpenAIService.py +23 -18
- edsl/inference_services/models_available_cache.py +31 -0
- edsl/inference_services/registry.py +13 -1
- edsl/jobs/Jobs.py +100 -19
- edsl/jobs/buckets/TokenBucket.py +12 -4
- edsl/jobs/interviews/Interview.py +31 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -34
- edsl/jobs/interviews/interview_exception_tracking.py +68 -10
- edsl/jobs/runners/JobsRunnerAsyncio.py +36 -15
- edsl/jobs/runners/JobsRunnerStatusMixin.py +81 -51
- edsl/jobs/tasks/TaskCreators.py +1 -1
- edsl/jobs/tasks/TaskHistory.py +145 -1
- edsl/language_models/LanguageModel.py +58 -43
- edsl/language_models/registry.py +2 -2
- edsl/questions/QuestionBudget.py +0 -1
- edsl/questions/QuestionCheckBox.py +0 -1
- edsl/questions/QuestionExtract.py +0 -1
- edsl/questions/QuestionFreeText.py +2 -9
- edsl/questions/QuestionList.py +0 -1
- edsl/questions/QuestionMultipleChoice.py +1 -2
- edsl/questions/QuestionNumerical.py +0 -1
- edsl/questions/QuestionRank.py +0 -1
- edsl/results/DatasetExportMixin.py +33 -3
- edsl/scenarios/Scenario.py +14 -0
- edsl/scenarios/ScenarioList.py +216 -13
- edsl/scenarios/ScenarioListExportMixin.py +15 -4
- edsl/scenarios/ScenarioListPdfMixin.py +3 -0
- edsl/surveys/Rule.py +5 -2
- edsl/surveys/Survey.py +84 -1
- edsl/surveys/SurveyQualtricsImport.py +213 -0
- edsl/utilities/utilities.py +31 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/METADATA +4 -1
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/RECORD +50 -45
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev4.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -1,12 +1,14 @@
|
|
1
1
|
from typing import Any, List
|
2
2
|
import re
|
3
3
|
import os
|
4
|
-
|
4
|
+
|
5
|
+
# from openai import AsyncOpenAI
|
5
6
|
import openai
|
6
7
|
|
7
8
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
9
|
from edsl.language_models import LanguageModel
|
9
10
|
from edsl.inference_services.rate_limits_cache import rate_limits
|
11
|
+
from edsl.utilities.utilities import fix_partial_correct_response
|
10
12
|
|
11
13
|
|
12
14
|
class OpenAIService(InferenceServiceABC):
|
@@ -18,18 +20,18 @@ class OpenAIService(InferenceServiceABC):
|
|
18
20
|
|
19
21
|
_sync_client_ = openai.OpenAI
|
20
22
|
_async_client_ = openai.AsyncOpenAI
|
21
|
-
|
23
|
+
|
22
24
|
@classmethod
|
23
25
|
def sync_client(cls):
|
24
26
|
return cls._sync_client_(
|
25
|
-
api_key
|
26
|
-
|
27
|
-
|
27
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
28
|
+
)
|
29
|
+
|
28
30
|
@classmethod
|
29
31
|
def async_client(cls):
|
30
32
|
return cls._async_client_(
|
31
|
-
api_key
|
32
|
-
|
33
|
+
api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
|
34
|
+
)
|
33
35
|
|
34
36
|
# TODO: Make this a coop call
|
35
37
|
model_exclude_list = [
|
@@ -59,14 +61,14 @@ class OpenAIService(InferenceServiceABC):
|
|
59
61
|
|
60
62
|
@classmethod
|
61
63
|
def available(cls) -> List[str]:
|
62
|
-
#from openai import OpenAI
|
64
|
+
# from openai import OpenAI
|
63
65
|
|
64
66
|
if not cls._models_list_cache:
|
65
67
|
try:
|
66
|
-
#client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
68
|
+
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
67
69
|
cls._models_list_cache = [
|
68
70
|
m.id
|
69
|
-
for m in cls.get_model_list()
|
71
|
+
for m in cls.get_model_list()
|
70
72
|
if m.id not in cls.model_exclude_list
|
71
73
|
]
|
72
74
|
except Exception as e:
|
@@ -106,21 +108,21 @@ class OpenAIService(InferenceServiceABC):
|
|
106
108
|
|
107
109
|
def sync_client(self):
|
108
110
|
return cls.sync_client()
|
109
|
-
|
111
|
+
|
110
112
|
def async_client(self):
|
111
113
|
return cls.async_client()
|
112
114
|
|
113
115
|
@classmethod
|
114
116
|
def available(cls) -> list[str]:
|
115
|
-
#import openai
|
116
|
-
#client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
117
|
-
#return client.models.list()
|
117
|
+
# import openai
|
118
|
+
# client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
119
|
+
# return client.models.list()
|
118
120
|
return cls.sync_client().models.list()
|
119
|
-
|
121
|
+
|
120
122
|
def get_headers(self) -> dict[str, Any]:
|
121
|
-
#from openai import OpenAI
|
123
|
+
# from openai import OpenAI
|
122
124
|
|
123
|
-
#client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
125
|
+
# client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
124
126
|
client = self.sync_client()
|
125
127
|
response = client.chat.completions.with_raw_response.create(
|
126
128
|
messages=[
|
@@ -172,7 +174,7 @@ class OpenAIService(InferenceServiceABC):
|
|
172
174
|
else:
|
173
175
|
content = user_prompt
|
174
176
|
# self.client = AsyncOpenAI(
|
175
|
-
# api_key = os.getenv(cls._env_key_name_),
|
177
|
+
# api_key = os.getenv(cls._env_key_name_),
|
176
178
|
# base_url = cls._base_url_
|
177
179
|
# )
|
178
180
|
client = self.async_client()
|
@@ -206,6 +208,9 @@ class OpenAIService(InferenceServiceABC):
|
|
206
208
|
if match:
|
207
209
|
return match.group(1)
|
208
210
|
else:
|
211
|
+
out = fix_partial_correct_response(response)
|
212
|
+
if "error" not in out:
|
213
|
+
response = out["extracted_json"]
|
209
214
|
return response
|
210
215
|
|
211
216
|
LLM.__name__ = "LanguageModel"
|
@@ -66,4 +66,35 @@ models_available = {
|
|
66
66
|
"openchat/openchat_3.5",
|
67
67
|
],
|
68
68
|
"google": ["gemini-pro"],
|
69
|
+
"bedrock": [
|
70
|
+
"amazon.titan-tg1-large",
|
71
|
+
"amazon.titan-text-lite-v1",
|
72
|
+
"amazon.titan-text-express-v1",
|
73
|
+
"ai21.j2-grande-instruct",
|
74
|
+
"ai21.j2-jumbo-instruct",
|
75
|
+
"ai21.j2-mid",
|
76
|
+
"ai21.j2-mid-v1",
|
77
|
+
"ai21.j2-ultra",
|
78
|
+
"ai21.j2-ultra-v1",
|
79
|
+
"anthropic.claude-instant-v1",
|
80
|
+
"anthropic.claude-v2:1",
|
81
|
+
"anthropic.claude-v2",
|
82
|
+
"anthropic.claude-3-sonnet-20240229-v1:0",
|
83
|
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
84
|
+
"anthropic.claude-3-opus-20240229-v1:0",
|
85
|
+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
86
|
+
"cohere.command-text-v14",
|
87
|
+
"cohere.command-r-v1:0",
|
88
|
+
"cohere.command-r-plus-v1:0",
|
89
|
+
"cohere.command-light-text-v14",
|
90
|
+
"meta.llama3-8b-instruct-v1:0",
|
91
|
+
"meta.llama3-70b-instruct-v1:0",
|
92
|
+
"meta.llama3-1-8b-instruct-v1:0",
|
93
|
+
"meta.llama3-1-70b-instruct-v1:0",
|
94
|
+
"meta.llama3-1-405b-instruct-v1:0",
|
95
|
+
"mistral.mistral-7b-instruct-v0:2",
|
96
|
+
"mistral.mixtral-8x7b-instruct-v0:1",
|
97
|
+
"mistral.mistral-large-2402-v1:0",
|
98
|
+
"mistral.mistral-large-2407-v1:0",
|
99
|
+
],
|
69
100
|
}
|
@@ -7,7 +7,19 @@ from edsl.inference_services.AnthropicService import AnthropicService
|
|
7
7
|
from edsl.inference_services.DeepInfraService import DeepInfraService
|
8
8
|
from edsl.inference_services.GoogleService import GoogleService
|
9
9
|
from edsl.inference_services.GroqService import GroqService
|
10
|
+
from edsl.inference_services.AwsBedrock import AwsBedrockService
|
11
|
+
from edsl.inference_services.AzureAI import AzureAIService
|
12
|
+
from edsl.inference_services.OllamaService import OllamaService
|
10
13
|
|
11
14
|
default = InferenceServicesCollection(
|
12
|
-
[
|
15
|
+
[
|
16
|
+
OpenAIService,
|
17
|
+
AnthropicService,
|
18
|
+
DeepInfraService,
|
19
|
+
GoogleService,
|
20
|
+
GroqService,
|
21
|
+
AwsBedrockService,
|
22
|
+
AzureAIService,
|
23
|
+
OllamaService,
|
24
|
+
]
|
13
25
|
)
|
edsl/jobs/Jobs.py
CHANGED
@@ -39,6 +39,8 @@ class Jobs(Base):
|
|
39
39
|
|
40
40
|
self.__bucket_collection = None
|
41
41
|
|
42
|
+
# these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
|
43
|
+
|
42
44
|
@property
|
43
45
|
def models(self):
|
44
46
|
return self._models
|
@@ -119,7 +121,9 @@ class Jobs(Base):
|
|
119
121
|
- scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
|
120
122
|
- models: new models overwrite old models.
|
121
123
|
"""
|
122
|
-
passed_objects = self._turn_args_to_list(
|
124
|
+
passed_objects = self._turn_args_to_list(
|
125
|
+
args
|
126
|
+
) # objects can also be passed comma-separated
|
123
127
|
|
124
128
|
current_objects, objects_key = self._get_current_objects_of_this_type(
|
125
129
|
passed_objects[0]
|
@@ -176,17 +180,27 @@ class Jobs(Base):
|
|
176
180
|
from edsl.agents.Agent import Agent
|
177
181
|
from edsl.scenarios.Scenario import Scenario
|
178
182
|
from edsl.scenarios.ScenarioList import ScenarioList
|
183
|
+
from edsl.language_models.ModelList import ModelList
|
179
184
|
|
180
185
|
if isinstance(object, Agent):
|
181
186
|
return AgentList
|
182
187
|
elif isinstance(object, Scenario):
|
183
188
|
return ScenarioList
|
189
|
+
elif isinstance(object, ModelList):
|
190
|
+
return ModelList
|
184
191
|
else:
|
185
192
|
return list
|
186
193
|
|
187
194
|
@staticmethod
|
188
195
|
def _turn_args_to_list(args):
|
189
|
-
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
|
196
|
+
"""Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
|
197
|
+
|
198
|
+
Example:
|
199
|
+
|
200
|
+
>>> Jobs._turn_args_to_list([1,2,3])
|
201
|
+
[1, 2, 3]
|
202
|
+
|
203
|
+
"""
|
190
204
|
|
191
205
|
def did_user_pass_a_sequence(args):
|
192
206
|
"""Return True if the user passed a sequence, False otherwise.
|
@@ -209,7 +223,7 @@ class Jobs(Base):
|
|
209
223
|
return container_class(args)
|
210
224
|
|
211
225
|
def _get_current_objects_of_this_type(
|
212
|
-
self, object: Union[Agent, Scenario, LanguageModel]
|
226
|
+
self, object: Union["Agent", "Scenario", "LanguageModel"]
|
213
227
|
) -> tuple[list, str]:
|
214
228
|
from edsl.agents.Agent import Agent
|
215
229
|
from edsl.scenarios.Scenario import Scenario
|
@@ -292,7 +306,11 @@ class Jobs(Base):
|
|
292
306
|
|
293
307
|
@classmethod
|
294
308
|
def from_interviews(cls, interview_list):
|
295
|
-
"""Return a Jobs instance from a list of interviews.
|
309
|
+
"""Return a Jobs instance from a list of interviews.
|
310
|
+
|
311
|
+
This is useful when you have, say, a list of failed interviews and you want to create
|
312
|
+
a new job with only those interviews.
|
313
|
+
"""
|
296
314
|
survey = interview_list[0].survey
|
297
315
|
# get all the models
|
298
316
|
models = list(set([interview.model for interview in interview_list]))
|
@@ -308,6 +326,8 @@ class Jobs(Base):
|
|
308
326
|
Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
|
309
327
|
This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
|
310
328
|
with us filling in defaults.
|
329
|
+
|
330
|
+
|
311
331
|
"""
|
312
332
|
# if no agents, models, or scenarios are set, set them to defaults
|
313
333
|
from edsl.agents.Agent import Agent
|
@@ -319,7 +339,11 @@ class Jobs(Base):
|
|
319
339
|
self.scenarios = self.scenarios or [Scenario()]
|
320
340
|
for agent, scenario, model in product(self.agents, self.scenarios, self.models):
|
321
341
|
yield Interview(
|
322
|
-
survey=self.survey,
|
342
|
+
survey=self.survey,
|
343
|
+
agent=agent,
|
344
|
+
scenario=scenario,
|
345
|
+
model=model,
|
346
|
+
skip_retry=self.skip_retry,
|
323
347
|
)
|
324
348
|
|
325
349
|
def create_bucket_collection(self) -> BucketCollection:
|
@@ -359,10 +383,16 @@ class Jobs(Base):
|
|
359
383
|
return links
|
360
384
|
|
361
385
|
def __hash__(self):
|
362
|
-
"""Allow the model to be used as a key in a dictionary.
|
386
|
+
"""Allow the model to be used as a key in a dictionary.
|
387
|
+
|
388
|
+
>>> from edsl.jobs import Jobs
|
389
|
+
>>> hash(Jobs.example())
|
390
|
+
846655441787442972
|
391
|
+
|
392
|
+
"""
|
363
393
|
from edsl.utilities.utilities import dict_hash
|
364
394
|
|
365
|
-
return dict_hash(self.
|
395
|
+
return dict_hash(self._to_dict())
|
366
396
|
|
367
397
|
def _output(self, message) -> None:
|
368
398
|
"""Check if a Job is verbose. If so, print the message."""
|
@@ -390,11 +420,27 @@ class Jobs(Base):
|
|
390
420
|
Traceback (most recent call last):
|
391
421
|
...
|
392
422
|
ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
|
423
|
+
|
424
|
+
>>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
|
425
|
+
>>> s = Scenario({'ugly_question': "B"})
|
426
|
+
>>> j = Jobs(survey = Survey(questions=[q])).by(s)
|
427
|
+
>>> j._check_parameters()
|
428
|
+
Traceback (most recent call last):
|
429
|
+
...
|
430
|
+
ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
|
393
431
|
"""
|
394
432
|
survey_parameters: set = self.survey.parameters
|
395
433
|
scenario_parameters: set = self.scenarios.parameters
|
396
434
|
|
397
|
-
msg1, msg2 = None, None
|
435
|
+
msg0, msg1, msg2 = None, None, None
|
436
|
+
|
437
|
+
# look for key issues
|
438
|
+
if intersection := set(self.scenarios.parameters) & set(
|
439
|
+
self.survey.question_names
|
440
|
+
):
|
441
|
+
msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
|
442
|
+
|
443
|
+
raise ValueError(msg0)
|
398
444
|
|
399
445
|
if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
|
400
446
|
msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
|
@@ -409,6 +455,12 @@ class Jobs(Base):
|
|
409
455
|
if warn:
|
410
456
|
warnings.warn(message)
|
411
457
|
|
458
|
+
@property
|
459
|
+
def skip_retry(self):
|
460
|
+
if not hasattr(self, "_skip_retry"):
|
461
|
+
return False
|
462
|
+
return self._skip_retry
|
463
|
+
|
412
464
|
def run(
|
413
465
|
self,
|
414
466
|
n: int = 1,
|
@@ -423,6 +475,7 @@ class Jobs(Base):
|
|
423
475
|
print_exceptions=True,
|
424
476
|
remote_cache_description: Optional[str] = None,
|
425
477
|
remote_inference_description: Optional[str] = None,
|
478
|
+
skip_retry: bool = False,
|
426
479
|
) -> Results:
|
427
480
|
"""
|
428
481
|
Runs the Job: conducts Interviews and returns their results.
|
@@ -441,6 +494,7 @@ class Jobs(Base):
|
|
441
494
|
from edsl.coop.coop import Coop
|
442
495
|
|
443
496
|
self._check_parameters()
|
497
|
+
self._skip_retry = skip_retry
|
444
498
|
|
445
499
|
if batch_mode is not None:
|
446
500
|
raise NotImplementedError(
|
@@ -631,12 +685,16 @@ class Jobs(Base):
|
|
631
685
|
return results
|
632
686
|
|
633
687
|
async def run_async(self, cache=None, n=1, **kwargs):
|
634
|
-
"""Run
|
688
|
+
"""Run asynchronously."""
|
635
689
|
results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
636
690
|
return results
|
637
691
|
|
638
692
|
def all_question_parameters(self):
|
639
|
-
"""Return all the fields in the questions in the survey.
|
693
|
+
"""Return all the fields in the questions in the survey.
|
694
|
+
>>> from edsl.jobs import Jobs
|
695
|
+
>>> Jobs.example().all_question_parameters()
|
696
|
+
{'period'}
|
697
|
+
"""
|
640
698
|
return set.union(*[question.parameters for question in self.survey.questions])
|
641
699
|
|
642
700
|
#######################
|
@@ -677,15 +735,19 @@ class Jobs(Base):
|
|
677
735
|
#######################
|
678
736
|
# Serialization methods
|
679
737
|
#######################
|
738
|
+
|
739
|
+
def _to_dict(self):
|
740
|
+
return {
|
741
|
+
"survey": self.survey._to_dict(),
|
742
|
+
"agents": [agent._to_dict() for agent in self.agents],
|
743
|
+
"models": [model._to_dict() for model in self.models],
|
744
|
+
"scenarios": [scenario._to_dict() for scenario in self.scenarios],
|
745
|
+
}
|
746
|
+
|
680
747
|
@add_edsl_version
|
681
748
|
def to_dict(self) -> dict:
|
682
749
|
"""Convert the Jobs instance to a dictionary."""
|
683
|
-
return
|
684
|
-
"survey": self.survey.to_dict(),
|
685
|
-
"agents": [agent.to_dict() for agent in self.agents],
|
686
|
-
"models": [model.to_dict() for model in self.models],
|
687
|
-
"scenarios": [scenario.to_dict() for scenario in self.scenarios],
|
688
|
-
}
|
750
|
+
return self._to_dict()
|
689
751
|
|
690
752
|
@classmethod
|
691
753
|
@remove_edsl_version
|
@@ -704,7 +766,13 @@ class Jobs(Base):
|
|
704
766
|
)
|
705
767
|
|
706
768
|
def __eq__(self, other: Jobs) -> bool:
|
707
|
-
"""Return True if the Jobs instance is equal to another Jobs instance.
|
769
|
+
"""Return True if the Jobs instance is equal to another Jobs instance.
|
770
|
+
|
771
|
+
>>> from edsl.jobs import Jobs
|
772
|
+
>>> Jobs.example() == Jobs.example()
|
773
|
+
True
|
774
|
+
|
775
|
+
"""
|
708
776
|
return self.to_dict() == other.to_dict()
|
709
777
|
|
710
778
|
#######################
|
@@ -712,11 +780,16 @@ class Jobs(Base):
|
|
712
780
|
#######################
|
713
781
|
@classmethod
|
714
782
|
def example(
|
715
|
-
cls,
|
783
|
+
cls,
|
784
|
+
throw_exception_probability: float = 0.0,
|
785
|
+
randomize: bool = False,
|
786
|
+
test_model=False,
|
716
787
|
) -> Jobs:
|
717
788
|
"""Return an example Jobs instance.
|
718
789
|
|
719
790
|
:param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
|
791
|
+
:param randomize: whether to randomize the job by adding a random string to the period
|
792
|
+
:param test_model: whether to use a test model
|
720
793
|
|
721
794
|
>>> Jobs.example()
|
722
795
|
Jobs(...)
|
@@ -730,6 +803,11 @@ class Jobs(Base):
|
|
730
803
|
|
731
804
|
addition = "" if not randomize else str(uuid4())
|
732
805
|
|
806
|
+
if test_model:
|
807
|
+
from edsl.language_models import LanguageModel
|
808
|
+
|
809
|
+
m = LanguageModel.example(test_model=True)
|
810
|
+
|
733
811
|
# (status, question, period)
|
734
812
|
agent_answers = {
|
735
813
|
("Joyful", "how_feeling", "morning"): "OK",
|
@@ -777,7 +855,10 @@ class Jobs(Base):
|
|
777
855
|
Scenario({"period": "afternoon"}),
|
778
856
|
]
|
779
857
|
)
|
780
|
-
|
858
|
+
if test_model:
|
859
|
+
job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
|
860
|
+
else:
|
861
|
+
job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
|
781
862
|
|
782
863
|
return job
|
783
864
|
|
edsl/jobs/buckets/TokenBucket.py
CHANGED
@@ -100,7 +100,9 @@ class TokenBucket:
|
|
100
100
|
available_tokens = min(self.capacity, self.tokens + refill_amount)
|
101
101
|
return max(0, requested_tokens - available_tokens) / self.refill_rate
|
102
102
|
|
103
|
-
async def get_tokens(
|
103
|
+
async def get_tokens(
|
104
|
+
self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
|
105
|
+
) -> None:
|
104
106
|
"""Wait for the specified number of tokens to become available.
|
105
107
|
|
106
108
|
|
@@ -116,14 +118,20 @@ class TokenBucket:
|
|
116
118
|
True
|
117
119
|
|
118
120
|
>>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
|
119
|
-
>>> asyncio.run(bucket.get_tokens(11))
|
121
|
+
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=False))
|
120
122
|
Traceback (most recent call last):
|
121
123
|
...
|
122
124
|
ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
|
125
|
+
>>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
|
123
126
|
"""
|
124
127
|
if amount > self.capacity:
|
125
|
-
|
126
|
-
|
128
|
+
if not cheat_bucket_capacity:
|
129
|
+
msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
|
130
|
+
raise ValueError(msg)
|
131
|
+
else:
|
132
|
+
self.tokens = 0 # clear the bucket but let it go through
|
133
|
+
return
|
134
|
+
|
127
135
|
while self.tokens < amount:
|
128
136
|
self.refill()
|
129
137
|
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
|
@@ -14,8 +14,8 @@ from edsl.jobs.tasks.TaskCreators import TaskCreators
|
|
14
14
|
from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
|
15
15
|
from edsl.jobs.interviews.interview_exception_tracking import (
|
16
16
|
InterviewExceptionCollection,
|
17
|
-
InterviewExceptionEntry,
|
18
17
|
)
|
18
|
+
from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
|
19
19
|
from edsl.jobs.interviews.retry_management import retry_strategy
|
20
20
|
from edsl.jobs.interviews.InterviewTaskBuildingMixin import InterviewTaskBuildingMixin
|
21
21
|
from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
|
@@ -44,6 +44,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
44
44
|
iteration: int = 0,
|
45
45
|
cache: Optional["Cache"] = None,
|
46
46
|
sidecar_model: Optional["LanguageModel"] = None,
|
47
|
+
skip_retry=False,
|
47
48
|
):
|
48
49
|
"""Initialize the Interview instance.
|
49
50
|
|
@@ -87,6 +88,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
87
88
|
self.task_creators = TaskCreators() # tracks the task creators
|
88
89
|
self.exceptions = InterviewExceptionCollection()
|
89
90
|
self._task_status_log_dict = InterviewStatusLog()
|
91
|
+
self.skip_retry = skip_retry
|
90
92
|
|
91
93
|
# dictionary mapping question names to their index in the survey.
|
92
94
|
self.to_index = {
|
@@ -94,6 +96,30 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
94
96
|
for index, question_name in enumerate(self.survey.question_names)
|
95
97
|
}
|
96
98
|
|
99
|
+
def _to_dict(self, include_exceptions=False) -> dict[str, Any]:
|
100
|
+
"""Return a dictionary representation of the Interview instance.
|
101
|
+
This is just for hashing purposes.
|
102
|
+
|
103
|
+
>>> i = Interview.example()
|
104
|
+
>>> hash(i)
|
105
|
+
1646262796627658719
|
106
|
+
"""
|
107
|
+
d = {
|
108
|
+
"agent": self.agent._to_dict(),
|
109
|
+
"survey": self.survey._to_dict(),
|
110
|
+
"scenario": self.scenario._to_dict(),
|
111
|
+
"model": self.model._to_dict(),
|
112
|
+
"iteration": self.iteration,
|
113
|
+
"exceptions": {},
|
114
|
+
}
|
115
|
+
if include_exceptions:
|
116
|
+
d["exceptions"] = self.exceptions.to_dict()
|
117
|
+
|
118
|
+
def __hash__(self) -> int:
|
119
|
+
from edsl.utilities.utilities import dict_hash
|
120
|
+
|
121
|
+
return dict_hash(self._to_dict())
|
122
|
+
|
97
123
|
async def async_conduct_interview(
|
98
124
|
self,
|
99
125
|
*,
|
@@ -134,8 +160,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
134
160
|
<BLANKLINE>
|
135
161
|
|
136
162
|
>>> i.exceptions
|
137
|
-
{'q0':
|
138
|
-
|
163
|
+
{'q0': ...
|
139
164
|
>>> i = Interview.example()
|
140
165
|
>>> result, _ = asyncio.run(i.async_conduct_interview(stop_on_exception = True))
|
141
166
|
Traceback (most recent call last):
|
@@ -204,13 +229,9 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
204
229
|
{}
|
205
230
|
>>> i._record_exception(i.tasks[0], Exception("An exception occurred."))
|
206
231
|
>>> i.exceptions
|
207
|
-
{'q0':
|
232
|
+
{'q0': ...
|
208
233
|
"""
|
209
|
-
exception_entry = InterviewExceptionEntry(
|
210
|
-
exception=repr(exception),
|
211
|
-
time=time.time(),
|
212
|
-
traceback=traceback.format_exc(),
|
213
|
-
)
|
234
|
+
exception_entry = InterviewExceptionEntry(exception)
|
214
235
|
self.exceptions.add(task.get_name(), exception_entry)
|
215
236
|
|
216
237
|
@property
|
@@ -251,6 +272,7 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):
|
|
251
272
|
model=self.model,
|
252
273
|
iteration=iteration,
|
253
274
|
cache=cache,
|
275
|
+
skip_retry=self.skip_retry,
|
254
276
|
)
|
255
277
|
|
256
278
|
@classmethod
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import traceback
|
2
|
+
import datetime
|
3
|
+
import time
|
4
|
+
from collections import UserDict
|
5
|
+
|
6
|
+
# traceback=traceback.format_exc(),
|
7
|
+
# traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
|
8
|
+
# traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
|
9
|
+
|
10
|
+
|
11
|
+
class InterviewExceptionEntry:
|
12
|
+
"""Class to record an exception that occurred during the interview.
|
13
|
+
|
14
|
+
>>> entry = InterviewExceptionEntry.example()
|
15
|
+
>>> entry.to_dict()['exception']
|
16
|
+
"ValueError('An error occurred.')"
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, exception: Exception, traceback_format="html"):
|
20
|
+
self.time = datetime.datetime.now().isoformat()
|
21
|
+
self.exception = exception
|
22
|
+
self.traceback_format = traceback_format
|
23
|
+
|
24
|
+
def __getitem__(self, key):
|
25
|
+
# Support dict-like access obj['a']
|
26
|
+
return str(getattr(self, key))
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def example(cls):
|
30
|
+
try:
|
31
|
+
raise ValueError("An error occurred.")
|
32
|
+
except Exception as e:
|
33
|
+
entry = InterviewExceptionEntry(e)
|
34
|
+
return entry
|
35
|
+
|
36
|
+
@property
|
37
|
+
def traceback(self):
|
38
|
+
"""Return the exception as HTML."""
|
39
|
+
if self.traceback_format == "html":
|
40
|
+
return self.html_traceback
|
41
|
+
else:
|
42
|
+
return self.text_traceback
|
43
|
+
|
44
|
+
@property
|
45
|
+
def text_traceback(self):
|
46
|
+
"""
|
47
|
+
>>> entry = InterviewExceptionEntry.example()
|
48
|
+
>>> entry.text_traceback
|
49
|
+
'Traceback (most recent call last):...'
|
50
|
+
"""
|
51
|
+
e = self.exception
|
52
|
+
tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
53
|
+
return tb_str
|
54
|
+
|
55
|
+
@property
|
56
|
+
def html_traceback(self):
|
57
|
+
from rich.console import Console
|
58
|
+
from rich.table import Table
|
59
|
+
from rich.traceback import Traceback
|
60
|
+
|
61
|
+
from io import StringIO
|
62
|
+
|
63
|
+
html_output = StringIO()
|
64
|
+
|
65
|
+
console = Console(file=html_output, record=True)
|
66
|
+
|
67
|
+
tb = Traceback.from_exception(
|
68
|
+
type(self.exception),
|
69
|
+
self.exception,
|
70
|
+
self.exception.__traceback__,
|
71
|
+
show_locals=True,
|
72
|
+
)
|
73
|
+
console.print(tb)
|
74
|
+
return html_output.getvalue()
|
75
|
+
|
76
|
+
def to_dict(self) -> dict:
|
77
|
+
"""Return the exception as a dictionary.
|
78
|
+
|
79
|
+
>>> entry = InterviewExceptionEntry.example()
|
80
|
+
>>> entry.to_dict()['exception']
|
81
|
+
"ValueError('An error occurred.')"
|
82
|
+
|
83
|
+
"""
|
84
|
+
return {
|
85
|
+
"exception": repr(self.exception),
|
86
|
+
"time": self.time,
|
87
|
+
"traceback": self.traceback,
|
88
|
+
}
|
89
|
+
|
90
|
+
def push(self):
|
91
|
+
from edsl import Coop
|
92
|
+
|
93
|
+
coop = Coop()
|
94
|
+
results = coop.error_create(self.to_dict())
|
95
|
+
return results
|
96
|
+
|
97
|
+
|
98
|
+
if __name__ == "__main__":
|
99
|
+
import doctest
|
100
|
+
|
101
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|