vellum-workflow-server 1.7.10__py3-none-any.whl → 1.9.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-workflow-server
3
- Version: 1.7.10
3
+ Version: 1.9.6.post1
4
4
  Summary:
5
5
  License: AGPL
6
6
  Requires-Python: >=3.9.0,<4
@@ -29,7 +29,7 @@ Requires-Dist: pyjwt (==2.10.0)
29
29
  Requires-Dist: python-dotenv (==1.0.1)
30
30
  Requires-Dist: retrying (==1.3.4)
31
31
  Requires-Dist: sentry-sdk[flask] (==2.20.0)
32
- Requires-Dist: vellum-ai (==1.7.10)
32
+ Requires-Dist: vellum-ai (==1.9.6)
33
33
  Description-Content-Type: text/markdown
34
34
 
35
35
  # Vellum Workflow Runner Server
@@ -1,22 +1,24 @@
1
1
  workflow_server/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  workflow_server/api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- workflow_server/api/auth_middleware.py,sha256=IlZaCiwZ5nwQqk5sYQorvOFj7lt0p1ZSSEqUxfiFaW0,2458
3
+ workflow_server/api/auth_middleware.py,sha256=qHsMZFYtZyaR7dvkJ6xI6ZYdYw2-2y8X4JyfIQqQTSw,2495
4
4
  workflow_server/api/healthz_view.py,sha256=itiRvBDBXncrw8Kbbc73UZLwqMAhgHOR3uSre_dAfgY,404
5
+ workflow_server/api/status_view.py,sha256=Jah8dBAVL4uOcRfsjKAOyfVONFyk9HQjXeRfjcIqhmA,514
5
6
  workflow_server/api/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
7
  workflow_server/api/tests/test_input_display_mapping.py,sha256=drBZqMudFyB5wgiUOcMgRXz7E7ge-Qgxbstw4E4f0zE,2211
7
- workflow_server/api/tests/test_workflow_view.py,sha256=_9SHdK1t-wwWILZFrnoYe5NhkoLazJnt7K1oH4bnJU0,30355
8
- workflow_server/api/tests/test_workflow_view_stream_workflow_route.py,sha256=GjWbwSmjxb691LmKqaCsAItkcb3QJaHw_OG_H5ETWdc,37432
9
- workflow_server/api/workflow_view.py,sha256=tSepIRP0LcJUKae_wSHGw6P5Ho8K20_YTj3J-9ZL6bs,20463
8
+ workflow_server/api/tests/test_workflow_view.py,sha256=81kAHpijNp0rvb3ZjvceB5uFEriVWPeWHnK78-xoeTc,32343
9
+ workflow_server/api/tests/test_workflow_view_stream_workflow_route.py,sha256=Yrp_DlLbbwZJe5WRLwdlFT17R8CQoCK9-jlQ1jUT_eM,40377
10
+ workflow_server/api/workflow_view.py,sha256=XSVfHYgsy2k_QqTiue9Xx438Z4qprHbD0PeS8JI04sY,24547
10
11
  workflow_server/code_exec_runner.py,sha256=DLNNrinCRbnkSvlqVvSZ1wv_etI7r_kKAXNPGMj3jBk,2196
11
- workflow_server/config.py,sha256=cUdI_lEovV7e7lwCkGJ1eM9R4OZVJw5R5zT1eG1SzQQ,2122
12
+ workflow_server/config.py,sha256=I4hfTsjIbHxoSKylPCjKnrysPV0jO5nfRKwpKvEcfAE,2193
12
13
  workflow_server/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
14
  workflow_server/core/cancel_workflow.py,sha256=QcEeYUIrxq4pub-z9BlGi5fLI3gVRml-56rMCW7j5Hc,2212
14
15
  workflow_server/core/events.py,sha256=24MA66DVQuaLJJcZrS8IL1Zq4Ohi9CoouKZ5VgoH3Cs,1402
15
- workflow_server/core/executor.py,sha256=rMpVgP3PqMnjHVYeuTdHvbYGwJW5C9O5TFGW926jFM0,16106
16
+ workflow_server/core/executor.py,sha256=lP69l8ATeSe88DOUPIO5mmwq1iuQ-02smw7Tr471wTY,17754
16
17
  workflow_server/core/utils.py,sha256=si0NB4Suurc-mn8NYdn59xM9CkPrfOP1aWEVrZvifDI,1929
17
- workflow_server/core/workflow_executor_context.py,sha256=Q2R0T2KkYZ1z52v8erDMysJfxSODbzPhDxBxX--k4Zw,3202
18
- workflow_server/server.py,sha256=lhHPmK1PhRZd6eCkj1C0acK3YwaApZgoPHghMChw0fc,1461
19
- workflow_server/start.py,sha256=xSIobowtSLoZI86bbMkmEw3pqJHQaFdDyNffk4kGYL8,2544
18
+ workflow_server/core/workflow_executor_context.py,sha256=uUlFF2PIfFzIzhHS25mpvO4wO97UWqQVArg7zC2xVcM,3490
19
+ workflow_server/logging_config.py,sha256=Hvx1t8uhqMMinl-5qcef7ufUvzs6x14VRnCb7YZxEAg,1206
20
+ workflow_server/server.py,sha256=pBl0OQmrLE-PbTDwTgsVmxgz_Ai3TVhFRaMnr6PX6Yk,1849
21
+ workflow_server/start.py,sha256=Ams5ycqVbBorC7s6EI95BYzjpxzlo5mQbBnMNOkJS0w,2753
20
22
  workflow_server/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
23
  workflow_server/utils/exit_handler.py,sha256=_FacDVi4zc3bfTA3D2mJsISePlJ8jpLrnGVo5-xZQFs,743
22
24
  workflow_server/utils/log_proxy.py,sha256=nugi6fOgAYKX2X9DIc39TG366rsmmDUPoEtG3gzma_Y,3088
@@ -28,7 +30,7 @@ workflow_server/utils/tests/test_sentry_integration.py,sha256=14PfuW8AaQNNtqLmBs
28
30
  workflow_server/utils/tests/test_system_utils.py,sha256=_4GwXvVvU5BrATxUEWwQIPg0bzQXMWBtiBmjP8MTxJM,4314
29
31
  workflow_server/utils/tests/test_utils.py,sha256=0Nq6du8o-iBtTrip9_wgHES53JSiJbVdSXaBnPobw3s,6930
30
32
  workflow_server/utils/utils.py,sha256=m7iMJtor5SQLWu7jlJw-X5Q3nmbq69BCxTMv6qnFYrA,4835
31
- vellum_workflow_server-1.7.10.dist-info/METADATA,sha256=0jfzSrst_oJpWuB5jFawlvAeVZIZWvaBn4t-O2G1M_I,2269
32
- vellum_workflow_server-1.7.10.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
33
- vellum_workflow_server-1.7.10.dist-info/entry_points.txt,sha256=uB_0yPkr7YV6RhEXzvFReUM8P4OQBlVXD6TN6eb9-oc,277
34
- vellum_workflow_server-1.7.10.dist-info/RECORD,,
33
+ vellum_workflow_server-1.9.6.post1.dist-info/METADATA,sha256=69SUWrSyCFeBTIuN7NVAhHUTUBaNtUAgz8SrHFlEOag,2273
34
+ vellum_workflow_server-1.9.6.post1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
35
+ vellum_workflow_server-1.9.6.post1.dist-info/entry_points.txt,sha256=uB_0yPkr7YV6RhEXzvFReUM8P4OQBlVXD6TN6eb9-oc,277
36
+ vellum_workflow_server-1.9.6.post1.dist-info/RECORD,,
@@ -5,7 +5,7 @@ from flask import Flask, Request, Response
5
5
  import jwt
6
6
  from jwt import ExpiredSignatureError
7
7
 
8
- from workflow_server.config import IS_VPC, NAMESPACE, VEMBDA_PUBLIC_KEY, is_development
8
+ from workflow_server.config import IS_ASYNC_MODE, IS_VPC, NAMESPACE, VEMBDA_PUBLIC_KEY, is_development
9
9
 
10
10
 
11
11
  class AuthMiddleware:
@@ -15,7 +15,7 @@ class AuthMiddleware:
15
15
  def __call__(self, environ: Dict[str, Any], start_response: Any) -> Any:
16
16
  try:
17
17
  request = Request(environ)
18
- if not request.path.startswith("/healthz") and not is_development() and not IS_VPC:
18
+ if not request.path.startswith("/healthz") and not is_development() and not IS_VPC and not IS_ASYNC_MODE:
19
19
  token = request.headers.get("X-Vembda-Signature")
20
20
  if token:
21
21
  decoded = jwt.decode(token, VEMBDA_PUBLIC_KEY, algorithms=["RS256"])
@@ -0,0 +1,19 @@
1
+ from typing import Tuple
2
+
3
+ from flask import Blueprint, Response, jsonify
4
+
5
+ from workflow_server.config import CONCURRENCY
6
+ from workflow_server.utils.system_utils import get_active_process_count
7
+
8
+ bp = Blueprint("status", __name__)
9
+
10
+
11
+ @bp.route("/is_available", methods=["GET"])
12
+ def is_available() -> Tuple[Response, int]:
13
+ resp = jsonify(
14
+ available=get_active_process_count() < CONCURRENCY,
15
+ process_count=get_active_process_count(),
16
+ max_concurrency=CONCURRENCY,
17
+ )
18
+
19
+ return resp, 200
@@ -537,6 +537,57 @@ def test_serialize_route__with_invalid_workspace_api_key():
537
537
  assert "exec_config" in response.json
538
538
 
539
539
 
540
+ def test_serialize_route__with_is_new_server_header():
541
+ """
542
+ Tests that the serialize route returns the is_new_server header.
543
+ """
544
+ # GIVEN a Flask application
545
+ flask_app = create_app()
546
+
547
+ workflow_files = {
548
+ "__init__.py": "",
549
+ "workflow.py": (
550
+ "from vellum.workflows import BaseWorkflow\n\n"
551
+ "class Workflow(BaseWorkflow):\n"
552
+ " class Outputs(BaseWorkflow.Outputs):\n"
553
+ " foo = 'hello'\n"
554
+ ),
555
+ }
556
+
557
+ # WHEN we make a request with is_new_server=True
558
+ with flask_app.test_client() as test_client:
559
+ response = test_client.post("/workflow/serialize", json={"files": workflow_files, "is_new_server": True})
560
+
561
+ # THEN we should get a successful response
562
+ assert response.status_code == 200
563
+
564
+ # AND the response should contain the is_new_server header set to true
565
+ assert "X-Vellum-Is-New-Server" in response.headers
566
+ assert response.headers["X-Vellum-Is-New-Server"] == "true"
567
+
568
+ # WHEN we make a request with is_new_server=False
569
+ with flask_app.test_client() as test_client:
570
+ response = test_client.post("/workflow/serialize", json={"files": workflow_files, "is_new_server": False})
571
+
572
+ # THEN we should get a successful response
573
+ assert response.status_code == 200
574
+
575
+ # AND the response should contain the is_new_server header set to false
576
+ assert "X-Vellum-Is-New-Server" in response.headers
577
+ assert response.headers["X-Vellum-Is-New-Server"] == "false"
578
+
579
+ # WHEN we make a request without is_new_server
580
+ with flask_app.test_client() as test_client:
581
+ response = test_client.post("/workflow/serialize", json={"files": workflow_files})
582
+
583
+ # THEN we should get a successful response
584
+ assert response.status_code == 200
585
+
586
+ # AND the response should contain the is_new_server header set to false (default)
587
+ assert "X-Vellum-Is-New-Server" in response.headers
588
+ assert response.headers["X-Vellum-Is-New-Server"] == "false"
589
+
590
+
540
591
  def test_stream_node_route__with_node_id():
541
592
  """
542
593
  Tests that the stream-node endpoint works with node_id.
@@ -5,6 +5,7 @@ import io
5
5
  import json
6
6
  from queue import Empty
7
7
  import re
8
+ import time
8
9
  from unittest import mock
9
10
  from uuid import uuid4
10
11
 
@@ -131,8 +132,11 @@ class Workflow(BaseWorkflow):
131
132
  },
132
133
  }
133
134
 
134
- # WHEN we call the stream route
135
- status_code, events = both_stream_types(request_body)
135
+ with mock.patch("builtins.open", mock.mock_open(read_data="104857600")):
136
+ # WHEN we call the stream route
137
+ ts_ns = time.time_ns()
138
+ request_body["vembda_service_initiated_timestamp"] = ts_ns
139
+ status_code, events = both_stream_types(request_body)
136
140
 
137
141
  # THEN we get a 200 response
138
142
  assert status_code == 200, events
@@ -164,9 +168,37 @@ class Workflow(BaseWorkflow):
164
168
  assert isinstance(display_context["workflow_inputs"], dict)
165
169
  assert isinstance(display_context["workflow_outputs"], dict)
166
170
  assert "foo" in display_context["workflow_outputs"]
171
+
172
+ # AND the initiated event should have server_metadata with version info and memory usage
173
+ assert "server_metadata" in events[1]["body"], events[1]["body"]
174
+ server_metadata = events[1]["body"]["server_metadata"]
175
+ assert server_metadata is not None, "server_metadata should not be None"
176
+ assert "server_version" in server_metadata
177
+ assert "sdk_version" in server_metadata
178
+ assert "memory_usage_mb" in server_metadata
179
+ assert isinstance(server_metadata["memory_usage_mb"], (int, float))
180
+ assert "is_new_server" in server_metadata
181
+ assert server_metadata["is_new_server"] is False
182
+
183
+ # AND the initiated event should have initiated_latency within a reasonable range
184
+ assert "initiated_latency" in server_metadata, "initiated_latency should be present in server_metadata"
185
+ initiated_latency = server_metadata["initiated_latency"]
186
+ assert isinstance(initiated_latency, int), "initiated_latency should be an integer (nanoseconds)"
187
+ # Latency should be positive and less than 60 seconds (60_000_000_000 nanoseconds) for CI
188
+ assert (
189
+ 0 < initiated_latency < 60_000_000_000
190
+ ), f"initiated_latency should be between 0 and 60 seconds, got {initiated_latency} ns"
191
+
167
192
  assert events[2]["name"] == "workflow.execution.fulfilled", events[2]
168
193
  assert events[2]["body"]["workflow_definition"]["module"] == ["test", "workflow"]
169
194
 
195
+ # AND the fulfilled event should have server_metadata with memory usage
196
+ assert "server_metadata" in events[2]["body"], events[2]["body"]
197
+ fulfilled_metadata = events[2]["body"]["server_metadata"]
198
+ assert fulfilled_metadata is not None, "fulfilled server_metadata should not be None"
199
+ assert "memory_usage_mb" in fulfilled_metadata
200
+ assert isinstance(fulfilled_metadata["memory_usage_mb"], (int, float))
201
+
170
202
  assert events[3] == {
171
203
  "id": mock.ANY,
172
204
  "trace_id": events[0]["trace_id"],
@@ -366,9 +398,15 @@ class State(BaseState):
366
398
  def test_stream_workflow_route__bad_indent_in_inputs_file(both_stream_types):
367
399
  # GIVEN a valid request body
368
400
  span_id = uuid4()
401
+ trace_id = uuid4()
402
+ parent_span_id = uuid4()
369
403
  request_body = {
370
404
  "timeout": 360,
371
405
  "execution_id": str(span_id),
406
+ "execution_context": {
407
+ "trace_id": str(trace_id),
408
+ "parent_context": {"span_id": str(parent_span_id)},
409
+ },
372
410
  "inputs": [
373
411
  {"name": "foo", "type": "STRING", "value": "hello"},
374
412
  ],
@@ -405,7 +443,7 @@ from vellum.workflows.inputs import BaseInputs
405
443
 
406
444
  assert events[0] == {
407
445
  "id": mock.ANY,
408
- "trace_id": mock.ANY,
446
+ "trace_id": str(trace_id),
409
447
  "span_id": str(span_id),
410
448
  "timestamp": mock.ANY,
411
449
  "api_version": "2024-10-25",
@@ -419,9 +457,19 @@ from vellum.workflows.inputs import BaseInputs
419
457
  }
420
458
 
421
459
  assert events[1]["name"] == "workflow.execution.initiated"
460
+ assert events[1]["trace_id"] == str(trace_id), "workflow initiated event should use request trace_id"
461
+ assert events[1]["parent"] is not None, "workflow initiated event should have parent context"
462
+ assert events[1]["parent"]["span_id"] == str(
463
+ parent_span_id
464
+ ), "workflow initiated event parent should match request parent_context"
422
465
 
423
466
  assert events[2]["name"] == "workflow.execution.rejected"
467
+ assert events[2]["trace_id"] == str(trace_id), "workflow rejected event should use request trace_id"
424
468
  assert events[2]["span_id"] == events[1]["span_id"]
469
+ assert events[2]["parent"] is not None, "workflow rejected event should have parent context"
470
+ assert events[2]["parent"]["span_id"] == str(
471
+ parent_span_id
472
+ ), "workflow rejected event parent should match request parent_context"
425
473
  assert (
426
474
  "Syntax Error raised while loading Workflow: "
427
475
  "unexpected indent (inputs.py, line 3)" in events[2]["body"]["error"]["message"]
@@ -653,7 +701,10 @@ class TimeoutWorkflow(BaseWorkflow):
653
701
 
654
702
  assert "workflow.execution.rejected" in event_names, "Should emit workflow.execution.rejected on timeout"
655
703
  workflow_execution_rejected = next(e for e in events if e["name"] == "workflow.execution.rejected")
656
- assert "vellum/workflows/runner/runner.py" in workflow_execution_rejected["body"]["stacktrace"]
704
+ assert workflow_execution_rejected["body"]["error"]["code"] == "WORKFLOW_TIMEOUT"
705
+ # TODO: Uncomment once version 1.8.1 is released
706
+ # assert "stacktrace" in workflow_execution_rejected["body"]
707
+ # assert "vellum/workflows/runner/runner.py" in workflow_execution_rejected["body"]["stacktrace"]
657
708
 
658
709
  assert "vembda.execution.fulfilled" in event_names
659
710
  vembda_fulfilled = next(e for e in events if e["name"] == "vembda.execution.fulfilled")
@@ -8,6 +8,7 @@ import os
8
8
  import pkgutil
9
9
  from queue import Empty
10
10
  import sys
11
+ import threading
11
12
  import time
12
13
  import traceback
13
14
  from uuid import uuid4
@@ -71,7 +72,6 @@ WORKFLOW_INITIATION_TIMEOUT_SECONDS = 60
71
72
  @bp.route("/stream", methods=["POST"])
72
73
  def stream_workflow_route() -> Response:
73
74
  data = request.get_json()
74
-
75
75
  try:
76
76
  context = WorkflowExecutorContext.model_validate(data)
77
77
  except ValidationError as e:
@@ -84,28 +84,7 @@ def stream_workflow_route() -> Response:
84
84
  content_type="application/json",
85
85
  )
86
86
 
87
- logger.info(
88
- f"Starting Workflow Server Request, trace ID: {context.trace_id}, "
89
- f"process count: {get_active_process_count()}, process wrapper: {ENABLE_PROCESS_WRAPPER}"
90
- )
91
-
92
- # Create this event up here so timestamps are fully from the start to account for any unknown overhead
93
- vembda_initiated_event = VembdaExecutionInitiatedEvent(
94
- id=uuid4(),
95
- timestamp=datetime.now(),
96
- trace_id=context.trace_id,
97
- span_id=context.execution_id,
98
- body=VembdaExecutionInitiatedBody.model_validate(get_version()),
99
- parent=None,
100
- )
101
-
102
- process_output_queue: Queue[Union[str, dict]] = Queue()
103
-
104
- headers = {
105
- "X-Vellum-SDK-Version": vembda_initiated_event.body.sdk_version,
106
- "X-Vellum-Server-Version": vembda_initiated_event.body.server_version,
107
- "X-Vellum-Events-Emitted": str(is_events_emitting_enabled(context)),
108
- }
87
+ headers = _get_headers(context)
109
88
 
110
89
  # We can exceed the concurrency count currently with long running workflows due to a knative issue. So here
111
90
  # if we detect a memory problem just exit us early
@@ -122,6 +101,183 @@ def stream_workflow_route() -> Response:
122
101
  headers=headers,
123
102
  )
124
103
 
104
+ start_workflow_state = _start_workflow(context)
105
+ if isinstance(start_workflow_state, Response):
106
+ return start_workflow_state
107
+
108
+ workflow_events, vembda_initiated_event, process, span_id, headers = start_workflow_state
109
+
110
+ def generator() -> Generator[str, None, None]:
111
+ try:
112
+ yield "\n"
113
+ yield vembda_initiated_event.model_dump_json()
114
+ yield "\n"
115
+ for row in workflow_events:
116
+ yield "\n"
117
+ if isinstance(row, dict):
118
+ dump = json.dumps(row)
119
+ yield dump
120
+ else:
121
+ yield row
122
+ yield "\n"
123
+ # Sometimes the connections get hung after they finish with the vembda fulfilled event
124
+ # if it happens during a knative scale down event. So we emit an END string so that
125
+ # we don't have to do string compares on all the events for performance.
126
+ yield "\n"
127
+ yield "END"
128
+ yield "\n"
129
+
130
+ logger.info(
131
+ f"Workflow stream completed, execution ID: {span_id}, process count: {get_active_process_count()}"
132
+ )
133
+ except GeneratorExit:
134
+ # These can happen either from Vembda disconnects (possibily from predict disconnects) or
135
+ # from knative activator gateway timeouts which are caused by idleTimeout or responseStartSeconds
136
+ # being exceeded.
137
+ app.logger.error(
138
+ "Client disconnected in the middle of the Workflow Stream",
139
+ extra={
140
+ "sentry_tags": {
141
+ "server_version": vembda_initiated_event.body.server_version,
142
+ "sdk_version": vembda_initiated_event.body.sdk_version,
143
+ }
144
+ },
145
+ )
146
+ return
147
+ except Exception as e:
148
+ logger.exception("Error during workflow response stream generator", extra={"error": e})
149
+ yield "\n"
150
+ yield "END"
151
+ yield "\n"
152
+ return
153
+ finally:
154
+ if ENABLE_PROCESS_WRAPPER:
155
+ try:
156
+ if process and process.is_alive():
157
+ process.kill()
158
+ if process:
159
+ increment_process_count(-1)
160
+ remove_active_span_id(span_id)
161
+ except Exception as e:
162
+ logger.error("Failed to kill process", e)
163
+ else:
164
+ increment_process_count(-1)
165
+ remove_active_span_id(span_id)
166
+
167
+ resp = Response(
168
+ stream_with_context(generator()),
169
+ status=200,
170
+ content_type="application/x-ndjson",
171
+ headers=headers,
172
+ )
173
+ return resp
174
+
175
+
176
+ @bp.route("/async-exec", methods=["POST"])
177
+ def async_exec_workflow() -> Response:
178
+ data = request.get_json()
179
+ try:
180
+ context = WorkflowExecutorContext.model_validate(data)
181
+ except ValidationError as e:
182
+ error_message = e.errors()[0]["msg"]
183
+ error_location = e.errors()[0]["loc"]
184
+
185
+ # TODO need to convert this to a vembda event so that trigger'd execs can me notified
186
+ # can either do it here in the workflow server or
187
+ return Response(
188
+ json.dumps({"detail": f"Invalid context: {error_message} at {error_location}"}),
189
+ status=400,
190
+ content_type="application/json",
191
+ )
192
+
193
+ # Reject back to the queue handler if were low on memory here, though maybe we should update the is_available
194
+ # route to look at memory too. Don't send this response as an event. Though we might want some logic to catch
195
+ # if they have a workflow server that can never start a workflow because the base image uses so much memory.
196
+ if not wait_for_available_process():
197
+ return Response(
198
+ json.dumps({"detail": f"Server resources low." f"Process count: {get_active_process_count()}"}),
199
+ status=429,
200
+ content_type="application/json",
201
+ )
202
+
203
+ def run_workflow_background() -> None:
204
+ process: Optional[Process] = None
205
+ span_id: Optional[str] = None
206
+
207
+ try:
208
+ start_workflow_result = _start_workflow(context)
209
+ if isinstance(start_workflow_result, Response):
210
+ # TODO same here, should return this response as en event or it will get yeeted to the nether
211
+ # return start_workflow_result
212
+ return
213
+
214
+ workflow_events, vembda_initiated_event, process, span_id, headers = start_workflow_result
215
+
216
+ for _ in workflow_events:
217
+ # This is way inefficient in process mode since were just having the main proc stream the events
218
+ # to nowhere wasting memory I/O and cpu.
219
+ continue
220
+ logger.info(
221
+ f"Workflow async exec completed, execution ID: {span_id}, process count: {get_active_process_count()}"
222
+ )
223
+ except Exception as e:
224
+ logger.exception("Error during workflow async background worker", e)
225
+ finally:
226
+ if ENABLE_PROCESS_WRAPPER:
227
+ try:
228
+ if process and process.is_alive():
229
+ process.kill()
230
+ if process:
231
+ increment_process_count(-1)
232
+ if span_id:
233
+ remove_active_span_id(span_id)
234
+ except Exception as e:
235
+ logger.error("Failed to kill process", e)
236
+ else:
237
+ increment_process_count(-1)
238
+ if span_id:
239
+ remove_active_span_id(span_id)
240
+
241
+ thread = threading.Thread(target=run_workflow_background)
242
+ thread.start()
243
+
244
+ return Response(
245
+ json.dumps({"success": True}),
246
+ status=200,
247
+ content_type="application/json",
248
+ )
249
+
250
+
251
+ def _start_workflow(
252
+ context: WorkflowExecutorContext,
253
+ ) -> Union[
254
+ Response,
255
+ tuple[
256
+ Iterator[Union[str, dict]],
257
+ VembdaExecutionInitiatedEvent,
258
+ Optional[Process],
259
+ str,
260
+ dict[str, str],
261
+ ],
262
+ ]:
263
+ headers = _get_headers(context)
264
+ logger.info(
265
+ f"Starting Workflow Server Request, trace ID: {context.trace_id}, "
266
+ f"process count: {get_active_process_count()}, process wrapper: {ENABLE_PROCESS_WRAPPER}"
267
+ )
268
+
269
+ # Create this event up here so timestamps are fully from the start to account for any unknown overhead
270
+ version_data = get_version()
271
+ vembda_initiated_event = VembdaExecutionInitiatedEvent(
272
+ id=uuid4(),
273
+ timestamp=datetime.now(),
274
+ trace_id=context.trace_id,
275
+ span_id=context.execution_id,
276
+ body=VembdaExecutionInitiatedBody.model_validate(version_data),
277
+ parent=None,
278
+ )
279
+
280
+ output_queue: Queue[Union[str, dict]] = Queue()
125
281
  cancel_signal = MultiprocessingEvent()
126
282
  timeout_signal = MultiprocessingEvent()
127
283
 
@@ -130,7 +286,7 @@ def stream_workflow_route() -> Response:
130
286
  try:
131
287
  process = stream_workflow_process_timeout(
132
288
  executor_context=context,
133
- queue=process_output_queue,
289
+ queue=output_queue,
134
290
  cancel_signal=cancel_signal,
135
291
  timeout_signal=timeout_signal,
136
292
  )
@@ -138,10 +294,10 @@ def stream_workflow_route() -> Response:
138
294
  except Exception as e:
139
295
  logger.exception(e)
140
296
 
141
- process_output_queue.put(create_vembda_rejected_event(context, traceback.format_exc()))
297
+ output_queue.put(create_vembda_rejected_event(context, traceback.format_exc()))
142
298
 
143
299
  try:
144
- first_item = process_output_queue.get(timeout=WORKFLOW_INITIATION_TIMEOUT_SECONDS)
300
+ first_item = output_queue.get(timeout=WORKFLOW_INITIATION_TIMEOUT_SECONDS)
145
301
  except Empty:
146
302
  logger.error("Request timed out trying to initiate the Workflow")
147
303
 
@@ -290,72 +446,9 @@ def stream_workflow_route() -> Response:
290
446
  break
291
447
  yield event
292
448
 
293
- workflow_events = process_events(process_output_queue)
294
-
295
- def generator() -> Generator[str, None, None]:
296
- try:
297
- yield "\n"
298
- yield vembda_initiated_event.model_dump_json()
299
- yield "\n"
300
- for row in workflow_events:
301
- yield "\n"
302
- if isinstance(row, dict):
303
- dump = json.dumps(row)
304
- yield dump
305
- else:
306
- yield row
307
- yield "\n"
308
- # Sometimes the connections get hung after they finish with the vembda fulfilled event
309
- # if it happens during a knative scale down event. So we emit an END string so that
310
- # we don't have to do string compares on all the events for performance.
311
- yield "\n"
312
- yield "END"
313
- yield "\n"
314
-
315
- logger.info(
316
- f"Workflow stream completed, execution ID: {span_id}, process count: {get_active_process_count()}"
317
- )
318
- except GeneratorExit:
319
- # These can happen either from Vembda disconnects (possibily from predict disconnects) or
320
- # from knative activator gateway timeouts which are caused by idleTimeout or responseStartSeconds
321
- # being exceeded.
322
- app.logger.error(
323
- "Client disconnected in the middle of the Workflow Stream",
324
- extra={
325
- "sentry_tags": {
326
- "server_version": vembda_initiated_event.body.server_version,
327
- "sdk_version": vembda_initiated_event.body.sdk_version,
328
- }
329
- },
330
- )
331
- return
332
- except Exception as e:
333
- logger.exception("Error during workflow response stream generator", extra={"error": e})
334
- yield "\n"
335
- yield "END"
336
- yield "\n"
337
- return
338
- finally:
339
- if ENABLE_PROCESS_WRAPPER:
340
- try:
341
- if process and process.is_alive():
342
- process.kill()
343
- if process:
344
- increment_process_count(-1)
345
- remove_active_span_id(span_id)
346
- except Exception as e:
347
- logger.error("Failed to kill process", e)
348
- else:
349
- increment_process_count(-1)
350
- remove_active_span_id(span_id)
449
+ workflow_events = process_events(output_queue)
351
450
 
352
- resp = Response(
353
- stream_with_context(generator()),
354
- status=200,
355
- content_type="application/x-ndjson",
356
- headers=headers,
357
- )
358
- return resp
451
+ return workflow_events, vembda_initiated_event, process, span_id, headers
359
452
 
360
453
 
361
454
  @bp.route("/stream-node", methods=["POST"])
@@ -374,12 +467,13 @@ def stream_node_route() -> Response:
374
467
  )
375
468
 
376
469
  # Create this event up here so timestamps are fully from the start to account for any unknown overhead
470
+ version_data = get_version()
377
471
  vembda_initiated_event = VembdaExecutionInitiatedEvent(
378
472
  id=uuid4(),
379
473
  timestamp=datetime.now(),
380
474
  trace_id=context.trace_id,
381
475
  span_id=context.execution_id,
382
- body=VembdaExecutionInitiatedBody.model_validate(get_version()),
476
+ body=VembdaExecutionInitiatedBody.model_validate(version_data),
383
477
  parent=None,
384
478
  )
385
479
 
@@ -433,6 +527,7 @@ def serialize_route() -> Response:
433
527
 
434
528
  files = data.get("files", {})
435
529
  workspace_api_key = data.get("workspace_api_key")
530
+ is_new_server = data.get("is_new_server", False)
436
531
 
437
532
  if not files:
438
533
  return Response(
@@ -446,6 +541,11 @@ def serialize_route() -> Response:
446
541
  # Generate a unique namespace for this serialization request
447
542
  namespace = get_random_namespace()
448
543
  virtual_finder = VirtualFileFinder(files, namespace)
544
+
545
+ headers = {
546
+ "X-Vellum-Is-New-Server": str(is_new_server).lower(),
547
+ }
548
+
449
549
  try:
450
550
  sys.meta_path.append(virtual_finder)
451
551
  result = BaseWorkflowDisplay.serialize_module(namespace, client=client, dry_run=True)
@@ -454,6 +554,7 @@ def serialize_route() -> Response:
454
554
  json.dumps(result.model_dump()),
455
555
  status=200,
456
556
  content_type="application/json",
557
+ headers=headers,
457
558
  )
458
559
 
459
560
  except WorkflowInitializationException as e:
@@ -463,6 +564,7 @@ def serialize_route() -> Response:
463
564
  json.dumps({"detail": error_message}),
464
565
  status=400,
465
566
  content_type="application/json",
567
+ headers=headers,
466
568
  )
467
569
 
468
570
  except Exception as e:
@@ -471,6 +573,7 @@ def serialize_route() -> Response:
471
573
  json.dumps({"detail": f"Serialization failed: {str(e)}"}),
472
574
  status=500,
473
575
  content_type="application/json",
576
+ headers=headers,
474
577
  )
475
578
 
476
579
  finally:
@@ -553,3 +656,12 @@ def startup_error_generator(
553
656
  },
554
657
  )
555
658
  return
659
+
660
+
661
+ def _get_headers(context: WorkflowExecutorContext) -> dict[str, Union[str, Any]]:
662
+ headers = {
663
+ "X-Vellum-SDK-Version": get_version()["sdk_version"],
664
+ "X-Vellum-Server-Version": get_version()["server_version"],
665
+ "X-Vellum-Events-Emitted": str(is_events_emitting_enabled(context)),
666
+ }
667
+ return headers
workflow_server/config.py CHANGED
@@ -42,6 +42,8 @@ LOCAL_WORKFLOW_MODULE = os.getenv("LOCAL_WORKFLOW_MODULE")
42
42
  # The deployment name to match against when using local mode so you can still run your normal workflow
43
43
  LOCAL_DEPLOYMENT = os.getenv("LOCAL_DEPLOYMENT")
44
44
 
45
+ IS_ASYNC_MODE = os.getenv("IS_ASYNC_MODE", "false").lower() == "true"
46
+
45
47
 
46
48
  def is_development() -> bool:
47
49
  return os.getenv("FLASK_ENV", "local") == "local"
@@ -1,5 +1,4 @@
1
- from datetime import datetime
2
- import importlib
1
+ from datetime import datetime, timezone
3
2
  from io import StringIO
4
3
  import json
5
4
  import logging
@@ -18,10 +17,12 @@ from vellum_ee.workflows.display.utils.events import event_enricher
18
17
  from vellum_ee.workflows.server.virtual_file_loader import VirtualFileFinder
19
18
 
20
19
  from vellum.workflows import BaseWorkflow
20
+ from vellum.workflows.context import execution_context
21
21
  from vellum.workflows.emitters.base import BaseWorkflowEmitter
22
22
  from vellum.workflows.emitters.vellum_emitter import VellumEmitter
23
23
  from vellum.workflows.events.exception_handling import stream_initialization_exception
24
24
  from vellum.workflows.events.types import BaseEvent
25
+ from vellum.workflows.events.workflow import WorkflowEvent
25
26
  from vellum.workflows.exceptions import WorkflowInitializationException
26
27
  from vellum.workflows.inputs import BaseInputs
27
28
  from vellum.workflows.nodes import BaseNode
@@ -30,8 +31,9 @@ from vellum.workflows.resolvers.base import BaseWorkflowResolver
30
31
  from vellum.workflows.resolvers.resolver import VellumResolver
31
32
  from vellum.workflows.state.context import WorkflowContext
32
33
  from vellum.workflows.state.store import EmptyStore
34
+ from vellum.workflows.triggers import BaseTrigger
33
35
  from vellum.workflows.types import CancelSignal
34
- from vellum.workflows.workflows.event_filters import all_workflow_event_filter
36
+ from vellum.workflows.workflows.event_filters import workflow_sandbox_event_filter
35
37
  from workflow_server.config import LOCAL_DEPLOYMENT, LOCAL_WORKFLOW_MODULE
36
38
  from workflow_server.core.cancel_workflow import CancelWorkflowWatcherThread
37
39
  from workflow_server.core.events import (
@@ -51,6 +53,8 @@ from workflow_server.core.workflow_executor_context import (
51
53
  WorkflowExecutorContext,
52
54
  )
53
55
  from workflow_server.utils.log_proxy import redirect_log
56
+ from workflow_server.utils.system_utils import get_memory_in_use_mb
57
+ from workflow_server.utils.utils import get_version
54
58
 
55
59
  logger = logging.getLogger(__name__)
56
60
 
@@ -146,7 +150,21 @@ def stream_workflow(
146
150
  cancel_watcher_kill_switch = ThreadingEvent()
147
151
  try:
148
152
  workflow, namespace = _create_workflow(executor_context)
149
- workflow_inputs = _get_workflow_inputs(executor_context, workflow.__class__)
153
+
154
+ trigger_id = executor_context.trigger_id
155
+
156
+ inputs_or_trigger = workflow.deserialize_trigger(trigger_id=trigger_id, inputs=executor_context.inputs)
157
+
158
+ # Determine whether we have inputs or a trigger
159
+ if isinstance(inputs_or_trigger, BaseInputs):
160
+ workflow_inputs = inputs_or_trigger
161
+ trigger = None
162
+ elif isinstance(inputs_or_trigger, BaseTrigger):
163
+ workflow_inputs = None
164
+ trigger = inputs_or_trigger
165
+ else:
166
+ workflow_inputs = None
167
+ trigger = None
150
168
 
151
169
  workflow_state = (
152
170
  workflow.deserialize_state(
@@ -167,14 +185,22 @@ def stream_workflow(
167
185
  inputs=workflow_inputs,
168
186
  state=workflow_state,
169
187
  node_output_mocks=node_output_mocks,
170
- event_filter=all_workflow_event_filter,
188
+ event_filter=workflow_sandbox_event_filter,
171
189
  cancel_signal=cancel_signal,
172
190
  entrypoint_nodes=[executor_context.node_id] if executor_context.node_id else None,
173
191
  previous_execution_id=executor_context.previous_execution_id,
192
+ timeout=executor_context.timeout,
193
+ trigger=trigger,
194
+ execution_id=executor_context.workflow_span_id,
174
195
  )
175
196
  except WorkflowInitializationException as e:
176
197
  cancel_watcher_kill_switch.set()
177
- initialization_exception_stream = stream_initialization_exception(e)
198
+
199
+ with execution_context(
200
+ parent_context=executor_context.execution_context.parent_context,
201
+ trace_id=executor_context.execution_context.trace_id,
202
+ ):
203
+ initialization_exception_stream = stream_initialization_exception(e)
178
204
 
179
205
  def _stream_generator() -> Generator[dict[str, Any], Any, None]:
180
206
  for event in initialization_exception_stream:
@@ -401,10 +427,55 @@ def get_random_namespace() -> str:
401
427
  return "workflow_tmp_" + "".join(random.choice(string.ascii_letters + string.digits) for i in range(14))
402
428
 
403
429
 
430
+ def _enrich_event(event: WorkflowEvent, executor_context: Optional[BaseExecutorContext] = None) -> WorkflowEvent:
431
+ """
432
+ Enrich an event with metadata based on the event type.
433
+
434
+ For initiated events, include server and SDK versions.
435
+ For fulfilled events with WORKFLOW_DEPLOYMENT parent, include memory usage.
436
+ """
437
+ metadata: Optional[dict] = None
438
+
439
+ try:
440
+ is_deployment = event.parent and event.parent.type in ("WORKFLOW_DEPLOYMENT", "EXTERNAL")
441
+
442
+ if event.name == "workflow.execution.initiated" and is_deployment:
443
+ metadata = {
444
+ **get_version(),
445
+ }
446
+
447
+ memory_mb = get_memory_in_use_mb()
448
+ if memory_mb is not None:
449
+ metadata["memory_usage_mb"] = memory_mb
450
+
451
+ if executor_context is not None:
452
+ metadata["is_new_server"] = executor_context.is_new_server
453
+
454
+ if executor_context.vembda_service_initiated_timestamp is not None and event.timestamp is not None:
455
+ event_ts = event.timestamp
456
+ if event_ts.tzinfo is None:
457
+ event_ts = event_ts.replace(tzinfo=timezone.utc)
458
+ event_ts_ns = int(event_ts.timestamp() * 1_000_000_000)
459
+ initiated_latency = event_ts_ns - executor_context.vembda_service_initiated_timestamp
460
+ metadata["initiated_latency"] = initiated_latency
461
+ elif event.name == "workflow.execution.fulfilled" and is_deployment:
462
+ metadata = {}
463
+ memory_mb = get_memory_in_use_mb()
464
+ if memory_mb is not None:
465
+ metadata["memory_usage_mb"] = memory_mb
466
+ except Exception:
467
+ pass
468
+
469
+ vellum_client = executor_context.vellum_client if executor_context else None
470
+ return event_enricher(event, vellum_client, metadata=metadata)
471
+
472
+
404
473
  def _dump_event(event: BaseEvent, executor_context: BaseExecutorContext) -> dict:
405
474
  module_base = executor_context.module.split(".")
475
+
406
476
  dump = event.model_dump(
407
- mode="json", context={"event_enricher": lambda event: event_enricher(event, executor_context.vellum_client)}
477
+ mode="json",
478
+ context={"event_enricher": lambda event: _enrich_event(event, executor_context)},
408
479
  )
409
480
  if dump["name"] in {
410
481
  "workflow.execution.initiated",
@@ -426,38 +497,3 @@ def _dump_event(event: BaseEvent, executor_context: BaseExecutorContext) -> dict
426
497
  dump["body"]["node_definition"]["module"] = module_base + dump["body"]["node_definition"]["module"][1:]
427
498
 
428
499
  return dump
429
-
430
-
431
- def _get_workflow_inputs(
432
- executor_context: BaseExecutorContext, workflow_class: Type[BaseWorkflow]
433
- ) -> Optional[BaseInputs]:
434
- if not executor_context.inputs:
435
- return None
436
-
437
- if not executor_context.files.get("inputs.py"):
438
- return None
439
-
440
- namespace = _get_file_namespace(executor_context)
441
- inputs_module_path = f"{namespace}.inputs"
442
- try:
443
- inputs_module = importlib.import_module(inputs_module_path)
444
- except Exception as e:
445
- raise WorkflowInitializationException(
446
- message=f"Failed to initialize workflow inputs: {e}",
447
- workflow_definition=workflow_class,
448
- ) from e
449
-
450
- if not hasattr(inputs_module, "Inputs"):
451
- raise WorkflowInitializationException(
452
- message=f"Inputs module {inputs_module_path} does not have a required Inputs class",
453
- workflow_definition=workflow_class,
454
- )
455
-
456
- if not issubclass(inputs_module.Inputs, BaseInputs):
457
- raise WorkflowInitializationException(
458
- message=f"""The class {inputs_module_path}.Inputs was expected to be a subclass of BaseInputs, \
459
- but found {inputs_module.Inputs.__class__.__name__}""",
460
- workflow_definition=workflow_class,
461
- )
462
-
463
- return inputs_module.Inputs(**executor_context.inputs)
@@ -35,6 +35,12 @@ class BaseExecutorContext(UniversalBaseModel):
35
35
  environment_variables: Optional[dict[str, str]] = None
36
36
  previous_execution_id: Optional[UUID] = None
37
37
  feature_flags: Optional[dict[str, bool]] = None
38
+ is_new_server: bool = False
39
+ trigger_id: Optional[UUID] = None
40
+ # The actual 'execution id' of the workflow that we pass into the workflow
41
+ # when running in async mode.
42
+ workflow_span_id: Optional[UUID] = None
43
+ vembda_service_initiated_timestamp: Optional[int] = None
38
44
 
39
45
  @field_validator("inputs", mode="before")
40
46
  @classmethod
@@ -0,0 +1,39 @@
1
+ from datetime import datetime
2
+ import json
3
+ import logging
4
+
5
+
6
+ class GCPJsonFormatter(logging.Formatter):
7
+ """
8
+ Custom JSON formatter for Google Cloud Platform logging.
9
+
10
+ Outputs logs in JSON format with a 'severity' field that GCP Cloud Logging
11
+ can properly parse. This ensures INFO logs show up as INFO in GCP instead of ERROR.
12
+
13
+ See: https://cloud.google.com/logging/docs/structured-logging
14
+ """
15
+
16
+ SEVERITY_MAP = {
17
+ "DEBUG": "DEBUG",
18
+ "INFO": "INFO",
19
+ "WARNING": "WARNING",
20
+ "ERROR": "ERROR",
21
+ "CRITICAL": "CRITICAL",
22
+ }
23
+
24
+ def format(self, record: logging.LogRecord) -> str:
25
+ log_obj = {
26
+ "severity": self.SEVERITY_MAP.get(record.levelname, "DEFAULT"),
27
+ "message": record.getMessage(),
28
+ "timestamp": datetime.utcfromtimestamp(record.created).isoformat() + "Z",
29
+ "logging.googleapis.com/sourceLocation": {
30
+ "file": record.pathname,
31
+ "line": record.lineno,
32
+ "function": record.funcName,
33
+ },
34
+ }
35
+
36
+ if record.exc_info:
37
+ log_obj["exception"] = self.formatException(record.exc_info)
38
+
39
+ return json.dumps(log_obj)
workflow_server/server.py CHANGED
@@ -5,8 +5,10 @@ from flask import Flask
5
5
 
6
6
  from workflow_server.api.auth_middleware import AuthMiddleware
7
7
  from workflow_server.api.healthz_view import bp as healthz_bp
8
+ from workflow_server.api.status_view import bp as status_bp
8
9
  from workflow_server.api.workflow_view import bp as workflow_bp
9
10
  from workflow_server.config import is_development
11
+ from workflow_server.logging_config import GCPJsonFormatter
10
12
  from workflow_server.utils.sentry import init_sentry
11
13
  from workflow_server.utils.utils import get_version
12
14
 
@@ -14,11 +16,18 @@ from workflow_server.utils.utils import get_version
14
16
  # enable_log_proxy()
15
17
 
16
18
  logger = logging.getLogger(__name__)
17
- logging.basicConfig(
18
- level=logging.INFO,
19
- format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
20
- datefmt="%Y-%m-%d %H:%M:%S",
21
- )
19
+
20
+ if is_development():
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
24
+ datefmt="%Y-%m-%d %H:%M:%S",
25
+ )
26
+ else:
27
+ handler = logging.StreamHandler()
28
+ handler.setFormatter(GCPJsonFormatter())
29
+ logging.root.addHandler(handler)
30
+ logging.root.setLevel(logging.INFO)
22
31
 
23
32
 
24
33
  def create_app() -> Flask:
@@ -40,6 +49,7 @@ def create_app() -> Flask:
40
49
 
41
50
  # Register blueprints
42
51
  app.register_blueprint(healthz_bp, url_prefix="/healthz")
52
+ app.register_blueprint(status_bp, url_prefix="/status")
43
53
  app.register_blueprint(workflow_bp, url_prefix="/workflow")
44
54
 
45
55
  logger.info(is_development())
workflow_server/start.py CHANGED
@@ -33,6 +33,7 @@ class CustomGunicornLogger(glogging.Logger):
33
33
  logger = logging.getLogger("gunicorn.access")
34
34
  logger.addFilter(HealthCheckFilter())
35
35
  logger.addFilter(SignalFilter())
36
+ logger.addFilter(StatusIsAvailableFilter())
36
37
 
37
38
 
38
39
  class HealthCheckFilter(logging.Filter):
@@ -45,6 +46,11 @@ class SignalFilter(logging.Filter):
45
46
  return "SIGTERM" not in record.getMessage()
46
47
 
47
48
 
49
+ class StatusIsAvailableFilter(logging.Filter):
50
+ def filter(self, record: Any) -> bool:
51
+ return "/status/is_available" not in record.getMessage()
52
+
53
+
48
54
  def start() -> None:
49
55
  if not is_development():
50
56
  start_oom_killer_worker()