cyphersmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cyphersmith/__init__.py +16 -0
- cyphersmith/__main__.py +4 -0
- cyphersmith/cli.py +303 -0
- cyphersmith/context.py +56 -0
- cyphersmith/generator.py +373 -0
- cyphersmith/llm.py +162 -0
- cyphersmith/models.py +108 -0
- cyphersmith/neo4j_client.py +414 -0
- cyphersmith/progress.py +66 -0
- cyphersmith/prompts.py +129 -0
- cyphersmith/safety.py +54 -0
- cyphersmith/validation.py +167 -0
- cyphersmith-0.1.0.dist-info/METADATA +135 -0
- cyphersmith-0.1.0.dist-info/RECORD +18 -0
- cyphersmith-0.1.0.dist-info/WHEEL +5 -0
- cyphersmith-0.1.0.dist-info/entry_points.txt +2 -0
- cyphersmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- cyphersmith-0.1.0.dist-info/top_level.txt +1 -0
cyphersmith/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Public API for the cyphersmith package."""
|
|
2
|
+
|
|
3
|
+
from .generator import CypherGenerator
|
|
4
|
+
from .models import CypherGenerationResult, LLMConfig, Neo4jCredentials
|
|
5
|
+
from .neo4j_client import Neo4jClient, get_neo4j_client
|
|
6
|
+
from .progress import TerminalProgressReporter
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"CypherGenerationResult",
|
|
10
|
+
"CypherGenerator",
|
|
11
|
+
"LLMConfig",
|
|
12
|
+
"Neo4jClient",
|
|
13
|
+
"Neo4jCredentials",
|
|
14
|
+
"TerminalProgressReporter",
|
|
15
|
+
"get_neo4j_client",
|
|
16
|
+
]
|
cyphersmith/__main__.py
ADDED
cyphersmith/cli.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import getpass
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
7
|
+
from typing import Sequence
|
|
8
|
+
|
|
9
|
+
from .generator import CypherGenerator
|
|
10
|
+
from .models import LLMConfig, Neo4jCredentials
|
|
11
|
+
from .progress import TerminalProgressReporter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def main(argv: Sequence[str] | None = None) -> int:
|
|
15
|
+
parser = _build_parser()
|
|
16
|
+
args = parser.parse_args(argv)
|
|
17
|
+
|
|
18
|
+
if args.command == "ask":
|
|
19
|
+
return _run_ask(args)
|
|
20
|
+
if args.command in {"chat", "setup", "interactive"}:
|
|
21
|
+
return _run_chat(args)
|
|
22
|
+
|
|
23
|
+
parser.print_help()
|
|
24
|
+
return 2
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
28
|
+
parser = argparse.ArgumentParser(
|
|
29
|
+
prog="cypher-generator",
|
|
30
|
+
description="Generate, validate, and execute read-only Cypher against Neo4j.",
|
|
31
|
+
)
|
|
32
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
33
|
+
|
|
34
|
+
ask = subparsers.add_parser("ask", help="Ask a natural-language graph question.")
|
|
35
|
+
ask.add_argument("query", help="Natural-language graph question.")
|
|
36
|
+
ask.add_argument("--neo4j-uri", dest="neo4j_uri")
|
|
37
|
+
ask.add_argument("--neo4j-user", dest="neo4j_user")
|
|
38
|
+
ask.add_argument("--neo4j-password", dest="neo4j_password")
|
|
39
|
+
ask.add_argument("--neo4j-database", dest="neo4j_database", default=None)
|
|
40
|
+
ask.add_argument("--model", required=True, help="LiteLLM model name.")
|
|
41
|
+
ask.add_argument("--temperature", type=float, default=0)
|
|
42
|
+
ask.add_argument("--timeout", type=float, default=None)
|
|
43
|
+
ask.add_argument("--api-key", dest="api_key", default=None)
|
|
44
|
+
ask.add_argument("--api-base", dest="api_base", default=None)
|
|
45
|
+
ask.add_argument("--api-version", dest="api_version", default=None)
|
|
46
|
+
ask.add_argument("--business-context", dest="business_context", default=None)
|
|
47
|
+
ask.add_argument("--max-validation-retries", type=int, default=2)
|
|
48
|
+
ask.add_argument(
|
|
49
|
+
"--no-progress",
|
|
50
|
+
action="store_true",
|
|
51
|
+
help="Disable intermediate progress output.",
|
|
52
|
+
)
|
|
53
|
+
ask.add_argument(
|
|
54
|
+
"--no-color",
|
|
55
|
+
action="store_true",
|
|
56
|
+
help="Disable ANSI colors in progress output.",
|
|
57
|
+
)
|
|
58
|
+
ask.add_argument("--pretty", action="store_true", help="Pretty-print JSON output.")
|
|
59
|
+
ask.set_defaults(command="ask")
|
|
60
|
+
|
|
61
|
+
chat = subparsers.add_parser(
|
|
62
|
+
"chat",
|
|
63
|
+
aliases=["setup", "interactive"],
|
|
64
|
+
help="Prompt for setup once, then ask questions in a loop.",
|
|
65
|
+
)
|
|
66
|
+
chat.add_argument("--pretty", action="store_true", help="Pretty-print result JSON.")
|
|
67
|
+
chat.add_argument("--max-validation-retries", type=int, default=2)
|
|
68
|
+
chat.add_argument(
|
|
69
|
+
"--no-progress",
|
|
70
|
+
action="store_true",
|
|
71
|
+
help="Disable intermediate progress output.",
|
|
72
|
+
)
|
|
73
|
+
chat.add_argument(
|
|
74
|
+
"--no-color",
|
|
75
|
+
action="store_true",
|
|
76
|
+
help="Disable ANSI colors in progress output.",
|
|
77
|
+
)
|
|
78
|
+
chat.set_defaults(command="chat")
|
|
79
|
+
return parser
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _run_ask(args: argparse.Namespace) -> int:
|
|
83
|
+
progress_reporter = _build_progress_reporter(args)
|
|
84
|
+
generator = CypherGenerator(
|
|
85
|
+
neo4j=Neo4jCredentials(
|
|
86
|
+
uri=args.neo4j_uri,
|
|
87
|
+
username=args.neo4j_user,
|
|
88
|
+
password=args.neo4j_password,
|
|
89
|
+
database=args.neo4j_database,
|
|
90
|
+
),
|
|
91
|
+
llm=LLMConfig(
|
|
92
|
+
model=args.model,
|
|
93
|
+
temperature=args.temperature,
|
|
94
|
+
timeout=args.timeout,
|
|
95
|
+
api_key=args.api_key,
|
|
96
|
+
api_base=args.api_base,
|
|
97
|
+
api_version=args.api_version,
|
|
98
|
+
),
|
|
99
|
+
business_context_path=args.business_context,
|
|
100
|
+
max_validation_retries=args.max_validation_retries,
|
|
101
|
+
progress_reporter=progress_reporter,
|
|
102
|
+
)
|
|
103
|
+
result = generator.ask(args.query)
|
|
104
|
+
payload = result.model_dump(mode="json")
|
|
105
|
+
print(
|
|
106
|
+
json.dumps(
|
|
107
|
+
payload,
|
|
108
|
+
indent=2 if args.pretty else None,
|
|
109
|
+
ensure_ascii=False,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
return 1 if result.error else 0
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _run_chat(args: argparse.Namespace) -> int:
|
|
116
|
+
print("Cypher Generator interactive setup")
|
|
117
|
+
print("Press Enter to accept defaults. Type /exit in the question loop to stop.")
|
|
118
|
+
|
|
119
|
+
llm_config = _prompt_llm_config()
|
|
120
|
+
neo4j_credentials = _prompt_neo4j_credentials()
|
|
121
|
+
business_context = _prompt_optional(
|
|
122
|
+
"Business context file path (.txt/.md/.yaml/.json), blank for none",
|
|
123
|
+
default="",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
generator = CypherGenerator(
|
|
127
|
+
neo4j=neo4j_credentials,
|
|
128
|
+
llm=llm_config,
|
|
129
|
+
business_context_path=business_context or None,
|
|
130
|
+
max_validation_retries=args.max_validation_retries,
|
|
131
|
+
progress_reporter=_build_progress_reporter(args),
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
print("\nSetup complete. Ask a question, or type /exit to quit.")
|
|
135
|
+
while True:
|
|
136
|
+
try:
|
|
137
|
+
question = input("\nQuestion> ").strip()
|
|
138
|
+
except (EOFError, KeyboardInterrupt):
|
|
139
|
+
print()
|
|
140
|
+
return 0
|
|
141
|
+
|
|
142
|
+
if question.lower() in {"/exit", "/quit", "exit", "quit", "q"}:
|
|
143
|
+
return 0
|
|
144
|
+
if not question:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
result = generator.ask(question)
|
|
148
|
+
_print_interactive_result(result, pretty=args.pretty)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _prompt_llm_config() -> LLMConfig:
|
|
152
|
+
providers = [
|
|
153
|
+
("OpenAI", "openai"),
|
|
154
|
+
("Azure OpenAI", "azure"),
|
|
155
|
+
("Anthropic", "anthropic"),
|
|
156
|
+
("Google Gemini", "gemini"),
|
|
157
|
+
("Groq", "groq"),
|
|
158
|
+
("Custom LiteLLM model string", "custom"),
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
print("\nLLM provider")
|
|
162
|
+
for index, (label, _) in enumerate(providers, start=1):
|
|
163
|
+
print(f" {index}. {label}")
|
|
164
|
+
|
|
165
|
+
provider_index = _prompt_choice("Select provider", default=1, maximum=len(providers))
|
|
166
|
+
provider = providers[provider_index - 1][1]
|
|
167
|
+
|
|
168
|
+
if provider == "openai":
|
|
169
|
+
model = _prompt_optional("Model", default="openai/gpt-5")
|
|
170
|
+
api_key = _prompt_secret("OpenAI API key", required=True)
|
|
171
|
+
return LLMConfig(model=model, api_key=api_key, temperature=0)
|
|
172
|
+
|
|
173
|
+
if provider == "azure":
|
|
174
|
+
deployment = _prompt_required(
|
|
175
|
+
"Azure deployment name, without azure/ prefix"
|
|
176
|
+
)
|
|
177
|
+
api_key = _prompt_secret("Azure OpenAI API key", required=True)
|
|
178
|
+
api_base = _prompt_required(
|
|
179
|
+
"Azure API base, e.g. https://your-resource.openai.azure.com"
|
|
180
|
+
)
|
|
181
|
+
api_version = _prompt_optional("Azure API version", default="2024-10-21")
|
|
182
|
+
return LLMConfig(
|
|
183
|
+
model=f"azure/{deployment}",
|
|
184
|
+
api_key=api_key,
|
|
185
|
+
api_base=api_base,
|
|
186
|
+
api_version=api_version,
|
|
187
|
+
temperature=0,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if provider == "anthropic":
|
|
191
|
+
model = _prompt_optional(
|
|
192
|
+
"Model",
|
|
193
|
+
default="anthropic/claude-3-5-sonnet-latest",
|
|
194
|
+
)
|
|
195
|
+
api_key = _prompt_secret("Anthropic API key", required=True)
|
|
196
|
+
return LLMConfig(model=model, api_key=api_key, temperature=0)
|
|
197
|
+
|
|
198
|
+
if provider == "gemini":
|
|
199
|
+
model = _prompt_optional("Model", default="gemini/gemini-2.5-pro")
|
|
200
|
+
api_key = _prompt_secret("Gemini API key", required=True)
|
|
201
|
+
return LLMConfig(model=model, api_key=api_key, temperature=0)
|
|
202
|
+
|
|
203
|
+
if provider == "groq":
|
|
204
|
+
model = _prompt_optional("Model", default="groq/llama-3.3-70b-versatile")
|
|
205
|
+
api_key = _prompt_secret("Groq API key", required=True)
|
|
206
|
+
return LLMConfig(model=model, api_key=api_key, temperature=0)
|
|
207
|
+
|
|
208
|
+
model = _prompt_required(
|
|
209
|
+
"LiteLLM model string, e.g. openai/gpt-5 or azure/my-deployment"
|
|
210
|
+
)
|
|
211
|
+
api_key = _prompt_secret("API key, blank to use provider environment variables")
|
|
212
|
+
api_base = _prompt_optional("API base, blank unless provider needs it", default="")
|
|
213
|
+
api_version = _prompt_optional(
|
|
214
|
+
"API version, blank unless provider needs it",
|
|
215
|
+
default="",
|
|
216
|
+
)
|
|
217
|
+
return LLMConfig(
|
|
218
|
+
model=model,
|
|
219
|
+
api_key=api_key or None,
|
|
220
|
+
api_base=api_base or None,
|
|
221
|
+
api_version=api_version or None,
|
|
222
|
+
temperature=0,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _prompt_neo4j_credentials() -> Neo4jCredentials:
|
|
227
|
+
print("\nNeo4j connection")
|
|
228
|
+
uri = _prompt_optional("Neo4j URI", default="bolt://localhost:7687")
|
|
229
|
+
username = _prompt_optional("Neo4j username", default="neo4j")
|
|
230
|
+
password = _prompt_secret("Neo4j password", required=True)
|
|
231
|
+
database = _prompt_optional("Neo4j database", default="neo4j")
|
|
232
|
+
return Neo4jCredentials(
|
|
233
|
+
uri=uri,
|
|
234
|
+
username=username,
|
|
235
|
+
password=password,
|
|
236
|
+
database=database,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _prompt_choice(prompt: str, *, default: int, maximum: int) -> int:
|
|
241
|
+
while True:
|
|
242
|
+
raw = input(f"{prompt} [{default}]: ").strip()
|
|
243
|
+
if not raw:
|
|
244
|
+
return default
|
|
245
|
+
try:
|
|
246
|
+
value = int(raw)
|
|
247
|
+
except ValueError:
|
|
248
|
+
print(f"Enter a number from 1 to {maximum}.", file=sys.stderr)
|
|
249
|
+
continue
|
|
250
|
+
if 1 <= value <= maximum:
|
|
251
|
+
return value
|
|
252
|
+
print(f"Enter a number from 1 to {maximum}.", file=sys.stderr)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _prompt_optional(prompt: str, *, default: str = "") -> str:
|
|
256
|
+
suffix = f" [{default}]" if default else ""
|
|
257
|
+
value = input(f"{prompt}{suffix}: ").strip()
|
|
258
|
+
return value or default
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _prompt_required(prompt: str) -> str:
|
|
262
|
+
while True:
|
|
263
|
+
value = input(f"{prompt}: ").strip()
|
|
264
|
+
if value:
|
|
265
|
+
return value
|
|
266
|
+
print("This value is required.", file=sys.stderr)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _prompt_secret(prompt: str, *, required: bool = False) -> str:
|
|
270
|
+
while True:
|
|
271
|
+
value = getpass.getpass(f"{prompt}: ").strip()
|
|
272
|
+
if value or not required:
|
|
273
|
+
return value
|
|
274
|
+
print("This value is required.", file=sys.stderr)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _print_interactive_result(result: object, *, pretty: bool) -> None:
|
|
278
|
+
cypher = getattr(result, "cypher", "") or ""
|
|
279
|
+
records = getattr(result, "records", []) or []
|
|
280
|
+
validation = getattr(result, "validation", {}) or {}
|
|
281
|
+
error = getattr(result, "error", None)
|
|
282
|
+
attempts = getattr(result, "attempts", 0)
|
|
283
|
+
|
|
284
|
+
print("\nCypher:")
|
|
285
|
+
print(cypher or "(none)")
|
|
286
|
+
|
|
287
|
+
print("\nResults:")
|
|
288
|
+
print(json.dumps(records, indent=2 if pretty else None, ensure_ascii=False))
|
|
289
|
+
|
|
290
|
+
print("\nValidation:")
|
|
291
|
+
print(json.dumps(validation, indent=2 if pretty else None, ensure_ascii=False))
|
|
292
|
+
|
|
293
|
+
print(f"\nAttempts: {attempts}")
|
|
294
|
+
if error:
|
|
295
|
+
print(f"Error: {error}", file=sys.stderr)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _build_progress_reporter(
|
|
299
|
+
args: argparse.Namespace,
|
|
300
|
+
) -> TerminalProgressReporter | None:
|
|
301
|
+
if getattr(args, "no_progress", False):
|
|
302
|
+
return None
|
|
303
|
+
return TerminalProgressReporter(use_color=not getattr(args, "no_color", False))
|
cyphersmith/context.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
SUPPORTED_CONTEXT_EXTENSIONS = {".txt", ".md", ".yaml", ".yml", ".json"}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_business_context(path: str | Path | None) -> str:
|
|
12
|
+
"""Load optional business context from a text, markdown, YAML, or JSON file."""
|
|
13
|
+
if path is None:
|
|
14
|
+
return ""
|
|
15
|
+
|
|
16
|
+
context_path = Path(path).expanduser()
|
|
17
|
+
suffix = context_path.suffix.lower()
|
|
18
|
+
if suffix not in SUPPORTED_CONTEXT_EXTENSIONS:
|
|
19
|
+
allowed = ", ".join(sorted(SUPPORTED_CONTEXT_EXTENSIONS))
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"Unsupported business context extension '{suffix}'. "
|
|
22
|
+
f"Expected one of: {allowed}."
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
text = context_path.read_text(encoding="utf-8")
|
|
27
|
+
except FileNotFoundError as exc:
|
|
28
|
+
raise FileNotFoundError(f"Business context file not found: {context_path}") from exc
|
|
29
|
+
except OSError as exc:
|
|
30
|
+
raise OSError(f"Failed to read business context file {context_path}: {exc}") from exc
|
|
31
|
+
|
|
32
|
+
if not text.strip():
|
|
33
|
+
return ""
|
|
34
|
+
|
|
35
|
+
if suffix == ".json":
|
|
36
|
+
try:
|
|
37
|
+
data = json.loads(text)
|
|
38
|
+
except json.JSONDecodeError as exc:
|
|
39
|
+
raise ValueError(f"Invalid JSON business context at {context_path}: {exc}") from exc
|
|
40
|
+
return json.dumps(data, indent=2, ensure_ascii=False)
|
|
41
|
+
|
|
42
|
+
if suffix in {".yaml", ".yml"}:
|
|
43
|
+
try:
|
|
44
|
+
data = yaml.safe_load(text)
|
|
45
|
+
except yaml.YAMLError as exc:
|
|
46
|
+
raise ValueError(f"Invalid YAML business context at {context_path}: {exc}") from exc
|
|
47
|
+
if data is None:
|
|
48
|
+
return ""
|
|
49
|
+
return yaml.safe_dump(
|
|
50
|
+
data,
|
|
51
|
+
sort_keys=False,
|
|
52
|
+
allow_unicode=True,
|
|
53
|
+
default_flow_style=False,
|
|
54
|
+
).strip()
|
|
55
|
+
|
|
56
|
+
return text.strip()
|