Skip to content

Vision2Seq Backend

Vision2SeqBackend

Bases: Backend

Vision2SeqBackend backend that runs locally to generate robot actions.

Beware of the memory requirements of 7B+ parameter models like OpenVLA.

Attributes:

Name Type Description
model_id str

The model to use for the OpenVLA backend.

device device

The device to run the model on.

torch_dtype dtype

The torch data type to use.

processor AutoProcessor

The processor for the model.

model AutoModelForVision2Seq

The model for the OpenVLA backend.

Source code in mbodied/agents/backends/vision2seq_backend.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class Vision2SeqBackend(Backend):
    """Vision2SeqBackend backend that runs locally to generate robot actions.

    Beware of the memory requirements of 7B+ parameter models like OpenVLA.

    Attributes:
        model_id (str): The model to use for the OpenVLA backend.
        device (torch.device): The device to run the model on.
        torch_dtype (torch.dtype): The torch data type to use.
        processor (AutoProcessor): The processor for the model.
        model (AutoModelForVision2Seq): The model for the OpenVLA backend.
    """

    DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    ATTN_IMPLEMENTATION = "flash_attention_2" if torch.cuda.is_available() else "eager"
    DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float16

    def __init__(
        self,
        model_id: str = "openvla/openvla-7b",
        attn_implementation: Literal["flash_attention_2", "eager"] = ATTN_IMPLEMENTATION,
        torch_dtype: torch.dtype = DTYPE,
        device: torch.device = DEFAULT_DEVICE,
        **kwargs,
    ) -> None:
        smart_import("transformers")
        from transformers import AutoModelForVision2Seq, AutoProcessor

        self.model_id = model_id
        self.device = device
        self.torch_dtype = torch_dtype
        # Load Processor & VLA
        self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        self.model = AutoModelForVision2Seq.from_pretrained(
            model_id,
            attn_implementation=attn_implementation,
            torch_dtype=self.torch_dtype,
            low_cpu_mem_usage=torch.cuda.is_available(),
            trust_remote_code=True,
            **kwargs,
        ).to(device)

    def predict(self, instruction: str, image: Image, unnorm_key: str = "bridge_orig") -> str:
        prompt = f"In: What action should the robot take to {instruction}?\nOut:"
        inputs = self.processor(prompt, image.pil).to(self.device, dtype=self.torch_dtype)
        response = self.model.predict_action(**inputs, unnorm_key=unnorm_key, do_sample=False)
        return str(response)