From 9443df80097d10fa228471c3690b42582563c152 Mon Sep 17 00:00:00 2001 From: Iain Lane Date: Fri, 16 May 2025 10:41:46 +0100 Subject: [PATCH] feat(control-plane): add support for handling multiple events in a single invocation Currently we restrict the `scale-up` Lambda to only handle a single event at a time. In very busy environments this can prove to be a bottleneck: there are calls to GitHub and AWS APIs that happen each time, and they can end up taking long enough that we can't process job queued events faster than they arrive. In our environment we are also using a pool, and typically we have responded to the alerts generated by this (SQS queue length growing) by expanding the size of the pool. This helps because we will more frequently find that we don't need to scale up, which allows the lambdas to exit a bit earlier, so we can get through the queue faster. But it makes the environment much less responsive to changes in usage patterns. At its core, this Lambda's task is to construct an EC2 `CreateFleet` call to create instances, after working out how many are needed. This is a job that can be batched. We can take any number of events, calculate the diff between our current state and the number of jobs we have, capping at the maximum, and then issue a single call. The thing to be careful about is how to handle partial failures, if EC2 creates some of the instances we wanted but not all of them. Lambda has a configurable function response type which can be set to `ReportBatchItemFailures`. In this mode, we return a list of failed messages from our handler and those are retried. We can make use of this to give back as many events as we failed to process. Now we're potentially processing multiple events in a single Lambda, one thing we should optimise for is not recreating GitHub API clients. We need one client for the app itself, which we use to find out installation IDs, and then one client for each installation which is relevant to the batch of events we are processing. This is done by creating a new client the first time we see an event for a given installation. We also remove the same `batch_size = 1` constraint from the `job-retry` Lambda and make it configurable instead, using AWS's default of 10 for SQS if not configured. This Lambda is used to retry events that previously failed. However, instead of reporting failures to be retried, here we maintain the pre-existing fault-tolerant behaviour where errors are logged but explicitly do not cause message retries, avoiding infinite loops from persistent GitHub API issues or malformed events. Tests are added for all of this. --- README.md | 2 + .../control-plane/src/lambda.test.ts | 171 +++- lambdas/functions/control-plane/src/lambda.ts | 62 +- lambdas/functions/control-plane/src/local.ts | 42 +- .../control-plane/src/pool/pool.test.ts | 24 +- .../functions/control-plane/src/pool/pool.ts | 2 +- .../src/scale-runners/job-retry.test.ts | 92 ++ .../src/scale-runners/scale-up.test.ts | 944 +++++++++++++++--- .../src/scale-runners/scale-up.ts | 268 +++-- .../aws-powertools-util/src/logger/index.ts | 10 +- main.tf | 46 +- modules/multi-runner/README.md | 2 + modules/multi-runner/runners.tf | 46 +- modules/multi-runner/variables.tf | 12 + modules/runners/README.md | 2 + modules/runners/job-retry.tf | 50 +- modules/runners/job-retry/README.md | 2 +- modules/runners/job-retry/main.tf | 7 +- modules/runners/job-retry/variables.tf | 16 +- modules/runners/scale-up.tf | 8 +- modules/runners/variables.tf | 20 + variables.tf | 16 + 22 files changed, 1468 insertions(+), 376 deletions(-) diff --git a/README.md b/README.md index e264bdc3d5..9e70df6d14 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,8 @@ Join our discord community via [this invite link](https://discord.gg/bxgXW8jJGh) | [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no | | [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. This key must be in the current account. | `string` | `null` | no | | [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no | +| [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no | +| [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no | | [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. |
list(object({
type = string
identifiers = list(string)
}))
| `[]` | no | | [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no | | [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no | diff --git a/lambdas/functions/control-plane/src/lambda.test.ts b/lambdas/functions/control-plane/src/lambda.test.ts index 2c54a4d541..c6f9f24c1d 100644 --- a/lambdas/functions/control-plane/src/lambda.test.ts +++ b/lambdas/functions/control-plane/src/lambda.test.ts @@ -70,19 +70,33 @@ vi.mock('@aws-github-runner/aws-powertools-util'); vi.mock('@aws-github-runner/aws-ssm-util'); describe('Test scale up lambda wrapper.', () => { - it('Do not handle multiple record sets.', async () => { - await testInvalidRecords([sqsRecord, sqsRecord]); + it('Do not handle empty record sets.', async () => { + const sqsEventMultipleRecords: SQSEvent = { + Records: [], + }; + + await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow(); }); - it('Do not handle empty record sets.', async () => { - await testInvalidRecords([]); + it('Ignores non-sqs event sources.', async () => { + const record = { + ...sqsRecord, + eventSource: 'aws:non-sqs', + }; + + const sqsEventMultipleRecordsNonSQS: SQSEvent = { + Records: [record], + }; + + await expect(scaleUpHandler(sqsEventMultipleRecordsNonSQS, context)).resolves.not.toThrow(); + expect(scaleUp).toHaveBeenCalledWith([]); }); it('Scale without error should resolve.', async () => { const mock = vi.fn(scaleUp); mock.mockImplementation(() => { return new Promise((resolve) => { - resolve(); + resolve([]); }); }); await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow(); @@ -104,28 +118,137 @@ describe('Test scale up lambda wrapper.', () => { vi.mocked(scaleUp).mockImplementation(mock); await expect(scaleUpHandler(sqsEvent, context)).rejects.toThrow(error); }); -}); -async function testInvalidRecords(sqsRecords: SQSRecord[]) { - const mock = vi.fn(scaleUp); - const logWarnSpy = vi.spyOn(logger, 'warn'); - mock.mockImplementation(() => { - return new Promise((resolve) => { - resolve(); + describe('Batch processing', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const createMultipleRecords = (count: number, eventSource = 'aws:sqs'): SQSRecord[] => { + return Array.from({ length: count }, (_, i) => ({ + ...sqsRecord, + eventSource, + messageId: `message-${i}`, + body: JSON.stringify({ + ...body, + id: i + 1, + }), + })); + }; + + it('Should handle multiple SQS records in a single invocation', async () => { + const records = createMultipleRecords(3); + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.resolve([])); + vi.mocked(scaleUp).mockImplementation(mock); + + await expect(scaleUpHandler(multiRecordEvent, context)).resolves.not.toThrow(); + expect(scaleUp).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ messageId: 'message-0' }), + expect.objectContaining({ messageId: 'message-1' }), + expect.objectContaining({ messageId: 'message-2' }), + ]), + ); + }); + + it('Should return batch item failures for rejected messages', async () => { + const records = createMultipleRecords(3); + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.resolve(['message-1', 'message-2'])); + vi.mocked(scaleUp).mockImplementation(mock); + + const result = await scaleUpHandler(multiRecordEvent, context); + expect(result).toEqual({ + batchItemFailures: [{ itemIdentifier: 'message-1' }, { itemIdentifier: 'message-2' }], + }); + }); + + it('Should filter out non-SQS event sources', async () => { + const sqsRecords = createMultipleRecords(2, 'aws:sqs'); + const nonSqsRecords = createMultipleRecords(1, 'aws:sns'); + const mixedEvent: SQSEvent = { + Records: [...sqsRecords, ...nonSqsRecords], + }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.resolve([])); + vi.mocked(scaleUp).mockImplementation(mock); + + await scaleUpHandler(mixedEvent, context); + expect(scaleUp).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ messageId: 'message-0' }), + expect.objectContaining({ messageId: 'message-1' }), + ]), + ); + expect(scaleUp).not.toHaveBeenCalledWith( + expect.arrayContaining([expect.objectContaining({ messageId: 'message-2' })]), + ); + }); + + it('Should sort messages by retry count', async () => { + const records = [ + { + ...sqsRecord, + messageId: 'high-retry', + body: JSON.stringify({ ...body, retryCounter: 5 }), + }, + { + ...sqsRecord, + messageId: 'low-retry', + body: JSON.stringify({ ...body, retryCounter: 1 }), + }, + { + ...sqsRecord, + messageId: 'no-retry', + body: JSON.stringify({ ...body }), + }, + ]; + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation((messages) => { + // Verify messages are sorted by retry count (ascending) + expect(messages[0].messageId).toBe('no-retry'); + expect(messages[1].messageId).toBe('low-retry'); + expect(messages[2].messageId).toBe('high-retry'); + return Promise.resolve([]); + }); + vi.mocked(scaleUp).mockImplementation(mock); + + await scaleUpHandler(multiRecordEvent, context); + }); + + it('Should return all failed messages when scaleUp throws non-ScaleError', async () => { + const records = createMultipleRecords(2); + const multiRecordEvent: SQSEvent = { Records: records }; + + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.reject(new Error('Generic error'))); + vi.mocked(scaleUp).mockImplementation(mock); + + const result = await scaleUpHandler(multiRecordEvent, context); + expect(result).toEqual({ batchItemFailures: [] }); + }); + + it('Should throw when scaleUp throws ScaleError', async () => { + const records = createMultipleRecords(2); + const multiRecordEvent: SQSEvent = { Records: records }; + + const error = new ScaleError('Critical scaling error'); + const mock = vi.fn(scaleUp); + mock.mockImplementation(() => Promise.reject(error)); + vi.mocked(scaleUp).mockImplementation(mock); + + await expect(scaleUpHandler(multiRecordEvent, context)).rejects.toThrow(error); }); }); - const sqsEventMultipleRecords: SQSEvent = { - Records: sqsRecords, - }; - - await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow(); - - expect(logWarnSpy).toHaveBeenCalledWith( - expect.stringContaining( - 'Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.', - ), - ); -} +}); describe('Test scale down lambda wrapper.', () => { it('Scaling down no error.', async () => { diff --git a/lambdas/functions/control-plane/src/lambda.ts b/lambdas/functions/control-plane/src/lambda.ts index 3e3ab90557..266fdfc7a1 100644 --- a/lambdas/functions/control-plane/src/lambda.ts +++ b/lambdas/functions/control-plane/src/lambda.ts @@ -1,34 +1,72 @@ import middy from '@middy/core'; import { logger, setContext } from '@aws-github-runner/aws-powertools-util'; import { captureLambdaHandler, tracer } from '@aws-github-runner/aws-powertools-util'; -import { Context, SQSEvent } from 'aws-lambda'; +import { Context, type SQSBatchItemFailure, type SQSBatchResponse, SQSEvent } from 'aws-lambda'; import { PoolEvent, adjust } from './pool/pool'; import ScaleError from './scale-runners/ScaleError'; import { scaleDown } from './scale-runners/scale-down'; -import { scaleUp } from './scale-runners/scale-up'; +import { type ActionRequestMessage, type ActionRequestMessageSQS, scaleUp } from './scale-runners/scale-up'; import { SSMCleanupOptions, cleanSSMTokens } from './scale-runners/ssm-housekeeper'; import { checkAndRetryJob } from './scale-runners/job-retry'; -export async function scaleUpHandler(event: SQSEvent, context: Context): Promise { +export async function scaleUpHandler(event: SQSEvent, context: Context): Promise { setContext(context, 'lambda.ts'); logger.logEventIfEnabled(event); - if (event.Records.length !== 1) { - logger.warn('Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.'); - return Promise.resolve(); + // Group the messages by their event source. We're only interested in + // `aws:sqs`-originated messages. + const groupedEvents = new Map(); + for (const { body, eventSource, messageId } of event.Records) { + const group = groupedEvents.get(eventSource) || []; + const payload = JSON.parse(body) as ActionRequestMessage; + + if (group.length === 0) { + groupedEvents.set(eventSource, group); + } + + groupedEvents.get(eventSource)?.push({ + ...payload, + messageId, + }); + } + + for (const [eventSource, messages] of groupedEvents.entries()) { + if (eventSource === 'aws:sqs') { + continue; + } + + logger.warn('Ignoring non-sqs event source', { eventSource, messages }); } + const sqsMessages = groupedEvents.get('aws:sqs') ?? []; + + // Sort messages by their retry count, so that we retry the same messages if + // there's a persistent failure. This should cause messages to be dropped + // quicker than if we retried in an arbitrary order. + sqsMessages.sort((l, r) => { + return (l.retryCounter ?? 0) - (r.retryCounter ?? 0); + }); + + const batchItemFailures: SQSBatchItemFailure[] = []; + try { - await scaleUp(event.Records[0].eventSource, JSON.parse(event.Records[0].body)); - return Promise.resolve(); + const rejectedMessageIds = await scaleUp(sqsMessages); + + for (const messageId of rejectedMessageIds) { + batchItemFailures.push({ + itemIdentifier: messageId, + }); + } + + return { batchItemFailures }; } catch (e) { if (e instanceof ScaleError) { - return Promise.reject(e); - } else { - logger.warn(`Ignoring error: ${e}`); - return Promise.resolve(); + throw e; } + + logger.warn(`Will retry error: ${e}`); + return { batchItemFailures }; } } diff --git a/lambdas/functions/control-plane/src/local.ts b/lambdas/functions/control-plane/src/local.ts index 2166da58fd..0b06335c8a 100644 --- a/lambdas/functions/control-plane/src/local.ts +++ b/lambdas/functions/control-plane/src/local.ts @@ -1,21 +1,21 @@ import { logger } from '@aws-github-runner/aws-powertools-util'; -import { ActionRequestMessage, scaleUp } from './scale-runners/scale-up'; +import { scaleUpHandler } from './lambda'; +import { Context, SQSEvent } from 'aws-lambda'; -const sqsEvent = { +const sqsEvent: SQSEvent = { Records: [ { messageId: 'e8d74d08-644e-42ca-bf82-a67daa6c4dad', receiptHandle: - // eslint-disable-next-line max-len 'AQEBCpLYzDEKq4aKSJyFQCkJduSKZef8SJVOperbYyNhXqqnpFG5k74WygVAJ4O0+9nybRyeOFThvITOaS21/jeHiI5fgaM9YKuI0oGYeWCIzPQsluW5CMDmtvqv1aA8sXQ5n2x0L9MJkzgdIHTC3YWBFLQ2AxSveOyIHwW+cHLIFCAcZlOaaf0YtaLfGHGkAC4IfycmaijV8NSlzYgDuxrC9sIsWJ0bSvk5iT4ru/R4+0cjm7qZtGlc04k9xk5Fu6A+wRxMaIyiFRY+Ya19ykcevQldidmEjEWvN6CRToLgclk=', - body: { + body: JSON.stringify({ repositoryName: 'self-hosted', repositoryOwner: 'test-runners', eventType: 'workflow_job', id: 987654, installationId: 123456789, - }, + }), attributes: { ApproximateReceiveCount: '1', SentTimestamp: '1626450047230', @@ -34,12 +34,34 @@ const sqsEvent = { ], }; +const context: Context = { + awsRequestId: '1', + callbackWaitsForEmptyEventLoop: false, + functionName: '', + functionVersion: '', + getRemainingTimeInMillis: () => 0, + invokedFunctionArn: '', + logGroupName: '', + logStreamName: '', + memoryLimitInMB: '', + done: () => { + return; + }, + fail: () => { + return; + }, + succeed: () => { + return; + }, +}; + export function run(): void { - scaleUp(sqsEvent.Records[0].eventSource, sqsEvent.Records[0].body as ActionRequestMessage) - .then() - .catch((e) => { - logger.error(e); - }); + try { + scaleUpHandler(sqsEvent, context); + } catch (e: unknown) { + const message = e instanceof Error ? e.message : `${e}`; + logger.error(message, e instanceof Error ? { error: e } : {}); + } } run(); diff --git a/lambdas/functions/control-plane/src/pool/pool.test.ts b/lambdas/functions/control-plane/src/pool/pool.test.ts index 3a7ba3ab1c..76c3d61d4c 100644 --- a/lambdas/functions/control-plane/src/pool/pool.test.ts +++ b/lambdas/functions/control-plane/src/pool/pool.test.ts @@ -188,11 +188,7 @@ describe('Test simple pool.', () => { it('Top up pool with pool size 2 registered.', async () => { await adjust({ poolSize: 3 }); expect(createRunners).toHaveBeenCalledTimes(1); - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 1 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 1, expect.anything()); }); it('Should not top up if pool size is reached.', async () => { @@ -268,11 +264,7 @@ describe('Test simple pool.', () => { it('Top up if the pool size is set to 5', async () => { await adjust({ poolSize: 5 }); // 2 idle, top up with 3 to match a pool of 5 - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 3 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 3, expect.anything()); }); }); @@ -287,11 +279,7 @@ describe('Test simple pool.', () => { it('Top up if the pool size is set to 5', async () => { await adjust({ poolSize: 5 }); // 2 idle, top up with 3 to match a pool of 5 - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 3 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 3, expect.anything()); }); }); @@ -341,11 +329,7 @@ describe('Test simple pool.', () => { await adjust({ poolSize: 5 }); // 2 idle, 2 prefixed idle top up with 1 to match a pool of 5 - expect(createRunners).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ numberOfRunners: 1 }), - expect.anything(), - ); + expect(createRunners).toHaveBeenCalledWith(expect.anything(), expect.anything(), 1, expect.anything()); }); }); }); diff --git a/lambdas/functions/control-plane/src/pool/pool.ts b/lambdas/functions/control-plane/src/pool/pool.ts index 162a7d0f6d..a333c8da61 100644 --- a/lambdas/functions/control-plane/src/pool/pool.ts +++ b/lambdas/functions/control-plane/src/pool/pool.ts @@ -92,11 +92,11 @@ export async function adjust(event: PoolEvent): Promise { environment, launchTemplateName, subnets, - numberOfRunners: topUp, amiIdSsmParameterName, tracingEnabled, onDemandFailoverOnError, }, + topUp, githubInstallationClient, ); } else { diff --git a/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts b/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts index 1edfefb69b..3ee24fb6d4 100644 --- a/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts +++ b/lambdas/functions/control-plane/src/scale-runners/job-retry.test.ts @@ -2,9 +2,11 @@ import { publishMessage } from '../aws/sqs'; import { publishRetryMessage, checkAndRetryJob } from './job-retry'; import { ActionRequestMessage, ActionRequestMessageRetry } from './scale-up'; import { getOctokit } from '../github/octokit'; +import { jobRetryCheck } from '../lambda'; import { Octokit } from '@octokit/rest'; import { createSingleMetric } from '@aws-github-runner/aws-powertools-util'; import { describe, it, expect, beforeEach, vi } from 'vitest'; +import type { SQSRecord } from 'aws-lambda'; vi.mock('../aws/sqs', async () => ({ publishMessage: vi.fn(), @@ -267,3 +269,93 @@ describe(`Test job retry check`, () => { expect(publishMessage).not.toHaveBeenCalled(); }); }); + +describe('Test job retry handler (batch processing)', () => { + const context = { + requestId: 'request-id', + functionName: 'function-name', + functionVersion: 'function-version', + invokedFunctionArn: 'invoked-function-arn', + memoryLimitInMB: '128', + awsRequestId: 'aws-request-id', + logGroupName: 'log-group-name', + logStreamName: 'log-stream-name', + remainingTimeInMillis: () => 30000, + done: () => {}, + fail: () => {}, + succeed: () => {}, + getRemainingTimeInMillis: () => 30000, + callbackWaitsForEmptyEventLoop: false, + }; + + function createSQSRecord(messageId: string): SQSRecord { + return { + messageId, + receiptHandle: 'receipt-handle', + body: JSON.stringify({ + eventType: 'workflow_job', + id: 123, + installationId: 456, + repositoryName: 'test-repo', + repositoryOwner: 'test-owner', + repoOwnerType: 'Organization', + retryCounter: 0, + }), + attributes: { + ApproximateReceiveCount: '1', + SentTimestamp: '1234567890', + SenderId: 'sender-id', + ApproximateFirstReceiveTimestamp: '1234567891', + }, + messageAttributes: {}, + md5OfBody: 'md5', + eventSource: 'aws:sqs', + eventSourceARN: 'arn:aws:sqs:region:account:queue', + awsRegion: 'us-east-1', + }; + } + + beforeEach(() => { + vi.clearAllMocks(); + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.JOB_QUEUE_SCALE_UP_URL = 'https://sqs.example.com/queue'; + }); + + it('should handle multiple records in a single batch', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ + data: { + status: 'queued', + }, + headers: {}, + })); + + const event = { + Records: [createSQSRecord('msg-1'), createSQSRecord('msg-2'), createSQSRecord('msg-3')], + }; + + await expect(jobRetryCheck(event, context)).resolves.not.toThrow(); + expect(publishMessage).toHaveBeenCalledTimes(3); + }); + + it('should continue processing other records when one fails', async () => { + mockCreateOctokitClient + .mockResolvedValueOnce(new Octokit()) // First record succeeds + .mockRejectedValueOnce(new Error('API error')) // Second record fails + .mockResolvedValueOnce(new Octokit()); // Third record succeeds + + mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ + data: { + status: 'queued', + }, + headers: {}, + })); + + const event = { + Records: [createSQSRecord('msg-1'), createSQSRecord('msg-2'), createSQSRecord('msg-3')], + }; + + await expect(jobRetryCheck(event, context)).resolves.not.toThrow(); + // There were two successful calls to publishMessage + expect(publishMessage).toHaveBeenCalledTimes(2); + }); +}); diff --git a/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts b/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts index 0611a6e697..3817bf55cb 100644 --- a/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts +++ b/lambdas/functions/control-plane/src/scale-runners/scale-up.test.ts @@ -1,5 +1,4 @@ import { PutParameterCommand, SSMClient } from '@aws-sdk/client-ssm'; -import { Octokit } from '@octokit/rest'; import { mockClient } from 'aws-sdk-client-mock'; import 'aws-sdk-client-mock-jest/vitest'; // Using vi.mocked instead of jest-mock @@ -9,10 +8,10 @@ import { performance } from 'perf_hooks'; import * as ghAuth from '../github/auth'; import { createRunner, listEC2Runners } from './../aws/runners'; import { RunnerInputParameters } from './../aws/runners.d'; -import ScaleError from './ScaleError'; import * as scaleUpModule from './scale-up'; import { getParameter } from '@aws-github-runner/aws-ssm-util'; import { describe, it, expect, beforeEach, vi } from 'vitest'; +import type { Octokit } from '@octokit/rest'; const mockOctokit = { paginate: vi.fn(), @@ -29,6 +28,7 @@ const mockOctokit = { getRepoInstallation: vi.fn(), }, }; + const mockCreateRunner = vi.mocked(createRunner); const mockListRunners = vi.mocked(listEC2Runners); const mockSSMClient = mockClient(SSMClient); @@ -65,26 +65,33 @@ export type RunnerType = 'ephemeral' | 'non-ephemeral'; // for ephemeral and non-ephemeral runners const RUNNER_TYPES: RunnerType[] = ['ephemeral', 'non-ephemeral']; -const mocktokit = Octokit as vi.MockedClass; const mockedAppAuth = vi.mocked(ghAuth.createGithubAppAuth); const mockedInstallationAuth = vi.mocked(ghAuth.createGithubInstallationAuth); const mockCreateClient = vi.mocked(ghAuth.createOctokitClient); -const TEST_DATA: scaleUpModule.ActionRequestMessage = { +const TEST_DATA_SINGLE: scaleUpModule.ActionRequestMessageSQS = { id: 1, eventType: 'workflow_job', repositoryName: 'hello-world', repositoryOwner: 'Codertocat', installationId: 2, repoOwnerType: 'Organization', + messageId: 'foobar', }; +const TEST_DATA: scaleUpModule.ActionRequestMessageSQS[] = [ + { + ...TEST_DATA_SINGLE, + messageId: 'foobar', + }, +]; + const cleanEnv = process.env; const EXPECTED_RUNNER_PARAMS: RunnerInputParameters = { environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, numberOfRunners: 1, launchTemplateName: 'lt-1', ec2instanceCriteria: { @@ -131,14 +138,14 @@ beforeEach(() => { instanceId: 'i-1234', launchTime: new Date(), type: 'Org', - owner: TEST_DATA.repositoryOwner, + owner: TEST_DATA_SINGLE.repositoryOwner, }, ]); mockedAppAuth.mockResolvedValue({ type: 'app', token: 'token', - appId: TEST_DATA.installationId, + appId: TEST_DATA_SINGLE.installationId, expiresAt: 'some-date', }); mockedInstallationAuth.mockResolvedValue({ @@ -152,7 +159,7 @@ beforeEach(() => { installationId: 0, }); - mockCreateClient.mockResolvedValue(new mocktokit()); + mockCreateClient.mockResolvedValue(mockOctokit as unknown as Octokit); }); describe('scaleUp with GHES', () => { @@ -160,17 +167,12 @@ describe('scaleUp with GHES', () => { process.env.GHES_URL = 'https://github.enterprise.something'; }); - it('ignores non-sqs events', async () => { - expect.assertions(1); - await expect(scaleUpModule.scaleUp('aws:s3', TEST_DATA)).rejects.toEqual(Error('Cannot handle non-SQS events!')); - }); - it('checks queued workflows', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalledWith({ - job_id: TEST_DATA.id, - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + job_id: TEST_DATA_SINGLE.id, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); @@ -178,7 +180,7 @@ describe('scaleUp with GHES', () => { mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ data: { total_count: 0 }, })); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toBeCalled(); }); @@ -197,18 +199,18 @@ describe('scaleUp with GHES', () => { }); it('gets the current org level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); @@ -216,35 +218,35 @@ describe('scaleUp with GHES', () => { it('does create a runner if maximum is set to -1', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '-1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toHaveBeenCalled(); expect(createRunner).toHaveBeenCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, }); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a runner with correct config', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with labels in a specific group', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with ami id override from ssm parameter', async () => { process.env.AMI_ID_SSM_PARAMETER_NAME = 'my-ami-id-param'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith({ ...expectedRunnerParams, amiIdSsmParameterName: 'my-ami-id-param' }); }); @@ -253,15 +255,15 @@ describe('scaleUp with GHES', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toBeInstanceOf(Error); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toBeInstanceOf(Error); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); }); it('Discards event if it is a User repo and org level runners is enabled', async () => { process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; - const USER_REPO_TEST_DATA = { ...TEST_DATA }; - USER_REPO_TEST_DATA.repoOwnerType = 'User'; - await scaleUpModule.scaleUp('aws:sqs', USER_REPO_TEST_DATA); + const USER_REPO_TEST_DATA = structuredClone(TEST_DATA); + USER_REPO_TEST_DATA[0].repoOwnerType = 'User'; + await scaleUpModule.scaleUp(USER_REPO_TEST_DATA); expect(createRunner).not.toHaveBeenCalled(); }); @@ -269,7 +271,7 @@ describe('scaleUp with GHES', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 2); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -280,7 +282,7 @@ describe('scaleUp with GHES', () => { }); it('Does not create SSM parameter for runner group id if it exists', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(0); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 1); }); @@ -288,9 +290,9 @@ describe('scaleUp with GHES', () => { it('create start runner config for ephemeral runners ', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, name: 'unit-test-i-12345', runner_group_id: 1, labels: ['label1', 'label2'], @@ -311,7 +313,7 @@ describe('scaleUp with GHES', () => { it('create start runner config for non-ephemeral runners ', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalled(); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -382,7 +384,7 @@ describe('scaleUp with GHES', () => { 'i-150', 'i-151', ]; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); const endTime = performance.now(); expect(endTime - startTime).toBeGreaterThan(1000); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 40); @@ -396,87 +398,307 @@ describe('scaleUp with GHES', () => { process.env.RUNNER_NAME_PREFIX = 'unit-test'; expectedRunnerParams = { ...EXPECTED_RUNNER_PARAMS }; expectedRunnerParams.runnerType = 'Repo'; - expectedRunnerParams.runnerOwner = `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`; - // `--url https://github.enterprise.something/${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + expectedRunnerParams.runnerOwner = `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`; + // `--url https://github.enterprise.something/${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, // `--token 1234abcd`, // ]; }); it('gets the current repo level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Repo', - runnerOwner: `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + runnerOwner: `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('uses the default runner max count', async () => { process.env.RUNNERS_MAXIMUM_COUNT = undefined; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('creates a runner with correct config and labels', async () => { process.env.RUNNER_LABELS = 'label1,label2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner and ensure the group argument is ignored', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP_IGNORED'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('Check error is thrown', async () => { const mockCreateRunners = vi.mocked(createRunner); mockCreateRunners.mockRejectedValue(new Error('no retry')); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toThrow('no retry'); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toThrow('no retry'); mockCreateRunners.mockReset(); }); }); -}); -describe('scaleUp with public GH', () => { - it('ignores non-sqs events', async () => { - expect.assertions(1); - await expect(scaleUpModule.scaleUp('aws:s3', TEST_DATA)).rejects.toEqual(Error('Cannot handle non-SQS events!')); + describe('Batch processing', () => { + beforeEach(() => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + process.env.RUNNERS_MAXIMUM_COUNT = '10'; + }); + + const createTestMessages = ( + count: number, + overrides: Partial[] = [], + ): scaleUpModule.ActionRequestMessageSQS[] => { + return Array.from({ length: count }, (_, i) => ({ + ...TEST_DATA_SINGLE, + id: i + 1, + messageId: `message-${i}`, + ...overrides[i], + })); + }; + + it('Should handle multiple messages for the same organization', async () => { + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(1); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 3, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, + }), + ); + }); + + it('Should handle multiple messages for different organizations', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'org1' }, + { repositoryOwner: 'org2' }, + { repositoryOwner: 'org1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'org1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'org2', + }), + ); + }); + + it('Should handle multiple messages for different repositories when org-level is disabled', async () => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'false'; + const messages = createTestMessages(3, [ + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + { repositoryOwner: 'owner1', repositoryName: 'repo2' }, + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'owner1/repo1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'owner1/repo2', + }), + ); + }); + + it('Should reject messages when maximum runners limit is reached', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '1'; // Set to 1 so with 1 existing, no new ones can be created + mockListRunners.mockImplementation(async () => [ + { + instanceId: 'i-existing', + launchTime: new Date(), + type: 'Org', + owner: TEST_DATA_SINGLE.repositoryOwner, + }, + ]); + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); // No runners should be created + expect(rejectedMessages).toHaveLength(3); // All 3 messages should be rejected + }); + + it('Should handle partial EC2 instance creation failures', async () => { + mockCreateRunner.mockImplementation(async () => ['i-12345']); // Only creates 1 instead of requested 3 + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(rejectedMessages).toHaveLength(2); // 3 requested - 1 created = 2 failed + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should filter out invalid event types for ephemeral runners', async () => { + const messages = createTestMessages(3, [ + { eventType: 'workflow_job' }, + { eventType: 'check_run' }, + { eventType: 'workflow_job' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only workflow_job events processed + }), + ); + expect(rejectedMessages).toContain('message-1'); // check_run event rejected + }); + + it('Should skip invalid repo owner types but not reject them', async () => { + const messages = createTestMessages(3, [ + { repoOwnerType: 'Organization' }, + { repoOwnerType: 'User' }, // Invalid for org-level runners + { repoOwnerType: 'Organization' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only Organization events processed + }), + ); + expect(rejectedMessages).not.toContain('message-1'); // User repo not rejected, just skipped + }); + + it('Should skip messages when jobs are not queued', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation((params) => { + const isQueued = params.job_id === 1 || params.job_id === 3; // Only jobs 1 and 3 are queued + return { + data: { + status: isQueued ? 'queued' : 'completed', + }, + }; + }); + + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only queued jobs processed + }), + ); + }); + + it('Should create separate GitHub clients for different installations', async () => { + // Override the default mock to return different installation IDs + mockOctokit.apps.getOrgInstallation.mockReset(); + mockOctokit.apps.getOrgInstallation.mockImplementation((params) => ({ + data: { + id: params.org === 'org1' ? 100 : 200, + }, + })); + + const messages = createTestMessages(2, [ + { repositoryOwner: 'org1', installationId: 0 }, + { repositoryOwner: 'org2', installationId: 0 }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(3); // 1 app client, 2 repo installation clients + expect(mockedInstallationAuth).toHaveBeenCalledWith(100, 'https://github.enterprise.something/api/v3'); + expect(mockedInstallationAuth).toHaveBeenCalledWith(200, 'https://github.enterprise.something/api/v3'); + }); + + it('Should reuse GitHub clients for same installation', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(2); // 1 app client, 1 installation client + expect(mockedInstallationAuth).toHaveBeenCalledTimes(1); + }); + + it('Should return empty array when no valid messages to process', async () => { + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + const messages = createTestMessages(2, [ + { eventType: 'check_run' }, // Invalid for ephemeral + { eventType: 'check_run' }, // Invalid for ephemeral + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should handle unlimited runners configuration', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '-1'; + const messages = createTestMessages(10); + + await scaleUpModule.scaleUp(messages); + + expect(listEC2Runners).not.toHaveBeenCalled(); // No need to check current runners + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 10, // All messages processed + }), + ); + }); }); +}); +describe('scaleUp with public GH', () => { it('checks queued workflows', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalledWith({ - job_id: TEST_DATA.id, - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + job_id: TEST_DATA_SINGLE.id, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('not checking queued workflows', async () => { process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); }); @@ -484,7 +706,7 @@ describe('scaleUp with public GH', () => { mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ data: { status: 'completed' }, })); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toBeCalled(); }); @@ -496,38 +718,38 @@ describe('scaleUp with public GH', () => { }); it('gets the current org level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, }); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a runner with correct config', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with labels in s specific group', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); }); @@ -540,44 +762,44 @@ describe('scaleUp with public GH', () => { process.env.RUNNER_NAME_PREFIX = 'unit-test'; expectedRunnerParams = { ...EXPECTED_RUNNER_PARAMS }; expectedRunnerParams.runnerType = 'Repo'; - expectedRunnerParams.runnerOwner = `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`; + expectedRunnerParams.runnerOwner = `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`; }); it('gets the current repo level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Repo', - runnerOwner: `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + runnerOwner: `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('creates a runner with correct config and labels', async () => { process.env.RUNNER_LABELS = 'label1,label2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with correct config and labels and on demand failover enabled.', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.ENABLE_ON_DEMAND_FAILOVER_FOR_ERRORS = JSON.stringify(['InsufficientInstanceCapacity']); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith({ ...expectedRunnerParams, onDemandFailoverOnError: ['InsufficientInstanceCapacity'], @@ -587,26 +809,25 @@ describe('scaleUp with public GH', () => { it('creates a runner and ensure the group argument is ignored', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP_IGNORED'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('ephemeral runners only run with workflow_job event, others should fail.', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; - await expect( - scaleUpModule.scaleUp('aws:sqs', { - ...TEST_DATA, - eventType: 'check_run', - }), - ).rejects.toBeInstanceOf(Error); + + const USER_REPO_TEST_DATA = structuredClone(TEST_DATA); + USER_REPO_TEST_DATA[0].eventType = 'check_run'; + + await expect(scaleUpModule.scaleUp(USER_REPO_TEST_DATA)).resolves.toEqual(['foobar']); }); it('creates a ephemeral runner with JIT config.', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; process.env.SSM_TOKEN_PATH = '/github-action-runners/default/runners/config'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); @@ -628,7 +849,7 @@ describe('scaleUp with public GH', () => { process.env.ENABLE_JIT_CONFIG = 'false'; process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; process.env.SSM_TOKEN_PATH = '/github-action-runners/default/runners/config'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); @@ -651,7 +872,7 @@ describe('scaleUp with public GH', () => { process.env.ENABLE_JOB_QUEUED_CHECK = 'false'; process.env.RUNNER_LABELS = 'jit'; process.env.SSM_TOKEN_PATH = '/github-action-runners/default/runners/config'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).not.toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); @@ -671,21 +892,247 @@ describe('scaleUp with public GH', () => { it('creates a ephemeral runner after checking job is queued.', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; process.env.ENABLE_JOB_QUEUED_CHECK = 'true'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalled(); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('disable auto update on the runner.', async () => { process.env.DISABLE_RUNNER_AUTOUPDATE = 'true'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); - it('Scaling error should cause reject so retry can be triggered.', async () => { + it('Scaling error should return failed message IDs so retry can be triggered.', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toBeInstanceOf(ScaleError); + await expect(scaleUpModule.scaleUp(TEST_DATA)).resolves.toEqual(['foobar']); + }); + }); + + describe('Batch processing', () => { + const createTestMessages = ( + count: number, + overrides: Partial[] = [], + ): scaleUpModule.ActionRequestMessageSQS[] => { + return Array.from({ length: count }, (_, i) => ({ + ...TEST_DATA_SINGLE, + id: i + 1, + messageId: `message-${i}`, + ...overrides[i], + })); + }; + + beforeEach(() => { + setDefaults(); + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + process.env.RUNNERS_MAXIMUM_COUNT = '10'; + }); + + it('Should handle multiple messages for the same organization', async () => { + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(1); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 3, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, + }), + ); + }); + + it('Should handle multiple messages for different organizations', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'org1' }, + { repositoryOwner: 'org2' }, + { repositoryOwner: 'org1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'org1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'org2', + }), + ); + }); + + it('Should handle multiple messages for different repositories when org-level is disabled', async () => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'false'; + const messages = createTestMessages(3, [ + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + { repositoryOwner: 'owner1', repositoryName: 'repo2' }, + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'owner1/repo1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'owner1/repo2', + }), + ); + }); + + it('Should reject messages when maximum runners limit is reached', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '1'; // Set to 1 so with 1 existing, no new ones can be created + mockListRunners.mockImplementation(async () => [ + { + instanceId: 'i-existing', + launchTime: new Date(), + type: 'Org', + owner: TEST_DATA_SINGLE.repositoryOwner, + }, + ]); + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); // No runners should be created + expect(rejectedMessages).toHaveLength(3); // All 3 messages should be rejected + }); + + it('Should handle partial EC2 instance creation failures', async () => { + mockCreateRunner.mockImplementation(async () => ['i-12345']); // Only creates 1 instead of requested 3 + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(rejectedMessages).toHaveLength(2); // 3 requested - 1 created = 2 failed + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should filter out invalid event types for ephemeral runners', async () => { + const messages = createTestMessages(3, [ + { eventType: 'workflow_job' }, + { eventType: 'check_run' }, + { eventType: 'workflow_job' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only workflow_job events processed + }), + ); + expect(rejectedMessages).toContain('message-1'); // check_run event rejected + }); + + it('Should skip invalid repo owner types but not reject them', async () => { + const messages = createTestMessages(3, [ + { repoOwnerType: 'Organization' }, + { repoOwnerType: 'User' }, // Invalid for org-level runners + { repoOwnerType: 'Organization' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only Organization events processed + }), + ); + expect(rejectedMessages).not.toContain('message-1'); // User repo not rejected, just skipped + }); + + it('Should skip messages when jobs are not queued', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation((params) => { + const isQueued = params.job_id === 1 || params.job_id === 3; // Only jobs 1 and 3 are queued + return { + data: { + status: isQueued ? 'queued' : 'completed', + }, + }; + }); + + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only queued jobs processed + }), + ); + }); + + it('Should create separate GitHub clients for different installations', async () => { + // Override the default mock to return different installation IDs + mockOctokit.apps.getOrgInstallation.mockReset(); + mockOctokit.apps.getOrgInstallation.mockImplementation((params) => ({ + data: { + id: params.org === 'org1' ? 100 : 200, + }, + })); + + const messages = createTestMessages(2, [ + { repositoryOwner: 'org1', installationId: 0 }, + { repositoryOwner: 'org2', installationId: 0 }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(3); // 1 app client, 2 repo installation clients + expect(mockedInstallationAuth).toHaveBeenCalledWith(100, ''); + expect(mockedInstallationAuth).toHaveBeenCalledWith(200, ''); + }); + + it('Should reuse GitHub clients for same installation', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(2); // 1 app client, 1 installation client + expect(mockedInstallationAuth).toHaveBeenCalledTimes(1); + }); + + it('Should return empty array when no valid messages to process', async () => { + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + const messages = createTestMessages(2, [ + { eventType: 'check_run' }, // Invalid for ephemeral + { eventType: 'check_run' }, // Invalid for ephemeral + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should handle unlimited runners configuration', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '-1'; + const messages = createTestMessages(10); + + await scaleUpModule.scaleUp(messages); + + expect(listEC2Runners).not.toHaveBeenCalled(); // No need to check current runners + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 10, // All messages processed + }), + ); }); }); }); @@ -695,17 +1142,12 @@ describe('scaleUp with Github Data Residency', () => { process.env.GHES_URL = 'https://companyname.ghe.com'; }); - it('ignores non-sqs events', async () => { - expect.assertions(1); - await expect(scaleUpModule.scaleUp('aws:s3', TEST_DATA)).rejects.toEqual(Error('Cannot handle non-SQS events!')); - }); - it('checks queued workflows', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.getJobForWorkflowRun).toBeCalledWith({ - job_id: TEST_DATA.id, - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + job_id: TEST_DATA_SINGLE.id, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); @@ -713,7 +1155,7 @@ describe('scaleUp with Github Data Residency', () => { mockOctokit.actions.getJobForWorkflowRun.mockImplementation(() => ({ data: { total_count: 0 }, })); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toBeCalled(); }); @@ -732,18 +1174,18 @@ describe('scaleUp with Github Data Residency', () => { }); it('gets the current org level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Org', - runnerOwner: TEST_DATA.repositoryOwner, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); @@ -751,35 +1193,35 @@ describe('scaleUp with Github Data Residency', () => { it('does create a runner if maximum is set to -1', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '-1'; process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).not.toHaveBeenCalled(); expect(createRunner).toHaveBeenCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, }); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a runner with correct config', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with labels in a specific group', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner with ami id override from ssm parameter', async () => { process.env.AMI_ID_SSM_PARAMETER_NAME = 'my-ami-id-param'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith({ ...expectedRunnerParams, amiIdSsmParameterName: 'my-ami-id-param' }); }); @@ -788,15 +1230,15 @@ describe('scaleUp with Github Data Residency', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toBeInstanceOf(Error); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toBeInstanceOf(Error); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); }); it('Discards event if it is a User repo and org level runners is enabled', async () => { process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; - const USER_REPO_TEST_DATA = { ...TEST_DATA }; - USER_REPO_TEST_DATA.repoOwnerType = 'User'; - await scaleUpModule.scaleUp('aws:sqs', USER_REPO_TEST_DATA); + const USER_REPO_TEST_DATA = structuredClone(TEST_DATA); + USER_REPO_TEST_DATA[0].repoOwnerType = 'User'; + await scaleUpModule.scaleUp(USER_REPO_TEST_DATA); expect(createRunner).not.toHaveBeenCalled(); }); @@ -804,7 +1246,7 @@ describe('scaleUp with Github Data Residency', () => { mockSSMgetParameter.mockImplementation(async () => { throw new Error('ParameterNotFound'); }); - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(1); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 2); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -815,7 +1257,7 @@ describe('scaleUp with Github Data Residency', () => { }); it('Does not create SSM parameter for runner group id if it exists', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.paginate).toHaveBeenCalledTimes(0); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 1); }); @@ -823,9 +1265,9 @@ describe('scaleUp with Github Data Residency', () => { it('create start runner config for ephemeral runners ', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).toBeCalledWith({ - org: TEST_DATA.repositoryOwner, + org: TEST_DATA_SINGLE.repositoryOwner, name: 'unit-test-i-12345', runner_group_id: 1, labels: ['label1', 'label2'], @@ -846,7 +1288,7 @@ describe('scaleUp with Github Data Residency', () => { it('create start runner config for non-ephemeral runners ', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; process.env.RUNNERS_MAXIMUM_COUNT = '2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.generateRunnerJitconfigForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForOrg).toBeCalled(); expect(mockSSMClient).toHaveReceivedNthSpecificCommandWith(1, PutParameterCommand, { @@ -917,7 +1359,7 @@ describe('scaleUp with Github Data Residency', () => { 'i-150', 'i-151', ]; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); const endTime = performance.now(); expect(endTime - startTime).toBeGreaterThan(1000); expect(mockSSMClient).toHaveReceivedCommandTimes(PutParameterCommand, 40); @@ -931,67 +1373,295 @@ describe('scaleUp with Github Data Residency', () => { process.env.RUNNER_NAME_PREFIX = 'unit-test'; expectedRunnerParams = { ...EXPECTED_RUNNER_PARAMS }; expectedRunnerParams.runnerType = 'Repo'; - expectedRunnerParams.runnerOwner = `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`; - // `--url https://companyname.ghe.com${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + expectedRunnerParams.runnerOwner = `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`; + // `--url https://companyname.ghe.com${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, // `--token 1234abcd`, // ]; }); it('gets the current repo level runners', async () => { - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(listEC2Runners).toBeCalledWith({ environment: 'unit-test-environment', runnerType: 'Repo', - runnerOwner: `${TEST_DATA.repositoryOwner}/${TEST_DATA.repositoryName}`, + runnerOwner: `${TEST_DATA_SINGLE.repositoryOwner}/${TEST_DATA_SINGLE.repositoryName}`, }); }); it('does not create a token when maximum runners has been reached', async () => { process.env.RUNNERS_MAXIMUM_COUNT = '1'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).not.toBeCalled(); }); it('creates a token when maximum runners has not been reached', async () => { process.env.ENABLE_EPHEMERAL_RUNNERS = 'false'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForOrg).not.toBeCalled(); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('uses the default runner max count', async () => { process.env.RUNNERS_MAXIMUM_COUNT = undefined; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(mockOctokit.actions.createRegistrationTokenForRepo).toBeCalledWith({ - owner: TEST_DATA.repositoryOwner, - repo: TEST_DATA.repositoryName, + owner: TEST_DATA_SINGLE.repositoryOwner, + repo: TEST_DATA_SINGLE.repositoryName, }); }); it('creates a runner with correct config and labels', async () => { process.env.RUNNER_LABELS = 'label1,label2'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('creates a runner and ensure the group argument is ignored', async () => { process.env.RUNNER_LABELS = 'label1,label2'; process.env.RUNNER_GROUP_NAME = 'TEST_GROUP_IGNORED'; - await scaleUpModule.scaleUp('aws:sqs', TEST_DATA); + await scaleUpModule.scaleUp(TEST_DATA); expect(createRunner).toBeCalledWith(expectedRunnerParams); }); it('Check error is thrown', async () => { const mockCreateRunners = vi.mocked(createRunner); mockCreateRunners.mockRejectedValue(new Error('no retry')); - await expect(scaleUpModule.scaleUp('aws:sqs', TEST_DATA)).rejects.toThrow('no retry'); + await expect(scaleUpModule.scaleUp(TEST_DATA)).rejects.toThrow('no retry'); mockCreateRunners.mockReset(); }); }); + + describe('Batch processing', () => { + const createTestMessages = ( + count: number, + overrides: Partial[] = [], + ): scaleUpModule.ActionRequestMessageSQS[] => { + return Array.from({ length: count }, (_, i) => ({ + ...TEST_DATA_SINGLE, + id: i + 1, + messageId: `message-${i}`, + ...overrides[i], + })); + }; + + beforeEach(() => { + setDefaults(); + process.env.ENABLE_ORGANIZATION_RUNNERS = 'true'; + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + process.env.RUNNERS_MAXIMUM_COUNT = '10'; + }); + + it('Should handle multiple messages for the same organization', async () => { + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(1); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 3, + runnerOwner: TEST_DATA_SINGLE.repositoryOwner, + }), + ); + }); + + it('Should handle multiple messages for different organizations', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'org1' }, + { repositoryOwner: 'org2' }, + { repositoryOwner: 'org1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'org1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'org2', + }), + ); + }); + + it('Should handle multiple messages for different repositories when org-level is disabled', async () => { + process.env.ENABLE_ORGANIZATION_RUNNERS = 'false'; + const messages = createTestMessages(3, [ + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + { repositoryOwner: 'owner1', repositoryName: 'repo2' }, + { repositoryOwner: 'owner1', repositoryName: 'repo1' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledTimes(2); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, + runnerOwner: 'owner1/repo1', + }), + ); + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, + runnerOwner: 'owner1/repo2', + }), + ); + }); + + it('Should reject messages when maximum runners limit is reached', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '2'; + mockListRunners.mockImplementation(async () => [ + { + instanceId: 'i-existing', + launchTime: new Date(), + type: 'Org', + owner: TEST_DATA_SINGLE.repositoryOwner, + }, + ]); + + const messages = createTestMessages(5); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 1, // 2 max - 1 existing = 1 new + }), + ); + expect(rejectedMessages).toHaveLength(4); // 5 requested - 1 created = 4 rejected + }); + + it('Should handle partial EC2 instance creation failures', async () => { + mockCreateRunner.mockImplementation(async () => ['i-12345']); // Only creates 1 instead of requested 3 + + const messages = createTestMessages(3); + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(rejectedMessages).toHaveLength(2); // 3 requested - 1 created = 2 failed + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should filter out invalid event types for ephemeral runners', async () => { + const messages = createTestMessages(3, [ + { eventType: 'workflow_job' }, + { eventType: 'check_run' }, + { eventType: 'workflow_job' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only workflow_job events processed + }), + ); + expect(rejectedMessages).toContain('message-1'); // check_run event rejected + }); + + it('Should skip invalid repo owner types but not reject them', async () => { + const messages = createTestMessages(3, [ + { repoOwnerType: 'Organization' }, + { repoOwnerType: 'User' }, // Invalid for org-level runners + { repoOwnerType: 'Organization' }, + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only Organization events processed + }), + ); + expect(rejectedMessages).not.toContain('message-1'); // User repo not rejected, just skipped + }); + + it('Should skip messages when jobs are not queued', async () => { + mockOctokit.actions.getJobForWorkflowRun.mockImplementation((params) => { + const isQueued = params.job_id === 1 || params.job_id === 3; // Only jobs 1 and 3 are queued + return { + data: { + status: isQueued ? 'queued' : 'completed', + }, + }; + }); + + const messages = createTestMessages(3); + await scaleUpModule.scaleUp(messages); + + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 2, // Only queued jobs processed + }), + ); + }); + + it('Should create separate GitHub clients for different installations', async () => { + mockOctokit.apps.getOrgInstallation.mockImplementation((params) => ({ + data: { + id: params.org === 'org1' ? 100 : 200, + }, + })); + + const messages = createTestMessages(2, [ + { repositoryOwner: 'org1', installationId: 0 }, + { repositoryOwner: 'org2', installationId: 0 }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(3); // 1 app client, 2 repo installation clients + expect(mockedInstallationAuth).toHaveBeenCalledWith(100, ''); + expect(mockedInstallationAuth).toHaveBeenCalledWith(200, ''); + }); + + it('Should reuse GitHub clients for same installation', async () => { + const messages = createTestMessages(3, [ + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + { repositoryOwner: 'same-org' }, + ]); + + await scaleUpModule.scaleUp(messages); + + expect(mockCreateClient).toHaveBeenCalledTimes(2); // 1 app client, 1 installation client + expect(mockedInstallationAuth).toHaveBeenCalledTimes(1); + }); + + it('Should return empty array when no valid messages to process', async () => { + process.env.ENABLE_EPHEMERAL_RUNNERS = 'true'; + const messages = createTestMessages(2, [ + { eventType: 'check_run' }, // Invalid for ephemeral + { eventType: 'check_run' }, // Invalid for ephemeral + ]); + + const rejectedMessages = await scaleUpModule.scaleUp(messages); + + expect(createRunner).not.toHaveBeenCalled(); + expect(rejectedMessages).toEqual(['message-0', 'message-1']); + }); + + it('Should handle unlimited runners configuration', async () => { + process.env.RUNNERS_MAXIMUM_COUNT = '-1'; + const messages = createTestMessages(10); + + await scaleUpModule.scaleUp(messages); + + expect(listEC2Runners).not.toHaveBeenCalled(); // No need to check current runners + expect(createRunner).toHaveBeenCalledWith( + expect.objectContaining({ + numberOfRunners: 10, // All messages processed + }), + ); + }); + }); }); function defaultOctokitMockImpl() { @@ -1029,12 +1699,12 @@ function defaultOctokitMockImpl() { }; const mockInstallationIdReturnValueOrgs = { data: { - id: TEST_DATA.installationId, + id: TEST_DATA_SINGLE.installationId, }, }; const mockInstallationIdReturnValueRepos = { data: { - id: TEST_DATA.installationId, + id: TEST_DATA_SINGLE.installationId, }, }; diff --git a/lambdas/functions/control-plane/src/scale-runners/scale-up.ts b/lambdas/functions/control-plane/src/scale-runners/scale-up.ts index 08d16d682a..417a033121 100644 --- a/lambdas/functions/control-plane/src/scale-runners/scale-up.ts +++ b/lambdas/functions/control-plane/src/scale-runners/scale-up.ts @@ -6,8 +6,6 @@ import yn from 'yn'; import { createGithubAppAuth, createGithubInstallationAuth, createOctokitClient } from '../github/auth'; import { createRunner, listEC2Runners } from './../aws/runners'; import { RunnerInputParameters } from './../aws/runners.d'; -import ScaleError from './ScaleError'; -import { publishRetryMessage } from './job-retry'; import { metricGitHubAppRateLimit } from '../github/rate-limit'; const logger = createChildLogger('scale-up'); @@ -33,6 +31,10 @@ export interface ActionRequestMessage { retryCounter?: number; } +export interface ActionRequestMessageSQS extends ActionRequestMessage { + messageId: string; +} + export interface ActionRequestMessageRetry extends ActionRequestMessage { retryCounter: number; } @@ -114,7 +116,7 @@ function removeTokenFromLogging(config: string[]): string[] { } export async function getInstallationId( - ghesApiUrl: string, + githubAppClient: Octokit, enableOrgLevel: boolean, payload: ActionRequestMessage, ): Promise { @@ -122,16 +124,14 @@ export async function getInstallationId( return payload.installationId; } - const ghAuth = await createGithubAppAuth(undefined, ghesApiUrl); - const githubClient = await createOctokitClient(ghAuth.token, ghesApiUrl); return enableOrgLevel ? ( - await githubClient.apps.getOrgInstallation({ + await githubAppClient.apps.getOrgInstallation({ org: payload.repositoryOwner, }) ).data.id : ( - await githubClient.apps.getRepoInstallation({ + await githubAppClient.apps.getRepoInstallation({ owner: payload.repositoryOwner, repo: payload.repositoryName, }) @@ -211,23 +211,27 @@ async function getRunnerGroupByName(ghClient: Octokit, githubRunnerConfig: Creat export async function createRunners( githubRunnerConfig: CreateGitHubRunnerConfig, ec2RunnerConfig: CreateEC2RunnerConfig, + numberOfRunners: number, ghClient: Octokit, -): Promise { +): Promise { const instances = await createRunner({ runnerType: githubRunnerConfig.runnerType, runnerOwner: githubRunnerConfig.runnerOwner, - numberOfRunners: 1, + numberOfRunners, ...ec2RunnerConfig, }); if (instances.length !== 0) { await createStartRunnerConfig(githubRunnerConfig, instances, ghClient); } + + return instances; } -export async function scaleUp(eventSource: string, payload: ActionRequestMessage): Promise { - logger.info(`Received ${payload.eventType} from ${payload.repositoryOwner}/${payload.repositoryName}`); +export async function scaleUp(payloads: ActionRequestMessageSQS[]): Promise { + logger.info('Received scale up requests', { + n_requests: payloads.length, + }); - if (eventSource !== 'aws:sqs') throw Error('Cannot handle non-SQS events!'); const enableOrgLevel = yn(process.env.ENABLE_ORGANIZATION_RUNNERS, { default: true }); const maximumRunners = parseInt(process.env.RUNNERS_MAXIMUM_COUNT || '3'); const runnerLabels = process.env.RUNNER_LABELS || ''; @@ -252,103 +256,195 @@ export async function scaleUp(eventSource: string, payload: ActionRequestMessage ? (JSON.parse(process.env.ENABLE_ON_DEMAND_FAILOVER_FOR_ERRORS) as [string]) : []; - if (ephemeralEnabled && payload.eventType !== 'workflow_job') { - logger.warn(`${payload.eventType} event is not supported in combination with ephemeral runners.`); - throw Error( - `The event type ${payload.eventType} is not supported in combination with ephemeral runners.` + - `Please ensure you have enabled workflow_job events.`, - ); - } + const { ghesApiUrl, ghesBaseUrl } = getGitHubEnterpriseApiUrl(); - if (!isValidRepoOwnerTypeIfOrgLevelEnabled(payload, enableOrgLevel)) { - logger.warn( - `Repository ${payload.repositoryOwner}/${payload.repositoryName} does not belong to a GitHub` + - `organization and organization runners are enabled. This is not supported. Not scaling up for this event.` + - `Not throwing error to prevent re-queueing and just ignoring the event.`, - ); - return; + const ghAuth = await createGithubAppAuth(undefined, ghesApiUrl); + const githubAppClient = await createOctokitClient(ghAuth.token, ghesApiUrl); + + // A map of either owner or owner/repo name to Octokit client, so we use a + // single client per installation (set of messages), depending on how the app + // is installed. This is for a couple of reasons: + // - Sharing clients opens up the possibility of caching API calls. + // - Fetching a client for an installation actually requires a couple of API + // calls itself, which would get expensive if done for every message in a + // batch. + type MessagesWithClient = { + messages: ActionRequestMessageSQS[]; + githubInstallationClient: Octokit; + }; + + const validMessages = new Map(); + const invalidMessages: string[] = []; + for (const payload of payloads) { + const { eventType, messageId, repositoryName, repositoryOwner } = payload; + if (ephemeralEnabled) { + if (eventType !== 'workflow_job') { + logger.warn( + 'Event is not supported in combination with ephemeral runners. Please ensure you have enabled workflow_job events.', + { eventType, messageId }, + ); + + invalidMessages.push(messageId); + + continue; + } + } + + if (!isValidRepoOwnerTypeIfOrgLevelEnabled(payload, enableOrgLevel)) { + logger.warn( + `Repository does not belong to a GitHub organization and organization runners are enabled. This is not supported. Not scaling up for this event. Not throwing error to prevent re-queueing and just ignoring the event.`, + { + repository: `${repositoryOwner}/${repositoryName}`, + messageId, + }, + ); + + continue; + } + + const key = enableOrgLevel ? payload.repositoryOwner : `${payload.repositoryOwner}/${payload.repositoryName}`; + + let entry = validMessages.get(key); + + // If we've not seen this owner/repo before, we'll need to create a GitHub + // client for it. + if (entry === undefined) { + const installationId = await getInstallationId(githubAppClient, enableOrgLevel, payload); + const ghAuth = await createGithubInstallationAuth(installationId, ghesApiUrl); + const githubInstallationClient = await createOctokitClient(ghAuth.token, ghesApiUrl); + + entry = { + messages: [], + githubInstallationClient, + }; + + validMessages.set(key, entry); + } + + entry.messages.push(payload); } - const ephemeral = ephemeralEnabled && payload.eventType === 'workflow_job'; const runnerType = enableOrgLevel ? 'Org' : 'Repo'; - const runnerOwner = enableOrgLevel ? payload.repositoryOwner : `${payload.repositoryOwner}/${payload.repositoryName}`; addPersistentContextToChildLogger({ runner: { + ephemeral: ephemeralEnabled, type: runnerType, - owner: runnerOwner, namePrefix: runnerNamePrefix, - }, - github: { - event: payload.eventType, - workflow_job_id: payload.id.toString(), + n_events: Array.from(validMessages.values()).reduce((acc, group) => acc + group.messages.length, 0), }, }); - logger.info(`Received event`); + logger.info(`Received events`); - const { ghesApiUrl, ghesBaseUrl } = getGitHubEnterpriseApiUrl(); + for (const [group, { githubInstallationClient, messages }] of validMessages.entries()) { + // Work out how much we want to scale up by. + let scaleUp = 0; - const installationId = await getInstallationId(ghesApiUrl, enableOrgLevel, payload); - const ghAuth = await createGithubInstallationAuth(installationId, ghesApiUrl); - const githubInstallationClient = await createOctokitClient(ghAuth.token, ghesApiUrl); + for (const message of messages) { + const messageLogger = logger.createChild({ + persistentKeys: { + eventType: message.eventType, + group, + messageId: message.messageId, + repository: `${message.repositoryOwner}/${message.repositoryName}`, + }, + }); - if (!enableJobQueuedCheck || (await isJobQueued(githubInstallationClient, payload))) { - let scaleUp = true; - if (maximumRunners !== -1) { - const currentRunners = await listEC2Runners({ - environment, - runnerType, - runnerOwner, + if (enableJobQueuedCheck && !(await isJobQueued(githubInstallationClient, message))) { + messageLogger.info('No runner will be created, job is not queued.'); + + continue; + } + + scaleUp++; + } + + if (scaleUp === 0) { + logger.info('No runners will be created for this group, no valid messages found.'); + + continue; + } + + // Don't call the EC2 API if we can create an unlimited nur of runners. + const currentRunners = + maximumRunners === -1 ? 0 : (await listEC2Runners({ environment, runnerType, runnerOwner: group })).length; + + logger.info('Current runners', { + currentRunners, + maximumRunners, + }); + + // Calculate how many runners we want to create. + const newRunners = + maximumRunners === -1 + ? // If we don't have an upper limit, scale up by the number of new jobs. + scaleUp + : // Otherwise, we do have a limit, so work out if `scaleUp` would exceed it. + Math.min(scaleUp, maximumRunners - currentRunners); + + if (newRunners <= 0) { + logger.info('No runners will be created for this group, maximum number of runners reached.', { + desiredNewRunners: scaleUp, }); - logger.info(`Current runners: ${currentRunners.length} of ${maximumRunners}`); - scaleUp = currentRunners.length < maximumRunners; + + invalidMessages.push(...messages.map((message) => message.messageId)); + + continue; } - if (scaleUp) { - logger.info(`Attempting to launch a new runner`); + logger.info(`Attempting to launch new runners`, { + newRunners, + }); - await createRunners( - { - ephemeral, - enableJitConfig, - ghesBaseUrl, - runnerLabels, - runnerGroup, - runnerNamePrefix, - runnerOwner, - runnerType, - disableAutoUpdate, - ssmTokenPath, - ssmConfigPath, - }, - { - ec2instanceCriteria: { - instanceTypes, - targetCapacityType: instanceTargetCapacityType, - maxSpotPrice: instanceMaxSpotPrice, - instanceAllocationStrategy: instanceAllocationStrategy, - }, - environment, - launchTemplateName, - subnets, - amiIdSsmParameterName, - tracingEnabled, - onDemandFailoverOnError, + const instances = await createRunners( + { + ephemeral: ephemeralEnabled, + enableJitConfig, + ghesBaseUrl, + runnerLabels, + runnerGroup, + runnerNamePrefix, + runnerOwner: group, + runnerType, + disableAutoUpdate, + ssmTokenPath, + ssmConfigPath, + }, + { + ec2instanceCriteria: { + instanceTypes, + targetCapacityType: instanceTargetCapacityType, + maxSpotPrice: instanceMaxSpotPrice, + instanceAllocationStrategy: instanceAllocationStrategy, }, - githubInstallationClient, - ); + environment, + launchTemplateName, + subnets, + amiIdSsmParameterName, + tracingEnabled, + onDemandFailoverOnError, + }, + newRunners, + githubInstallationClient, + ); - await publishRetryMessage(payload); - } else { - logger.info('No runner will be created, maximum number of runners reached.'); - if (ephemeral) { - throw new ScaleError('No runners create: maximum of runners reached.'); - } + // Not all runners we wanted were created, let's reject enough items so that + // number of entries will be retried. + if (instances.length !== scaleUp) { + const missingInstanceCount = scaleUp - instances.length; + + logger.warn('Not enough runners were created, rejecting some messages so the requests are retried', { + wanted: newRunners, + got: instances.length, + missingInstanceCount, + }); + + invalidMessages.push(...messages.slice(0, missingInstanceCount).map((message) => message.messageId)); } - } else { - logger.info('No runner will be created, job is not queued.'); } + + return invalidMessages; } export function getGitHubEnterpriseApiUrl() { diff --git a/lambdas/libs/aws-powertools-util/src/logger/index.ts b/lambdas/libs/aws-powertools-util/src/logger/index.ts index 195b552a74..2bad191a83 100644 --- a/lambdas/libs/aws-powertools-util/src/logger/index.ts +++ b/lambdas/libs/aws-powertools-util/src/logger/index.ts @@ -9,7 +9,7 @@ const defaultValues = { }; function setContext(context: Context, module?: string) { - logger.addPersistentLogAttributes({ + logger.appendPersistentKeys({ 'aws-request-id': context.awsRequestId, 'function-name': context.functionName, module: module, @@ -17,7 +17,7 @@ function setContext(context: Context, module?: string) { // Add the context to all child loggers childLoggers.forEach((childLogger) => { - childLogger.addPersistentLogAttributes({ + childLogger.appendPersistentKeys({ 'aws-request-id': context.awsRequestId, 'function-name': context.functionName, }); @@ -25,14 +25,14 @@ function setContext(context: Context, module?: string) { } const logger = new Logger({ - persistentLogAttributes: { + persistentKeys: { ...defaultValues, }, }); function createChildLogger(module: string): Logger { const childLogger = logger.createChild({ - persistentLogAttributes: { + persistentKeys: { module: module, }, }); @@ -47,7 +47,7 @@ type LogAttributes = { function addPersistentContextToChildLogger(attributes: LogAttributes) { childLoggers.forEach((childLogger) => { - childLogger.addPersistentLogAttributes(attributes); + childLogger.appendPersistentKeys(attributes); }); } diff --git a/main.tf b/main.tf index 759b8169f6..7f6bcfedf5 100644 --- a/main.tf +++ b/main.tf @@ -209,28 +209,30 @@ module "runners" { metadata_options = var.runner_metadata_options credit_specification = var.runner_credit_specification - enable_runner_binaries_syncer = var.enable_runner_binaries_syncer - lambda_s3_bucket = var.lambda_s3_bucket - runners_lambda_s3_key = var.runners_lambda_s3_key - runners_lambda_s3_object_version = var.runners_lambda_s3_object_version - lambda_runtime = var.lambda_runtime - lambda_architecture = var.lambda_architecture - lambda_zip = var.runners_lambda_zip - lambda_scale_up_memory_size = var.runners_scale_up_lambda_memory_size - lambda_scale_down_memory_size = var.runners_scale_down_lambda_memory_size - lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout - lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout - lambda_subnet_ids = var.lambda_subnet_ids - lambda_security_group_ids = var.lambda_security_group_ids - lambda_tags = var.lambda_tags - tracing_config = var.tracing_config - logging_retention_in_days = var.logging_retention_in_days - logging_kms_key_id = var.logging_kms_key_id - enable_cloudwatch_agent = var.enable_cloudwatch_agent - cloudwatch_config = var.cloudwatch_config - runner_log_files = var.runner_log_files - runner_group_name = var.runner_group_name - runner_name_prefix = var.runner_name_prefix + enable_runner_binaries_syncer = var.enable_runner_binaries_syncer + lambda_s3_bucket = var.lambda_s3_bucket + runners_lambda_s3_key = var.runners_lambda_s3_key + runners_lambda_s3_object_version = var.runners_lambda_s3_object_version + lambda_runtime = var.lambda_runtime + lambda_architecture = var.lambda_architecture + lambda_event_source_mapping_batch_size = var.lambda_event_source_mapping_batch_size + lambda_event_source_mapping_maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds + lambda_zip = var.runners_lambda_zip + lambda_scale_up_memory_size = var.runners_scale_up_lambda_memory_size + lambda_scale_down_memory_size = var.runners_scale_down_lambda_memory_size + lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout + lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout + lambda_subnet_ids = var.lambda_subnet_ids + lambda_security_group_ids = var.lambda_security_group_ids + lambda_tags = var.lambda_tags + tracing_config = var.tracing_config + logging_retention_in_days = var.logging_retention_in_days + logging_kms_key_id = var.logging_kms_key_id + enable_cloudwatch_agent = var.enable_cloudwatch_agent + cloudwatch_config = var.cloudwatch_config + runner_log_files = var.runner_log_files + runner_group_name = var.runner_group_name + runner_name_prefix = var.runner_name_prefix scale_up_reserved_concurrent_executions = var.scale_up_reserved_concurrent_executions diff --git a/modules/multi-runner/README.md b/modules/multi-runner/README.md index c43d00e245..3c32b29406 100644 --- a/modules/multi-runner/README.md +++ b/modules/multi-runner/README.md @@ -137,6 +137,8 @@ module "multi-runner" { | [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no | | [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. | `string` | `null` | no | | [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no | +| [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no | +| [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no | | [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. |
list(object({
type = string
identifiers = list(string)
}))
| `[]` | no | | [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no | | [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no | diff --git a/modules/multi-runner/runners.tf b/modules/multi-runner/runners.tf index 9f5d1bb456..e36ed7324f 100644 --- a/modules/multi-runner/runners.tf +++ b/modules/multi-runner/runners.tf @@ -57,28 +57,30 @@ module "runners" { metadata_options = each.value.runner_config.runner_metadata_options credit_specification = each.value.runner_config.credit_specification - enable_runner_binaries_syncer = each.value.runner_config.enable_runner_binaries_syncer - lambda_s3_bucket = var.lambda_s3_bucket - runners_lambda_s3_key = var.runners_lambda_s3_key - runners_lambda_s3_object_version = var.runners_lambda_s3_object_version - lambda_runtime = var.lambda_runtime - lambda_architecture = var.lambda_architecture - lambda_zip = var.runners_lambda_zip - lambda_scale_up_memory_size = var.scale_up_lambda_memory_size - lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout - lambda_scale_down_memory_size = var.scale_down_lambda_memory_size - lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout - lambda_subnet_ids = var.lambda_subnet_ids - lambda_security_group_ids = var.lambda_security_group_ids - lambda_tags = var.lambda_tags - tracing_config = var.tracing_config - logging_retention_in_days = var.logging_retention_in_days - logging_kms_key_id = var.logging_kms_key_id - enable_cloudwatch_agent = each.value.runner_config.enable_cloudwatch_agent - cloudwatch_config = try(coalesce(each.value.runner_config.cloudwatch_config, var.cloudwatch_config), null) - runner_log_files = each.value.runner_config.runner_log_files - runner_group_name = each.value.runner_config.runner_group_name - runner_name_prefix = each.value.runner_config.runner_name_prefix + enable_runner_binaries_syncer = each.value.runner_config.enable_runner_binaries_syncer + lambda_s3_bucket = var.lambda_s3_bucket + runners_lambda_s3_key = var.runners_lambda_s3_key + runners_lambda_s3_object_version = var.runners_lambda_s3_object_version + lambda_runtime = var.lambda_runtime + lambda_architecture = var.lambda_architecture + lambda_zip = var.runners_lambda_zip + lambda_scale_up_memory_size = var.scale_up_lambda_memory_size + lambda_event_source_mapping_batch_size = var.lambda_event_source_mapping_batch_size + lambda_event_source_mapping_maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds + lambda_timeout_scale_up = var.runners_scale_up_lambda_timeout + lambda_scale_down_memory_size = var.scale_down_lambda_memory_size + lambda_timeout_scale_down = var.runners_scale_down_lambda_timeout + lambda_subnet_ids = var.lambda_subnet_ids + lambda_security_group_ids = var.lambda_security_group_ids + lambda_tags = var.lambda_tags + tracing_config = var.tracing_config + logging_retention_in_days = var.logging_retention_in_days + logging_kms_key_id = var.logging_kms_key_id + enable_cloudwatch_agent = each.value.runner_config.enable_cloudwatch_agent + cloudwatch_config = try(coalesce(each.value.runner_config.cloudwatch_config, var.cloudwatch_config), null) + runner_log_files = each.value.runner_config.runner_log_files + runner_group_name = each.value.runner_config.runner_group_name + runner_name_prefix = each.value.runner_config.runner_name_prefix scale_up_reserved_concurrent_executions = each.value.runner_config.scale_up_reserved_concurrent_executions diff --git a/modules/multi-runner/variables.tf b/modules/multi-runner/variables.tf index b138205459..301eabfc0b 100644 --- a/modules/multi-runner/variables.tf +++ b/modules/multi-runner/variables.tf @@ -714,3 +714,15 @@ variable "user_agent" { type = string default = "github-aws-runners" } + +variable "lambda_event_source_mapping_batch_size" { + description = "Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used." + type = number + default = 10 +} + +variable "lambda_event_source_mapping_maximum_batching_window_in_seconds" { + description = "Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. Defaults to 0." + type = number + default = 0 +} diff --git a/modules/runners/README.md b/modules/runners/README.md index f7dd7ecb88..9eed4ca1b1 100644 --- a/modules/runners/README.md +++ b/modules/runners/README.md @@ -173,6 +173,8 @@ yarn run dist | [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no | | [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. | `string` | `null` | no | | [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no | +| [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no | +| [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no | | [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no | | [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no | | [lambda\_scale\_down\_memory\_size](#input\_lambda\_scale\_down\_memory\_size) | Memory size limit in MB for scale down lambda. | `number` | `512` | no | diff --git a/modules/runners/job-retry.tf b/modules/runners/job-retry.tf index e51c3903d4..130992667f 100644 --- a/modules/runners/job-retry.tf +++ b/modules/runners/job-retry.tf @@ -3,30 +3,32 @@ locals { job_retry_enabled = var.job_retry != null && var.job_retry.enable ? true : false job_retry = { - prefix = var.prefix - tags = local.tags - aws_partition = var.aws_partition - architecture = var.lambda_architecture - runtime = var.lambda_runtime - security_group_ids = var.lambda_security_group_ids - subnet_ids = var.lambda_subnet_ids - kms_key_arn = var.kms_key_arn - lambda_tags = var.lambda_tags - log_level = var.log_level - logging_kms_key_id = var.logging_kms_key_id - logging_retention_in_days = var.logging_retention_in_days - metrics = var.metrics - role_path = var.role_path - role_permissions_boundary = var.role_permissions_boundary - s3_bucket = var.lambda_s3_bucket - s3_key = var.runners_lambda_s3_key - s3_object_version = var.runners_lambda_s3_object_version - zip = var.lambda_zip - tracing_config = var.tracing_config - github_app_parameters = var.github_app_parameters - enable_organization_runners = var.enable_organization_runners - sqs_build_queue = var.sqs_build_queue - ghes_url = var.ghes_url + prefix = var.prefix + tags = local.tags + aws_partition = var.aws_partition + architecture = var.lambda_architecture + runtime = var.lambda_runtime + security_group_ids = var.lambda_security_group_ids + subnet_ids = var.lambda_subnet_ids + kms_key_arn = var.kms_key_arn + lambda_tags = var.lambda_tags + log_level = var.log_level + logging_kms_key_id = var.logging_kms_key_id + logging_retention_in_days = var.logging_retention_in_days + metrics = var.metrics + role_path = var.role_path + role_permissions_boundary = var.role_permissions_boundary + s3_bucket = var.lambda_s3_bucket + s3_key = var.runners_lambda_s3_key + s3_object_version = var.runners_lambda_s3_object_version + zip = var.lambda_zip + tracing_config = var.tracing_config + github_app_parameters = var.github_app_parameters + enable_organization_runners = var.enable_organization_runners + sqs_build_queue = var.sqs_build_queue + ghes_url = var.ghes_url + lambda_event_source_mapping_batch_size = var.lambda_event_source_mapping_batch_size + lambda_event_source_mapping_maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds } } diff --git a/modules/runners/job-retry/README.md b/modules/runners/job-retry/README.md index 5276db9d60..f2e078ac52 100644 --- a/modules/runners/job-retry/README.md +++ b/modules/runners/job-retry/README.md @@ -42,7 +42,7 @@ The module is an inner module and used by the runner module when the opt-in feat | Name | Description | Type | Default | Required | |------|-------------|------|---------|:--------:| -| [config](#input\_config) | Configuration for the spot termination watcher lambda function.

`aws_partition`: Partition for the base arn if not 'aws'
`architecture`: AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions.
`environment_variables`: Environment variables for the lambda.
`enable_organization_runners`: Enable organization runners.
`enable_metric`: Enable metric for the lambda. If `spot_warning` is set to true, the lambda will emit a metric when it detects a spot termination warning.
'ghes\_url': Optional GitHub Enterprise Server URL.
'user\_agent': Optional User-Agent header for GitHub API requests.
'github\_app\_parameters': Parameter Store for GitHub App Parameters.
'kms\_key\_arn': Optional CMK Key ARN instead of using the default AWS managed key.
`lambda_principals`: Add extra principals to the role created for execution of the lambda, e.g. for local testing.
`lambda_tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`log_level`: Logging level for lambda logging. Valid values are 'silly', 'trace', 'debug', 'info', 'warn', 'error', 'fatal'.
`logging_kms_key_id`: Specifies the kms key id to encrypt the logs with
`logging_retention_in_days`: Specifies the number of days you want to retain log events for the lambda log group. Possible values are: 0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, and 3653.
`memory_size`: Memory size linit in MB of the lambda.
`metrics`: Configuration to enable metrics creation by the lambda.
`prefix`: The prefix used for naming resources.
`role_path`: The path that will be added to the role, if not set the environment name will be used.
`role_permissions_boundary`: Permissions boundary that will be added to the created role for the lambda.
`runtime`: AWS Lambda runtime.
`s3_bucket`: S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly.
`s3_key`: S3 key for syncer lambda function. Required if using S3 bucket to specify lambdas.
`s3_object_version`: S3 object version for syncer lambda function. Useful if S3 versioning is enabled on source bucket.
`security_group_ids`: List of security group IDs associated with the Lambda function.
'sqs\_build\_queue': SQS queue for build events to re-publish job request.
`subnet_ids`: List of subnets in which the action runners will be launched, the subnets needs to be subnets in the `vpc_id`.
`tag_filters`: Map of tags that will be used to filter the resources to be tracked. Only for which all tags are present and starting with the same value as the value in the map will be tracked.
`tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`timeout`: Time out of the lambda in seconds.
`tracing_config`: Configuration for lambda tracing.
`zip`: File location of the lambda zip file. |
object({
aws_partition = optional(string, null)
architecture = optional(string, null)
enable_organization_runners = bool
environment_variables = optional(map(string), {})
ghes_url = optional(string, null)
user_agent = optional(string, null)
github_app_parameters = object({
key_base64 = map(string)
id = map(string)
})
kms_key_arn = optional(string, null)
lambda_tags = optional(map(string), {})
log_level = optional(string, null)
logging_kms_key_id = optional(string, null)
logging_retention_in_days = optional(number, null)
memory_size = optional(number, null)
metrics = optional(object({
enable = optional(bool, false)
namespace = optional(string, null)
metric = optional(object({
enable_github_app_rate_limit = optional(bool, true)
enable_job_retry = optional(bool, true)
}), {})
}), {})
prefix = optional(string, null)
principals = optional(list(object({
type = string
identifiers = list(string)
})), [])
queue_encryption = optional(object({
kms_data_key_reuse_period_seconds = optional(number, null)
kms_master_key_id = optional(string, null)
sqs_managed_sse_enabled = optional(bool, true)
}), {})
role_path = optional(string, null)
role_permissions_boundary = optional(string, null)
runtime = optional(string, null)
security_group_ids = optional(list(string), [])
subnet_ids = optional(list(string), [])
s3_bucket = optional(string, null)
s3_key = optional(string, null)
s3_object_version = optional(string, null)
sqs_build_queue = object({
url = string
arn = string
})
tags = optional(map(string), {})
timeout = optional(number, 30)
tracing_config = optional(object({
mode = optional(string, null)
capture_http_requests = optional(bool, false)
capture_error = optional(bool, false)
}), {})
zip = optional(string, null)
})
| n/a | yes | +| [config](#input\_config) | Configuration for the spot termination watcher lambda function.

`aws_partition`: Partition for the base arn if not 'aws'
`architecture`: AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions.
`environment_variables`: Environment variables for the lambda.
`enable_organization_runners`: Enable organization runners.
`enable_metric`: Enable metric for the lambda. If `spot_warning` is set to true, the lambda will emit a metric when it detects a spot termination warning.
'ghes\_url': Optional GitHub Enterprise Server URL.
'user\_agent': Optional User-Agent header for GitHub API requests.
'github\_app\_parameters': Parameter Store for GitHub App Parameters.
'kms\_key\_arn': Optional CMK Key ARN instead of using the default AWS managed key.
`lambda_event_source_mapping_batch_size`: Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default will be used.
`lambda_event_source_mapping_maximum_batching_window_in_seconds`: Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10.
`lambda_principals`: Add extra principals to the role created for execution of the lambda, e.g. for local testing.
`lambda_tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`log_level`: Logging level for lambda logging. Valid values are 'silly', 'trace', 'debug', 'info', 'warn', 'error', 'fatal'.
`logging_kms_key_id`: Specifies the kms key id to encrypt the logs with
`logging_retention_in_days`: Specifies the number of days you want to retain log events for the lambda log group. Possible values are: 0, 1, 3, 5, 7, 14, 30, 60, 90, 120, 150, 180, 365, 400, 545, 731, 1827, and 3653.
`memory_size`: Memory size linit in MB of the lambda.
`metrics`: Configuration to enable metrics creation by the lambda.
`prefix`: The prefix used for naming resources.
`role_path`: The path that will be added to the role, if not set the environment name will be used.
`role_permissions_boundary`: Permissions boundary that will be added to the created role for the lambda.
`runtime`: AWS Lambda runtime.
`s3_bucket`: S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly.
`s3_key`: S3 key for syncer lambda function. Required if using S3 bucket to specify lambdas.
`s3_object_version`: S3 object version for syncer lambda function. Useful if S3 versioning is enabled on source bucket.
`security_group_ids`: List of security group IDs associated with the Lambda function.
'sqs\_build\_queue': SQS queue for build events to re-publish job request.
`subnet_ids`: List of subnets in which the action runners will be launched, the subnets needs to be subnets in the `vpc_id`.
`tag_filters`: Map of tags that will be used to filter the resources to be tracked. Only for which all tags are present and starting with the same value as the value in the map will be tracked.
`tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment.
`timeout`: Time out of the lambda in seconds.
`tracing_config`: Configuration for lambda tracing.
`zip`: File location of the lambda zip file. |
object({
aws_partition = optional(string, null)
architecture = optional(string, null)
enable_organization_runners = bool
environment_variables = optional(map(string), {})
ghes_url = optional(string, null)
user_agent = optional(string, null)
github_app_parameters = object({
key_base64 = map(string)
id = map(string)
})
kms_key_arn = optional(string, null)
lambda_event_source_mapping_batch_size = optional(number, 10)
lambda_event_source_mapping_maximum_batching_window_in_seconds = optional(number, 0)
lambda_tags = optional(map(string), {})
log_level = optional(string, null)
logging_kms_key_id = optional(string, null)
logging_retention_in_days = optional(number, null)
memory_size = optional(number, null)
metrics = optional(object({
enable = optional(bool, false)
namespace = optional(string, null)
metric = optional(object({
enable_github_app_rate_limit = optional(bool, true)
enable_job_retry = optional(bool, true)
}), {})
}), {})
prefix = optional(string, null)
principals = optional(list(object({
type = string
identifiers = list(string)
})), [])
queue_encryption = optional(object({
kms_data_key_reuse_period_seconds = optional(number, null)
kms_master_key_id = optional(string, null)
sqs_managed_sse_enabled = optional(bool, true)
}), {})
role_path = optional(string, null)
role_permissions_boundary = optional(string, null)
runtime = optional(string, null)
security_group_ids = optional(list(string), [])
subnet_ids = optional(list(string), [])
s3_bucket = optional(string, null)
s3_key = optional(string, null)
s3_object_version = optional(string, null)
sqs_build_queue = object({
url = string
arn = string
})
tags = optional(map(string), {})
timeout = optional(number, 30)
tracing_config = optional(object({
mode = optional(string, null)
capture_http_requests = optional(bool, false)
capture_error = optional(bool, false)
}), {})
zip = optional(string, null)
})
| n/a | yes | ## Outputs diff --git a/modules/runners/job-retry/main.tf b/modules/runners/job-retry/main.tf index 9561c7db71..612c515f8c 100644 --- a/modules/runners/job-retry/main.tf +++ b/modules/runners/job-retry/main.tf @@ -44,9 +44,10 @@ module "job_retry" { } resource "aws_lambda_event_source_mapping" "job_retry" { - event_source_arn = aws_sqs_queue.job_retry_check_queue.arn - function_name = module.job_retry.lambda.function.arn - batch_size = 1 + event_source_arn = aws_sqs_queue.job_retry_check_queue.arn + function_name = module.job_retry.lambda.function.arn + batch_size = var.config.lambda_event_source_mapping_batch_size + maximum_batching_window_in_seconds = var.config.lambda_event_source_mapping_maximum_batching_window_in_seconds } resource "aws_lambda_permission" "job_retry" { diff --git a/modules/runners/job-retry/variables.tf b/modules/runners/job-retry/variables.tf index 4a8fe19fbf..f40bec1ba7 100644 --- a/modules/runners/job-retry/variables.tf +++ b/modules/runners/job-retry/variables.tf @@ -11,6 +11,8 @@ variable "config" { 'user_agent': Optional User-Agent header for GitHub API requests. 'github_app_parameters': Parameter Store for GitHub App Parameters. 'kms_key_arn': Optional CMK Key ARN instead of using the default AWS managed key. + `lambda_event_source_mapping_batch_size`: Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default will be used. + `lambda_event_source_mapping_maximum_batching_window_in_seconds`: Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. `lambda_principals`: Add extra principals to the role created for execution of the lambda, e.g. for local testing. `lambda_tags`: Map of tags that will be added to created resources. By default resources will be tagged with name and environment. `log_level`: Logging level for lambda logging. Valid values are 'silly', 'trace', 'debug', 'info', 'warn', 'error', 'fatal'. @@ -45,12 +47,14 @@ variable "config" { key_base64 = map(string) id = map(string) }) - kms_key_arn = optional(string, null) - lambda_tags = optional(map(string), {}) - log_level = optional(string, null) - logging_kms_key_id = optional(string, null) - logging_retention_in_days = optional(number, null) - memory_size = optional(number, null) + kms_key_arn = optional(string, null) + lambda_event_source_mapping_batch_size = optional(number, 10) + lambda_event_source_mapping_maximum_batching_window_in_seconds = optional(number, 0) + lambda_tags = optional(map(string), {}) + log_level = optional(string, null) + logging_kms_key_id = optional(string, null) + logging_retention_in_days = optional(number, null) + memory_size = optional(number, null) metrics = optional(object({ enable = optional(bool, false) namespace = optional(string, null) diff --git a/modules/runners/scale-up.tf b/modules/runners/scale-up.tf index ad96c496a4..fad37af288 100644 --- a/modules/runners/scale-up.tf +++ b/modules/runners/scale-up.tf @@ -87,9 +87,11 @@ resource "aws_cloudwatch_log_group" "scale_up" { } resource "aws_lambda_event_source_mapping" "scale_up" { - event_source_arn = var.sqs_build_queue.arn - function_name = aws_lambda_function.scale_up.arn - batch_size = 1 + event_source_arn = var.sqs_build_queue.arn + function_name = aws_lambda_function.scale_up.arn + function_response_types = ["ReportBatchItemFailures"] + batch_size = var.lambda_event_source_mapping_batch_size + maximum_batching_window_in_seconds = var.lambda_event_source_mapping_maximum_batching_window_in_seconds } resource "aws_lambda_permission" "scale_runners_lambda" { diff --git a/modules/runners/variables.tf b/modules/runners/variables.tf index f70e80b9cc..1960d51946 100644 --- a/modules/runners/variables.tf +++ b/modules/runners/variables.tf @@ -761,3 +761,23 @@ variable "user_agent" { type = string default = null } + +variable "lambda_event_source_mapping_batch_size" { + description = "Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used." + type = number + default = 10 + validation { + condition = var.lambda_event_source_mapping_batch_size >= 1 && var.lambda_event_source_mapping_batch_size <= 1000 + error_message = "The batch size for the lambda event source mapping must be between 1 and 1000." + } +} + +variable "lambda_event_source_mapping_maximum_batching_window_in_seconds" { + description = "Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. Defaults to 0." + type = number + default = 0 + validation { + condition = var.lambda_event_source_mapping_maximum_batching_window_in_seconds >= 0 && var.lambda_event_source_mapping_maximum_batching_window_in_seconds <= 300 + error_message = "Maximum batching window must be between 0 and 300 seconds." + } +} diff --git a/variables.tf b/variables.tf index 975aa19b1d..d066702256 100644 --- a/variables.tf +++ b/variables.tf @@ -1007,3 +1007,19 @@ variable "user_agent" { type = string default = "github-aws-runners" } + +variable "lambda_event_source_mapping_batch_size" { + description = "Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used." + type = number + default = 10 +} + +variable "lambda_event_source_mapping_maximum_batching_window_in_seconds" { + description = "Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch_size is greater than 10. Defaults to 0." + type = number + default = 0 + validation { + condition = var.lambda_event_source_mapping_maximum_batching_window_in_seconds >= 0 && var.lambda_event_source_mapping_maximum_batching_window_in_seconds <= 300 + error_message = "Maximum batching window must be between 0 and 300 seconds." + } +}