vision-agent 0.2.219__tar.gz → 0.2.221__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. {vision_agent-0.2.219 → vision_agent-0.2.221}/PKG-INFO +5 -9
  2. {vision_agent-0.2.219 → vision_agent-0.2.221}/README.md +4 -8
  3. {vision_agent-0.2.219 → vision_agent-0.2.221}/pyproject.toml +1 -1
  4. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/.sim_tools/df.csv +21 -3
  5. vision_agent-0.2.221/vision_agent/.sim_tools/embs.npy +0 -0
  6. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_coder_v2.py +3 -3
  7. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_planner_prompts_v2.py +4 -3
  8. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/__init__.py +1 -1
  9. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/planner_tools.py +4 -5
  10. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/tools.py +29 -18
  11. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/__init__.py +0 -1
  12. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/execute.py +2 -2
  13. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/image_utils.py +1 -1
  14. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/sim.py +44 -3
  15. vision_agent-0.2.219/vision_agent/.sim_tools/embs.npy +0 -0
  16. {vision_agent-0.2.219 → vision_agent-0.2.221}/LICENSE +0 -0
  17. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/__init__.py +0 -0
  18. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/README.md +0 -0
  19. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/__init__.py +0 -0
  20. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/agent.py +0 -0
  21. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/agent_utils.py +0 -0
  22. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/types.py +0 -0
  23. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent.py +0 -0
  24. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_coder.py +0 -0
  25. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_coder_prompts.py +0 -0
  26. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_coder_prompts_v2.py +0 -0
  27. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_planner.py +0 -0
  28. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_planner_prompts.py +0 -0
  29. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_planner_v2.py +0 -0
  30. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_prompts.py +0 -0
  31. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_prompts_v2.py +0 -0
  32. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/agent/vision_agent_v2.py +0 -0
  33. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/clients/__init__.py +0 -0
  34. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/clients/http.py +0 -0
  35. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/clients/landing_public_api.py +0 -0
  36. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/fonts/__init__.py +0 -0
  37. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
  38. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/lmm/__init__.py +0 -0
  39. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/lmm/lmm.py +0 -0
  40. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/lmm/types.py +0 -0
  41. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/meta_tools.py +0 -0
  42. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/prompts.py +0 -0
  43. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/tool_utils.py +0 -0
  44. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/tools/tools_types.py +0 -0
  45. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/exceptions.py +0 -0
  46. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/type_defs.py +0 -0
  47. {vision_agent-0.2.219 → vision_agent-0.2.221}/vision_agent/utils/video.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.219
3
+ Version: 0.2.221
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -89,18 +89,15 @@ To get started with the python library, you can install it using pip:
89
89
  pip install vision-agent
90
90
  ```
91
91
 
92
- Ensure you have both an Anthropic key and an OpenAI API key and set in your environment
93
- variables (if you are using Azure OpenAI please see the Azure setup section):
94
-
95
92
  ```bash
96
93
  export ANTHROPIC_API_KEY="your-api-key"
97
- export OPENAI_API_KEY="your-api-key"
98
94
  ```
99
95
 
100
96
  ---
101
97
  **NOTE**
102
- You must have both Anthropic and OpenAI API keys set in your environment variables to
103
- use VisionAgent. If you don't have an Anthropic key you can use Ollama as a backend.
98
+ You must have the Anthropic API key set in your environment variables to use
99
+ VisionAgent. If you don't have an Anthropic key you can use another provider like
100
+ OpenAI or Ollama.
104
101
  ---
105
102
 
106
103
  #### Chatting with VisionAgent
@@ -161,8 +158,7 @@ Anthropic/OpenAI models.
161
158
  ### Chatting and Message Formats
162
159
  `VisionAgent` is an agent that can chat with you and call other tools or agents to
163
160
  write vision code for you. You can interact with it like you would ChatGPT or any other
164
- chatbot. The agent uses Clause-3.5 for it's LMM and OpenAI for embeddings for searching
165
- for tools.
161
+ chatbot. The agent uses Clause-3.5 for it's LMM.
166
162
 
167
163
  The message format is:
168
164
  ```json
@@ -44,18 +44,15 @@ To get started with the python library, you can install it using pip:
44
44
  pip install vision-agent
45
45
  ```
46
46
 
47
- Ensure you have both an Anthropic key and an OpenAI API key and set in your environment
48
- variables (if you are using Azure OpenAI please see the Azure setup section):
49
-
50
47
  ```bash
51
48
  export ANTHROPIC_API_KEY="your-api-key"
52
- export OPENAI_API_KEY="your-api-key"
53
49
  ```
54
50
 
55
51
  ---
56
52
  **NOTE**
57
- You must have both Anthropic and OpenAI API keys set in your environment variables to
58
- use VisionAgent. If you don't have an Anthropic key you can use Ollama as a backend.
53
+ You must have the Anthropic API key set in your environment variables to use
54
+ VisionAgent. If you don't have an Anthropic key you can use another provider like
55
+ OpenAI or Ollama.
59
56
  ---
60
57
 
61
58
  #### Chatting with VisionAgent
@@ -116,8 +113,7 @@ Anthropic/OpenAI models.
116
113
  ### Chatting and Message Formats
117
114
  `VisionAgent` is an agent that can chat with you and call other tools or agents to
118
115
  write vision code for you. You can interact with it like you would ChatGPT or any other
119
- chatbot. The agent uses Clause-3.5 for it's LMM and OpenAI for embeddings for searching
120
- for tools.
116
+ chatbot. The agent uses Clause-3.5 for it's LMM.
121
117
 
122
118
  The message format is:
123
119
  ```json
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "vision-agent"
7
- version = "0.2.219"
7
+ version = "0.2.221"
8
8
  description = "Toolset for Vision Agent"
9
9
  authors = ["Landing AI <dev@landing.ai>"]
10
10
  readme = "README.md"
@@ -460,19 +460,37 @@ desc,doc,name
460
460
  -------
461
461
  >>> document_analysis(image)
462
462
  {'pages':
463
- [{'bbox': [0, 0, 1700, 2200],
464
- 'chunks': [{'bbox': [1371, 75, 1503, 112],
463
+ [{'bbox': [0, 0, 1.0, 1.0],
464
+ 'chunks': [{'bbox': [0.8, 0.1, 1.0, 0.2],
465
465
  'label': 'page_header',
466
466
  'order': 75
467
467
  'caption': 'Annual Report 2024',
468
468
  'summary': 'This annual report summarizes ...' },
469
- {'bbox': [201, 1119, 1497, 1647],
469
+ {'bbox': [0.2, 0.9, 0.9, 1.0],
470
470
  'label': table',
471
471
  'order': 1119,
472
472
  'caption': [{'Column 1': 'Value 1', 'Column 2': 'Value 2'},
473
473
  'summary': 'This table illustrates a trend of ...'},
474
474
  ],
475
475
  ",document_extraction
476
+ "'document_qa' is a tool that can answer any questions about arbitrary documents, presentations, or tables. It's very useful for document QA tasks, you can ask it a specific question or ask it to return a JSON object answering multiple questions about the document.","document_qa(prompt: str, image: numpy.ndarray) -> str:
477
+ 'document_qa' is a tool that can answer any questions about arbitrary documents,
478
+ presentations, or tables. It's very useful for document QA tasks, you can ask it a
479
+ specific question or ask it to return a JSON object answering multiple questions
480
+ about the document.
481
+
482
+ Parameters:
483
+ prompt (str): The question to be answered about the document image.
484
+ image (np.ndarray): The document image to analyze.
485
+
486
+ Returns:
487
+ str: The answer to the question based on the document's context.
488
+
489
+ Example
490
+ -------
491
+ >>> document_qa(image, question)
492
+ 'The answer to the question ...'
493
+ ",document_qa
476
494
  'video_temporal_localization' will run qwen2vl on each chunk_length_frames value selected for the video. It can detect multiple objects independently per chunk_length_frames given a text prompt such as a referring expression but does not track objects across frames. It returns a list of floats with a value of 1.0 if the objects are found in a given chunk_length_frames of the video.,"video_temporal_localization(prompt: str, frames: List[numpy.ndarray], model: str = 'qwen2vl', chunk_length_frames: Optional[int] = 2) -> List[float]:
477
495
  'video_temporal_localization' will run qwen2vl on each chunk_length_frames
478
496
  value selected for the video. It can detect multiple objects independently per
@@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
5
5
  from rich.console import Console
6
6
  from rich.markup import escape
7
7
 
8
- import vision_agent.tools as T
8
+ import vision_agent.tools.tools as T
9
9
  from vision_agent.agent import AgentCoder, AgentPlanner
10
10
  from vision_agent.agent.agent_utils import (
11
11
  DefaultImports,
@@ -34,7 +34,7 @@ from vision_agent.utils.execute import (
34
34
  CodeInterpreterFactory,
35
35
  Execution,
36
36
  )
37
- from vision_agent.utils.sim import Sim
37
+ from vision_agent.utils.sim import Sim, get_tool_recommender
38
38
 
39
39
  _CONSOLE = Console()
40
40
 
@@ -316,7 +316,7 @@ class VisionAgentCoderV2(AgentCoder):
316
316
  elif isinstance(tool_recommender, Sim):
317
317
  self.tool_recommender = tool_recommender
318
318
  else:
319
- self.tool_recommender = T.get_tool_recommender()
319
+ self.tool_recommender = get_tool_recommender()
320
320
 
321
321
  self.verbose = verbose
322
322
  self.code_sandbox_runtime = code_sandbox_runtime
@@ -440,16 +440,17 @@ PICK_PLAN = """
440
440
  """
441
441
 
442
442
  CATEGORIZE_TOOL_REQUEST = """
443
- You are given a task: {task} from the user. Your task is to extract the type of category this task belongs to, it can be one or more of the following:
443
+ You are given a task: "{task}" from the user. You must extract the type of category this task belongs to, it can be one or more of the following:
444
444
  - "object detection and counting" - detecting objects or counting objects from a text prompt in an image or video.
445
445
  - "classification" - classifying objects in an image given a text prompt.
446
446
  - "segmentation" - segmenting objects in an image or video given a text prompt.
447
447
  - "OCR" - extracting text from an image.
448
448
  - "VQA" - answering questions about an image or video, can also be used for text extraction.
449
+ - "DocQA" - answering questions about a document or extracting information from a document.
449
450
  - "video object tracking" - tracking objects in a video.
450
451
  - "depth and pose estimation" - estimating the depth or pose of objects in an image.
451
452
 
452
- Return the category or categories (comma separated) inside tags <category># your categories here</category>.
453
+ Return the category or categories (comma separated) inside tags <category># your categories here</category>. If you are unsure about a task, it is better to include more categories than less.
453
454
  """
454
455
 
455
456
  TEST_TOOLS = """
@@ -473,7 +474,7 @@ TEST_TOOLS = """
473
474
  {examples}
474
475
 
475
476
  **Instructions**:
476
- 1. List all the tools under **Tools** and the user request. Write a program to load the media and call every tool in parallel and print it's output along with other relevant information.
477
+ 1. List all the tools under **Tools** and the user request. Write a program to load the media and call the most relevant tools in parallel and print it's output along with other relevant information.
477
478
  2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary.
478
479
  3. Your test case MUST run only on the given images which are {media}
479
480
  4. Print this final dictionary.
@@ -43,7 +43,6 @@ from .tools import (
43
43
  flux_image_inpainting,
44
44
  generate_pose_image,
45
45
  get_tool_documentation,
46
- get_tool_recommender,
47
46
  gpt4o_image_vqa,
48
47
  gpt4o_video_vqa,
49
48
  load_image,
@@ -63,6 +62,7 @@ from .tools import (
63
62
  save_json,
64
63
  save_video,
65
64
  siglip_classification,
65
+ stella_embeddings,
66
66
  template_match,
67
67
  video_temporal_localization,
68
68
  vit_image_classification,
@@ -32,6 +32,7 @@ from vision_agent.utils.execute import (
32
32
  MimeType,
33
33
  )
34
34
  from vision_agent.utils.image_utils import convert_to_b64
35
+ from vision_agent.utils.sim import get_tool_recommender
35
36
 
36
37
  TOOL_FUNCTIONS = {tool.__name__: tool for tool in T.TOOLS}
37
38
 
@@ -116,13 +117,11 @@ def run_tool_testing(
116
117
  query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
117
118
  category = extract_tag(query, "category") # type: ignore
118
119
  if category is None:
119
- category = task
120
+ query = task
120
121
  else:
121
- category = (
122
- f"I need models from the {category.strip()} category of tools. {task}"
123
- )
122
+ query = f"{category.strip()}. {task}"
124
123
 
125
- tool_docs = T.get_tool_recommender().top_k(category, k=10, thresh=0.2)
124
+ tool_docs = get_tool_recommender().top_k(query, k=5, thresh=0.3)
126
125
  if exclude_tools is not None and len(exclude_tools) > 0:
127
126
  cleaned_tool_docs = []
128
127
  for tool_doc in tool_docs:
@@ -7,7 +7,6 @@ import urllib.request
7
7
  from base64 import b64encode
8
8
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
9
  from enum import Enum
10
- from functools import lru_cache
11
10
  from importlib import resources
12
11
  from pathlib import Path
13
12
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
@@ -49,7 +48,6 @@ from vision_agent.utils.image_utils import (
49
48
  rle_decode,
50
49
  rle_decode_array,
51
50
  )
52
- from vision_agent.utils.sim import Sim, load_cached_sim
53
51
  from vision_agent.utils.video import (
54
52
  extract_frames_from_video,
55
53
  frames_to_bytes,
@@ -85,11 +83,6 @@ _OCR_URL = "https://app.landing.ai/ocr/v1/detect-text"
85
83
  _LOGGER = logging.getLogger(__name__)
86
84
 
87
85
 
88
- @lru_cache(maxsize=1)
89
- def get_tool_recommender() -> Sim:
90
- return load_cached_sim(TOOLS_DF)
91
-
92
-
93
86
  def _display_tool_trace(
94
87
  function_name: str,
95
88
  request: Dict[str, Any],
@@ -410,7 +403,7 @@ def owl_v2_video(
410
403
  _display_tool_trace(
411
404
  owl_v2_video.__name__,
412
405
  payload,
413
- detections[0],
406
+ detections,
414
407
  files,
415
408
  )
416
409
  return bboxes_formatted
@@ -2178,13 +2171,14 @@ def document_qa(
2178
2171
  prompt: str,
2179
2172
  image: np.ndarray,
2180
2173
  ) -> str:
2181
- """'document_qa' is a tool that can answer any questions about arbitrary
2182
- images of documents or presentations. It answers by analyzing the contextual document data
2183
- and then using a model to answer specific questions. It returns text as an answer to the question.
2174
+ """'document_qa' is a tool that can answer any questions about arbitrary documents,
2175
+ presentations, or tables. It's very useful for document QA tasks, you can ask it a
2176
+ specific question or ask it to return a JSON object answering multiple questions
2177
+ about the document.
2184
2178
 
2185
2179
  Parameters:
2186
- prompt (str): The question to be answered about the document image
2187
- image (np.ndarray): The document image to analyze
2180
+ prompt (str): The question to be answered about the document image.
2181
+ image (np.ndarray): The document image to analyze.
2188
2182
 
2189
2183
  Returns:
2190
2184
  str: The answer to the question based on the document's context.
@@ -2203,7 +2197,7 @@ def document_qa(
2203
2197
  "model": "document-analysis",
2204
2198
  }
2205
2199
 
2206
- data: dict[str, Any] = send_inference_request(
2200
+ data: Dict[str, Any] = send_inference_request(
2207
2201
  payload=payload,
2208
2202
  endpoint_name="document-analysis",
2209
2203
  files=files,
@@ -2225,10 +2219,10 @@ def document_qa(
2225
2219
  data = normalize(data)
2226
2220
 
2227
2221
  prompt = f"""
2228
- Document Context:
2229
- {data}\n
2230
- Question: {prompt}\n
2231
- Please provide a clear, concise answer using only the information from the document. If the answer is not definitively contained in the document, say "I cannot find the answer in the provided document."
2222
+ Document Context:
2223
+ {data}\n
2224
+ Question: {prompt}\n
2225
+ Answer the question directly using only the information from the document, do not answer with any additional text besides the answer. If the answer is not definitively contained in the document, say "I cannot find the answer in the provided document."
2232
2226
  """
2233
2227
 
2234
2228
  lmm = AnthropicLMM()
@@ -2245,6 +2239,22 @@ def document_qa(
2245
2239
  return llm_output
2246
2240
 
2247
2241
 
2242
+ def stella_embeddings(prompts: List[str]) -> List[np.ndarray]:
2243
+ payload = {
2244
+ "input": prompts,
2245
+ "model": "stella1.5b",
2246
+ }
2247
+
2248
+ data: Dict[str, Any] = send_inference_request(
2249
+ payload=payload,
2250
+ endpoint_name="embeddings",
2251
+ v2=True,
2252
+ metadata_payload={"function_name": "get_embeddings"},
2253
+ is_form=True,
2254
+ )
2255
+ return [d["embedding"] for d in data] # type: ignore
2256
+
2257
+
2248
2258
  # Utility and visualization functions
2249
2259
 
2250
2260
 
@@ -2781,6 +2791,7 @@ FUNCTION_TOOLS = [
2781
2791
  qwen2_vl_images_vqa,
2782
2792
  qwen2_vl_video_vqa,
2783
2793
  document_extraction,
2794
+ document_qa,
2784
2795
  video_temporal_localization,
2785
2796
  flux_image_inpainting,
2786
2797
  siglip_classification,
@@ -7,4 +7,3 @@ from .execute import (
7
7
  Result,
8
8
  )
9
9
  from .sim import AzureSim, OllamaSim, Sim, load_sim, merge_sim
10
- from .video import extract_frames_from_video, video_writer
@@ -28,10 +28,10 @@ from nbclient import __version__ as nbclient_version
28
28
  from nbclient.exceptions import CellTimeoutError, DeadKernelError
29
29
  from nbclient.util import run_sync
30
30
  from nbformat.v4 import new_code_cell
31
+ from opentelemetry.context import get_current
32
+ from opentelemetry.trace import SpanKind, Status, StatusCode, get_tracer
31
33
  from pydantic import BaseModel, field_serializer
32
34
  from typing_extensions import Self
33
- from opentelemetry.trace import get_tracer, Status, StatusCode, SpanKind
34
- from opentelemetry.context import get_current
35
35
 
36
36
  from vision_agent.utils.exceptions import (
37
37
  RemoteSandboxCreationError,
@@ -11,7 +11,7 @@ import numpy as np
11
11
  from PIL import Image, ImageDraw, ImageFont
12
12
  from PIL.Image import Image as ImageType
13
13
 
14
- from vision_agent.utils import extract_frames_from_video
14
+ from vision_agent.utils.video import extract_frames_from_video
15
15
 
16
16
  COLORS = [
17
17
  (158, 218, 229),
@@ -12,6 +12,13 @@ import requests
12
12
  from openai import AzureOpenAI, OpenAI
13
13
  from scipy.spatial.distance import cosine # type: ignore
14
14
 
15
+ from vision_agent.tools.tools import TOOLS_DF, stella_embeddings
16
+
17
+
18
+ @lru_cache(maxsize=1)
19
+ def get_tool_recommender() -> "Sim":
20
+ return load_cached_sim(TOOLS_DF)
21
+
15
22
 
16
23
  @lru_cache(maxsize=512)
17
24
  def get_embedding(
@@ -27,13 +34,13 @@ def load_cached_sim(
27
34
  cached_dir_full_path = str(resources.files("vision_agent") / cached_dir)
28
35
  if os.path.exists(cached_dir_full_path):
29
36
  if tools_df is not None:
30
- if Sim.check_load(cached_dir_full_path, tools_df):
37
+ if StellaSim.check_load(cached_dir_full_path, tools_df):
31
38
  # don't pass sim_key to loaded Sim object or else it will re-calculate embeddings
32
- return Sim.load(cached_dir_full_path)
39
+ return StellaSim.load(cached_dir_full_path)
33
40
  if os.path.exists(cached_dir_full_path):
34
41
  shutil.rmtree(cached_dir_full_path)
35
42
 
36
- sim = Sim(tools_df, sim_key=sim_key)
43
+ sim = StellaSim(tools_df, sim_key=sim_key)
37
44
  sim.save(cached_dir_full_path)
38
45
  return sim
39
46
 
@@ -214,6 +221,40 @@ class OllamaSim(Sim):
214
221
  )
215
222
 
216
223
 
224
+ class StellaSim(Sim):
225
+ def __init__(
226
+ self,
227
+ df: pd.DataFrame,
228
+ sim_key: Optional[str] = None,
229
+ ) -> None:
230
+ self.df = df
231
+
232
+ def emb_call(text: List[str]) -> List[float]:
233
+ return stella_embeddings(text)[0] # type: ignore
234
+
235
+ self.emb_call = emb_call
236
+
237
+ if "embs" not in df.columns and sim_key is None:
238
+ raise ValueError("key is required if no column 'embs' is present.")
239
+
240
+ if sim_key is not None:
241
+ self.df["embs"] = self.df[sim_key].apply(
242
+ lambda x: get_embedding(emb_call, x)
243
+ )
244
+
245
+ @staticmethod
246
+ def load(
247
+ load_dir: Union[str, Path],
248
+ api_key: Optional[str] = None,
249
+ model: str = "stella1.5b",
250
+ ) -> "StellaSim":
251
+ load_dir = Path(load_dir)
252
+ df = pd.read_csv(load_dir / "df.csv")
253
+ embs = np.load(load_dir / "embs.npy")
254
+ df["embs"] = list(embs)
255
+ return StellaSim(df)
256
+
257
+
217
258
  def merge_sim(sim1: Sim, sim2: Sim) -> Sim:
218
259
  return Sim(pd.concat([sim1.df, sim2.df], ignore_index=True))
219
260
 
File without changes