bitfab-py 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bitfab/__init__.py ADDED
@@ -0,0 +1,50 @@
1
+ """Bitfab client for provider-based API calls."""
2
+
3
+ from bitfab.client import (
4
+ AllowedEnvVars,
5
+ Bitfab,
6
+ BitfabFunction,
7
+ CurrentSpan,
8
+ CurrentTrace,
9
+ SpanType,
10
+ flush_traces,
11
+ get_current_span,
12
+ get_current_trace,
13
+ )
14
+ from bitfab.replay import ReplayItem, ReplayResult
15
+
16
+ # Only export BitfabTracingProcessor if openai-agents is available
17
+ try:
18
+ from bitfab.tracing import (
19
+ BitfabOpenAITracingProcessor as BitfabTracingProcessor,
20
+ )
21
+
22
+ __all__ = [
23
+ "AllowedEnvVars",
24
+ "Bitfab",
25
+ "BitfabFunction",
26
+ "BitfabTracingProcessor",
27
+ "CurrentSpan",
28
+ "CurrentTrace",
29
+ "ReplayItem",
30
+ "ReplayResult",
31
+ "SpanType",
32
+ "flush_traces",
33
+ "get_current_span",
34
+ "get_current_trace",
35
+ ]
36
+ except ImportError:
37
+ # openai-agents not installed, skip tracing processor export
38
+ __all__ = [
39
+ "AllowedEnvVars",
40
+ "Bitfab",
41
+ "BitfabFunction",
42
+ "CurrentSpan",
43
+ "CurrentTrace",
44
+ "ReplayItem",
45
+ "ReplayResult",
46
+ "SpanType",
47
+ "flush_traces",
48
+ "get_current_span",
49
+ "get_current_trace",
50
+ ]
bitfab/baml.py ADDED
@@ -0,0 +1,611 @@
1
+ """BAML execution utilities for the Bitfab Python SDK.
2
+
3
+ This module provides functions to execute BAML prompts dynamically on the client side.
4
+ """
5
+
6
+ import contextlib
7
+ import json
8
+ import logging
9
+ import re
10
+ from dataclasses import dataclass
11
+ from typing import Any, Optional
12
+
13
+ from pydantic import BaseModel, create_model
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Allowed environment variable keys for LLM providers
18
+ ALLOWED_ENV_KEYS = ["OPENAI_API_KEY"]
19
+
20
+
21
+ @dataclass
22
+ class BamlExecutionResult:
23
+ """Result of a BAML function execution with raw collector data."""
24
+
25
+ result: Any
26
+ raw_collector: Optional[dict[str, Any]] = None
27
+
28
+
29
+ class ProviderDefinition:
30
+ """Provider definition from the server."""
31
+
32
+ def __init__(self, provider: str, api_key_env: str, models: list[dict[str, str]]):
33
+ self.provider = provider
34
+ self.api_key_env = api_key_env
35
+ self.models = models
36
+
37
+
38
+ def filter_env_vars(env_vars: dict[str, str]) -> dict[str, str]:
39
+ """Filter environment variables to only include allowed keys.
40
+
41
+ This prevents accidentally passing sensitive environment variables to the BAML runtime.
42
+
43
+ Args:
44
+ env_vars: Environment variables dictionary
45
+
46
+ Returns:
47
+ Filtered dictionary with only allowed keys
48
+ """
49
+ filtered = {}
50
+ for key in ALLOWED_ENV_KEYS:
51
+ if key in env_vars:
52
+ filtered[key] = env_vars[key]
53
+ return filtered
54
+
55
+
56
+ def parse_baml_class_to_pydantic(
57
+ baml_source: str, class_name: str
58
+ ) -> Optional[type[BaseModel]]:
59
+ """Parse a BAML class definition and create a Pydantic model.
60
+
61
+ Args:
62
+ baml_source: The BAML source code
63
+ class_name: The name of the class to parse
64
+
65
+ Returns:
66
+ A dynamically created Pydantic model, or None if parsing fails
67
+ """
68
+ # Find the class definition
69
+ class_pattern = rf"class\s+{re.escape(class_name)}\s*\{{([^}}]+)\}}"
70
+ class_match = re.search(class_pattern, baml_source, re.DOTALL)
71
+
72
+ if not class_match:
73
+ return None
74
+
75
+ class_body = class_match.group(1)
76
+
77
+ # Parse fields: field_name type? @description("...")
78
+ field_pattern = r"(\w+)\s+(string|int|float|bool)(\?)?"
79
+ fields = {}
80
+
81
+ for match in re.finditer(field_pattern, class_body):
82
+ field_name = match.group(1)
83
+ field_type_str = match.group(2)
84
+ is_optional = match.group(3) == "?"
85
+
86
+ # Map BAML types to Python types
87
+ type_map = {
88
+ "string": str,
89
+ "int": int,
90
+ "float": float,
91
+ "bool": bool,
92
+ }
93
+
94
+ field_type = type_map.get(field_type_str, str)
95
+
96
+ if is_optional:
97
+ fields[field_name] = (Optional[field_type], None)
98
+ else:
99
+ fields[field_name] = (field_type, ...)
100
+
101
+ if not fields:
102
+ return None
103
+
104
+ # Create Pydantic model dynamically
105
+ try:
106
+ return create_model(class_name, **fields)
107
+ except Exception as e:
108
+ logger.warning(f"Failed to create Pydantic model for {class_name}: {e}")
109
+ return None
110
+
111
+
112
+ def extract_function_name(baml_source: str) -> Optional[str]:
113
+ """Extract the first function name from BAML source code.
114
+
115
+ Args:
116
+ baml_source: BAML source code
117
+
118
+ Returns:
119
+ Function name or None if not found
120
+ """
121
+ match = re.search(r"function\s+(\w+)\s*\(", baml_source)
122
+ return match.group(1) if match else None
123
+
124
+
125
+ @dataclass
126
+ class BamlParameterType:
127
+ """Parameter type information extracted from BAML function signature."""
128
+
129
+ name: str
130
+ type: str
131
+ is_optional: bool
132
+
133
+
134
+ def extract_function_parameters(baml_source: str) -> list[BamlParameterType]:
135
+ """Extract function parameter names and types from BAML source code.
136
+
137
+ Used to properly coerce inputs based on expected types.
138
+
139
+ Args:
140
+ baml_source: The BAML source code
141
+
142
+ Returns:
143
+ List of parameter info with name, type, and optionality
144
+ """
145
+ # Match function signature: function Name(param1: type1, param2: type2?) -> ReturnType
146
+ function_match = re.search(r"function\s+\w+\s*\(([^)]*)\)\s*->", baml_source)
147
+ if not function_match:
148
+ return []
149
+
150
+ params_string = function_match.group(1).strip()
151
+ if not params_string:
152
+ return []
153
+
154
+ params: list[BamlParameterType] = []
155
+
156
+ # Split by comma, handling potential nested types like Map<string, int>
157
+ param_parts = _split_parameters(params_string)
158
+
159
+ for part in param_parts:
160
+ trimmed = part.strip()
161
+ if not trimmed:
162
+ continue
163
+
164
+ # Match: paramName: type or paramName: type?
165
+ param_match = re.match(r"^(\w+)\s*:\s*(.+)$", trimmed)
166
+ if param_match:
167
+ name = param_match.group(1)
168
+ param_type = param_match.group(2).strip()
169
+ is_optional = param_type.endswith("?")
170
+ if is_optional:
171
+ param_type = param_type[:-1]
172
+ params.append(
173
+ BamlParameterType(name=name, type=param_type, is_optional=is_optional)
174
+ )
175
+
176
+ return params
177
+
178
+
179
+ def _split_parameters(params_string: str) -> list[str]:
180
+ """Split parameter string by commas, respecting nested angle brackets."""
181
+ parts: list[str] = []
182
+ current = ""
183
+ depth = 0
184
+
185
+ for char in params_string:
186
+ if char == "<":
187
+ depth += 1
188
+ current += char
189
+ elif char == ">":
190
+ depth -= 1
191
+ current += char
192
+ elif char == "," and depth == 0:
193
+ parts.append(current)
194
+ current = ""
195
+ else:
196
+ current += char
197
+
198
+ if current.strip():
199
+ parts.append(current)
200
+
201
+ return parts
202
+
203
+
204
+ def _coerce_to_type(value: str, expected_type: str) -> Any:
205
+ """Coerce a single string value to the expected BAML type.
206
+
207
+ Returns the coerced value, or the original string if coercion fails.
208
+ """
209
+ # String type - keep as is
210
+ if expected_type == "string":
211
+ return value
212
+
213
+ # Integer type
214
+ if expected_type == "int":
215
+ try:
216
+ return int(value)
217
+ except ValueError:
218
+ return value
219
+
220
+ # Float type
221
+ if expected_type == "float":
222
+ try:
223
+ return float(value)
224
+ except ValueError:
225
+ return value
226
+
227
+ # Boolean type
228
+ if expected_type == "bool":
229
+ lower = value.lower()
230
+ if lower == "true":
231
+ return True
232
+ if lower == "false":
233
+ return False
234
+ return value
235
+
236
+ # Array types (e.g., string[], int[])
237
+ if expected_type.endswith("[]"):
238
+ try:
239
+ parsed = json.loads(value)
240
+ if isinstance(parsed, list):
241
+ return parsed
242
+ except (json.JSONDecodeError, ValueError):
243
+ pass
244
+ return value
245
+
246
+ # Complex types (objects, classes, maps) - try JSON parse
247
+ try:
248
+ return json.loads(value)
249
+ except (json.JSONDecodeError, ValueError):
250
+ return value
251
+
252
+
253
+ def coerce_inputs(
254
+ inputs: dict[str, Any], expected_types: Optional[dict[str, str]] = None
255
+ ) -> dict[str, Any]:
256
+ """Coerce input values from strings to their appropriate types based on expected BAML types.
257
+
258
+ Actively coerces to the expected type (int, float, bool, etc.) rather than just avoiding
259
+ unintended conversions.
260
+
261
+ Args:
262
+ inputs: Input dictionary from web UI (all values are strings)
263
+ expected_types: Map of parameter names to their expected BAML types
264
+
265
+ Returns:
266
+ Coerced input dictionary with proper types
267
+ """
268
+ if expected_types is None:
269
+ expected_types = {}
270
+
271
+ coerced = {}
272
+
273
+ for key, value in inputs.items():
274
+ if isinstance(value, str):
275
+ expected_type = expected_types.get(key)
276
+
277
+ if expected_type:
278
+ # Coerce to the expected type
279
+ coerced[key] = _coerce_to_type(value, expected_type)
280
+ else:
281
+ # No expected type info - keep as string
282
+ coerced[key] = value
283
+ else:
284
+ coerced[key] = value
285
+
286
+ return coerced
287
+
288
+
289
+ def format_provider(provider: str) -> str:
290
+ """Convert provider name to PascalCase.
291
+
292
+ Args:
293
+ provider: Provider name (e.g., "openai")
294
+
295
+ Returns:
296
+ Formatted provider name (e.g., "OpenAI")
297
+ """
298
+ provider_map = {
299
+ "openai": "OpenAI",
300
+ "anthropic": "Anthropic",
301
+ "google": "Google",
302
+ }
303
+ return provider_map.get(provider, provider.capitalize())
304
+
305
+
306
+ def format_model(model: str) -> str:
307
+ """Convert a model name to a valid BAML identifier part.
308
+
309
+ Args:
310
+ model: Model name (e.g., "gpt-5-mini")
311
+
312
+ Returns:
313
+ Formatted model name (e.g., "GPT5_mini")
314
+ """
315
+ return (
316
+ model.replace("gpt-", "GPT") # gpt- prefix -> GPT
317
+ .replace(".", "_") # dots -> underscore
318
+ .replace("-", "_") # hyphens -> underscore
319
+ )
320
+
321
+
322
+ def get_client_name(provider: str, model: str) -> str:
323
+ """Generate the BAML client name from provider and model.
324
+
325
+ Args:
326
+ provider: Provider name
327
+ model: Model name
328
+
329
+ Returns:
330
+ BAML client name (e.g., "OpenAI_GPT4_1_mini")
331
+ """
332
+ return f"{format_provider(provider)}_{format_model(model)}"
333
+
334
+
335
+ def generate_client_definitions(providers: list[ProviderDefinition]) -> str:
336
+ """Generate BAML client definition strings.
337
+
338
+ BamlRuntime requires clients to be defined in source for parsing.
339
+
340
+ Args:
341
+ providers: List of provider definitions
342
+
343
+ Returns:
344
+ BAML client definitions as a string
345
+ """
346
+ definitions = []
347
+
348
+ for provider_def in providers:
349
+ for model in provider_def.models:
350
+ client_name = get_client_name(provider_def.provider, model["model"])
351
+ definitions.append(
352
+ f"""client<llm> {client_name} {{
353
+ provider {provider_def.provider}
354
+ options {{
355
+ model "{model["model"]}"
356
+ api_key env.{provider_def.api_key_env}
357
+ }}
358
+ }}"""
359
+ )
360
+
361
+ return "\n\n".join(definitions)
362
+
363
+
364
+ def with_default_clients(baml_source: str, providers: list[ProviderDefinition]) -> str:
365
+ """Prepend the default client definitions to a BAML source if it doesn't already define them.
366
+
367
+ Args:
368
+ baml_source: BAML source code
369
+ providers: List of provider definitions
370
+
371
+ Returns:
372
+ BAML source with client definitions
373
+ """
374
+ if "client<llm> OpenAI_" in baml_source:
375
+ return baml_source
376
+
377
+ default_clients = generate_client_definitions(providers)
378
+ return f"{default_clients}\n\n{baml_source}"
379
+
380
+
381
+ def _obj_to_dict(obj: Any, depth: int = 0, max_depth: int = 5) -> Any:
382
+ """Recursively convert an object to a JSON-serializable dict."""
383
+ if depth > max_depth:
384
+ return f"<max depth reached: {type(obj).__name__}>"
385
+
386
+ if obj is None or isinstance(obj, (str, int, float, bool)):
387
+ return obj
388
+
389
+ if isinstance(obj, (list, tuple)):
390
+ return [_obj_to_dict(item, depth + 1, max_depth) for item in obj]
391
+
392
+ if isinstance(obj, dict):
393
+ return {k: _obj_to_dict(v, depth + 1, max_depth) for k, v in obj.items()}
394
+
395
+ # For objects, try to extract their attributes
396
+ result = {"__type__": type(obj).__name__}
397
+ for attr in dir(obj):
398
+ if attr.startswith("_"):
399
+ continue
400
+ try:
401
+ value = getattr(obj, attr)
402
+ if callable(value):
403
+ continue
404
+ result[attr] = _obj_to_dict(value, depth + 1, max_depth)
405
+ except Exception as e:
406
+ result[attr] = f"<error: {e}>"
407
+
408
+ return result
409
+
410
+
411
+ async def run_function_with_baml(
412
+ baml_source: str,
413
+ inputs: dict[str, Any],
414
+ providers: list[dict[str, Any]],
415
+ env_vars: dict[str, str],
416
+ ) -> BamlExecutionResult:
417
+ """Run the BAML function with the given inputs using the BAML runtime directly.
418
+
419
+ Note: This requires the baml-py package to be installed.
420
+
421
+ Args:
422
+ baml_source: The BAML source code containing the function
423
+ inputs: Named arguments to pass to the function
424
+ providers: Available provider definitions
425
+ env_vars: Environment variables for API keys (only OPENAI_API_KEY is allowed)
426
+
427
+ Returns:
428
+ BamlExecutionResult containing the result and execution metadata
429
+
430
+ Raises:
431
+ ImportError: If baml-py is not installed
432
+ ValueError: If no function found in BAML source
433
+ RuntimeError: If BAML function execution failed
434
+ """
435
+ try:
436
+ from baml_py import BamlRuntime, Collector
437
+ except ImportError as err:
438
+ raise ImportError(
439
+ "baml-py is required for local execution. Install it with: pip install baml-py"
440
+ ) from err
441
+
442
+ # Extract function name from the BAML source
443
+ function_name = extract_function_name(baml_source)
444
+ if not function_name:
445
+ raise ValueError("No function found in BAML source")
446
+
447
+ # Convert provider dicts to ProviderDefinition objects
448
+ provider_objs = [
449
+ ProviderDefinition(p["provider"], p["apiKeyEnv"], p["models"])
450
+ for p in providers
451
+ ]
452
+
453
+ # Add default client definitions (runtime needs them for parsing)
454
+ full_source = with_default_clients(baml_source, provider_objs)
455
+
456
+ # Filter env vars to only allowed keys
457
+ filtered_env_vars = filter_env_vars(env_vars)
458
+
459
+ # Create runtime from source with env vars
460
+ runtime = BamlRuntime.from_files(
461
+ "/tmp/baml_runtime", {"source.baml": full_source}, filtered_env_vars
462
+ )
463
+
464
+ # Create context manager
465
+ ctx = runtime.create_context_manager()
466
+
467
+ # Create collector to capture execution metadata
468
+ collector = Collector("bitfab-collector")
469
+
470
+ # Extract expected parameter types from BAML source
471
+ params = extract_function_parameters(baml_source)
472
+ expected_types = {p.name: p.type for p in params}
473
+
474
+ # Coerce inputs from strings to proper types based on BAML signature
475
+ args = coerce_inputs(inputs, expected_types)
476
+
477
+ # Call the function with all required arguments
478
+ # Signature: call_function(function_name, args, ctx, tb, cb, collectors, env_vars, tags)
479
+ result = await runtime.call_function(
480
+ function_name,
481
+ args,
482
+ ctx,
483
+ None, # tb (TypeBuilder)
484
+ None, # cb (ClientRegistry)
485
+ [collector], # collectors - capture execution data
486
+ filtered_env_vars,
487
+ {}, # tags
488
+ )
489
+
490
+ if not result.is_ok():
491
+ raise RuntimeError("BAML function execution failed")
492
+
493
+ # Serialize the collector to a dict for the server to parse
494
+ raw_collector = None
495
+ try:
496
+ raw_collector = _obj_to_dict(collector)
497
+ except Exception as e:
498
+ logger.warning(f"Failed to serialize collector: {e}")
499
+
500
+ # Extract the parsed result directly from the BAML result object
501
+ # The Python BAML library uses a different API than TypeScript
502
+ # Try different methods to get the parsed result
503
+ parsed_result = None
504
+
505
+ # Try method 1: .value() or .get_value()
506
+ if hasattr(result, "value"):
507
+ with contextlib.suppress(Exception):
508
+ parsed_result = result.value()
509
+
510
+ # Try method 2: Direct property access
511
+ if parsed_result is None and hasattr(result, "value"):
512
+ with contextlib.suppress(Exception):
513
+ parsed_result = result.value
514
+
515
+ # Try method 3: unstable_internal_repr() and parse it properly
516
+ if parsed_result is None:
517
+ try:
518
+ internal_json = result.unstable_internal_repr()
519
+ internal_data = json.loads(internal_json)
520
+
521
+ # The structure should be: {"Success": {"content": <parsed_object>, ...}}
522
+ if "Success" in internal_data:
523
+ success_data = internal_data["Success"]
524
+
525
+ # Check if content is already a dict/object
526
+ if "content" in success_data:
527
+ content = success_data["content"]
528
+
529
+ # If content is a string, try to parse it as JSON
530
+ if isinstance(content, str):
531
+ # Strip markdown code fence if present
532
+ stripped_content = content.strip()
533
+ if stripped_content.startswith("```"):
534
+ # Remove opening fence (```json or ```)
535
+ lines = stripped_content.split("\n")
536
+ if lines[0].startswith("```"):
537
+ lines = lines[1:]
538
+ # Remove closing fence
539
+ if lines and lines[-1].strip() == "```":
540
+ lines = lines[:-1]
541
+ stripped_content = "\n".join(lines).strip()
542
+
543
+ try:
544
+ parsed_result = json.loads(stripped_content)
545
+ except json.JSONDecodeError:
546
+ # Content might not be JSON, use as-is
547
+ parsed_result = content
548
+ else:
549
+ # Content is already parsed
550
+ parsed_result = content
551
+ except Exception as e:
552
+ logger.error(f"Failed to extract result from unstable_internal_repr: {e}")
553
+
554
+ if parsed_result is None:
555
+ raise RuntimeError("Failed to get parsed result from BAML: no method worked")
556
+
557
+ # Convert the parsed result to a dict if it's not already
558
+ if isinstance(parsed_result, dict):
559
+ result_dict = parsed_result
560
+ else:
561
+ # If it's a Pydantic model or other object, try to convert to dict
562
+ try:
563
+ if hasattr(parsed_result, "model_dump"):
564
+ result_dict = parsed_result.model_dump()
565
+ elif hasattr(parsed_result, "dict"):
566
+ result_dict = parsed_result.dict()
567
+ elif hasattr(parsed_result, "__dict__"):
568
+ result_dict = parsed_result.__dict__
569
+ else:
570
+ # Fallback: just return the result as-is
571
+ return BamlExecutionResult(
572
+ result=parsed_result, raw_collector=raw_collector
573
+ )
574
+ except Exception as e:
575
+ logger.warning(
576
+ f"Failed to convert parsed result to dict: {e}, returning as-is"
577
+ )
578
+ return BamlExecutionResult(
579
+ result=parsed_result, raw_collector=raw_collector
580
+ )
581
+
582
+ # Try to extract the return type from the function definition
583
+ # Pattern: function FunctionName(...) -> ReturnType {
584
+ return_type_match = re.search(
585
+ rf"function\s+{re.escape(function_name)}\s*\([^)]*\)\s*->\s*(\w+)",
586
+ baml_source,
587
+ )
588
+
589
+ if return_type_match:
590
+ return_type_name = return_type_match.group(1)
591
+
592
+ # Try to create a Pydantic model from the BAML class definition
593
+ pydantic_model = parse_baml_class_to_pydantic(baml_source, return_type_name)
594
+
595
+ if pydantic_model:
596
+ # Return a Pydantic model instance
597
+ try:
598
+ return BamlExecutionResult(
599
+ result=pydantic_model(**result_dict),
600
+ raw_collector=raw_collector,
601
+ )
602
+ except Exception as e:
603
+ logger.warning(
604
+ f"Failed to instantiate Pydantic model: {e}, returning dict"
605
+ )
606
+ return BamlExecutionResult(
607
+ result=result_dict, raw_collector=raw_collector
608
+ )
609
+
610
+ # If we couldn't create a Pydantic model, return the dict
611
+ return BamlExecutionResult(result=result_dict, raw_collector=raw_collector)