qtype 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qtype/application/__init__.py +12 -0
- qtype/application/commons/__init__.py +7 -0
- qtype/{converters → application/converters}/tools_from_module.py +2 -2
- qtype/application/converters/types.py +33 -0
- qtype/{dsl/document.py → application/documentation.py} +2 -0
- qtype/application/facade.py +160 -0
- qtype/base/__init__.py +14 -0
- qtype/base/exceptions.py +49 -0
- qtype/base/logging.py +39 -0
- qtype/base/types.py +29 -0
- qtype/commands/convert.py +64 -49
- qtype/commands/generate.py +59 -4
- qtype/commands/run.py +109 -72
- qtype/commands/serve.py +42 -28
- qtype/commands/validate.py +25 -42
- qtype/commands/visualize.py +51 -37
- qtype/dsl/__init__.py +9 -0
- qtype/dsl/base_types.py +8 -0
- qtype/dsl/custom_types.py +6 -4
- qtype/dsl/model.py +185 -50
- qtype/dsl/validator.py +9 -4
- qtype/interpreter/api.py +96 -40
- qtype/interpreter/auth/__init__.py +3 -0
- qtype/interpreter/auth/aws.py +234 -0
- qtype/interpreter/auth/cache.py +67 -0
- qtype/interpreter/auth/generic.py +103 -0
- qtype/interpreter/batch/flow.py +95 -0
- qtype/interpreter/batch/sql_source.py +95 -0
- qtype/interpreter/batch/step.py +63 -0
- qtype/interpreter/batch/types.py +41 -0
- qtype/interpreter/batch/utils.py +179 -0
- qtype/interpreter/conversions.py +21 -10
- qtype/interpreter/resource_cache.py +4 -2
- qtype/interpreter/steps/decoder.py +13 -9
- qtype/interpreter/steps/llm_inference.py +7 -9
- qtype/interpreter/steps/prompt_template.py +1 -1
- qtype/interpreter/streaming_helpers.py +3 -3
- qtype/interpreter/typing.py +47 -11
- qtype/interpreter/ui/404/index.html +1 -1
- qtype/interpreter/ui/404.html +1 -1
- qtype/interpreter/ui/index.html +1 -1
- qtype/interpreter/ui/index.txt +1 -1
- qtype/loader.py +15 -16
- qtype/semantic/generate.py +91 -39
- qtype/semantic/model.py +183 -52
- qtype/semantic/resolver.py +4 -4
- {qtype-0.0.10.dist-info → qtype-0.0.12.dist-info}/METADATA +5 -1
- {qtype-0.0.10.dist-info → qtype-0.0.12.dist-info}/RECORD +58 -44
- qtype/commons/generate.py +0 -93
- qtype/converters/types.py +0 -66
- qtype/semantic/errors.py +0 -4
- /qtype/{commons → application/commons}/tools.py +0 -0
- /qtype/{commons → application/converters}/__init__.py +0 -0
- /qtype/{converters → application/converters}/tools_from_api.py +0 -0
- /qtype/{converters → interpreter/batch}/__init__.py +0 -0
- /qtype/interpreter/ui/_next/static/{Jb2murBlt2XkN6punrQbE → OT8QJQW3J70VbDWWfrEMT}/_buildManifest.js +0 -0
- /qtype/interpreter/ui/_next/static/{Jb2murBlt2XkN6punrQbE → OT8QJQW3J70VbDWWfrEMT}/_ssgManifest.js +0 -0
- {qtype-0.0.10.dist-info → qtype-0.0.12.dist-info}/WHEEL +0 -0
- {qtype-0.0.10.dist-info → qtype-0.0.12.dist-info}/entry_points.txt +0 -0
- {qtype-0.0.10.dist-info → qtype-0.0.12.dist-info}/licenses/LICENSE +0 -0
- {qtype-0.0.10.dist-info → qtype-0.0.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AWS authentication context manager for QType interpreter.
|
|
3
|
+
|
|
4
|
+
This module provides a context manager for creating boto3 sessions using
|
|
5
|
+
AWSAuthProvider configuration from the semantic model.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from typing import Any, Generator
|
|
12
|
+
|
|
13
|
+
import boto3 # type: ignore[import-untyped]
|
|
14
|
+
from botocore.exceptions import ( # type: ignore[import-untyped]
|
|
15
|
+
ClientError,
|
|
16
|
+
NoCredentialsError,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from qtype.interpreter.auth.cache import cache_auth, get_cached_auth
|
|
20
|
+
from qtype.semantic.model import AWSAuthProvider
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AWSAuthenticationError(Exception):
|
|
24
|
+
"""Raised when AWS authentication fails."""
|
|
25
|
+
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_session_valid(session: boto3.Session) -> bool:
|
|
30
|
+
"""
|
|
31
|
+
Check if a boto3 session is still valid by testing credential access.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
session: The boto3 session to validate
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
bool: True if the session is valid, False otherwise
|
|
38
|
+
"""
|
|
39
|
+
try:
|
|
40
|
+
credentials = session.get_credentials()
|
|
41
|
+
if credentials is None:
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
# For temporary credentials, check if they're still valid
|
|
45
|
+
if hasattr(credentials, "token") and credentials.token:
|
|
46
|
+
# Create a test STS client to verify the credentials
|
|
47
|
+
sts_client = session.client("sts")
|
|
48
|
+
sts_client.get_caller_identity()
|
|
49
|
+
|
|
50
|
+
return True
|
|
51
|
+
except (ClientError, NoCredentialsError):
|
|
52
|
+
return False
|
|
53
|
+
except Exception:
|
|
54
|
+
# Any other exception means the session is likely invalid
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@contextmanager
|
|
59
|
+
def aws(aws_provider: AWSAuthProvider) -> Generator[boto3.Session, None, None]:
|
|
60
|
+
"""
|
|
61
|
+
Create a boto3 Session using AWS authentication provider configuration.
|
|
62
|
+
|
|
63
|
+
This context manager creates a boto3 Session based on the authentication
|
|
64
|
+
method specified in the AWSAuthProvider. Sessions are cached using an LRU
|
|
65
|
+
cache to avoid recreating them unnecessarily. The cache size can be configured
|
|
66
|
+
via the AUTH_CACHE_MAX_SIZE environment variable (default: 128).
|
|
67
|
+
|
|
68
|
+
It supports:
|
|
69
|
+
- Direct credentials (access key + secret key + optional session token)
|
|
70
|
+
- AWS profiles from shared credentials/config files
|
|
71
|
+
- Role assumption (with optional external ID and MFA)
|
|
72
|
+
- Environment-based authentication (when no explicit credentials provided)
|
|
73
|
+
|
|
74
|
+
Caching behavior:
|
|
75
|
+
- Sessions are cached based on the AWSAuthProvider configuration
|
|
76
|
+
- Cached sessions are validated before reuse to check for expiration
|
|
77
|
+
- Invalid or expired sessions are evicted and recreated
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
aws_provider: AWSAuthProvider instance containing authentication configuration
|
|
81
|
+
|
|
82
|
+
Yields:
|
|
83
|
+
boto3.Session: Configured boto3 session ready for creating AWS service clients
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
AWSAuthenticationError: When authentication fails or configuration is invalid
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
```python
|
|
90
|
+
from qtype.semantic.model import AWSAuthProvider
|
|
91
|
+
from qtype.interpreter.auth.aws import aws
|
|
92
|
+
|
|
93
|
+
aws_auth = AWSAuthProvider(
|
|
94
|
+
id="my-aws-auth",
|
|
95
|
+
type="aws",
|
|
96
|
+
access_key_id="AKIA...",
|
|
97
|
+
secret_access_key="...",
|
|
98
|
+
region="us-east-1"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
with aws(aws_auth) as session:
|
|
102
|
+
athena_client = session.client("athena")
|
|
103
|
+
s3_client = session.client("s3")
|
|
104
|
+
```
|
|
105
|
+
"""
|
|
106
|
+
try:
|
|
107
|
+
# Check cache first - use provider object directly as cache key
|
|
108
|
+
cached_session = get_cached_auth(aws_provider)
|
|
109
|
+
|
|
110
|
+
if cached_session is not None and _is_session_valid(cached_session):
|
|
111
|
+
# Cache hit with valid session
|
|
112
|
+
yield cached_session
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
# Cache miss or invalid session - create new session
|
|
116
|
+
session = _create_session(aws_provider)
|
|
117
|
+
|
|
118
|
+
# Validate the session by attempting to get credentials
|
|
119
|
+
credentials = session.get_credentials()
|
|
120
|
+
if credentials is None:
|
|
121
|
+
raise AWSAuthenticationError(
|
|
122
|
+
f"Failed to obtain AWS credentials for provider '{aws_provider.id}'"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Cache the valid session using provider object as key
|
|
126
|
+
cache_auth(aws_provider, session)
|
|
127
|
+
|
|
128
|
+
yield session
|
|
129
|
+
|
|
130
|
+
except (ClientError, NoCredentialsError) as e:
|
|
131
|
+
raise AWSAuthenticationError(
|
|
132
|
+
f"AWS authentication failed for provider '{aws_provider.id}': {e}"
|
|
133
|
+
) from e
|
|
134
|
+
except Exception as e:
|
|
135
|
+
raise AWSAuthenticationError(
|
|
136
|
+
f"Unexpected error during AWS authentication for provider '{aws_provider.id}': {e}"
|
|
137
|
+
) from e
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _create_session(aws_provider: AWSAuthProvider) -> boto3.Session:
|
|
141
|
+
"""
|
|
142
|
+
Create a boto3 Session based on the AWS provider configuration.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
aws_provider: AWSAuthProvider with authentication details
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
boto3.Session: Configured session
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
AWSAuthenticationError: If configuration is invalid
|
|
152
|
+
"""
|
|
153
|
+
session_kwargs: dict[str, Any] = {}
|
|
154
|
+
|
|
155
|
+
# Add region if specified
|
|
156
|
+
if aws_provider.region:
|
|
157
|
+
session_kwargs["region_name"] = aws_provider.region
|
|
158
|
+
|
|
159
|
+
# Handle different authentication methods
|
|
160
|
+
if aws_provider.profile_name:
|
|
161
|
+
# Use AWS profile from shared credentials/config files
|
|
162
|
+
session_kwargs["profile_name"] = aws_provider.profile_name
|
|
163
|
+
|
|
164
|
+
elif aws_provider.access_key_id and aws_provider.secret_access_key:
|
|
165
|
+
# Use direct credentials
|
|
166
|
+
session_kwargs["aws_access_key_id"] = aws_provider.access_key_id
|
|
167
|
+
session_kwargs["aws_secret_access_key"] = (
|
|
168
|
+
aws_provider.secret_access_key
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if aws_provider.session_token:
|
|
172
|
+
session_kwargs["aws_session_token"] = aws_provider.session_token
|
|
173
|
+
|
|
174
|
+
# Create the base session
|
|
175
|
+
session = boto3.Session(**session_kwargs)
|
|
176
|
+
|
|
177
|
+
# Handle role assumption if specified
|
|
178
|
+
if aws_provider.role_arn:
|
|
179
|
+
session = _assume_role_session(session, aws_provider)
|
|
180
|
+
|
|
181
|
+
return session
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _assume_role_session(
|
|
185
|
+
base_session: boto3.Session, aws_provider: AWSAuthProvider
|
|
186
|
+
) -> boto3.Session:
|
|
187
|
+
"""
|
|
188
|
+
Create a new session by assuming an IAM role.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
base_session: The base session to use for assuming the role
|
|
192
|
+
aws_provider: AWSAuthProvider with role configuration
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
boto3.Session: New session with assumed role credentials
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
AWSAuthenticationError: If role assumption fails
|
|
199
|
+
"""
|
|
200
|
+
if not aws_provider.role_arn:
|
|
201
|
+
raise AWSAuthenticationError(
|
|
202
|
+
"role_arn is required for role assumption"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
sts_client = base_session.client("sts")
|
|
207
|
+
|
|
208
|
+
# Prepare AssumeRole parameters
|
|
209
|
+
assume_role_params: dict[str, Any] = {
|
|
210
|
+
"RoleArn": aws_provider.role_arn,
|
|
211
|
+
"RoleSessionName": aws_provider.role_session_name
|
|
212
|
+
or f"qtype-session-{aws_provider.id}",
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
if aws_provider.external_id:
|
|
216
|
+
assume_role_params["ExternalId"] = aws_provider.external_id
|
|
217
|
+
|
|
218
|
+
# Assume the role
|
|
219
|
+
response = sts_client.assume_role(**assume_role_params)
|
|
220
|
+
credentials = response["Credentials"]
|
|
221
|
+
|
|
222
|
+
# Create new session with temporary credentials
|
|
223
|
+
return boto3.Session(
|
|
224
|
+
aws_access_key_id=credentials["AccessKeyId"],
|
|
225
|
+
aws_secret_access_key=credentials["SecretAccessKey"],
|
|
226
|
+
aws_session_token=credentials["SessionToken"],
|
|
227
|
+
region_name=aws_provider.region or base_session.region_name,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
except ClientError as e:
|
|
231
|
+
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
|
232
|
+
raise AWSAuthenticationError(
|
|
233
|
+
f"Failed to assume role '{aws_provider.role_arn}': {error_code} - {e}"
|
|
234
|
+
) from e
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Authorization cache for QType interpreter.
|
|
3
|
+
|
|
4
|
+
This module provides a shared LRU cache for authorization sessions and tokens
|
|
5
|
+
across different authentication providers (AWS, OAuth2, API keys, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from cachetools import LRUCache
|
|
14
|
+
|
|
15
|
+
# Global LRU cache for authorization sessions with configurable size
|
|
16
|
+
_AUTH_CACHE_MAX_SIZE = int(os.environ.get("AUTH_CACHE_MAX_SIZE", 128))
|
|
17
|
+
_AUTHORIZATION_CACHE: LRUCache[Any, Any] = LRUCache(
|
|
18
|
+
maxsize=_AUTH_CACHE_MAX_SIZE
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_cached_auth(auth_provider: Any) -> Any | None:
|
|
23
|
+
"""
|
|
24
|
+
Get a cached authorization session for the given provider.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
auth_provider: Authorization provider instance (must be hashable)
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Cached session/token or None if not found
|
|
31
|
+
"""
|
|
32
|
+
return _AUTHORIZATION_CACHE.get(auth_provider)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def cache_auth(auth_provider: Any, session: Any) -> None:
|
|
36
|
+
"""
|
|
37
|
+
Cache an authorization session for the given provider.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
auth_provider: Authorization provider instance (must be hashable)
|
|
41
|
+
session: Session or token to cache
|
|
42
|
+
"""
|
|
43
|
+
_AUTHORIZATION_CACHE[auth_provider] = session
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def clear_auth_cache() -> None:
|
|
47
|
+
"""
|
|
48
|
+
Clear all cached authorization sessions.
|
|
49
|
+
|
|
50
|
+
This can be useful for testing or when credential configurations change.
|
|
51
|
+
"""
|
|
52
|
+
_AUTHORIZATION_CACHE.clear()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_cache_info() -> dict[str, Any]:
|
|
56
|
+
"""
|
|
57
|
+
Get information about the current state of the authorization cache.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Dictionary with cache statistics and configuration
|
|
61
|
+
"""
|
|
62
|
+
return {
|
|
63
|
+
"max_size": _AUTH_CACHE_MAX_SIZE,
|
|
64
|
+
"current_size": len(_AUTHORIZATION_CACHE),
|
|
65
|
+
"hits": getattr(_AUTHORIZATION_CACHE, "hits", 0),
|
|
66
|
+
"misses": getattr(_AUTHORIZATION_CACHE, "misses", 0),
|
|
67
|
+
}
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generic authorization context manager for QType interpreter.
|
|
3
|
+
|
|
4
|
+
This module provides a unified context manager that can handle any AuthorizationProvider
|
|
5
|
+
type and return the appropriate session or provider instance.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from typing import Generator
|
|
12
|
+
|
|
13
|
+
import boto3 # type: ignore[import-untyped]
|
|
14
|
+
|
|
15
|
+
from qtype.interpreter.auth.aws import aws
|
|
16
|
+
from qtype.semantic.model import (
|
|
17
|
+
APIKeyAuthProvider,
|
|
18
|
+
AuthorizationProvider,
|
|
19
|
+
AWSAuthProvider,
|
|
20
|
+
OAuth2AuthProvider,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class UnsupportedAuthProviderError(Exception):
|
|
25
|
+
"""Raised when an unsupported authorization provider type is used."""
|
|
26
|
+
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@contextmanager
|
|
31
|
+
def auth(
|
|
32
|
+
auth_provider: AuthorizationProvider,
|
|
33
|
+
) -> Generator[boto3.Session | APIKeyAuthProvider, None, None]:
|
|
34
|
+
"""
|
|
35
|
+
Create an appropriate session or provider instance based on the auth provider type.
|
|
36
|
+
|
|
37
|
+
This context manager dispatches to the appropriate authentication handler based
|
|
38
|
+
on the type of AuthorizationProvider:
|
|
39
|
+
- AWSAuthProvider: Returns a configured boto3.Session
|
|
40
|
+
- APIKeyAuthProvider: Returns the provider instance (contains the API key)
|
|
41
|
+
- OAuth2AuthProvider: Raises NotImplementedError (not yet supported)
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
auth_provider: AuthorizationProvider instance of any supported type
|
|
45
|
+
|
|
46
|
+
Yields:
|
|
47
|
+
boto3.Session | APIKeyAuthProvider: The appropriate session or provider instance
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
UnsupportedAuthProviderError: When an unsupported provider type is used
|
|
51
|
+
NotImplementedError: When OAuth2AuthProvider is used (not yet implemented)
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
```python
|
|
55
|
+
from qtype.semantic.model import AWSAuthProvider, APIKeyAuthProvider
|
|
56
|
+
from qtype.interpreter.auth.generic import auth
|
|
57
|
+
|
|
58
|
+
# AWS provider - returns boto3.Session
|
|
59
|
+
aws_auth = AWSAuthProvider(
|
|
60
|
+
id="my-aws-auth",
|
|
61
|
+
type="aws",
|
|
62
|
+
access_key_id="AKIA...",
|
|
63
|
+
secret_access_key="...",
|
|
64
|
+
region="us-east-1"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
with auth(aws_auth) as session:
|
|
68
|
+
s3_client = session.client("s3")
|
|
69
|
+
|
|
70
|
+
# API Key provider - returns the provider itself
|
|
71
|
+
api_auth = APIKeyAuthProvider(
|
|
72
|
+
id="my-api-auth",
|
|
73
|
+
type="api_key",
|
|
74
|
+
api_key="sk-...",
|
|
75
|
+
host="api.openai.com"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
with auth(api_auth) as provider:
|
|
79
|
+
headers = {"Authorization": f"Bearer {provider.api_key}"}
|
|
80
|
+
```
|
|
81
|
+
"""
|
|
82
|
+
if isinstance(auth_provider, AWSAuthProvider):
|
|
83
|
+
# Use AWS-specific context manager
|
|
84
|
+
with aws(auth_provider) as session:
|
|
85
|
+
yield session
|
|
86
|
+
|
|
87
|
+
elif isinstance(auth_provider, APIKeyAuthProvider):
|
|
88
|
+
# For API key providers, just return the provider itself
|
|
89
|
+
# The caller can access provider.api_key and provider.host
|
|
90
|
+
yield auth_provider
|
|
91
|
+
|
|
92
|
+
elif isinstance(auth_provider, OAuth2AuthProvider):
|
|
93
|
+
# OAuth2 not yet implemented
|
|
94
|
+
raise NotImplementedError(
|
|
95
|
+
f"OAuth2 authentication is not yet implemented for provider '{auth_provider.id}'"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
else:
|
|
99
|
+
# Unknown provider type
|
|
100
|
+
raise UnsupportedAuthProviderError(
|
|
101
|
+
f"Unsupported authorization provider type: {type(auth_provider).__name__} "
|
|
102
|
+
f"for provider '{auth_provider.id}'"
|
|
103
|
+
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Tuple
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from qtype.interpreter.batch.step import batch_execute_step
|
|
9
|
+
from qtype.interpreter.batch.types import BatchConfig
|
|
10
|
+
from qtype.interpreter.batch.utils import reconcile_results_and_errors
|
|
11
|
+
from qtype.semantic.model import Flow, Sink
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def batch_execute_flow(
|
|
17
|
+
flow: Flow,
|
|
18
|
+
inputs: pd.DataFrame,
|
|
19
|
+
batch_config: BatchConfig,
|
|
20
|
+
**kwargs: dict[Any, Any],
|
|
21
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
22
|
+
"""Executes a flow in a batch context.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
flow: The flow to execute.
|
|
26
|
+
batch_config: The batch configuration to use.
|
|
27
|
+
**kwargs: Additional keyword arguments to pass to the flow.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A list of output variables produced by the flow.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
previous_outputs = inputs
|
|
34
|
+
|
|
35
|
+
all_errors = []
|
|
36
|
+
|
|
37
|
+
# Iterate over each step in the flow
|
|
38
|
+
for step in flow.steps:
|
|
39
|
+
results: list[pd.DataFrame] = []
|
|
40
|
+
errors: list[pd.DataFrame] = []
|
|
41
|
+
|
|
42
|
+
if isinstance(step, Sink):
|
|
43
|
+
# Send the entire batch to the sink
|
|
44
|
+
batch_results, batch_errors = batch_execute_step(
|
|
45
|
+
step, previous_outputs, batch_config
|
|
46
|
+
)
|
|
47
|
+
results.append(batch_results)
|
|
48
|
+
if len(batch_errors) > 1:
|
|
49
|
+
errors.append(batch_errors)
|
|
50
|
+
else:
|
|
51
|
+
# batch the current data into dataframes of max size batch_size
|
|
52
|
+
batch_size = batch_config.batch_size
|
|
53
|
+
for start in range(0, len(previous_outputs), batch_size):
|
|
54
|
+
end = start + batch_size
|
|
55
|
+
batch = previous_outputs.iloc[start:end]
|
|
56
|
+
# Execute the step with the current batch
|
|
57
|
+
batch_results, batch_errors = batch_execute_step(
|
|
58
|
+
step, batch, batch_config
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
results.append(batch_results)
|
|
62
|
+
if len(batch_errors) > 1:
|
|
63
|
+
errors.append(batch_errors)
|
|
64
|
+
|
|
65
|
+
previous_outputs, errors_df = reconcile_results_and_errors(
|
|
66
|
+
results, errors
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if len(errors_df):
|
|
70
|
+
all_errors.append(errors_df)
|
|
71
|
+
if batch_config.write_errors_to:
|
|
72
|
+
output_file = (
|
|
73
|
+
f"{batch_config.write_errors_to}/{step.id}.errors.parquet"
|
|
74
|
+
)
|
|
75
|
+
try:
|
|
76
|
+
errors_df.to_parquet(
|
|
77
|
+
output_file, engine="pyarrow", compression="snappy"
|
|
78
|
+
)
|
|
79
|
+
logging.info(
|
|
80
|
+
f"Saved errors for step {step.id} to {output_file}"
|
|
81
|
+
)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logging.warning(
|
|
84
|
+
f"Could not save errors step {step.id} to {output_file}",
|
|
85
|
+
exc_info=e,
|
|
86
|
+
stack_info=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Return the last steps results and errors
|
|
90
|
+
rv_errors = (
|
|
91
|
+
pd.concat(all_errors, ignore_index=True)
|
|
92
|
+
if len(all_errors)
|
|
93
|
+
else pd.DataFrame({})
|
|
94
|
+
)
|
|
95
|
+
return previous_outputs, rv_errors
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
import boto3 # type: ignore[import-untyped]
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import sqlalchemy
|
|
6
|
+
from sqlalchemy import create_engine
|
|
7
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
8
|
+
|
|
9
|
+
from qtype.base.exceptions import InterpreterError
|
|
10
|
+
from qtype.interpreter.auth.generic import auth
|
|
11
|
+
from qtype.interpreter.batch.types import BatchConfig, ErrorMode
|
|
12
|
+
from qtype.interpreter.batch.utils import (
|
|
13
|
+
reconcile_results_and_errors,
|
|
14
|
+
validate_inputs,
|
|
15
|
+
)
|
|
16
|
+
from qtype.semantic.model import SQLSource
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def to_output_columns(
|
|
20
|
+
df: pd.DataFrame, output_columns: set[str]
|
|
21
|
+
) -> pd.DataFrame:
|
|
22
|
+
"""Filters the DataFrame to only include specified output columns.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
df: The input DataFrame.
|
|
26
|
+
output_columns: A set of column names to retain in the DataFrame.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
A DataFrame containing only the specified output columns.
|
|
30
|
+
"""
|
|
31
|
+
if len(df) == 0:
|
|
32
|
+
return df
|
|
33
|
+
missing = output_columns - set(df.columns)
|
|
34
|
+
if missing:
|
|
35
|
+
raise InterpreterError(
|
|
36
|
+
f"SQL Result was missing expected columns: {','.join(missing)}, it has columns: {','.join(df.columns)}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return df[[col for col in df.columns if col in output_columns]]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def execute_sql_source(
|
|
43
|
+
step: SQLSource,
|
|
44
|
+
inputs: pd.DataFrame,
|
|
45
|
+
batch_config: BatchConfig,
|
|
46
|
+
**kwargs: dict[Any, Any],
|
|
47
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
48
|
+
"""Executes a SQLSource step to retrieve data from a SQL database.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
step: The SQLSource step to execute.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A tuple containing two DataFrames:
|
|
55
|
+
- The first DataFrame contains the successfully retrieved data.
|
|
56
|
+
- The second DataFrame contains rows that encountered errors with an 'error' column.
|
|
57
|
+
"""
|
|
58
|
+
# Create a database engine
|
|
59
|
+
validate_inputs(inputs, step)
|
|
60
|
+
|
|
61
|
+
connect_args = {}
|
|
62
|
+
if step.auth:
|
|
63
|
+
with auth(step.auth) as creds:
|
|
64
|
+
if isinstance(creds, boto3.Session):
|
|
65
|
+
connect_args["session"] = creds
|
|
66
|
+
engine = create_engine(step.connection, connect_args=connect_args)
|
|
67
|
+
|
|
68
|
+
output_columns = {output.id for output in step.outputs}
|
|
69
|
+
|
|
70
|
+
results = []
|
|
71
|
+
errors = []
|
|
72
|
+
step_inputs = {i.id for i in step.inputs}
|
|
73
|
+
for _, row in inputs.iterrows():
|
|
74
|
+
try:
|
|
75
|
+
# Make a dictionary of column_name: value from row
|
|
76
|
+
params = {col: row[col] for col in row.index if col in step_inputs}
|
|
77
|
+
# Execute the query and fetch the results into a DataFrame
|
|
78
|
+
with engine.connect() as connection:
|
|
79
|
+
result = connection.execute(
|
|
80
|
+
sqlalchemy.text(step.query),
|
|
81
|
+
parameters=params if len(params) else None,
|
|
82
|
+
)
|
|
83
|
+
df = pd.DataFrame(
|
|
84
|
+
result.fetchall(), columns=list(result.keys())
|
|
85
|
+
)
|
|
86
|
+
df = to_output_columns(df, output_columns)
|
|
87
|
+
results.append(df)
|
|
88
|
+
except SQLAlchemyError as e:
|
|
89
|
+
if batch_config.error_mode == ErrorMode.FAIL:
|
|
90
|
+
raise e
|
|
91
|
+
# If there's an error, return an empty DataFrame and the error message
|
|
92
|
+
error_df = pd.DataFrame([{"error": str(e)}])
|
|
93
|
+
errors.append(error_df)
|
|
94
|
+
|
|
95
|
+
return reconcile_results_and_errors(results, errors)
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from typing import Any, Tuple
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
|
|
6
|
+
from qtype.interpreter.batch.sql_source import execute_sql_source
|
|
7
|
+
from qtype.interpreter.batch.types import BatchConfig
|
|
8
|
+
from qtype.interpreter.batch.utils import (
|
|
9
|
+
batch_iterator,
|
|
10
|
+
single_step_adapter,
|
|
11
|
+
validate_inputs,
|
|
12
|
+
)
|
|
13
|
+
from qtype.interpreter.exceptions import InterpreterError
|
|
14
|
+
from qtype.semantic.model import (
|
|
15
|
+
Condition,
|
|
16
|
+
Decoder,
|
|
17
|
+
Flow,
|
|
18
|
+
PromptTemplate,
|
|
19
|
+
Search,
|
|
20
|
+
SQLSource,
|
|
21
|
+
Step,
|
|
22
|
+
Tool,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
SINGLE_WRAP_STEPS = {Decoder, Condition, PromptTemplate, Search, Tool}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def batch_execute_step(
|
|
29
|
+
step: Step,
|
|
30
|
+
inputs: pd.DataFrame,
|
|
31
|
+
batch_config: BatchConfig,
|
|
32
|
+
**kwargs: dict[str, Any],
|
|
33
|
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
34
|
+
"""
|
|
35
|
+
Executes a given step in a batch processing pipeline.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
step (Step): The step to be executed.
|
|
39
|
+
inputs (pd.DataFrame): The input data for the step.
|
|
40
|
+
batch_config (BatchConfig): Configuration for batch processing.
|
|
41
|
+
**kwargs: Additional keyword arguments.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the output results and any rows that returned errors.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
validate_inputs(inputs, step)
|
|
48
|
+
|
|
49
|
+
if isinstance(step, Flow):
|
|
50
|
+
from qtype.interpreter.batch.flow import batch_execute_flow
|
|
51
|
+
|
|
52
|
+
return batch_execute_flow(step, inputs, batch_config, **kwargs)
|
|
53
|
+
elif isinstance(step, SQLSource):
|
|
54
|
+
return execute_sql_source(step, inputs, batch_config, **kwargs)
|
|
55
|
+
elif step in SINGLE_WRAP_STEPS:
|
|
56
|
+
return batch_iterator(
|
|
57
|
+
f=partial(single_step_adapter, step=step),
|
|
58
|
+
batch=inputs,
|
|
59
|
+
batch_config=batch_config,
|
|
60
|
+
)
|
|
61
|
+
# TODO: implement batching for multi-row steps. For example, llm inference can be sped up in batch...
|
|
62
|
+
else:
|
|
63
|
+
raise InterpreterError(f"Unsupported step type: {type(step).__name__}")
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ErrorMode(str, Enum):
|
|
9
|
+
"""Error handling mode for batch processing."""
|
|
10
|
+
|
|
11
|
+
FAIL = "fail"
|
|
12
|
+
DROP = "drop"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BatchConfig(BaseModel):
|
|
16
|
+
"""Configuration for batch execution.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
num_workers: Number of async workers for batch operations.
|
|
20
|
+
batch_size: Maximum number of rows to send to a step at a time.
|
|
21
|
+
error_mode: Error handling mode for batch processing.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
num_workers: int = Field(
|
|
25
|
+
default=4,
|
|
26
|
+
description="Number of async workers for batch operations",
|
|
27
|
+
gt=0,
|
|
28
|
+
)
|
|
29
|
+
batch_size: int = Field(
|
|
30
|
+
default=512,
|
|
31
|
+
description="Max number of rows to send to a step at a time",
|
|
32
|
+
gt=0,
|
|
33
|
+
)
|
|
34
|
+
error_mode: ErrorMode = Field(
|
|
35
|
+
default=ErrorMode.FAIL,
|
|
36
|
+
description="Error handling mode for batch processing",
|
|
37
|
+
)
|
|
38
|
+
write_errors_to: str | None = Field(
|
|
39
|
+
default=None,
|
|
40
|
+
description="If error mode is DROP, the errors for any step are saved to this directory",
|
|
41
|
+
)
|