Distillation of Neural Network Into a Soft Decision Tree

While reading about different methods for increasing interpretability of neural network model decisions, I came along this paper authored by Nick Frosst and Geoff Hinton. It describes how a simpler, more explainable model, (e.g. binary soft decision tree, but with some crazy tweaks), can approximate function learned by a more complex but less explainable model (e.g. deep neural network) and serve as a medium for explainability of the learned function afterwards. I tried to reproduce the results and released my implementation on GitHub. Let me know if you use it to some good end!

Inference on BSDT
Inference on Binary Soft Decision Tree

The idea is, that all you need to do to understand a model prediction is just to follow the path of maximum probability through the tree and examine filters/activations along the way. This enables exponential growth in model capacity with linear growth in complexity of prediction explanation (length of the path from root to leaf). It does have some major drawbacks (nicely described in the paper), but these are some very interesting ideas anyway!