diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index a813a9e5015..004eac95273 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -189,6 +189,7 @@ message PbWorkerInfo { int32 internalPort = 8; string networkLocation = 9; int64 nextInterruptionNotice = 10; // Unix timestamp when disruption is expected to be initiated + repeated string tags = 11; } message PbFileGroup { @@ -207,6 +208,7 @@ message PbRegisterWorker { map userResourceConsumption = 8; int32 internalPort = 10; string networkLocation = 11; + repeated string tags = 12; } message PbMetaRegisterWorkerRequest { @@ -219,6 +221,7 @@ message PbMetaRegisterWorkerRequest { map userResourceConsumption = 7; int32 internalPort = 8; string networkLocation = 9; + repeated string tags = 12; } message PbHeartbeatFromWorker { diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 32c4eb73d7e..6ba872657c5 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1519,6 +1519,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientInputStreamCreationWindow = get(CLIENT_INPUTSTREAM_CREATION_WINDOW) def tagsEnabled: Boolean = get(TAGS_ENABLED) + def tagsWorkerRegistrationEnabled: Boolean = get(TAGS_WORKER_REGISTRATION_ENABLED) + def workerTags: Seq[String] = get(WORKER_TAGS) def tagsExpr: String = get(TAGS_EXPR) def preferClientTagsExpr: Boolean = get(PREFER_CLIENT_TAGS_EXPR) @@ -6845,6 +6847,24 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(true) + val TAGS_WORKER_REGISTRATION_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.tags.worker.registration.enabled") + .categories("master") + .version("0.7.0") + .doc("When true, the master honors tags advertised by workers at registration " + + "(merged with the config-store tags). When false, worker-supplied tags are ignored.") + .booleanConf + .createWithDefault(true) + + val WORKER_TAGS: ConfigEntry[Seq[String]] = + buildConf("celeborn.worker.tags") + .categories("worker") + .version("0.7.0") + .doc("Comma-separated tags this worker supplies to the master at registration.") + .stringConf + .toSequence + .createWithDefault(Seq.empty) + val TAGS_EXPR: ConfigEntry[String] = buildConf("celeborn.tags.tagsExpr") .categories("master", "client") diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala index 0304bd4236a..e92818e9c7c 100644 --- a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala +++ b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala @@ -47,7 +47,8 @@ class WorkerInfo( var nextInterruptionNotice = Long.MaxValue var lastHeartbeat: Long = 0 var workerStatus = WorkerStatus.normalWorkerStatus() - var isHighWorkLoad: Boolean = false; + var isHighWorkLoad: Boolean = false + var tags: util.Set[String] = new util.HashSet[String]() val diskInfos = { if (_diskInfos != null) JavaUtils.newConcurrentHashMap[String, DiskInfo](_diskInfos) else null diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 36f164d697e..e6e407651cb 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -93,6 +93,7 @@ object ControlMessages extends Logging { networkLocation: String, disks: Map[String, DiskInfo], userResourceConsumption: Map[UserIdentifier, ResourceConsumption], + tags: Set[String], requestId: String): PbRegisterWorker = { val pbDisks = disks.values.map(PbSerDeUtils.toPbDiskInfo).asJava val pbUserResourceConsumption = @@ -107,6 +108,7 @@ object ControlMessages extends Logging { .setNetworkLocation(networkLocation) .addAllDisks(pbDisks) .putAllUserResourceConsumption(pbUserResourceConsumption) + .addAllTags(tags.asJava) .setRequestId(requestId) .build() } diff --git a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala index e9c407ce80e..d021981d5e8 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/PbSerDeUtils.scala @@ -296,6 +296,7 @@ object PbSerDeUtils { } else { workerInfo.nextInterruptionNotice = pbWorkerInfo.getNextInterruptionNotice } + workerInfo.tags = new util.HashSet[String](pbWorkerInfo.getTagsList) workerInfo } @@ -310,6 +311,7 @@ object PbSerDeUtils { .setPushPort(workerInfo.pushPort) .setReplicatePort(workerInfo.replicatePort) .setInternalPort(workerInfo.internalPort) + .addAllTags(workerInfo.tags) if (masterPersistWorkerNetworkLocation) { builder.setNetworkLocation(workerInfo.networkLocation) } diff --git a/docs/configuration/master.md b/docs/configuration/master.md index 1f1c900c946..619b48b25ad 100644 --- a/docs/configuration/master.md +++ b/docs/configuration/master.md @@ -107,4 +107,5 @@ license: | | celeborn.tags.enabled | true | false | Whether to enable tags for workers. | 0.6.0 | | | celeborn.tags.preferClientTagsExpr | false | true | When `true`, prefer the tags expression provided by the client over the tags expression provided by the master. | 0.6.0 | | | celeborn.tags.tagsExpr | | true | Expression to filter workers by tags. The expression is a comma-separated list of tags. The expression is evaluated as a logical AND of all tags. For example, `prod,high-io` filters workers that have both the `prod` and `high-io` tags. | 0.6.0 | | +| celeborn.tags.worker.registration.enabled | true | false | When true, the master honors tags advertised by workers at registration (merged with the config-store tags). When false, worker-supplied tags are ignored. | 0.7.0 | | diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index bb2cec89bc3..439555fd680 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -200,6 +200,7 @@ license: | | celeborn.worker.storage.storagePolicy.createFilePolicy | <undefined> | false | This defined the order for creating files across available storages. Available storages options are: MEMORY,SSD,HDD,HDFS,S3,OSS | 0.5.1 | | | celeborn.worker.storage.storagePolicy.evictPolicy | <undefined> | false | This define the order of evict files if the storages are available. Available storages: MEMORY,SSD,HDD,HDFS,S3,OSS. Definition: StorageTypes|StorageTypes|StorageTypes. Example: MEMORY,SSD|SSD,HDFS. The example means that a MEMORY shuffle file can be evicted to SSD and a SSD shuffle file can be evicted to HDFS. | 0.5.1 | | | celeborn.worker.storage.workingDir | celeborn-worker/shuffle_data | false | Worker's working dir path name. | 0.3.0 | celeborn.worker.workingDir | +| celeborn.worker.tags | | false | Comma-separated tags this worker supplies to the master at registration. | 0.7.0 | | | celeborn.worker.writer.close.timeout | 120s | false | Timeout for a file writer to close | 0.2.0 | | | celeborn.worker.writer.create.maxAttempts | 3 | false | Retry count for a file writer to create if its creation was failed. | 0.2.0 | | | celeborn.worker.writer.create.parallel.enabled | false | false | Whether to parallelize the creation of file writer. | 0.6.3 | | diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java index 11d73c82abb..cf3c8ce9321 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/AbstractMetaManager.java @@ -358,7 +358,8 @@ public void updateRegisterWorkerMeta( int replicatePort, int internalPort, String networkLocation, - Map disks) { + Map disks, + Set tags) { WorkerInfo workerInfo = new WorkerInfo( host, @@ -370,6 +371,7 @@ public void updateRegisterWorkerMeta( disks, new HashMap<>()); workerInfo.lastHeartbeat_$eq(System.currentTimeMillis()); + workerInfo.tags_$eq(new HashSet<>(tags)); if (networkLocation != null && !networkLocation.isEmpty() && !NetworkTopology.DEFAULT_RACK.equals(networkLocation)) { diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java index e3e063348fc..f44f93adfcf 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/IMetadataHandler.java @@ -17,8 +17,10 @@ package org.apache.celeborn.service.deploy.master.clustermeta; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.meta.ApplicationMeta; @@ -77,6 +79,31 @@ void handleWorkerHeartbeat( WorkerStatus workerStatus, String requestId); + default void handleRegisterWorker( + String host, + int rpcPort, + int pushPort, + int fetchPort, + int replicatePort, + int internalPort, + String networkLocation, + Map disks, + Map userResourceConsumption, + String requestId) { + handleRegisterWorker( + host, + rpcPort, + pushPort, + fetchPort, + replicatePort, + internalPort, + networkLocation, + disks, + userResourceConsumption, + Collections.emptySet(), + requestId); + } + void handleRegisterWorker( String host, int rpcPort, @@ -87,6 +114,7 @@ void handleRegisterWorker( String networkLocation, Map disks, Map userResourceConsumption, + Set tags, String requestId); void handleReportWorkerUnavailable(List failedNodes, String requestId); diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java index c4b2e843f23..f8353829232 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/SingleMasterMetaManager.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -162,9 +163,18 @@ public void handleRegisterWorker( String networkLocation, Map disks, Map userResourceConsumption, + Set tags, String requestId) { updateRegisterWorkerMeta( - host, rpcPort, pushPort, fetchPort, replicatePort, internalPort, networkLocation, disks); + host, + rpcPort, + pushPort, + fetchPort, + replicatePort, + internalPort, + networkLocation, + disks, + tags); updateWorkerResourceConsumptions( host, rpcPort, pushPort, fetchPort, replicatePort, userResourceConsumption); } diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java index 3372143aa23..63fde583f66 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HAMasterMetaManager.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -349,6 +350,7 @@ public void handleRegisterWorker( String networkLocation, Map disks, Map userResourceConsumption, + Set tags, String requestId) { try { ratisServer.submitRequest( @@ -365,6 +367,7 @@ public void handleRegisterWorker( .setInternalPort(internalPort) .setNetworkLocation(networkLocation) .putAllDisks(MetaUtil.toPbDiskInfos(disks)) + .addAllTags(tags) .build()) .build()); updateWorkerResourceConsumptions( diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java index 3c8e8fe0d0a..06a3c2c1f69 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/MetaHandler.java @@ -258,6 +258,7 @@ public org.apache.celeborn.common.protocol.PbMetaRequestResponse handleWriteRequ int internalPort = request.getRegisterWorkerRequest().getInternalPort(); Map pbDiskInfo = request.getRegisterWorkerRequest().getDisksMap(); diskInfos = MetaUtil.fromPbDiskInfoMap(pbDiskInfo); + Set tags = new HashSet<>(request.getRegisterWorkerRequest().getTagsList()); LOG.debug( "Handle worker register for {} {} {} {} {} {} {}", host, @@ -275,7 +276,8 @@ public org.apache.celeborn.common.protocol.PbMetaRequestResponse handleWriteRequ replicatePort, internalPort, networkLocation, - diskInfos); + diskInfos, + tags); break; case ReportWorkerUnavailable: diff --git a/master/src/main/proto/Resource.proto b/master/src/main/proto/Resource.proto index d8fceee1564..f9e36e3ae87 100644 --- a/master/src/main/proto/Resource.proto +++ b/master/src/main/proto/Resource.proto @@ -182,6 +182,7 @@ message RegisterWorkerRequest { map userResourceConsumption = 7; int32 internalPort = 8; optional string networkLocation = 9; + repeated string tags = 10; } message ReportWorkerUnavailableRequest { diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index 81e0e0076ca..45d4fa56f71 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -509,6 +509,9 @@ private[celeborn] class Master( .toMap.asJava val userResourceConsumption = PbSerDeUtils.fromPbUserResourceConsumption(pbRegisterWorker.getUserResourceConsumptionMap) + val tags = + if (conf.tagsWorkerRegistrationEnabled) pbRegisterWorker.getTagsList.asScala.toSet + else Set.empty[String] logDebug(s"Received RegisterWorker request $requestId, $host:$pushPort:$replicatePort" + s" $disks.") @@ -525,6 +528,7 @@ private[celeborn] class Master( networkLocation, disks, userResourceConsumption, + tags, requestId)) case requestSlots @ RequestSlots(applicationId, _, _, _, _, _, _, _, _, _, _, _, _) => @@ -845,6 +849,7 @@ private[celeborn] class Master( networkLocation: String, disks: util.Map[String, DiskInfo], userResourceConsumption: util.Map[UserIdentifier, ResourceConsumption], + tags: Set[String], requestId: String): Unit = { val workerToRegister = new WorkerInfo( @@ -880,6 +885,7 @@ private[celeborn] class Master( networkLocation, disks, userResourceConsumption, + tags.asJava, requestId) context.reply(RegisterWorkerResponse(true, "Worker in snapshot, re-register.")) } else if (statusSystem.workerLostEvents.contains(workerToRegister)) { @@ -896,6 +902,7 @@ private[celeborn] class Master( networkLocation, disks, userResourceConsumption, + tags.asJava, requestId) context.reply(RegisterWorkerResponse(true, "Worker in workerLostEvents, re-register.")) } else { @@ -909,6 +916,7 @@ private[celeborn] class Master( networkLocation, disks, userResourceConsumption, + tags.asJava, requestId) logInfo(s"Registered worker $workerToRegister.") context.reply(RegisterWorkerResponse(true, "")) diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/tags/TagsManager.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/tags/TagsManager.scala index 74aa69092ca..85db41aaf98 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/tags/TagsManager.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/tags/TagsManager.scala @@ -19,112 +19,64 @@ package org.apache.celeborn.service.deploy.master.tags import java.util import java.util.{Collections, Set => JSet} -import java.util.concurrent.ConcurrentHashMap import java.util.function.Predicate import java.util.stream.Collectors -import scala.collection.JavaConverters.{asScalaIteratorConverter, mapAsScalaConcurrentMapConverter} +import scala.collection.JavaConverters._ import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.WorkerInfo -import org.apache.celeborn.common.util.JavaUtils import org.apache.celeborn.server.common.service.config.ConfigService class TagsManager(configService: Option[ConfigService]) extends Logging { - private val defaultTagStore = JavaUtils.newConcurrentHashMap[String, JSet[String]]() - private val addNewTagFunc = - new util.function.Function[String, ConcurrentHashMap.KeySetView[String, java.lang.Boolean]]() { - override def apply(t: String): ConcurrentHashMap.KeySetView[String, java.lang.Boolean] = - ConcurrentHashMap.newKeySet[String]() - } - - private def getTagStore: ConcurrentHashMap[String, JSet[String]] = { + private def tagStore: Option[util.Map[String, JSet[String]]] = { configService match { - case Some(cs) => - // TODO: Make configStore.getTags return ConcurrentMap - JavaUtils.newConcurrentHashMap(cs.getSystemConfigFromCache.getTags) - case _ => - defaultTagStore + case Some(cs) => Option(cs.getSystemConfigFromCache.getTags) + case _ => None } } + private def resolveTagsExpr(userIdentifier: UserIdentifier, clientTagsExpr: String): String = + configService.map { cs => + val tagsMeta = cs + .getTenantUserConfigFromCache(userIdentifier.tenantId, userIdentifier.name) + .getWorkerTagsMeta + if (tagsMeta.preferClientTagExpr) clientTagsExpr else tagsMeta.tagsExpr + }.getOrElse(clientTagsExpr) + def getTaggedWorkers( userIdentifier: UserIdentifier, clientTagsExpr: String, workers: util.List[WorkerInfo]): util.List[WorkerInfo] = { - val tagsExpr = configService.flatMap { cs => - val config = cs.getTenantUserConfigFromCache(userIdentifier.tenantId, userIdentifier.name) - val tagsMeta = config.getWorkerTagsMeta - if (tagsMeta.preferClientTagExpr) { - Some(clientTagsExpr) - } else { - Some(tagsMeta.tagsExpr) - } - }.getOrElse(clientTagsExpr) + val tags = resolveTagsExpr(userIdentifier, clientTagsExpr) + .split(',').map(_.trim).filter(_.nonEmpty) - if (tagsExpr.isEmpty) { - logWarning("No tags provided") + if (tags.isEmpty) { + logDebug("No tags provided, returning all workers") return workers } - val tags = tagsExpr.split(",").map(_.trim) - - var workersForTags: Option[JSet[String]] = None - tags.foreach { tag => - val taggedWorkers = getTagStore.getOrDefault(tag, Collections.emptySet()) - workersForTags match { - case Some(w) => - w.retainAll(taggedWorkers) - case _ => - workersForTags = Some(new util.HashSet[String](taggedWorkers)) - } - } - - if (workersForTags.isEmpty) { - logWarning(s"No workers for tags: $tagsExpr found in cluster") - return Collections.emptyList() - } - + val store = tagStore val workerTagsPredicate = new Predicate[WorkerInfo] { - override def test(w: WorkerInfo): Boolean = workersForTags.get.contains(w.toUniqueId) + override def test(w: WorkerInfo): Boolean = tags.forall { tag => + w.tags.contains(tag) || + store.flatMap(s => Option(s.get(tag))).exists(_.contains(w.toUniqueId)) + } } workers.stream().filter(workerTagsPredicate).collect(Collectors.toList()) } - def addTagToWorker(tag: String, workerId: String): Unit = { - val workers = defaultTagStore.computeIfAbsent(tag, addNewTagFunc) - logInfo(s"Adding Tag $tag to worker $workerId") - workers.add(workerId) - } - - def removeTagFromWorker(tag: String, workerId: String): Unit = { - val workers = defaultTagStore.get(tag) - - if (workers != null && workers.contains(workerId)) { - logInfo(s"Removing Tag $tag from worker $workerId") - workers.remove(workerId) - } else { - logWarning(s"Tag $tag not found for worker $workerId") - } - } - def getTagsForWorker(worker: WorkerInfo): Set[String] = { - defaultTagStore.asScala.filter(_._2.contains(worker.toUniqueId)).keySet.toSet - } - - def removeTagFromCluster(tag: String): Unit = { - val workers = defaultTagStore.remove(tag) - if (workers != null) { - logInfo(s"Removed Tag $tag from cluster with workers ${workers.toArray.mkString(", ")}") - } else { - logWarning(s"Tag $tag not found in cluster and thus can not be removed") - } + val storeTags = tagStore.map(_.asScala.collect { + case (tag, workerIds) if workerIds.contains(worker.toUniqueId) => tag + }.toSet).getOrElse(Set.empty) + storeTags ++ worker.tags.asScala } def getTagsForCluster: Set[String] = { - defaultTagStore.keySet().iterator().asScala.toSet + tagStore.map(_.keySet.asScala.toSet).getOrElse(Set.empty) } } diff --git a/master/src/test/scala/org/apache/celeborn/service/deploy/master/tags/TagsManagerSuite.scala b/master/src/test/scala/org/apache/celeborn/service/deploy/master/tags/TagsManagerSuite.scala index 369cd3efbd3..cae58d1cc39 100644 --- a/master/src/test/scala/org/apache/celeborn/service/deploy/master/tags/TagsManagerSuite.scala +++ b/master/src/test/scala/org/apache/celeborn/service/deploy/master/tags/TagsManagerSuite.scala @@ -23,7 +23,7 @@ import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta.WorkerInfo -import org.apache.celeborn.server.common.service.config.DynamicConfigServiceFactory +import org.apache.celeborn.server.common.service.config.{ConfigService, DynamicConfigServiceFactory} class TagsManagerSuite extends CelebornFunSuite { private var tagsManager: TagsManager = _ @@ -38,127 +38,58 @@ class TagsManagerSuite extends CelebornFunSuite { private val workers = List(WORKER1, WORKER2, WORKER3).asJava private val user = UserIdentifier("tenant_01", "Jerry") + private var configService: ConfigService = _ override def beforeEach(): Unit = { super.beforeEach() DynamicConfigServiceFactory.reset() - } - - test("test tags manager") { - tagsManager = new TagsManager(Option(null)) - - tagsManager.addTagToWorker(TAG1, WORKER1.toUniqueId) - tagsManager.addTagToWorker(TAG1, WORKER2.toUniqueId) - - tagsManager.addTagToWorker(TAG2, WORKER2.toUniqueId) - tagsManager.addTagToWorker(TAG2, WORKER3.toUniqueId) - { - val taggedWorkers = tagsManager.getTaggedWorkers(user, TAG1, workers) - assert(taggedWorkers.size == 2) - assert(taggedWorkers.contains(WORKER1)) - assert(taggedWorkers.contains(WORKER2)) - assert(!taggedWorkers.contains(WORKER3)) - } + val conf = new CelebornConf() + conf.set(CelebornConf.DYNAMIC_CONFIG_STORE_BACKEND, "FS") + conf.set( + CelebornConf.DYNAMIC_CONFIG_STORE_FS_PATH.key, + getTestResourceFile("dynamicConfig-tags.yaml").getPath) + configService = DynamicConfigServiceFactory.getConfigService(conf) + } - { - val taggedWorkers = tagsManager.getTaggedWorkers(user, TAG2, workers) - assert(taggedWorkers.size == 2) - assert(!taggedWorkers.contains(WORKER1)) - assert(taggedWorkers.contains(WORKER2)) - assert(taggedWorkers.contains(WORKER3)) - } + private def workerWithTags(host: String, tags: String*): WorkerInfo = { + val w = new WorkerInfo(host, 111, 112, 113, 114, 115) + w.tags = new java.util.HashSet[String](tags.asJava) + w + } - { - // Test get tags for cluster - val tags = tagsManager.getTagsForCluster - assert(tags.size == 2) - assert(tags.contains(TAG1)) - assert(tags.contains(TAG2)) - } + test("getTaggedWorkers filters by worker self-registered tags (no config service)") { + tagsManager = new TagsManager(None) - { - // Test an unknown tag - val taggedWorkers = tagsManager.getTaggedWorkers(user, "unknown-tag", workers) - assert(taggedWorkers.isEmpty) - } + val w1 = workerWithTags("host1", TAG1) + val w2 = workerWithTags("host2", TAG1, TAG2) + val w3 = workerWithTags("host3", TAG2) + val workers = List(w1, w2, w3).asJava { - // Test get tags for worker - val tagsWorker1 = tagsManager.getTagsForWorker(WORKER1) - assert(tagsWorker1.size == 1) - assert(tagsWorker1.contains(TAG1)) - - val tagsWorker2 = tagsManager.getTagsForWorker(WORKER2) - assert(tagsWorker2.size == 2) - assert(tagsWorker2.contains(TAG1)) - assert(tagsWorker2.contains(TAG2)) - - val tagsWorker3 = tagsManager.getTagsForWorker(WORKER3) - assert(tagsWorker3.size == 1) - assert(tagsWorker3.contains(TAG2)) - - // Untagged worker - val untaggedWorker = new WorkerInfo("host4", 999, 999, 999, 999, 999) - val tagsUntaggedWorker = tagsManager.getTagsForWorker(untaggedWorker) - assert(tagsUntaggedWorker.isEmpty) + val tagged = tagsManager.getTaggedWorkers(user, TAG1, workers) + assert(tagged.size == 2) + assert(tagged.contains(w1) && tagged.contains(w2) && !tagged.contains(w3)) } - { - // Remove tag from worker - tagsManager.removeTagFromWorker(TAG1, WORKER2.toUniqueId) - val taggedWorkers = tagsManager.getTaggedWorkers(user, TAG1, workers) - assert(taggedWorkers.size == 1) - assert(taggedWorkers.contains(WORKER1)) - assert(!taggedWorkers.contains(WORKER2)) - assert(!taggedWorkers.contains(WORKER3)) + val tagged = tagsManager.getTaggedWorkers(user, "tag1,tag2", workers) + assert(tagged.size == 1 && tagged.contains(w2)) } - { - // Remove tag from cluster - tagsManager.removeTagFromCluster(TAG1) - val taggedWorkers = tagsManager.getTaggedWorkers(user, TAG1, workers) - assert(taggedWorkers.isEmpty) - - val tags = tagsManager.getTagsForCluster - assert(tags.size == 1) - assert(tags.contains(TAG2)) + val tagged = tagsManager.getTaggedWorkers(user, "tag1,tag3", workers) + assert(tagged.isEmpty) } - } - - test("test tags expression with multiple tags") { - tagsManager = new TagsManager(Option(null)) - - // Tag1 - tagsManager.addTagToWorker(TAG1, WORKER1.toUniqueId) - tagsManager.addTagToWorker(TAG1, WORKER2.toUniqueId) - - // Tag2 - tagsManager.addTagToWorker(TAG2, WORKER2.toUniqueId) - tagsManager.addTagToWorker(TAG2, WORKER3.toUniqueId) - { - val taggedWorkers = tagsManager.getTaggedWorkers(user, "tag1,tag2", workers) - assert(taggedWorkers.size == 1) - assert(!taggedWorkers.contains(WORKER1)) - assert(taggedWorkers.contains(WORKER2)) - assert(!taggedWorkers.contains(WORKER3)) + assert(tagsManager.getTaggedWorkers(user, "unknown-tag", workers).isEmpty) } - { - val taggedWorkers = tagsManager.getTaggedWorkers(user, "tag1,tag3", workers) - assert(taggedWorkers.size == 0) + assert(tagsManager.getTagsForWorker(w2) == Set(TAG1, TAG2)) + assert(tagsManager.getTagsForWorker(workerWithTags("host4")).isEmpty) + assert(tagsManager.getTagsForCluster.isEmpty) } } test("test tags manager with config service") { - val conf = new CelebornConf() - conf.set(CelebornConf.DYNAMIC_CONFIG_STORE_BACKEND, "FS") - conf.set( - CelebornConf.DYNAMIC_CONFIG_STORE_FS_PATH.key, - getTestResourceFile("dynamicConfig-tags.yaml").getPath) - val configService = DynamicConfigServiceFactory.getConfigService(conf) - tagsManager = new TagsManager(Option(configService)) { @@ -199,4 +130,37 @@ class TagsManagerSuite extends CelebornFunSuite { assert(taggedWorkers.contains(WORKER3)) } } + + test("getTaggedWorkers matches workers tagged via either config store or self-registration") { + tagsManager = new TagsManager(Option(configService)) + val selfTaggedWorker = workerWithTags("host4", TAG1) + val all = List(WORKER1, WORKER2, WORKER3, selfTaggedWorker).asJava + + val tagged = tagsManager.getTaggedWorkers(user, TAG1, all) + assert(tagged.size == 3) // WORKER1, WORKER2 (config store) + selfTagged (self) + assert(tagged.contains(WORKER1)) + assert(tagged.contains(WORKER2)) + assert(!tagged.contains(WORKER3)) + assert(tagged.contains(selfTaggedWorker)) + } + + test("getTaggedWorkers matches a worker whose tags span config store and self-registration") { + tagsManager = new TagsManager(Option(configService)) + // host1 already tagged via config service + val selfTaggedWorker = workerWithTags("host1", TAG2) + val all = List(WORKER1, WORKER2, WORKER3, selfTaggedWorker).asJava + + val tagged = tagsManager.getTaggedWorkers(user, "tag1,tag2", all) + assert(tagged.size == 2) + assert(tagged.contains(WORKER1)) + assert(tagged.contains(WORKER2)) + assert(!tagged.contains(WORKER3)) + } + + test("getTaggedWorkers returns empty when no config service and workers have no matching tags") { + tagsManager = new TagsManager(None) + val untagged = List(workerWithTags("host1")).asJava + assert(tagsManager.getTaggedWorkers(user, "tag1", untagged).isEmpty) + assert(tagsManager.getTagsForCluster.isEmpty) + } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index 1c5bc201803..1faf2bc5a15 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -300,6 +300,10 @@ private[celeborn] class Worker( diskInfos, JavaUtils.newConcurrentHashMap[UserIdentifier, ResourceConsumption]) + // Tags this worker advertises to the master at registration. + private val workerTags = conf.workerTags.toSet + workerInfo.tags = workerTags.asJava + // whether this Worker registered to Master successfully val registered = new AtomicBoolean(false) val shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = @@ -705,6 +709,7 @@ private[celeborn] class Worker( // StorageManager have update the disk info. workerInfo.diskInfos.asScala.toMap, handleResourceConsumption().asScala.toMap, + workerTags, MasterClient.genRequestId()), classOf[PbRegisterWorkerResponse]) } catch {