langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/__init__.py +95 -0
- langroid/agent/__init__.py +40 -0
- langroid/agent/base.py +222 -91
- langroid/agent/batch.py +264 -0
- langroid/agent/callbacks/chainlit.py +608 -0
- langroid/agent/chat_agent.py +247 -101
- langroid/agent/chat_document.py +41 -4
- langroid/agent/openai_assistant.py +842 -0
- langroid/agent/special/__init__.py +50 -0
- langroid/agent/special/doc_chat_agent.py +837 -141
- langroid/agent/special/lance_doc_chat_agent.py +258 -0
- langroid/agent/special/lance_rag/__init__.py +9 -0
- langroid/agent/special/lance_rag/critic_agent.py +136 -0
- langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
- langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
- langroid/agent/special/lance_tools.py +44 -0
- langroid/agent/special/neo4j/__init__.py +0 -0
- langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
- langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
- langroid/agent/special/neo4j/utils/__init__.py +0 -0
- langroid/agent/special/neo4j/utils/system_message.py +46 -0
- langroid/agent/special/relevance_extractor_agent.py +127 -0
- langroid/agent/special/retriever_agent.py +32 -198
- langroid/agent/special/sql/__init__.py +11 -0
- langroid/agent/special/sql/sql_chat_agent.py +47 -23
- langroid/agent/special/sql/utils/__init__.py +22 -0
- langroid/agent/special/sql/utils/description_extractors.py +95 -46
- langroid/agent/special/sql/utils/populate_metadata.py +28 -21
- langroid/agent/special/table_chat_agent.py +43 -9
- langroid/agent/task.py +475 -122
- langroid/agent/tool_message.py +75 -13
- langroid/agent/tools/__init__.py +13 -0
- langroid/agent/tools/duckduckgo_search_tool.py +66 -0
- langroid/agent/tools/google_search_tool.py +11 -0
- langroid/agent/tools/metaphor_search_tool.py +67 -0
- langroid/agent/tools/recipient_tool.py +16 -29
- langroid/agent/tools/run_python_code.py +60 -0
- langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
- langroid/agent/tools/segment_extract_tool.py +36 -0
- langroid/cachedb/__init__.py +9 -0
- langroid/cachedb/base.py +22 -2
- langroid/cachedb/momento_cachedb.py +26 -2
- langroid/cachedb/redis_cachedb.py +78 -11
- langroid/embedding_models/__init__.py +34 -0
- langroid/embedding_models/base.py +21 -2
- langroid/embedding_models/models.py +120 -18
- langroid/embedding_models/protoc/embeddings.proto +19 -0
- langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
- langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
- langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
- langroid/embedding_models/remote_embeds.py +153 -0
- langroid/language_models/__init__.py +45 -0
- langroid/language_models/azure_openai.py +80 -27
- langroid/language_models/base.py +117 -12
- langroid/language_models/config.py +5 -0
- langroid/language_models/openai_assistants.py +3 -0
- langroid/language_models/openai_gpt.py +558 -174
- langroid/language_models/prompt_formatter/__init__.py +15 -0
- langroid/language_models/prompt_formatter/base.py +4 -6
- langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
- langroid/language_models/utils.py +18 -21
- langroid/mytypes.py +25 -8
- langroid/parsing/__init__.py +46 -0
- langroid/parsing/document_parser.py +260 -63
- langroid/parsing/image_text.py +32 -0
- langroid/parsing/parse_json.py +143 -0
- langroid/parsing/parser.py +122 -59
- langroid/parsing/repo_loader.py +114 -52
- langroid/parsing/search.py +68 -63
- langroid/parsing/spider.py +3 -2
- langroid/parsing/table_loader.py +44 -0
- langroid/parsing/url_loader.py +59 -11
- langroid/parsing/urls.py +85 -37
- langroid/parsing/utils.py +298 -4
- langroid/parsing/web_search.py +73 -0
- langroid/prompts/__init__.py +11 -0
- langroid/prompts/chat-gpt4-system-prompt.md +68 -0
- langroid/prompts/prompts_config.py +1 -1
- langroid/utils/__init__.py +17 -0
- langroid/utils/algorithms/__init__.py +3 -0
- langroid/utils/algorithms/graph.py +103 -0
- langroid/utils/configuration.py +36 -5
- langroid/utils/constants.py +4 -0
- langroid/utils/globals.py +2 -2
- langroid/utils/logging.py +2 -5
- langroid/utils/output/__init__.py +21 -0
- langroid/utils/output/printing.py +47 -1
- langroid/utils/output/status.py +33 -0
- langroid/utils/pandas_utils.py +30 -0
- langroid/utils/pydantic_utils.py +616 -2
- langroid/utils/system.py +98 -0
- langroid/vector_store/__init__.py +40 -0
- langroid/vector_store/base.py +203 -6
- langroid/vector_store/chromadb.py +59 -32
- langroid/vector_store/lancedb.py +463 -0
- langroid/vector_store/meilisearch.py +10 -7
- langroid/vector_store/momento.py +262 -0
- langroid/vector_store/qdrantdb.py +104 -22
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
- langroid-0.1.219.dist-info/RECORD +127 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
- langroid/agent/special/recipient_validator_agent.py +0 -157
- langroid/parsing/json.py +0 -64
- langroid/utils/web/selenium_login.py +0 -36
- langroid-0.1.85.dist-info/RECORD +0 -94
- /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
- {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
langroid/parsing/search.py
CHANGED
@@ -7,10 +7,8 @@ See tests for examples: tests/main/test_string_search.py
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
import difflib
|
10
|
-
import re
|
11
10
|
from typing import List, Tuple
|
12
11
|
|
13
|
-
import nltk
|
14
12
|
from nltk.corpus import stopwords
|
15
13
|
from nltk.stem import WordNetLemmatizer
|
16
14
|
from nltk.tokenize import RegexpTokenizer
|
@@ -19,10 +17,13 @@ from thefuzz import fuzz, process
|
|
19
17
|
|
20
18
|
from langroid.mytypes import Document
|
21
19
|
|
20
|
+
from .utils import download_nltk_resource
|
21
|
+
|
22
22
|
|
23
23
|
def find_fuzzy_matches_in_docs(
|
24
24
|
query: str,
|
25
25
|
docs: List[Document],
|
26
|
+
docs_clean: List[Document],
|
26
27
|
k: int,
|
27
28
|
words_before: int | None = None,
|
28
29
|
words_after: int | None = None,
|
@@ -48,58 +49,53 @@ def find_fuzzy_matches_in_docs(
|
|
48
49
|
return []
|
49
50
|
best_matches = process.extract(
|
50
51
|
query,
|
51
|
-
[d.content for d in
|
52
|
+
[d.content for d in docs_clean],
|
52
53
|
limit=k,
|
53
54
|
scorer=fuzz.partial_ratio,
|
54
55
|
)
|
55
56
|
|
56
57
|
real_matches = [m for m, score in best_matches if score > 50]
|
57
|
-
|
58
|
-
|
59
|
-
for
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
words_in_text = doc.content.split()
|
64
|
-
first_word_idx = next(
|
65
|
-
(
|
66
|
-
i
|
67
|
-
for i, word in enumerate(words_in_text)
|
68
|
-
if word.startswith(words[0])
|
69
|
-
),
|
70
|
-
-1,
|
71
|
-
)
|
72
|
-
if words_before is None:
|
73
|
-
words_before = len(words_in_text)
|
74
|
-
if words_after is None:
|
75
|
-
words_after = len(words_in_text)
|
76
|
-
if first_word_idx != -1:
|
77
|
-
start_idx = max(0, first_word_idx - words_before)
|
78
|
-
end_idx = min(
|
79
|
-
len(words_in_text),
|
80
|
-
first_word_idx + len(words) + words_after,
|
81
|
-
)
|
82
|
-
doc_match = Document(
|
83
|
-
content=" ".join(words_in_text[start_idx:end_idx]),
|
84
|
-
metadata=doc.metadata,
|
85
|
-
)
|
86
|
-
results.append(doc_match)
|
58
|
+
# find the original docs that corresponding to the matches
|
59
|
+
orig_doc_matches = []
|
60
|
+
for i, m in enumerate(real_matches):
|
61
|
+
for j, doc_clean in enumerate(docs_clean):
|
62
|
+
if m in doc_clean.content:
|
63
|
+
orig_doc_matches.append(docs[j])
|
87
64
|
break
|
65
|
+
if words_after is None and words_before is None:
|
66
|
+
return orig_doc_matches
|
67
|
+
if len(orig_doc_matches) == 0:
|
68
|
+
return []
|
69
|
+
if set(orig_doc_matches[0].__fields__) != {"content", "metadata"}:
|
70
|
+
# If there are fields beyond just content and metadata,
|
71
|
+
# we do NOT want to create new document objects with content fields
|
72
|
+
# based on words_before and words_after, since we don't know how to
|
73
|
+
# set those other fields.
|
74
|
+
return orig_doc_matches
|
75
|
+
|
76
|
+
contextual_matches = []
|
77
|
+
for match in orig_doc_matches:
|
78
|
+
choice_text = match.content
|
79
|
+
contexts = []
|
80
|
+
while choice_text != "":
|
81
|
+
context, start_pos, end_pos = get_context(
|
82
|
+
query, choice_text, words_before, words_after
|
83
|
+
)
|
84
|
+
if context == "" or end_pos == 0:
|
85
|
+
break
|
86
|
+
contexts.append(context)
|
87
|
+
words = choice_text.split()
|
88
|
+
end_pos = min(end_pos, len(words))
|
89
|
+
choice_text = " ".join(words[end_pos:])
|
90
|
+
if len(contexts) > 0:
|
91
|
+
contextual_matches.append(
|
92
|
+
Document(
|
93
|
+
content=" ... ".join(contexts),
|
94
|
+
metadata=match.metadata,
|
95
|
+
)
|
96
|
+
)
|
88
97
|
|
89
|
-
return
|
90
|
-
|
91
|
-
|
92
|
-
# Ensure NLTK resources are available
|
93
|
-
def download_nltk_resources() -> None:
|
94
|
-
resources = ["punkt", "wordnet", "stopwords"]
|
95
|
-
for resource in resources:
|
96
|
-
try:
|
97
|
-
nltk.data.find(resource)
|
98
|
-
except LookupError:
|
99
|
-
nltk.download(resource)
|
100
|
-
|
101
|
-
|
102
|
-
download_nltk_resources()
|
98
|
+
return contextual_matches
|
103
99
|
|
104
100
|
|
105
101
|
def preprocess_text(text: str) -> str:
|
@@ -117,6 +113,10 @@ def preprocess_text(text: str) -> str:
|
|
117
113
|
Returns:
|
118
114
|
str: The preprocessed text.
|
119
115
|
"""
|
116
|
+
# Ensure the NLTK resources are available
|
117
|
+
for resource in ["punkt", "wordnet", "stopwords"]:
|
118
|
+
download_nltk_resource(resource)
|
119
|
+
|
120
120
|
# Lowercase the text
|
121
121
|
text = text.lower()
|
122
122
|
|
@@ -179,7 +179,7 @@ def get_context(
|
|
179
179
|
text: str,
|
180
180
|
words_before: int | None = 100,
|
181
181
|
words_after: int | None = 100,
|
182
|
-
) -> str:
|
182
|
+
) -> Tuple[str, int, int]:
|
183
183
|
"""
|
184
184
|
Returns a portion of text containing the best approximate match of the query,
|
185
185
|
including b words before and a words after the match.
|
@@ -193,7 +193,9 @@ def get_context(
|
|
193
193
|
Returns:
|
194
194
|
str: A string containing b words before, the match, and a words after
|
195
195
|
the best approximate match position of the query in the text. If no
|
196
|
-
match is found, returns
|
196
|
+
match is found, returns empty string.
|
197
|
+
int: The start position of the match in the text.
|
198
|
+
int: The end position of the match in the text.
|
197
199
|
|
198
200
|
Example:
|
199
201
|
>>> get_context("apple", "The quick brown fox jumps over the apple.", 3, 2)
|
@@ -201,26 +203,29 @@ def get_context(
|
|
201
203
|
"""
|
202
204
|
if words_after is None and words_before is None:
|
203
205
|
# return entire text since we're not asked to return a bounded context
|
204
|
-
return text
|
206
|
+
return text, 0, 0
|
207
|
+
|
208
|
+
# make sure there is a good enough match to the query
|
209
|
+
if fuzz.partial_ratio(query, text) < 40:
|
210
|
+
return "", 0, 0
|
205
211
|
|
206
212
|
sequence_matcher = difflib.SequenceMatcher(None, text, query)
|
207
213
|
match = sequence_matcher.find_longest_match(0, len(text), 0, len(query))
|
208
214
|
|
209
215
|
if match.size == 0:
|
210
|
-
return "
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
)
|
216
|
+
return "", 0, 0
|
217
|
+
|
218
|
+
segments = text.split()
|
219
|
+
n_segs = len(segments)
|
220
|
+
|
221
|
+
start_segment_pos = len(text[: match.a].split())
|
222
|
+
|
223
|
+
words_before = words_before or n_segs
|
224
|
+
words_after = words_after or n_segs
|
225
|
+
start_pos = max(0, start_segment_pos - words_before)
|
226
|
+
end_pos = min(len(segments), start_segment_pos + words_after + len(query.split()))
|
222
227
|
|
223
|
-
return " ".join(
|
228
|
+
return " ".join(segments[start_pos:end_pos]), start_pos, end_pos
|
224
229
|
|
225
230
|
|
226
231
|
def eliminate_near_duplicates(passages: List[str], threshold: float = 0.8) -> List[str]:
|
langroid/parsing/spider.py
CHANGED
@@ -4,6 +4,7 @@ from urllib.parse import urlparse
|
|
4
4
|
from pydispatch import dispatcher
|
5
5
|
from scrapy import signals
|
6
6
|
from scrapy.crawler import CrawlerRunner
|
7
|
+
from scrapy.http import Response
|
7
8
|
from scrapy.linkextractors import LinkExtractor
|
8
9
|
from scrapy.spiders import CrawlSpider, Rule
|
9
10
|
from twisted.internet import defer, reactor
|
@@ -30,7 +31,7 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
|
|
30
31
|
self.k = k
|
31
32
|
self.visited_urls: Set[str] = set()
|
32
33
|
|
33
|
-
def parse_item(self, response): # type: ignore
|
34
|
+
def parse_item(self, response: Response): # type: ignore
|
34
35
|
"""Extracts URLs that are within the same domain.
|
35
36
|
|
36
37
|
Args:
|
@@ -57,7 +58,7 @@ def scrapy_fetch_urls(url: str, k: int = 20) -> List[str]:
|
|
57
58
|
"""
|
58
59
|
urls = []
|
59
60
|
|
60
|
-
def _collect_urls(spider
|
61
|
+
def _collect_urls(spider):
|
61
62
|
"""Handler for the spider_closed signal. Collects the visited URLs."""
|
62
63
|
nonlocal urls
|
63
64
|
urls.extend(list(spider.visited_urls))
|
langroid/parsing/table_loader.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from csv import Sniffer
|
2
|
+
from typing import List
|
2
3
|
|
3
4
|
import pandas as pd
|
4
5
|
|
@@ -48,3 +49,46 @@ def read_tabular_data(path_or_url: str, sep: None | str = None) -> pd.DataFrame:
|
|
48
49
|
"Unable to read data. "
|
49
50
|
"Please ensure it is correctly formatted. Error: " + str(e)
|
50
51
|
)
|
52
|
+
|
53
|
+
|
54
|
+
def describe_dataframe(
|
55
|
+
df: pd.DataFrame, filter_fields: List[str] = [], n_vals: int = 10
|
56
|
+
) -> str:
|
57
|
+
"""
|
58
|
+
Generates a description of the columns in the dataframe,
|
59
|
+
along with a listing of up to `n_vals` unique values for each column.
|
60
|
+
Intended to be used to insert into an LLM context so it can generate
|
61
|
+
appropriate queries or filters on the df.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
df (pd.DataFrame): The dataframe to describe.
|
65
|
+
filter_fields (list): A list of fields that can be used for filtering.
|
66
|
+
When non-empty, the values-list will be restricted to these.
|
67
|
+
n_vals (int): How many unique values to show for each column.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
str: A description of the dataframe.
|
71
|
+
"""
|
72
|
+
description = []
|
73
|
+
for column in df.columns.to_list():
|
74
|
+
unique_values = df[column].dropna().unique()
|
75
|
+
unique_count = len(unique_values)
|
76
|
+
if column not in filter_fields:
|
77
|
+
values_desc = f"{unique_count} unique values"
|
78
|
+
else:
|
79
|
+
if unique_count > n_vals:
|
80
|
+
displayed_values = unique_values[:n_vals]
|
81
|
+
more_count = unique_count - n_vals
|
82
|
+
values_desc = f" Values - {displayed_values}, ... {more_count} more"
|
83
|
+
else:
|
84
|
+
values_desc = f" Values - {unique_values}"
|
85
|
+
col_type = "string" if df[column].dtype == "object" else df[column].dtype
|
86
|
+
col_desc = f"* {column} ({col_type}); {values_desc}"
|
87
|
+
description.append(col_desc)
|
88
|
+
|
89
|
+
all_cols = "\n".join(description)
|
90
|
+
|
91
|
+
return f"""
|
92
|
+
Name of each field, its type and unique values (up to {n_vals}):
|
93
|
+
{all_cols}
|
94
|
+
"""
|
langroid/parsing/url_loader.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
3
|
+
from tempfile import NamedTemporaryFile
|
2
4
|
from typing import List, no_type_check
|
3
5
|
|
6
|
+
import requests
|
4
7
|
import trafilatura
|
5
8
|
from trafilatura.downloads import (
|
6
9
|
add_to_compressed_dict,
|
@@ -9,7 +12,7 @@ from trafilatura.downloads import (
|
|
9
12
|
)
|
10
13
|
|
11
14
|
from langroid.mytypes import DocMetaData, Document
|
12
|
-
from langroid.parsing.document_parser import DocumentParser
|
15
|
+
from langroid.parsing.document_parser import DocumentParser, ImagePdfParser
|
13
16
|
from langroid.parsing.parser import Parser, ParsingConfig
|
14
17
|
|
15
18
|
logging.getLogger("trafilatura").setLevel(logging.ERROR)
|
@@ -44,20 +47,65 @@ class URLLoader:
|
|
44
47
|
sleep_time=5,
|
45
48
|
)
|
46
49
|
for url, result in buffered_downloads(buffer, threads):
|
47
|
-
if
|
50
|
+
if (
|
51
|
+
url.lower().endswith(".pdf")
|
52
|
+
or url.lower().endswith(".docx")
|
53
|
+
or url.lower().endswith(".doc")
|
54
|
+
):
|
48
55
|
doc_parser = DocumentParser.create(
|
49
56
|
url,
|
50
57
|
self.parser.config,
|
51
58
|
)
|
52
|
-
|
59
|
+
new_chunks = doc_parser.get_doc_chunks()
|
60
|
+
if len(new_chunks) == 0:
|
61
|
+
# If the document is empty, try to extract images
|
62
|
+
img_parser = ImagePdfParser(url, self.parser.config)
|
63
|
+
new_chunks = img_parser.get_doc_chunks()
|
64
|
+
docs.extend(new_chunks)
|
53
65
|
else:
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
66
|
+
# Try to detect content type and handle accordingly
|
67
|
+
headers = requests.head(url).headers
|
68
|
+
content_type = headers.get("Content-Type", "").lower()
|
69
|
+
temp_file_suffix = None
|
70
|
+
if "application/pdf" in content_type:
|
71
|
+
temp_file_suffix = ".pdf"
|
72
|
+
elif (
|
73
|
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
74
|
+
in content_type
|
75
|
+
):
|
76
|
+
temp_file_suffix = ".docx"
|
77
|
+
elif "application/msword" in content_type:
|
78
|
+
temp_file_suffix = ".doc"
|
79
|
+
|
80
|
+
if temp_file_suffix:
|
81
|
+
# Download the document content
|
82
|
+
response = requests.get(url)
|
83
|
+
with NamedTemporaryFile(
|
84
|
+
delete=False, suffix=temp_file_suffix
|
85
|
+
) as temp_file:
|
86
|
+
temp_file.write(response.content)
|
87
|
+
temp_file_path = temp_file.name
|
88
|
+
# Process the downloaded document
|
89
|
+
doc_parser = DocumentParser.create(
|
90
|
+
temp_file_path, self.parser.config
|
91
|
+
)
|
92
|
+
docs.extend(doc_parser.get_doc_chunks())
|
93
|
+
# Clean up the temporary file
|
94
|
+
os.remove(temp_file_path)
|
95
|
+
else:
|
96
|
+
text = trafilatura.extract(
|
97
|
+
result,
|
98
|
+
no_fallback=False,
|
99
|
+
favor_recall=True,
|
62
100
|
)
|
101
|
+
if (
|
102
|
+
text is None
|
103
|
+
and result is not None
|
104
|
+
and isinstance(result, str)
|
105
|
+
):
|
106
|
+
text = result
|
107
|
+
if text is not None and text != "":
|
108
|
+
docs.append(
|
109
|
+
Document(content=text, metadata=DocMetaData(source=url))
|
110
|
+
)
|
63
111
|
return docs
|
langroid/parsing/urls.py
CHANGED
@@ -4,7 +4,7 @@ import tempfile
|
|
4
4
|
import urllib.parse
|
5
5
|
import urllib.robotparser
|
6
6
|
from typing import List, Optional, Set, Tuple
|
7
|
-
from urllib.parse import urljoin
|
7
|
+
from urllib.parse import urldefrag, urljoin, urlparse
|
8
8
|
|
9
9
|
import fire
|
10
10
|
import requests
|
@@ -14,8 +14,6 @@ from rich import print
|
|
14
14
|
from rich.prompt import Prompt
|
15
15
|
from trafilatura.spider import focused_crawler
|
16
16
|
|
17
|
-
from langroid.parsing.spider import scrapy_fetch_urls
|
18
|
-
|
19
17
|
logger = logging.getLogger(__name__)
|
20
18
|
|
21
19
|
|
@@ -86,7 +84,15 @@ def get_list_from_user(
|
|
86
84
|
url = input_str
|
87
85
|
input_str = Prompt.ask("[blue] How many new URLs to crawl?", default="0")
|
88
86
|
max_urls = int(input_str) + 1
|
89
|
-
tot_urls =
|
87
|
+
tot_urls = list(find_urls(url, max_links=max_urls, max_depth=2))
|
88
|
+
tot_urls_str = "\n".join(tot_urls)
|
89
|
+
print(
|
90
|
+
f"""
|
91
|
+
Found these {len(tot_urls)} links upto depth 2:
|
92
|
+
{tot_urls_str}
|
93
|
+
"""
|
94
|
+
)
|
95
|
+
|
90
96
|
input_set.update(tot_urls)
|
91
97
|
else:
|
92
98
|
input_set.add(input_str.strip())
|
@@ -106,32 +112,42 @@ def is_url(s: str) -> bool:
|
|
106
112
|
return False
|
107
113
|
|
108
114
|
|
109
|
-
def
|
115
|
+
def get_urls_paths_bytes_indices(
|
116
|
+
inputs: List[str | bytes],
|
117
|
+
) -> Tuple[List[int], List[int], List[int]]:
|
110
118
|
"""
|
111
|
-
Given a list of inputs, return a
|
119
|
+
Given a list of inputs, return a
|
120
|
+
list of indices of URLs, list of indices of paths, list of indices of byte-contents.
|
112
121
|
Args:
|
113
|
-
inputs: list of strings
|
122
|
+
inputs: list of strings or bytes
|
114
123
|
Returns:
|
115
|
-
list of
|
124
|
+
list of Indices of URLs,
|
125
|
+
list of indices of paths,
|
126
|
+
list of indices of byte-contents
|
116
127
|
"""
|
117
128
|
urls = []
|
118
129
|
paths = []
|
119
|
-
|
130
|
+
byte_list = []
|
131
|
+
for i, item in enumerate(inputs):
|
132
|
+
if isinstance(item, bytes):
|
133
|
+
byte_list.append(i)
|
134
|
+
continue
|
120
135
|
try:
|
121
|
-
|
122
|
-
urls.append(
|
136
|
+
Url(url=parse_obj_as(HttpUrl, item))
|
137
|
+
urls.append(i)
|
123
138
|
except ValidationError:
|
124
139
|
if os.path.exists(item):
|
125
|
-
paths.append(
|
140
|
+
paths.append(i)
|
126
141
|
else:
|
127
142
|
logger.warning(f"{item} is neither a URL nor a path.")
|
128
|
-
return urls, paths
|
143
|
+
return urls, paths, byte_list
|
129
144
|
|
130
145
|
|
131
146
|
def crawl_url(url: str, max_urls: int = 1) -> List[str]:
|
132
147
|
"""
|
133
148
|
Crawl starting at the url and return a list of URLs to be parsed,
|
134
149
|
up to a maximum of `max_urls`.
|
150
|
+
This has not been tested to work as intended. Ignore.
|
135
151
|
"""
|
136
152
|
if max_urls == 1:
|
137
153
|
# no need to crawl, just return the original list
|
@@ -161,6 +177,7 @@ def crawl_url(url: str, max_urls: int = 1) -> List[str]:
|
|
161
177
|
)
|
162
178
|
if to_visit is None:
|
163
179
|
break
|
180
|
+
|
164
181
|
if known_urls is None:
|
165
182
|
return [url]
|
166
183
|
final_urls = [s.strip() for s in known_urls]
|
@@ -169,46 +186,77 @@ def crawl_url(url: str, max_urls: int = 1) -> List[str]:
|
|
169
186
|
|
170
187
|
def find_urls(
|
171
188
|
url: str = "https://en.wikipedia.org/wiki/Generative_pre-trained_transformer",
|
189
|
+
max_links: int = 20,
|
172
190
|
visited: Optional[Set[str]] = None,
|
173
191
|
depth: int = 0,
|
174
192
|
max_depth: int = 2,
|
193
|
+
match_domain: bool = True,
|
175
194
|
) -> Set[str]:
|
176
195
|
"""
|
177
196
|
Recursively find all URLs on a given page.
|
197
|
+
|
178
198
|
Args:
|
179
|
-
url:
|
180
|
-
|
181
|
-
|
182
|
-
|
199
|
+
url (str): The URL to start from.
|
200
|
+
max_links (int): The maximum number of links to find.
|
201
|
+
visited (set): A set of URLs that have already been visited.
|
202
|
+
depth (int): The current depth of the recursion.
|
203
|
+
max_depth (int): The maximum depth of the recursion.
|
204
|
+
match_domain (bool): Whether to only return URLs that are on the same domain.
|
183
205
|
|
184
206
|
Returns:
|
185
|
-
|
207
|
+
set: A set of URLs found on the page.
|
186
208
|
"""
|
209
|
+
|
187
210
|
if visited is None:
|
188
211
|
visited = set()
|
189
|
-
visited.add(url)
|
190
212
|
|
191
|
-
|
192
|
-
response = requests.get(url)
|
193
|
-
response.raise_for_status()
|
194
|
-
except (
|
195
|
-
requests.exceptions.HTTPError,
|
196
|
-
requests.exceptions.RequestException,
|
197
|
-
):
|
198
|
-
print(f"Failed to fetch '{url}'")
|
213
|
+
if url in visited or depth > max_depth:
|
199
214
|
return visited
|
200
215
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
urls = [urljoin(url, link["href"]) for link in links] # Construct full URLs
|
205
|
-
|
206
|
-
if depth < max_depth:
|
207
|
-
for link_url in urls:
|
208
|
-
if link_url not in visited:
|
209
|
-
find_urls(link_url, visited, depth + 1, max_depth)
|
216
|
+
visited.add(url)
|
217
|
+
base_domain = urlparse(url).netloc
|
210
218
|
|
211
|
-
|
219
|
+
try:
|
220
|
+
response = requests.get(url, timeout=5)
|
221
|
+
response.raise_for_status()
|
222
|
+
soup = BeautifulSoup(response.text, "html.parser")
|
223
|
+
links = [urljoin(url, a["href"]) for a in soup.find_all("a", href=True)]
|
224
|
+
|
225
|
+
# Defrag links: discard links that are to portions of same page
|
226
|
+
defragged_links = list(set(urldefrag(link).url for link in links))
|
227
|
+
|
228
|
+
# Filter links based on domain matching requirement
|
229
|
+
domain_matching_links = [
|
230
|
+
link for link in defragged_links if urlparse(link).netloc == base_domain
|
231
|
+
]
|
232
|
+
|
233
|
+
# ensure url is first, since below we are taking first max_links urls
|
234
|
+
domain_matching_links = [url] + [x for x in domain_matching_links if x != url]
|
235
|
+
|
236
|
+
# If found links exceed max_links, return immediately
|
237
|
+
if len(domain_matching_links) >= max_links:
|
238
|
+
return set(domain_matching_links[:max_links])
|
239
|
+
|
240
|
+
for link in domain_matching_links:
|
241
|
+
if len(visited) >= max_links:
|
242
|
+
break
|
243
|
+
|
244
|
+
if link not in visited:
|
245
|
+
visited.update(
|
246
|
+
find_urls(
|
247
|
+
link,
|
248
|
+
max_links,
|
249
|
+
visited,
|
250
|
+
depth + 1,
|
251
|
+
max_depth,
|
252
|
+
match_domain,
|
253
|
+
)
|
254
|
+
)
|
255
|
+
|
256
|
+
except (requests.RequestException, Exception) as e:
|
257
|
+
print(f"Error fetching {url}. Error: {e}")
|
258
|
+
|
259
|
+
return set(list(visited)[:max_links])
|
212
260
|
|
213
261
|
|
214
262
|
def org_user_from_github(url: str) -> str:
|