22 января 2017

Особенности работы с функцией train() из пакета caret



Автор: Владимир Шитиков

Как обсуждалось нами ранее, пакет caret (сокращение от Classification and Regression Training) был разработан как эффективная надстройка, позволяющая унифицировать и интегрировать использование множества различных функций и методов построения предсказательных моделей, реализованных в других пакетах R. При этом происходит всестороннее тестирование и оптимизация настраиваемых параметров и гиперпараметров (tuning parameters) моделей. Разработанная единая технология настройки моделей основана на использовании полуавтоматических интеллектуальных подходов и ряда широко применяемых критериев качества, рассчитываемых с применением алгоритмов создания повторных выборок (resampling).



Процесс поиска оптимальных значений параметров моделей с использованием функции train() в общем виде реализуется по следующей схеме:


1.   Определение наборов данных и их предварительная обработка (при необходимости)
2.   Определение спецификации параметров модели
3.   Цикл для каждого параметра модели:
4.      | Итеративная оптимизация параметров:
5.      |     | Выделение и предобработка подвыборок для обучения и тестирования модели
6.      |     | Подгонка модели по обучающим объектам
7.      |     | Прогнозирование отклика для тестовых объектов
8.      |   end
9.      | Вычисление показателей средней эффективности прогноза
10.  end
11. Установление оптимальных параметров
12. Подгонка итоговой модели по всей выборке с использованием оптимальных параметров


Перед обучением модели при помощи функции train() необходимо задать соответствующий алгоритм и всю совокупность условий процесса оптимизации, для чего функцией trainControl() создается специальный объект. Вызов этой функции с исходными настройками имеет следующий вид:

trainControl(method = "boot",
             number = ifelse(grepl("cv", method), 10, 25),
             p = 0.75,
             repeats = ifelse(grepl("cv", method), 1, number),
             search = "grid", 
             initialWindow = NULL, 
             horizon = 1,
             fixedWindow = TRUE, 
             verboseIter = FALSE, 
             returnData = TRUE,
             returnResamp = "final", 
             savePredictions = FALSE,
             classProbs = FALSE, 
             summaryFunction = defaultSummary,
             selectionFunction = "best", 
             seeds = NA,
             preProcOptions = list(thresh = 0.95, 
                                   ICAcomp = 3, 
                                   k = 5),
             sampling = NULL, 
             index = NULL, 
             indexOut = NULL,
             timingSamps = 0, 
             predictionBounds = rep(FALSE, 2),
             adaptive = list(min = 5, 
                             alpha = 0.05,
                             method = "gls", 
                             complete = TRUE), 
             trim = FALSE, 
             allowParallel = TRUE)


Остановимся на описании наиболее важных аргументов этой функции:
  • method - метод создания повторных выборок: "boot", "boot632", "cv", "repeatedcv", "LOOCV", "LGOCV" (для повторяющихся разбиений на обучающую и контрольную выборки), "none" (проверка качества модели выполняется только на обучающей выборке), "oob" (для таких алгоритмов, как случайные леса, бэггинг и др.), "adaptive_cv", "adaptive_boot" или "adaptive_LGOCV";
  • number - задает число итераций при создании повторных выборок, в частности - количество таких выборок (k-folds) при перекрестной проверке;
  • repeats - число повторностей для выполнения k-кратной перекрестной проверки;
  • p - доля обучающей выборки от общего объема данных (при выполнении k-кратной перекрестной проверки);
  • verboseIter - TRUE означает, что пользователь будет получать сообщения о ходе вычислений (это удобно для оценки оставшегося времени вычислений, которое зачастую бывает очень большим);
  • search - задает способ перебора параметров модели при ее настройке - по предварительно заданной сетке ("grid") или случайным образом ("random");
  • returnResamp и savePredictions - определяют условия сохранения результатов вычислений и предсказанных значений (возможные варианты: без сохранения - "none", сохранение только для итоговой модели - "final", или сохранение всех результатов - "all");
  • classProbs - при выполнении классификации TRUE означает, что в процессе вычислений будут сохраняться не только конечные метки предсказанного класса, но и значения вероятности принадлежности того или иного наблюдения к каждому из имеющихся классов;
  • summaryFunction - определяет функцию, которая вычисляет сводную метрику качества модели на основе повторных выборок;
  • selectionFunction - определяет функцию выбора оптимального значения настраиваемого параметра;
  • preProcOptions - список опций, который передается на функцию предварительной обработки данных preprocess().
Например, при создании объекта ctrl

ctrl <- trainControl(method = "repeatedcv", number = 10, repeats = 10)

параметры перекрестной проверки будут иметь следующий смысл:
  • method = "repeatedcv" означает, что необходимо выполнить повторную перекрестную проверку (также возможна перекрестная проверка без повторов, проверка по одному наблюдению и др.);
  • number = 10 означает, что в процессе перекрестной проверки исходные данные необходимо разбить на 10 (примерно) равных частей.
  • repeats = 10 означает, что перекрестная проверка будет запущена 10 раз.
Теперь перейдем непосредственно к описанию функции train(), которая имеет следующий формат вызова:


train(x, y, 
      rf",preProcess = NULL,
      weights = NULL,
      metric = ifelse(is.factor(y), "Accuracy", "RMSE"),
      maximize = ifelse(metric %in% 
                        c("RMSE", "logLoss"), FALSE, TRUE),

      trControl = trainControl(), tuneGrid = NULL, tuneLength = 3)


Как всегда, исходные данные задаются либо матрицей предикторов и вектором со значениями отклика, либо объектом formula с одновременным указанием таблицы данных data. Следующий аргумент - method - в сущности, определяет модель классификации или регрессии, которую необходимо построить и протестировать. Если выполнить команду names(getModelInfo()), то можно увидеть список из 233 доступных методов (количество этих методов постоянно растет). Тот же список, но с различными возможностями поиска и сортировки можно найти по следующим ссылкам:


Пакет caret предоставляет пользователю унифицированный интерфейс доступа к пакетам, содержащим все эти функции. Необходимые компоненты автоматически подгружаются по мере их использования (предполагается, что соответствующие пакеты уже инсталлированы).

Например, следующим образом можно ознакомиться со списком всех моделей, имеющих отношение к линейной регрессии:

ls(getModelInfo(model = "lm"))
[1] "bayesglm"   "elm"        "glm"        "glmboost"   "glmnet"
[6] "glmStepAIC" "lm"         "lmStepAIC"  "plsRglm"    "rlm"

С каждым методом связан набор подлежащих оптимизации (гипер-)параметров. Например, легко убедиться в том, что простая линейная регрессия (method = "lm") не имеет таких параметров:

modelLookup("lm")
   model parameter     label forReg forClass probModel
 1    lm parameter parameter   TRUE    FALSE     FALSE

В свою очередь модель rpart (мы познакомимся с ней подробнее в следующем сообщении, в котором продолжим рассмотрение деревьев решений), имеет один параметр - Complexity Parameter (cp в сокращенном виде):

modelLookup("rpart")
   model parameter                label forReg forClass probModel
 1 rpart        cp Complexity Parameter   TRUE     TRUE      TRUE

Из вывода функции modelLookup() можно также увидеть, что линейная регрессия не используется для классификации (forClass = FALSE), тогда как rpart (от "Recursive Partitioning and Regression Trees") можно применять для построения деревьев как регрессии, так и классификации. В последнем случае модель осуществляет не только предсказание класса, но и оценивает апостериорные вероятности (probModel = TRUE).

Метод перекрестной проверки, заданный объектом trControl = trainControl(), хранит список опций, используемых на каждой итерации в ходе настройки параметров модели и оценки ее качества по определенным критериям. При построении каждой частной модели может осуществляться предварительная обработка данных с использованием методов, перечисленных в preProcess (и с учетом опций preProcOptions объекта trControl).

По умолчанию аргумент metric использует в качестве критерия качества точность предсказания ("Accuracy") в случае классификации и корень из среднеквадратичного отклонения прогнозных значений от наблюдаемых ("RMSE") в случае регрессии. Логический аргумент maximize уточняет, должен ли этот критерий быть максимизирован или минимизирован. Другие значения metric в совокупности с различными аргументами summaryFunction и selectionFunction объекта trControl обеспечивают широкие возможности для определения критериев поиска оптимальных моделей.

Количество перебираемых значений того или иного настраиваемого параметра задается аргументом tuneLength. Например, чтобы задать 30 повторов оценки параметра ср модели rpart, необходимо указать tuneLength = 30. Другой вариант - сохранить последовательность этих значений в отдельной таблице данных и затем подать эту таблицу на аргумент tuneGrid (например, tuneGrid = expand.grid(.cp = 0.5^(1:10))). Последний подход особенно полезен, когда диапазон возможных значений настраиваемого параметра известен заранее.

В результате спецификации всех положенных аргументов функции train() пользователь получает объект класса train, соответствующие элементы которого можно извлечь с помощью оператора $:

ls(mytrain)
[1]  "bestTune"     "call"         "coefnames"    "control"     
[5]  "dots"         "finalModel"   "maximize"     "method"      
[9]  "metric"       "modelInfo"    "modelType"    "perfNames"   
[13] "pred"         "preProcess"   "resample"     "resampledCM" 
[17] "results"      "terms"        "times"        "trainingData"
[21] "xlevels"      "yLimits"

Приведем краткий пример, иллюстрирующий нахождение оптимальной степени полинома для модели зависимости электрического сопротивления (Ом) мякоти фруктов киви от процентного содержания в ней сока (эти данные были подробно рассмотрены нами ранее). В одном из предыдущих сообщений мы выяснили, как найти оптимальную степень полинома d = 5 с использованием самостоятельно написанной функции скользящего контроля. К сожалению, функция train() не позволяет выполнить отбор предикторов для метода "lm" и тем самым выбрать оптимальную степень полинома. Однако мы можем выполнить любую из доступных типов перекрестной проверки и найти оптимальную степень полинома опосредованно, оценив характер изменения таких критериев качества, как например, RMSE или среднего коэффициента детерминации RSquared:

library(caret)
library(DAAG) 
data("fruitohms")

set.seed(123)
max.poly <- 7
degree <- 1:max.poly
RSquared <- rep(0, max.poly)
RMSE <- rep(0, max.poly)

# Выполним 10-кратную перекрестную проверку 10 раз
fitControl <- trainControl(method = "repeatedcv",
                           number = 10, repeats = 10)

# Использование функции train() для полиномиальной регрессии
for (d in degree)  {
     
     f <- bquote(juice ~ poly(ohms, .(d)))
     LinearRegressor <- train(as.formula(f),
     data = fruitohms,
     method = "lm", 
     trControl = fitControl)
     
     RSquared[d] <- LinearRegressor$results$Rsquared
     RMSE[d]<- LinearRegressor$results$RMSE
}

library(ggplot2)
Degree.RegParams = data.frame(degree, RSquared, RMSE)
ggplot(aes(x = degree, y = RSquared), data = Degree.RegParams) + geom_line()
ggplot(aes(x = degree, y = RMSE), data = Degree.RegParams) + geom_line()



Как и ранее, минимум ошибки и максимум коэффициента детерминации имеют место при d = 5. Выполним проверку качества итоговой модели, не указывая непосредственно объект trControl:

Poly5 <- train(juice ~ poly(ohms, 5), data = fruitohms, method = "lm")

summary(Poly5$finalModel)
 Coefficients:
                  Estimate Std. Error t value Pr(>|t|)    
 (Intercept)        35.152      0.766  45.893  < 2e-16 ***
 `poly(ohms, 5)1` -148.943      8.666 -17.187  < 2e-16 ***
 `poly(ohms, 5)2`   -5.540      8.666  -0.639  0.52383    
 `poly(ohms, 5)3`   51.078      8.666   5.894 3.43e-08 ***
 `poly(ohms, 5)4`  -13.905      8.666  -1.605  0.11118    
 `poly(ohms, 5)5`  -23.528      8.666  -2.715  0.00759 ** 
 ---
 Signif. codes:  0 С***Т 0.001 С**Т 0.01 С*Т 0.05 С.Т 0.1 С Т 1 
 Residual standard error: 8.666 on 122 degrees of freedom
 Multiple R-squared: 0.7362,     Adjusted R-squared: 0.7254 
 F-statistic:  68.1 on 5 and 122 DF,  p-value: < 2.2e-16 

Poly5
 Resampling: Bootstrap (25 reps) 
   RMSE  Rsquared  RMSE SD  Rsquared SD
   10    0.663     1.87     0.0745   

Мы получили в точности те же коэффициенты модели, что и при использовании обычной функции lm(), однако для критериев качества RMSE и RSquared здесь были найдены стандартные ошибки. Эти статистики оценивались на основе 25 бутстреп-выборок (Шитиков, Розенберг 2014), формируемых функцией train() по умолчанию. Обратите внимание, что несмещенное бутстреп-значение коэффициента детерминации (0.663) несколько меньше полученного по полной модели (0.7362).

В заключение приведем сокращенную таблицу со списком моделей (Khun 2008), доступных в пакете caret при использовании функции train() (указаны также наименования оптимизируемых параметров этих моделей). Эта таблица полезна также тем, что является своеобразным путеводителем по пакетам R и реализованным в них статистическим методам:


Модели Значение method Пакет Оптимизируемые параметры

Деревья на основе рекурсивного деления (recursive partitioning)

rpart

rpart

maxdepth
ctree
party
mincriterion

Деревья на основе бустинга (boosted trees)

gbm

gbm

interaction.depth, n.trees, shrinkage
blackboost
mboost
maxdepth, mstop
ada
ada
maxdepth, iter, nu

Другие модели на основе бустинга

glmboost

mboost

mstop
gamboost
mboost
mstop
logitboost
caTools
nIter

Случайные леса (random forests)

rf

randomForest

mtry
cforest
party
mtry

Деревья на основе бэггинга (bagged trees)

treebag

ipred


Нейронные сети (neural networks)

nnet

nnet

decay, size

Модели на основе частных наименьших квадратов (partial least squares)

pls

pls, caret

ncomp

Машины опорных векторов с RBF ядром (support vector machines, RBF kernel)

svmRadial

kernlab

sigma, C

Машины опорных векторов с полиномиальным ядром (support vector machines, polynomial kernel)

svmPoly

kernlab

scale, degree, C

Гауссовы процессы с RBF ядром (Gaussian processes, RBF kernel)

gaussprRadial

kernlab

sigma

Гауссовы процессы с полиномиальным ядром (Gaussian processes, polynomial kernel)

gaussprPoly

kernlab

scale, degree

Линейные модели наименьших квадратов (linear least squares)

lm

stats


Многомерные адаптивные регрессионные сплайны (multivariate adaptive regression splines, MARS)

earth, mars

earth

degree, nprune

MARS на основе бэггинга (bagged MARS)

bagEarth

caret, earth

degree, nprune

Эластичные сети (elastic net)

enet

elasticnet

lambda, fraction

Модели лассо (lasso)

lasso

elasticnet

fraction

Машины релевантных векторов с RBF ядром (relevance vector machines, RBF kernel)

rvmRadial

kernlab

sigma

Машины релевантных векторов с полиномиальным ядром (relevance vector machines, polynomial kernel)

rvmPoly

kernlab

scale, degree

Линейный дискриминантный анализ (linear discriminant analysis)

lda

MASS


Пошаговый диагональный дискриминантный анализ (stepwise diagonal discriminant analysis)

sddaLDA, sddaQDA

SDDA


Логистическая регрессия для двух или более классов (logistic/multinomial regression)

multinom

nnet

decay

Регуляризованный дискриминантный анализ (regularized discriminant analysis)

rda

klaR

lambda, gamma

Гибкий дискриминантный анализ (flexible discriminant analysis, FDA)

fda

mda, earth

degree, nprune

FDA на основе бэггинга (Bagged FDA)

bagFDA*

caretearth

degree, nprune

Машины опорных векторов на основе метода наименьших квадратов (least squares support vector machines)

ssvmRadial

kernlab

sigma

Метод k-ближайших соседей (k nearest neighbours)

knnЗ

caret

k

Разделение по центроидам (nearest shrunken centroids)

pam

pamr

threshold

Наивный байесовский классификатор (naive Bayes classifier)

nb

klaR

usekernel

Обобщенный метод частных наименьших квадратов (generalized partial least squares)

gpls

gpls

K.prov

Сети с квантованием обучающего вектора (learned vector quantization)

lvq

class

k



В будущем мы планируем описать большинство из перечисленных методов и продемонстрировать процесс оптимизации их параметров с использованием функции train().

Использованные источники:
  • Kuhn M. (2008) Building Predictive Models in R Using the caret Package. Journal of Statistical Software 5:113-142
  • Шитиков В. К., Розенберг Г. С. (2014) Рандомизация и бутстреп: статистический анализ в биологии и экологии с использованием R. Тольятти: Кассандра, 314 с. (PDF, данные и скрипты доступны на сайте авторов ievbras.ru/ecostat)

Комментариев нет :

Отправить комментарий