bohr-agent-sdk 0.1.111__py3-none-any.whl → 0.1.113__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: bohr-agent-sdk
3
- Version: 0.1.111
3
+ Version: 0.1.113
4
4
  Summary: SDK for scientific agents
5
5
  Home-page: https://github.com/dptech-corp/bohr-agent-sdk/
6
6
  Author: DP Technology
@@ -62,9 +62,9 @@ dp/agent/device/device/__init__.py,sha256=w7_1S16S1vWUq0RGl0GFgjq2vFkc5oNvy8cQTn
62
62
  dp/agent/device/device/device.py,sha256=9ZRIJth-4qMO-i-u_b_cO3d6a4eTbTQjPaxFsV_zEkc,9643
63
63
  dp/agent/device/device/types.py,sha256=JuxB-hjf1CjjvfBxCLwRAXVFlYS-nPEdiJpBWLFVCzo,1924
64
64
  dp/agent/server/__init__.py,sha256=rckaYd8pbYyB4ENEhgjXKeGMXjdnrgcJpdM1gu5u1Wc,508
65
- dp/agent/server/calculation_mcp_server.py,sha256=vvsf58aKbBtH0AqrG5_qhGqg5g2nEhmXgabDzZKpa6o,18534
65
+ dp/agent/server/calculation_mcp_server.py,sha256=a0hKNVz-WoUbL8y9GhDx1hO830frvmdtvXTBf7V40lI,18478
66
66
  dp/agent/server/preprocessor.py,sha256=XUWu7QOwo_sIDMYS2b1OTrM33EXEVH_73vk-ju1Ok8A,1264
67
- dp/agent/server/utils.py,sha256=ui3lca9EagcGqmYf8BKLsPARIzXxJ3jgN98yuEO3OSQ,1668
67
+ dp/agent/server/utils.py,sha256=cIKaAg8UaP5yMwvIVTgUVBjy-B3S16bEdnucUf4UDIM,2055
68
68
  dp/agent/server/executor/__init__.py,sha256=s95M5qKQk39Yi9qaVJZhk_nfj54quSf7EDghR3OCFUA,248
69
69
  dp/agent/server/executor/base_executor.py,sha256=nR2jI-wFvKoOk8QaK11pnSAkHj2MsE6uyzPWDx-vgJA,3018
70
70
  dp/agent/server/executor/dispatcher_executor.py,sha256=CZRxbVkLaDvStXhNaMKrKcx2Z0tPPVzIxkU1ufqWgYc,12081
@@ -75,8 +75,8 @@ dp/agent/server/storage/bohrium_storage.py,sha256=EsKX4dWWvZTn2TEhZv4zsvihfDK0mm
75
75
  dp/agent/server/storage/http_storage.py,sha256=KiySq7g9-iJr12XQCKKyJLn8wJoDnSRpQAR5_qPJ1ZU,1471
76
76
  dp/agent/server/storage/local_storage.py,sha256=t1wfjByjXew9ws3PuUxWxmZQ0-Wt1a6t4wmj3fW62GI,1352
77
77
  dp/agent/server/storage/oss_storage.py,sha256=pgjmi7Gir3Y5wkMDCvU4fvSls15fXT7Ax-h9MYHFPK0,3359
78
- bohr_agent_sdk-0.1.111.dist-info/METADATA,sha256=savZekEjjj6ToO6eqTYjuaEqgq7jJF-KH4Xzf_XSukM,11070
79
- bohr_agent_sdk-0.1.111.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
80
- bohr_agent_sdk-0.1.111.dist-info/entry_points.txt,sha256=5n5kneF5IbDQtoQ2WfF-QuBjDtsimJte9Rv9baSGgc0,86
81
- bohr_agent_sdk-0.1.111.dist-info/top_level.txt,sha256=87xLUDhu_1nQHoGLwlhJ6XlO7OsjILh6i1nX6ljFzDo,3
82
- bohr_agent_sdk-0.1.111.dist-info/RECORD,,
78
+ bohr_agent_sdk-0.1.113.dist-info/METADATA,sha256=66soACYq7kToAXVjiDX4A_jRkddvoSOLah9QPFrUBxs,11070
79
+ bohr_agent_sdk-0.1.113.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
80
+ bohr_agent_sdk-0.1.113.dist-info/entry_points.txt,sha256=5n5kneF5IbDQtoQ2WfF-QuBjDtsimJte9Rv9baSGgc0,86
81
+ bohr_agent_sdk-0.1.113.dist-info/top_level.txt,sha256=87xLUDhu_1nQHoGLwlhJ6XlO7OsjILh6i1nX6ljFzDo,3
82
+ bohr_agent_sdk-0.1.113.dist-info/RECORD,,
@@ -7,17 +7,18 @@ from copy import deepcopy
7
7
  from datetime import datetime
8
8
  from pathlib import Path
9
9
  from urllib.parse import urlparse
10
- from typing import Any, Literal, Optional, TypedDict, List, Dict
10
+ from typing import Annotated, Literal, Optional, List, Dict
11
11
 
12
- import mcp
13
12
  from mcp.server.fastmcp import FastMCP
14
13
  from mcp.server.fastmcp.utilities.context_injection import (
15
14
  find_context_parameter,
16
15
  )
17
16
  from mcp.server.fastmcp.utilities.func_metadata import (
18
- _get_typed_signature,
17
+ ArgModelBase,
19
18
  func_metadata,
20
19
  )
20
+ from mcp.server.sse import SseServerTransport
21
+ from pydantic import BaseModel, Field, create_model
21
22
  from starlette.responses import JSONResponse
22
23
  from starlette.routing import Route
23
24
 
@@ -65,18 +66,9 @@ def set_directory(workdir: str):
65
66
  os.chdir(cwd)
66
67
 
67
68
 
68
- def load_executor(executor):
69
- if not executor and os.path.exists("executor.json"):
70
- with open("executor.json", "r") as f:
71
- executor = json.load(f)
72
- return executor
73
-
74
-
75
- def load_storage(storage):
76
- if not storage and os.path.exists("storage.json"):
77
- with open("storage.json", "r") as f:
78
- storage = json.load(f)
79
- return storage
69
+ def load_job_info():
70
+ with open("job.json", "r") as f:
71
+ return json.load(f)
80
72
 
81
73
 
82
74
  def query_job_status(job_id: str, executor: Optional[dict] = None
@@ -90,7 +82,7 @@ def query_job_status(job_id: str, executor: Optional[dict] = None
90
82
  """
91
83
  trace_id, exec_id = job_id.split("/")
92
84
  with set_directory(trace_id):
93
- executor = load_executor(executor)
85
+ executor = load_job_info()["executor"] or executor
94
86
  _, executor = init_executor(executor)
95
87
  status = executor.query_status(exec_id)
96
88
  logger.info("Job %s status is %s" % (job_id, status))
@@ -105,7 +97,7 @@ def terminate_job(job_id: str, executor: Optional[dict] = None):
105
97
  """
106
98
  trace_id, exec_id = job_id.split("/")
107
99
  with set_directory(trace_id):
108
- executor = load_executor(executor)
100
+ executor = load_job_info()["executor"] or executor
109
101
  _, executor = init_executor(executor)
110
102
  executor.terminate(exec_id)
111
103
  logger.info("Job %s is terminated" % job_id)
@@ -215,7 +207,8 @@ def handle_output_artifacts(results, exec_id, storage):
215
207
  "storage_type": storage_type,
216
208
  "uri": uri,
217
209
  }
218
- elif isinstance(results[name], list) and all(isinstance(item, Path) for item in results[name]):
210
+ elif isinstance(results[name], list) and all(
211
+ isinstance(item, Path) for item in results[name]):
219
212
  new_uris = []
220
213
  for item in results[name]:
221
214
  key = storage.upload("%s/outputs/%s" % (exec_id, name),
@@ -235,7 +228,7 @@ def handle_output_artifacts(results, exec_id, storage):
235
228
  # MCP does not regard Any as serializable in Python 3.12
236
229
  # use Optional[Any] to work around
237
230
  def get_job_results(job_id: str, executor: Optional[dict] = None,
238
- storage: Optional[dict] = None) -> Optional[Any]:
231
+ storage: Optional[dict] = None):
239
232
  """
240
233
  Get results of a calculation job
241
234
  Args:
@@ -245,8 +238,9 @@ def get_job_results(job_id: str, executor: Optional[dict] = None,
245
238
  """
246
239
  trace_id, exec_id = job_id.split("/")
247
240
  with set_directory(trace_id):
248
- executor = load_executor(executor)
249
- storage = load_storage(storage)
241
+ job_info = load_job_info()
242
+ executor = job_info["executor"] or executor
243
+ storage = job_info["storage"] or storage
250
244
  _, executor = init_executor(executor)
251
245
  results = executor.get_results(exec_id)
252
246
  results, output_artifacts = handle_output_artifacts(
@@ -254,12 +248,55 @@ def get_job_results(job_id: str, executor: Optional[dict] = None,
254
248
  logger.info("Job %s result is %s" % (job_id, results))
255
249
  return JobResult(result=results, job_info={
256
250
  "output_artifacts": output_artifacts,
257
- })
251
+ }, tool_name=job_info["tool_name"])
252
+
253
+
254
+ annotation_map = {
255
+ Path: str,
256
+ Optional[Path]: Optional[str],
257
+ List[Path]: List[str],
258
+ Optional[List[Path]]: Optional[List[str]],
259
+ Dict[str, Path]: Dict[str, str],
260
+ Optional[Dict[str, Path]]: Optional[Dict[str, str]],
261
+ Dict[str, List[Path]]: Dict[str, List[str]],
262
+ Optional[Dict[str, List[Path]]]: Optional[Dict[str, List[str]]],
263
+ }
264
+
265
+
266
+ class SubmitResult(BaseModel):
267
+ job_id: str
268
+ extra_info: dict | None = None
269
+
270
+
271
+ def patch_mcp_close_connection():
272
+ _mock_orig_handle_post_message = SseServerTransport.handle_post_message
273
+
274
+ async def _mock_handle_post_message_with_close(self, scope, receive, send):
275
+ async def _send(message):
276
+ if message.get("type") == "http.response.start":
277
+ headers = list(message.get("headers", []))
278
+ headers = [
279
+ (name, value)
280
+ for name, value in headers
281
+ if name.lower() != b"connection"
282
+ ]
283
+ headers.append((b"connection", b"close"))
284
+ message["headers"] = headers
285
+ elif message.get("type") == "http.response.body":
286
+ message["more_body"] = False
287
+ await send(message)
288
+ await _mock_orig_handle_post_message(self, scope, receive, _send)
289
+
290
+ if not getattr(SseServerTransport.handle_post_message,
291
+ "__patched_close__", False):
292
+ SseServerTransport.handle_post_message = \
293
+ _mock_handle_post_message_with_close
294
+ SseServerTransport.handle_post_message.__patched_close__ = True
258
295
 
259
296
 
260
297
  class CalculationMCPServer:
261
298
  def __init__(self, *args, preprocess_func=None, fastmcp_mode=False,
262
- **kwargs):
299
+ patch_close_connection=False, **kwargs):
263
300
  """
264
301
  Args:
265
302
  preprocess_func: The preprocess function for all tools
@@ -267,72 +304,50 @@ class CalculationMCPServer:
267
304
  """
268
305
  self.preprocess_func = preprocess_func
269
306
  self.fastmcp_mode = fastmcp_mode
307
+ if patch_close_connection:
308
+ patch_mcp_close_connection()
270
309
  self.mcp = FastMCP(*args, **kwargs)
310
+ self.fn_metadata_map = {}
271
311
 
272
312
  def add_patched_tool(self, fn, new_fn, name, is_async=False, doc=None,
273
313
  override_return_annotation=False):
274
314
  """patch the metadata of the tool"""
275
315
  context_kwarg = find_context_parameter(fn)
276
-
277
- def _get_typed_signature_patched(call):
278
- """patch parameters"""
279
- typed_signature = _get_typed_signature(call)
280
- new_typed_signature = _get_typed_signature(new_fn)
281
- parameters = []
282
- for param in typed_signature.parameters.values():
283
- if param.annotation is Path:
284
- parameters.append(inspect.Parameter(
285
- name=param.name, default=param.default,
286
- annotation=str, kind=param.kind))
287
- elif param.annotation is Optional[Path]:
288
- parameters.append(inspect.Parameter(
289
- name=param.name, default=param.default,
290
- annotation=Optional[str], kind=param.kind))
291
- elif param.annotation is List[Path]:
292
- parameters.append(inspect.Parameter(
293
- name=param.name, default=param.default,
294
- annotation=List[str], kind=param.kind))
295
- elif param.annotation is Optional[List[Path]]:
296
- parameters.append(inspect.Parameter(
297
- name=param.name, default=param.default,
298
- annotation=Optional[List[str]], kind=param.kind))
299
- elif param.annotation is Dict[str, Path]:
300
- parameters.append(inspect.Parameter(
301
- name=param.name, default=param.default,
302
- annotation=Dict[str, str], kind=param.kind))
303
- elif param.annotation is Optional[Dict[str, Path]]:
304
- parameters.append(inspect.Parameter(
305
- name=param.name, default=param.default,
306
- annotation=Optional[Dict[str, str]], kind=param.kind))
307
- elif param.annotation is Dict[str, List[Path]]:
308
- parameters.append(inspect.Parameter(
309
- name=param.name, default=param.default,
310
- annotation=Dict[str, List[str]], kind=param.kind))
311
- elif param.annotation is Optional[Dict[str, List[Path]]]:
312
- parameters.append(inspect.Parameter(
313
- name=param.name, default=param.default,
314
- annotation=Optional[Dict[str, List[str]]], kind=param.kind))
315
- else:
316
- parameters.append(param)
317
- for param in new_typed_signature.parameters.values():
318
- if param.name != "kwargs":
319
- parameters.append(param)
320
- return inspect.Signature(
321
- parameters,
322
- return_annotation=(new_typed_signature.return_annotation
323
- if override_return_annotation
324
- else typed_signature.return_annotation))
325
-
326
- # Due to the frequent changes of MCP, we use a patching style here
327
- mcp.server.fastmcp.utilities.func_metadata._get_typed_signature = \
328
- _get_typed_signature_patched
329
316
  func_arg_metadata = func_metadata(
330
317
  fn,
331
318
  skip_names=[context_kwarg] if context_kwarg is not None else [],
332
- structured_output=None,
333
319
  )
334
- mcp.server.fastmcp.utilities.func_metadata._get_typed_signature = \
335
- _get_typed_signature
320
+ self.fn_metadata_map[name] = func_arg_metadata
321
+ model_params = {}
322
+ params = inspect.signature(fn, eval_str=True).parameters
323
+ for n, annotation in \
324
+ func_arg_metadata.arg_model.__annotations__.items():
325
+ param = params[n]
326
+ if param.annotation in annotation_map:
327
+ model_params[n] = Annotated[
328
+ (annotation_map[param.annotation], Field())]
329
+ else:
330
+ model_params[n] = annotation
331
+ if param.default is not inspect.Parameter.empty:
332
+ model_params[n] = (model_params[n], param.default)
333
+ for n, param in inspect.signature(new_fn).parameters.items():
334
+ if n == "kwargs":
335
+ continue
336
+ model_params[n] = Annotated[(param.annotation, Field())]
337
+ if param.default is not inspect.Parameter.empty:
338
+ model_params[n] = (model_params[n], param.default)
339
+
340
+ func_arg_metadata.arg_model = create_model(
341
+ f"{fn.__name__}Arguments",
342
+ __base__=ArgModelBase,
343
+ **model_params,
344
+ )
345
+ if override_return_annotation:
346
+ new_func_arg_metadata = func_metadata(new_fn)
347
+ func_arg_metadata.output_model = new_func_arg_metadata.output_model
348
+ func_arg_metadata.output_schema = \
349
+ new_func_arg_metadata.output_schema
350
+ func_arg_metadata.wrap_output = new_func_arg_metadata.wrap_output
336
351
  if self.fastmcp_mode and func_arg_metadata.wrap_output:
337
352
  # Only simulate behavior of fastmcp for output_schema
338
353
  func_arg_metadata.output_schema["x-fastmcp-wrap-result"] = True
@@ -341,16 +356,18 @@ class CalculationMCPServer:
341
356
  tool = Tool(
342
357
  fn=new_fn,
343
358
  name=name,
344
- description=doc or fn.__doc__,
359
+ description=doc or fn.__doc__ or "",
345
360
  parameters=parameters,
346
361
  fn_metadata=func_arg_metadata,
347
362
  is_async=is_async,
348
363
  context_kwarg=context_kwarg,
364
+ fn_metadata_map=self.fn_metadata_map,
349
365
  )
350
366
  self.mcp._tool_manager._tools[name] = tool
351
367
 
352
368
  def add_tool(self, fn, *args, **kwargs):
353
- tool = Tool.from_function(fn, *args, **kwargs)
369
+ tool = Tool.from_function(
370
+ fn, *args, fn_metadata_map=self.fn_metadata_map, **kwargs)
354
371
  self.mcp._tool_manager._tools[tool.name] = tool
355
372
  return tool
356
373
 
@@ -361,20 +378,20 @@ class CalculationMCPServer:
361
378
  def decorator(fn: Callable) -> Callable:
362
379
  def submit_job(executor: Optional[dict] = None,
363
380
  storage: Optional[dict] = None,
364
- **kwargs) -> TypedDict("results", {
365
- "job_id": str, "extra_info": Optional[dict]}):
381
+ **kwargs) -> SubmitResult:
366
382
  trace_id = datetime.today().strftime('%Y-%m-%d-%H:%M:%S.%f')
367
383
  logger.info("Job processing (Trace ID: %s)" % trace_id)
368
384
  with set_directory(trace_id):
369
385
  if preprocess_func is not None:
370
386
  executor, storage, kwargs = preprocess_func(
371
387
  executor, storage, kwargs)
372
- if executor:
373
- with open("executor.json", "w") as f:
374
- json.dump(executor, f, indent=4)
375
- if storage:
376
- with open("storage.json", "w") as f:
377
- json.dump(storage, f, indent=4)
388
+ job = {
389
+ "tool_name": fn.__name__,
390
+ "executor": executor,
391
+ "storage": storage,
392
+ }
393
+ with open("job.json", "w") as f:
394
+ json.dump(job, f, indent=4)
378
395
  kwargs, input_artifacts = handle_input_artifacts(
379
396
  fn, kwargs, storage)
380
397
  executor_type, executor = init_executor(executor)
@@ -382,10 +399,10 @@ class CalculationMCPServer:
382
399
  exec_id = res["job_id"]
383
400
  job_id = "%s/%s" % (trace_id, exec_id)
384
401
  logger.info("Job submitted (ID: %s)" % job_id)
385
- result = {
386
- "job_id": job_id,
387
- "extra_info": res.get("extra_info"),
388
- }
402
+ result = SubmitResult(
403
+ job_id=job_id,
404
+ extra_info=res.get("extra_info"),
405
+ )
389
406
  return JobResult(result=result, job_info={
390
407
  "trace_id": trace_id,
391
408
  "executor_type": executor_type,
dp/agent/server/utils.py CHANGED
@@ -21,6 +21,7 @@ def get_logger(name, level="INFO",
21
21
  class JobResult(BaseModel):
22
22
  result: Any
23
23
  job_info: dict
24
+ tool_name: str | None = None
24
25
 
25
26
 
26
27
  class Tool(mcp.server.fastmcp.tools.Tool):
@@ -28,13 +29,23 @@ class Tool(mcp.server.fastmcp.tools.Tool):
28
29
  Workaround MCP server cannot print traceback
29
30
  Add job info to first unstructured content
30
31
  """
32
+ fn_metadata_map: dict | None = None
33
+
34
+ @classmethod
35
+ def from_function(cls, *args, fn_metadata_map=None, **kwargs):
36
+ tool = super().from_function(*args, **kwargs)
37
+ tool.fn_metadata_map = fn_metadata_map
38
+ return tool
39
+
31
40
  async def run(self, *args, **kwargs):
32
41
  try:
33
42
  kwargs["convert_result"] = False
34
43
  result = await super().run(*args, **kwargs)
35
44
  if isinstance(result, JobResult):
36
45
  job_info = result.job_info
37
- result = self.fn_metadata.convert_result(result.result)
46
+ fn_metadata = self.fn_metadata_map.get(
47
+ result.tool_name, self.fn_metadata)
48
+ result = fn_metadata.convert_result(result.result)
38
49
  if isinstance(result, tuple) and len(result) == 2:
39
50
  unstructured_content, _ = result
40
51
  else: