Skip to content

Commit bf6cefc

Browse files
Add interrupted policy to MessageChannelPartitionHandler polling
Signed-off-by: brian.mcnamara <brian.mcnamara@salesforce.com>
1 parent 088487b commit bf6cefc

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.springframework.batch.core.partition.StepExecutionSplitter;
3535
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
3636
import org.springframework.batch.core.repository.JobRepository;
37+
import org.springframework.batch.core.step.StepInterruptionPolicy;
38+
import org.springframework.batch.core.step.ThreadStepInterruptionPolicy;
3739
import org.springframework.batch.infrastructure.poller.DirectPoller;
3840
import org.springframework.batch.infrastructure.poller.Poller;
3941
import org.springframework.beans.factory.InitializingBean;
@@ -100,6 +102,8 @@ public class MessageChannelPartitionHandler extends AbstractPartitionHandler imp
100102

101103
private long timeout = -1;
102104

105+
private StepInterruptionPolicy stepInterruptionPolicy = new ThreadStepInterruptionPolicy();
106+
103107
/**
104108
* pollable channel for the replies
105109
*/
@@ -192,6 +196,15 @@ public void setReplyChannel(PollableChannel replyChannel) {
192196
this.replyChannel = replyChannel;
193197
}
194198

199+
/**
200+
* Set the step interrupt policy for the manager step. Policy called during polling
201+
* @param stepInterruptionPolicy policy to use for polling
202+
*/
203+
public void setStepInterruptionPolicy(StepInterruptionPolicy stepInterruptionPolicy) {
204+
Assert.notNull(stepInterruptionPolicy, "StepInterruptionPolicy cannot be null");
205+
this.stepInterruptionPolicy = stepInterruptionPolicy;
206+
}
207+
195208
/**
196209
* Sends {@link StepExecutionRequest} objects to the request channel of the
197210
* {@link MessagingTemplate}, and then receives the result back as a list of
@@ -235,6 +248,9 @@ private Set<StepExecution> pollReplies(StepExecution managerStepExecution, final
235248
Set<Long> partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet());
236249

237250
Callable<Set<StepExecution>> callback = () -> {
251+
252+
stepInterruptionPolicy.checkInterrupted(managerStepExecution);
253+
238254
JobExecution jobExecution = jobRepository.getJobExecution(managerStepExecution.getJobExecutionId());
239255
Set<StepExecution> finishedStepExecutions = jobExecution.getStepExecutions()
240256
.stream()

spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
import org.springframework.batch.core.BatchStatus;
2828
import org.springframework.batch.core.job.JobExecution;
2929
import org.springframework.batch.core.job.JobInstance;
30+
import org.springframework.batch.core.job.JobInterruptedException;
3031
import org.springframework.batch.core.job.parameters.JobParameters;
3132
import org.springframework.batch.core.repository.JobRepository;
3233
import org.springframework.batch.core.step.StepExecution;
3334
import org.springframework.batch.core.partition.StepExecutionSplitter;
35+
import org.springframework.batch.core.step.StepInterruptionPolicy;
3436
import org.springframework.integration.MessageTimeoutException;
3537
import org.springframework.integration.core.MessagingTemplate;
3638
import org.springframework.messaging.Message;
@@ -251,4 +253,73 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception {
251253
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
252254
}
253255

256+
@Test
257+
void testShutdownCancelsHandle() throws Exception {
258+
// execute with no default set
259+
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
260+
// mock
261+
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
262+
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
263+
StepExecutionSplitter stepExecutionSplitter = mock();
264+
MessagingTemplate operations = mock();
265+
JobRepository jobRepository = mock();
266+
// when
267+
HashSet<StepExecution> stepExecutions = new HashSet<>();
268+
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
269+
partition1.setStatus(BatchStatus.STARTED);
270+
stepExecutions.add(partition1);
271+
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
272+
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
273+
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
274+
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
275+
managerStepExecution.setTerminateOnly();
276+
277+
// set
278+
messageChannelPartitionHandler.setMessagingOperations(operations);
279+
messageChannelPartitionHandler.setJobRepository(jobRepository);
280+
messageChannelPartitionHandler.setStepName("step1");
281+
messageChannelPartitionHandler.afterPropertiesSet();
282+
283+
// execute
284+
assertThrows(JobInterruptedException.class,
285+
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
286+
}
287+
288+
@Test
289+
void testInterruptPolicy() throws Exception {
290+
String testExceptionMessage = "test exception message";
291+
// execute with no default set
292+
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
293+
294+
// mock
295+
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
296+
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
297+
StepExecutionSplitter stepExecutionSplitter = mock();
298+
MessagingTemplate operations = mock();
299+
JobRepository jobRepository = mock();
300+
// when
301+
HashSet<StepExecution> stepExecutions = new HashSet<>();
302+
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
303+
partition1.setStatus(BatchStatus.STARTED);
304+
stepExecutions.add(partition1);
305+
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
306+
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
307+
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
308+
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
309+
310+
// set
311+
messageChannelPartitionHandler.setMessagingOperations(operations);
312+
messageChannelPartitionHandler.setJobRepository(jobRepository);
313+
messageChannelPartitionHandler.setStepName("step1");
314+
messageChannelPartitionHandler.setStepInterruptionPolicy(stepExecution -> {
315+
throw new JobInterruptedException(testExceptionMessage);
316+
});
317+
messageChannelPartitionHandler.afterPropertiesSet();
318+
319+
// execute
320+
JobInterruptedException exception = assertThrows(JobInterruptedException.class,
321+
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
322+
assertEquals(testExceptionMessage, exception.getMessage());
323+
}
324+
254325
}

0 commit comments

Comments
 (0)