~maddevelopers/mg5amcnlo/new_clustering

« back to all changes in this revision

Viewing changes to madgraph/core/diagram_generation.py

  • Committer: Rikkert Frederix
  • Date: 2021-09-09 15:51:40 UTC
  • mfrom: (78.75.502 3.2.1)
  • Revision ID: frederix@physik.uzh.ch-20210909155140-rg6umfq68h6h47cf
merge with 3.2.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
19
19
based on relevant properties.
20
20
"""
21
21
 
 
22
from __future__ import absolute_import
 
23
from six.moves import filter
 
24
#force filter to be a generator # like in py3
 
25
 
22
26
import array
23
27
import copy
24
28
import itertools
27
31
import madgraph.core.base_objects as base_objects
28
32
import madgraph.various.misc as misc
29
33
from madgraph import InvalidCmd, MadGraph5Error
 
34
from six.moves import range
 
35
from six.moves import zip
 
36
from six.moves import filter
30
37
 
31
38
logger = logging.getLogger('madgraph.diagram_generation')
32
39
 
152
159
        leg_vertices = [cls.vertices_from_link(l, model) for l in link.links]
153
160
        # The daughter legs are in the first entry
154
161
        legs = base_objects.LegList(sorted([l for l,v in leg_vertices],
155
 
                                           lambda l1,l2: l2.get('number') - \
156
 
                                           l1.get('number')))
 
162
                                           key= lambda l: l.get('number'), reverse=True))
 
163
 
157
164
        # The daughter vertices are in the second entry
158
165
        vertices = base_objects.VertexList(sum([v for l, v in leg_vertices],
159
166
                                               []))
260
267
        """Reorder a permutation with respect to start_perm. Note that
261
268
        both need to start from 1."""
262
269
        if perm == start_perm:
263
 
            return range(len(perm))
 
270
            return list(range(len(perm)))
264
271
        order = [i for (p,i) in \
265
272
                 sorted([(p,i) for (i,p) in enumerate(perm)])]
266
273
        return [start_perm[i]-1 for i in order]
366
373
            return len(self.links) < len(other.links)
367
374
 
368
375
        if self.vertex_id[0] != other.vertex_id[0]:
369
 
            return self.vertex_id[0] < other.vertex_id[0]
 
376
            if isinstance(self.vertex_id[0], int) and isinstance(other.vertex_id[0], tuple):
 
377
                return True
 
378
            elif isinstance(self.vertex_id[0], tuple) and isinstance(other.vertex_id[0], int):
 
379
                return False
 
380
            elif isinstance(self.vertex_id[0], str) and isinstance(other.vertex_id[0], tuple):
 
381
                return True
 
382
            elif isinstance(self.vertex_id[0], tuple) and isinstance(other.vertex_id[0], str):
 
383
                return False            
 
384
            else:
 
385
                try:
 
386
                    return self.vertex_id[0] < other.vertex_id[0]
 
387
                except TypeError as error:
 
388
                    if error.args == ("'<' not supported between instances of 'tuple' and 'str'",):
 
389
                        return False
 
390
                    elif error.args == ("'<' not supported between instances of 'str' and 'tuple'",):
 
391
                        return True
 
392
                    else:
 
393
                        raise Exception
 
394
                    
370
395
 
371
396
        for i, link in enumerate(self.links):
372
397
            if i > len(other.links) - 1:
439
464
 
440
465
        if name == 'process':
441
466
            if not isinstance(value, base_objects.Process):
442
 
                raise self.PhysicsObjectError, \
443
 
                        "%s is not a valid Process object" % str(value)
 
467
                raise self.PhysicsObjectError("%s is not a valid Process object" % str(value))
444
468
        if name == 'diagrams':
445
469
            if not isinstance(value, base_objects.DiagramList):
446
 
                raise self.PhysicsObjectError, \
447
 
                        "%s is not a valid DiagramList object" % str(value)
 
470
                raise self.PhysicsObjectError("%s is not a valid DiagramList object" % str(value))
448
471
        if name == 'has_mirror_process':
449
472
            if not isinstance(value, bool):
450
 
                raise self.PhysicsObjectError, \
451
 
                        "%s is not a valid boolean" % str(value)
 
473
                raise self.PhysicsObjectError("%s is not a valid boolean" % str(value))
452
474
        return True
453
475
 
454
476
    def get(self, name):
554
576
                                     process.get('overall_orders')[key])
555
577
            except KeyError:
556
578
                process.get('orders')[key] = process.get('overall_orders')[key]
557
 
 
 
579
                
558
580
        assert model.get('particles'), \
559
581
           "particles are missing in model: %s" %  model.get('particles')
560
582
 
564
586
 
565
587
        res = base_objects.DiagramList()
566
588
        # First check that the number of fermions is even
567
 
        if len(filter(lambda leg: model.get('particle_dict')[\
568
 
                        leg.get('id')].is_fermion(), legs)) % 2 == 1:
 
589
        if len([leg for leg in legs if model.get('particle_dict')[\
 
590
                        leg.get('id')].is_fermion()]) % 2 == 1:
569
591
            if not returndiag:
570
592
                self['diagrams'] = res
571
 
                raise InvalidCmd, 'The number of fermion is odd'
 
593
                raise InvalidCmd('The number of fermion is odd')
572
594
            else:
573
595
                return False, res
574
596
 
575
597
        # Then check same number of incoming and outgoing fermions (if
576
598
        # no Majorana particles in model)
577
599
        if not model.get('got_majoranas') and \
578
 
           len(filter(lambda leg: leg.is_incoming_fermion(model), legs)) != \
579
 
           len(filter(lambda leg: leg.is_outgoing_fermion(model), legs)):
 
600
           len([leg for leg in legs if leg.is_incoming_fermion(model)]) != \
 
601
           len([leg for leg in legs if leg.is_outgoing_fermion(model)]):
580
602
            if not returndiag:
581
603
                self['diagrams'] = res
582
 
                raise InvalidCmd, 'The number of of incoming/outcoming fermions are different'
 
604
                raise InvalidCmd('The number of of incoming/outcoming fermions are different')
583
605
            else:
584
606
                return False, res
585
607
 
605
627
            if abs(total) > 1e-10:
606
628
                if not returndiag:
607
629
                    self['diagrams'] = res
608
 
                    raise InvalidCmd, 'No %s conservation for this process ' % charge
 
630
                    raise InvalidCmd('No %s conservation for this process ' % charge)
609
631
                    return res
610
632
                else:
611
 
                    raise InvalidCmd, 'No %s conservation for this process ' % charge
 
633
                    raise InvalidCmd('No %s conservation for this process ' % charge)
612
634
                    return res, res
613
635
 
614
636
        if not returndiag:
699
721
            # extra vertex corresponding to particle 1=1, so we need
700
722
            # to exclude the two last vertexes.
701
723
            if is_decay_proc: lastvx = -2
702
 
            ninitial = len(filter(lambda leg: leg.get('state') == False,
703
 
                                  process.get('legs')))
 
724
            ninitial = len([leg for leg in process.get('legs') if leg.get('state') == False])
704
725
            # Check required s-channels for each list in required_s_channels
705
726
            old_res = res
706
727
            res = base_objects.DiagramList()
707
728
            for id_list in process.get('required_s_channels'):
708
 
                res_diags = filter(lambda diagram: \
709
 
                          all([req_s_channel in \
 
729
                res_diags = [diagram for diagram in old_res if all([req_s_channel in \
710
730
                               [vertex.get_s_channel_id(\
711
731
                               process.get('model'), ninitial) \
712
732
                               for vertex in diagram.get('vertices')[:lastvx]] \
713
733
                               for req_s_channel in \
714
 
                               id_list]), old_res)
 
734
                               id_list])]
715
735
                # Add diagrams only if not already in res
716
736
                res.extend([diag for diag in res_diags if diag not in res])
717
737
 
720
740
        # Note that we shouldn't look at the last vertex in each
721
741
        # diagram, since that is the n->0 vertex
722
742
        if process.get('forbidden_s_channels'):
723
 
            ninitial = len(filter(lambda leg: leg.get('state') == False,
724
 
                                  process.get('legs')))
 
743
            ninitial = len([leg for leg in process.get('legs') if leg.get('state') == False])
725
744
            if ninitial == 2:
726
745
                res = base_objects.DiagramList(\
727
 
                filter(lambda diagram: \
728
 
                       not any([vertex.get_s_channel_id(\
 
746
                [diagram for diagram in res if not any([vertex.get_s_channel_id(\
729
747
                           process.get('model'), ninitial) \
730
748
                                in process.get('forbidden_s_channels')
731
 
                                for vertex in diagram.get('vertices')[:-1]]),
732
 
                       res))
 
749
                                for vertex in diagram.get('vertices')[:-1]])])
733
750
            else:
734
751
                # split since we need to avoid that the initial particle is forbidden 
735
752
                # as well. 
742
759
                    vertex =  diagram.get('vertices')[-1]
743
760
                    if any([l['number'] ==1 for l in vertex.get('legs')]):
744
761
                        leg1 = [l['number'] for l in vertex.get('legs') if l['number'] !=1][0]
745
 
                    to_loop = range(len(diagram.get('vertices'))-1)
 
762
                    to_loop = list(range(len(diagram.get('vertices'))-1))
746
763
                    if leg1 >1:   
747
764
                        to_loop.reverse()
748
765
                    for i in to_loop:
762
779
        # Mark forbidden (onshell) s-channel propagators, to forbid onshell
763
780
        # generation.
764
781
        if process.get('forbidden_onsh_s_channels'):
765
 
            ninitial = len(filter(lambda leg: leg.get('state') == False,
766
 
                              process.get('legs')))
 
782
            ninitial = len([leg for leg in process.get('legs') if leg.get('state') == False])
767
783
            
768
784
            verts = base_objects.VertexList(sum([[vertex for vertex \
769
785
                                                  in diagram.get('vertices')[:-1]
808
824
                    nexttolastvertex = copy.copy(vertices.pop())
809
825
                    legs = copy.copy(nexttolastvertex.get('legs'))
810
826
                    ntlnumber = legs[-1].get('number')
811
 
                    lastleg = filter(lambda leg: leg.get('number') != ntlnumber,
812
 
                                     lastvx.get('legs'))[0]
 
827
                    lastleg = [leg for leg in lastvx.get('legs') if leg.get('number') != ntlnumber][0]
813
828
                    # Reset onshell in case we have forbidden s-channels
814
829
                    if lastleg.get('onshell') == False:
815
830
                        lastleg.set('onshell', None)
896
911
                                             fcts=['remove_diag'])
897
912
        else:
898
913
            #example and simple tests
899
 
            def remove_diag(diag):
 
914
            def remove_diag(diag, model=None):
900
915
                for vertex in diag['vertices']: #last 
901
916
                    if vertex['id'] == 0: #special final vertex
902
917
                        continue 
907
922
 
908
923
        res = diag_list.__class__()                
909
924
        nb_removed = 0 
 
925
        model = self['process']['model'] 
910
926
        for diag in diag_list:
911
 
            if remove_diag(diag):
 
927
            if remove_diag(diag, model):
912
928
                nb_removed +=1
913
929
            else:
914
930
                res.append(diag)
1175
1191
                    number = min([leg.get('number') for leg in entry])
1176
1192
                    # 3) state is final, unless there is exactly one initial 
1177
1193
                    # state particle involved in the combination -> t-channel
1178
 
                    if len(filter(lambda leg: leg.get('state') == False,
1179
 
                                  entry)) == 1:
 
1194
                    if len([leg for leg in entry if leg.get('state') == False]) == 1:
1180
1195
                        state = False
1181
1196
                    else:
1182
1197
                        state = True
1361
1376
 
1362
1377
            for process in argument.get('decay_chains'):
1363
1378
                if process.get('perturbation_couplings'):
1364
 
                    raise MadGraph5Error,\
1365
 
                          "Decay processes can not be perturbed"
 
1379
                    raise MadGraph5Error("Decay processes can not be perturbed")
1366
1380
                process.set('overall_orders', argument.get('overall_orders'))
1367
1381
                if not process.get('is_decay_chain'):
1368
1382
                    process.set('is_decay_chain',True)
1369
1383
                if not process.get_ninitial() == 1:
1370
 
                    raise InvalidCmd,\
1371
 
                          "Decay chain process must have exactly one" + \
1372
 
                          " incoming particle"
 
1384
                    raise InvalidCmd("Decay chain process must have exactly one" + \
 
1385
                          " incoming particle")
1373
1386
                self['decay_chains'].append(\
1374
1387
                    DecayChainAmplitude(process, collect_mirror_procs,
1375
1388
                                        ignore_six_quark_processes,
1442
1455
 
1443
1456
        if name == 'amplitudes':
1444
1457
            if not isinstance(value, AmplitudeList):
1445
 
                raise self.PhysicsObjectError, \
1446
 
                        "%s is not a valid AmplitudeList" % str(value)
 
1458
                raise self.PhysicsObjectError("%s is not a valid AmplitudeList" % str(value))
1447
1459
        if name == 'decay_chains':
1448
1460
            if not isinstance(value, DecayChainAmplitudeList):
1449
 
                raise self.PhysicsObjectError, \
1450
 
                        "%s is not a valid DecayChainAmplitudeList object" % \
1451
 
                        str(value)
 
1461
                raise self.PhysicsObjectError("%s is not a valid DecayChainAmplitudeList object" % \
 
1462
                        str(value))
1452
1463
        return True
1453
1464
 
1454
1465
    def get_sorted_keys(self):
1600
1611
 
1601
1612
        if name == 'process_definitions':
1602
1613
            if not isinstance(value, base_objects.ProcessDefinitionList):
1603
 
                raise self.PhysicsObjectError, \
1604
 
                        "%s is not a valid ProcessDefinitionList object" % str(value)
 
1614
                raise self.PhysicsObjectError("%s is not a valid ProcessDefinitionList object" % str(value))
1605
1615
 
1606
1616
        if name == 'amplitudes':
1607
1617
            if not isinstance(value, AmplitudeList):
1608
 
                raise self.PhysicsObjectError, \
1609
 
                        "%s is not a valid AmplitudeList object" % str(value)
 
1618
                raise self.PhysicsObjectError("%s is not a valid AmplitudeList object" % str(value))
1610
1619
 
1611
1620
        if name in ['collect_mirror_procs']:
1612
1621
            if not isinstance(value, bool):
1613
 
                raise self.PhysicsObjectError, \
1614
 
                        "%s is not a valid boolean" % str(value)
 
1622
                raise self.PhysicsObjectError("%s is not a valid boolean" % str(value))
1615
1623
 
1616
1624
        if name == 'ignore_six_quark_processes':
1617
1625
            if not isinstance(value, list):
1618
 
                raise self.PhysicsObjectError, \
1619
 
                        "%s is not a valid list" % str(value)
 
1626
                raise self.PhysicsObjectError("%s is not a valid list" % str(value))
1620
1627
 
1621
1628
        return True
1622
1629
 
1648
1655
        """Return process property names as a nicely sorted list."""
1649
1656
 
1650
1657
        return ['process_definitions', 'amplitudes']
 
1658
    
 
1659
    def get_model(self):
 
1660
        
 
1661
        return self['process_definitions'][0]['model']
1651
1662
 
1652
1663
    @classmethod
1653
1664
    def generate_multi_amplitudes(cls,process_definition,
1666
1677
                                    "%s not valid ProcessDefinition object" % \
1667
1678
                                    repr(process_definition)
1668
1679
 
1669
 
        # Set automatic coupling orders
1670
 
        process_definition.set('orders', MultiProcess.\
 
1680
        # Set automatic coupling orders if born_sq_orders are not specified
 
1681
        # otherwise skip
 
1682
        if not process_definition['born_sq_orders']:
 
1683
            process_definition.set('orders', MultiProcess.\
1671
1684
                               find_optimal_process_orders(process_definition,
1672
 
                               diagram_filter))
 
1685
                                                           diagram_filter))
1673
1686
        # Check for maximum orders from the model
1674
1687
        process_definition.check_expansion_orders()
1675
1688
 
1839
1852
            if len(failed_procs) == 1 and 'error' in locals():
1840
1853
                raise error
1841
1854
            else:
1842
 
                raise NoDiagramException, \
1843
 
            "No amplitudes generated from process %s. Please enter a valid process" % \
1844
 
                  process_definition.nice_string()
 
1855
                raise NoDiagramException("No amplitudes generated from process %s. Please enter a valid process" % \
 
1856
                  process_definition.nice_string())
1845
1857
        
1846
1858
 
1847
1859
        # Return the produced amplitudes
1923
1935
        
1924
1936
        # Extract the initial and final leg ids
1925
1937
        isids = [leg['ids'] for leg in \
1926
 
                 filter(lambda leg: leg['state'] == False, process_definition['legs'])]
 
1938
                 [leg for leg in process_definition['legs'] if leg['state'] == False]]
1927
1939
        fsids = [leg['ids'] for leg in \
1928
 
                 filter(lambda leg: leg['state'] == True, process_definition['legs'])]
 
1940
                 [leg for leg in process_definition['legs'] if leg['state'] == True]]
1929
1941
 
1930
1942
        max_WEIGHTED_order = \
1931
1943
                        (len(fsids + isids) - 2)*int(model.get_max_WEIGHTED())
1932
1944
        # get the definition of the WEIGHTED
1933
1945
        hierarchydef = process_definition['model'].get('order_hierarchy')
1934
1946
        tmp = []
1935
 
        hierarchy = hierarchydef.items()
 
1947
        hierarchy = list(hierarchydef.items())
1936
1948
        hierarchy.sort()
1937
1949
        for key, value in hierarchydef.items():
1938
1950
            if value>1:
1952
1964
            # based on crossing symmetry
1953
1965
            failed_procs = []
1954
1966
            # Generate all combinations for the initial state        
1955
 
            for prod in apply(itertools.product, isids):
 
1967
            for prod in itertools.product(*isids):
1956
1968
                islegs = [ base_objects.Leg({'id':id, 'state': False}) \
1957
1969
                        for id in prod]
1958
1970
 
1961
1973
 
1962
1974
                red_fsidlist = []
1963
1975
 
1964
 
                for prod in apply(itertools.product, fsids):
 
1976
                for prod in itertools.product(*fsids):
1965
1977
 
1966
1978
                    # Remove double counting between final states
1967
1979
                    if tuple(sorted(prod)) in red_fsidlist:
2030
2042
                    amplitude = Amplitude({'process': process})
2031
2043
                    try:
2032
2044
                        amplitude.generate_diagrams(diagram_filter=diagram_filter)
2033
 
                    except InvalidCmd, error:
 
2045
                    except InvalidCmd as error:
2034
2046
                        failed_procs.append(tuple(sorted_legs))
2035
2047
                    else:
2036
2048
                        if amplitude.get('diagrams'):
2050
2062
    def cross_amplitude(amplitude, process, org_perm, new_perm):
2051
2063
        """Return the amplitude crossed with the permutation new_perm"""
2052
2064
        # Create dict from original leg numbers to new leg numbers
2053
 
        perm_map = dict(zip(org_perm, new_perm))
 
2065
        perm_map = dict(list(zip(org_perm, new_perm)))
2054
2066
        # Initiate new amplitude
2055
2067
        new_amp = copy.copy(amplitude)
2056
2068
        # Number legs
2091
2103
        else:
2092
2104
            tmplist.append([item])
2093
2105
 
2094
 
    for item in apply(itertools.product, tmplist):
 
2106
    for item in itertools.product(*tmplist):
2095
2107
        res.append(list(item))
2096
2108
 
2097
2109
    return res