Skip to content

Commit 3f9ded7

Browse files
committed
make sure all steps work with zero and one rows data in bake()
1 parent 9d17f9d commit 3f9ded7

19 files changed

Lines changed: 131 additions & 0 deletions

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# themis (development version)
22

3+
* All `step_*()` functions now correctly handle 0 and 1 row inputs in `bake()` (#160).
4+
35
* `adasyn()`, `bsmote()`, `nearmiss()`, `smote()`, and `tomek()` now correctly attribute errors from non-numeric columns to the user-facing function (#181).
46

57
* `smotenc()` now only suppresses the specific benign warning from `gower::gower_topn()` about variables with zero range, rather than all warnings (#182).

R/adasyn.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ bake.step_adasyn <- function(object, new_data, ...) {
206206
return(new_data)
207207
}
208208

209+
if (nrow(new_data) <= 1) {
210+
return(new_data)
211+
}
212+
209213
new_data <- as.data.frame(new_data)
210214

211215
predictor_data <- new_data[, col_names]

R/bsmote.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ bake.step_bsmote <- function(object, new_data, ...) {
242242
return(new_data)
243243
}
244244

245+
if (nrow(new_data) <= 1) {
246+
return(new_data)
247+
}
248+
245249
new_data <- as.data.frame(new_data)
246250

247251
predictor_data <- new_data[, col_names]

R/downsample.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ prep.step_downsample <- function(x, training, info = NULL, ...) {
228228

229229
subsamp <- function(x, wts, num) {
230230
n <- nrow(x)
231+
if (n == 0) {
232+
return(x)
233+
}
231234
if (nrow(x) == num) {
232235
out <- x
233236
} else {
@@ -247,6 +250,10 @@ bake.step_downsample <- function(object, new_data, ...) {
247250
return(new_data)
248251
}
249252

253+
if (nrow(new_data) <= 1) {
254+
return(new_data)
255+
}
256+
250257
if (isTRUE(object$case_weights)) {
251258
wts_col <- purrr::map_lgl(new_data, hardhat::is_case_weights)
252259
wts <- new_data[[names(which(wts_col))]]

R/nearmiss.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ bake.step_nearmiss <- function(object, new_data, ...) {
212212
return(new_data)
213213
}
214214

215+
if (nrow(new_data) <= 1) {
216+
return(new_data)
217+
}
218+
215219
ignore_vars <- setdiff(names(new_data), col_names)
216220

217221
# nearmiss with seed for reproducibility

R/rose.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ bake.step_rose <- function(object, new_data, ...) {
230230
return(new_data)
231231
}
232232

233+
if (nrow(new_data) <= 1) {
234+
return(new_data)
235+
}
236+
233237
if (any(is.na(new_data[[object$column]]))) {
234238
missing <- new_data[is.na(new_data[[object$column]]), ]
235239
} else {

R/smote.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ bake.step_smote <- function(object, new_data, ...) {
212212
return(new_data)
213213
}
214214

215+
if (nrow(new_data) <= 1) {
216+
return(new_data)
217+
}
218+
215219
new_data <- as.data.frame(new_data)
216220

217221
predictor_data <- new_data[, col_names]

R/smotenc.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ bake.step_smotenc <- function(object, new_data, ...) {
195195
return(new_data)
196196
}
197197

198+
if (nrow(new_data) <= 1) {
199+
return(new_data)
200+
}
201+
198202
new_data <- as.data.frame(new_data)
199203

200204
predictor_data <- new_data[, col_names]

R/tomek.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ bake.step_tomek <- function(object, new_data, ...) {
177177
return(new_data)
178178
}
179179

180+
if (nrow(new_data) <= 1) {
181+
return(new_data)
182+
}
183+
180184
predictor_data <- new_data[, col_names]
181185

182186
# tomek with seed for reproducibility

R/upsample.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ prep.step_upsample <- function(x, training, info = NULL, ...) {
225225

226226
supsamp <- function(x, wts, num) {
227227
n <- nrow(x)
228+
if (n == 0) {
229+
return(x)
230+
}
228231
if (nrow(x) == num) {
229232
out <- x
230233
} else {
@@ -244,6 +247,10 @@ bake.step_upsample <- function(object, new_data, ...) {
244247
return(new_data)
245248
}
246249

250+
if (nrow(new_data) <= 1) {
251+
return(new_data)
252+
}
253+
247254
if (isTRUE(object$case_weights)) {
248255
wts_col <- purrr::map_lgl(new_data, hardhat::is_case_weights)
249256
wts <- new_data[[names(which(wts_col))]]

0 commit comments

Comments
 (0)