RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation 論文及代碼總結(jié)(三)

四、數(shù)據(jù)處理模塊


該模塊需要看兩處代碼分為別hdf5_vla_dataset.pydataset.py

4.1 hdf5_vla_dataset.py

import numpy as np
import h5py

def parse_hdf5_file(self, file_path):
    with h5py.File(file_path, 'r') as f:
        qpos = f['observations']['qpos'][:]
        num_steps = qpos.shape[0] 
        if num_steps < 128:
            return False, None

        EPS = 1e-2 # 為了證明機(jī)械臂是有移動的
        qpos_delta = np.abs(qpos - qpos[0:1]) # 其他qps值與第一個值之間的差距
        indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
        if len(indices) > 0:
            first_idx = indices[0] # 代表機(jī)器人開始移動的時間索引
        else:
            raise ValueError("Found no qpos that exceeds the threshold.")# 為了證明機(jī)械臂是有移動的
        
        step_id = np.random.randint(first_idx-1, num_steps)
        
        dir_path = os.path.dirname(file_path)
        with open(os.path.join(dir_path, 'expanded_instruction_gpt-4-turbo.json'), 'r') as f_instr:
            instruction_dict = json.load(f_instr)
            
        instruction_type = np.random.choice([ # 隨機(jī)選擇一種文本模態(tài)
            'instruction', 'simplified_instruction', 'expanded_instruction'])
        instruction = instruction_dict[instruction_type] # 對應(yīng)的文本描述
        if isinstance(instruction, list):
            instruction = np.random.choice(instruction)

        # Assemble the meta
        meta = {
            "dataset_name": self.DATASET_NAME,
            "#steps": num_steps, # 該數(shù)據(jù)有多少幀 T
            "step_id": step_id,
            "instruction": instruction
        }

        single_side_norm_scale_vec_len = qpos.shape[1] // 2 if qpos.shape[1] % 2 == 0 and qpos.shape[-1] % 2 <=10 else None
        assert single_side_norm_scale_vec_len is not None, "qpos cannot divede by 2 and lager than 10"
        _norm_vec = [1 for i in  range(qpos.shape[1])]

        qpos_norm_vec = _norm_vec.copy()
        qpos_norm_vec[single_side_norm_scale_vec_len - 1] = 4.7908
        qpos_norm_vec[-1] = 4.7888
        qpos = qpos / np.array([
            qpos_norm_vec
        ])

        action_norm_vec = _norm_vec.copy()
        action_norm_vec[single_side_norm_scale_vec_len - 1] = 11.8997
        action_norm_vec[-1] = 13.9231
        f_action = f['action']
        target_qpos = f_action[step_id:step_id + self.CHUNK_SIZE] / np.array([  # CHUNK_SIZE 序列段長度, 這里為64
            action_norm_vec
             ])

        state = qpos[step_id:step_id+1]
        state_std = np.std(qpos, axis=0)
        state_mean = np.mean(qpos, axis=0)
        state_norm = np.sqrt(np.mean(qpos**2, axis=0))
        actions = target_qpos
        if actions.shape[0] < self.CHUNK_SIZE: # ??小于段長度, 對最后的階段進(jìn)行重復(fù)操作
            # Pad the actions using the last action
            actions = np.concatenate([
                actions,
                np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
            ], axis=0)
        
        # Fill the state/action into the unified vector
        def fill_in_state(values):
            values_len = values.shape[-1] // 2 - 1 if values.shape[-1] % 2 == 0 and values.shape[-1] % 2 <=10 else None
            assert values is not None, "values cannot divede by 2 and lager than 10"
            # Target indices corresponding to your state space
            # In this example: 6 joints + 1 gripper for each arm
            UNI_STATE_INDICES = [
                STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(values_len)
            ] + [
                STATE_VEC_IDX_MAPPING["left_gripper_open"]
            ] + [
                STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(values_len)
            ] + [
                STATE_VEC_IDX_MAPPING["right_gripper_open"]
            ]
            uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
            uni_vec[..., UNI_STATE_INDICES] = values
            return uni_vec

        state = fill_in_state(state) # 將1,14 轉(zhuǎn)位1,128
        state_indicator = fill_in_state(np.ones_like(state_std))
        state_std = fill_in_state(state_std)
        state_mean = fill_in_state(state_mean)
        state_norm = fill_in_state(state_norm)
        actions = fill_in_state(actions)
        
        def parse_img(key):
            imgs = []
            for i in range(max(step_id-self.IMG_HISORY_SIZE+1, 0), step_id+1): # [step-1, step+1), 取兩個圖片, 即向前多取一張
                if key not in f['observations']['images']:
                    key = key.replace("_wrist", "")
                img = f['observations']['images'][key][i]
                if not isinstance(img, np.ndarray):
                    img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
                imgs.append(img)
            imgs = np.stack(imgs)
            if imgs.shape[0] < self.IMG_HISORY_SIZE:
                # Pad the images using the first image
                imgs = np.concatenate([
                    np.tile(imgs[:1], (self.IMG_HISORY_SIZE-imgs.shape[0], 1, 1, 1)),
                    imgs
                ], axis=0)
            return imgs
        cam_high = parse_img('cam_high') # 這里包括cam_high, cam_left_wrist, cam_right_wrist, shape: (2, 480, 640, 3)
        valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
        cam_high_mask = np.array(
            [False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len) # (True, True)
        cam_left_wrist = parse_img('cam_left_wrist') # shape: (2, 480, 640, 3)
        cam_left_wrist_mask = cam_high_mask.copy()
        cam_right_wrist = parse_img('cam_right_wrist') # (True, True)
        cam_right_wrist_mask = cam_high_mask.copy() # (True, True)
        
        return True, {
            "meta": meta,
            "state": state, # shape(1, 128)
            "state_std": state_std, # shape (128)
            "state_mean": state_mean, # shape (128), 對相應(yīng)的某類別(14中一個)所有數(shù)值進(jìn)行評價
            "state_norm": state_norm, # shape (128)
            "actions": actions, # shape (64, 128), 64個維度, 128 對應(yīng)14個位置有數(shù)值, 索引和state_indicator一致
            "state_indicator": state_indicator, # shape 128, 對應(yīng)14個位置為1, 其余地方為0
            "cam_high": cam_high,  # shape: (2, 480, 640, 3)
            "cam_high_mask": cam_high_mask, # (True, True)
            "cam_left_wrist": cam_left_wrist, # shape: (2, 480, 640, 3)
            "cam_left_wrist_mask": cam_left_wrist_mask, # (True, True)
            "cam_right_wrist": cam_right_wrist, # shape: (2, 480, 640, 3)
            "cam_right_wrist_mask": cam_right_wrist_mask # (True, True)
        }
4.1.1 f(.hd5py)文件內(nèi)容, 下面的T=300, 代表的是時間, 是一個完整動作的周期
{
  "observations":{
    "images":{
      "cam_high":  np.array(300, 480, 640, 3),
      "cam_low":  np.array(300, 480, 640, 3),
      "cam_left": np.array(300, 480, 640, 3),
      "cam_right": np.array(300, 480, 640, 3)
    },
   "effort": np.array(300, 16), # 關(guān)節(jié)力矩(Joint Efforts)
   "qpos": np.array(300, 16), # 末端執(zhí)行器的位姿
   "qvel": np.array(300, 16), # 關(guān)節(jié)速度(Joint Velocities)  
  },
  "action": np.array(300, 16)
}

?? 這里為什么要求出action是一段的呢? 是因?yàn)橹髸鶕?jù)一段的action做順滑, 通過差值一步步取做

4.1.2 針對機(jī)械臂沒有太大浮動數(shù)據(jù)進(jìn)行過濾,

可參考RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation 論文及代碼總結(jié)(一) 5.2部分中的其他finetune細(xì)節(jié)
* 對于數(shù)據(jù)動作周期小于128, 進(jìn)行過濾
* 計算之后的每一個動作與第一個動作進(jìn)行浮動計算, 當(dāng)前一個動作與當(dāng)前動作的差距大于閾值, 則認(rèn)為有動作發(fā)生, 記作first_idx

num_steps = qpos.shape[0]
if num_steps < 128:
    return False, None
EPS = 1e-2 # 為了證明機(jī)械臂是有移動的
# Get the idx of the first qpos whose delta exceeds the threshold
qpos_delta = np.abs(qpos - qpos[0:1]) # 其他qps值與第一個值之間的差距
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
if len(indices) > 0:
    first_idx = indices[0] # 代表機(jī)器人開始移動的時間索引
else:
    raise ValueError("Found no qpos that exceeds the threshold.")# 為了證明機(jī)械臂是有移動的
4.1.3 生成meta文件
{
  "dataset_name": "custom",
  "#steps": num_stemp, # 該數(shù)據(jù)有多少幀時長, 這里使用的是300
  "step_id": step_id, # 從上述得到的first_index 開始, 直到num_stemp結(jié)束隨機(jī)選一個值
  "instruction": instruction # 從instrction隨機(jī)選一個值'instruction', 'simplified_instruction', 'expanded_instruction']
}
4.1.4 對原始數(shù)據(jù)進(jìn)行歸一化, 并對action(target_qpos)隨機(jī)截取self.CHUNK_SIZE長度大小的段, 并根據(jù)step_id得到state, 當(dāng)前機(jī)械臂qpos??
# 1??對qpos進(jìn)行歸一化
qpos_norm_vec = [1, 1, 1, 1, 1, 1, 1, 4.7908, 1, 1, 1, 1, 1, 1, 1, 4.7888]
qpos = qpos / np.array([qpos_norm_vec])

# 2??對action進(jìn)行歸一化, ??小于段長度, 對最后的階段進(jìn)行重復(fù)操作
action_norm_vec = [1, 1, 1, 1, 1, 1, 1, 11.8997, 1, 1, 1, 1, 1, 1, 1, 13.9231]
actions = f_action[step_id:step_id + self.CHUNK_SIZE] / np.array([action_norm_vec]) # self.CHUNK_SIZE=64
if actions.shape[0] < self.CHUNK_SIZE:
    actions = np.concatenate([
        actions,
        np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
    ], axis=0)

# 3??并根據(jù)step_id得到state, 當(dāng)前機(jī)械臂qpos
state = qpos[step_id:step_id+1]
state_std = np.std(qpos, axis=0)
state_mean = np.mean(qpos, axis=0)
state_norm = np.sqrt(np.mean(qpos**2, axis=0))
4.1.5 對數(shù)據(jù)將16的維度映射到128的維度上
def fill_in_state(values):
    values_len = values.shape[-1] // 2 - 1 if values.shape[-1] % 2 == 0 and values.shape[-1] % 2 <=10 else None
    assert values is not None, "values cannot divede by 2 and lager than 10"
    # Target indices corresponding to your state space
    # In this example: 6 joints + 1 gripper for each arm
    UNI_STATE_INDICES = [
        STATE_VEC_IDX_MAPPING[f"left_arm_joint_{i}_pos"] for i in range(values_len)
    ] + [
        STATE_VEC_IDX_MAPPING["left_gripper_open"]
    ] + [
        STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(values_len)
    ] + [
        STATE_VEC_IDX_MAPPING["right_gripper_open"]
    ]
    uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
    uni_vec[..., UNI_STATE_INDICES] = values
    return uni_vec


state = fill_in_state(state) # 將1,16 轉(zhuǎn)位1,128, 把16維度上的值填到128維度的位置上, 不再對應(yīng)位置上的值默認(rèn)是0
state_indicator = fill_in_state(np.ones_like(state_std))
state_std = fill_in_state(state_std)
state_mean = fill_in_state(state_mean)
state_norm = fill_in_state(state_norm)
actions = fill_in_state(actions)
4.1.6 對圖像數(shù)據(jù)進(jìn)行處理,這里parse_img處理中解析圖像, 如果圖像時間軸上的數(shù)量少于預(yù)設(shè)的self.IMG_HISORY_SIZE, 則重復(fù)原圖補(bǔ)充
cam_high = parse_img('cam_high') # 這里包括cam_high, cam_left_wrist, cam_right_wrist, shape: (2, 480, 640, 3); cam_low沒有選取
valid_len = min(step_id - (first_idx - 1) + 1, self.IMG_HISORY_SIZE)
cam_high_mask = np.array([False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len) # 之前一直浮動小的圖像設(shè)置mask為False
cam_left_wrist = parse_img('cam_left_wrist') # shape: (2, 480, 640, 3)
cam_left_wrist_mask = cam_high_mask.copy()
cam_right_wrist = parse_img('cam_right_wrist') # (True, True)
cam_right_wrist_mask = cam_high_mask.copy() # (True, True)

4.2 dataset.py

def __getitem__(self, index):
    # For robustness, we will try to load the data until we succeed
    while True:
        data_dict = None
        try:
            if self.use_hdf5:
                res = self.hdf5_dataset.get_item()
                content = res['meta']
                states = res['state'] # (1, 128)
                actions = res['actions'] # (64, 128)
                state_elem_mask = res['state_indicator']
                image_metas = [
                    res['cam_high'], res['cam_high_mask'],
                    res['cam_right_wrist'], res['cam_right_wrist_mask'],
                    res['cam_left_wrist'], res['cam_left_wrist_mask'],
                ]
                state_std = res['state_std']
                state_mean = res['state_mean']
                state_norm = res['state_norm']
            else:
                (content, _, states, _, actions, _, 
                state_elem_mask, *image_metas, 
                state_std, state_mean, state_norm) = self._safe_load(index)
            
            data_dict = {}

            data_dict['dataset_name'] = content['dataset_name']
            data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]
            data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] \
                if random.random() > self.cond_mask_prob else 0 # 有一定概率ctrl_freq為0
            
            if self.state_noise_snr is not None:
                states += np.random.normal(
                    0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)), 
                    states.shape)
            ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean'])
            ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
            # Randomly mask the states by the mean state
            data_dict["states"] = states \
                if random.random() > self.cond_mask_prob else ds_state_mean
            data_dict["actions"] = actions
            data_dict["state_elem_mask"] = state_elem_mask \
                if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask)
            
            # Stat for the episode that the step belongs to 
            data_dict["state_norm"] = state_norm
            
            # We replace the invalid images with the background image
            # and also randomly mask images by the background image
            background_color = np.array([
                int(x*255) for x in self.image_processor.image_mean
            ], dtype=np.uint8).reshape(1, 1, 3)
            background_image = np.ones((
                self.image_processor.size["height"], 
                self.image_processor.size["width"], 3), dtype=np.uint8
            ) * background_color # 基于預(yù)訓(xùn)練的圖像均值構(gòu)建背景圖片
            
            image_metas = list(self.pairwise(image_metas))
            mask_probs = [self.cond_mask_prob] * self.num_cameras # 一定概率對圖像加入mask, 提高模型泛化性
            if self.cam_ext_mask_prob >= 0.0:
                mask_probs[0] = self.cam_ext_mask_prob
            rearranged_images = []
            for i in range(self.img_history_size):
                for j in range(self.num_cameras):
                    images, image_mask = image_metas[j]
                    image, valid = images[i], image_mask[i]
                    if valid and (math.prod(image.shape) > 0) and \
                        (random.random() > mask_probs[j]):
                        rearranged_images.append((image, True))
                    else:
                        rearranged_images.append((background_image.copy(), False)) # 直接將背景噪音加入進(jìn)去
            
            preprocessed_images = []
            processor = self.image_processor
            for image, valid in rearranged_images:
                image = Image.fromarray(image)
                if self.image_size is not None:
                    image = transforms.Resize(self.image_size)(image) # (1008, 336)
                # assert image.height == 336, "We haven't prepare for training with images of different resolutions."
                
                if valid and self.auto_adjust_image_brightness: # False
                    pixel_values = list(image.getdata())
                    average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
                    if average_brightness <= 0.15:
                        image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
                
                # Only apply image augmentation to 50% of the images
                if valid and self.image_aug and (random.random() > 0.5):
                    aug_type = random.choice([
                        "corrput_only", "color_only", "both"])
                    if aug_type != "corrput_only":
                        image = transforms.ColorJitter(
                            brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
                    if aug_type != "color_only":
                        image = image_corrupt(image)
                
                if self.image_aspect_ratio == 'pad': # True
                    def expand2square(pil_img, background_color):
                        width, height = pil_img.size
                        if width == height:
                            return pil_img
                        elif width > height:
                            result = Image.new(pil_img.mode, (width, width), background_color)
                            result.paste(pil_img, (0, (width - height) // 2))
                            return result
                        else:
                            result = Image.new(pil_img.mode, (height, height), background_color)
                            result.paste(pil_img, ((height - width) // 2, 0))
                            return result
                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
                preprocessed_images.append(image)
            data_dict["images"] = preprocessed_images

            if self.use_precomp_lang_embed:
                if content["instruction"][-1] == ".":
                    content["instruction"] = content["instruction"][:-1]
                data_dict["lang_embed"] = torch.load(content["instruction"]) \
                    if random.random() > self.cond_mask_prob else self.empty_lang_embed
            else:
                instruction = content["instruction"] \
                    if random.random() > self.cond_mask_prob else "" # 語言有的時候不輸入
                data_dict["input_ids"] = self.tokenizer(
                    instruction,
                    return_tensors="pt",
                    padding="longest",
                    truncation=False,
                ).input_ids[0] # 得到分詞的token編碼, (1, 35)
            
                assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \
                    f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."
            
            for k, v in data_dict.items():
                if isinstance(v, np.ndarray):
                    data_dict[k] = torch.from_numpy(v)

            for k, v in data_dict.items():
                assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}"
                    # data_dict[k] = torch.from_numpy(v)
    
            return data_dict
        except BaseException as e:
            # Print the error info
            if data_dict is not None:
                print(f"Error catched when processing sample from {data_dict.get('dataset_name')}:", e)
            else:
                print(f"Error catched when processing sample:", e)
            traceback.print_exc()
            # Try incresing the index
            index = (index + 1) % len(self)

上述hdf5_vla_dataset.py主要代碼為dataset.py代碼中self.hdf5_dataset.get_item(), 最終返回data_dict

4.2.1 數(shù)據(jù)重新整理

res = self.hdf5_dataset.get_item()
content = res['meta']
states = res['state'] # (1, 128)
actions = res['actions'] # (64, 128)
state_elem_mask = res['state_indicator']
image_metas = [
    res['cam_high'], res['cam_high_mask'],
    res['cam_right_wrist'], res['cam_right_wrist_mask'],
    res['cam_left_wrist'], res['cam_left_wrist_mask'],
]
state_std = res['state_std']
state_mean = res['state_mean']
state_norm = res['state_norm']

data_dict['dataset_name'] = content['dataset_name']
data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']]
data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] hu mao cif random.random() > self.cond_mask_prob else 0 # "configs/custom_configs/custom_dataset_control_freq.json",  有一定概率ctrl_freq為0

4.2.2 數(shù)據(jù)進(jìn)行后處理

# 1??對state進(jìn)行加噪聲, 類似關(guān)鍵點(diǎn)檢測生成高斯圖一樣
if self.state_noise_snr is not None:
    states += np.random.normal(
        0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)), 
        states.shape)

# 2??對state進(jìn)行進(jìn)一步的隨機(jī)drop, drop的值用預(yù)先處理的均值和標(biāo)準(zhǔn)差進(jìn)行填充, state_elem_mask也是如此, 一開始全是1, 但是這里會隨機(jī)補(bǔ)充0
ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean'])# "configs/custom_configs/custom_dataset_stat.json"
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1))
data_dict["states"] = states if random.random() > self.cond_mask_prob else ds_state_mean 
data_dict["actions"] = actions
data_dict["state_elem_mask"] = state_elem_mask if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask)
data_dict["state_norm"] = state_norm

4.2.3 對圖像進(jìn)行后處理

有一定概率基于預(yù)訓(xùn)練的圖像均值構(gòu)建背景圖片, 隨機(jī)加入圖片構(gòu)建圖像數(shù)據(jù), 處理后的圖片保存在data_dict["images"]

background_color = np.array([
    int(x*255) for x in self.image_processor.image_mean
], dtype=np.uint8).reshape(1, 1, 3)
background_image = np.ones((
    self.image_processor.size["height"], 
    self.image_processor.size["width"], 3), dtype=np.uint8
) * background_color # 基于預(yù)訓(xùn)練的圖像均值構(gòu)建背景圖片

image_metas = list(self.pairwise(image_metas))
mask_probs = [self.cond_mask_prob] * self.num_cameras # 一定概率對圖像加入mask, 提高模型泛化性
if self.cam_ext_mask_prob >= 0.0:
    mask_probs[0] = self.cam_ext_mask_prob
rearranged_images = []
for i in range(self.img_history_size):
    for j in range(self.num_cameras):
        images, image_mask = image_metas[j]
        image, valid = images[i], image_mask[i]
        if valid and (math.prod(image.shape) > 0) and \
            (random.random() > mask_probs[j]):
            rearranged_images.append((image, True))
        else:
            rearranged_images.append((background_image.copy(), False)) # 直接將背景噪音加入進(jìn)去

preprocessed_images = []
processor = self.image_processor
for image, valid in rearranged_images:
    image = Image.fromarray(image)
    if self.image_size is not None:
        image = transforms.Resize(self.image_size)(image) # (1008, 336)
    # assert image.height == 336, "We haven't prepare for training with images of different resolutions."
    
    if valid and self.auto_adjust_image_brightness: # False
        pixel_values = list(image.getdata())
        average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3)
        if average_brightness <= 0.15:
            image = transforms.ColorJitter(brightness=(1.75,1.75))(image)
    
    # Only apply image augmentation to 50% of the images
    if valid and self.image_aug and (random.random() > 0.5):
        aug_type = random.choice([
            "corrput_only", "color_only", "both"])
        if aug_type != "corrput_only":
            image = transforms.ColorJitter(
                brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image)
        if aug_type != "color_only":
            image = image_corrupt(image)
    
    if self.image_aspect_ratio == 'pad': # True
        def expand2square(pil_img, background_color):
            width, height = pil_img.size
            if width == height:
                return pil_img
            elif width > height:
                result = Image.new(pil_img.mode, (width, width), background_color)
                result.paste(pil_img, (0, (width - height) // 2))
                return result
            else:
                result = Image.new(pil_img.mode, (height, height), background_color)
                result.paste(pil_img, ((height - width) // 2, 0))
                return result
        image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
    preprocessed_images.append(image)
data_dict["images"] = preprocessed_images

4.2.4 對文本進(jìn)行后處理, 獲得短語token分詞

instruction = content["instruction"] \
    if random.random() > self.cond_mask_prob else "" # 語言有的時候不輸入
data_dict["input_ids"] = self.tokenizer(
    instruction,
    return_tensors="pt",
    padding="longest",
    truncation=False,
).input_ids[0] # 得到分詞的token編碼, (1, 35)

assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \
    f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}."

4.2.5 對數(shù)據(jù)進(jìn)行后處理, 獲得數(shù)據(jù)字典, data_dict

{
  "data_name": "custom",
  "data_idx": 0, # 文件名對應(yīng)的索引
  "ctrl_freq": 0, 
  "states": tensor[np.array(1, 128)],
  "actions": tensor[np.array(64, 128)],
  "state_elem_mask": tensor[np.array(128,)],
  "state_norm": tensor[np.array(128,)],
  "images": [6, 3, 384, 384], # 這里的6是3(三個攝像頭)*2(歷史圖片時間序列上是2)
  "input_ids": tensor[np.array(9,)],
  
}

參考

  1. 自然語言處理:第八十六章 Deepspeed各階段配置你了解么?

五、模型訓(xùn)練


腳本文件train.py、rdt_runner.pymodel.pyblocks.py

5.1 train.py

images = batch["images"].to(dtype=weight_dtype) # shape: (B, 6, 3, 384, 384)
states = batch["states"].to(dtype=weight_dtype) # (B, 1, D_a), 2, 1, 128, # We only use the last state as input
states = states[:, -1:, :] # (B, 1, 128)
actions = batch["actions"].to(dtype=weight_dtype) # shape (B, 64, 128)
state_elem_mask = batch["state_elem_mask"].to(dtype=weight_dtype) # (B, 128)shape
ctrl_freqs = batch["ctrl_freqs"] # shape (B, 128), 有些是0有些是100
    
with torch.no_grad():
    batch_size, _, C, H, W = images.shape
    image_embeds = vision_encoder(images.reshape(-1, C, H, W)).detach() # shape (B*6, 729, 1152)
    image_embeds = image_embeds.reshape((batch_size, -1, vision_encoder.hidden_size)) # shape (B, 4374, 1152)

    lang_attn_mask = batch["lang_attn_mask"] # mask 地方不做embedding
    text_embeds = batch["lang_embeds"].to(dtype=weight_dtype) \
        if args.precomp_lang_embed \
        else text_encoder(
            input_ids=batch["input_ids"],
            attention_mask=lang_attn_mask
        )["last_hidden_state"].detach()
    # shape (B, 46[token num], 4096)
state_elem_mask = state_elem_mask.unsqueeze(1)
loss = rdt(
    lang_tokens=text_embeds,
    lang_attn_mask=lang_attn_mask,
    img_tokens=image_embeds,
    state_tokens=states,
    action_gt=actions,
    action_mask=state_elem_mask,
    ctrl_freqs=ctrl_freqs
)
  • image_embeds 通過視覺編碼得到特征為 shape=(BatchSize*6, 729, 1152)并轉(zhuǎn)換成shape=(BatchSize, 4374, 1152)
  • text_embeds 通過語言編碼得到特征為 shape=(BatchSize, 46, 4096)將圖文編碼輸入到rdt`模型中

5.2 rdt_runner.py

1?? 整體pipeline

def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, 
                 state_tokens, action_gt, action_mask, ctrl_freqs
                ) -> torch.Tensor:
    '''
    lang_tokens: (batch_size, lang_len, lang_token_dim)
    lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
        which should be True-False bool tensor.
    img_tokens: (batch_size, img_len, img_token_dim)
    state_tokens: (batch_size, 1, state_token_dim), states
    action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
    action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
    ctrl_freqs: (batch_size,), control frequency for each sample.
    
    return: loss_value, a scalar tensor
    '''
    batch_size = lang_tokens.shape[0]
    device = lang_tokens.device  

    # Sample noise that we'll add to the actions
    noise = torch.randn(
        action_gt.shape, dtype=action_gt.dtype, device=device
    ) # shape(batch_size, 64 .128)
    # Sample random diffusion timesteps
    timesteps = torch.randint(
        0, self.num_train_timesteps, 
        (batch_size,), device=device
    ).long() # shape(batch_size,), 例如(226, 57); self.num_train_timesteps = 1000
    # Add noise to the clean actions according to the noise magnitude at each timestep
    # (this is the forward diffusion process)
    noisy_action = self.noise_scheduler.add_noise( # 在原圖上加上噪音
        action_gt, noise, timesteps) # shape (B 64, 128)
    
    # Concatenate the state and action tokens to form the input sequence
    state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) # shape (B, 65, 128) 將狀態(tài)token(128)以及加噪音的action(label)進(jìn)行合并作為state_action_traj
    # Append the action mask to the input sequence
    action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) # shape (2, 1, 128)
    state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) # 將當(dāng)前狀態(tài)token以及加噪音的action(label)以及對應(yīng)mask進(jìn)行合并作為state_action_traj (B, 65, 256)
    # Align the dimension with the hidden size
    lang_cond, img_cond, state_action_traj = self.adapt_conditions(
        lang_tokens, img_tokens, state_action_traj) # state_action_traj shape (2, 64, 128)
    # Predict the denoised result
    pred = self.model(state_action_traj, ctrl_freqs, 
                      timesteps, lang_cond, img_cond, 
                      lang_mask=lang_attn_mask)

    pred_type = self.prediction_type 
    if pred_type == 'epsilon':
        target = noise
    elif pred_type == 'sample':
        target = action_gt
    else:
        raise ValueError(f"Unsupported prediction type {pred_type}")

    loss = F.mse_loss(pred, target)
    return loss

2??數(shù)據(jù)構(gòu)成

noise = torch.randn(action_gt.shape, dtype=action_gt.dtype, device=device) # shape(batch_size, 64 .128)
timesteps = torch.randint(0, self.num_train_timesteps, (batch_size,), device=device).long() # shape(batch_size,), 例如(226, 57); self.num_train_timesteps = 1000
noisy_action = self.noise_scheduler.add_noise(action_gt, noise, timesteps) # shape (B 64, 128), 根據(jù)每個時間步的噪聲幅度,在清潔動作中添加噪聲
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) # shape (B, 65, 128) 將狀態(tài)state token(128)以及加噪音的action(label)進(jìn)行合并作為state_action_traj
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) # shape (B, 1, 128)
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) # 將當(dāng)前狀態(tài)token以及加噪音的action(label)以及對應(yīng)mask進(jìn)行合并作為state_action_traj (B, 65, 256)
lang_cond, img_cond, state_action_traj = self.adapt_conditions(lang_tokens, img_tokens, state_action_traj) # 分別加入adapter mlp的推理

5.3 model.py

def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None):
    """
    Forward pass of RDT.
    
    x: (B, T, D), state + action token sequence, T = horizon + 1,
        dimension D is assumed to be the same as the hidden size.狀態(tài)token以及加噪音的action(label)進(jìn)行合并作為
    freq: (B,), a scalar indicating control frequency.
    t: (B,) or (1,), diffusion timesteps.
    lang_c: (B, L_lang, D) or None, language condition tokens (variable length),
        dimension D is assumed to be the same as the hidden size.
    img_c: (B, L_img, D) or None, image condition tokens (fixed length),
        dimension D is assumed to be the same as the hidden size.
    lang_mask: (B, L_lang) or None, language condition mask (True for valid).
    img_mask: (B, L_img) or None, image condition mask (True for valid).
    """
    # 這里的D就是2048
    t = self.t_embedder(t).unsqueeze(1)             # (B, 1, D) or (1, 1, D)
    freq = self.freq_embedder(freq).unsqueeze(1)    # (B, 1, D)
    # Append timestep to the input tokens
    if t.shape[0] == 1:
        t = t.expand(x.shape[0], -1, -1)
    x = torch.cat([t, freq, x], dim=1)               # (B, T+1, D), 這里的x為(B, 65, 2048)
    
    # Add multimodal position embeddings
    x = x + self.x_pos_embed
    # Note the lang is of variable length
    lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]]
    img_c = img_c + self.img_cond_pos_embed

    # Forward pass
    conds = [lang_c, img_c] # shape: [(B, T, 2048), (B, T, 2048), ...]
    masks = [lang_mask, img_mask] # img_mask為None
    for i, block in enumerate(self.blocks):
        c, mask = conds[i%2], masks[i%2]
        x = block(x, c, mask)                       # (B, T+1, D)
    # Inject the language condition at the final layer
    x = self.final_layer(x)                          # (B, T+1, out_channels), out_channels = 128, 這里的x shape(B, 67, 128)

    # Only preserve the action tokens
    x = x[:, -self.horizon:] # x shape (B, 64, 128)
    return x

這里需要注明下-self.horizon:默認(rèn)為action的長度=64, 可以參考下圖Dit, 其實(shí)就是輸出的noise

DIT

5.4 blocks.py

上述的self.blocks代碼如下所示


#################################################################################
#                          Cross Attention Layers                               #
#################################################################################
class CrossAttention(nn.Module):
    """
    A cross-attention layer with flash attention.
    """
    fused_attn: Final[bool]
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0,
            proj_drop: float = 0,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x: torch.Tensor, c: torch.Tensor, 
                mask: torch.Tensor | None = None) -> torch.Tensor:
        B, N, C = x.shape
        _, L, _ = c.shape
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        kv = self.kv(c).reshape(B, L, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        k, v = kv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        # Prepare attn mask (B, L) to mask the conditioion
        if mask is not None:
            mask = mask.reshape(B, 1, 1, L)
            mask = mask.expand(-1, -1, N, -1)
        
        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                query=q,
                key=k,
                value=v,
                dropout_p=self.attn_drop.p if self.training else 0.,
                attn_mask=mask
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            if mask is not None:
                attn = attn.masked_fill_(mask.logical_not(), float('-inf'))
            attn = attn.softmax(dim=-1)
            if self.attn_drop.p > 0:
                attn = self.attn_drop(attn)
            x = attn @ v
            
        x = x.permute(0, 2, 1, 3).reshape(B, N, C)
        x = self.proj(x)
        if self.proj_drop.p > 0:
            x = self.proj_drop(x)
        return x


#################################################################################
#                                 RDT Block                                     #
#################################################################################
class RDTBlock(nn.Module):
    """
    A RDT block with cross-attention conditioning.
    """
    def __init__(self, hidden_size, num_heads, **block_kwargs):
        super().__init__()
        self.norm1 = RmsNorm(hidden_size, eps=1e-6)
        self.attn = Attention(
            dim=hidden_size, num_heads=num_heads, 
            qkv_bias=True, qk_norm=True, 
            norm_layer=RmsNorm,**block_kwargs)
        self.cross_attn = CrossAttention(
            hidden_size, num_heads=num_heads, 
            qkv_bias=True, qk_norm=True, 
            norm_layer=RmsNorm,**block_kwargs)
        
        self.norm2 = RmsNorm(hidden_size, eps=1e-6)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.ffn = Mlp(in_features=hidden_size, 
            hidden_features=hidden_size, 
            act_layer=approx_gelu, drop=0)
        self.norm3 = RmsNorm(hidden_size, eps=1e-6)

    def forward(self, x, c, mask=None):
        origin_x = x
        x = self.norm1(x)
        x = self.attn(x)
        x = x + origin_x
        
        origin_x = x
        x = self.norm2(x)
        x = self.cross_attn(x, c, mask)
        x = x + origin_x
                
        origin_x = x
        x = self.norm3(x)
        x = self.ffn(x)
        x = x + origin_x
        
        return x


class FinalLayer(nn.Module):
    """
    The final layer of RDT.
    """
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.norm_final = RmsNorm(hidden_size, eps=1e-6)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.ffn_final = Mlp(in_features=hidden_size,
            hidden_features=hidden_size,
            out_features=out_channels, 
            act_layer=approx_gelu, drop=0)

    def forward(self, x):
        x = self.norm_final(x)
        x = self.ffn_final(x)
        return x

六、 模型推理

agilex_inference.py

#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""

import argparse
import sys
import threading
import time
import yaml
from collections import deque

import numpy as np
import rospy
import torch
from cv_bridge import CvBridge
from geometry_msgs.msg import Twist
from nav_msgs.msg import Odometry
from PIL import Image as PImage
from sensor_msgs.msg import Image, JointState
from std_msgs.msg import Header
import cv2

from agilex_model import create_model

CAMERA_NAMES = ['cam_high', 'cam_right_wrist', 'cam_left_wrist']

observation_window = None

lang_embeddings = None

# debug
preload_images = None


# Initialize the model
def make_policy(args):
    with open(args.config_path, "r") as fp:
        config = yaml.safe_load(fp)
    args.config = config
    
    # pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
    pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
    model = create_model(
        args=args.config, 
        dtype=torch.bfloat16,
        pretrained=args.pretrained_model_name_or_path,
        # pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
        pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
        control_frequency=args.ctrl_freq,
    )

    return model


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)


# Interpolate the actions to make the robot move smoothly
def interpolate_action(args, prev_action, cur_action):
    steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
    diff = np.abs(cur_action - prev_action)
    step = np.ceil(diff / steps).astype(int)
    step = np.max(step)
    if step <= 1:
        return cur_action[np.newaxis, :]
    new_actions = np.linspace(prev_action, cur_action, step + 1)
    return new_actions[1:]


def get_config(args):
    config = {
        'episode_len': args.max_publish_step,
        'state_dim': 14,
        'chunk_size': args.chunk_size,
        'camera_names': CAMERA_NAMES,
    }
    return config


# Get the observation from the ROS topic
def get_ros_observation(args,ros_operator):
    rate = rospy.Rate(args.publish_rate)
    print_flag = True

    while True and not rospy.is_shutdown():
        result = ros_operator.get_frame()
        if not result:
            if print_flag:
                print("syn fail when get_ros_observation")
                print_flag = False
            rate.sleep()
            continue
        print_flag = True
        (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
         puppet_arm_left, puppet_arm_right, robot_base) = result
        # print(f"sync success when get_ros_observation")
        return (img_front, img_left, img_right,
         puppet_arm_left, puppet_arm_right)


# Update the observation window buffer
def update_observation_window(args, config, ros_operator):
    # JPEG transformation
    # Align with training
    def jpeg_mapping(img):
        img = cv2.imencode('.jpg', img)[1].tobytes()
        img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
        return img
    
    global observation_window
    if observation_window is None:
        observation_window = deque(maxlen=2)
    
        # Append the first dummy image
        observation_window.append(
            {
                'qpos': None,
                'images':
                    {
                        config["camera_names"][0]: None,
                        config["camera_names"][1]: None,
                        config["camera_names"][2]: None,
                    },
            }
        )
        
    img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = get_ros_observation(args,ros_operator)
    img_front = jpeg_mapping(img_front)
    img_left = jpeg_mapping(img_left)
    img_right = jpeg_mapping(img_right)
    
    qpos = np.concatenate(
            (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
    qpos = torch.from_numpy(qpos).float().cuda()
    observation_window.append(
        {
            'qpos': qpos,
            'images':
                {
                    config["camera_names"][0]: img_front,
                    config["camera_names"][1]: img_right,
                    config["camera_names"][2]: img_left,
                },
        }
    )


# RDT inference
def inference_fn(args, config, policy, t):
    global observation_window
    global lang_embeddings
    
    # print(f"Start inference_thread_fn: t={t}")
    while True and not rospy.is_shutdown():
        time1 = time.time()     

        # fetch images in sequence [front, right, left]
        image_arrs = [
            observation_window[-2]['images'][config['camera_names'][0]],
            observation_window[-2]['images'][config['camera_names'][1]],
            observation_window[-2]['images'][config['camera_names'][2]],
            
            observation_window[-1]['images'][config['camera_names'][0]],
            observation_window[-1]['images'][config['camera_names'][1]],
            observation_window[-1]['images'][config['camera_names'][2]]
        ]
        
        # fetch debug images in sequence [front, right, left]
        # image_arrs = [
        #     preload_images[config['camera_names'][0]][max(t - 1, 0)],
        #     preload_images[config['camera_names'][2]][max(t - 1, 0)],
        #     preload_images[config['camera_names'][1]][max(t - 1, 0)],
        #     preload_images[config['camera_names'][0]][t],
        #     preload_images[config['camera_names'][2]][t],
        #     preload_images[config['camera_names'][1]][t]
        # ]
        # # encode the images
        # for i in range(len(image_arrs)):
        #     image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
        # proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
        
        images = [PImage.fromarray(arr) if arr is not None else None
                  for arr in image_arrs]
        
        # for i, pos in enumerate(['f', 'r', 'l'] * 2):
        #     images[i].save(f'{t}-{i}-{pos}.png')
        
        # get last qpos in shape [14, ]
        proprio = observation_window[-1]['qpos']
        # unsqueeze to [1, 14]
        proprio = proprio.unsqueeze(0)
        
        # actions shaped as [1, 64, 14] in format [left, right]
        actions = policy.step(
            proprio=proprio,
            images=images,
            text_embeds=lang_embeddings 
        ).squeeze(0).cpu().numpy()
        # print(f"inference_actions: {actions.squeeze()}")
        
        print(f"Model inference time: {time.time() - time1} s")
        
        # print(f"Finish inference_thread_fn: t={t}")
        return actions


# Main loop for the manipulation task
def model_inference(args, config, ros_operator):
    global lang_embeddings
    
    # Load rdt model
    policy = make_policy(args)
    
    lang_dict = torch.load(args.lang_embeddings_path)
    print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
    lang_embeddings = lang_dict["embeddings"]
    
    max_publish_step = config['episode_len']
    chunk_size = config['chunk_size']

    # Initialize position of the puppet arm
    # left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
    # right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
    # left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
    # right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
    # ????????????????????????改??????????????????????????????
    left0 = [0, 0, 0, 0, 0, 0, 0, 1]
    right0 = [0, 0, 0, 0, 0, 0, 0, 1]
    left1 = [0, 0, 0, 0, 0, 0, 0, 1]
    right1 = [0, 0, 0, 0, 0, 0, 0, 1]

    ros_operator.puppet_arm_publish_continuous(left0, right0)
    input("Press enter to continue")
    ros_operator.puppet_arm_publish_continuous(left1, right1)
    # Initialize the previous action to be the initial robot state
    pre_action = np.zeros(config['state_dim'])
    # ????????????????????????改??????????????????????????????
    pre_action[:16] = np.array(
        [0, 0, 0, 0 ,0, 0, 0] +
        [0, 0, 0, 0, 0, 0, 0]
    )
    # pre_action[:14] = np.array(
    #     [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] +
    #     [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
    # )
    action = None
    # Inference loop
    with torch.inference_mode():
        while True and not rospy.is_shutdown():
            # The current time step
            t = 0
            rate = rospy.Rate(args.publish_rate)
    
            action_buffer = np.zeros([chunk_size, config['state_dim']]) # state_dim 14
            
            while t < max_publish_step and not rospy.is_shutdown():
                # Update observation window
                update_observation_window(args, config, ros_operator)
                
                # When coming to the end of the action chunk
                if t % chunk_size == 0:
                    # Start inference
                    action_buffer = inference_fn(args, config, policy, t).copy()
                
                raw_action = action_buffer[t % chunk_size]
                action = raw_action
                # Interpolate the original action sequence
                if args.use_actions_interpolation:
                    # print(f"Time {t}, pre {pre_action}, act {action}")
                    interp_actions = interpolate_action(args, pre_action, action)
                else:
                    interp_actions = action[np.newaxis, :]
                # Execute the interpolated actions one by one
                for act in interp_actions:
                    left_action = act[:7]
                    right_action = act[7:14]
                    
                    if not args.disable_puppet_arm:
                        ros_operator.puppet_arm_publish(left_action, right_action)  # puppet_arm_publish_continuous_thread
                
                    if args.use_robot_base:
                        vel_action = act[14:16]
                        ros_operator.robot_base_publish(vel_action)
                    rate.sleep()
                    # print(f"doing action: {act}")
                t += 1
                
                print("Published Step", t)
                pre_action = action.copy()


# ROS operator class
class RosOperator:
    def __init__(self, args):
        self.robot_base_deque = None
        self.puppet_arm_right_deque = None
        self.puppet_arm_left_deque = None
        self.img_front_deque = None
        self.img_right_deque = None
        self.img_left_deque = None
        self.img_front_depth_deque = None
        self.img_right_depth_deque = None
        self.img_left_depth_deque = None
        self.bridge = None
        self.puppet_arm_left_publisher = None
        self.puppet_arm_right_publisher = None
        self.robot_base_publisher = None
        self.puppet_arm_publish_thread = None
        self.puppet_arm_publish_lock = None
        self.args = args
        self.init()
        self.init_ros()

    def init(self):
        # 初始化CvBridge對象,用于OpenCV圖像格式和其他格式之間的轉(zhuǎn)換
        self.bridge = CvBridge()

        # 初始化圖像緩存隊(duì)列,用于存儲來自不同攝像頭的圖像數(shù)據(jù)
        self.img_left_deque = deque()
        self.img_right_deque = deque()
        self.img_front_deque = deque()

        # 初始化深度圖像緩存隊(duì)列,用于存儲來自不同攝像頭的深度圖像數(shù)據(jù)
        self.img_left_depth_deque = deque()
        self.img_right_depth_deque = deque()
        self.img_front_depth_deque = deque()

        # 初始化機(jī)械臂操作指令緩存隊(duì)列,用于存儲機(jī)械臂的操作指令
        self.puppet_arm_left_deque = deque()
        self.puppet_arm_right_deque = deque()

        # 初始化機(jī)器人底盤操作指令緩存隊(duì)列,用于存儲機(jī)器人底盤的操作指令
        self.robot_base_deque = deque()

        # 初始化機(jī)械臂指令發(fā)布鎖,用于同步機(jī)械臂指令的發(fā)布,避免競態(tài)條件
        self.puppet_arm_publish_lock = threading.Lock()

        # 獲取鎖,以確保在初始化階段之后,機(jī)械臂指令的發(fā)布是線程安全的
        self.puppet_arm_publish_lock.acquire()

    def puppet_arm_publish(self, left, right):
        joint_state_msg = JointState()
        joint_state_msg.header = Header()
        joint_state_msg.header.stamp = rospy.Time.now()  # Set timestep
        joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6']  # 設(shè)置關(guān)節(jié)名稱
        joint_state_msg.position = left
        self.puppet_arm_left_publisher.publish(joint_state_msg)
        joint_state_msg.position = right
        self.puppet_arm_right_publisher.publish(joint_state_msg)

    def robot_base_publish(self, vel):
        vel_msg = Twist()
        vel_msg.linear.x = vel[0]
        vel_msg.linear.y = 0
        vel_msg.linear.z = 0
        vel_msg.angular.x = 0
        vel_msg.angular.y = 0
        vel_msg.angular.z = vel[1]
        self.robot_base_publisher.publish(vel_msg)

    def puppet_arm_publish_continuous(self, left, right):
        # 初始化發(fā)布器的頻率
        rate = rospy.Rate(self.args.publish_rate)

        # 初始化左右手臂的位置數(shù)據(jù)
        left_arm = None
        right_arm = None
        """
        
        在代碼中,left 和 left_arm 的含義如下:
        left:表示目標(biāo)位置(即期望的機(jī)械臂左臂關(guān)節(jié)角度)。它是通過推理模型或用戶指定的目標(biāo)狀態(tài)生成的動作序列的一部分。
        left_arm:表示當(dāng)前機(jī)械臂左臂的實(shí)際關(guān)節(jié)角度狀態(tài)。它從 ROS 話題訂閱的數(shù)據(jù)中獲取,反映了機(jī)械臂當(dāng)前的真實(shí)位置。
        兩者的區(qū)別在于:
        left 是目標(biāo)狀態(tài),代表機(jī)械臂需要移動到的位置。
        left_arm 是當(dāng)前狀態(tài),代表機(jī)械臂當(dāng)前所在的位置。
        在插值或連續(xù)發(fā)布動作的過程中,代碼會逐步調(diào)整 left_arm 的值,使其逐漸接近目標(biāo)狀態(tài) left。
        """

        # 在ROS節(jié)點(diǎn)關(guān)閉前持續(xù)執(zhí)行循環(huán)
        while True and not rospy.is_shutdown():
            # 檢查并更新左臂位置數(shù)據(jù)
            if len(self.puppet_arm_left_deque) != 0:
                left_arm = list(self.puppet_arm_left_deque[-1].position)

            # 檢查并更新右臂位置數(shù)據(jù)
            if len(self.puppet_arm_right_deque) != 0:
                right_arm = list(self.puppet_arm_right_deque[-1].position)

            # 如果任一手臂的位置數(shù)據(jù)未更新,則等待一段時間后繼續(xù)嘗試
            if left_arm is None or right_arm is None:
                rate.sleep()
                continue
            else:
                # 當(dāng)左右手臂位置數(shù)據(jù)均成功更新后,退出循環(huán)
                break
        left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
        right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
        flag = True
        step = 0
        while flag and not rospy.is_shutdown():
            if self.puppet_arm_publish_lock.acquire(False):
                return
            left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
            right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
            flag = False
            for i in range(len(left)):
                if left_diff[i] < self.args.arm_steps_length[i]: # 每次移動不能超過太多
                    left_arm[i] = left[i]
                else:
                    left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] # 移動太多給她很小的距離進(jìn)行移動
                    flag = True
            for i in range(len(right)):
                if right_diff[i] < self.args.arm_steps_length[i]:
                    right_arm[i] = right[i]
                else:
                    right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
                    flag = True
            joint_state_msg = JointState()
            joint_state_msg.header = Header()
            joint_state_msg.header.stamp = rospy.Time.now()  # Set the timestep
            joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6']  # 設(shè)置關(guān)節(jié)名稱
            joint_state_msg.position = left_arm
            self.puppet_arm_left_publisher.publish(joint_state_msg)
            joint_state_msg.position = right_arm
            self.puppet_arm_right_publisher.publish(joint_state_msg)
            step += 1
            print("puppet_arm_publish_continuous:", step)
            rate.sleep()

    def puppet_arm_publish_linear(self, left, right):
        num_step = 100
        rate = rospy.Rate(200)

        left_arm = None
        right_arm = None

        while True and not rospy.is_shutdown():
            if len(self.puppet_arm_left_deque) != 0:
                left_arm = list(self.puppet_arm_left_deque[-1].position)
            if len(self.puppet_arm_right_deque) != 0:
                right_arm = list(self.puppet_arm_right_deque[-1].position)
            if left_arm is None or right_arm is None:
                rate.sleep()
                continue
            else:
                break

        traj_left_list = np.linspace(left_arm, left, num_step)
        traj_right_list = np.linspace(right_arm, right, num_step)

        for i in range(len(traj_left_list)):
            traj_left = traj_left_list[i]
            traj_right = traj_right_list[i]
            traj_left[-1] = left[-1]
            traj_right[-1] = right[-1]
            joint_state_msg = JointState()
            joint_state_msg.header = Header()
            joint_state_msg.header.stamp = rospy.Time.now()  # 設(shè)置時間戳
            joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6']  # 設(shè)置關(guān)節(jié)名稱
            joint_state_msg.position = traj_left
            self.puppet_arm_left_publisher.publish(joint_state_msg)
            joint_state_msg.position = traj_right
            self.puppet_arm_right_publisher.publish(joint_state_msg)
            rate.sleep()

    def puppet_arm_publish_continuous_thread(self, left, right):
        if self.puppet_arm_publish_thread is not None:
            self.puppet_arm_publish_lock.release()
            self.puppet_arm_publish_thread.join()
            self.puppet_arm_publish_lock.acquire(False)
            self.puppet_arm_publish_thread = None
        self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
        self.puppet_arm_publish_thread.start()

    def get_frame(self):
        if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
                (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
            return False
        if self.args.use_depth_image:
            frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
                              self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
        # 如果不是else之前的條件,計算三幅圖像中時間戳最小的值
        # 從三個圖像隊(duì)列中取出最后一幅圖像,并獲取它們的時間戳
        # 使用時間戳的秒數(shù)來比較,找到三個時間戳中最小的一個
        else:
            frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec()])

        if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
            return False
        if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
            return False
        if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
            return False
        if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
            return False
        if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
            return False
        if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
            return False
        if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
            return False
        if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
            return False
        if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
            return False

        # 當(dāng)隊(duì)列中最前面的左圖像的時間戳小于目標(biāo)幀時間時,持續(xù)移除隊(duì)列頭部的圖像
        while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
            self.img_left_deque.popleft()
        img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')

        while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
            self.img_right_deque.popleft()
        img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')

        while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
            self.img_front_deque.popleft()
        img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')

        while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
            self.puppet_arm_left_deque.popleft()
        puppet_arm_left = self.puppet_arm_left_deque.popleft()

        while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
            self.puppet_arm_right_deque.popleft()
        puppet_arm_right = self.puppet_arm_right_deque.popleft()

        img_left_depth = None
        if self.args.use_depth_image:
            while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
                self.img_left_depth_deque.popleft()
            img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')

        img_right_depth = None
        if self.args.use_depth_image:
            while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
                self.img_right_depth_deque.popleft()
            img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')

        img_front_depth = None
        if self.args.use_depth_image:
            while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
                self.img_front_depth_deque.popleft()
            img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')

        robot_base = None
        if self.args.use_robot_base:
            while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
                self.robot_base_deque.popleft()
            robot_base = self.robot_base_deque.popleft()

        return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
                puppet_arm_left, puppet_arm_right, robot_base)

    def img_left_callback(self, msg):
        if len(self.img_left_deque) >= 2000:
            self.img_left_deque.popleft()
        self.img_left_deque.append(msg)

    def img_right_callback(self, msg):
        if len(self.img_right_deque) >= 2000:
            self.img_right_deque.popleft()
        self.img_right_deque.append(msg)

    def img_front_callback(self, msg):
        if len(self.img_front_deque) >= 2000:
            self.img_front_deque.popleft()
        self.img_front_deque.append(msg)

    def img_left_depth_callback(self, msg):
        if len(self.img_left_depth_deque) >= 2000:
            self.img_left_depth_deque.popleft()
        self.img_left_depth_deque.append(msg)

    def img_right_depth_callback(self, msg):
        if len(self.img_right_depth_deque) >= 2000:
            self.img_right_depth_deque.popleft()
        self.img_right_depth_deque.append(msg)

    def img_front_depth_callback(self, msg):
        if len(self.img_front_depth_deque) >= 2000:
            self.img_front_depth_deque.popleft()
        self.img_front_depth_deque.append(msg)

    def puppet_arm_left_callback(self, msg):
        if len(self.puppet_arm_left_deque) >= 2000:
            self.puppet_arm_left_deque.popleft()
        self.puppet_arm_left_deque.append(msg)

    def puppet_arm_right_callback(self, msg):
        if len(self.puppet_arm_right_deque) >= 2000:
            self.puppet_arm_right_deque.popleft()
        self.puppet_arm_right_deque.append(msg)

    def robot_base_callback(self, msg):
        if len(self.robot_base_deque) >= 2000:
            self.robot_base_deque.popleft()
        self.robot_base_deque.append(msg)

    def init_ros(self):
        """
        初始化ROS節(jié)點(diǎn)和訂閱者。

        該方法初始化了ROS節(jié)點(diǎn),并根據(jù)參數(shù)訂閱了相應(yīng)的ROS話題。
        它還創(chuàng)建了用于發(fā)布關(guān)節(jié)狀態(tài)的發(fā)布者。
        """
        # 初始化ROS節(jié)點(diǎn),設(shè)置節(jié)點(diǎn)名為'joint_state_publisher',并允許匿名節(jié)點(diǎn)
        rospy.init_node('joint_state_publisher', anonymous=True)

        # 訂閱左側(cè)、右側(cè)和前端相機(jī)的圖像話題
        rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
        rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
        rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)

        # 如果使用深度圖像,則訂閱對應(yīng)的深度圖像話題
        if self.args.use_depth_image:
            rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
            rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
            rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)

        # 訂閱機(jī)械臂關(guān)節(jié)狀態(tài)和機(jī)器人底盤里程計話題
        rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
        rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
        rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)

        # 創(chuàng)建用于發(fā)布機(jī)械臂關(guān)節(jié)狀態(tài)和機(jī)器人底盤控制指令的發(fā)布者
        self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
        self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
        self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)

def get_arguments():
    parser = argparse.ArgumentParser()
    # 添加最大發(fā)布步驟數(shù)量的參數(shù),用于限制發(fā)布動作的步驟數(shù)
    parser.add_argument('--max_publish_step', action='store', type=int,
                        help='Maximum number of action publishing steps', default=10000, required=False)

    # 添加隨機(jī)種子參數(shù),用于確保結(jié)果的可重復(fù)性
    parser.add_argument('--seed', action='store', type=int,
                        help='Random seed', default=None, required=False)

    # 添加前置攝像頭圖像主題參數(shù),指定前置攝像頭的圖像來源
    parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
                        default='/camera_f/color/image_raw', required=False)
    # 添加左置攝像頭圖像主題參數(shù),指定左置攝像頭的圖像來源
    parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
                        default='/camera_l/color/image_raw', required=False)
    # 添加右置攝像頭圖像主題參數(shù),指定右置攝像頭的圖像來源
    parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
                        default='/camera_r/color/image_raw', required=False)

    # 添加前置攝像頭深度圖像主題參數(shù),指定前置攝像頭的深度圖像來源
    parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
                        default='/camera_f/depth/image_raw', required=False)
    # 添加左置攝像頭深度圖像主題參數(shù),指定左置攝像頭的深度圖像來源
    parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
                        default='/camera_l/depth/image_raw', required=False)
    # 添加右置攝像頭深度圖像主題參數(shù),指定右置攝像頭的深度圖像來源
    parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
                        default='/camera_r/depth/image_raw', required=False)

    # 添加命令行參數(shù)解析,用于配置左臂指令主題
    parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
                        default='/master/joint_left', required=False)
    # 添加命令行參數(shù)解析,用于配置右臂指令主題
    parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
                        default='/master/joint_right', required=False)
    # 添加命令行參數(shù)解析,用于配置左臂狀態(tài)主題
    parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
                        default='/puppet/joint_left', required=False)
    # 添加命令行參數(shù)解析,用于配置右臂狀態(tài)主題
    parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
                        default='/puppet/joint_right', required=False)

    # 添加命令行參數(shù)解析,用于配置機(jī)器人底盤狀態(tài)主題
    parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
                        default='/odom_raw', required=False)
    # 添加命令行參數(shù)解析,用于配置機(jī)器人底盤指令主題
    parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
                        default='/cmd_vel', required=False)
    # 添加命令行參數(shù)解析,用于配置是否使用機(jī)器人底盤移動
    parser.add_argument('--use_robot_base', action='store_true',
                        help='Whether to use the robot base to move around',
                        default=False, required=False)
    # 添加命令行參數(shù)解析,用于配置動作發(fā)布頻率
    parser.add_argument('--publish_rate', action='store', type=int,
                        help='The rate at which to publish the actions',
                        default=30, required=False)
    # 添加命令行參數(shù)解析,用于配置機(jī)器人控制頻率
    parser.add_argument('--ctrl_freq', action='store', type=int,
                        help='The control frequency of the robot',
                        default=25, required=False)

    # 添加命令行參數(shù)解析,用于配置動作塊大小
    parser.add_argument('--chunk_size', action='store', type=int,
                        help='Action chunk size',
                        default=64, required=False)
    # 添加命令行參數(shù)解析,用于配置每個關(guān)節(jié)每步的最大變化量
    parser.add_argument('--arm_steps_length', action='store', type=float,
                        help='The maximum change allowed for each joint per timestep',
                        default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)

    # 添加命令行參數(shù)解析,用于配置是否在動作差異過大時進(jìn)行插值
    parser.add_argument('--use_actions_interpolation', action='store_true',
                        help='Whether to interpolate the actions if the difference is too large',
                        default=False, required=False)
    # 添加命令行參數(shù)解析,用于配置是否使用深度圖像
    parser.add_argument('--use_depth_image', action='store_true',
                        help='Whether to use depth images',
                        default=False, required=False)

    # 添加命令行參數(shù)解析,用于配置是否禁用puppet臂,以便安全調(diào)試
    parser.add_argument('--disable_puppet_arm', action='store_true',
                        help='Whether to disable the puppet arm. This is useful for safely debugging',default=False)

    # 添加命令行參數(shù)解析,用于配置配置文件路徑
    parser.add_argument('--config_path', type=str, default="configs/base.yaml",
                        help='Path to the config file')
    # 以下命令行參數(shù)解析被注釋掉,可能是因?yàn)椴辉偈褂没蛘呱形磳?shí)現(xiàn)
    # parser.add_argument('--cfg_scale', type=float, default=2.0,
    #                     help='the scaling factor used to modify the magnitude of the control features during denoising')
    # 添加命令行參數(shù)解析,用于配置預(yù)訓(xùn)練模型的名稱或路徑
    parser.add_argument('--pretrained_model_name_or_path', type=str, required=True, help='Name or path to the pretrained model')

    # 添加命令行參數(shù)解析,用于配置語言嵌入路徑
    parser.add_argument('--lang_embeddings_path', type=str, required=True,
                        help='Path to the pre-encoded language instruction embeddings')

    args = parser.parse_args()
    return args


def main():
    args = get_arguments()
    ros_operator = RosOperator(args)
    if args.seed is not None:
        set_seed(args.seed)
    config = get_config(args)
    model_inference(args, config, ros_operator)


if __name__ == '__main__':
    main()

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容