wafer-core 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/environments/coding.py +0 -4
- wafer_core/lib/trace_compare/__init__.py +32 -0
- wafer_core/lib/trace_compare/analyzer.py +339 -0
- wafer_core/lib/trace_compare/classifier.py +192 -0
- wafer_core/lib/trace_compare/formatter.py +951 -0
- wafer_core/lib/trace_compare/fusion_analyzer.py +890 -0
- wafer_core/lib/trace_compare/loader.py +336 -0
- wafer_core/problem_config.py +3 -3
- wafer_core/rollouts/agent_presets/rlm_01_01.py +2 -2
- wafer_core/rollouts/dtypes.py +18 -3
- wafer_core/rollouts/providers/anthropic.py +35 -3
- wafer_core/rollouts/upload.py +45 -0
- wafer_core/tools/__init__.py +0 -7
- wafer_core/utils/kernel_utils/defense.py +10 -0
- wafer_core/utils/kernel_utils/targets/config.py +10 -0
- {wafer_core-0.1.24.dist-info → wafer_core-0.1.26.dist-info}/METADATA +1 -1
- {wafer_core-0.1.24.dist-info → wafer_core-0.1.26.dist-info}/RECORD +18 -13
- wafer_core/tools/search_docs_tool.py +0 -196
- {wafer_core-0.1.24.dist-info → wafer_core-0.1.26.dist-info}/WHEEL +0 -0
|
@@ -3,7 +3,7 @@ wafer_core/async_ssh.py,sha256=ocw2Gh5p8ltKeoqG_q32DXOBfu5q-IE7jCnzMbQN9WI,28713
|
|
|
3
3
|
wafer_core/auth.py,sha256=JpUkZ3bROIsgexayak5TLiGqUAR5kqGjekwqQRvIXH0,7235
|
|
4
4
|
wafer_core/gpu.py,sha256=ENa92btjXsx6ldpoyKfRrAmfy-LHG2KpA5k7SWd6Q_s,28627
|
|
5
5
|
wafer_core/gpu_detect.py,sha256=kpD8Q_G6GA9j-WnnnTNA3BBPulkGcWnZiogOmjKDao0,13650
|
|
6
|
-
wafer_core/problem_config.py,sha256=
|
|
6
|
+
wafer_core/problem_config.py,sha256=IM4ZRul4306dF7yo8wwyxXYORUZ7nz5wnphG59HN6fo,10907
|
|
7
7
|
wafer_core/remote_env.py,sha256=0ACTL-A_qn2B43qgQakqGaern-pspvwBGB9iebz199k,15354
|
|
8
8
|
wafer_core/remote_jobs.py,sha256=7HdBDCigSxfp32BreWoljzG5xjK6fp25rwC_6D7D04s,8306
|
|
9
9
|
wafer_core/retry.py,sha256=OIvSElJZbSm4-SFBpOFuYtoX2DWGiANomCmb3qxsirM,14821
|
|
@@ -12,7 +12,7 @@ wafer_core/config/__init__.py,sha256=hKywfjA4YXd4lBeBFEcBoMwFoflPHJTiBnkTq7_JYOQ
|
|
|
12
12
|
wafer_core/config/loader.py,sha256=k7JnILmO13TWUzIv9Lm8fvmj3UfYHZDgaFurjQ-GXpY,6623
|
|
13
13
|
wafer_core/config/schema.py,sha256=2WhFlnG0VYYX4T-70BLeJK8Janvi4KEa8KKGZA7331w,3898
|
|
14
14
|
wafer_core/environments/__init__.py,sha256=SIsResVtm22tr_d-oHPeeSxrkhFdmPOFico3DqtRqK8,238
|
|
15
|
-
wafer_core/environments/coding.py,sha256=
|
|
15
|
+
wafer_core/environments/coding.py,sha256=T-_JFU-n5OxPR8xAWp8qar4Y5xyC-TWTIBjRy4PDel8,8418
|
|
16
16
|
wafer_core/environments/gpumode.py,sha256=8Da08nltvN_YloNyYI6-omN2D4n5C7aptKDCtUgT2bQ,17191
|
|
17
17
|
wafer_core/lib/__init__.py,sha256=4-4p3mhwlquejWGglYXU8_nHdA0LoPaa_jGzcm13USA,1325
|
|
18
18
|
wafer_core/lib/kernel_scope/__init__.py,sha256=WW2vu8jUlqOu-MCpgO40lIYacCA9N2u-uuECIs_JO2w,2817
|
|
@@ -318,6 +318,12 @@ wafer_core/lib/rocprofiler/systems/run/analyzer.py,sha256=Qg3M8-kCKdV82ehn6Ta20N
|
|
|
318
318
|
wafer_core/lib/rocprofiler/systems/run/profiler.py,sha256=aiQLsDnfQHSeCM5zLnO4VlbTmREYnAtiuT50Eq6uWfg,8387
|
|
319
319
|
wafer_core/lib/rocprofiler/systems/sample/__init__.py,sha256=31rNmLPQ7OVhvlOEEOwPKgk8_qrCidj6AmzDXexQJ_o,288
|
|
320
320
|
wafer_core/lib/rocprofiler/systems/sample/profiler.py,sha256=CYZPTzNXd48LoCfmY6h_5RSYEdWYccuv3-t4YncHJLE,7384
|
|
321
|
+
wafer_core/lib/trace_compare/__init__.py,sha256=G5vmiQnuweiF9vjK1FC4ZIy-tzuHiaLMs7QBnir8OJw,800
|
|
322
|
+
wafer_core/lib/trace_compare/analyzer.py,sha256=o0SI1PsehpgxeUPQEB9708W_Q_ILiO5apgqVLe2xE8A,14541
|
|
323
|
+
wafer_core/lib/trace_compare/classifier.py,sha256=sE1K007GVk_Up2g59SVUIZ7BThf0yHNjGsZ9AyM_Ah8,6028
|
|
324
|
+
wafer_core/lib/trace_compare/formatter.py,sha256=GNrCZ45ueBN05CEXjOtTuKvTI8z-g-ZZFil-ni3sWVY,37962
|
|
325
|
+
wafer_core/lib/trace_compare/fusion_analyzer.py,sha256=LwYTBjL_gHCvydfgFp-L9f_qfXq3GenJHRemygly4H8,36482
|
|
326
|
+
wafer_core/lib/trace_compare/loader.py,sha256=E7-OS4uMqvJhGLyxKQNnAgK33YECrSjuCssUT_X0LQA,11728
|
|
321
327
|
wafer_core/lib/tracelens/__init__.py,sha256=AkHdmOnKlBO4RpsAqVVGe7MOfv6E6uhEaC_iKrYeMPI,2002
|
|
322
328
|
wafer_core/lib/tracelens/comparator.py,sha256=71YEPfjBi7_24u1oQuPerNtFsN0sDQ5CT_uBi0XLllw,3460
|
|
323
329
|
wafer_core/lib/tracelens/finder.py,sha256=HpbN8TuRNbbBytPYOmkBkfsFVBReQqVgsvFX-mBrln4,2459
|
|
@@ -336,7 +342,7 @@ wafer_core/rollouts/agents.py,sha256=Uv1kjYogUfdPl18YfkVxVqFTbmWfuJQrxem_iHTUgdw
|
|
|
336
342
|
wafer_core/rollouts/cli.py,sha256=2NqgegKdlmxD0eJzGOMB5o_1Hb5t7O5JpP_32uvF2BE,80117
|
|
337
343
|
wafer_core/rollouts/cli_agents.py,sha256=e4qqqYBzWLsbw8FsNnddGApWp_on9Cvzrfd1amiAyvI,20641
|
|
338
344
|
wafer_core/rollouts/deploy.py,sha256=3t88fM_BMyAPkxIl8pS4r5ogHJvrlqWQDuIaltDZBRc,40924
|
|
339
|
-
wafer_core/rollouts/dtypes.py,sha256=
|
|
345
|
+
wafer_core/rollouts/dtypes.py,sha256=oRWjpbUOTf4uyXvnO9QThcSzD1fBrDQnAfRhGbxdgrg,61916
|
|
340
346
|
wafer_core/rollouts/eval_helpers.py,sha256=OE7uQZRcbqQhpFqb4zOj8zafc9Gr6xZJpSrMvxXKVUw,1699
|
|
341
347
|
wafer_core/rollouts/evaluation.py,sha256=fk-pGZ5vpocVmw1iBbHtxMK0K6l8pYTLHCpDNvRY1Xo,69142
|
|
342
348
|
wafer_core/rollouts/events.py,sha256=z85J8kq0LXPj5CiUk4RkiTQg--r9xiO7QeeJwkyUOto,7505
|
|
@@ -359,7 +365,7 @@ wafer_core/rollouts/skills.py,sha256=ATYoG02Cc6_VrtE415TnseBFJrKOMq27z-5YgBgPpZQ
|
|
|
359
365
|
wafer_core/rollouts/slice.py,sha256=darOZO53BuSPfvv_KjOSzulGVSWbL4OuoE3k6xXpBFg,20195
|
|
360
366
|
wafer_core/rollouts/store.py,sha256=UDP9idDOEVs_0Pslx0K_Y8E1i-BeoqVSaxdQiaqtz1E,18051
|
|
361
367
|
wafer_core/rollouts/transform_messages.py,sha256=yldzdLgugNYb5Zxju7myFBel1tmrHXx9M399ImqPLGI,20891
|
|
362
|
-
wafer_core/rollouts/upload.py,sha256=
|
|
368
|
+
wafer_core/rollouts/upload.py,sha256=hEqZfwgb0b4GYbrwRSA3fuqF70pqo6hyaQU59j3vM7E,4890
|
|
363
369
|
wafer_core/rollouts/_logging/__init__.py,sha256=rCXeAssQ3gIrduuMzvKPD-ikt6rXejVL9h5XtDRyIQg,498
|
|
364
370
|
wafer_core/rollouts/_logging/color_formatter.py,sha256=x3qRKwHsUCFkgcIl8x_Ajjw82X2EedbTe14sCxMU4Kc,2267
|
|
365
371
|
wafer_core/rollouts/_logging/json_formatter.py,sha256=jJIa2IZCsu2C_Y1HXQi7hbI33x6L6shN_dqu-hmhxp4,2380
|
|
@@ -371,7 +377,7 @@ wafer_core/rollouts/agent_presets/gpt_5_1_codex_04_04.py,sha256=42NIBBYAnVoy5mbu
|
|
|
371
377
|
wafer_core/rollouts/agent_presets/gpt_5_2_03_03.py,sha256=lEsHRUhhr8UbP5wSVKMOVDVOOtH_bQMRRgZ0dRGZMVc,1166
|
|
372
378
|
wafer_core/rollouts/agent_presets/loader.py,sha256=WSkTbL7QhgMamZR5sXxep1n4cuy8LC3a4MN2phYTm-4,3666
|
|
373
379
|
wafer_core/rollouts/agent_presets/opus_4_01_01.py,sha256=rurZMI-Df7O-Q-uVJj2zfY_DSjdNbMKBDZlRg9MLADc,3568
|
|
374
|
-
wafer_core/rollouts/agent_presets/rlm_01_01.py,sha256=
|
|
380
|
+
wafer_core/rollouts/agent_presets/rlm_01_01.py,sha256=jsjwDgACQxxJj4GYOUCcQvYjcICAaKV3eccQu9oyEcw,4781
|
|
375
381
|
wafer_core/rollouts/agent_presets/sonnet_4_02_02.py,sha256=ZdHNxioki3wsfD6ficgB2r7HkgQDH_trCR-baGFgoHk,1269
|
|
376
382
|
wafer_core/rollouts/agent_presets/sonnet_4_subagent_03_02.py,sha256=nxyjs4HWAPOAYLmPknSQr3viBXhboKC7wQ76LWB-jA0,2165
|
|
377
383
|
wafer_core/rollouts/config/README.md,sha256=i0r0a3sKLkc1Eq3EqqR2Gahsgo-c8O3OZ0cCh7rp8Uw,9899
|
|
@@ -495,7 +501,7 @@ wafer_core/rollouts/prompt_optimization/adapters/system_prompt.py,sha256=CWFox1N
|
|
|
495
501
|
wafer_core/rollouts/prompt_optimization/adapters/system_user_prompt.py,sha256=8JsSirihgZ5gacyYhn31GagyIxG0xQ7f7i4PnEupWz8,12090
|
|
496
502
|
wafer_core/rollouts/prompt_optimization/adapters/terminal_bench.py,sha256=Etswuqf5dBIZQ2x2p26AXz4LT33YxT2qEeHqKXTJy18,12273
|
|
497
503
|
wafer_core/rollouts/providers/__init__.py,sha256=Xu8PPDHOmF97ylMJXfE9JX2FJCanNVh7LXkHMmg0vWs,3121
|
|
498
|
-
wafer_core/rollouts/providers/anthropic.py,sha256=
|
|
504
|
+
wafer_core/rollouts/providers/anthropic.py,sha256=9x1GIL6JE8gutxVrLNiyAkymknIEKtl-98TnIUpFxoI,45223
|
|
499
505
|
wafer_core/rollouts/providers/base.py,sha256=2ADu6pDz6yEcazo4j6-O12rs19bPewAfycjK_N03ZkY,14544
|
|
500
506
|
wafer_core/rollouts/providers/google.py,sha256=IbqdXOpzSuMdI7eOZqRtzni85ysKby13PGe482Fq13w,22073
|
|
501
507
|
wafer_core/rollouts/providers/openai_completions.py,sha256=3vUA74qjrxG-aOjyngtnZp0MzIhnzW5kudwxmOGxXfs,28820
|
|
@@ -586,11 +592,10 @@ wafer_core/sessions/hooks.py,sha256=A-txm6ufnRGQCdtP3vwh7oEOdlLN9Tv0XsjORMihuAI,
|
|
|
586
592
|
wafer_core/targets/__init__.py,sha256=sHndC7AAOaHXlrmDXFLB53a5Y8DBjuyqS6nwsO2nj-Y,1728
|
|
587
593
|
wafer_core/targets/digitalocean.py,sha256=cvoYpYjtSyy5t2lQAPi7ERruuuibronah_ivOiduAHQ,16550
|
|
588
594
|
wafer_core/targets/runpod.py,sha256=LrVmNvA6qjzL5nbGSWvtw7CHrK6bDu7_o3vKIek00Tc,20286
|
|
589
|
-
wafer_core/tools/__init__.py,sha256=
|
|
595
|
+
wafer_core/tools/__init__.py,sha256=wBQD45GdSfkxcT6NHzIv0IMeXCc0enwwkpm3T_9j1X8,3341
|
|
590
596
|
wafer_core/tools/bash_tool.py,sha256=daoKOVGSgL0x9X_3l8Apd6-wFH4VMXMGJwVemw2FIfc,16828
|
|
591
597
|
wafer_core/tools/glob_tool.py,sha256=9X5PdOjQJj7kiVNqqCZC0-1LmnE6wHx3Zc9zfMjtXdc,3533
|
|
592
598
|
wafer_core/tools/grep_tool.py,sha256=cStyDz-J47oDLLZCL83yOvYo8Ijv4qu3D372JKT_ptM,4580
|
|
593
|
-
wafer_core/tools/search_docs_tool.py,sha256=WY4hY83sseX8Fpxvw6DZxiG-F95F2t3-4PyfMD1Lpkg,6809
|
|
594
599
|
wafer_core/tools/skill_tool.py,sha256=JXsT5hBTUH5U4tmzHEywU7eHHt5xCEF79tL2tsuk4-c,2067
|
|
595
600
|
wafer_core/tools/wafer_tool.py,sha256=-dgPTHbWXq3I3wFj0mP7-lj5iZqGRoFvFf9IEEo3plQ,6345
|
|
596
601
|
wafer_core/tools/write_kernel_tool.py,sha256=dJjhr-WBhVNe06hcJQVmBZTbS8mid64KF1MwlE2s2R4,21547
|
|
@@ -656,7 +661,7 @@ wafer_core/utils/remote_execution.py,sha256=z7nLiOgmDiM_VmElLnT2LF-aKNeeKFYjXigT
|
|
|
656
661
|
wafer_core/utils/submission_selection.py,sha256=LucdMTAbkqZA-GitSb3ZJ2pAeJ36wKqt5cTeS8xuAQ4,5655
|
|
657
662
|
wafer_core/utils/kernel_utils/__init__.py,sha256=NsfKpbfpIsfupWIpIjWLGCjGAVqaONiwiWil5zXbrRc,2015
|
|
658
663
|
wafer_core/utils/kernel_utils/backends.py,sha256=t3wY73Y-pVc_wALNu_bPsaFkqJ2dp2pf38KQ5ofP_go,1143
|
|
659
|
-
wafer_core/utils/kernel_utils/defense.py,sha256=
|
|
664
|
+
wafer_core/utils/kernel_utils/defense.py,sha256=8tHVTZlJfFcB_FWjNZfeGHwReSjG191OmFXtWXa07OM,20124
|
|
660
665
|
wafer_core/utils/kernel_utils/deployment.py,sha256=-tMb3qWmAoXHWCmmT7SQBH7KBKyyLP0e5Dk6lOrTPW8,55957
|
|
661
666
|
wafer_core/utils/kernel_utils/evaluate.py,sha256=1kxFNMl9VCXfKfk_BIiuA_zFfvDB1sl_feS2OEIJA1k,72346
|
|
662
667
|
wafer_core/utils/kernel_utils/gpu_validation.py,sha256=LRiDjW_xAK4fXf1Vw1aYHG54B1W0J6b5L0K6PXzM2tI,3759
|
|
@@ -666,7 +671,7 @@ wafer_core/utils/kernel_utils/static_checker.py,sha256=XIQkzAOkGH5xtrOuZM4tNUqVJ
|
|
|
666
671
|
wafer_core/utils/kernel_utils/task.py,sha256=XcmKxKUWh5It6nX3zGqj77tWgA32uPfQMqNOqyD5T48,2682
|
|
667
672
|
wafer_core/utils/kernel_utils/utils.py,sha256=uDZoJDxh07hJeLNlPdKN2vgB15pqIr1LbXf0YIBHU4E,43056
|
|
668
673
|
wafer_core/utils/kernel_utils/targets/__init__.py,sha256=4NwRLsuJ__S4xKAfda4Ag82C5MQ3Qio-4xA5S-mQGlU,2067
|
|
669
|
-
wafer_core/utils/kernel_utils/targets/config.py,sha256=
|
|
674
|
+
wafer_core/utils/kernel_utils/targets/config.py,sha256=V587DYkisEFoWwkmLQUW6I0mXkMEwA52sM7ZINslkK8,20625
|
|
670
675
|
wafer_core/utils/kernel_utils/targets/execution.py,sha256=bZuNXCo0sIdD6hFhetLPrtDC-zMSiIsAx_aml49VVL0,15033
|
|
671
676
|
wafer_core/utils/kernel_utils/targets/selection.py,sha256=5I_RG_7cfhq7uaeR28meC2EeNNKssFsK-Tc3QFG6Ze0,3590
|
|
672
677
|
wafer_core/utils/modal_execution/__init__.py,sha256=jkVqYOLzCT5K73N9Od0UIUsx-99A0m6bpDrxfyXxQZ8,945
|
|
@@ -674,6 +679,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdh
|
|
|
674
679
|
wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
|
|
675
680
|
wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
|
|
676
681
|
wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
|
|
677
|
-
wafer_core-0.1.
|
|
678
|
-
wafer_core-0.1.
|
|
679
|
-
wafer_core-0.1.
|
|
682
|
+
wafer_core-0.1.26.dist-info/METADATA,sha256=xzTIIcsmbJkA06hTdoRb4uXZj2ud1-wnV7EXdLOSOe4,1420
|
|
683
|
+
wafer_core-0.1.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
684
|
+
wafer_core-0.1.26.dist-info/RECORD,,
|
|
@@ -1,196 +0,0 @@
|
|
|
1
|
-
"""Search documentation tool for GPU programming corpora.
|
|
2
|
-
|
|
3
|
-
Provides semantic and keyword search over documentation for CuTeDSL, CUDA, etc.
|
|
4
|
-
|
|
5
|
-
Corpora are downloaded via `wafer corpus download <name>` and stored in ~/.cache/wafer/corpora/.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import re
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
|
|
11
|
-
from wafer_core.rollouts.dtypes import Tool, ToolCall, ToolFunction, ToolFunctionParameter, ToolResult
|
|
12
|
-
|
|
13
|
-
# Cache directory where wafer corpus download stores files
|
|
14
|
-
CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
|
|
15
|
-
|
|
16
|
-
# Available corpora (names match wafer corpus download)
|
|
17
|
-
AVAILABLE_CORPORA = ["cutlass", "cutedsl", "cuda", "hip", "amd"]
|
|
18
|
-
|
|
19
|
-
SEARCH_DOCS_TOOL = Tool(
|
|
20
|
-
type="function",
|
|
21
|
-
function=ToolFunction(
|
|
22
|
-
name="search_docs",
|
|
23
|
-
description="""Search GPU programming documentation for relevant information.
|
|
24
|
-
|
|
25
|
-
Use this tool to find documentation about:
|
|
26
|
-
- CUTLASS C++ (cute:: namespace, gemm tutorials, tensor cores, TMA, Blackwell)
|
|
27
|
-
- CuTeDSL Python API (@cute.kernel, @cute.jit, cute.arch functions)
|
|
28
|
-
- CUDA programming concepts
|
|
29
|
-
- GPU kernel optimization techniques
|
|
30
|
-
- Code examples and patterns
|
|
31
|
-
|
|
32
|
-
Available corpora:
|
|
33
|
-
- 'cutlass' - NVIDIA CUTLASS C++ docs + GitHub examples (gemm, hopper, blackwell)
|
|
34
|
-
- 'cutedsl' - CuTeDSL Python documentation
|
|
35
|
-
- 'cuda' - General CUDA programming docs
|
|
36
|
-
- 'hip' - AMD HIP programming docs
|
|
37
|
-
- 'amd' - AMD GPU kernel development (rocWMMA, CK, etc.)
|
|
38
|
-
|
|
39
|
-
Note: Corpora must be downloaded first with `wafer corpus download <name>`.
|
|
40
|
-
Returns relevant documentation snippets with file paths.""",
|
|
41
|
-
parameters=ToolFunctionParameter(
|
|
42
|
-
type="object",
|
|
43
|
-
properties={
|
|
44
|
-
"query": {
|
|
45
|
-
"type": "string",
|
|
46
|
-
"description": "Search query - describe what you're looking for",
|
|
47
|
-
},
|
|
48
|
-
"corpus": {
|
|
49
|
-
"type": "string",
|
|
50
|
-
"description": "Which docs to search: 'cutlass', 'cutedsl', 'cuda', 'hip', 'amd' (default: cutlass)",
|
|
51
|
-
},
|
|
52
|
-
"max_results": {
|
|
53
|
-
"type": "integer",
|
|
54
|
-
"description": "Maximum number of results to return (default: 5)",
|
|
55
|
-
},
|
|
56
|
-
},
|
|
57
|
-
),
|
|
58
|
-
required=["query"],
|
|
59
|
-
)
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _get_corpus_path(corpus_name: str) -> Path | None:
|
|
64
|
-
"""Get the path to a corpus in the cache directory.
|
|
65
|
-
|
|
66
|
-
Corpora are stored at ~/.cache/wafer/corpora/<corpus_name>/
|
|
67
|
-
"""
|
|
68
|
-
if corpus_name not in AVAILABLE_CORPORA:
|
|
69
|
-
return None
|
|
70
|
-
|
|
71
|
-
corpus_path = CACHE_DIR / corpus_name
|
|
72
|
-
if corpus_path.exists():
|
|
73
|
-
return corpus_path
|
|
74
|
-
|
|
75
|
-
return None
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def _search_files(corpus_path: Path, query: str, max_results: int = 5) -> list[dict]:
|
|
79
|
-
"""Simple keyword search through documentation files."""
|
|
80
|
-
results = []
|
|
81
|
-
query_terms = query.lower().split()
|
|
82
|
-
|
|
83
|
-
# Search .md, .py, .cu, .hpp, and .h files (for CUTLASS examples)
|
|
84
|
-
for pattern in ["**/*.md", "**/*.py", "**/*.cu", "**/*.hpp", "**/*.h", "**/*.cuh"]:
|
|
85
|
-
for file_path in corpus_path.glob(pattern):
|
|
86
|
-
if file_path.is_file():
|
|
87
|
-
try:
|
|
88
|
-
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
89
|
-
content_lower = content.lower()
|
|
90
|
-
|
|
91
|
-
# Score based on term matches
|
|
92
|
-
score = sum(content_lower.count(term) for term in query_terms)
|
|
93
|
-
|
|
94
|
-
if score > 0:
|
|
95
|
-
# Extract relevant snippets
|
|
96
|
-
snippets = _extract_snippets(content, query_terms)
|
|
97
|
-
results.append({
|
|
98
|
-
"file": str(file_path), # Return absolute path so read tool can access it
|
|
99
|
-
"score": score,
|
|
100
|
-
"snippets": snippets[:3], # Top 3 snippets
|
|
101
|
-
})
|
|
102
|
-
except Exception:
|
|
103
|
-
continue
|
|
104
|
-
|
|
105
|
-
# Sort by score and return top results
|
|
106
|
-
results.sort(key=lambda x: x["score"], reverse=True)
|
|
107
|
-
return results[:max_results]
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
def _extract_snippets(content: str, terms: list[str], context_lines: int = 5) -> list[str]:
|
|
111
|
-
"""Extract snippets containing search terms."""
|
|
112
|
-
snippets = []
|
|
113
|
-
lines = content.split("\n")
|
|
114
|
-
|
|
115
|
-
for i, line in enumerate(lines):
|
|
116
|
-
line_lower = line.lower()
|
|
117
|
-
if any(term in line_lower for term in terms):
|
|
118
|
-
# Get context around the match
|
|
119
|
-
start = max(0, i - context_lines)
|
|
120
|
-
end = min(len(lines), i + context_lines + 1)
|
|
121
|
-
snippet = "\n".join(lines[start:end])
|
|
122
|
-
|
|
123
|
-
# Skip very short snippets
|
|
124
|
-
if len(snippet.strip()) > 50:
|
|
125
|
-
snippets.append(snippet)
|
|
126
|
-
|
|
127
|
-
return snippets
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
async def exec_search_docs(
|
|
131
|
-
tool_call: ToolCall,
|
|
132
|
-
corpus_override: str | None = None,
|
|
133
|
-
) -> ToolResult:
|
|
134
|
-
"""Execute search_docs tool.
|
|
135
|
-
|
|
136
|
-
Args:
|
|
137
|
-
tool_call: The tool call with query and optional corpus
|
|
138
|
-
corpus_override: Override corpus path (for testing)
|
|
139
|
-
"""
|
|
140
|
-
query = tool_call.args.get("query", "")
|
|
141
|
-
corpus_name = tool_call.args.get("corpus", "cutlass")
|
|
142
|
-
max_results = tool_call.args.get("max_results", 5)
|
|
143
|
-
|
|
144
|
-
if not query:
|
|
145
|
-
return ToolResult(
|
|
146
|
-
tool_call_id=tool_call.id,
|
|
147
|
-
content="",
|
|
148
|
-
error="query parameter is required",
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
# Find corpus path
|
|
152
|
-
if corpus_override:
|
|
153
|
-
corpus_path = Path(corpus_override)
|
|
154
|
-
else:
|
|
155
|
-
corpus_path = _get_corpus_path(corpus_name)
|
|
156
|
-
if corpus_path is None:
|
|
157
|
-
return ToolResult(
|
|
158
|
-
tool_call_id=tool_call.id,
|
|
159
|
-
content="",
|
|
160
|
-
error=f"Unknown corpus: {corpus_name}. Available: {AVAILABLE_CORPORA}",
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
if not corpus_path.exists():
|
|
164
|
-
return ToolResult(
|
|
165
|
-
tool_call_id=tool_call.id,
|
|
166
|
-
content="",
|
|
167
|
-
error=f"Corpus '{corpus_name}' not downloaded. Run: wafer corpus download {corpus_name}",
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
# Search
|
|
171
|
-
results = _search_files(corpus_path, query, max_results)
|
|
172
|
-
|
|
173
|
-
if not results:
|
|
174
|
-
return ToolResult(
|
|
175
|
-
tool_call_id=tool_call.id,
|
|
176
|
-
content=f"No results found for query: {query}",
|
|
177
|
-
error=None,
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
# Format output
|
|
181
|
-
output_parts = [f"Found {len(results)} results for: {query}\n"]
|
|
182
|
-
|
|
183
|
-
for i, result in enumerate(results, 1):
|
|
184
|
-
output_parts.append(f"\n{'='*60}")
|
|
185
|
-
output_parts.append(f"[{i}] {result['file']} (score: {result['score']})")
|
|
186
|
-
output_parts.append("=" * 60)
|
|
187
|
-
|
|
188
|
-
for snippet in result["snippets"]:
|
|
189
|
-
output_parts.append(snippet)
|
|
190
|
-
output_parts.append("-" * 40)
|
|
191
|
-
|
|
192
|
-
return ToolResult(
|
|
193
|
-
tool_call_id=tool_call.id,
|
|
194
|
-
content="\n".join(output_parts),
|
|
195
|
-
error=None,
|
|
196
|
-
)
|
|
File without changes
|