Categorical and Regression Trees with rpart
This tutorial uses data from the Dominante Trees of California dataset.
Loading the Libraries
Tree models are the most natural method to model response variables that are categorical (i.e. soil type, land cover type, etc.). However, they can also be used to model continuous responses but be careful of over fitting.
I had some errors when installing caret and found out you may need to remove the cli library and then reinstall it to get the latest version. The code below will do this and only needs to be exectued once.
# Note; I had to remove an old cli package to have R install the new one that was # Required by caret. Otherwise, confusionMatrix() was undefined remove.packages("cli") install.packages("cli")
CART results additional libraries to be loaded into R as shown below. You only need to execute this code once in R-Studio. Also, there is a library, rpart.plot, that provides much better looking trees than the standard plot() function. There is a section on this library at the bottom of the page.
library(rpart)
library(caret)
library(rpart.plot)
Setting the Working Directory
For this tutorial, I have also set the working directory at the top of the script. You'll want to change this to point to the folder with your data.
setwd("C:\\Users\\jim\\Desktop\\GSP 570\\Lab 6 CART\\")
Doug-Fir Example Data
The following examples were created using the Doug-Fir data we extracted from the FIA database.
# Read the data just using the file name TheData = read.csv("DominnantTrees_CA_Temp_Precip.csv") # Remove any NA (null) values TheData=na.omit(TheData)
The example data includes columns for:
- CommonName: Common name of the dominant tree species in the FIA plot
- Genus: Genus of the dominant tree species
- Species: Species of the dominant tree species
- AnnualPrecip: Annual precipitation values
- AnnualTemp: Annual temperature values
Creating and Plotting Trees
Creating a tree with rpart(...) is similar to the other modeling functions we have used and we just need to specific "class", for classification trees, as the method.
# create a classification tree with common names
TheTree=rpart(CommonName~AnnualPrecip+AnnualTemp,data=TheData,method="class")
The standard plot(...) function does not produce good looking trees by default so we'll use rpart.plot(...) to create a figure of the tree. The code below will create a nice simple and relatively readable tree. There are many options to change in rpart.plot(...). Take a look at Plotting rpart trees with the rpart.plot package for more options.
# regular plots do not look good so lets use rpart.plot()
rpart.plot(TheTree,type=1,extra=0,box.palette=0) # plot the tree
If you print the tree, you will see a rather complicated version of the tree that includes all of the values that the tree was built with. Below is the result of calling print(...) for a small set of data with just 39 values.
n= 39 node), split, n, loss, yval, (yprob) * denotes terminal node 1) root 39 34 Douglas-fir (0.026 0.077 0.026 0.077 0.1 0.13 0.026 0.077 0.077 0.026 0.1 0.051 0.026 0.026 0.026 0.13) 2) AnnualTemp>=133.5 11 8 blue oak (0 0.27 0 0.27 0.091 0.091 0 0.27 0 0 0 0 0 0 0 0) * 3) AnnualTemp< 133.5 28 23 white fir (0.036 0 0.036 0 0.11 0.14 0.036 0 0.11 0.036 0.14 0.071 0.036 0.036 0.036 0.18) 6) AnnualTemp>=83 14 10 Douglas-fir (0 0 0 0 0.071 0.29 0.071 0 0.071 0 0.21 0.14 0.071 0.071 0 0) * 7) AnnualTemp< 83 14 9 white fir (0.071 0 0.071 0 0.14 0 0 0 0.14 0.071 0.071 0 0 0 0.071 0.36) *
Here you can see each of the nodes of the tree, the condition that was used to split the values of the tree into each branch, the number of values that ended up in each branch, the number of values that were miss classified, and the predicted value. As the number of values in the tree becomes large, these printouts become very hard to interpret.
The printcp(...) function provides a small table like the one below that is quite helpful.
printcp(TheTree) # just print the pruning tree (last rel error is the best one)
Root node error: 6855/7715 = 0.88853 n= 7715 CP nsplit rel error xerror xstd 1 0.077899 0 1.00000 1.00000 0.0040325 2 0.023924 2 0.84420 0.84581 0.0055370 3 0.020000 3 0.82028 0.82217 0.0056851
The "Root node error" shows us the number of values that were correctly classified at the first split (branch) divided by the total number of values to give us a proportion of the values that were correctly classified. "n" is just the total number of values.
The table contains the following values.
- CP - complexity parameter
- nsplit - number of splits at each level
- rel error - 1 – R^2
- xerror - cross-validation error
- xstd - standard deviation of the cross-validation
The CP parameter is important because it helps us determine the level of complexity that provides for different amounts of error and we'll use it to "prune" trees or reduce their complexity. The "rel error" is 1 minus the R squared value and can be used to see how much of the variability in the data the model explains.
Using Factors
When we create trees using values that are text (e.g. Common names), R will convert the text into factors. Factors have a number that is matched to each unique text string. The model actually uses the numbers internally. The problem is that when we do a prediction, we will get the number back instead of our string. To address this, we want to be able to model based on the integer values of the factors and find out which number matches which text string.
The code below will create a column with factors and then write out the factors to a CSV file. If you examine the CSV file, you'll see there is a number for each text string.
NameFactors=factor(TheData$CommonName) # Convert the common names to factors TheData=data.frame(TheData,NameFactors) # Add a column with the factors # Save a table that has the numbers and their factors write.csv(levels(NameFactors), "Factors.csv")
From here, we can use the factors for modeling and the results should be the same.
# run the model using factors as the response TheTree=rpart(NameFactors~AnnualPrecip+AnnualTemp,data=TheData,method="class") printcp(TheTree) # print the tree rpart.plot(TheTree,type=1,extra=0,box.palette=0) # plot the tree
Controlling Complexity
Each time we make a split we increase the fit of the model but we also increase the trees complexity. Thus we want to control the complexity of the tree. We can do this with the rpart.control() function which can include a variable for the minimum number of values for a split to occur (minsplit) and the complexity parameter (cp). Typically we will allow for more complex trees by decreasing the cp value (i.e. try 0.002).
- minsplit - minimum number of data points in a node before a split is tried
- cp - complexity parameter
#create a control object to change the model parameters TheControl=rpart.control(minsplit=5,cp=0.02) # run the model with the parameters TheTree=rpart(NameFactors~AnnualPrecip+AnnualTemp,data=TheData,method="class",control=TheControl) printcp(TheTree) # print the pruned tree rpart.plot(TheTree,type=1,extra=0,box.palette=0) # plot the tree
Evaluating Trees
The diagrams of the trees are one of the best tools for evaluating the trees. Another tool is a "confusion matrix". This is a table that contains the response values along the top and left side and then each entry in the table shows the number of response values that were matched to each predicted value. For a perfect tree, the only counts that would be greater than 0 would be along the diagonal. You can view these tables in R but I recommend using MS-Excel.
# the type parameter specifies we want a classification (categories) rather than continuous ThePrediction = predict(TheTree,newdata=TheData,type="class") # create the table showing correct and incorrect matches TheResults=table(ThePrediction, TheData$CommonName) # save the table to view in MS-Excel write.csv(TheResults, file = "DominantTrees_ConfustionMatrix.csv") confusionMatrix(TheResults) # print the confusion matrix and summary statistics
You'll see a variety of measures for each response value in the output. See this Wikipedia page for information on the different measures.
Predicting into a Study Area to Create a Map
The code above shows that it is easy to create a prediction for the original data. However, typically we'll want to predict into a new dataset that contains a grid of points that we can convert to a raster in a GIS application. The code below shows how to do this by changing the "newdata" parameter when we do a prediction. Note that the names of the columns in the NewData must exactly match the columns in the dataset that was used to create the original model for this to work.
NewData = read.csv("DominantTrees_StudySite.csv") # the type parameter specifies we want a classification (categories) rather than continuous ThePrediction = predict(TheTree,newdata=NewData,type="class") # convert the prediction from strings to integers IntPrediction=as.integer(ThePrediction) # Create a new data frame with the string and int versions of the prediction FinalData=data.frame(NewData,ThePrediction) FinalData=data.frame(FinalData,IntPrediction) # Save the points for the entire site with the prediction to a CSV write.csv(FinalData, "DominantTrees_EntireSite_WithPrediction.csv")
Variable Importance
We can obtain the relative variable importance and then create a bar chart with the importance as in the code below.
barplot(t(TheTree$variable.importance),horiz=TRUE)
See: https://freakonometrics.hypotheses.org/tag/rpart
rpart.plot
rpart.plot provides tree plots that are typically better looking and allow for more customization than the standard plot() function. We need to install and include the library rpart.plot and then we can call rpart.plot() to display the tree.
library(rpart.plot) ... # create the CART model rpart.plot(TheTree)
The tree may not look great at first but there is hope. Try adding the following parameters to the rpart.plot function call:
- cex - font size for the labels (try values from 0.2 to 10)
- type - changes the layout of the information at each branch of the tree. Try values 0 to 6
- extra - Controls the amount of information at each node in the tree. Try values 0 to 11, 100, and 106.
For more information:
- Plotting rpart trees with the rpart.plot package - detailed and readable documentation on rpart.plot.
- rpart.plot - standard R documentation