edsl 0.1.31.dev2__py3-none-any.whl → 0.1.31.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +9 -3
- edsl/config.py +4 -0
- edsl/coop/coop.py +4 -0
- edsl/enums.py +2 -1
- edsl/inference_services/DeepInfraService.py +4 -90
- edsl/inference_services/GroqService.py +19 -0
- edsl/inference_services/OpenAIService.py +64 -22
- edsl/inference_services/registry.py +2 -1
- edsl/jobs/Jobs.py +3 -2
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +3 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +89 -79
- edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
- edsl/jobs/runners/JobsRunnerStatusMixin.py +264 -38
- edsl/jobs/tasks/TaskCreators.py +8 -2
- edsl/language_models/LanguageModel.py +7 -1
- edsl/language_models/registry.py +4 -0
- {edsl-0.1.31.dev2.dist-info → edsl-0.1.31.dev4.dist-info}/METADATA +2 -1
- {edsl-0.1.31.dev2.dist-info → edsl-0.1.31.dev4.dist-info}/RECORD +21 -20
- {edsl-0.1.31.dev2.dist-info → edsl-0.1.31.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.31.dev2.dist-info → edsl-0.1.31.dev4.dist-info}/WHEEL +0 -0
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.31.
|
1
|
+
__version__ = "0.1.31.dev4"
|
edsl/agents/Invigilator.py
CHANGED
@@ -18,7 +18,12 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
18
18
|
"""An invigilator that uses an AI model to answer questions."""
|
19
19
|
|
20
20
|
async def async_answer_question(self) -> AgentResponseDict:
|
21
|
-
"""Answer a question using the AI model.
|
21
|
+
"""Answer a question using the AI model.
|
22
|
+
|
23
|
+
>>> i = InvigilatorAI.example()
|
24
|
+
>>> i.answer_question()
|
25
|
+
{'message': '{"answer": "SPAM!"}'}
|
26
|
+
"""
|
22
27
|
params = self.get_prompts() | {"iteration": self.iteration}
|
23
28
|
raw_response = await self.async_get_response(**params)
|
24
29
|
data = {
|
@@ -29,6 +34,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
29
34
|
"raw_model_response": raw_response["raw_model_response"],
|
30
35
|
}
|
31
36
|
response = self._format_raw_response(**data)
|
37
|
+
#breakpoint()
|
32
38
|
return AgentResponseDict(**response)
|
33
39
|
|
34
40
|
async def async_get_response(
|
@@ -38,7 +44,8 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
38
44
|
iteration: int = 0,
|
39
45
|
encoded_image=None,
|
40
46
|
) -> dict:
|
41
|
-
"""Call the LLM and gets a response. Used in the `answer_question` method.
|
47
|
+
"""Call the LLM and gets a response. Used in the `answer_question` method.
|
48
|
+
"""
|
42
49
|
try:
|
43
50
|
params = {
|
44
51
|
"user_prompt": user_prompt.text,
|
@@ -97,7 +104,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
97
104
|
answer = question._translate_answer_code_to_answer(
|
98
105
|
response["answer"], combined_dict
|
99
106
|
)
|
100
|
-
# breakpoint()
|
101
107
|
data = {
|
102
108
|
"answer": answer,
|
103
109
|
"comment": response.get(
|
edsl/config.py
CHANGED
@@ -65,6 +65,10 @@ CONFIG_MAP = {
|
|
65
65
|
# "default": None,
|
66
66
|
# "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
|
67
67
|
# },
|
68
|
+
# "GROQ_API_KEY": {
|
69
|
+
# "default": None,
|
70
|
+
# "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
|
71
|
+
# },
|
68
72
|
}
|
69
73
|
|
70
74
|
|
edsl/coop/coop.py
CHANGED
@@ -465,6 +465,7 @@ class Coop:
|
|
465
465
|
description: Optional[str] = None,
|
466
466
|
status: RemoteJobStatus = "queued",
|
467
467
|
visibility: Optional[VisibilityType] = "unlisted",
|
468
|
+
iterations: Optional[int] = 1,
|
468
469
|
) -> dict:
|
469
470
|
"""
|
470
471
|
Send a remote inference job to the server.
|
@@ -473,6 +474,7 @@ class Coop:
|
|
473
474
|
:param optional description: A description for this entry in the remote cache.
|
474
475
|
:param status: The status of the job. Should be 'queued', unless you are debugging.
|
475
476
|
:param visibility: The visibility of the cache entry.
|
477
|
+
:param iterations: The number of times to run each interview.
|
476
478
|
|
477
479
|
>>> job = Jobs.example()
|
478
480
|
>>> coop.remote_inference_create(job=job, description="My job")
|
@@ -488,6 +490,7 @@ class Coop:
|
|
488
490
|
),
|
489
491
|
"description": description,
|
490
492
|
"status": status,
|
493
|
+
"iterations": iterations,
|
491
494
|
"visibility": visibility,
|
492
495
|
"version": self._edsl_version,
|
493
496
|
},
|
@@ -498,6 +501,7 @@ class Coop:
|
|
498
501
|
"uuid": response_json.get("jobs_uuid"),
|
499
502
|
"description": response_json.get("description"),
|
500
503
|
"status": response_json.get("status"),
|
504
|
+
"iterations": response_json.get("iterations"),
|
501
505
|
"visibility": response_json.get("visibility"),
|
502
506
|
"version": self._edsl_version,
|
503
507
|
}
|
edsl/enums.py
CHANGED
@@ -59,7 +59,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
59
59
|
GOOGLE = "google"
|
60
60
|
TEST = "test"
|
61
61
|
ANTHROPIC = "anthropic"
|
62
|
-
|
62
|
+
GROQ = "groq"
|
63
63
|
|
64
64
|
service_to_api_keyname = {
|
65
65
|
InferenceServiceType.BEDROCK.value: "TBD",
|
@@ -69,6 +69,7 @@ service_to_api_keyname = {
|
|
69
69
|
InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
|
70
70
|
InferenceServiceType.TEST.value: "TBD",
|
71
71
|
InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
|
72
|
+
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
72
73
|
}
|
73
74
|
|
74
75
|
|
@@ -2,102 +2,16 @@ import aiohttp
|
|
2
2
|
import json
|
3
3
|
import requests
|
4
4
|
from typing import Any, List
|
5
|
-
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
5
|
+
#from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
6
|
from edsl.language_models import LanguageModel
|
7
7
|
|
8
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
8
9
|
|
9
|
-
class DeepInfraService(
|
10
|
+
class DeepInfraService(OpenAIService):
|
10
11
|
"""DeepInfra service class."""
|
11
12
|
|
12
13
|
_inference_service_ = "deep_infra"
|
13
14
|
_env_key_name_ = "DEEP_INFRA_API_KEY"
|
14
|
-
|
15
|
+
_base_url_ = "https://api.deepinfra.com/v1/openai"
|
15
16
|
_models_list_cache: List[str] = []
|
16
17
|
|
17
|
-
@classmethod
|
18
|
-
def available(cls):
|
19
|
-
text_models = cls.full_details_available()
|
20
|
-
return [m["model_name"] for m in text_models]
|
21
|
-
|
22
|
-
@classmethod
|
23
|
-
def full_details_available(cls, verbose=False):
|
24
|
-
if not cls._models_list_cache:
|
25
|
-
url = "https://api.deepinfra.com/models/list"
|
26
|
-
response = requests.get(url)
|
27
|
-
if response.status_code == 200:
|
28
|
-
text_generation_models = [
|
29
|
-
r for r in response.json() if r["type"] == "text-generation"
|
30
|
-
]
|
31
|
-
cls._models_list_cache = text_generation_models
|
32
|
-
|
33
|
-
from rich import print_json
|
34
|
-
import json
|
35
|
-
|
36
|
-
if verbose:
|
37
|
-
print_json(json.dumps(text_generation_models))
|
38
|
-
return text_generation_models
|
39
|
-
else:
|
40
|
-
return f"Failed to fetch data: Status code {response.status_code}"
|
41
|
-
else:
|
42
|
-
return cls._models_list_cache
|
43
|
-
|
44
|
-
@classmethod
|
45
|
-
def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
|
46
|
-
base_url = "https://api.deepinfra.com/v1/inference/"
|
47
|
-
if model_class_name is None:
|
48
|
-
model_class_name = cls.to_class_name(model_name)
|
49
|
-
url = f"{base_url}{model_name}"
|
50
|
-
|
51
|
-
class LLM(LanguageModel):
|
52
|
-
_inference_service_ = cls._inference_service_
|
53
|
-
_model_ = model_name
|
54
|
-
_parameters_ = {
|
55
|
-
"temperature": 0.7,
|
56
|
-
"top_p": 0.2,
|
57
|
-
"top_k": 0.1,
|
58
|
-
"max_new_tokens": 512,
|
59
|
-
"stopSequences": [],
|
60
|
-
}
|
61
|
-
|
62
|
-
async def async_execute_model_call(
|
63
|
-
self, user_prompt: str, system_prompt: str = ""
|
64
|
-
) -> dict[str, Any]:
|
65
|
-
self.url = url
|
66
|
-
headers = {
|
67
|
-
"Content-Type": "application/json",
|
68
|
-
"Authorization": f"bearer {self.api_token}",
|
69
|
-
}
|
70
|
-
# don't mess w/ the newlines
|
71
|
-
data = {
|
72
|
-
"input": f"""
|
73
|
-
[INST]<<SYS>>
|
74
|
-
{system_prompt}
|
75
|
-
<<SYS>>{user_prompt}[/INST]
|
76
|
-
""",
|
77
|
-
"stream": False,
|
78
|
-
"temperature": self.temperature,
|
79
|
-
"top_p": self.top_p,
|
80
|
-
"top_k": self.top_k,
|
81
|
-
"max_new_tokens": self.max_new_tokens,
|
82
|
-
}
|
83
|
-
async with aiohttp.ClientSession() as session:
|
84
|
-
async with session.post(
|
85
|
-
self.url, headers=headers, data=json.dumps(data)
|
86
|
-
) as response:
|
87
|
-
raw_response_text = await response.text()
|
88
|
-
return json.loads(raw_response_text)
|
89
|
-
|
90
|
-
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
91
|
-
if "results" not in raw_response:
|
92
|
-
raise Exception(
|
93
|
-
f"Deep Infra response does not contain 'results' key: {raw_response}"
|
94
|
-
)
|
95
|
-
if "generated_text" not in raw_response["results"][0]:
|
96
|
-
raise Exception(
|
97
|
-
f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
|
98
|
-
)
|
99
|
-
return raw_response["results"][0]["generated_text"]
|
100
|
-
|
101
|
-
LLM.__name__ = model_class_name
|
102
|
-
|
103
|
-
return LLM
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import Any, List
|
2
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
3
|
+
|
4
|
+
import groq
|
5
|
+
|
6
|
+
|
7
|
+
class GroqService(OpenAIService):
|
8
|
+
"""DeepInfra service class."""
|
9
|
+
|
10
|
+
_inference_service_ = "groq"
|
11
|
+
_env_key_name_ = "GROQ_API_KEY"
|
12
|
+
|
13
|
+
_sync_client_ = groq.Groq
|
14
|
+
_async_client_ = groq.AsyncGroq
|
15
|
+
|
16
|
+
#_base_url_ = "https://api.deepinfra.com/v1/openai"
|
17
|
+
_base_url_ = None
|
18
|
+
_models_list_cache: List[str] = []
|
19
|
+
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from typing import Any, List
|
2
2
|
import re
|
3
|
-
|
3
|
+
import os
|
4
|
+
#from openai import AsyncOpenAI
|
5
|
+
import openai
|
4
6
|
|
5
7
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
6
8
|
from edsl.language_models import LanguageModel
|
@@ -12,6 +14,22 @@ class OpenAIService(InferenceServiceABC):
|
|
12
14
|
|
13
15
|
_inference_service_ = "openai"
|
14
16
|
_env_key_name_ = "OPENAI_API_KEY"
|
17
|
+
_base_url_ = None
|
18
|
+
|
19
|
+
_sync_client_ = openai.OpenAI
|
20
|
+
_async_client_ = openai.AsyncOpenAI
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def sync_client(cls):
|
24
|
+
return cls._sync_client_(
|
25
|
+
api_key = os.getenv(cls._env_key_name_),
|
26
|
+
base_url = cls._base_url_)
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def async_client(cls):
|
30
|
+
return cls._async_client_(
|
31
|
+
api_key = os.getenv(cls._env_key_name_),
|
32
|
+
base_url = cls._base_url_)
|
15
33
|
|
16
34
|
# TODO: Make this a coop call
|
17
35
|
model_exclude_list = [
|
@@ -31,16 +49,24 @@ class OpenAIService(InferenceServiceABC):
|
|
31
49
|
]
|
32
50
|
_models_list_cache: List[str] = []
|
33
51
|
|
52
|
+
@classmethod
|
53
|
+
def get_model_list(cls):
|
54
|
+
raw_list = cls.sync_client().models.list()
|
55
|
+
if hasattr(raw_list, "data"):
|
56
|
+
return raw_list.data
|
57
|
+
else:
|
58
|
+
return raw_list
|
59
|
+
|
34
60
|
@classmethod
|
35
61
|
def available(cls) -> List[str]:
|
36
|
-
from openai import OpenAI
|
62
|
+
#from openai import OpenAI
|
37
63
|
|
38
64
|
if not cls._models_list_cache:
|
39
65
|
try:
|
40
|
-
client = OpenAI()
|
66
|
+
#client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
41
67
|
cls._models_list_cache = [
|
42
68
|
m.id
|
43
|
-
for m in
|
69
|
+
for m in cls.get_model_list()
|
44
70
|
if m.id not in cls.model_exclude_list
|
45
71
|
]
|
46
72
|
except Exception as e:
|
@@ -78,15 +104,24 @@ class OpenAIService(InferenceServiceABC):
|
|
78
104
|
"top_logprobs": 3,
|
79
105
|
}
|
80
106
|
|
107
|
+
def sync_client(self):
|
108
|
+
return cls.sync_client()
|
109
|
+
|
110
|
+
def async_client(self):
|
111
|
+
return cls.async_client()
|
112
|
+
|
81
113
|
@classmethod
|
82
114
|
def available(cls) -> list[str]:
|
83
|
-
|
84
|
-
|
85
|
-
|
115
|
+
#import openai
|
116
|
+
#client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
117
|
+
#return client.models.list()
|
118
|
+
return cls.sync_client().models.list()
|
119
|
+
|
86
120
|
def get_headers(self) -> dict[str, Any]:
|
87
|
-
from openai import OpenAI
|
121
|
+
#from openai import OpenAI
|
88
122
|
|
89
|
-
client = OpenAI()
|
123
|
+
#client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
|
124
|
+
client = self.sync_client()
|
90
125
|
response = client.chat.completions.with_raw_response.create(
|
91
126
|
messages=[
|
92
127
|
{
|
@@ -124,8 +159,8 @@ class OpenAIService(InferenceServiceABC):
|
|
124
159
|
encoded_image=None,
|
125
160
|
) -> dict[str, Any]:
|
126
161
|
"""Calls the OpenAI API and returns the API response."""
|
127
|
-
content = [{"type": "text", "text": user_prompt}]
|
128
162
|
if encoded_image:
|
163
|
+
content = [{"type": "text", "text": user_prompt}]
|
129
164
|
content.append(
|
130
165
|
{
|
131
166
|
"type": "image_url",
|
@@ -134,21 +169,28 @@ class OpenAIService(InferenceServiceABC):
|
|
134
169
|
},
|
135
170
|
}
|
136
171
|
)
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
172
|
+
else:
|
173
|
+
content = user_prompt
|
174
|
+
# self.client = AsyncOpenAI(
|
175
|
+
# api_key = os.getenv(cls._env_key_name_),
|
176
|
+
# base_url = cls._base_url_
|
177
|
+
# )
|
178
|
+
client = self.async_client()
|
179
|
+
params = {
|
180
|
+
"model": self.model,
|
181
|
+
"messages": [
|
141
182
|
{"role": "system", "content": system_prompt},
|
142
183
|
{"role": "user", "content": content},
|
143
184
|
],
|
144
|
-
temperature
|
145
|
-
max_tokens
|
146
|
-
top_p
|
147
|
-
frequency_penalty
|
148
|
-
presence_penalty
|
149
|
-
logprobs
|
150
|
-
top_logprobs
|
151
|
-
|
185
|
+
"temperature": self.temperature,
|
186
|
+
"max_tokens": self.max_tokens,
|
187
|
+
"top_p": self.top_p,
|
188
|
+
"frequency_penalty": self.frequency_penalty,
|
189
|
+
"presence_penalty": self.presence_penalty,
|
190
|
+
"logprobs": self.logprobs,
|
191
|
+
"top_logprobs": self.top_logprobs if self.logprobs else None,
|
192
|
+
}
|
193
|
+
response = await client.chat.completions.create(**params)
|
152
194
|
return response.model_dump()
|
153
195
|
|
154
196
|
@staticmethod
|
@@ -6,7 +6,8 @@ from edsl.inference_services.OpenAIService import OpenAIService
|
|
6
6
|
from edsl.inference_services.AnthropicService import AnthropicService
|
7
7
|
from edsl.inference_services.DeepInfraService import DeepInfraService
|
8
8
|
from edsl.inference_services.GoogleService import GoogleService
|
9
|
+
from edsl.inference_services.GroqService import GroqService
|
9
10
|
|
10
11
|
default = InferenceServicesCollection(
|
11
|
-
[OpenAIService, AnthropicService, DeepInfraService, GoogleService]
|
12
|
+
[OpenAIService, AnthropicService, DeepInfraService, GoogleService, GroqService]
|
12
13
|
)
|
edsl/jobs/Jobs.py
CHANGED
@@ -475,6 +475,7 @@ class Jobs(Base):
|
|
475
475
|
self,
|
476
476
|
description=remote_inference_description,
|
477
477
|
status="queued",
|
478
|
+
iterations=n,
|
478
479
|
)
|
479
480
|
time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
|
480
481
|
job_uuid = remote_job_creation_data.get("uuid")
|
@@ -629,9 +630,9 @@ class Jobs(Base):
|
|
629
630
|
results = JobsRunnerAsyncio(self).run(*args, **kwargs)
|
630
631
|
return results
|
631
632
|
|
632
|
-
async def run_async(self, cache=None, **kwargs):
|
633
|
+
async def run_async(self, cache=None, n=1, **kwargs):
|
633
634
|
"""Run the job asynchronously."""
|
634
|
-
results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
|
635
|
+
results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
|
635
636
|
return results
|
636
637
|
|
637
638
|
def all_question_parameters(self):
|
@@ -204,12 +204,13 @@ class InterviewTaskBuildingMixin:
|
|
204
204
|
return skip
|
205
205
|
|
206
206
|
async def _attempt_to_answer_question(
|
207
|
-
self, invigilator: InvigilatorBase, task: asyncio.Task
|
208
|
-
) -> AgentResponseDict:
|
207
|
+
self, invigilator: 'InvigilatorBase', task: asyncio.Task
|
208
|
+
) -> 'AgentResponseDict':
|
209
209
|
"""Attempt to answer the question, and handle exceptions.
|
210
210
|
|
211
211
|
:param invigilator: the invigilator that will answer the question.
|
212
212
|
:param task: the task that is being run.
|
213
|
+
|
213
214
|
"""
|
214
215
|
try:
|
215
216
|
return await asyncio.wait_for(
|
@@ -13,6 +13,35 @@ from edsl.jobs.tasks.TaskHistory import TaskHistory
|
|
13
13
|
from edsl.jobs.buckets.BucketCollection import BucketCollection
|
14
14
|
from edsl.utilities.decorators import jupyter_nb_handler
|
15
15
|
|
16
|
+
import time
|
17
|
+
import functools
|
18
|
+
|
19
|
+
def cache_with_timeout(timeout):
|
20
|
+
def decorator(func):
|
21
|
+
cached_result = {}
|
22
|
+
last_computation_time = [0] # Using list to store mutable value
|
23
|
+
|
24
|
+
@functools.wraps(func)
|
25
|
+
def wrapper(*args, **kwargs):
|
26
|
+
current_time = time.time()
|
27
|
+
if (current_time - last_computation_time[0]) >= timeout:
|
28
|
+
cached_result['value'] = func(*args, **kwargs)
|
29
|
+
last_computation_time[0] = current_time
|
30
|
+
return cached_result['value']
|
31
|
+
|
32
|
+
return wrapper
|
33
|
+
return decorator
|
34
|
+
|
35
|
+
#from queue import Queue
|
36
|
+
from collections import UserList
|
37
|
+
|
38
|
+
class StatusTracker(UserList):
|
39
|
+
def __init__(self, total_tasks: int):
|
40
|
+
self.total_tasks = total_tasks
|
41
|
+
super().__init__()
|
42
|
+
|
43
|
+
def current_status(self):
|
44
|
+
return print(f"Completed: {len(self.data)} of {self.total_tasks}", end = "\r")
|
16
45
|
|
17
46
|
class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
18
47
|
"""A class for running a collection of interviews asynchronously.
|
@@ -43,7 +72,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
43
72
|
|
44
73
|
:param n: how many times to run each interview
|
45
74
|
:param debug:
|
46
|
-
:param stop_on_exception:
|
75
|
+
:param stop_on_exception: Whether to stop the interview if an exception is raised
|
76
|
+
:param sidecar_model: a language model to use in addition to the interview's model
|
77
|
+
:param total_interviews: A list of interviews to run can be provided instead.
|
47
78
|
"""
|
48
79
|
tasks = []
|
49
80
|
if total_interviews:
|
@@ -87,15 +118,18 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
87
118
|
) # set the cache for the first interview
|
88
119
|
self.total_interviews.append(interview)
|
89
120
|
|
90
|
-
async def run_async(self, cache=None) -> Results:
|
121
|
+
async def run_async(self, cache=None, n=1) -> Results:
|
91
122
|
from edsl.results.Results import Results
|
92
123
|
|
124
|
+
#breakpoint()
|
125
|
+
#tracker = StatusTracker(total_tasks=len(self.interviews))
|
126
|
+
|
93
127
|
if cache is None:
|
94
128
|
self.cache = Cache()
|
95
129
|
else:
|
96
130
|
self.cache = cache
|
97
131
|
data = []
|
98
|
-
async for result in self.run_async_generator(cache=self.cache):
|
132
|
+
async for result in self.run_async_generator(cache=self.cache, n=n):
|
99
133
|
data.append(result)
|
100
134
|
return Results(survey=self.jobs.survey, data=data)
|
101
135
|
|
@@ -201,91 +235,67 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
|
|
201
235
|
self.sidecar_model = sidecar_model
|
202
236
|
|
203
237
|
from edsl.results.Results import Results
|
238
|
+
from rich.live import Live
|
239
|
+
from rich.console import Console
|
204
240
|
|
205
|
-
|
206
|
-
|
207
|
-
|
241
|
+
@cache_with_timeout(1)
|
242
|
+
def generate_table():
|
243
|
+
return self.status_table(self.results, self.elapsed_time)
|
208
244
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
Live(generate_table(), console=console, refresh_per_second=5)
|
244
|
-
if progress_bar
|
245
|
-
else no_op_cm()
|
246
|
-
)
|
245
|
+
async def process_results(cache, progress_bar_context = None):
|
246
|
+
"""Processes results from interviews."""
|
247
|
+
async for result in self.run_async_generator(
|
248
|
+
n=n,
|
249
|
+
debug=debug,
|
250
|
+
stop_on_exception=stop_on_exception,
|
251
|
+
cache=cache,
|
252
|
+
sidecar_model=sidecar_model,
|
253
|
+
):
|
254
|
+
self.results.append(result)
|
255
|
+
if progress_bar_context:
|
256
|
+
progress_bar_context.update(generate_table())
|
257
|
+
self.completed = True
|
258
|
+
|
259
|
+
async def update_progress_bar(progress_bar_context):
|
260
|
+
"""Updates the progress bar at fixed intervals."""
|
261
|
+
if progress_bar_context is None:
|
262
|
+
return
|
263
|
+
|
264
|
+
while True:
|
265
|
+
progress_bar_context.update(generate_table())
|
266
|
+
await asyncio.sleep(0.1) # Update interval
|
267
|
+
if self.completed:
|
268
|
+
break
|
269
|
+
|
270
|
+
@contextmanager
|
271
|
+
def conditional_context(condition, context_manager):
|
272
|
+
if condition:
|
273
|
+
with context_manager as cm:
|
274
|
+
yield cm
|
275
|
+
else:
|
276
|
+
yield
|
277
|
+
|
278
|
+
with conditional_context(progress_bar, Live(generate_table(), console=console, refresh_per_second=1)) as progress_bar_context:
|
247
279
|
|
248
280
|
with cache as c:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
await asyncio.sleep(0.00001) # Update interval
|
256
|
-
if self.completed:
|
257
|
-
break
|
258
|
-
|
259
|
-
async def process_results():
|
260
|
-
"""Processes results from interviews."""
|
261
|
-
async for result in self.run_async_generator(
|
262
|
-
n=n,
|
263
|
-
debug=debug,
|
264
|
-
stop_on_exception=stop_on_exception,
|
265
|
-
cache=c,
|
266
|
-
sidecar_model=sidecar_model,
|
267
|
-
):
|
268
|
-
self.results.append(result)
|
269
|
-
live.update(generate_table())
|
270
|
-
self.completed = True
|
271
|
-
|
272
|
-
progress_task = asyncio.create_task(update_progress_bar())
|
273
|
-
|
274
|
-
try:
|
275
|
-
await asyncio.gather(process_results(), progress_task)
|
276
|
-
except asyncio.CancelledError:
|
281
|
+
|
282
|
+
progress_task = asyncio.create_task(update_progress_bar(progress_bar_context))
|
283
|
+
|
284
|
+
try:
|
285
|
+
await asyncio.gather(progress_task, process_results(cache = c, progress_bar_context = progress_bar_context))
|
286
|
+
except asyncio.CancelledError:
|
277
287
|
pass
|
278
|
-
|
279
|
-
|
280
|
-
|
288
|
+
finally:
|
289
|
+
progress_task.cancel() # Cancel the progress_task when process_results is done
|
290
|
+
await progress_task
|
281
291
|
|
282
|
-
|
292
|
+
await asyncio.sleep(1) # short delay to show the final status
|
283
293
|
|
284
|
-
|
285
|
-
|
294
|
+
if progress_bar_context:
|
295
|
+
progress_bar_context.update(generate_table())
|
286
296
|
|
287
|
-
results = Results(survey=self.jobs.survey, data=self.results)
|
288
297
|
|
298
|
+
results = Results(survey=self.jobs.survey, data=self.results)
|
289
299
|
task_history = TaskHistory(self.total_interviews, include_traceback=False)
|
290
300
|
results.task_history = task_history
|
291
301
|
|