From 5b193531e58755fafea857e4c719dbd78c22ce83 Mon Sep 17 00:00:00 2001 From: Aravind Patnam Date: Tue, 13 Aug 2024 10:18:15 -0700 Subject: [PATCH] fix --- .../celeborn/common/util/PbSerDeUtils.scala | 19 +++++---------- .../common/util/PbSerDeUtilsTest.scala | 24 ++++++++++--------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index d649500f030..96f1ed2d53d 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -35,17 +35,10 @@ import org.apache.celeborn.common.util.{CollectionUtils => localCollectionUtils} object PbSerDeUtils { - private var masterPersistWorkerNetworkLocation: Option[Boolean] = None - - def setMasterPersistWorkerNetworkLocation(value: Boolean): Unit = { - masterPersistWorkerNetworkLocation match { - case None => masterPersistWorkerNetworkLocation = Some(value) - case Some(_) => - // this should never happen, but being defensive - throw new IllegalStateException( - s"masterPersistWorkerNetworkLocation has already been set once to" + - s" ${masterPersistWorkerNetworkLocation.get}") - } + private var masterPersistWorkerNetworkLocation: Boolean = false + + def setMasterPersistWorkerNetworkLocation(masterPersistWorkerNetworkLocation: Boolean) = { + this.masterPersistWorkerNetworkLocation = masterPersistWorkerNetworkLocation } @throws[InvalidProtocolBufferException] @@ -249,7 +242,7 @@ object PbSerDeUtils { pbWorkerInfo.getInternalPort, disks, userResourceConsumption) - if (masterPersistWorkerNetworkLocation.getOrElse(false)) { + if (masterPersistWorkerNetworkLocation) { workerInfo.networkLocation_$eq(pbWorkerInfo.getNetworkLocation) } workerInfo @@ -266,7 +259,7 @@ object PbSerDeUtils { .setPushPort(workerInfo.pushPort) .setReplicatePort(workerInfo.replicatePort) .setInternalPort(workerInfo.internalPort) - if (masterPersistWorkerNetworkLocation.getOrElse(false)) { + if (masterPersistWorkerNetworkLocation) { builder.setNetworkLocation(workerInfo.networkLocation) } diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 82dc0c2a1b4..184d49d1456 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -251,17 +251,19 @@ class PbSerDeUtilsTest extends CelebornFunSuite { } test("fromAndToPbWorkerInfo") { - PbSerDeUtils.setMasterPersistWorkerNetworkLocation(true) - val pbWorkerInfo = PbSerDeUtils.toPbWorkerInfo(workerInfo1, false, false) - val pbWorkerInfoWithEmptyResource = PbSerDeUtils.toPbWorkerInfo(workerInfo1, true, false) - val restoredWorkerInfo = PbSerDeUtils.fromPbWorkerInfo(pbWorkerInfo) - val restoredWorkerInfoWithEmptyResource = - PbSerDeUtils.fromPbWorkerInfo(pbWorkerInfoWithEmptyResource) - - assert(restoredWorkerInfo.equals(workerInfo1)) - assert(restoredWorkerInfoWithEmptyResource.userResourceConsumption.equals(new util.HashMap[ - UserIdentifier, - ResourceConsumption]())) + Seq(false, true).foreach { b => + PbSerDeUtils.setMasterPersistWorkerNetworkLocation(b) + val pbWorkerInfo = PbSerDeUtils.toPbWorkerInfo(workerInfo1, false, false) + val pbWorkerInfoWithEmptyResource = PbSerDeUtils.toPbWorkerInfo(workerInfo1, true, false) + val restoredWorkerInfo = PbSerDeUtils.fromPbWorkerInfo(pbWorkerInfo) + val restoredWorkerInfoWithEmptyResource = + PbSerDeUtils.fromPbWorkerInfo(pbWorkerInfoWithEmptyResource) + + assert(restoredWorkerInfo.equals(workerInfo1)) + assert(restoredWorkerInfoWithEmptyResource.userResourceConsumption.equals(new util.HashMap[ + UserIdentifier, + ResourceConsumption]())) + } } test("fromAndToPbPartitionLocation") {