From 7c26aca534a25cefbbf9d07253e01134452e5f24 Mon Sep 17 00:00:00 2001 From: lucliu1108 Date: Thu, 28 May 2026 16:52:11 -0500 Subject: [PATCH 1/6] add supplier to catch shutdown signal while bootstrapping --- .../internals/GlobalStateManagerImpl.java | 44 +++- .../internals/GlobalStreamThread.java | 27 ++- .../internals/GlobalStateManagerImplTest.java | 211 ++++++++++++++++-- .../internals/GlobalStreamThreadTest.java | 96 ++++++++ .../kafka/streams/TopologyTestDriver.java | 3 +- 5 files changed, 360 insertions(+), 21 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java index 8e7347e6fa472..d8a963a67a9b9 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java @@ -23,6 +23,7 @@ import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.metrics.Sensor; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.utils.Time; @@ -64,6 +65,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.BooleanSupplier; import java.util.function.Supplier; import static org.apache.kafka.streams.StreamsConfig.PROCESSING_EXCEPTION_HANDLER_CLASS_CONFIG; @@ -123,6 +125,7 @@ private static class StateStoreMetadata { private DeserializationExceptionHandler deserializationExceptionHandler; private ProcessingExceptionHandler processingExceptionHandler; private Sensor droppedRecordsSensor; + private BooleanSupplier shouldShutDownSupplier; public GlobalStateManagerImpl(final LogContext logContext, final Time time, @@ -130,7 +133,8 @@ public GlobalStateManagerImpl(final LogContext logContext, final Consumer globalConsumer, final StateDirectory stateDirectory, final StateRestoreListener stateRestoreListener, - final StreamsConfig config) { + final StreamsConfig config, + final BooleanSupplier shouldShutDown) { this.time = time; this.topology = topology; this.stateDirectory = stateDirectory; @@ -147,6 +151,7 @@ public GlobalStateManagerImpl(final LogContext logContext, logPrefix = logContext.logPrefix(); this.globalConsumer = globalConsumer; this.stateRestoreListener = stateRestoreListener; + this.shouldShutDownSupplier = shouldShutDown; final Map consumerProps = config.getGlobalConsumerConfigs("dummy"); // need to add mandatory configs; otherwise `QuietConsumerConfig` throws @@ -209,6 +214,10 @@ public Set initialize() { LegacyCheckpointingStateStore.migrateLegacyOffsets(logPrefix, stateDirectory, null, wrappedStores); for (final StateStoreMetadata metadata : storeMetadata.values()) { + if(shouldShutDownSupplier.getAsBoolean()) { + log.info("Global store bootstrap interrupted by shutdown before starting {}", metadata.stateStore.name()); + break; + } // load the committed offsets from the store final StateStore store = metadata.stateStore; if (store.persistent()) { @@ -348,7 +357,22 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); + if(shouldShutDownSupplier.getAsBoolean()) { + log.info("Global store bootstrap interrupted by shutdown before starting {}", storeMetadata.stateStore.name()); + return; + } + + final ConsumerRecords records; + try { + records = globalConsumer.poll(pollMsPlusRequestTimeout); + } catch (final WakeupException e) { + if (shouldShutDownSupplier.getAsBoolean()) { + log.info("Bootstrap interrupted by shutdown for {}", + storeMetadata.stateStore.name()); + return; + } + throw e; + } if (records.isEmpty()) { currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); } else { @@ -493,7 +517,21 @@ private void restoreState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); + if(shouldShutDownSupplier.getAsBoolean()) { + log.info("Global store bootstrap interrupted by shutdown before starting {}", storeMetadata.stateStore.name()); + return; + } + final ConsumerRecords records; + try { + records = globalConsumer.poll(pollMsPlusRequestTimeout); + } catch (final WakeupException e) { + if (shouldShutDownSupplier.getAsBoolean()) { + log.info("Bootstrap interrupted by shutdown for {}", + storeMetadata.stateStore.name()); + return; + } + throw e; + } if (records.isEmpty()) { currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); } else { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java index bede888525ad9..2b19d2f191e06 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java @@ -26,6 +26,7 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.internals.KafkaFutureImpl; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.internals.LogContext; @@ -49,6 +50,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BooleanSupplier; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.PENDING_SHUTDOWN; @@ -299,7 +301,13 @@ public void run() { if (size != -1L) { cache.resize(size); } - stateConsumer.pollAndUpdate(); + try { + stateConsumer.pollAndUpdate(); + } catch (final WakeupException e) { + if (!inErrorState()) { + throw e; + } + } if (fetchDeadlineClientInstanceId != -1) { if (fetchDeadlineClientInstanceId >= time.milliseconds()) { @@ -382,7 +390,8 @@ private StateConsumer initialize() { globalConsumer, stateDirectory, stateRestoreListener, - config + config, + () -> inErrorState() ); final GlobalProcessorContextImpl globalProcessorContext = new GlobalProcessorContextImpl( @@ -428,9 +437,22 @@ private StateConsumer initialize() { recoverableException ); } + + if (inErrorState()) { + closeStateConsumer(stateConsumer, false); + return null; + } setState(RUNNING); return stateConsumer; + } catch (final WakeupException e) { + closeStateConsumer(stateConsumer, false); + if (inErrorState()) { + log.info("Global thread initialization interrupted by shutdown"); + } else { + startupException = new StreamsException( + "Unexpected wakeup during initialization of GlobalStreamThread", e); + } } catch (final StreamsException fatalException) { closeStateConsumer(stateConsumer, false); startupException = fatalException; @@ -477,6 +499,7 @@ public void shutdown() { // if already shutting down or dead setState(PENDING_SHUTDOWN); initializationLatch.countDown(); + globalConsumer.wakeup(); } public Map consumerMetrics() { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java index 4759bdc050841..68737f7fd2504 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java @@ -23,6 +23,7 @@ import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.serialization.Deserializer; import org.apache.kafka.common.utils.LogCaptureAppender; import org.apache.kafka.common.utils.MockTime; @@ -63,6 +64,7 @@ import java.util.Optional; import java.util.Properties; import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static java.util.Arrays.asList; @@ -168,7 +170,8 @@ public void before() { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext = new InternalMockProcessorContext(stateDirectory.globalStateDir(), streamsConfig); stateManager.setGlobalProcessorContext(processorContext); @@ -639,7 +642,8 @@ public synchronized Map endOffsets(final Collection false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -682,7 +686,8 @@ public synchronized Map endOffsets(final Collection false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -723,7 +728,8 @@ public synchronized Map endOffsets(final Collection false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -771,7 +777,8 @@ public synchronized long position(final TopicPartition partition) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -805,7 +812,8 @@ public List partitionsFor(final String topic) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -848,7 +856,8 @@ public List partitionsFor(final String topic) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -889,7 +898,8 @@ public List partitionsFor(final String topic) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -937,7 +947,8 @@ public synchronized long position(final TopicPartition partition) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -971,7 +982,8 @@ public synchronized long position(final TopicPartition partition) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1014,7 +1026,8 @@ public synchronized long position(final TopicPartition partition) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1055,7 +1068,8 @@ public synchronized long position(final TopicPartition partition) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1098,7 +1112,8 @@ public synchronized long position(final TopicPartition partition) { consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1139,7 +1154,8 @@ public synchronized ConsumerRecords poll(final Duration timeout) consumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1214,7 +1230,8 @@ public void shouldWriteDowngradeCheckpointOnCloseWhenUpgradeFromIsPre43() throws consumer, downgradeStateDir, stateRestoreListener, - downgradeConfig + downgradeConfig, + () -> false ); final InternalMockProcessorContext downgradeContext = @@ -1252,6 +1269,170 @@ public void shouldNotWriteDowngradeCheckpointOnCloseWhenUpgradeFromIsNull() { assertFalse(legacyGlobalFile.exists()); } + @Test + public void shouldAbortRestoreWhenSupplierFlipsToShutdown() { + final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + stateManager = new GlobalStateManagerImpl( + new LogContext("test"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig, + shouldShutDown::get + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + initializeConsumer(6, 1, t1); + initializeConsumer(0, 0, t2, t3, t4, t5); + + shouldShutDown.set(true); + + stateManager.initialize(); + + // Nothing should have been restored + assertEquals(0L, stateRestoreListener.totalNumRestored); + } + + @Test + public void shutAbortRestoreWhenSupplierFlipsMidRestore() { + final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicInteger restoredCount = new AtomicInteger(0); + final MockStateRestoreListener flippingRestoreListener = new MockStateRestoreListener() { + @Override + public void onBatchRestored(final TopicPartition tp, + final String storeName, + final long batchEndOffset, + final long numRestored) { + super.onBatchRestored(tp, storeName, batchEndOffset, numRestored); + restoredCount.addAndGet((int) numRestored); + if (numRestored > 0) { + shouldShutDown.set(true); + } + } + }; + stateManager = new GlobalStateManagerImpl( + new LogContext("test"), + time, + topology, + consumer, + stateDirectory, + flippingRestoreListener, + streamsConfig, + shouldShutDown::get + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + initializeConsumer(6, 1, t1); + initializeConsumer(0, 0, t2, t3, t4, t5); + consumer.setMaxPollRecords(2L); + + stateManager.initialize(); + + assertEquals(2, restoredCount.get()); + } + + @Test + public void shouldExitCleanlyOnWakeupDuringBootstrapWhenShuttingDown() { + final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicInteger pollCount = new AtomicInteger(0); + consumer = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { + @Override + public ConsumerRecords poll(final Duration timeout) { + pollCount.incrementAndGet(); + shouldShutDown.set(true); + throw new WakeupException(); + } + }; + initializeConsumer(6, 1, t1); + initializeConsumer(0, 0, t2, t3, t4, t5); + + stateManager = new GlobalStateManagerImpl( + new LogContext("test"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig, + shouldShutDown::get + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + stateManager.initialize(); + + assertEquals(1, pollCount.get()); + assertTrue(shouldShutDown.get()); + } + + @Test + public void shouldPropagateWakeupDuringBootstrapWhenNotShuttingDown() { + final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + consumer = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { + @Override + public ConsumerRecords poll(final Duration timeout) { + throw new WakeupException(); + } + }; + initializeConsumer(6, 1, t1); + initializeConsumer(0, 0, t2, t3, t4, t5); + + stateManager = new GlobalStateManagerImpl( + new LogContext("test"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig, + shouldShutDown::get + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + assertThrows(WakeupException.class, () -> stateManager.initialize()); + } + + @Test + public void shouldExitCleanlyOnWakeupDuringReprocessingWhenShuttingDown() { + setUpReprocessing(); + + final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicInteger pollCount = new AtomicInteger(0); + consumer = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { + @Override + public ConsumerRecords poll(final Duration timeout) { + pollCount.incrementAndGet(); + shouldShutDown.set(true); + throw new WakeupException(); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + initializeConsumer(6, 1, t5); + + stateManager = new GlobalStateManagerImpl( + new LogContext("test"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig, + shouldShutDown::get + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + stateManager.initialize(); + + assertEquals(1, pollCount.get()); + assertTrue(shouldShutDown.get()); + } + private void writeCorruptCheckpoint() throws IOException { final File checkpointFile = new File(stateManager.baseDir(), StateManagerUtil.CHECKPOINT_FILE_NAME); try (final OutputStream stream = Files.newOutputStream(checkpointFile.toPath())) { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java index 7c0216cd07f58..12ecd01a6c7b1 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java @@ -25,6 +25,7 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.serialization.Serdes; import org.apache.kafka.common.utils.Bytes; @@ -57,6 +58,7 @@ import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.RUNNING; @@ -67,6 +69,7 @@ import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -243,6 +246,99 @@ public void shouldTransitionToRunningOnStart() throws Exception { globalStreamThread.shutdown(); } + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void shouldShutdownDuringBootstrap() throws Exception { + initializeConsumer(); + mockConsumer.updateEndOffsets(Collections.singletonMap(topicPartition, 1_000_000L)); + + final Thread shutdownThread = new Thread(() -> { + try { + TestUtils.waitForCondition( + () -> stateRestoreListener.storeNameCalledStates.containsKey(MockStateRestoreListener.RESTORE_START), + 10 * 1000L, + "Bootstrap restore never started."); + } catch (final Exception e) { + throw new RuntimeException(e); + } + globalStreamThread.shutdown(); + }); + shutdownThread.start(); + + startAndSwallowError(); + shutdownThread.join(); + globalStreamThread.join(5_000); + + assertEquals(DEAD, globalStreamThread.state()); + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void shouldNotInvokeUncaughtExceptionHandlerOnCloseAfterStart() throws Exception { + final AtomicReference caughtException = new AtomicReference<>(); + globalStreamThread.setUncaughtExceptionHandler(caughtException::set); + + initializeConsumer(); + startAndSwallowError(); + + TestUtils.waitForCondition( + () -> globalStreamThread.state() == RUNNING, + 10 * 1000, + "Thread never started."); + + mockConsumer.setMaxPollRecords(1L); + mockConsumer.updateEndOffsets(Collections.singletonMap(topicPartition, 50L)); + for (long offset = 0L; offset < 50L; offset++) { + mockConsumer.addRecord(record(GLOBAL_STORE_TOPIC_NAME, 0, offset, "k".getBytes(), "v".getBytes())); + } + + TestUtils.waitForCondition( + () -> mockConsumer.position(topicPartition) >= 1L, + 10 * 1000, + "First record never consumed by the main loop."); + + // Capture position before shutdown + // Else, afterwards the consumer is closed and calling position() throws IllegalStateException. + final long positionBeforeShutdown = mockConsumer.position(topicPartition); + + globalStreamThread.shutdown(); + globalStreamThread.join(); + + assertEquals(DEAD, globalStreamThread.state()); + assertNull(caughtException.get()); + assertTrue(positionBeforeShutdown < 10L, + "Shutdown should have interrupted the main loop before all records were consumed; position was " + + positionBeforeShutdown); + } + + @Test + public void shouldThrowStreamsExceptionOnStartupIfWakeupOccursWithoutShutdown() throws Exception { + final MockConsumer wakeupOnPartitionsFor = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { + @Override + public List partitionsFor(final String topic) { + throw new WakeupException(); + } + }; + globalStreamThread = new GlobalStreamThread( + builder.rewriteTopology(config).buildGlobalStateTopology(), + config, + wakeupOnPartitionsFor, + new StateDirectory(config, time, true, false), + 0, + new StreamsMetricsImpl(new Metrics(), "test-client", time), + time, + "clientId", + stateRestoreListener, + e -> { } + ); + + final StreamsException e = assertThrows(StreamsException.class, () -> globalStreamThread.start()); + assertThat(e.getCause(), instanceOf(WakeupException.class)); + + globalStreamThread.join(); + assertFalse(globalStreamThread.stillRunning()); + } + @Test public void shouldDieOnInvalidOffsetExceptionDuringStartup() throws Exception { final StateStore globalStore = builder.globalStateStores().get(GLOBAL_STORE_NAME); diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java index 2738458062a58..424124bebf919 100644 --- a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java @@ -441,7 +441,8 @@ private void setupGlobalTask(final Time mockWallClockTime, globalConsumer, stateDirectory, stateRestoreListener, - streamsConfig + streamsConfig, + () -> false ); final GlobalProcessorContextImpl globalProcessorContext = From 627150c78f9468aebbd67d6a916ebbd64834d1b5 Mon Sep 17 00:00:00 2001 From: lucliu1108 Date: Fri, 29 May 2026 12:25:08 -0500 Subject: [PATCH 2/6] cleanup --- .../internals/GlobalStateManagerImpl.java | 32 +++++++++-------- .../internals/GlobalStreamThread.java | 5 ++- .../internals/GlobalStateManagerImplTest.java | 34 +++++++++---------- .../internals/GlobalStreamThreadTest.java | 4 +-- 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java index d8a963a67a9b9..99c650c20f03d 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java @@ -125,7 +125,7 @@ private static class StateStoreMetadata { private DeserializationExceptionHandler deserializationExceptionHandler; private ProcessingExceptionHandler processingExceptionHandler; private Sensor droppedRecordsSensor; - private BooleanSupplier shouldShutDownSupplier; + private BooleanSupplier inErrorStateSupplier; public GlobalStateManagerImpl(final LogContext logContext, final Time time, @@ -134,7 +134,7 @@ public GlobalStateManagerImpl(final LogContext logContext, final StateDirectory stateDirectory, final StateRestoreListener stateRestoreListener, final StreamsConfig config, - final BooleanSupplier shouldShutDown) { + final BooleanSupplier inErrorStateSupplier) { this.time = time; this.topology = topology; this.stateDirectory = stateDirectory; @@ -151,7 +151,7 @@ public GlobalStateManagerImpl(final LogContext logContext, logPrefix = logContext.logPrefix(); this.globalConsumer = globalConsumer; this.stateRestoreListener = stateRestoreListener; - this.shouldShutDownSupplier = shouldShutDown; + this.inErrorStateSupplier = inErrorStateSupplier; final Map consumerProps = config.getGlobalConsumerConfigs("dummy"); // need to add mandatory configs; otherwise `QuietConsumerConfig` throws @@ -214,7 +214,7 @@ public Set initialize() { LegacyCheckpointingStateStore.migrateLegacyOffsets(logPrefix, stateDirectory, null, wrappedStores); for (final StateStoreMetadata metadata : storeMetadata.values()) { - if(shouldShutDownSupplier.getAsBoolean()) { + if (inErrorStateSupplier.getAsBoolean()) { log.info("Global store bootstrap interrupted by shutdown before starting {}", metadata.stateStore.name()); break; } @@ -357,8 +357,8 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - if(shouldShutDownSupplier.getAsBoolean()) { - log.info("Global store bootstrap interrupted by shutdown before starting {}", storeMetadata.stateStore.name()); + if (inErrorStateSupplier.getAsBoolean()) { + logBootstrapInterrupted(storeMetadata); return; } @@ -366,9 +366,8 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { try { records = globalConsumer.poll(pollMsPlusRequestTimeout); } catch (final WakeupException e) { - if (shouldShutDownSupplier.getAsBoolean()) { - log.info("Bootstrap interrupted by shutdown for {}", - storeMetadata.stateStore.name()); + if (inErrorStateSupplier.getAsBoolean()) { + logBootstrapInterrupted(storeMetadata); return; } throw e; @@ -461,7 +460,7 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { fatalUserException ); } - + if (response.result() == ProcessingExceptionHandler.Result.FAIL) { log.error("Processing exception handler is set to fail upon" + " a processing error. If you would rather have the streaming pipeline" + @@ -517,17 +516,16 @@ private void restoreState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - if(shouldShutDownSupplier.getAsBoolean()) { - log.info("Global store bootstrap interrupted by shutdown before starting {}", storeMetadata.stateStore.name()); + if (inErrorStateSupplier.getAsBoolean()) { + logBootstrapInterrupted(storeMetadata); return; } final ConsumerRecords records; try { records = globalConsumer.poll(pollMsPlusRequestTimeout); } catch (final WakeupException e) { - if (shouldShutDownSupplier.getAsBoolean()) { - log.info("Bootstrap interrupted by shutdown for {}", - storeMetadata.stateStore.name()); + if (inErrorStateSupplier.getAsBoolean()) { + logBootstrapInterrupted(storeMetadata); return; } throw e; @@ -556,6 +554,10 @@ private void restoreState(final StateStoreMetadata storeMetadata) { } } + private void logBootstrapInterrupted(final StateStoreMetadata storeMetadata) { + log.info("Bootstrap interrupted by shutdown for {}", storeMetadata.stateStore.name()); + } + private long getGlobalConsumerOffset(final TopicPartition topicPartition) { return retryUntilSuccessOrThrowOnTaskTimeout( () -> globalConsumer.position(topicPartition), diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java index 2b19d2f191e06..7f54085e371fe 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java @@ -50,7 +50,6 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.BooleanSupplier; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.PENDING_SHUTDOWN; @@ -391,7 +390,7 @@ private StateConsumer initialize() { stateDirectory, stateRestoreListener, config, - () -> inErrorState() + this::inErrorState ); final GlobalProcessorContextImpl globalProcessorContext = new GlobalProcessorContextImpl( @@ -437,7 +436,7 @@ private StateConsumer initialize() { recoverableException ); } - + if (inErrorState()) { closeStateConsumer(stateConsumer, false); return null; diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java index 68737f7fd2504..089b41ed80b65 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java @@ -1271,7 +1271,7 @@ public void shouldNotWriteDowngradeCheckpointOnCloseWhenUpgradeFromIsNull() { @Test public void shouldAbortRestoreWhenSupplierFlipsToShutdown() { - final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicBoolean inErrorState = new AtomicBoolean(false); stateManager = new GlobalStateManagerImpl( new LogContext("test"), time, @@ -1280,7 +1280,7 @@ public void shouldAbortRestoreWhenSupplierFlipsToShutdown() { stateDirectory, stateRestoreListener, streamsConfig, - shouldShutDown::get + inErrorState::get ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1288,7 +1288,7 @@ public void shouldAbortRestoreWhenSupplierFlipsToShutdown() { initializeConsumer(6, 1, t1); initializeConsumer(0, 0, t2, t3, t4, t5); - shouldShutDown.set(true); + inErrorState.set(true); stateManager.initialize(); @@ -1297,8 +1297,8 @@ public void shouldAbortRestoreWhenSupplierFlipsToShutdown() { } @Test - public void shutAbortRestoreWhenSupplierFlipsMidRestore() { - final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + public void shouldAbortRestoreWhenSupplierFlipsMidRestore() { + final AtomicBoolean inErrorState = new AtomicBoolean(false); final AtomicInteger restoredCount = new AtomicInteger(0); final MockStateRestoreListener flippingRestoreListener = new MockStateRestoreListener() { @Override @@ -1309,7 +1309,7 @@ public void onBatchRestored(final TopicPartition tp, super.onBatchRestored(tp, storeName, batchEndOffset, numRestored); restoredCount.addAndGet((int) numRestored); if (numRestored > 0) { - shouldShutDown.set(true); + inErrorState.set(true); } } }; @@ -1321,7 +1321,7 @@ public void onBatchRestored(final TopicPartition tp, stateDirectory, flippingRestoreListener, streamsConfig, - shouldShutDown::get + inErrorState::get ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1337,13 +1337,13 @@ public void onBatchRestored(final TopicPartition tp, @Test public void shouldExitCleanlyOnWakeupDuringBootstrapWhenShuttingDown() { - final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicBoolean inErrorState = new AtomicBoolean(false); final AtomicInteger pollCount = new AtomicInteger(0); consumer = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { @Override public ConsumerRecords poll(final Duration timeout) { pollCount.incrementAndGet(); - shouldShutDown.set(true); + inErrorState.set(true); throw new WakeupException(); } }; @@ -1358,7 +1358,7 @@ public ConsumerRecords poll(final Duration timeout) { stateDirectory, stateRestoreListener, streamsConfig, - shouldShutDown::get + inErrorState::get ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1366,12 +1366,12 @@ public ConsumerRecords poll(final Duration timeout) { stateManager.initialize(); assertEquals(1, pollCount.get()); - assertTrue(shouldShutDown.get()); + assertTrue(inErrorState.get()); } @Test public void shouldPropagateWakeupDuringBootstrapWhenNotShuttingDown() { - final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicBoolean inErrorState = new AtomicBoolean(false); consumer = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { @Override public ConsumerRecords poll(final Duration timeout) { @@ -1389,7 +1389,7 @@ public ConsumerRecords poll(final Duration timeout) { stateDirectory, stateRestoreListener, streamsConfig, - shouldShutDown::get + inErrorState::get ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1401,13 +1401,13 @@ public ConsumerRecords poll(final Duration timeout) { public void shouldExitCleanlyOnWakeupDuringReprocessingWhenShuttingDown() { setUpReprocessing(); - final AtomicBoolean shouldShutDown = new AtomicBoolean(false); + final AtomicBoolean inErrorState = new AtomicBoolean(false); final AtomicInteger pollCount = new AtomicInteger(0); consumer = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { @Override public ConsumerRecords poll(final Duration timeout) { pollCount.incrementAndGet(); - shouldShutDown.set(true); + inErrorState.set(true); throw new WakeupException(); } }; @@ -1422,7 +1422,7 @@ public ConsumerRecords poll(final Duration timeout) { stateDirectory, stateRestoreListener, streamsConfig, - shouldShutDown::get + inErrorState::get ); processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); @@ -1430,7 +1430,7 @@ public ConsumerRecords poll(final Duration timeout) { stateManager.initialize(); assertEquals(1, pollCount.get()); - assertTrue(shouldShutDown.get()); + assertTrue(inErrorState.get()); } private void writeCorruptCheckpoint() throws IOException { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java index 12ecd01a6c7b1..73c61a61dab88 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java @@ -306,11 +306,11 @@ public void shouldNotInvokeUncaughtExceptionHandlerOnCloseAfterStart() throws Ex assertEquals(DEAD, globalStreamThread.state()); assertNull(caughtException.get()); - assertTrue(positionBeforeShutdown < 10L, + assertTrue(positionBeforeShutdown < 50L, "Shutdown should have interrupted the main loop before all records were consumed; position was " + positionBeforeShutdown); } - + @Test public void shouldThrowStreamsExceptionOnStartupIfWakeupOccursWithoutShutdown() throws Exception { final MockConsumer wakeupOnPartitionsFor = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { From bf57fefc953619d79f62219d4e33cde70ce96e61 Mon Sep 17 00:00:00 2001 From: lucliu1108 Date: Fri, 29 May 2026 12:37:45 -0500 Subject: [PATCH 3/6] remove trailing spaces --- .../streams/processor/internals/GlobalStateManagerImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java index 99c650c20f03d..32a45316f4893 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java @@ -460,7 +460,7 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { fatalUserException ); } - + if (response.result() == ProcessingExceptionHandler.Result.FAIL) { log.error("Processing exception handler is set to fail upon" + " a processing error. If you would rather have the streaming pipeline" + From 56abb15cdc30b57692820f25f47a9cbe4181e4ca Mon Sep 17 00:00:00 2001 From: lucliu1108 Date: Tue, 2 Jun 2026 16:38:09 -0500 Subject: [PATCH 4/6] revise what state manager return when bootstrapping stopped --- .../internals/GlobalStateManager.java | 6 +- .../internals/GlobalStateManagerImpl.java | 132 ++++++++---------- .../internals/GlobalStateUpdateTask.java | 11 +- .../internals/GlobalStreamThread.java | 27 ++-- .../internals/GlobalStateManagerImplTest.java | 8 +- .../internals/GlobalStateUpdateTaskTest.java | 77 ++++++++++ .../internals/GlobalStreamThreadTest.java | 41 ------ .../kafka/test/GlobalStateManagerStub.java | 5 +- 8 files changed, 168 insertions(+), 139 deletions(-) create mode 100644 streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTaskTest.java diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java index f470254142ede..697ca7da58fb0 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java @@ -18,6 +18,7 @@ import org.apache.kafka.streams.errors.StreamsException; +import java.util.Optional; import java.util.Set; public interface GlobalStateManager extends StateManager { @@ -25,8 +26,11 @@ public interface GlobalStateManager extends StateManager { void setGlobalProcessorContext(final InternalProcessorContext processorContext); /** + * Bootstraps all global state stores. Returns the set of registered store names on success, + * or {@link Optional#empty()} if bootstrap was interrupted by a shutdown request. + * * @throws IllegalStateException If store gets registered after initialized is already finished * @throws StreamsException if the store's change log does not contain the partition */ - Set initialize(); + Optional> initialize(); } diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java index 32a45316f4893..2abb8cb75b3e8 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java @@ -125,7 +125,7 @@ private static class StateStoreMetadata { private DeserializationExceptionHandler deserializationExceptionHandler; private ProcessingExceptionHandler processingExceptionHandler; private Sensor droppedRecordsSensor; - private BooleanSupplier inErrorStateSupplier; + private BooleanSupplier shouldStopBootstrappingSupplier; public GlobalStateManagerImpl(final LogContext logContext, final Time time, @@ -134,7 +134,7 @@ public GlobalStateManagerImpl(final LogContext logContext, final StateDirectory stateDirectory, final StateRestoreListener stateRestoreListener, final StreamsConfig config, - final BooleanSupplier inErrorStateSupplier) { + final BooleanSupplier shouldStopBootstrappingSupplier) { this.time = time; this.topology = topology; this.stateDirectory = stateDirectory; @@ -151,7 +151,7 @@ public GlobalStateManagerImpl(final LogContext logContext, logPrefix = logContext.logPrefix(); this.globalConsumer = globalConsumer; this.stateRestoreListener = stateRestoreListener; - this.inErrorStateSupplier = inErrorStateSupplier; + this.shouldStopBootstrappingSupplier = shouldStopBootstrappingSupplier; final Map consumerProps = config.getGlobalConsumerConfigs("dummy"); // need to add mandatory configs; otherwise `QuietConsumerConfig` throws @@ -178,70 +178,78 @@ public void setGlobalProcessorContext(final InternalProcessorContext globa } @Override - public Set initialize() { + public Optional> initialize() { droppedRecordsSensor = droppedRecordsSensor( Thread.currentThread().getName(), globalProcessorContext.taskId().toString(), globalProcessorContext.metrics() ); - final Map wrappedStores = new HashMap<>(); - for (final StateStore stateStore : topology.globalStateStores()) { - final List storePartitions = topicPartitionsForStore(stateStore); - final StateStore maybeWrappedStore = LegacyCheckpointingStateStore.maybeWrapStore( - stateStore, eosEnabled, new HashSet<>(storePartitions), stateDirectory, null, logPrefix); - try { - maybeWrappedStore.init(globalProcessorContext, maybeWrappedStore); - } catch (final ProcessorStateException e) { - if (eosEnabled) { - log.warn("{}Detected unclean shutdown for global store {}. " + - "Wiping global state directory.", logPrefix, stateStore.name(), e); - try { - Utils.delete(stateDirectory.globalStateDir().getAbsoluteFile()); - } catch (final IOException ioe) { - e.addSuppressed(ioe); + try { + final Map wrappedStores = new HashMap<>(); + for (final StateStore stateStore : topology.globalStateStores()) { + final List storePartitions = topicPartitionsForStore(stateStore); + final StateStore maybeWrappedStore = LegacyCheckpointingStateStore.maybeWrapStore( + stateStore, eosEnabled, new HashSet<>(storePartitions), stateDirectory, null, logPrefix); + try { + maybeWrappedStore.init(globalProcessorContext, maybeWrappedStore); + } catch (final ProcessorStateException e) { + if (eosEnabled) { + log.warn("{}Detected unclean shutdown for global store {}. " + + "Wiping global state directory.", logPrefix, stateStore.name(), e); + try { + Utils.delete(stateDirectory.globalStateDir().getAbsoluteFile()); + } catch (final IOException ioe) { + e.addSuppressed(ioe); + } } + throw e; } - throw e; - } - for (final TopicPartition storePartition : storePartitions) { - wrappedStores.put(storePartition, maybeWrappedStore); + for (final TopicPartition storePartition : storePartitions) { + wrappedStores.put(storePartition, maybeWrappedStore); + } } - } - // migrate offsets from legacy checkpoint file into the stores - LegacyCheckpointingStateStore.migrateLegacyOffsets(logPrefix, stateDirectory, null, wrappedStores); + // migrate offsets from legacy checkpoint file into the stores + LegacyCheckpointingStateStore.migrateLegacyOffsets(logPrefix, stateDirectory, null, wrappedStores); - for (final StateStoreMetadata metadata : storeMetadata.values()) { - if (inErrorStateSupplier.getAsBoolean()) { - log.info("Global store bootstrap interrupted by shutdown before starting {}", metadata.stateStore.name()); - break; - } - // load the committed offsets from the store - final StateStore store = metadata.stateStore; - if (store.persistent()) { - for (final TopicPartition partition : metadata.changelogPartitions) { - final Long offset = store.committedOffset(partition); - if (offset != null) { - currentOffsets.put(partition, offset); + for (final StateStoreMetadata metadata : storeMetadata.values()) { + if (shouldStopBootstrappingSupplier.getAsBoolean()) { + log.info("Global store bootstrap interrupted by shutdown before starting {}", metadata.stateStore.name()); + return Optional.empty(); + } + // load the committed offsets from the store + final StateStore store = metadata.stateStore; + if (store.persistent()) { + for (final TopicPartition partition : metadata.changelogPartitions) { + final Long offset = store.committedOffset(partition); + if (offset != null) { + currentOffsets.put(partition, offset); + } } } - } - // restore or reprocess each registered store using the now-populated currentOffsets - try { - if (metadata.reprocessFactory.isPresent()) { - reprocessState(metadata); - } else { - restoreState(metadata); + // restore or reprocess each registered store using the now-populated currentOffsets + try { + if (metadata.reprocessFactory.isPresent()) { + reprocessState(metadata); + } else { + restoreState(metadata); + } + } finally { + globalConsumer.unsubscribe(); } - } finally { - globalConsumer.unsubscribe(); } - } - return Collections.unmodifiableSet(globalStoreNames); + return Optional.of(Collections.unmodifiableSet(globalStoreNames)); + } catch (final WakeupException e) { + if (!shouldStopBootstrappingSupplier.getAsBoolean()) { + throw e; + } + log.info("Global store bootstrap interrupted by shutdown"); + return Optional.empty(); + } } public StateStore globalStore(final String name) { @@ -357,21 +365,12 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - if (inErrorStateSupplier.getAsBoolean()) { + if (shouldStopBootstrappingSupplier.getAsBoolean()) { logBootstrapInterrupted(storeMetadata); return; } - final ConsumerRecords records; - try { - records = globalConsumer.poll(pollMsPlusRequestTimeout); - } catch (final WakeupException e) { - if (inErrorStateSupplier.getAsBoolean()) { - logBootstrapInterrupted(storeMetadata); - return; - } - throw e; - } + final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); if (records.isEmpty()) { currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); } else { @@ -516,20 +515,11 @@ private void restoreState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - if (inErrorStateSupplier.getAsBoolean()) { + if (shouldStopBootstrappingSupplier.getAsBoolean()) { logBootstrapInterrupted(storeMetadata); return; } - final ConsumerRecords records; - try { - records = globalConsumer.poll(pollMsPlusRequestTimeout); - } catch (final WakeupException e) { - if (inErrorStateSupplier.getAsBoolean()) { - logBootstrapInterrupted(storeMetadata); - return; - } - throw e; - } + final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); if (records.isEmpty()) { currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); } else { diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java index 3717859845155..68c0202db94f0 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java @@ -29,8 +29,10 @@ import org.slf4j.Logger; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.Set; import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; @@ -79,9 +81,14 @@ public GlobalStateUpdateTask(final LogContext logContext, */ @Override public Map initialize() { - final Set storeNames = stateMgr.initialize(); + final Optional> storeNames = stateMgr.initialize(); + if (storeNames.isEmpty()) { + // bootstrap was interrupted by shutdown; skip topology/processor init to avoid + // opening user resources via Processor#init() during a shutdown. + return Collections.emptyMap(); + } final Map storeNameToTopic = topology.storeToChangelogTopic(); - for (final String storeName : storeNames) { + for (final String storeName : storeNames.get()) { final String sourceTopic = storeNameToTopic.get(storeName); final SourceNode source = topology.source(sourceTopic); deserializers.put( diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java index 7f54085e371fe..292d369630db5 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java @@ -26,7 +26,6 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.Uuid; import org.apache.kafka.common.errors.TimeoutException; -import org.apache.kafka.common.errors.WakeupException; import org.apache.kafka.common.internals.KafkaFutureImpl; import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.internals.LogContext; @@ -300,13 +299,7 @@ public void run() { if (size != -1L) { cache.resize(size); } - try { - stateConsumer.pollAndUpdate(); - } catch (final WakeupException e) { - if (!inErrorState()) { - throw e; - } - } + stateConsumer.pollAndUpdate(); if (fetchDeadlineClientInstanceId != -1) { if (fetchDeadlineClientInstanceId >= time.milliseconds()) { @@ -444,14 +437,6 @@ private StateConsumer initialize() { setState(RUNNING); return stateConsumer; - } catch (final WakeupException e) { - closeStateConsumer(stateConsumer, false); - if (inErrorState()) { - log.info("Global thread initialization interrupted by shutdown"); - } else { - startupException = new StreamsException( - "Unexpected wakeup during initialization of GlobalStreamThread", e); - } } catch (final StreamsException fatalException) { closeStateConsumer(stateConsumer, false); startupException = fatalException; @@ -496,9 +481,15 @@ public synchronized void start() { public void shutdown() { // one could call shutdown() multiple times, so ignore subsequent calls // if already shutting down or dead - setState(PENDING_SHUTDOWN); + final boolean wakeupBootstrap; + synchronized (stateLock) { + wakeupBootstrap = (state == State.CREATED); + setState(PENDING_SHUTDOWN); + } initializationLatch.countDown(); - globalConsumer.wakeup(); + if (wakeupBootstrap) { + globalConsumer.wakeup(); + } } public Map consumerMetrics() { diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java index 089b41ed80b65..f141ab2127032 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java @@ -240,8 +240,8 @@ public void shouldInitializeStateStores() { @Test public void shouldReturnInitializedStoreNames() { initializeConsumer(0, 0, t1, t2, t3, t4, t5); - final Set storeNames = stateManager.initialize(); - assertEquals(Set.of(storeName1, storeName2, storeName3, storeName4, storeName5), storeNames); + final Optional> storeNames = stateManager.initialize(); + assertEquals(Optional.of(Set.of(storeName1, storeName2, storeName3, storeName4, storeName5)), storeNames); } @Test @@ -1363,7 +1363,7 @@ public ConsumerRecords poll(final Duration timeout) { processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); - stateManager.initialize(); + assertEquals(Optional.empty(), stateManager.initialize()); assertEquals(1, pollCount.get()); assertTrue(inErrorState.get()); @@ -1427,7 +1427,7 @@ public ConsumerRecords poll(final Duration timeout) { processorContext.setStateManger(stateManager); stateManager.setGlobalProcessorContext(processorContext); - stateManager.initialize(); + assertEquals(Optional.empty(), stateManager.initialize()); assertEquals(1, pollCount.get()); assertTrue(inErrorState.get()); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTaskTest.java new file mode 100644 index 0000000000000..b3ef500c1d128 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTaskTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.internals.LogContext; +import org.apache.kafka.streams.errors.DeserializationExceptionHandler; +import org.apache.kafka.streams.errors.ProcessingExceptionHandler; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; + +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.STRICT_STUBS) +public class GlobalStateUpdateTaskTest { + + @Mock + private ProcessorTopology topology; + @Mock + private InternalProcessorContext processorContext; + @Mock + private GlobalStateManager stateMgr; + @Mock + private DeserializationExceptionHandler deserializationExceptionHandler; + @Mock + private ProcessingExceptionHandler processingExceptionHandler; + + @Test + public void shouldSkipTopologyAndProcessorInitWhenBootstrapInterrupted() { + when(stateMgr.initialize()).thenReturn(Optional.empty()); + + final GlobalStateUpdateTask task = new GlobalStateUpdateTask( + new LogContext("test"), + topology, + processorContext, + stateMgr, + deserializationExceptionHandler, + processingExceptionHandler, + new MockTime(), + 0L + ); + + final Map offsets = task.initialize(); + + verify(topology, never()).processors(); + verify(processorContext, never()).initialize(); + verify(stateMgr, never()).changelogOffsets(); + assertTrue(offsets.isEmpty()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java index 73c61a61dab88..bec27fccd321c 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java @@ -58,7 +58,6 @@ import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.RUNNING; @@ -69,7 +68,6 @@ import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -272,45 +270,6 @@ public void shouldShutdownDuringBootstrap() throws Exception { assertEquals(DEAD, globalStreamThread.state()); } - @Test - @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) - public void shouldNotInvokeUncaughtExceptionHandlerOnCloseAfterStart() throws Exception { - final AtomicReference caughtException = new AtomicReference<>(); - globalStreamThread.setUncaughtExceptionHandler(caughtException::set); - - initializeConsumer(); - startAndSwallowError(); - - TestUtils.waitForCondition( - () -> globalStreamThread.state() == RUNNING, - 10 * 1000, - "Thread never started."); - - mockConsumer.setMaxPollRecords(1L); - mockConsumer.updateEndOffsets(Collections.singletonMap(topicPartition, 50L)); - for (long offset = 0L; offset < 50L; offset++) { - mockConsumer.addRecord(record(GLOBAL_STORE_TOPIC_NAME, 0, offset, "k".getBytes(), "v".getBytes())); - } - - TestUtils.waitForCondition( - () -> mockConsumer.position(topicPartition) >= 1L, - 10 * 1000, - "First record never consumed by the main loop."); - - // Capture position before shutdown - // Else, afterwards the consumer is closed and calling position() throws IllegalStateException. - final long positionBeforeShutdown = mockConsumer.position(topicPartition); - - globalStreamThread.shutdown(); - globalStreamThread.join(); - - assertEquals(DEAD, globalStreamThread.state()); - assertNull(caughtException.get()); - assertTrue(positionBeforeShutdown < 50L, - "Shutdown should have interrupted the main loop before all records were consumed; position was " - + positionBeforeShutdown); - } - @Test public void shouldThrowStreamsExceptionOnStartupIfWakeupOccursWithoutShutdown() throws Exception { final MockConsumer wakeupOnPartitionsFor = new MockConsumer<>(AutoOffsetResetStrategy.NONE.name()) { diff --git a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java index 1237deacf9b92..af7fe5f0459d9 100644 --- a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java +++ b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java @@ -26,6 +26,7 @@ import java.io.File; import java.util.Map; +import java.util.Optional; import java.util.Set; public class GlobalStateManagerStub implements GlobalStateManager { @@ -49,9 +50,9 @@ public GlobalStateManagerStub(final Set storeNames, public void setGlobalProcessorContext(final InternalProcessorContext processorContext) {} @Override - public Set initialize() { + public Optional> initialize() { initialized = true; - return storeNames; + return Optional.of(storeNames); } @Override From 7bddbf158cfed38d3f2a36e68a6d2149528b2a04 Mon Sep 17 00:00:00 2001 From: lucliu1108 Date: Tue, 2 Jun 2026 17:05:07 -0500 Subject: [PATCH 5/6] remove check inside restoreState and reprocessState --- .../internals/GlobalStateManagerImpl.java | 13 ------- .../internals/GlobalStateManagerImplTest.java | 39 ------------------- 2 files changed, 52 deletions(-) diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java index 2abb8cb75b3e8..4b5ecd190226c 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java @@ -365,11 +365,6 @@ private void reprocessState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - if (shouldStopBootstrappingSupplier.getAsBoolean()) { - logBootstrapInterrupted(storeMetadata); - return; - } - final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); if (records.isEmpty()) { currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); @@ -515,10 +510,6 @@ private void restoreState(final StateStoreMetadata storeMetadata) { // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call // `poll(pollMS)` without adding the request timeout and do a more precise // timeout handling - if (shouldStopBootstrappingSupplier.getAsBoolean()) { - logBootstrapInterrupted(storeMetadata); - return; - } final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); if (records.isEmpty()) { currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); @@ -544,10 +535,6 @@ private void restoreState(final StateStoreMetadata storeMetadata) { } } - private void logBootstrapInterrupted(final StateStoreMetadata storeMetadata) { - log.info("Bootstrap interrupted by shutdown for {}", storeMetadata.stateStore.name()); - } - private long getGlobalConsumerOffset(final TopicPartition topicPartition) { return retryUntilSuccessOrThrowOnTaskTimeout( () -> globalConsumer.position(topicPartition), diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java index f141ab2127032..3333f5ce444e4 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java @@ -1296,45 +1296,6 @@ public void shouldAbortRestoreWhenSupplierFlipsToShutdown() { assertEquals(0L, stateRestoreListener.totalNumRestored); } - @Test - public void shouldAbortRestoreWhenSupplierFlipsMidRestore() { - final AtomicBoolean inErrorState = new AtomicBoolean(false); - final AtomicInteger restoredCount = new AtomicInteger(0); - final MockStateRestoreListener flippingRestoreListener = new MockStateRestoreListener() { - @Override - public void onBatchRestored(final TopicPartition tp, - final String storeName, - final long batchEndOffset, - final long numRestored) { - super.onBatchRestored(tp, storeName, batchEndOffset, numRestored); - restoredCount.addAndGet((int) numRestored); - if (numRestored > 0) { - inErrorState.set(true); - } - } - }; - stateManager = new GlobalStateManagerImpl( - new LogContext("test"), - time, - topology, - consumer, - stateDirectory, - flippingRestoreListener, - streamsConfig, - inErrorState::get - ); - processorContext.setStateManger(stateManager); - stateManager.setGlobalProcessorContext(processorContext); - - initializeConsumer(6, 1, t1); - initializeConsumer(0, 0, t2, t3, t4, t5); - consumer.setMaxPollRecords(2L); - - stateManager.initialize(); - - assertEquals(2, restoredCount.get()); - } - @Test public void shouldExitCleanlyOnWakeupDuringBootstrapWhenShuttingDown() { final AtomicBoolean inErrorState = new AtomicBoolean(false); From 09935c5f56f2bcb6bafb6e5f45549ed10153f289 Mon Sep 17 00:00:00 2001 From: lucliu1108 Date: Wed, 3 Jun 2026 13:18:10 -0500 Subject: [PATCH 6/6] revise tests --- .../internals/GlobalStreamThreadTest.java | 37 +++++++++++-------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java index bec27fccd321c..95688374b2894 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java @@ -57,6 +57,9 @@ import java.util.List; import java.util.Set; import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; @@ -250,22 +253,26 @@ public void shouldShutdownDuringBootstrap() throws Exception { initializeConsumer(); mockConsumer.updateEndOffsets(Collections.singletonMap(topicPartition, 1_000_000L)); - final Thread shutdownThread = new Thread(() -> { - try { - TestUtils.waitForCondition( - () -> stateRestoreListener.storeNameCalledStates.containsKey(MockStateRestoreListener.RESTORE_START), - 10 * 1000L, - "Bootstrap restore never started."); - } catch (final Exception e) { - throw new RuntimeException(e); - } - globalStreamThread.shutdown(); - }); - shutdownThread.start(); + final ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + final Future shutdownFuture = executor.submit(() -> { + try { + TestUtils.waitForCondition( + () -> stateRestoreListener.storeNameCalledStates.containsKey(MockStateRestoreListener.RESTORE_START), + 10 * 1000L, + "Bootstrap restore never started."); + } catch (final Exception e) { + throw new RuntimeException(e); + } + globalStreamThread.shutdown(); + }); - startAndSwallowError(); - shutdownThread.join(); - globalStreamThread.join(5_000); + startAndSwallowError(); + shutdownFuture.get(); + globalStreamThread.join(); + } finally { + executor.shutdown(); + } assertEquals(DEAD, globalStreamThread.state()); }