qtype 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. qtype/application/__init__.py +12 -0
  2. qtype/application/commons/__init__.py +7 -0
  3. qtype/{converters → application/converters}/tools_from_module.py +2 -2
  4. qtype/{converters → application/converters}/types.py +0 -33
  5. qtype/{dsl/document.py → application/documentation.py} +2 -0
  6. qtype/application/facade.py +160 -0
  7. qtype/base/__init__.py +14 -0
  8. qtype/base/exceptions.py +49 -0
  9. qtype/base/logging.py +39 -0
  10. qtype/base/types.py +29 -0
  11. qtype/commands/convert.py +64 -49
  12. qtype/commands/generate.py +59 -4
  13. qtype/commands/run.py +109 -72
  14. qtype/commands/serve.py +42 -28
  15. qtype/commands/validate.py +25 -42
  16. qtype/commands/visualize.py +51 -37
  17. qtype/dsl/__init__.py +9 -0
  18. qtype/dsl/base_types.py +8 -0
  19. qtype/dsl/custom_types.py +6 -4
  20. qtype/dsl/model.py +185 -50
  21. qtype/dsl/validator.py +9 -4
  22. qtype/interpreter/api.py +96 -40
  23. qtype/interpreter/auth/__init__.py +3 -0
  24. qtype/interpreter/auth/aws.py +234 -0
  25. qtype/interpreter/auth/cache.py +67 -0
  26. qtype/interpreter/auth/generic.py +103 -0
  27. qtype/interpreter/batch/flow.py +95 -0
  28. qtype/interpreter/batch/sql_source.py +95 -0
  29. qtype/interpreter/batch/step.py +63 -0
  30. qtype/interpreter/batch/types.py +41 -0
  31. qtype/interpreter/batch/utils.py +179 -0
  32. qtype/interpreter/conversions.py +21 -10
  33. qtype/interpreter/resource_cache.py +4 -2
  34. qtype/interpreter/steps/decoder.py +13 -9
  35. qtype/interpreter/steps/llm_inference.py +7 -9
  36. qtype/interpreter/steps/prompt_template.py +1 -1
  37. qtype/interpreter/streaming_helpers.py +3 -3
  38. qtype/interpreter/typing.py +47 -11
  39. qtype/interpreter/ui/404/index.html +1 -1
  40. qtype/interpreter/ui/404.html +1 -1
  41. qtype/interpreter/ui/index.html +1 -1
  42. qtype/interpreter/ui/index.txt +1 -1
  43. qtype/loader.py +9 -15
  44. qtype/semantic/generate.py +91 -39
  45. qtype/semantic/model.py +183 -52
  46. qtype/semantic/resolver.py +4 -4
  47. {qtype-0.0.9.dist-info → qtype-0.0.11.dist-info}/METADATA +5 -1
  48. {qtype-0.0.9.dist-info → qtype-0.0.11.dist-info}/RECORD +58 -44
  49. qtype/commons/generate.py +0 -93
  50. qtype/semantic/errors.py +0 -4
  51. /qtype/{commons → application/commons}/tools.py +0 -0
  52. /qtype/{commons → application/converters}/__init__.py +0 -0
  53. /qtype/{converters → application/converters}/tools_from_api.py +0 -0
  54. /qtype/{converters → interpreter/batch}/__init__.py +0 -0
  55. /qtype/interpreter/ui/_next/static/{uMm4B0RSTGhXxgH3rTfwc → OT8QJQW3J70VbDWWfrEMT}/_buildManifest.js +0 -0
  56. /qtype/interpreter/ui/_next/static/{uMm4B0RSTGhXxgH3rTfwc → OT8QJQW3J70VbDWWfrEMT}/_ssgManifest.js +0 -0
  57. {qtype-0.0.9.dist-info → qtype-0.0.11.dist-info}/WHEEL +0 -0
  58. {qtype-0.0.9.dist-info → qtype-0.0.11.dist-info}/entry_points.txt +0 -0
  59. {qtype-0.0.9.dist-info → qtype-0.0.11.dist-info}/licenses/LICENSE +0 -0
  60. {qtype-0.0.9.dist-info → qtype-0.0.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,234 @@
1
+ """
2
+ AWS authentication context manager for QType interpreter.
3
+
4
+ This module provides a context manager for creating boto3 sessions using
5
+ AWSAuthProvider configuration from the semantic model.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from contextlib import contextmanager
11
+ from typing import Any, Generator
12
+
13
+ import boto3 # type: ignore[import-untyped]
14
+ from botocore.exceptions import ( # type: ignore[import-untyped]
15
+ ClientError,
16
+ NoCredentialsError,
17
+ )
18
+
19
+ from qtype.interpreter.auth.cache import cache_auth, get_cached_auth
20
+ from qtype.semantic.model import AWSAuthProvider
21
+
22
+
23
+ class AWSAuthenticationError(Exception):
24
+ """Raised when AWS authentication fails."""
25
+
26
+ pass
27
+
28
+
29
+ def _is_session_valid(session: boto3.Session) -> bool:
30
+ """
31
+ Check if a boto3 session is still valid by testing credential access.
32
+
33
+ Args:
34
+ session: The boto3 session to validate
35
+
36
+ Returns:
37
+ bool: True if the session is valid, False otherwise
38
+ """
39
+ try:
40
+ credentials = session.get_credentials()
41
+ if credentials is None:
42
+ return False
43
+
44
+ # For temporary credentials, check if they're still valid
45
+ if hasattr(credentials, "token") and credentials.token:
46
+ # Create a test STS client to verify the credentials
47
+ sts_client = session.client("sts")
48
+ sts_client.get_caller_identity()
49
+
50
+ return True
51
+ except (ClientError, NoCredentialsError):
52
+ return False
53
+ except Exception:
54
+ # Any other exception means the session is likely invalid
55
+ return False
56
+
57
+
58
+ @contextmanager
59
+ def aws(aws_provider: AWSAuthProvider) -> Generator[boto3.Session, None, None]:
60
+ """
61
+ Create a boto3 Session using AWS authentication provider configuration.
62
+
63
+ This context manager creates a boto3 Session based on the authentication
64
+ method specified in the AWSAuthProvider. Sessions are cached using an LRU
65
+ cache to avoid recreating them unnecessarily. The cache size can be configured
66
+ via the AUTH_CACHE_MAX_SIZE environment variable (default: 128).
67
+
68
+ It supports:
69
+ - Direct credentials (access key + secret key + optional session token)
70
+ - AWS profiles from shared credentials/config files
71
+ - Role assumption (with optional external ID and MFA)
72
+ - Environment-based authentication (when no explicit credentials provided)
73
+
74
+ Caching behavior:
75
+ - Sessions are cached based on the AWSAuthProvider configuration
76
+ - Cached sessions are validated before reuse to check for expiration
77
+ - Invalid or expired sessions are evicted and recreated
78
+
79
+ Args:
80
+ aws_provider: AWSAuthProvider instance containing authentication configuration
81
+
82
+ Yields:
83
+ boto3.Session: Configured boto3 session ready for creating AWS service clients
84
+
85
+ Raises:
86
+ AWSAuthenticationError: When authentication fails or configuration is invalid
87
+
88
+ Example:
89
+ ```python
90
+ from qtype.semantic.model import AWSAuthProvider
91
+ from qtype.interpreter.auth.aws import aws
92
+
93
+ aws_auth = AWSAuthProvider(
94
+ id="my-aws-auth",
95
+ type="aws",
96
+ access_key_id="AKIA...",
97
+ secret_access_key="...",
98
+ region="us-east-1"
99
+ )
100
+
101
+ with aws(aws_auth) as session:
102
+ athena_client = session.client("athena")
103
+ s3_client = session.client("s3")
104
+ ```
105
+ """
106
+ try:
107
+ # Check cache first - use provider object directly as cache key
108
+ cached_session = get_cached_auth(aws_provider)
109
+
110
+ if cached_session is not None and _is_session_valid(cached_session):
111
+ # Cache hit with valid session
112
+ yield cached_session
113
+ return
114
+
115
+ # Cache miss or invalid session - create new session
116
+ session = _create_session(aws_provider)
117
+
118
+ # Validate the session by attempting to get credentials
119
+ credentials = session.get_credentials()
120
+ if credentials is None:
121
+ raise AWSAuthenticationError(
122
+ f"Failed to obtain AWS credentials for provider '{aws_provider.id}'"
123
+ )
124
+
125
+ # Cache the valid session using provider object as key
126
+ cache_auth(aws_provider, session)
127
+
128
+ yield session
129
+
130
+ except (ClientError, NoCredentialsError) as e:
131
+ raise AWSAuthenticationError(
132
+ f"AWS authentication failed for provider '{aws_provider.id}': {e}"
133
+ ) from e
134
+ except Exception as e:
135
+ raise AWSAuthenticationError(
136
+ f"Unexpected error during AWS authentication for provider '{aws_provider.id}': {e}"
137
+ ) from e
138
+
139
+
140
+ def _create_session(aws_provider: AWSAuthProvider) -> boto3.Session:
141
+ """
142
+ Create a boto3 Session based on the AWS provider configuration.
143
+
144
+ Args:
145
+ aws_provider: AWSAuthProvider with authentication details
146
+
147
+ Returns:
148
+ boto3.Session: Configured session
149
+
150
+ Raises:
151
+ AWSAuthenticationError: If configuration is invalid
152
+ """
153
+ session_kwargs: dict[str, Any] = {}
154
+
155
+ # Add region if specified
156
+ if aws_provider.region:
157
+ session_kwargs["region_name"] = aws_provider.region
158
+
159
+ # Handle different authentication methods
160
+ if aws_provider.profile_name:
161
+ # Use AWS profile from shared credentials/config files
162
+ session_kwargs["profile_name"] = aws_provider.profile_name
163
+
164
+ elif aws_provider.access_key_id and aws_provider.secret_access_key:
165
+ # Use direct credentials
166
+ session_kwargs["aws_access_key_id"] = aws_provider.access_key_id
167
+ session_kwargs["aws_secret_access_key"] = (
168
+ aws_provider.secret_access_key
169
+ )
170
+
171
+ if aws_provider.session_token:
172
+ session_kwargs["aws_session_token"] = aws_provider.session_token
173
+
174
+ # Create the base session
175
+ session = boto3.Session(**session_kwargs)
176
+
177
+ # Handle role assumption if specified
178
+ if aws_provider.role_arn:
179
+ session = _assume_role_session(session, aws_provider)
180
+
181
+ return session
182
+
183
+
184
+ def _assume_role_session(
185
+ base_session: boto3.Session, aws_provider: AWSAuthProvider
186
+ ) -> boto3.Session:
187
+ """
188
+ Create a new session by assuming an IAM role.
189
+
190
+ Args:
191
+ base_session: The base session to use for assuming the role
192
+ aws_provider: AWSAuthProvider with role configuration
193
+
194
+ Returns:
195
+ boto3.Session: New session with assumed role credentials
196
+
197
+ Raises:
198
+ AWSAuthenticationError: If role assumption fails
199
+ """
200
+ if not aws_provider.role_arn:
201
+ raise AWSAuthenticationError(
202
+ "role_arn is required for role assumption"
203
+ )
204
+
205
+ try:
206
+ sts_client = base_session.client("sts")
207
+
208
+ # Prepare AssumeRole parameters
209
+ assume_role_params: dict[str, Any] = {
210
+ "RoleArn": aws_provider.role_arn,
211
+ "RoleSessionName": aws_provider.role_session_name
212
+ or f"qtype-session-{aws_provider.id}",
213
+ }
214
+
215
+ if aws_provider.external_id:
216
+ assume_role_params["ExternalId"] = aws_provider.external_id
217
+
218
+ # Assume the role
219
+ response = sts_client.assume_role(**assume_role_params)
220
+ credentials = response["Credentials"]
221
+
222
+ # Create new session with temporary credentials
223
+ return boto3.Session(
224
+ aws_access_key_id=credentials["AccessKeyId"],
225
+ aws_secret_access_key=credentials["SecretAccessKey"],
226
+ aws_session_token=credentials["SessionToken"],
227
+ region_name=aws_provider.region or base_session.region_name,
228
+ )
229
+
230
+ except ClientError as e:
231
+ error_code = e.response.get("Error", {}).get("Code", "Unknown")
232
+ raise AWSAuthenticationError(
233
+ f"Failed to assume role '{aws_provider.role_arn}': {error_code} - {e}"
234
+ ) from e
@@ -0,0 +1,67 @@
1
+ """
2
+ Authorization cache for QType interpreter.
3
+
4
+ This module provides a shared LRU cache for authorization sessions and tokens
5
+ across different authentication providers (AWS, OAuth2, API keys, etc.).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ from typing import Any
12
+
13
+ from cachetools import LRUCache
14
+
15
+ # Global LRU cache for authorization sessions with configurable size
16
+ _AUTH_CACHE_MAX_SIZE = int(os.environ.get("AUTH_CACHE_MAX_SIZE", 128))
17
+ _AUTHORIZATION_CACHE: LRUCache[Any, Any] = LRUCache(
18
+ maxsize=_AUTH_CACHE_MAX_SIZE
19
+ )
20
+
21
+
22
+ def get_cached_auth(auth_provider: Any) -> Any | None:
23
+ """
24
+ Get a cached authorization session for the given provider.
25
+
26
+ Args:
27
+ auth_provider: Authorization provider instance (must be hashable)
28
+
29
+ Returns:
30
+ Cached session/token or None if not found
31
+ """
32
+ return _AUTHORIZATION_CACHE.get(auth_provider)
33
+
34
+
35
+ def cache_auth(auth_provider: Any, session: Any) -> None:
36
+ """
37
+ Cache an authorization session for the given provider.
38
+
39
+ Args:
40
+ auth_provider: Authorization provider instance (must be hashable)
41
+ session: Session or token to cache
42
+ """
43
+ _AUTHORIZATION_CACHE[auth_provider] = session
44
+
45
+
46
+ def clear_auth_cache() -> None:
47
+ """
48
+ Clear all cached authorization sessions.
49
+
50
+ This can be useful for testing or when credential configurations change.
51
+ """
52
+ _AUTHORIZATION_CACHE.clear()
53
+
54
+
55
+ def get_cache_info() -> dict[str, Any]:
56
+ """
57
+ Get information about the current state of the authorization cache.
58
+
59
+ Returns:
60
+ Dictionary with cache statistics and configuration
61
+ """
62
+ return {
63
+ "max_size": _AUTH_CACHE_MAX_SIZE,
64
+ "current_size": len(_AUTHORIZATION_CACHE),
65
+ "hits": getattr(_AUTHORIZATION_CACHE, "hits", 0),
66
+ "misses": getattr(_AUTHORIZATION_CACHE, "misses", 0),
67
+ }
@@ -0,0 +1,103 @@
1
+ """
2
+ Generic authorization context manager for QType interpreter.
3
+
4
+ This module provides a unified context manager that can handle any AuthorizationProvider
5
+ type and return the appropriate session or provider instance.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from contextlib import contextmanager
11
+ from typing import Generator
12
+
13
+ import boto3 # type: ignore[import-untyped]
14
+
15
+ from qtype.interpreter.auth.aws import aws
16
+ from qtype.semantic.model import (
17
+ APIKeyAuthProvider,
18
+ AuthorizationProvider,
19
+ AWSAuthProvider,
20
+ OAuth2AuthProvider,
21
+ )
22
+
23
+
24
+ class UnsupportedAuthProviderError(Exception):
25
+ """Raised when an unsupported authorization provider type is used."""
26
+
27
+ pass
28
+
29
+
30
+ @contextmanager
31
+ def auth(
32
+ auth_provider: AuthorizationProvider,
33
+ ) -> Generator[boto3.Session | APIKeyAuthProvider, None, None]:
34
+ """
35
+ Create an appropriate session or provider instance based on the auth provider type.
36
+
37
+ This context manager dispatches to the appropriate authentication handler based
38
+ on the type of AuthorizationProvider:
39
+ - AWSAuthProvider: Returns a configured boto3.Session
40
+ - APIKeyAuthProvider: Returns the provider instance (contains the API key)
41
+ - OAuth2AuthProvider: Raises NotImplementedError (not yet supported)
42
+
43
+ Args:
44
+ auth_provider: AuthorizationProvider instance of any supported type
45
+
46
+ Yields:
47
+ boto3.Session | APIKeyAuthProvider: The appropriate session or provider instance
48
+
49
+ Raises:
50
+ UnsupportedAuthProviderError: When an unsupported provider type is used
51
+ NotImplementedError: When OAuth2AuthProvider is used (not yet implemented)
52
+
53
+ Example:
54
+ ```python
55
+ from qtype.semantic.model import AWSAuthProvider, APIKeyAuthProvider
56
+ from qtype.interpreter.auth.generic import auth
57
+
58
+ # AWS provider - returns boto3.Session
59
+ aws_auth = AWSAuthProvider(
60
+ id="my-aws-auth",
61
+ type="aws",
62
+ access_key_id="AKIA...",
63
+ secret_access_key="...",
64
+ region="us-east-1"
65
+ )
66
+
67
+ with auth(aws_auth) as session:
68
+ s3_client = session.client("s3")
69
+
70
+ # API Key provider - returns the provider itself
71
+ api_auth = APIKeyAuthProvider(
72
+ id="my-api-auth",
73
+ type="api_key",
74
+ api_key="sk-...",
75
+ host="api.openai.com"
76
+ )
77
+
78
+ with auth(api_auth) as provider:
79
+ headers = {"Authorization": f"Bearer {provider.api_key}"}
80
+ ```
81
+ """
82
+ if isinstance(auth_provider, AWSAuthProvider):
83
+ # Use AWS-specific context manager
84
+ with aws(auth_provider) as session:
85
+ yield session
86
+
87
+ elif isinstance(auth_provider, APIKeyAuthProvider):
88
+ # For API key providers, just return the provider itself
89
+ # The caller can access provider.api_key and provider.host
90
+ yield auth_provider
91
+
92
+ elif isinstance(auth_provider, OAuth2AuthProvider):
93
+ # OAuth2 not yet implemented
94
+ raise NotImplementedError(
95
+ f"OAuth2 authentication is not yet implemented for provider '{auth_provider.id}'"
96
+ )
97
+
98
+ else:
99
+ # Unknown provider type
100
+ raise UnsupportedAuthProviderError(
101
+ f"Unsupported authorization provider type: {type(auth_provider).__name__} "
102
+ f"for provider '{auth_provider.id}'"
103
+ )
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Tuple
5
+
6
+ import pandas as pd
7
+
8
+ from qtype.interpreter.batch.step import batch_execute_step
9
+ from qtype.interpreter.batch.types import BatchConfig
10
+ from qtype.interpreter.batch.utils import reconcile_results_and_errors
11
+ from qtype.semantic.model import Flow, Sink
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def batch_execute_flow(
17
+ flow: Flow,
18
+ inputs: pd.DataFrame,
19
+ batch_config: BatchConfig,
20
+ **kwargs: dict[Any, Any],
21
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
22
+ """Executes a flow in a batch context.
23
+
24
+ Args:
25
+ flow: The flow to execute.
26
+ batch_config: The batch configuration to use.
27
+ **kwargs: Additional keyword arguments to pass to the flow.
28
+
29
+ Returns:
30
+ A list of output variables produced by the flow.
31
+ """
32
+
33
+ previous_outputs = inputs
34
+
35
+ all_errors = []
36
+
37
+ # Iterate over each step in the flow
38
+ for step in flow.steps:
39
+ results: list[pd.DataFrame] = []
40
+ errors: list[pd.DataFrame] = []
41
+
42
+ if isinstance(step, Sink):
43
+ # Send the entire batch to the sink
44
+ batch_results, batch_errors = batch_execute_step(
45
+ step, previous_outputs, batch_config
46
+ )
47
+ results.append(batch_results)
48
+ if len(batch_errors) > 1:
49
+ errors.append(batch_errors)
50
+ else:
51
+ # batch the current data into dataframes of max size batch_size
52
+ batch_size = batch_config.batch_size
53
+ for start in range(0, len(previous_outputs), batch_size):
54
+ end = start + batch_size
55
+ batch = previous_outputs.iloc[start:end]
56
+ # Execute the step with the current batch
57
+ batch_results, batch_errors = batch_execute_step(
58
+ step, batch, batch_config
59
+ )
60
+
61
+ results.append(batch_results)
62
+ if len(batch_errors) > 1:
63
+ errors.append(batch_errors)
64
+
65
+ previous_outputs, errors_df = reconcile_results_and_errors(
66
+ results, errors
67
+ )
68
+
69
+ if len(errors_df):
70
+ all_errors.append(errors_df)
71
+ if batch_config.write_errors_to:
72
+ output_file = (
73
+ f"{batch_config.write_errors_to}/{step.id}.errors.parquet"
74
+ )
75
+ try:
76
+ errors_df.to_parquet(
77
+ output_file, engine="pyarrow", compression="snappy"
78
+ )
79
+ logging.info(
80
+ f"Saved errors for step {step.id} to {output_file}"
81
+ )
82
+ except Exception as e:
83
+ logging.warning(
84
+ f"Could not save errors step {step.id} to {output_file}",
85
+ exc_info=e,
86
+ stack_info=True,
87
+ )
88
+
89
+ # Return the last steps results and errors
90
+ rv_errors = (
91
+ pd.concat(all_errors, ignore_index=True)
92
+ if len(all_errors)
93
+ else pd.DataFrame({})
94
+ )
95
+ return previous_outputs, rv_errors
@@ -0,0 +1,95 @@
1
+ from typing import Any, Tuple
2
+
3
+ import boto3 # type: ignore[import-untyped]
4
+ import pandas as pd
5
+ import sqlalchemy
6
+ from sqlalchemy import create_engine
7
+ from sqlalchemy.exc import SQLAlchemyError
8
+
9
+ from qtype.base.exceptions import InterpreterError
10
+ from qtype.interpreter.auth.generic import auth
11
+ from qtype.interpreter.batch.types import BatchConfig, ErrorMode
12
+ from qtype.interpreter.batch.utils import (
13
+ reconcile_results_and_errors,
14
+ validate_inputs,
15
+ )
16
+ from qtype.semantic.model import SQLSource
17
+
18
+
19
+ def to_output_columns(
20
+ df: pd.DataFrame, output_columns: set[str]
21
+ ) -> pd.DataFrame:
22
+ """Filters the DataFrame to only include specified output columns.
23
+
24
+ Args:
25
+ df: The input DataFrame.
26
+ output_columns: A set of column names to retain in the DataFrame.
27
+
28
+ Returns:
29
+ A DataFrame containing only the specified output columns.
30
+ """
31
+ if len(df) == 0:
32
+ return df
33
+ missing = output_columns - set(df.columns)
34
+ if missing:
35
+ raise InterpreterError(
36
+ f"SQL Result was missing expected columns: {','.join(missing)}, it has columns: {','.join(df.columns)}"
37
+ )
38
+
39
+ return df[[col for col in df.columns if col in output_columns]]
40
+
41
+
42
+ def execute_sql_source(
43
+ step: SQLSource,
44
+ inputs: pd.DataFrame,
45
+ batch_config: BatchConfig,
46
+ **kwargs: dict[Any, Any],
47
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
48
+ """Executes a SQLSource step to retrieve data from a SQL database.
49
+
50
+ Args:
51
+ step: The SQLSource step to execute.
52
+
53
+ Returns:
54
+ A tuple containing two DataFrames:
55
+ - The first DataFrame contains the successfully retrieved data.
56
+ - The second DataFrame contains rows that encountered errors with an 'error' column.
57
+ """
58
+ # Create a database engine
59
+ validate_inputs(inputs, step)
60
+
61
+ connect_args = {}
62
+ if step.auth:
63
+ with auth(step.auth) as creds:
64
+ if isinstance(creds, boto3.Session):
65
+ connect_args["session"] = creds
66
+ engine = create_engine(step.connection, connect_args=connect_args)
67
+
68
+ output_columns = {output.id for output in step.outputs}
69
+
70
+ results = []
71
+ errors = []
72
+ step_inputs = {i.id for i in step.inputs}
73
+ for _, row in inputs.iterrows():
74
+ try:
75
+ # Make a dictionary of column_name: value from row
76
+ params = {col: row[col] for col in row.index if col in step_inputs}
77
+ # Execute the query and fetch the results into a DataFrame
78
+ with engine.connect() as connection:
79
+ result = connection.execute(
80
+ sqlalchemy.text(step.query),
81
+ parameters=params if len(params) else None,
82
+ )
83
+ df = pd.DataFrame(
84
+ result.fetchall(), columns=list(result.keys())
85
+ )
86
+ df = to_output_columns(df, output_columns)
87
+ results.append(df)
88
+ except SQLAlchemyError as e:
89
+ if batch_config.error_mode == ErrorMode.FAIL:
90
+ raise e
91
+ # If there's an error, return an empty DataFrame and the error message
92
+ error_df = pd.DataFrame([{"error": str(e)}])
93
+ errors.append(error_df)
94
+
95
+ return reconcile_results_and_errors(results, errors)
@@ -0,0 +1,63 @@
1
+ from functools import partial
2
+ from typing import Any, Tuple
3
+
4
+ import pandas as pd
5
+
6
+ from qtype.interpreter.batch.sql_source import execute_sql_source
7
+ from qtype.interpreter.batch.types import BatchConfig
8
+ from qtype.interpreter.batch.utils import (
9
+ batch_iterator,
10
+ single_step_adapter,
11
+ validate_inputs,
12
+ )
13
+ from qtype.interpreter.exceptions import InterpreterError
14
+ from qtype.semantic.model import (
15
+ Condition,
16
+ Decoder,
17
+ Flow,
18
+ PromptTemplate,
19
+ Search,
20
+ SQLSource,
21
+ Step,
22
+ Tool,
23
+ )
24
+
25
+ SINGLE_WRAP_STEPS = {Decoder, Condition, PromptTemplate, Search, Tool}
26
+
27
+
28
+ def batch_execute_step(
29
+ step: Step,
30
+ inputs: pd.DataFrame,
31
+ batch_config: BatchConfig,
32
+ **kwargs: dict[str, Any],
33
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
34
+ """
35
+ Executes a given step in a batch processing pipeline.
36
+
37
+ Args:
38
+ step (Step): The step to be executed.
39
+ inputs (pd.DataFrame): The input data for the step.
40
+ batch_config (BatchConfig): Configuration for batch processing.
41
+ **kwargs: Additional keyword arguments.
42
+
43
+ Returns:
44
+ Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the output results and any rows that returned errors.
45
+ """
46
+
47
+ validate_inputs(inputs, step)
48
+
49
+ if isinstance(step, Flow):
50
+ from qtype.interpreter.batch.flow import batch_execute_flow
51
+
52
+ return batch_execute_flow(step, inputs, batch_config, **kwargs)
53
+ elif isinstance(step, SQLSource):
54
+ return execute_sql_source(step, inputs, batch_config, **kwargs)
55
+ elif step in SINGLE_WRAP_STEPS:
56
+ return batch_iterator(
57
+ f=partial(single_step_adapter, step=step),
58
+ batch=inputs,
59
+ batch_config=batch_config,
60
+ )
61
+ # TODO: implement batching for multi-row steps. For example, llm inference can be sped up in batch...
62
+ else:
63
+ raise InterpreterError(f"Unsupported step type: {type(step).__name__}")
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class ErrorMode(str, Enum):
9
+ """Error handling mode for batch processing."""
10
+
11
+ FAIL = "fail"
12
+ DROP = "drop"
13
+
14
+
15
+ class BatchConfig(BaseModel):
16
+ """Configuration for batch execution.
17
+
18
+ Attributes:
19
+ num_workers: Number of async workers for batch operations.
20
+ batch_size: Maximum number of rows to send to a step at a time.
21
+ error_mode: Error handling mode for batch processing.
22
+ """
23
+
24
+ num_workers: int = Field(
25
+ default=4,
26
+ description="Number of async workers for batch operations",
27
+ gt=0,
28
+ )
29
+ batch_size: int = Field(
30
+ default=512,
31
+ description="Max number of rows to send to a step at a time",
32
+ gt=0,
33
+ )
34
+ error_mode: ErrorMode = Field(
35
+ default=ErrorMode.FAIL,
36
+ description="Error handling mode for batch processing",
37
+ )
38
+ write_errors_to: str | None = Field(
39
+ default=None,
40
+ description="If error mode is DROP, the errors for any step are saved to this directory",
41
+ )