lean-explore 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,8 @@ communication with the backend Lean Explore search engine API for
7
7
  performing searches and retrieving detailed information.
8
8
  """
9
9
 
10
- from typing import List, Optional
10
+ import asyncio
11
+ from typing import List, Optional, Union, overload
11
12
 
12
13
  import httpx
13
14
 
@@ -44,81 +45,172 @@ class Client:
44
45
  self.timeout: float = timeout
45
46
  self._headers: dict = {"Authorization": f"Bearer {self.api_key}"}
46
47
 
48
+ async def _fetch_one_search(
49
+ self,
50
+ client: httpx.AsyncClient,
51
+ query: str,
52
+ package_filters: Optional[List[str]],
53
+ ) -> APISearchResponse:
54
+ """Coroutine to fetch a single search result.
55
+
56
+ Args:
57
+ client: An active httpx.AsyncClient instance.
58
+ query: The search query string.
59
+ package_filters: An optional list of package names.
60
+
61
+ Returns:
62
+ An APISearchResponse object.
63
+ """
64
+ endpoint = f"{self.base_url}/search"
65
+ params = {"q": query}
66
+ if package_filters:
67
+ params["pkg"] = package_filters
68
+
69
+ response = await client.get(endpoint, params=params, headers=self._headers)
70
+ response.raise_for_status()
71
+ return APISearchResponse(**response.json())
72
+
73
+ @overload
47
74
  async def search(
48
75
  self, query: str, package_filters: Optional[List[str]] = None
49
- ) -> APISearchResponse:
76
+ ) -> APISearchResponse: ...
77
+
78
+ @overload
79
+ async def search(
80
+ self, query: List[str], package_filters: Optional[List[str]] = None
81
+ ) -> List[APISearchResponse]: ...
82
+
83
+ async def search(
84
+ self,
85
+ query: Union[str, List[str]],
86
+ package_filters: Optional[List[str]] = None,
87
+ ) -> Union[APISearchResponse, List[APISearchResponse]]:
50
88
  """Performs a search for statement groups via the API.
51
89
 
90
+ This method can handle a single query string or a list of query strings.
91
+ When a list is provided, requests are sent concurrently.
92
+
52
93
  Args:
53
- query: The search query string.
94
+ query: The search query string or a list of query strings.
54
95
  package_filters: An optional list of package names to filter the
55
- search by.
96
+ search by. This filter is applied to all queries.
56
97
 
57
98
  Returns:
58
- An APISearchResponse object containing the search results and
59
- associated metadata.
99
+ An APISearchResponse object if a single query was provided, or a
100
+ list of APISearchResponse objects if a list of queries was provided.
60
101
 
61
102
  Raises:
62
103
  httpx.HTTPStatusError: If the API returns an HTTP error status (4xx or 5xx).
63
104
  httpx.RequestError: For network-related issues or other request errors.
64
105
  """
65
- endpoint = f"{self.base_url}/search"
66
- params = {"q": query}
67
- if package_filters:
68
- params["pkg"] = package_filters
106
+ was_single_query = isinstance(query, str)
107
+ queries = [query] if was_single_query else query
69
108
 
70
109
  async with httpx.AsyncClient(timeout=self.timeout) as client:
71
- response = await client.get(endpoint, params=params, headers=self._headers)
72
- response.raise_for_status()
73
- return APISearchResponse(**response.json())
74
-
75
- async def get_by_id(self, group_id: int) -> Optional[APISearchResultItem]:
110
+ tasks = [
111
+ self._fetch_one_search(client, q, package_filters) for q in queries
112
+ ]
113
+ results = await asyncio.gather(*tasks)
114
+
115
+ if was_single_query:
116
+ return results[0]
117
+ return results
118
+
119
+ async def _fetch_one_by_id(
120
+ self, client: httpx.AsyncClient, group_id: int
121
+ ) -> Optional[APISearchResultItem]:
122
+ endpoint = f"{self.base_url}/statement_groups/{group_id}"
123
+ response = await client.get(endpoint, headers=self._headers)
124
+ if response.status_code == 404:
125
+ return None
126
+ response.raise_for_status()
127
+ return APISearchResultItem(**response.json())
128
+
129
+ @overload
130
+ async def get_by_id(self, group_id: int) -> Optional[APISearchResultItem]: ...
131
+
132
+ @overload
133
+ async def get_by_id(
134
+ self, group_id: List[int]
135
+ ) -> List[Optional[APISearchResultItem]]: ...
136
+
137
+ async def get_by_id(
138
+ self, group_id: Union[int, List[int]]
139
+ ) -> Union[Optional[APISearchResultItem], List[Optional[APISearchResultItem]]]:
76
140
  """Retrieves a specific statement group by its unique ID via the API.
77
141
 
78
142
  Args:
79
- group_id: The unique identifier of the statement group.
143
+ group_id: The unique identifier of the statement group, or a list of IDs.
80
144
 
81
145
  Returns:
82
- An APISearchResultItem object if the statement group is found,
83
- otherwise None if a 404 error is received.
146
+ An APISearchResultItem object if a single ID was found, None if it was
147
+ not found. A list of Optional[APISearchResultItem] if a list of
148
+ IDs was provided.
84
149
 
85
150
  Raises:
86
151
  httpx.HTTPStatusError: If the API returns an HTTP error status
87
- other than 404 (e.g., 401, 403, 5xx).
152
+ other than 404 (e.g., 401, 403, 5xx).
88
153
  httpx.RequestError: For network-related issues or other request errors.
89
154
  """
90
- endpoint = f"{self.base_url}/statement_groups/{group_id}"
155
+ was_single_id = isinstance(group_id, int)
156
+ group_ids = [group_id] if was_single_id else group_id
157
+
91
158
  async with httpx.AsyncClient(timeout=self.timeout) as client:
92
- response = await client.get(endpoint, headers=self._headers)
93
- if response.status_code == 404:
94
- return None
95
- response.raise_for_status()
96
- return APISearchResultItem(**response.json())
159
+ tasks = [self._fetch_one_by_id(client, g_id) for g_id in group_ids]
160
+ results = await asyncio.gather(*tasks)
161
+
162
+ if was_single_id:
163
+ return results[0]
164
+ return results
97
165
 
98
- async def get_dependencies(self, group_id: int) -> Optional[APICitationsResponse]:
166
+ async def _fetch_one_dependencies(
167
+ self, client: httpx.AsyncClient, group_id: int
168
+ ) -> Optional[APICitationsResponse]:
169
+ endpoint = f"{self.base_url}/statement_groups/{group_id}/dependencies"
170
+ response = await client.get(endpoint, headers=self._headers)
171
+ if response.status_code == 404:
172
+ return None
173
+ response.raise_for_status()
174
+ return APICitationsResponse(**response.json())
175
+
176
+ @overload
177
+ async def get_dependencies(
178
+ self, group_id: int
179
+ ) -> Optional[APICitationsResponse]: ...
180
+
181
+ @overload
182
+ async def get_dependencies(
183
+ self, group_id: List[int]
184
+ ) -> List[Optional[APICitationsResponse]]: ...
185
+
186
+ async def get_dependencies(
187
+ self, group_id: Union[int, List[int]]
188
+ ) -> Union[Optional[APICitationsResponse], List[Optional[APICitationsResponse]]]:
99
189
  """Retrieves the dependencies (citations) for a specific statement group.
100
190
 
101
- This method fetches the statement groups that the specified 'group_id'
102
- depends on (i.e., cites).
191
+ This method fetches the statement groups that the specified 'group_id'(s)
192
+ depend on (i.e., cite).
103
193
 
104
194
  Args:
105
- group_id: The unique identifier of the statement group for which
106
- to fetch dependencies.
195
+ group_id: The unique identifier of the statement group, or a list of IDs.
107
196
 
108
197
  Returns:
109
- An APICitationsResponse object containing the list of dependencies
110
- (cited items) if the source statement group is found. Returns None
111
- if the source statement group itself is not found (receives a 404).
198
+ An APICitationsResponse object if a single ID was provided. A list
199
+ of Optional[APICitationsResponse] if a list of IDs was provided.
200
+ None is returned for IDs that are not found.
112
201
 
113
202
  Raises:
114
203
  httpx.HTTPStatusError: If the API returns an HTTP error status
115
- other than 404 (e.g., 401, 403, 5xx).
204
+ other than 404 (e.g., 401, 403, 5xx).
116
205
  httpx.RequestError: For network-related issues or other request errors.
117
206
  """
118
- endpoint = f"{self.base_url}/statement_groups/{group_id}/dependencies"
207
+ was_single_id = isinstance(group_id, int)
208
+ group_ids = [group_id] if was_single_id else group_id
209
+
119
210
  async with httpx.AsyncClient(timeout=self.timeout) as client:
120
- response = await client.get(endpoint, headers=self._headers)
121
- if response.status_code == 404:
122
- return None
123
- response.raise_for_status()
124
- return APICitationsResponse(**response.json())
211
+ tasks = [self._fetch_one_dependencies(client, g_id) for g_id in group_ids]
212
+ results = await asyncio.gather(*tasks)
213
+
214
+ if was_single_id:
215
+ return results[0]
216
+ return results
lean_explore/cli/agent.py CHANGED
@@ -520,19 +520,27 @@ async def _run_agent_session(
520
520
  "to the user.\n"
521
521
  "**Output:** CLI-friendly (plain text, simple lists). "
522
522
  "NO complex Markdown/LaTeX.\n\n"
523
+ "**Tool Usage & Efficiency:**\n"
524
+ "* The `search`, `get_by_id`, and `get_dependencies` tools can "
525
+ "all accept a list of inputs (queries or integer IDs) to "
526
+ "perform batch operations. This is highly efficient. For "
527
+ "example, `search(query=['query 1', 'query 2'])` or "
528
+ "`get_by_id(group_id=[123, 456])`.\n"
529
+ "* Always prefer making one batch call over multiple single "
530
+ "calls.\n\n"
523
531
  "**Packages:** Use exact top-level names for filters (Batteries, "
524
532
  "Init, Lean, Mathlib, PhysLean, Std). Map subpackage mentions "
525
533
  "to top-level (e.g., 'Mathlib.Analysis' -> 'Mathlib').\n\n"
526
534
  "**Core Workflow:**\n"
527
535
  "1. **Search & Analyze:**\n"
528
536
  " * Execute multiple distinct `search` queries for each user "
529
- "request (e.g., using full statements, rephrasing). Set `limit` "
537
+ "request by passing a list of queries to the tool. Set `limit` "
530
538
  ">= 10 for each search.\n"
531
539
  " * From all search results, select the statement(s) most "
532
540
  "helpful to the user.\n"
533
- " * For each selected statement, **MUST** use "
534
- "`get_dependencies` to understand its context before "
535
- "explaining.\n\n"
541
+ " * For each selected statement, use `get_dependencies` to "
542
+ "understand its context. Do this efficiently by collecting all "
543
+ "relevant IDs and passing them in a single list call.\n\n"
536
544
  "2. **Explain Results (Conversational & CLI-Friendly):**\n"
537
545
  " * Briefly state your search approach (e.g., 'I looked into X "
538
546
  "in Mathlib...').\n"
@@ -543,15 +551,14 @@ async def _run_agent_session(
543
551
  "`informal_description`, `statement_text`).\n"
544
552
  " * Provide the full Lean code (`statement_text`).\n"
545
553
  " * Explain key dependencies (what they are, their role, "
546
- "using `statement_text` or `display_statement_text` from "
547
- "`get_dependencies` output).\n"
554
+ "using `statement_text` from `get_dependencies` output).\n"
548
555
  "3. **Specific User Follow-ups (If Asked):**\n"
549
- " * **`get_by_id`:** For a specific ID, provide: ID, Lean name, "
550
- "statement text, source/line, docstring, informal description "
551
- "(structured CLI format).\n"
552
- " * **`get_dependencies` (Direct Request):** For all "
553
- "dependencies of an ID, list: ID, Lean name, statement "
554
- "text/summary. State total count.\n\n"
556
+ " * **`get_by_id`:** For one or more IDs, provide: ID, "
557
+ "Lean name, statement text, source/line, docstring, informal "
558
+ "description (structured CLI format).\n"
559
+ " * **`get_dependencies` (Direct Request):** For one or more "
560
+ "IDs, list dependencies for each: ID, Lean name, statement "
561
+ "text/summary. State total count per ID.\n\n"
555
562
  "Always be concise, helpful, and clear."
556
563
  ),
557
564
  mcp_servers=[server_instance],
@@ -9,7 +9,7 @@ data assets (SQLite database, FAISS index, and embedding models).
9
9
 
10
10
  import logging
11
11
  import time
12
- from typing import List, Optional
12
+ from typing import List, Optional, Union, overload
13
13
 
14
14
  import faiss # For type hinting if needed
15
15
  from sentence_transformers import SentenceTransformer # For type hinting if needed
@@ -69,10 +69,10 @@ class Service:
69
69
 
70
70
  Raises:
71
71
  FileNotFoundError: If essential data files (DB, FAISS index, map)
72
- are not found at their expected locations.
72
+ are not found at their expected locations.
73
73
  RuntimeError: If the embedding model fails to load or if other
74
- critical initialization steps (like database connection
75
- after file checks) fail.
74
+ critical initialization steps (like database connection
75
+ after file checks) fail.
76
76
  """
77
77
  logger.info("Initializing local Service...")
78
78
  try:
@@ -210,43 +210,25 @@ class Service:
210
210
  informal_description=sg_orm.informal_description,
211
211
  )
212
212
 
213
- def search(
213
+ def _perform_one_search(
214
214
  self,
215
215
  query: str,
216
- package_filters: Optional[List[str]] = None,
217
- limit: Optional[int] = None,
216
+ package_filters: Optional[List[str]],
217
+ limit: Optional[int],
218
218
  ) -> APISearchResponse:
219
- """Performs a local search for statement groups.
219
+ """Helper to perform and package a single local search.
220
220
 
221
221
  Args:
222
222
  query: The search query string.
223
223
  package_filters: An optional list of package names to filter results by.
224
224
  limit: An optional limit on the number of results to return.
225
- If None, defaults.DEFAULT_RESULTS_LIMIT is used.
226
225
 
227
226
  Returns:
228
- An APISearchResponse object containing search results and metadata.
229
-
230
- Raises:
231
- RuntimeError: If service not properly initialized (e.g., assets missing).
232
- Exception: Propagates exceptions from `perform_search`.
227
+ An APISearchResponse for the given query.
233
228
  """
234
229
  start_time = time.time()
235
230
  actual_limit = limit if limit is not None else self.default_results_limit
236
231
 
237
- if (
238
- self.embedding_model is None
239
- or self.faiss_index is None
240
- or self.text_chunk_id_map is None
241
- ):
242
- logger.error(
243
- "Search service assets not loaded. Service may not have initialized "
244
- "correctly."
245
- )
246
- raise RuntimeError(
247
- "Search service assets not loaded. Please ensure data has been fetched."
248
- )
249
-
250
232
  with self.SessionLocal() as session:
251
233
  try:
252
234
  ranked_results_orm = perform_search(
@@ -291,102 +273,207 @@ class Service:
291
273
  processing_time_ms=processing_time_ms,
292
274
  )
293
275
 
294
- def get_by_id(self, group_id: int) -> Optional[APISearchResultItem]:
276
+ @overload
277
+ def search(
278
+ self,
279
+ query: str,
280
+ package_filters: Optional[List[str]] = None,
281
+ limit: Optional[int] = None,
282
+ ) -> APISearchResponse: ...
283
+
284
+ @overload
285
+ def search(
286
+ self,
287
+ query: List[str],
288
+ package_filters: Optional[List[str]] = None,
289
+ limit: Optional[int] = None,
290
+ ) -> List[APISearchResponse]: ...
291
+
292
+ def search(
293
+ self,
294
+ query: Union[str, List[str]],
295
+ package_filters: Optional[List[str]] = None,
296
+ limit: Optional[int] = None,
297
+ ) -> Union[APISearchResponse, List[APISearchResponse]]:
298
+ """Performs a local search for statement groups.
299
+
300
+ This method can handle a single query string or a list of query strings.
301
+ When a list is provided, searches are performed serially.
302
+
303
+ Args:
304
+ query: The search query string or a list of query strings.
305
+ package_filters: An optional list of package names to filter results by.
306
+ limit: An optional limit on the number of results to return.
307
+ If None, defaults.DEFAULT_RESULTS_LIMIT is used.
308
+
309
+ Returns:
310
+ An APISearchResponse object if a single query was provided, or a
311
+ list of APISearchResponse objects if a list of queries was provided.
312
+
313
+ Raises:
314
+ RuntimeError: If service not properly initialized (e.g., assets missing).
315
+ Exception: Propagates exceptions from `perform_search`.
316
+ """
317
+ if (
318
+ self.embedding_model is None
319
+ or self.faiss_index is None
320
+ or self.text_chunk_id_map is None
321
+ ):
322
+ logger.error(
323
+ "Search service assets not loaded. Service may not have initialized "
324
+ "correctly."
325
+ )
326
+ raise RuntimeError(
327
+ "Search service assets not loaded. Please ensure data has been fetched."
328
+ )
329
+
330
+ was_single_query = isinstance(query, str)
331
+ queries = [query] if was_single_query else query
332
+ results = []
333
+
334
+ for q in queries:
335
+ response = self._perform_one_search(q, package_filters, limit)
336
+ results.append(response)
337
+
338
+ if was_single_query:
339
+ return results[0]
340
+ return results
341
+
342
+ @overload
343
+ def get_by_id(self, group_id: int) -> Optional[APISearchResultItem]: ...
344
+
345
+ @overload
346
+ def get_by_id(self, group_id: List[int]) -> List[Optional[APISearchResultItem]]: ...
347
+
348
+ def get_by_id(
349
+ self, group_id: Union[int, List[int]]
350
+ ) -> Union[Optional[APISearchResultItem], List[Optional[APISearchResultItem]]]:
295
351
  """Retrieves a specific statement group by its ID from local data.
296
352
 
297
353
  Args:
298
- group_id: The unique identifier of the statement group.
354
+ group_id: The unique identifier of the statement group, or a list of IDs.
299
355
 
300
356
  Returns:
301
- An APISearchResultItem if found, otherwise None.
357
+ An APISearchResultItem if a single ID was found, None if not found.
358
+ A list of Optional[APISearchResultItem] if a list of IDs was provided.
302
359
  """
360
+ was_single_id = isinstance(group_id, int)
361
+ group_ids = [group_id] if was_single_id else group_id
362
+ results = []
363
+
303
364
  with self.SessionLocal() as session:
304
- try:
305
- stmt_group_orm = (
306
- session.query(StatementGroup)
307
- .options(joinedload(StatementGroup.primary_declaration))
308
- .filter(StatementGroup.id == group_id)
309
- .first()
310
- )
311
- if stmt_group_orm:
312
- return self._serialize_sg_to_api_item(stmt_group_orm)
313
- return None
314
- except SQLAlchemyError as e:
315
- logger.error(
316
- f"Database error in get_by_id for group_id {group_id}: {e}",
317
- exc_info=True,
318
- )
319
- return None
320
- except Exception as e:
321
- logger.error(
322
- f"Unexpected error in get_by_id for group_id {group_id}: {e}",
323
- exc_info=True,
324
- )
325
- return None
365
+ for g_id in group_ids:
366
+ try:
367
+ stmt_group_orm = (
368
+ session.query(StatementGroup)
369
+ .options(joinedload(StatementGroup.primary_declaration))
370
+ .filter(StatementGroup.id == g_id)
371
+ .first()
372
+ )
373
+ if stmt_group_orm:
374
+ results.append(self._serialize_sg_to_api_item(stmt_group_orm))
375
+ else:
376
+ results.append(None)
377
+ except SQLAlchemyError as e:
378
+ logger.error(
379
+ f"Database error in get_by_id for group_id {g_id}: {e}",
380
+ exc_info=True,
381
+ )
382
+ results.append(None)
383
+ except Exception as e:
384
+ logger.error(
385
+ f"Unexpected error in get_by_id for group_id {g_id}: {e}",
386
+ exc_info=True,
387
+ )
388
+ results.append(None)
326
389
 
327
- def get_dependencies(self, group_id: int) -> Optional[APICitationsResponse]:
390
+ if was_single_id:
391
+ return results[0]
392
+ return results
393
+
394
+ @overload
395
+ def get_dependencies(self, group_id: int) -> Optional[APICitationsResponse]: ...
396
+
397
+ @overload
398
+ def get_dependencies(
399
+ self, group_id: List[int]
400
+ ) -> List[Optional[APICitationsResponse]]: ...
401
+
402
+ def get_dependencies(
403
+ self, group_id: Union[int, List[int]]
404
+ ) -> Union[Optional[APICitationsResponse], List[Optional[APICitationsResponse]]]:
328
405
  """Retrieves citations for a specific statement group from local data.
329
406
 
330
407
  Citations are the statement groups that the specified group_id depends on.
331
408
 
332
409
  Args:
333
- group_id: The unique identifier of the statement group for which
334
- to fetch citations.
410
+ group_id: The unique ID of the source group, or a list of IDs.
335
411
 
336
412
  Returns:
337
- An APICitationsResponse object if the source group is found and has
338
- citations, or an APICitationsResponse with an empty list if no
339
- citations, otherwise None if the source group itself is not found or
340
- a DB error occurs.
413
+ An APICitationsResponse if a single ID was provided, or a list of
414
+ Optional[APICitationsResponse] if a list of IDs was given. Returns
415
+ None for IDs that are not found or cause an error.
341
416
  """
417
+ was_single_id = isinstance(group_id, int)
418
+ group_ids = [group_id] if was_single_id else group_id
419
+ results = []
420
+
342
421
  with self.SessionLocal() as session:
343
- try:
344
- source_group_exists = (
345
- session.query(StatementGroup.id)
346
- .filter(StatementGroup.id == group_id)
347
- .first()
348
- )
349
- if not source_group_exists:
350
- logger.warning(
351
- f"Source statement group ID {group_id} not found for "
352
- "dependency lookup."
422
+ for g_id in group_ids:
423
+ try:
424
+ source_group_exists = (
425
+ session.query(StatementGroup.id)
426
+ .filter(StatementGroup.id == g_id)
427
+ .first()
353
428
  )
354
- return None
355
-
356
- cited_target_groups_orm = (
357
- session.query(StatementGroup)
358
- .join(
359
- StatementGroupDependency,
360
- StatementGroup.id
361
- == StatementGroupDependency.target_statement_group_id,
429
+ if not source_group_exists:
430
+ logger.warning(
431
+ f"Source statement group ID {g_id} not found for "
432
+ "dependency lookup."
433
+ )
434
+ results.append(None)
435
+ continue
436
+
437
+ cited_target_groups_orm = (
438
+ session.query(StatementGroup)
439
+ .join(
440
+ StatementGroupDependency,
441
+ StatementGroup.id
442
+ == StatementGroupDependency.target_statement_group_id,
443
+ )
444
+ .filter(
445
+ StatementGroupDependency.source_statement_group_id == g_id
446
+ )
447
+ .options(joinedload(StatementGroup.primary_declaration))
448
+ .all()
362
449
  )
363
- .filter(
364
- StatementGroupDependency.source_statement_group_id == group_id
365
- )
366
- .options(joinedload(StatementGroup.primary_declaration))
367
- .all()
368
- )
369
450
 
370
- citations_api_items = [
371
- self._serialize_sg_to_api_item(sg_orm)
372
- for sg_orm in cited_target_groups_orm
373
- ]
451
+ citations_api_items = [
452
+ self._serialize_sg_to_api_item(sg_orm)
453
+ for sg_orm in cited_target_groups_orm
454
+ ]
455
+
456
+ results.append(
457
+ APICitationsResponse(
458
+ source_group_id=g_id,
459
+ citations=citations_api_items,
460
+ count=len(citations_api_items),
461
+ )
462
+ )
463
+ except SQLAlchemyError as e:
464
+ logger.error(
465
+ f"Database error in get_dependencies for group_id {g_id}: {e}",
466
+ exc_info=True,
467
+ )
468
+ results.append(None)
469
+ except Exception as e:
470
+ logger.error(
471
+ f"Unexpected error in get_dependencies for "
472
+ f"group_id {g_id}: {e}",
473
+ exc_info=True,
474
+ )
475
+ results.append(None)
374
476
 
375
- return APICitationsResponse(
376
- source_group_id=group_id,
377
- citations=citations_api_items,
378
- count=len(citations_api_items),
379
- )
380
- except SQLAlchemyError as e:
381
- logger.error(
382
- f"Database error in get_dependencies for group_id {group_id}: {e}",
383
- exc_info=True,
384
- )
385
- return None
386
- except Exception as e:
387
- logger.error(
388
- f"Unexpected error in get_dependencies for "
389
- f"group_id {group_id}: {e}",
390
- exc_info=True,
391
- )
392
- return None
477
+ if was_single_id:
478
+ return results[0]
479
+ return results
lean_explore/mcp/tools.py CHANGED
@@ -10,7 +10,7 @@ made available through the MCP application context.
10
10
 
11
11
  import asyncio # Needed for asyncio.iscoroutinefunction
12
12
  import logging
13
- from typing import Any, Dict, List, Optional
13
+ from typing import Any, Dict, List, Optional, Union
14
14
 
15
15
  from mcp.server.fastmcp import Context as MCPContext
16
16
 
@@ -82,161 +82,189 @@ def _prepare_mcp_result_item(backend_item: APISearchResultItem) -> APISearchResu
82
82
  @mcp_app.tool()
83
83
  async def search(
84
84
  ctx: MCPContext,
85
- query: str,
85
+ query: Union[str, List[str]],
86
86
  package_filters: Optional[List[str]] = None,
87
87
  limit: int = 10,
88
- ) -> Dict[str, Any]:
89
- """Searches Lean statement groups by a query string.
88
+ ) -> List[Dict[str, Any]]:
89
+ """Searches Lean statement groups by a query string or list of strings.
90
90
 
91
91
  This tool allows for filtering by package names and limits the number
92
- of results returned.
92
+ of results returned per query.
93
93
 
94
94
  Args:
95
95
  ctx: The MCP context, providing access to shared resources like the
96
96
  backend service.
97
- query: The search query string. For example, "continuous function" or
98
- "prime number theorem".
97
+ query: A single search query string or a list of query strings. For
98
+ example, "continuous function" or ["prime number theorem",
99
+ "fundamental theorem of arithmetic"].
99
100
  package_filters: An optional list of package names to filter the search
100
101
  results by. For example, `["Mathlib.Analysis",
101
102
  "Mathlib.Order"]`. If None or empty, no package filter
102
103
  is applied.
103
- limit: The maximum number of search results to return from this tool.
104
+ limit: The maximum number of search results to return per query.
104
105
  Defaults to 10. Must be a positive integer.
105
106
 
106
107
  Returns:
107
- A dictionary corresponding to the APISearchResponse model, containing
108
- the search results (potentially truncated by the `limit` parameter of
109
- this tool), and metadata about the search operation. The
110
- `display_statement_text` field within each result item is omitted.
108
+ A list of dictionaries, where each dictionary corresponds to the
109
+ APISearchResponse model. Each response contains the search results
110
+ for a single query. The `display_statement_text` field within each
111
+ result item is omitted.
111
112
  """
112
113
  backend = await _get_backend_from_context(ctx)
113
114
  logger.info(
114
- f"MCP Tool 'search' called with query: '{query}', "
115
+ f"MCP Tool 'search' called with query/queries: '{query}', "
115
116
  f"packages: {package_filters}, tool_limit: {limit}"
116
117
  )
117
118
 
118
119
  if not hasattr(backend, "search"):
119
120
  logger.error("Backend service does not have a 'search' method.")
120
- # This should ideally return a structured error for MCP if possible.
121
- # For now, FastMCP will convert this RuntimeError.
122
121
  raise RuntimeError("Search functionality not available on configured backend.")
123
122
 
124
- tool_limit = max(1, limit) # Ensure limit is at least 1 for slicing
125
- api_response_pydantic: Optional[APISearchResponse]
123
+ tool_limit = max(1, limit)
124
+ backend_responses: Union[APISearchResponse, List[APISearchResponse]]
126
125
 
127
126
  # Conditionally await based on the backend's search method type
128
127
  if asyncio.iscoroutinefunction(backend.search):
129
- api_response_pydantic = await backend.search(
130
- query=query,
131
- package_filters=package_filters,
132
- # The backend.search method uses its own internal default for limit
133
- # if None is passed, or the passed limit.
134
- # The MCP tool will truncate the results later using tool_limit.
135
- )
136
- else:
137
- api_response_pydantic = backend.search(
128
+ backend_responses = await backend.search(
138
129
  query=query, package_filters=package_filters
139
130
  )
131
+ else:
132
+ backend_responses = backend.search(query=query, package_filters=package_filters)
140
133
 
141
- if not api_response_pydantic:
142
- logger.warning("Backend search returned None, responding with empty results.")
143
- empty_response = APISearchResponse(
144
- query=query,
145
- packages_applied=package_filters or [],
146
- results=[],
147
- count=0,
148
- total_candidates_considered=0,
149
- processing_time_ms=0,
134
+ # Normalize to a list for consistent processing, handling None from backend.
135
+ if backend_responses is None:
136
+ responses_list = []
137
+ else:
138
+ responses_list = (
139
+ [backend_responses]
140
+ if isinstance(backend_responses, APISearchResponse)
141
+ else backend_responses
150
142
  )
151
- return empty_response.model_dump(exclude_none=True)
152
143
 
153
- actual_backend_results = api_response_pydantic.results
154
-
155
- mcp_results_list = []
156
- for backend_item in actual_backend_results[:tool_limit]: # Apply MCP tool's limit
157
- mcp_results_list.append(_prepare_mcp_result_item(backend_item))
158
-
159
- final_mcp_response = APISearchResponse(
160
- query=api_response_pydantic.query,
161
- packages_applied=api_response_pydantic.packages_applied,
162
- results=mcp_results_list,
163
- count=len(mcp_results_list), # Count is after this tool's truncation
164
- total_candidates_considered=api_response_pydantic.total_candidates_considered,
165
- processing_time_ms=api_response_pydantic.processing_time_ms,
166
- )
144
+ final_mcp_responses = []
145
+
146
+ for response_pydantic in responses_list:
147
+ if not response_pydantic:
148
+ logger.warning("A backend search returned None; skipping this response.")
149
+ continue
150
+
151
+ actual_backend_results = response_pydantic.results
152
+ mcp_results_list = []
153
+ for backend_item in actual_backend_results[:tool_limit]:
154
+ mcp_results_list.append(_prepare_mcp_result_item(backend_item))
155
+
156
+ final_mcp_response = APISearchResponse(
157
+ query=response_pydantic.query,
158
+ packages_applied=response_pydantic.packages_applied,
159
+ results=mcp_results_list,
160
+ count=len(mcp_results_list),
161
+ total_candidates_considered=response_pydantic.total_candidates_considered,
162
+ processing_time_ms=response_pydantic.processing_time_ms,
163
+ )
164
+ final_mcp_responses.append(final_mcp_response.model_dump(exclude_none=True))
167
165
 
168
- return final_mcp_response.model_dump(exclude_none=True)
166
+ return final_mcp_responses
169
167
 
170
168
 
171
169
  @mcp_app.tool()
172
- async def get_by_id(ctx: MCPContext, group_id: int) -> Optional[Dict[str, Any]]:
173
- """Retrieves a specific statement group by its unique identifier.
170
+ async def get_by_id(
171
+ ctx: MCPContext, group_id: Union[int, List[int]]
172
+ ) -> List[Optional[Dict[str, Any]]]:
173
+ """Retrieves specific statement groups by their unique identifier(s).
174
174
 
175
- The `display_statement_text` field is omitted from the response.
175
+ The `display_statement_text` field is omitted from the response. This tool
176
+ always returns a list of results.
176
177
 
177
178
  Args:
178
179
  ctx: The MCP context, providing access to the backend service.
179
- group_id: The unique integer identifier of the statement group to retrieve.
180
- For example, `12345`.
180
+ group_id: A single unique integer identifier or a list of identifiers
181
+ of the statement group(s) to retrieve. For example, `12345` or
182
+ `[12345, 67890]`.
181
183
 
182
184
  Returns:
183
- A dictionary corresponding to the APISearchResultItem model if a
184
- statement group with the given ID is found (with
185
- `display_statement_text` omitted). Returns None (which will be
186
- serialized as JSON null by MCP) if no such group exists.
185
+ A list of dictionaries, where each dictionary corresponds to the
186
+ APISearchResultItem model. If an ID is not found, its corresponding
187
+ entry in the list will be None (serialized as JSON null by MCP).
187
188
  """
188
189
  backend = await _get_backend_from_context(ctx)
189
- logger.info(f"MCP Tool 'get_by_id' called for group_id: {group_id}")
190
+ logger.info(f"MCP Tool 'get_by_id' called for group_id(s): {group_id}")
190
191
 
191
- backend_item: Optional[APISearchResultItem]
192
+ backend_items: Union[
193
+ Optional[APISearchResultItem], List[Optional[APISearchResultItem]]
194
+ ]
192
195
  if asyncio.iscoroutinefunction(backend.get_by_id):
193
- backend_item = await backend.get_by_id(group_id=group_id)
196
+ backend_items = await backend.get_by_id(group_id=group_id)
194
197
  else:
195
- backend_item = backend.get_by_id(group_id=group_id)
198
+ backend_items = backend.get_by_id(group_id=group_id)
199
+
200
+ # Normalize to a list for consistent return type
201
+ items_list = (
202
+ [backend_items] if not isinstance(backend_items, list) else backend_items
203
+ )
204
+
205
+ mcp_items = []
206
+ for item in items_list:
207
+ if item:
208
+ mcp_item = _prepare_mcp_result_item(item)
209
+ mcp_items.append(mcp_item.model_dump(exclude_none=True))
210
+ else:
211
+ mcp_items.append(None)
196
212
 
197
- if backend_item:
198
- mcp_item = _prepare_mcp_result_item(backend_item)
199
- return mcp_item.model_dump(exclude_none=True)
200
- return None
213
+ return mcp_items
201
214
 
202
215
 
203
216
  @mcp_app.tool()
204
- async def get_dependencies(ctx: MCPContext, group_id: int) -> Optional[Dict[str, Any]]:
205
- """Retrieves the direct dependencies (citations) for a specific statement group.
217
+ async def get_dependencies(
218
+ ctx: MCPContext, group_id: Union[int, List[int]]
219
+ ) -> List[Optional[Dict[str, Any]]]:
220
+ """Retrieves direct dependencies (citations) for specific statement group(s).
206
221
 
207
222
  The `display_statement_text` field within each cited item is omitted
208
- from the response.
223
+ from the response. This tool always returns a list of results.
209
224
 
210
225
  Args:
211
226
  ctx: The MCP context, providing access to the backend service.
212
- group_id: The unique integer identifier of the statement group for which
213
- to fetch its direct dependencies. For example, `12345`.
227
+ group_id: A single unique integer identifier or a list of identifiers for
228
+ the statement group(s) for which to fetch direct dependencies.
229
+ For example, `12345` or `[12345, 67890]`.
214
230
 
215
231
  Returns:
216
- A dictionary corresponding to the APICitationsResponse model, which
217
- contains a list of cited statement groups (each with
218
- `display_statement_text` omitted), if the source group_id
219
- is found and has dependencies. Returns None (serialized as JSON null
220
- by MCP) if the source group is not found or has no dependencies.
232
+ A list of dictionaries, where each dictionary corresponds to the
233
+ APICitationsResponse model. If a source group ID is not found or has
234
+ no dependencies, its corresponding entry will be None.
221
235
  """
222
236
  backend = await _get_backend_from_context(ctx)
223
- logger.info(f"MCP Tool 'get_dependencies' called for group_id: {group_id}")
237
+ logger.info(f"MCP Tool 'get_dependencies' called for group_id(s): {group_id}")
224
238
 
225
- backend_response: Optional[APICitationsResponse]
239
+ backend_responses: Union[
240
+ Optional[APICitationsResponse], List[Optional[APICitationsResponse]]
241
+ ]
226
242
  if asyncio.iscoroutinefunction(backend.get_dependencies):
227
- backend_response = await backend.get_dependencies(group_id=group_id)
243
+ backend_responses = await backend.get_dependencies(group_id=group_id)
228
244
  else:
229
- backend_response = backend.get_dependencies(group_id=group_id)
245
+ backend_responses = backend.get_dependencies(group_id=group_id)
230
246
 
231
- if backend_response:
232
- mcp_citations_list = []
233
- for backend_item in backend_response.citations:
234
- mcp_citations_list.append(_prepare_mcp_result_item(backend_item))
235
-
236
- final_mcp_response = APICitationsResponse(
237
- source_group_id=backend_response.source_group_id,
238
- citations=mcp_citations_list,
239
- count=len(mcp_citations_list),
240
- )
241
- return final_mcp_response.model_dump(exclude_none=True)
242
- return None
247
+ # Normalize to a list for consistent return type
248
+ responses_list = (
249
+ [backend_responses]
250
+ if not isinstance(backend_responses, list)
251
+ else backend_responses
252
+ )
253
+ final_mcp_responses = []
254
+
255
+ for response in responses_list:
256
+ if response:
257
+ mcp_citations_list = []
258
+ for backend_item in response.citations:
259
+ mcp_citations_list.append(_prepare_mcp_result_item(backend_item))
260
+
261
+ final_response = APICitationsResponse(
262
+ source_group_id=response.source_group_id,
263
+ citations=mcp_citations_list,
264
+ count=len(mcp_citations_list),
265
+ )
266
+ final_mcp_responses.append(final_response.model_dump(exclude_none=True))
267
+ else:
268
+ final_mcp_responses.append(None)
269
+
270
+ return final_mcp_responses
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: lean-explore
3
- Version: 0.2.1
4
- Summary: A project to explore and rank Lean mathematical declarations.
3
+ Version: 0.3.0
4
+ Summary: A search engine for Lean 4 declarations.
5
5
  Author-email: Justin Asher <justinchadwickasher@gmail.com>
6
6
  License: Apache License
7
7
  Version 2.0, January 2004
@@ -213,15 +213,13 @@ Classifier: Intended Audience :: Developers
213
213
  Classifier: Intended Audience :: Science/Research
214
214
  Classifier: License :: OSI Approved :: Apache Software License
215
215
  Classifier: Programming Language :: Python :: 3
216
- Classifier: Programming Language :: Python :: 3.8
217
- Classifier: Programming Language :: Python :: 3.9
218
216
  Classifier: Programming Language :: Python :: 3.10
219
217
  Classifier: Programming Language :: Python :: 3.11
220
218
  Classifier: Programming Language :: Python :: 3.12
221
219
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
222
220
  Classifier: Topic :: Scientific/Engineering :: Mathematics
223
221
  Classifier: Topic :: Text Processing :: Indexing
224
- Requires-Python: >=3.8
222
+ Requires-Python: >=3.10
225
223
  Description-Content-Type: text/markdown
226
224
  License-File: LICENSE
227
225
  Requires-Dist: sqlalchemy>=2.0
@@ -241,7 +239,28 @@ Requires-Dist: tqdm>=4.60
241
239
  Requires-Dist: requests>=2.25.0
242
240
  Dynamic: license-file
243
241
 
244
- # LeanExplore
242
+ <h1 align="center">
243
+ LeanExplore
244
+ </h1>
245
+
246
+ <h3 align="center">
247
+ A search engine for Lean 4 declarations
248
+ </h3>
249
+
250
+ <p align="center">
251
+ <a href="https://pypi.org/project/lean-explore/">
252
+ <img src="https://img.shields.io/pypi/v/lean-explore.svg" alt="PyPI version" />
253
+ </a>
254
+ <a href="https://github.com/justincasher/lean-explore/blob/main/LeanExplore.pdf">
255
+ <img src="https://img.shields.io/badge/Paper-PDF-blue.svg" alt="Read the Paper" />
256
+ </a>
257
+ <a href="https://github.com/justincasher/lean-explore/commits/main">
258
+ <img src="https://img.shields.io/github/last-commit/justincasher/lean-explore" alt="last update" />
259
+ </a>
260
+ <a href="https://github.com/justincasher/lean-explore/blob/main/LICENSE">
261
+ <img src="https://img.shields.io/github/license/justincasher/lean-explore.svg" alt="license" />
262
+ </a>
263
+ </p>
245
264
 
246
265
  A search engine for Lean 4 declarations. This project provides tools and resources for exploring the Lean 4 ecosystem.
247
266
 
@@ -251,6 +270,7 @@ The current indexed projects include:
251
270
 
252
271
  * Batteries
253
272
  * Lean
273
+ * Init
254
274
  * Mathlib
255
275
  * PhysLean
256
276
  * Std
@@ -1,26 +1,26 @@
1
1
  lean_explore/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
2
2
  lean_explore/defaults.py,sha256=IJw6od-y0grYbwiDJ5ewNZI4u0j0dCCu_AXCDwWLHuA,4459
3
3
  lean_explore/api/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
4
- lean_explore/api/client.py,sha256=AgZG7pUY53Tl1WOhJgUdT0yxa_O1sHsals0pnjRD-Pc,4839
4
+ lean_explore/api/client.py,sha256=rvSIDbyqGl2I5b214VBBfT_UM9CvaHLa6DElsnUbi9E,7848
5
5
  lean_explore/cli/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
6
- lean_explore/cli/agent.py,sha256=jf1ebnViAqtKcZAGArqBf9YKHPbsOTpffOXr0Cd2M3Q,29933
6
+ lean_explore/cli/agent.py,sha256=BaC5uoK5HrySBJCB0aGcgLOE1N-UYlmbWz2hUcNUk44,30509
7
7
  lean_explore/cli/config_utils.py,sha256=RyIaDNP1UpUQZoy7HfaZ_JOXUgtzUP51Zrq_s6q7urY,16639
8
8
  lean_explore/cli/data_commands.py,sha256=mTBqFU7-fF4ZBGzCmNawZA_eHy0jyEMLlBEDEBXpxwY,21462
9
9
  lean_explore/cli/main.py,sha256=ZdbXy8x2VQ--JARqJMa9iFnrOhOCLcVgjpWhXkxj80o,24323
10
10
  lean_explore/local/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
11
11
  lean_explore/local/search.py,sha256=ZW8rKJ2riT6RRi6ngo8SylxQ_5jQbsipuv84kqpiwc4,40930
12
- lean_explore/local/service.py,sha256=7VT_njCiKlMqQItJ-Uy2aDraLttulAAgPhRp2BMWnSk,16166
12
+ lean_explore/local/service.py,sha256=AQAbYZ9tr3Yd_ED4weEnbRDwvkh7_0E-ERy1C1Abjlg,19292
13
13
  lean_explore/mcp/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
14
14
  lean_explore/mcp/app.py,sha256=XG6zTAaBRbdV1Ep_Za2JmifEmkYFKBmcGrdizCLlH-s,3808
15
15
  lean_explore/mcp/server.py,sha256=pzhLNGfTxelZvQ7ZJrWW0cbNH4MCwhviV24Y-yfQa0c,8666
16
- lean_explore/mcp/tools.py,sha256=Lri2GNIKNCLZNpNfvgvI600w3-0gaJGdPCFhhd7WVuk,9580
16
+ lean_explore/mcp/tools.py,sha256=L1U76Xg1nh3mRzq_zuEltVE2R-rsm7m6i4DFPkhqS48,10263
17
17
  lean_explore/shared/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
18
18
  lean_explore/shared/models/__init__.py,sha256=LK4g9wj7jCilTUfQcdQAxDf3F2GjMwzQZJYgnQ8ciGo,38
19
19
  lean_explore/shared/models/api.py,sha256=jejNDpgj-cu0KZTqkuOjM0useN4EvhvNB19lFFAOV94,4635
20
20
  lean_explore/shared/models/db.py,sha256=JYfIBnPrHZO2j7gHAVMlw9WSqVC2NinCG5KuBzdQWyk,16099
21
- lean_explore-0.2.1.dist-info/licenses/LICENSE,sha256=l4QLw1kIvEOjUktmmKm4dycK1E249Qs2s2AQTYbMXpY,11354
22
- lean_explore-0.2.1.dist-info/METADATA,sha256=5EWx5NniczmS6ApKVvoHj1RfgNC6eO9JgOIyZyNA1SY,15611
23
- lean_explore-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
- lean_explore-0.2.1.dist-info/entry_points.txt,sha256=JXl2Mo3BRX4jAU-Nxg_CWJR790pB_oi5qnt3Pv5iZnk,58
25
- lean_explore-0.2.1.dist-info/top_level.txt,sha256=h51BKWrFvB7iym-IlaNAAHX5MZfA8Gmg-aDuXGo0fQ8,13
26
- lean_explore-0.2.1.dist-info/RECORD,,
21
+ lean_explore-0.3.0.dist-info/licenses/LICENSE,sha256=l4QLw1kIvEOjUktmmKm4dycK1E249Qs2s2AQTYbMXpY,11354
22
+ lean_explore-0.3.0.dist-info/METADATA,sha256=gJTIosn6cuK8s1x0Q8XD4s5RZydzow_jUZFlZsKUpIM,16304
23
+ lean_explore-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
24
+ lean_explore-0.3.0.dist-info/entry_points.txt,sha256=JXl2Mo3BRX4jAU-Nxg_CWJR790pB_oi5qnt3Pv5iZnk,58
25
+ lean_explore-0.3.0.dist-info/top_level.txt,sha256=h51BKWrFvB7iym-IlaNAAHX5MZfA8Gmg-aDuXGo0fQ8,13
26
+ lean_explore-0.3.0.dist-info/RECORD,,