langroid 0.1.85__py3-none-any.whl → 0.1.219__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. langroid/__init__.py +95 -0
  2. langroid/agent/__init__.py +40 -0
  3. langroid/agent/base.py +222 -91
  4. langroid/agent/batch.py +264 -0
  5. langroid/agent/callbacks/chainlit.py +608 -0
  6. langroid/agent/chat_agent.py +247 -101
  7. langroid/agent/chat_document.py +41 -4
  8. langroid/agent/openai_assistant.py +842 -0
  9. langroid/agent/special/__init__.py +50 -0
  10. langroid/agent/special/doc_chat_agent.py +837 -141
  11. langroid/agent/special/lance_doc_chat_agent.py +258 -0
  12. langroid/agent/special/lance_rag/__init__.py +9 -0
  13. langroid/agent/special/lance_rag/critic_agent.py +136 -0
  14. langroid/agent/special/lance_rag/lance_rag_task.py +80 -0
  15. langroid/agent/special/lance_rag/query_planner_agent.py +180 -0
  16. langroid/agent/special/lance_tools.py +44 -0
  17. langroid/agent/special/neo4j/__init__.py +0 -0
  18. langroid/agent/special/neo4j/csv_kg_chat.py +174 -0
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py +370 -0
  20. langroid/agent/special/neo4j/utils/__init__.py +0 -0
  21. langroid/agent/special/neo4j/utils/system_message.py +46 -0
  22. langroid/agent/special/relevance_extractor_agent.py +127 -0
  23. langroid/agent/special/retriever_agent.py +32 -198
  24. langroid/agent/special/sql/__init__.py +11 -0
  25. langroid/agent/special/sql/sql_chat_agent.py +47 -23
  26. langroid/agent/special/sql/utils/__init__.py +22 -0
  27. langroid/agent/special/sql/utils/description_extractors.py +95 -46
  28. langroid/agent/special/sql/utils/populate_metadata.py +28 -21
  29. langroid/agent/special/table_chat_agent.py +43 -9
  30. langroid/agent/task.py +475 -122
  31. langroid/agent/tool_message.py +75 -13
  32. langroid/agent/tools/__init__.py +13 -0
  33. langroid/agent/tools/duckduckgo_search_tool.py +66 -0
  34. langroid/agent/tools/google_search_tool.py +11 -0
  35. langroid/agent/tools/metaphor_search_tool.py +67 -0
  36. langroid/agent/tools/recipient_tool.py +16 -29
  37. langroid/agent/tools/run_python_code.py +60 -0
  38. langroid/agent/tools/sciphi_search_rag_tool.py +79 -0
  39. langroid/agent/tools/segment_extract_tool.py +36 -0
  40. langroid/cachedb/__init__.py +9 -0
  41. langroid/cachedb/base.py +22 -2
  42. langroid/cachedb/momento_cachedb.py +26 -2
  43. langroid/cachedb/redis_cachedb.py +78 -11
  44. langroid/embedding_models/__init__.py +34 -0
  45. langroid/embedding_models/base.py +21 -2
  46. langroid/embedding_models/models.py +120 -18
  47. langroid/embedding_models/protoc/embeddings.proto +19 -0
  48. langroid/embedding_models/protoc/embeddings_pb2.py +33 -0
  49. langroid/embedding_models/protoc/embeddings_pb2.pyi +50 -0
  50. langroid/embedding_models/protoc/embeddings_pb2_grpc.py +79 -0
  51. langroid/embedding_models/remote_embeds.py +153 -0
  52. langroid/language_models/__init__.py +45 -0
  53. langroid/language_models/azure_openai.py +80 -27
  54. langroid/language_models/base.py +117 -12
  55. langroid/language_models/config.py +5 -0
  56. langroid/language_models/openai_assistants.py +3 -0
  57. langroid/language_models/openai_gpt.py +558 -174
  58. langroid/language_models/prompt_formatter/__init__.py +15 -0
  59. langroid/language_models/prompt_formatter/base.py +4 -6
  60. langroid/language_models/prompt_formatter/hf_formatter.py +135 -0
  61. langroid/language_models/utils.py +18 -21
  62. langroid/mytypes.py +25 -8
  63. langroid/parsing/__init__.py +46 -0
  64. langroid/parsing/document_parser.py +260 -63
  65. langroid/parsing/image_text.py +32 -0
  66. langroid/parsing/parse_json.py +143 -0
  67. langroid/parsing/parser.py +122 -59
  68. langroid/parsing/repo_loader.py +114 -52
  69. langroid/parsing/search.py +68 -63
  70. langroid/parsing/spider.py +3 -2
  71. langroid/parsing/table_loader.py +44 -0
  72. langroid/parsing/url_loader.py +59 -11
  73. langroid/parsing/urls.py +85 -37
  74. langroid/parsing/utils.py +298 -4
  75. langroid/parsing/web_search.py +73 -0
  76. langroid/prompts/__init__.py +11 -0
  77. langroid/prompts/chat-gpt4-system-prompt.md +68 -0
  78. langroid/prompts/prompts_config.py +1 -1
  79. langroid/utils/__init__.py +17 -0
  80. langroid/utils/algorithms/__init__.py +3 -0
  81. langroid/utils/algorithms/graph.py +103 -0
  82. langroid/utils/configuration.py +36 -5
  83. langroid/utils/constants.py +4 -0
  84. langroid/utils/globals.py +2 -2
  85. langroid/utils/logging.py +2 -5
  86. langroid/utils/output/__init__.py +21 -0
  87. langroid/utils/output/printing.py +47 -1
  88. langroid/utils/output/status.py +33 -0
  89. langroid/utils/pandas_utils.py +30 -0
  90. langroid/utils/pydantic_utils.py +616 -2
  91. langroid/utils/system.py +98 -0
  92. langroid/vector_store/__init__.py +40 -0
  93. langroid/vector_store/base.py +203 -6
  94. langroid/vector_store/chromadb.py +59 -32
  95. langroid/vector_store/lancedb.py +463 -0
  96. langroid/vector_store/meilisearch.py +10 -7
  97. langroid/vector_store/momento.py +262 -0
  98. langroid/vector_store/qdrantdb.py +104 -22
  99. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/METADATA +329 -149
  100. langroid-0.1.219.dist-info/RECORD +127 -0
  101. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/WHEEL +1 -1
  102. langroid/agent/special/recipient_validator_agent.py +0 -157
  103. langroid/parsing/json.py +0 -64
  104. langroid/utils/web/selenium_login.py +0 -36
  105. langroid-0.1.85.dist-info/RECORD +0 -94
  106. /langroid/{scripts → agent/callbacks}/__init__.py +0 -0
  107. {langroid-0.1.85.dist-info → langroid-0.1.219.dist-info}/LICENSE +0 -0
@@ -7,10 +7,8 @@ See tests for examples: tests/main/test_string_search.py
7
7
  """
8
8
 
9
9
  import difflib
10
- import re
11
10
  from typing import List, Tuple
12
11
 
13
- import nltk
14
12
  from nltk.corpus import stopwords
15
13
  from nltk.stem import WordNetLemmatizer
16
14
  from nltk.tokenize import RegexpTokenizer
@@ -19,10 +17,13 @@ from thefuzz import fuzz, process
19
17
 
20
18
  from langroid.mytypes import Document
21
19
 
20
+ from .utils import download_nltk_resource
21
+
22
22
 
23
23
  def find_fuzzy_matches_in_docs(
24
24
  query: str,
25
25
  docs: List[Document],
26
+ docs_clean: List[Document],
26
27
  k: int,
27
28
  words_before: int | None = None,
28
29
  words_after: int | None = None,
@@ -48,58 +49,53 @@ def find_fuzzy_matches_in_docs(
48
49
  return []
49
50
  best_matches = process.extract(
50
51
  query,
51
- [d.content for d in docs],
52
+ [d.content for d in docs_clean],
52
53
  limit=k,
53
54
  scorer=fuzz.partial_ratio,
54
55
  )
55
56
 
56
57
  real_matches = [m for m, score in best_matches if score > 50]
57
-
58
- results = []
59
- for match in real_matches:
60
- words = match.split()
61
- for doc in docs:
62
- if match in doc.content:
63
- words_in_text = doc.content.split()
64
- first_word_idx = next(
65
- (
66
- i
67
- for i, word in enumerate(words_in_text)
68
- if word.startswith(words[0])
69
- ),
70
- -1,
71
- )
72
- if words_before is None:
73
- words_before = len(words_in_text)
74
- if words_after is None:
75
- words_after = len(words_in_text)
76
- if first_word_idx != -1:
77
- start_idx = max(0, first_word_idx - words_before)
78
- end_idx = min(
79
- len(words_in_text),
80
- first_word_idx + len(words) + words_after,
81
- )
82
- doc_match = Document(
83
- content=" ".join(words_in_text[start_idx:end_idx]),
84
- metadata=doc.metadata,
85
- )
86
- results.append(doc_match)
58
+ # find the original docs that corresponding to the matches
59
+ orig_doc_matches = []
60
+ for i, m in enumerate(real_matches):
61
+ for j, doc_clean in enumerate(docs_clean):
62
+ if m in doc_clean.content:
63
+ orig_doc_matches.append(docs[j])
87
64
  break
65
+ if words_after is None and words_before is None:
66
+ return orig_doc_matches
67
+ if len(orig_doc_matches) == 0:
68
+ return []
69
+ if set(orig_doc_matches[0].__fields__) != {"content", "metadata"}:
70
+ # If there are fields beyond just content and metadata,
71
+ # we do NOT want to create new document objects with content fields
72
+ # based on words_before and words_after, since we don't know how to
73
+ # set those other fields.
74
+ return orig_doc_matches
75
+
76
+ contextual_matches = []
77
+ for match in orig_doc_matches:
78
+ choice_text = match.content
79
+ contexts = []
80
+ while choice_text != "":
81
+ context, start_pos, end_pos = get_context(
82
+ query, choice_text, words_before, words_after
83
+ )
84
+ if context == "" or end_pos == 0:
85
+ break
86
+ contexts.append(context)
87
+ words = choice_text.split()
88
+ end_pos = min(end_pos, len(words))
89
+ choice_text = " ".join(words[end_pos:])
90
+ if len(contexts) > 0:
91
+ contextual_matches.append(
92
+ Document(
93
+ content=" ... ".join(contexts),
94
+ metadata=match.metadata,
95
+ )
96
+ )
88
97
 
89
- return results
90
-
91
-
92
- # Ensure NLTK resources are available
93
- def download_nltk_resources() -> None:
94
- resources = ["punkt", "wordnet", "stopwords"]
95
- for resource in resources:
96
- try:
97
- nltk.data.find(resource)
98
- except LookupError:
99
- nltk.download(resource)
100
-
101
-
102
- download_nltk_resources()
98
+ return contextual_matches
103
99
 
104
100
 
105
101
  def preprocess_text(text: str) -> str:
@@ -117,6 +113,10 @@ def preprocess_text(text: str) -> str:
117
113
  Returns:
118
114
  str: The preprocessed text.
119
115
  """
116
+ # Ensure the NLTK resources are available
117
+ for resource in ["punkt", "wordnet", "stopwords"]:
118
+ download_nltk_resource(resource)
119
+
120
120
  # Lowercase the text
121
121
  text = text.lower()
122
122
 
@@ -179,7 +179,7 @@ def get_context(
179
179
  text: str,
180
180
  words_before: int | None = 100,
181
181
  words_after: int | None = 100,
182
- ) -> str:
182
+ ) -> Tuple[str, int, int]:
183
183
  """
184
184
  Returns a portion of text containing the best approximate match of the query,
185
185
  including b words before and a words after the match.
@@ -193,7 +193,9 @@ def get_context(
193
193
  Returns:
194
194
  str: A string containing b words before, the match, and a words after
195
195
  the best approximate match position of the query in the text. If no
196
- match is found, returns "No match found".
196
+ match is found, returns empty string.
197
+ int: The start position of the match in the text.
198
+ int: The end position of the match in the text.
197
199
 
198
200
  Example:
199
201
  >>> get_context("apple", "The quick brown fox jumps over the apple.", 3, 2)
@@ -201,26 +203,29 @@ def get_context(
201
203
  """
202
204
  if words_after is None and words_before is None:
203
205
  # return entire text since we're not asked to return a bounded context
204
- return text
206
+ return text, 0, 0
207
+
208
+ # make sure there is a good enough match to the query
209
+ if fuzz.partial_ratio(query, text) < 40:
210
+ return "", 0, 0
205
211
 
206
212
  sequence_matcher = difflib.SequenceMatcher(None, text, query)
207
213
  match = sequence_matcher.find_longest_match(0, len(text), 0, len(query))
208
214
 
209
215
  if match.size == 0:
210
- return "No match found"
211
-
212
- words = re.findall(r"\b\w+\b", text)
213
- if words_after is None:
214
- words_after = len(words)
215
- if words_before is None:
216
- words_before = len(words)
217
- start_word_pos = len(re.findall(r"\b\w+\b", text[: match.a]))
218
- start_pos = max(0, start_word_pos - words_before)
219
- end_pos = min(
220
- len(words), start_word_pos + words_after + len(re.findall(r"\b\w+\b", query))
221
- )
216
+ return "", 0, 0
217
+
218
+ segments = text.split()
219
+ n_segs = len(segments)
220
+
221
+ start_segment_pos = len(text[: match.a].split())
222
+
223
+ words_before = words_before or n_segs
224
+ words_after = words_after or n_segs
225
+ start_pos = max(0, start_segment_pos - words_before)
226
+ end_pos = min(len(segments), start_segment_pos + words_after + len(query.split()))
222
227
 
223
- return " ".join(words[start_pos:end_pos])
228
+ return " ".join(segments[start_pos:end_pos]), start_pos, end_pos
224
229
 
225
230
 
226
231
  def eliminate_near_duplicates(passages: List[str], threshold: float = 0.8) -> List[str]:
@@ -4,6 +4,7 @@ from urllib.parse import urlparse
4
4
  from pydispatch import dispatcher
5
5
  from scrapy import signals
6
6
  from scrapy.crawler import CrawlerRunner
7
+ from scrapy.http import Response
7
8
  from scrapy.linkextractors import LinkExtractor
8
9
  from scrapy.spiders import CrawlSpider, Rule
9
10
  from twisted.internet import defer, reactor
@@ -30,7 +31,7 @@ class DomainSpecificSpider(CrawlSpider): # type: ignore
30
31
  self.k = k
31
32
  self.visited_urls: Set[str] = set()
32
33
 
33
- def parse_item(self, response): # type: ignore
34
+ def parse_item(self, response: Response): # type: ignore
34
35
  """Extracts URLs that are within the same domain.
35
36
 
36
37
  Args:
@@ -57,7 +58,7 @@ def scrapy_fetch_urls(url: str, k: int = 20) -> List[str]:
57
58
  """
58
59
  urls = []
59
60
 
60
- def _collect_urls(spider, reason):
61
+ def _collect_urls(spider):
61
62
  """Handler for the spider_closed signal. Collects the visited URLs."""
62
63
  nonlocal urls
63
64
  urls.extend(list(spider.visited_urls))
@@ -1,4 +1,5 @@
1
1
  from csv import Sniffer
2
+ from typing import List
2
3
 
3
4
  import pandas as pd
4
5
 
@@ -48,3 +49,46 @@ def read_tabular_data(path_or_url: str, sep: None | str = None) -> pd.DataFrame:
48
49
  "Unable to read data. "
49
50
  "Please ensure it is correctly formatted. Error: " + str(e)
50
51
  )
52
+
53
+
54
+ def describe_dataframe(
55
+ df: pd.DataFrame, filter_fields: List[str] = [], n_vals: int = 10
56
+ ) -> str:
57
+ """
58
+ Generates a description of the columns in the dataframe,
59
+ along with a listing of up to `n_vals` unique values for each column.
60
+ Intended to be used to insert into an LLM context so it can generate
61
+ appropriate queries or filters on the df.
62
+
63
+ Args:
64
+ df (pd.DataFrame): The dataframe to describe.
65
+ filter_fields (list): A list of fields that can be used for filtering.
66
+ When non-empty, the values-list will be restricted to these.
67
+ n_vals (int): How many unique values to show for each column.
68
+
69
+ Returns:
70
+ str: A description of the dataframe.
71
+ """
72
+ description = []
73
+ for column in df.columns.to_list():
74
+ unique_values = df[column].dropna().unique()
75
+ unique_count = len(unique_values)
76
+ if column not in filter_fields:
77
+ values_desc = f"{unique_count} unique values"
78
+ else:
79
+ if unique_count > n_vals:
80
+ displayed_values = unique_values[:n_vals]
81
+ more_count = unique_count - n_vals
82
+ values_desc = f" Values - {displayed_values}, ... {more_count} more"
83
+ else:
84
+ values_desc = f" Values - {unique_values}"
85
+ col_type = "string" if df[column].dtype == "object" else df[column].dtype
86
+ col_desc = f"* {column} ({col_type}); {values_desc}"
87
+ description.append(col_desc)
88
+
89
+ all_cols = "\n".join(description)
90
+
91
+ return f"""
92
+ Name of each field, its type and unique values (up to {n_vals}):
93
+ {all_cols}
94
+ """
@@ -1,6 +1,9 @@
1
1
  import logging
2
+ import os
3
+ from tempfile import NamedTemporaryFile
2
4
  from typing import List, no_type_check
3
5
 
6
+ import requests
4
7
  import trafilatura
5
8
  from trafilatura.downloads import (
6
9
  add_to_compressed_dict,
@@ -9,7 +12,7 @@ from trafilatura.downloads import (
9
12
  )
10
13
 
11
14
  from langroid.mytypes import DocMetaData, Document
12
- from langroid.parsing.document_parser import DocumentParser
15
+ from langroid.parsing.document_parser import DocumentParser, ImagePdfParser
13
16
  from langroid.parsing.parser import Parser, ParsingConfig
14
17
 
15
18
  logging.getLogger("trafilatura").setLevel(logging.ERROR)
@@ -44,20 +47,65 @@ class URLLoader:
44
47
  sleep_time=5,
45
48
  )
46
49
  for url, result in buffered_downloads(buffer, threads):
47
- if url.lower().endswith(".pdf") or url.lower().endswith(".docx"):
50
+ if (
51
+ url.lower().endswith(".pdf")
52
+ or url.lower().endswith(".docx")
53
+ or url.lower().endswith(".doc")
54
+ ):
48
55
  doc_parser = DocumentParser.create(
49
56
  url,
50
57
  self.parser.config,
51
58
  )
52
- docs.extend(doc_parser.get_doc_chunks())
59
+ new_chunks = doc_parser.get_doc_chunks()
60
+ if len(new_chunks) == 0:
61
+ # If the document is empty, try to extract images
62
+ img_parser = ImagePdfParser(url, self.parser.config)
63
+ new_chunks = img_parser.get_doc_chunks()
64
+ docs.extend(new_chunks)
53
65
  else:
54
- text = trafilatura.extract(
55
- result,
56
- no_fallback=False,
57
- favor_recall=True,
58
- )
59
- if text is not None and text != "":
60
- docs.append(
61
- Document(content=text, metadata=DocMetaData(source=url))
66
+ # Try to detect content type and handle accordingly
67
+ headers = requests.head(url).headers
68
+ content_type = headers.get("Content-Type", "").lower()
69
+ temp_file_suffix = None
70
+ if "application/pdf" in content_type:
71
+ temp_file_suffix = ".pdf"
72
+ elif (
73
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
74
+ in content_type
75
+ ):
76
+ temp_file_suffix = ".docx"
77
+ elif "application/msword" in content_type:
78
+ temp_file_suffix = ".doc"
79
+
80
+ if temp_file_suffix:
81
+ # Download the document content
82
+ response = requests.get(url)
83
+ with NamedTemporaryFile(
84
+ delete=False, suffix=temp_file_suffix
85
+ ) as temp_file:
86
+ temp_file.write(response.content)
87
+ temp_file_path = temp_file.name
88
+ # Process the downloaded document
89
+ doc_parser = DocumentParser.create(
90
+ temp_file_path, self.parser.config
91
+ )
92
+ docs.extend(doc_parser.get_doc_chunks())
93
+ # Clean up the temporary file
94
+ os.remove(temp_file_path)
95
+ else:
96
+ text = trafilatura.extract(
97
+ result,
98
+ no_fallback=False,
99
+ favor_recall=True,
62
100
  )
101
+ if (
102
+ text is None
103
+ and result is not None
104
+ and isinstance(result, str)
105
+ ):
106
+ text = result
107
+ if text is not None and text != "":
108
+ docs.append(
109
+ Document(content=text, metadata=DocMetaData(source=url))
110
+ )
63
111
  return docs
langroid/parsing/urls.py CHANGED
@@ -4,7 +4,7 @@ import tempfile
4
4
  import urllib.parse
5
5
  import urllib.robotparser
6
6
  from typing import List, Optional, Set, Tuple
7
- from urllib.parse import urljoin
7
+ from urllib.parse import urldefrag, urljoin, urlparse
8
8
 
9
9
  import fire
10
10
  import requests
@@ -14,8 +14,6 @@ from rich import print
14
14
  from rich.prompt import Prompt
15
15
  from trafilatura.spider import focused_crawler
16
16
 
17
- from langroid.parsing.spider import scrapy_fetch_urls
18
-
19
17
  logger = logging.getLogger(__name__)
20
18
 
21
19
 
@@ -86,7 +84,15 @@ def get_list_from_user(
86
84
  url = input_str
87
85
  input_str = Prompt.ask("[blue] How many new URLs to crawl?", default="0")
88
86
  max_urls = int(input_str) + 1
89
- tot_urls = scrapy_fetch_urls(url, k=max_urls)
87
+ tot_urls = list(find_urls(url, max_links=max_urls, max_depth=2))
88
+ tot_urls_str = "\n".join(tot_urls)
89
+ print(
90
+ f"""
91
+ Found these {len(tot_urls)} links upto depth 2:
92
+ {tot_urls_str}
93
+ """
94
+ )
95
+
90
96
  input_set.update(tot_urls)
91
97
  else:
92
98
  input_set.add(input_str.strip())
@@ -106,32 +112,42 @@ def is_url(s: str) -> bool:
106
112
  return False
107
113
 
108
114
 
109
- def get_urls_and_paths(inputs: List[str]) -> Tuple[List[str], List[str]]:
115
+ def get_urls_paths_bytes_indices(
116
+ inputs: List[str | bytes],
117
+ ) -> Tuple[List[int], List[int], List[int]]:
110
118
  """
111
- Given a list of inputs, return a list of URLs and a list of paths.
119
+ Given a list of inputs, return a
120
+ list of indices of URLs, list of indices of paths, list of indices of byte-contents.
112
121
  Args:
113
- inputs: list of strings
122
+ inputs: list of strings or bytes
114
123
  Returns:
115
- list of URLs, list of paths
124
+ list of Indices of URLs,
125
+ list of indices of paths,
126
+ list of indices of byte-contents
116
127
  """
117
128
  urls = []
118
129
  paths = []
119
- for item in inputs:
130
+ byte_list = []
131
+ for i, item in enumerate(inputs):
132
+ if isinstance(item, bytes):
133
+ byte_list.append(i)
134
+ continue
120
135
  try:
121
- m = Url(url=parse_obj_as(HttpUrl, item))
122
- urls.append(str(m.url))
136
+ Url(url=parse_obj_as(HttpUrl, item))
137
+ urls.append(i)
123
138
  except ValidationError:
124
139
  if os.path.exists(item):
125
- paths.append(item)
140
+ paths.append(i)
126
141
  else:
127
142
  logger.warning(f"{item} is neither a URL nor a path.")
128
- return urls, paths
143
+ return urls, paths, byte_list
129
144
 
130
145
 
131
146
  def crawl_url(url: str, max_urls: int = 1) -> List[str]:
132
147
  """
133
148
  Crawl starting at the url and return a list of URLs to be parsed,
134
149
  up to a maximum of `max_urls`.
150
+ This has not been tested to work as intended. Ignore.
135
151
  """
136
152
  if max_urls == 1:
137
153
  # no need to crawl, just return the original list
@@ -161,6 +177,7 @@ def crawl_url(url: str, max_urls: int = 1) -> List[str]:
161
177
  )
162
178
  if to_visit is None:
163
179
  break
180
+
164
181
  if known_urls is None:
165
182
  return [url]
166
183
  final_urls = [s.strip() for s in known_urls]
@@ -169,46 +186,77 @@ def crawl_url(url: str, max_urls: int = 1) -> List[str]:
169
186
 
170
187
  def find_urls(
171
188
  url: str = "https://en.wikipedia.org/wiki/Generative_pre-trained_transformer",
189
+ max_links: int = 20,
172
190
  visited: Optional[Set[str]] = None,
173
191
  depth: int = 0,
174
192
  max_depth: int = 2,
193
+ match_domain: bool = True,
175
194
  ) -> Set[str]:
176
195
  """
177
196
  Recursively find all URLs on a given page.
197
+
178
198
  Args:
179
- url:
180
- visited:
181
- depth:
182
- max_depth:
199
+ url (str): The URL to start from.
200
+ max_links (int): The maximum number of links to find.
201
+ visited (set): A set of URLs that have already been visited.
202
+ depth (int): The current depth of the recursion.
203
+ max_depth (int): The maximum depth of the recursion.
204
+ match_domain (bool): Whether to only return URLs that are on the same domain.
183
205
 
184
206
  Returns:
185
-
207
+ set: A set of URLs found on the page.
186
208
  """
209
+
187
210
  if visited is None:
188
211
  visited = set()
189
- visited.add(url)
190
212
 
191
- try:
192
- response = requests.get(url)
193
- response.raise_for_status()
194
- except (
195
- requests.exceptions.HTTPError,
196
- requests.exceptions.RequestException,
197
- ):
198
- print(f"Failed to fetch '{url}'")
213
+ if url in visited or depth > max_depth:
199
214
  return visited
200
215
 
201
- soup = BeautifulSoup(response.content, "html.parser")
202
- links = soup.find_all("a", href=True)
203
-
204
- urls = [urljoin(url, link["href"]) for link in links] # Construct full URLs
205
-
206
- if depth < max_depth:
207
- for link_url in urls:
208
- if link_url not in visited:
209
- find_urls(link_url, visited, depth + 1, max_depth)
216
+ visited.add(url)
217
+ base_domain = urlparse(url).netloc
210
218
 
211
- return visited
219
+ try:
220
+ response = requests.get(url, timeout=5)
221
+ response.raise_for_status()
222
+ soup = BeautifulSoup(response.text, "html.parser")
223
+ links = [urljoin(url, a["href"]) for a in soup.find_all("a", href=True)]
224
+
225
+ # Defrag links: discard links that are to portions of same page
226
+ defragged_links = list(set(urldefrag(link).url for link in links))
227
+
228
+ # Filter links based on domain matching requirement
229
+ domain_matching_links = [
230
+ link for link in defragged_links if urlparse(link).netloc == base_domain
231
+ ]
232
+
233
+ # ensure url is first, since below we are taking first max_links urls
234
+ domain_matching_links = [url] + [x for x in domain_matching_links if x != url]
235
+
236
+ # If found links exceed max_links, return immediately
237
+ if len(domain_matching_links) >= max_links:
238
+ return set(domain_matching_links[:max_links])
239
+
240
+ for link in domain_matching_links:
241
+ if len(visited) >= max_links:
242
+ break
243
+
244
+ if link not in visited:
245
+ visited.update(
246
+ find_urls(
247
+ link,
248
+ max_links,
249
+ visited,
250
+ depth + 1,
251
+ max_depth,
252
+ match_domain,
253
+ )
254
+ )
255
+
256
+ except (requests.RequestException, Exception) as e:
257
+ print(f"Error fetching {url}. Error: {e}")
258
+
259
+ return set(list(visited)[:max_links])
212
260
 
213
261
 
214
262
  def org_user_from_github(url: str) -> str: