pangea-sdk 6.5.0__py3-none-any.whl → 6.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections.abc import Sequence
4
- from typing import Generic, Literal, Optional, overload
5
-
6
- from typing_extensions import TypeVar
4
+ from typing import Any, Generic, Literal, Optional, overload
7
5
 
6
+ from pangea._typing import T
8
7
  from pangea.config import PangeaConfig
9
8
  from pangea.response import APIRequestModel, APIResponseModel, PangeaResponse, PangeaResponseResult
10
9
  from pangea.services.base import ServiceBase
@@ -21,10 +20,15 @@ PiiEntityAction = Literal["disabled", "report", "block", "mask", "partial_maskin
21
20
 
22
21
 
23
22
  class Message(APIRequestModel):
24
- role: str
23
+ role: Optional[str] = None
25
24
  content: str
26
25
 
27
26
 
27
+ class McpToolsMessage(APIRequestModel):
28
+ role: Literal["tools"]
29
+ content: list[dict[str, Any]]
30
+
31
+
28
32
  class CodeDetectionOverride(APIRequestModel):
29
33
  disabled: Optional[bool] = None
30
34
  action: Optional[Literal["report", "block"]] = None
@@ -276,12 +280,9 @@ class CodeDetectionResult(APIResponseModel):
276
280
  """The action taken by this Detector"""
277
281
 
278
282
 
279
- _T = TypeVar("_T")
280
-
281
-
282
- class TextGuardDetector(APIResponseModel, Generic[_T]):
283
+ class TextGuardDetector(APIResponseModel, Generic[T]):
283
284
  detected: Optional[bool] = None
284
- data: Optional[_T] = None
285
+ data: Optional[T] = None
285
286
 
286
287
 
287
288
  class TextGuardDetectors(APIResponseModel):
@@ -301,6 +302,11 @@ class TextGuardDetectors(APIResponseModel):
301
302
  topic: Optional[TextGuardDetector[TopicDetectionResult]] = None
302
303
 
303
304
 
305
+ class PromptMessage(APIResponseModel):
306
+ role: str
307
+ content: str
308
+
309
+
304
310
  class TextGuardResult(PangeaResponseResult):
305
311
  detectors: TextGuardDetectors
306
312
  """Result of the recipe analyzing and input prompt."""
@@ -317,7 +323,7 @@ class TextGuardResult(PangeaResponseResult):
317
323
  unredact.
318
324
  """
319
325
 
320
- prompt_messages: Optional[object] = None
326
+ prompt_messages: Optional[list[PromptMessage]] = None
321
327
  """Updated structured prompt, if applicable."""
322
328
 
323
329
  prompt_text: Optional[str] = None
@@ -404,11 +410,12 @@ class AIGuard(ServiceBase):
404
410
  def guard_text(
405
411
  self,
406
412
  *,
407
- messages: Sequence[Message],
413
+ messages: Sequence[Message | McpToolsMessage],
408
414
  recipe: str | None = None,
409
415
  debug: bool | None = None,
410
416
  overrides: Overrides | None = None,
411
417
  log_fields: LogFields | None = None,
418
+ only_relevant_content: bool = False,
412
419
  ) -> PangeaResponse[TextGuardResult]:
413
420
  """
414
421
  Guard LLM input and output text
@@ -431,6 +438,8 @@ class AIGuard(ServiceBase):
431
438
  recipe: Recipe key of a configuration of data types and settings
432
439
  defined in the Pangea User Console. It specifies the rules that
433
440
  are to be applied to the text, such as defang malicious URLs.
441
+ only_relevant_content: Whether or not to only send relevant content
442
+ to AI Guard.
434
443
 
435
444
  Examples:
436
445
  response = ai_guard.guard_text(messages=[Message(role="user", content="hello world")])
@@ -440,11 +449,12 @@ class AIGuard(ServiceBase):
440
449
  self,
441
450
  text: str | None = None,
442
451
  *,
443
- messages: Sequence[Message] | None = None,
452
+ messages: Sequence[Message | McpToolsMessage] | None = None,
444
453
  debug: bool | None = None,
445
454
  log_fields: LogFields | None = None,
446
455
  overrides: Overrides | None = None,
447
456
  recipe: str | None = None,
457
+ only_relevant_content: bool = False,
448
458
  ) -> PangeaResponse[TextGuardResult]:
449
459
  """
450
460
  Guard LLM input and output text
@@ -470,6 +480,8 @@ class AIGuard(ServiceBase):
470
480
  recipe: Recipe key of a configuration of data types and settings
471
481
  defined in the Pangea User Console. It specifies the rules that
472
482
  are to be applied to the text, such as defang malicious URLs.
483
+ only_relevant_content: Whether or not to only send relevant content
484
+ to AI Guard.
473
485
 
474
486
  Examples:
475
487
  response = ai_guard.guard_text("text")
@@ -478,7 +490,11 @@ class AIGuard(ServiceBase):
478
490
  if text is not None and messages is not None:
479
491
  raise ValueError("Exactly one of `text` or `messages` must be given")
480
492
 
481
- return self.request.post(
493
+ if only_relevant_content and messages is not None:
494
+ original_messages = messages
495
+ messages, original_indices = get_relevant_content(messages)
496
+
497
+ response = self.request.post(
482
498
  "v1/text/guard",
483
499
  TextGuardResult,
484
500
  data={
@@ -490,3 +506,65 @@ class AIGuard(ServiceBase):
490
506
  "log_fields": log_fields,
491
507
  },
492
508
  )
509
+
510
+ if only_relevant_content and response.result and response.result.prompt_messages:
511
+ response.result.prompt_messages = patch_messages(
512
+ original_messages, original_indices, response.result.prompt_messages
513
+ ) # type: ignore[assignment]
514
+
515
+ return response
516
+
517
+
518
+ def get_relevant_content(
519
+ messages: Sequence[Message | McpToolsMessage],
520
+ ) -> tuple[list[Message | McpToolsMessage], list[int]]:
521
+ """
522
+ Returns relevant messages and their indices in the original list.
523
+
524
+ 1, If last message is "assistant", then the relevant messages are all system
525
+ messages that come before it, plus that last assistant message.
526
+ 2. Else, find the last "assistant" message. Then the relevant messages are
527
+ all system messages that come before it, and all messages that come after
528
+ it.
529
+ """
530
+
531
+ if len(messages) == 0:
532
+ return [], []
533
+
534
+ system_messages = [msg for msg in messages if msg.role == "system"]
535
+ system_indices = [i for i, msg in enumerate(messages) if msg.role == "system"]
536
+
537
+ # If the last message is assistant, then return all system messages and that
538
+ # assistant message.
539
+ if messages[-1].role == "assistant":
540
+ return system_messages + [messages[-1]], system_indices + [len(messages) - 1]
541
+
542
+ # Otherwise, work backwards until we find the last assistant message, then
543
+ # return all messages after that.
544
+ last_assistant_index = -1
545
+ for i in range(len(messages) - 1, -1, -1):
546
+ if messages[i].role == "assistant":
547
+ last_assistant_index = i
548
+ break
549
+
550
+ relevant_messages = []
551
+ indices = []
552
+ for i, msg in enumerate(messages):
553
+ if msg.role == "system" or i > last_assistant_index:
554
+ relevant_messages.append(msg)
555
+ indices.append(i)
556
+
557
+ return relevant_messages, indices
558
+
559
+
560
+ def patch_messages(
561
+ original: Sequence[Message | McpToolsMessage],
562
+ original_indices: list[int],
563
+ transformed: Sequence[PromptMessage],
564
+ ) -> list[Message | McpToolsMessage | PromptMessage]:
565
+ if len(original) == len(transformed):
566
+ return list(transformed)
567
+
568
+ return [
569
+ transformed[original_indices.index(i)] if i in original_indices else orig for i, orig in enumerate(original)
570
+ ]