ragaai-catalyst 2.1.5b25__py3-none-any.whl → 2.1.5b26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,7 @@ def init_tracing(
37
37
  secret_key: str = None,
38
38
  base_url: str = None,
39
39
  tracer: Tracer = None,
40
- catalyst: RagaAICatalyst = None,
40
+ catalyst: RagaAICatalyst = None,
41
41
  **kwargs
42
42
  ) -> None:
43
43
  """Initialize distributed tracing.
@@ -50,37 +50,20 @@ def init_tracing(
50
50
  base_url: RagaAI Catalyst API base URL
51
51
  tracer: Existing Tracer instance
52
52
  catalyst: Existing RagaAICatalyst instance
53
- **kwargs: Additional tracer configuration
53
+ **kwargs: Additional tracer parameters
54
54
  """
55
55
  global _global_tracer, _global_catalyst
56
56
 
57
57
  with _tracer_lock:
58
58
  if tracer and catalyst:
59
- _global_tracer = tracer
60
- _global_catalyst = catalyst
59
+ if isinstance(tracer, Tracer) and isinstance(catalyst, RagaAICatalyst):
60
+ _global_tracer = tracer
61
+ _global_catalyst = catalyst
62
+ else:
63
+ raise ValueError("Both Tracer and Catalyst objects must be instances of Tracer and RagaAICatalyst, respectively.")
61
64
  else:
62
- # Use env vars as fallback
63
- access_key = access_key or os.getenv("RAGAAI_CATALYST_ACCESS_KEY")
64
- secret_key = secret_key or os.getenv("RAGAAI_CATALYST_SECRET_KEY")
65
- base_url = base_url or os.getenv("RAGAAI_CATALYST_BASE_URL")
65
+ raise ValueError("Both Tracer and Catalyst objects must be provided.")
66
66
 
67
- if not all([access_key, secret_key]):
68
- raise ValueError(
69
- "Missing required credentials. Either provide access_key and secret_key "
70
- "or set RAGAAI_CATALYST_ACCESS_KEY and RAGAAI_CATALYST_SECRET_KEY environment variables."
71
- )
72
-
73
- _global_catalyst = RagaAICatalyst(
74
- access_key=access_key,
75
- secret_key=secret_key,
76
- base_url=base_url
77
- )
78
-
79
- _global_tracer = Tracer(
80
- project_name=project_name,
81
- dataset_name=dataset_name,
82
- **kwargs
83
- )
84
67
 
85
68
  def trace_agent(name: str = None, agent_type: str = "generic", version: str = "1.0.0", **kwargs):
86
69
  """Decorator for tracing agent functions."""
@@ -162,7 +145,7 @@ def trace_llm(name: str = None, model: str = None, **kwargs):
162
145
 
163
146
  try:
164
147
  # Just execute the function within the current span
165
- result = await tracer.file_tracker.trace_wrapper(func)(*args, **kwargs)
148
+ result = await func(*args, **kwargs)
166
149
  return result
167
150
  finally:
168
151
  # Reset using the stored token
@@ -180,7 +163,7 @@ def trace_llm(name: str = None, model: str = None, **kwargs):
180
163
 
181
164
  try:
182
165
  # Just execute the function within the current span
183
- result = tracer.file_tracker.trace_wrapper(func)(*args, **kwargs)
166
+ result = func(*args, **kwargs)
184
167
  return result
185
168
  finally:
186
169
  # Reset using the stored token
@@ -159,7 +159,7 @@ class LangchainTracer(BaseCallbackHandler):
159
159
  else:
160
160
  asyncio.run(self._async_save_trace(force))
161
161
 
162
- def _create_safe_wrapper(self, original_func, component_name):
162
+ def _create_safe_wrapper(self, original_func, component_name, method_name):
163
163
  """Create a safely wrapped version of an original function with enhanced error handling"""
164
164
 
165
165
  @wraps(original_func)
@@ -209,7 +209,59 @@ class LangchainTracer(BaseCallbackHandler):
209
209
 
210
210
  # Fallback to calling the original function without modifications
211
211
  return original_func(*args, **kwargs)
212
+
213
+ @wraps(original_func)
214
+ def wrapped_invoke(*args, **kwargs):
215
+ if not self._active:
216
+ return original_func(*args, **kwargs)
217
+
218
+ try:
219
+ # Deep copy kwargs to avoid modifying the original
220
+ kwargs_copy = kwargs.copy() if kwargs is not None else {}
221
+
222
+ # Handle different calling conventions
223
+ if 'config' not in kwargs_copy:
224
+ kwargs_copy['config'] = {'callbacks': [self]}
225
+ elif 'callbacks' not in kwargs_copy['config']:
226
+ kwargs_copy['config']['callbacks'] = [self]
227
+ elif self not in kwargs_copy['config']['callbacks']:
228
+ kwargs_copy['config']['callbacks'].append(self)
229
+
230
+ # Store model name if available
231
+ if component_name in ["OpenAI", "ChatOpenAI_LangchainOpenAI", "ChatOpenAI_ChatModels",
232
+ "ChatVertexAI", "VertexAI", "ChatGoogleGenerativeAI", "ChatAnthropic",
233
+ "ChatLiteLLM", "ChatBedrock", "AzureChatOpenAI", "ChatAnthropicVertex"]:
234
+ instance = args[0] if args else None
235
+ model_name = kwargs.get('model_name') or kwargs.get('model') or kwargs.get('model_id')
212
236
 
237
+ if instance and model_name:
238
+ self.model_names[id(instance)] = model_name
239
+
240
+ # Try different method signatures
241
+ try:
242
+ # First, try calling with modified kwargs
243
+ return original_func(*args, **kwargs_copy)
244
+ except TypeError:
245
+ # If that fails, try without kwargs
246
+ try:
247
+ return original_func(*args)
248
+ except Exception as e:
249
+ # If all else fails, use original call
250
+ logger.error(f"Failed to invoke {component_name} with modified callbacks: {e}")
251
+ return original_func(*args, **kwargs)
252
+
253
+ except Exception as e:
254
+ # Log any errors that occur during the function call
255
+ logger.error(f"Error in {component_name} wrapper: {e}")
256
+
257
+ # Record the error using the tracer's error handling method
258
+ self.on_error(e, context=f"wrapper_{component_name}")
259
+
260
+ # Fallback to calling the original function without modifications
261
+ return original_func(*args, **kwargs)
262
+
263
+ if method_name == 'invoke':
264
+ return wrapped_invoke
213
265
  return wrapped
214
266
 
215
267
 
@@ -287,6 +339,7 @@ class LangchainTracer(BaseCallbackHandler):
287
339
  from langchain.chains import create_retrieval_chain, RetrievalQA
288
340
  components_to_patch["RetrievalQA"] = (RetrievalQA, "from_chain_type")
289
341
  components_to_patch["create_retrieval_chain"] = (create_retrieval_chain, None)
342
+ components_to_patch['RetrievalQA.invoke'] = (RetrievalQA, 'invoke')
290
343
  except ImportError:
291
344
  logger.debug("Langchain chains not available for patching")
292
345
 
@@ -295,20 +348,20 @@ class LangchainTracer(BaseCallbackHandler):
295
348
  if method_name == "__init__":
296
349
  original = component.__init__
297
350
  self._original_inits[name] = original
298
- component.__init__ = self._create_safe_wrapper(original, name)
351
+ component.__init__ = self._create_safe_wrapper(original, name, method_name)
299
352
  elif method_name:
300
353
  original = getattr(component, method_name)
301
354
  self._original_methods[name] = original
302
355
  if isinstance(original, classmethod):
303
356
  wrapped = classmethod(
304
- self._create_safe_wrapper(original.__func__, name)
357
+ self._create_safe_wrapper(original.__func__, name, method_name)
305
358
  )
306
359
  else:
307
- wrapped = self._create_safe_wrapper(original, name)
360
+ wrapped = self._create_safe_wrapper(original, name, method_name)
308
361
  setattr(component, method_name, wrapped)
309
362
  else:
310
363
  self._original_methods[name] = component
311
- globals()[name] = self._create_safe_wrapper(component, name)
364
+ globals()[name] = self._create_safe_wrapper(component, name, method_name)
312
365
  except Exception as e:
313
366
  logger.error(f"Error patching {name}: {e}")
314
367
  self.on_error(e, context=f"patch_{name}")
@@ -354,7 +407,7 @@ class LangchainTracer(BaseCallbackHandler):
354
407
  elif name == "ChatOpenAI_ChatModels":
355
408
  from langchain.chat_models import ChatOpenAI as ChatOpenAI_ChatModels
356
409
  imported_components[name] = ChatOpenAI_ChatModels
357
- elif name in ["RetrievalQA", "create_retrieval_chain"]:
410
+ elif name in ["RetrievalQA", "create_retrieval_chain", 'RetrievalQA.invoke']:
358
411
  from langchain.chains import create_retrieval_chain, RetrievalQA
359
412
  imported_components["RetrievalQA"] = RetrievalQA
360
413
  imported_components["create_retrieval_chain"] = create_retrieval_chain
@@ -0,0 +1,424 @@
1
+ from configparser import InterpolationMissingOptionError
2
+ import json
3
+ from datetime import datetime
4
+ from typing import Any, Optional, Dict, List, ClassVar
5
+ from pydantic import Field
6
+ # from treelib import Tree
7
+
8
+ from llama_index.core.instrumentation.span import SimpleSpan
9
+ from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler
10
+ from llama_index.core.instrumentation.events import BaseEvent
11
+ from llama_index.core.instrumentation.event_handlers import BaseEventHandler
12
+ from llama_index.core.instrumentation import get_dispatcher
13
+ from llama_index.core.instrumentation.span_handlers import SimpleSpanHandler
14
+
15
+ from llama_index.core.instrumentation.events.agent import (
16
+ AgentChatWithStepStartEvent,
17
+ AgentChatWithStepEndEvent,
18
+ AgentRunStepStartEvent,
19
+ AgentRunStepEndEvent,
20
+ AgentToolCallEvent,
21
+ )
22
+ from llama_index.core.instrumentation.events.chat_engine import (
23
+ StreamChatErrorEvent,
24
+ StreamChatDeltaReceivedEvent,
25
+ )
26
+ from llama_index.core.instrumentation.events.embedding import (
27
+ EmbeddingStartEvent,
28
+ EmbeddingEndEvent,
29
+ )
30
+ from llama_index.core.instrumentation.events.llm import (
31
+ LLMPredictEndEvent,
32
+ LLMPredictStartEvent,
33
+ LLMStructuredPredictEndEvent,
34
+ LLMStructuredPredictStartEvent,
35
+ LLMCompletionEndEvent,
36
+ LLMCompletionStartEvent,
37
+ LLMChatEndEvent,
38
+ LLMChatStartEvent,
39
+ LLMChatInProgressEvent,
40
+ )
41
+ from llama_index.core.instrumentation.events.query import (
42
+ QueryStartEvent,
43
+ QueryEndEvent,
44
+ )
45
+ from llama_index.core.instrumentation.events.rerank import (
46
+ ReRankStartEvent,
47
+ ReRankEndEvent,
48
+ )
49
+ from llama_index.core.instrumentation.events.retrieval import (
50
+ RetrievalStartEvent,
51
+ RetrievalEndEvent,
52
+ )
53
+ from llama_index.core.instrumentation.events.span import (
54
+ SpanDropEvent,
55
+ )
56
+ from llama_index.core.instrumentation.events.synthesis import (
57
+ SynthesizeStartEvent,
58
+ SynthesizeEndEvent,
59
+ GetResponseEndEvent,
60
+ GetResponseStartEvent,
61
+ )
62
+
63
+ import uuid
64
+
65
+ from .utils.extraction_logic_llama_index import extract_llama_index_data
66
+ from .utils.convert_llama_instru_callback import convert_llamaindex_instrumentation_to_callback
67
+
68
+ class EventHandler(BaseEventHandler):
69
+ """Example event handler.
70
+
71
+ This event handler is an example of how to create a custom event handler.
72
+
73
+ In general, logged events are treated as single events in a point in time,
74
+ that link to a span. The span is a collection of events that are related to
75
+ a single task. The span is identified by a unique span_id.
76
+
77
+ While events are independent, there is some hierarchy.
78
+ For example, in query_engine.query() call with a reranker attached:
79
+ - QueryStartEvent
80
+ - RetrievalStartEvent
81
+ - EmbeddingStartEvent
82
+ - EmbeddingEndEvent
83
+ - RetrievalEndEvent
84
+ - RerankStartEvent
85
+ - RerankEndEvent
86
+ - SynthesizeStartEvent
87
+ - GetResponseStartEvent
88
+ - LLMPredictStartEvent
89
+ - LLMChatStartEvent
90
+ - LLMChatEndEvent
91
+ - LLMPredictEndEvent
92
+ - GetResponseEndEvent
93
+ - SynthesizeEndEvent
94
+ - QueryEndEvent
95
+ """
96
+
97
+ events: List[BaseEvent] = []
98
+ current_trace: List[Dict[str, Any]] = [] # Store events for the current trace
99
+
100
+
101
+ @classmethod
102
+ def class_name(cls) -> str:
103
+ """Class name."""
104
+ return "EventHandler"
105
+
106
+ def handle(self, event: BaseEvent) -> None:
107
+ """Logic for handling event."""
108
+ # print("-----------------------")
109
+ # # all events have these attributes
110
+ # print(event.id_)
111
+ # print(event.timestamp)
112
+ # print(event.span_id)
113
+
114
+ # Prepare event details dictionary
115
+ event_details = {
116
+ "id": event.id_,
117
+ "timestamp": event.timestamp,
118
+ "span_id": event.span_id,
119
+ "event_type": event.class_name(),
120
+ }
121
+
122
+ # event specific attributes
123
+ # print(f"Event type: {event.class_name()}")
124
+ if isinstance(event, AgentRunStepStartEvent):
125
+ event_details.update({
126
+ "task_id": event.task_id,
127
+ "step": event.step,
128
+ "input": event.input,
129
+ })
130
+ if isinstance(event, AgentRunStepEndEvent):
131
+ event_details.update({
132
+ "step_output": event.step_output,
133
+ })
134
+ if isinstance(event, AgentChatWithStepStartEvent):
135
+ event_details.update({
136
+ "user_msg": event.user_msg,
137
+ })
138
+ if isinstance(event, AgentChatWithStepEndEvent):
139
+ event_details.update({
140
+ "response": event.response,
141
+ })
142
+ if isinstance(event, AgentToolCallEvent):
143
+ event_details.update({
144
+ "arguments": event.arguments,
145
+ "tool_name": event.tool.name,
146
+ "tool_description": event.tool.description,
147
+ "tool_openai": event.tool.to_openai_tool(),
148
+ })
149
+ if isinstance(event, StreamChatDeltaReceivedEvent):
150
+ event_details.update({
151
+ "delta": event.delta,
152
+ })
153
+ if isinstance(event, StreamChatErrorEvent):
154
+ event_details.update({
155
+ "exception": event.exception,
156
+ })
157
+ if isinstance(event, EmbeddingStartEvent):
158
+ event_details.update({
159
+ "model_dict": event.model_dict,
160
+ })
161
+ if isinstance(event, EmbeddingEndEvent):
162
+ event_details.update({
163
+ "chunks": event.chunks,
164
+ "embeddings": event.embeddings[0][:5],
165
+ })
166
+ if isinstance(event, LLMPredictStartEvent):
167
+ event_details.update({
168
+ "template": event.template,
169
+ "template_args": event.template_args,
170
+ })
171
+ if isinstance(event, LLMPredictEndEvent):
172
+ event_details.update({
173
+ "output": event.output,
174
+ })
175
+ if isinstance(event, LLMStructuredPredictStartEvent):
176
+ event_details.update({
177
+ "template": event.template,
178
+ "template_args": event.template_args,
179
+ "output_cls": event.output_cls,
180
+ })
181
+ if isinstance(event, LLMStructuredPredictEndEvent):
182
+ event_details.update({
183
+ "output": event.output,
184
+ })
185
+ if isinstance(event, LLMCompletionStartEvent):
186
+ event_details.update({
187
+ "model_dict": event.model_dict,
188
+ "prompt": event.prompt,
189
+ "additional_kwargs": event.additional_kwargs,
190
+ })
191
+ if isinstance(event, LLMCompletionEndEvent):
192
+ event_details.update({
193
+ "response": event.response,
194
+ "prompt": event.prompt,
195
+ })
196
+ if isinstance(event, LLMChatInProgressEvent):
197
+ event_details.update({
198
+ "messages": event.messages,
199
+ "response": event.response,
200
+ })
201
+ if isinstance(event, LLMChatStartEvent):
202
+ event_details.update({
203
+ "messages": event.messages,
204
+ "additional_kwargs": event.additional_kwargs,
205
+ "model_dict": event.model_dict,
206
+ })
207
+ if isinstance(event, LLMChatEndEvent):
208
+ event_details.update({
209
+ "messages": event.messages,
210
+ "response": event.response,
211
+ })
212
+ if isinstance(event, RetrievalStartEvent):
213
+ event_details.update({
214
+ "str_or_query_bundle": event.str_or_query_bundle,
215
+ })
216
+ if isinstance(event, RetrievalEndEvent):
217
+ event_details.update({
218
+ "str_or_query_bundle": event.str_or_query_bundle,
219
+ "nodes": event.nodes,
220
+ "text": event.nodes[0].text
221
+ })
222
+ if isinstance(event, ReRankStartEvent):
223
+ event_details.update({
224
+ "query": event.query,
225
+ "nodes": event.nodes,
226
+ "top_n": event.top_n,
227
+ "model_name": event.model_name,
228
+ })
229
+ if isinstance(event, ReRankEndEvent):
230
+ event_details.update({
231
+ "nodes": event.nodes,
232
+ })
233
+ if isinstance(event, QueryStartEvent):
234
+ event_details.update({
235
+ "query": event.query,
236
+ })
237
+ if isinstance(event, QueryEndEvent):
238
+ event_details.update({
239
+ "response": event.response,
240
+ "query": event.query,
241
+ })
242
+ if isinstance(event, SpanDropEvent):
243
+ event_details.update({
244
+ "err_str": event.err_str,
245
+ })
246
+ if isinstance(event, SynthesizeStartEvent):
247
+ event_details.update({
248
+ "query": event.query,
249
+ })
250
+ if isinstance(event, SynthesizeEndEvent):
251
+ event_details.update({
252
+ "response": event.response,
253
+ "query": event.query,
254
+ })
255
+ if isinstance(event, GetResponseStartEvent):
256
+ event_details.update({
257
+ "query_str": event.query_str,
258
+ })
259
+
260
+ # Append event details to current_trace
261
+ self.current_trace.append(event_details)
262
+
263
+ self.events.append(event)
264
+
265
+ def _get_events_by_span(self) -> Dict[str, List[BaseEvent]]:
266
+ events_by_span: Dict[str, List[BaseEvent]] = {}
267
+ for event in self.events:
268
+ if event.span_id in events_by_span:
269
+ events_by_span[event.span_id].append(event)
270
+ else:
271
+ events_by_span[event.span_id] = [event]
272
+ return events_by_span
273
+
274
+ # def _get_event_span_trees(self) -> List[Tree]:
275
+ # events_by_span = self._get_events_by_span()
276
+
277
+ # trees = []
278
+ # tree = Tree()
279
+
280
+ # for span, sorted_events in events_by_span.items():
281
+ # # create root node i.e. span node
282
+ # tree.create_node(
283
+ # tag=f"{span} (SPAN)",
284
+ # identifier=span,
285
+ # parent=None,
286
+ # data=sorted_events[0].timestamp,
287
+ # )
288
+
289
+ # for event in sorted_events:
290
+ # tree.create_node(
291
+ # tag=f"{event.class_name()}: {event.id_}",
292
+ # identifier=event.id_,
293
+ # parent=event.span_id,
294
+ # data=event.timestamp,
295
+ # )
296
+
297
+ # trees.append(tree)
298
+ # tree = Tree()
299
+ # return trees
300
+
301
+ # def print_event_span_trees(self) -> None:
302
+ # """Method for viewing trace trees."""
303
+ # trees = self._get_event_span_trees()
304
+ # for tree in trees:
305
+ # print(
306
+ # tree.show(
307
+ # stdout=False, sorting=True, key=lambda node: node.data
308
+ # )
309
+ # )
310
+ # print("")
311
+
312
+
313
+
314
+ class SpanHandler(BaseSpanHandler[SimpleSpan]):
315
+ # span_dict = {}
316
+ span_dict: ClassVar[Dict[str, List[SimpleSpan]]] = {}
317
+
318
+ @classmethod
319
+ def class_name(cls) -> str:
320
+ """Class name."""
321
+ return "SpanHandler"
322
+
323
+ def new_span(
324
+ self,
325
+ id_: str,
326
+ bound_args: Any,
327
+ instance: Optional[Any] = None,
328
+ parent_span_id: Optional[str] = None,
329
+ tags: Optional[Dict[str, Any]] = None,
330
+ **kwargs: Any,
331
+ ) -> Optional[SimpleSpan]:
332
+ """Create a span."""
333
+ # logic for creating a new MyCustomSpan
334
+ if id_ not in self.span_dict:
335
+ self.span_dict[id_] = []
336
+ self.span_dict[id_].append(
337
+ SimpleSpan(id_=id_, parent_id=parent_span_id)
338
+ )
339
+
340
+ def prepare_to_exit_span(
341
+ self,
342
+ id_: str,
343
+ bound_args: Any,
344
+ instance: Optional[Any] = None,
345
+ result: Optional[Any] = None,
346
+ **kwargs: Any,
347
+ ) -> Any:
348
+ """Logic for preparing to exit a span."""
349
+ pass
350
+ # if id in self.span_dict:
351
+ # return self.span_dict[id].pop()
352
+
353
+ def prepare_to_drop_span(
354
+ self,
355
+ id_: str,
356
+ bound_args: Any,
357
+ instance: Optional[Any] = None,
358
+ err: Optional[BaseException] = None,
359
+ **kwargs: Any,
360
+ ) -> Any:
361
+ """Logic for preparing to drop a span."""
362
+ pass
363
+ # if id in self.span_dict:
364
+ # return self.span_dict[id].pop()
365
+
366
+
367
+
368
+ class LlamaIndexInstrumentationTracer:
369
+ def __init__(self, user_detail):
370
+ """Initialize the LlamaIndexTracer with handlers but don't start tracing yet."""
371
+ # Initialize the root dispatcher
372
+ self.root_dispatcher = get_dispatcher()
373
+
374
+ # Initialize handlers
375
+ self.json_event_handler = EventHandler()
376
+ self.span_handler = SpanHandler()
377
+ self.simple_span_handler = SimpleSpanHandler()
378
+
379
+ self.is_tracing = False # Flag to check if tracing is active
380
+
381
+ self.user_detail = user_detail
382
+
383
+ def start(self):
384
+ """Start tracing by registering handlers."""
385
+ if self.is_tracing:
386
+ print("Tracing is already active.")
387
+ return
388
+
389
+ # Register handlers
390
+ self.root_dispatcher.add_span_handler(self.span_handler)
391
+ self.root_dispatcher.add_span_handler(self.simple_span_handler)
392
+ self.root_dispatcher.add_event_handler(self.json_event_handler)
393
+
394
+ self.is_tracing = True
395
+ print("Tracing started.")
396
+
397
+ def stop(self):
398
+ """Stop tracing by unregistering handlers."""
399
+ if not self.is_tracing:
400
+ print("Tracing is not active.")
401
+ return
402
+
403
+ # Write current_trace to a JSON file
404
+ final_traces = {
405
+ "project_id": self.user_detail["project_id"],
406
+ "trace_id": str(uuid.uuid4()),
407
+ "session_id": None,
408
+ "trace_type": "llamaindex",
409
+ "metadata": self.user_detail["trace_user_detail"]["metadata"],
410
+ "pipeline": self.user_detail["trace_user_detail"]["pipeline"],
411
+ "traces": self.json_event_handler.current_trace,
412
+
413
+ }
414
+
415
+ with open('new_llamaindex_traces.json', 'w') as f:
416
+ json.dump([final_traces], f, default=str, indent=4)
417
+
418
+ llamaindex_instrumentation_data = extract_llama_index_data([final_traces])
419
+ converted_back_to_callback = convert_llamaindex_instrumentation_to_callback(llamaindex_instrumentation_data)
420
+
421
+ # Just indicate tracing is stopped
422
+ self.is_tracing = False
423
+ print("Tracing stopped.")
424
+ return converted_back_to_callback