cyphersmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,16 @@
1
+ """Public API for the cyphersmith package."""
2
+
3
+ from .generator import CypherGenerator
4
+ from .models import CypherGenerationResult, LLMConfig, Neo4jCredentials
5
+ from .neo4j_client import Neo4jClient, get_neo4j_client
6
+ from .progress import TerminalProgressReporter
7
+
8
+ __all__ = [
9
+ "CypherGenerationResult",
10
+ "CypherGenerator",
11
+ "LLMConfig",
12
+ "Neo4jClient",
13
+ "Neo4jCredentials",
14
+ "TerminalProgressReporter",
15
+ "get_neo4j_client",
16
+ ]
@@ -0,0 +1,4 @@
1
+ from .cli import main
2
+
3
+ if __name__ == "__main__":
4
+ raise SystemExit(main())
cyphersmith/cli.py ADDED
@@ -0,0 +1,303 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import getpass
5
+ import json
6
+ import sys
7
+ from typing import Sequence
8
+
9
+ from .generator import CypherGenerator
10
+ from .models import LLMConfig, Neo4jCredentials
11
+ from .progress import TerminalProgressReporter
12
+
13
+
14
+ def main(argv: Sequence[str] | None = None) -> int:
15
+ parser = _build_parser()
16
+ args = parser.parse_args(argv)
17
+
18
+ if args.command == "ask":
19
+ return _run_ask(args)
20
+ if args.command in {"chat", "setup", "interactive"}:
21
+ return _run_chat(args)
22
+
23
+ parser.print_help()
24
+ return 2
25
+
26
+
27
+ def _build_parser() -> argparse.ArgumentParser:
28
+ parser = argparse.ArgumentParser(
29
+ prog="cypher-generator",
30
+ description="Generate, validate, and execute read-only Cypher against Neo4j.",
31
+ )
32
+ subparsers = parser.add_subparsers(dest="command")
33
+
34
+ ask = subparsers.add_parser("ask", help="Ask a natural-language graph question.")
35
+ ask.add_argument("query", help="Natural-language graph question.")
36
+ ask.add_argument("--neo4j-uri", dest="neo4j_uri")
37
+ ask.add_argument("--neo4j-user", dest="neo4j_user")
38
+ ask.add_argument("--neo4j-password", dest="neo4j_password")
39
+ ask.add_argument("--neo4j-database", dest="neo4j_database", default=None)
40
+ ask.add_argument("--model", required=True, help="LiteLLM model name.")
41
+ ask.add_argument("--temperature", type=float, default=0)
42
+ ask.add_argument("--timeout", type=float, default=None)
43
+ ask.add_argument("--api-key", dest="api_key", default=None)
44
+ ask.add_argument("--api-base", dest="api_base", default=None)
45
+ ask.add_argument("--api-version", dest="api_version", default=None)
46
+ ask.add_argument("--business-context", dest="business_context", default=None)
47
+ ask.add_argument("--max-validation-retries", type=int, default=2)
48
+ ask.add_argument(
49
+ "--no-progress",
50
+ action="store_true",
51
+ help="Disable intermediate progress output.",
52
+ )
53
+ ask.add_argument(
54
+ "--no-color",
55
+ action="store_true",
56
+ help="Disable ANSI colors in progress output.",
57
+ )
58
+ ask.add_argument("--pretty", action="store_true", help="Pretty-print JSON output.")
59
+ ask.set_defaults(command="ask")
60
+
61
+ chat = subparsers.add_parser(
62
+ "chat",
63
+ aliases=["setup", "interactive"],
64
+ help="Prompt for setup once, then ask questions in a loop.",
65
+ )
66
+ chat.add_argument("--pretty", action="store_true", help="Pretty-print result JSON.")
67
+ chat.add_argument("--max-validation-retries", type=int, default=2)
68
+ chat.add_argument(
69
+ "--no-progress",
70
+ action="store_true",
71
+ help="Disable intermediate progress output.",
72
+ )
73
+ chat.add_argument(
74
+ "--no-color",
75
+ action="store_true",
76
+ help="Disable ANSI colors in progress output.",
77
+ )
78
+ chat.set_defaults(command="chat")
79
+ return parser
80
+
81
+
82
+ def _run_ask(args: argparse.Namespace) -> int:
83
+ progress_reporter = _build_progress_reporter(args)
84
+ generator = CypherGenerator(
85
+ neo4j=Neo4jCredentials(
86
+ uri=args.neo4j_uri,
87
+ username=args.neo4j_user,
88
+ password=args.neo4j_password,
89
+ database=args.neo4j_database,
90
+ ),
91
+ llm=LLMConfig(
92
+ model=args.model,
93
+ temperature=args.temperature,
94
+ timeout=args.timeout,
95
+ api_key=args.api_key,
96
+ api_base=args.api_base,
97
+ api_version=args.api_version,
98
+ ),
99
+ business_context_path=args.business_context,
100
+ max_validation_retries=args.max_validation_retries,
101
+ progress_reporter=progress_reporter,
102
+ )
103
+ result = generator.ask(args.query)
104
+ payload = result.model_dump(mode="json")
105
+ print(
106
+ json.dumps(
107
+ payload,
108
+ indent=2 if args.pretty else None,
109
+ ensure_ascii=False,
110
+ )
111
+ )
112
+ return 1 if result.error else 0
113
+
114
+
115
+ def _run_chat(args: argparse.Namespace) -> int:
116
+ print("Cypher Generator interactive setup")
117
+ print("Press Enter to accept defaults. Type /exit in the question loop to stop.")
118
+
119
+ llm_config = _prompt_llm_config()
120
+ neo4j_credentials = _prompt_neo4j_credentials()
121
+ business_context = _prompt_optional(
122
+ "Business context file path (.txt/.md/.yaml/.json), blank for none",
123
+ default="",
124
+ )
125
+
126
+ generator = CypherGenerator(
127
+ neo4j=neo4j_credentials,
128
+ llm=llm_config,
129
+ business_context_path=business_context or None,
130
+ max_validation_retries=args.max_validation_retries,
131
+ progress_reporter=_build_progress_reporter(args),
132
+ )
133
+
134
+ print("\nSetup complete. Ask a question, or type /exit to quit.")
135
+ while True:
136
+ try:
137
+ question = input("\nQuestion> ").strip()
138
+ except (EOFError, KeyboardInterrupt):
139
+ print()
140
+ return 0
141
+
142
+ if question.lower() in {"/exit", "/quit", "exit", "quit", "q"}:
143
+ return 0
144
+ if not question:
145
+ continue
146
+
147
+ result = generator.ask(question)
148
+ _print_interactive_result(result, pretty=args.pretty)
149
+
150
+
151
+ def _prompt_llm_config() -> LLMConfig:
152
+ providers = [
153
+ ("OpenAI", "openai"),
154
+ ("Azure OpenAI", "azure"),
155
+ ("Anthropic", "anthropic"),
156
+ ("Google Gemini", "gemini"),
157
+ ("Groq", "groq"),
158
+ ("Custom LiteLLM model string", "custom"),
159
+ ]
160
+
161
+ print("\nLLM provider")
162
+ for index, (label, _) in enumerate(providers, start=1):
163
+ print(f" {index}. {label}")
164
+
165
+ provider_index = _prompt_choice("Select provider", default=1, maximum=len(providers))
166
+ provider = providers[provider_index - 1][1]
167
+
168
+ if provider == "openai":
169
+ model = _prompt_optional("Model", default="openai/gpt-5")
170
+ api_key = _prompt_secret("OpenAI API key", required=True)
171
+ return LLMConfig(model=model, api_key=api_key, temperature=0)
172
+
173
+ if provider == "azure":
174
+ deployment = _prompt_required(
175
+ "Azure deployment name, without azure/ prefix"
176
+ )
177
+ api_key = _prompt_secret("Azure OpenAI API key", required=True)
178
+ api_base = _prompt_required(
179
+ "Azure API base, e.g. https://your-resource.openai.azure.com"
180
+ )
181
+ api_version = _prompt_optional("Azure API version", default="2024-10-21")
182
+ return LLMConfig(
183
+ model=f"azure/{deployment}",
184
+ api_key=api_key,
185
+ api_base=api_base,
186
+ api_version=api_version,
187
+ temperature=0,
188
+ )
189
+
190
+ if provider == "anthropic":
191
+ model = _prompt_optional(
192
+ "Model",
193
+ default="anthropic/claude-3-5-sonnet-latest",
194
+ )
195
+ api_key = _prompt_secret("Anthropic API key", required=True)
196
+ return LLMConfig(model=model, api_key=api_key, temperature=0)
197
+
198
+ if provider == "gemini":
199
+ model = _prompt_optional("Model", default="gemini/gemini-2.5-pro")
200
+ api_key = _prompt_secret("Gemini API key", required=True)
201
+ return LLMConfig(model=model, api_key=api_key, temperature=0)
202
+
203
+ if provider == "groq":
204
+ model = _prompt_optional("Model", default="groq/llama-3.3-70b-versatile")
205
+ api_key = _prompt_secret("Groq API key", required=True)
206
+ return LLMConfig(model=model, api_key=api_key, temperature=0)
207
+
208
+ model = _prompt_required(
209
+ "LiteLLM model string, e.g. openai/gpt-5 or azure/my-deployment"
210
+ )
211
+ api_key = _prompt_secret("API key, blank to use provider environment variables")
212
+ api_base = _prompt_optional("API base, blank unless provider needs it", default="")
213
+ api_version = _prompt_optional(
214
+ "API version, blank unless provider needs it",
215
+ default="",
216
+ )
217
+ return LLMConfig(
218
+ model=model,
219
+ api_key=api_key or None,
220
+ api_base=api_base or None,
221
+ api_version=api_version or None,
222
+ temperature=0,
223
+ )
224
+
225
+
226
+ def _prompt_neo4j_credentials() -> Neo4jCredentials:
227
+ print("\nNeo4j connection")
228
+ uri = _prompt_optional("Neo4j URI", default="bolt://localhost:7687")
229
+ username = _prompt_optional("Neo4j username", default="neo4j")
230
+ password = _prompt_secret("Neo4j password", required=True)
231
+ database = _prompt_optional("Neo4j database", default="neo4j")
232
+ return Neo4jCredentials(
233
+ uri=uri,
234
+ username=username,
235
+ password=password,
236
+ database=database,
237
+ )
238
+
239
+
240
+ def _prompt_choice(prompt: str, *, default: int, maximum: int) -> int:
241
+ while True:
242
+ raw = input(f"{prompt} [{default}]: ").strip()
243
+ if not raw:
244
+ return default
245
+ try:
246
+ value = int(raw)
247
+ except ValueError:
248
+ print(f"Enter a number from 1 to {maximum}.", file=sys.stderr)
249
+ continue
250
+ if 1 <= value <= maximum:
251
+ return value
252
+ print(f"Enter a number from 1 to {maximum}.", file=sys.stderr)
253
+
254
+
255
+ def _prompt_optional(prompt: str, *, default: str = "") -> str:
256
+ suffix = f" [{default}]" if default else ""
257
+ value = input(f"{prompt}{suffix}: ").strip()
258
+ return value or default
259
+
260
+
261
+ def _prompt_required(prompt: str) -> str:
262
+ while True:
263
+ value = input(f"{prompt}: ").strip()
264
+ if value:
265
+ return value
266
+ print("This value is required.", file=sys.stderr)
267
+
268
+
269
+ def _prompt_secret(prompt: str, *, required: bool = False) -> str:
270
+ while True:
271
+ value = getpass.getpass(f"{prompt}: ").strip()
272
+ if value or not required:
273
+ return value
274
+ print("This value is required.", file=sys.stderr)
275
+
276
+
277
+ def _print_interactive_result(result: object, *, pretty: bool) -> None:
278
+ cypher = getattr(result, "cypher", "") or ""
279
+ records = getattr(result, "records", []) or []
280
+ validation = getattr(result, "validation", {}) or {}
281
+ error = getattr(result, "error", None)
282
+ attempts = getattr(result, "attempts", 0)
283
+
284
+ print("\nCypher:")
285
+ print(cypher or "(none)")
286
+
287
+ print("\nResults:")
288
+ print(json.dumps(records, indent=2 if pretty else None, ensure_ascii=False))
289
+
290
+ print("\nValidation:")
291
+ print(json.dumps(validation, indent=2 if pretty else None, ensure_ascii=False))
292
+
293
+ print(f"\nAttempts: {attempts}")
294
+ if error:
295
+ print(f"Error: {error}", file=sys.stderr)
296
+
297
+
298
+ def _build_progress_reporter(
299
+ args: argparse.Namespace,
300
+ ) -> TerminalProgressReporter | None:
301
+ if getattr(args, "no_progress", False):
302
+ return None
303
+ return TerminalProgressReporter(use_color=not getattr(args, "no_color", False))
cyphersmith/context.py ADDED
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import yaml
7
+
8
+ SUPPORTED_CONTEXT_EXTENSIONS = {".txt", ".md", ".yaml", ".yml", ".json"}
9
+
10
+
11
+ def load_business_context(path: str | Path | None) -> str:
12
+ """Load optional business context from a text, markdown, YAML, or JSON file."""
13
+ if path is None:
14
+ return ""
15
+
16
+ context_path = Path(path).expanduser()
17
+ suffix = context_path.suffix.lower()
18
+ if suffix not in SUPPORTED_CONTEXT_EXTENSIONS:
19
+ allowed = ", ".join(sorted(SUPPORTED_CONTEXT_EXTENSIONS))
20
+ raise ValueError(
21
+ f"Unsupported business context extension '{suffix}'. "
22
+ f"Expected one of: {allowed}."
23
+ )
24
+
25
+ try:
26
+ text = context_path.read_text(encoding="utf-8")
27
+ except FileNotFoundError as exc:
28
+ raise FileNotFoundError(f"Business context file not found: {context_path}") from exc
29
+ except OSError as exc:
30
+ raise OSError(f"Failed to read business context file {context_path}: {exc}") from exc
31
+
32
+ if not text.strip():
33
+ return ""
34
+
35
+ if suffix == ".json":
36
+ try:
37
+ data = json.loads(text)
38
+ except json.JSONDecodeError as exc:
39
+ raise ValueError(f"Invalid JSON business context at {context_path}: {exc}") from exc
40
+ return json.dumps(data, indent=2, ensure_ascii=False)
41
+
42
+ if suffix in {".yaml", ".yml"}:
43
+ try:
44
+ data = yaml.safe_load(text)
45
+ except yaml.YAMLError as exc:
46
+ raise ValueError(f"Invalid YAML business context at {context_path}: {exc}") from exc
47
+ if data is None:
48
+ return ""
49
+ return yaml.safe_dump(
50
+ data,
51
+ sort_keys=False,
52
+ allow_unicode=True,
53
+ default_flow_style=False,
54
+ ).strip()
55
+
56
+ return text.strip()