ragaai-catalyst 2.1.7.5b5__py3-none-any.whl → 2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,206 +11,32 @@ logging_level = (
11
11
  logger.setLevel(logging.DEBUG) if os.getenv("DEBUG") == "1" else logging.INFO
12
12
  )
13
13
 
14
- def rag_trace_json_converter(input_trace, custom_model_cost, trace_id, user_details, tracer_type,user_context):
14
+ def rag_trace_json_converter(input_trace, custom_model_cost, trace_id, user_details, tracer_type, user_context):
15
15
  trace_aggregate = {}
16
- def get_prompt(input_trace):
17
- try:
18
- if tracer_type == "langchain":
19
- for span in input_trace:
20
- try:
21
- # First check if there's a user message in any of the input messages
22
- attributes = span.get("attributes", {})
23
-
24
- # Look for user role in any of the input messages
25
- if attributes:
26
- for key, value in attributes.items():
27
- try:
28
- if key.startswith("llm.input_messages.") and key.endswith(".message.role") and value == "user":
29
- # Extract the message number
30
- message_num = key.split(".")[2]
31
- # Construct the content key
32
- content_key = f"llm.input_messages.{message_num}.message.content"
33
- if content_key in attributes:
34
- return attributes.get(content_key)
35
- except Exception as e:
36
- logger.warning(f"Error processing attribute key-value pair: {str(e)}")
37
- continue
38
-
39
- for key, value in attributes.items():
40
- try:
41
- if key.startswith("llm.prompts") and isinstance(value, list):
42
- human_message = None
43
- for message in value:
44
- if isinstance(message, str):
45
- human_index = message.find("Human:")
46
- if human_index != -1:
47
- human_message = message[human_index:].replace("Human:", "")
48
- break
49
- return human_message if human_message else value
50
- except Exception as e:
51
- logger.warning(f"Error processing attribute key-value pair for prompt: {str(e)}")
52
- continue
53
- except Exception as e:
54
- logger.warning(f"Error processing span for prompt extraction: {str(e)}")
55
- continue
56
-
57
- for span in input_trace:
58
- try:
59
- # If no user message found, check for specific span types
60
- if span["name"] == "LLMChain":
61
- try:
62
- input_value = span["attributes"].get("input.value", "{}")
63
- return json.loads(input_value).get("question", "")
64
- except json.JSONDecodeError:
65
- logger.warning(f"Invalid JSON in LLMChain input.value: {input_value}")
66
- continue
67
- elif span["name"] == "RetrievalQA":
68
- return span["attributes"].get("input.value", "")
69
- elif span["name"] == "VectorStoreRetriever":
70
- return span["attributes"].get("input.value", "")
71
- except Exception as e:
72
- logger.warning(f"Error processing span for fallback prompt extraction: {str(e)}")
73
- continue
74
-
75
- # If we've gone through all spans and found nothing
76
- logger.warning("No user message found in any span")
77
- logger.warning("Returning empty string for prompt.")
78
- return ""
79
-
80
- logger.error("Prompt not found in the trace")
81
- return None
82
- except Exception as e:
83
- logger.error(f"Error while extracting prompt from trace: {str(e)}")
84
- return None
85
-
86
- def get_response(input_trace):
87
- try:
88
- if tracer_type == "langchain":
89
- for span in input_trace:
90
- try:
91
- attributes = span.get("attributes", {})
92
- if attributes:
93
- for key, value in attributes.items():
94
- try:
95
- if key.startswith("llm.output_messages.") and key.endswith(".message.content"):
96
- return value
97
- except Exception as e:
98
- logger.warning(f"Error processing attribute key-value pair for response: {str(e)}")
99
- continue
100
-
101
- for key, value in attributes.items():
102
- try:
103
- if key.startswith("output.value"):
104
- try:
105
- output_json = json.loads(value)
106
- if "generations" in output_json and isinstance(output_json.get("generations"), list) and len(output_json.get("generations")) > 0:
107
- if isinstance(output_json.get("generations")[0], list) and len(output_json.get("generations")[0]) > 0:
108
- first_generation = output_json.get("generations")[0][0]
109
- if "text" in first_generation:
110
- return first_generation["text"]
111
- except json.JSONDecodeError:
112
- logger.warning(f"Invalid JSON in output.value: {value}")
113
- continue
114
- except Exception as e:
115
- logger.warning(f"Error processing attribute key-value pair for response: {str(e)}")
116
- continue
117
- except Exception as e:
118
- logger.warning(f"Error processing span for response extraction: {str(e)}")
119
- continue
120
-
121
- for span in input_trace:
122
- try:
123
- if span["name"] == "LLMChain":
124
- try:
125
- output_value = span["attributes"].get("output.value", "")
126
- if output_value:
127
- return json.loads(output_value)
128
- return ""
129
- except json.JSONDecodeError:
130
- logger.warning(f"Invalid JSON in LLMChain output.value: {output_value}")
131
- continue
132
- elif span["name"] == "RetrievalQA":
133
- return span["attributes"].get("output.value", "")
134
- elif span["name"] == "VectorStoreRetriever":
135
- return span["attributes"].get("output.value", "")
136
- except Exception as e:
137
- logger.warning(f"Error processing span for fallback response extraction: {str(e)}")
138
- continue
139
-
140
- logger.warning("No response found in any span")
141
- return ""
142
-
143
- logger.error("Response not found in the trace")
144
- return None
145
- except Exception as e:
146
- logger.error(f"Error while extracting response from trace: {str(e)}")
147
- return None
148
-
149
- def get_context(input_trace):
150
- try:
151
- if user_context and user_context.strip():
152
- return user_context
153
- elif tracer_type == "langchain":
154
- for span in input_trace:
155
- try:
156
- if span["name"] == "VectorStoreRetriever":
157
- return span["attributes"].get("retrieval.documents.1.document.content", "")
158
- except Exception as e:
159
- logger.warning(f"Error processing span for context extraction: {str(e)}")
160
- continue
161
-
162
- logger.warning("Context not found in the trace")
163
- return ""
164
- except Exception as e:
165
- logger.error(f"Error while extracting context from trace: {str(e)}")
166
- return ""
167
-
168
- def get_span_errors(input_trace):
169
- try:
170
- if tracer_type == "langchain":
171
- span_errors = {}
172
- for span in input_trace:
173
- try:
174
- if "status" in span.keys() and span.get("status", {}).get("status_code", "").lower() == "error":
175
- span_errors[f"{span['name']}"] = span["status"]
176
- except:
177
- logger.error(f"Error fetching status from span")
178
- return span_errors
179
- except:
180
- logger.error(f"Error in get_span_errors")
181
- return None
182
-
183
-
184
-
185
-
186
-
187
- prompt = get_prompt(input_trace)
188
- response = get_response(input_trace)
189
- context = get_context(input_trace)
190
- error = get_span_errors(input_trace)
16
+ input_trace = add_span_hash_id(input_trace)
17
+ prompt = get_prompt(input_trace, tracer_type)
18
+ response = get_response(input_trace, tracer_type)
19
+ context = get_context(input_trace, tracer_type, user_context)
20
+ error = get_span_errors(input_trace, tracer_type)
191
21
 
192
22
  if tracer_type == "langchain":
193
23
  trace_aggregate["tracer_type"] = "langchain"
194
- else:
24
+ elif tracer_type == "llamaindex":
195
25
  trace_aggregate["tracer_type"] = "llamaindex"
196
26
 
197
- trace_aggregate['trace_id'] = trace_id
198
- trace_aggregate['session_id'] = None
27
+ trace_aggregate['id'] = trace_id
28
+ trace_aggregate['trace_name'] = user_details.get("dataset_name", "")
29
+ trace_aggregate['project_name'] = user_details.get("project_name", "")
30
+ trace_aggregate["start_time"] = input_trace[0].get("start_time", "")
31
+ trace_aggregate["end_time"] = input_trace[-1].get("end_time", "")
199
32
  trace_aggregate["metadata"] = user_details.get("trace_user_detail", {}).get("metadata")
200
33
  trace_aggregate["pipeline"] = user_details.get("trace_user_detail", {}).get("pipeline")
34
+ trace_aggregate["replays"] = {"source": None}
201
35
 
202
- trace_aggregate["data"] = {}
203
- trace_aggregate["data"]["prompt"] = prompt
204
- trace_aggregate["data"]["response"] = response
205
- trace_aggregate["data"]["context"] = context
206
- trace_aggregate["error"] = error
207
-
36
+ trace_aggregate["data"] = [{"spans": input_trace, "start_time": trace_aggregate["start_time"], "end_time": trace_aggregate["end_time"]}]
208
37
  if tracer_type == "langchain":
209
38
  additional_metadata = get_additional_metadata(input_trace, custom_model_cost, model_cost, prompt, response)
210
- else:
211
- additional_metadata = get_additional_metadata(input_trace, custom_model_cost, model_cost)
212
39
 
213
- trace_aggregate["metadata"] = user_details.get("trace_user_detail", {}).get("metadata")
214
40
  trace_aggregate["metadata"].update(additional_metadata)
215
41
  trace_aggregate["metadata"]["error"] = f"{error}"
216
42
  additional_metadata["error"] = error if error else None
@@ -366,4 +192,212 @@ def num_tokens_from_messages(model, message):
366
192
 
367
193
  except Exception as e:
368
194
  logger.error(f"Unexpected error in token counting: {str(e)}")
369
- return 0
195
+ return 0
196
+
197
+
198
+ def get_prompt(input_trace, tracer_type):
199
+ try:
200
+ if tracer_type == "langchain":
201
+ for span in input_trace:
202
+ try:
203
+ attributes = span.get("attributes", {})
204
+
205
+ if attributes:
206
+ for key, value in attributes.items():
207
+ try:
208
+ if key.startswith("llm.input_messages.") and key.endswith(".message.role") and value == "user":
209
+ message_num = key.split(".")[2]
210
+ content_key = f"llm.input_messages.{message_num}.message.content"
211
+ if content_key in attributes:
212
+ return attributes.get(content_key)
213
+ except Exception as e:
214
+ logger.warning(f"Error processing attribute key-value pair: {str(e)}")
215
+ continue
216
+
217
+ for key, value in attributes.items():
218
+ try:
219
+ if key.startswith("llm.prompts") and isinstance(value, list):
220
+ human_message = None
221
+ for message in value:
222
+ if isinstance(message, str):
223
+ human_index = message.find("Human:")
224
+ if human_index != -1:
225
+ human_message = message[human_index:].replace("Human:", "")
226
+ break
227
+ return human_message if human_message else value
228
+ except Exception as e:
229
+ logger.warning(f"Error processing attribute key-value pair for prompt: {str(e)}")
230
+ continue
231
+ except Exception as e:
232
+ logger.warning(f"Error processing span for prompt extraction: {str(e)}")
233
+ continue
234
+
235
+ for span in input_trace:
236
+ try:
237
+ if span["name"] == "LLMChain":
238
+ try:
239
+ input_value = span["attributes"].get("input.value", "{}")
240
+ return json.loads(input_value).get("question", "")
241
+ except json.JSONDecodeError:
242
+ logger.warning(f"Invalid JSON in LLMChain input.value: {input_value}")
243
+ continue
244
+ elif span["name"] == "RetrievalQA":
245
+ return span["attributes"].get("input.value", "")
246
+ elif span["name"] == "VectorStoreRetriever":
247
+ return span["attributes"].get("input.value", "")
248
+ except Exception as e:
249
+ logger.warning(f"Error processing span for fallback prompt extraction: {str(e)}")
250
+ continue
251
+
252
+ logger.warning("No user message found in any span")
253
+ logger.warning("Returning empty string for prompt.")
254
+ return ""
255
+ elif tracer_type == "llamaindex":
256
+ for span in input_trace:
257
+ if span["name"] == "BaseQueryEngine.query":
258
+ return span["attributes"]["input.value"]
259
+ elif "query_bundle" in span["attributes"].get("input.value", ""):
260
+ try:
261
+ query_data = json.loads(span["attributes"]["input.value"])
262
+ if "query_bundle" in query_data:
263
+ return query_data["query_bundle"]["query_str"]
264
+ except json.JSONDecodeError:
265
+ logger.error("Failed to parse query_bundle JSON")
266
+ logger.error("Prompt not found in the trace")
267
+ return None
268
+ except Exception as e:
269
+ logger.error(f"Error while extracting prompt from trace: {str(e)}")
270
+ return None
271
+
272
+ def get_response(input_trace, tracer_type):
273
+ try:
274
+ if tracer_type == "langchain":
275
+ for span in input_trace:
276
+ try:
277
+ attributes = span.get("attributes", {})
278
+ if attributes:
279
+ for key, value in attributes.items():
280
+ try:
281
+ if key.startswith("llm.output_messages.") and key.endswith(".message.content"):
282
+ return value
283
+ except Exception as e:
284
+ logger.warning(f"Error processing attribute key-value pair for response: {str(e)}")
285
+ continue
286
+
287
+ for key, value in attributes.items():
288
+ try:
289
+ if key.startswith("output.value"):
290
+ try:
291
+ output_json = json.loads(value)
292
+ if "generations" in output_json and isinstance(output_json.get("generations"), list) and len(output_json.get("generations")) > 0:
293
+ if isinstance(output_json.get("generations")[0], list) and len(output_json.get("generations")[0]) > 0:
294
+ first_generation = output_json.get("generations")[0][0]
295
+ if "text" in first_generation:
296
+ return first_generation["text"]
297
+ except json.JSONDecodeError:
298
+ logger.warning(f"Invalid JSON in output.value: {value}")
299
+ continue
300
+ except Exception as e:
301
+ logger.warning(f"Error processing attribute key-value pair for response: {str(e)}")
302
+ continue
303
+ except Exception as e:
304
+ logger.warning(f"Error processing span for response extraction: {str(e)}")
305
+ continue
306
+
307
+ for span in input_trace:
308
+ try:
309
+ if span["name"] == "LLMChain":
310
+ try:
311
+ output_value = span["attributes"].get("output.value", "")
312
+ if output_value:
313
+ return json.loads(output_value)
314
+ return ""
315
+ except json.JSONDecodeError:
316
+ logger.warning(f"Invalid JSON in LLMChain output.value: {output_value}")
317
+ continue
318
+ elif span["name"] == "RetrievalQA":
319
+ return span["attributes"].get("output.value", "")
320
+ elif span["name"] == "VectorStoreRetriever":
321
+ return span["attributes"].get("output.value", "")
322
+ except Exception as e:
323
+ logger.warning(f"Error processing span for fallback response extraction: {str(e)}")
324
+ continue
325
+
326
+ logger.warning("No response found in any span")
327
+ return ""
328
+ elif tracer_type == "llamaindex":
329
+ for span in input_trace:
330
+ if span["name"] == "BaseQueryEngine.query":
331
+ return span["attributes"]["output.value"]
332
+ logger.error("Response not found in the trace")
333
+ return None
334
+ except Exception as e:
335
+ logger.error(f"Error while extracting response from trace: {str(e)}")
336
+ return None
337
+
338
+ def get_context(input_trace, tracer_type, user_context):
339
+ try:
340
+ if user_context and user_context.strip():
341
+ return user_context
342
+ elif tracer_type == "langchain":
343
+ for span in input_trace:
344
+ try:
345
+ if span["name"] == "VectorStoreRetriever":
346
+ return span["attributes"].get("retrieval.documents.1.document.content", "")
347
+ except Exception as e:
348
+ logger.warning(f"Error processing span for context extraction: {str(e)}")
349
+ continue
350
+ elif tracer_type == "llamaindex":
351
+ for span in input_trace:
352
+ try:
353
+ if span["name"] == "BaseRetriever.retrieve":
354
+ return span["attributes"]["retrieval.documents.1.document.content"]
355
+ except Exception as e:
356
+ logger.warning(f"Error processing span for context extraction: {str(e)}")
357
+ continue
358
+ logger.warning("Context not found in the trace")
359
+ return ""
360
+ except Exception as e:
361
+ logger.error(f"Error while extracting context from trace: {str(e)}")
362
+ return ""
363
+
364
+ def get_span_errors(input_trace, tracer_type):
365
+ try:
366
+ if tracer_type == "langchain":
367
+ span_errors = {}
368
+ for span in input_trace:
369
+ try:
370
+ if "status" in span.keys() and span.get("status", {}).get("status_code", "").lower() == "error":
371
+ span_errors[f"{span['name']}"] = span["status"]
372
+ except:
373
+ logger.error(f"Error fetching status from span")
374
+ return span_errors
375
+ except:
376
+ logger.error(f"Error in get_span_errors")
377
+ return None
378
+
379
+ def add_span_hash_id(input_trace):
380
+ """
381
+ Add hash IDs to spans and track name occurrences.
382
+
383
+ Args:
384
+ input_trace (dict): The input trace containing spans
385
+
386
+ Returns:
387
+ dict: Modified trace with hash IDs and name occurrences added to spans
388
+ """
389
+ import uuid
390
+ from collections import defaultdict
391
+
392
+ name_counts = defaultdict(int)
393
+
394
+ for span in input_trace:
395
+ if "name" in span:
396
+ # Add hash ID
397
+ span["hash_id"] = str(uuid.uuid4())
398
+
399
+ # Track and update name occurrences
400
+ span["name_occurrences"] = name_counts[span["name"]]
401
+ name_counts[span["name"]] += 1
402
+
403
+ return input_trace