1
1
#!/usr/bin/env python
2
from __future__ import division
2
from __future__ import division, with_statement
3
5
from contextlib import contextmanager
5
from cogent.util.dict_array import DictArrayTemplate
6
from .setting import Var, ConstVal
7
from .calculation import Calculator
8
from cogent.util import parallel
6
9
from cogent.maths.stats.distribution import chdtri
8
LOG = logging.getLogger('cogent')
10
11
__author__ = "Peter Maxwell"
11
12
__copyright__ = "Copyright 2007-2009, The Cogent Project"
12
13
__credits__ = ["Peter Maxwell", "Gavin Huttley"]
13
14
__license__ = "GPL"
15
16
__maintainer__ = "Peter Maxwell"
16
17
__email__ = "pm67nz@gmail.com"
17
18
__status__ = "Production"
604
610
return self.defn_for[par_name].usedDimensions()
606
612
def getParamValue(self, par_name, *args, **kw):
613
"""The value for 'par_name'. Additional arguments specify the scope.
614
Despite the name intermediate values can also be retrieved this way."""
607
615
callback = self._makeValueCallback(None, None)
608
616
defn = self.defn_for[par_name]
609
617
posn = defn._getPosnForScope(*args, **kw)
610
618
return callback(defn, posn)
612
620
def getParamInterval(self, par_name, *args, **kw):
621
"""Confidence interval for 'par_name' found by adjusting the
622
single parameter until the final result falls by 'dropoff', which
623
can be specified directly or via 'p' as chdtri(1, p). Additional
624
arguments are taken to specify the scope."""
613
625
dropoff = kw.pop('dropoff', None)
614
626
p = kw.pop('p', None)
615
627
if dropoff is None and p is None:
625
637
def getParamValueDict(self, dimensions, p=None, dropoff=None,
626
638
params=None, xtol=None):
639
"""A dict tree of parameter values, with parameter names as the
640
top level keys, and the various dimensions ('edge', 'bin', etc.)
641
supplying lower level keys: edge names, bin names etc.
642
If 'p' or 'dropoff' is specified returns chi-square intervals instead
627
644
callback = self._makeValueCallback(dropoff, p, xtol)
628
645
if params is None:
629
646
params = self.getParamNames(scalar_only=True)
654
671
def updatesPostponed(self):
672
"Temporarily turn off calculation for faster input setting"
655
673
(old, self._update_suspended) = (self._update_suspended, True)
657
675
self._update_suspended = old
676
self._updateIntermediateValues()
660
def update(self, changed=None):
678
def updateIntermediateValues(self, changed=None):
661
679
if changed is None:
662
680
changed = self.defns # all
663
681
self._changed.update(id(defn) for defn in changed)
682
self._updateIntermediateValues()
684
def _updateIntermediateValues(self):
667
685
if self._update_suspended:
669
687
# use topological sort order
675
693
self._changed.add(id(c))
676
694
self._changed.clear()
681
parts = [defn._local_repr(col_width, max_width) for defn in self.defns
682
if not isinstance(defn, SelectFromDimension)]
683
return '\n'.join(parts)
696
def assignAll(self, par_name, scope_spec=None, value=None,
697
lower=None, upper=None, const=None, independent=None):
699
defn = self.defn_for[par_name]
700
if not isinstance(defn, _LeafDefn):
701
args = ' and '.join(['"%s"' % a.name for a in defn.args])
702
msg = '"%s" is not settable as it is derived from %s.' % (
704
raise ValueError(msg)
707
const = defn.const_by_default
709
for scope in defn.interpretScopes(
710
independent=independent, **(scope_spec or {})):
712
values = defn.getAllDefaultValues(scope)
716
s_value = sum(values) / len(values)
718
if not numpy.all(value==s_value):
719
warnings.warn("Used mean of %s values" % par_name,
723
s_value = defn.unwrapValue(value)
725
setting = ConstVal(s_value)
727
(s_lower, s_upper) = defn.getCurrentBounds(scope)
728
if lower is not None: s_lower = lower
729
if upper is not None: s_upper = upper
730
setting = Var((s_lower, s_value, s_upper))
731
settings.append((scope, setting))
732
defn.assign(settings)
733
self.updateIntermediateValues([defn])
735
def measureEvalsPerSecond(self, *args, **kw):
736
return self.makeCalculator().measureEvalsPerSecond(*args, **kw)
738
def setupParallelContext(self, parallel_split=None):
739
comm = parallel.getCommunicator()
740
cpu_count = comm.Get_size()
741
if parallel_split is None:
742
parallel_split = cpu_count
743
with parallel.mpi_split(parallel_split) as parallel_context:
744
self.remaining_parallel_context = parallel.getCommunicator()
745
if 'parallel_context' in self.defn_for:
747
'parallel_context', value=parallel_context, const=True)
748
self.overall_parallel_context = comm
750
def makeCalculator(self, calculatorClass=None, variable=None, **kw):
753
for defn in self.defns:
755
(newcells, outputs) = defn.makeCells(input_soup, variable)
756
cells.extend(newcells)
757
input_soup[id(defn)] = outputs
758
if calculatorClass is None:
759
calculatorClass = Calculator
760
kw['overall_parallel_context'] = self.overall_parallel_context
761
kw['remaining_parallel_context'] = self.remaining_parallel_context
762
return calculatorClass(cells, input_soup, **kw)
764
def updateFromCalculator(self, calc):
766
for defn in self.defn_for.values():
767
if isinstance(defn, _LeafDefn):
768
defn.updateFromCalculator(calc)
770
self.updateIntermediateValues(changed)
772
def getNumFreeParams(self):
773
return sum(defn.getNumFreeParams() for defn in self.defns if isinstance(defn, _LeafDefn))
775
def optimise(self, *args, **kw):
776
return_calculator = kw.pop('return_calculator', False)
777
lc = self.makeCalculator()
778
lc.optimise(*args, **kw)
779
self.updateFromCalculator(lc)
780
if return_calculator:
783
def graphviz(self, **kw):
784
lc = self.makeCalculator()
785
return lc.graphviz(**kw)