bohr-agent-sdk 0.1.103__py3-none-any.whl → 0.1.105__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bohr-agent-sdk
3
- Version: 0.1.103
3
+ Version: 0.1.105
4
4
  Summary: SDK for scientific agents
5
5
  Home-page: https://github.com/dptech-corp/bohr-agent-sdk/
6
6
  Author: DP Technology
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: 3.13
19
19
  Requires-Python: >=3.10
20
20
  Description-Content-Type: text/markdown
21
21
  Requires-Dist: click>=8.0.0
22
- Requires-Dist: mcp
22
+ Requires-Dist: mcp>=1.17.0
23
23
  Requires-Dist: python-dotenv>=1.0.0
24
24
  Requires-Dist: typing-extensions>=4.8.0
25
25
  Requires-Dist: dpdispatcher>=0.6.8
@@ -30,6 +30,7 @@ Requires-Dist: paho-mqtt>=2.1.0
30
30
  Requires-Dist: redis>=6.2.0
31
31
  Requires-Dist: twine>=6.1.0
32
32
  Requires-Dist: build>=1.2.2.post1
33
+ Requires-Dist: cloudpickle==2.2.0
33
34
  Requires-Dist: watchdog>=6.0.0
34
35
  Requires-Dist: fastapi>=0.116.0
35
36
  Requires-Dist: bohrium-open-sdk
@@ -60,21 +60,21 @@ dp/agent/device/device/__init__.py,sha256=w7_1S16S1vWUq0RGl0GFgjq2vFkc5oNvy8cQTn
60
60
  dp/agent/device/device/device.py,sha256=9ZRIJth-4qMO-i-u_b_cO3d6a4eTbTQjPaxFsV_zEkc,9643
61
61
  dp/agent/device/device/types.py,sha256=JuxB-hjf1CjjvfBxCLwRAXVFlYS-nPEdiJpBWLFVCzo,1924
62
62
  dp/agent/server/__init__.py,sha256=rckaYd8pbYyB4ENEhgjXKeGMXjdnrgcJpdM1gu5u1Wc,508
63
- dp/agent/server/calculation_mcp_server.py,sha256=eClRP7A-t5hMGyTh81KC3GAKjSPNJIylOrOKyzqwo8o,11459
63
+ dp/agent/server/calculation_mcp_server.py,sha256=iRFOdgTxySMGk7ZaSseNssEp-A7zT5cW1Ym2_MIKnG4,12602
64
64
  dp/agent/server/preprocessor.py,sha256=XUWu7QOwo_sIDMYS2b1OTrM33EXEVH_73vk-ju1Ok8A,1264
65
- dp/agent/server/utils.py,sha256=8jgYZEW4XBp86AF2Km6QkwHltBmrnS-soTpHov7ZEJw,4501
65
+ dp/agent/server/utils.py,sha256=ui3lca9EagcGqmYf8BKLsPARIzXxJ3jgN98yuEO3OSQ,1668
66
66
  dp/agent/server/executor/__init__.py,sha256=s95M5qKQk39Yi9qaVJZhk_nfj54quSf7EDghR3OCFUA,248
67
67
  dp/agent/server/executor/base_executor.py,sha256=EFJBsYVYAvuRbiLAbLOwLTw3h7ScjN025xnSP4uJHrQ,2052
68
- dp/agent/server/executor/dispatcher_executor.py,sha256=urpzmKH_tBOgblBdJEa3y8eEhXqUDrdcdWCnUdJpfZk,9420
68
+ dp/agent/server/executor/dispatcher_executor.py,sha256=wUJEmCrLzckwinOd8Caf5H9TN-YVfHLpfYgkVFd2OC0,10828
69
69
  dp/agent/server/executor/local_executor.py,sha256=wYCclNZFkLb3v7KpW1nCnupO8piBES-esYlDAuz86zk,6120
70
70
  dp/agent/server/storage/__init__.py,sha256=Sgsyp5hb0_hhIGugAPfQFzBHt_854rS_MuMuE3sn8Gs,389
71
71
  dp/agent/server/storage/base_storage.py,sha256=728-oNG6N8isV95gZVnyi4vTznJPJhSjxw9Gl5Y_y5o,2356
72
72
  dp/agent/server/storage/bohrium_storage.py,sha256=EsKX4dWWvZTn2TEhZv4zsvihfDK0mmPFecrln-Ytk40,10488
73
- dp/agent/server/storage/http_storage.py,sha256=w0lY95wQqKmjXTGFRhEG2hLu8GBFwgqG8ocm5lJ_fYc,1470
73
+ dp/agent/server/storage/http_storage.py,sha256=KiySq7g9-iJr12XQCKKyJLn8wJoDnSRpQAR5_qPJ1ZU,1471
74
74
  dp/agent/server/storage/local_storage.py,sha256=t1wfjByjXew9ws3PuUxWxmZQ0-Wt1a6t4wmj3fW62GI,1352
75
75
  dp/agent/server/storage/oss_storage.py,sha256=pgjmi7Gir3Y5wkMDCvU4fvSls15fXT7Ax-h9MYHFPK0,3359
76
- bohr_agent_sdk-0.1.103.dist-info/METADATA,sha256=hds24-rnglRxU193mo5VosDHUGte8Zk_Cmu12umIQF0,10184
77
- bohr_agent_sdk-0.1.103.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
78
- bohr_agent_sdk-0.1.103.dist-info/entry_points.txt,sha256=5n5kneF5IbDQtoQ2WfF-QuBjDtsimJte9Rv9baSGgc0,86
79
- bohr_agent_sdk-0.1.103.dist-info/top_level.txt,sha256=87xLUDhu_1nQHoGLwlhJ6XlO7OsjILh6i1nX6ljFzDo,3
80
- bohr_agent_sdk-0.1.103.dist-info/RECORD,,
76
+ bohr_agent_sdk-0.1.105.dist-info/METADATA,sha256=tNAvWoQAxoKaWU1pfKhK8t7F-r_YyD0GZvVdX4F-OMo,10226
77
+ bohr_agent_sdk-0.1.105.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
78
+ bohr_agent_sdk-0.1.105.dist-info/entry_points.txt,sha256=5n5kneF5IbDQtoQ2WfF-QuBjDtsimJte9Rv9baSGgc0,86
79
+ bohr_agent_sdk-0.1.105.dist-info/top_level.txt,sha256=87xLUDhu_1nQHoGLwlhJ6XlO7OsjILh6i1nX6ljFzDo,3
80
+ bohr_agent_sdk-0.1.105.dist-info/RECORD,,
@@ -6,17 +6,23 @@ from copy import deepcopy
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
8
  from urllib.parse import urlparse
9
- from typing import Literal, Optional, get_origin
9
+ from typing import Any, Literal, Optional, TypedDict
10
10
 
11
+ import mcp
11
12
  from mcp.server.fastmcp import FastMCP
12
- from mcp.server.fastmcp.server import Context
13
- from mcp.server.fastmcp.utilities.func_metadata import _get_typed_signature
13
+ from mcp.server.fastmcp.utilities.context_injection import (
14
+ find_context_parameter,
15
+ )
16
+ from mcp.server.fastmcp.utilities.func_metadata import (
17
+ _get_typed_signature,
18
+ func_metadata,
19
+ )
14
20
  from starlette.responses import JSONResponse
15
21
  from starlette.routing import Route
16
22
 
17
23
  from .executor import executor_dict
18
24
  from .storage import storage_dict
19
- from .utils import get_logger, get_metadata, convert_to_content, Tool
25
+ from .utils import get_logger, JobResult, Tool
20
26
  logger = get_logger(__name__)
21
27
 
22
28
 
@@ -133,7 +139,7 @@ def handle_output_artifacts(results, exec_id, storage):
133
139
 
134
140
 
135
141
  def get_job_results(job_id: str, executor: Optional[dict] = None,
136
- storage: Optional[dict] = None) -> dict:
142
+ storage: Optional[dict] = None) -> Any:
137
143
  """
138
144
  Get results of a calculation job
139
145
  Args:
@@ -148,63 +154,80 @@ def get_job_results(job_id: str, executor: Optional[dict] = None,
148
154
  results, output_artifacts = handle_output_artifacts(
149
155
  results, exec_id, storage)
150
156
  logger.info("Job %s result is %s" % (job_id, results))
151
- return convert_to_content(results, job_info={
157
+ return JobResult(result=results, job_info={
152
158
  "output_artifacts": output_artifacts,
153
159
  })
154
160
 
155
161
 
156
162
  class CalculationMCPServer:
157
- def __init__(self, *args, preprocess_func=None, **kwargs):
163
+ def __init__(self, *args, preprocess_func=None, fastmcp_mode=False,
164
+ **kwargs):
165
+ """
166
+ Args:
167
+ preprocess_func: The preprocess function for all tools
168
+ fastmcp_mode: compatible for fastmcp.FastMCP
169
+ """
158
170
  self.preprocess_func = preprocess_func
171
+ self.fastmcp_mode = fastmcp_mode
159
172
  self.mcp = FastMCP(*args, **kwargs)
160
173
 
161
- def add_patched_tool(self, fn, new_fn, name, is_async=False, doc=None):
162
- # patch the metadata of the tool
163
- context_kwarg = None
164
- sig = inspect.signature(fn)
165
- for param_name, param in sig.parameters.items():
166
- if get_origin(param.annotation) is not None:
167
- continue
168
- if issubclass(param.annotation, Context):
169
- context_kwarg = param_name
170
- break
171
- # combine parameters
172
- parameters = []
173
- for param in _get_typed_signature(fn).parameters.values():
174
- if param.annotation is Path:
175
- parameters.append(inspect.Parameter(
176
- name=param.name, default=param.default,
177
- annotation=str, kind=param.kind))
178
- elif param.annotation is Optional[Path]:
179
- parameters.append(inspect.Parameter(
180
- name=param.name, default=param.default,
181
- annotation=Optional[str], kind=param.kind))
182
- else:
183
- parameters.append(param)
184
- for param in _get_typed_signature(new_fn).parameters.values():
185
- if param.name != "kwargs":
186
- parameters.append(param)
187
- func_arg_metadata = get_metadata(
188
- name,
189
- parameters=parameters,
174
+ def add_patched_tool(self, fn, new_fn, name, is_async=False, doc=None,
175
+ override_return_annotation=False):
176
+ """patch the metadata of the tool"""
177
+ context_kwarg = find_context_parameter(fn)
178
+
179
+ def _get_typed_signature_patched(call):
180
+ """patch parameters"""
181
+ typed_signature = _get_typed_signature(call)
182
+ new_typed_signature = _get_typed_signature(new_fn)
183
+ parameters = []
184
+ for param in typed_signature.parameters.values():
185
+ if param.annotation is Path:
186
+ parameters.append(inspect.Parameter(
187
+ name=param.name, default=param.default,
188
+ annotation=str, kind=param.kind))
189
+ elif param.annotation is Optional[Path]:
190
+ parameters.append(inspect.Parameter(
191
+ name=param.name, default=param.default,
192
+ annotation=Optional[str], kind=param.kind))
193
+ else:
194
+ parameters.append(param)
195
+ for param in new_typed_signature.parameters.values():
196
+ if param.name != "kwargs":
197
+ parameters.append(param)
198
+ return inspect.Signature(
199
+ parameters,
200
+ return_annotation=(new_typed_signature.return_annotation
201
+ if override_return_annotation
202
+ else typed_signature.return_annotation))
203
+
204
+ # Due to the frequent changes of MCP, we use a patching style here
205
+ mcp.server.fastmcp.utilities.func_metadata._get_typed_signature = \
206
+ _get_typed_signature_patched
207
+ func_arg_metadata = func_metadata(
208
+ fn,
190
209
  skip_names=[context_kwarg] if context_kwarg is not None else [],
191
- globalns=getattr(fn, "__globals__", {})
210
+ structured_output=None,
192
211
  )
193
- json_schema = func_arg_metadata.arg_model.model_json_schema()
212
+ mcp.server.fastmcp.utilities.func_metadata._get_typed_signature = \
213
+ _get_typed_signature
214
+ if self.fastmcp_mode and func_arg_metadata.wrap_output:
215
+ # Only simulate behavior of fastmcp for output_schema
216
+ func_arg_metadata.output_schema["x-fastmcp-wrap-result"] = True
217
+ parameters = func_arg_metadata.arg_model.model_json_schema(
218
+ by_alias=True)
194
219
  tool = Tool(
195
220
  fn=new_fn,
196
221
  name=name,
197
222
  description=doc or fn.__doc__,
198
- parameters=json_schema,
223
+ parameters=parameters,
199
224
  fn_metadata=func_arg_metadata,
200
225
  is_async=is_async,
201
226
  context_kwarg=context_kwarg,
202
- annotations=None,
203
227
  )
204
228
  self.mcp._tool_manager._tools[name] = tool
205
229
 
206
230
  def add_tool(self, fn, *args, **kwargs):
207
- self.mcp.add_tool(fn, *args, **kwargs)
208
231
  tool = Tool.from_function(fn, *args, **kwargs)
209
232
  self.mcp._tool_manager._tools[tool.name] = tool
210
233
  return tool
@@ -215,7 +238,9 @@ class CalculationMCPServer:
215
238
 
216
239
  def decorator(fn: Callable) -> Callable:
217
240
  def submit_job(executor: Optional[dict] = None,
218
- storage: Optional[dict] = None, **kwargs):
241
+ storage: Optional[dict] = None,
242
+ **kwargs) -> TypedDict("results", {
243
+ "job_id": str, "extra_info": Optional[dict]}):
219
244
  trace_id = datetime.today().strftime('%Y-%m-%d-%H:%M:%S.%f')
220
245
  logger.info("Job processing (Trace ID: %s)" % trace_id)
221
246
  with set_directory(trace_id):
@@ -233,7 +258,7 @@ class CalculationMCPServer:
233
258
  "job_id": job_id,
234
259
  "extra_info": res.get("extra_info"),
235
260
  }
236
- return convert_to_content(result, job_info={
261
+ return JobResult(result=result, job_info={
237
262
  "trace_id": trace_id,
238
263
  "executor_type": executor_type,
239
264
  "job_id": job_id,
@@ -263,7 +288,7 @@ class CalculationMCPServer:
263
288
  logger.info("Job %s result is %s" % (job_id, results))
264
289
  await context.log(level="info", message="Job %s result is"
265
290
  " %s" % (job_id, results))
266
- return convert_to_content(results, job_info={
291
+ return JobResult(result=results, job_info={
267
292
  "trace_id": trace_id,
268
293
  "executor_type": executor_type,
269
294
  "job_id": job_id,
@@ -273,8 +298,9 @@ class CalculationMCPServer:
273
298
  })
274
299
 
275
300
  self.add_patched_tool(fn, run_job, fn.__name__, is_async=True)
276
- self.add_patched_tool(fn, submit_job, "submit_" + fn.__name__,
277
- doc="Submit a job")
301
+ self.add_patched_tool(
302
+ fn, submit_job, "submit_" + fn.__name__, doc="Submit a job",
303
+ override_return_annotation=True)
278
304
  self.add_tool(query_job_status)
279
305
  self.add_tool(terminate_job)
280
306
  self.add_tool(get_job_results)
@@ -284,9 +310,10 @@ class CalculationMCPServer:
284
310
  def run(self, **kwargs):
285
311
  if os.environ.get("DP_AGENT_RUNNING_MODE") in ["1", "true"]:
286
312
  return
287
- async def health_check(request) :
313
+
314
+ async def health_check(request):
288
315
  return JSONResponse({"status": "ok"})
289
-
316
+
290
317
  self.mcp._custom_starlette_routes.append(
291
318
  Route(
292
319
  "/health",
@@ -9,6 +9,7 @@ from pathlib import Path
9
9
 
10
10
  import jsonpickle
11
11
  from dpdispatcher import Machine, Resources, Task, Submission
12
+ from dpdispatcher.utils.job_status import JobStatus
12
13
 
13
14
  from .base_executor import BaseExecutor
14
15
  from .... import __path__
@@ -33,6 +34,50 @@ def get_source_code(fn):
33
34
  return "".join(pre_lines + source_lines) + "\n"
34
35
 
35
36
 
37
+ def get_func_def_script(fn):
38
+ script = ""
39
+ packages = []
40
+ fn_name = fn.__name__
41
+ module_name = fn.__module__
42
+ module = sys.modules[module_name]
43
+ if getattr(module, fn_name, None) is not fn:
44
+ # cannot import from module, maybe a local function
45
+ import cloudpickle
46
+ packages.extend(cloudpickle.__path__)
47
+ script += "import cloudpickle\n"
48
+ script += "%s = cloudpickle.loads(%s)\n" % \
49
+ (fn_name, cloudpickle.dumps(fn))
50
+ elif module_name in ["__main__", "__mp_main__"]:
51
+ if hasattr(module, "__file__"):
52
+ name = os.path.splitext(os.path.basename(module.__file__))[0]
53
+ if getattr(module, "__package__", None):
54
+ package = module.__package__
55
+ package_name = package.split('.')[0]
56
+ module = importlib.import_module(package_name)
57
+ packages.extend(module.__path__)
58
+ script += "from %s.%s import %s\n" % (
59
+ package, name, fn_name)
60
+ else:
61
+ packages.append(module.__file__)
62
+ script += "from %s import %s\n" % (name, fn_name)
63
+ else:
64
+ # cannot get file of __main__, maybe in the interactive mode
65
+ import cloudpickle
66
+ packages.extend(cloudpickle.__path__)
67
+ script += "import cloudpickle\n"
68
+ script += "%s = cloudpickle.loads(%s)\n" % \
69
+ (fn_name, cloudpickle.dumps(fn))
70
+ else:
71
+ package_name = module_name.split('.')[0]
72
+ module = importlib.import_module(package_name)
73
+ if hasattr(module, "__path__"):
74
+ packages.extend(module.__path__)
75
+ elif hasattr(module, "__file__"):
76
+ packages.append(module.__file__)
77
+ script += "from %s import %s\n" % (module_name, fn_name)
78
+ return script, packages
79
+
80
+
36
81
  class DispatcherExecutor(BaseExecutor):
37
82
  def __init__(
38
83
  self,
@@ -86,34 +131,19 @@ class DispatcherExecutor(BaseExecutor):
86
131
  def submit(self, fn, kwargs):
87
132
  script = ""
88
133
  fn_name = fn.__name__
89
- module_name = fn.__module__
90
- import_func_line = None
91
- if module_name in ["__main__", "__mp_main__"]:
92
- module = sys.modules[module_name]
93
- if hasattr(module, "__file__"):
94
- self.python_packages.append(module.__file__)
95
- name = os.path.splitext(os.path.basename(module.__file__))[0]
96
- import_func_line = "from %s import %s\n" % (name, fn_name)
97
- else:
98
- script += get_source_code(fn)
99
- else:
100
- package_name = module_name.split('.')[0]
101
- module = importlib.import_module(package_name)
102
- if hasattr(module, "__path__"):
103
- self.python_packages.extend(module.__path__)
104
- elif hasattr(module, "__file__"):
105
- self.python_packages.append(module.__file__)
106
- import_func_line = "from %s import %s\n" % (module_name, fn_name)
134
+ func_def_script, packages = get_func_def_script(fn)
135
+ self.python_packages.extend(packages)
107
136
 
108
137
  script += "import asyncio, jsonpickle, os\n"
109
138
  script += "from pathlib import Path\n\n"
110
139
  script += "if __name__ == \"__main__\":\n"
111
140
  script += " cwd = os.getcwd()\n"
112
- script += " kwargs = jsonpickle.loads(r'''%s''')\n" % \
113
- jsonpickle.dumps(kwargs)
141
+ script += " kwargs = jsonpickle.loads(%s)\n" % repr(
142
+ jsonpickle.dumps(kwargs))
114
143
  script += " try:\n"
115
- if import_func_line is not None:
116
- script += " " + import_func_line
144
+ for line in func_def_script.splitlines():
145
+ if line:
146
+ script += " " + line + "\n"
117
147
  if inspect.iscoroutinefunction(fn):
118
148
  script += " results = asyncio.run(%s(**kwargs))\n" % fn_name
119
149
  else:
@@ -198,17 +228,23 @@ class DispatcherExecutor(BaseExecutor):
198
228
  submission = Submission.deserialize(
199
229
  submission_dict=json.loads(content))
200
230
  submission.update_submission_state()
201
- if not submission.check_all_finished():
231
+ if not submission.check_all_finished() and not any(
232
+ job.job_state in [JobStatus.terminated, JobStatus.unknown,
233
+ JobStatus.unsubmitted]
234
+ for job in submission.belonging_jobs):
202
235
  return "Running"
203
236
  try:
204
237
  submission.run_submission(exit_on_submit=True)
205
238
  except Exception as e:
206
239
  logger.error(e)
207
240
  return "Failed"
208
- if os.path.isfile("results.txt"):
209
- return "Succeeded"
241
+ if submission.check_all_finished():
242
+ if os.path.isfile("results.txt"):
243
+ return "Succeeded"
244
+ else:
245
+ return "Failed"
210
246
  else:
211
- return "Failed"
247
+ return "Running"
212
248
 
213
249
  def terminate(self, job_id):
214
250
  machine = Machine.load_from_dict(self.machine)
@@ -51,4 +51,4 @@ class HTTPStorage(BaseStorage):
51
51
 
52
52
 
53
53
  class HTTPSStorage(HTTPStorage):
54
- scheme = "https"
54
+ scheme = "https"
dp/agent/server/utils.py CHANGED
@@ -1,26 +1,10 @@
1
- import inspect
2
1
  import logging
3
2
  import traceback
4
- from collections.abc import Sequence
5
- from typing import Annotated, Any, List, Optional
3
+ from typing import Any
6
4
 
7
- import jsonpickle
8
5
  import mcp
9
- from mcp.server.fastmcp.exceptions import InvalidSignature
10
- from mcp.server.fastmcp.utilities.func_metadata import (
11
- ArgModelBase,
12
- _get_typed_annotation,
13
- FuncMetadata,
14
- )
15
- from mcp.server.fastmcp.utilities.types import Image
16
- from mcp.types import (
17
- EmbeddedResource,
18
- ImageContent,
19
- TextContent,
20
- )
21
- from pydantic import Field, WithJsonSchema, create_model
22
- from pydantic.fields import FieldInfo
23
- from pydantic_core import PydanticUndefined
6
+ from mcp.types import TextContent
7
+ from pydantic import BaseModel
24
8
 
25
9
 
26
10
  def get_logger(name, level="INFO",
@@ -34,110 +18,34 @@ def get_logger(name, level="INFO",
34
18
  return logger
35
19
 
36
20
 
37
- def get_metadata(
38
- func_name: str,
39
- parameters: List[inspect.Parameter],
40
- skip_names: Sequence[str] = (),
41
- globalns: dict = {},
42
- ) -> FuncMetadata:
43
- dynamic_pydantic_model_params: dict[str, Any] = {}
44
- for param in parameters:
45
- if param.name.startswith("_"):
46
- raise InvalidSignature(
47
- f"Parameter {param.name} of {func_name} cannot start with '_'"
48
- )
49
- if param.name in skip_names:
50
- continue
51
- annotation = param.annotation
52
-
53
- # `x: None` / `x: None = None`
54
- if annotation is None:
55
- annotation = Annotated[
56
- None,
57
- Field(
58
- default=param.default
59
- if param.default is not inspect.Parameter.empty
60
- else PydanticUndefined
61
- ),
62
- ]
63
-
64
- # Untyped field
65
- if annotation is inspect.Parameter.empty:
66
- annotation = Annotated[
67
- Any,
68
- Field(),
69
- # 🤷
70
- WithJsonSchema({"title": param.name, "type": "string"}),
71
- ]
72
-
73
- field_info = FieldInfo.from_annotated_attribute(
74
- _get_typed_annotation(annotation, globalns),
75
- param.default
76
- if param.default is not inspect.Parameter.empty
77
- else PydanticUndefined,
78
- )
79
- dynamic_pydantic_model_params[param.name] = (
80
- field_info.annotation, field_info)
81
- continue
82
-
83
- arguments_model = create_model(
84
- f"{func_name}Arguments",
85
- **dynamic_pydantic_model_params,
86
- __base__=ArgModelBase,
87
- )
88
- resp = FuncMetadata(arg_model=arguments_model)
89
- return resp
90
-
91
-
92
- def convert_to_content(
93
- result: Any,
94
- job_info: Optional[dict] = None,
95
- ) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
96
- """Convert a result to a sequence of content objects."""
97
- other_contents = []
98
- if isinstance(result, Image):
99
- other_contents.append(result.to_image_content())
100
- result = None
101
-
102
- if isinstance(result, TextContent | ImageContent | EmbeddedResource):
103
- other_contents.append(result)
104
- result = None
105
-
106
- if isinstance(result, list | tuple):
107
- for item in result.copy():
108
- if isinstance(item, Image):
109
- other_contents.append(item.to_image_content())
110
- result.remove(item)
111
- elif isinstance(
112
- result, TextContent | ImageContent | EmbeddedResource):
113
- other_contents.append(item)
114
- result.remove(item)
115
-
116
- if isinstance(result, dict):
117
- for key, value in list(result.items()):
118
- if isinstance(value, Image):
119
- other_contents.append(value.to_image_content())
120
- del result[key]
121
- elif isinstance(
122
- value, TextContent | ImageContent | EmbeddedResource):
123
- other_contents.append(value)
124
- del result[key]
125
-
126
- if not isinstance(result, str):
127
- result = jsonpickle.dumps(result)
128
-
129
- return [TextContent(type="text", text=result, job_info=job_info)] \
130
- + other_contents
21
+ class JobResult(BaseModel):
22
+ result: Any
23
+ job_info: dict
131
24
 
132
25
 
133
26
  class Tool(mcp.server.fastmcp.tools.Tool):
134
27
  """
135
28
  Workaround MCP server cannot print traceback
136
- Remove this if MCP has proper support
29
+ Add job info to first unstructured content
137
30
  """
138
31
  async def run(self, *args, **kwargs):
139
32
  try:
140
- return await super().run(*args, **kwargs)
33
+ kwargs["convert_result"] = False
34
+ result = await super().run(*args, **kwargs)
35
+ if isinstance(result, JobResult):
36
+ job_info = result.job_info
37
+ result = self.fn_metadata.convert_result(result.result)
38
+ if isinstance(result, tuple) and len(result) == 2:
39
+ unstructured_content, _ = result
40
+ else:
41
+ unstructured_content = result
42
+ if len(unstructured_content) == 0:
43
+ unstructured_content.append(
44
+ TextContent(type="text", text="null"))
45
+ unstructured_content[0].job_info = job_info
46
+ else:
47
+ result = self.fn_metadata.convert_result(result)
48
+ return result
141
49
  except Exception as e:
142
50
  traceback.print_exc()
143
51
  raise e