Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.pekko.stream

import java.util.concurrent.{ CountDownLatch, TimeUnit }

import scala.concurrent.Await
import scala.concurrent.duration._

import org.apache.pekko.actor.ActorSystem
import org.apache.pekko.remote.artery.{ BenchTestSource, LatchSink }
import org.apache.pekko.stream.scaladsl._

import com.typesafe.config.ConfigFactory

/**
* Standalone benchmark runner for BroadcastHub consumer wheel performance.
* Run with: sbt "bench-jmh/runMain org.apache.pekko.stream.BroadcastHubBenchRunner"
*/
object BroadcastHubBenchRunner {

final val Elements = 100000
final val SmallBuffer = 64
final val LargeBuffer = 256
final val WarmupRuns = 2
final val MeasureRuns = 3

def main(args: Array[String]): Unit = {
val config = ConfigFactory.parseString("""
pekko.actor.default-dispatcher {
executor = "fork-join-executor"
fork-join-executor {
parallelism-factor = 1
}
}
""")

val consumerCounts = Array(64, 256, 1000, 2000)

println("=" * 80)
println("BroadcastHub Consumer Wheel Benchmark")
println(s"Elements per run: $Elements")
println(s"Warmup: $WarmupRuns runs, Measure: $MeasureRuns runs")
println("=" * 80)

for (bufferSize <- Array(SmallBuffer, LargeBuffer)) {
println(s"\n--- Buffer size: $bufferSize (wheel slots: ${bufferSize * 2}) ---")
println(f"${"Consumers"}%-12s ${"Avg (elem/s)"}%16s ${"Min"}%12s ${"Max"}%12s ${"StdDev"}%10s")
println("-" * 70)

for (consumerCount <- consumerCounts) {
implicit val system: ActorSystem = ActorSystem(s"bench-$consumerCount-$bufferSize", config)

// eager init
SystemMaterializer(system).materializer

val results = new Array[Double](WarmupRuns + MeasureRuns)

for (run <- 0 until WarmupRuns + MeasureRuns) {
val latch = new CountDownLatch(consumerCount)
val broadcastSink =
BroadcastHub.sink[java.lang.Integer](bufferSize = bufferSize, startAfterNrOfConsumers = consumerCount)
val testSource = Source.fromGraph(new BenchTestSource(Elements))
val source = testSource.runWith(broadcastSink)

val start = System.nanoTime()
var idx = 0
while (idx < consumerCount) {
source.runWith(new LatchSink(Elements, latch))
idx += 1
}

if (!latch.await(120, TimeUnit.SECONDS)) {
println(s" TIMEOUT at consumers=$consumerCount buffer=$bufferSize run=$run")
Await.result(system.terminate(), 10.seconds)
System.exit(1)
}
val elapsed = (System.nanoTime() - start) / 1e9
results(run) = Elements / elapsed
}

val measured = results.drop(WarmupRuns)
val avg = measured.sum / measured.length
val min = measured.min
val max = measured.max
val variance = measured.map(x => (x - avg) * (x - avg)).sum / measured.length
val stddev = math.sqrt(variance)

println(f"$consumerCount%-12d $avg%16.0f $min%12.0f $max%12.0f $stddev%10.0f")

Await.result(system.terminate(), 10.seconds)
}
}

println("\n" + "=" * 80)
println("Done.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,24 @@ import org.apache.pekko.stream.testkit.scaladsl.StreamTestKit

import com.typesafe.config.ConfigFactory

/**
* Benchmarks BroadcastHub throughput under high-fan-out lockstep consumer scenarios.
*
* The consumer wheel uses a LongMap per slot for O(1) keyed add/remove without Long boxing.
* In lockstep, all consumers cluster in the same wheel slot, maximizing per-slot contention.
* With a small buffer (64), the wheel has only 128 slots, so `consumerCount / 128` consumers
* share each slot — the old ArrayList.removeIf was O(k) per removal, now O(1).
*
* The `broadcast` benchmark parameterizes over consumer count with a fixed small buffer,
* measuring how throughput scales as wheel slot pressure increases.
*
* The `broadcastLargeBuffer` benchmark uses a larger buffer (256) for comparison,
* showing how the optimization holds up when consumers are spread across more slots.
*/
object BroadcastHubBenchmark {
final val OperationsPerInvocation = 100000
final val SmallBufferSize = 64
final val LargeBufferSize = 256
}

@State(Scope.Benchmark)
Expand All @@ -56,7 +72,7 @@ class BroadcastHubBenchmark {

var testSource: Source[java.lang.Integer, NotUsed] = _

@Param(Array("64", "256"))
@Param(Array("64", "256", "1000", "2000"))
var parallelism = 0

@Setup
Expand All @@ -71,12 +87,40 @@ class BroadcastHubBenchmark {
Await.result(system.terminate(), 5.seconds)
}

/**
* Lockstep broadcast with small buffer (64).
* All consumers stay at roughly the same wheel offset, clustering in the same slot.
* With 128 wheel slots and 2000 consumers, ~16 consumers share each slot on average;
* during NeedWakeup bursts, thousands cluster in a single slot.
* This maximizes the O(1) vs O(k) per-removal difference.
*/
@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def broadcast(): Unit = {
val latch = new CountDownLatch(parallelism)
val broadcastSink =
BroadcastHub.sink[java.lang.Integer](bufferSize = parallelism, startAfterNrOfConsumers = parallelism)
BroadcastHub.sink[java.lang.Integer](bufferSize = SmallBufferSize, startAfterNrOfConsumers = parallelism)
val sink = new LatchSink(OperationsPerInvocation, latch)
val source = testSource.runWith(broadcastSink)
var idx = 0
while (idx < parallelism) {
source.runWith(sink)
idx += 1
}
awaitLatch(latch)
}

/**
* Lockstep broadcast with larger buffer (256) for comparison.
* The wheel has 512 slots, so consumers are spread more thinly.
* Shows how the optimization scales when per-slot pressure is lower.
*/
@Benchmark
@OperationsPerInvocation(OperationsPerInvocation)
def broadcastLargeBuffer(): Unit = {
val latch = new CountDownLatch(parallelism)
val broadcastSink =
BroadcastHub.sink[java.lang.Integer](bufferSize = LargeBufferSize, startAfterNrOfConsumers = parallelism)
val sink = new LatchSink(OperationsPerInvocation, latch)
val source = testSource.runWith(broadcastSink)
var idx = 0
Expand All @@ -88,7 +132,7 @@ class BroadcastHubBenchmark {
}

private def awaitLatch(latch: CountDownLatch): Unit = {
if (!latch.await(30, TimeUnit.SECONDS)) {
if (!latch.await(60, TimeUnit.SECONDS)) {
StreamTestKit.printDebugDump(SystemMaterializer(system).materializer.supervisor)
throw new RuntimeException("Latch didn't complete in time")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,51 @@ class HubSpec extends StreamSpec {
in.sendComplete()
sinkProbe2.cancel()
}

"deliver all elements in order to many consumers" in {
val consumerCount = 200
val messageCount = 2000

val source = Source(0 until messageCount).runWith(BroadcastHub.sink(bufferSize = 256,
startAfterNrOfConsumers = consumerCount))

val futures = (0 until consumerCount).map { _ =>
source.runWith(Sink.seq)
}

val results = Await.result(Future.sequence(futures), 30.seconds)
results.foreach { result =>
result should ===(0 until messageCount)
}
}

"handle many consumers when some cancel mid-stream" in {
val totalConsumers = 64
val cancellingConsumers = 16
val cancelAfter = 64
val messageCount = 512

val source = Source(0 until messageCount).runWith(
BroadcastHub.sink(bufferSize = 256, startAfterNrOfConsumers = totalConsumers))

val cancellingFutures = (0 until cancellingConsumers).map { _ =>
source.take(cancelAfter).runWith(Sink.seq)
}

val remainingFutures = (0 until (totalConsumers - cancellingConsumers)).map { _ =>
source.runWith(Sink.seq)
}

val cancellingResults = Await.result(Future.sequence(cancellingFutures), 30.seconds)
cancellingResults.foreach { result =>
result should ===(0 until cancelAfter)
}

val remainingResults = Await.result(Future.sequence(remainingFutures), 30.seconds)
remainingResults.foreach { result =>
result should ===(0 until messageCount)
}
}
}

"PartitionHub" must {
Expand Down
72 changes: 49 additions & 23 deletions stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala
Original file line number Diff line number Diff line change
Expand Up @@ -536,14 +536,17 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
* of priorities always fall to a range
*
* This wheel tracks the position of Consumers relative to the slowest ones. Every slot
* contains a list of Consumers being known at that location (this might be out of date!).
* contains a map of Consumers being known at that location (this might be out of date!).
* Consumers from time to time send Advance messages to indicate that they have progressed
* by reading from the broadcast queue. Consumers that are blocked (due to reaching tail) request
* a wakeup and update their position at the same time.
*
* Each slot uses a LongMap keyed by Consumer.id for O(1) add/remove without Long boxing.
* Empty slots are null (no backing map allocated), reducing baseline memory and GC pressure.
* When a slot drains to zero consumers, its map is released (set to null).
*/
private[this] val consumerWheel =
Array.fill[java.util.ArrayList[Consumer]](bufferSize * 2)(new util.ArrayList[Consumer]())
new Array[LongMap[Consumer]](bufferSize * 2)
private[this] var activeConsumers = 0

override def preStart(): Unit = {
Expand Down Expand Up @@ -574,15 +577,19 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
val newOffset = previousOffset + DemandThreshold
// Move the consumer from its last known offset to its new one. Check if we are unblocked.
val consumer = findAndRemoveConsumer(id, previousOffset)
addConsumer(consumer, newOffset)
if (consumer ne null) {
addConsumer(consumer, newOffset)
}
checkUnblock(previousOffset)
case NeedWakeup(id, previousOffset, currentOffset) =>
// Move the consumer from its last known offset to its new one. Check if we are unblocked.
val consumer = findAndRemoveConsumer(id, previousOffset)
addConsumer(consumer, currentOffset)
if (consumer ne null) {
addConsumer(consumer, currentOffset)

// Also check if the consumer is now unblocked since we published an element since it went asleep.
if (currentOffset != tail) consumer.callback.invoke(Wakeup)
// Also check if the consumer is now unblocked since we published an element since it went asleep.
if (currentOffset != tail) consumer.callback.invoke(Wakeup)
}
checkUnblock(previousOffset)

case RegistrationPending =>
Expand Down Expand Up @@ -650,10 +657,14 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
consumer.callback.invoke(failMessage)
}

// Notify registered consumers
// Notify registered consumers — skip null (empty) slots
var idx = 0
while (idx < consumerWheel.length) {
consumerWheel(idx).forEach(_.callback.invoke(failMessage))
val bucket = consumerWheel(idx)
if (bucket ne null) {
val itr = bucket.valuesIterator
while (itr.hasNext) itr.next().callback.invoke(failMessage)
}
idx += 1
}
failStage(ex)
Expand All @@ -664,21 +675,19 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
*
* NB: You cannot remove a consumer without knowing its last offset! Consumers on the Source side always must
* track this so this can be a fast operation.
*
* Uses LongMap.getOrNull + -= to avoid Option allocation on the hot path.
*/
private def findAndRemoveConsumer(id: Long, offset: Int): Consumer = {
// TODO: Try to eliminate modulo division somehow...
val wheelSlot = offset & WheelMask
val consumersInSlot = consumerWheel(wheelSlot)
var removedConsumer: Consumer = null
if (consumersInSlot.size() > 0) {
consumersInSlot.removeIf(consumer => {
if (consumer.id == id) {
removedConsumer = consumer
true
} else false
})
val bucket = consumerWheel(wheelSlot)
if (bucket eq null) return null
val consumer = bucket.getOrNull(id)
if (consumer ne null) {
bucket -= id
if (bucket.isEmpty) consumerWheel(wheelSlot) = null
}
removedConsumer
consumer
}

/*
Expand All @@ -697,7 +706,7 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
if (offsetOfConsumerRemoved == head) {
// Try to advance along the wheel. We can skip any wheel slots which have no waiting Consumers, until
// we either find a nonempty one, or we reached the end of the buffer.
while (consumerWheel(head & WheelMask).isEmpty && head != tail) {
while (isConsumerWheelSlotEmpty(head & WheelMask) && head != tail) {
queue(head & Mask) = null
head += 1
unblocked = true
Expand All @@ -706,18 +715,35 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
unblocked
}

private def isConsumerWheelSlotEmpty(slot: Int): Boolean = {
val bucket = consumerWheel(slot)
(bucket eq null) || bucket.isEmpty
}

private def addConsumer(consumer: Consumer, offset: Int): Unit = {
val slot = offset & WheelMask
consumerWheel(slot).add(consumer)
val bucket = consumerWheel(slot)
if (bucket ne null) bucket.update(consumer.id, consumer)
else {
val newBucket = LongMap.empty[Consumer]
newBucket.update(consumer.id, consumer)
consumerWheel(slot) = newBucket
}
}

/*
* Send a wakeup signal to all the Consumers at a certain wheel index. Note, this needs the actual index,
* which is offset modulo (bufferSize + 1).
*
* Enumeration order of the bucket is not semantically significant — every consumer receives the same
* wakeup signal independently.
*/
private def wakeupIdx(idx: Int): Unit = {
val itr = consumerWheel(idx).iterator
while (itr.hasNext) itr.next().callback.invoke(Wakeup)
val bucket = consumerWheel(idx)
if (bucket ne null) {
val itr = bucket.valuesIterator
while (itr.hasNext) itr.next().callback.invoke(Wakeup)
}
}

private def complete(): Unit = {
Expand Down