versionhq 1.1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- versionhq/__init__.py +33 -0
- versionhq/_utils/__init__.py +0 -0
- versionhq/_utils/cache_handler.py +13 -0
- versionhq/_utils/i18n.py +48 -0
- versionhq/_utils/logger.py +57 -0
- versionhq/_utils/process_config.py +28 -0
- versionhq/_utils/rpm_controller.py +73 -0
- versionhq/_utils/usage_metrics.py +31 -0
- versionhq/agent/__init__.py +0 -0
- versionhq/agent/model.py +472 -0
- versionhq/agent/parser.py +148 -0
- versionhq/cli/__init__.py +0 -0
- versionhq/clients/__init__.py +0 -0
- versionhq/clients/customer/__init__.py +0 -0
- versionhq/clients/customer/model.py +57 -0
- versionhq/clients/product/__init__.py +0 -0
- versionhq/clients/product/model.py +74 -0
- versionhq/clients/workflow/__init__.py +0 -0
- versionhq/clients/workflow/model.py +174 -0
- versionhq/llm/__init__.py +0 -0
- versionhq/llm/llm_vars.py +173 -0
- versionhq/llm/model.py +245 -0
- versionhq/task/__init__.py +9 -0
- versionhq/task/formatter.py +22 -0
- versionhq/task/model.py +430 -0
- versionhq/team/__init__.py +0 -0
- versionhq/team/model.py +585 -0
- versionhq/team/team_planner.py +55 -0
- versionhq/tool/__init__.py +0 -0
- versionhq/tool/composio.py +102 -0
- versionhq/tool/decorator.py +40 -0
- versionhq/tool/model.py +220 -0
- versionhq/tool/tool_handler.py +47 -0
- versionhq-1.1.4.4.dist-info/LICENSE +21 -0
- versionhq-1.1.4.4.dist-info/METADATA +353 -0
- versionhq-1.1.4.4.dist-info/RECORD +38 -0
- versionhq-1.1.4.4.dist-info/WHEEL +5 -0
- versionhq-1.1.4.4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,74 @@
|
|
1
|
+
import uuid
|
2
|
+
from typing import Any, Dict, List, Callable, Type, Optional, get_args, get_origin
|
3
|
+
from pydantic import (
|
4
|
+
UUID4,
|
5
|
+
InstanceOf,
|
6
|
+
BaseModel,
|
7
|
+
ConfigDict,
|
8
|
+
Field,
|
9
|
+
create_model,
|
10
|
+
field_validator,
|
11
|
+
model_validator,
|
12
|
+
)
|
13
|
+
from pydantic_core import PydanticCustomError
|
14
|
+
|
15
|
+
|
16
|
+
class ProductProvider(BaseModel):
|
17
|
+
"""
|
18
|
+
Store the minimal client information.
|
19
|
+
`data_pipeline` and `destinations` are for composio plug-in.
|
20
|
+
(!REFINEME) Create an Enum list for the options.
|
21
|
+
(!REFINEME) Create an Enum list for regions.
|
22
|
+
"""
|
23
|
+
|
24
|
+
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
25
|
+
name: Optional[str] = Field(default=None, description="client name")
|
26
|
+
region: Optional[str] = Field(
|
27
|
+
default=None, description="region of client's main business operation"
|
28
|
+
)
|
29
|
+
data_pipeline: Optional[List[str]] = Field(
|
30
|
+
default=None, description="store the data pipelines that the client is using"
|
31
|
+
)
|
32
|
+
destinations: Optional[List[str]] = Field(
|
33
|
+
default=None,
|
34
|
+
description="store the destination services that the client is using",
|
35
|
+
)
|
36
|
+
|
37
|
+
@field_validator("id", mode="before")
|
38
|
+
@classmethod
|
39
|
+
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
40
|
+
if v:
|
41
|
+
raise PydanticCustomError(
|
42
|
+
"may_not_set_field", "This field is not to be set by the user.", {}
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class Product(BaseModel):
|
47
|
+
"""
|
48
|
+
Store the product information necessary to the outbound effrots and connect it to the `ProductProvider` instance.
|
49
|
+
"""
|
50
|
+
|
51
|
+
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
52
|
+
name: Optional[str] = Field(default=None, description="product name")
|
53
|
+
description: Optional[str] = Field(
|
54
|
+
default=None,
|
55
|
+
max_length=256,
|
56
|
+
description="product description scraped from landing url or client input. cascade to the agent",
|
57
|
+
)
|
58
|
+
provider: Optional[ProductProvider] = Field(default=None)
|
59
|
+
audience: Optional[str] = Field(default=None, description="target audience")
|
60
|
+
usp: Optional[str] = Field(default=None)
|
61
|
+
landing_url: Optional[str] = Field(
|
62
|
+
default=None, description="marketing url of the product if any"
|
63
|
+
)
|
64
|
+
cohort_timeframe: Optional[int] = Field(
|
65
|
+
default=30, description="ideal cohort timeframe of the product in days"
|
66
|
+
)
|
67
|
+
|
68
|
+
@field_validator("id", mode="before")
|
69
|
+
@classmethod
|
70
|
+
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
71
|
+
if v:
|
72
|
+
raise PydanticCustomError(
|
73
|
+
"may_not_set_field", "This field is not to be set by the user.", {}
|
74
|
+
)
|
File without changes
|
@@ -0,0 +1,174 @@
|
|
1
|
+
import uuid
|
2
|
+
from abc import ABC
|
3
|
+
from datetime import date, datetime, time, timedelta
|
4
|
+
from typing import (
|
5
|
+
Any,
|
6
|
+
Dict,
|
7
|
+
List,
|
8
|
+
Union,
|
9
|
+
Callable,
|
10
|
+
Type,
|
11
|
+
Optional,
|
12
|
+
get_args,
|
13
|
+
get_origin,
|
14
|
+
)
|
15
|
+
from pydantic import (
|
16
|
+
UUID4,
|
17
|
+
InstanceOf,
|
18
|
+
BaseModel,
|
19
|
+
ConfigDict,
|
20
|
+
Field,
|
21
|
+
create_model,
|
22
|
+
field_validator,
|
23
|
+
model_validator,
|
24
|
+
)
|
25
|
+
from pydantic_core import PydanticCustomError
|
26
|
+
|
27
|
+
from versionhq.clients.product.model import Product
|
28
|
+
from versionhq.clients.customer.model import Customer
|
29
|
+
from versionhq.agent.model import Agent
|
30
|
+
from versionhq.team.model import Team
|
31
|
+
|
32
|
+
|
33
|
+
class ScoreFormat:
|
34
|
+
def __init__(self, rate: float, weight: int = 1):
|
35
|
+
self.rate = rate
|
36
|
+
self.weight = weight
|
37
|
+
self.aggregate = rate * weight
|
38
|
+
|
39
|
+
|
40
|
+
class Score:
|
41
|
+
"""
|
42
|
+
Evaluate the score on 0 (no performance) to 1 scale.
|
43
|
+
`rate`: Any float from 0.0 to 1.0 given by an agent.
|
44
|
+
`weight`: Importance of each factor to the aggregated score.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
brand_tone: ScoreFormat,
|
50
|
+
audience: ScoreFormat,
|
51
|
+
track_record: ScoreFormat,
|
52
|
+
*args: List[ScoreFormat],
|
53
|
+
):
|
54
|
+
self.brand_tone = brand_tone
|
55
|
+
self.audience = audience
|
56
|
+
self.track_record = track_record
|
57
|
+
self.args = args
|
58
|
+
|
59
|
+
def result(self):
|
60
|
+
aggregated_score = sum(
|
61
|
+
self.brand_tone.aggregate,
|
62
|
+
self.audience.aggregate,
|
63
|
+
self.track_record.aggrigate,
|
64
|
+
)
|
65
|
+
denominator = sum(
|
66
|
+
self.brand_tone.weight, self.audience.weight, self.track_record.weight
|
67
|
+
)
|
68
|
+
try:
|
69
|
+
if self.args:
|
70
|
+
for item in self.args:
|
71
|
+
if isinstance(item, ScoreFormat):
|
72
|
+
aggregate_score += item.rate * item.weight
|
73
|
+
denominator += item.weight
|
74
|
+
except:
|
75
|
+
pass
|
76
|
+
return round(aggregated_score / denominator, 2)
|
77
|
+
|
78
|
+
|
79
|
+
class MessagingComponent(ABC, BaseModel):
|
80
|
+
layer_id: int = Field(default=0, description="add id of the layer: 0, 1, 2")
|
81
|
+
message: str = Field(
|
82
|
+
default=None, max_length=1024, description="text message content to be sent"
|
83
|
+
)
|
84
|
+
interval: Optional[str] = Field(
|
85
|
+
default=None,
|
86
|
+
description="interval to move on to the next layer. if this is the last layer, set as `None`",
|
87
|
+
)
|
88
|
+
score: Union[float, InstanceOf[Score]] = Field(default=None)
|
89
|
+
|
90
|
+
|
91
|
+
class MessagingWorkflow(ABC, BaseModel):
|
92
|
+
"""
|
93
|
+
Store 3 layers of messaging workflow sent to `customer` on the `product`
|
94
|
+
"""
|
95
|
+
|
96
|
+
_created_at: Optional[datetime]
|
97
|
+
_updated_at: Optional[datetime]
|
98
|
+
|
99
|
+
model_config = ConfigDict()
|
100
|
+
|
101
|
+
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
102
|
+
comps: List[MessagingComponent] = Field(
|
103
|
+
default=list, description="store at least 3 messaging components"
|
104
|
+
)
|
105
|
+
|
106
|
+
# responsible tean or agents
|
107
|
+
team: Optional[Team] = Field(
|
108
|
+
default=None,
|
109
|
+
description="store `Team` instance responsibile for autopiloting this workflow",
|
110
|
+
)
|
111
|
+
agents: Optional[List[Agent]] = Field(
|
112
|
+
default=None,
|
113
|
+
description="store `Agent` instances responsible for autopiloting this workflow. if the team exsits, this field remains as `None`",
|
114
|
+
)
|
115
|
+
|
116
|
+
# metrics
|
117
|
+
destination: Optional[str] = Field(
|
118
|
+
default=None, description="destination service to launch this workflow"
|
119
|
+
)
|
120
|
+
product: InstanceOf[Product] = Field(default=None)
|
121
|
+
customer: InstanceOf[Customer] = Field(default=None)
|
122
|
+
|
123
|
+
metrics: Union[List[Dict[str, Any]], List[str]] = Field(
|
124
|
+
default=None,
|
125
|
+
max_length=256,
|
126
|
+
description="store metrics that used to predict and track the performance of this workflow.",
|
127
|
+
)
|
128
|
+
|
129
|
+
@property
|
130
|
+
def name(self):
|
131
|
+
if self.customer.id:
|
132
|
+
return (
|
133
|
+
f"Workflow ID: {self.id} - on {self.product.id} for {self.customer.id}"
|
134
|
+
)
|
135
|
+
else:
|
136
|
+
return f"Workflow ID: {self.id} - on {self.product.id}"
|
137
|
+
|
138
|
+
@field_validator("id", mode="before")
|
139
|
+
@classmethod
|
140
|
+
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
141
|
+
if v:
|
142
|
+
raise PydanticCustomError(
|
143
|
+
"may_not_set_field", "This field is not to be set by the user.", {}
|
144
|
+
)
|
145
|
+
|
146
|
+
@model_validator(mode="after")
|
147
|
+
def set_up_destination(self):
|
148
|
+
"""
|
149
|
+
Set up the destination service when self.destination is None.
|
150
|
+
Prioritize customer's destination to the product provider's destination list.
|
151
|
+
"""
|
152
|
+
if self.destination is None:
|
153
|
+
if self.customer is not None:
|
154
|
+
self.destination = self.customer.on
|
155
|
+
|
156
|
+
else:
|
157
|
+
destination_list = self.product.provider.destinations
|
158
|
+
if destination_list:
|
159
|
+
self.destination = destination_list[0]
|
160
|
+
return self
|
161
|
+
|
162
|
+
def reassign_agent_or_team(
|
163
|
+
self, agents: List[Agent] = None, team: Team = None
|
164
|
+
) -> None:
|
165
|
+
"""
|
166
|
+
Fire unresponsible agents/team and assign new one.
|
167
|
+
"""
|
168
|
+
|
169
|
+
if not agents and not team:
|
170
|
+
raise ValueError("Need to add at least 1 agent or team.")
|
171
|
+
|
172
|
+
self.agents = agents
|
173
|
+
self.team = team
|
174
|
+
self.updated_at = datetime.datetime.now()
|
File without changes
|
@@ -0,0 +1,173 @@
|
|
1
|
+
LLM_CONTEXT_WINDOW_SIZES = {
|
2
|
+
"gpt-3.5-turbo": 8192,
|
3
|
+
"gpt-4": 8192,
|
4
|
+
"gpt-4o": 128000,
|
5
|
+
"gpt-4o-mini": 128000,
|
6
|
+
"gpt-4-turbo": 128000,
|
7
|
+
"o1-preview": 128000,
|
8
|
+
"o1-mini": 128000,
|
9
|
+
"deepseek-chat": 128000,
|
10
|
+
"gemma2-9b-it": 8192,
|
11
|
+
"gemma-7b-it": 8192,
|
12
|
+
"llama3-groq-70b-8192-tool-use-preview": 8192,
|
13
|
+
"llama3-groq-8b-8192-tool-use-preview": 8192,
|
14
|
+
"llama-3.1-70b-versatile": 131072,
|
15
|
+
"llama-3.1-8b-instant": 131072,
|
16
|
+
"llama-3.2-1b-preview": 8192,
|
17
|
+
"llama-3.2-3b-preview": 8192,
|
18
|
+
"llama-3.2-11b-text-preview": 8192,
|
19
|
+
"llama-3.2-90b-text-preview": 8192,
|
20
|
+
"llama3-70b-8192": 8192,
|
21
|
+
"llama3-8b-8192": 8192,
|
22
|
+
"mixtral-8x7b-32768": 32768,
|
23
|
+
}
|
24
|
+
|
25
|
+
|
26
|
+
LLM_VARS = {
|
27
|
+
"openai": [
|
28
|
+
{
|
29
|
+
"prompt": "Enter your OPENAI API key (press Enter to skip)",
|
30
|
+
"key_name": "OPENAI_API_KEY",
|
31
|
+
}
|
32
|
+
],
|
33
|
+
"anthropic": [
|
34
|
+
{
|
35
|
+
"prompt": "Enter your ANTHROPIC API key (press Enter to skip)",
|
36
|
+
"key_name": "ANTHROPIC_API_KEY",
|
37
|
+
}
|
38
|
+
],
|
39
|
+
"gemini": [
|
40
|
+
{
|
41
|
+
"prompt": "Enter your GEMINI API key (press Enter to skip)",
|
42
|
+
"key_name": "GEMINI_API_KEY",
|
43
|
+
}
|
44
|
+
],
|
45
|
+
"watson": [
|
46
|
+
{
|
47
|
+
"prompt": "Enter your WATSONX URL (press Enter to skip)",
|
48
|
+
"key_name": "WATSONX_URL",
|
49
|
+
},
|
50
|
+
{
|
51
|
+
"prompt": "Enter your WATSONX API Key (press Enter to skip)",
|
52
|
+
"key_name": "WATSONX_APIKEY",
|
53
|
+
},
|
54
|
+
{
|
55
|
+
"prompt": "Enter your WATSONX Project Id (press Enter to skip)",
|
56
|
+
"key_name": "WATSONX_PROJECT_ID",
|
57
|
+
},
|
58
|
+
],
|
59
|
+
"ollama": [
|
60
|
+
{
|
61
|
+
"default": True,
|
62
|
+
"API_BASE": "http://localhost:11434",
|
63
|
+
}
|
64
|
+
],
|
65
|
+
"bedrock": [
|
66
|
+
{
|
67
|
+
"prompt": "Enter your AWS Access Key ID (press Enter to skip)",
|
68
|
+
"key_name": "AWS_ACCESS_KEY_ID",
|
69
|
+
},
|
70
|
+
{
|
71
|
+
"prompt": "Enter your AWS Secret Access Key (press Enter to skip)",
|
72
|
+
"key_name": "AWS_SECRET_ACCESS_KEY",
|
73
|
+
},
|
74
|
+
{
|
75
|
+
"prompt": "Enter your AWS Region Name (press Enter to skip)",
|
76
|
+
"key_name": "AWS_REGION_NAME",
|
77
|
+
},
|
78
|
+
],
|
79
|
+
"azure": [
|
80
|
+
{
|
81
|
+
"prompt": "Enter your Azure deployment name (must start with 'azure/')",
|
82
|
+
"key_name": "model",
|
83
|
+
},
|
84
|
+
{
|
85
|
+
"prompt": "Enter your AZURE API key (press Enter to skip)",
|
86
|
+
"key_name": "AZURE_API_KEY",
|
87
|
+
},
|
88
|
+
{
|
89
|
+
"prompt": "Enter your AZURE API base URL (press Enter to skip)",
|
90
|
+
"key_name": "AZURE_API_BASE",
|
91
|
+
},
|
92
|
+
{
|
93
|
+
"prompt": "Enter your AZURE API version (press Enter to skip)",
|
94
|
+
"key_name": "AZURE_API_VERSION",
|
95
|
+
},
|
96
|
+
],
|
97
|
+
"cerebras": [
|
98
|
+
{
|
99
|
+
"prompt": "Enter your Cerebras model name (must start with 'cerebras/')",
|
100
|
+
"key_name": "model",
|
101
|
+
},
|
102
|
+
{
|
103
|
+
"prompt": "Enter your Cerebras API version (press Enter to skip)",
|
104
|
+
"key_name": "CEREBRAS_API_KEY",
|
105
|
+
},
|
106
|
+
],
|
107
|
+
}
|
108
|
+
|
109
|
+
|
110
|
+
PROVIDERS = [
|
111
|
+
"openai",
|
112
|
+
"anthropic",
|
113
|
+
"gemini",
|
114
|
+
"ollama",
|
115
|
+
"watson",
|
116
|
+
"bedrock",
|
117
|
+
"azure",
|
118
|
+
"cerebras",
|
119
|
+
"llama",
|
120
|
+
]
|
121
|
+
|
122
|
+
MODELS = {
|
123
|
+
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini", "o1-mini", "o1-preview"],
|
124
|
+
"anthropic": [
|
125
|
+
"claude-3-5-sonnet-20240620",
|
126
|
+
"claude-3-sonnet-20240229",
|
127
|
+
"claude-3-opus-20240229",
|
128
|
+
"claude-3-haiku-20240307",
|
129
|
+
],
|
130
|
+
"gemini": [
|
131
|
+
"gemini/gemini-1.5-flash",
|
132
|
+
"gemini/gemini-1.5-pro",
|
133
|
+
"gemini/gemini-gemma-2-9b-it",
|
134
|
+
"gemini/gemini-gemma-2-27b-it",
|
135
|
+
],
|
136
|
+
"ollama": ["ollama/llama3.1", "ollama/mixtral"],
|
137
|
+
"watson": [
|
138
|
+
"watsonx/meta-llama/llama-3-1-70b-instruct",
|
139
|
+
"watsonx/meta-llama/llama-3-1-8b-instruct",
|
140
|
+
"watsonx/meta-llama/llama-3-2-11b-vision-instruct",
|
141
|
+
"watsonx/meta-llama/llama-3-2-1b-instruct",
|
142
|
+
"watsonx/meta-llama/llama-3-2-90b-vision-instruct",
|
143
|
+
"watsonx/meta-llama/llama-3-405b-instruct",
|
144
|
+
"watsonx/mistral/mistral-large",
|
145
|
+
"watsonx/ibm/granite-3-8b-instruct",
|
146
|
+
],
|
147
|
+
"bedrock": [
|
148
|
+
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
149
|
+
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
150
|
+
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
151
|
+
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
152
|
+
"bedrock/anthropic.claude-v2:1",
|
153
|
+
"bedrock/anthropic.claude-v2",
|
154
|
+
"bedrock/anthropic.claude-instant-v1",
|
155
|
+
"bedrock/meta.llama3-1-405b-instruct-v1:0",
|
156
|
+
"bedrock/meta.llama3-1-70b-instruct-v1:0",
|
157
|
+
"bedrock/meta.llama3-1-8b-instruct-v1:0",
|
158
|
+
"bedrock/meta.llama3-70b-instruct-v1:0",
|
159
|
+
"bedrock/meta.llama3-8b-instruct-v1:0",
|
160
|
+
"bedrock/amazon.titan-text-lite-v1",
|
161
|
+
"bedrock/amazon.titan-text-express-v1",
|
162
|
+
"bedrock/cohere.command-text-v14",
|
163
|
+
"bedrock/ai21.j2-mid-v1",
|
164
|
+
"bedrock/ai21.j2-ultra-v1",
|
165
|
+
"bedrock/ai21.jamba-instruct-v1:0",
|
166
|
+
"bedrock/meta.llama2-13b-chat-v1",
|
167
|
+
"bedrock/meta.llama2-70b-chat-v1",
|
168
|
+
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
169
|
+
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
|
170
|
+
],
|
171
|
+
}
|
172
|
+
|
173
|
+
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
versionhq/llm/model.py
ADDED
@@ -0,0 +1,245 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
import threading
|
5
|
+
import warnings
|
6
|
+
import litellm
|
7
|
+
from dotenv import load_dotenv
|
8
|
+
from litellm import get_supported_openai_params
|
9
|
+
from contextlib import contextmanager
|
10
|
+
from typing import Any, Dict, List, Optional, Union
|
11
|
+
|
12
|
+
from versionhq.llm.llm_vars import LLM_CONTEXT_WINDOW_SIZES
|
13
|
+
from versionhq.task import TaskOutputFormat
|
14
|
+
from versionhq.task.model import ResponseField
|
15
|
+
|
16
|
+
load_dotenv(override=True)
|
17
|
+
API_KEY_LITELLM = os.environ.get("API_KEY_LITELLM")
|
18
|
+
DEFAULT_CONTEXT_WINDOW = int(8192 * 0.75)
|
19
|
+
os.environ["LITELLM_LOG"] = "DEBUG"
|
20
|
+
|
21
|
+
|
22
|
+
class FilteredStream:
|
23
|
+
def __init__(self, original_stream):
|
24
|
+
self._original_stream = original_stream
|
25
|
+
self._lock = threading.Lock()
|
26
|
+
|
27
|
+
def write(self, s) -> int:
|
28
|
+
with self._lock:
|
29
|
+
if (
|
30
|
+
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
31
|
+
in s
|
32
|
+
or "LiteLLM.Info: If you need to debug this error, use `os.environ['LITELLM_LOG'] = 'DEBUG'`"
|
33
|
+
in s
|
34
|
+
):
|
35
|
+
return 0
|
36
|
+
return self._original_stream.write(s)
|
37
|
+
|
38
|
+
def flush(self):
|
39
|
+
with self._lock:
|
40
|
+
return self._original_stream.flush()
|
41
|
+
|
42
|
+
|
43
|
+
@contextmanager
|
44
|
+
def suppress_warnings():
|
45
|
+
with warnings.catch_warnings():
|
46
|
+
warnings.filterwarnings("ignore")
|
47
|
+
old_stdout = sys.stdout
|
48
|
+
old_stderr = sys.stderr
|
49
|
+
sys.stdout = FilteredStream(old_stdout)
|
50
|
+
sys.stderr = FilteredStream(old_stderr)
|
51
|
+
|
52
|
+
try:
|
53
|
+
yield
|
54
|
+
finally:
|
55
|
+
sys.stdout = old_stdout
|
56
|
+
sys.stderr = old_stderr
|
57
|
+
|
58
|
+
|
59
|
+
class LLMResponseSchema:
|
60
|
+
"""
|
61
|
+
Use the response schema for LLM response.
|
62
|
+
`field_list` contains the title, value type, bool if required of each field that needs to be returned.
|
63
|
+
field_list: [{ title, type, required } ]
|
64
|
+
|
65
|
+
i.e., reponse_schema
|
66
|
+
response_type: "array" *options: "array", "dict"
|
67
|
+
propeties: { "recipe_name": { "type": "string" }, },
|
68
|
+
required: ["recipe_name"]
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(self, response_type: str, field_list: List[ResponseField]):
|
72
|
+
self.type = response_type
|
73
|
+
self.field_list = field_list
|
74
|
+
|
75
|
+
@property
|
76
|
+
def schema(self):
|
77
|
+
if len(self.field_list) == 0:
|
78
|
+
return
|
79
|
+
|
80
|
+
properties = [
|
81
|
+
{
|
82
|
+
field.title: {
|
83
|
+
"type": field.type,
|
84
|
+
}
|
85
|
+
}
|
86
|
+
for field in self.field_list
|
87
|
+
]
|
88
|
+
required = [field.title for field in self.field_list if field.required == True]
|
89
|
+
response_schema = {
|
90
|
+
"type": self.type,
|
91
|
+
"items": {"type": "object", "properties": {*properties}},
|
92
|
+
"required": required,
|
93
|
+
}
|
94
|
+
return response_schema
|
95
|
+
|
96
|
+
|
97
|
+
class LLM:
|
98
|
+
"""
|
99
|
+
Use LiteLLM to connect with the model of choice.
|
100
|
+
(Memo) Response formats will be given at the Task handling.
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(
|
104
|
+
self,
|
105
|
+
model: str,
|
106
|
+
timeout: Optional[Union[float, int]] = None,
|
107
|
+
max_tokens: Optional[int] = None,
|
108
|
+
max_completion_tokens: Optional[int] = None,
|
109
|
+
context_window_size: Optional[int] = DEFAULT_CONTEXT_WINDOW,
|
110
|
+
callbacks: List[Any] = [],
|
111
|
+
temperature: Optional[float] = None,
|
112
|
+
top_p: Optional[float] = None,
|
113
|
+
n: Optional[int] = None,
|
114
|
+
stop: Optional[Union[str, List[str]]] = None,
|
115
|
+
presence_penalty: Optional[float] = None,
|
116
|
+
frequency_penalty: Optional[float] = None,
|
117
|
+
logit_bias: Optional[Dict[int, float]] = None,
|
118
|
+
# response_format: Optional[Dict[str, Any]] = None,
|
119
|
+
seed: Optional[int] = None,
|
120
|
+
logprobs: Optional[bool] = None,
|
121
|
+
top_logprobs: Optional[int] = None,
|
122
|
+
base_url: Optional[str] = None,
|
123
|
+
api_version: Optional[str] = None,
|
124
|
+
api_key: Optional[str] = None,
|
125
|
+
**kwargs,
|
126
|
+
):
|
127
|
+
self.model = model
|
128
|
+
self.timeout = timeout
|
129
|
+
self.max_tokens = max_tokens
|
130
|
+
self.max_completion_tokens = max_completion_tokens
|
131
|
+
self.context_window_size = context_window_size
|
132
|
+
self.callbacks = callbacks
|
133
|
+
|
134
|
+
self.temperature = temperature
|
135
|
+
self.top_p = top_p
|
136
|
+
self.n = n
|
137
|
+
self.stop = stop
|
138
|
+
self.presence_penalty = presence_penalty
|
139
|
+
self.frequency_penalty = frequency_penalty
|
140
|
+
self.logit_bias = logit_bias
|
141
|
+
# self.response_format = response_format
|
142
|
+
self.seed = seed
|
143
|
+
self.logprobs = logprobs
|
144
|
+
self.top_logprobs = top_logprobs
|
145
|
+
|
146
|
+
self.base_url = base_url
|
147
|
+
self.api_version = api_version
|
148
|
+
self.api_key = api_key if api_key else API_KEY_LITELLM
|
149
|
+
|
150
|
+
self.kwargs = kwargs
|
151
|
+
|
152
|
+
litellm.drop_params = True
|
153
|
+
self.set_callbacks(callbacks)
|
154
|
+
|
155
|
+
def call(
|
156
|
+
self,
|
157
|
+
output_formats: List[TaskOutputFormat],
|
158
|
+
field_list: Optional[List[ResponseField]],
|
159
|
+
messages: List[Dict[str, str]],
|
160
|
+
callbacks: List[Any] = [],
|
161
|
+
) -> str:
|
162
|
+
"""
|
163
|
+
Execute LLM based on Agent's controls.
|
164
|
+
"""
|
165
|
+
|
166
|
+
with suppress_warnings():
|
167
|
+
if callbacks and len(callbacks) > 0:
|
168
|
+
self.set_callbacks(callbacks)
|
169
|
+
|
170
|
+
try:
|
171
|
+
response_format = None
|
172
|
+
|
173
|
+
#! REFINEME
|
174
|
+
if TaskOutputFormat.JSON in output_formats:
|
175
|
+
response_format = LLMResponseSchema(
|
176
|
+
response_type="json_object", field_list=field_list
|
177
|
+
)
|
178
|
+
|
179
|
+
params = {
|
180
|
+
"model": self.model,
|
181
|
+
"messages": messages,
|
182
|
+
"timeout": self.timeout,
|
183
|
+
"temperature": self.temperature,
|
184
|
+
"top_p": self.top_p,
|
185
|
+
"n": self.n,
|
186
|
+
"stop": self.stop,
|
187
|
+
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
188
|
+
"presence_penalty": self.presence_penalty,
|
189
|
+
"frequency_penalty": self.frequency_penalty,
|
190
|
+
"logit_bias": self.logit_bias,
|
191
|
+
# "response_format": response_format,
|
192
|
+
"seed": self.seed,
|
193
|
+
"logprobs": self.logprobs,
|
194
|
+
"top_logprobs": self.top_logprobs,
|
195
|
+
"api_base": self.base_url,
|
196
|
+
"api_version": self.api_version,
|
197
|
+
"api_key": self.api_key,
|
198
|
+
"stream": False,
|
199
|
+
**self.kwargs,
|
200
|
+
}
|
201
|
+
params = {k: v for k, v in params.items() if v is not None}
|
202
|
+
res = litellm.completion(**params)
|
203
|
+
return res["choices"][0]["message"]["content"]
|
204
|
+
|
205
|
+
except Exception as e:
|
206
|
+
logging.error(f"LiteLLM call failed: {str(e)}")
|
207
|
+
return None
|
208
|
+
|
209
|
+
def supports_function_calling(self) -> bool:
|
210
|
+
try:
|
211
|
+
params = get_supported_openai_params(model=self.model)
|
212
|
+
return "response_format" in params
|
213
|
+
except Exception as e:
|
214
|
+
logging.error(f"Failed to get supported params: {str(e)}")
|
215
|
+
return False
|
216
|
+
|
217
|
+
def supports_stop_words(self) -> bool:
|
218
|
+
try:
|
219
|
+
params = get_supported_openai_params(model=self.model)
|
220
|
+
return "stop" in params
|
221
|
+
except Exception as e:
|
222
|
+
logging.error(f"Failed to get supported params: {str(e)}")
|
223
|
+
return False
|
224
|
+
|
225
|
+
def get_context_window_size(self) -> int:
|
226
|
+
"""
|
227
|
+
Only use 75% of the context window size to avoid cutting the message in the middle.
|
228
|
+
"""
|
229
|
+
return (
|
230
|
+
int(LLM_CONTEXT_WINDOW_SIZES.get(self.model) * 0.75)
|
231
|
+
if hasattr(LLM_CONTEXT_WINDOW_SIZES, self.model)
|
232
|
+
else DEFAULT_CONTEXT_WINDOW
|
233
|
+
)
|
234
|
+
|
235
|
+
def set_callbacks(self, callbacks: List[Any]):
|
236
|
+
callback_types = [type(callback) for callback in callbacks]
|
237
|
+
for callback in litellm.success_callback[:]:
|
238
|
+
if type(callback) in callback_types:
|
239
|
+
litellm.success_callback.remove(callback)
|
240
|
+
|
241
|
+
for callback in litellm._async_success_callback[:]:
|
242
|
+
if type(callback) in callback_types:
|
243
|
+
litellm._async_success_callback.remove(callback)
|
244
|
+
|
245
|
+
litellm.callbacks = callbacks
|