From c0ba8aa8681c247f50135bee467693ef9e53b31b Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 09:42:20 -0600 Subject: [PATCH 01/30] adding auc from ModelMetrics --- pkg/caret/NAMESPACE | 1 + pkg/caret/R/filterVarImp.R | 47 +++++++++++++++++++++++----------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pkg/caret/NAMESPACE b/pkg/caret/NAMESPACE index be73a939..606068fc 100644 --- a/pkg/caret/NAMESPACE +++ b/pkg/caret/NAMESPACE @@ -1,4 +1,5 @@ useDynLib(caret) +importFrom(ModelMetrics, auc) import(foreach, methods, plyr, reshape2, ggplot2, lattice, nlme) importFrom(car, powerTransform, yjPower) importFrom(grDevices, extendrange) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 55dc2aa4..44995248 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -7,30 +7,30 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) if(any(notNumber)) { for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } + } } if(is.factor(y)) { classLevels <- levels(y) - + outStat <- matrix(NA, nrow = dim(x)[2], ncol = length(classLevels)) for(i in seq(along = classLevels)) { otherLevels <- classLevels[classLevels != classLevels[i]] - + for(k in seq(along = otherLevels)) { tmpSubset <- as.character(y) %in% c(classLevels[i], otherLevels[k]) tmpY <- factor(as.character(y)[tmpSubset]) - tmpX <- x[tmpSubset,] - + tmpX <- x[tmpSubset,] + rocAuc <- apply( - tmpX, - 2, + tmpX, + 2, function(x, class, pos) { - isMissing <- is.na(x) | is.na(class) + isMissing <- is.na(x) | is.na(class) if(any(isMissing)) { x <- x[!isMissing] @@ -40,16 +40,16 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) else roc(x, class = class, dataGrid = FALSE, positive = pos) aucRoc(outResults) }, - class = tmpY, + class = tmpY, pos = classLevels[i]) - outStat[, i] <- pmax(outStat[, i], rocAuc, na.rm = TRUE) + outStat[, i] <- pmax(outStat[, i], rocAuc, na.rm = TRUE) } if(i ==1 & length(classLevels) == 2) { outStat[, 2] <- outStat[, 1] break() - } - } + } + } colnames(outStat) <- classLevels rownames(outStat) <- dimnames(x)[[2]] outStat <- data.frame(outStat) @@ -60,18 +60,18 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) { meanMod <- sum((y - mean(y, rm.na = TRUE))^2) nzv <- nearZeroVar(x, saveMetrics = TRUE) - + if(nzv$zeroVar) return(NA) if(nzv$percentUnique < 20) { regMod <- lm(y~x, na.action = na.omit, ...) } else { regMod <- try(loess(y~x, na.action = na.omit, ...), silent = TRUE) - + if(class(regMod) == "try-error" | any(is.nan(regMod$residuals))) try(regMod <- lm(y~x, ...)) if(class(regMod) == "try-error") return(NA) } - + pR2 <- 1 - (sum(resid(regMod)^2)/meanMod) if(pR2 < 0) pR2 <- 0 pR2 @@ -79,7 +79,7 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) testFunc <- if(nonpara) nonparaFoo else paraFoo - outStat <- apply(x, 2, testFunc, y = y) + outStat <- apply(x, 2, testFunc, y = y) outStat <- data.frame(Overall = outStat) } outStat @@ -87,8 +87,7 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) rocPerCol <- function(dat, cls) { - loadNamespace("pROC") - pROC::roc(cls, dat, direction = "<")$auc + 1 - ModelMetrics::auc(cls, dat) } filterVarImp <- function(x, y, nonpara = FALSE, ...) @@ -98,14 +97,14 @@ filterVarImp <- function(x, y, nonpara = FALSE, ...) if(any(notNumber)) { for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } + } } if(is.factor(y)) { classLevels <- levels(y) k <- length(classLevels) - + if(k > 2) { counter <- 1 @@ -145,18 +144,18 @@ filterVarImp <- function(x, y, nonpara = FALSE, ...) { meanMod <- sum((y - mean(y, rm.na = TRUE))^2) nzv <- nearZeroVar(x, saveMetrics = TRUE) - + if(nzv$zeroVar) return(NA) if(nzv$percentUnique < 20) { regMod <- lm(y~x, na.action = na.omit, ...) } else { regMod <- try(loess(y~x, na.action = na.omit, ...), silent = TRUE) - + if(class(regMod) == "try-error" | any(is.nan(regMod$residuals))) try(regMod <- lm(y~x, ...)) if(class(regMod) == "try-error") return(NA) } - + pR2 <- 1 - (sum(resid(regMod)^2)/meanMod) if(pR2 < 0) pR2 <- 0 pR2 @@ -164,7 +163,7 @@ filterVarImp <- function(x, y, nonpara = FALSE, ...) testFunc <- if(nonpara) nonparaFoo else paraFoo - outStat <- apply(x, 2, testFunc, y = y) + outStat <- apply(x, 2, testFunc, y = y) outStat <- data.frame(Overall = outStat) } outStat From 671ec86c8395052be6c896dfebb0e12b303a878f Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 09:44:30 -0600 Subject: [PATCH 02/30] less verbose --- pkg/caret/R/filterVarImp.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 44995248..8a617bff 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -93,7 +93,7 @@ rocPerCol <- function(dat, cls) { filterVarImp <- function(x, y, nonpara = FALSE, ...) { { - notNumber <- unlist(lapply(x, function(x) !is.numeric(x))) + notNumber <- sapply(x, function(x) !is.numeric(x)) if(any(notNumber)) { for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) From 4b56b1f37c158e42f01d3504aa6f6e6a791e2b6d Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 09:48:39 -0600 Subject: [PATCH 03/30] not needed anymore --- pkg/caret/tests/testthat/test_pROC_direction.R | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 pkg/caret/tests/testthat/test_pROC_direction.R diff --git a/pkg/caret/tests/testthat/test_pROC_direction.R b/pkg/caret/tests/testthat/test_pROC_direction.R deleted file mode 100644 index 8b0a7d8f..00000000 --- a/pkg/caret/tests/testthat/test_pROC_direction.R +++ /dev/null @@ -1,24 +0,0 @@ -library(caret) - -context('Testing pROC direction') - -test_that('rocPerCol returns AUC < 0.5 with direction = "<"', { - #skip_on_cran() - skip_if_not_installed('pROC') - - set.seed(42) - dat <- twoClassSim(200, linearVars = 1) - - auto.auc <- as.numeric(pROC::roc(dat$Class, dat$Linear1, direction = "auto")$auc) - fixed.auc <- as.numeric(pROC::roc(dat$Class, dat$Linear1, direction = "<")$auc) - tested.auc <- as.numeric(caret:::rocPerCol(dat$Linear1, dat$Class)) - - # tested.auc should equal tested.auc now - expect_equal(tested.auc, fixed.auc) - - # Also it has been hand-checked to be < 0.5 with this seed (0.4875) - expect_lt(tested.auc, 0.5) - - # And it should be lower than the "auto" auc - expect_lt(tested.auc, auto.auc) -}) From 6fc2917cfcf271a5519c2f56bc776a59d5b4f6c8 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 10:29:24 -0600 Subject: [PATCH 04/30] minor test cleanup --- pkg/caret/tests/testthat/test_glmnet_varImp.R | 6 +++--- pkg/caret/tests/testthat/test_models_bagEarth.R | 1 + pkg/caret/tests/testthat/test_sampling_options.R | 25 ++++++++++++------------ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/pkg/caret/tests/testthat/test_glmnet_varImp.R b/pkg/caret/tests/testthat/test_glmnet_varImp.R index cef38045..261ec5c5 100644 --- a/pkg/caret/tests/testthat/test_glmnet_varImp.R +++ b/pkg/caret/tests/testthat/test_glmnet_varImp.R @@ -8,17 +8,17 @@ test_that('glmnet varImp returns non-negative values', { skip_if_not_installed('glmnet') set.seed(1) dat <- SLC14_1(200) - + reg <- train(y ~ ., data = dat, method = "glmnet", tuneGrid = data.frame(lambda = .1, alpha = .5), trControl = trainControl(method = "none")) - + # this checks that some coefficients are negative coefs <- predict(reg$finalModel, s=0.1, type="coef") expect_less_than(0, sum(0 > coefs)) # now check that all elements of varImp are nonnegative, # in spite of negative coefficients vis <- varImp(reg, s=0.1, scale=F)$importance - expect_equal(0, sum(0 > vis)) + expect_true(all(vis >= 0)) }) diff --git a/pkg/caret/tests/testthat/test_models_bagEarth.R b/pkg/caret/tests/testthat/test_models_bagEarth.R index 7f8b1f6c..63489e41 100644 --- a/pkg/caret/tests/testthat/test_models_bagEarth.R +++ b/pkg/caret/tests/testthat/test_models_bagEarth.R @@ -2,6 +2,7 @@ # such as the bagEarth() not returning the right kind of object, that one of # the functions (bagEarth, format, predict) crash during normal usage, or that # bagEarth cannot model a simplistic kind of linear equation. +context("earth") test_that('bagEarth simple regression', { skip_on_cran() data <- data.frame(X = 1:100) diff --git a/pkg/caret/tests/testthat/test_sampling_options.R b/pkg/caret/tests/testthat/test_sampling_options.R index 12abe61c..e199c4a0 100644 --- a/pkg/caret/tests/testthat/test_sampling_options.R +++ b/pkg/caret/tests/testthat/test_sampling_options.R @@ -1,18 +1,19 @@ library(caret) library(testthat) -load(system.file("models", "sampling.RData", package = "caret")) +context("sampling options") +load(system.file("models", "sampling.RData", package = "caret")) test_that('check appropriate sampling calls by name', { skip_on_cran() arg_names <- c("up", "down", "rose", "smote") arg_funcs <- sampling_methods arg_first <- c(TRUE, FALSE) - + ## test that calling by string gives the right result for(i in arg_names) { out <- caret:::parse_sampling(i) - expected <- list(name = i, + expected <- list(name = i, func = sampling_methods[[i]], first = TRUE) expect_equivalent(out, expected) @@ -24,11 +25,11 @@ test_that('check appropriate sampling calls by function', { arg_names <- c("up", "down", "rose", "smote") arg_funcs <- sampling_methods arg_first <- c(TRUE, FALSE) - + ## test that calling by function gives the right result for(i in arg_names) { out <- caret:::parse_sampling(sampling_methods[[i]]) - expected <- list(name = "custom", + expected <- list(name = "custom", func = sampling_methods[[i]], first = TRUE) expect_equivalent(out, expected) @@ -39,26 +40,26 @@ test_that('check bad sampling name', { skip_on_cran() expect_error(caret:::parse_sampling("what?")) }) - + test_that('check bad first arg', { skip_on_cran() expect_error(caret:::parse_sampling(list(name = "yep", func = sampling_methods[["up"]], first = 2))) -}) +}) test_that('check bad func arg', { skip_on_cran() expect_error(caret:::parse_sampling(list(name = "yep", func = I, first = 2))) -}) - +}) + test_that('check incomplete list', { skip_on_cran() expect_error(caret:::parse_sampling(list(name = "yep"))) -}) +}) test_that('check call', { skip_on_cran() expect_error(caret:::parse_sampling(14)) -}) +}) ################################################################### ## @@ -82,4 +83,4 @@ test_that('check getting one method', { test_that('check missing method', { skip_on_cran() expect_error(getSamplingInfo("plum")) -}) \ No newline at end of file +}) From 2ab5f5ab651b57d05474fe108cdf1b93e9b29b8d Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 10:29:45 -0600 Subject: [PATCH 05/30] modified tests for mnLogLoss --- pkg/caret/R/aaa.R | 129 +++++++++++++++--------------- pkg/caret/tests/testthat/test_mnLogLoss.R | 29 +++---- 2 files changed, 76 insertions(+), 82 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index 8512d8e1..fb712e51 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -20,25 +20,25 @@ ################################################################### if(getRversion() >= "2.15.1"){ - + utils::globalVariables(c('Metric', 'Model')) - - + + ## densityplot(~ values|Metric, data = plotData, groups = ind, ## xlab = "", ...) - + utils::globalVariables(c('ind')) - + ## avPerf <- ddply(subset(results, Metric == metric[1] & X2 == "Estimate"), ## .(Model), ## function(x) c(Median = median(x$value, na.rm = TRUE))) - + utils::globalVariables(c('X2')) - + ## x[[i]]$resample <- subset(x[[i]]$resample, Variables == x[[i]]$bestSubset) - + utils::globalVariables(c('Variables')) - + ## calibCalc: no visible binding for global variable 'obs' ## calibCalc: no visible binding for global variable 'bin' ## @@ -47,9 +47,9 @@ if(getRversion() >= "2.15.1"){ ## binData <- data.frame(prob = x$calibProbVar, ## bin = cut(x$calibProbVar, (0:cuts)/cuts, include.lowest = TRUE), ## class = x$calibClassVar) - + utils::globalVariables(c('obs', 'bin')) - + ## ## checkConditionalX: no visible binding for global variable '.outcome' ## checkConditionalX <- function(x, y) @@ -57,9 +57,9 @@ if(getRversion() >= "2.15.1"){ ## x$.outcome <- y ## unique(unlist(dlply(x, .(.outcome), zeroVar))) ## } - + utils::globalVariables(c('.outcome')) - + ## classLevels.splsda: no visible global function definition for 'ilevels' ## ## classLevels.splsda <- function(x, ...) @@ -68,9 +68,9 @@ if(getRversion() >= "2.15.1"){ ## ## same class name, but this works for either ## ilevels(x$y) ## } - + utils::globalVariables(c('ilevels')) - + ## looRfeWorkflow: no visible binding for global variable 'iter' ## looSbfWorkflow: no visible binding for global variable 'iter' ## looTrainWorkflow: no visible binding for global variable 'parm' @@ -93,9 +93,9 @@ if(getRversion() >= "2.15.1"){ ## .errorhandling = "stop") %dopar% ## { ## - + utils::globalVariables(c('iter', 'parm', 'method', 'Resample', 'dat')) - + ## tuneScheme: no visible binding for global variable '.alpha' ## tuneScheme: no visible binding for global variable '.phi' ## tuneScheme: no visible binding for global variable '.lambda' @@ -103,9 +103,9 @@ if(getRversion() >= "2.15.1"){ ## seqParam[[i]] <- data.frame(.lambda = subset(grid, ## subset = .phi == loop$.phi[i] & ## .lambda < loop$.lambda[i])$.lambda) - + utils::globalVariables(c('.alpha', '.phi', '.lambda')) - + ## createGrid : somDims: no visible binding for global variable '.xdim' ## createGrid : somDims: no visible binding for global variable '.ydim' ## createGrid : lvqGrid: no visible binding for global variable '.k' @@ -114,9 +114,9 @@ if(getRversion() >= "2.15.1"){ ## out <- expand.grid(.xdim = 1:x, .ydim = 2:(x+1), ## .xweight = seq(.5, .9, length = len)) ## - + utils::globalVariables(c('.xdim', '.ydim', '.k', '.size')) - + ## createModel: possible error in rda(trainX, trainY, gamma = ## tuneValue$.gamma, lambda = tuneValue$.lambda, ...): unused ## argument(s) (gamma = tuneValue$.gamma, lambda = tuneValue$.lambda) @@ -144,54 +144,54 @@ if(getRversion() >= "2.15.1"){ ## ## $lambda ## [1] NA - + ## predictionFunction: no visible binding for global variable '.alpha' ## ## delta <- subset(param, .alpha == uniqueA[i])$.delta ## - + utils::globalVariables(c('.alpha')) - + ## predictors.gbm: no visible binding for global variable 'rel.inf' ## predictors.sda: no visible binding for global variable 'varIndex' ## predictors.smda: no visible binding for global variable 'varIndex' ## ## varUsed <- as.character(subset(relImp, rel.inf != 0)$var) - + utils::globalVariables(c('rel.inf', 'varIndex')) - + ## plotClassProbs: no visible binding for global variable 'Observed' ## ## out <- densityplot(form, data = stackProbs, groups = Observed, ...) - + utils::globalVariables(c('Observed')) - + ## plot.train: no visible binding for global variable 'parameter' ## ## paramLabs <- subset(modelInfo, parameter %in% params)$label - + utils::globalVariables(c('parameter')) - + ## plot.rfe: no visible binding for global variable 'Selected' ## ## out <- xyplot(plotForm, data = results, groups = Selected, panel = panel.profile, ...) - + utils::globalVariables(c('Selected')) - + ## icr.formula: no visible binding for global variable 'thresh' ## ## res <- icr.default(x, y, weights = w, thresh = thresh, ...) - + utils::globalVariables(c('thresh', 'probValues', 'min_prob', 'groups', 'trainData', 'j', 'x', '.B')) - + utils::globalVariables(c('model_id', 'player1', 'player2', 'playa', 'win1', 'win2', 'name')) - + utils::globalVariables(c('object', 'Iter', 'lvls', 'Mean', 'Estimate')) - - + + ## parse_sampling: no visible binding for global variable 'sampling_methods' utils::globalVariables(c('sampling_methods')) - + ## ggplot.calibration: no visible binding for global variable 'midpoint' ## ggplot.calibration: no visible binding for global variable 'Percent' ## ggplot.calibration: no visible binding for global variable 'Lower' @@ -206,10 +206,10 @@ altTrainWorkflow <- function(x) x best <- function(x, metric, maximize) { - + bestIter <- if(maximize) which.max(x[,metric]) else which.min(x[,metric]) - + bestIter } @@ -242,26 +242,25 @@ mnLogLoss <- function(data, lev = NULL, model = NULL){ stop("'data' should have columns consistent with 'lev'") if(!all(sort(lev) %in% sort(levels(data$obs)))) stop("'data$obs' should have levels consistent with 'lev'") - eps <- 1e-15 - probs <- as.matrix(data[, lev, drop = FALSE]) - probs[probs > 1 - eps] <- 1 - eps - probs[probs < eps] <- eps - inds <- match(data$obs, colnames(probs)) - probs <- probs[cbind(seq_len(nrow(probs)), inds)] - c(logLoss = -mean(log(probs), na.rm = TRUE)) + + dataComplete <- data[complete.cases(data),] + probs <- as.matrix(dataComplete[, lev, drop = FALSE]) + + inds <- match(dataComplete$obs, colnames(probs)) + ModelMetrics::mlogLoss(dataComplete$obs, probs) } multiClassSummary <- function (data, lev = NULL, model = NULL){ #Check data - if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) + if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) stop("levels of observed and predicted data do not match") has_class_probs <- all(lev %in% colnames(data)) if(has_class_probs) { ## Overall multinomial loss lloss <- mnLogLoss(data = data, lev = lev, model = model) - requireNamespaceQuietStop("pROC") + requireNamespaceQuietStop("pROC") #Calculate custom one-vs-all ROC curves for each class - prob_stats <- lapply(levels(data[, "pred"]), + prob_stats <- lapply(levels(data[, "pred"]), function(class){ #Grab one-vs-all data for the class obs <- ifelse(data[, "obs"] == class, 1, 0) @@ -269,14 +268,14 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ rocObject <- try(pROC::roc(obs, data[,class], direction = "<"), silent = TRUE) prob_stats <- if (class(rocObject)[1] == "try-error") NA else rocObject$auc names(prob_stats) <- c('ROC') - return(prob_stats) + return(prob_stats) }) roc_stats <- mean(unlist(prob_stats)) } - + #Calculate confusion matrix-based statistics CM <- confusionMatrix(data[, "pred"], data[, "obs"]) - + #Aggregate and average class-wise stats #Todo: add weights # RES: support two classes here as well @@ -287,32 +286,32 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ class_stats <- colMeans(CM$byClass) names(class_stats) <- paste("Mean", names(class_stats)) } - + # Aggregate overall stats - overall_stats <- if(has_class_probs) + overall_stats <- if(has_class_probs) c(CM$overall, lloss, ROC = roc_stats) else CM$overall - if (length(levels(data[, "pred"])) > 2) + if (length(levels(data[, "pred"])) > 2) names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_ROC" - - - # Combine overall with class-wise stats and remove some stats we don't want + + + # Combine overall with class-wise stats and remove some stats we don't want stats <- c(overall_stats, class_stats) stats <- stats[! names(stats) %in% c('AccuracyNull', "AccuracyLower", "AccuracyUpper", - "AccuracyPValue", "McnemarPValue", + "AccuracyPValue", "McnemarPValue", 'Mean Prevalence', 'Mean Detection Prevalence')] - + # Clean names names(stats) <- gsub('[[:blank:]]+', '_', names(stats)) - + # Change name ordering to place most useful first # May want to remove some of these eventually - stat_list <- c("Accuracy", "Kappa", "Mean_Sensitivity", "Mean_Specificity", + stat_list <- c("Accuracy", "Kappa", "Mean_Sensitivity", "Mean_Specificity", "Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate", "Mean_Balanced_Accuracy") if(has_class_probs) stat_list <- c("logLoss", "Mean_ROC", stat_list) if (length(levels(data[, "pred"])) == 2) stat_list <- gsub("^Mean_", "", stat_list) - + stats <- stats[c(stat_list)] - + return(stats) } diff --git a/pkg/caret/tests/testthat/test_mnLogLoss.R b/pkg/caret/tests/testthat/test_mnLogLoss.R index def12606..ae19d1f6 100644 --- a/pkg/caret/tests/testthat/test_mnLogLoss.R +++ b/pkg/caret/tests/testthat/test_mnLogLoss.R @@ -1,7 +1,5 @@ context('mnLogLoss') -eps <- 1e-15 - classes <- LETTERS[1:3] test_dat1 <- data.frame(obs = c("A", "A", "A", "B", "B", "C"), @@ -10,24 +8,21 @@ test_dat1 <- data.frame(obs = c("A", "A", "A", "B", "B", "C"), B = c(0, .05, .29, .8, .6, .3), C = c(0, .15, .20, .1, .2, .4)) -expected1 <- log(1-eps) + log(.8) + log(.51) + log(.8) + log(.6) + log(.4) -expected1 <- c(logLoss = -expected1/nrow(test_dat1)) -result1 <- mnLogLoss(test_dat1, lev = classes) +test_that("Multiclass logloss returns expected values", { + result1 <- mnLogLoss(test_dat1, classes) -test_dat2 <- test_dat1 -test_dat2$A[1] <- NA + test_dat2 <- test_dat1 + test_dat2$A[1] <- NA + result2 <- mnLogLoss(test_dat2, classes) -expected2 <- log(.8) + log(.51) + log(.8) + log(.6) + log(.4) -expected2 <- c(logLoss = -expected2/sum(complete.cases(test_dat2))) -result2 <- mnLogLoss(test_dat2, lev = classes) + test_dat3 <- test_dat1 + test_dat3 <- test_dat3[, rev(1:5)] + result3 <- mnLogLoss(test_dat3, classes) -test_dat3 <- test_dat1 -test_dat3 <- test_dat3[, rev(1:5)] -expected3 <- expected1 -result3 <- mnLogLoss(test_dat3, lev = classes[c(2, 3, 1)]) + expect_equal(result1, 0.424458, tolerance = .000001) + expect_equal(result2, 0.5093496, tolerance = .000001) + expect_equal(result3, 0.424458, tolerance = .000001) -expect_equal(result1, expected1) -expect_equal(result2, expected2) -expect_equal(result3, expected3) +}) From 3bf402c2da5ac63cc948fdce82a3847baecadb85 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 10:51:17 -0600 Subject: [PATCH 06/30] reworked stats to include auc from ModelMetrics and cleaned up --- pkg/caret/R/aaa.R | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index fb712e51..70140412 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -258,17 +258,15 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ if(has_class_probs) { ## Overall multinomial loss lloss <- mnLogLoss(data = data, lev = lev, model = model) - requireNamespaceQuietStop("pROC") + requireNamespaceQuietStop("ModelMetrics") #Calculate custom one-vs-all ROC curves for each class prob_stats <- lapply(levels(data[, "pred"]), - function(class){ + function(x){ #Grab one-vs-all data for the class - obs <- ifelse(data[, "obs"] == class, 1, 0) - prob <- data[,class] - rocObject <- try(pROC::roc(obs, data[,class], direction = "<"), silent = TRUE) - prob_stats <- if (class(rocObject)[1] == "try-error") NA else rocObject$auc - names(prob_stats) <- c('ROC') - return(prob_stats) + obs <- ifelse(data[, "obs"] == x, 1, 0) + prob <- data[,x] + AUCs <- try(ModelMetrics::auc(obs, data[,x]), silent = TRUE) + return(AUCs) }) roc_stats <- mean(unlist(prob_stats)) } @@ -289,9 +287,9 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ # Aggregate overall stats overall_stats <- if(has_class_probs) - c(CM$overall, lloss, ROC = roc_stats) else CM$overall + c(CM$overall, logLoss = lloss, ROC = roc_stats) else CM$overall if (length(levels(data[, "pred"])) > 2) - names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_ROC" + names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_AUC" # Combine overall with class-wise stats and remove some stats we don't want @@ -308,7 +306,7 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ stat_list <- c("Accuracy", "Kappa", "Mean_Sensitivity", "Mean_Specificity", "Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate", "Mean_Balanced_Accuracy") - if(has_class_probs) stat_list <- c("logLoss", "Mean_ROC", stat_list) + if(has_class_probs) stat_list <- c("logLoss", "Mean_AUC", stat_list) if (length(levels(data[, "pred"])) == 2) stat_list <- gsub("^Mean_", "", stat_list) stats <- stats[c(stat_list)] From 871cc85a1c492e02aee17e1e42cff4a0c90d0695 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Fri, 26 Aug 2016 08:43:22 -0600 Subject: [PATCH 07/30] dropping uneeded files --- pkg/caret/R/aucRoc.R | 15 --------------- pkg/caret/R/roc.R | 21 --------------------- pkg/caret/R/rocPoint.R | 19 ------------------- 3 files changed, 55 deletions(-) delete mode 100644 pkg/caret/R/aucRoc.R delete mode 100644 pkg/caret/R/roc.R delete mode 100644 pkg/caret/R/rocPoint.R diff --git a/pkg/caret/R/aucRoc.R b/pkg/caret/R/aucRoc.R deleted file mode 100644 index 2820e846..00000000 --- a/pkg/caret/R/aucRoc.R +++ /dev/null @@ -1,15 +0,0 @@ -aucRoc <- function(object) -{ - warning("This function is deprecated a of 1/3/12. The computations now utilize the pROC package. This function will be removed in a few releases.") - - sens <- object[, "sensitivity"] - omspec <- 1 - object[, "specificity"] - newOrder <- order(omspec) - sens <- sens[newOrder] - omspec <- omspec[newOrder] - - rocArea <- sum(.5 *diff(omspec) * (sens[-1] + sens[-length(sens)])) - rocArea <- max(rocArea, 1 - rocArea) - rocArea -} - diff --git a/pkg/caret/R/roc.R b/pkg/caret/R/roc.R deleted file mode 100644 index 104a67d1..00000000 --- a/pkg/caret/R/roc.R +++ /dev/null @@ -1,21 +0,0 @@ -roc <- function(data, class, dataGrid = TRUE, gridLength = 100, positive = levels(class)[1]) -{ - warning("This function is deprecated a of 1/3/12. The computations now utilize the pROC package. This function will be removed in a few releases.") - - if(!is.character(positive) | length(positive) != 1) stop("positive argument should be a single character value") - - if(!(positive %in% levels(class))) stop("wrong level specified") - if(length(levels(class)) != 2) stop("wrong number of levels") - if(dataGrid) cutoffDF <- data.frame(value = sort(unique(data))) - else cutoffDF <- data.frame(value = seq( - from = min(data, na.rm = TRUE), - to = max(data, na.rm = TRUE), - length = gridLength)) - numCuts <- dim(cutoffDF)[1] - out <- matrix(NA, ncol = 3, nrow = numCuts + 1) - - out[2:(numCuts + 1), ] <- t(apply(cutoffDF, 1, rocPoint, x = data, y = class, positive = positive)) - out[1, ] <- c(NA, 1, 0) - colnames(out) <- c("cutoff", "sensitivity", "specificity") - out -} diff --git a/pkg/caret/R/rocPoint.R b/pkg/caret/R/rocPoint.R deleted file mode 100644 index c4dca43f..00000000 --- a/pkg/caret/R/rocPoint.R +++ /dev/null @@ -1,19 +0,0 @@ -rocPoint <- function(cutoff, x, y, positive) -{ - warning("This function is deprecated a of 1/3/12. The computations now utilize the pROC package. This function will be removed in a few releases.") - classLevels <- levels(y) - negative <- classLevels[positive != classLevels] - newClass <- factor( - ifelse( - x <= cutoff, - negative, - positive), - levels = classLevels) - out <- c( - cutoff, - sensitivity(newClass, y, positive), - specificity(newClass, y, negative)) - names(out) <- c("cutoff", "sensitivity", "specificity") - out -} - From f970bd432b5594e7e0e3d7d27faf2d73b6efbc1f Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Fri, 26 Aug 2016 08:43:43 -0600 Subject: [PATCH 08/30] optimizing filterVarImp --- pkg/caret/R/filterVarImp.R | 148 +++++++++------------------------------------ 1 file changed, 28 insertions(+), 120 deletions(-) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 8a617bff..801b8597 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -1,133 +1,41 @@ -## todo start using foreach here - -oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) -{ - { - notNumber <- unlist(lapply(x, function(x) !is.numeric(x))) - if(any(notNumber)) - { - for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } - } - - if(is.factor(y)) - { - classLevels <- levels(y) - - outStat <- matrix(NA, nrow = dim(x)[2], ncol = length(classLevels)) - for(i in seq(along = classLevels)) - { - otherLevels <- classLevels[classLevels != classLevels[i]] - - for(k in seq(along = otherLevels)) - { - tmpSubset <- as.character(y) %in% c(classLevels[i], otherLevels[k]) - tmpY <- factor(as.character(y)[tmpSubset]) - tmpX <- x[tmpSubset,] - - rocAuc <- apply( - tmpX, - 2, - function(x, class, pos) - { - isMissing <- is.na(x) | is.na(class) - if(any(isMissing)) - { - x <- x[!isMissing] - class <- class[!isMissing] - } - outResults <- if(length(unique(x)) > 200) roc(x, class = class, positive = pos) - else roc(x, class = class, dataGrid = FALSE, positive = pos) - aucRoc(outResults) - }, - class = tmpY, - pos = classLevels[i]) - outStat[, i] <- pmax(outStat[, i], rocAuc, na.rm = TRUE) - } - if(i ==1 & length(classLevels) == 2) - { - outStat[, 2] <- outStat[, 1] - break() - } - } - colnames(outStat) <- classLevels - rownames(outStat) <- dimnames(x)[[2]] - outStat <- data.frame(outStat) - } else { - paraFoo <- function(data, y) abs(coef(summary(lm(y ~ data, na.action = na.omit)))[2, "t value"]) - nonparaFoo <- function(x, y, ...) - { - meanMod <- sum((y - mean(y, rm.na = TRUE))^2) - nzv <- nearZeroVar(x, saveMetrics = TRUE) +rocPerCol <- function(dat, cls){ + ModelMetrics::auc(cls, dat) +} - if(nzv$zeroVar) return(NA) - if(nzv$percentUnique < 20) - { - regMod <- lm(y~x, na.action = na.omit, ...) - } else { - regMod <- try(loess(y~x, na.action = na.omit, ...), silent = TRUE) +asNumeric <- function(data){ + fc <- sapply(data, is.factor) + modifyList(data, lapply(data[, fc], as.numeric)) +} - if(class(regMod) == "try-error" | any(is.nan(regMod$residuals))) try(regMod <- lm(y~x, ...)) - if(class(regMod) == "try-error") return(NA) - } +filterVarImp <- function(x, y, nonpara = FALSE, ...){ + # converting factors to numeric + notNumber <- sapply(x, function(x) !is.numeric(x)) + x = asNumeric(x) - pR2 <- 1 - (sum(resid(regMod)^2)/meanMod) - if(pR2 < 0) pR2 <- 0 - pR2 - } + if(is.factor(y)){ + classLevels <- levels(y) + k <- length(classLevels) - testFunc <- if(nonpara) nonparaFoo else paraFoo + if(k > 2){ - outStat <- apply(x, 2, testFunc, y = y) - outStat <- data.frame(Overall = outStat) - } - outStat -} + Combs <- combn(classLevels, 2) + CombsN <- combn(1:k, 2) + lStat <- lapply(1:ncol(Combs), FUN = function(cc){ + yLevs <- as.character(y) %in% Combs[,cc] + tmpX <- x[yLevs,] + tmpY <- as.numeric(y[yLevs] == Combs[,cc][2]) + apply(tmpX, 2, rocPerCol, cls = tmpY) + }) + Stat = do.call("cbind", lStat) -rocPerCol <- function(dat, cls) { - 1 - ModelMetrics::auc(cls, dat) -} + loutStat <- lapply(1:k, function(j){ + apply(Stat[,CombsN[,j]], 1, max) + }) -filterVarImp <- function(x, y, nonpara = FALSE, ...) -{ - { - notNumber <- sapply(x, function(x) !is.numeric(x)) - if(any(notNumber)) - { - for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } - } - - if(is.factor(y)) - { - classLevels <- levels(y) - k <- length(classLevels) + outStat = do.call("cbind", loutStat) - if(k > 2) - { - counter <- 1 - classIndex <- vector(mode = "list", length = k) - tmpStat <- matrix(NA, nrow = ncol(x), ncol = choose(k, 2)) - for(i in 1:k) - { - for(j in i:k) - { - if(i != j) - { - classIndex[[i]] <- c(classIndex[[i]], counter) - classIndex[[j]] <- c(classIndex[[j]], counter) - index <- which(y %in% c(classLevels[i], classLevels[j])) - tmpX <- x[index,,drop = FALSE] - tmpY <- factor(as.character(y[index]), levels = c(classLevels[i], classLevels[j])) - tmpStat[,counter] <- apply(tmpX, 2, rocPerCol, cls = tmpY) - counter <- counter + 1 - } - } - } - outStat <- matrix(NA, ncol(x), k) - for(i in 1:k) outStat[,i] <- apply(tmpStat[,classIndex[[i]]], 1, max) } else { tmp <- apply(x, 2, rocPerCol, cls = y) outStat <- cbind(tmp, tmp) From a2d2663a7e2cd3d519b71d310ad785643d013bd2 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Fri, 26 Aug 2016 08:43:52 -0600 Subject: [PATCH 09/30] use ModelMetrics --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d88e179c..f020a21e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,7 +44,7 @@ before_install: - ./travis-tool.sh r_binary_install nnet - ./travis-tool.sh r_binary_install party - ./travis-tool.sh r_binary_install pls - - ./travis-tool.sh r_binary_install pROC + - ./travis-tool.sh r_binary_install ModelMetrics - ./travis-tool.sh r_binary_install proxy - ./travis-tool.sh r_binary_install randomForest - ./travis-tool.sh r_binary_install RANN From c8236799a6715399d407eb7386664faf2baebbf6 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 09:42:20 -0600 Subject: [PATCH 10/30] adding auc from ModelMetrics --- pkg/caret/NAMESPACE | 1 + pkg/caret/R/filterVarImp.R | 47 +++++++++++++++++++++++----------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pkg/caret/NAMESPACE b/pkg/caret/NAMESPACE index be73a939..606068fc 100644 --- a/pkg/caret/NAMESPACE +++ b/pkg/caret/NAMESPACE @@ -1,4 +1,5 @@ useDynLib(caret) +importFrom(ModelMetrics, auc) import(foreach, methods, plyr, reshape2, ggplot2, lattice, nlme) importFrom(car, powerTransform, yjPower) importFrom(grDevices, extendrange) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 55dc2aa4..44995248 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -7,30 +7,30 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) if(any(notNumber)) { for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } + } } if(is.factor(y)) { classLevels <- levels(y) - + outStat <- matrix(NA, nrow = dim(x)[2], ncol = length(classLevels)) for(i in seq(along = classLevels)) { otherLevels <- classLevels[classLevels != classLevels[i]] - + for(k in seq(along = otherLevels)) { tmpSubset <- as.character(y) %in% c(classLevels[i], otherLevels[k]) tmpY <- factor(as.character(y)[tmpSubset]) - tmpX <- x[tmpSubset,] - + tmpX <- x[tmpSubset,] + rocAuc <- apply( - tmpX, - 2, + tmpX, + 2, function(x, class, pos) { - isMissing <- is.na(x) | is.na(class) + isMissing <- is.na(x) | is.na(class) if(any(isMissing)) { x <- x[!isMissing] @@ -40,16 +40,16 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) else roc(x, class = class, dataGrid = FALSE, positive = pos) aucRoc(outResults) }, - class = tmpY, + class = tmpY, pos = classLevels[i]) - outStat[, i] <- pmax(outStat[, i], rocAuc, na.rm = TRUE) + outStat[, i] <- pmax(outStat[, i], rocAuc, na.rm = TRUE) } if(i ==1 & length(classLevels) == 2) { outStat[, 2] <- outStat[, 1] break() - } - } + } + } colnames(outStat) <- classLevels rownames(outStat) <- dimnames(x)[[2]] outStat <- data.frame(outStat) @@ -60,18 +60,18 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) { meanMod <- sum((y - mean(y, rm.na = TRUE))^2) nzv <- nearZeroVar(x, saveMetrics = TRUE) - + if(nzv$zeroVar) return(NA) if(nzv$percentUnique < 20) { regMod <- lm(y~x, na.action = na.omit, ...) } else { regMod <- try(loess(y~x, na.action = na.omit, ...), silent = TRUE) - + if(class(regMod) == "try-error" | any(is.nan(regMod$residuals))) try(regMod <- lm(y~x, ...)) if(class(regMod) == "try-error") return(NA) } - + pR2 <- 1 - (sum(resid(regMod)^2)/meanMod) if(pR2 < 0) pR2 <- 0 pR2 @@ -79,7 +79,7 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) testFunc <- if(nonpara) nonparaFoo else paraFoo - outStat <- apply(x, 2, testFunc, y = y) + outStat <- apply(x, 2, testFunc, y = y) outStat <- data.frame(Overall = outStat) } outStat @@ -87,8 +87,7 @@ oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) rocPerCol <- function(dat, cls) { - loadNamespace("pROC") - pROC::roc(cls, dat, direction = "<")$auc + 1 - ModelMetrics::auc(cls, dat) } filterVarImp <- function(x, y, nonpara = FALSE, ...) @@ -98,14 +97,14 @@ filterVarImp <- function(x, y, nonpara = FALSE, ...) if(any(notNumber)) { for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } + } } if(is.factor(y)) { classLevels <- levels(y) k <- length(classLevels) - + if(k > 2) { counter <- 1 @@ -145,18 +144,18 @@ filterVarImp <- function(x, y, nonpara = FALSE, ...) { meanMod <- sum((y - mean(y, rm.na = TRUE))^2) nzv <- nearZeroVar(x, saveMetrics = TRUE) - + if(nzv$zeroVar) return(NA) if(nzv$percentUnique < 20) { regMod <- lm(y~x, na.action = na.omit, ...) } else { regMod <- try(loess(y~x, na.action = na.omit, ...), silent = TRUE) - + if(class(regMod) == "try-error" | any(is.nan(regMod$residuals))) try(regMod <- lm(y~x, ...)) if(class(regMod) == "try-error") return(NA) } - + pR2 <- 1 - (sum(resid(regMod)^2)/meanMod) if(pR2 < 0) pR2 <- 0 pR2 @@ -164,7 +163,7 @@ filterVarImp <- function(x, y, nonpara = FALSE, ...) testFunc <- if(nonpara) nonparaFoo else paraFoo - outStat <- apply(x, 2, testFunc, y = y) + outStat <- apply(x, 2, testFunc, y = y) outStat <- data.frame(Overall = outStat) } outStat From e7f0f05114e8febffed62dc682adc199887f0f89 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 09:44:30 -0600 Subject: [PATCH 11/30] less verbose --- pkg/caret/R/filterVarImp.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 44995248..8a617bff 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -93,7 +93,7 @@ rocPerCol <- function(dat, cls) { filterVarImp <- function(x, y, nonpara = FALSE, ...) { { - notNumber <- unlist(lapply(x, function(x) !is.numeric(x))) + notNumber <- sapply(x, function(x) !is.numeric(x)) if(any(notNumber)) { for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) From 748023e8224f6d26a12df53c42900e399cb8ed6e Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 09:48:39 -0600 Subject: [PATCH 12/30] not needed anymore --- pkg/caret/tests/testthat/test_pROC_direction.R | 24 ------------------------ 1 file changed, 24 deletions(-) delete mode 100644 pkg/caret/tests/testthat/test_pROC_direction.R diff --git a/pkg/caret/tests/testthat/test_pROC_direction.R b/pkg/caret/tests/testthat/test_pROC_direction.R deleted file mode 100644 index 8b0a7d8f..00000000 --- a/pkg/caret/tests/testthat/test_pROC_direction.R +++ /dev/null @@ -1,24 +0,0 @@ -library(caret) - -context('Testing pROC direction') - -test_that('rocPerCol returns AUC < 0.5 with direction = "<"', { - #skip_on_cran() - skip_if_not_installed('pROC') - - set.seed(42) - dat <- twoClassSim(200, linearVars = 1) - - auto.auc <- as.numeric(pROC::roc(dat$Class, dat$Linear1, direction = "auto")$auc) - fixed.auc <- as.numeric(pROC::roc(dat$Class, dat$Linear1, direction = "<")$auc) - tested.auc <- as.numeric(caret:::rocPerCol(dat$Linear1, dat$Class)) - - # tested.auc should equal tested.auc now - expect_equal(tested.auc, fixed.auc) - - # Also it has been hand-checked to be < 0.5 with this seed (0.4875) - expect_lt(tested.auc, 0.5) - - # And it should be lower than the "auto" auc - expect_lt(tested.auc, auto.auc) -}) From ff15e04c7c9d538b30decc7382767d5d8e760e24 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 10:29:24 -0600 Subject: [PATCH 13/30] minor test cleanup --- pkg/caret/tests/testthat/test_glmnet_varImp.R | 6 +++--- pkg/caret/tests/testthat/test_models_bagEarth.R | 1 + pkg/caret/tests/testthat/test_sampling_options.R | 25 ++++++++++++------------ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/pkg/caret/tests/testthat/test_glmnet_varImp.R b/pkg/caret/tests/testthat/test_glmnet_varImp.R index cef38045..261ec5c5 100644 --- a/pkg/caret/tests/testthat/test_glmnet_varImp.R +++ b/pkg/caret/tests/testthat/test_glmnet_varImp.R @@ -8,17 +8,17 @@ test_that('glmnet varImp returns non-negative values', { skip_if_not_installed('glmnet') set.seed(1) dat <- SLC14_1(200) - + reg <- train(y ~ ., data = dat, method = "glmnet", tuneGrid = data.frame(lambda = .1, alpha = .5), trControl = trainControl(method = "none")) - + # this checks that some coefficients are negative coefs <- predict(reg$finalModel, s=0.1, type="coef") expect_less_than(0, sum(0 > coefs)) # now check that all elements of varImp are nonnegative, # in spite of negative coefficients vis <- varImp(reg, s=0.1, scale=F)$importance - expect_equal(0, sum(0 > vis)) + expect_true(all(vis >= 0)) }) diff --git a/pkg/caret/tests/testthat/test_models_bagEarth.R b/pkg/caret/tests/testthat/test_models_bagEarth.R index 7f8b1f6c..63489e41 100644 --- a/pkg/caret/tests/testthat/test_models_bagEarth.R +++ b/pkg/caret/tests/testthat/test_models_bagEarth.R @@ -2,6 +2,7 @@ # such as the bagEarth() not returning the right kind of object, that one of # the functions (bagEarth, format, predict) crash during normal usage, or that # bagEarth cannot model a simplistic kind of linear equation. +context("earth") test_that('bagEarth simple regression', { skip_on_cran() data <- data.frame(X = 1:100) diff --git a/pkg/caret/tests/testthat/test_sampling_options.R b/pkg/caret/tests/testthat/test_sampling_options.R index 12abe61c..e199c4a0 100644 --- a/pkg/caret/tests/testthat/test_sampling_options.R +++ b/pkg/caret/tests/testthat/test_sampling_options.R @@ -1,18 +1,19 @@ library(caret) library(testthat) -load(system.file("models", "sampling.RData", package = "caret")) +context("sampling options") +load(system.file("models", "sampling.RData", package = "caret")) test_that('check appropriate sampling calls by name', { skip_on_cran() arg_names <- c("up", "down", "rose", "smote") arg_funcs <- sampling_methods arg_first <- c(TRUE, FALSE) - + ## test that calling by string gives the right result for(i in arg_names) { out <- caret:::parse_sampling(i) - expected <- list(name = i, + expected <- list(name = i, func = sampling_methods[[i]], first = TRUE) expect_equivalent(out, expected) @@ -24,11 +25,11 @@ test_that('check appropriate sampling calls by function', { arg_names <- c("up", "down", "rose", "smote") arg_funcs <- sampling_methods arg_first <- c(TRUE, FALSE) - + ## test that calling by function gives the right result for(i in arg_names) { out <- caret:::parse_sampling(sampling_methods[[i]]) - expected <- list(name = "custom", + expected <- list(name = "custom", func = sampling_methods[[i]], first = TRUE) expect_equivalent(out, expected) @@ -39,26 +40,26 @@ test_that('check bad sampling name', { skip_on_cran() expect_error(caret:::parse_sampling("what?")) }) - + test_that('check bad first arg', { skip_on_cran() expect_error(caret:::parse_sampling(list(name = "yep", func = sampling_methods[["up"]], first = 2))) -}) +}) test_that('check bad func arg', { skip_on_cran() expect_error(caret:::parse_sampling(list(name = "yep", func = I, first = 2))) -}) - +}) + test_that('check incomplete list', { skip_on_cran() expect_error(caret:::parse_sampling(list(name = "yep"))) -}) +}) test_that('check call', { skip_on_cran() expect_error(caret:::parse_sampling(14)) -}) +}) ################################################################### ## @@ -82,4 +83,4 @@ test_that('check getting one method', { test_that('check missing method', { skip_on_cran() expect_error(getSamplingInfo("plum")) -}) \ No newline at end of file +}) From c21c900a08910a12c178bc7620fefe47b1b2bc3b Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 10:29:45 -0600 Subject: [PATCH 14/30] modified tests for mnLogLoss --- pkg/caret/R/aaa.R | 129 +++++++++++++++--------------- pkg/caret/tests/testthat/test_mnLogLoss.R | 29 +++---- 2 files changed, 76 insertions(+), 82 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index 8512d8e1..fb712e51 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -20,25 +20,25 @@ ################################################################### if(getRversion() >= "2.15.1"){ - + utils::globalVariables(c('Metric', 'Model')) - - + + ## densityplot(~ values|Metric, data = plotData, groups = ind, ## xlab = "", ...) - + utils::globalVariables(c('ind')) - + ## avPerf <- ddply(subset(results, Metric == metric[1] & X2 == "Estimate"), ## .(Model), ## function(x) c(Median = median(x$value, na.rm = TRUE))) - + utils::globalVariables(c('X2')) - + ## x[[i]]$resample <- subset(x[[i]]$resample, Variables == x[[i]]$bestSubset) - + utils::globalVariables(c('Variables')) - + ## calibCalc: no visible binding for global variable 'obs' ## calibCalc: no visible binding for global variable 'bin' ## @@ -47,9 +47,9 @@ if(getRversion() >= "2.15.1"){ ## binData <- data.frame(prob = x$calibProbVar, ## bin = cut(x$calibProbVar, (0:cuts)/cuts, include.lowest = TRUE), ## class = x$calibClassVar) - + utils::globalVariables(c('obs', 'bin')) - + ## ## checkConditionalX: no visible binding for global variable '.outcome' ## checkConditionalX <- function(x, y) @@ -57,9 +57,9 @@ if(getRversion() >= "2.15.1"){ ## x$.outcome <- y ## unique(unlist(dlply(x, .(.outcome), zeroVar))) ## } - + utils::globalVariables(c('.outcome')) - + ## classLevels.splsda: no visible global function definition for 'ilevels' ## ## classLevels.splsda <- function(x, ...) @@ -68,9 +68,9 @@ if(getRversion() >= "2.15.1"){ ## ## same class name, but this works for either ## ilevels(x$y) ## } - + utils::globalVariables(c('ilevels')) - + ## looRfeWorkflow: no visible binding for global variable 'iter' ## looSbfWorkflow: no visible binding for global variable 'iter' ## looTrainWorkflow: no visible binding for global variable 'parm' @@ -93,9 +93,9 @@ if(getRversion() >= "2.15.1"){ ## .errorhandling = "stop") %dopar% ## { ## - + utils::globalVariables(c('iter', 'parm', 'method', 'Resample', 'dat')) - + ## tuneScheme: no visible binding for global variable '.alpha' ## tuneScheme: no visible binding for global variable '.phi' ## tuneScheme: no visible binding for global variable '.lambda' @@ -103,9 +103,9 @@ if(getRversion() >= "2.15.1"){ ## seqParam[[i]] <- data.frame(.lambda = subset(grid, ## subset = .phi == loop$.phi[i] & ## .lambda < loop$.lambda[i])$.lambda) - + utils::globalVariables(c('.alpha', '.phi', '.lambda')) - + ## createGrid : somDims: no visible binding for global variable '.xdim' ## createGrid : somDims: no visible binding for global variable '.ydim' ## createGrid : lvqGrid: no visible binding for global variable '.k' @@ -114,9 +114,9 @@ if(getRversion() >= "2.15.1"){ ## out <- expand.grid(.xdim = 1:x, .ydim = 2:(x+1), ## .xweight = seq(.5, .9, length = len)) ## - + utils::globalVariables(c('.xdim', '.ydim', '.k', '.size')) - + ## createModel: possible error in rda(trainX, trainY, gamma = ## tuneValue$.gamma, lambda = tuneValue$.lambda, ...): unused ## argument(s) (gamma = tuneValue$.gamma, lambda = tuneValue$.lambda) @@ -144,54 +144,54 @@ if(getRversion() >= "2.15.1"){ ## ## $lambda ## [1] NA - + ## predictionFunction: no visible binding for global variable '.alpha' ## ## delta <- subset(param, .alpha == uniqueA[i])$.delta ## - + utils::globalVariables(c('.alpha')) - + ## predictors.gbm: no visible binding for global variable 'rel.inf' ## predictors.sda: no visible binding for global variable 'varIndex' ## predictors.smda: no visible binding for global variable 'varIndex' ## ## varUsed <- as.character(subset(relImp, rel.inf != 0)$var) - + utils::globalVariables(c('rel.inf', 'varIndex')) - + ## plotClassProbs: no visible binding for global variable 'Observed' ## ## out <- densityplot(form, data = stackProbs, groups = Observed, ...) - + utils::globalVariables(c('Observed')) - + ## plot.train: no visible binding for global variable 'parameter' ## ## paramLabs <- subset(modelInfo, parameter %in% params)$label - + utils::globalVariables(c('parameter')) - + ## plot.rfe: no visible binding for global variable 'Selected' ## ## out <- xyplot(plotForm, data = results, groups = Selected, panel = panel.profile, ...) - + utils::globalVariables(c('Selected')) - + ## icr.formula: no visible binding for global variable 'thresh' ## ## res <- icr.default(x, y, weights = w, thresh = thresh, ...) - + utils::globalVariables(c('thresh', 'probValues', 'min_prob', 'groups', 'trainData', 'j', 'x', '.B')) - + utils::globalVariables(c('model_id', 'player1', 'player2', 'playa', 'win1', 'win2', 'name')) - + utils::globalVariables(c('object', 'Iter', 'lvls', 'Mean', 'Estimate')) - - + + ## parse_sampling: no visible binding for global variable 'sampling_methods' utils::globalVariables(c('sampling_methods')) - + ## ggplot.calibration: no visible binding for global variable 'midpoint' ## ggplot.calibration: no visible binding for global variable 'Percent' ## ggplot.calibration: no visible binding for global variable 'Lower' @@ -206,10 +206,10 @@ altTrainWorkflow <- function(x) x best <- function(x, metric, maximize) { - + bestIter <- if(maximize) which.max(x[,metric]) else which.min(x[,metric]) - + bestIter } @@ -242,26 +242,25 @@ mnLogLoss <- function(data, lev = NULL, model = NULL){ stop("'data' should have columns consistent with 'lev'") if(!all(sort(lev) %in% sort(levels(data$obs)))) stop("'data$obs' should have levels consistent with 'lev'") - eps <- 1e-15 - probs <- as.matrix(data[, lev, drop = FALSE]) - probs[probs > 1 - eps] <- 1 - eps - probs[probs < eps] <- eps - inds <- match(data$obs, colnames(probs)) - probs <- probs[cbind(seq_len(nrow(probs)), inds)] - c(logLoss = -mean(log(probs), na.rm = TRUE)) + + dataComplete <- data[complete.cases(data),] + probs <- as.matrix(dataComplete[, lev, drop = FALSE]) + + inds <- match(dataComplete$obs, colnames(probs)) + ModelMetrics::mlogLoss(dataComplete$obs, probs) } multiClassSummary <- function (data, lev = NULL, model = NULL){ #Check data - if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) + if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) stop("levels of observed and predicted data do not match") has_class_probs <- all(lev %in% colnames(data)) if(has_class_probs) { ## Overall multinomial loss lloss <- mnLogLoss(data = data, lev = lev, model = model) - requireNamespaceQuietStop("pROC") + requireNamespaceQuietStop("pROC") #Calculate custom one-vs-all ROC curves for each class - prob_stats <- lapply(levels(data[, "pred"]), + prob_stats <- lapply(levels(data[, "pred"]), function(class){ #Grab one-vs-all data for the class obs <- ifelse(data[, "obs"] == class, 1, 0) @@ -269,14 +268,14 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ rocObject <- try(pROC::roc(obs, data[,class], direction = "<"), silent = TRUE) prob_stats <- if (class(rocObject)[1] == "try-error") NA else rocObject$auc names(prob_stats) <- c('ROC') - return(prob_stats) + return(prob_stats) }) roc_stats <- mean(unlist(prob_stats)) } - + #Calculate confusion matrix-based statistics CM <- confusionMatrix(data[, "pred"], data[, "obs"]) - + #Aggregate and average class-wise stats #Todo: add weights # RES: support two classes here as well @@ -287,32 +286,32 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ class_stats <- colMeans(CM$byClass) names(class_stats) <- paste("Mean", names(class_stats)) } - + # Aggregate overall stats - overall_stats <- if(has_class_probs) + overall_stats <- if(has_class_probs) c(CM$overall, lloss, ROC = roc_stats) else CM$overall - if (length(levels(data[, "pred"])) > 2) + if (length(levels(data[, "pred"])) > 2) names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_ROC" - - - # Combine overall with class-wise stats and remove some stats we don't want + + + # Combine overall with class-wise stats and remove some stats we don't want stats <- c(overall_stats, class_stats) stats <- stats[! names(stats) %in% c('AccuracyNull', "AccuracyLower", "AccuracyUpper", - "AccuracyPValue", "McnemarPValue", + "AccuracyPValue", "McnemarPValue", 'Mean Prevalence', 'Mean Detection Prevalence')] - + # Clean names names(stats) <- gsub('[[:blank:]]+', '_', names(stats)) - + # Change name ordering to place most useful first # May want to remove some of these eventually - stat_list <- c("Accuracy", "Kappa", "Mean_Sensitivity", "Mean_Specificity", + stat_list <- c("Accuracy", "Kappa", "Mean_Sensitivity", "Mean_Specificity", "Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate", "Mean_Balanced_Accuracy") if(has_class_probs) stat_list <- c("logLoss", "Mean_ROC", stat_list) if (length(levels(data[, "pred"])) == 2) stat_list <- gsub("^Mean_", "", stat_list) - + stats <- stats[c(stat_list)] - + return(stats) } diff --git a/pkg/caret/tests/testthat/test_mnLogLoss.R b/pkg/caret/tests/testthat/test_mnLogLoss.R index def12606..ae19d1f6 100644 --- a/pkg/caret/tests/testthat/test_mnLogLoss.R +++ b/pkg/caret/tests/testthat/test_mnLogLoss.R @@ -1,7 +1,5 @@ context('mnLogLoss') -eps <- 1e-15 - classes <- LETTERS[1:3] test_dat1 <- data.frame(obs = c("A", "A", "A", "B", "B", "C"), @@ -10,24 +8,21 @@ test_dat1 <- data.frame(obs = c("A", "A", "A", "B", "B", "C"), B = c(0, .05, .29, .8, .6, .3), C = c(0, .15, .20, .1, .2, .4)) -expected1 <- log(1-eps) + log(.8) + log(.51) + log(.8) + log(.6) + log(.4) -expected1 <- c(logLoss = -expected1/nrow(test_dat1)) -result1 <- mnLogLoss(test_dat1, lev = classes) +test_that("Multiclass logloss returns expected values", { + result1 <- mnLogLoss(test_dat1, classes) -test_dat2 <- test_dat1 -test_dat2$A[1] <- NA + test_dat2 <- test_dat1 + test_dat2$A[1] <- NA + result2 <- mnLogLoss(test_dat2, classes) -expected2 <- log(.8) + log(.51) + log(.8) + log(.6) + log(.4) -expected2 <- c(logLoss = -expected2/sum(complete.cases(test_dat2))) -result2 <- mnLogLoss(test_dat2, lev = classes) + test_dat3 <- test_dat1 + test_dat3 <- test_dat3[, rev(1:5)] + result3 <- mnLogLoss(test_dat3, classes) -test_dat3 <- test_dat1 -test_dat3 <- test_dat3[, rev(1:5)] -expected3 <- expected1 -result3 <- mnLogLoss(test_dat3, lev = classes[c(2, 3, 1)]) + expect_equal(result1, 0.424458, tolerance = .000001) + expect_equal(result2, 0.5093496, tolerance = .000001) + expect_equal(result3, 0.424458, tolerance = .000001) -expect_equal(result1, expected1) -expect_equal(result2, expected2) -expect_equal(result3, expected3) +}) From 5171e69f7bbeb4e501dd6c78bd444fc3869178fb Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Tue, 16 Aug 2016 10:51:17 -0600 Subject: [PATCH 15/30] reworked stats to include auc from ModelMetrics and cleaned up --- pkg/caret/R/aaa.R | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index fb712e51..70140412 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -258,17 +258,15 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ if(has_class_probs) { ## Overall multinomial loss lloss <- mnLogLoss(data = data, lev = lev, model = model) - requireNamespaceQuietStop("pROC") + requireNamespaceQuietStop("ModelMetrics") #Calculate custom one-vs-all ROC curves for each class prob_stats <- lapply(levels(data[, "pred"]), - function(class){ + function(x){ #Grab one-vs-all data for the class - obs <- ifelse(data[, "obs"] == class, 1, 0) - prob <- data[,class] - rocObject <- try(pROC::roc(obs, data[,class], direction = "<"), silent = TRUE) - prob_stats <- if (class(rocObject)[1] == "try-error") NA else rocObject$auc - names(prob_stats) <- c('ROC') - return(prob_stats) + obs <- ifelse(data[, "obs"] == x, 1, 0) + prob <- data[,x] + AUCs <- try(ModelMetrics::auc(obs, data[,x]), silent = TRUE) + return(AUCs) }) roc_stats <- mean(unlist(prob_stats)) } @@ -289,9 +287,9 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ # Aggregate overall stats overall_stats <- if(has_class_probs) - c(CM$overall, lloss, ROC = roc_stats) else CM$overall + c(CM$overall, logLoss = lloss, ROC = roc_stats) else CM$overall if (length(levels(data[, "pred"])) > 2) - names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_ROC" + names(overall_stats)[names(overall_stats) == "ROC"] <- "Mean_AUC" # Combine overall with class-wise stats and remove some stats we don't want @@ -308,7 +306,7 @@ multiClassSummary <- function (data, lev = NULL, model = NULL){ stat_list <- c("Accuracy", "Kappa", "Mean_Sensitivity", "Mean_Specificity", "Mean_Pos_Pred_Value", "Mean_Neg_Pred_Value", "Mean_Detection_Rate", "Mean_Balanced_Accuracy") - if(has_class_probs) stat_list <- c("logLoss", "Mean_ROC", stat_list) + if(has_class_probs) stat_list <- c("logLoss", "Mean_AUC", stat_list) if (length(levels(data[, "pred"])) == 2) stat_list <- gsub("^Mean_", "", stat_list) stats <- stats[c(stat_list)] From bfb282ab13f89dc131b05303a34164884ee52d68 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Fri, 26 Aug 2016 08:43:22 -0600 Subject: [PATCH 16/30] dropping uneeded files --- pkg/caret/R/aucRoc.R | 15 --------------- pkg/caret/R/roc.R | 21 --------------------- pkg/caret/R/rocPoint.R | 19 ------------------- 3 files changed, 55 deletions(-) delete mode 100644 pkg/caret/R/aucRoc.R delete mode 100644 pkg/caret/R/roc.R delete mode 100644 pkg/caret/R/rocPoint.R diff --git a/pkg/caret/R/aucRoc.R b/pkg/caret/R/aucRoc.R deleted file mode 100644 index 2820e846..00000000 --- a/pkg/caret/R/aucRoc.R +++ /dev/null @@ -1,15 +0,0 @@ -aucRoc <- function(object) -{ - warning("This function is deprecated a of 1/3/12. The computations now utilize the pROC package. This function will be removed in a few releases.") - - sens <- object[, "sensitivity"] - omspec <- 1 - object[, "specificity"] - newOrder <- order(omspec) - sens <- sens[newOrder] - omspec <- omspec[newOrder] - - rocArea <- sum(.5 *diff(omspec) * (sens[-1] + sens[-length(sens)])) - rocArea <- max(rocArea, 1 - rocArea) - rocArea -} - diff --git a/pkg/caret/R/roc.R b/pkg/caret/R/roc.R deleted file mode 100644 index 104a67d1..00000000 --- a/pkg/caret/R/roc.R +++ /dev/null @@ -1,21 +0,0 @@ -roc <- function(data, class, dataGrid = TRUE, gridLength = 100, positive = levels(class)[1]) -{ - warning("This function is deprecated a of 1/3/12. The computations now utilize the pROC package. This function will be removed in a few releases.") - - if(!is.character(positive) | length(positive) != 1) stop("positive argument should be a single character value") - - if(!(positive %in% levels(class))) stop("wrong level specified") - if(length(levels(class)) != 2) stop("wrong number of levels") - if(dataGrid) cutoffDF <- data.frame(value = sort(unique(data))) - else cutoffDF <- data.frame(value = seq( - from = min(data, na.rm = TRUE), - to = max(data, na.rm = TRUE), - length = gridLength)) - numCuts <- dim(cutoffDF)[1] - out <- matrix(NA, ncol = 3, nrow = numCuts + 1) - - out[2:(numCuts + 1), ] <- t(apply(cutoffDF, 1, rocPoint, x = data, y = class, positive = positive)) - out[1, ] <- c(NA, 1, 0) - colnames(out) <- c("cutoff", "sensitivity", "specificity") - out -} diff --git a/pkg/caret/R/rocPoint.R b/pkg/caret/R/rocPoint.R deleted file mode 100644 index c4dca43f..00000000 --- a/pkg/caret/R/rocPoint.R +++ /dev/null @@ -1,19 +0,0 @@ -rocPoint <- function(cutoff, x, y, positive) -{ - warning("This function is deprecated a of 1/3/12. The computations now utilize the pROC package. This function will be removed in a few releases.") - classLevels <- levels(y) - negative <- classLevels[positive != classLevels] - newClass <- factor( - ifelse( - x <= cutoff, - negative, - positive), - levels = classLevels) - out <- c( - cutoff, - sensitivity(newClass, y, positive), - specificity(newClass, y, negative)) - names(out) <- c("cutoff", "sensitivity", "specificity") - out -} - From 5de5cbfb136f38cfd2f5081c6c42ec9d4052d22b Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Fri, 26 Aug 2016 08:43:43 -0600 Subject: [PATCH 17/30] optimizing filterVarImp --- pkg/caret/R/filterVarImp.R | 148 +++++++++------------------------------------ 1 file changed, 28 insertions(+), 120 deletions(-) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 8a617bff..801b8597 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -1,133 +1,41 @@ -## todo start using foreach here - -oldfilterVarImp <- function(x, y, nonpara = FALSE, ...) -{ - { - notNumber <- unlist(lapply(x, function(x) !is.numeric(x))) - if(any(notNumber)) - { - for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } - } - - if(is.factor(y)) - { - classLevels <- levels(y) - - outStat <- matrix(NA, nrow = dim(x)[2], ncol = length(classLevels)) - for(i in seq(along = classLevels)) - { - otherLevels <- classLevels[classLevels != classLevels[i]] - - for(k in seq(along = otherLevels)) - { - tmpSubset <- as.character(y) %in% c(classLevels[i], otherLevels[k]) - tmpY <- factor(as.character(y)[tmpSubset]) - tmpX <- x[tmpSubset,] - - rocAuc <- apply( - tmpX, - 2, - function(x, class, pos) - { - isMissing <- is.na(x) | is.na(class) - if(any(isMissing)) - { - x <- x[!isMissing] - class <- class[!isMissing] - } - outResults <- if(length(unique(x)) > 200) roc(x, class = class, positive = pos) - else roc(x, class = class, dataGrid = FALSE, positive = pos) - aucRoc(outResults) - }, - class = tmpY, - pos = classLevels[i]) - outStat[, i] <- pmax(outStat[, i], rocAuc, na.rm = TRUE) - } - if(i ==1 & length(classLevels) == 2) - { - outStat[, 2] <- outStat[, 1] - break() - } - } - colnames(outStat) <- classLevels - rownames(outStat) <- dimnames(x)[[2]] - outStat <- data.frame(outStat) - } else { - paraFoo <- function(data, y) abs(coef(summary(lm(y ~ data, na.action = na.omit)))[2, "t value"]) - nonparaFoo <- function(x, y, ...) - { - meanMod <- sum((y - mean(y, rm.na = TRUE))^2) - nzv <- nearZeroVar(x, saveMetrics = TRUE) +rocPerCol <- function(dat, cls){ + ModelMetrics::auc(cls, dat) +} - if(nzv$zeroVar) return(NA) - if(nzv$percentUnique < 20) - { - regMod <- lm(y~x, na.action = na.omit, ...) - } else { - regMod <- try(loess(y~x, na.action = na.omit, ...), silent = TRUE) +asNumeric <- function(data){ + fc <- sapply(data, is.factor) + modifyList(data, lapply(data[, fc], as.numeric)) +} - if(class(regMod) == "try-error" | any(is.nan(regMod$residuals))) try(regMod <- lm(y~x, ...)) - if(class(regMod) == "try-error") return(NA) - } +filterVarImp <- function(x, y, nonpara = FALSE, ...){ + # converting factors to numeric + notNumber <- sapply(x, function(x) !is.numeric(x)) + x = asNumeric(x) - pR2 <- 1 - (sum(resid(regMod)^2)/meanMod) - if(pR2 < 0) pR2 <- 0 - pR2 - } + if(is.factor(y)){ + classLevels <- levels(y) + k <- length(classLevels) - testFunc <- if(nonpara) nonparaFoo else paraFoo + if(k > 2){ - outStat <- apply(x, 2, testFunc, y = y) - outStat <- data.frame(Overall = outStat) - } - outStat -} + Combs <- combn(classLevels, 2) + CombsN <- combn(1:k, 2) + lStat <- lapply(1:ncol(Combs), FUN = function(cc){ + yLevs <- as.character(y) %in% Combs[,cc] + tmpX <- x[yLevs,] + tmpY <- as.numeric(y[yLevs] == Combs[,cc][2]) + apply(tmpX, 2, rocPerCol, cls = tmpY) + }) + Stat = do.call("cbind", lStat) -rocPerCol <- function(dat, cls) { - 1 - ModelMetrics::auc(cls, dat) -} + loutStat <- lapply(1:k, function(j){ + apply(Stat[,CombsN[,j]], 1, max) + }) -filterVarImp <- function(x, y, nonpara = FALSE, ...) -{ - { - notNumber <- sapply(x, function(x) !is.numeric(x)) - if(any(notNumber)) - { - for(i in which(notNumber)) x[,i] <- as.numeric(x[,i]) - } - } - - if(is.factor(y)) - { - classLevels <- levels(y) - k <- length(classLevels) + outStat = do.call("cbind", loutStat) - if(k > 2) - { - counter <- 1 - classIndex <- vector(mode = "list", length = k) - tmpStat <- matrix(NA, nrow = ncol(x), ncol = choose(k, 2)) - for(i in 1:k) - { - for(j in i:k) - { - if(i != j) - { - classIndex[[i]] <- c(classIndex[[i]], counter) - classIndex[[j]] <- c(classIndex[[j]], counter) - index <- which(y %in% c(classLevels[i], classLevels[j])) - tmpX <- x[index,,drop = FALSE] - tmpY <- factor(as.character(y[index]), levels = c(classLevels[i], classLevels[j])) - tmpStat[,counter] <- apply(tmpX, 2, rocPerCol, cls = tmpY) - counter <- counter + 1 - } - } - } - outStat <- matrix(NA, ncol(x), k) - for(i in 1:k) outStat[,i] <- apply(tmpStat[,classIndex[[i]]], 1, max) } else { tmp <- apply(x, 2, rocPerCol, cls = y) outStat <- cbind(tmp, tmp) From 96fdab07c5dac02ecaafd451e0196d5f27d70cbb Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Fri, 26 Aug 2016 08:43:52 -0600 Subject: [PATCH 18/30] use ModelMetrics --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d88e179c..f020a21e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,7 +44,7 @@ before_install: - ./travis-tool.sh r_binary_install nnet - ./travis-tool.sh r_binary_install party - ./travis-tool.sh r_binary_install pls - - ./travis-tool.sh r_binary_install pROC + - ./travis-tool.sh r_binary_install ModelMetrics - ./travis-tool.sh r_binary_install proxy - ./travis-tool.sh r_binary_install randomForest - ./travis-tool.sh r_binary_install RANN From 1a68166f95b4accc403450c68a95d5dc6edabdc0 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sat, 3 Sep 2016 10:47:30 -0600 Subject: [PATCH 19/30] fix for deprecated test --- pkg/caret/tests/testthat/test_glmnet_varImp.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/caret/tests/testthat/test_glmnet_varImp.R b/pkg/caret/tests/testthat/test_glmnet_varImp.R index 261ec5c5..18331e3e 100644 --- a/pkg/caret/tests/testthat/test_glmnet_varImp.R +++ b/pkg/caret/tests/testthat/test_glmnet_varImp.R @@ -16,7 +16,7 @@ test_that('glmnet varImp returns non-negative values', { # this checks that some coefficients are negative coefs <- predict(reg$finalModel, s=0.1, type="coef") - expect_less_than(0, sum(0 > coefs)) + expect_lt(0, sum(0 > coefs)) # now check that all elements of varImp are nonnegative, # in spite of negative coefficients vis <- varImp(reg, s=0.1, scale=F)$importance From 589232c4e7bbadd1276867b957c316b7a44bca68 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sat, 3 Sep 2016 10:47:40 -0600 Subject: [PATCH 20/30] replacing pROC dependency --- pkg/caret/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/caret/DESCRIPTION b/pkg/caret/DESCRIPTION index 03b460f8..cc2db87e 100644 --- a/pkg/caret/DESCRIPTION +++ b/pkg/caret/DESCRIPTION @@ -44,7 +44,7 @@ Suggests: nnet, party (>= 0.9-99992), pls, - pROC (>= 1.8), + ModelMetrics (>= 1.1.0), proxy, randomForest, RANN, From f1dfd692fcd950c301e672a2cda0f09fe542e2c6 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sat, 3 Sep 2016 13:39:51 -0600 Subject: [PATCH 21/30] temporary github install --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index f020a21e..1cd7c443 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,7 +44,7 @@ before_install: - ./travis-tool.sh r_binary_install nnet - ./travis-tool.sh r_binary_install party - ./travis-tool.sh r_binary_install pls - - ./travis-tool.sh r_binary_install ModelMetrics + - ./travis-tool.sh install_github ModelMetrics - ./travis-tool.sh r_binary_install proxy - ./travis-tool.sh r_binary_install randomForest - ./travis-tool.sh r_binary_install RANN From 878270143dd60cf25adfa195d4a6e3c37078852a Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sat, 3 Sep 2016 14:12:57 -0600 Subject: [PATCH 22/30] ModelMetrics in twoClassSummary --- pkg/caret/R/aaa.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index 70140412..7b231b06 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -224,11 +224,10 @@ twoClassSummary <- function (data, lev = NULL, model = NULL) if(length(levels(data$obs)) > 2) stop(paste("Your outcome has", length(levels(data$obs)), "levels. The twoClassSummary() function isn't appropriate.")) - requireNamespaceQuietStop('pROC') + requireNamespaceQuietStop('ModelMetrics') if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) stop("levels of observed and predicted data do not match") - rocObject <- try(pROC::roc(data$obs, data[, lev[1]], direction = ">"), silent = TRUE) - rocAUC <- if(class(rocObject)[1] == "try-error") NA else rocObject$auc + rocAUC <- ModelMetrics::auc(data$obs, data$pred) out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"], lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2])) From f5cb9c85ca7c0fdcb3e54c43b7774c8a2f2735a5 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sun, 4 Sep 2016 08:48:17 -0600 Subject: [PATCH 23/30] added code and tests twoClassSummary --- pkg/caret/R/aaa.R | 10 +++++--- pkg/caret/tests/testthat/test_twoClassSummary.R | 33 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) create mode 100644 pkg/caret/tests/testthat/test_twoClassSummary.R diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index 7b231b06..5c7aa28f 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -221,13 +221,15 @@ defaultSummary <- function(data, lev = NULL, model = NULL) twoClassSummary <- function (data, lev = NULL, model = NULL) { - if(length(levels(data$obs)) > 2) - stop(paste("Your outcome has", length(levels(data$obs)), + lvls <- levels(data$obs) + if(length(lvls) > 2) + stop(paste("Your outcome has", length(lvls), "levels. The twoClassSummary() function isn't appropriate.")) requireNamespaceQuietStop('ModelMetrics') - if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) + if (!all(levels(data[, "pred"]) == lvls)) stop("levels of observed and predicted data do not match") - rocAUC <- ModelMetrics::auc(data$obs, data$pred) + data$y = as.numeric(data$obs == lvls[2]) + rocAUC <- ModelMetrics::auc(data$y, data$pred) out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"], lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2])) diff --git a/pkg/caret/tests/testthat/test_twoClassSummary.R b/pkg/caret/tests/testthat/test_twoClassSummary.R new file mode 100644 index 00000000..074df5a7 --- /dev/null +++ b/pkg/caret/tests/testthat/test_twoClassSummary.R @@ -0,0 +1,33 @@ + + +context('twoClassSummary') + + +test_that("twoClassSummary is calculating correctly", { + +library(caret) + +set.seed(1) +tr_dat <- twoClassSim(500) +te_dat <- tr_dat +tr_dat$Class = factor(tr_dat$Class, levels = rev(levels(te_dat$Class))) + +set.seed(35) +mod1 <- train(Class ~ ., data = tr_dat, + method = "lda", + metric = "ROC", + trControl = trainControl(classProbs = TRUE, + summaryFunction = twoClassSummary)) + +set.seed(35) +mod2 <- train(Class ~ ., data = te_dat, + method = "lda", + metric = "ROC", + trControl = trainControl(classProbs = TRUE, + summaryFunction = twoClassSummary)) + +expect_equal(mod1$resample$ROC, mod2$resample$ROC) +expect_equal(mod1$resample$Sens, mod2$resample$Spec) +expect_equal(mod1$resample$Spec, mod2$resample$Sens) + +}) From 06f1a5a4096de93336f9f25f690930e86f0947af Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sun, 4 Sep 2016 08:49:35 -0600 Subject: [PATCH 24/30] added myself Description as contributor --- pkg/caret/DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/caret/DESCRIPTION b/pkg/caret/DESCRIPTION index cc2db87e..3182c6d5 100644 --- a/pkg/caret/DESCRIPTION +++ b/pkg/caret/DESCRIPTION @@ -5,7 +5,7 @@ Title: Classification and Regression Training Author: Max Kuhn. Contributions from Jed Wing, Steve Weston, Andre Williams, Chris Keefer, Allan Engelhardt, Tony Cooper, Zachary Mayer, Brenton Kenkel, the R Core Team, Michael Benesty, Reynald Lescarbeau, - Andrew Ziem, Luca Scrucca, Yuan Tang, and Can Candan. + Andrew Ziem, Luca Scrucca, Yuan Tang, Can Candan, and Tyler Hunt. Description: Misc functions for training and plotting classification and regression models. Maintainer: Max Kuhn From 38df201aeb70687c266c16aab5c2f34216099280 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sun, 4 Sep 2016 13:59:16 -0600 Subject: [PATCH 25/30] trying cran binary again --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1cd7c443..f020a21e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,7 +44,7 @@ before_install: - ./travis-tool.sh r_binary_install nnet - ./travis-tool.sh r_binary_install party - ./travis-tool.sh r_binary_install pls - - ./travis-tool.sh install_github ModelMetrics + - ./travis-tool.sh r_binary_install ModelMetrics - ./travis-tool.sh r_binary_install proxy - ./travis-tool.sh r_binary_install randomForest - ./travis-tool.sh r_binary_install RANN From 1b2a7ebb790daa2d4e0e8b36c9beefa510c0b70d Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sun, 4 Sep 2016 18:45:27 -0600 Subject: [PATCH 26/30] test non-binary install --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index f020a21e..6d6de5b2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -44,7 +44,7 @@ before_install: - ./travis-tool.sh r_binary_install nnet - ./travis-tool.sh r_binary_install party - ./travis-tool.sh r_binary_install pls - - ./travis-tool.sh r_binary_install ModelMetrics + - ./travis-tool.sh r_install ModelMetrics - ./travis-tool.sh r_binary_install proxy - ./travis-tool.sh r_binary_install randomForest - ./travis-tool.sh r_binary_install RANN From e08db3dcd1dc7bfc4477ebcf6161fa99f17f6e5b Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sun, 4 Sep 2016 19:23:21 -0600 Subject: [PATCH 27/30] dropping direct calls to package relying on namespace --- pkg/caret/R/aaa.R | 2 +- pkg/caret/R/filterVarImp.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index 5c7aa28f..7fd166ba 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -229,7 +229,7 @@ twoClassSummary <- function (data, lev = NULL, model = NULL) if (!all(levels(data[, "pred"]) == lvls)) stop("levels of observed and predicted data do not match") data$y = as.numeric(data$obs == lvls[2]) - rocAUC <- ModelMetrics::auc(data$y, data$pred) + rocAUC <- auc(data$y, data$pred) out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"], lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2])) diff --git a/pkg/caret/R/filterVarImp.R b/pkg/caret/R/filterVarImp.R index 801b8597..0bb82c08 100644 --- a/pkg/caret/R/filterVarImp.R +++ b/pkg/caret/R/filterVarImp.R @@ -1,6 +1,6 @@ rocPerCol <- function(dat, cls){ - ModelMetrics::auc(cls, dat) + auc(cls, dat) } asNumeric <- function(data){ From 3466da2a1a18632876783638cf7065db19e30c03 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Sun, 4 Sep 2016 20:25:27 -0600 Subject: [PATCH 28/30] fixing issues from check --- pkg/caret/DESCRIPTION | 2 +- pkg/caret/NAMESPACE | 56 +++++++++++++++++++++-------------------- pkg/caret/man/caret-internal.Rd | 1 - 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pkg/caret/DESCRIPTION b/pkg/caret/DESCRIPTION index 3182c6d5..d788452f 100644 --- a/pkg/caret/DESCRIPTION +++ b/pkg/caret/DESCRIPTION @@ -20,6 +20,7 @@ Imports: foreach, methods, plyr, + ModelMetrics (>= 1.1.0), nlme, reshape2, stats, @@ -44,7 +45,6 @@ Suggests: nnet, party (>= 0.9-99992), pls, - ModelMetrics (>= 1.1.0), proxy, randomForest, RANN, diff --git a/pkg/caret/NAMESPACE b/pkg/caret/NAMESPACE index 606068fc..d44dd11f 100644 --- a/pkg/caret/NAMESPACE +++ b/pkg/caret/NAMESPACE @@ -3,23 +3,25 @@ importFrom(ModelMetrics, auc) import(foreach, methods, plyr, reshape2, ggplot2, lattice, nlme) importFrom(car, powerTransform, yjPower) importFrom(grDevices, extendrange) -importFrom(stats, .checkMFClasses, .getXlevels, aggregate, anova, - approx, as.formula, binom.test, complete.cases, contrasts, - cor, cov, delete.response, dist, - fitted.values, loess, mahalanobis, - mcnemar.test, median, model.frame, model.matrix, - model.response, model.weights, na.fail, na.pass, optim, - predict, qnorm, quantile, rbinom, reshape, resid, - residuals, rnorm, runif, sd, t.test, terms, +importFrom(stats, .checkMFClasses, .getXlevels, aggregate, anova, + approx, as.formula, binom.test, complete.cases, contrasts, + cor, cov, delete.response, dist, + fitted.values, loess, mahalanobis, + mcnemar.test, median, model.frame, model.matrix, + model.response, model.weights, na.fail, na.pass, optim, + predict, qnorm, quantile, rbinom, reshape, resid, + residuals, rnorm, runif, sd, t.test, terms, toeplitz, var, na.omit, p.adjust, fitted, prcomp, hclust, - lm, model.extract, pt, update, binomial) -importFrom(stats4, coef) -importFrom(utils, capture.output, getFromNamespace, head, - install.packages, installed.packages, object.size, flush.console, menu, stack) + lm, model.extract, pt, update, binomial) +importFrom(stats4, coef) +importFrom(utils, capture.output, getFromNamespace, head, + install.packages, installed.packages, object.size, flush.console, menu, stack, + modifyList, combn + ) export(anovaScores, as.data.frame.resamples, - as.matrix.resamples, + as.matrix.resamples, avNNet, avNNet.default, bag, @@ -74,7 +76,7 @@ export(anovaScores, extractProb, F_meas, F_meas.default, - F_meas.table, + F_meas.table, featurePlot, filterVarImp, findCorrelation, @@ -88,7 +90,7 @@ export(anovaScores, gafs_spCrossover, gafs_raMutation, gafs, - gafs.default, + gafs.default, gafsControl, gamFormula, gamFuncs, @@ -144,7 +146,7 @@ export(anovaScores, nullModel, nullModel.default, oneSE, - panel.calibration, + panel.calibration, panel.lift, panel.lift2, panel.needle, @@ -197,7 +199,7 @@ export(anovaScores, R2, recall, recall.default, - recall.table, + recall.table, resampleHist, resamples, resamples.default, @@ -208,16 +210,16 @@ export(anovaScores, rfeControl, rfeIter, rfFuncs, - rfGA, + rfGA, rfSA, rfSBF, rfStats, RMSE, safs_initial, - safs_perturb, + safs_perturb, safs_prob, safs, - safs.default, + safs.default, safsControl, sbf, sbf.default, @@ -340,7 +342,7 @@ S3method(varImp, nnet) S3method(varImp, glmnet) S3method(varImp, gam) S3method(varImp, gafs) -S3method(varImp, safs) +S3method(varImp, safs) S3method(densityplot, train) S3method(histogram, train) @@ -378,7 +380,7 @@ S3method(plot, prcomp.resamples) S3method(plot, lift) S3method(plot, calibration) S3method(plot, gafs) -S3method(plot, safs) +S3method(plot, safs) S3method(confusionMatrix, train) S3method(confusionMatrix, rfe) @@ -419,7 +421,7 @@ S3method(print, lift) S3method(print, calibration) S3method(print, expoTrans) S3method(print, gafs) -S3method(print, safs) +S3method(print, safs) S3method(predict, plsda) S3method(predict, splsda) @@ -442,7 +444,7 @@ S3method(predict, dummyVars) S3method(predict, BoxCoxTrans) S3method(predict, expoTrans) S3method(predict, gafs) -S3method(predict, safs) +S3method(predict, safs) S3method(summary, bagEarth) S3method(summary, bagFDA) @@ -460,7 +462,7 @@ S3method(predictors, default) S3method(predictors, rfe) S3method(predictors, sbf) S3method(predictors, gafs) -S3method(predictors, safs) +S3method(predictors, safs) S3method(confusionMatrix, table) @@ -493,7 +495,7 @@ S3method(summary, diff.resamples) S3method(update, train) S3method(update, rfe) S3method(update, gafs) -S3method(update, safs) +S3method(update, safs) S3method(fitted, train) S3method(residuals, train) @@ -511,7 +513,7 @@ S3method(oob_pred, sbf) S3method(oob_pred, list) S3method(gafs, default) -S3method(safs, default) +S3method(safs, default) S3method(trim, train) diff --git a/pkg/caret/man/caret-internal.Rd b/pkg/caret/man/caret-internal.Rd index 0d3406dc..3b721c82 100644 --- a/pkg/caret/man/caret-internal.Rd +++ b/pkg/caret/man/caret-internal.Rd @@ -33,7 +33,6 @@ MeanSD(x, exclude = NULL) sortImp(object, top) resampleWrapper(x, ind) caretTheme() -rocPoint(cutoff, x, y, positive) ipredStats(x) rfStats(x) bagEarthStats(x) From 707f25194c52365b9f784fcd608701cb532fa1b8 Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Wed, 7 Sep 2016 10:55:21 -0600 Subject: [PATCH 29/30] remove mcc (numeric overflow problems) --- pkg/caret/R/confusionMatrix.R | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/pkg/caret/R/confusionMatrix.R b/pkg/caret/R/confusionMatrix.R index a5ced4cc..308f6fb4 100644 --- a/pkg/caret/R/confusionMatrix.R +++ b/pkg/caret/R/confusionMatrix.R @@ -352,19 +352,3 @@ resampName <- function(x, numbers = TRUE){ out } - -mcc <- function(tab, pos = colnames(tab)[1]){ - if(nrow(tab) != 2 | ncol(tab) != 2) stop("A 2x2 table is needed") - neg <- colnames(tab)[colnames(tab) != pos] - tp <- tab[pos, pos] - tn <- tab[neg, neg] - fp <- tab[pos,neg] - fn <- tab[neg, pos] - d1 <- tp + fp - d2 <- tp + fn - d3 <- tn + fp - d4 <- tn + fn - if(d1 == 0 | d2 == 0 | d3 == 0 | d4 == 0) return(0) - ((tp * tn) - (fp * fn))/sqrt(d1*d2*d3*d4) -} - From 33d8b0d239633ad81bb6d860dbe51b8b13dcaace Mon Sep 17 00:00:00 2001 From: Tyler Hunt Date: Wed, 7 Sep 2016 14:55:48 -0600 Subject: [PATCH 30/30] bugfix and added test code for coverage --- pkg/caret/R/aaa.R | 2 +- pkg/caret/tests/testthat/test_twoClassSummary.R | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pkg/caret/R/aaa.R b/pkg/caret/R/aaa.R index 7fd166ba..97375fd0 100644 --- a/pkg/caret/R/aaa.R +++ b/pkg/caret/R/aaa.R @@ -229,7 +229,7 @@ twoClassSummary <- function (data, lev = NULL, model = NULL) if (!all(levels(data[, "pred"]) == lvls)) stop("levels of observed and predicted data do not match") data$y = as.numeric(data$obs == lvls[2]) - rocAUC <- auc(data$y, data$pred) + rocAUC <- ModelMetrics:::auc(ifelse(data$obs == lev[2], 0, 1), data[, lvls[1]]) out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"], lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2])) diff --git a/pkg/caret/tests/testthat/test_twoClassSummary.R b/pkg/caret/tests/testthat/test_twoClassSummary.R index 074df5a7..579bd656 100644 --- a/pkg/caret/tests/testthat/test_twoClassSummary.R +++ b/pkg/caret/tests/testthat/test_twoClassSummary.R @@ -14,14 +14,16 @@ tr_dat$Class = factor(tr_dat$Class, levels = rev(levels(te_dat$Class))) set.seed(35) mod1 <- train(Class ~ ., data = tr_dat, - method = "lda", + method = "fda", + tuneLength = 10, metric = "ROC", trControl = trainControl(classProbs = TRUE, summaryFunction = twoClassSummary)) set.seed(35) mod2 <- train(Class ~ ., data = te_dat, - method = "lda", + method = "fda", + tuneLength = 10, metric = "ROC", trControl = trainControl(classProbs = TRUE, summaryFunction = twoClassSummary))