mindspeed-llm是昇騰模型套件代碼倉(cāng),原來(lái)叫"modelLink"。這篇文章帶大家閱讀一下數(shù)據(jù)處理腳本preprocess_data.py(基于1.0.0分支),數(shù)據(jù)處理是模型訓(xùn)練的第一步,經(jīng)常會(huì)用到。
文章中貼的源碼加了相關(guān)注釋,同學(xué)們可以把源碼和注釋結(jié)合起來(lái)看。
首先來(lái)看一下main函數(shù)
def main():
# 獲取入?yún)?,通過(guò)后面的代碼可以知道有哪些關(guān)鍵參數(shù)
args = get_args()
# 參數(shù)校驗(yàn)
validate_args(args)
# 合并已經(jīng)處理好的數(shù)據(jù)集
if args.merge_group_keys is not None:
merge_datasets(args)
return
# 創(chuàng)建splitter,用來(lái)把文章段落分割成句子
splitter = build_splitter(args)
# 創(chuàng)建tokenizer,用來(lái)把句子切分成單個(gè)的詞
tokenizer = build_tokenizer(args)
logger.info("building dataset: %s", args.input)
# 加載數(shù)據(jù),把CSV、JSON、TXT等格式的數(shù)據(jù)加載到內(nèi)存
raw_data = build_dataset(args)
# 保存到一個(gè)文件
if args.n_subs == 1:
# 獲取處理后的數(shù)據(jù)句柄
handler = get_dataset_handler(args, raw_data, tokenizer, splitter)
# 數(shù)據(jù)落盤
handler.serialize_to_disk()
# 保存到多個(gè)文件,使用多進(jìn)程處理,單文件的處理方式和if條件中是一致的
else:
target_prefix = args.output_prefix
target_prefixname = os.path.basename(target_prefix)
num_samples = len(raw_data)
start_ends = cut_range_to_subs(num_samples, num_samples // args.n_subs)
subsets = [raw_data.select(range(x[0], x[1])) for x in start_ends]
# multiprocessing
params_list = []
for k, subset in enumerate(subsets):
args_ = copy.deepcopy(args)
args_.output_prefix = target_prefix.replace(target_prefixname, f'{str(k).zfill(3)}_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}')
params = [args_, subset, tokenizer, splitter]
params_list.append(params)
pool = multiprocessing.Pool()
sub_idx_files = pool.map(handle_subset, params_list)
pool.close()
pool.join()
for key in sub_idx_files[0].keys():
idx_files = [x[key] for x in sub_idx_files]
idx_files.sort()
target_idx = idx_files[0].replace(f'000_of_{str(len(subsets)-1).zfill(3)}_{target_prefixname}', target_prefixname)
target_bin = target_idx.replace('.idx', '.bin')
idx = IndexedDatasetBuilder(target_bin)
for idx_file in idx_files:
idx.add_index(idx_file.replace('.idx', ''))
idx.finalize(target_idx)
for idx_file in idx_files:
os.remove(idx_file)
os.remove(idx_file.replace('.idx', '.bin'))
可以看到,main函數(shù)處理邏輯主要由這幾個(gè)函數(shù)組成:build_splitter、build_tokenizer、build_dataset、get_dataset_handler、serialize_to_disk。
build_splitter
這個(gè)函數(shù)的功能是把文字段落分割成單個(gè)句子,查看源碼,主要使用的是三方庫(kù)nltk的函數(shù):
def build_splitter(args):
if nltk and args.split_sentences:
nltk.download("punkt", quiet=True)
if args.split_sentences:
if not nltk:
logger.error("NLTK is not available to split sentences.")
raise Exception("nltk is not available")
splitter = nltk.load("tokenizers/punkt/english.pickle")
if args.keep_newlines:
# this prevents punkt from eating newlines after sentences
final_splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text=splitter._params,
lang_vars=CustomLanguageVars())
else:
final_splitter = splitter
else:
# 自定義splitter
final_splitter = IdentitySplitter()
return final_splitter
build_tokenizer
這個(gè)函數(shù)的主要功能是把句子切分成單個(gè)的詞,比如說(shuō)把 "今天是星期幾" 切分成 "今天"、"是"、"星期幾",然后轉(zhuǎn)成對(duì)應(yīng)的整數(shù)。
def build_tokenizer(args):
"""Initialize tokenizer."""
# 獲取huggingface的tokenizer
if args.tokenizer_type == "PretrainedFromHF":
if args.rank == 0:
print(' > building PretrainFromHF tokenizer. Vocab file is un-used, '
'loading tokenizer from pre-trained model', flush=True)
if args.tokenizer_name_or_path is None:
raise ValueError("Missing tokenizer_name_or_path while building PretrainFromHF tokenizer.")
hf_tokenizer_kwargs = dict()
if hasattr(args, "tokenizer_kwargs") and args.tokenizer_kwargs:
if len(args.tokenizer_kwargs) % 2 != 0:
raise ValueError("The token name and token value must be entered in pairs.")
for i in range(0, len(args.tokenizer_kwargs), 2):
hf_tokenizer_kwargs[args.tokenizer_kwargs[i]] = \
args.tokenizer_kwargs[i + 1]
# 基于MegatronTokenizer構(gòu)建的類
tokenizer = _AutoTokenizer(
args.tokenizer_name_or_path,
vocab_extra_ids=args.vocab_extra_ids,
model_max_length=args.seq_length,
use_fast=args.tokenizer_not_use_fast,
**hf_tokenizer_kwargs
)
# Add vocab size (if not already set from a checkpoint).
if getattr(args, "padded_vocab_size", None) is None:
args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,
args)
else:
#
tokenizer = TokenizerAdaptor(megatron_build_tokenizer(args))
# 根據(jù)prompt_type完善tokenizer
if hasattr(args, "prompt_type") and args.prompt_type is not None:
if ("PreTrainedTokenizerBase" not in str(tokenizer.tokenizer._pad.__func__)):
tokenizer.tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer.tokenizer)
tokenizer.tokenizer.padding_side = "right"
fix_model_tokenizer(tokenizer.tokenizer, args.prompt_type.strip())
return tokenizer
分成了2類tokenizer,一類是PretrainedFromHF,也就是使用預(yù)訓(xùn)練的 HuggingFace 分詞器;如果不適用hf的,則則使用 TokenizerAdaptor 類和 megatron_build_tokenizer 函數(shù)創(chuàng)建分詞器實(shí)例 tokenizer。
build_dataset
這個(gè)函數(shù)的功能是把數(shù)據(jù)文件加載到內(nèi)存,返回DatasetDict 或Dataset,也就是一個(gè)Python容器。這個(gè)函數(shù)中調(diào)用的load_dataset是huggingface的datasets庫(kù)的函數(shù)。
def build_dataset(args):
"""loading dataset by huggingface"""
raw_datasets = None
if args.handler_name == "LlamaFactoryInstructionHandler":
all_datasets = []
for dataset_attr in get_dataset_list(args):
# 加載單個(gè)數(shù)據(jù)集
all_datasets.append(load_single_dataset(dataset_attr, args))
# 合并數(shù)據(jù)集
raw_datasets = merge_dataset(all_datasets, args)
else:
if args.handler_name == "MOSSInstructionHandler" or args.handler_name == "MOSSMultiTurnHandler":
# for MOSS, streaming is needed.流式加載數(shù)據(jù)
args.streaming = True
if args.hf_datasets_params:
with open(args.hf_datasets_params, 'r') as fin:
param_dict = json.load(fin)
return load_dataset(**param_dict)
cache_dir = args.cache_dir
split_flag = "train"
load_from_local = os.path.exists(args.input)
# 從本地加載
if load_from_local:
# args.input 是一個(gè)有效的 Python 腳本路徑
if _has_py_script(args.input):
logger.info("loading data from a local python script")
raw_datasets = load_dataset(
args.input,
data_dir='./' if not args.script_data_dir else args.script_data_dir,
split=split_flag,
num_proc=None if args.streaming else args.workers,
cache_dir=cache_dir,
streaming=args.streaming,
trust_remote_code=False
)
else:
# args.input 是一個(gè)文件或目錄路徑
data_files = [args.input] if os.path.isfile(args.input) else \
glob.glob(os.path.join(args.input, '*'))
# 獲取文件格式
ext, data_format = _get_data_format(data_files)
# 篩選合法的文件格式
filtered_data_files = list(filter(lambda x: x.split('.')[-1] == ext, data_files))
if filtered_data_files:
logger.info("loading data from local file, format: %s,"
" file num: %s", data_format, len(data_files))
raw_datasets = load_dataset(
data_format,
split=split_flag,
data_files=filtered_data_files,
num_proc=None if args.streaming else args.workers,
cache_dir=cache_dir,
streaming=args.streaming,
trust_remote_code=False
)
else:
raise Exception("unknown local data!")
else:
logger.info("loading data from remote huggingface") # 從遠(yuǎn)程 Hugging Face 數(shù)據(jù)集加載
raw_datasets = load_dataset(
args.input,
split=split_flag,
num_proc=None if args.streaming else args.workers,
cache_dir=cache_dir,
streaming=args.streaming,
trust_remote_code=False
)
if raw_datasets is None:
raise Exception("unknown data!")
if args.handler_name in [
"AlpacaStyleInstructionHandler",
"SharegptStyleInstructionHandler",
"AlpacaStylePairwiseHandler",
"SharegptStylePairwiseHandler"
]:
handler_dataset_attr = get_handler_dataset_attr(args, raw_datasets)
return align_dataset(raw_datasets, handler_dataset_attr, args)
return raw_datasets
get_dataset_handler
這個(gè)函數(shù)的功能是創(chuàng)建數(shù)據(jù)集處理實(shí)例,_get_handler_cls會(huì)根據(jù)args.handler_name選擇對(duì)應(yīng)的handler。handler的基類和子類都在mindspeed_llm/tasks/preprocess/data_handler.py里面定義了,查看BaseDatasetHandler可以知道,這個(gè)類的對(duì)外函數(shù)有這幾個(gè):get_tokenized_data、serialize_to_disk,功能分別是對(duì)數(shù)據(jù)進(jìn)行令牌化]、數(shù)據(jù)序列化。
serialize_to_disk
接著上面講,這個(gè)函數(shù)是handler的類函數(shù),用于將分詞后的數(shù)據(jù)集保存到磁盤。具體來(lái)說(shuō),它將數(shù)據(jù)集的每個(gè)樣本(或句子)序列化為二進(jìn)制文件,并生成相應(yīng)的索引文件。代碼如下:
def _serialize_to_disk(self, iteration_batch_size=50):
startup_start = time.time()
if not self.tokenized_dataset:
self.tokenized_dataset = self.get_tokenized_data()
output_bin_files = {} # 保存數(shù)據(jù)的文件路徑
output_idx_files = {} # 保存數(shù)據(jù)的文件路徑
builders = {} # 用于構(gòu)建索引數(shù)據(jù)集的字典
level = "document"
if self.args.split_sentences:
level = "sentence"
logger.info("Vocab size: %s", self.tokenizer.vocab_size)
logger.info("Output prefix: %s", self.args.output_prefix)
# 字典的key就是文件名,json_keys就是類似"input_ids", "attention_mask", "labels"的string
for key in self.args.json_keys:
output_bin_files[key] = f"{self.args.output_prefix}_{key}_{level}.bin"
output_idx_files[key] = f"{self.args.output_prefix}_{key}_{level}.idx"
# vocab_size=None : use int32 dtype for -100 will be used in labels
# 為每個(gè)文件創(chuàng)建一個(gè)數(shù)據(jù)字典
builders[key] = indexed_dataset.IndexedDatasetBuilder(output_bin_files[key])
self.output_idx_files = output_idx_files
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
logger.info("Time to startup:%s", startup_end - startup_start)
skip_num = 0
# 遍歷每個(gè)文件的內(nèi)容
for i, doc in enumerate(self.tokenized_dataset.iter(batch_size=iteration_batch_size), start=1):
# In post-training stage, we need to drop the data exceeded set sequence-length
skip_indices = set()
# 進(jìn)行一次篩選
for key in self.args.json_keys:
batch = [sentences for sentences in doc[key] if len(sentences) > 0]
if len(batch) == 0:
continue
for j, sentences in enumerate(batch):
for k, sentence in enumerate(sentences):
if self.args.seq_length is not None and len(sentence) >= self.args.seq_length:
skip_indices.add((j, k))
# 正式開始處理每個(gè)句子
for key in self.args.json_keys:
batch = [sentences for sentences in doc[key] if len(sentences) > 0]
if len(batch) == 0:
continue
for j, sentences in enumerate(batch):
for k, sentence in enumerate(sentences):
if (j, k) in skip_indices:
skip_num = skip_num + 1
continue
# 記錄處理的字節(jié)數(shù)
total_bytes_processed += len(sentence) * np.int32().itemsize
# 把合法的句子加到builders里面
builders[key].add_item(sentence)
builders[key].end_document()
batch_id = i * iteration_batch_size
if batch_id % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed / elapsed / 1024 / 1024
logger.info("Processed %s documents (%s docs/s, %s MB/s).", batch_id, batch_id / elapsed, mbs)
logger.info("Skip %s sample exceeded seq-length(%s)", skip_num / len(self.args.json_keys), self.args.seq_length)
for key in self.args.json_keys:
builders[key].finalize(output_idx_files[key])
以上就是mindspeed-llm處理數(shù)據(jù)的主要函數(shù)了,大家還有什么想了解的呢?歡迎評(píng)論區(qū)提問(wèn)!
本文由博客一文多發(fā)平臺(tái) OpenWrite 發(fā)布!