Skip to content

Commit 33c78f7

Browse files
committed
Adressing [PR review](microsoft#6729 (review)).
- `model` and `tree` parameters made mandatory - fixed linting - combined tree type tests
1 parent ed84ad7 commit 33c78f7

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

R-package/R/lgb.plot.tree.R

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@
102102
#' @importFrom data.table := fcoalesce fifelse setnames
103103
#' @importFrom DiagrammeR add_global_graph_attrs create_edge_df create_graph create_node_df render_graph
104104
#' @export
105-
lgb.plot.tree <- function(model = NULL,
106-
tree = NULL,
107-
rules = NULL,
108-
render = TRUE,
109-
plot_width = NULL,
110-
plot_height = NULL
105+
lgb.plot.tree <- function(model
106+
, tree
107+
, rules = NULL
108+
, render = TRUE
109+
, plot_width = NULL
110+
, plot_height = NULL
111111
) {
112112
# check model is lgb.Booster
113113
if (!.is_Booster(x = model)) {
@@ -119,12 +119,8 @@ lgb.plot.tree <- function(model = NULL,
119119
call. = FALSE
120120
)
121121
}
122-
# tree must be integer or numeric
123-
if (!inherits(tree, c('integer','numeric'))) {
124-
stop(sprintf("lgb.plot.tree: 'tree' must only contain integers."))
125-
}
126122
# all elements of tree must be integers
127-
if (!all(tree %% 1L == 0L)) {
123+
if (!inherits(tree, c("integer", "numeric")) || !all(tree %% 1L == 0L)) {
128124
stop(sprintf("lgb.plot.tree: 'tree' must only contain integers."))
129125
}
130126
# extract data.table model structure
@@ -140,9 +136,10 @@ lgb.plot.tree <- function(model = NULL,
140136
# filter modelDT to just the rows for the selected trees
141137
modelDT <- modelDT[tree_index %in% tree]
142138
# change some column names to shorter and more diagram friendly versions
143-
data.table::setnames(modelDT
144-
, old = c("tree_index", "split_feature", "threshold", "split_gain")
145-
, new = c("Tree", "Feature", "Split", "Gain")
139+
data.table::setnames(
140+
modelDT
141+
, old = c("tree_index", "split_feature", "threshold", "split_gain")
142+
, new = c("Tree", "Feature", "Split", "Gain")
146143
)
147144
# the output from "lgb.model.dt.tree" follows these rules
148145
# "leaf_value" and "leaf_count" are only populated for leaves (NA for internal splits)

R-package/tests/testthat/test_lgb.plot.tree.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ for (model_name in names(models)){
2727
test_that("lgb.plot.tree fails when a non existing tree is selected", {
2828
expect_error({
2929
lgb.plot.tree(model, -1L)
30-
}, regexp = paste0("lgb.plot.tree: All values of 'tree' should be between 0 and the total number of trees in the model minus one"))
30+
}, regexp = paste0(
31+
"lgb.plot.tree: All values of 'tree' should be between 0 and the total number of trees in the model minus one"))
3132
})
3233
test_that("lgb.plot.tree fails when a non existing tree is selected", {
3334
expect_error({
3435
lgb.plot.tree(model, 999L)
35-
}, regexp = paste0("lgb.plot.tree: All values of 'tree' should be between 0 and the total number of trees in the model minus one"))
36+
}, regexp = paste0(
37+
"lgb.plot.tree: All values of 'tree' should be between 0 and the total number of trees in the model minus one"))
3638
})
3739
test_that("lgb.plot.tree fails when a non numeric tree is selected", {
3840
expect_error({

0 commit comments

Comments
 (0)