vision-agent 0.2.113__tar.gz → 0.2.115__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (34) hide show
  1. {vision_agent-0.2.113 → vision_agent-0.2.115}/PKG-INFO +49 -16
  2. {vision_agent-0.2.113 → vision_agent-0.2.115}/README.md +48 -15
  3. {vision_agent-0.2.113 → vision_agent-0.2.115}/pyproject.toml +1 -1
  4. vision_agent-0.2.115/vision_agent/agent/__init__.py +7 -0
  5. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/agent/agent_utils.py +25 -2
  6. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/agent/vision_agent_coder.py +69 -7
  7. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/lmm/lmm.py +40 -21
  8. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/__init__.py +1 -1
  9. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/execute.py +1 -1
  10. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/sim.py +49 -9
  11. vision_agent-0.2.113/vision_agent/agent/__init__.py +0 -3
  12. {vision_agent-0.2.113 → vision_agent-0.2.115}/LICENSE +0 -0
  13. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/__init__.py +0 -0
  14. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/agent/agent.py +0 -0
  15. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/agent/vision_agent.py +0 -0
  16. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/agent/vision_agent_coder_prompts.py +0 -0
  17. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/agent/vision_agent_prompts.py +0 -0
  18. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/clients/__init__.py +0 -0
  19. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/clients/http.py +0 -0
  20. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/clients/landing_public_api.py +0 -0
  21. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/fonts/__init__.py +0 -0
  22. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  23. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/lmm/__init__.py +0 -0
  24. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/lmm/types.py +0 -0
  25. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/tools/__init__.py +0 -0
  26. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/tools/meta_tools.py +0 -0
  27. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/tools/prompts.py +0 -0
  28. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/tools/tool_utils.py +0 -0
  29. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/tools/tools.py +0 -0
  30. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/tools/tools_types.py +0 -0
  31. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/exceptions.py +0 -0
  32. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/image_utils.py +0 -0
  33. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/type_defs.py +0 -0
  34. {vision_agent-0.2.113 → vision_agent-0.2.115}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.113
3
+ Version: 0.2.115
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -208,20 +208,18 @@ result = agent.chat_with_workflow(conv)
208
208
 
209
209
  ### Tools
210
210
  There are a variety of tools for the model or the user to use. Some are executed locally
211
- while others are hosted for you. You can also ask an LMM directly to build a tool for
212
- you. For example:
211
+ while others are hosted for you. You can easily access them yourself, for example if
212
+ you want to run `owl_v2` and visualize the output you can run:
213
213
 
214
214
  ```python
215
- >>> import vision_agent as va
216
- >>> lmm = va.lmm.OpenAILMM()
217
- >>> detector = lmm.generate_detector("Can you build a jar detector for me?")
218
- >>> detector(va.tools.load_image("jar.jpg"))
219
- [{"labels": ["jar",],
220
- "scores": [0.99],
221
- "bboxes": [
222
- [0.58, 0.2, 0.72, 0.45],
223
- ]
224
- }]
215
+ import vision_agent.tools as T
216
+ import matplotlib.pyplot as plt
217
+
218
+ image = T.load_image("dogs.jpg")
219
+ dets = T.owl_v2("dogs", image)
220
+ viz = T.overlay_bounding_boxes(image, dets)
221
+ plt.imshow(viz)
222
+ plt.show()
225
223
  ```
226
224
 
227
225
  You can also add custom tools to the agent:
@@ -254,6 +252,41 @@ function. Make sure the documentation is in the same format above with descripti
254
252
  `Parameters:`, `Returns:`, and `Example\n-------`. You can find an example use case
255
253
  [here](examples/custom_tools/) as this is what the agent uses to pick and use the tool.
256
254
 
255
+ ## Additional LLMs
256
+ ### Ollama
257
+ We also provide a `VisionAgentCoder` that uses Ollama. To get started you must download
258
+ a few models:
259
+
260
+ ```bash
261
+ ollama pull llama3.1
262
+ ollama pull mxbai-embed-large
263
+ ```
264
+
265
+ `llama3.1` is used for the `OllamaLMM` for `OllamaVisionAgentCoder`. Normally we would
266
+ use an actual LMM such as `llava` but `llava` cannot handle the long context lengths
267
+ required by the agent. Since `llama3.1` cannot handle images you may see some
268
+ performance degredation. `mxbai-embed-large` is the embedding model used to look up
269
+ tools. You can use it just like you would use `VisionAgentCoder`:
270
+
271
+ ```python
272
+ >>> import vision_agent as va
273
+ >>> agent = va.agent.OllamaVisionAgentCoder()
274
+ >>> agent("Count the apples in the image", media="apples.jpg")
275
+ ```
276
+ > WARNING: VisionAgent doesn't work well unless the underlying LMM is sufficiently powerful. Do not expect good results or even working code with smaller models like Llama 3.1 8B.
277
+
278
+ ### Azure OpenAI
279
+ We also provide a `AzureVisionAgentCoder` that uses Azure OpenAI models. To get started
280
+ follow the Azure Setup section below. You can use it just like you would use=
281
+ `VisionAgentCoder`:
282
+
283
+ ```python
284
+ >>> import vision_agent as va
285
+ >>> agent = va.agent.AzureVisionAgentCoder()
286
+ >>> agent("Count the apples in the image", media="apples.jpg")
287
+ ```
288
+
289
+
257
290
  ### Azure Setup
258
291
  If you want to use Azure OpenAI models, you need to have two OpenAI model deployments:
259
292
 
@@ -292,7 +325,7 @@ agent = va.agent.AzureVisionAgentCoder()
292
325
  2. Follow the instructions to purchase and manage your API credits.
293
326
  3. Ensure your API key is correctly configured in your project settings.
294
327
 
295
- Failure to have sufficient API credits may result in limited or no functionality for the features that rely on the OpenAI API.
296
-
297
- For more details on managing your API usage and credits, please refer to the OpenAI API documentation.
328
+ Failure to have sufficient API credits may result in limited or no functionality for
329
+ the features that rely on the OpenAI API. For more details on managing your API usage
330
+ and credits, please refer to the OpenAI API documentation.
298
331
 
@@ -168,20 +168,18 @@ result = agent.chat_with_workflow(conv)
168
168
 
169
169
  ### Tools
170
170
  There are a variety of tools for the model or the user to use. Some are executed locally
171
- while others are hosted for you. You can also ask an LMM directly to build a tool for
172
- you. For example:
171
+ while others are hosted for you. You can easily access them yourself, for example if
172
+ you want to run `owl_v2` and visualize the output you can run:
173
173
 
174
174
  ```python
175
- >>> import vision_agent as va
176
- >>> lmm = va.lmm.OpenAILMM()
177
- >>> detector = lmm.generate_detector("Can you build a jar detector for me?")
178
- >>> detector(va.tools.load_image("jar.jpg"))
179
- [{"labels": ["jar",],
180
- "scores": [0.99],
181
- "bboxes": [
182
- [0.58, 0.2, 0.72, 0.45],
183
- ]
184
- }]
175
+ import vision_agent.tools as T
176
+ import matplotlib.pyplot as plt
177
+
178
+ image = T.load_image("dogs.jpg")
179
+ dets = T.owl_v2("dogs", image)
180
+ viz = T.overlay_bounding_boxes(image, dets)
181
+ plt.imshow(viz)
182
+ plt.show()
185
183
  ```
186
184
 
187
185
  You can also add custom tools to the agent:
@@ -214,6 +212,41 @@ function. Make sure the documentation is in the same format above with descripti
214
212
  `Parameters:`, `Returns:`, and `Example\n-------`. You can find an example use case
215
213
  [here](examples/custom_tools/) as this is what the agent uses to pick and use the tool.
216
214
 
215
+ ## Additional LLMs
216
+ ### Ollama
217
+ We also provide a `VisionAgentCoder` that uses Ollama. To get started you must download
218
+ a few models:
219
+
220
+ ```bash
221
+ ollama pull llama3.1
222
+ ollama pull mxbai-embed-large
223
+ ```
224
+
225
+ `llama3.1` is used for the `OllamaLMM` for `OllamaVisionAgentCoder`. Normally we would
226
+ use an actual LMM such as `llava` but `llava` cannot handle the long context lengths
227
+ required by the agent. Since `llama3.1` cannot handle images you may see some
228
+ performance degredation. `mxbai-embed-large` is the embedding model used to look up
229
+ tools. You can use it just like you would use `VisionAgentCoder`:
230
+
231
+ ```python
232
+ >>> import vision_agent as va
233
+ >>> agent = va.agent.OllamaVisionAgentCoder()
234
+ >>> agent("Count the apples in the image", media="apples.jpg")
235
+ ```
236
+ > WARNING: VisionAgent doesn't work well unless the underlying LMM is sufficiently powerful. Do not expect good results or even working code with smaller models like Llama 3.1 8B.
237
+
238
+ ### Azure OpenAI
239
+ We also provide a `AzureVisionAgentCoder` that uses Azure OpenAI models. To get started
240
+ follow the Azure Setup section below. You can use it just like you would use=
241
+ `VisionAgentCoder`:
242
+
243
+ ```python
244
+ >>> import vision_agent as va
245
+ >>> agent = va.agent.AzureVisionAgentCoder()
246
+ >>> agent("Count the apples in the image", media="apples.jpg")
247
+ ```
248
+
249
+
217
250
  ### Azure Setup
218
251
  If you want to use Azure OpenAI models, you need to have two OpenAI model deployments:
219
252
 
@@ -252,6 +285,6 @@ agent = va.agent.AzureVisionAgentCoder()
252
285
  2. Follow the instructions to purchase and manage your API credits.
253
286
  3. Ensure your API key is correctly configured in your project settings.
254
287
 
255
- Failure to have sufficient API credits may result in limited or no functionality for the features that rely on the OpenAI API.
256
-
257
- For more details on managing your API usage and credits, please refer to the OpenAI API documentation.
288
+ Failure to have sufficient API credits may result in limited or no functionality for
289
+ the features that rely on the OpenAI API. For more details on managing your API usage
290
+ and credits, please refer to the OpenAI API documentation.
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.113"
7
+ version = "0.2.115"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -0,0 +1,7 @@
1
+ from .agent import Agent
2
+ from .vision_agent import VisionAgent
3
+ from .vision_agent_coder import (
4
+ AzureVisionAgentCoder,
5
+ OllamaVisionAgentCoder,
6
+ VisionAgentCoder,
7
+ )
@@ -1,9 +1,24 @@
1
1
  import json
2
2
  import logging
3
+ import re
3
4
  import sys
4
- from typing import Any, Dict
5
+ from typing import Any, Dict, Optional
5
6
 
6
7
  logging.basicConfig(stream=sys.stdout)
8
+ _LOGGER = logging.getLogger(__name__)
9
+
10
+
11
+ def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
12
+ json_pattern = r"\{.*\}"
13
+ match = re.search(json_pattern, json_str, re.DOTALL)
14
+ if match:
15
+ json_str = match.group()
16
+ try:
17
+ json_dict = json.loads(json_str)
18
+ return json_dict # type: ignore
19
+ except json.JSONDecodeError:
20
+ return None
21
+ return None
7
22
 
8
23
 
9
24
  def extract_json(json_str: str) -> Dict[str, Any]:
@@ -18,8 +33,16 @@ def extract_json(json_str: str) -> Dict[str, Any]:
18
33
  json_str = json_str[json_str.find("```") + len("```") :]
19
34
  # get the last ``` not one from an intermediate string
20
35
  json_str = json_str[: json_str.find("}```")]
36
+ try:
37
+ json_dict = json.loads(json_str)
38
+ except json.JSONDecodeError as e:
39
+ json_dict = _extract_sub_json(json_str)
40
+ if json_dict is not None:
41
+ return json_dict # type: ignore
42
+ error_msg = f"Could not extract JSON from the given str: {json_str}"
43
+ _LOGGER.exception(error_msg)
44
+ raise ValueError(error_msg) from e
21
45
 
22
- json_dict = json.loads(json_str)
23
46
  return json_dict # type: ignore
24
47
 
25
48
 
@@ -28,11 +28,11 @@ from vision_agent.agent.vision_agent_coder_prompts import (
28
28
  TEST_PLANS,
29
29
  USER_REQ,
30
30
  )
31
- from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
31
+ from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
32
32
  from vision_agent.utils import CodeInterpreterFactory, Execution
33
33
  from vision_agent.utils.execute import CodeInterpreter
34
34
  from vision_agent.utils.image_utils import b64_to_pil
35
- from vision_agent.utils.sim import AzureSim, Sim
35
+ from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
36
36
  from vision_agent.utils.video import play_video
37
37
 
38
38
  logging.basicConfig(stream=sys.stdout)
@@ -267,7 +267,11 @@ def pick_plan(
267
267
  pass
268
268
  count += 1
269
269
 
270
- if best_plan is None:
270
+ if (
271
+ best_plan is None
272
+ or "best_plan" not in best_plan
273
+ or ("best_plan" in best_plan and best_plan["best_plan"] not in plans)
274
+ ):
271
275
  best_plan = {"best_plan": list(plans.keys())[0]}
272
276
 
273
277
  if verbosity >= 1:
@@ -589,8 +593,8 @@ class VisionAgentCoder(Agent):
589
593
 
590
594
  Example
591
595
  -------
592
- >>> from vision_agent.agent import VisionAgentCoder
593
- >>> agent = VisionAgentCoder()
596
+ >>> import vision_agent as va
597
+ >>> agent = va.agent.VisionAgentCoder()
594
598
  >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
595
599
  """
596
600
 
@@ -857,6 +861,64 @@ class VisionAgentCoder(Agent):
857
861
  self.report_progress_callback(data)
858
862
 
859
863
 
864
+ class OllamaVisionAgentCoder(VisionAgentCoder):
865
+ """VisionAgentCoder that uses Ollama models for planning, coding, testing.
866
+
867
+ Pre-requisites:
868
+ 1. Run ollama pull llama3.1 for the LLM
869
+ 2. Run ollama pull mxbai-embed-large for the embedding similarity model
870
+
871
+ Technically you should use a VLM such as llava but llava is not able to handle the
872
+ context length and crashes.
873
+
874
+ Example
875
+ -------
876
+ >>> image vision_agent as va
877
+ >>> agent = va.agent.OllamaVisionAgentCoder()
878
+ >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
879
+ """
880
+
881
+ def __init__(
882
+ self,
883
+ planner: Optional[LMM] = None,
884
+ coder: Optional[LMM] = None,
885
+ tester: Optional[LMM] = None,
886
+ debugger: Optional[LMM] = None,
887
+ tool_recommender: Optional[Sim] = None,
888
+ verbosity: int = 0,
889
+ report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
890
+ ) -> None:
891
+ super().__init__(
892
+ planner=(
893
+ OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
894
+ if planner is None
895
+ else planner
896
+ ),
897
+ coder=(
898
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
899
+ if coder is None
900
+ else coder
901
+ ),
902
+ tester=(
903
+ OllamaLMM(model_name="llama3.1", temperature=0.0)
904
+ if tester is None
905
+ else tester
906
+ ),
907
+ debugger=(
908
+ OllamaLMM(model_name="llama3.1", temperature=0.0, json_mode=True)
909
+ if debugger is None
910
+ else debugger
911
+ ),
912
+ tool_recommender=(
913
+ OllamaSim(T.TOOLS_DF, sim_key="desc")
914
+ if tool_recommender is None
915
+ else tool_recommender
916
+ ),
917
+ verbosity=verbosity,
918
+ report_progress_callback=report_progress_callback,
919
+ )
920
+
921
+
860
922
  class AzureVisionAgentCoder(VisionAgentCoder):
861
923
  """VisionAgentCoder that uses Azure OpenAI APIs for planning, coding, testing.
862
924
 
@@ -866,8 +928,8 @@ class AzureVisionAgentCoder(VisionAgentCoder):
866
928
 
867
929
  Example
868
930
  -------
869
- >>> from vision_agent import AzureVisionAgentCoder
870
- >>> agent = AzureVisionAgentCoder()
931
+ >>> import vision_agent as va
932
+ >>> agent = va.agent.AzureVisionAgentCoder()
871
933
  >>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
872
934
  """
873
935
 
@@ -330,12 +330,28 @@ class OllamaLMM(LMM):
330
330
  model_name: str = "llava",
331
331
  base_url: Optional[str] = "http://localhost:11434/api",
332
332
  json_mode: bool = False,
333
+ num_ctx: int = 128_000,
333
334
  **kwargs: Any,
334
335
  ):
336
+ """Initializes the Ollama LMM. kwargs are passed as 'options' to the model.
337
+ More information on options can be found here
338
+ https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
339
+
340
+ Parameters:
341
+ model_name (str): The ollama name of the model.
342
+ base_url (str): The base URL of the Ollama API.
343
+ json_mode (bool): Whether to use JSON mode.
344
+ num_ctx (int): The context length for the model.
345
+ kwargs (Any): Additional options to pass to the model.
346
+ """
347
+
335
348
  self.url = base_url
336
349
  self.model_name = model_name
337
- self.json_mode = json_mode
338
- self.kwargs = kwargs
350
+ self.kwargs = {"options": kwargs}
351
+
352
+ if json_mode:
353
+ self.kwargs["format"] = "json" # type: ignore
354
+ self.kwargs["options"]["num_cxt"] = num_ctx
339
355
 
340
356
  def __call__(
341
357
  self,
@@ -369,13 +385,14 @@ class OllamaLMM(LMM):
369
385
  url = f"{self.url}/chat"
370
386
  model = self.model_name
371
387
  messages = fixed_chat
372
- data = {"model": model, "messages": messages}
388
+ data: Dict[str, Any] = {"model": model, "messages": messages}
373
389
 
374
390
  tmp_kwargs = self.kwargs | kwargs
375
391
  data.update(tmp_kwargs)
376
- json_data = json.dumps(data)
377
392
  if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
378
393
 
394
+ json_data = json.dumps(data)
395
+
379
396
  def f() -> Iterator[Optional[str]]:
380
397
  with requests.post(url, data=json_data, stream=True) as stream:
381
398
  if stream.status_code != 200:
@@ -392,13 +409,14 @@ class OllamaLMM(LMM):
392
409
 
393
410
  return f()
394
411
  else:
395
- stream = requests.post(url, data=json_data)
396
- if stream.status_code != 200:
397
- raise ValueError(
398
- f"Request failed with status code {stream.status_code}"
399
- )
400
- stream = stream.json()
401
- return stream["message"]["content"] # type: ignore
412
+ data["stream"] = False
413
+ json_data = json.dumps(data)
414
+ resp = requests.post(url, data=json_data)
415
+
416
+ if resp.status_code != 200:
417
+ raise ValueError(f"Request failed with status code {resp.status_code}")
418
+ resp = resp.json()
419
+ return resp["message"]["content"] # type: ignore
402
420
 
403
421
  def generate(
404
422
  self,
@@ -408,7 +426,7 @@ class OllamaLMM(LMM):
408
426
  ) -> Union[str, Iterator[Optional[str]]]:
409
427
 
410
428
  url = f"{self.url}/generate"
411
- data = {
429
+ data: Dict[str, Any] = {
412
430
  "model": self.model_name,
413
431
  "prompt": prompt,
414
432
  "images": [],
@@ -416,13 +434,14 @@ class OllamaLMM(LMM):
416
434
 
417
435
  if media and len(media) > 0:
418
436
  for m in media:
419
- data["images"].append(encode_media(m)) # type: ignore
437
+ data["images"].append(encode_media(m))
420
438
 
421
439
  tmp_kwargs = self.kwargs | kwargs
422
440
  data.update(tmp_kwargs)
423
- json_data = json.dumps(data)
424
441
  if "stream" in tmp_kwargs and tmp_kwargs["stream"]:
425
442
 
443
+ json_data = json.dumps(data)
444
+
426
445
  def f() -> Iterator[Optional[str]]:
427
446
  with requests.post(url, data=json_data, stream=True) as stream:
428
447
  if stream.status_code != 200:
@@ -439,15 +458,15 @@ class OllamaLMM(LMM):
439
458
 
440
459
  return f()
441
460
  else:
442
- stream = requests.post(url, data=json_data)
461
+ data["stream"] = False
462
+ json_data = json.dumps(data)
463
+ resp = requests.post(url, data=json_data)
443
464
 
444
- if stream.status_code != 200:
445
- raise ValueError(
446
- f"Request failed with status code {stream.status_code}"
447
- )
465
+ if resp.status_code != 200:
466
+ raise ValueError(f"Request failed with status code {resp.status_code}")
448
467
 
449
- stream = stream.json()
450
- return stream["response"] # type: ignore
468
+ resp = resp.json()
469
+ return resp["response"] # type: ignore
451
470
 
452
471
 
453
472
  class ClaudeSonnetLMM(LMM):
@@ -6,5 +6,5 @@ from .execute import (
6
6
  Logs,
7
7
  Result,
8
8
  )
9
- from .sim import AzureSim, Sim, load_sim, merge_sim
9
+ from .sim import AzureSim, OllamaSim, Sim, load_sim, merge_sim
10
10
  from .video import extract_frames_from_video
@@ -532,7 +532,7 @@ print(f"Vision Agent version: {va_version}")"""
532
532
 
533
533
  @staticmethod
534
534
  def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore
535
- template_name = os.environ.get("E2B_TEMPLATE_NAME", "nx3fagq7sgdliww9cvm3")
535
+ template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox")
536
536
  _LOGGER.info(
537
537
  f"Creating a new E2BCodeInterpreter using template: {template_name}"
538
538
  )
@@ -1,20 +1,21 @@
1
1
  import os
2
2
  from functools import lru_cache
3
3
  from pathlib import Path
4
- from typing import Dict, List, Optional, Sequence, Union
4
+ from typing import Callable, Dict, List, Optional, Sequence, Union
5
5
 
6
6
  import numpy as np
7
7
  import pandas as pd
8
- from openai import AzureOpenAI, Client, OpenAI
8
+ import requests
9
+ from openai import AzureOpenAI, OpenAI
9
10
  from scipy.spatial.distance import cosine # type: ignore
10
11
 
11
12
 
12
13
  @lru_cache(maxsize=512)
13
14
  def get_embedding(
14
- client: Client, text: str, model: str = "text-embedding-3-small"
15
+ emb_call: Callable[[List[str]], List[float]], text: str
15
16
  ) -> List[float]:
16
17
  text = text.replace("\n", " ")
17
- return client.embeddings.create(input=[text], model=model).data[0].embedding
18
+ return emb_call([text])
18
19
 
19
20
 
20
21
  class Sim:
@@ -35,14 +36,19 @@ class Sim:
35
36
  model: str: The model to use for embeddings.
36
37
  """
37
38
  self.df = df
38
- self.client = OpenAI(api_key=api_key)
39
+ client = OpenAI(api_key=api_key)
40
+ self.emb_call = (
41
+ lambda text: client.embeddings.create(input=text, model=model)
42
+ .data[0]
43
+ .embedding
44
+ )
39
45
  self.model = model
40
46
  if "embs" not in df.columns and sim_key is None:
41
47
  raise ValueError("key is required if no column 'embs' is present.")
42
48
 
43
49
  if sim_key is not None:
44
50
  self.df["embs"] = self.df[sim_key].apply(
45
- lambda x: get_embedding(self.client, x, model=self.model)
51
+ lambda x: get_embedding(self.emb_call, x)
46
52
  )
47
53
 
48
54
  def save(self, sim_file: Union[str, Path]) -> None:
@@ -70,7 +76,7 @@ class Sim:
70
76
  Sequence[Dict]: The top k most similar items.
71
77
  """
72
78
 
73
- embedding = get_embedding(self.client, query, model=self.model)
79
+ embedding = get_embedding(self.emb_call, query)
74
80
  self.df["sim"] = self.df.embs.apply(lambda x: 1 - cosine(x, embedding))
75
81
  res = self.df.sort_values("sim", ascending=False).head(k)
76
82
  if thresh is not None:
@@ -105,17 +111,51 @@ class AzureSim(Sim):
105
111
  )
106
112
 
107
113
  self.df = df
108
- self.client = AzureOpenAI(
114
+ client = AzureOpenAI(
109
115
  api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
110
116
  )
117
+ self.emb_call = (
118
+ lambda text: client.embeddings.create(input=text, model=model)
119
+ .data[0]
120
+ .embedding
121
+ )
111
122
 
112
123
  self.model = model
124
+ if "embs" not in df.columns and sim_key is None:
125
+ raise ValueError("key is required if no column 'embs' is present.")
126
+
127
+ if sim_key is not None:
128
+ self.df["embs"] = self.df[sim_key].apply(lambda x: get_embedding(client, x))
129
+
130
+
131
+ class OllamaSim(Sim):
132
+ def __init__(
133
+ self,
134
+ df: pd.DataFrame,
135
+ sim_key: Optional[str] = None,
136
+ model_name: Optional[str] = None,
137
+ base_url: Optional[str] = None,
138
+ ) -> None:
139
+ self.df = df
140
+ if base_url is None:
141
+ base_url = "http://localhost:11434/api/embeddings"
142
+ if model_name is None:
143
+ model_name = "mxbai-embed-large"
144
+
145
+ def emb_call(text: List[str]) -> List[float]:
146
+ resp = requests.post(
147
+ base_url, json={"prompt": text[0], "model": model_name}
148
+ )
149
+ return resp.json()["embedding"] # type: ignore
150
+
151
+ self.emb_call = emb_call
152
+
113
153
  if "embs" not in df.columns and sim_key is None:
114
154
  raise ValueError("key is required if no column 'embs' is present.")
115
155
 
116
156
  if sim_key is not None:
117
157
  self.df["embs"] = self.df[sim_key].apply(
118
- lambda x: get_embedding(self.client, x, model=self.model)
158
+ lambda x: get_embedding(emb_call, x)
119
159
  )
120
160
 
121
161
 
@@ -1,3 +0,0 @@
1
- from .agent import Agent
2
- from .vision_agent import VisionAgent
3
- from .vision_agent_coder import AzureVisionAgentCoder, VisionAgentCoder
File without changes