claude-music 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md ADDED
@@ -0,0 +1,83 @@
1
+ # claude-music
2
+
3
+ Live AI background music for Claude Code, scored in real time by Claude itself.
4
+ You type `/music`, music starts playing, and as Claude works it silently steers
5
+ the soundtrack to match — calm and ambient while reading, driving and intense
6
+ while grinding through a hard problem, a victory lap when the tests go green.
7
+
8
+ Powered by Google's [Magenta RealTime 2](https://github.com/magenta/magenta-realtime)
9
+ (`mrt2_small`, 230M params) running locally on Apple Silicon.
10
+
11
+ ## Install
12
+
13
+ ```bash
14
+ npm install -g claude-music # or: pnpm add -g claude-music
15
+ claude-music install
16
+ ```
17
+
18
+ `claude-music install` downloads the prebuilt audio daemon and the model, registers
19
+ the MCP server with Claude Code, and installs the `/music` command. **Restart Claude
20
+ Code** afterwards so it loads the MCP server, then:
21
+
22
+ ```
23
+ /music # start — Claude picks an opening vibe
24
+ /music stop # stop
25
+ /music status # what's currently playing
26
+ ```
27
+
28
+ Once running you don't need to do anything — Claude re-vibes on its own as the work
29
+ changes, and music stops when you end the session.
30
+
31
+ ### Requirements
32
+
33
+ - **Apple-Silicon Mac** (M-series) — required by MLX/Metal.
34
+ - **Node ≥ 18** and the **`claude`** CLI (Claude Code).
35
+ - ~1.3 GB of disk for the model, downloaded on first install.
36
+
37
+ No Xcode, CMake, or Python needed — the daemon ships as a prebuilt binary. (If a
38
+ prebuilt binary can't be fetched, `install` falls back to building from source,
39
+ which *does* require Xcode CLT + CMake.)
40
+
41
+ ## How it works
42
+
43
+ ```
44
+ Claude Code ──set_music_vibe(prompt)──▶ mrt2-mcp (the `claude-music mcp` server)
45
+ ──JSON over ~/.mrt2d.sock──▶ mrt2d (C++ daemon, links magentart::core) ──CoreAudio──▶ ♪
46
+ ```
47
+
48
+ - **`mrt2d`** — a C++ daemon wrapping `magentart::core::RealtimeRunner`, streaming
49
+ audio to CoreAudio and taking prompt commands over a Unix socket. Cached at
50
+ `~/.claude-music/bin/mrt2d`. Source lives in `daemon/`.
51
+ - **MCP server** (`src/server.ts`, run via `claude-music mcp`) — exposes
52
+ `set_music_vibe` / `stop_music`. The vibe tool auto-spawns the daemon, so starting
53
+ music is a single silent tool call.
54
+ - **`templates/music.md`** — the `/music` slash command, installed to
55
+ `~/.claude/commands/music.md`.
56
+ - The model + resources download to `~/Documents/Magenta/magenta-rt-v2/`.
57
+
58
+ ## CLI
59
+
60
+ ```
61
+ claude-music install download daemon + model, register MCP, install /music
62
+ claude-music uninstall remove MCP registration + /music (--purge: also delete cache + model)
63
+ claude-music doctor check that everything is installed
64
+ claude-music start|stop|status
65
+ claude-music mcp run the MCP server (used internally by Claude Code)
66
+ ```
67
+
68
+ Override the daemon binary with `MRT2D_BIN`, or its download URL with
69
+ `CLAUDE_MUSIC_BINARY_URL`.
70
+
71
+ ## Maintainers — cutting a release
72
+
73
+ The prebuilt daemon is produced by `scripts/build-release.sh`, which builds `mrt2d`
74
+ against a fresh `magenta-realtime` checkout, bundles its non-system dylibs with
75
+ `@loader_path` rewrites, and emits `dist-release/mrt2d-darwin-arm64.tar.gz`:
76
+
77
+ ```bash
78
+ bash scripts/build-release.sh
79
+ # test the artifact end-to-end before publishing:
80
+ CLAUDE_MUSIC_BINARY_URL="file://$PWD/dist-release/mrt2d-darwin-arm64.tar.gz" claude-music install
81
+ # publish to a public GitHub release:
82
+ gh release upload vX.Y.Z dist-release/mrt2d-darwin-arm64.tar.gz --repo <owner>/claude-music
83
+ ```
@@ -0,0 +1,26 @@
1
+ # Consumed by magenta-realtime/CMakeLists.txt via add_subdirectory.
2
+ # All heavy deps (MLX, TFLite, SentencePiece, nlohmann_json, magentart::core)
3
+ # are already configured by the parent project — just add targets here.
4
+
5
+ add_executable(mrt2d
6
+ src/main.cpp
7
+ src/command.cpp
8
+ src/socket_server.cpp
9
+ src/audio_output.cpp
10
+ )
11
+ target_include_directories(mrt2d PRIVATE include)
12
+ target_link_libraries(mrt2d PRIVATE
13
+ magentart::core
14
+ nlohmann_json::nlohmann_json
15
+ "-framework CoreAudio"
16
+ "-framework AudioUnit"
17
+ "-framework AudioToolbox"
18
+ )
19
+
20
+ # ---- unit test binaries (no model / audio deps) -----------------------------
21
+ add_executable(test_command tests/test_command.cpp src/command.cpp)
22
+ target_include_directories(test_command PRIVATE include)
23
+ target_link_libraries(test_command PRIVATE nlohmann_json::nlohmann_json)
24
+
25
+ add_executable(test_socket tests/test_socket.cpp src/socket_server.cpp)
26
+ target_include_directories(test_socket PRIVATE include)
@@ -0,0 +1,36 @@
1
+ #pragma once
2
+
3
+ #include <AudioUnit/AudioUnit.h>
4
+ #include <cstddef>
5
+ #include <functional>
6
+
7
+ namespace mrt2d {
8
+
9
+ // CoreAudio default output unit. Calls fill(L, R, count) from the I/O thread.
10
+ // 48 kHz stereo float32 non-interleaved — matches RealtimeRunner output.
11
+ class AudioOutput {
12
+ public:
13
+ using FillCallback = std::function<void(float* L, float* R, std::size_t count)>;
14
+
15
+ AudioOutput();
16
+ ~AudioOutput();
17
+
18
+ AudioOutput(const AudioOutput&) = delete;
19
+ AudioOutput& operator=(const AudioOutput&) = delete;
20
+
21
+ void start(FillCallback fill);
22
+ void stop();
23
+
24
+ private:
25
+ static OSStatus render_cb(void* refcon,
26
+ AudioUnitRenderActionFlags* flags,
27
+ const AudioTimeStamp* ts,
28
+ UInt32 bus,
29
+ UInt32 num_frames,
30
+ AudioBufferList* data);
31
+
32
+ AudioUnit unit_{nullptr};
33
+ FillCallback fill_;
34
+ };
35
+
36
+ } // namespace mrt2d
@@ -0,0 +1,24 @@
1
+ #pragma once
2
+
3
+ #include <nlohmann/json.hpp>
4
+ #include <optional>
5
+ #include <string>
6
+ #include <variant>
7
+
8
+ namespace mrt2d {
9
+
10
+ struct VibeCommand {
11
+ std::string prompt;
12
+ };
13
+
14
+ struct StatusCommand {};
15
+
16
+ using Command = std::variant<VibeCommand, StatusCommand>;
17
+
18
+ std::optional<Command> parse_command(const std::string& line);
19
+
20
+ std::string ack_ok(const std::string& prompt);
21
+ std::string ack_status(bool running, const std::string& prompt);
22
+ std::string ack_error(const std::string& message);
23
+
24
+ } // namespace mrt2d
@@ -0,0 +1,30 @@
1
+ #pragma once
2
+
3
+ #include <atomic>
4
+ #include <functional>
5
+ #include <string>
6
+
7
+ namespace mrt2d {
8
+
9
+ class SocketServer {
10
+ public:
11
+ explicit SocketServer(const std::string& path);
12
+ ~SocketServer();
13
+
14
+ SocketServer(const SocketServer&) = delete;
15
+ SocketServer& operator=(const SocketServer&) = delete;
16
+
17
+ // Block until stop(). Reads newline-delimited messages per client,
18
+ // writes handler(line) back.
19
+ void run(std::function<std::string(const std::string&)> handler);
20
+
21
+ // Thread-safe. Unblocks run().
22
+ void stop();
23
+
24
+ private:
25
+ std::string path_;
26
+ int server_fd_{-1};
27
+ std::atomic<bool> running_{false};
28
+ };
29
+
30
+ } // namespace mrt2d
@@ -0,0 +1,68 @@
1
+ #include "audio_output.h"
2
+
3
+ #include <AudioToolbox/AudioToolbox.h>
4
+ #include <stdexcept>
5
+
6
+ namespace mrt2d {
7
+
8
+ AudioOutput::AudioOutput() {
9
+ AudioComponentDescription desc{};
10
+ desc.componentType = kAudioUnitType_Output;
11
+ desc.componentSubType = kAudioUnitSubType_DefaultOutput;
12
+ desc.componentManufacturer = kAudioUnitManufacturer_Apple;
13
+
14
+ AudioComponent comp = AudioComponentFindNext(nullptr, &desc);
15
+ if (!comp) throw std::runtime_error("AudioComponentFindNext failed");
16
+
17
+ if (AudioComponentInstanceNew(comp, &unit_) != noErr)
18
+ throw std::runtime_error("AudioComponentInstanceNew failed");
19
+
20
+ AudioStreamBasicDescription fmt{};
21
+ fmt.mSampleRate = 48000.0;
22
+ fmt.mFormatID = kAudioFormatLinearPCM;
23
+ fmt.mFormatFlags = kAudioFormatFlagIsFloat | kAudioFormatFlagIsNonInterleaved;
24
+ fmt.mBytesPerPacket = 4;
25
+ fmt.mFramesPerPacket = 1;
26
+ fmt.mBytesPerFrame = 4;
27
+ fmt.mChannelsPerFrame = 2;
28
+ fmt.mBitsPerChannel = 32;
29
+ AudioUnitSetProperty(unit_, kAudioUnitProperty_StreamFormat,
30
+ kAudioUnitScope_Input, 0, &fmt, sizeof(fmt));
31
+
32
+ if (AudioUnitInitialize(unit_) != noErr)
33
+ throw std::runtime_error("AudioUnitInitialize failed");
34
+ }
35
+
36
+ AudioOutput::~AudioOutput() {
37
+ stop();
38
+ if (unit_) { AudioComponentInstanceDispose(unit_); unit_ = nullptr; }
39
+ }
40
+
41
+ void AudioOutput::start(FillCallback fill) {
42
+ fill_ = std::move(fill);
43
+ AURenderCallbackStruct cb{render_cb, this};
44
+ AudioUnitSetProperty(unit_, kAudioUnitProperty_SetRenderCallback,
45
+ kAudioUnitScope_Input, 0, &cb, sizeof(cb));
46
+ if (AudioOutputUnitStart(unit_) != noErr)
47
+ throw std::runtime_error("AudioOutputUnitStart failed");
48
+ }
49
+
50
+ void AudioOutput::stop() {
51
+ if (unit_) AudioOutputUnitStop(unit_);
52
+ fill_ = nullptr;
53
+ }
54
+
55
+ OSStatus AudioOutput::render_cb(void* refcon,
56
+ AudioUnitRenderActionFlags*,
57
+ const AudioTimeStamp*,
58
+ UInt32,
59
+ UInt32 num_frames,
60
+ AudioBufferList* data) {
61
+ auto* self = static_cast<AudioOutput*>(refcon);
62
+ auto* L = static_cast<float*>(data->mBuffers[0].mData);
63
+ auto* R = static_cast<float*>(data->mBuffers[1].mData);
64
+ if (self->fill_) self->fill_(L, R, num_frames);
65
+ return noErr;
66
+ }
67
+
68
+ } // namespace mrt2d
@@ -0,0 +1,35 @@
1
+ #include "command.h"
2
+
3
+ namespace mrt2d {
4
+
5
+ std::optional<Command> parse_command(const std::string& line) {
6
+ try {
7
+ auto j = nlohmann::json::parse(line);
8
+ if (j.contains("prompt")) {
9
+ return VibeCommand{j["prompt"].get<std::string>()};
10
+ }
11
+ if (j.contains("status")) {
12
+ return StatusCommand{};
13
+ }
14
+ return std::nullopt;
15
+ } catch (...) {
16
+ return std::nullopt;
17
+ }
18
+ }
19
+
20
+ std::string ack_ok(const std::string& prompt) {
21
+ return nlohmann::json{{"ok", true}, {"prompt", prompt}}.dump() + "\n";
22
+ }
23
+
24
+ std::string ack_status(bool running, const std::string& prompt) {
25
+ return nlohmann::json{
26
+ {"ok", true},
27
+ {"status", {{"running", running}, {"prompt", prompt}}}
28
+ }.dump() + "\n";
29
+ }
30
+
31
+ std::string ack_error(const std::string& message) {
32
+ return nlohmann::json{{"ok", false}, {"error", message}}.dump() + "\n";
33
+ }
34
+
35
+ } // namespace mrt2d
@@ -0,0 +1,126 @@
1
+ #include "audio_output.h"
2
+ #include "command.h"
3
+ #include "socket_server.h"
4
+
5
+ #include <magentart/realtime_runner.h>
6
+
7
+ #include <atomic>
8
+ #include <csignal>
9
+ #include <cstdio>
10
+ #include <cstdlib>
11
+ #include <filesystem>
12
+ #include <fstream>
13
+ #include <string>
14
+ #include <unistd.h>
15
+
16
+ using namespace magentart::core;
17
+ using namespace mrt2d;
18
+
19
+ static std::atomic<bool> g_running{true};
20
+ static SocketServer* g_socket = nullptr;
21
+
22
+ static void on_signal(int) {
23
+ g_running = false;
24
+ if (g_socket) g_socket->stop();
25
+ }
26
+
27
+ static std::string expand_home(const std::string& p) {
28
+ if (!p.empty() && p[0] == '~') {
29
+ const char* home = std::getenv("HOME");
30
+ if (home) return std::string(home) + p.substr(1);
31
+ }
32
+ return p;
33
+ }
34
+
35
+ int main(int argc, char** argv) {
36
+ std::string model_path = expand_home("~/Documents/Magenta/magenta-rt-v2/models/mrt2_small/mrt2_small.mlxfn");
37
+ std::string resource_dir = expand_home("~/Documents/Magenta/magenta-rt-v2/resources");
38
+ std::string socket_path = expand_home("~/.mrt2d.sock");
39
+ std::string pid_path = expand_home("~/.mrt2d.pid");
40
+
41
+ for (int i = 1; i < argc; ++i) {
42
+ std::string a = argv[i];
43
+ if (a == "--model" && i + 1 < argc) model_path = argv[++i];
44
+ else if (a == "--resources" && i + 1 < argc) resource_dir = argv[++i];
45
+ else if (a == "--socket" && i + 1 < argc) socket_path = argv[++i];
46
+ }
47
+
48
+ std::signal(SIGTERM, on_signal);
49
+ std::signal(SIGINT, on_signal);
50
+
51
+ RealtimeRunner runner;
52
+
53
+ std::printf("Loading assets from %s ...\n", resource_dir.c_str());
54
+ if (!runner.init_assets(resource_dir.c_str())) {
55
+ std::fprintf(stderr, "init_assets failed\n"); return 1;
56
+ }
57
+ std::printf("Loading model %s ...\n", model_path.c_str());
58
+ if (!runner.load_model(model_path.c_str())) {
59
+ std::fprintf(stderr, "load_model failed\n"); return 1;
60
+ }
61
+ // load_model() starts the inference loop internally, generating from the
62
+ // model's baked-in initial conditioning state. We deliberately skip
63
+ // prefill_silence: this small model's SpectroStream encoder is compiled for
64
+ // a fixed 28s window that the silent-prefill path doesn't match, and
65
+ // prefill is only a first-second cleanliness optimisation anyway.
66
+
67
+ // Max out the output ring buffer (default 2048). Background coding music
68
+ // doesn't care about latency, so we trade it for the most underrun headroom
69
+ // the core's lock-free buffer allows — fewer audible glitches if inference
70
+ // momentarily stalls.
71
+ runner.set_buffer_size(RingBuffer::kCapacity);
72
+ std::printf("Ring buffer set to %zu samples (max).\n",
73
+ runner.get_buffer_size());
74
+ std::printf("Inference running.\n");
75
+
76
+ AudioOutput audio;
77
+ audio.start([&runner](float* L, float* R, std::size_t count) {
78
+ runner.read_audio_stereo(L, R, count);
79
+ });
80
+ std::printf("Audio output started.\n");
81
+
82
+ { std::ofstream pf(pid_path); pf << ::getpid() << "\n"; }
83
+
84
+ // Sampling / guidance knobs, matched to Google's shipped MRT2 apps so our
85
+ // output quality tracks their tuned defaults. Set ALL of them explicitly
86
+ // rather than relying on engine defaults — the stock cfg_notes default is
87
+ // 5.0, but Google ships 2.2. These are quality knobs, not musical energy
88
+ // (energy lives in the prompt text). Tune here to experiment.
89
+ constexpr float kTemperature = 1.0f; // best-sounding; also Google's value
90
+ constexpr int kTopK = 100;
91
+ constexpr float kCfgMusiccoca = 3.1f; // "style" guidance
92
+ constexpr float kCfgNotes = 2.2f; // "notes" guidance
93
+ runner.set_temperature(kTemperature);
94
+ runner.set_top_k(kTopK);
95
+ runner.set_cfg_musiccoca(kCfgMusiccoca);
96
+ runner.set_cfg_notes(kCfgNotes);
97
+
98
+ std::string current_prompt = "ambient pads, slow, contemplative, coding session";
99
+ runner.set_text_prompt(current_prompt);
100
+
101
+ SocketServer server(socket_path);
102
+ g_socket = &server;
103
+ std::printf("Listening on %s\n", socket_path.c_str());
104
+
105
+ server.run([&](const std::string& line) -> std::string {
106
+ auto cmd = parse_command(line);
107
+ if (!cmd) return ack_error("parse error");
108
+
109
+ return std::visit([&](auto&& c) -> std::string {
110
+ using T = std::decay_t<decltype(c)>;
111
+ if constexpr (std::is_same_v<T, VibeCommand>) {
112
+ current_prompt = c.prompt;
113
+ runner.set_text_prompt(c.prompt);
114
+ return ack_ok(current_prompt);
115
+ } else {
116
+ return ack_status(true, current_prompt);
117
+ }
118
+ }, *cmd);
119
+ });
120
+
121
+ audio.stop();
122
+ runner.stop();
123
+ std::filesystem::remove(pid_path);
124
+ std::printf("Shutdown complete.\n");
125
+ return 0;
126
+ }
@@ -0,0 +1,78 @@
1
+ #include "socket_server.h"
2
+
3
+ #include <cstring>
4
+ #include <poll.h>
5
+ #include <stdexcept>
6
+ #include <sys/socket.h>
7
+ #include <sys/un.h>
8
+ #include <thread>
9
+ #include <unistd.h>
10
+
11
+ namespace mrt2d {
12
+
13
+ SocketServer::SocketServer(const std::string& path) : path_(path) {
14
+ server_fd_ = ::socket(AF_UNIX, SOCK_STREAM, 0);
15
+ if (server_fd_ < 0) throw std::runtime_error("socket() failed");
16
+
17
+ ::unlink(path_.c_str());
18
+
19
+ struct sockaddr_un addr{};
20
+ addr.sun_family = AF_UNIX;
21
+ std::strncpy(addr.sun_path, path_.c_str(), sizeof(addr.sun_path) - 1);
22
+
23
+ if (::bind(server_fd_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) < 0)
24
+ throw std::runtime_error("bind() failed on: " + path_);
25
+ if (::listen(server_fd_, 8) < 0)
26
+ throw std::runtime_error("listen() failed");
27
+ }
28
+
29
+ SocketServer::~SocketServer() {
30
+ if (server_fd_ >= 0) ::close(server_fd_);
31
+ ::unlink(path_.c_str());
32
+ }
33
+
34
+ void SocketServer::run(std::function<std::string(const std::string&)> handler) {
35
+ running_.store(true, std::memory_order_relaxed);
36
+ while (running_.load(std::memory_order_relaxed)) {
37
+ // Poll with a timeout instead of a bare blocking accept() so stop()
38
+ // (which may be called from a signal handler) is observed within
39
+ // ~200ms regardless of SA_RESTART / shutdown()-on-listen-socket
40
+ // portability quirks.
41
+ struct pollfd pfd{server_fd_, POLLIN, 0};
42
+ int pr = ::poll(&pfd, 1, 200);
43
+ if (!running_.load(std::memory_order_relaxed)) break;
44
+ if (pr <= 0) continue; // timeout or EINTR — re-check running_
45
+
46
+ int client_fd = ::accept(server_fd_, nullptr, nullptr);
47
+ if (client_fd < 0) continue;
48
+
49
+ std::thread([client_fd, handler]() {
50
+ char buf[4096];
51
+ std::string partial;
52
+ while (true) {
53
+ ssize_t n = ::recv(client_fd, buf, sizeof(buf) - 1, 0);
54
+ if (n <= 0) break;
55
+ buf[n] = '\0';
56
+ partial += buf;
57
+ size_t pos;
58
+ while ((pos = partial.find('\n')) != std::string::npos) {
59
+ std::string line = partial.substr(0, pos);
60
+ partial = partial.substr(pos + 1);
61
+ if (!line.empty()) {
62
+ std::string resp = handler(line);
63
+ ::send(client_fd, resp.c_str(), resp.size(), 0);
64
+ }
65
+ }
66
+ }
67
+ ::close(client_fd);
68
+ }).detach();
69
+ }
70
+ }
71
+
72
+ void SocketServer::stop() {
73
+ // Async-signal-safe: just an atomic store. The poll loop notices within
74
+ // the poll timeout and returns from run().
75
+ running_.store(false, std::memory_order_relaxed);
76
+ }
77
+
78
+ } // namespace mrt2d
@@ -0,0 +1,74 @@
1
+ #include "command.h"
2
+ #include <cassert>
3
+ #include <iostream>
4
+ #include <variant>
5
+
6
+ int main() {
7
+ using namespace mrt2d;
8
+
9
+ // parse vibe command
10
+ {
11
+ auto cmd = parse_command(R"({"prompt":"jazz trio"})");
12
+ assert(cmd.has_value());
13
+ assert(std::get<VibeCommand>(*cmd).prompt == "jazz trio");
14
+ std::cout << "PASS: parse vibe\n";
15
+ }
16
+
17
+ // extra keys (e.g. a legacy intensity field) are ignored
18
+ {
19
+ auto cmd = parse_command(R"({"prompt":"x","intensity":0.7})");
20
+ assert(cmd.has_value());
21
+ assert(std::get<VibeCommand>(*cmd).prompt == "x");
22
+ std::cout << "PASS: ignore extra keys\n";
23
+ }
24
+
25
+ // parse status command
26
+ {
27
+ auto cmd = parse_command(R"({"status":true})");
28
+ assert(cmd.has_value());
29
+ assert(std::holds_alternative<StatusCommand>(*cmd));
30
+ std::cout << "PASS: parse status\n";
31
+ }
32
+
33
+ // reject malformed JSON
34
+ {
35
+ auto cmd = parse_command("not json {{");
36
+ assert(!cmd.has_value());
37
+ std::cout << "PASS: reject malformed\n";
38
+ }
39
+
40
+ // reject empty object (no known keys)
41
+ {
42
+ auto cmd = parse_command("{}");
43
+ assert(!cmd.has_value());
44
+ std::cout << "PASS: reject empty object\n";
45
+ }
46
+
47
+ // ack_ok serialises correctly
48
+ {
49
+ auto j = nlohmann::json::parse(ack_ok("jazz"));
50
+ assert(j["ok"] == true);
51
+ assert(j["prompt"] == "jazz");
52
+ std::cout << "PASS: ack_ok\n";
53
+ }
54
+
55
+ // ack_status serialises correctly
56
+ {
57
+ auto j = nlohmann::json::parse(ack_status(true, "jazz"));
58
+ assert(j["ok"] == true);
59
+ assert(j["status"]["running"] == true);
60
+ assert(j["status"]["prompt"] == "jazz");
61
+ std::cout << "PASS: ack_status\n";
62
+ }
63
+
64
+ // ack_error serialises correctly
65
+ {
66
+ auto j = nlohmann::json::parse(ack_error("parse error"));
67
+ assert(j["ok"] == false);
68
+ assert(j["error"] == "parse error");
69
+ std::cout << "PASS: ack_error\n";
70
+ }
71
+
72
+ std::cout << "\nAll command tests passed.\n";
73
+ return 0;
74
+ }
@@ -0,0 +1,89 @@
1
+ #include "socket_server.h"
2
+
3
+ #include <cassert>
4
+ #include <chrono>
5
+ #include <cstring>
6
+ #include <iostream>
7
+ #include <string>
8
+ #include <sys/socket.h>
9
+ #include <sys/un.h>
10
+ #include <thread>
11
+ #include <unistd.h>
12
+
13
+ static const char* SOCK = "/tmp/test_mrt2d_socket.sock";
14
+
15
+ // Connect a Unix socket client, send msg, read one line response.
16
+ static std::string roundtrip(const std::string& msg) {
17
+ int fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
18
+ assert(fd >= 0);
19
+
20
+ struct sockaddr_un addr{};
21
+ addr.sun_family = AF_UNIX;
22
+ std::strncpy(addr.sun_path, SOCK, sizeof(addr.sun_path) - 1);
23
+ assert(::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == 0);
24
+
25
+ std::string to_send = msg + "\n";
26
+ ::send(fd, to_send.c_str(), to_send.size(), 0);
27
+
28
+ char buf[1024] = {};
29
+ std::string result;
30
+ while (result.find('\n') == std::string::npos) {
31
+ ssize_t n = ::recv(fd, buf, sizeof(buf) - 1, 0);
32
+ if (n <= 0) break;
33
+ result.append(buf, n);
34
+ }
35
+ ::close(fd);
36
+ return result;
37
+ }
38
+
39
+ int main() {
40
+ mrt2d::SocketServer server(SOCK);
41
+
42
+ std::thread srv([&server]() {
43
+ server.run([](const std::string& line) -> std::string {
44
+ return "echo:" + line + "\n";
45
+ });
46
+ });
47
+
48
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
49
+
50
+ // Single message round-trip.
51
+ {
52
+ std::string resp = roundtrip("hello");
53
+ assert(resp == "echo:hello\n");
54
+ std::cout << "PASS: single message roundtrip\n";
55
+ }
56
+
57
+ // Multiple messages on the same connection.
58
+ {
59
+ int fd = ::socket(AF_UNIX, SOCK_STREAM, 0);
60
+ assert(fd >= 0);
61
+ struct sockaddr_un addr{};
62
+ addr.sun_family = AF_UNIX;
63
+ std::strncpy(addr.sun_path, SOCK, sizeof(addr.sun_path) - 1);
64
+ assert(::connect(fd, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) == 0);
65
+
66
+ for (int i = 0; i < 3; ++i) {
67
+ std::string msg = "msg" + std::to_string(i) + "\n";
68
+ ::send(fd, msg.c_str(), msg.size(), 0);
69
+ char buf[256] = {};
70
+ std::string resp;
71
+ while (resp.find('\n') == std::string::npos) {
72
+ ssize_t n = ::recv(fd, buf, sizeof(buf) - 1, 0);
73
+ if (n <= 0) break;
74
+ resp.append(buf, n);
75
+ }
76
+ assert(resp == "echo:msg" + std::to_string(i) + "\n");
77
+ }
78
+ ::close(fd);
79
+ std::cout << "PASS: multiple messages same connection\n";
80
+ }
81
+
82
+ // stop() must unblock run() so the thread joins (the SIGTERM-path fix).
83
+ server.stop();
84
+ srv.join();
85
+ std::cout << "PASS: stop() unblocks run()\n";
86
+
87
+ std::cout << "\nAll socket tests passed.\n";
88
+ return 0;
89
+ }