vision-agent 0.2.120__py3-none-any.whl → 0.2.122__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -30,7 +30,7 @@ class BoilerplateCode:
30
30
  pre_code = [
31
31
  "from typing import *",
32
32
  "from vision_agent.utils.execute import CodeInterpreter",
33
- "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
33
+ "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
34
34
  "artifacts = Artifacts('{remote_path}')",
35
35
  "artifacts.load('{remote_path}')",
36
36
  ]
@@ -76,11 +76,16 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
76
76
 
77
77
  def run_code_action(
78
78
  code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
79
- ) -> Execution:
80
- return code_interpreter.exec_isolation(
79
+ ) -> Tuple[Execution, str]:
80
+ result = code_interpreter.exec_isolation(
81
81
  BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
82
82
  )
83
83
 
84
+ obs = str(result.logs)
85
+ if result.error:
86
+ obs += f"\n{result.error}"
87
+ return result, obs
88
+
84
89
 
85
90
  def parse_execution(response: str) -> Optional[str]:
86
91
  code = None
@@ -192,7 +197,7 @@ class VisionAgent(Agent):
192
197
  artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
193
198
 
194
199
  with CodeInterpreterFactory.new_instance(
195
- code_sandbox_runtime=self.code_sandbox_runtime
200
+ code_sandbox_runtime=self.code_sandbox_runtime,
196
201
  ) as code_interpreter:
197
202
  orig_chat = copy.deepcopy(chat)
198
203
  int_chat = copy.deepcopy(chat)
@@ -260,10 +265,9 @@ class VisionAgent(Agent):
260
265
  code_action = parse_execution(response["response"])
261
266
 
262
267
  if code_action is not None:
263
- result = run_code_action(
268
+ result, obs = run_code_action(
264
269
  code_action, code_interpreter, str(remote_artifacts_path)
265
270
  )
266
- obs = str(result.logs)
267
271
 
268
272
  if self.verbosity >= 1:
269
273
  _LOGGER.info(obs)
@@ -1,5 +1,4 @@
1
1
  import copy
2
- import difflib
3
2
  import logging
4
3
  import os
5
4
  import sys
@@ -29,6 +28,7 @@ from vision_agent.agent.vision_agent_coder_prompts import (
29
28
  USER_REQ,
30
29
  )
31
30
  from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
31
+ from vision_agent.tools.meta_tools import get_diff
32
32
  from vision_agent.utils import CodeInterpreterFactory, Execution
33
33
  from vision_agent.utils.execute import CodeInterpreter
34
34
  from vision_agent.utils.image_utils import b64_to_pil
@@ -63,14 +63,6 @@ class DefaultImports:
63
63
  return DefaultImports.to_code_string() + "\n\n" + code
64
64
 
65
65
 
66
- def get_diff(before: str, after: str) -> str:
67
- return "".join(
68
- difflib.unified_diff(
69
- before.splitlines(keepends=True), after.splitlines(keepends=True)
70
- )
71
- )
72
-
73
-
74
66
  def format_memory(memory: List[Dict[str, str]]) -> str:
75
67
  output_str = ""
76
68
  for i, m in enumerate(memory):
@@ -81,20 +81,19 @@ plan2:
81
81
  - Count the number of detected objects labeled as 'person'.
82
82
  plan3:
83
83
  - Load the image from the provided file path 'image.jpg'.
84
- - Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people.
84
+ - Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
85
85
 
86
86
  ```python
87
- from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting
87
+ from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
88
88
  image = load_image("image.jpg")
89
89
  owl_v2_out = owl_v2("person", image)
90
90
 
91
91
  gsam_out = grounding_sam("person", image)
92
92
  gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
93
93
 
94
- loca_out = loca_zero_shot_counting(image)
95
- loca_out = loca_out["count"]
94
+ cgd_out = countgd_counting(image)
96
95
 
97
- final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}}
96
+ final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
98
97
  print(final_out)
99
98
  ```
100
99
  """
@@ -48,7 +48,7 @@ OBSERVATION:
48
48
  4| return dogs
49
49
  [End of artifact]
50
50
 
51
- AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
51
+ AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
52
52
 
53
53
  OBSERVATION:
54
54
  ----- stdout -----
@@ -75,7 +75,7 @@ OBSERVATION:
75
75
  4| return dogs
76
76
  [End of artifact]
77
77
 
78
- AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
78
+ AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
79
79
 
80
80
  OBSERVATION:
81
81
  ----- stdout -----
@@ -126,7 +126,7 @@ OBSERVATION:
126
126
  15| return count
127
127
  [End of artifact]
128
128
 
129
- AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code to get the output and write the visualization to the artifacts so the user can see it.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
129
+ AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
130
130
 
131
131
  OBSERVATION:
132
132
  ----- stdout -----
vision_agent/lmm/lmm.py CHANGED
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
286
286
 
287
287
  return lambda x: T.grounding_sam(params["prompt"], x)
288
288
 
289
- def generate_zero_shot_counter(self, question: str) -> Callable:
290
- return T.loca_zero_shot_counting
291
-
292
289
  def generate_image_qa_tool(self, question: str) -> Callable:
293
290
  return lambda x: T.git_vqa_v2(question, x)
294
291
 
@@ -37,10 +37,13 @@ from .tools import (
37
37
  load_image,
38
38
  loca_visual_prompt_counting,
39
39
  loca_zero_shot_counting,
40
+ countgd_counting,
41
+ countgd_example_based_counting,
40
42
  ocr,
41
43
  overlay_bounding_boxes,
42
44
  overlay_heat_map,
43
45
  overlay_segmentation_masks,
46
+ overlay_counting_results,
44
47
  owl_v2,
45
48
  save_image,
46
49
  save_json,
@@ -1,5 +1,7 @@
1
+ import difflib
1
2
  import os
2
3
  import pickle as pkl
4
+ import re
3
5
  import subprocess
4
6
  import tempfile
5
7
  from pathlib import Path
@@ -8,10 +10,13 @@ from typing import Any, Dict, List, Optional, Union
8
10
  from IPython.display import display
9
11
 
10
12
  import vision_agent as va
13
+ from vision_agent.clients.landing_public_api import LandingPublicAPI
11
14
  from vision_agent.lmm.types import Message
12
15
  from vision_agent.tools.tool_utils import get_tool_documentation
13
16
  from vision_agent.tools.tools import TOOL_DESCRIPTIONS
17
+ from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
14
18
  from vision_agent.utils.execute import Execution, MimeType
19
+ from vision_agent.utils.image_utils import convert_to_b64
15
20
 
16
21
  # These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent
17
22
 
@@ -99,13 +104,14 @@ class Artifacts:
99
104
 
100
105
  def show(self) -> str:
101
106
  """Shows the artifacts that have been loaded and their remote save paths."""
102
- out_str = "[Artifacts loaded]\n"
107
+ output_str = "[Artifacts loaded]\n"
103
108
  for k in self.artifacts.keys():
104
- out_str += (
109
+ output_str += (
105
110
  f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
106
111
  )
107
- out_str += "[End of artifacts]\n"
108
- return out_str
112
+ output_str += "[End of artifacts]\n"
113
+ print(output_str)
114
+ return output_str
109
115
 
110
116
  def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
111
117
  save_path = (
@@ -135,7 +141,12 @@ def format_lines(lines: List[str], start_idx: int) -> str:
135
141
 
136
142
 
137
143
  def view_lines(
138
- lines: List[str], line_num: int, window_size: int, name: str, total_lines: int
144
+ lines: List[str],
145
+ line_num: int,
146
+ window_size: int,
147
+ name: str,
148
+ total_lines: int,
149
+ print_output: bool = True,
139
150
  ) -> str:
140
151
  start = max(0, line_num - window_size)
141
152
  end = min(len(lines), line_num + window_size)
@@ -148,7 +159,9 @@ def view_lines(
148
159
  else f"[{len(lines) - end} more lines]"
149
160
  )
150
161
  )
151
- print(return_str)
162
+
163
+ if print_output:
164
+ print(return_str)
152
165
  return return_str
153
166
 
154
167
 
@@ -231,7 +244,7 @@ def edit_code_artifact(
231
244
  new_content_lines = [
232
245
  line if line.endswith("\n") else line + "\n" for line in new_content_lines
233
246
  ]
234
- lines = artifacts[name].splitlines()
247
+ lines = artifacts[name].splitlines(keepends=True)
235
248
  edited_lines = lines[:start] + new_content_lines + lines[end:]
236
249
 
237
250
  cur_line = start + len(content.split("\n")) // 2
@@ -261,13 +274,20 @@ def edit_code_artifact(
261
274
  DEFAULT_WINDOW_SIZE,
262
275
  name,
263
276
  total_lines,
277
+ print_output=False,
264
278
  )
265
279
  total_lines_edit = sum(1 for _ in edited_lines)
266
280
  edited_view = view_lines(
267
- edited_lines, cur_line, DEFAULT_WINDOW_SIZE, name, total_lines_edit
281
+ edited_lines,
282
+ cur_line,
283
+ DEFAULT_WINDOW_SIZE,
284
+ name,
285
+ total_lines_edit,
286
+ print_output=False,
268
287
  )
269
288
 
270
289
  error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
290
+ print(error_msg)
271
291
  return error_msg
272
292
 
273
293
  artifacts[name] = "".join(edited_lines)
@@ -390,6 +410,13 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
390
410
  return f"[Media {Path(local_path).name} saved]"
391
411
 
392
412
 
413
+ def list_artifacts(artifacts: Artifacts) -> str:
414
+ """Lists all the artifacts that have been loaded into the artifacts object."""
415
+ output_str = artifacts.show()
416
+ print(output_str)
417
+ return output_str
418
+
419
+
393
420
  def get_tool_descriptions() -> str:
394
421
  """Returns a description of all the tools that `generate_vision_code` has access to.
395
422
  Helpful for answering questions about what types of vision tasks you can do with
@@ -397,6 +424,108 @@ def get_tool_descriptions() -> str:
397
424
  return TOOL_DESCRIPTIONS
398
425
 
399
426
 
427
+ def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
428
+ """'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
429
+ objects in an image based on a given dataset. It returns the fine tuning job id.
430
+
431
+ Parameters:
432
+ bboxes (List[BboxInput]): A list of BboxInput containing the
433
+ image path, labels and bounding boxes.
434
+ task (str): The florencev2 fine-tuning task. The options are
435
+ 'phrase_grounding'.
436
+
437
+ Returns:
438
+ UUID: The fine tuning job id, this id will used to retrieve the fine
439
+ tuned model.
440
+
441
+ Example
442
+ -------
443
+ >>> fine_tuning_job_id = florencev2_fine_tuning(
444
+ [{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
445
+ {'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
446
+ "phrase_grounding"
447
+ )
448
+ """
449
+ bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
450
+ task_type = PromptTask[task.upper()]
451
+ fine_tuning_request = [
452
+ BboxInputBase64(
453
+ image=convert_to_b64(bbox_input.image_path),
454
+ filename=Path(bbox_input.image_path).name,
455
+ labels=bbox_input.labels,
456
+ bboxes=bbox_input.bboxes,
457
+ )
458
+ for bbox_input in bboxes_input
459
+ ]
460
+ landing_api = LandingPublicAPI()
461
+ fine_tune_id = str(
462
+ landing_api.launch_fine_tuning_job("florencev2", task_type, fine_tuning_request)
463
+ )
464
+ print(f"[Florence2 fine tuning id: {fine_tune_id}]")
465
+ return fine_tune_id
466
+
467
+
468
+ def get_diff(before: str, after: str) -> str:
469
+ return "".join(
470
+ difflib.unified_diff(
471
+ before.splitlines(keepends=True), after.splitlines(keepends=True)
472
+ )
473
+ )
474
+
475
+
476
+ def use_florence2_fine_tuning(
477
+ artifacts: Artifacts, name: str, task: str, fine_tune_id: str
478
+ ) -> str:
479
+ """Replaces florence2 calls with the fine tuning id. This ensures that the code
480
+ utilizes the fined tuned florence2 model. Returns the diff between the original
481
+ code and the new code.
482
+
483
+ Parameters:
484
+ artifacts (Artifacts): The artifacts object to edit the code from.
485
+ name (str): The name of the artifact to edit.
486
+ task (str): The task to fine tune the model for. The options are
487
+ 'phrase_grounding'.
488
+ fine_tune_id (str): The fine tuning job id.
489
+
490
+ Examples
491
+ --------
492
+ >>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
493
+ """
494
+
495
+ task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}
496
+
497
+ if name not in artifacts:
498
+ output_str = f"[Artifact {name} does not exist]"
499
+ print(output_str)
500
+ return output_str
501
+
502
+ code = artifacts[name]
503
+ if task.lower() == "phrase_grounding":
504
+ pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"
505
+
506
+ def replacer(match: re.Match) -> str:
507
+ arg = match.group(1) # capture all initial arguments
508
+ return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'
509
+
510
+ else:
511
+ raise ValueError(f"Task {task} is not supported.")
512
+
513
+ new_code = re.sub(pattern, replacer, code)
514
+
515
+ if new_code == code:
516
+ output_str = (
517
+ f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
518
+ )
519
+ print(output_str)
520
+ return output_str
521
+
522
+ artifacts[name] = new_code
523
+
524
+ diff = get_diff(code, new_code)
525
+ print(diff)
526
+ return diff
527
+
528
+
400
529
  META_TOOL_DOCSTRING = get_tool_documentation(
401
530
  [
402
531
  get_tool_descriptions,
@@ -406,5 +535,8 @@ META_TOOL_DOCSTRING = get_tool_documentation(
406
535
  generate_vision_code,
407
536
  edit_vision_code,
408
537
  write_media_artifact,
538
+ florence2_fine_tuning,
539
+ use_florence2_fine_tuning,
540
+ list_artifacts,
409
541
  ]
410
542
  )
@@ -1,6 +1,6 @@
1
+ import os
1
2
  import inspect
2
3
  import logging
3
- import os
4
4
  from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
5
5
 
6
6
  import pandas as pd
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
13
13
  from vision_agent.utils.exceptions import RemoteToolCallFailed
14
14
  from vision_agent.utils.execute import Error, MimeType
15
15
  from vision_agent.utils.type_defs import LandingaiAPIKey
16
+ from vision_agent.tools.tools_types import BoundingBoxes
16
17
 
17
18
  _LOGGER = logging.getLogger(__name__)
18
19
  _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
@@ -34,61 +35,58 @@ def send_inference_request(
34
35
  files: Optional[List[Tuple[Any, ...]]] = None,
35
36
  v2: bool = False,
36
37
  metadata_payload: Optional[Dict[str, Any]] = None,
37
- ) -> Dict[str, Any]:
38
+ ) -> Any:
38
39
  # TODO: runtime_tag and function_name should be metadata_payload and now included
39
40
  # in the service payload
40
- try:
41
- if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
42
- payload["runtime_tag"] = runtime_tag
41
+ if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
42
+ payload["runtime_tag"] = runtime_tag
43
+
44
+ url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
45
+ if "TOOL_ENDPOINT_URL" in os.environ:
46
+ url = os.environ["TOOL_ENDPOINT_URL"]
47
+
48
+ headers = {"apikey": _LND_API_KEY}
49
+ if "TOOL_ENDPOINT_AUTH" in os.environ:
50
+ headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
51
+ headers.pop("apikey")
52
+
53
+ session = _create_requests_session(
54
+ url=url,
55
+ num_retry=3,
56
+ headers=headers,
57
+ )
43
58
 
44
- url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
45
- if "TOOL_ENDPOINT_URL" in os.environ:
46
- url = os.environ["TOOL_ENDPOINT_URL"]
59
+ function_name = "unknown"
60
+ if "function_name" in payload:
61
+ function_name = payload["function_name"]
62
+ elif metadata_payload is not None and "function_name" in metadata_payload:
63
+ function_name = metadata_payload["function_name"]
47
64
 
48
- tool_call_trace = ToolCallTrace(
49
- endpoint_url=url,
50
- request=payload,
51
- response={},
52
- error=None,
53
- )
54
- headers = {"apikey": _LND_API_KEY}
55
- if "TOOL_ENDPOINT_AUTH" in os.environ:
56
- headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
57
- headers.pop("apikey")
58
-
59
- session = _create_requests_session(
60
- url=url,
61
- num_retry=3,
62
- headers=headers,
63
- )
65
+ response = _call_post(url, payload, session, files, function_name)
64
66
 
65
- if files is not None:
66
- res = session.post(url, data=payload, files=files)
67
- else:
68
- res = session.post(url, json=payload)
69
- if res.status_code != 200:
70
- tool_call_trace.error = Error(
71
- name="RemoteToolCallFailed",
72
- value=f"{res.status_code} - {res.text}",
73
- traceback_raw=[],
74
- )
75
- _LOGGER.error(f"Request failed: {res.status_code} {res.text}")
76
- # TODO: function_name should be in metadata_payload
77
- function_name = "unknown"
78
- if "function_name" in payload:
79
- function_name = payload["function_name"]
80
- elif metadata_payload is not None and "function_name" in metadata_payload:
81
- function_name = metadata_payload["function_name"]
82
- raise RemoteToolCallFailed(function_name, res.status_code, res.text)
83
-
84
- resp = res.json()
85
- tool_call_trace.response = resp
86
- # TODO: consider making the response schema the same between below two sources
87
- return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore
88
- finally:
89
- trace = tool_call_trace.model_dump()
90
- trace["type"] = "tool_call"
91
- display({MimeType.APPLICATION_JSON: trace}, raw=True)
67
+ # TODO: consider making the response schema the same between below two sources
68
+ return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
69
+
70
+
71
+ def send_task_inference_request(
72
+ payload: Dict[str, Any],
73
+ task_name: str,
74
+ files: Optional[List[Tuple[Any, ...]]] = None,
75
+ metadata: Optional[Dict[str, Any]] = None,
76
+ ) -> Any:
77
+ url = f"{_LND_API_URL_v2}/{task_name}"
78
+ headers = {"apikey": _LND_API_KEY}
79
+ session = _create_requests_session(
80
+ url=url,
81
+ num_retry=3,
82
+ headers=headers,
83
+ )
84
+
85
+ function_name = "unknown"
86
+ if metadata is not None and "function_name" in metadata:
87
+ function_name = metadata["function_name"]
88
+ response = _call_post(url, payload, session, files, function_name)
89
+ return response["data"]
92
90
 
93
91
 
94
92
  def _create_requests_session(
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
195
193
  data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
196
194
 
197
195
  return data
196
+
197
+
198
+ def _call_post(
199
+ url: str,
200
+ payload: dict[str, Any],
201
+ session: Session,
202
+ files: Optional[List[Tuple[Any, ...]]] = None,
203
+ function_name: str = "unknown",
204
+ ) -> Any:
205
+ try:
206
+ tool_call_trace = ToolCallTrace(
207
+ endpoint_url=url,
208
+ request=payload,
209
+ response={},
210
+ error=None,
211
+ )
212
+
213
+ if files is not None:
214
+ response = session.post(url, data=payload, files=files)
215
+ else:
216
+ response = session.post(url, json=payload)
217
+
218
+ if response.status_code != 200:
219
+ tool_call_trace.error = Error(
220
+ name="RemoteToolCallFailed",
221
+ value=f"{response.status_code} - {response.text}",
222
+ traceback_raw=[],
223
+ )
224
+ _LOGGER.error(f"Request failed: {response.status_code} {response.text}")
225
+ raise RemoteToolCallFailed(
226
+ function_name, response.status_code, response.text
227
+ )
228
+
229
+ result = response.json()
230
+ tool_call_trace.response = result
231
+ return result
232
+ finally:
233
+ trace = tool_call_trace.model_dump()
234
+ trace["type"] = "tool_call"
235
+ display({MimeType.APPLICATION_JSON: trace}, raw=True)
236
+
237
+
238
+ def filter_bboxes_by_threshold(
239
+ bboxes: BoundingBoxes, threshold: float
240
+ ) -> BoundingBoxes:
241
+ return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
@@ -13,7 +13,7 @@ import cv2
13
13
  import numpy as np
14
14
  import requests
15
15
  from moviepy.editor import ImageSequenceClip
16
- from PIL import Image, ImageDraw, ImageFont
16
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance
17
17
  from pillow_heif import register_heif_opener # type: ignore
18
18
  from pytube import YouTube # type: ignore
19
19
 
@@ -24,14 +24,15 @@ from vision_agent.tools.tool_utils import (
24
24
  get_tools_df,
25
25
  get_tools_info,
26
26
  send_inference_request,
27
+ send_task_inference_request,
28
+ filter_bboxes_by_threshold,
27
29
  )
28
30
  from vision_agent.tools.tools_types import (
29
- BboxInput,
30
- BboxInputBase64,
31
31
  FineTuning,
32
- Florencev2FtRequest,
32
+ Florence2FtRequest,
33
33
  JobStatus,
34
34
  PromptTask,
35
+ ODResponseData,
35
36
  )
36
37
  from vision_agent.utils import extract_frames_from_video
37
38
  from vision_agent.utils.exceptions import FineTuneModelIsNotReady
@@ -455,7 +456,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
455
456
  "image": image_b64,
456
457
  "function_name": "loca_zero_shot_counting",
457
458
  }
458
- resp_data = send_inference_request(data, "loca", v2=True)
459
+ resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
459
460
  resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
460
461
  return resp_data
461
462
 
@@ -469,6 +470,8 @@ def loca_visual_prompt_counting(
469
470
 
470
471
  Parameters:
471
472
  image (np.ndarray): The image that contains lot of instances of a single object
473
+ visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
474
+ [xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
472
475
 
473
476
  Returns:
474
477
  Dict[str, Any]: A dictionary containing the key 'count' and the count as a
@@ -496,11 +499,109 @@ def loca_visual_prompt_counting(
496
499
  "bbox": list(map(int, denormalize_bbox(bbox, image_size))),
497
500
  "function_name": "loca_visual_prompt_counting",
498
501
  }
499
- resp_data = send_inference_request(data, "loca", v2=True)
502
+ resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
500
503
  resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
501
504
  return resp_data
502
505
 
503
506
 
507
+ def countgd_counting(
508
+ prompt: str,
509
+ image: np.ndarray,
510
+ box_threshold: float = 0.23,
511
+ ) -> List[Dict[str, Any]]:
512
+ """'countgd_counting' is a tool that can precisely count multiple instances of an
513
+ object given a text prompt. It returns a list of bounding boxes with normalized
514
+ coordinates, label names and associated confidence scores.
515
+
516
+ Parameters:
517
+ prompt (str): The object that needs to be counted.
518
+ image (np.ndarray): The image that contains multiple instances of the object.
519
+ box_threshold (float, optional): The threshold for detection. Defaults
520
+ to 0.23.
521
+
522
+ Returns:
523
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
524
+ bounding box of the detected objects with normalized coordinates between 0
525
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
526
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
527
+ bounding box.
528
+
529
+ Example
530
+ -------
531
+ >>> countgd_counting("flower", image)
532
+ [
533
+ {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
534
+ {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
535
+ {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
536
+ {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
537
+ ]
538
+ """
539
+ buffer_bytes = numpy_to_bytes(image)
540
+ files = [("image", buffer_bytes)]
541
+ prompt = prompt.replace(", ", " .")
542
+ payload = {"prompts": [prompt], "model": "countgd"}
543
+ metadata = {"function_name": "countgd_counting"}
544
+ resp_data = send_task_inference_request(
545
+ payload, "text-to-object-detection", files=files, metadata=metadata
546
+ )
547
+ bboxes_per_frame = resp_data[0]
548
+ bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
549
+ filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
550
+ return [bbox.model_dump() for bbox in filtered_bboxes]
551
+
552
+
553
+ def countgd_example_based_counting(
554
+ visual_prompts: List[List[float]],
555
+ image: np.ndarray,
556
+ box_threshold: float = 0.23,
557
+ ) -> List[Dict[str, Any]]:
558
+ """'countgd_example_based_counting' is a tool that can precisely count multiple
559
+ instances of an object given few visual example prompts. It returns a list of bounding
560
+ boxes with normalized coordinates, label names and associated confidence scores.
561
+
562
+ Parameters:
563
+ visual_prompts (List[List[float]]): Bounding boxes of the object in format
564
+ [xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
565
+ image (np.ndarray): The image that contains multiple instances of the object.
566
+ box_threshold (float, optional): The threshold for detection. Defaults
567
+ to 0.23.
568
+
569
+ Returns:
570
+ List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
571
+ bounding box of the detected objects with normalized coordinates between 0
572
+ and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
573
+ top-left and xmax and ymax are the coordinates of the bottom-right of the
574
+ bounding box.
575
+
576
+ Example
577
+ -------
578
+ >>> countgd_example_based_counting(
579
+ visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
580
+ image=image
581
+ )
582
+ [
583
+ {'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
584
+ {'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
585
+ {'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
586
+ {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
587
+ ]
588
+ """
589
+ buffer_bytes = numpy_to_bytes(image)
590
+ files = [("image", buffer_bytes)]
591
+ visual_prompts = [
592
+ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
593
+ ]
594
+ payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
595
+ metadata = {"function_name": "countgd_example_based_counting"}
596
+ resp_data = send_task_inference_request(
597
+ payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
598
+ )
599
+ bboxes_per_frame = resp_data[0]
600
+ bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
601
+ filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
602
+ return [bbox.model_dump() for bbox in filtered_bboxes]
603
+
604
+
504
605
  def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
505
606
  """'florence2_roberta_vqa' is a tool that takes an image and analyzes
506
607
  its contents, generates detailed captions and then tries to answer the given
@@ -646,7 +747,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
646
747
  "tool": "closed_set_image_classification",
647
748
  "function_name": "clip",
648
749
  }
649
- resp_data = send_inference_request(data, "tools")
750
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
650
751
  resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
651
752
  return resp_data
652
753
 
@@ -674,7 +775,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
674
775
  "tool": "image_classification",
675
776
  "function_name": "vit_image_classification",
676
777
  }
677
- resp_data = send_inference_request(data, "tools")
778
+ resp_data: dict[str, Any] = send_inference_request(data, "tools")
678
779
  resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
679
780
  return resp_data
680
781
 
@@ -701,7 +802,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
701
802
  "image": image_b64,
702
803
  "function_name": "vit_nsfw_classification",
703
804
  }
704
- resp_data = send_inference_request(data, "nsfw-classification", v2=True)
805
+ resp_data: dict[str, Any] = send_inference_request(
806
+ data, "nsfw-classification", v2=True
807
+ )
705
808
  resp_data["score"] = round(resp_data["score"], 4)
706
809
  return resp_data
707
810
 
@@ -762,7 +865,9 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
762
865
  return answer[task] # type: ignore
763
866
 
764
867
 
765
- def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
868
+ def florence2_phrase_grounding(
869
+ prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
870
+ ) -> List[Dict[str, Any]]:
766
871
  """'florence2_phrase_grounding' is a tool that can detect multiple
767
872
  objects given a text prompt which can be object names or caption. You
768
873
  can optionally separate the object names in the text with commas. It returns a list
@@ -772,6 +877,8 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
772
877
  Parameters:
773
878
  prompt (str): The prompt to ground to the image.
774
879
  image (np.ndarray): The image to used to detect objects
880
+ fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
881
+ fine-tuned model ID here to use it.
775
882
 
776
883
  Returns:
777
884
  List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
@@ -790,14 +897,33 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
790
897
  """
791
898
  image_size = image.shape[:2]
792
899
  image_b64 = convert_to_b64(image)
793
- data = {
794
- "image": image_b64,
795
- "task": "<CAPTION_TO_PHRASE_GROUNDING>",
796
- "prompt": prompt,
797
- "function_name": "florence2_phrase_grounding",
798
- }
799
900
 
800
- detections = send_inference_request(data, "florence2", v2=True)
901
+ if fine_tune_id is not None:
902
+ landing_api = LandingPublicAPI()
903
+ status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
904
+ if status is not JobStatus.SUCCEEDED:
905
+ raise FineTuneModelIsNotReady(
906
+ f"Fine-tuned model {fine_tune_id} is not ready yet"
907
+ )
908
+
909
+ data_obj = Florence2FtRequest(
910
+ image=image_b64,
911
+ task=PromptTask.PHRASE_GROUNDING,
912
+ tool="florencev2_fine_tuning",
913
+ prompt=prompt,
914
+ fine_tuning=FineTuning(job_id=UUID(fine_tune_id)),
915
+ )
916
+ data = data_obj.model_dump(by_alias=True)
917
+ detections = send_inference_request(data, "tools", v2=False)
918
+ else:
919
+ data = {
920
+ "image": image_b64,
921
+ "task": "<CAPTION_TO_PHRASE_GROUNDING>",
922
+ "prompt": prompt,
923
+ "function_name": "florence2_phrase_grounding",
924
+ }
925
+ detections = send_inference_request(data, "florence2", v2=True)
926
+
801
927
  detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
802
928
  return_data = []
803
929
  for i in range(len(detections["bboxes"])):
@@ -1559,117 +1685,72 @@ def overlay_heat_map(
1559
1685
  return np.array(combined)
1560
1686
 
1561
1687
 
1562
- # TODO: add this function to the imports so that is picked in the agent
1563
- def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID:
1564
- """'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able
1565
- to detect objects in an image based on a given dataset. It returns the fine
1566
- tuning job id.
1688
+ def overlay_counting_results(
1689
+ image: np.ndarray, instances: List[Dict[str, Any]]
1690
+ ) -> np.ndarray:
1691
+ """'overlay_counting_results' is a utility function that displays counting results on
1692
+ an image.
1567
1693
 
1568
1694
  Parameters:
1569
- bboxes (List[BboxInput]): A list of BboxInput containing the
1570
- image path, labels and bounding boxes.
1571
- task (PromptTask): The florencev2 fine-tuning task. The options are
1572
- CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
1695
+ image (np.ndarray): The image to display the bounding boxes on.
1696
+ instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
1697
+ box information of each instance
1573
1698
 
1574
1699
  Returns:
1575
- UUID: The fine tuning job id, this id will used to retrieve the fine
1576
- tuned model.
1700
+ np.ndarray: The image with the instance_id dislpayed
1577
1701
 
1578
1702
  Example
1579
1703
  -------
1580
- >>> fine_tuning_job_id = florencev2_fine_tuning(
1581
- [{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
1582
- {'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
1583
- "OBJECT_DETECTION"
1704
+ >>> image_with_bboxes = overlay_counting_results(
1705
+ image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
1584
1706
  )
1585
1707
  """
1586
- bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
1587
- task_input = PromptTask[task]
1588
- fine_tuning_request = [
1589
- BboxInputBase64(
1590
- image=convert_to_b64(bbox_input.image_path),
1591
- filename=bbox_input.image_path.split("/")[-1],
1592
- labels=bbox_input.labels,
1593
- bboxes=bbox_input.bboxes,
1594
- )
1595
- for bbox_input in bboxes_input
1596
- ]
1597
- landing_api = LandingPublicAPI()
1598
- return landing_api.launch_fine_tuning_job(
1599
- "florencev2", task_input, fine_tuning_request
1600
- )
1601
-
1708
+ pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
1709
+ color = (158, 218, 229)
1602
1710
 
1603
- # TODO: add this function to the imports so that is picked in the agent
1604
- def florencev2_fine_tuned_object_detection(
1605
- image: np.ndarray, prompt: str, model_id: UUID, task: str
1606
- ) -> List[Dict[str, Any]]:
1607
- """'florencev2_fine_tuned_object_detection' is a tool that uses a fine tuned model
1608
- to detect objects given a text prompt such as a phrase or class names separated by
1609
- commas. It returns a list of detected objects as labels and their location as
1610
- bounding boxes with score of 1.0.
1711
+ width, height = pil_image.size
1712
+ fontsize = max(10, int(min(width, height) / 80))
1713
+ pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
1714
+ draw = ImageDraw.Draw(pil_image)
1715
+ font = ImageFont.truetype(
1716
+ str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
1717
+ fontsize,
1718
+ )
1611
1719
 
1612
- Parameters:
1613
- image (np.ndarray): The image to used to detect objects.
1614
- prompt (str): The prompt to help find objects in the image.
1615
- model_id (UUID): The fine-tuned model id.
1616
- task (PromptTask): The florencev2 fine-tuning task. The options are
1617
- CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
1720
+ for i, elt in enumerate(instances):
1721
+ label = f"{i}"
1722
+ box = elt["bbox"]
1618
1723
 
1619
- Returns:
1620
- List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
1621
- bounding box of the detected objects with normalized coordinates between 0
1622
- and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
1623
- top-left and xmax and ymax are the coordinates of the bottom-right of the
1624
- bounding box. The scores are always 1.0 and cannot be thresholded
1724
+ # denormalize the box if it is normalized
1725
+ box = denormalize_bbox(box, (height, width))
1726
+ x0, y0, x1, y1 = box
1727
+ cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
1625
1728
 
1626
- Example
1627
- -------
1628
- >>> florencev2_fine_tuned_object_detection(
1629
- image,
1630
- 'person looking at a coyote',
1631
- UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83")
1729
+ text_box = draw.textbbox(
1730
+ (cx, cy), text=label, font=font, align="center", anchor="mm"
1632
1731
  )
1633
- [
1634
- {'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
1635
- {'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
1636
- ]
1637
- """
1638
- # check if job succeeded first
1639
- landing_api = LandingPublicAPI()
1640
- status = landing_api.check_fine_tuning_job(model_id)
1641
- if status is not JobStatus.SUCCEEDED:
1642
- raise FineTuneModelIsNotReady()
1643
-
1644
- task = PromptTask[task]
1645
- if task is PromptTask.OBJECT_DETECTION:
1646
- prompt = ""
1647
-
1648
- data_obj = Florencev2FtRequest(
1649
- image=convert_to_b64(image),
1650
- task=task,
1651
- tool="florencev2_fine_tuning",
1652
- prompt=prompt,
1653
- fine_tuning=FineTuning(job_id=model_id),
1654
- )
1655
- data = data_obj.model_dump(by_alias=True)
1656
- metadata_payload = {"function_name": "florencev2_fine_tuned_object_detection"}
1657
- detections = send_inference_request(
1658
- data, "tools", v2=False, metadata_payload=metadata_payload
1659
- )
1660
1732
 
1661
- detections = detections[task.value]
1662
- return_data = []
1663
- image_size = image.shape[:2]
1664
- for i in range(len(detections["bboxes"])):
1665
- return_data.append(
1666
- {
1667
- "score": 1.0,
1668
- "label": detections["labels"][i],
1669
- "bbox": normalize_bbox(detections["bboxes"][i], image_size),
1670
- }
1733
+ # Calculate the offset to center the text within the bounding box
1734
+ text_width = text_box[2] - text_box[0]
1735
+ text_height = text_box[3] - text_box[1]
1736
+ text_x0 = cx - text_width / 2
1737
+ text_y0 = cy - text_height / 2
1738
+ text_x1 = cx + text_width / 2
1739
+ text_y1 = cy + text_height / 2
1740
+
1741
+ # Draw the rectangle encapsulating the text
1742
+ draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
1743
+
1744
+ # Draw the text at the center of the bounding box
1745
+ draw.text(
1746
+ (text_x0, text_y0),
1747
+ label,
1748
+ fill="black",
1749
+ font=font,
1750
+ anchor="lt",
1671
1751
  )
1672
- return return_data
1752
+
1753
+ return np.array(pil_image)
1673
1754
 
1674
1755
 
1675
1756
  FUNCTION_TOOLS = [
@@ -1679,8 +1760,7 @@ FUNCTION_TOOLS = [
1679
1760
  clip,
1680
1761
  vit_image_classification,
1681
1762
  vit_nsfw_classification,
1682
- loca_zero_shot_counting,
1683
- loca_visual_prompt_counting,
1763
+ countgd_counting,
1684
1764
  florence2_image_caption,
1685
1765
  florence2_ocr,
1686
1766
  florence2_sam2_image,
@@ -1703,6 +1783,7 @@ UTIL_TOOLS = [
1703
1783
  overlay_bounding_boxes,
1704
1784
  overlay_segmentation_masks,
1705
1785
  overlay_heat_map,
1786
+ overlay_counting_results,
1706
1787
  ]
1707
1788
 
1708
1789
  TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
@@ -1720,5 +1801,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
1720
1801
  overlay_bounding_boxes,
1721
1802
  overlay_segmentation_masks,
1722
1803
  overlay_heat_map,
1804
+ overlay_counting_results,
1723
1805
  ]
1724
1806
  )
@@ -1,8 +1,8 @@
1
1
  from enum import Enum
2
- from typing import List, Optional, Tuple
3
2
  from uuid import UUID
3
+ from typing import List, Tuple, Optional, Union
4
4
 
5
- from pydantic import BaseModel, ConfigDict, Field, SerializationInfo, field_serializer
5
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
6
6
 
7
7
 
8
8
  class BboxInput(BaseModel):
@@ -19,16 +19,9 @@ class BboxInputBase64(BaseModel):
19
19
 
20
20
 
21
21
  class PromptTask(str, Enum):
22
- """
23
- Valid task prompts options for the Florencev2 model.
24
- """
22
+ """Valid task prompts options for the Florence2 model."""
25
23
 
26
- CAPTION = "<CAPTION>"
27
- """"""
28
- CAPTION_TO_PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
29
- """"""
30
- OBJECT_DETECTION = "<OD>"
31
- """"""
24
+ PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
32
25
 
33
26
 
34
27
  class FineTuning(BaseModel):
@@ -41,7 +34,7 @@ class FineTuning(BaseModel):
41
34
  return str(job_id)
42
35
 
43
36
 
44
- class Florencev2FtRequest(BaseModel):
37
+ class Florence2FtRequest(BaseModel):
45
38
  model_config = ConfigDict(populate_by_name=True)
46
39
 
47
40
  image: str
@@ -82,3 +75,16 @@ class JobStatus(str, Enum):
82
75
  SUCCEEDED = "SUCCEEDED"
83
76
  FAILED = "FAILED"
84
77
  STOPPED = "STOPPED"
78
+
79
+
80
+ class ODResponseData(BaseModel):
81
+ label: str
82
+ score: float
83
+ bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
84
+
85
+ model_config = ConfigDict(
86
+ populate_by_name=True,
87
+ )
88
+
89
+
90
+ BoundingBoxes = list[ODResponseData]
@@ -564,7 +564,13 @@ class LocalCodeInterpreter(CodeInterpreter):
564
564
  ) -> None:
565
565
  super().__init__(timeout=timeout)
566
566
  self.nb = nbformat.v4.new_notebook()
567
- self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
567
+ # Set the notebook execution path to the remote path
568
+ self.resources = {"metadata": {"path": str(self.remote_path)}}
569
+ self.nb_client = NotebookClient(
570
+ self.nb,
571
+ timeout=self.timeout,
572
+ resources=self.resources,
573
+ )
568
574
  _LOGGER.info(
569
575
  f"""Local code interpreter initialized
570
576
  Python version: {sys.version}
@@ -606,7 +612,9 @@ Timeout: {self.timeout}"""
606
612
  def restart_kernel(self) -> None:
607
613
  self.close()
608
614
  self.nb = nbformat.v4.new_notebook()
609
- self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
615
+ self.nb_client = NotebookClient(
616
+ self.nb, timeout=self.timeout, resources=self.resources
617
+ )
610
618
  sleep(1)
611
619
  self._new_kernel()
612
620
 
@@ -636,7 +644,7 @@ Timeout: {self.timeout}"""
636
644
  f.write(contents)
637
645
  _LOGGER.info(f"File ({file_path}) is uploaded to: {str(self.remote_path)}")
638
646
 
639
- return Path(self.remote_path / file_path)
647
+ return Path(self.remote_path / Path(file_path).name)
640
648
 
641
649
  def download_file(
642
650
  self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
@@ -672,7 +680,8 @@ class CodeInterpreterFactory:
672
680
 
673
681
  @staticmethod
674
682
  def new_instance(
675
- code_sandbox_runtime: Optional[str] = None, remote_path: Optional[str] = None
683
+ code_sandbox_runtime: Optional[str] = None,
684
+ remote_path: Optional[Union[str, Path]] = None,
676
685
  ) -> CodeInterpreter:
677
686
  if not code_sandbox_runtime:
678
687
  code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
@@ -181,7 +181,7 @@ def denormalize_bbox(
181
181
  raise ValueError("Bounding box must be of length 4.")
182
182
 
183
183
  arr = np.array(bbox)
184
- if np.all((arr >= 0) & (arr <= 1)):
184
+ if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
185
185
  x1, y1, x2, y2 = bbox
186
186
  x1 = round(x1 * image_size[1])
187
187
  y1 = round(y1 * image_size[0])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vision-agent
3
- Version: 0.2.120
3
+ Version: 0.2.122
4
4
  Summary: Toolset for Vision Agent
5
5
  Author: Landing AI
6
6
  Author-email: dev@landing.ai
@@ -2,32 +2,32 @@ vision_agent/__init__.py,sha256=EAb4-f9iyuEYkBrX4ag1syM8Syx8118_t0R6_C34M9w,57
2
2
  vision_agent/agent/__init__.py,sha256=FRwiux1FGvGccetyUCtY46KP01fQteqorm-JtFepovI,176
3
3
  vision_agent/agent/agent.py,sha256=2cjIOxEuSJrqbfPXYoV0qER5ihXsPFCoEFJa4jpqan0,597
4
4
  vision_agent/agent/agent_utils.py,sha256=22LiPhkJlS5mVeo2dIi259pc2NgA7PGHRpcbnrtKo78,1930
5
- vision_agent/agent/vision_agent.py,sha256=IEyXT_JPCuWmBHdEnM1Wrsj7hmCe5pKLf0gnZFJTddI,11046
6
- vision_agent/agent/vision_agent_coder.py,sha256=DOTmDdGPxcI06Jp6yx4ekRMP0vhiVaK9B9Dl8UyJHeo,34396
7
- vision_agent/agent/vision_agent_coder_prompts.py,sha256=xIya1txRZM8qoQHAWTEkEFCL8L3iZD7QD09t3ZtdxSE,11305
8
- vision_agent/agent/vision_agent_prompts.py,sha256=0GliXFtBf32aPu2ClU63FI5ii5CTxWYsvrsmnnDp-gs,7134
5
+ vision_agent/agent/vision_agent.py,sha256=WM1_o0VAQokAKlDr-0lpFxCRwUm_eFfFNWP-wSNjo7s,11180
6
+ vision_agent/agent/vision_agent_coder.py,sha256=ujctkpmQkX2C6YXjlp7VLZFqSB00xwkGe-9swA8Gv8s,34240
7
+ vision_agent/agent/vision_agent_coder_prompts.py,sha256=Rg7-Ih7oFgFbHFFno0EHpaZEgm0SYj_nTdqqdp21YLo,11246
8
+ vision_agent/agent/vision_agent_prompts.py,sha256=K1nLo3XKQ-IqCom1TRwh3cMoGZNxNwEgZqf3uJ6eL18,7221
9
9
  vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
11
11
  vision_agent/clients/landing_public_api.py,sha256=rGtACkr8o5egDuMHQ5MBO4NuvsgPTp9Ew3rbq4R-vs0,1507
12
12
  vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
14
14
  vision_agent/lmm/__init__.py,sha256=YuUZRsMHdn8cMOv6iBU8yUqlIOLrbZQqZl9KPnofsHQ,103
15
- vision_agent/lmm/lmm.py,sha256=AYrZNdhghG293wd3aKZ1jK1lUm2NLWwALktbM4wNais,20862
15
+ vision_agent/lmm/lmm.py,sha256=H3a5V7c073-vXRJfQOblE2j_CsZkH1CNNRoQgLjJZuQ,20751
16
16
  vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
17
- vision_agent/tools/__init__.py,sha256=i7JOLxRaLdcY7-vCNOGAeOFMBfiAUIwWhnT32FO97VE,2201
18
- vision_agent/tools/meta_tools.py,sha256=Vu9WnKicGhafx9dPzDbQjQdcIzRCYYFPF68o79hDP-8,14616
17
+ vision_agent/tools/__init__.py,sha256=TILaqdFYicScvpnCXMxgBsFmSW22NQDIvucvEgo0etw,2289
18
+ vision_agent/tools/meta_tools.py,sha256=KeGiw2OtY8ARpGbtWjoNAoO1dwevt7LbCupaJX61MkE,18929
19
19
  vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
20
- vision_agent/tools/tool_utils.py,sha256=qMsb9d8QtpXGgF9rpPO2dA390BewKdYO68oWKDu-TGg,6504
21
- vision_agent/tools/tools.py,sha256=kbbMToAaHxl42dDEvyz9Mvtpqts0l0hGoC5YQQyozr8,59953
22
- vision_agent/tools/tools_types.py,sha256=iLWSirheC87fKQolIhx_O4Jk8Lv7DRiLuE8PJqLGiVQ,2216
20
+ vision_agent/tools/tool_utils.py,sha256=e_p-G2nwgWOpoaqpDitY3FJ6fFuTEg5GhDOD67wI2bE,7527
21
+ vision_agent/tools/tools.py,sha256=jOBsuN-spY_2TlvpahoRYGvyInhQDTPXXukx9q72lEU,63454
22
+ vision_agent/tools/tools_types.py,sha256=qs11HGLRXc9zytahBtG6TQxCh8Gigvn232at3jk54jI,2356
23
23
  vision_agent/utils/__init__.py,sha256=pWk0ktvR4aUEhuEIzSLM9kSgW4WDVqptdvOTeGLkJ6M,230
24
24
  vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
25
- vision_agent/utils/execute.py,sha256=Ap8Yx80spQq5f2QtKGx1MK03BR45mJKhlp1kfh-rIao,26751
26
- vision_agent/utils/image_utils.py,sha256=eNghu_2L8624jEXy8ZZS9OX46Mv0DT9bcvLForujwTs,9848
25
+ vision_agent/utils/execute.py,sha256=gc4R_0BKUrZyhiKvIxOpYuzQPYVWQEqxr3ANy1lJAw4,27037
26
+ vision_agent/utils/image_utils.py,sha256=UloC4byIQLM4CSCaH41SBciQ7X2OqKvsVvNOVKqIH_k,9856
27
27
  vision_agent/utils/sim.py,sha256=ebE9Cs00pVEDI1HMjAzUBk88tQQmc2U-yAzIDinnekU,5572
28
28
  vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
29
29
  vision_agent/utils/video.py,sha256=rNmU9KEIkZB5-EztZNlUiKYN0mm_55A_2VGUM0QpqLA,8779
30
- vision_agent-0.2.120.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
- vision_agent-0.2.120.dist-info/METADATA,sha256=-FuNdlrzt5cTK6Ou_HTTROGVvsIwP3trsB5Edt2St3o,12255
32
- vision_agent-0.2.120.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
33
- vision_agent-0.2.120.dist-info/RECORD,,
30
+ vision_agent-0.2.122.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
31
+ vision_agent-0.2.122.dist-info/METADATA,sha256=WMdLNPyKY4Ot6ifOzwXNDiVm2TsStY-l-ge8t72Ynhk,12255
32
+ vision_agent-0.2.122.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
33
+ vision_agent-0.2.122.dist-info/RECORD,,