alita-sdk 0.3.205__py3-none-any.whl → 0.3.206__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,259 +0,0 @@
1
- # Copyright (c) 2023 Artem Rozumenko
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- #
17
- # This is adoption of https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/openai.py
18
- #
19
-
20
- import logging
21
- import requests
22
- from time import sleep
23
- from traceback import format_exc
24
-
25
- from typing import Any, List, Optional, AsyncIterator, Dict, Iterator, Mapping, Type
26
- from tiktoken import get_encoding, encoding_for_model
27
- from langchain_core.callbacks import (
28
- AsyncCallbackManagerForLLMRun,
29
- CallbackManagerForLLMRun,
30
- )
31
- from langchain_core.language_models import BaseChatModel, SimpleChatModel
32
- from langchain_core.messages import (AIMessageChunk, BaseMessage, HumanMessage, HumanMessageChunk, ChatMessageChunk,
33
- FunctionMessageChunk, SystemMessageChunk, ToolMessageChunk, BaseMessageChunk,
34
- AIMessage)
35
- from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
36
- from langchain_core.runnables import run_in_executor
37
- from langchain_community.chat_models.openai import generate_from_stream, _convert_delta_to_message_chunk
38
- from ..clients.client import AlitaClient
39
- from pydantic import Field, model_validator, field_validator, ValidationInfo
40
-
41
- logger = logging.getLogger(__name__)
42
-
43
-
44
- class MaxRetriesExceededError(Exception):
45
- """Raised when the maximum number of retries is exceeded"""
46
-
47
- def __init__(self, message="Maximum number of retries exceeded"):
48
- self.message = message
49
- super().__init__(self.message)
50
-
51
-
52
- class AlitaChatModel(BaseChatModel):
53
- class Config:
54
- populate_by_name = True
55
-
56
- client: Any #: :meta private:
57
- encoding: Any #: :meta private:
58
- deployment: str = Field(default="https://eye.projectalita.ai", alias="base_url")
59
- api_token: str = Field(default=None, alias="api_key")
60
- project_id: int = None
61
- model_name: Optional[str] = Field(default="gpt-35-turbo", alias="model")
62
- integration_uid: Optional[str] = None
63
- max_tokens: Optional[int] = 512
64
- tiktoken_model_name: Optional[str] = None
65
- tiktoken_encoding_name: Optional[str] = 'cl100k_base'
66
- max_retries: Optional[int] = 2
67
- temperature: Optional[float] = 0.7
68
- top_p: Optional[float] = 0.9
69
- top_k: Optional[int] = 20
70
- stream_response: Optional[bool] = Field(default=False, alias="stream")
71
- api_extra_headers: Optional[dict] = Field(default_factory=dict)
72
- configurations: Optional[list] = Field(default_factory=list)
73
-
74
- @model_validator(mode="before")
75
- @classmethod
76
- def validate_env(cls, values: dict) -> Dict:
77
- values['client'] = AlitaClient(
78
- values.get('deployment', values.get('base_url', "https://eye.projectalita.ai")),
79
- values['project_id'],
80
- values.get('api_token', values.get('api_key')),
81
- api_extra_headers=values.get('api_extra_headers', {}),
82
- configurations=values.get('configurations', [])
83
- )
84
- if values.get("tiktoken_model_name"):
85
- values["encoding"] = encoding_for_model(values["tiktoken_model_name"])
86
- else:
87
- values['encoding'] = get_encoding('cl100k_base')
88
- return values
89
-
90
- def _generate(
91
- self,
92
- messages: List[BaseMessage],
93
- stop: Optional[List[str]] = None,
94
- run_manager: Optional[CallbackManagerForLLMRun] = None,
95
- **kwargs: Any,
96
- ) -> ChatResult:
97
-
98
- # TODO: Implement streaming
99
-
100
- if self.stream_response:
101
- stream_iter = self._stream(
102
- messages, stop=stop, run_manager=run_manager, **kwargs
103
- )
104
- return generate_from_stream(stream_iter)
105
- self.stream_response = False
106
- response = self.completion_with_retry(messages)
107
- return self._create_chat_result(response)
108
-
109
-
110
- def _stream(
111
- self,
112
- messages: List[BaseMessage],
113
- stop: Optional[List[str]] = None,
114
- run_manager: Optional[CallbackManagerForLLMRun] = None,
115
- **kwargs: Any,
116
- ) -> Iterator[ChatGenerationChunk]:
117
-
118
- self.stream_response = True
119
- default_chunk_class = AIMessageChunk
120
- for chunk in self.completion_with_retry(messages):
121
- if not isinstance(chunk, dict):
122
- chunk = chunk.dict()
123
- logger.debug(f"Chunk: {chunk}")
124
- if "delta" in chunk:
125
- chunk = _convert_delta_to_message_chunk(
126
- chunk["delta"], default_chunk_class
127
- )
128
- finish_reason = chunk.get("z")
129
- generation_info = (
130
- dict(finish_reason=finish_reason) if finish_reason is not None else None
131
- )
132
- default_chunk_class = chunk.__class__
133
- cg_chunk = ChatGenerationChunk(
134
- message=chunk, generation_info=generation_info
135
- )
136
- if run_manager:
137
- run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
138
- yield cg_chunk
139
- else:
140
- message = _convert_delta_to_message_chunk(chunk, default_chunk_class)
141
- finish_reason = None
142
- generation_info = dict()
143
- if stop:
144
- for stop_word in stop:
145
- if stop_word in message.content:
146
- finish_reason = "stop"
147
- message.z = finish_reason
148
- break
149
- generation_info = (dict(finish_reason=finish_reason))
150
- logger.debug(f"message before getting to ChatGenerationChunk: {message}")
151
- yield ChatGenerationChunk(message=message, generation_info=generation_info)
152
-
153
- async def _astream(
154
- self,
155
- messages: List[BaseMessage],
156
- stop: Optional[List[str]] = None,
157
- run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
158
- **kwargs: Any,
159
- ) -> AsyncIterator[ChatGenerationChunk]:
160
- iterator = await run_in_executor(
161
- None,
162
- self._stream,
163
- messages,
164
- stop,
165
- run_manager.get_sync() if run_manager else None,
166
- **kwargs,
167
- )
168
- done = object()
169
- while True:
170
- item: ChatGenerationChunk | object = await run_in_executor(
171
- None,
172
- next,
173
- iterator,
174
- done,
175
- )
176
- if item is done:
177
- break
178
- if isinstance(item, ChatGenerationChunk):
179
- yield item
180
-
181
- def _create_chat_result(self, response: list[BaseMessage]) -> ChatResult:
182
- token_usage = 0
183
- generations = []
184
- for message in response:
185
- token_usage += len(self.encoding.encode(message.content))
186
- generations.append(ChatGeneration(message=message))
187
-
188
- llm_output = {
189
- "token_usage": token_usage,
190
- "model_name": self.model_name,
191
- }
192
-
193
- return ChatResult(
194
- generations=generations,
195
- llm_output=llm_output,
196
- )
197
-
198
- def completion_with_retry(self, messages, retry_count=0):
199
- try:
200
- return self.client.predict(messages, self._get_model_default_parameters)
201
- except requests.exceptions.HTTPError as e:
202
- from json import loads
203
- logger.error(f"ERROR: HTTPError in completion_with_retry: {e}, retry_count: {retry_count}")
204
- sleep(60)
205
- if retry_count >= self.max_retries:
206
- logger.error(f"ERROR: Retry count exceeded: {format_exc()}")
207
- raise MaxRetriesExceededError(format_exc())
208
- return self.completion_with_retry(messages, retry_count+1)
209
- except Exception as e:
210
- logger.error(f"ERROR: Exception in completion_with_retry: {e}, retry_count: {retry_count}")
211
- if retry_count >= self.max_retries:
212
- logger.error(f"ERROR: Retry count exceeded: {format_exc()}")
213
- raise MaxRetriesExceededError(format_exc())
214
- return self.completion_with_retry(messages, retry_count+1)
215
-
216
-
217
- # def _call(self, prompt:str, **kwargs: Any):
218
- # """
219
- # This is the main method that will be called when we run our LLM.
220
- # """
221
- # return self.client.predict([HumanMessage(content=prompt)], self._get_model_default_parameters)
222
-
223
- @property
224
- def _llm_type(self) -> str:
225
- """
226
- This should return the type of the LLM.
227
- """
228
- return self.model_name
229
-
230
- @property
231
- def _get_model_default_parameters(self):
232
- return {
233
- "temperature": self.temperature,
234
- "top_k": self.top_k,
235
- "top_p": self.top_p,
236
- "max_tokens": self.max_tokens,
237
- "stream": self.stream_response,
238
- "model": {
239
- "model_name": self.model_name,
240
- "integration_uid": self.integration_uid,
241
- }
242
- }
243
-
244
- @property
245
- def _identifying_params(self) -> dict:
246
- """
247
- It should return a dict that provides the information of all the parameters
248
- that are used in the LLM. This is useful when we print our llm, it will give use the
249
- information of all the parameters.
250
- """
251
- return {
252
- "deployment": self.deployment,
253
- "api_token": self.api_token,
254
- "project_id": self.project_id,
255
- "integration_id": self.integration_uid,
256
- "model_settings": self._get_model_default_parameters,
257
- }
258
-
259
-