opengradient 0.3.15__tar.gz → 0.3.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {opengradient-0.3.15 → opengradient-0.3.17}/PKG-INFO +1 -1
  2. {opengradient-0.3.15 → opengradient-0.3.17}/pyproject.toml +1 -1
  3. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/__init__.py +46 -16
  4. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/cli.py +49 -16
  5. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/client.py +75 -128
  6. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/types.py +7 -0
  7. {opengradient-0.3.15 → opengradient-0.3.17}/.gitignore +0 -0
  8. {opengradient-0.3.15 → opengradient-0.3.17}/LICENSE +0 -0
  9. {opengradient-0.3.15 → opengradient-0.3.17}/README.md +0 -0
  10. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/abi/inference.abi +0 -0
  11. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/account.py +0 -0
  12. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/defaults.py +0 -0
  13. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/exceptions.py +0 -0
  14. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/llm/__init__.py +0 -0
  15. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/llm/chat.py +0 -0
  16. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/__init__.py +0 -0
  17. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/infer.proto +0 -0
  18. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/infer_pb2.py +0 -0
  19. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/proto/infer_pb2_grpc.py +0 -0
  20. {opengradient-0.3.15 → opengradient-0.3.17}/src/opengradient/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: opengradient
3
- Version: 0.3.15
3
+ Version: 0.3.17
4
4
  Summary: Python SDK for OpenGradient decentralized model management & inference services
5
5
  Project-URL: Homepage, https://opengradient.ai
6
6
  Author-email: OpenGradient <oliver@opengradient.ai>
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "opengradient"
7
- version = "0.3.15"
7
+ version = "0.3.17"
8
8
  description = "Python SDK for OpenGradient decentralized model management & inference services"
9
9
  authors = [{name = "OpenGradient", email = "oliver@opengradient.ai"}]
10
10
  license = {file = "LICENSE"}
@@ -2,10 +2,10 @@ from typing import Dict, List, Optional, Tuple
2
2
 
3
3
  from .client import Client
4
4
  from .defaults import DEFAULT_INFERENCE_CONTRACT_ADDRESS, DEFAULT_RPC_URL
5
- from .types import InferenceMode, LLM
5
+ from .types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
6
6
  from . import llm
7
7
 
8
- __version__ = "0.3.15"
8
+ __version__ = "0.3.17"
9
9
 
10
10
  _client = None
11
11
 
@@ -40,30 +40,60 @@ def create_version(model_name, notes=None, is_major=False):
40
40
  raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
41
41
  return _client.create_version(model_name, notes, is_major)
42
42
 
43
- def infer(model_cid, inference_mode, model_input):
43
+ def infer(model_cid, inference_mode, model_input, max_retries: Optional[int] = None):
44
+ """
45
+ Perform inference on a model.
46
+
47
+ Args:
48
+ model_cid: Model CID to use for inference
49
+ inference_mode: Mode of inference (e.g. VANILLA)
50
+ model_input: Input data for the model
51
+ max_retries: Optional maximum number of retry attempts for transaction errors
52
+
53
+ Returns:
54
+ Tuple of (transaction hash, model output)
55
+ """
44
56
  if _client is None:
45
57
  raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
46
- return _client.infer(model_cid, inference_mode, model_input)
58
+ return _client.infer(model_cid, inference_mode, model_input, max_retries=max_retries)
47
59
 
48
60
  def llm_completion(model_cid: LLM,
49
- prompt: str,
50
- max_tokens: int = 100,
51
- stop_sequence: Optional[List[str]] = None,
52
- temperature: float = 0.0) -> Tuple[str, str]:
61
+ prompt: str,
62
+ inference_mode: str = LlmInferenceMode.VANILLA,
63
+ max_tokens: int = 100,
64
+ stop_sequence: Optional[List[str]] = None,
65
+ temperature: float = 0.0,
66
+ max_retries: Optional[int] = None) -> Tuple[str, str]:
53
67
  if _client is None:
54
68
  raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
55
- return _client.llm_completion(model_cid, prompt, max_tokens, stop_sequence, temperature)
69
+ return _client.llm_completion(model_cid=model_cid,
70
+ inference_mode=inference_mode,
71
+ prompt=prompt,
72
+ max_tokens=max_tokens,
73
+ stop_sequence=stop_sequence,
74
+ temperature=temperature,
75
+ max_retries=max_retries)
56
76
 
57
77
  def llm_chat(model_cid: LLM,
58
- messages: List[Dict],
59
- max_tokens: int = 100,
60
- stop_sequence: Optional[List[str]] = None,
61
- temperature: float = 0.0,
62
- tools: Optional[List[Dict]] = None,
63
- tool_choice: Optional[str] = None):
78
+ messages: List[Dict],
79
+ inference_mode: str = LlmInferenceMode.VANILLA,
80
+ max_tokens: int = 100,
81
+ stop_sequence: Optional[List[str]] = None,
82
+ temperature: float = 0.0,
83
+ tools: Optional[List[Dict]] = None,
84
+ tool_choice: Optional[str] = None,
85
+ max_retries: Optional[int] = None) -> Tuple[str, str, Dict]:
64
86
  if _client is None:
65
87
  raise RuntimeError("OpenGradient client not initialized. Call og.init() first.")
66
- return _client.llm_chat(model_cid, messages, max_tokens, stop_sequence, temperature, tools, tool_choice)
88
+ return _client.llm_chat(model_cid=model_cid,
89
+ inference_mode=inference_mode,
90
+ messages=messages,
91
+ max_tokens=max_tokens,
92
+ stop_sequence=stop_sequence,
93
+ temperature=temperature,
94
+ tools=tools,
95
+ tool_choice=tool_choice,
96
+ max_retries=max_retries)
67
97
 
68
98
  def login(email: str, password: str):
69
99
  if _client is None:
@@ -20,7 +20,7 @@ from .defaults import (
20
20
  DEFAULT_OG_FAUCET_URL,
21
21
  DEFAULT_RPC_URL,
22
22
  )
23
- from .types import InferenceMode
23
+ from .types import InferenceMode, LlmInferenceMode
24
24
 
25
25
  OG_CONFIG_FILE = Path.home() / '.opengradient_config.json'
26
26
 
@@ -65,11 +65,17 @@ InferenceModes = {
65
65
  "TEE": InferenceMode.TEE,
66
66
  }
67
67
 
68
+ LlmInferenceModes = {
69
+ "VANILLA": LlmInferenceMode.VANILLA,
70
+ "TEE": LlmInferenceMode.TEE,
71
+ }
72
+
68
73
  # Supported LLMs
69
74
  LlmModels = {
70
75
  "meta-llama/Meta-Llama-3-8B-Instruct",
71
76
  "meta-llama/Llama-3.2-3B-Instruct",
72
- "mistralai/Mistral-7B-Instruct-v0.3"
77
+ "mistralai/Mistral-7B-Instruct-v0.3",
78
+ "meta-llama/Llama-3.1-70B-Instruct",
73
79
  }
74
80
 
75
81
  def initialize_config(ctx):
@@ -339,13 +345,15 @@ def infer(ctx, model_cid: str, inference_mode: str, input_data, input_file: Path
339
345
  click.echo(f"Error running inference: {str(e)}")
340
346
 
341
347
  @cli.command()
342
- @click.option('--model', '-m', 'model_cid', type=click.Choice(LlmModels), required=True, help='CID of the LLM model to run inference on')
348
+ @click.option('--model', '-m', 'model_cid', type=click.Choice([e.value for e in types.LLM]), required=True, help='CID of the LLM model to run inference on')
349
+ @click.option('--mode', 'inference_mode', type=click.Choice(LlmInferenceModes.keys()), default="VANILLA",
350
+ help='Inference mode (default: VANILLA)')
343
351
  @click.option('--prompt', '-p', required=True, help='Input prompt for the LLM completion')
344
352
  @click.option('--max-tokens', type=int, default=100, help='Maximum number of tokens for LLM completion output')
345
353
  @click.option('--stop-sequence', multiple=True, help='Stop sequences for LLM')
346
354
  @click.option('--temperature', type=float, default=0.0, help='Temperature for LLM inference (0.0 to 1.0)')
347
355
  @click.pass_context
348
- def completion(ctx, model_cid: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
356
+ def completion(ctx, model_cid: str, inference_mode: str, prompt: str, max_tokens: int, stop_sequence: List[str], temperature: float):
349
357
  """
350
358
  Run completion inference on an LLM model.
351
359
 
@@ -355,13 +363,14 @@ def completion(ctx, model_cid: str, prompt: str, max_tokens: int, stop_sequence:
355
363
 
356
364
  \b
357
365
  opengradient completion --model meta-llama/Meta-Llama-3-8B-Instruct --prompt "Hello, how are you?" --max-tokens 50 --temperature 0.7
358
- opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "\n"
366
+ opengradient completion -m meta-llama/Meta-Llama-3-8B-Instruct -p "Translate to French: Hello world" --stop-sequence "." --stop-sequence "\\n"
359
367
  """
360
368
  client: Client = ctx.obj['client']
361
369
  try:
362
370
  click.echo(f"Running LLM completion inference for model \"{model_cid}\"\n")
363
371
  tx_hash, llm_output = client.llm_completion(
364
372
  model_cid=model_cid,
373
+ inference_mode=LlmInferenceModes[inference_mode],
365
374
  prompt=prompt,
366
375
  max_tokens=max_tokens,
367
376
  stop_sequence=list(stop_sequence),
@@ -394,6 +403,9 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
394
403
  type=click.Choice([e.value for e in types.LLM]),
395
404
  required=True,
396
405
  help='CID of the LLM model to run inference on')
406
+ @click.option('--mode', 'inference_mode', type=click.Choice(LlmInferenceModes.keys()),
407
+ default="VANILLA",
408
+ help='Inference mode (default: VANILLA)')
397
409
  @click.option('--messages',
398
410
  type=str,
399
411
  required=False,
@@ -431,6 +443,7 @@ def print_llm_completion_result(model_cid, tx_hash, llm_output):
431
443
  def chat(
432
444
  ctx,
433
445
  model_cid: str,
446
+ inference_mode: str,
434
447
  messages: Optional[str],
435
448
  messages_file: Optional[Path],
436
449
  max_tokens: int,
@@ -444,11 +457,13 @@ def chat(
444
457
 
445
458
  This command runs a chat inference on the specified LLM model using the provided messages and parameters.
446
459
 
460
+ Tool call formatting is based on OpenAI documentation tool calls (see here: https://platform.openai.com/docs/guides/function-calling).
461
+
447
462
  Example usage:
448
463
 
449
464
  \b
450
465
  opengradient chat --model meta-llama/Meta-Llama-3-8B-Instruct --messages '[{"role":"user","content":"hello"}]' --max-tokens 50 --temperature 0.7
451
- opengradient chat -m mistralai/Mistral-7B-Instruct-v0.3 --messages-file messages.json --stop-sequence "." --stop-sequence "\n"
466
+ opengradient chat --model mistralai/Mistral-7B-Instruct-v0.3 --messages-file messages.json --tools-file tools.json --max-tokens 200 --stop-sequence "." --stop-sequence "\\n"
452
467
  """
453
468
  client: Client = ctx.obj['client']
454
469
  try:
@@ -458,7 +473,7 @@ def chat(
458
473
  ctx.exit(1)
459
474
  return
460
475
  if messages and messages_file:
461
- click.echo("Cannot have both messages and messages_file")
476
+ click.echo("Cannot have both messages and messages-file")
462
477
  ctx.exit(1)
463
478
  return
464
479
 
@@ -473,9 +488,9 @@ def chat(
473
488
  messages = json.load(file)
474
489
 
475
490
  # Parse tools if provided
476
- if tools is not None and tools != "[]" and tools_file:
477
- click.echo("Cannot have both tools and tools_file")
478
- ctx.exit(1)
491
+ if (tools and tools != '[]') and tools_file:
492
+ click.echo("Cannot have both tools and tools-file")
493
+ click.exit(1)
479
494
  return
480
495
 
481
496
  parsed_tools=[]
@@ -509,6 +524,7 @@ def chat(
509
524
 
510
525
  tx_hash, finish_reason, llm_chat_output = client.llm_chat(
511
526
  model_cid=model_cid,
527
+ inference_mode=LlmInferenceModes[inference_mode],
512
528
  messages=messages,
513
529
  max_tokens=max_tokens,
514
530
  stop_sequence=list(stop_sequence),
@@ -517,15 +533,32 @@ def chat(
517
533
  tool_choice=tool_choice,
518
534
  )
519
535
 
520
- # TODO (Kyle): Make this prettier
521
- print("TX Hash: ", tx_hash)
522
- print("Finish reason: ", finish_reason)
523
- print("Chat output: ", llm_chat_output)
536
+ print_llm_chat_result(model_cid, tx_hash, finish_reason, llm_chat_output)
524
537
  except Exception as e:
525
538
  click.echo(f"Error running LLM chat inference: {str(e)}")
526
539
 
527
- def print_llm_chat_result():
528
- pass
540
+ def print_llm_chat_result(model_cid, tx_hash, finish_reason, chat_output):
541
+ click.secho("✅ LLM Chat Successful", fg="green", bold=True)
542
+ click.echo("──────────────────────────────────────")
543
+ click.echo("Model CID: ", nl=False)
544
+ click.secho(model_cid, fg="cyan", bold=True)
545
+ click.echo("Transaction hash: ", nl=False)
546
+ click.secho(tx_hash, fg="cyan", bold=True)
547
+ block_explorer_link = f"{DEFAULT_BLOCKCHAIN_EXPLORER}0x{tx_hash}"
548
+ click.echo("Block explorer link: ", nl=False)
549
+ click.secho(block_explorer_link, fg="blue", underline=True)
550
+ click.echo("──────────────────────────────────────")
551
+ click.secho("Finish Reason: ", fg="yellow", bold=True)
552
+ click.echo()
553
+ click.echo(finish_reason)
554
+ click.echo()
555
+ click.secho("Chat Output:", fg="yellow", bold=True)
556
+ click.echo()
557
+ for key, value in chat_output.items():
558
+ # If the value doesn't give any information, don't print it
559
+ if value != None and value != "" and value != '[]' and value != []:
560
+ click.echo(f"{key}: {value}")
561
+ click.echo()
529
562
 
530
563
  @cli.command()
531
564
  def create_account():
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
+ import random
4
5
  from typing import Dict, List, Optional, Tuple, Union
5
6
 
6
7
  import firebase
@@ -12,7 +13,7 @@ from web3.logs import DISCARD
12
13
 
13
14
  from opengradient import utils
14
15
  from opengradient.exceptions import OpenGradientError
15
- from opengradient.types import InferenceMode, LLM
16
+ from opengradient.types import InferenceMode, LlmInferenceMode, LLM, TEE_LLM
16
17
 
17
18
  import grpc
18
19
  import time
@@ -23,6 +24,31 @@ from opengradient.proto import infer_pb2
23
24
  from opengradient.proto import infer_pb2_grpc
24
25
  from .defaults import DEFAULT_IMAGE_GEN_HOST, DEFAULT_IMAGE_GEN_PORT
25
26
 
27
+ from functools import wraps
28
+
29
+ def run_with_retry(txn_function, max_retries=5):
30
+ """
31
+ Execute a blockchain transaction with retry logic.
32
+
33
+ Args:
34
+ txn_function: Function that executes the transaction
35
+ max_retries (int): Maximum number of retry attempts
36
+ """
37
+ last_error = None
38
+ for attempt in range(max_retries):
39
+ try:
40
+ return txn_function()
41
+ except Exception as e:
42
+ last_error = e
43
+ if attempt < max_retries - 1:
44
+ if "nonce too low" in str(e) or "nonce too high" in str(e):
45
+ time.sleep(1) # Wait before retry
46
+ continue
47
+ # If it's not a nonce error, raise immediately
48
+ raise
49
+ # If we've exhausted all retries, raise the last error
50
+ raise OpenGradientError(f"Transaction failed after {max_retries} attempts: {str(last_error)}")
51
+
26
52
  class Client:
27
53
  FIREBASE_CONFIG = {
28
54
  "apiKey": "AIzaSyDUVckVtfl-hiteBzPopy1pDD8Uvfncs7w",
@@ -311,7 +337,8 @@ class Client:
311
337
  self,
312
338
  model_cid: str,
313
339
  inference_mode: InferenceMode,
314
- model_input: Dict[str, Union[str, int, float, List, np.ndarray]]
340
+ model_input: Dict[str, Union[str, int, float, List, np.ndarray]],
341
+ max_retries: Optional[int] = None
315
342
  ) -> Tuple[str, Dict[str, np.ndarray]]:
316
343
  """
317
344
  Perform inference on a model.
@@ -320,6 +347,7 @@ class Client:
320
347
  model_cid (str): The unique content identifier for the model from IPFS.
321
348
  inference_mode (InferenceMode): The inference mode.
322
349
  model_input (Dict[str, Union[str, int, float, List, np.ndarray]]): The input data for the model.
350
+ max_retries (int, optional): Maximum number of retry attempts. Defaults to 5.
323
351
 
324
352
  Returns:
325
353
  Tuple[str, Dict[str, np.ndarray]]: The transaction hash and the model output.
@@ -327,46 +355,22 @@ class Client:
327
355
  Raises:
328
356
  OpenGradientError: If the inference fails.
329
357
  """
330
- # TODO (Kyle): Add input support for JSON tensors
331
- try:
332
- logging.debug("Entering infer method")
358
+ def execute_transaction():
333
359
  self._initialize_web3()
334
- logging.debug(f"Web3 initialized. Connected: {self._w3.is_connected()}")
335
-
336
- logging.debug(f"Creating contract instance. Address: {self.contract_address}")
337
360
  contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
338
- logging.debug("Contract instance created successfully")
339
-
340
- logging.debug(f"Model ID: {model_cid}")
341
- logging.debug(f"Inference Mode: {inference_mode}")
342
- logging.debug(f"Model Input: {model_input}")
343
-
344
- # Convert InferenceMode to uint8
361
+
345
362
  inference_mode_uint8 = int(inference_mode)
346
-
347
- # Prepare ModelInput tuple
348
363
  converted_model_input = utils.convert_to_model_input(model_input)
349
- logging.debug(f"Prepared model input tuple: {converted_model_input}")
350
-
351
- logging.debug("Preparing run function")
364
+
352
365
  run_function = contract.functions.run(
353
366
  model_cid,
354
367
  inference_mode_uint8,
355
368
  converted_model_input
356
369
  )
357
- logging.debug("Run function prepared successfully")
358
370
 
359
- # Build transaction
360
- nonce = self._w3.eth.get_transaction_count(self.wallet_address)
361
- logging.debug(f"Nonce: {nonce}")
362
-
363
- # Estimate gas
371
+ nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
364
372
  estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
365
- logging.debug(f"Estimated gas: {estimated_gas}")
366
-
367
- # Increase gas limit by 20%
368
373
  gas_limit = int(estimated_gas * 3)
369
- logging.debug(f"Gas limit set to: {gas_limit}")
370
374
 
371
375
  transaction = run_function.build_transaction({
372
376
  'from': self.wallet_address,
@@ -375,62 +379,36 @@ class Client:
375
379
  'gasPrice': self._w3.eth.gas_price,
376
380
  })
377
381
 
378
- logging.debug(f"Transaction built: {transaction}")
379
-
380
- # Sign transaction
381
382
  signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
382
- logging.debug("Transaction signed successfully")
383
-
384
- # Send transaction
385
383
  tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
386
- logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
387
-
388
- # Wait for transaction receipt
389
384
  tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
390
- logging.debug(f"Transaction receipt received: {tx_receipt}")
391
385
 
392
- # Check if the transaction was successful
393
386
  if tx_receipt['status'] == 0:
394
387
  raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
395
388
 
396
- # Process the InferenceResult event
397
389
  parsed_logs = contract.events.InferenceResult().process_receipt(tx_receipt, errors=DISCARD)
398
-
399
390
  if len(parsed_logs) < 1:
400
391
  raise OpenGradientError("InferenceResult event not found in transaction logs")
401
- inference_result = parsed_logs[0]
402
-
403
- # Extract the ModelOutput from the event
404
- event_data = inference_result['args']
405
- logging.debug(f"Raw event data: {event_data}")
406
-
407
- try:
408
- model_output = utils.convert_to_model_output(event_data)
409
- logging.debug(f"Parsed ModelOutput: {model_output}")
410
- except Exception as e:
411
- logging.error(f"Error parsing event data: {str(e)}", exc_info=True)
412
- raise OpenGradientError(f"Failed to parse event data: {str(e)}")
413
392
 
393
+ model_output = utils.convert_to_model_output(parsed_logs[0]['args'])
414
394
  return tx_hash.hex(), model_output
415
395
 
416
- except ContractLogicError as e:
417
- logging.error(f"Contract logic error: {str(e)}", exc_info=True)
418
- raise OpenGradientError(f"Inference failed due to contract logic error: {str(e)}")
419
- except Exception as e:
420
- logging.error(f"Error in infer method: {str(e)}", exc_info=True)
421
- raise OpenGradientError(f"Inference failed: {str(e)}")
422
-
396
+ return run_with_retry(execute_transaction, max_retries or 5)
397
+
423
398
  def llm_completion(self,
424
399
  model_cid: LLM,
400
+ inference_mode: InferenceMode,
425
401
  prompt: str,
426
402
  max_tokens: int = 100,
427
403
  stop_sequence: Optional[List[str]] = None,
428
- temperature: float = 0.0) -> Tuple[str, str]:
404
+ temperature: float = 0.0,
405
+ max_retries: Optional[int] = None) -> Tuple[str, str]:
429
406
  """
430
407
  Perform inference on an LLM model using completions.
431
408
 
432
409
  Args:
433
410
  model_cid (LLM): The unique content identifier for the model.
411
+ inference_mode (InferenceMode): The inference mode.
434
412
  prompt (str): The input prompt for the LLM.
435
413
  max_tokens (int): Maximum number of tokens for LLM output. Default is 100.
436
414
  stop_sequence (List[str], optional): List of stop sequences for LLM. Default is None.
@@ -442,17 +420,20 @@ class Client:
442
420
  Raises:
443
421
  OpenGradientError: If the inference fails.
444
422
  """
445
- try:
446
- self._initialize_web3()
423
+ def execute_transaction():
424
+ # Check inference mode and supported model
425
+ if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
426
+ raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
447
427
 
448
- abi_path = os.path.join(os.path.dirname(__file__), 'abi', 'inference.abi')
449
- with open(abi_path, 'r') as abi_file:
450
- llm_abi = json.load(abi_file)
451
- contract = self._w3.eth.contract(address=self.contract_address, abi=llm_abi)
428
+ if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
429
+ raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
430
+
431
+ self._initialize_web3()
432
+ contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
452
433
 
453
434
  # Prepare LLM input
454
435
  llm_request = {
455
- "mode": InferenceMode.VANILLA,
436
+ "mode": inference_mode,
456
437
  "modelCID": model_cid,
457
438
  "prompt": prompt,
458
439
  "max_tokens": max_tokens,
@@ -461,11 +442,9 @@ class Client:
461
442
  }
462
443
  logging.debug(f"Prepared LLM request: {llm_request}")
463
444
 
464
- # Prepare run function
465
445
  run_function = contract.functions.runLLMCompletion(llm_request)
466
446
 
467
- # Build transaction
468
- nonce = self._w3.eth.get_transaction_count(self.wallet_address)
447
+ nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
469
448
  estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
470
449
  gas_limit = int(estimated_gas * 1.2)
471
450
 
@@ -476,47 +455,38 @@ class Client:
476
455
  'gasPrice': self._w3.eth.gas_price,
477
456
  })
478
457
 
479
- # Sign and send transaction
480
458
  signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
481
459
  tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
482
- logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
483
-
484
- # Wait for transaction receipt
485
460
  tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
486
461
 
487
462
  if tx_receipt['status'] == 0:
488
463
  raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
489
464
 
490
- # Process the LLMResult event
491
465
  parsed_logs = contract.events.LLMCompletionResult().process_receipt(tx_receipt, errors=DISCARD)
492
-
493
466
  if len(parsed_logs) < 1:
494
467
  raise OpenGradientError("LLM completion result event not found in transaction logs")
495
- llm_result = parsed_logs[0]
496
468
 
497
- llm_answer = llm_result['args']['response']['answer']
469
+ llm_answer = parsed_logs[0]['args']['response']['answer']
498
470
  return tx_hash.hex(), llm_answer
499
471
 
500
- except ContractLogicError as e:
501
- logging.error(f"Contract logic error: {str(e)}", exc_info=True)
502
- raise OpenGradientError(f"LLM inference failed due to contract logic error: {str(e)}")
503
- except Exception as e:
504
- logging.error(f"Error in infer completion method: {str(e)}", exc_info=True)
505
- raise OpenGradientError(f"LLM inference failed: {str(e)}")
506
-
472
+ return run_with_retry(execute_transaction, max_retries or 5)
473
+
507
474
  def llm_chat(self,
508
475
  model_cid: str,
476
+ inference_mode: InferenceMode,
509
477
  messages: List[Dict],
510
478
  max_tokens: int = 100,
511
479
  stop_sequence: Optional[List[str]] = None,
512
480
  temperature: float = 0.0,
513
481
  tools: Optional[List[Dict]] = [],
514
- tool_choice: Optional[str] = None) -> Tuple[str, str]:
482
+ tool_choice: Optional[str] = None,
483
+ max_retries: Optional[int] = None) -> Tuple[str, str]:
515
484
  """
516
485
  Perform inference on an LLM model using chat.
517
486
 
518
487
  Args:
519
488
  model_cid (LLM): The unique content identifier for the model.
489
+ inference_mode (InferenceMode): The inference mode.
520
490
  messages (dict): The messages that will be passed into the chat.
521
491
  This should be in OpenAI API format (https://platform.openai.com/docs/api-reference/chat/create)
522
492
  Example:
@@ -567,13 +537,16 @@ class Client:
567
537
  Raises:
568
538
  OpenGradientError: If the inference fails.
569
539
  """
570
- try:
571
- self._initialize_web3()
540
+ def execute_transaction():
541
+ # Check inference mode and supported model
542
+ if inference_mode != LlmInferenceMode.VANILLA and inference_mode != LlmInferenceMode.TEE:
543
+ raise OpenGradientError("Invalid inference mode %s: Inference mode must be VANILLA or TEE" % inference_mode)
544
+
545
+ if inference_mode == LlmInferenceMode.TEE and model_cid not in TEE_LLM:
546
+ raise OpenGradientError("That model CID is not supported yet supported for TEE inference")
572
547
 
573
- abi_path = os.path.join(os.path.dirname(__file__), 'abi', 'inference.abi')
574
- with open(abi_path, 'r') as abi_file:
575
- llm_abi = json.load(abi_file)
576
- contract = self._w3.eth.contract(address=self.contract_address, abi=llm_abi)
548
+ self._initialize_web3()
549
+ contract = self._w3.eth.contract(address=self.contract_address, abi=self.abi)
577
550
 
578
551
  # For incoming chat messages, tool_calls can be empty. Add an empty array so that it will fit the ABI.
579
552
  for message in messages:
@@ -585,17 +558,10 @@ class Client:
585
558
  message['name'] = ""
586
559
 
587
560
  # Create simplified tool structure for smart contract
588
- #
589
- # struct ToolDefinition {
590
- # string description;
591
- # string name;
592
- # string parameters; // This must be a JSON
593
- # }
594
561
  converted_tools = []
595
562
  if tools is not None:
596
563
  for tool in tools:
597
564
  function = tool['function']
598
-
599
565
  converted_tool = {}
600
566
  converted_tool['name'] = function['name']
601
567
  converted_tool['description'] = function['description']
@@ -604,12 +570,11 @@ class Client:
604
570
  converted_tool['parameters'] = json.dumps(parameters)
605
571
  except Exception as e:
606
572
  raise OpenGradientError("Chat LLM failed to convert parameters into JSON: %s", e)
607
-
608
573
  converted_tools.append(converted_tool)
609
574
 
610
575
  # Prepare LLM input
611
576
  llm_request = {
612
- "mode": InferenceMode.VANILLA,
577
+ "mode": inference_mode,
613
578
  "modelCID": model_cid,
614
579
  "messages": messages,
615
580
  "max_tokens": max_tokens,
@@ -620,11 +585,9 @@ class Client:
620
585
  }
621
586
  logging.debug(f"Prepared LLM request: {llm_request}")
622
587
 
623
- # Prepare run function
624
588
  run_function = contract.functions.runLLMChat(llm_request)
625
589
 
626
- # Build transaction
627
- nonce = self._w3.eth.get_transaction_count(self.wallet_address)
590
+ nonce = self._w3.eth.get_transaction_count(self.wallet_address, 'pending')
628
591
  estimated_gas = run_function.estimate_gas({'from': self.wallet_address})
629
592
  gas_limit = int(estimated_gas * 1.2)
630
593
 
@@ -635,41 +598,25 @@ class Client:
635
598
  'gasPrice': self._w3.eth.gas_price,
636
599
  })
637
600
 
638
- # Sign and send transaction
639
601
  signed_tx = self._w3.eth.account.sign_transaction(transaction, self.private_key)
640
602
  tx_hash = self._w3.eth.send_raw_transaction(signed_tx.raw_transaction)
641
- logging.debug(f"Transaction sent. Hash: {tx_hash.hex()}")
642
-
643
- # Wait for transaction receipt
644
603
  tx_receipt = self._w3.eth.wait_for_transaction_receipt(tx_hash)
645
604
 
646
605
  if tx_receipt['status'] == 0:
647
606
  raise ContractLogicError(f"Transaction failed. Receipt: {tx_receipt}")
648
607
 
649
- # Process the LLMResult event
650
608
  parsed_logs = contract.events.LLMChatResult().process_receipt(tx_receipt, errors=DISCARD)
651
-
652
609
  if len(parsed_logs) < 1:
653
610
  raise OpenGradientError("LLM chat result event not found in transaction logs")
654
- llm_result = parsed_logs[0]['args']['response']
655
611
 
656
- # Turn tool calls into normal dicts
612
+ llm_result = parsed_logs[0]['args']['response']
657
613
  message = dict(llm_result['message'])
658
- if (tool_calls := message.get('tool_calls')) != None:
659
- new_tool_calls = []
660
- for tool_call in tool_calls:
661
- new_tool_calls.append(dict(tool_call))
662
- message['tool_calls'] = new_tool_calls
663
-
664
- return (tx_hash.hex(), llm_result['finish_reason'], message)
614
+ if (tool_calls := message.get('tool_calls')) is not None:
615
+ message['tool_calls'] = [dict(tool_call) for tool_call in tool_calls]
665
616
 
666
- except ContractLogicError as e:
667
- logging.error(f"Contract logic error: {str(e)}", exc_info=True)
668
- raise OpenGradientError(f"LLM inference failed due to contract logic error: {str(e)}")
669
- except Exception as e:
670
- logging.error(f"Error in infer chat method: {str(e)}", exc_info=True)
671
- raise OpenGradientError(f"LLM inference failed: {str(e)}")
617
+ return tx_hash.hex(), llm_result['finish_reason'], message
672
618
 
619
+ return run_with_retry(execute_transaction, max_retries or 5)
673
620
 
674
621
  def list_files(self, model_name: str, version: str) -> List[Dict]:
675
622
  """
@@ -27,6 +27,10 @@ class InferenceMode:
27
27
  ZKML = 1
28
28
  TEE = 2
29
29
 
30
+ class LlmInferenceMode:
31
+ VANILLA = 0
32
+ TEE = 1
33
+
30
34
  @dataclass
31
35
  class ModelOutput:
32
36
  numbers: List[NumberTensor]
@@ -79,4 +83,7 @@ class LLM(str, Enum):
79
83
  LLAMA_3_2_3B_INSTRUCT = "meta-llama/Llama-3.2-3B-Instruct"
80
84
  MISTRAL_7B_INSTRUCT_V3 = "mistralai/Mistral-7B-Instruct-v0.3"
81
85
  HERMES_3_LLAMA_3_1_70B = "NousResearch/Hermes-3-Llama-3.1-70B"
86
+ META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
87
+
88
+ class TEE_LLM(str, Enum):
82
89
  META_LLAMA_3_1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
File without changes
File without changes
File without changes