haiku.rag-slim 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of haiku.rag-slim might be problematic. Click here for more details.
- haiku/rag/__init__.py +0 -0
- haiku/rag/app.py +542 -0
- haiku/rag/chunker.py +65 -0
- haiku/rag/cli.py +466 -0
- haiku/rag/client.py +731 -0
- haiku/rag/config/__init__.py +74 -0
- haiku/rag/config/loader.py +94 -0
- haiku/rag/config/models.py +99 -0
- haiku/rag/embeddings/__init__.py +49 -0
- haiku/rag/embeddings/base.py +25 -0
- haiku/rag/embeddings/ollama.py +28 -0
- haiku/rag/embeddings/openai.py +26 -0
- haiku/rag/embeddings/vllm.py +29 -0
- haiku/rag/embeddings/voyageai.py +27 -0
- haiku/rag/graph/__init__.py +26 -0
- haiku/rag/graph/agui/__init__.py +53 -0
- haiku/rag/graph/agui/cli_renderer.py +135 -0
- haiku/rag/graph/agui/emitter.py +197 -0
- haiku/rag/graph/agui/events.py +254 -0
- haiku/rag/graph/agui/server.py +310 -0
- haiku/rag/graph/agui/state.py +34 -0
- haiku/rag/graph/agui/stream.py +86 -0
- haiku/rag/graph/common/__init__.py +5 -0
- haiku/rag/graph/common/models.py +42 -0
- haiku/rag/graph/common/nodes.py +265 -0
- haiku/rag/graph/common/prompts.py +46 -0
- haiku/rag/graph/common/utils.py +44 -0
- haiku/rag/graph/deep_qa/__init__.py +1 -0
- haiku/rag/graph/deep_qa/dependencies.py +27 -0
- haiku/rag/graph/deep_qa/graph.py +243 -0
- haiku/rag/graph/deep_qa/models.py +20 -0
- haiku/rag/graph/deep_qa/prompts.py +59 -0
- haiku/rag/graph/deep_qa/state.py +56 -0
- haiku/rag/graph/research/__init__.py +3 -0
- haiku/rag/graph/research/common.py +87 -0
- haiku/rag/graph/research/dependencies.py +151 -0
- haiku/rag/graph/research/graph.py +295 -0
- haiku/rag/graph/research/models.py +166 -0
- haiku/rag/graph/research/prompts.py +107 -0
- haiku/rag/graph/research/state.py +85 -0
- haiku/rag/logging.py +56 -0
- haiku/rag/mcp.py +245 -0
- haiku/rag/monitor.py +194 -0
- haiku/rag/qa/__init__.py +33 -0
- haiku/rag/qa/agent.py +93 -0
- haiku/rag/qa/prompts.py +60 -0
- haiku/rag/reader.py +135 -0
- haiku/rag/reranking/__init__.py +63 -0
- haiku/rag/reranking/base.py +13 -0
- haiku/rag/reranking/cohere.py +34 -0
- haiku/rag/reranking/mxbai.py +28 -0
- haiku/rag/reranking/vllm.py +44 -0
- haiku/rag/reranking/zeroentropy.py +59 -0
- haiku/rag/store/__init__.py +4 -0
- haiku/rag/store/engine.py +309 -0
- haiku/rag/store/models/__init__.py +4 -0
- haiku/rag/store/models/chunk.py +17 -0
- haiku/rag/store/models/document.py +17 -0
- haiku/rag/store/repositories/__init__.py +9 -0
- haiku/rag/store/repositories/chunk.py +442 -0
- haiku/rag/store/repositories/document.py +261 -0
- haiku/rag/store/repositories/settings.py +165 -0
- haiku/rag/store/upgrades/__init__.py +62 -0
- haiku/rag/store/upgrades/v0_10_1.py +64 -0
- haiku/rag/store/upgrades/v0_9_3.py +112 -0
- haiku/rag/utils.py +211 -0
- haiku_rag_slim-0.16.0.dist-info/METADATA +128 -0
- haiku_rag_slim-0.16.0.dist-info/RECORD +71 -0
- haiku_rag_slim-0.16.0.dist-info/WHEEL +4 -0
- haiku_rag_slim-0.16.0.dist-info/entry_points.txt +2 -0
- haiku_rag_slim-0.16.0.dist-info/licenses/LICENSE +7 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from haiku.rag.config.loader import (
|
|
4
|
+
find_config_file,
|
|
5
|
+
generate_default_config,
|
|
6
|
+
load_yaml_config,
|
|
7
|
+
)
|
|
8
|
+
from haiku.rag.config.models import (
|
|
9
|
+
AGUIConfig,
|
|
10
|
+
AppConfig,
|
|
11
|
+
EmbeddingsConfig,
|
|
12
|
+
LanceDBConfig,
|
|
13
|
+
MonitorConfig,
|
|
14
|
+
OllamaConfig,
|
|
15
|
+
ProcessingConfig,
|
|
16
|
+
ProvidersConfig,
|
|
17
|
+
QAConfig,
|
|
18
|
+
RerankingConfig,
|
|
19
|
+
ResearchConfig,
|
|
20
|
+
StorageConfig,
|
|
21
|
+
VLLMConfig,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"Config",
|
|
26
|
+
"AGUIConfig",
|
|
27
|
+
"AppConfig",
|
|
28
|
+
"StorageConfig",
|
|
29
|
+
"MonitorConfig",
|
|
30
|
+
"LanceDBConfig",
|
|
31
|
+
"EmbeddingsConfig",
|
|
32
|
+
"RerankingConfig",
|
|
33
|
+
"QAConfig",
|
|
34
|
+
"ResearchConfig",
|
|
35
|
+
"ProcessingConfig",
|
|
36
|
+
"OllamaConfig",
|
|
37
|
+
"VLLMConfig",
|
|
38
|
+
"ProvidersConfig",
|
|
39
|
+
"find_config_file",
|
|
40
|
+
"load_yaml_config",
|
|
41
|
+
"generate_default_config",
|
|
42
|
+
"get_config",
|
|
43
|
+
"set_config",
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
# Global config instance - initially loads from default locations
|
|
47
|
+
_config: AppConfig | None = None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _load_default_config() -> AppConfig:
|
|
51
|
+
"""Load config from default locations (used at import time)."""
|
|
52
|
+
config_path = find_config_file(None)
|
|
53
|
+
if config_path:
|
|
54
|
+
yaml_data = load_yaml_config(config_path)
|
|
55
|
+
return AppConfig.model_validate(yaml_data)
|
|
56
|
+
return AppConfig()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def set_config(config: AppConfig) -> None:
|
|
60
|
+
"""Set the global config instance (used by CLI to override)."""
|
|
61
|
+
global _config
|
|
62
|
+
_config = config
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_config() -> AppConfig:
|
|
66
|
+
"""Get the current config instance."""
|
|
67
|
+
global _config
|
|
68
|
+
if _config is None:
|
|
69
|
+
_config = _load_default_config()
|
|
70
|
+
return _config
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# Legacy compatibility - Config is the default instance
|
|
74
|
+
Config = _load_default_config()
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import yaml
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def find_config_file(cli_path: Path | None = None) -> Path | None:
|
|
8
|
+
"""Find the YAML config file using the search path.
|
|
9
|
+
|
|
10
|
+
Search order:
|
|
11
|
+
1. CLI-provided path (via HAIKU_RAG_CONFIG_PATH env var or parameter)
|
|
12
|
+
2. ./haiku.rag.yaml (current directory)
|
|
13
|
+
3. Platform-specific user config directory
|
|
14
|
+
|
|
15
|
+
Returns None if no config file is found.
|
|
16
|
+
"""
|
|
17
|
+
# Check environment variable first (set by CLI --config flag)
|
|
18
|
+
if not cli_path:
|
|
19
|
+
env_path = os.getenv("HAIKU_RAG_CONFIG_PATH")
|
|
20
|
+
if env_path:
|
|
21
|
+
cli_path = Path(env_path)
|
|
22
|
+
|
|
23
|
+
if cli_path:
|
|
24
|
+
if cli_path.exists():
|
|
25
|
+
return cli_path
|
|
26
|
+
raise FileNotFoundError(f"Config file not found: {cli_path}")
|
|
27
|
+
|
|
28
|
+
cwd_config = Path.cwd() / "haiku.rag.yaml"
|
|
29
|
+
if cwd_config.exists():
|
|
30
|
+
return cwd_config
|
|
31
|
+
|
|
32
|
+
# Use same directory as data storage for config
|
|
33
|
+
from haiku.rag.utils import get_default_data_dir
|
|
34
|
+
|
|
35
|
+
data_dir = get_default_data_dir()
|
|
36
|
+
user_config = data_dir / "haiku.rag.yaml"
|
|
37
|
+
if user_config.exists():
|
|
38
|
+
return user_config
|
|
39
|
+
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load_yaml_config(path: Path) -> dict:
|
|
44
|
+
"""Load and parse a YAML config file."""
|
|
45
|
+
with open(path) as f:
|
|
46
|
+
data = yaml.safe_load(f)
|
|
47
|
+
return data or {}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def generate_default_config() -> dict:
|
|
51
|
+
"""Generate a default YAML config structure with documentation."""
|
|
52
|
+
return {
|
|
53
|
+
"environment": "production",
|
|
54
|
+
"storage": {
|
|
55
|
+
"data_dir": "",
|
|
56
|
+
"vacuum_retention_seconds": 86400,
|
|
57
|
+
},
|
|
58
|
+
"monitor": {
|
|
59
|
+
"directories": [],
|
|
60
|
+
"ignore_patterns": [],
|
|
61
|
+
"include_patterns": [],
|
|
62
|
+
},
|
|
63
|
+
"lancedb": {"uri": "", "api_key": "", "region": ""},
|
|
64
|
+
"embeddings": {
|
|
65
|
+
"provider": "ollama",
|
|
66
|
+
"model": "qwen3-embedding",
|
|
67
|
+
"vector_dim": 4096,
|
|
68
|
+
},
|
|
69
|
+
"reranking": {"provider": "", "model": ""},
|
|
70
|
+
"qa": {"provider": "ollama", "model": "gpt-oss"},
|
|
71
|
+
"research": {"provider": "", "model": ""},
|
|
72
|
+
"processing": {
|
|
73
|
+
"chunk_size": 256,
|
|
74
|
+
"context_chunk_radius": 0,
|
|
75
|
+
"markdown_preprocessor": "",
|
|
76
|
+
},
|
|
77
|
+
"providers": {
|
|
78
|
+
"ollama": {"base_url": "http://localhost:11434"},
|
|
79
|
+
"vllm": {
|
|
80
|
+
"embeddings_base_url": "",
|
|
81
|
+
"rerank_base_url": "",
|
|
82
|
+
"qa_base_url": "",
|
|
83
|
+
"research_base_url": "",
|
|
84
|
+
},
|
|
85
|
+
},
|
|
86
|
+
"agui": {
|
|
87
|
+
"host": "0.0.0.0",
|
|
88
|
+
"port": 8000,
|
|
89
|
+
"cors_origins": ["*"],
|
|
90
|
+
"cors_credentials": True,
|
|
91
|
+
"cors_methods": ["GET", "POST", "OPTIONS"],
|
|
92
|
+
"cors_headers": ["*"],
|
|
93
|
+
},
|
|
94
|
+
}
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from haiku.rag.utils import get_default_data_dir
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StorageConfig(BaseModel):
|
|
9
|
+
data_dir: Path = Field(default_factory=get_default_data_dir)
|
|
10
|
+
vacuum_retention_seconds: int = 86400
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MonitorConfig(BaseModel):
|
|
14
|
+
directories: list[Path] = []
|
|
15
|
+
ignore_patterns: list[str] = []
|
|
16
|
+
include_patterns: list[str] = []
|
|
17
|
+
delete_orphans: bool = False
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LanceDBConfig(BaseModel):
|
|
21
|
+
uri: str = ""
|
|
22
|
+
api_key: str = ""
|
|
23
|
+
region: str = ""
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmbeddingsConfig(BaseModel):
|
|
27
|
+
provider: str = "ollama"
|
|
28
|
+
model: str = "qwen3-embedding"
|
|
29
|
+
vector_dim: int = 4096
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RerankingConfig(BaseModel):
|
|
33
|
+
provider: str = ""
|
|
34
|
+
model: str = ""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class QAConfig(BaseModel):
|
|
38
|
+
provider: str = "ollama"
|
|
39
|
+
model: str = "gpt-oss"
|
|
40
|
+
max_sub_questions: int = 3
|
|
41
|
+
max_iterations: int = 2
|
|
42
|
+
max_concurrency: int = 1
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ResearchConfig(BaseModel):
|
|
46
|
+
provider: str = "ollama"
|
|
47
|
+
model: str = "gpt-oss"
|
|
48
|
+
max_iterations: int = 3
|
|
49
|
+
confidence_threshold: float = 0.8
|
|
50
|
+
max_concurrency: int = 1
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ProcessingConfig(BaseModel):
|
|
54
|
+
chunk_size: int = 256
|
|
55
|
+
context_chunk_radius: int = 0
|
|
56
|
+
markdown_preprocessor: str = ""
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class OllamaConfig(BaseModel):
|
|
60
|
+
base_url: str = Field(
|
|
61
|
+
default_factory=lambda: __import__("os").environ.get(
|
|
62
|
+
"OLLAMA_BASE_URL", "http://localhost:11434"
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class VLLMConfig(BaseModel):
|
|
68
|
+
embeddings_base_url: str = ""
|
|
69
|
+
rerank_base_url: str = ""
|
|
70
|
+
qa_base_url: str = ""
|
|
71
|
+
research_base_url: str = ""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class ProvidersConfig(BaseModel):
|
|
75
|
+
ollama: OllamaConfig = Field(default_factory=OllamaConfig)
|
|
76
|
+
vllm: VLLMConfig = Field(default_factory=VLLMConfig)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class AGUIConfig(BaseModel):
|
|
80
|
+
host: str = "0.0.0.0"
|
|
81
|
+
port: int = 8000
|
|
82
|
+
cors_origins: list[str] = ["*"]
|
|
83
|
+
cors_credentials: bool = True
|
|
84
|
+
cors_methods: list[str] = ["GET", "POST", "OPTIONS"]
|
|
85
|
+
cors_headers: list[str] = ["*"]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class AppConfig(BaseModel):
|
|
89
|
+
environment: str = "production"
|
|
90
|
+
storage: StorageConfig = Field(default_factory=StorageConfig)
|
|
91
|
+
monitor: MonitorConfig = Field(default_factory=MonitorConfig)
|
|
92
|
+
lancedb: LanceDBConfig = Field(default_factory=LanceDBConfig)
|
|
93
|
+
embeddings: EmbeddingsConfig = Field(default_factory=EmbeddingsConfig)
|
|
94
|
+
reranking: RerankingConfig = Field(default_factory=RerankingConfig)
|
|
95
|
+
qa: QAConfig = Field(default_factory=QAConfig)
|
|
96
|
+
research: ResearchConfig = Field(default_factory=ResearchConfig)
|
|
97
|
+
processing: ProcessingConfig = Field(default_factory=ProcessingConfig)
|
|
98
|
+
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
|
99
|
+
agui: AGUIConfig = Field(default_factory=AGUIConfig)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from haiku.rag.config import AppConfig, Config
|
|
2
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
3
|
+
from haiku.rag.embeddings.ollama import Embedder as OllamaEmbedder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_embedder(config: AppConfig = Config) -> EmbedderBase:
|
|
7
|
+
"""
|
|
8
|
+
Factory function to get the appropriate embedder based on the configuration.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
config: Configuration to use. Defaults to global Config.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
An embedder instance configured according to the config.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
if config.embeddings.provider == "ollama":
|
|
18
|
+
return OllamaEmbedder(
|
|
19
|
+
config.embeddings.model, config.embeddings.vector_dim, config
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if config.embeddings.provider == "voyageai":
|
|
23
|
+
try:
|
|
24
|
+
from haiku.rag.embeddings.voyageai import Embedder as VoyageAIEmbedder
|
|
25
|
+
except ImportError:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"VoyageAI embedder requires the 'voyageai' package. "
|
|
28
|
+
"Please install haiku.rag with the 'voyageai' extra: "
|
|
29
|
+
"uv pip install haiku.rag[voyageai]"
|
|
30
|
+
)
|
|
31
|
+
return VoyageAIEmbedder(
|
|
32
|
+
config.embeddings.model, config.embeddings.vector_dim, config
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if config.embeddings.provider == "openai":
|
|
36
|
+
from haiku.rag.embeddings.openai import Embedder as OpenAIEmbedder
|
|
37
|
+
|
|
38
|
+
return OpenAIEmbedder(
|
|
39
|
+
config.embeddings.model, config.embeddings.vector_dim, config
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if config.embeddings.provider == "vllm":
|
|
43
|
+
from haiku.rag.embeddings.vllm import Embedder as VllmEmbedder
|
|
44
|
+
|
|
45
|
+
return VllmEmbedder(
|
|
46
|
+
config.embeddings.model, config.embeddings.vector_dim, config
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
raise ValueError(f"Unsupported embedding provider: {config.embeddings.provider}")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
3
|
+
from haiku.rag.config import AppConfig, Config
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EmbedderBase:
|
|
7
|
+
_model: str = Config.embeddings.model
|
|
8
|
+
_vector_dim: int = Config.embeddings.vector_dim
|
|
9
|
+
_config: AppConfig = Config
|
|
10
|
+
|
|
11
|
+
def __init__(self, model: str, vector_dim: int, config: AppConfig = Config):
|
|
12
|
+
self._model = model
|
|
13
|
+
self._vector_dim = vector_dim
|
|
14
|
+
self._config = config
|
|
15
|
+
|
|
16
|
+
@overload
|
|
17
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
18
|
+
|
|
19
|
+
@overload
|
|
20
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
21
|
+
|
|
22
|
+
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
23
|
+
raise NotImplementedError(
|
|
24
|
+
"Embedder is an abstract class. Please implement the embed method in a subclass."
|
|
25
|
+
)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Embedder(EmbedderBase):
|
|
9
|
+
@overload
|
|
10
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
11
|
+
|
|
12
|
+
@overload
|
|
13
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
|
+
|
|
15
|
+
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
16
|
+
client = AsyncOpenAI(
|
|
17
|
+
base_url=f"{self._config.providers.ollama.base_url}/v1", api_key="dummy"
|
|
18
|
+
)
|
|
19
|
+
if not text:
|
|
20
|
+
return []
|
|
21
|
+
response = await client.embeddings.create(
|
|
22
|
+
model=self._model,
|
|
23
|
+
input=text,
|
|
24
|
+
)
|
|
25
|
+
if isinstance(text, str):
|
|
26
|
+
return response.data[0].embedding
|
|
27
|
+
else:
|
|
28
|
+
return [item.embedding for item in response.data]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Embedder(EmbedderBase):
|
|
9
|
+
@overload
|
|
10
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
11
|
+
|
|
12
|
+
@overload
|
|
13
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
|
+
|
|
15
|
+
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
16
|
+
client = AsyncOpenAI()
|
|
17
|
+
if not text:
|
|
18
|
+
return []
|
|
19
|
+
response = await client.embeddings.create(
|
|
20
|
+
model=self._model,
|
|
21
|
+
input=text,
|
|
22
|
+
)
|
|
23
|
+
if isinstance(text, str):
|
|
24
|
+
return response.data[0].embedding
|
|
25
|
+
else:
|
|
26
|
+
return [item.embedding for item in response.data]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import overload
|
|
2
|
+
|
|
3
|
+
from openai import AsyncOpenAI
|
|
4
|
+
|
|
5
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Embedder(EmbedderBase):
|
|
9
|
+
@overload
|
|
10
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
11
|
+
|
|
12
|
+
@overload
|
|
13
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
|
+
|
|
15
|
+
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
16
|
+
client = AsyncOpenAI(
|
|
17
|
+
base_url=f"{self._config.providers.vllm.embeddings_base_url}/v1",
|
|
18
|
+
api_key="dummy",
|
|
19
|
+
)
|
|
20
|
+
if not text:
|
|
21
|
+
return []
|
|
22
|
+
response = await client.embeddings.create(
|
|
23
|
+
model=self._model,
|
|
24
|
+
input=text,
|
|
25
|
+
)
|
|
26
|
+
if isinstance(text, str):
|
|
27
|
+
return response.data[0].embedding
|
|
28
|
+
else:
|
|
29
|
+
return [item.embedding for item in response.data]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from typing import overload
|
|
3
|
+
|
|
4
|
+
from voyageai.client import Client # type: ignore
|
|
5
|
+
|
|
6
|
+
from haiku.rag.embeddings.base import EmbedderBase
|
|
7
|
+
|
|
8
|
+
class Embedder(EmbedderBase):
|
|
9
|
+
@overload
|
|
10
|
+
async def embed(self, text: str) -> list[float]: ...
|
|
11
|
+
|
|
12
|
+
@overload
|
|
13
|
+
async def embed(self, text: list[str]) -> list[list[float]]: ...
|
|
14
|
+
|
|
15
|
+
async def embed(self, text: str | list[str]) -> list[float] | list[list[float]]:
|
|
16
|
+
client = Client()
|
|
17
|
+
if not text:
|
|
18
|
+
return []
|
|
19
|
+
if isinstance(text, str):
|
|
20
|
+
res = client.embed([text], model=self._model, output_dtype="float")
|
|
21
|
+
return res.embeddings[0] # type: ignore[return-value]
|
|
22
|
+
else:
|
|
23
|
+
res = client.embed(text, model=self._model, output_dtype="float")
|
|
24
|
+
return res.embeddings # type: ignore[return-value]
|
|
25
|
+
|
|
26
|
+
except ImportError:
|
|
27
|
+
pass
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Graph module for haiku.rag.
|
|
2
|
+
|
|
3
|
+
This module contains all graph-related functionality including:
|
|
4
|
+
- AG-UI protocol for graph streaming
|
|
5
|
+
- Common graph utilities and models
|
|
6
|
+
- Research graph implementation
|
|
7
|
+
- Deep QA graph implementation
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from haiku.rag.graph.agui import (
|
|
11
|
+
AGUIConsoleRenderer,
|
|
12
|
+
AGUIEmitter,
|
|
13
|
+
create_agui_server,
|
|
14
|
+
stream_graph,
|
|
15
|
+
)
|
|
16
|
+
from haiku.rag.graph.deep_qa.graph import build_deep_qa_graph
|
|
17
|
+
from haiku.rag.graph.research.graph import build_research_graph
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"AGUIConsoleRenderer",
|
|
21
|
+
"AGUIEmitter",
|
|
22
|
+
"build_deep_qa_graph",
|
|
23
|
+
"build_research_graph",
|
|
24
|
+
"create_agui_server",
|
|
25
|
+
"stream_graph",
|
|
26
|
+
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Generic AG-UI protocol support for haiku.rag graphs."""
|
|
2
|
+
|
|
3
|
+
from haiku.rag.graph.agui.cli_renderer import AGUIConsoleRenderer
|
|
4
|
+
from haiku.rag.graph.agui.emitter import AGUIEmitter
|
|
5
|
+
from haiku.rag.graph.agui.events import (
|
|
6
|
+
AGUIEvent,
|
|
7
|
+
emit_activity,
|
|
8
|
+
emit_activity_delta,
|
|
9
|
+
emit_run_error,
|
|
10
|
+
emit_run_finished,
|
|
11
|
+
emit_run_started,
|
|
12
|
+
emit_state_delta,
|
|
13
|
+
emit_state_snapshot,
|
|
14
|
+
emit_step_finished,
|
|
15
|
+
emit_step_started,
|
|
16
|
+
emit_text_message,
|
|
17
|
+
emit_text_message_content,
|
|
18
|
+
emit_text_message_end,
|
|
19
|
+
emit_text_message_start,
|
|
20
|
+
)
|
|
21
|
+
from haiku.rag.graph.agui.server import (
|
|
22
|
+
RunAgentInput,
|
|
23
|
+
create_agui_app,
|
|
24
|
+
create_agui_server,
|
|
25
|
+
format_sse_event,
|
|
26
|
+
)
|
|
27
|
+
from haiku.rag.graph.agui.state import compute_state_delta
|
|
28
|
+
from haiku.rag.graph.agui.stream import stream_graph
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"AGUIConsoleRenderer",
|
|
32
|
+
"AGUIEmitter",
|
|
33
|
+
"AGUIEvent",
|
|
34
|
+
"RunAgentInput",
|
|
35
|
+
"compute_state_delta",
|
|
36
|
+
"create_agui_app",
|
|
37
|
+
"create_agui_server",
|
|
38
|
+
"emit_activity",
|
|
39
|
+
"emit_activity_delta",
|
|
40
|
+
"emit_run_error",
|
|
41
|
+
"emit_run_finished",
|
|
42
|
+
"emit_run_started",
|
|
43
|
+
"emit_state_delta",
|
|
44
|
+
"emit_state_snapshot",
|
|
45
|
+
"emit_step_finished",
|
|
46
|
+
"emit_step_started",
|
|
47
|
+
"emit_text_message",
|
|
48
|
+
"emit_text_message_content",
|
|
49
|
+
"emit_text_message_end",
|
|
50
|
+
"emit_text_message_start",
|
|
51
|
+
"format_sse_event",
|
|
52
|
+
"stream_graph",
|
|
53
|
+
]
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Generic CLI renderer for AG-UI events with Rich console output."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
|
|
8
|
+
from haiku.rag.graph.agui.events import AGUIEvent
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AGUIConsoleRenderer:
|
|
12
|
+
"""Renders AG-UI events to Rich console with formatted output.
|
|
13
|
+
|
|
14
|
+
Generic renderer that processes AG-UI protocol events and renders them
|
|
15
|
+
with Rich formatting. Works with any graph that emits AG-UI events.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, console: Console | None = None):
|
|
19
|
+
"""Initialize the renderer.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
console: Optional Rich console instance (creates new one if not provided)
|
|
23
|
+
"""
|
|
24
|
+
self.console = console or Console()
|
|
25
|
+
|
|
26
|
+
async def render(self, events: AsyncIterator[AGUIEvent]) -> Any | None:
|
|
27
|
+
"""Process events and render to console, return final result.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
events: Async iterator of AG-UI events
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The final result from RunFinished event, or None
|
|
34
|
+
"""
|
|
35
|
+
result = None
|
|
36
|
+
|
|
37
|
+
async for event in events:
|
|
38
|
+
event_type = event.get("type")
|
|
39
|
+
|
|
40
|
+
if event_type == "RUN_STARTED":
|
|
41
|
+
self._render_run_started(event)
|
|
42
|
+
elif event_type == "RUN_FINISHED":
|
|
43
|
+
result = event.get("result")
|
|
44
|
+
self._render_run_finished()
|
|
45
|
+
elif event_type == "RUN_ERROR":
|
|
46
|
+
self._render_error(event)
|
|
47
|
+
elif event_type == "STEP_STARTED":
|
|
48
|
+
self._render_step_started(event)
|
|
49
|
+
elif event_type == "STEP_FINISHED":
|
|
50
|
+
self._render_step_finished(event)
|
|
51
|
+
elif event_type == "TEXT_MESSAGE_CHUNK":
|
|
52
|
+
self._render_text_message(event)
|
|
53
|
+
elif event_type == "TEXT_MESSAGE_START":
|
|
54
|
+
pass # Start of streaming message, no output needed
|
|
55
|
+
elif event_type == "TEXT_MESSAGE_CONTENT":
|
|
56
|
+
self._render_text_content(event)
|
|
57
|
+
elif event_type == "TEXT_MESSAGE_END":
|
|
58
|
+
pass # End of streaming message, no output needed
|
|
59
|
+
elif event_type == "STATE_SNAPSHOT":
|
|
60
|
+
self._render_state_snapshot(event)
|
|
61
|
+
elif event_type == "STATE_DELTA":
|
|
62
|
+
self._render_state_delta(event)
|
|
63
|
+
elif event_type == "ACTIVITY_SNAPSHOT":
|
|
64
|
+
self._render_activity(event)
|
|
65
|
+
elif event_type == "ACTIVITY_DELTA":
|
|
66
|
+
pass # Activity deltas don't need separate rendering
|
|
67
|
+
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
def _render_run_started(self, event: AGUIEvent) -> None:
|
|
71
|
+
"""Render run start event."""
|
|
72
|
+
run_id = event.get("runId", "")
|
|
73
|
+
if run_id:
|
|
74
|
+
# Show shortened run ID (first 8 chars like our UUIDs)
|
|
75
|
+
short_id = run_id[:8] if len(run_id) > 8 else run_id
|
|
76
|
+
self.console.print(f"[bold green][RUN_STARTED][/bold green] Run {short_id}")
|
|
77
|
+
|
|
78
|
+
def _render_run_finished(self) -> None:
|
|
79
|
+
"""Render run completion."""
|
|
80
|
+
self.console.print("[bold green][RUN_FINISHED][/bold green] Completed")
|
|
81
|
+
|
|
82
|
+
def _render_error(self, event: AGUIEvent) -> None:
|
|
83
|
+
"""Render error event."""
|
|
84
|
+
message = event.get("message", "Unknown error")
|
|
85
|
+
self.console.print(f"[bold red][RUN_ERROR][/bold red] {message}")
|
|
86
|
+
|
|
87
|
+
def _render_step_started(self, event: AGUIEvent) -> None:
|
|
88
|
+
"""Render step start event."""
|
|
89
|
+
step_name = event.get("stepName", "")
|
|
90
|
+
if step_name:
|
|
91
|
+
display_name = step_name.replace("_", " ").title()
|
|
92
|
+
self.console.print(
|
|
93
|
+
f"\n[bold cyan][STEP_STARTED][/bold cyan] {display_name}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def _render_step_finished(self, event: AGUIEvent) -> None:
|
|
97
|
+
"""Render step finish event."""
|
|
98
|
+
step_name = event.get("stepName", "")
|
|
99
|
+
if step_name:
|
|
100
|
+
display_name = step_name.replace("_", " ").title()
|
|
101
|
+
self.console.print(f"[cyan][STEP_FINISHED][/cyan] {display_name}")
|
|
102
|
+
|
|
103
|
+
def _render_text_message(self, event: AGUIEvent) -> None:
|
|
104
|
+
"""Render complete text message."""
|
|
105
|
+
delta = event.get("delta", "")
|
|
106
|
+
self.console.print(f"[magenta][TEXT_MESSAGE][/magenta] {delta}")
|
|
107
|
+
|
|
108
|
+
def _render_text_content(self, event: AGUIEvent) -> None:
|
|
109
|
+
"""Render streaming text content delta."""
|
|
110
|
+
delta = event.get("delta", "")
|
|
111
|
+
self.console.print(delta, end="")
|
|
112
|
+
|
|
113
|
+
def _render_activity(self, event: AGUIEvent) -> None:
|
|
114
|
+
"""Render activity update."""
|
|
115
|
+
content = event.get("content", "")
|
|
116
|
+
if content:
|
|
117
|
+
self.console.print(f"[yellow][ACTIVITY][/yellow] {content}")
|
|
118
|
+
|
|
119
|
+
def _render_state_snapshot(self, event: AGUIEvent) -> None:
|
|
120
|
+
"""Render full state snapshot."""
|
|
121
|
+
snapshot = event.get("snapshot")
|
|
122
|
+
if not snapshot:
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
self.console.print("[blue][STATE_SNAPSHOT][/blue]")
|
|
126
|
+
self.console.print(snapshot, style="dim")
|
|
127
|
+
|
|
128
|
+
def _render_state_delta(self, event: AGUIEvent) -> None:
|
|
129
|
+
"""Render state delta operations."""
|
|
130
|
+
delta = event.get("delta", [])
|
|
131
|
+
if not delta:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
self.console.print("[blue][STATE_DELTA][/blue]")
|
|
135
|
+
self.console.print(delta, style="dim")
|