edsl 0.1.29.dev2__py3-none-any.whl → 0.1.29.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/jobs/Jobs.py CHANGED
@@ -312,10 +312,6 @@ class Jobs(Base):
312
312
  # if no agents, models, or scenarios are set, set them to defaults
313
313
  self.agents = self.agents or [Agent()]
314
314
  self.models = self.models or [Model()]
315
- # if remote, set all the models to remote
316
- if hasattr(self, "remote") and self.remote:
317
- for model in self.models:
318
- model.remote = True
319
315
  self.scenarios = self.scenarios or [Scenario()]
320
316
  for agent, scenario, model in product(self.agents, self.scenarios, self.models):
321
317
  yield Interview(
@@ -368,14 +364,14 @@ class Jobs(Base):
368
364
  if self.verbose:
369
365
  print(message)
370
366
 
371
- def _check_parameters(self, strict=False, warn = True) -> None:
367
+ def _check_parameters(self, strict=False, warn=False) -> None:
372
368
  """Check if the parameters in the survey and scenarios are consistent.
373
369
 
374
370
  >>> from edsl import QuestionFreeText
375
371
  >>> q = QuestionFreeText(question_text = "{{poo}}", question_name = "ugly_question")
376
372
  >>> j = Jobs(survey = Survey(questions=[q]))
377
373
  >>> with warnings.catch_warnings(record=True) as w:
378
- ... j._check_parameters()
374
+ ... j._check_parameters(warn = True)
379
375
  ... assert len(w) == 1
380
376
  ... assert issubclass(w[-1].category, UserWarning)
381
377
  ... assert "The following parameters are in the survey but not in the scenarios" in str(w[-1].message)
@@ -413,15 +409,13 @@ class Jobs(Base):
413
409
  progress_bar: bool = False,
414
410
  stop_on_exception: bool = False,
415
411
  cache: Union[Cache, bool] = None,
416
- remote: bool = (
417
- False if os.getenv("DEFAULT_RUN_MODE", "local") == "local" else True
418
- ),
419
412
  check_api_keys: bool = False,
420
413
  sidecar_model: Optional[LanguageModel] = None,
421
414
  batch_mode: Optional[bool] = None,
422
415
  verbose: bool = False,
423
416
  print_exceptions=True,
424
417
  remote_cache_description: Optional[str] = None,
418
+ remote_inference_description: Optional[str] = None,
425
419
  ) -> Results:
426
420
  """
427
421
  Runs the Job: conducts Interviews and returns their results.
@@ -431,11 +425,11 @@ class Jobs(Base):
431
425
  :param progress_bar: shows a progress bar
432
426
  :param stop_on_exception: stops the job if an exception is raised
433
427
  :param cache: a cache object to store results
434
- :param remote: run the job remotely
435
428
  :param check_api_keys: check if the API keys are valid
436
429
  :param batch_mode: run the job in batch mode i.e., no expecation of interaction with the user
437
430
  :param verbose: prints messages
438
431
  :param remote_cache_description: specifies a description for this group of entries in the remote cache
432
+ :param remote_inference_description: specifies a description for the remote inference job
439
433
  """
440
434
  from edsl.coop.coop import Coop
441
435
 
@@ -446,21 +440,33 @@ class Jobs(Base):
446
440
  "Batch mode is deprecated. Please update your code to not include 'batch_mode' in the 'run' method."
447
441
  )
448
442
 
449
- self.remote = remote
450
443
  self.verbose = verbose
451
444
 
452
445
  try:
453
446
  coop = Coop()
454
- remote_cache = coop.edsl_settings["remote_caching"]
447
+ user_edsl_settings = coop.edsl_settings
448
+ remote_cache = user_edsl_settings["remote_caching"]
449
+ remote_inference = user_edsl_settings["remote_inference"]
455
450
  except Exception:
456
451
  remote_cache = False
452
+ remote_inference = False
457
453
 
458
- if self.remote:
459
- ## TODO: This should be a coop check
460
- if os.getenv("EXPECTED_PARROT_API_KEY", None) is None:
461
- raise MissingRemoteInferenceError()
454
+ if remote_inference:
455
+ self._output("Remote inference activated. Sending job to server...")
456
+ if remote_cache:
457
+ self._output(
458
+ "Remote caching activated. The remote cache will be used for this job."
459
+ )
462
460
 
463
- if not self.remote:
461
+ remote_job_data = coop.remote_inference_create(
462
+ self,
463
+ description=remote_inference_description,
464
+ status="queued",
465
+ )
466
+ self._output("Job sent!")
467
+ self._output(remote_job_data)
468
+ return remote_job_data
469
+ else:
464
470
  if check_api_keys:
465
471
  for model in self.models + [Model()]:
466
472
  if not model.has_valid_api_key():
@@ -44,6 +44,7 @@ class InterviewTaskBuildingMixin:
44
44
  scenario=self.scenario,
45
45
  model=self.model,
46
46
  debug=debug,
47
+ survey=self.survey,
47
48
  memory_plan=self.survey.memory_plan,
48
49
  current_answers=self.answers,
49
50
  iteration=self.iteration,
@@ -56,6 +56,37 @@ class Notebook(Base):
56
56
 
57
57
  self.name = name or self.default_name
58
58
 
59
+ @classmethod
60
+ def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
61
+ # Read the script file
62
+ with open(path, "r") as script_file:
63
+ script_content = script_file.read()
64
+
65
+ # Create a new Jupyter notebook
66
+ nb = nbformat.v4.new_notebook()
67
+
68
+ # Add the script content to the first cell
69
+ first_cell = nbformat.v4.new_code_cell(script_content)
70
+ nb.cells.append(first_cell)
71
+
72
+ # Create a Notebook instance with the notebook data
73
+ notebook_instance = cls(nb)
74
+
75
+ return notebook_instance
76
+
77
+ @classmethod
78
+ def from_current_script(cls) -> "Notebook":
79
+ import inspect
80
+ import os
81
+
82
+ # Get the path to the current file
83
+ current_frame = inspect.currentframe()
84
+ caller_frame = inspect.getouterframes(current_frame, 2)
85
+ current_file_path = os.path.abspath(caller_frame[1].filename)
86
+
87
+ # Use from_script to create the notebook
88
+ return cls.from_script(current_file_path)
89
+
59
90
  def __eq__(self, other):
60
91
  """
61
92
  Check if two Notebooks are equal.
edsl/prompts/Prompt.py CHANGED
@@ -1,12 +1,16 @@
1
- """Class for creating prompts to be used in a survey."""
2
-
3
1
  from __future__ import annotations
4
2
  from typing import Optional
5
3
  from abc import ABC
6
4
  from typing import Any, List
7
5
 
8
6
  from rich.table import Table
9
- from jinja2 import Template, Environment, meta, TemplateSyntaxError
7
+ from jinja2 import Template, Environment, meta, TemplateSyntaxError, Undefined
8
+
9
+
10
+ class PreserveUndefined(Undefined):
11
+ def __str__(self):
12
+ return "{{ " + self._undefined_name + " }}"
13
+
10
14
 
11
15
  from edsl.exceptions.prompts import TemplateRenderError
12
16
  from edsl.prompts.prompt_config import (
@@ -35,6 +39,10 @@ class PromptBase(
35
39
 
36
40
  return data_to_html(self.to_dict())
37
41
 
42
+ def __len__(self):
43
+ """Return the length of the prompt text."""
44
+ return len(self.text)
45
+
38
46
  @classmethod
39
47
  def prompt_attributes(cls) -> List[str]:
40
48
  """Return the prompt class attributes."""
@@ -75,10 +83,10 @@ class PromptBase(
75
83
  >>> p = Prompt("Hello, {{person}}")
76
84
  >>> p2 = Prompt("How are you?")
77
85
  >>> p + p2
78
- Prompt(text='Hello, {{person}}How are you?')
86
+ Prompt(text=\"""Hello, {{person}}How are you?\""")
79
87
 
80
88
  >>> p + "How are you?"
81
- Prompt(text='Hello, {{person}}How are you?')
89
+ Prompt(text=\"""Hello, {{person}}How are you?\""")
82
90
  """
83
91
  if isinstance(other_prompt, str):
84
92
  return self.__class__(self.text + other_prompt)
@@ -114,7 +122,7 @@ class PromptBase(
114
122
  Example:
115
123
  >>> p = Prompt("Hello, {{person}}")
116
124
  >>> p
117
- Prompt(text='Hello, {{person}}')
125
+ Prompt(text=\"""Hello, {{person}}\""")
118
126
  """
119
127
  return f'Prompt(text="""{self.text}""")'
120
128
 
@@ -137,7 +145,7 @@ class PromptBase(
137
145
  :param template: The template to find the variables in.
138
146
 
139
147
  """
140
- env = Environment()
148
+ env = Environment(undefined=PreserveUndefined)
141
149
  ast = env.parse(template)
142
150
  return list(meta.find_undeclared_variables(ast))
143
151
 
@@ -186,13 +194,16 @@ class PromptBase(
186
194
 
187
195
  >>> p = Prompt("Hello, {{person}}")
188
196
  >>> p.render({"person": "John"})
189
- 'Hello, John'
197
+ Prompt(text=\"""Hello, John\""")
190
198
 
191
199
  >>> p.render({"person": "Mr. {{last_name}}", "last_name": "Horton"})
192
- 'Hello, Mr. Horton'
200
+ Prompt(text=\"""Hello, Mr. Horton\""")
193
201
 
194
202
  >>> p.render({"person": "Mr. {{last_name}}", "last_name": "Ho{{letter}}ton"}, max_nesting = 1)
195
- 'Hello, Mr. Horton'
203
+ Prompt(text=\"""Hello, Mr. Ho{{ letter }}ton\""")
204
+
205
+ >>> p.render({"person": "Mr. {{last_name}}"})
206
+ Prompt(text=\"""Hello, Mr. {{ last_name }}\""")
196
207
  """
197
208
  new_text = self._render(
198
209
  self.text, primary_replacement, **additional_replacements
@@ -216,12 +227,13 @@ class PromptBase(
216
227
  >>> codebook = {"age": "Age"}
217
228
  >>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
218
229
  >>> p.render({"name": "John", "age": 44}, codebook=codebook)
219
- 'You are an agent named John. Age: 44'
230
+ Prompt(text=\"""You are an agent named John. Age: 44\""")
220
231
  """
232
+ env = Environment(undefined=PreserveUndefined)
221
233
  try:
222
234
  previous_text = None
223
235
  for _ in range(MAX_NESTING):
224
- rendered_text = Template(text).render(
236
+ rendered_text = env.from_string(text).render(
225
237
  primary_replacement, **additional_replacements
226
238
  )
227
239
  if rendered_text == previous_text:
@@ -258,7 +270,7 @@ class PromptBase(
258
270
  >>> p = Prompt("Hello, {{person}}")
259
271
  >>> p2 = Prompt.from_dict(p.to_dict())
260
272
  >>> p2
261
- Prompt(text='Hello, {{person}}')
273
+ Prompt(text=\"""Hello, {{person}}\""")
262
274
 
263
275
  """
264
276
  class_name = data["class_name"]
@@ -290,6 +302,12 @@ class Prompt(PromptBase):
290
302
  component_type = ComponentTypes.GENERIC
291
303
 
292
304
 
305
+ if __name__ == "__main__":
306
+ print("Running doctests...")
307
+ import doctest
308
+
309
+ doctest.testmod()
310
+
293
311
  from edsl.prompts.library.question_multiple_choice import *
294
312
  from edsl.prompts.library.agent_instructions import *
295
313
  from edsl.prompts.library.agent_persona import *
@@ -302,9 +320,3 @@ from edsl.prompts.library.question_numerical import *
302
320
  from edsl.prompts.library.question_rank import *
303
321
  from edsl.prompts.library.question_extract import *
304
322
  from edsl.prompts.library.question_list import *
305
-
306
-
307
- if __name__ == "__main__":
308
- import doctest
309
-
310
- doctest.testmod()
@@ -173,15 +173,16 @@ class QuestionBase(
173
173
  def add_model_instructions(
174
174
  self, *, instructions: str, model: Optional[str] = None
175
175
  ) -> None:
176
- """Add model-specific instructions for the question.
176
+ """Add model-specific instructions for the question that override the default instructions.
177
177
 
178
178
  :param instructions: The instructions to add. This is typically a jinja2 template.
179
179
  :param model: The language model for this instruction.
180
180
 
181
181
  >>> from edsl.questions import QuestionFreeText
182
182
  >>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
183
- >>> q.add_model_instructions(instructions = "Answer in valid JSON like so {'answer': 'comment: <>}", model = "gpt3")
184
-
183
+ >>> q.add_model_instructions(instructions = "{{question_text}}. Answer in valid JSON like so {'answer': 'comment: <>}", model = "gpt3")
184
+ >>> q.get_instructions(model = "gpt3")
185
+ Prompt(text=\"""{{question_text}}. Answer in valid JSON like so {'answer': 'comment: <>}\""")
185
186
  """
186
187
  from edsl import Model
187
188
 
@@ -201,6 +202,13 @@ class QuestionBase(
201
202
  """Get the mathcing question-answering instructions for the question.
202
203
 
203
204
  :param model: The language model to use.
205
+
206
+ >>> from edsl import QuestionFreeText
207
+ >>> QuestionFreeText.example().get_instructions()
208
+ Prompt(text=\"""You are being asked the following question: {{question_text}}
209
+ Return a valid JSON formatted like this:
210
+ {"answer": "<put free text answer here>"}
211
+ \""")
204
212
  """
205
213
  from edsl.prompts.Prompt import Prompt
206
214
 
@@ -293,7 +301,16 @@ class QuestionBase(
293
301
  print_json(json.dumps(self.to_dict()))
294
302
 
295
303
  def __call__(self, just_answer=True, model=None, agent=None, **kwargs):
296
- """Call the question."""
304
+ """Call the question.
305
+
306
+ >>> from edsl.language_models import LanguageModel
307
+ >>> m = LanguageModel.example(canned_response = "Yo, what's up?", test_model = True)
308
+ >>> from edsl import QuestionFreeText
309
+ >>> q = QuestionFreeText(question_name = "color", question_text = "What is your favorite color?")
310
+ >>> q(model = m)
311
+ "Yo, what's up?"
312
+
313
+ """
297
314
  survey = self.to_survey()
298
315
  results = survey(model=model, agent=agent, **kwargs)
299
316
  if just_answer:
@@ -304,7 +321,6 @@ class QuestionBase(
304
321
  async def run_async(self, just_answer=True, model=None, agent=None, **kwargs):
305
322
  """Call the question."""
306
323
  survey = self.to_survey()
307
- ## asyncio.run(survey.async_call());
308
324
  results = await survey.run_async(model=model, agent=agent, **kwargs)
309
325
  if just_answer:
310
326
  return results.select(f"answer.{self.question_name}").first()
@@ -383,29 +399,34 @@ class QuestionBase(
383
399
  s = Survey([self, other])
384
400
  return s
385
401
 
386
- def to_survey(self):
402
+ def to_survey(self) -> "Survey":
387
403
  """Turn a single question into a survey."""
388
404
  from edsl.surveys.Survey import Survey
389
405
 
390
406
  s = Survey([self])
391
407
  return s
392
408
 
393
- def run(self, *args, **kwargs):
409
+ def run(self, *args, **kwargs) -> "Results":
394
410
  """Turn a single question into a survey and run it."""
395
411
  from edsl.surveys.Survey import Survey
396
412
 
397
413
  s = self.to_survey()
398
414
  return s.run(*args, **kwargs)
399
415
 
400
- def by(self, *args):
401
- """Turn a single question into a survey and run it."""
416
+ def by(self, *args) -> "Jobs":
417
+ """Turn a single question into a survey and then a Job."""
402
418
  from edsl.surveys.Survey import Survey
403
419
 
404
420
  s = Survey([self])
405
421
  return s.by(*args)
406
422
 
407
- def human_readable(self):
408
- """Print the question in a human readable format."""
423
+ def human_readable(self) -> str:
424
+ """Print the question in a human readable format.
425
+
426
+ >>> from edsl.questions import QuestionFreeText
427
+ >>> QuestionFreeText.example().human_readable()
428
+ 'Question Type: free_text\\nQuestion: How are you?'
429
+ """
409
430
  lines = []
410
431
  lines.append(f"Question Type: {self.question_type}")
411
432
  lines.append(f"Question: {self.question_text}")
@@ -1,10 +1,10 @@
1
1
  """This module provides a factory class for creating question objects."""
2
2
 
3
3
  import textwrap
4
- from typing import Union
4
+ from uuid import UUID
5
+ from typing import Any, Optional, Union
6
+
5
7
 
6
- from edsl.exceptions import QuestionSerializationError
7
- from edsl.exceptions import QuestionCreationValidationError
8
8
  from edsl.questions.QuestionBase import RegisterQuestionsMeta
9
9
 
10
10
 
@@ -60,46 +60,35 @@ class Question(metaclass=Meta):
60
60
  return q.example()
61
61
 
62
62
  @classmethod
63
- def pull(cls, id_or_url: str):
63
+ def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
64
64
  """Pull the object from coop."""
65
65
  from edsl.coop import Coop
66
66
 
67
- c = Coop()
68
- if c.url in id_or_url:
69
- id = id_or_url.split("/")[-1]
70
- else:
71
- id = id_or_url
72
- from edsl.questions.QuestionBase import QuestionBase
73
-
74
- return c._get_base(QuestionBase, id)
67
+ coop = Coop()
68
+ return coop.get(uuid, url, "question")
75
69
 
76
70
  @classmethod
77
- def delete(cls, id_or_url: str):
71
+ def delete(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
78
72
  """Delete the object from coop."""
79
73
  from edsl.coop import Coop
80
74
 
81
- c = Coop()
82
- if c.url in id_or_url:
83
- id = id_or_url.split("/")[-1]
84
- else:
85
- id = id_or_url
86
- from edsl.questions.QuestionBase import QuestionBase
87
-
88
- return c._delete_base(QuestionBase, id)
75
+ coop = Coop()
76
+ return coop.delete(uuid, url)
89
77
 
90
78
  @classmethod
91
- def update(cls, id_or_url: str, visibility: str):
92
- """Update the object on coop."""
79
+ def patch(
80
+ cls,
81
+ uuid: Optional[Union[str, UUID]] = None,
82
+ url: Optional[str] = None,
83
+ description: Optional[str] = None,
84
+ value: Optional[Any] = None,
85
+ visibility: Optional[str] = None,
86
+ ):
87
+ """Patch the object on coop."""
93
88
  from edsl.coop import Coop
94
89
 
95
- c = Coop()
96
- if c.url in id_or_url:
97
- id = id_or_url.split("/")[-1]
98
- else:
99
- id = id_or_url
100
- from edsl.questions.QuestionBase import QuestionBase
101
-
102
- return c._update_base(QuestionBase, id, visibility)
90
+ coop = Coop()
91
+ return coop.patch(uuid, url, description, value, visibility)
103
92
 
104
93
  @classmethod
105
94
  def available(cls, show_class_names: bool = False) -> Union[list, dict]:
@@ -8,5 +8,5 @@ class Settings:
8
8
  MAX_EXPRESSION_CONSTRAINT_LENGTH = 1000
9
9
  MAX_NUM_OPTIONS = 200
10
10
  MIN_NUM_OPTIONS = 2
11
- MAX_OPTION_LENGTH = 1000
11
+ MAX_OPTION_LENGTH = 10000
12
12
  MAX_QUESTION_LENGTH = 100000
edsl/results/Dataset.py CHANGED
@@ -78,6 +78,28 @@ class Dataset(UserList, ResultsExportMixin):
78
78
 
79
79
  return get_values(self.data[0])[0]
80
80
 
81
+ def select(self, *keys):
82
+ """Return a new dataset with only the selected keys.
83
+
84
+ :param keys: The keys to select.
85
+
86
+ >>> d = Dataset([{'a.b':[1,2,3,4]}, {'c.d':[5,6,7,8]}])
87
+ >>> d.select('a.b')
88
+ Dataset([{'a.b': [1, 2, 3, 4]}])
89
+
90
+ >>> d.select('a.b', 'c.d')
91
+ Dataset([{'a.b': [1, 2, 3, 4]}, {'c.d': [5, 6, 7, 8]}])
92
+ """
93
+ if isinstance(keys, str):
94
+ keys = [keys]
95
+
96
+ new_data = []
97
+ for observation in self.data:
98
+ observation_key = list(observation.keys())[0]
99
+ if observation_key in keys:
100
+ new_data.append(observation)
101
+ return Dataset(new_data)
102
+
81
103
  def _repr_html_(self) -> str:
82
104
  """Return an HTML representation of the dataset."""
83
105
  from edsl.utilities.utilities import data_to_html
@@ -223,6 +245,15 @@ class Dataset(UserList, ResultsExportMixin):
223
245
 
224
246
  return Dataset(new_data)
225
247
 
248
+ @classmethod
249
+ def example(self):
250
+ """Return an example dataset.
251
+
252
+ >>> Dataset.example()
253
+ Dataset([{'a': [1, 2, 3, 4]}, {'b': [4, 3, 2, 1]}])
254
+ """
255
+ return Dataset([{"a": [1, 2, 3, 4]}, {"b": [4, 3, 2, 1]}])
256
+
226
257
 
227
258
  if __name__ == "__main__":
228
259
  import doctest
edsl/results/Results.py CHANGED
@@ -165,13 +165,7 @@ class Results(UserList, Mixins, Base):
165
165
  )
166
166
 
167
167
  def __repr__(self) -> str:
168
- # return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
169
- return f"""Results object
170
- Size: {len(self.data)}.
171
- Survey questions: {[q.question_name for q in self.survey.questions]}.
172
- Created columns: {self.created_columns}
173
- Hash: {hash(self)}
174
- """
168
+ return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
175
169
 
176
170
  def _repr_html_(self) -> str:
177
171
  json_str = json.dumps(self.to_dict()["data"], indent=4)
@@ -759,7 +753,10 @@ class Results(UserList, Mixins, Base):
759
753
 
760
754
  def sort_by(self, *columns: str, reverse: bool = False) -> Results:
761
755
  import warnings
762
- warnings.warn("sort_by is deprecated. Use order_by instead.", DeprecationWarning)
756
+
757
+ warnings.warn(
758
+ "sort_by is deprecated. Use order_by instead.", DeprecationWarning
759
+ )
763
760
  return self.order_by(*columns, reverse=reverse)
764
761
 
765
762
  def order_by(self, *columns: str, reverse: bool = False) -> Results:
@@ -800,6 +797,7 @@ class Results(UserList, Mixins, Base):
800
797
  │ Great │
801
798
  └──────────────┘
802
799
  """
800
+
803
801
  def to_numeric_if_possible(v):
804
802
  try:
805
803
  return float(v)
@@ -13,7 +13,10 @@ class ResultsToolsMixin:
13
13
  progress_bar=False,
14
14
  print_exceptions=False,
15
15
  ) -> list:
16
- values = self.shuffle(seed=seed).select(field).to_list()[:max_values]
16
+ values = [
17
+ str(txt)[:1000]
18
+ for txt in self.shuffle(seed=seed).select(field).to_list()[:max_values]
19
+ ]
17
20
  from edsl import ScenarioList
18
21
 
19
22
  q = QuestionList(
@@ -119,7 +119,7 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
119
119
 
120
120
  return ScenarioList(random.sample(self.data, n))
121
121
 
122
- def expand(self, expand_field: str, number_field = False) -> ScenarioList:
122
+ def expand(self, expand_field: str, number_field=False) -> ScenarioList:
123
123
  """Expand the ScenarioList by a field.
124
124
 
125
125
  Example:
@@ -137,7 +137,7 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
137
137
  new_scenario = scenario.copy()
138
138
  new_scenario[expand_field] = value
139
139
  if number_field:
140
- new_scenario[expand_field + '_number'] = index + 1
140
+ new_scenario[expand_field + "_number"] = index + 1
141
141
  new_scenarios.append(new_scenario)
142
142
  return ScenarioList(new_scenarios)
143
143
 
@@ -192,7 +192,7 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
192
192
 
193
193
  def get_sort_key(scenario: Any) -> tuple:
194
194
  return tuple(scenario[field] for field in fields)
195
-
195
+
196
196
  return ScenarioList(sorted(self, key=get_sort_key, reverse=reverse))
197
197
 
198
198
  def filter(self, expression: str) -> ScenarioList:
@@ -343,6 +343,20 @@ class ScenarioList(Base, UserList, ScenarioListPdfMixin, ResultsExportMixin):
343
343
  """
344
344
  return cls([Scenario(row) for row in df.to_dict(orient="records")])
345
345
 
346
+ def to_key_value(self, field, value=None) -> Union[dict, set]:
347
+ """Return the set of values in the field.
348
+
349
+ Example:
350
+
351
+ >>> s = ScenarioList([Scenario({'name': 'Alice'}), Scenario({'name': 'Bob'})])
352
+ >>> s.to_key_value('name') == {'Alice', 'Bob'}
353
+ True
354
+ """
355
+ if value is None:
356
+ return {scenario[field] for scenario in self}
357
+ else:
358
+ return {scenario[field]: scenario[value] for scenario in self}
359
+
346
360
  @classmethod
347
361
  def from_csv(cls, filename: str) -> ScenarioList:
348
362
  """Create a ScenarioList from a CSV file.
edsl/study/Study.py CHANGED
@@ -472,18 +472,12 @@ class Study:
472
472
  coop.create(self, description=self.description)
473
473
 
474
474
  @classmethod
475
- def pull(cls, id_or_url: Union[str, UUID], exec_profile=None):
475
+ def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
476
476
  """Pull the object from coop."""
477
477
  from edsl.coop import Coop
478
478
 
479
- if id_or_url.startswith("http"):
480
- uuid_value = id_or_url.split("/")[-1]
481
- else:
482
- uuid_value = id_or_url
483
-
484
- c = Coop()
485
-
486
- return c._get_base(cls, uuid_value, exec_profile=exec_profile)
479
+ coop = Coop()
480
+ return coop.get(uuid, url, "study")
487
481
 
488
482
  def __repr__(self):
489
483
  return f"""Study(name = "{self.name}", description = "{self.description}", objects = {self.objects}, cache = {self.cache}, filename = "{self.filename}", coop = {self.coop}, use_study_cache = {self.use_study_cache}, overwrite_on_change = {self.overwrite_on_change})"""