raw.sh

Batched reward model inference and Best-of-N sampling

Published on November 19, 2024
r
robert washbourne
@rawsh0

Reward models have been a key part of reinforcement learning on top of LLMs, used broadly in techniques like RLHF and as LLM-as-a-judge critics in evals. They have also been used in the data preparation phase of preference optimization methods like SimPO, where a reward model was used to create the preference data used to train models like princeton-nlp/gemma-2-9b-it-SimPO.

I recently had some trouble figuring out how to run high throughput reward model inference. Offline, you can just collect everything you need to score in advance and go through all of your data in batches. But for a lot of use cases, such as for search methods like tree search or MCTS, this is hard to do efficiently. Unfortunately, easy to set up inference servers like vllm and llama.cpp don't support sequence classification models1 out of the box (yet2)

Dynamic batching with batched

Thankfully, Mixedbread has a great generic dynamic batching library that makes it really easy to inference reward models3.

Here's how we can host an endpoint on Modal4. First, let's add some helper functions and setup the image:

from typing import List, Dict
import modal

image = modal.Image.debian_slim().pip_install([
"torch", "transformers", "accelerate", "batched", "hf_transfer"
]).env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})

app = modal.App("reward-api", image=image)

MODEL_NAME = "RLHFlow/ArmoRM-Llama3-8B-v0.1"

with image.imports():
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from batched import inference

def validate_messages(messages: List[Dict[str, str]]):
if not messages or len(messages) < 2:
raise ValueError("Messages must contain at least a user and assistant message")
if not all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages):
raise ValueError("Each message must have 'role' and 'content' fields")

class RewardModelHelper:
def __init__(self, model):
self.model = model

@inference.dynamically(batch_size=8, timeout_ms=100.0)
def score_batch(self, features: dict[str, torch.Tensor]) -> torch.Tensor:
with torch.no_grad():
# Move input to same device as model
inputs = {k: v.to(self.model.device) for k, v in features.items()}
return self.model(inputs["input_ids"]).score.float()
from typing import List, Dict
import modal

image = modal.Image.debian_slim().pip_install([
"torch", "transformers", "accelerate", "batched", "hf_transfer"
]).env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})

app = modal.App("reward-api", image=image)

MODEL_NAME = "RLHFlow/ArmoRM-Llama3-8B-v0.1"

with image.imports():
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from batched import inference

def validate_messages(messages: List[Dict[str, str]]):
if not messages or len(messages) < 2:
raise ValueError("Messages must contain at least a user and assistant message")
if not all(isinstance(m, dict) and 'role' in m and 'content' in m for m in messages):
raise ValueError("Each message must have 'role' and 'content' fields")

class RewardModelHelper:
def __init__(self, model):
self.model = model

@inference.dynamically(batch_size=8, timeout_ms=100.0)
def score_batch(self, features: dict[str, torch.Tensor]) -> torch.Tensor:
with torch.no_grad():
# Move input to same device as model
inputs = {k: v.to(self.model.device) for k, v in features.items()}
return self.model(inputs["input_ids"]).score.float()

I'm using RLHFlow/ArmoRM-Llama3-8B-v0.1, the same reward model used by the SimPO team, which uses Mixture-of-Experts aggregation of reward objectives with each expert corresponding to a human-interpretable objective (such as helpfulness, safety, coherence...)

ArmoRM MoE

Check out the blog post from RLHFlow for more details5.

Now we can add the meat of the code, a web endpoint to score a LLM completion using dynamic batching under the hood.

@app.cls(
gpu=modal.gpu.A10G(),
allow_concurrent_inputs=1000,
container_idle_timeout=120,
)
class Model:
def load_model(self):
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
device_map="cuda",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
use_safetensors=True,
)
return model

@modal.build()
def build(self):
self.load_model()

@modal.enter()
def setup(self):
self.model = self.load_model()
self.tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
)
self.score_batch = RewardModelHelper(self.model).score_batch

@modal.web_endpoint(method="POST")
async def score(self, messages_dict: Dict[str, List[Dict[str, str]]]):
messages = messages_dict["messages"]
validate_messages(messages)
inputs = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
padding=True,
truncation=True,
)
score = await self.score_batch.acall({"input_ids": inputs})
return {"score": score[0].item()}
@app.cls(
gpu=modal.gpu.A10G(),
allow_concurrent_inputs=1000,
container_idle_timeout=120,
)
class Model:
def load_model(self):
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
device_map="cuda",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
use_safetensors=True,
)
return model

@modal.build()
def build(self):
self.load_model()

@modal.enter()
def setup(self):
self.model = self.load_model()
self.tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
)
self.score_batch = RewardModelHelper(self.model).score_batch

@modal.web_endpoint(method="POST")
async def score(self, messages_dict: Dict[str, List[Dict[str, str]]]):
messages = messages_dict["messages"]
validate_messages(messages)
inputs = self.tokenizer.apply_chat_template(
messages,
return_tensors="pt",
padding=True,
truncation=True,
)
score = await self.score_batch.acall({"input_ids": inputs})
return {"score": score[0].item()}

The allow_concurrent_inputs setting prevents Modal from starting multiple servers so we can take advantage of dynamic batching. Now we can use this reward model endpoint with a hosted API or an open source model + vLLM.

Evaluating the reward model

Since my cloud GPU bill this month is kind of ridiculous, I decided to evaluate using a small random 100 question subset of TruthfulQA, a multiple choice benchmark for LLMs.

One way to benchmark the reward model on multiple choice data is to check the score for each possible answer to the question6. Here are the results:

Average rank of correct answer: 2.10
Times correct answer ranked first: 53/100
Average score difference from best: 0.0163
Average correct answer score: 0.0659
Average best score: 0.0821
Average rank of correct answer: 2.10
Times correct answer ranked first: 53/100
Average score difference from best: 0.0163
Average correct answer score: 0.0659
Average best score: 0.0821

The reward model scores the correct answer the highest 53% of the time. Next we will evaluate Best-of-N sampling, choosing the best LLM completion using the reward model.

Best-of-N-Sampling

Now we can test LLM + reward model verifier, a super simple way to add more test time compute. Best-of-N-Sampling doesn't add much latency (we can sample and score hundreds of completions in parallel with batching) and is super easy to implement7.

I get the following zero-shot accuracy with Llama-3.1-8B-Instruct:

# completions1248163264
accuracy (%)58636463706767

best of n results

Using the reward model with Best-of-N-Sampling shows a 20.7% increase in accuracy from n=1 to n=16, with performance plateauing with more than 16 generations.

Footnotes

  1. Reward models often use architectures similar to SequenceClassification, but output a single scalar rather than multiple class outputs.

  2. vLLM supports Qwen2.5-Math-RM-72B (see the model implementation), but reward model API is still on the roadmap.

  3. https://www.mixedbread.ai/blog/dynamic-batching

  4. Not sponsored, I just really like the development workflow. To deploy this on your own machine you can swap out the Modal code for a local Flask server.

  5. RLHFlow has a lot of great resources. Also check out RLHFlow/RLHF-Reward-Modeling for reward model training recipes.

  6. Reward model benchmark code: eval_reward.py. Note that evaluating answers directly is "harder" than evaluating model completions, since we aren't providing any chain-of-thought data to the reward model, just the final answer.

  7. Best-of-N-Sampling benchmark code: best_of_n.py