diff --git a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java index 0676e2030d..020e7438cc 100644 --- a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java +++ b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java @@ -34,6 +34,8 @@ import org.springframework.batch.core.partition.StepExecutionSplitter; import org.springframework.batch.core.partition.support.AbstractPartitionHandler; import org.springframework.batch.core.repository.JobRepository; +import org.springframework.batch.core.step.StepInterruptionPolicy; +import org.springframework.batch.core.step.ThreadStepInterruptionPolicy; import org.springframework.batch.infrastructure.poller.DirectPoller; import org.springframework.batch.infrastructure.poller.Poller; import org.springframework.beans.factory.InitializingBean; @@ -100,6 +102,8 @@ public class MessageChannelPartitionHandler extends AbstractPartitionHandler imp private long timeout = -1; + private StepInterruptionPolicy stepInterruptionPolicy = new ThreadStepInterruptionPolicy(); + /** * pollable channel for the replies */ @@ -192,6 +196,15 @@ public void setReplyChannel(PollableChannel replyChannel) { this.replyChannel = replyChannel; } + /** + * Set the step interrupt policy for the manager step. Policy called during polling + * @param stepInterruptionPolicy policy to use for polling + */ + public void setStepInterruptionPolicy(StepInterruptionPolicy stepInterruptionPolicy) { + Assert.notNull(stepInterruptionPolicy, "StepInterruptionPolicy cannot be null"); + this.stepInterruptionPolicy = stepInterruptionPolicy; + } + /** * Sends {@link StepExecutionRequest} objects to the request channel of the * {@link MessagingTemplate}, and then receives the result back as a list of @@ -235,6 +248,9 @@ private Set pollReplies(StepExecution managerStepExecution, final Set partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet()); Callable> callback = () -> { + + stepInterruptionPolicy.checkInterrupted(managerStepExecution); + JobExecution jobExecution = jobRepository.getJobExecution(managerStepExecution.getJobExecutionId()); Set finishedStepExecutions = jobExecution.getStepExecutions() .stream() diff --git a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java index 72e9db154e..9a466cf97e 100644 --- a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java +++ b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java @@ -27,10 +27,12 @@ import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.job.JobExecution; import org.springframework.batch.core.job.JobInstance; +import org.springframework.batch.core.job.JobInterruptedException; import org.springframework.batch.core.job.parameters.JobParameters; import org.springframework.batch.core.repository.JobRepository; import org.springframework.batch.core.step.StepExecution; import org.springframework.batch.core.partition.StepExecutionSplitter; +import org.springframework.batch.core.step.StepInterruptionPolicy; import org.springframework.integration.MessageTimeoutException; import org.springframework.integration.core.MessagingTemplate; import org.springframework.messaging.Message; @@ -251,4 +253,73 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception { () -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution)); } + @Test + void testShutdownCancelsHandle() throws Exception { + // execute with no default set + messageChannelPartitionHandler = new MessageChannelPartitionHandler(); + // mock + JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters()); + StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution); + StepExecutionSplitter stepExecutionSplitter = mock(); + MessagingTemplate operations = mock(); + JobRepository jobRepository = mock(); + // when + HashSet stepExecutions = new HashSet<>(); + StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution); + partition1.setStatus(BatchStatus.STARTED); + stepExecutions.add(partition1); + when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions); + JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters()); + runningJobExecution.addStepExecutions(Arrays.asList(partition1)); + when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution); + managerStepExecution.setTerminateOnly(); + + // set + messageChannelPartitionHandler.setMessagingOperations(operations); + messageChannelPartitionHandler.setJobRepository(jobRepository); + messageChannelPartitionHandler.setStepName("step1"); + messageChannelPartitionHandler.afterPropertiesSet(); + + // execute + assertThrows(JobInterruptedException.class, + () -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution)); + } + + @Test + void testInterruptPolicy() throws Exception { + String testExceptionMessage = "test exception message"; + // execute with no default set + messageChannelPartitionHandler = new MessageChannelPartitionHandler(); + + // mock + JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters()); + StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution); + StepExecutionSplitter stepExecutionSplitter = mock(); + MessagingTemplate operations = mock(); + JobRepository jobRepository = mock(); + // when + HashSet stepExecutions = new HashSet<>(); + StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution); + partition1.setStatus(BatchStatus.STARTED); + stepExecutions.add(partition1); + when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions); + JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters()); + runningJobExecution.addStepExecutions(Arrays.asList(partition1)); + when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution); + + // set + messageChannelPartitionHandler.setMessagingOperations(operations); + messageChannelPartitionHandler.setJobRepository(jobRepository); + messageChannelPartitionHandler.setStepName("step1"); + messageChannelPartitionHandler.setStepInterruptionPolicy(stepExecution -> { + throw new JobInterruptedException(testExceptionMessage); + }); + messageChannelPartitionHandler.afterPropertiesSet(); + + // execute + JobInterruptedException exception = assertThrows(JobInterruptedException.class, + () -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution)); + assertEquals(testExceptionMessage, exception.getMessage()); + } + }