gstaichi 0.1.25.dev0__cp311-cp311-macosx_15_0_arm64.whl → 2.0.0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +6 -0
- gstaichi/__init__.py +1 -1
- gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +11 -41
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -1
- gstaichi/examples/minimal.py +1 -1
- gstaichi/lang/__init__.py +1 -1
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_template_mapper.py +16 -20
- gstaichi/lang/_wrap_inspect.py +27 -1
- gstaichi/lang/ast/ast_transformer.py +7 -2
- gstaichi/lang/ast/ast_transformer_utils.py +18 -13
- gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
- gstaichi/lang/field.py +0 -38
- gstaichi/lang/impl.py +25 -24
- gstaichi/lang/kernel_arguments.py +28 -30
- gstaichi/lang/kernel_impl.py +154 -200
- gstaichi/lang/matrix.py +0 -46
- gstaichi/lang/struct.py +0 -45
- gstaichi/lang/util.py +11 -80
- gstaichi/types/annotations.py +10 -5
- gstaichi/types/compound_types.py +1 -20
- gstaichi/types/ndarray_type.py +31 -11
- gstaichi/types/utils.py +0 -2
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/METADATA +2 -1
- gstaichi-2.0.0.dist-info/RECORD +177 -0
- gstaichi/__main__.py +0 -5
- gstaichi/_main.py +0 -545
- gstaichi/lang/argpack.py +0 -411
- gstaichi-0.1.25.dev0.dist-info/RECORD +0 -168
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +0 -2
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/GLFW/glfw3native.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/GLSL.std.450.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cfg.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_common.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cpp.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_c.h +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_containers.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_error_handling.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_cross_util.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_glsl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_hlsl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_msl.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_parser.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/include/spirv_cross/spirv_reflect.hpp +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Config.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +0 -0
- {gstaichi-0.1.25.dev0.data → gstaichi-2.0.0.data}/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/licenses/LICENSE +0 -0
- {gstaichi-0.1.25.dev0.dist-info → gstaichi-2.0.0.dist-info}/top_level.txt +0 -0
gstaichi/CHANGELOG.md
ADDED
gstaichi/__init__.py
CHANGED
@@ -24,7 +24,7 @@ from gstaichi.types.primitive_types import *
|
|
24
24
|
|
25
25
|
def __getattr__(attr):
|
26
26
|
if attr == "cfg":
|
27
|
-
return None if lang.impl.get_runtime().
|
27
|
+
return None if lang.impl.get_runtime()._prog is None else lang.impl.current_cfg()
|
28
28
|
raise AttributeError(f"module '{__name__}' has no attribute '{attr}'")
|
29
29
|
|
30
30
|
|
Binary file
|
@@ -4,7 +4,7 @@ gstaichi_python
|
|
4
4
|
from __future__ import annotations
|
5
5
|
import numpy
|
6
6
|
import typing
|
7
|
-
__all__ = ['ADJOINT', 'ADJOINT_CHECKBIT', 'AOS', 'ASTBuilder', 'Arch', '
|
7
|
+
__all__: list[str] = ['ADJOINT', 'ADJOINT_CHECKBIT', 'AOS', 'ASTBuilder', 'Arch', 'AutodiffMode', 'Axis', 'Benchmark', 'BinaryOpType', 'BitStructType', 'BitStructTypeBuilder', 'BoundaryMode', 'CC', 'CE', 'CF', 'CGd', 'CGf', 'CLAMP', 'CUCG', 'CV', 'Cell', 'CompileConfig', 'CompiledKernelData', 'Config', 'ConvType', 'CuSparseMatrix', 'CuSparseSolver', 'DUAL', 'DataTypeCxx', 'DataType_f16', 'DataType_f32', 'DataType_f64', 'DataType_gen', 'DataType_i16', 'DataType_i32', 'DataType_i64', 'DataType_i8', 'DataType_u1', 'DataType_u16', 'DataType_u32', 'DataType_u64', 'DataType_u8', 'DataType_unknown', 'DebugInfo', 'DeviceAllocation', 'DeviceCapabilityConfig', 'EC', 'EE', 'EF', 'EV', 'Edge', 'EigenSparseSolverfloat32LDLTAMD', 'EigenSparseSolverfloat32LDLTCOLAMD', 'EigenSparseSolverfloat32LLTAMD', 'EigenSparseSolverfloat32LLTCOLAMD', 'EigenSparseSolverfloat32LUAMD', 'EigenSparseSolverfloat32LUCOLAMD', 'EigenSparseSolverfloat64LDLTAMD', 'EigenSparseSolverfloat64LDLTCOLAMD', 'EigenSparseSolverfloat64LLTAMD', 'EigenSparseSolverfloat64LLTCOLAMD', 'EigenSparseSolverfloat64LUAMD', 'EigenSparseSolverfloat64LUCOLAMD', 'ExprCxx', 'ExprGroup', 'Extension', 'FC', 'FE', 'FF', 'FORWARD', 'FV', 'Face', 'Format', 'Function', 'FunctionKey', 'GsTaichiAssertionError', 'GsTaichiIndexError', 'GsTaichiRuntimeError', 'GsTaichiSyntaxError', 'GsTaichiTypeError', 'HackedSignalRegister', 'InternalOp', 'KernelCxx', 'KernelLaunchContext', 'KernelProfileTracedRecord', 'KernelProfilerQueryResult', 'Layout', 'Mesh', 'MeshElementType', 'MeshPtr', 'MeshRelationType', 'MeshTopology', 'NONE', 'NULL', 'NdarrayCxx', 'Operation', 'PRIMAL', 'Program', 'REVERSE', 'SNodeAccessFlag', 'SNodeCxx', 'SNodeGradType', 'SNodeRegistry', 'SNodeTreeCxx', 'SNodeType', 'SOA', 'SparseMatrix', 'SparseMatrixBuilder', 'SparseSolver', 'Stmt', 'Task', 'Tetrahedron', 'TextureCxx', 'TextureOpType', 'Triangle', 'Type', 'TypeFactory', 'UNSAFE', 'UnaryOpType', 'VALIDATION', 'VC', 'VE', 'VF', 'VV', 'Vector2d', 'Vector2f', 'Vector2i', 'Vector3d', 'Vector3f', 'Vector3i', 'Vector4d', 'Vector4f', 'Vector4i', 'Vertex', 'abs', 'acos', 'add', 'adstack', 'amdgpu', 'arch_from_name', 'arch_name', 'arch_uses_llvm', 'arm64', 'asin', 'assertion', 'atan2', 'bit_and', 'bit_not', 'bit_or', 'bit_sar', 'bit_shl', 'bit_shr', 'bit_struct', 'bit_xor', 'bitmasked', 'bits_cast', 'block_local', 'bls', 'cast_bits', 'cast_value', 'ceil', 'clear_profile_info', 'clz', 'cmp_eq', 'cmp_ge', 'cmp_gt', 'cmp_le', 'cmp_lt', 'cmp_ne', 'cos', 'create_benchmark', 'create_initialized_benchmark', 'create_initialized_task', 'create_mesh', 'create_task', 'critical', 'cuda', 'cuda_version', 'dColMajor_EigenSparseMatrix', 'dRowMajor_EigenSparseMatrix', 'data64', 'data_type_name', 'data_type_size', 'debug', 'default_compile_config', 'dense', 'div', 'dynamic', 'element_order', 'element_type_name', 'error', 'exp', 'expr_abs', 'expr_acos', 'expr_add', 'expr_asin', 'expr_assume_in_range', 'expr_atan2', 'expr_atomic_add', 'expr_atomic_bit_and', 'expr_atomic_bit_or', 'expr_atomic_bit_xor', 'expr_atomic_max', 'expr_atomic_min', 'expr_atomic_mul', 'expr_atomic_sub', 'expr_bit_and', 'expr_bit_not', 'expr_bit_or', 'expr_bit_sar', 'expr_bit_shl', 'expr_bit_shr', 'expr_bit_xor', 'expr_ceil', 'expr_clz', 'expr_cmp_eq', 'expr_cmp_ge', 'expr_cmp_gt', 'expr_cmp_le', 'expr_cmp_lt', 'expr_cmp_ne', 'expr_cos', 'expr_div', 'expr_exp', 'expr_field', 'expr_floor', 'expr_floordiv', 'expr_frexp', 'expr_ifte', 'expr_inv', 'expr_log', 'expr_logic_not', 'expr_logical_and', 'expr_logical_or', 'expr_loop_unique', 'expr_matrix_field', 'expr_max', 'expr_min', 'expr_mod', 'expr_mul', 'expr_neg', 'expr_popcnt', 'expr_pow', 'expr_rcp', 'expr_round', 'expr_rsqrt', 'expr_select', 'expr_sin', 'expr_sqrt', 'expr_sub', 'expr_tan', 'expr_tanh', 'expr_truediv', 'extfunc', 'fColMajor_EigenSparseMatrix', 'fRowMajor_EigenSparseMatrix', 'finalize_snode_tree', 'floor', 'floordiv', 'flush_log', 'frexp', 'from_end_element_order', 'g2r', 'get_commit_hash', 'get_default_float_size', 'get_external_tensor_dim', 'get_external_tensor_element_dim', 'get_external_tensor_element_shape', 'get_external_tensor_element_type', 'get_external_tensor_needs_grad', 'get_external_tensor_real_func_args', 'get_external_tensor_shape_along_axis', 'get_llvm_target_support', 'get_max_num_indices', 'get_num_elements', 'get_python_package_dir', 'get_relation_access', 'get_relation_size', 'get_repo_dir', 'get_type_factory_instance', 'get_version_major', 'get_version_minor', 'get_version_patch', 'get_version_string', 'hash', 'host_arch', 'info', 'insert_internal_func_call', 'inv', 'inverse_relation', 'is_extension_supported', 'is_integral', 'is_quant', 'is_real', 'is_signed', 'is_tensor', 'is_unsigned', 'js', 'kFetchTexel', 'kLoad', 'kSampleLod', 'kStore', 'kUndefined', 'l2g', 'l2r', 'libdevice_path', 'log', 'logging_effective', 'logic_not', 'make_arg_load_expr', 'make_binary_op_expr', 'make_const_expr_bool', 'make_const_expr_fp', 'make_const_expr_int', 'make_cucg_solver', 'make_cusparse_solver', 'make_double_cg_solver', 'make_external_tensor_expr', 'make_external_tensor_grad_expr', 'make_float_cg_solver', 'make_frontend_assign_stmt', 'make_get_element_expr', 'make_global_load_stmt', 'make_global_store_stmt', 'make_rand_expr', 'make_reference', 'make_rw_texture_ptr_expr', 'make_sparse_solver', 'make_texture_ptr_expr', 'make_unary_op_expr', 'max', 'mesh', 'mesh_local', 'metal', 'min', 'mod', 'mul', 'neg', 'opencl', 'place', 'pointer', 'pop_python_print_buffer', 'popcnt', 'pow', 'print_all_units', 'print_profile_info', 'promoted_type', 'quant', 'quant_array', 'quant_basic', 'query_int64', 'rcp', 'read_only', 'relation_by_orders', 'reset_default_compile_config', 'root', 'round', 'rsqrt', 'set_core_state_python_imported', 'set_core_trigger_gdb_when_crash', 'set_index_mapping', 'set_lib_dir', 'set_logging_level', 'set_logging_level_default', 'set_num_elements', 'set_num_patches', 'set_owned_offset', 'set_patch_max_element_num', 'set_python_package_dir', 'set_relation_dynamic', 'set_relation_fixed', 'set_tmp_dir', 'set_total_offset', 'set_vulkan_visible_device', 'sgn', 'sin', 'sparse', 'sqrt', 'start_memory_monitoring', 'sub', 'subscript_with_multiple_indices', 'tan', 'tanh', 'test_cpp_exception', 'test_logging', 'test_printf', 'test_raise_error', 'test_threading', 'test_throw', 'to_end_element_order', 'toggle_python_print_buffer', 'trace', 'trigger_crash', 'trigger_sig_fpe', 'truediv', 'undefined', 'value_cast', 'vulkan', 'wait_for_debugger', 'warn', 'with_amdgpu', 'with_cuda', 'with_metal', 'with_vulkan', 'x64']
|
8
8
|
class ASTBuilder:
|
9
9
|
def begin_frontend_if(self, arg0: ..., arg1: DebugInfo) -> None:
|
10
10
|
...
|
@@ -167,26 +167,6 @@ class Arch:
|
|
167
167
|
@property
|
168
168
|
def value(self) -> int:
|
169
169
|
...
|
170
|
-
class ArgPackCxx:
|
171
|
-
def data_type(self) -> DataTypeCxx:
|
172
|
-
...
|
173
|
-
def device_allocation(self) -> DeviceAllocation:
|
174
|
-
...
|
175
|
-
def device_allocation_ptr(self) -> int:
|
176
|
-
...
|
177
|
-
def nelement(self) -> int:
|
178
|
-
...
|
179
|
-
def set_arg_float(self, arg0: tuple[int, ...], arg1: float) -> None:
|
180
|
-
...
|
181
|
-
def set_arg_int(self, arg0: tuple[int, ...], arg1: int) -> None:
|
182
|
-
...
|
183
|
-
def set_arg_nested_argpack(self, arg0: int, arg1: ArgPackCxx) -> None:
|
184
|
-
...
|
185
|
-
def set_arg_uint(self, arg0: tuple[int, ...], arg1: int) -> None:
|
186
|
-
...
|
187
|
-
@property
|
188
|
-
def dtype(self) -> DataTypeCxx:
|
189
|
-
...
|
190
170
|
class AutodiffMode:
|
191
171
|
"""
|
192
172
|
Members:
|
@@ -1213,8 +1193,6 @@ class KernelCxx:
|
|
1213
1193
|
...
|
1214
1194
|
def finalize_rets(self) -> None:
|
1215
1195
|
...
|
1216
|
-
def insert_argpack_param_and_push(self, arg0: str) -> tuple[int, ...]:
|
1217
|
-
...
|
1218
1196
|
def insert_arr_param(self, arg0: DataTypeCxx, arg1: int, arg2: tuple[int, ...], arg3: str) -> tuple[int, ...]:
|
1219
1197
|
...
|
1220
1198
|
def insert_ndarray_param(self, arg0: DataTypeCxx, arg1: int, arg2: str, arg3: bool) -> tuple[int, ...]:
|
@@ -1233,8 +1211,6 @@ class KernelCxx:
|
|
1233
1211
|
...
|
1234
1212
|
def no_activate(self, arg0: SNodeCxx) -> None:
|
1235
1213
|
...
|
1236
|
-
def pop_argpack_stack(self) -> None:
|
1237
|
-
...
|
1238
1214
|
class KernelLaunchContext:
|
1239
1215
|
def get_struct_ret_float(self, arg0: tuple[int, ...]) -> float:
|
1240
1216
|
...
|
@@ -1242,8 +1218,6 @@ class KernelLaunchContext:
|
|
1242
1218
|
...
|
1243
1219
|
def get_struct_ret_uint(self, arg0: tuple[int, ...]) -> int:
|
1244
1220
|
...
|
1245
|
-
def set_arg_argpack(self, arg0: tuple[int, ...], arg1: ArgPackCxx) -> None:
|
1246
|
-
...
|
1247
1221
|
def set_arg_external_array_with_shape(self, arg0: tuple[int, ...], arg1: int, arg2: int, arg3: tuple[int, ...], arg4: int) -> None:
|
1248
1222
|
...
|
1249
1223
|
def set_arg_float(self, arg0: tuple[int, ...], arg1: float) -> None:
|
@@ -1558,8 +1532,6 @@ class Program:
|
|
1558
1532
|
...
|
1559
1533
|
def config(self) -> CompileConfig:
|
1560
1534
|
...
|
1561
|
-
def create_argpack(self, dt: DataTypeCxx) -> ...:
|
1562
|
-
...
|
1563
1535
|
def create_function(self, arg0: ...) -> Function:
|
1564
1536
|
...
|
1565
1537
|
def create_kernel(self, arg0: typing.Callable[[...], None], arg1: str, arg2: AutodiffMode) -> KernelCxx:
|
@@ -1570,10 +1542,10 @@ class Program:
|
|
1570
1542
|
...
|
1571
1543
|
def create_texture(self, fmt: ..., shape: tuple[int, ...] = ()) -> ...:
|
1572
1544
|
...
|
1573
|
-
def delete_argpack(self, arg0: ...) -> None:
|
1574
|
-
...
|
1575
1545
|
def delete_ndarray(self, arg0: ...) -> None:
|
1576
1546
|
...
|
1547
|
+
def dump_cache_data_to_disk(self) -> None:
|
1548
|
+
...
|
1577
1549
|
def fill_float(self, arg0: ..., arg1: float) -> None:
|
1578
1550
|
...
|
1579
1551
|
def fill_int(self, arg0: ..., arg1: int) -> None:
|
@@ -1604,6 +1576,8 @@ class Program:
|
|
1604
1576
|
...
|
1605
1577
|
def launch_kernel(self, arg0: CompiledKernelData, arg1: ...) -> None:
|
1606
1578
|
...
|
1579
|
+
def load_fast_cache(self, arg0: str, arg1: str, arg2: CompileConfig, arg3: DeviceCapabilityConfig) -> CompiledKernelData:
|
1580
|
+
...
|
1607
1581
|
def make_id_expr(self, arg0: str) -> ...:
|
1608
1582
|
...
|
1609
1583
|
def make_sparse_matrix_from_ndarray(self, arg0: ..., arg1: ...) -> None:
|
@@ -1618,6 +1592,8 @@ class Program:
|
|
1618
1592
|
...
|
1619
1593
|
def set_kernel_profiler_toolkit(self, arg0: str) -> bool:
|
1620
1594
|
...
|
1595
|
+
def store_fast_cache(self, arg0: str, arg1: ..., arg2: CompileConfig, arg3: DeviceCapabilityConfig, arg4: CompiledKernelData) -> None:
|
1596
|
+
...
|
1621
1597
|
def sync_kernel_profiler(self) -> None:
|
1622
1598
|
...
|
1623
1599
|
def synchronize(self) -> None:
|
@@ -2013,8 +1989,6 @@ class Type:
|
|
2013
1989
|
def to_string(self) -> str:
|
2014
1990
|
...
|
2015
1991
|
class TypeFactory:
|
2016
|
-
def get_argpack_type(self, arg0: list[tuple[DataTypeCxx, str]]) -> DataTypeCxx:
|
2017
|
-
...
|
2018
1992
|
def get_ndarray_struct_type(self, dt: DataTypeCxx, ndim: int, needs_grad: bool) -> Type:
|
2019
1993
|
...
|
2020
1994
|
def get_quant_fixed_type(self, digits_type: Type, compute_type: Type, scale: float) -> Type:
|
@@ -2027,8 +2001,6 @@ class TypeFactory:
|
|
2027
2001
|
...
|
2028
2002
|
def get_struct_type(self, arg0: list[tuple[DataTypeCxx, str]]) -> DataTypeCxx:
|
2029
2003
|
...
|
2030
|
-
def get_struct_type_for_argpack_ptr(self, dt: DataTypeCxx, layout: str = 'none') -> Type:
|
2031
|
-
...
|
2032
2004
|
def get_tensor_type(self, shape: tuple[int, ...], element_type: DataType) -> DataTypeCxx:
|
2033
2005
|
...
|
2034
2006
|
class UnaryOpType:
|
@@ -2442,8 +2414,6 @@ def arch_uses_llvm(arg0: Arch) -> bool:
|
|
2442
2414
|
...
|
2443
2415
|
def bits_cast(arg0: ExprCxx, arg1: DataTypeCxx) -> ExprCxx:
|
2444
2416
|
...
|
2445
|
-
def clean_offline_cache_files(arg0: str) -> int:
|
2446
|
-
...
|
2447
2417
|
def clear_profile_info() -> None:
|
2448
2418
|
...
|
2449
2419
|
def create_benchmark(arg0: str) -> ...:
|
@@ -2670,7 +2640,7 @@ def libdevice_path() -> str:
|
|
2670
2640
|
...
|
2671
2641
|
def logging_effective(arg0: str) -> bool:
|
2672
2642
|
...
|
2673
|
-
def make_arg_load_expr(arg_id: tuple[int, ...], dt: DataTypeCxx, is_ptr: bool = False, create_load: bool = True,
|
2643
|
+
def make_arg_load_expr(arg_id: tuple[int, ...], dt: DataTypeCxx, is_ptr: bool = False, create_load: bool = True, dbg_info: DebugInfo = ...) -> ExprCxx:
|
2674
2644
|
...
|
2675
2645
|
def make_binary_op_expr(arg0: BinaryOpType, arg1: ExprCxx, arg2: ExprCxx) -> ExprCxx:
|
2676
2646
|
...
|
@@ -2686,7 +2656,7 @@ def make_cusparse_solver(arg0: DataTypeCxx, arg1: str, arg2: str) -> SparseSolve
|
|
2686
2656
|
...
|
2687
2657
|
def make_double_cg_solver(arg0: SparseMatrix, arg1: int, arg2: float, arg3: bool) -> CGd:
|
2688
2658
|
...
|
2689
|
-
def make_external_tensor_expr(arg0: DataTypeCxx, arg1: int, arg2: tuple[int, ...], arg3: bool, arg4:
|
2659
|
+
def make_external_tensor_expr(arg0: DataTypeCxx, arg1: int, arg2: tuple[int, ...], arg3: bool, arg4: BoundaryMode) -> ExprCxx:
|
2690
2660
|
...
|
2691
2661
|
def make_external_tensor_grad_expr(arg0: ExprCxx) -> ExprCxx:
|
2692
2662
|
...
|
@@ -2704,11 +2674,11 @@ def make_rand_expr(arg0: DataTypeCxx, arg1: DebugInfo) -> ExprCxx:
|
|
2704
2674
|
...
|
2705
2675
|
def make_reference(arg0: ExprCxx, arg1: DebugInfo) -> ExprCxx:
|
2706
2676
|
...
|
2707
|
-
def make_rw_texture_ptr_expr(arg0: tuple[int, ...], arg1: int, arg2:
|
2677
|
+
def make_rw_texture_ptr_expr(arg0: tuple[int, ...], arg1: int, arg2: Format, arg3: int, arg4: DebugInfo) -> ExprCxx:
|
2708
2678
|
...
|
2709
2679
|
def make_sparse_solver(arg0: DataTypeCxx, arg1: str, arg2: str) -> SparseSolver:
|
2710
2680
|
...
|
2711
|
-
def make_texture_ptr_expr(arg0: tuple[int, ...], arg1: int, arg2:
|
2681
|
+
def make_texture_ptr_expr(arg0: tuple[int, ...], arg1: int, arg2: DebugInfo) -> ExprCxx:
|
2712
2682
|
...
|
2713
2683
|
def make_unary_op_expr(arg0: UnaryOpType, arg1: ExprCxx) -> ExprCxx:
|
2714
2684
|
...
|
gstaichi/_test_tools/__init__.py
CHANGED
@@ -0,0 +1,18 @@
|
|
1
|
+
import gstaichi as ti
|
2
|
+
|
3
|
+
from . import textwrap2
|
4
|
+
|
5
|
+
|
6
|
+
def ti_init_same_arch(**options) -> None:
|
7
|
+
"""
|
8
|
+
Used in tests to call ti.init, passing in the same arch as currently
|
9
|
+
configured. Since it's fairly fiddly to do that, extracting this out
|
10
|
+
to this helper function.
|
11
|
+
"""
|
12
|
+
assert ti.cfg is not None
|
13
|
+
options = dict(options)
|
14
|
+
options["arch"] = getattr(ti, ti.cfg.arch.name)
|
15
|
+
ti.init(**options)
|
16
|
+
|
17
|
+
|
18
|
+
__all__ = ["textwrap2"]
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import dataclasses
|
2
|
+
from typing import Any, cast
|
3
|
+
|
4
|
+
import gstaichi as ti
|
5
|
+
|
6
|
+
|
7
|
+
def _make_child_obj(obj_type: Any) -> Any:
|
8
|
+
if isinstance(obj_type, ti.types.NDArray):
|
9
|
+
ndarray_type = cast(ti.types.ndarray, obj_type)
|
10
|
+
assert ndarray_type.ndim is not None
|
11
|
+
shape = tuple([10] * ndarray_type.ndim)
|
12
|
+
child_obj = ti.ndarray(ndarray_type.dtype, shape=shape)
|
13
|
+
elif dataclasses.is_dataclass(obj_type):
|
14
|
+
child_obj = build_struct(obj_type)
|
15
|
+
elif isinstance(obj_type, ti.Template) or obj_type == ti.Template:
|
16
|
+
child_obj = ti.field(ti.i32, (10,))
|
17
|
+
else:
|
18
|
+
raise Exception("unknown type ", obj_type)
|
19
|
+
return child_obj
|
20
|
+
|
21
|
+
|
22
|
+
def build_struct(struct_type: Any) -> Any:
|
23
|
+
member_objects = {}
|
24
|
+
for field in dataclasses.fields(struct_type):
|
25
|
+
child_obj = _make_child_obj(field.type)
|
26
|
+
member_objects[field.name] = child_obj
|
27
|
+
dataclass_object = struct_type(**member_objects)
|
28
|
+
return dataclass_object
|
29
|
+
|
30
|
+
|
31
|
+
def build_obj_tuple_from_type_dict(name_to_type: dict[str, Any]) -> tuple[Any, ...]:
|
32
|
+
obj_l = []
|
33
|
+
for _name, param_type in name_to_type.items():
|
34
|
+
child_obj = _make_child_obj(param_type)
|
35
|
+
obj_l.append(child_obj)
|
36
|
+
return tuple(obj_l)
|
gstaichi/_version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = '0.
|
1
|
+
__version__ = '2.0.0'
|
gstaichi/examples/minimal.py
CHANGED
@@ -2,7 +2,7 @@ import gstaichi as ti
|
|
2
2
|
|
3
3
|
|
4
4
|
@ti.kernel
|
5
|
-
def lcg_ti(B: int, lcg_its: int, a: ti.types.
|
5
|
+
def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
|
6
6
|
"""
|
7
7
|
Linear congruential generator https://en.wikipedia.org/wiki/Linear_congruential_generator
|
8
8
|
"""
|
gstaichi/lang/__init__.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
# type: ignore
|
2
2
|
|
3
3
|
from gstaichi.lang import impl, simt
|
4
|
+
from gstaichi.lang._fast_caching.function_hasher import pure
|
4
5
|
from gstaichi.lang._ndarray import *
|
5
6
|
from gstaichi.lang._ndrange import ndrange
|
6
7
|
from gstaichi.lang._texture import Texture
|
7
|
-
from gstaichi.lang.argpack import *
|
8
8
|
from gstaichi.lang.exception import *
|
9
9
|
from gstaichi.lang.field import *
|
10
10
|
from gstaichi.lang.impl import *
|
@@ -0,0 +1,31 @@
|
|
1
|
+
def create_flat_name(basename: str, child_name: str) -> str:
|
2
|
+
"""
|
3
|
+
Appends child_name to basename, separated by __ti_.
|
4
|
+
If basename does not start with __ti_ then prefix the resulting string
|
5
|
+
with __ti_.
|
6
|
+
|
7
|
+
Note that we want to avoid adding prefix __ti_ if already included in `basename`,
|
8
|
+
to avoid duplicating said delimiter.
|
9
|
+
|
10
|
+
We'll use this when expanding py dataclass members, e.g.
|
11
|
+
|
12
|
+
@dataclasses.dataclass
|
13
|
+
def Foo:
|
14
|
+
a: int
|
15
|
+
b: int
|
16
|
+
|
17
|
+
foo = Foo(a=5, b=3)
|
18
|
+
|
19
|
+
When we expand out foo, we'll replace foo with the following names instead:
|
20
|
+
- __ti_foo__ti_a
|
21
|
+
- __ti_foo__ti_b
|
22
|
+
|
23
|
+
We use the __ti_ to ensure that it's easy to ensure no collision with existing user-defined
|
24
|
+
names. We require the user to not create any fields or variables which themselves are prefixed
|
25
|
+
with __ti_, and given this constraint, the names we create will not conflict with user-generated
|
26
|
+
names.
|
27
|
+
"""
|
28
|
+
full_name = f"{basename}__ti_{child_name}"
|
29
|
+
if not full_name.startswith("__ti_"):
|
30
|
+
full_name = f"__ti_{full_name}"
|
31
|
+
return full_name
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import enum
|
3
|
+
import numbers
|
4
|
+
import time
|
5
|
+
from typing import Any, Sequence
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from .._ndarray import ScalarNdarray
|
10
|
+
from ..field import ScalarField
|
11
|
+
from ..matrix import MatrixField, MatrixNdarray, VectorNdarray
|
12
|
+
from ..util import is_data_oriented
|
13
|
+
from .hash_utils import hash_iterable_strings
|
14
|
+
|
15
|
+
g_num_calls = 0
|
16
|
+
g_num_args = 0
|
17
|
+
g_hashing_time = 0
|
18
|
+
g_repr_time = 0
|
19
|
+
g_num_ignored_calls = 0
|
20
|
+
|
21
|
+
|
22
|
+
FIELD_METADATA_CACHE_VALUE = "add_value_to_cache_key"
|
23
|
+
|
24
|
+
|
25
|
+
def dataclass_to_repr(path: tuple[str, ...], arg: Any) -> str:
|
26
|
+
repr_l = []
|
27
|
+
for field in dataclasses.fields(arg):
|
28
|
+
child_value = getattr(arg, field.name)
|
29
|
+
_repr = stringify_obj_type(path + (field.name,), child_value)
|
30
|
+
full_repr = f"{field.name}: ({_repr})"
|
31
|
+
if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False):
|
32
|
+
full_repr += f" = {child_value}"
|
33
|
+
repr_l.append(full_repr)
|
34
|
+
return "[" + ",".join(repr_l) + "]"
|
35
|
+
|
36
|
+
|
37
|
+
def stringify_obj_type(path: tuple[str, ...], obj: Any) -> str | None:
|
38
|
+
"""
|
39
|
+
Convert an object into a string representation that only depends on its type.
|
40
|
+
|
41
|
+
String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have
|
42
|
+
to be the actual python type string, just a string that is representative of the type, and won't collide
|
43
|
+
with different (allowed) types.
|
44
|
+
|
45
|
+
`path` is used during debugging.
|
46
|
+
"""
|
47
|
+
# TODO: We should have a way of printing this without having to hack the code really. Using logger perhaps?
|
48
|
+
# (I have another PR that addreses this https://github.com/Genesis-Embodied-AI/gstaichi/pull/144/files)
|
49
|
+
arg_type = type(obj)
|
50
|
+
if isinstance(obj, ScalarNdarray):
|
51
|
+
return f"[nd-{obj.dtype}-{len(obj.shape)}]"
|
52
|
+
if isinstance(obj, VectorNdarray):
|
53
|
+
return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
|
54
|
+
if isinstance(obj, ScalarField):
|
55
|
+
return f"[f-{obj.snode._id}-{obj.dtype}-{obj.shape}]"
|
56
|
+
if isinstance(obj, MatrixNdarray):
|
57
|
+
return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
|
58
|
+
if "torch.Tensor" in str(arg_type):
|
59
|
+
return f"[pt-{obj.dtype}-{obj.ndim}]"
|
60
|
+
if isinstance(obj, np.ndarray):
|
61
|
+
return f"[np-{obj.dtype}-{obj.ndim}]"
|
62
|
+
if isinstance(obj, MatrixField):
|
63
|
+
return f"[fm-{obj.m}-{obj.n}-{obj.snode._id}-{obj.dtype}-{obj.shape}]"
|
64
|
+
if dataclasses.is_dataclass(obj):
|
65
|
+
return dataclass_to_repr(path, obj)
|
66
|
+
if is_data_oriented(obj):
|
67
|
+
child_repr_l = []
|
68
|
+
for k, v in obj.__dict__.items():
|
69
|
+
_child_repr = stringify_obj_type((*path, k), v)
|
70
|
+
if _child_repr is None:
|
71
|
+
print("not representable child", k, type(v), "path", path)
|
72
|
+
return None
|
73
|
+
child_repr_l.append(f"{k}: {_child_repr}")
|
74
|
+
return ", ".join(child_repr_l)
|
75
|
+
if issubclass(arg_type, (numbers.Number, np.number)):
|
76
|
+
return str(arg_type)
|
77
|
+
if arg_type is np.bool_:
|
78
|
+
# np is deprecating bool. Treat specially/carefully
|
79
|
+
return "np.bool_"
|
80
|
+
if isinstance(obj, enum.Enum):
|
81
|
+
return f"enum-{obj.name}-{obj.value}"
|
82
|
+
return None
|
83
|
+
|
84
|
+
|
85
|
+
def hash_args(args: Sequence[Any]) -> str | None:
|
86
|
+
global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls
|
87
|
+
g_num_calls += 1
|
88
|
+
g_num_args += len(args)
|
89
|
+
hash_l = []
|
90
|
+
for i_arg, arg in enumerate(args):
|
91
|
+
start = time.time()
|
92
|
+
_hash = stringify_obj_type((str(i_arg),), arg)
|
93
|
+
g_repr_time += time.time() - start
|
94
|
+
if not _hash:
|
95
|
+
g_num_ignored_calls += 1
|
96
|
+
return None
|
97
|
+
hash_l.append(_hash)
|
98
|
+
start = time.time()
|
99
|
+
res = hash_iterable_strings(hash_l)
|
100
|
+
g_hashing_time += time.time() - start
|
101
|
+
return res
|
102
|
+
|
103
|
+
|
104
|
+
def dump_stats() -> None:
|
105
|
+
print("args hasher dump stats")
|
106
|
+
print("total calls", g_num_calls)
|
107
|
+
print("ignored calls", g_num_ignored_calls)
|
108
|
+
print("total args", g_num_args)
|
109
|
+
print("hashing time", g_hashing_time)
|
110
|
+
print("arg representation time", g_repr_time)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from gstaichi.lang import impl
|
2
|
+
|
3
|
+
from .hash_utils import hash_iterable_strings
|
4
|
+
|
5
|
+
EXCLUDE_PREFIXES = ["_", "offline_cache", "print_", "verbose_"]
|
6
|
+
|
7
|
+
|
8
|
+
def hash_compile_config() -> str:
|
9
|
+
"""
|
10
|
+
Calculates a hash string for the current compiler config.
|
11
|
+
|
12
|
+
If any value in the compiler config changes, the hash string changes too.
|
13
|
+
|
14
|
+
Though arguably we might want to blacklist certain keys, such as print_ir_debug,
|
15
|
+
which do not affect the compiled kernels, just stuff that gets printed during
|
16
|
+
the compilation process.
|
17
|
+
"""
|
18
|
+
config = impl.get_runtime().prog.config()
|
19
|
+
config_l = []
|
20
|
+
for k in dir(config):
|
21
|
+
skip = False
|
22
|
+
for prefix in EXCLUDE_PREFIXES:
|
23
|
+
if k.startswith(prefix) or k in [""]:
|
24
|
+
skip = True
|
25
|
+
if skip:
|
26
|
+
continue
|
27
|
+
v = getattr(config, k)
|
28
|
+
config_l.append(f"{k}={v}")
|
29
|
+
config_hash = hash_iterable_strings(config_l, separator="\n")
|
30
|
+
return config_hash
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from .._wrap_inspect import FunctionSourceInfo
|
4
|
+
|
5
|
+
|
6
|
+
class HashedFunctionSourceInfo(BaseModel):
|
7
|
+
"""
|
8
|
+
Wraps a function source info, and the hash string of that function.
|
9
|
+
|
10
|
+
By not adding the hash directly into function source info, we avoid
|
11
|
+
having to make hash an optional type, and checking if it's empty or not.
|
12
|
+
|
13
|
+
If you have a HashedFunctionSourceInfo object, then you are guaranteed
|
14
|
+
to have the hash string.
|
15
|
+
|
16
|
+
If you only have the FunctionSourceInfo object, you are guaranteed that it
|
17
|
+
does not have a hash string.
|
18
|
+
"""
|
19
|
+
|
20
|
+
function_source_info: FunctionSourceInfo
|
21
|
+
hash: str
|
@@ -0,0 +1,57 @@
|
|
1
|
+
import os
|
2
|
+
from itertools import islice
|
3
|
+
from typing import TYPE_CHECKING, Iterable
|
4
|
+
|
5
|
+
from .._wrap_inspect import FunctionSourceInfo
|
6
|
+
from .fast_caching_types import HashedFunctionSourceInfo
|
7
|
+
from .hash_utils import hash_iterable_strings
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from gstaichi.lang.kernel_impl import GsTaichiCallable
|
11
|
+
|
12
|
+
|
13
|
+
def pure(fn: "GsTaichiCallable") -> "GsTaichiCallable":
|
14
|
+
fn.is_pure = True
|
15
|
+
return fn
|
16
|
+
|
17
|
+
|
18
|
+
def _read_file(function_info: FunctionSourceInfo) -> list[str]:
|
19
|
+
with open(function_info.filepath) as f:
|
20
|
+
return list(islice(f, function_info.start_lineno, function_info.end_lineno + 1))
|
21
|
+
|
22
|
+
|
23
|
+
def _hash_function(function_info: FunctionSourceInfo) -> str:
|
24
|
+
return hash_iterable_strings(_read_file(function_info))
|
25
|
+
|
26
|
+
|
27
|
+
def hash_functions(function_infos: Iterable[FunctionSourceInfo]) -> list[HashedFunctionSourceInfo]:
|
28
|
+
results = []
|
29
|
+
for f_info in function_infos:
|
30
|
+
hash_ = _hash_function(f_info)
|
31
|
+
results.append(HashedFunctionSourceInfo(function_source_info=f_info, hash=hash_))
|
32
|
+
return results
|
33
|
+
|
34
|
+
|
35
|
+
def hash_kernel(kernel_info: FunctionSourceInfo) -> str:
|
36
|
+
return _hash_function(kernel_info)
|
37
|
+
|
38
|
+
|
39
|
+
def dump_stats() -> None:
|
40
|
+
print("function hasher dump stats")
|
41
|
+
|
42
|
+
|
43
|
+
def _validate_hashed_function_info(hashed_function_info: HashedFunctionSourceInfo) -> bool:
|
44
|
+
"""
|
45
|
+
Checks the hash
|
46
|
+
"""
|
47
|
+
if not os.path.isfile(hashed_function_info.function_source_info.filepath):
|
48
|
+
return False
|
49
|
+
_hash = _hash_function(hashed_function_info.function_source_info)
|
50
|
+
return _hash == hashed_function_info.hash
|
51
|
+
|
52
|
+
|
53
|
+
def validate_hashed_function_infos(function_infos: Iterable[HashedFunctionSourceInfo]) -> bool:
|
54
|
+
for function_info in function_infos:
|
55
|
+
if not _validate_hashed_function_info(function_info):
|
56
|
+
return False
|
57
|
+
return True
|
@@ -0,0 +1,11 @@
|
|
1
|
+
import hashlib
|
2
|
+
from typing import Iterable
|
3
|
+
|
4
|
+
|
5
|
+
def hash_iterable_strings(strings: Iterable[str], separator: str = "_") -> str:
|
6
|
+
h = hashlib.sha256()
|
7
|
+
separator_enc = separator.encode("utf-8")
|
8
|
+
for v in strings:
|
9
|
+
h.update(v.encode("utf-8"))
|
10
|
+
h.update(separator_enc)
|
11
|
+
return h.hexdigest()
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
from .. import impl
|
4
|
+
|
5
|
+
|
6
|
+
class PythonSideCache:
|
7
|
+
"""
|
8
|
+
Manages a cache that is managed from the python side (we also have c++-side caches)
|
9
|
+
|
10
|
+
The cache is disk-based. When we create the PythonSideCache object, the cache
|
11
|
+
path is created as a sub-folder of CompileConfig.offline_cache_file_path.
|
12
|
+
|
13
|
+
Note that constructing this object is cheap, so there is no need to maintain some
|
14
|
+
kind of conceptual singleton instance or similar.
|
15
|
+
|
16
|
+
Each cache key value is stored to a single file, with the cache key as the filename.
|
17
|
+
|
18
|
+
No metadata is associated with the file, making management very lightweight.
|
19
|
+
|
20
|
+
We update the file date/time when we read from a particular file, so we can easily
|
21
|
+
implement an LRU cleaning strategy at some point in the future, based on the file
|
22
|
+
date/times.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self) -> None:
|
26
|
+
_cache_parent_folder = impl.get_runtime().prog.config().offline_cache_file_path
|
27
|
+
self.cache_folder = os.path.join(_cache_parent_folder, "python_side_cache")
|
28
|
+
os.makedirs(self.cache_folder, exist_ok=True)
|
29
|
+
|
30
|
+
def _get_filepath(self, key: str) -> str:
|
31
|
+
filepath = os.path.join(self.cache_folder, f"{key}.cache.txt")
|
32
|
+
return filepath
|
33
|
+
|
34
|
+
def _touch(self, filepath):
|
35
|
+
"""
|
36
|
+
Updates file date/time.
|
37
|
+
"""
|
38
|
+
with open(filepath, "a"):
|
39
|
+
os.utime(filepath, None)
|
40
|
+
|
41
|
+
def store(self, key: str, value: str) -> None:
|
42
|
+
filepath = self._get_filepath(key)
|
43
|
+
with open(filepath, "w") as f:
|
44
|
+
f.write(value)
|
45
|
+
|
46
|
+
def try_load(self, key: str) -> str | None:
|
47
|
+
filepath = self._get_filepath(key)
|
48
|
+
if not os.path.isfile(filepath):
|
49
|
+
return None
|
50
|
+
self._touch(filepath)
|
51
|
+
with open(filepath) as f:
|
52
|
+
return f.read()
|