mindspeed-llm源碼解析(一)preprocess_data

mindspeed-llm是昇騰模型套件代碼倉(cāng),原來(lái)叫"modelLink"。這篇文章帶大家閱讀一下數(shù)據(jù)處理腳本preprocess_data.py(基于1.0.0分支),數(shù)據(jù)處理是模型訓(xùn)練的第一步,經(jīng)常會(huì)用到。

文章中貼的源碼加了相關(guān)注釋,同學(xué)們可以把源碼和注釋結(jié)合起來(lái)看。

首先來(lái)看一下main函數(shù)

def main():
    # 獲取入?yún)?,通過(guò)后面的代碼可以知道有哪些關(guān)鍵參數(shù)
    args = get_args()
    # 參數(shù)校驗(yàn)
    validate_args(args)
    # 合并已經(jīng)處理好的數(shù)據(jù)集
    if args.merge_group_keys is not None:
        merge_datasets(args)
        return

    # 創(chuàng)建splitter,用來(lái)把文章段落分割成句子
    splitter = build_splitter(args)
    # 創(chuàng)建tokenizer,用來(lái)把句子切分成單個(gè)的詞
    tokenizer = build_tokenizer(args)

    logger.info("building dataset: %s", args.input)
    # 加載數(shù)據(jù),把CSV、JSON、TXT等格式的數(shù)據(jù)加載到內(nèi)存
    raw_data = build_dataset(args)

    # 保存到一個(gè)文件
    if args.n_subs == 1:
        # 獲取處理后的數(shù)據(jù)句柄
        handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
        # 數(shù)據(jù)落盤
        handler.serialize_to_disk()
    # 保存到多個(gè)文件,使用多進(jìn)程處理,單文件的處理方式和if條件中是一致的
    else:
        target_prefix = args.output_prefix
        target_prefixname = os.path.basename(target_prefix)
        
        num_samples = len(raw_data)
        start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
        subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]
        
        # multiprocessing
        params_list = []
        for k, subset in enumerate(subsets):
            args_ = copy.deepcopy(args)
            args_.output_prefix = target_prefix.replace(target_prefixname, f'{str(k).zfill(3)}_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}')
            params = [args_, subset, tokenizer, splitter]
            params_list.append(params)
        pool = multiprocessing.Pool()
        sub_idx_files = pool.map(handle_subset, params_list)
        pool.close()
        pool.join()
        
        for key in sub_idx_files[0].keys():
            idx_files = [x[key] for x in sub_idx_files]
            idx_files.sort()
            target_idx = idx_files[0].replace(f'000_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}', target_prefixname)
            target_bin = target_idx.replace('.idx', '.bin')
            idx = IndexedDatasetBuilder(target_bin)
            for idx_file in idx_files:
                idx.add_index(idx_file.replace('.idx', ''))
            idx.finalize(target_idx)
            
            for idx_file in idx_files:
                os.remove(idx_file)
                os.remove(idx_file.replace('.idx', '.bin'))

可以看到,main函數(shù)處理邏輯主要由這幾個(gè)函數(shù)組成:build_splitter、build_tokenizer、build_dataset、get_dataset_handler、serialize_to_disk。

build_splitter

這個(gè)函數(shù)的功能是把文字段落分割成單個(gè)句子,查看源碼,主要使用的是三方庫(kù)nltk的函數(shù):

def build_splitter(args):
    if nltk and args.split_sentences:
        nltk.download("punkt", quiet=True)
    if args.split_sentences:
        if not nltk:
            logger.error("NLTK is not available to split sentences.")
            raise Exception("nltk is not available")
        splitter = nltk.load("tokenizers/punkt/english.pickle")
        if args.keep_newlines:
            # this prevents punkt from eating newlines after sentences
            final_splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
                train_text=splitter._params,
                lang_vars=CustomLanguageVars())
        else:
            final_splitter = splitter

    else:
        # 自定義splitter
        final_splitter = IdentitySplitter()
    return final_splitter

build_tokenizer

這個(gè)函數(shù)的主要功能是把句子切分成單個(gè)的詞,比如說(shuō)把 "今天是星期幾" 切分成 "今天"、"是"、"星期幾",然后轉(zhuǎn)成對(duì)應(yīng)的整數(shù)。

def build_tokenizer(args):
    """Initialize tokenizer."""
    # 獲取huggingface的tokenizer 
    if args.tokenizer_type == "PretrainedFromHF":
        if args.rank == 0:
            print(' > building PretrainFromHF tokenizer. Vocab file is un-used, '
                  'loading tokenizer from pre-trained model', flush=True)

        if args.tokenizer_name_or_path is None:
            raise ValueError("Missing tokenizer_name_or_path while building PretrainFromHF tokenizer.")

        hf_tokenizer_kwargs = dict()
        if hasattr(args, "tokenizer_kwargs") and args.tokenizer_kwargs:
            if len(args.tokenizer_kwargs) % 2 != 0:
                raise ValueError("The token name and token value must be entered in pairs.")

            for i in range(0, len(args.tokenizer_kwargs), 2):
                hf_tokenizer_kwargs[args.tokenizer_kwargs[i]] = \
                    args.tokenizer_kwargs[i + 1]
        # 基于MegatronTokenizer構(gòu)建的類
        tokenizer = _AutoTokenizer(
            args.tokenizer_name_or_path,
            vocab_extra_ids=args.vocab_extra_ids,
            model_max_length=args.seq_length,
            use_fast=args.tokenizer_not_use_fast,
            **hf_tokenizer_kwargs
        )

        # Add vocab size (if not already set from a checkpoint).
        if getattr(args, "padded_vocab_size", None) is None:
            args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
                                                              args)
    else:
        # 
        tokenizer = TokenizerAdaptor(megatron_build_tokenizer(args))
    # 根據(jù)prompt_type完善tokenizer
    if hasattr(args, "prompt_type") and args.prompt_type is not None:
        if ("PreTrainedTokenizerBase" not in str(tokenizer.tokenizer._pad.__func__)):
            tokenizer.tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer.tokenizer)
            tokenizer.tokenizer.padding_side = "right"
        fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip())

    return tokenizer

分成了2類tokenizer,一類是PretrainedFromHF,也就是使用預(yù)訓(xùn)練的 HuggingFace 分詞器;如果不適用hf的,則則使用 TokenizerAdaptor 類和 megatron_build_tokenizer 函數(shù)創(chuàng)建分詞器實(shí)例 tokenizer。

build_dataset

這個(gè)函數(shù)的功能是把數(shù)據(jù)文件加載到內(nèi)存,返回DatasetDict 或Dataset,也就是一個(gè)Python容器。這個(gè)函數(shù)中調(diào)用的load_dataset是huggingface的datasets庫(kù)的函數(shù)。

def build_dataset(args):
    """loading dataset by huggingface"""
    raw_datasets = None
    if args.handler_name == "LlamaFactoryInstructionHandler":
        all_datasets = []
        for dataset_attr in get_dataset_list(args):
            # 加載單個(gè)數(shù)據(jù)集
            all_datasets.append(load_single_dataset(dataset_attr, args))
        # 合并數(shù)據(jù)集
        raw_datasets = merge_dataset(all_datasets, args)
    else:
        if args.handler_name == "MOSSInstructionHandler" or args.handler_name == "MOSSMultiTurnHandler":
            # for MOSS, streaming is needed.流式加載數(shù)據(jù)
            args.streaming = True
        if args.hf_datasets_params:
            with open(args.hf_datasets_params, 'r') as fin:
                param_dict = json.load(fin)
            return load_dataset(**param_dict)
        cache_dir = args.cache_dir
        split_flag = "train"
        load_from_local = os.path.exists(args.input)
        # 從本地加載
        if load_from_local:
            # args.input 是一個(gè)有效的 Python 腳本路徑
            if _has_py_script(args.input):
                logger.info("loading data from a local python script")
                raw_datasets = load_dataset(
                    args.input,
                    data_dir='./' if not args.script_data_dir else args.script_data_dir,
                    split=split_flag,
                    num_proc=None if args.streaming else args.workers,
                    cache_dir=cache_dir,
                    streaming=args.streaming,
                    trust_remote_code=False
                )
            else:
                # args.input 是一個(gè)文件或目錄路徑
                data_files = [args.input] if os.path.isfile(args.input) else \
                    glob.glob(os.path.join(args.input, '*'))
                # 獲取文件格式
                ext, data_format = _get_data_format(data_files)
                # 篩選合法的文件格式
                filtered_data_files = list(filter(lambda x: x.split('.')[-1] == ext, data_files))
                if filtered_data_files:
                    logger.info("loading data from local file, format: %s,"
                                " file num: %s", data_format, len(data_files))
                    raw_datasets = load_dataset(
                        data_format,
                        split=split_flag,
                        data_files=filtered_data_files,
                        num_proc=None if args.streaming else args.workers,
                        cache_dir=cache_dir,
                        streaming=args.streaming,
                        trust_remote_code=False
                    )
                else:
                    raise Exception("unknown local data!")
        else:
            logger.info("loading data from remote huggingface")  # 從遠(yuǎn)程 Hugging Face 數(shù)據(jù)集加載
            raw_datasets = load_dataset(
                args.input,
                split=split_flag,
                num_proc=None if args.streaming else args.workers,
                cache_dir=cache_dir,
                streaming=args.streaming,
                trust_remote_code=False
            )
        if raw_datasets is None:
            raise Exception("unknown data!")

        if args.handler_name in [
            "AlpacaStyleInstructionHandler",
            "SharegptStyleInstructionHandler",
            "AlpacaStylePairwiseHandler",
            "SharegptStylePairwiseHandler"
        ]:
            handler_dataset_attr = get_handler_dataset_attr(args, raw_datasets)

            return align_dataset(raw_datasets, handler_dataset_attr, args)

    return raw_datasets

get_dataset_handler

這個(gè)函數(shù)的功能是創(chuàng)建數(shù)據(jù)集處理實(shí)例,_get_handler_cls會(huì)根據(jù)args.handler_name選擇對(duì)應(yīng)的handler。handler的基類和子類都在mindspeed_llm/tasks/preprocess/data_handler.py里面定義了,查看BaseDatasetHandler可以知道,這個(gè)類的對(duì)外函數(shù)有這幾個(gè):get_tokenized_data、serialize_to_disk,功能分別是對(duì)數(shù)據(jù)進(jìn)行令牌化]、數(shù)據(jù)序列化。

serialize_to_disk

接著上面講,這個(gè)函數(shù)是handler的類函數(shù),用于將分詞后的數(shù)據(jù)集保存到磁盤。具體來(lái)說(shuō),它將數(shù)據(jù)集的每個(gè)樣本(或句子)序列化為二進(jìn)制文件,并生成相應(yīng)的索引文件。代碼如下:

    def _serialize_to_disk(self, iteration_batch_size=50):
        startup_start = time.time()
        if not self.tokenized_dataset:
            self.tokenized_dataset = self.get_tokenized_data()
        output_bin_files = {}  # 保存數(shù)據(jù)的文件路徑
        output_idx_files = {}  # 保存數(shù)據(jù)的文件路徑
        builders = {}  # 用于構(gòu)建索引數(shù)據(jù)集的字典
        level = "document"
        if self.args.split_sentences:
            level = "sentence"

        logger.info("Vocab size: %s", self.tokenizer.vocab_size)
        logger.info("Output prefix: %s", self.args.output_prefix)
        # 字典的key就是文件名,json_keys就是類似"input_ids", "attention_mask", "labels"的string
        for key in self.args.json_keys:
            output_bin_files[key] = f"{self.args.output_prefix}_{key}_{level}.bin"
            output_idx_files[key] = f"{self.args.output_prefix}_{key}_{level}.idx"
            # vocab_size=None : use int32 dtype for -100 will be used in labels
            # 為每個(gè)文件創(chuàng)建一個(gè)數(shù)據(jù)字典
            builders[key] = indexed_dataset.IndexedDatasetBuilder(output_bin_files[key])
        self.output_idx_files = output_idx_files
        startup_end = time.time()
        proc_start = time.time()
        total_bytes_processed = 0
        logger.info("Time to startup:%s", startup_end - startup_start)

        skip_num = 0
        # 遍歷每個(gè)文件的內(nèi)容
        for i, doc in enumerate(self.tokenized_dataset.iter(batch_size=iteration_batch_size), start=1):
            # In post-training stage, we need to drop the data exceeded set sequence-length
            skip_indices = set()
            # 進(jìn)行一次篩選
            for key in self.args.json_keys:
                batch = [sentences for sentences in doc[key] if len(sentences) > 0]

                if len(batch) == 0:
                    continue

                for j, sentences in enumerate(batch):
                    for k, sentence in enumerate(sentences):
                        if self.args.seq_length is not None and len(sentence) >= self.args.seq_length:
                            skip_indices.add((j, k))
            # 正式開始處理每個(gè)句子
            for key in self.args.json_keys:
                batch = [sentences for sentences in doc[key] if len(sentences) > 0]

                if len(batch) == 0:
                    continue

                for j, sentences in enumerate(batch):
                    for k, sentence in enumerate(sentences):
                        if (j, k) in skip_indices:
                            skip_num = skip_num + 1
                            continue
                        # 記錄處理的字節(jié)數(shù)
                        total_bytes_processed += len(sentence) * np.int32().itemsize
                        # 把合法的句子加到builders里面
                        builders[key].add_item(sentence)
                    builders[key].end_document()

            batch_id = i * iteration_batch_size
            if batch_id % self.args.log_interval == 0:
                current = time.time()
                elapsed = current - proc_start
                mbs = total_bytes_processed / elapsed / 1024 / 1024
                logger.info("Processed %s documents (%s docs/s, %s MB/s).", batch_id, batch_id / elapsed, mbs)

        logger.info("Skip %s sample exceeded seq-length(%s)", skip_num / len(self.args.json_keys), self.args.seq_length)
        for key in self.args.json_keys:
            builders[key].finalize(output_idx_files[key])

以上就是mindspeed-llm處理數(shù)據(jù)的主要函數(shù)了,大家還有什么想了解的呢?歡迎評(píng)論區(qū)提問(wèn)!

本文由博客一文多發(fā)平臺(tái) OpenWrite 發(fā)布!

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容