Skip to content

Commit 64f22ef

Browse files
Add interrupted policy to MessageChannelPartitionHandler polling
Signed-off-by: brian.mcnamara <brian.mcnamara@salesforce.com>
1 parent 088487b commit 64f22ef

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121
import java.util.Set;
2222
import java.util.concurrent.Callable;
23+
import java.util.concurrent.ExecutionException;
2324
import java.util.concurrent.Future;
2425
import java.util.concurrent.TimeUnit;
2526
import java.util.stream.Collectors;
@@ -28,12 +29,15 @@
2829
import org.apache.commons.logging.LogFactory;
2930

3031
import org.springframework.batch.core.job.JobExecution;
32+
import org.springframework.batch.core.job.JobInterruptedException;
3133
import org.springframework.batch.core.step.Step;
3234
import org.springframework.batch.core.step.StepExecution;
3335
import org.springframework.batch.core.partition.PartitionHandler;
3436
import org.springframework.batch.core.partition.StepExecutionSplitter;
3537
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
3638
import org.springframework.batch.core.repository.JobRepository;
39+
import org.springframework.batch.core.step.StepInterruptionPolicy;
40+
import org.springframework.batch.core.step.ThreadStepInterruptionPolicy;
3741
import org.springframework.batch.infrastructure.poller.DirectPoller;
3842
import org.springframework.batch.infrastructure.poller.Poller;
3943
import org.springframework.beans.factory.InitializingBean;
@@ -100,6 +104,8 @@ public class MessageChannelPartitionHandler extends AbstractPartitionHandler imp
100104

101105
private long timeout = -1;
102106

107+
private StepInterruptionPolicy stepInterruptionPolicy = new ThreadStepInterruptionPolicy();
108+
103109
/**
104110
* pollable channel for the replies
105111
*/
@@ -192,6 +198,15 @@ public void setReplyChannel(PollableChannel replyChannel) {
192198
this.replyChannel = replyChannel;
193199
}
194200

201+
/**
202+
* Set the step interrupt policy for the manager step. Policy called during polling
203+
* @param stepInterruptionPolicy policy to use for polling
204+
*/
205+
public void setStepInterruptionPolicy(StepInterruptionPolicy stepInterruptionPolicy) {
206+
Assert.notNull(stepInterruptionPolicy, "StepInterruptionPolicy cannot be null");
207+
this.stepInterruptionPolicy = stepInterruptionPolicy;
208+
}
209+
195210
/**
196211
* Sends {@link StepExecutionRequest} objects to the request channel of the
197212
* {@link MessagingTemplate}, and then receives the result back as a list of
@@ -235,6 +250,9 @@ private Set<StepExecution> pollReplies(StepExecution managerStepExecution, final
235250
Set<Long> partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet());
236251

237252
Callable<Set<StepExecution>> callback = () -> {
253+
254+
stepInterruptionPolicy.checkInterrupted(managerStepExecution);
255+
238256
JobExecution jobExecution = jobRepository.getJobExecution(managerStepExecution.getJobExecutionId());
239257
Set<StepExecution> finishedStepExecutions = jobExecution.getStepExecutions()
240258
.stream()

spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
import org.springframework.batch.core.BatchStatus;
2828
import org.springframework.batch.core.job.JobExecution;
2929
import org.springframework.batch.core.job.JobInstance;
30+
import org.springframework.batch.core.job.JobInterruptedException;
3031
import org.springframework.batch.core.job.parameters.JobParameters;
3132
import org.springframework.batch.core.repository.JobRepository;
3233
import org.springframework.batch.core.step.StepExecution;
3334
import org.springframework.batch.core.partition.StepExecutionSplitter;
35+
import org.springframework.batch.core.step.StepInterruptionPolicy;
3436
import org.springframework.integration.MessageTimeoutException;
3537
import org.springframework.integration.core.MessagingTemplate;
3638
import org.springframework.messaging.Message;
@@ -251,4 +253,73 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception {
251253
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
252254
}
253255

256+
@Test
257+
void testShutdownCancelsHandle() throws Exception {
258+
// execute with no default set
259+
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
260+
// mock
261+
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
262+
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
263+
StepExecutionSplitter stepExecutionSplitter = mock();
264+
MessagingTemplate operations = mock();
265+
JobRepository jobRepository = mock();
266+
// when
267+
HashSet<StepExecution> stepExecutions = new HashSet<>();
268+
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
269+
partition1.setStatus(BatchStatus.STARTED);
270+
stepExecutions.add(partition1);
271+
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
272+
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
273+
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
274+
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
275+
managerStepExecution.setTerminateOnly();
276+
277+
// set
278+
messageChannelPartitionHandler.setMessagingOperations(operations);
279+
messageChannelPartitionHandler.setJobRepository(jobRepository);
280+
messageChannelPartitionHandler.setStepName("step1");
281+
messageChannelPartitionHandler.afterPropertiesSet();
282+
283+
// execute
284+
assertThrows(JobInterruptedException.class,
285+
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
286+
}
287+
288+
@Test
289+
void testInterruptPolicy() throws Exception {
290+
String testExceptionMessage = "test exception message";
291+
// execute with no default set
292+
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
293+
294+
// mock
295+
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
296+
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
297+
StepExecutionSplitter stepExecutionSplitter = mock();
298+
MessagingTemplate operations = mock();
299+
JobRepository jobRepository = mock();
300+
// when
301+
HashSet<StepExecution> stepExecutions = new HashSet<>();
302+
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
303+
partition1.setStatus(BatchStatus.STARTED);
304+
stepExecutions.add(partition1);
305+
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
306+
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
307+
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
308+
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
309+
310+
// set
311+
messageChannelPartitionHandler.setMessagingOperations(operations);
312+
messageChannelPartitionHandler.setJobRepository(jobRepository);
313+
messageChannelPartitionHandler.setStepName("step1");
314+
messageChannelPartitionHandler.setStepInterruptionPolicy(stepExecution -> {
315+
throw new JobInterruptedException(testExceptionMessage);
316+
});
317+
messageChannelPartitionHandler.afterPropertiesSet();
318+
319+
// execute
320+
JobInterruptedException exception = assertThrows(JobInterruptedException.class,
321+
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
322+
assertEquals(testExceptionMessage, exception.getMessage());
323+
}
324+
254325
}

0 commit comments

Comments
 (0)