glean-indexing-sdk 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glean/indexing/__init__.py +56 -0
- glean/indexing/common/__init__.py +15 -0
- glean/indexing/common/batch_processor.py +31 -0
- glean/indexing/common/content_formatter.py +46 -0
- glean/indexing/common/glean_client.py +18 -0
- glean/indexing/common/metrics.py +54 -0
- glean/indexing/common/mocks.py +20 -0
- glean/indexing/connectors/__init__.py +21 -0
- glean/indexing/connectors/base_connector.py +60 -0
- glean/indexing/connectors/base_data_client.py +35 -0
- glean/indexing/connectors/base_datasource_connector.py +314 -0
- glean/indexing/connectors/base_people_connector.py +154 -0
- glean/indexing/connectors/base_streaming_data_client.py +39 -0
- glean/indexing/connectors/base_streaming_datasource_connector.py +184 -0
- glean/indexing/models.py +45 -0
- glean/indexing/observability/__init__.py +19 -0
- glean/indexing/observability/observability.py +262 -0
- glean/indexing/py.typed +1 -0
- glean/indexing/testing/__init__.py +13 -0
- glean/indexing/testing/connector_test_harness.py +53 -0
- glean/indexing/testing/mock_data_source.py +47 -0
- glean/indexing/testing/mock_glean_client.py +69 -0
- glean/indexing/testing/response_validator.py +52 -0
- glean_indexing_sdk-0.0.3.dist-info/METADATA +482 -0
- glean_indexing_sdk-0.0.3.dist-info/RECORD +27 -0
- glean_indexing_sdk-0.0.3.dist-info/WHEEL +4 -0
- glean_indexing_sdk-0.0.3.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""Observability infrastructure for Glean connectors."""
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
# Type variable for decorated classes
|
|
12
|
+
T = TypeVar("T")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConnectorObservability:
|
|
16
|
+
"""
|
|
17
|
+
Centralized observability for connector operations.
|
|
18
|
+
|
|
19
|
+
Tracks metrics, performance, and provides structured logging.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, connector_name: str):
|
|
23
|
+
self.connector_name = connector_name
|
|
24
|
+
self.metrics: Dict[str, Any] = defaultdict(int)
|
|
25
|
+
self.timers: Dict[str, float] = {}
|
|
26
|
+
self.start_time: Optional[float] = None
|
|
27
|
+
|
|
28
|
+
def start_execution(self):
|
|
29
|
+
"""Mark the start of connector execution."""
|
|
30
|
+
self.start_time = time.time()
|
|
31
|
+
logger.info(f"[{self.connector_name}] Execution started")
|
|
32
|
+
|
|
33
|
+
def end_execution(self):
|
|
34
|
+
"""Mark the end of connector execution."""
|
|
35
|
+
if self.start_time:
|
|
36
|
+
duration = time.time() - self.start_time
|
|
37
|
+
self.metrics["total_execution_time"] = duration
|
|
38
|
+
logger.info(f"[{self.connector_name}] Execution completed in {duration:.2f}s")
|
|
39
|
+
|
|
40
|
+
def record_metric(self, key: str, value: Any):
|
|
41
|
+
"""Record a custom metric."""
|
|
42
|
+
self.metrics[key] = value
|
|
43
|
+
logger.debug(f"[{self.connector_name}] Metric recorded: {key}={value}")
|
|
44
|
+
|
|
45
|
+
def increment_counter(self, key: str, value: int = 1):
|
|
46
|
+
"""Increment a counter metric."""
|
|
47
|
+
self.metrics[key] += value
|
|
48
|
+
|
|
49
|
+
def start_timer(self, operation: str):
|
|
50
|
+
"""Start timing an operation."""
|
|
51
|
+
self.timers[operation] = time.time()
|
|
52
|
+
|
|
53
|
+
def end_timer(self, operation: str):
|
|
54
|
+
"""End timing an operation and record the duration."""
|
|
55
|
+
if operation in self.timers:
|
|
56
|
+
duration = time.time() - self.timers[operation]
|
|
57
|
+
self.record_metric(f"{operation}_duration", duration)
|
|
58
|
+
del self.timers[operation]
|
|
59
|
+
return duration
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
def get_metrics_summary(self) -> Dict[str, Any]:
|
|
63
|
+
"""Get a summary of all collected metrics."""
|
|
64
|
+
return dict(self.metrics)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def with_observability(
|
|
68
|
+
exclude_methods: Optional[List[str]] = None,
|
|
69
|
+
include_args: bool = False,
|
|
70
|
+
include_return: bool = False,
|
|
71
|
+
) -> Callable[[type], type]:
|
|
72
|
+
"""
|
|
73
|
+
Class decorator that adds comprehensive logging to all public methods.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
exclude_methods: List of method names to exclude from logging
|
|
77
|
+
include_args: Whether to log method arguments
|
|
78
|
+
include_return: Whether to log return values
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Decorated class with enhanced logging
|
|
82
|
+
"""
|
|
83
|
+
if exclude_methods is None:
|
|
84
|
+
exclude_methods = ["__init__", "__str__", "__repr__"]
|
|
85
|
+
|
|
86
|
+
def decorator(cls: type) -> type:
|
|
87
|
+
def wrap_method(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
88
|
+
if method.__name__ in exclude_methods:
|
|
89
|
+
return method
|
|
90
|
+
|
|
91
|
+
@functools.wraps(method)
|
|
92
|
+
def wrapped_method(self, *args: Any, **kwargs: Any) -> Any:
|
|
93
|
+
method_name = method.__name__
|
|
94
|
+
class_name = self.__class__.__name__
|
|
95
|
+
|
|
96
|
+
# Log method start
|
|
97
|
+
if include_args:
|
|
98
|
+
logger.info(
|
|
99
|
+
f"[{class_name}] {method_name} started with args={args}, kwargs={kwargs}"
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
logger.info(f"[{class_name}] {method_name} started")
|
|
103
|
+
|
|
104
|
+
start_time = time.time()
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
result = method(self, *args, **kwargs)
|
|
108
|
+
duration = time.time() - start_time
|
|
109
|
+
|
|
110
|
+
# Log successful completion
|
|
111
|
+
if include_return:
|
|
112
|
+
logger.info(
|
|
113
|
+
f"[{class_name}] {method_name} completed in {duration:.3f}s with result={result}"
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
logger.info(f"[{class_name}] {method_name} completed in {duration:.3f}s")
|
|
117
|
+
|
|
118
|
+
# Record timing metric if observability is available
|
|
119
|
+
if hasattr(self, "_observability"):
|
|
120
|
+
self._observability.record_metric(f"{method_name}_duration", duration)
|
|
121
|
+
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
duration = time.time() - start_time
|
|
126
|
+
logger.error(f"[{class_name}] {method_name} failed after {duration:.3f}s: {e}")
|
|
127
|
+
|
|
128
|
+
# Record error metric if observability is available
|
|
129
|
+
if hasattr(self, "_observability"):
|
|
130
|
+
self._observability.increment_counter(f"{method_name}_errors")
|
|
131
|
+
|
|
132
|
+
raise
|
|
133
|
+
|
|
134
|
+
return wrapped_method
|
|
135
|
+
|
|
136
|
+
# Apply the wrapper to all public methods
|
|
137
|
+
for attr_name, attr_value in cls.__dict__.items():
|
|
138
|
+
if callable(attr_value) and not attr_name.startswith("_"):
|
|
139
|
+
setattr(cls, attr_name, wrap_method(attr_value))
|
|
140
|
+
|
|
141
|
+
return cls
|
|
142
|
+
|
|
143
|
+
return decorator
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def track_crawl_progress(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
147
|
+
"""
|
|
148
|
+
Decorator that tracks crawling progress and item counts.
|
|
149
|
+
|
|
150
|
+
Expects the method to return a sequence and increments crawl metrics.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
@functools.wraps(method)
|
|
154
|
+
def wrapper(self, *args: Any, **kwargs: Any) -> Any:
|
|
155
|
+
result = method(self, *args, **kwargs)
|
|
156
|
+
|
|
157
|
+
# Track item count if result is a sequence
|
|
158
|
+
if hasattr(result, "__len__"):
|
|
159
|
+
item_count = len(result)
|
|
160
|
+
if hasattr(self, "_observability"):
|
|
161
|
+
self._observability.increment_counter("items_processed", item_count)
|
|
162
|
+
self._observability.increment_counter("total_items_crawled", item_count)
|
|
163
|
+
logger.info(f"Processed {item_count} items in {method.__name__}")
|
|
164
|
+
|
|
165
|
+
return result
|
|
166
|
+
|
|
167
|
+
return wrapper
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class PerformanceTracker:
|
|
171
|
+
"""
|
|
172
|
+
Context manager for tracking performance of operations.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
def __init__(self, operation_name: str, observability: Optional[ConnectorObservability] = None):
|
|
176
|
+
self.operation_name = operation_name
|
|
177
|
+
self.observability = observability
|
|
178
|
+
self.start_time: Optional[float] = None
|
|
179
|
+
|
|
180
|
+
def __enter__(self):
|
|
181
|
+
self.start_time = time.time()
|
|
182
|
+
logger.info(f"Starting operation: {self.operation_name}")
|
|
183
|
+
return self
|
|
184
|
+
|
|
185
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
186
|
+
if self.start_time:
|
|
187
|
+
duration = time.time() - self.start_time
|
|
188
|
+
|
|
189
|
+
if exc_type is None:
|
|
190
|
+
logger.info(f"Operation '{self.operation_name}' completed in {duration:.3f}s")
|
|
191
|
+
else:
|
|
192
|
+
logger.error(
|
|
193
|
+
f"Operation '{self.operation_name}' failed after {duration:.3f}s: {exc_val}"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
if self.observability:
|
|
197
|
+
self.observability.record_metric(f"{self.operation_name}_duration", duration)
|
|
198
|
+
if exc_type is not None:
|
|
199
|
+
self.observability.increment_counter(f"{self.operation_name}_errors")
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class ProgressCallback:
|
|
203
|
+
"""
|
|
204
|
+
Callback interface for tracking connector progress.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def __init__(self, total_items: Optional[int] = None):
|
|
208
|
+
self.total_items = total_items
|
|
209
|
+
self.processed_items = 0
|
|
210
|
+
self.start_time = time.time()
|
|
211
|
+
|
|
212
|
+
def update(self, items_processed: int):
|
|
213
|
+
"""Update progress with number of items processed."""
|
|
214
|
+
self.processed_items += items_processed
|
|
215
|
+
elapsed = time.time() - self.start_time
|
|
216
|
+
|
|
217
|
+
if self.total_items:
|
|
218
|
+
progress_pct = (self.processed_items / self.total_items) * 100
|
|
219
|
+
logger.info(
|
|
220
|
+
f"Progress: {self.processed_items}/{self.total_items} ({progress_pct:.1f}%) - "
|
|
221
|
+
f"Rate: {self.processed_items / elapsed:.1f} items/sec"
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
logger.info(
|
|
225
|
+
f"Progress: {self.processed_items} items processed - "
|
|
226
|
+
f"Rate: {self.processed_items / elapsed:.1f} items/sec"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def complete(self):
|
|
230
|
+
"""Mark progress as complete."""
|
|
231
|
+
elapsed = time.time() - self.start_time
|
|
232
|
+
logger.info(
|
|
233
|
+
f"Processing complete: {self.processed_items} items in {elapsed:.2f}s "
|
|
234
|
+
f"(avg rate: {self.processed_items / elapsed:.1f} items/sec)"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def setup_connector_logging(
|
|
239
|
+
connector_name: str, log_level: str = "INFO", log_format: Optional[str] = None
|
|
240
|
+
):
|
|
241
|
+
"""
|
|
242
|
+
Set up standardized logging for a connector.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
connector_name: Name of the connector for log identification
|
|
246
|
+
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
|
247
|
+
log_format: Custom log format string
|
|
248
|
+
"""
|
|
249
|
+
if log_format is None:
|
|
250
|
+
log_format = f"%(asctime)s - {connector_name} - %(name)s - %(levelname)s - %(message)s"
|
|
251
|
+
|
|
252
|
+
logging.basicConfig(
|
|
253
|
+
level=getattr(logging, log_level.upper()),
|
|
254
|
+
format=log_format,
|
|
255
|
+
handlers=[
|
|
256
|
+
logging.StreamHandler(),
|
|
257
|
+
# Add file handler if needed
|
|
258
|
+
# logging.FileHandler(f"{connector_name}.log")
|
|
259
|
+
],
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
logger.info(f"Logging configured for connector: {connector_name}")
|
glean/indexing/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Testing utilities for Glean connectors."""
|
|
2
|
+
|
|
3
|
+
from glean.indexing.testing.connector_test_harness import ConnectorTestHarness
|
|
4
|
+
from glean.indexing.testing.mock_data_source import MockDataSource
|
|
5
|
+
from glean.indexing.testing.mock_glean_client import MockGleanClient
|
|
6
|
+
from glean.indexing.testing.response_validator import ResponseValidator
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"ConnectorTestHarness",
|
|
10
|
+
"MockDataSource",
|
|
11
|
+
"MockGleanClient",
|
|
12
|
+
"ResponseValidator",
|
|
13
|
+
]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Test harness for running and validating connectors."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from unittest.mock import patch
|
|
5
|
+
|
|
6
|
+
from glean.indexing.connectors import BaseConnector, BaseDatasourceConnector, BasePeopleConnector
|
|
7
|
+
from glean.indexing.testing.mock_glean_client import MockGleanClient
|
|
8
|
+
from glean.indexing.testing.response_validator import ResponseValidator
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConnectorTestHarness:
|
|
14
|
+
"""Test harness for connectors that works with the new dependency injection pattern."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, connector: BaseConnector):
|
|
17
|
+
"""Initialize the ConnectorTestHarness.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
connector: The connector to test.
|
|
21
|
+
"""
|
|
22
|
+
self.connector = connector
|
|
23
|
+
self.validator = ResponseValidator()
|
|
24
|
+
self.mock_client = MockGleanClient(self.validator)
|
|
25
|
+
|
|
26
|
+
def run(self) -> None:
|
|
27
|
+
"""Run the connector."""
|
|
28
|
+
logger.info(f"Running test harness for connector '{self.connector.name}'")
|
|
29
|
+
|
|
30
|
+
# Reset validator
|
|
31
|
+
self.validator.reset()
|
|
32
|
+
|
|
33
|
+
# Patch the api_client to return our mock client
|
|
34
|
+
with (
|
|
35
|
+
patch(
|
|
36
|
+
"glean.indexing.connectors.base_datasource_connector.api_client"
|
|
37
|
+
) as mock_api_client,
|
|
38
|
+
patch(
|
|
39
|
+
"glean.indexing.connectors.base_people_connector.api_client"
|
|
40
|
+
) as mock_people_api_client,
|
|
41
|
+
):
|
|
42
|
+
mock_api_client.return_value.__enter__.return_value = self.mock_client
|
|
43
|
+
mock_people_api_client.return_value.__enter__.return_value = self.mock_client
|
|
44
|
+
|
|
45
|
+
# Run the connector for any supported type
|
|
46
|
+
if isinstance(self.connector, (BaseDatasourceConnector, BasePeopleConnector)):
|
|
47
|
+
self.connector.index_data()
|
|
48
|
+
else:
|
|
49
|
+
raise ValueError(f"Unsupported connector type: {type(self.connector)}")
|
|
50
|
+
|
|
51
|
+
def get_validator(self) -> ResponseValidator:
|
|
52
|
+
"""Get the response validator for checking results."""
|
|
53
|
+
return self.validator
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""Mock data source for testing connectors."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MockDataSource:
|
|
10
|
+
"""Mock data source for testing."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
all_items: Optional[List[Dict[str, Any]]] = None,
|
|
15
|
+
modified_items: Optional[List[Dict[str, Any]]] = None,
|
|
16
|
+
):
|
|
17
|
+
"""Initialize the MockDataSource.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
all_items: Items to return for get_all_items.
|
|
21
|
+
modified_items: Items to return for get_modified_items.
|
|
22
|
+
"""
|
|
23
|
+
self.all_items = all_items or []
|
|
24
|
+
self.modified_items = modified_items or []
|
|
25
|
+
|
|
26
|
+
def get_all_items(self) -> List[Dict[str, Any]]:
|
|
27
|
+
"""Get all items.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A list of all items.
|
|
31
|
+
"""
|
|
32
|
+
logger.info(f"MockDataSource.get_all_items() returning {len(self.all_items)} items")
|
|
33
|
+
return self.all_items
|
|
34
|
+
|
|
35
|
+
def get_modified_items(self, since: str) -> List[Dict[str, Any]]:
|
|
36
|
+
"""Get modified items.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
since: Timestamp to filter by.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A list of modified items.
|
|
43
|
+
"""
|
|
44
|
+
logger.info(
|
|
45
|
+
f"MockDataSource.get_modified_items(since={since}) returning {len(self.modified_items)} items"
|
|
46
|
+
)
|
|
47
|
+
return self.modified_items
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Mock Glean API client for testing."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from glean.api_client.models import DocumentDefinition, EmployeeInfoDefinition
|
|
7
|
+
from glean.indexing.testing.response_validator import ResponseValidator
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MockGleanClient:
|
|
13
|
+
"""Mock Glean API client for testing that matches the new GleanClient interface."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, validator: ResponseValidator):
|
|
16
|
+
"""Initialize the MockGleanClient.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
validator: Validator to record posted items.
|
|
20
|
+
"""
|
|
21
|
+
self.validator = validator
|
|
22
|
+
|
|
23
|
+
def index_documents(
|
|
24
|
+
self,
|
|
25
|
+
datasource: str,
|
|
26
|
+
documents: List[DocumentDefinition],
|
|
27
|
+
upload_id: Optional[str] = None,
|
|
28
|
+
is_first_page: bool = True,
|
|
29
|
+
is_last_page: bool = True,
|
|
30
|
+
**kwargs,
|
|
31
|
+
) -> Dict[str, Any]:
|
|
32
|
+
"""Mock method for indexing documents (new interface).
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
datasource: The datasource name.
|
|
36
|
+
documents: The documents to index.
|
|
37
|
+
upload_id: Optional upload ID for batch tracking
|
|
38
|
+
is_first_page: Whether this is the first page of a multi-page upload
|
|
39
|
+
is_last_page: Whether this is the last page of a multi-page upload
|
|
40
|
+
**kwargs: Additional parameters
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Mock API response
|
|
44
|
+
"""
|
|
45
|
+
logger.info(f"Mock indexing {len(documents)} documents to datasource '{datasource}'")
|
|
46
|
+
self.validator.documents_posted.extend(documents)
|
|
47
|
+
return {"status": "success", "indexed": len(documents)}
|
|
48
|
+
|
|
49
|
+
def index_employees(self, employees: List[EmployeeInfoDefinition], **kwargs) -> Dict[str, Any]:
|
|
50
|
+
"""Mock method for indexing employees (new interface).
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
employees: The employees to index.
|
|
54
|
+
**kwargs: Additional parameters
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Mock API response
|
|
58
|
+
"""
|
|
59
|
+
logger.info(f"Mock indexing {len(employees)} employees")
|
|
60
|
+
self.validator.employees_posted.extend(employees)
|
|
61
|
+
return {"status": "success", "indexed": len(employees)}
|
|
62
|
+
|
|
63
|
+
def batch_index_documents(self, datasource: str, documents: List[DocumentDefinition]) -> None:
|
|
64
|
+
"""Legacy method for indexing documents."""
|
|
65
|
+
self.index_documents(datasource=datasource, documents=documents)
|
|
66
|
+
|
|
67
|
+
def bulk_index_employees(self, employees: List[EmployeeInfoDefinition]) -> None:
|
|
68
|
+
"""Legacy method for indexing employees."""
|
|
69
|
+
self.index_employees(employees=employees)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Response validator for testing connector outputs."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from glean.api_client.models import DocumentDefinition, EmployeeInfoDefinition
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ResponseValidator:
|
|
12
|
+
"""Validator for connector responses."""
|
|
13
|
+
|
|
14
|
+
def __init__(self):
|
|
15
|
+
"""Initialize the ResponseValidator."""
|
|
16
|
+
self.documents_posted: List[DocumentDefinition] = []
|
|
17
|
+
self.employees_posted: List[EmployeeInfoDefinition] = []
|
|
18
|
+
|
|
19
|
+
def assert_documents_posted(self, count: Optional[int] = None) -> None:
|
|
20
|
+
"""Assert that documents were posted.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
count: Optional expected count of documents.
|
|
24
|
+
"""
|
|
25
|
+
if count is not None:
|
|
26
|
+
assert len(self.documents_posted) == count, (
|
|
27
|
+
f"Expected {count} documents to be posted, but got {len(self.documents_posted)}"
|
|
28
|
+
)
|
|
29
|
+
else:
|
|
30
|
+
assert len(self.documents_posted) > 0, "No documents were posted"
|
|
31
|
+
|
|
32
|
+
logger.info(f"Validated {len(self.documents_posted)} documents posted")
|
|
33
|
+
|
|
34
|
+
def assert_employees_posted(self, count: Optional[int] = None) -> None:
|
|
35
|
+
"""Assert that employees were posted.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
count: Optional expected count of employees.
|
|
39
|
+
"""
|
|
40
|
+
if count is not None:
|
|
41
|
+
assert len(self.employees_posted) == count, (
|
|
42
|
+
f"Expected {count} employees to be posted, but got {len(self.employees_posted)}"
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
assert len(self.employees_posted) > 0, "No employees were posted"
|
|
46
|
+
|
|
47
|
+
logger.info(f"Validated {len(self.employees_posted)} employees posted")
|
|
48
|
+
|
|
49
|
+
def reset(self) -> None:
|
|
50
|
+
"""Reset the validator state."""
|
|
51
|
+
self.documents_posted.clear()
|
|
52
|
+
self.employees_posted.clear()
|