From 100293ffc0a9b0d47797d51f738e1da84d99167b Mon Sep 17 00:00:00 2001 From: Yan Bai Date: Mon, 11 May 2026 06:17:52 -0700 Subject: [PATCH] Add VerlBB recipe --- README.md | 2 + verlbb/README.md | 82 ++ verlbb/REQUIRED_VERL.txt | 13 + verlbb/scripts/run_qwen3moe_gsm8k_grpo.sh | 287 +++++++ verlbb/scripts/run_qwen3moe_sft.sh | 252 ++++++ verlbb/verlbb/__init__.py | 1 + verlbb/verlbb/config/__init__.py | 1 + verlbb/verlbb/config/actor/__init__.py | 1 + .../verlbb/config/actor/bumblebee_actor.yaml | 10 + verlbb/verlbb/config/engine/__init__.py | 1 + verlbb/verlbb/config/engine/bumblebee.yaml | 24 + verlbb/verlbb/engine/__init__.py | 5 + verlbb/verlbb/engine/bumblebee_engine.py | 747 ++++++++++++++++++ verlbb/verlbb/engine/config.py | 33 + 14 files changed, 1459 insertions(+) create mode 100644 verlbb/README.md create mode 100644 verlbb/REQUIRED_VERL.txt create mode 100755 verlbb/scripts/run_qwen3moe_gsm8k_grpo.sh create mode 100755 verlbb/scripts/run_qwen3moe_sft.sh create mode 100644 verlbb/verlbb/__init__.py create mode 100644 verlbb/verlbb/config/__init__.py create mode 100644 verlbb/verlbb/config/actor/__init__.py create mode 100644 verlbb/verlbb/config/actor/bumblebee_actor.yaml create mode 100644 verlbb/verlbb/config/engine/__init__.py create mode 100644 verlbb/verlbb/config/engine/bumblebee.yaml create mode 100644 verlbb/verlbb/engine/__init__.py create mode 100644 verlbb/verlbb/engine/bumblebee_engine.py create mode 100644 verlbb/verlbb/engine/config.py diff --git a/README.md b/README.md index 93e1d790..4218a6d2 100644 --- a/README.md +++ b/README.md @@ -90,12 +90,14 @@ The script requires only `bash`, `git`, `awk`, and `pip`/`pip3` on `PATH`. It do | spo | [`recipe/spo/REQUIRED_VERL.txt`](spo/REQUIRED_VERL.txt) | | sppo | [`recipe/sppo/REQUIRED_VERL.txt`](sppo/REQUIRED_VERL.txt) | | swe_agent | [`recipe/swe_agent/REQUIRED_VERL.txt`](swe_agent/REQUIRED_VERL.txt) | +| verlbb | [`recipe/verlbb/REQUIRED_VERL.txt`](verlbb/REQUIRED_VERL.txt) | ## Available Recipes (high level) - [retool](https://github.com/verl-project/verl-recipe/tree/main/retool): Reinforcement Learning for Strategic Tool Use in LLMs - [langgraph_agent](https://github.com/verl-project/verl-recipe/tree/main/langgraph_agent): A tiny example to demonstrate multi-turn rollout with [LangGraph ReactAgent](https://langchain-ai.github.io/langgraph/agents/overview/) to solve math expression. - [spo](https://github.com/verl-project/verl-recipe/tree/main/spo): [Single-stream Policy Optimization](https://arxiv.org/abs/2509.13232). +- [verlbb](https://github.com/verl-project/verl-recipe/tree/main/verlbb): Bumblebee as an external verl engine for SFT and GRPO. - TBA... ## Contribution diff --git a/verlbb/README.md b/verlbb/README.md new file mode 100644 index 00000000..65d05bd1 --- /dev/null +++ b/verlbb/README.md @@ -0,0 +1,82 @@ +# VerlBB + +## Required `verl` version + +See [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt) for the upstream repository, install mode, and copy-pastable `pip` / `git` instructions. + +## Overview + +VerlBB is a recipe for running verl SFT and GRPO with Bumblebee as an external training engine. The recipe keeps the integration thin: + +- verl still owns trainers, datasets, rollout orchestration, and algorithm logic. +- Bumblebee owns model construction, parallelism, optimizer/offload, checkpoint, and weight export. +- The adapter package under [`verlbb/`](verlbb/) only registers `strategy=bumblebee` and translates verl batches into Bumblebee's THD/no-padding runtime contract. + +The included scripts also keep `BACKEND=megatron` as a reference path so the same data/model settings can be compared against verl's Megatron backend. + +## Prerequisites + +Install or expose these packages before running the scripts: + +- `verl` at the version described in [`REQUIRED_VERL.txt`](REQUIRED_VERL.txt). +- Bumblebee runtime, either installed in the active environment or exposed with `BUMBLEBEE_ROOT=/path/to/bumblebee`. +- Megatron-LM and mbridge when using `BACKEND=megatron`, exposed with `MEGATRON_ROOT` and `MBRIDGE_ROOT` if they are not installed. + +Optional source-tree overrides: + +```bash +export VERL_ROOT=/path/to/verl +export BUMBLEBEE_ROOT=/path/to/bumblebee +export MEGATRON_ROOT=/path/to/Megatron-LM +export MBRIDGE_ROOT=/path/to/mbridge +``` + +The scripts do not launch through a cluster scheduler. For multi-node runs, start the same script per node with the standard `torchrun` or Ray environment variables for your environment. + +## SFT + +The SFT script expects a messages parquet input and uses verl's native SFT trainer. + +```bash +export MODEL_PATH=/path/to/qwen3-moe-hf +export TRAIN_FILES=/path/to/train.parquet +export VAL_FILES=/path/to/val.parquet + +BACKEND=bumblebee bash verlbb/scripts/run_qwen3moe_sft.sh +``` + +Common knobs: + +- `BACKEND=bumblebee|megatron` +- `TP_SIZE`, `PP_SIZE`, `VPP_SIZE`, `CP_SIZE`, `EP_SIZE`, `ETP_SIZE` +- `TOTAL_STEPS`, `TRAIN_BATCH_SIZE`, `MICRO_BATCH_SIZE`, `MAX_TOKENS_PER_GPU` +- `PARAM_OFFLOAD`, `OPTIMIZER_OFFLOAD`, `GRAD_OFFLOAD` +- `ATTENTION_BACKEND=flash` +- `DRY_RUN=1` to print the resolved command without running it + +## GRPO on GSM8K + +The GRPO script expects GSM8K-style train/validation parquet files with a `prompt` field. + +```bash +export MODEL_PATH=/path/to/qwen3-moe-hf +export TRAIN_FILE=/path/to/gsm8k/train.parquet +export VAL_FILE=/path/to/gsm8k/test.parquet + +BACKEND=bumblebee bash verlbb/scripts/run_qwen3moe_gsm8k_grpo.sh +``` + +Common knobs: + +- `BACKEND=bumblebee|megatron` +- `TP_SIZE`, `PP_SIZE`, `VPP_SIZE`, `CP_SIZE`, `EP_SIZE`, `ETP_SIZE` +- `TOTAL_STEPS`, `TRAIN_BATCH_SIZE`, `ROLLOUT_N` +- `ROLLOUT_GPU_MEMORY_UTILIZATION`, default `0.7` +- `ALL_OFFLOAD`, or the individual `PARAM_OFFLOAD`, `OPTIMIZER_OFFLOAD`, `GRAD_OFFLOAD` +- `USE_FUSED_KERNELS=True` +- `ATTENTION_BACKEND=flash` +- `DRY_RUN=1` to print the resolved command without running it + +## Outputs + +By default the scripts write logs, file-logger JSONL, command snapshots, and checkpoints under `verlbb/outputs/...`. Override `OUTPUT_ROOT`, `LOG_FILE`, `JSONL_FILE`, `CMD_FILE`, or `CKPT_DIR` to redirect artifacts. diff --git a/verlbb/REQUIRED_VERL.txt b/verlbb/REQUIRED_VERL.txt new file mode 100644 index 00000000..cee5a16f --- /dev/null +++ b/verlbb/REQUIRED_VERL.txt @@ -0,0 +1,13 @@ +# verlbb — rolling; requires Bumblebee runtime available in the Python environment +UPSTREAM=https://github.com/verl-project/verl.git +MODE=rolling +BRANCH=main +# Exact upstream verl commit this recipe was last exercised against. +VERL_COMMIT=ca82c4b47148fe2266d4b4626723d0c4654c5bc0 +PIP_INSTALL=pip install verl@git+https://github.com/verl-project/verl.git@ca82c4b47148fe2266d4b4626723d0c4654c5bc0 +GIT_SETUP=git clone https://github.com/verl-project/verl.git && cd verl && git checkout ca82c4b47148fe2266d4b4626723d0c4654c5bc0 && git submodule update --init --recursive recipe +RECIPE_SUBMODULE_COMMIT=34b0d7e08706d1b35d3527ca2e9c39e7c6d9fe3c +RECIPE_FOLDER=verlbb +RECIPE_FOLDER_LAST_COMMIT=new-recipe +NOTES=VERL_COMMIT pins the core library. Bumblebee itself is an external runtime dependency and must be installed or exposed through BUMBLEBEE_ROOT. +REFRESH=Recompute: (cd verl && git rev-parse HEAD); (cd verl/recipe && git rev-parse HEAD); after merge, (cd verl/recipe && git log -1 --format=%H -- verlbb) diff --git a/verlbb/scripts/run_qwen3moe_gsm8k_grpo.sh b/verlbb/scripts/run_qwen3moe_gsm8k_grpo.sh new file mode 100755 index 00000000..5ff23974 --- /dev/null +++ b/verlbb/scripts/run_qwen3moe_gsm8k_grpo.sh @@ -0,0 +1,287 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [[ "${VERBOSE:-0}" == "1" ]]; then + set -x +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -L)" +RECIPE_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd -L)" + +add_pythonpath() { + local path="${1:-}" + if [[ -n "${path}" ]]; then + export PYTHONPATH="${path}:${PYTHONPATH:-}" + fi +} + +add_pythonpath "${RECIPE_ROOT}" +add_pythonpath "${VERL_ROOT:-}" +add_pythonpath "${BUMBLEBEE_ROOT:-}" +add_pythonpath "${MEGATRON_ROOT:-}" +add_pythonpath "${MBRIDGE_ROOT:-}" + +export CUDA_DEVICE_MAX_CONNECTIONS="${CUDA_DEVICE_MAX_CONNECTIONS:-1}" +export VLLM_ATTENTION_BACKEND="${VLLM_ATTENTION_BACKEND:-FLASH_ATTN}" +export VLLM_USE_V1="${VLLM_USE_V1:-1}" +export RAY_memory_monitor_refresh_ms="${RAY_memory_monitor_refresh_ms:-0}" + +if [[ "${DISABLE_VLLM_EXPANDABLE_SEGMENTS:-True}" == "True" ]]; then + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-}" + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF//expandable_segments:True/}" + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF//,,/,}" + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF#,}" + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF%,}" + if [[ -z "${PYTORCH_CUDA_ALLOC_CONF}" ]]; then + unset PYTORCH_CUDA_ALLOC_CONF + fi + + export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF:-}" + export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF//expandable_segments:True/}" + export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF//,,/,}" + export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF#,}" + export PYTORCH_ALLOC_CONF="${PYTORCH_ALLOC_CONF%,}" + if [[ -z "${PYTORCH_ALLOC_CONF}" ]]; then + unset PYTORCH_ALLOC_CONF + fi +fi + +: "${MODEL_PATH:?set MODEL_PATH to a Hugging Face checkpoint directory or model id}" +TRAIN_FILE="${TRAIN_FILE:-${TRAIN_FILES:-}}" +VAL_FILE="${VAL_FILE:-${VAL_FILES:-}}" +: "${TRAIN_FILE:?set TRAIN_FILE or TRAIN_FILES to a GSM8K train parquet path}" +: "${VAL_FILE:?set VAL_FILE or VAL_FILES to a GSM8K validation parquet path}" + +BACKEND="${BACKEND:-bumblebee}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${RECIPE_ROOT}/outputs/qwen3moe_gsm8k_grpo}" +PROJECT_NAME="${PROJECT_NAME:-verlbb-qwen3moe-gsm8k-grpo}" + +NUM_GPUS="${NUM_GPUS:-8}" +NNODES="${NNODES:-1}" + +TOTAL_STEPS="${TOTAL_STEPS:-100}" +TOTAL_EPOCHS="${TOTAL_EPOCHS:-1}" +SAVE_FREQ="${SAVE_FREQ:--1}" +TEST_FREQ="${TEST_FREQ:--1}" +RESUME_MODE="${RESUME_MODE:-disable}" +RESUME_FROM_PATH="${RESUME_FROM_PATH:-null}" +TRAIN_BATCH_SIZE="${TRAIN_BATCH_SIZE:-4}" +TRAIN_MAX_SAMPLES="${TRAIN_MAX_SAMPLES:-null}" +VAL_MAX_SAMPLES="${VAL_MAX_SAMPLES:-null}" +MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-512}" +MAX_RESPONSE_LENGTH="${MAX_RESPONSE_LENGTH:-128}" +ROLLOUT_N="${ROLLOUT_N:-2}" +PPO_MINI_BATCH_SIZE="${PPO_MINI_BATCH_SIZE:-4}" +PPO_MICRO_BATCH_SIZE_PER_GPU="${PPO_MICRO_BATCH_SIZE_PER_GPU:-1}" +MAX_TOKEN_LEN_PER_GPU="${MAX_TOKEN_LEN_PER_GPU:-2048}" +INFER_MAX_TOKEN_LEN_PER_GPU="${INFER_MAX_TOKEN_LEN_PER_GPU:-2048}" +LOG_PROB_MICRO_BATCH_SIZE_PER_GPU="${LOG_PROB_MICRO_BATCH_SIZE_PER_GPU:-1}" + +TP_SIZE="${TP_SIZE:-2}" +PP_SIZE="${PP_SIZE:-1}" +VPP_SIZE="${VPP_SIZE:-null}" +CP_SIZE="${CP_SIZE:-1}" +EP_SIZE="${EP_SIZE:-8}" +ETP_SIZE="${ETP_SIZE:-1}" +DTYPE="${DTYPE:-bfloat16}" +BB_IMPL="${BB_IMPL:-lite}" +ATTENTION_BACKEND="${ATTENTION_BACKEND:-flash}" +USE_FUSED_KERNELS="${USE_FUSED_KERNELS:-True}" +USE_DYNAMIC_BSZ="${USE_DYNAMIC_BSZ:-True}" +USE_REMOVE_PADDING="${USE_REMOVE_PADDING:-True}" +ALL_OFFLOAD="${ALL_OFFLOAD:-True}" +PARAM_OFFLOAD="${PARAM_OFFLOAD:-${ALL_OFFLOAD}}" +OPTIMIZER_OFFLOAD="${OPTIMIZER_OFFLOAD:-${ALL_OFFLOAD}}" +GRAD_OFFLOAD="${GRAD_OFFLOAD:-${ALL_OFFLOAD}}" +OPTIMIZER_STATE_OFFLOAD_FRACTION="${OPTIMIZER_STATE_OFFLOAD_FRACTION:-1.0}" +OPTIMIZER_CPU_OFFLOAD="${OPTIMIZER_CPU_OFFLOAD:-True}" +USE_PRECISION_AWARE_OPTIMIZER="${USE_PRECISION_AWARE_OPTIMIZER:-True}" +DECOUPLED_WEIGHT_DECAY="${DECOUPLED_WEIGHT_DECAY:-True}" + +ROLLOUT_TP_SIZE="${ROLLOUT_TP_SIZE:-2}" +ROLLOUT_DP_SIZE="${ROLLOUT_DP_SIZE:-1}" +ROLLOUT_EP_SIZE="${ROLLOUT_EP_SIZE:-1}" +ROLLOUT_GPU_MEMORY_UTILIZATION="${ROLLOUT_GPU_MEMORY_UTILIZATION:-0.7}" +ROLLOUT_MAX_MODEL_LEN="${ROLLOUT_MAX_MODEL_LEN:-$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))}" +ROLLOUT_MAX_NUM_BATCHED_TOKENS="${ROLLOUT_MAX_NUM_BATCHED_TOKENS:-$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))}" +ROLLOUT_MAX_NUM_SEQS="${ROLLOUT_MAX_NUM_SEQS:-16}" + +LR="${LR:-1e-6}" +WEIGHT_DECAY="${WEIGHT_DECAY:-0.01}" +SEED="${SEED:-1}" +DRY_RUN="${DRY_RUN:-0}" +EXTRA_ARGS=("$@") + +case "${BACKEND}" in + megatron|bumblebee) + ;; + *) + echo "Unsupported BACKEND=${BACKEND}. Expected megatron or bumblebee." >&2 + exit 1 + ;; +esac + +BB_VPP_SIZE="${VPP_SIZE}" +if [[ "${BB_VPP_SIZE}" == "null" ]]; then + BB_VPP_SIZE=1 +fi + +RUN_NAME="${RUN_NAME:-qwen3moe_gsm8k_grpo_${BACKEND}_tp${TP_SIZE}_pp${PP_SIZE}_cp${CP_SIZE}_ep${EP_SIZE}_etp${ETP_SIZE}_n${ROLLOUT_N}}" +CKPT_DIR="${CKPT_DIR:-${OUTPUT_ROOT}/checkpoints/${RUN_NAME}}" +LOG_FILE="${LOG_FILE:-${OUTPUT_ROOT}/${RUN_NAME}.log}" +JSONL_FILE="${JSONL_FILE:-${OUTPUT_ROOT}/${RUN_NAME}.jsonl}" +CMD_FILE="${CMD_FILE:-${OUTPUT_ROOT}/${RUN_NAME}.cmd.sh}" + +mkdir -p "${OUTPUT_ROOT}" "${CKPT_DIR}" "$(dirname "${LOG_FILE}")" "$(dirname "${JSONL_FILE}")" "$(dirname "${CMD_FILE}")" +export VERL_FILE_LOGGER_PATH="${JSONL_FILE}" +export VERL_FILE_LOGGER_ROOT="${OUTPUT_ROOT}" + +CACHE_ROOT="${VERLBB_CACHE_ROOT:-${TMPDIR:-/tmp}/verlbb}" +mkdir -p "${CACHE_ROOT}/pycache_${USER:-user}" "${CACHE_ROOT}/torchinductor_${USER:-user}" "${CACHE_ROOT}/triton_${USER:-user}" +export PYTHONPYCACHEPREFIX="${PYTHONPYCACHEPREFIX:-${CACHE_ROOT}/pycache_${USER:-user}}" +export TORCHINDUCTOR_CACHE_DIR="${TORCHINDUCTOR_CACHE_DIR:-${CACHE_ROOT}/torchinductor_${USER:-user}}" +export TRITON_CACHE_DIR="${TRITON_CACHE_DIR:-${CACHE_ROOT}/triton_${USER:-user}}" + +COMMON_ARGS=( + "data.train_files=${TRAIN_FILE}" + "data.val_files=${VAL_FILE}" + "data.prompt_key=prompt" + "data.train_batch_size=${TRAIN_BATCH_SIZE}" + "data.train_max_samples=${TRAIN_MAX_SAMPLES}" + "data.val_max_samples=${VAL_MAX_SAMPLES}" + "data.max_prompt_length=${MAX_PROMPT_LENGTH}" + "data.max_response_length=${MAX_RESPONSE_LENGTH}" + "data.filter_overlong_prompts=True" + "data.truncation=left" + "data.return_raw_chat=True" + "data.trust_remote_code=True" + "data.dataloader_num_workers=0" + "actor_rollout_ref.model.path=${MODEL_PATH}" + "actor_rollout_ref.model.trust_remote_code=True" + "actor_rollout_ref.model.use_remove_padding=${USE_REMOVE_PADDING}" + "actor_rollout_ref.model.use_fused_kernels=${USE_FUSED_KERNELS}" + "actor_rollout_ref.rollout.name=vllm" + "actor_rollout_ref.rollout.mode=async" + "actor_rollout_ref.rollout.n=${ROLLOUT_N}" + "actor_rollout_ref.rollout.temperature=1.0" + "actor_rollout_ref.rollout.top_p=1.0" + "actor_rollout_ref.rollout.top_k=-1" + "actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${LOG_PROB_MICRO_BATCH_SIZE_PER_GPU}" + "actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${USE_DYNAMIC_BSZ}" + "actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${INFER_MAX_TOKEN_LEN_PER_GPU}" + "actor_rollout_ref.rollout.tensor_model_parallel_size=${ROLLOUT_TP_SIZE}" + "actor_rollout_ref.rollout.data_parallel_size=${ROLLOUT_DP_SIZE}" + "actor_rollout_ref.rollout.expert_parallel_size=${ROLLOUT_EP_SIZE}" + "actor_rollout_ref.rollout.gpu_memory_utilization=${ROLLOUT_GPU_MEMORY_UTILIZATION}" + "actor_rollout_ref.rollout.max_model_len=${ROLLOUT_MAX_MODEL_LEN}" + "actor_rollout_ref.rollout.max_num_batched_tokens=${ROLLOUT_MAX_NUM_BATCHED_TOKENS}" + "actor_rollout_ref.rollout.max_num_seqs=${ROLLOUT_MAX_NUM_SEQS}" + "actor_rollout_ref.rollout.enable_chunked_prefill=True" + "actor_rollout_ref.rollout.enable_prefix_caching=True" + "actor_rollout_ref.rollout.enforce_eager=True" + "actor_rollout_ref.rollout.free_cache_engine=True" + "actor_rollout_ref.rollout.val_kwargs.n=1" + "actor_rollout_ref.rollout.val_kwargs.do_sample=True" + "actor_rollout_ref.rollout.val_kwargs.temperature=1.0" + "actor_rollout_ref.actor.use_dynamic_bsz=${USE_DYNAMIC_BSZ}" + "actor_rollout_ref.actor.use_kl_loss=False" + "actor_rollout_ref.actor.entropy_coeff=0" + "actor_rollout_ref.actor.ppo_epochs=1" + "actor_rollout_ref.actor.ppo_mini_batch_size=${PPO_MINI_BATCH_SIZE}" + "actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${PPO_MICRO_BATCH_SIZE_PER_GPU}" + "actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${MAX_TOKEN_LEN_PER_GPU}" + "actor_rollout_ref.actor.optim.lr=${LR}" + "actor_rollout_ref.actor.optim.weight_decay=${WEIGHT_DECAY}" + "algorithm.adv_estimator=grpo" + "algorithm.use_kl_in_reward=False" + "algorithm.kl_ctrl.kl_coef=0.0" + "reward.reward_manager.name=naive" + "trainer.logger=[console,file]" + "trainer.project_name=${PROJECT_NAME}" + "trainer.experiment_name=${RUN_NAME}" + "trainer.default_local_dir=${CKPT_DIR}" + "trainer.total_epochs=${TOTAL_EPOCHS}" + "trainer.total_training_steps=${TOTAL_STEPS}" + "trainer.val_before_train=False" + "trainer.test_freq=${TEST_FREQ}" + "trainer.save_freq=${SAVE_FREQ}" + "trainer.resume_mode=${RESUME_MODE}" + "trainer.resume_from_path=${RESUME_FROM_PATH}" + "trainer.nnodes=${NNODES}" + "trainer.n_gpus_per_node=${NUM_GPUS}" + "trainer.use_legacy_worker_impl=disable" +) + +if [[ "${BACKEND}" == "megatron" ]]; then + BACKEND_ARGS=( + "model_engine=megatron" + "actor_rollout_ref.actor.use_torch_compile=False" + "actor_rollout_ref.actor.megatron.dtype=${DTYPE}" + "actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP_SIZE}" + "actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP_SIZE}" + "actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${VPP_SIZE}" + "actor_rollout_ref.actor.megatron.context_parallel_size=${CP_SIZE}" + "actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP_SIZE}" + "actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP_SIZE}" + "actor_rollout_ref.actor.megatron.param_offload=${PARAM_OFFLOAD}" + "actor_rollout_ref.actor.megatron.optimizer_offload=${OPTIMIZER_OFFLOAD}" + "actor_rollout_ref.actor.megatron.grad_offload=${GRAD_OFFLOAD}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${OPTIMIZER_STATE_OFFLOAD_FRACTION}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True" + "+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=${USE_PRECISION_AWARE_OPTIMIZER}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=${OPTIMIZER_CPU_OFFLOAD}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.decoupled_weight_decay=${DECOUPLED_WEIGHT_DECAY}" + "actor_rollout_ref.actor.megatron.override_transformer_config.attention_backend=${ATTENTION_BACKEND}" + ) +else + BACKEND_ARGS=( + "hydra.searchpath=[pkg://verlbb.config]" + "actor@actor_rollout_ref.actor=bumblebee_actor" + "actor_rollout_ref.actor.use_torch_compile=False" + "actor_rollout_ref.actor.engine.dtype=${DTYPE}" + "actor_rollout_ref.actor.engine.impl=${BB_IMPL}" + "actor_rollout_ref.actor.engine.tp=${TP_SIZE}" + "actor_rollout_ref.actor.engine.pp=${PP_SIZE}" + "actor_rollout_ref.actor.engine.vpp=${BB_VPP_SIZE}" + "actor_rollout_ref.actor.engine.cp=${CP_SIZE}" + "actor_rollout_ref.actor.engine.ep=${EP_SIZE}" + "actor_rollout_ref.actor.engine.etp=${ETP_SIZE}" + "actor_rollout_ref.actor.engine.param_offload=${PARAM_OFFLOAD}" + "actor_rollout_ref.actor.engine.optimizer_offload=${OPTIMIZER_OFFLOAD}" + "actor_rollout_ref.actor.engine.grad_offload=${GRAD_OFFLOAD}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.offload_fraction=${OPTIMIZER_STATE_OFFLOAD_FRACTION}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=${USE_PRECISION_AWARE_OPTIMIZER}" + "+actor_rollout_ref.actor.optim.override_optimizer_config.decoupled_weight_decay=${DECOUPLED_WEIGHT_DECAY}" + "actor_rollout_ref.actor.engine.attention_backend_override=${ATTENTION_BACKEND}" + "actor_rollout_ref.actor.engine.impl_cfg.use_thd=True" + ) +fi + +COMMAND=( + python3 + -m + verl.trainer.main_ppo + "${COMMON_ARGS[@]}" + "${BACKEND_ARGS[@]}" + "${EXTRA_ARGS[@]}" +) + +printf '%q ' "${COMMAND[@]}" > "${CMD_FILE}" +printf '\n' >> "${CMD_FILE}" + +if [[ "${DRY_RUN}" == "1" ]]; then + printf '%q ' "${COMMAND[@]}" + printf '\n' + exit 0 +fi + +echo "[${BACKEND}] output_root=${OUTPUT_ROOT}" +echo "[${BACKEND}] log=${LOG_FILE}" +echo "[${BACKEND}] jsonl=${JSONL_FILE}" +echo "[${BACKEND}] cmd=${CMD_FILE}" + +set +e +"${COMMAND[@]}" 2>&1 | tee "${LOG_FILE}" +cmd_rc="${PIPESTATUS[0]}" +set -e +exit "${cmd_rc}" diff --git a/verlbb/scripts/run_qwen3moe_sft.sh b/verlbb/scripts/run_qwen3moe_sft.sh new file mode 100755 index 00000000..4b5b3f3b --- /dev/null +++ b/verlbb/scripts/run_qwen3moe_sft.sh @@ -0,0 +1,252 @@ +#!/usr/bin/env bash +set -euo pipefail + +if [[ "${VERBOSE:-0}" == "1" ]]; then + set -x +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -L)" +RECIPE_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd -L)" + +add_pythonpath() { + local path="${1:-}" + if [[ -n "${path}" ]]; then + export PYTHONPATH="${path}:${PYTHONPATH:-}" + fi +} + +add_pythonpath "${RECIPE_ROOT}" +add_pythonpath "${VERL_ROOT:-}" +add_pythonpath "${BUMBLEBEE_ROOT:-}" +add_pythonpath "${MEGATRON_ROOT:-}" +add_pythonpath "${MBRIDGE_ROOT:-}" + +export CUDA_DEVICE_MAX_CONNECTIONS="${CUDA_DEVICE_MAX_CONNECTIONS:-1}" + +: "${MODEL_PATH:?set MODEL_PATH to a Hugging Face checkpoint directory or model id}" +: "${TRAIN_FILES:?set TRAIN_FILES to a messages parquet path or comma-separated parquet paths}" + +BACKEND="${BACKEND:-bumblebee}" +VAL_FILES="${VAL_FILES:-}" +OUTPUT_ROOT="${OUTPUT_ROOT:-${RECIPE_ROOT}/outputs/qwen3moe_sft}" +PROJECT_NAME="${PROJECT_NAME:-verlbb-qwen3moe-sft}" + +NUM_GPUS="${NUM_GPUS:-${NPROC_PER_NODE:-8}}" +NPROC_PER_NODE="${NPROC_PER_NODE:-${NUM_GPUS}}" +NNODES="${NNODES:-1}" +NODE_RANK="${NODE_RANK:-0}" +MASTER_ADDR="${MASTER_ADDR:-127.0.0.1}" +MASTER_PORT="${MASTER_PORT:-29500}" + +TOTAL_STEPS="${TOTAL_STEPS:-100}" +TOTAL_EPOCHS="${TOTAL_EPOCHS:-1}" +SAVE_FREQ="${SAVE_FREQ:-${TOTAL_STEPS}}" +TEST_FREQ="${TEST_FREQ:--1}" +RESUME_MODE="${RESUME_MODE:-disable}" +RESUME_FROM_PATH="${RESUME_FROM_PATH:-null}" +TRAIN_BATCH_SIZE="${TRAIN_BATCH_SIZE:-64}" +MICRO_BATCH_SIZE="${MICRO_BATCH_SIZE:-1}" +MAX_TOKENS_PER_GPU="${MAX_TOKENS_PER_GPU:-8192}" +MAX_LENGTH="${MAX_LENGTH:-${MAX_TOKENS_PER_GPU}}" +PAD_MODE="${PAD_MODE:-no_padding}" +USE_DYNAMIC_BSZ="${USE_DYNAMIC_BSZ:-True}" +USE_REMOVE_PADDING="${USE_REMOVE_PADDING:-True}" +IGNORE_INPUT_IDS_MISMATCH="${IGNORE_INPUT_IDS_MISMATCH:-True}" +TRUST_REMOTE_CODE="${TRUST_REMOTE_CODE:-True}" +NUM_WORKERS="${NUM_WORKERS:-0}" +SEED="${SEED:-1}" + +TP_SIZE="${TP_SIZE:-2}" +PP_SIZE="${PP_SIZE:-1}" +VPP_SIZE="${VPP_SIZE:-null}" +CP_SIZE="${CP_SIZE:-1}" +EP_SIZE="${EP_SIZE:-8}" +ETP_SIZE="${ETP_SIZE:-1}" +DTYPE="${DTYPE:-bfloat16}" +BB_IMPL="${BB_IMPL:-lite}" +ATTENTION_BACKEND="${ATTENTION_BACKEND:-flash}" + +LR="${LR:-1e-5}" +MIN_LR="${MIN_LR:-${LR}}" +WEIGHT_DECAY="${WEIGHT_DECAY:-0.1}" +BETAS="${BETAS:-[0.9,0.95]}" +CLIP_GRAD="${CLIP_GRAD:-1.0}" +LR_WARMUP_STEPS="${LR_WARMUP_STEPS:-0}" +LR_DECAY_STYLE="${LR_DECAY_STYLE:-constant}" + +PARAM_OFFLOAD="${PARAM_OFFLOAD:-False}" +OPTIMIZER_OFFLOAD="${OPTIMIZER_OFFLOAD:-True}" +GRAD_OFFLOAD="${GRAD_OFFLOAD:-False}" +OPTIMIZER_STATE_OFFLOAD_FRACTION="${OPTIMIZER_STATE_OFFLOAD_FRACTION:-1.0}" +OPTIMIZER_CPU_OFFLOAD="${OPTIMIZER_CPU_OFFLOAD:-True}" +USE_PRECISION_AWARE_OPTIMIZER="${USE_PRECISION_AWARE_OPTIMIZER:-True}" +DECOUPLED_WEIGHT_DECAY="${DECOUPLED_WEIGHT_DECAY:-True}" +USE_MBRIDGE="${USE_MBRIDGE:-True}" +DRY_RUN="${DRY_RUN:-0}" +EXTRA_ARGS=("$@") + +case "${BACKEND}" in + megatron|bumblebee) + ;; + *) + echo "Unsupported BACKEND=${BACKEND}. Expected megatron or bumblebee." >&2 + exit 1 + ;; +esac + +if [[ "${PAD_MODE}" != "no_padding" ]]; then + echo "VerlBB currently supports PAD_MODE=no_padding only." >&2 + exit 1 +fi + +BB_VPP_SIZE="${VPP_SIZE}" +if [[ "${BB_VPP_SIZE}" == "null" ]]; then + BB_VPP_SIZE=1 +fi + +RUN_NAME="${RUN_NAME:-qwen3moe_sft_${BACKEND}_tp${TP_SIZE}_pp${PP_SIZE}_cp${CP_SIZE}_ep${EP_SIZE}_etp${ETP_SIZE}}" +CKPT_DIR="${CKPT_DIR:-${OUTPUT_ROOT}/checkpoints/${RUN_NAME}}" +LOG_FILE="${LOG_FILE:-${OUTPUT_ROOT}/${RUN_NAME}.log}" +JSONL_FILE="${JSONL_FILE:-${OUTPUT_ROOT}/${RUN_NAME}.jsonl}" +CMD_FILE="${CMD_FILE:-${OUTPUT_ROOT}/${RUN_NAME}.cmd.sh}" + +mkdir -p "${OUTPUT_ROOT}" "${CKPT_DIR}" "$(dirname "${LOG_FILE}")" "$(dirname "${JSONL_FILE}")" "$(dirname "${CMD_FILE}")" +export VERL_FILE_LOGGER_PATH="${JSONL_FILE}" + +CACHE_ROOT="${VERLBB_CACHE_ROOT:-${TMPDIR:-/tmp}/verlbb}" +mkdir -p "${CACHE_ROOT}/pycache_${USER:-user}" "${CACHE_ROOT}/torchinductor_${USER:-user}" "${CACHE_ROOT}/triton_${USER:-user}" +export PYTHONPYCACHEPREFIX="${PYTHONPYCACHEPREFIX:-${CACHE_ROOT}/pycache_${USER:-user}}" +export TORCHINDUCTOR_CACHE_DIR="${TORCHINDUCTOR_CACHE_DIR:-${CACHE_ROOT}/torchinductor_${USER:-user}}" +export TRITON_CACHE_DIR="${TRITON_CACHE_DIR:-${CACHE_ROOT}/triton_${USER:-user}}" + +COMMON_ARGS=( + "data.train_files=${TRAIN_FILES}" + "data.train_batch_size=${TRAIN_BATCH_SIZE}" + "data.micro_batch_size_per_gpu=${MICRO_BATCH_SIZE}" + "data.use_dynamic_bsz=${USE_DYNAMIC_BSZ}" + "data.max_token_len_per_gpu=${MAX_TOKENS_PER_GPU}" + "data.max_length=${MAX_LENGTH}" + "data.pad_mode=${PAD_MODE}" + "data.truncation=error" + "data.messages_key=messages" + "data.ignore_input_ids_mismatch=${IGNORE_INPUT_IDS_MISMATCH}" + "data.num_workers=${NUM_WORKERS}" + "model=hf_model" + "model.path=${MODEL_PATH}" + "model.trust_remote_code=${TRUST_REMOTE_CODE}" + "model.use_remove_padding=${USE_REMOVE_PADDING}" + "optim=megatron" + "optim.lr=${LR}" + "optim.min_lr=${MIN_LR}" + "optim.weight_decay=${WEIGHT_DECAY}" + "optim.betas=${BETAS}" + "optim.clip_grad=${CLIP_GRAD}" + "optim.lr_warmup_steps=${LR_WARMUP_STEPS}" + "optim.lr_warmup_init=0" + "optim.lr_decay_style=${LR_DECAY_STYLE}" + "trainer.logger=[console,file]" + "trainer.project_name=${PROJECT_NAME}" + "trainer.experiment_name=${RUN_NAME}" + "trainer.default_local_dir=${CKPT_DIR}" + "trainer.total_epochs=${TOTAL_EPOCHS}" + "trainer.total_training_steps=${TOTAL_STEPS}" + "trainer.save_freq=${SAVE_FREQ}" + "trainer.test_freq=${TEST_FREQ}" + "trainer.seed=${SEED}" + "trainer.resume_mode=${RESUME_MODE}" + "trainer.resume_from_path=${RESUME_FROM_PATH}" + "trainer.nnodes=${NNODES}" + "trainer.n_gpus_per_node=${NPROC_PER_NODE}" + "checkpoint.save_contents=[model,optimizer,extra]" +) + +if [[ -n "${VAL_FILES}" ]]; then + COMMON_ARGS+=("data.val_files=${VAL_FILES}") +fi + +if [[ "${BACKEND}" == "megatron" ]]; then + BACKEND_ARGS=( + "engine=megatron" + "engine.dtype=${DTYPE}" + "engine.tensor_model_parallel_size=${TP_SIZE}" + "engine.pipeline_model_parallel_size=${PP_SIZE}" + "engine.virtual_pipeline_model_parallel_size=${VPP_SIZE}" + "engine.context_parallel_size=${CP_SIZE}" + "engine.expert_model_parallel_size=${EP_SIZE}" + "engine.expert_tensor_parallel_size=${ETP_SIZE}" + "engine.use_mbridge=${USE_MBRIDGE}" + "engine.param_offload=${PARAM_OFFLOAD}" + "engine.optimizer_offload=${OPTIMIZER_OFFLOAD}" + "engine.grad_offload=${GRAD_OFFLOAD}" + "+engine.override_transformer_config.context_parallel_size=${CP_SIZE}" + "+engine.override_transformer_config.attention_backend=${ATTENTION_BACKEND}" + ) + if [[ "${OPTIMIZER_OFFLOAD}" == "True" || "${OPTIMIZER_OFFLOAD}" == "true" || "${OPTIMIZER_OFFLOAD}" == "1" ]]; then + BACKEND_ARGS+=( + "+optim.override_optimizer_config.optimizer_offload_fraction=${OPTIMIZER_STATE_OFFLOAD_FRACTION}" + "+optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True" + "+optim.override_optimizer_config.use_precision_aware_optimizer=${USE_PRECISION_AWARE_OPTIMIZER}" + "+optim.override_optimizer_config.optimizer_cpu_offload=${OPTIMIZER_CPU_OFFLOAD}" + "+optim.override_optimizer_config.decoupled_weight_decay=${DECOUPLED_WEIGHT_DECAY}" + ) + fi +else + BACKEND_ARGS=( + "hydra.searchpath=[pkg://verlbb.config]" + "engine=bumblebee" + "engine.dtype=${DTYPE}" + "engine.impl=${BB_IMPL}" + "engine.tp=${TP_SIZE}" + "engine.pp=${PP_SIZE}" + "engine.vpp=${BB_VPP_SIZE}" + "engine.cp=${CP_SIZE}" + "engine.ep=${EP_SIZE}" + "engine.etp=${ETP_SIZE}" + "engine.param_offload=${PARAM_OFFLOAD}" + "engine.optimizer_offload=${OPTIMIZER_OFFLOAD}" + "engine.grad_offload=${GRAD_OFFLOAD}" + "engine.attention_backend_override=${ATTENTION_BACKEND}" + "engine.impl_cfg.use_thd=True" + ) + if [[ "${OPTIMIZER_OFFLOAD}" == "True" || "${OPTIMIZER_OFFLOAD}" == "true" || "${OPTIMIZER_OFFLOAD}" == "1" ]]; then + BACKEND_ARGS+=( + "+optim.override_optimizer_config.offload_fraction=${OPTIMIZER_STATE_OFFLOAD_FRACTION}" + "+optim.override_optimizer_config.use_precision_aware_optimizer=${USE_PRECISION_AWARE_OPTIMIZER}" + "+optim.override_optimizer_config.decoupled_weight_decay=${DECOUPLED_WEIGHT_DECAY}" + ) + fi +fi + +COMMAND=( + torchrun + --nnodes="${NNODES}" + --node_rank="${NODE_RANK}" + --master_addr="${MASTER_ADDR}" + --master_port="${MASTER_PORT}" + --nproc_per_node="${NPROC_PER_NODE}" + -m + verl.trainer.sft_trainer + "${COMMON_ARGS[@]}" + "${BACKEND_ARGS[@]}" + "${EXTRA_ARGS[@]}" +) + +printf '%q ' "${COMMAND[@]}" > "${CMD_FILE}" +printf '\n' >> "${CMD_FILE}" + +if [[ "${DRY_RUN}" == "1" ]]; then + printf '%q ' "${COMMAND[@]}" + printf '\n' + exit 0 +fi + +echo "[${BACKEND}] output_root=${OUTPUT_ROOT}" +echo "[${BACKEND}] log=${LOG_FILE}" +echo "[${BACKEND}] jsonl=${JSONL_FILE}" +echo "[${BACKEND}] cmd=${CMD_FILE}" + +set +e +"${COMMAND[@]}" 2>&1 | tee "${LOG_FILE}" +cmd_rc="${PIPESTATUS[0]}" +set -e +exit "${cmd_rc}" diff --git a/verlbb/verlbb/__init__.py b/verlbb/verlbb/__init__.py new file mode 100644 index 00000000..a22578cc --- /dev/null +++ b/verlbb/verlbb/__init__.py @@ -0,0 +1 @@ +"""VerlBB bridge package.""" diff --git a/verlbb/verlbb/config/__init__.py b/verlbb/verlbb/config/__init__.py new file mode 100644 index 00000000..64faf98d --- /dev/null +++ b/verlbb/verlbb/config/__init__.py @@ -0,0 +1 @@ +"""Hydra config package for VerlBB.""" diff --git a/verlbb/verlbb/config/actor/__init__.py b/verlbb/verlbb/config/actor/__init__.py new file mode 100644 index 00000000..e0a31526 --- /dev/null +++ b/verlbb/verlbb/config/actor/__init__.py @@ -0,0 +1 @@ +"""Hydra actor config group for VerlBB.""" diff --git a/verlbb/verlbb/config/actor/bumblebee_actor.yaml b/verlbb/verlbb/config/actor/bumblebee_actor.yaml new file mode 100644 index 00000000..e1bc472e --- /dev/null +++ b/verlbb/verlbb/config/actor/bumblebee_actor.yaml @@ -0,0 +1,10 @@ +# Bumblebee actor config for VERL's new engine worker path. +defaults: + - /optim@optim: megatron + - /engine@engine: bumblebee + - actor + - _self_ + +_target_: verl.workers.config.ActorConfig + +strategy: bumblebee diff --git a/verlbb/verlbb/config/engine/__init__.py b/verlbb/verlbb/config/engine/__init__.py new file mode 100644 index 00000000..421320af --- /dev/null +++ b/verlbb/verlbb/config/engine/__init__.py @@ -0,0 +1 @@ +"""Engine config group for VerlBB.""" diff --git a/verlbb/verlbb/config/engine/bumblebee.yaml b/verlbb/verlbb/config/engine/bumblebee.yaml new file mode 100644 index 00000000..ab8c8a64 --- /dev/null +++ b/verlbb/verlbb/config/engine/bumblebee.yaml @@ -0,0 +1,24 @@ +_target_: verlbb.engine.config.BumblebeeEngineConfig + +strategy: bumblebee +custom_backend_module: verlbb.engine.bumblebee_engine +param_offload: false +optimizer_offload: false +grad_offload: false +forward_only: false +dtype: bfloat16 + +model_name: auto +impl: lite + +tp: 1 +etp: null +ep: 1 +pp: 1 +vpp: 1 +cp: 1 + +attention_backend_override: flash +router_aux_loss_coef: null +impl_cfg: + use_thd: true diff --git a/verlbb/verlbb/engine/__init__.py b/verlbb/verlbb/engine/__init__.py new file mode 100644 index 00000000..d0e2dfbc --- /dev/null +++ b/verlbb/verlbb/engine/__init__.py @@ -0,0 +1,5 @@ +"""Engine entrypoints for VerlBB.""" + +from verlbb.engine.bumblebee_engine import BumblebeeEngine + +__all__ = ["BumblebeeEngine"] diff --git a/verlbb/verlbb/engine/bumblebee_engine.py b/verlbb/verlbb/engine/bumblebee_engine.py new file mode 100644 index 00000000..64de25d2 --- /dev/null +++ b/verlbb/verlbb/engine/bumblebee_engine.py @@ -0,0 +1,747 @@ +"""External VERL engine backed by Bumblebee runtime primitives.""" + +from __future__ import annotations + +import os +from typing import Any + +import torch +import torch.distributed as dist +from bumblebee.model import resolve_model_type_from_hf +from bumblebee.primitive.ckpt import load_training_checkpoint, save_training_checkpoint +from bumblebee.primitive.parallel import pack_nested_thd, unpack_packed_thd_to_nested +from bumblebee.primitive.protocols import default_expert_classifier, default_placement_fn +from bumblebee.runtime import create_runtime +from bumblebee.runtime.backends.bb.config import BBConfig +from bumblebee.runtime.contracts.config import ( + OptimizerConfig as BBOptimizerConfig, +) +from bumblebee.runtime.contracts.config import ( + ParallelConfig, + RuntimeConfig, +) +from tensordict import TensorDict + +from verl.trainer.config import CheckpointConfig +from verl.utils import tensordict_utils as tu +from verl.utils.dataset.dataset_utils import DatasetPadMode +from verl.utils.device import get_device_id, get_device_name +from verl.workers.config import HFModelConfig, OptimizerConfig +from verl.workers.engine.base import BaseEngine, BaseEngineCtx, EngineRegistry +from verl.workers.engine.utils import postprocess_batch_func, prepare_micro_batches + +from .config import BumblebeeEngineConfig + +_LR_SCHEDULER_STATE = "lr_scheduler.pt" + + +def _isolate_compile_cache_per_rank() -> None: + """Avoid torchinductor/triton cache races between local torchrun ranks.""" + rank = os.environ.get("LOCAL_RANK") or os.environ.get("RANK") + if rank is None: + return + for var in ("TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"): + base = os.environ.get(var) + if not base: + continue + base_var = f"VERLBB_BASE_{var}" + root = os.environ.setdefault(base_var, base) + rank_dir = os.path.join(root, f"rank_{rank}") + os.makedirs(rank_dir, exist_ok=True) + os.environ[var] = rank_dir + + +def _build_lr_scheduler(optimizer, opt: BBOptimizerConfig): + """Build a Megatron-style LR scheduler for Bumblebee's optimizer.""" + total_steps = opt.total_training_steps + if total_steps <= 0: + return None + + from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler + + warmup_steps = opt.lr_warmup_steps if opt.lr_warmup_steps is not None else -1 + if warmup_steps <= 0 and opt.lr_warmup_steps_ratio > 0: + warmup_steps = int(opt.lr_warmup_steps_ratio * total_steps) + warmup_steps = max(warmup_steps, 0) + + decay_steps = opt.lr_decay_steps if opt.lr_decay_steps is not None else total_steps + min_lr = opt.min_lr if opt.min_lr is not None else 0.0 + for param_group in optimizer.param_groups: + if param_group.get("min_lr") is None: + param_group["min_lr"] = min_lr + + return OptimizerParamScheduler( + optimizer, + init_lr=opt.lr_warmup_init, + max_lr=opt.lr, + min_lr=min_lr, + lr_warmup_steps=warmup_steps, + lr_decay_steps=decay_steps, + lr_decay_style=opt.lr_decay_style, + start_wd=opt.weight_decay, + end_wd=opt.weight_decay, + wd_incr_steps=total_steps, + wd_incr_style=opt.weight_decay_incr_style, + use_checkpoint_opt_param_scheduler=opt.use_checkpoint_opt_param_scheduler, + override_opt_param_scheduler=not opt.use_checkpoint_opt_param_scheduler, + wsd_decay_steps=opt.lr_wsd_decay_steps, + lr_wsd_decay_style=opt.lr_wsd_decay_style, + ) + + +class _BumblebeeModeCtx(BaseEngineCtx): + """Wrap Bumblebee runtime contexts with VERL's offload behavior.""" + + def __init__(self, engine: BumblebeeEngine, mode: str, **kwargs): + super().__init__(engine=engine, mode=mode, **kwargs) + self._runtime_ctx = None + + def __enter__(self): + super().__enter__() + assert self.engine.runtime is not None and self.engine.handle is not None + if self.mode == "train": + self._runtime_ctx = self.engine.runtime.train_mode(self.engine.handle) + else: + self._runtime_ctx = self.engine.runtime.eval_mode(self.engine.handle) + self._runtime_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert self._runtime_ctx is not None + self._runtime_ctx.__exit__(exc_type, exc_val, exc_tb) + super().__exit__(exc_type, exc_val, exc_tb) + return False + + +@EngineRegistry.register(model_type="language_model", backend="bumblebee", device="cuda") +class BumblebeeEngine(BaseEngine): + """VERL BaseEngine implementation that delegates model lifecycle to Bumblebee.""" + + def __init__( + self, + model_config: HFModelConfig, + engine_config: BumblebeeEngineConfig, + optimizer_config: OptimizerConfig, + checkpoint_config: CheckpointConfig, + ): + super().__init__() + _isolate_compile_cache_per_rank() + self.model_config = model_config + self.engine_config = engine_config + self.optimizer_config = optimizer_config + self.checkpoint_config = checkpoint_config + + self.mode = None + self.device_name = get_device_name() + self.runtime = None + self.handle = None + self.module = None + self._bb_config = None + self._rank = dist.get_rank() if dist.is_initialized() else 0 + + @property + def is_param_offload_enabled(self) -> bool: + return self.engine_config.param_offload + + @property + def is_optimizer_offload_enabled(self) -> bool: + return self.engine_config.optimizer_offload + + def initialize(self): + if self.engine_config.full_determinism: + from verl.workers.engine.utils import enable_full_determinism + + enable_full_determinism(seed=self.engine_config.seed) + + self._bb_config = self._build_bb_config() + self.runtime = create_runtime( + RuntimeConfig( + backend="bb", + hf_path=self.model_config.local_path, + backend_cfg=self._bb_config, + ) + ) + self.handle = self.runtime.build_model() + self.module = self._extract_primary_module() + + if self.handle._optimizer is not None and self.handle._lr_scheduler is None: + self.handle._lr_scheduler = _build_lr_scheduler(self.handle._optimizer, self._bb_config.optimizer) + + self.to( + device="cpu", + model=self.is_param_offload_enabled, + optimizer=self.is_optimizer_offload_enabled, + grad=self.is_param_offload_enabled, + ) + + def train_mode(self, **kwargs): + self._require_initialized() + return _BumblebeeModeCtx(self, mode="train", **kwargs) + + def eval_mode(self, **kwargs): + self._require_initialized() + return _BumblebeeModeCtx(self, mode="eval", **kwargs) + + def optimizer_zero_grad(self): + self._require_initialized() + self.runtime.zero_grad(self.handle) + + def optimizer_step(self): + self._require_initialized() + _, grad_norm, _ = self.runtime.optimizer_step(self.handle) + return grad_norm + + def lr_scheduler_step(self): + self._require_initialized() + if self.handle._lr_scheduler is not None: + self.handle._lr_scheduler.step(1) + return self.handle._optimizer.param_groups[0]["lr"] + return 0.0 + + def forward_backward_batch(self, data: TensorDict, loss_function, forward_only: bool = False) -> dict[str, Any]: + self._require_initialized() + pad_mode = tu.get_non_tensor_data(data=data, key="pad_mode", default=DatasetPadMode.NO_PADDING) + if pad_mode != DatasetPadMode.NO_PADDING: + raise NotImplementedError("BumblebeeEngine only supports pad_mode=no_padding for now.") + + tu.assign_non_tensor(data, sp_size=self.engine_config.cp) + + token_mask = data["loss_mask"] if "loss_mask" in data.keys() else data["response_mask"] + batch_num_tokens = token_mask.sum().to(get_device_id()) + torch.distributed.all_reduce( + batch_num_tokens, + op=torch.distributed.ReduceOp.SUM, + group=self.get_data_parallel_group(), + ) + tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item()) + tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size()) + + micro_batches, indices = prepare_micro_batches( + data=data, + dp_group=self.get_data_parallel_group(), + same_micro_num_in_dp=True, + ) + + if self._use_runtime_forward_backward(): + return self._forward_backward_batch_with_runtime( + data=data, + micro_batches=micro_batches, + indices=indices, + loss_function=loss_function, + forward_only=forward_only, + ) + + outputs = [] + num_micro_batches = len(micro_batches) + for micro_idx, micro_batch in enumerate(micro_batches): + tu.assign_non_tensor(micro_batch, micro_batch_idx=micro_idx) + micro_batch = micro_batch.to(get_device_id()) + model_inputs = self._make_model_inputs(micro_batch) + + pre_forward_hook = self.handle._extras.get("pre_forward_hook") + if pre_forward_hook is not None: + pre_forward_hook(torch.tensor(1.0 / num_micro_batches, device=get_device_id())) + + with torch.no_grad() if forward_only else torch.enable_grad(): + raw_output = self.module( + input_ids=model_inputs["input_ids"], + position_ids=model_inputs["position_ids"], + packed_seq_params=model_inputs["packed_seq_params"], + labels=model_inputs["labels"], + loss_mask=model_inputs.get("loss_mask"), + temperature=model_inputs["temperature"], + use_fused_kernels=model_inputs["use_fused_kernels"], + calculate_entropy=model_inputs["calculate_entropy"], + ) + + model_output = self._build_verl_model_output( + raw_output=raw_output, + micro_batch=micro_batch, + inputs=model_inputs, + ) + + if loss_function is not None: + loss, metrics = loss_function( + model_output=model_output, + data=micro_batch, + dp_group=self.get_data_parallel_group(), + ) + else: + loss = torch.zeros((), device=get_device_id(), dtype=torch.float32) + metrics = {} + if raw_output.get("mtp_loss") is not None: + metrics = dict(metrics) + mtp_loss = self._reduce_mtp_metric(raw_output["mtp_loss"]) + metrics["mtp_losses/mtp_1_loss"] = ( + float(mtp_loss.item()) if mtp_loss.numel() == 1 else mtp_loss.cpu().tolist() + ) + + if not forward_only and loss_function is not None: + loss.backward() + + outputs.append( + { + "model_output": model_output, + "loss": loss.detach().item(), + "metrics": metrics, + } + ) + + if not forward_only: + finalize_grads = self.handle._extras.get("finalize_grads") + if finalize_grads is not None: + finalize_grads() + + result = postprocess_batch_func(output_lst=outputs, indices=indices, data=data) + return result + + def get_per_tensor_param(self, **kwargs): + self._require_initialized() + if self.is_param_offload_enabled: + self.to("cuda", model=True, optimizer=False, grad=False) + export_kwargs = { + key: kwargs[key] + for key in ("limit", "include_mtp_only", "include_local_prefixes") + if key in kwargs + } + return self.runtime.export_weights(self.handle, **export_kwargs), None + + def get_data_parallel_size(self): + if self.handle is None: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + return world_size // (self.engine_config.tp * self.engine_config.cp * self.engine_config.pp) + return self.handle.dp_size + + def get_data_parallel_rank(self): + if self.handle is None: + rank = dist.get_rank() if dist.is_initialized() else 0 + dense_dp = self.get_data_parallel_size() + return (rank // (self.engine_config.tp * self.engine_config.cp)) % dense_dp + return self.handle.dp_rank + + def get_data_parallel_group(self): + if self.handle is None: + if ( + self.engine_config.tp == 1 + and self.engine_config.cp == 1 + and self.engine_config.pp == 1 + and dist.is_initialized() + ): + return dist.group.WORLD + return None + return self.handle.dp_group + + def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): + self._require_initialized() + if model or not (optimizer or grad): + super().to(device=device, model=model, optimizer=optimizer, grad=grad) + self.runtime.to(self.handle, device, model=model, optimizer=optimizer, grad=grad) + + def save_checkpoint( + self, + local_path: str, + hdfs_path: str | None = None, + global_step: int = 0, + max_ckpt_to_keep: int | None = None, + **kwargs, + ) -> None: + del hdfs_path, max_ckpt_to_keep, kwargs + self._require_initialized() + + save_contents = self.checkpoint_config.get("save_contents", None) + if save_contents is not None and not any(item in save_contents for item in ("model", "optimizer")): + if self._rank == 0: + print(f"Skipping Bumblebee checkpoint save at step {global_step}: save_contents={save_contents}") + if dist.is_initialized(): + dist.barrier() + return + + os.makedirs(local_path, exist_ok=True) + placement_fn, expert_classifier = self._checkpoint_hooks() + reload_params_for_save = self.is_param_offload_enabled + if reload_params_for_save: + self.to(device="cuda", model=True, optimizer=False, grad=False) + torch.cuda.synchronize() + try: + save_training_checkpoint( + self.module, + self.handle._optimizer, + global_step, + local_path, + self.handle._config.parallel, + self.handle._parallel_state, + get_placements=placement_fn, + is_expert=expert_classifier, + ) + if self.handle._lr_scheduler is not None and self._rank == 0: + torch.save(self.handle._lr_scheduler.state_dict(), os.path.join(local_path, _LR_SCHEDULER_STATE)) + if dist.is_initialized(): + dist.barrier() + finally: + if reload_params_for_save: + self.to(device="cpu", model=True, optimizer=False, grad=False) + + def load_checkpoint( + self, + local_path: str, + hdfs_path: str | None = None, + del_local_after_load: bool = True, + **kwargs, + ) -> None: + del hdfs_path, del_local_after_load, kwargs + self._require_initialized() + + placement_fn, expert_classifier = self._checkpoint_hooks() + reload_params_for_load = self.is_param_offload_enabled + if reload_params_for_load: + self.to(device="cuda", model=True, optimizer=False, grad=False) + torch.cuda.synchronize() + try: + load_training_checkpoint( + self.module, + self.handle._optimizer, + local_path, + self.handle._config.parallel, + self.handle._parallel_state, + get_placements=placement_fn, + is_expert=expert_classifier, + ) + scheduler_path = os.path.join(local_path, _LR_SCHEDULER_STATE) + if self.handle._lr_scheduler is not None and os.path.exists(scheduler_path): + state = torch.load(scheduler_path, map_location="cpu", weights_only=False) + self.handle._lr_scheduler.load_state_dict(state) + if dist.is_initialized(): + dist.barrier() + finally: + if reload_params_for_load: + self.to(device="cpu", model=True, optimizer=False, grad=False) + + def is_mp_src_rank_with_outputs(self): + if self.handle is None: + rank = dist.get_rank() if dist.is_initialized() else 0 + dense_dp = self.get_data_parallel_size() + tp_rank = rank % self.engine_config.tp + cp_rank = (rank // self.engine_config.tp) % self.engine_config.cp + pp_rank = rank // (self.engine_config.tp * self.engine_config.cp * dense_dp) + return tp_rank == 0 and cp_rank == 0 and pp_rank == self.engine_config.pp - 1 + return self.runtime.is_mp_src_rank_with_outputs(self.handle) + + def _require_initialized(self) -> None: + if self.runtime is None or self.handle is None: + raise RuntimeError("BumblebeeEngine is not initialized yet.") + + def _build_bb_config(self) -> BBConfig: + return BBConfig( + model_name=self._resolve_model_name(), + impl=self.engine_config.impl, + hf_path=self.model_config.local_path, + parallel=ParallelConfig( + tp=self.engine_config.tp, + etp=self.engine_config.etp or 1, + ep=self.engine_config.ep, + pp=self.engine_config.pp, + vpp=self.engine_config.vpp, + cp=self.engine_config.cp, + ), + optimizer=self._build_bb_optimizer_config(), + attention_backend_override=self.engine_config.attention_backend_override, + router_aux_loss_coef=self.engine_config.router_aux_loss_coef, + impl_cfg=self._build_impl_cfg(), + ) + + def _resolve_model_name(self) -> str: + if self.engine_config.model_name != "auto": + return self.engine_config.model_name + return resolve_model_type_from_hf(self.model_config.hf_config) + + def _build_impl_cfg(self) -> dict[str, Any]: + impl_cfg = dict(self.engine_config.impl_cfg) + if impl_cfg.get("use_thd", True) is not True: + raise ValueError("BumblebeeEngine supports only THD/no-padding SFT; set engine.impl_cfg.use_thd=True.") + impl_cfg["use_thd"] = True + mtp_cfg = getattr(self.model_config, "mtp", None) + if mtp_cfg is not None: + mtp_enable = bool(getattr(mtp_cfg, "enable", False)) + mtp_enable_train = mtp_enable and bool(getattr(mtp_cfg, "enable_train", False)) + impl_cfg["mtp_enable"] = mtp_enable + impl_cfg["mtp_enable_train"] = mtp_enable_train + impl_cfg["mtp_detach_encoder"] = bool(getattr(mtp_cfg, "detach_encoder", False)) + impl_cfg["mtp_loss_scaling_factor"] = float(getattr(mtp_cfg, "mtp_loss_scaling_factor", 0.1)) + if self.engine_config.full_determinism: + impl_cfg.setdefault("deterministic", True) + if self.engine_config.forward_only: + impl_cfg["optimizer"] = None + return impl_cfg + + def _build_bb_optimizer_config(self) -> BBOptimizerConfig: + optimizer_name = self._normalize_optimizer_name(self.optimizer_config) + betas = tuple(getattr(self.optimizer_config, "betas", (0.9, 0.999))) + override = getattr(self.optimizer_config, "override_optimizer_config", {}) or {} + offload_fraction = override.get("offload_fraction", override.get("optimizer_offload_fraction")) + if offload_fraction is None and override.get("optimizer_cpu_offload"): + offload_fraction = 1.0 + if offload_fraction is None and self.is_optimizer_offload_enabled: + offload_fraction = 1.0 + + min_lr = getattr(self.optimizer_config, "min_lr", None) + min_lr_ratio = getattr(self.optimizer_config, "min_lr_ratio", None) + if min_lr is None: + min_lr = 0.0 if min_lr_ratio is None else self.optimizer_config.lr * min_lr_ratio + + lr_decay_style = getattr(self.optimizer_config, "lr_decay_style", None) + if lr_decay_style is None: + lr_decay_style = getattr(self.optimizer_config, "lr_scheduler_type", "constant") + + return BBOptimizerConfig( + optimizer=optimizer_name, + lr=self.optimizer_config.lr, + min_lr=min_lr, + clip_grad=self.optimizer_config.clip_grad, + weight_decay=self.optimizer_config.weight_decay, + lr_warmup_steps_ratio=self.optimizer_config.lr_warmup_steps_ratio, + total_training_steps=self.optimizer_config.total_training_steps, + lr_warmup_steps=self.optimizer_config.lr_warmup_steps, + lr_warmup_init=getattr(self.optimizer_config, "lr_warmup_init", 0.0), + lr_decay_steps=getattr(self.optimizer_config, "lr_decay_steps", None), + lr_decay_style=lr_decay_style, + weight_decay_incr_style=getattr(self.optimizer_config, "weight_decay_incr_style", "constant"), + lr_wsd_decay_style=getattr(self.optimizer_config, "lr_wsd_decay_style", "exponential"), + lr_wsd_decay_steps=getattr(self.optimizer_config, "lr_wsd_decay_steps", None), + use_checkpoint_opt_param_scheduler=getattr( + self.optimizer_config, + "use_checkpoint_opt_param_scheduler", + False, + ), + adam_beta1=betas[0], + adam_beta2=betas[1], + adam_eps=override.get("adam_eps", override.get("eps")), + offload_fraction=offload_fraction, + use_precision_aware_optimizer=override.get("use_precision_aware_optimizer"), + decoupled_weight_decay=override.get("decoupled_weight_decay"), + ) + + @staticmethod + def _normalize_optimizer_name(config: OptimizerConfig) -> str: + optimizer_name = getattr(config, "optimizer", "adam") + lower = str(optimizer_name).lower() + if "adam" in lower: + return "adam" + raise ValueError(f"BumblebeeEngine only supports Adam-style optimizers today, got {optimizer_name!r}") + + def _extract_primary_module(self): + model = self.handle._model + if isinstance(model, list | tuple): + if not model: + raise RuntimeError("Bumblebee runtime returned an empty model chunk list.") + if len(model) > 1: + return torch.nn.ModuleList(model) + return model[0] + return model + + def _use_runtime_forward_backward(self) -> bool: + ps = self.handle._parallel_state + return ps.pp_size > 1 + + def _forward_backward_batch_with_runtime( + self, + *, + data: TensorDict, + micro_batches: list[TensorDict], + indices, + loss_function, + forward_only: bool, + ) -> dict[str, Any]: + runtime_batches = [] + num_micro_batches = len(micro_batches) + batch_num_tokens = tu.get_non_tensor_data(data=data, key="batch_num_tokens", default=None) + if batch_num_tokens is None: + raise ValueError("BumblebeeEngine PP/CP SFT requires batch_num_tokens for VERL-compatible loss scaling.") + if batch_num_tokens <= 0: + raise ValueError(f"batch_num_tokens must be positive, got {batch_num_tokens}.") + loss_scale = self.get_data_parallel_size() * num_micro_batches / float(batch_num_tokens) + for micro_idx, micro_batch in enumerate(micro_batches): + tu.assign_non_tensor(micro_batch, micro_batch_idx=micro_idx) + micro_batch = micro_batch.to(get_device_id()) + model_inputs = self._make_model_inputs(micro_batch) + runtime_batches.append( + { + "input_ids": model_inputs["input_ids"], + "position_ids": model_inputs["position_ids"], + "packed_seq_params": model_inputs["packed_seq_params"], + "labels": model_inputs["labels"], + "loss_mask": model_inputs.get("loss_mask"), + "loss_scale": loss_scale, + "temperature": model_inputs["temperature"], + "use_fused_kernels": model_inputs["use_fused_kernels"], + "calculate_entropy": model_inputs["calculate_entropy"], + "_verl_micro_batch": micro_batch, + "_verl_inputs": model_inputs, + } + ) + + runtime_loss_fn = None + if loss_function is not None or forward_only: + runtime_loss_fn = self._make_runtime_loss_fn(loss_function, forward_only=forward_only) + + result = self.runtime.forward_backward( + self.handle, + iter(runtime_batches), + loss_fn=runtime_loss_fn, + num_microbatches=num_micro_batches, + forward_only=forward_only, + ) + metrics = dict(result.metrics) + micro_outputs = metrics.pop("_micro_outputs", None) + if micro_outputs is not None and self.is_mp_src_rank_with_outputs(): + return postprocess_batch_func(output_lst=micro_outputs, indices=indices, data=data) + loss = float(metrics.get("loss", 0.0)) + return { + "model_output": {}, + "loss": [loss], + "metrics": {key: [value] for key, value in metrics.items()}, + } + + def _make_model_inputs(self, micro_batch: TensorDict) -> dict[str, torch.Tensor]: + input_ids = micro_batch["input_ids"] + if not getattr(input_ids, "is_nested", False): + raise NotImplementedError("BumblebeeEngine supports only nested no-padding THD batches.") + + ps = self.handle._parallel_state + loss_mask = self._loss_mask_for_packing(micro_batch, input_ids) + packed_batch = pack_nested_thd( + input_ids, + tp_size=ps.tp_size, + cp_size=ps.cp_size, + cp_rank=ps.cp_rank, + cp_group=ps.cp_group if ps.cp_size > 1 else None, + labels=input_ids, + roll_labels=True, + loss_mask=loss_mask, + roll_loss_mask=True, + ) + use_fused_kernels = tu.get_non_tensor_data( + data=micro_batch, + key="use_fused_kernels", + default=self.engine_config.use_fused_kernels, + ) + + return { + "input_ids": packed_batch.input_ids, + "labels": packed_batch.labels, + "loss_mask": packed_batch.loss_mask, + "position_ids": packed_batch.position_ids, + "packed_seq_params": packed_batch.packed_seq_params, + "packed_batch": packed_batch, + "temperature": self._scalar_temperature(micro_batch), + "use_fused_kernels": use_fused_kernels, + "calculate_entropy": tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False), + } + + @staticmethod + def _loss_mask_for_packing(micro_batch: TensorDict, input_ids: torch.Tensor) -> torch.Tensor | None: + if "loss_mask" not in micro_batch.keys(): + return None + + loss_mask = micro_batch["loss_mask"] + if getattr(loss_mask, "is_nested", False): + return loss_mask + + rows = [] + for seq_ids, row_mask in zip(input_ids.unbind(0), loss_mask, strict=True): + seq_len = seq_ids.numel() + response_tokens = int(row_mask.sum().item()) + if response_tokens > seq_len: + raise ValueError( + f"response loss mask has {response_tokens} tokens but packed input sequence has {seq_len} tokens" + ) + full_mask = torch.zeros(seq_len, dtype=row_mask.dtype, device=row_mask.device) + if response_tokens: + full_mask[-response_tokens:] = row_mask[:response_tokens] + rows.append(full_mask) + return torch.nested.as_nested_tensor(rows, layout=torch.jagged) + + def _build_verl_model_output( + self, + *, + raw_output: dict[str, torch.Tensor], + micro_batch: TensorDict, + inputs: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + del micro_batch + log_probs = raw_output.get("log_probs") + if log_probs is None: + raise ValueError("Bumblebee THD model output must contain token log_probs.") + nested_log_probs = unpack_packed_thd_to_nested(log_probs, inputs["packed_batch"]) + output = {"log_probs": nested_log_probs} + entropy = raw_output.get("entropy") + if entropy is not None: + output["entropy"] = unpack_packed_thd_to_nested(entropy, inputs["packed_batch"]) + return output + + def _make_runtime_loss_fn(self, loss_function, *, forward_only: bool): + def _loss_fn(raw_output: dict[str, torch.Tensor], runtime_batch: dict[str, Any]): + micro_batch = runtime_batch["_verl_micro_batch"] + inputs = runtime_batch["_verl_inputs"] + model_output = self._build_verl_model_output( + raw_output=raw_output, + micro_batch=micro_batch, + inputs=inputs, + ) + raw_output["_verl_model_output"] = model_output + if loss_function is not None: + loss, metrics = loss_function( + model_output=model_output, + data=micro_batch, + dp_group=self.get_data_parallel_group(), + ) + else: + loss = torch.zeros((), device=get_device_id(), dtype=torch.float32) + metrics = {} + + if raw_output.get("mtp_loss") is not None: + metrics = dict(metrics) + mtp_loss = self._reduce_mtp_metric(raw_output["mtp_loss"]) + metrics["mtp_losses/mtp_1_loss"] = ( + float(mtp_loss.item()) if mtp_loss.numel() == 1 else mtp_loss.cpu().tolist() + ) + + raw_output["_verl_metrics"] = metrics + return loss, metrics + + return _loss_fn + + def _mtp_enable_train(self) -> bool: + mtp_cfg = getattr(self.model_config, "mtp", None) + return bool( + mtp_cfg is not None + and getattr(mtp_cfg, "enable", False) + and getattr(mtp_cfg, "enable_train", False) + ) + + def _reduce_mtp_metric(self, mtp_loss: torch.Tensor) -> torch.Tensor: + mtp_loss = mtp_loss.detach().float().clone() + dp_group = self.get_data_parallel_group() + if dist.is_initialized() and dp_group is not None: + dist.all_reduce(mtp_loss, op=dist.ReduceOp.AVG, group=dp_group) + return mtp_loss + + @staticmethod + def _scalar_temperature(micro_batch: TensorDict) -> float: + if "temperature" not in micro_batch.keys(): + return 1.0 + temperature = micro_batch["temperature"] + if not isinstance(temperature, torch.Tensor): + return float(temperature) + values = temperature.values() if getattr(temperature, "is_nested", False) else temperature.reshape(-1) + if values.numel() == 0: + return 1.0 + first = values[0].detach() + if not torch.all(values.detach() == first).item(): + raise NotImplementedError("BumblebeeEngine currently supports scalar temperature only.") + return float(first.float().item()) + + def _checkpoint_hooks(self): + proto = self.handle._extras.get("protocol") + placement_fn = getattr(proto, "PLACEMENT_FN", default_placement_fn) + expert_classifier = getattr(proto, "EXPERT_CLASSIFIER", default_expert_classifier) + return placement_fn, expert_classifier diff --git a/verlbb/verlbb/engine/config.py b/verlbb/verlbb/engine/config.py new file mode 100644 index 00000000..916b96d2 --- /dev/null +++ b/verlbb/verlbb/engine/config.py @@ -0,0 +1,33 @@ +"""Config objects for the VerlBB Bumblebee engine.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from verl.workers.config.engine import EngineConfig + + +@dataclass +class BumblebeeEngineConfig(EngineConfig): + """Minimal VERL-facing config for the external Bumblebee engine.""" + + strategy: str = "bumblebee" + model_name: str = "auto" + impl: str = "lite" + + tp: int = 1 + etp: int | None = None + ep: int = 1 + pp: int = 1 + vpp: int = 1 + cp: int = 1 + + attention_backend_override: str | None = "flash" + router_aux_loss_coef: float | None = None + impl_cfg: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + super().__post_init__() + if self.strategy != "bumblebee": + raise ValueError(f"BumblebeeEngineConfig expects strategy='bumblebee', got {self.strategy!r}")