stageflow-core 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stageflow/__init__.py +297 -0
- stageflow/agent/__init__.py +19 -0
- stageflow/auth/__init__.py +56 -0
- stageflow/auth/context.py +119 -0
- stageflow/auth/errors.py +87 -0
- stageflow/auth/events.py +118 -0
- stageflow/auth/interceptors.py +340 -0
- stageflow/auth/tenant.py +317 -0
- stageflow/cli/__init__.py +19 -0
- stageflow/cli/__main__.py +130 -0
- stageflow/cli/lint.py +492 -0
- stageflow/context/__init__.py +50 -0
- stageflow/context/bag.py +131 -0
- stageflow/context/context_snapshot.py +453 -0
- stageflow/context/conversation.py +137 -0
- stageflow/context/enrichments.py +196 -0
- stageflow/context/extensions.py +125 -0
- stageflow/context/identity.py +112 -0
- stageflow/context/output_bag.py +329 -0
- stageflow/context/types.py +27 -0
- stageflow/core/__init__.py +19 -0
- stageflow/core/stage_context.py +136 -0
- stageflow/core/stage_enums.py +28 -0
- stageflow/core/stage_output.py +115 -0
- stageflow/core/stage_protocol.py +40 -0
- stageflow/core/timer.py +44 -0
- stageflow/events/__init__.py +42 -0
- stageflow/events/sink.py +319 -0
- stageflow/extensions.py +161 -0
- stageflow/helpers/__init__.py +104 -0
- stageflow/helpers/analytics.py +548 -0
- stageflow/helpers/guardrails.py +551 -0
- stageflow/helpers/memory.py +326 -0
- stageflow/helpers/mocks.py +843 -0
- stageflow/helpers/providers.py +230 -0
- stageflow/helpers/run_utils.py +582 -0
- stageflow/helpers/streaming.py +651 -0
- stageflow/observability/__init__.py +239 -0
- stageflow/observability/tracing.py +466 -0
- stageflow/observability/wide_events.py +191 -0
- stageflow/pipeline/__init__.py +67 -0
- stageflow/pipeline/builder.py +336 -0
- stageflow/pipeline/cancellation.py +332 -0
- stageflow/pipeline/dag.py +735 -0
- stageflow/pipeline/interceptors.py +556 -0
- stageflow/pipeline/interfaces.py +182 -0
- stageflow/pipeline/pipeline.py +173 -0
- stageflow/pipeline/registry.py +88 -0
- stageflow/pipeline/spec.py +118 -0
- stageflow/pipeline/subpipeline.py +682 -0
- stageflow/projector/__init__.py +18 -0
- stageflow/projector/service.py +135 -0
- stageflow/protocols.py +272 -0
- stageflow/py.typed +0 -0
- stageflow/stages/__init__.py +32 -0
- stageflow/stages/context.py +329 -0
- stageflow/stages/errors.py +384 -0
- stageflow/stages/inputs.py +308 -0
- stageflow/stages/ports.py +176 -0
- stageflow/stages/result.py +41 -0
- stageflow/testing.py +402 -0
- stageflow/tools/__init__.py +140 -0
- stageflow/tools/adapters.py +112 -0
- stageflow/tools/approval.py +337 -0
- stageflow/tools/base.py +106 -0
- stageflow/tools/definitions.py +264 -0
- stageflow/tools/diff.py +488 -0
- stageflow/tools/errors.py +176 -0
- stageflow/tools/events.py +207 -0
- stageflow/tools/executor.py +127 -0
- stageflow/tools/executor_v2.py +558 -0
- stageflow/tools/registry.py +302 -0
- stageflow/tools/undo.py +183 -0
- stageflow/utils/__init__.py +8 -0
- stageflow/utils/frozen.py +130 -0
- stageflow_core-0.2.0.dist-info/METADATA +252 -0
- stageflow_core-0.2.0.dist-info/RECORD +78 -0
- stageflow_core-0.2.0.dist-info/WHEEL +4 -0
stageflow/__init__.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""Stageflow - DAG-based pipeline orchestration framework.
|
|
2
|
+
|
|
3
|
+
This package provides a framework for building observable, composable
|
|
4
|
+
stage pipelines with parallel execution, cancellation, and interceptors.
|
|
5
|
+
|
|
6
|
+
Core Components:
|
|
7
|
+
- Stage: Protocol for pipeline stage implementations
|
|
8
|
+
- Pipeline: Fluent builder for composing stages into DAGs
|
|
9
|
+
- StageGraph: DAG executor with parallel execution
|
|
10
|
+
- Interceptors: Middleware for cross-cutting concerns
|
|
11
|
+
- EventSink: Protocol for event persistence
|
|
12
|
+
|
|
13
|
+
Stage Kinds:
|
|
14
|
+
- TRANSFORM: Data transformation stages (STT, TTS, LLM)
|
|
15
|
+
- ENRICH: Context enrichment stages (Profile, Memory)
|
|
16
|
+
- ROUTE: Routing decision stages (Router)
|
|
17
|
+
- GUARD: Guardrail/validation stages
|
|
18
|
+
- WORK: Side-effect stages (Persist, Assessment)
|
|
19
|
+
- AGENT: Agentic/coaching stages
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
from stageflow import Pipeline, Stage, StageOutput, StageKind
|
|
23
|
+
|
|
24
|
+
class MyStage:
|
|
25
|
+
name = "my_stage"
|
|
26
|
+
kind = StageKind.TRANSFORM
|
|
27
|
+
|
|
28
|
+
async def execute(self, ctx):
|
|
29
|
+
return StageOutput.ok(result="done")
|
|
30
|
+
|
|
31
|
+
pipeline = Pipeline().with_stage("my", MyStage, StageKind.TRANSFORM)
|
|
32
|
+
graph = pipeline.build()
|
|
33
|
+
results = await graph.run(ctx)
|
|
34
|
+
|
|
35
|
+
Extension System:
|
|
36
|
+
Stageflow provides a generic extension system for application-specific data.
|
|
37
|
+
Use ContextSnapshot.extensions dict to store application data:
|
|
38
|
+
|
|
39
|
+
snapshot = ContextSnapshot(
|
|
40
|
+
...
|
|
41
|
+
extensions={"skills": {"active_skill_ids": ["python"]}}
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
For type-safe extensions, use the ExtensionRegistry in stageflow.extensions.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
# Core stage types
|
|
48
|
+
# CLI and linting
|
|
49
|
+
from stageflow.cli.lint import (
|
|
50
|
+
DependencyIssue,
|
|
51
|
+
DependencyLintResult,
|
|
52
|
+
IssueSeverity,
|
|
53
|
+
lint_pipeline,
|
|
54
|
+
lint_pipeline_file,
|
|
55
|
+
)
|
|
56
|
+
from stageflow.core import (
|
|
57
|
+
PipelineTimer,
|
|
58
|
+
Stage,
|
|
59
|
+
StageArtifact,
|
|
60
|
+
StageContext,
|
|
61
|
+
StageEvent,
|
|
62
|
+
StageKind,
|
|
63
|
+
StageOutput,
|
|
64
|
+
StageStatus,
|
|
65
|
+
create_stage_context,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Events
|
|
69
|
+
from stageflow.events import (
|
|
70
|
+
EventSink,
|
|
71
|
+
LoggingEventSink,
|
|
72
|
+
NoOpEventSink,
|
|
73
|
+
clear_event_sink,
|
|
74
|
+
get_event_sink,
|
|
75
|
+
set_event_sink,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Extensions
|
|
79
|
+
from stageflow.extensions import (
|
|
80
|
+
ExtensionHelper,
|
|
81
|
+
ExtensionRegistry,
|
|
82
|
+
TypedExtension,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Observability protocols
|
|
86
|
+
from stageflow.observability import (
|
|
87
|
+
CircuitBreaker,
|
|
88
|
+
CircuitBreakerOpenError,
|
|
89
|
+
PipelineRunLogger,
|
|
90
|
+
ProviderCallLogger,
|
|
91
|
+
error_summary_to_stages_patch,
|
|
92
|
+
error_summary_to_string,
|
|
93
|
+
get_circuit_breaker,
|
|
94
|
+
summarize_pipeline_error,
|
|
95
|
+
)
|
|
96
|
+
from stageflow.pipeline.dag import (
|
|
97
|
+
StageExecutionError,
|
|
98
|
+
StageGraph,
|
|
99
|
+
StageSpec,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Interceptors
|
|
103
|
+
from stageflow.pipeline.interceptors import (
|
|
104
|
+
BaseInterceptor,
|
|
105
|
+
ChildTrackerMetricsInterceptor,
|
|
106
|
+
CircuitBreakerInterceptor,
|
|
107
|
+
ErrorAction,
|
|
108
|
+
InterceptorContext,
|
|
109
|
+
InterceptorResult,
|
|
110
|
+
LoggingInterceptor,
|
|
111
|
+
MetricsInterceptor,
|
|
112
|
+
TimeoutInterceptor,
|
|
113
|
+
TracingInterceptor,
|
|
114
|
+
get_default_interceptors,
|
|
115
|
+
run_with_interceptors,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Pipeline types
|
|
119
|
+
from stageflow.pipeline.pipeline import (
|
|
120
|
+
Pipeline,
|
|
121
|
+
UnifiedStageSpec,
|
|
122
|
+
)
|
|
123
|
+
from stageflow.pipeline.registry import (
|
|
124
|
+
PipelineRegistry,
|
|
125
|
+
pipeline_registry,
|
|
126
|
+
)
|
|
127
|
+
from stageflow.pipeline.spec import (
|
|
128
|
+
CycleDetectedError,
|
|
129
|
+
PipelineValidationError,
|
|
130
|
+
)
|
|
131
|
+
from stageflow.pipeline.subpipeline import (
|
|
132
|
+
ChildRunTracker,
|
|
133
|
+
MaxDepthExceededError,
|
|
134
|
+
SubpipelineResult,
|
|
135
|
+
SubpipelineSpawner,
|
|
136
|
+
get_child_tracker,
|
|
137
|
+
get_subpipeline_spawner,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Projector
|
|
141
|
+
from stageflow.projector.service import (
|
|
142
|
+
WSMessageProjector,
|
|
143
|
+
WSMetadata,
|
|
144
|
+
WSOutboundMessage,
|
|
145
|
+
WSStatusUpdatePayload,
|
|
146
|
+
_coerce_uuid_str,
|
|
147
|
+
)
|
|
148
|
+
from stageflow.projector.service import (
|
|
149
|
+
WSMessageProjector as ProjectorService, # Backward compatibility
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Protocols
|
|
153
|
+
from stageflow.protocols import (
|
|
154
|
+
ConfigProvider,
|
|
155
|
+
CorrelationIds,
|
|
156
|
+
RunStore,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Context types
|
|
160
|
+
from stageflow.stages.context import (
|
|
161
|
+
PipelineContext,
|
|
162
|
+
extract_service,
|
|
163
|
+
)
|
|
164
|
+
from stageflow.stages.inputs import (
|
|
165
|
+
StageInputs,
|
|
166
|
+
create_stage_inputs,
|
|
167
|
+
)
|
|
168
|
+
from stageflow.stages.ports import (
|
|
169
|
+
AudioPorts,
|
|
170
|
+
CorePorts,
|
|
171
|
+
LLMPorts,
|
|
172
|
+
create_audio_ports,
|
|
173
|
+
create_core_ports,
|
|
174
|
+
create_llm_ports,
|
|
175
|
+
)
|
|
176
|
+
from stageflow.stages.result import (
|
|
177
|
+
StageError,
|
|
178
|
+
StageResult,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Testing utilities (optional)
|
|
182
|
+
from stageflow.testing import (
|
|
183
|
+
create_test_snapshot,
|
|
184
|
+
create_test_stage_context,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
__all__ = [
|
|
188
|
+
# Core stage types
|
|
189
|
+
"Stage",
|
|
190
|
+
"StageKind",
|
|
191
|
+
"StageStatus",
|
|
192
|
+
"StageOutput",
|
|
193
|
+
"StageContext",
|
|
194
|
+
"StageArtifact",
|
|
195
|
+
"StageEvent",
|
|
196
|
+
"StageError",
|
|
197
|
+
"StageResult",
|
|
198
|
+
# Context utilities
|
|
199
|
+
"create_stage_context",
|
|
200
|
+
# Timer
|
|
201
|
+
"PipelineTimer",
|
|
202
|
+
# Pipeline types
|
|
203
|
+
"Pipeline",
|
|
204
|
+
"LinearPipeline",
|
|
205
|
+
"UnifiedStageSpec",
|
|
206
|
+
# DAG types
|
|
207
|
+
"StageExecutionError",
|
|
208
|
+
"StageGraph",
|
|
209
|
+
"StageSpec",
|
|
210
|
+
# Registry
|
|
211
|
+
"PipelineRegistry",
|
|
212
|
+
"pipeline_registry",
|
|
213
|
+
# Context types
|
|
214
|
+
"PipelineContext",
|
|
215
|
+
# Testing utilities
|
|
216
|
+
"create_test_snapshot",
|
|
217
|
+
"create_test_stage_context",
|
|
218
|
+
"StageError",
|
|
219
|
+
"extract_service",
|
|
220
|
+
# Interceptors
|
|
221
|
+
"BaseInterceptor",
|
|
222
|
+
"InterceptorResult",
|
|
223
|
+
"InterceptorContext",
|
|
224
|
+
"ErrorAction",
|
|
225
|
+
"LoggingInterceptor",
|
|
226
|
+
"MetricsInterceptor",
|
|
227
|
+
"ChildTrackerMetricsInterceptor",
|
|
228
|
+
"TracingInterceptor",
|
|
229
|
+
"CircuitBreakerInterceptor",
|
|
230
|
+
"TimeoutInterceptor",
|
|
231
|
+
"get_default_interceptors",
|
|
232
|
+
"run_with_interceptors",
|
|
233
|
+
# Events
|
|
234
|
+
"EventSink",
|
|
235
|
+
"NoOpEventSink",
|
|
236
|
+
"LoggingEventSink",
|
|
237
|
+
"get_event_sink",
|
|
238
|
+
"set_event_sink",
|
|
239
|
+
"clear_event_sink",
|
|
240
|
+
# Protocols
|
|
241
|
+
"RunStore",
|
|
242
|
+
"ConfigProvider",
|
|
243
|
+
"CorrelationIds",
|
|
244
|
+
# Observability
|
|
245
|
+
"CircuitBreaker",
|
|
246
|
+
"CircuitBreakerOpenError",
|
|
247
|
+
"PipelineRunLogger",
|
|
248
|
+
"ProviderCallLogger",
|
|
249
|
+
"summarize_pipeline_error",
|
|
250
|
+
"error_summary_to_string",
|
|
251
|
+
"error_summary_to_stages_patch",
|
|
252
|
+
"get_circuit_breaker",
|
|
253
|
+
# Extensions
|
|
254
|
+
"ExtensionRegistry",
|
|
255
|
+
"ExtensionHelper",
|
|
256
|
+
"TypedExtension",
|
|
257
|
+
# Stage inputs/ports
|
|
258
|
+
"StageInputs",
|
|
259
|
+
"create_stage_inputs",
|
|
260
|
+
"CorePorts",
|
|
261
|
+
"LLMPorts",
|
|
262
|
+
"AudioPorts",
|
|
263
|
+
"create_core_ports",
|
|
264
|
+
"create_llm_ports",
|
|
265
|
+
"create_audio_ports",
|
|
266
|
+
# Pipeline validation
|
|
267
|
+
"CycleDetectedError",
|
|
268
|
+
"PipelineValidationError",
|
|
269
|
+
# Subpipeline
|
|
270
|
+
"SubpipelineSpawner",
|
|
271
|
+
"SubpipelineResult",
|
|
272
|
+
"ChildRunTracker",
|
|
273
|
+
"MaxDepthExceededError",
|
|
274
|
+
"get_child_tracker",
|
|
275
|
+
"get_subpipeline_spawner",
|
|
276
|
+
# CLI and linting
|
|
277
|
+
"DependencyIssue",
|
|
278
|
+
"DependencyLintResult",
|
|
279
|
+
"IssueSeverity",
|
|
280
|
+
"lint_pipeline",
|
|
281
|
+
"lint_pipeline_file",
|
|
282
|
+
# Projector
|
|
283
|
+
"WSMessageProjector",
|
|
284
|
+
"ProjectorService",
|
|
285
|
+
"WSMetadata",
|
|
286
|
+
"WSOutboundMessage",
|
|
287
|
+
"WSStatusUpdatePayload",
|
|
288
|
+
"_coerce_uuid_str",
|
|
289
|
+
# Testing utilities
|
|
290
|
+
"create_test_snapshot",
|
|
291
|
+
"create_test_context",
|
|
292
|
+
"create_test_pipeline_context",
|
|
293
|
+
# Wide events
|
|
294
|
+
"WideEventEmitter",
|
|
295
|
+
"emit_stage_wide_event",
|
|
296
|
+
"emit_pipeline_wide_event",
|
|
297
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Agent package containing ContextSnapshot and related types."""
|
|
2
|
+
|
|
3
|
+
from stageflow.context import (
|
|
4
|
+
ContextSnapshot,
|
|
5
|
+
DocumentEnrichment,
|
|
6
|
+
MemoryEnrichment,
|
|
7
|
+
Message,
|
|
8
|
+
ProfileEnrichment,
|
|
9
|
+
RoutingDecision,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ContextSnapshot",
|
|
14
|
+
"Message",
|
|
15
|
+
"RoutingDecision",
|
|
16
|
+
"ProfileEnrichment",
|
|
17
|
+
"MemoryEnrichment",
|
|
18
|
+
"DocumentEnrichment",
|
|
19
|
+
]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Stageflow auth module - authentication and authorization types."""
|
|
2
|
+
|
|
3
|
+
from stageflow.auth.context import AuthContext, OrgContext
|
|
4
|
+
from stageflow.auth.errors import (
|
|
5
|
+
AuthenticationError,
|
|
6
|
+
CrossTenantAccessError,
|
|
7
|
+
InvalidTokenError,
|
|
8
|
+
MissingClaimsError,
|
|
9
|
+
TokenExpiredError,
|
|
10
|
+
)
|
|
11
|
+
from stageflow.auth.events import (
|
|
12
|
+
AuthFailureEvent,
|
|
13
|
+
AuthLoginEvent,
|
|
14
|
+
TenantAccessDeniedEvent,
|
|
15
|
+
)
|
|
16
|
+
from stageflow.auth.interceptors import (
|
|
17
|
+
AuthInterceptor,
|
|
18
|
+
JwtValidator,
|
|
19
|
+
MockJwtValidator,
|
|
20
|
+
OrgEnforcementInterceptor,
|
|
21
|
+
)
|
|
22
|
+
from stageflow.auth.tenant import (
|
|
23
|
+
TenantAwareLogger,
|
|
24
|
+
TenantContext,
|
|
25
|
+
TenantIsolationError,
|
|
26
|
+
TenantIsolationValidator,
|
|
27
|
+
clear_current_tenant,
|
|
28
|
+
get_current_tenant,
|
|
29
|
+
require_tenant,
|
|
30
|
+
set_current_tenant,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"AuthContext",
|
|
35
|
+
"AuthenticationError",
|
|
36
|
+
"AuthFailureEvent",
|
|
37
|
+
"AuthInterceptor",
|
|
38
|
+
"AuthLoginEvent",
|
|
39
|
+
"CrossTenantAccessError",
|
|
40
|
+
"InvalidTokenError",
|
|
41
|
+
"JwtValidator",
|
|
42
|
+
"MissingClaimsError",
|
|
43
|
+
"MockJwtValidator",
|
|
44
|
+
"OrgContext",
|
|
45
|
+
"OrgEnforcementInterceptor",
|
|
46
|
+
"TenantAccessDeniedEvent",
|
|
47
|
+
"TenantAwareLogger",
|
|
48
|
+
"TenantContext",
|
|
49
|
+
"TenantIsolationError",
|
|
50
|
+
"TenantIsolationValidator",
|
|
51
|
+
"TokenExpiredError",
|
|
52
|
+
"clear_current_tenant",
|
|
53
|
+
"get_current_tenant",
|
|
54
|
+
"require_tenant",
|
|
55
|
+
"set_current_tenant",
|
|
56
|
+
]
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Authentication and organization context types.
|
|
2
|
+
|
|
3
|
+
This module provides immutable dataclasses for representing
|
|
4
|
+
authenticated user context and organization context.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Literal
|
|
11
|
+
from uuid import UUID
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True, slots=True)
|
|
15
|
+
class AuthContext:
|
|
16
|
+
"""Authenticated user context from JWT validation.
|
|
17
|
+
|
|
18
|
+
Immutable dataclass containing user identity and authorization
|
|
19
|
+
information extracted from a validated JWT token.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
user_id: Unique user identifier
|
|
23
|
+
email: User's email address (optional)
|
|
24
|
+
org_id: Organization/tenant identifier (optional for personal accounts)
|
|
25
|
+
roles: Tuple of role names assigned to the user
|
|
26
|
+
session_id: Current session identifier
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
user_id: UUID
|
|
30
|
+
session_id: UUID
|
|
31
|
+
email: str | None = None
|
|
32
|
+
org_id: UUID | None = None
|
|
33
|
+
roles: tuple[str, ...] = field(default_factory=tuple)
|
|
34
|
+
|
|
35
|
+
def has_role(self, role: str) -> bool:
|
|
36
|
+
"""Check if user has a specific role.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
role: Role name to check
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
True if user has the role
|
|
43
|
+
"""
|
|
44
|
+
return role in self.roles
|
|
45
|
+
|
|
46
|
+
def is_admin(self) -> bool:
|
|
47
|
+
"""Check if user has admin privileges.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
True if user has 'admin' or 'org_admin' role
|
|
51
|
+
"""
|
|
52
|
+
return self.has_role("admin") or self.has_role("org_admin")
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def is_authenticated(self) -> bool:
|
|
56
|
+
"""Check if this represents an authenticated user.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Always True for a valid AuthContext
|
|
60
|
+
"""
|
|
61
|
+
return True
|
|
62
|
+
|
|
63
|
+
def __repr__(self) -> str:
|
|
64
|
+
"""Return string representation hiding sensitive email."""
|
|
65
|
+
email_display = f"{self.email[:3]}***" if self.email else None
|
|
66
|
+
return (
|
|
67
|
+
f"AuthContext(user_id={self.user_id!r}, "
|
|
68
|
+
f"email={email_display!r}, "
|
|
69
|
+
f"org_id={self.org_id!r}, "
|
|
70
|
+
f"roles={self.roles!r})"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
PlanTier = Literal["starter", "pro", "enterprise"]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True, slots=True)
|
|
78
|
+
class OrgContext:
|
|
79
|
+
"""Organization context with plan and feature information.
|
|
80
|
+
|
|
81
|
+
Immutable dataclass containing organization metadata,
|
|
82
|
+
subscription tier, and enabled features.
|
|
83
|
+
|
|
84
|
+
Attributes:
|
|
85
|
+
org_id: Organization identifier
|
|
86
|
+
tenant_id: Tenant identifier (may differ from org_id in multi-tenant setups)
|
|
87
|
+
plan_tier: Subscription tier level
|
|
88
|
+
features: Tuple of enabled feature flags
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
org_id: UUID
|
|
92
|
+
tenant_id: UUID | None = None
|
|
93
|
+
plan_tier: PlanTier = "starter"
|
|
94
|
+
features: tuple[str, ...] = field(default_factory=tuple)
|
|
95
|
+
|
|
96
|
+
def has_feature(self, feature: str) -> bool:
|
|
97
|
+
"""Check if organization has a specific feature enabled.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
feature: Feature name to check
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
True if feature is enabled
|
|
104
|
+
"""
|
|
105
|
+
return feature in self.features
|
|
106
|
+
|
|
107
|
+
def __repr__(self) -> str:
|
|
108
|
+
return (
|
|
109
|
+
f"OrgContext(org_id={self.org_id!r}, "
|
|
110
|
+
f"plan_tier={self.plan_tier!r}, "
|
|
111
|
+
f"features={self.features!r})"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
__all__ = [
|
|
116
|
+
"AuthContext",
|
|
117
|
+
"OrgContext",
|
|
118
|
+
"PlanTier",
|
|
119
|
+
]
|
stageflow/auth/errors.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Authentication and authorization exceptions.
|
|
2
|
+
|
|
3
|
+
This module provides exception classes for authentication failures,
|
|
4
|
+
token validation errors, and cross-tenant access violations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AuthenticationError(Exception):
|
|
11
|
+
"""Base exception for authentication failures.
|
|
12
|
+
|
|
13
|
+
All authentication-related errors inherit from this class,
|
|
14
|
+
enabling catch-all handling for auth failures.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, message: str, code: str = "auth_error") -> None:
|
|
18
|
+
super().__init__(message)
|
|
19
|
+
self.code = code
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TokenExpiredError(AuthenticationError):
|
|
23
|
+
"""Raised when a JWT token has expired.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
expired_at: ISO timestamp when the token expired
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, message: str = "Token has expired", expired_at: str | None = None) -> None:
|
|
30
|
+
super().__init__(message, code="token_expired")
|
|
31
|
+
self.expired_at = expired_at
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class InvalidTokenError(AuthenticationError):
|
|
35
|
+
"""Raised when a JWT token is invalid or malformed.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
reason: Specific reason for invalidity (e.g., "invalid_signature", "malformed")
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, message: str = "Invalid token", reason: str | None = None) -> None:
|
|
42
|
+
super().__init__(message, code="invalid_token")
|
|
43
|
+
self.reason = reason
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MissingClaimsError(AuthenticationError):
|
|
47
|
+
"""Raised when required JWT claims are missing.
|
|
48
|
+
|
|
49
|
+
Attributes:
|
|
50
|
+
missing_claims: List of claim names that are missing
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
message: str = "Required claims are missing",
|
|
56
|
+
missing_claims: list[str] | None = None,
|
|
57
|
+
) -> None:
|
|
58
|
+
super().__init__(message, code="missing_claims")
|
|
59
|
+
self.missing_claims = missing_claims or []
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class CrossTenantAccessError(AuthenticationError):
|
|
63
|
+
"""Raised when a user attempts to access another tenant's resources.
|
|
64
|
+
|
|
65
|
+
Attributes:
|
|
66
|
+
user_org_id: The user's organization ID
|
|
67
|
+
resource_org_id: The resource's organization ID
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
message: str = "Cross-tenant access denied",
|
|
73
|
+
user_org_id: str | None = None,
|
|
74
|
+
resource_org_id: str | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
super().__init__(message, code="cross_tenant_access")
|
|
77
|
+
self.user_org_id = user_org_id
|
|
78
|
+
self.resource_org_id = resource_org_id
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
__all__ = [
|
|
82
|
+
"AuthenticationError",
|
|
83
|
+
"CrossTenantAccessError",
|
|
84
|
+
"InvalidTokenError",
|
|
85
|
+
"MissingClaimsError",
|
|
86
|
+
"TokenExpiredError",
|
|
87
|
+
]
|
stageflow/auth/events.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Authentication audit event types.
|
|
2
|
+
|
|
3
|
+
This module defines structured event types for authentication
|
|
4
|
+
and authorization audit logging.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from datetime import UTC, datetime
|
|
11
|
+
from typing import Any
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True, slots=True)
|
|
16
|
+
class AuthLoginEvent:
|
|
17
|
+
"""Event emitted on successful authentication.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
user_id: Authenticated user ID
|
|
21
|
+
org_id: Organization ID (if applicable)
|
|
22
|
+
session_id: New session ID
|
|
23
|
+
timestamp: When authentication occurred
|
|
24
|
+
request_id: Request correlation ID
|
|
25
|
+
pipeline_run_id: Pipeline run correlation ID
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
user_id: UUID
|
|
29
|
+
session_id: UUID
|
|
30
|
+
org_id: UUID | None = None
|
|
31
|
+
timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat())
|
|
32
|
+
request_id: UUID | None = None
|
|
33
|
+
pipeline_run_id: UUID | None = None
|
|
34
|
+
|
|
35
|
+
def to_dict(self) -> dict[str, Any]:
|
|
36
|
+
"""Convert to dictionary for event emission."""
|
|
37
|
+
return {
|
|
38
|
+
"type": "auth.login",
|
|
39
|
+
"user_id": str(self.user_id),
|
|
40
|
+
"org_id": str(self.org_id) if self.org_id else None,
|
|
41
|
+
"session_id": str(self.session_id),
|
|
42
|
+
"timestamp": self.timestamp,
|
|
43
|
+
"request_id": str(self.request_id) if self.request_id else None,
|
|
44
|
+
"pipeline_run_id": str(self.pipeline_run_id) if self.pipeline_run_id else None,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(frozen=True, slots=True)
|
|
49
|
+
class AuthFailureEvent:
|
|
50
|
+
"""Event emitted on authentication failure.
|
|
51
|
+
|
|
52
|
+
Attributes:
|
|
53
|
+
reason: Failure reason code (e.g., "token_expired", "invalid_signature")
|
|
54
|
+
ip_address: Client IP address
|
|
55
|
+
user_agent: Client user agent string
|
|
56
|
+
timestamp: When failure occurred
|
|
57
|
+
request_id: Request correlation ID
|
|
58
|
+
user_id: User ID if known (e.g., from expired token)
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
reason: str
|
|
62
|
+
ip_address: str | None = None
|
|
63
|
+
user_agent: str | None = None
|
|
64
|
+
timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat())
|
|
65
|
+
request_id: UUID | None = None
|
|
66
|
+
user_id: UUID | None = None
|
|
67
|
+
|
|
68
|
+
def to_dict(self) -> dict[str, Any]:
|
|
69
|
+
"""Convert to dictionary for event emission."""
|
|
70
|
+
return {
|
|
71
|
+
"type": "auth.failure",
|
|
72
|
+
"reason": self.reason,
|
|
73
|
+
"ip_address": self.ip_address,
|
|
74
|
+
"user_agent": self.user_agent,
|
|
75
|
+
"timestamp": self.timestamp,
|
|
76
|
+
"request_id": str(self.request_id) if self.request_id else None,
|
|
77
|
+
"user_id": str(self.user_id) if self.user_id else None,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(frozen=True, slots=True)
|
|
82
|
+
class TenantAccessDeniedEvent:
|
|
83
|
+
"""Event emitted on cross-tenant access attempt.
|
|
84
|
+
|
|
85
|
+
Attributes:
|
|
86
|
+
user_org_id: User's organization ID
|
|
87
|
+
resource_org_id: Resource's organization ID
|
|
88
|
+
user_id: User who attempted access
|
|
89
|
+
timestamp: When violation occurred
|
|
90
|
+
request_id: Request correlation ID
|
|
91
|
+
pipeline_run_id: Pipeline run correlation ID
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
user_org_id: UUID
|
|
95
|
+
resource_org_id: UUID
|
|
96
|
+
user_id: UUID | None = None
|
|
97
|
+
timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat())
|
|
98
|
+
request_id: UUID | None = None
|
|
99
|
+
pipeline_run_id: UUID | None = None
|
|
100
|
+
|
|
101
|
+
def to_dict(self) -> dict[str, Any]:
|
|
102
|
+
"""Convert to dictionary for event emission."""
|
|
103
|
+
return {
|
|
104
|
+
"type": "tenant.access_denied",
|
|
105
|
+
"user_org_id": str(self.user_org_id),
|
|
106
|
+
"resource_org_id": str(self.resource_org_id),
|
|
107
|
+
"user_id": str(self.user_id) if self.user_id else None,
|
|
108
|
+
"timestamp": self.timestamp,
|
|
109
|
+
"request_id": str(self.request_id) if self.request_id else None,
|
|
110
|
+
"pipeline_run_id": str(self.pipeline_run_id) if self.pipeline_run_id else None,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
__all__ = [
|
|
115
|
+
"AuthFailureEvent",
|
|
116
|
+
"AuthLoginEvent",
|
|
117
|
+
"TenantAccessDeniedEvent",
|
|
118
|
+
]
|