edsl 0.1.29.dev3__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. edsl/Base.py +18 -18
  2. edsl/__init__.py +23 -23
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +79 -41
  5. edsl/agents/AgentList.py +26 -26
  6. edsl/agents/Invigilator.py +19 -2
  7. edsl/agents/InvigilatorBase.py +15 -10
  8. edsl/agents/PromptConstructionMixin.py +342 -100
  9. edsl/agents/descriptors.py +2 -1
  10. edsl/base/Base.py +289 -0
  11. edsl/config.py +2 -1
  12. edsl/conjure/InputData.py +39 -8
  13. edsl/conversation/car_buying.py +1 -1
  14. edsl/coop/coop.py +187 -150
  15. edsl/coop/utils.py +43 -75
  16. edsl/data/Cache.py +41 -18
  17. edsl/data/CacheEntry.py +6 -7
  18. edsl/data/SQLiteDict.py +11 -3
  19. edsl/data_transfer_models.py +4 -0
  20. edsl/jobs/Answers.py +15 -1
  21. edsl/jobs/Jobs.py +108 -49
  22. edsl/jobs/buckets/ModelBuckets.py +14 -2
  23. edsl/jobs/buckets/TokenBucket.py +32 -5
  24. edsl/jobs/interviews/Interview.py +99 -79
  25. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +19 -24
  26. edsl/jobs/runners/JobsRunnerAsyncio.py +16 -16
  27. edsl/jobs/tasks/QuestionTaskCreator.py +10 -6
  28. edsl/jobs/tasks/TaskHistory.py +4 -3
  29. edsl/language_models/LanguageModel.py +17 -17
  30. edsl/language_models/ModelList.py +1 -1
  31. edsl/language_models/repair.py +8 -7
  32. edsl/notebooks/Notebook.py +47 -10
  33. edsl/prompts/Prompt.py +31 -19
  34. edsl/questions/QuestionBase.py +38 -13
  35. edsl/questions/QuestionBudget.py +5 -6
  36. edsl/questions/QuestionCheckBox.py +7 -3
  37. edsl/questions/QuestionExtract.py +5 -3
  38. edsl/questions/QuestionFreeText.py +7 -5
  39. edsl/questions/QuestionFunctional.py +34 -5
  40. edsl/questions/QuestionList.py +3 -4
  41. edsl/questions/QuestionMultipleChoice.py +68 -12
  42. edsl/questions/QuestionNumerical.py +4 -3
  43. edsl/questions/QuestionRank.py +5 -3
  44. edsl/questions/__init__.py +4 -3
  45. edsl/questions/descriptors.py +46 -4
  46. edsl/questions/question_registry.py +20 -31
  47. edsl/questions/settings.py +1 -1
  48. edsl/results/Dataset.py +31 -0
  49. edsl/results/DatasetExportMixin.py +570 -0
  50. edsl/results/Result.py +66 -70
  51. edsl/results/Results.py +160 -68
  52. edsl/results/ResultsDBMixin.py +7 -3
  53. edsl/results/ResultsExportMixin.py +22 -537
  54. edsl/results/ResultsGGMixin.py +3 -3
  55. edsl/results/ResultsToolsMixin.py +5 -5
  56. edsl/scenarios/FileStore.py +299 -0
  57. edsl/scenarios/Scenario.py +16 -24
  58. edsl/scenarios/ScenarioList.py +42 -17
  59. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  60. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  61. edsl/scenarios/__init__.py +1 -0
  62. edsl/study/Study.py +8 -16
  63. edsl/surveys/MemoryPlan.py +11 -4
  64. edsl/surveys/Survey.py +88 -17
  65. edsl/surveys/SurveyExportMixin.py +4 -2
  66. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  67. edsl/tools/plotting.py +4 -2
  68. edsl/utilities/__init__.py +21 -21
  69. edsl/utilities/interface.py +66 -45
  70. edsl/utilities/utilities.py +11 -13
  71. {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/METADATA +11 -10
  72. {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/RECORD +74 -71
  73. {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/WHEEL +1 -1
  74. edsl-0.1.29.dev3.dist-info/entry_points.txt +0 -3
  75. {edsl-0.1.29.dev3.dist-info → edsl-0.1.30.dist-info}/LICENSE +0 -0
edsl/data/Cache.py CHANGED
@@ -7,17 +7,10 @@ import json
7
7
  import os
8
8
  import warnings
9
9
  from typing import Optional, Union
10
-
11
- from edsl.config import CONFIG
12
- from edsl.data.CacheEntry import CacheEntry
13
- from edsl.data.SQLiteDict import SQLiteDict
14
10
  from edsl.Base import Base
11
+ from edsl.data.CacheEntry import CacheEntry
15
12
  from edsl.utilities.utilities import dict_hash
16
-
17
- from edsl.utilities.decorators import (
18
- add_edsl_version,
19
- remove_edsl_version,
20
- )
13
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
21
14
 
22
15
 
23
16
  class Cache(Base):
@@ -38,9 +31,10 @@ class Cache(Base):
38
31
  self,
39
32
  *,
40
33
  filename: Optional[str] = None,
41
- data: Optional[Union[SQLiteDict, dict]] = None,
34
+ data: Optional[Union["SQLiteDict", dict]] = None,
42
35
  immediate_write: bool = True,
43
36
  method=None,
37
+ verbose=False,
44
38
  ):
45
39
  """
46
40
  Create two dictionaries to store the cache data.
@@ -59,6 +53,7 @@ class Cache(Base):
59
53
  self.new_entries = {}
60
54
  self.new_entries_to_write_later = {}
61
55
  self.coop = None
56
+ self.verbose = verbose
62
57
 
63
58
  self.filename = filename
64
59
  if filename and data:
@@ -104,6 +99,8 @@ class Cache(Base):
104
99
 
105
100
  def _perform_checks(self):
106
101
  """Perform checks on the cache."""
102
+ from edsl.data.CacheEntry import CacheEntry
103
+
107
104
  if any(not isinstance(value, CacheEntry) for value in self.data.values()):
108
105
  raise Exception("Not all values are CacheEntry instances")
109
106
  if self.method is not None:
@@ -120,7 +117,7 @@ class Cache(Base):
120
117
  system_prompt: str,
121
118
  user_prompt: str,
122
119
  iteration: int,
123
- ) -> Union[None, str]:
120
+ ) -> tuple(Union[None, str], str):
124
121
  """
125
122
  Fetch a value (LLM output) from the cache.
126
123
 
@@ -133,11 +130,13 @@ class Cache(Base):
133
130
  Return None if the response is not found.
134
131
 
135
132
  >>> c = Cache()
136
- >>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1) is None
133
+ >>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
137
134
  True
138
135
 
139
136
 
140
137
  """
138
+ from edsl.data.CacheEntry import CacheEntry
139
+
141
140
  key = CacheEntry.gen_key(
142
141
  model=model,
143
142
  parameters=parameters,
@@ -147,8 +146,13 @@ class Cache(Base):
147
146
  )
148
147
  entry = self.data.get(key, None)
149
148
  if entry is not None:
149
+ if self.verbose:
150
+ print(f"Cache hit for key: {key}")
150
151
  self.fetched_data[key] = entry
151
- return None if entry is None else entry.output
152
+ else:
153
+ if self.verbose:
154
+ print(f"Cache miss for key: {key}")
155
+ return None if entry is None else entry.output, key
152
156
 
153
157
  def store(
154
158
  self,
@@ -171,6 +175,7 @@ class Cache(Base):
171
175
  * If `immediate_write` is True , the key-value pair is added to `self.data`
172
176
  * If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
173
177
  """
178
+
174
179
  entry = CacheEntry(
175
180
  model=model,
176
181
  parameters=parameters,
@@ -188,13 +193,14 @@ class Cache(Base):
188
193
  return key
189
194
 
190
195
  def add_from_dict(
191
- self, new_data: dict[str, CacheEntry], write_now: Optional[bool] = True
196
+ self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
192
197
  ) -> None:
193
198
  """
194
199
  Add entries to the cache from a dictionary.
195
200
 
196
201
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
197
202
  """
203
+
198
204
  for key, value in new_data.items():
199
205
  if key in self.data:
200
206
  if value != self.data[key]:
@@ -231,6 +237,8 @@ class Cache(Base):
231
237
 
232
238
  :param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
233
239
  """
240
+ from edsl.data.SQLiteDict import SQLiteDict
241
+
234
242
  db = SQLiteDict(db_path)
235
243
  new_data = {}
236
244
  for key, value in db.items():
@@ -242,6 +250,8 @@ class Cache(Base):
242
250
  """
243
251
  Construct a Cache from a SQLite database.
244
252
  """
253
+ from edsl.data.SQLiteDict import SQLiteDict
254
+
245
255
  return cls(data=SQLiteDict(db_path))
246
256
 
247
257
  @classmethod
@@ -268,6 +278,8 @@ class Cache(Base):
268
278
  * If `db_path` is provided, the cache will be stored in an SQLite database.
269
279
  """
270
280
  # if a file doesn't exist at jsonfile, throw an error
281
+ from edsl.data.SQLiteDict import SQLiteDict
282
+
271
283
  if not os.path.exists(jsonlfile):
272
284
  raise FileNotFoundError(f"File {jsonlfile} not found")
273
285
 
@@ -286,6 +298,8 @@ class Cache(Base):
286
298
  """
287
299
  ## TODO: Check to make sure not over-writing (?)
288
300
  ## Should be added to SQLiteDict constructor (?)
301
+ from edsl.data.SQLiteDict import SQLiteDict
302
+
289
303
  new_data = SQLiteDict(db_path)
290
304
  for key, value in self.data.items():
291
305
  new_data[key] = value
@@ -340,6 +354,9 @@ class Cache(Base):
340
354
  for key, entry in self.new_entries_to_write_later.items():
341
355
  self.data[key] = entry
342
356
 
357
+ if self.filename:
358
+ self.write(self.filename)
359
+
343
360
  ####################
344
361
  # DUNDER / USEFUL
345
362
  ####################
@@ -456,12 +473,18 @@ class Cache(Base):
456
473
  webbrowser.open("file://" + filepath)
457
474
 
458
475
  @classmethod
459
- def example(cls) -> Cache:
476
+ def example(cls, randomize: bool = False) -> Cache:
460
477
  """
461
- Return an example Cache.
462
- The example Cache has one entry.
478
+ Returns an example Cache instance.
479
+
480
+ :param randomize: If True, uses CacheEntry's randomize method.
463
481
  """
464
- return cls(data={CacheEntry.example().key: CacheEntry.example()})
482
+ return cls(
483
+ data={
484
+ CacheEntry.example(randomize).key: CacheEntry.example(),
485
+ CacheEntry.example(randomize).key: CacheEntry.example(),
486
+ }
487
+ )
465
488
 
466
489
 
467
490
  if __name__ == "__main__":
edsl/data/CacheEntry.py CHANGED
@@ -2,11 +2,8 @@ from __future__ import annotations
2
2
  import json
3
3
  import datetime
4
4
  import hashlib
5
- import random
6
5
  from typing import Optional
7
-
8
-
9
- # TODO: Timestamp should probably be float?
6
+ from uuid import uuid4
10
7
 
11
8
 
12
9
  class CacheEntry:
@@ -151,10 +148,12 @@ class CacheEntry:
151
148
  @classmethod
152
149
  def example(cls, randomize: bool = False) -> CacheEntry:
153
150
  """
154
- Returns a CacheEntry example.
151
+ Returns an example CacheEntry instance.
152
+
153
+ :param randomize: If True, adds a random string to the system prompt.
155
154
  """
156
- # if random, create a random number for 0-100
157
- addition = "" if not randomize else str(random.randint(0, 1000))
155
+ # if random, create a uuid
156
+ addition = "" if not randomize else str(uuid4())
158
157
  return CacheEntry(
159
158
  model="gpt-3.5-turbo",
160
159
  parameters={"temperature": 0.5},
edsl/data/SQLiteDict.py CHANGED
@@ -1,9 +1,7 @@
1
1
  from __future__ import annotations
2
2
  import json
3
- from sqlalchemy import create_engine
4
- from sqlalchemy.exc import SQLAlchemyError
5
- from sqlalchemy.orm import sessionmaker
6
3
  from typing import Any, Generator, Optional, Union
4
+
7
5
  from edsl.config import CONFIG
8
6
  from edsl.data.CacheEntry import CacheEntry
9
7
  from edsl.data.orm import Base, Data
@@ -25,10 +23,16 @@ class SQLiteDict:
25
23
  >>> import os; os.unlink(temp_db_path) # Clean up the temp file after the test
26
24
 
27
25
  """
26
+ from sqlalchemy.exc import SQLAlchemyError
27
+ from sqlalchemy.orm import sessionmaker
28
+ from sqlalchemy import create_engine
29
+
28
30
  self.db_path = db_path or CONFIG.get("EDSL_DATABASE_PATH")
29
31
  if not self.db_path.startswith("sqlite:///"):
30
32
  self.db_path = f"sqlite:///{self.db_path}"
31
33
  try:
34
+ from edsl.data.orm import Base, Data
35
+
32
36
  self.engine = create_engine(self.db_path, echo=False, future=True)
33
37
  Base.metadata.create_all(self.engine)
34
38
  self.Session = sessionmaker(bind=self.engine)
@@ -55,6 +59,8 @@ class SQLiteDict:
55
59
  if not isinstance(value, CacheEntry):
56
60
  raise ValueError(f"Value must be a CacheEntry object (got {type(value)}).")
57
61
  with self.Session() as db:
62
+ from edsl.data.orm import Base, Data
63
+
58
64
  db.merge(Data(key=key, value=json.dumps(value.to_dict())))
59
65
  db.commit()
60
66
 
@@ -69,6 +75,8 @@ class SQLiteDict:
69
75
  True
70
76
  """
71
77
  with self.Session() as db:
78
+ from edsl.data.orm import Base, Data
79
+
72
80
  value = db.query(Data).filter_by(key=key).first()
73
81
  if not value:
74
82
  raise KeyError(f"Key '{key}' not found.")
@@ -17,6 +17,8 @@ class AgentResponseDict(UserDict):
17
17
  cached_response=None,
18
18
  raw_model_response=None,
19
19
  simple_model_raw_response=None,
20
+ cache_used=None,
21
+ cache_key=None,
20
22
  ):
21
23
  """Initialize the AgentResponseDict object."""
22
24
  usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
@@ -30,5 +32,7 @@ class AgentResponseDict(UserDict):
30
32
  "cached_response": cached_response,
31
33
  "raw_model_response": raw_model_response,
32
34
  "simple_model_raw_response": simple_model_raw_response,
35
+ "cache_used": cache_used,
36
+ "cache_key": cache_key,
33
37
  }
34
38
  )
edsl/jobs/Answers.py CHANGED
@@ -8,7 +8,15 @@ class Answers(UserDict):
8
8
  """Helper class to hold the answers to a survey."""
9
9
 
10
10
  def add_answer(self, response, question) -> None:
11
- """Add a response to the answers dictionary."""
11
+ """Add a response to the answers dictionary.
12
+
13
+ >>> from edsl import QuestionFreeText
14
+ >>> q = QuestionFreeText.example()
15
+ >>> answers = Answers()
16
+ >>> answers.add_answer({"answer": "yes"}, q)
17
+ >>> answers[q.question_name]
18
+ 'yes'
19
+ """
12
20
  answer = response.get("answer")
13
21
  comment = response.pop("comment", None)
14
22
  # record the answer
@@ -42,3 +50,9 @@ class Answers(UserDict):
42
50
  table.add_row(attr_name, repr(attr_value))
43
51
 
44
52
  return table
53
+
54
+
55
+ if __name__ == "__main__":
56
+ import doctest
57
+
58
+ doctest.testmod()
edsl/jobs/Jobs.py CHANGED
@@ -1,30 +1,15 @@
1
1
  # """The Jobs class is a collection of agents, scenarios and models and one survey."""
2
2
  from __future__ import annotations
3
- import os
4
3
  import warnings
5
4
  from itertools import product
6
5
  from typing import Optional, Union, Sequence, Generator
7
- from edsl import Model
8
- from edsl.agents import Agent
9
- from edsl.agents.AgentList import AgentList
6
+
10
7
  from edsl.Base import Base
11
- from edsl.data.Cache import Cache
12
- from edsl.data.CacheHandler import CacheHandler
13
- from edsl.results.Dataset import Dataset
14
8
 
15
- from edsl.exceptions.jobs import MissingRemoteInferenceError
16
9
  from edsl.exceptions import MissingAPIKeyError
17
10
  from edsl.jobs.buckets.BucketCollection import BucketCollection
18
11
  from edsl.jobs.interviews.Interview import Interview
19
- from edsl.language_models import LanguageModel
20
- from edsl.results import Results
21
- from edsl.scenarios import Scenario
22
- from edsl import ScenarioList
23
- from edsl.surveys import Survey
24
12
  from edsl.jobs.runners.JobsRunnerAsyncio import JobsRunnerAsyncio
25
-
26
- from edsl.language_models.ModelList import ModelList
27
-
28
13
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
29
14
 
30
15
 
@@ -37,10 +22,10 @@ class Jobs(Base):
37
22
 
38
23
  def __init__(
39
24
  self,
40
- survey: Survey,
41
- agents: Optional[list[Agent]] = None,
42
- models: Optional[list[LanguageModel]] = None,
43
- scenarios: Optional[list[Scenario]] = None,
25
+ survey: "Survey",
26
+ agents: Optional[list["Agent"]] = None,
27
+ models: Optional[list["LanguageModel"]] = None,
28
+ scenarios: Optional[list["Scenario"]] = None,
44
29
  ):
45
30
  """Initialize a Jobs instance.
46
31
 
@@ -50,8 +35,8 @@ class Jobs(Base):
50
35
  :param scenarios: a list of scenarios
51
36
  """
52
37
  self.survey = survey
53
- self.agents: AgentList = agents
54
- self.scenarios: ScenarioList = scenarios
38
+ self.agents: "AgentList" = agents
39
+ self.scenarios: "ScenarioList" = scenarios
55
40
  self.models = models
56
41
 
57
42
  self.__bucket_collection = None
@@ -62,6 +47,8 @@ class Jobs(Base):
62
47
 
63
48
  @models.setter
64
49
  def models(self, value):
50
+ from edsl import ModelList
51
+
65
52
  if value:
66
53
  if not isinstance(value, ModelList):
67
54
  self._models = ModelList(value)
@@ -76,6 +63,8 @@ class Jobs(Base):
76
63
 
77
64
  @agents.setter
78
65
  def agents(self, value):
66
+ from edsl import AgentList
67
+
79
68
  if value:
80
69
  if not isinstance(value, AgentList):
81
70
  self._agents = AgentList(value)
@@ -90,6 +79,8 @@ class Jobs(Base):
90
79
 
91
80
  @scenarios.setter
92
81
  def scenarios(self, value):
82
+ from edsl import ScenarioList
83
+
93
84
  if value:
94
85
  if not isinstance(value, ScenarioList):
95
86
  self._scenarios = ScenarioList(value)
@@ -101,10 +92,10 @@ class Jobs(Base):
101
92
  def by(
102
93
  self,
103
94
  *args: Union[
104
- Agent,
105
- Scenario,
106
- LanguageModel,
107
- Sequence[Union[Agent, Scenario, LanguageModel]],
95
+ "Agent",
96
+ "Scenario",
97
+ "LanguageModel",
98
+ Sequence[Union["Agent", "Scenario", "LanguageModel"]],
108
99
  ],
109
100
  ) -> Jobs:
110
101
  """
@@ -144,7 +135,7 @@ class Jobs(Base):
144
135
  setattr(self, objects_key, new_objects) # update the job
145
136
  return self
146
137
 
147
- def prompts(self) -> Dataset:
138
+ def prompts(self) -> "Dataset":
148
139
  """Return a Dataset of prompts that will be used.
149
140
 
150
141
 
@@ -160,6 +151,7 @@ class Jobs(Base):
160
151
  user_prompts = []
161
152
  system_prompts = []
162
153
  scenario_indices = []
154
+ from edsl.results.Dataset import Dataset
163
155
 
164
156
  for interview_index, interview in enumerate(interviews):
165
157
  invigilators = list(interview._build_invigilators(debug=False))
@@ -182,7 +174,10 @@ class Jobs(Base):
182
174
 
183
175
  @staticmethod
184
176
  def _get_container_class(object):
185
- from edsl import AgentList
177
+ from edsl.agents.AgentList import AgentList
178
+ from edsl.agents.Agent import Agent
179
+ from edsl.scenarios.Scenario import Scenario
180
+ from edsl.scenarios.ScenarioList import ScenarioList
186
181
 
187
182
  if isinstance(object, Agent):
188
183
  return AgentList
@@ -218,6 +213,10 @@ class Jobs(Base):
218
213
  def _get_current_objects_of_this_type(
219
214
  self, object: Union[Agent, Scenario, LanguageModel]
220
215
  ) -> tuple[list, str]:
216
+ from edsl.agents.Agent import Agent
217
+ from edsl.scenarios.Scenario import Scenario
218
+ from edsl.language_models.LanguageModel import LanguageModel
219
+
221
220
  """Return the current objects of the same type as the first argument.
222
221
 
223
222
  >>> from edsl.jobs import Jobs
@@ -246,6 +245,9 @@ class Jobs(Base):
246
245
  @staticmethod
247
246
  def _get_empty_container_object(object):
248
247
  from edsl import AgentList
248
+ from edsl import Agent
249
+ from edsl import Scenario
250
+ from edsl import ScenarioList
249
251
 
250
252
  if isinstance(object, Agent):
251
253
  return AgentList([])
@@ -310,12 +312,12 @@ class Jobs(Base):
310
312
  with us filling in defaults.
311
313
  """
312
314
  # if no agents, models, or scenarios are set, set them to defaults
315
+ from edsl.agents.Agent import Agent
316
+ from edsl.language_models.registry import Model
317
+ from edsl.scenarios.Scenario import Scenario
318
+
313
319
  self.agents = self.agents or [Agent()]
314
320
  self.models = self.models or [Model()]
315
- # if remote, set all the models to remote
316
- if hasattr(self, "remote") and self.remote:
317
- for model in self.models:
318
- model.remote = True
319
321
  self.scenarios = self.scenarios or [Scenario()]
320
322
  for agent, scenario, model in product(self.agents, self.scenarios, self.models):
321
323
  yield Interview(
@@ -329,6 +331,7 @@ class Jobs(Base):
329
331
  These buckets are used to track API calls and token usage.
330
332
 
331
333
  >>> from edsl.jobs import Jobs
334
+ >>> from edsl import Model
332
335
  >>> j = Jobs.example().by(Model(temperature = 1), Model(temperature = 0.5))
333
336
  >>> bc = j.create_bucket_collection()
334
337
  >>> bc
@@ -368,14 +371,16 @@ class Jobs(Base):
368
371
  if self.verbose:
369
372
  print(message)
370
373
 
371
- def _check_parameters(self, strict=False, warn = True) -> None:
374
+ def _check_parameters(self, strict=False, warn=False) -> None:
372
375
  """Check if the parameters in the survey and scenarios are consistent.
373
376
 
374
377
  >>> from edsl import QuestionFreeText
378
+ >>> from edsl import Survey
379
+ >>> from edsl import Scenario
375
380
  >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
376
381
  >>> j = Jobs(survey = Survey(questions=[q]))
377
382
  >>> with warnings.catch_warnings(record=True) as w:
378
- ... j._check_parameters()
383
+ ... j._check_parameters(warn = True)
379
384
  ... assert len(w) == 1
380
385
  ... assert issubclass(w[-1].category, UserWarning)
381
386
  ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
@@ -413,15 +418,13 @@ class Jobs(Base):
413
418
  progress_bar: bool = False,
414
419
  stop_on_exception: bool = False,
415
420
  cache: Union[Cache, bool] = None,
416
- remote: bool = (
417
- False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
418
- ),
419
421
  check_api_keys: bool = False,
420
422
  sidecar_model: Optional[LanguageModel] = None,
421
423
  batch_mode: Optional[bool] = None,
422
424
  verbose: bool = False,
423
425
  print_exceptions=True,
424
426
  remote_cache_description: Optional[str] = None,
427
+ remote_inference_description: Optional[str] = None,
425
428
  ) -> Results:
426
429
  """
427
430
  Runs the Job: conducts Interviews and returns their results.
@@ -431,11 +434,11 @@ class Jobs(Base):
431
434
  :param progress_bar: shows a progress bar
432
435
  :param stop_on_exception: stops the job if an exception is raised
433
436
  :param cache: a cache object to store results
434
- :param remote: run the job remotely
435
437
  :param check_api_keys: check if the API keys are valid
436
438
  :param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
437
439
  :param verbose: prints messages
438
440
  :param remote_cache_description: specifies a description for this group of entries in the remote cache
441
+ :param remote_inference_description: specifies a description for the remote inference job
439
442
  """
440
443
  from edsl.coop.coop import Coop
441
444
 
@@ -446,21 +449,57 @@ class Jobs(Base):
446
449
  "Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
447
450
  )
448
451
 
449
- self.remote = remote
450
452
  self.verbose = verbose
451
453
 
452
454
  try:
453
455
  coop = Coop()
454
- remote_cache = coop.edsl_settings["remote_caching"]
456
+ user_edsl_settings = coop.edsl_settings
457
+ remote_cache = user_edsl_settings["remote_caching"]
458
+ remote_inference = user_edsl_settings["remote_inference"]
455
459
  except Exception:
456
460
  remote_cache = False
461
+ remote_inference = False
462
+
463
+ if remote_inference:
464
+ from edsl.agents.Agent import Agent
465
+ from edsl.language_models.registry import Model
466
+ from edsl.results.Result import Result
467
+ from edsl.results.Results import Results
468
+ from edsl.scenarios.Scenario import Scenario
469
+ from edsl.surveys.Survey import Survey
470
+
471
+ self._output("Remote inference activated. Sending job to server...")
472
+ if remote_cache:
473
+ self._output(
474
+ "Remote caching activated. The remote cache will be used for this job."
475
+ )
457
476
 
458
- if self.remote:
459
- ## TODO: This should be a coop check
460
- if os.getenv("EXPECTED_PARROT_API_KEY", None) is None:
461
- raise MissingRemoteInferenceError()
462
-
463
- if not self.remote:
477
+ remote_job_data = coop.remote_inference_create(
478
+ self,
479
+ description=remote_inference_description,
480
+ status="queued",
481
+ )
482
+ self._output("Job sent!")
483
+ # Create mock results object to store job data
484
+ results = Results(
485
+ survey=Survey(),
486
+ data=[
487
+ Result(
488
+ agent=Agent.example(),
489
+ scenario=Scenario.example(),
490
+ model=Model(),
491
+ iteration=1,
492
+ answer={"info": "Remote job details"},
493
+ )
494
+ ],
495
+ )
496
+ results.add_columns_from_dict([remote_job_data])
497
+ if self.verbose:
498
+ results.select(["info", "uuid", "status", "version"]).print(
499
+ format="rich"
500
+ )
501
+ return results
502
+ else:
464
503
  if check_api_keys:
465
504
  for model in self.models + [Model()]:
466
505
  if not model.has_valid_api_key():
@@ -471,8 +510,12 @@ class Jobs(Base):
471
510
 
472
511
  # handle cache
473
512
  if cache is None:
513
+ from edsl.data.CacheHandler import CacheHandler
514
+
474
515
  cache = CacheHandler().get_cache()
475
516
  if cache is False:
517
+ from edsl.data.Cache import Cache
518
+
476
519
  cache = Cache()
477
520
 
478
521
  if not remote_cache:
@@ -624,6 +667,11 @@ class Jobs(Base):
624
667
  @remove_edsl_version
625
668
  def from_dict(cls, data: dict) -> Jobs:
626
669
  """Creates a Jobs instance from a dictionary."""
670
+ from edsl import Survey
671
+ from edsl.agents.Agent import Agent
672
+ from edsl.language_models.LanguageModel import LanguageModel
673
+ from edsl.scenarios.Scenario import Scenario
674
+
627
675
  return cls(
628
676
  survey=Survey.from_dict(data["survey"]),
629
677
  agents=[Agent.from_dict(agent) for agent in data["agents"]],
@@ -639,7 +687,9 @@ class Jobs(Base):
639
687
  # Example methods
640
688
  #######################
641
689
  @classmethod
642
- def example(cls, throw_exception_probability=0) -> Jobs:
690
+ def example(
691
+ cls, throw_exception_probability: int = 0, randomize: bool = False
692
+ ) -> Jobs:
643
693
  """Return an example Jobs instance.
644
694
 
645
695
  :param throw_exception_probability: the probability that an exception will be thrown when answering a question. This is useful for testing error handling.
@@ -649,8 +699,12 @@ class Jobs(Base):
649
699
 
650
700
  """
651
701
  import random
702
+ from uuid import uuid4
652
703
  from edsl.questions import QuestionMultipleChoice
653
- from edsl import Agent
704
+ from edsl.agents.Agent import Agent
705
+ from edsl.scenarios.Scenario import Scenario
706
+
707
+ addition = "" if not randomize else str(uuid4())
654
708
 
655
709
  # (status, question, period)
656
710
  agent_answers = {
@@ -689,10 +743,15 @@ class Jobs(Base):
689
743
  question_options=["Good", "Great", "OK", "Terrible"],
690
744
  question_name="how_feeling_yesterday",
691
745
  )
746
+ from edsl import Survey, ScenarioList
747
+
692
748
  base_survey = Survey(questions=[q1, q2])
693
749
 
694
750
  scenario_list = ScenarioList(
695
- [Scenario({"period": "morning"}), Scenario({"period": "afternoon"})]
751
+ [
752
+ Scenario({"period": f"morning{addition}"}),
753
+ Scenario({"period": "afternoon"}),
754
+ ]
696
755
  )
697
756
  job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
698
757
 
@@ -1,4 +1,4 @@
1
- from edsl.jobs.buckets.TokenBucket import TokenBucket
1
+ # from edsl.jobs.buckets.TokenBucket import TokenBucket
2
2
 
3
3
 
4
4
  class ModelBuckets:
@@ -8,7 +8,7 @@ class ModelBuckets:
8
8
  A request is one call to the service. The number of tokens required for a request depends on parameters.
9
9
  """
10
10
 
11
- def __init__(self, requests_bucket: TokenBucket, tokens_bucket: TokenBucket):
11
+ def __init__(self, requests_bucket: "TokenBucket", tokens_bucket: "TokenBucket"):
12
12
  """Initialize the model buckets.
13
13
 
14
14
  The requests bucket captures requests per unit of time.
@@ -25,9 +25,21 @@ class ModelBuckets:
25
25
  tokens_bucket=self.tokens_bucket + other.tokens_bucket,
26
26
  )
27
27
 
28
+ def turbo_mode_on(self):
29
+ """Set the refill rate to infinity for both buckets."""
30
+ self.requests_bucket.turbo_mode_on()
31
+ self.tokens_bucket.turbo_mode_on()
32
+
33
+ def turbo_mode_off(self):
34
+ """Restore the refill rate to its original value for both buckets."""
35
+ self.requests_bucket.turbo_mode_off()
36
+ self.tokens_bucket.turbo_mode_off()
37
+
28
38
  @classmethod
29
39
  def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
30
40
  """Create a bucket with infinite capacity and refill rate."""
41
+ from edsl.jobs.buckets.TokenBucket import TokenBucket
42
+
31
43
  return cls(
32
44
  requests_bucket=TokenBucket(
33
45
  bucket_name=model_name,