ragaai-catalyst 2.1.5b20__py3-none-any.whl → 2.1.5b22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragaai_catalyst/dataset.py +54 -1
- ragaai_catalyst/synthetic_data_generation.py +39 -6
- ragaai_catalyst/tracers/agentic_tracing/tracers/agent_tracer.py +28 -18
- ragaai_catalyst/tracers/agentic_tracing/tracers/base.py +3 -1
- ragaai_catalyst/tracers/agentic_tracing/tracers/custom_tracer.py +17 -7
- ragaai_catalyst/tracers/agentic_tracing/tracers/llm_tracer.py +106 -16
- ragaai_catalyst/tracers/agentic_tracing/tracers/main_tracer.py +1 -2
- ragaai_catalyst/tracers/agentic_tracing/tracers/tool_tracer.py +17 -6
- ragaai_catalyst/tracers/agentic_tracing/upload/upload_trace_metric.py +6 -5
- ragaai_catalyst/tracers/agentic_tracing/utils/file_name_tracker.py +21 -2
- ragaai_catalyst/tracers/agentic_tracing/utils/llm_utils.py +30 -11
- ragaai_catalyst/tracers/agentic_tracing/utils/model_costs.json +1204 -484
- ragaai_catalyst/tracers/agentic_tracing/utils/span_attributes.py +35 -0
- ragaai_catalyst/tracers/agentic_tracing/utils/trace_utils.py +0 -32
- ragaai_catalyst/tracers/distributed.py +7 -3
- ragaai_catalyst/tracers/tracer.py +25 -8
- ragaai_catalyst/tracers/utils/langchain_tracer_extraction_logic.py +5 -4
- {ragaai_catalyst-2.1.5b20.dist-info → ragaai_catalyst-2.1.5b22.dist-info}/METADATA +2 -2
- {ragaai_catalyst-2.1.5b20.dist-info → ragaai_catalyst-2.1.5b22.dist-info}/RECORD +22 -22
- {ragaai_catalyst-2.1.5b20.dist-info → ragaai_catalyst-2.1.5b22.dist-info}/LICENSE +0 -0
- {ragaai_catalyst-2.1.5b20.dist-info → ragaai_catalyst-2.1.5b22.dist-info}/WHEEL +0 -0
- {ragaai_catalyst-2.1.5b20.dist-info → ragaai_catalyst-2.1.5b22.dist-info}/top_level.txt +0 -0
ragaai_catalyst/dataset.py
CHANGED
@@ -9,6 +9,10 @@ import pandas as pd
|
|
9
9
|
logger = logging.getLogger(__name__)
|
10
10
|
get_token = RagaAICatalyst.get_token
|
11
11
|
|
12
|
+
# Job status constants
|
13
|
+
JOB_STATUS_FAILED = "failed"
|
14
|
+
JOB_STATUS_IN_PROGRESS = "in_progress"
|
15
|
+
JOB_STATUS_COMPLETED = "success"
|
12
16
|
|
13
17
|
class Dataset:
|
14
18
|
BASE_URL = None
|
@@ -18,6 +22,7 @@ class Dataset:
|
|
18
22
|
self.project_name = project_name
|
19
23
|
self.num_projects = 99999
|
20
24
|
Dataset.BASE_URL = RagaAICatalyst.BASE_URL
|
25
|
+
self.jobId = None
|
21
26
|
headers = {
|
22
27
|
"Authorization": f'Bearer {os.getenv("RAGAAI_CATALYST_TOKEN")}',
|
23
28
|
}
|
@@ -219,7 +224,6 @@ class Dataset:
|
|
219
224
|
try:
|
220
225
|
|
221
226
|
put_csv_response = put_csv_to_presignedUrl(url)
|
222
|
-
print(put_csv_response)
|
223
227
|
if put_csv_response.status_code not in (200, 201):
|
224
228
|
raise ValueError('Unable to put csv to the presignedUrl')
|
225
229
|
except Exception as e:
|
@@ -269,6 +273,7 @@ class Dataset:
|
|
269
273
|
raise ValueError('Unable to upload csv')
|
270
274
|
else:
|
271
275
|
print(upload_csv_response['message'])
|
276
|
+
self.jobId = upload_csv_response['data']['jobId']
|
272
277
|
except Exception as e:
|
273
278
|
logger.error(f"Error in create_from_csv: {e}")
|
274
279
|
raise
|
@@ -436,6 +441,7 @@ class Dataset:
|
|
436
441
|
response_data = response.json()
|
437
442
|
if response_data.get('success', False):
|
438
443
|
print(f"{response_data['message']}")
|
444
|
+
self.jobId = response_data['data']['jobId']
|
439
445
|
else:
|
440
446
|
raise ValueError(response_data.get('message', 'Failed to add rows'))
|
441
447
|
|
@@ -594,6 +600,7 @@ class Dataset:
|
|
594
600
|
|
595
601
|
if response_data.get('success', False):
|
596
602
|
print(f"Column '{column_name}' added successfully to dataset '{dataset_name}'")
|
603
|
+
self.jobId = response_data['data']['jobId']
|
597
604
|
else:
|
598
605
|
raise ValueError(response_data.get('message', 'Failed to add column'))
|
599
606
|
|
@@ -601,3 +608,49 @@ class Dataset:
|
|
601
608
|
print(f"Error adding column: {e}")
|
602
609
|
raise
|
603
610
|
|
611
|
+
def get_status(self):
|
612
|
+
headers = {
|
613
|
+
'Content-Type': 'application/json',
|
614
|
+
"Authorization": f"Bearer {os.getenv('RAGAAI_CATALYST_TOKEN')}",
|
615
|
+
'X-Project-Id': str(self.project_id),
|
616
|
+
}
|
617
|
+
try:
|
618
|
+
response = requests.get(
|
619
|
+
f'{Dataset.BASE_URL}/job/status',
|
620
|
+
headers=headers,
|
621
|
+
timeout=30)
|
622
|
+
response.raise_for_status()
|
623
|
+
if response.json()["success"]:
|
624
|
+
|
625
|
+
status_json = [item["status"] for item in response.json()["data"]["content"] if item["id"]==self.jobId]
|
626
|
+
status_json = status_json[0]
|
627
|
+
if status_json == "Failed":
|
628
|
+
print("Job failed. No results to fetch.")
|
629
|
+
return JOB_STATUS_FAILED
|
630
|
+
elif status_json == "In Progress":
|
631
|
+
print(f"Job in progress. Please wait while the job completes.\nVisit Job Status: {Dataset.BASE_URL.removesuffix('/api')}/projects/job-status?projectId={self.project_id} to track")
|
632
|
+
return JOB_STATUS_IN_PROGRESS
|
633
|
+
elif status_json == "Completed":
|
634
|
+
print(f"Job completed. Fetching results.\nVisit Job Status: {Dataset.BASE_URL.removesuffix('/api')}/projects/job-status?projectId={self.project_id} to check")
|
635
|
+
return JOB_STATUS_COMPLETED
|
636
|
+
else:
|
637
|
+
logger.error(f"Unknown status received: {status_json}")
|
638
|
+
return JOB_STATUS_FAILED
|
639
|
+
else:
|
640
|
+
logger.error("Request was not successful")
|
641
|
+
return JOB_STATUS_FAILED
|
642
|
+
except requests.exceptions.HTTPError as http_err:
|
643
|
+
logger.error(f"HTTP error occurred: {http_err}")
|
644
|
+
return JOB_STATUS_FAILED
|
645
|
+
except requests.exceptions.ConnectionError as conn_err:
|
646
|
+
logger.error(f"Connection error occurred: {conn_err}")
|
647
|
+
return JOB_STATUS_FAILED
|
648
|
+
except requests.exceptions.Timeout as timeout_err:
|
649
|
+
logger.error(f"Timeout error occurred: {timeout_err}")
|
650
|
+
return JOB_STATUS_FAILED
|
651
|
+
except requests.exceptions.RequestException as req_err:
|
652
|
+
logger.error(f"An error occurred: {req_err}")
|
653
|
+
return JOB_STATUS_FAILED
|
654
|
+
except Exception as e:
|
655
|
+
logger.error(f"An unexpected error occurred: {e}")
|
656
|
+
return JOB_STATUS_FAILED
|
@@ -8,7 +8,9 @@ import markdown
|
|
8
8
|
import pandas as pd
|
9
9
|
import json
|
10
10
|
from litellm import completion
|
11
|
+
import litellm
|
11
12
|
from tqdm import tqdm
|
13
|
+
import tiktoken
|
12
14
|
# import internal_api_completion
|
13
15
|
# import proxy_call
|
14
16
|
from .internal_api_completion import api_completion as internal_api_completion
|
@@ -48,13 +50,18 @@ class SyntheticDataGeneration:
|
|
48
50
|
Raises:
|
49
51
|
ValueError: If an invalid provider is specified or API key is missing.
|
50
52
|
"""
|
53
|
+
text_validity = self.validate_input(text)
|
54
|
+
if text_validity:
|
55
|
+
raise ValueError(text_validity)
|
56
|
+
|
51
57
|
BATCH_SIZE = 5 # Optimal batch size for maintaining response quality
|
52
58
|
provider = model_config.get("provider")
|
53
59
|
model = model_config.get("model")
|
54
60
|
api_base = model_config.get("api_base")
|
61
|
+
api_version = model_config.get("api_version")
|
55
62
|
|
56
63
|
# Initialize the appropriate client based on provider
|
57
|
-
self._initialize_client(provider, api_key, api_base, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
|
64
|
+
self._initialize_client(provider, api_key, api_base, api_version, internal_llm_proxy=kwargs.get("internal_llm_proxy", None))
|
58
65
|
|
59
66
|
# Initialize progress bar
|
60
67
|
pbar = tqdm(total=n, desc="Generating QA pairs")
|
@@ -88,7 +95,7 @@ class SyntheticDataGeneration:
|
|
88
95
|
pbar.update(len(batch_df))
|
89
96
|
|
90
97
|
except Exception as e:
|
91
|
-
print(f"Batch generation failed
|
98
|
+
print(f"Batch generation failed:{str(e)}")
|
92
99
|
|
93
100
|
if any(error in str(e) for error in FAILURE_CASES):
|
94
101
|
raise Exception(f"{e}")
|
@@ -139,7 +146,7 @@ class SyntheticDataGeneration:
|
|
139
146
|
|
140
147
|
return final_df
|
141
148
|
|
142
|
-
def _initialize_client(self, provider, api_key, api_base=None, internal_llm_proxy=None):
|
149
|
+
def _initialize_client(self, provider, api_key, api_base=None, api_version=None, internal_llm_proxy=None):
|
143
150
|
"""Initialize the appropriate client based on provider."""
|
144
151
|
if not provider:
|
145
152
|
raise ValueError("Model configuration must be provided with a valid provider and model.")
|
@@ -158,7 +165,17 @@ class SyntheticDataGeneration:
|
|
158
165
|
if api_key is None and os.getenv("OPENAI_API_KEY") is None and internal_llm_proxy is None:
|
159
166
|
raise ValueError("API key must be provided for OpenAI.")
|
160
167
|
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
161
|
-
|
168
|
+
|
169
|
+
elif provider == "azure":
|
170
|
+
if api_key is None and os.getenv("AZURE_API_KEY") is None and internal_llm_proxy is None:
|
171
|
+
raise ValueError("API key must be provided for Azure.")
|
172
|
+
litellm.api_key = api_key or os.getenv("AZURE_API_KEY")
|
173
|
+
if api_base is None and os.getenv("AZURE_API_BASE") is None and internal_llm_proxy is None:
|
174
|
+
raise ValueError("API Base must be provided for Azure.")
|
175
|
+
litellm.api_base = api_base or os.getenv("AZURE_API_BASE")
|
176
|
+
if api_version is None and os.getenv("AZURE_API_VERSION") is None and internal_llm_proxy is None:
|
177
|
+
raise ValueError("API version must be provided for Azure.")
|
178
|
+
litellm.api_version = api_version or os.getenv("AZURE_API_VERSION")
|
162
179
|
else:
|
163
180
|
raise ValueError(f"Provider is not recognized.")
|
164
181
|
|
@@ -189,7 +206,15 @@ class SyntheticDataGeneration:
|
|
189
206
|
kwargs=kwargs
|
190
207
|
)
|
191
208
|
|
209
|
+
def validate_input(self,text):
|
192
210
|
|
211
|
+
if not text.strip():
|
212
|
+
return 'Empty Text provided for qna generation. Please provide valid text'
|
213
|
+
encoding = tiktoken.encoding_for_model("gpt-4")
|
214
|
+
tokens = encoding.encode(text)
|
215
|
+
if len(tokens)<5:
|
216
|
+
return 'Very Small Text provided for qna generation. Please provide longer text'
|
217
|
+
return False
|
193
218
|
|
194
219
|
|
195
220
|
def _get_system_message(self, question_type, n):
|
@@ -274,10 +299,14 @@ class SyntheticDataGeneration:
|
|
274
299
|
# Add optional parameters if they exist in model_config
|
275
300
|
if "api_base" in model_config:
|
276
301
|
completion_params["api_base"] = model_config["api_base"]
|
302
|
+
if "api_version" in model_config:
|
303
|
+
completion_params["api_version"] = model_config["api_version"]
|
277
304
|
if "max_tokens" in model_config:
|
278
305
|
completion_params["max_tokens"] = model_config["max_tokens"]
|
279
306
|
if "temperature" in model_config:
|
280
307
|
completion_params["temperature"] = model_config["temperature"]
|
308
|
+
if 'provider' in model_config:
|
309
|
+
completion_params['model'] = f'{model_config["provider"]}/{model_config["model"]}'
|
281
310
|
|
282
311
|
# Make the API call using LiteLLM
|
283
312
|
try:
|
@@ -318,9 +347,13 @@ class SyntheticDataGeneration:
|
|
318
347
|
list_start_index = data.find('[') # Find the index of the first '['
|
319
348
|
substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
|
320
349
|
data = substring_data
|
321
|
-
|
350
|
+
elif provider == "azure":
|
351
|
+
data = response.choices[0].message.content.replace('\n', '')
|
352
|
+
list_start_index = data.find('[') # Find the index of the first '['
|
353
|
+
substring_data = data[list_start_index:] if list_start_index != -1 else data # Slice from the list start
|
354
|
+
data = substring_data
|
322
355
|
else:
|
323
|
-
raise ValueError("Invalid provider. Choose 'groq', 'gemini', or 'openai'.")
|
356
|
+
raise ValueError("Invalid provider. Choose 'groq', 'gemini', 'azure' or 'openai'.")
|
324
357
|
try:
|
325
358
|
json_data = json.loads(data)
|
326
359
|
return pd.DataFrame(json_data)
|
@@ -101,7 +101,10 @@ class AgentTracerMixin:
|
|
101
101
|
original_init = target.__init__
|
102
102
|
|
103
103
|
def wrapped_init(self, *args, **kwargs):
|
104
|
-
|
104
|
+
gt = kwargs.get("gt") if kwargs else None
|
105
|
+
if gt is not None:
|
106
|
+
span = self.span(name)
|
107
|
+
span.add_gt(gt)
|
105
108
|
# Set agent context before initializing
|
106
109
|
component_id = str(uuid.uuid4())
|
107
110
|
hash_id = top_level_hash_id
|
@@ -159,7 +162,10 @@ class AgentTracerMixin:
|
|
159
162
|
@self.file_tracker.trace_decorator
|
160
163
|
@functools.wraps(method)
|
161
164
|
def wrapped_method(self, *args, **kwargs):
|
162
|
-
|
165
|
+
gt = kwargs.get("gt") if kwargs else None
|
166
|
+
if gt is not None:
|
167
|
+
span = tracer.span(name)
|
168
|
+
span.add_gt(gt)
|
163
169
|
# Set this agent as current during method execution
|
164
170
|
token = tracer.current_agent_id.set(
|
165
171
|
self._agent_component_id
|
@@ -247,6 +253,7 @@ class AgentTracerMixin:
|
|
247
253
|
agent_type,
|
248
254
|
version,
|
249
255
|
capabilities,
|
256
|
+
top_level_hash_id,
|
250
257
|
*args,
|
251
258
|
**kwargs,
|
252
259
|
)
|
@@ -256,10 +263,9 @@ class AgentTracerMixin:
|
|
256
263
|
return decorator
|
257
264
|
|
258
265
|
def _trace_sync_agent_execution(
|
259
|
-
self, func, name, agent_type, version, capabilities, *args, **kwargs
|
266
|
+
self, func, name, agent_type, version, capabilities, top_level_hash_id, *args, **kwargs
|
260
267
|
):
|
261
|
-
|
262
|
-
hash_id = str(uuid.uuid4())
|
268
|
+
hash_id = top_level_hash_id
|
263
269
|
|
264
270
|
"""Synchronous version of agent tracing"""
|
265
271
|
if not self.is_active:
|
@@ -275,7 +281,10 @@ class AgentTracerMixin:
|
|
275
281
|
component_id = str(uuid.uuid4())
|
276
282
|
|
277
283
|
# Extract ground truth if present
|
278
|
-
ground_truth = kwargs.pop("gt"
|
284
|
+
ground_truth = kwargs.pop("gt") if kwargs else None
|
285
|
+
if ground_truth is not None:
|
286
|
+
span = self.span(name)
|
287
|
+
span.add_gt(ground_truth)
|
279
288
|
|
280
289
|
# Get parent agent ID if exists
|
281
290
|
parent_agent_id = self.current_agent_id.get()
|
@@ -293,7 +302,7 @@ class AgentTracerMixin:
|
|
293
302
|
|
294
303
|
try:
|
295
304
|
# Execute the agent
|
296
|
-
result = func(*args, **kwargs)
|
305
|
+
result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
|
297
306
|
|
298
307
|
# Calculate resource usage
|
299
308
|
end_memory = psutil.Process().memory_info().rss
|
@@ -320,9 +329,6 @@ class AgentTracerMixin:
|
|
320
329
|
children=children,
|
321
330
|
parent_id=parent_agent_id,
|
322
331
|
)
|
323
|
-
# Add ground truth to component data if present
|
324
|
-
if ground_truth is not None:
|
325
|
-
agent_component["data"]["gt"] = ground_truth
|
326
332
|
|
327
333
|
# Add this component as a child to parent's children list
|
328
334
|
parent_children.append(agent_component)
|
@@ -398,7 +404,10 @@ class AgentTracerMixin:
|
|
398
404
|
component_id = str(uuid.uuid4())
|
399
405
|
|
400
406
|
# Extract ground truth if present
|
401
|
-
ground_truth = kwargs.pop("gt"
|
407
|
+
ground_truth = kwargs.pop("gt") if kwargs else None
|
408
|
+
if ground_truth is not None:
|
409
|
+
span = self.span(name)
|
410
|
+
span.add_gt(ground_truth)
|
402
411
|
|
403
412
|
# Get parent agent ID if exists
|
404
413
|
parent_agent_id = self.current_agent_id.get()
|
@@ -414,7 +423,7 @@ class AgentTracerMixin:
|
|
414
423
|
|
415
424
|
try:
|
416
425
|
# Execute the agent
|
417
|
-
result = await func(*args, **kwargs)
|
426
|
+
result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
|
418
427
|
|
419
428
|
# Calculate resource usage
|
420
429
|
end_memory = psutil.Process().memory_info().rss
|
@@ -441,10 +450,6 @@ class AgentTracerMixin:
|
|
441
450
|
parent_id=parent_agent_id,
|
442
451
|
)
|
443
452
|
|
444
|
-
# Add ground truth to component data if present
|
445
|
-
if ground_truth is not None:
|
446
|
-
agent_component["data"]["gt"] = ground_truth
|
447
|
-
|
448
453
|
# Add this component as a child to parent's children list
|
449
454
|
parent_children.append(agent_component)
|
450
455
|
self.agent_children.set(parent_children)
|
@@ -576,8 +581,13 @@ class AgentTracerMixin:
|
|
576
581
|
"interactions": interactions,
|
577
582
|
}
|
578
583
|
|
579
|
-
if self.
|
580
|
-
|
584
|
+
if name in self.span_attributes_dict:
|
585
|
+
span_gt = self.span_attributes_dict[name].gt
|
586
|
+
if span_gt is not None:
|
587
|
+
component["data"]["gt"] = span_gt
|
588
|
+
span_context = self.span_attributes_dict[name].context
|
589
|
+
if span_context:
|
590
|
+
component["data"]["context"] = span_context
|
581
591
|
|
582
592
|
# Reset the SpanAttributes context variable
|
583
593
|
self.span_attributes_dict[kwargs["name"]] = SpanAttributes(kwargs["name"])
|
@@ -83,6 +83,7 @@ class BaseTracer:
|
|
83
83
|
self.tracking_thread = None
|
84
84
|
self.tracking = False
|
85
85
|
self.system_monitor = None
|
86
|
+
self.gt = None
|
86
87
|
|
87
88
|
def _get_system_info(self) -> SystemInfo:
|
88
89
|
return self.system_monitor.get_system_info()
|
@@ -249,7 +250,8 @@ class BaseTracer:
|
|
249
250
|
|
250
251
|
# Format interactions and add to trace
|
251
252
|
interactions = self.format_interactions()
|
252
|
-
trace_data["workflow"] = interactions["workflow"]
|
253
|
+
# trace_data["workflow"] = interactions["workflow"]
|
254
|
+
cleaned_trace_data["workflow"] = interactions["workflow"]
|
253
255
|
|
254
256
|
with open(filepath, "w") as f:
|
255
257
|
json.dump(cleaned_trace_data, f, cls=TracerJSONEncoder, indent=2)
|
@@ -45,7 +45,10 @@ class CustomTracerMixin:
|
|
45
45
|
@functools.wraps(func)
|
46
46
|
async def async_wrapper(*args, **kwargs):
|
47
47
|
async_wrapper.metadata = metadata
|
48
|
-
|
48
|
+
gt = kwargs.get('gt') if kwargs else None
|
49
|
+
if gt is not None:
|
50
|
+
span = self.span(name)
|
51
|
+
span.add_gt(gt)
|
49
52
|
return await self._trace_custom_execution(
|
50
53
|
func, name or func.__name__, custom_type, version, trace_variables, *args, **kwargs
|
51
54
|
)
|
@@ -54,7 +57,10 @@ class CustomTracerMixin:
|
|
54
57
|
@functools.wraps(func)
|
55
58
|
def sync_wrapper(*args, **kwargs):
|
56
59
|
sync_wrapper.metadata = metadata
|
57
|
-
|
60
|
+
gt = kwargs.get('gt') if kwargs else None
|
61
|
+
if gt is not None:
|
62
|
+
span = self.span(name)
|
63
|
+
span.add_gt(gt)
|
58
64
|
return self._trace_sync_custom_execution(
|
59
65
|
func, name or func.__name__, custom_type, version, trace_variables, *args, **kwargs
|
60
66
|
)
|
@@ -98,7 +104,7 @@ class CustomTracerMixin:
|
|
98
104
|
|
99
105
|
try:
|
100
106
|
# Execute the function
|
101
|
-
result = func(*args, **kwargs)
|
107
|
+
result = self.file_tracker.trace_wrapper(func)(*args, **kwargs)
|
102
108
|
|
103
109
|
# Calculate resource usage
|
104
110
|
end_time = datetime.now().astimezone().isoformat()
|
@@ -186,7 +192,7 @@ class CustomTracerMixin:
|
|
186
192
|
|
187
193
|
try:
|
188
194
|
# Execute the function
|
189
|
-
result = await func(*args, **kwargs)
|
195
|
+
result = await self.file_tracker.trace_wrapper(func)(*args, **kwargs)
|
190
196
|
|
191
197
|
# Calculate resource usage
|
192
198
|
end_time = datetime.now().astimezone().isoformat()
|
@@ -284,9 +290,13 @@ class CustomTracerMixin:
|
|
284
290
|
"interactions": interactions
|
285
291
|
}
|
286
292
|
|
287
|
-
if self.
|
288
|
-
|
289
|
-
|
293
|
+
if kwargs["name"] in self.span_attributes_dict:
|
294
|
+
span_gt = self.span_attributes_dict[kwargs["name"]].gt
|
295
|
+
if span_gt is not None:
|
296
|
+
component["data"]["gt"] = span_gt
|
297
|
+
span_context = self.span_attributes_dict[kwargs["name"]].context
|
298
|
+
if span_context:
|
299
|
+
component["data"]["context"] = span_context
|
290
300
|
return component
|
291
301
|
|
292
302
|
def start_component(self, component_id):
|
@@ -12,6 +12,7 @@ import contextvars
|
|
12
12
|
import traceback
|
13
13
|
import importlib
|
14
14
|
import sys
|
15
|
+
from litellm import model_cost
|
15
16
|
|
16
17
|
from ..utils.llm_utils import (
|
17
18
|
extract_model_name,
|
@@ -24,8 +25,7 @@ from ..utils.llm_utils import (
|
|
24
25
|
extract_llm_output,
|
25
26
|
num_tokens_from_messages
|
26
27
|
)
|
27
|
-
from ..utils.
|
28
|
-
from ..utils.unique_decorator import generate_unique_hash_simple
|
28
|
+
from ..utils.unique_decorator import generate_unique_hash
|
29
29
|
from ..utils.file_name_tracker import TrackName
|
30
30
|
from ..utils.span_attributes import SpanAttributes
|
31
31
|
import logging
|
@@ -44,7 +44,7 @@ class LLMTracerMixin:
|
|
44
44
|
self.file_tracker = TrackName()
|
45
45
|
self.patches = []
|
46
46
|
try:
|
47
|
-
self.model_costs =
|
47
|
+
self.model_costs = model_cost
|
48
48
|
except Exception as e:
|
49
49
|
self.model_costs = {
|
50
50
|
"default": {"input_cost_per_token": 0.0, "output_cost_per_token": 0.0}
|
@@ -97,6 +97,10 @@ class LLMTracerMixin:
|
|
97
97
|
self.patch_langchain_google_methods(sys.modules["langchain_google_vertexai"])
|
98
98
|
if "langchain_google_genai" in sys.modules:
|
99
99
|
self.patch_langchain_google_methods(sys.modules["langchain_google_genai"])
|
100
|
+
if "langchain_openai" in sys.modules:
|
101
|
+
self.patch_langchain_openai_methods(sys.modules["langchain_openai"])
|
102
|
+
if "langchain_anthropic" in sys.modules:
|
103
|
+
self.patch_langchain_anthropic_methods(sys.modules["langchain_anthropic"])
|
100
104
|
|
101
105
|
# Register hooks for future imports with availability checks
|
102
106
|
if self.check_package_available("vertexai"):
|
@@ -130,6 +134,15 @@ class LLMTracerMixin:
|
|
130
134
|
wrapt.register_post_import_hook(
|
131
135
|
self.patch_langchain_google_methods, "langchain_google_genai"
|
132
136
|
)
|
137
|
+
|
138
|
+
if self.check_package_available("langchain_openai"):
|
139
|
+
wrapt.register_post_import_hook(
|
140
|
+
self.patch_langchain_openai_methods, "langchain_openai"
|
141
|
+
)
|
142
|
+
if self.check_package_available("langchain_anthropic"):
|
143
|
+
wrapt.register_post_import_hook(
|
144
|
+
self.patch_langchain_anthropic_methods, "langchain_anthropic"
|
145
|
+
)
|
133
146
|
|
134
147
|
def instrument_user_interaction_calls(self):
|
135
148
|
"""Enable user interaction instrumentation for LLM calls"""
|
@@ -154,6 +167,42 @@ class LLMTracerMixin:
|
|
154
167
|
except Exception as e:
|
155
168
|
# Log the error but continue execution
|
156
169
|
print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
|
170
|
+
|
171
|
+
def patch_langchain_openai_methods(self, module):
|
172
|
+
try:
|
173
|
+
if hasattr(module, 'ChatOpenAI'):
|
174
|
+
client_class = getattr(module, "ChatOpenAI")
|
175
|
+
|
176
|
+
if hasattr(client_class, "invoke"):
|
177
|
+
self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.invoke")
|
178
|
+
elif hasattr(client_class, "run"):
|
179
|
+
self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.run")
|
180
|
+
if hasattr(module, 'AsyncChatOpenAI'):
|
181
|
+
if hasattr(client_class, "ainvoke"):
|
182
|
+
self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.ainvoke")
|
183
|
+
elif hasattr(client_class, "arun"):
|
184
|
+
self.wrap_langchain_openai_method(client_class, f"{client_class.__name__}.arun")
|
185
|
+
except Exception as e:
|
186
|
+
# Log the error but continue execution
|
187
|
+
print(f"Warning: Failed to patch OpenAI methods: {str(e)}")
|
188
|
+
|
189
|
+
def patch_langchain_anthropic_methods(self, module):
|
190
|
+
try:
|
191
|
+
if hasattr(module, 'ChatAnthropic'):
|
192
|
+
client_class = getattr(module, "ChatAnthropic")
|
193
|
+
if hasattr(client_class, "invoke"):
|
194
|
+
self.wrap_langchain_anthropic_method(client_class, f"{client_class.__name__}.invoke")
|
195
|
+
if hasattr(client_class, "ainvoke"):
|
196
|
+
self.wrap_langchain_anthropic_method(client_class, f"{client_class.__name__}.ainvoke")
|
197
|
+
if hasattr(module, 'AsyncChatAnthropic'):
|
198
|
+
async_client_class = getattr(module, "AsyncChatAnthropic")
|
199
|
+
if hasattr(async_client_class, "ainvoke"):
|
200
|
+
self.wrap_langchain_anthropic_method(async_client_class, f"{async_client_class.__name__}.ainvoke")
|
201
|
+
if hasattr(async_client_class, "arun"):
|
202
|
+
self.wrap_langchain_anthropic_method(async_client_class, f"{async_client_class.__name__}.arun")
|
203
|
+
except Exception as e:
|
204
|
+
# Log the error but continue execution
|
205
|
+
print(f"Warning: Failed to patch Anthropic methods: {str(e)}")
|
157
206
|
|
158
207
|
def patch_openai_beta_methods(self, openai_module):
|
159
208
|
"""
|
@@ -293,7 +342,6 @@ class LLMTracerMixin:
|
|
293
342
|
return await self.trace_llm_call(
|
294
343
|
original_create, *args, **kwargs
|
295
344
|
)
|
296
|
-
|
297
345
|
client_self.chat.completions.create = wrapped_create
|
298
346
|
else:
|
299
347
|
# Patch sync methods for OpenAI
|
@@ -305,10 +353,39 @@ class LLMTracerMixin:
|
|
305
353
|
return self.trace_llm_call_sync(
|
306
354
|
original_create, *args, **kwargs
|
307
355
|
)
|
308
|
-
|
309
356
|
client_self.chat.completions.create = wrapped_create
|
310
357
|
|
311
358
|
setattr(client_class, "__init__", patched_init)
|
359
|
+
|
360
|
+
def wrap_langchain_openai_method(self, client_class, method_name):
|
361
|
+
method = method_name.split(".")[-1]
|
362
|
+
original_init = getattr(client_class, method)
|
363
|
+
|
364
|
+
@functools.wraps(original_init)
|
365
|
+
def patched_init(*args, **kwargs):
|
366
|
+
# Check if this is AsyncOpenAI or OpenAI
|
367
|
+
is_async = "AsyncChatOpenAI" in client_class.__name__
|
368
|
+
|
369
|
+
if is_async:
|
370
|
+
return self.trace_llm_call(original_init, *args, **kwargs)
|
371
|
+
else:
|
372
|
+
return self.trace_llm_call_sync(original_init, *args, **kwargs)
|
373
|
+
|
374
|
+
setattr(client_class, method, patched_init)
|
375
|
+
|
376
|
+
def wrap_langchain_anthropic_method(self, client_class, method_name):
|
377
|
+
original_init = getattr(client_class, method_name)
|
378
|
+
|
379
|
+
@functools.wraps(original_init)
|
380
|
+
def patched_init(*args, **kwargs):
|
381
|
+
is_async = "AsyncChatAnthropic" in client_class.__name__
|
382
|
+
|
383
|
+
if is_async:
|
384
|
+
return self.trace_llm_call(original_init, *args, **kwargs)
|
385
|
+
else:
|
386
|
+
return self.trace_llm_call_sync(original_init, *args, **kwargs)
|
387
|
+
|
388
|
+
setattr(client_class, method_name, patched_init)
|
312
389
|
|
313
390
|
def wrap_anthropic_client_methods(self, client_class):
|
314
391
|
original_init = client_class.__init__
|
@@ -475,8 +552,13 @@ class LLMTracerMixin:
|
|
475
552
|
"interactions": interactions,
|
476
553
|
}
|
477
554
|
|
478
|
-
if self.
|
479
|
-
|
555
|
+
if name in self.span_attributes_dict:
|
556
|
+
span_gt = self.span_attributes_dict[name].gt
|
557
|
+
if span_gt is not None:
|
558
|
+
component["data"]["gt"] = span_gt
|
559
|
+
span_context = self.span_attributes_dict[name].context
|
560
|
+
if span_context:
|
561
|
+
component["data"]["context"] = span_context
|
480
562
|
|
481
563
|
# Reset the SpanAttributes context variable
|
482
564
|
self.span_attributes_dict[name] = SpanAttributes(name)
|
@@ -503,7 +585,7 @@ class LLMTracerMixin:
|
|
503
585
|
start_time = datetime.now().astimezone()
|
504
586
|
start_memory = psutil.Process().memory_info().rss
|
505
587
|
component_id = str(uuid.uuid4())
|
506
|
-
hash_id =
|
588
|
+
hash_id = generate_unique_hash(original_func, args, kwargs)
|
507
589
|
|
508
590
|
# Start tracking network calls for this component
|
509
591
|
self.start_component(component_id)
|
@@ -605,7 +687,7 @@ class LLMTracerMixin:
|
|
605
687
|
|
606
688
|
start_time = datetime.now().astimezone()
|
607
689
|
component_id = str(uuid.uuid4())
|
608
|
-
hash_id =
|
690
|
+
hash_id = generate_unique_hash(original_func, args, kwargs)
|
609
691
|
|
610
692
|
# Start tracking network calls for this component
|
611
693
|
self.start_component(component_id)
|
@@ -745,12 +827,14 @@ class LLMTracerMixin:
|
|
745
827
|
@self.file_tracker.trace_decorator
|
746
828
|
@functools.wraps(func)
|
747
829
|
async def async_wrapper(*args, **kwargs):
|
748
|
-
|
830
|
+
gt = kwargs.get("gt") if kwargs else None
|
831
|
+
if gt is not None:
|
832
|
+
span = self.span(name)
|
833
|
+
span.add_gt(gt)
|
749
834
|
self.current_llm_call_name.set(name)
|
750
835
|
if not self.is_active:
|
751
836
|
return await func(*args, **kwargs)
|
752
837
|
|
753
|
-
hash_id = generate_unique_hash_simple(func)
|
754
838
|
component_id = str(uuid.uuid4())
|
755
839
|
parent_agent_id = self.current_agent_id.get()
|
756
840
|
self.start_component(component_id)
|
@@ -777,8 +861,13 @@ class LLMTracerMixin:
|
|
777
861
|
if (name is not None) or (name != ""):
|
778
862
|
llm_component["name"] = name
|
779
863
|
|
780
|
-
if self.
|
781
|
-
|
864
|
+
if name in self.span_attributes_dict:
|
865
|
+
span_gt = self.span_attributes_dict[name].gt
|
866
|
+
if span_gt is not None:
|
867
|
+
llm_component["data"]["gt"] = span_gt
|
868
|
+
span_context = self.span_attributes_dict[name].context
|
869
|
+
if span_context:
|
870
|
+
llm_component["data"]["context"] = span_context
|
782
871
|
|
783
872
|
if error_info:
|
784
873
|
llm_component["error"] = error_info["error"]
|
@@ -811,13 +900,14 @@ class LLMTracerMixin:
|
|
811
900
|
@self.file_tracker.trace_decorator
|
812
901
|
@functools.wraps(func)
|
813
902
|
def sync_wrapper(*args, **kwargs):
|
814
|
-
|
903
|
+
gt = kwargs.get("gt") if kwargs else None
|
904
|
+
if gt is not None:
|
905
|
+
span = self.span(name)
|
906
|
+
span.add_gt(gt)
|
815
907
|
self.current_llm_call_name.set(name)
|
816
908
|
if not self.is_active:
|
817
909
|
return func(*args, **kwargs)
|
818
910
|
|
819
|
-
hash_id = generate_unique_hash_simple(func)
|
820
|
-
|
821
911
|
component_id = str(uuid.uuid4())
|
822
912
|
parent_agent_id = self.current_agent_id.get()
|
823
913
|
self.start_component(component_id)
|