Skip to content

Commit 38df00d

Browse files
Gary Hemmingfboudry
authored andcommitted
Fix linting, make changes based on LightGBM maintainer feedback, better Help examples including categorical features
1 parent 2b1458e commit 38df00d

File tree

3 files changed

+291
-145
lines changed

3 files changed

+291
-145
lines changed

R-package/NAMESPACE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,18 @@ importClassesFrom(Matrix,dgCMatrix)
4444
importClassesFrom(Matrix,dgRMatrix)
4545
importClassesFrom(Matrix,dsparseMatrix)
4646
importClassesFrom(Matrix,dsparseVector)
47+
importFrom(DiagrammeR,add_global_graph_attrs)
48+
importFrom(DiagrammeR,create_edge_df)
49+
importFrom(DiagrammeR,create_graph)
50+
importFrom(DiagrammeR,create_node_df)
51+
importFrom(DiagrammeR,render_graph)
4752
importFrom(Matrix,Matrix)
4853
importFrom(R6,R6Class)
4954
importFrom(data.table,":=")
5055
importFrom(data.table,as.data.table)
5156
importFrom(data.table,data.table)
57+
importFrom(data.table,fcoalesce)
58+
importFrom(data.table,fifelse)
5259
importFrom(data.table,rbindlist)
5360
importFrom(data.table,set)
5461
importFrom(data.table,setnames)

R-package/R/lgb.plot.tree.R

Lines changed: 190 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,114 @@
11
#' @name lgb.plot.tree
2-
#' @title Plot a single LightGBM tree.
3-
#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree.
2+
#' @title Plot LightGBM trees.
3+
#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of one or more LightGBM trees.
44
#' @param model a \code{lgb.Booster} object.
5-
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation).
5+
#' @param tree An integer vector of tree indices that should be visualized IMPORTANT:
6+
#' the tree index in lightgbm is zero-based, i.e. use tree = 0 for the first tree in a model.
67
#' @param rules a list of rules to replace the split values with feature levels.
8+
#' @param render a logical flag for whether the graph should be rendered (see Value).
9+
#' @param plot_width the width of the diagram in pixels.
10+
#' @param plot_height the height of the diagram in pixels.
711
#'
812
#' @return
9-
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot.
13+
#' When \code{render = TRUE}:
14+
#' returns a rendered graph object which is an \code{htmlwidget} of class \code{grViz}.
15+
#' Similar to ggplot objects, it needs to be printed to see it when not running from command line.
16+
#'
17+
#' When \code{render = FALSE}:
18+
#' silently returns a graph object which is of DiagrammeR's class \code{dgr_graph}.
19+
#' This could be useful if one wants to modify some of the graph attributes
20+
#' before rendering the graph with \code{\link[DiagrammeR]{render_graph}}.
1021
#'
1122
#' @details
12-
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value.
23+
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree.
24+
#' The tree is extracted from the model and displayed as a directed graph.
25+
#' The nodes are labelled with the feature, split value, gain, count and value.
26+
#' The edges are labelled with the decision type and split value.
1327
#'
1428
#' @examples
1529
#' \donttest{
16-
#' # EXAMPLE: use the LightGBM example dataset to build a model with a single tree
30+
#' \dontshow{setLGBMthreads(2L)}
31+
#' \dontshow{data.table::setDTthreads(1L)}
32+
#' # Example One
1733
#' data(agaricus.train, package = "lightgbm")
1834
#' train <- agaricus.train
1935
#' dtrain <- lgb.Dataset(train$data, label = train$label)
20-
#' data(agaricus.test, package = "lightgbm")
21-
#' test <- agaricus.test
22-
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
23-
#' # define model parameters and build a single tree
2436
#' params <- list(
25-
#' objective = "regression",
26-
#' min_data = 1L,
37+
#' objective = "regression"
38+
#' , metric = "l2"
39+
#' , min_data = 1L
40+
#' , learning_rate = 0.3
41+
#' , num_leaves = 5L
2742
#' )
28-
#' valids <- list(test = dtest)
2943
#' model <- lgb.train(
30-
#' params = params,
31-
#' data = dtrain,
32-
#' nrounds = 1L,
33-
#' valids = valids,
34-
#' early_stopping_rounds = 1L
44+
#' params = params
45+
#' , data = dtrain
46+
#' , nrounds = 5L
47+
#' )
48+
#'
49+
#' # Plot the first tree
50+
#' lgb.plot.tree(model, 0L)
51+
#'
52+
#' # Plot the first and fifth trees
53+
#' lgb.plot.tree(model, c(0L,4L))
54+
#'
55+
#' # Example Two - model uses categorical features
56+
#' data(bank, package = "lightgbm")
57+
#'
58+
#' # We are dividing the dataset into two: one train, one validation
59+
#' bank_train <- bank[1L:4000L, ]
60+
#' bank_test <- bank[4001L:4521L, ]
61+
#'
62+
#' # We must now transform the data to fit in LightGBM
63+
#' # For this task, we use lgb.convert_with_rules
64+
#' # The function transforms the data into a fittable data
65+
#' bank_rules <- lgb.convert_with_rules(data = bank_train)
66+
#' bank_train <- bank_rules$data
67+
#'
68+
#' # Remove 1 to label because it must be between 0 and 1
69+
#' bank_train$y <- bank_train$y - 1L
70+
#'
71+
#' # Data input to LightGBM must be a matrix, without the label
72+
#' my_data_train <- as.matrix(bank_train[, 1L:16L, with = FALSE])
73+
#'
74+
#' # Creating the LightGBM dataset with categorical features
75+
#' # The categorical features can be passed to lgb.train to not copy and paste a lot
76+
#' dtrain <- lgb.Dataset(
77+
#' data = my_data_train
78+
#' , label = bank_train$y
79+
#' , categorical_feature = c(2L, 3L, 4L, 5L, 7L, 8L, 9L, 11L, 16L)
80+
#' )
81+
#'
82+
#' # Train the model with 5 training rounds
83+
#' params <- list(
84+
#' objective = "binary"
85+
#' , metric = "l2"
86+
#' , learning_rate = 0.1
87+
#' , num_leaves = 5L
88+
#' )
89+
#' model_bank <- lgb.train(
90+
#' params = params
91+
#' , data = dtrain
92+
#' , nrounds = 5L
3593
#' )
36-
#' # plot the tree and compare to the tree table
37-
#' # trees start from 0 in lgb.model.dt.tree
38-
#' tree_table <- lgb.model.dt.tree(model)
39-
#' lgb.plot.tree(model, 0)
40-
#' }
4194
#'
95+
#' # Plot the first two trees in the model without specifying "rules"
96+
#' lgb.plot.tree(model_bank, tree = 0L:1L)
97+
#'
98+
#' # Plot the first two trees in the model specifying "rules"
99+
#' lgb.plot.tree(model_bank, rules = bank_rules$rules, tree = 0L:1L)
100+
#'
101+
#' }
102+
#' @importFrom data.table := fcoalesce fifelse setnames
103+
#' @importFrom DiagrammeR add_global_graph_attrs create_edge_df create_graph create_node_df render_graph
42104
#' @export
43-
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
105+
lgb.plot.tree <- function(model = NULL,
106+
tree = NULL,
107+
rules = NULL,
108+
render = TRUE,
109+
plot_width = NULL,
110+
plot_height = NULL
111+
) {
44112
# check model is lgb.Booster
45113
if (!.is_Booster(x = model)) {
46114
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
@@ -51,74 +119,81 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
51119
call. = FALSE
52120
)
53121
}
54-
# tree must be numeric
55-
if (!inherits(tree, "numeric")) {
56-
stop("lgb.plot.tree: Has to be an integer numeric")
122+
# tree must be integer or numeric
123+
if (!inherits(tree, c('integer','numeric'))) {
124+
stop(sprintf("lgb.plot.tree: 'tree' must only contain integers."))
57125
}
58-
# tree must be integer
59-
if (tree %% 1 != 0) {
60-
stop("lgb.plot.tree: Has to be an integer numeric")
126+
# all elements of tree must be integers
127+
if (!all(tree %% 1L == 0L)) {
128+
stop(sprintf("lgb.plot.tree: 'tree' must only contain integers."))
61129
}
62130
# extract data.table model structure
63131
modelDT <- lgb.model.dt.tree(model)
64-
# check that tree is less than or equal to the maximum tree index in the model
65-
if (tree > max(modelDT$tree_index) || tree < 1) {
66-
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
67-
stop("lgb.plot.tree: Invalid tree number")
132+
# check that all values of tree are greater than zero and less than or equal to the maximum tree index in the model
133+
if (!all(tree >= 0L & tree <= max(modelDT$tree_index))) {
134+
stop(
135+
"lgb.plot.tree: All values of 'tree' should be between 0 and the total number of trees in the model minus one ("
136+
, max(modelDT$tree_index)
137+
, ")."
138+
)
68139
}
69-
# filter modelDT to just the rows for the selected tree
70-
modelDT <- modelDT[tree_index == tree, ]
71-
# change the column names to shorter more diagram friendly versions
140+
# filter modelDT to just the rows for the selected trees
141+
modelDT <- modelDT[tree_index %in% tree]
142+
# change some column names to shorter and more diagram friendly versions
72143
data.table::setnames(modelDT
73144
, old = c("tree_index", "split_feature", "threshold", "split_gain")
74-
, new = c("Tree", "Feature", "Split", "Gain"))
75-
# assign leaf_value to the Value column in modelDT
76-
modelDT[, Value := leaf_value]
77-
# assign new values if NA
78-
modelDT[is.na(Value), Value := internal_value]
79-
modelDT[is.na(Gain), Gain := leaf_value]
145+
, new = c("Tree", "Feature", "Split", "Gain")
146+
)
147+
# the output from "lgb.model.dt.tree" follows these rules
148+
# "leaf_value" and "leaf_count" are only populated for leaves (NA for internal splits)
149+
# "internal_value" and "internal_count" are only populated for splits (NA for leaves)
150+
# for the diagram, combine leaf_value and internal_value into a single column called "Value"
151+
modelDT[, Value := data.table::fcoalesce(leaf_value, internal_value)]
152+
# for the diagram, combine leaf_count and internal_count into a single column called "Count"
153+
modelDT[, Count := data.table::fcoalesce(leaf_count, internal_count)]
154+
# "Feature" is only present for splits, it is NA for leaves
155+
# Use the text "Leaf" to denote leaves in the diagram
80156
modelDT[is.na(Feature), Feature := "Leaf"]
81-
# assign internal_count to Cover, and if Feature is "Leaf", assign leaf_count to Cover
82-
modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
83-
# remove unnecessary columns
84-
modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
85-
# assign split_index to Node
157+
# within each tree, "Node" holds a unique index for each split and leaf
158+
# for splits, Node = split_index (already populated by lgb.model.dt.tree as an integer)
159+
# for leaves, Node = max(split_index) for that tree, plus the leaf_index plus one
160+
# plus one is needed as leaf_index starts at zero within each tree
86161
modelDT[, Node := split_index]
87-
# find the maximum value of Node, if Node is NA, assign max_node + leaf_index + 1 to Node
88-
max_node <- max(modelDT[["Node"]], na.rm = TRUE)
89-
modelDT[is.na(Node), Node := max_node + leaf_index + 1]
90-
# adding ID column
162+
modelDT[, Node := data.table::fifelse(!is.na(Node), Node, max(Node, na.rm = TRUE) + leaf_index + 1L), by = Tree]
163+
# create an ID column to uniquely identify each Node in the diagram (even if there are multiple trees)
164+
# concatenate Tree and Node, e.g. "0-3" is the third node in the zeroth tree
91165
modelDT[, ID := paste(Tree, Node, sep = "-")]
92-
# remove unnecessary columns
93-
modelDT[, c("depth", "leaf_index") := NULL]
94166
modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent]
95-
modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL]
96-
# assign the IDs of the matching parent nodes to Yes and No
97-
modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
98-
modelDT <- modelDT[nrow(modelDT):1, ]
99-
modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
167+
# each split node is parent to two "descendent" nodes
168+
# column "Yes" will hold the ID of the first descendent node
169+
# column "No" will hold the ID of the second descendent node
170+
modelDT[, Yes := ID[match(Node, parent)], by = Tree]
171+
# reverse the order of modelDT
172+
# so the match now finds the second descendent node
173+
modelDT <- modelDT[rev(seq_len(.N))]
174+
modelDT[, No := ID[match(Node, parent)], by = Tree]
100175
# which way do the NA's go (this path will get a thicker arrow)
101-
# for categorical features, NA gets put into the zero group
102-
modelDT[default_left == TRUE, Missing := Yes]
103-
modelDT[default_left == FALSE, Missing := No]
104-
modelDT[.zero_present(Split), Missing := Yes]
105-
# create the label text
176+
modelDT[default_left == "TRUE", Missing := Yes]
177+
modelDT[default_left == "FALSE", Missing := No]
178+
# create the label text for each node
179+
# for leaves include the Gain, rounded to 6 s.f. for display
180+
# round the Value to 6 s.f. for display
106181
modelDT[, label := paste0(
107182
Feature
108-
, "\nCover: "
109-
, Cover
110-
, ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf"
111-
, ""
112-
, round(Gain, 4))
183+
, "\nCount: "
184+
, Count
185+
, data.table::fifelse(Feature == "Leaf", "", "\nGain: ")
186+
, data.table::fifelse(Feature == "Leaf", "", as.character(round(Gain, 6L)))
113187
, "\nValue: "
114-
, round(Value, 4)
188+
, round(Value, 6L)
115189
)]
116-
# style the nodes - same format as xgboost
117-
modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
190+
# ensure the initial split in each tree is correctly labelled
191+
modelDT[Node == 0L, label := paste0("Tree ", Tree, "\n", label)]
192+
# style nodes with rectangles for splits and ovals for leaves
118193
modelDT[, shape := "rectangle"][Feature == "Leaf", shape := "oval"]
194+
# style Nodes with the same colours as xgboost's xgb.plot.trees
119195
modelDT[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"]
120-
# in order to draw the first tree on top:
121-
modelDT <- modelDT[order(-Tree)]
196+
# create the diagram nodes
122197
nodes <- DiagrammeR::create_node_df(
123198
n = nrow(modelDT)
124199
, ID = modelDT$ID
@@ -128,29 +203,43 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
128203
, data = modelDT$Feature
129204
, fontcolor = "black"
130205
)
131-
# round the edge labels to 4 s.f. if they are numeric
132-
# as otherwise get too many decimal places and the diagram looks bad
133-
# would rather not use suppressWarnings
134-
numeric_idx <- suppressWarnings(!is.na(as.numeric(modelDT[["Split"]])))
135-
modelDT[numeric_idx, Split := round(as.numeric(Split), 4)]
136-
# replace indices with feature levels if rules supplied
137-
206+
# The Split column might be numeric or character (e.g. if categorical features are used)
207+
# sometimes numeric <=0 splits are reported as <= 1.00000001800251e-35 or similar by lgb.model.dt.tree
208+
# replace these with "0"
209+
if (is.numeric(modelDT[["Split"]])) {
210+
modelDT[abs(Split) < .Machine$double.eps, Split := 0.0]
211+
}
212+
# for categorical features, LightGBM labels the splits as a single integer or
213+
# several integers separated by "||", e.g. "1" or "2||3||5"
214+
# if "rules" supplied, the integers are replaced by their corresponding factor level
215+
# to make the diagram easier to understand
138216
if (!is.null(rules)) {
139-
for (f in names(rules)) {
140-
modelDT[Feature == f & decision_type == "==", Split := .levels.to.names(Split, f, rules)]
141-
}
217+
for (f in names(rules)) {
218+
modelDT[Feature == f & decision_type == "==", Split := unlist(lapply(
219+
Split,
220+
function(x) paste(names(rules[[f]])[as.numeric(unlist(strsplit(x, "||", fixed = TRUE)))], collapse = "\n")
221+
))]
222+
}
142223
}
143-
# replace long split names with a message
144-
modelDT[nchar(Split) > 500, Split := "Split too long to render"]
145-
# create the edge labels
224+
# replace very long splits with a message as otherwise diagram will be very tall
225+
modelDT[nchar(Split) > 500L, Split := "Split too long to render"]
226+
# create the edges
227+
# define edgesDT to filter out leaf nodes
228+
edgesDT <- modelDT[Feature != "Leaf"]
229+
# create the edge data frame using edgesDT
146230
edges <- DiagrammeR::create_edge_df(
147-
from = match(modelDT[Feature != "Leaf", c(ID)] %>% rep(2), modelDT$ID),
148-
to = match(modelDT[Feature != "Leaf", c(Yes, No)], modelDT$ID),
149-
label = modelDT[Feature != "Leaf", paste(decision_type, Split)] %>%
150-
c(rep("", nrow(modelDT[Feature != "Leaf"]))),
151-
style = modelDT[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
152-
c(modelDT[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
153-
rel = "leading_to"
231+
from = match(rep(edgesDT[, ID], 2L), modelDT$ID),
232+
to = match(edgesDT[, c(Yes, No)], modelDT$ID),
233+
label = c(
234+
edgesDT[, paste(decision_type, Split)],
235+
rep("", nrow(edgesDT))
236+
),
237+
# make the Missing edge bold
238+
style = c(
239+
edgesDT[, data.table::fifelse(Missing == Yes, "bold", "solid")],
240+
edgesDT[, data.table::fifelse(Missing == No, "bold", "solid")]
241+
),
242+
rel = "leading_to"
154243
)
155244
# create the graph
156245
graph <- DiagrammeR::create_graph(
@@ -176,29 +265,11 @@ lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
176265
, attr = c("color", "arrowsize", "arrowhead", "fontname")
177266
, value = c("DimGray", "1.5", "vee", "Helvetica")
178267
)
179-
# render the graph
180-
DiagrammeR::render_graph(graph)
181-
return(invisible(NULL))
182-
}
183-
184-
.zero_present <- function(x) {
185-
sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) {
186-
any(el == "0")
187-
})
188-
return(invisible(NULL))
189-
}
190-
191-
.levels.to.names <- function(x, feature_name, rules) {
192-
lvls <- sort(rules[[feature_name]])
193-
result <- strsplit(x, "||", fixed = TRUE)
194-
result <- lapply(result, as.numeric)
195-
result <- lapply(result, .levels_to_names)
196-
result <- lapply(result, paste, collapse = "\n")
197-
result <- as.character(result)
198-
return(invisible(NULL))
268+
# if 'render' is FALSE, return the graph object invisibly (without printing it)
269+
if (!render) {
270+
return(invisible(graph))
271+
} else {
272+
# if 'render' is TRUE, display the graph with specified width and height
273+
DiagrammeR::render_graph(graph, width = plot_width, height = plot_height)
274+
}
199275
}
200-
201-
.levels_to_names <- function(x) {
202-
names(lvls)[as.numeric(x)]
203-
return(invisible(NULL))
204-
}

0 commit comments

Comments
 (0)