Spark 2.0 RPC通信層設(shè)計(jì)原理分析

Spark RPC層設(shè)計(jì)概況

spark2.0的RPC框架是基于優(yōu)秀的網(wǎng)絡(luò)通信框架Netty開(kāi)發(fā)的,我們先把Spark中與RPC相關(guān)的一些類的關(guān)系梳理一下,為了能夠更直觀地表達(dá)RPC的設(shè)計(jì),我們先從類的設(shè)計(jì)來(lái)看,如下圖所示:



從上圖左半邊可以看出,RPC通信主要有RpcEnv、RpcEndpoint、RpcEndpointRef這三個(gè)核心類。
RpcEndpoint是一個(gè)通信端,例如Spark集群中的Master,或Worker,都是一個(gè)RpcEndpoint。但是,如果想要與一個(gè)RpcEndpoint端進(jìn)行通信,一定需要獲取到該RpcEndpoint一個(gè)RpcEndpointRef,通過(guò)RpcEndpointRef與RpcEndpoint進(jìn)行通信,只能通過(guò)一個(gè)RpcEnv環(huán)境對(duì)象來(lái)獲取RpcEndpoint對(duì)應(yīng)的RPCEndpointRef。

客戶端通過(guò)RpcEndpointRef發(fā)消息,首先通過(guò)RpcEnv來(lái)處理這個(gè)消息,找到這個(gè)消息具體發(fā)給誰(shuí),然后路由給RpcEndpoint實(shí)體。Spark默認(rèn)使用更加高效的NettyRpcEnv。下面對(duì)這個(gè)三個(gè)類進(jìn)行詳細(xì)介紹。

RpcEnv

RpcEnv是RPC的環(huán)境對(duì)象,管理著整個(gè)RpcEndpoint的生命周期,其主要功能有:根據(jù)name或uri注冊(cè)endpoints、管理各種消息的處理、停止endpoints。其中RpcEnv只能通過(guò)RpcEnvFactory創(chuàng)建得到。
RpcEnv中有一個(gè)核心的方法:
def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
通過(guò)上面方法,可以注冊(cè)一個(gè)RpcEndpoint到RpcEnv環(huán)境對(duì)象中,由RpcEnv來(lái)管理RpcEndpoint到RpcEndpointRef的綁定關(guān)系。在注冊(cè)RpcEndpoint時(shí),每個(gè)RpcEndpoint都需要有一個(gè)唯一的名稱。

RpcEndpoint

RpcEndpoint定義了RPC通信過(guò)程中的通信端對(duì)象,除了具有管理一個(gè)RpcEndpoint生命周期的操作(constructor-> onStart -> receive* ->onStop),并給出了通信過(guò)程中一個(gè)RpcEndpoint所具有的基于事件驅(qū)動(dòng)的行為(連接、斷開(kāi)、網(wǎng)絡(luò)異常),實(shí)際上對(duì)于Spark框架來(lái)說(shuō)RpcEndpoint主要是接收消息并處理。
RpcEndpoint中有兩個(gè)核心方法:

def  receive:PartialFunction[Any, Unit]={
    case_ =>throw newSparkException(self +" does not implement 'receive'")
}
def  receiveAndReply(context:RpcCallContext):PartialFunction[Any, Unit]={
  case_ => context.sendFailure(newSparkException(self +" won't reply anything"))
}

通過(guò)上面的receive方法,接收由RpcEndpointRef.send方法發(fā)送的消息,該類消息不需要進(jìn)行響應(yīng)消息(Reply),而只是在RpcEndpoint端進(jìn)行處理。通過(guò)receiveAndReply方法,接收由RpcEndpointRef.ask發(fā)送的消息,RpcEndpoint端處理完消息后,需要給調(diào)用RpcEndpointRef.ask的通信端響應(yīng)消息。

RpcEndPointRef

RpcEndpointRef是一個(gè)對(duì)RpcEndpoint的遠(yuǎn)程引用對(duì)象,通過(guò)它可以向遠(yuǎn)程的RpcEndpoint端發(fā)送消息以進(jìn)行通信。RpcEndpointRef特質(zhì)的定義,代碼如下所示:

private[spark] abstract class RpcEndpointRef(conf: SparkConf)  extends Serializable with Logging { 
  private[this] val maxRetries = RpcUtils.numRetries(conf)
  private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf)
  private[this] val defaultAskTimeout = RpcUtils.askRpcTimeout(conf)
  def address: RpcAddress
  def name: String
  def send(message: Any): Unit
  def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
  def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout)
  def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout)
  def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = {
    ... ...
  }
}

上面代碼中,send方法發(fā)送消息后不等待響應(yīng),亦即Send-and-forget。而ask方法發(fā)送消息后需要等待通信對(duì)端給予響應(yīng),通過(guò)Future來(lái)異步獲取響應(yīng)結(jié)果。

Driver Spark Env中NettyRpcEnv創(chuàng)建

Driver Spark Env是Spark Application中Driver的運(yùn)行環(huán)境,其需要?jiǎng)?chuàng)建很多組件,比如SecurityManager、rpcEnv、broadcastManager、mapOutputTracker、memoryManager、blockTransferService、blockManagerMaster、blockManager、metricsSystem等,由于本文是介紹Spark RPC機(jī)制的,估只介紹rpcEnv創(chuàng)建過(guò)程及服務(wù)啟動(dòng)過(guò)程。從NettyRpcEnv.scala的NettyRpcEnvFactory的Create方法說(shuō)起

private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
  def create(config: RpcEnvConfig): RpcEnv = {
    val sparkConf = config.conf
    //創(chuàng)建序列化
    val javaSerializerInstance = new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
    //new 一個(gè)NettyRpcEnv實(shí)例
    val nettyEnv =  new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
    if (!config.clientMode) {
      val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = {
        actualPort => nettyEnv.startServer(actualPort)
        (nettyEnv, nettyEnv.address.port)
      }
      try {
        // 根據(jù)指定的端口號(hào)和主機(jī),啟動(dòng)Driver Rpc服務(wù)
        Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
      }
      catch {
        case NonFatal(e) =>
          nettyEnv.shutdown()
          throw e
      }
    }
  nettyEnv
  }
}

NettyRpcEnvFactory繼承RpcEnvFactory并實(shí)現(xiàn)其Create方法,create方法中最重要的就是聲明一個(gè)NettyRpc實(shí)例和啟動(dòng)服務(wù)。

1. 創(chuàng)建NettyRpcEnv
private[netty] class NettyRpcEnv(val conf: SparkConf, javaSerializerInstance: JavaSerializerInstance, host: String, securityManager: SecurityManager)  extends RpcEnv(conf) with Logging {  
  // 創(chuàng)建transportConf
  private[netty] val transportConf = SparkTransportConf.fromSparkConf(conf.clone.set("spark.rpc.io.numConnectionsPerPeer", "1"), "rpc", conf.getInt("spark.rpc.io.threads", 0))
  //創(chuàng)建Dispatcher,主要用戶消息的分發(fā)處理
  private val dispatcher: Dispatcher = new Dispatcher(this)
  //創(chuàng)建streamManager
  private val streamManager = new NettyStreamManager(this)  
  //創(chuàng)建一個(gè)transportContext,主要用于創(chuàng)建Netty的Server和Client,其中Spark將Netty框架進(jìn)行封裝,以transportContext為外部切入口,與NettyRpcEndpoint等Spark代碼對(duì)應(yīng),從而創(chuàng)建底層通信的服務(wù)端和客戶端。后面會(huì)詳細(xì)介紹Spark對(duì)Netty的封裝。
  private val transportContext = new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this, streamManager))

  private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = {
    if (securityManager.isAuthenticationEnabled()) {
      java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager,        securityManager.isSaslEncryptionEnabled()))
    } else {
      java.util.Collections.emptyList[TransportClientBootstrap]
    }  
  }
  // 聲明一個(gè)clientFactory,用戶創(chuàng)建通信的客戶端
  private val clientFactory = transportContext.createClientFactory(createClientBootstraps())  
  /**   
  * A separate client factory for file downloads. This avoids using the same RPC handler as   
  * the main RPC context, so that events caused by these clients are kept isolated from the   
  * main RPC traffic.   
  *   
  * It also allows for different configuration of certain properties, such as the number of   
  * connections per peer.   
  */  
  @volatile private var fileDownloadFactory: TransportClientFactory = _  

  //創(chuàng)建一個(gè)netty-rpc-env-timeout的守護(hù)線程
  val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout")

  // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool  
  // to implement non-blocking send/ask.  
  // TODO: a non-blocking TransportClientFactory.createClient in future
  private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection",    conf.getInt("spark.rpc.connect.threads", 64))

  @volatile private var server: TransportServer = _  

  private val stopped = new AtomicBoolean(false)  
  /**   
  * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],   
  * we just put messages to its [[Outbox]] to implement a non-blocking `send` method.  
  */  
  private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()  
  /**   
  * Remove the address's Outbox and stop it.   
  */  
  private[netty] def removeOutbox(address: RpcAddress): Unit = {    
    val outbox = outboxes.remove(address)    
    if (outbox != null) {      
      outbox.stop()    
    }  
  }  
  
  //根據(jù)指定端口,啟動(dòng)transportServer
  def startServer(port: Int): Unit = {    
    val bootstraps: java.util.List[TransportServerBootstrap] =
    if(securityManager.isAuthenticationEnabled()) {  
      java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) 
    } else {   
      java.util.Collections.emptyList() 
    }
    //通過(guò)transportContext啟動(dòng)通信底層的服務(wù)端
    server = transportContext.createServer(host, port, bootstraps)
    //注冊(cè)一個(gè)RpcEndpointVerifier,對(duì)Server進(jìn)行驗(yàn)證
    dispatcher.registerRpcEndpoint(RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))  
  }

  @Nullable  override lazy val address: RpcAddress = {    
    if (server != null) 
      RpcAddress(host, server.getPort()) 
    else 
      null
  }  
  //重寫(xiě)rpcEnv的setupEndpoint方法,用戶rpcEndpoint在rpcEnv上進(jìn)行注冊(cè)
  override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
    dispatcher.registerRpcEndpoint(name, endpoint)  
  }  

  def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {    
    val addr = RpcEndpointAddress(uri)    
    val endpointRef = new NettyRpcEndpointRef(conf, addr, this)    
    val verifier = new NettyRpcEndpointRef(conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this)    
    verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find =>  if (find) {  Future.successful(endpointRef) } else { Future.failed(new RpcEndpointNotFoundException(uri)) } }(ThreadUtils.sameThread)  
  }  
  
  override def stop(endpointRef: RpcEndpointRef): Unit = {
    require(endpointRef.isInstanceOf[NettyRpcEndpointRef])    
    dispatcher.stop(endpointRef)  
  }  
  
  private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {    
    if (receiver.client != null) {  
      message.sendWith(receiver.client)    
    } else {      
      require(receiver.address != null, "Cannot send message to client endpoint with no listen address.")      
      val targetOutbox = { 
        val outbox = outboxes.get(receiver.address)        
        if (outbox == null) {          
          val newOutbox = new Outbox(this, receiver.address)          
          val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox)          
          if (oldOutbox == null) { 
            newOutbox
          } else {            
            oldOutbox
          }        
        } else {          
          outbox        
        }      
      }      
      if (stopped.get) {       
       // It's possible that we put `targetOutbox` after stopping. So we need to clean it.
       outboxes.remove(receiver.address)        
       targetOutbox.stop()      
      } else {        
        targetOutbox.send(message)      
      }    
    }  
  }  

  private[netty] def send(message: RequestMessage): Unit = {    
    val remoteAddr = message.receiver.address    
    if (remoteAddr == address) {      
      // Message to a local RPC endpoint.      
      try {        
        dispatcher.postOneWayMessage(message)      
      } 
      catch {        
        case e: RpcEnvStoppedException => logWarning(e.getMessage)      
      }    
    } else {      
      // Message to a remote RPC endpoint.      
      postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))    
    }  
  }  

  private[netty] def createClient(address: RpcAddress): TransportClient = { clientFactory.createClient(address.host, address.port)  }  

  private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
    val promise = Promise[Any]()    
    val remoteAddr = message.receiver.address    
    def onFailure(e: Throwable): Unit = {      
      if (!promise.tryFailure(e)) {        
        logWarning(s"Ignored failure: $e")      
      }    
    }    
    def onSuccess(reply: Any): Unit = reply match {      
      case RpcFailure(e) => onFailure(e)      
      case rpcReply => if (!promise.trySuccess(rpcReply)) { logWarning(s"Ignored message: $reply") }
      }    
      try {      
        if (remoteAddr == address) {        
          val p = Promise[Any]()        
          p.future.onComplete {          
            case Success(response) => onSuccess(response)          
            case Failure(e) => onFailure(e)        
        }(ThreadUtils.sameThread)        
          dispatcher.postLocalMessage(message, p)      
        } else {        
         val rpcMessage = RpcOutboxMessage(serialize(message), onFailure, (client, response) => onSuccess(deserialize[Any](client, response)))
         postToOutbox(message.receiver, rpcMessage)        
         promise.future.onFailure {
           case _: TimeoutException => rpcMessage.onTimeout()          
           case _ =>        
         }(ThreadUtils.sameThread)
       }      
       val timeoutCancelable = timeoutScheduler.schedule(new Runnable { 
         override def run(): Unit = {          
           onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) 
         } 
        }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
        promise.future.onComplete { v =>  
            timeoutCancelable.cancel(true)
          }(ThreadUtils.sameThread)    
        } catch {      
          case NonFatal(e) => onFailure(e)    
        }    
        promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
  }  
  private[netty] def serialize(content: Any): ByteBuffer = {
    javaSerializerInstance.serialize(content)
  }  
  
  private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {    
    NettyRpcEnv.currentClient.withValue(client) {      
      deserialize { 
        () =>  javaSerializerInstance.deserialize[T](bytes)      
      }    
    }  
  }  

  override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = {
    dispatcher.getRpcEndpointRef(endpoint)  
  }  

  override def shutdown(): Unit = {    
    cleanup()  
  }  
  
  override def awaitTermination(): Unit = {    
    dispatcher.awaitTermination()  
  }  

  private def cleanup(): Unit = {    
    if (!stopped.compareAndSet(false, true)) {      
      return    
    }    
    val iter = outboxes.values().iterator()    
    while (iter.hasNext()) {      
      val outbox = iter.next()      
      outboxes.remove(outbox.address)      
      outbox.stop()    
    }    
    if (timeoutScheduler != null) {      
      timeoutScheduler.shutdownNow()    
    }    
    if (dispatcher != null) {      
      dispatcher.stop()    
    }    
    if (server != null) {     
     server.close()    
    }    
    if (clientFactory != null) {      
      clientFactory.close()    
    }    
    if (clientConnectionExecutor != null) {      
      clientConnectionExecutor.shutdownNow()    
    }    
    if (fileDownloadFactory != null) {      
      fileDownloadFactory.close()    
    }    
  }  

  override def deserialize[T](deserializationAction: () => T): T = {
    NettyRpcEnv.currentEnv.withValue(this) {      
      deserializationAction()    
    }  
  }  

  override def fileServer: RpcEnvFileServer = streamManager  

  override def openChannel(uri: String): ReadableByteChannel = {    
    val parsedUri = new URI(uri)    
    require(parsedUri.getHost() != null, "Host name must be defined.")
    require(parsedUri.getPort() > 0, "Port must be defined.")    
    require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.")    
    val pipe = Pipe.open()    
    val source = new FileDownloadChannel(pipe.source())    
    try {      
      val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())      
      val callback = new FileDownloadCallback(pipe.sink(), source, client)
      client.stream(parsedUri.getPath(), callback)    
    } catch {      
      case e: Exception =>  
        pipe.sink().close()        
        source.close()        
        throw e    
    }    
    source  
  }  

  private def downloadClient(host: String, port: Int): TransportClient = {    
    if (fileDownloadFactory == null) 
      synchronized {      
        if (fileDownloadFactory == null) {        
          val module = "files"        
          val prefix = "spark.rpc.io."        
          val clone = conf.clone()        
          // Copy any RPC configuration that is not overridden in the spark.files namespace.        
          conf.getAll.foreach { 
            case (key, value) => 
              if (key.startsWith(prefix)) {            
                val opt = key.substring(prefix.length())
                clone.setIfMissing(s"spark.$module.io.$opt", value)          
              }        
            }        
          val ioThreads = clone.getInt("spark.files.io.threads", 1)        
          val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads)        
          val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true)        
          fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps())      
      }    
    }    
    fileDownloadFactory.createClient(host, port)  
  }  

  private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {    
    @volatile private var error: Throwable = _    
    def setError(e: Throwable): Unit = {      error = e      source.close()    }    
    override def read(dst: ByteBuffer): Int = {      
      Try(source.read(dst)) match {        
        case Success(bytesRead) => bytesRead        
        case Failure(readErr) =>          
          if (error != null) {            
           throw error          
          } else {            
            throw readErr          
          }      
        }    
      }    
      override def close(): Unit = source.close()    
      override def isOpen(): Boolean = source.isOpen()  
  }  

  private class FileDownloadCallback(sink: WritableByteChannel, source: FileDownloadChannel, client: TransportClient) extends StreamCallback {    
    override def onData(streamId: String, buf: ByteBuffer): Unit = {      
      while (buf.remaining() > 0) {        
        sink.write(buf)      
      }    
    }    
    override def onComplete(streamId: String): Unit = {      
      sink.close()    
    }    
    override def onFailure(streamId: String, cause: Throwable): Unit = {      
      logDebug(s"Error downloading stream $streamId.", cause)
      source.setError(cause)      
      sink.close()
    }  
  }
}

新創(chuàng)建的NettyRpcEnv主要用于Endpoint的注冊(cè)、啟動(dòng)transportServer、獲得RPCEndpointRef、創(chuàng)建客戶端等等;其主要成員有dispatcher、transportContext。

1.1 Dispatcher介紹

Dispatcher的主要作用是保存注冊(cè)的RpcEndpoint、分發(fā)相應(yīng)的Message到RpcEndPoint中進(jìn)行處理。

private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
  // Dispatcher的內(nèi)部類,主要是聲明一個(gè)
  private class EndpointData(val name: String,  val endpoint: RpcEndpoint,  val ref:   NettyRpcEndpointRef) {
    val inbox = new Inbox(ref, endpoint)  
  }  

  // 維護(hù)一個(gè)HaskMap,保存Name與EndpointData的關(guān)系
  private val endpoints = new ConcurrentHashMap[String, EndpointData]  
  // 維護(hù)一個(gè)HaskMap,保存RpcEndpoint與RpcEndpointRef的關(guān)系
  private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]
 
  // Track the receivers whose inboxes may contain messages.  
  //維護(hù)一個(gè)BlockingQueue的隊(duì)列,用于保存擁有消息的EndpointData,注冊(cè)Endpoint、
  //發(fā)送消息時(shí)、停止RpcEnv時(shí)、取消注冊(cè)的Endpoint時(shí),會(huì)在receivers中添加相應(yīng)的EndpointData
  private val receivers = new LinkedBlockingQueue[EndpointData]  

  /**   
  * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced immediately.   
  */  
  @GuardedBy("this")  private var stopped = false  

  // 根據(jù)Name和RPCEndpoint,在RpcEnv上進(jìn)行注冊(cè)
  def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = {
    //根據(jù)NettyEnv的address和參數(shù)Name,創(chuàng)建RpcEndpointAddress
    val addr = RpcEndpointAddress(nettyEnv.address, name)
    //創(chuàng)建對(duì)應(yīng)的NettyRpcEndpointRef
    val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv)
    synchronized {      
      if (stopped) {        
        throw new IllegalStateException("RpcEnv has been stopped")      
      }
      //新建一個(gè)EndpointData,里面主要包含一個(gè)inbox成員,后面會(huì)講到。
      //將新創(chuàng)建的EndpointData和對(duì)應(yīng)的Name添加到endpoints中
      if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) {
        throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name")
      }
      val data = endpoints.get(name)
      //將endpoint和對(duì)應(yīng)的endpointRef添加到endpointRefs中
      endpointRefs.put(data.endpoint, data.ref)
      //在receivers中添加新創(chuàng)建的endpointData
      receivers.offer(data)
      // for the OnStart message
    }
    //返回對(duì)應(yīng)的EndpointRef
    endpointRef
  }

  //根據(jù)endpoint獲取對(duì)應(yīng)的endpointRef
  def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint)

  //從endpointRefs中移除對(duì)應(yīng)的endpoint
  def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint)  
  
  // Should be idempotent  private
  // 根據(jù)Name,取消其在NettyRpcEnv中注冊(cè)的endpoint
  def unregisterRpcEndpoint(name: String): Unit = { 
     //從endpoints中移除對(duì)應(yīng)的endpointData
    val   data = endpoints.remove(name)
    if (data != null) {
      //調(diào)用endpointData中inbox的stop方法,停止endpointData
      data.inbox.stop()
      //將endpointData添加到receivers中,以便守護(hù)線程能執(zhí)行endpointData.inbox的message
      receivers.offer(data)  
      // for the OnStop message    
    }
    // Don't clean `endpointRefs` here because it's possible that some messages are being processed    
    // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via    
    // `removeRpcEndpointRef`.  
  }  

  def stop(rpcEndpointRef: RpcEndpointRef): Unit = {    
    synchronized {      
      if (stopped) {
        // This endpoint will be stopped by Dispatcher.stop() method.        
        return
      }      
      unregisterRpcEndpoint(rpcEndpointRef.name)    
    }  
  }  

  /**   
  * Send a message to all registered [[RpcEndpoint]]s in this process.   
  *   
  * This can be used to make network events known to all end points (e.g. "a new node connected").   
  */ 
  //向所有已經(jīng)注冊(cè)的RpcEndpoint發(fā)送消息
  def postToAll(message: InboxMessage): Unit = {    
    val iter = endpoints.keySet().iterator()    
    while (iter.hasNext) {      
      val name = iter.next      
      postMessage(name, message, (e) => logWarning(s"Message $message dropped. ${e.getMessage}"))    
    }  
  }  

  /** Posts a message sent by a remote endpoint. */
  //發(fā)布一個(gè)由遠(yuǎn)端endpoint發(fā)送的消息
  def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
    val rpcCallContext =  new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)    
    postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))  
  }  

  /** Posts a message sent by a local endpoint. */
  //發(fā)布一個(gè)由本地endpoint發(fā)送的消息
  def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {    
    val rpcCallContext = new LocalNettyRpcCallContext(message.senderAddress, p)    
    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)    
    postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))  
  }  

  /** Posts a one-way message. */  
  def postOneWayMessage(message: RequestMessage): Unit = {
    postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),      (e) => throw e)  
  }  

  /**   
  * Posts a message to a specific endpoint.   
  *   
  * @param endpointName name of the endpoint.   
  * @param message the message to post   
  * @param callbackIfStopped callback function if the endpoint is stopped.   
  */ 
  //將消息發(fā)送給特定的endpoint進(jìn)行處理,參數(shù)1:endpoint的名字,參數(shù)2:消息,參數(shù)3:當(dāng)endpoint停止時(shí)的回調(diào)函數(shù)
  private def postMessage(endpointName: String,  message: InboxMessage,  callbackIfStopped: (Exception) => Unit): Unit = {
    val error = synchronized { 
      // 根據(jù)endpointName獲得對(duì)應(yīng)的endpointData
      val data = endpoints.get(endpointName)
      if (stopped) {
        Some(new RpcEnvStoppedException())
      } else if (data == null) {
        Some(new SparkException(s"Could not find $endpointName."))
      } else {
        //將Message添加到該endpointData的inbox的message中
        data.inbox.post(message)
        //將endpointData添加到receivers中
        receivers.offer(data)
        None
      }
    }
    // We don't need to call `onStop` in the `synchronized` block
    error.foreach(callbackIfStopped)
  }  

  def stop(): Unit = {    
    synchronized {      
      if (stopped) {        
        return      
      }      
      stopped = true    
    }    
    // Stop all endpoints. This will queue all endpoints for processing by the message loops.    
    endpoints.keySet().asScala.foreach(unregisterRpcEndpoint)    
    // Enqueue a message that tells the message loops to stop.    receivers.offer(PoisonPill)    
    threadpool.shutdown()  
  }  

  def awaitTermination(): Unit = {    
    threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS)  
  }  

  /**   
  * Return if the endpoint exists   
  */
  //判斷endpoints中是否包含對(duì)應(yīng)的endpointName
  def verify(name: String): Boolean = {    endpoints.containsKey(name)  }  

  /** Thread pool used for dispatching messages. */
  //創(chuàng)建一個(gè)線程組,用于分發(fā)消息
  private val threadpool: ThreadPoolExecutor = {
    //根據(jù)配置項(xiàng),獲的線程組中線程個(gè)數(shù)
    val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",      math.max(2, Runtime.getRuntime.availableProcessors()))
    //創(chuàng)建線程組
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
    //創(chuàng)建多線程,執(zhí)行相應(yīng)的MessageLoop
    for (i <- 0 until numThreads) {      
      pool.execute(new MessageLoop)    
    }    
    pool  
  }  

  /** Message loop used for dispatching messages. */  
  //聲明一個(gè)MessageLoop繼承Runnable
  private class MessageLoop extends Runnable {    
    override def run(): Unit = {      
      try {        
        while (true) {
          try {
            //從receivers中獲得一個(gè)endpointData,由于receivers是LinkBlockingQueue,所以如果receivers中沒(méi)有元素時(shí),該線程會(huì)阻塞
            val data = receivers.take()
            //獲取的元素如果是PoisonPill,將停止該線程,同時(shí) 將PoisonPill繼續(xù)放回receivers中,以便停止所有線程
            if (data == PoisonPill) {
            // Put PoisonPill back so that other MessageLoops can see it.
              receivers.offer(PoisonPill)
              return
            }
            //調(diào)用rpcEndpointData中inbox的process方法,處理響應(yīng)RpcEndpointData中的Message
              data.inbox.process(Dispatcher.this)
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }

  /** A poison endpoint that indicates MessageLoop should exit its message loop. */  
  private val PoisonPill = new EndpointData(null, null, null)}

根據(jù)上面的代碼可以看出,Dispatcher在進(jìn)行Message分發(fā)到相應(yīng)的Endpoint進(jìn)行處理時(shí),實(shí)際上是將Message分發(fā)到endpointData中進(jìn)行處理了,而EndpointData類中最重要的成員就是inbox,下面介紹Inbox。

1.2 Inbox
private[netty] class Inbox(val endpointRef: NettyRpcEndpointRef,  val endpoint: RpcEndpoint)  extends Logging {  
inbox =>  
  // Give this an alias so we can use it more clearly in closures.
  // 聲明一個(gè)InboxMessage類型的LinkedList,命名為message
  @GuardedBy("this")  protected val messages = new java.util.LinkedList[InboxMessage]()

  /** True if the inbox (and its associated endpoint) is stopped. */  
  @GuardedBy("this")  private var stopped = false  
  
  /** Allow multiple threads to process messages at the same time. */
  //允許多個(gè)線程同時(shí)處理message
  @GuardedBy("this")  private var enableConcurrent = false  

  /** The number of threads processing messages for this inbox. */
  //對(duì)當(dāng)前處理message的進(jìn)程的計(jì)數(shù)
  @GuardedBy("this")  private var numActiveThreads = 0  

  // OnStart should be the first message to process
  //最開(kāi)始在聲明的時(shí)候就將OnStart消息添加到message中
  inbox.synchronized {
    messages.add(OnStart)  
  }  

  /**   
  * Process stored messages.   
  */
  //處理消息
  def process(dispatcher: Dispatcher): Unit = {    
    var message: InboxMessage = null    
    inbox.synchronized {      
      if (!enableConcurrent && numActiveThreads != 0) {        
        return
      }
      //獲取list中頭部的第一個(gè)message
      message = messages.poll()
      //去過(guò)message不為Null,就將numActiveThreads加1
      if (message != null) {
        numActiveThreads += 1 
      } else {
        return      
      }
    }
    //對(duì)Message進(jìn)行匹配,然后執(zhí)行
    while (true) {      
      safelyCall(endpoint) {        
        message match {          
          case RpcMessage(_sender, content, context) =>  
            try {              
              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>                
                throw new SparkException(s"Unsupported message $message from ${_sender}")              
              })            
            } catch {              
              case NonFatal(e) =>                
                context.sendFailure(e)                
                // Throw the exception -- this exception will be caught by the safelyCall function.  
                // The endpoint's onError function will be called.                
                  throw e            
            }          
          case OneWayMessage(_sender, content) =>
            endpoint.receive.applyOrElse[Any, Unit](content, { msg =>              
             throw new SparkException(s"Unsupported message $message from ${_sender}")            
            })          
          case OnStart =>            
            endpoint.onStart()            
            if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {              
              inbox.synchronized {                
                if (!stopped) {                  
                  enableConcurrent = true               
                }              
              }            
            }          
          case OnStop =>            
            val activeThreads = inbox.synchronized { inbox.numActiveThreads }
            assert(activeThreads == 1,              
              s"There should be only a single active thread but found $activeThreads threads.")            
            dispatcher.removeRpcEndpointRef(endpoint)            
            endpoint.onStop()            
            assert(isEmpty, "OnStop should be the last message")          
          case RemoteProcessConnected(remoteAddress) =>
            endpoint.onConnected(remoteAddress)          
          case RemoteProcessDisconnected(remoteAddress) =>
            endpoint.onDisconnected(remoteAddress)          
          case RemoteProcessConnectionError(cause, remoteAddress) =>
            endpoint.onNetworkError(cause, remoteAddress)        
        }      
      }      
      inbox.synchronized {        
        // "enableConcurrent" will be set to false after `onStop` is called, so we should check it  every time.        
        if (!enableConcurrent && numActiveThreads != 1) {          
          // If we are not the only one worker, exit          
          numActiveThreads -= 1          
          return        
        }
        //獲取message中的下一個(gè)元素,繼續(xù)進(jìn)行匹配執(zhí)行
        message = messages.poll()        
        if (message == null) {          
          numActiveThreads -= 1          
          return        
        }      
      }    
    }  
  }  

  //將message消息添加到messages列表中
  def post(message: InboxMessage): Unit = inbox.synchronized {
    //如果inbox已經(jīng)停止,就將OnStop添加到messages中
    if (stopped) {      
      // We already put "OnStop" into "messages", so we should drop further messages
      onDrop(message)    
    } else {      
      messages.add(message)      
      false    
    }  
  }  

  def stop(): Unit = inbox.synchronized   {    
    // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last    
    // message    
    if (!stopped) {      
      // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only      
      // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources      
      // safely.      
      enableConcurrent = false      
      stopped = true      
      messages.add(OnStop)     
      // Note: The concurrent events in messages will be processed one by one.    
    }  
  }  

  //判斷messages是否為空
  def isEmpty: Boolean = inbox.synchronized { messages.isEmpty }  

  /**
  * Called when we are dropping a message. Test cases override this to test message dropping.   
  * Exposed for testing.   
  */
  protected def onDrop(message: InboxMessage): Unit = {    
    logWarning(s"Drop $message because $endpointRef is stopped")  
  }  
  
  /**   
  * Calls action closure, and calls the endpoint's onError function in the case of exceptions.   
  */  
  private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = {    
    try action catch {      
      case NonFatal(e) =>        
        try endpoint.onError(e) catch {          
          case NonFatal(ee) => logError(s"Ignoring error", ee)        
        }    
      }  
    }
  }

至此,NettyRpcEnv中的Dispatcher已經(jīng)講完了,主要流程是:

  1. 創(chuàng)建Dispatcher
  • 聲明線程組,并監(jiān)控receivers是否有新的EndpointData
    • 如果有消息,并且不為PoisonPill,調(diào)用相應(yīng)EndpointData的Inbox的process方法進(jìn)行消息處理
      1). 依次從相應(yīng)的EndpointData的inbox的messages中獲取第一個(gè)元素
      2). 匹配消息,并調(diào)用對(duì)應(yīng)的endpoint的相應(yīng)方法進(jìn)行處理
    • 如果沒(méi)有消息,則阻塞等待
    • 如果有消息,但是為PoisonPill,則將PoisonPill繼續(xù)添加到receivers中,然后停止該線程
  1. 根據(jù)name和endpoint,在NettyRpcEnv進(jìn)行注冊(cè)
  • 根據(jù)nettyEnv.conf、RpcEndpointAddress和nettyEnv創(chuàng)建對(duì)應(yīng)的NettyRpcEndpointRef
  • 根據(jù)name、endpoint、endpointRef創(chuàng)建新的EndpointData
  • 將name -> EndpointData添加到endpoints中
  • 將endpoint -> endpointRef添加到endpointRefs中
  • 將新建的EndpointData添加到receivers中
  1. 將InboxMessage消息分發(fā)到相應(yīng)的EndpointData中進(jìn)行處理
  • 根據(jù)Name獲取EndpointData
  • 將Message添加到EndpointData的Inbox的messages中
  • 將EndpointData添加到receivers中

接下來(lái)重點(diǎn)介紹下RpcEndpointRef的生成方法,根據(jù)name和rpcendpoint在NettyRpcEnv注冊(cè)時(shí),首先會(huì)根據(jù)name和NettyEnv的address創(chuàng)建RpcEndpointAddress,然后再根據(jù)RpcEndpointAddress、NettyEnv.conf和NettyEnv創(chuàng)建一個(gè)相應(yīng)的NettyRpcEndpointRef,也就是說(shuō)NettyRpcEndpointRef的生成與實(shí)際的RPCEndpoint并沒(méi)有什么直接聯(lián)系,只是在NettyRpcEnv中依據(jù)某個(gè)Name生成一個(gè)NettyRpcEndpointRef,然后客戶端通過(guò)NettyRpcEndpotinRef發(fā)送消息時(shí),NettyRpcEnv會(huì)根據(jù)消息中的name,將消息發(fā)送給對(duì)應(yīng)的NettyRpcEndpoint進(jìn)行相應(yīng)消息處理。

1.3 NettyRpcEndpointRef

private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf,    endpointAddress: RpcEndpointAddress,    @transient @volatile private var nettyEnv: NettyRpcEnv)  extends RpcEndpointRef(conf) with Serializable with Logging { 
   //聲明一個(gè)transportClient
  @transient @volatile var client: TransportClient = _ 
   //根據(jù)endpointAddress獲得NettyRpcEnv的host地址
  private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null
  //聲明一個(gè)_name變量并賦值為endpointAddress的Name
  private val _name = endpointAddress.name
  
  override def address: RpcAddress = if (_address != null) _address.rpcAddress else null
  //讀對(duì)象
  private def readObject(in: ObjectInputStream): Unit = {    
    in.defaultReadObject()    
    nettyEnv = NettyRpcEnv.currentEnv.value    
    client = NettyRpcEnv.currentClient.value  
  }  
  //寫(xiě)對(duì)象
  private def writeObject(out: ObjectOutputStream): Unit = {    
    out.defaultWriteObject()  
  }  
  
  override def name: String = _name  
  //重寫(xiě)RPCEndpointRef的ask方法
  override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
    nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)  
  }  
  //重寫(xiě)RPCEndpointRef的send方法
  override def send(message: Any): Unit = {    
    require(message != null, "Message is null")
    nettyEnv.send(RequestMessage(nettyEnv.address, this, message))  
  }  

  override def toString: String = s"NettyRpcEndpointRef(${_address})"  

  def toURI: URI = new URI(_address.toString)  

  final override def equals(that: Any): Boolean = that match {    
    case other: NettyRpcEndpointRef => _address == other._address    
    case _ => false  
  }  
  final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode()}

至此,Spark RPC通信模塊中的NettyRpcEnv、NettyRpcEndpoint、NettyRpcEndpointRef已經(jīng)全部梳理完成。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容