lionagi 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +60 -5
- lionagi/core/__init__.py +0 -25
- lionagi/core/_setting/_setting.py +59 -0
- lionagi/core/action/__init__.py +14 -0
- lionagi/core/action/function_calling.py +136 -0
- lionagi/core/action/manual.py +1 -0
- lionagi/core/action/node.py +109 -0
- lionagi/core/action/tool.py +114 -0
- lionagi/core/action/tool_manager.py +356 -0
- lionagi/core/agent/base_agent.py +27 -13
- lionagi/core/agent/eval/evaluator.py +1 -0
- lionagi/core/agent/eval/vote.py +40 -0
- lionagi/core/agent/learn/learner.py +59 -0
- lionagi/core/agent/plan/unit_template.py +1 -0
- lionagi/core/collections/__init__.py +17 -0
- lionagi/core/{generic/data_logger.py → collections/_logger.py} +69 -55
- lionagi/core/collections/abc/__init__.py +53 -0
- lionagi/core/collections/abc/component.py +615 -0
- lionagi/core/collections/abc/concepts.py +297 -0
- lionagi/core/collections/abc/exceptions.py +150 -0
- lionagi/core/collections/abc/util.py +45 -0
- lionagi/core/collections/exchange.py +161 -0
- lionagi/core/collections/flow.py +426 -0
- lionagi/core/collections/model.py +419 -0
- lionagi/core/collections/pile.py +913 -0
- lionagi/core/collections/progression.py +236 -0
- lionagi/core/collections/util.py +64 -0
- lionagi/core/director/direct.py +314 -0
- lionagi/core/director/director.py +2 -0
- lionagi/core/{execute/branch_executor.py → engine/branch_engine.py} +134 -97
- lionagi/core/{execute/instruction_map_executor.py → engine/instruction_map_engine.py} +80 -55
- lionagi/{experimental/directive/evaluator → core/engine}/script_engine.py +17 -1
- lionagi/core/executor/base_executor.py +90 -0
- lionagi/core/{execute/structure_executor.py → executor/graph_executor.py} +62 -66
- lionagi/core/{execute → executor}/neo4j_executor.py +70 -67
- lionagi/core/generic/__init__.py +3 -33
- lionagi/core/generic/edge.py +29 -79
- lionagi/core/generic/edge_condition.py +16 -0
- lionagi/core/generic/graph.py +236 -0
- lionagi/core/generic/hyperedge.py +1 -0
- lionagi/core/generic/node.py +156 -221
- lionagi/core/generic/tree.py +48 -0
- lionagi/core/generic/tree_node.py +79 -0
- lionagi/core/mail/__init__.py +12 -0
- lionagi/core/mail/mail.py +25 -0
- lionagi/core/mail/mail_manager.py +139 -58
- lionagi/core/mail/package.py +45 -0
- lionagi/core/mail/start_mail.py +36 -0
- lionagi/core/message/__init__.py +19 -0
- lionagi/core/message/action_request.py +133 -0
- lionagi/core/message/action_response.py +135 -0
- lionagi/core/message/assistant_response.py +95 -0
- lionagi/core/message/instruction.py +234 -0
- lionagi/core/message/message.py +101 -0
- lionagi/core/message/system.py +86 -0
- lionagi/core/message/util.py +283 -0
- lionagi/core/report/__init__.py +4 -0
- lionagi/core/report/base.py +217 -0
- lionagi/core/report/form.py +231 -0
- lionagi/core/report/report.py +166 -0
- lionagi/core/report/util.py +28 -0
- lionagi/core/rule/_default.py +16 -0
- lionagi/core/rule/action.py +99 -0
- lionagi/core/rule/base.py +238 -0
- lionagi/core/rule/boolean.py +56 -0
- lionagi/core/rule/choice.py +47 -0
- lionagi/core/rule/mapping.py +96 -0
- lionagi/core/rule/number.py +71 -0
- lionagi/core/rule/rulebook.py +109 -0
- lionagi/core/rule/string.py +52 -0
- lionagi/core/rule/util.py +35 -0
- lionagi/core/session/branch.py +431 -0
- lionagi/core/session/directive_mixin.py +287 -0
- lionagi/core/session/session.py +229 -903
- lionagi/core/structure/__init__.py +1 -0
- lionagi/core/structure/chain.py +1 -0
- lionagi/core/structure/forest.py +1 -0
- lionagi/core/structure/graph.py +1 -0
- lionagi/core/structure/tree.py +1 -0
- lionagi/core/unit/__init__.py +5 -0
- lionagi/core/unit/parallel_unit.py +245 -0
- lionagi/core/unit/template/action.py +81 -0
- lionagi/core/unit/template/base.py +51 -0
- lionagi/core/unit/template/plan.py +84 -0
- lionagi/core/unit/template/predict.py +109 -0
- lionagi/core/unit/template/score.py +124 -0
- lionagi/core/unit/template/select.py +104 -0
- lionagi/core/unit/unit.py +362 -0
- lionagi/core/unit/unit_form.py +305 -0
- lionagi/core/unit/unit_mixin.py +1168 -0
- lionagi/core/unit/util.py +71 -0
- lionagi/core/validator/validator.py +364 -0
- lionagi/core/work/work.py +74 -0
- lionagi/core/work/work_function.py +92 -0
- lionagi/core/work/work_queue.py +81 -0
- lionagi/core/work/worker.py +195 -0
- lionagi/core/work/worklog.py +124 -0
- lionagi/experimental/compressor/base.py +46 -0
- lionagi/experimental/compressor/llm_compressor.py +247 -0
- lionagi/experimental/compressor/llm_summarizer.py +61 -0
- lionagi/experimental/compressor/util.py +70 -0
- lionagi/experimental/directive/__init__.py +19 -0
- lionagi/experimental/directive/parser/base_parser.py +69 -2
- lionagi/experimental/directive/{template_ → template}/base_template.py +17 -1
- lionagi/{libs/ln_tokenizer.py → experimental/directive/tokenizer.py} +16 -0
- lionagi/experimental/{directive/evaluator → evaluator}/ast_evaluator.py +16 -0
- lionagi/experimental/{directive/evaluator → evaluator}/base_evaluator.py +16 -0
- lionagi/experimental/knowledge/base.py +10 -0
- lionagi/experimental/memory/__init__.py +0 -0
- lionagi/experimental/strategies/__init__.py +0 -0
- lionagi/experimental/strategies/base.py +1 -0
- lionagi/integrations/bridge/langchain_/documents.py +4 -0
- lionagi/integrations/bridge/llamaindex_/index.py +30 -0
- lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +6 -0
- lionagi/integrations/chunker/chunk.py +161 -24
- lionagi/integrations/config/oai_configs.py +34 -3
- lionagi/integrations/config/openrouter_configs.py +14 -2
- lionagi/integrations/loader/load.py +122 -21
- lionagi/integrations/loader/load_util.py +6 -77
- lionagi/integrations/provider/_mapping.py +46 -0
- lionagi/integrations/provider/litellm.py +2 -1
- lionagi/integrations/provider/mlx_service.py +16 -9
- lionagi/integrations/provider/oai.py +91 -4
- lionagi/integrations/provider/ollama.py +6 -5
- lionagi/integrations/provider/openrouter.py +115 -8
- lionagi/integrations/provider/services.py +2 -2
- lionagi/integrations/provider/transformers.py +18 -22
- lionagi/integrations/storage/__init__.py +3 -3
- lionagi/integrations/storage/neo4j.py +52 -60
- lionagi/integrations/storage/storage_util.py +44 -46
- lionagi/integrations/storage/structure_excel.py +43 -26
- lionagi/integrations/storage/to_excel.py +11 -4
- lionagi/libs/__init__.py +22 -1
- lionagi/libs/ln_api.py +75 -20
- lionagi/libs/ln_context.py +37 -0
- lionagi/libs/ln_convert.py +21 -9
- lionagi/libs/ln_func_call.py +69 -28
- lionagi/libs/ln_image.py +107 -0
- lionagi/libs/ln_nested.py +26 -11
- lionagi/libs/ln_parse.py +82 -23
- lionagi/libs/ln_queue.py +16 -0
- lionagi/libs/ln_tokenize.py +164 -0
- lionagi/libs/ln_validate.py +16 -0
- lionagi/libs/special_tokens.py +172 -0
- lionagi/libs/sys_util.py +95 -24
- lionagi/lions/coder/code_form.py +13 -0
- lionagi/lions/coder/coder.py +50 -3
- lionagi/lions/coder/util.py +30 -25
- lionagi/tests/libs/test_func_call.py +23 -21
- lionagi/tests/libs/test_nested.py +36 -21
- lionagi/tests/libs/test_parse.py +1 -1
- lionagi/tests/test_core/collections/__init__.py +0 -0
- lionagi/tests/test_core/collections/test_component.py +206 -0
- lionagi/tests/test_core/collections/test_exchange.py +138 -0
- lionagi/tests/test_core/collections/test_flow.py +145 -0
- lionagi/tests/test_core/collections/test_pile.py +171 -0
- lionagi/tests/test_core/collections/test_progression.py +129 -0
- lionagi/tests/test_core/generic/test_edge.py +67 -0
- lionagi/tests/test_core/generic/test_graph.py +96 -0
- lionagi/tests/test_core/generic/test_node.py +106 -0
- lionagi/tests/test_core/generic/test_tree_node.py +73 -0
- lionagi/tests/test_core/test_branch.py +115 -294
- lionagi/tests/test_core/test_form.py +46 -0
- lionagi/tests/test_core/test_report.py +105 -0
- lionagi/tests/test_core/test_validator.py +111 -0
- lionagi/version.py +1 -1
- lionagi-0.2.0.dist-info/LICENSE +202 -0
- lionagi-0.2.0.dist-info/METADATA +272 -0
- lionagi-0.2.0.dist-info/RECORD +240 -0
- lionagi/core/branch/base.py +0 -653
- lionagi/core/branch/branch.py +0 -474
- lionagi/core/branch/flow_mixin.py +0 -96
- lionagi/core/branch/util.py +0 -323
- lionagi/core/direct/__init__.py +0 -19
- lionagi/core/direct/cot.py +0 -123
- lionagi/core/direct/plan.py +0 -164
- lionagi/core/direct/predict.py +0 -166
- lionagi/core/direct/react.py +0 -171
- lionagi/core/direct/score.py +0 -279
- lionagi/core/direct/select.py +0 -170
- lionagi/core/direct/sentiment.py +0 -1
- lionagi/core/direct/utils.py +0 -110
- lionagi/core/direct/vote.py +0 -64
- lionagi/core/execute/base_executor.py +0 -47
- lionagi/core/flow/baseflow.py +0 -23
- lionagi/core/flow/monoflow/ReAct.py +0 -240
- lionagi/core/flow/monoflow/__init__.py +0 -9
- lionagi/core/flow/monoflow/chat.py +0 -95
- lionagi/core/flow/monoflow/chat_mixin.py +0 -253
- lionagi/core/flow/monoflow/followup.py +0 -215
- lionagi/core/flow/polyflow/__init__.py +0 -1
- lionagi/core/flow/polyflow/chat.py +0 -251
- lionagi/core/form/action_form.py +0 -26
- lionagi/core/form/field_validator.py +0 -287
- lionagi/core/form/form.py +0 -302
- lionagi/core/form/mixin.py +0 -214
- lionagi/core/form/scored_form.py +0 -13
- lionagi/core/generic/action.py +0 -26
- lionagi/core/generic/component.py +0 -532
- lionagi/core/generic/condition.py +0 -46
- lionagi/core/generic/mail.py +0 -90
- lionagi/core/generic/mailbox.py +0 -36
- lionagi/core/generic/relation.py +0 -70
- lionagi/core/generic/signal.py +0 -22
- lionagi/core/generic/structure.py +0 -362
- lionagi/core/generic/transfer.py +0 -20
- lionagi/core/generic/work.py +0 -40
- lionagi/core/graph/graph.py +0 -126
- lionagi/core/graph/tree.py +0 -190
- lionagi/core/mail/schema.py +0 -63
- lionagi/core/messages/schema.py +0 -325
- lionagi/core/tool/__init__.py +0 -5
- lionagi/core/tool/tool.py +0 -28
- lionagi/core/tool/tool_manager.py +0 -283
- lionagi/experimental/report/form.py +0 -64
- lionagi/experimental/report/report.py +0 -138
- lionagi/experimental/report/util.py +0 -47
- lionagi/experimental/tool/function_calling.py +0 -43
- lionagi/experimental/tool/manual.py +0 -66
- lionagi/experimental/tool/schema.py +0 -59
- lionagi/experimental/tool/tool_manager.py +0 -138
- lionagi/experimental/tool/util.py +0 -16
- lionagi/experimental/validator/rule.py +0 -139
- lionagi/experimental/validator/validator.py +0 -56
- lionagi/experimental/work/__init__.py +0 -10
- lionagi/experimental/work/async_queue.py +0 -54
- lionagi/experimental/work/schema.py +0 -73
- lionagi/experimental/work/work_function.py +0 -67
- lionagi/experimental/work/worker.py +0 -56
- lionagi/experimental/work2/form.py +0 -371
- lionagi/experimental/work2/report.py +0 -289
- lionagi/experimental/work2/schema.py +0 -30
- lionagi/experimental/work2/tests.py +0 -72
- lionagi/experimental/work2/work_function.py +0 -89
- lionagi/experimental/work2/worker.py +0 -12
- lionagi/integrations/bridge/llamaindex_/get_index.py +0 -294
- lionagi/tests/test_core/generic/test_component.py +0 -89
- lionagi/tests/test_core/test_base_branch.py +0 -426
- lionagi/tests/test_core/test_chat_flow.py +0 -63
- lionagi/tests/test_core/test_mail_manager.py +0 -75
- lionagi/tests/test_core/test_prompts.py +0 -51
- lionagi/tests/test_core/test_session.py +0 -254
- lionagi/tests/test_core/test_session_base_util.py +0 -313
- lionagi/tests/test_core/test_tool_manager.py +0 -95
- lionagi-0.1.2.dist-info/LICENSE +0 -9
- lionagi-0.1.2.dist-info/METADATA +0 -174
- lionagi-0.1.2.dist-info/RECORD +0 -206
- /lionagi/core/{branch → _setting}/__init__.py +0 -0
- /lionagi/core/{execute → agent/eval}/__init__.py +0 -0
- /lionagi/core/{flow → agent/learn}/__init__.py +0 -0
- /lionagi/core/{form → agent/plan}/__init__.py +0 -0
- /lionagi/core/{branch/executable_branch.py → agent/plan/plan.py} +0 -0
- /lionagi/core/{graph → director}/__init__.py +0 -0
- /lionagi/core/{messages → engine}/__init__.py +0 -0
- /lionagi/{experimental/directive/evaluator → core/engine}/sandbox_.py +0 -0
- /lionagi/{experimental/directive/evaluator → core/executor}/__init__.py +0 -0
- /lionagi/{experimental/directive/template_ → core/rule}/__init__.py +0 -0
- /lionagi/{experimental/report → core/unit/template}/__init__.py +0 -0
- /lionagi/{experimental/tool → core/validator}/__init__.py +0 -0
- /lionagi/{experimental/validator → core/work}/__init__.py +0 -0
- /lionagi/experimental/{work2 → compressor}/__init__.py +0 -0
- /lionagi/{core/flow/mono_chat_mixin.py → experimental/directive/template/__init__.py} +0 -0
- /lionagi/experimental/directive/{schema.py → template/schema.py} +0 -0
- /lionagi/experimental/{work2/util.py → evaluator/__init__.py} +0 -0
- /lionagi/experimental/{work2/work.py → knowledge/__init__.py} +0 -0
- /lionagi/{tests/libs/test_async.py → experimental/knowledge/graph.py} +0 -0
- {lionagi-0.1.2.dist-info → lionagi-0.2.0.dist-info}/WHEEL +0 -0
- {lionagi-0.1.2.dist-info → lionagi-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,419 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2024 HaiyangLi
|
3
|
+
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
you may not use this file except in compliance with the License.
|
6
|
+
You may obtain a copy of the License at
|
7
|
+
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
See the License for the specific language governing permissions and
|
14
|
+
limitations under the License.
|
15
|
+
"""
|
16
|
+
|
17
|
+
import os
|
18
|
+
import asyncio
|
19
|
+
import numpy as np
|
20
|
+
from dotenv import load_dotenv
|
21
|
+
from lionagi.libs import SysUtil, BaseService, StatusTracker, APIUtil, to_list, ninsert
|
22
|
+
from .abc import Component, ModelLimitExceededError
|
23
|
+
|
24
|
+
load_dotenv()
|
25
|
+
|
26
|
+
|
27
|
+
_oai_price_map = {
|
28
|
+
"gpt-4o": (5, 15),
|
29
|
+
"gpt-4-turbo": (10, 30),
|
30
|
+
"gpt-3.5-turbo": (0.5, 1.5),
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
class iModel:
|
35
|
+
"""
|
36
|
+
iModel is a class for managing AI model configurations and service
|
37
|
+
integrations.
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
ln_id (str): A unique identifier for the model instance.
|
41
|
+
timestamp (str): The timestamp when the model instance is created.
|
42
|
+
endpoint (str): The API endpoint for the model service.
|
43
|
+
provider_schema (dict): The schema for the service provider.
|
44
|
+
provider (BaseService): The service provider instance.
|
45
|
+
endpoint_schema (dict): The schema for the endpoint configuration.
|
46
|
+
api_key (str): The API key for the service provider.
|
47
|
+
status_tracker (StatusTracker): Instance of StatusTracker to track
|
48
|
+
service status.
|
49
|
+
service (BaseService): Configured service instance.
|
50
|
+
config (dict): Configuration dictionary for the model.
|
51
|
+
iModel_name (str): Name of the model.
|
52
|
+
"""
|
53
|
+
|
54
|
+
default_model = "gpt-4o"
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
model: str = None,
|
59
|
+
config: dict = None,
|
60
|
+
provider: str = None,
|
61
|
+
provider_schema: dict = None,
|
62
|
+
endpoint: str = "chat/completions",
|
63
|
+
token_encoding_name: str = None,
|
64
|
+
api_key: str = None,
|
65
|
+
api_key_schema: str = None,
|
66
|
+
interval_tokens: int = None,
|
67
|
+
interval_requests: int = None,
|
68
|
+
interval: int = None,
|
69
|
+
service: BaseService = None,
|
70
|
+
allowed_parameters=[],
|
71
|
+
device: str = None,
|
72
|
+
costs=None,
|
73
|
+
**kwargs, # additional parameters for the model
|
74
|
+
):
|
75
|
+
"""
|
76
|
+
Initializes an instance of the iModel class.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
model (str, optional): Name of the model.
|
80
|
+
config (dict, optional): Configuration dictionary.
|
81
|
+
provider (str, optional): Name or class of the provider.
|
82
|
+
provider_schema (dict, optional): Schema dictionary for the
|
83
|
+
provider.
|
84
|
+
endpoint (str, optional): Endpoint string, default is
|
85
|
+
"chat/completions".
|
86
|
+
token_encoding_name (str, optional): Name of the token encoding,
|
87
|
+
default is "cl100k_base".
|
88
|
+
api_key (str, optional): API key for the provider.
|
89
|
+
api_key_schema (str, optional): Schema for the API key.
|
90
|
+
interval_tokens (int, optional): Token interval limit, default is
|
91
|
+
100,000.
|
92
|
+
interval_requests (int, optional): Request interval limit, default
|
93
|
+
is 1,000.
|
94
|
+
interval (int, optional): Time interval in seconds, default is 60.
|
95
|
+
service (BaseService, optional): An instance of BaseService.
|
96
|
+
**kwargs: Additional parameters for the model.
|
97
|
+
"""
|
98
|
+
self.ln_id: str = SysUtil.create_id()
|
99
|
+
self.timestamp: str = SysUtil.get_timestamp(sep=None)[:-6]
|
100
|
+
self.endpoint = endpoint
|
101
|
+
self.allowed_parameters = allowed_parameters
|
102
|
+
if isinstance(provider, type):
|
103
|
+
provider = provider.__name__.replace("Service", "").lower()
|
104
|
+
|
105
|
+
else:
|
106
|
+
provider = str(provider).lower() if provider else "openai"
|
107
|
+
|
108
|
+
from lionagi.integrations.provider._mapping import (
|
109
|
+
SERVICE_PROVIDERS_MAPPING,
|
110
|
+
)
|
111
|
+
|
112
|
+
self.provider_schema = (
|
113
|
+
provider_schema or SERVICE_PROVIDERS_MAPPING[provider]["schema"]
|
114
|
+
)
|
115
|
+
self.provider = SERVICE_PROVIDERS_MAPPING[provider]["service"]
|
116
|
+
self.endpoint_schema = self.provider_schema.get(endpoint, {})
|
117
|
+
self.token_limit = self.endpoint_schema.get("token_limit", 100_000)
|
118
|
+
|
119
|
+
if api_key is not None:
|
120
|
+
self.api_key = api_key
|
121
|
+
|
122
|
+
elif api_key_schema is not None:
|
123
|
+
self.api_key = os.getenv(api_key_schema)
|
124
|
+
else:
|
125
|
+
api_schema = self.provider_schema.get("API_key_schema", None)
|
126
|
+
if api_schema:
|
127
|
+
self.api_key = os.getenv(
|
128
|
+
self.provider_schema["API_key_schema"][0], None
|
129
|
+
)
|
130
|
+
|
131
|
+
self.status_tracker = StatusTracker()
|
132
|
+
|
133
|
+
set_up_kwargs = {
|
134
|
+
"api_key": getattr(self, "api_key", None),
|
135
|
+
"schema": self.provider_schema or None,
|
136
|
+
"endpoint": self.endpoint,
|
137
|
+
"token_limit": self.token_limit,
|
138
|
+
"token_encoding_name": token_encoding_name
|
139
|
+
or self.endpoint_schema.get("token_encoding_name", None),
|
140
|
+
"max_tokens": interval_tokens
|
141
|
+
or self.endpoint_schema.get("interval_tokens", None),
|
142
|
+
"max_requests": interval_requests
|
143
|
+
or self.endpoint_schema.get("interval_requests", None),
|
144
|
+
"interval": interval or self.endpoint_schema.get("interval", None),
|
145
|
+
}
|
146
|
+
|
147
|
+
set_up_kwargs = {
|
148
|
+
k: v
|
149
|
+
for k, v in set_up_kwargs.items()
|
150
|
+
if v is not None and k in self.allowed_parameters
|
151
|
+
}
|
152
|
+
|
153
|
+
self.config = self._set_up_params(
|
154
|
+
config or self.endpoint_schema.get("config", {}), **kwargs
|
155
|
+
)
|
156
|
+
|
157
|
+
if not model:
|
158
|
+
if "model" not in self.config:
|
159
|
+
model = SERVICE_PROVIDERS_MAPPING[provider]["default_model"]
|
160
|
+
|
161
|
+
if self.config.get("model", None) != model and model is not None:
|
162
|
+
self.iModel_name = model
|
163
|
+
self.config["model"] = model
|
164
|
+
ninsert(self.endpoint_schema, ["config", "model"], model)
|
165
|
+
|
166
|
+
else:
|
167
|
+
self.iModel_name = self.config["model"]
|
168
|
+
|
169
|
+
if device:
|
170
|
+
set_up_kwargs["device"] = device
|
171
|
+
set_up_kwargs["model"] = self.iModel_name
|
172
|
+
self.service: BaseService = self._set_up_service(
|
173
|
+
service=service,
|
174
|
+
provider=self.provider,
|
175
|
+
**set_up_kwargs,
|
176
|
+
)
|
177
|
+
if self.iModel_name in _oai_price_map:
|
178
|
+
self.costs = _oai_price_map[self.iModel_name]
|
179
|
+
else:
|
180
|
+
self.costs = costs or (0, 0)
|
181
|
+
|
182
|
+
def update_config(self, **kwargs):
|
183
|
+
"""
|
184
|
+
Updates the configuration with additional parameters.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
**kwargs: Additional parameters to update the configuration.
|
188
|
+
"""
|
189
|
+
self.config = self._set_up_params(self.config, **kwargs)
|
190
|
+
|
191
|
+
def _set_up_config(self, model_config, **kwargs):
|
192
|
+
"""
|
193
|
+
Sets up the model configuration.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
model_config (dict): The default configuration dictionary.
|
197
|
+
**kwargs: Additional parameters to update the configuration.
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
dict: Updated configuration dictionary.
|
201
|
+
"""
|
202
|
+
return {**model_config, **kwargs}
|
203
|
+
|
204
|
+
def _set_up_service(self, service=None, provider=None, **kwargs):
|
205
|
+
"""
|
206
|
+
Sets up the service for the model.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
service (BaseService, optional): An instance of BaseService.
|
210
|
+
provider (str, optional): Provider name or instance.
|
211
|
+
**kwargs: Additional parameters for the service.
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
BaseService: Configured service instance.
|
215
|
+
"""
|
216
|
+
if not service:
|
217
|
+
provider = provider or self.provider
|
218
|
+
a = provider.__name__.replace("Service", "").lower()
|
219
|
+
if a in ["openai", "openrouter"]:
|
220
|
+
kwargs.pop("model", None)
|
221
|
+
|
222
|
+
return provider(**kwargs)
|
223
|
+
return service
|
224
|
+
|
225
|
+
def _set_up_params(self, default_config=None, **kwargs):
|
226
|
+
"""
|
227
|
+
Sets up the parameters for the model.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
default_config (dict, optional): The default configuration
|
231
|
+
dictionary.
|
232
|
+
**kwargs: Additional parameters to update the configuration.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
dict: Updated parameters dictionary.
|
236
|
+
|
237
|
+
Raises:
|
238
|
+
ValueError: If any parameter is not allowed.
|
239
|
+
"""
|
240
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
241
|
+
params = {**default_config, **kwargs}
|
242
|
+
allowed_params = self.endpoint_schema.get(
|
243
|
+
"required", []
|
244
|
+
) + self.endpoint_schema.get("optional", [])
|
245
|
+
|
246
|
+
if allowed_params != []:
|
247
|
+
if (
|
248
|
+
len(
|
249
|
+
not_allowed := [k for k in params.keys() if k not in allowed_params]
|
250
|
+
)
|
251
|
+
> 0
|
252
|
+
):
|
253
|
+
raise ValueError(f"Not allowed parameters: {not_allowed}")
|
254
|
+
|
255
|
+
return params
|
256
|
+
|
257
|
+
async def call_chat_completion(self, messages, **kwargs):
|
258
|
+
"""
|
259
|
+
Asynchronous method to call the chat completion service.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
messages (list): List of messages for the chat completion.
|
263
|
+
**kwargs: Additional parameters for the service call.
|
264
|
+
|
265
|
+
Returns:
|
266
|
+
dict: Response from the chat completion service.
|
267
|
+
"""
|
268
|
+
|
269
|
+
num_tokens = APIUtil.calculate_num_token(
|
270
|
+
{"messages": messages},
|
271
|
+
"chat/completions",
|
272
|
+
self.endpoint_schema.get("token_encoding_name", None),
|
273
|
+
)
|
274
|
+
|
275
|
+
if num_tokens > self.token_limit:
|
276
|
+
raise ModelLimitExceededError(
|
277
|
+
f"Number of tokens {num_tokens} exceeds the limit {self.token_limit}"
|
278
|
+
)
|
279
|
+
|
280
|
+
return await self.service.serve_chat(
|
281
|
+
messages, required_tokens=num_tokens, **kwargs
|
282
|
+
)
|
283
|
+
|
284
|
+
async def call_embedding(self, embed_str, **kwargs):
|
285
|
+
"""
|
286
|
+
Asynchronous method to call the embedding service.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
input_file (str): Path to the input file.
|
290
|
+
**kwargs: Additional parameters for the service call.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
dict: Response from the embedding service.
|
294
|
+
"""
|
295
|
+
return await self.service.serve_embedding(embed_str, **kwargs)
|
296
|
+
|
297
|
+
async def embed_node(self, node, field="content", **kwargs) -> bool:
|
298
|
+
"""
|
299
|
+
if not specify field, we embed node.content
|
300
|
+
"""
|
301
|
+
if not isinstance(node, Component):
|
302
|
+
raise ValueError("Node must a lionagi item")
|
303
|
+
embed_str = getattr(node, field)
|
304
|
+
|
305
|
+
if isinstance(embed_str, dict) and "images" in embed_str:
|
306
|
+
embed_str.pop("images", None)
|
307
|
+
embed_str.pop("image_detail", None)
|
308
|
+
|
309
|
+
num_tokens = APIUtil.calculate_num_token(
|
310
|
+
{"input": str(embed_str) if isinstance(embed_str, dict) else embed_str},
|
311
|
+
"embeddings",
|
312
|
+
self.endpoint_schema["token_encoding_name"],
|
313
|
+
)
|
314
|
+
|
315
|
+
if self.token_limit and num_tokens > self.token_limit:
|
316
|
+
raise ModelLimitExceededError(
|
317
|
+
f"Number of tokens {num_tokens} exceeds the limit {self.token_limit}"
|
318
|
+
)
|
319
|
+
|
320
|
+
payload, embed = await self.call_embedding(embed_str, **kwargs)
|
321
|
+
payload.pop("input")
|
322
|
+
node.add_field("embedding", embed["data"][0]["embedding"])
|
323
|
+
node._meta_insert("embedding_meta", payload)
|
324
|
+
|
325
|
+
def to_dict(self):
|
326
|
+
"""
|
327
|
+
Converts the model instance to a dictionary representation.
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
dict: Dictionary representation of the model instance.
|
331
|
+
"""
|
332
|
+
return {
|
333
|
+
"ln_id": self.ln_id,
|
334
|
+
"timestamp": self.timestamp,
|
335
|
+
"provider": self.provider.__name__.replace("Service", ""),
|
336
|
+
"api_key": self.api_key[:4]
|
337
|
+
+ "*" * (len(self.api_key) - 8)
|
338
|
+
+ self.api_key[-4:],
|
339
|
+
"endpoint": self.endpoint,
|
340
|
+
"token_encoding_name": self.service.token_encoding_name,
|
341
|
+
**{
|
342
|
+
k: v
|
343
|
+
for k, v in self.config.items()
|
344
|
+
if k in getattr(self.service, "allowed_kwargs", []) and v is not None
|
345
|
+
},
|
346
|
+
"model_costs": None if self.costs == (0, 0) else self.costs,
|
347
|
+
}
|
348
|
+
|
349
|
+
async def compute_perplexity(
|
350
|
+
self,
|
351
|
+
initial_context: str = None,
|
352
|
+
tokens: list[str] = None,
|
353
|
+
system_msg: str = None,
|
354
|
+
n_samples: int = 1, # number of samples used for the computation
|
355
|
+
use_residue: bool = True, # whether to use residue for the last sample
|
356
|
+
**kwargs, # additional arguments for the model
|
357
|
+
) -> tuple[list[str], float]:
|
358
|
+
tasks = []
|
359
|
+
context = initial_context or ""
|
360
|
+
|
361
|
+
n_samples = n_samples or len(tokens)
|
362
|
+
sample_token_len, residue = divmod(len(tokens), n_samples)
|
363
|
+
samples = []
|
364
|
+
|
365
|
+
if n_samples == 1:
|
366
|
+
samples = [tokens]
|
367
|
+
else:
|
368
|
+
samples = [tokens[: (i + 1) * sample_token_len] for i in range(n_samples)]
|
369
|
+
|
370
|
+
if use_residue and residue != 0:
|
371
|
+
samples.append(tokens[-residue:])
|
372
|
+
|
373
|
+
sampless = [context + " ".join(sample) for sample in samples]
|
374
|
+
|
375
|
+
for sample in sampless:
|
376
|
+
messages = [{"role": "system", "content": system_msg}] if system_msg else []
|
377
|
+
messages.append(
|
378
|
+
{"role": "user", "content": sample},
|
379
|
+
)
|
380
|
+
|
381
|
+
task = asyncio.create_task(
|
382
|
+
self.call_chat_completion(
|
383
|
+
messages=messages,
|
384
|
+
logprobs=True,
|
385
|
+
max_tokens=sample_token_len,
|
386
|
+
**kwargs,
|
387
|
+
)
|
388
|
+
)
|
389
|
+
tasks.append(task)
|
390
|
+
|
391
|
+
results = await asyncio.gather(*tasks) # result is (payload, response)
|
392
|
+
results_ = []
|
393
|
+
num_prompt_tokens = 0
|
394
|
+
num_completion_tokens = 0
|
395
|
+
|
396
|
+
for idx, result in enumerate(results):
|
397
|
+
_dict = {}
|
398
|
+
_dict["tokens"] = samples[idx]
|
399
|
+
|
400
|
+
num_prompt_tokens += result[1]["usage"]["prompt_tokens"]
|
401
|
+
num_completion_tokens += result[1]["usage"]["completion_tokens"]
|
402
|
+
|
403
|
+
logprobs = result[1]["choices"][0]["logprobs"]["content"]
|
404
|
+
logprobs = to_list(logprobs, flatten=True, dropna=True)
|
405
|
+
_dict["logprobs"] = [(i["token"], i["logprob"]) for i in logprobs]
|
406
|
+
results_.append(_dict)
|
407
|
+
|
408
|
+
logprobs = to_list(
|
409
|
+
[[i[1] for i in d["logprobs"]] for d in results_], flatten=True
|
410
|
+
)
|
411
|
+
|
412
|
+
return {
|
413
|
+
"tokens": tokens,
|
414
|
+
"n_samples": n_samples,
|
415
|
+
"num_prompt_tokens": num_prompt_tokens,
|
416
|
+
"num_completion_tokens": num_completion_tokens,
|
417
|
+
"logprobs": logprobs,
|
418
|
+
"perplexity": np.exp(np.mean(logprobs)),
|
419
|
+
}
|