~elachuni/chaid/partykit

« back to all changes in this revision

Viewing changes to partykit/pkg/R/party.R

  • Committer: Anthony Lenton
  • Date: 2010-12-07 18:42:51 UTC
  • Revision ID: anthony.lenton@canonical.com-20101207184251-ei0pvabzhjly2ez6
Initial import from svn.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
## FIXME: data in party
 
2
##   - currently assumed to be a data.frame
 
3
##   - potentially empty
 
4
##   - the following are all assumed to work:
 
5
##     dim(data), names(data)
 
6
##     sapply(data, class), lapply(data, levels)
 
7
##   - potentially these need to be modified if data/terms
 
8
##     should be able to deal with data bases
 
9
 
 
10
party <- function(node, data, fitted = NULL, terms = NULL, names = NULL, info = NULL) {
 
11
 
 
12
    stopifnot(inherits(node, "partynode"))
 
13
    stopifnot(inherits(data, "data.frame"))
 
14
    ### make sure all split variables are there 
 
15
    ids <- nodeids(node)[!nodeids(node) %in% nodeids(node, terminal = TRUE)]
 
16
    varids <- unique(unlist(nodeapply(node, ids = ids, FUN = function(x) 
 
17
        varid_split(split_node(x)))))
 
18
    stopifnot(varids %in% 1:ncol(data))
 
19
    
 
20
    if(!is.null(fitted)) {
 
21
        stopifnot(inherits(fitted, "data.frame"))
 
22
        stopifnot("(fitted)" == names(fitted)[1])
 
23
        stopifnot(nrow(data) == 0 | nrow(data) == nrow(fitted))
 
24
 
 
25
        nt <- nodeids(node, terminal = TRUE)
 
26
        stopifnot(all(fitted[["(fitted)"]] %in% nt))
 
27
 
 
28
        node <- as.partynode(node, from = 1L)
 
29
        nt2 <- nodeids(node, terminal = TRUE)
 
30
        fitted[["(fitted)"]] <- nt2[match(fitted[["(fitted)"]], nt)]
 
31
    } else {
 
32
        node <- as.partynode(node, from = 1L)
 
33
    }
 
34
    
 
35
    party <- list(node = node, data = data, fitted = fitted, 
 
36
                  terms = NULL, names = NULL, info = info)
 
37
    class(party) <- "party"
 
38
 
 
39
    if(!is.null(terms)) {
 
40
        stopifnot(inherits(terms, "terms"))
 
41
        party$terms <- terms
 
42
    }
 
43
 
 
44
    if (!is.null(names)) {
 
45
        n <- length(nodeids(party, terminal = FALSE))
 
46
        if (length(names) != n)
 
47
            stop("invalid", " ", sQuote("names"), " ", "argument")
 
48
        party$names <- names
 
49
    }
 
50
 
 
51
    party
 
52
}
 
53
 
 
54
length.party <- function(x)
 
55
    length(nodeids(x))
 
56
 
 
57
names.party <- function(x)
 
58
    .names_party(x)
 
59
 
 
60
"names<-.party" <- function(x, value) {
 
61
     n <- length(nodeids(x, terminal = FALSE))
 
62
     if (!is.null(value) && length(value) != n)
 
63
         stop("invalid", " ", sQuote("names"), " ", "argument")
 
64
     x$names <- value
 
65
     x
 
66
}
 
67
 
 
68
.names_party <- function(party) {
 
69
    names <- party$names
 
70
    if (is.null(names))
 
71
        names <- as.character(nodeids(party, terminal = FALSE))
 
72
    names
 
73
}
 
74
 
 
75
node_party <- function(party) {
 
76
    stopifnot(inherits(party, "party"))
 
77
    party$node
 
78
}
 
79
 
 
80
is.constparty <- function(party) {
 
81
    stopifnot(inherits(party, "party"))
 
82
    if (!is.null(party$fitted)) 
 
83
        return(all(c("(fitted)", "(response)") %in% names(party$fitted)))
 
84
    return(FALSE)
 
85
}
 
86
 
 
87
as.constparty <- function(obj, ...) {
 
88
    if (is.constparty(obj)) {
 
89
        ret <- obj
 
90
        class(ret) <- c("constparty", class(obj))
 
91
        return(ret)
 
92
    }
 
93
    stop("cannot coerce object of class", " ", sQuote(class(obj)), 
 
94
          " ", "to", " ", sQuote("constparty"))
 
95
}
 
96
 
 
97
"[.party" <- "[[.party" <- function(x, i, ...) {
 
98
    if (is.character(i) && !is.null(names(x)))
 
99
        i <- which(names(x) %in% i)
 
100
    stopifnot(length(i) == 1 & is.numeric(i))
 
101
    stopifnot(i <= length(x) & i >= 1)
 
102
    i <- as.integer(i)
 
103
    dat <- data_party(x, i)
 
104
    if (!is.null(x$fitted)) {
 
105
        findx <- which("(fitted)" == names(dat))[1]
 
106
        fit <- dat[,findx:ncol(dat), drop = FALSE]
 
107
        dat <- dat[,-(findx:ncol(dat)), drop = FALSE]
 
108
        if (ncol(dat) == 0)
 
109
            dat <- x$data
 
110
    } else {
 
111
        fit <- NULL
 
112
        dat <- x$data
 
113
    }
 
114
    nam <- names(x)[nodeids(x, from = i, terminal = FALSE)]
 
115
 
 
116
    recFun <- function(node) {
 
117
        if (id_node(node) == i) return(node)
 
118
        kid <- sapply(kids_node(node), id_node)
 
119
        return(recFun(node[[max(which(kid <= i))]]))
 
120
    }
 
121
    node <- recFun(node_party(x))
 
122
 
 
123
    ret <- party(node = node, data = dat, fitted = fit, 
 
124
                 terms = x$terms, names = nam, info = x$info)
 
125
    class(ret) <- class(x)
 
126
    ret
 
127
}
 
128
 
 
129
nodeids <- function(obj, ...)
 
130
    UseMethod("nodeids")
 
131
 
 
132
nodeids.partynode <- function(obj, from = NULL, terminal = FALSE, ...) {
 
133
 
 
134
    if(is.null(from)) from <- id_node(obj)
 
135
 
 
136
    id <- function(node, record = TRUE, terminal = FALSE) {
 
137
      if(!record) return(NULL)
 
138
      if(!terminal)
 
139
          return(id_node(node))
 
140
      else
 
141
          if(is.terminal(node)) return(id_node(node)) else return(NULL)
 
142
    }
 
143
 
 
144
    rid <- function(node, record = TRUE, terminal = FALSE) {  
 
145
        myid <- id(node, record = record, terminal = terminal)
 
146
        if(is.terminal(node)) return(myid)
 
147
        kids <- kids_node(node)    
 
148
        kids_record <- if(record)  
 
149
            rep(TRUE, length(kids))
 
150
        else
 
151
            sapply(kids, id_node) == from
 
152
        return(c(myid,
 
153
            unlist(lapply(1:length(kids), function(i)
 
154
                rid(kids[[i]], record = kids_record[i], terminal = terminal)))
 
155
        ))
 
156
    }
 
157
 
 
158
    return(rid(obj, from == id_node(obj), terminal))
 
159
}
 
160
 
 
161
nodeids.party <- function(obj, from = NULL, terminal = FALSE, ...)
 
162
    nodeids(node_party(obj), from = from, terminal = terminal, ...)
 
163
 
 
164
nodeapply <- function(obj, ids = 1, FUN = NULL, ...)
 
165
    UseMethod("nodeapply")
 
166
 
 
167
nodeapply.party <- function(obj, ids = 1, FUN = NULL, by_node = TRUE, ...) {
 
168
 
 
169
    stopifnot(isTRUE(all.equal(ids, round(ids))))
 
170
    ids <- as.integer(ids)
 
171
 
 
172
    if(is.null(FUN)) FUN <- function(x, ...) x
 
173
 
 
174
    if (length(ids) == 0)
 
175
        return(NULL)
 
176
 
 
177
    if (by_node) {
 
178
        rval <- nodeapply(node_party(obj), ids = ids, FUN = FUN, ...)
 
179
    } else {
 
180
        rval <- lapply(ids, function(i) FUN(obj[[i]], ...))
 
181
    }
 
182
 
 
183
    names(rval) <- names(obj)[ids]
 
184
    return(rval)
 
185
}
 
186
 
 
187
nodeapply.partynode <- function(obj, ids = 1, FUN = NULL, ...) {
 
188
 
 
189
    stopifnot(isTRUE(all.equal(ids, round(ids))))
 
190
    ids <- as.integer(ids)
 
191
 
 
192
    if(is.null(FUN)) FUN <- function(x, ...) x
 
193
 
 
194
    if (length(ids) == 0)
 
195
        return(NULL)
 
196
 
 
197
    rval <- vector(mode = "list", length = length(ids))
 
198
    rval_id <- rep(0, length(ids))
 
199
    i <- 1
 
200
        
 
201
    recFUN <- function(node, ...) {
 
202
        if(id_node(node) %in% ids) {
 
203
            rval_id[i] <<- id_node(node)
 
204
            rval[[i]] <<- FUN(node, ...)
 
205
            i <<- i + 1
 
206
        }
 
207
        kids <- kids_node(node)
 
208
        if(length(kids) > 0) {
 
209
            for(j in 1:length(kids)) recFUN(kids[[j]])
 
210
        }
 
211
        invisible(TRUE)
 
212
    }
 
213
    foo <- recFUN(obj)
 
214
    rval <- rval[match(rval_id, ids)]
 
215
    return(rval)
 
216
}
 
217
 
 
218
predict.party <- function(object, newdata = NULL, ...)
 
219
{
 
220
    ### compute fitted node ids first
 
221
    fitted <- if(is.null(newdata)) object$fitted[["(fitted)"]] else {
 
222
 
 
223
        terminal <- nodeids(object, terminal = TRUE)
 
224
        inner <- 1:max(terminal)
 
225
        inner <- inner[-terminal]
 
226
 
 
227
        primary_vars <- nodeapply(object, ids = inner, by_node = TRUE, FUN = function(node) {
 
228
            varid_split(split_node(node))
 
229
        })
 
230
        surrogate_vars <- nodeapply(object, ids = inner, by_node = TRUE, FUN = function(node) {
 
231
            surr <- surrogates_node(node)
 
232
            if(is.null(surr)) return(NULL) else return(sapply(surr, varid_split))
 
233
        })
 
234
        vnames <- names(object$data)
 
235
 
 
236
        ## ## FIXME: the is.na() call takes loooong on large data sets
 
237
        ## unames <- if(any(sapply(newdata, is.na))) 
 
238
        ##     vnames[unique(unlist(c(primary_vars, surrogate_vars)))]
 
239
        ## else 
 
240
        ##     vnames[unique(unlist(primary_vars))]
 
241
        unames <- vnames[unique(unlist(c(primary_vars, surrogate_vars)))]
 
242
        
 
243
        vclass <- structure(lapply(object$data, class), .Names = vnames)
 
244
        ndnames <- names(newdata)
 
245
        ndclass <- structure(lapply(newdata, class), .Names = ndnames)
 
246
        if(all(unames %in% ndnames) &&
 
247
           all(unlist(lapply(unames, function(x) vclass[[x]] == ndclass[[x]])))) {
 
248
            vmatch <- match(vnames, ndnames)
 
249
            fitted_node(node_party(object), newdata, vmatch)
 
250
        } else {
 
251
            if (!is.null(object$terms)) {
 
252
                mf <- model.frame(delete.response(object$terms), newdata)
 
253
                fitted_node(node_party(object), mf, match(vnames, names(mf)))
 
254
            } else
 
255
                stop("") ## FIXME: write error message
 
256
        }
 
257
    }
 
258
    ### compute predictions
 
259
    predict_party(object, fitted, newdata, ...)
 
260
}
 
261
 
 
262
predict_party <- function(party, id, newdata = NULL, ...)
 
263
    UseMethod("predict_party")
 
264
 
 
265
### do nothing expect returning the fitted ids
 
266
predict_party.default <- function(party, id, newdata = NULL, ...) {
 
267
 
 
268
    if (length(list(...)) > 1) 
 
269
        warning("argument(s)", " ", sQuote(names(list(...))), " ", "have been ignored")
 
270
 
 
271
    ## get observation names: either node names or
 
272
    ## observation names from newdata
 
273
    nam <- if(is.null(newdata)) names(party)[id] else rownames(newdata)
 
274
    if(length(nam) != length(id)) nam <- NULL
 
275
 
 
276
    ## special case: fitted ids
 
277
    return(structure(id, .Names = nam))
 
278
}
 
279
 
 
280
predict_party.constparty <- function(party, id, newdata = NULL,
 
281
    type = c("response", "prob", "node"), FUN = NULL, simplify = TRUE, ...)
 
282
{
 
283
    ## extract fitted information
 
284
    response <- party$fitted[["(response)"]]
 
285
    weights <- party$fitted[["(weights)"]]
 
286
    fitted <- party$fitted[["(fitted)"]]
 
287
    if (is.null(weights)) weights <- rep(1, NROW(response))
 
288
 
 
289
    ## get observation names: either node names or
 
290
    ## observation names from newdata
 
291
    nam <- if(is.null(newdata)) names(party)[id] else rownames(newdata)
 
292
    if(length(nam) != length(id)) nam <- NULL
 
293
 
 
294
    ## match type
 
295
    type <- match.arg(type)
 
296
 
 
297
    ## special case: fitted ids
 
298
    if(type == "node")
 
299
      return(structure(id, .Names = nam))
 
300
 
 
301
    ### multivariate response
 
302
    if (is.data.frame(response)) {
 
303
        ret <- lapply(response, function(r) {
 
304
            ret <- .predict_party_constparty(node_party(party), fitted = fitted, 
 
305
                response = r, weights, id = id, type = type, FUN = FUN, ...)
 
306
            if (simplify) .simplify_pred(ret, id, nam) else ret
 
307
        })
 
308
        if (all(sapply(ret, is.atomic)))
 
309
            ret <- as.data.frame(ret)
 
310
        names(ret) <- colnames(response)
 
311
        return(ret)
 
312
    }
 
313
 
 
314
    ### univariate response
 
315
    ret <- .predict_party_constparty(node_party(party), fitted = fitted, response = response, 
 
316
        weights = weights, id = id, type = type, FUN = FUN, ...)
 
317
    if (simplify) .simplify_pred(ret, id, nam) else ret[as.character(id)]
 
318
}
 
319
 
 
320
### functions for node prediction based on fitted / response
 
321
.pred_Surv <- function(y, w)
 
322
    survival:::survfit(y ~ 1, weights = w, subset = w > 0)
 
323
 
 
324
.pred_Surv_response <- function(y, w)
 
325
    .median_survival_time(.pred_Surv(y, w))
 
326
                    
 
327
.pred_factor <- function(y, w) {
 
328
    lev <- levels(y)
 
329
    sumw <- tapply(w, y, sum)
 
330
    sumw[is.na(sumw)] <- 0
 
331
    prob <- sumw / sum(w)
 
332
    names(prob) <- lev
 
333
    return(prob)
 
334
}
 
335
 
 
336
.pred_factor_response <- function(y, w) {
 
337
    prob <- .pred_factor(y, w)
 
338
    return(factor(which.max(prob), levels = 1:nlevels(y),
 
339
                  labels = levels(y), 
 
340
                  ordered = is.ordered(y)))
 
341
    return(prob) 
 
342
}
 
343
                    
 
344
.pred_numeric <- function(y, w) weighted.mean(y, w, na.rm = TRUE)
 
345
 
 
346
### workhorse: compute predictions based on fitted / response data
 
347
.predict_party_constparty <- function(node, fitted, response, weights,
 
348
    id = id, type = c("response", "prob"), FUN = NULL, ...) {
 
349
 
 
350
    if (is.null(FUN)) {
 
351
 
 
352
        rtype <- class(response)[1]
 
353
        if (rtype == "ordered") rtype <- "factor"    
 
354
        if (rtype == "integer") rtype <- "numeric"
 
355
 
 
356
        FUN <- switch(rtype,
 
357
            "Surv" = if (type == "response") .pred_Surv_response else .pred_Surv,
 
358
            "factor" = if (type == "response") .pred_factor_response else .pred_factor,
 
359
            "numeric" = {
 
360
                if (type == "prob")
 
361
                    stop(sQuote("type = \"prob\""), " ", "is not available")
 
362
                .pred_numeric
 
363
           })
 
364
    }
 
365
      
 
366
    ## empirical distribution in each leaf
 
367
    if (all(id %in% fitted)) {
 
368
        tab <- tapply(1:NROW(response), fitted, 
 
369
                      function(i) FUN(response[i], weights[i]), simplify = FALSE)
 
370
    } else {
 
371
        ### id may also refer to inner nodes
 
372
        tab <- as.array(lapply(sort(unique(id)), function(i) {
 
373
            index <- fitted %in% nodeids(node, i, terminal = TRUE)
 
374
            FUN(response[index], weights[index])
 
375
        }))
 
376
        names(tab) <- as.character(sort(unique(id)))
 
377
    }
 
378
    tn <- names(tab)
 
379
    dim(tab) <- NULL
 
380
    names(tab) <- tn
 
381
 
 
382
    tab
 
383
}
 
384
 
 
385
 
 
386
### simplify structure of predictions
 
387
.simplify_pred <- function(tab, id, nam) {
 
388
 
 
389
    if (all(sapply(tab, length) == 1) & all(sapply(tab, is.atomic))) {
 
390
        ret <- do.call("c", tab)
 
391
        names(ret) <- names(tab)
 
392
        ret <- if (is.factor(tab[[1]]))
 
393
            factor(ret[as.character(id)], levels = 1:length(levels(tab[[1]])),
 
394
                   labels = levels(tab[[1]]), ordered = is.ordered(tab[[1]]))
 
395
        else 
 
396
            ret[as.character(id)]
 
397
        names(ret) <- nam
 
398
    } else if (length(unique(sapply(tab, length))) == 1 & 
 
399
               all(sapply(tab, is.numeric))) {
 
400
        ret <- matrix(unlist(tab), nrow = length(tab), byrow = TRUE)
 
401
        colnames(ret) <- names(tab[[1]])
 
402
        rownames(ret) <- names(tab)
 
403
        ret <- ret[as.character(id),, drop = FALSE]
 
404
        rownames(ret) <- nam
 
405
    } else {
 
406
        ret <- tab[as.character(id)]
 
407
        names(ret) <- nam
 
408
    }
 
409
    ret
 
410
}
 
411
 
 
412
#print_party <- function(party, id, ...)
 
413
#    UseMethod("print_party")
 
414
#
 
415
#
 
416
#print_party.default <- function(party, id, newdata = NULL, ...) {
 
417
#
 
418
#}
 
419
 
 
420
data_party <- function(party, id = 1L)
 
421
    UseMethod("data_party")
 
422
 
 
423
data_party.default <- function(party, id = 1L) {
 
424
    
 
425
    extract <- function(id) {
 
426
        if(is.null(party$fitted))
 
427
            if(nrow(party$data) == 0) return(NULL)
 
428
        else
 
429
            stop("cannot subset data without fitted ids")
 
430
 
 
431
        ### which terminal nodes follow node number id?
 
432
        nt <- nodeids(party, id, terminal = TRUE)
 
433
        wi <- party$fitted[["(fitted)"]] %in% nt
 
434
 
 
435
        ret <- if (nrow(party$data) == 0)
 
436
            subset(party$fitted, wi)
 
437
        else
 
438
            subset(cbind(party$data, party$fitted), wi)
 
439
        ret
 
440
    }
 
441
    if (length(id) > 1)
 
442
        return(lapply(id, extract))
 
443
    else 
 
444
        return(extract(id))
 
445
}
 
446
 
 
447
width.party <- function(x, ...) {
 
448
  width(node_party(x), ...)
 
449
}
 
450
 
 
451
depth.party <- function(x, ...) {
 
452
  depth(node_party(x), ...)
 
453
}
 
454
 
 
455
.list.rules.party <- function(x, i = NULL, ...) {
 
456
    if (is.null(i)) i <- nodeids(x, terminal = TRUE)
 
457
    if (length(i) > 1) {
 
458
        ret <- sapply(i, .list.rules.party, x = x)
 
459
        names(ret) <- if (is.character(i)) i else names(x)[i]
 
460
        return(ret)
 
461
    }
 
462
    if (is.character(i) && !is.null(names(x)))
 
463
        i <- which(names(x) %in% i)
 
464
    stopifnot(length(i) == 1 & is.numeric(i))
 
465
    stopifnot(i <= length(x) & i >= 1)
 
466
    i <- as.integer(i)
 
467
    dat <- data_party(x, i)  
 
468
    if (!is.null(x$fitted)) {
 
469
        findx <- which("(fitted)" == names(dat))[1]  
 
470
        fit <- dat[,findx:ncol(dat), drop = FALSE]   
 
471
        dat <- dat[,-(findx:ncol(dat)), drop = FALSE]
 
472
        if (ncol(dat) == 0)
 
473
            dat <- x$data
 
474
    } else {
 
475
        fit <- NULL  
 
476
        dat <- x$data
 
477
    }
 
478
 
 
479
    rule <- c()
 
480
 
 
481
    recFun <- function(node) {
 
482
        if (id_node(node) == i) return(NULL)   
 
483
        kid <- sapply(kids_node(node), id_node)
 
484
        whichkid <- max(which(kid <= i))
 
485
        split <- split_node(node)
 
486
        ivar <- varid_split(split)
 
487
        svar <- names(dat)[ivar]
 
488
        index <- index_split(split)
 
489
        if (is.factor(dat[, svar])) {
 
490
            slevels <- levels(dat[, svar])[index == whichkid]
 
491
            srule <- paste(svar, " %in% c(\"", 
 
492
                paste(slevels, collapse = "\", \"", sep = ""), "\")",
 
493
                sep = "")
 
494
        } else {
 
495
            if (is.null(index)) index <- 1:length(kid)
 
496
            breaks <- cbind(c(-Inf, breaks_split(split)), 
 
497
                            c(breaks_split(split), Inf))
 
498
            sbreak <- breaks[index == whichkid,]
 
499
            right <- right_split(split)
 
500
            srule <- c()
 
501
            if (is.finite(sbreak[1]))
 
502
                srule <- c(srule, 
 
503
                    paste(svar, ifelse(right, ">", ">="), sbreak[1]))
 
504
            if (is.finite(sbreak[2]))
 
505
                srule <- c(srule, 
 
506
                    paste(svar, ifelse(right, "<=", ">"), sbreak[2]))
 
507
            srule <- paste(srule, collapse = " & ")
 
508
        }
 
509
        rule <<- c(rule, srule)
 
510
        return(recFun(node[[whichkid]]))
 
511
    }
 
512
    node <- recFun(node_party(x))
 
513
    paste(rule, collapse = " & ")
 
514
}