truefoundry 0.5.0rc6__py3-none-any.whl → 0.5.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of truefoundry might be problematic. Click here for more details.

Files changed (53) hide show
  1. truefoundry/common/utils.py +73 -1
  2. truefoundry/deploy/__init__.py +5 -0
  3. truefoundry/deploy/cli/cli.py +2 -0
  4. truefoundry/deploy/cli/commands/__init__.py +1 -0
  5. truefoundry/deploy/cli/commands/deploy_init_command.py +22 -0
  6. truefoundry/deploy/lib/dao/application.py +2 -1
  7. truefoundry/deploy/v2/lib/patched_models.py +8 -0
  8. truefoundry/ml/__init__.py +15 -12
  9. truefoundry/ml/artifact/truefoundry_artifact_repo.py +8 -3
  10. truefoundry/ml/autogen/client/__init__.py +11 -0
  11. truefoundry/ml/autogen/client/api/mlfoundry_artifacts_api.py +161 -0
  12. truefoundry/ml/autogen/client/models/__init__.py +11 -0
  13. truefoundry/ml/autogen/client/models/artifact_version_manifest.py +2 -2
  14. truefoundry/ml/autogen/client/models/export_deployment_files_request_dto.py +82 -0
  15. truefoundry/ml/autogen/client/models/infer_method_name.py +34 -0
  16. truefoundry/ml/autogen/client/models/model_server.py +34 -0
  17. truefoundry/ml/autogen/client/models/model_version_environment.py +97 -0
  18. truefoundry/ml/autogen/client/models/model_version_manifest.py +14 -3
  19. truefoundry/ml/autogen/client/models/serialization_format.py +35 -0
  20. truefoundry/ml/autogen/client/models/sklearn_framework.py +31 -2
  21. truefoundry/ml/autogen/client/models/transformers_framework.py +2 -2
  22. truefoundry/ml/autogen/client/models/xg_boost_framework.py +20 -2
  23. truefoundry/ml/autogen/client_README.md +6 -0
  24. truefoundry/ml/autogen/entities/artifacts.py +65 -6
  25. truefoundry/ml/cli/commands/model_init.py +97 -0
  26. truefoundry/ml/cli/utils.py +34 -0
  27. truefoundry/ml/log_types/artifacts/model.py +48 -24
  28. truefoundry/ml/log_types/artifacts/utils.py +37 -1
  29. truefoundry/ml/mlfoundry_api.py +77 -79
  30. truefoundry/ml/mlfoundry_run.py +3 -31
  31. truefoundry/ml/model_framework.py +257 -3
  32. truefoundry/ml/validation_utils.py +2 -0
  33. {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1rc1.dist-info}/METADATA +2 -6
  34. {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1rc1.dist-info}/RECORD +36 -45
  35. truefoundry/deploy/function_service/__init__.py +0 -3
  36. truefoundry/deploy/function_service/__main__.py +0 -27
  37. truefoundry/deploy/function_service/app.py +0 -92
  38. truefoundry/deploy/function_service/build.py +0 -45
  39. truefoundry/deploy/function_service/remote/__init__.py +0 -6
  40. truefoundry/deploy/function_service/remote/context.py +0 -3
  41. truefoundry/deploy/function_service/remote/method.py +0 -67
  42. truefoundry/deploy/function_service/remote/remote.py +0 -144
  43. truefoundry/deploy/function_service/route.py +0 -137
  44. truefoundry/deploy/function_service/service.py +0 -113
  45. truefoundry/deploy/function_service/utils.py +0 -53
  46. truefoundry/langchain/__init__.py +0 -12
  47. truefoundry/langchain/deprecated.py +0 -302
  48. truefoundry/langchain/truefoundry_chat.py +0 -130
  49. truefoundry/langchain/truefoundry_embeddings.py +0 -171
  50. truefoundry/langchain/truefoundry_llm.py +0 -106
  51. truefoundry/langchain/utils.py +0 -44
  52. {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1rc1.dist-info}/WHEEL +0 -0
  53. {truefoundry-0.5.0rc6.dist-info → truefoundry-0.5.1rc1.dist-info}/entry_points.txt +0 -0
@@ -1,302 +0,0 @@
1
- import enum
2
- import warnings
3
- from typing import Any, Dict, List, Mapping, Optional
4
- from urllib.parse import urljoin
5
-
6
- import requests
7
- from requests.auth import HTTPBasicAuth
8
-
9
- from truefoundry.deploy.lib.auth.servicefoundry_session import ServiceFoundrySession
10
- from truefoundry.pydantic_v1 import root_validator
11
-
12
- try:
13
- from langchain.callbacks.manager import CallbackManagerForLLMRun
14
- from langchain.llms.base import LLM
15
- from langchain.llms.utils import enforce_stop_tokens
16
- except Exception as ex:
17
- raise Exception(
18
- "Failed to import langchain."
19
- " Please install langchain by using `pip install langchain` command"
20
- ) from ex
21
-
22
-
23
- class _ModelServerImpl(str, enum.Enum):
24
- MLSERVER = "MLSERVER"
25
- TGI = "TGI"
26
- VLLM = "VLLM"
27
-
28
-
29
- def _get_model_server_and_validate_if_mlserver(endpoint_url, auth, model_name=None):
30
- try:
31
- response = requests.get(urljoin(endpoint_url, "info"), json={}, auth=auth)
32
- if response.status_code == 200:
33
- return _ModelServerImpl.TGI, None
34
- elif response.status_code == 404:
35
- # We are not using TGI, try for mlserver
36
- response = requests.post(
37
- urljoin(endpoint_url, "v2/repository/index"), json={}, auth=auth
38
- )
39
- if response.status_code == 200:
40
- models = response.json()
41
- if len(models) == 0:
42
- raise ValueError("No model is deployed in the model server")
43
- model_names = [m.get("name") for m in models]
44
- if model_name and model_name not in model_names:
45
- raise ValueError(
46
- f"Model {model_name!r} is not available in the model server. "
47
- f"Available models {model_names!r}"
48
- )
49
- if not model_name and len(model_names) > 1:
50
- raise ValueError(
51
- f"Please pass `model_name` while instantiating `TruefoundryLLM`. "
52
- f"Available models are {model_names!r} "
53
- )
54
- if model_name:
55
- return _ModelServerImpl.MLSERVER, model_name
56
- return _ModelServerImpl.MLSERVER, model_names[0]
57
- if response.status_code == 404:
58
- return _ModelServerImpl.VLLM, None
59
- response.raise_for_status()
60
- except Exception as e:
61
- raise Exception(f"Error raised by inference API: {e}") from e
62
-
63
-
64
- # TODO (chiragjn): Refactor this into separate implementations for each model server
65
-
66
-
67
- class TruefoundryLLM(LLM):
68
- """Wrapper around TFY model deployment.
69
- To use this class, you need to have the langchain library installed.
70
- Example:
71
- .. code-block:: python
72
- from truefoundry.langchain import TruefoundryLLM
73
- endpoint_url = (
74
- "https://pythia-70m-model-model-catalogue.demo2.truefoundry.tech"
75
- )
76
- model = TruefoundryLLM(
77
- endpoint_url=endpoint_url,
78
- parameters={
79
- "max_new_tokens": 100,
80
- "temperature": 0.7,
81
- "top_k": 5,
82
- "top_p": 0.9
83
- }
84
- )
85
- """
86
-
87
- endpoint_url: str
88
- model_name: Optional[str] = None
89
- auth: Optional[HTTPBasicAuth] = None
90
- parameters: Optional[Dict[str, Any]] = None
91
- model_server_impl: Optional[_ModelServerImpl] = None
92
-
93
- @root_validator(pre=False)
94
- def validate_model_server_and_name(cls, values: Dict):
95
- warnings.warn(
96
- message=f"{cls.__name__} is deprecated and will be removed soon. Please use `TrueFoundryLLM` or `TrueFoundryChat` to invoke models using the new TrueFoundry LLM Gateway",
97
- category=DeprecationWarning,
98
- stacklevel=2,
99
- )
100
- endpoint_url = values["endpoint_url"]
101
- model_name = values.get("model_name")
102
- auth = values.get("auth")
103
- model_server_impl, model_name = _get_model_server_and_validate_if_mlserver(
104
- endpoint_url=endpoint_url, model_name=model_name, auth=auth
105
- )
106
- values["model_server_impl"] = model_server_impl
107
- if model_server_impl == _ModelServerImpl.MLSERVER:
108
- values["model_name"] = model_name
109
- return values
110
-
111
- @property
112
- def _identifying_params(self) -> Mapping[str, Any]:
113
- """Get the identifying parameters."""
114
- return {
115
- "endpoint_url": self.endpoint_url,
116
- "model_name": self.model_name,
117
- }
118
-
119
- @property
120
- def _llm_type(self) -> str:
121
- """Return type of llm."""
122
- return "tfy_model_deployment"
123
-
124
- def _call( # noqa: C901
125
- self,
126
- prompt: str,
127
- stop: Optional[List[str]] = None,
128
- run_manager: Optional[CallbackManagerForLLMRun] = None,
129
- **params: Any,
130
- ) -> str:
131
- """Call out to the deployed model
132
- Args:
133
- prompt: The prompt to pass into the model.
134
- stop: Optional list of stop words to use when generating.
135
- Returns:
136
- The string generated by the model.
137
- Example:
138
- .. code-block:: python
139
- response = model("Tell me a joke.")
140
- """
141
- _params_already_set = self.parameters or {}
142
- params = {**_params_already_set, **params, "return_full_text": False}
143
-
144
- if self.model_server_impl == _ModelServerImpl.MLSERVER:
145
- generate_path = f"v2/models/{self.model_name}/infer/simple"
146
- payload = {"inputs": prompt, "parameters": params}
147
- elif self.model_server_impl == _ModelServerImpl.TGI:
148
- generate_path = "generate"
149
- payload = {"inputs": prompt, "parameters": params}
150
- elif self.model_server_impl == _ModelServerImpl.VLLM:
151
- generate_path = "generate"
152
- payload = {**params, "prompt": prompt}
153
- else:
154
- raise ValueError(f"No known generate path for {self.model_server_impl}")
155
- url = urljoin(self.endpoint_url, generate_path)
156
-
157
- try:
158
- response = requests.post(url, json=payload, auth=self.auth)
159
- response.raise_for_status()
160
- except Exception as e:
161
- raise Exception(f"Error raised by inference API: {e}") from e
162
- response_dict = response.json()
163
- if "error" in response_dict:
164
- raise ValueError(
165
- f"Error raised by inference API: {response_dict['error']!r}"
166
- )
167
-
168
- if self.model_server_impl == _ModelServerImpl.MLSERVER:
169
- inference_result = response_dict[0]
170
- elif self.model_server_impl == _ModelServerImpl.TGI:
171
- inference_result = response_dict
172
- elif self.model_server_impl == _ModelServerImpl.VLLM:
173
- inference_result = response_dict
174
- else:
175
- raise ValueError(
176
- f"Unknown model server {self.model_server_impl}, cannot parse response"
177
- )
178
-
179
- if "generated_text" in inference_result:
180
- text = inference_result["generated_text"]
181
- elif "summarization" in inference_result:
182
- text = inference_result["summary_text"]
183
- elif "text" in inference_result:
184
- text = inference_result["text"]
185
- else:
186
- raise ValueError(f"Could not parse inference response: {response_dict!r}")
187
-
188
- if isinstance(text, list):
189
- text = text[0]
190
-
191
- if stop:
192
- text = enforce_stop_tokens(text, stop)
193
-
194
- return text
195
-
196
-
197
- class TruefoundryPlaygroundLLM(LLM):
198
- """Wrapper around TFY Playground.
199
- To use this class, you need to have the langchain library installed.
200
- Example:
201
- .. code-block:: python
202
- from truefoundry.langchain import TruefoundryPlaygroundLLM
203
- import os
204
- # Note: Login using tfy login --host <https://example-domain.com>
205
- model = TruefoundryPlaygroundLLM(
206
- model_name="vicuna-13b",
207
- parameters={
208
- "maximumLength": 100,
209
- "temperature": 0.7,
210
- "topP": 0.9,
211
- "repetitionPenalty": 1
212
- }
213
- )
214
- response = model.predict("Enter the prompt here")
215
- """
216
-
217
- model_name: str
218
- parameters: Optional[Dict[str, Any]] = None
219
- provider: str = "truefoundry-public"
220
-
221
- @root_validator(pre=False)
222
- def validate_model_server_and_name(cls, values: Dict):
223
- warnings.warn(
224
- message=f"{cls.__name__} is deprecated and will be removed soon. Please use `TrueFoundryLLM` or `TrueFoundryChat` to invoke models using the new TrueFoundry LLM Gateway",
225
- category=DeprecationWarning,
226
- stacklevel=2,
227
- )
228
- return values
229
-
230
- @property
231
- def _get_model(self) -> str:
232
- """returns the model name"""
233
- return self.model_name
234
-
235
- @property
236
- def _get_provider(self) -> str:
237
- """Returns the provider name"""
238
- return self.provider
239
-
240
- @property
241
- def _llm_type(self) -> str:
242
- """Return type of llm."""
243
- return "tfy_playground"
244
-
245
- def _call(
246
- self,
247
- prompt: str,
248
- stop: Optional[List[str]] = None,
249
- **params: Any,
250
- ) -> str:
251
- """Call out to the deployed model
252
- Args:
253
- prompt: The prompt to pass into the model.
254
- stop: Optional list of stop words to use when generating.
255
- Returns:
256
- The string generated by the model.
257
- Example:
258
- .. code-block:: python
259
- response = model("I have a joke for you...")
260
- """
261
- _params_already_set = self.parameters or {}
262
- params = {**_params_already_set, **params}
263
- if stop:
264
- params["stopSequences"] = stop
265
- session = ServiceFoundrySession()
266
-
267
- if not session:
268
- raise Exception(
269
- "Unauthenticated: Please login using tfy login --host <https://example-domain.com>"
270
- )
271
-
272
- host = session.base_url
273
-
274
- if host[-1] == "/":
275
- host = host[: len(host) - 1]
276
-
277
- url = f"{host}/llm-playground/api/inference/text"
278
- headers = {"Authorization": f"Bearer {session.access_token}"}
279
-
280
- json = {
281
- "prompt": prompt,
282
- "models": [
283
- {
284
- "name": self.model_name,
285
- "provider": self.provider,
286
- "tag": self.model_name,
287
- "parameters": params,
288
- }
289
- ],
290
- }
291
-
292
- try:
293
- response = requests.post(url=url, headers=headers, json=json)
294
- response.raise_for_status()
295
- except Exception as ex:
296
- raise Exception(f"Error inferencing the model: {ex}") from ex
297
-
298
- data = response.json()
299
- text = data[0].get("text")
300
- if stop:
301
- text = enforce_stop_tokens(text, stop)
302
- return text
@@ -1,130 +0,0 @@
1
- from typing import Any, Dict, List, Optional
2
-
3
- from langchain.chat_models.base import SimpleChatModel
4
- from langchain.pydantic_v1 import Extra, Field, root_validator
5
- from langchain.schema.messages import (
6
- AIMessage,
7
- BaseMessage,
8
- ChatMessage,
9
- HumanMessage,
10
- SystemMessage,
11
- )
12
-
13
- from truefoundry.common.request_utils import requests_retry_session
14
- from truefoundry.langchain.utils import (
15
- validate_tfy_environment,
16
- )
17
- from truefoundry.logger import logger
18
-
19
-
20
- class TrueFoundryChat(SimpleChatModel):
21
- """`TrueFoundry LLM Gateway` chat models API.
22
-
23
- To use, you must have the environment variable ``TFY_API_KEY`` set with your API key and ``TFY_HOST`` set with your host or pass it as a named parameter to the constructor.
24
- """
25
-
26
- model: str = Field(description="The model to use for chat.")
27
- """The model to use for chat."""
28
- tfy_llm_gateway_url: Optional[str] = Field(default=None)
29
- """TrueFoundry LLM Gateway endpoint URL. Automatically inferred from env var `TFY_LLM_GATEWAY_URL` if not provided."""
30
- tfy_api_key: Optional[str] = Field(default=None)
31
- """TrueFoundry API Key. Automatically inferred from env var `TFY_API_KEY` if not provided."""
32
- model_parameters: Optional[dict] = Field(default_factory=dict)
33
- """Model parameters"""
34
- request_timeout: int = Field(default=30)
35
- """The timeout for the request in seconds."""
36
- max_retries: int = Field(default=5)
37
- """The number of retries for HTTP requests."""
38
- retry_backoff_factor: float = Field(default=0.3)
39
- """The backoff factor for exponential backoff during retries."""
40
- system_prompt: str = Field(default="You are a AI assistant")
41
-
42
- class Config:
43
- """Configuration for this pydantic object."""
44
-
45
- extra = Extra.forbid
46
- allow_population_by_field_name = True
47
-
48
- @root_validator()
49
- def validate_environment(cls, values: Dict) -> Dict:
50
- values = validate_tfy_environment(values)
51
- if not values["tfy_api_key"]:
52
- raise ValueError(
53
- "Did not find `tfy_api_key`, please add an environment variable"
54
- " `TFY_API_KEY` which contains it, or pass"
55
- " `tfy_api_key` as a named parameter."
56
- )
57
- if not values["tfy_llm_gateway_url"]:
58
- raise ValueError(
59
- "Did not find `tfy_llm_gateway_url`, please add an environment variable"
60
- " `TFY_LLM_GATEWAY_URL` which contains it, or pass"
61
- " `tfy_llm_gateway_url` as a named parameter."
62
- )
63
- return values
64
-
65
- @property
66
- def _llm_type(self) -> str:
67
- """Return type of chat model."""
68
- return "truefoundry-chat"
69
-
70
- def _call(
71
- self,
72
- messages: List[BaseMessage],
73
- stop: Optional[List[str]] = None,
74
- **kwargs: Any,
75
- ) -> str:
76
- if len(messages) == 0:
77
- raise ValueError("No messages provided to chat.")
78
-
79
- if not isinstance(messages[0], SystemMessage):
80
- messages.insert(0, SystemMessage(content=self.system_prompt))
81
-
82
- message_dicts = [
83
- TrueFoundryChat._convert_message_to_dict(message) for message in messages
84
- ]
85
-
86
- payload = {**self.model_parameters} if self.model_parameters else {}
87
-
88
- if stop:
89
- payload["stop_sequences"] = stop
90
-
91
- payload["messages"] = message_dicts
92
- payload["model"] = self.model
93
-
94
- session = requests_retry_session(
95
- retries=self.max_retries, backoff_factor=self.retry_backoff_factor
96
- )
97
-
98
- url = f"{self.tfy_llm_gateway_url}/openai/chat/completions"
99
- logger.debug(f"Chat using - model: {self.model} at endpoint: {url}")
100
- response = session.post(
101
- url=url,
102
- json=payload,
103
- headers={
104
- "Authorization": f"Bearer {self.tfy_api_key}",
105
- },
106
- timeout=self.request_timeout,
107
- )
108
- response.raise_for_status()
109
- output = response.json()
110
- return output["choices"][0]["message"]["content"]
111
-
112
- @staticmethod
113
- def _convert_message_to_dict(message: BaseMessage) -> dict:
114
- if isinstance(message, ChatMessage):
115
- message_dict = {"role": message.role, "content": message.content}
116
- elif isinstance(message, HumanMessage):
117
- message_dict = {"role": "user", "content": message.content}
118
- elif isinstance(message, AIMessage):
119
- message_dict = {"role": "assistant", "content": message.content}
120
- elif isinstance(message, SystemMessage):
121
- message_dict = {"role": "system", "content": message.content}
122
- else:
123
- raise ValueError(f"Got unknown message type: {message}")
124
- if message.additional_kwargs:
125
- logger.debug(
126
- "Additional message arguments are unsupported by TrueFoundry LLM Gateway "
127
- " and will be ignored: %s",
128
- message.additional_kwargs,
129
- )
130
- return message_dict
@@ -1,171 +0,0 @@
1
- import concurrent.futures
2
- import math
3
- from typing import Dict, List, Optional
4
-
5
- import tqdm
6
- from langchain.embeddings.base import Embeddings
7
- from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
8
-
9
- from truefoundry.common.request_utils import requests_retry_session
10
- from truefoundry.langchain.utils import (
11
- validate_tfy_environment,
12
- )
13
- from truefoundry.logger import logger
14
-
15
- EMBEDDER_BATCH_SIZE = 32
16
- PARALLEL_WORKERS = 4
17
-
18
-
19
- class TrueFoundryEmbeddings(BaseModel, Embeddings):
20
- """`TrueFoundry LLM Gateway` embedding models API.
21
-
22
- To use, you must have the environment variable ``TFY_API_KEY`` set with your API key and ``TFY_HOST`` set with your host or pass it
23
- as a named parameter to the constructor.
24
- """
25
-
26
- model: str = Field(description="The model to use for embedding.")
27
- """The model to use for embedding."""
28
- tfy_llm_gateway_url: Optional[str] = Field(default=None)
29
- """TrueFoundry LLM Gateway endpoint URL. Automatically inferred from env var `TFY_LLM_GATEWAY_URL` if not provided."""
30
- tfy_api_key: Optional[str] = Field(default=None)
31
- """TrueFoundry API Key. Automatically inferred from env var `TFY_API_KEY` if not provided."""
32
- model_parameters: Optional[dict] = Field(default_factory=dict)
33
- """Model parameters"""
34
- request_timeout: int = Field(default=30)
35
- """The timeout for the request in seconds."""
36
- max_retries: int = Field(default=5)
37
- """The number of retries for HTTP requests."""
38
- retry_backoff_factor: float = Field(default=0.3)
39
- """The backoff factor for exponential backoff during retries."""
40
- batch_size: int = Field(default=EMBEDDER_BATCH_SIZE)
41
- """The batch size to use for embedding."""
42
- parallel_workers: int = Field(default=PARALLEL_WORKERS)
43
- """The number of parallel workers to use for embedding."""
44
-
45
- __private_attributes__ = {"_executor"}
46
-
47
- class Config:
48
- """Configuration for this pydantic object."""
49
-
50
- extra = Extra.forbid
51
- allow_population_by_field_name = True
52
-
53
- @root_validator()
54
- def validate_environment(cls, values: Dict) -> Dict:
55
- values = validate_tfy_environment(values)
56
- if not values["tfy_api_key"]:
57
- raise ValueError(
58
- "Did not find `tfy_api_key`, please add an environment variable"
59
- " `TFY_API_KEY` which contains it, or pass"
60
- " `tfy_api_key` as a named parameter."
61
- )
62
- if not values["tfy_llm_gateway_url"]:
63
- raise ValueError(
64
- "Did not find `tfy_llm_gateway_url`, please add an environment variable"
65
- " `TFY_LLM_GATEWAY_URL` which contains it, or pass"
66
- " `tfy_llm_gateway_url` as a named parameter."
67
- )
68
- return values
69
-
70
- def _init_private_attributes(self):
71
- self._executor = concurrent.futures.ThreadPoolExecutor(
72
- max_workers=self.parallel_workers
73
- )
74
-
75
- @property
76
- def _llm_type(self) -> str:
77
- """Return type of embedding model."""
78
- return "truefoundry-embeddings"
79
-
80
- def __del__(self):
81
- """
82
- Destructor method to clean up the executor when the object is deleted.
83
-
84
- Returns:
85
- None
86
- """
87
- self._executor.shutdown()
88
-
89
- def _remote_embed(self, texts, query_mode=False):
90
- """
91
- Perform remote embedding using a HTTP POST request to a designated endpoint.
92
-
93
- Args:
94
- texts (List[str]): A list of text strings to be embedded.
95
- query_mode (bool): A flag to indicate if running in query mode or in embed mode (indexing).
96
- Returns:
97
- List[List[float]]: A list of embedded representations of the input texts.
98
- """
99
- session = requests_retry_session(
100
- retries=self.max_retries, backoff_factor=self.retry_backoff_factor
101
- )
102
-
103
- payload = {
104
- "input": texts,
105
- "model": self.model,
106
- }
107
-
108
- url = f"{self.tfy_llm_gateway_url}/openai/embeddings"
109
- logger.debug(
110
- f"Embedding using - model: {self.model} at endpoint: {url}, for {len(texts)} texts"
111
- )
112
- response = session.post(
113
- url=url,
114
- json=payload,
115
- headers={
116
- "Authorization": f"Bearer {self.tfy_api_key}",
117
- },
118
- timeout=self.request_timeout,
119
- )
120
- response.raise_for_status()
121
- output = response.json()
122
- return [data["embedding"] for data in output["data"]]
123
-
124
- def _embed(self, texts: List[str], query_mode: bool):
125
- """
126
- Perform embedding on a list of texts using remote embedding in chunks.
127
-
128
- Args:
129
- texts (List[str]): A list of text strings to be embedded.
130
- query_mode (bool): A flag to indicate if running in query mode or in embed mode (indexing).
131
- Returns:
132
- List[List[float]]: A list of embedded representations of the input texts.
133
- """
134
- embeddings = []
135
-
136
- def _feeder():
137
- for i in range(0, len(texts), self.batch_size):
138
- chunk = texts[i : i + self.batch_size]
139
- yield chunk
140
-
141
- embeddings = list(
142
- tqdm.tqdm(
143
- self._executor.map(self._remote_embed, _feeder()),
144
- total=int(math.ceil(len(texts) / self.batch_size)),
145
- )
146
- )
147
- return [item for batch in embeddings for item in batch]
148
-
149
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
150
- """
151
- Embed a list of text documents.
152
-
153
- Args:
154
- texts (List[str]): A list of text documents to be embedded.
155
-
156
- Returns:
157
- List[List[float]]: A list of embedded representations of the input documents.
158
- """
159
- return self._embed(texts, query_mode=False)
160
-
161
- def embed_query(self, text: str) -> List[float]:
162
- """
163
- Embed a query text.
164
-
165
- Args:
166
- text (str): The query text to be embedded.
167
-
168
- Returns:
169
- List[float]: The embedded representation of the input query text.
170
- """
171
- return self._embed([text], query_mode=True)[0]