diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 4d070da995b3..4d04c97a1eca 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -178,6 +178,9 @@ public final class StreamingDataflowWorker { // Experiment make the monitor within BoundedQueueExecutor fair public static final String BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT = "windmill_bounded_queue_executor_use_fair_monitor"; + // Don't use. Experiment guarding multi key bundles. The feature is work in progress and + // incomplete. + private static final String UNSTABLE_ENABLE_MULTI_KEY_BUNDLE = "unstable_enable_multi_key_bundle"; private final WindmillStateCache stateCache; private AtomicReference statusPages = new AtomicReference<>(); @@ -1017,6 +1020,8 @@ private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, l private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarnessOptions options) { boolean useFairMonitor = DataflowRunner.hasExperiment(options, BOUNDED_QUEUE_EXECUTOR_USE_FAIR_MONITOR_EXPERIMENT); + boolean useKeyGroupWorkQueue = + DataflowRunner.hasExperiment(options, UNSTABLE_ENABLE_MULTI_KEY_BUNDLE); return new BoundedQueueExecutor( chooseMaxThreads(options), THREAD_EXPIRATION_TIME_SEC, @@ -1024,7 +1029,8 @@ private static BoundedQueueExecutor createWorkUnitExecutor(DataflowWorkerHarness chooseMaxBundlesOutstanding(options), chooseMaxBytesOutstanding(options), new ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(), - useFairMonitor); + useFairMonitor, + useKeyGroupWorkQueue); } public static void main(String[] args) throws Exception { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java index ecaa673f5570..7748a554f0fc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ExecutableWork.java @@ -62,11 +62,11 @@ public void run(BoundedQueueExecutorWorkHandle handle) { } } - public final WorkId id() { + public WorkId id() { return work().id(); } - public final Windmill.WorkItem getWorkItem() { + public Windmill.WorkItem getWorkItem() { return work().getWorkItem(); } @@ -74,4 +74,12 @@ public final Windmill.WorkItem getWorkItem() { public String toString() { return "ExecutableWork{" + id() + "}"; } + + public String getComputationId() { + return work().getComputationId(); + } + + public Work.KeyGroup getKeyGroup() { + return work().getKeyGroup(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index cb01e1e508ce..53ed30fdedbb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -25,6 +25,7 @@ import java.util.IntSummaryStatistics; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -52,6 +53,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Duration; import org.joda.time.Instant; @@ -74,6 +76,7 @@ public final class Work implements RefreshableWork { private final Instant startTime; private final Map totalDurationPerState; private final WorkId id; + private final KeyGroup keyGroup; private final String latencyTrackingId; private final long serializedWorkItemSize; private volatile TimedState currentState; @@ -101,6 +104,10 @@ private Work( // keyUniverse inside EnumMap every time. this.totalDurationPerState = new EnumMap<>(EMPTY_ENUM_MAP); this.id = WorkId.of(workItem); + this.keyGroup = + workItem.hasKeyGroup() + ? KeyGroup.create(workItem.getKeyGroup().getHigh(), workItem.getKeyGroup().getLow()) + : KeyGroup.DEFAULT; this.latencyTrackingId = Long.toHexString(workItem.getShardingKey()) + '-' @@ -383,6 +390,14 @@ private boolean isCommitPending() { abstract Instant startTime(); } + public String getComputationId() { + return processingContext.computationId(); + } + + public KeyGroup getKeyGroup() { + return keyGroup; + } + @AutoValue public abstract static class ProcessingContext { @@ -416,4 +431,60 @@ private Optional fetchKeyedState(KeyedGetDataRequest reque return Optional.ofNullable(getDataClient().getStateData(computationId(), request)); } } + + /** + * WorkItems with same key group and computation are eligible to be executed together in a + * multi-key bundle. + */ + public static final class KeyGroup { + + // Work items equaling to the default keyGroup will always be executed + // separately and not in a multi-key bundle + public static final KeyGroup DEFAULT = new KeyGroup(0, 0); + + private final long high; + private final long low; + + private KeyGroup(long high, long low) { + this.high = high; + this.low = low; + } + + public static KeyGroup create(long high, long low) { + if (high == 0 && low == 0) { + return DEFAULT; + } + return new KeyGroup(high, low); + } + + public long high() { + return high; + } + + public long low() { + return low; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof KeyGroup)) { + return false; + } + KeyGroup other = (KeyGroup) o; + return high == other.high && low == other.low; + } + + @Override + public int hashCode() { + return Objects.hash(high, low); + } + + @Override + public String toString() { + return String.format("%016x%016x", high, low); + } + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java index c6fd96e0a4cb..1445492968fc 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutor.java @@ -29,10 +29,13 @@ import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.dataflow.worker.streaming.BoundedQueueExecutorWorkHandle; import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.Monitor.Guard; +import org.checkerframework.checker.nullness.qual.Nullable; /** An executor for executing work on windmill items. */ @SuppressWarnings({ @@ -78,6 +81,9 @@ private static class Budget { @GuardedBy("this") private long totalTimeMaxActiveThreadsUsed; + // If set the keyGroupWorkQueue is used by the underlying executor. + private final @Nullable KeyGroupWorkQueue keyGroupWorkQueue; + public BoundedQueueExecutor( int initialMaximumPoolSize, long keepAliveTime, @@ -85,7 +91,9 @@ public BoundedQueueExecutor( int maximumElementsOutstanding, long maximumBytesOutstanding, ThreadFactory threadFactory, - boolean useFairMonitor) { + boolean useFairMonitor, + boolean useKeyGroupWorkQueue) { + this.keyGroupWorkQueue = useKeyGroupWorkQueue ? new KeyGroupWorkQueue(useFairMonitor) : null; this.maximumPoolSize = initialMaximumPoolSize; monitor = new Monitor(useFairMonitor); executor = @@ -94,7 +102,7 @@ public BoundedQueueExecutor( initialMaximumPoolSize, keepAliveTime, unit, - new LinkedBlockingQueue<>(), + keyGroupWorkQueue != null ? keyGroupWorkQueue : new LinkedBlockingQueue<>(), threadFactory) { @Override protected void beforeExecute(Thread t, Runnable r) { @@ -313,7 +321,7 @@ public synchronized void close() { } } - private static final class QueuedWork implements Runnable { + static final class QueuedWork implements Runnable { private final ExecutableWork work; private final BoundedQueueExecutorWorkHandleImpl handle; @@ -378,6 +386,22 @@ BoundedQueueExecutorWorkHandleImpl createBudgetHandle(int elements, long bytes) return new BoundedQueueExecutorWorkHandleImpl(elements, bytes); } + public @Nullable ExecutableWork pollWork( + String computationId, Work.KeyGroup keyGroup, BoundedQueueExecutorWorkHandle handle) { + checkArgument(handle instanceof BoundedQueueExecutorWorkHandleImpl); + checkArgument(computationId != null && keyGroup != null && !keyGroup.equals(KeyGroup.DEFAULT)); + BoundedQueueExecutorWorkHandleImpl internalHandle = (BoundedQueueExecutorWorkHandleImpl) handle; + if (keyGroupWorkQueue == null) { + return null; + } + QueuedWork queuedWork = keyGroupWorkQueue.pollWork(computationId, keyGroup); + if (queuedWork == null) { + return null; + } + internalHandle.merge(queuedWork.getHandle()); + return queuedWork.getWork(); + } + private void decrementCounters(int elements, long bytes) { // All threads queue decrements and one thread grabs the monitor and updates // counters. We do this to reduce contention on monitor which is locked by diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java new file mode 100644 index 000000000000..823178d96584 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueue.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; + +import java.util.AbstractQueue; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import javax.annotation.concurrent.GuardedBy; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.streaming.Work.KeyGroup; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.jspecify.annotations.NonNull; + +/** + * A custom, thread-safe doubly-linked BlockingQueue. In addition to global FIFO ordering, the queue + * supports polling work by computation + key group in FIFO order + */ +class KeyGroupWorkQueue extends AbstractQueue implements BlockingQueue { + + public static final Runnable SENTINEL_RUNNABLE = + () -> { + throw new IllegalStateException("sentinel runnable called"); + }; + + static class Node { + // If keyGroup is non-null, task is an instance of QueuedWork + final Runnable task; + final @Nullable String computationId; + final Work.@Nullable KeyGroup keyGroup; + // cached keyGroupList if the Node is part of one. + @Nullable KeyGroupWorkList keyGroupList; + + // prevNode, nextNode are used for the global order across all queued Runnables + @Nullable Node prevNode; + @Nullable Node nextNode; + + // prevKeyGroupNode and nextKeyGroupNode are used for the keyGroup level lists linking + // QueuedWork with same keyGroup + @Nullable Node prevKeyGroupNode; + @Nullable Node nextKeyGroupNode; + + Node(Runnable task) { + this.task = task; + if (task instanceof QueuedWork) { + this.computationId = ((QueuedWork) task).getWork().getComputationId(); + this.keyGroup = ((QueuedWork) task).getWork().getKeyGroup(); + } else { + this.computationId = null; + this.keyGroup = null; + } + } + } + + /** Double linked list implementing key group level queue */ + private static class KeyGroupWorkList { + final Node head = new Node(SENTINEL_RUNNABLE); + final Node tail = new Node(SENTINEL_RUNNABLE); + + KeyGroupWorkList() { + head.nextKeyGroupNode = tail; + tail.prevKeyGroupNode = head; + } + + boolean isEmpty() { + return head.nextKeyGroupNode == tail; + } + + void append(Node node) { + Node last = checkStateNotNull(tail.prevKeyGroupNode); + node.prevKeyGroupNode = last; + node.nextKeyGroupNode = tail; + last.nextKeyGroupNode = node; + tail.prevKeyGroupNode = node; + } + + void remove(Node node) { + @Nullable Node prev = node.prevKeyGroupNode; + @Nullable Node next = node.nextKeyGroupNode; + if (prev != null && next != null) { + prev.nextKeyGroupNode = next; + next.prevKeyGroupNode = prev; + node.prevKeyGroupNode = null; + node.nextKeyGroupNode = null; + } + } + } + + private final ReentrantLock lock; + private final Condition notEmpty; + + // Sentinels for the global list + @GuardedBy("lock") + private final Node globalHead = new Node(SENTINEL_RUNNABLE); + + @GuardedBy("lock") + private final Node globalTail = new Node(SENTINEL_RUNNABLE); + + @GuardedBy("lock") + private final Map keyGroupQueueMap = new HashMap<>(); + + @GuardedBy("lock") + private int size = 0; + + public KeyGroupWorkQueue(boolean fair) { + this.lock = new ReentrantLock(fair); + this.notEmpty = lock.newCondition(); + globalHead.nextNode = globalTail; + globalTail.prevNode = globalHead; + } + + @GuardedBy("lock") + private void unlinkNode(Node node) { + // An existing node should always have previous and next since we have sentinels + // 1. Unlink from global list + Node prevG = checkArgumentNotNull(node.prevNode); + Node nextG = checkArgumentNotNull(node.nextNode); + prevG.nextNode = nextG; + nextG.prevNode = prevG; + node.prevNode = null; + node.nextNode = null; + + // 2. Unlink from key group list + KeyGroupWorkList list = node.keyGroupList; + if (list != null) { + list.remove(node); + if (list.isEmpty()) { + String compId = checkStateNotNull(node.computationId); + Work.KeyGroup keyGroup = checkStateNotNull(node.keyGroup); + QueueKey key = new QueueKey(compId, keyGroup); + keyGroupQueueMap.remove(key); + } + node.keyGroupList = null; + } + --size; + } + + @GuardedBy("lock") + private @Nullable Node removeFirstGlobal() { + Node first = checkStateNotNull(globalHead.nextNode); + if (first == globalTail) { + return null; + } + unlinkNode(first); + return first; + } + + /** + * Remove and Return QueuedWork for the computationId, keyGroup in the FIFO order. Returns null, + * if there are no matches. + * + * @param keyGroup should not be equal to KeyGroup.DEFAULT + */ + public @Nullable QueuedWork pollWork(String computationId, Work.KeyGroup keyGroup) { + checkArgument(computationId != null && keyGroup != null && !keyGroup.equals(KeyGroup.DEFAULT)); + QueueKey key = new QueueKey(computationId, keyGroup); + lock.lock(); + try { + KeyGroupWorkList keyGroupWorkList = keyGroupQueueMap.get(key); + if (keyGroupWorkList == null || keyGroupWorkList.isEmpty()) { + return null; + } + + // Retrieve the first pending task for this computation and keyGroup in O(1) + Node firstNode = checkStateNotNull(keyGroupWorkList.head.nextKeyGroupNode); + if (firstNode == keyGroupWorkList.tail) { + return null; + } + unlinkNode(firstNode); + + return (QueuedWork) firstNode.task; + } finally { + lock.unlock(); + } + } + + @Override + public boolean offer(@NonNull Runnable runnable) { + Node node = new Node(checkStateNotNull(runnable)); + lock.lock(); + try { + // Append to global list tail + Node lastG = checkStateNotNull(globalTail.prevNode); + node.prevNode = lastG; + node.nextNode = globalTail; + lastG.nextNode = node; + globalTail.prevNode = node; + + // Append to key group list if applicable + String compId = node.computationId; + Work.KeyGroup keyGroup = node.keyGroup; + if (compId != null && keyGroup != null && !keyGroup.equals(KeyGroup.DEFAULT)) { + QueueKey key = new QueueKey(compId, keyGroup); + KeyGroupWorkList keyGroupWorkList = + keyGroupQueueMap.computeIfAbsent(key, k -> new KeyGroupWorkList()); + keyGroupWorkList.append(node); + node.keyGroupList = keyGroupWorkList; + } + + ++size; + notEmpty.signal(); + return true; + } finally { + lock.unlock(); + } + } + + @Override + public void put(Runnable e) throws InterruptedException { + offer(e); // Unbounded queue + } + + @Override + public boolean offer(Runnable e, long timeout, TimeUnit unit) throws InterruptedException { + return offer(e); // Unbounded queue + } + + @Override + public @Nullable Runnable poll() { + lock.lock(); + try { + @Nullable Node node = removeFirstGlobal(); + return (node != null) ? node.task : null; + } finally { + lock.unlock(); + } + } + + @Override + public Runnable take() throws InterruptedException { + lock.lockInterruptibly(); + try { + while (size == 0) { + notEmpty.await(); + } + @Nullable Node node = removeFirstGlobal(); + checkStateNotNull(node, "Queue is empty but size was " + size); + return node.task; + } finally { + lock.unlock(); + } + } + + @Override + public @Nullable Runnable poll(long timeout, TimeUnit unit) throws InterruptedException { + long nanos = unit.toNanos(timeout); + lock.lockInterruptibly(); + try { + while (size == 0) { + if (nanos <= 0) { + return null; + } + nanos = notEmpty.awaitNanos(nanos); + } + @Nullable Node node = removeFirstGlobal(); + return (node != null) ? node.task : null; + } finally { + lock.unlock(); + } + } + + @Override + public @Nullable Runnable peek() { + lock.lock(); + try { + Node first = checkStateNotNull(globalHead.nextNode); + if (first == globalTail) { + return null; + } + return first.task; + } finally { + lock.unlock(); + } + } + + @Override + public int size() { + lock.lock(); + try { + return size; + } finally { + lock.unlock(); + } + } + + @Override + public boolean isEmpty() { + lock.lock(); + try { + return size == 0; + } finally { + lock.unlock(); + } + } + + @Override + public boolean remove(Object o) { + if (o == null) return false; + lock.lock(); + try { + // Walk the global queue in O(N) to find and unlink the node + Node curr = checkStateNotNull(globalHead.nextNode); + while (curr != globalTail) { + if (o.equals(curr.task)) { + unlinkNode(curr); + return true; + } + curr = checkStateNotNull(curr.nextNode); + } + return false; + } finally { + lock.unlock(); + } + } + + @Override + public boolean contains(Object o) { + if (o == null) return false; + lock.lock(); + try { + Node curr = checkStateNotNull(globalHead.nextNode); + while (curr != globalTail) { + if (o.equals(curr.task)) { + return true; + } + curr = checkStateNotNull(curr.nextNode); + } + return false; + } finally { + lock.unlock(); + } + } + + @Override + public int remainingCapacity() { + return Integer.MAX_VALUE; + } + + @Override + public int drainTo(Collection c) { + return drainTo(c, Integer.MAX_VALUE); + } + + @Override + public int drainTo(Collection c, int maxElements) { + if (c == null) throw new NullPointerException(); + if (c == this) throw new IllegalArgumentException(); + if (maxElements <= 0) return 0; + lock.lock(); + try { + int added = 0; + Node curr = checkStateNotNull(globalHead.nextNode); + while (curr != globalTail && added < maxElements) { + Node next = checkStateNotNull(curr.nextNode); + unlinkNode(curr); + Runnable task = curr.task; + c.add(task); + ++added; + curr = next; + } + return added; + } finally { + lock.unlock(); + } + } + + @Override + public void clear() { + lock.lock(); + try { + Node curr = checkStateNotNull(globalHead.nextNode); + while (curr != globalTail) { + Node next = checkStateNotNull(curr.nextNode); + unlinkNode(curr); + curr = next; + } + } finally { + lock.unlock(); + } + } + + @Override + public Iterator iterator() { + lock.lock(); + try { + ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(size); + Node curr = checkStateNotNull(globalHead.nextNode); + while (curr != globalTail) { + builder.add(curr.task); + curr = checkStateNotNull(curr.nextNode); + } + return builder.build().iterator(); + } finally { + lock.unlock(); + } + } + + static final class QueueKey { + private final String computationId; + private final Work.KeyGroup keyGroup; + + QueueKey(String computationId, Work.KeyGroup keyGroup) { + this.computationId = computationId; + this.keyGroup = keyGroup; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof QueueKey)) { + return false; + } + QueueKey other = (QueueKey) o; + return computationId.equals(other.computationId) && keyGroup.equals(other.keyGroup); + } + + @Override + public int hashCode() { + return Objects.hash(computationId, keyGroup); + } + + @Override + public String toString() { + return "QueueKey{" + + "computationId='" + + computationId + + '\'' + + ", keyGroup=" + + keyGroup + + '}'; + } + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index d58f20076994..5bcdffcc2564 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -3036,7 +3036,8 @@ public void testMaxThreadMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( @@ -3097,7 +3098,8 @@ public void testActiveThreadMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( @@ -3167,7 +3169,8 @@ public void testOutstandingBytesMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( @@ -3241,7 +3244,8 @@ public void testOutstandingBundlesMetric() throws Exception { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); ComputationState computationState = new ComputationState( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index 55fe82c7163c..a98102751fb2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -20,6 +20,8 @@ import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; @@ -33,6 +35,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.BoundedQueueExecutorWorkHandleImpl; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -66,13 +69,30 @@ public static Collection useFairMonitor() { @Rule public transient Timeout globalTimeout = Timeout.seconds(300); private BoundedQueueExecutor executor; + private static final Work.KeyGroup DEFAULT_KEY_GROUP = Work.KeyGroup.create(1, 2); + private static ExecutableWork createWork(Consumer executeWorkFn) { + return createWorkWithCompId("computationId", executeWorkFn); + } + + private static ExecutableWork createWorkWithCompId( + String computationId, Consumer executeWorkFn) { + return createWorkWithCompIdAndKeyGroup(computationId, DEFAULT_KEY_GROUP, executeWorkFn); + } + + private static ExecutableWork createWorkWithCompIdAndKeyGroup( + String computationId, Work.KeyGroup keyGroup, Consumer executeWorkFn) { WorkItem workItem = WorkItem.newBuilder() .setKey(ByteString.EMPTY) .setShardingKey(1) .setWorkToken(33) .setCacheToken(1) + .setKeyGroup( + Windmill.Uint128Proto.newBuilder() + .setHigh(keyGroup.high()) + .setLow(keyGroup.low()) + .build()) .build(); return ExecutableWork.create( Work.create( @@ -80,10 +100,7 @@ private static ExecutableWork createWork(Consumer executeWorkFn) { workItem.getSerializedSize(), Watermarks.builder().setInputDataWatermark(Instant.now()).build(), Work.createProcessingContext( - "computationId", - new FakeGetDataClient(), - ignored -> {}, - mock(HeartbeatSender.class)), + computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), false, Instant::now), (work, handle) -> { @@ -116,7 +133,8 @@ public void setUp() { .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - useFairMonitor); + useFairMonitor, + /*useKeyGroupWorkQueue=*/ false); } @Test @@ -413,4 +431,126 @@ public void testRenderSummaryHtml() { + "Work Queue Bytes: 0/10000000
/n"; assertEquals(expectedSummaryHtml, executor.summaryHtml()); } + + @Test + public void testPollWork() throws Exception { + // Create separate BoundedQueueExecutor with 1 thread so we can block it easily + BoundedQueueExecutor testExecutor = + new BoundedQueueExecutor( + 1, + 60, + TimeUnit.SECONDS, + 100, + 10000000, + new ThreadFactoryBuilder().setNameFormat("testStealing-%d").setDaemon(true).build(), + useFairMonitor, + /*useKeyGroupWorkQueue=*/ true); + + // 1. Create blocker task to occupy the worker thread + CountDownLatch blockerStart = new CountDownLatch(1); + CountDownLatch blockerStop = new CountDownLatch(1); + ExecutableWork blockerWork = + createWorkWithCompIdAndKeyGroup( + "blockerComp", + DEFAULT_KEY_GROUP, + ignored -> { + blockerStart.countDown(); + try { + blockerStop.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + testExecutor.execute(blockerWork, 0); + blockerStart.await(); + + // 2. Create two distinct key groups + Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); + Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2); + + // Create executable tasks + CountDownLatch targetStart = new CountDownLatch(1); + ExecutableWork work1 = createWorkWithCompIdAndKeyGroup("compA", keyGroup1, ignored -> {}); + ExecutableWork work2 = + createWorkWithCompIdAndKeyGroup( + "compA", + keyGroup2, + ignored -> { + targetStart.countDown(); + }); + + // Enqueue tasks (they will wait in the queue because the thread is blocked) + testExecutor.execute(work1, 100); + testExecutor.execute(work2, 150); + + // Total outstanding elements must be 3 (blocker + work1 + work2) + assertEquals(3, testExecutor.elementsOutstanding()); + + // Steal work2 using pollWork with compA and keyGroup2 + try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup2, stealHandle); + assertNotNull(stolen); + assertEquals(work2, stolen); + + // Run the stolen task + stolen.run(stealHandle); + targetStart.await(); + } + + // Steal work1 using pollWork with compA and keyGroup1 + try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup1, stealHandle); + assertNotNull(stolen); + assertEquals(work1, stolen); + } + + // Unblock the blocker and shut down + blockerStop.countDown(); + testExecutor.shutdown(); + } + + @Test + public void testPollWorkWithLinkedBlockingQueue() throws Exception { + BoundedQueueExecutor testExecutor = + new BoundedQueueExecutor( + 1, + 60, + TimeUnit.SECONDS, + 100, + 10000000, + new ThreadFactoryBuilder().setNameFormat("testLinkedQueue-%d").setDaemon(true).build(), + useFairMonitor, + /* useKeyGroupWorkQueue= */ false); + + CountDownLatch blockerStart = new CountDownLatch(1); + CountDownLatch blockerStop = new CountDownLatch(1); + ExecutableWork blockerWork = + createWorkWithCompIdAndKeyGroup( + "blockerComp", + DEFAULT_KEY_GROUP, + ignored -> { + blockerStart.countDown(); + try { + blockerStop.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + testExecutor.execute(blockerWork, 0); + blockerStart.await(); + + Work.KeyGroup keyGroup = Work.KeyGroup.create(1, 1); + ExecutableWork work = createWorkWithCompIdAndKeyGroup("compA", keyGroup, ignored -> {}); + testExecutor.execute(work, 100); + + try (BoundedQueueExecutorWorkHandleImpl stealHandle = testExecutor.createBudgetHandle(0, 0L)) { + ExecutableWork stolen = testExecutor.pollWork("compA", keyGroup, stealHandle); + assertNull(stolen); + } + + blockerStop.countDown(); + testExecutor.shutdown(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java new file mode 100644 index 000000000000..42079361391c --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/KeyGroupWorkQueueTest.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.runners.dataflow.worker.streaming.ExecutableWork; +import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; +import org.apache.beam.runners.dataflow.worker.streaming.Work; +import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor.QueuedWork; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public class KeyGroupWorkQueueTest { + + @Parameters(name = "fairQueue={0}") + public static Iterable data() { + return Arrays.asList(new Object[][] {{true}, {false}}); + } + + @Parameterized.Parameter public boolean fairQueue; + + private BoundedQueueExecutor executor; + + @Before + public void setUp() { + executor = + new BoundedQueueExecutor( + 2, + 60, + TimeUnit.SECONDS, + 100, + 10000000, + new ThreadFactoryBuilder().setNameFormat("Test-%d").setDaemon(true).build(), + fairQueue, + /*useKeyGroupWorkQueue=*/ true); + } + + private static final Work.KeyGroup TEST_KEY_GROUP = Work.KeyGroup.create(1, 2); + + private QueuedWork createQueuedWork(String computationId, long workBytes) { + return createQueuedWork(computationId, TEST_KEY_GROUP, workBytes); + } + + private QueuedWork createQueuedWork( + String computationId, Work.@Nullable KeyGroup keyGroup, long workBytes) { + WorkItem.Builder workItemBuilder = + WorkItem.newBuilder() + .setKey(ByteString.EMPTY) + .setShardingKey(1) + .setWorkToken(33) + .setCacheToken(1); + if (keyGroup != null) { + workItemBuilder.setKeyGroup( + org.apache.beam.runners.dataflow.worker.windmill.Windmill.Uint128Proto.newBuilder() + .setHigh(keyGroup.high()) + .setLow(keyGroup.low()) + .build()); + } + WorkItem workItem = workItemBuilder.build(); + ExecutableWork work = + ExecutableWork.create( + Work.create( + workItem, + workItem.getSerializedSize(), + Watermarks.builder().setInputDataWatermark(Instant.now()).build(), + Work.createProcessingContext( + computationId, + new FakeGetDataClient(), + ignored -> {}, + mock(HeartbeatSender.class)), + false, + Instant::now), + (w, h) -> {}); + return new QueuedWork(work, executor.createBudgetHandle(1, workBytes)); + } + + private static class MockRunnable implements Runnable { + final String id; + + MockRunnable(String id) { + this.id = id; + } + + @Override + public void run() {} + + @Override + public String toString() { + return "MockRunnable(" + id + ")"; + } + } + + @Test + public void testBasicOfferAndPoll() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + + assertTrue(queue.offer(task1)); + assertTrue(queue.offer(task2)); + assertEquals(2, queue.size()); + + assertEquals(task1, queue.poll()); + assertEquals(task2, queue.poll()); + assertNull(queue.poll()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testRemove() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + + queue.offer(task1); + queue.offer(task2); + + assertTrue(queue.remove(task1)); + assertEquals(1, queue.size()); + assertEquals(task2, queue.poll()); + assertFalse(queue.remove(task1)); // Already gone + } + + @Test + public void testDrainTo() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + queue.offer(task1); + queue.offer(task2); + + List drained = new ArrayList<>(); + assertEquals(2, queue.drainTo(drained)); + assertEquals(2, drained.size()); + assertEquals(task1, drained.get(0)); + assertEquals(task2, drained.get(1)); + assertTrue(queue.isEmpty()); + } + + @Test + public void testIteratorSafeTraversalAndImmutable() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + MockRunnable task1 = new MockRunnable("1"); + MockRunnable task2 = new MockRunnable("2"); + queue.offer(task1); + queue.offer(task2); + + Iterator it = queue.iterator(); + assertTrue(it.hasNext()); + assertEquals(task1, it.next()); + assertTrue(it.hasNext()); + assertEquals(task2, it.next()); + assertFalse(it.hasNext()); + + // Assert that mutating the iterator throws UnsupportedOperationException + it = queue.iterator(); + assertTrue(it.hasNext()); + it.next(); + try { + it.remove(); + fail("Iterator must be immutable"); + } catch (UnsupportedOperationException e) { + // Expected + } + } + + @Test + public void testPollWorkTargeted() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + + QueuedWork workA1 = createQueuedWork("compA", 100); + QueuedWork workB1 = createQueuedWork("compB", 200); + QueuedWork workA2 = createQueuedWork("compA", 150); + + queue.offer(workA1); + queue.offer(workB1); + queue.offer(workA2); + + assertEquals(3, queue.size()); + + // Targeted poll A + QueuedWork polledA1 = queue.pollWork("compA", TEST_KEY_GROUP); + assertNotNull(polledA1); + assertEquals("compA", polledA1.getWork().getComputationId()); + assertEquals(100, polledA1.getHandle().bytes()); + + // Verify size decremented + assertEquals(2, queue.size()); + + // Poll next should be B1 (since A1 was stolen, B1 is now first global) + assertEquals(workB1, queue.poll()); + assertEquals(1, queue.size()); + + // Last should be A2 + assertEquals(workA2, queue.poll()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testMemoryPruningLeavesZeroLeaks() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + QueuedWork workA1 = createQueuedWork("compA", 100); + queue.offer(workA1); + + // Steal A1 + QueuedWork polled = queue.pollWork("compA", TEST_KEY_GROUP); + assertNotNull(polled); + assertTrue(queue.isEmpty()); + + // Offering another work with different computation ID + QueuedWork workB1 = createQueuedWork("compB", 200); + queue.offer(workB1); + assertEquals(1, queue.size()); + + // Steal B1 + QueuedWork polledB = queue.pollWork("compB", TEST_KEY_GROUP); + assertNotNull(polledB); + assertTrue(queue.isEmpty()); + } + + @Test + public void testConcurrentStress() throws InterruptedException, ExecutionException { + final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + final int producerThreads = 4; + final int consumerThreads = 4; + final int tasksPerProducer = 1000; + final int totalTasks = producerThreads * tasksPerProducer; + + ExecutorService executorService = + Executors.newFixedThreadPool(producerThreads + consumerThreads); + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(producerThreads + consumerThreads); + final AtomicInteger consumedCount = new AtomicInteger(0); + List> futures = new ArrayList<>(); + + // Start producers + for (int i = 0; i < producerThreads; i++) { + futures.add( + executorService.submit( + () -> { + try { + startLatch.await(); + for (int j = 0; j < tasksPerProducer; j++) { + String compId = "comp-" + (j % 5); + queue.offer(createQueuedWork(compId, 10)); + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + doneLatch.countDown(); + } + })); + } + + // Start consumers (mix of poll and pollWork) + for (int i = 0; i < consumerThreads; i++) { + final int consumerId = i; + futures.add( + executorService.submit( + () -> { + try { + startLatch.await(); + while (consumedCount.get() < totalTasks) { + Runnable task; + if (consumerId % 2 == 0) { + // Targeted poll + String compId = "comp-" + (consumedCount.get() % 5); + task = queue.pollWork(compId, TEST_KEY_GROUP); + } else { + // Global poll + task = queue.poll(); + } + if (task != null) { + consumedCount.incrementAndGet(); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + doneLatch.countDown(); + } + })); + } + + startLatch.countDown(); + assertTrue(doneLatch.await(10, TimeUnit.SECONDS)); + + // Check for exceptions in threads + for (Future future : futures) { + future.get(); + } + + executorService.shutdown(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); + + assertEquals(0, queue.size()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testTakeBlocksAndWakesUp() throws InterruptedException { + final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + final MockRunnable task = new MockRunnable("take-task"); + final AtomicReference result = new AtomicReference<>(); + final CountDownLatch started = new CountDownLatch(1); + final CountDownLatch finished = new CountDownLatch(1); + + Thread t = + new Thread( + () -> { + started.countDown(); + try { + result.set(queue.take()); + } catch (InterruptedException e) { + // Ignore + } finally { + finished.countDown(); + } + }); + t.setDaemon(true); + t.start(); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + // Give thread a moment to enter await() + Thread.sleep(100); + assertEquals(Thread.State.WAITING, t.getState()); + + queue.offer(task); + + assertTrue(finished.await(2, TimeUnit.SECONDS)); + assertEquals(task, result.get()); + } + + @Test + public void testPollWithTimeout() throws InterruptedException { + final KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + final MockRunnable task = new MockRunnable("poll-task"); + final AtomicReference result = new AtomicReference<>(); + final CountDownLatch started = new CountDownLatch(1); + final CountDownLatch finished = new CountDownLatch(1); + + // 1. Verify timeout returns null + Thread t1 = + new Thread( + () -> { + started.countDown(); + try { + result.set(queue.poll(500, TimeUnit.MILLISECONDS)); + } catch (InterruptedException e) { + // Ignore + } finally { + finished.countDown(); + } + }); + t1.setDaemon(true); + t1.start(); + + assertTrue(started.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); + assertEquals(Thread.State.TIMED_WAITING, t1.getState()); + + assertTrue(finished.await(2, TimeUnit.SECONDS)); + assertNull(result.get()); + + // 2. Verify timed poll receives task offered concurrently + final CountDownLatch started2 = new CountDownLatch(1); + final CountDownLatch finished2 = new CountDownLatch(1); + final AtomicReference result2 = new AtomicReference<>(); + + Thread t2 = + new Thread( + () -> { + started2.countDown(); + try { + result2.set(queue.poll(2, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + // Ignore + } finally { + finished2.countDown(); + } + }); + t2.setDaemon(true); + t2.start(); + + assertTrue(started2.await(2, TimeUnit.SECONDS)); + Thread.sleep(100); + assertEquals(Thread.State.TIMED_WAITING, t2.getState()); + + queue.offer(task); + + assertTrue(finished2.await(2, TimeUnit.SECONDS)); + assertEquals(task, result2.get()); + } + + @Test + public void testPollWorkWithKeyGroup() { + KeyGroupWorkQueue queue = new KeyGroupWorkQueue(fairQueue); + + Work.KeyGroup keyGroup1 = Work.KeyGroup.create(1, 1); + Work.KeyGroup keyGroup2 = Work.KeyGroup.create(1, 2); + + QueuedWork workA1 = createQueuedWork("compA", keyGroup1, 100); + QueuedWork workA2 = createQueuedWork("compA", keyGroup2, 150); + + queue.offer(workA1); + queue.offer(workA2); + + assertEquals(2, queue.size()); + + // Poll with keyGroup2 first - should return workA2 + QueuedWork polledA2 = queue.pollWork("compA", keyGroup2); + assertNotNull(polledA2); + assertEquals(workA2, polledA2); + assertEquals(1, queue.size()); + + // Poll with keyGroup2 again - should return null + assertNull(queue.pollWork("compA", keyGroup2)); + + // Poll with keyGroup1 - should return workA1 + QueuedWork polledA1 = queue.pollWork("compA", keyGroup1); + assertNotNull(polledA1); + assertEquals(workA1, polledA1); + assertTrue(queue.isEmpty()); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java index 07b4b14fd115..ef0d8e434858 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/StreamingCommitFinalizerTest.java @@ -62,7 +62,8 @@ public void setUp() { .setNameFormat("FinalizationCallback-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); cleanupExecutor = Executors.newScheduledThreadPool( diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index 51bd4816b031..0610ed44c27f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -64,7 +64,8 @@ private static WorkFailureProcessor createWorkFailureProcessor( .setNameFormat("DataflowWorkUnits-%d") .setDaemon(true) .build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); return WorkFailureProcessor.forTesting(workExecutor, failureTracker, Optional::empty, clock, 0); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java index 88a82c6f76b6..f32282056e4f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java @@ -80,7 +80,8 @@ private static BoundedQueueExecutor workExecutor() { 1, 10000000, new ThreadFactoryBuilder().setNameFormat("DataflowWorkUnits-%d").setDaemon(true).build(), - /*useFairMonitor=*/ false); + /*useFairMonitor=*/ false, + /*useKeyGroupWorkQueue=*/ false); } private static ComputationState createComputationState(int computationIdSuffix) { diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index 1da7ef9be8bb..aaa09c105fc3 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -421,6 +421,11 @@ message WatermarkHold { optional string state_family = 4; } +message Uint128Proto { + required fixed64 high = 1; + required fixed64 low = 2; +} + // Proto describing a hot key detected on a given WorkItem. message HotKeyInfo { // The age of the hot key measured from when it was first detected. @@ -448,6 +453,8 @@ message WorkItem { // present, this field includes metadata associated with any hot key. optional HotKeyInfo hot_key_info = 11; + optional Uint128Proto key_group = 18; + reserved 12, 13, 14, 15, 16; }