diff --git a/src/gateway/auth.ts b/src/gateway/auth.ts index ded563487..60864f581 100644 --- a/src/gateway/auth.ts +++ b/src/gateway/auth.ts @@ -105,7 +105,7 @@ function resolveTailscaleClientIp(req?: IncomingMessage): string | undefined { }); } -function resolveRequestClientIp( +export function resolveRequestClientIp( req?: IncomingMessage, trustedProxies?: string[], allowRealIpFallback = false, diff --git a/src/gateway/hooks-test-helpers.ts b/src/gateway/hooks-test-helpers.ts index ca0988edb..0351b829f 100644 --- a/src/gateway/hooks-test-helpers.ts +++ b/src/gateway/hooks-test-helpers.ts @@ -26,9 +26,11 @@ export function createGatewayRequest(params: { method?: string; remoteAddress?: string; host?: string; + headers?: Record; }): IncomingMessage { const headers: Record = { host: params.host ?? "localhost:18789", + ...params.headers, }; if (params.authorization) { headers.authorization = params.authorization; diff --git a/src/gateway/server-http.hooks-request-timeout.test.ts b/src/gateway/server-http.hooks-request-timeout.test.ts index 0452cab7b..e56432a95 100644 --- a/src/gateway/server-http.hooks-request-timeout.test.ts +++ b/src/gateway/server-http.hooks-request-timeout.test.ts @@ -23,6 +23,7 @@ function createRequest(params?: { authorization?: string; remoteAddress?: string; url?: string; + headers?: Record; }): IncomingMessage { return createGatewayRequest({ method: "POST", @@ -30,6 +31,7 @@ function createRequest(params?: { host: "127.0.0.1:18789", authorization: params?.authorization ?? "Bearer hook-secret", remoteAddress: params?.remoteAddress, + headers: params?.headers, }); } @@ -52,6 +54,7 @@ function createHandler(params?: { dispatchWakeHook?: HooksHandlerDeps["dispatchWakeHook"]; dispatchAgentHook?: HooksHandlerDeps["dispatchAgentHook"]; bindHost?: string; + getClientIpConfig?: HooksHandlerDeps["getClientIpConfig"]; }) { return createHooksRequestHandler({ getHooksConfig: () => createHooksConfig(), @@ -63,6 +66,7 @@ function createHandler(params?: { info: vi.fn(), error: vi.fn(), } as unknown as ReturnType, + getClientIpConfig: params?.getClientIpConfig, dispatchWakeHook: params?.dispatchWakeHook ?? ((() => { @@ -121,6 +125,36 @@ describe("createHooksRequestHandler timeout status mapping", () => { expect(setHeader).toHaveBeenCalledWith("Retry-After", expect.any(String)); }); + test("uses trusted proxy forwarded client ip for hook auth throttling", async () => { + const handler = createHandler({ + getClientIpConfig: () => ({ trustedProxies: ["10.0.0.1"] }), + }); + + for (let i = 0; i < 20; i++) { + const req = createRequest({ + authorization: "Bearer wrong", + remoteAddress: "10.0.0.1", + headers: { "x-forwarded-for": "1.2.3.4" }, + }); + const { res } = createResponse(); + const handled = await handler(req, res); + expect(handled).toBe(true); + expect(res.statusCode).toBe(401); + } + + const forwardedReq = createRequest({ + authorization: "Bearer wrong", + remoteAddress: "10.0.0.1", + headers: { "x-forwarded-for": "1.2.3.4, 10.0.0.1" }, + }); + const { res: forwardedRes, setHeader } = createResponse(); + const handled = await handler(forwardedReq, forwardedRes); + + expect(handled).toBe(true); + expect(forwardedRes.statusCode).toBe(429); + expect(setHeader).toHaveBeenCalledWith("Retry-After", expect.any(String)); + }); + test.each(["0.0.0.0", "::"])( "does not throw when bindHost=%s while parsing non-hook request URL", async (bindHost) => { diff --git a/src/gateway/server-http.ts b/src/gateway/server-http.ts index 89db12bc2..110d64e09 100644 --- a/src/gateway/server-http.ts +++ b/src/gateway/server-http.ts @@ -23,6 +23,7 @@ import { import { authorizeHttpGatewayConnect, isLocalDirectRequest, + resolveRequestClientIp, type GatewayAuthResult, type ResolvedGatewayAuth, } from "./auth.js"; @@ -351,9 +352,13 @@ export function createHooksRequestHandler( bindHost: string; port: number; logHooks: SubsystemLogger; + getClientIpConfig?: () => { + trustedProxies?: string[]; + allowRealIpFallback?: boolean; + }; } & HookDispatchers, ): HooksRequestHandler { - const { getHooksConfig, logHooks, dispatchAgentHook, dispatchWakeHook } = opts; + const { getHooksConfig, logHooks, dispatchAgentHook, dispatchWakeHook, getClientIpConfig } = opts; const hookAuthLimiter = createAuthRateLimiter({ maxAttempts: HOOK_AUTH_FAILURE_LIMIT, windowMs: HOOK_AUTH_FAILURE_WINDOW_MS, @@ -364,7 +369,14 @@ export function createHooksRequestHandler( }); const resolveHookClientKey = (req: IncomingMessage): string => { - return normalizeRateLimitClientIp(req.socket?.remoteAddress); + const clientIpConfig = getClientIpConfig?.(); + const clientIp = + resolveRequestClientIp( + req, + clientIpConfig?.trustedProxies, + clientIpConfig?.allowRealIpFallback === true, + ) ?? req.socket?.remoteAddress; + return normalizeRateLimitClientIp(clientIp); }; return async (req, res) => { diff --git a/src/gateway/server/hooks.ts b/src/gateway/server/hooks.ts index 3b159c680..8630ef008 100644 --- a/src/gateway/server/hooks.ts +++ b/src/gateway/server/hooks.ts @@ -108,6 +108,13 @@ export function createGatewayHooksRequestHandler(params: { bindHost, port, logHooks, + getClientIpConfig: () => { + const cfg = loadConfig(); + return { + trustedProxies: cfg.gateway?.trustedProxies, + allowRealIpFallback: cfg.gateway?.allowRealIpFallback === true, + }; + }, dispatchAgentHook, dispatchWakeHook, });