vision-agent 0.2.120__py3-none-any.whl → 0.2.122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,7 +30,7 @@ class BoilerplateCode:
30
30
  pre_code = [
31
31
  "from typing import *",
32
32
  "from vision_agent.utils.execute import CodeInterpreter",
33
- "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
33
+ "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
34
34
  "artifacts = Artifacts('{remote_path}')",
35
35
  "artifacts.load('{remote_path}')",
36
36
  ]
@@ -76,11 +76,16 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
76
76
 
77
77
  def run_code_action(
78
78
  code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
79
- ) -> Execution:
80
- return code_interpreter.exec_isolation(
79
+ ) -> Tuple[Execution, str]:
80
+ result = code_interpreter.exec_isolation(
81
81
  BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
82
82
  )
83
83
 
84
+ obs = str(result.logs)
85
+ if result.error:
86
+ obs += f"\n{result.error}"
87
+ return result, obs
88
+
84
89
 
85
90
  def parse_execution(response: str) -> Optional[str]:
86
91
  code = None
@@ -192,7 +197,7 @@ class VisionAgent(Agent):
192
197
  artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
193
198
 
194
199
  with CodeInterpreterFactory.new_instance(
195
- code_sandbox_runtime=self.code_sandbox_runtime
200
+ code_sandbox_runtime=self.code_sandbox_runtime,
196
201
  ) as code_interpreter:
197
202
  orig_chat = copy.deepcopy(chat)
198
203
  int_chat = copy.deepcopy(chat)
@@ -260,10 +265,9 @@ class VisionAgent(Agent):
260
265
  code_action = parse_execution(response["response"])
261
266
 
262
267
  if code_action is not None:
263
- result = run_code_action(
268
+ result, obs = run_code_action(
264
269
  code_action, code_interpreter, str(remote_artifacts_path)
265
270
  )
266
- obs = str(result.logs)
267
271
 
268
272
  if self.verbosity >= 1:
269
273
  _LOGGER.info(obs)
@@ -1,5 +1,4 @@
1
1
  import copy
2
- import difflib
3
2
  import logging
4
3
  import os
5
4
  import sys
@@ -29,6 +28,7 @@ from vision_agent.agent.vision_agent_coder_prompts import (
29
28
  USER_REQ,
30
29
  )
31
30
  from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
31
+ from vision_agent.tools.meta_tools import get_diff
32
32
  from vision_agent.utils import CodeInterpreterFactory, Execution
33
33
  from vision_agent.utils.execute import CodeInterpreter
34
34
  from vision_agent.utils.image_utils import b64_to_pil
@@ -63,14 +63,6 @@ class DefaultImports:
63
63
  return DefaultImports.to_code_string() + "\n\n" + code
64
64
 
65
65
 
66
- def get_diff(before: str, after: str) -> str:
67
- return "".join(
68
- difflib.unified_diff(
69
- before.splitlines(keepends=True), after.splitlines(keepends=True)
70
- )
71
- )
72
-
73
-
74
66
  def format_memory(memory: List[Dict[str, str]]) -> str:
75
67
  output_str = ""
76
68
  for i, m in enumerate(memory):
@@ -81,20 +81,19 @@ plan2:
81
81
  - Count the number of detected objects labeled as 'person'.
82
82
  plan3:
83
83
  - Load the image from the provided file path 'image.jpg'.
84
- - Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
84
+ - Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
85
85
 
86
86
  ```python
87
- from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
87
+ from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
88
88
  image = load_image("image.jpg")
89
89
  owl_v2_out = owl_v2("person", image)
90
90
 
91
91
  gsam_out = grounding_sam("person", image)
92
92
  gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
93
93
 
94
- loca_out = loca_zero_shot_counting(image)
95
- loca_out = loca_out["count"]
94
+ cgd_out = countgd_counting(image)
96
95
 
97
- final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
96
+ final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
98
97
  print(final_out)
99
98
  ```
100
99
  """
@@ -48,7 +48,7 @@ OBSERVATION:
48
48
  4| return dogs
49
49
  [End of artifact]
50
50
 
51
- AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
51
+ AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
52
52
 
53
53
  OBSERVATION:
54
54
  ----- stdout -----
@@ -75,7 +75,7 @@ OBSERVATION:
75
75
  4| return dogs
76
76
  [End of artifact]
77
77
 
78
- AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
78
+ AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
79
79
 
80
80
  OBSERVATION:
81
81
  ----- stdout -----
@@ -126,7 +126,7 @@ OBSERVATION:
126
126
  15| return count
127
127
  [End of artifact]
128
128
 
129
- AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code to get the output and write the visualization to the artifacts so the user can see it.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
129
+ AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
130
130
 
131
131
  OBSERVATION:
132
132
  ----- stdout -----
vision_agent/lmm/lmm.py CHANGED
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
286
286
 
287
287
  return lambda x: T.grounding_sam(params["prompt"], x)
288
288
 
289
- def generate_zero_shot_counter(self, question: str) -> Callable:
290
- return T.loca_zero_shot_counting
291
-
292
289
  def generate_image_qa_tool(self, question: str) -> Callable:
293
290
  return lambda x: T.git_vqa_v2(question, x)
294
291
 
@@ -37,10 +37,13 @@ from .tools import (
37
37
  load_image,
38
38
  loca_visual_prompt_counting,
39
39
  loca_zero_shot_counting,
40
+ countgd_counting,
41
+ countgd_example_based_counting,
40
42
  ocr,
41
43
  overlay_bounding_boxes,
42
44
  overlay_heat_map,
43
45
  overlay_segmentation_masks,
46
+ overlay_counting_results,
44
47
  owl_v2,
45
48
  save_image,
46
49
  save_json,
@@ -1,5 +1,7 @@
1
+ import difflib
1
2
  import os
2
3
  import pickle as pkl
4
+ import re
3
5
  import subprocess
4
6
  import tempfile
5
7
  from pathlib import Path
@@ -8,10 +10,13 @@ from typing import Any, Dict, List, Optional, Union
8
10
  from IPython.display import display
9
11
 
10
12
  import vision_agent as va
13
+ from vision_agent.clients.landing_public_api import LandingPublicAPI
11
14
  from vision_agent.lmm.types import Message
12
15
  from vision_agent.tools.tool_utils import get_tool_documentation
13
16
  from vision_agent.tools.tools import TOOL_DESCRIPTIONS
17
+ from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
14
18
  from vision_agent.utils.execute import Execution, MimeType
19
+ from vision_agent.utils.image_utils import convert_to_b64
15
20
 
16
21
  # These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent
17
22
 
@@ -99,13 +104,14 @@ class Artifacts:
99
104
 
100
105
  def show(self) -> str:
101
106
  """Shows the artifacts that have been loaded and their remote save paths."""
102
- out_str = "[Artifacts loaded]\n"
107
+ output_str = "[Artifacts loaded]\n"
103
108
  for k in self.artifacts.keys():
104
- out_str += (
109
+ output_str += (
105
110
  f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
106
111
  )
107
- out_str += "[End of artifacts]\n"
108
- return out_str
112
+ output_str += "[End of artifacts]\n"
113
+ print(output_str)
114
+ return output_str
109
115
 
110
116
  def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
111
117
  save_path = (
@@ -135,7 +141,12 @@ def format_lines(lines: List[str], start_idx: int) -> str:
135
141
 
136
142
 
137
143
  def view_lines(
138
- lines: List[str], line_num: int, window_size: int, name: str, total_lines: int
144
+ lines: List[str],
145
+ line_num: int,
146
+ window_size: int,
147
+ name: str,
148
+ total_lines: int,
149
+ print_output: bool = True,
139
150
  ) -> str:
140
151
  start = max(0, line_num - window_size)
141
152
  end = min(len(lines), line_num + window_size)
@@ -148,7 +159,9 @@ def view_lines(
148
159
  else f"[{len(lines) - end} more lines]"
149
160
  )
150
161
  )
151
- print(return_str)
162
+
163
+ if print_output:
164
+ print(return_str)
152
165
  return return_str
153
166
 
154
167
 
@@ -231,7 +244,7 @@ def edit_code_artifact(
231
244
  new_content_lines = [
232
245
  line if line.endswith("\n") else line + "\n" for line in new_content_lines
233
246
  ]
234
- lines = artifacts[name].splitlines()
247
+ lines = artifacts[name].splitlines(keepends=True)
235
248
  edited_lines = lines[:start] + new_content_lines + lines[end:]
236
249
 
237
250
  cur_line = start + len(content.split("\n")) // 2
@@ -261,13 +274,20 @@ def edit_code_artifact(
261
274
  DEFAULT_WINDOW_SIZE,
262
275
  name,
263
276
  total_lines,
277
+ print_output=False,
264
278
  )
265
279
  total_lines_edit = sum(1 for _ in edited_lines)
266
280
  edited_view = view_lines(
267
- edited_lines, cur_line, DEFAULT_WINDOW_SIZE, name, total_lines_edit
281
+ edited_lines,
282
+ cur_line,
283
+ DEFAULT_WINDOW_SIZE,
284
+ name,
285
+ total_lines_edit,
286
+ print_output=False,
268
287
  )
269
288
 
270
289
  error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
290
+ print(error_msg)
271
291
  return error_msg
272
292
 
273
293
  artifacts[name] = "".join(edited_lines)
@@ -390,6 +410,13 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
390
410
  return f"[Media {Path(local_path).name} saved]"
391
411
 
392
412
 
413
+ def list_artifacts(artifacts: Artifacts) -> str:
414
+ """Lists all the artifacts that have been loaded into the artifacts object."""
415
+ output_str = artifacts.show()
416
+ print(output_str)
417
+ return output_str
418
+
419
+
393
420
  def get_tool_descriptions() -> str:
394
421
  """Returns a description of all the tools that `generate_vision_code` has access to.
395
422
  Helpful for answering questions about what types of vision tasks you can do with
@@ -397,6 +424,108 @@ def get_tool_descriptions() -> str:
397
424
  return TOOL_DESCRIPTIONS
398
425
 
399
426
 
427
+ def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
428
+ """'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
429
+ objects in an image based on a given dataset. It returns the fine tuning job id.
430
+
431
+ Parameters:
432
+ bboxes (List[BboxInput]): A list of BboxInput containing the
433
+ image path, labels and bounding boxes.
434
+ task (str): The florencev2 fine-tuning task. The options are
435
+ 'phrase_grounding'.
436
+
437
+ Returns:
438
+ UUID: The fine tuning job id, this id will used to retrieve the fine
439
+ tuned model.
440
+
441
+ Example
442
+ -------
443
+ >>> fine_tuning_job_id = florencev2_fine_tuning(
444
+ [{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
445
+ {'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
446
+ "phrase_grounding"
447
+ )
448
+ """
449
+ bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
450
+ task_type = PromptTask[task.upper()]
451
+ fine_tuning_request = [
452
+ BboxInputBase64(
453
+ image=convert_to_b64(bbox_input.image_path),
454
+ filename=Path(bbox_input.image_path).name,
455
+ labels=bbox_input.labels,
456
+ bboxes=bbox_input.bboxes,
457
+ )
458
+ for bbox_input in bboxes_input
459
+ ]
460
+ landing_api = LandingPublicAPI()
461
+ fine_tune_id = str(
462
+ landing_api.launch_fine_tuning_job("florencev2", task_type, fine_tuning_request)
463
+ )
464
+ print(f"[Florence2 fine tuning id: {fine_tune_id}]")
465
+ return fine_tune_id
466
+
467
+
468
+ def get_diff(before: str, after: str) -> str:
469
+ return "".join(
470
+ difflib.unified_diff(
471
+ before.splitlines(keepends=True), after.splitlines(keepends=True)
472
+ )
473
+ )
474
+
475
+
476
+ def use_florence2_fine_tuning(
477
+ artifacts: Artifacts, name: str, task: str, fine_tune_id: str
478
+ ) -> str:
479
+ """Replaces florence2 calls with the fine tuning id. This ensures that the code
480
+ utilizes the fined tuned florence2 model. Returns the diff between the original
481
+ code and the new code.
482
+
483
+ Parameters:
484
+ artifacts (Artifacts): The artifacts object to edit the code from.
485
+ name (str): The name of the artifact to edit.
486
+ task (str): The task to fine tune the model for. The options are
487
+ 'phrase_grounding'.
488
+ fine_tune_id (str): The fine tuning job id.
489
+
490
+ Examples
491
+ --------
492
+ >>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
493
+ """
494
+
495
+ task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}
496
+
497
+ if name not in artifacts:
498
+ output_str = f"[Artifact {name} does not exist]"
499
+ print(output_str)
500
+ return output_str
501
+
502
+ code = artifacts[name]
503
+ if task.lower() == "phrase_grounding":
504
+ pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"
505
+
506
+ def replacer(match: re.Match) -> str:
507
+ arg = match.group(1) # capture all initial arguments
508
+ return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'
509
+
510
+ else:
511
+ raise ValueError(f"Task {task} is not supported.")
512
+
513
+ new_code = re.sub(pattern, replacer, code)
514
+
515
+ if new_code == code:
516
+ output_str = (
517
+ f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
518
+ )
519
+ print(output_str)
520
+ return output_str
521
+
522
+ artifacts[name] = new_code
523
+
524
+ diff = get_diff(code, new_code)
525
+ print(diff)
526
+ return diff
527
+
528
+
400
529
  META_TOOL_DOCSTRING = get_tool_documentation(
401
530
  [
402
531
  get_tool_descriptions,
@@ -406,5 +535,8 @@ META_TOOL_DOCSTRING = get_tool_documentation(
406
535
  generate_vision_code,
407
536
  edit_vision_code,
408
537
  write_media_artifact,
538
+ florence2_fine_tuning,
539
+ use_florence2_fine_tuning,
540
+ list_artifacts,
409
541
  ]
410
542
  )
@@ -1,6 +1,6 @@
1
+ import os
1
2
  import inspect
2
3
  import logging
3
- import os
4
4
  from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
5
5
 
6
6
  import pandas as pd
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
13
13
  from vision_agent.utils.exceptions import RemoteToolCallFailed
14
14
  from vision_agent.utils.execute import Error, MimeType
15
15
  from vision_agent.utils.type_defs import LandingaiAPIKey
16
+ from vision_agent.tools.tools_types import BoundingBoxes
16
17
 
17
18
  _LOGGER = logging.getLogger(__name__)
18
19
  _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
@@ -34,61 +35,58 @@ def send_inference_request(
34
35
  files: Optional[List[Tuple[Any, ...]]] = None,
35
36
  v2: bool = False,
36
37
  metadata_payload: Optional[Dict[str, Any]] = None,
37
- ) -> Dict[str, Any]:
38
+ ) -> Any:
38
39
  # TODO: runtime_tag and function_name should be metadata_payload and now included
39
40
  # in the service payload
40
- try:
41
- if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
42
- payload["runtime_tag"] = runtime_tag
41
+ if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
42
+ payload["runtime_tag"] = runtime_tag
43
+
44
+ url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
45
+ if "TOOL_ENDPOINT_URL" in os.environ:
46
+ url = os.environ["TOOL_ENDPOINT_URL"]
47
+
48
+ headers = {"apikey": _LND_API_KEY}
49
+ if "TOOL_ENDPOINT_AUTH" in os.environ:
50
+ headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
51
+ headers.pop("apikey")
52
+
53
+ session = _create_requests_session(
54
+ url=url,
55
+ num_retry=3,
56
+ headers=headers,
57
+ )
43
58
 
44
- url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
45
- if "TOOL_ENDPOINT_URL" in os.environ:
46
- url = os.environ["TOOL_ENDPOINT_URL"]
59
+ function_name = "unknown"
60
+ if "function_name" in payload:
61
+ function_name = payload["function_name"]
62
+ elif metadata_payload is not None and "function_name" in metadata_payload:
63
+ function_name = metadata_payload["function_name"]
47
64
 
48
- tool_call_trace = ToolCallTrace(
49
- endpoint_url=url,
50
- request=payload,
51
- response={},
52
- error=None,
53
- )
54
- headers = {"apikey": _LND_API_KEY}
55
- if "TOOL_ENDPOINT_AUTH" in os.environ:
56
- headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
57
- headers.pop("apikey")
58
-
59
- session = _create_requests_session(
60
- url=url,
61
- num_retry=3,
62
- headers=headers,
63
- )
65
+ response = _call_post(url, payload, session, files, function_name)
64
66
 
65
- if files is not None:
66
- res = session.post(url, data=payload, files=files)
67
- else:
68
- res = session.post(url, json=payload)
69
- if res.status_code != 200:
70
- tool_call_trace.error = Error(
71
- name="RemoteToolCallFailed",
72
- value=f"{res.status_code} - {res.text}",
73
- traceback_raw=[],
74
- )
75
- _LOGGER.error(f"Request failed: {res.status_code} {res.text}")
76
- # TODO: function_name should be in metadata_payload
77
- function_name = "unknown"
78
- if "function_name" in payload:
79
- function_name = payload["function_name"]
80
- elif metadata_payload is not None and "function_name" in metadata_payload:
81
- function_name = metadata_payload["function_name"]
82
- raise RemoteToolCallFailed(function_name, res.status_code, res.text)
83
-
84
- resp = res.json()
85
- tool_call_trace.response = resp
86
- # TODO: consider making the response schema the same between below two sources
87
- return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
88
- finally:
89
- trace = tool_call_trace.model_dump()
90
- trace["type"] = "tool_call"
91
- display({MimeType.APPLICATION_JSON: trace}, raw=True)
67
+ # TODO: consider making the response schema the same between below two sources
68
+ return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
69
+
70
+
71
+ def send_task_inference_request(
72
+ payload: Dict[str, Any],
73
+ task_name: str,
74
+ files: Optional[List[Tuple[Any, ...]]] = None,
75
+ metadata: Optional[Dict[str, Any]] = None,
76
+ ) -> Any:
77
+ url = f"{_LND_API_URL_v2}/{task_name}"
78
+ headers = {"apikey": _LND_API_KEY}
79
+ session = _create_requests_session(
80
+ url=url,
81
+ num_retry=3,
82
+ headers=headers,
83
+ )
84
+
85
+ function_name = "unknown"
86
+ if metadata is not None and "function_name" in metadata:
87
+ function_name = metadata["function_name"]
88
+ response = _call_post(url, payload, session, files, function_name)
89
+ return response["data"]
92
90
 
93
91
 
94
92
  def _create_requests_session(
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
195
193
  data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
196
194
 
197
195
  return data
196
+
197
+
198
+ def _call_post(
199
+ url: str,
200
+ payload: dict[str, Any],
201
+ session: Session,
202
+ files: Optional[List[Tuple[Any, ...]]] = None,
203
+ function_name: str = "unknown",
204
+ ) -> Any:
205
+ try:
206
+ tool_call_trace = ToolCallTrace(
207
+ endpoint_url=url,
208
+ request=payload,
209
+ response={},
210
+ error=None,
211
+ )
212
+
213
+ if files is not None:
214
+ response = session.post(url, data=payload, files=files)
215
+ else:
216
+ response = session.post(url, json=payload)
217
+
218
+ if response.status_code != 200:
219
+ tool_call_trace.error = Error(
220
+ name="RemoteToolCallFailed",
221
+ value=f"{response.status_code} - {response.text}",
222
+ traceback_raw=[],
223
+ )
224
+ _LOGGER.error(f"Request failed: {response.status_code} {response.text}")
225
+ raise RemoteToolCallFailed(
226
+ function_name, response.status_code, response.text
227
+ )
228
+
229
+ result = response.json()
230
+ tool_call_trace.response = result
231
+ return result
232
+ finally:
233
+ trace = tool_call_trace.model_dump()
234
+ trace["type"] = "tool_call"
235
+ display({MimeType.APPLICATION_JSON: trace}, raw=True)
236
+
237
+
238
+ def filter_bboxes_by_threshold(
239
+ bboxes: BoundingBoxes, threshold: float
240
+ ) -> BoundingBoxes:
241
+ return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
@@ -13,7 +13,7 @@ import cv2
13
13
  import numpy as np
14
14
  import requests
15
15
  from moviepy.editor import ImageSequenceClip
16
- from PIL import Image, ImageDraw, ImageFont
16
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance
17
17
  from pillow_heif import register_heif_opener # type: ignore
18
18
  from pytube import YouTube # type: ignore
19
19
 
@@ -24,14 +24,15 @@ from vision_agent.tools.tool_utils import (
24
24
  get_tools_df,
25
25
  get_tools_info,
26
26
  send_inference_request,
27
+ send_task_inference_request,
28
+ filter_bboxes_by_threshold,
27
29
  )
28
30
  from vision_agent.tools.tools_types import (
29
- BboxInput,
30
- BboxInputBase64,
31
31
  FineTuning,
32
- Florencev2FtRequest,
32
+ Florence2FtRequest,
33
33
  JobStatus,
34
34
  PromptTask,
35
+ ODResponseData,
35
36
  )
36
37
  from vision_agent.utils import extract_frames_from_video
37
38
  from vision_agent.utils.exceptions import FineTuneModelIsNotReady
@@ -455,7 +456,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
455
456
  "image": image_b64,
456
457
  "function_name": "loca_zero_shot_counting",
457
458
  }
458
- resp_data = send_inference_request(data, "loca", v2=True)
459
+ resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
459
460
  resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
460
461
  return resp_data
461
462
 
@@ -469,6 +470,8 @@ def loca_visual_prompt_counting(
469
470
 
470
471
  Parameters:
471
472
  image (np.ndarray): The image that contains lot of instances of a single object
473
+ visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
474
+ [xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
472
475
 
473
476
  Returns:
474
477
  Dict[str, Any]: A dictionary containing the key 'count' and the count as a
@@ -496,11 +499,109 @@ def loca_visual_prompt_counting(
496
499
  "bbox": list(map(int, denormalize_bbox(bbox, image_size))),
497
500
  "function_name": "loca_visual_prompt_counting",
498
501
  }
499
- resp_data = send_inference_request(data, "loca", v2=True)
502
+ resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
500
503
  resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
501
504
  return resp_data
502
505
 
503
506
 
507
+ def countgd_counting(
508
+ prompt: str,
509
+ image: np.ndarray,
510
+ box_threshold: float = 0.23,
511
+ ) -> List[Dict[str, Any]]:
512
+ """'countgd_counting' is a tool that can precisely count multiple instances of an
513
+ object given a text prompt. It returns a list of bounding boxes with normalized
514
+ coordinates, label names and associated confidence scores.
515
+
516
+ Parameters:
517
+ prompt (str): The object that needs to be counted.
518
+ image (np.ndarray): The image that contains multiple instances of the object.
519
+ box_threshold (float, optional): The threshold for detection. Defaults
520
+ to 0.23.
521
+
522
+ Returns:
523
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
524
+ bounding box of the detected objects with normalized coordinates between 0
525
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
526
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
527
+ bounding box.
528
+
529
+ Example
530
+ -------
531
+ >>> countgd_counting("flower", image)
532
+ [
533
+ {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
534
+ {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
535
+ {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
536
+ {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
537
+ ]
538
+ """
539
+ buffer_bytes = numpy_to_bytes(image)
540
+ files = [("image", buffer_bytes)]
541
+ prompt = prompt.replace(", ", " .")
542
+ payload = {"prompts": [prompt], "model": "countgd"}
543
+ metadata = {"function_name": "countgd_counting"}
544
+ resp_data = send_task_inference_request(
545
+ payload, "text-to-object-detection", files=files, metadata=metadata
546
+ )
547
+ bboxes_per_frame = resp_data[0]
548
+ bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
549
+ filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
550
+ return [bbox.model_dump() for bbox in filtered_bboxes]
551
+
552
+
553
+ def countgd_example_based_counting(
554
+ visual_prompts: List[List[float]],
555
+ image: np.ndarray,
556
+ box_threshold: float = 0.23,
557
+ ) -> List[Dict[str, Any]]:
558
+ """'countgd_example_based_counting' is a tool that can precisely count multiple
559
+ instances of an object given few visual example prompts. It returns a list of bounding
560
+ boxes with normalized coordinates, label names and associated confidence scores.
561
+
562
+ Parameters:
563
+ visual_prompts (List[List[float]]): Bounding boxes of the object in format
564
+ [xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
565
+ image (np.ndarray): The image that contains multiple instances of the object.
566
+ box_threshold (float, optional): The threshold for detection. Defaults
567
+ to 0.23.
568
+
569
+ Returns:
570
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
571
+ bounding box of the detected objects with normalized coordinates between 0
572
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
573
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
574
+ bounding box.
575
+
576
+ Example
577
+ -------
578
+ >>> countgd_example_based_counting(
579
+ visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
580
+ image=image
581
+ )
582
+ [
583
+ {'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
584
+ {'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
585
+ {'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
586
+ {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
587
+ ]
588
+ """
589
+ buffer_bytes = numpy_to_bytes(image)
590
+ files = [("image", buffer_bytes)]
591
+ visual_prompts = [
592
+ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
593
+ ]
594
+ payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
595
+ metadata = {"function_name": "countgd_example_based_counting"}
596
+ resp_data = send_task_inference_request(
597
+ payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
598
+ )
599
+ bboxes_per_frame = resp_data[0]
600
+ bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
601
+ filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
602
+ return [bbox.model_dump() for bbox in filtered_bboxes]
603
+
604
+
504
605
  def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
505
606
  """'florence2_roberta_vqa' is a tool that takes an image and analyzes
506
607
  its contents, generates detailed captions and then tries to answer the given
@@ -646,7 +747,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
646
747
  "tool": "closed_set_image_classification",
647
748
  "function_name": "clip",
648
749
  }
649
- resp_data = send_inference_request(data, "tools")
750
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
650
751
  resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
651
752
  return resp_data
652
753
 
@@ -674,7 +775,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
674
775
  "tool": "image_classification",
675
776
  "function_name": "vit_image_classification",
676
777
  }
677
- resp_data = send_inference_request(data, "tools")
778
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
678
779
  resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
679
780
  return resp_data
680
781
 
@@ -701,7 +802,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
701
802
  "image": image_b64,
702
803
  "function_name": "vit_nsfw_classification",
703
804
  }
704
- resp_data = send_inference_request(data, "nsfw-classification", v2=True)
805
+ resp_data: dict[str, Any] = send_inference_request(
806
+ data, "nsfw-classification", v2=True
807
+ )
705
808
  resp_data["score"] = round(resp_data["score"], 4)
706
809
  return resp_data
707
810
 
@@ -762,7 +865,9 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
762
865
  return answer[task] # type: ignore
763
866
 
764
867
 
765
- def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
868
+ def florence2_phrase_grounding(
869
+ prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
870
+ ) -> List[Dict[str, Any]]:
766
871
  """'florence2_phrase_grounding' is a tool that can detect multiple
767
872
  objects given a text prompt which can be object names or caption. You
768
873
  can optionally separate the object names in the text with commas. It returns a list
@@ -772,6 +877,8 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
772
877
  Parameters:
773
878
  prompt (str): The prompt to ground to the image.
774
879
  image (np.ndarray): The image to used to detect objects
880
+ fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
881
+ fine-tuned model ID here to use it.
775
882
 
776
883
  Returns:
777
884
  List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
@@ -790,14 +897,33 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
790
897
  """
791
898
  image_size = image.shape[:2]
792
899
  image_b64 = convert_to_b64(image)
793
- data = {
794
- "image": image_b64,
795
- "task": "<CAPTION_TO_PHRASE_GROUNDING>",
796
- "prompt": prompt,
797
- "function_name": "florence2_phrase_grounding",
798
- }
799
900
 
800
- detections = send_inference_request(data, "florence2", v2=True)
901
+ if fine_tune_id is not None:
902
+ landing_api = LandingPublicAPI()
903
+ status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
904
+ if status is not JobStatus.SUCCEEDED:
905
+ raise FineTuneModelIsNotReady(
906
+ f"Fine-tuned model {fine_tune_id} is not ready yet"
907
+ )
908
+
909
+ data_obj = Florence2FtRequest(
910
+ image=image_b64,
911
+ task=PromptTask.PHRASE_GROUNDING,
912
+ tool="florencev2_fine_tuning",
913
+ prompt=prompt,
914
+ fine_tuning=FineTuning(job_id=UUID(fine_tune_id)),
915
+ )
916
+ data = data_obj.model_dump(by_alias=True)
917
+ detections = send_inference_request(data, "tools", v2=False)
918
+ else:
919
+ data = {
920
+ "image": image_b64,
921
+ "task": "<CAPTION_TO_PHRASE_GROUNDING>",
922
+ "prompt": prompt,
923
+ "function_name": "florence2_phrase_grounding",
924
+ }
925
+ detections = send_inference_request(data, "florence2", v2=True)
926
+
801
927
  detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
802
928
  return_data = []
803
929
  for i in range(len(detections["bboxes"])):
@@ -1559,117 +1685,72 @@ def overlay_heat_map(
1559
1685
  return np.array(combined)
1560
1686
 
1561
1687
 
1562
- # TODO: add this function to the imports so that is picked in the agent
1563
- def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
1564
- """'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
1565
- to detect objects in an image based on a given dataset. It returns the fine
1566
- tuning job id.
1688
+ def overlay_counting_results(
1689
+ image: np.ndarray, instances: List[Dict[str, Any]]
1690
+ ) -> np.ndarray:
1691
+ """'overlay_counting_results' is a utility function that displays counting results on
1692
+ an image.
1567
1693
 
1568
1694
  Parameters:
1569
- bboxes (List[BboxInput]): A list of BboxInput containing the
1570
- image path, labels and bounding boxes.
1571
- task (PromptTask): The florencev2 fine-tuning task. The options are
1572
- CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
1695
+ image (np.ndarray): The image to display the bounding boxes on.
1696
+ instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
1697
+ box information of each instance
1573
1698
 
1574
1699
  Returns:
1575
- UUID: The fine tuning job id, this id will used to retrieve the fine
1576
- tuned model.
1700
+ np.ndarray: The image with the instance_id dislpayed
1577
1701
 
1578
1702
  Example
1579
1703
  -------
1580
- >>> fine_tuning_job_id = florencev2_fine_tuning(
1581
- [{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
1582
- {'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
1583
- "OBJECT_DETECTION"
1704
+ >>> image_with_bboxes = overlay_counting_results(
1705
+ image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
1584
1706
  )
1585
1707
  """
1586
- bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
1587
- task_input = PromptTask[task]
1588
- fine_tuning_request = [
1589
- BboxInputBase64(
1590
- image=convert_to_b64(bbox_input.image_path),
1591
- filename=bbox_input.image_path.split("/")[-1],
1592
- labels=bbox_input.labels,
1593
- bboxes=bbox_input.bboxes,
1594
- )
1595
- for bbox_input in bboxes_input
1596
- ]
1597
- landing_api = LandingPublicAPI()
1598
- return landing_api.launch_fine_tuning_job(
1599
- "florencev2", task_input, fine_tuning_request
1600
- )
1601
-
1708
+ pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
1709
+ color = (158, 218, 229)
1602
1710
 
1603
- # TODO: add this function to the imports so that is picked in the agent
1604
- def florencev2_fine_tuned_object_detection(
1605
- image: np.ndarray, prompt: str, model_id: UUID, task: str
1606
- ) -> List[Dict[str, Any]]:
1607
- """'florencev2_fine_tuned_object_detection' is a tool that uses a fine tuned model
1608
- to detect objects given a text prompt such as a phrase or class names separated by
1609
- commas. It returns a list of detected objects as labels and their location as
1610
- bounding boxes with score of 1.0.
1711
+ width, height = pil_image.size
1712
+ fontsize = max(10, int(min(width, height) / 80))
1713
+ pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
1714
+ draw = ImageDraw.Draw(pil_image)
1715
+ font = ImageFont.truetype(
1716
+ str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
1717
+ fontsize,
1718
+ )
1611
1719
 
1612
- Parameters:
1613
- image (np.ndarray): The image to used to detect objects.
1614
- prompt (str): The prompt to help find objects in the image.
1615
- model_id (UUID): The fine-tuned model id.
1616
- task (PromptTask): The florencev2 fine-tuning task. The options are
1617
- CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
1720
+ for i, elt in enumerate(instances):
1721
+ label = f"{i}"
1722
+ box = elt["bbox"]
1618
1723
 
1619
- Returns:
1620
- List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
1621
- bounding box of the detected objects with normalized coordinates between 0
1622
- and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
1623
- top-left and xmax and ymax are the coordinates of the bottom-right of the
1624
- bounding box. The scores are always 1.0 and cannot be thresholded
1724
+ # denormalize the box if it is normalized
1725
+ box = denormalize_bbox(box, (height, width))
1726
+ x0, y0, x1, y1 = box
1727
+ cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
1625
1728
 
1626
- Example
1627
- -------
1628
- >>> florencev2_fine_tuned_object_detection(
1629
- image,
1630
- 'person looking at a coyote',
1631
- UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83")
1729
+ text_box = draw.textbbox(
1730
+ (cx, cy), text=label, font=font, align="center", anchor="mm"
1632
1731
  )
1633
- [
1634
- {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
1635
- {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
1636
- ]
1637
- """
1638
- # check if job succeeded first
1639
- landing_api = LandingPublicAPI()
1640
- status = landing_api.check_fine_tuning_job(model_id)
1641
- if status is not JobStatus.SUCCEEDED:
1642
- raise FineTuneModelIsNotReady()
1643
-
1644
- task = PromptTask[task]
1645
- if task is PromptTask.OBJECT_DETECTION:
1646
- prompt = ""
1647
-
1648
- data_obj = Florencev2FtRequest(
1649
- image=convert_to_b64(image),
1650
- task=task,
1651
- tool="florencev2_fine_tuning",
1652
- prompt=prompt,
1653
- fine_tuning=FineTuning(job_id=model_id),
1654
- )
1655
- data = data_obj.model_dump(by_alias=True)
1656
- metadata_payload = {"function_name": "florencev2_fine_tuned_object_detection"}
1657
- detections = send_inference_request(
1658
- data, "tools", v2=False, metadata_payload=metadata_payload
1659
- )
1660
1732
 
1661
- detections = detections[task.value]
1662
- return_data = []
1663
- image_size = image.shape[:2]
1664
- for i in range(len(detections["bboxes"])):
1665
- return_data.append(
1666
- {
1667
- "score": 1.0,
1668
- "label": detections["labels"][i],
1669
- "bbox": normalize_bbox(detections["bboxes"][i], image_size),
1670
- }
1733
+ # Calculate the offset to center the text within the bounding box
1734
+ text_width = text_box[2] - text_box[0]
1735
+ text_height = text_box[3] - text_box[1]
1736
+ text_x0 = cx - text_width / 2
1737
+ text_y0 = cy - text_height / 2
1738
+ text_x1 = cx + text_width / 2
1739
+ text_y1 = cy + text_height / 2
1740
+
1741
+ # Draw the rectangle encapsulating the text
1742
+ draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
1743
+
1744
+ # Draw the text at the center of the bounding box
1745
+ draw.text(
1746
+ (text_x0, text_y0),
1747
+ label,
1748
+ fill="black",
1749
+ font=font,
1750
+ anchor="lt",
1671
1751
  )
1672
- return return_data
1752
+
1753
+ return np.array(pil_image)
1673
1754
 
1674
1755
 
1675
1756
  FUNCTION_TOOLS = [
@@ -1679,8 +1760,7 @@ FUNCTION_TOOLS = [
1679
1760
  clip,
1680
1761
  vit_image_classification,
1681
1762
  vit_nsfw_classification,
1682
- loca_zero_shot_counting,
1683
- loca_visual_prompt_counting,
1763
+ countgd_counting,
1684
1764
  florence2_image_caption,
1685
1765
  florence2_ocr,
1686
1766
  florence2_sam2_image,
@@ -1703,6 +1783,7 @@ UTIL_TOOLS = [
1703
1783
  overlay_bounding_boxes,
1704
1784
  overlay_segmentation_masks,
1705
1785
  overlay_heat_map,
1786
+ overlay_counting_results,
1706
1787
  ]
1707
1788
 
1708
1789
  TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
@@ -1720,5 +1801,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
1720
1801
  overlay_bounding_boxes,
1721
1802
  overlay_segmentation_masks,
1722
1803
  overlay_heat_map,
1804
+ overlay_counting_results,
1723
1805
  ]
1724
1806
  )
@@ -1,8 +1,8 @@
1
1
  from enum import Enum
2
- from typing import List, Optional, Tuple
3
2
  from uuid import UUID
3
+ from typing import List, Tuple, Optional, Union
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, field_serializer
5
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
6
6
 
7
7
 
8
8
  class BboxInput(BaseModel):
@@ -19,16 +19,9 @@ class BboxInputBase64(BaseModel):
19
19
 
20
20
 
21
21
  class PromptTask(str, Enum):
22
- """
23
- Valid task prompts options for the Florencev2 model.
24
- """
22
+ """Valid task prompts options for the Florence2 model."""
25
23
 
26
- CAPTION = "<CAPTION>"
27
- """"""
28
- CAPTION_TO_PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
29
- """"""
30
- OBJECT_DETECTION = "<OD>"
31
- """"""
24
+ PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
32
25
 
33
26
 
34
27
  class FineTuning(BaseModel):
@@ -41,7 +34,7 @@ class FineTuning(BaseModel):
41
34
  return str(job_id)
42
35
 
43
36
 
44
- class Florencev2FtRequest(BaseModel):
37
+ class Florence2FtRequest(BaseModel):
45
38
  model_config = ConfigDict(populate_by_name=True)
46
39
 
47
40
  image: str
@@ -82,3 +75,16 @@ class JobStatus(str, Enum):
82
75
  SUCCEEDED = "SUCCEEDED"
83
76
  FAILED = "FAILED"
84
77
  STOPPED = "STOPPED"
78
+
79
+
80
+ class ODResponseData(BaseModel):
81
+ label: str
82
+ score: float
83
+ bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
84
+
85
+ model_config = ConfigDict(
86
+ populate_by_name=True,
87
+ )
88
+
89
+
90
+ BoundingBoxes = list[ODResponseData]
@@ -564,7 +564,13 @@ class LocalCodeInterpreter(CodeInterpreter):
564
564
  ) -> None:
565
565
  super().__init__(timeout=timeout)
566
566
  self.nb = nbformat.v4.new_notebook()
567
- self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
567
+ # Set the notebook execution path to the remote path
568
+ self.resources = {"metadata": {"path": str(self.remote_path)}}
569
+ self.nb_client = NotebookClient(
570
+ self.nb,
571
+ timeout=self.timeout,
572
+ resources=self.resources,
573
+ )
568
574
  _LOGGER.info(
569
575
  f"""Local code interpreter initialized
570
576
  Python version: {sys.version}
@@ -606,7 +612,9 @@ Timeout: {self.timeout}"""
606
612
  def restart_kernel(self) -> None:
607
613
  self.close()
608
614
  self.nb = nbformat.v4.new_notebook()
609
- self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
615
+ self.nb_client = NotebookClient(
616
+ self.nb, timeout=self.timeout, resources=self.resources
617
+ )
610
618
  sleep(1)
611
619
  self._new_kernel()
612
620
 
@@ -636,7 +644,7 @@ Timeout: {self.timeout}"""
636
644
  f.write(contents)
637
645
  _LOGGER.info(f"File ({file_path}) is uploaded to: {str(self.remote_path)}")
638
646
 
639
- return Path(self.remote_path / file_path)
647
+ return Path(self.remote_path / Path(file_path).name)
640
648
 
641
649
  def download_file(
642
650
  self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
@@ -672,7 +680,8 @@ class CodeInterpreterFactory:
672
680
 
673
681
  @staticmethod
674
682
  def new_instance(
675
- code_sandbox_runtime: Optional[str] = None, remote_path: Optional[str] = None
683
+ code_sandbox_runtime: Optional[str] = None,
684
+ remote_path: Optional[Union[str, Path]] = None,
676
685
  ) -> CodeInterpreter:
677
686
  if not code_sandbox_runtime:
678
687
  code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
@@ -181,7 +181,7 @@ def denormalize_bbox(
181
181
  raise ValueError("Bounding box must be of length 4.")
182
182
 
183
183
  arr = np.array(bbox)
184
- if np.all((arr >= 0) & (arr <= 1)):
184
+ if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
185
185
  x1, y1, x2, y2 = bbox
186
186
  x1 = round(x1 * image_size[1])
187
187
  y1 = round(y1 * image_size[0])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.120
3
+ Version: 0.2.122
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -2,32 +2,32 @@ vision_agent/__init__.py,sha256=EAb4-f9iyuEYkBrX4ag1syM8Syx8118_t0R6_C34M9w,57
2
2
  vision_agent/agent/__init__.py,sha256=FRwiux1FGvGccetyUCtY46KP01fQteqorm-JtFepovI,176
3
3
  vision_agent/agent/agent.py,sha256=2cjIOxEuSJrqbfPXYoV0qER5ihXsPFCoEFJa4jpqan0,597
4
4
  vision_agent/agent/agent_utils.py,sha256=22LiPhkJlS5mVeo2dIi259pc2NgA7PGHRpcbnrtKo78,1930
5
- vision_agent/agent/vision_agent.py,sha256=IEyXT_JPCuWmBHdEnM1Wrsj7hmCe5pKLf0gnZFJTddI,11046
6
- vision_agent/agent/vision_agent_coder.py,sha256=DOTmDdGPxcI06Jp6yx4ekRMP0vhiVaK9B9Dl8UyJHeo,34396
7
- vision_agent/agent/vision_agent_coder_prompts.py,sha256=xIya1txRZM8qoQHAWTEkEFCL8L3iZD7QD09t3ZtdxSE,11305
8
- vision_agent/agent/vision_agent_prompts.py,sha256=0GliXFtBf32aPu2ClU63FI5ii5CTxWYsvrsmnnDp-gs,7134
5
+ vision_agent/agent/vision_agent.py,sha256=WM1_o0VAQokAKlDr-0lpFxCRwUm_eFfFNWP-wSNjo7s,11180
6
+ vision_agent/agent/vision_agent_coder.py,sha256=ujctkpmQkX2C6YXjlp7VLZFqSB00xwkGe-9swA8Gv8s,34240
7
+ vision_agent/agent/vision_agent_coder_prompts.py,sha256=Rg7-Ih7oFgFbHFFno0EHpaZEgm0SYj_nTdqqdp21YLo,11246
8
+ vision_agent/agent/vision_agent_prompts.py,sha256=K1nLo3XKQ-IqCom1TRwh3cMoGZNxNwEgZqf3uJ6eL18,7221
9
9
  vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
11
11
  vision_agent/clients/landing_public_api.py,sha256=rGtACkr8o5egDuMHQ5MBO4NuvsgPTp9Ew3rbq4R-vs0,1507
12
12
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
14
14
  vision_agent/lmm/__init__.py,sha256=YuUZRsMHdn8cMOv6iBU8yUqlIOLrbZQqZl9KPnofsHQ,103
15
- vision_agent/lmm/lmm.py,sha256=AYrZNdhghG293wd3aKZ1jK1lUm2NLWwALktbM4wNais,20862
15
+ vision_agent/lmm/lmm.py,sha256=H3a5V7c073-vXRJfQOblE2j_CsZkH1CNNRoQgLjJZuQ,20751
16
16
  vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
17
- vision_agent/tools/__init__.py,sha256=i7JOLxRaLdcY7-vCNOGAeOFMBfiAUIwWhnT32FO97VE,2201
18
- vision_agent/tools/meta_tools.py,sha256=Vu9WnKicGhafx9dPzDbQjQdcIzRCYYFPF68o79hDP-8,14616
17
+ vision_agent/tools/__init__.py,sha256=TILaqdFYicScvpnCXMxgBsFmSW22NQDIvucvEgo0etw,2289
18
+ vision_agent/tools/meta_tools.py,sha256=KeGiw2OtY8ARpGbtWjoNAoO1dwevt7LbCupaJX61MkE,18929
19
19
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
20
- vision_agent/tools/tool_utils.py,sha256=qMsb9d8QtpXGgF9rpPO2dA390BewKdYO68oWKDu-TGg,6504
21
- vision_agent/tools/tools.py,sha256=kbbMToAaHxl42dDEvyz9Mvtpqts0l0hGoC5YQQyozr8,59953
22
- vision_agent/tools/tools_types.py,sha256=iLWSirheC87fKQolIhx_O4Jk8Lv7DRiLuE8PJqLGiVQ,2216
20
+ vision_agent/tools/tool_utils.py,sha256=e_p-G2nwgWOpoaqpDitY3FJ6fFuTEg5GhDOD67wI2bE,7527
21
+ vision_agent/tools/tools.py,sha256=jOBsuN-spY_2TlvpahoRYGvyInhQDTPXXukx9q72lEU,63454
22
+ vision_agent/tools/tools_types.py,sha256=qs11HGLRXc9zytahBtG6TQxCh8Gigvn232at3jk54jI,2356
23
23
  vision_agent/utils/__init__.py,sha256=pWk0ktvR4aUEhuEIzSLM9kSgW4WDVqptdvOTeGLkJ6M,230
24
24
  vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
25
- vision_agent/utils/execute.py,sha256=Ap8Yx80spQq5f2QtKGx1MK03BR45mJKhlp1kfh-rIao,26751
26
- vision_agent/utils/image_utils.py,sha256=eNghu_2L8624jEXy8ZZS9OX46Mv0DT9bcvLForujwTs,9848
25
+ vision_agent/utils/execute.py,sha256=gc4R_0BKUrZyhiKvIxOpYuzQPYVWQEqxr3ANy1lJAw4,27037
26
+ vision_agent/utils/image_utils.py,sha256=UloC4byIQLM4CSCaH41SBciQ7X2OqKvsVvNOVKqIH_k,9856
27
27
  vision_agent/utils/sim.py,sha256=ebE9Cs00pVEDI1HMjAzUBk88tQQmc2U-yAzIDinnekU,5572
28
28
  vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
29
29
  vision_agent/utils/video.py,sha256=rNmU9KEIkZB5-EztZNlUiKYN0mm_55A_2VGUM0QpqLA,8779
30
- vision_agent-0.2.120.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
- vision_agent-0.2.120.dist-info/METADATA,sha256=-FuNdlrzt5cTK6Ou_HTTROGVvsIwP3trsB5Edt2St3o,12255
32
- vision_agent-0.2.120.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
33
- vision_agent-0.2.120.dist-info/RECORD,,
30
+ vision_agent-0.2.122.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
+ vision_agent-0.2.122.dist-info/METADATA,sha256=WMdLNPyKY4Ot6ifOzwXNDiVm2TsStY-l-ge8t72Ynhk,12255
32
+ vision_agent-0.2.122.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
33
+ vision_agent-0.2.122.dist-info/RECORD,,