turingpulse-sdk-cohere 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# Virtual environments
|
|
7
|
+
.venv/
|
|
8
|
+
venv/
|
|
9
|
+
ENV/
|
|
10
|
+
|
|
11
|
+
# Distribution / packaging
|
|
12
|
+
dist/
|
|
13
|
+
build/
|
|
14
|
+
*.egg-info/
|
|
15
|
+
|
|
16
|
+
# Database files
|
|
17
|
+
*.db
|
|
18
|
+
*.sqlite3
|
|
19
|
+
|
|
20
|
+
# Environment variables
|
|
21
|
+
.env
|
|
22
|
+
.env.local
|
|
23
|
+
|
|
24
|
+
# IDE
|
|
25
|
+
.idea/
|
|
26
|
+
.vscode/
|
|
27
|
+
*.swp
|
|
28
|
+
*.swo
|
|
29
|
+
|
|
30
|
+
# Testing
|
|
31
|
+
.pytest_cache/
|
|
32
|
+
.coverage
|
|
33
|
+
htmlcov/
|
|
34
|
+
.tox/
|
|
35
|
+
|
|
36
|
+
# Logs
|
|
37
|
+
*.log
|
|
38
|
+
logs/
|
|
39
|
+
|
|
40
|
+
# OS files
|
|
41
|
+
.DS_Store
|
|
42
|
+
Thumbs.db
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: turingpulse-sdk-cohere
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: TuringPulse SDK integration for Cohere
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Requires-Dist: cohere>=5.20.0
|
|
8
|
+
Requires-Dist: turingpulse-sdk>=1.0.0
|
|
9
|
+
Provides-Extra: dev
|
|
10
|
+
Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
|
|
11
|
+
Requires-Dist: pytest>=8.0; extra == 'dev'
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "turingpulse-sdk-cohere"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = "TuringPulse SDK integration for Cohere"
|
|
9
|
+
requires-python = ">=3.11"
|
|
10
|
+
license = "Apache-2.0"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"turingpulse-sdk>=1.0.0",
|
|
13
|
+
"cohere>=5.20.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[project.optional-dependencies]
|
|
17
|
+
dev = ["pytest>=8.0", "pytest-asyncio>=0.23"]
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Cohere monkey-patch instrumentation for TuringPulse."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from contextvars import ContextVar
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
from turingpulse_sdk import instrument, GovernanceDirective
|
|
11
|
+
from turingpulse_sdk.config import MAX_FIELD_SIZE
|
|
12
|
+
from turingpulse_sdk.context import current_context
|
|
13
|
+
from turingpulse_sdk.exceptions import ConfigurationError
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger("turingpulse.sdk.cohere")
|
|
16
|
+
|
|
17
|
+
_INSTRUMENTING: ContextVar[bool] = ContextVar("_tp_cohere_instrumenting", default=False)
|
|
18
|
+
|
|
19
|
+
_ORIGINAL_CHAT: Any = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def patch_cohere(
|
|
23
|
+
*,
|
|
24
|
+
name: str | None = None,
|
|
25
|
+
governance: Optional[GovernanceDirective] = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Monkey-patch ``cohere.ClientV2.chat``."""
|
|
28
|
+
global _ORIGINAL_CHAT
|
|
29
|
+
|
|
30
|
+
effective_name = name or os.getenv("TP_WORKFLOW_NAME", "")
|
|
31
|
+
if not effective_name:
|
|
32
|
+
raise ConfigurationError(
|
|
33
|
+
"patch_cohere() requires name= or TP_WORKFLOW_NAME to be set."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from cohere import ClientV2
|
|
38
|
+
except ImportError as exc:
|
|
39
|
+
raise ImportError("cohere package is required: pip install cohere>=5.20.0") from exc
|
|
40
|
+
|
|
41
|
+
if _ORIGINAL_CHAT is not None:
|
|
42
|
+
logger.warning("Cohere is already patched — skipping")
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
_ORIGINAL_CHAT = ClientV2.chat
|
|
46
|
+
|
|
47
|
+
@instrument(name=effective_name, governance=governance)
|
|
48
|
+
def _patched_chat(self_client, *args: Any, **kwargs: Any) -> Any:
|
|
49
|
+
if _INSTRUMENTING.get(False):
|
|
50
|
+
return _ORIGINAL_CHAT(self_client, *args, **kwargs)
|
|
51
|
+
token = _INSTRUMENTING.set(True)
|
|
52
|
+
try:
|
|
53
|
+
response = _ORIGINAL_CHAT(self_client, *args, **kwargs)
|
|
54
|
+
_record_cohere_span(response, kwargs)
|
|
55
|
+
return response
|
|
56
|
+
finally:
|
|
57
|
+
_INSTRUMENTING.reset(token)
|
|
58
|
+
|
|
59
|
+
ClientV2.chat = _patched_chat
|
|
60
|
+
logger.info("Cohere patched for TuringPulse instrumentation")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def unpatch_cohere() -> None:
|
|
64
|
+
"""Restore original Cohere methods."""
|
|
65
|
+
global _ORIGINAL_CHAT
|
|
66
|
+
if _ORIGINAL_CHAT is None:
|
|
67
|
+
return
|
|
68
|
+
try:
|
|
69
|
+
from cohere import ClientV2
|
|
70
|
+
ClientV2.chat = _ORIGINAL_CHAT
|
|
71
|
+
except ImportError:
|
|
72
|
+
pass
|
|
73
|
+
_ORIGINAL_CHAT = None
|
|
74
|
+
logger.info("Cohere unpatched — original methods restored")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _record_cohere_span(response: Any, kwargs: dict) -> None:
|
|
78
|
+
ctx = current_context()
|
|
79
|
+
if not ctx:
|
|
80
|
+
return
|
|
81
|
+
ctx.framework = "cohere"
|
|
82
|
+
ctx.node_type = "llm"
|
|
83
|
+
|
|
84
|
+
usage = getattr(response, "usage", None)
|
|
85
|
+
if usage:
|
|
86
|
+
billed = getattr(usage, "billed_units", None)
|
|
87
|
+
if billed:
|
|
88
|
+
ctx.set_tokens(
|
|
89
|
+
getattr(billed, "input_tokens", 0) or 0,
|
|
90
|
+
getattr(billed, "output_tokens", 0) or 0,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
ctx.set_model(kwargs.get("model", getattr(response, "model", "unknown")), "cohere")
|
|
94
|
+
|
|
95
|
+
messages = kwargs.get("messages", [])
|
|
96
|
+
if messages:
|
|
97
|
+
last_user = next(
|
|
98
|
+
(m for m in reversed(messages)
|
|
99
|
+
if (m.get("role") if isinstance(m, dict) else getattr(m, "role", "")) == "user"),
|
|
100
|
+
None,
|
|
101
|
+
)
|
|
102
|
+
if last_user:
|
|
103
|
+
content = last_user.get("content", "") if isinstance(last_user, dict) else getattr(last_user, "content", "")
|
|
104
|
+
ctx.set_prompt(str(content)[:MAX_FIELD_SIZE])
|
|
105
|
+
|
|
106
|
+
tokens = getattr(response, "meta", None)
|
|
107
|
+
if tokens:
|
|
108
|
+
tok = getattr(tokens, "tokens", None) if not isinstance(tokens, dict) else tokens.get("tokens")
|
|
109
|
+
if tok:
|
|
110
|
+
inp = (getattr(tok, "input_tokens", 0) if not isinstance(tok, dict) else tok.get("input_tokens", 0)) or 0
|
|
111
|
+
out = (getattr(tok, "output_tokens", 0) if not isinstance(tok, dict) else tok.get("output_tokens", 0)) or 0
|
|
112
|
+
if inp or out:
|
|
113
|
+
ctx.set_tokens(inp, out)
|
|
114
|
+
|
|
115
|
+
message = getattr(response, "message", None)
|
|
116
|
+
if message:
|
|
117
|
+
content = getattr(message, "content", [])
|
|
118
|
+
if content:
|
|
119
|
+
texts = [getattr(c, "text", "") for c in content if getattr(c, "type", "") == "text"]
|
|
120
|
+
if texts:
|
|
121
|
+
ctx.set_io(output_data="\n".join(texts)[:MAX_FIELD_SIZE])
|
|
122
|
+
|
|
123
|
+
tool_calls = getattr(message, "tool_calls", None)
|
|
124
|
+
if tool_calls:
|
|
125
|
+
for tc in tool_calls:
|
|
126
|
+
ctx.add_tool_call(
|
|
127
|
+
tool_name=getattr(tc, "name", "unknown"),
|
|
128
|
+
tool_args=getattr(tc, "parameters", {}),
|
|
129
|
+
tool_id=getattr(tc, "id", None),
|
|
130
|
+
)
|