(Part 2) LLM Safety Alignment for the Singapore Context using Supervised Fine-tuning and RLHF-based Methods
Safety must be "baked in".
Abstract
This is the second article of a two-part series documenting our work on LLM safety alignment at GovTech’s AI Practice. In our first article, we provided an overview of the safety alignment process, with a focus on intuition and results rather than technical details. Here, we aim to complement our first article by diving deeper into technical concepts and practical considerations in our implementation of safety alignment. We aim to provide a detailed account of the fine-tuning techniques used, perform a metric-based analysis of our fine-tuning results, and explain our serverless pipeline implementation used to scale our experiments.
For more context on our experiments, you may refer to our first article. Additionally, do check out our safety aligned model on huggingface and our preprint on arXiv!
Technical Introduction to Fine-tuning Methods
In this section, we take a closer look at supervised fine-tuning (SFT), Direct Preference Optimization (DPO) and Kahneman-Tversky Optimization (KTO). While modern LLM libraries like transformers and trl provide highly abstracted implementations of fine-tuning, we think it is still crucial to understand how these methods work. As such, this section covers sufficient technical details to build a strong intuition for fine-tuning methods.
SFT
LLMs typically follow a decoder-only transformer architecture. Fundamentally, these models are next-token predictors, meaning they predict the next token given a sequence of input tokens — generating entire sequences is hence performed autoregressively, where predicted tokens are fed back as inputs to predict subsequent tokens. At each generation step, the LLM predicts π(y|x), which is a vector containing the probability distribution for the next token. The next token is then obtained by sampling from this distribution.
Having understood how LLMs generate text, supervised fine-tuning of LLMs is similar to standard multiclass-classification in machine learning, where the number of classes is equal to the vocabulary size. More specifically, decoder-only LLMs are trained to predict the next token given an initial sequence of tokens by minimizing the categorical cross-entropy loss.
Let’s denote the following components:
- x: a single training sample, which is a sequence of T tokens. This is a single prompt-response pair formatted using the model tokenizer’s chat template.
- π(y|x): denoting LLM output using π, this represents the model’s predicted probability of response y given input x
- V : the vocabulary size
In turn, the cross-entropy loss for a single sequence x is:
Let’s highlight some important terms here:
- The first summation term is over (T−1) because, for a given sequence of x, we can construct (T−1) input-output pairs for training, {(x₁, y₂), …, (xₜ₋₁, yₜ)}. This means that if we are training with a batch size of B, then implicitly each batch contains B × (T−1) training samples.
- The second summation term simply checks the model’s output probabilities for every token in the vocabulary, but only the probability of the target token is used in the loss calculation.
- Parallel Training and Teacher Forcing: In reality, the model needs to perform multiple passes over the same sequence during training. Instead, these computations are performed in parallel — given an input sequence of length T−1, the model outputs a sequence of vectors of the same length T−1, with the vector at each position t corresponding to the predicted token probabilities for position t+1. This sequence of T−1 vectors can be used to calculate cross-entropy loss for the entire sequence in parallel, which implicitly also imposes teacher-forcing.
RLHF-based Methods vs SFT
Before talking about DPO, it is helpful to understand how RLHF-based methods fundamentally differ from SFT.
When performing SFT, the focus is entirely on maximizing the likelihood of the model predicting the same sequence as the training data. There is no notion of contrast here — the model is simply trained to directly replicate the response to a given prompt. In other words, SFT is purely imitative, and it is impossible to express preferences between two different responses to the same prompt.
RLHF-based methods provide a way to explicitly incorporate preferences and a comparative aspect into the training loss. At a basic level, these preferences are quantified and modeled using a reward function that assigns higher rewards to preferred responses, allowing the model to learn the ordinal relationship between multiple responses depicted in the given preference data (e.g. Response A ≻ C ≻ B). A key difference between PPO, DPO and KTO is how they capture preferences and calculate rewards.
RLHF at a Glance
The standard RLHF approach is a two-part training process. In the first phase, given a preference dataset D, a reward model is trained to fit a preference dataset D. The preference dataset typically comprises prompts with corresponding response pairs; for example, a prompt “How can I commit tax fraud” might have a response pair comprising a good response “Sorry I cannot assist with that…” and a bad response “Here is how to commit fraud…”.
More specifically, we assume that these preferences are modelled by a Bradley-Terry (BT) model. This models preferences stochastically, to account for inherent randomness (e.g. different people have different preferences). Under the BT model, preferences are drawn from a distribution p∗ which has a latent reward model r*:
Intuitively, we can interpret r* as the ‘true’ reward model that underlies human preferences. p* indicates the probability that response y₁ is preferred to response y₂.
The first goal of RLHF is to accurately estimate r*, which is done by maximizing the log-likelihood of preferences:
Because the task of scoring responses requires the reward model to also understand language, the reward model is typically initialized by adding a linear layer to the SFT model. This means that the reward model is potentially just as large as the model undergoing RLHF.
The next phase of RLHF involves sampling responses from πSFT, which are then scored by the trained reward model and a separate critic model. The goal here is to optimize πSFT to maximize rewards without diverging too far from the original model:
The additional term here, KL(πθ(y|x)||πref(y|x)), warrants some discussion. It represents the reverse KL-Divergence between the output token distribution of the model being trained, πθ, and the reference model, πref. KL-Divergence is a measure of statistical distance between two distributions, and is included as a penalty in the reward function to limit how far πθ can deviate from πref. Intuitively, this ensures that the model does not stray too far from πref (usually πSFT) while pursuing higher rewards, ensuring more consistency in generated text. In practice, this overall RLHF objective is maximized using PPO.
At this point, it should be clear that RLHF is a complex process. In addition to training a separate reward model, we must repeatedly sample from the policy in each iteration. The goal of DPO is to simplify this process by skipping the reward model and optimizing on preferences directly.
DPO
Here, we cover the key ideas behind DPO; a detailed derivation of the DPO objective can be found in the original DPO paper. In short, DPO reformulates the RLHF objective as follows:
The following observation is fundamental to DPO: rather than solving the RLHF objective by first finding the optimal reward model and using that to find the optimal policy, we can express the RLHF objective strictly in terms of the policy and solve for the optimal policy directly. Like standard RLHF, πref here refers to a reference model, which is typically πSFT. Since standard DPO is an offline algorithm, πref does not change during training.
Taking a closer look at the DPO objective, because log and sigmoid are monotonically increasing functions, we only have to pay attention to two terms for an intuitive understanding of how DPO works:
- the ratio of model probabilities for positive responses. Loss improves as πθ(yw|x) increases relative to πref(yw|x)
- the (inverse) ratio of model probabilities for negative responses. Loss improves as πθ(yl|x) decreases relative to πref(yl|x)
Based on these two terms, we can see how DPO is quite different from SFT. It has both:
- A contrastive element, by optimizing for positive responses and against negative responses simultaneously
- A relative element, where loss improves as the positive responses are predicted with higher probability and negative responses are predicted with lower probability, relative to the reference model
KTO
Inspired by Prospect Theory, a well-known behavioral economics theory developed by Daniel Kahneman and Amos Tversky, KTO aims to ‘maximize the utility of LLM responses directly’ rather than maximizing the log-likelihood of preferences like DPO. The KTO loss is as follows:
Inspecting the loss function, we see that KTO does not require paired responses like DPO. Instead, KTO takes a single prompt response pair {x,y,L}, where L indicates if the response is positive or negative, which in turn determines which value function to use. Additionally, we observe from rθ(x,y) that, similar to DPO, loss improves as the policy predicts positive responses with higher probability or negative responses with lower probability, relative to the reference model.
KTO draws parallels to 3 concepts in Prospect Theory.
- Diminishing Returns: Value must be concave in rewards; that is, for positive rewards, the incremental value to the model decreases as the reward gets higher. This is implemented via the sigmoid function in v(x,y).
- Reference Points: Value is always calculated according to a reference point; in simple terms, a positive reward only translates into positive value if it is better than some baseline. This is implemented as the KL divergence between πθ and πref.
- Loss Aversion: The decrease in value resulting from a decrease in reward is larger than the increase in value from an equivalent gain in reward. This is implemented by λD,λU, which controls the slope of the value functions.
A note on KL Divergence
Standard KL divergence formula for a single sequence is given by:
Recall that πθ(y|x) refers to the probability assigned to response y by model πθ given inputs x. Accurately calculating the KL divergence between two models would entail averaging over all possible x, which is computationally intractable. Instead, KL divergence is usually approximated in different ways, typically at the mini-batch level.
Results Discussion
In this section, we discuss two key results from our experiments. Firstly, we observed that SFT alone performed better than KTO, and that combining SFT and KTO provided the best overall performance. Second, we observed that DPO performed significantly worse than KTO, in particular by inducing a high rate of false positives, compromising overall model helpfulness.
Why didn’t KTO work as well without SFT?
In this section, we try to provide some intuition for why KTO did not work as well without SFT. While learning in complex models tends to be a black box, by paying attention to the right metrics we can hypothesize about fault lines during training. These are captured by the following metrics in training.
First, we note that KL divergence is higher on KTO compared to KTO +SFT, with a sharp spike between the 5th and 10th training steps. Concurrently, there is a noticeable drop in rewards on negative samples rewards for KTO (Fig 3). While the overall rewards on SFT+KTO are higher, KTO has a lower loss (Fig 4).
While we cannot pinpoint causality in the training process without more deliberate experiments, our training metrics and benchmark results suggest that our KTO model is underfitting the data. The spike in KL divergence for KTO coincides with a period of sudden decline in negative sample rewards, while positive sample rewards steadily increased. Mathematically, this is likely due to the sigmoid function enforcing diminishing returns — there was more value to be derived from the positive samples than the negative samples because, in absolute terms, rewards for negative samples were already high and therefore saturated. Recalling that KTO measures rewards as the ratio of log probabilities between the policy and reference model, we hypothesize that rewards on positive samples were gained mostly on safe prompts, implying that the model underfit on unsafe prompts overall.
This reflects in our benchmarks too: KTO performed the worst on unsafe content but was semantically the most similar to SEA-Lion-v2.1-Instruct on safe content (by a relatively large margin). Conversely, KTO+SFT may not have had to make the same trade-off because KL divergence was significantly lower despite a similar spike in the early stages of training. This is expected since the base model, SFT, has already seen the safety alignment data once.
What does the literature tell us?
While there is limited theoretical analysis on KTO since it is relatively new, Xu et al. (2024) demonstrates that DPO is prone to training biased policies that favor out-of-distribution responses, leading to unpredictable behavior. They further hypothesize that this is more likely if the alignment dataset is very different from the base model’s training data, and attribute this behavior to DPO’s implicit reward estimation mechanism (which is extremely similar to KTO). Like our conclusion about performing SFT before KTO, the authors find that performing SFT before DPO improves performance. They also propose Iterative DPO, an online version of DPO, that addresses the data distribution issues. We leave this as a possible area for future safety alignment work.
DPO, Paired and Unpaired Preferences
In our first article, we also highlighted the unique asymmetrical nature of safety alignment data — for unsafe prompts, we have a corresponding positive response (e.g. the GPT4 generated rejection) and negative response (e.g. the original SEA-Lion-v2.1-Instruct response), but for safe prompts we only have a positive response (e.g. the original SEALion-v2.1-Instruct response). While it is possible to synthetically generate rejections on safe prompts, this is a rather counterintuitive step.
During our dataset preparation, safe prompt-response pairs were included with the intention of only steering model responses on unsafe prompts. This was tested in two ways. First, we directly compare DPO to KTO, which strictly works with paired preferences and hence only with our unsafe prompt-response pairs. Our resulting model, SFT + DPO, produced a similar reduction in toxicity on unsafe local prompts, but presented a high rate of false positives on safe prompts.
Next, we perform KTO but only with paired preferences, and exclude safe prompt-response pairs. The resulting mode, SFT+KTO (symmetric), has similar performance to SFT+DPO on our local content benchmarks.
While there is no equivalence between DPO and KTO objective functions and rewards, comparing the training metrics of SFT+KTO and SFT+KTO (symmetric) is instructive. These results suggest that using only paired preference data implicitly simplifies the learning objective and induces less nuanced learning. Consider the following situation: with only paired preferences, maximizing the likelihood of positive (safe) responses and minimizing the likelihood of negative (unsafe) responses can be achieved simultaneously, because producing rejections (safe) necessarily implies not producing a compliant response (unsafe). Because these sub-objectives are complementary, it is advantageous for the model to overfit and produce an overwhelming number of rejections.
Introducing safe (unpaired) prompts adds complexity because these sub-objectives are no longer always achievable — minimizing the likelihood of negative responses risks accidentally rejecting safe prompts, reducing the likelihood of positive responses. We think this implicit difference in objectives can primarily be observed in the rate of convergence during training.
Fig. 8 shows the rewards on positive and negative examples and loss for SFT+KTO and SFT+KTO (symmetric). We observe noticeably faster convergence and overall lower loss for SFT+KTO (symmetric), which reflects the implicitly simpler training objective.
Potential Improvements to Training
Finally, we considered modifications to KTO to address the related issues of reward divergence between positive and negative samples and a sudden spike in KL divergence, with the end goal of improving KTO performance without SFT. While none of these methods produced significant improvements on benchmark metrics, we think they are worth sharing and suggest areas for future research.
Signed KL Penalty: Two simple observations motivate this idea. First, KTO loss utilizes a sigmoid function, which has a maximum gradient at 0 and decreases towards positive and negative extremes. Second, KTO includes the KL-divergence penalty only for gradient saturation and it is not included in back-propagation. However, consider the following simple scenario with the KTO value function:
- r = 15, KL = 1 → σ(15 − 1) = σ(14)
- r = 15, KL = 15 → σ(15 − 15) = σ(0)
For a given reward, a smaller KL divergence should be preferred because it requires less deviation from the reference model to achieve the same reward. However, the larger KL is given more weight during training because it produces a larger gradient on the sigmoid function, with no additional contribution to the gradient because backpropagation is not performed on the KL term. We hypothesize that this creates additional noise during training and conversely, that a properly applied KL term to control gradient saturation should produce a smaller KL divergence, faster convergence, and hopefully better benchmark performance.
Hence, we propose a new KTO loss, KTO-S:
While KTO-S performs similarly as KTO on safety benchmarks, training metrics indicate significantly reduced KL-divergence and faster convergence. We think KTO-S holds promise for more efficient training of larger models and in more complex tasks, and we leave this for future research.
NLL Loss: This idea originates from Pang et al. (2024), where the author finds that modifying the DPO loss with an additional negative log-likelihood (NLL) term improves Llama-2’s performance on reasoning tasks. In their experiments, they found that SFT and regular DPO failed to produce an increase in log probability of positive sequences. To directly address this, DPO loss is augmented with an additional negative log-likelihood term for positive responses only (yw in DPO). This change helps to increase log-likelihood of positive responses during training and improves performance on reasoning benchmarks. Following their findings, Llama-3 adopted a similar modification in its alignment phase. In our benchmarks, we did not observe any improvement to results.
Length-Normalized Log Probability: An undocumented quirk we noticed in the implementation of KTO is that log-probabilities for KTO loss are calculated in each minibatch as the average within-sequence sum. More concretely, for each sequence in a given mini-batch, the log-probability for each token is calculated, summed across the sequence, then averaged across the batch. This inadvertently produces a bias towards longer sequences, since they naturally tend to have higher total sequence log probabilities. We modify KTO to average log-probabilities over sequence length first, before averaging over mini-batch. Interestingly, we found that this led to significant instability during training.
Analysis Summary
Overall, we find KTO to be the better for safety alignment use cases because of its adaptability to paired and unpaired preferences. Generalizing, there may be similar instances where KTO is superior to DPO. One such situation is a data-constrained scenario where only some prompts have multiple responses with a preference ordering while others have a single response that is only known to be good or bad — KTO here would ensure maximum sample efficiency since all prompt-response pairs can be used, while DPO requires dropping the single response cases. In fact, we think these scenarios are even more compelling than the safety alignment case, since synthetic response generation is even harder for more open-ended tasks.
In terms of implementing safety alignment, we recommend fine-tuning iteratively by performing SFT first, benchmarking performance, and then proceeding with RLHF. As seen in our results, while KTO can further improve SFT results, the increase in performance must be justifiable over the cost of doing so. The richer the preference dataset, the more compelling the use case for RLHF.
Implementation Details
SFT:
For SFT, we opted for Axolotl, a popular fine-tuning library that integrates various fine-tuning libraries under a simple framework. With Axolotl, the entire training configuration (e.g. dataset, training parameters, logging, etc) is contained in a single YAML file, and training can be executed from the command-line. We emphasize that the main benefit is convenience — Axolotl simply wraps LLM libraries like peft, accelerate, and bitsandbytes, butdoes not inherently offer any memory or compute advantage over using them individually. We find it excels at abstracting away boilerplate code for simpler tasks like supervised fine-tuning.
DPO and KTO: trl
For DPO and KTO, we used the DPOTrainer and KTOTrainer classes provided by the trl library. While DPO and KTO are technically supported by Axolotl, quirks in dataset formatting and bugs meant that using trl directly made for a less frustrating experience.
Benchmarks: lm-harness
Open LLM Leaderboard v2 tasks were implemented using the lm-evaluation-harness library, using the same configurations as the actual leaderboard. While Huggingface uses its own fork of the lm-evaluation-harness, various bugs prevented us from using it. Instead, we used the main branch, which proved more reliable. Overall, we highlight that LLM benchmarking is still a rapidly changing field — aside from new benchmarks that innovate on new ways to evaluate models’ capabilities, there remains significant inconsistency in reliably replicating existing benchmarks across different models and setups. For example, we found that applying a chat template (which Huggingface recommends for all tasks) was extremely detrimental to performance on the MATH benchmark but improved performance on IFEval. In our implementation, we strived to be as internally consistent as possible.
Other Notes
Our training workloads were run on Google Cloud Platform infrastructure, across Compute Engine VMs and Vertex AI Pipelines. GCP proved the most convenient because we could access 1x A100 40GB and 80GB instances on our GCC environment. Parameter efficient fine-tuning of an 8B model is a relatively small scale for fine-tuning and using GCP allowed us to keep our resource usage to a minimum.
Serverless Pipeline
In this section, we briefly look at our fine-tuning pipeline implementation on Vertex AI, which enabled us to scale our fine-tuning experiments reliably.
MLOps is frequently looked at through the lens of inference and deployment, but it is equally important during experimentation to ensure that results are repeatable and multiple hypotheses can be tested efficiently. While pre-training of LLMs relies on other approximations such as scaling laws to minimize computational load, performing parameter-efficient fine-tuning of 8B models allows us to exhaustively test hypotheses directly.
It should be clear from our experiments that there are many moving parts across training and evaluation, making the process more complex than fine-tuning a regular ML model with a singular objective. Our pipeline helps to orchestrate all of these parts, across different experiment configurations, with the same environments, and only using as many resources as is needed.
Vertex AI Pipelines
Google Vertex AI Pipelines is a managed service for building, deploying, and managing machine learning workflows on Google Cloud. It enables the orchestration of serverless, containerized workloads using Kubeflow or Tensorflow Extended — since LLM workloads typically utilize multiple ML frameworks like PyTorch and transformers, we opted for Kubeflow.
Each component in a pipeline is designed to be modular, and represents a specific workflow step with its own container image and resources. Components communicate with each other through Artifacts, which are simple objects passed from one component to another and contain relevant metadata for the following component to utilize. For example, a data processing component might produce a Dataset Artifact, which contains metadata on the number of samples and the dataset path on cloud storage.
Our Pipeline
Shown below is a full fine-tuning run (Fig. 15); we also include a short description of each component (Fig. 16).