datahub-agent-context 1.3.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. datahub_agent_context/__init__.py +25 -0
  2. datahub_agent_context/_version.py +16 -0
  3. datahub_agent_context/context.py +97 -0
  4. datahub_agent_context/langchain_tools/__init__.py +8 -0
  5. datahub_agent_context/langchain_tools/builder.py +127 -0
  6. datahub_agent_context/mcp_tools/__init__.py +46 -0
  7. datahub_agent_context/mcp_tools/_token_estimator.py +71 -0
  8. datahub_agent_context/mcp_tools/base.py +325 -0
  9. datahub_agent_context/mcp_tools/descriptions.py +299 -0
  10. datahub_agent_context/mcp_tools/documents.py +473 -0
  11. datahub_agent_context/mcp_tools/domains.py +246 -0
  12. datahub_agent_context/mcp_tools/entities.py +349 -0
  13. datahub_agent_context/mcp_tools/get_me.py +99 -0
  14. datahub_agent_context/mcp_tools/gql/__init__.py +13 -0
  15. datahub_agent_context/mcp_tools/gql/document_search.gql +114 -0
  16. datahub_agent_context/mcp_tools/gql/document_semantic_search.gql +111 -0
  17. datahub_agent_context/mcp_tools/gql/entity_details.gql +1682 -0
  18. datahub_agent_context/mcp_tools/gql/queries.gql +51 -0
  19. datahub_agent_context/mcp_tools/gql/query_entity.gql +37 -0
  20. datahub_agent_context/mcp_tools/gql/read_documents.gql +16 -0
  21. datahub_agent_context/mcp_tools/gql/search.gql +242 -0
  22. datahub_agent_context/mcp_tools/helpers.py +448 -0
  23. datahub_agent_context/mcp_tools/lineage.py +698 -0
  24. datahub_agent_context/mcp_tools/owners.py +318 -0
  25. datahub_agent_context/mcp_tools/queries.py +191 -0
  26. datahub_agent_context/mcp_tools/search.py +239 -0
  27. datahub_agent_context/mcp_tools/structured_properties.py +447 -0
  28. datahub_agent_context/mcp_tools/tags.py +296 -0
  29. datahub_agent_context/mcp_tools/terms.py +295 -0
  30. datahub_agent_context/py.typed +2 -0
  31. datahub_agent_context-1.3.1.8.dist-info/METADATA +233 -0
  32. datahub_agent_context-1.3.1.8.dist-info/RECORD +34 -0
  33. datahub_agent_context-1.3.1.8.dist-info/WHEEL +5 -0
  34. datahub_agent_context-1.3.1.8.dist-info/top_level.txt +1 -0
@@ -0,0 +1,448 @@
1
+ """Helper functions for MCP tools."""
2
+
3
+ import html
4
+ import logging
5
+ import os
6
+ import re
7
+ from typing import Any, Callable, Generator, Iterator, List, Optional, TypeVar
8
+
9
+ import jmespath # type: ignore[import-untyped]
10
+
11
+ from datahub.ingestion.graph.client import DataHubGraph
12
+ from datahub.metadata.urns import DatasetUrn, SchemaFieldUrn, Urn
13
+ from datahub_agent_context.mcp_tools._token_estimator import TokenCountEstimator
14
+ from datahub_agent_context.mcp_tools.base import _is_datahub_cloud
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ T = TypeVar("T")
19
+
20
+ DESCRIPTION_LENGTH_HARD_LIMIT = 1000
21
+ QUERY_LENGTH_HARD_LIMIT = 5000
22
+
23
+ # Maximum token count for tool responses to prevent context window issues
24
+ TOOL_RESPONSE_TOKEN_LIMIT = int(os.getenv("TOOL_RESPONSE_TOKEN_LIMIT", "80000"))
25
+
26
+ # Per-entity schema token budget for field truncation
27
+ ENTITY_SCHEMA_TOKEN_BUDGET = int(os.getenv("ENTITY_SCHEMA_TOKEN_BUDGET", "16000"))
28
+
29
+
30
+ def sanitize_html_content(text: str) -> str:
31
+ """Remove HTML tags and decode HTML entities from text.
32
+
33
+ Uses a bounded regex pattern to prevent ReDoS (Regular Expression Denial of Service)
34
+ attacks. The pattern limits matching to tags with at most 100 characters between < and >,
35
+ which prevents backtracking on malicious input like "<" followed by millions of characters
36
+ without a closing ">".
37
+ """
38
+ if not text:
39
+ return text
40
+
41
+ # Use bounded regex to prevent ReDoS (max 100 chars between < and >)
42
+ text = re.sub(r"<[^<>]{0,100}>", "", text)
43
+
44
+ # Decode HTML entities
45
+ text = html.unescape(text)
46
+
47
+ return text.strip()
48
+
49
+
50
+ def truncate_with_ellipsis(text: str, max_length: int, suffix: str = "...") -> str:
51
+ """Truncate text to max_length and add suffix if truncated."""
52
+ if not text or len(text) <= max_length:
53
+ return text
54
+
55
+ # Account for suffix length
56
+ actual_max = max_length - len(suffix)
57
+ return text[:actual_max] + suffix
58
+
59
+
60
+ def sanitize_markdown_content(text: str) -> str:
61
+ """Remove markdown-style embeds that contain encoded data from text, but preserve alt text."""
62
+ if not text:
63
+ return text
64
+
65
+ # Remove markdown embeds with data URLs (base64 encoded content) but preserve alt text
66
+ # Pattern: ![alt text](data:image/type;base64,encoded_data) -> alt text
67
+ text = re.sub(r"!\[([^\]]*)\]\(data:[^)]+\)", r"\1", text)
68
+
69
+ return text.strip()
70
+
71
+
72
+ def sanitize_and_truncate_description(text: str, max_length: int) -> str:
73
+ """Sanitize HTML content and truncate to specified length."""
74
+ if not text:
75
+ return text
76
+
77
+ try:
78
+ # First sanitize HTML content
79
+ sanitized = sanitize_html_content(text)
80
+
81
+ # Then sanitize markdown content (preserving alt text)
82
+ sanitized = sanitize_markdown_content(sanitized)
83
+
84
+ # Then truncate if needed
85
+ return truncate_with_ellipsis(sanitized, max_length)
86
+ except Exception as e:
87
+ logger.warning(f"Error sanitizing and truncating description: {e}")
88
+ return text[:max_length] if len(text) > max_length else text
89
+
90
+
91
+ def truncate_descriptions(
92
+ data: dict | list, max_length: int = DESCRIPTION_LENGTH_HARD_LIMIT
93
+ ) -> None:
94
+ """Recursively truncates values of keys named 'description' in a dictionary in place."""
95
+ if isinstance(data, dict):
96
+ for key, value in data.items():
97
+ if key == "description" and isinstance(value, str):
98
+ data[key] = sanitize_and_truncate_description(value, max_length)
99
+ elif isinstance(value, (dict, list)):
100
+ truncate_descriptions(value)
101
+ elif isinstance(data, list):
102
+ for item in data:
103
+ truncate_descriptions(item)
104
+
105
+
106
+ def truncate_query(query: str) -> str:
107
+ """Truncate a SQL query if it exceeds the maximum length."""
108
+ return truncate_with_ellipsis(
109
+ query, QUERY_LENGTH_HARD_LIMIT, suffix="... [truncated]"
110
+ )
111
+
112
+
113
+ def inject_urls_for_urns(
114
+ graph: DataHubGraph, response: Any, json_paths: List[str]
115
+ ) -> None:
116
+ """Inject URLs for URNs in the response at specified JSON paths (in place).
117
+
118
+ Only works for DataHub Cloud instances.
119
+
120
+ Args:
121
+ graph: DataHubGraph instance
122
+ response: Response dict to modify in place
123
+ json_paths: List of JMESPath expressions to find URNs
124
+ """
125
+ if not _is_datahub_cloud(graph):
126
+ return
127
+
128
+ for path in json_paths:
129
+ for item in jmespath.search(path, response) if path else [response]:
130
+ if isinstance(item, dict) and item.get("urn"):
131
+ # Update item in place with url, ensuring that urn and url are first.
132
+ new_item = {"urn": item["urn"], "url": graph.url_for(item["urn"])}
133
+ new_item.update({k: v for k, v in item.items() if k != "urn"})
134
+ item.clear()
135
+ item.update(new_item)
136
+
137
+
138
+ def maybe_convert_to_schema_field_urn(urn: str, column: Optional[str]) -> str:
139
+ """Convert a dataset URN to a schema field URN if column is provided.
140
+
141
+ Args:
142
+ urn: Dataset URN
143
+ column: Optional column name
144
+
145
+ Returns:
146
+ SchemaField URN if column provided, otherwise original URN
147
+
148
+ Raises:
149
+ ValueError: If column provided but URN is not a dataset URN
150
+ """
151
+ if column:
152
+ maybe_dataset_urn = Urn.from_string(urn)
153
+ if not isinstance(maybe_dataset_urn, DatasetUrn):
154
+ raise ValueError(
155
+ f"Input urn should be a dataset urn if column is provided, but got {urn}."
156
+ )
157
+ urn = str(SchemaFieldUrn(maybe_dataset_urn, column))
158
+ return urn
159
+
160
+
161
+ def _select_results_within_budget(
162
+ results: Iterator[T],
163
+ fetch_entity: Callable[[T], dict],
164
+ max_results: int = 10,
165
+ token_budget: Optional[int] = None,
166
+ ) -> Generator[T, None, None]:
167
+ """Generator that yields results within token budget.
168
+
169
+ Generic helper that works for any result structure. Caller provides a function
170
+ to extract/clean entity for token counting (can mutate the result).
171
+
172
+ Yields results until:
173
+ - max_results reached, OR
174
+ - token_budget would be exceeded (and we have at least 1 result)
175
+
176
+ Args:
177
+ results: Iterator of result objects of any type T (memory efficient)
178
+ fetch_entity: Function that extracts entity dict from result for token counting.
179
+ Can mutate the result to clean/update entity in place.
180
+ Signature: T -> dict (entity for token counting)
181
+ Example: lambda r: (r.__setitem__("entity", clean(r["entity"])), r["entity"])[1]
182
+ max_results: Maximum number of results to return
183
+ token_budget: Token budget (defaults to 90% of TOOL_RESPONSE_TOKEN_LIMIT)
184
+
185
+ Yields:
186
+ Original result objects of type T (possibly mutated by fetch_entity)
187
+ """
188
+ if token_budget is None:
189
+ # Use 90% of limit as safety buffer:
190
+ # - Token estimation is approximate, not exact
191
+ # - Response wrapper adds overhead
192
+ # - Better to return fewer results that fit than exceed limit
193
+ token_budget = int(TOOL_RESPONSE_TOKEN_LIMIT * 0.9)
194
+
195
+ total_tokens = 0
196
+ results_count = 0
197
+
198
+ # Consume iterator up to max_results
199
+ for i, result in enumerate(results):
200
+ if i >= max_results:
201
+ break
202
+ # Extract (and possibly clean) entity using caller's lambda
203
+ # Note: fetch_entity may mutate result to clean/update entity in place
204
+ entity = fetch_entity(result)
205
+
206
+ # Estimate token cost
207
+ entity_tokens = TokenCountEstimator.estimate_dict_tokens(entity)
208
+
209
+ # Check if adding this entity would exceed budget
210
+ if total_tokens + entity_tokens > token_budget:
211
+ if results_count == 0:
212
+ # Always yield at least 1 result
213
+ logger.warning(
214
+ f"First result ({entity_tokens:,} tokens) exceeds budget ({token_budget:,}), "
215
+ "yielding it anyway"
216
+ )
217
+ yield result # Yield original result structure
218
+ results_count += 1
219
+ total_tokens += entity_tokens
220
+ else:
221
+ # Have at least 1 result, stop here to stay within budget
222
+ logger.info(
223
+ f"Stopping at {results_count} results (next would exceed {token_budget:,} token budget)"
224
+ )
225
+ break
226
+ else:
227
+ yield result # Yield original result structure
228
+ results_count += 1
229
+ total_tokens += entity_tokens
230
+
231
+ logger.info(
232
+ f"Selected {results_count} results using {total_tokens:,} tokens "
233
+ f"(budget: {token_budget:,})"
234
+ )
235
+
236
+
237
+ def clean_get_entities_response(
238
+ raw_response: dict,
239
+ *,
240
+ sort_fn: Optional[Callable[[List[dict]], Iterator[dict]]] = None,
241
+ offset: int = 0,
242
+ limit: Optional[int] = None,
243
+ ) -> dict:
244
+ """Clean and optimize entity responses for LLM consumption.
245
+
246
+ Performs several transformations to reduce token usage while preserving essential information:
247
+
248
+ 1. **Clean GraphQL artifacts**: Removes __typename, null values, empty objects/arrays
249
+ (via clean_gql_response)
250
+
251
+ 2. **Schema field processing** (if schemaMetadata.fields exists):
252
+ - Sorts fields using sort_fn (defaults to _sort_fields_by_priority)
253
+ - Cleans each field to keep only essential properties (fieldPath, type, description, etc.)
254
+ - Merges editableSchemaMetadata into fields with "edited*" prefix (editedDescription,
255
+ editedTags, editedGlossaryTerms) - only included when they differ from system values
256
+ - Applies pagination (offset/limit) with token budget constraint
257
+ - Field selection stops when EITHER limit is reached OR ENTITY_SCHEMA_TOKEN_BUDGET is exceeded
258
+ - Adds schemaFieldsTruncated metadata when fields are cut
259
+
260
+ 3. **Remove duplicates**: Deletes editableSchemaMetadata after merging into schemaMetadata
261
+
262
+ 4. **Truncate view definitions**: Limits SQL view logic to QUERY_LENGTH_HARD_LIMIT
263
+
264
+ The result is optimized for LLM tool responses: reduced token usage, no duplication,
265
+ clear distinction between system-generated and user-curated content.
266
+
267
+ Args:
268
+ raw_response: Raw entity dict from GraphQL query
269
+ sort_fn: Optional custom function to sort fields. If None, uses _sort_fields_by_priority.
270
+ Should take a list of field dicts and return an iterator of sorted fields.
271
+ offset: Number of fields to skip after sorting (default: 0)
272
+ limit: Maximum number of fields to include after offset (default: None = unlimited)
273
+
274
+ Returns:
275
+ Cleaned entity dict optimized for LLM consumption
276
+ """
277
+ from datahub_agent_context.mcp_tools.base import clean_gql_response
278
+
279
+ response = clean_gql_response(raw_response)
280
+
281
+ if response and (schema_metadata := response.get("schemaMetadata")):
282
+ # Remove empty platformSchema to reduce response clutter
283
+ if platform_schema := schema_metadata.get("platformSchema"):
284
+ schema_value = platform_schema.get("schema")
285
+ if not schema_value or schema_value == "":
286
+ del schema_metadata["platformSchema"]
287
+
288
+ # Process schema fields with sorting and budget constraint
289
+ if fields := schema_metadata.get("fields"):
290
+ # Use custom sort function if provided, otherwise sort by priority
291
+ sorted_fields = iter(fields) if sort_fn is None else sort_fn(fields)
292
+
293
+ # Apply offset/limit with token budget constraint
294
+ selected_fields: list[dict] = []
295
+ total_schema_tokens = 0
296
+ fields_truncated = False
297
+
298
+ for i, field in enumerate(sorted_fields):
299
+ # Skip fields before offset
300
+ if i < offset:
301
+ continue
302
+
303
+ # Check limit
304
+ if limit is not None and len(selected_fields) >= limit:
305
+ fields_truncated = True
306
+ break
307
+
308
+ # Estimate tokens for this field
309
+ field_tokens = TokenCountEstimator.estimate_dict_tokens(field)
310
+
311
+ # Check token budget
312
+ if total_schema_tokens + field_tokens > ENTITY_SCHEMA_TOKEN_BUDGET:
313
+ if len(selected_fields) == 0:
314
+ # Always include at least one field
315
+ logger.warning(
316
+ f"First field ({field_tokens:,} tokens) exceeds schema budget "
317
+ f"({ENTITY_SCHEMA_TOKEN_BUDGET:,}), including it anyway"
318
+ )
319
+ selected_fields.append(field)
320
+ total_schema_tokens += field_tokens
321
+ fields_truncated = True
322
+ break
323
+
324
+ selected_fields.append(field)
325
+ total_schema_tokens += field_tokens
326
+
327
+ schema_metadata["fields"] = selected_fields
328
+
329
+ if fields_truncated:
330
+ schema_metadata["schemaFieldsTruncated"] = True
331
+ logger.info(
332
+ f"Truncated schema fields: showing {len(selected_fields)} of {len(fields)} fields "
333
+ f"({total_schema_tokens:,} tokens)"
334
+ )
335
+
336
+ # Remove editableSchemaMetadata if present (already merged into fields)
337
+ if response and "editableSchemaMetadata" in response:
338
+ del response["editableSchemaMetadata"]
339
+
340
+ # Truncate view definitions
341
+ if response and (view_properties := response.get("viewProperties")):
342
+ if logic := view_properties.get("logic"):
343
+ view_properties["logic"] = truncate_query(logic)
344
+
345
+ return response
346
+
347
+
348
+ def _extract_lineage_columns_from_paths(search_results: List[dict]) -> List[dict]:
349
+ """Extract column information from paths field for column-level lineage results.
350
+
351
+ When querying column-level lineage (e.g., get_lineage(urn, column="user_id")),
352
+ the GraphQL response returns DATASET entities (not individual columns) with a
353
+ 'paths' field containing the column-level lineage chains.
354
+
355
+ Each path shows the column flow, e.g.:
356
+ source_table.user_id -> intermediate_table.uid -> target_table.customer_id
357
+
358
+ The LAST entity in each path is a SchemaFieldEntity representing a column in the
359
+ target dataset. This function extracts those column names into a 'lineageColumns' field.
360
+
361
+ Args:
362
+ search_results: List of lineage search results where entities are DATASET
363
+
364
+ Returns:
365
+ Same list with 'lineageColumns' field added to each result:
366
+ - entity: Dataset entity (unchanged)
367
+ - lineageColumns: List of unique column names (fieldPath) from path endpoints
368
+ - degree: Degree value (unchanged)
369
+ - paths: Removed to reduce response size (column info extracted to lineageColumns)
370
+
371
+ Example transformation:
372
+ Input: [
373
+ {
374
+ entity: {type: "DATASET", name: "target_table"},
375
+ paths: [
376
+ {path: [
377
+ {type: "SCHEMA_FIELD", fieldPath: "user_id"},
378
+ {type: "SCHEMA_FIELD", fieldPath: "customer_id"} # <- target column
379
+ ]},
380
+ {path: [
381
+ {type: "SCHEMA_FIELD", fieldPath: "user_id"},
382
+ {type: "SCHEMA_FIELD", fieldPath: "uid"} # <- another target column
383
+ ]}
384
+ ],
385
+ degree: 1
386
+ }
387
+ ]
388
+ Output: [
389
+ {
390
+ entity: {type: "DATASET", name: "target_table"},
391
+ lineageColumns: ["customer_id", "uid"],
392
+ degree: 1
393
+ }
394
+ ]
395
+ """
396
+ if not search_results:
397
+ return search_results
398
+
399
+ # Check if this is column-level lineage by looking for paths
400
+ has_column_paths = any(
401
+ result.get("paths") and len(result.get("paths", [])) > 0
402
+ for result in search_results
403
+ )
404
+
405
+ if not has_column_paths:
406
+ # Not column-level lineage (or no paths available), return as-is
407
+ return search_results
408
+
409
+ processed_results = []
410
+ for result in search_results:
411
+ paths = result.get("paths", [])
412
+
413
+ if not paths:
414
+ # No paths for this result, keep as-is
415
+ processed_results.append(result)
416
+ continue
417
+
418
+ # Extract column names from the LAST entity in each path
419
+ lineage_columns = []
420
+ for path_obj in paths:
421
+ path = path_obj.get("path", [])
422
+ if not path:
423
+ continue
424
+
425
+ # Get the last entity in the path (target column)
426
+ last_entity = path[-1]
427
+ if last_entity.get("type") == "SCHEMA_FIELD":
428
+ field_path = last_entity.get("fieldPath")
429
+ if field_path and field_path not in lineage_columns:
430
+ lineage_columns.append(field_path)
431
+
432
+ # Create new result with lineageColumns
433
+ new_result = {
434
+ "entity": result["entity"],
435
+ "degree": result.get("degree", 0),
436
+ }
437
+
438
+ if lineage_columns:
439
+ new_result["lineageColumns"] = lineage_columns
440
+
441
+ # Keep other fields that might exist (explored, truncatedChildren, etc.)
442
+ for key in ["explored", "truncatedChildren", "ignoredAsHop"]:
443
+ if key in result:
444
+ new_result[key] = result[key]
445
+
446
+ processed_results.append(new_result)
447
+
448
+ return processed_results