prompture 0.0.29.dev8__py3-none-any.whl → 0.0.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +264 -23
- prompture/_version.py +34 -0
- prompture/agent.py +924 -0
- prompture/agent_types.py +156 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_agent.py +880 -0
- prompture/async_conversation.py +789 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +193 -0
- prompture/async_groups.py +551 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +826 -0
- prompture/core.py +894 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +187 -0
- prompture/driver.py +206 -5
- prompture/drivers/__init__.py +175 -67
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +123 -0
- prompture/drivers/async_claude_driver.py +113 -0
- prompture/drivers/async_google_driver.py +316 -0
- prompture/drivers/async_grok_driver.py +97 -0
- prompture/drivers/async_groq_driver.py +90 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +148 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +135 -0
- prompture/drivers/async_openai_driver.py +102 -0
- prompture/drivers/async_openrouter_driver.py +102 -0
- prompture/drivers/async_registry.py +133 -0
- prompture/drivers/azure_driver.py +42 -9
- prompture/drivers/claude_driver.py +257 -34
- prompture/drivers/google_driver.py +295 -42
- prompture/drivers/grok_driver.py +35 -32
- prompture/drivers/groq_driver.py +33 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +97 -19
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +168 -23
- prompture/drivers/openai_driver.py +184 -9
- prompture/drivers/openrouter_driver.py +37 -25
- prompture/drivers/registry.py +306 -0
- prompture/drivers/vision_helpers.py +153 -0
- prompture/field_definitions.py +106 -96
- prompture/group_types.py +147 -0
- prompture/groups.py +530 -0
- prompture/image.py +180 -0
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/persistence.py +254 -0
- prompture/persona.py +482 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/serialization.py +218 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +19 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- prompture-0.0.38.dev2.dist-info/METADATA +369 -0
- prompture-0.0.38.dev2.dist-info/RECORD +77 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/METADATA +0 -368
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.38.dev2.dist-info}/top_level.txt +0 -0
prompture/field_definitions.py
CHANGED
|
@@ -10,19 +10,21 @@ Features:
|
|
|
10
10
|
- Pydantic integration via field_from_registry()
|
|
11
11
|
- Clean registration API with register_field() and add_field_definition()
|
|
12
12
|
"""
|
|
13
|
+
|
|
13
14
|
import collections.abc
|
|
14
15
|
import threading
|
|
15
|
-
import
|
|
16
|
-
from
|
|
17
|
-
|
|
16
|
+
from datetime import date, datetime
|
|
17
|
+
from typing import Any, Optional, Union
|
|
18
|
+
|
|
18
19
|
from pydantic import Field
|
|
19
20
|
|
|
21
|
+
|
|
20
22
|
# Template variable providers
|
|
21
|
-
def _get_template_variables() ->
|
|
23
|
+
def _get_template_variables() -> dict[str, Any]:
|
|
22
24
|
"""Get current template variables for field definitions."""
|
|
23
25
|
now = datetime.now()
|
|
24
26
|
today = date.today()
|
|
25
|
-
|
|
27
|
+
|
|
26
28
|
return {
|
|
27
29
|
"current_year": now.year,
|
|
28
30
|
"current_date": today.isoformat(),
|
|
@@ -30,30 +32,32 @@ def _get_template_variables() -> Dict[str, Any]:
|
|
|
30
32
|
"current_timestamp": int(now.timestamp()),
|
|
31
33
|
"current_month": now.month,
|
|
32
34
|
"current_day": now.day,
|
|
33
|
-
"current_weekday": now.strftime("%A"),
|
|
35
|
+
"current_weekday": now.strftime("%A"), # e.g. "Monday"
|
|
34
36
|
"current_iso_week": now.isocalendar().week, # ISO week number
|
|
35
37
|
}
|
|
36
38
|
|
|
37
|
-
|
|
39
|
+
|
|
40
|
+
def _apply_templates(text: str, custom_vars: Optional[dict[str, Any]] = None) -> str:
|
|
38
41
|
"""Apply template variable substitution to a text string."""
|
|
39
42
|
if not isinstance(text, str):
|
|
40
43
|
return text
|
|
41
|
-
|
|
44
|
+
|
|
42
45
|
variables = _get_template_variables()
|
|
43
46
|
if custom_vars:
|
|
44
47
|
variables.update(custom_vars)
|
|
45
|
-
|
|
48
|
+
|
|
46
49
|
# Simple template replacement
|
|
47
50
|
result = text
|
|
48
51
|
for key, value in variables.items():
|
|
49
52
|
placeholder = f"{{{{{key}}}}}"
|
|
50
53
|
result = result.replace(placeholder, str(value))
|
|
51
|
-
|
|
54
|
+
|
|
52
55
|
return result
|
|
53
56
|
|
|
57
|
+
|
|
54
58
|
# Thread-safe global registry
|
|
55
59
|
_registry_lock = threading.Lock()
|
|
56
|
-
_global_registry:
|
|
60
|
+
_global_registry: dict[str, dict[str, Any]] = {}
|
|
57
61
|
|
|
58
62
|
# Base field definitions dictionary containing all supported fields
|
|
59
63
|
BASE_FIELD_DEFINITIONS = {
|
|
@@ -79,7 +83,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
79
83
|
"default": None,
|
|
80
84
|
"nullable": True,
|
|
81
85
|
},
|
|
82
|
-
|
|
83
86
|
# Contact Information
|
|
84
87
|
"email": {
|
|
85
88
|
"type": str,
|
|
@@ -102,7 +105,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
102
105
|
"default": "",
|
|
103
106
|
"nullable": True,
|
|
104
107
|
},
|
|
105
|
-
|
|
106
108
|
# Professional Information
|
|
107
109
|
"occupation": {
|
|
108
110
|
"type": str,
|
|
@@ -125,7 +127,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
125
127
|
"default": 0,
|
|
126
128
|
"nullable": True,
|
|
127
129
|
},
|
|
128
|
-
|
|
129
130
|
# Metadata Fields
|
|
130
131
|
"source": {
|
|
131
132
|
"type": str,
|
|
@@ -148,7 +149,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
148
149
|
"default": 0.0,
|
|
149
150
|
"nullable": False,
|
|
150
151
|
},
|
|
151
|
-
|
|
152
152
|
# Location Fields
|
|
153
153
|
"city": {
|
|
154
154
|
"type": str,
|
|
@@ -185,7 +185,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
185
185
|
"default": "",
|
|
186
186
|
"nullable": True,
|
|
187
187
|
},
|
|
188
|
-
|
|
189
188
|
# Demographic Fields
|
|
190
189
|
"gender": {
|
|
191
190
|
"type": str,
|
|
@@ -215,7 +214,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
215
214
|
"default": "",
|
|
216
215
|
"nullable": True,
|
|
217
216
|
},
|
|
218
|
-
|
|
219
217
|
# Education Fields
|
|
220
218
|
"education_level": {
|
|
221
219
|
"type": str,
|
|
@@ -238,7 +236,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
238
236
|
"default": None,
|
|
239
237
|
"nullable": True,
|
|
240
238
|
},
|
|
241
|
-
|
|
242
239
|
# Financial Fields
|
|
243
240
|
"salary": {
|
|
244
241
|
"type": float,
|
|
@@ -261,7 +258,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
261
258
|
"default": None,
|
|
262
259
|
"nullable": True,
|
|
263
260
|
},
|
|
264
|
-
|
|
265
261
|
# Social Media Fields
|
|
266
262
|
"sentiment": {
|
|
267
263
|
"type": str,
|
|
@@ -292,7 +288,6 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
292
288
|
"default": "",
|
|
293
289
|
"nullable": True,
|
|
294
290
|
},
|
|
295
|
-
|
|
296
291
|
# Enum Fields for Task Management
|
|
297
292
|
"priority": {
|
|
298
293
|
"type": str,
|
|
@@ -325,28 +320,32 @@ BASE_FIELD_DEFINITIONS = {
|
|
|
325
320
|
"enum": ["formal", "informal", "optimistic", "pessimistic"],
|
|
326
321
|
"default": "formal",
|
|
327
322
|
"nullable": True,
|
|
328
|
-
}
|
|
323
|
+
},
|
|
329
324
|
}
|
|
330
325
|
|
|
326
|
+
|
|
331
327
|
def _initialize_registry() -> None:
|
|
332
328
|
"""Initialize the global registry with base field definitions."""
|
|
333
329
|
with _registry_lock:
|
|
334
330
|
if not _global_registry:
|
|
335
331
|
_global_registry.update(BASE_FIELD_DEFINITIONS)
|
|
336
332
|
|
|
333
|
+
|
|
337
334
|
# Initialize registry on import
|
|
338
335
|
_initialize_registry()
|
|
339
336
|
|
|
340
337
|
# Type hints for field definition structure
|
|
341
338
|
FieldType = Union[type, str]
|
|
342
|
-
FieldDefinition =
|
|
343
|
-
FieldDefinitions =
|
|
339
|
+
FieldDefinition = dict[str, Union[FieldType, str, Any, bool]]
|
|
340
|
+
FieldDefinitions = dict[str, FieldDefinition]
|
|
344
341
|
|
|
345
342
|
# Maintain backward compatibility
|
|
346
343
|
FIELD_DEFINITIONS = _global_registry
|
|
347
344
|
|
|
348
|
-
|
|
349
|
-
|
|
345
|
+
|
|
346
|
+
def get_field_definition(
|
|
347
|
+
field_name: str, apply_templates: bool = True, custom_template_vars: Optional[dict[str, Any]] = None
|
|
348
|
+
) -> Optional[FieldDefinition]:
|
|
350
349
|
"""
|
|
351
350
|
Retrieve the definition for a specific field from the global registry.
|
|
352
351
|
|
|
@@ -360,22 +359,23 @@ def get_field_definition(field_name: str, apply_templates: bool = True,
|
|
|
360
359
|
"""
|
|
361
360
|
with _registry_lock:
|
|
362
361
|
field_def = _global_registry.get(field_name)
|
|
363
|
-
|
|
362
|
+
|
|
364
363
|
if field_def is None:
|
|
365
364
|
return None
|
|
366
|
-
|
|
365
|
+
|
|
367
366
|
# Make a copy to avoid modifying the original
|
|
368
367
|
result = field_def.copy()
|
|
369
|
-
|
|
368
|
+
|
|
370
369
|
if apply_templates:
|
|
371
370
|
# Apply templates to string values
|
|
372
371
|
for key, value in result.items():
|
|
373
372
|
if isinstance(value, str):
|
|
374
373
|
result[key] = _apply_templates(value, custom_template_vars)
|
|
375
|
-
|
|
374
|
+
|
|
376
375
|
return result
|
|
377
376
|
|
|
378
|
-
|
|
377
|
+
|
|
378
|
+
def get_required_fields() -> list[str]:
|
|
379
379
|
"""
|
|
380
380
|
Get a list of all required (non-nullable) fields.
|
|
381
381
|
|
|
@@ -384,11 +384,11 @@ def get_required_fields() -> List[str]:
|
|
|
384
384
|
"""
|
|
385
385
|
with _registry_lock:
|
|
386
386
|
return [
|
|
387
|
-
field_name for field_name, definition in _global_registry.items()
|
|
388
|
-
if not definition.get('nullable', True)
|
|
387
|
+
field_name for field_name, definition in _global_registry.items() if not definition.get("nullable", True)
|
|
389
388
|
]
|
|
390
389
|
|
|
391
|
-
|
|
390
|
+
|
|
391
|
+
def get_field_names() -> list[str]:
|
|
392
392
|
"""
|
|
393
393
|
Get a list of all defined field names.
|
|
394
394
|
|
|
@@ -398,163 +398,167 @@ def get_field_names() -> List[str]:
|
|
|
398
398
|
with _registry_lock:
|
|
399
399
|
return list(_global_registry.keys())
|
|
400
400
|
|
|
401
|
+
|
|
401
402
|
def register_field(field_name: str, field_definition: FieldDefinition) -> None:
|
|
402
403
|
"""
|
|
403
404
|
Register a single field definition in the global registry.
|
|
404
|
-
|
|
405
|
+
|
|
405
406
|
Args:
|
|
406
407
|
field_name (str): Name of the field
|
|
407
408
|
field_definition (FieldDefinition): Field definition dictionary
|
|
408
|
-
|
|
409
|
+
|
|
409
410
|
Raises:
|
|
410
411
|
ValueError: If field definition is invalid
|
|
411
412
|
"""
|
|
412
413
|
from .tools import validate_field_definition
|
|
413
|
-
|
|
414
|
+
|
|
414
415
|
if not validate_field_definition(field_definition):
|
|
415
416
|
raise ValueError(f"Invalid field definition for '{field_name}'")
|
|
416
|
-
|
|
417
|
+
|
|
417
418
|
with _registry_lock:
|
|
418
419
|
_global_registry[field_name] = field_definition.copy()
|
|
419
420
|
|
|
421
|
+
|
|
420
422
|
def add_field_definition(field_name: str, field_definition: FieldDefinition) -> None:
|
|
421
423
|
"""
|
|
422
424
|
Add a field definition to the global registry (alias for register_field).
|
|
423
|
-
|
|
425
|
+
|
|
424
426
|
Args:
|
|
425
427
|
field_name (str): Name of the field
|
|
426
428
|
field_definition (FieldDefinition): Field definition dictionary
|
|
427
429
|
"""
|
|
428
430
|
register_field(field_name, field_definition)
|
|
429
431
|
|
|
430
|
-
|
|
432
|
+
|
|
433
|
+
def add_field_definitions(field_definitions: dict[str, FieldDefinition]) -> None:
|
|
431
434
|
"""
|
|
432
435
|
Add multiple field definitions to the global registry.
|
|
433
|
-
|
|
436
|
+
|
|
434
437
|
Args:
|
|
435
438
|
field_definitions (Dict[str, FieldDefinition]): Dictionary of field definitions
|
|
436
439
|
"""
|
|
437
440
|
for field_name, field_def in field_definitions.items():
|
|
438
441
|
register_field(field_name, field_def)
|
|
439
442
|
|
|
440
|
-
|
|
441
|
-
|
|
443
|
+
|
|
444
|
+
def field_from_registry(
|
|
445
|
+
field_name: str, apply_templates: bool = True, custom_template_vars: Optional[dict[str, Any]] = None
|
|
446
|
+
) -> Field:
|
|
442
447
|
"""
|
|
443
448
|
Create a Pydantic Field from a field definition in the global registry.
|
|
444
|
-
|
|
449
|
+
|
|
445
450
|
Args:
|
|
446
451
|
field_name (str): Name of the field in the registry
|
|
447
452
|
apply_templates (bool): Whether to apply template variable substitution
|
|
448
453
|
custom_template_vars (Optional[Dict[str, Any]]): Custom template variables
|
|
449
|
-
|
|
454
|
+
|
|
450
455
|
Returns:
|
|
451
456
|
pydantic.Field: Configured Pydantic Field object
|
|
452
|
-
|
|
457
|
+
|
|
453
458
|
Raises:
|
|
454
459
|
KeyError: If field_name is not found in the registry
|
|
455
460
|
"""
|
|
456
461
|
field_def = get_field_definition(field_name, apply_templates, custom_template_vars)
|
|
457
|
-
|
|
462
|
+
|
|
458
463
|
if field_def is None:
|
|
459
464
|
raise KeyError(f"Field '{field_name}' not found in registry. Available fields: {', '.join(get_field_names())}")
|
|
460
|
-
|
|
465
|
+
|
|
461
466
|
# Extract Pydantic Field parameters
|
|
462
|
-
default_value = field_def.get(
|
|
463
|
-
description = field_def.get(
|
|
464
|
-
instructions = field_def.get(
|
|
465
|
-
|
|
467
|
+
default_value = field_def.get("default")
|
|
468
|
+
description = field_def.get("description", f"Extract the {field_name} from the text.")
|
|
469
|
+
instructions = field_def.get("instructions", "")
|
|
470
|
+
|
|
466
471
|
# Handle enum fields
|
|
467
|
-
enum_values = field_def.get(
|
|
472
|
+
enum_values = field_def.get("enum")
|
|
468
473
|
if enum_values:
|
|
469
474
|
# Enhance description with enum constraint information
|
|
470
475
|
enum_str = "', '".join(str(v) for v in enum_values)
|
|
471
476
|
enhanced_instructions = f"{instructions}. Must be one of: '{enum_str}'"
|
|
472
477
|
enhanced_description = f"{description}. Allowed values: {enum_str}"
|
|
473
|
-
|
|
478
|
+
|
|
474
479
|
# Create json_schema_extra with enum constraint
|
|
475
|
-
json_schema_extra = {
|
|
476
|
-
|
|
477
|
-
"instructions": enhanced_instructions
|
|
478
|
-
}
|
|
479
|
-
|
|
480
|
+
json_schema_extra = {"enum": enum_values, "instructions": enhanced_instructions}
|
|
481
|
+
|
|
480
482
|
# Handle nullable/required logic with enum
|
|
481
|
-
if field_def.get(
|
|
483
|
+
if field_def.get("nullable", True) and default_value is not None:
|
|
482
484
|
return Field(default=default_value, description=enhanced_description, json_schema_extra=json_schema_extra)
|
|
483
|
-
elif field_def.get(
|
|
485
|
+
elif field_def.get("nullable", True):
|
|
484
486
|
return Field(default=None, description=enhanced_description, json_schema_extra=json_schema_extra)
|
|
485
487
|
else:
|
|
486
488
|
return Field(description=enhanced_description, json_schema_extra=json_schema_extra)
|
|
487
|
-
|
|
489
|
+
|
|
488
490
|
# Handle non-enum fields (original logic)
|
|
489
|
-
if field_def.get(
|
|
491
|
+
if field_def.get("nullable", True) and default_value is not None:
|
|
490
492
|
# Optional field with default
|
|
491
493
|
return Field(default=default_value, description=description)
|
|
492
|
-
elif field_def.get(
|
|
494
|
+
elif field_def.get("nullable", True):
|
|
493
495
|
# Optional field without default (None)
|
|
494
496
|
return Field(default=None, description=description)
|
|
495
497
|
else:
|
|
496
498
|
# Required field
|
|
497
499
|
return Field(description=description)
|
|
498
500
|
|
|
501
|
+
|
|
499
502
|
def validate_enum_value(field_name: str, value: Any) -> bool:
|
|
500
503
|
"""
|
|
501
504
|
Validate that a value is in the allowed enum list for a field.
|
|
502
|
-
|
|
505
|
+
|
|
503
506
|
Args:
|
|
504
507
|
field_name (str): Name of the field in the registry
|
|
505
508
|
value (Any): Value to validate
|
|
506
|
-
|
|
509
|
+
|
|
507
510
|
Returns:
|
|
508
511
|
bool: True if value is valid for the enum field, False otherwise
|
|
509
512
|
"""
|
|
510
513
|
field_def = get_field_definition(field_name, apply_templates=False)
|
|
511
|
-
|
|
514
|
+
|
|
512
515
|
if field_def is None:
|
|
513
516
|
return False
|
|
514
|
-
|
|
515
|
-
enum_values = field_def.get(
|
|
517
|
+
|
|
518
|
+
enum_values = field_def.get("enum")
|
|
516
519
|
if not enum_values:
|
|
517
520
|
# Not an enum field, so any value is valid
|
|
518
521
|
return True
|
|
519
|
-
|
|
522
|
+
|
|
520
523
|
# Check if value is in the allowed enum list
|
|
521
524
|
return value in enum_values
|
|
522
525
|
|
|
526
|
+
|
|
523
527
|
def normalize_enum_value(field_name: str, value: Any, case_sensitive: bool = True) -> Any:
|
|
524
528
|
"""
|
|
525
529
|
Normalize and validate an enum value for a field.
|
|
526
|
-
|
|
530
|
+
|
|
527
531
|
Args:
|
|
528
532
|
field_name (str): Name of the field in the registry
|
|
529
533
|
value (Any): Value to normalize
|
|
530
534
|
case_sensitive (bool): Whether to perform case-sensitive matching
|
|
531
|
-
|
|
535
|
+
|
|
532
536
|
Returns:
|
|
533
537
|
Any: Normalized value if valid, original value otherwise
|
|
534
|
-
|
|
538
|
+
|
|
535
539
|
Raises:
|
|
536
540
|
ValueError: If value is not in the allowed enum list
|
|
537
541
|
"""
|
|
538
542
|
field_def = get_field_definition(field_name, apply_templates=False)
|
|
539
|
-
|
|
543
|
+
|
|
540
544
|
if field_def is None:
|
|
541
545
|
raise KeyError(f"Field '{field_name}' not found in registry")
|
|
542
|
-
|
|
543
|
-
enum_values = field_def.get(
|
|
546
|
+
|
|
547
|
+
enum_values = field_def.get("enum")
|
|
544
548
|
if not enum_values:
|
|
545
549
|
# Not an enum field, return as-is
|
|
546
550
|
return value
|
|
547
|
-
|
|
551
|
+
|
|
548
552
|
# Convert value to string for comparison
|
|
549
553
|
str_value = str(value) if value is not None else None
|
|
550
|
-
|
|
554
|
+
|
|
551
555
|
if str_value is None:
|
|
552
556
|
# Handle nullable fields
|
|
553
|
-
if field_def.get(
|
|
557
|
+
if field_def.get("nullable", True):
|
|
554
558
|
return None
|
|
555
559
|
else:
|
|
556
560
|
raise ValueError(f"Field '{field_name}' does not allow null values")
|
|
557
|
-
|
|
561
|
+
|
|
558
562
|
# Case-sensitive matching
|
|
559
563
|
if case_sensitive:
|
|
560
564
|
if str_value in enum_values:
|
|
@@ -563,28 +567,30 @@ def normalize_enum_value(field_name: str, value: Any, case_sensitive: bool = Tru
|
|
|
563
567
|
f"Invalid value '{str_value}' for field '{field_name}'. "
|
|
564
568
|
f"Must be one of: {', '.join(repr(v) for v in enum_values)}"
|
|
565
569
|
)
|
|
566
|
-
|
|
570
|
+
|
|
567
571
|
# Case-insensitive matching
|
|
568
572
|
lower_value = str_value.lower()
|
|
569
573
|
for enum_val in enum_values:
|
|
570
574
|
if str(enum_val).lower() == lower_value:
|
|
571
575
|
return enum_val
|
|
572
|
-
|
|
576
|
+
|
|
573
577
|
raise ValueError(
|
|
574
578
|
f"Invalid value '{str_value}' for field '{field_name}'. "
|
|
575
579
|
f"Must be one of: {', '.join(repr(v) for v in enum_values)}"
|
|
576
580
|
)
|
|
577
581
|
|
|
578
|
-
|
|
582
|
+
|
|
583
|
+
def get_registry_snapshot() -> dict[str, FieldDefinition]:
|
|
579
584
|
"""
|
|
580
585
|
Get a snapshot of the current global registry.
|
|
581
|
-
|
|
586
|
+
|
|
582
587
|
Returns:
|
|
583
588
|
Dict[str, FieldDefinition]: Copy of the current registry
|
|
584
589
|
"""
|
|
585
590
|
with _registry_lock:
|
|
586
591
|
return _global_registry.copy()
|
|
587
592
|
|
|
593
|
+
|
|
588
594
|
def clear_registry() -> None:
|
|
589
595
|
"""
|
|
590
596
|
Clear all field definitions from the global registry.
|
|
@@ -593,6 +599,7 @@ def clear_registry() -> None:
|
|
|
593
599
|
with _registry_lock:
|
|
594
600
|
_global_registry.clear()
|
|
595
601
|
|
|
602
|
+
|
|
596
603
|
def reset_registry() -> None:
|
|
597
604
|
"""
|
|
598
605
|
Reset the global registry to contain only the base field definitions.
|
|
@@ -601,22 +608,24 @@ def reset_registry() -> None:
|
|
|
601
608
|
_global_registry.clear()
|
|
602
609
|
_global_registry.update(BASE_FIELD_DEFINITIONS)
|
|
603
610
|
|
|
611
|
+
|
|
604
612
|
# For backward compatibility, keep the old FIELD_DEFINITIONS reference
|
|
605
613
|
# but make it point to the global registry
|
|
606
614
|
def _get_field_definitions():
|
|
607
615
|
"""Backward compatibility getter for FIELD_DEFINITIONS."""
|
|
608
616
|
return get_registry_snapshot()
|
|
609
617
|
|
|
618
|
+
|
|
610
619
|
# Create a property-like access to maintain backward compatibility
|
|
611
620
|
class _FieldDefinitionsProxy(dict, collections.abc.MutableMapping):
|
|
612
621
|
"""Proxy class to maintain backward compatibility with FIELD_DEFINITIONS."""
|
|
613
|
-
|
|
622
|
+
|
|
614
623
|
def __getitem__(self, key):
|
|
615
624
|
return get_field_definition(key)
|
|
616
|
-
|
|
625
|
+
|
|
617
626
|
def __setitem__(self, key, value):
|
|
618
627
|
register_field(key, value)
|
|
619
|
-
|
|
628
|
+
|
|
620
629
|
def __delitem__(self, key):
|
|
621
630
|
"""Remove a field from the registry."""
|
|
622
631
|
with _registry_lock:
|
|
@@ -624,37 +633,38 @@ class _FieldDefinitionsProxy(dict, collections.abc.MutableMapping):
|
|
|
624
633
|
del _global_registry[key]
|
|
625
634
|
else:
|
|
626
635
|
raise KeyError(f"Field '{key}' not found in registry")
|
|
627
|
-
|
|
636
|
+
|
|
628
637
|
def __contains__(self, key):
|
|
629
638
|
return key in get_field_names()
|
|
630
|
-
|
|
639
|
+
|
|
631
640
|
def __iter__(self):
|
|
632
641
|
return iter(get_field_names())
|
|
633
|
-
|
|
642
|
+
|
|
634
643
|
def keys(self):
|
|
635
644
|
return get_field_names()
|
|
636
|
-
|
|
645
|
+
|
|
637
646
|
def values(self):
|
|
638
647
|
with _registry_lock:
|
|
639
648
|
return list(_global_registry.values())
|
|
640
|
-
|
|
649
|
+
|
|
641
650
|
def items(self):
|
|
642
651
|
with _registry_lock:
|
|
643
652
|
return list(_global_registry.items())
|
|
644
|
-
|
|
653
|
+
|
|
645
654
|
def __len__(self):
|
|
646
655
|
with _registry_lock:
|
|
647
656
|
return len(_global_registry)
|
|
648
|
-
|
|
657
|
+
|
|
649
658
|
def get(self, key, default=None):
|
|
650
659
|
field_def = get_field_definition(key)
|
|
651
660
|
return field_def if field_def is not None else default
|
|
652
|
-
|
|
661
|
+
|
|
653
662
|
def update(self, other):
|
|
654
|
-
if hasattr(other,
|
|
663
|
+
if hasattr(other, "items"):
|
|
655
664
|
add_field_definitions(dict(other.items()))
|
|
656
665
|
else:
|
|
657
666
|
add_field_definitions(dict(other))
|
|
658
667
|
|
|
668
|
+
|
|
659
669
|
# Replace FIELD_DEFINITIONS with the proxy for backward compatibility
|
|
660
|
-
FIELD_DEFINITIONS = _FieldDefinitionsProxy()
|
|
670
|
+
FIELD_DEFINITIONS = _FieldDefinitionsProxy()
|
prompture/group_types.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
"""Shared types for multi-agent group coordination.
|
|
2
|
+
|
|
3
|
+
Defines enums, dataclasses, and callbacks used by
|
|
4
|
+
:class:`~prompture.groups.SequentialGroup`,
|
|
5
|
+
:class:`~prompture.async_groups.ParallelGroup`, and related classes.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import enum
|
|
11
|
+
import json
|
|
12
|
+
import time
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ErrorPolicy(enum.Enum):
|
|
19
|
+
"""How a group handles agent failures."""
|
|
20
|
+
|
|
21
|
+
fail_fast = "fail_fast"
|
|
22
|
+
continue_on_error = "continue_on_error"
|
|
23
|
+
retry_failed = "retry_failed"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class GroupStep:
|
|
28
|
+
"""Record of a single agent execution within a group run."""
|
|
29
|
+
|
|
30
|
+
agent_name: str
|
|
31
|
+
step_type: str = "agent_run"
|
|
32
|
+
timestamp: float = 0.0
|
|
33
|
+
duration_ms: float = 0.0
|
|
34
|
+
usage_delta: dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
error: str | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class AgentError:
|
|
40
|
+
"""Captures a failed agent execution within a group."""
|
|
41
|
+
|
|
42
|
+
agent_name: str
|
|
43
|
+
error: Exception
|
|
44
|
+
error_message: str = ""
|
|
45
|
+
output_key: str | None = None
|
|
46
|
+
|
|
47
|
+
def __post_init__(self) -> None:
|
|
48
|
+
if not self.error_message:
|
|
49
|
+
self.error_message = str(self.error)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class GroupResult:
|
|
54
|
+
"""Outcome of a group execution.
|
|
55
|
+
|
|
56
|
+
Attributes:
|
|
57
|
+
agent_results: Mapping of agent name/key to their :class:`AgentResult`.
|
|
58
|
+
aggregate_usage: Combined token/cost totals across all agent runs.
|
|
59
|
+
shared_state: Final state dict after all agents have written outputs.
|
|
60
|
+
elapsed_ms: Wall-clock duration of the group run.
|
|
61
|
+
timeline: Ordered list of :class:`GroupStep` records.
|
|
62
|
+
errors: List of :class:`AgentError` for any failed agents.
|
|
63
|
+
success: ``True`` if no errors occurred.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
agent_results: dict[str, Any] = field(default_factory=dict)
|
|
67
|
+
aggregate_usage: dict[str, Any] = field(default_factory=dict)
|
|
68
|
+
shared_state: dict[str, Any] = field(default_factory=dict)
|
|
69
|
+
elapsed_ms: float = 0.0
|
|
70
|
+
timeline: list[GroupStep] = field(default_factory=list)
|
|
71
|
+
errors: list[AgentError] = field(default_factory=list)
|
|
72
|
+
success: bool = True
|
|
73
|
+
|
|
74
|
+
def export(self) -> dict[str, Any]:
|
|
75
|
+
"""Return a JSON-serializable dict representation."""
|
|
76
|
+
return {
|
|
77
|
+
"agent_results": {
|
|
78
|
+
k: {
|
|
79
|
+
"output_text": getattr(v, "output_text", str(v)),
|
|
80
|
+
"usage": getattr(v, "run_usage", {}),
|
|
81
|
+
}
|
|
82
|
+
for k, v in self.agent_results.items()
|
|
83
|
+
},
|
|
84
|
+
"aggregate_usage": self.aggregate_usage,
|
|
85
|
+
"shared_state": self.shared_state,
|
|
86
|
+
"elapsed_ms": self.elapsed_ms,
|
|
87
|
+
"timeline": [
|
|
88
|
+
{
|
|
89
|
+
"agent_name": s.agent_name,
|
|
90
|
+
"step_type": s.step_type,
|
|
91
|
+
"timestamp": s.timestamp,
|
|
92
|
+
"duration_ms": s.duration_ms,
|
|
93
|
+
"usage_delta": s.usage_delta,
|
|
94
|
+
"error": s.error,
|
|
95
|
+
}
|
|
96
|
+
for s in self.timeline
|
|
97
|
+
],
|
|
98
|
+
"errors": [
|
|
99
|
+
{
|
|
100
|
+
"agent_name": e.agent_name,
|
|
101
|
+
"error_message": e.error_message,
|
|
102
|
+
"output_key": e.output_key,
|
|
103
|
+
}
|
|
104
|
+
for e in self.errors
|
|
105
|
+
],
|
|
106
|
+
"success": self.success,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
def save(self, path: str) -> None:
|
|
110
|
+
"""Write the exported dict to a JSON file."""
|
|
111
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
112
|
+
json.dump(self.export(), f, indent=2, default=str)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class GroupCallbacks:
|
|
117
|
+
"""Observability callbacks for group execution."""
|
|
118
|
+
|
|
119
|
+
on_agent_start: Callable[[str, str], None] | None = None
|
|
120
|
+
on_agent_complete: Callable[[str, Any], None] | None = None
|
|
121
|
+
on_agent_error: Callable[[str, Exception], None] | None = None
|
|
122
|
+
on_state_update: Callable[[str, Any], None] | None = None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _aggregate_usage(*sessions: dict[str, Any]) -> dict[str, Any]:
|
|
126
|
+
"""Merge multiple usage summary dicts into one aggregate."""
|
|
127
|
+
agg: dict[str, Any] = {
|
|
128
|
+
"prompt_tokens": 0,
|
|
129
|
+
"completion_tokens": 0,
|
|
130
|
+
"total_tokens": 0,
|
|
131
|
+
"total_cost": 0.0,
|
|
132
|
+
"call_count": 0,
|
|
133
|
+
"errors": 0,
|
|
134
|
+
}
|
|
135
|
+
for s in sessions:
|
|
136
|
+
agg["prompt_tokens"] += s.get("prompt_tokens", 0)
|
|
137
|
+
agg["completion_tokens"] += s.get("completion_tokens", 0)
|
|
138
|
+
agg["total_tokens"] += s.get("total_tokens", 0)
|
|
139
|
+
agg["total_cost"] += s.get("total_cost", 0.0)
|
|
140
|
+
agg["call_count"] += s.get("call_count", 0)
|
|
141
|
+
agg["errors"] += s.get("errors", 0)
|
|
142
|
+
return agg
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _now_ms() -> float:
|
|
146
|
+
"""Current time in milliseconds (perf_counter-based)."""
|
|
147
|
+
return time.perf_counter() * 1000
|