Skip to content

Commit 65011db

Browse files
Add interrupted policy to MessageChannelPartitionHandler polling
Signed-off-by: brian.mcnamara <brian.mcnamara@salesforce.com>
1 parent 088487b commit 65011db

File tree

2 files changed

+98
-12
lines changed

2 files changed

+98
-12
lines changed

spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,17 @@
1515
*/
1616
package org.springframework.batch.integration.partition;
1717

18-
import java.util.Collection;
19-
import java.util.HashSet;
20-
import java.util.List;
21-
import java.util.Set;
22-
import java.util.concurrent.Callable;
23-
import java.util.concurrent.Future;
24-
import java.util.concurrent.TimeUnit;
25-
import java.util.stream.Collectors;
26-
2718
import org.apache.commons.logging.Log;
2819
import org.apache.commons.logging.LogFactory;
29-
3020
import org.springframework.batch.core.job.JobExecution;
31-
import org.springframework.batch.core.step.Step;
32-
import org.springframework.batch.core.step.StepExecution;
3321
import org.springframework.batch.core.partition.PartitionHandler;
3422
import org.springframework.batch.core.partition.StepExecutionSplitter;
3523
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
3624
import org.springframework.batch.core.repository.JobRepository;
25+
import org.springframework.batch.core.step.Step;
26+
import org.springframework.batch.core.step.StepExecution;
27+
import org.springframework.batch.core.step.StepInterruptionPolicy;
28+
import org.springframework.batch.core.step.ThreadStepInterruptionPolicy;
3729
import org.springframework.batch.infrastructure.poller.DirectPoller;
3830
import org.springframework.batch.infrastructure.poller.Poller;
3931
import org.springframework.beans.factory.InitializingBean;
@@ -50,6 +42,15 @@
5042
import org.springframework.util.Assert;
5143
import org.springframework.util.CollectionUtils;
5244

45+
import java.util.Collection;
46+
import java.util.HashSet;
47+
import java.util.List;
48+
import java.util.Set;
49+
import java.util.concurrent.Callable;
50+
import java.util.concurrent.Future;
51+
import java.util.concurrent.TimeUnit;
52+
import java.util.stream.Collectors;
53+
5354
/**
5455
* A {@link PartitionHandler} that uses {@link MessageChannel} instances to send
5556
* instructions to remote workers and receive their responses. The {@link MessageChannel}
@@ -100,6 +101,8 @@ public class MessageChannelPartitionHandler extends AbstractPartitionHandler imp
100101

101102
private long timeout = -1;
102103

104+
private StepInterruptionPolicy stepInterruptionPolicy = new ThreadStepInterruptionPolicy();
105+
103106
/**
104107
* pollable channel for the replies
105108
*/
@@ -192,6 +195,15 @@ public void setReplyChannel(PollableChannel replyChannel) {
192195
this.replyChannel = replyChannel;
193196
}
194197

198+
/**
199+
* Set the step interrupt policy for the manager step. Policy called during polling
200+
* @param stepInterruptionPolicy policy to use for polling
201+
*/
202+
public void setStepInterruptionPolicy(StepInterruptionPolicy stepInterruptionPolicy) {
203+
Assert.notNull(stepInterruptionPolicy, "StepInterruptionPolicy cannot be null");
204+
this.stepInterruptionPolicy = stepInterruptionPolicy;
205+
}
206+
195207
/**
196208
* Sends {@link StepExecutionRequest} objects to the request channel of the
197209
* {@link MessagingTemplate}, and then receives the result back as a list of
@@ -235,6 +247,9 @@ private Set<StepExecution> pollReplies(StepExecution managerStepExecution, final
235247
Set<Long> partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet());
236248

237249
Callable<Set<StepExecution>> callback = () -> {
250+
251+
stepInterruptionPolicy.checkInterrupted(managerStepExecution);
252+
238253
JobExecution jobExecution = jobRepository.getJobExecution(managerStepExecution.getJobExecutionId());
239254
Set<StepExecution> finishedStepExecutions = jobExecution.getStepExecutions()
240255
.stream()

spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
import org.springframework.batch.core.BatchStatus;
2828
import org.springframework.batch.core.job.JobExecution;
2929
import org.springframework.batch.core.job.JobInstance;
30+
import org.springframework.batch.core.job.JobInterruptedException;
3031
import org.springframework.batch.core.job.parameters.JobParameters;
3132
import org.springframework.batch.core.repository.JobRepository;
3233
import org.springframework.batch.core.step.StepExecution;
3334
import org.springframework.batch.core.partition.StepExecutionSplitter;
35+
import org.springframework.batch.core.step.StepInterruptionPolicy;
3436
import org.springframework.integration.MessageTimeoutException;
3537
import org.springframework.integration.core.MessagingTemplate;
3638
import org.springframework.messaging.Message;
@@ -251,4 +253,73 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception {
251253
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
252254
}
253255

256+
@Test
257+
void testShutdownCancelsHandle() throws Exception {
258+
// execute with no default set
259+
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
260+
// mock
261+
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
262+
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
263+
StepExecutionSplitter stepExecutionSplitter = mock();
264+
MessagingTemplate operations = mock();
265+
JobRepository jobRepository = mock();
266+
// when
267+
HashSet<StepExecution> stepExecutions = new HashSet<>();
268+
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
269+
partition1.setStatus(BatchStatus.STARTED);
270+
stepExecutions.add(partition1);
271+
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
272+
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
273+
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
274+
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
275+
managerStepExecution.setTerminateOnly();
276+
277+
// set
278+
messageChannelPartitionHandler.setMessagingOperations(operations);
279+
messageChannelPartitionHandler.setJobRepository(jobRepository);
280+
messageChannelPartitionHandler.setStepName("step1");
281+
messageChannelPartitionHandler.afterPropertiesSet();
282+
283+
// execute
284+
assertThrows(JobInterruptedException.class,
285+
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
286+
}
287+
288+
@Test
289+
void testInterruptPolicy() throws Exception {
290+
String testExceptionMessage = "test exception message";
291+
// execute with no default set
292+
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
293+
294+
// mock
295+
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
296+
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
297+
StepExecutionSplitter stepExecutionSplitter = mock();
298+
MessagingTemplate operations = mock();
299+
JobRepository jobRepository = mock();
300+
// when
301+
HashSet<StepExecution> stepExecutions = new HashSet<>();
302+
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
303+
partition1.setStatus(BatchStatus.STARTED);
304+
stepExecutions.add(partition1);
305+
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
306+
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
307+
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
308+
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
309+
310+
// set
311+
messageChannelPartitionHandler.setMessagingOperations(operations);
312+
messageChannelPartitionHandler.setJobRepository(jobRepository);
313+
messageChannelPartitionHandler.setStepName("step1");
314+
messageChannelPartitionHandler.setStepInterruptionPolicy(stepExecution -> {
315+
throw new JobInterruptedException(testExceptionMessage);
316+
});
317+
messageChannelPartitionHandler.afterPropertiesSet();
318+
319+
// execute
320+
JobInterruptedException exception = assertThrows(JobInterruptedException.class,
321+
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
322+
assertEquals(testExceptionMessage, exception.getMessage());
323+
}
324+
254325
}

0 commit comments

Comments
 (0)