edsl 0.1.31.dev2__py3-none-any.whl → 0.1.31.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.31.dev2"
1
+ __version__ = "0.1.31.dev4"
@@ -18,7 +18,12 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
18
18
  """An invigilator that uses an AI model to answer questions."""
19
19
 
20
20
  async def async_answer_question(self) -> AgentResponseDict:
21
- """Answer a question using the AI model."""
21
+ """Answer a question using the AI model.
22
+
23
+ >>> i = InvigilatorAI.example()
24
+ >>> i.answer_question()
25
+ {'message': '{"answer": "SPAM!"}'}
26
+ """
22
27
  params = self.get_prompts() | {"iteration": self.iteration}
23
28
  raw_response = await self.async_get_response(**params)
24
29
  data = {
@@ -29,6 +34,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
29
34
  "raw_model_response": raw_response["raw_model_response"],
30
35
  }
31
36
  response = self._format_raw_response(**data)
37
+ #breakpoint()
32
38
  return AgentResponseDict(**response)
33
39
 
34
40
  async def async_get_response(
@@ -38,7 +44,8 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
38
44
  iteration: int = 0,
39
45
  encoded_image=None,
40
46
  ) -> dict:
41
- """Call the LLM and gets a response. Used in the `answer_question` method."""
47
+ """Call the LLM and gets a response. Used in the `answer_question` method.
48
+ """
42
49
  try:
43
50
  params = {
44
51
  "user_prompt": user_prompt.text,
@@ -97,7 +104,6 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
97
104
  answer = question._translate_answer_code_to_answer(
98
105
  response["answer"], combined_dict
99
106
  )
100
- # breakpoint()
101
107
  data = {
102
108
  "answer": answer,
103
109
  "comment": response.get(
edsl/config.py CHANGED
@@ -65,6 +65,10 @@ CONFIG_MAP = {
65
65
  # "default": None,
66
66
  # "info": "This env var holds your Anthropic API key (https://www.anthropic.com/).",
67
67
  # },
68
+ # "GROQ_API_KEY": {
69
+ # "default": None,
70
+ # "info": "This env var holds your GROQ API key (https://console.groq.com/login).",
71
+ # },
68
72
  }
69
73
 
70
74
 
edsl/coop/coop.py CHANGED
@@ -465,6 +465,7 @@ class Coop:
465
465
  description: Optional[str] = None,
466
466
  status: RemoteJobStatus = "queued",
467
467
  visibility: Optional[VisibilityType] = "unlisted",
468
+ iterations: Optional[int] = 1,
468
469
  ) -> dict:
469
470
  """
470
471
  Send a remote inference job to the server.
@@ -473,6 +474,7 @@ class Coop:
473
474
  :param optional description: A description for this entry in the remote cache.
474
475
  :param status: The status of the job. Should be 'queued', unless you are debugging.
475
476
  :param visibility: The visibility of the cache entry.
477
+ :param iterations: The number of times to run each interview.
476
478
 
477
479
  >>> job = Jobs.example()
478
480
  >>> coop.remote_inference_create(job=job, description="My job")
@@ -488,6 +490,7 @@ class Coop:
488
490
  ),
489
491
  "description": description,
490
492
  "status": status,
493
+ "iterations": iterations,
491
494
  "visibility": visibility,
492
495
  "version": self._edsl_version,
493
496
  },
@@ -498,6 +501,7 @@ class Coop:
498
501
  "uuid": response_json.get("jobs_uuid"),
499
502
  "description": response_json.get("description"),
500
503
  "status": response_json.get("status"),
504
+ "iterations": response_json.get("iterations"),
501
505
  "visibility": response_json.get("visibility"),
502
506
  "version": self._edsl_version,
503
507
  }
edsl/enums.py CHANGED
@@ -59,7 +59,7 @@ class InferenceServiceType(EnumWithChecks):
59
59
  GOOGLE = "google"
60
60
  TEST = "test"
61
61
  ANTHROPIC = "anthropic"
62
-
62
+ GROQ = "groq"
63
63
 
64
64
  service_to_api_keyname = {
65
65
  InferenceServiceType.BEDROCK.value: "TBD",
@@ -69,6 +69,7 @@ service_to_api_keyname = {
69
69
  InferenceServiceType.GOOGLE.value: "GOOGLE_API_KEY",
70
70
  InferenceServiceType.TEST.value: "TBD",
71
71
  InferenceServiceType.ANTHROPIC.value: "ANTHROPIC_API_KEY",
72
+ InferenceServiceType.GROQ.value: "GROQ_API_KEY",
72
73
  }
73
74
 
74
75
 
@@ -2,102 +2,16 @@ import aiohttp
2
2
  import json
3
3
  import requests
4
4
  from typing import Any, List
5
- from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
5
+ #from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
6
  from edsl.language_models import LanguageModel
7
7
 
8
+ from edsl.inference_services.OpenAIService import OpenAIService
8
9
 
9
- class DeepInfraService(InferenceServiceABC):
10
+ class DeepInfraService(OpenAIService):
10
11
  """DeepInfra service class."""
11
12
 
12
13
  _inference_service_ = "deep_infra"
13
14
  _env_key_name_ = "DEEP_INFRA_API_KEY"
14
-
15
+ _base_url_ = "https://api.deepinfra.com/v1/openai"
15
16
  _models_list_cache: List[str] = []
16
17
 
17
- @classmethod
18
- def available(cls):
19
- text_models = cls.full_details_available()
20
- return [m["model_name"] for m in text_models]
21
-
22
- @classmethod
23
- def full_details_available(cls, verbose=False):
24
- if not cls._models_list_cache:
25
- url = "https://api.deepinfra.com/models/list"
26
- response = requests.get(url)
27
- if response.status_code == 200:
28
- text_generation_models = [
29
- r for r in response.json() if r["type"] == "text-generation"
30
- ]
31
- cls._models_list_cache = text_generation_models
32
-
33
- from rich import print_json
34
- import json
35
-
36
- if verbose:
37
- print_json(json.dumps(text_generation_models))
38
- return text_generation_models
39
- else:
40
- return f"Failed to fetch data: Status code {response.status_code}"
41
- else:
42
- return cls._models_list_cache
43
-
44
- @classmethod
45
- def create_model(cls, model_name: str, model_class_name=None) -> LanguageModel:
46
- base_url = "https://api.deepinfra.com/v1/inference/"
47
- if model_class_name is None:
48
- model_class_name = cls.to_class_name(model_name)
49
- url = f"{base_url}{model_name}"
50
-
51
- class LLM(LanguageModel):
52
- _inference_service_ = cls._inference_service_
53
- _model_ = model_name
54
- _parameters_ = {
55
- "temperature": 0.7,
56
- "top_p": 0.2,
57
- "top_k": 0.1,
58
- "max_new_tokens": 512,
59
- "stopSequences": [],
60
- }
61
-
62
- async def async_execute_model_call(
63
- self, user_prompt: str, system_prompt: str = ""
64
- ) -> dict[str, Any]:
65
- self.url = url
66
- headers = {
67
- "Content-Type": "application/json",
68
- "Authorization": f"bearer {self.api_token}",
69
- }
70
- # don't mess w/ the newlines
71
- data = {
72
- "input": f"""
73
- [INST]<<SYS>>
74
- {system_prompt}
75
- <<SYS>>{user_prompt}[/INST]
76
- """,
77
- "stream": False,
78
- "temperature": self.temperature,
79
- "top_p": self.top_p,
80
- "top_k": self.top_k,
81
- "max_new_tokens": self.max_new_tokens,
82
- }
83
- async with aiohttp.ClientSession() as session:
84
- async with session.post(
85
- self.url, headers=headers, data=json.dumps(data)
86
- ) as response:
87
- raw_response_text = await response.text()
88
- return json.loads(raw_response_text)
89
-
90
- def parse_response(self, raw_response: dict[str, Any]) -> str:
91
- if "results" not in raw_response:
92
- raise Exception(
93
- f"Deep Infra response does not contain 'results' key: {raw_response}"
94
- )
95
- if "generated_text" not in raw_response["results"][0]:
96
- raise Exception(
97
- f"Deep Infra response does not contain 'generate_text' key: {raw_response['results'][0]}"
98
- )
99
- return raw_response["results"][0]["generated_text"]
100
-
101
- LLM.__name__ = model_class_name
102
-
103
- return LLM
@@ -0,0 +1,19 @@
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+ import groq
5
+
6
+
7
+ class GroqService(OpenAIService):
8
+ """DeepInfra service class."""
9
+
10
+ _inference_service_ = "groq"
11
+ _env_key_name_ = "GROQ_API_KEY"
12
+
13
+ _sync_client_ = groq.Groq
14
+ _async_client_ = groq.AsyncGroq
15
+
16
+ #_base_url_ = "https://api.deepinfra.com/v1/openai"
17
+ _base_url_ = None
18
+ _models_list_cache: List[str] = []
19
+
@@ -1,6 +1,8 @@
1
1
  from typing import Any, List
2
2
  import re
3
- from openai import AsyncOpenAI
3
+ import os
4
+ #from openai import AsyncOpenAI
5
+ import openai
4
6
 
5
7
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
6
8
  from edsl.language_models import LanguageModel
@@ -12,6 +14,22 @@ class OpenAIService(InferenceServiceABC):
12
14
 
13
15
  _inference_service_ = "openai"
14
16
  _env_key_name_ = "OPENAI_API_KEY"
17
+ _base_url_ = None
18
+
19
+ _sync_client_ = openai.OpenAI
20
+ _async_client_ = openai.AsyncOpenAI
21
+
22
+ @classmethod
23
+ def sync_client(cls):
24
+ return cls._sync_client_(
25
+ api_key = os.getenv(cls._env_key_name_),
26
+ base_url = cls._base_url_)
27
+
28
+ @classmethod
29
+ def async_client(cls):
30
+ return cls._async_client_(
31
+ api_key = os.getenv(cls._env_key_name_),
32
+ base_url = cls._base_url_)
15
33
 
16
34
  # TODO: Make this a coop call
17
35
  model_exclude_list = [
@@ -31,16 +49,24 @@ class OpenAIService(InferenceServiceABC):
31
49
  ]
32
50
  _models_list_cache: List[str] = []
33
51
 
52
+ @classmethod
53
+ def get_model_list(cls):
54
+ raw_list = cls.sync_client().models.list()
55
+ if hasattr(raw_list, "data"):
56
+ return raw_list.data
57
+ else:
58
+ return raw_list
59
+
34
60
  @classmethod
35
61
  def available(cls) -> List[str]:
36
- from openai import OpenAI
62
+ #from openai import OpenAI
37
63
 
38
64
  if not cls._models_list_cache:
39
65
  try:
40
- client = OpenAI()
66
+ #client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
41
67
  cls._models_list_cache = [
42
68
  m.id
43
- for m in client.models.list()
69
+ for m in cls.get_model_list()
44
70
  if m.id not in cls.model_exclude_list
45
71
  ]
46
72
  except Exception as e:
@@ -78,15 +104,24 @@ class OpenAIService(InferenceServiceABC):
78
104
  "top_logprobs": 3,
79
105
  }
80
106
 
107
+ def sync_client(self):
108
+ return cls.sync_client()
109
+
110
+ def async_client(self):
111
+ return cls.async_client()
112
+
81
113
  @classmethod
82
114
  def available(cls) -> list[str]:
83
- client = openai.OpenAI()
84
- return client.models.list()
85
-
115
+ #import openai
116
+ #client = openai.OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
117
+ #return client.models.list()
118
+ return cls.sync_client().models.list()
119
+
86
120
  def get_headers(self) -> dict[str, Any]:
87
- from openai import OpenAI
121
+ #from openai import OpenAI
88
122
 
89
- client = OpenAI()
123
+ #client = OpenAI(api_key = os.getenv(cls._env_key_name_), base_url = cls._base_url_)
124
+ client = self.sync_client()
90
125
  response = client.chat.completions.with_raw_response.create(
91
126
  messages=[
92
127
  {
@@ -124,8 +159,8 @@ class OpenAIService(InferenceServiceABC):
124
159
  encoded_image=None,
125
160
  ) -> dict[str, Any]:
126
161
  """Calls the OpenAI API and returns the API response."""
127
- content = [{"type": "text", "text": user_prompt}]
128
162
  if encoded_image:
163
+ content = [{"type": "text", "text": user_prompt}]
129
164
  content.append(
130
165
  {
131
166
  "type": "image_url",
@@ -134,21 +169,28 @@ class OpenAIService(InferenceServiceABC):
134
169
  },
135
170
  }
136
171
  )
137
- self.client = AsyncOpenAI()
138
- response = await self.client.chat.completions.create(
139
- model=self.model,
140
- messages=[
172
+ else:
173
+ content = user_prompt
174
+ # self.client = AsyncOpenAI(
175
+ # api_key = os.getenv(cls._env_key_name_),
176
+ # base_url = cls._base_url_
177
+ # )
178
+ client = self.async_client()
179
+ params = {
180
+ "model": self.model,
181
+ "messages": [
141
182
  {"role": "system", "content": system_prompt},
142
183
  {"role": "user", "content": content},
143
184
  ],
144
- temperature=self.temperature,
145
- max_tokens=self.max_tokens,
146
- top_p=self.top_p,
147
- frequency_penalty=self.frequency_penalty,
148
- presence_penalty=self.presence_penalty,
149
- logprobs=self.logprobs,
150
- top_logprobs=self.top_logprobs if self.logprobs else None,
151
- )
185
+ "temperature": self.temperature,
186
+ "max_tokens": self.max_tokens,
187
+ "top_p": self.top_p,
188
+ "frequency_penalty": self.frequency_penalty,
189
+ "presence_penalty": self.presence_penalty,
190
+ "logprobs": self.logprobs,
191
+ "top_logprobs": self.top_logprobs if self.logprobs else None,
192
+ }
193
+ response = await client.chat.completions.create(**params)
152
194
  return response.model_dump()
153
195
 
154
196
  @staticmethod
@@ -6,7 +6,8 @@ from edsl.inference_services.OpenAIService import OpenAIService
6
6
  from edsl.inference_services.AnthropicService import AnthropicService
7
7
  from edsl.inference_services.DeepInfraService import DeepInfraService
8
8
  from edsl.inference_services.GoogleService import GoogleService
9
+ from edsl.inference_services.GroqService import GroqService
9
10
 
10
11
  default = InferenceServicesCollection(
11
- [OpenAIService, AnthropicService, DeepInfraService, GoogleService]
12
+ [OpenAIService, AnthropicService, DeepInfraService, GoogleService, GroqService]
12
13
  )
edsl/jobs/Jobs.py CHANGED
@@ -475,6 +475,7 @@ class Jobs(Base):
475
475
  self,
476
476
  description=remote_inference_description,
477
477
  status="queued",
478
+ iterations=n,
478
479
  )
479
480
  time_queued = datetime.now().strftime("%m/%d/%Y %I:%M:%S %p")
480
481
  job_uuid = remote_job_creation_data.get("uuid")
@@ -629,9 +630,9 @@ class Jobs(Base):
629
630
  results = JobsRunnerAsyncio(self).run(*args, **kwargs)
630
631
  return results
631
632
 
632
- async def run_async(self, cache=None, **kwargs):
633
+ async def run_async(self, cache=None, n=1, **kwargs):
633
634
  """Run the job asynchronously."""
634
- results = await JobsRunnerAsyncio(self).run_async(cache=cache, **kwargs)
635
+ results = await JobsRunnerAsyncio(self).run_async(cache=cache, n=n, **kwargs)
635
636
  return results
636
637
 
637
638
  def all_question_parameters(self):
@@ -204,12 +204,13 @@ class InterviewTaskBuildingMixin:
204
204
  return skip
205
205
 
206
206
  async def _attempt_to_answer_question(
207
- self, invigilator: InvigilatorBase, task: asyncio.Task
208
- ) -> AgentResponseDict:
207
+ self, invigilator: 'InvigilatorBase', task: asyncio.Task
208
+ ) -> 'AgentResponseDict':
209
209
  """Attempt to answer the question, and handle exceptions.
210
210
 
211
211
  :param invigilator: the invigilator that will answer the question.
212
212
  :param task: the task that is being run.
213
+
213
214
  """
214
215
  try:
215
216
  return await asyncio.wait_for(
@@ -13,6 +13,35 @@ from edsl.jobs.tasks.TaskHistory import TaskHistory
13
13
  from edsl.jobs.buckets.BucketCollection import BucketCollection
14
14
  from edsl.utilities.decorators import jupyter_nb_handler
15
15
 
16
+ import time
17
+ import functools
18
+
19
+ def cache_with_timeout(timeout):
20
+ def decorator(func):
21
+ cached_result = {}
22
+ last_computation_time = [0] # Using list to store mutable value
23
+
24
+ @functools.wraps(func)
25
+ def wrapper(*args, **kwargs):
26
+ current_time = time.time()
27
+ if (current_time - last_computation_time[0]) >= timeout:
28
+ cached_result['value'] = func(*args, **kwargs)
29
+ last_computation_time[0] = current_time
30
+ return cached_result['value']
31
+
32
+ return wrapper
33
+ return decorator
34
+
35
+ #from queue import Queue
36
+ from collections import UserList
37
+
38
+ class StatusTracker(UserList):
39
+ def __init__(self, total_tasks: int):
40
+ self.total_tasks = total_tasks
41
+ super().__init__()
42
+
43
+ def current_status(self):
44
+ return print(f"Completed: {len(self.data)} of {self.total_tasks}", end = "\r")
16
45
 
17
46
  class JobsRunnerAsyncio(JobsRunnerStatusMixin):
18
47
  """A class for running a collection of interviews asynchronously.
@@ -43,7 +72,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
43
72
 
44
73
  :param n: how many times to run each interview
45
74
  :param debug:
46
- :param stop_on_exception:
75
+ :param stop_on_exception: Whether to stop the interview if an exception is raised
76
+ :param sidecar_model: a language model to use in addition to the interview's model
77
+ :param total_interviews: A list of interviews to run can be provided instead.
47
78
  """
48
79
  tasks = []
49
80
  if total_interviews:
@@ -87,15 +118,18 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
87
118
  ) # set the cache for the first interview
88
119
  self.total_interviews.append(interview)
89
120
 
90
- async def run_async(self, cache=None) -> Results:
121
+ async def run_async(self, cache=None, n=1) -> Results:
91
122
  from edsl.results.Results import Results
92
123
 
124
+ #breakpoint()
125
+ #tracker = StatusTracker(total_tasks=len(self.interviews))
126
+
93
127
  if cache is None:
94
128
  self.cache = Cache()
95
129
  else:
96
130
  self.cache = cache
97
131
  data = []
98
- async for result in self.run_async_generator(cache=self.cache):
132
+ async for result in self.run_async_generator(cache=self.cache, n=n):
99
133
  data.append(result)
100
134
  return Results(survey=self.jobs.survey, data=data)
101
135
 
@@ -201,91 +235,67 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
201
235
  self.sidecar_model = sidecar_model
202
236
 
203
237
  from edsl.results.Results import Results
238
+ from rich.live import Live
239
+ from rich.console import Console
204
240
 
205
- if not progress_bar:
206
- # print("Running without progress bar")
207
- with cache as c:
241
+ @cache_with_timeout(1)
242
+ def generate_table():
243
+ return self.status_table(self.results, self.elapsed_time)
208
244
 
209
- async def process_results():
210
- """Processes results from interviews."""
211
- async for result in self.run_async_generator(
212
- n=n,
213
- debug=debug,
214
- stop_on_exception=stop_on_exception,
215
- cache=c,
216
- sidecar_model=sidecar_model,
217
- ):
218
- self.results.append(result)
219
- self.completed = True
220
-
221
- await asyncio.gather(process_results())
222
-
223
- results = Results(survey=self.jobs.survey, data=self.results)
224
- else:
225
- # print("Running with progress bar")
226
- from rich.live import Live
227
- from rich.console import Console
228
-
229
- def generate_table():
230
- return self.status_table(self.results, self.elapsed_time)
231
-
232
- @contextmanager
233
- def no_op_cm():
234
- """A no-op context manager with a dummy update method."""
235
- yield DummyLive()
236
-
237
- class DummyLive:
238
- def update(self, *args, **kwargs):
239
- """A dummy update method that does nothing."""
240
- pass
241
-
242
- progress_bar_context = (
243
- Live(generate_table(), console=console, refresh_per_second=5)
244
- if progress_bar
245
- else no_op_cm()
246
- )
245
+ async def process_results(cache, progress_bar_context = None):
246
+ """Processes results from interviews."""
247
+ async for result in self.run_async_generator(
248
+ n=n,
249
+ debug=debug,
250
+ stop_on_exception=stop_on_exception,
251
+ cache=cache,
252
+ sidecar_model=sidecar_model,
253
+ ):
254
+ self.results.append(result)
255
+ if progress_bar_context:
256
+ progress_bar_context.update(generate_table())
257
+ self.completed = True
258
+
259
+ async def update_progress_bar(progress_bar_context):
260
+ """Updates the progress bar at fixed intervals."""
261
+ if progress_bar_context is None:
262
+ return
263
+
264
+ while True:
265
+ progress_bar_context.update(generate_table())
266
+ await asyncio.sleep(0.1) # Update interval
267
+ if self.completed:
268
+ break
269
+
270
+ @contextmanager
271
+ def conditional_context(condition, context_manager):
272
+ if condition:
273
+ with context_manager as cm:
274
+ yield cm
275
+ else:
276
+ yield
277
+
278
+ with conditional_context(progress_bar, Live(generate_table(), console=console, refresh_per_second=1)) as progress_bar_context:
247
279
 
248
280
  with cache as c:
249
- with progress_bar_context as live:
250
-
251
- async def update_progress_bar():
252
- """Updates the progress bar at fixed intervals."""
253
- while True:
254
- live.update(generate_table())
255
- await asyncio.sleep(0.00001) # Update interval
256
- if self.completed:
257
- break
258
-
259
- async def process_results():
260
- """Processes results from interviews."""
261
- async for result in self.run_async_generator(
262
- n=n,
263
- debug=debug,
264
- stop_on_exception=stop_on_exception,
265
- cache=c,
266
- sidecar_model=sidecar_model,
267
- ):
268
- self.results.append(result)
269
- live.update(generate_table())
270
- self.completed = True
271
-
272
- progress_task = asyncio.create_task(update_progress_bar())
273
-
274
- try:
275
- await asyncio.gather(process_results(), progress_task)
276
- except asyncio.CancelledError:
281
+
282
+ progress_task = asyncio.create_task(update_progress_bar(progress_bar_context))
283
+
284
+ try:
285
+ await asyncio.gather(progress_task, process_results(cache = c, progress_bar_context = progress_bar_context))
286
+ except asyncio.CancelledError:
277
287
  pass
278
- finally:
279
- progress_task.cancel() # Cancel the progress_task when process_results is done
280
- await progress_task
288
+ finally:
289
+ progress_task.cancel() # Cancel the progress_task when process_results is done
290
+ await progress_task
281
291
 
282
- await asyncio.sleep(1) # short delay to show the final status
292
+ await asyncio.sleep(1) # short delay to show the final status
283
293
 
284
- # one more update
285
- live.update(generate_table())
294
+ if progress_bar_context:
295
+ progress_bar_context.update(generate_table())
286
296
 
287
- results = Results(survey=self.jobs.survey, data=self.results)
288
297
 
298
+ results = Results(survey=self.jobs.survey, data=self.results)
289
299
  task_history = TaskHistory(self.total_interviews, include_traceback=False)
290
300
  results.task_history = task_history
291
301