Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.springframework.batch.core.partition.StepExecutionSplitter;
import org.springframework.batch.core.partition.support.AbstractPartitionHandler;
import org.springframework.batch.core.repository.JobRepository;
import org.springframework.batch.core.step.StepInterruptionPolicy;
import org.springframework.batch.core.step.ThreadStepInterruptionPolicy;
import org.springframework.batch.infrastructure.poller.DirectPoller;
import org.springframework.batch.infrastructure.poller.Poller;
import org.springframework.beans.factory.InitializingBean;
Expand Down Expand Up @@ -100,6 +102,8 @@ public class MessageChannelPartitionHandler extends AbstractPartitionHandler imp

private long timeout = -1;

private StepInterruptionPolicy stepInterruptionPolicy = new ThreadStepInterruptionPolicy();

/**
* pollable channel for the replies
*/
Expand Down Expand Up @@ -192,6 +196,15 @@ public void setReplyChannel(PollableChannel replyChannel) {
this.replyChannel = replyChannel;
}

/**
* Set the step interrupt policy for the manager step. Policy called during polling
* @param stepInterruptionPolicy policy to use for polling
*/
public void setStepInterruptionPolicy(StepInterruptionPolicy stepInterruptionPolicy) {
Assert.notNull(stepInterruptionPolicy, "StepInterruptionPolicy cannot be null");
this.stepInterruptionPolicy = stepInterruptionPolicy;
}

/**
* Sends {@link StepExecutionRequest} objects to the request channel of the
* {@link MessagingTemplate}, and then receives the result back as a list of
Expand Down Expand Up @@ -235,6 +248,9 @@ private Set<StepExecution> pollReplies(StepExecution managerStepExecution, final
Set<Long> partitionStepExecutionIds = split.stream().map(StepExecution::getId).collect(Collectors.toSet());

Callable<Set<StepExecution>> callback = () -> {

stepInterruptionPolicy.checkInterrupted(managerStepExecution);

JobExecution jobExecution = jobRepository.getJobExecution(managerStepExecution.getJobExecutionId());
Set<StepExecution> finishedStepExecutions = jobExecution.getStepExecutions()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import org.springframework.batch.core.BatchStatus;
import org.springframework.batch.core.job.JobExecution;
import org.springframework.batch.core.job.JobInstance;
import org.springframework.batch.core.job.JobInterruptedException;
import org.springframework.batch.core.job.parameters.JobParameters;
import org.springframework.batch.core.repository.JobRepository;
import org.springframework.batch.core.step.StepExecution;
import org.springframework.batch.core.partition.StepExecutionSplitter;
import org.springframework.batch.core.step.StepInterruptionPolicy;
import org.springframework.integration.MessageTimeoutException;
import org.springframework.integration.core.MessagingTemplate;
import org.springframework.messaging.Message;
Expand Down Expand Up @@ -251,4 +253,73 @@ void testHandleWithJobRepositoryPollingTimeout() throws Exception {
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
}

@Test
void testShutdownCancelsHandle() throws Exception {
// execute with no default set
messageChannelPartitionHandler = new MessageChannelPartitionHandler();
// mock
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
StepExecutionSplitter stepExecutionSplitter = mock();
MessagingTemplate operations = mock();
JobRepository jobRepository = mock();
// when
HashSet<StepExecution> stepExecutions = new HashSet<>();
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
partition1.setStatus(BatchStatus.STARTED);
stepExecutions.add(partition1);
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);
managerStepExecution.setTerminateOnly();

// set
messageChannelPartitionHandler.setMessagingOperations(operations);
messageChannelPartitionHandler.setJobRepository(jobRepository);
messageChannelPartitionHandler.setStepName("step1");
messageChannelPartitionHandler.afterPropertiesSet();

// execute
assertThrows(JobInterruptedException.class,
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
}

@Test
void testInterruptPolicy() throws Exception {
String testExceptionMessage = "test exception message";
// execute with no default set
messageChannelPartitionHandler = new MessageChannelPartitionHandler();

// mock
JobExecution jobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
StepExecution managerStepExecution = new StepExecution(1L, "step1", jobExecution);
StepExecutionSplitter stepExecutionSplitter = mock();
MessagingTemplate operations = mock();
JobRepository jobRepository = mock();
// when
HashSet<StepExecution> stepExecutions = new HashSet<>();
StepExecution partition1 = new StepExecution(2L, "step1:partition1", jobExecution);
partition1.setStatus(BatchStatus.STARTED);
stepExecutions.add(partition1);
when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions);
JobExecution runningJobExecution = new JobExecution(5L, new JobInstance(1L, "job"), new JobParameters());
runningJobExecution.addStepExecutions(Arrays.asList(partition1));
when(jobRepository.getJobExecution(5L)).thenReturn(runningJobExecution);

// set
messageChannelPartitionHandler.setMessagingOperations(operations);
messageChannelPartitionHandler.setJobRepository(jobRepository);
messageChannelPartitionHandler.setStepName("step1");
messageChannelPartitionHandler.setStepInterruptionPolicy(stepExecution -> {
throw new JobInterruptedException(testExceptionMessage);
});
messageChannelPartitionHandler.afterPropertiesSet();

// execute
JobInterruptedException exception = assertThrows(JobInterruptedException.class,
() -> messageChannelPartitionHandler.handle(stepExecutionSplitter, managerStepExecution));
assertEquals(testExceptionMessage, exception.getMessage());
}

}