adaptive-memory-engine 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_memory_engine-0.1.6.dist-info/METADATA +228 -0
- adaptive_memory_engine-0.1.6.dist-info/RECORD +72 -0
- adaptive_memory_engine-0.1.6.dist-info/WHEEL +4 -0
- adaptive_memory_engine-0.1.6.dist-info/entry_points.txt +3 -0
- adaptive_memory_engine-0.1.6.dist-info/licenses/LICENSE +21 -0
- ame/__init__.py +1 -0
- ame/agent/__init__.py +1 -0
- ame/agent/mcp.py +474 -0
- ame/agent/memory_api.py +141 -0
- ame/agent/results.py +30 -0
- ame/bronze/schema.py +17 -0
- ame/bronze/store.py +38 -0
- ame/cli/__init__.py +1 -0
- ame/cli/main.py +903 -0
- ame/connectors/base.py +30 -0
- ame/connectors/contract.py +199 -0
- ame/connectors/github.py +66 -0
- ame/connectors/google.py +464 -0
- ame/connectors/google_oauth.py +156 -0
- ame/connectors/jira.py +66 -0
- ame/connectors/json_helpers.py +43 -0
- ame/connectors/markdown.py +116 -0
- ame/connectors/notion.py +59 -0
- ame/connectors/oauth_callback.py +102 -0
- ame/connectors/oauth_provider.py +250 -0
- ame/connectors/obsidian.py +19 -0
- ame/connectors/router.py +155 -0
- ame/connectors/slack.py +66 -0
- ame/connectors/slack_oauth.py +417 -0
- ame/connectors/sync_history.py +73 -0
- ame/context_budget.py +106 -0
- ame/core/config.py +77 -0
- ame/core/corpus.py +17 -0
- ame/core/errors.py +18 -0
- ame/core/paths.py +111 -0
- ame/core/state.py +57 -0
- ame/export/obsidian.py +123 -0
- ame/gold/builder.py +300 -0
- ame/gold/ontology.py +80 -0
- ame/gold/resolver.py +91 -0
- ame/gold/schema.py +40 -0
- ame/gold/store.py +45 -0
- ame/hardware/profiler.py +85 -0
- ame/hardware/tier.py +27 -0
- ame/hermes/__init__.py +3 -0
- ame/hermes/memory.py +209 -0
- ame/models/download.py +243 -0
- ame/models/ollama.py +60 -0
- ame/models/registry.py +101 -0
- ame/models/router.py +22 -0
- ame/pipeline.py +155 -0
- ame/query/diff.py +40 -0
- ame/query/engine.py +919 -0
- ame/query/memory_os.py +313 -0
- ame/query/mql.py +84 -0
- ame/query/multihop.py +264 -0
- ame/query/result.py +20 -0
- ame/sdk.py +52 -0
- ame/security.py +145 -0
- ame/silver/extractor.py +414 -0
- ame/silver/llm_extractor.py +181 -0
- ame/silver/prompts.py +56 -0
- ame/silver/rationale.py +140 -0
- ame/silver/schema.py +51 -0
- ame/silver/store.py +59 -0
- ame/storage/custom_kg.py +33 -0
- ame/storage/lightrag_adapter.py +362 -0
- ame/validation/confidence.py +5 -0
- ame/validation/grounding.py +10 -0
- ame/validation/type_gate.py +22 -0
- ame/writeback.py +173 -0
- memory/__init__.py +3 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def read_json(path: Path) -> Any:
|
|
9
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def as_list(value: Any) -> list[Any]:
|
|
13
|
+
if isinstance(value, list):
|
|
14
|
+
return value
|
|
15
|
+
if isinstance(value, dict):
|
|
16
|
+
for key in [
|
|
17
|
+
"messages",
|
|
18
|
+
"issues",
|
|
19
|
+
"pull_requests",
|
|
20
|
+
"prs",
|
|
21
|
+
"discussions",
|
|
22
|
+
"comments",
|
|
23
|
+
"files",
|
|
24
|
+
"documents",
|
|
25
|
+
"threads",
|
|
26
|
+
"events",
|
|
27
|
+
"rows",
|
|
28
|
+
"values",
|
|
29
|
+
"items",
|
|
30
|
+
]:
|
|
31
|
+
rows = value.get(key)
|
|
32
|
+
if isinstance(rows, list):
|
|
33
|
+
return rows
|
|
34
|
+
return [value]
|
|
35
|
+
return []
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def first_present(row: dict[str, Any], *keys: str) -> Any:
|
|
39
|
+
for key in keys:
|
|
40
|
+
value = row.get(key)
|
|
41
|
+
if value is not None:
|
|
42
|
+
return value
|
|
43
|
+
return None
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import re
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ame.bronze.schema import BronzeDocument
|
|
9
|
+
from ame.connectors.base import SourceRef
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
FRONTMATTER_RE = re.compile(r"\A---\s*\n(.*?)\n---\s*\n", re.DOTALL)
|
|
13
|
+
HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class MarkdownConnector:
|
|
17
|
+
source_type = "markdown"
|
|
18
|
+
|
|
19
|
+
def scan(self, path: Path) -> list[SourceRef]:
|
|
20
|
+
root = path.expanduser().resolve()
|
|
21
|
+
files = [root] if root.is_file() else sorted([*root.rglob("*.md"), *root.rglob("*.markdown")])
|
|
22
|
+
refs: list[SourceRef] = []
|
|
23
|
+
for file in files:
|
|
24
|
+
source_id = str(file.relative_to(root) if root.is_dir() else file.name)
|
|
25
|
+
content = file.read_text(encoding="utf-8")
|
|
26
|
+
sections = self._sections(content)
|
|
27
|
+
if len(sections) <= 1:
|
|
28
|
+
refs.append(SourceRef(path=file, source_id=source_id, content=content, root_source_id=source_id))
|
|
29
|
+
continue
|
|
30
|
+
for index, section in enumerate(sections):
|
|
31
|
+
refs.append(
|
|
32
|
+
SourceRef(
|
|
33
|
+
path=file,
|
|
34
|
+
source_id=f"{source_id}#{self._safe_section_id(section['title'])}",
|
|
35
|
+
content=section["content"],
|
|
36
|
+
section_title=section["title"],
|
|
37
|
+
section_path=tuple(section["path"]),
|
|
38
|
+
section_level=section["level"],
|
|
39
|
+
section_index=index,
|
|
40
|
+
root_source_id=source_id,
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
return refs
|
|
44
|
+
|
|
45
|
+
def load(self, corpus_id: str, ref: SourceRef) -> BronzeDocument:
|
|
46
|
+
content = ref.content if ref.content is not None else ref.path.read_text(encoding="utf-8")
|
|
47
|
+
digest = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
|
48
|
+
metadata = self._metadata(content, ref.path)
|
|
49
|
+
if ref.section_title:
|
|
50
|
+
metadata["title"] = ref.section_title
|
|
51
|
+
metadata["section_title"] = ref.section_title
|
|
52
|
+
metadata["section_path"] = list(ref.section_path)
|
|
53
|
+
metadata["section_level"] = ref.section_level
|
|
54
|
+
metadata["section_index"] = ref.section_index
|
|
55
|
+
metadata["root_source_id"] = ref.root_source_id or ref.source_id.split("#", 1)[0]
|
|
56
|
+
metadata["source_file"] = ref.root_source_id or ref.source_id.split("#", 1)[0]
|
|
57
|
+
return BronzeDocument(
|
|
58
|
+
id=f"bronze_{digest[:16]}",
|
|
59
|
+
corpus_id=corpus_id,
|
|
60
|
+
source_type=self.source_type,
|
|
61
|
+
source_id=ref.source_id,
|
|
62
|
+
content=content,
|
|
63
|
+
metadata=metadata,
|
|
64
|
+
content_hash=f"sha256:{digest}",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
def _metadata(self, content: str, path: Path) -> dict[str, Any]:
|
|
68
|
+
metadata: dict[str, Any] = {"path": str(path), "title": path.stem}
|
|
69
|
+
frontmatter = FRONTMATTER_RE.search(content)
|
|
70
|
+
if frontmatter:
|
|
71
|
+
metadata["frontmatter"] = self._parse_frontmatter(frontmatter.group(1))
|
|
72
|
+
if title := metadata["frontmatter"].get("title"):
|
|
73
|
+
metadata["title"] = title
|
|
74
|
+
metadata["headings"] = [match.group(2).strip() for match in HEADING_RE.finditer(content)]
|
|
75
|
+
if metadata["title"] == path.stem and metadata["headings"]:
|
|
76
|
+
metadata["title"] = metadata["headings"][0]
|
|
77
|
+
return metadata
|
|
78
|
+
|
|
79
|
+
def _parse_frontmatter(self, text: str) -> dict[str, str]:
|
|
80
|
+
data: dict[str, str] = {}
|
|
81
|
+
for line in text.splitlines():
|
|
82
|
+
if ":" not in line:
|
|
83
|
+
continue
|
|
84
|
+
key, value = line.split(":", 1)
|
|
85
|
+
data[key.strip()] = value.strip().strip('"')
|
|
86
|
+
return data
|
|
87
|
+
|
|
88
|
+
def _sections(self, content: str) -> list[dict]:
|
|
89
|
+
matches = [match for match in HEADING_RE.finditer(content) if len(match.group(1)) <= 2]
|
|
90
|
+
if len(matches) <= 1:
|
|
91
|
+
return []
|
|
92
|
+
sections: list[dict] = []
|
|
93
|
+
path_by_level: dict[int, str] = {}
|
|
94
|
+
for index, match in enumerate(matches):
|
|
95
|
+
level = len(match.group(1))
|
|
96
|
+
title = match.group(2).strip()
|
|
97
|
+
path_by_level[level] = title
|
|
98
|
+
for deeper in [known for known in path_by_level if known > level]:
|
|
99
|
+
del path_by_level[deeper]
|
|
100
|
+
start = match.start()
|
|
101
|
+
end = matches[index + 1].start() if index + 1 < len(matches) else len(content)
|
|
102
|
+
section_content = content[start:end].strip()
|
|
103
|
+
if not section_content:
|
|
104
|
+
continue
|
|
105
|
+
sections.append(
|
|
106
|
+
{
|
|
107
|
+
"title": title,
|
|
108
|
+
"level": level,
|
|
109
|
+
"path": [path_by_level[key] for key in sorted(path_by_level)],
|
|
110
|
+
"content": section_content,
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
return sections
|
|
114
|
+
|
|
115
|
+
def _safe_section_id(self, title: str) -> str:
|
|
116
|
+
return re.sub(r"\s+", " ", title).strip().replace("/", "-").replace(":", " -")
|
ame/connectors/notion.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ame.bronze.schema import BronzeDocument
|
|
8
|
+
from ame.connectors.base import SourceRef
|
|
9
|
+
from ame.connectors.json_helpers import as_list, first_present, read_json
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NotionConnector:
|
|
13
|
+
source_type = "notion"
|
|
14
|
+
|
|
15
|
+
def scan(self, path: Path) -> list[SourceRef]:
|
|
16
|
+
root = path.expanduser().resolve()
|
|
17
|
+
files = [root] if root.is_file() else sorted(root.rglob("*.json"))
|
|
18
|
+
refs: list[SourceRef] = []
|
|
19
|
+
for file in files:
|
|
20
|
+
for row in as_list(read_json(file)):
|
|
21
|
+
if not isinstance(row, dict):
|
|
22
|
+
continue
|
|
23
|
+
page_id = str(first_present(row, "id", "page_id") or file.stem)
|
|
24
|
+
refs.append(SourceRef(path=file, source_id=f"notion:{page_id}", content=self._page_content(row, page_id)))
|
|
25
|
+
return refs
|
|
26
|
+
|
|
27
|
+
def load(self, corpus_id: str, ref: SourceRef) -> BronzeDocument:
|
|
28
|
+
content = ref.content or ref.path.read_text(encoding="utf-8")
|
|
29
|
+
digest = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
|
30
|
+
return BronzeDocument(
|
|
31
|
+
id=f"bronze_{digest[:16]}",
|
|
32
|
+
corpus_id=corpus_id,
|
|
33
|
+
source_type=self.source_type,
|
|
34
|
+
source_id=ref.source_id,
|
|
35
|
+
content=content,
|
|
36
|
+
metadata={"path": str(ref.path), "title": ref.source_id, "connector": "notion"},
|
|
37
|
+
content_hash=f"sha256:{digest}",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def _page_content(self, row: dict[str, Any], page_id: str) -> str:
|
|
41
|
+
title = str(first_present(row, "title", "name") or page_id)
|
|
42
|
+
body = str(first_present(row, "content", "body", "text") or "")
|
|
43
|
+
url = str(first_present(row, "url", "public_url") or "")
|
|
44
|
+
last_edited_time = str(first_present(row, "last_edited_time", "updated_at") or "")
|
|
45
|
+
return "\n".join(
|
|
46
|
+
[
|
|
47
|
+
"---",
|
|
48
|
+
f"title: {title}",
|
|
49
|
+
f"notion_page_id: {page_id}",
|
|
50
|
+
f"url: {url}",
|
|
51
|
+
f"last_edited_time: {last_edited_time}",
|
|
52
|
+
"---",
|
|
53
|
+
"",
|
|
54
|
+
f"# {title}",
|
|
55
|
+
"",
|
|
56
|
+
body,
|
|
57
|
+
"",
|
|
58
|
+
]
|
|
59
|
+
)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
import urllib.parse
|
|
5
|
+
import webbrowser
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OAuthCallbackError(RuntimeError):
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class OAuthCallbackResult:
|
|
16
|
+
code: str
|
|
17
|
+
state: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def new_oauth_state() -> str:
|
|
21
|
+
return secrets.token_urlsafe(24)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def run_local_oauth_login(
|
|
25
|
+
authorization_url: str,
|
|
26
|
+
redirect_uri: str,
|
|
27
|
+
expected_state: str,
|
|
28
|
+
*,
|
|
29
|
+
open_browser: bool = True,
|
|
30
|
+
timeout_seconds: int = 180,
|
|
31
|
+
) -> OAuthCallbackResult:
|
|
32
|
+
parsed = urllib.parse.urlparse(redirect_uri)
|
|
33
|
+
if parsed.scheme != "http" or parsed.hostname not in {"localhost", "127.0.0.1"}:
|
|
34
|
+
raise OAuthCallbackError("Local OAuth login requires an http://localhost redirect URI")
|
|
35
|
+
if not parsed.port:
|
|
36
|
+
raise OAuthCallbackError("Local OAuth login redirect URI must include a port")
|
|
37
|
+
|
|
38
|
+
server = _OAuthCallbackServer((parsed.hostname, parsed.port), _OAuthCallbackHandler)
|
|
39
|
+
server.expected_state = expected_state
|
|
40
|
+
server.timeout = timeout_seconds
|
|
41
|
+
try:
|
|
42
|
+
if open_browser:
|
|
43
|
+
webbrowser.open(authorization_url)
|
|
44
|
+
server.handle_request()
|
|
45
|
+
finally:
|
|
46
|
+
server.server_close()
|
|
47
|
+
|
|
48
|
+
if server.error:
|
|
49
|
+
raise OAuthCallbackError(server.error)
|
|
50
|
+
if server.result is None:
|
|
51
|
+
raise OAuthCallbackError("OAuth login timed out before receiving a callback")
|
|
52
|
+
return server.result
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class _OAuthCallbackServer(HTTPServer):
|
|
56
|
+
expected_state: str
|
|
57
|
+
result: OAuthCallbackResult | None = None
|
|
58
|
+
error: str | None = None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class _OAuthCallbackHandler(BaseHTTPRequestHandler):
|
|
62
|
+
def do_GET(self) -> None:
|
|
63
|
+
parsed = urllib.parse.urlparse(self.path)
|
|
64
|
+
query = urllib.parse.parse_qs(parsed.query)
|
|
65
|
+
state = _single(query, "state")
|
|
66
|
+
code = _single(query, "code")
|
|
67
|
+
error = _single(query, "error")
|
|
68
|
+
|
|
69
|
+
if error:
|
|
70
|
+
self.server.error = f"OAuth provider returned error: {error}" # type: ignore[attr-defined]
|
|
71
|
+
self._respond(400, "OAuth login failed. You can close this window.")
|
|
72
|
+
return
|
|
73
|
+
if state != self.server.expected_state: # type: ignore[attr-defined]
|
|
74
|
+
self.server.error = "OAuth callback state did not match" # type: ignore[attr-defined]
|
|
75
|
+
self._respond(400, "OAuth login state mismatch. You can close this window.")
|
|
76
|
+
return
|
|
77
|
+
if not code:
|
|
78
|
+
self.server.error = "OAuth callback did not include code" # type: ignore[attr-defined]
|
|
79
|
+
self._respond(400, "OAuth login did not include a code. You can close this window.")
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
self.server.result = OAuthCallbackResult(code=code, state=state) # type: ignore[attr-defined]
|
|
83
|
+
self._respond(200, "OAuth login complete. You can close this window.")
|
|
84
|
+
|
|
85
|
+
def log_message(self, format: str, *args) -> None:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
def _respond(self, status: int, message: str) -> None:
|
|
89
|
+
payload = (
|
|
90
|
+
"<!doctype html><html><head><meta charset=\"utf-8\"><title>Adaptive Memory Engine</title></head>"
|
|
91
|
+
f"<body><h1>{message}</h1></body></html>"
|
|
92
|
+
).encode("utf-8")
|
|
93
|
+
self.send_response(status)
|
|
94
|
+
self.send_header("Content-Type", "text/html; charset=utf-8")
|
|
95
|
+
self.send_header("Content-Length", str(len(payload)))
|
|
96
|
+
self.end_headers()
|
|
97
|
+
self.wfile.write(payload)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _single(query: dict[str, list[str]], key: str) -> str:
|
|
101
|
+
values = query.get(key) or []
|
|
102
|
+
return values[0] if values else ""
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import json
|
|
5
|
+
import urllib.parse
|
|
6
|
+
import urllib.request
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Literal, Protocol
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
|
|
13
|
+
from ame.core.paths import ame_home
|
|
14
|
+
from ame.security import token_vault
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
TokenRequestStyle = Literal["form", "json"]
|
|
18
|
+
ClientAuthStyle = Literal["body", "basic"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConnectedAppOAuthError(RuntimeError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OAuthProviderSpec(BaseModel):
|
|
26
|
+
name: str
|
|
27
|
+
authorize_url: str
|
|
28
|
+
token_url: str
|
|
29
|
+
default_redirect_uri: str
|
|
30
|
+
default_scopes: list[str] = Field(default_factory=list)
|
|
31
|
+
scope_separator: str = " "
|
|
32
|
+
auth_extra: dict[str, str] = Field(default_factory=dict)
|
|
33
|
+
token_request_style: TokenRequestStyle = "form"
|
|
34
|
+
client_auth_style: ClientAuthStyle = "body"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConnectedAppOAuthConfig(BaseModel):
|
|
38
|
+
provider: str
|
|
39
|
+
client_id: str = ""
|
|
40
|
+
client_secret: str = ""
|
|
41
|
+
redirect_uri: str
|
|
42
|
+
scopes: list[str] = Field(default_factory=list)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ConnectedAppToken(BaseModel):
|
|
46
|
+
provider: str
|
|
47
|
+
account_id: str
|
|
48
|
+
access_token: str
|
|
49
|
+
refresh_token: str | None = None
|
|
50
|
+
token_type: str = "Bearer"
|
|
51
|
+
expires_in: int | None = None
|
|
52
|
+
scopes: list[str] = Field(default_factory=list)
|
|
53
|
+
raw: dict[str, Any] = Field(default_factory=dict)
|
|
54
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ConnectedAppHttpClient(Protocol):
|
|
58
|
+
def post_json(
|
|
59
|
+
self,
|
|
60
|
+
url: str,
|
|
61
|
+
data: dict[str, Any],
|
|
62
|
+
headers: dict[str, str] | None = None,
|
|
63
|
+
*,
|
|
64
|
+
json_body: bool = False,
|
|
65
|
+
) -> dict[str, Any]:
|
|
66
|
+
...
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class UrlLibConnectedAppHttpClient:
|
|
70
|
+
def post_json(
|
|
71
|
+
self,
|
|
72
|
+
url: str,
|
|
73
|
+
data: dict[str, Any],
|
|
74
|
+
headers: dict[str, str] | None = None,
|
|
75
|
+
*,
|
|
76
|
+
json_body: bool = False,
|
|
77
|
+
) -> dict[str, Any]:
|
|
78
|
+
request_headers = dict(headers or {})
|
|
79
|
+
if json_body:
|
|
80
|
+
body = json.dumps(data).encode("utf-8")
|
|
81
|
+
request_headers.setdefault("Content-Type", "application/json")
|
|
82
|
+
else:
|
|
83
|
+
body = urllib.parse.urlencode(data).encode("utf-8")
|
|
84
|
+
request_headers.setdefault("Content-Type", "application/x-www-form-urlencoded")
|
|
85
|
+
request_headers.setdefault("Accept", "application/json")
|
|
86
|
+
request = urllib.request.Request(url, data=body, headers=request_headers, method="POST")
|
|
87
|
+
with urllib.request.urlopen(request, timeout=30) as response:
|
|
88
|
+
return json.loads(response.read().decode("utf-8"))
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
PROVIDER_SPECS: dict[str, OAuthProviderSpec] = {
|
|
92
|
+
"github": OAuthProviderSpec(
|
|
93
|
+
name="github",
|
|
94
|
+
authorize_url="https://github.com/login/oauth/authorize",
|
|
95
|
+
token_url="https://github.com/login/oauth/access_token",
|
|
96
|
+
default_redirect_uri="http://localhost:8765/github/oauth/callback",
|
|
97
|
+
default_scopes=["read:user"],
|
|
98
|
+
),
|
|
99
|
+
"notion": OAuthProviderSpec(
|
|
100
|
+
name="notion",
|
|
101
|
+
authorize_url="https://api.notion.com/v1/oauth/authorize",
|
|
102
|
+
token_url="https://api.notion.com/v1/oauth/token",
|
|
103
|
+
default_redirect_uri="http://localhost:8765/notion/oauth/callback",
|
|
104
|
+
default_scopes=[],
|
|
105
|
+
auth_extra={"owner": "user"},
|
|
106
|
+
token_request_style="json",
|
|
107
|
+
client_auth_style="basic",
|
|
108
|
+
),
|
|
109
|
+
"jira": OAuthProviderSpec(
|
|
110
|
+
name="jira",
|
|
111
|
+
authorize_url="https://auth.atlassian.com/authorize",
|
|
112
|
+
token_url="https://auth.atlassian.com/oauth/token",
|
|
113
|
+
default_redirect_uri="http://localhost:8765/jira/oauth/callback",
|
|
114
|
+
default_scopes=["read:jira-work", "read:jira-user", "offline_access"],
|
|
115
|
+
auth_extra={"audience": "api.atlassian.com", "prompt": "consent"},
|
|
116
|
+
token_request_style="json",
|
|
117
|
+
),
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def provider_spec(provider: str) -> OAuthProviderSpec:
|
|
122
|
+
normalized = provider.casefold().replace("_", "-")
|
|
123
|
+
aliases = {"atlassian": "jira", "github-oauth": "github", "notion-oauth": "notion", "jira-oauth": "jira"}
|
|
124
|
+
normalized = aliases.get(normalized, normalized)
|
|
125
|
+
spec = PROVIDER_SPECS.get(normalized)
|
|
126
|
+
if not spec:
|
|
127
|
+
raise ConnectedAppOAuthError(f"Unsupported connected app provider: {provider}")
|
|
128
|
+
return spec
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ConnectedAppOAuthClient:
|
|
132
|
+
def __init__(self, config: ConnectedAppOAuthConfig, spec: OAuthProviderSpec | None = None, http: ConnectedAppHttpClient | None = None):
|
|
133
|
+
self.config = config
|
|
134
|
+
self.spec = spec or provider_spec(config.provider)
|
|
135
|
+
self.http = http or UrlLibConnectedAppHttpClient()
|
|
136
|
+
|
|
137
|
+
def authorization_url(self, state: str) -> str:
|
|
138
|
+
params = {
|
|
139
|
+
"client_id": self.config.client_id,
|
|
140
|
+
"redirect_uri": self.config.redirect_uri,
|
|
141
|
+
"response_type": "code",
|
|
142
|
+
"state": state,
|
|
143
|
+
**self.spec.auth_extra,
|
|
144
|
+
}
|
|
145
|
+
if self.config.scopes:
|
|
146
|
+
params["scope"] = self.spec.scope_separator.join(self.config.scopes)
|
|
147
|
+
return f"{self.spec.authorize_url}?{urllib.parse.urlencode(params)}"
|
|
148
|
+
|
|
149
|
+
def exchange_code(self, code: str, account_id: str = "default") -> ConnectedAppToken:
|
|
150
|
+
headers: dict[str, str] = {}
|
|
151
|
+
payload = {
|
|
152
|
+
"code": code,
|
|
153
|
+
"redirect_uri": self.config.redirect_uri,
|
|
154
|
+
"grant_type": "authorization_code",
|
|
155
|
+
}
|
|
156
|
+
if self.spec.client_auth_style == "basic":
|
|
157
|
+
auth = f"{self.config.client_id}:{self.config.client_secret}".encode("utf-8")
|
|
158
|
+
headers["Authorization"] = f"Basic {base64.b64encode(auth).decode('ascii')}"
|
|
159
|
+
else:
|
|
160
|
+
payload["client_id"] = self.config.client_id
|
|
161
|
+
payload["client_secret"] = self.config.client_secret
|
|
162
|
+
|
|
163
|
+
response = self.http.post_json(
|
|
164
|
+
self.spec.token_url,
|
|
165
|
+
payload,
|
|
166
|
+
headers=headers,
|
|
167
|
+
json_body=self.spec.token_request_style == "json",
|
|
168
|
+
)
|
|
169
|
+
if error := response.get("error"):
|
|
170
|
+
raise ConnectedAppOAuthError(f"{self.spec.name} OAuth exchange failed: {error}")
|
|
171
|
+
access_token = str(response.get("access_token") or "")
|
|
172
|
+
if not access_token:
|
|
173
|
+
raise ConnectedAppOAuthError(f"{self.spec.name} OAuth response did not include access_token")
|
|
174
|
+
return ConnectedAppToken(
|
|
175
|
+
provider=self.spec.name,
|
|
176
|
+
account_id=account_id,
|
|
177
|
+
access_token=access_token,
|
|
178
|
+
refresh_token=response.get("refresh_token"),
|
|
179
|
+
token_type=str(response.get("token_type") or "Bearer"),
|
|
180
|
+
expires_in=response.get("expires_in"),
|
|
181
|
+
scopes=_split_scopes(response.get("scope")) or list(self.config.scopes),
|
|
182
|
+
raw={key: value for key, value in response.items() if key not in {"access_token", "refresh_token"}},
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class ConnectedAppTokenStore:
|
|
187
|
+
def __init__(self, provider: str, path: Path | None = None, backend: str = "file"):
|
|
188
|
+
self.provider = provider_spec(provider).name
|
|
189
|
+
self.path = path or ame_home() / "tokens" / f"{self.provider}.json"
|
|
190
|
+
self.vault = token_vault(self.provider, self.path, backend=backend) # type: ignore[arg-type]
|
|
191
|
+
|
|
192
|
+
def save(self, token: ConnectedAppToken) -> ConnectedAppToken:
|
|
193
|
+
data = self._read()
|
|
194
|
+
data[token.account_id] = token.model_dump(mode="json")
|
|
195
|
+
self.vault.save(data)
|
|
196
|
+
return token
|
|
197
|
+
|
|
198
|
+
def load(self, account_id: str = "default") -> ConnectedAppToken:
|
|
199
|
+
data = self._read()
|
|
200
|
+
row = data.get(account_id)
|
|
201
|
+
if not isinstance(row, dict):
|
|
202
|
+
raise ConnectedAppOAuthError(f"{self.provider} token not found for account_id={account_id}")
|
|
203
|
+
return ConnectedAppToken.model_validate(row)
|
|
204
|
+
|
|
205
|
+
def revoke(self, account_id: str = "default") -> bool:
|
|
206
|
+
data = self._read()
|
|
207
|
+
existed = account_id in data
|
|
208
|
+
data.pop(account_id, None)
|
|
209
|
+
if data:
|
|
210
|
+
self.vault.save(data)
|
|
211
|
+
else:
|
|
212
|
+
self.vault.delete()
|
|
213
|
+
return existed
|
|
214
|
+
|
|
215
|
+
def _read(self) -> dict[str, Any]:
|
|
216
|
+
return self.vault.load()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def exchange_and_save_connected_app_token(
|
|
220
|
+
provider: str,
|
|
221
|
+
code: str,
|
|
222
|
+
config: ConnectedAppOAuthConfig,
|
|
223
|
+
*,
|
|
224
|
+
account_id: str = "default",
|
|
225
|
+
store_path: Path | None = None,
|
|
226
|
+
token_backend: str = "file",
|
|
227
|
+
http: ConnectedAppHttpClient | None = None,
|
|
228
|
+
) -> ConnectedAppToken:
|
|
229
|
+
spec = provider_spec(provider)
|
|
230
|
+
token = ConnectedAppOAuthClient(config, spec=spec, http=http).exchange_code(code, account_id=account_id)
|
|
231
|
+
return ConnectedAppTokenStore(spec.name, store_path, backend=token_backend).save(token)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def default_config(provider: str, *, client_id: str = "", client_secret: str = "", redirect_uri: str | None = None, scopes: list[str] | None = None) -> ConnectedAppOAuthConfig:
|
|
235
|
+
spec = provider_spec(provider)
|
|
236
|
+
return ConnectedAppOAuthConfig(
|
|
237
|
+
provider=spec.name,
|
|
238
|
+
client_id=client_id,
|
|
239
|
+
client_secret=client_secret,
|
|
240
|
+
redirect_uri=redirect_uri or spec.default_redirect_uri,
|
|
241
|
+
scopes=list(spec.default_scopes if scopes is None else scopes),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _split_scopes(value: Any) -> list[str]:
|
|
246
|
+
if not value:
|
|
247
|
+
return []
|
|
248
|
+
if isinstance(value, list):
|
|
249
|
+
return [str(item) for item in value if str(item).strip()]
|
|
250
|
+
return [item for item in str(value).replace(",", " ").split() if item]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
from ame.connectors.markdown import MarkdownConnector
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
WIKILINK_RE = re.compile(r"\[\[([^\]]+)\]\]")
|
|
9
|
+
TAG_RE = re.compile(r"(?<!\w)#([A-Za-z0-9_/-]+)")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ObsidianConnector(MarkdownConnector):
|
|
13
|
+
source_type = "obsidian"
|
|
14
|
+
|
|
15
|
+
def _metadata(self, content: str, path):
|
|
16
|
+
metadata = super()._metadata(content, path)
|
|
17
|
+
metadata["wikilinks"] = sorted({m.group(1).split("|", 1)[0].strip() for m in WIKILINK_RE.finditer(content)})
|
|
18
|
+
metadata["tags"] = sorted({m.group(1) for m in TAG_RE.finditer(content)})
|
|
19
|
+
return metadata
|