datarobot-genai 0.2.22__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datarobot_genai/drmcp/core/dr_mcp_server.py +0 -3
- datarobot_genai/drmcp/core/mcp_instance.py +37 -103
- datarobot_genai/drmcp/core/tool_filter.py +10 -1
- datarobot_genai/drmcp/tools/clients/confluence.py +93 -1
- datarobot_genai/drmcp/tools/clients/gdrive.py +255 -0
- datarobot_genai/drmcp/tools/confluence/tools.py +67 -0
- datarobot_genai/drmcp/tools/gdrive/tools.py +66 -0
- datarobot_genai/drmcp/tools/predictive/project.py +45 -27
- datarobot_genai/drmcp/tools/predictive/training.py +160 -151
- {datarobot_genai-0.2.22.dist-info → datarobot_genai-0.2.26.dist-info}/METADATA +1 -1
- {datarobot_genai-0.2.22.dist-info → datarobot_genai-0.2.26.dist-info}/RECORD +15 -16
- datarobot_genai/drmcp/core/mcp_server_tools.py +0 -129
- {datarobot_genai-0.2.22.dist-info → datarobot_genai-0.2.26.dist-info}/WHEEL +0 -0
- {datarobot_genai-0.2.22.dist-info → datarobot_genai-0.2.26.dist-info}/entry_points.txt +0 -0
- {datarobot_genai-0.2.22.dist-info → datarobot_genai-0.2.26.dist-info}/licenses/AUTHORS +0 -0
- {datarobot_genai-0.2.22.dist-info → datarobot_genai-0.2.26.dist-info}/licenses/LICENSE +0 -0
|
@@ -31,9 +31,6 @@ from .dynamic_prompts.register import register_prompts_from_datarobot_prompt_man
|
|
|
31
31
|
from .dynamic_tools.deployment.register import register_tools_of_datarobot_deployments
|
|
32
32
|
from .logging import MCPLogging
|
|
33
33
|
from .mcp_instance import mcp
|
|
34
|
-
from .mcp_server_tools import get_all_available_tags # noqa # pylint: disable=unused-import
|
|
35
|
-
from .mcp_server_tools import get_tool_info_by_name # noqa # pylint: disable=unused-import
|
|
36
|
-
from .mcp_server_tools import list_tools_by_tags # noqa # pylint: disable=unused-import
|
|
37
34
|
from .memory_management.manager import MemoryManager
|
|
38
35
|
from .routes import register_routes
|
|
39
36
|
from .routes_utils import prefix_mount_path
|
|
@@ -16,20 +16,18 @@ import logging
|
|
|
16
16
|
from collections.abc import Callable
|
|
17
17
|
from functools import wraps
|
|
18
18
|
from typing import Any
|
|
19
|
-
from typing import
|
|
19
|
+
from typing import TypedDict
|
|
20
20
|
|
|
21
21
|
from fastmcp import Context
|
|
22
22
|
from fastmcp import FastMCP
|
|
23
23
|
from fastmcp.exceptions import NotFoundError
|
|
24
24
|
from fastmcp.prompts.prompt import Prompt
|
|
25
25
|
from fastmcp.server.dependencies import get_context
|
|
26
|
-
from fastmcp.tools import FunctionTool
|
|
27
26
|
from fastmcp.tools import Tool
|
|
28
|
-
from fastmcp.utilities.types import NotSet
|
|
29
|
-
from fastmcp.utilities.types import NotSetT
|
|
30
27
|
from mcp.types import AnyFunction
|
|
31
28
|
from mcp.types import Tool as MCPTool
|
|
32
29
|
from mcp.types import ToolAnnotations
|
|
30
|
+
from typing_extensions import Unpack
|
|
33
31
|
|
|
34
32
|
from .config import MCPServerConfig
|
|
35
33
|
from .config import get_config
|
|
@@ -120,86 +118,6 @@ class TaggedFastMCP(FastMCP):
|
|
|
120
118
|
"In stateless mode, clients will see changes on next request."
|
|
121
119
|
)
|
|
122
120
|
|
|
123
|
-
@overload
|
|
124
|
-
def tool(
|
|
125
|
-
self,
|
|
126
|
-
name_or_fn: AnyFunction,
|
|
127
|
-
*,
|
|
128
|
-
name: str | None = None,
|
|
129
|
-
title: str | None = None,
|
|
130
|
-
description: str | None = None,
|
|
131
|
-
tags: set[str] | None = None,
|
|
132
|
-
output_schema: dict[str, Any] | None | NotSetT = NotSet,
|
|
133
|
-
annotations: ToolAnnotations | dict[str, Any] | None = None,
|
|
134
|
-
exclude_args: list[str] | None = None,
|
|
135
|
-
meta: dict[str, Any] | None = None,
|
|
136
|
-
enabled: bool | None = None,
|
|
137
|
-
) -> FunctionTool: ...
|
|
138
|
-
|
|
139
|
-
@overload
|
|
140
|
-
def tool(
|
|
141
|
-
self,
|
|
142
|
-
name_or_fn: str | None = None,
|
|
143
|
-
*,
|
|
144
|
-
name: str | None = None,
|
|
145
|
-
title: str | None = None,
|
|
146
|
-
description: str | None = None,
|
|
147
|
-
tags: set[str] | None = None,
|
|
148
|
-
output_schema: dict[str, Any] | None | NotSetT = NotSet,
|
|
149
|
-
annotations: ToolAnnotations | dict[str, Any] | None = None,
|
|
150
|
-
exclude_args: list[str] | None = None,
|
|
151
|
-
meta: dict[str, Any] | None = None,
|
|
152
|
-
enabled: bool | None = None,
|
|
153
|
-
) -> Callable[[AnyFunction], FunctionTool]: ...
|
|
154
|
-
|
|
155
|
-
def tool(
|
|
156
|
-
self,
|
|
157
|
-
name_or_fn: str | Callable[..., Any] | None = None,
|
|
158
|
-
*,
|
|
159
|
-
name: str | None = None,
|
|
160
|
-
title: str | None = None,
|
|
161
|
-
description: str | None = None,
|
|
162
|
-
tags: set[str] | None = None,
|
|
163
|
-
output_schema: dict[str, Any] | None | NotSetT = NotSet,
|
|
164
|
-
annotations: ToolAnnotations | dict[str, Any] | None = None,
|
|
165
|
-
exclude_args: list[str] | None = None,
|
|
166
|
-
meta: dict[str, Any] | None = None,
|
|
167
|
-
enabled: bool | None = None,
|
|
168
|
-
**kwargs: Any,
|
|
169
|
-
) -> Callable[[AnyFunction], FunctionTool] | FunctionTool:
|
|
170
|
-
"""
|
|
171
|
-
Extend tool decorator that supports tags and other annotations, while remaining
|
|
172
|
-
signature-compatible with FastMCP.tool to avoid recursion issues with partials.
|
|
173
|
-
"""
|
|
174
|
-
if isinstance(annotations, dict):
|
|
175
|
-
annotations = ToolAnnotations(**annotations)
|
|
176
|
-
|
|
177
|
-
# Ensure tags are available both via native fastmcp `tags` and inside annotations
|
|
178
|
-
if tags is not None:
|
|
179
|
-
tags_ = sorted(tags)
|
|
180
|
-
if annotations is None:
|
|
181
|
-
annotations = ToolAnnotations() # type: ignore[call-arg]
|
|
182
|
-
annotations.tags = tags_ # type: ignore[attr-defined, union-attr]
|
|
183
|
-
else:
|
|
184
|
-
# At this point, annotations is ToolAnnotations (not dict)
|
|
185
|
-
assert isinstance(annotations, ToolAnnotations)
|
|
186
|
-
annotations.tags = tags_ # type: ignore[attr-defined]
|
|
187
|
-
|
|
188
|
-
return super().tool(
|
|
189
|
-
name_or_fn,
|
|
190
|
-
name=name,
|
|
191
|
-
title=title,
|
|
192
|
-
description=description,
|
|
193
|
-
tags=tags,
|
|
194
|
-
output_schema=output_schema
|
|
195
|
-
if output_schema is not None
|
|
196
|
-
else kwargs.get("output_schema"),
|
|
197
|
-
annotations=annotations,
|
|
198
|
-
exclude_args=exclude_args,
|
|
199
|
-
meta=meta,
|
|
200
|
-
enabled=enabled,
|
|
201
|
-
)
|
|
202
|
-
|
|
203
121
|
async def list_tools(
|
|
204
122
|
self, tags: list[str] | None = None, match_all: bool = False
|
|
205
123
|
) -> list[MCPTool]:
|
|
@@ -371,16 +289,37 @@ mcp = TaggedFastMCP(
|
|
|
371
289
|
)
|
|
372
290
|
|
|
373
291
|
|
|
292
|
+
class ToolKwargs(TypedDict, total=False):
|
|
293
|
+
"""Keyword arguments passed through to FastMCP's mcp.tool() decorator.
|
|
294
|
+
|
|
295
|
+
All parameters are optional and forwarded directly to FastMCP tool registration.
|
|
296
|
+
See FastMCP documentation for full details on each parameter.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
name: str | None
|
|
300
|
+
title: str | None
|
|
301
|
+
description: str | None
|
|
302
|
+
icons: list[Any] | None
|
|
303
|
+
tags: set[str] | None
|
|
304
|
+
output_schema: dict[str, Any] | None
|
|
305
|
+
annotations: Any | None
|
|
306
|
+
exclude_args: list[str] | None
|
|
307
|
+
meta: dict[str, Any] | None
|
|
308
|
+
enabled: bool | None
|
|
309
|
+
|
|
310
|
+
|
|
374
311
|
def dr_core_mcp_tool(
|
|
375
|
-
|
|
376
|
-
description: str | None = None,
|
|
377
|
-
tags: set[str] | None = None,
|
|
312
|
+
**kwargs: Unpack[ToolKwargs],
|
|
378
313
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
379
|
-
"""Combine decorator that includes mcp.tool() and dr_mcp_extras().
|
|
314
|
+
"""Combine decorator that includes mcp.tool() and dr_mcp_extras().
|
|
315
|
+
|
|
316
|
+
All keyword arguments are passed through to FastMCP's mcp.tool() decorator.
|
|
317
|
+
See ToolKwargs for available parameters.
|
|
318
|
+
"""
|
|
380
319
|
|
|
381
320
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
382
321
|
instrumented = dr_mcp_extras()(func)
|
|
383
|
-
mcp.tool(
|
|
322
|
+
mcp.tool(**kwargs)(instrumented)
|
|
384
323
|
return instrumented
|
|
385
324
|
|
|
386
325
|
return decorator
|
|
@@ -413,27 +352,23 @@ async def memory_aware_wrapper(func: Callable[..., Any], *args: Any, **kwargs: A
|
|
|
413
352
|
|
|
414
353
|
|
|
415
354
|
def dr_mcp_tool(
|
|
416
|
-
|
|
417
|
-
description: str | None = None,
|
|
418
|
-
tags: set[str] | None = None,
|
|
355
|
+
**kwargs: Unpack[ToolKwargs],
|
|
419
356
|
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
420
357
|
"""Combine decorator that includes mcp.tool(), dr_mcp_extras(), and capture memory ids from
|
|
421
358
|
the request headers if they exist.
|
|
422
359
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
description: Tool description
|
|
426
|
-
tags: Optional set of tags to apply to the tool
|
|
360
|
+
All keyword arguments are passed through to FastMCP's mcp.tool() decorator.
|
|
361
|
+
See ToolKwargs for available parameters.
|
|
427
362
|
"""
|
|
428
363
|
|
|
429
364
|
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
430
365
|
@wraps(func)
|
|
431
|
-
async def wrapper(*args: Any, **
|
|
432
|
-
return await memory_aware_wrapper(func, *args, **
|
|
366
|
+
async def wrapper(*args: Any, **inner_kwargs: Any) -> Any:
|
|
367
|
+
return await memory_aware_wrapper(func, *args, **inner_kwargs)
|
|
433
368
|
|
|
434
369
|
# Apply the MCP decorators
|
|
435
370
|
instrumented = dr_mcp_extras()(wrapper)
|
|
436
|
-
mcp.tool(
|
|
371
|
+
mcp.tool(**kwargs)(instrumented)
|
|
437
372
|
return instrumented
|
|
438
373
|
|
|
439
374
|
return decorator
|
|
@@ -488,11 +423,10 @@ async def register_tools(
|
|
|
488
423
|
# Apply dr_mcp_extras to the memory-aware function
|
|
489
424
|
wrapped_fn = dr_mcp_extras()(memory_aware_fn)
|
|
490
425
|
|
|
491
|
-
# Create annotations
|
|
492
|
-
annotations =
|
|
493
|
-
if tags is not None:
|
|
494
|
-
annotations.tags = tags # type: ignore[attr-defined]
|
|
426
|
+
# Create annotations only when additional metadata is required
|
|
427
|
+
annotations: ToolAnnotations | None = None # type: ignore[assignment]
|
|
495
428
|
if deployment_id is not None:
|
|
429
|
+
annotations = ToolAnnotations() # type: ignore[call-arg]
|
|
496
430
|
annotations.deployment_id = deployment_id # type: ignore[attr-defined]
|
|
497
431
|
|
|
498
432
|
tool = Tool.from_function(
|
|
@@ -41,7 +41,7 @@ def filter_tools_by_tags(
|
|
|
41
41
|
filtered_tools = []
|
|
42
42
|
|
|
43
43
|
for tool in tools:
|
|
44
|
-
tool_tags =
|
|
44
|
+
tool_tags = get_tool_tags(tool)
|
|
45
45
|
|
|
46
46
|
if not tool_tags:
|
|
47
47
|
continue
|
|
@@ -68,9 +68,18 @@ def get_tool_tags(tool: Tool | MCPTool) -> list[str]:
|
|
|
68
68
|
-------
|
|
69
69
|
List of tags for the tool
|
|
70
70
|
"""
|
|
71
|
+
# Primary: native FastMCP meta location
|
|
72
|
+
if hasattr(tool, "meta") and getattr(tool, "meta"):
|
|
73
|
+
fastmcp_meta = tool.meta.get("_fastmcp", {})
|
|
74
|
+
meta_tags = fastmcp_meta.get("tags", [])
|
|
75
|
+
if isinstance(meta_tags, list):
|
|
76
|
+
return meta_tags
|
|
77
|
+
|
|
78
|
+
# Fallback: annotations.tags (for compatibility during transition)
|
|
71
79
|
if tool.annotations and hasattr(tool.annotations, "tags"):
|
|
72
80
|
tags = getattr(tool.annotations, "tags", [])
|
|
73
81
|
return tags if isinstance(tags, list) else []
|
|
82
|
+
|
|
74
83
|
return []
|
|
75
84
|
|
|
76
85
|
|
|
@@ -50,6 +50,7 @@ class ConfluencePage(BaseModel):
|
|
|
50
50
|
space_id: str = Field(..., description="Space ID where the page resides")
|
|
51
51
|
space_key: str | None = Field(None, description="Space key (if available)")
|
|
52
52
|
body: str = Field(..., description="Page content in storage format (HTML-like)")
|
|
53
|
+
version: int = Field(..., description="Current version number of the page")
|
|
53
54
|
|
|
54
55
|
def as_flat_dict(self) -> dict[str, Any]:
|
|
55
56
|
"""Return a flat dictionary representation of the page."""
|
|
@@ -59,6 +60,7 @@ class ConfluencePage(BaseModel):
|
|
|
59
60
|
"space_id": self.space_id,
|
|
60
61
|
"space_key": self.space_key,
|
|
61
62
|
"body": self.body,
|
|
63
|
+
"version": self.version,
|
|
62
64
|
}
|
|
63
65
|
|
|
64
66
|
|
|
@@ -111,7 +113,7 @@ class ConfluenceClient:
|
|
|
111
113
|
At the moment of creating this client, official Confluence SDK is not supporting async.
|
|
112
114
|
"""
|
|
113
115
|
|
|
114
|
-
EXPAND_FIELDS = "body.storage,space"
|
|
116
|
+
EXPAND_FIELDS = "body.storage,space,version"
|
|
115
117
|
|
|
116
118
|
def __init__(self, access_token: str) -> None:
|
|
117
119
|
"""
|
|
@@ -164,6 +166,8 @@ class ConfluenceClient:
|
|
|
164
166
|
space = data.get("space", {})
|
|
165
167
|
space_key = space.get("key") if isinstance(space, dict) else None
|
|
166
168
|
space_id = space.get("id", "") if isinstance(space, dict) else data.get("spaceId", "")
|
|
169
|
+
version_data = data.get("version", {})
|
|
170
|
+
version_number = version_data.get("number", 1) if isinstance(version_data, dict) else 1
|
|
167
171
|
|
|
168
172
|
return ConfluencePage(
|
|
169
173
|
page_id=str(data.get("id", "")),
|
|
@@ -171,6 +175,7 @@ class ConfluenceClient:
|
|
|
171
175
|
space_id=str(space_id),
|
|
172
176
|
space_key=space_key,
|
|
173
177
|
body=body_content,
|
|
178
|
+
version=version_number,
|
|
174
179
|
)
|
|
175
180
|
|
|
176
181
|
async def get_page_by_id(self, page_id: str) -> ConfluencePage:
|
|
@@ -339,6 +344,93 @@ class ConfluenceClient:
|
|
|
339
344
|
|
|
340
345
|
return self._parse_response(response.json())
|
|
341
346
|
|
|
347
|
+
async def update_page(
|
|
348
|
+
self,
|
|
349
|
+
page_id: str,
|
|
350
|
+
new_body_content: str,
|
|
351
|
+
version_number: int,
|
|
352
|
+
) -> ConfluencePage:
|
|
353
|
+
"""
|
|
354
|
+
Update the content of an existing Confluence page.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
page_id: The ID of the page to update
|
|
358
|
+
new_body_content: The new content in Confluence Storage Format (XML) or raw text
|
|
359
|
+
version_number: The current version number of the page (for optimistic locking).
|
|
360
|
+
The update will increment this by 1.
|
|
361
|
+
|
|
362
|
+
Returns
|
|
363
|
+
-------
|
|
364
|
+
ConfluencePage with the updated page data including the new version number
|
|
365
|
+
|
|
366
|
+
Raises
|
|
367
|
+
------
|
|
368
|
+
ConfluenceError: If page not found (404), version conflict (409),
|
|
369
|
+
permission denied (403), invalid content (400),
|
|
370
|
+
or rate limited (429)
|
|
371
|
+
httpx.HTTPStatusError: If the API request fails with unexpected status
|
|
372
|
+
"""
|
|
373
|
+
cloud_id = await self._get_cloud_id()
|
|
374
|
+
url = f"{ATLASSIAN_API_BASE}/ex/confluence/{cloud_id}/wiki/rest/api/content/{page_id}"
|
|
375
|
+
|
|
376
|
+
try:
|
|
377
|
+
current_page = await self.get_page_by_id(page_id)
|
|
378
|
+
title_to_use = current_page.title
|
|
379
|
+
except ConfluenceError as e:
|
|
380
|
+
if e.status_code == 404:
|
|
381
|
+
raise ConfluenceError(
|
|
382
|
+
f"Page with ID '{page_id}' not found: cannot fetch existing title",
|
|
383
|
+
status_code=404,
|
|
384
|
+
)
|
|
385
|
+
raise
|
|
386
|
+
|
|
387
|
+
payload: dict[str, Any] = {
|
|
388
|
+
"type": "page",
|
|
389
|
+
"title": title_to_use,
|
|
390
|
+
"body": {
|
|
391
|
+
"storage": {
|
|
392
|
+
"value": new_body_content,
|
|
393
|
+
"representation": "storage",
|
|
394
|
+
}
|
|
395
|
+
},
|
|
396
|
+
"version": {
|
|
397
|
+
"number": version_number + 1,
|
|
398
|
+
},
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
response = await self._client.put(url, json=payload)
|
|
402
|
+
|
|
403
|
+
if response.status_code == HTTPStatus.NOT_FOUND:
|
|
404
|
+
error_msg = self._extract_error_message(response)
|
|
405
|
+
raise ConfluenceError(
|
|
406
|
+
f"Page with ID '{page_id}' not found: {error_msg}",
|
|
407
|
+
status_code=404,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
if response.status_code == HTTPStatus.CONFLICT:
|
|
411
|
+
error_msg = self._extract_error_message(response)
|
|
412
|
+
raise ConfluenceError(
|
|
413
|
+
f"Version conflict: the page has been modified since version {version_number}. "
|
|
414
|
+
f"Please fetch the latest version and retry. Details: {error_msg}",
|
|
415
|
+
status_code=409,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
if response.status_code == HTTPStatus.FORBIDDEN:
|
|
419
|
+
raise ConfluenceError(
|
|
420
|
+
f"Permission denied: you don't have access to update page '{page_id}'",
|
|
421
|
+
status_code=403,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
if response.status_code == HTTPStatus.BAD_REQUEST:
|
|
425
|
+
error_msg = self._extract_error_message(response)
|
|
426
|
+
raise ConfluenceError(f"Invalid request: {error_msg}", status_code=400)
|
|
427
|
+
|
|
428
|
+
if response.status_code == HTTPStatus.TOO_MANY_REQUESTS:
|
|
429
|
+
raise ConfluenceError("Rate limit exceeded. Please try again later.", status_code=429)
|
|
430
|
+
|
|
431
|
+
response.raise_for_status()
|
|
432
|
+
return self._parse_response(response.json())
|
|
433
|
+
|
|
342
434
|
def _parse_comment_response(self, data: dict, page_id: str) -> ConfluenceComment:
|
|
343
435
|
"""Parse API response into ConfluenceComment."""
|
|
344
436
|
body_content = ""
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
"""Google Drive API Client and utilities for OAuth."""
|
|
16
16
|
|
|
17
|
+
import io
|
|
17
18
|
import logging
|
|
18
19
|
from typing import Annotated
|
|
19
20
|
from typing import Any
|
|
@@ -24,6 +25,7 @@ from fastmcp.exceptions import ToolError
|
|
|
24
25
|
from pydantic import BaseModel
|
|
25
26
|
from pydantic import ConfigDict
|
|
26
27
|
from pydantic import Field
|
|
28
|
+
from pypdf import PdfReader
|
|
27
29
|
|
|
28
30
|
from datarobot_genai.drmcp.core.auth import get_access_token
|
|
29
31
|
|
|
@@ -37,6 +39,23 @@ DEFAULT_ORDER = "modifiedTime desc"
|
|
|
37
39
|
MAX_PAGE_SIZE = 100
|
|
38
40
|
LIMIT = 500
|
|
39
41
|
|
|
42
|
+
GOOGLE_WORKSPACE_EXPORT_MIMES: dict[str, str] = {
|
|
43
|
+
"application/vnd.google-apps.document": "text/markdown",
|
|
44
|
+
"application/vnd.google-apps.spreadsheet": "text/csv",
|
|
45
|
+
"application/vnd.google-apps.presentation": "text/plain",
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
BINARY_MIME_PREFIXES = (
|
|
49
|
+
"image/",
|
|
50
|
+
"audio/",
|
|
51
|
+
"video/",
|
|
52
|
+
"application/zip",
|
|
53
|
+
"application/octet-stream",
|
|
54
|
+
"application/vnd.google-apps.drawing",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
PDF_MIME_TYPE = "application/pdf"
|
|
58
|
+
|
|
40
59
|
|
|
41
60
|
async def get_gdrive_access_token() -> str | ToolError:
|
|
42
61
|
"""
|
|
@@ -116,6 +135,35 @@ class PaginatedResult(BaseModel):
|
|
|
116
135
|
next_page_token: str | None = None
|
|
117
136
|
|
|
118
137
|
|
|
138
|
+
class GoogleDriveFileContent(BaseModel):
|
|
139
|
+
"""Content retrieved from a Google Drive file."""
|
|
140
|
+
|
|
141
|
+
id: str
|
|
142
|
+
name: str
|
|
143
|
+
mime_type: str
|
|
144
|
+
content: str
|
|
145
|
+
original_mime_type: str
|
|
146
|
+
was_exported: bool = False
|
|
147
|
+
size: int | None = None
|
|
148
|
+
web_view_link: str | None = None
|
|
149
|
+
|
|
150
|
+
def as_flat_dict(self) -> dict[str, Any]:
|
|
151
|
+
"""Return a flat dictionary representation of the file content."""
|
|
152
|
+
result: dict[str, Any] = {
|
|
153
|
+
"id": self.id,
|
|
154
|
+
"name": self.name,
|
|
155
|
+
"mimeType": self.mime_type,
|
|
156
|
+
"content": self.content,
|
|
157
|
+
"originalMimeType": self.original_mime_type,
|
|
158
|
+
"wasExported": self.was_exported,
|
|
159
|
+
}
|
|
160
|
+
if self.size is not None:
|
|
161
|
+
result["size"] = self.size
|
|
162
|
+
if self.web_view_link is not None:
|
|
163
|
+
result["webViewLink"] = self.web_view_link
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
|
|
119
167
|
class GoogleDriveClient:
|
|
120
168
|
"""Client for interacting with Google Drive API."""
|
|
121
169
|
|
|
@@ -344,6 +392,213 @@ class GoogleDriveClient:
|
|
|
344
392
|
logger.debug(f"Auto-formatted query '{query}' to '{formatted_query}'")
|
|
345
393
|
return formatted_query
|
|
346
394
|
|
|
395
|
+
@staticmethod
|
|
396
|
+
def _is_binary_mime_type(mime_type: str) -> bool:
|
|
397
|
+
"""Check if MIME type indicates binary content that's not useful for LLM consumption.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
mime_type: The MIME type to check.
|
|
401
|
+
|
|
402
|
+
Returns
|
|
403
|
+
-------
|
|
404
|
+
True if the MIME type is considered binary, False otherwise.
|
|
405
|
+
"""
|
|
406
|
+
return any(mime_type.startswith(prefix) for prefix in BINARY_MIME_PREFIXES)
|
|
407
|
+
|
|
408
|
+
async def get_file_metadata(self, file_id: str) -> GoogleDriveFile:
|
|
409
|
+
"""Get file metadata from Google Drive.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
file_id: The ID of the file to get metadata for.
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
GoogleDriveFile with file metadata.
|
|
417
|
+
|
|
418
|
+
Raises
|
|
419
|
+
------
|
|
420
|
+
GoogleDriveError: If the file is not found or access is denied.
|
|
421
|
+
"""
|
|
422
|
+
params = {"fields": SUPPORTED_FIELDS_STR}
|
|
423
|
+
response = await self._client.get(f"/{file_id}", params=params)
|
|
424
|
+
|
|
425
|
+
if response.status_code == 404:
|
|
426
|
+
raise GoogleDriveError(f"File with ID '{file_id}' not found.")
|
|
427
|
+
if response.status_code == 403:
|
|
428
|
+
raise GoogleDriveError(f"Permission denied: you don't have access to file '{file_id}'.")
|
|
429
|
+
if response.status_code == 429:
|
|
430
|
+
raise GoogleDriveError("Rate limit exceeded. Please try again later.")
|
|
431
|
+
|
|
432
|
+
response.raise_for_status()
|
|
433
|
+
return GoogleDriveFile.from_api_response(response.json())
|
|
434
|
+
|
|
435
|
+
async def _export_workspace_file(self, file_id: str, export_mime_type: str) -> str:
|
|
436
|
+
"""Export a Google Workspace file to the specified format.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
file_id: The ID of the Google Workspace file.
|
|
440
|
+
export_mime_type: The MIME type to export to (e.g., 'text/markdown').
|
|
441
|
+
|
|
442
|
+
Returns
|
|
443
|
+
-------
|
|
444
|
+
The exported content as a string.
|
|
445
|
+
|
|
446
|
+
Raises
|
|
447
|
+
------
|
|
448
|
+
GoogleDriveError: If export fails.
|
|
449
|
+
"""
|
|
450
|
+
response = await self._client.get(
|
|
451
|
+
f"/{file_id}/export",
|
|
452
|
+
params={"mimeType": export_mime_type},
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
if response.status_code == 404:
|
|
456
|
+
raise GoogleDriveError(f"File with ID '{file_id}' not found.")
|
|
457
|
+
if response.status_code == 403:
|
|
458
|
+
raise GoogleDriveError(
|
|
459
|
+
f"Permission denied: you don't have access to export file '{file_id}'."
|
|
460
|
+
)
|
|
461
|
+
if response.status_code == 400:
|
|
462
|
+
raise GoogleDriveError(
|
|
463
|
+
f"Cannot export file '{file_id}' to format '{export_mime_type}'. "
|
|
464
|
+
"The file may not support this export format."
|
|
465
|
+
)
|
|
466
|
+
if response.status_code == 429:
|
|
467
|
+
raise GoogleDriveError("Rate limit exceeded. Please try again later.")
|
|
468
|
+
|
|
469
|
+
response.raise_for_status()
|
|
470
|
+
return response.text
|
|
471
|
+
|
|
472
|
+
async def _download_file(self, file_id: str) -> str:
|
|
473
|
+
"""Download a regular file's content from Google Drive as text."""
|
|
474
|
+
content = await self._download_file_bytes(file_id)
|
|
475
|
+
return content.decode("utf-8")
|
|
476
|
+
|
|
477
|
+
async def _download_file_bytes(self, file_id: str) -> bytes:
|
|
478
|
+
"""Download a file's content as bytes from Google Drive.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
file_id: The ID of the file to download.
|
|
482
|
+
|
|
483
|
+
Returns
|
|
484
|
+
-------
|
|
485
|
+
The file content as bytes.
|
|
486
|
+
|
|
487
|
+
Raises
|
|
488
|
+
------
|
|
489
|
+
GoogleDriveError: If download fails.
|
|
490
|
+
"""
|
|
491
|
+
response = await self._client.get(
|
|
492
|
+
f"/{file_id}",
|
|
493
|
+
params={"alt": "media"},
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
if response.status_code == 404:
|
|
497
|
+
raise GoogleDriveError(f"File with ID '{file_id}' not found.")
|
|
498
|
+
if response.status_code == 403:
|
|
499
|
+
raise GoogleDriveError(
|
|
500
|
+
f"Permission denied: you don't have access to download file '{file_id}'."
|
|
501
|
+
)
|
|
502
|
+
if response.status_code == 429:
|
|
503
|
+
raise GoogleDriveError("Rate limit exceeded. Please try again later.")
|
|
504
|
+
|
|
505
|
+
response.raise_for_status()
|
|
506
|
+
return response.content
|
|
507
|
+
|
|
508
|
+
def _extract_text_from_pdf(self, pdf_bytes: bytes) -> str:
|
|
509
|
+
"""Extract text from PDF bytes using pypdf.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
pdf_bytes: The PDF file content as bytes.
|
|
513
|
+
|
|
514
|
+
Returns
|
|
515
|
+
-------
|
|
516
|
+
Extracted text from the PDF.
|
|
517
|
+
|
|
518
|
+
Raises
|
|
519
|
+
------
|
|
520
|
+
GoogleDriveError: If PDF text extraction fails.
|
|
521
|
+
"""
|
|
522
|
+
try:
|
|
523
|
+
reader = PdfReader(io.BytesIO(pdf_bytes))
|
|
524
|
+
text_parts = []
|
|
525
|
+
for page in reader.pages:
|
|
526
|
+
page_text = page.extract_text()
|
|
527
|
+
if page_text:
|
|
528
|
+
text_parts.append(page_text)
|
|
529
|
+
return "\n\n".join(text_parts)
|
|
530
|
+
except Exception as e:
|
|
531
|
+
raise GoogleDriveError(f"Failed to extract text from PDF: {e}")
|
|
532
|
+
|
|
533
|
+
async def read_file_content(
|
|
534
|
+
self, file_id: str, target_format: str | None = None
|
|
535
|
+
) -> GoogleDriveFileContent:
|
|
536
|
+
"""Read the content of a file from Google Drive.
|
|
537
|
+
|
|
538
|
+
Google Workspace files (Docs, Sheets, Slides) are automatically exported to
|
|
539
|
+
LLM-readable formats:
|
|
540
|
+
- Google Docs -> Markdown (text/markdown)
|
|
541
|
+
- Google Sheets -> CSV (text/csv)
|
|
542
|
+
- Google Slides -> Plain text (text/plain)
|
|
543
|
+
- PDF files -> Extracted text (text/plain)
|
|
544
|
+
|
|
545
|
+
Regular text files are downloaded directly.
|
|
546
|
+
Binary files (images, videos, etc.) will raise an error.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
file_id: The ID of the file to read.
|
|
550
|
+
target_format: Optional MIME type to export Google Workspace files to.
|
|
551
|
+
If not specified, uses sensible defaults. Has no effect on non-Workspace files.
|
|
552
|
+
|
|
553
|
+
Returns
|
|
554
|
+
-------
|
|
555
|
+
GoogleDriveFileContent with the file content and metadata.
|
|
556
|
+
|
|
557
|
+
Raises
|
|
558
|
+
------
|
|
559
|
+
GoogleDriveError: If the file cannot be read (not found, permission denied,
|
|
560
|
+
binary file, etc.).
|
|
561
|
+
"""
|
|
562
|
+
file_metadata = await self.get_file_metadata(file_id)
|
|
563
|
+
original_mime_type = file_metadata.mime_type
|
|
564
|
+
|
|
565
|
+
if self._is_binary_mime_type(original_mime_type):
|
|
566
|
+
raise GoogleDriveError(
|
|
567
|
+
f"Binary files are not supported for reading. "
|
|
568
|
+
f"File '{file_metadata.name}' has MIME type '{original_mime_type}'."
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
if original_mime_type == GOOGLE_DRIVE_FOLDER_MIME:
|
|
572
|
+
raise GoogleDriveError(
|
|
573
|
+
f"Cannot read content of a folder. '{file_metadata.name}' is a folder, not a file."
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
was_exported = False
|
|
577
|
+
if original_mime_type in GOOGLE_WORKSPACE_EXPORT_MIMES:
|
|
578
|
+
export_mime = target_format or GOOGLE_WORKSPACE_EXPORT_MIMES[original_mime_type]
|
|
579
|
+
content = await self._export_workspace_file(file_id, export_mime)
|
|
580
|
+
result_mime_type = export_mime
|
|
581
|
+
was_exported = True
|
|
582
|
+
elif original_mime_type == PDF_MIME_TYPE:
|
|
583
|
+
pdf_bytes = await self._download_file_bytes(file_id)
|
|
584
|
+
content = self._extract_text_from_pdf(pdf_bytes)
|
|
585
|
+
result_mime_type = "text/plain"
|
|
586
|
+
was_exported = True
|
|
587
|
+
else:
|
|
588
|
+
content = await self._download_file(file_id)
|
|
589
|
+
result_mime_type = original_mime_type
|
|
590
|
+
|
|
591
|
+
return GoogleDriveFileContent(
|
|
592
|
+
id=file_metadata.id,
|
|
593
|
+
name=file_metadata.name,
|
|
594
|
+
mime_type=result_mime_type,
|
|
595
|
+
content=content,
|
|
596
|
+
original_mime_type=original_mime_type,
|
|
597
|
+
was_exported=was_exported,
|
|
598
|
+
size=file_metadata.size,
|
|
599
|
+
web_view_link=file_metadata.web_view_link,
|
|
600
|
+
)
|
|
601
|
+
|
|
347
602
|
async def __aenter__(self) -> "GoogleDriveClient":
|
|
348
603
|
"""Async context manager entry."""
|
|
349
604
|
return self
|