langroid 0.31.1__py3-none-any.whl → 0.33.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langroid-0.31.1.dist-info → langroid-0.33.3.dist-info}/METADATA +150 -124
- langroid-0.33.3.dist-info/RECORD +7 -0
- {langroid-0.31.1.dist-info → langroid-0.33.3.dist-info}/WHEEL +1 -1
- langroid-0.33.3.dist-info/entry_points.txt +4 -0
- pyproject.toml +317 -212
- langroid/__init__.py +0 -106
- langroid/agent/.chainlit/config.toml +0 -121
- langroid/agent/.chainlit/translations/bn.json +0 -231
- langroid/agent/.chainlit/translations/en-US.json +0 -229
- langroid/agent/.chainlit/translations/gu.json +0 -231
- langroid/agent/.chainlit/translations/he-IL.json +0 -231
- langroid/agent/.chainlit/translations/hi.json +0 -231
- langroid/agent/.chainlit/translations/kn.json +0 -231
- langroid/agent/.chainlit/translations/ml.json +0 -231
- langroid/agent/.chainlit/translations/mr.json +0 -231
- langroid/agent/.chainlit/translations/ta.json +0 -231
- langroid/agent/.chainlit/translations/te.json +0 -231
- langroid/agent/.chainlit/translations/zh-CN.json +0 -229
- langroid/agent/__init__.py +0 -41
- langroid/agent/base.py +0 -1981
- langroid/agent/batch.py +0 -398
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +0 -598
- langroid/agent/chat_agent.py +0 -1899
- langroid/agent/chat_document.py +0 -454
- langroid/agent/helpers.py +0 -0
- langroid/agent/junk +0 -13
- langroid/agent/openai_assistant.py +0 -882
- langroid/agent/special/__init__.py +0 -59
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +0 -656
- langroid/agent/special/arangodb/system_messages.py +0 -186
- langroid/agent/special/arangodb/tools.py +0 -107
- langroid/agent/special/arangodb/utils.py +0 -36
- langroid/agent/special/doc_chat_agent.py +0 -1466
- langroid/agent/special/lance_doc_chat_agent.py +0 -262
- langroid/agent/special/lance_rag/__init__.py +0 -9
- langroid/agent/special/lance_rag/critic_agent.py +0 -198
- langroid/agent/special/lance_rag/lance_rag_task.py +0 -82
- langroid/agent/special/lance_rag/query_planner_agent.py +0 -260
- langroid/agent/special/lance_tools.py +0 -61
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +0 -174
- langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -433
- langroid/agent/special/neo4j/system_messages.py +0 -120
- langroid/agent/special/neo4j/tools.py +0 -32
- langroid/agent/special/relevance_extractor_agent.py +0 -127
- langroid/agent/special/retriever_agent.py +0 -56
- langroid/agent/special/sql/__init__.py +0 -17
- langroid/agent/special/sql/sql_chat_agent.py +0 -654
- langroid/agent/special/sql/utils/__init__.py +0 -21
- langroid/agent/special/sql/utils/description_extractors.py +0 -190
- langroid/agent/special/sql/utils/populate_metadata.py +0 -85
- langroid/agent/special/sql/utils/system_message.py +0 -35
- langroid/agent/special/sql/utils/tools.py +0 -64
- langroid/agent/special/table_chat_agent.py +0 -263
- langroid/agent/structured_message.py +0 -9
- langroid/agent/task.py +0 -2093
- langroid/agent/tool_message.py +0 -393
- langroid/agent/tools/__init__.py +0 -38
- langroid/agent/tools/duckduckgo_search_tool.py +0 -50
- langroid/agent/tools/file_tools.py +0 -234
- langroid/agent/tools/google_search_tool.py +0 -39
- langroid/agent/tools/metaphor_search_tool.py +0 -67
- langroid/agent/tools/orchestration.py +0 -303
- langroid/agent/tools/recipient_tool.py +0 -235
- langroid/agent/tools/retrieval_tool.py +0 -32
- langroid/agent/tools/rewind_tool.py +0 -137
- langroid/agent/tools/segment_extract_tool.py +0 -41
- langroid/agent/typed_task.py +0 -19
- langroid/agent/xml_tool_message.py +0 -382
- langroid/agent_config.py +0 -0
- langroid/cachedb/__init__.py +0 -17
- langroid/cachedb/base.py +0 -58
- langroid/cachedb/momento_cachedb.py +0 -108
- langroid/cachedb/redis_cachedb.py +0 -153
- langroid/embedding_models/__init__.py +0 -39
- langroid/embedding_models/base.py +0 -74
- langroid/embedding_models/clustering.py +0 -189
- langroid/embedding_models/models.py +0 -461
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +0 -19
- langroid/embedding_models/protoc/embeddings_pb2.py +0 -33
- langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -50
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -79
- langroid/embedding_models/remote_embeds.py +0 -153
- langroid/exceptions.py +0 -65
- langroid/experimental/team-save.py +0 -391
- langroid/language_models/.chainlit/config.toml +0 -121
- langroid/language_models/.chainlit/translations/en-US.json +0 -231
- langroid/language_models/__init__.py +0 -53
- langroid/language_models/azure_openai.py +0 -153
- langroid/language_models/base.py +0 -678
- langroid/language_models/config.py +0 -18
- langroid/language_models/mock_lm.py +0 -124
- langroid/language_models/openai_gpt.py +0 -1923
- langroid/language_models/prompt_formatter/__init__.py +0 -16
- langroid/language_models/prompt_formatter/base.py +0 -40
- langroid/language_models/prompt_formatter/hf_formatter.py +0 -132
- langroid/language_models/prompt_formatter/llama2_formatter.py +0 -75
- langroid/language_models/utils.py +0 -147
- langroid/mytypes.py +0 -84
- langroid/parsing/__init__.py +0 -52
- langroid/parsing/agent_chats.py +0 -38
- langroid/parsing/code-parsing.md +0 -86
- langroid/parsing/code_parser.py +0 -121
- langroid/parsing/config.py +0 -0
- langroid/parsing/document_parser.py +0 -718
- langroid/parsing/image_text.py +0 -32
- langroid/parsing/para_sentence_split.py +0 -62
- langroid/parsing/parse_json.py +0 -155
- langroid/parsing/parser.py +0 -313
- langroid/parsing/repo_loader.py +0 -790
- langroid/parsing/routing.py +0 -36
- langroid/parsing/search.py +0 -275
- langroid/parsing/spider.py +0 -102
- langroid/parsing/table_loader.py +0 -94
- langroid/parsing/url_loader.py +0 -111
- langroid/parsing/url_loader_cookies.py +0 -73
- langroid/parsing/urls.py +0 -273
- langroid/parsing/utils.py +0 -373
- langroid/parsing/web_search.py +0 -155
- langroid/prompts/__init__.py +0 -9
- langroid/prompts/chat-gpt4-system-prompt.md +0 -68
- langroid/prompts/dialog.py +0 -17
- langroid/prompts/prompts_config.py +0 -5
- langroid/prompts/templates.py +0 -141
- langroid/pydantic_v1/__init__.py +0 -10
- langroid/pydantic_v1/main.py +0 -4
- langroid/utils/.chainlit/config.toml +0 -121
- langroid/utils/.chainlit/translations/en-US.json +0 -231
- langroid/utils/__init__.py +0 -19
- langroid/utils/algorithms/__init__.py +0 -3
- langroid/utils/algorithms/graph.py +0 -103
- langroid/utils/configuration.py +0 -98
- langroid/utils/constants.py +0 -30
- langroid/utils/docker.py +0 -37
- langroid/utils/git_utils.py +0 -252
- langroid/utils/globals.py +0 -49
- langroid/utils/llms/__init__.py +0 -0
- langroid/utils/llms/strings.py +0 -8
- langroid/utils/logging.py +0 -135
- langroid/utils/object_registry.py +0 -66
- langroid/utils/output/__init__.py +0 -20
- langroid/utils/output/citations.py +0 -41
- langroid/utils/output/printing.py +0 -99
- langroid/utils/output/status.py +0 -40
- langroid/utils/pandas_utils.py +0 -30
- langroid/utils/pydantic_utils.py +0 -602
- langroid/utils/system.py +0 -286
- langroid/utils/types.py +0 -93
- langroid/utils/web/__init__.py +0 -0
- langroid/utils/web/login.py +0 -83
- langroid/vector_store/__init__.py +0 -50
- langroid/vector_store/base.py +0 -357
- langroid/vector_store/chromadb.py +0 -214
- langroid/vector_store/lancedb.py +0 -401
- langroid/vector_store/meilisearch.py +0 -299
- langroid/vector_store/momento.py +0 -278
- langroid/vector_store/qdrant_cloud.py +0 -6
- langroid/vector_store/qdrantdb.py +0 -468
- langroid-0.31.1.dist-info/RECORD +0 -162
- {langroid-0.31.1.dist-info → langroid-0.33.3.dist-info/licenses}/LICENSE +0 -0
@@ -1,153 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import os
|
4
|
-
from contextlib import AbstractContextManager, contextmanager
|
5
|
-
from typing import Any, Dict, List, TypeVar
|
6
|
-
|
7
|
-
import fakeredis
|
8
|
-
import redis
|
9
|
-
from dotenv import load_dotenv
|
10
|
-
|
11
|
-
from langroid.cachedb.base import CacheDB, CacheDBConfig
|
12
|
-
|
13
|
-
T = TypeVar("T", bound="RedisCache")
|
14
|
-
logger = logging.getLogger(__name__)
|
15
|
-
|
16
|
-
|
17
|
-
class RedisCacheConfig(CacheDBConfig):
|
18
|
-
"""Configuration model for RedisCache."""
|
19
|
-
|
20
|
-
fake: bool = False
|
21
|
-
|
22
|
-
|
23
|
-
class RedisCache(CacheDB):
|
24
|
-
"""Redis implementation of the CacheDB."""
|
25
|
-
|
26
|
-
_warned_password: bool = False
|
27
|
-
|
28
|
-
def __init__(self, config: RedisCacheConfig):
|
29
|
-
"""
|
30
|
-
Initialize a RedisCache with the given config.
|
31
|
-
|
32
|
-
Args:
|
33
|
-
config (RedisCacheConfig): The configuration to use.
|
34
|
-
"""
|
35
|
-
self.config = config
|
36
|
-
load_dotenv()
|
37
|
-
|
38
|
-
if self.config.fake:
|
39
|
-
self.pool = fakeredis.FakeStrictRedis() # type: ignore
|
40
|
-
else:
|
41
|
-
redis_password = os.getenv("REDIS_PASSWORD")
|
42
|
-
redis_host = os.getenv("REDIS_HOST")
|
43
|
-
redis_port = os.getenv("REDIS_PORT")
|
44
|
-
if None in [redis_password, redis_host, redis_port]:
|
45
|
-
if not RedisCache._warned_password:
|
46
|
-
logger.warning(
|
47
|
-
"""REDIS_PASSWORD, REDIS_HOST, REDIS_PORT not set in .env file,
|
48
|
-
using fake redis client"""
|
49
|
-
)
|
50
|
-
RedisCache._warned_password = True
|
51
|
-
self.pool = fakeredis.FakeStrictRedis() # type: ignore
|
52
|
-
else:
|
53
|
-
self.pool = redis.ConnectionPool( # type: ignore
|
54
|
-
host=redis_host,
|
55
|
-
port=redis_port,
|
56
|
-
password=redis_password,
|
57
|
-
max_connections=500,
|
58
|
-
socket_timeout=5,
|
59
|
-
socket_keepalive=True,
|
60
|
-
retry_on_timeout=True,
|
61
|
-
health_check_interval=30,
|
62
|
-
)
|
63
|
-
|
64
|
-
@contextmanager # type: ignore
|
65
|
-
def redis_client(self) -> AbstractContextManager[T]: # type: ignore
|
66
|
-
"""Cleanly open and close a redis client, avoids max clients exceeded error"""
|
67
|
-
if isinstance(self.pool, fakeredis.FakeStrictRedis):
|
68
|
-
yield self.pool
|
69
|
-
else:
|
70
|
-
client: T = redis.Redis(connection_pool=self.pool)
|
71
|
-
try:
|
72
|
-
yield client
|
73
|
-
finally:
|
74
|
-
client.close()
|
75
|
-
|
76
|
-
def close_all_connections(self) -> None:
|
77
|
-
with self.redis_client() as client: # type: ignore
|
78
|
-
clients = client.client_list()
|
79
|
-
for c in clients:
|
80
|
-
client.client_kill(c["addr"])
|
81
|
-
|
82
|
-
def clear(self) -> None:
|
83
|
-
"""Clear keys from current db."""
|
84
|
-
with self.redis_client() as client: # type: ignore
|
85
|
-
client.flushdb()
|
86
|
-
|
87
|
-
def clear_all(self) -> None:
|
88
|
-
"""Clear all keys from all dbs."""
|
89
|
-
with self.redis_client() as client: # type: ignore
|
90
|
-
client.flushall()
|
91
|
-
|
92
|
-
def store(self, key: str, value: Any) -> None:
|
93
|
-
"""
|
94
|
-
Store a value associated with a key.
|
95
|
-
|
96
|
-
Args:
|
97
|
-
key (str): The key under which to store the value.
|
98
|
-
value (Any): The value to store.
|
99
|
-
"""
|
100
|
-
with self.redis_client() as client: # type: ignore
|
101
|
-
try:
|
102
|
-
client.set(key, json.dumps(value))
|
103
|
-
except redis.exceptions.ConnectionError:
|
104
|
-
logger.warning("Redis connection error, not storing key/value")
|
105
|
-
return None
|
106
|
-
|
107
|
-
def retrieve(self, key: str) -> Dict[str, Any] | str | None:
|
108
|
-
"""
|
109
|
-
Retrieve the value associated with a key.
|
110
|
-
|
111
|
-
Args:
|
112
|
-
key (str): The key to retrieve the value for.
|
113
|
-
|
114
|
-
Returns:
|
115
|
-
dict|str|None: The value associated with the key.
|
116
|
-
"""
|
117
|
-
with self.redis_client() as client: # type: ignore
|
118
|
-
try:
|
119
|
-
value = client.get(key)
|
120
|
-
except redis.exceptions.ConnectionError:
|
121
|
-
logger.warning("Redis connection error, returning None")
|
122
|
-
return None
|
123
|
-
return json.loads(value) if value else None
|
124
|
-
|
125
|
-
def delete_keys(self, keys: List[str]) -> None:
|
126
|
-
"""
|
127
|
-
Delete the keys from the cache.
|
128
|
-
|
129
|
-
Args:
|
130
|
-
keys (List[str]): The keys to delete.
|
131
|
-
"""
|
132
|
-
with self.redis_client() as client: # type: ignore
|
133
|
-
try:
|
134
|
-
client.delete(*keys)
|
135
|
-
except redis.exceptions.ConnectionError:
|
136
|
-
logger.warning("Redis connection error, not deleting keys")
|
137
|
-
return None
|
138
|
-
|
139
|
-
def delete_keys_pattern(self, pattern: str) -> None:
|
140
|
-
"""
|
141
|
-
Delete the keys matching the pattern from the cache.
|
142
|
-
|
143
|
-
Args:
|
144
|
-
prefix (str): The pattern to match.
|
145
|
-
"""
|
146
|
-
with self.redis_client() as client: # type: ignore
|
147
|
-
try:
|
148
|
-
keys = client.keys(pattern)
|
149
|
-
if len(keys) > 0:
|
150
|
-
client.delete(*keys)
|
151
|
-
except redis.exceptions.ConnectionError:
|
152
|
-
logger.warning("Redis connection error, not deleting keys")
|
153
|
-
return None
|
@@ -1,39 +0,0 @@
|
|
1
|
-
from . import base
|
2
|
-
from . import models
|
3
|
-
from . import remote_embeds
|
4
|
-
|
5
|
-
from .base import (
|
6
|
-
EmbeddingModel,
|
7
|
-
EmbeddingModelsConfig,
|
8
|
-
)
|
9
|
-
from .models import (
|
10
|
-
OpenAIEmbeddings,
|
11
|
-
OpenAIEmbeddingsConfig,
|
12
|
-
SentenceTransformerEmbeddings,
|
13
|
-
SentenceTransformerEmbeddingsConfig,
|
14
|
-
LlamaCppServerEmbeddings,
|
15
|
-
LlamaCppServerEmbeddingsConfig,
|
16
|
-
embedding_model,
|
17
|
-
)
|
18
|
-
from .remote_embeds import (
|
19
|
-
RemoteEmbeddingsConfig,
|
20
|
-
RemoteEmbeddings,
|
21
|
-
)
|
22
|
-
|
23
|
-
|
24
|
-
__all__ = [
|
25
|
-
"base",
|
26
|
-
"models",
|
27
|
-
"remote_embeds",
|
28
|
-
"EmbeddingModel",
|
29
|
-
"EmbeddingModelsConfig",
|
30
|
-
"OpenAIEmbeddings",
|
31
|
-
"OpenAIEmbeddingsConfig",
|
32
|
-
"SentenceTransformerEmbeddings",
|
33
|
-
"SentenceTransformerEmbeddingsConfig",
|
34
|
-
"LlamaCppServerEmbeddings",
|
35
|
-
"LlamaCppServerEmbeddingsConfig",
|
36
|
-
"embedding_model",
|
37
|
-
"RemoteEmbeddingsConfig",
|
38
|
-
"RemoteEmbeddings",
|
39
|
-
]
|
@@ -1,74 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from abc import ABC, abstractmethod
|
3
|
-
|
4
|
-
import numpy as np
|
5
|
-
|
6
|
-
from langroid.mytypes import EmbeddingFunction
|
7
|
-
from langroid.pydantic_v1 import BaseSettings
|
8
|
-
|
9
|
-
logging.getLogger("openai").setLevel(logging.ERROR)
|
10
|
-
|
11
|
-
|
12
|
-
class EmbeddingModelsConfig(BaseSettings):
|
13
|
-
model_type: str = "openai"
|
14
|
-
dims: int = 0
|
15
|
-
context_length: int = 512
|
16
|
-
batch_size: int = 512
|
17
|
-
|
18
|
-
|
19
|
-
class EmbeddingModel(ABC):
|
20
|
-
"""
|
21
|
-
Abstract base class for an embedding model.
|
22
|
-
"""
|
23
|
-
|
24
|
-
@classmethod
|
25
|
-
def create(cls, config: EmbeddingModelsConfig) -> "EmbeddingModel":
|
26
|
-
from langroid.embedding_models.models import (
|
27
|
-
AzureOpenAIEmbeddings,
|
28
|
-
AzureOpenAIEmbeddingsConfig,
|
29
|
-
FastEmbedEmbeddings,
|
30
|
-
FastEmbedEmbeddingsConfig,
|
31
|
-
LlamaCppServerEmbeddings,
|
32
|
-
LlamaCppServerEmbeddingsConfig,
|
33
|
-
OpenAIEmbeddings,
|
34
|
-
OpenAIEmbeddingsConfig,
|
35
|
-
SentenceTransformerEmbeddings,
|
36
|
-
SentenceTransformerEmbeddingsConfig,
|
37
|
-
)
|
38
|
-
from langroid.embedding_models.remote_embeds import (
|
39
|
-
RemoteEmbeddings,
|
40
|
-
RemoteEmbeddingsConfig,
|
41
|
-
)
|
42
|
-
|
43
|
-
if isinstance(config, RemoteEmbeddingsConfig):
|
44
|
-
return RemoteEmbeddings(config)
|
45
|
-
elif isinstance(config, OpenAIEmbeddingsConfig):
|
46
|
-
return OpenAIEmbeddings(config)
|
47
|
-
elif isinstance(config, AzureOpenAIEmbeddingsConfig):
|
48
|
-
return AzureOpenAIEmbeddings(config)
|
49
|
-
elif isinstance(config, SentenceTransformerEmbeddingsConfig):
|
50
|
-
return SentenceTransformerEmbeddings(config)
|
51
|
-
elif isinstance(config, FastEmbedEmbeddingsConfig):
|
52
|
-
return FastEmbedEmbeddings(config)
|
53
|
-
elif isinstance(config, LlamaCppServerEmbeddingsConfig):
|
54
|
-
return LlamaCppServerEmbeddings(config)
|
55
|
-
else:
|
56
|
-
raise ValueError(f"Unknown embedding config: {config.__repr_name__}")
|
57
|
-
|
58
|
-
@abstractmethod
|
59
|
-
def embedding_fn(self) -> EmbeddingFunction:
|
60
|
-
pass
|
61
|
-
|
62
|
-
@property
|
63
|
-
@abstractmethod
|
64
|
-
def embedding_dims(self) -> int:
|
65
|
-
pass
|
66
|
-
|
67
|
-
def similarity(self, text1: str, text2: str) -> float:
|
68
|
-
"""Compute cosine similarity between two texts."""
|
69
|
-
[emb1, emb2] = self.embedding_fn()([text1, text2])
|
70
|
-
return float(
|
71
|
-
np.array(emb1)
|
72
|
-
@ np.array(emb2)
|
73
|
-
/ (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
74
|
-
)
|
@@ -1,189 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from collections import Counter
|
3
|
-
from typing import Callable, List, Tuple
|
4
|
-
|
5
|
-
import faiss
|
6
|
-
import numpy as np
|
7
|
-
from sklearn.cluster import DBSCAN
|
8
|
-
from sklearn.neighbors import NearestNeighbors
|
9
|
-
from sklearn.preprocessing import StandardScaler
|
10
|
-
|
11
|
-
from langroid.mytypes import Document
|
12
|
-
|
13
|
-
logging.getLogger("faiss").setLevel(logging.ERROR)
|
14
|
-
logging.getLogger("faiss-cpu").setLevel(logging.ERROR)
|
15
|
-
|
16
|
-
|
17
|
-
def find_optimal_clusters(X: np.ndarray, max_clusters: int, threshold=0.1) -> int:
|
18
|
-
"""
|
19
|
-
Find the optimal number of clusters for FAISS K-means using the Elbow Method.
|
20
|
-
|
21
|
-
Args:
|
22
|
-
X (np.ndarray): A 2D NumPy array of data points.
|
23
|
-
max_clusters (int): The maximum number of clusters to try.
|
24
|
-
threshold (float): Threshold for the rate of change in inertia values.
|
25
|
-
Defaults to 0.1.
|
26
|
-
|
27
|
-
Returns:
|
28
|
-
int: The optimal number of clusters.
|
29
|
-
"""
|
30
|
-
inertias = []
|
31
|
-
max_clusters = min(max_clusters, X.shape[0])
|
32
|
-
cluster_range = range(1, max_clusters + 1)
|
33
|
-
|
34
|
-
for nclusters in cluster_range:
|
35
|
-
kmeans = faiss.Kmeans(X.shape[1], nclusters, niter=20, verbose=False)
|
36
|
-
kmeans.train(X)
|
37
|
-
centroids = kmeans.centroids
|
38
|
-
distances = np.sum(np.square(X[:, None] - centroids), axis=-1)
|
39
|
-
inertia = np.sum(np.min(distances, axis=-1))
|
40
|
-
inertias.append(inertia)
|
41
|
-
|
42
|
-
# Calculate the rate of change in inertia values
|
43
|
-
rate_of_change = [
|
44
|
-
abs((inertias[i + 1] - inertias[i]) / inertias[i])
|
45
|
-
for i in range(len(inertias) - 1)
|
46
|
-
]
|
47
|
-
|
48
|
-
# Find the optimal number of clusters based on the rate of change threshold
|
49
|
-
optimal_clusters = 1
|
50
|
-
for i, roc in enumerate(rate_of_change):
|
51
|
-
if roc < threshold:
|
52
|
-
optimal_clusters = i + 1
|
53
|
-
break
|
54
|
-
|
55
|
-
return optimal_clusters
|
56
|
-
|
57
|
-
|
58
|
-
def densest_clusters(
|
59
|
-
embeddings: List[np.ndarray], k: int = 5
|
60
|
-
) -> List[Tuple[np.ndarray, int]]:
|
61
|
-
"""
|
62
|
-
Find the top k densest clusters in the given list of embeddings using FAISS K-means.
|
63
|
-
See here:
|
64
|
-
'https://github.com/facebookresearch/faiss/wiki/Faiss-building-blocks%3A-clustering%
|
65
|
-
2C-PCA%2C-quantization'
|
66
|
-
|
67
|
-
Args:
|
68
|
-
embeddings (List[np.ndarray]): A list of embedding vectors.
|
69
|
-
k (int, optional): The number of densest clusters to find. Defaults to 5.
|
70
|
-
|
71
|
-
Returns:
|
72
|
-
List[Tuple[np.ndarray, int]]: A list of representative vectors and their indices
|
73
|
-
from the k densest clusters.
|
74
|
-
"""
|
75
|
-
# Convert the list of embeddings to a NumPy array
|
76
|
-
X = np.vstack(embeddings)
|
77
|
-
|
78
|
-
# FAISS K-means clustering
|
79
|
-
ncentroids = find_optimal_clusters(X, max_clusters=2 * k, threshold=0.1)
|
80
|
-
k = min(k, ncentroids)
|
81
|
-
niter = 20
|
82
|
-
verbose = True
|
83
|
-
d = X.shape[1]
|
84
|
-
kmeans = faiss.Kmeans(d, k, niter=niter, verbose=verbose)
|
85
|
-
kmeans.train(X)
|
86
|
-
|
87
|
-
# Get the cluster centroids
|
88
|
-
centroids = kmeans.centroids
|
89
|
-
|
90
|
-
# Find the nearest neighbors of the centroids in the original embeddings
|
91
|
-
nbrs = NearestNeighbors(n_neighbors=1, algorithm="auto").fit(X)
|
92
|
-
distances, indices = nbrs.kneighbors(centroids)
|
93
|
-
|
94
|
-
# Sort the centroids by their nearest neighbor distances
|
95
|
-
sorted_centroids_indices = np.argsort(distances, axis=0).flatten()
|
96
|
-
|
97
|
-
# Select the top k densest clusters
|
98
|
-
densest_clusters_indices = sorted_centroids_indices[:k]
|
99
|
-
|
100
|
-
# Get the representative vectors and their indices from the densest clusters
|
101
|
-
representative_vectors = [
|
102
|
-
(idx, embeddings[idx]) for idx in densest_clusters_indices
|
103
|
-
]
|
104
|
-
|
105
|
-
return representative_vectors
|
106
|
-
|
107
|
-
|
108
|
-
def densest_clusters_DBSCAN(
|
109
|
-
embeddings: np.ndarray, k: int = 10
|
110
|
-
) -> List[Tuple[int, np.ndarray]]:
|
111
|
-
"""
|
112
|
-
Find the representative vector and corresponding index from each of the k densest
|
113
|
-
clusters in the given embeddings.
|
114
|
-
|
115
|
-
Args:
|
116
|
-
embeddings (np.ndarray): A NumPy array of shape (n, d), where n is the number
|
117
|
-
of embedding vectors and d is their dimensionality.
|
118
|
-
k (int): Number of densest clusters to find.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
List[Tuple[int, np.ndarray]]: A list of tuples containing the index and
|
122
|
-
representative vector for each of the k densest
|
123
|
-
clusters.
|
124
|
-
"""
|
125
|
-
|
126
|
-
# Normalize the embeddings if necessary
|
127
|
-
scaler = StandardScaler()
|
128
|
-
embeddings_normalized = scaler.fit_transform(embeddings)
|
129
|
-
|
130
|
-
# Choose a clustering algorithm (DBSCAN in this case)
|
131
|
-
# Tune eps and min_samples for your use case
|
132
|
-
dbscan = DBSCAN(eps=4, min_samples=5)
|
133
|
-
|
134
|
-
# Apply the clustering algorithm
|
135
|
-
cluster_labels = dbscan.fit_predict(embeddings_normalized)
|
136
|
-
|
137
|
-
# Compute the densities of the clusters
|
138
|
-
cluster_density = Counter(cluster_labels)
|
139
|
-
|
140
|
-
# Sort clusters by their density
|
141
|
-
sorted_clusters = sorted(cluster_density.items(), key=lambda x: x[1], reverse=True)
|
142
|
-
|
143
|
-
# Select top-k densest clusters
|
144
|
-
top_k_clusters = sorted_clusters[:k]
|
145
|
-
|
146
|
-
# Find a representative vector for each cluster
|
147
|
-
representatives = []
|
148
|
-
for cluster_id, _ in top_k_clusters:
|
149
|
-
if cluster_id == -1:
|
150
|
-
continue # Skip the noise cluster (label -1)
|
151
|
-
indices = np.where(cluster_labels == cluster_id)[0]
|
152
|
-
centroid = embeddings[indices].mean(axis=0)
|
153
|
-
closest_index = indices[
|
154
|
-
np.argmin(np.linalg.norm(embeddings[indices] - centroid, axis=1))
|
155
|
-
]
|
156
|
-
representatives.append((closest_index, embeddings[closest_index]))
|
157
|
-
|
158
|
-
return representatives
|
159
|
-
|
160
|
-
|
161
|
-
def densest_doc_clusters(
|
162
|
-
docs: List[Document], k: int, embedding_fn: Callable[[str], np.ndarray]
|
163
|
-
) -> List[Document]:
|
164
|
-
"""
|
165
|
-
Find the documents corresponding to the representative vectors of the k densest
|
166
|
-
clusters in the given list of documents.
|
167
|
-
|
168
|
-
Args:
|
169
|
-
docs (List[Document]): A list of Document instances, each containing a "content"
|
170
|
-
field to be embedded and a "metadata" field.
|
171
|
-
k (int): Number of densest clusters to find.
|
172
|
-
embedding_fn (Callable[[str], np.ndarray]): A function that maps a string to an
|
173
|
-
embedding vector.
|
174
|
-
|
175
|
-
Returns:
|
176
|
-
List[Document]: A list of Document instances corresponding to the representative
|
177
|
-
vectors of the k densest clusters.
|
178
|
-
"""
|
179
|
-
|
180
|
-
# Extract embeddings from the documents
|
181
|
-
embeddings = np.array(embedding_fn([doc.content for doc in docs]))
|
182
|
-
|
183
|
-
# Find the densest clusters and their representative indices
|
184
|
-
representative_indices_and_vectors = densest_clusters(embeddings, k)
|
185
|
-
|
186
|
-
# Extract the corresponding documents
|
187
|
-
representative_docs = [docs[i] for i, _ in representative_indices_and_vectors]
|
188
|
-
|
189
|
-
return representative_docs
|