sunholo 0.114.2__py3-none-any.whl → 0.115.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sunholo/__init__.py +0 -2
- sunholo/chunker/message_data.py +1 -4
- sunholo/chunker/pdfs.py +1 -1
- sunholo/chunker/publish.py +1 -5
- sunholo/chunker/splitter.py +6 -0
- sunholo/cli/cli_init.py +3 -1
- sunholo/components/llm.py +1 -1
- sunholo/components/vectorstore.py +1 -1
- sunholo/embedder/embed_chunk.py +3 -0
- sunholo/senses/stream_voice.py +22 -7
- sunholo/streaming/content_buffer.py +6 -13
- sunholo/summarise/summarise.py +5 -5
- sunholo/types.py +52 -0
- sunholo/utils/config.py +4 -3
- sunholo/utils/config_class.py +20 -20
- sunholo/utils/gcp.py +0 -3
- sunholo/vertex/extensions_class.py +4 -4
- {sunholo-0.114.2.dist-info → sunholo-0.115.0.dist-info}/METADATA +14 -10
- {sunholo-0.114.2.dist-info → sunholo-0.115.0.dist-info}/RECORD +23 -26
- sunholo/patches/__init__.py +0 -0
- sunholo/patches/langchain/__init__.py +0 -0
- sunholo/patches/langchain/lancedb.py +0 -219
- sunholo/patches/langchain/vertexai.py +0 -506
- {sunholo-0.114.2.dist-info → sunholo-0.115.0.dist-info}/LICENSE.txt +0 -0
- {sunholo-0.114.2.dist-info → sunholo-0.115.0.dist-info}/WHEEL +0 -0
- {sunholo-0.114.2.dist-info → sunholo-0.115.0.dist-info}/entry_points.txt +0 -0
- {sunholo-0.114.2.dist-info → sunholo-0.115.0.dist-info}/top_level.txt +0 -0
|
@@ -1,506 +0,0 @@
|
|
|
1
|
-
# pull request not merged yet with langchain
|
|
2
|
-
# https://github.com/langchain-ai/langchain/pull/12723
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from concurrent.futures import Executor, ThreadPoolExecutor
|
|
6
|
-
from typing import (
|
|
7
|
-
TYPE_CHECKING,
|
|
8
|
-
Any,
|
|
9
|
-
Callable,
|
|
10
|
-
ClassVar,
|
|
11
|
-
Dict,
|
|
12
|
-
Iterator,
|
|
13
|
-
List,
|
|
14
|
-
Optional,
|
|
15
|
-
Union,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
from langchain.callbacks.manager import (
|
|
19
|
-
AsyncCallbackManagerForLLMRun,
|
|
20
|
-
CallbackManagerForLLMRun,
|
|
21
|
-
)
|
|
22
|
-
from langchain.llms.base import BaseLLM, create_base_retry_decorator
|
|
23
|
-
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
|
24
|
-
from langchain.schema import (
|
|
25
|
-
Generation,
|
|
26
|
-
LLMResult,
|
|
27
|
-
)
|
|
28
|
-
from langchain.schema.output import GenerationChunk
|
|
29
|
-
from langchain.utilities.vertexai import (
|
|
30
|
-
get_client_info,
|
|
31
|
-
init_vertexai,
|
|
32
|
-
raise_vertex_import_error,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
if TYPE_CHECKING:
|
|
36
|
-
from google.cloud.aiplatform.gapic import (
|
|
37
|
-
PredictionServiceAsyncClient,
|
|
38
|
-
PredictionServiceClient,
|
|
39
|
-
)
|
|
40
|
-
from vertexai.language_models._language_models import (
|
|
41
|
-
TextGenerationResponse,
|
|
42
|
-
_LanguageModel,
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def _response_to_generation(
|
|
47
|
-
response: TextGenerationResponse,
|
|
48
|
-
) -> GenerationChunk:
|
|
49
|
-
"""Convert a stream response to a generation chunk."""
|
|
50
|
-
try:
|
|
51
|
-
generation_info = {
|
|
52
|
-
"is_blocked": response.is_blocked,
|
|
53
|
-
"safety_attributes": response.safety_attributes,
|
|
54
|
-
}
|
|
55
|
-
except Exception:
|
|
56
|
-
generation_info = None
|
|
57
|
-
return GenerationChunk(text=response.text, generation_info=generation_info)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def is_codey_model(model_name: str) -> bool:
|
|
61
|
-
"""Returns True if the model name is a Codey model.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
model_name: The model name to check.
|
|
65
|
-
|
|
66
|
-
Returns: True if the model name is a Codey model.
|
|
67
|
-
"""
|
|
68
|
-
return "code" in model_name
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def _create_retry_decorator(
|
|
72
|
-
llm: VertexAI,
|
|
73
|
-
*,
|
|
74
|
-
run_manager: Optional[
|
|
75
|
-
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
|
76
|
-
] = None,
|
|
77
|
-
) -> Callable[[Any], Any]:
|
|
78
|
-
import google.api_core
|
|
79
|
-
|
|
80
|
-
errors = [
|
|
81
|
-
google.api_core.exceptions.ResourceExhausted,
|
|
82
|
-
google.api_core.exceptions.ServiceUnavailable,
|
|
83
|
-
google.api_core.exceptions.Aborted,
|
|
84
|
-
google.api_core.exceptions.DeadlineExceeded,
|
|
85
|
-
]
|
|
86
|
-
decorator = create_base_retry_decorator(
|
|
87
|
-
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
|
88
|
-
)
|
|
89
|
-
return decorator
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def completion_with_retry(
|
|
93
|
-
llm: VertexAI,
|
|
94
|
-
*args: Any,
|
|
95
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
96
|
-
**kwargs: Any,
|
|
97
|
-
) -> Any:
|
|
98
|
-
"""Use tenacity to retry the completion call."""
|
|
99
|
-
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
100
|
-
|
|
101
|
-
@retry_decorator
|
|
102
|
-
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
|
103
|
-
return llm.client.predict(*args, **kwargs)
|
|
104
|
-
|
|
105
|
-
return _completion_with_retry(*args, **kwargs)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def stream_completion_with_retry(
|
|
109
|
-
llm: VertexAI,
|
|
110
|
-
*args: Any,
|
|
111
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
112
|
-
**kwargs: Any,
|
|
113
|
-
) -> Any:
|
|
114
|
-
"""Use tenacity to retry the completion call."""
|
|
115
|
-
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
116
|
-
|
|
117
|
-
@retry_decorator
|
|
118
|
-
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
|
119
|
-
return llm.client.predict_streaming(*args, **kwargs)
|
|
120
|
-
|
|
121
|
-
return _completion_with_retry(*args, **kwargs)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
async def acompletion_with_retry(
|
|
125
|
-
llm: VertexAI,
|
|
126
|
-
*args: Any,
|
|
127
|
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
128
|
-
**kwargs: Any,
|
|
129
|
-
) -> Any:
|
|
130
|
-
"""Use tenacity to retry the completion call."""
|
|
131
|
-
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
132
|
-
|
|
133
|
-
@retry_decorator
|
|
134
|
-
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
|
|
135
|
-
return await llm.client.predict_async(*args, **kwargs)
|
|
136
|
-
|
|
137
|
-
return await _acompletion_with_retry(*args, **kwargs)
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class _VertexAIBase(BaseModel):
|
|
141
|
-
project: Optional[str] = None
|
|
142
|
-
"The default GCP project to use when making Vertex API calls."
|
|
143
|
-
location: Optional[str] = None
|
|
144
|
-
"The location to use when making API calls from models in Model Garden"
|
|
145
|
-
request_parallelism: int = 5
|
|
146
|
-
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
|
147
|
-
"Default is 5."
|
|
148
|
-
max_retries: int = 6
|
|
149
|
-
"""The maximum number of retries to make when generating."""
|
|
150
|
-
task_executor: ClassVar[Optional[Executor]] = Field(default=None, exclude=True)
|
|
151
|
-
stop: Optional[List[str]] = None
|
|
152
|
-
"Optional list of stop words to use when generating."
|
|
153
|
-
model_name: Optional[str] = None
|
|
154
|
-
"Underlying model name."
|
|
155
|
-
|
|
156
|
-
@classmethod
|
|
157
|
-
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
|
158
|
-
if cls.task_executor is None:
|
|
159
|
-
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
|
|
160
|
-
return cls.task_executor
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
class _VertexAICommon(_VertexAIBase):
|
|
164
|
-
client: "_LanguageModel" = None #: :meta private:
|
|
165
|
-
model_name: str
|
|
166
|
-
"Underlying model name."
|
|
167
|
-
temperature: float = 0.0
|
|
168
|
-
"Sampling temperature, it controls the degree of randomness in token selection."
|
|
169
|
-
max_output_tokens: int = 128
|
|
170
|
-
"Token limit determines the maximum amount of text output from one prompt."
|
|
171
|
-
top_p: float = 0.95
|
|
172
|
-
"Tokens are selected from most probable to least until the sum of their "
|
|
173
|
-
"probabilities equals the top-p value. Top-p is ignored for Codey models."
|
|
174
|
-
top_k: int = 40
|
|
175
|
-
"How the model selects tokens for output, the next token is selected from "
|
|
176
|
-
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
|
177
|
-
credentials: Any = Field(default=None, exclude=True)
|
|
178
|
-
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
|
179
|
-
"when making API calls. If not provided, credentials will be ascertained from "
|
|
180
|
-
"the environment."
|
|
181
|
-
n: int = 1
|
|
182
|
-
"""How many completions to generate for each prompt."""
|
|
183
|
-
streaming: bool = False
|
|
184
|
-
"""Whether to stream the results or not."""
|
|
185
|
-
|
|
186
|
-
@property
|
|
187
|
-
def _llm_type(self) -> str:
|
|
188
|
-
return "vertexai"
|
|
189
|
-
|
|
190
|
-
@property
|
|
191
|
-
def is_codey_model(self) -> bool:
|
|
192
|
-
return is_codey_model(self.model_name)
|
|
193
|
-
|
|
194
|
-
@property
|
|
195
|
-
def _identifying_params(self) -> Dict[str, Any]:
|
|
196
|
-
"""Get the identifying parameters."""
|
|
197
|
-
return {**{"model_name": self.model_name}, **self._default_params}
|
|
198
|
-
|
|
199
|
-
@property
|
|
200
|
-
def _default_params(self) -> Dict[str, Any]:
|
|
201
|
-
if self.is_codey_model:
|
|
202
|
-
return {
|
|
203
|
-
"temperature": self.temperature,
|
|
204
|
-
"max_output_tokens": self.max_output_tokens,
|
|
205
|
-
}
|
|
206
|
-
else:
|
|
207
|
-
return {
|
|
208
|
-
"temperature": self.temperature,
|
|
209
|
-
"max_output_tokens": self.max_output_tokens,
|
|
210
|
-
"top_k": self.top_k,
|
|
211
|
-
"top_p": self.top_p,
|
|
212
|
-
"candidate_count": self.n,
|
|
213
|
-
}
|
|
214
|
-
|
|
215
|
-
@classmethod
|
|
216
|
-
def _try_init_vertexai(cls, values: Dict) -> None:
|
|
217
|
-
allowed_params = ["project", "location", "credentials"]
|
|
218
|
-
params = {k: v for k, v in values.items() if k in allowed_params}
|
|
219
|
-
init_vertexai(**params)
|
|
220
|
-
return None
|
|
221
|
-
|
|
222
|
-
def _prepare_params(
|
|
223
|
-
self,
|
|
224
|
-
stop: Optional[List[str]] = None,
|
|
225
|
-
stream: bool = False,
|
|
226
|
-
**kwargs: Any,
|
|
227
|
-
) -> dict:
|
|
228
|
-
stop_sequences = stop or self.stop
|
|
229
|
-
params_mapping = {"n": "candidate_count"}
|
|
230
|
-
params = {params_mapping.get(k, k): v for k, v in kwargs.items()}
|
|
231
|
-
params = {**self._default_params, "stop_sequences": stop_sequences, **params}
|
|
232
|
-
if stream or self.streaming:
|
|
233
|
-
params.pop("candidate_count")
|
|
234
|
-
return params
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
class VertexAI(_VertexAICommon, BaseLLM):
|
|
238
|
-
"""Google Vertex AI large language models."""
|
|
239
|
-
|
|
240
|
-
model_name: str = "text-bison"
|
|
241
|
-
"The name of the Vertex AI large language model."
|
|
242
|
-
tuned_model_name: Optional[str] = None
|
|
243
|
-
"The name of a tuned model. If provided, model_name is ignored."
|
|
244
|
-
"tuned_model_name should be in the format: "
|
|
245
|
-
"'projects/' + PROJECT_ID + '/locations/' + REGION +'/models/'+ model_id"
|
|
246
|
-
|
|
247
|
-
@classmethod
|
|
248
|
-
def is_lc_serializable(self) -> bool:
|
|
249
|
-
return True
|
|
250
|
-
|
|
251
|
-
@root_validator()
|
|
252
|
-
def validate_environment(cls, values: Dict) -> Dict:
|
|
253
|
-
"""Validate that the python package exists in environment."""
|
|
254
|
-
cls._try_init_vertexai(values)
|
|
255
|
-
tuned_model_name = values.get("tuned_model_name")
|
|
256
|
-
model_name = values["model_name"]
|
|
257
|
-
try:
|
|
258
|
-
if not is_codey_model(model_name):
|
|
259
|
-
from vertexai.preview.language_models import TextGenerationModel
|
|
260
|
-
|
|
261
|
-
if tuned_model_name:
|
|
262
|
-
values["client"] = TextGenerationModel.get_tuned_model(
|
|
263
|
-
tuned_model_name
|
|
264
|
-
)
|
|
265
|
-
else:
|
|
266
|
-
values["client"] = TextGenerationModel.from_pretrained(model_name)
|
|
267
|
-
else:
|
|
268
|
-
from vertexai.preview.language_models import CodeGenerationModel
|
|
269
|
-
|
|
270
|
-
if tuned_model_name:
|
|
271
|
-
values["client"] = CodeGenerationModel.get_tuned_model(
|
|
272
|
-
tuned_model_name
|
|
273
|
-
)
|
|
274
|
-
else:
|
|
275
|
-
values["client"] = CodeGenerationModel.from_pretrained(model_name)
|
|
276
|
-
except ImportError:
|
|
277
|
-
raise_vertex_import_error()
|
|
278
|
-
|
|
279
|
-
if values["streaming"] and values["n"] > 1:
|
|
280
|
-
raise ValueError("Only one candidate can be generated with streaming!")
|
|
281
|
-
return values
|
|
282
|
-
|
|
283
|
-
def get_num_tokens(self, text: str) -> int:
|
|
284
|
-
"""Get the number of tokens present in the text.
|
|
285
|
-
|
|
286
|
-
Useful for checking if an input will fit in a model's context window.
|
|
287
|
-
|
|
288
|
-
Args:
|
|
289
|
-
text: The string input to tokenize.
|
|
290
|
-
|
|
291
|
-
Returns:
|
|
292
|
-
The integer number of tokens in the text.
|
|
293
|
-
"""
|
|
294
|
-
try:
|
|
295
|
-
result = self.client.count_tokens(text)
|
|
296
|
-
except AttributeError:
|
|
297
|
-
raise NotImplementedError(
|
|
298
|
-
"Your google-cloud-aiplatform version didn't implement count_tokens."
|
|
299
|
-
"Please, install it with pip install google-cloud-aiplatform>=1.35.0"
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
return result.total_tokens
|
|
303
|
-
|
|
304
|
-
def _generate(
|
|
305
|
-
self,
|
|
306
|
-
prompts: List[str],
|
|
307
|
-
stop: Optional[List[str]] = None,
|
|
308
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
309
|
-
stream: Optional[bool] = None,
|
|
310
|
-
**kwargs: Any,
|
|
311
|
-
) -> LLMResult:
|
|
312
|
-
should_stream = stream if stream is not None else self.streaming
|
|
313
|
-
params = self._prepare_params(stop=stop, stream=should_stream, **kwargs)
|
|
314
|
-
generations = []
|
|
315
|
-
for prompt in prompts:
|
|
316
|
-
if should_stream:
|
|
317
|
-
generation = GenerationChunk(text="")
|
|
318
|
-
for chunk in self._stream(
|
|
319
|
-
prompt, stop=stop, run_manager=run_manager, **kwargs
|
|
320
|
-
):
|
|
321
|
-
generation += chunk
|
|
322
|
-
generations.append([generation])
|
|
323
|
-
else:
|
|
324
|
-
res = completion_with_retry(
|
|
325
|
-
self, prompt, run_manager=run_manager, **params
|
|
326
|
-
)
|
|
327
|
-
if self.is_codey_model:
|
|
328
|
-
generations.append([_response_to_generation(res)])
|
|
329
|
-
else:
|
|
330
|
-
generations.append(
|
|
331
|
-
[_response_to_generation(r) for r in res.candidates]
|
|
332
|
-
)
|
|
333
|
-
return LLMResult(generations=generations)
|
|
334
|
-
|
|
335
|
-
async def _agenerate(
|
|
336
|
-
self,
|
|
337
|
-
prompts: List[str],
|
|
338
|
-
stop: Optional[List[str]] = None,
|
|
339
|
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
340
|
-
**kwargs: Any,
|
|
341
|
-
) -> LLMResult:
|
|
342
|
-
params = self._prepare_params(stop=stop, **kwargs)
|
|
343
|
-
generations = []
|
|
344
|
-
for prompt in prompts:
|
|
345
|
-
res = await acompletion_with_retry(
|
|
346
|
-
self, prompt, run_manager=run_manager, **params
|
|
347
|
-
)
|
|
348
|
-
generations.append([_response_to_generation(r) for r in res.candidates])
|
|
349
|
-
return LLMResult(generations=generations)
|
|
350
|
-
|
|
351
|
-
def _stream(
|
|
352
|
-
self,
|
|
353
|
-
prompt: str,
|
|
354
|
-
stop: Optional[List[str]] = None,
|
|
355
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
356
|
-
**kwargs: Any,
|
|
357
|
-
) -> Iterator[GenerationChunk]:
|
|
358
|
-
params = self._prepare_params(stop=stop, stream=True, **kwargs)
|
|
359
|
-
for stream_resp in stream_completion_with_retry(
|
|
360
|
-
self, prompt, run_manager=run_manager, **params
|
|
361
|
-
):
|
|
362
|
-
chunk = _response_to_generation(stream_resp)
|
|
363
|
-
yield chunk
|
|
364
|
-
if run_manager:
|
|
365
|
-
run_manager.on_llm_new_token(
|
|
366
|
-
chunk.text,
|
|
367
|
-
chunk=chunk,
|
|
368
|
-
verbose=self.verbose,
|
|
369
|
-
)
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
class VertexAIModelGarden(_VertexAIBase, BaseLLM):
|
|
373
|
-
"""Large language models served from Vertex AI Model Garden."""
|
|
374
|
-
|
|
375
|
-
client: "PredictionServiceClient" = None #: :meta private:
|
|
376
|
-
async_client: "PredictionServiceAsyncClient" = None #: :meta private:
|
|
377
|
-
endpoint_id: str
|
|
378
|
-
"A name of an endpoint where the model has been deployed."
|
|
379
|
-
allowed_model_args: Optional[List[str]] = None
|
|
380
|
-
"""Allowed optional args to be passed to the model."""
|
|
381
|
-
prompt_arg: str = "prompt"
|
|
382
|
-
result_arg: str = "generated_text"
|
|
383
|
-
|
|
384
|
-
@root_validator()
|
|
385
|
-
def validate_environment(cls, values: Dict) -> Dict:
|
|
386
|
-
"""Validate that the python package exists in environment."""
|
|
387
|
-
try:
|
|
388
|
-
from google.api_core.client_options import ClientOptions
|
|
389
|
-
from google.cloud.aiplatform.gapic import (
|
|
390
|
-
PredictionServiceAsyncClient,
|
|
391
|
-
PredictionServiceClient,
|
|
392
|
-
)
|
|
393
|
-
except ImportError:
|
|
394
|
-
raise_vertex_import_error()
|
|
395
|
-
|
|
396
|
-
if values["project"] is None:
|
|
397
|
-
raise ValueError(
|
|
398
|
-
"A GCP project should be provided to run inference on Model Garden!"
|
|
399
|
-
)
|
|
400
|
-
if values["location"] is None:
|
|
401
|
-
raise ValueError(
|
|
402
|
-
"The location of the endpoint must be provided to run inference!"
|
|
403
|
-
)
|
|
404
|
-
|
|
405
|
-
client_options = ClientOptions(
|
|
406
|
-
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
|
|
407
|
-
)
|
|
408
|
-
client_info = get_client_info(module="vertex-ai-model-garden")
|
|
409
|
-
values["client"] = PredictionServiceClient(
|
|
410
|
-
client_options=client_options, client_info=client_info
|
|
411
|
-
)
|
|
412
|
-
values["async_client"] = PredictionServiceAsyncClient(
|
|
413
|
-
client_options=client_options, client_info=client_info
|
|
414
|
-
)
|
|
415
|
-
return values
|
|
416
|
-
|
|
417
|
-
@property
|
|
418
|
-
def _llm_type(self) -> str:
|
|
419
|
-
return "vertexai_model_garden"
|
|
420
|
-
|
|
421
|
-
def _generate(
|
|
422
|
-
self,
|
|
423
|
-
prompts: List[str],
|
|
424
|
-
stop: Optional[List[str]] = None,
|
|
425
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
426
|
-
**kwargs: Any,
|
|
427
|
-
) -> LLMResult:
|
|
428
|
-
"""Run the LLM on the given prompt and input."""
|
|
429
|
-
try:
|
|
430
|
-
from google.protobuf import json_format
|
|
431
|
-
from google.protobuf.struct_pb2 import Value
|
|
432
|
-
except ImportError:
|
|
433
|
-
raise ImportError(
|
|
434
|
-
"protobuf package not found, please install it with"
|
|
435
|
-
" `pip install protobuf`"
|
|
436
|
-
)
|
|
437
|
-
|
|
438
|
-
instances = []
|
|
439
|
-
for prompt in prompts:
|
|
440
|
-
if self.allowed_model_args:
|
|
441
|
-
print(f"allowed_model_args: {kwargs.items()}")
|
|
442
|
-
instance = {
|
|
443
|
-
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
|
444
|
-
}
|
|
445
|
-
else:
|
|
446
|
-
instance = {}
|
|
447
|
-
instance[self.prompt_arg] = prompt
|
|
448
|
-
instances.append(instance)
|
|
449
|
-
print(f"{instance}")
|
|
450
|
-
|
|
451
|
-
predict_instances = [
|
|
452
|
-
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
|
453
|
-
]
|
|
454
|
-
|
|
455
|
-
endpoint = self.client.endpoint_path(
|
|
456
|
-
project=self.project, location=self.location, endpoint=self.endpoint_id
|
|
457
|
-
)
|
|
458
|
-
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
|
|
459
|
-
generations: List[List[Generation]] = []
|
|
460
|
-
for result in response.predictions:
|
|
461
|
-
generations.append([Generation(text=result)])
|
|
462
|
-
return LLMResult(generations=generations)
|
|
463
|
-
|
|
464
|
-
async def _agenerate(
|
|
465
|
-
self,
|
|
466
|
-
prompts: List[str],
|
|
467
|
-
stop: Optional[List[str]] = None,
|
|
468
|
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
469
|
-
**kwargs: Any,
|
|
470
|
-
) -> LLMResult:
|
|
471
|
-
"""Run the LLM on the given prompt and input."""
|
|
472
|
-
try:
|
|
473
|
-
from google.protobuf import json_format
|
|
474
|
-
from google.protobuf.struct_pb2 import Value
|
|
475
|
-
except ImportError:
|
|
476
|
-
raise ImportError(
|
|
477
|
-
"protobuf package not found, please install it with"
|
|
478
|
-
" `pip install protobuf`"
|
|
479
|
-
)
|
|
480
|
-
|
|
481
|
-
instances = []
|
|
482
|
-
for prompt in prompts:
|
|
483
|
-
if self.allowed_model_args:
|
|
484
|
-
instance = {
|
|
485
|
-
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
|
486
|
-
}
|
|
487
|
-
else:
|
|
488
|
-
instance = {}
|
|
489
|
-
instance[self.prompt_arg] = prompt
|
|
490
|
-
instances.append(instance)
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
predict_instances = [
|
|
494
|
-
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
|
495
|
-
]
|
|
496
|
-
|
|
497
|
-
endpoint = self.async_client.endpoint_path(
|
|
498
|
-
project=self.project, location=self.location, endpoint=self.endpoint_id
|
|
499
|
-
)
|
|
500
|
-
response = await self.async_client.predict(
|
|
501
|
-
endpoint=endpoint, instances=predict_instances
|
|
502
|
-
)
|
|
503
|
-
generations: List[List[Generation]] = []
|
|
504
|
-
for result in response.predictions:
|
|
505
|
-
generations.append([Generation(text=result)])
|
|
506
|
-
return LLMResult(generations=generations)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|