opengradient 0.3.16__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opengradient/__init__.py +195 -25
- opengradient/cli.py +49 -16
- opengradient/client.py +75 -128
- opengradient/llm/__init__.py +36 -3
- opengradient/llm/og_openai.py +121 -0
- opengradient/types.py +11 -0
- {opengradient-0.3.16.dist-info → opengradient-0.3.18.dist-info}/METADATA +4 -30
- opengradient-0.3.18.dist-info/RECORD +21 -0
- {opengradient-0.3.16.dist-info → opengradient-0.3.18.dist-info}/WHEEL +1 -1
- opengradient-0.3.16.dist-info/RECORD +0 -20
- /opengradient/llm/{chat.py → og_langchain.py} +0 -0
- {opengradient-0.3.16.dist-info → opengradient-0.3.18.dist-info}/entry_points.txt +0 -0
- {opengradient-0.3.16.dist-info → opengradient-0.3.18.dist-info}/licenses/LICENSE +0 -0
opengradient/__init__.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OpenGradient Python SDK for interacting with AI models and infrastructure.
|
|
3
|
+
"""
|
|
4
|
+
|
|
1
5
|
from typing import Dict, List, Optional, Tuple
|
|
2
6
|
|
|
3
7
|
from .client import Client
|
|
4
8
|
from .defaults import DEFAULT_INFERENCE_CONTRACT_ADDRESS, DEFAULT_RPC_URL
|
|
5
|
-
from .types import InferenceMode, LLM
|
|
9
|
+
from .types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
|
|
6
10
|
from . import llm
|
|
7
11
|
|
|
8
|
-
__version__ = "0.3.
|
|
12
|
+
__version__ = "0.3.18"
|
|
9
13
|
|
|
10
14
|
_client = None
|
|
11
15
|
|
|
@@ -14,15 +18,50 @@ def init(email: str,
|
|
|
14
18
|
private_key: str,
|
|
15
19
|
rpc_url=DEFAULT_RPC_URL,
|
|
16
20
|
contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS):
|
|
21
|
+
"""Initialize the OpenGradient SDK with authentication and network settings.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
email: User's email address for authentication
|
|
25
|
+
password: User's password for authentication
|
|
26
|
+
private_key: Ethereum private key for blockchain transactions
|
|
27
|
+
rpc_url: Optional RPC URL for the blockchain network, defaults to mainnet
|
|
28
|
+
contract_address: Optional inference contract address
|
|
29
|
+
"""
|
|
17
30
|
global _client
|
|
18
31
|
_client = Client(private_key=private_key, rpc_url=rpc_url, contract_address=contract_address, email=email, password=password)
|
|
19
32
|
|
|
20
33
|
def upload(model_path, model_name, version):
|
|
34
|
+
"""Upload a model file to OpenGradient.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model_path: Path to the model file on local filesystem
|
|
38
|
+
model_name: Name of the model repository
|
|
39
|
+
version: Version string for this model upload
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
dict: Upload response containing file metadata
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
RuntimeError: If SDK is not initialized
|
|
46
|
+
"""
|
|
21
47
|
if _client is None:
|
|
22
48
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
23
49
|
return _client.upload(model_path, model_name, version)
|
|
24
50
|
|
|
25
51
|
def create_model(model_name: str, model_desc: str, model_path: str = None):
|
|
52
|
+
"""Create a new model repository.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model_name: Name for the new model repository
|
|
56
|
+
model_desc: Description of the model
|
|
57
|
+
model_path: Optional path to model file to upload immediately
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: Creation response with model metadata and optional upload results
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
RuntimeError: If SDK is not initialized
|
|
64
|
+
"""
|
|
26
65
|
if _client is None:
|
|
27
66
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
28
67
|
|
|
@@ -36,62 +75,193 @@ def create_model(model_name: str, model_desc: str, model_path: str = None):
|
|
|
36
75
|
return result
|
|
37
76
|
|
|
38
77
|
def create_version(model_name, notes=None, is_major=False):
|
|
78
|
+
"""Create a new version for an existing model.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
model_name: Name of the model repository
|
|
82
|
+
notes: Optional release notes for this version
|
|
83
|
+
is_major: If True, creates a major version bump instead of minor
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
dict: Version creation response with version metadata
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
RuntimeError: If SDK is not initialized
|
|
90
|
+
"""
|
|
39
91
|
if _client is None:
|
|
40
92
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
41
93
|
return _client.create_version(model_name, notes, is_major)
|
|
42
94
|
|
|
43
|
-
def infer(model_cid, inference_mode, model_input):
|
|
95
|
+
def infer(model_cid, inference_mode, model_input, max_retries: Optional[int] = None):
|
|
96
|
+
"""Run inference on a model.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
model_cid: CID of the model to use
|
|
100
|
+
inference_mode: Mode of inference (e.g. VANILLA)
|
|
101
|
+
model_input: Input data for the model
|
|
102
|
+
max_retries: Maximum number of retries for failed transactions
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Tuple[str, Any]: Transaction hash and model output
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
RuntimeError: If SDK is not initialized
|
|
109
|
+
"""
|
|
44
110
|
if _client is None:
|
|
45
111
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
46
|
-
return _client.infer(model_cid, inference_mode, model_input)
|
|
112
|
+
return _client.infer(model_cid, inference_mode, model_input, max_retries=max_retries)
|
|
47
113
|
|
|
48
114
|
def llm_completion(model_cid: LLM,
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
115
|
+
prompt: str,
|
|
116
|
+
inference_mode: str = LlmInferenceMode.VANILLA,
|
|
117
|
+
max_tokens: int = 100,
|
|
118
|
+
stop_sequence: Optional[List[str]] = None,
|
|
119
|
+
temperature: float = 0.0,
|
|
120
|
+
max_retries: Optional[int] = None) -> Tuple[str, str]:
|
|
121
|
+
"""Generate text completion using an LLM.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
model_cid: CID of the LLM model to use
|
|
125
|
+
prompt: Text prompt for completion
|
|
126
|
+
inference_mode: Mode of inference, defaults to VANILLA
|
|
127
|
+
max_tokens: Maximum tokens to generate
|
|
128
|
+
stop_sequence: Optional list of sequences where generation should stop
|
|
129
|
+
temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative)
|
|
130
|
+
max_retries: Maximum number of retries for failed transactions
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Tuple[str, str]: Transaction hash and generated text
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
RuntimeError: If SDK is not initialized
|
|
137
|
+
"""
|
|
53
138
|
if _client is None:
|
|
54
139
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
55
|
-
return _client.llm_completion(model_cid,
|
|
140
|
+
return _client.llm_completion(model_cid=model_cid,
|
|
141
|
+
inference_mode=inference_mode,
|
|
142
|
+
prompt=prompt,
|
|
143
|
+
max_tokens=max_tokens,
|
|
144
|
+
stop_sequence=stop_sequence,
|
|
145
|
+
temperature=temperature,
|
|
146
|
+
max_retries=max_retries)
|
|
56
147
|
|
|
57
148
|
def llm_chat(model_cid: LLM,
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
149
|
+
messages: List[Dict],
|
|
150
|
+
inference_mode: str = LlmInferenceMode.VANILLA,
|
|
151
|
+
max_tokens: int = 100,
|
|
152
|
+
stop_sequence: Optional[List[str]] = None,
|
|
153
|
+
temperature: float = 0.0,
|
|
154
|
+
tools: Optional[List[Dict]] = None,
|
|
155
|
+
tool_choice: Optional[str] = None,
|
|
156
|
+
max_retries: Optional[int] = None) -> Tuple[str, str, Dict]:
|
|
157
|
+
"""Have a chat conversation with an LLM.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
model_cid: CID of the LLM model to use
|
|
161
|
+
messages: List of chat messages, each with 'role' and 'content'
|
|
162
|
+
inference_mode: Mode of inference, defaults to VANILLA
|
|
163
|
+
max_tokens: Maximum tokens to generate
|
|
164
|
+
stop_sequence: Optional list of sequences where generation should stop
|
|
165
|
+
temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative)
|
|
166
|
+
tools: Optional list of tools the model can use
|
|
167
|
+
tool_choice: Optional specific tool to use
|
|
168
|
+
max_retries: Maximum number of retries for failed transactions
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Tuple[str, str, Dict]: Transaction hash, model response, and metadata
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
RuntimeError: If SDK is not initialized
|
|
175
|
+
"""
|
|
64
176
|
if _client is None:
|
|
65
177
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
66
|
-
return _client.llm_chat(model_cid,
|
|
178
|
+
return _client.llm_chat(model_cid=model_cid,
|
|
179
|
+
inference_mode=inference_mode,
|
|
180
|
+
messages=messages,
|
|
181
|
+
max_tokens=max_tokens,
|
|
182
|
+
stop_sequence=stop_sequence,
|
|
183
|
+
temperature=temperature,
|
|
184
|
+
tools=tools,
|
|
185
|
+
tool_choice=tool_choice,
|
|
186
|
+
max_retries=max_retries)
|
|
67
187
|
|
|
68
188
|
def login(email: str, password: str):
|
|
189
|
+
"""Login to OpenGradient.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
email: User's email address
|
|
193
|
+
password: User's password
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
dict: Login response with authentication tokens
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
RuntimeError: If SDK is not initialized
|
|
200
|
+
"""
|
|
69
201
|
if _client is None:
|
|
70
202
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
71
203
|
return _client.login(email, password)
|
|
72
204
|
|
|
73
205
|
def list_files(model_name: str, version: str) -> List[Dict]:
|
|
206
|
+
"""List files in a model repository version.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
model_name: Name of the model repository
|
|
210
|
+
version: Version string to list files from
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
List[Dict]: List of file metadata dictionaries
|
|
214
|
+
|
|
215
|
+
Raises:
|
|
216
|
+
RuntimeError: If SDK is not initialized
|
|
217
|
+
"""
|
|
74
218
|
if _client is None:
|
|
75
219
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
76
220
|
return _client.list_files(model_name, version)
|
|
77
221
|
|
|
78
222
|
def generate_image(model: str, prompt: str, height: Optional[int] = None, width: Optional[int] = None) -> bytes:
|
|
79
|
-
"""
|
|
80
|
-
Generate an image using the specified model and prompt.
|
|
223
|
+
"""Generate an image from a text prompt.
|
|
81
224
|
|
|
82
225
|
Args:
|
|
83
|
-
model
|
|
84
|
-
prompt
|
|
85
|
-
height
|
|
86
|
-
width
|
|
226
|
+
model: Model identifier (e.g. "stabilityai/stable-diffusion-xl-base-1.0")
|
|
227
|
+
prompt: Text description of the desired image
|
|
228
|
+
height: Optional height of the generated image in pixels
|
|
229
|
+
width: Optional width of the generated image in pixels
|
|
87
230
|
|
|
88
231
|
Returns:
|
|
89
|
-
bytes:
|
|
232
|
+
bytes: Raw image data as bytes
|
|
90
233
|
|
|
91
234
|
Raises:
|
|
92
|
-
RuntimeError: If
|
|
93
|
-
OpenGradientError: If
|
|
235
|
+
RuntimeError: If SDK is not initialized
|
|
236
|
+
OpenGradientError: If image generation fails
|
|
94
237
|
"""
|
|
95
238
|
if _client is None:
|
|
96
239
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
97
240
|
return _client.generate_image(model, prompt, height=height, width=width)
|
|
241
|
+
|
|
242
|
+
__all__ = [
|
|
243
|
+
'generate_image',
|
|
244
|
+
'list_files'
|
|
245
|
+
'login',
|
|
246
|
+
'llm_chat',
|
|
247
|
+
'llm_completion',
|
|
248
|
+
'infer',
|
|
249
|
+
'create_version',
|
|
250
|
+
'create_model',
|
|
251
|
+
'upload',
|
|
252
|
+
'init',
|
|
253
|
+
'LLM',
|
|
254
|
+
'TEE_LLM'
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
__pdoc__ = {
|
|
258
|
+
'account': False,
|
|
259
|
+
'cli': False,
|
|
260
|
+
'client': False,
|
|
261
|
+
'defaults': False,
|
|
262
|
+
'exceptions': False,
|
|
263
|
+
'llm': True,
|
|
264
|
+
'proto': False,
|
|
265
|
+
'types': False,
|
|
266
|
+
'utils': False
|
|
267
|
+
}
|
opengradient/cli.py
CHANGED
|
@@ -20,7 +20,7 @@ from .defaults import (
|
|
|
20
20
|
DEFAULT_OG_FAUCET_URL,
|
|
21
21
|
DEFAULT_RPC_URL,
|
|
22
22
|
)
|
|
23
|
-
from .types import InferenceMode
|
|
23
|
+
from .types import InferenceMode, LlmInferenceMode
|
|
24
24
|
|
|
25
25
|
OG_CONFIG_FILE = Path.home() / '.opengradient_config.json'
|
|
26
26
|
|
|
@@ -65,11 +65,17 @@ InferenceModes = {
|
|
|
65
65
|
"TEE": InferenceMode.TEE,
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
LlmInferenceModes = {
|
|
69
|
+
"VANILLA": LlmInferenceMode.VANILLA,
|
|
70
|
+
"TEE": LlmInferenceMode.TEE,
|
|
71
|
+
}
|
|
72
|
+
|
|
68
73
|
# Supported LLMs
|
|
69
74
|
LlmModels = {
|
|
70
75
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
|
71
76
|
"meta-llama/Llama-3.2-3B-Instruct",
|
|
72
|
-
"mistralai/Mistral-7B-Instruct-v0.3"
|
|
77
|
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
|
78
|
+
"meta-llama/Llama-3.1-70B-Instruct",
|
|
73
79
|
}
|
|
74
80
|
|
|
75
81
|
def initialize_config(ctx):
|
|
@@ -339,13 +345,15 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
|
|
|
339
345
|
click.echo(f"Error running inference: {str(e)}")
|
|
340
346
|
|
|
341
347
|
@cli.command()
|
|
342
|
-
@click.option('--model', '-m', 'model_cid', type=click.Choice(
|
|
348
|
+
@click.option('--model', '-m', 'model_cid', type=click.Choice([e.value for e in types.LLM]), required=True, help='CID of the LLM model to run inference on')
|
|
349
|
+
@click.option('--mode', 'inference_mode', type=click.Choice(LlmInferenceModes.keys()), default="VANILLA",
|
|
350
|
+
help='Inference mode (default: VANILLA)')
|
|
343
351
|
@click.option('--prompt', '-p', required=True, help='Input prompt for the LLM completion')
|
|
344
352
|
@click.option('--max-tokens', type=int, default=100, help='Maximum number of tokens for LLM completion output')
|
|
345
353
|
@click.option('--stop-sequence', multiple=True, help='Stop sequences for LLM')
|
|
346
354
|
@click.option('--temperature', type=float, default=0.0, help='Temperature for LLM inference (0.0 to 1.0)')
|
|
347
355
|
@click.pass_context
|
|
348
|
-
def completion(ctx, model_cid: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
|
|
356
|
+
def completion(ctx, model_cid: str, inference_mode: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
|
|
349
357
|
"""
|
|
350
358
|
Run completion inference on an LLM model.
|
|
351
359
|
|
|
@@ -355,13 +363,14 @@ def completion(ctx, model_cid: str, prompt: str, max_tokens: int, stop_sequence:
|
|
|
355
363
|
|
|
356
364
|
\b
|
|
357
365
|
opengradient completion --model meta-llama/Meta-Llama-3-8B-Instruct --prompt "Hello, how are you?" --max-tokens 50 --temperature 0.7
|
|
358
|
-
opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "
|
|
366
|
+
opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "\\n"
|
|
359
367
|
"""
|
|
360
368
|
client: Client = ctx.obj['client']
|
|
361
369
|
try:
|
|
362
370
|
click.echo(f"Running LLM completion inference for model \"{model_cid}\"\n")
|
|
363
371
|
tx_hash, llm_output = client.llm_completion(
|
|
364
372
|
model_cid=model_cid,
|
|
373
|
+
inference_mode=LlmInferenceModes[inference_mode],
|
|
365
374
|
prompt=prompt,
|
|
366
375
|
max_tokens=max_tokens,
|
|
367
376
|
stop_sequence=list(stop_sequence),
|
|
@@ -394,6 +403,9 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
|
|
|
394
403
|
type=click.Choice([e.value for e in types.LLM]),
|
|
395
404
|
required=True,
|
|
396
405
|
help='CID of the LLM model to run inference on')
|
|
406
|
+
@click.option('--mode', 'inference_mode', type=click.Choice(LlmInferenceModes.keys()),
|
|
407
|
+
default="VANILLA",
|
|
408
|
+
help='Inference mode (default: VANILLA)')
|
|
397
409
|
@click.option('--messages',
|
|
398
410
|
type=str,
|
|
399
411
|
required=False,
|
|
@@ -431,6 +443,7 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
|
|
|
431
443
|
def chat(
|
|
432
444
|
ctx,
|
|
433
445
|
model_cid: str,
|
|
446
|
+
inference_mode: str,
|
|
434
447
|
messages: Optional[str],
|
|
435
448
|
messages_file: Optional[Path],
|
|
436
449
|
max_tokens: int,
|
|
@@ -444,11 +457,13 @@ def chat(
|
|
|
444
457
|
|
|
445
458
|
This command runs a chat inference on the specified LLM model using the provided messages and parameters.
|
|
446
459
|
|
|
460
|
+
Tool call formatting is based on OpenAI documentation tool calls (see here: https://platform.openai.com/docs/guides/function-calling).
|
|
461
|
+
|
|
447
462
|
Example usage:
|
|
448
463
|
|
|
449
464
|
\b
|
|
450
465
|
opengradient chat --model meta-llama/Meta-Llama-3-8B-Instruct --messages '[{"role":"user","content":"hello"}]' --max-tokens 50 --temperature 0.7
|
|
451
|
-
opengradient chat
|
|
466
|
+
opengradient chat --model mistralai/Mistral-7B-Instruct-v0.3 --messages-file messages.json --tools-file tools.json --max-tokens 200 --stop-sequence "." --stop-sequence "\\n"
|
|
452
467
|
"""
|
|
453
468
|
client: Client = ctx.obj['client']
|
|
454
469
|
try:
|
|
@@ -458,7 +473,7 @@ def chat(
|
|
|
458
473
|
ctx.exit(1)
|
|
459
474
|
return
|
|
460
475
|
if messages and messages_file:
|
|
461
|
-
click.echo("Cannot have both messages and
|
|
476
|
+
click.echo("Cannot have both messages and messages-file")
|
|
462
477
|
ctx.exit(1)
|
|
463
478
|
return
|
|
464
479
|
|
|
@@ -473,9 +488,9 @@ def chat(
|
|
|
473
488
|
messages = json.load(file)
|
|
474
489
|
|
|
475
490
|
# Parse tools if provided
|
|
476
|
-
if tools
|
|
477
|
-
click.echo("Cannot have both tools and
|
|
478
|
-
|
|
491
|
+
if (tools and tools != '[]') and tools_file:
|
|
492
|
+
click.echo("Cannot have both tools and tools-file")
|
|
493
|
+
click.exit(1)
|
|
479
494
|
return
|
|
480
495
|
|
|
481
496
|
parsed_tools=[]
|
|
@@ -509,6 +524,7 @@ def chat(
|
|
|
509
524
|
|
|
510
525
|
tx_hash, finish_reason, llm_chat_output = client.llm_chat(
|
|
511
526
|
model_cid=model_cid,
|
|
527
|
+
inference_mode=LlmInferenceModes[inference_mode],
|
|
512
528
|
messages=messages,
|
|
513
529
|
max_tokens=max_tokens,
|
|
514
530
|
stop_sequence=list(stop_sequence),
|
|
@@ -517,15 +533,32 @@ def chat(
|
|
|
517
533
|
tool_choice=tool_choice,
|
|
518
534
|
)
|
|
519
535
|
|
|
520
|
-
|
|
521
|
-
print("TX Hash: ", tx_hash)
|
|
522
|
-
print("Finish reason: ", finish_reason)
|
|
523
|
-
print("Chat output: ", llm_chat_output)
|
|
536
|
+
print_llm_chat_result(model_cid, tx_hash, finish_reason, llm_chat_output)
|
|
524
537
|
except Exception as e:
|
|
525
538
|
click.echo(f"Error running LLM chat inference: {str(e)}")
|
|
526
539
|
|
|
527
|
-
def print_llm_chat_result():
|
|
528
|
-
|
|
540
|
+
def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output):
|
|
541
|
+
click.secho("✅ LLM Chat Successful", fg="green", bold=True)
|
|
542
|
+
click.echo("──────────────────────────────────────")
|
|
543
|
+
click.echo("Model CID: ", nl=False)
|
|
544
|
+
click.secho(model_cid, fg="cyan", bold=True)
|
|
545
|
+
click.echo("Transaction hash: ", nl=False)
|
|
546
|
+
click.secho(tx_hash, fg="cyan", bold=True)
|
|
547
|
+
block_explorer_link = f"{DEFAULT_BLOCKCHAIN_EXPLORER}0x{tx_hash}"
|
|
548
|
+
click.echo("Block explorer link: ", nl=False)
|
|
549
|
+
click.secho(block_explorer_link, fg="blue", underline=True)
|
|
550
|
+
click.echo("──────────────────────────────────────")
|
|
551
|
+
click.secho("Finish Reason: ", fg="yellow", bold=True)
|
|
552
|
+
click.echo()
|
|
553
|
+
click.echo(finish_reason)
|
|
554
|
+
click.echo()
|
|
555
|
+
click.secho("Chat Output:", fg="yellow", bold=True)
|
|
556
|
+
click.echo()
|
|
557
|
+
for key, value in chat_output.items():
|
|
558
|
+
# If the value doesn't give any information, don't print it
|
|
559
|
+
if value != None and value != "" and value != '[]' and value != []:
|
|
560
|
+
click.echo(f"{key}: {value}")
|
|
561
|
+
click.echo()
|
|
529
562
|
|
|
530
563
|
@cli.command()
|
|
531
564
|
def create_account():
|
opengradient/client.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import random
|
|
4
5
|
from typing import Dict, List, Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
import firebase
|
|
@@ -12,7 +13,7 @@ from web3.logs import DISCARD
|
|
|
12
13
|
|
|
13
14
|
from opengradient import utils
|
|
14
15
|
from opengradient.exceptions import OpenGradientError
|
|
15
|
-
from opengradient.types import InferenceMode, LLM
|
|
16
|
+
from opengradient.types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
|
|
16
17
|
|
|
17
18
|
import grpc
|
|
18
19
|
import time
|
|
@@ -23,6 +24,31 @@ from opengradient.proto import infer_pb2
|
|
|
23
24
|
from opengradient.proto import infer_pb2_grpc
|
|
24
25
|
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
|
|
25
26
|
|
|
27
|
+
from functools import wraps
|
|
28
|
+
|
|
29
|
+
def run_with_retry(txn_function, max_retries=5):
|
|
30
|
+
"""
|
|
31
|
+
Execute a blockchain transaction with retry logic.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
txn_function: Function that executes the transaction
|
|
35
|
+
max_retries (int): Maximum number of retry attempts
|
|
36
|
+
"""
|
|
37
|
+
last_error = None
|
|
38
|
+
for attempt in range(max_retries):
|
|
39
|
+
try:
|
|
40
|
+
return txn_function()
|
|
41
|
+
except Exception as e:
|
|
42
|
+
last_error = e
|
|
43
|
+
if attempt < max_retries - 1:
|
|
44
|
+
if "nonce too low" in str(e) or "nonce too high" in str(e):
|
|
45
|
+
time.sleep(1) # Wait before retry
|
|
46
|
+
continue
|
|
47
|
+
# If it's not a nonce error, raise immediately
|
|
48
|
+
raise
|
|
49
|
+
# If we've exhausted all retries, raise the last error
|
|
50
|
+
raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")
|
|
51
|
+
|
|
26
52
|
class Client:
|
|
27
53
|
FIREBASE_CONFIG = {
|
|
28
54
|
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
@@ -311,7 +337,8 @@ class Client:
|
|
|
311
337
|
self,
|
|
312
338
|
model_cid: str,
|
|
313
339
|
inference_mode: InferenceMode,
|
|
314
|
-
model_input: Dict[str, Union[str, int, float, List, np.ndarray]]
|
|
340
|
+
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
|
|
341
|
+
max_retries: Optional[int] = None
|
|
315
342
|
) -> Tuple[str, Dict[str, np.ndarray]]:
|
|
316
343
|
"""
|
|
317
344
|
Perform inference on a model.
|
|
@@ -320,6 +347,7 @@ class Client:
|
|
|
320
347
|
model_cid (str): The unique content identifier for the model from IPFS.
|
|
321
348
|
inference_mode (InferenceMode): The inference mode.
|
|
322
349
|
model_input (Dict[str, Union[str, int, float, List, np.ndarray]]): The input data for the model.
|
|
350
|
+
max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
|
|
323
351
|
|
|
324
352
|
Returns:
|
|
325
353
|
Tuple[str, Dict[str, np.ndarray]]: The transaction hash and the model output.
|
|
@@ -327,46 +355,22 @@ class Client:
|
|
|
327
355
|
Raises:
|
|
328
356
|
OpenGradientError: If the inference fails.
|
|
329
357
|
"""
|
|
330
|
-
|
|
331
|
-
try:
|
|
332
|
-
logging.debug("Entering infer method")
|
|
358
|
+
def execute_transaction():
|
|
333
359
|
self._initialize_web3()
|
|
334
|
-
logging.debug(f"Web3 initialized. Connected: {self._w3.is_connected()}")
|
|
335
|
-
|
|
336
|
-
logging.debug(f"Creating contract instance. Address: {self.contract_address}")
|
|
337
360
|
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
logging.debug(f"Model ID: {model_cid}")
|
|
341
|
-
logging.debug(f"Inference Mode: {inference_mode}")
|
|
342
|
-
logging.debug(f"Model Input: {model_input}")
|
|
343
|
-
|
|
344
|
-
# Convert InferenceMode to uint8
|
|
361
|
+
|
|
345
362
|
inference_mode_uint8 = int(inference_mode)
|
|
346
|
-
|
|
347
|
-
# Prepare ModelInput tuple
|
|
348
363
|
converted_model_input = utils.convert_to_model_input(model_input)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
logging.debug("Preparing run function")
|
|
364
|
+
|
|
352
365
|
run_function = contract.functions.run(
|
|
353
366
|
model_cid,
|
|
354
367
|
inference_mode_uint8,
|
|
355
368
|
converted_model_input
|
|
356
369
|
)
|
|
357
|
-
logging.debug("Run function prepared successfully")
|
|
358
370
|
|
|
359
|
-
|
|
360
|
-
nonce = self._w3.eth.get_transaction_count(self.wallet_address)
|
|
361
|
-
logging.debug(f"Nonce: {nonce}")
|
|
362
|
-
|
|
363
|
-
# Estimate gas
|
|
371
|
+
nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
|
|
364
372
|
estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
|
|
365
|
-
logging.debug(f"Estimated gas: {estimated_gas}")
|
|
366
|
-
|
|
367
|
-
# Increase gas limit by 20%
|
|
368
373
|
gas_limit = int(estimated_gas * 3)
|
|
369
|
-
logging.debug(f"Gas limit set to: {gas_limit}")
|
|
370
374
|
|
|
371
375
|
transaction = run_function.build_transaction({
|
|
372
376
|
'from': self.wallet_address,
|
|
@@ -375,62 +379,36 @@ class Client:
|
|
|
375
379
|
'gasPrice': self._w3.eth.gas_price,
|
|
376
380
|
})
|
|
377
381
|
|
|
378
|
-
logging.debug(f"Transaction built: {transaction}")
|
|
379
|
-
|
|
380
|
-
# Sign transaction
|
|
381
382
|
signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
|
|
382
|
-
logging.debug("Transaction signed successfully")
|
|
383
|
-
|
|
384
|
-
# Send transaction
|
|
385
383
|
tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
386
|
-
logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
|
|
387
|
-
|
|
388
|
-
# Wait for transaction receipt
|
|
389
384
|
tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
|
|
390
|
-
logging.debug(f"Transaction receipt received: {tx_receipt}")
|
|
391
385
|
|
|
392
|
-
# Check if the transaction was successful
|
|
393
386
|
if tx_receipt['status'] == 0:
|
|
394
387
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
395
388
|
|
|
396
|
-
# Process the InferenceResult event
|
|
397
389
|
parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
398
|
-
|
|
399
390
|
if len(parsed_logs) < 1:
|
|
400
391
|
raise OpenGradientError("InferenceResult event not found in transaction logs")
|
|
401
|
-
inference_result = parsed_logs[0]
|
|
402
|
-
|
|
403
|
-
# Extract the ModelOutput from the event
|
|
404
|
-
event_data = inference_result['args']
|
|
405
|
-
logging.debug(f"Raw event data: {event_data}")
|
|
406
|
-
|
|
407
|
-
try:
|
|
408
|
-
model_output = utils.convert_to_model_output(event_data)
|
|
409
|
-
logging.debug(f"Parsed ModelOutput: {model_output}")
|
|
410
|
-
except Exception as e:
|
|
411
|
-
logging.error(f"Error parsing event data: {str(e)}", exc_info=True)
|
|
412
|
-
raise OpenGradientError(f"Failed to parse event data: {str(e)}")
|
|
413
392
|
|
|
393
|
+
model_output = utils.convert_to_model_output(parsed_logs[0]['args'])
|
|
414
394
|
return tx_hash.hex(), model_output
|
|
415
395
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
raise OpenGradientError(f"Inference failed due to contract logic error: {str(e)}")
|
|
419
|
-
except Exception as e:
|
|
420
|
-
logging.error(f"Error in infer method: {str(e)}", exc_info=True)
|
|
421
|
-
raise OpenGradientError(f"Inference failed: {str(e)}")
|
|
422
|
-
|
|
396
|
+
return run_with_retry(execute_transaction, max_retries or 5)
|
|
397
|
+
|
|
423
398
|
def llm_completion(self,
|
|
424
399
|
model_cid: LLM,
|
|
400
|
+
inference_mode: InferenceMode,
|
|
425
401
|
prompt: str,
|
|
426
402
|
max_tokens: int = 100,
|
|
427
403
|
stop_sequence: Optional[List[str]] = None,
|
|
428
|
-
temperature: float = 0.0
|
|
404
|
+
temperature: float = 0.0,
|
|
405
|
+
max_retries: Optional[int] = None) -> Tuple[str, str]:
|
|
429
406
|
"""
|
|
430
407
|
Perform inference on an LLM model using completions.
|
|
431
408
|
|
|
432
409
|
Args:
|
|
433
410
|
model_cid (LLM): The unique content identifier for the model.
|
|
411
|
+
inference_mode (InferenceMode): The inference mode.
|
|
434
412
|
prompt (str): The input prompt for the LLM.
|
|
435
413
|
max_tokens (int): Maximum number of tokens for LLM output. Default is 100.
|
|
436
414
|
stop_sequence (List[str], optional): List of stop sequences for LLM. Default is None.
|
|
@@ -442,17 +420,20 @@ class Client:
|
|
|
442
420
|
Raises:
|
|
443
421
|
OpenGradientError: If the inference fails.
|
|
444
422
|
"""
|
|
445
|
-
|
|
446
|
-
|
|
423
|
+
def execute_transaction():
|
|
424
|
+
# Check inference mode and supported model
|
|
425
|
+
if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
|
|
426
|
+
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
447
427
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
428
|
+
if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
|
|
429
|
+
raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
|
|
430
|
+
|
|
431
|
+
self._initialize_web3()
|
|
432
|
+
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
452
433
|
|
|
453
434
|
# Prepare LLM input
|
|
454
435
|
llm_request = {
|
|
455
|
-
"mode":
|
|
436
|
+
"mode": inference_mode,
|
|
456
437
|
"modelCID": model_cid,
|
|
457
438
|
"prompt": prompt,
|
|
458
439
|
"max_tokens": max_tokens,
|
|
@@ -461,11 +442,9 @@ class Client:
|
|
|
461
442
|
}
|
|
462
443
|
logging.debug(f"Prepared LLM request: {llm_request}")
|
|
463
444
|
|
|
464
|
-
# Prepare run function
|
|
465
445
|
run_function = contract.functions.runLLMCompletion(llm_request)
|
|
466
446
|
|
|
467
|
-
|
|
468
|
-
nonce = self._w3.eth.get_transaction_count(self.wallet_address)
|
|
447
|
+
nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
|
|
469
448
|
estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
|
|
470
449
|
gas_limit = int(estimated_gas * 1.2)
|
|
471
450
|
|
|
@@ -476,47 +455,38 @@ class Client:
|
|
|
476
455
|
'gasPrice': self._w3.eth.gas_price,
|
|
477
456
|
})
|
|
478
457
|
|
|
479
|
-
# Sign and send transaction
|
|
480
458
|
signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
|
|
481
459
|
tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
482
|
-
logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
|
|
483
|
-
|
|
484
|
-
# Wait for transaction receipt
|
|
485
460
|
tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
|
|
486
461
|
|
|
487
462
|
if tx_receipt['status'] == 0:
|
|
488
463
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
489
464
|
|
|
490
|
-
# Process the LLMResult event
|
|
491
465
|
parsed_logs = contract.events.LLMCompletionResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
492
|
-
|
|
493
466
|
if len(parsed_logs) < 1:
|
|
494
467
|
raise OpenGradientError("LLM completion result event not found in transaction logs")
|
|
495
|
-
llm_result = parsed_logs[0]
|
|
496
468
|
|
|
497
|
-
llm_answer =
|
|
469
|
+
llm_answer = parsed_logs[0]['args']['response']['answer']
|
|
498
470
|
return tx_hash.hex(), llm_answer
|
|
499
471
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
raise OpenGradientError(f"LLM inference failed due to contract logic error: {str(e)}")
|
|
503
|
-
except Exception as e:
|
|
504
|
-
logging.error(f"Error in infer completion method: {str(e)}", exc_info=True)
|
|
505
|
-
raise OpenGradientError(f"LLM inference failed: {str(e)}")
|
|
506
|
-
|
|
472
|
+
return run_with_retry(execute_transaction, max_retries or 5)
|
|
473
|
+
|
|
507
474
|
def llm_chat(self,
|
|
508
475
|
model_cid: str,
|
|
476
|
+
inference_mode: InferenceMode,
|
|
509
477
|
messages: List[Dict],
|
|
510
478
|
max_tokens: int = 100,
|
|
511
479
|
stop_sequence: Optional[List[str]] = None,
|
|
512
480
|
temperature: float = 0.0,
|
|
513
481
|
tools: Optional[List[Dict]] = [],
|
|
514
|
-
tool_choice: Optional[str] = None
|
|
482
|
+
tool_choice: Optional[str] = None,
|
|
483
|
+
max_retries: Optional[int] = None) -> Tuple[str, str]:
|
|
515
484
|
"""
|
|
516
485
|
Perform inference on an LLM model using chat.
|
|
517
486
|
|
|
518
487
|
Args:
|
|
519
488
|
model_cid (LLM): The unique content identifier for the model.
|
|
489
|
+
inference_mode (InferenceMode): The inference mode.
|
|
520
490
|
messages (dict): The messages that will be passed into the chat.
|
|
521
491
|
This should be in OpenAI API format (https://platform.openai.com/docs/api-reference/chat/create)
|
|
522
492
|
Example:
|
|
@@ -567,13 +537,16 @@ class Client:
|
|
|
567
537
|
Raises:
|
|
568
538
|
OpenGradientError: If the inference fails.
|
|
569
539
|
"""
|
|
570
|
-
|
|
571
|
-
|
|
540
|
+
def execute_transaction():
|
|
541
|
+
# Check inference mode and supported model
|
|
542
|
+
if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
|
|
543
|
+
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
544
|
+
|
|
545
|
+
if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
|
|
546
|
+
raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
|
|
572
547
|
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
llm_abi = json.load(abi_file)
|
|
576
|
-
contract = self._w3.eth.contract(address=self.contract_address, abi=llm_abi)
|
|
548
|
+
self._initialize_web3()
|
|
549
|
+
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
577
550
|
|
|
578
551
|
# For incoming chat messages, tool_calls can be empty. Add an empty array so that it will fit the ABI.
|
|
579
552
|
for message in messages:
|
|
@@ -585,17 +558,10 @@ class Client:
|
|
|
585
558
|
message['name'] = ""
|
|
586
559
|
|
|
587
560
|
# Create simplified tool structure for smart contract
|
|
588
|
-
#
|
|
589
|
-
# struct ToolDefinition {
|
|
590
|
-
# string description;
|
|
591
|
-
# string name;
|
|
592
|
-
# string parameters; // This must be a JSON
|
|
593
|
-
# }
|
|
594
561
|
converted_tools = []
|
|
595
562
|
if tools is not None:
|
|
596
563
|
for tool in tools:
|
|
597
564
|
function = tool['function']
|
|
598
|
-
|
|
599
565
|
converted_tool = {}
|
|
600
566
|
converted_tool['name'] = function['name']
|
|
601
567
|
converted_tool['description'] = function['description']
|
|
@@ -604,12 +570,11 @@ class Client:
|
|
|
604
570
|
converted_tool['parameters'] = json.dumps(parameters)
|
|
605
571
|
except Exception as e:
|
|
606
572
|
raise OpenGradientError("Chat LLM failed to convert parameters into JSON: %s", e)
|
|
607
|
-
|
|
608
573
|
converted_tools.append(converted_tool)
|
|
609
574
|
|
|
610
575
|
# Prepare LLM input
|
|
611
576
|
llm_request = {
|
|
612
|
-
"mode":
|
|
577
|
+
"mode": inference_mode,
|
|
613
578
|
"modelCID": model_cid,
|
|
614
579
|
"messages": messages,
|
|
615
580
|
"max_tokens": max_tokens,
|
|
@@ -620,11 +585,9 @@ class Client:
|
|
|
620
585
|
}
|
|
621
586
|
logging.debug(f"Prepared LLM request: {llm_request}")
|
|
622
587
|
|
|
623
|
-
# Prepare run function
|
|
624
588
|
run_function = contract.functions.runLLMChat(llm_request)
|
|
625
589
|
|
|
626
|
-
|
|
627
|
-
nonce = self._w3.eth.get_transaction_count(self.wallet_address)
|
|
590
|
+
nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
|
|
628
591
|
estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
|
|
629
592
|
gas_limit = int(estimated_gas * 1.2)
|
|
630
593
|
|
|
@@ -635,41 +598,25 @@ class Client:
|
|
|
635
598
|
'gasPrice': self._w3.eth.gas_price,
|
|
636
599
|
})
|
|
637
600
|
|
|
638
|
-
# Sign and send transaction
|
|
639
601
|
signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
|
|
640
602
|
tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
641
|
-
logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
|
|
642
|
-
|
|
643
|
-
# Wait for transaction receipt
|
|
644
603
|
tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
|
|
645
604
|
|
|
646
605
|
if tx_receipt['status'] == 0:
|
|
647
606
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
648
607
|
|
|
649
|
-
# Process the LLMResult event
|
|
650
608
|
parsed_logs = contract.events.LLMChatResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
651
|
-
|
|
652
609
|
if len(parsed_logs) < 1:
|
|
653
610
|
raise OpenGradientError("LLM chat result event not found in transaction logs")
|
|
654
|
-
llm_result = parsed_logs[0]['args']['response']
|
|
655
611
|
|
|
656
|
-
|
|
612
|
+
llm_result = parsed_logs[0]['args']['response']
|
|
657
613
|
message = dict(llm_result['message'])
|
|
658
|
-
if (tool_calls := message.get('tool_calls'))
|
|
659
|
-
|
|
660
|
-
for tool_call in tool_calls:
|
|
661
|
-
new_tool_calls.append(dict(tool_call))
|
|
662
|
-
message['tool_calls'] = new_tool_calls
|
|
663
|
-
|
|
664
|
-
return (tx_hash.hex(), llm_result['finish_reason'], message)
|
|
614
|
+
if (tool_calls := message.get('tool_calls')) is not None:
|
|
615
|
+
message['tool_calls'] = [dict(tool_call) for tool_call in tool_calls]
|
|
665
616
|
|
|
666
|
-
|
|
667
|
-
logging.error(f"Contract logic error: {str(e)}", exc_info=True)
|
|
668
|
-
raise OpenGradientError(f"LLM inference failed due to contract logic error: {str(e)}")
|
|
669
|
-
except Exception as e:
|
|
670
|
-
logging.error(f"Error in infer chat method: {str(e)}", exc_info=True)
|
|
671
|
-
raise OpenGradientError(f"LLM inference failed: {str(e)}")
|
|
617
|
+
return tx_hash.hex(), llm_result['finish_reason'], message
|
|
672
618
|
|
|
619
|
+
return run_with_retry(execute_transaction, max_retries or 5)
|
|
673
620
|
|
|
674
621
|
def list_files(self, model_name: str, version: str) -> List[Dict]:
|
|
675
622
|
"""
|
opengradient/llm/__init__.py
CHANGED
|
@@ -1,5 +1,38 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
OpenGradient LLM Adapters
|
|
3
|
+
|
|
4
|
+
This module provides adapter interfaces to use OpenGradient LLMs with popular AI frameworks
|
|
5
|
+
like LangChain and OpenAI. These adapters allow seamless integration of OpenGradient models
|
|
6
|
+
into existing applications and agent frameworks.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from .og_langchain import *
|
|
10
|
+
from .og_openai import *
|
|
11
|
+
|
|
12
|
+
def langchain_adapter(private_key: str, model_cid: str, max_tokens: int = 300) -> OpenGradientChatModel:
|
|
13
|
+
"""
|
|
14
|
+
Returns an OpenGradient LLM that implements LangChain's LLM interface
|
|
15
|
+
and can be plugged into LangChain agents.
|
|
16
|
+
"""
|
|
17
|
+
return OpenGradientChatModel(
|
|
18
|
+
private_key=private_key,
|
|
19
|
+
model_cid=model_cid,
|
|
20
|
+
max_tokens=max_tokens)
|
|
21
|
+
|
|
22
|
+
def openai_adapter(private_key: str) -> OpenGradientOpenAIClient:
|
|
23
|
+
"""
|
|
24
|
+
Returns an generic OpenAI LLM client that can be plugged into Swarm and can
|
|
25
|
+
be used with any LLM model on OpenGradient. The LLM is usually defined in the
|
|
26
|
+
agent.
|
|
27
|
+
"""
|
|
28
|
+
return OpenGradientOpenAIClient(private_key=private_key)
|
|
2
29
|
|
|
3
30
|
__all__ = [
|
|
4
|
-
'
|
|
5
|
-
|
|
31
|
+
'langchain_adapter',
|
|
32
|
+
'openai_adapter',
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
__pdoc__ = {
|
|
36
|
+
'og_langchain': False,
|
|
37
|
+
'og_openai': False
|
|
38
|
+
}
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from openai.types.chat import ChatCompletion
|
|
2
|
+
import opengradient as og
|
|
3
|
+
from opengradient.defaults import DEFAULT_RPC_URL, DEFAULT_INFERENCE_CONTRACT_ADDRESS
|
|
4
|
+
|
|
5
|
+
from typing import List
|
|
6
|
+
import time
|
|
7
|
+
import json
|
|
8
|
+
import uuid
|
|
9
|
+
|
|
10
|
+
class OGCompletions(object):
|
|
11
|
+
client: og.Client
|
|
12
|
+
|
|
13
|
+
def __init__(self, client: og.Client):
|
|
14
|
+
self.client = client
|
|
15
|
+
|
|
16
|
+
def create(
|
|
17
|
+
self,
|
|
18
|
+
model: str,
|
|
19
|
+
messages: List[object],
|
|
20
|
+
tools: List[object],
|
|
21
|
+
tool_choice: str,
|
|
22
|
+
stream: bool = False,
|
|
23
|
+
parallel_tool_calls: bool = False) -> ChatCompletion:
|
|
24
|
+
|
|
25
|
+
# convert OpenAI message format so it's compatible with the SDK
|
|
26
|
+
sdk_messages = OGCompletions.convert_to_abi_compatible(messages)
|
|
27
|
+
|
|
28
|
+
_, finish_reason, chat_completion = self.client.llm_chat(
|
|
29
|
+
model_cid=model,
|
|
30
|
+
messages=sdk_messages,
|
|
31
|
+
max_tokens=200,
|
|
32
|
+
tools=tools,
|
|
33
|
+
tool_choice=tool_choice,
|
|
34
|
+
temperature=0.25,
|
|
35
|
+
inference_mode=og.LlmInferenceMode.VANILLA
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
choice = {
|
|
39
|
+
'index': 0, # Add missing index field
|
|
40
|
+
'finish_reason': finish_reason,
|
|
41
|
+
'message': {
|
|
42
|
+
'role': chat_completion['role'],
|
|
43
|
+
'content': chat_completion['content'],
|
|
44
|
+
'tool_calls': [
|
|
45
|
+
{
|
|
46
|
+
'id': tool_call['id'],
|
|
47
|
+
'type': 'function', # Add missing type field
|
|
48
|
+
'function': { # Add missing function field
|
|
49
|
+
'name': tool_call['name'],
|
|
50
|
+
'arguments': tool_call['arguments']
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
for tool_call in chat_completion.get('tool_calls', [])
|
|
54
|
+
]
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
return ChatCompletion(
|
|
59
|
+
id=str(uuid.uuid4()),
|
|
60
|
+
created=int(time.time()),
|
|
61
|
+
model=model,
|
|
62
|
+
object='chat.completion',
|
|
63
|
+
choices=[choice]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
@staticmethod
|
|
68
|
+
def convert_to_abi_compatible(messages):
|
|
69
|
+
sdk_messages = []
|
|
70
|
+
|
|
71
|
+
for message in messages:
|
|
72
|
+
role = message['role']
|
|
73
|
+
sdk_message = {
|
|
74
|
+
'role': role
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
if role == 'system':
|
|
78
|
+
sdk_message['content'] = message['content']
|
|
79
|
+
elif role == 'user':
|
|
80
|
+
sdk_message['content'] = message['content']
|
|
81
|
+
elif role == 'tool':
|
|
82
|
+
sdk_message['content'] = message['content']
|
|
83
|
+
sdk_message['tool_call_id'] = message['tool_call_id']
|
|
84
|
+
elif role == 'assistant':
|
|
85
|
+
flattened_calls = []
|
|
86
|
+
for tool_call in message['tool_calls']:
|
|
87
|
+
# OpenAI format
|
|
88
|
+
flattened_call = {
|
|
89
|
+
'id': tool_call['id'],
|
|
90
|
+
'name': tool_call['function']['name'],
|
|
91
|
+
'arguments': tool_call['function']['arguments']
|
|
92
|
+
}
|
|
93
|
+
flattened_calls.append(flattened_call)
|
|
94
|
+
|
|
95
|
+
sdk_message['tool_calls'] = flattened_calls
|
|
96
|
+
sdk_message['content'] = message['content']
|
|
97
|
+
|
|
98
|
+
sdk_messages.append(sdk_message)
|
|
99
|
+
|
|
100
|
+
return sdk_messages
|
|
101
|
+
|
|
102
|
+
class OGChat(object):
|
|
103
|
+
completions: OGCompletions
|
|
104
|
+
|
|
105
|
+
def __init__(self, client: og.Client):
|
|
106
|
+
self.completions = OGCompletions(client)
|
|
107
|
+
|
|
108
|
+
class OpenGradientOpenAIClient(object):
|
|
109
|
+
"""OpenAI client implementation"""
|
|
110
|
+
client: og.Client
|
|
111
|
+
chat: OGChat
|
|
112
|
+
|
|
113
|
+
def __init__(self, private_key: str):
|
|
114
|
+
self.client = og.Client(
|
|
115
|
+
private_key=private_key,
|
|
116
|
+
rpc_url=DEFAULT_RPC_URL,
|
|
117
|
+
contract_address=DEFAULT_INFERENCE_CONTRACT_ADDRESS,
|
|
118
|
+
email=None,
|
|
119
|
+
password=None
|
|
120
|
+
)
|
|
121
|
+
self.chat = OGChat(self.client)
|
opengradient/types.py
CHANGED
|
@@ -27,6 +27,10 @@ class InferenceMode:
|
|
|
27
27
|
ZKML = 1
|
|
28
28
|
TEE = 2
|
|
29
29
|
|
|
30
|
+
class LlmInferenceMode:
|
|
31
|
+
VANILLA = 0
|
|
32
|
+
TEE = 1
|
|
33
|
+
|
|
30
34
|
@dataclass
|
|
31
35
|
class ModelOutput:
|
|
32
36
|
numbers: List[NumberTensor]
|
|
@@ -75,8 +79,15 @@ class Abi:
|
|
|
75
79
|
return result
|
|
76
80
|
|
|
77
81
|
class LLM(str, Enum):
|
|
82
|
+
"""Enum for available LLM models"""
|
|
83
|
+
|
|
78
84
|
META_LLAMA_3_8B_INSTRUCT = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
79
85
|
LLAMA_3_2_3B_INSTRUCT = "meta-llama/Llama-3.2-3B-Instruct"
|
|
80
86
|
MISTRAL_7B_INSTRUCT_V3 = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
81
87
|
HERMES_3_LLAMA_3_1_70B = "NousResearch/Hermes-3-Llama-3.1-70B"
|
|
88
|
+
META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
|
|
89
|
+
|
|
90
|
+
class TEE_LLM(str, Enum):
|
|
91
|
+
"""Enum for LLM models available for TEE execution"""
|
|
92
|
+
|
|
82
93
|
META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: opengradient
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.18
|
|
4
4
|
Summary: Python SDK for OpenGradient decentralized model management & inference services
|
|
5
5
|
Project-URL: Homepage, https://opengradient.ai
|
|
6
6
|
Author-email: OpenGradient <oliver@opengradient.ai>
|
|
@@ -25,6 +25,7 @@ License: MIT License
|
|
|
25
25
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
26
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
27
|
SOFTWARE.
|
|
28
|
+
License-File: LICENSE
|
|
28
29
|
Classifier: Development Status :: 3 - Alpha
|
|
29
30
|
Classifier: Intended Audience :: Developers
|
|
30
31
|
Classifier: License :: OSI Approved :: MIT License
|
|
@@ -85,6 +86,7 @@ Requires-Dist: langchain==0.3.7
|
|
|
85
86
|
Requires-Dist: more-itertools==10.5.0
|
|
86
87
|
Requires-Dist: msgpack==1.1.0
|
|
87
88
|
Requires-Dist: multidict==6.1.0
|
|
89
|
+
Requires-Dist: openai==1.58.1
|
|
88
90
|
Requires-Dist: packaging==24.1
|
|
89
91
|
Requires-Dist: pandas==2.2.3
|
|
90
92
|
Requires-Dist: parsimonious==0.10.0
|
|
@@ -215,18 +217,7 @@ tx_hash, finish_reason, message = og.llm_chat(
|
|
|
215
217
|
)
|
|
216
218
|
```
|
|
217
219
|
|
|
218
|
-
### Image Generation
|
|
219
|
-
```python
|
|
220
|
-
tx_hash, image_data = og.generate_image(
|
|
221
|
-
model="stabilityai/stable-diffusion-xl-base-1.0",
|
|
222
|
-
prompt="A beautiful sunset over mountains",
|
|
223
|
-
width=1024,
|
|
224
|
-
height=1024
|
|
225
|
-
)
|
|
226
220
|
|
|
227
|
-
with open("generated_image.png", "wb") as f:
|
|
228
|
-
f.write(image_data)
|
|
229
|
-
```
|
|
230
221
|
|
|
231
222
|
## Using the CLI
|
|
232
223
|
|
|
@@ -281,21 +272,4 @@ Or you can use files instead of text input in order to simplify your command:
|
|
|
281
272
|
opengradient chat --model "mistralai/Mistral-7B-Instruct-v0.3" --messages-file messages.json --tools-file tools.json --max-tokens 200
|
|
282
273
|
```
|
|
283
274
|
|
|
284
|
-
### Image Generation
|
|
285
|
-
```bash
|
|
286
|
-
opengradient generate-image \
|
|
287
|
-
--model "stabilityai/stable-diffusion-xl-base-1.0" \
|
|
288
|
-
--prompt "A beautiful sunset over mountains" \
|
|
289
|
-
--output-path sunset.png \
|
|
290
|
-
--width 1024 \
|
|
291
|
-
--height 1024
|
|
292
|
-
```
|
|
293
|
-
|
|
294
|
-
Options:
|
|
295
|
-
- `--model`, `-m`: Model identifier for image generation (required)
|
|
296
|
-
- `--prompt`, `-p`: Text prompt for generating the image (required)
|
|
297
|
-
- `--output-path`, `-o`: Output file path for the generated image (required)
|
|
298
|
-
- `--width`: Output image width in pixels (default: 1024)
|
|
299
|
-
- `--height`: Output image height in pixels (default: 1024)
|
|
300
|
-
|
|
301
275
|
For more information read the OpenGradient [documentation](https://docs.opengradient.ai/).
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
opengradient/__init__.py,sha256=CRII5thG_vRVtEqyhc48IWgYwjif2S4kOBbBf6MgH8I,9262
|
|
2
|
+
opengradient/account.py,sha256=2B7rtCXQDX-yF4U69h8B9-OUreJU4IqoGXG_1Hn9nWs,1150
|
|
3
|
+
opengradient/cli.py,sha256=niN8tlLaiVEpdtkdWEUbxidG75nxrlb6mMUfUAIjiVw,26400
|
|
4
|
+
opengradient/client.py,sha256=axMbfqAQ6OWq3sM_D4bGVWrpAd-15Ru-bOHJ6R6GruA,35197
|
|
5
|
+
opengradient/defaults.py,sha256=6tsW9Z84z6YtITCsULTTgnN0KRUZjfSoeWJZqdWkYCo,384
|
|
6
|
+
opengradient/exceptions.py,sha256=v4VmUGTvvtjhCZAhR24Ga42z3q-DzR1Y5zSqP_yn2Xk,3366
|
|
7
|
+
opengradient/types.py,sha256=-lGWv_yfXMN48bbvARKIFrj1L0AotIwr2c7GOv1JBcI,2464
|
|
8
|
+
opengradient/utils.py,sha256=lUDPmyPqLwpZI-owyN6Rm3QvUjOn5pLN5G1QyriVm-E,6994
|
|
9
|
+
opengradient/abi/inference.abi,sha256=MR5u9npZ-Yx2EqRW17_M-UnGgFF3mMEMepOwaZ-Bkgc,7040
|
|
10
|
+
opengradient/llm/__init__.py,sha256=J1W_AKPntqlDqLeflhn2x7A0i-dkMT-ol3jlEdFgMWU,1135
|
|
11
|
+
opengradient/llm/og_langchain.py,sha256=F32yN1o8EvRbzZSJkUwI0-FmSVAWMuMTI9ho7wgW5hk,4470
|
|
12
|
+
opengradient/llm/og_openai.py,sha256=GilEkIVDac5Cacennb7YNXD6BKwUj69mfnMkvxkiW5Y,3865
|
|
13
|
+
opengradient/proto/__init__.py,sha256=AhaSmrqV0TXGzCKaoPV8-XUvqs2fGAJBM2aOmDpkNbE,55
|
|
14
|
+
opengradient/proto/infer.proto,sha256=13eaEMcppxkBF8yChptsX9HooWFwJKze7oLZNl-LEb8,1217
|
|
15
|
+
opengradient/proto/infer_pb2.py,sha256=wg2vjLQCNv6HRhYuIqgj9xivi3nO4IPz6E5wh2dhDqY,3446
|
|
16
|
+
opengradient/proto/infer_pb2_grpc.py,sha256=y5GYwD1EdNs892xx58jdfyA0fO5QC7k3uZOtImTHMiE,6891
|
|
17
|
+
opengradient-0.3.18.dist-info/METADATA,sha256=eH1-0CRuheW8Ho8mMbiivn2WIrgAKZf37c4zlFMloes,8744
|
|
18
|
+
opengradient-0.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
+
opengradient-0.3.18.dist-info/entry_points.txt,sha256=yUKTaJx8RXnybkob0J62wVBiCp_1agVbgw9uzsmaeJc,54
|
|
20
|
+
opengradient-0.3.18.dist-info/licenses/LICENSE,sha256=xEcvQ3AxZOtDkrqkys2Mm6Y9diEnaSeQRKvxi-JGnNA,1069
|
|
21
|
+
opengradient-0.3.18.dist-info/RECORD,,
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
opengradient/__init__.py,sha256=IlVjsDPLaC8Zqaj8_QZCzi6-F20Ta-wyefd1mh4k53Q,4015
|
|
2
|
-
opengradient/account.py,sha256=2B7rtCXQDX-yF4U69h8B9-OUreJU4IqoGXG_1Hn9nWs,1150
|
|
3
|
-
opengradient/cli.py,sha256=kdYR_AFKHV99HtO_son7vHpM5jWVZe8FO0iMWxJ7pJE,24444
|
|
4
|
-
opengradient/client.py,sha256=RdlTz60NJKVJihYY6oVYLfNOg6RGnJbfG-2UIxUk-ws,37069
|
|
5
|
-
opengradient/defaults.py,sha256=6tsW9Z84z6YtITCsULTTgnN0KRUZjfSoeWJZqdWkYCo,384
|
|
6
|
-
opengradient/exceptions.py,sha256=v4VmUGTvvtjhCZAhR24Ga42z3q-DzR1Y5zSqP_yn2Xk,3366
|
|
7
|
-
opengradient/types.py,sha256=QTEsygwT5AnIf8Dg9mexvVUe49nCo9N0pgZOOIp3trc,2214
|
|
8
|
-
opengradient/utils.py,sha256=lUDPmyPqLwpZI-owyN6Rm3QvUjOn5pLN5G1QyriVm-E,6994
|
|
9
|
-
opengradient/abi/inference.abi,sha256=MR5u9npZ-Yx2EqRW17_M-UnGgFF3mMEMepOwaZ-Bkgc,7040
|
|
10
|
-
opengradient/llm/__init__.py,sha256=n_11WFPoU8YtGc6wg9cK6gEy9zBISf1183Loip3dAbI,62
|
|
11
|
-
opengradient/llm/chat.py,sha256=F32yN1o8EvRbzZSJkUwI0-FmSVAWMuMTI9ho7wgW5hk,4470
|
|
12
|
-
opengradient/proto/__init__.py,sha256=AhaSmrqV0TXGzCKaoPV8-XUvqs2fGAJBM2aOmDpkNbE,55
|
|
13
|
-
opengradient/proto/infer.proto,sha256=13eaEMcppxkBF8yChptsX9HooWFwJKze7oLZNl-LEb8,1217
|
|
14
|
-
opengradient/proto/infer_pb2.py,sha256=wg2vjLQCNv6HRhYuIqgj9xivi3nO4IPz6E5wh2dhDqY,3446
|
|
15
|
-
opengradient/proto/infer_pb2_grpc.py,sha256=y5GYwD1EdNs892xx58jdfyA0fO5QC7k3uZOtImTHMiE,6891
|
|
16
|
-
opengradient-0.3.16.dist-info/METADATA,sha256=FrkTwhsj8sHnpaDxKVwCJosvfAwljdJ9yFE9KKHkLMk,9536
|
|
17
|
-
opengradient-0.3.16.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
|
|
18
|
-
opengradient-0.3.16.dist-info/entry_points.txt,sha256=yUKTaJx8RXnybkob0J62wVBiCp_1agVbgw9uzsmaeJc,54
|
|
19
|
-
opengradient-0.3.16.dist-info/licenses/LICENSE,sha256=xEcvQ3AxZOtDkrqkys2Mm6Y9diEnaSeQRKvxi-JGnNA,1069
|
|
20
|
-
opengradient-0.3.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|