lexsi-sdk 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. lexsi_sdk/__init__.py +5 -0
  2. lexsi_sdk/client/__init__.py +0 -0
  3. lexsi_sdk/client/client.py +176 -0
  4. lexsi_sdk/common/__init__.py +0 -0
  5. lexsi_sdk/common/config/.env.prod +3 -0
  6. lexsi_sdk/common/constants.py +143 -0
  7. lexsi_sdk/common/enums.py +8 -0
  8. lexsi_sdk/common/environment.py +49 -0
  9. lexsi_sdk/common/monitoring.py +81 -0
  10. lexsi_sdk/common/trigger.py +75 -0
  11. lexsi_sdk/common/types.py +122 -0
  12. lexsi_sdk/common/utils.py +93 -0
  13. lexsi_sdk/common/validation.py +110 -0
  14. lexsi_sdk/common/xai_uris.py +197 -0
  15. lexsi_sdk/core/__init__.py +0 -0
  16. lexsi_sdk/core/agent.py +62 -0
  17. lexsi_sdk/core/alert.py +56 -0
  18. lexsi_sdk/core/case.py +618 -0
  19. lexsi_sdk/core/dashboard.py +131 -0
  20. lexsi_sdk/core/guardrails/__init__.py +0 -0
  21. lexsi_sdk/core/guardrails/guard_template.py +299 -0
  22. lexsi_sdk/core/guardrails/guardrail_autogen.py +554 -0
  23. lexsi_sdk/core/guardrails/guardrails_langgraph.py +525 -0
  24. lexsi_sdk/core/guardrails/guardrails_openai.py +541 -0
  25. lexsi_sdk/core/guardrails/openai_runner.py +1328 -0
  26. lexsi_sdk/core/model_summary.py +110 -0
  27. lexsi_sdk/core/organization.py +549 -0
  28. lexsi_sdk/core/project.py +5131 -0
  29. lexsi_sdk/core/synthetic.py +387 -0
  30. lexsi_sdk/core/text.py +595 -0
  31. lexsi_sdk/core/tracer.py +208 -0
  32. lexsi_sdk/core/utils.py +36 -0
  33. lexsi_sdk/core/workspace.py +325 -0
  34. lexsi_sdk/core/wrapper.py +766 -0
  35. lexsi_sdk/core/xai.py +306 -0
  36. lexsi_sdk/version.py +34 -0
  37. lexsi_sdk-0.1.16.dist-info/METADATA +100 -0
  38. lexsi_sdk-0.1.16.dist-info/RECORD +40 -0
  39. lexsi_sdk-0.1.16.dist-info/WHEEL +5 -0
  40. lexsi_sdk-0.1.16.dist-info/top_level.txt +1 -0
@@ -0,0 +1,766 @@
1
+ import json
2
+ import time
3
+ import functools
4
+ from typing import Callable, Optional
5
+ import inspect
6
+ import uuid
7
+ import botocore.client
8
+ from openai import OpenAI
9
+ from anthropic import Anthropic
10
+ from google import genai
11
+ from mistralai import Mistral
12
+ from pydantic import BaseModel
13
+
14
+ import requests
15
+ import httpx
16
+ from lexsi_sdk.client.client import APIClient
17
+ from lexsi_sdk.common.environment import Environment
18
+ from lexsi_sdk.common.xai_uris import CASE_INFO_TEXT_URI, GENERATE_TEXT_CASE_STREAM_URI, GENERATE_TEXT_CASE_URI
19
+
20
+ from together import Together
21
+ from groq import Groq
22
+ import replicate
23
+ from huggingface_hub import InferenceClient
24
+ import boto3
25
+ from xai_sdk import Client
26
+ import botocore
27
+
28
+
29
+ class Wrapper:
30
+ """Wraps SDK clients to add Lexsi tracing, logging, and guardrails."""
31
+
32
+ def __init__(self, project_name, api_client):
33
+ """Store project context for downstream wrapper calls.
34
+
35
+ :param project_name: Name of the Lexsi project.
36
+ :param api_client: Initialized API client used for telemetry calls.
37
+ """
38
+ self.project_name = project_name
39
+ self.api_client = api_client
40
+
41
+ def add_message(self, session_id, trace_id, input_data, output_data, metadata, duration):
42
+ """Persist a message with timing metadata to Lexsi tracing.
43
+
44
+ :param session_id: Session identifier returned by tracing APIs.
45
+ :param trace_id: Trace identifier for the current LLM call chain.
46
+ :param input_data: Raw input payload sent to the model.
47
+ :param output_data: Model output payload.
48
+ :param metadata: Any additional metadata to persist.
49
+ :param duration: End-to-end latency for the operation.
50
+ :return: API response from the tracing endpoint.
51
+ """
52
+ payload = {
53
+ "project_name": self.project_name,
54
+ "session_id": session_id,
55
+ "trace_id": trace_id,
56
+ "input_data": input_data,
57
+ "output_data": output_data,
58
+ "metadata": metadata,
59
+ "duration": duration,
60
+ }
61
+ try:
62
+ res = self.api_client.post("sessions/add_session_message", payload=payload)
63
+ return res
64
+ except Exception as e:
65
+ raise e
66
+
67
+ async def async_add_trace_details(self, session_id, trace_id, component, input_data, metadata, output_data=None, function_to_run=None):
68
+ """Create an async trace entry, optionally executing and recording a coroutine.
69
+
70
+ :param session_id: Session identifier returned by tracing APIs.
71
+ :param trace_id: Trace identifier for the current call.
72
+ :param component: Logical component name (Input, LLM, Guardrails, etc.).
73
+ :param input_data: Payload being traced.
74
+ :param metadata: Extra metadata to attach to the trace.
75
+ :param output_data: Optional output to store; computed if omitted.
76
+ :param function_to_run: Awaitable to execute and trace around.
77
+ :return: API response or the result of the wrapped coroutine.
78
+ """
79
+ start_time = time.perf_counter()
80
+ result = None
81
+ if function_to_run:
82
+ try:
83
+ result = await function_to_run()
84
+ except Exception as e:
85
+ raise
86
+ duration = time.perf_counter() - start_time
87
+ if not output_data and result is not None:
88
+ output_data = result
89
+ if isinstance(result, BaseModel):
90
+ output_data = result.model_dump()
91
+ payload = {
92
+ "project_name": self.project_name,
93
+ "trace_id": trace_id,
94
+ "session_id": session_id,
95
+ "component": component,
96
+ "input_data": input_data,
97
+ "output_data": output_data,
98
+ "metadata": metadata,
99
+ "duration": duration,
100
+ }
101
+ res = self.api_client.post("traces/add_trace", payload=payload)
102
+ if function_to_run:
103
+ if component in ["Input Guardrails", "Output Guardrails"]:
104
+ if not result.get("success", True):
105
+ return result.get("details")
106
+ return result
107
+ return res
108
+
109
+ def add_trace_details(self, session_id, trace_id, component, input_data, metadata, is_grok = False, output_data=None, function_to_run=None):
110
+ """Create a trace entry for synchronous flows, executing optional callable.
111
+
112
+ :param session_id: Session identifier returned by tracing APIs.
113
+ :param trace_id: Trace identifier for the current call.
114
+ :param component: Logical component name (Input, LLM, Guardrails, etc.).
115
+ :param input_data: Payload being traced.
116
+ :param metadata: Extra metadata to attach to the trace.
117
+ :param is_grok: Whether the result is from Grok and needs special handling.
118
+ :param output_data: Optional output to store; computed if omitted.
119
+ :param function_to_run: Callable to execute and trace around.
120
+ :return: API response or the result of the wrapped callable.
121
+ """
122
+ start_time = time.perf_counter()
123
+ result = None
124
+ if function_to_run:
125
+ try:
126
+ if is_grok:
127
+ result = function_to_run()
128
+ result = {
129
+ "id": result.id,
130
+ "content" : result.content,
131
+ "reasoning_content":result.reasoning_content,
132
+ "system_fingerprint": result.system_fingerprint,
133
+ "usage" : {
134
+ "completion_tokens" : result.usage.completion_tokens,
135
+ "prompt_tokens:" : result.usage.prompt_tokens,
136
+ "total_tokens" : result.usage.total_tokens,
137
+ "prompt_text_tokens" : result.usage.prompt_text_tokens,
138
+ "reasoning_tokens" : result.usage.reasoning_tokens,
139
+ "cached_prompt_text_tokens" : result.usage.cached_prompt_text_tokens
140
+ }
141
+ }
142
+ else:
143
+ result = function_to_run()
144
+ except Exception as e:
145
+ raise
146
+ duration = time.perf_counter() - start_time
147
+ if not output_data and result is not None:
148
+ output_data = result
149
+ if isinstance(result, BaseModel):
150
+ output_data = result.model_dump()
151
+
152
+ payload = {
153
+ "project_name": self.project_name,
154
+ "trace_id": trace_id,
155
+ "session_id": session_id,
156
+ "component": component,
157
+ "input_data": input_data,
158
+ "output_data": output_data,
159
+ "metadata": metadata,
160
+ "duration": duration,
161
+ }
162
+ res = self.api_client.post("traces/add_trace", payload=payload)
163
+ if function_to_run:
164
+ if component in ["Input Guardrails", "Output Guardrails"]:
165
+ if not result.get("success", True):
166
+ return result.get("details")
167
+ return result
168
+ return res
169
+
170
+ def run_guardrails(self, input_data, trace_id, session_id, model_name, apply_on):
171
+ """Invoke server-side guardrails for the provided input/output payload.
172
+
173
+ :param input_data: Payload to validate with guardrails.
174
+ :param trace_id: Trace identifier for this run.
175
+ :param session_id: Session identifier for this run.
176
+ :param model_name: Model identifier to supply to guardrails.
177
+ :param apply_on: Whether guardrails apply to input or output.
178
+ :return: API response from guardrails execution.
179
+ """
180
+ payload = {
181
+ "trace_id": trace_id,
182
+ "session_id": session_id,
183
+ "input_data": input_data,
184
+ "model_name": model_name,
185
+ "project_name": self.project_name,
186
+ "apply_on": apply_on
187
+ }
188
+ try:
189
+ res = self.api_client.post("v2/ai-models/run_guardrails", payload=payload)
190
+ return res
191
+ except Exception as e:
192
+ raise
193
+
194
+ def _get_wrapper(self, original_method: Callable, method_name: str, provider: str, session_id: Optional[str] = None, chat=None , **extra_kwargs) -> Callable:
195
+ """Return a callable that wraps LLM SDK methods with Lexsi telemetry.
196
+
197
+ :param original_method: SDK method being wrapped.
198
+ :param method_name: Identifier describing the wrapped method.
199
+ :param provider: Name of the model provider (OpenAI, Anthropic, etc.).
200
+ :param session_id: Optional session id to reuse.
201
+ :param chat: Optional chat object for Grok/XAI.
202
+ :param extra_kwargs: Extra kwargs to inject into wrapped calls.
203
+ :return: Wrapped callable preserving the original signature.
204
+ """
205
+ if inspect.iscoroutinefunction(original_method):
206
+ @functools.wraps(original_method)
207
+ async def async_wrapper(*args, **kwargs):
208
+ """Async wrapper around SDK method to add tracing and guardrails."""
209
+ total_start_time = time.perf_counter()
210
+ trace_id = str(uuid.uuid4())
211
+ # model_name = kwargs.get("model")
212
+ model_name = provider
213
+ input_data = kwargs.get("messages")
214
+
215
+ trace_res = self.add_trace_details(
216
+ trace_id=trace_id,
217
+ session_id=session_id,
218
+ component="Input",
219
+ input_data=input_data,
220
+ output_data=input_data,
221
+ metadata={},
222
+ )
223
+ id_session = trace_res.get("details", {}).get("session_id")
224
+
225
+
226
+ self.add_trace_details(
227
+ trace_id=trace_id,
228
+ session_id=id_session,
229
+ component="Input Guardrails",
230
+ input_data=input_data,
231
+ metadata={},
232
+ function_to_run=lambda: self.run_guardrails(
233
+ session_id=id_session,
234
+ trace_id=trace_id,
235
+ input_data=input_data,
236
+ model_name=model_name,
237
+ apply_on="input"
238
+ )
239
+ )
240
+
241
+ result = await self.async_add_trace_details(
242
+ trace_id=trace_id,
243
+ session_id=id_session,
244
+ component="LLM",
245
+ input_data=input_data,
246
+ metadata=kwargs,
247
+ function_to_run= lambda : original_method(*args, **kwargs)
248
+ )
249
+
250
+ output_data = result.choices[0].message.content
251
+
252
+
253
+ self.add_trace_details(
254
+ trace_id=trace_id,
255
+ session_id=id_session,
256
+ component="Output Guardrails",
257
+ input_data=output_data,
258
+ metadata={},
259
+ function_to_run=lambda: self.run_guardrails(
260
+ session_id=id_session,
261
+ trace_id=trace_id,
262
+ model_name=model_name,
263
+ input_data=output_data,
264
+ apply_on="output"
265
+ )
266
+ )
267
+
268
+ self.add_trace_details(
269
+ trace_id=trace_id,
270
+ session_id=id_session,
271
+ component="Output",
272
+ input_data=input_data,
273
+ output_data=output_data,
274
+ metadata={},
275
+ )
276
+
277
+ self.add_message(
278
+ trace_id=trace_id,
279
+ session_id=id_session,
280
+ input_data=input_data,
281
+ output_data=output_data,
282
+ duration=time.perf_counter() - total_start_time,
283
+ metadata={}
284
+ )
285
+
286
+ return result
287
+ return async_wrapper
288
+ else:
289
+ @functools.wraps(original_method)
290
+ def wrapper(*args, **kwargs):
291
+ """Sync wrapper around SDK method to add tracing and guardrails."""
292
+ total_start_time = time.perf_counter()
293
+ trace_id = str(uuid.uuid4())
294
+ model_name = None
295
+ input_data = None
296
+
297
+ if extra_kwargs:
298
+ kwargs.update(extra_kwargs)
299
+ # Handle input data based on method
300
+ if method_name == "client.chat.completions.create": # OpenAI (Completions)
301
+ input_data = kwargs.get("messages")
302
+ # model_name = kwargs.get("model")
303
+ model_name = provider
304
+ elif method_name == "client.responses.create": # OpenAI (Response)
305
+ input_data = kwargs.get("input")
306
+ # model_name = kwargs.get("model")
307
+ model_name = provider
308
+ elif method_name == "client.messages.create": # Anthropic Messages API
309
+ input_data = kwargs.get("messages")
310
+ # model_name = kwargs.get("model")
311
+ model_name = provider
312
+ elif method_name == "client.models.generate_content": # Gemini
313
+ input_data = kwargs.get("contents")
314
+ # model_name = kwargs.get("model")
315
+ model_name = provider
316
+ elif method_name == "client.chat.complete": # Mistral
317
+ input_data = kwargs.get("messages")
318
+ # model_name = kwargs.get("model")
319
+ model_name = provider
320
+ elif method_name == "client.chat_completion":
321
+ input_data = kwargs.get("messages")
322
+ # model_name = kwargs.get("model")
323
+ model_name = provider
324
+ elif method_name == "client.generate_text_case": # LexsiModels
325
+ input_data = kwargs.get("prompt")
326
+ model_name = kwargs.get("model_name")
327
+ elif method_name == "client.run":
328
+ input_data = kwargs.get("input")
329
+ # model_name = args.index(0)
330
+ model_name = provider
331
+ elif method_name == "client.converse": # Bedrock
332
+ input_data = kwargs.get("messages")
333
+ # model_name = kwargs.get("modelId")
334
+ model_name = provider
335
+ elif method_name == "chat.sample": # XAI Grok
336
+ input_data = chat.messages[1].content[0].text
337
+ # model_name = None
338
+ model_name = provider
339
+ else:
340
+ input_data = kwargs
341
+ model_name = None
342
+
343
+
344
+ trace_res = self.add_trace_details(
345
+ trace_id=trace_id,
346
+ session_id=session_id,
347
+ component="Input",
348
+ input_data=input_data,
349
+ output_data=input_data,
350
+ metadata={},
351
+ )
352
+ id_session = trace_res.get("details", {}).get("session_id")
353
+
354
+ self.add_trace_details(
355
+ trace_id=trace_id,
356
+ session_id=id_session,
357
+ component="Input Guardrails",
358
+ input_data=input_data,
359
+ metadata={},
360
+ function_to_run=lambda: self.run_guardrails(
361
+ session_id=id_session,
362
+ trace_id=trace_id,
363
+ input_data=input_data,
364
+ model_name=model_name,
365
+ apply_on="input"
366
+ )
367
+ )
368
+
369
+ if method_name == "client.generate_text_case":
370
+ kwargs["session_id"] = id_session
371
+ kwargs["trace_id"] = trace_id
372
+
373
+ sanitized_kwargs = {}
374
+ if method_name == "chat.sample":
375
+ # For XAI, use the stored kwargs from chat creation
376
+ chat_obj = args[0] if args else None
377
+ if chat_obj and hasattr(chat_obj, '_original_kwargs'):
378
+ for key, value in chat_obj._original_kwargs.items():
379
+ try:
380
+ import json
381
+ json.dumps(value)
382
+ sanitized_kwargs[key] = value
383
+ except (TypeError, ValueError):
384
+ sanitized_kwargs[key] = str(value)
385
+
386
+
387
+ if method_name == "chat.sample":
388
+ result = self.add_trace_details(
389
+ trace_id=trace_id,
390
+ session_id=id_session,
391
+ component="LLM",
392
+ is_grok=True,
393
+ input_data=input_data,
394
+ metadata=sanitized_kwargs,
395
+ function_to_run=lambda: original_method(*args) if method_name == "chat.sample" else lambda: original_method(*args, **kwargs)
396
+ )
397
+ else:
398
+ result = self.add_trace_details(
399
+ trace_id=trace_id,
400
+ session_id=id_session,
401
+ component="LLM",
402
+ input_data=input_data,
403
+ metadata=kwargs,
404
+ function_to_run=lambda: original_method(*args, **kwargs)
405
+ )
406
+
407
+ # Handle output data based on method
408
+ if method_name == "chat.sample": # XAI Grok
409
+ output_data = result
410
+ elif method_name == "client.converse": # Bedrock
411
+ output_data = result["output"]["message"]["content"][-1]["text"]
412
+ elif method_name == "client.responses.create":
413
+ output_data = result.output_text
414
+ elif method_name == "client.messages.create": # Anthropic Messages API
415
+ output_data = result.content[0].text
416
+ elif method_name == "client.models.generate_content": # Gemini
417
+ output_data = result.text
418
+ elif method_name == "client.chat.complete": # Mistral
419
+ output_data = result.choices[0].message.content
420
+ elif method_name == "client.chat.complete_async": # Mistral Async
421
+ output_data = result.choices[0].message.content
422
+ elif method_name == "client.generate_text_case": # LexsiModels
423
+ output_data = result.get("details", {}).get("result", {}).get("output")
424
+ elif method_name == "client.run": # Replicate
425
+ output_data == result
426
+ elif method_name == "client.chat.completions.create" or "client.chat_completion": # OpenAI
427
+ output_data = result.choices[0].message.content
428
+ else:
429
+ output_data = result
430
+
431
+ self.add_trace_details(
432
+ trace_id=trace_id,
433
+ session_id=id_session,
434
+ component="Output Guardrails",
435
+ input_data=output_data,
436
+ metadata={},
437
+ function_to_run=lambda: self.run_guardrails(
438
+ session_id=id_session,
439
+ trace_id=trace_id,
440
+ model_name=model_name,
441
+ input_data=output_data,
442
+ apply_on="output"
443
+ )
444
+ )
445
+
446
+ self.add_trace_details(
447
+ trace_id=trace_id,
448
+ session_id=id_session,
449
+ component="Output",
450
+ input_data=input_data,
451
+ output_data=output_data,
452
+ metadata={},
453
+ )
454
+ metadata = {}
455
+ if method_name == "client.generate_text_case":
456
+
457
+ input_tokens = 0
458
+ output_tokens = 0
459
+ if result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("input_tokens", None):
460
+ input_tokens = result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("input_tokens")
461
+ elif result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("input_decoded_length", None):
462
+ input_tokens = result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("input_decoded_length")
463
+
464
+ if result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("output_tokens", None):
465
+ output_tokens = result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("output_tokens")
466
+ elif result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("output_decoded_length", None):
467
+ output_tokens = result.get("details", {}).get("result", {}).get("audit_trail", {}).get("tokens", {}).get("output_decoded_length")
468
+ total_tokens = input_tokens + output_tokens
469
+ metadata = {
470
+ "case_id":result.get("details",{}).get("case_id"),
471
+ "input_tokens": input_tokens,
472
+ "output_tokens": output_tokens,
473
+ "total_tokens": total_tokens
474
+ }
475
+ self.add_message(
476
+ trace_id=trace_id,
477
+ session_id=id_session,
478
+ input_data=input_data,
479
+ output_data=output_data,
480
+ duration=time.perf_counter() - total_start_time,
481
+ metadata=metadata
482
+ )
483
+
484
+ return result
485
+ return wrapper
486
+
487
+
488
+ class LexsiModels:
489
+ """Convenience wrapper for Lexsi hosted text models."""
490
+
491
+ def __init__(self, project, api_client: APIClient):
492
+ """Bind project and API client for model operations.
493
+
494
+ :param project: Project instance owning the model.
495
+ :param api_client: API client with auth configured.
496
+ """
497
+ self.project = project
498
+ self.api_client = api_client
499
+
500
+ def generate_text_case(
501
+ self,
502
+ model_name: str,
503
+ prompt: str,
504
+ instance_type: str = "xsmall",
505
+ serverless_instance_type: str = "gova-2",
506
+ explainability_method: list = ["DLB"],
507
+ explain_model: bool = False,
508
+ trace_id: str = None,
509
+ session_id: str = None,
510
+ min_tokens: int = 100,
511
+ max_tokens: int = 500,
512
+ stream: bool = False,
513
+ ):
514
+ """Generate an explainable text case using a hosted Lexsi model.
515
+
516
+ :param model_name: Name of the deployed text model.
517
+ :param prompt: Input prompt for generation.
518
+ :param instance_type: Dedicated instance type (if applicable).
519
+ :param serverless_instance_type: Serverless instance flavor.
520
+ :param explainability_method: Methods to compute explanations with.
521
+ :param explain_model: Whether to explain the model behavior.
522
+ :param trace_id: Optional existing trace id.
523
+ :param session_id: Optional existing session id.
524
+ :param min_tokens: Minimum tokens to generate.
525
+ :param max_tokens: Maximum tokens to generate.
526
+ :param stream: Whether to stream responses.
527
+ :return: API response with generation details.
528
+ """
529
+ payload = {
530
+ "session_id": session_id,
531
+ "trace_id": trace_id,
532
+ "project_name": self.project.project_name,
533
+ "model_name": model_name,
534
+ "input_text": prompt,
535
+ "instance_type": instance_type,
536
+ "serverless_instance_type": serverless_instance_type,
537
+ "explainability_method": explainability_method,
538
+ "explain_model": explain_model,
539
+ "max_tokens": max_tokens,
540
+ "min_tokens": min_tokens,
541
+ "stream": stream,
542
+ }
543
+
544
+ # if stream:
545
+ # env = Environment()
546
+ # url = env.get_base_url() + "/" + GENERATE_TEXT_CASE_STREAM_URI
547
+ # with requests.post(
548
+ # url,
549
+ # headers={**self.api_client.headers, "Accept": "text/event-stream"},
550
+ # json=payload,
551
+ # stream=True,
552
+ # ) as response:
553
+ # response.raise_for_status()
554
+
555
+ # buffer = ""
556
+ # for line in response.iter_lines(decode_unicode=True):
557
+ # if not line or line.strip() == "[DONE]":
558
+ # continue
559
+
560
+ # if line.startswith("data:"):
561
+ # line = line[len("data:"):].strip()
562
+ # try:
563
+ # event = json.loads(line)
564
+ # text_piece = event.get("text", "")
565
+ # except Exception as e:
566
+ # text_piece = line
567
+ # buffer += text_piece
568
+ # print(text_piece, end="", flush=True)
569
+ # response = {"details": {"result": {"output": buffer}}}
570
+ # payload = {
571
+ # "session_id": session_id,
572
+ # "trace_id": trace_id,
573
+ # "project_name": self.project.project_name
574
+ # }
575
+ # res = self.api_client.post(CASE_INFO_TEXT_URI, payload)
576
+ # return res
577
+ # else:
578
+ # #return "Text case generation is not supported for this modality type"
579
+ # res = self.api_client.post(GENERATE_TEXT_CASE_URI, payload)
580
+ # if not res.get("success"):
581
+ # raise Exception(res.get("details"))
582
+ # return res
583
+
584
+ if stream:
585
+ env = Environment()
586
+ url = env.get_base_url() + "/" + GENERATE_TEXT_CASE_STREAM_URI
587
+
588
+ headers = {
589
+ **self.api_client.headers,
590
+ "Accept": "text/event-stream",
591
+ }
592
+
593
+ # Use a client so we can enable HTTP/2 and connection reuse if needed
594
+ with httpx.Client(http2=True, timeout=None) as client:
595
+ with client.stream(
596
+ "POST",
597
+ url,
598
+ headers=headers,
599
+ json=payload,
600
+ ) as response:
601
+ response.raise_for_status()
602
+
603
+ buffer = ""
604
+
605
+ for line in response.iter_lines():
606
+ if not line:
607
+ continue
608
+
609
+ # httpx may return str or bytes depending on encoding
610
+ if isinstance(line, bytes):
611
+ line = line.decode("utf-8", errors="ignore")
612
+
613
+ line = line.strip()
614
+ if not line or line == "[DONE]":
615
+ continue
616
+
617
+ if line.startswith("data:"):
618
+ line = line[len("data:"):].strip()
619
+
620
+ try:
621
+ event = json.loads(line)
622
+ text_piece = event.get("text", "")
623
+ except Exception:
624
+ # Fallback: treat raw line as text content
625
+ text_piece = line
626
+
627
+ buffer += text_piece
628
+ print(text_piece, end="", flush=True)
629
+
630
+ # After stream finishes, send the case info payload
631
+ session_id = payload.get("session_id")
632
+ trace_id = payload.get("trace_id")
633
+ payload_case = {
634
+ "session_id": session_id,
635
+ "trace_id": trace_id,
636
+ "project_name": self.project.project_name,
637
+ }
638
+ res = self.api_client.post(CASE_INFO_TEXT_URI, payload_case)
639
+ return res
640
+
641
+ else:
642
+ res = self.api_client.post(GENERATE_TEXT_CASE_URI, payload)
643
+ if not res.get("success"):
644
+ raise Exception(res.get("details"))
645
+ return res
646
+
647
+ def monitor(project, client, session_id=None):
648
+ """Attach tracing wrappers to supported SDK clients.
649
+
650
+ :param project: Project instance providing API client and name.
651
+ :param client: SDK client instance to instrument.
652
+ :param session_id: Optional session id to reuse.
653
+ :return: The same client instance with wrapped methods.
654
+ """
655
+ wrapper = Wrapper(project_name=project.project_name, api_client=project.api_client)
656
+ if isinstance(client, OpenAI):
657
+ models = project.models()["model_name"].to_list()
658
+ if "OpenAI" not in models:
659
+ raise Exception("OpenAI Model Not Initialized")
660
+ client.chat.completions.create = wrapper._get_wrapper(
661
+ original_method=client.chat.completions.create,
662
+ method_name="client.chat.completions.create",
663
+ session_id=session_id,
664
+ provider="OpenAI"
665
+ )
666
+ client.responses.create = wrapper._get_wrapper(
667
+ original_method=client.responses.create,
668
+ method_name="client.responses.create",
669
+ session_id=session_id,
670
+ provider="OpenAI"
671
+ )
672
+ elif isinstance(client, Anthropic):
673
+ client.messages.create = wrapper._get_wrapper(
674
+ original_method=client.messages.create,
675
+ method_name="client.messages.create",
676
+ session_id=session_id,
677
+ provider="Anthropic"
678
+ )
679
+ elif isinstance(client, genai.Client):
680
+ client.models.generate_content = wrapper._get_wrapper(
681
+ original_method=client.models.generate_content,
682
+ method_name="client.models.generate_content",
683
+ session_id=session_id,
684
+ provider="Gemini"
685
+ )
686
+ elif isinstance(client , Groq):
687
+ client.chat.completions.create = wrapper._get_wrapper(
688
+ original_method=client.chat.completions.create,
689
+ method_name="client.chat.completions.create",
690
+ session_id=session_id,
691
+ provider="Groq"
692
+ )
693
+ elif isinstance(client , Together):
694
+ client.chat.completions.create = wrapper._get_wrapper(
695
+ original_method=client.chat.completions.create,
696
+ method_name="client.chat.completions.create",
697
+ session_id=session_id,
698
+ provider="Together"
699
+ )
700
+ elif isinstance(client , InferenceClient):
701
+ client.chat_completion = wrapper._get_wrapper(
702
+ original_method=client.chat_completion,
703
+ method_name="client.chat_completion",
704
+ session_id=session_id,
705
+ provider="HuggingFace"
706
+ )
707
+ elif isinstance(client, replicate.Client) or client is replicate:
708
+ client.run = wrapper._get_wrapper(
709
+ original_method=client.run,
710
+ method_name="run",
711
+ session_id=session_id,
712
+ provider="Replicate"
713
+ )
714
+ elif isinstance(client, Mistral):
715
+ client.chat.complete = wrapper._get_wrapper(
716
+ original_method=client.chat.complete,
717
+ method_name="client.chat.complete",
718
+ session_id=session_id,
719
+ provider="Mistral"
720
+ )
721
+ client.chat.complete_async = wrapper._get_wrapper(
722
+ original_method=client.chat.complete_async,
723
+ method_name="client.chat.complete_async",
724
+ session_id=session_id,
725
+ provider="Mistral"
726
+ )
727
+ elif isinstance(client, botocore.client.BaseClient):
728
+ client.converse = wrapper._get_wrapper(
729
+ original_method=client.converse,
730
+ method_name="client.converse",
731
+ session_id=session_id,
732
+ provider="AWS Bedrock"
733
+ )
734
+ elif isinstance(client, Client): # XAI Client
735
+ # Wrap the chat.create method to return a wrapped chat object
736
+ original_chat_create = client.chat.create
737
+ def wrapped_chat_create(*args, **kwargs):
738
+ """Wrap chat creation to instrument returned chat object.
739
+
740
+ :param args: Positional args forwarded to chat.create.
741
+ :param kwargs: Keyword args forwarded to chat.create.
742
+ :return: Wrapped chat object with instrumented sample method.
743
+ """
744
+ chat = original_chat_create(*args, **kwargs)
745
+ chat.sample = wrapper._get_wrapper(
746
+ chat=chat,
747
+ original_method=chat.sample,
748
+ method_name="chat.sample",
749
+ session_id=session_id,
750
+ xai_kwargs=kwargs,
751
+ provider="Grok"
752
+ )
753
+ return chat
754
+
755
+ client.chat.create = wrapped_chat_create
756
+
757
+ elif isinstance(client, LexsiModels):
758
+ client.generate_text_case = wrapper._get_wrapper(
759
+ original_method=client.generate_text_case,
760
+ method_name="client.generate_text_case",
761
+ session_id=session_id,
762
+ provider='Lexsi'
763
+ )
764
+ else:
765
+ raise Exception("Not a valid SDK to monitor")
766
+ return client