MixtrainDocsBlog

VLM RL Training

Train Vision-Language Models using Proximal Policy Optimization (PPO) on visual reasoning tasks. Inspired by vlm-gym.

Overview

This workflow demonstrates:

  • Fine-tuning pretrained VLMs (Qwen2-VL) with reinforcement learning
  • Creating visual environments with reward signals
  • PPO training loop in PyTorch
  • Curriculum learning with progressive difficulty stages
  • Checkpointing and evaluation via mixtrain

Architecture

┌─────────────────────────────────────────────────────────────┐
│                     VLM RL Training                         │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌──────────┐    ┌──────────┐    ┌──────────┐               │
│  │   VLM    │───▶│  Action  │───▶│   Env    │               │
│  │ (Actor)  │    │  (Text)  │    │ (Reward) │               │
│  └──────────┘    └──────────┘    └──────────┘               │
│       │                               │                     │
│       │          ┌──────────┐         │                     │
│       └─────────▶│  Value   │◀────────┘                     │
│                  │  Head    │                               │
│                  └──────────┘                               │
│                       │                                     │
│                  ┌──────────┐                               │
│                  │   PPO    │                               │
│                  │  Update  │                               │
│                  └──────────┘                               │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Configuration

from mixtrain import Dataset, Eval, MixFlow, Sandbox


class VLMRLTraining(MixFlow):
    _sandbox = Sandbox(
        image="nvcr.io/nvidia/pytorch:25.01-py3",
        gpu="A100",
        gpu_per_node=4,
        timeout=14400,  # 4 hours
    )

    def run(
        self,
        env_name: str = "geospot",
        model_name: str = "Qwen/Qwen2-VL-7B-Instruct",
        total_steps: int = 10000,
        learning_rate: float = 5e-7,
        curriculum_enabled: bool = True,
        env_dataset: Dataset | None = None,
    ):
        # ... training logic

Environments

GeoSpot (GeoGuessr-style)

Predict location from street view images. The model outputs country, region, city, and coordinates.

# Reward structure
reward = (
    0.3 * country_match +
    0.2 * region_match +
    0.2 * city_match +
    0.3 * coordinate_accuracy  # exponential decay by distance
)

Curriculum stages:

StageStepsTaskTolerance
10-100Country only500km
2100-300Country (refined)200km
3300-600Country + Region100km
4600-1000Country + Region + City50km
51000+Full (with coordinates)25km

Visual QA

Answer verifiable questions about images (counting, attributes, existence).

# Example questions
{"question": "How many people?", "answer": "3", "type": "counting"}
{"question": "What color is the car?", "answer": "red", "type": "attribute"}
{"question": "Is there a dog?", "answer": "yes", "type": "existence"}

PPO Training

The workflow implements standard PPO with:

  • Actor: VLM generates text responses
  • Critic: Value head predicts expected rewards
  • GAE: Generalized Advantage Estimation
  • Clipping: Prevents large policy updates
# PPO hyperparameters
ppo_epochs: int = 4          # Update epochs per rollout
clip_epsilon: float = 0.2    # PPO clipping parameter
value_coef: float = 0.5      # Value loss coefficient
entropy_coef: float = 0.01   # Entropy bonus
gamma: float = 0.99          # Discount factor
gae_lambda: float = 0.95     # GAE lambda

Custom Environment Dataset

Provide your own images and ground truth:

from mixtrain import Dataset

# Create environment dataset
env_data = Dataset.from_pandas(pd.DataFrame([
    {"image_url": "s3://bucket/image1.jpg", "country": "France", "city": "Paris", "lat": 48.86, "lon": 2.35},
    {"image_url": "s3://bucket/image2.jpg", "country": "Japan", "city": "Tokyo", "lat": 35.68, "lon": 139.65},
    # ...
])).save("geospot-training-data", column_types={"image_url": Image})

# Run training with custom data
mixtrain workflow run vlm-rl-training \
  --input '{"env_dataset": "geospot-training-data", "total_steps": 50000}'

Running

Basic training

mixtrain workflow run vlm-rl-training \
  --input '{"env_name": "geospot", "total_steps": 10000}'

With curriculum learning

mixtrain workflow run vlm-rl-training \
  --input '{
    "env_name": "geospot",
    "curriculum_enabled": true,
    "total_steps": 50000,
    "checkpoint_interval": 1000
  }'

Resume from checkpoint

mixtrain workflow run vlm-rl-training \
  --input '{
    "resume_checkpoint": "vlm-rl-checkpoint-step-5000",
    "total_steps": 50000
  }'

Evaluation

Run the companion evaluation workflow:

mixtrain workflow run vlm-rl-eval \
  --input '{
    "env_name": "geospot",
    "trained_model": "my-trained-vlm",
    "baseline_model": "qwen2-vl-7b-instruct",
    "num_episodes": 100
  }'

Outputs

The workflow returns:

{
    "final_reward": 0.75,           # Average reward (last 100 episodes)
    "total_steps": 10000,
    "total_episodes": 625,
    "metrics_dataset": Dataset,     # Training metrics over time
    "evaluation": Eval,             # Visualization in UI
}

Next Steps

On this page