prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. prompture/__init__.py +264 -23
  2. prompture/_version.py +34 -0
  3. prompture/agent.py +924 -0
  4. prompture/agent_types.py +156 -0
  5. prompture/aio/__init__.py +74 -0
  6. prompture/async_agent.py +880 -0
  7. prompture/async_conversation.py +789 -0
  8. prompture/async_core.py +803 -0
  9. prompture/async_driver.py +193 -0
  10. prompture/async_groups.py +551 -0
  11. prompture/cache.py +469 -0
  12. prompture/callbacks.py +55 -0
  13. prompture/cli.py +63 -4
  14. prompture/conversation.py +826 -0
  15. prompture/core.py +894 -263
  16. prompture/cost_mixin.py +51 -0
  17. prompture/discovery.py +187 -0
  18. prompture/driver.py +206 -5
  19. prompture/drivers/__init__.py +175 -67
  20. prompture/drivers/airllm_driver.py +109 -0
  21. prompture/drivers/async_airllm_driver.py +26 -0
  22. prompture/drivers/async_azure_driver.py +123 -0
  23. prompture/drivers/async_claude_driver.py +113 -0
  24. prompture/drivers/async_google_driver.py +316 -0
  25. prompture/drivers/async_grok_driver.py +97 -0
  26. prompture/drivers/async_groq_driver.py +90 -0
  27. prompture/drivers/async_hugging_driver.py +61 -0
  28. prompture/drivers/async_lmstudio_driver.py +148 -0
  29. prompture/drivers/async_local_http_driver.py +44 -0
  30. prompture/drivers/async_ollama_driver.py +135 -0
  31. prompture/drivers/async_openai_driver.py +102 -0
  32. prompture/drivers/async_openrouter_driver.py +102 -0
  33. prompture/drivers/async_registry.py +133 -0
  34. prompture/drivers/azure_driver.py +42 -9
  35. prompture/drivers/claude_driver.py +257 -34
  36. prompture/drivers/google_driver.py +295 -42
  37. prompture/drivers/grok_driver.py +35 -32
  38. prompture/drivers/groq_driver.py +33 -26
  39. prompture/drivers/hugging_driver.py +6 -6
  40. prompture/drivers/lmstudio_driver.py +97 -19
  41. prompture/drivers/local_http_driver.py +6 -6
  42. prompture/drivers/ollama_driver.py +168 -23
  43. prompture/drivers/openai_driver.py +184 -9
  44. prompture/drivers/openrouter_driver.py +37 -25
  45. prompture/drivers/registry.py +306 -0
  46. prompture/drivers/vision_helpers.py +153 -0
  47. prompture/field_definitions.py +106 -96
  48. prompture/group_types.py +147 -0
  49. prompture/groups.py +530 -0
  50. prompture/image.py +180 -0
  51. prompture/logging.py +80 -0
  52. prompture/model_rates.py +217 -0
  53. prompture/persistence.py +254 -0
  54. prompture/persona.py +482 -0
  55. prompture/runner.py +49 -47
  56. prompture/scaffold/__init__.py +1 -0
  57. prompture/scaffold/generator.py +84 -0
  58. prompture/scaffold/templates/Dockerfile.j2 +12 -0
  59. prompture/scaffold/templates/README.md.j2 +41 -0
  60. prompture/scaffold/templates/config.py.j2 +21 -0
  61. prompture/scaffold/templates/env.example.j2 +8 -0
  62. prompture/scaffold/templates/main.py.j2 +86 -0
  63. prompture/scaffold/templates/models.py.j2 +40 -0
  64. prompture/scaffold/templates/requirements.txt.j2 +5 -0
  65. prompture/serialization.py +218 -0
  66. prompture/server.py +183 -0
  67. prompture/session.py +117 -0
  68. prompture/settings.py +19 -1
  69. prompture/tools.py +219 -267
  70. prompture/tools_schema.py +254 -0
  71. prompture/validator.py +3 -3
  72. prompture-0.0.38.dev2.dist-info/METADATA +369 -0
  73. prompture-0.0.38.dev2.dist-info/RECORD +77 -0
  74. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
  75. prompture-0.0.29.dev8.dist-info/METADATA +0 -368
  76. prompture-0.0.29.dev8.dist-info/RECORD +0 -27
  77. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
  78. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
  79. {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
@@ -10,19 +10,21 @@ Features:
10
10
  - Pydantic integration via field_from_registry()
11
11
  - Clean registration API with register_field() and add_field_definition()
12
12
  """
13
+
13
14
  import collections.abc
14
15
  import threading
15
- import warnings
16
- from datetime import datetime, date
17
- from typing import Dict, Any, Union, Optional, List, Literal, get_args
16
+ from datetime import date, datetime
17
+ from typing import Any, Optional, Union
18
+
18
19
  from pydantic import Field
19
20
 
21
+
20
22
  # Template variable providers
21
- def _get_template_variables() -> Dict[str, Any]:
23
+ def _get_template_variables() -> dict[str, Any]:
22
24
  """Get current template variables for field definitions."""
23
25
  now = datetime.now()
24
26
  today = date.today()
25
-
27
+
26
28
  return {
27
29
  "current_year": now.year,
28
30
  "current_date": today.isoformat(),
@@ -30,30 +32,32 @@ def _get_template_variables() -> Dict[str, Any]:
30
32
  "current_timestamp": int(now.timestamp()),
31
33
  "current_month": now.month,
32
34
  "current_day": now.day,
33
- "current_weekday": now.strftime("%A"), # e.g. "Monday"
35
+ "current_weekday": now.strftime("%A"), # e.g. "Monday"
34
36
  "current_iso_week": now.isocalendar().week, # ISO week number
35
37
  }
36
38
 
37
- def _apply_templates(text: str, custom_vars: Optional[Dict[str, Any]] = None) -> str:
39
+
40
+ def _apply_templates(text: str, custom_vars: Optional[dict[str, Any]] = None) -> str:
38
41
  """Apply template variable substitution to a text string."""
39
42
  if not isinstance(text, str):
40
43
  return text
41
-
44
+
42
45
  variables = _get_template_variables()
43
46
  if custom_vars:
44
47
  variables.update(custom_vars)
45
-
48
+
46
49
  # Simple template replacement
47
50
  result = text
48
51
  for key, value in variables.items():
49
52
  placeholder = f"{{{{{key}}}}}"
50
53
  result = result.replace(placeholder, str(value))
51
-
54
+
52
55
  return result
53
56
 
57
+
54
58
  # Thread-safe global registry
55
59
  _registry_lock = threading.Lock()
56
- _global_registry: Dict[str, Dict[str, Any]] = {}
60
+ _global_registry: dict[str, dict[str, Any]] = {}
57
61
 
58
62
  # Base field definitions dictionary containing all supported fields
59
63
  BASE_FIELD_DEFINITIONS = {
@@ -79,7 +83,6 @@ BASE_FIELD_DEFINITIONS = {
79
83
  "default": None,
80
84
  "nullable": True,
81
85
  },
82
-
83
86
  # Contact Information
84
87
  "email": {
85
88
  "type": str,
@@ -102,7 +105,6 @@ BASE_FIELD_DEFINITIONS = {
102
105
  "default": "",
103
106
  "nullable": True,
104
107
  },
105
-
106
108
  # Professional Information
107
109
  "occupation": {
108
110
  "type": str,
@@ -125,7 +127,6 @@ BASE_FIELD_DEFINITIONS = {
125
127
  "default": 0,
126
128
  "nullable": True,
127
129
  },
128
-
129
130
  # Metadata Fields
130
131
  "source": {
131
132
  "type": str,
@@ -148,7 +149,6 @@ BASE_FIELD_DEFINITIONS = {
148
149
  "default": 0.0,
149
150
  "nullable": False,
150
151
  },
151
-
152
152
  # Location Fields
153
153
  "city": {
154
154
  "type": str,
@@ -185,7 +185,6 @@ BASE_FIELD_DEFINITIONS = {
185
185
  "default": "",
186
186
  "nullable": True,
187
187
  },
188
-
189
188
  # Demographic Fields
190
189
  "gender": {
191
190
  "type": str,
@@ -215,7 +214,6 @@ BASE_FIELD_DEFINITIONS = {
215
214
  "default": "",
216
215
  "nullable": True,
217
216
  },
218
-
219
217
  # Education Fields
220
218
  "education_level": {
221
219
  "type": str,
@@ -238,7 +236,6 @@ BASE_FIELD_DEFINITIONS = {
238
236
  "default": None,
239
237
  "nullable": True,
240
238
  },
241
-
242
239
  # Financial Fields
243
240
  "salary": {
244
241
  "type": float,
@@ -261,7 +258,6 @@ BASE_FIELD_DEFINITIONS = {
261
258
  "default": None,
262
259
  "nullable": True,
263
260
  },
264
-
265
261
  # Social Media Fields
266
262
  "sentiment": {
267
263
  "type": str,
@@ -292,7 +288,6 @@ BASE_FIELD_DEFINITIONS = {
292
288
  "default": "",
293
289
  "nullable": True,
294
290
  },
295
-
296
291
  # Enum Fields for Task Management
297
292
  "priority": {
298
293
  "type": str,
@@ -325,28 +320,32 @@ BASE_FIELD_DEFINITIONS = {
325
320
  "enum": ["formal", "informal", "optimistic", "pessimistic"],
326
321
  "default": "formal",
327
322
  "nullable": True,
328
- }
323
+ },
329
324
  }
330
325
 
326
+
331
327
  def _initialize_registry() -> None:
332
328
  """Initialize the global registry with base field definitions."""
333
329
  with _registry_lock:
334
330
  if not _global_registry:
335
331
  _global_registry.update(BASE_FIELD_DEFINITIONS)
336
332
 
333
+
337
334
  # Initialize registry on import
338
335
  _initialize_registry()
339
336
 
340
337
  # Type hints for field definition structure
341
338
  FieldType = Union[type, str]
342
- FieldDefinition = Dict[str, Union[FieldType, str, Any, bool]]
343
- FieldDefinitions = Dict[str, FieldDefinition]
339
+ FieldDefinition = dict[str, Union[FieldType, str, Any, bool]]
340
+ FieldDefinitions = dict[str, FieldDefinition]
344
341
 
345
342
  # Maintain backward compatibility
346
343
  FIELD_DEFINITIONS = _global_registry
347
344
 
348
- def get_field_definition(field_name: str, apply_templates: bool = True,
349
- custom_template_vars: Optional[Dict[str, Any]] = None) -> Optional[FieldDefinition]:
345
+
346
+ def get_field_definition(
347
+ field_name: str, apply_templates: bool = True, custom_template_vars: Optional[dict[str, Any]] = None
348
+ ) -> Optional[FieldDefinition]:
350
349
  """
351
350
  Retrieve the definition for a specific field from the global registry.
352
351
 
@@ -360,22 +359,23 @@ def get_field_definition(field_name: str, apply_templates: bool = True,
360
359
  """
361
360
  with _registry_lock:
362
361
  field_def = _global_registry.get(field_name)
363
-
362
+
364
363
  if field_def is None:
365
364
  return None
366
-
365
+
367
366
  # Make a copy to avoid modifying the original
368
367
  result = field_def.copy()
369
-
368
+
370
369
  if apply_templates:
371
370
  # Apply templates to string values
372
371
  for key, value in result.items():
373
372
  if isinstance(value, str):
374
373
  result[key] = _apply_templates(value, custom_template_vars)
375
-
374
+
376
375
  return result
377
376
 
378
- def get_required_fields() -> List[str]:
377
+
378
+ def get_required_fields() -> list[str]:
379
379
  """
380
380
  Get a list of all required (non-nullable) fields.
381
381
 
@@ -384,11 +384,11 @@ def get_required_fields() -> List[str]:
384
384
  """
385
385
  with _registry_lock:
386
386
  return [
387
- field_name for field_name, definition in _global_registry.items()
388
- if not definition.get('nullable', True)
387
+ field_name for field_name, definition in _global_registry.items() if not definition.get("nullable", True)
389
388
  ]
390
389
 
391
- def get_field_names() -> List[str]:
390
+
391
+ def get_field_names() -> list[str]:
392
392
  """
393
393
  Get a list of all defined field names.
394
394
 
@@ -398,163 +398,167 @@ def get_field_names() -> List[str]:
398
398
  with _registry_lock:
399
399
  return list(_global_registry.keys())
400
400
 
401
+
401
402
  def register_field(field_name: str, field_definition: FieldDefinition) -> None:
402
403
  """
403
404
  Register a single field definition in the global registry.
404
-
405
+
405
406
  Args:
406
407
  field_name (str): Name of the field
407
408
  field_definition (FieldDefinition): Field definition dictionary
408
-
409
+
409
410
  Raises:
410
411
  ValueError: If field definition is invalid
411
412
  """
412
413
  from .tools import validate_field_definition
413
-
414
+
414
415
  if not validate_field_definition(field_definition):
415
416
  raise ValueError(f"Invalid field definition for '{field_name}'")
416
-
417
+
417
418
  with _registry_lock:
418
419
  _global_registry[field_name] = field_definition.copy()
419
420
 
421
+
420
422
  def add_field_definition(field_name: str, field_definition: FieldDefinition) -> None:
421
423
  """
422
424
  Add a field definition to the global registry (alias for register_field).
423
-
425
+
424
426
  Args:
425
427
  field_name (str): Name of the field
426
428
  field_definition (FieldDefinition): Field definition dictionary
427
429
  """
428
430
  register_field(field_name, field_definition)
429
431
 
430
- def add_field_definitions(field_definitions: Dict[str, FieldDefinition]) -> None:
432
+
433
+ def add_field_definitions(field_definitions: dict[str, FieldDefinition]) -> None:
431
434
  """
432
435
  Add multiple field definitions to the global registry.
433
-
436
+
434
437
  Args:
435
438
  field_definitions (Dict[str, FieldDefinition]): Dictionary of field definitions
436
439
  """
437
440
  for field_name, field_def in field_definitions.items():
438
441
  register_field(field_name, field_def)
439
442
 
440
- def field_from_registry(field_name: str, apply_templates: bool = True,
441
- custom_template_vars: Optional[Dict[str, Any]] = None) -> Field:
443
+
444
+ def field_from_registry(
445
+ field_name: str, apply_templates: bool = True, custom_template_vars: Optional[dict[str, Any]] = None
446
+ ) -> Field:
442
447
  """
443
448
  Create a Pydantic Field from a field definition in the global registry.
444
-
449
+
445
450
  Args:
446
451
  field_name (str): Name of the field in the registry
447
452
  apply_templates (bool): Whether to apply template variable substitution
448
453
  custom_template_vars (Optional[Dict[str, Any]]): Custom template variables
449
-
454
+
450
455
  Returns:
451
456
  pydantic.Field: Configured Pydantic Field object
452
-
457
+
453
458
  Raises:
454
459
  KeyError: If field_name is not found in the registry
455
460
  """
456
461
  field_def = get_field_definition(field_name, apply_templates, custom_template_vars)
457
-
462
+
458
463
  if field_def is None:
459
464
  raise KeyError(f"Field '{field_name}' not found in registry. Available fields: {', '.join(get_field_names())}")
460
-
465
+
461
466
  # Extract Pydantic Field parameters
462
- default_value = field_def.get('default')
463
- description = field_def.get('description', f"Extract the {field_name} from the text.")
464
- instructions = field_def.get('instructions', '')
465
-
467
+ default_value = field_def.get("default")
468
+ description = field_def.get("description", f"Extract the {field_name} from the text.")
469
+ instructions = field_def.get("instructions", "")
470
+
466
471
  # Handle enum fields
467
- enum_values = field_def.get('enum')
472
+ enum_values = field_def.get("enum")
468
473
  if enum_values:
469
474
  # Enhance description with enum constraint information
470
475
  enum_str = "', '".join(str(v) for v in enum_values)
471
476
  enhanced_instructions = f"{instructions}. Must be one of: '{enum_str}'"
472
477
  enhanced_description = f"{description}. Allowed values: {enum_str}"
473
-
478
+
474
479
  # Create json_schema_extra with enum constraint
475
- json_schema_extra = {
476
- "enum": enum_values,
477
- "instructions": enhanced_instructions
478
- }
479
-
480
+ json_schema_extra = {"enum": enum_values, "instructions": enhanced_instructions}
481
+
480
482
  # Handle nullable/required logic with enum
481
- if field_def.get('nullable', True) and default_value is not None:
483
+ if field_def.get("nullable", True) and default_value is not None:
482
484
  return Field(default=default_value, description=enhanced_description, json_schema_extra=json_schema_extra)
483
- elif field_def.get('nullable', True):
485
+ elif field_def.get("nullable", True):
484
486
  return Field(default=None, description=enhanced_description, json_schema_extra=json_schema_extra)
485
487
  else:
486
488
  return Field(description=enhanced_description, json_schema_extra=json_schema_extra)
487
-
489
+
488
490
  # Handle non-enum fields (original logic)
489
- if field_def.get('nullable', True) and default_value is not None:
491
+ if field_def.get("nullable", True) and default_value is not None:
490
492
  # Optional field with default
491
493
  return Field(default=default_value, description=description)
492
- elif field_def.get('nullable', True):
494
+ elif field_def.get("nullable", True):
493
495
  # Optional field without default (None)
494
496
  return Field(default=None, description=description)
495
497
  else:
496
498
  # Required field
497
499
  return Field(description=description)
498
500
 
501
+
499
502
  def validate_enum_value(field_name: str, value: Any) -> bool:
500
503
  """
501
504
  Validate that a value is in the allowed enum list for a field.
502
-
505
+
503
506
  Args:
504
507
  field_name (str): Name of the field in the registry
505
508
  value (Any): Value to validate
506
-
509
+
507
510
  Returns:
508
511
  bool: True if value is valid for the enum field, False otherwise
509
512
  """
510
513
  field_def = get_field_definition(field_name, apply_templates=False)
511
-
514
+
512
515
  if field_def is None:
513
516
  return False
514
-
515
- enum_values = field_def.get('enum')
517
+
518
+ enum_values = field_def.get("enum")
516
519
  if not enum_values:
517
520
  # Not an enum field, so any value is valid
518
521
  return True
519
-
522
+
520
523
  # Check if value is in the allowed enum list
521
524
  return value in enum_values
522
525
 
526
+
523
527
  def normalize_enum_value(field_name: str, value: Any, case_sensitive: bool = True) -> Any:
524
528
  """
525
529
  Normalize and validate an enum value for a field.
526
-
530
+
527
531
  Args:
528
532
  field_name (str): Name of the field in the registry
529
533
  value (Any): Value to normalize
530
534
  case_sensitive (bool): Whether to perform case-sensitive matching
531
-
535
+
532
536
  Returns:
533
537
  Any: Normalized value if valid, original value otherwise
534
-
538
+
535
539
  Raises:
536
540
  ValueError: If value is not in the allowed enum list
537
541
  """
538
542
  field_def = get_field_definition(field_name, apply_templates=False)
539
-
543
+
540
544
  if field_def is None:
541
545
  raise KeyError(f"Field '{field_name}' not found in registry")
542
-
543
- enum_values = field_def.get('enum')
546
+
547
+ enum_values = field_def.get("enum")
544
548
  if not enum_values:
545
549
  # Not an enum field, return as-is
546
550
  return value
547
-
551
+
548
552
  # Convert value to string for comparison
549
553
  str_value = str(value) if value is not None else None
550
-
554
+
551
555
  if str_value is None:
552
556
  # Handle nullable fields
553
- if field_def.get('nullable', True):
557
+ if field_def.get("nullable", True):
554
558
  return None
555
559
  else:
556
560
  raise ValueError(f"Field '{field_name}' does not allow null values")
557
-
561
+
558
562
  # Case-sensitive matching
559
563
  if case_sensitive:
560
564
  if str_value in enum_values:
@@ -563,28 +567,30 @@ def normalize_enum_value(field_name: str, value: Any, case_sensitive: bool = Tru
563
567
  f"Invalid value '{str_value}' for field '{field_name}'. "
564
568
  f"Must be one of: {', '.join(repr(v) for v in enum_values)}"
565
569
  )
566
-
570
+
567
571
  # Case-insensitive matching
568
572
  lower_value = str_value.lower()
569
573
  for enum_val in enum_values:
570
574
  if str(enum_val).lower() == lower_value:
571
575
  return enum_val
572
-
576
+
573
577
  raise ValueError(
574
578
  f"Invalid value '{str_value}' for field '{field_name}'. "
575
579
  f"Must be one of: {', '.join(repr(v) for v in enum_values)}"
576
580
  )
577
581
 
578
- def get_registry_snapshot() -> Dict[str, FieldDefinition]:
582
+
583
+ def get_registry_snapshot() -> dict[str, FieldDefinition]:
579
584
  """
580
585
  Get a snapshot of the current global registry.
581
-
586
+
582
587
  Returns:
583
588
  Dict[str, FieldDefinition]: Copy of the current registry
584
589
  """
585
590
  with _registry_lock:
586
591
  return _global_registry.copy()
587
592
 
593
+
588
594
  def clear_registry() -> None:
589
595
  """
590
596
  Clear all field definitions from the global registry.
@@ -593,6 +599,7 @@ def clear_registry() -> None:
593
599
  with _registry_lock:
594
600
  _global_registry.clear()
595
601
 
602
+
596
603
  def reset_registry() -> None:
597
604
  """
598
605
  Reset the global registry to contain only the base field definitions.
@@ -601,22 +608,24 @@ def reset_registry() -> None:
601
608
  _global_registry.clear()
602
609
  _global_registry.update(BASE_FIELD_DEFINITIONS)
603
610
 
611
+
604
612
  # For backward compatibility, keep the old FIELD_DEFINITIONS reference
605
613
  # but make it point to the global registry
606
614
  def _get_field_definitions():
607
615
  """Backward compatibility getter for FIELD_DEFINITIONS."""
608
616
  return get_registry_snapshot()
609
617
 
618
+
610
619
  # Create a property-like access to maintain backward compatibility
611
620
  class _FieldDefinitionsProxy(dict, collections.abc.MutableMapping):
612
621
  """Proxy class to maintain backward compatibility with FIELD_DEFINITIONS."""
613
-
622
+
614
623
  def __getitem__(self, key):
615
624
  return get_field_definition(key)
616
-
625
+
617
626
  def __setitem__(self, key, value):
618
627
  register_field(key, value)
619
-
628
+
620
629
  def __delitem__(self, key):
621
630
  """Remove a field from the registry."""
622
631
  with _registry_lock:
@@ -624,37 +633,38 @@ class _FieldDefinitionsProxy(dict, collections.abc.MutableMapping):
624
633
  del _global_registry[key]
625
634
  else:
626
635
  raise KeyError(f"Field '{key}' not found in registry")
627
-
636
+
628
637
  def __contains__(self, key):
629
638
  return key in get_field_names()
630
-
639
+
631
640
  def __iter__(self):
632
641
  return iter(get_field_names())
633
-
642
+
634
643
  def keys(self):
635
644
  return get_field_names()
636
-
645
+
637
646
  def values(self):
638
647
  with _registry_lock:
639
648
  return list(_global_registry.values())
640
-
649
+
641
650
  def items(self):
642
651
  with _registry_lock:
643
652
  return list(_global_registry.items())
644
-
653
+
645
654
  def __len__(self):
646
655
  with _registry_lock:
647
656
  return len(_global_registry)
648
-
657
+
649
658
  def get(self, key, default=None):
650
659
  field_def = get_field_definition(key)
651
660
  return field_def if field_def is not None else default
652
-
661
+
653
662
  def update(self, other):
654
- if hasattr(other, 'items'):
663
+ if hasattr(other, "items"):
655
664
  add_field_definitions(dict(other.items()))
656
665
  else:
657
666
  add_field_definitions(dict(other))
658
667
 
668
+
659
669
  # Replace FIELD_DEFINITIONS with the proxy for backward compatibility
660
- FIELD_DEFINITIONS = _FieldDefinitionsProxy()
670
+ FIELD_DEFINITIONS = _FieldDefinitionsProxy()
@@ -0,0 +1,147 @@
1
+ """Shared types for multi-agent group coordination.
2
+
3
+ Defines enums, dataclasses, and callbacks used by
4
+ :class:`~prompture.groups.SequentialGroup`,
5
+ :class:`~prompture.async_groups.ParallelGroup`, and related classes.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import enum
11
+ import json
12
+ import time
13
+ from collections.abc import Callable
14
+ from dataclasses import dataclass, field
15
+ from typing import Any
16
+
17
+
18
+ class ErrorPolicy(enum.Enum):
19
+ """How a group handles agent failures."""
20
+
21
+ fail_fast = "fail_fast"
22
+ continue_on_error = "continue_on_error"
23
+ retry_failed = "retry_failed"
24
+
25
+
26
+ @dataclass
27
+ class GroupStep:
28
+ """Record of a single agent execution within a group run."""
29
+
30
+ agent_name: str
31
+ step_type: str = "agent_run"
32
+ timestamp: float = 0.0
33
+ duration_ms: float = 0.0
34
+ usage_delta: dict[str, Any] = field(default_factory=dict)
35
+ error: str | None = None
36
+
37
+
38
+ @dataclass
39
+ class AgentError:
40
+ """Captures a failed agent execution within a group."""
41
+
42
+ agent_name: str
43
+ error: Exception
44
+ error_message: str = ""
45
+ output_key: str | None = None
46
+
47
+ def __post_init__(self) -> None:
48
+ if not self.error_message:
49
+ self.error_message = str(self.error)
50
+
51
+
52
+ @dataclass
53
+ class GroupResult:
54
+ """Outcome of a group execution.
55
+
56
+ Attributes:
57
+ agent_results: Mapping of agent name/key to their :class:`AgentResult`.
58
+ aggregate_usage: Combined token/cost totals across all agent runs.
59
+ shared_state: Final state dict after all agents have written outputs.
60
+ elapsed_ms: Wall-clock duration of the group run.
61
+ timeline: Ordered list of :class:`GroupStep` records.
62
+ errors: List of :class:`AgentError` for any failed agents.
63
+ success: ``True`` if no errors occurred.
64
+ """
65
+
66
+ agent_results: dict[str, Any] = field(default_factory=dict)
67
+ aggregate_usage: dict[str, Any] = field(default_factory=dict)
68
+ shared_state: dict[str, Any] = field(default_factory=dict)
69
+ elapsed_ms: float = 0.0
70
+ timeline: list[GroupStep] = field(default_factory=list)
71
+ errors: list[AgentError] = field(default_factory=list)
72
+ success: bool = True
73
+
74
+ def export(self) -> dict[str, Any]:
75
+ """Return a JSON-serializable dict representation."""
76
+ return {
77
+ "agent_results": {
78
+ k: {
79
+ "output_text": getattr(v, "output_text", str(v)),
80
+ "usage": getattr(v, "run_usage", {}),
81
+ }
82
+ for k, v in self.agent_results.items()
83
+ },
84
+ "aggregate_usage": self.aggregate_usage,
85
+ "shared_state": self.shared_state,
86
+ "elapsed_ms": self.elapsed_ms,
87
+ "timeline": [
88
+ {
89
+ "agent_name": s.agent_name,
90
+ "step_type": s.step_type,
91
+ "timestamp": s.timestamp,
92
+ "duration_ms": s.duration_ms,
93
+ "usage_delta": s.usage_delta,
94
+ "error": s.error,
95
+ }
96
+ for s in self.timeline
97
+ ],
98
+ "errors": [
99
+ {
100
+ "agent_name": e.agent_name,
101
+ "error_message": e.error_message,
102
+ "output_key": e.output_key,
103
+ }
104
+ for e in self.errors
105
+ ],
106
+ "success": self.success,
107
+ }
108
+
109
+ def save(self, path: str) -> None:
110
+ """Write the exported dict to a JSON file."""
111
+ with open(path, "w", encoding="utf-8") as f:
112
+ json.dump(self.export(), f, indent=2, default=str)
113
+
114
+
115
+ @dataclass
116
+ class GroupCallbacks:
117
+ """Observability callbacks for group execution."""
118
+
119
+ on_agent_start: Callable[[str, str], None] | None = None
120
+ on_agent_complete: Callable[[str, Any], None] | None = None
121
+ on_agent_error: Callable[[str, Exception], None] | None = None
122
+ on_state_update: Callable[[str, Any], None] | None = None
123
+
124
+
125
+ def _aggregate_usage(*sessions: dict[str, Any]) -> dict[str, Any]:
126
+ """Merge multiple usage summary dicts into one aggregate."""
127
+ agg: dict[str, Any] = {
128
+ "prompt_tokens": 0,
129
+ "completion_tokens": 0,
130
+ "total_tokens": 0,
131
+ "total_cost": 0.0,
132
+ "call_count": 0,
133
+ "errors": 0,
134
+ }
135
+ for s in sessions:
136
+ agg["prompt_tokens"] += s.get("prompt_tokens", 0)
137
+ agg["completion_tokens"] += s.get("completion_tokens", 0)
138
+ agg["total_tokens"] += s.get("total_tokens", 0)
139
+ agg["total_cost"] += s.get("total_cost", 0.0)
140
+ agg["call_count"] += s.get("call_count", 0)
141
+ agg["errors"] += s.get("errors", 0)
142
+ return agg
143
+
144
+
145
+ def _now_ms() -> float:
146
+ """Current time in milliseconds (perf_counter-based)."""
147
+ return time.perf_counter() * 1000