vision-agent 0.2.120__py3-none-any.whl → 0.2.122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vision_agent/agent/vision_agent.py +10 -6
- vision_agent/agent/vision_agent_coder.py +1 -9
- vision_agent/agent/vision_agent_coder_prompts.py +4 -5
- vision_agent/agent/vision_agent_prompts.py +3 -3
- vision_agent/lmm/lmm.py +0 -3
- vision_agent/tools/__init__.py +3 -0
- vision_agent/tools/meta_tools.py +140 -8
- vision_agent/tools/tool_utils.py +95 -51
- vision_agent/tools/tools.py +196 -114
- vision_agent/tools/tools_types.py +18 -12
- vision_agent/utils/execute.py +13 -4
- vision_agent/utils/image_utils.py +1 -1
- {vision_agent-0.2.120.dist-info → vision_agent-0.2.122.dist-info}/METADATA +1 -1
- {vision_agent-0.2.120.dist-info → vision_agent-0.2.122.dist-info}/RECORD +16 -16
- {vision_agent-0.2.120.dist-info → vision_agent-0.2.122.dist-info}/LICENSE +0 -0
- {vision_agent-0.2.120.dist-info → vision_agent-0.2.122.dist-info}/WHEEL +0 -0
@@ -30,7 +30,7 @@ class BoilerplateCode:
|
|
30
30
|
pre_code = [
|
31
31
|
"from typing import *",
|
32
32
|
"from vision_agent.utils.execute import CodeInterpreter",
|
33
|
-
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
|
33
|
+
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
|
34
34
|
"artifacts = Artifacts('{remote_path}')",
|
35
35
|
"artifacts.load('{remote_path}')",
|
36
36
|
]
|
@@ -76,11 +76,16 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
|
|
76
76
|
|
77
77
|
def run_code_action(
|
78
78
|
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
|
79
|
-
) -> Execution:
|
80
|
-
|
79
|
+
) -> Tuple[Execution, str]:
|
80
|
+
result = code_interpreter.exec_isolation(
|
81
81
|
BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
|
82
82
|
)
|
83
83
|
|
84
|
+
obs = str(result.logs)
|
85
|
+
if result.error:
|
86
|
+
obs += f"\n{result.error}"
|
87
|
+
return result, obs
|
88
|
+
|
84
89
|
|
85
90
|
def parse_execution(response: str) -> Optional[str]:
|
86
91
|
code = None
|
@@ -192,7 +197,7 @@ class VisionAgent(Agent):
|
|
192
197
|
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
|
193
198
|
|
194
199
|
with CodeInterpreterFactory.new_instance(
|
195
|
-
code_sandbox_runtime=self.code_sandbox_runtime
|
200
|
+
code_sandbox_runtime=self.code_sandbox_runtime,
|
196
201
|
) as code_interpreter:
|
197
202
|
orig_chat = copy.deepcopy(chat)
|
198
203
|
int_chat = copy.deepcopy(chat)
|
@@ -260,10 +265,9 @@ class VisionAgent(Agent):
|
|
260
265
|
code_action = parse_execution(response["response"])
|
261
266
|
|
262
267
|
if code_action is not None:
|
263
|
-
result = run_code_action(
|
268
|
+
result, obs = run_code_action(
|
264
269
|
code_action, code_interpreter, str(remote_artifacts_path)
|
265
270
|
)
|
266
|
-
obs = str(result.logs)
|
267
271
|
|
268
272
|
if self.verbosity >= 1:
|
269
273
|
_LOGGER.info(obs)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import copy
|
2
|
-
import difflib
|
3
2
|
import logging
|
4
3
|
import os
|
5
4
|
import sys
|
@@ -29,6 +28,7 @@ from vision_agent.agent.vision_agent_coder_prompts import (
|
|
29
28
|
USER_REQ,
|
30
29
|
)
|
31
30
|
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
|
31
|
+
from vision_agent.tools.meta_tools import get_diff
|
32
32
|
from vision_agent.utils import CodeInterpreterFactory, Execution
|
33
33
|
from vision_agent.utils.execute import CodeInterpreter
|
34
34
|
from vision_agent.utils.image_utils import b64_to_pil
|
@@ -63,14 +63,6 @@ class DefaultImports:
|
|
63
63
|
return DefaultImports.to_code_string() + "\n\n" + code
|
64
64
|
|
65
65
|
|
66
|
-
def get_diff(before: str, after: str) -> str:
|
67
|
-
return "".join(
|
68
|
-
difflib.unified_diff(
|
69
|
-
before.splitlines(keepends=True), after.splitlines(keepends=True)
|
70
|
-
)
|
71
|
-
)
|
72
|
-
|
73
|
-
|
74
66
|
def format_memory(memory: List[Dict[str, str]]) -> str:
|
75
67
|
output_str = ""
|
76
68
|
for i, m in enumerate(memory):
|
@@ -81,20 +81,19 @@ plan2:
|
|
81
81
|
- Count the number of detected objects labeled as 'person'.
|
82
82
|
plan3:
|
83
83
|
- Load the image from the provided file path 'image.jpg'.
|
84
|
-
- Use the '
|
84
|
+
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
|
85
85
|
|
86
86
|
```python
|
87
|
-
from vision_agent.tools import load_image, owl_v2, grounding_sam,
|
87
|
+
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
|
88
88
|
image = load_image("image.jpg")
|
89
89
|
owl_v2_out = owl_v2("person", image)
|
90
90
|
|
91
91
|
gsam_out = grounding_sam("person", image)
|
92
92
|
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
|
93
93
|
|
94
|
-
|
95
|
-
loca_out = loca_out["count"]
|
94
|
+
cgd_out = countgd_counting(image)
|
96
95
|
|
97
|
-
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "
|
96
|
+
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
|
98
97
|
print(final_out)
|
99
98
|
```
|
100
99
|
"""
|
@@ -48,7 +48,7 @@ OBSERVATION:
|
|
48
48
|
4| return dogs
|
49
49
|
[End of artifact]
|
50
50
|
|
51
|
-
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
|
51
|
+
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
|
52
52
|
|
53
53
|
OBSERVATION:
|
54
54
|
----- stdout -----
|
@@ -75,7 +75,7 @@ OBSERVATION:
|
|
75
75
|
4| return dogs
|
76
76
|
[End of artifact]
|
77
77
|
|
78
|
-
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
|
78
|
+
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
|
79
79
|
|
80
80
|
OBSERVATION:
|
81
81
|
----- stdout -----
|
@@ -126,7 +126,7 @@ OBSERVATION:
|
|
126
126
|
15| return count
|
127
127
|
[End of artifact]
|
128
128
|
|
129
|
-
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code
|
129
|
+
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
|
130
130
|
|
131
131
|
OBSERVATION:
|
132
132
|
----- stdout -----
|
vision_agent/lmm/lmm.py
CHANGED
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
|
|
286
286
|
|
287
287
|
return lambda x: T.grounding_sam(params["prompt"], x)
|
288
288
|
|
289
|
-
def generate_zero_shot_counter(self, question: str) -> Callable:
|
290
|
-
return T.loca_zero_shot_counting
|
291
|
-
|
292
289
|
def generate_image_qa_tool(self, question: str) -> Callable:
|
293
290
|
return lambda x: T.git_vqa_v2(question, x)
|
294
291
|
|
vision_agent/tools/__init__.py
CHANGED
@@ -37,10 +37,13 @@ from .tools import (
|
|
37
37
|
load_image,
|
38
38
|
loca_visual_prompt_counting,
|
39
39
|
loca_zero_shot_counting,
|
40
|
+
countgd_counting,
|
41
|
+
countgd_example_based_counting,
|
40
42
|
ocr,
|
41
43
|
overlay_bounding_boxes,
|
42
44
|
overlay_heat_map,
|
43
45
|
overlay_segmentation_masks,
|
46
|
+
overlay_counting_results,
|
44
47
|
owl_v2,
|
45
48
|
save_image,
|
46
49
|
save_json,
|
vision_agent/tools/meta_tools.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
|
+
import difflib
|
1
2
|
import os
|
2
3
|
import pickle as pkl
|
4
|
+
import re
|
3
5
|
import subprocess
|
4
6
|
import tempfile
|
5
7
|
from pathlib import Path
|
@@ -8,10 +10,13 @@ from typing import Any, Dict, List, Optional, Union
|
|
8
10
|
from IPython.display import display
|
9
11
|
|
10
12
|
import vision_agent as va
|
13
|
+
from vision_agent.clients.landing_public_api import LandingPublicAPI
|
11
14
|
from vision_agent.lmm.types import Message
|
12
15
|
from vision_agent.tools.tool_utils import get_tool_documentation
|
13
16
|
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
|
17
|
+
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
|
14
18
|
from vision_agent.utils.execute import Execution, MimeType
|
19
|
+
from vision_agent.utils.image_utils import convert_to_b64
|
15
20
|
|
16
21
|
# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent
|
17
22
|
|
@@ -99,13 +104,14 @@ class Artifacts:
|
|
99
104
|
|
100
105
|
def show(self) -> str:
|
101
106
|
"""Shows the artifacts that have been loaded and their remote save paths."""
|
102
|
-
|
107
|
+
output_str = "[Artifacts loaded]\n"
|
103
108
|
for k in self.artifacts.keys():
|
104
|
-
|
109
|
+
output_str += (
|
105
110
|
f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
|
106
111
|
)
|
107
|
-
|
108
|
-
|
112
|
+
output_str += "[End of artifacts]\n"
|
113
|
+
print(output_str)
|
114
|
+
return output_str
|
109
115
|
|
110
116
|
def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
|
111
117
|
save_path = (
|
@@ -135,7 +141,12 @@ def format_lines(lines: List[str], start_idx: int) -> str:
|
|
135
141
|
|
136
142
|
|
137
143
|
def view_lines(
|
138
|
-
lines: List[str],
|
144
|
+
lines: List[str],
|
145
|
+
line_num: int,
|
146
|
+
window_size: int,
|
147
|
+
name: str,
|
148
|
+
total_lines: int,
|
149
|
+
print_output: bool = True,
|
139
150
|
) -> str:
|
140
151
|
start = max(0, line_num - window_size)
|
141
152
|
end = min(len(lines), line_num + window_size)
|
@@ -148,7 +159,9 @@ def view_lines(
|
|
148
159
|
else f"[{len(lines) - end} more lines]"
|
149
160
|
)
|
150
161
|
)
|
151
|
-
|
162
|
+
|
163
|
+
if print_output:
|
164
|
+
print(return_str)
|
152
165
|
return return_str
|
153
166
|
|
154
167
|
|
@@ -231,7 +244,7 @@ def edit_code_artifact(
|
|
231
244
|
new_content_lines = [
|
232
245
|
line if line.endswith("\n") else line + "\n" for line in new_content_lines
|
233
246
|
]
|
234
|
-
lines = artifacts[name].splitlines()
|
247
|
+
lines = artifacts[name].splitlines(keepends=True)
|
235
248
|
edited_lines = lines[:start] + new_content_lines + lines[end:]
|
236
249
|
|
237
250
|
cur_line = start + len(content.split("\n")) // 2
|
@@ -261,13 +274,20 @@ def edit_code_artifact(
|
|
261
274
|
DEFAULT_WINDOW_SIZE,
|
262
275
|
name,
|
263
276
|
total_lines,
|
277
|
+
print_output=False,
|
264
278
|
)
|
265
279
|
total_lines_edit = sum(1 for _ in edited_lines)
|
266
280
|
edited_view = view_lines(
|
267
|
-
edited_lines,
|
281
|
+
edited_lines,
|
282
|
+
cur_line,
|
283
|
+
DEFAULT_WINDOW_SIZE,
|
284
|
+
name,
|
285
|
+
total_lines_edit,
|
286
|
+
print_output=False,
|
268
287
|
)
|
269
288
|
|
270
289
|
error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
|
290
|
+
print(error_msg)
|
271
291
|
return error_msg
|
272
292
|
|
273
293
|
artifacts[name] = "".join(edited_lines)
|
@@ -390,6 +410,13 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
|
|
390
410
|
return f"[Media {Path(local_path).name} saved]"
|
391
411
|
|
392
412
|
|
413
|
+
def list_artifacts(artifacts: Artifacts) -> str:
|
414
|
+
"""Lists all the artifacts that have been loaded into the artifacts object."""
|
415
|
+
output_str = artifacts.show()
|
416
|
+
print(output_str)
|
417
|
+
return output_str
|
418
|
+
|
419
|
+
|
393
420
|
def get_tool_descriptions() -> str:
|
394
421
|
"""Returns a description of all the tools that `generate_vision_code` has access to.
|
395
422
|
Helpful for answering questions about what types of vision tasks you can do with
|
@@ -397,6 +424,108 @@ def get_tool_descriptions() -> str:
|
|
397
424
|
return TOOL_DESCRIPTIONS
|
398
425
|
|
399
426
|
|
427
|
+
def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
|
428
|
+
"""'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
|
429
|
+
objects in an image based on a given dataset. It returns the fine tuning job id.
|
430
|
+
|
431
|
+
Parameters:
|
432
|
+
bboxes (List[BboxInput]): A list of BboxInput containing the
|
433
|
+
image path, labels and bounding boxes.
|
434
|
+
task (str): The florencev2 fine-tuning task. The options are
|
435
|
+
'phrase_grounding'.
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
UUID: The fine tuning job id, this id will used to retrieve the fine
|
439
|
+
tuned model.
|
440
|
+
|
441
|
+
Example
|
442
|
+
-------
|
443
|
+
>>> fine_tuning_job_id = florencev2_fine_tuning(
|
444
|
+
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
|
445
|
+
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
|
446
|
+
"phrase_grounding"
|
447
|
+
)
|
448
|
+
"""
|
449
|
+
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
|
450
|
+
task_type = PromptTask[task.upper()]
|
451
|
+
fine_tuning_request = [
|
452
|
+
BboxInputBase64(
|
453
|
+
image=convert_to_b64(bbox_input.image_path),
|
454
|
+
filename=Path(bbox_input.image_path).name,
|
455
|
+
labels=bbox_input.labels,
|
456
|
+
bboxes=bbox_input.bboxes,
|
457
|
+
)
|
458
|
+
for bbox_input in bboxes_input
|
459
|
+
]
|
460
|
+
landing_api = LandingPublicAPI()
|
461
|
+
fine_tune_id = str(
|
462
|
+
landing_api.launch_fine_tuning_job("florencev2", task_type, fine_tuning_request)
|
463
|
+
)
|
464
|
+
print(f"[Florence2 fine tuning id: {fine_tune_id}]")
|
465
|
+
return fine_tune_id
|
466
|
+
|
467
|
+
|
468
|
+
def get_diff(before: str, after: str) -> str:
|
469
|
+
return "".join(
|
470
|
+
difflib.unified_diff(
|
471
|
+
before.splitlines(keepends=True), after.splitlines(keepends=True)
|
472
|
+
)
|
473
|
+
)
|
474
|
+
|
475
|
+
|
476
|
+
def use_florence2_fine_tuning(
|
477
|
+
artifacts: Artifacts, name: str, task: str, fine_tune_id: str
|
478
|
+
) -> str:
|
479
|
+
"""Replaces florence2 calls with the fine tuning id. This ensures that the code
|
480
|
+
utilizes the fined tuned florence2 model. Returns the diff between the original
|
481
|
+
code and the new code.
|
482
|
+
|
483
|
+
Parameters:
|
484
|
+
artifacts (Artifacts): The artifacts object to edit the code from.
|
485
|
+
name (str): The name of the artifact to edit.
|
486
|
+
task (str): The task to fine tune the model for. The options are
|
487
|
+
'phrase_grounding'.
|
488
|
+
fine_tune_id (str): The fine tuning job id.
|
489
|
+
|
490
|
+
Examples
|
491
|
+
--------
|
492
|
+
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
|
493
|
+
"""
|
494
|
+
|
495
|
+
task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}
|
496
|
+
|
497
|
+
if name not in artifacts:
|
498
|
+
output_str = f"[Artifact {name} does not exist]"
|
499
|
+
print(output_str)
|
500
|
+
return output_str
|
501
|
+
|
502
|
+
code = artifacts[name]
|
503
|
+
if task.lower() == "phrase_grounding":
|
504
|
+
pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"
|
505
|
+
|
506
|
+
def replacer(match: re.Match) -> str:
|
507
|
+
arg = match.group(1) # capture all initial arguments
|
508
|
+
return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'
|
509
|
+
|
510
|
+
else:
|
511
|
+
raise ValueError(f"Task {task} is not supported.")
|
512
|
+
|
513
|
+
new_code = re.sub(pattern, replacer, code)
|
514
|
+
|
515
|
+
if new_code == code:
|
516
|
+
output_str = (
|
517
|
+
f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
|
518
|
+
)
|
519
|
+
print(output_str)
|
520
|
+
return output_str
|
521
|
+
|
522
|
+
artifacts[name] = new_code
|
523
|
+
|
524
|
+
diff = get_diff(code, new_code)
|
525
|
+
print(diff)
|
526
|
+
return diff
|
527
|
+
|
528
|
+
|
400
529
|
META_TOOL_DOCSTRING = get_tool_documentation(
|
401
530
|
[
|
402
531
|
get_tool_descriptions,
|
@@ -406,5 +535,8 @@ META_TOOL_DOCSTRING = get_tool_documentation(
|
|
406
535
|
generate_vision_code,
|
407
536
|
edit_vision_code,
|
408
537
|
write_media_artifact,
|
538
|
+
florence2_fine_tuning,
|
539
|
+
use_florence2_fine_tuning,
|
540
|
+
list_artifacts,
|
409
541
|
]
|
410
542
|
)
|
vision_agent/tools/tool_utils.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
|
+
import os
|
1
2
|
import inspect
|
2
3
|
import logging
|
3
|
-
import os
|
4
4
|
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
|
5
5
|
|
6
6
|
import pandas as pd
|
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
|
|
13
13
|
from vision_agent.utils.exceptions import RemoteToolCallFailed
|
14
14
|
from vision_agent.utils.execute import Error, MimeType
|
15
15
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
16
|
+
from vision_agent.tools.tools_types import BoundingBoxes
|
16
17
|
|
17
18
|
_LOGGER = logging.getLogger(__name__)
|
18
19
|
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
|
@@ -34,61 +35,58 @@ def send_inference_request(
|
|
34
35
|
files: Optional[List[Tuple[Any, ...]]] = None,
|
35
36
|
v2: bool = False,
|
36
37
|
metadata_payload: Optional[Dict[str, Any]] = None,
|
37
|
-
) ->
|
38
|
+
) -> Any:
|
38
39
|
# TODO: runtime_tag and function_name should be metadata_payload and now included
|
39
40
|
# in the service payload
|
40
|
-
|
41
|
-
|
42
|
-
|
41
|
+
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
|
42
|
+
payload["runtime_tag"] = runtime_tag
|
43
|
+
|
44
|
+
url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
|
45
|
+
if "TOOL_ENDPOINT_URL" in os.environ:
|
46
|
+
url = os.environ["TOOL_ENDPOINT_URL"]
|
47
|
+
|
48
|
+
headers = {"apikey": _LND_API_KEY}
|
49
|
+
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
50
|
+
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
51
|
+
headers.pop("apikey")
|
52
|
+
|
53
|
+
session = _create_requests_session(
|
54
|
+
url=url,
|
55
|
+
num_retry=3,
|
56
|
+
headers=headers,
|
57
|
+
)
|
43
58
|
|
44
|
-
|
45
|
-
|
46
|
-
|
59
|
+
function_name = "unknown"
|
60
|
+
if "function_name" in payload:
|
61
|
+
function_name = payload["function_name"]
|
62
|
+
elif metadata_payload is not None and "function_name" in metadata_payload:
|
63
|
+
function_name = metadata_payload["function_name"]
|
47
64
|
|
48
|
-
|
49
|
-
endpoint_url=url,
|
50
|
-
request=payload,
|
51
|
-
response={},
|
52
|
-
error=None,
|
53
|
-
)
|
54
|
-
headers = {"apikey": _LND_API_KEY}
|
55
|
-
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
56
|
-
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
57
|
-
headers.pop("apikey")
|
58
|
-
|
59
|
-
session = _create_requests_session(
|
60
|
-
url=url,
|
61
|
-
num_retry=3,
|
62
|
-
headers=headers,
|
63
|
-
)
|
65
|
+
response = _call_post(url, payload, session, files, function_name)
|
64
66
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
finally:
|
89
|
-
trace = tool_call_trace.model_dump()
|
90
|
-
trace["type"] = "tool_call"
|
91
|
-
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
67
|
+
# TODO: consider making the response schema the same between below two sources
|
68
|
+
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
|
69
|
+
|
70
|
+
|
71
|
+
def send_task_inference_request(
|
72
|
+
payload: Dict[str, Any],
|
73
|
+
task_name: str,
|
74
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
75
|
+
metadata: Optional[Dict[str, Any]] = None,
|
76
|
+
) -> Any:
|
77
|
+
url = f"{_LND_API_URL_v2}/{task_name}"
|
78
|
+
headers = {"apikey": _LND_API_KEY}
|
79
|
+
session = _create_requests_session(
|
80
|
+
url=url,
|
81
|
+
num_retry=3,
|
82
|
+
headers=headers,
|
83
|
+
)
|
84
|
+
|
85
|
+
function_name = "unknown"
|
86
|
+
if metadata is not None and "function_name" in metadata:
|
87
|
+
function_name = metadata["function_name"]
|
88
|
+
response = _call_post(url, payload, session, files, function_name)
|
89
|
+
return response["data"]
|
92
90
|
|
93
91
|
|
94
92
|
def _create_requests_session(
|
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
|
|
195
193
|
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
|
196
194
|
|
197
195
|
return data
|
196
|
+
|
197
|
+
|
198
|
+
def _call_post(
|
199
|
+
url: str,
|
200
|
+
payload: dict[str, Any],
|
201
|
+
session: Session,
|
202
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
203
|
+
function_name: str = "unknown",
|
204
|
+
) -> Any:
|
205
|
+
try:
|
206
|
+
tool_call_trace = ToolCallTrace(
|
207
|
+
endpoint_url=url,
|
208
|
+
request=payload,
|
209
|
+
response={},
|
210
|
+
error=None,
|
211
|
+
)
|
212
|
+
|
213
|
+
if files is not None:
|
214
|
+
response = session.post(url, data=payload, files=files)
|
215
|
+
else:
|
216
|
+
response = session.post(url, json=payload)
|
217
|
+
|
218
|
+
if response.status_code != 200:
|
219
|
+
tool_call_trace.error = Error(
|
220
|
+
name="RemoteToolCallFailed",
|
221
|
+
value=f"{response.status_code} - {response.text}",
|
222
|
+
traceback_raw=[],
|
223
|
+
)
|
224
|
+
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
|
225
|
+
raise RemoteToolCallFailed(
|
226
|
+
function_name, response.status_code, response.text
|
227
|
+
)
|
228
|
+
|
229
|
+
result = response.json()
|
230
|
+
tool_call_trace.response = result
|
231
|
+
return result
|
232
|
+
finally:
|
233
|
+
trace = tool_call_trace.model_dump()
|
234
|
+
trace["type"] = "tool_call"
|
235
|
+
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
236
|
+
|
237
|
+
|
238
|
+
def filter_bboxes_by_threshold(
|
239
|
+
bboxes: BoundingBoxes, threshold: float
|
240
|
+
) -> BoundingBoxes:
|
241
|
+
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
|
vision_agent/tools/tools.py
CHANGED
@@ -13,7 +13,7 @@ import cv2
|
|
13
13
|
import numpy as np
|
14
14
|
import requests
|
15
15
|
from moviepy.editor import ImageSequenceClip
|
16
|
-
from PIL import Image, ImageDraw, ImageFont
|
16
|
+
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
|
17
17
|
from pillow_heif import register_heif_opener # type: ignore
|
18
18
|
from pytube import YouTube # type: ignore
|
19
19
|
|
@@ -24,14 +24,15 @@ from vision_agent.tools.tool_utils import (
|
|
24
24
|
get_tools_df,
|
25
25
|
get_tools_info,
|
26
26
|
send_inference_request,
|
27
|
+
send_task_inference_request,
|
28
|
+
filter_bboxes_by_threshold,
|
27
29
|
)
|
28
30
|
from vision_agent.tools.tools_types import (
|
29
|
-
BboxInput,
|
30
|
-
BboxInputBase64,
|
31
31
|
FineTuning,
|
32
|
-
|
32
|
+
Florence2FtRequest,
|
33
33
|
JobStatus,
|
34
34
|
PromptTask,
|
35
|
+
ODResponseData,
|
35
36
|
)
|
36
37
|
from vision_agent.utils import extract_frames_from_video
|
37
38
|
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
|
@@ -455,7 +456,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
|
|
455
456
|
"image": image_b64,
|
456
457
|
"function_name": "loca_zero_shot_counting",
|
457
458
|
}
|
458
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
459
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
459
460
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
460
461
|
return resp_data
|
461
462
|
|
@@ -469,6 +470,8 @@ def loca_visual_prompt_counting(
|
|
469
470
|
|
470
471
|
Parameters:
|
471
472
|
image (np.ndarray): The image that contains lot of instances of a single object
|
473
|
+
visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
|
474
|
+
[xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
|
472
475
|
|
473
476
|
Returns:
|
474
477
|
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
|
@@ -496,11 +499,109 @@ def loca_visual_prompt_counting(
|
|
496
499
|
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
|
497
500
|
"function_name": "loca_visual_prompt_counting",
|
498
501
|
}
|
499
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
502
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
500
503
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
501
504
|
return resp_data
|
502
505
|
|
503
506
|
|
507
|
+
def countgd_counting(
|
508
|
+
prompt: str,
|
509
|
+
image: np.ndarray,
|
510
|
+
box_threshold: float = 0.23,
|
511
|
+
) -> List[Dict[str, Any]]:
|
512
|
+
"""'countgd_counting' is a tool that can precisely count multiple instances of an
|
513
|
+
object given a text prompt. It returns a list of bounding boxes with normalized
|
514
|
+
coordinates, label names and associated confidence scores.
|
515
|
+
|
516
|
+
Parameters:
|
517
|
+
prompt (str): The object that needs to be counted.
|
518
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
519
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
520
|
+
to 0.23.
|
521
|
+
|
522
|
+
Returns:
|
523
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
524
|
+
bounding box of the detected objects with normalized coordinates between 0
|
525
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
526
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
527
|
+
bounding box.
|
528
|
+
|
529
|
+
Example
|
530
|
+
-------
|
531
|
+
>>> countgd_counting("flower", image)
|
532
|
+
[
|
533
|
+
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
534
|
+
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
|
535
|
+
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
|
536
|
+
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
|
537
|
+
]
|
538
|
+
"""
|
539
|
+
buffer_bytes = numpy_to_bytes(image)
|
540
|
+
files = [("image", buffer_bytes)]
|
541
|
+
prompt = prompt.replace(", ", " .")
|
542
|
+
payload = {"prompts": [prompt], "model": "countgd"}
|
543
|
+
metadata = {"function_name": "countgd_counting"}
|
544
|
+
resp_data = send_task_inference_request(
|
545
|
+
payload, "text-to-object-detection", files=files, metadata=metadata
|
546
|
+
)
|
547
|
+
bboxes_per_frame = resp_data[0]
|
548
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
549
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
550
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
551
|
+
|
552
|
+
|
553
|
+
def countgd_example_based_counting(
|
554
|
+
visual_prompts: List[List[float]],
|
555
|
+
image: np.ndarray,
|
556
|
+
box_threshold: float = 0.23,
|
557
|
+
) -> List[Dict[str, Any]]:
|
558
|
+
"""'countgd_example_based_counting' is a tool that can precisely count multiple
|
559
|
+
instances of an object given few visual example prompts. It returns a list of bounding
|
560
|
+
boxes with normalized coordinates, label names and associated confidence scores.
|
561
|
+
|
562
|
+
Parameters:
|
563
|
+
visual_prompts (List[List[float]]): Bounding boxes of the object in format
|
564
|
+
[xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
|
565
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
566
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
567
|
+
to 0.23.
|
568
|
+
|
569
|
+
Returns:
|
570
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
571
|
+
bounding box of the detected objects with normalized coordinates between 0
|
572
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
573
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
574
|
+
bounding box.
|
575
|
+
|
576
|
+
Example
|
577
|
+
-------
|
578
|
+
>>> countgd_example_based_counting(
|
579
|
+
visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
|
580
|
+
image=image
|
581
|
+
)
|
582
|
+
[
|
583
|
+
{'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
|
584
|
+
{'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
|
585
|
+
{'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
|
586
|
+
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
|
587
|
+
]
|
588
|
+
"""
|
589
|
+
buffer_bytes = numpy_to_bytes(image)
|
590
|
+
files = [("image", buffer_bytes)]
|
591
|
+
visual_prompts = [
|
592
|
+
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
|
593
|
+
]
|
594
|
+
payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
|
595
|
+
metadata = {"function_name": "countgd_example_based_counting"}
|
596
|
+
resp_data = send_task_inference_request(
|
597
|
+
payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
|
598
|
+
)
|
599
|
+
bboxes_per_frame = resp_data[0]
|
600
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
601
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
602
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
603
|
+
|
604
|
+
|
504
605
|
def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
|
505
606
|
"""'florence2_roberta_vqa' is a tool that takes an image and analyzes
|
506
607
|
its contents, generates detailed captions and then tries to answer the given
|
@@ -646,7 +747,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
|
|
646
747
|
"tool": "closed_set_image_classification",
|
647
748
|
"function_name": "clip",
|
648
749
|
}
|
649
|
-
resp_data = send_inference_request(data, "tools")
|
750
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
650
751
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
651
752
|
return resp_data
|
652
753
|
|
@@ -674,7 +775,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
674
775
|
"tool": "image_classification",
|
675
776
|
"function_name": "vit_image_classification",
|
676
777
|
}
|
677
|
-
resp_data = send_inference_request(data, "tools")
|
778
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
678
779
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
679
780
|
return resp_data
|
680
781
|
|
@@ -701,7 +802,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
701
802
|
"image": image_b64,
|
702
803
|
"function_name": "vit_nsfw_classification",
|
703
804
|
}
|
704
|
-
resp_data = send_inference_request(
|
805
|
+
resp_data: dict[str, Any] = send_inference_request(
|
806
|
+
data, "nsfw-classification", v2=True
|
807
|
+
)
|
705
808
|
resp_data["score"] = round(resp_data["score"], 4)
|
706
809
|
return resp_data
|
707
810
|
|
@@ -762,7 +865,9 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
|
|
762
865
|
return answer[task] # type: ignore
|
763
866
|
|
764
867
|
|
765
|
-
def florence2_phrase_grounding(
|
868
|
+
def florence2_phrase_grounding(
|
869
|
+
prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
|
870
|
+
) -> List[Dict[str, Any]]:
|
766
871
|
"""'florence2_phrase_grounding' is a tool that can detect multiple
|
767
872
|
objects given a text prompt which can be object names or caption. You
|
768
873
|
can optionally separate the object names in the text with commas. It returns a list
|
@@ -772,6 +877,8 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
|
|
772
877
|
Parameters:
|
773
878
|
prompt (str): The prompt to ground to the image.
|
774
879
|
image (np.ndarray): The image to used to detect objects
|
880
|
+
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
881
|
+
fine-tuned model ID here to use it.
|
775
882
|
|
776
883
|
Returns:
|
777
884
|
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
@@ -790,14 +897,33 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
|
|
790
897
|
"""
|
791
898
|
image_size = image.shape[:2]
|
792
899
|
image_b64 = convert_to_b64(image)
|
793
|
-
data = {
|
794
|
-
"image": image_b64,
|
795
|
-
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
|
796
|
-
"prompt": prompt,
|
797
|
-
"function_name": "florence2_phrase_grounding",
|
798
|
-
}
|
799
900
|
|
800
|
-
|
901
|
+
if fine_tune_id is not None:
|
902
|
+
landing_api = LandingPublicAPI()
|
903
|
+
status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
|
904
|
+
if status is not JobStatus.SUCCEEDED:
|
905
|
+
raise FineTuneModelIsNotReady(
|
906
|
+
f"Fine-tuned model {fine_tune_id} is not ready yet"
|
907
|
+
)
|
908
|
+
|
909
|
+
data_obj = Florence2FtRequest(
|
910
|
+
image=image_b64,
|
911
|
+
task=PromptTask.PHRASE_GROUNDING,
|
912
|
+
tool="florencev2_fine_tuning",
|
913
|
+
prompt=prompt,
|
914
|
+
fine_tuning=FineTuning(job_id=UUID(fine_tune_id)),
|
915
|
+
)
|
916
|
+
data = data_obj.model_dump(by_alias=True)
|
917
|
+
detections = send_inference_request(data, "tools", v2=False)
|
918
|
+
else:
|
919
|
+
data = {
|
920
|
+
"image": image_b64,
|
921
|
+
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
|
922
|
+
"prompt": prompt,
|
923
|
+
"function_name": "florence2_phrase_grounding",
|
924
|
+
}
|
925
|
+
detections = send_inference_request(data, "florence2", v2=True)
|
926
|
+
|
801
927
|
detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
|
802
928
|
return_data = []
|
803
929
|
for i in range(len(detections["bboxes"])):
|
@@ -1559,117 +1685,72 @@ def overlay_heat_map(
|
|
1559
1685
|
return np.array(combined)
|
1560
1686
|
|
1561
1687
|
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1688
|
+
def overlay_counting_results(
|
1689
|
+
image: np.ndarray, instances: List[Dict[str, Any]]
|
1690
|
+
) -> np.ndarray:
|
1691
|
+
"""'overlay_counting_results' is a utility function that displays counting results on
|
1692
|
+
an image.
|
1567
1693
|
|
1568
1694
|
Parameters:
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
|
1695
|
+
image (np.ndarray): The image to display the bounding boxes on.
|
1696
|
+
instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
|
1697
|
+
box information of each instance
|
1573
1698
|
|
1574
1699
|
Returns:
|
1575
|
-
|
1576
|
-
tuned model.
|
1700
|
+
np.ndarray: The image with the instance_id dislpayed
|
1577
1701
|
|
1578
1702
|
Example
|
1579
1703
|
-------
|
1580
|
-
>>>
|
1581
|
-
[{'
|
1582
|
-
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
|
1583
|
-
"OBJECT_DETECTION"
|
1704
|
+
>>> image_with_bboxes = overlay_counting_results(
|
1705
|
+
image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
|
1584
1706
|
)
|
1585
1707
|
"""
|
1586
|
-
|
1587
|
-
|
1588
|
-
fine_tuning_request = [
|
1589
|
-
BboxInputBase64(
|
1590
|
-
image=convert_to_b64(bbox_input.image_path),
|
1591
|
-
filename=bbox_input.image_path.split("/")[-1],
|
1592
|
-
labels=bbox_input.labels,
|
1593
|
-
bboxes=bbox_input.bboxes,
|
1594
|
-
)
|
1595
|
-
for bbox_input in bboxes_input
|
1596
|
-
]
|
1597
|
-
landing_api = LandingPublicAPI()
|
1598
|
-
return landing_api.launch_fine_tuning_job(
|
1599
|
-
"florencev2", task_input, fine_tuning_request
|
1600
|
-
)
|
1601
|
-
|
1708
|
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
1709
|
+
color = (158, 218, 229)
|
1602
1710
|
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1711
|
+
width, height = pil_image.size
|
1712
|
+
fontsize = max(10, int(min(width, height) / 80))
|
1713
|
+
pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
|
1714
|
+
draw = ImageDraw.Draw(pil_image)
|
1715
|
+
font = ImageFont.truetype(
|
1716
|
+
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
|
1717
|
+
fontsize,
|
1718
|
+
)
|
1611
1719
|
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
model_id (UUID): The fine-tuned model id.
|
1616
|
-
task (PromptTask): The florencev2 fine-tuning task. The options are
|
1617
|
-
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
|
1720
|
+
for i, elt in enumerate(instances):
|
1721
|
+
label = f"{i}"
|
1722
|
+
box = elt["bbox"]
|
1618
1723
|
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
1624
|
-
bounding box. The scores are always 1.0 and cannot be thresholded
|
1724
|
+
# denormalize the box if it is normalized
|
1725
|
+
box = denormalize_bbox(box, (height, width))
|
1726
|
+
x0, y0, x1, y1 = box
|
1727
|
+
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
1625
1728
|
|
1626
|
-
|
1627
|
-
|
1628
|
-
>>> florencev2_fine_tuned_object_detection(
|
1629
|
-
image,
|
1630
|
-
'person looking at a coyote',
|
1631
|
-
UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83")
|
1729
|
+
text_box = draw.textbbox(
|
1730
|
+
(cx, cy), text=label, font=font, align="center", anchor="mm"
|
1632
1731
|
)
|
1633
|
-
[
|
1634
|
-
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
1635
|
-
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
|
1636
|
-
]
|
1637
|
-
"""
|
1638
|
-
# check if job succeeded first
|
1639
|
-
landing_api = LandingPublicAPI()
|
1640
|
-
status = landing_api.check_fine_tuning_job(model_id)
|
1641
|
-
if status is not JobStatus.SUCCEEDED:
|
1642
|
-
raise FineTuneModelIsNotReady()
|
1643
|
-
|
1644
|
-
task = PromptTask[task]
|
1645
|
-
if task is PromptTask.OBJECT_DETECTION:
|
1646
|
-
prompt = ""
|
1647
|
-
|
1648
|
-
data_obj = Florencev2FtRequest(
|
1649
|
-
image=convert_to_b64(image),
|
1650
|
-
task=task,
|
1651
|
-
tool="florencev2_fine_tuning",
|
1652
|
-
prompt=prompt,
|
1653
|
-
fine_tuning=FineTuning(job_id=model_id),
|
1654
|
-
)
|
1655
|
-
data = data_obj.model_dump(by_alias=True)
|
1656
|
-
metadata_payload = {"function_name": "florencev2_fine_tuned_object_detection"}
|
1657
|
-
detections = send_inference_request(
|
1658
|
-
data, "tools", v2=False, metadata_payload=metadata_payload
|
1659
|
-
)
|
1660
1732
|
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
|
1670
|
-
|
1733
|
+
# Calculate the offset to center the text within the bounding box
|
1734
|
+
text_width = text_box[2] - text_box[0]
|
1735
|
+
text_height = text_box[3] - text_box[1]
|
1736
|
+
text_x0 = cx - text_width / 2
|
1737
|
+
text_y0 = cy - text_height / 2
|
1738
|
+
text_x1 = cx + text_width / 2
|
1739
|
+
text_y1 = cy + text_height / 2
|
1740
|
+
|
1741
|
+
# Draw the rectangle encapsulating the text
|
1742
|
+
draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
|
1743
|
+
|
1744
|
+
# Draw the text at the center of the bounding box
|
1745
|
+
draw.text(
|
1746
|
+
(text_x0, text_y0),
|
1747
|
+
label,
|
1748
|
+
fill="black",
|
1749
|
+
font=font,
|
1750
|
+
anchor="lt",
|
1671
1751
|
)
|
1672
|
-
|
1752
|
+
|
1753
|
+
return np.array(pil_image)
|
1673
1754
|
|
1674
1755
|
|
1675
1756
|
FUNCTION_TOOLS = [
|
@@ -1679,8 +1760,7 @@ FUNCTION_TOOLS = [
|
|
1679
1760
|
clip,
|
1680
1761
|
vit_image_classification,
|
1681
1762
|
vit_nsfw_classification,
|
1682
|
-
|
1683
|
-
loca_visual_prompt_counting,
|
1763
|
+
countgd_counting,
|
1684
1764
|
florence2_image_caption,
|
1685
1765
|
florence2_ocr,
|
1686
1766
|
florence2_sam2_image,
|
@@ -1703,6 +1783,7 @@ UTIL_TOOLS = [
|
|
1703
1783
|
overlay_bounding_boxes,
|
1704
1784
|
overlay_segmentation_masks,
|
1705
1785
|
overlay_heat_map,
|
1786
|
+
overlay_counting_results,
|
1706
1787
|
]
|
1707
1788
|
|
1708
1789
|
TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
|
@@ -1720,5 +1801,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
|
|
1720
1801
|
overlay_bounding_boxes,
|
1721
1802
|
overlay_segmentation_masks,
|
1722
1803
|
overlay_heat_map,
|
1804
|
+
overlay_counting_results,
|
1723
1805
|
]
|
1724
1806
|
)
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import List, Optional, Tuple
|
3
2
|
from uuid import UUID
|
3
|
+
from typing import List, Tuple, Optional, Union
|
4
4
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field,
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
|
6
6
|
|
7
7
|
|
8
8
|
class BboxInput(BaseModel):
|
@@ -19,16 +19,9 @@ class BboxInputBase64(BaseModel):
|
|
19
19
|
|
20
20
|
|
21
21
|
class PromptTask(str, Enum):
|
22
|
-
"""
|
23
|
-
Valid task prompts options for the Florencev2 model.
|
24
|
-
"""
|
22
|
+
"""Valid task prompts options for the Florence2 model."""
|
25
23
|
|
26
|
-
|
27
|
-
""""""
|
28
|
-
CAPTION_TO_PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
|
29
|
-
""""""
|
30
|
-
OBJECT_DETECTION = "<OD>"
|
31
|
-
""""""
|
24
|
+
PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
|
32
25
|
|
33
26
|
|
34
27
|
class FineTuning(BaseModel):
|
@@ -41,7 +34,7 @@ class FineTuning(BaseModel):
|
|
41
34
|
return str(job_id)
|
42
35
|
|
43
36
|
|
44
|
-
class
|
37
|
+
class Florence2FtRequest(BaseModel):
|
45
38
|
model_config = ConfigDict(populate_by_name=True)
|
46
39
|
|
47
40
|
image: str
|
@@ -82,3 +75,16 @@ class JobStatus(str, Enum):
|
|
82
75
|
SUCCEEDED = "SUCCEEDED"
|
83
76
|
FAILED = "FAILED"
|
84
77
|
STOPPED = "STOPPED"
|
78
|
+
|
79
|
+
|
80
|
+
class ODResponseData(BaseModel):
|
81
|
+
label: str
|
82
|
+
score: float
|
83
|
+
bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
|
84
|
+
|
85
|
+
model_config = ConfigDict(
|
86
|
+
populate_by_name=True,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
BoundingBoxes = list[ODResponseData]
|
vision_agent/utils/execute.py
CHANGED
@@ -564,7 +564,13 @@ class LocalCodeInterpreter(CodeInterpreter):
|
|
564
564
|
) -> None:
|
565
565
|
super().__init__(timeout=timeout)
|
566
566
|
self.nb = nbformat.v4.new_notebook()
|
567
|
-
|
567
|
+
# Set the notebook execution path to the remote path
|
568
|
+
self.resources = {"metadata": {"path": str(self.remote_path)}}
|
569
|
+
self.nb_client = NotebookClient(
|
570
|
+
self.nb,
|
571
|
+
timeout=self.timeout,
|
572
|
+
resources=self.resources,
|
573
|
+
)
|
568
574
|
_LOGGER.info(
|
569
575
|
f"""Local code interpreter initialized
|
570
576
|
Python version: {sys.version}
|
@@ -606,7 +612,9 @@ Timeout: {self.timeout}"""
|
|
606
612
|
def restart_kernel(self) -> None:
|
607
613
|
self.close()
|
608
614
|
self.nb = nbformat.v4.new_notebook()
|
609
|
-
self.nb_client = NotebookClient(
|
615
|
+
self.nb_client = NotebookClient(
|
616
|
+
self.nb, timeout=self.timeout, resources=self.resources
|
617
|
+
)
|
610
618
|
sleep(1)
|
611
619
|
self._new_kernel()
|
612
620
|
|
@@ -636,7 +644,7 @@ Timeout: {self.timeout}"""
|
|
636
644
|
f.write(contents)
|
637
645
|
_LOGGER.info(f"File ({file_path}) is uploaded to: {str(self.remote_path)}")
|
638
646
|
|
639
|
-
return Path(self.remote_path / file_path)
|
647
|
+
return Path(self.remote_path / Path(file_path).name)
|
640
648
|
|
641
649
|
def download_file(
|
642
650
|
self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
|
@@ -672,7 +680,8 @@ class CodeInterpreterFactory:
|
|
672
680
|
|
673
681
|
@staticmethod
|
674
682
|
def new_instance(
|
675
|
-
code_sandbox_runtime: Optional[str] = None,
|
683
|
+
code_sandbox_runtime: Optional[str] = None,
|
684
|
+
remote_path: Optional[Union[str, Path]] = None,
|
676
685
|
) -> CodeInterpreter:
|
677
686
|
if not code_sandbox_runtime:
|
678
687
|
code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
|
@@ -181,7 +181,7 @@ def denormalize_bbox(
|
|
181
181
|
raise ValueError("Bounding box must be of length 4.")
|
182
182
|
|
183
183
|
arr = np.array(bbox)
|
184
|
-
if np.all((arr >= 0) & (arr <= 1)):
|
184
|
+
if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
|
185
185
|
x1, y1, x2, y2 = bbox
|
186
186
|
x1 = round(x1 * image_size[1])
|
187
187
|
y1 = round(y1 * image_size[0])
|
@@ -2,32 +2,32 @@ vision_agent/__init__.py,sha256=EAb4-f9iyuEYkBrX4ag1syM8Syx8118_t0R6_C34M9w,57
|
|
2
2
|
vision_agent/agent/__init__.py,sha256=FRwiux1FGvGccetyUCtY46KP01fQteqorm-JtFepovI,176
|
3
3
|
vision_agent/agent/agent.py,sha256=2cjIOxEuSJrqbfPXYoV0qER5ihXsPFCoEFJa4jpqan0,597
|
4
4
|
vision_agent/agent/agent_utils.py,sha256=22LiPhkJlS5mVeo2dIi259pc2NgA7PGHRpcbnrtKo78,1930
|
5
|
-
vision_agent/agent/vision_agent.py,sha256=
|
6
|
-
vision_agent/agent/vision_agent_coder.py,sha256=
|
7
|
-
vision_agent/agent/vision_agent_coder_prompts.py,sha256=
|
8
|
-
vision_agent/agent/vision_agent_prompts.py,sha256=
|
5
|
+
vision_agent/agent/vision_agent.py,sha256=WM1_o0VAQokAKlDr-0lpFxCRwUm_eFfFNWP-wSNjo7s,11180
|
6
|
+
vision_agent/agent/vision_agent_coder.py,sha256=ujctkpmQkX2C6YXjlp7VLZFqSB00xwkGe-9swA8Gv8s,34240
|
7
|
+
vision_agent/agent/vision_agent_coder_prompts.py,sha256=Rg7-Ih7oFgFbHFFno0EHpaZEgm0SYj_nTdqqdp21YLo,11246
|
8
|
+
vision_agent/agent/vision_agent_prompts.py,sha256=K1nLo3XKQ-IqCom1TRwh3cMoGZNxNwEgZqf3uJ6eL18,7221
|
9
9
|
vision_agent/clients/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
vision_agent/clients/http.py,sha256=k883i6M_4nl7zwwHSI-yP5sAgQZIDPM1nrKD6YFJ3Xs,2009
|
11
11
|
vision_agent/clients/landing_public_api.py,sha256=rGtACkr8o5egDuMHQ5MBO4NuvsgPTp9Ew3rbq4R-vs0,1507
|
12
12
|
vision_agent/fonts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
13
|
vision_agent/fonts/default_font_ch_en.ttf,sha256=1YM0Z3XqLDjSNbF7ihQFSAIUdjF9m1rtHiNC_6QosTE,1594400
|
14
14
|
vision_agent/lmm/__init__.py,sha256=YuUZRsMHdn8cMOv6iBU8yUqlIOLrbZQqZl9KPnofsHQ,103
|
15
|
-
vision_agent/lmm/lmm.py,sha256=
|
15
|
+
vision_agent/lmm/lmm.py,sha256=H3a5V7c073-vXRJfQOblE2j_CsZkH1CNNRoQgLjJZuQ,20751
|
16
16
|
vision_agent/lmm/types.py,sha256=ZEXR_ptBL0ZwDMTDYkgxUCmSZFmBYPQd2jreNzr_8UY,221
|
17
|
-
vision_agent/tools/__init__.py,sha256=
|
18
|
-
vision_agent/tools/meta_tools.py,sha256=
|
17
|
+
vision_agent/tools/__init__.py,sha256=TILaqdFYicScvpnCXMxgBsFmSW22NQDIvucvEgo0etw,2289
|
18
|
+
vision_agent/tools/meta_tools.py,sha256=KeGiw2OtY8ARpGbtWjoNAoO1dwevt7LbCupaJX61MkE,18929
|
19
19
|
vision_agent/tools/prompts.py,sha256=V1z4YJLXZuUl_iZ5rY0M5hHc_2tmMEUKr0WocXKGt4E,1430
|
20
|
-
vision_agent/tools/tool_utils.py,sha256=
|
21
|
-
vision_agent/tools/tools.py,sha256=
|
22
|
-
vision_agent/tools/tools_types.py,sha256=
|
20
|
+
vision_agent/tools/tool_utils.py,sha256=e_p-G2nwgWOpoaqpDitY3FJ6fFuTEg5GhDOD67wI2bE,7527
|
21
|
+
vision_agent/tools/tools.py,sha256=jOBsuN-spY_2TlvpahoRYGvyInhQDTPXXukx9q72lEU,63454
|
22
|
+
vision_agent/tools/tools_types.py,sha256=qs11HGLRXc9zytahBtG6TQxCh8Gigvn232at3jk54jI,2356
|
23
23
|
vision_agent/utils/__init__.py,sha256=pWk0ktvR4aUEhuEIzSLM9kSgW4WDVqptdvOTeGLkJ6M,230
|
24
24
|
vision_agent/utils/exceptions.py,sha256=booSPSuoULF7OXRr_YbC4dtKt6gM_HyiFQHBuaW86C4,2052
|
25
|
-
vision_agent/utils/execute.py,sha256=
|
26
|
-
vision_agent/utils/image_utils.py,sha256=
|
25
|
+
vision_agent/utils/execute.py,sha256=gc4R_0BKUrZyhiKvIxOpYuzQPYVWQEqxr3ANy1lJAw4,27037
|
26
|
+
vision_agent/utils/image_utils.py,sha256=UloC4byIQLM4CSCaH41SBciQ7X2OqKvsVvNOVKqIH_k,9856
|
27
27
|
vision_agent/utils/sim.py,sha256=ebE9Cs00pVEDI1HMjAzUBk88tQQmc2U-yAzIDinnekU,5572
|
28
28
|
vision_agent/utils/type_defs.py,sha256=BE12s3JNQy36QvauXHjwyeffVh5enfcvd4vTzSwvEZI,1384
|
29
29
|
vision_agent/utils/video.py,sha256=rNmU9KEIkZB5-EztZNlUiKYN0mm_55A_2VGUM0QpqLA,8779
|
30
|
-
vision_agent-0.2.
|
31
|
-
vision_agent-0.2.
|
32
|
-
vision_agent-0.2.
|
33
|
-
vision_agent-0.2.
|
30
|
+
vision_agent-0.2.122.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
31
|
+
vision_agent-0.2.122.dist-info/METADATA,sha256=WMdLNPyKY4Ot6ifOzwXNDiVm2TsStY-l-ge8t72Ynhk,12255
|
32
|
+
vision_agent-0.2.122.dist-info/WHEEL,sha256=7Z8_27uaHI_UZAc4Uox4PpBhQ9Y5_modZXWMxtUi4NU,88
|
33
|
+
vision_agent-0.2.122.dist-info/RECORD,,
|
File without changes
|
File without changes
|