~daniele-bigoni/tensortoolbox/tt-docs

« back to all changes in this revision

Viewing changes to TensorToolbox/core/tensor_wrapper.py

  • Committer: Daniele Bigoni
  • Date: 2015-01-19 11:10:20 UTC
  • Revision ID: dabi@dtu.dk-20150119111020-p0uckg4ab3xqzf47
merged with research

Show diffs side-by-side

added added

removed removed

Lines of Context:
46
46
except ImportError:
47
47
    MPI_SUPPORT = False
48
48
 
49
 
from TensorToolbox.core import idxunfold, idxfold, storable_object
 
49
from TensorToolbox.core import idxunfold, idxfold, expand_idxs, storable_object
50
50
 
51
51
class TensorWrapper(storable_object):
52
52
    """ A tensor wrapper is a data structure W that given a multi-dimensional scalar function f(X,params), and a set of coordinates {{x1}_i1,{x2}_i2,..,{xd}_id} indexed by the multi index {i1,..,id}, let you access f(x1_i1,..,xd_id) by W[i1,..,id]. The function evaluations are performed "as needed" and stored for future accesses.
54
54
    :param f: multi-dimensional scalar function of type f(x,params), x being a list.
55
55
    :param list X: list of arrays with coordinates for each dimension
56
56
    :param tuple params: parameters to be passed to function f
 
57
    :param list W: list of arrays with weights for each dimension
57
58
    :param int Q: power to which round all the dimensions to.
58
59
    :param string twtype: 'array' values are stored whenever computed, 'view' values are never stored and function f is always called
59
60
    :param dict data: initialization data of the Tensor Wrapper (already computed entries)
66
67
    :param int maxprocs: Number of processors to be used in the function evaluation (MPI)
67
68
    :param bool marshal_f: whether to marshal the function or not
68
69
 
 
70
    Several shape parameters are used by the TensorWrapper in order to keep track of reshaping and slicings, without affecting the underlying shape of the tensor which is always preserved. The following table lists the existing shapes and their meaning.
 
71
 
 
72
    +------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
 
73
    | Shape attribute/function                       | Applied transformations (ordered)    | Description                                                                                                                                                                                                                                                      |
 
74
    +================================================+======================================+==================================================================================================================================================================================================================================================================+
 
75
    | :py:meth:`~TensorWrapper.get_global_shape`     | None                                 | The original shape of the tensor. This shape can be modified only  through a refinement of the grid using the function :py:meth:`~TensorWrapper.refine`.                                                                                                         |
 
76
    +------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
 
77
    | :py:meth:`~TensorWrapper.get_view_shape`       | VIEW                                 | The particular view of the tensor, defined by the view in :py:attr:`TensorWrapper.maps` set active using :py:meth:`~TensorWrapper.set_active_view`.                                                                                                              |
 
78
    +------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
 
79
    | :py:meth:`~TensorWrapper.get_extended_shape`   | VIEW, QUANTICS                       | The shape of the extended tensor in order to allow for the quantics folding with basis :py:attr:`TensorWrapper.Q`.                                                                                                                                               |
 
80
    +------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
 
81
    | :py:meth:`~TensorWrapper.get_ghost_shape`      | VIEW, QUANTICS, RESHAPE              | The shape of the tensor reshaped using :py:meth:`~TensorWrapper.reshape`. If a Quantics folding is pre-applied, then the reshape is on the extended shape.                                                                                                       |
 
82
    +------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
 
83
    | :py:meth:`~TensorWrapper.get_shape`            | VIEW, QUANTICS, RESHAPE, FIX_IDXS    | The shape of the tensor with :py:meth:`~TensorWrapper.fix_indices` and :py:meth:`~TensorWrapper.release_indices`. This is the view that is always used when the tensor is accessed through the function :py:meth:`~TensorWrapper.__getitem__` (i.e. ``TW[...]``) |
 
84
    +------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
 
85
 
 
86
    .. document private functions
 
87
    .. automethod:: __getitem__
69
88
    """
70
89
    
71
90
    logger = logging.getLogger(__name__)
75
94
                                  "%Y-%m-%d %H:%M:%S")
76
95
    ch.setFormatter(formatter)
77
96
    logger.addHandler(ch)
 
97
 
 
98
    # FILL_VALUE = 0.0            # Value used to fill powQ tensors
78
99
    
79
 
    def __init__(self, f, X, params=None, Q=None, twtype='array', data=None, dtype=object,
 
100
    def __init__(self, f, X, params=None, W=None, twtype='array', data=None, dtype=object,
80
101
                 store_file = "", store_object = None, store_freq=None, store_overwrite=False,
81
102
                 empty=False,
82
103
                 maxprocs=None,
90
111
        # List of attributes
91
112
        self.f_code = None               # Marshal string of the function f
92
113
        self.X = None
 
114
        self.W = None
93
115
        self.params = None
94
 
        self.Q = None
95
116
        self.dtype = object
96
 
        self.shape = None
97
 
        self.ndim = None
98
 
        self.size = None
99
117
        self.twtype = None
100
118
        self.data = None
101
119
 
102
 
        self.serialize_list.extend( ['X', 'dtype', 'params', 'Q', 'shape', 'ndim', 'size', 'twtype', 'f_code'] )
 
120
        self.maps = {}          # Multiple views
 
121
        self.view = None        # Multiple views
 
122
 
 
123
        self.serialize_list.extend( ['X', 'W', 'dtype', 'params', 'twtype', 'f_code', 'maps', 'view'] )
103
124
        self.subserialize_list.extend( [] )
104
125
 
105
126
        # Attributes which are not serialized and need to be reset on reload
106
 
        self.__ghost_shape = None
107
127
        self.f = None
108
128
        self.store_object = None
109
129
        self.__maxprocs = None             # Number of processors to be used (MPI)
110
130
 
111
 
        self.fix_idxs = []
112
 
        self.fix_dims = []
113
 
 
 
131
        self.shape = None
 
132
        self.ndim = None
 
133
        self.size = None
 
134
        
114
135
        self.stored_keys = None # Set of stored keys (to improve the saving speed)
 
136
 
 
137
        self.active_weights = None
115
138
        # End list of attributes
116
139
        #################################
117
 
        
 
140
 
 
141
        self.active_weights = False        
118
142
        self.stored_keys = set()
119
143
        if not empty: 
120
144
            self.set_f(f,marshal_f)
121
145
            self.params = params
122
146
            self.X = X
123
 
            self.Q = Q
 
147
            self.W = W
 
148
            self.maps['full'] = { 'X_map': self.X, 
 
149
                                  'idx_map': [ range( len(x) ) for x in self.X ],
 
150
                                  'Q': None,
 
151
                                  'ghost_shape': None,
 
152
                                  'fix_idxs': [],
 
153
                                  'fix_dims': [] }
 
154
            self.set_active_view('full')
124
155
            self.shape = self.get_shape()
125
 
            self.ndim = len(self.shape)
 
156
            self.ndim = self.get_ndim()
126
157
            self.size = self.get_size()
127
158
            self.twtype = twtype
128
159
            self.dtype = dtype
150
181
        self.fix_idxs = []
151
182
        self.fix_dims = []
152
183
        self.__maxprocs = None
153
 
        self.reset_shape()
 
184
        self.reset_ghost_shape()
 
185
        self.active_weights = False
 
186
 
 
187
    def set_weights(self, W):
 
188
        """ Set a new list of weights for the tensor
 
189
        :param list W: list of np.ndarray with weights for each dimension
 
190
        """
 
191
        if len(W) != len(self.get_global_shape()):
 
192
            raise ValueError("The provided set of weights has not the right dimension: len(W)=%d, dim=%d" % (len(W),len(self.get_global_shape())))
 
193
        if any( [ len(wi) != si for wi,si in zip(W,self.get_global_shape()) ] ):
 
194
            raise ValueError("The provided set of weights contains at least one dimension which is not conformal with the tensor grid.")
 
195
        self.W = W
 
196
 
 
197
    def set_active_weights(self,flag):
 
198
        """ Set whether to use the weights or not.
 
199
 
 
200
        :param bool flag: If ``True`` the items returned by the Tensor Wrapper will be weighted according to the weights provided at construction time. If ``False`` the original values of the function will be returned.
 
201
        """
 
202
        self.active_weights = flag
154
203
 
155
204
    def getstate(self):
156
205
        return self.__getstate__();
226
275
    def copy(self):
227
276
        return TensorWrapper(self.f, self.X, params=self.params, twtype=self.twtype, data=self.data.copy())
228
277
 
229
 
    def set_Q(self, Q):
230
 
        self.Q = Q
231
 
        self.__ghost_shape = self.get_q_shape()
232
 
        self.update_shape()
233
 
 
234
 
    def get_size(self):
235
 
        """ Always returns the size of the tensor view
236
 
        
237
 
        .. note: use :py:meth:`TensorWrapper.get_global_size` to get the size of the original tensor
238
 
        """
239
 
        return reduce(operator.mul, self.get_shape(), 1)
240
 
    
241
 
    def get_ndim(self):
242
 
        """ Always returns the number of dimensions of the tensor view
243
 
        
244
 
        .. note: use :py:meth:`TensorWrapper.get_global_ndim` to get the number of dimensions of the original tensor
245
 
        """
246
 
        return len(self.get_shape())
247
 
 
248
 
    def get_shape(self):
249
 
        """ Always returns the shape of the actual tensor view
250
 
        
251
 
        .. note: use :py:meth:`TensorWrapper.get_global_shape` to get the shape of the original tensor
252
 
        """
253
 
        if self.__ghost_shape == None:
254
 
            dim = [ len(coord) for dim,coord in enumerate(self.X) if not (dim in self.fix_dims) ]
255
 
        else: 
256
 
            dim = [ s for dim,s in enumerate(self.__ghost_shape) if not (dim in self.fix_dims) ]
257
 
        return tuple(dim)
258
 
    
259
 
    def get_full_shape(self):
260
 
        """ Always returns the shape of the reshaped tensor tensor with no fixed indices
261
 
        """
262
 
        if self.__ghost_shape == None:
263
 
            dim = self.get_q_shape()
264
 
        else:
265
 
            dim = self.__ghost_shape
266
 
        return tuple(dim)
267
 
 
268
 
    def get_full_ndim(self):
269
 
        """ Always returns the ndim of the reshaped tensor tensor with no fixed indices
270
 
        """
271
 
        return len(self.get_full_shape())
272
 
    
273
 
    def get_full_size(self):
274
 
        """ Always returns the size of the reshaped tensor tensor with no fixed indices
275
 
        """
276
 
        return reduce(operator.mul, self.get_full_shape(), 1)
277
 
 
278
 
    def get_q_shape(self):
279
 
        """ Always returns the shape of the tensor rounded to the next power of Q if Q!=None. Otherwise returns the shape of the underlying tensor.
280
 
        """
281
 
        if self.Q == None:
282
 
            dim = self.get_global_shape()
283
 
        else:
284
 
            dim = [ self.Q**(int(math.log(s-0.5,self.Q))+1) for s in self.get_global_shape() ]
285
 
        return tuple( dim )
286
 
    
287
 
    def get_q_size(self):
288
 
        """Always returns the size of the tensor rounded to the next power of Q
289
 
        """
290
 
        return reduce(operator.mul, self.get_q_shape(), 1)
291
 
    
 
278
    #####################################################
 
279
    #               SHAPES AND VIEWS                    #
 
280
    #####################################################
 
281
 
 
282
    ##########
 
283
    # GLOBAL #
 
284
    ##########
292
285
    def get_global_shape(self):
293
286
        """ Always returns the shape of the underlying tensor
294
287
        """
304
297
        """ Always returns the size of the underlying tensor
305
298
        """
306
299
        return reduce(operator.mul, self.get_global_shape(), 1)
307
 
 
308
 
    def get_fill_level(self):
309
 
        if self.twtype == 'view': return 0
310
 
        else: return len(self.data)
311
 
 
 
300
    
 
301
    #########
 
302
    # VIEWS #
 
303
    #########
 
304
    def get_view_shape(self):
 
305
        """ Always returns the shape of the current view
 
306
        """
 
307
        dim = [ len(coord) for coord in self.maps[self.view]['X_map'] ]
 
308
        return tuple(dim)
 
309
 
 
310
    def get_view_ndim(self):
 
311
        """ Always returns the ndim of the current view
 
312
        """
 
313
        return len(self.maps[self.view]['X_map'])
 
314
    
 
315
    def get_view_size(self):
 
316
        """ Always returns the size of the current view
 
317
        """
 
318
        return reduce(operator.mul, self.get_view_shape(), 1)
 
319
 
 
320
    def set_active_view(self, view):
 
321
        """ Set a view among the ones in ``self.maps``.
 
322
        
 
323
        :param str view: name of the view to be set as active
 
324
        """
 
325
        self.view = view
 
326
        self.update_shape()
 
327
        
 
328
    def set_view(self, view, X_map, tol=None):
 
329
        """ Set or add a view to ``self.maps``. This resest all the existing reshape parameters in existing views.
 
330
        
 
331
        :param str view: name of the view to be added
 
332
        :param list X_map: list of coordinates of the new view
 
333
        :param float tol: tolerance for the matching of coordinates
 
334
        """
 
335
        if tol == None: tol = np.spacing(1)
 
336
        idx_map = []
 
337
        for d, x in enumerate(X_map):
 
338
            if any(x[i] > x[i+1] for i in range( len(x)-1 )):
 
339
                raise ValueError("TensorWrapperView: the input coordinates must be sorted")
 
340
            
 
341
            idx_map.append([])
 
342
            j = 0
 
343
            for val in x:
 
344
                while j < len(self.X[d]):
 
345
                    if abs( val - self.X[d][j] ) <= tol :
 
346
                        idx_map[-1].append( j )
 
347
                        break
 
348
                    j += 1
 
349
                if j == len(self.X[d]):
 
350
                    raise ValueError("TensorWrapperView: the input coordinates are not a subset of the full coordinates")
 
351
        self.maps[view] = { 'X_map': X_map, 
 
352
                            'idx_map': idx_map,
 
353
                            'Q': None,
 
354
                            'ghost_shape': None,
 
355
                            'fix_idxs': [],
 
356
                            'fix_dims': []}
 
357
 
 
358
    def refine(self, X_new, tol=None):
 
359
        """ Refine the global discretization. The new discretization must contain the old one.
 
360
 
 
361
        This function takes care of updating all the indices in the global view as well in all the other views.
 
362
        
 
363
        :param list X_new: list of coordinates of the new refinement
 
364
        :param float tol: tolerance for the matching of coordinates
 
365
        """
 
366
        if tol == None: tol = np.spacing(1)
 
367
        top_map = []            # Map from the old full coord to X_new
 
368
        for d, x in enumerate(X_new):
 
369
            if any(x[i] > x[i+1] for i in range( len(x)-1 )):
 
370
                raise ValueError("TensorWrapperView: the input coordinates must be sorted")
 
371
            
 
372
            top_map.append([])
 
373
            j = 0
 
374
            for val in self.X[d]:
 
375
                while j < len(x):
 
376
                    if abs( val, x[j] ) <= tol:
 
377
                        top_map[-1].append( j )
 
378
                        break
 
379
                    j += 1
 
380
                if j == len(x):
 
381
                    raise ValueError("TensorWrapperView: the full coordinates are not a subset of the new coordinates")
 
382
            
 
383
        # Update all keys in data
 
384
        for old_key in self.data:
 
385
            new_key = tuple( [ top_map[i][k] for i,k in enumerate(old_key) ] )
 
386
            self.data[new_key] = self.data.pop(old_key)
 
387
        
 
388
        # Update the coordinates
 
389
        self.X = X_new
 
390
 
 
391
        # Update all the views
 
392
        self.maps['full'] = { 'X_map': self.X, 
 
393
                              'idx_map': [ range( len(x) ) for x in self.X ],
 
394
                              'Q': None,
 
395
                              'ghost_shape': None,
 
396
                              'fix_idxs': None,
 
397
                              'fix_dims': None }
 
398
        for view in self.maps:
 
399
            if view != 'full':
 
400
                self.set_view( view, self.maps[view]['X_map'] )
 
401
 
 
402
    #########################
 
403
    # EXTENDED for Quantics #
 
404
    #########################
 
405
    def get_extended_shape(self):
 
406
        """ If the quantics folding has been performed on the current view, then this returns the shape of the extended tensor to the next power of Q. If the folding has not been performed, this returns the view shape.
 
407
        """
 
408
        if self.maps[self.view]['Q'] == None:
 
409
            return self.get_view_shape()
 
410
        else:
 
411
            return tuple( [ self.maps[self.view]['Q']**(int(math.log(s-0.5,self.maps[self.view]['Q']))+1) for s in self.get_view_shape() ] )
 
412
 
 
413
    def get_extended_ndim(self):
 
414
        """ If the quantics folding has been performed on the current view, then this returns the number of dimensions of the extended tensor to the next power of Q. If the folding has not been performed, this returns an error.
 
415
        """
 
416
        return len(self.get_extended_shape())
 
417
    
 
418
    def get_extended_size(self):
 
419
        """ If the quantics folding has been performed on the current view, then this returns the size of the extended tensor to the next power of Q. If the folding has not been performed, this returns an error.
 
420
        """
 
421
        return reduce(operator.mul, self.get_extended_shape())
 
422
    
 
423
    def set_Q(self, Q):
 
424
        """ Set the quantics folding base for the current view.
 
425
 
 
426
        This will unset any fixed index for the current view set using :py:meth:`~TensorWrapper.fix_indices`.
 
427
        
 
428
        :param int Q: folding base.
 
429
        """
 
430
        self.maps[self.view]['Q'] = Q
 
431
        self.maps[self.view]['ghost_shape'] = [ Q**(int(math.log(s-0.5, Q))+1) for s in self.get_view_shape() ]
 
432
        self.maps[self.view]['fix_idxs'] = []
 
433
        self.maps[self.view]['fix_dims'] = []
 
434
        self.update_shape()
 
435
 
 
436
    def reset_shape(self):
 
437
        """ Reset the shape of the tensor erasing the reshape and quantics foldings.
 
438
        """
 
439
        self.maps[self.view]['Q'] = None
 
440
        self.maps[self.view]['ghost_shape'] = None
 
441
        self.update_shape()
 
442
 
 
443
    #########
 
444
    # GHOST #
 
445
    #########
 
446
    def get_ghost_shape(self):
 
447
        """ If the ``ghost_shape`` is set for this view, then it returns the shape obtained after quantics folding by the function :py:meth:`~TensorWrapper.set_Q` or after reshaping by the function :py:meth:`~TensorWrapper.reshape`. Otherwise the shape of the extended shape is returned.
 
448
        """
 
449
        if self.maps[self.view]['ghost_shape'] != None:
 
450
            return self.maps[self.view]['ghost_shape']
 
451
        else:
 
452
            return self.get_extended_shape()
 
453
    
 
454
    def get_ghost_ndim(self):
 
455
        """ If the ``ghost_shape`` is set for this view, then it returns the number of dimensions obtained after quantics folding by the function :py:meth:`~TensorWrapper.set_Q` or after reshaping by the function :py:meth:`~TensorWrapper.reshape`. Otherwise the number of dimensions of the view is returned.
 
456
        """
 
457
        return len(self.get_ghost_shape())
 
458
 
 
459
    def get_ghost_size(self):
 
460
        """ If the ``ghost_shape`` is set for this view, then it returns the size obtained after quantics folding by the function :py:meth:`~TensorWrapper.set_Q` or after reshaping by the function :py:meth:`~TensorWrapper.reshape`. Otherwise the size of the view is returned.
 
461
        """
 
462
        return reduce(operator.mul, self.get_ghost_shape())
 
463
 
 
464
    def reset_ghost_shape(self):
 
465
        """ Reset the shape of the tensor erasing the reshape and quantics foldings.
 
466
        """
 
467
        self.maps[self.view]['ghost_shape'] = None
 
468
        self.update_shape()
 
469
    
 
470
    def reshape(self,newshape):
 
471
        """ Reshape the tensor. The number of items in the new shape must be consistent with :py:meth:`~TensorWrapper.get_extended_size`, i.e. with the number of items in the extended quantics size or the view size if ``Q`` is not set for this view.
 
472
 
 
473
        This will unset any fixed index for the current view set using :py:meth:`~TensorWrapper.fix_indices`.
 
474
 
 
475
        :param list newshape: new shape to be applied to the tensor.
 
476
        """
 
477
        if reduce(operator.mul, newshape, 1) == self.get_extended_size():
 
478
            self.maps[self.view]['ghost_shape'] = tuple(newshape)
 
479
            self.maps[self.view]['fix_idxs'] = []
 
480
            self.maps[self.view]['fix_dims'] = []
 
481
        self.update_shape()
 
482
        return self
 
483
    
 
484
    ###########
 
485
    # FIX_IDX #
 
486
    ###########
 
487
    def get_shape(self):
 
488
        """ Always returns the shape of the actual tensor view
 
489
        
 
490
        .. note: use :py:meth:`TensorWrapper.get_global_shape` to get the shape of the original tensor
 
491
        """
 
492
        return tuple([ s for dim,s in enumerate(self.get_ghost_shape()) if not (dim in self.maps[self.view]['fix_dims']) ])
 
493
    
 
494
    def get_ndim(self):
 
495
        """ Always returns the number of dimensions of the tensor view
 
496
        
 
497
        .. note: use :py:meth:`TensorWrapper.get_global_ndim` to get the number of dimensions of the original tensor
 
498
        """
 
499
        return len(self.get_shape())
 
500
 
 
501
    def get_size(self):
 
502
        """ Always returns the size of the tensor view
 
503
        
 
504
        .. note: use :py:meth:`TensorWrapper.get_global_size` to get the size of the original tensor
 
505
        """
 
506
        return reduce(operator.mul, self.get_shape(), 1)
 
507
    
 
508
    def update_shape(self):
 
509
        self.shape = self.get_shape()
 
510
        self.size = self.get_size()
 
511
        self.ndim = self.get_ndim()
 
512
    
312
513
    def fix_indices(self, idxs, dims):
313
514
        """ Fix some of the indices in the tensor wrapper and reshape/resize it accordingly. The internal storage of the data is still done with respect to the global indices, but once some indices are fixed, the TensorWrapper can be accessed using just the remaining free indices.
314
515
        
324
525
        
325
526
        # Reorder the lists
326
527
        i_ord = sorted(range(len(dims)), key=dims.__getitem__)
327
 
        self.fix_idxs = [ idxs[i] for i in i_ord ]
328
 
        self.fix_dims = [ dims[i] for i in i_ord ]
329
 
        
 
528
        self.maps[self.view]['fix_idxs'] = [ idxs[i] for i in i_ord ]
 
529
        self.maps[self.view]['fix_dims'] = [ dims[i] for i in i_ord ]
330
530
        # Update shape, ndim and size
331
531
        self.update_shape()
332
532
    
333
533
    def release_indices(self):
334
 
        self.fix_idxs = []
335
 
        self.fix_dims = []
336
 
        self.update_shape()
337
 
    
338
 
    def update_shape(self):
339
 
        self.shape = self.get_shape()
340
 
        self.size = self.get_size()
341
 
        self.ndim = self.get_ndim()
342
 
    
343
 
    def reshape(self,newshape):
344
 
        if reduce(operator.mul, newshape, 1) == self.get_q_size():
345
 
            self.__ghost_shape = tuple(newshape)
346
 
        self.update_shape()
347
 
        return self
348
 
    
349
 
    def reset_shape(self):
350
 
        self.__ghost_shape = None
351
 
        self.update_shape()
352
 
    
353
 
    def full_to_q(self,idxs):
354
 
        return idxfold( self.get_q_shape(), idxunfold( self.get_full_shape(), idxs ) )
355
 
    
356
 
    def q_to_full(self,idxs):
357
 
        return idxfold( self.get_full_shape(), idxunfold( self.get_q_shape(), idxs ) )
358
 
    
359
 
    def q_to_global(self,idxs):
360
 
        """ This is a non-injective function from the q indices to the global indices
361
 
        """
362
 
        return tuple( [ ( i if i<N else N-1 ) for i,N in zip(idxs,self.get_global_shape()) ] )
363
 
    
364
 
    def global_to_q(self,idxs):
365
 
        """ This operation is undefined because one global idx can point to many q indices
366
 
        """
367
 
        if self.Q != None:
368
 
            raise NotImplemented("This operation is undefined because one global idx can point to many q indices")
369
 
        else:
370
 
            return idxs
371
 
 
372
 
    def global_to_full(self,idxs):
373
 
        return self.q_to_full( self.global_to_q( idxs ) )
374
 
    
375
 
    def full_to_global(self,idxs):
376
 
        return self.q_to_global( self.full_to_q( idxs ) )
377
 
    
 
534
        """ Release all the indices in the tensor wrapper which were fixed using :py:meth:`~TensorWrapper.fix_indices`.
 
535
        """
 
536
        self.maps[self.view]['fix_idxs'] = []
 
537
        self.maps[self.view]['fix_dims'] = []
 
538
        self.update_shape()
 
539
 
 
540
    #####################################################
 
541
    #            INDEX TRANSFORMATIONS                  #
 
542
    #####################################################
 
543
 
 
544
    ###################
 
545
    # GLOBAL to SHAPE #
 
546
    ###################
 
547
    def global_to_view(self, idxs):
 
548
        """ This maps the index from the global shape to the view shape.
 
549
        
 
550
        :param tuple idxs: tuple representing an index to be transformed.
 
551
        
 
552
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
553
        .. note:: this returns an error if the ``idxs`` do not belong to the index mapping of the current view.
 
554
        """
 
555
        return tuple( [ self.idx_max[d].index(i) for d,i in enumerate(idxs) ] )
 
556
    
 
557
    def view_to_ghost(self, idxs):
 
558
        """ This maps the index from the view to the ghost shape.
 
559
        
 
560
        :param list idxs: list of indices to be transformed
 
561
 
 
562
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
563
        .. note:: this returns an error if the ghost shape is obtained by quantics folding, because the one view index can be pointing to many indices in the folding.
 
564
        """
 
565
        if self.maps[self.view]['Q'] != None:
 
566
            raise NotImplemented("This operation is undefined because one view idx can point to many q indices")
 
567
        else:
 
568
            return idxfold( self.get_ghost_shape(), idxunfold( self.get_view_shape(), idxs ) )
 
569
    
 
570
    def global_to_ghost(self,idxs):
 
571
        """ This maps the index from the global shape to the ghost shape.
 
572
        
 
573
        :param list idxs: list of indices to be transformed
 
574
 
 
575
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
576
 
 
577
        For :py:meth:`TensorWrapper` ``A``, this corresponds to:
 
578
        
 
579
        >>> A.view_to_ghost( A.global_to_view( idxs ) )
 
580
 
 
581
        """
 
582
        return self.view_to_ghost( self.global_to_view( idxs ) )
 
583
 
 
584
    ###################
 
585
    # SHAPE to GLOBAL #
 
586
    ###################
 
587
    def shape_to_ghost(self, idxs_in):
 
588
        """ This maps the index from the current shape of the view (fixed indices) to the ghost shape.
 
589
        
 
590
        :param list idxs_in: list of indices to be transformed
 
591
 
 
592
        .. note:: slicing is admitted here.
 
593
        """
 
594
        idxs = idxs_in[:]
 
595
        # Insert the fixed indices
 
596
        for i in self.maps[self.view]['fix_dims']:
 
597
            idxs.insert(i, self.maps[self.view]['fix_idxs'][ self.maps[self.view]['fix_dims'].index(i)] )
 
598
        return idxs
 
599
    
 
600
    def ghost_to_extended(self, idxs):
 
601
        return idxfold( self.get_extended_shape(), idxunfold( self.get_ghost_shape(), idxs ) )
 
602
 
 
603
    def ghost_to_view(self, idxs):
 
604
        """ This maps the index from the current ghost shape of the view to the view shape.
 
605
        
 
606
        :param list idxs_in: list of indices to be transformed
 
607
 
 
608
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
609
        """
 
610
        if self.maps[self.view]['Q'] != None:
 
611
            return tuple( [ ( i if i<N else N-1 ) for i,N in zip(self.ghost_to_extended(idxs),self.get_view_shape()) ] )
 
612
        else:
 
613
            idxs = idxfold( self.get_view_shape(), idxunfold( self.get_ghost_shape(), idxs ) )
 
614
            return tuple( idxs )
 
615
    
 
616
    def shape_to_view(self, idxs):
 
617
        """ This maps the index from the current shape of the view to the view shape.
 
618
        
 
619
        :param list idxs_in: list of indices to be transformed
 
620
 
 
621
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
622
        """
 
623
        return self.ghost_to_view( self.shape_to_ghost( idxs ) )
 
624
 
 
625
    def view_to_global(self, idxs):
 
626
        """ This maps the index in view to the global indices of the full tensor wrapper.
 
627
        
 
628
        :param list idxs_in: list of indices to be transformed
 
629
 
 
630
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
631
        """
 
632
        return tuple( [ self.maps[self.view]['idx_map'][d][i] for d,i in enumerate(idxs) ] )
 
633
 
 
634
    def ghost_to_global(self, idxs):
 
635
        """ This maps the index from the current ghost shape of the view to the global shape.
 
636
        
 
637
        :param list idxs_in: list of indices to be transformed
 
638
 
 
639
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
640
        """
 
641
        return self.view_to_global( self.ghost_to_view( idxs ) )
 
642
    
 
643
    def shape_to_global(self, idxs):
 
644
        """ This maps the index from the current shape of the view to the global shape.
 
645
        
 
646
        :param list idxs_in: list of indices to be transformed
 
647
 
 
648
        .. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
 
649
        """
 
650
        return self.view_to_global( self.ghost_to_view( self.shape_to_ghost( idxs ) ) )
 
651
 
 
652
    ######################
 
653
    # CHECKS ON EXTENDED #
 
654
    ######################
 
655
    def extended_is_view(self, idxs):
 
656
        """ 
 
657
        :return: True if the idxs is in the view shape. False if it is outside
 
658
        """
 
659
        return all( [ i<N for i,N in zip(idxs,self.get_view_shape()) ] )
 
660
    
 
661
    def full_is_view(self,idxs):
 
662
        return self.extended_is_view( self.ghost_to_extended( idxs ) )
 
663
 
 
664
    ###############################################
 
665
    #            DATA AND FUNCTIONS               #
 
666
    ###############################################
 
667
    def get_fill_level(self):
 
668
        if self.twtype == 'view': return 0
 
669
        else: return len(self.data)
 
670
        
378
671
    def get_fill_idxs(self):
379
672
        return self.data.keys()
380
673
    
387
680
    def get_params(self):
388
681
        return self.params
389
682
    
390
 
    def set_f(self,f,marshal_f=True):
 
683
    def set_f(self,f, marshal_f):
391
684
        self.f = f
392
685
        if self.f != None and marshal_f:
393
686
            self.f_code = marshal.dumps(self.f.func_code)
417
710
        self.store_object = store_object
418
711
 
419
712
    def __getitem__(self,idxs_in):
420
 
        # Transform the tuple to a list for convinience
421
 
        idxs_in = list(idxs_in)
422
 
        
423
 
        # Slice notation can be used. Remember: slice(start:stop:step)
424
 
        if len(idxs_in) != len(self.shape):
425
 
            raise IndexError('wrong number of indices')
426
 
        
427
 
        # Check that all the lists are of the same length
428
 
        int_idx = []
429
 
        llen = None
430
 
        for i,idx in enumerate(idxs_in):
431
 
            if isinstance(idx, int):
432
 
                int_idx.append(i)
433
 
            if isinstance(idx, list) or isinstance(idx,tuple):
434
 
                idxs_in[i] = list(idx)
435
 
                if llen == None:
436
 
                    llen = len(idx)
437
 
                elif llen != len(idx):
438
 
                    raise IndexError('List of indices must have the same length.')
439
 
        
440
 
        if llen == None: llen = 1
441
 
 
442
 
        # Expand single indices in idxs_in to llen
443
 
        for i in int_idx: idxs_in[i] = [idxs_in[i]] * llen
444
 
 
445
 
        # Update input indices of slices and lists
446
 
        list_idx_in = []
447
 
        slice_idx_in = []
448
 
        for i,idx in enumerate(idxs_in):
449
 
            if isinstance(idx, list) or isinstance(idx,tuple):
450
 
                list_idx_in.append(i)
451
 
            if isinstance(idx, slice):
452
 
                slice_idx_in.append(i)
453
 
        
454
 
        # Insert fixed indices
455
 
        for i in self.fix_dims:
456
 
            idxs_in.insert(i, [self.fix_idxs[self.fix_dims.index(i)]] * llen)
457
 
 
458
 
        # Construct list of indices which are lists and slices
459
 
        list_idx = []
460
 
        list_IDXs = []
461
 
        slice_idx = []
462
 
        slice_IDXs = []
463
 
        out_shape = []
464
 
        for i,idx in enumerate(idxs_in):
465
 
            if isinstance(idx, list) or isinstance(idx,tuple):
466
 
                list_idx.append(i)
467
 
                list_IDXs.append( idx )
468
 
            if isinstance(idx, slice):
469
 
                slice_idx.append(i)
470
 
                IDXs = range(idx.start if idx.start != None else 0,
471
 
                                 idx.stop  if idx.stop  != None else self.get_full_shape()[i],
472
 
                                 idx.step  if idx.step  != None else 1)
473
 
                slice_IDXs.append( IDXs )
474
 
                out_shape.append(len(IDXs))
475
 
        
476
 
        if len(list_idx) == 0: list_IDXs.append( [-1] ) # Ghost element added to make the full slicing work
477
 
        unlistIdxs = itertools.izip(*list_IDXs)
478
 
 
479
 
        transpose_list_shape = False
480
 
        if llen > 1: 
481
 
            out_shape.insert(0,llen)
482
 
            if len(slice_idx_in) > 0 and len(list_idx_in) > 0 and min(list_idx_in) > max(slice_idx_in):
483
 
                transpose_list_shape = True
484
 
        
485
 
        # Un-slice sliced idxs
486
 
        unslicedIdxs = itertools.product(*slice_IDXs)
487
 
 
488
 
        # Final list of indices (iterator)
489
 
        lidxs = itertools.product(unlistIdxs, unslicedIdxs)
 
713
        
 
714
        (lidxs,out_shape,transpose_list_shape) = expand_idxs(idxs_in, self.shape, self.get_ghost_shape(), self.maps[self.view]['fix_dims'], self.maps[self.view]['fix_idxs'])
490
715
        
491
716
        # Allocate output array
492
717
        if len(out_shape) > 0:
493
718
            out = np.empty(out_shape, dtype=self.dtype)
 
719
            if self.active_weights:
 
720
                out_weights = np.empty(out_shape, dtype=self.dtype)
494
721
            
495
722
            # MPI code
496
723
            eval_is =[]
498
725
            eval_xx = []
499
726
            # End MPI code
500
727
 
501
 
            for i,(lidx,sidx) in enumerate(lidxs):
502
 
                # Reorder the idxs
503
 
                idxs = [None for j in range(len(idxs_in))]
504
 
                for j,jj in enumerate(list_idx): idxs[jj] = lidx[j]
505
 
                for j,jj in enumerate(slice_idx): idxs[jj] = sidx[j]
506
 
                idxs = tuple(idxs)
 
728
            for i,idxs in enumerate(lidxs):
507
729
 
508
730
                # Map ghost indices to global indices
509
 
                idxs = self.full_to_global( idxs )
 
731
                idxs = self.ghost_to_global( idxs )
 
732
                
 
733
                # Compute the weight corresponding to idxs
 
734
                if self.active_weights:
 
735
                    out_weights[idxfold(out_shape,i)] = np.prod([self.W[j][jj] for j,jj in enumerate(idxs)])
510
736
                
511
737
                # Separate field idxs from parameter idxs                
512
738
                if self.twtype == 'array':
530
756
                    # Evaluate function
531
757
                    xx = np.array([self.X[ii][idx] for ii,idx in enumerate(idxs)])
532
758
                    out[idxfold(out_shape,i)] = self.f(xx,self.params)
 
759
 
 
760
                # # Check that the idxs belong to the real tensor
 
761
                # isout_flag = not self.full_is_view( idxs )
 
762
                #
 
763
                # if isout_flag:
 
764
                #     out[idxfold(out_shape,i)] = TensorWrapper.FILL_VALUE
 
765
                # else:
 
766
                #     # Map ghost indices to global indices
 
767
                #     idxs = self.full_to_global( idxs )
 
768
                                # 
 
769
                #     # Separate field idxs from parameter idxs                
 
770
                #     if self.twtype == 'array':
 
771
                #         # Check whether the value has already been computed
 
772
                #         try:
 
773
                #             out[idxfold(out_shape,i)] = self.data[idxs]
 
774
                #         except KeyError:
 
775
                #             if idxs not in eval_idxs:
 
776
                #                 # Evaluate function
 
777
                #                 xx = np.array( [self.X[ii][idx] for ii,idx in enumerate(idxs)] )
 
778
                #                 # MPI code
 
779
                #                 eval_is.append([i])
 
780
                #                 eval_idxs.append(idxs)
 
781
                #                 eval_xx.append(xx)
 
782
                #                 # End MPI code
 
783
                #             else:
 
784
                #                 pos = eval_idxs.index(idxs)
 
785
                #                 eval_is[pos].append(i)
 
786
                                # 
 
787
                #     else:
 
788
                #         # Evaluate function
 
789
                #         xx = np.array([self.X[ii][idx] for ii,idx in enumerate(idxs)])
 
790
                #         out[idxfold(out_shape,i)] = self.f(xx,self.params)
533
791
            
534
792
            # Evaluate missing values
535
793
            if len(eval_xx) > 0:
554
812
                stop_eval = time.time()
555
813
                self.logger.debug(" [DONE] Num. of func. eval.: %d - Avg. time of func. eval.: %fs - Tot. time: %s" % (len(eval_xx),(stop_eval-start_eval)/len(eval_xx)*(min(self.__maxprocs,len(eval_xx)) if self.__maxprocs != None else 1), str(datetime.timedelta(seconds=(stop_eval-start_eval))) ))
556
814
            
 
815
            # Apply weights if needed
 
816
            if self.active_weights:
 
817
                out *= out_weights
 
818
            
557
819
            if transpose_list_shape:
558
820
                out = np.transpose( out , tuple( range(1,len(out_shape)) + [0] ) )
559
821
            
560
822
        else:
561
 
            idxs = tuple(itertools.chain(*lidxs.next()))
 
823
            idxs = tuple(itertools.chain(*lidxs))
562
824
            # Map ghost indices to global indices
563
 
            idxs = self.full_to_global( idxs )
 
825
            idxs = self.ghost_to_global( idxs )
 
826
            # Compute weight if necessary
 
827
            if self.active_weights:
 
828
                w = np.prod([self.W[j][jj] for j,jj in enumerate(idxs)])
564
829
            if self.twtype == 'array':
565
830
                try:
566
831
                    out = self.data[idxs]
572
837
                    out = self.data[idxs]
573
838
            else:
574
839
                out = self.f(np.array([self.X[ii][idx] for ii,idx in enumerate(idxs)]),self.params)
575
 
            return out
 
840
            # Apply the weight if necessary
 
841
            if self.active_weights:
 
842
                out *= w
576
843
        
577
844
        return out