vision-agent 0.2.120__tar.gz → 0.2.122__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {vision_agent-0.2.120 → vision_agent-0.2.122}/PKG-INFO +1 -1
- {vision_agent-0.2.120 → vision_agent-0.2.122}/pyproject.toml +1 -1
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/vision_agent.py +10 -6
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/vision_agent_coder.py +1 -9
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/vision_agent_coder_prompts.py +4 -5
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/vision_agent_prompts.py +3 -3
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/lmm/lmm.py +0 -3
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/tools/__init__.py +3 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/tools/meta_tools.py +140 -8
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/tools/tool_utils.py +95 -51
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/tools/tools.py +196 -114
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/tools/tools_types.py +18 -12
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/execute.py +13 -4
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/image_utils.py +1 -1
- {vision_agent-0.2.120 → vision_agent-0.2.122}/LICENSE +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/README.md +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/agent.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/agent_utils.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/clients/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/clients/http.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/clients/landing_public_api.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/fonts/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/fonts/default_font_ch_en.ttf +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/lmm/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/lmm/types.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/tools/prompts.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/__init__.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/exceptions.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/sim.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/type_defs.py +0 -0
- {vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/utils/video.py +0 -0
@@ -30,7 +30,7 @@ class BoilerplateCode:
|
|
30
30
|
pre_code = [
|
31
31
|
"from typing import *",
|
32
32
|
"from vision_agent.utils.execute import CodeInterpreter",
|
33
|
-
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
|
33
|
+
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, florence2_fine_tuning, use_florence2_fine_tuning",
|
34
34
|
"artifacts = Artifacts('{remote_path}')",
|
35
35
|
"artifacts.load('{remote_path}')",
|
36
36
|
]
|
@@ -76,11 +76,16 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
|
|
76
76
|
|
77
77
|
def run_code_action(
|
78
78
|
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
|
79
|
-
) -> Execution:
|
80
|
-
|
79
|
+
) -> Tuple[Execution, str]:
|
80
|
+
result = code_interpreter.exec_isolation(
|
81
81
|
BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
|
82
82
|
)
|
83
83
|
|
84
|
+
obs = str(result.logs)
|
85
|
+
if result.error:
|
86
|
+
obs += f"\n{result.error}"
|
87
|
+
return result, obs
|
88
|
+
|
84
89
|
|
85
90
|
def parse_execution(response: str) -> Optional[str]:
|
86
91
|
code = None
|
@@ -192,7 +197,7 @@ class VisionAgent(Agent):
|
|
192
197
|
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
|
193
198
|
|
194
199
|
with CodeInterpreterFactory.new_instance(
|
195
|
-
code_sandbox_runtime=self.code_sandbox_runtime
|
200
|
+
code_sandbox_runtime=self.code_sandbox_runtime,
|
196
201
|
) as code_interpreter:
|
197
202
|
orig_chat = copy.deepcopy(chat)
|
198
203
|
int_chat = copy.deepcopy(chat)
|
@@ -260,10 +265,9 @@ class VisionAgent(Agent):
|
|
260
265
|
code_action = parse_execution(response["response"])
|
261
266
|
|
262
267
|
if code_action is not None:
|
263
|
-
result = run_code_action(
|
268
|
+
result, obs = run_code_action(
|
264
269
|
code_action, code_interpreter, str(remote_artifacts_path)
|
265
270
|
)
|
266
|
-
obs = str(result.logs)
|
267
271
|
|
268
272
|
if self.verbosity >= 1:
|
269
273
|
_LOGGER.info(obs)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import copy
|
2
|
-
import difflib
|
3
2
|
import logging
|
4
3
|
import os
|
5
4
|
import sys
|
@@ -29,6 +28,7 @@ from vision_agent.agent.vision_agent_coder_prompts import (
|
|
29
28
|
USER_REQ,
|
30
29
|
)
|
31
30
|
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
|
31
|
+
from vision_agent.tools.meta_tools import get_diff
|
32
32
|
from vision_agent.utils import CodeInterpreterFactory, Execution
|
33
33
|
from vision_agent.utils.execute import CodeInterpreter
|
34
34
|
from vision_agent.utils.image_utils import b64_to_pil
|
@@ -63,14 +63,6 @@ class DefaultImports:
|
|
63
63
|
return DefaultImports.to_code_string() + "\n\n" + code
|
64
64
|
|
65
65
|
|
66
|
-
def get_diff(before: str, after: str) -> str:
|
67
|
-
return "".join(
|
68
|
-
difflib.unified_diff(
|
69
|
-
before.splitlines(keepends=True), after.splitlines(keepends=True)
|
70
|
-
)
|
71
|
-
)
|
72
|
-
|
73
|
-
|
74
66
|
def format_memory(memory: List[Dict[str, str]]) -> str:
|
75
67
|
output_str = ""
|
76
68
|
for i, m in enumerate(memory):
|
{vision_agent-0.2.120 → vision_agent-0.2.122}/vision_agent/agent/vision_agent_coder_prompts.py
RENAMED
@@ -81,20 +81,19 @@ plan2:
|
|
81
81
|
- Count the number of detected objects labeled as 'person'.
|
82
82
|
plan3:
|
83
83
|
- Load the image from the provided file path 'image.jpg'.
|
84
|
-
- Use the '
|
84
|
+
- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people.
|
85
85
|
|
86
86
|
```python
|
87
|
-
from vision_agent.tools import load_image, owl_v2, grounding_sam,
|
87
|
+
from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting
|
88
88
|
image = load_image("image.jpg")
|
89
89
|
owl_v2_out = owl_v2("person", image)
|
90
90
|
|
91
91
|
gsam_out = grounding_sam("person", image)
|
92
92
|
gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out]
|
93
93
|
|
94
|
-
|
95
|
-
loca_out = loca_out["count"]
|
94
|
+
cgd_out = countgd_counting(image)
|
96
95
|
|
97
|
-
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "
|
96
|
+
final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}}
|
98
97
|
print(final_out)
|
99
98
|
```
|
100
99
|
"""
|
@@ -48,7 +48,7 @@ OBSERVATION:
|
|
48
48
|
4| return dogs
|
49
49
|
[End of artifact]
|
50
50
|
|
51
|
-
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
|
51
|
+
AGENT: {"thoughts": "I have generated the code to detect the dogs in the image, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/example/workspace/dog.jpg'))</execute_python>", "let_user_respond": false}
|
52
52
|
|
53
53
|
OBSERVATION:
|
54
54
|
----- stdout -----
|
@@ -75,7 +75,7 @@ OBSERVATION:
|
|
75
75
|
4| return dogs
|
76
76
|
[End of artifact]
|
77
77
|
|
78
|
-
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
|
78
|
+
AGENT: {"thoughts": "I have edited the code to detect only one dog, I must now run the code and print the results to get the output.", "response": "<execute_python>from dog_detector import detect_dogs\n print(detect_dogs('/path/to/images/dog.jpg'))</execute_python>", "let_user_respond": false}
|
79
79
|
|
80
80
|
OBSERVATION:
|
81
81
|
----- stdout -----
|
@@ -126,7 +126,7 @@ OBSERVATION:
|
|
126
126
|
15| return count
|
127
127
|
[End of artifact]
|
128
128
|
|
129
|
-
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code
|
129
|
+
AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "<execute_python>from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')</execute_python>", "let_user_respond": false}
|
130
130
|
|
131
131
|
OBSERVATION:
|
132
132
|
----- stdout -----
|
@@ -286,9 +286,6 @@ class OpenAILMM(LMM):
|
|
286
286
|
|
287
287
|
return lambda x: T.grounding_sam(params["prompt"], x)
|
288
288
|
|
289
|
-
def generate_zero_shot_counter(self, question: str) -> Callable:
|
290
|
-
return T.loca_zero_shot_counting
|
291
|
-
|
292
289
|
def generate_image_qa_tool(self, question: str) -> Callable:
|
293
290
|
return lambda x: T.git_vqa_v2(question, x)
|
294
291
|
|
@@ -37,10 +37,13 @@ from .tools import (
|
|
37
37
|
load_image,
|
38
38
|
loca_visual_prompt_counting,
|
39
39
|
loca_zero_shot_counting,
|
40
|
+
countgd_counting,
|
41
|
+
countgd_example_based_counting,
|
40
42
|
ocr,
|
41
43
|
overlay_bounding_boxes,
|
42
44
|
overlay_heat_map,
|
43
45
|
overlay_segmentation_masks,
|
46
|
+
overlay_counting_results,
|
44
47
|
owl_v2,
|
45
48
|
save_image,
|
46
49
|
save_json,
|
@@ -1,5 +1,7 @@
|
|
1
|
+
import difflib
|
1
2
|
import os
|
2
3
|
import pickle as pkl
|
4
|
+
import re
|
3
5
|
import subprocess
|
4
6
|
import tempfile
|
5
7
|
from pathlib import Path
|
@@ -8,10 +10,13 @@ from typing import Any, Dict, List, Optional, Union
|
|
8
10
|
from IPython.display import display
|
9
11
|
|
10
12
|
import vision_agent as va
|
13
|
+
from vision_agent.clients.landing_public_api import LandingPublicAPI
|
11
14
|
from vision_agent.lmm.types import Message
|
12
15
|
from vision_agent.tools.tool_utils import get_tool_documentation
|
13
16
|
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
|
17
|
+
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
|
14
18
|
from vision_agent.utils.execute import Execution, MimeType
|
19
|
+
from vision_agent.utils.image_utils import convert_to_b64
|
15
20
|
|
16
21
|
# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent
|
17
22
|
|
@@ -99,13 +104,14 @@ class Artifacts:
|
|
99
104
|
|
100
105
|
def show(self) -> str:
|
101
106
|
"""Shows the artifacts that have been loaded and their remote save paths."""
|
102
|
-
|
107
|
+
output_str = "[Artifacts loaded]\n"
|
103
108
|
for k in self.artifacts.keys():
|
104
|
-
|
109
|
+
output_str += (
|
105
110
|
f"Artifact {k} loaded to {str(self.remote_save_path.parent / k)}\n"
|
106
111
|
)
|
107
|
-
|
108
|
-
|
112
|
+
output_str += "[End of artifacts]\n"
|
113
|
+
print(output_str)
|
114
|
+
return output_str
|
109
115
|
|
110
116
|
def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
|
111
117
|
save_path = (
|
@@ -135,7 +141,12 @@ def format_lines(lines: List[str], start_idx: int) -> str:
|
|
135
141
|
|
136
142
|
|
137
143
|
def view_lines(
|
138
|
-
lines: List[str],
|
144
|
+
lines: List[str],
|
145
|
+
line_num: int,
|
146
|
+
window_size: int,
|
147
|
+
name: str,
|
148
|
+
total_lines: int,
|
149
|
+
print_output: bool = True,
|
139
150
|
) -> str:
|
140
151
|
start = max(0, line_num - window_size)
|
141
152
|
end = min(len(lines), line_num + window_size)
|
@@ -148,7 +159,9 @@ def view_lines(
|
|
148
159
|
else f"[{len(lines) - end} more lines]"
|
149
160
|
)
|
150
161
|
)
|
151
|
-
|
162
|
+
|
163
|
+
if print_output:
|
164
|
+
print(return_str)
|
152
165
|
return return_str
|
153
166
|
|
154
167
|
|
@@ -231,7 +244,7 @@ def edit_code_artifact(
|
|
231
244
|
new_content_lines = [
|
232
245
|
line if line.endswith("\n") else line + "\n" for line in new_content_lines
|
233
246
|
]
|
234
|
-
lines = artifacts[name].splitlines()
|
247
|
+
lines = artifacts[name].splitlines(keepends=True)
|
235
248
|
edited_lines = lines[:start] + new_content_lines + lines[end:]
|
236
249
|
|
237
250
|
cur_line = start + len(content.split("\n")) // 2
|
@@ -261,13 +274,20 @@ def edit_code_artifact(
|
|
261
274
|
DEFAULT_WINDOW_SIZE,
|
262
275
|
name,
|
263
276
|
total_lines,
|
277
|
+
print_output=False,
|
264
278
|
)
|
265
279
|
total_lines_edit = sum(1 for _ in edited_lines)
|
266
280
|
edited_view = view_lines(
|
267
|
-
edited_lines,
|
281
|
+
edited_lines,
|
282
|
+
cur_line,
|
283
|
+
DEFAULT_WINDOW_SIZE,
|
284
|
+
name,
|
285
|
+
total_lines_edit,
|
286
|
+
print_output=False,
|
268
287
|
)
|
269
288
|
|
270
289
|
error_msg += f"\n[This is how your edit would have looked like if applied]\n{edited_view}\n\n[This is the original code before your edit]\n{original_view}"
|
290
|
+
print(error_msg)
|
271
291
|
return error_msg
|
272
292
|
|
273
293
|
artifacts[name] = "".join(edited_lines)
|
@@ -390,6 +410,13 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
|
|
390
410
|
return f"[Media {Path(local_path).name} saved]"
|
391
411
|
|
392
412
|
|
413
|
+
def list_artifacts(artifacts: Artifacts) -> str:
|
414
|
+
"""Lists all the artifacts that have been loaded into the artifacts object."""
|
415
|
+
output_str = artifacts.show()
|
416
|
+
print(output_str)
|
417
|
+
return output_str
|
418
|
+
|
419
|
+
|
393
420
|
def get_tool_descriptions() -> str:
|
394
421
|
"""Returns a description of all the tools that `generate_vision_code` has access to.
|
395
422
|
Helpful for answering questions about what types of vision tasks you can do with
|
@@ -397,6 +424,108 @@ def get_tool_descriptions() -> str:
|
|
397
424
|
return TOOL_DESCRIPTIONS
|
398
425
|
|
399
426
|
|
427
|
+
def florence2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> str:
|
428
|
+
"""'florence2_fine_tuning' is a tool that fine-tune florence2 to be able to detect
|
429
|
+
objects in an image based on a given dataset. It returns the fine tuning job id.
|
430
|
+
|
431
|
+
Parameters:
|
432
|
+
bboxes (List[BboxInput]): A list of BboxInput containing the
|
433
|
+
image path, labels and bounding boxes.
|
434
|
+
task (str): The florencev2 fine-tuning task. The options are
|
435
|
+
'phrase_grounding'.
|
436
|
+
|
437
|
+
Returns:
|
438
|
+
UUID: The fine tuning job id, this id will used to retrieve the fine
|
439
|
+
tuned model.
|
440
|
+
|
441
|
+
Example
|
442
|
+
-------
|
443
|
+
>>> fine_tuning_job_id = florencev2_fine_tuning(
|
444
|
+
[{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[370, 30, 560, 290]]},
|
445
|
+
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
|
446
|
+
"phrase_grounding"
|
447
|
+
)
|
448
|
+
"""
|
449
|
+
bboxes_input = [BboxInput.model_validate(bbox) for bbox in bboxes]
|
450
|
+
task_type = PromptTask[task.upper()]
|
451
|
+
fine_tuning_request = [
|
452
|
+
BboxInputBase64(
|
453
|
+
image=convert_to_b64(bbox_input.image_path),
|
454
|
+
filename=Path(bbox_input.image_path).name,
|
455
|
+
labels=bbox_input.labels,
|
456
|
+
bboxes=bbox_input.bboxes,
|
457
|
+
)
|
458
|
+
for bbox_input in bboxes_input
|
459
|
+
]
|
460
|
+
landing_api = LandingPublicAPI()
|
461
|
+
fine_tune_id = str(
|
462
|
+
landing_api.launch_fine_tuning_job("florencev2", task_type, fine_tuning_request)
|
463
|
+
)
|
464
|
+
print(f"[Florence2 fine tuning id: {fine_tune_id}]")
|
465
|
+
return fine_tune_id
|
466
|
+
|
467
|
+
|
468
|
+
def get_diff(before: str, after: str) -> str:
|
469
|
+
return "".join(
|
470
|
+
difflib.unified_diff(
|
471
|
+
before.splitlines(keepends=True), after.splitlines(keepends=True)
|
472
|
+
)
|
473
|
+
)
|
474
|
+
|
475
|
+
|
476
|
+
def use_florence2_fine_tuning(
|
477
|
+
artifacts: Artifacts, name: str, task: str, fine_tune_id: str
|
478
|
+
) -> str:
|
479
|
+
"""Replaces florence2 calls with the fine tuning id. This ensures that the code
|
480
|
+
utilizes the fined tuned florence2 model. Returns the diff between the original
|
481
|
+
code and the new code.
|
482
|
+
|
483
|
+
Parameters:
|
484
|
+
artifacts (Artifacts): The artifacts object to edit the code from.
|
485
|
+
name (str): The name of the artifact to edit.
|
486
|
+
task (str): The task to fine tune the model for. The options are
|
487
|
+
'phrase_grounding'.
|
488
|
+
fine_tune_id (str): The fine tuning job id.
|
489
|
+
|
490
|
+
Examples
|
491
|
+
--------
|
492
|
+
>>> diff = use_florence2_fine_tuning(artifacts, "code.py", "phrase_grounding", "23b3b022-5ebf-4798-9373-20ef36429abf")
|
493
|
+
"""
|
494
|
+
|
495
|
+
task_to_fn = {"phrase_grounding": "florence2_phrase_grounding"}
|
496
|
+
|
497
|
+
if name not in artifacts:
|
498
|
+
output_str = f"[Artifact {name} does not exist]"
|
499
|
+
print(output_str)
|
500
|
+
return output_str
|
501
|
+
|
502
|
+
code = artifacts[name]
|
503
|
+
if task.lower() == "phrase_grounding":
|
504
|
+
pattern = r"florence2_phrase_grounding\(\s*([^\)]+)\)"
|
505
|
+
|
506
|
+
def replacer(match: re.Match) -> str:
|
507
|
+
arg = match.group(1) # capture all initial arguments
|
508
|
+
return f'florence2_phrase_grounding({arg}, "{fine_tune_id}")'
|
509
|
+
|
510
|
+
else:
|
511
|
+
raise ValueError(f"Task {task} is not supported.")
|
512
|
+
|
513
|
+
new_code = re.sub(pattern, replacer, code)
|
514
|
+
|
515
|
+
if new_code == code:
|
516
|
+
output_str = (
|
517
|
+
f"[Fine tuning task {task} function {task_to_fn[task]} not found in code]"
|
518
|
+
)
|
519
|
+
print(output_str)
|
520
|
+
return output_str
|
521
|
+
|
522
|
+
artifacts[name] = new_code
|
523
|
+
|
524
|
+
diff = get_diff(code, new_code)
|
525
|
+
print(diff)
|
526
|
+
return diff
|
527
|
+
|
528
|
+
|
400
529
|
META_TOOL_DOCSTRING = get_tool_documentation(
|
401
530
|
[
|
402
531
|
get_tool_descriptions,
|
@@ -406,5 +535,8 @@ META_TOOL_DOCSTRING = get_tool_documentation(
|
|
406
535
|
generate_vision_code,
|
407
536
|
edit_vision_code,
|
408
537
|
write_media_artifact,
|
538
|
+
florence2_fine_tuning,
|
539
|
+
use_florence2_fine_tuning,
|
540
|
+
list_artifacts,
|
409
541
|
]
|
410
542
|
)
|
@@ -1,6 +1,6 @@
|
|
1
|
+
import os
|
1
2
|
import inspect
|
2
3
|
import logging
|
3
|
-
import os
|
4
4
|
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
|
5
5
|
|
6
6
|
import pandas as pd
|
@@ -13,6 +13,7 @@ from urllib3.util.retry import Retry
|
|
13
13
|
from vision_agent.utils.exceptions import RemoteToolCallFailed
|
14
14
|
from vision_agent.utils.execute import Error, MimeType
|
15
15
|
from vision_agent.utils.type_defs import LandingaiAPIKey
|
16
|
+
from vision_agent.tools.tools_types import BoundingBoxes
|
16
17
|
|
17
18
|
_LOGGER = logging.getLogger(__name__)
|
18
19
|
_LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key)
|
@@ -34,61 +35,58 @@ def send_inference_request(
|
|
34
35
|
files: Optional[List[Tuple[Any, ...]]] = None,
|
35
36
|
v2: bool = False,
|
36
37
|
metadata_payload: Optional[Dict[str, Any]] = None,
|
37
|
-
) ->
|
38
|
+
) -> Any:
|
38
39
|
# TODO: runtime_tag and function_name should be metadata_payload and now included
|
39
40
|
# in the service payload
|
40
|
-
|
41
|
-
|
42
|
-
|
41
|
+
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
|
42
|
+
payload["runtime_tag"] = runtime_tag
|
43
|
+
|
44
|
+
url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}"
|
45
|
+
if "TOOL_ENDPOINT_URL" in os.environ:
|
46
|
+
url = os.environ["TOOL_ENDPOINT_URL"]
|
47
|
+
|
48
|
+
headers = {"apikey": _LND_API_KEY}
|
49
|
+
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
50
|
+
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
51
|
+
headers.pop("apikey")
|
52
|
+
|
53
|
+
session = _create_requests_session(
|
54
|
+
url=url,
|
55
|
+
num_retry=3,
|
56
|
+
headers=headers,
|
57
|
+
)
|
43
58
|
|
44
|
-
|
45
|
-
|
46
|
-
|
59
|
+
function_name = "unknown"
|
60
|
+
if "function_name" in payload:
|
61
|
+
function_name = payload["function_name"]
|
62
|
+
elif metadata_payload is not None and "function_name" in metadata_payload:
|
63
|
+
function_name = metadata_payload["function_name"]
|
47
64
|
|
48
|
-
|
49
|
-
endpoint_url=url,
|
50
|
-
request=payload,
|
51
|
-
response={},
|
52
|
-
error=None,
|
53
|
-
)
|
54
|
-
headers = {"apikey": _LND_API_KEY}
|
55
|
-
if "TOOL_ENDPOINT_AUTH" in os.environ:
|
56
|
-
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
|
57
|
-
headers.pop("apikey")
|
58
|
-
|
59
|
-
session = _create_requests_session(
|
60
|
-
url=url,
|
61
|
-
num_retry=3,
|
62
|
-
headers=headers,
|
63
|
-
)
|
65
|
+
response = _call_post(url, payload, session, files, function_name)
|
64
66
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
finally:
|
89
|
-
trace = tool_call_trace.model_dump()
|
90
|
-
trace["type"] = "tool_call"
|
91
|
-
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
67
|
+
# TODO: consider making the response schema the same between below two sources
|
68
|
+
return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"]
|
69
|
+
|
70
|
+
|
71
|
+
def send_task_inference_request(
|
72
|
+
payload: Dict[str, Any],
|
73
|
+
task_name: str,
|
74
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
75
|
+
metadata: Optional[Dict[str, Any]] = None,
|
76
|
+
) -> Any:
|
77
|
+
url = f"{_LND_API_URL_v2}/{task_name}"
|
78
|
+
headers = {"apikey": _LND_API_KEY}
|
79
|
+
session = _create_requests_session(
|
80
|
+
url=url,
|
81
|
+
num_retry=3,
|
82
|
+
headers=headers,
|
83
|
+
)
|
84
|
+
|
85
|
+
function_name = "unknown"
|
86
|
+
if metadata is not None and "function_name" in metadata:
|
87
|
+
function_name = metadata["function_name"]
|
88
|
+
response = _call_post(url, payload, session, files, function_name)
|
89
|
+
return response["data"]
|
92
90
|
|
93
91
|
|
94
92
|
def _create_requests_session(
|
@@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]:
|
|
195
193
|
data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}"
|
196
194
|
|
197
195
|
return data
|
196
|
+
|
197
|
+
|
198
|
+
def _call_post(
|
199
|
+
url: str,
|
200
|
+
payload: dict[str, Any],
|
201
|
+
session: Session,
|
202
|
+
files: Optional[List[Tuple[Any, ...]]] = None,
|
203
|
+
function_name: str = "unknown",
|
204
|
+
) -> Any:
|
205
|
+
try:
|
206
|
+
tool_call_trace = ToolCallTrace(
|
207
|
+
endpoint_url=url,
|
208
|
+
request=payload,
|
209
|
+
response={},
|
210
|
+
error=None,
|
211
|
+
)
|
212
|
+
|
213
|
+
if files is not None:
|
214
|
+
response = session.post(url, data=payload, files=files)
|
215
|
+
else:
|
216
|
+
response = session.post(url, json=payload)
|
217
|
+
|
218
|
+
if response.status_code != 200:
|
219
|
+
tool_call_trace.error = Error(
|
220
|
+
name="RemoteToolCallFailed",
|
221
|
+
value=f"{response.status_code} - {response.text}",
|
222
|
+
traceback_raw=[],
|
223
|
+
)
|
224
|
+
_LOGGER.error(f"Request failed: {response.status_code} {response.text}")
|
225
|
+
raise RemoteToolCallFailed(
|
226
|
+
function_name, response.status_code, response.text
|
227
|
+
)
|
228
|
+
|
229
|
+
result = response.json()
|
230
|
+
tool_call_trace.response = result
|
231
|
+
return result
|
232
|
+
finally:
|
233
|
+
trace = tool_call_trace.model_dump()
|
234
|
+
trace["type"] = "tool_call"
|
235
|
+
display({MimeType.APPLICATION_JSON: trace}, raw=True)
|
236
|
+
|
237
|
+
|
238
|
+
def filter_bboxes_by_threshold(
|
239
|
+
bboxes: BoundingBoxes, threshold: float
|
240
|
+
) -> BoundingBoxes:
|
241
|
+
return list(filter(lambda bbox: bbox.score >= threshold, bboxes))
|
@@ -13,7 +13,7 @@ import cv2
|
|
13
13
|
import numpy as np
|
14
14
|
import requests
|
15
15
|
from moviepy.editor import ImageSequenceClip
|
16
|
-
from PIL import Image, ImageDraw, ImageFont
|
16
|
+
from PIL import Image, ImageDraw, ImageFont, ImageEnhance
|
17
17
|
from pillow_heif import register_heif_opener # type: ignore
|
18
18
|
from pytube import YouTube # type: ignore
|
19
19
|
|
@@ -24,14 +24,15 @@ from vision_agent.tools.tool_utils import (
|
|
24
24
|
get_tools_df,
|
25
25
|
get_tools_info,
|
26
26
|
send_inference_request,
|
27
|
+
send_task_inference_request,
|
28
|
+
filter_bboxes_by_threshold,
|
27
29
|
)
|
28
30
|
from vision_agent.tools.tools_types import (
|
29
|
-
BboxInput,
|
30
|
-
BboxInputBase64,
|
31
31
|
FineTuning,
|
32
|
-
|
32
|
+
Florence2FtRequest,
|
33
33
|
JobStatus,
|
34
34
|
PromptTask,
|
35
|
+
ODResponseData,
|
35
36
|
)
|
36
37
|
from vision_agent.utils import extract_frames_from_video
|
37
38
|
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
|
@@ -455,7 +456,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
|
|
455
456
|
"image": image_b64,
|
456
457
|
"function_name": "loca_zero_shot_counting",
|
457
458
|
}
|
458
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
459
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
459
460
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
460
461
|
return resp_data
|
461
462
|
|
@@ -469,6 +470,8 @@ def loca_visual_prompt_counting(
|
|
469
470
|
|
470
471
|
Parameters:
|
471
472
|
image (np.ndarray): The image that contains lot of instances of a single object
|
473
|
+
visual_prompt (Dict[str, List[float]]): Bounding box of the object in format
|
474
|
+
[xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided.
|
472
475
|
|
473
476
|
Returns:
|
474
477
|
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
|
@@ -496,11 +499,109 @@ def loca_visual_prompt_counting(
|
|
496
499
|
"bbox": list(map(int, denormalize_bbox(bbox, image_size))),
|
497
500
|
"function_name": "loca_visual_prompt_counting",
|
498
501
|
}
|
499
|
-
resp_data = send_inference_request(data, "loca", v2=True)
|
502
|
+
resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True)
|
500
503
|
resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8)
|
501
504
|
return resp_data
|
502
505
|
|
503
506
|
|
507
|
+
def countgd_counting(
|
508
|
+
prompt: str,
|
509
|
+
image: np.ndarray,
|
510
|
+
box_threshold: float = 0.23,
|
511
|
+
) -> List[Dict[str, Any]]:
|
512
|
+
"""'countgd_counting' is a tool that can precisely count multiple instances of an
|
513
|
+
object given a text prompt. It returns a list of bounding boxes with normalized
|
514
|
+
coordinates, label names and associated confidence scores.
|
515
|
+
|
516
|
+
Parameters:
|
517
|
+
prompt (str): The object that needs to be counted.
|
518
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
519
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
520
|
+
to 0.23.
|
521
|
+
|
522
|
+
Returns:
|
523
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
524
|
+
bounding box of the detected objects with normalized coordinates between 0
|
525
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
526
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
527
|
+
bounding box.
|
528
|
+
|
529
|
+
Example
|
530
|
+
-------
|
531
|
+
>>> countgd_counting("flower", image)
|
532
|
+
[
|
533
|
+
{'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
534
|
+
{'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5},
|
535
|
+
{'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52},
|
536
|
+
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
|
537
|
+
]
|
538
|
+
"""
|
539
|
+
buffer_bytes = numpy_to_bytes(image)
|
540
|
+
files = [("image", buffer_bytes)]
|
541
|
+
prompt = prompt.replace(", ", " .")
|
542
|
+
payload = {"prompts": [prompt], "model": "countgd"}
|
543
|
+
metadata = {"function_name": "countgd_counting"}
|
544
|
+
resp_data = send_task_inference_request(
|
545
|
+
payload, "text-to-object-detection", files=files, metadata=metadata
|
546
|
+
)
|
547
|
+
bboxes_per_frame = resp_data[0]
|
548
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
549
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
550
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
551
|
+
|
552
|
+
|
553
|
+
def countgd_example_based_counting(
|
554
|
+
visual_prompts: List[List[float]],
|
555
|
+
image: np.ndarray,
|
556
|
+
box_threshold: float = 0.23,
|
557
|
+
) -> List[Dict[str, Any]]:
|
558
|
+
"""'countgd_example_based_counting' is a tool that can precisely count multiple
|
559
|
+
instances of an object given few visual example prompts. It returns a list of bounding
|
560
|
+
boxes with normalized coordinates, label names and associated confidence scores.
|
561
|
+
|
562
|
+
Parameters:
|
563
|
+
visual_prompts (List[List[float]]): Bounding boxes of the object in format
|
564
|
+
[xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided.
|
565
|
+
image (np.ndarray): The image that contains multiple instances of the object.
|
566
|
+
box_threshold (float, optional): The threshold for detection. Defaults
|
567
|
+
to 0.23.
|
568
|
+
|
569
|
+
Returns:
|
570
|
+
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
571
|
+
bounding box of the detected objects with normalized coordinates between 0
|
572
|
+
and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the
|
573
|
+
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
574
|
+
bounding box.
|
575
|
+
|
576
|
+
Example
|
577
|
+
-------
|
578
|
+
>>> countgd_example_based_counting(
|
579
|
+
visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]],
|
580
|
+
image=image
|
581
|
+
)
|
582
|
+
[
|
583
|
+
{'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]},
|
584
|
+
{'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5},
|
585
|
+
{'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52},
|
586
|
+
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
|
587
|
+
]
|
588
|
+
"""
|
589
|
+
buffer_bytes = numpy_to_bytes(image)
|
590
|
+
files = [("image", buffer_bytes)]
|
591
|
+
visual_prompts = [
|
592
|
+
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
|
593
|
+
]
|
594
|
+
payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"}
|
595
|
+
metadata = {"function_name": "countgd_example_based_counting"}
|
596
|
+
resp_data = send_task_inference_request(
|
597
|
+
payload, "visual-prompts-to-object-detection", files=files, metadata=metadata
|
598
|
+
)
|
599
|
+
bboxes_per_frame = resp_data[0]
|
600
|
+
bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame]
|
601
|
+
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
|
602
|
+
return [bbox.model_dump() for bbox in filtered_bboxes]
|
603
|
+
|
604
|
+
|
504
605
|
def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
|
505
606
|
"""'florence2_roberta_vqa' is a tool that takes an image and analyzes
|
506
607
|
its contents, generates detailed captions and then tries to answer the given
|
@@ -646,7 +747,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
|
|
646
747
|
"tool": "closed_set_image_classification",
|
647
748
|
"function_name": "clip",
|
648
749
|
}
|
649
|
-
resp_data = send_inference_request(data, "tools")
|
750
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
650
751
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
651
752
|
return resp_data
|
652
753
|
|
@@ -674,7 +775,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
674
775
|
"tool": "image_classification",
|
675
776
|
"function_name": "vit_image_classification",
|
676
777
|
}
|
677
|
-
resp_data = send_inference_request(data, "tools")
|
778
|
+
resp_data: dict[str, Any] = send_inference_request(data, "tools")
|
678
779
|
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
|
679
780
|
return resp_data
|
680
781
|
|
@@ -701,7 +802,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
|
|
701
802
|
"image": image_b64,
|
702
803
|
"function_name": "vit_nsfw_classification",
|
703
804
|
}
|
704
|
-
resp_data = send_inference_request(
|
805
|
+
resp_data: dict[str, Any] = send_inference_request(
|
806
|
+
data, "nsfw-classification", v2=True
|
807
|
+
)
|
705
808
|
resp_data["score"] = round(resp_data["score"], 4)
|
706
809
|
return resp_data
|
707
810
|
|
@@ -762,7 +865,9 @@ def florence2_image_caption(image: np.ndarray, detail_caption: bool = True) -> s
|
|
762
865
|
return answer[task] # type: ignore
|
763
866
|
|
764
867
|
|
765
|
-
def florence2_phrase_grounding(
|
868
|
+
def florence2_phrase_grounding(
|
869
|
+
prompt: str, image: np.ndarray, fine_tune_id: Optional[str] = None
|
870
|
+
) -> List[Dict[str, Any]]:
|
766
871
|
"""'florence2_phrase_grounding' is a tool that can detect multiple
|
767
872
|
objects given a text prompt which can be object names or caption. You
|
768
873
|
can optionally separate the object names in the text with commas. It returns a list
|
@@ -772,6 +877,8 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
|
|
772
877
|
Parameters:
|
773
878
|
prompt (str): The prompt to ground to the image.
|
774
879
|
image (np.ndarray): The image to used to detect objects
|
880
|
+
fine_tune_id (Optional[str]): If you have a fine-tuned model, you can pass the
|
881
|
+
fine-tuned model ID here to use it.
|
775
882
|
|
776
883
|
Returns:
|
777
884
|
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
|
@@ -790,14 +897,33 @@ def florence2_phrase_grounding(prompt: str, image: np.ndarray) -> List[Dict[str,
|
|
790
897
|
"""
|
791
898
|
image_size = image.shape[:2]
|
792
899
|
image_b64 = convert_to_b64(image)
|
793
|
-
data = {
|
794
|
-
"image": image_b64,
|
795
|
-
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
|
796
|
-
"prompt": prompt,
|
797
|
-
"function_name": "florence2_phrase_grounding",
|
798
|
-
}
|
799
900
|
|
800
|
-
|
901
|
+
if fine_tune_id is not None:
|
902
|
+
landing_api = LandingPublicAPI()
|
903
|
+
status = landing_api.check_fine_tuning_job(UUID(fine_tune_id))
|
904
|
+
if status is not JobStatus.SUCCEEDED:
|
905
|
+
raise FineTuneModelIsNotReady(
|
906
|
+
f"Fine-tuned model {fine_tune_id} is not ready yet"
|
907
|
+
)
|
908
|
+
|
909
|
+
data_obj = Florence2FtRequest(
|
910
|
+
image=image_b64,
|
911
|
+
task=PromptTask.PHRASE_GROUNDING,
|
912
|
+
tool="florencev2_fine_tuning",
|
913
|
+
prompt=prompt,
|
914
|
+
fine_tuning=FineTuning(job_id=UUID(fine_tune_id)),
|
915
|
+
)
|
916
|
+
data = data_obj.model_dump(by_alias=True)
|
917
|
+
detections = send_inference_request(data, "tools", v2=False)
|
918
|
+
else:
|
919
|
+
data = {
|
920
|
+
"image": image_b64,
|
921
|
+
"task": "<CAPTION_TO_PHRASE_GROUNDING>",
|
922
|
+
"prompt": prompt,
|
923
|
+
"function_name": "florence2_phrase_grounding",
|
924
|
+
}
|
925
|
+
detections = send_inference_request(data, "florence2", v2=True)
|
926
|
+
|
801
927
|
detections = detections["<CAPTION_TO_PHRASE_GROUNDING>"]
|
802
928
|
return_data = []
|
803
929
|
for i in range(len(detections["bboxes"])):
|
@@ -1559,117 +1685,72 @@ def overlay_heat_map(
|
|
1559
1685
|
return np.array(combined)
|
1560
1686
|
|
1561
1687
|
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1688
|
+
def overlay_counting_results(
|
1689
|
+
image: np.ndarray, instances: List[Dict[str, Any]]
|
1690
|
+
) -> np.ndarray:
|
1691
|
+
"""'overlay_counting_results' is a utility function that displays counting results on
|
1692
|
+
an image.
|
1567
1693
|
|
1568
1694
|
Parameters:
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
|
1695
|
+
image (np.ndarray): The image to display the bounding boxes on.
|
1696
|
+
instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
|
1697
|
+
box information of each instance
|
1573
1698
|
|
1574
1699
|
Returns:
|
1575
|
-
|
1576
|
-
tuned model.
|
1700
|
+
np.ndarray: The image with the instance_id dislpayed
|
1577
1701
|
|
1578
1702
|
Example
|
1579
1703
|
-------
|
1580
|
-
>>>
|
1581
|
-
[{'
|
1582
|
-
{'image_path': 'filename.png', 'labels': ['screw'], 'bboxes': [[120, 0, 300, 170]]}],
|
1583
|
-
"OBJECT_DETECTION"
|
1704
|
+
>>> image_with_bboxes = overlay_counting_results(
|
1705
|
+
image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
|
1584
1706
|
)
|
1585
1707
|
"""
|
1586
|
-
|
1587
|
-
|
1588
|
-
fine_tuning_request = [
|
1589
|
-
BboxInputBase64(
|
1590
|
-
image=convert_to_b64(bbox_input.image_path),
|
1591
|
-
filename=bbox_input.image_path.split("/")[-1],
|
1592
|
-
labels=bbox_input.labels,
|
1593
|
-
bboxes=bbox_input.bboxes,
|
1594
|
-
)
|
1595
|
-
for bbox_input in bboxes_input
|
1596
|
-
]
|
1597
|
-
landing_api = LandingPublicAPI()
|
1598
|
-
return landing_api.launch_fine_tuning_job(
|
1599
|
-
"florencev2", task_input, fine_tuning_request
|
1600
|
-
)
|
1601
|
-
|
1708
|
+
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
|
1709
|
+
color = (158, 218, 229)
|
1602
1710
|
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1711
|
+
width, height = pil_image.size
|
1712
|
+
fontsize = max(10, int(min(width, height) / 80))
|
1713
|
+
pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
|
1714
|
+
draw = ImageDraw.Draw(pil_image)
|
1715
|
+
font = ImageFont.truetype(
|
1716
|
+
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
|
1717
|
+
fontsize,
|
1718
|
+
)
|
1611
1719
|
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
model_id (UUID): The fine-tuned model id.
|
1616
|
-
task (PromptTask): The florencev2 fine-tuning task. The options are
|
1617
|
-
CAPTION, CAPTION_TO_PHRASE_GROUNDING and OBJECT_DETECTION.
|
1720
|
+
for i, elt in enumerate(instances):
|
1721
|
+
label = f"{i}"
|
1722
|
+
box = elt["bbox"]
|
1618
1723
|
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
top-left and xmax and ymax are the coordinates of the bottom-right of the
|
1624
|
-
bounding box. The scores are always 1.0 and cannot be thresholded
|
1724
|
+
# denormalize the box if it is normalized
|
1725
|
+
box = denormalize_bbox(box, (height, width))
|
1726
|
+
x0, y0, x1, y1 = box
|
1727
|
+
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
1625
1728
|
|
1626
|
-
|
1627
|
-
|
1628
|
-
>>> florencev2_fine_tuned_object_detection(
|
1629
|
-
image,
|
1630
|
-
'person looking at a coyote',
|
1631
|
-
UUID("381cd5f9-5dc4-472d-9260-f3bb89d31f83")
|
1729
|
+
text_box = draw.textbbox(
|
1730
|
+
(cx, cy), text=label, font=font, align="center", anchor="mm"
|
1632
1731
|
)
|
1633
|
-
[
|
1634
|
-
{'score': 1.0, 'label': 'person', 'bbox': [0.1, 0.11, 0.35, 0.4]},
|
1635
|
-
{'score': 1.0, 'label': 'coyote', 'bbox': [0.34, 0.21, 0.85, 0.5},
|
1636
|
-
]
|
1637
|
-
"""
|
1638
|
-
# check if job succeeded first
|
1639
|
-
landing_api = LandingPublicAPI()
|
1640
|
-
status = landing_api.check_fine_tuning_job(model_id)
|
1641
|
-
if status is not JobStatus.SUCCEEDED:
|
1642
|
-
raise FineTuneModelIsNotReady()
|
1643
|
-
|
1644
|
-
task = PromptTask[task]
|
1645
|
-
if task is PromptTask.OBJECT_DETECTION:
|
1646
|
-
prompt = ""
|
1647
|
-
|
1648
|
-
data_obj = Florencev2FtRequest(
|
1649
|
-
image=convert_to_b64(image),
|
1650
|
-
task=task,
|
1651
|
-
tool="florencev2_fine_tuning",
|
1652
|
-
prompt=prompt,
|
1653
|
-
fine_tuning=FineTuning(job_id=model_id),
|
1654
|
-
)
|
1655
|
-
data = data_obj.model_dump(by_alias=True)
|
1656
|
-
metadata_payload = {"function_name": "florencev2_fine_tuned_object_detection"}
|
1657
|
-
detections = send_inference_request(
|
1658
|
-
data, "tools", v2=False, metadata_payload=metadata_payload
|
1659
|
-
)
|
1660
1732
|
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
|
1670
|
-
|
1733
|
+
# Calculate the offset to center the text within the bounding box
|
1734
|
+
text_width = text_box[2] - text_box[0]
|
1735
|
+
text_height = text_box[3] - text_box[1]
|
1736
|
+
text_x0 = cx - text_width / 2
|
1737
|
+
text_y0 = cy - text_height / 2
|
1738
|
+
text_x1 = cx + text_width / 2
|
1739
|
+
text_y1 = cy + text_height / 2
|
1740
|
+
|
1741
|
+
# Draw the rectangle encapsulating the text
|
1742
|
+
draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
|
1743
|
+
|
1744
|
+
# Draw the text at the center of the bounding box
|
1745
|
+
draw.text(
|
1746
|
+
(text_x0, text_y0),
|
1747
|
+
label,
|
1748
|
+
fill="black",
|
1749
|
+
font=font,
|
1750
|
+
anchor="lt",
|
1671
1751
|
)
|
1672
|
-
|
1752
|
+
|
1753
|
+
return np.array(pil_image)
|
1673
1754
|
|
1674
1755
|
|
1675
1756
|
FUNCTION_TOOLS = [
|
@@ -1679,8 +1760,7 @@ FUNCTION_TOOLS = [
|
|
1679
1760
|
clip,
|
1680
1761
|
vit_image_classification,
|
1681
1762
|
vit_nsfw_classification,
|
1682
|
-
|
1683
|
-
loca_visual_prompt_counting,
|
1763
|
+
countgd_counting,
|
1684
1764
|
florence2_image_caption,
|
1685
1765
|
florence2_ocr,
|
1686
1766
|
florence2_sam2_image,
|
@@ -1703,6 +1783,7 @@ UTIL_TOOLS = [
|
|
1703
1783
|
overlay_bounding_boxes,
|
1704
1784
|
overlay_segmentation_masks,
|
1705
1785
|
overlay_heat_map,
|
1786
|
+
overlay_counting_results,
|
1706
1787
|
]
|
1707
1788
|
|
1708
1789
|
TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
|
@@ -1720,5 +1801,6 @@ UTILITIES_DOCSTRING = get_tool_documentation(
|
|
1720
1801
|
overlay_bounding_boxes,
|
1721
1802
|
overlay_segmentation_masks,
|
1722
1803
|
overlay_heat_map,
|
1804
|
+
overlay_counting_results,
|
1723
1805
|
]
|
1724
1806
|
)
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import List, Optional, Tuple
|
3
2
|
from uuid import UUID
|
3
|
+
from typing import List, Tuple, Optional, Union
|
4
4
|
|
5
|
-
from pydantic import BaseModel, ConfigDict, Field,
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo
|
6
6
|
|
7
7
|
|
8
8
|
class BboxInput(BaseModel):
|
@@ -19,16 +19,9 @@ class BboxInputBase64(BaseModel):
|
|
19
19
|
|
20
20
|
|
21
21
|
class PromptTask(str, Enum):
|
22
|
-
"""
|
23
|
-
Valid task prompts options for the Florencev2 model.
|
24
|
-
"""
|
22
|
+
"""Valid task prompts options for the Florence2 model."""
|
25
23
|
|
26
|
-
|
27
|
-
""""""
|
28
|
-
CAPTION_TO_PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
|
29
|
-
""""""
|
30
|
-
OBJECT_DETECTION = "<OD>"
|
31
|
-
""""""
|
24
|
+
PHRASE_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
|
32
25
|
|
33
26
|
|
34
27
|
class FineTuning(BaseModel):
|
@@ -41,7 +34,7 @@ class FineTuning(BaseModel):
|
|
41
34
|
return str(job_id)
|
42
35
|
|
43
36
|
|
44
|
-
class
|
37
|
+
class Florence2FtRequest(BaseModel):
|
45
38
|
model_config = ConfigDict(populate_by_name=True)
|
46
39
|
|
47
40
|
image: str
|
@@ -82,3 +75,16 @@ class JobStatus(str, Enum):
|
|
82
75
|
SUCCEEDED = "SUCCEEDED"
|
83
76
|
FAILED = "FAILED"
|
84
77
|
STOPPED = "STOPPED"
|
78
|
+
|
79
|
+
|
80
|
+
class ODResponseData(BaseModel):
|
81
|
+
label: str
|
82
|
+
score: float
|
83
|
+
bbox: Union[list[int], list[float]] = Field(alias="bounding_box")
|
84
|
+
|
85
|
+
model_config = ConfigDict(
|
86
|
+
populate_by_name=True,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
BoundingBoxes = list[ODResponseData]
|
@@ -564,7 +564,13 @@ class LocalCodeInterpreter(CodeInterpreter):
|
|
564
564
|
) -> None:
|
565
565
|
super().__init__(timeout=timeout)
|
566
566
|
self.nb = nbformat.v4.new_notebook()
|
567
|
-
|
567
|
+
# Set the notebook execution path to the remote path
|
568
|
+
self.resources = {"metadata": {"path": str(self.remote_path)}}
|
569
|
+
self.nb_client = NotebookClient(
|
570
|
+
self.nb,
|
571
|
+
timeout=self.timeout,
|
572
|
+
resources=self.resources,
|
573
|
+
)
|
568
574
|
_LOGGER.info(
|
569
575
|
f"""Local code interpreter initialized
|
570
576
|
Python version: {sys.version}
|
@@ -606,7 +612,9 @@ Timeout: {self.timeout}"""
|
|
606
612
|
def restart_kernel(self) -> None:
|
607
613
|
self.close()
|
608
614
|
self.nb = nbformat.v4.new_notebook()
|
609
|
-
self.nb_client = NotebookClient(
|
615
|
+
self.nb_client = NotebookClient(
|
616
|
+
self.nb, timeout=self.timeout, resources=self.resources
|
617
|
+
)
|
610
618
|
sleep(1)
|
611
619
|
self._new_kernel()
|
612
620
|
|
@@ -636,7 +644,7 @@ Timeout: {self.timeout}"""
|
|
636
644
|
f.write(contents)
|
637
645
|
_LOGGER.info(f"File ({file_path}) is uploaded to: {str(self.remote_path)}")
|
638
646
|
|
639
|
-
return Path(self.remote_path / file_path)
|
647
|
+
return Path(self.remote_path / Path(file_path).name)
|
640
648
|
|
641
649
|
def download_file(
|
642
650
|
self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
|
@@ -672,7 +680,8 @@ class CodeInterpreterFactory:
|
|
672
680
|
|
673
681
|
@staticmethod
|
674
682
|
def new_instance(
|
675
|
-
code_sandbox_runtime: Optional[str] = None,
|
683
|
+
code_sandbox_runtime: Optional[str] = None,
|
684
|
+
remote_path: Optional[Union[str, Path]] = None,
|
676
685
|
) -> CodeInterpreter:
|
677
686
|
if not code_sandbox_runtime:
|
678
687
|
code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
|
@@ -181,7 +181,7 @@ def denormalize_bbox(
|
|
181
181
|
raise ValueError("Bounding box must be of length 4.")
|
182
182
|
|
183
183
|
arr = np.array(bbox)
|
184
|
-
if np.all((arr >= 0) & (arr <= 1)):
|
184
|
+
if np.all((arr[:2] >= 0) & (arr[:2] <= 1)):
|
185
185
|
x1, y1, x2, y2 = bbox
|
186
186
|
x1 = round(x1 * image_size[1])
|
187
187
|
y1 = round(y1 * image_size[0])
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|