diff --git a/.release-manifest.json b/.release-manifest.json index 6f4b10c..606a098 100644 --- a/.release-manifest.json +++ b/.release-manifest.json @@ -1,13 +1,13 @@ { - "crates/rust-mcp-sdk": "0.5.0", - "crates/rust-mcp-macros": "0.5.0", - "crates/rust-mcp-transport": "0.4.0", - "examples/hello-world-mcp-server": "0.1.24", - "examples/hello-world-mcp-server-core": "0.1.15", - "examples/simple-mcp-client": "0.1.24", - "examples/simple-mcp-client-core": "0.1.24", - "examples/hello-world-server-core-sse": "0.1.15", - "examples/hello-world-server-sse": "0.1.24", - "examples/simple-mcp-client-core-sse": "0.1.15", - "examples/simple-mcp-client-sse": "0.1.15" -} \ No newline at end of file + "crates/rust-mcp-sdk": "0.5.0", + "crates/rust-mcp-macros": "0.5.0", + "crates/rust-mcp-transport": "0.4.0", + "examples/hello-world-mcp-server": "0.1.24", + "examples/hello-world-mcp-server-core": "0.1.15", + "examples/simple-mcp-client": "0.1.24", + "examples/simple-mcp-client-core": "0.1.24", + "examples/hello-world-server-core-streamable-http": "0.1.15", + "examples/hello-world-server-streamable-http": "0.1.24", + "examples/simple-mcp-client-core-sse": "0.1.15", + "examples/simple-mcp-client-sse": "0.1.15" +} diff --git a/Cargo.lock b/Cargo.lock index 05ac024..1fbdf25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -84,9 +84,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.13.1" +version = "1.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" +checksum = "5c953fe1ba023e6b7730c0d4b031d06f267f23a46167dcbd40316644b10a17ba" dependencies = [ "aws-lc-sys", "zeroize", @@ -94,9 +94,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.29.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +checksum = "dbfd150b5dbdb988bcc8fb1fe787eb6b7ee6180ca24da683b61ea5405f3d43ff" dependencies = [ "bindgen", "cc", @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.27" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "jobserver", "libc", @@ -644,9 +644,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "h2" -version = "0.3.26" +version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" dependencies = [ "bytes", "fnv", @@ -713,7 +713,7 @@ dependencies = [ ] [[package]] -name = "hello-world-server-core-sse" +name = "hello-world-server-core-streamable-http" version = "0.1.15" dependencies = [ "async-trait", @@ -727,7 +727,7 @@ dependencies = [ ] [[package]] -name = "hello-world-server-sse" +name = "hello-world-server-streamable-http" version = "0.1.24" dependencies = [ "async-trait", @@ -854,14 +854,14 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.3.26", + "h2 0.3.27", "http 0.2.12", "http-body 0.4.6", "httparse", "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.5.10", "tokio", "tower-service", "tracing", @@ -908,9 +908,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.14" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "base64 0.22.1", "bytes", @@ -924,7 +924,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.0", "tokio", "tower-service", "tracing", @@ -1062,6 +1062,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "io-uring" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" +dependencies = [ + "bitflags", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -1155,9 +1166,9 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "litrs" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5" +checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed" [[package]] name = "lock_api" @@ -1374,9 +1385,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.35" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn", @@ -1420,7 +1431,7 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls", - "socket2", + "socket2 0.5.10", "thiserror 2.0.12", "tokio", "tracing", @@ -1436,7 +1447,7 @@ dependencies = [ "bytes", "getrandom 0.3.3", "lru-slab", - "rand 0.9.1", + "rand 0.9.2", "ring", "rustc-hash 2.1.1", "rustls", @@ -1457,7 +1468,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.5.10", "tracing", "windows-sys 0.59.0", ] @@ -1492,9 +1503,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -1549,9 +1560,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.13" +version = "0.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d04b7d0ee6b4a0207a0a7adb104d23ecb0b47d6beae7152d0fa34b692b29fd6" +checksum = "7251471db004e509f4e75a62cca9435365b5ec7bcdff530d612ac7c87c44a792" dependencies = [ "bitflags", ] @@ -1678,9 +1689,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fec1ff3507a619b9945f60a94dac448541aef8c9803aa6192c30f4a932cb1499" +checksum = "a0e71aee61257cd3d4a78fdc10c92c29e7a55c4f767119ffdafd837bb5e5cb9a" dependencies = [ "serde", "serde_json", @@ -1729,9 +1740,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" [[package]] name = "rustc-hash" @@ -1760,9 +1771,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.28" +version = "0.23.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7160e3e10bf4535308537f3c4e1641468cd0e485175d6163087c0393c7d46643" +checksum = "069a8df149a16b1a12dcc31497c3396a173844be3cac4bd40c9e7671fef96671" dependencies = [ "aws-lc-rs", "once_cell", @@ -1794,9 +1805,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.3" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ "aws-lc-rs", "ring", @@ -1844,9 +1855,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "itoa", "memchr", @@ -1993,6 +2004,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -2143,20 +2164,22 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.45.1" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "43864ed400b6043a4757a25c7a64a8efde741aed79a056a2fb348a406701bb35" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "slab", + "socket2 0.6.0", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2519,9 +2542,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8782dd5a41a24eed3a4f40b606249b3e236ca61adf1f25ea4d45c73de122b502" +checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2" dependencies = [ "rustls-pki-types", ] diff --git a/Cargo.toml b/Cargo.toml index 6a7196d..59af1e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,8 @@ members = [ "examples/simple-mcp-client-core", "examples/hello-world-mcp-server", "examples/hello-world-mcp-server-core", - "examples/hello-world-server-sse", - "examples/hello-world-server-core-sse", + "examples/hello-world-server-streamable-http", + "examples/hello-world-server-core-streamable-http", "examples/simple-mcp-client-sse", "examples/simple-mcp-client-core-sse", ] diff --git a/README.md b/README.md index 1c73330..85f3070 100644 --- a/README.md +++ b/README.md @@ -23,15 +23,26 @@ By default, it uses the **2025-06-18** version, but earlier versions can be enab -This project currently supports following transports: -- **stdio** (Standard Input/Output) -- **sse** (Server-Sent Events). - +This project supports following transports: +- **Stdio** (Standard Input/Output) +- **SSE** (Server-Sent Events). +- **Streamable HTTP**. 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `sse` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. -**⚠️** **Streamable HTTP** transport and authentication still in progress and not yet available. Project is currently under development and should be used at your own risk. + + +**MCP Streamable HTTP Support** +- [x] Streamable HTTP Support for MCP Servers +- [x] DNS Rebinding Protection +- [x] Batch Messages +- [x] Streaming & non-streaming JSON response +- [ ] Streamable HTTP Support for MCP Clients +- [ ] Resumability +- [ ] Authentication / Oauth + +**⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents - [Usage Examples](#usage-examples) @@ -39,6 +50,9 @@ This project currently supports following transports: - [MCP Server (sse)](#mcp-server-sse) - [MCP Client (stdio)](#mcp-client-stdio) - [MCP Client (sse)](#mcp-client-sse) +- [Getting Started](#getting-started) +- [HyperServerOptions](#hyperserveroptions) + - [Security Considerations](#security-considerations) - [Cargo features](#cargo-features) - [Available Features](#available-features) - [MCP protocol versions with corresponding features](#mcp-protocol-versions-with-corresponding-features) @@ -100,10 +114,14 @@ See hello-world-mcp-server example running in [MCP Inspector](https://modelconte ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) -### MCP Server (sse) +### MCP Server (Streamable HTTP) Creating an MCP server in `rust-mcp-sdk` with the `sse` transport allows multiple clients to connect simultaneously with no additional setup. -Simply create a Hyper Server using `hyper_server::create_server()` and pass in the same handler and transform options. +Simply create a Hyper Server using `hyper_server::create_server()` and pass in the same handler and HyperServerOptions. + + +💡 By default, both **Streamable HTTP** and **SSE** transports are enabled for backward compatibility. To disable the SSE transport , set the `sse_support` to false in the `HyperServerOptions`. + ```rust @@ -134,6 +152,7 @@ let server = hyper_server::create_server( handler, HyperServerOptions { host: "127.0.0.1".to_string(), + sse_support: false, ..Default::default() }, ); @@ -188,9 +207,9 @@ impl ServerHandler for MyServerHandler { 👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** -See hello-world-server-sse example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : -![mcp-server in rust](assets/examples/hello-world-server-sse.gif) +![mcp-server in rust](assets/examples/hello-world-server-streamable-http.gif) --- @@ -292,6 +311,103 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) +## HyperServerOptions + +HyperServer is a lightweight Axum-based server that streamlines MCP servers by supporting **Streamable HTTP** and **SSE** transports. It supports simultaneous client connections, internal session management, and includes built-in security features like DNS rebinding protection and more. + +HyperServer is highly customizable through HyperServerOptions provided during initialization. + +A typical example of creating a HyperServer that exposes the MCP server via Streamable HTTP and SSE transports at: + +```rs + +let server = hyper_server::create_server( + server_details, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + enable_ssl: true, + ..Default::default() + }, +); + +server.start().await?; + +``` + +Here is a list of available options with descriptions for configuring the HyperServer: +```rs +pub struct HyperServerOptions { + /// Hostname or IP address the server will bind to (default: "127.0.0.1") + pub host: String, + + /// Hostname or IP address the server will bind to (default: "8080") + pub port: u16, + + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) + pub custom_streamable_http_endpoint: Option, + + /// This setting only applies to streamable HTTP. + /// If true, the server will return JSON responses instead of starting an SSE stream. + /// This can be useful for simple request/response scenarios without streaming. + /// Default is false (SSE streams are preferred). + pub enable_json_response: Option, + + /// Interval between automatic ping messages sent to clients to detect disconnects + pub ping_interval: Duration, + + /// Shared transport configuration used by the server + pub transport_options: Arc, + + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>, + + /// Enables SSL/TLS if set to `true` + pub enable_ssl: bool, + + /// Path to the SSL/TLS certificate file (e.g., "cert.pem"). + /// Required if `enable_ssl` is `true`. + pub ssl_cert_path: Option, + + /// Path to the SSL/TLS private key file (e.g., "key.pem"). + /// Required if `enable_ssl` is `true`. + pub ssl_key_path: Option, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, + + /// List of allowed host header values for DNS rebinding protection. + /// If not specified, host validation is disabled. + pub allowed_hosts: Option>, + + /// List of allowed origin header values for DNS rebinding protection. + /// If not specified, origin validation is disabled. + pub allowed_origins: Option>, + + /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + /// Default is false for backwards compatibility. + pub dns_rebinding_protection: bool, +} + +``` + +### Security Considerations + +When using Streamable HTTP transport, following security best practices are recommended: + +- Enable DNS rebinding protection and provide proper `allowed_hosts` and `allowed_origins` to prevent DNS rebinding attacks. +- When running locally, bind only to localhost (127.0.0.1 / localhost) rather than all network interfaces (0.0.0.0) +- Use TLS/HTTPS for production deployments + + ## Cargo Features The `rust-mcp-sdk` crate provides several features that can be enabled or disabled. By default, all features are enabled to ensure maximum functionality, but you can customize which ones to include based on your project's requirements. diff --git a/assets/examples/hello-world-mcp-server.gif b/assets/examples/hello-world-mcp-server.gif index 5796d45..dcc08ce 100644 Binary files a/assets/examples/hello-world-mcp-server.gif and b/assets/examples/hello-world-mcp-server.gif differ diff --git a/assets/examples/hello-world-server-core-streamable-http.gif b/assets/examples/hello-world-server-core-streamable-http.gif new file mode 100644 index 0000000..6800321 Binary files /dev/null and b/assets/examples/hello-world-server-core-streamable-http.gif differ diff --git a/assets/examples/hello-world-server-streamable-http.gif b/assets/examples/hello-world-server-streamable-http.gif new file mode 100644 index 0000000..a1e66b2 Binary files /dev/null and b/assets/examples/hello-world-server-streamable-http.gif differ diff --git a/crates/rust-mcp-macros/Cargo.toml b/crates/rust-mcp-macros/Cargo.toml index 534c8f5..a581ac6 100644 --- a/crates/rust-mcp-macros/Cargo.toml +++ b/crates/rust-mcp-macros/Cargo.toml @@ -21,6 +21,7 @@ syn = "2.0" quote = "1.0" proc-macro2 = "1.0" + [dev-dependencies] rust-mcp-schema = { workspace = true, default-features = false } diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 1c73330..85f3070 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -23,15 +23,26 @@ By default, it uses the **2025-06-18** version, but earlier versions can be enab -This project currently supports following transports: -- **stdio** (Standard Input/Output) -- **sse** (Server-Sent Events). - +This project supports following transports: +- **Stdio** (Standard Input/Output) +- **SSE** (Server-Sent Events). +- **Streamable HTTP**. 🚀 The **rust-mcp-sdk** includes a lightweight [Axum](https://github.com/tokio-rs/axum) based server that handles all core functionality seamlessly. Switching between `stdio` and `sse` is straightforward, requiring minimal code changes. The server is designed to efficiently handle multiple concurrent client connections and offers built-in support for SSL. -**⚠️** **Streamable HTTP** transport and authentication still in progress and not yet available. Project is currently under development and should be used at your own risk. + + +**MCP Streamable HTTP Support** +- [x] Streamable HTTP Support for MCP Servers +- [x] DNS Rebinding Protection +- [x] Batch Messages +- [x] Streaming & non-streaming JSON response +- [ ] Streamable HTTP Support for MCP Clients +- [ ] Resumability +- [ ] Authentication / Oauth + +**⚠️** Project is currently under development and should be used at your own risk. ## Table of Contents - [Usage Examples](#usage-examples) @@ -39,6 +50,9 @@ This project currently supports following transports: - [MCP Server (sse)](#mcp-server-sse) - [MCP Client (stdio)](#mcp-client-stdio) - [MCP Client (sse)](#mcp-client-sse) +- [Getting Started](#getting-started) +- [HyperServerOptions](#hyperserveroptions) + - [Security Considerations](#security-considerations) - [Cargo features](#cargo-features) - [Available Features](#available-features) - [MCP protocol versions with corresponding features](#mcp-protocol-versions-with-corresponding-features) @@ -100,10 +114,14 @@ See hello-world-mcp-server example running in [MCP Inspector](https://modelconte ![mcp-server in rust](assets/examples/hello-world-mcp-server.gif) -### MCP Server (sse) +### MCP Server (Streamable HTTP) Creating an MCP server in `rust-mcp-sdk` with the `sse` transport allows multiple clients to connect simultaneously with no additional setup. -Simply create a Hyper Server using `hyper_server::create_server()` and pass in the same handler and transform options. +Simply create a Hyper Server using `hyper_server::create_server()` and pass in the same handler and HyperServerOptions. + + +💡 By default, both **Streamable HTTP** and **SSE** transports are enabled for backward compatibility. To disable the SSE transport , set the `sse_support` to false in the `HyperServerOptions`. + ```rust @@ -134,6 +152,7 @@ let server = hyper_server::create_server( handler, HyperServerOptions { host: "127.0.0.1".to_string(), + sse_support: false, ..Default::default() }, ); @@ -188,9 +207,9 @@ impl ServerHandler for MyServerHandler { 👉 For a more detailed example of a [Hello World MCP](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) Server that supports multiple tools and provides more type-safe handling of `CallToolRequest`, check out: **[examples/hello-world-mcp-server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server)** -See hello-world-server-sse example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : +See hello-world-server-streamable-http example running in [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) : -![mcp-server in rust](assets/examples/hello-world-server-sse.gif) +![mcp-server in rust](assets/examples/hello-world-server-streamable-http.gif) --- @@ -292,6 +311,103 @@ Creating an MCP client using the `rust-mcp-sdk` with the SSE transport is almost If you are looking for a step-by-step tutorial on how to get started with `rust-mcp-sdk` , please see : [Getting Started MCP Server](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/doc/getting-started-mcp-server.md) +## HyperServerOptions + +HyperServer is a lightweight Axum-based server that streamlines MCP servers by supporting **Streamable HTTP** and **SSE** transports. It supports simultaneous client connections, internal session management, and includes built-in security features like DNS rebinding protection and more. + +HyperServer is highly customizable through HyperServerOptions provided during initialization. + +A typical example of creating a HyperServer that exposes the MCP server via Streamable HTTP and SSE transports at: + +```rs + +let server = hyper_server::create_server( + server_details, + handler, + HyperServerOptions { + host: "127.0.0.1".to_string(), + enable_ssl: true, + ..Default::default() + }, +); + +server.start().await?; + +``` + +Here is a list of available options with descriptions for configuring the HyperServer: +```rs +pub struct HyperServerOptions { + /// Hostname or IP address the server will bind to (default: "127.0.0.1") + pub host: String, + + /// Hostname or IP address the server will bind to (default: "8080") + pub port: u16, + + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) + pub custom_streamable_http_endpoint: Option, + + /// This setting only applies to streamable HTTP. + /// If true, the server will return JSON responses instead of starting an SSE stream. + /// This can be useful for simple request/response scenarios without streaming. + /// Default is false (SSE streams are preferred). + pub enable_json_response: Option, + + /// Interval between automatic ping messages sent to clients to detect disconnects + pub ping_interval: Duration, + + /// Shared transport configuration used by the server + pub transport_options: Arc, + + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>, + + /// Enables SSL/TLS if set to `true` + pub enable_ssl: bool, + + /// Path to the SSL/TLS certificate file (e.g., "cert.pem"). + /// Required if `enable_ssl` is `true`. + pub ssl_cert_path: Option, + + /// Path to the SSL/TLS private key file (e.g., "key.pem"). + /// Required if `enable_ssl` is `true`. + pub ssl_key_path: Option, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, + + /// List of allowed host header values for DNS rebinding protection. + /// If not specified, host validation is disabled. + pub allowed_hosts: Option>, + + /// List of allowed origin header values for DNS rebinding protection. + /// If not specified, origin validation is disabled. + pub allowed_origins: Option>, + + /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + /// Default is false for backwards compatibility. + pub dns_rebinding_protection: bool, +} + +``` + +### Security Considerations + +When using Streamable HTTP transport, following security best practices are recommended: + +- Enable DNS rebinding protection and provide proper `allowed_hosts` and `allowed_origins` to prevent DNS rebinding attacks. +- When running locally, bind only to localhost (127.0.0.1 / localhost) rather than all network interfaces (0.0.0.0) +- Use TLS/HTTPS for production deployments + + ## Cargo Features The `rust-mcp-sdk` crate provides several features that can be enabled or disabled. By default, all features are enabled to ensure maximum functionality, but you can customize which ones to include based on your project's requirements. diff --git a/crates/rust-mcp-sdk/assets/examples/hello-world-mcp-server.gif b/crates/rust-mcp-sdk/assets/examples/hello-world-mcp-server.gif index 5796d45..dcc08ce 100644 Binary files a/crates/rust-mcp-sdk/assets/examples/hello-world-mcp-server.gif and b/crates/rust-mcp-sdk/assets/examples/hello-world-mcp-server.gif differ diff --git a/crates/rust-mcp-sdk/assets/examples/hello-world-server-core-streamable-http.gif b/crates/rust-mcp-sdk/assets/examples/hello-world-server-core-streamable-http.gif new file mode 100644 index 0000000..6800321 Binary files /dev/null and b/crates/rust-mcp-sdk/assets/examples/hello-world-server-core-streamable-http.gif differ diff --git a/crates/rust-mcp-sdk/assets/examples/hello-world-server-streamable-http.gif b/crates/rust-mcp-sdk/assets/examples/hello-world-server-streamable-http.gif new file mode 100644 index 0000000..a1e66b2 Binary files /dev/null and b/crates/rust-mcp-sdk/assets/examples/hello-world-server-streamable-http.gif differ diff --git a/crates/rust-mcp-sdk/src/error.rs b/crates/rust-mcp-sdk/src/error.rs index 788e31a..2feab67 100644 --- a/crates/rust-mcp-sdk/src/error.rs +++ b/crates/rust-mcp-sdk/src/error.rs @@ -1,6 +1,8 @@ -use crate::schema::RpcError; +use crate::schema::{ParseProtocolVersionError, RpcError}; + use rust_mcp_transport::error::TransportError; use thiserror::Error; +use tokio::task::JoinError; #[cfg(feature = "hyper-server")] use crate::hyper_servers::error::TransportServerError; @@ -16,14 +18,18 @@ pub enum McpSdkError { #[error("{0}")] TransportError(#[from] TransportError), #[error("{0}")] + JoinError(#[from] JoinError), + #[error("{0}")] AnyError(Box<(dyn std::error::Error + Send + Sync)>), #[error("{0}")] SdkError(#[from] crate::schema::schema_utils::SdkError), #[cfg(feature = "hyper-server")] #[error("{0}")] TransportServerError(#[from] TransportServerError), - #[error("Incompatible mcp protocol version: client:{0} server:{1}")] + #[error("Incompatible mcp protocol version: requested:{0} current:{1}")] IncompatibleProtocolVersion(String, String), + #[error("{0}")] + ParseProtocolVersionError(#[from] ParseProtocolVersionError), } impl McpSdkError { diff --git a/crates/rust-mcp-sdk/src/hyper_servers.rs b/crates/rust-mcp-sdk/src/hyper_servers.rs index 9a58b04..f18c428 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers.rs @@ -1,5 +1,6 @@ mod app_state; pub mod error; +pub mod hyper_runtime; pub mod hyper_server; pub mod hyper_server_core; mod middlewares; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs index 65e77f2..0c1dcf3 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs @@ -19,5 +19,23 @@ pub struct AppState { pub handler: Arc, pub ping_interval: Duration, pub sse_message_endpoint: String, + pub http_streamable_endpoint: String, pub transport_options: Arc, + pub enable_json_response: bool, + /// List of allowed host header values for DNS rebinding protection. + /// If not specified, host validation is disabled. + pub allowed_hosts: Option>, + /// List of allowed origin header values for DNS rebinding protection. + /// If not specified, origin validation is disabled. + pub allowed_origins: Option>, + /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + /// Default is false for backwards compatibility. + pub dns_rebinding_protection: bool, +} + +impl AppState { + pub fn needs_dns_protection(&self) -> bool { + self.dns_rebinding_protection + && (self.allowed_hosts.is_some() || self.allowed_origins.is_some()) + } } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/error.rs b/crates/rust-mcp-sdk/src/hyper_servers/error.rs index adcccf4..74cbcd1 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/error.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/error.rs @@ -21,6 +21,8 @@ pub enum TransportServerError { InvalidServerOptions(String), #[error("{0}")] SslCertError(String), + #[error("{0}")] + TransportError(String), } impl IntoResponse for TransportServerError { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs new file mode 100644 index 0000000..30df951 --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/hyper_runtime.rs @@ -0,0 +1,198 @@ +use std::{sync::Arc, time::Duration}; + +use crate::{ + mcp_server::HyperServer, + schema::{ + schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}, + CreateMessageRequestParams, CreateMessageResult, LoggingMessageNotificationParams, + PromptListChangedNotificationParams, ResourceListChangedNotificationParams, + ResourceUpdatedNotificationParams, ToolListChangedNotificationParams, + }, + McpServer, +}; + +use axum_server::Handle; +use rust_mcp_transport::SessionId; +use tokio::{sync::Mutex, task::JoinHandle}; + +use crate::{ + error::SdkResult, + hyper_servers::app_state::AppState, + mcp_server::{ + error::{TransportServerError, TransportServerResult}, + ServerRuntime, + }, +}; + +pub struct HyperRuntime { + pub(crate) state: Arc, + pub(crate) server_task: JoinHandle>, + pub(crate) server_handle: Handle, +} + +impl HyperRuntime { + pub async fn create(server: HyperServer) -> SdkResult { + let addr = server.options.resolve_server_address().await?; + let state = server.state(); + + let server_handle = server.server_handle(); + + let server_task = tokio::spawn(async move { + #[cfg(feature = "ssl")] + if server.options.enable_ssl { + server.start_ssl(addr).await + } else { + server.start_http(addr).await + } + + #[cfg(not(feature = "ssl"))] + if server.options.enable_ssl { + panic!("SSL requested but the 'ssl' feature is not enabled"); + } else { + server.start_http(addr).await + } + }); + + Ok(Self { + state, + server_task, + server_handle, + }) + } + + pub fn graceful_shutdown(&self, timeout: Option) { + self.server_handle.graceful_shutdown(timeout); + } + + pub async fn await_server(self) -> SdkResult<()> { + let result = self.server_task.await?; + result.map_err(|err| err.into()) + } + + pub async fn runtime_by_session( + &self, + session_id: &SessionId, + ) -> TransportServerResult>>> { + self.state.session_store.get(session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + ) + } + + pub async fn send_request( + &self, + session_id: &SessionId, + request: RequestFromServer, + timeout: Option, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.request(request, timeout).await + } + + pub async fn send_notification( + &self, + session_id: &SessionId, + notification: NotificationFromServer, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.send_notification(notification).await + } + + pub async fn send_logging_message( + &self, + session_id: &SessionId, + params: LoggingMessageNotificationParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.send_logging_message(params).await + } + + /// An optional notification from the server to the client, informing it that + /// the list of prompts it offers has changed. + /// This may be issued by servers without any previous subscription from the client. + pub async fn send_prompt_list_changed( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.send_prompt_list_changed(params).await + } + + /// An optional notification from the server to the client, + /// informing it that the list of resources it can read from has changed. + /// This may be issued by servers without any previous subscription from the client. + pub async fn send_resource_list_changed( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.send_resource_list_changed(params).await + } + + /// A notification from the server to the client, informing it that + /// a resource has changed and may need to be read again. + /// This should only be sent if the client previously sent a resources/subscribe request. + pub async fn send_resource_updated( + &self, + session_id: &SessionId, + params: ResourceUpdatedNotificationParams, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.send_resource_updated(params).await + } + + /// An optional notification from the server to the client, informing it that + /// the list of tools it offers has changed. + /// This may be issued by servers without any previous subscription from the client. + pub async fn send_tool_list_changed( + &self, + session_id: &SessionId, + params: Option, + ) -> SdkResult<()> { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.send_tool_list_changed(params).await + } + + /// A ping request to check that the other party is still alive. + /// The receiver must promptly respond, or else may be disconnected. + /// + /// This function creates a `PingRequest` with no specific parameters, sends the request and awaits the response + /// Once the response is received, it attempts to convert it into the expected + /// result type. + /// + /// # Returns + /// A `SdkResult` containing the `rust_mcp_schema::Result` if the request is successful. + /// If the request or conversion fails, an error is returned. + pub async fn ping( + &self, + session_id: &SessionId, + timeout: Option, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.ping(timeout).await + } + + /// A request from the server to sample an LLM via the client. + /// The client has full discretion over which model to select. + /// The client should also inform the user before beginning sampling, + /// to allow them to inspect the request (human in the loop) + /// and decide whether to approve it. + pub async fn create_message( + &self, + session_id: &SessionId, + params: CreateMessageRequestParams, + ) -> SdkResult { + let runtime = self.runtime_by_session(session_id).await?; + let runtime = runtime.lock().await.to_owned(); + runtime.create_message(params).await + } +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs index 612510e..0222952 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs @@ -1 +1,2 @@ +pub(crate) mod protect_dns_rebinding; pub(crate) mod session_id_gen; diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs new file mode 100644 index 0000000..5674e87 --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/protect_dns_rebinding.rs @@ -0,0 +1,66 @@ +use crate::hyper_servers::app_state::AppState; +use crate::schema::schema_utils::SdkError; +use axum::{ + extract::{Request, State}, + middleware::Next, + response::IntoResponse, + Json, +}; +use hyper::{ + header::{HOST, ORIGIN}, + HeaderMap, StatusCode, +}; +use std::sync::Arc; + +// Middleware to protect against DNS rebinding attacks by validating Host and Origin headers. +pub async fn protect_dns_rebinding( + headers: HeaderMap, + State(state): State>, + request: Request, + next: Next, +) -> impl IntoResponse { + if !state.needs_dns_protection() { + // If protection is not needed, pass the request to the next handler + return next.run(request).await.into_response(); + } + + if let Some(allowed_hosts) = state.allowed_hosts.as_ref() { + if !allowed_hosts.is_empty() { + let Some(host) = headers.get(HOST).and_then(|h| h.to_str().ok()) else { + let error = SdkError::bad_request().with_message("Invalid Host header: [unknown] "); + return (StatusCode::FORBIDDEN, Json(error)).into_response(); + }; + + if !allowed_hosts + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(host)) + { + let error = SdkError::bad_request() + .with_message(format!("Invalid Host header: \"{host}\" ").as_str()); + return (StatusCode::FORBIDDEN, Json(error)).into_response(); + } + } + } + + if let Some(allowed_origins) = state.allowed_origins.as_ref() { + if !allowed_origins.is_empty() { + let Some(origin) = headers.get(ORIGIN).and_then(|h| h.to_str().ok()) else { + let error = + SdkError::bad_request().with_message("Invalid Origin header: [unknown] "); + return (StatusCode::FORBIDDEN, Json(error)).into_response(); + }; + + if !allowed_origins + .iter() + .any(|allowed| allowed.eq_ignore_ascii_case(origin)) + { + let error = SdkError::bad_request() + .with_message(format!("Invalid Origin header: \"{origin}\" ").as_str()); + return (StatusCode::FORBIDDEN, Json(error)).into_response(); + } + } + } + + // If all checks pass, proceed to the next handler in the chain + next.run(request).await +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs index 7146880..b1b15fc 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes.rs @@ -1,6 +1,8 @@ pub mod fallback_routes; +mod hyper_utils; pub mod messages_routes; pub mod sse_routes; +pub mod streamable_http_routes; use super::{app_state::AppState, HyperServerOptions}; use axum::Router; @@ -18,12 +20,25 @@ use std::sync::Arc; /// # Returns /// * `Router` - An Axum router configured with all application routes and state pub fn app_routes(state: Arc, server_options: &HyperServerOptions) -> Router { - Router::new() - .merge(sse_routes::routes( + let router: Router = Router::new() + .merge(streamable_http_routes::routes( state.clone(), - server_options.sse_endpoint(), + server_options.streamable_http_endpoint(), )) - .merge(messages_routes::routes(state.clone())) + .merge({ + let mut r = Router::new(); + if server_options.sse_support { + r = r + .merge(sse_routes::routes( + state.clone(), + server_options.sse_endpoint(), + )) + .merge(messages_routes::routes(state.clone())) + } + r + }) .with_state(state) - .merge(fallback_routes::routes()) + .merge(fallback_routes::routes()); + + router } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs index b76d5dc..971ed43 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/fallback_routes.rs @@ -9,7 +9,7 @@ pub fn routes() -> Router { pub async fn not_found(uri: Uri) -> (StatusCode, String) { ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Server Error!\r\n uri: {uri}"), + StatusCode::NOT_FOUND, + format!("The requested uri does not exist:\r\nuri: {uri}"), ) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs new file mode 100644 index 0000000..79bf226 --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -0,0 +1,398 @@ +use crate::{ + error::SdkResult, + hyper_servers::{ + app_state::AppState, + error::{TransportServerError, TransportServerResult}, + }, + mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, + mcp_server::{server_runtime, ServerRuntime}, + mcp_traits::mcp_handler::McpServerHandler, + utils::validate_mcp_protocol_version, +}; + +use crate::schema::schema_utils::{ClientMessage, SdkError}; + +use axum::{http::HeaderValue, response::IntoResponse}; +use axum::{ + response::{ + sse::{Event, KeepAlive}, + Sse, + }, + Json, +}; +use futures::stream; +use hyper::{header, HeaderMap, StatusCode}; +use rust_mcp_transport::{SessionId, SseTransport}; +use std::{sync::Arc, time::Duration}; +use tokio::io::{duplex, AsyncBufReadExt, BufReader}; + +pub const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id"; +pub const MCP_PROTOCOL_VERSION_HEADER: &str = "Mcp-Protocol-Version"; + +const DUPLEX_BUFFER_SIZE: usize = 8192; + +async fn create_sse_stream( + runtime: Arc, + session_id: SessionId, + state: Arc, + payload: Option<&str>, + standalone: bool, +) -> TransportServerResult> { + let payload_string = payload.map(|p| p.to_string()); + + // TODO: this logic should be moved out after refactoing the mcp_stream.rs + let result = payload_string + .as_ref() + .map(|json_str| contains_request(json_str)) + .unwrap_or(Ok(false)); + let Ok(payload_contains_request) = result else { + return Ok((StatusCode::BAD_REQUEST, Json(SdkError::parse_error())).into_response()); + }; + + // readable stream of string to be used in transport + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + let transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + + let stream_id = if standalone { + DEFAULT_STREAM_ID.to_string() + } else { + state.id_generator.generate() + }; + let ping_interval = state.ping_interval; + let runtime_clone = Arc::clone(&runtime); + + //Start the server runtime + tokio::spawn(async move { + match runtime_clone + .start_stream(transport, &stream_id, ping_interval, payload_string) + .await + { + Ok(_) => tracing::trace!("stream {} exited gracefully.", &stream_id), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + } + let _ = runtime.remove_transport(&stream_id).await; + }); + + // Construct SSE stream + let reader = BufReader::new(write_rx); + + let message_stream = stream::unfold(reader, |mut reader| async move { + let mut line = String::new(); + + match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + Some((Ok(Event::default().data(trimmed_line)), reader)) + } + Err(e) => Some((Err(e), reader)), + } + }); + + let sse_stream = + Sse::new(message_stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(10))); + + // Return SSE response with keep-alive + // Create a Response and set headers + let mut response = sse_stream.into_response(); + response.headers_mut().insert( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(&session_id).unwrap(), + ); + + if !payload_contains_request { + *response.status_mut() = StatusCode::ACCEPTED; + } + Ok(response) +} + +// TODO: this function will be removed after refactoring the readable stream of the transports +// so we would deserialize the string syncronousely and have more control over the flow +// this function could potentially add a 20-250 ns overhead which could be avoided +fn contains_request(json_str: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(json_str)?; + match value { + serde_json::Value::Object(obj) => Ok(obj.contains_key("id") && obj.contains_key("method")), + serde_json::Value::Array(arr) => Ok(arr.iter().all(|item| { + item.as_object() + .map(|obj| obj.contains_key("id") && obj.contains_key("method")) + .unwrap_or(false) + })), + _ => Ok(false), + } +} + +pub async fn create_standalone_stream( + session_id: SessionId, + state: Arc, +) -> TransportServerResult> { + let runtime = state.session_store.get(&session_id).await.ok_or( + TransportServerError::SessionIdInvalid(session_id.to_string()), + )?; + let runtime = runtime.lock().await.to_owned(); + + if runtime.stream_id_exists(DEFAULT_STREAM_ID).await { + let error = + SdkError::bad_request().with_message("Only one SSE stream is allowed per session"); + return Ok((StatusCode::CONFLICT, Json(error)).into_response()); + } + + let mut response = create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + None, + true, + ) + .await?; + *response.status_mut() = StatusCode::OK; + Ok(response) +} + +pub async fn start_new_session( + state: Arc, + payload: &str, +) -> TransportServerResult> { + let session_id: SessionId = state.id_generator.generate(); + + let h: Arc = state.handler.clone(); + // create a new server instance with unique session_id and + let runtime: Arc = Arc::new(server_runtime::create_server_instance( + Arc::clone(&state.server_details), + h, + session_id.to_owned(), + )); + + tracing::info!( + "a new client joined : {}", + runtime.session_id().await.unwrap_or_default().to_owned() + ); + + let response = create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await; + + if response.is_ok() { + state + .session_store + .set(session_id.to_owned(), runtime.clone()) + .await; + } + response +} + +async fn single_shot_stream( + runtime: Arc, + session_id: SessionId, + state: Arc, + payload: Option<&str>, + standalone: bool, +) -> TransportServerResult> { + // readable stream of string to be used in transport + let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client + let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); + + let transport = SseTransport::::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .map_err(|err| TransportServerError::TransportError(err.to_string()))?; + + let stream_id = if standalone { + DEFAULT_STREAM_ID.to_string() + } else { + state.id_generator.generate() + }; + let ping_interval = state.ping_interval; + let runtime_clone = Arc::clone(&runtime); + + let payload_string = payload.map(|p| p.to_string()); + + tokio::spawn(async move { + match runtime_clone + .start_stream(transport, &stream_id, ping_interval, payload_string) + .await + { + Ok(_) => tracing::info!("stream {} exited gracefully.", &stream_id), + Err(err) => tracing::info!("stream {} exited with error : {}", &stream_id, err), + } + let _ = runtime.remove_transport(&stream_id).await; + }); + + // Construct SSE stream + let mut reader = BufReader::new(write_rx); + let mut line = String::new(); + let response = match reader.read_line(&mut line).await { + Ok(0) => None, // EOF + Ok(_) => { + let trimmed_line = line.trim_end_matches('\n').to_owned(); + Some(Ok(trimmed_line)) + } + Err(e) => Some(Err(e)), + }; + + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + headers.insert( + MCP_SESSION_ID_HEADER, + HeaderValue::from_str(&session_id).unwrap(), + ); + + match response { + Some(response_result) => match response_result { + Ok(response_str) => { + Ok((StatusCode::OK, headers, response_str.to_string()).into_response()) + } + Err(err) => Ok(( + StatusCode::INTERNAL_SERVER_ERROR, + headers, + Json(err.to_string()), + ) + .into_response()), + }, + None => Ok(( + StatusCode::UNPROCESSABLE_ENTITY, + headers, + Json("End of the transport stream reached."), + ) + .into_response()), + } +} + +pub async fn process_incoming_message_return( + session_id: SessionId, + state: Arc, + payload: &str, +) -> TransportServerResult { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + + single_shot_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await + // Ok(StatusCode::OK.into_response()) + } + None => { + let error = SdkError::session_not_found(); + Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) + } + } +} + +pub async fn process_incoming_message( + session_id: SessionId, + state: Arc, + payload: &str, +) -> TransportServerResult { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + + create_sse_stream( + runtime.clone(), + session_id.clone(), + state.clone(), + Some(payload), + false, + ) + .await + } + None => { + let error = SdkError::session_not_found(); + Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) + } + } +} + +pub async fn delete_session( + session_id: SessionId, + state: Arc, +) -> TransportServerResult { + match state.session_store.get(&session_id).await { + Some(runtime) => { + let runtime = runtime.lock().await.to_owned(); + runtime.shutdown().await; + state.session_store.delete(&session_id).await; + tracing::info!("client disconnected : {}", &session_id); + Ok((StatusCode::OK, Json("ok")).into_response()) + } + None => { + let error = SdkError::session_not_found(); + Ok((StatusCode::NOT_FOUND, Json(error)).into_response()) + } + } +} + +pub fn acceptable_content_type(headers: &HeaderMap) -> bool { + let accept_header = headers + .get("content-type") + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + accept_header + .split(',') + .any(|val| val.trim().starts_with("application/json")) +} + +pub fn validate_mcp_protocol_version_header(headers: &HeaderMap) -> SdkResult<()> { + let protocol_version_header = headers + .get(MCP_PROTOCOL_VERSION_HEADER) + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + + // requests without protocol version header are acceptable + if protocol_version_header.is_empty() { + return Ok(()); + } + + validate_mcp_protocol_version(protocol_version_header) +} + +pub fn accepts_event_stream(headers: &HeaderMap) -> bool { + let accept_header = headers + .get("accept") + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + + accept_header + .split(',') + .any(|val| val.trim().starts_with("text/event-stream")) +} + +pub fn valid_streaming_http_accept_header(headers: &HeaderMap) -> bool { + let accept_header = headers + .get("accept") + .and_then(|val| val.to_str().ok()) + .unwrap_or(""); + + let types: Vec<_> = accept_header.split(',').map(|v| v.trim()).collect(); + + let has_event_stream = types.iter().any(|v| v.starts_with("text/event-stream")); + let has_json = types.iter().any(|v| v.starts_with("application/json")); + has_event_stream && has_json +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs index 55d15b1..44b671f 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs @@ -3,6 +3,7 @@ use crate::{ app_state::AppState, error::{TransportServerError, TransportServerResult}, }, + mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, utils::remove_query_and_hash, }; use axum::{ @@ -12,7 +13,6 @@ use axum::{ Router, }; use std::{collections::HashMap, sync::Arc}; -use tokio::io::AsyncWriteExt; pub fn routes(state: Arc) -> Router> { Router::new().route( @@ -30,6 +30,7 @@ pub async fn handle_messages( .get("sessionId") .ok_or(TransportServerError::SessionIdMissing)?; + // transmit to the readable stream, that transport is reading from let transmit = state .session_store @@ -38,17 +39,16 @@ pub async fn handle_messages( .ok_or(TransportServerError::SessionIdInvalid( session_id.to_string(), ))?; - let mut transmit = transmit.lock().await; - transmit - .write_all(format!("{message}\n").as_bytes()) - .await - .map_err(|err| TransportServerError::StreamIoError(err.to_string()))?; + let transmit = transmit.lock().await; transmit - .flush() + .consume_payload_string(DEFAULT_STREAM_ID, &message) .await - .map_err(|err| TransportServerError::StreamIoError(err.to_string()))?; + .map_err(|err| { + tracing::trace!("{}", err); + TransportServerError::StreamIoError(err.to_string()) + })?; - Ok(axum::http::StatusCode::OK) + Ok(axum::http::StatusCode::ACCEPTED) } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index 4e7cb3a..e1c00f8 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -1,12 +1,15 @@ +use crate::schema::schema_utils::ClientMessage; use crate::{ - error::McpSdkError, hyper_servers::{ - app_state::AppState, error::TransportServerResult, - middlewares::session_id_gen::generate_session_id, + app_state::AppState, + error::TransportServerResult, + middlewares::{ + protect_dns_rebinding::protect_dns_rebinding, session_id_gen::generate_session_id, + }, }, + mcp_runtimes::server_runtime::DEFAULT_STREAM_ID, mcp_server::{server_runtime, ServerRuntime}, mcp_traits::mcp_handler::McpServerHandler, - McpServer, }; use axum::{ extract::State, @@ -19,17 +22,11 @@ use axum::{ Extension, Router, }; use futures::stream::{self}; -use rust_mcp_schema::schema_utils::ClientMessage; -use rust_mcp_transport::{error::TransportError, SessionId, SseTransport}; +use rust_mcp_transport::{SessionId, SseTransport}; use std::{convert::Infallible, sync::Arc, time::Duration}; -use tokio::{ - io::{duplex, AsyncBufReadExt, BufReader}, - time::{self, Interval}, -}; +use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio_stream::StreamExt; -const CLIENT_PING_TIMEOUT: Duration = Duration::from_secs(2); - const DUPLEX_BUFFER_SIZE: usize = 8192; /// Creates an initial SSE event that returns the messages endpoint @@ -62,6 +59,10 @@ pub fn routes(state: Arc, sse_endpoint: &str) -> Router> state.clone(), generate_session_id, )) + .route_layer(middleware::from_fn_with_state( + state.clone(), + protect_dns_rebinding, + )) } /// Handles Server-Sent Events (SSE) connections @@ -82,76 +83,56 @@ pub async fn handle_sse( SseTransport::::message_endpoint(&state.sse_message_endpoint, &session_id); // readable stream of string to be used in transport + // writing string to read_tx will be received as messages inside the transport and messages will be processed let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE); + // writable stream to deliver message to the client let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE); - state - .session_store - .set(session_id.to_owned(), read_tx) - .await; - // create a transport for sending/receiving messages - let transport = - SseTransport::new(read_rx, write_tx, Arc::clone(&state.transport_options)).unwrap(); - let d: Arc = state.handler.clone(); + let transport = SseTransport::new( + read_rx, + write_tx, + read_tx, + Arc::clone(&state.transport_options), + ) + .unwrap(); + let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and let server: Arc = Arc::new(server_runtime::create_server_instance( Arc::clone(&state.server_details), - transport, - d, + h, session_id.to_owned(), )); - // Ping the server periodically to check if the SSE client is still connected - let server_ping = Arc::clone(&server); - tokio::spawn(async move { - let mut interval: Interval = time::interval(state.ping_interval); - loop { - interval.tick().await; // Wait for the next tick (10 seconds) - if !server_ping.is_initialized() { - continue; - } - match server_ping.ping(Some(CLIENT_PING_TIMEOUT)).await { - Ok(_) => {} - Err(McpSdkError::TransportError(TransportError::StdioError(error))) => { - if error.kind() == std::io::ErrorKind::BrokenPipe { - if let Some(session_id) = server_ping.session_id().await { - tracing::info!("Stopping {} server task ...", session_id); - state.session_store.delete(&session_id).await; - break; - } - } - } - _ => {} - } - } - }); + state + .session_store + .set(session_id.to_owned(), server.clone()) + .await; - tracing::info!( - "A new client joined : {}", - server.session_id().await.unwrap_or_default().to_owned() - ); + tracing::info!("A new client joined : {}", session_id.to_owned()); // Start the server tokio::spawn(async move { - match server.start().await { - Ok(_) => tracing::info!( - "server {} exited gracefully.", - server.session_id().await.unwrap_or_default().to_owned() - ), + match server + .start_stream(transport, DEFAULT_STREAM_ID, state.ping_interval, None) + .await + { + Ok(_) => tracing::info!("server {} exited gracefully.", session_id.to_owned()), Err(err) => tracing::info!( "server {} exited with error : {}", - server.session_id().await.unwrap_or_default().to_owned(), + session_id.to_owned(), err ), - } + }; + + state.session_store.delete(&session_id).await; }); // Initial SSE message to inform the client about the server's endpoint let initial_event = stream::once(async move { initial_event(&messages_endpoint) }); - // Construct SSE stream for sending MCP messages to the server + // Construct SSE stream let reader = BufReader::new(write_rx); let message_stream = stream::unfold(reader, |mut reader| async move { diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs new file mode 100644 index 0000000..83cc372 --- /dev/null +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/streamable_http_routes.rs @@ -0,0 +1,158 @@ +use super::hyper_utils::{start_new_session, MCP_SESSION_ID_HEADER}; +use crate::schema::schema_utils::SdkError; +use crate::{ + error::McpSdkError, + hyper_servers::{ + app_state::AppState, + error::TransportServerResult, + middlewares::protect_dns_rebinding::protect_dns_rebinding, + routes::hyper_utils::{ + acceptable_content_type, accepts_event_stream, create_standalone_stream, + delete_session, process_incoming_message, process_incoming_message_return, + valid_streaming_http_accept_header, validate_mcp_protocol_version_header, + }, + }, + utils::valid_initialize_method, +}; +use axum::{ + extract::{Query, State}, + middleware, + response::IntoResponse, + routing::{delete, post}, + Json, Router, +}; +use hyper::{HeaderMap, StatusCode}; +use rust_mcp_transport::SessionId; +use std::{collections::HashMap, sync::Arc}; + +use axum::routing::get; + +pub fn routes(state: Arc, streamable_http_endpoint: &str) -> Router> { + Router::new() + .route(streamable_http_endpoint, get(handle_streamable_http_get)) + .route(streamable_http_endpoint, post(handle_streamable_http_post)) + .route( + streamable_http_endpoint, + delete(handle_streamable_http_delete), + ) + .route_layer(middleware::from_fn_with_state( + state.clone(), + protect_dns_rebinding, + )) +} + +pub async fn handle_streamable_http_get( + headers: HeaderMap, + State(state): State>, +) -> TransportServerResult { + if !accepts_event_stream(&headers) { + let error = SdkError::bad_request().with_message(r#"Client must accept text/event-stream"#); + return Ok((StatusCode::NOT_ACCEPTABLE, Json(error)).into_response()); + } + + if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + let error = + SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(session_id) => { + let res = create_standalone_stream(session_id, state).await?; + Ok(res.into_response()) + } + None => { + let error = SdkError::bad_request().with_message("Bad request: session not found"); + Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) + } + } +} + +pub async fn handle_streamable_http_post( + headers: HeaderMap, + State(state): State>, + Query(_params): Query>, + payload: String, +) -> TransportServerResult { + if !valid_streaming_http_accept_header(&headers) { + let error = SdkError::bad_request() + .with_message(r#"Client must accept both application/json and text/event-stream"#); + return Ok((StatusCode::NOT_ACCEPTABLE, Json(error)).into_response()); + } + + if !acceptable_content_type(&headers) { + let error = SdkError::bad_request() + .with_message(r#"Unsupported Media Type: Content-Type must be application/json"#); + return Ok((StatusCode::UNSUPPORTED_MEDIA_TYPE, Json(error)).into_response()); + } + + if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + let error = + SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + //TODO: validate reconnect after disconnect + + match session_id { + // has session-id => write to the existing stream + Some(id) => { + if state.enable_json_response { + let res = process_incoming_message_return(id, state, &payload).await?; + Ok(res.into_response()) + } else { + let res = process_incoming_message(id, state, &payload).await?; + Ok(res.into_response()) + } + } + None => match valid_initialize_method(&payload) { + Ok(_) => { + return start_new_session(state, &payload).await; + } + Err(McpSdkError::SdkError(error)) => { + Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) + } + Err(error) => { + let error = SdkError::bad_request().with_message(&error.to_string()); + Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) + } + }, + } +} + +pub async fn handle_streamable_http_delete( + headers: HeaderMap, + State(state): State>, +) -> TransportServerResult { + if let Err(parse_error) = validate_mcp_protocol_version_header(&headers) { + let error = + SdkError::bad_request().with_message(format!(r#"Bad Request: {parse_error}"#).as_str()); + return Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()); + } + + let session_id: Option = headers + .get(MCP_SESSION_ID_HEADER) + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + match session_id { + Some(id) => { + let res = delete_session(id, state).await; + Ok(res.into_response()) + } + None => { + let error = SdkError::bad_request().with_message("Bad Request: Session not found"); + Ok((StatusCode::BAD_REQUEST, Json(error)).into_response()) + } + } +} diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs index 1078a6d..a114012 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs @@ -1,4 +1,7 @@ -use crate::mcp_traits::mcp_handler::McpServerHandler; +use crate::{ + error::SdkResult, mcp_server::hyper_runtime::HyperRuntime, + mcp_traits::mcp_handler::McpServerHandler, +}; #[cfg(feature = "ssl")] use axum_server::tls_rustls::RustlsConfig; use axum_server::Handle; @@ -22,37 +25,74 @@ use rust_mcp_transport::TransportOptions; // Default client ping interval (12 seconds) const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12); -const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 30; +const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 5; // Default Server-Sent Events (SSE) endpoint path const DEFAULT_SSE_ENDPOINT: &str = "/sse"; // Default MCP Messages endpoint path const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages"; +// Default Streamable HTTP endpoint path +const DEFAULT_STREAMABLE_HTTP_ENDPOINT: &str = "/mcp"; /// Configuration struct for the Hyper server /// Used to configure the HyperServer instance. pub struct HyperServerOptions { - /// Hostname or IP address the server will bind to (default: "localhost") + /// Hostname or IP address the server will bind to (default: "127.0.0.1") pub host: String, + /// Hostname or IP address the server will bind to (default: "8080") pub port: u16, - /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) - pub custom_sse_endpoint: Option, - /// Optional custom path for the MCP messages endpoint (default: `/messages`) - pub custom_messages_endpoint: Option, + + /// Optional custom path for the Streamable HTTP endpoint (default: `/mcp`) + pub custom_streamable_http_endpoint: Option, + + /// This setting only applies to streamable HTTP. + /// If true, the server will return JSON responses instead of starting an SSE stream. + /// This can be useful for simple request/response scenarios without streaming. + /// Default is false (SSE streams are preferred). + pub enable_json_response: Option, + /// Interval between automatic ping messages sent to clients to detect disconnects pub ping_interval: Duration, + + /// Shared transport configuration used by the server + pub transport_options: Arc, + + /// Optional thread-safe session id generator to generate unique session IDs. + pub session_id_generator: Option>, + /// Enables SSL/TLS if set to `true` pub enable_ssl: bool, + /// Path to the SSL/TLS certificate file (e.g., "cert.pem"). /// Required if `enable_ssl` is `true`. pub ssl_cert_path: Option, + /// Path to the SSL/TLS private key file (e.g., "key.pem"). /// Required if `enable_ssl` is `true`. pub ssl_key_path: Option, - /// Shared transport configuration used by the server - pub transport_options: Arc, - /// Optional thread-safe session id generator to generate unique session IDs. - pub session_id_generator: Option>, + + /// If set to true, the SSE transport will also be supported for backward compatibility (default: true) + pub sse_support: bool, + + /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`) + /// Applicable only if sse_support is true + pub custom_sse_endpoint: Option, + + /// Optional custom path for the MCP messages endpoint for sse (default: `/messages`) + /// Applicable only if sse_support is true + pub custom_messages_endpoint: Option, + + /// List of allowed host header values for DNS rebinding protection. + /// If not specified, host validation is disabled. + pub allowed_hosts: Option>, + + /// List of allowed origin header values for DNS rebinding protection. + /// If not specified, origin validation is disabled. + pub allowed_origins: Option>, + + /// Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + /// Default is false for backwards compatibility. + pub dns_rebinding_protection: bool, } impl HyperServerOptions { @@ -94,7 +134,7 @@ impl HyperServerOptions { /// /// # Returns /// * `TransportServerResult` - The resolved server address or an error - async fn resolve_server_address(&self) -> TransportServerResult { + pub(crate) async fn resolve_server_address(&self) -> TransportServerResult { self.validate()?; let mut host = self.host.to_string(); @@ -123,6 +163,24 @@ impl HyperServerOptions { Ok(addr) } + pub fn base_url(&self) -> String { + format!( + "{}://{}:{}", + if self.enable_ssl { "https" } else { "http" }, + self.host, + self.port + ) + } + pub fn streamable_http_url(&self) -> String { + format!("{}{}", self.base_url(), self.streamable_http_endpoint()) + } + pub fn sse_url(&self) -> String { + format!("{}{}", self.base_url(), self.sse_endpoint()) + } + pub fn sse_message_url(&self) -> String { + format!("{}{}", self.base_url(), self.sse_messages_endpoint()) + } + pub fn sse_endpoint(&self) -> &str { self.custom_sse_endpoint .as_deref() @@ -134,18 +192,25 @@ impl HyperServerOptions { .as_deref() .unwrap_or(DEFAULT_MESSAGES_ENDPOINT) } + + pub fn streamable_http_endpoint(&self) -> &str { + self.custom_messages_endpoint + .as_deref() + .unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT) + } } /// Default implementation for HyperServerOptions /// -/// Provides default values for the server configuration, including localhost address, -/// port 8080, default SSE endpoint, and 12-second ping interval. +/// Provides default values for the server configuration, including 127.0.0.1 address, +/// port 8080, default Streamable HTTP endpoint, and 12-second ping interval. impl Default for HyperServerOptions { fn default() -> Self { Self { host: "127.0.0.1".to_string(), port: 8080, custom_sse_endpoint: None, + custom_streamable_http_endpoint: None, custom_messages_endpoint: None, ping_interval: DEFAULT_CLIENT_PING_INTERVAL, transport_options: Default::default(), @@ -153,6 +218,11 @@ impl Default for HyperServerOptions { ssl_cert_path: None, ssl_key_path: None, session_id_generator: None, + enable_json_response: None, + sse_support: true, + allowed_hosts: None, + allowed_origins: None, + dns_rebinding_protection: false, } } } @@ -161,7 +231,7 @@ impl Default for HyperServerOptions { pub struct HyperServer { app: Router, state: Arc, - options: HyperServerOptions, + pub(crate) options: HyperServerOptions, handle: Handle, } @@ -192,7 +262,12 @@ impl HyperServer { handler, ping_interval: server_options.ping_interval, sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(), + http_streamable_endpoint: server_options.streamable_http_endpoint().to_owned(), transport_options: Arc::clone(&server_options.transport_options), + enable_json_response: server_options.enable_json_response.unwrap_or(false), + allowed_hosts: server_options.allowed_hosts.take(), + allowed_origins: server_options.allowed_origins.take(), + dns_rebinding_protection: server_options.dns_rebinding_protection, }); let app = app_routes(Arc::clone(&state), &server_options); Self { @@ -246,17 +321,32 @@ impl HyperServer { "http" }; - let server_url = format!( - "{} is available at {}://{}{}", + let mut server_url = format!( + "\n• Streamable HTTP {} is available at {}://{}{}", server_type, protocol, addr, - self.options.sse_endpoint() + self.options.streamable_http_endpoint() ); + if self.options.sse_support { + let sse_url = format!( + "\n• SSE {} is available at {}://{}{}", + server_type, + protocol, + addr, + self.options.sse_endpoint() + ); + server_url.push_str(&sse_url); + }; + Ok(server_url) } + pub fn options(&self) -> &HyperServerOptions { + &self.options + } + // pub fn with_layer(mut self, layer: L) -> Self // where // // L: Layer + Clone + Send + Sync + 'static, @@ -274,7 +364,7 @@ impl HyperServer { /// # Returns /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise #[cfg(feature = "ssl")] - async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> { + pub(crate) async fn start_ssl(self, addr: SocketAddr) -> TransportServerResult<()> { let config = RustlsConfig::from_pem_file( self.options.ssl_cert_path.as_deref().unwrap_or_default(), self.options.ssl_key_path.as_deref().unwrap_or_default(), @@ -286,8 +376,9 @@ impl HyperServer { // Spawn a task to trigger shutdown on signal let handle_clone = self.handle.clone(); + let state_clone = self.state().clone(); tokio::spawn(async move { - shutdown_signal(handle_clone).await; + shutdown_signal(handle_clone, state_clone).await; }); let handle_clone = self.handle.clone(); @@ -310,13 +401,13 @@ impl HyperServer { /// /// # Returns /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise - async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> { + pub(crate) async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> { tracing::info!("{}", self.server_info(Some(addr)).await?); // Spawn a task to trigger shutdown on signal let handle_clone = self.handle.clone(); tokio::spawn(async move { - shutdown_signal(handle_clone).await; + shutdown_signal(handle_clone, self.state.clone()).await; }); let handle_clone = self.handle.clone(); @@ -333,28 +424,25 @@ impl HyperServer { /// Panics if SSL is requested but the "ssl" feature is not enabled. /// /// # Returns - /// * `TransportServerResult<()>` - Ok if the server starts successfully, Err otherwise - pub async fn start(self) -> TransportServerResult<()> { - let addr = self.options.resolve_server_address().await?; - - #[cfg(feature = "ssl")] - if self.options.enable_ssl { - self.start_ssl(addr).await - } else { - self.start_http(addr).await - } + /// * `SdkResult<()>` - Ok if the server starts successfully, Err otherwise + pub async fn start(self) -> SdkResult<()> { + let runtime = HyperRuntime::create(self).await?; + runtime.await_server().await + } - #[cfg(not(feature = "ssl"))] - if self.options.enable_ssl { - panic!("SSL requested but the 'ssl' feature is not enabled"); - } else { - self.start_http(addr).await - } + /// Similar to start() , but returns a HyperRuntime after server started + /// + /// HyperRuntime could be used to access sessions and send server initiated messages if needed + /// + /// # Returns + /// * `SdkResult` - Ok if the server starts successfully, Err otherwise + pub async fn start_runtime(self) -> SdkResult { + HyperRuntime::create(self).await } } // Shutdown signal handler -async fn shutdown_signal(handle: Handle) { +async fn shutdown_signal(handle: Handle, state: Arc) { // Wait for a Ctrl+C or SIGTERM signal let ctrl_c = async { signal::ctrl_c() @@ -379,6 +467,7 @@ async fn shutdown_signal(handle: Handle) { } tracing::info!("Signal received, starting graceful shutdown"); + state.session_store.clear().await; // Trigger graceful shutdown with a timeout handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS))); } diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs index b0716b8..95b2158 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs @@ -4,11 +4,13 @@ use std::sync::Arc; use async_trait::async_trait; pub use in_memory::*; use rust_mcp_transport::SessionId; -use tokio::{io::DuplexStream, sync::Mutex}; +use tokio::sync::Mutex; use uuid::Uuid; +use crate::mcp_server::ServerRuntime; + // Type alias for the server-side duplex stream used in sessions -pub type TxServer = DuplexStream; +pub type TxServer = Arc; /// Trait defining the interface for session storage operations /// @@ -40,7 +42,9 @@ pub trait SessionStore: Send + Sync { async fn keys(&self) -> Vec; - async fn values(&self) -> Vec>>; + async fn values(&self) -> Vec>>; + + async fn has(&self, session: &SessionId) -> bool; } /// Trait for generating session identifiers diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs index 342d232..80cba3f 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs @@ -3,7 +3,6 @@ use super::{SessionStore, TxServer}; use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; -use tokio::io::DuplexStream; use tokio::sync::Mutex; use tokio::sync::RwLock; @@ -59,8 +58,12 @@ impl SessionStore for InMemorySessionStore { let store = self.store.read().await; store.keys().cloned().collect::>() } - async fn values(&self) -> Vec>> { + async fn values(&self) -> Vec>> { let store = self.store.read().await; store.values().cloned().collect::>() } + async fn has(&self, session: &SessionId) -> bool { + let store = self.store.read().await; + store.contains_key(session) + } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs index 035756b..8d113c3 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime.rs @@ -1,14 +1,18 @@ pub mod mcp_client_runtime; pub mod mcp_client_runtime_core; -use crate::schema::schema_utils::{self, MessageFromClient, ServerMessage}; use crate::schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, FromMessage, MessageFromClient, ServerMessage, + ServerMessages, + }, InitializeRequest, InitializeRequestParams, InitializeResult, InitializedNotification, RpcError, ServerResult, }; use async_trait::async_trait; -use futures::future::join_all; +use futures::future::{join_all, try_join_all}; use futures::StreamExt; + use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; use std::sync::{Arc, RwLock}; use tokio::io::{AsyncBufReadExt, BufReader}; @@ -21,7 +25,15 @@ use crate::utils::ensure_server_protocole_compatibility; pub struct ClientRuntime { // The transport interface for handling messages between client and server - transport: Box>, + transport: Arc< + dyn Transport< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + >, // The handler for processing MCP messages handler: Box, // // Information about the server @@ -34,11 +46,17 @@ pub struct ClientRuntime { impl ClientRuntime { pub(crate) fn new( client_details: InitializeRequestParams, - transport: impl Transport, + transport: impl Transport< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, handler: Box, ) -> Self { Self { - transport: Box::new(transport), + transport: Arc::new(transport), handler, client_details, server_details: Arc::new(RwLock::new(None)), @@ -68,21 +86,79 @@ impl ClientRuntime { } Ok(()) } + + pub(crate) async fn handle_message( + &self, + message: ServerMessage, + transport: &Arc< + dyn Transport< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, + >, + ) -> SdkResult> { + let response = match message { + ServerMessage::Request(jsonrpc_request) => { + let result = self + .handler + .handle_request(jsonrpc_request.request, self) + .await; + + // create a response to send back to the server + let response: MessageFromClient = match result { + Ok(success_value) => success_value.into(), + Err(error_value) => MessageFromClient::Error(error_value), + }; + + let mcp_message = ClientMessage::from_message(response, Some(jsonrpc_request.id))?; + Some(mcp_message) + } + ServerMessage::Notification(jsonrpc_notification) => { + self.handler + .handle_notification(jsonrpc_notification.notification, self) + .await?; + None + } + ServerMessage::Error(jsonrpc_error) => { + self.handler.handle_error(jsonrpc_error.error, self).await?; + None + } + ServerMessage::Response(response) => { + if let Some(tx_response) = transport.pending_request_tx(&response.id).await { + tx_response + .send(ServerMessage::Response(response)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received response or error without a matching request: {:?}", + &response.id + ); + } + None + } + }; + Ok(response) + } } #[async_trait] impl McpClient for ClientRuntime { - async fn sender(&self) -> &tokio::sync::RwLock>> + fn sender(&self) -> Arc>>> where - MessageDispatcher: McpDispatch, + MessageDispatcher: + McpDispatch, { - (self.transport.message_sender()) as _ + (self.transport.message_sender().clone()) as _ } async fn start(self: Arc) -> SdkResult<()> { + //TODO: improve the flow let mut stream = self.transport.start().await?; - - let mut error_io_stream = self.transport.error_stream().write().await; + let transport = self.transport.clone(); + let mut error_io_stream = transport.error_stream().write().await; let error_io_stream = error_io_stream.take(); let self_clone = Arc::clone(&self); @@ -125,53 +201,58 @@ impl McpClient for ClientRuntime { Ok::<(), McpSdkError>(()) }); - // send initialize request to the MCP server - self_clone.initialize_request().await?; + let transport = self.transport.clone(); let main_task = tokio::spawn(async move { - let sender = self_clone.sender().await.read().await; + let sender = self_clone.sender(); + let sender = sender.read().await; let sender = sender .as_ref() .ok_or(schema_utils::SdkError::connection_closed())?; - while let Some(mcp_message) = stream.next().await { + while let Some(mcp_messages) = stream.next().await { let self_ref = &*self_clone; - match mcp_message { - ServerMessage::Request(jsonrpc_request) => { - let result = self_ref - .handler - .handle_request(jsonrpc_request.request, self_ref) - .await; - - // create a response to send back to the server - let response: MessageFromClient = match result { - Ok(success_value) => success_value.into(), - Err(error_value) => MessageFromClient::Error(error_value), - }; - // send the response back with corresponding request id - sender - .send(response, Some(jsonrpc_request.id), None) - .await?; - } - ServerMessage::Notification(jsonrpc_notification) => { - self_ref - .handler - .handle_notification(jsonrpc_notification.notification, self_ref) - .await?; + match mcp_messages { + ServerMessages::Single(server_message) => { + let result = self_ref.handle_message(server_message, &transport).await; + + match result { + Ok(result) => { + if let Some(result) = result { + sender + .send_message(ClientMessages::Single(result), None) + .await?; + } + } + Err(error) => { + tracing::error!("Error handling message : {}", error) + } + } } - ServerMessage::Error(jsonrpc_error) => { - self_ref - .handler - .handle_error(jsonrpc_error.error, self_ref) - .await?; + ServerMessages::Batch(server_messages) => { + let handling_tasks: Vec<_> = server_messages + .into_iter() + .map(|server_message| { + self_ref.handle_message(server_message, &transport) + }) + .collect(); + let results: Vec<_> = try_join_all(handling_tasks).await?; + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + sender + .send_message(ClientMessages::Batch(results), None) + .await?; + } } - // The response is the result of a request, it is processed at the transport level. - ServerMessage::Response(_) => {} } } Ok::<(), McpSdkError>(()) }); + // send initialize request to the MCP server + self.initialize_request().await?; + let mut lock = self.handlers.lock().await; lock.push(main_task); lock.push(err_task); diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs index f3a1c79..9ccd4d9 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use crate::schema::{ schema_utils::{ - MessageFromClient, NotificationFromServer, RequestFromServer, ResultFromClient, - ServerMessage, + ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, + RequestFromServer, ResultFromClient, ServerMessage, ServerMessages, }, InitializeRequestParams, RpcError, ServerNotification, ServerRequest, }; @@ -40,7 +40,13 @@ use super::ClientRuntime; /// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport, + transport: impl Transport< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, handler: impl ClientHandler, ) -> Arc { Arc::new(ClientRuntime::new( diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs index d8d2400..3bdc318 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/client_runtime/mcp_client_runtime_core.rs @@ -2,12 +2,13 @@ use std::sync::Arc; use crate::schema::{ schema_utils::{ - MessageFromClient, NotificationFromServer, RequestFromServer, ResultFromClient, - ServerMessage, + ClientMessage, ClientMessages, MessageFromClient, NotificationFromServer, + RequestFromServer, ResultFromClient, ServerMessage, ServerMessages, }, InitializeRequestParams, RpcError, }; use async_trait::async_trait; + use rust_mcp_transport::Transport; use crate::{ @@ -41,7 +42,13 @@ use super::ClientRuntime; /// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/simple-mcp-client-core) pub fn create_client( client_details: InitializeRequestParams, - transport: impl Transport, + transport: impl Transport< + ServerMessages, + MessageFromClient, + ServerMessage, + ClientMessages, + ClientMessage, + >, handler: impl ClientHandlerCore, ) -> Arc { Arc::new(ClientRuntime::new( diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 9c5a7f6..28cdd8c 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -1,25 +1,45 @@ pub mod mcp_server_runtime; pub mod mcp_server_runtime_core; +use crate::schema::{ + schema_utils::{ + ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage, + ServerMessages, + }, + InitializeRequestParams, InitializeResult, RequestId, RpcError, +}; -use crate::schema::schema_utils::{self, MessageFromServer}; -use crate::schema::{InitializeRequestParams, InitializeResult, RpcError}; use async_trait::async_trait; -use futures::StreamExt; -use rust_mcp_transport::{IoStream, McpDispatch, MessageDispatcher, Transport}; -use schema_utils::ClientMessage; +use futures::future::try_join_all; +use futures::{StreamExt, TryFutureExt}; + +use rust_mcp_transport::{IoStream, TransportDispatcher}; + +use std::collections::HashMap; use std::sync::{Arc, RwLock}; +use std::time::Duration; use tokio::io::AsyncWriteExt; +use tokio::sync::oneshot; use crate::error::SdkResult; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; #[cfg(feature = "hyper-server")] use rust_mcp_transport::SessionId; +pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; + +// Define a type alias for the TransportDispatcher trait object +type TransportType = Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, +>; /// Struct representing the runtime core of the MCP server, handling transport and client details pub struct ServerRuntime { - // The transport interface for handling messages between client and server - transport: Box>, // The handler for processing MCP messages handler: Arc, // Information about the server @@ -28,6 +48,7 @@ pub struct ServerRuntime { client_details: Arc>>, #[cfg(feature = "hyper-server")] session_id: Option, + transport_map: tokio::sync::RwLock>, } #[async_trait] @@ -46,6 +67,42 @@ impl McpServer for ServerRuntime { } } + async fn send( + &self, + message: MessageFromServer, + request_id: Option, + request_timeout: Option, + ) -> SdkResult> { + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + + let mcp_message = ServerMessage::from_message(message, request_id)?; + transport + .send_message(ServerMessages::Single(mcp_message), request_timeout) + .map_err(|err| err.into()) + .await + } + + async fn send_batch( + &self, + messages: Vec, + request_timeout: Option, + ) -> SdkResult>> { + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + + transport + .send_batch(messages, request_timeout) + .map_err(|err| err.into()) + .await + } + /// Returns the server's details, including server capability, /// instructions, protocol_version , server_info and optional meta data fn server_info(&self) -> &InitializeResult { @@ -62,69 +119,67 @@ impl McpServer for ServerRuntime { } } - async fn sender(&self) -> &tokio::sync::RwLock>> - where - MessageDispatcher: McpDispatch, - { - (self.transport.message_sender()) as _ - } - /// Main runtime loop, processes incoming messages and handles requests async fn start(&self) -> SdkResult<()> { - let mut stream = self.transport.start().await?; + let transport_map = self.transport_map.read().await; + + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; - let sender = self.transport.message_sender().read().await; - let sender = sender - .as_ref() - .ok_or(schema_utils::SdkError::connection_closed())?; + let mut stream = transport.start().await?; self.handler.on_server_started(self).await; // Process incoming messages from the client - while let Some(mcp_message) = stream.next().await { - match mcp_message { - // Handle a client request - ClientMessage::Request(client_jsonrpc_request) => { - let result = self - .handler - .handle_request(client_jsonrpc_request.request, self) - .await; - // create a response to send back to the client - let response: MessageFromServer = match result { - Ok(success_value) => success_value.into(), - Err(error_value) => { - // Error occurred during initialization. - // A likely cause could be an unsupported protocol version. - if !self.is_initialized() { - return Err(error_value.into()); + while let Some(mcp_messages) = stream.next().await { + match mcp_messages { + ClientMessages::Single(client_message) => { + let result = self.handle_message(client_message, transport).await; + + match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .await?; } - MessageFromServer::Error(error_value) } - }; - - // send the response back with corresponding request id - sender - .send(response, Some(client_jsonrpc_request.id), None) - .await?; - } - ClientMessage::Notification(client_jsonrpc_notification) => { - self.handler - .handle_notification(client_jsonrpc_notification.notification, self) - .await?; + Err(error) => { + tracing::error!("Error handling message : {}", error) + } + } } - ClientMessage::Error(jsonrpc_error) => { - self.handler.handle_error(jsonrpc_error.error, self).await?; + ClientMessages::Batch(client_messages) => { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, transport)) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + if !results.is_empty() { + transport + .send_message(ServerMessages::Batch(results), None) + .await?; + } } - // The response is the result of a request, it is processed at the transport level. - ClientMessage::Response(_) => {} } } - return Ok(()); } async fn stderr_message(&self, message: String) -> SdkResult<()> { - let mut lock = self.transport.error_stream().write().await; + let transport_map = self.transport_map.read().await; + let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( + RpcError::internal_error() + .with_message("transport stream does not exists or is closed!".to_string()), + )?; + let mut lock = transport.error_stream().write().await; + if let Some(IoStream::Writable(stderr)) = lock.as_mut() { stderr.write_all(message.as_bytes()).await?; stderr.write_all(b"\n").await?; @@ -135,6 +190,221 @@ impl McpServer for ServerRuntime { } impl ServerRuntime { + pub(crate) async fn consume_payload_string( + &self, + stream_id: &str, + payload: &str, + ) -> SdkResult<()> { + let transport_map = self.transport_map.read().await; + + let transport = transport_map.get(stream_id).ok_or( + RpcError::internal_error() + .with_message("stream id does not exists or is closed!".to_string()), + )?; + + transport.consume_string_payload(payload).await?; + + Ok(()) + } + + pub(crate) async fn handle_message( + &self, + message: ClientMessage, + transport: &Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, + >, + ) -> SdkResult> { + let response = match message { + // Handle a client request + ClientMessage::Request(client_jsonrpc_request) => { + let result = self + .handler + .handle_request(client_jsonrpc_request.request, self) + .await; + // create a response to send back to the client + let response: MessageFromServer = match result { + Ok(success_value) => success_value.into(), + Err(error_value) => { + // Error occurred during initialization. + // A likely cause could be an unsupported protocol version. + if !self.is_initialized() { + return Err(error_value.into()); + } + MessageFromServer::Error(error_value) + } + }; + + let mpc_message: ServerMessage = + ServerMessage::from_message(response, Some(client_jsonrpc_request.id))?; + + Some(mpc_message) + } + ClientMessage::Notification(client_jsonrpc_notification) => { + self.handler + .handle_notification(client_jsonrpc_notification.notification, self) + .await?; + None + } + ClientMessage::Error(jsonrpc_error) => { + self.handler.handle_error(jsonrpc_error.error, self).await?; + None + } + // The response is the result of a request, it is processed at the transport level. + ClientMessage::Response(response) => { + if let Some(tx_response) = transport.pending_request_tx(&response.id).await { + tx_response + .send(ClientMessage::Response(response)) + .map_err(|e| RpcError::internal_error().with_message(e.to_string()))?; + } else { + tracing::warn!( + "Received response or error without a matching request: {:?}", + &response.id + ); + } + None + } + }; + Ok(response) + } + + pub(crate) async fn store_transport( + &self, + stream_id: &str, + transport: Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, + >, + ) -> SdkResult<()> { + let mut transport_map = self.transport_map.write().await; + tracing::trace!("save transport for stream id : {}", stream_id); + transport_map.insert(stream_id.to_string(), transport); + Ok(()) + } + + pub(crate) async fn remove_transport(&self, stream_id: &str) -> SdkResult<()> { + let mut transport_map = self.transport_map.write().await; + tracing::trace!("removing transport for stream id : {}", stream_id); + transport_map.remove(stream_id); + Ok(()) + } + + pub(crate) async fn transport_by_stream( + &self, + stream_id: &str, + ) -> SdkResult< + Arc< + dyn TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, + >, + > { + let transport_map = self.transport_map.read().await; + transport_map.get(stream_id).cloned().ok_or_else(|| { + RpcError::internal_error() + .with_message(format!("Transport for key {stream_id} not found")) + .into() + }) + } + + pub(crate) async fn shutdown(&self) { + let mut transport_map = self.transport_map.write().await; + let items: Vec<_> = transport_map.drain().map(|(_, v)| v).collect(); + drop(transport_map); + for item in items { + let _ = item.shut_down().await; + } + } + + pub(crate) async fn stream_id_exists(&self, stream_id: &str) -> bool { + let transport_map = self.transport_map.read().await; + transport_map.contains_key(stream_id) + } + + pub(crate) async fn start_stream( + self: Arc, + transport: impl TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, + stream_id: &str, + ping_interval: Duration, + payload: Option, + ) -> SdkResult<()> { + let mut stream = transport.start().await?; + + self.store_transport(stream_id, Arc::new(transport)).await?; + + let transport = self.transport_by_stream(stream_id).await?; + + let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); + let _ = transport.keep_alive(ping_interval, disconnect_tx).await; + + // in case there is a payload, we consume it by transport to get processed + if let Some(payload) = payload { + transport.consume_string_payload(&payload).await?; + } + + loop { + tokio::select! { + Some(mcp_messages) = stream.next() =>{ + + match mcp_messages { + ClientMessages::Single(client_message) => { + let result = self.handle_message(client_message, &transport).await?; + if let Some(result) = result { + transport.send_message(ServerMessages::Single(result), None).await?; + } + } + ClientMessages::Batch(client_messages) => { + + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, &transport)) + .collect(); + + let results: Vec<_> = try_join_all(handling_tasks).await?; + + let results: Vec<_> = results.into_iter().flatten().collect(); + + + if !results.is_empty() { + transport.send_message(ServerMessages::Batch(results), None).await?; + } + } + } + // close the stream after all messages are sent, unless it is a standalone stream + if !stream_id.eq(DEFAULT_STREAM_ID){ + return Ok(()); + } + } + _ = &mut disconnect_rx => { + self.remove_transport(stream_id).await?; + // Disconnection detected by keep-alive task + return Err(SdkError::connection_closed().into()); + + } + } + } + } + #[cfg(feature = "hyper-server")] pub(crate) async fn session_id(&self) -> Option { self.session_id.to_owned() @@ -143,31 +413,38 @@ impl ServerRuntime { #[cfg(feature = "hyper-server")] pub(crate) fn new_instance( server_details: Arc, - transport: impl Transport, handler: Arc, session_id: SessionId, ) -> Self { Self { server_details, client_details: Arc::new(RwLock::new(None)), - transport: Box::new(transport), handler, session_id: Some(session_id), + transport_map: tokio::sync::RwLock::new(HashMap::new()), } } pub(crate) fn new( server_details: InitializeResult, - transport: impl Transport, + transport: impl TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, handler: Arc, ) -> Self { + let mut map: HashMap = HashMap::new(); + map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); Self { server_details: Arc::new(server_details), client_details: Arc::new(RwLock::new(None)), - transport: Box::new(transport), handler, #[cfg(feature = "hyper-server")] session_id: None, + transport_map: tokio::sync::RwLock::new(map), } } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index 5d9c433..26f37e1 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -2,13 +2,14 @@ use std::sync::Arc; use crate::schema::{ schema_utils::{ - self, CallToolError, ClientMessage, MessageFromServer, NotificationFromClient, - RequestFromClient, ResultFromServer, + self, CallToolError, ClientMessage, ClientMessages, MessageFromServer, + NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, }, CallToolResult, ClientNotification, ClientRequest, InitializeResult, RpcError, }; use async_trait::async_trait; -use rust_mcp_transport::Transport; + +use rust_mcp_transport::TransportDispatcher; use super::ServerRuntime; #[cfg(feature = "hyper-server")] @@ -40,7 +41,13 @@ use crate::{ /// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server) pub fn create_server( server_details: InitializeResult, - transport: impl Transport, + transport: impl TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, handler: impl ServerHandler, ) -> ServerRuntime { ServerRuntime::new( @@ -53,11 +60,10 @@ pub fn create_server( #[cfg(feature = "hyper-server")] pub(crate) fn create_server_instance( server_details: Arc, - transport: impl Transport, handler: Arc, session_id: SessionId, ) -> ServerRuntime { - ServerRuntime::new_instance(server_details, transport, handler, session_id) + ServerRuntime::new_instance(server_details, handler, session_id) } pub(crate) struct ServerRuntimeInternalHandler { diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index b2e021c..27f04df 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -1,19 +1,19 @@ -use std::sync::Arc; - -use crate::schema::schema_utils::{ - self, ClientMessage, MessageFromServer, NotificationFromClient, RequestFromClient, - ResultFromServer, -}; -use crate::schema::{ClientRequest, InitializeResult, RpcError}; -use async_trait::async_trait; -use rust_mcp_transport::Transport; - +use super::ServerRuntime; use crate::error::SdkResult; use crate::mcp_handlers::mcp_server_handler_core::ServerHandlerCore; use crate::mcp_traits::mcp_handler::McpServerHandler; use crate::mcp_traits::mcp_server::McpServer; - -use super::ServerRuntime; +use crate::schema::schema_utils::{ + self, ClientMessage, MessageFromServer, NotificationFromClient, RequestFromClient, + ResultFromServer, ServerMessage, +}; +use crate::schema::{ + schema_utils::{ClientMessages, ServerMessages}, + ClientRequest, InitializeResult, RpcError, +}; +use async_trait::async_trait; +use rust_mcp_transport::TransportDispatcher; +use std::sync::Arc; /// Creates a new MCP server runtime with the specified configuration. /// @@ -35,7 +35,13 @@ use super::ServerRuntime; /// [Repository Example](https://github.com/rust-mcp-stack/rust-mcp-sdk/tree/main/examples/hello-world-mcp-server-core) pub fn create_server( server_details: InitializeResult, - transport: impl Transport, + transport: impl TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + >, handler: impl ServerHandlerCore, ) -> ServerRuntime { ServerRuntime::new( diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs index 49ea75f..8e72c26 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_client.rs @@ -2,8 +2,8 @@ use std::{sync::Arc, time::Duration}; use crate::schema::{ schema_utils::{ - self, McpMessage, MessageFromClient, NotificationFromClient, RequestFromClient, - ResultFromServer, ServerMessage, + self, ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromClient, + NotificationFromClient, RequestFromClient, ResultFromServer, ServerMessage, ServerMessages, }, CallToolRequest, CallToolRequestParams, CallToolResult, CompleteRequest, CompleteRequestParams, CreateMessageRequest, GetPromptRequest, GetPromptRequestParams, Implementation, @@ -15,11 +15,10 @@ use crate::schema::{ SetLevelRequest, SetLevelRequestParams, SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, }; +use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; use rust_mcp_transport::{McpDispatch, MessageDispatcher}; -use crate::{error::SdkResult, utils::format_assertion_message}; - #[async_trait] pub trait McpClient: Sync + Send { async fn start(self: Arc) -> SdkResult<()>; @@ -28,9 +27,10 @@ pub trait McpClient: Sync + Send { async fn shut_down(&self) -> SdkResult<()>; async fn is_shut_down(&self) -> bool; - async fn sender(&self) -> &tokio::sync::RwLock>> + fn sender(&self) -> Arc>>> where - MessageDispatcher: McpDispatch; + MessageDispatcher: + McpDispatch; fn client_info(&self) -> &InitializeRequestParams; fn server_info(&self) -> Option; @@ -175,14 +175,18 @@ pub trait McpClient: Sync + Send { request: RequestFromClient, timeout: Option, ) -> SdkResult { - let sender = self.sender().await.read().await; + let sender = self.sender(); + let sender = sender.read().await; let sender = sender .as_ref() .ok_or(schema_utils::SdkError::connection_closed())?; - // Send the request and receive the response. + let request_id = sender.next_request_id(); + + let mcp_message = + ClientMessage::from_message(MessageFromClient::from(request), Some(request_id))?; let response = sender - .send(MessageFromClient::RequestFromClient(request), None, timeout) + .send_message(ClientMessages::Single(mcp_message), timeout) .await?; let server_message = response.ok_or_else(|| { @@ -190,6 +194,8 @@ pub trait McpClient: Sync + Send { .with_message("An empty response was received from the server.".to_string()) })?; + let server_message = server_message.as_single()?; + if server_message.is_error() { return Err(server_message.as_error()?.error.into()); } @@ -197,20 +203,68 @@ pub trait McpClient: Sync + Send { return Ok(server_message.as_response()?.result); } + async fn send( + &self, + message: ClientMessage, + timeout: Option, + ) -> SdkResult> { + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let response = sender + .send_message(ClientMessages::Single(message), timeout) + .await?; + + match response { + Some(res) => { + let server_results = res.as_single()?; + Ok(Some(server_results)) + } + None => Ok(None), + } + } + + async fn send_batch( + &self, + messages: Vec, + timeout: Option, + ) -> SdkResult>> { + let sender = self.sender(); + let sender = sender.read().await; + let sender = sender + .as_ref() + .ok_or(schema_utils::SdkError::connection_closed())?; + + let response = sender + .send_message(ClientMessages::Batch(messages), timeout) + .await?; + + match response { + Some(res) => { + let server_results = res.as_batch()?; + Ok(Some(server_results)) + } + None => Ok(None), + } + } + /// Sends a notification. This is a one-way message that is not expected /// to return any response. The method asynchronously sends the notification using /// the transport layer and does not wait for any acknowledgement or result. async fn send_notification(&self, notification: NotificationFromClient) -> SdkResult<()> { - let sender = self.sender().await.read().await; + let sender = self.sender(); + let sender = sender.read().await; let sender = sender .as_ref() .ok_or(schema_utils::SdkError::connection_closed())?; + + let mcp_message = ClientMessage::from_message(MessageFromClient::from(notification), None)?; + sender - .send( - MessageFromClient::NotificationFromClient(notification), - None, - None, - ) + .send_message(ClientMessages::Single(mcp_message), None) .await?; Ok(()) } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 414bf00..cf0f168 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -2,21 +2,20 @@ use std::time::Duration; use crate::schema::{ schema_utils::{ - ClientMessage, McpMessage, MessageFromServer, NotificationFromServer, RequestFromServer, - ResultFromClient, + ClientMessage, ClientMessages, McpMessage, MessageFromServer, NotificationFromServer, + RequestFromServer, ResultFromClient, ServerMessage, }, CallToolRequest, CreateMessageRequest, CreateMessageRequestParams, CreateMessageResult, GetPromptRequest, Implementation, InitializeRequestParams, InitializeResult, ListPromptsRequest, ListResourceTemplatesRequest, ListResourcesRequest, ListRootsRequest, ListRootsRequestParams, ListRootsResult, ListToolsRequest, LoggingMessageNotification, LoggingMessageNotificationParams, PingRequest, PromptListChangedNotification, - PromptListChangedNotificationParams, ReadResourceRequest, ResourceListChangedNotification, - ResourceListChangedNotificationParams, ResourceUpdatedNotification, - ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, - ToolListChangedNotification, ToolListChangedNotificationParams, + PromptListChangedNotificationParams, ReadResourceRequest, RequestId, + ResourceListChangedNotification, ResourceListChangedNotificationParams, + ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, + SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; use async_trait::async_trait; -use rust_mcp_transport::{McpDispatch, MessageDispatcher}; use crate::{error::SdkResult, utils::format_assertion_message}; @@ -38,9 +37,18 @@ pub trait McpServer: Sync + Send { self.server_info() } - async fn sender(&self) -> &tokio::sync::RwLock>> - where - MessageDispatcher: McpDispatch; + async fn send( + &self, + message: MessageFromServer, + request_id: Option, + request_timeout: Option, + ) -> SdkResult>; + + async fn send_batch( + &self, + messages: Vec, + request_timeout: Option, + ) -> SdkResult>>; /// Checks whether the server has been initialized with client fn is_initialized(&self) -> bool { @@ -69,19 +77,18 @@ pub trait McpServer: Sync + Send { request: RequestFromServer, timeout: Option, ) -> SdkResult { - let sender = self.sender().await; - let sender = sender.read().await; - let sender = sender.as_ref().unwrap(); - // Send the request and receive the response. - let response = sender + let response = self .send(MessageFromServer::RequestFromServer(request), None, timeout) .await?; - let client_message = response.ok_or_else(|| { + + let client_messages = response.ok_or_else(|| { RpcError::internal_error() .with_message("An empty response was received from the client.".to_string()) })?; + let client_message = client_messages.as_single()?; + if client_message.is_error() { return Err(client_message.as_error()?.error.into()); } @@ -93,17 +100,12 @@ pub trait McpServer: Sync + Send { /// to return any response. The method asynchronously sends the notification using /// the transport layer and does not wait for any acknowledgement or result. async fn send_notification(&self, notification: NotificationFromServer) -> SdkResult<()> { - let sender = self.sender().await; - let sender = sender.read().await; - let sender = sender.as_ref().unwrap(); - - sender - .send( - MessageFromServer::NotificationFromServer(notification), - None, - None, - ) - .await?; + self.send( + MessageFromServer::NotificationFromServer(notification), + None, + None, + ) + .await?; Ok(()) } diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs index e5b7206..de92a06 100644 --- a/crates/rust-mcp-sdk/src/utils.rs +++ b/crates/rust-mcp-sdk/src/utils.rs @@ -1,6 +1,8 @@ -use std::cmp::Ordering; +use crate::schema::schema_utils::{ClientMessages, SdkError}; use crate::error::{McpSdkError, SdkResult}; +use crate::schema::ProtocolVersion; +use std::cmp::Ordering; /// Formats an assertion error message for unsupported capabilities. /// @@ -144,6 +146,11 @@ pub fn enforce_compatible_protocol_version( } } +pub fn validate_mcp_protocol_version(mcp_protocol_version: &str) -> SdkResult<()> { + let _mcp_protocol_version = ProtocolVersion::try_from(mcp_protocol_version)?; + Ok(()) +} + /// Removes query string and hash fragment from a URL, returning the base path. /// /// # Arguments @@ -170,6 +177,39 @@ pub(crate) fn remove_query_and_hash(endpoint: &str) -> String { } } +/// Checks if the input string is valid JSON and represents an "initialize" method request. +pub fn valid_initialize_method(json_str: &str) -> SdkResult<()> { + // Attempt to deserialize the input string into ClientMessages + let Ok(request) = serde_json::from_str::(json_str) else { + return Err(SdkError::bad_request() + .with_message("Bad Request: Session not found") + .into()); + }; + + match request { + ClientMessages::Single(client_message) => { + if !client_message.is_initialize_request() { + return Err(SdkError::bad_request() + .with_message("Bad Request: Session not found") + .into()); + } + } + ClientMessages::Batch(client_messages) => { + let count = client_messages + .iter() + .filter(|item| item.is_initialize_request()) + .count(); + if count > 1 { + return Err(SdkError::invalid_request() + .with_message("Bad Request: Only one initialization request is allowed") + .into()); + } + } + }; + + Ok(()) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/rust-mcp-sdk/tests/check_imports.rs b/crates/rust-mcp-sdk/tests/check_imports.rs new file mode 100644 index 0000000..cda7d0c --- /dev/null +++ b/crates/rust-mcp-sdk/tests/check_imports.rs @@ -0,0 +1,88 @@ +#[cfg(test)] +mod tests { + use std::fs::File; + use std::io::{self, Read}; + use std::path::{Path, MAIN_SEPARATOR_STR}; + + // List of files to exclude from the check + const EXCLUDED_FILES: &[&str] = &["src/schema.rs"]; + + // Check all .rs files for incorrect `use rust_mcp_schema` imports + #[test] + fn check_no_rust_mcp_schema_imports() { + let mut errors = Vec::new(); + + // Walk through the src directory + for entry in walk_src_dir("src").expect("Failed to read src directory") { + let entry = entry.unwrap(); + let path = entry.path(); + + // only check files with .rs extension + if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("rs") { + let abs_path = path.to_string_lossy(); + let relative_path = path.strip_prefix("src").unwrap_or(&path); + let path_str = relative_path.to_string_lossy(); + + // Skip excluded files + if EXCLUDED_FILES + .iter() + .any(|&excluded| abs_path.replace(MAIN_SEPARATOR_STR, "/") == excluded) + { + continue; + } + + // Read the file content + match read_file(&path) { + Ok(content) => { + // Check for `use rust_mcp_schema` + if content.contains("use rust_mcp_schema") { + errors.push(format!( + "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", + abs_path + )); + } + } + Err(e) => { + errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + } + } + } + } + + // If there are any errors, fail the test with all error messages + if !errors.is_empty() { + panic!( + "Found {} incorrect imports:\n{}\n\n", + errors.len(), + errors.join("\n") + ); + } + } + + // Helper function to walk the src directory + fn walk_src_dir>( + path: P, + ) -> io::Result>> { + Ok(std::fs::read_dir(path)?.flat_map(|entry| { + let entry = entry.unwrap(); + let path = entry.path(); + if path.is_dir() { + // Recursively walk subdirectories + walk_src_dir(&path) + .into_iter() + .flatten() + .collect::>() + } else { + vec![Ok(entry)] + } + })) + } + + // Helper function to read file content + fn read_file(path: &Path) -> io::Result { + let mut file = File::open(path)?; + let mut content = String::new(); + file.read_to_string(&mut content)?; + Ok(content) + } +} diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs index 046e796..57a3ea8 100644 --- a/crates/rust-mcp-sdk/tests/common/common.rs +++ b/crates/rust-mcp-sdk/tests/common/common.rs @@ -1,8 +1,15 @@ mod test_server; use async_trait::async_trait; +use reqwest::{Client, Response, Url}; +use rust_mcp_macros::{mcp_tool, JsonSchema}; use rust_mcp_schema::ProtocolVersion; use rust_mcp_sdk::mcp_client::ClientHandler; + use rust_mcp_sdk::schema::{ClientCapabilities, Implementation, InitializeRequestParams}; +use std::collections::HashMap; +use std::process; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio_stream::StreamExt; pub use test_server::*; @@ -11,6 +18,164 @@ pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything #[cfg(unix)] pub const UVX_SERVER_GIT: &str = "mcp-server-git"; +#[mcp_tool( + name = "say_hello", + description = "Accepts a person's name and says a personalized \"Hello\" to that person", + title = "A tool that says hello!", + idempotent_hint = false, + destructive_hint = false, + open_world_hint = false, + read_only_hint = false, + meta = r#"{"version": "1.0"}"# +)] +#[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] +pub struct SayHelloTool { + /// The name of the person to greet with a "Hello". + name: String, +} + +pub async fn send_post_request( + base_url: &str, + message: &str, + session_id: Option<&str>, + post_headers: Option>, +) -> Result { + let client = Client::new(); + let url = Url::parse(base_url).expect("Invalid URL"); + + let mut headers = reqwest::header::HeaderMap::new(); + + let protocol_version = ProtocolVersion::V2025_06_18.to_string(); + let post_headers = post_headers.unwrap_or({ + let mut map: HashMap<&str, &str> = HashMap::new(); + map.insert("Content-Type", "application/json"); + map.insert("Accept", "application/json, text/event-stream"); + map.insert("mcp-protocol-version", protocol_version.as_str()); + map + }); + + if let Some(sid) = session_id { + headers.insert("mcp-session-id", sid.parse().unwrap()); + } + + for (key, value) in post_headers { + headers.insert( + reqwest::header::HeaderName::from_bytes(key.as_bytes()).unwrap(), + value.parse().unwrap(), + ); + } + + let body = message.to_string(); + + client.post(url).headers(headers).body(body).send().await +} + +pub async fn send_delete_request( + base_url: &str, + session_id: Option<&str>, + post_headers: Option>, +) -> Result { + let client = Client::new(); + let url = Url::parse(base_url).expect("Invalid URL"); + + let mut headers = reqwest::header::HeaderMap::new(); + + let protocol_version = ProtocolVersion::V2025_06_18.to_string(); + let post_headers = post_headers.unwrap_or({ + let mut map: HashMap<&str, &str> = HashMap::new(); + map.insert("Content-Type", "application/json"); + map.insert("Accept", "application/json, text/event-stream"); + map.insert("mcp-protocol-version", protocol_version.as_str()); + map + }); + + if let Some(sid) = session_id { + headers.insert("mcp-session-id", sid.parse().unwrap()); + } + + for (key, value) in post_headers { + headers.insert( + reqwest::header::HeaderName::from_bytes(key.as_bytes()).unwrap(), + value.parse().unwrap(), + ); + } + + client.delete(url).headers(headers).send().await +} + +pub async fn send_get_request( + base_url: &str, + extra_headers: Option>, +) -> Result { + let client = Client::new(); + let url = Url::parse(base_url).expect("Invalid URL"); + + let mut headers = reqwest::header::HeaderMap::new(); + + if let Some(extra) = extra_headers { + for (key, value) in extra { + headers.insert( + reqwest::header::HeaderName::from_bytes(key.as_bytes()).unwrap(), + value.parse().unwrap(), + ); + } + } + client.get(url).headers(headers).send().await +} + +use futures::stream::Stream; + +// stream: &mut impl Stream>, +pub async fn read_sse_event_from_stream( + stream: &mut (impl Stream> + Unpin), +) -> Option { + let mut buffer = String::new(); + + while let Some(item) = stream.next().await { + match item { + Ok(chunk) => { + let chunk_str = std::str::from_utf8(&chunk).unwrap(); + buffer.push_str(chunk_str); + + while let Some(pos) = buffer.find("\n\n") { + let data = { + // Scope to limit borrows + let (event_str, rest) = buffer.split_at(pos); + let mut data = None; + + // Process the event string + for line in event_str.lines() { + if line.starts_with("data:") { + data = Some(line.trim_start_matches("data:").trim().to_string()); + break; // Exit loop after finding data + } + } + + // Update buffer after processing + buffer = rest[2..].to_string(); // Skip "\n\n" + data + }; + + // Return if data was found + if let Some(data) = data { + return Some(data); + } + } + } + Err(_e) => { + // return Err(TransportServerError::HyperError(e)); + return None; + } + } + } + None +} + +pub async fn read_sse_event(response: Response) -> Option { + let mut stream = response.bytes_stream(); + read_sse_event_from_stream(&mut stream).await +} + pub fn test_client_info() -> InitializeRequestParams { InitializeRequestParams { capabilities: ClientCapabilities::default(), @@ -37,6 +202,72 @@ pub fn sse_data(sse_raw: &str) -> String { sse_raw.replace("data: ", "") } +// Simple Xorshift PRNG struct +struct Xorshift { + state: u64, +} + +impl Xorshift { + // Initialize with a seed based on system time and process ID + fn new() -> Self { + // Get nanoseconds since UNIX epoch + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("System time error") + .as_nanos() as u64; + // Get process ID for additional entropy + let pid = process::id() as u64; + // Combine nanos and pid with a simple mix + let seed = nanos ^ (pid << 32) ^ (nanos.rotate_left(17)); + Xorshift { state: seed | 1 } // Ensure non-zero seed + } + + // Generate the next random u64 using Xorshift + fn next_u64(&mut self) -> u64 { + let mut x = self.state; + x ^= x << 13; + x ^= x >> 7; + x ^= x << 17; + self.state = x; + x + } + + // Generate a random u16 within a range [min, max] + fn next_u16_range(&mut self, min: u16, max: u16) -> u16 { + assert!(max >= min, "max must be greater than or equal to min"); + let range = (max - min + 1) as u64; + min + (self.next_u64() % range) as u16 + } +} + +// Generate a random port number in the range [8081, 15000] +pub fn random_port() -> u16 { + const MIN_PORT: u16 = 8081; + const MAX_PORT: u16 = 15000; + + let mut rng = Xorshift::new(); + rng.next_u16_range(MIN_PORT, MAX_PORT) +} + +pub fn random_port_old() -> u16 { + let min: u16 = 8081; + let max: u16 = 15000; + let range = max - min + 1; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("systime error!"); + + // Combine seconds and nanoseconds for better entropy + let nanos = now.subsec_nanos() as u64; + let secs = now.as_secs(); + + // Simple hash-like mix + let mixed = (nanos ^ (secs << 16)) ^ (nanos.rotate_left(13)); + + min + ((mixed as u16) % range) +} + pub mod sample_tools { #[cfg(feature = "2025_06_18")] use rust_mcp_sdk::macros::{mcp_tool, JsonSchema}; @@ -56,7 +287,7 @@ pub mod sample_tools { #[derive(Debug, ::serde::Deserialize, ::serde::Serialize, JsonSchema)] pub struct SayHelloTool { /// The name of the person to greet with a "Hello". - name: String, + pub name: String, } impl SayHelloTool { diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 0ea8d0a..aa8e2fb 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,10 +1,17 @@ #[cfg(feature = "hyper-server")] pub mod test_server_common { use async_trait::async_trait; + use rust_mcp_schema::schema_utils::CallToolError; + use rust_mcp_schema::{ + CallToolRequest, CallToolResult, ListToolsRequest, ListToolsResult, ProtocolVersion, + RpcError, + }; + use rust_mcp_sdk::mcp_server::hyper_runtime::HyperRuntime; use tokio_stream::StreamExt; use rust_mcp_sdk::schema::{ - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, + ClientCapabilities, Implementation, InitializeRequest, InitializeRequestParams, + InitializeResult, ServerCapabilities, ServerCapabilitiesTools, }; use rust_mcp_sdk::{ mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, @@ -14,9 +21,32 @@ pub mod test_server_common { use std::time::Duration; use tokio::time::timeout; - pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; + use crate::common::sample_tools::SayHelloTool; + + pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + pub struct LaunchedServer { + pub hyper_runtime: HyperRuntime, + pub streamable_url: String, + pub sse_url: String, + pub sse_message_url: String, + } + + pub fn initialize_request() -> InitializeRequest { + InitializeRequest::new(InitializeRequestParams { + capabilities: ClientCapabilities { + ..Default::default() + }, + client_info: Implementation { + name: "test-server".to_string(), + title: None, + version: "0.1.0".to_string(), + }, + protocol_version: ProtocolVersion::V2025_06_18.to_string(), + }) + } + pub fn test_server_details() -> InitializeResult { InitializeResult { // server name and version @@ -33,7 +63,7 @@ pub mod test_server_common { }, meta: None, instructions: Some("server instructions...".to_string()), - protocol_version: "2025-03-26".to_string(), + protocol_version: ProtocolVersion::V2025_06_18.to_string(), } } @@ -46,12 +76,71 @@ pub mod test_server_common { .stderr_message("Server started successfully".into()) .await; } + + async fn handle_list_tools_request( + &self, + request: ListToolsRequest, + runtime: &dyn McpServer, + ) -> std::result::Result { + runtime.assert_server_request_capabilities(request.method())?; + + Ok(ListToolsResult { + meta: None, + next_cursor: None, + tools: vec![SayHelloTool::tool()], + }) + } + + async fn handle_call_tool_request( + &self, + request: CallToolRequest, + runtime: &dyn McpServer, + ) -> std::result::Result { + runtime + .assert_server_request_capabilities(request.method()) + .map_err(CallToolError::new)?; + if request.params.name != "say_hello" { + Ok( + CallToolError::unknown_tool(format!("Unknown tool: {}", request.params.name)) + .into(), + ) + } else { + let tool = SayHelloTool { + name: request.params.arguments.unwrap()["name"] + .as_str() + .unwrap() + .to_string(), + }; + + Ok(tool.call_tool().unwrap()) + } + } } pub fn create_test_server(options: HyperServerOptions) -> HyperServer { hyper_server::create_server(test_server_details(), TestServerHandler {}, options) } + pub async fn create_start_server(options: HyperServerOptions) -> LaunchedServer { + let streamable_url = options.streamable_http_url(); + let sse_url = options.sse_url(); + let sse_message_url = options.sse_message_url(); + + let server = + hyper_server::create_server(test_server_details(), TestServerHandler {}, options); + + let hyper_runtime = HyperRuntime::create(server).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(75)).await; + + LaunchedServer { + hyper_runtime, + streamable_url, + sse_url, + sse_message_url, + } + } + // Tests the session ID generator, ensuring it returns a sequence of predefined session IDs. pub struct TestIdGenerator { constant_ids: Vec, diff --git a/crates/rust-mcp-sdk/tests/test_client_runtime.rs b/crates/rust-mcp-sdk/tests/test_client_runtime.rs index c8b3b17..b9749cb 100644 --- a/crates/rust-mcp-sdk/tests/test_client_runtime.rs +++ b/crates/rust-mcp-sdk/tests/test_client_runtime.rs @@ -43,9 +43,7 @@ async fn tets_client_launch_uvx_server() { .unwrap(); let client = client_runtime::create_client(test_client_info(), transport, TestClientHandler {}); - client.clone().start().await.unwrap(); - let server_capabilities = client.server_capabilities().unwrap(); let server_info = client.server_info().unwrap(); diff --git a/crates/rust-mcp-sdk/tests/test_streamable_http.rs b/crates/rust-mcp-sdk/tests/test_streamable_http.rs new file mode 100644 index 0000000..08c85e8 --- /dev/null +++ b/crates/rust-mcp-sdk/tests/test_streamable_http.rs @@ -0,0 +1,1277 @@ +use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, vec}; + +use hyper::StatusCode; +use rust_mcp_schema::{ + schema_utils::{ + ClientJsonrpcRequest, ClientMessage, ClientMessages, FromMessage, NotificationFromServer, + ResultFromServer, RpcMessage, SdkError, SdkErrorCodes, ServerJsonrpcNotification, + ServerJsonrpcResponse, ServerMessages, + }, + CallToolRequest, CallToolRequestParams, ListToolsRequest, LoggingLevel, + LoggingMessageNotificationParams, RequestId, RootsListChangedNotification, ServerNotification, + ServerResult, +}; +use rust_mcp_sdk::mcp_server::HyperServerOptions; +use serde_json::{json, Map, Value}; +use tokio_stream::StreamExt; + +use crate::common::{ + random_port, read_sse_event, read_sse_event_from_stream, send_delete_request, send_get_request, + send_post_request, + test_server_common::{ + create_start_server, initialize_request, LaunchedServer, TestIdGenerator, + }, +}; + +#[path = "common/common.rs"] +pub mod common; + +const ONE_MILLISECOND: Option = Some(Duration::from_millis(1)); + +async fn initialize_server( + enable_json_response: Option, +) -> Result<(LaunchedServer, String), Box> { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + enable_json_response, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + None, + ) + .await + .expect("Request failed"); + + let session_id = response + .headers() + .get("mcp-session-id") + .unwrap() + .to_str() + .unwrap() + .to_owned(); + + Ok((server, session_id)) +} + +// should initialize server and generate session ID +#[tokio::test] +async fn should_initialize_server_and_generate_session_id() { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + assert!(response.headers().get("mcp-session-id").is_some()); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject batch initialize request +#[tokio::test] +async fn should_reject_batch_initialize_request() { + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + enable_json_response: None, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + tokio::time::sleep(Duration::from_millis(250)).await; + + let first_init_message = ClientJsonrpcRequest::new( + RequestId::String("first-init".to_string()), + initialize_request().into(), + ); + let second_init_message = ClientJsonrpcRequest::new( + RequestId::String("second-init".to_string()), + initialize_request().into(), + ); + + let messages = vec![ + ClientMessage::Request(first_init_message), + ClientMessage::Request(second_init_message), + ]; + let batch_message = ClientMessages::Batch(messages); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&batch_message).unwrap(), + None, + None, + ) + .await + .expect("Request failed"); + + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::INVALID_REQUEST as i64); + assert!(error_data + .message + .contains("Only one initialization request is allowed")); +} + +// should handle post requests via sse response correctly +#[tokio::test] +async fn should_handle_post_requests_via_sse_response_correctly() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + let event = read_sse_event(response).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + + assert!(matches!(message.id, RequestId::Integer(1))); + + let ResultFromServer::ServerResult(ServerResult::ListToolsResult(result)) = message.result + else { + panic!("invalid ListToolsResult") + }; + + assert_eq!(result.tools.len(), 1); + + let tool = &result.tools[0]; + assert_eq!(tool.name, "say_hello"); + assert_eq!( + tool.description.as_ref().unwrap(), + r#"Accepts a person's name and says a personalized "Hello" to that person"# + ); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should call a tool and return the result +#[tokio::test] +async fn should_call_a_tool_and_return_the_result() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut map = Map::new(); + map.insert("name".to_string(), Value::String("Ali".to_string())); + + let json_rpc_message: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + CallToolRequest::new(CallToolRequestParams { + arguments: Some(map), + name: "say_hello".to_string(), + }) + .into(), + ); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + let event = read_sse_event(response).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + + assert!(matches!(message.id, RequestId::Integer(1))); + + let ResultFromServer::ServerResult(ServerResult::CallToolResult(result)) = message.result + else { + panic!("invalid CallToolResult") + }; + + assert_eq!(result.content.len(), 1); + assert_eq!( + result.content[0].as_text_content().unwrap().text, + "Hello, Ali!" + ); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject requests without a valid session ID +#[tokio::test] +async fn should_reject_requests_without_a_valid_session_id() { + let (server, _session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, // pass no session id + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject invalid session ID +#[tokio::test] +async fn should_reject_invalid_session_id() { + let (server, _session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some("invalid-session-id"), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::SESSION_NOT_FOUND as i64); // Typescript sdk uses -32001 code + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +async fn get_standalone_stream(streamable_url: &str, session_id: &str) -> reqwest::Response { + let mut headers = HashMap::new(); + headers.insert("Accept", "text/event-stream , application/json"); + headers.insert("mcp-session-id", session_id); + headers.insert("mcp-protocol-version", "2025-03-26"); + + let response = send_get_request(streamable_url, Some(headers)) + .await + .unwrap(); + response +} + +// should establish standalone SSE stream and receive server-initiated messages +#[tokio::test] +async fn should_establish_standalone_stream_and_receive_server_messages() { + let (server, session_id) = initialize_server(None).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id).await; + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!( + response + .headers() + .get("mcp-session-id") + .unwrap() + .to_str() + .unwrap(), + session_id + ); + + assert_eq!( + response + .headers() + .get("content-type") + .unwrap() + .to_str() + .unwrap(), + "text/event-stream" + ); + + // Send a notification (server-initiated message) that should appear on SSE stream + server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("Test notification"), + level: rust_mcp_schema::LoggingLevel::Info, + logger: None, + }, + ) + .await + .unwrap(); + + let event = read_sse_event(response).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification.params.level, LoggingLevel::Info); + assert_eq!( + notification.params.data.as_str().unwrap(), + "Test notification" + ); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should not close GET SSE stream after sending multiple server notifications +#[tokio::test] +async fn should_not_close_get_sse_stream() { + let (server, session_id) = initialize_server(None).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id).await; + + assert_eq!(response.status(), StatusCode::OK); + + server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("First notification"), + level: rust_mcp_schema::LoggingLevel::Info, + logger: None, + }, + ) + .await + .unwrap(); + + let mut stream = response.bytes_stream(); + let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification.params.level, LoggingLevel::Info); + assert_eq!( + notification.params.data.as_str().unwrap(), + "First notification" + ); + + server + .hyper_runtime + .send_logging_message( + &session_id, + LoggingMessageNotificationParams { + data: json!("Second notification"), + level: rust_mcp_schema::LoggingLevel::Info, + logger: None, + }, + ) + .await + .unwrap(); + + let event = read_sse_event_from_stream(&mut stream).await.unwrap(); + let message: ServerJsonrpcNotification = serde_json::from_str(&event).unwrap(); + + let NotificationFromServer::ServerNotification(ServerNotification::LoggingMessageNotification( + notification_2, + )) = message.notification + else { + panic!("invalid message received!"); + }; + + assert_eq!(notification_2.params.level, LoggingLevel::Info); + assert_eq!( + notification_2.params.data.as_str().unwrap(), + "Second notification" + ); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +//should reject second SSE stream for the same session +#[tokio::test] +async fn should_reject_second_sse_stream_for_the_same_session() { + let (server, session_id) = initialize_server(None).await.unwrap(); + let response = get_standalone_stream(&server.streamable_url, &session_id).await; + assert_eq!(response.status(), StatusCode::OK); + + let second_response = get_standalone_stream(&server.streamable_url, &session_id).await; + assert_eq!(second_response.status(), StatusCode::CONFLICT); + + let error_data: SdkError = second_response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject GET requests without Accept: text/event-stream header +#[tokio::test] +async fn should_reject_get_requests() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Accept", "application/json"); + headers.insert("mcp-session-id", &session_id); + headers.insert("mcp-protocol-version", "2025-03-26"); + + let response = send_get_request(&server.streamable_url, Some(headers)) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); //406 + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + assert!(error_data.message.contains("must accept text/event-stream")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject POST requests without proper Accept header +#[tokio::test] +async fn reject_post_requests_without_accept_header() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let mut headers = HashMap::new(); + headers.insert("Accept", "application/json"); + headers.insert("mcp-session-id", &session_id); + headers.insert("mcp-protocol-version", "2025-03-26"); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); //406 + + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + assert!(error_data + .message + .contains("must accept both application/json and text/event-stream")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +//should reject unsupported Content-Type +#[tokio::test] +async fn should_reject_unsupported_content_type() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "text/plain"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("mcp-session-id", &session_id); + headers.insert("mcp-protocol-version", "2025-03-26"); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); //415 + + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + assert!(error_data + .message + .contains("Content-Type must be application/json")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should handle JSON-RPC batch notification messages with 202 response +#[tokio::test] +async fn should_handle_batch_notification_messages_with_202_response() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let batch_notification = ClientMessages::Batch(vec![ + ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), + ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), + ]); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&batch_notification).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + assert_eq!(response.status(), StatusCode::ACCEPTED); +} + +// should properly handle invalid JSON data +#[tokio::test] +async fn should_properly_handle_invalid_json_data() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let response = send_post_request( + &server.streamable_url, + "This is not a valid JSON", + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let error_data: SdkError = response.json().await.unwrap(); + assert_eq!(error_data.code, SdkErrorCodes::PARSE_ERROR as i64); + assert!(error_data.message.contains("Parse Error")); +} + +// should send response messages to the connection that sent the request +#[tokio::test] +async fn should_send_response_messages_to_the_connection_that_sent_the_request() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message_1: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let mut map = Map::new(); + map.insert("name".to_string(), Value::String("Ali".to_string())); + + let json_rpc_message_2: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::Integer(1), + CallToolRequest::new(CallToolRequestParams { + arguments: Some(map), + name: "say_hello".to_string(), + }) + .into(), + ); + + let response_1 = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message_1).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + let response_2 = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message_2).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response_1.status(), StatusCode::OK); + assert_eq!(response_2.status(), StatusCode::OK); + + let event = read_sse_event(response_2).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + + assert!(matches!(message.id, RequestId::Integer(1))); + + let ResultFromServer::ServerResult(ServerResult::CallToolResult(result)) = message.result + else { + panic!("invalid CallToolResult") + }; + + assert_eq!(result.content.len(), 1); + assert_eq!( + result.content[0].as_text_content().unwrap().text, + "Hello, Ali!" + ); + + let event = read_sse_event(response_1).await.unwrap(); + let message: ServerJsonrpcResponse = serde_json::from_str(&event).unwrap(); + + assert!(matches!(message.id, RequestId::Integer(1))); + + let ResultFromServer::ServerResult(ServerResult::ListToolsResult(result)) = message.result + else { + panic!("invalid ListToolsResult") + }; + + assert_eq!(result.tools.len(), 1); + + let tool = &result.tools[0]; + assert_eq!(tool.name, "say_hello"); + assert_eq!( + tool.description.as_ref().unwrap(), + r#"Accepts a person's name and says a personalized "Hello" to that person"# + ); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should properly handle DELETE requests and close session +#[tokio::test] +async fn should_properly_handle_delete_requests_and_close_session() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "text/plain"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("mcp-session-id", &session_id); + headers.insert("mcp-protocol-version", "2025-03-26"); + + let response = send_delete_request(&server.streamable_url, Some(&session_id), Some(headers)) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject DELETE requests with invalid session ID +#[tokio::test] +async fn should_reject_delete_requests_with_invalid_session_id() { + let (server, _session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "text/plain"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("mcp-session-id", "invalid-session-id"); + headers.insert("mcp-protocol-version", "2025-03-26"); + + let response = send_delete_request( + &server.streamable_url, + Some("invalid-session-id"), + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + let error_data: SdkError = response.json().await.unwrap(); + + assert_eq!(error_data.code, SdkErrorCodes::SESSION_NOT_FOUND as i64); // Typescript sdk uses -32001 code + assert!(error_data.message.contains("Session not found")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +/** + * protocol version header validation + */ + +// should accept requests without protocol version header +#[tokio::test] +async fn should_accept_requests_without_protocol_version_header() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject requests with unsupported protocol version +#[tokio::test] +async fn should_reject_requests_with_unsupported_protocol_version() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("mcp-protocol-version", "1999-15-21"); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let error_data: SdkError = response.json().await.unwrap(); + + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + assert!(error_data.message.contains("Unsupported protocol version")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should handle protocol version validation for get requests +#[tokio::test] +async fn should_handle_protocol_version_validation_for_get_requests() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("mcp-protocol-version", "1999-15-21"); + headers.insert("mcp-session-id", &session_id); + + let response = send_get_request(&server.streamable_url, Some(headers)) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let error_data: SdkError = response.json().await.unwrap(); + + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + assert!(error_data.message.contains("Unsupported protocol version")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should handle protocol version validation for DELETE requests +#[tokio::test] +async fn should_handle_protocol_version_validation_for_delete_requests() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("mcp-protocol-version", "1999-15-21"); + + let response = send_delete_request(&server.streamable_url, Some(&session_id), Some(headers)) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + let error_data: SdkError = response.json().await.unwrap(); + + assert_eq!(error_data.code, SdkErrorCodes::BAD_REQUEST as i64); // Typescript sdk uses -32000 code + assert!(error_data.message.contains("Unsupported protocol version")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +/** + * Test JSON Response Mode + */ + +// should return JSON response for a single request +#[tokio::test] +async fn should_return_json_response_for_a_single_request() { + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(1), ListToolsRequest::new(None).into()); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/json" + ); + assert!(response.headers().get("mcp-session-id").is_some()); + + let message = response.json::().await.unwrap(); + + let ResultFromServer::ServerResult(ServerResult::ListToolsResult(result)) = message.result + else { + panic!("invalid ListToolsResult") + }; + + assert_eq!(result.tools.len(), 1); + + let tool = &result.tools[0]; + assert_eq!(tool.name, "say_hello"); + assert_eq!( + tool.description.as_ref().unwrap(), + r#"Accepts a person's name and says a personalized "Hello" to that person"# + ); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should return JSON response for batch requests +#[tokio::test] +async fn should_return_json_response_for_a_batch_request() { + let (server, session_id) = initialize_server(Some(true)).await.unwrap(); + + let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::String("req_1".to_string()), + ListToolsRequest::new(None).into(), + ); + + let mut map = Map::new(); + map.insert("name".to_string(), Value::String("Ali".to_string())); + let json_rpc_message_3: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::String("req_2".to_string()), + CallToolRequest::new(CallToolRequestParams { + arguments: Some(map), + name: "say_hello".to_string(), + }) + .into(), + ); + + let batch_message = ClientMessages::Batch(vec![ + json_rpc_message_1.into(), + ClientMessage::from_message(RootsListChangedNotification::new(None), None).unwrap(), + json_rpc_message_3.into(), + ]); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&batch_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/json" + ); + assert!(response.headers().get("mcp-session-id").is_some()); + + let messages = response.json::().await.unwrap(); + + let ServerMessages::Batch(mut messages) = messages else { + panic!("Invalid message type"); + }; + + assert_eq!(messages.len(), 2); + + let mut results = messages.drain(0..); + let result_1 = results.next().unwrap(); + assert_eq!( + result_1.request_id().unwrap(), + RequestId::String("req_1".to_string()) + ); + let ResultFromServer::ServerResult(ServerResult::ListToolsResult(_)) = + result_1.as_response().unwrap().result + else { + panic!("Expected a ListToolsResult"); + }; + + let result_2 = results.next().unwrap(); + assert_eq!( + result_2.request_id().unwrap(), + RequestId::String("req_2".to_string()) + ); + let ResultFromServer::ServerResult(ServerResult::CallToolResult(_)) = + result_2.as_response().unwrap().result + else { + panic!("Expected a CallToolResult"); + }; + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should handle batch request messages with SSE stream for responses +#[tokio::test] +async fn should_handle_batch_request_messages_with_sse_stream_for_responses() { + let (server, session_id) = initialize_server(None).await.unwrap(); + + let json_rpc_message_1: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::String("req_1".to_string()), + ListToolsRequest::new(None).into(), + ); + + let mut map = Map::new(); + map.insert("name".to_string(), Value::String("Ali".to_string())); + let json_rpc_message_2: ClientJsonrpcRequest = ClientJsonrpcRequest::new( + RequestId::String("req_2".to_string()), + CallToolRequest::new(CallToolRequestParams { + arguments: Some(map), + name: "say_hello".to_string(), + }) + .into(), + ); + + let batch_message = + ClientMessages::Batch(vec![json_rpc_message_1.into(), json_rpc_message_2.into()]); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&batch_message).unwrap(), + Some(&session_id), + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + + let event = read_sse_event(response).await.unwrap(); + let message: ServerMessages = serde_json::from_str(&event).unwrap(); + + let ServerMessages::Batch(mut messages) = message else { + panic!("Invalid message type"); + }; + + assert_eq!(messages.len(), 2); + + let mut results = messages.drain(0..); + let result_1 = results.next().unwrap(); + assert_eq!( + result_1.request_id().unwrap(), + RequestId::String("req_1".to_string()) + ); + let ResultFromServer::ServerResult(ServerResult::ListToolsResult(_)) = + result_1.as_response().unwrap().result + else { + panic!("Expected a ListToolsResult"); + }; + + let result_2 = results.next().unwrap(); + assert_eq!( + result_2.request_id().unwrap(), + RequestId::String("req_2".to_string()) + ); + let ResultFromServer::ServerResult(ServerResult::CallToolResult(_)) = + result_2.as_response().unwrap().result + else { + panic!("Expected a CallToolResult"); + }; +} + +// Test DNS rebinding protection + +// should accept requests with allowed host headers +#[tokio::test] +async fn should_accept_requests_with_allowed_host_headers() { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: 8090, + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + allowed_hosts: Some(vec!["127.0.0.1:8090".to_string()]), + dns_rebinding_protection: true, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + assert!(response.headers().get("mcp-session-id").is_some()); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject requests with disallowed host headers +#[tokio::test] +async fn should_reject_requests_with_disallowed_host_headers() { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + allowed_hosts: Some(vec!["example.com:3001".to_string()]), + dns_rebinding_protection: true, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + None, + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let error_data: SdkError = response.json().await.unwrap(); + assert!(error_data.message.contains("Invalid Host header")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should reject GET requests with disallowed host headers +#[tokio::test] +async fn should_reject_get_requests_with_disallowed_host_headers() { + let server_options = HyperServerOptions { + port: random_port(), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + allowed_hosts: Some(vec!["example.com:3001".to_string()]), + dns_rebinding_protection: true, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + + let response = send_get_request(&server.streamable_url, Some(headers)) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::FORBIDDEN); + let error_data: SdkError = response.json().await.unwrap(); + assert!(error_data.message.contains("Invalid Host header")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should accept requests with allowed origin headers +#[tokio::test] +async fn should_accept_requests_with_allowed_origin_headers() { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: 3000, + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + allowed_origins: Some(vec![ + "http://localhost:3000".to_string(), + "https://example.com".to_string(), + ]), + dns_rebinding_protection: true, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + // Origin: "http://localhost:3000", + + tokio::time::sleep(Duration::from_millis(250)).await; + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("Origin", "http://localhost:3000"); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "text/event-stream" + ); + assert!(response.headers().get("mcp-session-id").is_some()); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +//should reject requests with disallowed origin headers +#[tokio::test] +async fn should_reject_requests_with_disallowed_origin_headers() { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: 3000, + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + allowed_origins: Some(vec!["http://localhost:3000".to_string()]), + dns_rebinding_protection: true, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("Origin", "http://evil.com"); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::FORBIDDEN); + + let error_data: SdkError = response.json().await.unwrap(); + + assert!(error_data.message.contains("Invalid Origin header")); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +// should skip all validations when enableDnsRebindingProtection is false +#[tokio::test] +async fn should_skip_all_validations_when_false() { + let json_rpc_message: ClientJsonrpcRequest = + ClientJsonrpcRequest::new(RequestId::Integer(0), initialize_request().into()); + + let server_options = HyperServerOptions { + port: 3030, + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + allowed_hosts: Some(vec!["localhost".to_string()]), + allowed_origins: Some(vec!["http://localhost:3030".to_string()]), + dns_rebinding_protection: false, + ..Default::default() + }; + + let server = create_start_server(server_options).await; + + tokio::time::sleep(Duration::from_millis(250)).await; + + let mut headers = HashMap::new(); + headers.insert("Content-Type", "application/json"); + headers.insert("Accept", "application/json, text/event-stream"); + headers.insert("Origin", "http://evil.com"); + headers.insert("Host", "evil.com"); + + let response = send_post_request( + &server.streamable_url, + &serde_json::to_string(&json_rpc_message).unwrap(), + None, + Some(headers), + ) + .await + .expect("Request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + server.hyper_runtime.graceful_shutdown(ONE_MILLISECOND); + server.hyper_runtime.await_server().await.unwrap() +} + +//TODO: +// should return 400 error for invalid JSON-RPC messages diff --git a/crates/rust-mcp-transport/README.md b/crates/rust-mcp-transport/README.md index 4aa476b..23b78bf 100644 --- a/crates/rust-mcp-transport/README.md +++ b/crates/rust-mcp-transport/README.md @@ -2,8 +2,6 @@ `rust-mcp-transport` is a part of the [rust-mcp-sdk](https://crates.io/crates/rust-mcp-sdk) ecosystem, offering transport implementations for the MCP (Model Context Protocol). It enables asynchronous data exchange and efficient MCP message handling between MCP Clients and Servers. -**⚠️WARNING**: Currently, only Standard Input/Output (stdio) transport is supported. Server-Sent Events (SSE) transport is under development and will be available soon. - ## Usage Example ### For MCP Server diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index fbc52dd..f201aa0 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -8,11 +8,12 @@ use crate::utils::{ use crate::{IoStream, McpDispatch, TransportOptions}; use async_trait::async_trait; use bytes::Bytes; -use futures::Stream; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; +use tokio::sync::oneshot::Sender; +use tokio::task::JoinHandle; -use crate::schema::schema_utils::{McpMessage, RpcMessage}; +use crate::schema::schema_utils::McpMessage; use crate::schema::RequestId; use std::cmp::Ordering; use std::collections::HashMap; @@ -54,7 +55,7 @@ impl Default for ClientSseTransportOptions { /// Manages SSE connections, HTTP POST requests, and message streaming for client-server communication. pub struct ClientSseTransport where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { /// Optional cancellation token source for shutting down the transport shutdown_source: tokio::sync::RwLock>, @@ -76,13 +77,14 @@ where custom_headers: Option, sse_task: tokio::sync::RwLock>>, post_task: tokio::sync::RwLock>>, - message_sender: tokio::sync::RwLock>>, + message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, } impl ClientSseTransport where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { /// Creates a new ClientSseTransport instance /// @@ -97,8 +99,15 @@ where pub fn new(server_url: &str, options: ClientSseTransportOptions) -> TransportResult { let client = Client::new(); - //TODO: error handling - let base_url = extract_origin(server_url).unwrap(); + let base_url = match extract_origin(server_url) { + Some(url) => url, + None => { + let error_message = + format!("Failed to extract origin from server URL: {server_url}"); + tracing::error!(error_message); + return Err(TransportError::InvalidOptions(error_message)); + } + }; let headers = match &options.custom_headers { Some(h) => Some(Self::validate_headers(h)?), @@ -119,8 +128,9 @@ where custom_headers: headers, sse_task: tokio::sync::RwLock::new(None), post_task: tokio::sync::RwLock::new(None), - message_sender: tokio::sync::RwLock::new(None), + message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), }) } @@ -187,10 +197,13 @@ where } #[async_trait] -impl Transport for ClientSseTransport +impl Transport for ClientSseTransport where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { /// Starts the transport, initializing SSE and POST tasks /// @@ -199,18 +212,15 @@ where /// # Returns /// * `TransportResult<(Pin + Send>>, MessageDispatcher, IoStream)>` /// - The message stream, dispatcher, and error stream - async fn start(&self) -> TransportResult + Send>>> + async fn start(&self) -> TransportResult> where - MessageDispatcher: McpDispatch, + MessageDispatcher: McpDispatch, { // Create CancellationTokenSource and token let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); let mut lock = self.shutdown_source.write().await; *lock = Some(cancellation_source); - let pending_requests: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let (write_tx, mut write_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); let (read_tx, read_rx) = mpsc::channel::(DEFAULT_CHANNEL_CAPACITY); @@ -302,7 +312,7 @@ where readable, writable, IoStream::Writable(Box::pin(tokio::io::stderr())), - pending_requests, + self.pending_requests.clone(), self.request_timeout, cancellation_token, ); @@ -316,14 +326,31 @@ where Ok(stream) } - fn message_sender(&self) -> &tokio::sync::RwLock>> { - &self.message_sender as _ + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ } fn error_stream(&self) -> &tokio::sync::RwLock> { &self.error_stream as _ } + async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { + Err(TransportError::FromString( + "Invalid invocation of consume_string_payload() function for ClientSseTransport" + .to_string(), + )) + } + + async fn keep_alive( + &self, + _: Duration, + _: oneshot::Sender<()>, + ) -> TransportResult> { + Err(TransportError::FromString( + "Invalid invocation of keep_alive() function for ClientSseTransport".to_string(), + )) + } + /// Checks if the transport has been shut down /// /// # Returns @@ -380,4 +407,9 @@ where } } } + + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } } diff --git a/crates/rust-mcp-transport/src/error.rs b/crates/rust-mcp-transport/src/error.rs index 462c831..8f8b62f 100644 --- a/crates/rust-mcp-transport/src/error.rs +++ b/crates/rust-mcp-transport/src/error.rs @@ -4,7 +4,7 @@ use thiserror::Error; use crate::utils::CancellationError; use core::fmt; use std::any::Any; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, mpsc}; /// A wrapper around a broadcast send error. This structure allows for generic error handling /// by boxing the underlying error into a type-erased form. @@ -15,7 +15,7 @@ pub struct GenericSendError { #[allow(unused)] impl GenericSendError { - pub fn new(error: broadcast::error::SendError) -> Self { + pub fn new(error: mpsc::error::SendError) -> Self { Self { inner: Box::new(error), } diff --git a/crates/rust-mcp-transport/src/mcp_stream.rs b/crates/rust-mcp-transport/src/mcp_stream.rs index e32de05..2d2a377 100644 --- a/crates/rust-mcp-transport/src/mcp_stream.rs +++ b/crates/rust-mcp-transport/src/mcp_stream.rs @@ -1,12 +1,10 @@ -use crate::schema::schema_utils::RpcMessage; -use crate::schema::{RequestId, RpcError}; +use crate::schema::RequestId; use crate::{ error::{GenericSendError, TransportError}, message_dispatcher::MessageDispatcher, utils::CancellationToken, IoStream, }; -use futures::Stream; use std::{ collections::HashMap, pin::Pin, @@ -16,7 +14,7 @@ use std::{ use tokio::task::JoinHandle; use tokio::{ io::{AsyncBufReadExt, BufReader}, - sync::{broadcast::Sender, oneshot, Mutex}, + sync::Mutex, }; const CHANNEL_CAPACITY: usize = 36; @@ -32,7 +30,7 @@ impl MCPStream { /// - A `Pin + Send>>`: A stream that yields items of type `R`. /// - A `MessageDispatcher`: A sender that can be used to send messages of type `R`. /// - An `IoStream`: An error handling stream for managing error I/O (stderr). - pub fn create( + pub fn create( readable: Pin>, writable: Mutex>>, error_io: IoStream, @@ -40,29 +38,24 @@ impl MCPStream { request_timeout: Duration, cancellation_token: CancellationToken, ) -> ( - Pin + Send>>, + tokio_stream::wrappers::ReceiverStream, MessageDispatcher, IoStream, ) where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { - let (tx, rx) = tokio::sync::broadcast::channel::(CHANNEL_CAPACITY); + let (tx, rx) = tokio::sync::mpsc::channel::(CHANNEL_CAPACITY); + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); // Clone cancellation_token for reader let reader_token = cancellation_token.clone(); #[allow(clippy::let_underscore_future)] - let _ = Self::spawn_reader(readable, tx, pending_requests.clone(), reader_token); + let _ = Self::spawn_reader(readable, tx, reader_token); - let stream = { - Box::pin(futures::stream::unfold(rx, |mut rx| async move { - match rx.recv().await { - Ok(msg) => Some((msg, rx)), - Err(_) => None, - } - })) - }; + // rpc message stream that receives incoming messages let sender = MessageDispatcher::new( pending_requests, @@ -78,14 +71,13 @@ impl MCPStream { /// The received data is deserialized into a JsonrpcMessage. If the deserialization is successful, /// the object is transmitted. If the object is a response or error corresponding to a pending request, /// the associated pending request will ber removed from pending_requests. - fn spawn_reader( + fn spawn_reader( readable: Pin>, - tx: Sender, - pending_requests: Arc>>>, + tx: tokio::sync::mpsc::Sender, cancellation_token: CancellationToken, ) -> JoinHandle> where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + X: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { tokio::spawn(async move { let mut lines_stream = BufReader::new(readable).lines(); @@ -100,38 +92,16 @@ impl MCPStream { line = lines_stream.next_line() =>{ match line { Ok(Some(line)) => { + // deserialize and send it to the stream - let message: R = match serde_json::from_str(&line){ + let message: X = match serde_json::from_str(&line){ Ok(mcp_message) => mcp_message, Err(_) => { // continue if malformed message is received continue; }, }; - - if message.is_response() || message.is_error() { - if let Some(request_id) = &message.request_id() { - let mut pending_requests = pending_requests.lock().await; - - if let Some(tx_response) = pending_requests.remove(request_id) { - tx_response.send(message).map_err(|_| { - crate::error::TransportError::JsonrpcError( - RpcError::internal_error(), - ) - })?; - } else if message.is_error() { - //An error that is unrelated to a request. - tx.send(message).map_err(GenericSendError::new)?; - } else { - tracing::warn!( - "Received response or error without a matching request: {:?}", - &message.is_response() - ); - } - } - } else { - tx.send(message).map_err(GenericSendError::new)?; - } + tx.send(message).await.map_err(GenericSendError::new)?; } Ok(None) => { // EOF reached, exit loop diff --git a/crates/rust-mcp-transport/src/message_dispatcher.rs b/crates/rust-mcp-transport/src/message_dispatcher.rs index ee1c050..22d0b58 100644 --- a/crates/rust-mcp-transport/src/message_dispatcher.rs +++ b/crates/rust-mcp-transport/src/message_dispatcher.rs @@ -1,9 +1,12 @@ -use crate::schema::schema_utils::{ - self, ClientMessage, FromMessage, McpMessage, MessageFromClient, MessageFromServer, - ServerMessage, +use crate::schema::{ + schema_utils::{ + self, ClientMessage, ClientMessages, McpMessage, RpcMessage, ServerMessage, ServerMessages, + }, + JsonrpcError, }; use crate::schema::{RequestId, RpcError}; use async_trait::async_trait; +use futures::future::join_all; use std::collections::HashMap; use std::pin::Pin; @@ -11,7 +14,7 @@ use std::sync::atomic::AtomicI64; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::oneshot; +use tokio::sync::oneshot::{self}; use tokio::sync::Mutex; use crate::error::{TransportError, TransportResult}; @@ -68,7 +71,7 @@ impl MessageDispatcher { /// /// # Returns /// An `Option`: `Some` for requests or responses/errors, `None` for notifications. - fn request_id_for_message( + pub fn request_id_for_message( &self, message: &impl McpMessage, request_id: Option, @@ -77,10 +80,7 @@ impl MessageDispatcher { if message.is_request() { // request_id should be None for requests assert!(request_id.is_none()); - Some(RequestId::Integer( - self.message_id_counter - .fetch_add(1, std::sync::atomic::Ordering::Relaxed), - )) + Some(self.next_request_id()) } else if !message.is_notification() { // `request_id` must not be `None` for errors, notifications and responses assert!(request_id.is_some()); @@ -89,146 +89,300 @@ impl MessageDispatcher { None } } + pub fn next_request_id(&self) -> RequestId { + RequestId::Integer( + self.message_id_counter + .fetch_add(1, std::sync::atomic::Ordering::Relaxed), + ) + } + + async fn store_pending_request( + &self, + request_id: RequestId, + ) -> tokio::sync::oneshot::Receiver { + let (tx_response, rx_response) = oneshot::channel::(); + let mut pending_requests = self.pending_requests.lock().await; + // store request id in the hashmap while waiting for a matching response + pending_requests.insert(request_id.clone(), tx_response); + rx_response + } + + async fn store_pending_request_for_message( + &self, + message: &M, + ) -> Option> { + if message.is_request() { + if let Some(request_id) = message.request_id() { + Some(self.store_pending_request(request_id.clone()).await) + } else { + None + } + } else { + None + } + } } +// Client side dispatcher #[async_trait] -impl McpDispatch for MessageDispatcher { +impl McpDispatch + for MessageDispatcher +{ /// Sends a message from the client to the server and awaits a response if applicable. /// - /// Serializes the `MessageFromClient` to JSON, writes it to the transport, and waits for a - /// `ServerMessage` response if the message is a request. Notifications and responses return + /// Serializes the `ClientMessages` to JSON, writes it to the transport, and waits for a + /// `ServerMessages` response if the message is a request. Notifications and responses return /// `Ok(None)`. /// /// # Arguments - /// * `message` - The client message to send. - /// * `request_id` - An optional request ID (used for responses/errors, None for requests). + /// * `messages` - The client message to send, coulld be a single message or batch. /// /// # Returns - /// A `TransportResult` containing `Some(ServerMessage)` for requests with a response, + /// A `TransportResult` containing `Some(ServerMessages)` for requests with a response, /// or `None` for notifications/responses, or an error if the operation fails. /// /// # Errors /// Returns a `TransportError` if serialization, writing, or timeout occurs. - async fn send( + async fn send_message( &self, - message: MessageFromClient, - request_id: Option, + messages: ClientMessages, request_timeout: Option, - ) -> TransportResult> { - let mut writable_std = self.writable_std.lock().await; + ) -> TransportResult> { + match messages { + ClientMessages::Single(message) => { + let rx_response: Option> = + self.store_pending_request_for_message(&message).await; + + //serialize the message and write it to the writable_std + let message_payload = serde_json::to_string(&message).map_err(|_| { + crate::error::TransportError::JsonrpcError(RpcError::parse_error()) + })?; - // returns the request_id to be used to construct the message - // a new requestId will be returned for Requests and Notification - let outgoing_request_id = self.request_id_for_message(&message, request_id); - - let rx_response: Option> = { - // Store the sender in the pending requests map - if message.is_request() { - if let Some(request_id) = &outgoing_request_id { - let (tx_response, rx_response) = oneshot::channel::(); - let mut pending_requests = self.pending_requests.lock().await; - // store request id in the hashmap while waiting for a matching response - pending_requests.insert(request_id.clone(), tx_response); - Some(rx_response) + self.write_str(message_payload.as_str()).await?; + + if let Some(rx) = rx_response { + // Wait for the response with timeout + match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { + Ok(response) => Ok(Some(ServerMessages::Single(response))), + Err(error) => match error { + TransportError::OneshotRecvError(_) => { + Err(schema_utils::SdkError::connection_closed().into()) + } + _ => Err(error), + }, + } } else { - None + Ok(None) } - } else { - None } - }; + ClientMessages::Batch(client_messages) => { + let (request_ids, pending_tasks): (Vec<_>, Vec<_>) = client_messages + .iter() + .filter(|message| message.is_request()) + .map(|message| { + ( + message.request_id().unwrap(), // guaranteed to have request_id + self.store_pending_request_for_message(message), + ) + }) + .unzip(); - let mpc_message: ClientMessage = ClientMessage::from_message(message, outgoing_request_id)?; + // send the batch messages to the server + let message_payload = serde_json::to_string(&client_messages).map_err(|_| { + crate::error::TransportError::JsonrpcError(RpcError::parse_error()) + })?; + self.write_str(message_payload.as_str()).await?; - //serialize the message and write it to the writable_std - let message_str = serde_json::to_string(&mpc_message) - .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?; + // no request in the batch, no need to wait for the result + if pending_tasks.is_empty() { + return Ok(None); + } - writable_std.write_all(message_str.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; + let tasks = join_all(pending_tasks).await; - if let Some(rx) = rx_response { - // Wait for the response with timeout - match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { - Ok(response) => Ok(Some(response)), - Err(error) => match error { - TransportError::OneshotRecvError(_) => { - Err(schema_utils::SdkError::connection_closed().into()) - } - _ => Err(error), - }, + let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| { + rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout))) + }); + + let results: Vec<_> = join_all(timeout_wrapped_futures) + .await + .into_iter() + .zip(request_ids) + .map(|(res, request_id)| match res { + Ok(response) => response, + Err(error) => ServerMessage::Error(JsonrpcError::new( + RpcError::internal_error().with_message(error.to_string()), + request_id.to_owned(), + )), + }) + .collect(); + + Ok(Some(ServerMessages::Batch(results))) } - } else { - Ok(None) } } + + async fn send( + &self, + message: ClientMessage, + request_timeout: Option, + ) -> TransportResult> { + let response = self.send_message(message.into(), request_timeout).await?; + match response { + Some(r) => Ok(Some(r.as_single()?)), + None => Ok(None), + } + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let response = self.send_message(message.into(), request_timeout).await?; + match response { + Some(r) => Ok(Some(r.as_batch()?)), + None => Ok(None), + } + } + + /// Writes a string payload to the underlying asynchronous writable stream, + /// appending a newline character and flushing the stream afterward. + /// + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let mut writable_std = self.writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + Ok(()) + } } +// Server side dispatcher, Sends S and Returns R #[async_trait] -impl McpDispatch for MessageDispatcher { +impl McpDispatch + for MessageDispatcher +{ /// Sends a message from the server to the client and awaits a response if applicable. /// - /// Serializes the `MessageFromServer` to JSON, writes it to the transport, and waits for a - /// `ClientMessage` response if the message is a request. Notifications and responses return + /// Serializes the `ServerMessages` to JSON, writes it to the transport, and waits for a + /// `ClientMessages` response if the message is a request. Notifications and responses return /// `Ok(None)`. /// /// # Arguments - /// * `message` - The server message to send. - /// * `request_id` - An optional request ID (used for responses/errors, None for requests). + /// * `messages` - The client message to send, coulld be a single message or batch. /// /// # Returns - /// A `TransportResult` containing `Some(ClientMessage)` for requests with a response, + /// A `TransportResult` containing `Some(ClientMessages)` for requests with a response, /// or `None` for notifications/responses, or an error if the operation fails. /// /// # Errors /// Returns a `TransportError` if serialization, writing, or timeout occurs. - async fn send( + async fn send_message( &self, - message: MessageFromServer, - request_id: Option, + messages: ServerMessages, request_timeout: Option, - ) -> TransportResult> { - let mut writable_std = self.writable_std.lock().await; + ) -> TransportResult> { + match messages { + ServerMessages::Single(message) => { + let rx_response: Option> = + self.store_pending_request_for_message(&message).await; - // returns the request_id to be used to construct the message - // a new requestId will be returned for Requests and Notification - let outgoing_request_id = self.request_id_for_message(&message, request_id); - - let rx_response: Option> = { - // Store the sender in the pending requests map - if message.is_request() { - if let Some(request_id) = &outgoing_request_id { - let (tx_response, rx_response) = oneshot::channel::(); - let mut pending_requests = self.pending_requests.lock().await; - // store request id in the hashmap while waiting for a matching response - pending_requests.insert(request_id.clone(), tx_response); - Some(rx_response) + let message_payload = serde_json::to_string(&message).map_err(|_| { + crate::error::TransportError::JsonrpcError(RpcError::parse_error()) + })?; + + self.write_str(message_payload.as_str()).await?; + + if let Some(rx) = rx_response { + match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { + Ok(response) => Ok(Some(ClientMessages::Single(response))), + Err(error) => Err(error), + } } else { - None + Ok(None) } - } else { - None } - }; + ServerMessages::Batch(server_messages) => { + let (request_ids, pending_tasks): (Vec<_>, Vec<_>) = server_messages + .iter() + .filter(|message| message.is_request()) + .map(|message| { + ( + message.request_id().unwrap(), // guaranteed to have request_id + self.store_pending_request_for_message(message), + ) + }) + .unzip(); - let mpc_message: ServerMessage = ServerMessage::from_message(message, outgoing_request_id)?; + // send the batch messages to the client + let message_payload = serde_json::to_string(&server_messages).map_err(|_| { + crate::error::TransportError::JsonrpcError(RpcError::parse_error()) + })?; - //serialize the message and write it to the writable_std - let message_str = serde_json::to_string(&mpc_message) - .map_err(|_| crate::error::TransportError::JsonrpcError(RpcError::parse_error()))?; + self.write_str(message_payload.as_str()).await?; - writable_std.write_all(message_str.as_bytes()).await?; - writable_std.write_all(b"\n").await?; // new line - writable_std.flush().await?; + // no request in the batch, no need to wait for the result + if pending_tasks.is_empty() { + return Ok(None); + } + + let tasks = join_all(pending_tasks).await; + + let timeout_wrapped_futures = tasks.into_iter().filter_map(|rx| { + rx.map(|rx| await_timeout(rx, request_timeout.unwrap_or(self.request_timeout))) + }); - if let Some(rx) = rx_response { - match await_timeout(rx, request_timeout.unwrap_or(self.request_timeout)).await { - Ok(response) => Ok(Some(response)), - Err(error) => Err(error), + let results: Vec<_> = join_all(timeout_wrapped_futures) + .await + .into_iter() + .zip(request_ids) + .map(|(res, request_id)| match res { + Ok(response) => response, + Err(error) => ClientMessage::Error(JsonrpcError::new( + RpcError::internal_error().with_message(error.to_string()), + request_id.to_owned(), + )), + }) + .collect(); + + Ok(Some(ClientMessages::Batch(results))) } - } else { - Ok(None) } } + + async fn send( + &self, + message: ServerMessage, + request_timeout: Option, + ) -> TransportResult> { + let response = self.send_message(message.into(), request_timeout).await?; + match response { + Some(r) => Ok(Some(r.as_single()?)), + None => Ok(None), + } + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let response = self.send_message(message.into(), request_timeout).await?; + match response { + Some(r) => Ok(Some(r.as_batch()?)), + None => Ok(None), + } + } + + /// Writes a string payload to the underlying asynchronous writable stream, + /// appending a newline character and flushing the stream afterward. + /// + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let mut writable_std = self.writable_std.lock().await; + writable_std.write_all(payload.as_bytes()).await?; + writable_std.write_all(b"\n").await?; // new line + writable_std.flush().await?; + Ok(()) + } } diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs index 5dbd4cf..50dbb32 100644 --- a/crates/rust-mcp-transport/src/sse.rs +++ b/crates/rust-mcp-transport/src/sse.rs @@ -1,37 +1,44 @@ -use crate::schema::schema_utils::{McpMessage, RpcMessage}; +use crate::schema::schema_utils::{ + ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, +}; use crate::schema::RequestId; use async_trait::async_trait; -use futures::Stream; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; -use tokio::io::DuplexStream; -use tokio::sync::Mutex; +use std::time::Duration; +use tokio::io::{AsyncWriteExt, DuplexStream}; +use tokio::sync::oneshot::Sender; +use tokio::sync::{oneshot, Mutex}; +use tokio::task::JoinHandle; +use tokio::time::{self, Interval}; use crate::error::{TransportError, TransportResult}; use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; use crate::utils::{endpoint_with_session_id, CancellationTokenSource}; -use crate::{IoStream, McpDispatch, SessionId, TransportOptions}; +use crate::{IoStream, McpDispatch, SessionId, TransportDispatcher, TransportOptions}; pub struct SseTransport where - R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, + R: Clone + Send + Sync + DeserializeOwned + 'static, { shutdown_source: tokio::sync::RwLock>, is_shut_down: Mutex, read_write_streams: Mutex>, + receiver_tx: Mutex, // receiving string payload options: Arc, - message_sender: tokio::sync::RwLock>>, + message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, } /// Server-Sent Events (SSE) transport implementation impl SseTransport where - R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, + R: Clone + Send + Sync + DeserializeOwned + 'static, { /// Creates a new SseTransport instance /// @@ -40,6 +47,7 @@ where /// # Arguments /// * `read_rx` - Duplex stream for receiving messages /// * `write_tx` - Duplex stream for sending messages + /// * `receiver_tx` - Duplex stream for receiving string payload /// * `options` - Shared transport configuration options /// /// # Returns @@ -47,6 +55,7 @@ where pub fn new( read_rx: DuplexStream, write_tx: DuplexStream, + receiver_tx: DuplexStream, options: Arc, ) -> TransportResult { Ok(Self { @@ -54,8 +63,10 @@ where options, shutdown_source: tokio::sync::RwLock::new(None), is_shut_down: Mutex::new(false), - message_sender: tokio::sync::RwLock::new(None), + receiver_tx: Mutex::new(receiver_tx), + message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), }) } @@ -78,10 +89,50 @@ where } #[async_trait] -impl Transport for SseTransport -where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, - S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, +impl McpDispatch + for SseTransport +{ + async fn send_message( + &self, + message: ServerMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ServerMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +#[async_trait] //RSMX +impl Transport + for SseTransport { /// Starts the transport, initializing streams and message dispatcher /// @@ -93,18 +144,16 @@ where /// /// # Errors /// * Returns `TransportError` if streams are already taken or not initialized - async fn start(&self) -> TransportResult + Send>>> + async fn start(&self) -> TransportResult> where - MessageDispatcher: McpDispatch, + MessageDispatcher: + McpDispatch, { // Create CancellationTokenSource and token let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); let mut lock = self.shutdown_source.write().await; *lock = Some(cancellation_source); - let pending_requests: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let mut lock = self.read_write_streams.lock().await; let (read_rx, write_tx) = lock.take().ok_or_else(|| { TransportError::FromString( @@ -112,11 +161,11 @@ where ) })?; - let (stream, sender, error_stream) = MCPStream::create( + let (stream, sender, error_stream) = MCPStream::create::( Box::pin(read_rx), Mutex::new(Box::pin(write_tx)), IoStream::Writable(Box::pin(tokio::io::stderr())), - pending_requests, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); @@ -139,14 +188,23 @@ where *result } - fn message_sender(&self) -> &tokio::sync::RwLock>> { - &self.message_sender as _ + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ } fn error_stream(&self) -> &tokio::sync::RwLock> { &self.error_stream as _ } + async fn consume_string_payload(&self, payload: &str) -> TransportResult<()> { + let mut transmit = self.receiver_tx.lock().await; + transmit + .write_all(format!("{payload}\n").as_bytes()) + .await?; + transmit.flush().await?; + Ok(()) + } + /// Shuts down the transport, terminating tasks and signaling closure /// /// Cancels any running tasks and clears the cancellation source. @@ -166,4 +224,49 @@ where *is_shut_down_lock = true; Ok(()) } + + async fn keep_alive( + &self, + interval: Duration, + disconnect_tx: oneshot::Sender<()>, + ) -> TransportResult> { + let sender = self.message_sender(); + + let handle = tokio::spawn(async move { + let mut interval: Interval = time::interval(interval); + interval.tick().await; // Skip the first immediate tick + loop { + interval.tick().await; + let sender = sender.read().await; + if let Some(sender) = sender.as_ref() { + match sender.write_str(":\n").await { + Ok(_) => {} + Err(TransportError::StdioError(error)) => { + if error.kind() == std::io::ErrorKind::BrokenPipe { + let _ = disconnect_tx.send(()); + break; + } + } + _ => {} + } + } + } + }); + Ok(handle) + } + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } +} + +impl + TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + > for SseTransport +{ } diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 09a5d3b..06931d2 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -1,20 +1,24 @@ -use crate::schema::schema_utils::{McpMessage, RpcMessage}; +use crate::schema::schema_utils::{ + ClientMessage, ClientMessages, MessageFromServer, SdkError, ServerMessage, ServerMessages, +}; use crate::schema::RequestId; use async_trait::async_trait; -use futures::Stream; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; +use std::time::Duration; use tokio::process::Command; -use tokio::sync::Mutex; +use tokio::sync::oneshot::Sender; +use tokio::sync::{oneshot, Mutex}; +use tokio::task::JoinHandle; use crate::error::{TransportError, TransportResult}; use crate::mcp_stream::MCPStream; use crate::message_dispatcher::MessageDispatcher; use crate::transport::Transport; use crate::utils::CancellationTokenSource; -use crate::{IoStream, McpDispatch, TransportOptions}; +use crate::{IoStream, McpDispatch, TransportDispatcher, TransportOptions}; /// Implements a standard I/O transport for MCP communication. /// @@ -25,7 +29,7 @@ use crate::{IoStream, McpDispatch, TransportOptions}; /// operations, integrating with the MCP runtime ecosystem. pub struct StdioTransport where - R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, + R: Clone + Send + Sync + DeserializeOwned + 'static, { command: Option, args: Option>, @@ -33,13 +37,14 @@ where options: TransportOptions, shutdown_source: tokio::sync::RwLock>, is_shut_down: Mutex, - message_sender: tokio::sync::RwLock>>, + message_sender: Arc>>>, error_stream: tokio::sync::RwLock>, + pending_requests: Arc>>>, } impl StdioTransport where - R: RpcMessage + Clone + Send + Sync + DeserializeOwned + 'static, + R: Clone + Send + Sync + DeserializeOwned + 'static, { /// Creates a new `StdioTransport` instance for MCP Server. /// @@ -62,8 +67,9 @@ where options, shutdown_source: tokio::sync::RwLock::new(None), is_shut_down: Mutex::new(false), - message_sender: tokio::sync::RwLock::new(None), + message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), }) } @@ -95,8 +101,9 @@ where options, shutdown_source: tokio::sync::RwLock::new(None), is_shut_down: Mutex::new(false), - message_sender: tokio::sync::RwLock::new(None), + message_sender: Arc::new(tokio::sync::RwLock::new(None)), error_stream: tokio::sync::RwLock::new(None), + pending_requests: Arc::new(Mutex::new(HashMap::new())), }) } @@ -138,10 +145,13 @@ where } #[async_trait] -impl Transport for StdioTransport +impl Transport for StdioTransport where - R: RpcMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, - S: McpMessage + Clone + Send + Sync + serde::Serialize + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + S: Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { /// Starts the transport, initializing streams and the message dispatcher. /// @@ -156,9 +166,9 @@ where /// /// # Errors /// Returns a `TransportError` if the subprocess fails to spawn or stdio streams cannot be accessed. - async fn start(&self) -> TransportResult + Send>>> + async fn start(&self) -> TransportResult> where - MessageDispatcher: McpDispatch, + MessageDispatcher: McpDispatch, { // Create CancellationTokenSource and token let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); @@ -200,14 +210,13 @@ where .take() .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; - let pending_requests: Arc>>> = - Arc::new(Mutex::new(HashMap::new())); - let pending_requests_clone = Arc::clone(&pending_requests); + let pending_requests_clone1 = self.pending_requests.clone(); + let pending_requests_clone2 = self.pending_requests.clone(); tokio::spawn(async move { let _ = process.wait().await; // clean up pending requests to cancel waiting tasks - let mut pending_requests = pending_requests.lock().await; + let mut pending_requests = pending_requests_clone1.lock().await; pending_requests.clear(); }); @@ -215,7 +224,7 @@ where Box::pin(stdout), Mutex::new(Box::pin(stdin)), IoStream::Readable(Box::pin(stderr)), - pending_requests_clone, + pending_requests_clone2, self.options.timeout, cancellation_token, ); @@ -228,7 +237,7 @@ where Ok(stream) } else { - let pending_requests: Arc>>> = + let pending_requests: Arc>>> = Arc::new(Mutex::new(HashMap::new())); let (stream, sender, error_stream) = MCPStream::create( Box::pin(tokio::io::stdin()), @@ -248,20 +257,41 @@ where } } + async fn pending_request_tx(&self, request_id: &RequestId) -> Option> { + let mut pending_requests = self.pending_requests.lock().await; + pending_requests.remove(request_id) + } + /// Checks if the transport has been shut down. async fn is_shut_down(&self) -> bool { let result = self.is_shut_down.lock().await; *result } - fn message_sender(&self) -> &tokio::sync::RwLock>> { - &self.message_sender as _ + fn message_sender(&self) -> Arc>>> { + self.message_sender.clone() as _ } fn error_stream(&self) -> &tokio::sync::RwLock> { &self.error_stream as _ } + async fn consume_string_payload(&self, _payload: &str) -> TransportResult<()> { + Err(TransportError::FromString( + "Invalid invocation of consume_string_payload() function in StdioTransport".to_string(), + )) + } + + async fn keep_alive( + &self, + _interval: Duration, + _disconnect_tx: oneshot::Sender<()>, + ) -> TransportResult> { + Err(TransportError::FromString( + "Invalid invocation of keep_alive() function for StdioTransport".to_string(), + )) + } + // Shuts down the transport, terminating any subprocess and signaling closure. /// /// Sends a shutdown signal via the watch channel and kills the subprocess if present. @@ -285,3 +315,55 @@ where Ok(()) } } + +#[async_trait] +impl McpDispatch + for StdioTransport +{ + async fn send_message( + &self, + message: ServerMessages, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_message(message, request_timeout).await + } + + async fn send( + &self, + message: ServerMessage, + request_timeout: Option, + ) -> TransportResult> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send(message, request_timeout).await + } + + async fn send_batch( + &self, + message: Vec, + request_timeout: Option, + ) -> TransportResult>> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.send_batch(message, request_timeout).await + } + + async fn write_str(&self, payload: &str) -> TransportResult<()> { + let sender = self.message_sender.read().await; + let sender = sender.as_ref().ok_or(SdkError::connection_closed())?; + sender.write_str(payload).await + } +} + +impl + TransportDispatcher< + ClientMessages, + MessageFromServer, + ClientMessage, + ServerMessages, + ServerMessage, + > for StdioTransport +{ +} diff --git a/crates/rust-mcp-transport/src/transport.rs b/crates/rust-mcp-transport/src/transport.rs index f2b4a96..3d17ebd 100644 --- a/crates/rust-mcp-transport/src/transport.rs +++ b/crates/rust-mcp-transport/src/transport.rs @@ -1,9 +1,12 @@ -use std::{pin::Pin, time::Duration}; +use std::{pin::Pin, sync::Arc, time::Duration}; -use crate::schema::{schema_utils::McpMessage, RequestId}; +use crate::schema::RequestId; use async_trait::async_trait; -use futures::Stream; +use tokio::{ + sync::oneshot::{self, Sender}, + task::JoinHandle, +}; use crate::{error::TransportResult, message_dispatcher::MessageDispatcher}; @@ -40,81 +43,120 @@ impl Default for TransportOptions { } } -/// A trait for sending MCP messages. +/// A trait for dispatching MCP (Message Communication Protocol) messages. /// -///It is intended to be implemented by types that send messages in the MCP protocol, such as servers or clients. -/// -/// The `McpDispatch` trait requires two associated types: -/// - `R`: The type of the response, which must implement the `McpMessage` trait and be capable of deserialization. -/// - `S`: The type of the message to send, which must be serializable and cloneable. -/// -/// Both associated types `R` and `S` must be `Send`, `Sync`, and `'static` to ensure they can be used -/// safely in an asynchronous context and across threads. +/// This trait is designed to be implemented by components such as clients, servers, or transports +/// that send and receive messages in the MCP protocol. It defines the interface for transmitting messages, +/// optionally awaiting responses, writing raw payloads, and handling batch communication. /// /// # Associated Types /// -/// - `R`: The response type, which must implement the `McpMessage` trait, be `Clone`, `Send`, `Sync`, and -/// be deserializable (`DeserializeOwned`). -/// - `S`: The type of the message to send, which must be `Clone`, `Send`, `Sync`, and serializable (`Serialize`). -/// -/// # Methods -/// -/// ### `send` -/// -/// Sends a raw message represented by type `S` and optionally includes a `request_id`. -/// The method returns a `TransportResult>`, where: -/// - `Option`: The response, which can be `None` or contain the response of type `R`. -/// - `TransportResult`: Represents the result of the operation, which can include success or failure. -/// -/// # Arguments -/// - `message`: The message to send, of type `S`, which will be serialized before transmission. -/// - `request_id`: An optional `RequestId` to associate with this message. It can be used for tracking -/// or correlating the request with its response. -/// -/// # Example -/// -/// let sender: Box> = ...; -/// let result = sender.send(my_message, Some(request_id)).await; +/// - `R`: The response type expected from a message. This must implement deserialization and be safe +/// for concurrent use in async contexts. +/// - `S`: The type of the outgoing message sent directly to the wire. Must be serializable. +/// - `M`: The internal message type used for responses received from a remote peer. +/// - `OM`: The outgoing message type submitted to the dispatcher. This is the higher-level form of `S` +/// used by clients or services submitting requests. /// #[async_trait] -pub trait McpDispatch: Send + Sync + 'static +pub trait McpDispatch: Send + Sync + 'static where - R: McpMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { /// Sends a raw message represented by type `S` and optionally includes a `request_id`. /// The `request_id` is used when sending a message in response to an MCP request. /// It should match the `request_id` of the original request. - async fn send( + async fn send_message( &self, message: S, - request_id: Option, request_timeout: Option, ) -> TransportResult>; + + async fn send(&self, message: OM, timeout: Option) -> TransportResult>; + async fn send_batch( + &self, + message: Vec, + timeout: Option, + ) -> TransportResult>>; + + /// Writes a string payload to the underlying asynchronous writable stream, + /// appending a newline character and flushing the stream afterward. + /// + async fn write_str(&self, payload: &str) -> TransportResult<()>; } -/// A trait representing the transport layer for MCP. +/// A trait representing the transport layer for the MCP (Message Communication Protocol). +/// +/// This trait abstracts the transport layer functionality required to send and receive messages +/// within an MCP-based system. It provides methods to initialize the transport, send and receive +/// messages, handle errors, manage pending requests, and implement keep-alive functionality. /// -/// This trait is designed for handling the transport of messages within an MCP protocol system. It -/// provides a method to start the transport process, which involves setting up a stream, a message sender, -/// and handling I/O operations. +/// # Associated Types /// -/// The `Transport` trait requires three associated types: -/// - `R`: The message type to send, which must implement the `McpMessage` trait. -/// - `S`: The message type to send. -/// - `M`: The type of message that we expect to receive as a response to the sent message. +/// - `R`: The type of message expected to be received from the transport layer. Must be deserializable. +/// - `S`: The type of message to be sent over the transport layer. Must be serializable. +/// - `M`: The internal message type used by the dispatcher. Typically this wraps or transforms `R`. +/// - `OR`: The outbound response type expected to be produced by the dispatcher when handling incoming messages. +/// - `OM`: The outbound message type that the dispatcher expects to send as a reply to received messages. /// #[async_trait] -pub trait Transport: Send + Sync + 'static +pub trait Transport: Send + Sync + 'static where - R: McpMessage + Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, S: Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, { - async fn start(&self) -> TransportResult + Send>>> + async fn start(&self) -> TransportResult> where - MessageDispatcher: McpDispatch; - fn message_sender(&self) -> &tokio::sync::RwLock>>; + MessageDispatcher: McpDispatch; + fn message_sender(&self) -> Arc>>>; fn error_stream(&self) -> &tokio::sync::RwLock>; async fn shut_down(&self) -> TransportResult<()>; async fn is_shut_down(&self) -> bool; + async fn consume_string_payload(&self, payload: &str) -> TransportResult<()>; + async fn pending_request_tx(&self, request_id: &RequestId) -> Option>; + async fn keep_alive( + &self, + interval: Duration, + disconnect_tx: oneshot::Sender<()>, + ) -> TransportResult>; +} + +/// A composite trait that combines both transport and dispatch capabilities for the MCP protocol. +/// +/// `TransportDispatcher` unifies the functionality of [`Transport`] and [`McpDispatch`], allowing implementors +/// to both manage the transport layer and handle message dispatch logic in a single abstraction. +/// +/// This trait applies to components responsible for the following operations: +/// - Handle low-level I/O (stream management, payload parsing, lifecycle control) +/// - Dispatch and route messages, potentially awaiting or sending responses +/// +/// # Supertraits +/// +/// - [`Transport`]: Provides the transport-level operations (starting, shutting down, +/// receiving messages, etc.). +/// - [`McpDispatch`]: Provides message-sending and dispatching capabilities. +/// +/// # Associated Types +/// +/// - `R`: The raw message type expected to be received. Must be deserializable. +/// - `S`: The message type sent over the transport (often serialized directly to wire). +/// - `M`: The internal message type used within the dispatcher. +/// - `OR`: The outbound response type returned from processing a received message. +/// - `OM`: The outbound message type submitted by clients or application code. +/// +pub trait TransportDispatcher: + Transport + McpDispatch +where + R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + S: Clone + Send + Sync + serde::Serialize + 'static, + M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, + OR: Clone + Send + Sync + serde::Serialize + 'static, + OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static, +{ } diff --git a/crates/rust-mcp-transport/src/utils/readable_channel.rs b/crates/rust-mcp-transport/src/utils/readable_channel.rs index d07ca63..547a5fe 100644 --- a/crates/rust-mcp-transport/src/utils/readable_channel.rs +++ b/crates/rust-mcp-transport/src/utils/readable_channel.rs @@ -101,6 +101,7 @@ mod tests { let mut buf2 = vec![0; 6]; reader.read_exact(&mut buf2).await.unwrap(); assert_eq!(&buf2, b" world"); + drop(tx); } #[tokio::test] diff --git a/crates/rust-mcp-transport/tests/check_imports.rs b/crates/rust-mcp-transport/tests/check_imports.rs new file mode 100644 index 0000000..cda7d0c --- /dev/null +++ b/crates/rust-mcp-transport/tests/check_imports.rs @@ -0,0 +1,88 @@ +#[cfg(test)] +mod tests { + use std::fs::File; + use std::io::{self, Read}; + use std::path::{Path, MAIN_SEPARATOR_STR}; + + // List of files to exclude from the check + const EXCLUDED_FILES: &[&str] = &["src/schema.rs"]; + + // Check all .rs files for incorrect `use rust_mcp_schema` imports + #[test] + fn check_no_rust_mcp_schema_imports() { + let mut errors = Vec::new(); + + // Walk through the src directory + for entry in walk_src_dir("src").expect("Failed to read src directory") { + let entry = entry.unwrap(); + let path = entry.path(); + + // only check files with .rs extension + if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("rs") { + let abs_path = path.to_string_lossy(); + let relative_path = path.strip_prefix("src").unwrap_or(&path); + let path_str = relative_path.to_string_lossy(); + + // Skip excluded files + if EXCLUDED_FILES + .iter() + .any(|&excluded| abs_path.replace(MAIN_SEPARATOR_STR, "/") == excluded) + { + continue; + } + + // Read the file content + match read_file(&path) { + Ok(content) => { + // Check for `use rust_mcp_schema` + if content.contains("use rust_mcp_schema") { + errors.push(format!( + "File {} contains `use rust_mcp_schema`. Use `use crate::schema` instead.", + abs_path + )); + } + } + Err(e) => { + errors.push(format!("Failed to read file `{}`: {}", path_str, e)); + } + } + } + } + + // If there are any errors, fail the test with all error messages + if !errors.is_empty() { + panic!( + "Found {} incorrect imports:\n{}\n\n", + errors.len(), + errors.join("\n") + ); + } + } + + // Helper function to walk the src directory + fn walk_src_dir>( + path: P, + ) -> io::Result>> { + Ok(std::fs::read_dir(path)?.flat_map(|entry| { + let entry = entry.unwrap(); + let path = entry.path(); + if path.is_dir() { + // Recursively walk subdirectories + walk_src_dir(&path) + .into_iter() + .flatten() + .collect::>() + } else { + vec![Ok(entry)] + } + })) + } + + // Helper function to read file content + fn read_file(path: &Path) -> io::Result { + let mut file = File::open(path)?; + let mut content = String::new(); + file.read_to_string(&mut content)?; + Ok(content) + } +} diff --git a/examples/hello-world-server-core-sse/README.md b/examples/hello-world-server-core-sse/README.md deleted file mode 100644 index b276ca5..0000000 --- a/examples/hello-world-server-core-sse/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Hello World MCP Server (Core) - SSE Transport - -A basic MCP server implementation featuring two custom tools: `Say Hello` and `Say Goodbye` , utilizing [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) , using SSE transport - -## Overview - -This project showcases a fundamental MCP server implementation, highlighting the capabilities of -[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) with these features: - -- SSE transport -- Custom server handler -- Basic server capabilities - -## Running the Example - -1. Clone the repository: - -```bash -git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git -cd rust-mcp-sdk -``` - -2. Build and start the server: - -```bash -cargo run -p hello-world-server-sse --release -``` - -By default, the SSE endpoint is accessible at `http://127.0.0.1:8080/sse`. -You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. - -```bash -npx -y @modelcontextprotocol/inspector -``` - -Then visit: http://localhost:6274/?transport=sse&serverUrl=http://localhost:8080/sse - -Here you can see it in action : - -![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-sse.gif) diff --git a/examples/hello-world-server-core-sse/.gitignore b/examples/hello-world-server-core-streamable-http/.gitignore similarity index 100% rename from examples/hello-world-server-core-sse/.gitignore rename to examples/hello-world-server-core-streamable-http/.gitignore diff --git a/examples/hello-world-server-core-sse/Cargo.toml b/examples/hello-world-server-core-streamable-http/Cargo.toml similarity index 91% rename from examples/hello-world-server-core-sse/Cargo.toml rename to examples/hello-world-server-core-streamable-http/Cargo.toml index 5d1c43f..ac3500c 100644 --- a/examples/hello-world-server-core-sse/Cargo.toml +++ b/examples/hello-world-server-core-streamable-http/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "hello-world-server-core-sse" +name = "hello-world-server-core-streamable-http" version = "0.1.15" edition = "2021" publish = false diff --git a/examples/hello-world-server-core-streamable-http/README.md b/examples/hello-world-server-core-streamable-http/README.md new file mode 100644 index 0000000..cd37623 --- /dev/null +++ b/examples/hello-world-server-core-streamable-http/README.md @@ -0,0 +1,68 @@ +# Hello World MCP Server (Core) - Streamable Http + +A simple MCP server implementation with two custom tools `Say Hello` and `Say Goodbye` , utilizing [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). It uses Streamable HTTP as the primary transport, while also supporting SSE for backward compatibility. + +## Overview + +This project showcases a fundamental MCP server implementation, highlighting the capabilities of +[rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) with these features: + +- Streamable HTTP transport +- SSE transport (for backward compatibility) +- Custom server handler +- Basic server capabilities + +💡 By default, both **Streamable HTTP** and **SSE** transports are enabled for backward compatibility. +To disable the SSE transport, set the `sse_support` value in the `HyperServerOptions` accordingly: + +```rs +let server = + hyper_server_core::create_server(server_details, handler, + HyperServerOptions{ + sse_support: false, // Disable SSE support + Default::default() + }); +``` + + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2. Build and start the server: + +```bash +cargo run -p hello-world-server-core-streamable-http --release +``` + +By default, both the Streamable HTTP and SSE endpoints are displayed in the terminal: + +```sh +• Streamable HTTP Server is available at http://127.0.0.1:8080/mcp +• SSE Server is available at http://127.0.0.1:8080/sse +``` + +You can test them out with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. + +```bash +npx -y @modelcontextprotocol/inspector@latest +``` + +That will open the inspector in a browser, + +Then , to test the server, visit one of the following URLs based on the desired transport: + +* Streamable HTTP: + [http://localhost:6274/?transport=streamable-http\&serverUrl=http://localhost:8080/mcp](http://localhost:6274/?transport=streamable-http&serverUrl=http://localhost:8080/mcp) +* SSE: + [http://localhost:6274/?transport=sse\&serverUrl=http://localhost:8080/sse](http://localhost:6274/?transport=sse&serverUrl=http://localhost:8080/sse) + + +Here you can see it in action : + +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) diff --git a/examples/hello-world-server-core-sse/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs similarity index 100% rename from examples/hello-world-server-core-sse/src/handler.rs rename to examples/hello-world-server-core-streamable-http/src/handler.rs diff --git a/examples/hello-world-server-core-sse/src/main.rs b/examples/hello-world-server-core-streamable-http/src/main.rs similarity index 78% rename from examples/hello-world-server-core-sse/src/main.rs rename to examples/hello-world-server-core-streamable-http/src/main.rs index 31a5620..7b41c70 100644 --- a/examples/hello-world-server-core-sse/src/main.rs +++ b/examples/hello-world-server-core-streamable-http/src/main.rs @@ -17,8 +17,7 @@ async fn main() -> SdkResult<()> { // initialize tracing tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -26,9 +25,9 @@ async fn main() -> SdkResult<()> { let server_details = InitializeResult { // server name and version server_info: Implementation { - name: "Hello World MCP Server SSE".to_string(), + name: "Hello World MCP Server Streamable HTTP + SSE".to_string(), version: "0.1.0".to_string(), - title: Some("Hello World MCP Server SSE".to_string()), + title: Some("Hello World MCP Server Streamable HTTP + SSE".to_string()), }, capabilities: ServerCapabilities { // indicates that server support mcp tools @@ -44,8 +43,14 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 3: create a MCP server - let server = - hyper_server_core::create_server(server_details, handler, HyperServerOptions::default()); + let server = hyper_server_core::create_server( + server_details, + handler, + HyperServerOptions { + sse_support: true, + ..Default::default() + }, + ); // STEP 4: Start the server server.start().await?; diff --git a/examples/hello-world-server-core-sse/src/tools.rs b/examples/hello-world-server-core-streamable-http/src/tools.rs similarity index 100% rename from examples/hello-world-server-core-sse/src/tools.rs rename to examples/hello-world-server-core-streamable-http/src/tools.rs diff --git a/examples/hello-world-server-sse/README.md b/examples/hello-world-server-sse/README.md deleted file mode 100644 index 05c22c1..0000000 --- a/examples/hello-world-server-sse/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# Hello World MCP Server - SSE Transport - -A basic MCP server implementation using SSE transport, featuring two custom tools: `Say Hello` and `Say Goodbye` , utilizing [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) , using SSE transport - -## Overview - -This project showcases a fundamental MCP server implementation, highlighting the capabilities of [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) with these features: - -- SSE transport -- Custom server handler -- Basic server capabilities - -## Running the Example - -1. Clone the repository: - -```bash -git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git -cd rust-mcp-sdk -``` - -2. Build and start the server: - -```bash -cargo run -p hello-world-server-sse --release -``` - -By default, the SSE endpoint is accessible at `http://127.0.0.1:8080/sse`. -You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. - -```bash -npx -y @modelcontextprotocol/inspector -``` - -Then visit: http://localhost:6274/?transport=sse&serverUrl=http://localhost:8080/sse - -Here you can see it in action : - -![hello-world-mcp-server](../../assets/examples/hello-world-server-sse.gif) diff --git a/examples/hello-world-server-sse/Cargo.toml b/examples/hello-world-server-streamable-http/Cargo.toml similarity index 92% rename from examples/hello-world-server-sse/Cargo.toml rename to examples/hello-world-server-streamable-http/Cargo.toml index e4c7692..7e34075 100644 --- a/examples/hello-world-server-sse/Cargo.toml +++ b/examples/hello-world-server-streamable-http/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "hello-world-server-sse" +name = "hello-world-server-streamable-http" version = "0.1.24" edition = "2021" publish = false diff --git a/examples/hello-world-server-streamable-http/README.md b/examples/hello-world-server-streamable-http/README.md new file mode 100644 index 0000000..ac56a86 --- /dev/null +++ b/examples/hello-world-server-streamable-http/README.md @@ -0,0 +1,69 @@ +# Hello World MCP Server - Streamable Http + +A basic MCP server implementation using SSE transport, featuring two custom tools: `Say Hello` and `Say Goodbye` , utilizing [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk). It uses Streamable HTTP as the primary transport, while also supporting SSE for backward compatibility. + +## Overview + +This project showcases a fundamental MCP server implementation, highlighting the capabilities of [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) with these features: + +- Streamable HTTP transport +- SSE transport (for backward compatibility) +- Custom server handler +- Basic server capabilities + +💡 By default, both **Streamable HTTP** and **SSE** transports are enabled for backward compatibility. +To disable the SSE transport, set the `sse_support` value in the `HyperServerOptions` accordingly: + +```rs +let server = + hyper_server_core::create_server(server_details, handler, + HyperServerOptions{ + sse_support: false, // Disable SSE support + Default::default() + }); +``` + + +## Running the Example + +1. Clone the repository: + +```bash +git clone git@github.com:rust-mcp-stack/rust-mcp-sdk.git +cd rust-mcp-sdk +``` + +2. Build and start the server: + +```bash +cargo run -p hello-world-server-streamable-http --release +``` + +By default, both the Streamable HTTP and SSE endpoints are displayed in the terminal: + +```sh +• Streamable HTTP Server is available at http://127.0.0.1:8080/mcp +• SSE Server is available at http://127.0.0.1:8080/sse +``` + +You can test it with [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector), or alternatively, use it with any MCP client you prefer. + +```bash +npx -y @modelcontextprotocol/inspector@latest +``` + + +That will open the inspector in a browser, + +That will open the inspector in a browser, + +Then , to test the server, visit one of the following URLs based on the desired transport: + +* Streamable HTTP: + [http://localhost:6274/?transport=streamable-http\&serverUrl=http://localhost:8080/mcp](http://localhost:6274/?transport=streamable-http&serverUrl=http://localhost:8080/mcp) +* SSE: + [http://localhost:6274/?transport=sse\&serverUrl=http://localhost:8080/sse](http://localhost:6274/?transport=sse&serverUrl=http://localhost:8080/sse) + +Here you can see it in action : + +![hello-world-mcp-server-sse-core](../../assets/examples/hello-world-server-core-streamable-http.gif) diff --git a/examples/hello-world-server-sse/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs similarity index 100% rename from examples/hello-world-server-sse/src/handler.rs rename to examples/hello-world-server-streamable-http/src/handler.rs diff --git a/examples/hello-world-server-sse/src/main.rs b/examples/hello-world-server-streamable-http/src/main.rs similarity index 95% rename from examples/hello-world-server-sse/src/main.rs rename to examples/hello-world-server-streamable-http/src/main.rs index b5064f3..cd8c658 100644 --- a/examples/hello-world-server-sse/src/main.rs +++ b/examples/hello-world-server-streamable-http/src/main.rs @@ -24,8 +24,7 @@ async fn main() -> SdkResult<()> { // initialize tracing tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/hello-world-server-sse/src/tools.rs b/examples/hello-world-server-streamable-http/src/tools.rs similarity index 100% rename from examples/hello-world-server-sse/src/tools.rs rename to examples/hello-world-server-streamable-http/src/tools.rs diff --git a/examples/simple-mcp-client-core-sse/src/main.rs b/examples/simple-mcp-client-core-sse/src/main.rs index 48d2e0d..459f9ba 100644 --- a/examples/simple-mcp-client-core-sse/src/main.rs +++ b/examples/simple-mcp-client-core-sse/src/main.rs @@ -21,8 +21,7 @@ const MCP_SERVER_URL: &str = "http://localhost:3001/sse"; async fn main() -> SdkResult<()> { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/simple-mcp-client-sse/src/main.rs b/examples/simple-mcp-client-sse/src/main.rs index 64886b2..ce8850a 100644 --- a/examples/simple-mcp-client-sse/src/main.rs +++ b/examples/simple-mcp-client-sse/src/main.rs @@ -21,8 +21,7 @@ const MCP_SERVER_URL: &str = "http://localhost:3001/sse"; async fn main() -> SdkResult<()> { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()), ) .with(tracing_subscriber::fmt::layer()) .init();