opengradient 0.3.15__tar.gz → 0.3.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {opengradient-0.3.15 → opengradient-0.3.17}/PKG-INFO +1 -1
- {opengradient-0.3.15 → opengradient-0.3.17}/pyproject.toml +1 -1
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/__init__.py +46 -16
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/cli.py +49 -16
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/client.py +75 -128
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/types.py +7 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/.gitignore +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/LICENSE +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/README.md +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/abi/inference.abi +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/account.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/defaults.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/exceptions.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/llm/__init__.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/llm/chat.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/__init__.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/infer.proto +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/infer_pb2.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/infer_pb2_grpc.py +0 -0
- {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: opengradient
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.17
|
|
4
4
|
Summary: Python SDK for OpenGradient decentralized model management & inference services
|
|
5
5
|
Project-URL: Homepage, https://opengradient.ai
|
|
6
6
|
Author-email: OpenGradient <oliver@opengradient.ai>
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "opengradient"
|
|
7
|
-
version = "0.3.
|
|
7
|
+
version = "0.3.17"
|
|
8
8
|
description = "Python SDK for OpenGradient decentralized model management & inference services"
|
|
9
9
|
authors = [{name = "OpenGradient", email = "oliver@opengradient.ai"}]
|
|
10
10
|
license = {file = "LICENSE"}
|
|
@@ -2,10 +2,10 @@ from typing import Dict, List, Optional, Tuple
|
|
|
2
2
|
|
|
3
3
|
from .client import Client
|
|
4
4
|
from .defaults import DEFAULT_INFERENCE_CONTRACT_ADDRESS, DEFAULT_RPC_URL
|
|
5
|
-
from .types import InferenceMode, LLM
|
|
5
|
+
from .types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
|
|
6
6
|
from . import llm
|
|
7
7
|
|
|
8
|
-
__version__ = "0.3.
|
|
8
|
+
__version__ = "0.3.17"
|
|
9
9
|
|
|
10
10
|
_client = None
|
|
11
11
|
|
|
@@ -40,30 +40,60 @@ def create_version(model_name, notes=None, is_major=False):
|
|
|
40
40
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
41
41
|
return _client.create_version(model_name, notes, is_major)
|
|
42
42
|
|
|
43
|
-
def infer(model_cid, inference_mode, model_input):
|
|
43
|
+
def infer(model_cid, inference_mode, model_input, max_retries: Optional[int] = None):
|
|
44
|
+
"""
|
|
45
|
+
Perform inference on a model.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model_cid: Model CID to use for inference
|
|
49
|
+
inference_mode: Mode of inference (e.g. VANILLA)
|
|
50
|
+
model_input: Input data for the model
|
|
51
|
+
max_retries: Optional maximum number of retry attempts for transaction errors
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Tuple of (transaction hash, model output)
|
|
55
|
+
"""
|
|
44
56
|
if _client is None:
|
|
45
57
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
46
|
-
return _client.infer(model_cid, inference_mode, model_input)
|
|
58
|
+
return _client.infer(model_cid, inference_mode, model_input, max_retries=max_retries)
|
|
47
59
|
|
|
48
60
|
def llm_completion(model_cid: LLM,
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
61
|
+
prompt: str,
|
|
62
|
+
inference_mode: str = LlmInferenceMode.VANILLA,
|
|
63
|
+
max_tokens: int = 100,
|
|
64
|
+
stop_sequence: Optional[List[str]] = None,
|
|
65
|
+
temperature: float = 0.0,
|
|
66
|
+
max_retries: Optional[int] = None) -> Tuple[str, str]:
|
|
53
67
|
if _client is None:
|
|
54
68
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
55
|
-
return _client.llm_completion(model_cid,
|
|
69
|
+
return _client.llm_completion(model_cid=model_cid,
|
|
70
|
+
inference_mode=inference_mode,
|
|
71
|
+
prompt=prompt,
|
|
72
|
+
max_tokens=max_tokens,
|
|
73
|
+
stop_sequence=stop_sequence,
|
|
74
|
+
temperature=temperature,
|
|
75
|
+
max_retries=max_retries)
|
|
56
76
|
|
|
57
77
|
def llm_chat(model_cid: LLM,
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
78
|
+
messages: List[Dict],
|
|
79
|
+
inference_mode: str = LlmInferenceMode.VANILLA,
|
|
80
|
+
max_tokens: int = 100,
|
|
81
|
+
stop_sequence: Optional[List[str]] = None,
|
|
82
|
+
temperature: float = 0.0,
|
|
83
|
+
tools: Optional[List[Dict]] = None,
|
|
84
|
+
tool_choice: Optional[str] = None,
|
|
85
|
+
max_retries: Optional[int] = None) -> Tuple[str, str, Dict]:
|
|
64
86
|
if _client is None:
|
|
65
87
|
raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
|
|
66
|
-
return _client.llm_chat(model_cid,
|
|
88
|
+
return _client.llm_chat(model_cid=model_cid,
|
|
89
|
+
inference_mode=inference_mode,
|
|
90
|
+
messages=messages,
|
|
91
|
+
max_tokens=max_tokens,
|
|
92
|
+
stop_sequence=stop_sequence,
|
|
93
|
+
temperature=temperature,
|
|
94
|
+
tools=tools,
|
|
95
|
+
tool_choice=tool_choice,
|
|
96
|
+
max_retries=max_retries)
|
|
67
97
|
|
|
68
98
|
def login(email: str, password: str):
|
|
69
99
|
if _client is None:
|
|
@@ -20,7 +20,7 @@ from .defaults import (
|
|
|
20
20
|
DEFAULT_OG_FAUCET_URL,
|
|
21
21
|
DEFAULT_RPC_URL,
|
|
22
22
|
)
|
|
23
|
-
from .types import InferenceMode
|
|
23
|
+
from .types import InferenceMode, LlmInferenceMode
|
|
24
24
|
|
|
25
25
|
OG_CONFIG_FILE = Path.home() / '.opengradient_config.json'
|
|
26
26
|
|
|
@@ -65,11 +65,17 @@ InferenceModes = {
|
|
|
65
65
|
"TEE": InferenceMode.TEE,
|
|
66
66
|
}
|
|
67
67
|
|
|
68
|
+
LlmInferenceModes = {
|
|
69
|
+
"VANILLA": LlmInferenceMode.VANILLA,
|
|
70
|
+
"TEE": LlmInferenceMode.TEE,
|
|
71
|
+
}
|
|
72
|
+
|
|
68
73
|
# Supported LLMs
|
|
69
74
|
LlmModels = {
|
|
70
75
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
|
71
76
|
"meta-llama/Llama-3.2-3B-Instruct",
|
|
72
|
-
"mistralai/Mistral-7B-Instruct-v0.3"
|
|
77
|
+
"mistralai/Mistral-7B-Instruct-v0.3",
|
|
78
|
+
"meta-llama/Llama-3.1-70B-Instruct",
|
|
73
79
|
}
|
|
74
80
|
|
|
75
81
|
def initialize_config(ctx):
|
|
@@ -339,13 +345,15 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
|
|
|
339
345
|
click.echo(f"Error running inference: {str(e)}")
|
|
340
346
|
|
|
341
347
|
@cli.command()
|
|
342
|
-
@click.option('--model', '-m', 'model_cid', type=click.Choice(
|
|
348
|
+
@click.option('--model', '-m', 'model_cid', type=click.Choice([e.value for e in types.LLM]), required=True, help='CID of the LLM model to run inference on')
|
|
349
|
+
@click.option('--mode', 'inference_mode', type=click.Choice(LlmInferenceModes.keys()), default="VANILLA",
|
|
350
|
+
help='Inference mode (default: VANILLA)')
|
|
343
351
|
@click.option('--prompt', '-p', required=True, help='Input prompt for the LLM completion')
|
|
344
352
|
@click.option('--max-tokens', type=int, default=100, help='Maximum number of tokens for LLM completion output')
|
|
345
353
|
@click.option('--stop-sequence', multiple=True, help='Stop sequences for LLM')
|
|
346
354
|
@click.option('--temperature', type=float, default=0.0, help='Temperature for LLM inference (0.0 to 1.0)')
|
|
347
355
|
@click.pass_context
|
|
348
|
-
def completion(ctx, model_cid: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
|
|
356
|
+
def completion(ctx, model_cid: str, inference_mode: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
|
|
349
357
|
"""
|
|
350
358
|
Run completion inference on an LLM model.
|
|
351
359
|
|
|
@@ -355,13 +363,14 @@ def completion(ctx, model_cid: str, prompt: str, max_tokens: int, stop_sequence:
|
|
|
355
363
|
|
|
356
364
|
\b
|
|
357
365
|
opengradient completion --model meta-llama/Meta-Llama-3-8B-Instruct --prompt "Hello, how are you?" --max-tokens 50 --temperature 0.7
|
|
358
|
-
opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "
|
|
366
|
+
opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "\\n"
|
|
359
367
|
"""
|
|
360
368
|
client: Client = ctx.obj['client']
|
|
361
369
|
try:
|
|
362
370
|
click.echo(f"Running LLM completion inference for model \"{model_cid}\"\n")
|
|
363
371
|
tx_hash, llm_output = client.llm_completion(
|
|
364
372
|
model_cid=model_cid,
|
|
373
|
+
inference_mode=LlmInferenceModes[inference_mode],
|
|
365
374
|
prompt=prompt,
|
|
366
375
|
max_tokens=max_tokens,
|
|
367
376
|
stop_sequence=list(stop_sequence),
|
|
@@ -394,6 +403,9 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
|
|
|
394
403
|
type=click.Choice([e.value for e in types.LLM]),
|
|
395
404
|
required=True,
|
|
396
405
|
help='CID of the LLM model to run inference on')
|
|
406
|
+
@click.option('--mode', 'inference_mode', type=click.Choice(LlmInferenceModes.keys()),
|
|
407
|
+
default="VANILLA",
|
|
408
|
+
help='Inference mode (default: VANILLA)')
|
|
397
409
|
@click.option('--messages',
|
|
398
410
|
type=str,
|
|
399
411
|
required=False,
|
|
@@ -431,6 +443,7 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
|
|
|
431
443
|
def chat(
|
|
432
444
|
ctx,
|
|
433
445
|
model_cid: str,
|
|
446
|
+
inference_mode: str,
|
|
434
447
|
messages: Optional[str],
|
|
435
448
|
messages_file: Optional[Path],
|
|
436
449
|
max_tokens: int,
|
|
@@ -444,11 +457,13 @@ def chat(
|
|
|
444
457
|
|
|
445
458
|
This command runs a chat inference on the specified LLM model using the provided messages and parameters.
|
|
446
459
|
|
|
460
|
+
Tool call formatting is based on OpenAI documentation tool calls (see here: https://platform.openai.com/docs/guides/function-calling).
|
|
461
|
+
|
|
447
462
|
Example usage:
|
|
448
463
|
|
|
449
464
|
\b
|
|
450
465
|
opengradient chat --model meta-llama/Meta-Llama-3-8B-Instruct --messages '[{"role":"user","content":"hello"}]' --max-tokens 50 --temperature 0.7
|
|
451
|
-
opengradient chat
|
|
466
|
+
opengradient chat --model mistralai/Mistral-7B-Instruct-v0.3 --messages-file messages.json --tools-file tools.json --max-tokens 200 --stop-sequence "." --stop-sequence "\\n"
|
|
452
467
|
"""
|
|
453
468
|
client: Client = ctx.obj['client']
|
|
454
469
|
try:
|
|
@@ -458,7 +473,7 @@ def chat(
|
|
|
458
473
|
ctx.exit(1)
|
|
459
474
|
return
|
|
460
475
|
if messages and messages_file:
|
|
461
|
-
click.echo("Cannot have both messages and
|
|
476
|
+
click.echo("Cannot have both messages and messages-file")
|
|
462
477
|
ctx.exit(1)
|
|
463
478
|
return
|
|
464
479
|
|
|
@@ -473,9 +488,9 @@ def chat(
|
|
|
473
488
|
messages = json.load(file)
|
|
474
489
|
|
|
475
490
|
# Parse tools if provided
|
|
476
|
-
if tools
|
|
477
|
-
click.echo("Cannot have both tools and
|
|
478
|
-
|
|
491
|
+
if (tools and tools != '[]') and tools_file:
|
|
492
|
+
click.echo("Cannot have both tools and tools-file")
|
|
493
|
+
click.exit(1)
|
|
479
494
|
return
|
|
480
495
|
|
|
481
496
|
parsed_tools=[]
|
|
@@ -509,6 +524,7 @@ def chat(
|
|
|
509
524
|
|
|
510
525
|
tx_hash, finish_reason, llm_chat_output = client.llm_chat(
|
|
511
526
|
model_cid=model_cid,
|
|
527
|
+
inference_mode=LlmInferenceModes[inference_mode],
|
|
512
528
|
messages=messages,
|
|
513
529
|
max_tokens=max_tokens,
|
|
514
530
|
stop_sequence=list(stop_sequence),
|
|
@@ -517,15 +533,32 @@ def chat(
|
|
|
517
533
|
tool_choice=tool_choice,
|
|
518
534
|
)
|
|
519
535
|
|
|
520
|
-
|
|
521
|
-
print("TX Hash: ", tx_hash)
|
|
522
|
-
print("Finish reason: ", finish_reason)
|
|
523
|
-
print("Chat output: ", llm_chat_output)
|
|
536
|
+
print_llm_chat_result(model_cid, tx_hash, finish_reason, llm_chat_output)
|
|
524
537
|
except Exception as e:
|
|
525
538
|
click.echo(f"Error running LLM chat inference: {str(e)}")
|
|
526
539
|
|
|
527
|
-
def print_llm_chat_result():
|
|
528
|
-
|
|
540
|
+
def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output):
|
|
541
|
+
click.secho("✅ LLM Chat Successful", fg="green", bold=True)
|
|
542
|
+
click.echo("──────────────────────────────────────")
|
|
543
|
+
click.echo("Model CID: ", nl=False)
|
|
544
|
+
click.secho(model_cid, fg="cyan", bold=True)
|
|
545
|
+
click.echo("Transaction hash: ", nl=False)
|
|
546
|
+
click.secho(tx_hash, fg="cyan", bold=True)
|
|
547
|
+
block_explorer_link = f"{DEFAULT_BLOCKCHAIN_EXPLORER}0x{tx_hash}"
|
|
548
|
+
click.echo("Block explorer link: ", nl=False)
|
|
549
|
+
click.secho(block_explorer_link, fg="blue", underline=True)
|
|
550
|
+
click.echo("──────────────────────────────────────")
|
|
551
|
+
click.secho("Finish Reason: ", fg="yellow", bold=True)
|
|
552
|
+
click.echo()
|
|
553
|
+
click.echo(finish_reason)
|
|
554
|
+
click.echo()
|
|
555
|
+
click.secho("Chat Output:", fg="yellow", bold=True)
|
|
556
|
+
click.echo()
|
|
557
|
+
for key, value in chat_output.items():
|
|
558
|
+
# If the value doesn't give any information, don't print it
|
|
559
|
+
if value != None and value != "" and value != '[]' and value != []:
|
|
560
|
+
click.echo(f"{key}: {value}")
|
|
561
|
+
click.echo()
|
|
529
562
|
|
|
530
563
|
@cli.command()
|
|
531
564
|
def create_account():
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import random
|
|
4
5
|
from typing import Dict, List, Optional, Tuple, Union
|
|
5
6
|
|
|
6
7
|
import firebase
|
|
@@ -12,7 +13,7 @@ from web3.logs import DISCARD
|
|
|
12
13
|
|
|
13
14
|
from opengradient import utils
|
|
14
15
|
from opengradient.exceptions import OpenGradientError
|
|
15
|
-
from opengradient.types import InferenceMode, LLM
|
|
16
|
+
from opengradient.types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
|
|
16
17
|
|
|
17
18
|
import grpc
|
|
18
19
|
import time
|
|
@@ -23,6 +24,31 @@ from opengradient.proto import infer_pb2
|
|
|
23
24
|
from opengradient.proto import infer_pb2_grpc
|
|
24
25
|
from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
|
|
25
26
|
|
|
27
|
+
from functools import wraps
|
|
28
|
+
|
|
29
|
+
def run_with_retry(txn_function, max_retries=5):
|
|
30
|
+
"""
|
|
31
|
+
Execute a blockchain transaction with retry logic.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
txn_function: Function that executes the transaction
|
|
35
|
+
max_retries (int): Maximum number of retry attempts
|
|
36
|
+
"""
|
|
37
|
+
last_error = None
|
|
38
|
+
for attempt in range(max_retries):
|
|
39
|
+
try:
|
|
40
|
+
return txn_function()
|
|
41
|
+
except Exception as e:
|
|
42
|
+
last_error = e
|
|
43
|
+
if attempt < max_retries - 1:
|
|
44
|
+
if "nonce too low" in str(e) or "nonce too high" in str(e):
|
|
45
|
+
time.sleep(1) # Wait before retry
|
|
46
|
+
continue
|
|
47
|
+
# If it's not a nonce error, raise immediately
|
|
48
|
+
raise
|
|
49
|
+
# If we've exhausted all retries, raise the last error
|
|
50
|
+
raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")
|
|
51
|
+
|
|
26
52
|
class Client:
|
|
27
53
|
FIREBASE_CONFIG = {
|
|
28
54
|
"apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
|
|
@@ -311,7 +337,8 @@ class Client:
|
|
|
311
337
|
self,
|
|
312
338
|
model_cid: str,
|
|
313
339
|
inference_mode: InferenceMode,
|
|
314
|
-
model_input: Dict[str, Union[str, int, float, List, np.ndarray]]
|
|
340
|
+
model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
|
|
341
|
+
max_retries: Optional[int] = None
|
|
315
342
|
) -> Tuple[str, Dict[str, np.ndarray]]:
|
|
316
343
|
"""
|
|
317
344
|
Perform inference on a model.
|
|
@@ -320,6 +347,7 @@ class Client:
|
|
|
320
347
|
model_cid (str): The unique content identifier for the model from IPFS.
|
|
321
348
|
inference_mode (InferenceMode): The inference mode.
|
|
322
349
|
model_input (Dict[str, Union[str, int, float, List, np.ndarray]]): The input data for the model.
|
|
350
|
+
max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
|
|
323
351
|
|
|
324
352
|
Returns:
|
|
325
353
|
Tuple[str, Dict[str, np.ndarray]]: The transaction hash and the model output.
|
|
@@ -327,46 +355,22 @@ class Client:
|
|
|
327
355
|
Raises:
|
|
328
356
|
OpenGradientError: If the inference fails.
|
|
329
357
|
"""
|
|
330
|
-
|
|
331
|
-
try:
|
|
332
|
-
logging.debug("Entering infer method")
|
|
358
|
+
def execute_transaction():
|
|
333
359
|
self._initialize_web3()
|
|
334
|
-
logging.debug(f"Web3 initialized. Connected: {self._w3.is_connected()}")
|
|
335
|
-
|
|
336
|
-
logging.debug(f"Creating contract instance. Address: {self.contract_address}")
|
|
337
360
|
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
logging.debug(f"Model ID: {model_cid}")
|
|
341
|
-
logging.debug(f"Inference Mode: {inference_mode}")
|
|
342
|
-
logging.debug(f"Model Input: {model_input}")
|
|
343
|
-
|
|
344
|
-
# Convert InferenceMode to uint8
|
|
361
|
+
|
|
345
362
|
inference_mode_uint8 = int(inference_mode)
|
|
346
|
-
|
|
347
|
-
# Prepare ModelInput tuple
|
|
348
363
|
converted_model_input = utils.convert_to_model_input(model_input)
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
logging.debug("Preparing run function")
|
|
364
|
+
|
|
352
365
|
run_function = contract.functions.run(
|
|
353
366
|
model_cid,
|
|
354
367
|
inference_mode_uint8,
|
|
355
368
|
converted_model_input
|
|
356
369
|
)
|
|
357
|
-
logging.debug("Run function prepared successfully")
|
|
358
370
|
|
|
359
|
-
|
|
360
|
-
nonce = self._w3.eth.get_transaction_count(self.wallet_address)
|
|
361
|
-
logging.debug(f"Nonce: {nonce}")
|
|
362
|
-
|
|
363
|
-
# Estimate gas
|
|
371
|
+
nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
|
|
364
372
|
estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
|
|
365
|
-
logging.debug(f"Estimated gas: {estimated_gas}")
|
|
366
|
-
|
|
367
|
-
# Increase gas limit by 20%
|
|
368
373
|
gas_limit = int(estimated_gas * 3)
|
|
369
|
-
logging.debug(f"Gas limit set to: {gas_limit}")
|
|
370
374
|
|
|
371
375
|
transaction = run_function.build_transaction({
|
|
372
376
|
'from': self.wallet_address,
|
|
@@ -375,62 +379,36 @@ class Client:
|
|
|
375
379
|
'gasPrice': self._w3.eth.gas_price,
|
|
376
380
|
})
|
|
377
381
|
|
|
378
|
-
logging.debug(f"Transaction built: {transaction}")
|
|
379
|
-
|
|
380
|
-
# Sign transaction
|
|
381
382
|
signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
|
|
382
|
-
logging.debug("Transaction signed successfully")
|
|
383
|
-
|
|
384
|
-
# Send transaction
|
|
385
383
|
tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
386
|
-
logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
|
|
387
|
-
|
|
388
|
-
# Wait for transaction receipt
|
|
389
384
|
tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
|
|
390
|
-
logging.debug(f"Transaction receipt received: {tx_receipt}")
|
|
391
385
|
|
|
392
|
-
# Check if the transaction was successful
|
|
393
386
|
if tx_receipt['status'] == 0:
|
|
394
387
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
395
388
|
|
|
396
|
-
# Process the InferenceResult event
|
|
397
389
|
parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
398
|
-
|
|
399
390
|
if len(parsed_logs) < 1:
|
|
400
391
|
raise OpenGradientError("InferenceResult event not found in transaction logs")
|
|
401
|
-
inference_result = parsed_logs[0]
|
|
402
|
-
|
|
403
|
-
# Extract the ModelOutput from the event
|
|
404
|
-
event_data = inference_result['args']
|
|
405
|
-
logging.debug(f"Raw event data: {event_data}")
|
|
406
|
-
|
|
407
|
-
try:
|
|
408
|
-
model_output = utils.convert_to_model_output(event_data)
|
|
409
|
-
logging.debug(f"Parsed ModelOutput: {model_output}")
|
|
410
|
-
except Exception as e:
|
|
411
|
-
logging.error(f"Error parsing event data: {str(e)}", exc_info=True)
|
|
412
|
-
raise OpenGradientError(f"Failed to parse event data: {str(e)}")
|
|
413
392
|
|
|
393
|
+
model_output = utils.convert_to_model_output(parsed_logs[0]['args'])
|
|
414
394
|
return tx_hash.hex(), model_output
|
|
415
395
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
raise OpenGradientError(f"Inference failed due to contract logic error: {str(e)}")
|
|
419
|
-
except Exception as e:
|
|
420
|
-
logging.error(f"Error in infer method: {str(e)}", exc_info=True)
|
|
421
|
-
raise OpenGradientError(f"Inference failed: {str(e)}")
|
|
422
|
-
|
|
396
|
+
return run_with_retry(execute_transaction, max_retries or 5)
|
|
397
|
+
|
|
423
398
|
def llm_completion(self,
|
|
424
399
|
model_cid: LLM,
|
|
400
|
+
inference_mode: InferenceMode,
|
|
425
401
|
prompt: str,
|
|
426
402
|
max_tokens: int = 100,
|
|
427
403
|
stop_sequence: Optional[List[str]] = None,
|
|
428
|
-
temperature: float = 0.0
|
|
404
|
+
temperature: float = 0.0,
|
|
405
|
+
max_retries: Optional[int] = None) -> Tuple[str, str]:
|
|
429
406
|
"""
|
|
430
407
|
Perform inference on an LLM model using completions.
|
|
431
408
|
|
|
432
409
|
Args:
|
|
433
410
|
model_cid (LLM): The unique content identifier for the model.
|
|
411
|
+
inference_mode (InferenceMode): The inference mode.
|
|
434
412
|
prompt (str): The input prompt for the LLM.
|
|
435
413
|
max_tokens (int): Maximum number of tokens for LLM output. Default is 100.
|
|
436
414
|
stop_sequence (List[str], optional): List of stop sequences for LLM. Default is None.
|
|
@@ -442,17 +420,20 @@ class Client:
|
|
|
442
420
|
Raises:
|
|
443
421
|
OpenGradientError: If the inference fails.
|
|
444
422
|
"""
|
|
445
|
-
|
|
446
|
-
|
|
423
|
+
def execute_transaction():
|
|
424
|
+
# Check inference mode and supported model
|
|
425
|
+
if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
|
|
426
|
+
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
447
427
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
428
|
+
if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
|
|
429
|
+
raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
|
|
430
|
+
|
|
431
|
+
self._initialize_web3()
|
|
432
|
+
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
452
433
|
|
|
453
434
|
# Prepare LLM input
|
|
454
435
|
llm_request = {
|
|
455
|
-
"mode":
|
|
436
|
+
"mode": inference_mode,
|
|
456
437
|
"modelCID": model_cid,
|
|
457
438
|
"prompt": prompt,
|
|
458
439
|
"max_tokens": max_tokens,
|
|
@@ -461,11 +442,9 @@ class Client:
|
|
|
461
442
|
}
|
|
462
443
|
logging.debug(f"Prepared LLM request: {llm_request}")
|
|
463
444
|
|
|
464
|
-
# Prepare run function
|
|
465
445
|
run_function = contract.functions.runLLMCompletion(llm_request)
|
|
466
446
|
|
|
467
|
-
|
|
468
|
-
nonce = self._w3.eth.get_transaction_count(self.wallet_address)
|
|
447
|
+
nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
|
|
469
448
|
estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
|
|
470
449
|
gas_limit = int(estimated_gas * 1.2)
|
|
471
450
|
|
|
@@ -476,47 +455,38 @@ class Client:
|
|
|
476
455
|
'gasPrice': self._w3.eth.gas_price,
|
|
477
456
|
})
|
|
478
457
|
|
|
479
|
-
# Sign and send transaction
|
|
480
458
|
signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
|
|
481
459
|
tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
482
|
-
logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
|
|
483
|
-
|
|
484
|
-
# Wait for transaction receipt
|
|
485
460
|
tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
|
|
486
461
|
|
|
487
462
|
if tx_receipt['status'] == 0:
|
|
488
463
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
489
464
|
|
|
490
|
-
# Process the LLMResult event
|
|
491
465
|
parsed_logs = contract.events.LLMCompletionResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
492
|
-
|
|
493
466
|
if len(parsed_logs) < 1:
|
|
494
467
|
raise OpenGradientError("LLM completion result event not found in transaction logs")
|
|
495
|
-
llm_result = parsed_logs[0]
|
|
496
468
|
|
|
497
|
-
llm_answer =
|
|
469
|
+
llm_answer = parsed_logs[0]['args']['response']['answer']
|
|
498
470
|
return tx_hash.hex(), llm_answer
|
|
499
471
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
raise OpenGradientError(f"LLM inference failed due to contract logic error: {str(e)}")
|
|
503
|
-
except Exception as e:
|
|
504
|
-
logging.error(f"Error in infer completion method: {str(e)}", exc_info=True)
|
|
505
|
-
raise OpenGradientError(f"LLM inference failed: {str(e)}")
|
|
506
|
-
|
|
472
|
+
return run_with_retry(execute_transaction, max_retries or 5)
|
|
473
|
+
|
|
507
474
|
def llm_chat(self,
|
|
508
475
|
model_cid: str,
|
|
476
|
+
inference_mode: InferenceMode,
|
|
509
477
|
messages: List[Dict],
|
|
510
478
|
max_tokens: int = 100,
|
|
511
479
|
stop_sequence: Optional[List[str]] = None,
|
|
512
480
|
temperature: float = 0.0,
|
|
513
481
|
tools: Optional[List[Dict]] = [],
|
|
514
|
-
tool_choice: Optional[str] = None
|
|
482
|
+
tool_choice: Optional[str] = None,
|
|
483
|
+
max_retries: Optional[int] = None) -> Tuple[str, str]:
|
|
515
484
|
"""
|
|
516
485
|
Perform inference on an LLM model using chat.
|
|
517
486
|
|
|
518
487
|
Args:
|
|
519
488
|
model_cid (LLM): The unique content identifier for the model.
|
|
489
|
+
inference_mode (InferenceMode): The inference mode.
|
|
520
490
|
messages (dict): The messages that will be passed into the chat.
|
|
521
491
|
This should be in OpenAI API format (https://platform.openai.com/docs/api-reference/chat/create)
|
|
522
492
|
Example:
|
|
@@ -567,13 +537,16 @@ class Client:
|
|
|
567
537
|
Raises:
|
|
568
538
|
OpenGradientError: If the inference fails.
|
|
569
539
|
"""
|
|
570
|
-
|
|
571
|
-
|
|
540
|
+
def execute_transaction():
|
|
541
|
+
# Check inference mode and supported model
|
|
542
|
+
if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
|
|
543
|
+
raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
|
|
544
|
+
|
|
545
|
+
if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
|
|
546
|
+
raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
|
|
572
547
|
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
llm_abi = json.load(abi_file)
|
|
576
|
-
contract = self._w3.eth.contract(address=self.contract_address, abi=llm_abi)
|
|
548
|
+
self._initialize_web3()
|
|
549
|
+
contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
|
|
577
550
|
|
|
578
551
|
# For incoming chat messages, tool_calls can be empty. Add an empty array so that it will fit the ABI.
|
|
579
552
|
for message in messages:
|
|
@@ -585,17 +558,10 @@ class Client:
|
|
|
585
558
|
message['name'] = ""
|
|
586
559
|
|
|
587
560
|
# Create simplified tool structure for smart contract
|
|
588
|
-
#
|
|
589
|
-
# struct ToolDefinition {
|
|
590
|
-
# string description;
|
|
591
|
-
# string name;
|
|
592
|
-
# string parameters; // This must be a JSON
|
|
593
|
-
# }
|
|
594
561
|
converted_tools = []
|
|
595
562
|
if tools is not None:
|
|
596
563
|
for tool in tools:
|
|
597
564
|
function = tool['function']
|
|
598
|
-
|
|
599
565
|
converted_tool = {}
|
|
600
566
|
converted_tool['name'] = function['name']
|
|
601
567
|
converted_tool['description'] = function['description']
|
|
@@ -604,12 +570,11 @@ class Client:
|
|
|
604
570
|
converted_tool['parameters'] = json.dumps(parameters)
|
|
605
571
|
except Exception as e:
|
|
606
572
|
raise OpenGradientError("Chat LLM failed to convert parameters into JSON: %s", e)
|
|
607
|
-
|
|
608
573
|
converted_tools.append(converted_tool)
|
|
609
574
|
|
|
610
575
|
# Prepare LLM input
|
|
611
576
|
llm_request = {
|
|
612
|
-
"mode":
|
|
577
|
+
"mode": inference_mode,
|
|
613
578
|
"modelCID": model_cid,
|
|
614
579
|
"messages": messages,
|
|
615
580
|
"max_tokens": max_tokens,
|
|
@@ -620,11 +585,9 @@ class Client:
|
|
|
620
585
|
}
|
|
621
586
|
logging.debug(f"Prepared LLM request: {llm_request}")
|
|
622
587
|
|
|
623
|
-
# Prepare run function
|
|
624
588
|
run_function = contract.functions.runLLMChat(llm_request)
|
|
625
589
|
|
|
626
|
-
|
|
627
|
-
nonce = self._w3.eth.get_transaction_count(self.wallet_address)
|
|
590
|
+
nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
|
|
628
591
|
estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
|
|
629
592
|
gas_limit = int(estimated_gas * 1.2)
|
|
630
593
|
|
|
@@ -635,41 +598,25 @@ class Client:
|
|
|
635
598
|
'gasPrice': self._w3.eth.gas_price,
|
|
636
599
|
})
|
|
637
600
|
|
|
638
|
-
# Sign and send transaction
|
|
639
601
|
signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
|
|
640
602
|
tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
|
|
641
|
-
logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
|
|
642
|
-
|
|
643
|
-
# Wait for transaction receipt
|
|
644
603
|
tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
|
|
645
604
|
|
|
646
605
|
if tx_receipt['status'] == 0:
|
|
647
606
|
raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
|
|
648
607
|
|
|
649
|
-
# Process the LLMResult event
|
|
650
608
|
parsed_logs = contract.events.LLMChatResult().process_receipt(tx_receipt, errors=DISCARD)
|
|
651
|
-
|
|
652
609
|
if len(parsed_logs) < 1:
|
|
653
610
|
raise OpenGradientError("LLM chat result event not found in transaction logs")
|
|
654
|
-
llm_result = parsed_logs[0]['args']['response']
|
|
655
611
|
|
|
656
|
-
|
|
612
|
+
llm_result = parsed_logs[0]['args']['response']
|
|
657
613
|
message = dict(llm_result['message'])
|
|
658
|
-
if (tool_calls := message.get('tool_calls'))
|
|
659
|
-
|
|
660
|
-
for tool_call in tool_calls:
|
|
661
|
-
new_tool_calls.append(dict(tool_call))
|
|
662
|
-
message['tool_calls'] = new_tool_calls
|
|
663
|
-
|
|
664
|
-
return (tx_hash.hex(), llm_result['finish_reason'], message)
|
|
614
|
+
if (tool_calls := message.get('tool_calls')) is not None:
|
|
615
|
+
message['tool_calls'] = [dict(tool_call) for tool_call in tool_calls]
|
|
665
616
|
|
|
666
|
-
|
|
667
|
-
logging.error(f"Contract logic error: {str(e)}", exc_info=True)
|
|
668
|
-
raise OpenGradientError(f"LLM inference failed due to contract logic error: {str(e)}")
|
|
669
|
-
except Exception as e:
|
|
670
|
-
logging.error(f"Error in infer chat method: {str(e)}", exc_info=True)
|
|
671
|
-
raise OpenGradientError(f"LLM inference failed: {str(e)}")
|
|
617
|
+
return tx_hash.hex(), llm_result['finish_reason'], message
|
|
672
618
|
|
|
619
|
+
return run_with_retry(execute_transaction, max_retries or 5)
|
|
673
620
|
|
|
674
621
|
def list_files(self, model_name: str, version: str) -> List[Dict]:
|
|
675
622
|
"""
|
|
@@ -27,6 +27,10 @@ class InferenceMode:
|
|
|
27
27
|
ZKML = 1
|
|
28
28
|
TEE = 2
|
|
29
29
|
|
|
30
|
+
class LlmInferenceMode:
|
|
31
|
+
VANILLA = 0
|
|
32
|
+
TEE = 1
|
|
33
|
+
|
|
30
34
|
@dataclass
|
|
31
35
|
class ModelOutput:
|
|
32
36
|
numbers: List[NumberTensor]
|
|
@@ -79,4 +83,7 @@ class LLM(str, Enum):
|
|
|
79
83
|
LLAMA_3_2_3B_INSTRUCT = "meta-llama/Llama-3.2-3B-Instruct"
|
|
80
84
|
MISTRAL_7B_INSTRUCT_V3 = "mistralai/Mistral-7B-Instruct-v0.3"
|
|
81
85
|
HERMES_3_LLAMA_3_1_70B = "NousResearch/Hermes-3-Llama-3.1-70B"
|
|
86
|
+
META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
|
|
87
|
+
|
|
88
|
+
class TEE_LLM(str, Enum):
|
|
82
89
|
META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|