edsl 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +124 -53
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +21 -21
- edsl/agents/agent_list.py +2 -5
- edsl/agents/exceptions.py +119 -5
- edsl/base/__init__.py +10 -35
- edsl/base/base_class.py +71 -36
- edsl/base/base_exception.py +204 -0
- edsl/base/data_transfer_models.py +1 -1
- edsl/base/exceptions.py +94 -0
- edsl/buckets/__init__.py +15 -1
- edsl/buckets/bucket_collection.py +3 -4
- edsl/buckets/exceptions.py +107 -0
- edsl/buckets/model_buckets.py +1 -2
- edsl/buckets/token_bucket.py +11 -6
- edsl/buckets/token_bucket_api.py +27 -12
- edsl/buckets/token_bucket_client.py +9 -7
- edsl/caching/cache.py +12 -4
- edsl/caching/cache_entry.py +10 -9
- edsl/caching/exceptions.py +113 -7
- edsl/caching/remote_cache_sync.py +6 -7
- edsl/caching/sql_dict.py +20 -14
- edsl/cli.py +43 -0
- edsl/config/__init__.py +1 -1
- edsl/config/config_class.py +32 -6
- edsl/conversation/Conversation.py +8 -4
- edsl/conversation/car_buying.py +1 -3
- edsl/conversation/exceptions.py +58 -0
- edsl/conversation/mug_negotiation.py +2 -8
- edsl/coop/__init__.py +28 -6
- edsl/coop/coop.py +120 -29
- edsl/coop/coop_functions.py +1 -1
- edsl/coop/ep_key_handling.py +1 -1
- edsl/coop/exceptions.py +188 -9
- edsl/coop/price_fetcher.py +5 -8
- edsl/coop/utils.py +4 -6
- edsl/dataset/__init__.py +5 -4
- edsl/dataset/dataset.py +177 -86
- edsl/dataset/dataset_operations_mixin.py +98 -76
- edsl/dataset/dataset_tree.py +11 -7
- edsl/dataset/display/table_display.py +0 -2
- edsl/dataset/display/table_renderers.py +6 -4
- edsl/dataset/exceptions.py +125 -0
- edsl/dataset/file_exports.py +18 -11
- edsl/dataset/r/ggplot.py +13 -6
- edsl/display/__init__.py +27 -0
- edsl/display/core.py +147 -0
- edsl/display/plugin.py +189 -0
- edsl/display/utils.py +52 -0
- edsl/inference_services/__init__.py +9 -1
- edsl/inference_services/available_model_cache_handler.py +1 -1
- edsl/inference_services/available_model_fetcher.py +5 -6
- edsl/inference_services/data_structures.py +10 -7
- edsl/inference_services/exceptions.py +132 -1
- edsl/inference_services/inference_service_abc.py +2 -2
- edsl/inference_services/inference_services_collection.py +2 -6
- edsl/inference_services/registry.py +4 -3
- edsl/inference_services/service_availability.py +4 -3
- edsl/inference_services/services/anthropic_service.py +4 -1
- edsl/inference_services/services/aws_bedrock.py +13 -12
- edsl/inference_services/services/azure_ai.py +12 -10
- edsl/inference_services/services/deep_infra_service.py +1 -4
- edsl/inference_services/services/deep_seek_service.py +1 -5
- edsl/inference_services/services/google_service.py +7 -3
- edsl/inference_services/services/groq_service.py +1 -1
- edsl/inference_services/services/mistral_ai_service.py +4 -2
- edsl/inference_services/services/ollama_service.py +1 -1
- edsl/inference_services/services/open_ai_service.py +7 -5
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +8 -7
- edsl/inference_services/services/together_ai_service.py +2 -3
- edsl/inference_services/services/xai_service.py +1 -1
- edsl/instructions/__init__.py +1 -1
- edsl/instructions/change_instruction.py +7 -5
- edsl/instructions/exceptions.py +61 -0
- edsl/instructions/instruction.py +6 -2
- edsl/instructions/instruction_collection.py +6 -4
- edsl/instructions/instruction_handler.py +12 -15
- edsl/interviews/ReportErrors.py +0 -3
- edsl/interviews/__init__.py +9 -2
- edsl/interviews/answering_function.py +11 -13
- edsl/interviews/exception_tracking.py +15 -8
- edsl/interviews/exceptions.py +79 -0
- edsl/interviews/interview.py +33 -30
- edsl/interviews/interview_status_dictionary.py +4 -2
- edsl/interviews/interview_status_log.py +2 -1
- edsl/interviews/interview_task_manager.py +5 -5
- edsl/interviews/request_token_estimator.py +5 -2
- edsl/interviews/statistics.py +3 -4
- edsl/invigilators/__init__.py +7 -1
- edsl/invigilators/exceptions.py +79 -0
- edsl/invigilators/invigilator_base.py +0 -1
- edsl/invigilators/invigilators.py +9 -13
- edsl/invigilators/prompt_constructor.py +1 -5
- edsl/invigilators/prompt_helpers.py +8 -4
- edsl/invigilators/question_instructions_prompt_builder.py +1 -1
- edsl/invigilators/question_option_processor.py +9 -5
- edsl/invigilators/question_template_replacements_builder.py +3 -2
- edsl/jobs/__init__.py +42 -5
- edsl/jobs/async_interview_runner.py +25 -23
- edsl/jobs/check_survey_scenario_compatibility.py +11 -10
- edsl/jobs/data_structures.py +8 -5
- edsl/jobs/exceptions.py +177 -8
- edsl/jobs/fetch_invigilator.py +1 -1
- edsl/jobs/jobs.py +74 -69
- edsl/jobs/jobs_checks.py +6 -7
- edsl/jobs/jobs_component_constructor.py +4 -4
- edsl/jobs/jobs_pricing_estimation.py +4 -3
- edsl/jobs/jobs_remote_inference_logger.py +5 -4
- edsl/jobs/jobs_runner_asyncio.py +3 -4
- edsl/jobs/jobs_runner_status.py +8 -9
- edsl/jobs/remote_inference.py +27 -24
- edsl/jobs/results_exceptions_handler.py +10 -7
- edsl/key_management/__init__.py +3 -1
- edsl/key_management/exceptions.py +62 -0
- edsl/key_management/key_lookup.py +1 -1
- edsl/key_management/key_lookup_builder.py +37 -14
- edsl/key_management/key_lookup_collection.py +2 -0
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/exceptions.py +302 -14
- edsl/language_models/language_model.py +9 -8
- edsl/language_models/model.py +4 -4
- edsl/language_models/model_list.py +1 -1
- edsl/language_models/price_manager.py +1 -1
- edsl/language_models/raw_response_handler.py +14 -9
- edsl/language_models/registry.py +17 -21
- edsl/language_models/repair.py +0 -6
- edsl/language_models/unused/fake_openai_service.py +0 -1
- edsl/load_plugins.py +69 -0
- edsl/logger.py +146 -0
- edsl/notebooks/__init__.py +24 -1
- edsl/notebooks/exceptions.py +82 -0
- edsl/notebooks/notebook.py +7 -3
- edsl/notebooks/notebook_to_latex.py +1 -2
- edsl/plugins/__init__.py +63 -0
- edsl/plugins/built_in/export_example.py +50 -0
- edsl/plugins/built_in/pig_latin.py +67 -0
- edsl/plugins/cli.py +372 -0
- edsl/plugins/cli_typer.py +283 -0
- edsl/plugins/exceptions.py +31 -0
- edsl/plugins/hookspec.py +51 -0
- edsl/plugins/plugin_host.py +128 -0
- edsl/plugins/plugin_manager.py +633 -0
- edsl/plugins/plugins_registry.py +168 -0
- edsl/prompts/__init__.py +24 -1
- edsl/prompts/exceptions.py +107 -5
- edsl/prompts/prompt.py +15 -7
- edsl/questions/HTMLQuestion.py +5 -11
- edsl/questions/Quick.py +0 -1
- edsl/questions/__init__.py +6 -4
- edsl/questions/answer_validator_mixin.py +318 -323
- edsl/questions/compose_questions.py +3 -3
- edsl/questions/descriptors.py +11 -50
- edsl/questions/exceptions.py +278 -22
- edsl/questions/loop_processor.py +7 -5
- edsl/questions/prompt_templates/question_list.jinja +3 -0
- edsl/questions/question_base.py +46 -19
- edsl/questions/question_base_gen_mixin.py +2 -2
- edsl/questions/question_base_prompts_mixin.py +13 -7
- edsl/questions/question_budget.py +503 -98
- edsl/questions/question_check_box.py +660 -160
- edsl/questions/question_dict.py +345 -194
- edsl/questions/question_extract.py +401 -61
- edsl/questions/question_free_text.py +80 -14
- edsl/questions/question_functional.py +119 -9
- edsl/questions/{derived/question_likert_five.py → question_likert_five.py} +2 -2
- edsl/questions/{derived/question_linear_scale.py → question_linear_scale.py} +3 -4
- edsl/questions/question_list.py +275 -28
- edsl/questions/question_matrix.py +643 -96
- edsl/questions/question_multiple_choice.py +219 -51
- edsl/questions/question_numerical.py +361 -32
- edsl/questions/question_rank.py +401 -124
- edsl/questions/question_registry.py +7 -5
- edsl/questions/{derived/question_top_k.py → question_top_k.py} +3 -3
- edsl/questions/{derived/question_yes_no.py → question_yes_no.py} +3 -4
- edsl/questions/register_questions_meta.py +2 -2
- edsl/questions/response_validator_abc.py +13 -15
- edsl/questions/response_validator_factory.py +10 -12
- edsl/questions/templates/dict/answering_instructions.jinja +1 -0
- edsl/questions/templates/rank/question_presentation.jinja +1 -1
- edsl/results/__init__.py +1 -1
- edsl/results/exceptions.py +141 -7
- edsl/results/report.py +1 -2
- edsl/results/result.py +11 -9
- edsl/results/results.py +480 -321
- edsl/results/results_selector.py +8 -4
- edsl/scenarios/PdfExtractor.py +2 -2
- edsl/scenarios/construct_download_link.py +69 -35
- edsl/scenarios/directory_scanner.py +33 -14
- edsl/scenarios/document_chunker.py +1 -1
- edsl/scenarios/exceptions.py +238 -14
- edsl/scenarios/file_methods.py +1 -1
- edsl/scenarios/file_store.py +7 -3
- edsl/scenarios/handlers/__init__.py +17 -0
- edsl/scenarios/handlers/docx_file_store.py +0 -5
- edsl/scenarios/handlers/pdf_file_store.py +0 -1
- edsl/scenarios/handlers/pptx_file_store.py +0 -5
- edsl/scenarios/handlers/py_file_store.py +0 -1
- edsl/scenarios/handlers/sql_file_store.py +1 -4
- edsl/scenarios/handlers/sqlite_file_store.py +0 -1
- edsl/scenarios/handlers/txt_file_store.py +1 -1
- edsl/scenarios/scenario.py +1 -3
- edsl/scenarios/scenario_list.py +179 -27
- edsl/scenarios/scenario_list_pdf_tools.py +1 -0
- edsl/scenarios/scenario_selector.py +0 -1
- edsl/surveys/__init__.py +3 -4
- edsl/surveys/dag/__init__.py +4 -2
- edsl/surveys/descriptors.py +1 -1
- edsl/surveys/edit_survey.py +1 -0
- edsl/surveys/exceptions.py +165 -9
- edsl/surveys/memory/__init__.py +5 -3
- edsl/surveys/memory/memory_management.py +1 -0
- edsl/surveys/memory/memory_plan.py +6 -15
- edsl/surveys/rules/__init__.py +5 -3
- edsl/surveys/rules/rule.py +1 -2
- edsl/surveys/rules/rule_collection.py +1 -1
- edsl/surveys/survey.py +12 -24
- edsl/surveys/survey_css.py +3 -3
- edsl/surveys/survey_export.py +6 -3
- edsl/surveys/survey_flow_visualization.py +10 -1
- edsl/surveys/survey_simulator.py +2 -1
- edsl/tasks/__init__.py +23 -1
- edsl/tasks/exceptions.py +72 -0
- edsl/tasks/question_task_creator.py +3 -3
- edsl/tasks/task_creators.py +1 -3
- edsl/tasks/task_history.py +8 -10
- edsl/tasks/task_status_log.py +1 -2
- edsl/tokens/__init__.py +29 -1
- edsl/tokens/exceptions.py +37 -0
- edsl/tokens/interview_token_usage.py +3 -2
- edsl/tokens/token_usage.py +4 -3
- edsl/utilities/__init__.py +21 -1
- edsl/utilities/decorators.py +1 -2
- edsl/utilities/markdown_to_docx.py +2 -2
- edsl/utilities/markdown_to_pdf.py +1 -1
- edsl/utilities/repair_functions.py +0 -1
- edsl/utilities/restricted_python.py +0 -1
- edsl/utilities/template_loader.py +2 -3
- edsl/utilities/utilities.py +8 -29
- {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/METADATA +32 -2
- edsl-0.1.51.dist-info/RECORD +365 -0
- edsl-0.1.51.dist-info/entry_points.txt +3 -0
- edsl/dataset/smart_objects.py +0 -96
- edsl/exceptions/BaseException.py +0 -21
- edsl/exceptions/__init__.py +0 -54
- edsl/exceptions/configuration.py +0 -16
- edsl/exceptions/general.py +0 -34
- edsl/questions/derived/__init__.py +0 -0
- edsl/study/ObjectEntry.py +0 -173
- edsl/study/ProofOfWork.py +0 -113
- edsl/study/SnapShot.py +0 -80
- edsl/study/Study.py +0 -520
- edsl/study/__init__.py +0 -6
- edsl/utilities/interface.py +0 -135
- edsl-0.1.49.dist-info/RECORD +0 -347
- {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/LICENSE +0 -0
- {edsl-0.1.49.dist-info → edsl-0.1.51.dist-info}/WHEEL +0 -0
edsl/config/config_class.py
CHANGED
@@ -3,10 +3,18 @@
|
|
3
3
|
import os
|
4
4
|
import platformdirs
|
5
5
|
from dotenv import load_dotenv, find_dotenv
|
6
|
-
from
|
7
|
-
|
8
|
-
|
9
|
-
)
|
6
|
+
from ..base import BaseException
|
7
|
+
import logging
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
class InvalidEnvironmentVariableError(BaseException):
|
12
|
+
"""Raised when an environment variable is invalid."""
|
13
|
+
pass
|
14
|
+
|
15
|
+
class MissingEnvironmentVariableError(BaseException):
|
16
|
+
"""Raised when an expected environment variable is missing."""
|
17
|
+
pass
|
10
18
|
|
11
19
|
cache_dir = platformdirs.user_cache_dir("edsl")
|
12
20
|
os.makedirs(cache_dir, exist_ok=True)
|
@@ -50,6 +58,10 @@ CONFIG_MAP = {
|
|
50
58
|
"default": "True",
|
51
59
|
"info": "This config var determines whether to fetch prices for tokens used in remote inference",
|
52
60
|
},
|
61
|
+
"EDSL_LOG_LEVEL": {
|
62
|
+
"default": "ERROR",
|
63
|
+
"info": "This config var determines the logging level for the EDSL package (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
|
64
|
+
},
|
53
65
|
"EDSL_MAX_ATTEMPTS": {
|
54
66
|
"default": "5",
|
55
67
|
"info": "This config var determines the maximum number of times to retry a failed API call.",
|
@@ -86,9 +98,11 @@ class Config:
|
|
86
98
|
|
87
99
|
def __init__(self):
|
88
100
|
"""Initialize the Config class."""
|
101
|
+
logger.debug("Initializing Config class")
|
89
102
|
self._set_run_mode()
|
90
103
|
self._load_dotenv()
|
91
104
|
self._set_env_vars()
|
105
|
+
logger.info(f"Config initialized with run mode: {self.EDSL_RUN_MODE}")
|
92
106
|
|
93
107
|
def show_path_to_dot_env(self):
|
94
108
|
print(find_dotenv(usecwd=True))
|
@@ -101,7 +115,12 @@ class Config:
|
|
101
115
|
default = CONFIG_MAP.get("EDSL_RUN_MODE").get("default")
|
102
116
|
if run_mode is None:
|
103
117
|
run_mode = default
|
118
|
+
logger.debug(f"EDSL_RUN_MODE not set, using default: {default}")
|
119
|
+
else:
|
120
|
+
logger.debug(f"EDSL_RUN_MODE set to: {run_mode}")
|
121
|
+
|
104
122
|
if run_mode not in EDSL_RUN_MODES:
|
123
|
+
logger.error(f"Invalid EDSL_RUN_MODE: {run_mode}")
|
105
124
|
raise InvalidEnvironmentVariableError(
|
106
125
|
f"Value `{run_mode}` is not allowed for EDSL_RUN_MODE."
|
107
126
|
)
|
@@ -149,12 +168,19 @@ class Config:
|
|
149
168
|
"""
|
150
169
|
Returns the value of an environment variable.
|
151
170
|
"""
|
171
|
+
logger.debug(f"Getting config value for: {env_var}")
|
172
|
+
|
152
173
|
if env_var not in CONFIG_MAP:
|
174
|
+
logger.error(f"Invalid environment variable requested: {env_var}")
|
153
175
|
raise InvalidEnvironmentVariableError(f"{env_var} is not a valid env var. ")
|
154
176
|
elif env_var not in self.__dict__:
|
155
177
|
info = CONFIG_MAP[env_var].get("info")
|
178
|
+
logger.error(f"Missing environment variable: {env_var}")
|
156
179
|
raise MissingEnvironmentVariableError(f"{env_var} is not set. {info}")
|
157
|
-
|
180
|
+
|
181
|
+
value = self.__dict__.get(env_var)
|
182
|
+
logger.debug(f"Config value for {env_var}: {value}")
|
183
|
+
return value
|
158
184
|
|
159
185
|
def __iter__(self):
|
160
186
|
"""Iterate over the environment variables."""
|
@@ -174,4 +200,4 @@ class Config:
|
|
174
200
|
|
175
201
|
# Note: Python modules are singletons. As such, once this module is imported
|
176
202
|
# the same instance of it is reused across the application.
|
177
|
-
CONFIG = Config()
|
203
|
+
CONFIG = Config()
|
@@ -1,13 +1,16 @@
|
|
1
1
|
from collections import UserList
|
2
2
|
import asyncio
|
3
3
|
import inspect
|
4
|
-
from typing import Optional, Callable
|
5
|
-
from .. import
|
4
|
+
from typing import Optional, Callable, TYPE_CHECKING
|
5
|
+
from .. import QuestionFreeText, Results, AgentList, ScenarioList, Scenario, Model
|
6
6
|
from ..questions import QuestionBase
|
7
7
|
from ..results.Result import Result
|
8
8
|
from jinja2 import Template
|
9
9
|
from ..caching import Cache
|
10
10
|
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from ..language_models.model import Model
|
13
|
+
|
11
14
|
from .next_speaker_utilities import (
|
12
15
|
default_turn_taking_generator,
|
13
16
|
speaker_closure,
|
@@ -71,7 +74,7 @@ class Conversation:
|
|
71
74
|
conversation_index: Optional[int] = None,
|
72
75
|
cache=None,
|
73
76
|
disable_remote_inference=False,
|
74
|
-
default_model: Optional[
|
77
|
+
default_model: Optional[Model] = None,
|
75
78
|
):
|
76
79
|
self.disable_remote_inference = disable_remote_inference
|
77
80
|
self.per_round_message_template = per_round_message_template
|
@@ -120,7 +123,8 @@ What do you say next?"""
|
|
120
123
|
per_round_message_template
|
121
124
|
and "{{ round_message }}" not in next_statement_question.question_text
|
122
125
|
):
|
123
|
-
|
126
|
+
from .exceptions import ConversationValueError
|
127
|
+
raise ConversationValueError(
|
124
128
|
"If you pass in a per_round_message_template, you must include {{ round_message }} in the question_text."
|
125
129
|
)
|
126
130
|
|
edsl/conversation/car_buying.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
from .. import Agent, AgentList, QuestionFreeText
|
2
|
-
from .. import Cache
|
1
|
+
from .. import Agent, AgentList, QuestionFreeText, Cache, QuestionList
|
3
2
|
from .Conversation import Conversation, ConversationList
|
4
3
|
|
5
4
|
a1 = Agent(
|
@@ -46,7 +45,6 @@ q = QuestionFreeText(
|
|
46
45
|
question_name="car_brand",
|
47
46
|
)
|
48
47
|
|
49
|
-
from .. import QuestionList
|
50
48
|
|
51
49
|
q_actors = QuestionList(
|
52
50
|
question_text="""This was a conversation about buying a car: {{ transcript }}.
|
@@ -0,0 +1,58 @@
|
|
1
|
+
"""
|
2
|
+
Exceptions for the conversation module.
|
3
|
+
|
4
|
+
This module defines custom exceptions for the conversation module,
|
5
|
+
including errors for invalid participant configurations, agent interaction
|
6
|
+
failures, and conversation state errors.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from ..base import BaseException
|
10
|
+
|
11
|
+
|
12
|
+
class ConversationError(BaseException):
|
13
|
+
"""
|
14
|
+
Base exception class for all conversation-related errors.
|
15
|
+
|
16
|
+
This is the parent class for all exceptions related to conversation
|
17
|
+
operations, including agent communication, turn management, and
|
18
|
+
participant configuration.
|
19
|
+
"""
|
20
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
21
|
+
|
22
|
+
|
23
|
+
class ConversationValueError(ConversationError):
|
24
|
+
"""
|
25
|
+
Exception raised when an invalid value is provided to a conversation.
|
26
|
+
|
27
|
+
This exception occurs when attempting to create or modify a conversation
|
28
|
+
with invalid values, such as:
|
29
|
+
- Invalid participant configurations
|
30
|
+
- Inappropriate agent parameters
|
31
|
+
- Incompatible conversation settings
|
32
|
+
|
33
|
+
Examples:
|
34
|
+
```python
|
35
|
+
# Attempting to add an invalid participant to a conversation
|
36
|
+
conversation.add_participant(None) # Raises ConversationValueError
|
37
|
+
```
|
38
|
+
"""
|
39
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
40
|
+
|
41
|
+
|
42
|
+
class ConversationStateError(ConversationError):
|
43
|
+
"""
|
44
|
+
Exception raised when the conversation is in an invalid state.
|
45
|
+
|
46
|
+
This exception occurs when attempting to perform an operation that
|
47
|
+
is incompatible with the current state of the conversation, such as:
|
48
|
+
- Ending a conversation that hasn't started
|
49
|
+
- Starting a conversation that's already in progress
|
50
|
+
- Accessing a participant that doesn't exist
|
51
|
+
|
52
|
+
Examples:
|
53
|
+
```python
|
54
|
+
# Attempting to get the next speaker when the conversation is empty
|
55
|
+
empty_conversation.next_speaker() # Raises ConversationStateError
|
56
|
+
```
|
57
|
+
"""
|
58
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from .. import Agent, AgentList, QuestionYesNo, QuestionNumerical
|
2
|
+
from .Conversation import Conversation, ConversationList
|
3
3
|
|
4
4
|
|
5
5
|
def bargaining_pairs(alice_valuation, bob_valuation):
|
@@ -43,12 +43,6 @@ results.select("conversation_index", "index", "agent_name", "dialogue").print(
|
|
43
43
|
format="rich"
|
44
44
|
)
|
45
45
|
|
46
|
-
from edsl import (
|
47
|
-
QuestionFreeText,
|
48
|
-
QuestionMultipleChoice,
|
49
|
-
QuestionYesNo,
|
50
|
-
QuestionNumerical,
|
51
|
-
)
|
52
46
|
|
53
47
|
q_deal = QuestionYesNo(
|
54
48
|
question_text="""This was a negotiation: {{ transcript }}.
|
edsl/coop/__init__.py
CHANGED
@@ -8,18 +8,40 @@ This module enables EDSL to interact with cloud-based resources for enhanced fun
|
|
8
8
|
3. Caching of interview results for improved performance and cost savings
|
9
9
|
4. API key management and authentication
|
10
10
|
5. Price and model availability information
|
11
|
+
6. Plugin registry and discovery
|
11
12
|
|
12
13
|
The primary interface is the Coop class, which serves as a client for the
|
13
14
|
Expected Parrot API. Most users will only need to interact with the Coop class directly.
|
14
15
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
16
|
+
Examples:
|
17
|
+
|
18
|
+
```python
|
19
|
+
from edsl.coop import Coop
|
20
|
+
coop = Coop() # Uses API key from environment or stored location
|
21
|
+
survey = my_survey.push() # Uploads survey to Expected Parrot
|
22
|
+
job_info = coop.remote_inference_create(my_job) # Creates remote job
|
23
|
+
|
24
|
+
# Working with plugins
|
25
|
+
from edsl.coop import get_available_plugins
|
26
|
+
plugins = get_available_plugins()
|
27
|
+
plugin_names = [p.name for p in plugins]
|
28
|
+
```
|
20
29
|
"""
|
21
30
|
|
22
31
|
from .utils import EDSLObject, ObjectType, VisibilityType, ObjectRegistry
|
23
32
|
from .coop import Coop
|
24
33
|
from .exceptions import CoopServerResponseError
|
25
|
-
|
34
|
+
|
35
|
+
__all__ = [
|
36
|
+
"Coop",
|
37
|
+
"EDSLObject",
|
38
|
+
"ObjectType",
|
39
|
+
"VisibilityType",
|
40
|
+
"ObjectRegistry",
|
41
|
+
"CoopServerResponseError",
|
42
|
+
"AvailablePlugin",
|
43
|
+
"get_available_plugins",
|
44
|
+
"search_plugins",
|
45
|
+
"get_plugin_details",
|
46
|
+
"PluginRegistryError"
|
47
|
+
]
|
edsl/coop/coop.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import aiohttp
|
2
|
+
import base64
|
2
3
|
import json
|
3
4
|
import requests
|
4
5
|
|
@@ -140,7 +141,7 @@ class Coop(CoopFunctionsMixin):
|
|
140
141
|
if self.api_key:
|
141
142
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
142
143
|
else:
|
143
|
-
headers["Authorization"] =
|
144
|
+
headers["Authorization"] = "Bearer None"
|
144
145
|
return headers
|
145
146
|
|
146
147
|
def _send_server_request(
|
@@ -149,7 +150,7 @@ class Coop(CoopFunctionsMixin):
|
|
149
150
|
method: str,
|
150
151
|
payload: Optional[dict[str, Any]] = None,
|
151
152
|
params: Optional[dict[str, Any]] = None,
|
152
|
-
timeout: Optional[float] =
|
153
|
+
timeout: Optional[float] = 10,
|
153
154
|
) -> requests.Response:
|
154
155
|
"""
|
155
156
|
Send a request to the server and return the response.
|
@@ -159,7 +160,7 @@ class Coop(CoopFunctionsMixin):
|
|
159
160
|
if payload is None:
|
160
161
|
timeout = 40
|
161
162
|
elif (
|
162
|
-
method.upper() == "POST"
|
163
|
+
(method.upper() == "POST" or method.upper() == "PATCH")
|
163
164
|
and "json_string" in payload
|
164
165
|
and payload.get("json_string") is not None
|
165
166
|
):
|
@@ -179,7 +180,9 @@ class Coop(CoopFunctionsMixin):
|
|
179
180
|
timeout=timeout,
|
180
181
|
)
|
181
182
|
else:
|
182
|
-
|
183
|
+
from .exceptions import CoopInvalidMethodError
|
184
|
+
|
185
|
+
raise CoopInvalidMethodError(f"Invalid {method=}.")
|
183
186
|
except requests.ConnectionError:
|
184
187
|
raise requests.ConnectionError(f"Could not connect to the server at {url}.")
|
185
188
|
|
@@ -226,7 +229,8 @@ class Coop(CoopFunctionsMixin):
|
|
226
229
|
"""
|
227
230
|
# Get EDSL version from header
|
228
231
|
# breakpoint()
|
229
|
-
|
232
|
+
# Commented out as currently unused
|
233
|
+
# server_edsl_version = response.headers.get("X-EDSL-Version")
|
230
234
|
|
231
235
|
# if server_edsl_version:
|
232
236
|
# if self._user_version_is_outdated(
|
@@ -266,7 +270,7 @@ class Coop(CoopFunctionsMixin):
|
|
266
270
|
|
267
271
|
print("\n✨ API key retrieved.")
|
268
272
|
|
269
|
-
if
|
273
|
+
if self.ep_key_handler.ask_to_store(api_key):
|
270
274
|
pass
|
271
275
|
else:
|
272
276
|
path_to_env = write_api_key_to_env(api_key)
|
@@ -299,13 +303,19 @@ class Coop(CoopFunctionsMixin):
|
|
299
303
|
message = root.find("Message").text
|
300
304
|
details = root.find("Details").text
|
301
305
|
except Exception:
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
"
|
306
|
+
from .exceptions import CoopServerResponseError
|
307
|
+
|
308
|
+
raise CoopServerResponseError(
|
309
|
+
f"Server returned status code {response.status_code}. "
|
310
|
+
f"XML response could not be decoded. "
|
311
|
+
f"The server response was: {response.text}"
|
306
312
|
)
|
307
313
|
|
308
|
-
|
314
|
+
from .exceptions import CoopServerResponseError
|
315
|
+
|
316
|
+
raise CoopServerResponseError(
|
317
|
+
f"An error occurred: {code} - {message} - {details}"
|
318
|
+
)
|
309
319
|
|
310
320
|
def _poll_for_api_key(
|
311
321
|
self, edsl_auth_token: str, timeout: int = 120
|
@@ -432,6 +442,23 @@ class Coop(CoopFunctionsMixin):
|
|
432
442
|
else:
|
433
443
|
return None
|
434
444
|
|
445
|
+
def _scenario_is_file_store(self, scenario_dict: dict) -> bool:
|
446
|
+
"""
|
447
|
+
Check if the scenario object is a valid FileStore.
|
448
|
+
|
449
|
+
Matches keys in the scenario dict against the expected keys for a FileStore.
|
450
|
+
"""
|
451
|
+
file_store_keys = [
|
452
|
+
"path",
|
453
|
+
"base64_string",
|
454
|
+
"binary",
|
455
|
+
"suffix",
|
456
|
+
"mime_type",
|
457
|
+
"external_locations",
|
458
|
+
"extracted_text",
|
459
|
+
]
|
460
|
+
return all(key in scenario_dict.keys() for key in file_store_keys)
|
461
|
+
|
435
462
|
def create(
|
436
463
|
self,
|
437
464
|
object: EDSLObject,
|
@@ -471,21 +498,30 @@ class Coop(CoopFunctionsMixin):
|
|
471
498
|
>>> print(result["url"]) # URL to access the survey
|
472
499
|
"""
|
473
500
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
|
501
|
+
object_dict = object.to_dict()
|
502
|
+
if object_type == "scenario" and self._scenario_is_file_store(object_dict):
|
503
|
+
file_store_metadata = {
|
504
|
+
"suffix": object_dict["suffix"],
|
505
|
+
"mime_type": object_dict["mime_type"],
|
506
|
+
}
|
507
|
+
else:
|
508
|
+
file_store_metadata = None
|
474
509
|
response = self._send_server_request(
|
475
|
-
uri=
|
510
|
+
uri="api/v0/object",
|
476
511
|
method="POST",
|
477
512
|
payload={
|
478
513
|
"description": description,
|
479
514
|
"alias": alias,
|
480
515
|
"json_string": (
|
481
516
|
json.dumps(
|
482
|
-
|
517
|
+
object_dict,
|
483
518
|
default=self._json_handle_none,
|
484
519
|
)
|
485
520
|
if object_type != "scenario"
|
486
521
|
else ""
|
487
522
|
),
|
488
523
|
"object_type": object_type,
|
524
|
+
"file_store_metadata": file_store_metadata,
|
489
525
|
"visibility": visibility,
|
490
526
|
"version": self._edsl_version,
|
491
527
|
},
|
@@ -495,19 +531,57 @@ class Coop(CoopFunctionsMixin):
|
|
495
531
|
|
496
532
|
if object_type == "scenario":
|
497
533
|
json_data = json.dumps(
|
498
|
-
|
534
|
+
object_dict,
|
499
535
|
default=self._json_handle_none,
|
500
536
|
)
|
501
537
|
headers = {"Content-Type": "application/json"}
|
502
538
|
if response_json.get("upload_signed_url"):
|
503
539
|
signed_url = response_json.get("upload_signed_url")
|
504
540
|
else:
|
505
|
-
|
541
|
+
from .exceptions import CoopResponseError
|
542
|
+
|
543
|
+
raise CoopResponseError("No signed url was provided received")
|
506
544
|
|
507
545
|
response = requests.put(
|
508
546
|
signed_url, data=json_data.encode(), headers=headers
|
509
547
|
)
|
510
548
|
self._resolve_gcs_response(response)
|
549
|
+
|
550
|
+
file_store_upload_signed_url = response_json.get(
|
551
|
+
"file_store_upload_signed_url"
|
552
|
+
)
|
553
|
+
if file_store_metadata and not file_store_upload_signed_url:
|
554
|
+
from .exceptions import CoopResponseError
|
555
|
+
|
556
|
+
raise CoopResponseError("No file store signed url provided.")
|
557
|
+
elif file_store_metadata:
|
558
|
+
headers = {"Content-Type": file_store_metadata["mime_type"]}
|
559
|
+
# Lint json files prior to upload
|
560
|
+
if file_store_metadata["suffix"] == "json":
|
561
|
+
file_store_bytes = base64.b64decode(object_dict["base64_string"])
|
562
|
+
pretty_json_string = json.dumps(
|
563
|
+
json.loads(file_store_bytes), indent=4
|
564
|
+
)
|
565
|
+
byte_data = pretty_json_string.encode("utf-8")
|
566
|
+
# Lint python files prior to upload
|
567
|
+
elif file_store_metadata["suffix"] == "py":
|
568
|
+
import black
|
569
|
+
|
570
|
+
file_store_bytes = base64.b64decode(object_dict["base64_string"])
|
571
|
+
python_string = file_store_bytes.decode("utf-8")
|
572
|
+
formatted_python_string = black.format_str(
|
573
|
+
python_string, mode=black.Mode()
|
574
|
+
)
|
575
|
+
byte_data = formatted_python_string.encode("utf-8")
|
576
|
+
else:
|
577
|
+
byte_data = base64.b64decode(object_dict["base64_string"])
|
578
|
+
response = requests.put(
|
579
|
+
file_store_upload_signed_url,
|
580
|
+
data=byte_data,
|
581
|
+
headers=headers,
|
582
|
+
)
|
583
|
+
self._resolve_gcs_response(response)
|
584
|
+
|
511
585
|
owner_username = response_json.get("owner_username")
|
512
586
|
object_alias = response_json.get("alias")
|
513
587
|
|
@@ -519,7 +593,6 @@ class Coop(CoopFunctionsMixin):
|
|
519
593
|
"uuid": response_json.get("uuid"),
|
520
594
|
"version": self._edsl_version,
|
521
595
|
"visibility": response_json.get("visibility"),
|
522
|
-
"upload_signed_url": response_json.get("upload_signed_url", None),
|
523
596
|
}
|
524
597
|
|
525
598
|
def get(
|
@@ -566,13 +639,13 @@ class Coop(CoopFunctionsMixin):
|
|
566
639
|
|
567
640
|
if obj_uuid:
|
568
641
|
response = self._send_server_request(
|
569
|
-
uri=
|
642
|
+
uri="api/v0/object",
|
570
643
|
method="GET",
|
571
644
|
params={"uuid": obj_uuid},
|
572
645
|
)
|
573
646
|
else:
|
574
647
|
response = self._send_server_request(
|
575
|
-
uri=
|
648
|
+
uri="api/v0/object/alias",
|
576
649
|
method="GET",
|
577
650
|
params={"owner_username": owner_username, "alias": alias},
|
578
651
|
)
|
@@ -586,7 +659,11 @@ class Coop(CoopFunctionsMixin):
|
|
586
659
|
json_string = object_data.text
|
587
660
|
object_type = response.json().get("object_type")
|
588
661
|
if expected_object_type and object_type != expected_object_type:
|
589
|
-
|
662
|
+
from .exceptions import CoopObjectTypeError
|
663
|
+
|
664
|
+
raise CoopObjectTypeError(
|
665
|
+
f"Expected {expected_object_type=} but got {object_type=}"
|
666
|
+
)
|
590
667
|
edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
|
591
668
|
object = edsl_class.from_dict(json.loads(json_string))
|
592
669
|
return object
|
@@ -597,7 +674,7 @@ class Coop(CoopFunctionsMixin):
|
|
597
674
|
"""
|
598
675
|
edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
|
599
676
|
response = self._send_server_request(
|
600
|
-
uri=
|
677
|
+
uri="api/v0/objects",
|
601
678
|
method="GET",
|
602
679
|
params={"type": object_type},
|
603
680
|
)
|
@@ -677,7 +754,9 @@ class Coop(CoopFunctionsMixin):
|
|
677
754
|
and value is None
|
678
755
|
and alias is None
|
679
756
|
):
|
680
|
-
|
757
|
+
from .exceptions import CoopPatchError
|
758
|
+
|
759
|
+
raise CoopPatchError("Nothing to patch.")
|
681
760
|
|
682
761
|
obj_uuid, owner_username, obj_alias = self._resolve_uuid_or_alias(url_or_uuid)
|
683
762
|
|
@@ -808,7 +887,9 @@ class Coop(CoopFunctionsMixin):
|
|
808
887
|
[CacheEntry(...), CacheEntry(...), ...]
|
809
888
|
"""
|
810
889
|
if job_uuid is None:
|
811
|
-
|
890
|
+
from .exceptions import CoopValueError
|
891
|
+
|
892
|
+
raise CoopValueError("Must provide a job_uuid.")
|
812
893
|
response = self._send_server_request(
|
813
894
|
uri="api/v0/remote-cache/get-many-by-job",
|
814
895
|
method="POST",
|
@@ -836,7 +917,9 @@ class Coop(CoopFunctionsMixin):
|
|
836
917
|
[CacheEntry(...), CacheEntry(...), ...]
|
837
918
|
"""
|
838
919
|
if select_keys is None or len(select_keys) == 0:
|
839
|
-
|
920
|
+
from .exceptions import CoopValueError
|
921
|
+
|
922
|
+
raise CoopValueError("Must provide a non-empty list of select_keys.")
|
840
923
|
response = self._send_server_request(
|
841
924
|
uri="api/v0/remote-cache/get-many-by-key",
|
842
925
|
method="POST",
|
@@ -1099,7 +1182,9 @@ class Coop(CoopFunctionsMixin):
|
|
1099
1182
|
... print(f"Results available at: {job_status['results_url']}")
|
1100
1183
|
"""
|
1101
1184
|
if job_uuid is None and results_uuid is None:
|
1102
|
-
|
1185
|
+
from .exceptions import CoopValueError
|
1186
|
+
|
1187
|
+
raise CoopValueError("Either job_uuid or results_uuid must be provided.")
|
1103
1188
|
elif job_uuid is not None:
|
1104
1189
|
params = {"job_uuid": job_uuid}
|
1105
1190
|
else:
|
@@ -1136,7 +1221,7 @@ class Coop(CoopFunctionsMixin):
|
|
1136
1221
|
"latest_error_report_uuid": latest_error_report_uuid,
|
1137
1222
|
"latest_error_report_url": latest_error_report_url,
|
1138
1223
|
"status": data.get("status"),
|
1139
|
-
"reason": data.get("
|
1224
|
+
"reason": data.get("latest_failure_reason"),
|
1140
1225
|
"credits_consumed": data.get("price"),
|
1141
1226
|
"version": data.get("version"),
|
1142
1227
|
}
|
@@ -1173,7 +1258,9 @@ class Coop(CoopFunctionsMixin):
|
|
1173
1258
|
elif isinstance(input, Survey):
|
1174
1259
|
job = Jobs(survey=input)
|
1175
1260
|
else:
|
1176
|
-
|
1261
|
+
from .exceptions import CoopTypeError
|
1262
|
+
|
1263
|
+
raise CoopTypeError("Input must be either a Job or a Survey.")
|
1177
1264
|
|
1178
1265
|
response = self._send_server_request(
|
1179
1266
|
uri="api/v0/remote-inference/cost",
|
@@ -1215,7 +1302,7 @@ class Coop(CoopFunctionsMixin):
|
|
1215
1302
|
)
|
1216
1303
|
survey_uuid = survey_details.get("uuid")
|
1217
1304
|
response = self._send_server_request(
|
1218
|
-
uri=
|
1305
|
+
uri="api/v0/projects/create-from-survey",
|
1219
1306
|
method="POST",
|
1220
1307
|
payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
|
1221
1308
|
)
|
@@ -1308,7 +1395,9 @@ class Coop(CoopFunctionsMixin):
|
|
1308
1395
|
elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
|
1309
1396
|
return {}
|
1310
1397
|
else:
|
1311
|
-
|
1398
|
+
from .exceptions import CoopValueError
|
1399
|
+
|
1400
|
+
raise CoopValueError(
|
1312
1401
|
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
1313
1402
|
)
|
1314
1403
|
|
@@ -1464,7 +1553,9 @@ class Coop(CoopFunctionsMixin):
|
|
1464
1553
|
api_key = self._poll_for_api_key(edsl_auth_token)
|
1465
1554
|
|
1466
1555
|
if api_key is None:
|
1467
|
-
|
1556
|
+
from .exceptions import CoopTimeoutError
|
1557
|
+
|
1558
|
+
raise CoopTimeoutError("Timed out waiting for login. Please try again.")
|
1468
1559
|
|
1469
1560
|
path_to_env = write_api_key_to_env(api_key)
|
1470
1561
|
print("\n✨ API key retrieved and written to .env file at the following path:")
|
edsl/coop/coop_functions.py
CHANGED
edsl/coop/ep_key_handling.py
CHANGED
@@ -70,7 +70,7 @@ class ExpectedParrotKeyHandler:
|
|
70
70
|
|
71
71
|
def ok_to_ask_to_store(self):
|
72
72
|
"""Check if it's okay to ask the user to store the key."""
|
73
|
-
from
|
73
|
+
from ..config import CONFIG
|
74
74
|
|
75
75
|
if CONFIG.get("EDSL_RUN_MODE") != "production":
|
76
76
|
return False
|