neuralogic 0.8.1.dev0__tar.gz → 0.8.3.dev0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/PKG-INFO +12 -2
  2. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/README.md +11 -1
  3. neuralogic-0.8.3.dev0/neuralogic/__version__.py +1 -0
  4. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/builder/builder.py +23 -18
  5. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/builder/components.py +18 -15
  6. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/builder/dataset.py +4 -18
  7. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/builder/dataset_builder.py +30 -13
  8. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/java_objects.py +6 -7
  9. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/neural_module.py +44 -90
  10. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/settings/settings_proxy.py +2 -9
  11. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/template.py +10 -9
  12. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/torch/neural_module.py +2 -6
  13. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/torch/tensor.py +2 -4
  14. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/logic.py +13 -11
  15. neuralogic-0.8.3.dev0/neuralogic/jar/NeuraLogic.jar +0 -0
  16. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic.egg-info/PKG-INFO +12 -2
  17. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/pyproject.toml +1 -1
  18. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_function.py +4 -5
  19. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_general_modules.py +2 -1
  20. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_inference_engine.py +2 -2
  21. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_java_evaluation.py +4 -44
  22. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_quick_start.py +2 -2
  23. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_recurrent_modules.py +7 -7
  24. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_special_predicates.py +9 -19
  25. neuralogic-0.8.1.dev0/neuralogic/__version__.py +0 -1
  26. neuralogic-0.8.1.dev0/neuralogic/jar/NeuraLogic.jar +0 -0
  27. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/LICENSE +0 -0
  28. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/__init__.py +0 -0
  29. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/__init__.py +0 -0
  30. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/builder/__init__.py +0 -0
  31. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/__init__.py +0 -0
  32. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/factories.py +0 -0
  33. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/__init__.py +0 -0
  34. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/concat.py +0 -0
  35. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/enum.py +0 -0
  36. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/function.py +0 -0
  37. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/function_container.py +0 -0
  38. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/function_graph.py +0 -0
  39. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/reshape.py +0 -0
  40. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/slice.py +0 -0
  41. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/function/softmax.py +0 -0
  42. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/metadata.py +0 -0
  43. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/predicate.py +0 -0
  44. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/relation.py +0 -0
  45. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/rule.py +0 -0
  46. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/constructs/term.py +0 -0
  47. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/enums.py +0 -0
  48. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/settings/__init__.py +0 -0
  49. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/sources.py +0 -0
  50. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/torch/__init__.py +0 -0
  51. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/core/torch/network_output.py +0 -0
  52. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/__init__.py +0 -0
  53. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/base.py +0 -0
  54. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/csv.py +0 -0
  55. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/db.py +0 -0
  56. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/file.py +0 -0
  57. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/dataset/tensor.py +0 -0
  58. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/experimental/__init__.py +0 -0
  59. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/experimental/db/__init__.py +0 -0
  60. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/experimental/db/converter.py +0 -0
  61. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/experimental/db/pg/__init__.py +0 -0
  62. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/experimental/db/pg/helpers.py +0 -0
  63. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/jar/__init__.py +0 -0
  64. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/logging/__init__.py +0 -0
  65. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/__init__.py +0 -0
  66. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/init.py +0 -0
  67. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/java.py +0 -0
  68. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/loss.py +0 -0
  69. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/__init__.py +0 -0
  70. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/__init__.py +0 -0
  71. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/attention.py +0 -0
  72. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/gru.py +0 -0
  73. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/linear.py +0 -0
  74. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/lstm.py +0 -0
  75. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/mlp.py +0 -0
  76. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/pooling.py +0 -0
  77. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/positional_encoding.py +0 -0
  78. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/rnn.py +0 -0
  79. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/rvnn.py +0 -0
  80. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/general/transformer.py +0 -0
  81. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/__init__.py +0 -0
  82. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/appnp.py +0 -0
  83. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/gatv2.py +0 -0
  84. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/gcn.py +0 -0
  85. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/gen.py +0 -0
  86. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/gin.py +0 -0
  87. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/gine.py +0 -0
  88. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/gsage.py +0 -0
  89. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/res_gated.py +0 -0
  90. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/rgcn.py +0 -0
  91. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/sg.py +0 -0
  92. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/gnn/tag.py +0 -0
  93. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/meta/__init__.py +0 -0
  94. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/meta/magnn.py +0 -0
  95. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/meta/meta.py +0 -0
  96. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/module/module.py +0 -0
  97. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/torch_function.py +0 -0
  98. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/nn/trainer.py +0 -0
  99. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/__init__.py +0 -0
  100. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/adam.py +0 -0
  101. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/lr_scheduler/__init__.py +0 -0
  102. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/lr_scheduler/arithmetic.py +0 -0
  103. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/lr_scheduler/geometric.py +0 -0
  104. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/lr_scheduler/lr_decay.py +0 -0
  105. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/optimizer.py +0 -0
  106. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/optim/sgd.py +0 -0
  107. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/setup.py +0 -0
  108. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/__init__.py +0 -0
  109. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/__init__.py +0 -0
  110. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/atomEmbeddings3.txt +0 -0
  111. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/bondEmbeddings3.txt +0 -0
  112. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/examples.txt +0 -0
  113. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/queries.txt +0 -0
  114. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/embeddings.txt +0 -0
  115. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template.txt +0 -0
  116. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template_crosssum.txt +0 -0
  117. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template_elementProduct.txt +0 -0
  118. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template_gnn.txt +0 -0
  119. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template_graphlets.txt +0 -0
  120. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template_partial.txt +0 -0
  121. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/molecules/mutagenesis/templates/template_product.txt +0 -0
  122. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/nations/embeddings.txt +0 -0
  123. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/nations/examples.txt +0 -0
  124. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/nations/queries.txt +0 -0
  125. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/nations/template.txt +0 -0
  126. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/family/examples.txt +0 -0
  127. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/family/queries.txt +0 -0
  128. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/family/template.txt +0 -0
  129. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/trains/examples.txt +0 -0
  130. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/trains/queries.txt +0 -0
  131. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/trains/template.txt +0 -0
  132. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/generalized/examples.txt +0 -0
  133. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/generalized/template.txt +0 -0
  134. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/generalized/texamples.txt +0 -0
  135. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/generalized/ztexamples.txt +0 -0
  136. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/naive/template.txt +0 -0
  137. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/naive/trainExamples.txt +0 -0
  138. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/solution/template.txt +0 -0
  139. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/solution/testExamples.txt +0 -0
  140. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/vectorized/template.txt +0 -0
  141. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/data/datasets/simple/xor/vectorized/trainExamples.txt +0 -0
  142. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic/utils/visualize/__init__.py +0 -0
  143. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic.egg-info/SOURCES.txt +0 -0
  144. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic.egg-info/dependency_links.txt +0 -0
  145. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic.egg-info/requires.txt +0 -0
  146. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/neuralogic.egg-info/top_level.txt +0 -0
  147. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/setup.cfg +0 -0
  148. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/setup.py +0 -0
  149. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_constructs.py +0 -0
  150. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_csv_datasets.py +0 -0
  151. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_datasets.py +0 -0
  152. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_drawing.py +0 -0
  153. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_gnn_modules.py +0 -0
  154. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_lr_decay.py +0 -0
  155. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_modules.py +0 -0
  156. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_settings.py +0 -0
  157. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_torch_function.py +0 -0
  158. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_transformer.py +0 -0
  159. {neuralogic-0.8.1.dev0 → neuralogic-0.8.3.dev0}/tests/test_xor_generalization.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neuralogic
3
- Version: 0.8.1.dev0
3
+ Version: 0.8.3.dev0
4
4
  Summary: PyNeuraLogic lets you use Python to create Differentiable Logic Programs.
5
5
  Home-page: https://github.com/LukasZahradnik/PyNeuraLogic
6
6
  Author: Lukáš Zahradník
@@ -155,7 +155,8 @@ $ pip install neuralogic
155
155
  ```
156
156
 
157
157
 
158
- ### Prerequisites
158
+ <details>
159
+ <summary>Prerequisites</summary>
159
160
 
160
161
  To use PyNeuraLogic, you need to install the following prerequisites:
161
162
 
@@ -168,6 +169,9 @@ Java >= 1.8
168
169
  >
169
170
  > In case you want to use visualization provided in the library, it is required to have [Graphviz](https://graphviz.org/download/) installed.
170
171
 
172
+ </details>
173
+
174
+
171
175
  <br />
172
176
 
173
177
  ## 📦 Predefined Modules
@@ -195,6 +199,12 @@ It contains, for example, predefined modules for:
195
199
  <br />
196
200
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LukasZahradnik/PyNeuraLogic/blob/master/examples/DistinguishingNonRegularGraphs.ipynb) [Distinguishing non-regular graphs](https://github.com/LukasZahradnik/PyNeuraLogic/blob/master/examples/DistinguishingNonRegularGraphs.ipynb)
197
201
 
202
+ ## 🤝 Community & Support
203
+
204
+ - [GitHub Issues](https://github.com/LukasZahradnik/PyNeuraLogic/issues) for bug reports and feature requests
205
+ - [Discussions](https://github.com/LukasZahradnik/PyNeuraLogic/discussions) for questions and ideas
206
+ - [Contributing Guide](CONTRIBUTING.md) if you want to help improve PyNeuraLogic
207
+
198
208
  ## 📝 Papers
199
209
 
200
210
  - [Beyond Graph Neural Networks with Lifted Relational Neural Networks](https://arxiv.org/abs/2007.06286) Machine Learning Journal, 2021
@@ -115,7 +115,8 @@ $ pip install neuralogic
115
115
  ```
116
116
 
117
117
 
118
- ### Prerequisites
118
+ <details>
119
+ <summary>Prerequisites</summary>
119
120
 
120
121
  To use PyNeuraLogic, you need to install the following prerequisites:
121
122
 
@@ -128,6 +129,9 @@ Java >= 1.8
128
129
  >
129
130
  > In case you want to use visualization provided in the library, it is required to have [Graphviz](https://graphviz.org/download/) installed.
130
131
 
132
+ </details>
133
+
134
+
131
135
  <br />
132
136
 
133
137
  ## 📦 Predefined Modules
@@ -155,6 +159,12 @@ It contains, for example, predefined modules for:
155
159
  <br />
156
160
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LukasZahradnik/PyNeuraLogic/blob/master/examples/DistinguishingNonRegularGraphs.ipynb) [Distinguishing non-regular graphs](https://github.com/LukasZahradnik/PyNeuraLogic/blob/master/examples/DistinguishingNonRegularGraphs.ipynb)
157
161
 
162
+ ## 🤝 Community & Support
163
+
164
+ - [GitHub Issues](https://github.com/LukasZahradnik/PyNeuraLogic/issues) for bug reports and feature requests
165
+ - [Discussions](https://github.com/LukasZahradnik/PyNeuraLogic/discussions) for questions and ideas
166
+ - [Contributing Guide](CONTRIBUTING.md) if you want to help improve PyNeuraLogic
167
+
158
168
  ## 📝 Papers
159
169
 
160
170
  - [Beyond Graph Neural Networks with Lifted Relational Neural Networks](https://arxiv.org/abs/2007.06286) Machine Learning Journal, 2021
@@ -0,0 +1 @@
1
+ __version__ = "0.8.3.dev"
@@ -1,5 +1,3 @@
1
- from typing import List, Optional
2
-
3
1
  import jpype
4
2
  from tqdm.autonotebook import tqdm
5
3
 
@@ -9,7 +7,7 @@ from neuralogic.core.settings import SettingsProxy
9
7
  from neuralogic.core.sources import Sources
10
8
 
11
9
 
12
- def stream_to_list(stream) -> List:
10
+ def stream_to_list(stream) -> list:
13
11
  return list(stream.collect(jpype.JClass("java.util.stream.Collectors").toList()))
14
12
 
15
13
 
@@ -52,37 +50,44 @@ class Builder:
52
50
 
53
51
  return template
54
52
 
55
- def ground_from_sources(self, parsed_template, sources: Sources):
56
- return self._ground(parsed_template, sources, None)
53
+ def ground_from_sources(self, parsed_template, sources: Sources, progress: bool):
54
+ if not progress:
55
+ return self._ground(parsed_template, sources, None, None)
56
+ with tqdm(total=None, desc="Grounding", unit=" samples", dynamic_ncols=True) as pbar:
57
+ return self._ground(parsed_template, sources, None, self._callback(pbar))
57
58
 
58
- def ground_from_logic_samples(self, parsed_template, logic_samples):
59
- return self._ground(parsed_template, None, logic_samples)
59
+ def ground_from_logic_samples(self, parsed_template, logic_samples, progress: bool):
60
+ if not progress:
61
+ return self._ground(parsed_template, None, logic_samples, None)
62
+ with tqdm(total=len(logic_samples), desc="Grounding", unit=" samples", dynamic_ncols=True) as pbar:
63
+ return self._ground(parsed_template, None, logic_samples, self._callback(pbar))
60
64
 
61
- def _ground(self, parsed_template, sources: Optional[Sources], logic_samples) -> List[NeuralSample]:
65
+ def _ground(self, parsed_template, sources: Sources | None, logic_samples, callback):
62
66
  if sources is not None:
63
- ground_pipeline = self.example_builder.buildGroundings(parsed_template, sources.sources)
67
+ ground_pipeline = self.example_builder.buildGroundings(parsed_template, sources.sources, callback)
64
68
  else:
65
69
  logic_samples = jpype.java.util.ArrayList(logic_samples).stream()
66
- ground_pipeline = self.example_builder.buildGroundings(parsed_template, logic_samples)
70
+ ground_pipeline = self.example_builder.buildGroundings(parsed_template, logic_samples, callback)
67
71
 
68
72
  ground_pipeline.execute(None if sources is None else sources.sources)
73
+ groundings = ground_pipeline.get()
69
74
 
70
- return ground_pipeline.get()
75
+ if callback is not None:
76
+ return groundings.collect(self.collectors.toList())
77
+ return groundings
71
78
 
72
- def neuralize(self, groundings, progress: bool, length: Optional[int]) -> List[NeuralSample]:
79
+ def neuralize(self, groundings, progress: bool, length: int | None) -> list[NeuralSample]:
73
80
  if not progress:
74
81
  return self._neuralize(groundings, None)
75
82
  with tqdm(total=length, desc="Building", unit=" samples", dynamic_ncols=True) as pbar:
76
83
  return self._neuralize(groundings, self._callback(pbar))
77
84
 
78
- def _neuralize(self, groundings, callback) -> List[NeuralSample]:
79
- neuralize_pipeline = self.example_builder.neuralize(groundings.stream(), callback)
85
+ def _neuralize(self, groundings, callback) -> list[NeuralSample]:
86
+ neuralize_pipeline = self.example_builder.neuralize(groundings, callback)
80
87
  neuralize_pipeline.execute(None)
88
+ logic_samples = neuralize_pipeline.get().collect(self.collectors.toList())
81
89
 
82
- samples = neuralize_pipeline.get()
83
- logic_samples = samples.collect(self.collectors.toList())
84
-
85
- return [NeuralSample(sample, grounding) for sample, grounding in zip(logic_samples, groundings)]
90
+ return [NeuralSample(sample) for sample in logic_samples]
86
91
 
87
92
  def build_model(self, parsed_template, settings: SettingsProxy):
88
93
  neural_model = self.neural_model(parsed_template.getAllWeights(), settings.settings)
@@ -1,7 +1,6 @@
1
1
  import enum
2
2
  from typing import Any, Dict, Optional
3
3
 
4
- from neuralogic.core.settings.settings_proxy import SettingsProxy
5
4
  from neuralogic.core.constructs.java_objects import ValueFactory
6
5
  from neuralogic.utils.visualize import draw_sample, draw_grounding
7
6
 
@@ -32,17 +31,9 @@ class Atom:
32
31
  return self._predicate
33
32
 
34
33
  @property
35
- def arityy(self):
34
+ def arity(self):
36
35
  return self._arity
37
36
 
38
- @property
39
- def value(self):
40
- return ValueFactory.from_java(self._atom.getRawState().getValue(), SettingsProxy.number_format())
41
-
42
- @property
43
- def gradient(self):
44
- return ValueFactory.from_java(self._atom.getRawState().getGradient(), SettingsProxy.number_format())
45
-
46
37
  def node_type(self) -> NeuronType:
47
38
  return NeuronType(self._atom.getClass().getSimpleName())
48
39
 
@@ -51,14 +42,26 @@ class Atom:
51
42
 
52
43
 
53
44
  class Neuron(Atom):
54
- pass
45
+ def __init__(self, neuron, substitutions: Dict):
46
+ self.substitutions = substitutions
47
+ self._atom = neuron
48
+
49
+ self._predicate = neuron.getName()
50
+ self._arity = len(substitutions)
51
+
52
+ @property
53
+ def value(self):
54
+ return ValueFactory.from_java(self._atom.getRawState().getValue())
55
+
56
+ @property
57
+ def gradient(self):
58
+ return ValueFactory.from_java(self._atom.getRawState().getGradient())
55
59
 
56
60
 
57
61
  class NeuralSample:
58
- __slots__ = "_java_sample", "grounding", "_neurons"
62
+ __slots__ = "_java_sample", "_neurons"
59
63
 
60
- def __init__(self, sample, grounding):
61
- self.grounding = Grounding(grounding)
64
+ def __init__(self, sample):
62
65
  self._java_sample = sample
63
66
  self._neurons = None
64
67
 
@@ -70,7 +73,7 @@ class NeuralSample:
70
73
 
71
74
  @property
72
75
  def target(self):
73
- return ValueFactory.from_java(self._java_sample.target, SettingsProxy.number_format())
76
+ return ValueFactory.from_java(self._java_sample.target)
74
77
 
75
78
  def get_neurons(self, literal, neuron_type: NeuronType | None = NeuronType.Atom):
76
79
  literal_name = literal.predicate.name
@@ -29,32 +29,18 @@ class GroundedDataset:
29
29
  __slots__ = "_groundings", "_groundings_list", "_builder"
30
30
 
31
31
  def __init__(self, groundings, builder: Builder):
32
- self._groundings = groundings
33
- self._groundings_list = None
34
32
  self._builder = builder
35
-
36
- def _to_list(self):
37
- if self._groundings_list is None:
38
- self._groundings = self._groundings.collect(jpype.JClass("java.util.stream.Collectors").toList())
39
- self._groundings_list = [Grounding(g) for g in self._groundings]
33
+ self._groundings = groundings
34
+ self._groundings_list = [Grounding(g) for g in self._groundings]
40
35
 
41
36
  def __getitem__(self, item) -> Grounding:
42
- self._to_list()
43
- if self._groundings_list is None:
44
- raise ValueError
45
-
46
37
  return self._groundings_list[item]
47
38
 
48
39
  def __len__(self) -> int:
49
- self._to_list()
50
- if self._groundings_list is None:
51
- return 0
52
40
  return len(self._groundings_list)
53
41
 
54
42
  def __iter__(self):
55
- self._to_list()
56
43
  return iter(self._groundings_list)
57
44
 
58
- def neuralize(self, *, progress: bool = False) -> list[NeuralSample]:
59
- self._to_list()
60
- return self._builder.neuralize(self._groundings, progress, len(self))
45
+ def neuralize(self, *, batch_size: int = 1, progress: bool = False) -> BuiltDataset:
46
+ return BuiltDataset(self._builder.neuralize(self._groundings.stream(), progress, len(self)), batch_size)
@@ -1,4 +1,4 @@
1
- from typing import Union, Set, Dict, List
1
+ from typing import Union, List
2
2
 
3
3
  import jpype
4
4
 
@@ -32,8 +32,6 @@ class DatasetBuilder:
32
32
  self.query_counter = 0
33
33
  self.examples_counter = 0
34
34
 
35
- self.hooks: Dict[str, Set] = {}
36
-
37
35
  def build_queries(self, queries, query_builder):
38
36
  logic_samples = []
39
37
  one_query_per_example = True
@@ -114,13 +112,17 @@ class DatasetBuilder:
114
112
  *,
115
113
  batch_size: int = 1,
116
114
  learnable_facts: bool = False,
117
- ) -> GroundedDataset:
115
+ progress: bool = False,
116
+ raw_groundings: bool = False,
117
+ ):
118
118
  """Grounds the dataset
119
119
 
120
120
  :param dataset:
121
121
  :param settings:
122
122
  :param batch_size:
123
123
  :param learnable_facts:
124
+ :param progress:
125
+ :param raw_groundings:
124
126
  :return:
125
127
  """
126
128
  if isinstance(dataset, datasets.ConvertibleDataset):
@@ -129,6 +131,8 @@ class DatasetBuilder:
129
131
  settings,
130
132
  batch_size=batch_size,
131
133
  learnable_facts=learnable_facts,
134
+ progress=progress,
135
+ raw_groundings=raw_groundings,
132
136
  )
133
137
 
134
138
  if batch_size > 1:
@@ -172,7 +176,7 @@ class DatasetBuilder:
172
176
  queries, examples, one_query_per_example, example_queries
173
177
  )
174
178
 
175
- groundings = builder.ground_from_logic_samples(self.parsed_template, logic_samples)
179
+ groundings = builder.ground_from_logic_samples(self.parsed_template, logic_samples, progress)
176
180
 
177
181
  self.java_factory.weight_factory = weight_factory
178
182
  elif isinstance(dataset, datasets.FileDataset):
@@ -186,11 +190,16 @@ class DatasetBuilder:
186
190
  args.extend(["-e", dataset.examples_file])
187
191
  sources = Sources.from_args(args, settings)
188
192
 
189
- groundings = builder.ground_from_sources(self.parsed_template, sources)
193
+ groundings = builder.ground_from_sources(self.parsed_template, sources, progress)
190
194
  else:
191
195
  raise NotImplementedError
192
196
 
193
- return GroundedDataset(groundings, builder)
197
+ if raw_groundings:
198
+ return groundings
199
+ if progress:
200
+ return GroundedDataset(groundings, builder)
201
+
202
+ return GroundedDataset(groundings.collect(builder.collectors.toList()), builder)
194
203
 
195
204
  def build_dataset(
196
205
  self,
@@ -211,12 +220,13 @@ class DatasetBuilder:
211
220
  :return:
212
221
  """
213
222
  if not isinstance(dataset, GroundedDataset):
214
- grounded_dataset = self.ground_dataset(
215
- dataset, settings, batch_size=batch_size, learnable_facts=learnable_facts
223
+ groundings = self.ground_dataset(
224
+ dataset, settings, batch_size=batch_size, learnable_facts=learnable_facts, raw_groundings=True,
216
225
  )
217
- else:
218
- grounded_dataset = dataset
219
- return BuiltDataset(grounded_dataset.neuralize(progress=progress), batch_size)
226
+
227
+ samples = Builder(settings).neuralize(groundings, progress, None)
228
+ return BuiltDataset(samples, batch_size)
229
+ return dataset.neuralize(batch_size=batch_size, progress=progress)
220
230
 
221
231
  @staticmethod
222
232
  def merge_queries_with_examples(queries, examples, one_query_per_example, example_queries=True):
@@ -282,8 +292,15 @@ class DatasetBuilder:
282
292
  idx = id(sample.example)
283
293
 
284
294
  if idx not in example_dict:
285
- queries_dict[idx] = [sample.query]
295
+ if isinstance(sample.query, list):
296
+ queries_dict[idx] = [*sample.query]
297
+ else:
298
+ queries_dict[idx] = [sample.query]
286
299
  example_dict[idx] = sample.example
300
+ continue
301
+
302
+ if isinstance(sample.query, list):
303
+ queries_dict[idx].extend(sample.query)
287
304
  else:
288
305
  queries_dict[idx].append(sample.query)
289
306
  return example_dict.values(), queries_dict.values()
@@ -21,14 +21,13 @@ class ValueFactory:
21
21
  self.matrix_value = jpype.JClass("cz.cvut.fel.ida.algebra.values.MatrixValue")
22
22
 
23
23
  @staticmethod
24
- def from_java(value, number_format):
24
+ def from_java(value):
25
25
  size = list(value.size())
26
-
27
26
  if len(size) == 0 or size[0] == 0:
28
27
  return float(value.get(0))
29
- elif len(size) == 1 or size[0] == 1 or size[1] == 1:
30
- return list(float(x) for x in value.values)
31
- return json.loads(str(value.toString(number_format)))
28
+ if len(size) == 2 and size[1] != 1:
29
+ return np.array(memoryview(value.getAsArray())).reshape(size).tolist()
30
+ return np.array(memoryview(value.getAsArray())).tolist()
32
31
 
33
32
  def get_value(self, weight):
34
33
  if isinstance(weight, (float, int)) or np.ndim(weight) == 0:
@@ -48,7 +47,7 @@ class ValueFactory:
48
47
  else:
49
48
  value = self.matrix_value(weight[0], weight[1])
50
49
  else:
51
- raise NotImplementedError
50
+ raise NotImplementedError(f"dimensions of size {len(weight)} are not supported. If you wanted to provide tensor as weight, wrap it into a list first")
52
51
  return False, value
53
52
 
54
53
  if isinstance(weight, (Sequence, np.ndarray, Iterable)):
@@ -373,7 +372,7 @@ class JavaFactory:
373
372
  body_relation.append(self.get_relation(relation, variable_factory))
374
373
 
375
374
  for index in all_diff_index:
376
- terms = {term for term in body_relation[index].terms}
375
+ terms = {term for term in body_relation[index].terms if term is not Ellipsis}
377
376
  terms.update(term for term in all_variables)
378
377
 
379
378
  body_relation[index] = self.get_relation(R.special.alldiff(terms), variable_factory)
@@ -1,11 +1,9 @@
1
- from typing import Union, Callable, Dict, Any, Set, Collection, List, Tuple
2
- import json
1
+ from typing import Collection
3
2
 
4
3
  import jpype
5
4
 
6
5
  from neuralogic.setup import is_initialized, initialize
7
6
  from neuralogic.core.constructs.java_objects import ValueFactory
8
- from neuralogic.core.constructs.relation import BaseRelation
9
7
  from neuralogic.core.builder import DatasetBuilder
10
8
  from neuralogic.core.builder.dataset import BuiltDataset, GroundedDataset
11
9
  from neuralogic.core.settings.settings_proxy import SettingsProxy
@@ -15,7 +13,7 @@ from neuralogic.dataset.base import BaseDataset
15
13
  from neuralogic.utils.visualize import draw_model
16
14
 
17
15
 
18
- Value = Union[List, float]
16
+ Value = list | float
19
17
 
20
18
 
21
19
  class NeuralModule:
@@ -26,22 +24,6 @@ class NeuralModule:
26
24
  self._need_sync = False
27
25
  self._value_factory = ValueFactory()
28
26
 
29
- @jpype.JImplements(
30
- jpype.JClass("cz.cvut.fel.ida.neural.networks.computation.iteration.actions.PythonHookHandler")
31
- )
32
- class HookHandler:
33
- def __init__(self, module: "NeuralModule"):
34
- self.module = module
35
-
36
- @jpype.JOverride
37
- def handleHook(self, hook, value):
38
- self.module._run_hook(hook, json.loads(value))
39
-
40
- self._hooks: Dict[str, Set[Callable]] = {}
41
- self._hooks_set = False
42
-
43
- self._hook_handler = HookHandler(self)
44
-
45
27
  self._parsed_template = None
46
28
  self._dataset_builder: DatasetBuilder | None = None
47
29
  self._settings: SettingsProxy | None = None
@@ -56,10 +38,7 @@ class NeuralModule:
56
38
 
57
39
  self._weight_updater = None
58
40
  self._tensor_parameters = None
59
-
60
- from neuralogic.core.torch.neural_module import TorchNeuralModule
61
-
62
- self._torch_module = TorchNeuralModule()
41
+ self._torch_module = None
63
42
 
64
43
  def ground(
65
44
  self,
@@ -67,6 +46,7 @@ class NeuralModule:
67
46
  *,
68
47
  batch_size: int = 1,
69
48
  learnable_facts: bool = False,
49
+ progress: bool = False,
70
50
  ) -> GroundedDataset:
71
51
  if self._dataset_builder is None or self._settings is None:
72
52
  raise ValueError("template is not built")
@@ -76,11 +56,12 @@ class NeuralModule:
76
56
  self._settings,
77
57
  batch_size=batch_size,
78
58
  learnable_facts=learnable_facts,
59
+ progress=progress,
79
60
  )
80
61
 
81
62
  def build_dataset(
82
63
  self,
83
- dataset: Union[BaseDataset, GroundedDataset],
64
+ dataset: BaseDataset | GroundedDataset,
84
65
  *,
85
66
  batch_size: int = 1,
86
67
  learnable_facts: bool = False,
@@ -98,12 +79,8 @@ class NeuralModule:
98
79
  )
99
80
 
100
81
  def __call__(self, dataset=None):
101
- self._set_hooks()
102
82
  samples, _ = self._dataset_to_samples(dataset)
103
- sample_collection = samples
104
-
105
- if not isinstance(samples, Collection):
106
- sample_collection = [samples]
83
+ sample_collection = samples if isinstance(samples, Collection) else [samples]
107
84
 
108
85
  for sample in sample_collection:
109
86
  self._trainer.invalidateSample(self._invalidation, sample._java_sample)
@@ -111,7 +88,6 @@ class NeuralModule:
111
88
  results = [
112
89
  self._value_factory.from_java(
113
90
  self._trainer.evaluateSample(self._evaluation, sample._java_sample).getOutput(),
114
- SettingsProxy.number_format(),
115
91
  )
116
92
  for sample in sample_collection
117
93
  ]
@@ -124,47 +100,55 @@ class NeuralModule:
124
100
  def forward(self, dataset):
125
101
  return self(dataset)
126
102
 
127
- def train(self, dataset, epochs: int = 1) -> Tuple[Value, int]:
128
- self._set_hooks()
103
+ def train(self, dataset, epochs: int = 1) -> Value:
129
104
  samples, batch_size = self._dataset_to_samples(dataset)
130
105
 
131
106
  if not isinstance(samples, Collection):
132
107
  result = self._strategy.learnSample(samples._java_sample)
133
- res = json.loads(str(result)), 1
108
+ res = (
109
+ ValueFactory.from_java(result.getTarget()),
110
+ ValueFactory.from_java(result.getOutput()),
111
+ ValueFactory.from_java(result.errorValue()),
112
+ )
134
113
  else:
135
114
  sample_array = jpype.java.util.ArrayList([sample._java_sample for sample in samples])
136
115
  results = self._strategy.learnSamples(sample_array, epochs, batch_size)
137
- res = json.loads(str(results)), len(samples)
116
+ res = [
117
+ (
118
+ ValueFactory.from_java(result.getTarget()),
119
+ ValueFactory.from_java(result.getOutput()),
120
+ ValueFactory.from_java(result.errorValue()),
121
+ ) for result in results
122
+ ]
138
123
 
139
124
  self._update_tensor_parameters()
140
125
  return res
141
126
 
142
127
  def test(self, dataset) -> Value:
143
- self._set_hooks()
144
128
  samples, batch_size = self._dataset_to_samples(dataset)
145
129
 
146
130
  if not isinstance(samples, Collection):
147
- return json.loads(str(self._strategy.evaluateSample(samples._java_sample)))
131
+ return ValueFactory.from_java(self._strategy.evaluateSample(samples._java_sample))
148
132
 
149
133
  sample_array = jpype.java.util.ArrayList([sample._java_sample for sample in samples])
150
-
151
134
  results = self._strategy.evaluateSamples(sample_array, batch_size)
152
- return json.loads(str(results))
135
+
136
+ return [ValueFactory.from_java(result) for result in results]
153
137
 
154
138
  def reset_parameters(self):
155
139
  self._strategy.resetParameters()
156
140
 
157
- def parameters(self) -> Dict:
141
+ def parameters(self) -> dict:
158
142
  return self.state_dict()
159
143
 
160
- def state_dict(self) -> Dict:
144
+ def state_dict(self) -> dict:
161
145
  weights = self._neural_model.getAllWeights()
162
146
  weights_dict = {}
163
147
  weight_names = {}
164
148
 
165
149
  for weight in weights:
166
150
  if weight.isLearnable:
167
- weights_dict[weight.index] = ValueFactory.from_java(weight.value, SettingsProxy.number_format())
151
+ weights_dict[weight.index] = ValueFactory.from_java(weight.value)
168
152
  weight_names[weight.index] = weight.name
169
153
  return {
170
154
  "weights": weights_dict,
@@ -188,7 +172,7 @@ class NeuralModule:
188
172
  if self._torch_module is not None:
189
173
  self._torch_module.update_tensor_parameters(self._tensor_parameters)
190
174
 
191
- def load_state_dict(self, state_dict: Dict):
175
+ def load_state_dict(self, state_dict: dict):
192
176
  self._sync_template(state_dict, self._neural_model.getAllWeights())
193
177
 
194
178
  if self._torch_module is not None:
@@ -204,51 +188,25 @@ class NeuralModule:
204
188
  *args,
205
189
  **kwargs,
206
190
  ):
191
+ if self._dataset_builder is None or self._settings is None:
192
+ raise ValueError("template is not built")
207
193
  return draw_model(self, filename, show, img_type, value_detail, graphviz_path, *args, **kwargs)
208
194
 
209
- def set_hooks(self, hooks):
210
- self._hooks_set = len(hooks) != 0
211
- self._hooks = hooks
212
-
213
- def add_hook(self, relation: Union[BaseRelation, str], callback: Callable[[Any], None]):
214
- """Hooks the callable to be called with the relation's value as an argument when the value of
215
- the relation is being calculated.
216
-
217
- :param relation:
218
- :param callback:
219
- :return:
220
- """
221
- name = str(relation)
222
-
223
- if isinstance(relation, BaseRelation):
224
- name = name[:-1]
225
-
226
- if name not in self._hooks:
227
- self._hooks[name] = {callback}
228
- else:
229
- self._hooks[name].add(callback)
230
-
231
- def remove_hook(self, relation: Union[BaseRelation, str], callback):
232
- """Removes the callable from the relation's hooks
233
-
234
- :param relation:
235
- :param callback:
236
- :return:
237
- """
238
- name = str(relation)
239
-
240
- if isinstance(relation, BaseRelation):
241
- name = name[:-1]
242
-
243
- if name not in self._hooks:
244
- return
245
- self._hooks[name].discard(callback)
246
-
247
- def _initialize_neural_module(self, dataset_builder: DatasetBuilder, settings: SettingsProxy, model):
195
+ def _initialize_neural_module(self, dataset_builder: DatasetBuilder, settings: SettingsProxy, model, torch: bool):
248
196
  self._dataset_builder = dataset_builder
249
197
  self._settings = settings
250
198
  self._neural_model = model
251
199
 
200
+ if torch:
201
+ try:
202
+ import torch
203
+ except:
204
+ raise Exception("torch is not installed in the environment")
205
+
206
+ from neuralogic.core.torch.neural_module import TorchNeuralModule
207
+
208
+ self._torch_module = TorchNeuralModule()
209
+
252
210
  optimizer = self._settings.optimizer.initialize()
253
211
  lr_decay = self._settings.optimizer.get_lr_decay()
254
212
 
@@ -266,23 +224,19 @@ class NeuralModule:
266
224
 
267
225
  self.reset_parameters()
268
226
 
269
- def _run_hook(self, hook: str, value):
270
- for callback in self._hooks[hook]:
271
- callback(value)
272
-
273
227
  def _dataset_to_samples(self, dataset):
274
228
  if isinstance(dataset, Dataset):
275
229
  dataset = self.build_dataset(dataset)
276
230
  return dataset._samples, dataset._batch_size
277
231
 
232
+ if isinstance(dataset, GroundedDataset):
233
+ dataset = dataset.neuralize()
234
+ return dataset._samples, dataset._batch_size
235
+
278
236
  if isinstance(dataset, BuiltDataset):
279
237
  return dataset._samples, dataset._batch_size
280
238
  return dataset, 1
281
239
 
282
- def _set_hooks(self):
283
- if len(self._hooks) != 0:
284
- self._strategy.setHooks(set(self._hooks.keys()), self._hook_handler)
285
-
286
240
  def _sync_template(self, state_dict: dict | None = None, weights=None):
287
241
  state_dict = self.state_dict() if state_dict is None else state_dict
288
242
  weights = self._parsed_template.getAllWeights() if weights is None else weights