66
67
:param int maxprocs: Number of processors to be used in the function evaluation (MPI)
67
68
:param bool marshal_f: whether to marshal the function or not
70
Several shape parameters are used by the TensorWrapper in order to keep track of reshaping and slicings, without affecting the underlying shape of the tensor which is always preserved. The following table lists the existing shapes and their meaning.
72
+------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
73
| Shape attribute/function | Applied transformations (ordered) | Description |
74
+================================================+======================================+==================================================================================================================================================================================================================================================================+
75
| :py:meth:`~TensorWrapper.get_global_shape` | None | The original shape of the tensor. This shape can be modified only through a refinement of the grid using the function :py:meth:`~TensorWrapper.refine`. |
76
+------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
77
| :py:meth:`~TensorWrapper.get_view_shape` | VIEW | The particular view of the tensor, defined by the view in :py:attr:`TensorWrapper.maps` set active using :py:meth:`~TensorWrapper.set_active_view`. |
78
+------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
79
| :py:meth:`~TensorWrapper.get_extended_shape` | VIEW, QUANTICS | The shape of the extended tensor in order to allow for the quantics folding with basis :py:attr:`TensorWrapper.Q`. |
80
+------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
81
| :py:meth:`~TensorWrapper.get_ghost_shape` | VIEW, QUANTICS, RESHAPE | The shape of the tensor reshaped using :py:meth:`~TensorWrapper.reshape`. If a Quantics folding is pre-applied, then the reshape is on the extended shape. |
82
+------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
83
| :py:meth:`~TensorWrapper.get_shape` | VIEW, QUANTICS, RESHAPE, FIX_IDXS | The shape of the tensor with :py:meth:`~TensorWrapper.fix_indices` and :py:meth:`~TensorWrapper.release_indices`. This is the view that is always used when the tensor is accessed through the function :py:meth:`~TensorWrapper.__getitem__` (i.e. ``TW[...]``) |
84
+------------------------------------------------+--------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
86
.. document private functions
87
.. automethod:: __getitem__
71
90
logger = logging.getLogger(__name__)
90
111
# List of attributes
91
112
self.f_code = None # Marshal string of the function f
93
115
self.params = None
95
116
self.dtype = object
99
117
self.twtype = None
102
self.serialize_list.extend( ['X', 'dtype', 'params', 'Q', 'shape', 'ndim', 'size', 'twtype', 'f_code'] )
120
self.maps = {} # Multiple views
121
self.view = None # Multiple views
123
self.serialize_list.extend( ['X', 'W', 'dtype', 'params', 'twtype', 'f_code', 'maps', 'view'] )
103
124
self.subserialize_list.extend( [] )
105
126
# Attributes which are not serialized and need to be reset on reload
106
self.__ghost_shape = None
108
128
self.store_object = None
109
129
self.__maxprocs = None # Number of processors to be used (MPI)
114
135
self.stored_keys = None # Set of stored keys (to improve the saving speed)
137
self.active_weights = None
115
138
# End list of attributes
116
139
#################################
141
self.active_weights = False
118
142
self.stored_keys = set()
120
144
self.set_f(f,marshal_f)
121
145
self.params = params
148
self.maps['full'] = { 'X_map': self.X,
149
'idx_map': [ range( len(x) ) for x in self.X ],
154
self.set_active_view('full')
124
155
self.shape = self.get_shape()
125
self.ndim = len(self.shape)
156
self.ndim = self.get_ndim()
126
157
self.size = self.get_size()
127
158
self.twtype = twtype
128
159
self.dtype = dtype
227
276
return TensorWrapper(self.f, self.X, params=self.params, twtype=self.twtype, data=self.data.copy())
231
self.__ghost_shape = self.get_q_shape()
235
""" Always returns the size of the tensor view
237
.. note: use :py:meth:`TensorWrapper.get_global_size` to get the size of the original tensor
239
return reduce(operator.mul, self.get_shape(), 1)
242
""" Always returns the number of dimensions of the tensor view
244
.. note: use :py:meth:`TensorWrapper.get_global_ndim` to get the number of dimensions of the original tensor
246
return len(self.get_shape())
249
""" Always returns the shape of the actual tensor view
251
.. note: use :py:meth:`TensorWrapper.get_global_shape` to get the shape of the original tensor
253
if self.__ghost_shape == None:
254
dim = [ len(coord) for dim,coord in enumerate(self.X) if not (dim in self.fix_dims) ]
256
dim = [ s for dim,s in enumerate(self.__ghost_shape) if not (dim in self.fix_dims) ]
259
def get_full_shape(self):
260
""" Always returns the shape of the reshaped tensor tensor with no fixed indices
262
if self.__ghost_shape == None:
263
dim = self.get_q_shape()
265
dim = self.__ghost_shape
268
def get_full_ndim(self):
269
""" Always returns the ndim of the reshaped tensor tensor with no fixed indices
271
return len(self.get_full_shape())
273
def get_full_size(self):
274
""" Always returns the size of the reshaped tensor tensor with no fixed indices
276
return reduce(operator.mul, self.get_full_shape(), 1)
278
def get_q_shape(self):
279
""" Always returns the shape of the tensor rounded to the next power of Q if Q!=None. Otherwise returns the shape of the underlying tensor.
282
dim = self.get_global_shape()
284
dim = [ self.Q**(int(math.log(s-0.5,self.Q))+1) for s in self.get_global_shape() ]
287
def get_q_size(self):
288
"""Always returns the size of the tensor rounded to the next power of Q
290
return reduce(operator.mul, self.get_q_shape(), 1)
278
#####################################################
280
#####################################################
292
285
def get_global_shape(self):
293
286
""" Always returns the shape of the underlying tensor
304
297
""" Always returns the size of the underlying tensor
306
299
return reduce(operator.mul, self.get_global_shape(), 1)
308
def get_fill_level(self):
309
if self.twtype == 'view': return 0
310
else: return len(self.data)
304
def get_view_shape(self):
305
""" Always returns the shape of the current view
307
dim = [ len(coord) for coord in self.maps[self.view]['X_map'] ]
310
def get_view_ndim(self):
311
""" Always returns the ndim of the current view
313
return len(self.maps[self.view]['X_map'])
315
def get_view_size(self):
316
""" Always returns the size of the current view
318
return reduce(operator.mul, self.get_view_shape(), 1)
320
def set_active_view(self, view):
321
""" Set a view among the ones in ``self.maps``.
323
:param str view: name of the view to be set as active
328
def set_view(self, view, X_map, tol=None):
329
""" Set or add a view to ``self.maps``. This resest all the existing reshape parameters in existing views.
331
:param str view: name of the view to be added
332
:param list X_map: list of coordinates of the new view
333
:param float tol: tolerance for the matching of coordinates
335
if tol == None: tol = np.spacing(1)
337
for d, x in enumerate(X_map):
338
if any(x[i] > x[i+1] for i in range( len(x)-1 )):
339
raise ValueError("TensorWrapperView: the input coordinates must be sorted")
344
while j < len(self.X[d]):
345
if abs( val - self.X[d][j] ) <= tol :
346
idx_map[-1].append( j )
349
if j == len(self.X[d]):
350
raise ValueError("TensorWrapperView: the input coordinates are not a subset of the full coordinates")
351
self.maps[view] = { 'X_map': X_map,
358
def refine(self, X_new, tol=None):
359
""" Refine the global discretization. The new discretization must contain the old one.
361
This function takes care of updating all the indices in the global view as well in all the other views.
363
:param list X_new: list of coordinates of the new refinement
364
:param float tol: tolerance for the matching of coordinates
366
if tol == None: tol = np.spacing(1)
367
top_map = [] # Map from the old full coord to X_new
368
for d, x in enumerate(X_new):
369
if any(x[i] > x[i+1] for i in range( len(x)-1 )):
370
raise ValueError("TensorWrapperView: the input coordinates must be sorted")
374
for val in self.X[d]:
376
if abs( val, x[j] ) <= tol:
377
top_map[-1].append( j )
381
raise ValueError("TensorWrapperView: the full coordinates are not a subset of the new coordinates")
383
# Update all keys in data
384
for old_key in self.data:
385
new_key = tuple( [ top_map[i][k] for i,k in enumerate(old_key) ] )
386
self.data[new_key] = self.data.pop(old_key)
388
# Update the coordinates
391
# Update all the views
392
self.maps['full'] = { 'X_map': self.X,
393
'idx_map': [ range( len(x) ) for x in self.X ],
398
for view in self.maps:
400
self.set_view( view, self.maps[view]['X_map'] )
402
#########################
403
# EXTENDED for Quantics #
404
#########################
405
def get_extended_shape(self):
406
""" If the quantics folding has been performed on the current view, then this returns the shape of the extended tensor to the next power of Q. If the folding has not been performed, this returns the view shape.
408
if self.maps[self.view]['Q'] == None:
409
return self.get_view_shape()
411
return tuple( [ self.maps[self.view]['Q']**(int(math.log(s-0.5,self.maps[self.view]['Q']))+1) for s in self.get_view_shape() ] )
413
def get_extended_ndim(self):
414
""" If the quantics folding has been performed on the current view, then this returns the number of dimensions of the extended tensor to the next power of Q. If the folding has not been performed, this returns an error.
416
return len(self.get_extended_shape())
418
def get_extended_size(self):
419
""" If the quantics folding has been performed on the current view, then this returns the size of the extended tensor to the next power of Q. If the folding has not been performed, this returns an error.
421
return reduce(operator.mul, self.get_extended_shape())
424
""" Set the quantics folding base for the current view.
426
This will unset any fixed index for the current view set using :py:meth:`~TensorWrapper.fix_indices`.
428
:param int Q: folding base.
430
self.maps[self.view]['Q'] = Q
431
self.maps[self.view]['ghost_shape'] = [ Q**(int(math.log(s-0.5, Q))+1) for s in self.get_view_shape() ]
432
self.maps[self.view]['fix_idxs'] = []
433
self.maps[self.view]['fix_dims'] = []
436
def reset_shape(self):
437
""" Reset the shape of the tensor erasing the reshape and quantics foldings.
439
self.maps[self.view]['Q'] = None
440
self.maps[self.view]['ghost_shape'] = None
446
def get_ghost_shape(self):
447
""" If the ``ghost_shape`` is set for this view, then it returns the shape obtained after quantics folding by the function :py:meth:`~TensorWrapper.set_Q` or after reshaping by the function :py:meth:`~TensorWrapper.reshape`. Otherwise the shape of the extended shape is returned.
449
if self.maps[self.view]['ghost_shape'] != None:
450
return self.maps[self.view]['ghost_shape']
452
return self.get_extended_shape()
454
def get_ghost_ndim(self):
455
""" If the ``ghost_shape`` is set for this view, then it returns the number of dimensions obtained after quantics folding by the function :py:meth:`~TensorWrapper.set_Q` or after reshaping by the function :py:meth:`~TensorWrapper.reshape`. Otherwise the number of dimensions of the view is returned.
457
return len(self.get_ghost_shape())
459
def get_ghost_size(self):
460
""" If the ``ghost_shape`` is set for this view, then it returns the size obtained after quantics folding by the function :py:meth:`~TensorWrapper.set_Q` or after reshaping by the function :py:meth:`~TensorWrapper.reshape`. Otherwise the size of the view is returned.
462
return reduce(operator.mul, self.get_ghost_shape())
464
def reset_ghost_shape(self):
465
""" Reset the shape of the tensor erasing the reshape and quantics foldings.
467
self.maps[self.view]['ghost_shape'] = None
470
def reshape(self,newshape):
471
""" Reshape the tensor. The number of items in the new shape must be consistent with :py:meth:`~TensorWrapper.get_extended_size`, i.e. with the number of items in the extended quantics size or the view size if ``Q`` is not set for this view.
473
This will unset any fixed index for the current view set using :py:meth:`~TensorWrapper.fix_indices`.
475
:param list newshape: new shape to be applied to the tensor.
477
if reduce(operator.mul, newshape, 1) == self.get_extended_size():
478
self.maps[self.view]['ghost_shape'] = tuple(newshape)
479
self.maps[self.view]['fix_idxs'] = []
480
self.maps[self.view]['fix_dims'] = []
488
""" Always returns the shape of the actual tensor view
490
.. note: use :py:meth:`TensorWrapper.get_global_shape` to get the shape of the original tensor
492
return tuple([ s for dim,s in enumerate(self.get_ghost_shape()) if not (dim in self.maps[self.view]['fix_dims']) ])
495
""" Always returns the number of dimensions of the tensor view
497
.. note: use :py:meth:`TensorWrapper.get_global_ndim` to get the number of dimensions of the original tensor
499
return len(self.get_shape())
502
""" Always returns the size of the tensor view
504
.. note: use :py:meth:`TensorWrapper.get_global_size` to get the size of the original tensor
506
return reduce(operator.mul, self.get_shape(), 1)
508
def update_shape(self):
509
self.shape = self.get_shape()
510
self.size = self.get_size()
511
self.ndim = self.get_ndim()
312
513
def fix_indices(self, idxs, dims):
313
514
""" Fix some of the indices in the tensor wrapper and reshape/resize it accordingly. The internal storage of the data is still done with respect to the global indices, but once some indices are fixed, the TensorWrapper can be accessed using just the remaining free indices.
325
526
# Reorder the lists
326
527
i_ord = sorted(range(len(dims)), key=dims.__getitem__)
327
self.fix_idxs = [ idxs[i] for i in i_ord ]
328
self.fix_dims = [ dims[i] for i in i_ord ]
528
self.maps[self.view]['fix_idxs'] = [ idxs[i] for i in i_ord ]
529
self.maps[self.view]['fix_dims'] = [ dims[i] for i in i_ord ]
330
530
# Update shape, ndim and size
331
531
self.update_shape()
333
533
def release_indices(self):
338
def update_shape(self):
339
self.shape = self.get_shape()
340
self.size = self.get_size()
341
self.ndim = self.get_ndim()
343
def reshape(self,newshape):
344
if reduce(operator.mul, newshape, 1) == self.get_q_size():
345
self.__ghost_shape = tuple(newshape)
349
def reset_shape(self):
350
self.__ghost_shape = None
353
def full_to_q(self,idxs):
354
return idxfold( self.get_q_shape(), idxunfold( self.get_full_shape(), idxs ) )
356
def q_to_full(self,idxs):
357
return idxfold( self.get_full_shape(), idxunfold( self.get_q_shape(), idxs ) )
359
def q_to_global(self,idxs):
360
""" This is a non-injective function from the q indices to the global indices
362
return tuple( [ ( i if i<N else N-1 ) for i,N in zip(idxs,self.get_global_shape()) ] )
364
def global_to_q(self,idxs):
365
""" This operation is undefined because one global idx can point to many q indices
368
raise NotImplemented("This operation is undefined because one global idx can point to many q indices")
372
def global_to_full(self,idxs):
373
return self.q_to_full( self.global_to_q( idxs ) )
375
def full_to_global(self,idxs):
376
return self.q_to_global( self.full_to_q( idxs ) )
534
""" Release all the indices in the tensor wrapper which were fixed using :py:meth:`~TensorWrapper.fix_indices`.
536
self.maps[self.view]['fix_idxs'] = []
537
self.maps[self.view]['fix_dims'] = []
540
#####################################################
541
# INDEX TRANSFORMATIONS #
542
#####################################################
547
def global_to_view(self, idxs):
548
""" This maps the index from the global shape to the view shape.
550
:param tuple idxs: tuple representing an index to be transformed.
552
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
553
.. note:: this returns an error if the ``idxs`` do not belong to the index mapping of the current view.
555
return tuple( [ self.idx_max[d].index(i) for d,i in enumerate(idxs) ] )
557
def view_to_ghost(self, idxs):
558
""" This maps the index from the view to the ghost shape.
560
:param list idxs: list of indices to be transformed
562
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
563
.. note:: this returns an error if the ghost shape is obtained by quantics folding, because the one view index can be pointing to many indices in the folding.
565
if self.maps[self.view]['Q'] != None:
566
raise NotImplemented("This operation is undefined because one view idx can point to many q indices")
568
return idxfold( self.get_ghost_shape(), idxunfold( self.get_view_shape(), idxs ) )
570
def global_to_ghost(self,idxs):
571
""" This maps the index from the global shape to the ghost shape.
573
:param list idxs: list of indices to be transformed
575
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
577
For :py:meth:`TensorWrapper` ``A``, this corresponds to:
579
>>> A.view_to_ghost( A.global_to_view( idxs ) )
582
return self.view_to_ghost( self.global_to_view( idxs ) )
587
def shape_to_ghost(self, idxs_in):
588
""" This maps the index from the current shape of the view (fixed indices) to the ghost shape.
590
:param list idxs_in: list of indices to be transformed
592
.. note:: slicing is admitted here.
595
# Insert the fixed indices
596
for i in self.maps[self.view]['fix_dims']:
597
idxs.insert(i, self.maps[self.view]['fix_idxs'][ self.maps[self.view]['fix_dims'].index(i)] )
600
def ghost_to_extended(self, idxs):
601
return idxfold( self.get_extended_shape(), idxunfold( self.get_ghost_shape(), idxs ) )
603
def ghost_to_view(self, idxs):
604
""" This maps the index from the current ghost shape of the view to the view shape.
606
:param list idxs_in: list of indices to be transformed
608
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
610
if self.maps[self.view]['Q'] != None:
611
return tuple( [ ( i if i<N else N-1 ) for i,N in zip(self.ghost_to_extended(idxs),self.get_view_shape()) ] )
613
idxs = idxfold( self.get_view_shape(), idxunfold( self.get_ghost_shape(), idxs ) )
616
def shape_to_view(self, idxs):
617
""" This maps the index from the current shape of the view to the view shape.
619
:param list idxs_in: list of indices to be transformed
621
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
623
return self.ghost_to_view( self.shape_to_ghost( idxs ) )
625
def view_to_global(self, idxs):
626
""" This maps the index in view to the global indices of the full tensor wrapper.
628
:param list idxs_in: list of indices to be transformed
630
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
632
return tuple( [ self.maps[self.view]['idx_map'][d][i] for d,i in enumerate(idxs) ] )
634
def ghost_to_global(self, idxs):
635
""" This maps the index from the current ghost shape of the view to the global shape.
637
:param list idxs_in: list of indices to be transformed
639
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
641
return self.view_to_global( self.ghost_to_view( idxs ) )
643
def shape_to_global(self, idxs):
644
""" This maps the index from the current shape of the view to the global shape.
646
:param list idxs_in: list of indices to be transformed
648
.. note:: no slicing is admitted here. Preprocess ``idxs`` with :py:meth:`expand_idxs` if slicing is required.
650
return self.view_to_global( self.ghost_to_view( self.shape_to_ghost( idxs ) ) )
652
######################
653
# CHECKS ON EXTENDED #
654
######################
655
def extended_is_view(self, idxs):
657
:return: True if the idxs is in the view shape. False if it is outside
659
return all( [ i<N for i,N in zip(idxs,self.get_view_shape()) ] )
661
def full_is_view(self,idxs):
662
return self.extended_is_view( self.ghost_to_extended( idxs ) )
664
###############################################
665
# DATA AND FUNCTIONS #
666
###############################################
667
def get_fill_level(self):
668
if self.twtype == 'view': return 0
669
else: return len(self.data)
378
671
def get_fill_idxs(self):
379
672
return self.data.keys()
417
710
self.store_object = store_object
419
712
def __getitem__(self,idxs_in):
420
# Transform the tuple to a list for convinience
421
idxs_in = list(idxs_in)
423
# Slice notation can be used. Remember: slice(start:stop:step)
424
if len(idxs_in) != len(self.shape):
425
raise IndexError('wrong number of indices')
427
# Check that all the lists are of the same length
430
for i,idx in enumerate(idxs_in):
431
if isinstance(idx, int):
433
if isinstance(idx, list) or isinstance(idx,tuple):
434
idxs_in[i] = list(idx)
437
elif llen != len(idx):
438
raise IndexError('List of indices must have the same length.')
440
if llen == None: llen = 1
442
# Expand single indices in idxs_in to llen
443
for i in int_idx: idxs_in[i] = [idxs_in[i]] * llen
445
# Update input indices of slices and lists
448
for i,idx in enumerate(idxs_in):
449
if isinstance(idx, list) or isinstance(idx,tuple):
450
list_idx_in.append(i)
451
if isinstance(idx, slice):
452
slice_idx_in.append(i)
454
# Insert fixed indices
455
for i in self.fix_dims:
456
idxs_in.insert(i, [self.fix_idxs[self.fix_dims.index(i)]] * llen)
458
# Construct list of indices which are lists and slices
464
for i,idx in enumerate(idxs_in):
465
if isinstance(idx, list) or isinstance(idx,tuple):
467
list_IDXs.append( idx )
468
if isinstance(idx, slice):
470
IDXs = range(idx.start if idx.start != None else 0,
471
idx.stop if idx.stop != None else self.get_full_shape()[i],
472
idx.step if idx.step != None else 1)
473
slice_IDXs.append( IDXs )
474
out_shape.append(len(IDXs))
476
if len(list_idx) == 0: list_IDXs.append( [-1] ) # Ghost element added to make the full slicing work
477
unlistIdxs = itertools.izip(*list_IDXs)
479
transpose_list_shape = False
481
out_shape.insert(0,llen)
482
if len(slice_idx_in) > 0 and len(list_idx_in) > 0 and min(list_idx_in) > max(slice_idx_in):
483
transpose_list_shape = True
485
# Un-slice sliced idxs
486
unslicedIdxs = itertools.product(*slice_IDXs)
488
# Final list of indices (iterator)
489
lidxs = itertools.product(unlistIdxs, unslicedIdxs)
714
(lidxs,out_shape,transpose_list_shape) = expand_idxs(idxs_in, self.shape, self.get_ghost_shape(), self.maps[self.view]['fix_dims'], self.maps[self.view]['fix_idxs'])
491
716
# Allocate output array
492
717
if len(out_shape) > 0:
493
718
out = np.empty(out_shape, dtype=self.dtype)
719
if self.active_weights:
720
out_weights = np.empty(out_shape, dtype=self.dtype)