diff --git a/.github/dictionary.txt b/.github/dictionary.txt index 71e5ed28d0..6b5fa94fea 100644 --- a/.github/dictionary.txt +++ b/.github/dictionary.txt @@ -14,3 +14,4 @@ additionals SECG Certicom RSAES +unuse diff --git a/packages/interface-compliance-tests/src/mocks/registrar.ts b/packages/interface-compliance-tests/src/mocks/registrar.ts index 1b2aa21bd8..6f6676f085 100644 --- a/packages/interface-compliance-tests/src/mocks/registrar.ts +++ b/packages/interface-compliance-tests/src/mocks/registrar.ts @@ -1,10 +1,11 @@ import { mergeOptions } from '@libp2p/utils/merge-options' -import type { Connection, PeerId, Topology, IncomingStreamData, StreamHandler, StreamHandlerOptions, StreamHandlerRecord } from '@libp2p/interface' +import type { Connection, PeerId, Topology, IncomingStreamData, StreamHandler, StreamHandlerOptions, StreamHandlerRecord, StreamMiddleware } from '@libp2p/interface' import type { Registrar } from '@libp2p/interface-internal' export class MockRegistrar implements Registrar { private readonly topologies = new Map>() private readonly handlers = new Map() + private readonly middleware = new Map() getProtocols (): string[] { return Array.from(this.handlers.keys()).sort() @@ -69,6 +70,18 @@ export class MockRegistrar implements Registrar { getTopologies (protocol: string): Topology[] { return (this.topologies.get(protocol) ?? []).map(t => t.topology) } + + use (protocol: string, middleware: StreamMiddleware[]): void { + this.middleware.set(protocol, middleware) + } + + unuse (protocol: string): void { + this.middleware.delete(protocol) + } + + getMiddleware (protocol: string): StreamMiddleware[] { + return this.middleware.get(protocol) ?? [] + } } export function mockRegistrar (): Registrar { diff --git a/packages/interface-internal/src/registrar.ts b/packages/interface-internal/src/registrar.ts index dc96e3877f..69332a78ae 100644 --- a/packages/interface-internal/src/registrar.ts +++ b/packages/interface-internal/src/registrar.ts @@ -1,4 +1,4 @@ -import type { StreamHandler, StreamHandlerOptions, StreamHandlerRecord, Topology, IncomingStreamData } from '@libp2p/interface' +import type { StreamHandler, StreamHandlerOptions, StreamHandlerRecord, Topology, IncomingStreamData, StreamMiddleware } from '@libp2p/interface' import type { AbortOptions } from '@multiformats/multiaddr' export type { @@ -69,6 +69,30 @@ export interface Registrar { */ getHandler(protocol: string): StreamHandlerRecord + /** + * Retrieve any registered middleware for a given protocol. + * + * @param protocol - The protocol to fetch middleware for + * @returns A list of `StreamMiddleware` implementations + */ + use(protocol: string, middleware: StreamMiddleware[]): void + + /** + * Retrieve any registered middleware for a given protocol. + * + * @param protocol - The protocol to fetch middleware for + * @returns A list of `StreamMiddleware` implementations + */ + unuse(protocol: string): void + + /** + * Retrieve any registered middleware for a given protocol. + * + * @param protocol - The protocol to fetch middleware for + * @returns A list of `StreamMiddleware` implementations + */ + getMiddleware(protocol: string): StreamMiddleware[] + /** * Register a topology handler for a protocol - the topology will be * invoked when peers are discovered on the network that support the diff --git a/packages/interface/src/index.ts b/packages/interface/src/index.ts index 9a6af78b48..fe2a87655c 100644 --- a/packages/interface/src/index.ts +++ b/packages/interface/src/index.ts @@ -23,7 +23,7 @@ import type { PeerInfo } from './peer-info.js' import type { PeerRouting } from './peer-routing.js' import type { Address, Peer, PeerStore } from './peer-store.js' import type { Startable } from './startable.js' -import type { StreamHandler, StreamHandlerOptions } from './stream-handler.js' +import type { StreamHandler, StreamHandlerOptions, StreamMiddleware } from './stream-handler.js' import type { Topology } from './topology.js' import type { Listener, OutboundConnectionUpgradeEvents } from './transport.js' import type { DNS } from '@multiformats/dns' @@ -744,6 +744,33 @@ export interface Libp2p extends Startable, Ty */ unregister(id: string): void + /** + * Registers one or more middleware implementations that will be invoked for + * incoming and outgoing protocol streams that match the passed protocol. + * + * @example + * + * ```TypeScript + * libp2p.use('/my/protocol/1.0.0', (stream, connection, next) => { + * // do something with stream and/or connection + * next(stream, connection) + * }) + * ``` + */ + use (protocol: string, middleware: StreamMiddleware | StreamMiddleware[]): void + + /** + * Deregisters all middleware for the passed protocol. + * + * @example + * + * ```TypeScript + * libp2p.unuse('/my/protocol/1.0.0') + * // any previously registered middleware will no longer be invoked + * ``` + */ + unuse (protocol: string): void + /** * Returns the public key for the passed PeerId. If the PeerId is of the 'RSA' * type this may mean searching the routing if the peer's key is not present diff --git a/packages/interface/src/stream-handler.ts b/packages/interface/src/stream-handler.ts index 39a087a5f9..38cf2f1ae9 100644 --- a/packages/interface/src/stream-handler.ts +++ b/packages/interface/src/stream-handler.ts @@ -20,6 +20,13 @@ export interface StreamHandler { (data: IncomingStreamData): void } +/** + * Stream middleware allows accessing stream data outside of the stream handler + */ +export interface StreamMiddleware { + (stream: Stream, connection: Connection, next: (stream: Stream, connection: Connection) => void | Promise): void | Promise +} + export interface StreamHandlerOptions extends AbortOptions { /** * How many incoming streams can be open for this protocol at the same time on each connection @@ -46,6 +53,11 @@ export interface StreamHandlerOptions extends AbortOptions { * protocol(s), the existing handler will be discarded. */ force?: true + + /** + * Middleware allows accessing stream data outside of the stream handler + */ + middleware?: StreamMiddleware[] } export interface StreamHandlerRecord { diff --git a/packages/libp2p/src/libp2p.ts b/packages/libp2p/src/libp2p.ts index 5dad2a0520..a4e0235b4b 100644 --- a/packages/libp2p/src/libp2p.ts +++ b/packages/libp2p/src/libp2p.ts @@ -24,7 +24,7 @@ import { userAgent } from './user-agent.js' import * as pkg from './version.js' import type { Components } from './components.js' import type { Libp2p as Libp2pInterface, Libp2pInit } from './index.js' -import type { PeerRouting, ContentRouting, Libp2pEvents, PendingDial, ServiceMap, AbortOptions, ComponentLogger, Logger, Connection, NewStreamOptions, Stream, Metrics, PeerId, PeerInfo, PeerStore, Topology, Libp2pStatus, IsDialableOptions, DialOptions, PublicKey, Ed25519PeerId, Secp256k1PeerId, RSAPublicKey, RSAPeerId, URLPeerId, Ed25519PublicKey, Secp256k1PublicKey, StreamHandler, StreamHandlerOptions } from '@libp2p/interface' +import type { PeerRouting, ContentRouting, Libp2pEvents, PendingDial, ServiceMap, AbortOptions, ComponentLogger, Logger, Connection, NewStreamOptions, Stream, Metrics, PeerId, PeerInfo, PeerStore, Topology, Libp2pStatus, IsDialableOptions, DialOptions, PublicKey, Ed25519PeerId, Secp256k1PeerId, RSAPublicKey, RSAPeerId, URLPeerId, Ed25519PublicKey, Secp256k1PublicKey, StreamHandler, StreamHandlerOptions, StreamMiddleware } from '@libp2p/interface' import type { Multiaddr } from '@multiformats/multiaddr' export class Libp2p extends TypedEventEmitter implements Libp2pInterface { @@ -402,6 +402,14 @@ export class Libp2p extends TypedEventEmitter this.components.registrar.unregister(id) } + use (protocol: string, middleware: StreamMiddleware | StreamMiddleware[]): void { + this.components.registrar.use(protocol, Array.isArray(middleware) ? middleware : [middleware]) + } + + unuse (protocol: string): void { + this.components.registrar.unuse(protocol) + } + async isDialable (multiaddr: Multiaddr, options: IsDialableOptions = {}): Promise { return this.components.connectionManager.isDialable(multiaddr, options) } diff --git a/packages/libp2p/src/registrar.ts b/packages/libp2p/src/registrar.ts index c4f10cb111..d628ee047b 100644 --- a/packages/libp2p/src/registrar.ts +++ b/packages/libp2p/src/registrar.ts @@ -2,7 +2,7 @@ import { InvalidParametersError } from '@libp2p/interface' import { mergeOptions } from '@libp2p/utils/merge-options' import { trackedMap } from '@libp2p/utils/tracked-map' import * as errorsJs from './errors.js' -import type { IdentifyResult, Libp2pEvents, Logger, PeerUpdate, PeerId, PeerStore, Topology, StreamHandler, StreamHandlerRecord, StreamHandlerOptions, AbortOptions, Metrics } from '@libp2p/interface' +import type { IdentifyResult, Libp2pEvents, Logger, PeerUpdate, PeerId, PeerStore, Topology, StreamHandler, StreamHandlerRecord, StreamHandlerOptions, AbortOptions, Metrics, StreamMiddleware } from '@libp2p/interface' import type { Registrar as RegistrarInterface } from '@libp2p/interface-internal' import type { ComponentLogger } from '@libp2p/logger' import type { TypedEventTarget } from 'main-event' @@ -26,10 +26,12 @@ export class Registrar implements RegistrarInterface { private readonly topologies: Map> private readonly handlers: Map private readonly components: RegistrarComponents + private readonly middleware: Map constructor (components: RegistrarComponents) { this.components = components this.log = components.logger.forComponent('libp2p:registrar') + this.middleware = new Map() this.topologies = new Map() components.metrics?.registerMetricGroup('libp2p_registrar_topologies', { calculate: () => { @@ -165,6 +167,18 @@ export class Registrar implements RegistrarInterface { } } + use (protocol: string, middleware: StreamMiddleware[]): void { + this.middleware.set(protocol, middleware) + } + + unuse (protocol: string): void { + this.middleware.delete(protocol) + } + + getMiddleware (protocol: string): StreamMiddleware[] { + return this.middleware.get(protocol) ?? [] + } + /** * Remove a disconnected peer from the record */ diff --git a/packages/libp2p/src/upgrader.ts b/packages/libp2p/src/upgrader.ts index 833e6e6872..37277b0bed 100644 --- a/packages/libp2p/src/upgrader.ts +++ b/packages/libp2p/src/upgrader.ts @@ -10,7 +10,7 @@ import { createConnection } from './connection/index.js' import { PROTOCOL_NEGOTIATION_TIMEOUT, INBOUND_UPGRADE_TIMEOUT } from './connection-manager/constants.js' import { ConnectionDeniedError, ConnectionInterceptedError, EncryptionFailedError, MuxerUnavailableError } from './errors.js' import { DEFAULT_MAX_INBOUND_STREAMS, DEFAULT_MAX_OUTBOUND_STREAMS } from './registrar.js' -import type { Libp2pEvents, AbortOptions, ComponentLogger, MultiaddrConnection, Connection, Stream, ConnectionProtector, NewStreamOptions, ConnectionEncrypter, SecuredConnection, ConnectionGater, Metrics, PeerId, PeerStore, StreamMuxer, StreamMuxerFactory, Upgrader as UpgraderInterface, UpgraderOptions, ConnectionLimits, SecureConnectionOptions, CounterGroup, ClearableSignal } from '@libp2p/interface' +import type { Libp2pEvents, AbortOptions, ComponentLogger, MultiaddrConnection, Connection, Stream, ConnectionProtector, NewStreamOptions, ConnectionEncrypter, SecuredConnection, ConnectionGater, Metrics, PeerId, PeerStore, StreamMuxer, StreamMuxerFactory, Upgrader as UpgraderInterface, UpgraderOptions, ConnectionLimits, SecureConnectionOptions, CounterGroup, ClearableSignal, Logger, StreamMiddleware } from '@libp2p/interface' import type { ConnectionManager, Registrar } from '@libp2p/interface-internal' import type { TypedEventTarget } from 'main-event' @@ -379,6 +379,39 @@ export class Upgrader implements UpgraderInterface { }) } + async _runMiddlewareChain ( + stream: Stream, + connection: Connection, + middleware: StreamMiddleware[], + log?: Logger + ): Promise<{ stream: Stream; connection: Connection }> { + for (let i = 0; i < middleware.length; i++) { + const mw = middleware[i] + log?.trace('running middleware', i, mw) + + // eslint-disable-next-line no-loop-func + await new Promise((resolve, reject) => { + try { + const result = mw(stream, connection, (s, c) => { + stream = s + connection = c + resolve() + }) + + if (result instanceof Promise) { + result.catch(reject) + } + } catch (err) { + reject(err) + } + }) + + log?.trace('ran middleware', i, mw) + } + + return { stream, connection } + } + /** * A convenience method for generating a new `Connection` */ @@ -395,7 +428,7 @@ export class Upgrader implements UpgraderInterface { let muxer: StreamMuxer | undefined let newStream: ((multicodecs: string[], options?: AbortOptions) => Promise) | undefined - let connection: Connection // eslint-disable-line prefer-const + let connection: Connection if (muxerFactory != null) { // Create the muxer @@ -488,7 +521,7 @@ export class Upgrader implements UpgraderInterface { } connection.log.trace('starting new stream for protocols %s', protocols) - const muxedStream = await muxer.newStream() + let muxedStream = await muxer.newStream() connection.log.trace('started new stream %s for protocols %s', muxedStream.id, protocols) try { @@ -556,6 +589,19 @@ export class Upgrader implements UpgraderInterface { this.components.metrics?.trackProtocolStream(muxedStream, connection) + const middleware = this.components.registrar.getMiddleware(protocol) + + middleware.push((stream, connection, next) => { + next(stream, connection) + }) + + ;({ stream: muxedStream, connection } = await this._runMiddlewareChain( + muxedStream, + connection, + middleware, + muxedStream.log + )) + return muxedStream } catch (err: any) { connection.log.error('could not create new outbound stream on connection %s %a for protocols %s - %e', direction === 'inbound' ? 'from' : 'to', opts.maConn.remoteAddr, protocols, err) @@ -659,7 +705,22 @@ export class Upgrader implements UpgraderInterface { throw new LimitedConnectionError('Cannot open protocol stream on limited connection') } - handler({ connection, stream }) + const middleware = this.components.registrar.getMiddleware(protocol) + + if (middleware.length === 0) { + // No middleware, call handler immediately + handler({ stream, connection }) + return + } + + this._runMiddlewareChain(stream, connection, middleware, stream.log) + .then(({ stream: s, connection: c }) => { + handler({ stream: s, connection: c }) + }) + .catch(err => { + connection.log.error('middleware error for inbound stream %s - %e', stream.id, err) + stream.abort(err) + }) } /** diff --git a/packages/libp2p/test/upgrading/upgrader.spec.ts b/packages/libp2p/test/upgrading/upgrader.spec.ts index 5fd14e8cc8..7308555afc 100644 --- a/packages/libp2p/test/upgrading/upgrader.spec.ts +++ b/packages/libp2p/test/upgrading/upgrader.spec.ts @@ -16,15 +16,17 @@ import { fromString as uint8ArrayFromString } from 'uint8arrays/from-string' import { Upgrader } from '../../src/upgrader.js' import { createDefaultUpgraderComponents } from './utils.js' import type { UpgraderComponents, UpgraderInit } from '../../src/upgrader.js' -import type { ConnectionEncrypter, StreamMuxerFactory, MultiaddrConnection, StreamMuxer, ConnectionProtector, PeerId, SecuredConnection, Stream, StreamMuxerInit, Connection } from '@libp2p/interface' +import type { ConnectionEncrypter, StreamMuxerFactory, MultiaddrConnection, StreamMuxer, ConnectionProtector, PeerId, SecuredConnection, Stream, StreamMuxerInit, Connection, AbortOptions } from '@libp2p/interface' import type { ConnectionManager, Registrar } from '@libp2p/interface-internal' import type { Multiaddr } from '@multiformats/multiaddr' +import type { SinonStub } from 'sinon' describe('upgrader', () => { let components: UpgraderComponents let init: UpgraderInit const encrypterProtocol = '/test-encrypter' const muxerProtocol = '/test-muxer' + const streamProtocol = '/test/protocol' let remotePeer: PeerId let remoteAddr: Multiaddr let maConn: MultiaddrConnection @@ -36,6 +38,66 @@ describe('upgrader', () => { async secureOutbound (): Promise { throw new Error('Boom') } } + function stubMuxerFactory (protocol: string = streamProtocol, onInit?: (init: StreamMuxerInit) => void): StreamMuxerFactory { + return stubInterface({ + protocol: muxerProtocol, + createStreamMuxer (init: StreamMuxerInit = {}): StreamMuxer { + onInit?.(init) + + // our “stub” muxer keeps its own streams list + const streams: Stream[] = [] + + const streamMuxer = stubInterface({ + protocol: muxerProtocol, + streams, + sink: async (source) => drain(source), + source: (async function * () {})(), + newStream: () => { + const outgoingStream = stubInterface({ + id: 'stream-id', + log: logger('test-stream'), + direction: 'outbound', + sink: async (source) => drain(source), + source: map((async function * () { + yield '/multistream/1.0.0\n' + yield `${protocol}\n` + })(), str => encode.single(uint8ArrayFromString(str))) + }) + + streams.push(outgoingStream) + + const abortStub = outgoingStream.abort as SinonStub<[Error], void> + abortStub.callsFake((_: Error) => { + const idx = streams.indexOf(outgoingStream) + if (idx !== -1) { + streams.splice(idx, 1) + } + }) + + const closeStub = outgoingStream.close as SinonStub<[AbortOptions?], Promise> + closeStub.callsFake(async (_?: AbortOptions) => { + const idx = streams.indexOf(outgoingStream) + if (idx !== -1) { + streams.splice(idx, 1) + } + }) + + return outgoingStream + } + }) + + // wrap the user’s onIncomingStream callback so we track inbound + const originalHandler = init.onIncomingStream + init.onIncomingStream = (incoming: Stream) => { + streams.push(incoming) + originalHandler?.(incoming) + } + + return streamMuxer + } + }) + } + beforeEach(async () => { remotePeer = peerIdFromPrivateKey(await generateKeyPair('Ed25519')) remoteAddr = multiaddr(`/ip4/123.123.123.123/tcp/1234/p2p/${remotePeer}`) @@ -435,7 +497,8 @@ describe('upgrader', () => { }, handler: Sinon.stub() }), - getProtocols: () => [protocol] + getProtocols: () => [protocol], + getMiddleware: () => [] }) }) const upgrader = new Upgrader(components, { @@ -503,7 +566,8 @@ describe('upgrader', () => { }, handler: Sinon.stub() }), - getProtocols: () => [protocol] + getProtocols: () => [protocol], + getMiddleware: () => [] }) }) const upgrader = new Upgrader(components, { @@ -566,7 +630,8 @@ describe('upgrader', () => { options: {}, handler: Sinon.stub() }), - getProtocols: () => [protocol] + getProtocols: () => [protocol], + getMiddleware: () => [] }) }) const upgrader = new Upgrader(components, { @@ -625,6 +690,239 @@ describe('upgrader', () => { .with.property('name', 'TooManyOutboundProtocolStreamsError') }) + describe('middleware', () => { + it('should support outgoing stream middleware', async () => { + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + const components = await createDefaultUpgraderComponents({ + registrar: stubInterface({ + getHandler: () => ({ + options: {}, + handler: Sinon.stub() + }), + getProtocols: () => [streamProtocol], + getMiddleware: () => middleware + }) + }) + const upgrader = new Upgrader(components, { + ...init, + streamMuxers: [ + stubMuxerFactory() + ] + }) + + const connectionPromise = pEvent<'connection:open', CustomEvent>(components.events, 'connection:open') + + await upgrader.upgradeInbound(maConn, { + signal: AbortSignal.timeout(5_000) + }) + + const event = await connectionPromise + const conn = event.detail + + expect(conn.streams).to.have.lengthOf(0) + + await conn.newStream(streamProtocol) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.true() + expect(conn.streams).to.have.lengthOf(1) + }) + + it('should support incoming stream middleware', async () => { + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + const streamMuxerInitPromise = Promise.withResolvers() + + const components = await createDefaultUpgraderComponents({ + registrar: stubInterface({ + getHandler: () => ({ + options: {}, + handler: Sinon.stub() + }), + getProtocols: () => [streamProtocol], + getMiddleware: () => middleware + }) + }) + const upgrader = new Upgrader(components, { + ...init, + streamMuxers: [ + stubMuxerFactory(muxerProtocol, (init) => { + streamMuxerInitPromise.resolve(init) + }) + ] + }) + + const conn = await upgrader.upgradeOutbound(maConn, { + signal: AbortSignal.timeout(5_000) + }) + + const { onIncomingStream } = await streamMuxerInitPromise.promise + + expect(conn.streams).to.have.lengthOf(0) + + const incomingStream = stubInterface({ + id: 'stream-id', + log: logger('test-stream'), + direction: 'outbound', + sink: async (source) => drain(source), + source: map((async function * () { + yield '/multistream/1.0.0\n' + yield `${streamProtocol}\n` + })(), str => encode.single(uint8ArrayFromString(str))) + }) + + onIncomingStream?.(incomingStream) + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.true() + expect(conn.streams).to.have.lengthOf(1) + }) + + it('should not call outbound middleware if previous middleware errors', async () => { + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + throw new Error('boom') + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + const components = await createDefaultUpgraderComponents({ + registrar: stubInterface({ + getHandler: () => ({ + options: {}, + handler: Sinon.stub() + }), + getProtocols: () => [streamProtocol], + getMiddleware: () => middleware + }) + }) + const upgrader = new Upgrader(components, { + ...init, + streamMuxers: [ + stubMuxerFactory() + ] + }) + + const connectionPromise = pEvent<'connection:open', CustomEvent>(components.events, 'connection:open') + + await upgrader.upgradeInbound(maConn, { + signal: AbortSignal.timeout(5_000) + }) + + const event = await connectionPromise + const conn = event.detail + + expect(conn.streams).to.have.lengthOf(0) + + let err: any + let stream: Stream | undefined + try { + stream = await conn.newStream(streamProtocol) + } catch (e) { + err = e + } + + expect(err).to.be.an('error').with.property('message', 'boom') + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.false() + expect(conn.streams).to.have.lengthOf(0) + expect(stream).to.be.undefined() + }) + + it('should not call inbound middleware if previous middleware errors', async () => { + const middleware1 = Sinon.stub().callsFake((stream, connection, next) => { + throw new Error('boom') + }) + const middleware2 = Sinon.stub().callsFake((stream, connection, next) => { + next(stream, connection) + }) + + const middleware = [ + middleware1, + middleware2 + ] + + const streamMuxerInitPromise = Promise.withResolvers() + + const components = await createDefaultUpgraderComponents({ + registrar: stubInterface({ + getHandler: () => ({ + options: {}, + handler: Sinon.stub() + }), + getProtocols: () => [streamProtocol], + getMiddleware: () => middleware + }) + }) + const upgrader = new Upgrader(components, { + ...init, + streamMuxers: [ + stubMuxerFactory(muxerProtocol, (init) => { + streamMuxerInitPromise.resolve(init) + }) + ] + }) + + const conn = await upgrader.upgradeOutbound(maConn, { + signal: AbortSignal.timeout(100) + }) + + const { onIncomingStream } = await streamMuxerInitPromise.promise + + expect(conn.streams).to.have.lengthOf(0) + + const incomingStream = stubInterface({ + id: 'stream-id', + log: logger('test-stream'), + direction: 'outbound', + sink: async (source) => drain(source), + source: map((async function * () { + yield '/multistream/1.0.0\n' + yield `${streamProtocol}\n` + })(), str => encode.single(uint8ArrayFromString(str))) + }) + + onIncomingStream?.(incomingStream) + + // incoming stream is opened asynchronously + await delay(100) + + expect(middleware1.called).to.be.true() + expect(middleware2.called).to.be.false() + expect(incomingStream).to.have.nested.property('abort.called', true) + }) + }) + describe('early muxer selection', () => { let earlyMuxerProtocol: string let streamMuxerFactory: StreamMuxerFactory diff --git a/packages/pubsub/test/utils/index.ts b/packages/pubsub/test/utils/index.ts index 453981c567..44c7be5747 100644 --- a/packages/pubsub/test/utils/index.ts +++ b/packages/pubsub/test/utils/index.ts @@ -1,7 +1,7 @@ import { duplexPair } from 'it-pair/duplex' import { PubSubBaseProtocol } from '../../src/index.js' import { RPC } from '../message/rpc.js' -import type { Connection, PeerId, PublishResult, PubSubRPC, PubSubRPCMessage, Topology, IncomingStreamData, StreamHandler, StreamHandlerRecord } from '@libp2p/interface' +import type { Connection, PeerId, PublishResult, PubSubRPC, PubSubRPCMessage, Topology, IncomingStreamData, StreamHandler, StreamHandlerRecord, StreamMiddleware } from '@libp2p/interface' import type { Registrar } from '@libp2p/interface-internal' export class PubsubImplementation extends PubSubBaseProtocol { @@ -31,6 +31,7 @@ export class PubsubImplementation extends PubSubBaseProtocol { export class MockRegistrar implements Registrar { private readonly topologies = new Map() private readonly handlers = new Map() + private readonly middleware = new Map() getProtocols (): string[] { const protocols = new Set() @@ -114,6 +115,18 @@ export class MockRegistrar implements Registrar { throw new Error(`No topologies registered for protocol ${protocol}`) } + + use (protocol: string, middleware: StreamMiddleware[]): void { + this.middleware.set(protocol, middleware) + } + + unuse (protocol: string): void { + this.middleware.delete(protocol) + } + + getMiddleware (protocol: string): StreamMiddleware[] { + return this.middleware.get(protocol) ?? [] + } } export const ConnectionPair = (): [Connection, Connection] => {