~ubuntu-branches/ubuntu/trusty/cdk/trusty-proposed

« back to all changes in this revision

Viewing changes to src/org/openscience/cdk/qsar/model/R2/CNNRegressionModel.java

  • Committer: Bazaar Package Importer
  • Author(s): Paul Cager
  • Date: 2008-04-09 21:17:53 UTC
  • Revision ID: james.westby@ubuntu.com-20080409211753-46lmjw5z8mx5pd8d
Tags: upstream-1.0.2
ImportĀ upstreamĀ versionĀ 1.0.2

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
package org.openscience.cdk.qsar.model.R2;
 
2
 
 
3
import org.openscience.cdk.qsar.model.QSARModelException;
 
4
import org.openscience.cdk.tools.LoggingTool;
 
5
import org.rosuda.JRI.RBool;
 
6
import org.rosuda.JRI.REXP;
 
7
import org.rosuda.JRI.RList;
 
8
 
 
9
import java.io.File;
 
10
import java.util.HashMap;
 
11
 
 
12
/**
 
13
 * A modeling class that provides a computational neural network regression model.
 
14
 * <p/>
 
15
 * When instantiated this class ensures that the R/Java interface has been
 
16
 * initialized. The response and independent variables can be specified at construction
 
17
 * time or via the <code>setParameters</code> method.
 
18
 * The actual fitting procedure is carried out by <code>build</code> after which
 
19
 * the model may be used to make predictions, via <code>predict</code>. An example of the use
 
20
 * of this class is shown below:
 
21
 * <pre>
 
22
 * double[][] x;
 
23
 * double[] y;
 
24
 * Double[] wts;
 
25
 * Double[][] newx;
 
26
 * ...
 
27
 * try {
 
28
 *     CNNRegressionModel cnnrm = new CNNRegressionModel(x,y,3);
 
29
 *     cnnrm.setParameters("Wts",wts);
 
30
 *     cnnrm.build();
 
31
 * <p/>
 
32
 *     double fitValue = cnnrm.getFitValue();
 
33
 * <p/>
 
34
 *     cnnrm.setParameters("newdata", newx);
 
35
 *     cnnrm.setParameters("type", "raw");
 
36
 *     cnnrm.predict();
 
37
 * <p/>
 
38
 *     double[][] preds = cnnrm.getPredictPredicted();
 
39
 * } catch (QSARModelException qme) {
 
40
 *     System.out.println(qme.toString());
 
41
 * }
 
42
 * </pre>
 
43
 * The above code snippet builds a 3-3-1 CNN model.
 
44
 * Multiple output neurons are easily
 
45
 * specified by supplying a matrix for y (i.e., double[][]) with the output variables
 
46
 * in the columns.
 
47
 * <p/>
 
48
 * Nearly all the arguments to
 
49
 * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet()</a> are
 
50
 * supported via the <code>setParameters</code> method. The table below lists the names of the arguments,
 
51
 * the expected type of the argument and the default setting for the arguments supported by this wrapper class.
 
52
 * <center>
 
53
 * <table border=1 cellpadding=5>
 
54
 * <THEAD>
 
55
 * <tr>
 
56
 * <th>Name</th><th>Java Type</th><th>Default</th><th>Notes</th>
 
57
 * </tr>
 
58
 * </thead>
 
59
 * <tbody>
 
60
 * <tr><td>x</td><td>Double[][]</td><td>None</td><td>This must be set by the caller via the constructors or via <code>setParameters</code></td></tr>
 
61
 * <tr><td>y</td><td>Double[][]</td><td>None</td><td>This must be set by the caller via the constructors or via <code>setParameters</code></td></tr>
 
62
 * <tr><td>weights</td><td>Double[]</td><td>rep(1,nobs)</td><td>The default case weights is a vector of 1's equal in length to the number of observations, nobs</td></tr>
 
63
 * <tr><td>size</td><td>Integer</td><td>None</td><td>This must be set by the caller via the constructors or via <code>setParameters</code></td></tr>
 
64
 * <tr><td>subset</td><td>Integer[]</td><td>1:nobs</td><td>This is supposed to be an index vector specifying which observations are to be used in building the model. The default indicates that all should be used</td></tr>
 
65
 * <tr><td>Wts</td><td>Double[]</td><td>runif(1,nwt)</td><td>The initial weight vector is set to a random vector of length equal to the number of weights if not set by the user</td></tr>
 
66
 * <tr><td>mask</td><td>Boolean[]</td><td>rep(TRUE,nwt)</td><td>All weights are to be optimized unless otherwise specified by the user</td></tr>
 
67
 * <tr><td>linout</td><td>Boolean</td><td>TRUE</td><td>Since this class performs regression this need not be changed</td></tr>
 
68
 * <tr><td>entropy</td><td>Boolean</td><td>FALSE</td><td></td></tr>
 
69
 * <tr><td>softmax</td><td>Boolean</td><td>FALSE</td><td></td></tr>
 
70
 * <tr><td>censored</td><td>Boolean</td><td>FALSE</td><td></td></tr>
 
71
 * <tr><td>skip</td><td>Boolean</td><td>FALSE</td><td></td></tr>
 
72
 * <tr><td>rang</td><td>Double</td><td>0.7</td><td></td></tr>
 
73
 * <tr><td>decay</td><td>Double</td><td>0.0</td><td></td></tr>
 
74
 * <tr><td>maxit</td><td>Integer</td><td>100</td><td></td></tr>
 
75
 * <tr><td>Hess</td><td>Boolean</td><td>FALSE</td><td></td></tr>
 
76
 * <tr><td>trace</td><td>Boolean</td><td>TRUE</td><td></td></tr>
 
77
 * <tr><td>MaxNWts</td><td>Integer</td><td>1000</td><td></td></tr>
 
78
 * <tr><td>abstol</td><td>Double</td><td>1.0e-4</td><td></td></tr>
 
79
 * <tr><td>reltol</td><td>Double</td><td>1.0e-8</td><td></td></tr>
 
80
 * </tbody>
 
81
 * </table>
 
82
 * </center>
 
83
 * <p/>
 
84
 * The values returned correspond to the various
 
85
 * values returned by the <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a> and
 
86
 * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/predict.nnet.html" target="_top">predict.nnet</a> functions
 
87
 * in R
 
88
 * <p/>
 
89
 * See {@link org.openscience.cdk.qsar.model.R.RModel} for details regarding the R and Java environment.
 
90
 *
 
91
 * @author      Rajarshi Guha
 
92
 * @cdk.require r-project
 
93
 * @cdk.require java1.5+
 
94
 * @cdk.module  qsar
 
95
 * @cdk.keyword neural network
 
96
 * @cdk.keyword R
 
97
 */
 
98
 
 
99
public class CNNRegressionModel extends RModel {
 
100
    public static int globalID = 0;
 
101
    private int noutput = 0;
 
102
    private int nvar = 0;
 
103
 
 
104
    private double[][] modelPredict = null;
 
105
 
 
106
    private static LoggingTool logger;
 
107
 
 
108
    private void setDefaults() {
 
109
        // lets set the default values of the arguments that are specified
 
110
        // to have default values in ?nnet
 
111
 
 
112
        // these params are vectors that depend on user defined stuff
 
113
        // so as a default we set them to FALSE so R can check if these
 
114
        // were not set
 
115
        this.params.put("subset", Boolean.FALSE);
 
116
        this.params.put("mask", Boolean.FALSE);
 
117
        this.params.put("Wts", Boolean.FALSE);
 
118
        this.params.put("weights", Boolean.FALSE);
 
119
 
 
120
        this.params.put("linout", Boolean.TRUE); // we want only regression
 
121
        this.params.put("entropy", Boolean.FALSE);
 
122
        this.params.put("softmax", Boolean.FALSE);
 
123
        this.params.put("censored", Boolean.FALSE);
 
124
        this.params.put("skip", Boolean.FALSE);
 
125
        this.params.put("rang", new Double(0.7));
 
126
        this.params.put("decay", new Double(0.0));
 
127
        this.params.put("maxit", new Integer(100));
 
128
        this.params.put("Hess", Boolean.FALSE);
 
129
        this.params.put("trace", Boolean.FALSE); // no need to see output
 
130
        this.params.put("MaxNWts", new Integer(1000));
 
131
        this.params.put("abstol", new Double(1.0e-4));
 
132
        this.params.put("reltol", new Double(1.0e-8));
 
133
    }
 
134
 
 
135
    /**
 
136
     * Constructs a CNNRegressionModel object.
 
137
     * <p/>
 
138
     * This constructor allows the user to simply set up an instance of a CNN
 
139
     * regression modeling class. This constructor simply sets the name for this
 
140
     * instance. It is expected all the relevent parameters for modeling will be
 
141
     * set at a later point.
 
142
     * <p/>
 
143
     * Other parameters that are required to be set should be done via
 
144
     * calls to <code>setParameters</code>. A number of parameters are set to the
 
145
     * defaults as specified in the manpage for
 
146
     * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>.
 
147
     */
 
148
    public CNNRegressionModel() throws QSARModelException {
 
149
        super();
 
150
        logger = new LoggingTool(this);
 
151
 
 
152
        params = new HashMap();
 
153
        int currentID = CNNRegressionModel.globalID;
 
154
        CNNRegressionModel.globalID++;
 
155
        setModelName("cdkCNNModel" + currentID);
 
156
        setDefaults();
 
157
 
 
158
 
 
159
    }
 
160
 
 
161
 
 
162
    /**
 
163
     * Constructs a CNNRegressionModel object.
 
164
     * <p/>
 
165
     * This constructor allows the user to specify the dependent and
 
166
     * independent variables along with the number of hidden layer neurons.
 
167
     * This constructor is suitable for cases when there is a single output
 
168
     * neuron. If the number of rows of the design matrix is not equal to
 
169
     * the number of observations in y an exception will be thrown.
 
170
     * <p/>
 
171
     * Other parameters that are required to be set should be done via
 
172
     * calls to <code>setParameters</code>. A number of parameters are set to the
 
173
     * defaults as specified in the manpage for
 
174
     * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>.
 
175
     *
 
176
     * @param x    An array of independent variables. Observations should be in
 
177
     *             the rows and variables in the columns.
 
178
     * @param y    An array (single column) of observed values
 
179
     * @param size The number of hidden layer neurons
 
180
     * @throws QSARModelException if the number of observations in x and y do not match
 
181
     */
 
182
    public CNNRegressionModel(double[][] x, double[] y, int size) throws QSARModelException {
 
183
        super();
 
184
        logger = new LoggingTool(this);
 
185
 
 
186
        params = new HashMap();
 
187
        int currentID = CNNRegressionModel.globalID;
 
188
        CNNRegressionModel.globalID++;
 
189
        setModelName("cdkCNNModel" + currentID);
 
190
 
 
191
        int nrow = y.length;
 
192
        int ncol = x[0].length;
 
193
 
 
194
        if (nrow != x.length) {
 
195
            throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix");
 
196
        }
 
197
 
 
198
        nvar = ncol;
 
199
        noutput = 1;
 
200
 
 
201
        Double[][] xx = new Double[nrow][ncol];
 
202
        Double[][] yy = new Double[nrow][1];
 
203
 
 
204
        for (int i = 0; i < nrow; i++) {
 
205
            yy[i][0] = new Double(y[i]);
 
206
            for (int j = 0; j < ncol; j++) {
 
207
                xx[i][j] = new Double(x[i][j]);
 
208
            }
 
209
        }
 
210
        params.put("x", xx);
 
211
        params.put("y", yy);
 
212
        params.put("size", new Integer(size));
 
213
        setDefaults();
 
214
    }
 
215
 
 
216
    /**
 
217
     * Constructs a CNNRegressionModel object.
 
218
     * <p/>
 
219
     * This constructor allows the user to specify the dependent and
 
220
     * independent variables along with the number of hidden layer neurons.
 
221
     * This constructor is suitable for cases when there are multiple output
 
222
     * neuron. If the number of rows of the design matrix is not equal to
 
223
     * the number of observations in y an exception will be thrown.
 
224
     * <p/>
 
225
     * Other parameters that are required to be set should be done via
 
226
     * calls to <code>setParameters</code>. A number of parameters are set to the
 
227
     * defaults as specified in the manpage for
 
228
     * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>.
 
229
     *
 
230
     * @param x    An array of independent variables. Observations should be in
 
231
     *             the rows and variables in the columns.
 
232
     * @param y    An array (multiple columns) of observed values
 
233
     * @param size The number of hidden layer neurons
 
234
     * @throws QSARModelException if the number of observations in x and y do not match
 
235
     */
 
236
    public CNNRegressionModel(double[][] x, double[][] y, int size) throws QSARModelException {
 
237
        super();
 
238
        logger = new LoggingTool(this);
 
239
 
 
240
        params = new HashMap();
 
241
        int currentID = CNNRegressionModel.globalID;
 
242
        CNNRegressionModel.globalID++;
 
243
        setModelName("cdkCNNModel" + currentID);
 
244
 
 
245
        int nrow = y.length;
 
246
        int ncol = x[0].length;
 
247
 
 
248
        if (nrow != x.length) {
 
249
            throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix");
 
250
        }
 
251
 
 
252
        nvar = ncol;
 
253
        noutput = y[0].length;
 
254
 
 
255
        Double[][] xx = new Double[nrow][ncol];
 
256
        Double[][] yy = new Double[nrow][noutput];
 
257
 
 
258
        for (int i = 0; i < nrow; i++) {
 
259
            for (int j = 0; j < ncol; j++) {
 
260
                xx[i][j] = new Double(x[i][j]);
 
261
            }
 
262
        }
 
263
        for (int i = 0; i < nrow; i++) {
 
264
            for (int j = 0; j < noutput; j++) {
 
265
                yy[i][j] = new Double(y[i][j]);
 
266
            }
 
267
        }
 
268
        params.put("x", xx);
 
269
        params.put("y", yy);
 
270
        params.put("size", new Integer(size));
 
271
        setDefaults();
 
272
    }
 
273
 
 
274
 
 
275
    /**
 
276
     * Sets parameters required for building a CNN model or using one for prediction.
 
277
     * <p/>
 
278
     * This function allows the caller to set the various parameters available
 
279
     * for the
 
280
     * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>
 
281
     * and
 
282
     * <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/predict.nnet.html" target="_top">predict.nnet</a>
 
283
     * R routines. See the R help pages for the details of the available
 
284
     * parameters.
 
285
     *
 
286
     * @param key A String containing the name of the parameter as described in the
 
287
     *            R help pages
 
288
     * @param obj An Object containing the value of the parameter
 
289
     * @throws QSARModelException if the type of the supplied value does not match the
 
290
     *                            expected type
 
291
     */
 
292
    public void setParameters(String key, Object obj) throws QSARModelException {
 
293
        // since we know the possible values of key we should check the coresponding
 
294
        // objects and throw errors if required. Note that this checking can't really check
 
295
        // for values (such as number of variables in the X matrix to build the model and the
 
296
        // X matrix to make new predictions) - these should be checked in functions that will
 
297
        // use these parameters. The main checking done here is for the class of obj and
 
298
        // some cases where the value of obj is not dependent on what is set before it
 
299
 
 
300
        if (key.equals("y")) {
 
301
            if (!(obj instanceof Double[][])) {
 
302
                throw new QSARModelException("The class of the 'y' object must be Double[][]");
 
303
            } else {
 
304
                noutput = ((Double[][]) obj)[0].length;
 
305
            }
 
306
        }
 
307
        if (key.equals("x")) {
 
308
            if (!(obj instanceof Double[][])) {
 
309
                throw new QSARModelException("The class of the 'x' object must be Double[][]");
 
310
            } else {
 
311
                nvar = ((Double[][]) obj)[0].length;
 
312
            }
 
313
        }
 
314
        if (key.equals("weights")) {
 
315
            if (!(obj instanceof Double[])) {
 
316
                throw new QSARModelException("The class of the 'weights' object must be Double[]");
 
317
            }
 
318
        }
 
319
        if (key.equals("size")) {
 
320
            if (!(obj instanceof Integer)) {
 
321
                throw new QSARModelException("The class of the 'size' object must be Integer");
 
322
            }
 
323
        }
 
324
        if (key.equals("subset")) {
 
325
            if (!(obj instanceof Integer[])) {
 
326
                throw new QSARModelException("The class of the 'size' object must be Integer[]");
 
327
            }
 
328
        }
 
329
        if (key.equals("Wts")) {
 
330
            if (!(obj instanceof Double[])) {
 
331
                throw new QSARModelException("The class of the 'Wts' object must be Double[]");
 
332
            }
 
333
        }
 
334
        if (key.equals("mask")) {
 
335
            if (!(obj instanceof Boolean[])) {
 
336
                throw new QSARModelException("The class of the 'mask' object must be Boolean[]");
 
337
            }
 
338
        }
 
339
        if (key.equals("linout") ||
 
340
                key.equals("entropy") ||
 
341
                key.equals("softmax") ||
 
342
                key.equals("censored") ||
 
343
                key.equals("skip") ||
 
344
                key.equals("Hess") ||
 
345
                key.equals("trace")) {
 
346
            if (!(obj instanceof Boolean)) {
 
347
                throw new QSARModelException("The class of the 'trace|skip|Hess|linout|entropy|softmax|censored' object must be Boolean");
 
348
            }
 
349
        }
 
350
        if (key.equals("rang") ||
 
351
                key.equals("decay") ||
 
352
                key.equals("abstol") ||
 
353
                key.equals("reltol")) {
 
354
            if (!(obj instanceof Double)) {
 
355
                throw new QSARModelException("The class of the 'reltol|abstol|decay|rang' object must be Double");
 
356
            }
 
357
        }
 
358
        if (key.equals("maxit") ||
 
359
                key.equals("MaxNWts")) {
 
360
            if (!(obj instanceof Integer)) {
 
361
                throw new QSARModelException("The class of the 'maxit|MaxNWts' object must be Integer");
 
362
            }
 
363
        }
 
364
 
 
365
        if (key.equals("newdata")) {
 
366
            if (!(obj instanceof Double[][])) {
 
367
                throw new QSARModelException("The class of the 'newdata' object must be Double[][]");
 
368
            }
 
369
        }
 
370
        params.put(key, obj);
 
371
    }
 
372
 
 
373
    /**
 
374
     * Fits a CNN regression model.
 
375
     * <p/>
 
376
     * This method calls the R function to fit a CNN regression model
 
377
     * to the specified dependent and independent variables. If an error
 
378
     * occurs in the R session, an exception is thrown.
 
379
     * <p/>
 
380
     * Note that, this method should be called prior to calling the various get
 
381
     * methods to obtain information regarding the fit.
 
382
     */
 
383
    public void build() throws QSARModelException {
 
384
        Double[][] x;
 
385
        Double[][] y;
 
386
        x = (Double[][]) this.params.get("x");
 
387
        y = (Double[][]) this.params.get("y");
 
388
        if (x.length != y.length)
 
389
            throw new QSARModelException("Number of observations does not match number of rows in the design matrix");
 
390
        if (nvar == 0) nvar = x[0].length;
 
391
 
 
392
        // lets build the model
 
393
        String paramVarName = loadParametersIntoRSession();
 
394
        String cmd = "buildCNN(\"" + getModelName() + "\", " + paramVarName + ")";
 
395
        REXP ret = rengine.eval(cmd);
 
396
        if (ret == null) {
 
397
            CNNRegressionModel.logger.debug("Error in buildCNN");
 
398
            throw new QSARModelException("Error in buildCNN");
 
399
        }
 
400
 
 
401
        // remove the parameter list
 
402
        rengine.eval("rm(" + paramVarName + ")");
 
403
 
 
404
        // save the model object on the Java side
 
405
        modelObject = ret.asList();
 
406
    }
 
407
 
 
408
    /**
 
409
     * Uses a fitted model to predict the response for new observations.
 
410
     * <p/>
 
411
     * This function uses a previously fitted model to obtain predicted values
 
412
     * for a new set of observations. If the model has not been fitted prior to this
 
413
     * call an exception will be thrown. Use <code>setParameters</code>
 
414
     * to set the values of the independent variable for the new observations and the
 
415
     * interval type.
 
416
     *
 
417
     * @throws org.openscience.cdk.qsar.model.QSARModelException
 
418
     *          if the model has not been built prior to a call
 
419
     *          to this method. Also if the number of independent variables specified for prediction
 
420
     *          is not the same as specified during model building
 
421
     */
 
422
    public void predict() throws QSARModelException {
 
423
 
 
424
        if (modelObject == null)
 
425
            throw new QSARModelException("Before calling predict() you must fit the model using build()");
 
426
 
 
427
        Double[][] newx = (Double[][]) params.get("newdata");
 
428
        if (newx[0].length != nvar) {
 
429
            throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting");
 
430
        }
 
431
 
 
432
        String pn = loadParametersIntoRSession();
 
433
        REXP ret = rengine.eval("predicCNN(\"" + getModelName() + "\", " + pn + ")");
 
434
        if (ret == null) throw new QSARModelException("Error occured in prediction");
 
435
 
 
436
        // remove the parameter list
 
437
        rengine.eval("rm(" + pn + ")");
 
438
 
 
439
        modelPredict = ret.asDoubleMatrix();
 
440
    }
 
441
 
 
442
    /**
 
443
     * Get the matrix of predicted values obtained from <code>predict.nnet<code>.
 
444
     *
 
445
     * @return The result of the prediction.
 
446
     */
 
447
    public double[][] getPredictions() {
 
448
        return modelPredict;
 
449
    }
 
450
 
 
451
    /**
 
452
     * Returns an <code>RList</code> object summarizing the nnet regression model.
 
453
     * <p/>
 
454
     * The return object can be queried via the <code>RList</code> methods to extract the
 
455
     * required components.
 
456
     *
 
457
     * @return A summary for the nnet regression model
 
458
     * @throws org.openscience.cdk.qsar.model.QSARModelException
 
459
     *          if the model has not been built prior to a call
 
460
     *          to this method
 
461
     */
 
462
    public RList summary() throws QSARModelException {
 
463
        if (modelObject == null)
 
464
            throw new QSARModelException("Before calling summary() you must fit the model using build()");
 
465
 
 
466
        REXP ret = rengine.eval("summary(" + getModelName() + ")");
 
467
        if (ret == null) {
 
468
            logger.debug("Error in summary()");
 
469
            throw new QSARModelException("Error in summary()");
 
470
        }
 
471
        return ret.asList();
 
472
    }
 
473
 
 
474
 
 
475
    /**
 
476
     * Loads a <code>'nnet'</code> object from disk in to the current session.
 
477
     *
 
478
     * @param fileName The disk file containing the model
 
479
     * @throws org.openscience.cdk.qsar.model.QSARModelException
 
480
     *          if the model being loaded is not a <code>'nnet'</code> model
 
481
     *          object  or the file does not exist
 
482
     */
 
483
    public void loadModel(String fileName) throws QSARModelException {
 
484
        File f = new File(fileName);
 
485
        if (!f.exists()) throw new QSARModelException(fileName + " does not exist");
 
486
 
 
487
        rengine.assign("tmpFileName", fileName);
 
488
        REXP ret = rengine.eval("loadModel(tmpFileName)");
 
489
        if (ret == null) throw new QSARModelException("Model could not be loaded");
 
490
 
 
491
        String name = ret.asList().at("name").asString();
 
492
        if (!isOfClass(name, "nnet")) {
 
493
            removeObject(name);
 
494
            throw new QSARModelException("Loaded object was not of class \'nnet\'");
 
495
        }
 
496
 
 
497
        modelObject = ret.asList().at("model").asList();
 
498
        setModelName(name);
 
499
        nvar = (int) getN()[0];
 
500
        noutput = (int) getN()[2];
 
501
    }
 
502
 
 
503
    /**
 
504
     * Loads a  <code>'nnet'</code> object from a serialized string into the current session.
 
505
     *
 
506
     * @param serializedModel A String containing the serialized version of the model
 
507
     * @param modelName       A String indicating the name of the model in the R session
 
508
     * @throws org.openscience.cdk.qsar.model.QSARModelException
 
509
     *          if the model being loaded is not a <code>'nnet'</code> model
 
510
     *          object
 
511
     */
 
512
    public void loadModel(String serializedModel, String modelName) throws QSARModelException {
 
513
        rengine.assign("tmpSerializedModel", serializedModel);
 
514
        rengine.assign("tmpModelName", modelName);
 
515
        REXP ret = rengine.eval("unserializeModel(tmpSerializedModel, tmpModelName)");
 
516
 
 
517
        if (ret == null) throw new QSARModelException("Model could not be unserialized");
 
518
 
 
519
        String name = ret.asList().at("name").asString();
 
520
        if (!isOfClass(name, "nnet")) {
 
521
            removeObject(name);
 
522
            throw new QSARModelException("Loaded object was not of class \'nnet\'");
 
523
        }
 
524
 
 
525
        modelObject = ret.asList().at("model").asList();
 
526
        setModelName(name);
 
527
        nvar = (int) getN()[0];
 
528
        noutput = (int) getN()[2];
 
529
    }
 
530
 
 
531
// Autogenerated code: assumes that 'modelObject' is
 
532
// a RList object
 
533
 
 
534
 
 
535
    /**
 
536
     * Gets the <code>censored</code> field of an <code>'nnet'</code> object.
 
537
     *
 
538
     * @return The value of the censored field
 
539
     */
 
540
    public RBool getCensored() {
 
541
        return modelObject.at("censored").asBool();
 
542
    }
 
543
 
 
544
    /**
 
545
     * Gets the <code>conn</code> field of an <code>'nnet'</code> object.
 
546
     *
 
547
     * @return The value of the conn field
 
548
     */
 
549
    public double[] getConn() {
 
550
        return modelObject.at("conn").asDoubleArray();
 
551
    }
 
552
 
 
553
    /**
 
554
     * Gets the <code>decay</code> field of an <code>'nnet'</code> object.
 
555
     *
 
556
     * @return The value of the decay field
 
557
     */
 
558
    public double getDecay() {
 
559
        return modelObject.at("decay").asDouble();
 
560
    }
 
561
 
 
562
    /**
 
563
     * Gets the <code>entropy</code> field of an <code>'nnet'</code> object.
 
564
     *
 
565
     * @return The value of the entropy field
 
566
     */
 
567
    public RBool getEntropy() {
 
568
        return modelObject.at("entropy").asBool();
 
569
    }
 
570
 
 
571
    /**
 
572
     * Gets the <code>fitted.values</code> field of an <code>'nnet'</code> object.
 
573
     *
 
574
     * @return The value of the fitted.values field
 
575
     */
 
576
    public double[][] getFittedValues() {
 
577
        return modelObject.at("fitted.values").asDoubleMatrix();
 
578
    }
 
579
 
 
580
    /**
 
581
     * Gets the <code>n</code> field of an <code>'nnet'</code> object.
 
582
     *
 
583
     * @return The value of the n field
 
584
     */
 
585
    public double[] getN() {
 
586
        return modelObject.at("n").asDoubleArray();
 
587
    }
 
588
 
 
589
    /**
 
590
     * Gets the <code>nconn</code> field of an <code>'nnet'</code> object.
 
591
     *
 
592
     * @return The value of the nconn field
 
593
     */
 
594
    public double[] getNconn() {
 
595
        return modelObject.at("nconn").asDoubleArray();
 
596
    }
 
597
 
 
598
    /**
 
599
     * Gets the <code>nsunits</code> field of an <code>'nnet'</code> object.
 
600
     *
 
601
     * @return The value of the nsunits field
 
602
     */
 
603
    public double getNsunits() {
 
604
        return modelObject.at("nsunits").asDouble();
 
605
    }
 
606
 
 
607
    /**
 
608
     * Gets the <code>nunits</code> field of an <code>'nnet'</code> object.
 
609
     *
 
610
     * @return The value of the nunits field
 
611
     */
 
612
    public double getNunits() {
 
613
        return modelObject.at("nunits").asDouble();
 
614
    }
 
615
 
 
616
    /**
 
617
     * Gets the <code>residuals</code> field of an <code>'nnet'</code> object.
 
618
     *
 
619
     * @return The value of the residuals field
 
620
     */
 
621
    public double[][] getResiduals() {
 
622
        return modelObject.at("residuals").asDoubleMatrix();
 
623
    }
 
624
 
 
625
    /**
 
626
     * Gets the <code>softmax</code> field of an <code>'nnet'</code> object.
 
627
     *
 
628
     * @return The value of the softmax field
 
629
     */
 
630
    public RBool getSoftmax() {
 
631
        return modelObject.at("softmax").asBool();
 
632
    }
 
633
 
 
634
    /**
 
635
     * Gets the <code>value</code> field of an <code>'nnet'</code> object.
 
636
     *
 
637
     * @return The value of the value field
 
638
     */
 
639
    public double getValue() {
 
640
        return modelObject.at("value").asDouble();
 
641
    }
 
642
 
 
643
    /**
 
644
     * Gets the <code>wts</code> field of an <code>'nnet'</code> object.
 
645
     *
 
646
     * @return The value of the wts field
 
647
     */
 
648
    public double[] getWts() {
 
649
        return modelObject.at("wts").asDoubleArray();
 
650
    }
 
651
 
 
652
 
 
653
}