edsl 0.1.31.dev3__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +7 -2
  3. edsl/agents/PromptConstructionMixin.py +35 -15
  4. edsl/config.py +15 -1
  5. edsl/conjure/Conjure.py +6 -0
  6. edsl/coop/coop.py +4 -0
  7. edsl/data/CacheHandler.py +3 -4
  8. edsl/enums.py +5 -0
  9. edsl/exceptions/general.py +10 -8
  10. edsl/inference_services/AwsBedrock.py +110 -0
  11. edsl/inference_services/AzureAI.py +197 -0
  12. edsl/inference_services/DeepInfraService.py +6 -91
  13. edsl/inference_services/GroqService.py +18 -0
  14. edsl/inference_services/InferenceServicesCollection.py +13 -8
  15. edsl/inference_services/OllamaService.py +18 -0
  16. edsl/inference_services/OpenAIService.py +68 -21
  17. edsl/inference_services/models_available_cache.py +31 -0
  18. edsl/inference_services/registry.py +14 -1
  19. edsl/jobs/Jobs.py +103 -21
  20. edsl/jobs/buckets/TokenBucket.py +12 -4
  21. edsl/jobs/interviews/Interview.py +31 -9
  22. edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
  23. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +49 -33
  24. edsl/jobs/interviews/interview_exception_tracking.py +68 -10
  25. edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
  26. edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
  27. edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
  28. edsl/jobs/tasks/TaskCreators.py +8 -2
  29. edsl/jobs/tasks/TaskHistory.py +145 -1
  30. edsl/language_models/LanguageModel.py +62 -41
  31. edsl/language_models/registry.py +4 -0
  32. edsl/questions/QuestionBudget.py +0 -1
  33. edsl/questions/QuestionCheckBox.py +0 -1
  34. edsl/questions/QuestionExtract.py +0 -1
  35. edsl/questions/QuestionFreeText.py +2 -9
  36. edsl/questions/QuestionList.py +0 -1
  37. edsl/questions/QuestionMultipleChoice.py +1 -2
  38. edsl/questions/QuestionNumerical.py +0 -1
  39. edsl/questions/QuestionRank.py +0 -1
  40. edsl/results/DatasetExportMixin.py +33 -3
  41. edsl/scenarios/Scenario.py +14 -0
  42. edsl/scenarios/ScenarioList.py +216 -13
  43. edsl/scenarios/ScenarioListExportMixin.py +15 -4
  44. edsl/scenarios/ScenarioListPdfMixin.py +3 -0
  45. edsl/surveys/Rule.py +5 -2
  46. edsl/surveys/Survey.py +84 -1
  47. edsl/surveys/SurveyQualtricsImport.py +213 -0
  48. edsl/utilities/utilities.py +31 -0
  49. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/METADATA +5 -1
  50. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/RECORD +52 -46
  51. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/LICENSE +0 -0
  52. {edsl-0.1.31.dev3.dist-info → edsl-0.1.32.dist-info}/WHEEL +0 -0
@@ -0,0 +1,18 @@
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+ import groq
5
+
6
+
7
+ class GroqService(OpenAIService):
8
+ """DeepInfra service class."""
9
+
10
+ _inference_service_ = "groq"
11
+ _env_key_name_ = "GROQ_API_KEY"
12
+
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
+
16
+ # _base_url_ = "https://api.deepinfra.com/v1/openai"
17
+ _base_url_ = None
18
+ _models_list_cache: List[str] = []
@@ -15,18 +15,19 @@ class InferenceServicesCollection:
15
15
  cls.added_models[service_name].append(model_name)
16
16
 
17
17
  @staticmethod
18
- def _get_service_available(service) -> list[str]:
18
+ def _get_service_available(service, warn: bool = False) -> list[str]:
19
19
  from_api = True
20
20
  try:
21
21
  service_models = service.available()
22
22
  except Exception as e:
23
- warnings.warn(
24
- f"""Error getting models for {service._inference_service_}.
25
- Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
26
- See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
27
- Relying on cache.""",
28
- UserWarning,
29
- )
23
+ if warn:
24
+ warnings.warn(
25
+ f"""Error getting models for {service._inference_service_}.
26
+ Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
27
+ See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
28
+ Relying on cache.""",
29
+ UserWarning,
30
+ )
30
31
  from edsl.inference_services.models_available_cache import models_available
31
32
 
32
33
  service_models = models_available.get(service._inference_service_, [])
@@ -60,4 +61,8 @@ class InferenceServicesCollection:
60
61
  if service_name is None or service_name == service._inference_service_:
61
62
  return service.create_model(model_name)
62
63
 
64
+ # if model_name == "test":
65
+ # from edsl.language_models import LanguageModel
66
+ # return LanguageModel(test = True)
67
+
63
68
  raise Exception(f"Model {model_name} not found in any of the services")
@@ -0,0 +1,18 @@
1
+ import aiohttp
2
+ import json
3
+ import requests
4
+ from typing import Any, List
5
+
6
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
+ from edsl.language_models import LanguageModel
8
+
9
+ from edsl.inference_services.OpenAIService import OpenAIService
10
+
11
+
12
+ class OllamaService(OpenAIService):
13
+ """DeepInfra service class."""
14
+
15
+ _inference_service_ = "ollama"
16
+ _env_key_name_ = "DEEP_INFRA_API_KEY"
17
+ _base_url_ = "http://localhost:11434/v1"
18
+ _models_list_cache: List[str] = []
@@ -1,10 +1,14 @@
1
1
  from typing import Any, List
2
2
  import re
3
- from openai import AsyncOpenAI
3
+ import os
4
+
5
+ # from openai import AsyncOpenAI
6
+ import openai
4
7
 
5
8
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
9
  from edsl.language_models import LanguageModel
7
10
  from edsl.inference_services.rate_limits_cache import rate_limits
11
+ from edsl.utilities.utilities import fix_partial_correct_response
8
12
 
9
13
 
10
14
  class OpenAIService(InferenceServiceABC):
@@ -12,6 +16,22 @@ class OpenAIService(InferenceServiceABC):
12
16
 
13
17
  _inference_service_ = "openai"
14
18
  _env_key_name_ = "OPENAI_API_KEY"
19
+ _base_url_ = None
20
+
21
+ _sync_client_ = openai.OpenAI
22
+ _async_client_ = openai.AsyncOpenAI
23
+
24
+ @classmethod
25
+ def sync_client(cls):
26
+ return cls._sync_client_(
27
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
28
+ )
29
+
30
+ @classmethod
31
+ def async_client(cls):
32
+ return cls._async_client_(
33
+ api_key=os.getenv(cls._env_key_name_), base_url=cls._base_url_
34
+ )
15
35
 
16
36
  # TODO: Make this a coop call
17
37
  model_exclude_list = [
@@ -31,16 +51,24 @@ class OpenAIService(InferenceServiceABC):
31
51
  ]
32
52
  _models_list_cache: List[str] = []
33
53
 
54
+ @classmethod
55
+ def get_model_list(cls):
56
+ raw_list = cls.sync_client().models.list()
57
+ if hasattr(raw_list, "data"):
58
+ return raw_list.data
59
+ else:
60
+ return raw_list
61
+
34
62
  @classmethod
35
63
  def available(cls) -> List[str]:
36
- from openai import OpenAI
64
+ # from openai import OpenAI
37
65
 
38
66
  if not cls._models_list_cache:
39
67
  try:
40
- client = OpenAI()
68
+ # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
41
69
  cls._models_list_cache = [
42
70
  m.id
43
- for m in client.models.list()
71
+ for m in cls.get_model_list()
44
72
  if m.id not in cls.model_exclude_list
45
73
  ]
46
74
  except Exception as e:
@@ -78,15 +106,24 @@ class OpenAIService(InferenceServiceABC):
78
106
  "top_logprobs": 3,
79
107
  }
80
108
 
109
+ def sync_client(self):
110
+ return cls.sync_client()
111
+
112
+ def async_client(self):
113
+ return cls.async_client()
114
+
81
115
  @classmethod
82
116
  def available(cls) -> list[str]:
83
- client = openai.OpenAI()
84
- return client.models.list()
117
+ # import openai
118
+ # client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
119
+ # return client.models.list()
120
+ return cls.sync_client().models.list()
85
121
 
86
122
  def get_headers(self) -> dict[str, Any]:
87
- from openai import OpenAI
123
+ # from openai import OpenAI
88
124
 
89
- client = OpenAI()
125
+ # client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
126
+ client = self.sync_client()
90
127
  response = client.chat.completions.with_raw_response.create(
91
128
  messages=[
92
129
  {
@@ -124,8 +161,8 @@ class OpenAIService(InferenceServiceABC):
124
161
  encoded_image=None,
125
162
  ) -> dict[str, Any]:
126
163
  """Calls the OpenAI API and returns the API response."""
127
- content = [{"type": "text", "text": user_prompt}]
128
164
  if encoded_image:
165
+ content = [{"type": "text", "text": user_prompt}]
129
166
  content.append(
130
167
  {
131
168
  "type": "image_url",
@@ -134,21 +171,28 @@ class OpenAIService(InferenceServiceABC):
134
171
  },
135
172
  }
136
173
  )
137
- self.client = AsyncOpenAI()
138
- response = await self.client.chat.completions.create(
139
- model=self.model,
140
- messages=[
174
+ else:
175
+ content = user_prompt
176
+ # self.client = AsyncOpenAI(
177
+ # api_key = os.getenv(cls._env_key_name_),
178
+ # base_url = cls._base_url_
179
+ # )
180
+ client = self.async_client()
181
+ params = {
182
+ "model": self.model,
183
+ "messages": [
141
184
  {"role": "system", "content": system_prompt},
142
185
  {"role": "user", "content": content},
143
186
  ],
144
- temperature=self.temperature,
145
- max_tokens=self.max_tokens,
146
- top_p=self.top_p,
147
- frequency_penalty=self.frequency_penalty,
148
- presence_penalty=self.presence_penalty,
149
- logprobs=self.logprobs,
150
- top_logprobs=self.top_logprobs if self.logprobs else None,
151
- )
187
+ "temperature": self.temperature,
188
+ "max_tokens": self.max_tokens,
189
+ "top_p": self.top_p,
190
+ "frequency_penalty": self.frequency_penalty,
191
+ "presence_penalty": self.presence_penalty,
192
+ "logprobs": self.logprobs,
193
+ "top_logprobs": self.top_logprobs if self.logprobs else None,
194
+ }
195
+ response = await client.chat.completions.create(**params)
152
196
  return response.model_dump()
153
197
 
154
198
  @staticmethod
@@ -164,6 +208,9 @@ class OpenAIService(InferenceServiceABC):
164
208
  if match:
165
209
  return match.group(1)
166
210
  else:
211
+ out = fix_partial_correct_response(response)
212
+ if "error" not in out:
213
+ response = out["extracted_json"]
167
214
  return response
168
215
 
169
216
  LLM.__name__ = "LanguageModel"
@@ -66,4 +66,35 @@ models_available = {
66
66
  "openchat/openchat_3.5",
67
67
  ],
68
68
  "google": ["gemini-pro"],
69
+ "bedrock": [
70
+ "amazon.titan-tg1-large",
71
+ "amazon.titan-text-lite-v1",
72
+ "amazon.titan-text-express-v1",
73
+ "ai21.j2-grande-instruct",
74
+ "ai21.j2-jumbo-instruct",
75
+ "ai21.j2-mid",
76
+ "ai21.j2-mid-v1",
77
+ "ai21.j2-ultra",
78
+ "ai21.j2-ultra-v1",
79
+ "anthropic.claude-instant-v1",
80
+ "anthropic.claude-v2:1",
81
+ "anthropic.claude-v2",
82
+ "anthropic.claude-3-sonnet-20240229-v1:0",
83
+ "anthropic.claude-3-haiku-20240307-v1:0",
84
+ "anthropic.claude-3-opus-20240229-v1:0",
85
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
86
+ "cohere.command-text-v14",
87
+ "cohere.command-r-v1:0",
88
+ "cohere.command-r-plus-v1:0",
89
+ "cohere.command-light-text-v14",
90
+ "meta.llama3-8b-instruct-v1:0",
91
+ "meta.llama3-70b-instruct-v1:0",
92
+ "meta.llama3-1-8b-instruct-v1:0",
93
+ "meta.llama3-1-70b-instruct-v1:0",
94
+ "meta.llama3-1-405b-instruct-v1:0",
95
+ "mistral.mistral-7b-instruct-v0:2",
96
+ "mistral.mixtral-8x7b-instruct-v0:1",
97
+ "mistral.mistral-large-2402-v1:0",
98
+ "mistral.mistral-large-2407-v1:0",
99
+ ],
69
100
  }
@@ -6,7 +6,20 @@ from edsl.inference_services.OpenAIService import OpenAIService
6
6
  from edsl.inference_services.AnthropicService import AnthropicService
7
7
  from edsl.inference_services.DeepInfraService import DeepInfraService
8
8
  from edsl.inference_services.GoogleService import GoogleService
9
+ from edsl.inference_services.GroqService import GroqService
10
+ from edsl.inference_services.AwsBedrock import AwsBedrockService
11
+ from edsl.inference_services.AzureAI import AzureAIService
12
+ from edsl.inference_services.OllamaService import OllamaService
9
13
 
10
14
  default = InferenceServicesCollection(
11
- [OpenAIService, AnthropicService, DeepInfraService, GoogleService]
15
+ [
16
+ OpenAIService,
17
+ AnthropicService,
18
+ DeepInfraService,
19
+ GoogleService,
20
+ GroqService,
21
+ AwsBedrockService,
22
+ AzureAIService,
23
+ OllamaService,
24
+ ]
12
25
  )
edsl/jobs/Jobs.py CHANGED
@@ -39,6 +39,8 @@ class Jobs(Base):
39
39
 
40
40
  self.__bucket_collection = None
41
41
 
42
+ # these setters and getters are used to ensure that the agents, models, and scenarios are stored as AgentList, ModelList, and ScenarioList objects
43
+
42
44
  @property
43
45
  def models(self):
44
46
  return self._models
@@ -119,7 +121,9 @@ class Jobs(Base):
119
121
  - scenarios: traits of new scenarios are combined with traits of old existing. New scenarios will overwrite overlapping traits, and do not increase the number of scenarios in the instance
120
122
  - models: new models overwrite old models.
121
123
  """
122
- passed_objects = self._turn_args_to_list(args)
124
+ passed_objects = self._turn_args_to_list(
125
+ args
126
+ ) # objects can also be passed comma-separated
123
127
 
124
128
  current_objects, objects_key = self._get_current_objects_of_this_type(
125
129
  passed_objects[0]
@@ -176,17 +180,27 @@ class Jobs(Base):
176
180
  from edsl.agents.Agent import Agent
177
181
  from edsl.scenarios.Scenario import Scenario
178
182
  from edsl.scenarios.ScenarioList import ScenarioList
183
+ from edsl.language_models.ModelList import ModelList
179
184
 
180
185
  if isinstance(object, Agent):
181
186
  return AgentList
182
187
  elif isinstance(object, Scenario):
183
188
  return ScenarioList
189
+ elif isinstance(object, ModelList):
190
+ return ModelList
184
191
  else:
185
192
  return list
186
193
 
187
194
  @staticmethod
188
195
  def _turn_args_to_list(args):
189
- """Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments."""
196
+ """Return a list of the first argument if it is a sequence, otherwise returns a list of all the arguments.
197
+
198
+ Example:
199
+
200
+ >>> Jobs._turn_args_to_list([1,2,3])
201
+ [1, 2, 3]
202
+
203
+ """
190
204
 
191
205
  def did_user_pass_a_sequence(args):
192
206
  """Return True if the user passed a sequence, False otherwise.
@@ -209,7 +223,7 @@ class Jobs(Base):
209
223
  return container_class(args)
210
224
 
211
225
  def _get_current_objects_of_this_type(
212
- self, object: Union[Agent, Scenario, LanguageModel]
226
+ self, object: Union["Agent", "Scenario", "LanguageModel"]
213
227
  ) -> tuple[list, str]:
214
228
  from edsl.agents.Agent import Agent
215
229
  from edsl.scenarios.Scenario import Scenario
@@ -292,7 +306,11 @@ class Jobs(Base):
292
306
 
293
307
  @classmethod
294
308
  def from_interviews(cls, interview_list):
295
- """Return a Jobs instance from a list of interviews."""
309
+ """Return a Jobs instance from a list of interviews.
310
+
311
+ This is useful when you have, say, a list of failed interviews and you want to create
312
+ a new job with only those interviews.
313
+ """
296
314
  survey = interview_list[0].survey
297
315
  # get all the models
298
316
  models = list(set([interview.model for interview in interview_list]))
@@ -308,6 +326,8 @@ class Jobs(Base):
308
326
  Note that this sets the agents, model and scenarios if they have not been set. This is a side effect of the method.
309
327
  This is useful because a user can create a job without setting the agents, models, or scenarios, and the job will still run,
310
328
  with us filling in defaults.
329
+
330
+
311
331
  """
312
332
  # if no agents, models, or scenarios are set, set them to defaults
313
333
  from edsl.agents.Agent import Agent
@@ -319,7 +339,11 @@ class Jobs(Base):
319
339
  self.scenarios = self.scenarios or [Scenario()]
320
340
  for agent, scenario, model in product(self.agents, self.scenarios, self.models):
321
341
  yield Interview(
322
- survey=self.survey, agent=agent, scenario=scenario, model=model
342
+ survey=self.survey,
343
+ agent=agent,
344
+ scenario=scenario,
345
+ model=model,
346
+ skip_retry=self.skip_retry,
323
347
  )
324
348
 
325
349
  def create_bucket_collection(self) -> BucketCollection:
@@ -359,10 +383,16 @@ class Jobs(Base):
359
383
  return links
360
384
 
361
385
  def __hash__(self):
362
- """Allow the model to be used as a key in a dictionary."""
386
+ """Allow the model to be used as a key in a dictionary.
387
+
388
+ >>> from edsl.jobs import Jobs
389
+ >>> hash(Jobs.example())
390
+ 846655441787442972
391
+
392
+ """
363
393
  from edsl.utilities.utilities import dict_hash
364
394
 
365
- return dict_hash(self.to_dict())
395
+ return dict_hash(self._to_dict())
366
396
 
367
397
  def _output(self, message) -> None:
368
398
  """Check if a Job is verbose. If so, print the message."""
@@ -390,11 +420,27 @@ class Jobs(Base):
390
420
  Traceback (most recent call last):
391
421
  ...
392
422
  ValueError: The following parameters are in the scenarios but not in the survey: {'plop'}
423
+
424
+ >>> q = QuestionFreeText(question_text = "Hello", question_name = "ugly_question")
425
+ >>> s = Scenario({'ugly_question': "B"})
426
+ >>> j = Jobs(survey = Survey(questions=[q])).by(s)
427
+ >>> j._check_parameters()
428
+ Traceback (most recent call last):
429
+ ...
430
+ ValueError: The following names are in both the survey question_names and the scenario keys: {'ugly_question'}. This will create issues.
393
431
  """
394
432
  survey_parameters: set = self.survey.parameters
395
433
  scenario_parameters: set = self.scenarios.parameters
396
434
 
397
- msg1, msg2 = None, None
435
+ msg0, msg1, msg2 = None, None, None
436
+
437
+ # look for key issues
438
+ if intersection := set(self.scenarios.parameters) & set(
439
+ self.survey.question_names
440
+ ):
441
+ msg0 = f"The following names are in both the survey question_names and the scenario keys: {intersection}. This will create issues."
442
+
443
+ raise ValueError(msg0)
398
444
 
399
445
  if in_survey_but_not_in_scenarios := survey_parameters - scenario_parameters:
400
446
  msg1 = f"The following parameters are in the survey but not in the scenarios: {in_survey_but_not_in_scenarios}"
@@ -409,6 +455,12 @@ class Jobs(Base):
409
455
  if warn:
410
456
  warnings.warn(message)
411
457
 
458
+ @property
459
+ def skip_retry(self):
460
+ if not hasattr(self, "_skip_retry"):
461
+ return False
462
+ return self._skip_retry
463
+
412
464
  def run(
413
465
  self,
414
466
  n: int = 1,
@@ -423,6 +475,7 @@ class Jobs(Base):
423
475
  print_exceptions=True,
424
476
  remote_cache_description: Optional[str] = None,
425
477
  remote_inference_description: Optional[str] = None,
478
+ skip_retry: bool = False,
426
479
  ) -> Results:
427
480
  """
428
481
  Runs the Job: conducts Interviews and returns their results.
@@ -441,6 +494,7 @@ class Jobs(Base):
441
494
  from edsl.coop.coop import Coop
442
495
 
443
496
  self._check_parameters()
497
+ self._skip_retry = skip_retry
444
498
 
445
499
  if batch_mode is not None:
446
500
  raise NotImplementedError(
@@ -475,6 +529,7 @@ class Jobs(Base):
475
529
  self,
476
530
  description=remote_inference_description,
477
531
  status="queued",
532
+ iterations=n,
478
533
  )
479
534
  time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
480
535
  job_uuid = remote_job_creation_data.get("uuid")
@@ -629,13 +684,17 @@ class Jobs(Base):
629
684
  results = JobsRunnerAsyncio(self).run(*args, **kwargs)
630
685
  return results
631
686
 
632
- async def run_async(self, cache=None, **kwargs):
633
- """Run the job asynchronously."""
634
- results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
687
+ async def run_async(self, cache=None, n=1, **kwargs):
688
+ """Run asynchronously."""
689
+ results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
635
690
  return results
636
691
 
637
692
  def all_question_parameters(self):
638
- """Return all the fields in the questions in the survey."""
693
+ """Return all the fields in the questions in the survey.
694
+ >>> from edsl.jobs import Jobs
695
+ >>> Jobs.example().all_question_parameters()
696
+ {'period'}
697
+ """
639
698
  return set.union(*[question.parameters for question in self.survey.questions])
640
699
 
641
700
  #######################
@@ -676,15 +735,19 @@ class Jobs(Base):
676
735
  #######################
677
736
  # Serialization methods
678
737
  #######################
738
+
739
+ def _to_dict(self):
740
+ return {
741
+ "survey": self.survey._to_dict(),
742
+ "agents": [agent._to_dict() for agent in self.agents],
743
+ "models": [model._to_dict() for model in self.models],
744
+ "scenarios": [scenario._to_dict() for scenario in self.scenarios],
745
+ }
746
+
679
747
  @add_edsl_version
680
748
  def to_dict(self) -> dict:
681
749
  """Convert the Jobs instance to a dictionary."""
682
- return {
683
- "survey": self.survey.to_dict(),
684
- "agents": [agent.to_dict() for agent in self.agents],
685
- "models": [model.to_dict() for model in self.models],
686
- "scenarios": [scenario.to_dict() for scenario in self.scenarios],
687
- }
750
+ return self._to_dict()
688
751
 
689
752
  @classmethod
690
753
  @remove_edsl_version
@@ -703,7 +766,13 @@ class Jobs(Base):
703
766
  )
704
767
 
705
768
  def __eq__(self, other: Jobs) -> bool:
706
- """Return True if the Jobs instance is equal to another Jobs instance."""
769
+ """Return True if the Jobs instance is equal to another Jobs instance.
770
+
771
+ >>> from edsl.jobs import Jobs
772
+ >>> Jobs.example() == Jobs.example()
773
+ True
774
+
775
+ """
707
776
  return self.to_dict() == other.to_dict()
708
777
 
709
778
  #######################
@@ -711,11 +780,16 @@ class Jobs(Base):
711
780
  #######################
712
781
  @classmethod
713
782
  def example(
714
- cls, throw_exception_probability: int = 0, randomize: bool = False
783
+ cls,
784
+ throw_exception_probability: float = 0.0,
785
+ randomize: bool = False,
786
+ test_model=False,
715
787
  ) -> Jobs:
716
788
  """Return an example Jobs instance.
717
789
 
718
790
  :param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
791
+ :param randomize: whether to randomize the job by adding a random string to the period
792
+ :param test_model: whether to use a test model
719
793
 
720
794
  >>> Jobs.example()
721
795
  Jobs(...)
@@ -729,6 +803,11 @@ class Jobs(Base):
729
803
 
730
804
  addition = "" if not randomize else str(uuid4())
731
805
 
806
+ if test_model:
807
+ from edsl.language_models import LanguageModel
808
+
809
+ m = LanguageModel.example(test_model=True)
810
+
732
811
  # (status, question, period)
733
812
  agent_answers = {
734
813
  ("Joyful", "how_feeling", "morning"): "OK",
@@ -776,7 +855,10 @@ class Jobs(Base):
776
855
  Scenario({"period": "afternoon"}),
777
856
  ]
778
857
  )
779
- job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
858
+ if test_model:
859
+ job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
860
+ else:
861
+ job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
780
862
 
781
863
  return job
782
864
 
@@ -100,7 +100,9 @@ class TokenBucket:
100
100
  available_tokens = min(self.capacity, self.tokens + refill_amount)
101
101
  return max(0, requested_tokens - available_tokens) / self.refill_rate
102
102
 
103
- async def get_tokens(self, amount: Union[int, float] = 1) -> None:
103
+ async def get_tokens(
104
+ self, amount: Union[int, float] = 1, cheat_bucket_capacity=True
105
+ ) -> None:
104
106
  """Wait for the specified number of tokens to become available.
105
107
 
106
108
 
@@ -116,14 +118,20 @@ class TokenBucket:
116
118
  True
117
119
 
118
120
  >>> bucket = TokenBucket(bucket_name="test", bucket_type="test", capacity=10, refill_rate=1)
119
- >>> asyncio.run(bucket.get_tokens(11))
121
+ >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=False))
120
122
  Traceback (most recent call last):
121
123
  ...
122
124
  ValueError: Requested amount exceeds bucket capacity. Bucket capacity: 10, requested amount: 11. As the bucket never overflows, the requested amount will never be available.
125
+ >>> asyncio.run(bucket.get_tokens(11, cheat_bucket_capacity=True))
123
126
  """
124
127
  if amount > self.capacity:
125
- msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
126
- raise ValueError(msg)
128
+ if not cheat_bucket_capacity:
129
+ msg = f"Requested amount exceeds bucket capacity. Bucket capacity: {self.capacity}, requested amount: {amount}. As the bucket never overflows, the requested amount will never be available."
130
+ raise ValueError(msg)
131
+ else:
132
+ self.tokens = 0 # clear the bucket but let it go through
133
+ return
134
+
127
135
  while self.tokens < amount:
128
136
  self.refill()
129
137
  await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting