tinygrad 0.9.0__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. tinygrad/codegen/__init__.py +0 -0
  2. tinygrad/codegen/kernel.py +78 -90
  3. tinygrad/codegen/linearizer.py +237 -169
  4. tinygrad/codegen/uops.py +278 -242
  5. tinygrad/device.py +147 -10
  6. tinygrad/dtype.py +7 -7
  7. tinygrad/engine/graph.py +16 -16
  8. tinygrad/engine/jit.py +39 -36
  9. tinygrad/engine/realize.py +6 -5
  10. tinygrad/engine/schedule.py +15 -7
  11. tinygrad/engine/search.py +6 -3
  12. tinygrad/function.py +17 -23
  13. tinygrad/helpers.py +77 -8
  14. tinygrad/lazy.py +26 -26
  15. tinygrad/multi.py +13 -9
  16. tinygrad/nn/__init__.py +1 -1
  17. tinygrad/nn/datasets.py +2 -1
  18. tinygrad/nn/state.py +3 -4
  19. tinygrad/ops.py +49 -16
  20. tinygrad/renderer/__init__.py +8 -4
  21. tinygrad/renderer/assembly.py +93 -100
  22. tinygrad/renderer/cstyle.py +47 -42
  23. tinygrad/renderer/llvmir.py +30 -30
  24. tinygrad/runtime/__init__.py +0 -0
  25. tinygrad/runtime/autogen/amd_gpu.py +11504 -1
  26. tinygrad/runtime/autogen/comgr.py +36 -10
  27. tinygrad/runtime/autogen/hsa.py +146 -14
  28. tinygrad/runtime/autogen/io_uring.py +1486 -0
  29. tinygrad/runtime/autogen/nv_gpu.py +269 -0
  30. tinygrad/runtime/driver/__init__.py +0 -0
  31. tinygrad/runtime/driver/hip_comgr.py +20 -11
  32. tinygrad/runtime/graph/__init__.py +0 -0
  33. tinygrad/runtime/graph/clang.py +3 -2
  34. tinygrad/runtime/graph/cuda.py +2 -2
  35. tinygrad/runtime/graph/hcq.py +122 -78
  36. tinygrad/runtime/ops_amd.py +302 -316
  37. tinygrad/runtime/ops_cuda.py +3 -3
  38. tinygrad/runtime/ops_disk.py +70 -5
  39. tinygrad/runtime/ops_gpu.py +2 -2
  40. tinygrad/runtime/ops_metal.py +5 -6
  41. tinygrad/runtime/ops_npy.py +1 -1
  42. tinygrad/runtime/ops_nv.py +161 -166
  43. tinygrad/runtime/ops_python.py +20 -16
  44. tinygrad/shape/__init__.py +0 -0
  45. tinygrad/shape/shapetracker.py +5 -2
  46. tinygrad/shape/symbolic.py +1 -3
  47. tinygrad/shape/view.py +34 -19
  48. tinygrad/tensor.py +219 -135
  49. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/METADATA +14 -6
  50. tinygrad-0.9.1.dist-info/RECORD +63 -0
  51. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/WHEEL +1 -1
  52. tinygrad/runtime/driver/hsa.py +0 -143
  53. tinygrad/runtime/graph/hsa.py +0 -171
  54. tinygrad/runtime/ops_hsa.py +0 -278
  55. tinygrad-0.9.0.dist-info/RECORD +0 -60
  56. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/LICENSE +0 -0
  57. {tinygrad-0.9.0.dist-info → tinygrad-0.9.1.dist-info}/top_level.txt +0 -0
@@ -33326,3 +33326,272 @@ __all__ = \
33326
33326
  'union_NV2080_CTRL_NVLINK_CALLBACK_TYPE_callbackParams',
33327
33327
  'union_NV2080_CTRL_NVLINK_INJECT_TLC_ERROR_TYPE',
33328
33328
  'union_RM_GSP_SPDM_CMD', 'union_c__SA_NVOS32_PARAMETERS_data']
33329
+ nv_status_codes = {}
33330
+ NV_OK = 0x00000000
33331
+ nv_status_codes[NV_OK] = "Success"
33332
+ NV_ERR_GENERIC = 0x0000FFFF
33333
+ nv_status_codes[NV_ERR_GENERIC] = "Failure: Generic Error"
33334
+ NV_ERR_BROKEN_FB = 0x00000001
33335
+ nv_status_codes[NV_ERR_BROKEN_FB] = "Frame-Buffer broken"
33336
+ NV_ERR_BUFFER_TOO_SMALL = 0x00000002
33337
+ nv_status_codes[NV_ERR_BUFFER_TOO_SMALL] = "Buffer passed in is too small"
33338
+ NV_ERR_BUSY_RETRY = 0x00000003
33339
+ nv_status_codes[NV_ERR_BUSY_RETRY] = "System is busy, retry later"
33340
+ NV_ERR_CALLBACK_NOT_SCHEDULED = 0x00000004
33341
+ nv_status_codes[NV_ERR_CALLBACK_NOT_SCHEDULED] = "The requested callback API not scheduled"
33342
+ NV_ERR_CARD_NOT_PRESENT = 0x00000005
33343
+ nv_status_codes[NV_ERR_CARD_NOT_PRESENT] = "Card not detected"
33344
+ NV_ERR_CYCLE_DETECTED = 0x00000006
33345
+ nv_status_codes[NV_ERR_CYCLE_DETECTED] = "Call cycle detected"
33346
+ NV_ERR_DMA_IN_USE = 0x00000007
33347
+ nv_status_codes[NV_ERR_DMA_IN_USE] = "Requested DMA is in use"
33348
+ NV_ERR_DMA_MEM_NOT_LOCKED = 0x00000008
33349
+ nv_status_codes[NV_ERR_DMA_MEM_NOT_LOCKED] = "Requested DMA memory is not locked"
33350
+ NV_ERR_DMA_MEM_NOT_UNLOCKED = 0x00000009
33351
+ nv_status_codes[NV_ERR_DMA_MEM_NOT_UNLOCKED] = "Requested DMA memory is not unlocked"
33352
+ NV_ERR_DUAL_LINK_INUSE = 0x0000000A
33353
+ nv_status_codes[NV_ERR_DUAL_LINK_INUSE] = "Dual-Link is in use"
33354
+ NV_ERR_ECC_ERROR = 0x0000000B
33355
+ nv_status_codes[NV_ERR_ECC_ERROR] = "Generic ECC error"
33356
+ NV_ERR_FIFO_BAD_ACCESS = 0x0000000C
33357
+ nv_status_codes[NV_ERR_FIFO_BAD_ACCESS] = "FIFO: Invalid access"
33358
+ NV_ERR_FREQ_NOT_SUPPORTED = 0x0000000D
33359
+ nv_status_codes[NV_ERR_FREQ_NOT_SUPPORTED] = "Requested frequency is not supported"
33360
+ NV_ERR_GPU_DMA_NOT_INITIALIZED = 0x0000000E
33361
+ nv_status_codes[NV_ERR_GPU_DMA_NOT_INITIALIZED] = "Requested DMA not initialized"
33362
+ NV_ERR_GPU_IS_LOST = 0x0000000F
33363
+ nv_status_codes[NV_ERR_GPU_IS_LOST] = "GPU lost from the bus"
33364
+ NV_ERR_GPU_IN_FULLCHIP_RESET = 0x00000010
33365
+ nv_status_codes[NV_ERR_GPU_IN_FULLCHIP_RESET] = "GPU currently in full-chip reset"
33366
+ NV_ERR_GPU_NOT_FULL_POWER = 0x00000011
33367
+ nv_status_codes[NV_ERR_GPU_NOT_FULL_POWER] = "GPU not in full power"
33368
+ NV_ERR_GPU_UUID_NOT_FOUND = 0x00000012
33369
+ nv_status_codes[NV_ERR_GPU_UUID_NOT_FOUND] = "GPU UUID not found"
33370
+ NV_ERR_HOT_SWITCH = 0x00000013
33371
+ nv_status_codes[NV_ERR_HOT_SWITCH] = "System in hot switch"
33372
+ NV_ERR_I2C_ERROR = 0x00000014
33373
+ nv_status_codes[NV_ERR_I2C_ERROR] = "I2C Error"
33374
+ NV_ERR_I2C_SPEED_TOO_HIGH = 0x00000015
33375
+ nv_status_codes[NV_ERR_I2C_SPEED_TOO_HIGH] = "I2C Error: Speed too high"
33376
+ NV_ERR_ILLEGAL_ACTION = 0x00000016
33377
+ nv_status_codes[NV_ERR_ILLEGAL_ACTION] = "Current action is not allowed"
33378
+ NV_ERR_IN_USE = 0x00000017
33379
+ nv_status_codes[NV_ERR_IN_USE] = "Generic busy error"
33380
+ NV_ERR_INFLATE_COMPRESSED_DATA_FAILED = 0x00000018
33381
+ nv_status_codes[NV_ERR_INFLATE_COMPRESSED_DATA_FAILED] = "Failed to inflate compressed data"
33382
+ NV_ERR_INSERT_DUPLICATE_NAME = 0x00000019
33383
+ nv_status_codes[NV_ERR_INSERT_DUPLICATE_NAME] = "Found a duplicate entry in the requested btree"
33384
+ NV_ERR_INSUFFICIENT_RESOURCES = 0x0000001A
33385
+ nv_status_codes[NV_ERR_INSUFFICIENT_RESOURCES] = "Ran out of a critical resource, other than memory"
33386
+ NV_ERR_INSUFFICIENT_PERMISSIONS = 0x0000001B
33387
+ nv_status_codes[NV_ERR_INSUFFICIENT_PERMISSIONS] = "The requester does not have sufficient permissions"
33388
+ NV_ERR_INSUFFICIENT_POWER = 0x0000001C
33389
+ nv_status_codes[NV_ERR_INSUFFICIENT_POWER] = "Generic Error: Low power"
33390
+ NV_ERR_INVALID_ACCESS_TYPE = 0x0000001D
33391
+ nv_status_codes[NV_ERR_INVALID_ACCESS_TYPE] = "This type of access is not allowed"
33392
+ NV_ERR_INVALID_ADDRESS = 0x0000001E
33393
+ nv_status_codes[NV_ERR_INVALID_ADDRESS] = "Address not valid"
33394
+ NV_ERR_INVALID_ARGUMENT = 0x0000001F
33395
+ nv_status_codes[NV_ERR_INVALID_ARGUMENT] = "Invalid argument to call"
33396
+ NV_ERR_INVALID_BASE = 0x00000020
33397
+ nv_status_codes[NV_ERR_INVALID_BASE] = "Invalid base"
33398
+ NV_ERR_INVALID_CHANNEL = 0x00000021
33399
+ nv_status_codes[NV_ERR_INVALID_CHANNEL] = "Given channel-id not valid"
33400
+ NV_ERR_INVALID_CLASS = 0x00000022
33401
+ nv_status_codes[NV_ERR_INVALID_CLASS] = "Given class-id not valid"
33402
+ NV_ERR_INVALID_CLIENT = 0x00000023
33403
+ nv_status_codes[NV_ERR_INVALID_CLIENT] = "Given client not valid"
33404
+ NV_ERR_INVALID_COMMAND = 0x00000024
33405
+ nv_status_codes[NV_ERR_INVALID_COMMAND] = "Command passed is not valid"
33406
+ NV_ERR_INVALID_DATA = 0x00000025
33407
+ nv_status_codes[NV_ERR_INVALID_DATA] = "Invalid data passed"
33408
+ NV_ERR_INVALID_DEVICE = 0x00000026
33409
+ nv_status_codes[NV_ERR_INVALID_DEVICE] = "Current device is not valid"
33410
+ NV_ERR_INVALID_DMA_SPECIFIER = 0x00000027
33411
+ nv_status_codes[NV_ERR_INVALID_DMA_SPECIFIER] = "The requested DMA specifier is not valid"
33412
+ NV_ERR_INVALID_EVENT = 0x00000028
33413
+ nv_status_codes[NV_ERR_INVALID_EVENT] = "Invalid event occurred"
33414
+ NV_ERR_INVALID_FLAGS = 0x00000029
33415
+ nv_status_codes[NV_ERR_INVALID_FLAGS] = "Invalid flags passed"
33416
+ NV_ERR_INVALID_FUNCTION = 0x0000002A
33417
+ nv_status_codes[NV_ERR_INVALID_FUNCTION] = "Called function is not valid"
33418
+ NV_ERR_INVALID_HEAP = 0x0000002B
33419
+ nv_status_codes[NV_ERR_INVALID_HEAP] = "Heap corrupted"
33420
+ NV_ERR_INVALID_INDEX = 0x0000002C
33421
+ nv_status_codes[NV_ERR_INVALID_INDEX] = "Index invalid"
33422
+ NV_ERR_INVALID_IRQ_LEVEL = 0x0000002D
33423
+ nv_status_codes[NV_ERR_INVALID_IRQ_LEVEL] = "Requested IRQ level is not valid"
33424
+ NV_ERR_INVALID_LIMIT = 0x0000002E
33425
+ nv_status_codes[NV_ERR_INVALID_LIMIT] = "Generic Error: Invalid limit"
33426
+ NV_ERR_INVALID_LOCK_STATE = 0x0000002F
33427
+ nv_status_codes[NV_ERR_INVALID_LOCK_STATE] = "Requested lock state not valid"
33428
+ NV_ERR_INVALID_METHOD = 0x00000030
33429
+ nv_status_codes[NV_ERR_INVALID_METHOD] = "Requested method not valid"
33430
+ NV_ERR_INVALID_OBJECT = 0x00000031
33431
+ nv_status_codes[NV_ERR_INVALID_OBJECT] = "Object not valid"
33432
+ NV_ERR_INVALID_OBJECT_BUFFER = 0x00000032
33433
+ nv_status_codes[NV_ERR_INVALID_OBJECT_BUFFER] = "Object buffer passed is not valid"
33434
+ NV_ERR_INVALID_OBJECT_HANDLE = 0x00000033
33435
+ nv_status_codes[NV_ERR_INVALID_OBJECT_HANDLE] = "Object handle is not valid"
33436
+ NV_ERR_INVALID_OBJECT_NEW = 0x00000034
33437
+ nv_status_codes[NV_ERR_INVALID_OBJECT_NEW] = "New object is not valid"
33438
+ NV_ERR_INVALID_OBJECT_OLD = 0x00000035
33439
+ nv_status_codes[NV_ERR_INVALID_OBJECT_OLD] = "Old object is not valid"
33440
+ NV_ERR_INVALID_OBJECT_PARENT = 0x00000036
33441
+ nv_status_codes[NV_ERR_INVALID_OBJECT_PARENT] = "Object parent is not valid"
33442
+ NV_ERR_INVALID_OFFSET = 0x00000037
33443
+ nv_status_codes[NV_ERR_INVALID_OFFSET] = "The offset passed is not valid"
33444
+ NV_ERR_INVALID_OPERATION = 0x00000038
33445
+ nv_status_codes[NV_ERR_INVALID_OPERATION] = "Requested operation is not valid"
33446
+ NV_ERR_INVALID_OWNER = 0x00000039
33447
+ nv_status_codes[NV_ERR_INVALID_OWNER] = "Owner not valid"
33448
+ NV_ERR_INVALID_PARAM_STRUCT = 0x0000003A
33449
+ nv_status_codes[NV_ERR_INVALID_PARAM_STRUCT] = "Invalid structure parameter"
33450
+ NV_ERR_INVALID_PARAMETER = 0x0000003B
33451
+ nv_status_codes[NV_ERR_INVALID_PARAMETER] = "At least one of the parameters passed is not valid"
33452
+ NV_ERR_INVALID_PATH = 0x0000003C
33453
+ nv_status_codes[NV_ERR_INVALID_PATH] = "The requested path is not valid"
33454
+ NV_ERR_INVALID_POINTER = 0x0000003D
33455
+ nv_status_codes[NV_ERR_INVALID_POINTER] = "Pointer not valid"
33456
+ NV_ERR_INVALID_REGISTRY_KEY = 0x0000003E
33457
+ nv_status_codes[NV_ERR_INVALID_REGISTRY_KEY] = "Found an invalid registry key"
33458
+ NV_ERR_INVALID_REQUEST = 0x0000003F
33459
+ nv_status_codes[NV_ERR_INVALID_REQUEST] = "Generic Error: Invalid request"
33460
+ NV_ERR_INVALID_STATE = 0x00000040
33461
+ nv_status_codes[NV_ERR_INVALID_STATE] = "Generic Error: Invalid state"
33462
+ NV_ERR_INVALID_STRING_LENGTH = 0x00000041
33463
+ nv_status_codes[NV_ERR_INVALID_STRING_LENGTH] = "The string length is not valid"
33464
+ NV_ERR_INVALID_READ = 0x00000042
33465
+ nv_status_codes[NV_ERR_INVALID_READ] = "The requested read operation is not valid"
33466
+ NV_ERR_INVALID_WRITE = 0x00000043
33467
+ nv_status_codes[NV_ERR_INVALID_WRITE] = "The requested write operation is not valid"
33468
+ NV_ERR_INVALID_XLATE = 0x00000044
33469
+ nv_status_codes[NV_ERR_INVALID_XLATE] = "The requested translate operation is not valid"
33470
+ NV_ERR_IRQ_NOT_FIRING = 0x00000045
33471
+ nv_status_codes[NV_ERR_IRQ_NOT_FIRING] = "Requested IRQ is not firing"
33472
+ NV_ERR_IRQ_EDGE_TRIGGERED = 0x00000046
33473
+ nv_status_codes[NV_ERR_IRQ_EDGE_TRIGGERED] = "IRQ is edge triggered"
33474
+ NV_ERR_MEMORY_TRAINING_FAILED = 0x00000047
33475
+ nv_status_codes[NV_ERR_MEMORY_TRAINING_FAILED] = "Failed memory training sequence"
33476
+ NV_ERR_MISMATCHED_SLAVE = 0x00000048
33477
+ nv_status_codes[NV_ERR_MISMATCHED_SLAVE] = "Slave mismatch"
33478
+ NV_ERR_MISMATCHED_TARGET = 0x00000049
33479
+ nv_status_codes[NV_ERR_MISMATCHED_TARGET] = "Target mismatch"
33480
+ NV_ERR_MISSING_TABLE_ENTRY = 0x0000004A
33481
+ nv_status_codes[NV_ERR_MISSING_TABLE_ENTRY] = "Requested entry missing not found in the table"
33482
+ NV_ERR_MODULE_LOAD_FAILED = 0x0000004B
33483
+ nv_status_codes[NV_ERR_MODULE_LOAD_FAILED] = "Failed to load the requested module"
33484
+ NV_ERR_MORE_DATA_AVAILABLE = 0x0000004C
33485
+ nv_status_codes[NV_ERR_MORE_DATA_AVAILABLE] = "There is more data available"
33486
+ NV_ERR_MORE_PROCESSING_REQUIRED = 0x0000004D
33487
+ nv_status_codes[NV_ERR_MORE_PROCESSING_REQUIRED] = "More processing required for the given call"
33488
+ NV_ERR_MULTIPLE_MEMORY_TYPES = 0x0000004E
33489
+ nv_status_codes[NV_ERR_MULTIPLE_MEMORY_TYPES] = "Multiple memory types found"
33490
+ NV_ERR_NO_FREE_FIFOS = 0x0000004F
33491
+ nv_status_codes[NV_ERR_NO_FREE_FIFOS] = "No more free FIFOs found"
33492
+ NV_ERR_NO_INTR_PENDING = 0x00000050
33493
+ nv_status_codes[NV_ERR_NO_INTR_PENDING] = "No interrupt pending"
33494
+ NV_ERR_NO_MEMORY = 0x00000051
33495
+ nv_status_codes[NV_ERR_NO_MEMORY] = "Out of memory"
33496
+ NV_ERR_NO_SUCH_DOMAIN = 0x00000052
33497
+ nv_status_codes[NV_ERR_NO_SUCH_DOMAIN] = "Requested domain does not exist"
33498
+ NV_ERR_NO_VALID_PATH = 0x00000053
33499
+ nv_status_codes[NV_ERR_NO_VALID_PATH] = "Caller did not specify a valid path"
33500
+ NV_ERR_NOT_COMPATIBLE = 0x00000054
33501
+ nv_status_codes[NV_ERR_NOT_COMPATIBLE] = "Generic Error: Incompatible types"
33502
+ NV_ERR_NOT_READY = 0x00000055
33503
+ nv_status_codes[NV_ERR_NOT_READY] = "Generic Error: Not ready"
33504
+ NV_ERR_NOT_SUPPORTED = 0x00000056
33505
+ nv_status_codes[NV_ERR_NOT_SUPPORTED] = "Call not supported"
33506
+ NV_ERR_OBJECT_NOT_FOUND = 0x00000057
33507
+ nv_status_codes[NV_ERR_OBJECT_NOT_FOUND] = "Requested object not found"
33508
+ NV_ERR_OBJECT_TYPE_MISMATCH = 0x00000058
33509
+ nv_status_codes[NV_ERR_OBJECT_TYPE_MISMATCH] = "Specified objects do not match"
33510
+ NV_ERR_OPERATING_SYSTEM = 0x00000059
33511
+ nv_status_codes[NV_ERR_OPERATING_SYSTEM] = "Generic operating system error"
33512
+ NV_ERR_OTHER_DEVICE_FOUND = 0x0000005A
33513
+ nv_status_codes[NV_ERR_OTHER_DEVICE_FOUND] = "Found other device instead of the requested one"
33514
+ NV_ERR_OUT_OF_RANGE = 0x0000005B
33515
+ nv_status_codes[NV_ERR_OUT_OF_RANGE] = "The specified value is out of bounds"
33516
+ NV_ERR_OVERLAPPING_UVM_COMMIT = 0x0000005C
33517
+ nv_status_codes[NV_ERR_OVERLAPPING_UVM_COMMIT] = "Overlapping unified virtual memory commit"
33518
+ NV_ERR_PAGE_TABLE_NOT_AVAIL = 0x0000005D
33519
+ nv_status_codes[NV_ERR_PAGE_TABLE_NOT_AVAIL] = "Requested page table not available"
33520
+ NV_ERR_PID_NOT_FOUND = 0x0000005E
33521
+ nv_status_codes[NV_ERR_PID_NOT_FOUND] = "Process-Id not found"
33522
+ NV_ERR_PROTECTION_FAULT = 0x0000005F
33523
+ nv_status_codes[NV_ERR_PROTECTION_FAULT] = "Protection fault"
33524
+ NV_ERR_RC_ERROR = 0x00000060
33525
+ nv_status_codes[NV_ERR_RC_ERROR] = "Generic RC error"
33526
+ NV_ERR_REJECTED_VBIOS = 0x00000061
33527
+ nv_status_codes[NV_ERR_REJECTED_VBIOS] = "Given Video BIOS rejected/invalid"
33528
+ NV_ERR_RESET_REQUIRED = 0x00000062
33529
+ nv_status_codes[NV_ERR_RESET_REQUIRED] = "Reset required"
33530
+ NV_ERR_STATE_IN_USE = 0x00000063
33531
+ nv_status_codes[NV_ERR_STATE_IN_USE] = "State in use"
33532
+ NV_ERR_SIGNAL_PENDING = 0x00000064
33533
+ nv_status_codes[NV_ERR_SIGNAL_PENDING] = "Signal pending"
33534
+ NV_ERR_TIMEOUT = 0x00000065
33535
+ nv_status_codes[NV_ERR_TIMEOUT] = "Call timed out"
33536
+ NV_ERR_TIMEOUT_RETRY = 0x00000066
33537
+ nv_status_codes[NV_ERR_TIMEOUT_RETRY] = "Call timed out, please retry later"
33538
+ NV_ERR_TOO_MANY_PRIMARIES = 0x00000067
33539
+ nv_status_codes[NV_ERR_TOO_MANY_PRIMARIES] = "Too many primaries"
33540
+ NV_ERR_UVM_ADDRESS_IN_USE = 0x00000068
33541
+ nv_status_codes[NV_ERR_UVM_ADDRESS_IN_USE] = "Unified virtual memory requested address already in use"
33542
+ NV_ERR_MAX_SESSION_LIMIT_REACHED = 0x00000069
33543
+ nv_status_codes[NV_ERR_MAX_SESSION_LIMIT_REACHED] = "Maximum number of sessions reached"
33544
+ NV_ERR_LIB_RM_VERSION_MISMATCH = 0x0000006A
33545
+ nv_status_codes[NV_ERR_LIB_RM_VERSION_MISMATCH] = "Library version doesn't match driver version"
33546
+ NV_ERR_PRIV_SEC_VIOLATION = 0x0000006B
33547
+ nv_status_codes[NV_ERR_PRIV_SEC_VIOLATION] = "Priv security violation"
33548
+ NV_ERR_GPU_IN_DEBUG_MODE = 0x0000006C
33549
+ nv_status_codes[NV_ERR_GPU_IN_DEBUG_MODE] = "GPU currently in debug mode"
33550
+ NV_ERR_FEATURE_NOT_ENABLED = 0x0000006D
33551
+ nv_status_codes[NV_ERR_FEATURE_NOT_ENABLED] = "Requested Feature functionality is not enabled"
33552
+ NV_ERR_RESOURCE_LOST = 0x0000006E
33553
+ nv_status_codes[NV_ERR_RESOURCE_LOST] = "Requested resource has been destroyed"
33554
+ NV_ERR_PMU_NOT_READY = 0x0000006F
33555
+ nv_status_codes[NV_ERR_PMU_NOT_READY] = "PMU is not ready or has not yet been initialized"
33556
+ NV_ERR_FLCN_ERROR = 0x00000070
33557
+ nv_status_codes[NV_ERR_FLCN_ERROR] = "Generic falcon assert or halt"
33558
+ NV_ERR_FATAL_ERROR = 0x00000071
33559
+ nv_status_codes[NV_ERR_FATAL_ERROR] = "Fatal/unrecoverable error"
33560
+ NV_ERR_MEMORY_ERROR = 0x00000072
33561
+ nv_status_codes[NV_ERR_MEMORY_ERROR] = "Generic memory error"
33562
+ NV_ERR_INVALID_LICENSE = 0x00000073
33563
+ nv_status_codes[NV_ERR_INVALID_LICENSE] = "License provided is rejected or invalid"
33564
+ NV_ERR_NVLINK_INIT_ERROR = 0x00000074
33565
+ nv_status_codes[NV_ERR_NVLINK_INIT_ERROR] = "Nvlink Init Error"
33566
+ NV_ERR_NVLINK_MINION_ERROR = 0x00000075
33567
+ nv_status_codes[NV_ERR_NVLINK_MINION_ERROR] = "Nvlink Minion Error"
33568
+ NV_ERR_NVLINK_CLOCK_ERROR = 0x00000076
33569
+ nv_status_codes[NV_ERR_NVLINK_CLOCK_ERROR] = "Nvlink Clock Error"
33570
+ NV_ERR_NVLINK_TRAINING_ERROR = 0x00000077
33571
+ nv_status_codes[NV_ERR_NVLINK_TRAINING_ERROR] = "Nvlink Training Error"
33572
+ NV_ERR_NVLINK_CONFIGURATION_ERROR = 0x00000078
33573
+ nv_status_codes[NV_ERR_NVLINK_CONFIGURATION_ERROR] = "Nvlink Configuration Error"
33574
+ NV_ERR_RISCV_ERROR = 0x00000079
33575
+ nv_status_codes[NV_ERR_RISCV_ERROR] = "Generic RISC-V assert or halt"
33576
+ NV_ERR_FABRIC_MANAGER_NOT_PRESENT = 0x0000007A
33577
+ nv_status_codes[NV_ERR_FABRIC_MANAGER_NOT_PRESENT] = "Fabric Manager is not loaded"
33578
+ NV_ERR_ALREADY_SIGNALLED = 0x0000007B
33579
+ nv_status_codes[NV_ERR_ALREADY_SIGNALLED] = "Semaphore Surface value already >= requested wait value"
33580
+ NV_ERR_QUEUE_TASK_SLOT_NOT_AVAILABLE = 0x0000007C
33581
+ nv_status_codes[NV_ERR_QUEUE_TASK_SLOT_NOT_AVAILABLE] = "PMU RPC error due to no queue slot available for this event"
33582
+ NV_WARN_HOT_SWITCH = 0x00010001
33583
+ nv_status_codes[NV_WARN_HOT_SWITCH] = "WARNING Hot switch"
33584
+ NV_WARN_INCORRECT_PERFMON_DATA = 0x00010002
33585
+ nv_status_codes[NV_WARN_INCORRECT_PERFMON_DATA] = "WARNING Incorrect performance monitor data"
33586
+ NV_WARN_MISMATCHED_SLAVE = 0x00010003
33587
+ nv_status_codes[NV_WARN_MISMATCHED_SLAVE] = "WARNING Slave mismatch"
33588
+ NV_WARN_MISMATCHED_TARGET = 0x00010004
33589
+ nv_status_codes[NV_WARN_MISMATCHED_TARGET] = "WARNING Target mismatch"
33590
+ NV_WARN_MORE_PROCESSING_REQUIRED = 0x00010005
33591
+ nv_status_codes[NV_WARN_MORE_PROCESSING_REQUIRED] = "WARNING More processing required for the call"
33592
+ NV_WARN_NOTHING_TO_DO = 0x00010006
33593
+ nv_status_codes[NV_WARN_NOTHING_TO_DO] = "WARNING Nothing to do"
33594
+ NV_WARN_NULL_OBJECT = 0x00010007
33595
+ nv_status_codes[NV_WARN_NULL_OBJECT] = "WARNING NULL object found"
33596
+ NV_WARN_OUT_OF_RANGE = 0x00010008
33597
+ nv_status_codes[NV_WARN_OUT_OF_RANGE] = "WARNING value out of range"
File without changes
@@ -14,7 +14,7 @@ def _get_comgr_data(data_set, data_type):
14
14
  return bytes(dat)
15
15
 
16
16
  # AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
17
- def compile_hip(prg:str, arch="gfx1100") -> bytes:
17
+ def compile_hip(prg:str, arch="gfx1100", asm=False) -> bytes:
18
18
  check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
19
19
  check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
20
20
  check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
@@ -27,17 +27,26 @@ def compile_hip(prg:str, arch="gfx1100") -> bytes:
27
27
 
28
28
  check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
29
29
  check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
30
- check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
31
30
 
32
- check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
33
- # -include hiprtc_runtime.h was removed
34
- check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
35
- status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
36
- if status != 0:
37
- print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
38
- raise RuntimeError("compile failed")
39
- check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
40
- check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
31
+ if asm:
32
+ check(comgr.amd_comgr_set_data_name(data_src, b"<null>.s"))
33
+ check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
34
+ status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_ASSEMBLE_SOURCE_TO_RELOCATABLE, action_info, data_set_src, data_set_reloc)
35
+ if status != 0:
36
+ print(_get_comgr_data(data_set_reloc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
37
+ raise RuntimeError("assemble failed")
38
+ else:
39
+ check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
40
+ check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
41
+ # -include hiprtc_runtime.h was removed
42
+ check(comgr.amd_comgr_action_info_set_options(action_info, f"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch={arch} -I/opt/rocm/include -Xclang -disable-llvm-passes".encode())) # noqa: E501
43
+ status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
44
+ if status != 0:
45
+ print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
46
+ raise RuntimeError("compile failed")
47
+ check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
48
+ check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
49
+
41
50
  check(comgr.amd_comgr_action_info_set_options(action_info, b""))
42
51
  check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
43
52
  ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
File without changes
@@ -16,7 +16,7 @@ class ClangGraph(GraphRunner):
16
16
 
17
17
  prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
18
18
  args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
19
- args += [f"int {v.expr}" for v in var_vals]
19
+ args += sorted([f"int {v.expr}" for v in var_vals])
20
20
  code = ["void batched("+','.join(args)+") {"]
21
21
  for ji in jit_cache:
22
22
  args = []
@@ -35,4 +35,5 @@ class ClangGraph(GraphRunner):
35
35
  self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
36
36
 
37
37
  def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
38
- return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
38
+ return cpu_time_execution(
39
+ lambda: self.clprg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)
@@ -1,7 +1,7 @@
1
1
  import ctypes
2
2
  from typing import Any, Optional, Tuple, Dict, List, cast
3
3
  import tinygrad.runtime.autogen.cuda as cuda
4
- from tinygrad.helpers import init_c_var, GraphException
4
+ from tinygrad.helpers import init_c_var, GraphException, dedup
5
5
  from tinygrad.device import Buffer, Device
6
6
  from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
7
7
  from tinygrad.shape.symbolic import Variable
@@ -15,7 +15,7 @@ class CUDAGraph(MultiGraphRunner):
15
15
  # Check all jit items are compatible.
16
16
  if not all(isinstance(ji.prg, (CompiledRunner, BufferXfer)) for ji in jit_cache): raise GraphException
17
17
 
18
- self.jc_idx_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()]))
18
+ self.jc_idx_with_updatable_rawbufs = dedup([x[0] for x in self.input_replace.keys()])
19
19
  self.updatable_nodes: Dict[int, Tuple[Any, Any, Any, bool]] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params, is memcpy)
20
20
 
21
21
  self.graph = init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
@@ -1,134 +1,160 @@
1
- import ctypes, collections, array, time
1
+ import collections, array, time
2
2
  from typing import List, Any, Dict, cast, Optional, Tuple, Set
3
- from tinygrad.helpers import GraphException, round_up, to_mv, init_c_struct_t
3
+ from tinygrad.helpers import round_up, to_mv, PROFILE
4
4
  from tinygrad.device import Buffer, BufferOptions, Compiled, Device
5
5
  from tinygrad.shape.symbolic import Variable
6
6
  from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
7
7
  from tinygrad.engine.jit import MultiGraphRunner
8
8
 
9
9
  class HCQGraph(MultiGraphRunner):
10
- def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
10
+ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
11
11
  super().__init__(jit_cache, input_rawbuffers, var_vals)
12
- self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
13
-
14
- # Check all jit items are compatible.
15
- self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
16
- if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
12
+ self.devices = list(set(cast(Any, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
17
13
 
18
14
  # Allocate kernel args.
19
15
  kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
20
16
  for ji in self.jit_cache:
21
17
  if not isinstance(ji.prg, CompiledRunner): continue
22
- kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
23
- kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
18
+ kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
19
+ self.kernargs_bufs: Dict[Compiled, Any] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
20
+ kernargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
24
21
 
25
22
  # Fill initial arguments.
26
23
  self.kargs_addrs: Dict[int, int] = {}
27
- self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
24
+ self.ji_args_bufs: Dict[int, memoryview] = {}
25
+ self.ji_args_vars: Dict[int, memoryview] = {}
28
26
  for j,ji in enumerate(self.jit_cache):
29
27
  if not isinstance(ji.prg, CompiledRunner): continue
30
28
  self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
31
- kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
29
+ kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
32
30
 
33
- args_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(ji.bufs))] +
34
- [(f'v{i}', ctypes.c_int) for i in range(len(ji.prg.p.vars))]))
35
- self.ji_kargs_structs[j] = args_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
36
- for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
37
- for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
31
+ self.ji_args_bufs[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset, len(ji.bufs) * 8).cast('Q')
32
+ self.ji_args_vars[j] = to_mv(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset + len(ji.bufs) * 8, len(ji.prg.p.vars) * 4).cast('I')
33
+ for i in range(len(ji.bufs)): self.ji_args_bufs[j][i] = cast(Buffer, ji.bufs[i])._buf.va_addr
34
+ for i in range(len(ji.prg.p.vars)): self.ji_args_vars[j][i] = var_vals[ji.prg.p.vars[i]]
38
35
 
39
36
  # NV needs constbuffer to be set
40
37
  if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
41
38
 
42
- # Build queues.
43
- self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
44
- self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
45
- self.comp_signal_val = {dev: 0 for dev in self.devices}
39
+ # Schedule Dependencies.
40
+ # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
41
+ # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
42
+ # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
43
+ # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
44
+ self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
45
+ self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
46
+
47
+ self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)]
48
+ self.signals: Dict[Any, Any] = {q: self.devices[0]._get_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
49
+ self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
50
+ self.kickoff_value = 0
51
+
52
+ self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
53
+ for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
46
54
 
47
- self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
48
- self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
49
- self.copy_signal_val = {dev: 0 for dev in self.devices}
55
+ self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
56
+ self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
50
57
 
51
- self.kickoff_signal = self.devices[0]._get_signal(value=0)
52
- self.kickoff_value = 0
53
- self.graph_timeline = {dev: 0 for dev in self.devices}
58
+ for j,ji in enumerate(self.jit_cache):
59
+ enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
60
+ enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
61
+ out_signal = self.signals[enqueue_queue]
62
+ writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
63
+ deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
64
+
65
+ if isinstance(ji.prg, CompiledRunner):
66
+ # Update signal on compute kernel to depend on the previous kernel.
67
+ if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
54
68
 
69
+ # Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
70
+ if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
71
+ deps = [x for x in deps if id(x[0]) != id(out_signal)]
72
+ elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
73
+
74
+ # Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
75
+ for sig, val in deps:
76
+ if id(sig) in [id(x) for x in self.signals.values()]:
77
+ self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:]
78
+
79
+ prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
80
+ prof_info = ([enqueue_dev._get_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
81
+ self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
82
+ self.last_ji[enqueue_queue] = j
83
+
84
+ # Build hardware queues.
55
85
  self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
56
86
  self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
87
+ self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
88
+
89
+ for dev in self.devices:
90
+ self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
91
+ .wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
57
92
 
58
93
  for j,ji in enumerate(self.jit_cache):
94
+ deps, signal_value, prof_info = self.signal_sched[j]
95
+ enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
96
+
97
+ # Encode waits and start profile timestamp (if needed).
98
+ for sig, val in deps:
99
+ enqueue_queue.wait(sig, val)
100
+ if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
101
+ if prof_info: enqueue_queue.timestamp(prof_info[0])
102
+
103
+ # Encode main commands based on ji type.
59
104
  if isinstance(ji.prg, CompiledRunner):
60
- exec_params = {}
61
- deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
62
- deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
63
-
64
- # On NV, to synchronize kernel execution, we must either issue a wait or chain executions to schedule them in order.
65
- # Chaining executions is preferred when possible, as it is faster.
66
- if ji.prg.device.dname.startswith("NV"):
67
- if len(deps) == 0 and self.comp_signal_val[ji.prg.device] > 0:
68
- exec_params['chain_exec_ptr'] = self.exec_ptrs[self.comp_signal_val[ji.prg.device] - 1][1]
69
- else: deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
70
-
71
- for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
72
-
73
- self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
74
- self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
75
- signal=self.comp_signal[ji.prg.device], signal_value=sig_val, **exec_params)
76
- self.comp_signal_val[ji.prg.device] = sig_val
105
+ enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
106
+ signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
107
+ self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
77
108
  elif isinstance(ji.prg, BufferXfer):
78
109
  dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
79
110
  Device[src.device]._gpu_map(dest._buf) #type: ignore
80
-
81
- deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
82
- deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
83
- self.copy_signal_val[Device[src.device]] = sig_val
84
-
85
- for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
86
- self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
87
- .signal(self.copy_signal[Device[src.device]], sig_val)
111
+ enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value)
88
112
  self.copy_to_devs[Device[dest.device]].add(Device[src.device])
89
113
 
114
+ # Encode finish profile timestamp (if needed).
115
+ if prof_info: enqueue_queue.timestamp(prof_info[1])
116
+
90
117
  for dev in self.devices:
91
- if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
92
- for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
118
+ for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
119
+ if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
120
+ self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1])
93
121
 
122
+ self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
94
123
  if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
95
- if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
124
+ if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
96
125
 
97
126
  def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
98
127
  # Wait and restore signals
99
128
  self.kickoff_value += 1
100
129
  for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
101
- for dev in self.devices:
102
- dev._set_signal(self.comp_signal[dev], 0)
103
- dev._set_signal(self.copy_signal[dev], 0)
104
- dev._set_signal(self.kickoff_signal, self.kickoff_value)
130
+ for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
131
+ for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
132
+ self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
133
+
134
+ if PROFILE and self.kickoff_value > 1:
135
+ for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
136
+ dev.raw_prof_records += [(dev._read_timestamp(st), dev._read_timestamp(en), desc, is_cp)]
105
137
 
106
138
  # Update rawbuffers
107
- for (j,i),input_idx in self.input_replace.items():
108
- self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
139
+ for (j,i),input_idx in self.input_replace.items(): self.ji_args_bufs[j][i] = input_rawbuffers[input_idx]._buf.va_addr
109
140
 
110
141
  # Update var_vals
111
142
  for j in self.jc_idx_with_updatable_var_vals:
112
- for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
113
- self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
143
+ for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars): self.ji_args_vars[j][i] = var_vals[v]
114
144
 
115
145
  for j in self.jc_idx_with_updatable_launch_dims:
116
146
  queue, cmd_ptr = self.exec_ptrs[j]
117
147
  queue.update_exec(cmd_ptr, *cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals))
118
148
 
119
149
  for dev in self.devices:
120
- # Submit sync with world and queues.
121
- self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
122
- .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
123
- self.comp_queues[dev].submit(dev)
124
-
125
- if self.copy_signal_val[dev] > 0:
126
- self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
127
- .wait(self.kickoff_signal, self.kickoff_value).submit(dev)
128
- self.copy_queues[dev].submit(dev)
129
-
130
- # Signal the final value
131
- self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
150
+ self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
151
+ .update_signal(3, value=self.kickoff_value) \
152
+ .update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
153
+
154
+ if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
155
+ for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
156
+ cp_queue.submit(dev)
157
+
132
158
  self.graph_timeline[dev] = dev.timeline_value
133
159
  dev.timeline_value += 1
134
160
 
@@ -138,6 +164,24 @@ class HCQGraph(MultiGraphRunner):
138
164
  return time.perf_counter() - st
139
165
  return None
140
166
 
141
- def access_resources(self, read, write, new_dependency):
142
- deps = self._access_resources(read, write, new_dependency)
143
- return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
167
+ def access_resources(self, queue, read, write, new_val):
168
+ deps = self._access_resources(read, write, (queue, new_val))
169
+
170
+ sync_signals = []
171
+ for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
172
+ for buf in read+write:
173
+ if buf.device not in self.save_devs[queue]:
174
+ self.save_devs[queue].add(buf.device)
175
+ sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
176
+
177
+ return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
178
+
179
+ def __del__(self):
180
+ for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
181
+
182
+ # Graph is destructed. No need to keep signals any more, so return them as part of profiling.
183
+ if PROFILE and self.kickoff_value > 1:
184
+ for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
185
+
186
+ self.devices[0].signals_pool += list(self.dev_kickoff_signal.values()) + list(self.signals.values()) # type: ignore
187
+ for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))