There's a thought experiment I keep coming back to. Imagine you're in a car on a highway, and your self-driving system makes a mistake. It was supposed to brake, but instead it switches lanes. That's one kind of wrong. Now imagine the same situation, but instead it accelerates into the vehicle ahead. That's a very different kind of wrong.

Both are misclassifications. A standard machine learning classifier would treat them identically — one wrong prediction is one wrong prediction. But anyone with a pulse understands that these aren't the same thing at all.

This is the problem we went after in this paper, published at CODS-COMAD 2024. The setting is interpretable reinforcement learning — specifically, learning decision tree policies. And the fix, once you see it, feels almost obvious.

The Setup: Why Trees at All?

Neural networks make excellent RL agents. They handle complex environments, learn subtle patterns, and generalize well. But they're black boxes. If you're deploying an agent in healthcare, autonomous driving, or finance, "trust me, the network figured it out" is not a satisfying answer.

Decision trees are the opposite. They're human-readable. You can trace exactly why a prediction was made, verify the policy's behavior, and audit specific decisions. The catch is that they're hard to train directly with RL — the non-differentiability of trees makes gradient-based methods awkward.

The workaround that's become standard: train a neural network first, then use it as a teacher to distill a tree policy through imitation learning. Methods like VIPER and MoET do exactly this. You roll out the teacher, collect state-action pairs, and train a tree to mimic it. Repeat for a few iterations.

It works reasonably well. But there's a quiet assumption baked into the whole process that nobody had questioned.

The Hidden Assumption

When you train a tree to imitate a neural network, you're essentially solving a classification problem: given a state, predict the action the teacher would take. Standard classifiers minimize the number of wrong predictions. Every misclassification counts equally — one mistake is one mistake.

But in reinforcement learning, actions have Q-values. The teacher's Q-function tells you the expected cumulative reward of taking each action from each state. Which means you already have a way to measure how bad a wrong prediction actually is: it's the gap between the Q-value of the optimal action and the Q-value of whatever action you predicted instead.

Braking when you should've switched lanes: Q-gap is small, both actions lead to roughly similar outcomes. Accelerating into a bus: Q-gap is enormous, you've just chosen the worst possible action.

Standard tree learning doesn't know this. It sees two misclassifications and penalizes them identically. We thought: what if the tree knew?

// the core idea

Define the misclassification cost of predicting action a in state s as:

cost(s, a) = max_a' Q*(s, a') − Q*(s, a)

Zero for the optimal action. Large for catastrophically wrong ones. Use this during tree construction, not just at evaluation time.

Making Trees Cost-Sensitive

The mechanism we used is called an Example-Dependent Cost-Sensitive Decision Tree (EDCSDT). Unlike standard trees that split on features to minimize misclassification rate, EDCSDT splits to minimize misclassification cost. The impurity criterion changes from Gini/entropy to a cost-weighted measure, and at each leaf, the predicted action is whichever minimizes the total cost across all states in that leaf — not necessarily the plurality label.

We plugged this into both VIPER and MoET, giving us CS-VIPER and CS-MoET. In CS-MoET, there's an additional wrinkle: each expert tree in the mixture is also weighted by its "responsibility" for each training instance, so the cost-sensitive criterion gets scaled by how much that expert actually owns a given state. The math works out cleanly — the structure of the underlying policy (trees with axis-aligned splits) is identical, so all the verifiability properties carry over too.

What Actually Happened

We tested across four environments: LunarLander, Taxi, FourRooms, and a highway driving simulation. The results were cleaner than I expected.

The headline result from LunarLander: the best VIPER policy (at depth 11) hit a cumulative reward of 169.9, still well below the teacher's 207.2. CS-VIPER hit 203.2 at depth 9 — closer to the teacher, with a shallower tree. At depth 7, CS-VIPER already beat VIPER's all-time best.

The highway environment had a striking depth-1 result. A depth-1 tree can only ask one question. VIPER, optimizing for accuracy, learned to mostly predict "accelerate" — because that's the most common action in the dataset. It scored 8.6. CS-VIPER learned to mostly predict "brake" — because the teacher's Q-function assigns high value to cautious behavior. It scored 21.1. Same tree depth, same data, wildly different outcomes, because one method understood what mistakes cost.

There's one counterintuitive finding worth dwelling on. CS-VIPER and CS-MoET consistently achieve lower fidelity than their vanilla counterparts — meaning they match the teacher's exact action less often. But they accumulate higher reward. The lesson: imitating a teacher's actions as accurately as possible is not the same as learning to perform well. The Q-function gives you something richer than labels, and throwing that information away during tree construction is a mistake.

Looking at the Mistakes

We did a deeper analysis in Taxi — a gridworld where an agent picks up and drops off passengers. We ran both depth-7 policies on a shared set of states and looked at the distribution of misclassification costs.

Over 90% of CS-VIPER's mistakes had a cost below 1 — small navigational errors, taking a slightly suboptimal path, nothing catastrophic. VIPER's mistakes were spread across the range: most had costs of 1–3 (taking two extra steps to backtrack), and a small but painful fraction had costs around 10 — those were cases where the tree predicted an illegal drop-off action, triggering a −10 penalty.

CS-VIPER essentially learned to avoid the expensive mistakes, even at the cost of making more cheap ones. Which is exactly what you'd want.

Why This Matters

The broader point here is about what information you use and when. The Q-function has always been available — it's sitting right there, produced as a byproduct of training the teacher. VIPER uses it to weight which states to sample (states with higher Q-spread get sampled more). But then, during tree construction, it reverts to treating every misclassification equally. We're just saying: don't stop there.

In safety-critical domains, the distribution of your errors matters as much as the error rate. A model that gets 95% accuracy but makes catastrophic mistakes 5% of the time is very different from one with 93% accuracy and only benign mistakes in the tail. Cost-sensitivity is a way to push the former toward the latter.

There's plenty left to explore — how does this hold up with a weaker teacher policy? Can you incorporate uncertainty in the Q-estimates? What happens with continuous action spaces? But as a targeted intervention on a clean problem, I'm happy with how this turned out.

// paper

Cost-Sensitive Trees for Interpretable Reinforcement Learning
Siddharth Nishtala & Balaraman Ravindran
CODS-COMAD 2024 · IIT Madras