wt-runner 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wt_runner/__init__.py ADDED
@@ -0,0 +1,36 @@
1
+ """wt-runner: FastAPI application for workflow execution.
2
+
3
+ This package provides a FastAPI web service for executing workflows using
4
+ wt-invokers. It includes endpoints for:
5
+ - Running workflows with various configurations
6
+ - Processing Pub/Sub messages
7
+ - Retrieving workflow metadata and schemas
8
+ - Converting between parameter formats
9
+ """
10
+
11
+ from wt_runner.app import app
12
+ from wt_runner.testing import Case, CaseRunner
13
+ from wt_runner.tracing import (
14
+ TraceContextHeaders,
15
+ attach_context,
16
+ build_context_headers,
17
+ configure_tracer,
18
+ )
19
+
20
+ try:
21
+ from wt_runner._version import __version__, __version_tuple__
22
+ except ImportError:
23
+ __version__ = "unknown"
24
+ __version_tuple__ = (0, 0, 0)
25
+
26
+ __all__ = [
27
+ "app",
28
+ "Case",
29
+ "CaseRunner",
30
+ "configure_tracer",
31
+ "attach_context",
32
+ "build_context_headers",
33
+ "TraceContextHeaders",
34
+ "__version__",
35
+ "__version_tuple__",
36
+ ]
wt_runner/_version.py ADDED
@@ -0,0 +1,34 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
12
+
13
+ TYPE_CHECKING = False
14
+ if TYPE_CHECKING:
15
+ from typing import Tuple
16
+ from typing import Union
17
+
18
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
20
+ else:
21
+ VERSION_TUPLE = object
22
+ COMMIT_ID = object
23
+
24
+ version: str
25
+ __version__: str
26
+ __version_tuple__: VERSION_TUPLE
27
+ version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
30
+
31
+ __version__ = version = '0.1.3'
32
+ __version_tuple__ = version_tuple = (0, 1, 3)
33
+
34
+ __commit_id__ = commit_id = None
wt_runner/app.py ADDED
@@ -0,0 +1,758 @@
1
+ """FastAPI application for workflow execution."""
2
+
3
+ import base64
4
+ import binascii
5
+ import json
6
+ import logging
7
+ import os
8
+ import traceback
9
+ from collections.abc import AsyncIterator
10
+ from contextlib import asynccontextmanager
11
+ from dataclasses import asdict, dataclass
12
+ from importlib.metadata import PackageNotFoundError, version
13
+ from io import StringIO
14
+ from pathlib import Path
15
+ from typing import Any, Literal
16
+ from urllib.parse import urlparse
17
+
18
+ import ruamel.yaml
19
+ from fastapi import (
20
+ Depends,
21
+ FastAPI,
22
+ Header,
23
+ HTTPException,
24
+ Query,
25
+ Request,
26
+ Response,
27
+ status,
28
+ )
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.middleware.gzip import GZipMiddleware
31
+ from fastapi.responses import JSONResponse
32
+ from opentelemetry import trace as otel_trace
33
+ from pydantic import BaseModel, Field, SecretStr
34
+ from rattler import MatchSpec
35
+ from wt_invokers import (
36
+ AbstractInvoker,
37
+ CloudBatchInvoker,
38
+ LocalSubprocessInvoker,
39
+ )
40
+
41
+ from wt_runner.tracing import (
42
+ TraceContextHeaders,
43
+ attach_context,
44
+ build_context_headers,
45
+ configure_tracer,
46
+ make_otel_console_exporter_file_dst_kws,
47
+ )
48
+
49
+ # Optional imports for ecoscope integration
50
+ try:
51
+ from ecoscope_eda_core.messages.commands import ( # type: ignore[import-untyped,import-not-found,unused-ignore]
52
+ InvokerType as EcoscopeInvokerType,
53
+ )
54
+ from ecoscope_eda_core.messages.commands import (
55
+ RunWorkflow,
56
+ RunWorkflowParams,
57
+ )
58
+ from ecoscope_eda_core.workflows import ( # type: ignore[import-untyped,import-not-found,unused-ignore]
59
+ get_results_json as ecoscope_get_results_json,
60
+ )
61
+
62
+ HAS_ECOSCOPE = True
63
+ InvokerType = EcoscopeInvokerType
64
+ except ImportError:
65
+ HAS_ECOSCOPE = False
66
+ # Define InvokerType locally when ecoscope_eda_core is not available
67
+ InvokerType = Literal[
68
+ "BlockingLocalSubprocessInvoker",
69
+ "AsyncLocalSubprocessInvoker",
70
+ "CloudBatchInvoker",
71
+ ]
72
+ RunWorkflow = None
73
+ RunWorkflowParams = None
74
+ ecoscope_get_results_json = None
75
+
76
+ import obstore
77
+
78
+ # Invoker registry mapping invoker names to classes
79
+ INVOKERS: dict[str, type[AbstractInvoker]] = {
80
+ "BlockingLocalSubprocessInvoker": LocalSubprocessInvoker,
81
+ "AsyncLocalSubprocessInvoker": LocalSubprocessInvoker,
82
+ "CloudBatchInvoker": CloudBatchInvoker,
83
+ }
84
+
85
+ TITLE = "wt-runner"
86
+ TIMEOUT_EXPIRED_ERROR_MSG = (
87
+ "The workflow timed out. Consider reducing the amount of data being processed."
88
+ )
89
+ PUBSUB_ACK_MAX_TIMEOUT = 570 # seconds
90
+
91
+
92
+ async def get_results_json(results_url: str) -> dict[str, Any]:
93
+ """Get workflow results from results URL.
94
+
95
+ Args:
96
+ results_url: URL or path to results
97
+
98
+ Returns:
99
+ Results dictionary
100
+
101
+ Raises:
102
+ RuntimeError: If results cannot be retrieved
103
+ """
104
+ if HAS_ECOSCOPE and ecoscope_get_results_json is not None:
105
+ result: dict[str, Any] = await ecoscope_get_results_json(results_url)
106
+ return result
107
+
108
+ # Fallback: use obstore directly
109
+ store = obstore.store.from_url(results_url)
110
+ get_result = await store.get_async("result.json")
111
+ result_bytes = bytes(await get_result.bytes_async())
112
+ result_json: dict[str, Any] = json.loads(result_bytes)
113
+ return result_json
114
+
115
+
116
+ def get_otel_exporter() -> Literal["console", "gcp"] | None:
117
+ """Get OpenTelemetry exporter type from environment.
118
+
119
+ Returns:
120
+ Exporter type or None
121
+ """
122
+ value = os.environ.get("ECOSCOPE_WORKFLOWS_OTEL_EXPORTER")
123
+ if value == "console":
124
+ return "console"
125
+ if value == "gcp":
126
+ return "gcp"
127
+ return None
128
+
129
+
130
+ def get_otel_console_exporter_dst() -> Literal["stdout", "file"]:
131
+ """Get console exporter destination from environment.
132
+
133
+ Returns:
134
+ Destination type (stdout or file)
135
+ """
136
+ value = os.environ.get("ECOSCOPE_WORKFLOWS_OTEL_CONSOLE_EXPORTER_DST", "file")
137
+ if value == "stdout":
138
+ return "stdout"
139
+ return "file"
140
+
141
+
142
+ def get_otel_console_exporter_file_dst_target_dir() -> str | None:
143
+ """Get console exporter file destination directory from environment.
144
+
145
+ Returns:
146
+ Target directory path or None
147
+ """
148
+ return os.environ.get("ECOSCOPE_WORKFLOWS_OTEL_CONSOLE_EXPORTER_FILE_DST_TARGET_DIR")
149
+
150
+
151
+ @dataclass
152
+ class SpanAttributes:
153
+ """Attributes for tracing spans."""
154
+
155
+ workflow_run_id: str
156
+ invoker_type: str
157
+
158
+
159
+ @asynccontextmanager
160
+ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
161
+ """FastAPI lifespan context manager for startup/shutdown.
162
+
163
+ Configures OpenTelemetry tracer on startup.
164
+
165
+ Args:
166
+ app: FastAPI application instance
167
+ """
168
+ # on app startup
169
+ otel_exporter_kws: dict[str, Any] = {}
170
+ otel_exporter = get_otel_exporter()
171
+ if otel_exporter == "console" and get_otel_console_exporter_dst() == "file":
172
+ if not (file_dst_target_dir := get_otel_console_exporter_file_dst_target_dir()):
173
+ raise RuntimeError(
174
+ "If OTEL_EXPORTER is 'console' with the destination as 'file', "
175
+ "then OTEL_CONSOLE_EXPORTER_FILE_DST_TARGET_DIR must be set via the "
176
+ "env var 'ECOSCOPE_WORKFLOWS_OTEL_CONSOLE_EXPORTER_FILE_DST_TARGET_DIR'."
177
+ )
178
+ otel_exporter_kws |= make_otel_console_exporter_file_dst_kws(Path(file_dst_target_dir))
179
+ configure_tracer(
180
+ name=app.title,
181
+ version=app.version,
182
+ exporter=otel_exporter,
183
+ exporter_kws=otel_exporter_kws,
184
+ )
185
+ yield
186
+ # on app shutdown
187
+
188
+
189
+ try:
190
+ _version = version(TITLE)
191
+ except PackageNotFoundError:
192
+ _version = "unknown"
193
+
194
+ app = FastAPI(title=TITLE, version=_version, lifespan=lifespan)
195
+ app.add_middleware(
196
+ CORSMiddleware,
197
+ allow_origins=["*"],
198
+ allow_credentials=True,
199
+ allow_methods=["POST"],
200
+ allow_headers=["*"],
201
+ )
202
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
203
+
204
+
205
+ class Lithops(BaseModel):
206
+ """Lithops configuration."""
207
+
208
+ backend: Literal["localhost", "gcp_cloudrun"] = "localhost"
209
+ storage: Literal["localhost", "gcp_storage"] = "localhost"
210
+ log_level: str = "DEBUG"
211
+ data_limit: int = 256
212
+
213
+
214
+ class GCP(BaseModel):
215
+ """Google Cloud Platform configuration."""
216
+
217
+ region: str = "us-central1"
218
+ credentials_path: str = "placeholder" # os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
219
+
220
+
221
+ class GCPCloudRun(BaseModel):
222
+ """Google Cloud Run configuration."""
223
+
224
+ runtime: str = "placeholder" # os.environ["LITHOPS_GCP_CLOUDRUN_RUNTIME"]
225
+ runtime_cpu: int = 2
226
+ runtime_memory: int = 1000
227
+
228
+
229
+ class LithopsConfig(BaseModel):
230
+ """Complete Lithops configuration."""
231
+
232
+ lithops: Lithops = Field(default_factory=Lithops)
233
+ gcp: GCP | None = None
234
+ gcp_cloudrun: GCPCloudRun | None = None
235
+
236
+
237
+ class ResponseModel(BaseModel):
238
+ """Standard response model for workflow execution."""
239
+
240
+ result: dict[str, Any] | None = None
241
+ error: str | None = None
242
+ trace: str | None = None
243
+
244
+
245
+ @app.get("/", status_code=200)
246
+ def health_check() -> dict[str, str]:
247
+ """Health check endpoint.
248
+
249
+ Returns:
250
+ Status dictionary
251
+ """
252
+ return {"status": "ok"}
253
+
254
+
255
+ def resolve_matchspec(
256
+ matchspec: str | None = Query(None, description="Matchspec for the workflow."),
257
+ ) -> MatchSpec:
258
+ """Get the matchspec for the workflow.
259
+
260
+ Args:
261
+ matchspec: Rattler matchspec string
262
+
263
+ Returns:
264
+ Parsed MatchSpec object
265
+
266
+ Raises:
267
+ ValueError: If matchspec is not provided
268
+ """
269
+ matchspec_override = os.environ.get("ECOSCOPE_WORKFLOWS_MATCHSPEC_OVERRIDE")
270
+ matchspec_str = matchspec_override or matchspec
271
+ if not matchspec_str:
272
+ raise ValueError("Query param `matchspec` is required.")
273
+ return MatchSpec(matchspec_str)
274
+
275
+
276
+ async def resolve_invoker(
277
+ invoker_type: str = Query("BlockingLocalSubprocessInvoker"),
278
+ matchspec: MatchSpec = Depends(resolve_matchspec),
279
+ ) -> AbstractInvoker:
280
+ """Resolves the invoker name to the corresponding invoker class.
281
+
282
+ Args:
283
+ invoker_type: Type of invoker to use
284
+ matchspec: Workflow matchspec
285
+
286
+ Returns:
287
+ Configured and installed invoker instance
288
+
289
+ Raises:
290
+ ValueError: If unknown invoker type specified
291
+ """
292
+ if invoker_type not in INVOKERS:
293
+ raise ValueError(f"Unknown invoker name: {invoker_type}")
294
+
295
+ invoker = INVOKERS[invoker_type](matchspec=matchspec)
296
+ is_installed = await invoker.is_installed()
297
+ if not is_installed:
298
+ await invoker.install()
299
+
300
+ return invoker
301
+
302
+
303
+ def resolve_results_url(
304
+ results_url: str = Query(..., description="Results URL for the workflow."),
305
+ ) -> str:
306
+ """Get the results URL for the workflow.
307
+
308
+ Args:
309
+ results_url: URL or local path for results
310
+
311
+ Returns:
312
+ Normalized results URL
313
+
314
+ Raises:
315
+ ValueError: If URL is invalid
316
+ """
317
+ if not urlparse(results_url).scheme:
318
+ p = Path(results_url)
319
+ if not p.is_absolute():
320
+ raise ValueError("Results URL must be an absolute local path or a URL with scheme.")
321
+ return p.as_uri()
322
+ return results_url
323
+
324
+
325
+ @app.post("/", status_code=200, response_model=ResponseModel)
326
+ async def run(
327
+ # service response
328
+ response: Response,
329
+ # user (http) inputs
330
+ params: dict[str, Any],
331
+ execution_mode: Literal["async", "sequential"],
332
+ mock_io: bool,
333
+ results_url: str = Depends(resolve_results_url),
334
+ data_connections_env_vars: dict[str, SecretStr] | None = None,
335
+ lithops_config: LithopsConfig | None = None,
336
+ invoker: AbstractInvoker = Depends(resolve_invoker),
337
+ workflow_run_id: str = Query("", description="Unique ID for the workflow run."),
338
+ timeout: float | None = Query(
339
+ None,
340
+ description="Timeout for the workflow in seconds. Defaults to null; i.e., no timeout.",
341
+ ),
342
+ docker_image_uri: str | None = Query(None, description="Docker image URI for the workflow."),
343
+ traceparent: str | None = Header(
344
+ None,
345
+ description="Traceparent header; Cf. https://www.w3.org/TR/trace-context/.",
346
+ ),
347
+ tracestate: str | None = Header(
348
+ None, description="Tracestate header; Cf. https://www.w3.org/TR/trace-context/."
349
+ ),
350
+ ) -> dict[str, Any] | JSONResponse:
351
+ """Run a workflow with the specified parameters.
352
+
353
+ Args:
354
+ response: FastAPI response object
355
+ params: Workflow parameters
356
+ execution_mode: Execution mode (async or sequential)
357
+ mock_io: Whether to mock I/O operations
358
+ results_url: URL for storing results
359
+ data_connections_env_vars: Environment variables for data connections
360
+ lithops_config: Lithops configuration for async execution
361
+ invoker: Workflow invoker instance
362
+ workflow_run_id: Unique run identifier
363
+ timeout: Timeout in seconds
364
+ docker_image_uri: Docker image URI
365
+ traceparent: W3C traceparent header
366
+ tracestate: W3C tracestate header
367
+
368
+ Returns:
369
+ Workflow execution result
370
+ """
371
+ tracer = otel_trace.get_tracer(__name__)
372
+ if traceparent:
373
+ attach_context(traceparent, tracestate)
374
+ span_attributes = SpanAttributes(
375
+ workflow_run_id=workflow_run_id,
376
+ invoker_type=type(invoker).__name__,
377
+ )
378
+ with tracer.start_as_current_span(
379
+ "run-endpoint",
380
+ attributes=asdict(span_attributes),
381
+ ):
382
+ yaml = ruamel.yaml.YAML(typ="safe")
383
+ extra_env: dict[str, str] = {}
384
+ if data_connections_env_vars:
385
+ extra_env |= {k: v.get_secret_value() for k, v in data_connections_env_vars.items()}
386
+ trace_context = build_context_headers()
387
+ for k, v in trace_context.items():
388
+ extra_env[k.upper()] = str(v)
389
+ config_text_stream = StringIO()
390
+ yaml.dump(params, config_text_stream)
391
+ lithops_kws = {}
392
+ if execution_mode == "async":
393
+ lithops_config = LithopsConfig() if not lithops_config else lithops_config
394
+ lithops_text_stream = StringIO()
395
+ yaml.dump(lithops_config.model_dump(), lithops_text_stream)
396
+ lithops_kws = {"lithops_config_text": lithops_text_stream.getvalue()}
397
+ try:
398
+ await invoker.run(
399
+ workflow_run_id=workflow_run_id,
400
+ config_text=config_text_stream.getvalue(),
401
+ results_url=results_url,
402
+ execution_mode=execution_mode,
403
+ mock_io=mock_io,
404
+ extra_env=extra_env,
405
+ otel_exporter=get_otel_exporter(),
406
+ otel_console_exporter_dst=get_otel_console_exporter_dst(),
407
+ **lithops_kws,
408
+ docker_image_uri=docker_image_uri,
409
+ )
410
+ if invoker.is_waitable:
411
+ await invoker.wait(timeout=timeout, error_msg=TIMEOUT_EXPIRED_ERROR_MSG)
412
+ result = await get_results_json(results_url)
413
+ else:
414
+ result = {"result": {}, "error": None, "trace": None}
415
+ return JSONResponse(content=result, status_code=status.HTTP_202_ACCEPTED)
416
+ except Exception as e:
417
+ trace = traceback.format_exc()
418
+ result = {"error": str(e), "trace": trace}
419
+
420
+ if not isinstance(result, dict):
421
+ raise RuntimeError(f"Unexpected {result = }. Expected dict.")
422
+
423
+ if result.get("result") is None and result.get("error") is not None:
424
+ response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
425
+
426
+ return result
427
+
428
+
429
+ @app.post(
430
+ "/run-from-pubsub",
431
+ summary="Processes RunWorkflow messages from Pub/Sub",
432
+ status_code=200,
433
+ )
434
+ async def run_from_pubsub(
435
+ request: Request,
436
+ ) -> dict[str, Any]:
437
+ """Process RunWorkflow messages from Google Cloud Pub/Sub.
438
+
439
+ Note: Requires ecoscope_eda_core to be installed.
440
+
441
+ Args:
442
+ request: FastAPI request object containing Pub/Sub message
443
+
444
+ Returns:
445
+ Status dictionary
446
+
447
+ Raises:
448
+ HTTPException: If ecoscope_eda_core is not available
449
+ """
450
+ if not HAS_ECOSCOPE:
451
+ raise HTTPException(
452
+ status_code=501,
453
+ detail="Pub/Sub endpoint requires ecoscope_eda_core to be installed",
454
+ )
455
+
456
+ try: # Extract the payload from the PubSub message
457
+ command_payload = await extract_payload_from_pubsub_request(request)
458
+ invoker_params, trace_context = prepare_invoker_parameters(command_payload)
459
+ except (binascii.Error, json.JSONDecodeError, ValueError) as e:
460
+ # handle invalid payload errors to avoid 500 errors,
461
+ # since it doesn't make sense to let GCP retry those
462
+ trace = traceback.format_exc()
463
+ error_msg = f"Error extracting data from PubSub message: {type(e).__name__}: {e}"
464
+ logging.exception(error_msg)
465
+ return {
466
+ "status": "error",
467
+ "error": error_msg,
468
+ "trace": trace,
469
+ } # Error details are returned for local debugging
470
+
471
+ tracer = otel_trace.get_tracer(__name__)
472
+ if trace_context and (traceparent := trace_context.get("traceparent")) is not None:
473
+ attach_context(traceparent, tracestate=trace_context.get("tracestate"))
474
+ span_attributes = SpanAttributes(
475
+ workflow_run_id=invoker_params.get("workflow_run_id", ""),
476
+ invoker_type=command_payload.invoker_type,
477
+ )
478
+ with tracer.start_as_current_span(
479
+ "run-from-pubsub-endpoint",
480
+ attributes=asdict(span_attributes),
481
+ ):
482
+ trace_context = build_context_headers()
483
+ invoker_params["extra_env"] |= {k.upper(): v for k, v in trace_context.items()}
484
+ invoker_params["otel_exporter"] = get_otel_exporter()
485
+ invoker_params["otel_console_exporter_dst"] = get_otel_console_exporter_dst()
486
+ try: # Resolve the invoker
487
+ match_spec_obj = resolve_matchspec(matchspec=command_payload.match_spec)
488
+ invoker = await resolve_invoker(
489
+ invoker_type=command_payload.invoker_type, matchspec=match_spec_obj
490
+ )
491
+ except ValueError as e:
492
+ trace = traceback.format_exc()
493
+ error = {"error": str(e), "trace": trace}
494
+ await upload_error_to_gcs(
495
+ error_details=error, results_url=invoker_params["results_url"]
496
+ )
497
+ return {"status": "error", **error}
498
+
499
+ try:
500
+ await invoker.run(**invoker_params)
501
+ if invoker.is_waitable:
502
+ timeout = command_payload.invoker_kwargs.get("timeout", PUBSUB_ACK_MAX_TIMEOUT)
503
+ # Maximum timeout when running from PubSub is 10 minutes.
504
+ # It's set to a little bit less to have time to cancel and handle the error
505
+ timeout = min(timeout, max(timeout, PUBSUB_ACK_MAX_TIMEOUT))
506
+ exit_code = await invoker.wait(timeout=timeout, error_msg=TIMEOUT_EXPIRED_ERROR_MSG)
507
+ if exit_code != 0:
508
+ raise RuntimeError(f"Workflow invoker failed with exit code {exit_code}.")
509
+ except Exception as e:
510
+ trace = traceback.format_exc()
511
+ error = {"error": f"{type(e).__name__}: {e}", "trace": trace}
512
+ await upload_error_to_gcs(
513
+ error_details=error, results_url=invoker_params["results_url"]
514
+ )
515
+ return {"status": "error", **error}
516
+
517
+ return {"status": "processed"}
518
+
519
+
520
+ async def extract_payload_from_pubsub_request(
521
+ request: Request,
522
+ ) -> RunWorkflowParams:
523
+ """Extract the payload from the PubSub request.
524
+
525
+ Args:
526
+ request: FastAPI request object
527
+
528
+ Returns:
529
+ Parsed workflow parameters
530
+
531
+ Raises:
532
+ json.JSONDecodeError: If JSON is invalid
533
+ base64.binascii.Error: If base64 decoding fails
534
+ """
535
+ request_data = await request.json()
536
+ message = request_data.get("message", {})
537
+ payload = base64.b64decode(message.get("data", "{}").encode("utf-8"))
538
+ json_payload = json.loads(payload)
539
+ command = RunWorkflow.model_validate(json_payload)
540
+ return command.payload
541
+
542
+
543
+ def prepare_invoker_parameters(
544
+ command_payload: RunWorkflowParams,
545
+ ) -> tuple[dict[str, Any], TraceContextHeaders]:
546
+ """Prepare parameters for the invoker from the command payload.
547
+
548
+ Args:
549
+ command_payload: Workflow parameters from Pub/Sub message
550
+
551
+ Returns:
552
+ Tuple of (invoker_params, trace_context)
553
+ """
554
+ invoker_kwargs = command_payload.invoker_kwargs
555
+ workflow_run_id = invoker_kwargs.pop("workflow_run_id", "")
556
+ results_url = invoker_kwargs.pop("results_url", None)
557
+ params = invoker_kwargs.pop("params", {})
558
+ data_connections_env_vars = invoker_kwargs.pop("data_connections_env_vars", {})
559
+ # at minimum, should contain `traceparent`, optionally `tracestate`
560
+ trace_context = invoker_kwargs.pop("trace_context", None)
561
+ execution_mode = invoker_kwargs.pop("execution_mode", "sequential")
562
+ mock_io = invoker_kwargs.pop("mock_io", False)
563
+ # Build extra params needed for the invoker
564
+ yaml = ruamel.yaml.YAML(typ="safe")
565
+ config_text_stream = StringIO()
566
+ yaml.dump(params, config_text_stream)
567
+ lithops_kws = {}
568
+ if execution_mode == "async":
569
+ lithops_config = LithopsConfig()
570
+ lithops_text_stream = StringIO()
571
+ yaml.dump(lithops_config.model_dump(), lithops_text_stream)
572
+ lithops_kws = {"lithops_config_text": lithops_text_stream.getvalue()}
573
+ return (
574
+ {
575
+ "workflow_run_id": workflow_run_id,
576
+ "config_text": config_text_stream.getvalue(),
577
+ "results_url": results_url,
578
+ "execution_mode": execution_mode,
579
+ "mock_io": mock_io,
580
+ "extra_env": data_connections_env_vars,
581
+ }
582
+ | lithops_kws
583
+ | invoker_kwargs,
584
+ trace_context,
585
+ ) # Extra kwargs are passed to the invoker
586
+
587
+
588
+ async def upload_error_to_gcs(error_details: dict[str, Any], results_url: str) -> None:
589
+ """Upload error details to Google Cloud Storage.
590
+
591
+ Args:
592
+ error_details: Error information dictionary
593
+ results_url: URL for storing results
594
+ """
595
+ # Save error in result.json and upload to GCS
596
+ result_store = obstore.store.from_url(results_url)
597
+ result_bytes = json.dumps(error_details).encode("utf-8")
598
+ await result_store.put_async("result.json", result_bytes)
599
+
600
+
601
+ async def _get_metadata_attribute(
602
+ attr: str,
603
+ invoker: AbstractInvoker,
604
+ ) -> dict[str, Any]:
605
+ """Get a metadata attribute for the workflow.
606
+
607
+ Args:
608
+ attr: Attribute name to retrieve
609
+ invoker: Invoker instance
610
+
611
+ Returns:
612
+ Metadata as dictionary
613
+
614
+ Raises:
615
+ RuntimeError: If attribute retrieval or parsing fails
616
+ """
617
+ out = await invoker.check_output(f"get {attr}".split())
618
+ if not out:
619
+ raise RuntimeError(f"Failed to get {attr}.")
620
+ try:
621
+ as_json: dict[str, Any] = json.loads(out)
622
+ except json.JSONDecodeError as e:
623
+ raise RuntimeError(f"Failed to parse rjsf from str: {out}") from e
624
+ return as_json
625
+
626
+
627
+ @app.get("/rjsf", status_code=200)
628
+ async def rjsf(invoker: AbstractInvoker = Depends(resolve_invoker)) -> dict[str, Any]:
629
+ """Get the React JSON Schema Form schema for the workflow.
630
+
631
+ Args:
632
+ invoker: Invoker instance
633
+
634
+ Returns:
635
+ RJSF schema dictionary
636
+ """
637
+ return await _get_metadata_attribute("rjsf", invoker)
638
+
639
+
640
+ @app.get("/data-connection-property-names", status_code=200)
641
+ async def data_connection_property_names(
642
+ invoker: AbstractInvoker = Depends(resolve_invoker),
643
+ ) -> dict[str, Any]:
644
+ """Get the data connection property names for the workflow.
645
+
646
+ Args:
647
+ invoker: Invoker instance
648
+
649
+ Returns:
650
+ Data connection property names
651
+ """
652
+ return await _get_metadata_attribute("data-connection-property-names", invoker)
653
+
654
+
655
+ async def _convert(
656
+ from_: str,
657
+ to: str,
658
+ json_: str,
659
+ invoker: AbstractInvoker,
660
+ ) -> dict[str, Any] | list[dict[str, Any]]:
661
+ """Convert between params and formdata, and visa-versa.
662
+
663
+ Args:
664
+ from_: Source format
665
+ to: Target format
666
+ json_: JSON string to convert
667
+ invoker: Invoker instance
668
+
669
+ Returns:
670
+ Converted data as dictionary, or list of dicts for validation errors
671
+
672
+ Raises:
673
+ RuntimeError: If conversion or parsing fails
674
+ """
675
+ cmd = f"convert --from {from_} --to {to}"
676
+ out = await invoker.check_output(cmd.split(), stdin=json_)
677
+ if not out:
678
+ raise RuntimeError(f"Failed to convert {from_} to {to} for '{json_}'.")
679
+ try:
680
+ as_json: dict[str, Any] | list[dict[str, Any]] = json.loads(out)
681
+ except json.JSONDecodeError as e:
682
+ raise RuntimeError(f"Failed to parse rjsf from str: {out}") from e
683
+ return as_json
684
+
685
+
686
+ def _is_422(json_: dict[str, Any] | list[dict[str, Any]]) -> bool:
687
+ """Check if the json is a 422 validation error.
688
+
689
+ Args:
690
+ json_: JSON data to check
691
+
692
+ Returns:
693
+ True if data represents a 422 error
694
+ """
695
+ return (
696
+ isinstance(json_, list)
697
+ and len(json_) > 0
698
+ and all(isinstance(e, dict) for e in json_)
699
+ and all(set(e) == {"type", "loc", "msg", "input", "url"} for e in json_)
700
+ )
701
+
702
+
703
+ @app.post("/formdata-to-params", status_code=200)
704
+ async def validate_formdata(
705
+ formdata: dict[str, Any], invoker: AbstractInvoker = Depends(resolve_invoker)
706
+ ) -> dict[str, Any]:
707
+ """Convert and validate form data to workflow parameters.
708
+
709
+ Args:
710
+ formdata: Form data dictionary
711
+ invoker: Invoker instance
712
+
713
+ Returns:
714
+ Validated parameters dictionary
715
+
716
+ Raises:
717
+ HTTPException: If validation fails (422 error)
718
+ """
719
+ outjson = await _convert(
720
+ from_="formdata",
721
+ to="params",
722
+ json_=json.dumps(formdata),
723
+ invoker=invoker,
724
+ )
725
+ if _is_422(outjson):
726
+ raise HTTPException(status_code=422, detail=outjson)
727
+ # At this point, outjson is not a 422 error list, so it's a dict
728
+ assert isinstance(outjson, dict)
729
+ return outjson
730
+
731
+
732
+ @app.post("/params-to-formdata", status_code=200)
733
+ async def generate_nested_params(
734
+ params: dict[str, Any], invoker: AbstractInvoker = Depends(resolve_invoker)
735
+ ) -> dict[str, Any]:
736
+ """Convert workflow parameters to form data format.
737
+
738
+ Args:
739
+ params: Parameters dictionary
740
+ invoker: Invoker instance
741
+
742
+ Returns:
743
+ Form data dictionary
744
+
745
+ Raises:
746
+ HTTPException: If conversion fails (422 error)
747
+ """
748
+ outjson = await _convert(
749
+ from_="params",
750
+ to="formdata",
751
+ json_=json.dumps(params),
752
+ invoker=invoker,
753
+ )
754
+ if _is_422(outjson):
755
+ raise HTTPException(status_code=422, detail=outjson)
756
+ # At this point, outjson is not a 422 error list, so it's a dict
757
+ assert isinstance(outjson, dict)
758
+ return outjson
wt_runner/py.typed ADDED
File without changes
wt_runner/testing.py ADDED
@@ -0,0 +1,140 @@
1
+ """Testing utilities for workflow test cases.
2
+
3
+ Provides Case (Pydantic model) and CaseRunner (dataclass) for running
4
+ workflow test cases via either the FastAPI application or CLI.
5
+ """
6
+
7
+ import asyncio
8
+ import os
9
+ import traceback
10
+ import uuid
11
+ from dataclasses import dataclass
12
+ from io import StringIO
13
+ from pathlib import Path
14
+ from typing import Any, Literal
15
+
16
+ import ruamel.yaml
17
+ from fastapi.testclient import TestClient
18
+ from pydantic import BaseModel
19
+ from rattler import MatchSpec
20
+ from wt_invokers.local import LocalSubprocessInvoker
21
+
22
+ from .app import get_results_json
23
+ from .tracing import OTelConsoleExporterDst, OtelExporterChoice
24
+
25
+
26
+ class Case(BaseModel):
27
+ """A test case for a workflow.
28
+
29
+ Args:
30
+ name: Human-readable name of the test case.
31
+ description: Description of what the test case covers.
32
+ params: Workflow parameters to pass.
33
+ raises: Whether the test case is expected to raise an error.
34
+ expected_status_code: Expected HTTP status code (default 200).
35
+ """
36
+
37
+ name: str
38
+ description: str
39
+ params: dict[str, Any]
40
+ raises: bool = False
41
+ expected_status_code: int = 200
42
+
43
+
44
+ ExecutionMode = Literal["async", "sequential"] # TODO: move to executors module
45
+
46
+
47
+ @dataclass
48
+ class CaseRunner:
49
+ """Run a single test case for a workflow via either the FastAPI application or CLI.
50
+
51
+ Args:
52
+ execution_mode: The execution mode to test. One of "async" or "sequential".
53
+ mock_io: Whether or not to mock IO with 3rd party services.
54
+ case: The test case to run. Test cases are defined by the `test-cases.yaml` file.
55
+ results_subdir: The temporary directory to use for the test.
56
+ traceparent: The traceparent header to propagate tracing context. Optional.
57
+ otel_exporter: The OpenTelemetry exporter to use. Optional. One of "console", or "gcp".
58
+ otel_console_exporter_dst: The destination for the console exporter.
59
+ One of "stdout" or "file".
60
+ """
61
+
62
+ execution_mode: ExecutionMode
63
+ mock_io: bool
64
+ case: Case
65
+ results_subdir: Path
66
+ traceparent: str | None = None
67
+ otel_exporter: OtelExporterChoice | None = "console"
68
+ otel_console_exporter_dst: OTelConsoleExporterDst = "file"
69
+
70
+ def run_app(
71
+ self, app: Any, data_connections_env_vars: dict[str, Any] | None = None
72
+ ) -> dict[str, Any]:
73
+ """Run a single test case for a workflow via the FastAPI application.
74
+
75
+ Args:
76
+ app: The fastapi.App instance.
77
+ data_connections_env_vars: Optional environment variables for data connections.
78
+
79
+ Returns:
80
+ Response JSON as a dictionary.
81
+ """
82
+ json_ = {
83
+ "params": self.case.params,
84
+ "data_connections_env_vars": data_connections_env_vars or {},
85
+ }
86
+ query_params = {
87
+ "execution_mode": self.execution_mode,
88
+ "mock_io": self.mock_io,
89
+ "results_url": self.results_subdir.absolute().as_posix(),
90
+ }
91
+ headers = {"Content-Type": "application/json"}
92
+ if self.traceparent:
93
+ headers["traceparent"] = self.traceparent
94
+ with TestClient(app) as client:
95
+ response = client.post("/", json=json_, params=query_params, headers=headers)
96
+ assert response.status_code == self.case.expected_status_code, (
97
+ f"Test failed with {response.status_code = }, "
98
+ f"which differs from {self.case.expected_status_code = }; "
99
+ f"{response.text =}"
100
+ )
101
+ result: dict[str, Any] = response.json()
102
+ return result
103
+
104
+ def run_cli(self, matchspec: MatchSpec) -> dict[str, Any]:
105
+ """Run a single test case for a workflow via the CLI.
106
+
107
+ Args:
108
+ matchspec: The matchspec of the workflow to run.
109
+
110
+ Returns:
111
+ Results dictionary.
112
+ """
113
+ invoker = LocalSubprocessInvoker(matchspec=matchspec, cwd=os.getcwd())
114
+ yaml = ruamel.yaml.YAML(typ="safe")
115
+ config_text_stream = StringIO()
116
+ yaml.dump(self.case.params, config_text_stream)
117
+
118
+ async def _run() -> dict[str, Any]:
119
+ try:
120
+ await invoker.run(
121
+ workflow_run_id=uuid.uuid4().hex,
122
+ config_text=config_text_stream.getvalue(),
123
+ results_url=self.results_subdir.as_uri(),
124
+ execution_mode=self.execution_mode,
125
+ mock_io=self.mock_io,
126
+ extra_env=({"TRACEPARENT": self.traceparent} if self.traceparent else None),
127
+ otel_exporter=self.otel_exporter,
128
+ otel_console_exporter_dst=self.otel_console_exporter_dst,
129
+ )
130
+ await invoker.wait(timeout=300)
131
+ result = await get_results_json(self.results_subdir.as_uri())
132
+ except Exception as e:
133
+ trace = traceback.format_exc()
134
+ result = {"error": str(e), "trace": trace}
135
+
136
+ if not isinstance(result, dict):
137
+ raise RuntimeError(f"Unexpected {result = }. Expected dict.")
138
+ return result
139
+
140
+ return asyncio.run(_run())
wt_runner/tracing.py ADDED
@@ -0,0 +1,164 @@
1
+ """Basic OpenTelemetry tracing setup for Google Cloud Trace.
2
+
3
+ Note this is adapted from https://github.com/PADAS/cdip-routing.
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Literal, TypedDict
9
+
10
+ from opentelemetry import context, propagate, trace
11
+ from opentelemetry.propagate import set_global_textmap
12
+ from opentelemetry.sdk.resources import Resource
13
+ from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
14
+ from opentelemetry.sdk.trace.export import (
15
+ BatchSpanProcessor,
16
+ ConsoleSpanExporter,
17
+ SpanExporter,
18
+ )
19
+ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
20
+
21
+ # Optional GCP exporter
22
+ try:
23
+ from opentelemetry.exporter.cloud_trace import (
24
+ CloudTraceSpanExporter, # type: ignore[import-not-found,unused-ignore]
25
+ )
26
+
27
+ HAS_GCP_EXPORTER = True
28
+ except ImportError:
29
+ CloudTraceSpanExporter = None # type: ignore[misc,assignment,unused-ignore]
30
+ HAS_GCP_EXPORTER = False
31
+
32
+ OtelExporterChoice = Literal["console", "gcp"]
33
+ OTelConsoleExporterDst = Literal["stdout", "file"]
34
+
35
+
36
+ def otel_span_formatter(span: ReadableSpan) -> str:
37
+ """Format an OTEL span as an unindented JSON line.
38
+
39
+ Args:
40
+ span: The span to format
41
+
42
+ Returns:
43
+ Formatted span as JSON line with newline
44
+ """
45
+ result: str = span.to_json(indent=None) + os.linesep
46
+ return result
47
+
48
+
49
+ def make_otel_console_exporter_file_dst_kws(target_dir: Path) -> dict[str, Any]:
50
+ """Create kwargs for console exporter writing to a file.
51
+
52
+ This opinionated configuration:
53
+ 1. Ensures the target directory exists (creating if necessary)
54
+ 2. Opens a file `otel_traces.jsonl` in the target directory for appending
55
+ 3. Uses line buffering for immediate writes
56
+ 4. Uses unindented JSON formatter for easier parsing
57
+
58
+ Args:
59
+ target_dir: Directory to write traces to
60
+
61
+ Returns:
62
+ Dictionary of kwargs for ConsoleSpanExporter
63
+
64
+ Raises:
65
+ ValueError: If target_dir exists but is not a directory
66
+ """
67
+ if target_dir.exists() and not target_dir.is_dir():
68
+ raise ValueError(f"Target dir {target_dir} exists but is not a directory")
69
+ elif not target_dir.exists():
70
+ target_dir.mkdir(parents=True, exist_ok=True)
71
+ traces_outpath = target_dir / "otel_traces.jsonl"
72
+ return {
73
+ "out": traces_outpath.open("a", buffering=1),
74
+ "formatter": otel_span_formatter,
75
+ }
76
+
77
+
78
+ def configure_tracer(
79
+ name: str,
80
+ version: str = "",
81
+ exporter: OtelExporterChoice | None = None,
82
+ exporter_kws: dict[str, Any] | None = None,
83
+ ) -> None:
84
+ """Configure OpenTelemetry tracer with specified exporter.
85
+
86
+ Args:
87
+ name: Service name for the tracer
88
+ version: Service version (optional)
89
+ exporter: Type of exporter to use (console or gcp), None for no exporter
90
+ exporter_kws: Additional kwargs for the exporter
91
+
92
+ Raises:
93
+ ValueError: If unknown exporter type specified
94
+ RuntimeError: If GCP exporter is requested but not available
95
+ """
96
+ resource = Resource.create(
97
+ {
98
+ "service.name": name,
99
+ "service.version": version,
100
+ }
101
+ )
102
+ tracer_provider = TracerProvider(resource=resource)
103
+ if exporter:
104
+ _exporter: SpanExporter
105
+ _exporter_kws = exporter_kws or {}
106
+ match exporter:
107
+ case "console":
108
+ _exporter = ConsoleSpanExporter(**_exporter_kws)
109
+ case "gcp":
110
+ if not HAS_GCP_EXPORTER:
111
+ raise RuntimeError(
112
+ "GCP exporter requested but opentelemetry-exporter-gcp-trace "
113
+ "is not installed. Install with: pip install wt-runner[tracing]"
114
+ )
115
+ _exporter = CloudTraceSpanExporter(**_exporter_kws) # type: ignore[no-untyped-call,unused-ignore]
116
+ case _:
117
+ raise ValueError(f"Unknown exporter: {exporter}")
118
+
119
+ tracer_provider.add_span_processor(
120
+ # BatchSpanProcessor buffers spans and sends them in batches in a
121
+ # background thread. The default parameters are sensible, but can be
122
+ # tweaked to optimize your performance
123
+ BatchSpanProcessor(_exporter)
124
+ )
125
+ trace.set_tracer_provider(tracer_provider)
126
+
127
+
128
+ class TraceContextHeaders(TypedDict, total=False):
129
+ """W3C Trace Context headers.
130
+
131
+ See: https://www.w3.org/TR/trace-context/
132
+ """
133
+
134
+ traceparent: str
135
+ tracestate: str
136
+
137
+
138
+ def build_context_headers() -> TraceContextHeaders:
139
+ """Build trace context headers from current OpenTelemetry context.
140
+
141
+ Returns:
142
+ Dictionary containing traceparent and optionally tracestate headers
143
+ """
144
+ headers: TraceContextHeaders = {}
145
+ propagate.inject(headers)
146
+ return headers
147
+
148
+
149
+ def attach_context(traceparent: str, tracestate: str | None = None) -> None:
150
+ """Attach tracing context from given traceparent and tracestate headers.
151
+
152
+ Args:
153
+ traceparent: W3C traceparent header value
154
+ tracestate: W3C tracestate header value (optional)
155
+ """
156
+ carrier = {"traceparent": traceparent}
157
+ if tracestate:
158
+ carrier["tracestate"] = tracestate
159
+ ctx = propagate.extract(carrier=carrier)
160
+ context.attach(ctx)
161
+
162
+
163
+ # uses the default W3C Trace Context propagator, i.e. `traceparent` header
164
+ set_global_textmap(TraceContextTextMapPropagator())
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.4
2
+ Name: wt-runner
3
+ Version: 0.1.3
4
+ Summary: FastAPI application for workflow execution using wt-invokers
5
+ License: BSD-3-Clause
6
+ Classifier: Development Status :: 3 - Alpha
7
+ Classifier: Intended Audience :: Developers
8
+ Classifier: License :: OSI Approved :: BSD License
9
+ Classifier: Programming Language :: Python :: 3
10
+ Classifier: Programming Language :: Python :: 3.13
11
+ Requires-Python: <3.16,>=3.13
12
+ Requires-Dist: wt-contracts<1.0.0,>=0.1.0
13
+ Requires-Dist: wt-invokers<1.0.0,>=0.1.0
14
+ Requires-Dist: fastapi>=0.100.0
15
+ Requires-Dist: uvicorn>=0.20.0
16
+ Requires-Dist: pydantic<3.0.0,>=2.0.0
17
+ Requires-Dist: py-rattler>=0.8.0
18
+ Requires-Dist: ruamel.yaml>=0.18.0
19
+ Requires-Dist: opentelemetry-api>=1.0.0
20
+ Requires-Dist: opentelemetry-sdk>=1.0.0
21
+ Requires-Dist: obstore>=0.6.0
22
+ Provides-Extra: gcp
23
+ Requires-Dist: opentelemetry-sdk<2,>=1.37.0; extra == "gcp"
24
+ Requires-Dist: opentelemetry-exporter-gcp-trace<2,>=1.9.0; extra == "gcp"
25
+ Requires-Dist: gcloud-aio-pubsub<7,>=6.1.0; extra == "gcp"
@@ -0,0 +1,10 @@
1
+ wt_runner/__init__.py,sha256=NKQBohYUXiv4kheM3Cyy0zn9exUJaLVRTgdEP-MZamY,899
2
+ wt_runner/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
3
+ wt_runner/app.py,sha256=d_8qZkbZoKE-5ye00LURGkwrY386Z3uTH9KaFt6tI0Y,24467
4
+ wt_runner/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ wt_runner/testing.py,sha256=ec5__H9QFJtzvtO4pkSvAV0giUur6_Xrl9UiMJ6MMSY,5131
6
+ wt_runner/tracing.py,sha256=-mt8MMGo9S74O2tCTXhdCPCw8BcvWSmS4u2RwnzLnAs,5354
7
+ wt_runner-0.1.3.dist-info/METADATA,sha256=EEfkH_8xhcaaGLH2vGmXwnsd_5JWO0F3IHr3MKzMFLE,988
8
+ wt_runner-0.1.3.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
9
+ wt_runner-0.1.3.dist-info/top_level.txt,sha256=ujeMrgee-Be9X1QZegBuCKNWZ2NgYnsHI-VzrxXW70c,10
10
+ wt_runner-0.1.3.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ wt_runner