edsl 0.1.38__py3-none-any.whl → 0.1.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +31 -60
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +9 -18
- edsl/agents/AgentList.py +8 -59
- edsl/agents/Invigilator.py +7 -18
- edsl/agents/InvigilatorBase.py +19 -0
- edsl/agents/PromptConstructor.py +4 -5
- edsl/config.py +0 -8
- edsl/coop/coop.py +7 -74
- edsl/data/Cache.py +2 -27
- edsl/data/CacheEntry.py +3 -8
- edsl/data/RemoteCacheSync.py +19 -0
- edsl/enums.py +0 -2
- edsl/inference_services/GoogleService.py +15 -7
- edsl/inference_services/registry.py +0 -2
- edsl/jobs/Jobs.py +548 -88
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +35 -140
- edsl/jobs/runners/JobsRunnerStatus.py +2 -0
- edsl/jobs/tasks/TaskHistory.py +16 -15
- edsl/language_models/LanguageModel.py +84 -44
- edsl/language_models/ModelList.py +1 -47
- edsl/language_models/registry.py +4 -57
- edsl/prompts/Prompt.py +3 -8
- edsl/questions/QuestionBase.py +16 -20
- edsl/questions/QuestionExtract.py +4 -3
- edsl/questions/question_registry.py +6 -36
- edsl/results/Dataset.py +15 -146
- edsl/results/DatasetExportMixin.py +217 -231
- edsl/results/DatasetTree.py +4 -134
- edsl/results/Result.py +9 -18
- edsl/results/Results.py +51 -145
- edsl/scenarios/FileStore.py +13 -187
- edsl/scenarios/Scenario.py +4 -61
- edsl/scenarios/ScenarioList.py +62 -237
- edsl/surveys/Survey.py +2 -16
- edsl/surveys/SurveyFlowVisualizationMixin.py +9 -67
- edsl/surveys/instructions/Instruction.py +0 -12
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +9 -18
- edsl/utilities/utilities.py +0 -15
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/METADATA +1 -2
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/RECORD +45 -53
- edsl/inference_services/PerplexityService.py +0 -163
- edsl/jobs/JobsChecks.py +0 -147
- edsl/jobs/JobsPrompts.py +0 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +0 -239
- edsl/results/CSSParameterizer.py +0 -108
- edsl/results/TableDisplay.py +0 -198
- edsl/results/table_display.css +0 -78
- edsl/scenarios/ScenarioJoin.py +0 -127
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/WHEEL +0 -0
edsl/jobs/JobsPrompts.py
DELETED
@@ -1,268 +0,0 @@
|
|
1
|
-
from typing import List, TYPE_CHECKING
|
2
|
-
|
3
|
-
from edsl.results.Dataset import Dataset
|
4
|
-
|
5
|
-
if TYPE_CHECKING:
|
6
|
-
from edsl.jobs import Jobs
|
7
|
-
|
8
|
-
# from edsl.jobs.interviews.Interview import Interview
|
9
|
-
# from edsl.results.Dataset import Dataset
|
10
|
-
# from edsl.agents.AgentList import AgentList
|
11
|
-
# from edsl.scenarios.ScenarioList import ScenarioList
|
12
|
-
# from edsl.surveys.Survey import Survey
|
13
|
-
|
14
|
-
|
15
|
-
class JobsPrompts:
|
16
|
-
def __init__(self, jobs: "Jobs"):
|
17
|
-
self.interviews = jobs.interviews()
|
18
|
-
self.agents = jobs.agents
|
19
|
-
self.scenarios = jobs.scenarios
|
20
|
-
self.survey = jobs.survey
|
21
|
-
self._price_lookup = None
|
22
|
-
|
23
|
-
@property
|
24
|
-
def price_lookup(self):
|
25
|
-
if self._price_lookup is None:
|
26
|
-
from edsl import Coop
|
27
|
-
|
28
|
-
c = Coop()
|
29
|
-
self._price_lookup = c.fetch_prices()
|
30
|
-
return self._price_lookup
|
31
|
-
|
32
|
-
def prompts(self) -> "Dataset":
|
33
|
-
"""Return a Dataset of prompts that will be used.
|
34
|
-
|
35
|
-
>>> from edsl.jobs import Jobs
|
36
|
-
>>> Jobs.example().prompts()
|
37
|
-
Dataset(...)
|
38
|
-
"""
|
39
|
-
interviews = self.interviews
|
40
|
-
interview_indices = []
|
41
|
-
question_names = []
|
42
|
-
user_prompts = []
|
43
|
-
system_prompts = []
|
44
|
-
scenario_indices = []
|
45
|
-
agent_indices = []
|
46
|
-
models = []
|
47
|
-
costs = []
|
48
|
-
|
49
|
-
for interview_index, interview in enumerate(interviews):
|
50
|
-
invigilators = [
|
51
|
-
interview._get_invigilator(question)
|
52
|
-
for question in self.survey.questions
|
53
|
-
]
|
54
|
-
for _, invigilator in enumerate(invigilators):
|
55
|
-
prompts = invigilator.get_prompts()
|
56
|
-
user_prompt = prompts["user_prompt"]
|
57
|
-
system_prompt = prompts["system_prompt"]
|
58
|
-
user_prompts.append(user_prompt)
|
59
|
-
system_prompts.append(system_prompt)
|
60
|
-
agent_index = self.agents.index(invigilator.agent)
|
61
|
-
agent_indices.append(agent_index)
|
62
|
-
interview_indices.append(interview_index)
|
63
|
-
scenario_index = self.scenarios.index(invigilator.scenario)
|
64
|
-
scenario_indices.append(scenario_index)
|
65
|
-
models.append(invigilator.model.model)
|
66
|
-
question_names.append(invigilator.question.question_name)
|
67
|
-
|
68
|
-
prompt_cost = self.estimate_prompt_cost(
|
69
|
-
system_prompt=system_prompt,
|
70
|
-
user_prompt=user_prompt,
|
71
|
-
price_lookup=self.price_lookup,
|
72
|
-
inference_service=invigilator.model._inference_service_,
|
73
|
-
model=invigilator.model.model,
|
74
|
-
)
|
75
|
-
costs.append(prompt_cost["cost_usd"])
|
76
|
-
|
77
|
-
d = Dataset(
|
78
|
-
[
|
79
|
-
{"user_prompt": user_prompts},
|
80
|
-
{"system_prompt": system_prompts},
|
81
|
-
{"interview_index": interview_indices},
|
82
|
-
{"question_name": question_names},
|
83
|
-
{"scenario_index": scenario_indices},
|
84
|
-
{"agent_index": agent_indices},
|
85
|
-
{"model": models},
|
86
|
-
{"estimated_cost": costs},
|
87
|
-
]
|
88
|
-
)
|
89
|
-
return d
|
90
|
-
|
91
|
-
@staticmethod
|
92
|
-
def estimate_prompt_cost(
|
93
|
-
system_prompt: str,
|
94
|
-
user_prompt: str,
|
95
|
-
price_lookup: dict,
|
96
|
-
inference_service: str,
|
97
|
-
model: str,
|
98
|
-
) -> dict:
|
99
|
-
"""Estimates the cost of a prompt. Takes piping into account."""
|
100
|
-
import math
|
101
|
-
|
102
|
-
def get_piping_multiplier(prompt: str):
|
103
|
-
"""Returns 2 if a prompt includes Jinja braces, and 1 otherwise."""
|
104
|
-
|
105
|
-
if "{{" in prompt and "}}" in prompt:
|
106
|
-
return 2
|
107
|
-
return 1
|
108
|
-
|
109
|
-
# Look up prices per token
|
110
|
-
key = (inference_service, model)
|
111
|
-
|
112
|
-
try:
|
113
|
-
relevant_prices = price_lookup[key]
|
114
|
-
|
115
|
-
service_input_token_price = float(
|
116
|
-
relevant_prices["input"]["service_stated_token_price"]
|
117
|
-
)
|
118
|
-
service_input_token_qty = float(
|
119
|
-
relevant_prices["input"]["service_stated_token_qty"]
|
120
|
-
)
|
121
|
-
input_price_per_token = service_input_token_price / service_input_token_qty
|
122
|
-
|
123
|
-
service_output_token_price = float(
|
124
|
-
relevant_prices["output"]["service_stated_token_price"]
|
125
|
-
)
|
126
|
-
service_output_token_qty = float(
|
127
|
-
relevant_prices["output"]["service_stated_token_qty"]
|
128
|
-
)
|
129
|
-
output_price_per_token = (
|
130
|
-
service_output_token_price / service_output_token_qty
|
131
|
-
)
|
132
|
-
|
133
|
-
except KeyError:
|
134
|
-
# A KeyError is likely to occur if we cannot retrieve prices (the price_lookup dict is empty)
|
135
|
-
# Use a sensible default
|
136
|
-
|
137
|
-
import warnings
|
138
|
-
|
139
|
-
warnings.warn(
|
140
|
-
"Price data could not be retrieved. Using default estimates for input and output token prices. Input: $0.15 / 1M tokens; Output: $0.60 / 1M tokens"
|
141
|
-
)
|
142
|
-
input_price_per_token = 0.00000015 # $0.15 / 1M tokens
|
143
|
-
output_price_per_token = 0.00000060 # $0.60 / 1M tokens
|
144
|
-
|
145
|
-
# Compute the number of characters (double if the question involves piping)
|
146
|
-
user_prompt_chars = len(str(user_prompt)) * get_piping_multiplier(
|
147
|
-
str(user_prompt)
|
148
|
-
)
|
149
|
-
system_prompt_chars = len(str(system_prompt)) * get_piping_multiplier(
|
150
|
-
str(system_prompt)
|
151
|
-
)
|
152
|
-
|
153
|
-
# Convert into tokens (1 token approx. equals 4 characters)
|
154
|
-
input_tokens = (user_prompt_chars + system_prompt_chars) // 4
|
155
|
-
|
156
|
-
output_tokens = math.ceil(0.75 * input_tokens)
|
157
|
-
|
158
|
-
cost = (
|
159
|
-
input_tokens * input_price_per_token
|
160
|
-
+ output_tokens * output_price_per_token
|
161
|
-
)
|
162
|
-
|
163
|
-
return {
|
164
|
-
"input_tokens": input_tokens,
|
165
|
-
"output_tokens": output_tokens,
|
166
|
-
"cost_usd": cost,
|
167
|
-
}
|
168
|
-
|
169
|
-
def estimate_job_cost_from_external_prices(
|
170
|
-
self, price_lookup: dict, iterations: int = 1
|
171
|
-
) -> dict:
|
172
|
-
"""
|
173
|
-
Estimates the cost of a job according to the following assumptions:
|
174
|
-
|
175
|
-
- 1 token = 4 characters.
|
176
|
-
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
177
|
-
|
178
|
-
price_lookup is an external pricing dictionary.
|
179
|
-
"""
|
180
|
-
|
181
|
-
import pandas as pd
|
182
|
-
|
183
|
-
interviews = self.interviews
|
184
|
-
data = []
|
185
|
-
for interview in interviews:
|
186
|
-
invigilators = [
|
187
|
-
interview._get_invigilator(question)
|
188
|
-
for question in self.survey.questions
|
189
|
-
]
|
190
|
-
for invigilator in invigilators:
|
191
|
-
prompts = invigilator.get_prompts()
|
192
|
-
|
193
|
-
# By this point, agent and scenario data has already been added to the prompts
|
194
|
-
user_prompt = prompts["user_prompt"]
|
195
|
-
system_prompt = prompts["system_prompt"]
|
196
|
-
inference_service = invigilator.model._inference_service_
|
197
|
-
model = invigilator.model.model
|
198
|
-
|
199
|
-
prompt_cost = self.estimate_prompt_cost(
|
200
|
-
system_prompt=system_prompt,
|
201
|
-
user_prompt=user_prompt,
|
202
|
-
price_lookup=price_lookup,
|
203
|
-
inference_service=inference_service,
|
204
|
-
model=model,
|
205
|
-
)
|
206
|
-
|
207
|
-
data.append(
|
208
|
-
{
|
209
|
-
"user_prompt": user_prompt,
|
210
|
-
"system_prompt": system_prompt,
|
211
|
-
"estimated_input_tokens": prompt_cost["input_tokens"],
|
212
|
-
"estimated_output_tokens": prompt_cost["output_tokens"],
|
213
|
-
"estimated_cost_usd": prompt_cost["cost_usd"],
|
214
|
-
"inference_service": inference_service,
|
215
|
-
"model": model,
|
216
|
-
}
|
217
|
-
)
|
218
|
-
|
219
|
-
df = pd.DataFrame.from_records(data)
|
220
|
-
|
221
|
-
df = (
|
222
|
-
df.groupby(["inference_service", "model"])
|
223
|
-
.agg(
|
224
|
-
{
|
225
|
-
"estimated_cost_usd": "sum",
|
226
|
-
"estimated_input_tokens": "sum",
|
227
|
-
"estimated_output_tokens": "sum",
|
228
|
-
}
|
229
|
-
)
|
230
|
-
.reset_index()
|
231
|
-
)
|
232
|
-
df["estimated_cost_usd"] = df["estimated_cost_usd"] * iterations
|
233
|
-
df["estimated_input_tokens"] = df["estimated_input_tokens"] * iterations
|
234
|
-
df["estimated_output_tokens"] = df["estimated_output_tokens"] * iterations
|
235
|
-
|
236
|
-
estimated_costs_by_model = df.to_dict("records")
|
237
|
-
|
238
|
-
estimated_total_cost = sum(
|
239
|
-
model["estimated_cost_usd"] for model in estimated_costs_by_model
|
240
|
-
)
|
241
|
-
estimated_total_input_tokens = sum(
|
242
|
-
model["estimated_input_tokens"] for model in estimated_costs_by_model
|
243
|
-
)
|
244
|
-
estimated_total_output_tokens = sum(
|
245
|
-
model["estimated_output_tokens"] for model in estimated_costs_by_model
|
246
|
-
)
|
247
|
-
|
248
|
-
output = {
|
249
|
-
"estimated_total_cost_usd": estimated_total_cost,
|
250
|
-
"estimated_total_input_tokens": estimated_total_input_tokens,
|
251
|
-
"estimated_total_output_tokens": estimated_total_output_tokens,
|
252
|
-
"model_costs": estimated_costs_by_model,
|
253
|
-
}
|
254
|
-
|
255
|
-
return output
|
256
|
-
|
257
|
-
def estimate_job_cost(self, iterations: int = 1) -> dict:
|
258
|
-
"""
|
259
|
-
Estimates the cost of a job according to the following assumptions:
|
260
|
-
|
261
|
-
- 1 token = 4 characters.
|
262
|
-
- For each prompt, output tokens = input tokens * 0.75, rounded up to the nearest integer.
|
263
|
-
|
264
|
-
Fetches prices from Coop.
|
265
|
-
"""
|
266
|
-
return self.estimate_job_cost_from_external_prices(
|
267
|
-
price_lookup=self.price_lookup, iterations=iterations
|
268
|
-
)
|
@@ -1,239 +0,0 @@
|
|
1
|
-
from typing import Optional, Union, Literal
|
2
|
-
import requests
|
3
|
-
import sys
|
4
|
-
from edsl.exceptions.coop import CoopServerResponseError
|
5
|
-
|
6
|
-
# from edsl.enums import VisibilityType
|
7
|
-
from edsl.results import Results
|
8
|
-
|
9
|
-
|
10
|
-
class JobsRemoteInferenceHandler:
|
11
|
-
def __init__(self, jobs, verbose=False, poll_interval=3):
|
12
|
-
"""
|
13
|
-
>>> from edsl.jobs import Jobs
|
14
|
-
>>> jh = JobsRemoteInferenceHandler(Jobs.example(), verbose=True)
|
15
|
-
>>> jh.use_remote_inference(True)
|
16
|
-
False
|
17
|
-
>>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "failed"}) # doctest: +NORMALIZE_WHITESPACE
|
18
|
-
Job failed.
|
19
|
-
...
|
20
|
-
>>> jh._poll_remote_inference_job({'uuid':1234}, testing_simulated_response={"status": "completed"}) # doctest: +NORMALIZE_WHITESPACE
|
21
|
-
Job completed and Results stored on Coop: None.
|
22
|
-
Results(...)
|
23
|
-
"""
|
24
|
-
self.jobs = jobs
|
25
|
-
self.verbose = verbose
|
26
|
-
self.poll_interval = poll_interval
|
27
|
-
|
28
|
-
self._remote_job_creation_data = None
|
29
|
-
self._job_uuid = None
|
30
|
-
|
31
|
-
@property
|
32
|
-
def remote_job_creation_data(self):
|
33
|
-
return self._remote_job_creation_data
|
34
|
-
|
35
|
-
@property
|
36
|
-
def job_uuid(self):
|
37
|
-
return self._job_uuid
|
38
|
-
|
39
|
-
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
40
|
-
if disable_remote_inference:
|
41
|
-
return False
|
42
|
-
if not disable_remote_inference:
|
43
|
-
try:
|
44
|
-
from edsl import Coop
|
45
|
-
|
46
|
-
user_edsl_settings = Coop().edsl_settings
|
47
|
-
return user_edsl_settings.get("remote_inference", False)
|
48
|
-
except requests.ConnectionError:
|
49
|
-
pass
|
50
|
-
except CoopServerResponseError as e:
|
51
|
-
pass
|
52
|
-
|
53
|
-
return False
|
54
|
-
|
55
|
-
def create_remote_inference_job(
|
56
|
-
self,
|
57
|
-
iterations: int = 1,
|
58
|
-
remote_inference_description: Optional[str] = None,
|
59
|
-
remote_inference_results_visibility: Optional["VisibilityType"] = "unlisted",
|
60
|
-
verbose=False,
|
61
|
-
):
|
62
|
-
""" """
|
63
|
-
from edsl.config import CONFIG
|
64
|
-
from edsl.coop.coop import Coop
|
65
|
-
from rich import print as rich_print
|
66
|
-
|
67
|
-
coop = Coop()
|
68
|
-
print("Remote inference activated. Sending job to server...")
|
69
|
-
remote_job_creation_data = coop.remote_inference_create(
|
70
|
-
self.jobs,
|
71
|
-
description=remote_inference_description,
|
72
|
-
status="queued",
|
73
|
-
iterations=iterations,
|
74
|
-
initial_results_visibility=remote_inference_results_visibility,
|
75
|
-
)
|
76
|
-
job_uuid = remote_job_creation_data.get("uuid")
|
77
|
-
print(f"Job sent to server. (Job uuid={job_uuid}).")
|
78
|
-
|
79
|
-
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
80
|
-
progress_bar_url = f"{expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
81
|
-
|
82
|
-
rich_print(
|
83
|
-
f"View job progress here: [#38bdf8][link={progress_bar_url}]{progress_bar_url}[/link][/#38bdf8]"
|
84
|
-
)
|
85
|
-
|
86
|
-
self._remote_job_creation_data = remote_job_creation_data
|
87
|
-
self._job_uuid = job_uuid
|
88
|
-
# return remote_job_creation_data
|
89
|
-
|
90
|
-
@staticmethod
|
91
|
-
def check_status(job_uuid):
|
92
|
-
from edsl.coop.coop import Coop
|
93
|
-
|
94
|
-
coop = Coop()
|
95
|
-
return coop.remote_inference_get(job_uuid)
|
96
|
-
|
97
|
-
def poll_remote_inference_job(self):
|
98
|
-
return self._poll_remote_inference_job(
|
99
|
-
self.remote_job_creation_data, verbose=self.verbose
|
100
|
-
)
|
101
|
-
|
102
|
-
def _poll_remote_inference_job(
|
103
|
-
self,
|
104
|
-
remote_job_creation_data: dict,
|
105
|
-
verbose=False,
|
106
|
-
poll_interval: Optional[float] = None,
|
107
|
-
testing_simulated_response: Optional[dict] = None,
|
108
|
-
) -> Union[Results, None]:
|
109
|
-
import time
|
110
|
-
from datetime import datetime
|
111
|
-
from edsl.config import CONFIG
|
112
|
-
from edsl.coop.coop import Coop
|
113
|
-
|
114
|
-
if poll_interval is None:
|
115
|
-
poll_interval = self.poll_interval
|
116
|
-
|
117
|
-
expected_parrot_url = CONFIG.get("EXPECTED_PARROT_URL")
|
118
|
-
|
119
|
-
job_uuid = remote_job_creation_data.get("uuid")
|
120
|
-
coop = Coop()
|
121
|
-
|
122
|
-
if testing_simulated_response is not None:
|
123
|
-
remote_job_data_fetcher = lambda job_uuid: testing_simulated_response
|
124
|
-
object_fetcher = (
|
125
|
-
lambda results_uuid, expected_object_type: Results.example()
|
126
|
-
)
|
127
|
-
else:
|
128
|
-
remote_job_data_fetcher = coop.remote_inference_get
|
129
|
-
object_fetcher = coop.get
|
130
|
-
|
131
|
-
job_in_queue = True
|
132
|
-
while job_in_queue:
|
133
|
-
remote_job_data = remote_job_data_fetcher(job_uuid)
|
134
|
-
status = remote_job_data.get("status")
|
135
|
-
if status == "cancelled":
|
136
|
-
print("\r" + " " * 80 + "\r", end="")
|
137
|
-
print("Job cancelled by the user.")
|
138
|
-
print(
|
139
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
140
|
-
)
|
141
|
-
return None
|
142
|
-
elif status == "failed":
|
143
|
-
print("\r" + " " * 80 + "\r", end="")
|
144
|
-
# write to stderr
|
145
|
-
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
146
|
-
if latest_error_report_url:
|
147
|
-
print("Job failed.")
|
148
|
-
print(
|
149
|
-
f"Your job generated exceptions. Details on these exceptions can be found in the following report: {latest_error_report_url}"
|
150
|
-
)
|
151
|
-
print(
|
152
|
-
f"Need support? Post a message at the Expected Parrot Discord channel (https://discord.com/invite/mxAYkjfy9m) or send an email to info@expectedparrot.com."
|
153
|
-
)
|
154
|
-
else:
|
155
|
-
print("Job failed.")
|
156
|
-
print(
|
157
|
-
f"See {expected_parrot_url}/home/remote-inference for more details."
|
158
|
-
)
|
159
|
-
return None
|
160
|
-
elif status == "completed":
|
161
|
-
results_uuid = remote_job_data.get("results_uuid")
|
162
|
-
results_url = remote_job_data.get("results_url")
|
163
|
-
results = object_fetcher(results_uuid, expected_object_type="results")
|
164
|
-
print("\r" + " " * 80 + "\r", end="")
|
165
|
-
print(f"Job completed and Results stored on Coop: {results_url}.")
|
166
|
-
return results
|
167
|
-
else:
|
168
|
-
duration = poll_interval
|
169
|
-
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
170
|
-
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
171
|
-
start_time = time.time()
|
172
|
-
i = 0
|
173
|
-
while time.time() - start_time < duration:
|
174
|
-
print(
|
175
|
-
f"\r{frames[i % len(frames)]} Job status: {status} - last update: {time_checked}",
|
176
|
-
end="",
|
177
|
-
flush=True,
|
178
|
-
)
|
179
|
-
time.sleep(0.1)
|
180
|
-
i += 1
|
181
|
-
|
182
|
-
def use_remote_inference(self, disable_remote_inference: bool) -> bool:
|
183
|
-
if disable_remote_inference:
|
184
|
-
return False
|
185
|
-
if not disable_remote_inference:
|
186
|
-
try:
|
187
|
-
from edsl import Coop
|
188
|
-
|
189
|
-
user_edsl_settings = Coop().edsl_settings
|
190
|
-
return user_edsl_settings.get("remote_inference", False)
|
191
|
-
except requests.ConnectionError:
|
192
|
-
pass
|
193
|
-
except CoopServerResponseError as e:
|
194
|
-
pass
|
195
|
-
|
196
|
-
return False
|
197
|
-
|
198
|
-
async def create_and_poll_remote_job(
|
199
|
-
self,
|
200
|
-
iterations: int = 1,
|
201
|
-
remote_inference_description: Optional[str] = None,
|
202
|
-
remote_inference_results_visibility: Optional[
|
203
|
-
Literal["private", "public", "unlisted"]
|
204
|
-
] = "unlisted",
|
205
|
-
) -> Union[Results, None]:
|
206
|
-
"""
|
207
|
-
Creates and polls a remote inference job asynchronously.
|
208
|
-
Reuses existing synchronous methods but runs them in an async context.
|
209
|
-
|
210
|
-
:param iterations: Number of times to run each interview
|
211
|
-
:param remote_inference_description: Optional description for the remote job
|
212
|
-
:param remote_inference_results_visibility: Visibility setting for results
|
213
|
-
:return: Results object if successful, None if job fails or is cancelled
|
214
|
-
"""
|
215
|
-
import asyncio
|
216
|
-
from functools import partial
|
217
|
-
|
218
|
-
# Create job using existing method
|
219
|
-
loop = asyncio.get_event_loop()
|
220
|
-
remote_job_creation_data = await loop.run_in_executor(
|
221
|
-
None,
|
222
|
-
partial(
|
223
|
-
self.create_remote_inference_job,
|
224
|
-
iterations=iterations,
|
225
|
-
remote_inference_description=remote_inference_description,
|
226
|
-
remote_inference_results_visibility=remote_inference_results_visibility,
|
227
|
-
),
|
228
|
-
)
|
229
|
-
|
230
|
-
# Poll using existing method but with async sleep
|
231
|
-
return await loop.run_in_executor(
|
232
|
-
None, partial(self.poll_remote_inference_job, remote_job_creation_data)
|
233
|
-
)
|
234
|
-
|
235
|
-
|
236
|
-
if __name__ == "__main__":
|
237
|
-
import doctest
|
238
|
-
|
239
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/results/CSSParameterizer.py
DELETED
@@ -1,108 +0,0 @@
|
|
1
|
-
import re
|
2
|
-
from typing import Dict, Set, Optional
|
3
|
-
|
4
|
-
|
5
|
-
class CSSParameterizer:
|
6
|
-
"""A utility class to parameterize CSS with custom properties (variables)."""
|
7
|
-
|
8
|
-
def __init__(self, css_content: str):
|
9
|
-
"""
|
10
|
-
Initialize with CSS content to be parameterized.
|
11
|
-
|
12
|
-
Args:
|
13
|
-
css_content (str): The CSS content containing var() declarations
|
14
|
-
"""
|
15
|
-
self.css_content = css_content
|
16
|
-
self._extract_variables()
|
17
|
-
|
18
|
-
def _extract_variables(self) -> None:
|
19
|
-
"""Extract all CSS custom properties (variables) from the CSS content."""
|
20
|
-
# Find all var(...) declarations in the CSS
|
21
|
-
var_pattern = r"var\((--[a-zA-Z0-9-]+)\)"
|
22
|
-
self.variables = set(re.findall(var_pattern, self.css_content))
|
23
|
-
|
24
|
-
def _validate_parameters(self, parameters: Dict[str, str]) -> Set[str]:
|
25
|
-
"""
|
26
|
-
Validate the provided parameters against the CSS variables.
|
27
|
-
|
28
|
-
Args:
|
29
|
-
parameters (Dict[str, str]): Dictionary of variable names and their values
|
30
|
-
|
31
|
-
Returns:
|
32
|
-
Set[str]: Set of missing variables
|
33
|
-
"""
|
34
|
-
# Convert parameter keys to CSS variable format if they don't already have --
|
35
|
-
formatted_params = {
|
36
|
-
f"--{k}" if not k.startswith("--") else k for k in parameters.keys()
|
37
|
-
}
|
38
|
-
|
39
|
-
# print("Variables from CSS:", self.variables)
|
40
|
-
# print("Formatted parameters:", formatted_params)
|
41
|
-
|
42
|
-
# Find missing and extra variables
|
43
|
-
missing_vars = self.variables - formatted_params
|
44
|
-
extra_vars = formatted_params - self.variables
|
45
|
-
|
46
|
-
if extra_vars:
|
47
|
-
print(f"Warning: Found unused parameters: {extra_vars}")
|
48
|
-
|
49
|
-
return missing_vars
|
50
|
-
|
51
|
-
def generate_root(self, **parameters: str) -> Optional[str]:
|
52
|
-
"""
|
53
|
-
Generate a :root block with the provided parameters.
|
54
|
-
|
55
|
-
Args:
|
56
|
-
**parameters: Keyword arguments where keys are variable names and values are their values
|
57
|
-
|
58
|
-
Returns:
|
59
|
-
str: Generated :root block with variables, or None if validation fails
|
60
|
-
|
61
|
-
Example:
|
62
|
-
>>> css = "body { height: var(--bodyHeight); }"
|
63
|
-
>>> parameterizer = CSSParameterizer(css)
|
64
|
-
>>> parameterizer.apply_parameters({'bodyHeight':"100vh"})
|
65
|
-
':root {\\n --bodyHeight: 100vh;\\n}\\n\\nbody { height: var(--bodyHeight); }'
|
66
|
-
"""
|
67
|
-
missing_vars = self._validate_parameters(parameters)
|
68
|
-
|
69
|
-
if missing_vars:
|
70
|
-
print(f"Error: Missing required variables: {missing_vars}")
|
71
|
-
return None
|
72
|
-
|
73
|
-
# Format parameters with -- prefix if not present
|
74
|
-
formatted_params = {
|
75
|
-
f"--{k}" if not k.startswith("--") else k: v for k, v in parameters.items()
|
76
|
-
}
|
77
|
-
|
78
|
-
# Generate the :root block
|
79
|
-
root_block = [":root {"]
|
80
|
-
for var_name, value in sorted(formatted_params.items()):
|
81
|
-
if var_name in self.variables:
|
82
|
-
root_block.append(f" {var_name}: {value};")
|
83
|
-
root_block.append("}")
|
84
|
-
|
85
|
-
return "\n".join(root_block)
|
86
|
-
|
87
|
-
def apply_parameters(self, parameters: dict) -> Optional[str]:
|
88
|
-
"""
|
89
|
-
Generate the complete CSS with the :root block and original CSS content.
|
90
|
-
|
91
|
-
Args:
|
92
|
-
**parameters: Keyword arguments where keys are variable names and values are their values
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
str: Complete CSS with :root block and original content, or None if validation fails
|
96
|
-
"""
|
97
|
-
root_block = self.generate_root(**parameters)
|
98
|
-
if root_block is None:
|
99
|
-
return None
|
100
|
-
|
101
|
-
return f"{root_block}\n\n{self.css_content}"
|
102
|
-
|
103
|
-
|
104
|
-
# Example usage
|
105
|
-
if __name__ == "__main__":
|
106
|
-
import doctest
|
107
|
-
|
108
|
-
doctest.testmod()
|