rasa_core: nlg模塊源碼解讀

最近在學習使用rasa構(gòu)建聊天機器人,為了實現(xiàn)一個比較特別的功能,需要搞懂源碼。rasa 的代碼質(zhì)量相當高,注釋完整,函數(shù)定義包含 type hint 讀起來非常舒服。
rasa_core.nlg模塊包含5個py腳本:

  • __init__.py
  • callback.py
  • generator.py
  • interpolator.py
  • template.py

首先看 __init__.py

from rasa.core.nlg.generator import NaturalLanguageGenerator
from rasa.core.nlg.template import TemplatedNaturalLanguageGenerator
from rasa.core.nlg.callback import CallbackNaturalLanguageGenerator

可以看到,nlg模塊主要有三個類,

  • NaturalLanguageGenerator(NLG)
  • TemplatedNaturalLanguageGenerator(TNLG)
  • CallbackNaturalLanguageGenerator(CNLG)

TNLGCNLG都繼承自NLG,所以從NLG開始。

NaturalLanguageGenerator

NLG類包含兩個成員函數(shù):

  • generate
  • create
    generate是抽象函數(shù),沒有具體實現(xiàn),create是靜態(tài)函數(shù)。
generate:
async def generate(
    self,
    template_name: Text,
    tracker: "DialogueStateTracker",
    output_channel: Text,
    **kwargs: Any,
) -> Optional[Dict[Text, Any]]

異步抽象函數(shù),用于對用戶輸入產(chǎn)生回復(fù)。

create
@staticmethod
def create(
    obj: Union["NaturalLanguageGenerator", EndpointConfig, None],
    domain: Optional[Domain],
) -> "NaturalLanguageGenerator":
    """Factory to create a generator."""

    if isinstance(obj, NaturalLanguageGenerator):
        return obj
    else:
        return _create_from_endpoint_config(obj, domain)

靜態(tài)函數(shù),用于產(chǎn)生一個NLG實例。建議的輸入obj是NLG實例或者EndpointConfig對象,domain是Domain對象,如果obj是NLG實例,直接返回obj,否則根據(jù)EndpointConfig和Domain的配置,借助了_create_from_endpoint_config函數(shù),實例化一個NLG。

_create_from_endpoint_config

接下來,我們來看_create_from_endpoint_config這個函數(shù)。

def _create_from_endpoint_config(
    endpoint_config: Optional[EndpointConfig] = None, domain: Optional[Domain] = None,
) -> "NaturalLanguageGenerator":
    """Given an endpoint configuration, create a proper NLG object."""

    domain = domain or Domain.empty()

    if endpoint_config is None:
        from rasa.core.nlg import (  # pytype: disable=pyi-error
            TemplatedNaturalLanguageGenerator,
        )

        # this is the default type if no endpoint config is set
        nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    elif endpoint_config.type is None or endpoint_config.type.lower() == "callback":
        from rasa.core.nlg import (  # pytype: disable=pyi-error
            CallbackNaturalLanguageGenerator,
        )

        # this is the default type if no nlg type is set
        nlg = CallbackNaturalLanguageGenerator(endpoint_config=endpoint_config)
    elif endpoint_config.type.lower() == "template":
        from rasa.core.nlg import (  # pytype: disable=pyi-error
            TemplatedNaturalLanguageGenerator,
        )

        nlg = TemplatedNaturalLanguageGenerator(domain.templates)
    else:
        nlg = _load_from_module_string(endpoint_config, domain)

    logger.debug(f"Instantiated NLG to '{nlg.__class__.__name__}'.")
    return nlg

_create_from_endpoint_config的輸入同樣是EndpointConfig對象和Domain對象。函數(shù)主體是if-else的結(jié)構(gòu),根據(jù)EndpointConfig的狀況決定構(gòu)建怎樣的NLG實例。

_load_from_module_string
def _load_from_module_string(
    endpoint_config: EndpointConfig, domain: Domain
) -> "NaturalLanguageGenerator":
    """Initializes a custom natural language generator.

    Args:
        domain: defines the universe in which the assistant operates
        endpoint_config: the specific natural language generator
    """

    try:
        nlg_class = common.class_from_module_path(endpoint_config.type)
        return nlg_class(endpoint_config=endpoint_config, domain=domain)
    except (AttributeError, ImportError) as e:
        raise Exception(
            f"Could not find a class based on the module path "
            f"'{endpoint_config.type}'. Failed to create a "
            f"`NaturalLanguageGenerator` instance. Error: {e}"
        )

TemplatedNaturalLanguageGenerator

TNLG繼承自NLG,除了NLG的成員函數(shù)之外,還有以下新成員:

  • _templates_for_utter_action
  • _random_template_for
  • generate
  • generate_from_slots
  • _fill_template
  • _template_variables
    首先來看最重要的generate
generate
async def generate(
    self,
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Optional[Dict[Text, Any]]:
    """Generate a response for the requested template."""

    filled_slots = tracker.current_slot_values()
    return self.generate_from_slots(
        template_name, filled_slots, output_channel, **kwargs
    )

輸入是模板名和tracker對象,在模板中填充tracker記錄的槽位生成回復(fù)語句。生成語句這里調(diào)用的是generate_from_slots函數(shù)。

generate_from_slots
def generate_from_slots(
    self,
    template_name: Text,
    filled_slots: Dict[Text, Any],
    output_channel: Text,
    **kwargs: Any,
) -> Optional[Dict[Text, Any]]:
    """Generate a response for the requested template."""

    # Fetching a random template for the passed template name
    r = copy.deepcopy(self._random_template_for(template_name, output_channel))
    # Filling the slots in the template and returning the template
    if r is not None:
        return self._fill_template(r, filled_slots, **kwargs)
    else:
        return None

這里調(diào)用_random_template_for隨機選擇模板(一個action可能對應(yīng)多個回復(fù)模板),然后調(diào)用_fill_template填充模板中的槽位。
先來看_random_template_for。

_random_template_for
def _random_template_for(
    self, utter_action: Text, output_channel: Text
) -> Optional[Dict[Text, Any]]:
    """Select random template for the utter action from available ones.

    If channel-specific templates for the current output channel are given,
    only choose from channel-specific ones.
    """
    import numpy as np

    if utter_action in self.templates:
        suitable_templates = self._templates_for_utter_action(
            utter_action, output_channel
        )

        if suitable_templates:
            return np.random.choice(suitable_templates)
        else:
            return None
    else:
        return None

調(diào)用_templates_for_utter_action函數(shù)拿到當前action的所有模板,使用np.random.choice在模板列表中隨機選擇一個。可以看到,輸入是action名,返回的template其實是一個 dict 對象。

_fill_template

_fill_template將對選擇的模板進行槽位填充的工作。

def _fill_template(
    self,
    template: Dict[Text, Any],
    filled_slots: Optional[Dict[Text, Any]] = None,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """"Combine slot values and key word arguments to fill templates."""

    # Getting the slot values in the template variables
    template_vars = self._template_variables(filled_slots, kwargs)

    keys_to_interpolate = [
        "text",
        "image",
        "custom",
        "button",
        "attachment",
        "quick_replies",
    ]
    if template_vars:
        for key in keys_to_interpolate:
            if key in template:
                template[key] = interpolate(template[key], template_vars)
    return template

可以看到,輸入的模板template和填充槽位filled_slots都是dict對象。暫時沒有看到具體的例子,猜測:
filled_slots中的所有key都是template中的槽位名,value是對槽位的填充值value,通過替換template中的槽位填充值,完成回復(fù)語句的生成。

interpolate.py

在實現(xiàn)TNLG的回復(fù)生成階段,調(diào)用了interpolate.py下的兩個模塊 interpolate和interpolate_text。interpolate_text用于對text格式的template進行槽位填充,使用正則表達式替換和str.format()的形式:

def interpolate_text(template: Text, values: Dict[Text, Text]) -> Text:
    # transforming template tags from
    # "{tag_name}" to "{0[tag_name]}"
    # as described here:
    # https://stackoverflow.com/questions/7934620/python-dots-in-the-name-of-variable-in-a-format-string#comment9695339_7934969
    # black list character and make sure to not to allow
    # (a) newline in slot name
    # (b) { or } in slot name
    try:
        text = re.sub(r"{([^\n{}]+?)}", r"{0[\1]}", template)
        text = text.format(values)
        if "0[" in text:
            # regex replaced tag but format did not replace
            # likely cause would be that tag name was enclosed
            # in double curly and format func simply escaped it.
            # we don't want to return {0[SLOTNAME]} thus
            # restoring original value with { being escaped.
            return template.format({})

        return text
    except KeyError as e:
        logger.exception(
            "Failed to fill utterance template '{}'. "
            "Tried to replace '{}' but could not find "
            "a value for it. There is no slot with this "
            "name nor did you pass the value explicitly "
            "when calling the template. Return template "
            "without filling the template. "
            "".format(template, e.args[0])
        )
        return template

CallbackNaturalLanguageGenerator

最后,來看CNLG。CNLG的結(jié)構(gòu)要簡單很多,僅包含兩個成員函數(shù),一個產(chǎn)生回復(fù)的generate,另一個用于檢驗回復(fù)格式是否合法的validate_response。

generate
async def generate(
    self,
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """Retrieve a named template from the domain using an endpoint."""

    body = nlg_request_format(template_name, tracker, output_channel, **kwargs)

    logger.debug(
        "Requesting NLG for {} from {}."
        "".format(template_name, self.nlg_endpoint.url)
    )

    response = await self.nlg_endpoint.request(
        method="post", json=body, timeout=DEFAULT_REQUEST_TIMEOUT
    )

    if self.validate_response(response):
        return response
    else:
        raise Exception("NLG web endpoint returned an invalid response.")

輸入是action的名稱,用于記錄的tracker,以及output_channel。首先從nlg_request_format函數(shù)中得到request的body,之后向endpoint上的服務(wù)發(fā)出請求,調(diào)用定義在對應(yīng)Action類中的run函數(shù),得到response,驗證response的合法性,并且返回。

nlg_request_format
def nlg_request_format(
    template_name: Text,
    tracker: DialogueStateTracker,
    output_channel: Text,
    **kwargs: Any,
) -> Dict[Text, Any]:
    """Create the json body for the NLG json body for the request."""

    tracker_state = tracker.current_state(EventVerbosity.ALL)

    return {
        "template": template_name,
        "arguments": kwargs,
        "tracker": tracker_state,
        "channel": {"name": output_channel},
    }

這個函數(shù)處理產(chǎn)生request的主體,用于指定Action的調(diào)用。在寫Action的時候就很好奇,Action類的run函數(shù)一般定義成這樣:def run(self, dispatcher, tracker, domain),后來就很神奇的發(fā)現(xiàn)這里邊的tracker并不是一個rasa_core.trackers,包含的信息比較少。果然,這里產(chǎn)生的tracker,僅僅保留了當前狀態(tài)。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

  • 一、Rasa Rasa是一個開源機器學習框架,用于構(gòu)建上下文AI助手和聊天機器人。Rasa有兩個主要模塊: Ras...
    風玲兒閱讀 52,913評論 1 30
  • [TOC] Rasa學習筆記2--Rasa Core 1. 概念介紹 首先引出Rasa的設(shè)計理念:Learning...
    ColdCoder閱讀 3,312評論 1 4
  • 模板標簽除了幾個常用的,還真心沒有仔細了解一下,看到2.0發(fā)布后,翻譯學習一下。 本文盡量忠實原著,畢竟大神的東西...
    海明_fd17閱讀 2,139評論 0 5
  • Linux基礎(chǔ)入門第二節(jié)實驗報告 1、重要快捷鍵: 【Tab】 好處就是當你忘記某個命令的全稱時可以只輸入它的開頭...
    小公舉凡閱讀 281評論 0 1
  • 從我的世界里逃出來 多想 進入另一個人的世界 期待著 與他 不期而遇 沿著他的路徑 再走一遍 當年的情景 期待著 ...
    青果未熟閱讀 690評論 0 3

友情鏈接更多精彩內(nèi)容