langroid 0.31.1__py3-none-any.whl → 0.33.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. {langroid-0.31.1.dist-info → langroid-0.33.3.dist-info}/METADATA +150 -124
  2. langroid-0.33.3.dist-info/RECORD +7 -0
  3. {langroid-0.31.1.dist-info → langroid-0.33.3.dist-info}/WHEEL +1 -1
  4. langroid-0.33.3.dist-info/entry_points.txt +4 -0
  5. pyproject.toml +317 -212
  6. langroid/__init__.py +0 -106
  7. langroid/agent/.chainlit/config.toml +0 -121
  8. langroid/agent/.chainlit/translations/bn.json +0 -231
  9. langroid/agent/.chainlit/translations/en-US.json +0 -229
  10. langroid/agent/.chainlit/translations/gu.json +0 -231
  11. langroid/agent/.chainlit/translations/he-IL.json +0 -231
  12. langroid/agent/.chainlit/translations/hi.json +0 -231
  13. langroid/agent/.chainlit/translations/kn.json +0 -231
  14. langroid/agent/.chainlit/translations/ml.json +0 -231
  15. langroid/agent/.chainlit/translations/mr.json +0 -231
  16. langroid/agent/.chainlit/translations/ta.json +0 -231
  17. langroid/agent/.chainlit/translations/te.json +0 -231
  18. langroid/agent/.chainlit/translations/zh-CN.json +0 -229
  19. langroid/agent/__init__.py +0 -41
  20. langroid/agent/base.py +0 -1981
  21. langroid/agent/batch.py +0 -398
  22. langroid/agent/callbacks/__init__.py +0 -0
  23. langroid/agent/callbacks/chainlit.py +0 -598
  24. langroid/agent/chat_agent.py +0 -1899
  25. langroid/agent/chat_document.py +0 -454
  26. langroid/agent/helpers.py +0 -0
  27. langroid/agent/junk +0 -13
  28. langroid/agent/openai_assistant.py +0 -882
  29. langroid/agent/special/__init__.py +0 -59
  30. langroid/agent/special/arangodb/__init__.py +0 -0
  31. langroid/agent/special/arangodb/arangodb_agent.py +0 -656
  32. langroid/agent/special/arangodb/system_messages.py +0 -186
  33. langroid/agent/special/arangodb/tools.py +0 -107
  34. langroid/agent/special/arangodb/utils.py +0 -36
  35. langroid/agent/special/doc_chat_agent.py +0 -1466
  36. langroid/agent/special/lance_doc_chat_agent.py +0 -262
  37. langroid/agent/special/lance_rag/__init__.py +0 -9
  38. langroid/agent/special/lance_rag/critic_agent.py +0 -198
  39. langroid/agent/special/lance_rag/lance_rag_task.py +0 -82
  40. langroid/agent/special/lance_rag/query_planner_agent.py +0 -260
  41. langroid/agent/special/lance_tools.py +0 -61
  42. langroid/agent/special/neo4j/__init__.py +0 -0
  43. langroid/agent/special/neo4j/csv_kg_chat.py +0 -174
  44. langroid/agent/special/neo4j/neo4j_chat_agent.py +0 -433
  45. langroid/agent/special/neo4j/system_messages.py +0 -120
  46. langroid/agent/special/neo4j/tools.py +0 -32
  47. langroid/agent/special/relevance_extractor_agent.py +0 -127
  48. langroid/agent/special/retriever_agent.py +0 -56
  49. langroid/agent/special/sql/__init__.py +0 -17
  50. langroid/agent/special/sql/sql_chat_agent.py +0 -654
  51. langroid/agent/special/sql/utils/__init__.py +0 -21
  52. langroid/agent/special/sql/utils/description_extractors.py +0 -190
  53. langroid/agent/special/sql/utils/populate_metadata.py +0 -85
  54. langroid/agent/special/sql/utils/system_message.py +0 -35
  55. langroid/agent/special/sql/utils/tools.py +0 -64
  56. langroid/agent/special/table_chat_agent.py +0 -263
  57. langroid/agent/structured_message.py +0 -9
  58. langroid/agent/task.py +0 -2093
  59. langroid/agent/tool_message.py +0 -393
  60. langroid/agent/tools/__init__.py +0 -38
  61. langroid/agent/tools/duckduckgo_search_tool.py +0 -50
  62. langroid/agent/tools/file_tools.py +0 -234
  63. langroid/agent/tools/google_search_tool.py +0 -39
  64. langroid/agent/tools/metaphor_search_tool.py +0 -67
  65. langroid/agent/tools/orchestration.py +0 -303
  66. langroid/agent/tools/recipient_tool.py +0 -235
  67. langroid/agent/tools/retrieval_tool.py +0 -32
  68. langroid/agent/tools/rewind_tool.py +0 -137
  69. langroid/agent/tools/segment_extract_tool.py +0 -41
  70. langroid/agent/typed_task.py +0 -19
  71. langroid/agent/xml_tool_message.py +0 -382
  72. langroid/agent_config.py +0 -0
  73. langroid/cachedb/__init__.py +0 -17
  74. langroid/cachedb/base.py +0 -58
  75. langroid/cachedb/momento_cachedb.py +0 -108
  76. langroid/cachedb/redis_cachedb.py +0 -153
  77. langroid/embedding_models/__init__.py +0 -39
  78. langroid/embedding_models/base.py +0 -74
  79. langroid/embedding_models/clustering.py +0 -189
  80. langroid/embedding_models/models.py +0 -461
  81. langroid/embedding_models/protoc/__init__.py +0 -0
  82. langroid/embedding_models/protoc/embeddings.proto +0 -19
  83. langroid/embedding_models/protoc/embeddings_pb2.py +0 -33
  84. langroid/embedding_models/protoc/embeddings_pb2.pyi +0 -50
  85. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +0 -79
  86. langroid/embedding_models/remote_embeds.py +0 -153
  87. langroid/exceptions.py +0 -65
  88. langroid/experimental/team-save.py +0 -391
  89. langroid/language_models/.chainlit/config.toml +0 -121
  90. langroid/language_models/.chainlit/translations/en-US.json +0 -231
  91. langroid/language_models/__init__.py +0 -53
  92. langroid/language_models/azure_openai.py +0 -153
  93. langroid/language_models/base.py +0 -678
  94. langroid/language_models/config.py +0 -18
  95. langroid/language_models/mock_lm.py +0 -124
  96. langroid/language_models/openai_gpt.py +0 -1923
  97. langroid/language_models/prompt_formatter/__init__.py +0 -16
  98. langroid/language_models/prompt_formatter/base.py +0 -40
  99. langroid/language_models/prompt_formatter/hf_formatter.py +0 -132
  100. langroid/language_models/prompt_formatter/llama2_formatter.py +0 -75
  101. langroid/language_models/utils.py +0 -147
  102. langroid/mytypes.py +0 -84
  103. langroid/parsing/__init__.py +0 -52
  104. langroid/parsing/agent_chats.py +0 -38
  105. langroid/parsing/code-parsing.md +0 -86
  106. langroid/parsing/code_parser.py +0 -121
  107. langroid/parsing/config.py +0 -0
  108. langroid/parsing/document_parser.py +0 -718
  109. langroid/parsing/image_text.py +0 -32
  110. langroid/parsing/para_sentence_split.py +0 -62
  111. langroid/parsing/parse_json.py +0 -155
  112. langroid/parsing/parser.py +0 -313
  113. langroid/parsing/repo_loader.py +0 -790
  114. langroid/parsing/routing.py +0 -36
  115. langroid/parsing/search.py +0 -275
  116. langroid/parsing/spider.py +0 -102
  117. langroid/parsing/table_loader.py +0 -94
  118. langroid/parsing/url_loader.py +0 -111
  119. langroid/parsing/url_loader_cookies.py +0 -73
  120. langroid/parsing/urls.py +0 -273
  121. langroid/parsing/utils.py +0 -373
  122. langroid/parsing/web_search.py +0 -155
  123. langroid/prompts/__init__.py +0 -9
  124. langroid/prompts/chat-gpt4-system-prompt.md +0 -68
  125. langroid/prompts/dialog.py +0 -17
  126. langroid/prompts/prompts_config.py +0 -5
  127. langroid/prompts/templates.py +0 -141
  128. langroid/pydantic_v1/__init__.py +0 -10
  129. langroid/pydantic_v1/main.py +0 -4
  130. langroid/utils/.chainlit/config.toml +0 -121
  131. langroid/utils/.chainlit/translations/en-US.json +0 -231
  132. langroid/utils/__init__.py +0 -19
  133. langroid/utils/algorithms/__init__.py +0 -3
  134. langroid/utils/algorithms/graph.py +0 -103
  135. langroid/utils/configuration.py +0 -98
  136. langroid/utils/constants.py +0 -30
  137. langroid/utils/docker.py +0 -37
  138. langroid/utils/git_utils.py +0 -252
  139. langroid/utils/globals.py +0 -49
  140. langroid/utils/llms/__init__.py +0 -0
  141. langroid/utils/llms/strings.py +0 -8
  142. langroid/utils/logging.py +0 -135
  143. langroid/utils/object_registry.py +0 -66
  144. langroid/utils/output/__init__.py +0 -20
  145. langroid/utils/output/citations.py +0 -41
  146. langroid/utils/output/printing.py +0 -99
  147. langroid/utils/output/status.py +0 -40
  148. langroid/utils/pandas_utils.py +0 -30
  149. langroid/utils/pydantic_utils.py +0 -602
  150. langroid/utils/system.py +0 -286
  151. langroid/utils/types.py +0 -93
  152. langroid/utils/web/__init__.py +0 -0
  153. langroid/utils/web/login.py +0 -83
  154. langroid/vector_store/__init__.py +0 -50
  155. langroid/vector_store/base.py +0 -357
  156. langroid/vector_store/chromadb.py +0 -214
  157. langroid/vector_store/lancedb.py +0 -401
  158. langroid/vector_store/meilisearch.py +0 -299
  159. langroid/vector_store/momento.py +0 -278
  160. langroid/vector_store/qdrant_cloud.py +0 -6
  161. langroid/vector_store/qdrantdb.py +0 -468
  162. langroid-0.31.1.dist-info/RECORD +0 -162
  163. {langroid-0.31.1.dist-info → langroid-0.33.3.dist-info/licenses}/LICENSE +0 -0
@@ -1,153 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from contextlib import AbstractContextManager, contextmanager
5
- from typing import Any, Dict, List, TypeVar
6
-
7
- import fakeredis
8
- import redis
9
- from dotenv import load_dotenv
10
-
11
- from langroid.cachedb.base import CacheDB, CacheDBConfig
12
-
13
- T = TypeVar("T", bound="RedisCache")
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class RedisCacheConfig(CacheDBConfig):
18
- """Configuration model for RedisCache."""
19
-
20
- fake: bool = False
21
-
22
-
23
- class RedisCache(CacheDB):
24
- """Redis implementation of the CacheDB."""
25
-
26
- _warned_password: bool = False
27
-
28
- def __init__(self, config: RedisCacheConfig):
29
- """
30
- Initialize a RedisCache with the given config.
31
-
32
- Args:
33
- config (RedisCacheConfig): The configuration to use.
34
- """
35
- self.config = config
36
- load_dotenv()
37
-
38
- if self.config.fake:
39
- self.pool = fakeredis.FakeStrictRedis() # type: ignore
40
- else:
41
- redis_password = os.getenv("REDIS_PASSWORD")
42
- redis_host = os.getenv("REDIS_HOST")
43
- redis_port = os.getenv("REDIS_PORT")
44
- if None in [redis_password, redis_host, redis_port]:
45
- if not RedisCache._warned_password:
46
- logger.warning(
47
- """REDIS_PASSWORD, REDIS_HOST, REDIS_PORT not set in .env file,
48
- using fake redis client"""
49
- )
50
- RedisCache._warned_password = True
51
- self.pool = fakeredis.FakeStrictRedis() # type: ignore
52
- else:
53
- self.pool = redis.ConnectionPool( # type: ignore
54
- host=redis_host,
55
- port=redis_port,
56
- password=redis_password,
57
- max_connections=500,
58
- socket_timeout=5,
59
- socket_keepalive=True,
60
- retry_on_timeout=True,
61
- health_check_interval=30,
62
- )
63
-
64
- @contextmanager # type: ignore
65
- def redis_client(self) -> AbstractContextManager[T]: # type: ignore
66
- """Cleanly open and close a redis client, avoids max clients exceeded error"""
67
- if isinstance(self.pool, fakeredis.FakeStrictRedis):
68
- yield self.pool
69
- else:
70
- client: T = redis.Redis(connection_pool=self.pool)
71
- try:
72
- yield client
73
- finally:
74
- client.close()
75
-
76
- def close_all_connections(self) -> None:
77
- with self.redis_client() as client: # type: ignore
78
- clients = client.client_list()
79
- for c in clients:
80
- client.client_kill(c["addr"])
81
-
82
- def clear(self) -> None:
83
- """Clear keys from current db."""
84
- with self.redis_client() as client: # type: ignore
85
- client.flushdb()
86
-
87
- def clear_all(self) -> None:
88
- """Clear all keys from all dbs."""
89
- with self.redis_client() as client: # type: ignore
90
- client.flushall()
91
-
92
- def store(self, key: str, value: Any) -> None:
93
- """
94
- Store a value associated with a key.
95
-
96
- Args:
97
- key (str): The key under which to store the value.
98
- value (Any): The value to store.
99
- """
100
- with self.redis_client() as client: # type: ignore
101
- try:
102
- client.set(key, json.dumps(value))
103
- except redis.exceptions.ConnectionError:
104
- logger.warning("Redis connection error, not storing key/value")
105
- return None
106
-
107
- def retrieve(self, key: str) -> Dict[str, Any] | str | None:
108
- """
109
- Retrieve the value associated with a key.
110
-
111
- Args:
112
- key (str): The key to retrieve the value for.
113
-
114
- Returns:
115
- dict|str|None: The value associated with the key.
116
- """
117
- with self.redis_client() as client: # type: ignore
118
- try:
119
- value = client.get(key)
120
- except redis.exceptions.ConnectionError:
121
- logger.warning("Redis connection error, returning None")
122
- return None
123
- return json.loads(value) if value else None
124
-
125
- def delete_keys(self, keys: List[str]) -> None:
126
- """
127
- Delete the keys from the cache.
128
-
129
- Args:
130
- keys (List[str]): The keys to delete.
131
- """
132
- with self.redis_client() as client: # type: ignore
133
- try:
134
- client.delete(*keys)
135
- except redis.exceptions.ConnectionError:
136
- logger.warning("Redis connection error, not deleting keys")
137
- return None
138
-
139
- def delete_keys_pattern(self, pattern: str) -> None:
140
- """
141
- Delete the keys matching the pattern from the cache.
142
-
143
- Args:
144
- prefix (str): The pattern to match.
145
- """
146
- with self.redis_client() as client: # type: ignore
147
- try:
148
- keys = client.keys(pattern)
149
- if len(keys) > 0:
150
- client.delete(*keys)
151
- except redis.exceptions.ConnectionError:
152
- logger.warning("Redis connection error, not deleting keys")
153
- return None
@@ -1,39 +0,0 @@
1
- from . import base
2
- from . import models
3
- from . import remote_embeds
4
-
5
- from .base import (
6
- EmbeddingModel,
7
- EmbeddingModelsConfig,
8
- )
9
- from .models import (
10
- OpenAIEmbeddings,
11
- OpenAIEmbeddingsConfig,
12
- SentenceTransformerEmbeddings,
13
- SentenceTransformerEmbeddingsConfig,
14
- LlamaCppServerEmbeddings,
15
- LlamaCppServerEmbeddingsConfig,
16
- embedding_model,
17
- )
18
- from .remote_embeds import (
19
- RemoteEmbeddingsConfig,
20
- RemoteEmbeddings,
21
- )
22
-
23
-
24
- __all__ = [
25
- "base",
26
- "models",
27
- "remote_embeds",
28
- "EmbeddingModel",
29
- "EmbeddingModelsConfig",
30
- "OpenAIEmbeddings",
31
- "OpenAIEmbeddingsConfig",
32
- "SentenceTransformerEmbeddings",
33
- "SentenceTransformerEmbeddingsConfig",
34
- "LlamaCppServerEmbeddings",
35
- "LlamaCppServerEmbeddingsConfig",
36
- "embedding_model",
37
- "RemoteEmbeddingsConfig",
38
- "RemoteEmbeddings",
39
- ]
@@ -1,74 +0,0 @@
1
- import logging
2
- from abc import ABC, abstractmethod
3
-
4
- import numpy as np
5
-
6
- from langroid.mytypes import EmbeddingFunction
7
- from langroid.pydantic_v1 import BaseSettings
8
-
9
- logging.getLogger("openai").setLevel(logging.ERROR)
10
-
11
-
12
- class EmbeddingModelsConfig(BaseSettings):
13
- model_type: str = "openai"
14
- dims: int = 0
15
- context_length: int = 512
16
- batch_size: int = 512
17
-
18
-
19
- class EmbeddingModel(ABC):
20
- """
21
- Abstract base class for an embedding model.
22
- """
23
-
24
- @classmethod
25
- def create(cls, config: EmbeddingModelsConfig) -> "EmbeddingModel":
26
- from langroid.embedding_models.models import (
27
- AzureOpenAIEmbeddings,
28
- AzureOpenAIEmbeddingsConfig,
29
- FastEmbedEmbeddings,
30
- FastEmbedEmbeddingsConfig,
31
- LlamaCppServerEmbeddings,
32
- LlamaCppServerEmbeddingsConfig,
33
- OpenAIEmbeddings,
34
- OpenAIEmbeddingsConfig,
35
- SentenceTransformerEmbeddings,
36
- SentenceTransformerEmbeddingsConfig,
37
- )
38
- from langroid.embedding_models.remote_embeds import (
39
- RemoteEmbeddings,
40
- RemoteEmbeddingsConfig,
41
- )
42
-
43
- if isinstance(config, RemoteEmbeddingsConfig):
44
- return RemoteEmbeddings(config)
45
- elif isinstance(config, OpenAIEmbeddingsConfig):
46
- return OpenAIEmbeddings(config)
47
- elif isinstance(config, AzureOpenAIEmbeddingsConfig):
48
- return AzureOpenAIEmbeddings(config)
49
- elif isinstance(config, SentenceTransformerEmbeddingsConfig):
50
- return SentenceTransformerEmbeddings(config)
51
- elif isinstance(config, FastEmbedEmbeddingsConfig):
52
- return FastEmbedEmbeddings(config)
53
- elif isinstance(config, LlamaCppServerEmbeddingsConfig):
54
- return LlamaCppServerEmbeddings(config)
55
- else:
56
- raise ValueError(f"Unknown embedding config: {config.__repr_name__}")
57
-
58
- @abstractmethod
59
- def embedding_fn(self) -> EmbeddingFunction:
60
- pass
61
-
62
- @property
63
- @abstractmethod
64
- def embedding_dims(self) -> int:
65
- pass
66
-
67
- def similarity(self, text1: str, text2: str) -> float:
68
- """Compute cosine similarity between two texts."""
69
- [emb1, emb2] = self.embedding_fn()([text1, text2])
70
- return float(
71
- np.array(emb1)
72
- @ np.array(emb2)
73
- / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
74
- )
@@ -1,189 +0,0 @@
1
- import logging
2
- from collections import Counter
3
- from typing import Callable, List, Tuple
4
-
5
- import faiss
6
- import numpy as np
7
- from sklearn.cluster import DBSCAN
8
- from sklearn.neighbors import NearestNeighbors
9
- from sklearn.preprocessing import StandardScaler
10
-
11
- from langroid.mytypes import Document
12
-
13
- logging.getLogger("faiss").setLevel(logging.ERROR)
14
- logging.getLogger("faiss-cpu").setLevel(logging.ERROR)
15
-
16
-
17
- def find_optimal_clusters(X: np.ndarray, max_clusters: int, threshold=0.1) -> int:
18
- """
19
- Find the optimal number of clusters for FAISS K-means using the Elbow Method.
20
-
21
- Args:
22
- X (np.ndarray): A 2D NumPy array of data points.
23
- max_clusters (int): The maximum number of clusters to try.
24
- threshold (float): Threshold for the rate of change in inertia values.
25
- Defaults to 0.1.
26
-
27
- Returns:
28
- int: The optimal number of clusters.
29
- """
30
- inertias = []
31
- max_clusters = min(max_clusters, X.shape[0])
32
- cluster_range = range(1, max_clusters + 1)
33
-
34
- for nclusters in cluster_range:
35
- kmeans = faiss.Kmeans(X.shape[1], nclusters, niter=20, verbose=False)
36
- kmeans.train(X)
37
- centroids = kmeans.centroids
38
- distances = np.sum(np.square(X[:, None] - centroids), axis=-1)
39
- inertia = np.sum(np.min(distances, axis=-1))
40
- inertias.append(inertia)
41
-
42
- # Calculate the rate of change in inertia values
43
- rate_of_change = [
44
- abs((inertias[i + 1] - inertias[i]) / inertias[i])
45
- for i in range(len(inertias) - 1)
46
- ]
47
-
48
- # Find the optimal number of clusters based on the rate of change threshold
49
- optimal_clusters = 1
50
- for i, roc in enumerate(rate_of_change):
51
- if roc < threshold:
52
- optimal_clusters = i + 1
53
- break
54
-
55
- return optimal_clusters
56
-
57
-
58
- def densest_clusters(
59
- embeddings: List[np.ndarray], k: int = 5
60
- ) -> List[Tuple[np.ndarray, int]]:
61
- """
62
- Find the top k densest clusters in the given list of embeddings using FAISS K-means.
63
- See here:
64
- 'https://github.com/facebookresearch/faiss/wiki/Faiss-building-blocks%3A-clustering%
65
- 2C-PCA%2C-quantization'
66
-
67
- Args:
68
- embeddings (List[np.ndarray]): A list of embedding vectors.
69
- k (int, optional): The number of densest clusters to find. Defaults to 5.
70
-
71
- Returns:
72
- List[Tuple[np.ndarray, int]]: A list of representative vectors and their indices
73
- from the k densest clusters.
74
- """
75
- # Convert the list of embeddings to a NumPy array
76
- X = np.vstack(embeddings)
77
-
78
- # FAISS K-means clustering
79
- ncentroids = find_optimal_clusters(X, max_clusters=2 * k, threshold=0.1)
80
- k = min(k, ncentroids)
81
- niter = 20
82
- verbose = True
83
- d = X.shape[1]
84
- kmeans = faiss.Kmeans(d, k, niter=niter, verbose=verbose)
85
- kmeans.train(X)
86
-
87
- # Get the cluster centroids
88
- centroids = kmeans.centroids
89
-
90
- # Find the nearest neighbors of the centroids in the original embeddings
91
- nbrs = NearestNeighbors(n_neighbors=1, algorithm="auto").fit(X)
92
- distances, indices = nbrs.kneighbors(centroids)
93
-
94
- # Sort the centroids by their nearest neighbor distances
95
- sorted_centroids_indices = np.argsort(distances, axis=0).flatten()
96
-
97
- # Select the top k densest clusters
98
- densest_clusters_indices = sorted_centroids_indices[:k]
99
-
100
- # Get the representative vectors and their indices from the densest clusters
101
- representative_vectors = [
102
- (idx, embeddings[idx]) for idx in densest_clusters_indices
103
- ]
104
-
105
- return representative_vectors
106
-
107
-
108
- def densest_clusters_DBSCAN(
109
- embeddings: np.ndarray, k: int = 10
110
- ) -> List[Tuple[int, np.ndarray]]:
111
- """
112
- Find the representative vector and corresponding index from each of the k densest
113
- clusters in the given embeddings.
114
-
115
- Args:
116
- embeddings (np.ndarray): A NumPy array of shape (n, d), where n is the number
117
- of embedding vectors and d is their dimensionality.
118
- k (int): Number of densest clusters to find.
119
-
120
- Returns:
121
- List[Tuple[int, np.ndarray]]: A list of tuples containing the index and
122
- representative vector for each of the k densest
123
- clusters.
124
- """
125
-
126
- # Normalize the embeddings if necessary
127
- scaler = StandardScaler()
128
- embeddings_normalized = scaler.fit_transform(embeddings)
129
-
130
- # Choose a clustering algorithm (DBSCAN in this case)
131
- # Tune eps and min_samples for your use case
132
- dbscan = DBSCAN(eps=4, min_samples=5)
133
-
134
- # Apply the clustering algorithm
135
- cluster_labels = dbscan.fit_predict(embeddings_normalized)
136
-
137
- # Compute the densities of the clusters
138
- cluster_density = Counter(cluster_labels)
139
-
140
- # Sort clusters by their density
141
- sorted_clusters = sorted(cluster_density.items(), key=lambda x: x[1], reverse=True)
142
-
143
- # Select top-k densest clusters
144
- top_k_clusters = sorted_clusters[:k]
145
-
146
- # Find a representative vector for each cluster
147
- representatives = []
148
- for cluster_id, _ in top_k_clusters:
149
- if cluster_id == -1:
150
- continue # Skip the noise cluster (label -1)
151
- indices = np.where(cluster_labels == cluster_id)[0]
152
- centroid = embeddings[indices].mean(axis=0)
153
- closest_index = indices[
154
- np.argmin(np.linalg.norm(embeddings[indices] - centroid, axis=1))
155
- ]
156
- representatives.append((closest_index, embeddings[closest_index]))
157
-
158
- return representatives
159
-
160
-
161
- def densest_doc_clusters(
162
- docs: List[Document], k: int, embedding_fn: Callable[[str], np.ndarray]
163
- ) -> List[Document]:
164
- """
165
- Find the documents corresponding to the representative vectors of the k densest
166
- clusters in the given list of documents.
167
-
168
- Args:
169
- docs (List[Document]): A list of Document instances, each containing a "content"
170
- field to be embedded and a "metadata" field.
171
- k (int): Number of densest clusters to find.
172
- embedding_fn (Callable[[str], np.ndarray]): A function that maps a string to an
173
- embedding vector.
174
-
175
- Returns:
176
- List[Document]: A list of Document instances corresponding to the representative
177
- vectors of the k densest clusters.
178
- """
179
-
180
- # Extract embeddings from the documents
181
- embeddings = np.array(embedding_fn([doc.content for doc in docs]))
182
-
183
- # Find the densest clusters and their representative indices
184
- representative_indices_and_vectors = densest_clusters(embeddings, k)
185
-
186
- # Extract the corresponding documents
187
- representative_docs = [docs[i] for i, _ in representative_indices_and_vectors]
188
-
189
- return representative_docs