langroid 0.33.6__py3-none-any.whl → 0.33.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +106 -0
- langroid/agent/__init__.py +41 -0
- langroid/agent/base.py +1983 -0
- langroid/agent/batch.py +398 -0
- langroid/agent/callbacks/__init__.py +0 -0
- langroid/agent/callbacks/chainlit.py +598 -0
- langroid/agent/chat_agent.py +1899 -0
- langroid/agent/chat_document.py +454 -0
- langroid/agent/openai_assistant.py +882 -0
- langroid/agent/special/__init__.py +59 -0
- langroid/agent/special/arangodb/__init__.py +0 -0
- langroid/agent/special/arangodb/arangodb_agent.py +656 -0
- langroid/agent/special/arangodb/system_messages.py +186 -0
- langroid/agent/special/arangodb/tools.py +107 -0
- langroid/agent/special/arangodb/utils.py +36 -0
- langroid/agent/special/doc_chat_agent.py +1466 -0
- langroid/agent/special/lance_doc_chat_agent.py +262 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +198 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +82 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +260 -0
- langroid/agent/special/lance_tools.py +61 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +433 -0
- langroid/agent/special/neo4j/system_messages.py +120 -0
- langroid/agent/special/neo4j/tools.py +32 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +56 -0
- langroid/agent/special/sql/__init__.py +17 -0
- langroid/agent/special/sql/sql_chat_agent.py +654 -0
- langroid/agent/special/sql/utils/__init__.py +21 -0
- langroid/agent/special/sql/utils/description_extractors.py +190 -0
- langroid/agent/special/sql/utils/populate_metadata.py +85 -0
- langroid/agent/special/sql/utils/system_message.py +35 -0
- langroid/agent/special/sql/utils/tools.py +64 -0
- langroid/agent/special/table_chat_agent.py +263 -0
- langroid/agent/task.py +2095 -0
- langroid/agent/tool_message.py +393 -0
- langroid/agent/tools/__init__.py +38 -0
- langroid/agent/tools/duckduckgo_search_tool.py +50 -0
- langroid/agent/tools/file_tools.py +234 -0
- langroid/agent/tools/google_search_tool.py +39 -0
- langroid/agent/tools/metaphor_search_tool.py +68 -0
- langroid/agent/tools/orchestration.py +303 -0
- langroid/agent/tools/recipient_tool.py +235 -0
- langroid/agent/tools/retrieval_tool.py +32 -0
- langroid/agent/tools/rewind_tool.py +137 -0
- langroid/agent/tools/segment_extract_tool.py +41 -0
- langroid/agent/xml_tool_message.py +382 -0
- langroid/cachedb/__init__.py +17 -0
- langroid/cachedb/base.py +58 -0
- langroid/cachedb/momento_cachedb.py +108 -0
- langroid/cachedb/redis_cachedb.py +153 -0
- langroid/embedding_models/__init__.py +39 -0
- langroid/embedding_models/base.py +74 -0
- langroid/embedding_models/models.py +461 -0
- langroid/embedding_models/protoc/__init__.py +0 -0
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/exceptions.py +71 -0
- langroid/language_models/__init__.py +53 -0
- langroid/language_models/azure_openai.py +153 -0
- langroid/language_models/base.py +678 -0
- langroid/language_models/config.py +18 -0
- langroid/language_models/mock_lm.py +124 -0
- langroid/language_models/openai_gpt.py +1964 -0
- langroid/language_models/prompt_formatter/__init__.py +16 -0
- langroid/language_models/prompt_formatter/base.py +40 -0
- langroid/language_models/prompt_formatter/hf_formatter.py +132 -0
- langroid/language_models/prompt_formatter/llama2_formatter.py +75 -0
- langroid/language_models/utils.py +151 -0
- langroid/mytypes.py +84 -0
- langroid/parsing/__init__.py +52 -0
- langroid/parsing/agent_chats.py +38 -0
- langroid/parsing/code_parser.py +121 -0
- langroid/parsing/document_parser.py +718 -0
- langroid/parsing/para_sentence_split.py +62 -0
- langroid/parsing/parse_json.py +155 -0
- langroid/parsing/parser.py +313 -0
- langroid/parsing/repo_loader.py +790 -0
- langroid/parsing/routing.py +36 -0
- langroid/parsing/search.py +275 -0
- langroid/parsing/spider.py +102 -0
- langroid/parsing/table_loader.py +94 -0
- langroid/parsing/url_loader.py +111 -0
- langroid/parsing/urls.py +273 -0
- langroid/parsing/utils.py +373 -0
- langroid/parsing/web_search.py +156 -0
- langroid/prompts/__init__.py +9 -0
- langroid/prompts/dialog.py +17 -0
- langroid/prompts/prompts_config.py +5 -0
- langroid/prompts/templates.py +141 -0
- langroid/pydantic_v1/__init__.py +10 -0
- langroid/pydantic_v1/main.py +4 -0
- langroid/utils/__init__.py +19 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +98 -0
- langroid/utils/constants.py +30 -0
- langroid/utils/git_utils.py +252 -0
- langroid/utils/globals.py +49 -0
- langroid/utils/logging.py +135 -0
- langroid/utils/object_registry.py +66 -0
- langroid/utils/output/__init__.py +20 -0
- langroid/utils/output/citations.py +41 -0
- langroid/utils/output/printing.py +99 -0
- langroid/utils/output/status.py +40 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +602 -0
- langroid/utils/system.py +286 -0
- langroid/utils/types.py +93 -0
- langroid/vector_store/__init__.py +50 -0
- langroid/vector_store/base.py +359 -0
- langroid/vector_store/chromadb.py +214 -0
- langroid/vector_store/lancedb.py +406 -0
- langroid/vector_store/meilisearch.py +299 -0
- langroid/vector_store/momento.py +278 -0
- langroid/vector_store/qdrantdb.py +468 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/METADATA +95 -94
- langroid-0.33.7.dist-info/RECORD +127 -0
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/WHEEL +1 -1
- langroid-0.33.6.dist-info/RECORD +0 -7
- langroid-0.33.6.dist-info/entry_points.txt +0 -4
- pyproject.toml +0 -356
- {langroid-0.33.6.dist-info → langroid-0.33.7.dist-info}/licenses/LICENSE +0 -0
langroid/parsing/urls.py
ADDED
@@ -0,0 +1,273 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
import urllib.parse
|
5
|
+
import urllib.robotparser
|
6
|
+
from typing import List, Optional, Set, Tuple
|
7
|
+
from urllib.parse import urldefrag, urljoin, urlparse
|
8
|
+
|
9
|
+
import fire
|
10
|
+
import requests
|
11
|
+
from bs4 import BeautifulSoup
|
12
|
+
from rich import print
|
13
|
+
from rich.prompt import Prompt
|
14
|
+
from trafilatura.spider import focused_crawler
|
15
|
+
|
16
|
+
from langroid.pydantic_v1 import BaseModel, HttpUrl, ValidationError, parse_obj_as
|
17
|
+
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
|
21
|
+
def url_to_tempfile(url: str) -> str:
|
22
|
+
"""
|
23
|
+
Fetch content from the given URL and save it to a temporary local file.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
url (str): The URL of the content to fetch.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
str: The path to the temporary file where the content is saved.
|
30
|
+
|
31
|
+
Raises:
|
32
|
+
HTTPError: If there's any issue fetching the content.
|
33
|
+
"""
|
34
|
+
|
35
|
+
response = requests.get(url)
|
36
|
+
response.raise_for_status() # Raise an exception for HTTP errors
|
37
|
+
|
38
|
+
# Create a temporary file and write the content
|
39
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".tmp") as temp_file:
|
40
|
+
temp_file.write(response.content)
|
41
|
+
return temp_file.name
|
42
|
+
|
43
|
+
|
44
|
+
def get_user_input(msg: str, color: str = "blue") -> str:
|
45
|
+
"""
|
46
|
+
Prompt the user for input.
|
47
|
+
Args:
|
48
|
+
msg: printed prompt
|
49
|
+
color: color of the prompt
|
50
|
+
Returns:
|
51
|
+
user input
|
52
|
+
"""
|
53
|
+
color_str = f"[{color}]{msg} " if color else msg + " "
|
54
|
+
print(color_str, end="")
|
55
|
+
return input("")
|
56
|
+
|
57
|
+
|
58
|
+
def get_list_from_user(
|
59
|
+
prompt: str = "Enter input (type 'done' or hit return to finish)",
|
60
|
+
n: int | None = None,
|
61
|
+
) -> List[str]:
|
62
|
+
"""
|
63
|
+
Prompt the user for inputs.
|
64
|
+
Args:
|
65
|
+
prompt: printed prompt
|
66
|
+
n: how many inputs to prompt for. If None, then prompt until done, otherwise
|
67
|
+
quit after n inputs.
|
68
|
+
Returns:
|
69
|
+
list of input strings
|
70
|
+
"""
|
71
|
+
# Create an empty set to store the URLs.
|
72
|
+
input_set = set()
|
73
|
+
|
74
|
+
# Use a while loop to continuously ask the user for URLs.
|
75
|
+
for _ in range(n or 1000):
|
76
|
+
# Prompt the user for input.
|
77
|
+
input_str = Prompt.ask(f"[blue]{prompt}")
|
78
|
+
|
79
|
+
# Check if the user wants to exit the loop.
|
80
|
+
if input_str.lower() == "done" or input_str == "":
|
81
|
+
break
|
82
|
+
|
83
|
+
# if it is a URL, ask how many to crawl
|
84
|
+
if is_url(input_str):
|
85
|
+
url = input_str
|
86
|
+
input_str = Prompt.ask("[blue] How many new URLs to crawl?", default="0")
|
87
|
+
max_urls = int(input_str) + 1
|
88
|
+
tot_urls = list(find_urls(url, max_links=max_urls, max_depth=2))
|
89
|
+
tot_urls_str = "\n".join(tot_urls)
|
90
|
+
print(
|
91
|
+
f"""
|
92
|
+
Found these {len(tot_urls)} links upto depth 2:
|
93
|
+
{tot_urls_str}
|
94
|
+
"""
|
95
|
+
)
|
96
|
+
|
97
|
+
input_set.update(tot_urls)
|
98
|
+
else:
|
99
|
+
input_set.add(input_str.strip())
|
100
|
+
|
101
|
+
return list(input_set)
|
102
|
+
|
103
|
+
|
104
|
+
class Url(BaseModel):
|
105
|
+
url: HttpUrl
|
106
|
+
|
107
|
+
|
108
|
+
def is_url(s: str) -> bool:
|
109
|
+
try:
|
110
|
+
Url(url=parse_obj_as(HttpUrl, s))
|
111
|
+
return True
|
112
|
+
except ValidationError:
|
113
|
+
return False
|
114
|
+
|
115
|
+
|
116
|
+
def get_urls_paths_bytes_indices(
|
117
|
+
inputs: List[str | bytes],
|
118
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
119
|
+
"""
|
120
|
+
Given a list of inputs, return a
|
121
|
+
list of indices of URLs, list of indices of paths, list of indices of byte-contents.
|
122
|
+
Args:
|
123
|
+
inputs: list of strings or bytes
|
124
|
+
Returns:
|
125
|
+
list of Indices of URLs,
|
126
|
+
list of indices of paths,
|
127
|
+
list of indices of byte-contents
|
128
|
+
"""
|
129
|
+
urls = []
|
130
|
+
paths = []
|
131
|
+
byte_list = []
|
132
|
+
for i, item in enumerate(inputs):
|
133
|
+
if isinstance(item, bytes):
|
134
|
+
byte_list.append(i)
|
135
|
+
continue
|
136
|
+
try:
|
137
|
+
Url(url=parse_obj_as(HttpUrl, item))
|
138
|
+
urls.append(i)
|
139
|
+
except ValidationError:
|
140
|
+
if os.path.exists(item):
|
141
|
+
paths.append(i)
|
142
|
+
else:
|
143
|
+
logger.warning(f"{item} is neither a URL nor a path.")
|
144
|
+
return urls, paths, byte_list
|
145
|
+
|
146
|
+
|
147
|
+
def crawl_url(url: str, max_urls: int = 1) -> List[str]:
|
148
|
+
"""
|
149
|
+
Crawl starting at the url and return a list of URLs to be parsed,
|
150
|
+
up to a maximum of `max_urls`.
|
151
|
+
This has not been tested to work as intended. Ignore.
|
152
|
+
"""
|
153
|
+
if max_urls == 1:
|
154
|
+
# no need to crawl, just return the original list
|
155
|
+
return [url]
|
156
|
+
|
157
|
+
to_visit = None
|
158
|
+
known_urls = None
|
159
|
+
|
160
|
+
# Create a RobotFileParser object
|
161
|
+
robots = urllib.robotparser.RobotFileParser()
|
162
|
+
while True:
|
163
|
+
if known_urls is not None and len(known_urls) >= max_urls:
|
164
|
+
break
|
165
|
+
# Set the RobotFileParser object to the website's robots.txt file
|
166
|
+
robots.set_url(url + "/robots.txt")
|
167
|
+
robots.read()
|
168
|
+
|
169
|
+
if robots.can_fetch("*", url):
|
170
|
+
# Start or resume the crawl
|
171
|
+
to_visit, known_urls = focused_crawler(
|
172
|
+
url,
|
173
|
+
max_seen_urls=max_urls,
|
174
|
+
max_known_urls=max_urls,
|
175
|
+
todo=to_visit,
|
176
|
+
known_links=known_urls,
|
177
|
+
rules=robots,
|
178
|
+
)
|
179
|
+
if to_visit is None:
|
180
|
+
break
|
181
|
+
|
182
|
+
if known_urls is None:
|
183
|
+
return [url]
|
184
|
+
final_urls = [s.strip() for s in known_urls]
|
185
|
+
return list(final_urls)[:max_urls]
|
186
|
+
|
187
|
+
|
188
|
+
def find_urls(
|
189
|
+
url: str = "https://en.wikipedia.org/wiki/Generative_pre-trained_transformer",
|
190
|
+
max_links: int = 20,
|
191
|
+
visited: Optional[Set[str]] = None,
|
192
|
+
depth: int = 0,
|
193
|
+
max_depth: int = 2,
|
194
|
+
match_domain: bool = True,
|
195
|
+
) -> Set[str]:
|
196
|
+
"""
|
197
|
+
Recursively find all URLs on a given page.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
url (str): The URL to start from.
|
201
|
+
max_links (int): The maximum number of links to find.
|
202
|
+
visited (set): A set of URLs that have already been visited.
|
203
|
+
depth (int): The current depth of the recursion.
|
204
|
+
max_depth (int): The maximum depth of the recursion.
|
205
|
+
match_domain (bool): Whether to only return URLs that are on the same domain.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
set: A set of URLs found on the page.
|
209
|
+
"""
|
210
|
+
|
211
|
+
if visited is None:
|
212
|
+
visited = set()
|
213
|
+
|
214
|
+
if url in visited or depth > max_depth:
|
215
|
+
return visited
|
216
|
+
|
217
|
+
visited.add(url)
|
218
|
+
base_domain = urlparse(url).netloc
|
219
|
+
|
220
|
+
try:
|
221
|
+
response = requests.get(url, timeout=5)
|
222
|
+
response.raise_for_status()
|
223
|
+
soup = BeautifulSoup(response.text, "html.parser")
|
224
|
+
links = [urljoin(url, a["href"]) for a in soup.find_all("a", href=True)]
|
225
|
+
|
226
|
+
# Defrag links: discard links that are to portions of same page
|
227
|
+
defragged_links = list(set(urldefrag(link).url for link in links))
|
228
|
+
|
229
|
+
# Filter links based on domain matching requirement
|
230
|
+
domain_matching_links = [
|
231
|
+
link for link in defragged_links if urlparse(link).netloc == base_domain
|
232
|
+
]
|
233
|
+
|
234
|
+
# ensure url is first, since below we are taking first max_links urls
|
235
|
+
domain_matching_links = [url] + [x for x in domain_matching_links if x != url]
|
236
|
+
|
237
|
+
# If found links exceed max_links, return immediately
|
238
|
+
if len(domain_matching_links) >= max_links:
|
239
|
+
return set(domain_matching_links[:max_links])
|
240
|
+
|
241
|
+
for link in domain_matching_links:
|
242
|
+
if len(visited) >= max_links:
|
243
|
+
break
|
244
|
+
|
245
|
+
if link not in visited:
|
246
|
+
visited.update(
|
247
|
+
find_urls(
|
248
|
+
link,
|
249
|
+
max_links,
|
250
|
+
visited,
|
251
|
+
depth + 1,
|
252
|
+
max_depth,
|
253
|
+
match_domain,
|
254
|
+
)
|
255
|
+
)
|
256
|
+
|
257
|
+
except (requests.RequestException, Exception) as e:
|
258
|
+
print(f"Error fetching {url}. Error: {e}")
|
259
|
+
|
260
|
+
return set(list(visited)[:max_links])
|
261
|
+
|
262
|
+
|
263
|
+
def org_user_from_github(url: str) -> str:
|
264
|
+
parsed = urllib.parse.urlparse(url)
|
265
|
+
org, user = parsed.path.lstrip("/").split("/")
|
266
|
+
return f"{org}-{user}"
|
267
|
+
|
268
|
+
|
269
|
+
if __name__ == "__main__":
|
270
|
+
# Example usage
|
271
|
+
found_urls = set(fire.Fire(find_urls))
|
272
|
+
for url in found_urls:
|
273
|
+
print(url)
|
@@ -0,0 +1,373 @@
|
|
1
|
+
import difflib
|
2
|
+
import logging
|
3
|
+
import random
|
4
|
+
import re
|
5
|
+
from functools import cache
|
6
|
+
from itertools import islice
|
7
|
+
from typing import Iterable, List, Sequence, TypeVar
|
8
|
+
|
9
|
+
import nltk
|
10
|
+
from faker import Faker
|
11
|
+
|
12
|
+
from langroid.mytypes import Document
|
13
|
+
from langroid.parsing.document_parser import DocumentType
|
14
|
+
from langroid.parsing.parser import Parser, ParsingConfig
|
15
|
+
from langroid.parsing.repo_loader import RepoLoader
|
16
|
+
from langroid.parsing.url_loader import URLLoader
|
17
|
+
from langroid.parsing.urls import get_urls_paths_bytes_indices
|
18
|
+
|
19
|
+
Faker.seed(23)
|
20
|
+
random.seed(43)
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
# Ensures the NLTK resource is available
|
26
|
+
@cache
|
27
|
+
def download_nltk_resource(resource: str) -> None:
|
28
|
+
try:
|
29
|
+
nltk.data.find(resource)
|
30
|
+
except LookupError:
|
31
|
+
nltk.download(resource, quiet=True)
|
32
|
+
|
33
|
+
|
34
|
+
# Download punkt_tab resource at module import
|
35
|
+
download_nltk_resource("punkt_tab")
|
36
|
+
download_nltk_resource("gutenberg")
|
37
|
+
|
38
|
+
T = TypeVar("T")
|
39
|
+
|
40
|
+
|
41
|
+
def batched(iterable: Iterable[T], n: int) -> Iterable[Sequence[T]]:
|
42
|
+
"""Batch data into tuples of length n. The last batch may be shorter."""
|
43
|
+
# batched('ABCDEFG', 3) --> ABC DEF G
|
44
|
+
if n < 1:
|
45
|
+
raise ValueError("n must be at least one")
|
46
|
+
it = iter(iterable)
|
47
|
+
while batch := tuple(islice(it, n)):
|
48
|
+
yield batch
|
49
|
+
|
50
|
+
|
51
|
+
def generate_random_sentences(k: int) -> str:
|
52
|
+
# Load the sample text
|
53
|
+
|
54
|
+
from nltk.corpus import gutenberg
|
55
|
+
|
56
|
+
text = gutenberg.raw("austen-emma.txt")
|
57
|
+
|
58
|
+
# Split the text into sentences
|
59
|
+
sentences = nltk.tokenize.sent_tokenize(text)
|
60
|
+
|
61
|
+
# Generate k random sentences
|
62
|
+
random_sentences = random.choices(sentences, k=k)
|
63
|
+
return " ".join(random_sentences)
|
64
|
+
|
65
|
+
|
66
|
+
def generate_random_text(num_sentences: int) -> str:
|
67
|
+
fake = Faker()
|
68
|
+
text = ""
|
69
|
+
for _ in range(num_sentences):
|
70
|
+
text += fake.sentence() + " "
|
71
|
+
return text
|
72
|
+
|
73
|
+
|
74
|
+
def closest_string(query: str, string_list: List[str]) -> str:
|
75
|
+
"""Find the closest match to the query in a list of strings.
|
76
|
+
|
77
|
+
This function is case-insensitive and ignores leading and trailing whitespace.
|
78
|
+
If no match is found, it returns 'No match found'.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
query (str): The string to match.
|
82
|
+
string_list (List[str]): The list of strings to search.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
str: The closest match to the query from the list, or 'No match found'
|
86
|
+
if no match is found.
|
87
|
+
"""
|
88
|
+
# Create a dictionary where the keys are the standardized strings and
|
89
|
+
# the values are the original strings.
|
90
|
+
str_dict = {s.lower().strip(): s for s in string_list}
|
91
|
+
|
92
|
+
# Standardize the query and find the closest match in the list of keys.
|
93
|
+
closest_match = difflib.get_close_matches(
|
94
|
+
query.lower().strip(), str_dict.keys(), n=1
|
95
|
+
)
|
96
|
+
|
97
|
+
# Retrieve the original string from the value in the dictionary.
|
98
|
+
original_closest_match = (
|
99
|
+
str_dict[closest_match[0]] if closest_match else "No match found"
|
100
|
+
)
|
101
|
+
|
102
|
+
return original_closest_match
|
103
|
+
|
104
|
+
|
105
|
+
def split_paragraphs(text: str) -> List[str]:
|
106
|
+
"""
|
107
|
+
Split the input text into paragraphs using "\n\n" as the delimiter.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
text (str): The input text.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
list: A list of paragraphs.
|
114
|
+
"""
|
115
|
+
# Split based on a newline, followed by spaces/tabs, then another newline.
|
116
|
+
paras = re.split(r"\n[ \t]*\n", text)
|
117
|
+
return [para.strip() for para in paras if para.strip()]
|
118
|
+
|
119
|
+
|
120
|
+
def split_newlines(text: str) -> List[str]:
|
121
|
+
"""
|
122
|
+
Split the input text into lines using "\n" as the delimiter.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
text (str): The input text.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
list: A list of lines.
|
129
|
+
"""
|
130
|
+
lines = re.split(r"\n", text)
|
131
|
+
return [line.strip() for line in lines if line.strip()]
|
132
|
+
|
133
|
+
|
134
|
+
def number_segments(s: str, granularity: int = 1) -> str:
|
135
|
+
"""
|
136
|
+
Number the segments in a given text, preserving paragraph structure.
|
137
|
+
A segment is a sequence of `len` consecutive "sentences", where a "sentence"
|
138
|
+
is either a normal sentence, or if there isn't enough punctuation to properly
|
139
|
+
identify sentences, then we use a pseudo-sentence via heuristics (split by newline
|
140
|
+
or failing that, just split every 40 words). The goal here is simply to number
|
141
|
+
segments at a reasonable granularity so the LLM can identify relevant segments,
|
142
|
+
in the RelevanceExtractorAgent.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
s (str): The input text.
|
146
|
+
granularity (int): The number of sentences in a segment.
|
147
|
+
If this is -1, then the entire text is treated as a single segment,
|
148
|
+
and is numbered as <#1#>.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
str: The text with segments numbered in the style <#1#>, <#2#> etc.
|
152
|
+
|
153
|
+
Example:
|
154
|
+
>>> number_segments("Hello world! How are you? Have a good day.")
|
155
|
+
'<#1#> Hello world! <#2#> How are you? <#3#> Have a good day.'
|
156
|
+
"""
|
157
|
+
if granularity < 0:
|
158
|
+
return "<#1#> " + s
|
159
|
+
numbered_text = []
|
160
|
+
count = 0
|
161
|
+
|
162
|
+
paragraphs = split_paragraphs(s)
|
163
|
+
for paragraph in paragraphs:
|
164
|
+
sentences = nltk.sent_tokenize(paragraph)
|
165
|
+
# Some docs are problematic (e.g. resumes) and have no (or too few) periods,
|
166
|
+
# so we can't split usefully into sentences.
|
167
|
+
# We try a series of heuristics to split into sentences,
|
168
|
+
# until the avg num words per sentence is less than 40.
|
169
|
+
avg_words_per_sentence = sum(
|
170
|
+
len(nltk.word_tokenize(sentence)) for sentence in sentences
|
171
|
+
) / len(sentences)
|
172
|
+
if avg_words_per_sentence > 40:
|
173
|
+
sentences = split_newlines(paragraph)
|
174
|
+
avg_words_per_sentence = sum(
|
175
|
+
len(nltk.word_tokenize(sentence)) for sentence in sentences
|
176
|
+
) / len(sentences)
|
177
|
+
if avg_words_per_sentence > 40:
|
178
|
+
# Still too long, just split on every 40 words
|
179
|
+
sentences = []
|
180
|
+
for sentence in nltk.sent_tokenize(paragraph):
|
181
|
+
words = nltk.word_tokenize(sentence)
|
182
|
+
for i in range(0, len(words), 40):
|
183
|
+
# if there are less than 20 words left after this,
|
184
|
+
# just add them to the last sentence and break
|
185
|
+
if len(words) - i < 20:
|
186
|
+
sentences.append(" ".join(words[i:]))
|
187
|
+
break
|
188
|
+
else:
|
189
|
+
sentences.append(" ".join(words[i : i + 40]))
|
190
|
+
for i, sentence in enumerate(sentences):
|
191
|
+
num = count // granularity + 1
|
192
|
+
number_prefix = f"<#{num}#>" if count % granularity == 0 else ""
|
193
|
+
sentence = f"{number_prefix} {sentence}"
|
194
|
+
count += 1
|
195
|
+
sentences[i] = sentence
|
196
|
+
numbered_paragraph = " ".join(sentences)
|
197
|
+
numbered_text.append(numbered_paragraph)
|
198
|
+
|
199
|
+
return " \n\n ".join(numbered_text)
|
200
|
+
|
201
|
+
|
202
|
+
def number_sentences(s: str) -> str:
|
203
|
+
return number_segments(s, granularity=1)
|
204
|
+
|
205
|
+
|
206
|
+
def parse_number_range_list(specs: str) -> List[int]:
|
207
|
+
"""
|
208
|
+
Parse a specs string like "3,5,7-10" into a list of integers.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
specs (str): A string containing segment numbers and/or ranges
|
212
|
+
(e.g., "3,5,7-10").
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
List[int]: List of segment numbers.
|
216
|
+
|
217
|
+
Example:
|
218
|
+
>>> parse_number_range_list("3,5,7-10")
|
219
|
+
[3, 5, 7, 8, 9, 10]
|
220
|
+
"""
|
221
|
+
spec_indices = set() # type: ignore
|
222
|
+
for part in specs.split(","):
|
223
|
+
# some weak LLMs may generate <#1#> instead of 1, so extract just the digits
|
224
|
+
# or the "-"
|
225
|
+
part = "".join(char for char in part if char.isdigit() or char == "-")
|
226
|
+
if "-" in part:
|
227
|
+
start, end = map(int, part.split("-"))
|
228
|
+
spec_indices.update(range(start, end + 1))
|
229
|
+
else:
|
230
|
+
spec_indices.add(int(part))
|
231
|
+
|
232
|
+
return sorted(list(spec_indices))
|
233
|
+
|
234
|
+
|
235
|
+
def strip_k(s: str, k: int = 2) -> str:
|
236
|
+
"""
|
237
|
+
Strip any leading and trailing whitespaces from the input text beyond length k.
|
238
|
+
This is useful for removing leading/trailing whitespaces from a text while
|
239
|
+
preserving paragraph structure.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
s (str): The input text.
|
243
|
+
k (int): The number of leading and trailing whitespaces to retain.
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
str: The text with leading and trailing whitespaces removed beyond length k.
|
247
|
+
"""
|
248
|
+
|
249
|
+
# Count leading and trailing whitespaces
|
250
|
+
leading_count = len(s) - len(s.lstrip())
|
251
|
+
trailing_count = len(s) - len(s.rstrip())
|
252
|
+
|
253
|
+
# Determine how many whitespaces to retain
|
254
|
+
leading_keep = min(leading_count, k)
|
255
|
+
trailing_keep = min(trailing_count, k)
|
256
|
+
|
257
|
+
# Use slicing to get the desired output
|
258
|
+
return s[leading_count - leading_keep : len(s) - (trailing_count - trailing_keep)]
|
259
|
+
|
260
|
+
|
261
|
+
def clean_whitespace(text: str) -> str:
|
262
|
+
"""Remove extra whitespace from the input text, while preserving
|
263
|
+
paragraph structure.
|
264
|
+
"""
|
265
|
+
paragraphs = split_paragraphs(text)
|
266
|
+
cleaned_paragraphs = [" ".join(p.split()) for p in paragraphs if p]
|
267
|
+
return "\n\n".join(cleaned_paragraphs) # Join the cleaned paragraphs.
|
268
|
+
|
269
|
+
|
270
|
+
def extract_numbered_segments(s: str, specs: str) -> str:
|
271
|
+
"""
|
272
|
+
Extract specified segments from a numbered text, preserving paragraph structure.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
s (str): The input text containing numbered segments.
|
276
|
+
specs (str): A string containing segment numbers and/or ranges
|
277
|
+
(e.g., "3,5,7-10").
|
278
|
+
|
279
|
+
Returns:
|
280
|
+
str: Extracted segments, keeping original paragraph structures.
|
281
|
+
|
282
|
+
Example:
|
283
|
+
>>> text = "(1) Hello world! (2) How are you? (3) Have a good day."
|
284
|
+
>>> extract_numbered_segments(text, "1,3")
|
285
|
+
'Hello world! Have a good day.'
|
286
|
+
"""
|
287
|
+
# Use the helper function to get the list of indices from specs
|
288
|
+
if specs.strip() == "":
|
289
|
+
return ""
|
290
|
+
spec_indices = parse_number_range_list(specs)
|
291
|
+
|
292
|
+
# Regular expression to identify numbered segments like
|
293
|
+
# <#1#> Hello world! This is me. <#2#> How are you? <#3#> Have a good day.
|
294
|
+
# Note we match any character between segment markers, including newlines.
|
295
|
+
segment_pattern = re.compile(r"<#(\d+)#>([\s\S]*?)(?=<#\d+#>|$)")
|
296
|
+
|
297
|
+
# Split the text into paragraphs while preserving their boundaries
|
298
|
+
paragraphs = split_paragraphs(s)
|
299
|
+
|
300
|
+
extracted_paragraphs = []
|
301
|
+
|
302
|
+
for paragraph in paragraphs:
|
303
|
+
segments_with_numbers = segment_pattern.findall(paragraph)
|
304
|
+
|
305
|
+
# Extract the desired segments from this paragraph
|
306
|
+
extracted_segments = [
|
307
|
+
segment
|
308
|
+
for num, segment in segments_with_numbers
|
309
|
+
if int(num) in spec_indices
|
310
|
+
]
|
311
|
+
|
312
|
+
# If we extracted any segments from this paragraph,
|
313
|
+
# join them and append to results
|
314
|
+
if extracted_segments:
|
315
|
+
extracted_paragraphs.append(" ".join(extracted_segments))
|
316
|
+
|
317
|
+
return "\n\n".join(extracted_paragraphs)
|
318
|
+
|
319
|
+
|
320
|
+
def extract_content_from_path(
|
321
|
+
path: bytes | str | List[bytes | str],
|
322
|
+
parsing: ParsingConfig,
|
323
|
+
doc_type: str | DocumentType | None = None,
|
324
|
+
) -> str | List[str]:
|
325
|
+
"""
|
326
|
+
Extract the content from a file path or URL, or a list of file paths or URLs.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
path (bytes | str | List[str]): The file path or URL, or a list of file paths or
|
330
|
+
URLs, or bytes content. The bytes option is meant to support cases
|
331
|
+
where upstream code may have already loaded the content (e.g., from a
|
332
|
+
database or API) and we want to avoid having to copy the content to a
|
333
|
+
temporary file.
|
334
|
+
parsing (ParsingConfig): The parsing configuration.
|
335
|
+
doc_type (str | DocumentType | None): The document type if known.
|
336
|
+
If multiple paths are given, this MUST apply to ALL docs.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
str | List[str]: The extracted content if a single file path or URL is provided,
|
340
|
+
or a list of extracted contents if a
|
341
|
+
list of file paths or URLs is provided.
|
342
|
+
"""
|
343
|
+
if isinstance(path, str) or isinstance(path, bytes):
|
344
|
+
paths = [path]
|
345
|
+
elif isinstance(path, list) and len(path) == 0:
|
346
|
+
return ""
|
347
|
+
else:
|
348
|
+
paths = path
|
349
|
+
|
350
|
+
url_idxs, path_idxs, byte_idxs = get_urls_paths_bytes_indices(paths)
|
351
|
+
urls = [paths[i] for i in url_idxs]
|
352
|
+
path_list = [paths[i] for i in path_idxs]
|
353
|
+
byte_list = [paths[i] for i in byte_idxs]
|
354
|
+
path_list.extend(byte_list)
|
355
|
+
parser = Parser(parsing)
|
356
|
+
docs: List[Document] = []
|
357
|
+
try:
|
358
|
+
if len(urls) > 0:
|
359
|
+
loader = URLLoader(urls=urls, parser=parser) # type: ignore
|
360
|
+
docs = loader.load()
|
361
|
+
if len(path_list) > 0:
|
362
|
+
for p in path_list:
|
363
|
+
path_docs = RepoLoader.get_documents(
|
364
|
+
p, parser=parser, doc_type=doc_type
|
365
|
+
)
|
366
|
+
docs.extend(path_docs)
|
367
|
+
except Exception as e:
|
368
|
+
logger.warning(f"Error loading path {paths}: {e}")
|
369
|
+
return ""
|
370
|
+
if len(docs) == 1:
|
371
|
+
return docs[0].content
|
372
|
+
else:
|
373
|
+
return [d.content for d in docs]
|