1
## FIXME: data in party
2
## - currently assumed to be a data.frame
4
## - the following are all assumed to work:
5
## dim(data), names(data)
6
## sapply(data, class), lapply(data, levels)
7
## - potentially these need to be modified if data/terms
8
## should be able to deal with data bases
10
party <- function(node, data, fitted = NULL, terms = NULL, names = NULL, info = NULL) {
12
stopifnot(inherits(node, "partynode"))
13
stopifnot(inherits(data, "data.frame"))
14
### make sure all split variables are there
15
ids <- nodeids(node)[!nodeids(node) %in% nodeids(node, terminal = TRUE)]
16
varids <- unique(unlist(nodeapply(node, ids = ids, FUN = function(x)
17
varid_split(split_node(x)))))
18
stopifnot(varids %in% 1:ncol(data))
20
if(!is.null(fitted)) {
21
stopifnot(inherits(fitted, "data.frame"))
22
stopifnot("(fitted)" == names(fitted)[1])
23
stopifnot(nrow(data) == 0 | nrow(data) == nrow(fitted))
25
nt <- nodeids(node, terminal = TRUE)
26
stopifnot(all(fitted[["(fitted)"]] %in% nt))
28
node <- as.partynode(node, from = 1L)
29
nt2 <- nodeids(node, terminal = TRUE)
30
fitted[["(fitted)"]] <- nt2[match(fitted[["(fitted)"]], nt)]
32
node <- as.partynode(node, from = 1L)
35
party <- list(node = node, data = data, fitted = fitted,
36
terms = NULL, names = NULL, info = info)
37
class(party) <- "party"
40
stopifnot(inherits(terms, "terms"))
44
if (!is.null(names)) {
45
n <- length(nodeids(party, terminal = FALSE))
46
if (length(names) != n)
47
stop("invalid", " ", sQuote("names"), " ", "argument")
54
length.party <- function(x)
57
names.party <- function(x)
60
"names<-.party" <- function(x, value) {
61
n <- length(nodeids(x, terminal = FALSE))
62
if (!is.null(value) && length(value) != n)
63
stop("invalid", " ", sQuote("names"), " ", "argument")
68
.names_party <- function(party) {
71
names <- as.character(nodeids(party, terminal = FALSE))
75
node_party <- function(party) {
76
stopifnot(inherits(party, "party"))
80
is.constparty <- function(party) {
81
stopifnot(inherits(party, "party"))
82
if (!is.null(party$fitted))
83
return(all(c("(fitted)", "(response)") %in% names(party$fitted)))
87
as.constparty <- function(obj, ...) {
88
if (is.constparty(obj)) {
90
class(ret) <- c("constparty", class(obj))
93
stop("cannot coerce object of class", " ", sQuote(class(obj)),
94
" ", "to", " ", sQuote("constparty"))
97
"[.party" <- "[[.party" <- function(x, i, ...) {
98
if (is.character(i) && !is.null(names(x)))
99
i <- which(names(x) %in% i)
100
stopifnot(length(i) == 1 & is.numeric(i))
101
stopifnot(i <= length(x) & i >= 1)
103
dat <- data_party(x, i)
104
if (!is.null(x$fitted)) {
105
findx <- which("(fitted)" == names(dat))[1]
106
fit <- dat[,findx:ncol(dat), drop = FALSE]
107
dat <- dat[,-(findx:ncol(dat)), drop = FALSE]
114
nam <- names(x)[nodeids(x, from = i, terminal = FALSE)]
116
recFun <- function(node) {
117
if (id_node(node) == i) return(node)
118
kid <- sapply(kids_node(node), id_node)
119
return(recFun(node[[max(which(kid <= i))]]))
121
node <- recFun(node_party(x))
123
ret <- party(node = node, data = dat, fitted = fit,
124
terms = x$terms, names = nam, info = x$info)
125
class(ret) <- class(x)
129
nodeids <- function(obj, ...)
132
nodeids.partynode <- function(obj, from = NULL, terminal = FALSE, ...) {
134
if(is.null(from)) from <- id_node(obj)
136
id <- function(node, record = TRUE, terminal = FALSE) {
137
if(!record) return(NULL)
139
return(id_node(node))
141
if(is.terminal(node)) return(id_node(node)) else return(NULL)
144
rid <- function(node, record = TRUE, terminal = FALSE) {
145
myid <- id(node, record = record, terminal = terminal)
146
if(is.terminal(node)) return(myid)
147
kids <- kids_node(node)
148
kids_record <- if(record)
149
rep(TRUE, length(kids))
151
sapply(kids, id_node) == from
153
unlist(lapply(1:length(kids), function(i)
154
rid(kids[[i]], record = kids_record[i], terminal = terminal)))
158
return(rid(obj, from == id_node(obj), terminal))
161
nodeids.party <- function(obj, from = NULL, terminal = FALSE, ...)
162
nodeids(node_party(obj), from = from, terminal = terminal, ...)
164
nodeapply <- function(obj, ids = 1, FUN = NULL, ...)
165
UseMethod("nodeapply")
167
nodeapply.party <- function(obj, ids = 1, FUN = NULL, by_node = TRUE, ...) {
169
stopifnot(isTRUE(all.equal(ids, round(ids))))
170
ids <- as.integer(ids)
172
if(is.null(FUN)) FUN <- function(x, ...) x
174
if (length(ids) == 0)
178
rval <- nodeapply(node_party(obj), ids = ids, FUN = FUN, ...)
180
rval <- lapply(ids, function(i) FUN(obj[[i]], ...))
183
names(rval) <- names(obj)[ids]
187
nodeapply.partynode <- function(obj, ids = 1, FUN = NULL, ...) {
189
stopifnot(isTRUE(all.equal(ids, round(ids))))
190
ids <- as.integer(ids)
192
if(is.null(FUN)) FUN <- function(x, ...) x
194
if (length(ids) == 0)
197
rval <- vector(mode = "list", length = length(ids))
198
rval_id <- rep(0, length(ids))
201
recFUN <- function(node, ...) {
202
if(id_node(node) %in% ids) {
203
rval_id[i] <<- id_node(node)
204
rval[[i]] <<- FUN(node, ...)
207
kids <- kids_node(node)
208
if(length(kids) > 0) {
209
for(j in 1:length(kids)) recFUN(kids[[j]])
214
rval <- rval[match(rval_id, ids)]
218
predict.party <- function(object, newdata = NULL, ...)
220
### compute fitted node ids first
221
fitted <- if(is.null(newdata)) object$fitted[["(fitted)"]] else {
223
terminal <- nodeids(object, terminal = TRUE)
224
inner <- 1:max(terminal)
225
inner <- inner[-terminal]
227
primary_vars <- nodeapply(object, ids = inner, by_node = TRUE, FUN = function(node) {
228
varid_split(split_node(node))
230
surrogate_vars <- nodeapply(object, ids = inner, by_node = TRUE, FUN = function(node) {
231
surr <- surrogates_node(node)
232
if(is.null(surr)) return(NULL) else return(sapply(surr, varid_split))
234
vnames <- names(object$data)
236
## ## FIXME: the is.na() call takes loooong on large data sets
237
## unames <- if(any(sapply(newdata, is.na)))
238
## vnames[unique(unlist(c(primary_vars, surrogate_vars)))]
240
## vnames[unique(unlist(primary_vars))]
241
unames <- vnames[unique(unlist(c(primary_vars, surrogate_vars)))]
243
vclass <- structure(lapply(object$data, class), .Names = vnames)
244
ndnames <- names(newdata)
245
ndclass <- structure(lapply(newdata, class), .Names = ndnames)
246
if(all(unames %in% ndnames) &&
247
all(unlist(lapply(unames, function(x) vclass[[x]] == ndclass[[x]])))) {
248
vmatch <- match(vnames, ndnames)
249
fitted_node(node_party(object), newdata, vmatch)
251
if (!is.null(object$terms)) {
252
mf <- model.frame(delete.response(object$terms), newdata)
253
fitted_node(node_party(object), mf, match(vnames, names(mf)))
255
stop("") ## FIXME: write error message
258
### compute predictions
259
predict_party(object, fitted, newdata, ...)
262
predict_party <- function(party, id, newdata = NULL, ...)
263
UseMethod("predict_party")
265
### do nothing expect returning the fitted ids
266
predict_party.default <- function(party, id, newdata = NULL, ...) {
268
if (length(list(...)) > 1)
269
warning("argument(s)", " ", sQuote(names(list(...))), " ", "have been ignored")
271
## get observation names: either node names or
272
## observation names from newdata
273
nam <- if(is.null(newdata)) names(party)[id] else rownames(newdata)
274
if(length(nam) != length(id)) nam <- NULL
276
## special case: fitted ids
277
return(structure(id, .Names = nam))
280
predict_party.constparty <- function(party, id, newdata = NULL,
281
type = c("response", "prob", "node"), FUN = NULL, simplify = TRUE, ...)
283
## extract fitted information
284
response <- party$fitted[["(response)"]]
285
weights <- party$fitted[["(weights)"]]
286
fitted <- party$fitted[["(fitted)"]]
287
if (is.null(weights)) weights <- rep(1, NROW(response))
289
## get observation names: either node names or
290
## observation names from newdata
291
nam <- if(is.null(newdata)) names(party)[id] else rownames(newdata)
292
if(length(nam) != length(id)) nam <- NULL
295
type <- match.arg(type)
297
## special case: fitted ids
299
return(structure(id, .Names = nam))
301
### multivariate response
302
if (is.data.frame(response)) {
303
ret <- lapply(response, function(r) {
304
ret <- .predict_party_constparty(node_party(party), fitted = fitted,
305
response = r, weights, id = id, type = type, FUN = FUN, ...)
306
if (simplify) .simplify_pred(ret, id, nam) else ret
308
if (all(sapply(ret, is.atomic)))
309
ret <- as.data.frame(ret)
310
names(ret) <- colnames(response)
314
### univariate response
315
ret <- .predict_party_constparty(node_party(party), fitted = fitted, response = response,
316
weights = weights, id = id, type = type, FUN = FUN, ...)
317
if (simplify) .simplify_pred(ret, id, nam) else ret[as.character(id)]
320
### functions for node prediction based on fitted / response
321
.pred_Surv <- function(y, w)
322
survival:::survfit(y ~ 1, weights = w, subset = w > 0)
324
.pred_Surv_response <- function(y, w)
325
.median_survival_time(.pred_Surv(y, w))
327
.pred_factor <- function(y, w) {
329
sumw <- tapply(w, y, sum)
330
sumw[is.na(sumw)] <- 0
331
prob <- sumw / sum(w)
336
.pred_factor_response <- function(y, w) {
337
prob <- .pred_factor(y, w)
338
return(factor(which.max(prob), levels = 1:nlevels(y),
340
ordered = is.ordered(y)))
344
.pred_numeric <- function(y, w) weighted.mean(y, w, na.rm = TRUE)
346
### workhorse: compute predictions based on fitted / response data
347
.predict_party_constparty <- function(node, fitted, response, weights,
348
id = id, type = c("response", "prob"), FUN = NULL, ...) {
352
rtype <- class(response)[1]
353
if (rtype == "ordered") rtype <- "factor"
354
if (rtype == "integer") rtype <- "numeric"
357
"Surv" = if (type == "response") .pred_Surv_response else .pred_Surv,
358
"factor" = if (type == "response") .pred_factor_response else .pred_factor,
361
stop(sQuote("type = \"prob\""), " ", "is not available")
366
## empirical distribution in each leaf
367
if (all(id %in% fitted)) {
368
tab <- tapply(1:NROW(response), fitted,
369
function(i) FUN(response[i], weights[i]), simplify = FALSE)
371
### id may also refer to inner nodes
372
tab <- as.array(lapply(sort(unique(id)), function(i) {
373
index <- fitted %in% nodeids(node, i, terminal = TRUE)
374
FUN(response[index], weights[index])
376
names(tab) <- as.character(sort(unique(id)))
386
### simplify structure of predictions
387
.simplify_pred <- function(tab, id, nam) {
389
if (all(sapply(tab, length) == 1) & all(sapply(tab, is.atomic))) {
390
ret <- do.call("c", tab)
391
names(ret) <- names(tab)
392
ret <- if (is.factor(tab[[1]]))
393
factor(ret[as.character(id)], levels = 1:length(levels(tab[[1]])),
394
labels = levels(tab[[1]]), ordered = is.ordered(tab[[1]]))
396
ret[as.character(id)]
398
} else if (length(unique(sapply(tab, length))) == 1 &
399
all(sapply(tab, is.numeric))) {
400
ret <- matrix(unlist(tab), nrow = length(tab), byrow = TRUE)
401
colnames(ret) <- names(tab[[1]])
402
rownames(ret) <- names(tab)
403
ret <- ret[as.character(id),, drop = FALSE]
406
ret <- tab[as.character(id)]
412
#print_party <- function(party, id, ...)
413
# UseMethod("print_party")
416
#print_party.default <- function(party, id, newdata = NULL, ...) {
420
data_party <- function(party, id = 1L)
421
UseMethod("data_party")
423
data_party.default <- function(party, id = 1L) {
425
extract <- function(id) {
426
if(is.null(party$fitted))
427
if(nrow(party$data) == 0) return(NULL)
429
stop("cannot subset data without fitted ids")
431
### which terminal nodes follow node number id?
432
nt <- nodeids(party, id, terminal = TRUE)
433
wi <- party$fitted[["(fitted)"]] %in% nt
435
ret <- if (nrow(party$data) == 0)
436
subset(party$fitted, wi)
438
subset(cbind(party$data, party$fitted), wi)
442
return(lapply(id, extract))
447
width.party <- function(x, ...) {
448
width(node_party(x), ...)
451
depth.party <- function(x, ...) {
452
depth(node_party(x), ...)
455
.list.rules.party <- function(x, i = NULL, ...) {
456
if (is.null(i)) i <- nodeids(x, terminal = TRUE)
458
ret <- sapply(i, .list.rules.party, x = x)
459
names(ret) <- if (is.character(i)) i else names(x)[i]
462
if (is.character(i) && !is.null(names(x)))
463
i <- which(names(x) %in% i)
464
stopifnot(length(i) == 1 & is.numeric(i))
465
stopifnot(i <= length(x) & i >= 1)
467
dat <- data_party(x, i)
468
if (!is.null(x$fitted)) {
469
findx <- which("(fitted)" == names(dat))[1]
470
fit <- dat[,findx:ncol(dat), drop = FALSE]
471
dat <- dat[,-(findx:ncol(dat)), drop = FALSE]
481
recFun <- function(node) {
482
if (id_node(node) == i) return(NULL)
483
kid <- sapply(kids_node(node), id_node)
484
whichkid <- max(which(kid <= i))
485
split <- split_node(node)
486
ivar <- varid_split(split)
487
svar <- names(dat)[ivar]
488
index <- index_split(split)
489
if (is.factor(dat[, svar])) {
490
slevels <- levels(dat[, svar])[index == whichkid]
491
srule <- paste(svar, " %in% c(\"",
492
paste(slevels, collapse = "\", \"", sep = ""), "\")",
495
if (is.null(index)) index <- 1:length(kid)
496
breaks <- cbind(c(-Inf, breaks_split(split)),
497
c(breaks_split(split), Inf))
498
sbreak <- breaks[index == whichkid,]
499
right <- right_split(split)
501
if (is.finite(sbreak[1]))
503
paste(svar, ifelse(right, ">", ">="), sbreak[1]))
504
if (is.finite(sbreak[2]))
506
paste(svar, ifelse(right, "<=", ">"), sbreak[2]))
507
srule <- paste(srule, collapse = " & ")
509
rule <<- c(rule, srule)
510
return(recFun(node[[whichkid]]))
512
node <- recFun(node_party(x))
513
paste(rule, collapse = " & ")