polos-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. polos/__init__.py +105 -0
  2. polos/agents/__init__.py +7 -0
  3. polos/agents/agent.py +746 -0
  4. polos/agents/conversation_history.py +121 -0
  5. polos/agents/stop_conditions.py +280 -0
  6. polos/agents/stream.py +635 -0
  7. polos/core/__init__.py +0 -0
  8. polos/core/context.py +143 -0
  9. polos/core/state.py +26 -0
  10. polos/core/step.py +1380 -0
  11. polos/core/workflow.py +1192 -0
  12. polos/features/__init__.py +0 -0
  13. polos/features/events.py +456 -0
  14. polos/features/schedules.py +110 -0
  15. polos/features/tracing.py +605 -0
  16. polos/features/wait.py +82 -0
  17. polos/llm/__init__.py +9 -0
  18. polos/llm/generate.py +152 -0
  19. polos/llm/providers/__init__.py +5 -0
  20. polos/llm/providers/anthropic.py +615 -0
  21. polos/llm/providers/azure.py +42 -0
  22. polos/llm/providers/base.py +196 -0
  23. polos/llm/providers/fireworks.py +41 -0
  24. polos/llm/providers/gemini.py +40 -0
  25. polos/llm/providers/groq.py +40 -0
  26. polos/llm/providers/openai.py +1021 -0
  27. polos/llm/providers/together.py +40 -0
  28. polos/llm/stream.py +183 -0
  29. polos/middleware/__init__.py +0 -0
  30. polos/middleware/guardrail.py +148 -0
  31. polos/middleware/guardrail_executor.py +253 -0
  32. polos/middleware/hook.py +164 -0
  33. polos/middleware/hook_executor.py +104 -0
  34. polos/runtime/__init__.py +0 -0
  35. polos/runtime/batch.py +87 -0
  36. polos/runtime/client.py +841 -0
  37. polos/runtime/queue.py +42 -0
  38. polos/runtime/worker.py +1365 -0
  39. polos/runtime/worker_server.py +249 -0
  40. polos/tools/__init__.py +0 -0
  41. polos/tools/tool.py +587 -0
  42. polos/types/__init__.py +23 -0
  43. polos/types/types.py +116 -0
  44. polos/utils/__init__.py +27 -0
  45. polos/utils/agent.py +27 -0
  46. polos/utils/client_context.py +41 -0
  47. polos/utils/config.py +12 -0
  48. polos/utils/output_schema.py +311 -0
  49. polos/utils/retry.py +47 -0
  50. polos/utils/serializer.py +167 -0
  51. polos/utils/tracing.py +27 -0
  52. polos/utils/worker_singleton.py +40 -0
  53. polos_sdk-0.1.0.dist-info/METADATA +650 -0
  54. polos_sdk-0.1.0.dist-info/RECORD +55 -0
  55. polos_sdk-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,605 @@
1
+ """OpenTelemetry tracing support for Polos workflows."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import os
7
+ import threading
8
+ from collections.abc import Sequence
9
+ from datetime import datetime, timezone
10
+ from typing import Any
11
+
12
+ import httpx
13
+
14
+ from ..utils.client_context import get_client_or_raise
15
+
16
+ try:
17
+ from opentelemetry import context, trace
18
+ from opentelemetry.sdk.trace import RandomIdGenerator, ReadableSpan, TracerProvider
19
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter, SpanExportResult
20
+ from opentelemetry.trace import SpanKind, Status, StatusCode
21
+ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
22
+ from opentelemetry.trace.span import Span
23
+
24
+ OTELEMETRY_AVAILABLE = True
25
+ except ImportError:
26
+ OTELEMETRY_AVAILABLE = False
27
+
28
+ # Create no-op types for when OpenTelemetry is not available
29
+ class Span:
30
+ pass
31
+
32
+ class Status:
33
+ pass
34
+
35
+ class StatusCode:
36
+ OK = "OK"
37
+ ERROR = "ERROR"
38
+
39
+ SpanKind = None
40
+ SpanExporter = None
41
+ SpanExportResult = None
42
+ ReadableSpan = None
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Global state
47
+ _tracer_provider = None
48
+ _tracer = None
49
+ _propagator = None
50
+
51
+
52
+ def get_tracer():
53
+ """Get or create the OpenTelemetry tracer instance."""
54
+ global _tracer
55
+ if _tracer is None:
56
+ initialize_otel()
57
+ return _tracer
58
+
59
+
60
+ def get_propagator():
61
+ """Get the trace context propagator."""
62
+ global _propagator
63
+ if not OTELEMETRY_AVAILABLE:
64
+ return None
65
+ if _propagator is None:
66
+ _propagator = TraceContextTextMapPropagator()
67
+ return _propagator
68
+
69
+
70
+ if OTELEMETRY_AVAILABLE:
71
+
72
+ class DatabaseSpanExporter(SpanExporter):
73
+ """Custom span exporter that stores spans directly to database in batches.
74
+
75
+ Uses a dedicated event loop in a background thread to handle async operations
76
+ safely from the BatchSpanProcessor's thread.
77
+ """
78
+
79
+ def __init__(self):
80
+ """Initialize exporter with dedicated event loop."""
81
+ # Dedicated event loop for this exporter (runs in background thread)
82
+ self.loop = None
83
+ self.loop_thread = None
84
+ self._loop_ready = threading.Event()
85
+ self._start_event_loop()
86
+
87
+ def _start_event_loop(self):
88
+ """Start dedicated event loop in background thread."""
89
+
90
+ def run_event_loop():
91
+ """Run event loop in background thread."""
92
+ self.loop = asyncio.new_event_loop()
93
+ asyncio.set_event_loop(self.loop)
94
+ self._loop_ready.set()
95
+ self.loop.run_forever()
96
+
97
+ self.loop_thread = threading.Thread(
98
+ target=run_event_loop, daemon=True, name="span-exporter-loop"
99
+ )
100
+ self.loop_thread.start()
101
+ # Wait for loop to be ready
102
+ self._loop_ready.wait(timeout=5)
103
+
104
+ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
105
+ """Export spans to database in a batch.
106
+
107
+ Args:
108
+ spans: Sequence of OpenTelemetry spans to export
109
+
110
+ Returns:
111
+ SpanExportResult.SUCCESS or SpanExportResult.FAILURE
112
+ """
113
+ if not spans or not self.loop:
114
+ return SpanExportResult.SUCCESS
115
+
116
+ try:
117
+ # Convert spans to database format
118
+ span_data_list = [self._span_to_dict(span) for span in spans]
119
+
120
+ # Schedule coroutine in dedicated event loop and wait for completion
121
+ future = asyncio.run_coroutine_threadsafe(
122
+ self._store_spans_batch(span_data_list), self.loop
123
+ )
124
+
125
+ # Wait for completion with timeout (30 seconds)
126
+ try:
127
+ future.result(timeout=30)
128
+ return SpanExportResult.SUCCESS
129
+ except TimeoutError:
130
+ logger.error("Span export timed out after 30 seconds")
131
+ return SpanExportResult.FAILURE
132
+ except Exception as e:
133
+ logger.error(f"Span export failed: {e}")
134
+ return SpanExportResult.FAILURE
135
+
136
+ except Exception as e:
137
+ logger.warning(f"Failed to export spans: {e}")
138
+ return SpanExportResult.FAILURE
139
+
140
+ def _span_to_dict(self, span) -> dict[str, Any]:
141
+ """Convert OpenTelemetry span to database format.
142
+
143
+ Args:
144
+ span: OpenTelemetry span object
145
+
146
+ Returns:
147
+ Dictionary with span data in database format
148
+ """
149
+ span_context = span.get_span_context()
150
+ trace_id = format(span_context.trace_id, "032x") if span_context else None
151
+ span_id = format(span_context.span_id, "016x") if span_context else None
152
+
153
+ # Get parent span ID from span's parent context
154
+ # OpenTelemetry SDK spans have a parent_span_id in their context
155
+ parent_span_id = None
156
+ try:
157
+ # Check if span has parent context
158
+ if hasattr(span, "parent") and span.parent:
159
+ parent_ctx = span.parent
160
+ # Parent context may have span_id attribute
161
+ if hasattr(parent_ctx, "span_id"):
162
+ parent_span_id = format(parent_ctx.span_id, "016x")
163
+ # Or it might be in the span context
164
+ elif hasattr(parent_ctx, "span_context"):
165
+ parent_span_context = parent_ctx.span_context()
166
+ if parent_span_context and parent_span_context.is_valid:
167
+ parent_span_id = format(parent_span_context.span_id, "016x")
168
+ except Exception:
169
+ # If we can't extract parent, that's okay - it might be a root span
170
+ pass
171
+
172
+ # Extract attributes
173
+ attributes = {}
174
+ input_data = None
175
+ output_data = None
176
+ initial_state = None
177
+ final_state = None
178
+ error_from_attributes = None
179
+
180
+ if hasattr(span, "attributes"):
181
+ for key, value in span.attributes.items():
182
+ # Check for input/output/error in attributes (stored as JSON strings)
183
+ if (
184
+ key
185
+ in (
186
+ "step.input",
187
+ "workflow.input",
188
+ "agent.input",
189
+ "tool.input",
190
+ "llm.input",
191
+ )
192
+ and value is not None
193
+ ):
194
+ # Input is stored as JSON string - parse it
195
+ try:
196
+ input_data = json.loads(value)
197
+ except (json.JSONDecodeError, TypeError):
198
+ # If parsing fails, store as None
199
+ input_data = None
200
+ elif (
201
+ key
202
+ in (
203
+ "step.output",
204
+ "workflow.output",
205
+ "agent.output",
206
+ "tool.output",
207
+ "llm.output",
208
+ )
209
+ and value is not None
210
+ ):
211
+ # Output is stored as JSON string - parse it
212
+ try:
213
+ output_data = json.loads(value)
214
+ except (json.JSONDecodeError, TypeError):
215
+ # If parsing fails, store as None
216
+ output_data = None
217
+ elif (
218
+ key
219
+ in (
220
+ "step.error",
221
+ "workflow.error",
222
+ "agent.error",
223
+ "tool.error",
224
+ "llm.error",
225
+ )
226
+ and value is not None
227
+ ):
228
+ # Error is stored as JSON string - parse it
229
+ try:
230
+ error_from_attributes = json.loads(value)
231
+ except (json.JSONDecodeError, TypeError):
232
+ # If parsing fails, store as None
233
+ error_from_attributes = None
234
+ elif (
235
+ key
236
+ in ("workflow.initial_state", "agent.initial_state", "tool.initial_state")
237
+ and value is not None
238
+ ):
239
+ # State is stored as JSON string - parse it
240
+ try:
241
+ initial_state = json.loads(value)
242
+ except (json.JSONDecodeError, TypeError):
243
+ # If parsing fails, store as None
244
+ initial_state = None
245
+ elif (
246
+ key in ("workflow.final_state", "agent.final_state", "tool.final_state")
247
+ and value is not None
248
+ ):
249
+ # State is stored as JSON string - parse it
250
+ try:
251
+ final_state = json.loads(value)
252
+ except (json.JSONDecodeError, TypeError):
253
+ # If parsing fails, store as None
254
+ final_state = None
255
+ else:
256
+ # Store remaining attributes
257
+ attributes[key] = str(value) if value is not None else None
258
+
259
+ # Extract events from span
260
+ events_data = []
261
+ if hasattr(span, "events") and span.events:
262
+ for event in span.events:
263
+ # Extract event name and timestamp
264
+ event_name = event.name if hasattr(event, "name") else str(event)
265
+ event_timestamp = None
266
+ if hasattr(event, "timestamp"):
267
+ event_timestamp = format_timestamp(
268
+ datetime.fromtimestamp(event.timestamp / 1e9, tz=timezone.utc)
269
+ )
270
+
271
+ # Extract event attributes if any
272
+ event_attributes = {}
273
+ if hasattr(event, "attributes") and event.attributes:
274
+ for key, value in event.attributes.items():
275
+ event_attributes[key] = str(value) if value is not None else None
276
+
277
+ events_data.append(
278
+ {
279
+ "name": event_name,
280
+ "timestamp": event_timestamp,
281
+ "attributes": event_attributes if event_attributes else None,
282
+ }
283
+ )
284
+
285
+ # Get status and error
286
+ status = span.status
287
+ error_data = None
288
+ # Prefer error from attributes if available (more detailed)
289
+ if error_from_attributes:
290
+ error_data = error_from_attributes
291
+ elif status and status.status_code == StatusCode.ERROR:
292
+ error_data = {
293
+ "message": status.description or "Unknown error",
294
+ "error_type": "Error",
295
+ }
296
+
297
+ # Extract start/end times
298
+ started_at = format_timestamp(
299
+ datetime.fromtimestamp(span.start_time / 1e9, tz=timezone.utc)
300
+ )
301
+ ended_at = None
302
+ if span.end_time:
303
+ ended_at = format_timestamp(
304
+ datetime.fromtimestamp(span.end_time / 1e9, tz=timezone.utc)
305
+ )
306
+
307
+ # Determine span type from name
308
+ span_type = "custom"
309
+ if span.name.startswith("workflow."):
310
+ span_type = "workflow"
311
+ elif span.name.startswith("agent."):
312
+ span_type = "agent"
313
+ elif span.name.startswith("tool."):
314
+ span_type = "tool"
315
+ elif span.name.startswith("step."):
316
+ span_type = "step"
317
+ elif "span_type" in attributes:
318
+ span_type = attributes["span_type"]
319
+
320
+ return {
321
+ "trace_id": trace_id,
322
+ "span_id": span_id,
323
+ "parent_span_id": parent_span_id,
324
+ "name": span.name,
325
+ "span_type": span_type,
326
+ "attributes": attributes if attributes else None,
327
+ "events": events_data if events_data else None,
328
+ "input": input_data,
329
+ "output": output_data,
330
+ "error": error_data,
331
+ "initial_state": initial_state,
332
+ "final_state": final_state,
333
+ "started_at": started_at,
334
+ "ended_at": ended_at,
335
+ }
336
+
337
+ async def _store_spans_batch(self, spans: list[dict[str, Any]]):
338
+ """Store a batch of spans to the database via API.
339
+
340
+ Args:
341
+ spans: List of span dictionaries
342
+
343
+ Note:
344
+ This runs in the exporter's dedicated event loop, so we cannot reuse
345
+ the worker's HTTP client (which is bound to a different event loop).
346
+ We must create a new client in this event loop.
347
+ """
348
+ try:
349
+ polos_client = get_client_or_raise()
350
+ api_url = polos_client.api_url
351
+ headers = polos_client._get_headers()
352
+
353
+ # Create a new client in this event loop (cannot reuse worker's client
354
+ # as it's bound to a different event loop)
355
+ async with httpx.AsyncClient() as client:
356
+ response = await client.post(
357
+ f"{api_url}/internal/spans/batch",
358
+ json={"spans": spans},
359
+ headers=headers,
360
+ )
361
+ response.raise_for_status()
362
+ except Exception as e:
363
+ logger.warning(f"Failed to store spans batch: {e}")
364
+
365
+ def shutdown(self):
366
+ """Clean shutdown of exporter."""
367
+ if self.loop and self.loop.is_running():
368
+ try:
369
+ # Schedule loop stop
370
+ self.loop.call_soon_threadsafe(self.loop.stop)
371
+ # Wait for thread to finish (with timeout)
372
+ if self.loop_thread and self.loop_thread.is_alive():
373
+ self.loop_thread.join(timeout=5)
374
+ except Exception as e:
375
+ logger.warning(f"Error during exporter shutdown: {e}")
376
+ finally:
377
+ # Close the loop
378
+ if self.loop and not self.loop.is_closed():
379
+ try:
380
+ # Cancel any pending tasks
381
+ pending = asyncio.all_tasks(self.loop)
382
+ for task in pending:
383
+ task.cancel()
384
+ # Run one more iteration to process cancellations
385
+ if pending:
386
+ self.loop.run_until_complete(
387
+ asyncio.gather(*pending, return_exceptions=True)
388
+ )
389
+ except Exception:
390
+ pass
391
+ finally:
392
+ self.loop.close()
393
+ else:
394
+ # No-op class when OpenTelemetry is not available
395
+ class DatabaseSpanExporter:
396
+ def __init__(self):
397
+ pass
398
+
399
+ def export(self, spans):
400
+ return None
401
+
402
+ def shutdown(self):
403
+ pass
404
+
405
+
406
+ if OTELEMETRY_AVAILABLE:
407
+
408
+ class DeterministicTraceIdGenerator(RandomIdGenerator):
409
+ """ID generator that uses deterministic trace IDs from context if available."""
410
+
411
+ def generate_trace_id(self) -> int:
412
+ # Check if we have a deterministic trace_id in context
413
+ # context.get_value() requires passing the context explicitly or it uses get_current()
414
+ current_ctx = context.get_current()
415
+ deterministic_trace_id = context.get_value("polos.trace_id", context=current_ctx)
416
+ if deterministic_trace_id is not None:
417
+ logger.debug(
418
+ f"Using deterministic trace_id from context: {deterministic_trace_id:032x}"
419
+ )
420
+ return deterministic_trace_id
421
+ # Otherwise generate random
422
+ return super().generate_trace_id()
423
+ else:
424
+ # No-op class when OpenTelemetry is not available
425
+ class DeterministicTraceIdGenerator:
426
+ def generate_trace_id(self) -> int:
427
+ return 0
428
+
429
+
430
+ def initialize_otel():
431
+ """Initialize OpenTelemetry SDK."""
432
+ global _tracer_provider, _tracer
433
+
434
+ if not OTELEMETRY_AVAILABLE:
435
+ logger.warning("OpenTelemetry not available. Tracing disabled.")
436
+ return
437
+
438
+ try:
439
+ # Check if enabled
440
+ if os.getenv("POLOS_OTEL_ENABLED", "true").lower() != "true":
441
+ _tracer = trace.NoOpTracer()
442
+ return
443
+
444
+ # Initialize provider with custom IdGenerator for deterministic trace IDs
445
+ _tracer_provider = TracerProvider(id_generator=DeterministicTraceIdGenerator())
446
+
447
+ # Add database exporter (MVP - DB storage only)
448
+ db_exporter = DatabaseSpanExporter()
449
+ _tracer_provider.add_span_processor(BatchSpanProcessor(db_exporter))
450
+
451
+ # Future: Add OTLP exporter if endpoint is configured
452
+ # otlp_endpoint = os.getenv("POLOS_OTEL_ENDPOINT")
453
+ # if otlp_endpoint:
454
+ # from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
455
+ # otlp_exporter = OTLPSpanExporter(endpoint=otlp_endpoint)
456
+ # _tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter))
457
+
458
+ # Set as global
459
+ trace.set_tracer_provider(_tracer_provider)
460
+
461
+ # Get tracer
462
+ service_name = os.getenv("POLOS_OTEL_SERVICE_NAME", "polos")
463
+ _tracer = trace.get_tracer(service_name)
464
+
465
+ logger.info("OpenTelemetry initialized with database exporter")
466
+
467
+ except Exception as e:
468
+ # Log error but continue with no-op tracer
469
+ logger.warning(f"Failed to initialize OpenTelemetry: {e}. Tracing disabled.")
470
+ _tracer = trace.NoOpTracer()
471
+
472
+
473
+ def get_current_span() -> Span | None:
474
+ """Get the current active span."""
475
+ if not OTELEMETRY_AVAILABLE:
476
+ return None
477
+ return trace.get_current_span()
478
+
479
+
480
+ def extract_traceparent(span: Span) -> str | None:
481
+ """Extract traceparent string from a span for cross-process propagation.
482
+
483
+ Args:
484
+ span: The span to extract trace context from
485
+
486
+ Returns:
487
+ traceparent string in W3C format, or None if not available
488
+ """
489
+ if not OTELEMETRY_AVAILABLE or span is None:
490
+ return None
491
+
492
+ try:
493
+ propagator = get_propagator()
494
+ if propagator is None:
495
+ return None
496
+
497
+ # Extract trace context
498
+ carrier = {}
499
+ span_context = span.get_span_context()
500
+ if span_context and span_context.is_valid:
501
+ # Create a context with the span
502
+ from opentelemetry.trace import set_span_in_context
503
+
504
+ ctx = set_span_in_context(span)
505
+ propagator.inject(carrier, context=ctx)
506
+ return carrier.get("traceparent")
507
+ except Exception as e:
508
+ logger.warning(f"Failed to extract traceparent: {e}")
509
+
510
+ return None
511
+
512
+
513
+ def create_context_from_traceparent(traceparent: str):
514
+ """Create OpenTelemetry context from traceparent string.
515
+
516
+ When extracting from a traceparent, we get a context with a SpanContext.
517
+ We need to wrap this in a NonRecordingSpan and set it in the context
518
+ so that child spans created with this context will be properly linked
519
+ to the parent span.
520
+
521
+ Args:
522
+ traceparent: W3C traceparent string (format: version-trace_id-parent_span_id-flags)
523
+
524
+ Returns:
525
+ OpenTelemetry context with parent span set, or None if extraction fails
526
+ """
527
+ if not OTELEMETRY_AVAILABLE or not traceparent:
528
+ return None
529
+
530
+ try:
531
+ propagator = get_propagator()
532
+ if propagator is None:
533
+ logger.warning("Propagator is None, cannot extract trace context")
534
+ return None
535
+
536
+ # Extract context from traceparent
537
+ # The traceparent format is: version-trace_id-parent_span_id-flags
538
+ carrier = {"traceparent": traceparent}
539
+ extracted_context = propagator.extract(carrier)
540
+
541
+ if extracted_context is None:
542
+ logger.warning(f"Propagator.extract returned None for traceparent: {traceparent}")
543
+ return None
544
+
545
+ return extracted_context
546
+ except Exception as e:
547
+ logger.warning(
548
+ f"Failed to create context from traceparent '{traceparent}': {e}", exc_info=True
549
+ )
550
+
551
+ return None
552
+
553
+
554
+ def create_context_with_trace_id(trace_id: int):
555
+ """Create OpenTelemetry context with a deterministic trace ID.
556
+
557
+ This creates a context that will make the next span a ROOT span
558
+ (no parent) but within the specified trace. The IdGenerator will
559
+ pick up the trace_id from this context.
560
+
561
+ Args:
562
+ trace_id: Trace ID as integer (128 bits)
563
+
564
+ Returns:
565
+ OpenTelemetry context configured for the trace ID
566
+ """
567
+ if not OTELEMETRY_AVAILABLE:
568
+ return None
569
+
570
+ try:
571
+ # Store the trace_id in context using a custom key
572
+ # This will be picked up by our custom DeterministicTraceIdGenerator
573
+ ctx = context.get_current()
574
+ ctx = context.set_value("polos.trace_id", trace_id, ctx)
575
+ logger.debug(f"Set trace_id in context: {trace_id:032x}")
576
+ return ctx
577
+
578
+ except Exception as e:
579
+ logger.warning(f"Failed to create context with trace ID: {e}")
580
+
581
+ return None
582
+
583
+
584
+ def generate_trace_id_from_execution_id(execution_id: str) -> int:
585
+ """Generate a deterministic trace ID from execution_id.
586
+
587
+ Args:
588
+ execution_id: Execution ID string (UUID format)
589
+
590
+ Returns:
591
+ Trace ID as integer (128 bits)
592
+ """
593
+ # Remove dashes from UUID and convert to int
594
+ hex_str = execution_id.replace("-", "")
595
+ # Ensure it's exactly 32 hex characters (128 bits)
596
+ if len(hex_str) != 32:
597
+ raise ValueError(f"Invalid execution_id format: {execution_id}")
598
+ return int(hex_str, 16)
599
+
600
+
601
+ def format_timestamp(dt: datetime) -> str:
602
+ """Format datetime to ISO string."""
603
+ if dt.tzinfo is None:
604
+ dt = dt.replace(tzinfo=timezone.utc)
605
+ return dt.isoformat()
polos/features/wait.py ADDED
@@ -0,0 +1,82 @@
1
+ """
2
+ Wait API for pausing workflow execution and resuming later.
3
+
4
+ This allows workflows to wait for time periods or subworkflows without consuming compute resources.
5
+ """
6
+
7
+ from datetime import datetime, timedelta, timezone
8
+
9
+ import httpx
10
+
11
+ from ..utils.client_context import get_client_or_raise
12
+
13
+
14
+ async def _set_waiting(
15
+ execution_id: str,
16
+ wait_until: datetime | None,
17
+ wait_type: str,
18
+ step_key: str,
19
+ wait_topic: str | None = None,
20
+ expires_at: datetime | None = None,
21
+ ) -> None:
22
+ """Internal method to set execution to waiting state."""
23
+ polos_client = get_client_or_raise()
24
+ api_url = polos_client.api_url
25
+ headers = polos_client._get_headers()
26
+
27
+ async with httpx.AsyncClient() as client:
28
+ response = await client.post(
29
+ f"{api_url}/internal/executions/{execution_id}/wait",
30
+ json={
31
+ "wait_until": wait_until.isoformat() if wait_until else None,
32
+ "wait_type": wait_type,
33
+ "step_key": step_key,
34
+ "wait_topic": wait_topic,
35
+ "expires_at": expires_at.isoformat() if expires_at else None,
36
+ },
37
+ headers=headers,
38
+ )
39
+ response.raise_for_status()
40
+
41
+
42
+ async def _get_wait_time(
43
+ seconds: float | None = None,
44
+ minutes: float | None = None,
45
+ hours: float | None = None,
46
+ days: float | None = None,
47
+ weeks: float | None = None,
48
+ ):
49
+ # Calculate wait_until datetime using proper date arithmetic
50
+ # Use timezone-aware UTC datetime
51
+ now = datetime.now(timezone.utc)
52
+ wait_until = now
53
+
54
+ if seconds:
55
+ wait_until = wait_until + timedelta(seconds=seconds)
56
+ if minutes:
57
+ wait_until = wait_until + timedelta(minutes=minutes)
58
+ if hours:
59
+ wait_until = wait_until + timedelta(hours=hours)
60
+ if days:
61
+ wait_until = wait_until + timedelta(days=days)
62
+ if weeks:
63
+ wait_until = wait_until + timedelta(weeks=weeks)
64
+
65
+ # Calculate total seconds for threshold check
66
+ total_seconds = (wait_until - now).total_seconds()
67
+
68
+ return total_seconds, wait_until
69
+
70
+
71
+ class WaitException(BaseException):
72
+ """
73
+ Exception raised when workflow execution must pause to wait.
74
+
75
+ This is used internally for checkpointing and should not be
76
+ caught by user code (inherits from BaseException to prevent this).
77
+ """
78
+
79
+ def __init__(self, reason: str, wait_data: dict | None = None):
80
+ self.reason = reason
81
+ self.wait_data = wait_data or {}
82
+ super().__init__(reason)
polos/llm/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """LLM generation and streaming functions."""
2
+
3
+ from .generate import _llm_generate
4
+ from .stream import _llm_stream
5
+
6
+ __all__ = [
7
+ "_llm_generate",
8
+ "_llm_stream",
9
+ ]