okb 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
okb/config.py ADDED
@@ -0,0 +1,661 @@
1
+ """Shared configuration for the knowledge base."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Any
10
+ from urllib.parse import urlparse
11
+
12
+ import yaml
13
+
14
+
15
+ def resolve_env_vars(value: Any) -> Any:
16
+ """Recursively resolve ${ENV_VAR} references in config values.
17
+
18
+ Supports:
19
+ - ${VAR} - required, raises if not set
20
+ - ${VAR:-default} - optional with default value
21
+
22
+ Args:
23
+ value: Config value (string, dict, or list)
24
+
25
+ Returns:
26
+ Value with env vars resolved
27
+ """
28
+ if isinstance(value, str):
29
+ # Pattern: ${VAR} or ${VAR:-default}
30
+ pattern = r"\$\{([A-Za-z_][A-Za-z0-9_]*)(?::-([^}]*))?\}"
31
+
32
+ def replacer(match):
33
+ var_name = match.group(1)
34
+ default = match.group(2)
35
+ env_value = os.environ.get(var_name)
36
+ if env_value is not None:
37
+ return env_value
38
+ if default is not None:
39
+ return default
40
+ raise ValueError(f"Environment variable ${var_name} not set and no default provided")
41
+
42
+ return re.sub(pattern, replacer, value)
43
+ elif isinstance(value, dict):
44
+ return {k: resolve_env_vars(v) for k, v in value.items()}
45
+ elif isinstance(value, list):
46
+ return [resolve_env_vars(v) for v in value]
47
+ return value
48
+
49
+
50
+ @dataclass
51
+ class DatabaseConfig:
52
+ """Configuration for a single database."""
53
+
54
+ name: str
55
+ url: str
56
+ managed: bool = True # Whether lkb manages this (Docker) or external
57
+ default: bool = False
58
+ description: str | None = None # Human-readable description for LLM context
59
+ topics: list[str] | None = None # Topic keywords to help LLM route queries
60
+
61
+ @property
62
+ def database_name(self) -> str:
63
+ """Extract database name from URL."""
64
+ parsed = urlparse(self.url)
65
+ return parsed.path.lstrip("/") or self.name
66
+
67
+
68
+ def get_config_dir() -> Path:
69
+ """Get the config directory, respecting XDG_CONFIG_HOME."""
70
+ xdg_config = os.environ.get("XDG_CONFIG_HOME")
71
+ if xdg_config:
72
+ return Path(xdg_config) / "okb"
73
+ return Path.home() / ".config" / "okb"
74
+
75
+
76
+ def get_config_path() -> Path:
77
+ """Get the path to the config file."""
78
+ return get_config_dir() / "config.yaml"
79
+
80
+
81
+ def load_config_file() -> dict[str, Any]:
82
+ """Load configuration from config file if it exists."""
83
+ config_path = get_config_path()
84
+ if config_path.exists():
85
+ with open(config_path) as f:
86
+ return yaml.safe_load(f) or {}
87
+ return {}
88
+
89
+
90
+ def find_local_config(start_path: Path | None = None) -> Path | None:
91
+ """Find .okbconf.yaml by walking up from start_path (default: CWD)."""
92
+ path = (start_path or Path.cwd()).resolve()
93
+ while path != path.parent:
94
+ local_config = path / ".okbconf.yaml"
95
+ if local_config.exists():
96
+ return local_config
97
+ path = path.parent
98
+ return None
99
+
100
+
101
+ def load_local_config() -> dict[str, Any]:
102
+ """Load local config overlay if present."""
103
+ local_path = find_local_config()
104
+ if local_path:
105
+ with open(local_path) as f:
106
+ return yaml.safe_load(f) or {}
107
+ return {}
108
+
109
+
110
+ def merge_configs(base: dict, overlay: dict, path: str = "") -> dict:
111
+ """Merge overlay config into base. Lists extend, dicts deep-merge, scalars replace.
112
+
113
+ Raises ValueError if types don't match (e.g., list vs dict).
114
+ """
115
+ result = dict(base)
116
+ for key, value in overlay.items():
117
+ key_path = f"{path}.{key}" if path else key
118
+ if key in result:
119
+ base_val = result[key]
120
+ # Check for type mismatches between collections
121
+ if isinstance(value, list) and not isinstance(base_val, list):
122
+ raise ValueError(
123
+ f"Config type mismatch at '{key_path}': "
124
+ f"global config has {type(base_val)}, local config has list"
125
+ )
126
+ if isinstance(value, dict) and not isinstance(base_val, dict):
127
+ raise ValueError(
128
+ f"Config type mismatch at '{key_path}': "
129
+ f"global config has {type(base_val)}, local config has dict"
130
+ )
131
+ # Merge by type
132
+ if isinstance(value, list):
133
+ result[key] = base_val + value # Extend lists
134
+ elif isinstance(value, dict):
135
+ result[key] = merge_configs(base_val, value, key_path) # Deep merge dicts
136
+ else:
137
+ result[key] = value # Replace scalar
138
+ else:
139
+ result[key] = value
140
+ return result
141
+
142
+
143
+ # Default configuration values
144
+ DEFAULTS = {
145
+ "databases": {
146
+ "default": {
147
+ "url": "postgresql://knowledge:localdev@localhost:5433/knowledge_base",
148
+ "default": True,
149
+ "managed": True,
150
+ },
151
+ },
152
+ "docker": {
153
+ "port": 5433,
154
+ "container_name": "okb-pgvector",
155
+ "volume_name": "okb-pgvector-data",
156
+ "password": "localdev",
157
+ },
158
+ "http": {
159
+ "host": "127.0.0.1",
160
+ "port": 8080,
161
+ },
162
+ "embedding": {
163
+ "model_name": "nomic-ai/nomic-embed-text-v1.5",
164
+ "dimension": 768,
165
+ },
166
+ "chunking": {
167
+ "chunk_size": 512,
168
+ "chunk_overlap": 64,
169
+ "chars_per_token": 4,
170
+ },
171
+ "search": {
172
+ "default_limit": 5,
173
+ "max_limit": 20,
174
+ "min_similarity": 0.3,
175
+ },
176
+ "extensions": {
177
+ "documents": [".md", ".txt", ".markdown", ".org", ".pdf", ".docx"],
178
+ "code": [
179
+ ".py",
180
+ ".rb",
181
+ ".js",
182
+ ".ts",
183
+ ".jsx",
184
+ ".tsx",
185
+ ".sql",
186
+ ".sh",
187
+ ".bash",
188
+ ".fish",
189
+ ".yaml",
190
+ ".yml",
191
+ ".toml",
192
+ ".json",
193
+ ".html",
194
+ ".css",
195
+ ".scss",
196
+ ".go",
197
+ ".rs",
198
+ ".java",
199
+ ".kt",
200
+ ".c",
201
+ ".cpp",
202
+ ".h",
203
+ ],
204
+ "skip_directories": [
205
+ ".git",
206
+ ".hg",
207
+ ".svn",
208
+ "vault",
209
+ "node_modules",
210
+ "__pycache__",
211
+ ".venv",
212
+ "venv",
213
+ ".mypy_cache",
214
+ ".pytest_cache",
215
+ ".ruff_cache",
216
+ "dist",
217
+ "build",
218
+ ".next",
219
+ ".nuxt",
220
+ "lib",
221
+ "libs",
222
+ "vendor",
223
+ "third_party",
224
+ "third-party",
225
+ "external",
226
+ "bower_components",
227
+ ],
228
+ },
229
+ "security": {
230
+ # Sensitive files - blocked by default
231
+ "block_patterns": [
232
+ "id_rsa",
233
+ "id_ed25519",
234
+ "id_ecdsa",
235
+ "id_dsa",
236
+ "*.pem",
237
+ "*.key",
238
+ "*.p12",
239
+ "*.pfx",
240
+ ".env",
241
+ ".env.*",
242
+ "*credentials*",
243
+ "*secret*",
244
+ ".netrc",
245
+ ".pgpass",
246
+ ".my.cnf",
247
+ "*_history",
248
+ ".bash_history",
249
+ ".zsh_history",
250
+ ],
251
+ # Low-value files - skipped for usefulness not security
252
+ "skip_patterns": [
253
+ "*.min.js",
254
+ "*.min.css",
255
+ "*.bundle.js",
256
+ "*.chunk.js",
257
+ "*.map",
258
+ "package-lock.json",
259
+ "yarn.lock",
260
+ "uv.lock",
261
+ "Cargo.lock",
262
+ "*.pyc",
263
+ "*.pyo",
264
+ "*.tmp",
265
+ "*.tmp.*",
266
+ ".#*",
267
+ "*~", # Temp/backup files
268
+ ],
269
+ "scan_content": True,
270
+ "max_line_length_for_minified": 1000,
271
+ },
272
+ "plugins": {
273
+ # API sources configuration
274
+ # Example:
275
+ # sources:
276
+ # github:
277
+ # enabled: true
278
+ # token: ${GITHUB_TOKEN}
279
+ # repos: [owner/repo1, owner/repo2]
280
+ "sources": {},
281
+ },
282
+ "llm": {
283
+ # LLM provider configuration
284
+ # provider: None = disabled, "claude" = Anthropic API
285
+ "provider": None,
286
+ "model": "claude-haiku-4-5-20251001",
287
+ "timeout": 30,
288
+ "cache_responses": True,
289
+ # Bedrock settings (when use_bedrock is True)
290
+ "use_bedrock": False,
291
+ "aws_region": "us-west-2",
292
+ },
293
+ }
294
+
295
+
296
+ @dataclass
297
+ class Config:
298
+ """Knowledge base configuration."""
299
+
300
+ # Multiple databases support
301
+ databases: dict[str, DatabaseConfig] = field(default_factory=dict)
302
+ default_database: str | None = None
303
+
304
+ # Local config overlay path (set in __post_init__ if found)
305
+ local_config_path: Path | None = None
306
+
307
+ # Docker
308
+ docker_port: int = 5433
309
+ docker_container_name: str = "okb-pgvector"
310
+ docker_volume_name: str = "okb-pgvector-data"
311
+ docker_password: str = "localdev"
312
+
313
+ # HTTP server
314
+ http_host: str = "127.0.0.1"
315
+ http_port: int = 8080
316
+
317
+ # Embedding model
318
+ model_name: str = "nomic-ai/nomic-embed-text-v1.5"
319
+ embedding_dim: int = 768
320
+
321
+ # Chunking
322
+ chunk_size: int = 512
323
+ chunk_overlap: int = 64
324
+ chars_per_token: int = 4
325
+
326
+ # Search defaults
327
+ default_limit: int = 5
328
+ max_limit: int = 20
329
+ min_similarity: float = 0.3
330
+
331
+ # File types (loaded from config in __post_init__)
332
+ document_extensions: frozenset[str] = field(default_factory=frozenset)
333
+ code_extensions: frozenset[str] = field(default_factory=frozenset)
334
+ skip_directories: frozenset[str] = field(default_factory=frozenset)
335
+
336
+ # Security settings (loaded from config in __post_init__)
337
+ block_patterns: list[str] = field(default_factory=list)
338
+ skip_patterns: list[str] = field(default_factory=list)
339
+ scan_content: bool = True
340
+ max_line_length_for_minified: int = 1000
341
+
342
+ # Plugin settings (loaded from config in __post_init__)
343
+ plugin_sources: dict[str, dict] = field(default_factory=dict)
344
+
345
+ # LLM settings (loaded from config in __post_init__)
346
+ llm_provider: str | None = None
347
+ llm_model: str = "claude-haiku-4-5-20251001"
348
+ llm_timeout: int = 30
349
+ llm_cache_responses: bool = True
350
+ llm_use_bedrock: bool = False
351
+ llm_aws_region: str = "us-west-2"
352
+
353
+ def __post_init__(self):
354
+ """Load configuration from file and environment."""
355
+ file_config = load_config_file()
356
+
357
+ # Load and merge local config overlay (.lkbconf.yaml)
358
+ local_path = find_local_config()
359
+ local_default_db: str | None = None
360
+ if local_path:
361
+ self.local_config_path = local_path
362
+ local_config = load_local_config()
363
+ file_config = merge_configs(file_config, local_config)
364
+
365
+ # Save local config's default_database to apply after database loading
366
+ if "default_database" in local_config:
367
+ local_default_db = local_config["default_database"]
368
+
369
+ # Merge extension/security lists with defaults so local config extends defaults
370
+ # (not just global config file which may be empty)
371
+ list_fields_to_extend = [
372
+ ("extensions", ["skip_directories"]),
373
+ ("security", ["block_patterns", "skip_patterns"]),
374
+ ]
375
+ for section, keys in list_fields_to_extend:
376
+ if section in file_config:
377
+ for key in keys:
378
+ if key in file_config[section]:
379
+ # Prepend defaults to user's list (user values extend defaults)
380
+ default_list = DEFAULTS[section][key]
381
+ user_list = file_config[section][key]
382
+ # Deduplicate while preserving order
383
+ seen = set()
384
+ merged = []
385
+ for item in default_list + user_list:
386
+ if item not in seen:
387
+ seen.add(item)
388
+ merged.append(item)
389
+ file_config[section][key] = merged
390
+
391
+ # Load databases: new multi-db format or legacy single database_url
392
+ if "databases" in file_config:
393
+ default_dbs = []
394
+ for name, db_cfg in file_config["databases"].items():
395
+ self.databases[name] = DatabaseConfig(
396
+ name=name,
397
+ url=db_cfg["url"],
398
+ managed=db_cfg.get("managed", True),
399
+ default=db_cfg.get("default", False),
400
+ description=db_cfg.get("description"),
401
+ topics=db_cfg.get("topics"),
402
+ )
403
+ if db_cfg.get("default"):
404
+ default_dbs.append(name)
405
+ self.default_database = name
406
+ # Validate only one default
407
+ if len(default_dbs) > 1:
408
+ raise ValueError(
409
+ f"Multiple databases marked as default: {default_dbs}. "
410
+ "Only one database can have 'default: true'."
411
+ )
412
+ # If no default was marked, use first database
413
+ if not self.default_database and self.databases:
414
+ first_name = next(iter(self.databases))
415
+ self.databases[first_name].default = True
416
+ self.default_database = first_name
417
+ else:
418
+ # Legacy: single database_url (env > file > default)
419
+ legacy_url = os.environ.get(
420
+ "KB_DATABASE_URL",
421
+ file_config.get("database_url", DEFAULTS["databases"]["default"]["url"]),
422
+ )
423
+ self.databases["default"] = DatabaseConfig(
424
+ name="default",
425
+ url=legacy_url,
426
+ managed=True,
427
+ default=True,
428
+ )
429
+ self.default_database = "default"
430
+
431
+ # Apply local config's default_database override (takes precedence over global)
432
+ if local_default_db:
433
+ self.default_database = local_default_db
434
+
435
+ # Docker settings
436
+ docker_cfg = file_config.get("docker", {})
437
+ self.docker_port = int(
438
+ os.environ.get(
439
+ "OKB_DOCKER_PORT",
440
+ docker_cfg.get("port", DEFAULTS["docker"]["port"]),
441
+ )
442
+ )
443
+ self.docker_container_name = os.environ.get(
444
+ "OKB_CONTAINER_NAME",
445
+ docker_cfg.get("container_name", DEFAULTS["docker"]["container_name"]),
446
+ )
447
+ self.docker_volume_name = os.environ.get(
448
+ "OKB_VOLUME_NAME",
449
+ docker_cfg.get("volume_name", DEFAULTS["docker"]["volume_name"]),
450
+ )
451
+ self.docker_password = os.environ.get(
452
+ "OKB_DB_PASSWORD",
453
+ docker_cfg.get("password", DEFAULTS["docker"]["password"]),
454
+ )
455
+
456
+ # HTTP server settings
457
+ http_cfg = file_config.get("http", {})
458
+ self.http_host = os.environ.get(
459
+ "OKB_HTTP_HOST",
460
+ http_cfg.get("host", DEFAULTS["http"]["host"]),
461
+ )
462
+ self.http_port = int(
463
+ os.environ.get(
464
+ "OKB_HTTP_PORT",
465
+ http_cfg.get("port", DEFAULTS["http"]["port"]),
466
+ )
467
+ )
468
+
469
+ # Embedding settings
470
+ embedding_cfg = file_config.get("embedding", {})
471
+ self.model_name = embedding_cfg.get("model_name", DEFAULTS["embedding"]["model_name"])
472
+ self.embedding_dim = embedding_cfg.get("dimension", DEFAULTS["embedding"]["dimension"])
473
+
474
+ # Chunking settings
475
+ chunking_cfg = file_config.get("chunking", {})
476
+ self.chunk_size = chunking_cfg.get("chunk_size", DEFAULTS["chunking"]["chunk_size"])
477
+ self.chunk_overlap = chunking_cfg.get(
478
+ "chunk_overlap", DEFAULTS["chunking"]["chunk_overlap"]
479
+ )
480
+ self.chars_per_token = chunking_cfg.get(
481
+ "chars_per_token", DEFAULTS["chunking"]["chars_per_token"]
482
+ )
483
+
484
+ # Search settings
485
+ search_cfg = file_config.get("search", {})
486
+ self.default_limit = search_cfg.get("default_limit", DEFAULTS["search"]["default_limit"])
487
+ self.max_limit = search_cfg.get("max_limit", DEFAULTS["search"]["max_limit"])
488
+ self.min_similarity = search_cfg.get("min_similarity", DEFAULTS["search"]["min_similarity"])
489
+
490
+ # Extension settings
491
+ ext_cfg = file_config.get("extensions", {})
492
+ self.document_extensions = frozenset(
493
+ ext_cfg.get("documents", DEFAULTS["extensions"]["documents"])
494
+ )
495
+ self.code_extensions = frozenset(ext_cfg.get("code", DEFAULTS["extensions"]["code"]))
496
+ self.skip_directories = frozenset(
497
+ ext_cfg.get("skip_directories", DEFAULTS["extensions"]["skip_directories"])
498
+ )
499
+
500
+ # Security settings
501
+ security_cfg = file_config.get("security", {})
502
+ self.block_patterns = security_cfg.get(
503
+ "block_patterns", DEFAULTS["security"]["block_patterns"]
504
+ )
505
+ self.skip_patterns = security_cfg.get(
506
+ "skip_patterns", DEFAULTS["security"]["skip_patterns"]
507
+ )
508
+ self.scan_content = security_cfg.get("scan_content", DEFAULTS["security"]["scan_content"])
509
+ self.max_line_length_for_minified = security_cfg.get(
510
+ "max_line_length_for_minified", DEFAULTS["security"]["max_line_length_for_minified"]
511
+ )
512
+
513
+ # Plugin settings - resolve env vars in source configs
514
+ plugins_cfg = file_config.get("plugins", {})
515
+ self.plugin_sources = plugins_cfg.get("sources", {})
516
+
517
+ # LLM settings
518
+ llm_cfg = file_config.get("llm", {})
519
+ self.llm_provider = os.environ.get(
520
+ "OKB_LLM_PROVIDER",
521
+ llm_cfg.get("provider", DEFAULTS["llm"]["provider"]),
522
+ )
523
+ self.llm_model = os.environ.get(
524
+ "OKB_LLM_MODEL",
525
+ llm_cfg.get("model", DEFAULTS["llm"]["model"]),
526
+ )
527
+ self.llm_timeout = int(
528
+ os.environ.get(
529
+ "OKB_LLM_TIMEOUT",
530
+ llm_cfg.get("timeout", DEFAULTS["llm"]["timeout"]),
531
+ )
532
+ )
533
+ self.llm_cache_responses = llm_cfg.get(
534
+ "cache_responses", DEFAULTS["llm"]["cache_responses"]
535
+ )
536
+ self.llm_use_bedrock = llm_cfg.get("use_bedrock", DEFAULTS["llm"]["use_bedrock"])
537
+ self.llm_aws_region = llm_cfg.get("aws_region", DEFAULTS["llm"]["aws_region"])
538
+
539
+ def get_database(self, name: str | None = None) -> DatabaseConfig:
540
+ """Get database config by name, or default if None."""
541
+ if name is None:
542
+ name = self.default_database
543
+ if name is None:
544
+ raise ValueError("No database specified and no default configured")
545
+ if name not in self.databases:
546
+ raise ValueError(f"Unknown database: {name}. Available: {list(self.databases.keys())}")
547
+ return self.databases[name]
548
+
549
+ @property
550
+ def db_url(self) -> str:
551
+ """Backward compat: return default database URL."""
552
+ return self.get_database().url
553
+
554
+ @property
555
+ def all_extensions(self) -> frozenset[str]:
556
+ return self.document_extensions | self.code_extensions
557
+
558
+ def should_skip_path(self, path: Path) -> bool:
559
+ """Check if a path should be skipped during collection."""
560
+ return any(part.startswith(".") or part in self.skip_directories for part in path.parts)
561
+
562
+ def get_source_config(self, source_name: str) -> dict | None:
563
+ """Get resolved config for a plugin source.
564
+
565
+ Resolves ${ENV_VAR} references in the config values.
566
+ Returns None if source not configured or disabled.
567
+ """
568
+ source_cfg = self.plugin_sources.get(source_name)
569
+ if source_cfg is None:
570
+ return None
571
+ if not source_cfg.get("enabled", True):
572
+ return None
573
+ try:
574
+ return resolve_env_vars(source_cfg)
575
+ except ValueError as e:
576
+ raise ValueError(f"Error resolving config for source '{source_name}': {e}") from e
577
+
578
+ def list_enabled_sources(self) -> list[str]:
579
+ """List all enabled plugin sources."""
580
+ return [name for name, cfg in self.plugin_sources.items() if cfg.get("enabled", True)]
581
+
582
+ def to_dict(self) -> dict[str, Any]:
583
+ """Convert config to dictionary for display."""
584
+ databases_dict = {}
585
+ for name, db_cfg in self.databases.items():
586
+ db_dict: dict[str, Any] = {
587
+ "url": db_cfg.url,
588
+ "managed": db_cfg.managed,
589
+ "default": db_cfg.default,
590
+ }
591
+ if db_cfg.description:
592
+ db_dict["description"] = db_cfg.description
593
+ if db_cfg.topics:
594
+ db_dict["topics"] = db_cfg.topics
595
+ databases_dict[name] = db_dict
596
+
597
+ result: dict[str, Any] = {}
598
+ if self.local_config_path:
599
+ result["local_config"] = str(self.local_config_path)
600
+ result["databases"] = databases_dict
601
+ return {
602
+ **result,
603
+ "docker": {
604
+ "port": self.docker_port,
605
+ "container_name": self.docker_container_name,
606
+ "volume_name": self.docker_volume_name,
607
+ "password": "***" if self.docker_password else None,
608
+ },
609
+ "http": {
610
+ "host": self.http_host,
611
+ "port": self.http_port,
612
+ },
613
+ "embedding": {
614
+ "model_name": self.model_name,
615
+ "dimension": self.embedding_dim,
616
+ },
617
+ "chunking": {
618
+ "chunk_size": self.chunk_size,
619
+ "chunk_overlap": self.chunk_overlap,
620
+ "chars_per_token": self.chars_per_token,
621
+ },
622
+ "search": {
623
+ "default_limit": self.default_limit,
624
+ "max_limit": self.max_limit,
625
+ "min_similarity": self.min_similarity,
626
+ },
627
+ "extensions": {
628
+ "documents": sorted(self.document_extensions),
629
+ "code": sorted(self.code_extensions),
630
+ "skip_directories": sorted(self.skip_directories),
631
+ },
632
+ "security": {
633
+ "block_patterns": self.block_patterns,
634
+ "skip_patterns": self.skip_patterns,
635
+ "scan_content": self.scan_content,
636
+ "max_line_length_for_minified": self.max_line_length_for_minified,
637
+ },
638
+ "plugins": {
639
+ "sources": {
640
+ name: {**cfg, "token": "***" if "token" in cfg else None}
641
+ for name, cfg in self.plugin_sources.items()
642
+ },
643
+ },
644
+ "llm": {
645
+ "provider": self.llm_provider,
646
+ "model": self.llm_model,
647
+ "timeout": self.llm_timeout,
648
+ "cache_responses": self.llm_cache_responses,
649
+ "use_bedrock": self.llm_use_bedrock,
650
+ "aws_region": self.llm_aws_region,
651
+ },
652
+ }
653
+
654
+
655
+ def get_default_config_yaml() -> str:
656
+ """Get the default config as YAML string."""
657
+ return yaml.dump(DEFAULTS, default_flow_style=False, sort_keys=False)
658
+
659
+
660
+ # Global config instance
661
+ config = Config()