alita-sdk 0.3.205__py3-none-any.whl → 0.3.207__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/runtime/clients/client.py +314 -11
- alita_sdk/runtime/langchain/assistant.py +22 -21
- alita_sdk/runtime/langchain/interfaces/llm_processor.py +1 -4
- alita_sdk/runtime/langchain/langraph_agent.py +6 -1
- alita_sdk/runtime/langchain/store_manager.py +4 -4
- alita_sdk/runtime/toolkits/application.py +5 -10
- alita_sdk/runtime/toolkits/tools.py +11 -21
- alita_sdk/runtime/tools/vectorstore.py +25 -11
- alita_sdk/runtime/utils/streamlit.py +505 -222
- alita_sdk/runtime/utils/toolkit_runtime.py +147 -0
- alita_sdk/runtime/utils/toolkit_utils.py +157 -0
- alita_sdk/runtime/utils/utils.py +5 -0
- alita_sdk/tools/__init__.py +2 -0
- alita_sdk/tools/ado/repos/repos_wrapper.py +20 -13
- alita_sdk/tools/bitbucket/api_wrapper.py +5 -5
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +54 -29
- alita_sdk/tools/elitea_base.py +9 -4
- alita_sdk/tools/gitlab/__init__.py +22 -10
- alita_sdk/tools/gitlab/api_wrapper.py +278 -253
- alita_sdk/tools/gitlab/tools.py +354 -376
- alita_sdk/tools/llm/llm_utils.py +0 -6
- alita_sdk/tools/memory/__init__.py +54 -10
- alita_sdk/tools/openapi/__init__.py +14 -3
- alita_sdk/tools/sharepoint/__init__.py +2 -1
- alita_sdk/tools/sharepoint/api_wrapper.py +11 -3
- alita_sdk/tools/testrail/api_wrapper.py +39 -16
- alita_sdk/tools/utils/content_parser.py +77 -13
- {alita_sdk-0.3.205.dist-info → alita_sdk-0.3.207.dist-info}/METADATA +1 -1
- {alita_sdk-0.3.205.dist-info → alita_sdk-0.3.207.dist-info}/RECORD +32 -40
- alita_sdk/community/analysis/__init__.py +0 -0
- alita_sdk/community/analysis/ado_analyse/__init__.py +0 -103
- alita_sdk/community/analysis/ado_analyse/api_wrapper.py +0 -261
- alita_sdk/community/analysis/github_analyse/__init__.py +0 -98
- alita_sdk/community/analysis/github_analyse/api_wrapper.py +0 -166
- alita_sdk/community/analysis/gitlab_analyse/__init__.py +0 -110
- alita_sdk/community/analysis/gitlab_analyse/api_wrapper.py +0 -172
- alita_sdk/community/analysis/jira_analyse/__init__.py +0 -141
- alita_sdk/community/analysis/jira_analyse/api_wrapper.py +0 -252
- alita_sdk/runtime/llms/alita.py +0 -259
- {alita_sdk-0.3.205.dist-info → alita_sdk-0.3.207.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.205.dist-info → alita_sdk-0.3.207.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.205.dist-info → alita_sdk-0.3.207.dist-info}/top_level.txt +0 -0
@@ -1,252 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
from io import StringIO
|
3
|
-
from typing import Optional, List, Dict, Any
|
4
|
-
from langchain_core.callbacks import dispatch_custom_event
|
5
|
-
from langchain_core.tools import ToolException
|
6
|
-
from pydantic import BaseModel, Field
|
7
|
-
from jira import JIRA
|
8
|
-
import pandas as pd
|
9
|
-
|
10
|
-
|
11
|
-
from elitea_analyse.utils.constants import OUTPUT_MAPPING_FILE, OUTPUT_WORK_ITEMS_FILE
|
12
|
-
from elitea_analyse.jira.jira_projects_overview import jira_projects_overview
|
13
|
-
from elitea_analyse.jira.jira_statuses import get_all_statuses_list
|
14
|
-
from elitea_analyse.jira.jira_issues import JiraIssues
|
15
|
-
|
16
|
-
from alita_sdk.tools.elitea_base import BaseToolApiWrapper
|
17
|
-
from alita_sdk.runtime.tools.artifact import ArtifactWrapper
|
18
|
-
from alita_sdk.runtime.utils.logging import with_streamlit_logs
|
19
|
-
|
20
|
-
logger = logging.getLogger(__name__)
|
21
|
-
|
22
|
-
|
23
|
-
class GetJiraFieldsArgs(BaseModel):
|
24
|
-
project_keys: Optional[str] = Field(
|
25
|
-
description="One or more projects keys separated with comma.",
|
26
|
-
default=''
|
27
|
-
)
|
28
|
-
after_date: str = Field(description="Date after which issues are considered.")
|
29
|
-
|
30
|
-
|
31
|
-
class GetJiraIssuesArgs(BaseModel):
|
32
|
-
project_keys: Optional[str] = Field(
|
33
|
-
description="One or more projects keys separated with comma.", default=''
|
34
|
-
)
|
35
|
-
closed_issues_based_on: int = Field(
|
36
|
-
description=("Define whether issues can be thought as closed based on their status (1) "
|
37
|
-
"or not empty resolved date (2).")
|
38
|
-
)
|
39
|
-
resolved_after: str = Field(description="Resolved after date (i.e. 2023-01-01).")
|
40
|
-
updated_after: str = Field(description="Updated after date (i.e. 2023-01-01).")
|
41
|
-
created_after: str = Field(description="Created after date (i.e. 2023-01-01).")
|
42
|
-
add_filter: Optional[str] = Field(
|
43
|
-
description=("Additional filter for Jira issues in JQL format like "
|
44
|
-
"'customfield_10000 = 'value' AND customfield_10001 = 'value'")
|
45
|
-
)
|
46
|
-
|
47
|
-
|
48
|
-
class JiraAnalyseWrapper(BaseToolApiWrapper):
|
49
|
-
artifacts_wrapper: ArtifactWrapper
|
50
|
-
jira: JIRA
|
51
|
-
project_keys: str # Jira project keys
|
52
|
-
closed_status: str # Jira ticket closed statuses
|
53
|
-
defects_name: str # Jira ticket defects name
|
54
|
-
custom_fields: dict # Jira ticket custom fields
|
55
|
-
|
56
|
-
class Config:
|
57
|
-
arbitrary_types_allowed = True
|
58
|
-
|
59
|
-
def get_number_off_all_issues(self, after_date: str, project_keys: Optional[str] = None):
|
60
|
-
"""
|
61
|
-
Get projects a user has access to and merge them with issues count.
|
62
|
-
after_date: str
|
63
|
-
date after which issues are considered
|
64
|
-
project_keys: str
|
65
|
-
one or more projects keys separated with comma
|
66
|
-
"""
|
67
|
-
project_keys = project_keys or self.project_keys
|
68
|
-
|
69
|
-
dispatch_custom_event(
|
70
|
-
name="thinking_step",
|
71
|
-
data={
|
72
|
-
"message": f"I am extracting number of all issues with initial parameters:\
|
73
|
-
project keys: {project_keys}, after date: {after_date}",
|
74
|
-
"tool_name": "get_number_off_all_issues",
|
75
|
-
"toolkit": "analyse_jira",
|
76
|
-
},
|
77
|
-
)
|
78
|
-
|
79
|
-
project_df = jira_projects_overview(
|
80
|
-
after_date, project_keys=project_keys, jira=self.jira
|
81
|
-
)
|
82
|
-
|
83
|
-
# Save project_df DataFrame into the bucket
|
84
|
-
self.save_dataframe(
|
85
|
-
project_df,
|
86
|
-
f"projects_overview_{project_keys}.csv",
|
87
|
-
csv_options={"index": False},
|
88
|
-
)
|
89
|
-
return {
|
90
|
-
"projects": project_df["key"].tolist(),
|
91
|
-
"projects_summary": project_df.to_string(),
|
92
|
-
}
|
93
|
-
|
94
|
-
@with_streamlit_logs(tool_name="get_jira_issues")
|
95
|
-
def get_jira_issues(
|
96
|
-
self,
|
97
|
-
closed_issues_based_on: int,
|
98
|
-
resolved_after: str,
|
99
|
-
updated_after: str,
|
100
|
-
created_after: str,
|
101
|
-
add_filter: str = "",
|
102
|
-
project_keys: Optional[str] = None,
|
103
|
-
):
|
104
|
-
"""
|
105
|
-
Extract Jira issues for the specified projects.
|
106
|
-
closed_issues_based_on: int
|
107
|
-
define whether issues can be thought as
|
108
|
-
closed based on their status (1) or not empty resolved date (2)
|
109
|
-
resolved_after: str
|
110
|
-
resolved after date (i.e. 2023-01-01)
|
111
|
-
updated_after: str
|
112
|
-
updated after date (i.e. 2023-01-01)
|
113
|
-
created_after: str
|
114
|
-
created after date (i.e. 2023-01-01)
|
115
|
-
add_filter: str
|
116
|
-
additional filter for Jira issues in JQL format
|
117
|
-
like "customfield_10000 = 'value' AND customfield_10001 = 'value'"
|
118
|
-
project_keys: str
|
119
|
-
one or more projects keys separated with comma
|
120
|
-
"""
|
121
|
-
|
122
|
-
if not (
|
123
|
-
(
|
124
|
-
closed_issues_based_on == 1
|
125
|
-
and self.closed_status in get_all_statuses_list(jira=self.jira)
|
126
|
-
)
|
127
|
-
or closed_issues_based_on == 2
|
128
|
-
):
|
129
|
-
return (
|
130
|
-
f"ERROR: Check input parameters closed_issues_based_on ({closed_issues_based_on}) "
|
131
|
-
f"and closed_status ({self.closed_status}) not in Jira statuses list."
|
132
|
-
)
|
133
|
-
|
134
|
-
project_keys = project_keys or self.project_keys
|
135
|
-
|
136
|
-
dispatch_custom_event(
|
137
|
-
name="thinking_step",
|
138
|
-
data={
|
139
|
-
"message": f"I am extracting Jira issues with initial parameters:\
|
140
|
-
project keys: {project_keys}, closed status: {self.closed_status},\
|
141
|
-
defects name: {self.defects_name}, custom fields: {self.custom_fields}, \
|
142
|
-
closed status based on: {closed_issues_based_on}, resolved after: {resolved_after}, \
|
143
|
-
updated after: {updated_after}, created after: {created_after}, additional filter:{add_filter}",
|
144
|
-
"tool_name": "jira_issues_extraction_start",
|
145
|
-
"toolkit": "analyse_jira",
|
146
|
-
},
|
147
|
-
)
|
148
|
-
|
149
|
-
jira_issues = JiraIssues(
|
150
|
-
self.jira,
|
151
|
-
project_keys,
|
152
|
-
(closed_issues_based_on, self.closed_status),
|
153
|
-
self.defects_name,
|
154
|
-
add_filter="",
|
155
|
-
)
|
156
|
-
|
157
|
-
df_issues, df_map = jira_issues.extract_issues_from_jira_and_transform(
|
158
|
-
self.custom_fields, (resolved_after, updated_after, created_after)
|
159
|
-
)
|
160
|
-
|
161
|
-
dispatch_custom_event(
|
162
|
-
name="thinking_step",
|
163
|
-
data={
|
164
|
-
"message": f"I am saving the extracted Jira issues to the artifact repository. \
|
165
|
-
issues count: {len(df_issues)}, mapping rows: {len(df_map)}, \
|
166
|
-
output file: {OUTPUT_MAPPING_FILE}{jira_issues.projects}.csv",
|
167
|
-
"tool_name": "get_jira_issues",
|
168
|
-
"toolkit": "analyse_jira",
|
169
|
-
},
|
170
|
-
)
|
171
|
-
self.save_dataframe(
|
172
|
-
df_map,
|
173
|
-
f"{OUTPUT_MAPPING_FILE}{jira_issues.projects}.csv",
|
174
|
-
csv_options={"index_label": "id"},
|
175
|
-
)
|
176
|
-
|
177
|
-
if not df_issues.empty:
|
178
|
-
self.save_dataframe(
|
179
|
-
df_issues,
|
180
|
-
f"{OUTPUT_WORK_ITEMS_FILE}{jira_issues.projects}.csv",
|
181
|
-
csv_options={"index_label": "id"},
|
182
|
-
)
|
183
|
-
dispatch_custom_event(
|
184
|
-
name="thinking_step",
|
185
|
-
data={
|
186
|
-
"message": f"Saving Jira issues to the file . \
|
187
|
-
output file: {OUTPUT_WORK_ITEMS_FILE}{jira_issues.projects}.csv,\
|
188
|
-
row count: {len(df_issues)}",
|
189
|
-
"tool_name": "get_jira_issues",
|
190
|
-
"toolkit": "analyse_jira",
|
191
|
-
},
|
192
|
-
)
|
193
|
-
|
194
|
-
return f"{jira_issues.projects} Data has been extracted successfully."
|
195
|
-
|
196
|
-
def get_available_tools(self) -> List[Dict[str, Any]]:
|
197
|
-
"""Get a list of available tools."""
|
198
|
-
return [
|
199
|
-
{
|
200
|
-
"name": "get_number_off_all_issues",
|
201
|
-
"description": self.get_number_off_all_issues.__doc__,
|
202
|
-
"args_schema": GetJiraFieldsArgs,
|
203
|
-
"ref": self.get_number_off_all_issues,
|
204
|
-
},
|
205
|
-
{
|
206
|
-
"name": "get_jira_issues",
|
207
|
-
"description": self.get_jira_issues.__doc__,
|
208
|
-
"args_schema": GetJiraIssuesArgs,
|
209
|
-
"ref": self.get_jira_issues,
|
210
|
-
},
|
211
|
-
]
|
212
|
-
|
213
|
-
def save_dataframe(
|
214
|
-
self,
|
215
|
-
df: pd.DataFrame,
|
216
|
-
target_file: str,
|
217
|
-
csv_options: Optional[Dict[str, Any]] = None,
|
218
|
-
):
|
219
|
-
"""
|
220
|
-
Save a pandas DataFrame as a CSV file in the artifact repository using the ArtifactWrapper.
|
221
|
-
|
222
|
-
Args:
|
223
|
-
df (pd.DataFrame): The DataFrame to save.
|
224
|
-
target_file (str): The target file name in the storage (e.g., "file.csv").
|
225
|
-
csv_options: Dictionary of options to pass to Dataframe.to_csv()
|
226
|
-
|
227
|
-
Raises:
|
228
|
-
ValueError: If the DataFrame is empty or the file name is invalid.
|
229
|
-
Exception: If saving to the artifact repository fails.
|
230
|
-
"""
|
231
|
-
csv_options = csv_options or {}
|
232
|
-
|
233
|
-
# Use StringIO to save the DataFrame as a string
|
234
|
-
try:
|
235
|
-
buffer = StringIO()
|
236
|
-
df.to_csv(buffer, **csv_options)
|
237
|
-
self.artifacts_wrapper.create_file(target_file, buffer.getvalue())
|
238
|
-
logger.info(
|
239
|
-
f"Successfully saved dataframe to {target_file} in bucket {self.artifacts_wrapper.bucket}"
|
240
|
-
)
|
241
|
-
except Exception as e:
|
242
|
-
logger.exception("Failed to save DataFrame to artifact repository")
|
243
|
-
return ToolException(
|
244
|
-
f"Failed to save DataFrame to artifact repository: {str(e)}"
|
245
|
-
)
|
246
|
-
|
247
|
-
def run(self, mode: str, *args: Any, **kwargs: Any):
|
248
|
-
for tool in self.get_available_tools():
|
249
|
-
if tool["name"] == mode:
|
250
|
-
return tool["ref"](*args, **kwargs)
|
251
|
-
|
252
|
-
raise ValueError(f"Unknown mode: {mode}")
|
alita_sdk/runtime/llms/alita.py
DELETED
@@ -1,259 +0,0 @@
|
|
1
|
-
# Copyright (c) 2023 Artem Rozumenko
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
|
16
|
-
#
|
17
|
-
# This is adoption of https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
|
18
|
-
#
|
19
|
-
|
20
|
-
import logging
|
21
|
-
import requests
|
22
|
-
from time import sleep
|
23
|
-
from traceback import format_exc
|
24
|
-
|
25
|
-
from typing import Any, List, Optional, AsyncIterator, Dict, Iterator, Mapping, Type
|
26
|
-
from tiktoken import get_encoding, encoding_for_model
|
27
|
-
from langchain_core.callbacks import (
|
28
|
-
AsyncCallbackManagerForLLMRun,
|
29
|
-
CallbackManagerForLLMRun,
|
30
|
-
)
|
31
|
-
from langchain_core.language_models import BaseChatModel, SimpleChatModel
|
32
|
-
from langchain_core.messages import (AIMessageChunk, BaseMessage, HumanMessage, HumanMessageChunk, ChatMessageChunk,
|
33
|
-
FunctionMessageChunk, SystemMessageChunk, ToolMessageChunk, BaseMessageChunk,
|
34
|
-
AIMessage)
|
35
|
-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
36
|
-
from langchain_core.runnables import run_in_executor
|
37
|
-
from langchain_community.chat_models.openai import generate_from_stream, _convert_delta_to_message_chunk
|
38
|
-
from ..clients.client import AlitaClient
|
39
|
-
from pydantic import Field, model_validator, field_validator, ValidationInfo
|
40
|
-
|
41
|
-
logger = logging.getLogger(__name__)
|
42
|
-
|
43
|
-
|
44
|
-
class MaxRetriesExceededError(Exception):
|
45
|
-
"""Raised when the maximum number of retries is exceeded"""
|
46
|
-
|
47
|
-
def __init__(self, message="Maximum number of retries exceeded"):
|
48
|
-
self.message = message
|
49
|
-
super().__init__(self.message)
|
50
|
-
|
51
|
-
|
52
|
-
class AlitaChatModel(BaseChatModel):
|
53
|
-
class Config:
|
54
|
-
populate_by_name = True
|
55
|
-
|
56
|
-
client: Any #: :meta private:
|
57
|
-
encoding: Any #: :meta private:
|
58
|
-
deployment: str = Field(default="https://eye.projectalita.ai", alias="base_url")
|
59
|
-
api_token: str = Field(default=None, alias="api_key")
|
60
|
-
project_id: int = None
|
61
|
-
model_name: Optional[str] = Field(default="gpt-35-turbo", alias="model")
|
62
|
-
integration_uid: Optional[str] = None
|
63
|
-
max_tokens: Optional[int] = 512
|
64
|
-
tiktoken_model_name: Optional[str] = None
|
65
|
-
tiktoken_encoding_name: Optional[str] = 'cl100k_base'
|
66
|
-
max_retries: Optional[int] = 2
|
67
|
-
temperature: Optional[float] = 0.7
|
68
|
-
top_p: Optional[float] = 0.9
|
69
|
-
top_k: Optional[int] = 20
|
70
|
-
stream_response: Optional[bool] = Field(default=False, alias="stream")
|
71
|
-
api_extra_headers: Optional[dict] = Field(default_factory=dict)
|
72
|
-
configurations: Optional[list] = Field(default_factory=list)
|
73
|
-
|
74
|
-
@model_validator(mode="before")
|
75
|
-
@classmethod
|
76
|
-
def validate_env(cls, values: dict) -> Dict:
|
77
|
-
values['client'] = AlitaClient(
|
78
|
-
values.get('deployment', values.get('base_url', "https://eye.projectalita.ai")),
|
79
|
-
values['project_id'],
|
80
|
-
values.get('api_token', values.get('api_key')),
|
81
|
-
api_extra_headers=values.get('api_extra_headers', {}),
|
82
|
-
configurations=values.get('configurations', [])
|
83
|
-
)
|
84
|
-
if values.get("tiktoken_model_name"):
|
85
|
-
values["encoding"] = encoding_for_model(values["tiktoken_model_name"])
|
86
|
-
else:
|
87
|
-
values['encoding'] = get_encoding('cl100k_base')
|
88
|
-
return values
|
89
|
-
|
90
|
-
def _generate(
|
91
|
-
self,
|
92
|
-
messages: List[BaseMessage],
|
93
|
-
stop: Optional[List[str]] = None,
|
94
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
95
|
-
**kwargs: Any,
|
96
|
-
) -> ChatResult:
|
97
|
-
|
98
|
-
# TODO: Implement streaming
|
99
|
-
|
100
|
-
if self.stream_response:
|
101
|
-
stream_iter = self._stream(
|
102
|
-
messages, stop=stop, run_manager=run_manager, **kwargs
|
103
|
-
)
|
104
|
-
return generate_from_stream(stream_iter)
|
105
|
-
self.stream_response = False
|
106
|
-
response = self.completion_with_retry(messages)
|
107
|
-
return self._create_chat_result(response)
|
108
|
-
|
109
|
-
|
110
|
-
def _stream(
|
111
|
-
self,
|
112
|
-
messages: List[BaseMessage],
|
113
|
-
stop: Optional[List[str]] = None,
|
114
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
115
|
-
**kwargs: Any,
|
116
|
-
) -> Iterator[ChatGenerationChunk]:
|
117
|
-
|
118
|
-
self.stream_response = True
|
119
|
-
default_chunk_class = AIMessageChunk
|
120
|
-
for chunk in self.completion_with_retry(messages):
|
121
|
-
if not isinstance(chunk, dict):
|
122
|
-
chunk = chunk.dict()
|
123
|
-
logger.debug(f"Chunk: {chunk}")
|
124
|
-
if "delta" in chunk:
|
125
|
-
chunk = _convert_delta_to_message_chunk(
|
126
|
-
chunk["delta"], default_chunk_class
|
127
|
-
)
|
128
|
-
finish_reason = chunk.get("z")
|
129
|
-
generation_info = (
|
130
|
-
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
131
|
-
)
|
132
|
-
default_chunk_class = chunk.__class__
|
133
|
-
cg_chunk = ChatGenerationChunk(
|
134
|
-
message=chunk, generation_info=generation_info
|
135
|
-
)
|
136
|
-
if run_manager:
|
137
|
-
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
|
138
|
-
yield cg_chunk
|
139
|
-
else:
|
140
|
-
message = _convert_delta_to_message_chunk(chunk, default_chunk_class)
|
141
|
-
finish_reason = None
|
142
|
-
generation_info = dict()
|
143
|
-
if stop:
|
144
|
-
for stop_word in stop:
|
145
|
-
if stop_word in message.content:
|
146
|
-
finish_reason = "stop"
|
147
|
-
message.z = finish_reason
|
148
|
-
break
|
149
|
-
generation_info = (dict(finish_reason=finish_reason))
|
150
|
-
logger.debug(f"message before getting to ChatGenerationChunk: {message}")
|
151
|
-
yield ChatGenerationChunk(message=message, generation_info=generation_info)
|
152
|
-
|
153
|
-
async def _astream(
|
154
|
-
self,
|
155
|
-
messages: List[BaseMessage],
|
156
|
-
stop: Optional[List[str]] = None,
|
157
|
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
158
|
-
**kwargs: Any,
|
159
|
-
) -> AsyncIterator[ChatGenerationChunk]:
|
160
|
-
iterator = await run_in_executor(
|
161
|
-
None,
|
162
|
-
self._stream,
|
163
|
-
messages,
|
164
|
-
stop,
|
165
|
-
run_manager.get_sync() if run_manager else None,
|
166
|
-
**kwargs,
|
167
|
-
)
|
168
|
-
done = object()
|
169
|
-
while True:
|
170
|
-
item: ChatGenerationChunk | object = await run_in_executor(
|
171
|
-
None,
|
172
|
-
next,
|
173
|
-
iterator,
|
174
|
-
done,
|
175
|
-
)
|
176
|
-
if item is done:
|
177
|
-
break
|
178
|
-
if isinstance(item, ChatGenerationChunk):
|
179
|
-
yield item
|
180
|
-
|
181
|
-
def _create_chat_result(self, response: list[BaseMessage]) -> ChatResult:
|
182
|
-
token_usage = 0
|
183
|
-
generations = []
|
184
|
-
for message in response:
|
185
|
-
token_usage += len(self.encoding.encode(message.content))
|
186
|
-
generations.append(ChatGeneration(message=message))
|
187
|
-
|
188
|
-
llm_output = {
|
189
|
-
"token_usage": token_usage,
|
190
|
-
"model_name": self.model_name,
|
191
|
-
}
|
192
|
-
|
193
|
-
return ChatResult(
|
194
|
-
generations=generations,
|
195
|
-
llm_output=llm_output,
|
196
|
-
)
|
197
|
-
|
198
|
-
def completion_with_retry(self, messages, retry_count=0):
|
199
|
-
try:
|
200
|
-
return self.client.predict(messages, self._get_model_default_parameters)
|
201
|
-
except requests.exceptions.HTTPError as e:
|
202
|
-
from json import loads
|
203
|
-
logger.error(f"ERROR: HTTPError in completion_with_retry: {e}, retry_count: {retry_count}")
|
204
|
-
sleep(60)
|
205
|
-
if retry_count >= self.max_retries:
|
206
|
-
logger.error(f"ERROR: Retry count exceeded: {format_exc()}")
|
207
|
-
raise MaxRetriesExceededError(format_exc())
|
208
|
-
return self.completion_with_retry(messages, retry_count+1)
|
209
|
-
except Exception as e:
|
210
|
-
logger.error(f"ERROR: Exception in completion_with_retry: {e}, retry_count: {retry_count}")
|
211
|
-
if retry_count >= self.max_retries:
|
212
|
-
logger.error(f"ERROR: Retry count exceeded: {format_exc()}")
|
213
|
-
raise MaxRetriesExceededError(format_exc())
|
214
|
-
return self.completion_with_retry(messages, retry_count+1)
|
215
|
-
|
216
|
-
|
217
|
-
# def _call(self, prompt:str, **kwargs: Any):
|
218
|
-
# """
|
219
|
-
# This is the main method that will be called when we run our LLM.
|
220
|
-
# """
|
221
|
-
# return self.client.predict([HumanMessage(content=prompt)], self._get_model_default_parameters)
|
222
|
-
|
223
|
-
@property
|
224
|
-
def _llm_type(self) -> str:
|
225
|
-
"""
|
226
|
-
This should return the type of the LLM.
|
227
|
-
"""
|
228
|
-
return self.model_name
|
229
|
-
|
230
|
-
@property
|
231
|
-
def _get_model_default_parameters(self):
|
232
|
-
return {
|
233
|
-
"temperature": self.temperature,
|
234
|
-
"top_k": self.top_k,
|
235
|
-
"top_p": self.top_p,
|
236
|
-
"max_tokens": self.max_tokens,
|
237
|
-
"stream": self.stream_response,
|
238
|
-
"model": {
|
239
|
-
"model_name": self.model_name,
|
240
|
-
"integration_uid": self.integration_uid,
|
241
|
-
}
|
242
|
-
}
|
243
|
-
|
244
|
-
@property
|
245
|
-
def _identifying_params(self) -> dict:
|
246
|
-
"""
|
247
|
-
It should return a dict that provides the information of all the parameters
|
248
|
-
that are used in the LLM. This is useful when we print our llm, it will give use the
|
249
|
-
information of all the parameters.
|
250
|
-
"""
|
251
|
-
return {
|
252
|
-
"deployment": self.deployment,
|
253
|
-
"api_token": self.api_token,
|
254
|
-
"project_id": self.project_id,
|
255
|
-
"integration_id": self.integration_uid,
|
256
|
-
"model_settings": self._get_model_default_parameters,
|
257
|
-
}
|
258
|
-
|
259
|
-
|
File without changes
|
File without changes
|
File without changes
|