1
package org.openscience.cdk.qsar.model.R2;
3
import org.openscience.cdk.qsar.model.QSARModelException;
4
import org.openscience.cdk.tools.LoggingTool;
5
import org.rosuda.JRI.RBool;
6
import org.rosuda.JRI.REXP;
7
import org.rosuda.JRI.RList;
10
import java.util.HashMap;
13
* A modeling class that provides a computational neural network regression model.
15
* When instantiated this class ensures that the R/Java interface has been
16
* initialized. The response and independent variables can be specified at construction
17
* time or via the <code>setParameters</code> method.
18
* The actual fitting procedure is carried out by <code>build</code> after which
19
* the model may be used to make predictions, via <code>predict</code>. An example of the use
20
* of this class is shown below:
28
* CNNRegressionModel cnnrm = new CNNRegressionModel(x,y,3);
29
* cnnrm.setParameters("Wts",wts);
32
* double fitValue = cnnrm.getFitValue();
34
* cnnrm.setParameters("newdata", newx);
35
* cnnrm.setParameters("type", "raw");
38
* double[][] preds = cnnrm.getPredictPredicted();
39
* } catch (QSARModelException qme) {
40
* System.out.println(qme.toString());
43
* The above code snippet builds a 3-3-1 CNN model.
44
* Multiple output neurons are easily
45
* specified by supplying a matrix for y (i.e., double[][]) with the output variables
48
* Nearly all the arguments to
49
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet()</a> are
50
* supported via the <code>setParameters</code> method. The table below lists the names of the arguments,
51
* the expected type of the argument and the default setting for the arguments supported by this wrapper class.
53
* <table border=1 cellpadding=5>
56
* <th>Name</th><th>Java Type</th><th>Default</th><th>Notes</th>
60
* <tr><td>x</td><td>Double[][]</td><td>None</td><td>This must be set by the caller via the constructors or via <code>setParameters</code></td></tr>
61
* <tr><td>y</td><td>Double[][]</td><td>None</td><td>This must be set by the caller via the constructors or via <code>setParameters</code></td></tr>
62
* <tr><td>weights</td><td>Double[]</td><td>rep(1,nobs)</td><td>The default case weights is a vector of 1's equal in length to the number of observations, nobs</td></tr>
63
* <tr><td>size</td><td>Integer</td><td>None</td><td>This must be set by the caller via the constructors or via <code>setParameters</code></td></tr>
64
* <tr><td>subset</td><td>Integer[]</td><td>1:nobs</td><td>This is supposed to be an index vector specifying which observations are to be used in building the model. The default indicates that all should be used</td></tr>
65
* <tr><td>Wts</td><td>Double[]</td><td>runif(1,nwt)</td><td>The initial weight vector is set to a random vector of length equal to the number of weights if not set by the user</td></tr>
66
* <tr><td>mask</td><td>Boolean[]</td><td>rep(TRUE,nwt)</td><td>All weights are to be optimized unless otherwise specified by the user</td></tr>
67
* <tr><td>linout</td><td>Boolean</td><td>TRUE</td><td>Since this class performs regression this need not be changed</td></tr>
68
* <tr><td>entropy</td><td>Boolean</td><td>FALSE</td><td></td></tr>
69
* <tr><td>softmax</td><td>Boolean</td><td>FALSE</td><td></td></tr>
70
* <tr><td>censored</td><td>Boolean</td><td>FALSE</td><td></td></tr>
71
* <tr><td>skip</td><td>Boolean</td><td>FALSE</td><td></td></tr>
72
* <tr><td>rang</td><td>Double</td><td>0.7</td><td></td></tr>
73
* <tr><td>decay</td><td>Double</td><td>0.0</td><td></td></tr>
74
* <tr><td>maxit</td><td>Integer</td><td>100</td><td></td></tr>
75
* <tr><td>Hess</td><td>Boolean</td><td>FALSE</td><td></td></tr>
76
* <tr><td>trace</td><td>Boolean</td><td>TRUE</td><td></td></tr>
77
* <tr><td>MaxNWts</td><td>Integer</td><td>1000</td><td></td></tr>
78
* <tr><td>abstol</td><td>Double</td><td>1.0e-4</td><td></td></tr>
79
* <tr><td>reltol</td><td>Double</td><td>1.0e-8</td><td></td></tr>
84
* The values returned correspond to the various
85
* values returned by the <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a> and
86
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/predict.nnet.html" target="_top">predict.nnet</a> functions
89
* See {@link org.openscience.cdk.qsar.model.R.RModel} for details regarding the R and Java environment.
91
* @author Rajarshi Guha
92
* @cdk.require r-project
93
* @cdk.require java1.5+
95
* @cdk.keyword neural network
99
public class CNNRegressionModel extends RModel {
100
public static int globalID = 0;
101
private int noutput = 0;
102
private int nvar = 0;
104
private double[][] modelPredict = null;
106
private static LoggingTool logger;
108
private void setDefaults() {
109
// lets set the default values of the arguments that are specified
110
// to have default values in ?nnet
112
// these params are vectors that depend on user defined stuff
113
// so as a default we set them to FALSE so R can check if these
115
this.params.put("subset", Boolean.FALSE);
116
this.params.put("mask", Boolean.FALSE);
117
this.params.put("Wts", Boolean.FALSE);
118
this.params.put("weights", Boolean.FALSE);
120
this.params.put("linout", Boolean.TRUE); // we want only regression
121
this.params.put("entropy", Boolean.FALSE);
122
this.params.put("softmax", Boolean.FALSE);
123
this.params.put("censored", Boolean.FALSE);
124
this.params.put("skip", Boolean.FALSE);
125
this.params.put("rang", new Double(0.7));
126
this.params.put("decay", new Double(0.0));
127
this.params.put("maxit", new Integer(100));
128
this.params.put("Hess", Boolean.FALSE);
129
this.params.put("trace", Boolean.FALSE); // no need to see output
130
this.params.put("MaxNWts", new Integer(1000));
131
this.params.put("abstol", new Double(1.0e-4));
132
this.params.put("reltol", new Double(1.0e-8));
136
* Constructs a CNNRegressionModel object.
138
* This constructor allows the user to simply set up an instance of a CNN
139
* regression modeling class. This constructor simply sets the name for this
140
* instance. It is expected all the relevent parameters for modeling will be
141
* set at a later point.
143
* Other parameters that are required to be set should be done via
144
* calls to <code>setParameters</code>. A number of parameters are set to the
145
* defaults as specified in the manpage for
146
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>.
148
public CNNRegressionModel() throws QSARModelException {
150
logger = new LoggingTool(this);
152
params = new HashMap();
153
int currentID = CNNRegressionModel.globalID;
154
CNNRegressionModel.globalID++;
155
setModelName("cdkCNNModel" + currentID);
163
* Constructs a CNNRegressionModel object.
165
* This constructor allows the user to specify the dependent and
166
* independent variables along with the number of hidden layer neurons.
167
* This constructor is suitable for cases when there is a single output
168
* neuron. If the number of rows of the design matrix is not equal to
169
* the number of observations in y an exception will be thrown.
171
* Other parameters that are required to be set should be done via
172
* calls to <code>setParameters</code>. A number of parameters are set to the
173
* defaults as specified in the manpage for
174
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>.
176
* @param x An array of independent variables. Observations should be in
177
* the rows and variables in the columns.
178
* @param y An array (single column) of observed values
179
* @param size The number of hidden layer neurons
180
* @throws QSARModelException if the number of observations in x and y do not match
182
public CNNRegressionModel(double[][] x, double[] y, int size) throws QSARModelException {
184
logger = new LoggingTool(this);
186
params = new HashMap();
187
int currentID = CNNRegressionModel.globalID;
188
CNNRegressionModel.globalID++;
189
setModelName("cdkCNNModel" + currentID);
192
int ncol = x[0].length;
194
if (nrow != x.length) {
195
throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix");
201
Double[][] xx = new Double[nrow][ncol];
202
Double[][] yy = new Double[nrow][1];
204
for (int i = 0; i < nrow; i++) {
205
yy[i][0] = new Double(y[i]);
206
for (int j = 0; j < ncol; j++) {
207
xx[i][j] = new Double(x[i][j]);
212
params.put("size", new Integer(size));
217
* Constructs a CNNRegressionModel object.
219
* This constructor allows the user to specify the dependent and
220
* independent variables along with the number of hidden layer neurons.
221
* This constructor is suitable for cases when there are multiple output
222
* neuron. If the number of rows of the design matrix is not equal to
223
* the number of observations in y an exception will be thrown.
225
* Other parameters that are required to be set should be done via
226
* calls to <code>setParameters</code>. A number of parameters are set to the
227
* defaults as specified in the manpage for
228
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>.
230
* @param x An array of independent variables. Observations should be in
231
* the rows and variables in the columns.
232
* @param y An array (multiple columns) of observed values
233
* @param size The number of hidden layer neurons
234
* @throws QSARModelException if the number of observations in x and y do not match
236
public CNNRegressionModel(double[][] x, double[][] y, int size) throws QSARModelException {
238
logger = new LoggingTool(this);
240
params = new HashMap();
241
int currentID = CNNRegressionModel.globalID;
242
CNNRegressionModel.globalID++;
243
setModelName("cdkCNNModel" + currentID);
246
int ncol = x[0].length;
248
if (nrow != x.length) {
249
throw new QSARModelException("The number of values for the dependent variable does not match the number of rows of the design matrix");
253
noutput = y[0].length;
255
Double[][] xx = new Double[nrow][ncol];
256
Double[][] yy = new Double[nrow][noutput];
258
for (int i = 0; i < nrow; i++) {
259
for (int j = 0; j < ncol; j++) {
260
xx[i][j] = new Double(x[i][j]);
263
for (int i = 0; i < nrow; i++) {
264
for (int j = 0; j < noutput; j++) {
265
yy[i][j] = new Double(y[i][j]);
270
params.put("size", new Integer(size));
276
* Sets parameters required for building a CNN model or using one for prediction.
278
* This function allows the caller to set the various parameters available
280
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/nnet.html" target="_top">nnet</a>
282
* <a href="http://www.maths.lth.se/help/R/.R/library/nnet/html/predict.nnet.html" target="_top">predict.nnet</a>
283
* R routines. See the R help pages for the details of the available
286
* @param key A String containing the name of the parameter as described in the
288
* @param obj An Object containing the value of the parameter
289
* @throws QSARModelException if the type of the supplied value does not match the
292
public void setParameters(String key, Object obj) throws QSARModelException {
293
// since we know the possible values of key we should check the coresponding
294
// objects and throw errors if required. Note that this checking can't really check
295
// for values (such as number of variables in the X matrix to build the model and the
296
// X matrix to make new predictions) - these should be checked in functions that will
297
// use these parameters. The main checking done here is for the class of obj and
298
// some cases where the value of obj is not dependent on what is set before it
300
if (key.equals("y")) {
301
if (!(obj instanceof Double[][])) {
302
throw new QSARModelException("The class of the 'y' object must be Double[][]");
304
noutput = ((Double[][]) obj)[0].length;
307
if (key.equals("x")) {
308
if (!(obj instanceof Double[][])) {
309
throw new QSARModelException("The class of the 'x' object must be Double[][]");
311
nvar = ((Double[][]) obj)[0].length;
314
if (key.equals("weights")) {
315
if (!(obj instanceof Double[])) {
316
throw new QSARModelException("The class of the 'weights' object must be Double[]");
319
if (key.equals("size")) {
320
if (!(obj instanceof Integer)) {
321
throw new QSARModelException("The class of the 'size' object must be Integer");
324
if (key.equals("subset")) {
325
if (!(obj instanceof Integer[])) {
326
throw new QSARModelException("The class of the 'size' object must be Integer[]");
329
if (key.equals("Wts")) {
330
if (!(obj instanceof Double[])) {
331
throw new QSARModelException("The class of the 'Wts' object must be Double[]");
334
if (key.equals("mask")) {
335
if (!(obj instanceof Boolean[])) {
336
throw new QSARModelException("The class of the 'mask' object must be Boolean[]");
339
if (key.equals("linout") ||
340
key.equals("entropy") ||
341
key.equals("softmax") ||
342
key.equals("censored") ||
343
key.equals("skip") ||
344
key.equals("Hess") ||
345
key.equals("trace")) {
346
if (!(obj instanceof Boolean)) {
347
throw new QSARModelException("The class of the 'trace|skip|Hess|linout|entropy|softmax|censored' object must be Boolean");
350
if (key.equals("rang") ||
351
key.equals("decay") ||
352
key.equals("abstol") ||
353
key.equals("reltol")) {
354
if (!(obj instanceof Double)) {
355
throw new QSARModelException("The class of the 'reltol|abstol|decay|rang' object must be Double");
358
if (key.equals("maxit") ||
359
key.equals("MaxNWts")) {
360
if (!(obj instanceof Integer)) {
361
throw new QSARModelException("The class of the 'maxit|MaxNWts' object must be Integer");
365
if (key.equals("newdata")) {
366
if (!(obj instanceof Double[][])) {
367
throw new QSARModelException("The class of the 'newdata' object must be Double[][]");
370
params.put(key, obj);
374
* Fits a CNN regression model.
376
* This method calls the R function to fit a CNN regression model
377
* to the specified dependent and independent variables. If an error
378
* occurs in the R session, an exception is thrown.
380
* Note that, this method should be called prior to calling the various get
381
* methods to obtain information regarding the fit.
383
public void build() throws QSARModelException {
386
x = (Double[][]) this.params.get("x");
387
y = (Double[][]) this.params.get("y");
388
if (x.length != y.length)
389
throw new QSARModelException("Number of observations does not match number of rows in the design matrix");
390
if (nvar == 0) nvar = x[0].length;
392
// lets build the model
393
String paramVarName = loadParametersIntoRSession();
394
String cmd = "buildCNN(\"" + getModelName() + "\", " + paramVarName + ")";
395
REXP ret = rengine.eval(cmd);
397
CNNRegressionModel.logger.debug("Error in buildCNN");
398
throw new QSARModelException("Error in buildCNN");
401
// remove the parameter list
402
rengine.eval("rm(" + paramVarName + ")");
404
// save the model object on the Java side
405
modelObject = ret.asList();
409
* Uses a fitted model to predict the response for new observations.
411
* This function uses a previously fitted model to obtain predicted values
412
* for a new set of observations. If the model has not been fitted prior to this
413
* call an exception will be thrown. Use <code>setParameters</code>
414
* to set the values of the independent variable for the new observations and the
417
* @throws org.openscience.cdk.qsar.model.QSARModelException
418
* if the model has not been built prior to a call
419
* to this method. Also if the number of independent variables specified for prediction
420
* is not the same as specified during model building
422
public void predict() throws QSARModelException {
424
if (modelObject == null)
425
throw new QSARModelException("Before calling predict() you must fit the model using build()");
427
Double[][] newx = (Double[][]) params.get("newdata");
428
if (newx[0].length != nvar) {
429
throw new QSARModelException("Number of independent variables used for prediction must match those used for fitting");
432
String pn = loadParametersIntoRSession();
433
REXP ret = rengine.eval("predicCNN(\"" + getModelName() + "\", " + pn + ")");
434
if (ret == null) throw new QSARModelException("Error occured in prediction");
436
// remove the parameter list
437
rengine.eval("rm(" + pn + ")");
439
modelPredict = ret.asDoubleMatrix();
443
* Get the matrix of predicted values obtained from <code>predict.nnet<code>.
445
* @return The result of the prediction.
447
public double[][] getPredictions() {
452
* Returns an <code>RList</code> object summarizing the nnet regression model.
454
* The return object can be queried via the <code>RList</code> methods to extract the
455
* required components.
457
* @return A summary for the nnet regression model
458
* @throws org.openscience.cdk.qsar.model.QSARModelException
459
* if the model has not been built prior to a call
462
public RList summary() throws QSARModelException {
463
if (modelObject == null)
464
throw new QSARModelException("Before calling summary() you must fit the model using build()");
466
REXP ret = rengine.eval("summary(" + getModelName() + ")");
468
logger.debug("Error in summary()");
469
throw new QSARModelException("Error in summary()");
476
* Loads a <code>'nnet'</code> object from disk in to the current session.
478
* @param fileName The disk file containing the model
479
* @throws org.openscience.cdk.qsar.model.QSARModelException
480
* if the model being loaded is not a <code>'nnet'</code> model
481
* object or the file does not exist
483
public void loadModel(String fileName) throws QSARModelException {
484
File f = new File(fileName);
485
if (!f.exists()) throw new QSARModelException(fileName + " does not exist");
487
rengine.assign("tmpFileName", fileName);
488
REXP ret = rengine.eval("loadModel(tmpFileName)");
489
if (ret == null) throw new QSARModelException("Model could not be loaded");
491
String name = ret.asList().at("name").asString();
492
if (!isOfClass(name, "nnet")) {
494
throw new QSARModelException("Loaded object was not of class \'nnet\'");
497
modelObject = ret.asList().at("model").asList();
499
nvar = (int) getN()[0];
500
noutput = (int) getN()[2];
504
* Loads a <code>'nnet'</code> object from a serialized string into the current session.
506
* @param serializedModel A String containing the serialized version of the model
507
* @param modelName A String indicating the name of the model in the R session
508
* @throws org.openscience.cdk.qsar.model.QSARModelException
509
* if the model being loaded is not a <code>'nnet'</code> model
512
public void loadModel(String serializedModel, String modelName) throws QSARModelException {
513
rengine.assign("tmpSerializedModel", serializedModel);
514
rengine.assign("tmpModelName", modelName);
515
REXP ret = rengine.eval("unserializeModel(tmpSerializedModel, tmpModelName)");
517
if (ret == null) throw new QSARModelException("Model could not be unserialized");
519
String name = ret.asList().at("name").asString();
520
if (!isOfClass(name, "nnet")) {
522
throw new QSARModelException("Loaded object was not of class \'nnet\'");
525
modelObject = ret.asList().at("model").asList();
527
nvar = (int) getN()[0];
528
noutput = (int) getN()[2];
531
// Autogenerated code: assumes that 'modelObject' is
536
* Gets the <code>censored</code> field of an <code>'nnet'</code> object.
538
* @return The value of the censored field
540
public RBool getCensored() {
541
return modelObject.at("censored").asBool();
545
* Gets the <code>conn</code> field of an <code>'nnet'</code> object.
547
* @return The value of the conn field
549
public double[] getConn() {
550
return modelObject.at("conn").asDoubleArray();
554
* Gets the <code>decay</code> field of an <code>'nnet'</code> object.
556
* @return The value of the decay field
558
public double getDecay() {
559
return modelObject.at("decay").asDouble();
563
* Gets the <code>entropy</code> field of an <code>'nnet'</code> object.
565
* @return The value of the entropy field
567
public RBool getEntropy() {
568
return modelObject.at("entropy").asBool();
572
* Gets the <code>fitted.values</code> field of an <code>'nnet'</code> object.
574
* @return The value of the fitted.values field
576
public double[][] getFittedValues() {
577
return modelObject.at("fitted.values").asDoubleMatrix();
581
* Gets the <code>n</code> field of an <code>'nnet'</code> object.
583
* @return The value of the n field
585
public double[] getN() {
586
return modelObject.at("n").asDoubleArray();
590
* Gets the <code>nconn</code> field of an <code>'nnet'</code> object.
592
* @return The value of the nconn field
594
public double[] getNconn() {
595
return modelObject.at("nconn").asDoubleArray();
599
* Gets the <code>nsunits</code> field of an <code>'nnet'</code> object.
601
* @return The value of the nsunits field
603
public double getNsunits() {
604
return modelObject.at("nsunits").asDouble();
608
* Gets the <code>nunits</code> field of an <code>'nnet'</code> object.
610
* @return The value of the nunits field
612
public double getNunits() {
613
return modelObject.at("nunits").asDouble();
617
* Gets the <code>residuals</code> field of an <code>'nnet'</code> object.
619
* @return The value of the residuals field
621
public double[][] getResiduals() {
622
return modelObject.at("residuals").asDoubleMatrix();
626
* Gets the <code>softmax</code> field of an <code>'nnet'</code> object.
628
* @return The value of the softmax field
630
public RBool getSoftmax() {
631
return modelObject.at("softmax").asBool();
635
* Gets the <code>value</code> field of an <code>'nnet'</code> object.
637
* @return The value of the value field
639
public double getValue() {
640
return modelObject.at("value").asDouble();
644
* Gets the <code>wts</code> field of an <code>'nnet'</code> object.
646
* @return The value of the wts field
648
public double[] getWts() {
649
return modelObject.at("wts").asDoubleArray();