qtype 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. qtype/application/converters/tools_from_api.py +476 -11
  2. qtype/application/converters/tools_from_module.py +37 -13
  3. qtype/application/converters/types.py +17 -3
  4. qtype/application/facade.py +17 -20
  5. qtype/commands/convert.py +36 -2
  6. qtype/commands/generate.py +48 -0
  7. qtype/commands/run.py +1 -0
  8. qtype/commands/serve.py +11 -1
  9. qtype/commands/validate.py +8 -11
  10. qtype/commands/visualize.py +0 -3
  11. qtype/dsl/model.py +190 -4
  12. qtype/dsl/validator.py +2 -1
  13. qtype/interpreter/api.py +5 -1
  14. qtype/interpreter/batch/file_sink_source.py +162 -0
  15. qtype/interpreter/batch/flow.py +1 -1
  16. qtype/interpreter/batch/sql_source.py +3 -6
  17. qtype/interpreter/batch/step.py +12 -1
  18. qtype/interpreter/batch/utils.py +8 -9
  19. qtype/interpreter/step.py +2 -2
  20. qtype/interpreter/steps/tool.py +194 -28
  21. qtype/interpreter/ui/404/index.html +1 -1
  22. qtype/interpreter/ui/404.html +1 -1
  23. qtype/interpreter/ui/_next/static/chunks/393-8fd474427f8e19ce.js +36 -0
  24. qtype/interpreter/ui/_next/static/chunks/{964-ed4ab073db645007.js → 964-2b041321a01cbf56.js} +1 -1
  25. qtype/interpreter/ui/_next/static/chunks/app/{layout-5ccbc44fd528d089.js → layout-a05273ead5de2c41.js} +1 -1
  26. qtype/interpreter/ui/_next/static/chunks/app/page-7e26b6156cfb55d3.js +1 -0
  27. qtype/interpreter/ui/_next/static/chunks/{main-6d261b6c5d6fb6c2.js → main-e26b9cb206da2cac.js} +1 -1
  28. qtype/interpreter/ui/_next/static/chunks/webpack-08642e441b39b6c2.js +1 -0
  29. qtype/interpreter/ui/_next/static/css/b40532b0db09cce3.css +3 -0
  30. qtype/interpreter/ui/_next/static/media/4cf2300e9c8272f7-s.p.woff2 +0 -0
  31. qtype/interpreter/ui/index.html +1 -1
  32. qtype/interpreter/ui/index.txt +4 -4
  33. qtype/loader.py +8 -2
  34. qtype/semantic/generate.py +6 -2
  35. qtype/semantic/model.py +132 -77
  36. qtype/semantic/visualize.py +24 -6
  37. {qtype-0.0.11.dist-info → qtype-0.0.13.dist-info}/METADATA +4 -2
  38. {qtype-0.0.11.dist-info → qtype-0.0.13.dist-info}/RECORD +44 -43
  39. qtype/interpreter/ui/_next/static/chunks/736-7fc606e244fedcb1.js +0 -36
  40. qtype/interpreter/ui/_next/static/chunks/app/page-c72e847e888e549d.js +0 -1
  41. qtype/interpreter/ui/_next/static/chunks/webpack-8289c17c67827f22.js +0 -1
  42. qtype/interpreter/ui/_next/static/css/a262c53826df929b.css +0 -3
  43. qtype/interpreter/ui/_next/static/media/569ce4b8f30dc480-s.p.woff2 +0 -0
  44. /qtype/interpreter/ui/_next/static/{OT8QJQW3J70VbDWWfrEMT → nUaw6_IwRwPqkzwe5s725}/_buildManifest.js +0 -0
  45. /qtype/interpreter/ui/_next/static/{OT8QJQW3J70VbDWWfrEMT → nUaw6_IwRwPqkzwe5s725}/_ssgManifest.js +0 -0
  46. {qtype-0.0.11.dist-info → qtype-0.0.13.dist-info}/WHEEL +0 -0
  47. {qtype-0.0.11.dist-info → qtype-0.0.13.dist-info}/entry_points.txt +0 -0
  48. {qtype-0.0.11.dist-info → qtype-0.0.13.dist-info}/licenses/LICENSE +0 -0
  49. {qtype-0.0.11.dist-info → qtype-0.0.13.dist-info}/top_level.txt +0 -0
qtype/commands/convert.py CHANGED
@@ -9,13 +9,48 @@ import logging
9
9
  from pathlib import Path
10
10
 
11
11
  from qtype.application.facade import QTypeFacade
12
+ from qtype.dsl.model import Application, ToolList
12
13
 
13
14
  logger = logging.getLogger(__name__)
14
15
 
15
16
 
16
17
  def convert_api(args: argparse.Namespace) -> None:
17
18
  """Convert API specification to qtype format."""
18
- raise NotImplementedError("API conversion is not implemented yet.")
19
+ from qtype.application.converters.tools_from_api import tools_from_api
20
+
21
+ try:
22
+ api_name, auths, tools, types = tools_from_api(args.api_spec)
23
+ if not tools:
24
+ raise ValueError(
25
+ f"No tools found from the API specification: {args.api_spec}"
26
+ )
27
+ if not auths and not types:
28
+ doc = ToolList(
29
+ root=list(tools),
30
+ )
31
+ else:
32
+ doc: Application | ToolList = Application(
33
+ id=api_name,
34
+ description=f"Tools created from API specification {args.api_spec}",
35
+ tools=list(tools),
36
+ types=types,
37
+ auths=auths,
38
+ )
39
+ # Use facade to convert to YAML format
40
+ facade = QTypeFacade()
41
+ content = facade.convert_document(doc)
42
+
43
+ # Write to file or stdout
44
+ if args.output:
45
+ output_path = Path(args.output)
46
+ output_path.write_text(content, encoding="utf-8")
47
+ logger.info(f"✅ Converted tools saved to {output_path}")
48
+ else:
49
+ print(content)
50
+
51
+ except Exception as e:
52
+ logger.error(f"❌ Conversion failed: {e}")
53
+ raise
19
54
 
20
55
 
21
56
  def convert_module(args: argparse.Namespace) -> None:
@@ -23,7 +58,6 @@ def convert_module(args: argparse.Namespace) -> None:
23
58
  from qtype.application.converters.tools_from_module import (
24
59
  tools_from_module,
25
60
  )
26
- from qtype.dsl.model import Application, ToolList
27
61
 
28
62
  try:
29
63
  tools, types = tools_from_module(args.module_path)
@@ -84,8 +84,56 @@ def generate_schema(args: argparse.Namespace) -> None:
84
84
  'output' attribute specifying the output file path.
85
85
  """
86
86
  schema = Document.model_json_schema()
87
+
87
88
  # Add the $schema property to indicate JSON Schema version
88
89
  schema["$schema"] = "http://json-schema.org/draft-07/schema#"
90
+
91
+ # Add custom YAML tag definitions for QType loader features
92
+ if "$defs" not in schema:
93
+ schema["$defs"] = {}
94
+
95
+ # Define custom YAML tags used by QType loader
96
+ schema["$defs"]["qtype_include_tag"] = {
97
+ "type": "string",
98
+ "pattern": "^!include\\s+.+",
99
+ "description": "Include external YAML file using QType's !include tag",
100
+ }
101
+
102
+ schema["$defs"]["qtype_include_raw_tag"] = {
103
+ "type": "string",
104
+ "pattern": "^!include_raw\\s+.+",
105
+ "description": "Include raw text file using QType's !include_raw tag",
106
+ }
107
+
108
+ schema["$defs"]["qtype_env_var"] = {
109
+ "type": "string",
110
+ "pattern": "^.*\\$\\{[^}:]+(?::[^}]*)?\\}.*$",
111
+ "description": "String with environment variable substitution using ${VAR_NAME} or ${VAR_NAME:default} syntax",
112
+ }
113
+
114
+ # Add these custom patterns to string types throughout the schema
115
+ def add_custom_patterns(obj):
116
+ if isinstance(obj, dict):
117
+ if obj.get("type") == "string" and "anyOf" not in obj:
118
+ # Add anyOf to allow either regular strings or custom tag patterns
119
+ original_obj = obj.copy()
120
+ obj.clear()
121
+ obj["anyOf"] = [
122
+ original_obj,
123
+ {"$ref": "#/$defs/qtype_include_tag"},
124
+ {"$ref": "#/$defs/qtype_include_raw_tag"},
125
+ {"$ref": "#/$defs/qtype_env_var"},
126
+ ]
127
+ else:
128
+ for value in obj.values():
129
+ add_custom_patterns(value)
130
+ elif isinstance(obj, list):
131
+ for item in obj:
132
+ add_custom_patterns(item)
133
+
134
+ # Apply custom patterns to the schema
135
+ add_custom_patterns(schema)
136
+
89
137
  output = json.dumps(schema, indent=2)
90
138
  output_path: Optional[str] = getattr(args, "output", None)
91
139
  if output_path:
qtype/commands/run.py CHANGED
@@ -112,6 +112,7 @@ def run_flow(args: Any) -> None:
112
112
  logger.error(f"❌ Execution failed: {e}")
113
113
  except Exception as e:
114
114
  logger.error(f"❌ Unexpected error: {e}", exc_info=True)
115
+ pass
115
116
 
116
117
 
117
118
  def parser(subparsers: argparse._SubParsersAction) -> None:
qtype/commands/serve.py CHANGED
@@ -31,6 +31,7 @@ def serve(args: Any) -> None:
31
31
  logger.info(f"Loading and validating spec: {spec_path}")
32
32
 
33
33
  semantic_model, type_registry = facade.load_semantic_model(spec_path)
34
+ facade.telemetry(semantic_model)
34
35
  logger.info(f"✅ Successfully loaded spec: {spec_path}")
35
36
 
36
37
  # Import APIExecutor and create the FastAPI app
@@ -43,8 +44,17 @@ def serve(args: Any) -> None:
43
44
 
44
45
  logger.info(f"Starting server for: {name}")
45
46
  api_executor = APIExecutor(semantic_model)
47
+
48
+ # Create server info for OpenAPI spec
49
+ servers = [
50
+ {
51
+ "url": f"http://{args.host}:{args.port}",
52
+ "description": "Development server",
53
+ }
54
+ ]
55
+
46
56
  fastapi_app = api_executor.create_app(
47
- name=name, ui_enabled=not args.disable_ui
57
+ name=name, ui_enabled=not args.disable_ui, servers=servers
48
58
  )
49
59
 
50
60
  # Start the server
@@ -10,6 +10,7 @@ import sys
10
10
  from pathlib import Path
11
11
  from typing import Any
12
12
 
13
+ from qtype import dsl
13
14
  from qtype.application.facade import QTypeFacade
14
15
  from qtype.base.exceptions import LoadError, SemanticError, ValidationError
15
16
 
@@ -31,16 +32,11 @@ def main(args: Any) -> None:
31
32
 
32
33
  try:
33
34
  # Use the facade for validation - it will raise exceptions on errors
34
- loaded_data = facade.load_and_validate(spec_path)
35
+ loaded_data, custom_types = facade.load_dsl_document(spec_path)
36
+ if isinstance(loaded_data, dsl.Application):
37
+ loaded_data, custom_types = facade.load_semantic_model(spec_path)
35
38
  logger.info("✅ Validation successful - document is valid.")
36
39
 
37
- # If printing is requested, load and print the document
38
- if args.print:
39
- try:
40
- print(loaded_data.model_dump_json(indent=2, exclude_none=True)) # type: ignore
41
- except Exception as e:
42
- logger.warning(f"Could not print document: {e}")
43
-
44
40
  except LoadError as e:
45
41
  logger.error(f"❌ Failed to load document: {e}")
46
42
  sys.exit(1)
@@ -50,9 +46,10 @@ def main(args: Any) -> None:
50
46
  except SemanticError as e:
51
47
  logger.error(f"❌ Semantic validation failed: {e}")
52
48
  sys.exit(1)
53
- except Exception as e:
54
- logger.error(f"❌ Unexpected error during validation: {e}")
55
- sys.exit(1)
49
+
50
+ # If printing is requested, load and print the document
51
+ if args.print:
52
+ logging.info(facade.convert_document(loaded_data)) # type: ignore
56
53
 
57
54
 
58
55
  def parser(subparsers: argparse._SubParsersAction) -> None:
@@ -68,9 +68,6 @@ def main(args: Any) -> None:
68
68
  except ValidationError as e:
69
69
  logger.error(f"❌ Visualization failed: {e}")
70
70
  exit(1)
71
- except Exception as e:
72
- logger.error(f"❌ Unexpected error: {e}")
73
- exit(1)
74
71
 
75
72
 
76
73
  def parser(subparsers: argparse._SubParsersAction) -> None:
qtype/dsl/model.py CHANGED
@@ -45,6 +45,33 @@ def _resolve_variable_type(
45
45
  return parsed_type
46
46
 
47
47
  # --- Case 1: The type is a string ---
48
+ # Check if it's a list type (e.g., "list[text]")
49
+ if parsed_type.startswith("list[") and parsed_type.endswith("]"):
50
+ # Extract the element type from "list[element_type]"
51
+ element_type_str = parsed_type[5:-1] # Remove "list[" and "]"
52
+
53
+ # Recursively resolve the element type
54
+ element_type = _resolve_variable_type(
55
+ element_type_str, custom_type_registry
56
+ )
57
+
58
+ # Allow both primitive types and custom types (but no nested lists)
59
+ if isinstance(element_type, PrimitiveTypeEnum):
60
+ return ListType(element_type=element_type)
61
+ elif isinstance(element_type, str):
62
+ # This is a custom type reference - store as string for later resolution
63
+ return ListType(element_type=element_type)
64
+ elif element_type in DOMAIN_CLASSES.values():
65
+ # Domain class - store its name as string reference
66
+ for name, cls in DOMAIN_CLASSES.items():
67
+ if cls == element_type:
68
+ return ListType(element_type=name)
69
+ return ListType(element_type=str(element_type))
70
+ else:
71
+ raise ValueError(
72
+ f"List element type must be a primitive type or custom type reference, got: {element_type}"
73
+ )
74
+
48
75
  # Try to resolve it as a primitive type first.
49
76
  try:
50
77
  return PrimitiveTypeEnum(parsed_type)
@@ -107,12 +134,56 @@ class CustomType(StrictBaseModel):
107
134
  properties: dict[str, str]
108
135
 
109
136
 
137
+ class ToolParameter(BaseModel):
138
+ """Defines a tool input or output parameter with type and optional flag."""
139
+
140
+ type: VariableType | str
141
+ optional: bool = Field(
142
+ default=False, description="Whether this parameter is optional"
143
+ )
144
+
145
+ @model_validator(mode="before")
146
+ @classmethod
147
+ def resolve_type(cls, data: Any, info: ValidationInfo) -> Any:
148
+ """
149
+ This validator runs during the main validation pass. It uses the
150
+ context to resolve string-based type references.
151
+ """
152
+ if (
153
+ isinstance(data, dict)
154
+ and "type" in data
155
+ and isinstance(data["type"], str)
156
+ ):
157
+ # Get the registry of custom types from the validation context.
158
+ custom_types = (info.context or {}).get("custom_types", {})
159
+ resolved = _resolve_variable_type(data["type"], custom_types)
160
+ data["type"] = resolved
161
+ return data
162
+
163
+
164
+ class ListType(BaseModel):
165
+ """Represents a list type with a specific element type."""
166
+
167
+ element_type: PrimitiveTypeEnum | str = Field(
168
+ ...,
169
+ description="Type of elements in the list (primitive type or custom type reference)",
170
+ )
171
+
172
+ def __str__(self) -> str:
173
+ """String representation for list type."""
174
+ if isinstance(self.element_type, PrimitiveTypeEnum):
175
+ return f"list[{self.element_type.value}]"
176
+ else:
177
+ return f"list[{self.element_type}]"
178
+
179
+
110
180
  VariableType = (
111
181
  PrimitiveTypeEnum
112
182
  | Type[Embedding]
113
183
  | Type[ChatMessage]
114
184
  | Type[ChatContent]
115
185
  | Type[BaseModel]
186
+ | ListType
116
187
  )
117
188
 
118
189
 
@@ -238,15 +309,24 @@ class Condition(Step):
238
309
  return self
239
310
 
240
311
 
241
- class Tool(Step, ABC):
312
+ class Tool(StrictBaseModel, ABC):
242
313
  """
243
314
  Base class for callable functions or external operations available to the model or as a step in a flow.
244
315
  """
245
316
 
317
+ id: str = Field(..., description="Unique ID of this component.")
246
318
  name: str = Field(..., description="Name of the tool function.")
247
319
  description: str = Field(
248
320
  ..., description="Description of what the tool does."
249
321
  )
322
+ inputs: dict[str, ToolParameter] | None = Field(
323
+ default=None,
324
+ description="Input parameters required by this tool.",
325
+ )
326
+ outputs: dict[str, ToolParameter] | None = Field(
327
+ default=None,
328
+ description="Output parameters produced by this tool.",
329
+ )
250
330
 
251
331
 
252
332
  class PythonFunctionTool(Tool):
@@ -277,6 +357,10 @@ class APITool(Tool):
277
357
  default=None,
278
358
  description="Optional HTTP headers to include in the request.",
279
359
  )
360
+ parameters: dict[str, ToolParameter] | None = Field(
361
+ default=None,
362
+ description="Output parameters produced by this tool.",
363
+ )
280
364
 
281
365
 
282
366
  class LLMInference(Step):
@@ -376,6 +460,23 @@ class Decoder(Step):
376
460
  return self
377
461
 
378
462
 
463
+ class Invoke(Step):
464
+ """Invokes a tool with input and output bindings."""
465
+
466
+ tool: ToolType | str = Field(
467
+ ...,
468
+ description="Tool to invoke.",
469
+ )
470
+ input_bindings: dict[str, str] = Field(
471
+ ...,
472
+ description="Mapping from step input IDs to tool input parameter names.",
473
+ )
474
+ output_bindings: dict[str, str] = Field(
475
+ ...,
476
+ description="Mapping from tool output parameter names to step output IDs.",
477
+ )
478
+
479
+
379
480
  #
380
481
  # ---------------- Observability and Authentication Components ----------------
381
482
  #
@@ -400,6 +501,13 @@ class APIKeyAuthProvider(AuthorizationProvider):
400
501
  )
401
502
 
402
503
 
504
+ class BearerTokenAuthProvider(AuthorizationProvider):
505
+ """Bearer token authentication provider."""
506
+
507
+ type: Literal["bearer_token"] = "bearer_token"
508
+ token: str = Field(..., description="Bearer token for authentication.")
509
+
510
+
403
511
  class OAuth2AuthProvider(AuthorizationProvider):
404
512
  """OAuth2 authentication provider."""
405
513
 
@@ -594,6 +702,38 @@ class SQLSource(Source):
594
702
  return self
595
703
 
596
704
 
705
+ class FileSource(Source):
706
+ """File source that reads data from a file using fsspec-compatible URIs."""
707
+
708
+ path: str | None = Field(
709
+ default=None,
710
+ description="fsspec-compatible URI to read from. If None, expects 'path' input variable.",
711
+ )
712
+
713
+ @model_validator(mode="after")
714
+ def validate_file_source(self) -> "FileSource":
715
+ """Validate that either path is specified or 'path' input variable exists."""
716
+ if self.path is None:
717
+ # Check if 'path' input variable exists
718
+ if self.inputs is None:
719
+ raise ValueError(
720
+ "FileSource must either specify 'path' field or have a 'path' input variable."
721
+ )
722
+
723
+ path_input_exists = any(
724
+ (isinstance(inp, Variable) and inp.id == "path")
725
+ or (isinstance(inp, str) and inp == "path")
726
+ for inp in self.inputs
727
+ )
728
+
729
+ if not path_input_exists:
730
+ raise ValueError(
731
+ "FileSource must either specify 'path' field or have a 'path' input variable."
732
+ )
733
+
734
+ return self
735
+
736
+
597
737
  class Sink(Step):
598
738
  """Base class for data sinks"""
599
739
 
@@ -606,6 +746,47 @@ class Sink(Step):
606
746
  )
607
747
 
608
748
 
749
+ class FileSink(Sink):
750
+ """File sink that writes data to a file using fsspec-compatible URIs."""
751
+
752
+ path: str | None = Field(
753
+ default=None,
754
+ description="fsspec-compatible URI to write to. If None, expects 'path' input variable.",
755
+ )
756
+
757
+ @model_validator(mode="after")
758
+ def validate_file_sink(self) -> "FileSink":
759
+ """Validate that either path is specified or 'path' input variable exists."""
760
+ # Ensure user does not set any output variables
761
+ if self.outputs is not None and len(self.outputs) > 0:
762
+ raise ValueError(
763
+ "FileSink outputs are automatically generated. Do not specify outputs."
764
+ )
765
+
766
+ # Automatically set the output variable
767
+ self.outputs = [Variable(id=f"{self.id}-file-uri", type="text")]
768
+
769
+ if self.path is None:
770
+ # Check if 'path' input variable exists
771
+ if self.inputs is None:
772
+ raise ValueError(
773
+ "FileSink must either specify 'path' field or have a 'path' input variable."
774
+ )
775
+
776
+ path_input_exists = any(
777
+ (isinstance(inp, Variable) and inp.id == "path")
778
+ or (isinstance(inp, str) and inp == "path")
779
+ for inp in self.inputs
780
+ )
781
+
782
+ if not path_input_exists:
783
+ raise ValueError(
784
+ "FileSink must either specify 'path' field or have a 'path' input variable."
785
+ )
786
+
787
+ return self
788
+
789
+
609
790
  #
610
791
  # ---------------- Retrieval Augmented Generation Components ----------------
611
792
  #
@@ -704,11 +885,15 @@ ToolType = Union[
704
885
  ]
705
886
 
706
887
  # Create a union type for all source types
707
- SourceType = Union[SQLSource,]
888
+ SourceType = Union[
889
+ FileSource,
890
+ SQLSource,
891
+ ]
708
892
 
709
893
  # Create a union type for all authorization provider types
710
894
  AuthProviderType = Union[
711
895
  APIKeyAuthProvider,
896
+ BearerTokenAuthProvider,
712
897
  AWSAuthProvider,
713
898
  OAuth2AuthProvider,
714
899
  ]
@@ -716,15 +901,16 @@ AuthProviderType = Union[
716
901
  # Create a union type for all step types
717
902
  StepType = Union[
718
903
  Agent,
719
- APITool,
720
904
  Condition,
721
905
  Decoder,
722
906
  DocumentSearch,
907
+ FileSink,
908
+ FileSource,
723
909
  Flow,
724
910
  IndexUpsert,
911
+ Invoke,
725
912
  LLMInference,
726
913
  PromptTemplate,
727
- PythonFunctionTool,
728
914
  SQLSource,
729
915
  Sink,
730
916
  VectorSearch,
qtype/dsl/validator.py CHANGED
@@ -21,7 +21,8 @@ class DuplicateComponentError(QTypeValidationError):
21
21
  existing_obj: qtype.dsl.domain_types.StrictBaseModel,
22
22
  ):
23
23
  super().__init__(
24
- f"Duplicate component with ID '{obj_id}' found:\n{found_obj.model_dump_json()}\nAlready exists:\n{existing_obj.model_dump_json()}"
24
+ f'Duplicate component with ID "{obj_id}" found.'
25
+ # f"Duplicate component with ID \"{obj_id}\" found:\n{found_obj.model_dump_json()}\nAlready exists:\n{existing_obj.model_dump_json()}"
25
26
  )
26
27
 
27
28
 
qtype/interpreter/api.py CHANGED
@@ -37,6 +37,7 @@ class APIExecutor:
37
37
  name: str | None = None,
38
38
  ui_enabled: bool = True,
39
39
  fast_api_args: dict | None = None,
40
+ servers: list[dict] | None = None,
40
41
  ) -> FastAPI:
41
42
  """Create FastAPI app with dynamic endpoints."""
42
43
  if fast_api_args is None:
@@ -45,6 +46,10 @@ class APIExecutor:
45
46
  "redoc_url": "/redoc",
46
47
  }
47
48
 
49
+ # Add servers to FastAPI kwargs if provided
50
+ if servers is not None:
51
+ fast_api_args["servers"] = servers
52
+
48
53
  app = FastAPI(title=name or "QType API", **fast_api_args)
49
54
 
50
55
  # Serve static UI files if they exist
@@ -158,7 +163,6 @@ class APIExecutor:
158
163
  status_code=400,
159
164
  detail=f"Required input '{var.id}' not provided",
160
165
  )
161
- return flow_copy
162
166
  # Execute the flow
163
167
  result_vars = execute_flow(flow_copy)
164
168
 
@@ -0,0 +1,162 @@
1
+ from typing import Any, Tuple
2
+
3
+ import fsspec # type: ignore[import-untyped]
4
+ import pandas as pd
5
+
6
+ from qtype.base.exceptions import InterpreterError
7
+ from qtype.interpreter.batch.types import BatchConfig, ErrorMode
8
+ from qtype.interpreter.batch.utils import reconcile_results_and_errors
9
+ from qtype.semantic.model import FileSink, FileSource
10
+
11
+
12
+ def execute_file_source(
13
+ step: FileSource,
14
+ inputs: pd.DataFrame,
15
+ batch_config: BatchConfig,
16
+ **kwargs: dict[Any, Any],
17
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
18
+ """Executes a FileSource step to read data from a file using fsspec.
19
+
20
+ Args:
21
+ step: The FileSource step to execute.
22
+ inputs: Input DataFrame (may contain path variable).
23
+ batch_config: Configuration for batch processing.
24
+ **kwargs: Additional keyword arguments.
25
+
26
+ Returns:
27
+ A tuple containing two DataFrames:
28
+ - The first DataFrame contains the successfully read data.
29
+ - The second DataFrame contains rows that encountered errors with an 'error' column.
30
+ """
31
+ output_columns = {output.id for output in step.outputs}
32
+
33
+ results = []
34
+ errors = []
35
+
36
+ # FileSource has cardinality 'many', so it reads once and produces multiple output rows
37
+ # We process each input row (which might have different paths) separately
38
+ for _, row in inputs.iterrows():
39
+ try:
40
+ file_path = step.path if step.path else row.get("path")
41
+ if not file_path:
42
+ raise InterpreterError(
43
+ f"No path specified for {type(step).__name__}. "
44
+ "Either set the 'path' field or provide a 'path' input variable."
45
+ )
46
+
47
+ # Use fsspec to open the file and read with pandas
48
+ with fsspec.open(file_path, "rb") as file_handle:
49
+ df = pd.read_parquet(file_handle) # type: ignore[arg-type]
50
+
51
+ # Filter to only the expected output columns if they exist
52
+ if output_columns and len(df) > 0:
53
+ available_columns = set(df.columns)
54
+ missing_columns = output_columns - available_columns
55
+ if missing_columns:
56
+ raise InterpreterError(
57
+ f"File {file_path} missing expected columns: {', '.join(missing_columns)}. "
58
+ f"Available columns: {', '.join(available_columns)}"
59
+ )
60
+ df = df[[col for col in df.columns if col in output_columns]]
61
+
62
+ results.append(df)
63
+
64
+ except Exception as e:
65
+ if batch_config.error_mode == ErrorMode.FAIL:
66
+ raise e
67
+
68
+ # If there's an error, add it to the errors list
69
+ error_df = pd.DataFrame([{"error": str(e)}])
70
+ errors.append(error_df)
71
+
72
+ return reconcile_results_and_errors(results, errors)
73
+
74
+
75
+ def execute_file_sink(
76
+ step: FileSink,
77
+ inputs: pd.DataFrame,
78
+ batch_config: BatchConfig,
79
+ **kwargs: dict[Any, Any],
80
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
81
+ """Executes a FileSink step to write data to a file using fsspec.
82
+
83
+ Args:
84
+ step: The FileSink step to execute.
85
+ inputs: Input DataFrame containing data to write.
86
+ batch_config: Configuration for batch processing.
87
+ **kwargs: Additional keyword arguments.
88
+
89
+ Returns:
90
+ A tuple containing two DataFrames:
91
+ - The first DataFrame contains success indicators.
92
+ - The second DataFrame contains rows that encountered errors with an 'error' column.
93
+ """
94
+ # this is enforced by the dsl, but we'll check here to confirm
95
+ if len(step.outputs) > 1:
96
+ raise InterpreterError(
97
+ f"There should only be one output variable for {type(step).__name__}."
98
+ )
99
+ output_column_name = step.outputs[0].id
100
+
101
+ # make a list of all file paths
102
+ try:
103
+ if step.path:
104
+ file_paths = [step.path] * len(inputs)
105
+ else:
106
+ if "path" not in inputs.columns:
107
+ raise InterpreterError(
108
+ f"No path specified for {type(step).__name__}. "
109
+ "Either set the 'path' field or provide a 'path' input variable."
110
+ )
111
+ file_paths = inputs["path"].tolist()
112
+ except Exception as e:
113
+ if batch_config.error_mode == ErrorMode.FAIL:
114
+ raise e
115
+ # If we can't get the path, we can't proceed
116
+ return pd.DataFrame(), pd.DataFrame([{"error": str(e)}])
117
+
118
+ # Check if all paths are the same
119
+ unique_paths = list(set(file_paths))
120
+
121
+ if len(unique_paths) == 1:
122
+ # All rows write to the same file - process as one batch
123
+ file_path = unique_paths[0]
124
+
125
+ try:
126
+ # Use fsspec to write the parquet file
127
+ input_columns = [i.id for i in step.inputs]
128
+ with fsspec.open(file_path, "wb") as file_handle:
129
+ inputs[input_columns].to_parquet(file_handle, index=False) # type: ignore[arg-type]
130
+
131
+ inputs[output_column_name] = file_path
132
+ return inputs, pd.DataFrame()
133
+
134
+ except Exception as e:
135
+ if batch_config.error_mode == ErrorMode.FAIL:
136
+ raise e
137
+
138
+ # If there's an error, return error for all rows
139
+ error_df = pd.DataFrame([{"error": str(e)}])
140
+ return pd.DataFrame(), error_df
141
+
142
+ else:
143
+ # Multiple unique paths - split inputs and process recursively
144
+ all_results = []
145
+ all_errors = []
146
+
147
+ for unique_path in unique_paths:
148
+ # Create mask for rows with this path
149
+ path_mask = [p == unique_path for p in file_paths]
150
+ sliced_inputs = inputs[path_mask].copy()
151
+
152
+ # Recursively call execute_file_sink with the sliced DataFrame
153
+ results, errors = execute_file_sink(
154
+ step, sliced_inputs, batch_config, **kwargs
155
+ )
156
+
157
+ if len(results) > 0:
158
+ all_results.append(results)
159
+ if len(errors) > 0:
160
+ all_errors.append(errors)
161
+
162
+ return reconcile_results_and_errors(all_results, all_errors)