Skip to content

Commit

Permalink
[CELEBORN-1600] Support revise lost shuffles
Browse files Browse the repository at this point in the history
  • Loading branch information
FMX committed Sep 23, 2024
1 parent c4c3299 commit e71ac9e
Show file tree
Hide file tree
Showing 28 changed files with 645 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,10 @@ final class MasterOptions {
names = Array("--remove-workers-unavailable-info"),
description = Array("Remove the workers unavailable info from the master."))
private[master] var removeWorkersUnavailableInfo: Boolean = _

@Option(
names = Array("--revise-lost-shuffles"),
description = Array("Revise lost shuffles or remove shuffles for an application."))
private[master] var reviseLostShuffles: Boolean = _

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ trait MasterSubcommand extends CliLogging {
@ArgGroup(exclusive = true, multiplicity = "1")
private[master] var masterOptions: MasterOptions = _

@ArgGroup(exclusive = false)
private[master] var reviseLostShuffleOptions: ReviseLostShuffleOptions = _

@Mixin
private[master] var commonOptions: CommonOptions = _

Expand Down Expand Up @@ -106,4 +109,6 @@ trait MasterSubcommand extends CliLogging {

private[master] def runShowThreadDump: ThreadStackResponse

private[master] def reviseLostShuffles: ReviseLostShuffleResponse

}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class MasterSubcommandImpl extends Runnable with MasterSubcommand {
if (masterOptions.showConf) log(runShowConf)
if (masterOptions.showDynamicConf) log(runShowDynamicConf)
if (masterOptions.showThreadDump) log(runShowThreadDump)
if (masterOptions.reviseLostShuffles) log(reviseLostShuffles)
if (masterOptions.addClusterAlias != null && masterOptions.addClusterAlias.nonEmpty)
runAddClusterAlias
if (masterOptions.removeClusterAlias != null && masterOptions.removeClusterAlias.nonEmpty)
Expand Down Expand Up @@ -206,4 +207,11 @@ class MasterSubcommandImpl extends Runnable with MasterSubcommand {
cliConfigManager.remove(aliasToRemove)
logInfo(s"Cluster alias $aliasToRemove removed.")
}

override private[master] def reviseLostShuffles: ReviseLostShuffleResponse = {
val deleteApp = reviseLostShuffleOptions.deleteApp
val appId = reviseLostShuffleOptions.appId
val shuffleIds = reviseLostShuffleOptions.shuffleIds
applicationApi.reviseLostShuffles(deleteApp, appId, shuffleIds)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.cli.master

import picocli.CommandLine.Option

final class ReviseLostShuffleOptions {

@Option(
names = Array("--deleteApp"),
description = Array("Whether to delete an application's shuffles or not"))
private[master] var deleteApp: String = _

@Option(
names = Array("--appId"),
description = Array("The application to manipulate shuffles"))
private[master] var appId: String = _

@Option(
names = Array("--shuffleIds"),
description = Array("The shuffle ids to manipulate."))
private[master] var shuffleIds: String = _

}
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,26 @@ class TestCelebornCliCommands extends CelebornFunSuite with MiniClusterFeature {
captureOutputAndValidateResponse(args, "success: true")
}

test("master --revise-lost-shuffles case1") {
val args = prepareMasterArgs() ++ Array(
"--revise-lost-shuffles",
"--deleteApp",
"true",
"--appId",
"app1")
captureOutputAndValidateResponse(args, "success: true")
}

test("master --revise-lost-shuffles case2") {
val args = prepareMasterArgs() ++ Array(
"--revise-lost-shuffles",
"--deleteApp",
"false",
"--appId",
"app1")
captureOutputAndValidateResponse(args, "success: true")
}

private def prepareMasterArgs(): Array[String] = {
Array(
"master",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@

package org.apache.celeborn.client

import java.util.concurrent.{ScheduledFuture, TimeUnit}
import java.util
import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit}
import java.util.function.Consumer

import scala.collection.JavaConverters._

import org.apache.commons.lang3.StringUtils

import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ZERO_UUID}
import org.apache.celeborn.common.protocol.PbReviseLostShufflesResponse
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ReviseLostShuffles, ZERO_UUID}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{ThreadUtils, Utils}

Expand All @@ -33,7 +38,8 @@ class ApplicationHeartbeater(
conf: CelebornConf,
masterClient: MasterClient,
shuffleMetrics: () => (Long, Long),
workerStatusTracker: WorkerStatusTracker) extends Logging {
workerStatusTracker: WorkerStatusTracker,
registeredShuffles: ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]) extends Logging {

private var stopped = false

Expand Down Expand Up @@ -68,6 +74,27 @@ class ApplicationHeartbeater(
if (response.statusCode == StatusCode.SUCCESS) {
logDebug("Successfully send app heartbeat.")
workerStatusTracker.handleHeartbeatResponse(response)
// revise shuffle id if there are lost shuffles
val masterRecordedShuffleIds = response.registeredShuffles
val localShuffleIds = new util.ArrayList[Integer]()
registeredShuffles.forEach(new Consumer[Int] {
override def accept(key: Int): Unit = {
localShuffleIds.add(key)
}
})
localShuffleIds.removeAll(masterRecordedShuffleIds)
if (!localShuffleIds.isEmpty) {
logWarning(s"There are lost shuffle found ${StringUtils.join(localShuffleIds, ",")}, revise lost shuffles.")
val reviseLostShufflesResponse = masterClient.askSync(
ReviseLostShuffles.apply(appId, localShuffleIds, MasterClient.genRequestId()),
classOf[PbReviseLostShufflesResponse])
if (!reviseLostShufflesResponse.getSuccess) {
logWarning(
s"Revise lost shuffles failed. Error message :${reviseLostShufflesResponse.getMessage}")
} else {
logInfo("Revise lost shuffles succeed.")
}
}
}
} catch {
case it: InterruptedException =>
Expand Down Expand Up @@ -97,6 +124,7 @@ class ApplicationHeartbeater(
StatusCode.REQUEST_FAILED,
List.empty.asJava,
List.empty.asJava,
List.empty.asJava,
List.empty.asJava)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
conf,
masterClient,
() => commitManager.commitMetrics(),
workerStatusTracker)
workerStatusTracker,
registeredShuffle)
private val changePartitionManager = new ChangePartitionManager(conf, this)
private val releasePartitionManager = new ReleasePartitionManager(conf, this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
StatusCode.SUCCESS,
excludedWorkers,
unknownWorkers,
shuttingWorkers)
shuttingWorkers,
new util.ArrayList[Integer]())
}

private def mockWorkers(workerHosts: Array[String]): util.ArrayList[WorkerInfo] = {
Expand Down
15 changes: 15 additions & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ enum MessageType {
REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE = 84;
SEGMENT_START = 85;
NOTIFY_REQUIRED_SEGMENT = 86;

REVISE_LOST_SHUFFLES = 202;
REVISE_LOST_SHUFFLES_RESPONSE = 203;
}

enum StreamType {
Expand Down Expand Up @@ -433,6 +436,7 @@ message PbHeartbeatFromApplicationResponse {
repeated PbWorkerInfo excludedWorkers = 2;
repeated PbWorkerInfo unknownWorkers = 3;
repeated PbWorkerInfo shuttingWorkers = 4;
repeated int32 registeredShuffles = 6;
}

message PbCheckQuota {
Expand Down Expand Up @@ -842,3 +846,14 @@ message PbReportWorkerDecommission {
repeated PbWorkerInfo workers = 1;
string requestId = 2;
}

message PbReviseLostShuffles{
string appId = 1;
repeated int32 lostShuffles = 2;
string requestId = 3;
}

message PbReviseLostShufflesResponse{
bool success = 1;
string message = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,26 @@ object ControlMessages extends Logging {
.build()
}

object ReviseLostShuffles {
def apply(
appId: String,
lostShuffles: java.util.List[Integer],
requestId: String): PbReviseLostShuffles =
PbReviseLostShuffles.newBuilder()
.setAppId(appId)
.addAllLostShuffles(lostShuffles)
.setRequestId(requestId)
.build()
}

object ReviseLostShufflesResponse {
def apply(success: Boolean, message: String): PbReviseLostShufflesResponse =
PbReviseLostShufflesResponse.newBuilder()
.setSuccess(success)
.setMessage(message)
.build()
}

case class StageEnd(shuffleId: Int) extends MasterMessage

case class StageEndResponse(status: StatusCode)
Expand Down Expand Up @@ -376,7 +396,8 @@ object ControlMessages extends Logging {
statusCode: StatusCode,
excludedWorkers: util.List[WorkerInfo],
unknownWorkers: util.List[WorkerInfo],
shuttingWorkers: util.List[WorkerInfo]) extends Message
shuttingWorkers: util.List[WorkerInfo],
registeredShuffles: util.List[Integer]) extends Message

case class CheckQuota(userIdentifier: UserIdentifier) extends Message

Expand Down Expand Up @@ -541,6 +562,12 @@ object ControlMessages extends Logging {
case pb: PbReportShuffleFetchFailureResponse =>
new TransportMessage(MessageType.REPORT_SHUFFLE_FETCH_FAILURE_RESPONSE, pb.toByteArray)

case pb: PbReviseLostShuffles =>
new TransportMessage(MessageType.REVISE_LOST_SHUFFLES, pb.toByteArray)

case pb: PbReviseLostShufflesResponse =>
new TransportMessage(MessageType.REVISE_LOST_SHUFFLES_RESPONSE, pb.toByteArray)

case pb: PbReportBarrierStageAttemptFailure =>
new TransportMessage(MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE, pb.toByteArray)

Expand Down Expand Up @@ -767,7 +794,8 @@ object ControlMessages extends Logging {
statusCode,
excludedWorkers,
unknownWorkers,
shuttingWorkers) =>
shuttingWorkers,
registeredShuffles) =>
val payload = PbHeartbeatFromApplicationResponse.newBuilder()
.setStatus(statusCode.getValue)
.addAllExcludedWorkers(
Expand All @@ -776,6 +804,7 @@ object ControlMessages extends Logging {
unknownWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, true)).toList.asJava)
.addAllShuttingWorkers(
shuttingWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, true)).toList.asJava)
.addAllRegisteredShuffles(registeredShuffles)
.build().toByteArray
new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION_RESPONSE, payload)

Expand Down Expand Up @@ -1152,7 +1181,8 @@ object ControlMessages extends Logging {
pbHeartbeatFromApplicationResponse.getUnknownWorkersList.asScala
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
pbHeartbeatFromApplicationResponse.getShuttingWorkersList.asScala
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava)
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
pbHeartbeatFromApplicationResponse.getRegisteredShufflesList)

case CHECK_QUOTA_VALUE =>
val pbCheckAvailable = PbCheckQuota.parseFrom(message.getPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ object PbSerDeUtils {

def toPbSnapshotMetaInfo(
estimatedPartitionSize: java.lang.Long,
registeredShuffle: java.util.Set[String],
registeredShuffle: java.util.Map[String, java.util.Set[Integer]],
hostnameSet: java.util.Set[String],
excludedWorkers: java.util.Set[WorkerInfo],
manuallyExcludedWorkers: java.util.Set[WorkerInfo],
Expand All @@ -468,7 +468,9 @@ object PbSerDeUtils {
decommissionWorkers: java.util.Set[WorkerInfo]): PbSnapshotMetaInfo = {
val builder = PbSnapshotMetaInfo.newBuilder()
.setEstimatedPartitionSize(estimatedPartitionSize)
.addAllRegisteredShuffle(registeredShuffle)
.addAllRegisteredShuffle(registeredShuffle.asScala.flatMap { appIdAndShuffleId =>
appIdAndShuffleId._2.asScala.map(i => Utils.makeShuffleKey(appIdAndShuffleId._1, i))
}.asJava)
.addAllHostnameSet(hostnameSet)
.addAllExcludedWorkers(excludedWorkers.asScala.map(toPbWorkerInfo(_, true, false)).asJava)
.addAllManuallyExcludedWorkers(manuallyExcludedWorkers.asScala
Expand Down
Loading

0 comments on commit e71ac9e

Please sign in to comment.