API

Couplings dataclass

A data class for representing orbital couplings.

Attributes:
  • hamiltonian (_SortedTupleDict) –

    A dictionary-like container holding Hamiltonian terms.

  • coulomb (_SortedTupleDict) –

    A dictionary-like container for Coulomb interaction terms.

  • dipole_transitions (_SortedTupleDict) –

    A dictionary-like container for storing dipole transition elements.

Source code in src/granad/orbitals.py
@dataclass
class Couplings:
    """
    A data class for representing orbital couplings.

    Attributes:
        hamiltonian (_SortedTupleDict): A dictionary-like container holding Hamiltonian terms.
        coulomb (_SortedTupleDict): A dictionary-like container for Coulomb interaction terms.
        dipole_transitions (_SortedTupleDict): A dictionary-like container for storing dipole transition elements.
    """

    hamiltonian : _SortedTupleDict = field(default_factory=_SortedTupleDict)
    coulomb : _SortedTupleDict = field(default_factory=_SortedTupleDict)
    dipole_transitions : _SortedTupleDict = field(default_factory=_SortedTupleDict)

    def __str__( self ):
        return 

    def __add__( self, other ):
        if isinstance(other, Couplings):            
            return Couplings(
                _SortedTupleDict(self.hamiltonian | other.hamiltonian),
                _SortedTupleDict(self.coulomb | other.coulomb),
                _SortedTupleDict(self.dipole_transitions | other.dipole_transitions)
            )
        raise ValueError        

Orbital dataclass

Attributes:
  • position (Array) –

    The position of the orbital in space, initialized by default to a zero position. This field is not used in hashing or comparison of instances.

  • layer_index (Optional[int]) –

    An optional index representing the layer of the orbital within its atom, may be None if not specified.

  • tag (Optional[str]) –

    An optional tag for additional identification or categorization of the orbital, defaults to None.

  • spin (Optional[int]) –

    The spin quantum number of the orbital, indicating its intrinsic angular momentum, optional and may be None. Note This is experimental.

  • atom_name (Optional[str]) –

    The name of the atom this orbital belongs to, can be None if not applicable.

  • group_id (int) –

    A group identifier for the orbital, automatically assigned by a Watchdog class default factory method. For example, all pz orbitals in a single graphene flake get the same group_id.

Source code in src/granad/orbitals.py
@dataclass
class Orbital:
    """
    Attributes:
        position (jax.Array): The position of the orbital in space, initialized by default to a zero position.
                              This field is not used in hashing or comparison of instances.
        layer_index (Optional[int]): An optional index representing the layer of the orbital within its atom,
                                     may be None if not specified.
        tag (Optional[str]): An optional tag for additional identification or categorization of the orbital,
                             defaults to None.
        spin (Optional[int]): The spin quantum number of the orbital, indicating its intrinsic angular momentum,
                              optional and may be None. *Note* This is experimental.
        atom_name (Optional[str]): The name of the atom this orbital belongs to, can be None if not applicable.
        group_id (int): A group identifier for the orbital, automatically assigned by a Watchdog class
                        default factory method. For example, all pz orbitals in a single graphene flake get the same 
                        group_id.
    """
    position: jax.Array = field(default_factory=lambda : jnp.array([0, 0, 0]), hash=False, compare=False)
    layer_index: Optional[int] = None
    tag: Optional[str] = None
    spin: Optional[int] = None
    atom_name: Optional[str] = None
    group_id: _watchdog.GroupId = field(default_factory=_watchdog._Watchdog.next_value)

    def __post_init__(self):
        object.__setattr__(self, "position", jnp.array(self.position).astype(float))

    def __hash__(self):
        # Include only immutable fields in hash calculation
        return hash(
            (
                self.layer_index,
                self.tag,
                self.spin,
                self.atom_name,
                self.group_id.id,
            )
        )

    def __str__(self):
        return pformat(vars(self), sort_dicts=False)

    def __eq__(self, other):
        if not isinstance(other, Orbital):
            return NotImplemented
        return self.group_id == other.group_id and self.layer_index == other.layer_index

    def __lt__(self, other):
        if not isinstance(other, Orbital):
            return NotImplemented
        return self.group_id < other.group_id

    def __le__(self, other):
        return self < other or self == other

    def __gt__(self, other):
        return not self <= other

    def __ge__(self, other):
        return not self < other

    def __ne__(self, other):
        return not self == other

OrbitalList

A class that encapsulates a list of orbitals, providing an interface similar to a standard Python list, while also maintaining additional functionalities for coupling orbitals and managing their relationships.

The class stores orbitals in a wrapped Python list and handles the coupling of orbitals using dictionaries, where the keys are tuples of orbital identifiers (orb_id), and the values are the couplings (either a float or a function representing the coupling strength or mechanism between the orbitals).

The class also stores simulation parameters like the number of electrons and temperature in a dataclass.

The class computes physical observables (energies etc) lazily on the fly, when they are needed. If there is a basis (either site or energy) to reasonably associate with a quantity, the class exposes quantity_x as an attribute for the site basis and quantity_e as an attribute for the energy basis. By default, all quantities are in site basis, so quantity_x == quantity.

The class exposes simulation methods.

Attributes:
  • _list (list) ) –

    the underlying list that contains the orbitals

  • params (Params) –

    Simulation parameters like electron count and temperature.

  • couplings (_SortedTupleDict) –

    A (customized) dictionary where keys are tuples of orbital identifiers and values are the couplings (either float values or functions).

Note
  • Orbital Identification: Orbitals can be identified either by their group_id, a direct reference to the orbital object itself, or via a user-defined tag.
  • Index Access: Orbitals can be accessed and managed by their index in the list, allowing for list-like manipulation (addition, removal, access).
  • Coupling Definition: Allows for the definition and adjustment of couplings between pairs of orbitals, identified by a tuple of their respective identifiers. These couplings can dynamically represent the interaction strength or be a computational function that defines the interaction.
Source code in src/granad/orbitals.py
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
@plotting_methods
class OrbitalList:
    """
    A class that encapsulates a list of orbitals, providing an interface similar to a standard Python list,
    while also maintaining additional functionalities for coupling orbitals and managing their relationships.

    The class stores orbitals in a wrapped Python list and handles the coupling of orbitals using dictionaries,
    where the keys are tuples of orbital identifiers (orb_id), and the values are the couplings (either a float
    or a function representing the coupling strength or mechanism between the orbitals).

    The class also stores simulation parameters like the number of electrons and temperature in a dataclass.

    The class computes physical observables (energies etc) lazily on the fly, when they are needed. If there is 
    a basis (either site or energy) to reasonably associate with a quantity, the class exposes quantity_x as an attribute
    for the site basis and quantity_e as an attribute for the energy basis. By default, all quantities are in site basis, so
    quantity_x == quantity.

    The class exposes simulation methods.

    Attributes:
        _list (list) : the underlying list that contains the orbitals
        params (Params): Simulation parameters like electron count and temperature.
        couplings (_SortedTupleDict): A (customized) dictionary where keys are tuples of orbital identifiers and values are the couplings
                          (either float values or functions).

    Note:
        - **Orbital Identification**: Orbitals can be identified either by their group_id, a direct
          reference to the orbital object itself, or via a user-defined tag.
        - **Index Access**: Orbitals can be accessed and managed by their index in the list, allowing for
          list-like manipulation (addition, removal, access).
        - **Coupling Definition**: Allows for the definition and adjustment of couplings between pairs of orbitals,
          identified by a tuple of their respective identifiers. These couplings can dynamically represent the
          interaction strength or be a computational function that defines the interaction.
    """
    def __init__( self, orbs = None, couplings = None, params = None, recompute = True):
        self._list = orbs if orbs is not None else []
        self.couplings = couplings if couplings is not None else Couplings( )
        self.params = params if params is not None else Params( len(orbs) )
        self._recompute = recompute    

    def __getattr__(self, property_name):
        if property_name.endswith("_x"):
            original_name = property_name[:-2]
            try:
                return getattr(self, original_name)
            except AttributeError:
                pass 
        elif property_name.endswith("_e"):
            original_name = property_name[:-2]
            try:
                return self.transform_to_energy_basis(getattr(self, original_name))
            except AttributeError:
                pass            
        raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {property_name!r}")

    def __len__(self):
        return len(self._list)

    # can't mutate, because orbitals are immutable
    def __getitem__(self, position):
        return self._list[position]

    def __repr__(self):
        info = f"List with {len(self)} orbitals, {self.electrons} electrons."
        exc = self.params.excitation
        info += f"\nExcitation: {exc[2]} electrons excited from energy levels {exc[0]} to {exc[1]}."
        info += f"\nIncluded tags with number of orbitals: {dict(Counter(o.tag for o in self))}"
        return info 

    def __iter__(self):
        return iter(self._list)

    def __add__(self, other):
        if not self._are_orbs(other):
            raise TypeError

        if any(orb in other for orb in self._list):
            raise ValueError

        if isinstance(other, OrbitalList):
            new_list = (self._list + list(other)).copy()
            new_couplings = self.couplings + other.couplings 
            new_params = self.params + other.params 

        return OrbitalList( new_list, new_couplings, new_params )

    @mutates
    def __setitem__(self, position, value):
        if isinstance(value, Orbital):
            self._list[position] = value
        raise TypeError

    def _delete_coupling(self, orb, coupling):
        keys_to_remove = [key for key in coupling if orb in key]
        for key in keys_to_remove:
            del coupling[key]

    @mutates
    def __delitem__(self, position):
        orb = self._list[position]
        self._delete_coupling(orb, self.couplings.hamiltonian)
        self._delete_coupling(orb, self.couplings.coulomb)
        self._delete_coupling(orb, self.couplings.dipole_transitions)
        self.params.electrons -= 1
        del self._list[position]

    @staticmethod
    def _are_orbs(candidate):
        return all(isinstance(orb, Orbital) for orb in candidate)

    @mutates
    def _set_coupling(self, orb1, orb2, val_or_func, coupling):
        for o1 in orb1:
            for o2 in orb2:
                coupling[(o1, o2)] = val_or_func

    def _hamiltonian_coulomb(self):

        def fill_matrix(matrix, coupling_dict):

            # matrix is NxN and hermitian
            # we fill the upper triangle with a mask and make the matrix hermitian by adding adding its conjugate transpose
            dummy = jnp.arange(len(self))
            triangle_mask = dummy[:, None] >= dummy

            # first, we loop over all group_id couplings => interactions between groups
            for key, function in coupling_dict.group_id_items():
                # if it were the other way around, these would be zeroed by the triangle mask
                cols = group_ids == key[0].id
                rows = (group_ids == key[1].id)[:, None] 
                combination_indices = jnp.logical_and(rows, cols)                
                valid_indices = jnp.logical_and(triangle_mask, combination_indices)

                # hotfix
                if valid_indices.sum() == 0:
                    valid_indices = jnp.logical_and(triangle_mask.T, combination_indices)

                function = jax.vmap(function)
                matrix = matrix.at[valid_indices].set(
                    function(distances[valid_indices])
                )

            matrix += matrix.conj().T - jnp.diag(jnp.diag(matrix))

            # we now set single elements
            rows, cols, vals = [], [], []
            for key, val in coupling_dict.orbital_items():
                rows.append(self._list.index(key[0]))
                cols.append(self._list.index(key[1]))
                vals.append(val)

            vals = jnp.array(vals)
            matrix = matrix.at[rows, cols].set(vals)
            matrix = matrix.at[cols, rows].set(vals.conj())

            return matrix

        positions = self.positions
        distances = jnp.round(positions - positions[:, None], 6)
        group_ids = jnp.array( [orb.group_id.id for orb in self._list] )

        hamiltonian = fill_matrix(
            jnp.zeros((len(self), len(self))).astype(complex), self.couplings.hamiltonian
        )
        coulomb = fill_matrix(
            jnp.zeros((len(self), len(self))).astype(complex), self.couplings.coulomb
        )
        return hamiltonian, coulomb

    @mutates
    def set_dipole_element(self, orb1, orb2, arr):
        """
        Sets a dipole transition for specified orbital or index pairs.

        Parameters:
            orb1: Identifier for orbital(s) for the first part of the transition.
            orb2: Identifier for orbital(s) for the second part of the transition.
            arr (jax.Array): The 3-element array containing dipole transition elements.
        """
        self._set_coupling(self.filter_orbs(orb1, Orbital), self.filter_orbs(orb2, Orbital), jnp.array(arr).astype(complex), self.couplings.dipole_transitions)

    def set_hamiltonian_groups(self, orb1, orb2, func):
        """
        Sets the hamiltonian coupling between two groups of orbitals.

        Parameters:
            orb1: Identifier for orbital(s) for the first group.
            orb2: Identifier for orbital(s) for the second group.
            func (callable): Function that defines the hamiltonian interaction.

        Note:
            The function `func` should be complex-valued.
        """
        self._set_coupling(
            self.filter_orbs(orb1, _watchdog.GroupId), self.filter_orbs(orb2, _watchdog.GroupId), self._ensure_complex(func), self.couplings.hamiltonian
        )

    def set_coulomb_groups(self, orb1, orb2, func):
        """
        Sets the Coulomb coupling between two groups of orbitals.

        Parameters:
            orb1: Identifier for orbital(s) for the first group.
            orb2: Identifier for orbital(s) for the second group.
            func (callable): Function that defines the Coulomb interaction.

        Note:
            The function `func` should be complex-valued.
        """
        self._set_coupling(
            self.filter_orbs(orb1, _watchdog.GroupId), self.filter_orbs(orb2, _watchdog.GroupId), self._ensure_complex(func), self.couplings.coulomb
        )

    def set_onsite_hopping(self, orb, val):
        """
        Sets onsite hopping element of the Hamiltonian matrix.

        Parameters:
            orb: Identifier for orbital(s).
            val (real): The value to set for the onsite hopping.
        """
        self.set_hamiltonian_element(orb, orb, val)        

    def set_hamiltonian_element(self, orb1, orb2, val):
        """
        Sets an element of the Hamiltonian matrix between two orbitals or indices.

        Parameters:
            orb1: Identifier for orbital(s) for the first element.
            orb2: Identifier for orbital(s) for the second element.
            val (complex): The complex value to set for the Hamiltonian element.
        """
        self._set_coupling(self.filter_orbs(orb1, Orbital), self.filter_orbs(orb2, Orbital), self._ensure_complex(val), self.couplings.hamiltonian)

    def set_coulomb_element(self, orb1, orb2, val):
        """
        Sets a Coulomb interaction element between two orbitals or indices.

        Parameters:
            orb1: Identifier for orbital(s) for the first element.
            orb2: Identifier for orbital(s) for the second element.
            val (complex): The complex value to set for the Coulomb interaction element.
        """
        self._set_coupling(self.filter_orbs(orb1, Orbital), self.filter_orbs(orb2, Orbital), self._ensure_complex(val), self.couplings.coulomb)

    @property
    def center_index(self):
        """index of approximate center orbital of the structure"""
        distances = jnp.round(jnp.linalg.norm(self.positions - self.positions[:, None], axis = -1), 4)
        return jnp.argmin(distances.sum(axis=0))

    def localization(self, neighbor_number : int = 6):
        """Compute edge localization of eigenstates according to

        $$
        \\frac{\sum_{j \, edge} |\phi_{j}|^2}{\sum_i |\phi_i|^2 }
        $$

        Edges are identified based on the number of next-to-next-to nearest neighbors (nnn).

        Args:
            neighbor_number (int): nnn used to identify edges. Depends on lattice and orbital number.
            Defaults to nnn = 6 for the case of a hexagonal lattice with a single orbital per site.
            For more orbitals use nnn * num_orbitals.

        Returns:
            jax.Array: localization, where i-th entry corresponds to i-th energy eigenstate
        """

        # edges => neighboring unit cells are incomplete => all points that are not inside a "big hexagon" made up of nearest neighbors
        positions, states, energies = self.positions, self.eigenvectors, self.energies 

        distances = jnp.round(jnp.linalg.norm(positions - positions[:, None], axis = -1), 4)
        nnn = jnp.unique(distances)[2]
        mask = (distances == nnn).sum(axis=0) < neighbor_number

        # localization => how much eingenstate 
        l = (jnp.abs(states[mask, :])**2).sum(axis = 0) # vectors are normed

        return l

    def _ensure_complex(self, func_or_val):
        if callable(func_or_val):
            return lambda x: func_or_val(x) + 0.0j
        if isinstance(func_or_val, (int, float, complex)):
            return func_or_val + 0.0j
        raise TypeError

    def _build(self):

        assert len(self) > 0

        self._hamiltonian, self._coulomb = self._hamiltonian_coulomb()

        self._eigenvectors, self._energies = jax.lax.linalg.eigh(self._hamiltonian)

        self._initial_density_matrix = _numerics._density_matrix(
            self._energies,
            self.params.electrons,
            self.params.spin_degeneracy,
            self.params.eps,
            self.params.excitation,
            self.params.beta,
        )
        self._stationary_density_matrix = _numerics._density_matrix(
            self._energies,
            self.params.electrons,
            self.params.spin_degeneracy,
            self.params.eps,
            Params(0).excitation,
            self.params.beta,
        )

        if len(self.params.self_consistency_params) != 0:
            (
                self._hamiltonian,
                self._initial_density_matrix,
                self._stationary_density_matrix,
                self._energies,
                self._eigenvectors,
            ) = _numerics._get_self_consistent(
                self._hamiltonian,
                self._coulomb,
                self.positions,
                self.params.excitation,
                self.params.spin_degeneracy,
                self.params.electrons,
                self.params.eps,
                self._eigenvectors,
                self._stationary_density_matrix,
                **self.params.self_consistency_params,
            )

        if len(self.params.mean_field_params) != 0:
            (
                self._hamiltonian,
                self._initial_density_matrix,
                self._stationary_density_matrix,
                self._energies,
                self._eigenvectors,
            ) = _numerics._mf_loop(
                self._hamiltonian,
                self._coulomb,
                self.params.excitation,
                self.params.spin_degeneracy,
                self.params.electrons,
                self.params.eps,
                **self.params.mean_field_params,
            )

        eps = 1e-1
        lower = -eps
        upper = 2 + eps
        if jnp.any(self._initial_density_matrix.diagonal() < lower) or jnp.any(self._initial_density_matrix.diagonal() > upper):
            raise Exception("Occupation numbers in initial density matrix are invalid.")

        if jnp.any(self._stationary_density_matrix.diagonal() < lower) or jnp.any(self._stationary_density_matrix.diagonal() > upper ) :
            raise Exception("Occupation numbers in stationary density matrix are invalid.")

        self._initial_density_matrix = self.transform_to_site_basis( self._initial_density_matrix )

        self._stationary_density_matrix = self.transform_to_site_basis( self._stationary_density_matrix )

    def set_open_shell( self ):
        if any( orb.spin is None for orb in self._list ):
            raise ValueError
        self.params.spin_degeneracy = 1.0

    def set_closed_shell( self ):
        self.params.spin_degeneracy = 2.0

    def index(self, orb):
        return self._list.index(orb)

    @mutates
    def append(self, other):
        """
        Appends an orbital to the list, ensuring it is not already present.

        Parameters:
            other (Orbital): The orbital to append.

        Raises:
            TypeError: If `other` is not an instance of Orbital.
            ValueError: If `other` is already in the list.
        """
        if not isinstance(other, Orbital):
            raise TypeError
        if other in self:
            raise ValueError
        self._list.append(other)
        self.params.electrons += 1

    def filter_orbs( self, orb_id, t ):
        """maps a given orb_id (such as an index or tag) to a list of the required type t"""
        def filter_single_orb(orb_id, t):
            if type(orb_id) == t:
                return [orb_id]

            # index to group, orb, tag => group_id / orb / tag at index,
            if isinstance(orb_id, int) and t == _watchdog.GroupId:
                return [self._list[orb_id].group_id]
            if isinstance(orb_id, int) and t == Orbital:
                return [self._list[orb_id]]
            if isinstance(orb_id, int) and t == str:
                return [self._list[orb_id].tag]

            # group to index, orb, tag => group_id / orb / tag at index,
            if isinstance(orb_id, _watchdog.GroupId) and t == str:
                return [ orb.tag for orb in self if orb.group_id == orb_id ]
            if isinstance(orb_id, _watchdog.GroupId) and t == Orbital:
                return [ orb for orb in self if orb.group_id == orb_id ]
            if isinstance(orb_id, _watchdog.GroupId) and t == int:
                return [ i for i, orb in enumerate(self) if orb.group_id == orb_id ]

            # tag to group, orb, index => group_id / orb / tag at index,
            if isinstance(orb_id, str) and t == _watchdog.GroupId:
                return [orb.group_id for orb in self if orb.tag == orb_id]
            if isinstance(orb_id, str) and t == int:
                return [i for i, orb in enumerate(self) if orb.tag == orb_id]
            if isinstance(orb_id, str) and t == Orbital:
                return [orb for orb in self if orb.tag == orb_id]

            # orb to index, group, tag
            if isinstance(orb_id, Orbital) and t == _watchdog.GroupId:
                return [orb_id.group_id]
            if isinstance(orb_id, Orbital) and t == int:
                return [self._list.index(orb_id)]
            if isinstance(orb_id, Orbital) and t == str:
                return [orb_id.tag]

        if not isinstance(orb_id, OrbitalList):
            orb_id = [orb_id]

        return [ x for orb in orb_id for x in filter_single_orb(orb, t) ]



    @mutates
    def shift_by_vector(self, translation_vector, orb_id = None):
        """
        Shifts all orbitals with a specific tag by a given vector.

        Parameters:
            translation_vector (list or jax.Array): The vector by which to translate the orbital positions.
            orb_id: Identifier for the orbital(s) to shift.

        Note:
            This operation mutates the positions of the matched orbitals.
        """
        filtered_orbs = self.filter_orbs( orb_id, Orbital ) if orb_id is not None else self
        for orb in filtered_orbs:
            orb.position += jnp.array(translation_vector)

    @mutates
    def set_position(self, position, orb_id = None):
        """
        Sets the position of all orbitals with a specific tag.

        Parameters:
            position (list or jax.Array): The vector at which to move the orbitals
            orb_id: Identifier for the orbital(s) to shift.

        Note:
            This operation mutates the positions of the matched orbitals.
        """
        filtered_orbs = self.filter_orbs( orb_id, Orbital ) if orb_id is not None else self
        for orb in filtered_orbs:
            orb.position = position

    @mutates
    def rotate(self, x, phi, axis = 'z'):
        """rotates all orbitals an angle phi around a point p around axis.    

        Args:
        x : jnp.ndarray
            A 3D point around which to rotate.
        phi : float
            Angle by which to rotate.
        axis : str
            Axis to rotate around ('x', 'y', or 'z'). Default is 'z'.
        """

        # Define the rotation matrix based on the specified axis
        if axis == 'x':
            rotation_matrix = jnp.array([
                [1, 0, 0],
                [0, jnp.cos(phi), -jnp.sin(phi)],
                [0, jnp.sin(phi), jnp.cos(phi)]
            ])
        elif axis == 'y':
            rotation_matrix = jnp.array([
                [jnp.cos(phi), 0, jnp.sin(phi)],
                [0, 1, 0],
                [-jnp.sin(phi), 0, jnp.cos(phi)]
            ])
        elif axis == 'z':
            rotation_matrix = jnp.array([
                [jnp.cos(phi), -jnp.sin(phi), 0],
                [jnp.sin(phi), jnp.cos(phi), 0],
                [0, 0, 1]
            ])
        else:
            raise ValueError("Axis must be 'x', 'y', or 'z'.")

        for orb in self._list:
            # Perform the rotation (translate along x, rotate, translate back)
            self.set_position(rotation_matrix @ (orb.position - x) + x, orb)

    @mutates
    def set_self_consistent(self, **kwargs):
        """
        Configures the parameters for self-consistent field (SCF) calculations.

        This function sets up the self-consistency parameters used in iterative calculations 
        to update the system's density matrix until convergence is achieved.

        Args:
            **kwargs: Keyword arguments to override the default self-consistency parameters. 
                The available parameters are:

                - `accuracy` (float, optional): The convergence criterion for self-consistency. 
                  Specifies the maximum allowed difference between successive density matrices.
                  Default is 1e-6.

                - `mix` (float, optional): The mixing factor for the density matrix during updates.
                  This controls the contribution of the new density matrix to the updated one.
                  Values closer to 1 favor the new density matrix, while smaller values favor 
                  smoother convergence. Default is 0.3.

                - `iterations` (int, optional): The maximum number of iterations allowed in the 
                  self-consistency cycle. Default is 500.

                - `coulomb_strength` (float, optional): A scaling factor for the Coulomb matrix.
                  This allows tuning of the strength of Coulomb interactions in the system. 
                  Default is 1.0.

        Example:
            >>> model.set_self_consistent(accuracy=1e-7, mix=0.5, iterations=1000)
            >>> print(model.params.self_consistency_params)
            {'accuracy': 1e-7, 'mix': 0.5, 'iterations': 1000, 'coulomb_strength': 1.0}
        """
        default = {"accuracy" : 1e-6, "mix" : 0.3, "iterations" : 500, "coulomb_strength" : 1.0}
        self.params.self_consistency_params = default | kwargs

    @mutates
    def set_mean_field(self, **kwargs):
        """
        Configures the parameters for mean field calculations.
        If no other parameters are passed, a standard direct channel Hartree-Fock calculation is performed.
        Note that this procedure differs slightly from the self-consistent field procedure.

        This function sets up the mean field parameters used in iterative calculations 
        to update the system's density matrix until convergence is achieved.

        Args:
            **kwargs: Keyword arguments to override the default self-consistency parameters. 
                The available parameters are:

                - `accuracy` (float, optional): The convergence criterion for self-consistency. 
                  Specifies the maximum allowed difference between successive density matrices.
                  Default is 1e-6.

                - `mix` (float, optional): The mixing factor for the density matrix during updates.
                  This controls the contribution of the new density matrix to the updated one.
                  Values closer to 1 favor the new density matrix, while smaller values favor 
                  smoother convergence. Default is 0.3.

                - `iterations` (int, optional): The maximum number of iterations allowed in the 
                  self-consistency cycle. Default is 500.

                - `coulomb_strength` (float, optional): A scaling factor for the Coulomb matrix.
                  This allows tuning of the strength of Coulomb interactions in the system. 
                  Default is 1.0.

                - `f_mean_field` (Callable, optional): A function for computing the mean field term.
                  First argument is density matrix, second argument is single particle hamiltonian.
                  Can be used, e.g., for full HF by passing a closure containing ERIs.       
                  Default is None.

                - `f_build` (Callable, optional): Construction of the density matrix from energies and eigenvectors. If None, single-particle energy levels are filled according to number of electrons.
                  Default is None.

                - `rho_0` (jax.Array, optional): Initial guess for the density matrix. If None, zeros are used.
                   Default is None.

        Example:
            >>> model.set_mean_field(accuracy=1e-7, mix=0.5, iterations=1000)
            >>> print(model.params.mean_field_params)
            {'accuracy': 1e-7, 'mix': 0.5, 'iterations': 1000, 'coulomb_strength': 1.0, 'f_mean_field': None}
        """
        default = {"accuracy" : 1e-6, "mix" : 0.3, "iterations" : 500, "coulomb_strength" : 1.0, "f_mean_field" : None, "f_build" : None, "rho_0" : None}
        self.params.mean_field_params = default | kwargs


    @mutates
    def set_excitation(self, from_state, to_state, excited_electrons):
        """
        Sets up an excitation process from one state to another with specified electrons.

        Parameters:
            from_state (int, list, or jax.Array): The initial state index or indices.
            to_state (int, list, or jax.Array): The final state index or indices.
            excited_electrons (int, list, or jax.Array): The indices of electrons to be excited.

        Note:
            The states and electron indices may be specified as scalars, lists, or arrays.
        """
        def maybe_int_to_arr(maybe_int):
            if isinstance(maybe_int, int):
                return jnp.array([maybe_int])
            if isinstance(maybe_int, list):
                return jnp.array(maybe_int)
            raise TypeError

        self.params.excitation = [maybe_int_to_arr(from_state), maybe_int_to_arr(to_state), maybe_int_to_arr(excited_electrons)]

    @property
    def positions(self):
        return jnp.array([orb.position for orb in self._list])

    @property
    def electrons( self ):
        return self.params.electrons

    @mutates
    def set_electrons( self, val ):
        assert val <= self.params.spin_degeneracy * len(self), "Max electrons exceeded"
        self.params.electrons = val

    @property
    def eps( self ):
        return self.params.eps

    @mutates
    def set_eps( self, val ):
        self.params.eps = val

    @property
    def spin_degeneracy( self ):
        return self.params.spin_degeneracy

    @property
    @recomputes
    def homo(self):
        return (self.electrons * self.stationary_density_matrix_e).real.diagonal().round(2).nonzero()[0][-1].item()

    @property
    @recomputes
    def lumo(self):
        return (self.electrons * self.stationary_density_matrix_e).real.diagonal().round(2).nonzero()[0][-1].item() + 1

    @property
    @recomputes
    def eigenvectors(self):
        return self._eigenvectors

    @property
    @recomputes
    def energies(self):
        return self._energies

    @property
    @recomputes
    def hamiltonian(self):
        return self._hamiltonian

    @property
    @recomputes
    def coulomb(self):
        return self._coulomb

    @property
    @recomputes
    def initial_density_matrix(self):
        return self._initial_density_matrix

    @property
    @recomputes
    def stationary_density_matrix(self):
        return self._stationary_density_matrix

    @property
    @recomputes
    def quadrupole_operator(self):
        """
        Calculates the quadrupole operator based on the dipole operator terms. It combines products of the dipole terms and their differences from the identity matrix scaled by the diagonal components.

        Returns:
           jax.Array: A tensor representing the quadrupole operator.
        """

        dip = self.dipole_operator
        term = jnp.einsum("ijk,jlm->ilkm", dip, dip)
        diag = jnp.einsum("ijk,jlk->il", dip, dip)
        diag = jnp.einsum("ij,kl->ijkl", diag, jnp.eye(term.shape[-1]))
        return 3 * term - diag

    @property
    @recomputes
    def dipole_operator(self):
        """
        Computes the dipole operator using positions and transition values. The diagonal is set by position components, and the off-diagonal elements are set by transition matrix values.

        Returns:
           jax.Array: A 3D tensor representing the dipole operator, symmetrized and complex conjugated.
        """

        N = self.positions.shape[0]
        dipole_operator = jnp.zeros((3, N, N)).astype(complex)
        for i in range(3):
            dipole_operator = dipole_operator.at[i, :, :].set(
                jnp.diag(self.positions[:, i] / 2) #-jnp.average(self.positions[:,i]))
            )
        for orbital_combination, value in self.couplings.dipole_transitions.items():
            i, j = self._list.index(orbital_combination[0]), self._list.index(
                orbital_combination[1]
            )
            k = value.nonzero()[0]
            dipole_operator = dipole_operator.at[k, i, j].set(value[k])

        return dipole_operator + jnp.transpose(dipole_operator, (0, 2, 1)).conj()

    @property
    @recomputes
    def velocity_operator(self):
        """
        Calculates the velocity operator as the commutator of position with the Hamiltonian using matrix multiplications.

        Returns:
           jax.Array: A tensor representing the velocity operator, computed as a differential of position and Hamiltonian.
        """

        if self.couplings.dipole_transitions is None:
            x_times_h = jnp.einsum("ij,iL->ijL", self._hamiltonian, self.positions)
            h_times = jnp.einsum("ij,jL->ijL", self._hamiltonian, self.positions)
        else:
            positions = self.dipole_operator
            x_times_h = jnp.einsum("kj,Lik->Lij", self._hamiltonian, positions)
            h_times = jnp.einsum("ik,Lkj->Lij", self._hamiltonian, positions)
        return -1j * (x_times_h - h_times)

    @property
    @recomputes
    def oam_operator(self):
        """
        Calculates the orbital angular momentum operator from the dipole $P$ and velocity operator $J$ as $L_{k} = \epsilon_{ijk} P_j J_k$.

        Returns:
           jax.Array: A 3 x N x N tensor representing the orbital angular momentum operator
        """
        epsilon = jnp.array([[[ 0,  0,  0],
                             [ 0,  0,  1],
                             [ 0, -1,  0]],
                            [[ 0,  0, -1],
                             [ 0,  0,  0],
                             [ 1,  0,  0]],
                            [[ 0,  1,  0],
                             [-1,  0,  0],
                             [ 0,  0,  0]]])

        return jnp.einsum('ijk,jlm,kmn->iln', epsilon, self.dipole_operator, self.velocity_operator)

    @property
    @recomputes
    def transition_energies(self):
        """
        Computes independent-particle transition energies associated with the TB-Hamiltonian of a stack.

        Returns:
           jax.Array: The element `arr[i,j]` contains the transition energy from `i` to `j`.
        """
        return self._energies[:, None] - self._energies

    @property
    @recomputes

    def wigner_weisskopf_transition_rates(self):
        """
        Calculates Wigner-Weisskopf transition rates based on transition energies and dipole moments transformed to the energy basis.

        Returns:
           jax.Array: The element `arr[i,j]` contains the transition rate from `i` to `j`.
        """
        charge = 1.602e-19   # C
        eps_0 = 8.85 * 1e-12 # F/m
        hbar = 1.0545718 * 1e-34 # Js
        c = 3e8  # 137 (a.u.) # m/s
        angstroem = 1e-10 # m
        factor = (charge/hbar)**3 * (charge*angstroem)**2  / (3 * jnp.pi * eps_0 * hbar * c**3)
        te = self.transition_energies
        transition_dipole_moments_squared = jnp.sum(self.dipole_operator_e**2, axis = 0)
        factor2 = hbar/charge # transfer Gamma back to code units
        return (
            (te * (te > self.eps)) ** 3
            * transition_dipole_moments_squared
            * factor * factor2
        ).real

    @staticmethod
    def _transform_basis(observable, vectors):
        dims_einsum_strings = {2: "ij,jk,lk->il", 3: "ij,mjk,lk->mil"}
        einsum_string = dims_einsum_strings[(observable.ndim)]
        return jnp.einsum(einsum_string, vectors, observable, vectors.conj())

    def transform_to_site_basis(self, observable):
        """
        Transforms an observable to the site basis using eigenvectors of the system.

        Parameters:
           observable (jax.Array): The observable to transform.

        Returns:
           jax.Array: The transformed observable in the site basis.
        """
        return self._transform_basis(observable, self._eigenvectors)

    def transform_to_energy_basis(self, observable):
        """
        Transforms an observable to the energy basis using the conjugate transpose of the system's eigenvectors.

        Parameters:
           observable (jax.Array): The observable to transform.

        Returns:
           jax.Array: The transformed observable in the energy basis.
        """

        return self._transform_basis(observable, self._eigenvectors.conj().T)

    @recomputes
    def get_charge(self, density_matrix = None):
        """
        Calculates the charge distribution from a given density matrix or from the initial density matrix if not specified.

        Parameters:
           density_matrix (jax.Array, optional): The density matrix to use for calculating charge. 
                                                 If omitted, the initial density matrix is used.

        Returns:
           jax.Array: A diagonal array representing charges at each site.
        """
        density_matrix = self.initial_density_matrix if density_matrix is None else density_matrix
        return jnp.diag(density_matrix * self.electrons)

    @recomputes
    def get_dos(self, omega: float, broadening: float = 0.1):
        """
        Calculates the density of states (DOS) of a nanomaterial stack at a given frequency with broadening.

        Parameters:
           omega (float): The frequency at which to evaluate the DOS.
           broadening (float, optional): The numerical broadening parameter to replace Dirac Deltas.

        Returns:
           float: The integrated density of states at the specified frequency.
        """

        broadening = 1 / broadening
        prefactor = 1 / (jnp.sqrt(2 * jnp.pi) * broadening)
        gaussians = jnp.exp(-((self._energies - omega) ** 2) / 2 * broadening**2)
        return prefactor * jnp.sum(gaussians)

    @recomputes
    def get_ldos(self, omega: float, site_index: int, broadening: float = 0.1):
        """
        Calculates the local density of states (LDOS) at a specific site and frequency within a nanomaterial stack.

        Parameters:
           omega (float): The frequency at which to evaluate the LDOS.
           site_index (int): The site index to evaluate the LDOS at.
           broadening (float, optional): The numerical broadening parameter to replace Dirac Deltas.

        Returns:
           float: The local density of states at the specified site and frequency.
        """

        broadening = 1 / broadening
        weight = jnp.abs(self._eigenvectors[site_index, :]) ** 2
        prefactor = 1 / (jnp.sqrt(2 * jnp.pi) * broadening)
        gaussians = jnp.exp(-((self._energies - omega) ** 2) / 2 * broadening**2)
        return prefactor * jnp.sum(weight * gaussians)

    @recomputes
    def get_epi(self, density_matrix_stat: jax.Array, omega: float, epsilon: float = None) -> float:
        """
        Calculates the energy-based plasmonicity index (EPI) for a given density matrix and frequency.

        Parameters:
           density_matrix_stat (jax.Array): The density matrix to consider for EPI calculation.
           omega (float): The frequency to evaluate the EPI at.
           epsilon (float, optional): The small imaginary part to stabilize the calculation, defaults to internal epsilon if not provided.

        Returns:
           float: The EPI.
        """

        epsilon = epsilon if epsilon is not None else self.eps
        density_matrix_stat_without_diagonal = jnp.abs(density_matrix_stat - jnp.diag(jnp.diag(density_matrix_stat)))
        density_matrix_stat_normalized = density_matrix_stat_without_diagonal / jnp.linalg.norm(density_matrix_stat_without_diagonal)
        te = self.transition_energies
        excitonic_transitions = (
            density_matrix_stat_normalized / (te * (te > self.eps) - omega + 1j * epsilon) ** 2
        )
        return 1 - jnp.sum(jnp.abs(excitonic_transitions * density_matrix_stat_normalized)) / (
            jnp.linalg.norm(density_matrix_stat_normalized) * jnp.linalg.norm(excitonic_transitions)
        )

    @recomputes
    def get_induced_field(self, positions: jax.Array, density_matrix):
        """
        Calculates the induced electric field at specified positions based on a given density matrix.

        Parameters:
           positions (jax.Array): The positions at which to evaluate the induced field.
           density_matrix (jax.Array): The density matrix used to calculate the induced field.

        Returns:
           jax.Array: The resulting electric field vector at each position.
        """


        # distance vector array from field sources to positions to evaluate field on
        vec_r = self.positions[:, None] - positions

        # scalar distances
        denominator = jnp.linalg.norm(vec_r, axis=2) ** 3

        # normalize distance vector array
        point_charge = jnp.nan_to_num(
            vec_r / denominator[:, :, None], posinf=0.0, neginf=0.0
        )

        # compute charge via occupations in site basis
        charge = self.electrons * density_matrix.real

        # induced field is a sum of point charges, i.e. \vec{r} / r^3
        e_field = 14.39 * jnp.sum(point_charge * charge[:, None, None], axis=0)
        return e_field        

    def get_expectation_value(self, *, operator, density_matrix, induced = True):
        """
        Calculates the expectation value of an operator with respect to a given density matrix using tensor contractions specified for different dimensionalities of the input arrays.

        Parameters:
           operator (jax.Array): The operator for which the expectation value is calculated.
           density_matrix (jax.Array): The density matrix representing the state of the system.

        Returns:
           jax.Array: The calculated expectation value(s) depending on the dimensions of the operator and the density matrix.
        """

        dims_einsum_strings = {
            (3, 2): "ijk,kj->i",
            (3, 3): "ijk,lkj->li",
            (2, 3): "ij,kji->k",
            (2, 2): "ij,ji->",
        }
        correction = self.stationary_density_matrix_x if induced == True else 0
        return self.electrons * jnp.einsum(
            dims_einsum_strings[(operator.ndim, density_matrix.ndim)],
            operator,
            correction - density_matrix,
        )

    def get_args( self, relaxation_rate = 0.0, coulomb_strength = 1.0, propagator = None):
        return TDArgs(
            self.hamiltonian,
            self.energies,
            self.coulomb * coulomb_strength,
            self.initial_density_matrix,
            self.stationary_density_matrix,
            self.eigenvectors,
            self.dipole_operator,
            self.electrons,
            relaxation_rate,
            propagator,
            self.spin_degeneracy,
            self.positions
            )

    @staticmethod
    def get_hamiltonian(illumination = None, use_rwa = False, add_induced = False):
        """Dict holding terms of the default hamiltonian: bare + coulomb + dipole gauge coupling to external field  + (optional) induced field (optionally in RWA)"""
        contents = {}
        contents["bare_hamiltonian"] = potentials.BareHamiltonian()
        contents["coulomb"] = potentials.Coulomb()
        if illumination is not None:
            contents["potential"] = potentials.DipoleGauge(illumination, use_rwa)
        if add_induced == True:
            contents["induced"] = potentials.Induced( )
        return contents

    @staticmethod
    def get_dissipator(relaxation_rate = None, saturation = None):
        """Dict holding the term of the default dissipator: either decoherence time from relaxation_rate as float and ignored saturation or lindblad from relaxation_rate as array and saturation function"""
        if relaxation_rate is None and saturation is None:
            return {"no_dissipation" : lambda t, r, args : 0.0}
        if isinstance(relaxation_rate, float):
            return { "decoherence_time" : dissipators.DecoherenceTime() }
        func  = (lambda x: 1 / (1 + jnp.exp(-1e6 * (2.0 - x)))) if saturation is None else saturation
        return {"lindblad" : dissipators.SaturationLindblad(func) }        

    def get_postprocesses( self, expectation_values, density_matrix ):
        postprocesses = {}
        if isinstance(expectation_values, jax.Array):
            expectation_values = [expectation_values]
        if expectation_values is not None:
            ops = jnp.concatenate( expectation_values)
            postprocesses["expectation_values"] = lambda rho, args: self.get_expectation_value(operator=ops,density_matrix=rho)

        if density_matrix is None:
            return postprocesses

        if isinstance(density_matrix, str):
            density_matrix = [density_matrix]
        for option in density_matrix:
            if option == "occ_x":
                postprocesses[option] = lambda rho, args: args.electrons * jnp.diagonal(rho, axis1=-1, axis2=-2) 
            elif option == "occ_e":
                postprocesses[option] = lambda rho, args: args.electrons * jnp.diagonal( args.eigenvectors.conj().T @ rho @ args.eigenvectors, axis1=-1, axis2=-2) 
            elif option == "full":
                postprocesses[option] = lambda rho, args: rho
            elif option == "diag_x":
                postprocesses[option] = lambda rho, args: jnp.diagonal(rho, axis1=-1, axis2=-2)
            elif option == "diag_e":
                postprocesses[option] = lambda rho, args: jnp.diagonal( args.eigenvectors.conj().T @ rho @ args.eigenvectors, axis1=-1, axis2=-2) 


        return postprocesses


    @recomputes
    def master_equation(            
            self,
            *,
            end_time : float,
            start_time : float = 0.0,
            dt : float = 1e-4,
            grid : Union[int, jax.Array] = 100,
            max_mem_gb : float = 0.5,

            initial_density_matrix : Optional[jax.Array] = None,

            coulomb_strength : float = 1.0,

            illumination : Callable = None,

            relaxation_rate : Optional[Union[float, jax.Array]] = None,

            compute_at : Optional[jax.Array] = None,

            expectation_values : Optional[list[jax.Array]] = None,
            density_matrix : Optional[list[str]] = None,

            use_rwa : bool = False,

            solver = diffrax.Dopri5(),
            stepsize_controller = diffrax.PIDController(rtol=1e-10,atol=1e-10),

            hamiltonian : dict = None,
            dissipator : dict = None,
            postprocesses : dict = None,
            rhs_args = None,

    ):
        """
        Simulates the time evolution of the density matrix, computing observables, density matrices or extracting custom information.

        Args:
            end_time (float): The final time for the simulation.
            start_time (float): The starting time for the simulation. Defaults to 0.0.
            dt (float): The time step size for the simulation. Defaults to 1e-4.
            grid (Union[int, jax.Array]): Determines the output times for the simulation results. If an integer, results
                                          are saved every 'grid'-th time step. If an array, results are saved at the
                                          specified times.
            max_mem_gb (float): Maximum memory in gigabytes allowed for each batch of intermediate density matrices.
            initial_density_matrix (Optional[jax.Array]): The initial state of the density matrix. If not provided,
                                                          `self.initial_density_matrix` is used.
            coulomb_strength (float): Scaling factor for the Coulomb interaction matrix.
            illumination (Callable): Function describing the time-dependent external illumination applied to the system.
            relaxation_rate (Union[float, jax.Array, Callable]): Specifies the relaxation dynamics. A float indicates a
                                                                 uniform decoherence time, an array provides state-specific
                                                                 rates.
            compute_at (Optional[jax.Array]): The orbitals indexed by this array will experience induced fields.
            expectation_values (Optional[list[jax.Array]]): Expectation values to compute during the simulation.
            density_matrix (Optional[list[str]]): Tags for additional density matrix computations. "full", "occ_x", "occ_e", "diag_x", "diag_e". May be deprecated.
            computation (Optional[Callable]): Additional computation to be performed at each step.
            use_rwa (bool): Whether to use the rotating wave approximation. Defaults to False.
            solver: The numerical solver instance to use for integrating the differential equations.
            stepsize_controller: Controller for adjusting the solver's step size based on error tolerance.
            hamiltonian: dict of functions representing terms in the hamiltonian. functions must have signature `t, r, args->jax.Array`. keys don't matter.
            dissipator:: dict of functions representing terms in the dissipator. functions must have signature `t, r, args->jax.Array`. keys don't matter.
            postprocesses: (bool): dict of functions representing information to extract from the simulation. functions must have signature `r, args->jax.Array`. keys don't matter.
            rhs_args: arguments passed to hamiltonian, dissipator, postprocesses during the simulation. namedtuple.

        Returns:
            ResultTD
        """


        # arguments to evolution function
        if rhs_args is None:
            rhs_args = self.get_args( relaxation_rate,
                                      coulomb_strength,
                                      _numerics.get_coulomb_field_to_from(self.positions, self.positions, compute_at) )

        if illumination is None:
            illumination = lambda t : jnp.array( [0j, 0j, 0j] )

        # each of these functions is applied to a density matrix batch
        postprocesses = self.get_postprocesses( expectation_values, density_matrix ) if postprocesses is None else postprocesses

        # hermitian rhs
        hamiltonian = self.get_hamiltonian(illumination, use_rwa, compute_at is not None) if hamiltonian is None else hamiltonian

        # non hermitian rhs
        dissipator = self.get_dissipator(relaxation_rate, None) if dissipator is None else dissipator

        # set reasonable default 
        initial_density_matrix = initial_density_matrix if initial_density_matrix is not None else rhs_args.initial_density_matrix

        try:        
            return self._integrate_master_equation( list(hamiltonian.values()), list(dissipator.values()), list(postprocesses.values()), rhs_args, illumination, solver, stepsize_controller, initial_density_matrix, start_time, end_time, grid, max_mem_gb, dt )
        except Exception as e:
            print(f"Simulation crashed with exception {e}. Try increasing the time mesh and make your sure your illumination is differentiable. The full diffrax traceback follows below.")
            traceback.print_stack()

    @staticmethod
    def _integrate_master_equation( hamiltonian, dissipator, postprocesses, rhs_args, illumination, solver, stepsize_controller, initial_density_matrix, start_time, end_time, grid, max_mem_gb, dt ):

        # batched time axis to save memory 
        mat_size = initial_density_matrix.size * initial_density_matrix.itemsize / 1e9
        time_axis = _numerics.get_time_axis( mat_size = mat_size, grid = grid, start_time = start_time, end_time = end_time, max_mem_gb = max_mem_gb, dt = dt )

        ## integrate
        final, output = _numerics.td_run(
            initial_density_matrix,
            _numerics.get_integrator(hamiltonian, dissipator, postprocesses, solver, stepsize_controller, dt),
            time_axis,
            rhs_args)

        return TDResult(
            td_illumination = jax.vmap(illumination)(jnp.concatenate(time_axis)) ,
            output = output,
            final_density_matrix = final,
            time_axis = jnp.concatenate( time_axis )
        )


    def get_ip_green_function(self, A, B, omegas, occupations = None, energies = None, mask = None, relaxation_rate = 1e-1):
        """independent-particle greens function at the specified frequency according to 

        $$
        G_{AB}(\omega) = \sum_{nm} \\frac{P_m - P_n}{\omega + E_m - E_n + i e} A_{nm} B_{mn}
        $$

        Parameters: 
          A, B : operators *in energy basis*, square jax.Array
          omegas (jax.Array) : frequency grid
          rho_e (jax.Array) : energy occupations, if omitted, current density matrix diagonal is used
          energies (jax.Array) : energies, if omitted, current energies are used
          mask (jax.Array): boolean mask excluding energy states from the summation
          relaxation_rate (float): broadening parameter

        Returns:
          jax.Array: Values of the Green's function
        """

        def inner(omega):
            return jnp.trace( (delta_occ / (omega + delta_e + 1j*relaxation_rate)) @ operator_product)

        print("Computing Greens function. Remember we default to site basis")

        operator_product =  A.T * B
        occupations = self.initial_density_matrix_e.diagonal() * self.electrons if occupations is None else occupations
        energies = self.energies if energies is None else energies        
        delta_occ = (occupations[:, None] - occupations)
        if mask is not None:        
            delta_occ = delta_occ.at[mask].set(0) 
        delta_e = energies[:, None] - energies

        return jax.lax.map(jax.jit(inner), omegas)

    def get_polarizability_rpa(
        self,
        omegas,            
        polarization,
        coulomb_strength=1.0,
        relaxation_rate=1/10,
        hungry=0,
        phi_ext=None,
        args = None,
    ):
        """
        Calculates the random phase approximation (RPA) polarizability of the system at given frequencies under specified conditions.

        Parameters:
           omegas (jax.Array): Frequencies at which to calculate polarizability. If given as an nxm array, this function will be applied vectorized to the batches given by the last axis in omegas.
           relaxation_rate (float): The relaxation time parameter.
           polarization (jax.Array): Polarization directions or modes.
           coulomb_strength (float): The strength of Coulomb interaction in the calculations.
           hungry (int): speed up the simulation up, higher numbers (max 2) increase RAM usage.
           phi_ext (Optional[jax.Array]): External potential influences, if any.
           args (Optional): numeric representation of an orbital list, as obtained by `get_args`

        Returns:
           jax.Array: The calculated polarizabilities at the specified frequencies.
        """

        if args is None:
            args = self.get_args(relaxation_rate = relaxation_rate, coulomb_strength = coulomb_strength, propagator = None)
        alpha = _numerics.rpa_polarizability_function(args, polarization, hungry, phi_ext)
        if omegas.ndim == 1:        
            return jax.lax.map(alpha, omegas)
        else:
            return jnp.concatenate( [ jax.vmap(alpha)(omega) for omega in omegas ] )

    def get_susceptibility_rpa(
            self, omegas, relaxation_rate=1/10, coulomb_strength=1.0, hungry=0, args = None,
    ):
        """
        Computes the random phase approximation (RPA) susceptibility of the system over a range of frequencies.

        Parameters:
           omegas (jax.Array): The frequencies at which to compute susceptibility.
           relaxation_rate (float): The relaxation time affecting susceptibility calculations.
           coulomb_strength (float): The strength of Coulomb interactions considered in the calculations.
           hungry (int): speed up the simulation up, higher numbers (max 2) increase RAM usage.
           args (Optional): numeric representation of an orbital list, as obtained by `get_args`

        Returns:
           jax.Array: The susceptibility values at the given frequencies.
        """
        if args is None:
            args = self.get_args(relaxation_rate = relaxation_rate, coulomb_strength = coulomb_strength, propagator = None)
        sus = _numerics.rpa_polarizability_function( args, hungry )
        return jax.lax.map(sus, omegas)    

    @property
    def atoms( self ):
        atoms_pos = defaultdict(list)
        for orb in self._list:
            atoms_pos[orb.atom_name] += [[str(x) for x in orb.position]]
        return atoms_pos

    def to_xyz( self, name : str = None ):
        atoms = self.atoms
        number_of_atoms = sum( [len(x) for x in atoms.values()] )
        str_rep = str(number_of_atoms) + "\n\n"

        for atom, positions in atoms.items():
            for pos in positions:
                str_rep += f'{atom} {" ".join(pos)}\n'

        if name is None:
            return str_rep

        with open( name, "w" ) as f:
            f.writelines(str_rep)

    @classmethod
    def from_xyz( cls, name : str ):
        orbs, group_id = [], _watchdog._Watchdog.next_value()        
        with open(name, 'r') as f:
            for line in f:
                processed = line.strip().split()
                if len(processed) <= 1:
                    continue                
                atom_name, x, y, z = processed

                orbs.append( Orbital( group_id = group_id,
                                      atom_name = atom_name,
                                      position = [float(x), float(y), float(z)] )  )
        return cls( orbs )

center_index property

index of approximate center orbital of the structure

dipole_operator property

Computes the dipole operator using positions and transition values. The diagonal is set by position components, and the off-diagonal elements are set by transition matrix values.

Returns:
  • jax.Array: A 3D tensor representing the dipole operator, symmetrized and complex conjugated.

oam_operator property

Calculates the orbital angular momentum operator from the dipole \(P\) and velocity operator \(J\) as \(L_{k} = \epsilon_{ijk} P_j J_k\).

Returns:
  • jax.Array: A 3 x N x N tensor representing the orbital angular momentum operator

quadrupole_operator property

Calculates the quadrupole operator based on the dipole operator terms. It combines products of the dipole terms and their differences from the identity matrix scaled by the diagonal components.

Returns:
  • jax.Array: A tensor representing the quadrupole operator.

transition_energies property

Computes independent-particle transition energies associated with the TB-Hamiltonian of a stack.

Returns:
  • jax.Array: The element arr[i,j] contains the transition energy from i to j.

velocity_operator property

Calculates the velocity operator as the commutator of position with the Hamiltonian using matrix multiplications.

Returns:
  • jax.Array: A tensor representing the velocity operator, computed as a differential of position and Hamiltonian.

wigner_weisskopf_transition_rates property

Calculates Wigner-Weisskopf transition rates based on transition energies and dipole moments transformed to the energy basis.

Returns:
  • jax.Array: The element arr[i,j] contains the transition rate from i to j.

append(other)

Appends an orbital to the list, ensuring it is not already present.

Parameters:
  • other (Orbital) –

    The orbital to append.

Raises:
  • TypeError

    If other is not an instance of Orbital.

  • ValueError

    If other is already in the list.

Source code in src/granad/orbitals.py
@mutates
def append(self, other):
    """
    Appends an orbital to the list, ensuring it is not already present.

    Parameters:
        other (Orbital): The orbital to append.

    Raises:
        TypeError: If `other` is not an instance of Orbital.
        ValueError: If `other` is already in the list.
    """
    if not isinstance(other, Orbital):
        raise TypeError
    if other in self:
        raise ValueError
    self._list.append(other)
    self.params.electrons += 1

filter_orbs(orb_id, t)

maps a given orb_id (such as an index or tag) to a list of the required type t

Source code in src/granad/orbitals.py
def filter_orbs( self, orb_id, t ):
    """maps a given orb_id (such as an index or tag) to a list of the required type t"""
    def filter_single_orb(orb_id, t):
        if type(orb_id) == t:
            return [orb_id]

        # index to group, orb, tag => group_id / orb / tag at index,
        if isinstance(orb_id, int) and t == _watchdog.GroupId:
            return [self._list[orb_id].group_id]
        if isinstance(orb_id, int) and t == Orbital:
            return [self._list[orb_id]]
        if isinstance(orb_id, int) and t == str:
            return [self._list[orb_id].tag]

        # group to index, orb, tag => group_id / orb / tag at index,
        if isinstance(orb_id, _watchdog.GroupId) and t == str:
            return [ orb.tag for orb in self if orb.group_id == orb_id ]
        if isinstance(orb_id, _watchdog.GroupId) and t == Orbital:
            return [ orb for orb in self if orb.group_id == orb_id ]
        if isinstance(orb_id, _watchdog.GroupId) and t == int:
            return [ i for i, orb in enumerate(self) if orb.group_id == orb_id ]

        # tag to group, orb, index => group_id / orb / tag at index,
        if isinstance(orb_id, str) and t == _watchdog.GroupId:
            return [orb.group_id for orb in self if orb.tag == orb_id]
        if isinstance(orb_id, str) and t == int:
            return [i for i, orb in enumerate(self) if orb.tag == orb_id]
        if isinstance(orb_id, str) and t == Orbital:
            return [orb for orb in self if orb.tag == orb_id]

        # orb to index, group, tag
        if isinstance(orb_id, Orbital) and t == _watchdog.GroupId:
            return [orb_id.group_id]
        if isinstance(orb_id, Orbital) and t == int:
            return [self._list.index(orb_id)]
        if isinstance(orb_id, Orbital) and t == str:
            return [orb_id.tag]

    if not isinstance(orb_id, OrbitalList):
        orb_id = [orb_id]

    return [ x for orb in orb_id for x in filter_single_orb(orb, t) ]

get_charge(density_matrix=None)

Calculates the charge distribution from a given density matrix or from the initial density matrix if not specified.

Parameters:
  • density_matrix (Array, default: None ) –

    The density matrix to use for calculating charge. If omitted, the initial density matrix is used.

Returns:
  • jax.Array: A diagonal array representing charges at each site.

Source code in src/granad/orbitals.py
@recomputes
def get_charge(self, density_matrix = None):
    """
    Calculates the charge distribution from a given density matrix or from the initial density matrix if not specified.

    Parameters:
       density_matrix (jax.Array, optional): The density matrix to use for calculating charge. 
                                             If omitted, the initial density matrix is used.

    Returns:
       jax.Array: A diagonal array representing charges at each site.
    """
    density_matrix = self.initial_density_matrix if density_matrix is None else density_matrix
    return jnp.diag(density_matrix * self.electrons)

get_dissipator(relaxation_rate=None, saturation=None) staticmethod

Dict holding the term of the default dissipator: either decoherence time from relaxation_rate as float and ignored saturation or lindblad from relaxation_rate as array and saturation function

Source code in src/granad/orbitals.py
@staticmethod
def get_dissipator(relaxation_rate = None, saturation = None):
    """Dict holding the term of the default dissipator: either decoherence time from relaxation_rate as float and ignored saturation or lindblad from relaxation_rate as array and saturation function"""
    if relaxation_rate is None and saturation is None:
        return {"no_dissipation" : lambda t, r, args : 0.0}
    if isinstance(relaxation_rate, float):
        return { "decoherence_time" : dissipators.DecoherenceTime() }
    func  = (lambda x: 1 / (1 + jnp.exp(-1e6 * (2.0 - x)))) if saturation is None else saturation
    return {"lindblad" : dissipators.SaturationLindblad(func) }        

get_dos(omega, broadening=0.1)

Calculates the density of states (DOS) of a nanomaterial stack at a given frequency with broadening.

Parameters:
  • omega (float) –

    The frequency at which to evaluate the DOS.

  • broadening (float, default: 0.1 ) –

    The numerical broadening parameter to replace Dirac Deltas.

Returns:
  • float

    The integrated density of states at the specified frequency.

Source code in src/granad/orbitals.py
@recomputes
def get_dos(self, omega: float, broadening: float = 0.1):
    """
    Calculates the density of states (DOS) of a nanomaterial stack at a given frequency with broadening.

    Parameters:
       omega (float): The frequency at which to evaluate the DOS.
       broadening (float, optional): The numerical broadening parameter to replace Dirac Deltas.

    Returns:
       float: The integrated density of states at the specified frequency.
    """

    broadening = 1 / broadening
    prefactor = 1 / (jnp.sqrt(2 * jnp.pi) * broadening)
    gaussians = jnp.exp(-((self._energies - omega) ** 2) / 2 * broadening**2)
    return prefactor * jnp.sum(gaussians)

get_epi(density_matrix_stat, omega, epsilon=None)

Calculates the energy-based plasmonicity index (EPI) for a given density matrix and frequency.

Parameters:
  • density_matrix_stat (Array) –

    The density matrix to consider for EPI calculation.

  • omega (float) –

    The frequency to evaluate the EPI at.

  • epsilon (float, default: None ) –

    The small imaginary part to stabilize the calculation, defaults to internal epsilon if not provided.

Returns:
  • float( float ) –

    The EPI.

Source code in src/granad/orbitals.py
@recomputes
def get_epi(self, density_matrix_stat: jax.Array, omega: float, epsilon: float = None) -> float:
    """
    Calculates the energy-based plasmonicity index (EPI) for a given density matrix and frequency.

    Parameters:
       density_matrix_stat (jax.Array): The density matrix to consider for EPI calculation.
       omega (float): The frequency to evaluate the EPI at.
       epsilon (float, optional): The small imaginary part to stabilize the calculation, defaults to internal epsilon if not provided.

    Returns:
       float: The EPI.
    """

    epsilon = epsilon if epsilon is not None else self.eps
    density_matrix_stat_without_diagonal = jnp.abs(density_matrix_stat - jnp.diag(jnp.diag(density_matrix_stat)))
    density_matrix_stat_normalized = density_matrix_stat_without_diagonal / jnp.linalg.norm(density_matrix_stat_without_diagonal)
    te = self.transition_energies
    excitonic_transitions = (
        density_matrix_stat_normalized / (te * (te > self.eps) - omega + 1j * epsilon) ** 2
    )
    return 1 - jnp.sum(jnp.abs(excitonic_transitions * density_matrix_stat_normalized)) / (
        jnp.linalg.norm(density_matrix_stat_normalized) * jnp.linalg.norm(excitonic_transitions)
    )

get_expectation_value(*, operator, density_matrix, induced=True)

Calculates the expectation value of an operator with respect to a given density matrix using tensor contractions specified for different dimensionalities of the input arrays.

Parameters:
  • operator (Array) –

    The operator for which the expectation value is calculated.

  • density_matrix (Array) –

    The density matrix representing the state of the system.

Returns:
  • jax.Array: The calculated expectation value(s) depending on the dimensions of the operator and the density matrix.

Source code in src/granad/orbitals.py
def get_expectation_value(self, *, operator, density_matrix, induced = True):
    """
    Calculates the expectation value of an operator with respect to a given density matrix using tensor contractions specified for different dimensionalities of the input arrays.

    Parameters:
       operator (jax.Array): The operator for which the expectation value is calculated.
       density_matrix (jax.Array): The density matrix representing the state of the system.

    Returns:
       jax.Array: The calculated expectation value(s) depending on the dimensions of the operator and the density matrix.
    """

    dims_einsum_strings = {
        (3, 2): "ijk,kj->i",
        (3, 3): "ijk,lkj->li",
        (2, 3): "ij,kji->k",
        (2, 2): "ij,ji->",
    }
    correction = self.stationary_density_matrix_x if induced == True else 0
    return self.electrons * jnp.einsum(
        dims_einsum_strings[(operator.ndim, density_matrix.ndim)],
        operator,
        correction - density_matrix,
    )

get_hamiltonian(illumination=None, use_rwa=False, add_induced=False) staticmethod

Dict holding terms of the default hamiltonian: bare + coulomb + dipole gauge coupling to external field + (optional) induced field (optionally in RWA)

Source code in src/granad/orbitals.py
@staticmethod
def get_hamiltonian(illumination = None, use_rwa = False, add_induced = False):
    """Dict holding terms of the default hamiltonian: bare + coulomb + dipole gauge coupling to external field  + (optional) induced field (optionally in RWA)"""
    contents = {}
    contents["bare_hamiltonian"] = potentials.BareHamiltonian()
    contents["coulomb"] = potentials.Coulomb()
    if illumination is not None:
        contents["potential"] = potentials.DipoleGauge(illumination, use_rwa)
    if add_induced == True:
        contents["induced"] = potentials.Induced( )
    return contents

get_induced_field(positions, density_matrix)

Calculates the induced electric field at specified positions based on a given density matrix.

Parameters:
  • positions (Array) –

    The positions at which to evaluate the induced field.

  • density_matrix (Array) –

    The density matrix used to calculate the induced field.

Returns:
  • jax.Array: The resulting electric field vector at each position.

Source code in src/granad/orbitals.py
@recomputes
def get_induced_field(self, positions: jax.Array, density_matrix):
    """
    Calculates the induced electric field at specified positions based on a given density matrix.

    Parameters:
       positions (jax.Array): The positions at which to evaluate the induced field.
       density_matrix (jax.Array): The density matrix used to calculate the induced field.

    Returns:
       jax.Array: The resulting electric field vector at each position.
    """


    # distance vector array from field sources to positions to evaluate field on
    vec_r = self.positions[:, None] - positions

    # scalar distances
    denominator = jnp.linalg.norm(vec_r, axis=2) ** 3

    # normalize distance vector array
    point_charge = jnp.nan_to_num(
        vec_r / denominator[:, :, None], posinf=0.0, neginf=0.0
    )

    # compute charge via occupations in site basis
    charge = self.electrons * density_matrix.real

    # induced field is a sum of point charges, i.e. \vec{r} / r^3
    e_field = 14.39 * jnp.sum(point_charge * charge[:, None, None], axis=0)
    return e_field        

get_ip_green_function(A, B, omegas, occupations=None, energies=None, mask=None, relaxation_rate=0.1)

independent-particle greens function at the specified frequency according to

\[ G_{AB}(\omega) = \sum_{nm} \frac{P_m - P_n}{\omega + E_m - E_n + i e} A_{nm} B_{mn} \]
Parameters:
  • A, (B) –

    operators in energy basis, square jax.Array

  • omegas (jax.Array) ) –

    frequency grid

  • rho_e (jax.Array) ) –

    energy occupations, if omitted, current density matrix diagonal is used

  • energies (jax.Array) , default: None ) –

    energies, if omitted, current energies are used

  • mask (Array, default: None ) –

    boolean mask excluding energy states from the summation

  • relaxation_rate (float, default: 0.1 ) –

    broadening parameter

Returns:
  • jax.Array: Values of the Green's function

Source code in src/granad/orbitals.py
def get_ip_green_function(self, A, B, omegas, occupations = None, energies = None, mask = None, relaxation_rate = 1e-1):
    """independent-particle greens function at the specified frequency according to 

    $$
    G_{AB}(\omega) = \sum_{nm} \\frac{P_m - P_n}{\omega + E_m - E_n + i e} A_{nm} B_{mn}
    $$

    Parameters: 
      A, B : operators *in energy basis*, square jax.Array
      omegas (jax.Array) : frequency grid
      rho_e (jax.Array) : energy occupations, if omitted, current density matrix diagonal is used
      energies (jax.Array) : energies, if omitted, current energies are used
      mask (jax.Array): boolean mask excluding energy states from the summation
      relaxation_rate (float): broadening parameter

    Returns:
      jax.Array: Values of the Green's function
    """

    def inner(omega):
        return jnp.trace( (delta_occ / (omega + delta_e + 1j*relaxation_rate)) @ operator_product)

    print("Computing Greens function. Remember we default to site basis")

    operator_product =  A.T * B
    occupations = self.initial_density_matrix_e.diagonal() * self.electrons if occupations is None else occupations
    energies = self.energies if energies is None else energies        
    delta_occ = (occupations[:, None] - occupations)
    if mask is not None:        
        delta_occ = delta_occ.at[mask].set(0) 
    delta_e = energies[:, None] - energies

    return jax.lax.map(jax.jit(inner), omegas)

get_ldos(omega, site_index, broadening=0.1)

Calculates the local density of states (LDOS) at a specific site and frequency within a nanomaterial stack.

Parameters:
  • omega (float) –

    The frequency at which to evaluate the LDOS.

  • site_index (int) –

    The site index to evaluate the LDOS at.

  • broadening (float, default: 0.1 ) –

    The numerical broadening parameter to replace Dirac Deltas.

Returns:
  • float

    The local density of states at the specified site and frequency.

Source code in src/granad/orbitals.py
@recomputes
def get_ldos(self, omega: float, site_index: int, broadening: float = 0.1):
    """
    Calculates the local density of states (LDOS) at a specific site and frequency within a nanomaterial stack.

    Parameters:
       omega (float): The frequency at which to evaluate the LDOS.
       site_index (int): The site index to evaluate the LDOS at.
       broadening (float, optional): The numerical broadening parameter to replace Dirac Deltas.

    Returns:
       float: The local density of states at the specified site and frequency.
    """

    broadening = 1 / broadening
    weight = jnp.abs(self._eigenvectors[site_index, :]) ** 2
    prefactor = 1 / (jnp.sqrt(2 * jnp.pi) * broadening)
    gaussians = jnp.exp(-((self._energies - omega) ** 2) / 2 * broadening**2)
    return prefactor * jnp.sum(weight * gaussians)

get_polarizability_rpa(omegas, polarization, coulomb_strength=1.0, relaxation_rate=1 / 10, hungry=0, phi_ext=None, args=None)

Calculates the random phase approximation (RPA) polarizability of the system at given frequencies under specified conditions.

Parameters:
  • omegas (Array) –

    Frequencies at which to calculate polarizability. If given as an nxm array, this function will be applied vectorized to the batches given by the last axis in omegas.

  • relaxation_rate (float, default: 1 / 10 ) –

    The relaxation time parameter.

  • polarization (Array) –

    Polarization directions or modes.

  • coulomb_strength (float, default: 1.0 ) –

    The strength of Coulomb interaction in the calculations.

  • hungry (int, default: 0 ) –

    speed up the simulation up, higher numbers (max 2) increase RAM usage.

  • phi_ext (Optional[Array], default: None ) –

    External potential influences, if any.

  • args (Optional, default: None ) –

    numeric representation of an orbital list, as obtained by get_args

Returns:
  • jax.Array: The calculated polarizabilities at the specified frequencies.

Source code in src/granad/orbitals.py
def get_polarizability_rpa(
    self,
    omegas,            
    polarization,
    coulomb_strength=1.0,
    relaxation_rate=1/10,
    hungry=0,
    phi_ext=None,
    args = None,
):
    """
    Calculates the random phase approximation (RPA) polarizability of the system at given frequencies under specified conditions.

    Parameters:
       omegas (jax.Array): Frequencies at which to calculate polarizability. If given as an nxm array, this function will be applied vectorized to the batches given by the last axis in omegas.
       relaxation_rate (float): The relaxation time parameter.
       polarization (jax.Array): Polarization directions or modes.
       coulomb_strength (float): The strength of Coulomb interaction in the calculations.
       hungry (int): speed up the simulation up, higher numbers (max 2) increase RAM usage.
       phi_ext (Optional[jax.Array]): External potential influences, if any.
       args (Optional): numeric representation of an orbital list, as obtained by `get_args`

    Returns:
       jax.Array: The calculated polarizabilities at the specified frequencies.
    """

    if args is None:
        args = self.get_args(relaxation_rate = relaxation_rate, coulomb_strength = coulomb_strength, propagator = None)
    alpha = _numerics.rpa_polarizability_function(args, polarization, hungry, phi_ext)
    if omegas.ndim == 1:        
        return jax.lax.map(alpha, omegas)
    else:
        return jnp.concatenate( [ jax.vmap(alpha)(omega) for omega in omegas ] )

get_susceptibility_rpa(omegas, relaxation_rate=1 / 10, coulomb_strength=1.0, hungry=0, args=None)

Computes the random phase approximation (RPA) susceptibility of the system over a range of frequencies.

Parameters:
  • omegas (Array) –

    The frequencies at which to compute susceptibility.

  • relaxation_rate (float, default: 1 / 10 ) –

    The relaxation time affecting susceptibility calculations.

  • coulomb_strength (float, default: 1.0 ) –

    The strength of Coulomb interactions considered in the calculations.

  • hungry (int, default: 0 ) –

    speed up the simulation up, higher numbers (max 2) increase RAM usage.

  • args (Optional, default: None ) –

    numeric representation of an orbital list, as obtained by get_args

Returns:
  • jax.Array: The susceptibility values at the given frequencies.

Source code in src/granad/orbitals.py
def get_susceptibility_rpa(
        self, omegas, relaxation_rate=1/10, coulomb_strength=1.0, hungry=0, args = None,
):
    """
    Computes the random phase approximation (RPA) susceptibility of the system over a range of frequencies.

    Parameters:
       omegas (jax.Array): The frequencies at which to compute susceptibility.
       relaxation_rate (float): The relaxation time affecting susceptibility calculations.
       coulomb_strength (float): The strength of Coulomb interactions considered in the calculations.
       hungry (int): speed up the simulation up, higher numbers (max 2) increase RAM usage.
       args (Optional): numeric representation of an orbital list, as obtained by `get_args`

    Returns:
       jax.Array: The susceptibility values at the given frequencies.
    """
    if args is None:
        args = self.get_args(relaxation_rate = relaxation_rate, coulomb_strength = coulomb_strength, propagator = None)
    sus = _numerics.rpa_polarizability_function( args, hungry )
    return jax.lax.map(sus, omegas)    

localization(neighbor_number=6)

Compute edge localization of eigenstates according to

\[ \frac{\sum_{j \, edge} |\phi_{j}|^2}{\sum_i |\phi_i|^2 } \]

Edges are identified based on the number of next-to-next-to nearest neighbors (nnn).

Parameters:
  • neighbor_number (int, default: 6 ) –

    nnn used to identify edges. Depends on lattice and orbital number.

Returns:
  • jax.Array: localization, where i-th entry corresponds to i-th energy eigenstate

Source code in src/granad/orbitals.py
def localization(self, neighbor_number : int = 6):
    """Compute edge localization of eigenstates according to

    $$
    \\frac{\sum_{j \, edge} |\phi_{j}|^2}{\sum_i |\phi_i|^2 }
    $$

    Edges are identified based on the number of next-to-next-to nearest neighbors (nnn).

    Args:
        neighbor_number (int): nnn used to identify edges. Depends on lattice and orbital number.
        Defaults to nnn = 6 for the case of a hexagonal lattice with a single orbital per site.
        For more orbitals use nnn * num_orbitals.

    Returns:
        jax.Array: localization, where i-th entry corresponds to i-th energy eigenstate
    """

    # edges => neighboring unit cells are incomplete => all points that are not inside a "big hexagon" made up of nearest neighbors
    positions, states, energies = self.positions, self.eigenvectors, self.energies 

    distances = jnp.round(jnp.linalg.norm(positions - positions[:, None], axis = -1), 4)
    nnn = jnp.unique(distances)[2]
    mask = (distances == nnn).sum(axis=0) < neighbor_number

    # localization => how much eingenstate 
    l = (jnp.abs(states[mask, :])**2).sum(axis = 0) # vectors are normed

    return l

master_equation(*, end_time, start_time=0.0, dt=0.0001, grid=100, max_mem_gb=0.5, initial_density_matrix=None, coulomb_strength=1.0, illumination=None, relaxation_rate=None, compute_at=None, expectation_values=None, density_matrix=None, use_rwa=False, solver=diffrax.Dopri5(), stepsize_controller=diffrax.PIDController(rtol=1e-10, atol=1e-10), hamiltonian=None, dissipator=None, postprocesses=None, rhs_args=None)

Simulates the time evolution of the density matrix, computing observables, density matrices or extracting custom information.

Parameters:
  • end_time (float) –

    The final time for the simulation.

  • start_time (float, default: 0.0 ) –

    The starting time for the simulation. Defaults to 0.0.

  • dt (float, default: 0.0001 ) –

    The time step size for the simulation. Defaults to 1e-4.

  • grid (Union[int, Array], default: 100 ) –

    Determines the output times for the simulation results. If an integer, results are saved every 'grid'-th time step. If an array, results are saved at the specified times.

  • max_mem_gb (float, default: 0.5 ) –

    Maximum memory in gigabytes allowed for each batch of intermediate density matrices.

  • initial_density_matrix (Optional[Array], default: None ) –

    The initial state of the density matrix. If not provided, self.initial_density_matrix is used.

  • coulomb_strength (float, default: 1.0 ) –

    Scaling factor for the Coulomb interaction matrix.

  • illumination (Callable, default: None ) –

    Function describing the time-dependent external illumination applied to the system.

  • relaxation_rate (Union[float, Array, Callable], default: None ) –

    Specifies the relaxation dynamics. A float indicates a uniform decoherence time, an array provides state-specific rates.

  • compute_at (Optional[Array], default: None ) –

    The orbitals indexed by this array will experience induced fields.

  • expectation_values (Optional[list[Array]], default: None ) –

    Expectation values to compute during the simulation.

  • density_matrix (Optional[list[str]], default: None ) –

    Tags for additional density matrix computations. "full", "occ_x", "occ_e", "diag_x", "diag_e". May be deprecated.

  • computation (Optional[Callable]) –

    Additional computation to be performed at each step.

  • use_rwa (bool, default: False ) –

    Whether to use the rotating wave approximation. Defaults to False.

  • solver

    The numerical solver instance to use for integrating the differential equations.

  • stepsize_controller

    Controller for adjusting the solver's step size based on error tolerance.

  • hamiltonian (dict, default: None ) –

    dict of functions representing terms in the hamiltonian. functions must have signature t, r, args->jax.Array. keys don't matter.

  • dissipator (dict, default: None ) –

    : dict of functions representing terms in the dissipator. functions must have signature t, r, args->jax.Array. keys don't matter.

  • postprocesses (dict, default: None ) –

    (bool): dict of functions representing information to extract from the simulation. functions must have signature r, args->jax.Array. keys don't matter.

  • rhs_args

    arguments passed to hamiltonian, dissipator, postprocesses during the simulation. namedtuple.

Returns:
  • ResultTD

Source code in src/granad/orbitals.py
@recomputes
def master_equation(            
        self,
        *,
        end_time : float,
        start_time : float = 0.0,
        dt : float = 1e-4,
        grid : Union[int, jax.Array] = 100,
        max_mem_gb : float = 0.5,

        initial_density_matrix : Optional[jax.Array] = None,

        coulomb_strength : float = 1.0,

        illumination : Callable = None,

        relaxation_rate : Optional[Union[float, jax.Array]] = None,

        compute_at : Optional[jax.Array] = None,

        expectation_values : Optional[list[jax.Array]] = None,
        density_matrix : Optional[list[str]] = None,

        use_rwa : bool = False,

        solver = diffrax.Dopri5(),
        stepsize_controller = diffrax.PIDController(rtol=1e-10,atol=1e-10),

        hamiltonian : dict = None,
        dissipator : dict = None,
        postprocesses : dict = None,
        rhs_args = None,

):
    """
    Simulates the time evolution of the density matrix, computing observables, density matrices or extracting custom information.

    Args:
        end_time (float): The final time for the simulation.
        start_time (float): The starting time for the simulation. Defaults to 0.0.
        dt (float): The time step size for the simulation. Defaults to 1e-4.
        grid (Union[int, jax.Array]): Determines the output times for the simulation results. If an integer, results
                                      are saved every 'grid'-th time step. If an array, results are saved at the
                                      specified times.
        max_mem_gb (float): Maximum memory in gigabytes allowed for each batch of intermediate density matrices.
        initial_density_matrix (Optional[jax.Array]): The initial state of the density matrix. If not provided,
                                                      `self.initial_density_matrix` is used.
        coulomb_strength (float): Scaling factor for the Coulomb interaction matrix.
        illumination (Callable): Function describing the time-dependent external illumination applied to the system.
        relaxation_rate (Union[float, jax.Array, Callable]): Specifies the relaxation dynamics. A float indicates a
                                                             uniform decoherence time, an array provides state-specific
                                                             rates.
        compute_at (Optional[jax.Array]): The orbitals indexed by this array will experience induced fields.
        expectation_values (Optional[list[jax.Array]]): Expectation values to compute during the simulation.
        density_matrix (Optional[list[str]]): Tags for additional density matrix computations. "full", "occ_x", "occ_e", "diag_x", "diag_e". May be deprecated.
        computation (Optional[Callable]): Additional computation to be performed at each step.
        use_rwa (bool): Whether to use the rotating wave approximation. Defaults to False.
        solver: The numerical solver instance to use for integrating the differential equations.
        stepsize_controller: Controller for adjusting the solver's step size based on error tolerance.
        hamiltonian: dict of functions representing terms in the hamiltonian. functions must have signature `t, r, args->jax.Array`. keys don't matter.
        dissipator:: dict of functions representing terms in the dissipator. functions must have signature `t, r, args->jax.Array`. keys don't matter.
        postprocesses: (bool): dict of functions representing information to extract from the simulation. functions must have signature `r, args->jax.Array`. keys don't matter.
        rhs_args: arguments passed to hamiltonian, dissipator, postprocesses during the simulation. namedtuple.

    Returns:
        ResultTD
    """


    # arguments to evolution function
    if rhs_args is None:
        rhs_args = self.get_args( relaxation_rate,
                                  coulomb_strength,
                                  _numerics.get_coulomb_field_to_from(self.positions, self.positions, compute_at) )

    if illumination is None:
        illumination = lambda t : jnp.array( [0j, 0j, 0j] )

    # each of these functions is applied to a density matrix batch
    postprocesses = self.get_postprocesses( expectation_values, density_matrix ) if postprocesses is None else postprocesses

    # hermitian rhs
    hamiltonian = self.get_hamiltonian(illumination, use_rwa, compute_at is not None) if hamiltonian is None else hamiltonian

    # non hermitian rhs
    dissipator = self.get_dissipator(relaxation_rate, None) if dissipator is None else dissipator

    # set reasonable default 
    initial_density_matrix = initial_density_matrix if initial_density_matrix is not None else rhs_args.initial_density_matrix

    try:        
        return self._integrate_master_equation( list(hamiltonian.values()), list(dissipator.values()), list(postprocesses.values()), rhs_args, illumination, solver, stepsize_controller, initial_density_matrix, start_time, end_time, grid, max_mem_gb, dt )
    except Exception as e:
        print(f"Simulation crashed with exception {e}. Try increasing the time mesh and make your sure your illumination is differentiable. The full diffrax traceback follows below.")
        traceback.print_stack()

rotate(x, phi, axis='z')

rotates all orbitals an angle phi around a point p around axis.

x : jnp.ndarray A 3D point around which to rotate. phi : float Angle by which to rotate. axis : str Axis to rotate around ('x', 'y', or 'z'). Default is 'z'.

Source code in src/granad/orbitals.py
@mutates
def rotate(self, x, phi, axis = 'z'):
    """rotates all orbitals an angle phi around a point p around axis.    

    Args:
    x : jnp.ndarray
        A 3D point around which to rotate.
    phi : float
        Angle by which to rotate.
    axis : str
        Axis to rotate around ('x', 'y', or 'z'). Default is 'z'.
    """

    # Define the rotation matrix based on the specified axis
    if axis == 'x':
        rotation_matrix = jnp.array([
            [1, 0, 0],
            [0, jnp.cos(phi), -jnp.sin(phi)],
            [0, jnp.sin(phi), jnp.cos(phi)]
        ])
    elif axis == 'y':
        rotation_matrix = jnp.array([
            [jnp.cos(phi), 0, jnp.sin(phi)],
            [0, 1, 0],
            [-jnp.sin(phi), 0, jnp.cos(phi)]
        ])
    elif axis == 'z':
        rotation_matrix = jnp.array([
            [jnp.cos(phi), -jnp.sin(phi), 0],
            [jnp.sin(phi), jnp.cos(phi), 0],
            [0, 0, 1]
        ])
    else:
        raise ValueError("Axis must be 'x', 'y', or 'z'.")

    for orb in self._list:
        # Perform the rotation (translate along x, rotate, translate back)
        self.set_position(rotation_matrix @ (orb.position - x) + x, orb)

set_coulomb_element(orb1, orb2, val)

Sets a Coulomb interaction element between two orbitals or indices.

Parameters:
  • orb1

    Identifier for orbital(s) for the first element.

  • orb2

    Identifier for orbital(s) for the second element.

  • val (complex) –

    The complex value to set for the Coulomb interaction element.

Source code in src/granad/orbitals.py
def set_coulomb_element(self, orb1, orb2, val):
    """
    Sets a Coulomb interaction element between two orbitals or indices.

    Parameters:
        orb1: Identifier for orbital(s) for the first element.
        orb2: Identifier for orbital(s) for the second element.
        val (complex): The complex value to set for the Coulomb interaction element.
    """
    self._set_coupling(self.filter_orbs(orb1, Orbital), self.filter_orbs(orb2, Orbital), self._ensure_complex(val), self.couplings.coulomb)

set_coulomb_groups(orb1, orb2, func)

Sets the Coulomb coupling between two groups of orbitals.

Parameters:
  • orb1

    Identifier for orbital(s) for the first group.

  • orb2

    Identifier for orbital(s) for the second group.

  • func (callable) –

    Function that defines the Coulomb interaction.

Note

The function func should be complex-valued.

Source code in src/granad/orbitals.py
def set_coulomb_groups(self, orb1, orb2, func):
    """
    Sets the Coulomb coupling between two groups of orbitals.

    Parameters:
        orb1: Identifier for orbital(s) for the first group.
        orb2: Identifier for orbital(s) for the second group.
        func (callable): Function that defines the Coulomb interaction.

    Note:
        The function `func` should be complex-valued.
    """
    self._set_coupling(
        self.filter_orbs(orb1, _watchdog.GroupId), self.filter_orbs(orb2, _watchdog.GroupId), self._ensure_complex(func), self.couplings.coulomb
    )

set_dipole_element(orb1, orb2, arr)

Sets a dipole transition for specified orbital or index pairs.

Parameters:
  • orb1

    Identifier for orbital(s) for the first part of the transition.

  • orb2

    Identifier for orbital(s) for the second part of the transition.

  • arr (Array) –

    The 3-element array containing dipole transition elements.

Source code in src/granad/orbitals.py
@mutates
def set_dipole_element(self, orb1, orb2, arr):
    """
    Sets a dipole transition for specified orbital or index pairs.

    Parameters:
        orb1: Identifier for orbital(s) for the first part of the transition.
        orb2: Identifier for orbital(s) for the second part of the transition.
        arr (jax.Array): The 3-element array containing dipole transition elements.
    """
    self._set_coupling(self.filter_orbs(orb1, Orbital), self.filter_orbs(orb2, Orbital), jnp.array(arr).astype(complex), self.couplings.dipole_transitions)

set_excitation(from_state, to_state, excited_electrons)

Sets up an excitation process from one state to another with specified electrons.

Parameters:
  • from_state (int, list, or jax.Array) –

    The initial state index or indices.

  • to_state (int, list, or jax.Array) –

    The final state index or indices.

  • excited_electrons (int, list, or jax.Array) –

    The indices of electrons to be excited.

Note

The states and electron indices may be specified as scalars, lists, or arrays.

Source code in src/granad/orbitals.py
@mutates
def set_excitation(self, from_state, to_state, excited_electrons):
    """
    Sets up an excitation process from one state to another with specified electrons.

    Parameters:
        from_state (int, list, or jax.Array): The initial state index or indices.
        to_state (int, list, or jax.Array): The final state index or indices.
        excited_electrons (int, list, or jax.Array): The indices of electrons to be excited.

    Note:
        The states and electron indices may be specified as scalars, lists, or arrays.
    """
    def maybe_int_to_arr(maybe_int):
        if isinstance(maybe_int, int):
            return jnp.array([maybe_int])
        if isinstance(maybe_int, list):
            return jnp.array(maybe_int)
        raise TypeError

    self.params.excitation = [maybe_int_to_arr(from_state), maybe_int_to_arr(to_state), maybe_int_to_arr(excited_electrons)]

set_hamiltonian_element(orb1, orb2, val)

Sets an element of the Hamiltonian matrix between two orbitals or indices.

Parameters:
  • orb1

    Identifier for orbital(s) for the first element.

  • orb2

    Identifier for orbital(s) for the second element.

  • val (complex) –

    The complex value to set for the Hamiltonian element.

Source code in src/granad/orbitals.py
def set_hamiltonian_element(self, orb1, orb2, val):
    """
    Sets an element of the Hamiltonian matrix between two orbitals or indices.

    Parameters:
        orb1: Identifier for orbital(s) for the first element.
        orb2: Identifier for orbital(s) for the second element.
        val (complex): The complex value to set for the Hamiltonian element.
    """
    self._set_coupling(self.filter_orbs(orb1, Orbital), self.filter_orbs(orb2, Orbital), self._ensure_complex(val), self.couplings.hamiltonian)

set_hamiltonian_groups(orb1, orb2, func)

Sets the hamiltonian coupling between two groups of orbitals.

Parameters:
  • orb1

    Identifier for orbital(s) for the first group.

  • orb2

    Identifier for orbital(s) for the second group.

  • func (callable) –

    Function that defines the hamiltonian interaction.

Note

The function func should be complex-valued.

Source code in src/granad/orbitals.py
def set_hamiltonian_groups(self, orb1, orb2, func):
    """
    Sets the hamiltonian coupling between two groups of orbitals.

    Parameters:
        orb1: Identifier for orbital(s) for the first group.
        orb2: Identifier for orbital(s) for the second group.
        func (callable): Function that defines the hamiltonian interaction.

    Note:
        The function `func` should be complex-valued.
    """
    self._set_coupling(
        self.filter_orbs(orb1, _watchdog.GroupId), self.filter_orbs(orb2, _watchdog.GroupId), self._ensure_complex(func), self.couplings.hamiltonian
    )

set_mean_field(**kwargs)

Configures the parameters for mean field calculations. If no other parameters are passed, a standard direct channel Hartree-Fock calculation is performed. Note that this procedure differs slightly from the self-consistent field procedure.

This function sets up the mean field parameters used in iterative calculations to update the system's density matrix until convergence is achieved.

Parameters:
  • **kwargs

    Keyword arguments to override the default self-consistency parameters. The available parameters are:

    • accuracy (float, optional): The convergence criterion for self-consistency. Specifies the maximum allowed difference between successive density matrices. Default is 1e-6.

    • mix (float, optional): The mixing factor for the density matrix during updates. This controls the contribution of the new density matrix to the updated one. Values closer to 1 favor the new density matrix, while smaller values favor smoother convergence. Default is 0.3.

    • iterations (int, optional): The maximum number of iterations allowed in the self-consistency cycle. Default is 500.

    • coulomb_strength (float, optional): A scaling factor for the Coulomb matrix. This allows tuning of the strength of Coulomb interactions in the system. Default is 1.0.

    • f_mean_field (Callable, optional): A function for computing the mean field term. First argument is density matrix, second argument is single particle hamiltonian. Can be used, e.g., for full HF by passing a closure containing ERIs.
      Default is None.

    • f_build (Callable, optional): Construction of the density matrix from energies and eigenvectors. If None, single-particle energy levels are filled according to number of electrons. Default is None.

    • rho_0 (jax.Array, optional): Initial guess for the density matrix. If None, zeros are used. Default is None.

Example

model.set_mean_field(accuracy=1e-7, mix=0.5, iterations=1000) print(model.params.mean_field_params) {'accuracy': 1e-7, 'mix': 0.5, 'iterations': 1000, 'coulomb_strength': 1.0, 'f_mean_field': None}

Source code in src/granad/orbitals.py
@mutates
def set_mean_field(self, **kwargs):
    """
    Configures the parameters for mean field calculations.
    If no other parameters are passed, a standard direct channel Hartree-Fock calculation is performed.
    Note that this procedure differs slightly from the self-consistent field procedure.

    This function sets up the mean field parameters used in iterative calculations 
    to update the system's density matrix until convergence is achieved.

    Args:
        **kwargs: Keyword arguments to override the default self-consistency parameters. 
            The available parameters are:

            - `accuracy` (float, optional): The convergence criterion for self-consistency. 
              Specifies the maximum allowed difference between successive density matrices.
              Default is 1e-6.

            - `mix` (float, optional): The mixing factor for the density matrix during updates.
              This controls the contribution of the new density matrix to the updated one.
              Values closer to 1 favor the new density matrix, while smaller values favor 
              smoother convergence. Default is 0.3.

            - `iterations` (int, optional): The maximum number of iterations allowed in the 
              self-consistency cycle. Default is 500.

            - `coulomb_strength` (float, optional): A scaling factor for the Coulomb matrix.
              This allows tuning of the strength of Coulomb interactions in the system. 
              Default is 1.0.

            - `f_mean_field` (Callable, optional): A function for computing the mean field term.
              First argument is density matrix, second argument is single particle hamiltonian.
              Can be used, e.g., for full HF by passing a closure containing ERIs.       
              Default is None.

            - `f_build` (Callable, optional): Construction of the density matrix from energies and eigenvectors. If None, single-particle energy levels are filled according to number of electrons.
              Default is None.

            - `rho_0` (jax.Array, optional): Initial guess for the density matrix. If None, zeros are used.
               Default is None.

    Example:
        >>> model.set_mean_field(accuracy=1e-7, mix=0.5, iterations=1000)
        >>> print(model.params.mean_field_params)
        {'accuracy': 1e-7, 'mix': 0.5, 'iterations': 1000, 'coulomb_strength': 1.0, 'f_mean_field': None}
    """
    default = {"accuracy" : 1e-6, "mix" : 0.3, "iterations" : 500, "coulomb_strength" : 1.0, "f_mean_field" : None, "f_build" : None, "rho_0" : None}
    self.params.mean_field_params = default | kwargs

set_onsite_hopping(orb, val)

Sets onsite hopping element of the Hamiltonian matrix.

Parameters:
  • orb

    Identifier for orbital(s).

  • val (real) –

    The value to set for the onsite hopping.

Source code in src/granad/orbitals.py
def set_onsite_hopping(self, orb, val):
    """
    Sets onsite hopping element of the Hamiltonian matrix.

    Parameters:
        orb: Identifier for orbital(s).
        val (real): The value to set for the onsite hopping.
    """
    self.set_hamiltonian_element(orb, orb, val)        

set_position(position, orb_id=None)

Sets the position of all orbitals with a specific tag.

Parameters:
  • position (list or Array) –

    The vector at which to move the orbitals

  • orb_id

    Identifier for the orbital(s) to shift.

Note

This operation mutates the positions of the matched orbitals.

Source code in src/granad/orbitals.py
@mutates
def set_position(self, position, orb_id = None):
    """
    Sets the position of all orbitals with a specific tag.

    Parameters:
        position (list or jax.Array): The vector at which to move the orbitals
        orb_id: Identifier for the orbital(s) to shift.

    Note:
        This operation mutates the positions of the matched orbitals.
    """
    filtered_orbs = self.filter_orbs( orb_id, Orbital ) if orb_id is not None else self
    for orb in filtered_orbs:
        orb.position = position

set_self_consistent(**kwargs)

Configures the parameters for self-consistent field (SCF) calculations.

This function sets up the self-consistency parameters used in iterative calculations to update the system's density matrix until convergence is achieved.

Parameters:
  • **kwargs

    Keyword arguments to override the default self-consistency parameters. The available parameters are:

    • accuracy (float, optional): The convergence criterion for self-consistency. Specifies the maximum allowed difference between successive density matrices. Default is 1e-6.

    • mix (float, optional): The mixing factor for the density matrix during updates. This controls the contribution of the new density matrix to the updated one. Values closer to 1 favor the new density matrix, while smaller values favor smoother convergence. Default is 0.3.

    • iterations (int, optional): The maximum number of iterations allowed in the self-consistency cycle. Default is 500.

    • coulomb_strength (float, optional): A scaling factor for the Coulomb matrix. This allows tuning of the strength of Coulomb interactions in the system. Default is 1.0.

Example

model.set_self_consistent(accuracy=1e-7, mix=0.5, iterations=1000) print(model.params.self_consistency_params) {'accuracy': 1e-7, 'mix': 0.5, 'iterations': 1000, 'coulomb_strength': 1.0}

Source code in src/granad/orbitals.py
@mutates
def set_self_consistent(self, **kwargs):
    """
    Configures the parameters for self-consistent field (SCF) calculations.

    This function sets up the self-consistency parameters used in iterative calculations 
    to update the system's density matrix until convergence is achieved.

    Args:
        **kwargs: Keyword arguments to override the default self-consistency parameters. 
            The available parameters are:

            - `accuracy` (float, optional): The convergence criterion for self-consistency. 
              Specifies the maximum allowed difference between successive density matrices.
              Default is 1e-6.

            - `mix` (float, optional): The mixing factor for the density matrix during updates.
              This controls the contribution of the new density matrix to the updated one.
              Values closer to 1 favor the new density matrix, while smaller values favor 
              smoother convergence. Default is 0.3.

            - `iterations` (int, optional): The maximum number of iterations allowed in the 
              self-consistency cycle. Default is 500.

            - `coulomb_strength` (float, optional): A scaling factor for the Coulomb matrix.
              This allows tuning of the strength of Coulomb interactions in the system. 
              Default is 1.0.

    Example:
        >>> model.set_self_consistent(accuracy=1e-7, mix=0.5, iterations=1000)
        >>> print(model.params.self_consistency_params)
        {'accuracy': 1e-7, 'mix': 0.5, 'iterations': 1000, 'coulomb_strength': 1.0}
    """
    default = {"accuracy" : 1e-6, "mix" : 0.3, "iterations" : 500, "coulomb_strength" : 1.0}
    self.params.self_consistency_params = default | kwargs

shift_by_vector(translation_vector, orb_id=None)

Shifts all orbitals with a specific tag by a given vector.

Parameters:
  • translation_vector (list or Array) –

    The vector by which to translate the orbital positions.

  • orb_id

    Identifier for the orbital(s) to shift.

Note

This operation mutates the positions of the matched orbitals.

Source code in src/granad/orbitals.py
@mutates
def shift_by_vector(self, translation_vector, orb_id = None):
    """
    Shifts all orbitals with a specific tag by a given vector.

    Parameters:
        translation_vector (list or jax.Array): The vector by which to translate the orbital positions.
        orb_id: Identifier for the orbital(s) to shift.

    Note:
        This operation mutates the positions of the matched orbitals.
    """
    filtered_orbs = self.filter_orbs( orb_id, Orbital ) if orb_id is not None else self
    for orb in filtered_orbs:
        orb.position += jnp.array(translation_vector)

transform_to_energy_basis(observable)

Transforms an observable to the energy basis using the conjugate transpose of the system's eigenvectors.

Parameters:
  • observable (Array) –

    The observable to transform.

Returns:
  • jax.Array: The transformed observable in the energy basis.

Source code in src/granad/orbitals.py
def transform_to_energy_basis(self, observable):
    """
    Transforms an observable to the energy basis using the conjugate transpose of the system's eigenvectors.

    Parameters:
       observable (jax.Array): The observable to transform.

    Returns:
       jax.Array: The transformed observable in the energy basis.
    """

    return self._transform_basis(observable, self._eigenvectors.conj().T)

transform_to_site_basis(observable)

Transforms an observable to the site basis using eigenvectors of the system.

Parameters:
  • observable (Array) –

    The observable to transform.

Returns:
  • jax.Array: The transformed observable in the site basis.

Source code in src/granad/orbitals.py
def transform_to_site_basis(self, observable):
    """
    Transforms an observable to the site basis using eigenvectors of the system.

    Parameters:
       observable (jax.Array): The observable to transform.

    Returns:
       jax.Array: The transformed observable in the site basis.
    """
    return self._transform_basis(observable, self._eigenvectors)

Params dataclass

Stores parameters characterizing a given structure.

Attributes:
  • excitation (Array) –

    from state, to state, excited electrons

  • eps (float) –

    Numerical precision used for identifying degenerate eigenstates. Defaults to 1e-5.

  • beta (float) –

    Inverse temperature parameter (1/kT) used in thermodynamic calculations. Set to jax.numpy.inf by default, implying zero temperature.

  • self_consistency_params (dict) –

    A dictionary to hold additional parameters required for self-consistency calculations within the simulation. Defaults to an empty dictionary.

  • spin_degeneracy (float) –

    Factor to account for the degeneracy of spin states in the simulation. Typically set to 2, considering spin up and spin down.

  • electrons (Optional[int]) –

    The total number of electrons in the structure. If not provided, it is assumed that the system's electron number needs to be calculated or is managed elsewhere.

Note

This object should not be created directly, but is rather used to encapsulate (ephemeral) internal state of OrbitalList.

Source code in src/granad/orbitals.py
@dataclass
class Params:
    """
    Stores parameters characterizing a given structure.

    Attributes:
        excitation (jax.Array): from state, to state, excited electrons 
        eps (float): Numerical precision used for identifying degenerate eigenstates. Defaults to 1e-5.
        beta (float): Inverse temperature parameter (1/kT) used in thermodynamic calculations. Set to
                      `jax.numpy.inf` by default, implying zero temperature.
        self_consistency_params (dict): A dictionary to hold additional parameters required for self-consistency
                                        calculations within the simulation. Defaults to an empty dictionary.
        spin_degeneracy (float): Factor to account for the degeneracy of spin states in the simulation. Typically
                               set to 2, considering spin up and spin down. 
        electrons (Optional[int]): The total number of electrons in the structure. If not provided, it is assumed
                                   that the system's electron number needs to be calculated or is managed elsewhere.

    Note:
        This object should not be created directly, but is rather used to encapsulate (ephemeral) internal state
        of OrbitalList.
    """
    electrons : int 
    excitation : list[jax.Array] = field(default_factory=lambda : [jnp.array([0]), jnp.array([0]), jnp.array([0])])
    eps : float = 1e-5
    beta : float = jnp.inf
    self_consistency_params : dict =  field(default_factory=dict)
    mean_field_params : dict =  field(default_factory=dict)
    spin_degeneracy : float = 2.0

    def __add__( self, other ):
        if isinstance(other, Params):
            return Params(self.electrons + other.electrons)        
        raise ValueError

TDResult dataclass

A data class for storing the results of time-dependent simulations.

Attributes:
  • td_illumination (Array) –

    An array containing the time-dependent illumination function applied to the system, typically representing an external electromagnetic field.

  • time_axis (Array) –

    An array representing the time points at which the simulation was evaluated.

  • final_density_matrix (Array) –

    The resulting density matrix at the end of the simulation, representing the state of the system.

  • output (list[Array]) –

    A list of arrays containing various output data from the simulation, such as observables over time.

  • extra_attributes (Dict[str, Any]) –

    A dictionary saving any other quantity of interest (e.g. absorption spectra), by default empty.

Source code in src/granad/orbitals.py
@dataclass
class TDResult:
    """
    A data class for storing the results of time-dependent simulations.

    Attributes:
        td_illumination (jax.Array): An array containing the time-dependent illumination function applied to the system,
                                     typically representing an external electromagnetic field.
        time_axis (jax.Array): An array representing the time points at which the simulation was evaluated.
        final_density_matrix (jax.Array): The resulting density matrix at the end of the simulation, representing the
                                          state of the system.
        output (list[jax.Array]): A list of arrays containing various output data from the simulation, such as observables
                                  over time.
        extra_attributes (Dict[str, Any]): A dictionary saving any other quantity of interest (e.g. absorption spectra), by default empty.

    """

    td_illumination : jax.Array = field(default_factory=lambda: jnp.array([]))
    time_axis : jax.Array = field(default_factory=lambda: jnp.array([]))
    final_density_matrix : jax.Array = field(default_factory=lambda: jnp.array([[]]))
    output : list[jax.Array] = field(default_factory=list)
    extra_attributes: Dict[str, Any] = field(default_factory=dict)  # Stores dynamic attributes

    def ft_output( self, omega_max, omega_min ):
        """
        Computes the Fourier transform of each element in the output data across a specified frequency range.

        Args:
            omega_max (float): The maximum frequency bound for the Fourier transform.
            omega_min (float): The minimum frequency bound for the Fourier transform.

        Returns:
            list[jax.Array]: A list of Fourier transformed arrays corresponding to each element in the `output` attribute,
                              evaluated over the specified frequency range.

        Note:
            This method applies a Fourier transform to each array in the `output` list to analyze the frequency components
            between `omega_min` and `omega_max`.
        """
        ft = lambda o : _numerics.get_fourier_transform(self.time_axis, o, omega_max, omega_min, False)
        return [ft(o) for o in self.output]

    def ft_illumination( self, omega_max, omega_min, return_omega_axis = True ):
        """
        Calculates the Fourier transform of the time-dependent illumination function over a specified frequency range,
        with an option to return the frequency axis.

        Args:
            omega_max (float): The maximum frequency limit for the Fourier transform.
            omega_min (float): The minimum frequency limit for the Fourier transform.
            return_omega_axis (bool): If True, the function also returns the frequency axis along with the Fourier
                                      transformed illumination function. Defaults to True.

        Returns:
            jax.Array, optional[jax.Array]: The Fourier transformed illumination function. If `return_omega_axis` is True,
                                            a tuple containing the Fourier transformed data and the corresponding frequency
                                            axis is returned. Otherwise, only the Fourier transformed data is returned.

        """
        return _numerics.get_fourier_transform(self.time_axis, self.td_illumination, omega_max, omega_min, return_omega_axis)

    def add_extra_attribute(self,name: str,value: Any):
        """
        Dynamically adds an attribute to the 'extra_attributes' field.

        Args:
            name (str): Name of the new attribute to be added.
            value (Any): Value of the attribute.
        """
        self.extra_attributes[name]=value
        print(f"Extra attribute '{name}' is added.")

    def remove_extra_attribute(self,name: str):
        """
        Dynamically deletes an attribute from 'extra_attributes'.

        Args:
            name (str): Name of the attribute to be removed.
        """
        if name not in self.extra_attributes:
            raise KeyError(f"The attribute '{name}' does not exist in 'extra_attributes'. ")
        else: 
            del self.extra_attributes[name]
            print(f"Extra attribute '{name}' is removed.")

    def show_extra_attribute_list(self):
        """
        Displays all available extra attributes. 
        """
        print(list(self.extra_attributes.keys()))

    def get_attribute(self, name: str):
        """
        Returns the value of any specified attribute, no matter the original class attributes or the extra ones.

        Args:
            name (str): Name of the attribute.

        Return:
            Value of the attribute.
        """
        if name in self.__dict__.keys():
            return self.__dict__[name]

        elif name in self.extra_attributes.keys():
            return self.extra_attributes[name]

        else:
            raise KeyError(f"The attribute '{name}' does not exist ")

    def save(self, name, save_only=None):
        """
        Saves the TDResult into a .npz file

        Args:
            name (str): The filename prefix for saving.
            save_only (list, optional): List of attribute names to save selectively.
        """
        data = asdict(self) 

        if save_only:
            data.update(self.extra_attributes) # flatted dict with key from both orginal data and extra_attributes dictionary
            data={k:v for k,v in data.items() if k in save_only } # filtered data

        jnp.savez(f"{name}.npz", **data)

    @classmethod
    def load( cls, name ):
        """
        Constructs a TDResult object from saved data.

        Args:
            name (str): The filename (without extension) from which to load the data.

        Returns:
            TDResult: A TDResult object constructed from the saved data.

        Note:
            If the 'save_only' option was used earlier, the TDResult object will be created 
            with only the available data, and missing fields will be filled with empty values 
            of their corresponding types.
        """
        with jnp.load(f'{name}.npz',allow_pickle=True) as data:
            data=dict(**data)
            primary_attribute_list=['td_illumination','time_axis','final_density_matrix','output','extra_attributes']
            dynamic_attributes={k:v for k,v in data.items() if k not in primary_attribute_list}

            return cls(

                td_illumination = jnp.asarray(data.get('td_illumination',[])),

                time_axis = jnp.asarray(data.get('time_axis',[])),

                final_density_matrix = jnp.asarray(data.get('final_density_matrix',[[]])),

                output=[jnp.asarray(arr) for arr in data.get('output', [])],

                extra_attributes=data.get('extra_attributes',dynamic_attributes).item()

            )

add_extra_attribute(name, value)

Dynamically adds an attribute to the 'extra_attributes' field.

Parameters:
  • name (str) –

    Name of the new attribute to be added.

  • value (Any) –

    Value of the attribute.

Source code in src/granad/orbitals.py
def add_extra_attribute(self,name: str,value: Any):
    """
    Dynamically adds an attribute to the 'extra_attributes' field.

    Args:
        name (str): Name of the new attribute to be added.
        value (Any): Value of the attribute.
    """
    self.extra_attributes[name]=value
    print(f"Extra attribute '{name}' is added.")

ft_illumination(omega_max, omega_min, return_omega_axis=True)

Calculates the Fourier transform of the time-dependent illumination function over a specified frequency range, with an option to return the frequency axis.

Parameters:
  • omega_max (float) –

    The maximum frequency limit for the Fourier transform.

  • omega_min (float) –

    The minimum frequency limit for the Fourier transform.

  • return_omega_axis (bool, default: True ) –

    If True, the function also returns the frequency axis along with the Fourier transformed illumination function. Defaults to True.

Returns:
  • jax.Array, optional[jax.Array]: The Fourier transformed illumination function. If return_omega_axis is True, a tuple containing the Fourier transformed data and the corresponding frequency axis is returned. Otherwise, only the Fourier transformed data is returned.

Source code in src/granad/orbitals.py
def ft_illumination( self, omega_max, omega_min, return_omega_axis = True ):
    """
    Calculates the Fourier transform of the time-dependent illumination function over a specified frequency range,
    with an option to return the frequency axis.

    Args:
        omega_max (float): The maximum frequency limit for the Fourier transform.
        omega_min (float): The minimum frequency limit for the Fourier transform.
        return_omega_axis (bool): If True, the function also returns the frequency axis along with the Fourier
                                  transformed illumination function. Defaults to True.

    Returns:
        jax.Array, optional[jax.Array]: The Fourier transformed illumination function. If `return_omega_axis` is True,
                                        a tuple containing the Fourier transformed data and the corresponding frequency
                                        axis is returned. Otherwise, only the Fourier transformed data is returned.

    """
    return _numerics.get_fourier_transform(self.time_axis, self.td_illumination, omega_max, omega_min, return_omega_axis)

ft_output(omega_max, omega_min)

Computes the Fourier transform of each element in the output data across a specified frequency range.

Parameters:
  • omega_max (float) –

    The maximum frequency bound for the Fourier transform.

  • omega_min (float) –

    The minimum frequency bound for the Fourier transform.

Returns:
  • list[jax.Array]: A list of Fourier transformed arrays corresponding to each element in the output attribute, evaluated over the specified frequency range.

Note

This method applies a Fourier transform to each array in the output list to analyze the frequency components between omega_min and omega_max.

Source code in src/granad/orbitals.py
def ft_output( self, omega_max, omega_min ):
    """
    Computes the Fourier transform of each element in the output data across a specified frequency range.

    Args:
        omega_max (float): The maximum frequency bound for the Fourier transform.
        omega_min (float): The minimum frequency bound for the Fourier transform.

    Returns:
        list[jax.Array]: A list of Fourier transformed arrays corresponding to each element in the `output` attribute,
                          evaluated over the specified frequency range.

    Note:
        This method applies a Fourier transform to each array in the `output` list to analyze the frequency components
        between `omega_min` and `omega_max`.
    """
    ft = lambda o : _numerics.get_fourier_transform(self.time_axis, o, omega_max, omega_min, False)
    return [ft(o) for o in self.output]

get_attribute(name)

Returns the value of any specified attribute, no matter the original class attributes or the extra ones.

Parameters:
  • name (str) –

    Name of the attribute.

Return

Value of the attribute.

Source code in src/granad/orbitals.py
def get_attribute(self, name: str):
    """
    Returns the value of any specified attribute, no matter the original class attributes or the extra ones.

    Args:
        name (str): Name of the attribute.

    Return:
        Value of the attribute.
    """
    if name in self.__dict__.keys():
        return self.__dict__[name]

    elif name in self.extra_attributes.keys():
        return self.extra_attributes[name]

    else:
        raise KeyError(f"The attribute '{name}' does not exist ")

load(name) classmethod

Constructs a TDResult object from saved data.

Parameters:
  • name (str) –

    The filename (without extension) from which to load the data.

Returns:
  • TDResult

    A TDResult object constructed from the saved data.

Note

If the 'save_only' option was used earlier, the TDResult object will be created with only the available data, and missing fields will be filled with empty values of their corresponding types.

Source code in src/granad/orbitals.py
@classmethod
def load( cls, name ):
    """
    Constructs a TDResult object from saved data.

    Args:
        name (str): The filename (without extension) from which to load the data.

    Returns:
        TDResult: A TDResult object constructed from the saved data.

    Note:
        If the 'save_only' option was used earlier, the TDResult object will be created 
        with only the available data, and missing fields will be filled with empty values 
        of their corresponding types.
    """
    with jnp.load(f'{name}.npz',allow_pickle=True) as data:
        data=dict(**data)
        primary_attribute_list=['td_illumination','time_axis','final_density_matrix','output','extra_attributes']
        dynamic_attributes={k:v for k,v in data.items() if k not in primary_attribute_list}

        return cls(

            td_illumination = jnp.asarray(data.get('td_illumination',[])),

            time_axis = jnp.asarray(data.get('time_axis',[])),

            final_density_matrix = jnp.asarray(data.get('final_density_matrix',[[]])),

            output=[jnp.asarray(arr) for arr in data.get('output', [])],

            extra_attributes=data.get('extra_attributes',dynamic_attributes).item()

        )

remove_extra_attribute(name)

Dynamically deletes an attribute from 'extra_attributes'.

Parameters:
  • name (str) –

    Name of the attribute to be removed.

Source code in src/granad/orbitals.py
def remove_extra_attribute(self,name: str):
    """
    Dynamically deletes an attribute from 'extra_attributes'.

    Args:
        name (str): Name of the attribute to be removed.
    """
    if name not in self.extra_attributes:
        raise KeyError(f"The attribute '{name}' does not exist in 'extra_attributes'. ")
    else: 
        del self.extra_attributes[name]
        print(f"Extra attribute '{name}' is removed.")

save(name, save_only=None)

Saves the TDResult into a .npz file

Parameters:
  • name (str) –

    The filename prefix for saving.

  • save_only (list, default: None ) –

    List of attribute names to save selectively.

Source code in src/granad/orbitals.py
def save(self, name, save_only=None):
    """
    Saves the TDResult into a .npz file

    Args:
        name (str): The filename prefix for saving.
        save_only (list, optional): List of attribute names to save selectively.
    """
    data = asdict(self) 

    if save_only:
        data.update(self.extra_attributes) # flatted dict with key from both orginal data and extra_attributes dictionary
        data={k:v for k,v in data.items() if k in save_only } # filtered data

    jnp.savez(f"{name}.npz", **data)

show_extra_attribute_list()

Displays all available extra attributes.

Source code in src/granad/orbitals.py
def show_extra_attribute_list(self):
    """
    Displays all available extra attributes. 
    """
    print(list(self.extra_attributes.keys()))

Pulse(amplitudes, frequency, peak, fwhm)

Function for computing temporally located time-harmonics electric fields. The pulse is implemented as a temporal Gaussian.

Parameters:
  • amplitudes (list[float]) –

    electric field amplitudes in xyz-components

  • frequency (float) –

    angular frequency of the electric field

  • peak (float) –

    time where the pulse reaches its peak

  • fwhm (float) –

    full width at half maximum

Returns:
  • Function that computes the electric field

Source code in src/granad/fields.py
def Pulse(
    amplitudes: list[float],
    frequency: float,
    peak: float,
    fwhm: float,
):
    """Function for computing temporally located time-harmonics electric fields. The pulse is implemented as a temporal Gaussian.

    Args:
        amplitudes: electric field amplitudes in xyz-components
        frequency: angular frequency of the electric field
        peak: time where the pulse reaches its peak
        fwhm: full width at half maximum

    Returns:
       Function that computes the electric field
    """
    def _field(t, real = True):
        val = (
            static_part
            * jnp.exp(-1j * jnp.pi / 2 + 1j * frequency * (t - peak))
            * jnp.exp(-((t - peak) ** 2) / sigma**2)
        )
        return val.real if real else val
    static_part = jnp.array(amplitudes)
    sigma = fwhm / (2.0 * jnp.sqrt(jnp.log(2)))
    return _field

Ramp(amplitudes, frequency, ramp_duration, time_ramp)

Function for computing ramping up time-harmonic electric fields.

Parameters:
  • amplitudes (list[float]) –

    electric field amplitudes in xyz-components

  • frequency (float) –

    angular frequency

  • ramp_duration (float) –

    specifies how long does the electric field ramps up

  • time_ramp (float) –

    specifies time at which the field starts to ramp up

Returns:
  • Function that computes the electric field as a functon of time

Source code in src/granad/fields.py
def Ramp(
    amplitudes: list[float],
    frequency: float,
    ramp_duration: float,
    time_ramp: float,
):
    """Function for computing ramping up time-harmonic electric fields.

    Args:
        amplitudes: electric field amplitudes in xyz-components
        frequency: angular frequency
        ramp_duration: specifies how long does the electric field ramps up
        time_ramp: specifies time at which the field starts to ramp up

    Returns:
       Function that computes the electric field as a functon of time
    """
    def _field(t, real = True):        
        val =  (
            static_part
            * jnp.exp(1j * frequency * t)
            / (1 + 1.0 * jnp.exp(-ramp_constant * (t - time_ramp)))
        )
        return val.real if real else val
    static_part = jnp.array(amplitudes)
    p = 0.99
    ramp_constant = 2 * jnp.log(p / (1 - p)) / ramp_duration
    return _field

Wave(amplitudes, frequency)

Function for computing time-harmonic electric fields.

Parameters:
  • amplitudes (list[float]) –

    electric field amplitudes in xyz-components

  • frequency (float) –

    angular frequency

Returns:
  • Function that computes the electric field as a functon of time

Source code in src/granad/fields.py
def Wave(
    amplitudes: list[float],
    frequency: float,
):
    """Function for computing time-harmonic electric fields.

    Args:
        amplitudes: electric field amplitudes in xyz-components
        frequency: angular frequency

    Returns:
       Function that computes the electric field as a functon of time
    """
    def _field(t, real = True):
        val = (jnp.exp(1j * frequency * t) * static_part)
        return val.real if real else val
    static_part = jnp.array(amplitudes)
    return _field

Material

Represents a material in a simulation, encapsulating its physical properties and interactions.

Attributes:
  • name (str) –

    The name of the material.

  • species (dict) –

    Dictionary mapping species names to their quantum numbers and associated atoms. Each species is defined with properties like spin quantum number (s), and the atom type.

  • orbitals (defaultdict[list]) –

    A mapping from species to lists of orbitals. Each orbital is represented as a dictionary containing the orbital's position and an optional tag for further identification.

  • interactions (defaultdict[dict]) –

    Describes the interactions between orbitals within the material. Each interaction is categorized by type (e.g., 'hamiltonian', 'Coulomb'), and includes the participants, parameters like [onsite, offsite_nearest_neighbor, offsite_next_to_nearest_neighbor, ...], and
    an optional mathematical expression defining the interaction for the coupling beyound the len(parameters) - th nearest neighbor.

Note

The Material class is used to define a material's structure and properties step-by-step. An example is constructing the material graphene, with specific lattice properties, orbitals corresponding to carbon's p_z orbitals, and defining hamiltonian and Coulomb interactions among these orbitals.

graphene = (
    Material("graphene")
    .lattice_constant(2.46)
    .lattice_basis([
        [1, 0, 0],
        [-0.5, jnp.sqrt(3)/2, 0]
    ])
    .add_orbital_species("pz", atom='C')
    .add_orbital(position=(0, 0), tag="sublattice_1", species="pz")
    .add_orbital(position=(-1/3, -2/3), tag="sublattice_2", species="pz")
    .add_interaction(
        "hamiltonian",
        participants=("pz", "pz"),
        parameters=[0.0, -2.66],
    )
    .add_interaction(
        "coulomb",
        participants=("pz", "pz"),
        parameters=[16.522, 8.64, 5.333],
        expression=lambda r : 1/r + 0j
    )
)
Source code in src/granad/materials.py
class Material:
    """
    Represents a material in a simulation, encapsulating its physical properties and interactions.

    Attributes:
        name (str): The name of the material.
        species (dict): Dictionary mapping species names to their quantum numbers and associated atoms.
                        Each species is defined with properties like spin quantum number (s), and the atom type.
        orbitals (defaultdict[list]): A mapping from species to lists of orbitals. Each orbital is represented
                                      as a dictionary containing the orbital's position and an optional tag
                                      for further identification.
        interactions (defaultdict[dict]): Describes the interactions between orbitals within the material.
                                         Each interaction is categorized by type (e.g., 'hamiltonian', 'Coulomb'),
                                         and includes the participants, parameters like 
                                         [onsite, offsite_nearest_neighbor, offsite_next_to_nearest_neighbor, ...], and                                         
                                         an optional mathematical expression defining the interaction for the coupling beyound 
                                         the len(parameters) - th nearest neighbor.

    Note:
        The `Material` class is used to define a material's structure and properties step-by-step.
        An example is constructing the material graphene, with specific lattice properties,
        orbitals corresponding to carbon's p_z orbitals, and defining hamiltonian and Coulomb interactions
        among these orbitals. 

        ```python
        graphene = (
            Material("graphene")
            .lattice_constant(2.46)
            .lattice_basis([
                [1, 0, 0],
                [-0.5, jnp.sqrt(3)/2, 0]
            ])
            .add_orbital_species("pz", atom='C')
            .add_orbital(position=(0, 0), tag="sublattice_1", species="pz")
            .add_orbital(position=(-1/3, -2/3), tag="sublattice_2", species="pz")
            .add_interaction(
                "hamiltonian",
                participants=("pz", "pz"),
                parameters=[0.0, -2.66],
            )
            .add_interaction(
                "coulomb",
                participants=("pz", "pz"),
                parameters=[16.522, 8.64, 5.333],
                expression=lambda r : 1/r + 0j
            )
        )
        ```
    """
    def __init__(self, name):
        self.name = name
        self.species = {}
        self.orbitals = defaultdict(list)
        self.interactions = defaultdict(dict)
        self._species_to_groups=  {}
        self.dim = None

    def __str__(self):
        description = f"Material: {self.name}\n"
        if self.lattice_constant:
            description += f"  Lattice Constant: {self.lattice_constant} Å\n"
        if self.lattice_basis:
            description += f"  Lattice Basis: \n{self._lattice_basis}\n"

        if self.species:
            description += "  Orbital Species:\n"
            for species_name, attributes in self.species.items():
                description += f"    {species_name} characterized by (n,l,m,s, atom name) = {attributes}\n"

        if self.orbitals:
            description += "  Orbitals:\n"
            for spec, orbs in self.orbitals.items():
                for orb in orbs:
                    description += f"    Position: {orb['position']}, Tag: {orb['tag']}, Species: {spec}\n"

        if self.interactions:
            description += "  Interactions:\n"
            for type_, interaction in self.interactions.items():
                for participants, coupling in interaction.items():
                    description += f"""   Type: {type_}, Participants: {participants}:
                    NN Couplings: {', '.join(map(str, coupling[0]))}
                    """
                    # Check if there's a docstring on the function
                    if coupling[1].__doc__ is not None:
                        function_description = coupling[1].__doc__
                    else:
                        function_description = "No description available for this function."

                    description += f"Other neighbors: {function_description}\n"

        return description

    def lattice_constant(self, value):
        """
        Sets the lattice constant for the material.

        Parameters:
            value (float): The lattice constant value.

        Returns:
            Material: Returns self to enable method chaining.
        """
        self.lattice_constant = value
        return self

    def lattice_basis(self, values, periodic = None):
        """
        Defines the lattice basis vectors and specifies which dimensions are periodic.

        Parameters:
            values (list of list of float): A list of vectors representing the lattice basis.
            periodic (list of int, optional): Indices of the basis vectors that are periodic. Defaults to all vectors being periodic.

        Returns:
            Material: Returns self to enable method chaining.
        """
        self._lattice_basis = jnp.array(values)
        total = set(range(len(self._lattice_basis)))        
        periodic = set(periodic) if periodic is not None else total
        self.periodic = list(periodic)
        self.finite = list(total - periodic)
        self.dim = len(self.periodic)                                              
        return self

    @_finalize
    def cut_flake( self ):
        """
        Finalizes the material construction by defining a method to cut a flake of the material,
        according to the material's dimensions like this

        1D material : materials.cut_flake_1d
        2D material : materials.cut_flake_2d
        3D material and higher : materials.cut_flake_generic

        This method is intended to be called after all material properties (like lattice constants, 
        basis, orbitals, and interactions) have been fully defined.

        Note:
        This method does not take any parameters and does not return any value. Its effect is
        internal to the state of the Material object and is meant to prepare the material for
        simulation by implementing necessary final structural adjustments.
        """
        pass

    def add_orbital(self, position, species, tag = ''):
        """
        Sets the lattice constant for the material.

        Parameters:
            value (float): The lattice constant value.

        Returns:
            Material: Returns self to enable method chaining.
        """
        self.orbitals[species].append({'position': position, 'tag': tag})
        return self

    def add_orbital_species( self, name, s = 0, atom  = ''):
        """
        Adds a species definition for orbitals in the material.

        Parameters:
            name (str): The name of the orbital species.
            s (int): Spin quantum number.
            atom (str, optional): Name of the atom the orbital belongs to.

        Returns:
            Material: Returns self to enable method chaining.
        """
        self.species[name] = (s,atom)
        return self

    def add_interaction(self, interaction_type, participants, parameters = None, expression = zero_coupling):
        """
        Adds an interaction between orbitals specified by an interaction type and participants.

        Parameters:
            interaction_type (str): The type of interaction (e.g., 'hamiltonian', 'Coulomb').
            participants (tuple): A tuple identifying the participants in the interaction.
            parameters (dict): Parameters relevant to the interaction.
            expression (function): A function defining the mathematical form of the interaction.

        Returns:
            Material: Returns self to enable method chaining.
        """            
        self.interactions[interaction_type][participants] =  (parameters if parameters is not None else [], expression)
        return self

    def _get_positions_in_uc( self, species = None ):
        if species is None:
            return jnp.array( [x["position"] for orb in list(self.orbitals.values()) for x in orb] )
        return jnp.array( [orb_group['position'] for s in species for orb_group in self.orbitals[s] ] )

    def _get_positions_in_lattice(self, uc_positions, grid):
        shift = jnp.array(uc_positions) @ self._lattice_basis
        return self.lattice_constant * (
            grid @ self._lattice_basis + shift[:, None, :]
        ).reshape(shift.shape[0] * grid.shape[0], 3)

    def _get_grid(self, ns ):
        grid = [(1,) for i in range( len(self.finite) + len(self.periodic)) ]
        for i, p in enumerate(self.periodic):
            grid[p] = range(*ns[i])
        return jnp.array( list( product( *(x for x in grid) ) ) )

    def _keep_matching_positions(self, positions, candidates):
        idxs = (
            jnp.round(jnp.linalg.norm(positions - candidates[:, None], axis=-1), 4) == 0
        ).nonzero()[0]
        return candidates[idxs]

    def _couplings_to_function(
        self, couplings, outside_fun, species
    ):

        # no couplings
        if len(couplings) == 0:
            return outside_fun

        # vector couplings
        if all(isinstance(i, list) for i in couplings):
            return self._vector_couplings_to_function(couplings, outside_fun, species)

        # distance couplings
        return self._distance_couplings_to_function(couplings, outside_fun, species)


    def _vector_couplings_to_function(self, couplings, outside_fun, species):

        vecs, couplings_vals = jnp.array(couplings).astype(float)[:, :3], jnp.array(couplings).astype(complex)[:, 3]
        distances = jnp.linalg.norm(vecs, axis=1)

        def inner(d):
            return jax.lax.cond(
                jnp.min(jnp.abs(jnp.linalg.norm(d) - distances)) < 1e-5,
                lambda x: couplings_vals[jnp.argmin(jnp.linalg.norm(d - vecs, axis=1))],
                outside_fun,
                d,
            )
        return inner

    def _distance_couplings_to_function(self, couplings, outside_fun, species):

        couplings = jnp.array(couplings).astype(complex)
        grid = self._get_grid( [ (0, len(couplings)) for i in range(self.dim) ] )
        pos_uc_1 = self._get_positions_in_uc( (species[0],) )
        pos_uc_2 = self._get_positions_in_uc( (species[1],) )
        positions_1 = self._get_positions_in_lattice(pos_uc_1, grid )
        positions_2 = self._get_positions_in_lattice(pos_uc_2, grid )

        distances = jnp.unique(
            jnp.round(jnp.linalg.norm(positions_1 - positions_2[:, None, :], axis=2), 5)
        )[: len(couplings)]

        def inner(d):
            d = jnp.linalg.norm(d)
            return jax.lax.cond(
                jnp.min(jnp.abs(d - distances)) < 1e-5,
                lambda x: couplings[jnp.argmin(jnp.abs(x - distances))],
                outside_fun,
                d,
            )

        return inner

    def _set_couplings(self, setter_func, interaction_type):
        interaction_dict = self.interactions[interaction_type]
        for (species_1, species_2), couplings in interaction_dict.items():
            distance_func = self._couplings_to_function(
                *couplings, (species_1, species_2)
            )
            setter_func(self._species_to_groups[species_1], self._species_to_groups[species_2], distance_func)


    def _get_orbital_list(self, allowed_positions, grid):

        raw_list, layer_index = [], 0
        for species, orb_group in self.orbitals.items():

            for orb_uc in orb_group:

                uc_positions = jnp.array( [orb_uc['position']] )

                rs_positions = self._get_positions_in_lattice( uc_positions, grid )

                final_positions = self._keep_matching_positions( allowed_positions, rs_positions )

                for position in final_positions:
                    orb = Orbital(
                        position = position,
                        layer_index = layer_index,
                        tag=orb_uc['tag'],
                        group_id = self._species_to_groups[species],                        
                        spin=self.species[species][0],
                        atom_name=self.species[species][1]
                    )
                    layer_index += 1
                    raw_list.append(orb)

        orbital_list = OrbitalList(raw_list)
        self._set_couplings(orbital_list.set_hamiltonian_groups, "hamiltonian")
        self._set_couplings(orbital_list.set_coulomb_groups, "coulomb")
        return orbital_list

add_interaction(interaction_type, participants, parameters=None, expression=zero_coupling)

Adds an interaction between orbitals specified by an interaction type and participants.

Parameters:
  • interaction_type (str) –

    The type of interaction (e.g., 'hamiltonian', 'Coulomb').

  • participants (tuple) –

    A tuple identifying the participants in the interaction.

  • parameters (dict, default: None ) –

    Parameters relevant to the interaction.

  • expression (function, default: zero_coupling ) –

    A function defining the mathematical form of the interaction.

Returns:
  • Material

    Returns self to enable method chaining.

Source code in src/granad/materials.py
def add_interaction(self, interaction_type, participants, parameters = None, expression = zero_coupling):
    """
    Adds an interaction between orbitals specified by an interaction type and participants.

    Parameters:
        interaction_type (str): The type of interaction (e.g., 'hamiltonian', 'Coulomb').
        participants (tuple): A tuple identifying the participants in the interaction.
        parameters (dict): Parameters relevant to the interaction.
        expression (function): A function defining the mathematical form of the interaction.

    Returns:
        Material: Returns self to enable method chaining.
    """            
    self.interactions[interaction_type][participants] =  (parameters if parameters is not None else [], expression)
    return self

add_orbital(position, species, tag='')

Sets the lattice constant for the material.

Parameters:
  • value (float) –

    The lattice constant value.

Returns:
  • Material

    Returns self to enable method chaining.

Source code in src/granad/materials.py
def add_orbital(self, position, species, tag = ''):
    """
    Sets the lattice constant for the material.

    Parameters:
        value (float): The lattice constant value.

    Returns:
        Material: Returns self to enable method chaining.
    """
    self.orbitals[species].append({'position': position, 'tag': tag})
    return self

add_orbital_species(name, s=0, atom='')

Adds a species definition for orbitals in the material.

Parameters:
  • name (str) –

    The name of the orbital species.

  • s (int, default: 0 ) –

    Spin quantum number.

  • atom (str, default: '' ) –

    Name of the atom the orbital belongs to.

Returns:
  • Material

    Returns self to enable method chaining.

Source code in src/granad/materials.py
def add_orbital_species( self, name, s = 0, atom  = ''):
    """
    Adds a species definition for orbitals in the material.

    Parameters:
        name (str): The name of the orbital species.
        s (int): Spin quantum number.
        atom (str, optional): Name of the atom the orbital belongs to.

    Returns:
        Material: Returns self to enable method chaining.
    """
    self.species[name] = (s,atom)
    return self

cut_flake()

Finalizes the material construction by defining a method to cut a flake of the material, according to the material's dimensions like this

1D material : materials.cut_flake_1d 2D material : materials.cut_flake_2d 3D material and higher : materials.cut_flake_generic

This method is intended to be called after all material properties (like lattice constants, basis, orbitals, and interactions) have been fully defined.

Note: This method does not take any parameters and does not return any value. Its effect is internal to the state of the Material object and is meant to prepare the material for simulation by implementing necessary final structural adjustments.

Source code in src/granad/materials.py
@_finalize
def cut_flake( self ):
    """
    Finalizes the material construction by defining a method to cut a flake of the material,
    according to the material's dimensions like this

    1D material : materials.cut_flake_1d
    2D material : materials.cut_flake_2d
    3D material and higher : materials.cut_flake_generic

    This method is intended to be called after all material properties (like lattice constants, 
    basis, orbitals, and interactions) have been fully defined.

    Note:
    This method does not take any parameters and does not return any value. Its effect is
    internal to the state of the Material object and is meant to prepare the material for
    simulation by implementing necessary final structural adjustments.
    """
    pass

lattice_basis(values, periodic=None)

Defines the lattice basis vectors and specifies which dimensions are periodic.

Parameters:
  • values (list of list of float) –

    A list of vectors representing the lattice basis.

  • periodic (list of int, default: None ) –

    Indices of the basis vectors that are periodic. Defaults to all vectors being periodic.

Returns:
  • Material

    Returns self to enable method chaining.

Source code in src/granad/materials.py
def lattice_basis(self, values, periodic = None):
    """
    Defines the lattice basis vectors and specifies which dimensions are periodic.

    Parameters:
        values (list of list of float): A list of vectors representing the lattice basis.
        periodic (list of int, optional): Indices of the basis vectors that are periodic. Defaults to all vectors being periodic.

    Returns:
        Material: Returns self to enable method chaining.
    """
    self._lattice_basis = jnp.array(values)
    total = set(range(len(self._lattice_basis)))        
    periodic = set(periodic) if periodic is not None else total
    self.periodic = list(periodic)
    self.finite = list(total - periodic)
    self.dim = len(self.periodic)                                              
    return self

lattice_constant(value)

Sets the lattice constant for the material.

Parameters:
  • value (float) –

    The lattice constant value.

Returns:
  • Material

    Returns self to enable method chaining.

Source code in src/granad/materials.py
def lattice_constant(self, value):
    """
    Sets the lattice constant for the material.

    Parameters:
        value (float): The lattice constant value.

    Returns:
        Material: Returns self to enable method chaining.
    """
    self.lattice_constant = value
    return self

MaterialCatalog

A class to manage and access built-in material properties within a simulation or modeling framework.

This class provides a central repository for predefined materials, allowing for easy retrieval and description of their properties.

Attributes:
  • _materials (dict) –

    A private dictionary that maps material names to their respective data objects. This dictionary is pre-populated with several example materials such as graphene and MoS2.

Methods:

Name Description
get

Retrieves the data object associated with the given material name.

describe

Prints a description or the data object of the specified material.

available

Prints a list of all available materials stored in the catalog.

Source code in src/granad/materials.py
class MaterialCatalog:
    """
    A class to manage and access built-in material properties within a simulation or modeling framework.

    This class provides a central repository for predefined materials, allowing for easy retrieval
    and description of their properties.

    Attributes:
        _materials (dict): A private dictionary that maps material names to their respective data objects.
                           This dictionary is pre-populated with several example materials such as graphene and MoS2.

    Methods:
        get(material): Retrieves the data object associated with the given material name.
        describe(material): Prints a description or the data object of the specified material.
        available(): Prints a list of all available materials stored in the catalog.
    """
    _materials = {"graphene" : get_graphene, "ssh" : get_ssh, "chain" : get_chain, "hBN" : get_hbn }

    @staticmethod
    def get(material : str, **kwargs):
        """
        Retrieves the material data object for the specified material. Additional keyword arguments are given to the corresponding material function.

        Args:
            material (str): The name of the material to retrieve.

        Returns:
            The data object associated with the specified material.

        Example:
            ```python
            graphene_data = MaterialCatalog.get('graphene')
            ```
        """
        return MaterialCatalog._materials[material](**kwargs)

    @staticmethod
    def describe(material : str):
        """
        Prints a description or the raw data of the specified material from the catalog.

        Args:
            material (str): The name of the material to describe.

        Example:
            ```python
            MaterialCatalog.describe('graphene')
            ```
        """
        print(MaterialCatalog._materials[material]())

    @staticmethod
    def available():
        """
        Prints a list of all materials available in the catalog.

        Example:
            ```python
            MaterialCatalog.available()
            ```
        """
        available_materials = "\n".join(MaterialCatalog._materials.keys())
        print(f"Available materials:\n{available_materials}")

available() staticmethod

Prints a list of all materials available in the catalog.

Example
MaterialCatalog.available()
Source code in src/granad/materials.py
@staticmethod
def available():
    """
    Prints a list of all materials available in the catalog.

    Example:
        ```python
        MaterialCatalog.available()
        ```
    """
    available_materials = "\n".join(MaterialCatalog._materials.keys())
    print(f"Available materials:\n{available_materials}")

describe(material) staticmethod

Prints a description or the raw data of the specified material from the catalog.

Parameters:
  • material (str) –

    The name of the material to describe.

Example
MaterialCatalog.describe('graphene')
Source code in src/granad/materials.py
@staticmethod
def describe(material : str):
    """
    Prints a description or the raw data of the specified material from the catalog.

    Args:
        material (str): The name of the material to describe.

    Example:
        ```python
        MaterialCatalog.describe('graphene')
        ```
    """
    print(MaterialCatalog._materials[material]())

get(material, **kwargs) staticmethod

Retrieves the material data object for the specified material. Additional keyword arguments are given to the corresponding material function.

Parameters:
  • material (str) –

    The name of the material to retrieve.

Returns:
  • The data object associated with the specified material.

Example
graphene_data = MaterialCatalog.get('graphene')
Source code in src/granad/materials.py
@staticmethod
def get(material : str, **kwargs):
    """
    Retrieves the material data object for the specified material. Additional keyword arguments are given to the corresponding material function.

    Args:
        material (str): The name of the material to retrieve.

    Returns:
        The data object associated with the specified material.

    Example:
        ```python
        graphene_data = MaterialCatalog.get('graphene')
        ```
    """
    return MaterialCatalog._materials[material](**kwargs)

cut_flake_1d(material, unit_cells, plot=False)

Cuts a one-dimensional flake from the material based on the specified number of unit cells and optionally plots the lattice and orbital positions.

Parameters:
  • material (Material) –

    The material instance from which to cut the flake.

  • unit_cells (int) –

    The number of unit cells to include in the flake.

  • plot (bool, default: False ) –

    If True, displays a plot of the orbital positions within the lattice. Default is False.

Returns:
  • list

    A list of orbitals positioned within the specified range of the material's lattice.

Note

The function utilizes internal methods of the Material class to compute positions and retrieve orbital data, ensuring that the positions are unique and correctly mapped to the material's grid.

Source code in src/granad/materials.py
def cut_flake_1d( material, unit_cells, plot=False):
    """
    Cuts a one-dimensional flake from the material based on the specified number of unit cells
    and optionally plots the lattice and orbital positions.

    Parameters:
        material (Material): The material instance from which to cut the flake.
        unit_cells (int): The number of unit cells to include in the flake.
        plot (bool, optional): If True, displays a plot of the orbital positions within the lattice.
                               Default is False.

    Returns:
        list: A list of orbitals positioned within the specified range of the material's lattice.

    Note:
        The function utilizes internal methods of the `Material` class to compute positions and
        retrieve orbital data, ensuring that the positions are unique and correctly mapped to the
        material's grid.
    """

    orbital_positions_uc =  material._get_positions_in_uc()
    grid = material._get_grid( [(0, unit_cells)] )
    orbital_positions = material._get_positions_in_lattice( orbital_positions_uc, grid )
    if plot:
        _display_lattice_cut( orbital_positions, orbital_positions )

    orbital_positions = jnp.unique( orbital_positions, axis = 0)        
    return material._get_orbital_list( orbital_positions, grid )

cut_flake_2d(material, polygon, plot=False, minimum_neighbor_number=2)

Cuts a two-dimensional flake from the material defined within the bounds of a specified polygon. It further prunes the positions to ensure that each atom has at least the specified minimum number of neighbors. Optionally, the function can plot the initial and final positions of the atoms within the polygon.

Parameters:
  • material (Material) –

    The material instance from which to cut the flake.

  • polygon (Polygon) –

    A polygon objects with a vertices property holding an array of coordinates defining the vertices of the polygon within which to cut the flake.

  • plot (bool, default: False ) –

    If True, plots the lattice and the positions of atoms before and after pruning. Default is False.

  • minimum_neighbor_number (int, default: 2 ) –

    The minimum number of neighbors each atom must have to remain in the final positions. Default is 2.

Returns:
  • list

    A list of orbitals positioned within the specified polygon and satisfying the neighbor condition.

Note

The function assumes the underlying lattice to be in the xy-plane.

Source code in src/granad/materials.py
def cut_flake_2d( material, polygon, plot=False, minimum_neighbor_number: int = 2):
    """
    Cuts a two-dimensional flake from the material defined within the bounds of a specified polygon.
    It further prunes the positions to ensure that each atom has at least the specified minimum number of neighbors.
    Optionally, the function can plot the initial and final positions of the atoms within the polygon.

    Parameters:
        material (Material): The material instance from which to cut the flake.
        polygon (Polygon): A polygon objects with a vertices property holding an array of coordinates defining the vertices of the polygon within which to cut the flake.
        plot (bool, optional): If True, plots the lattice and the positions of atoms before and after pruning.
                               Default is False.
        minimum_neighbor_number (int, optional): The minimum number of neighbors each atom must have to remain in the final positions.
                                                 Default is 2.

    Returns:
        list: A list of orbitals positioned within the specified polygon and satisfying the neighbor condition.

    Note:
        The function assumes the underlying lattice to be in the xy-plane.
    """
    def _prune_neighbors(
            positions, minimum_neighbor_number, remaining_old=jnp.inf
    ):
        """
        Recursively prunes positions to ensure each position has a sufficient number of neighboring positions
        based on a minimum distance calculated from the unique set of distances between positions.

        Parameters:
            positions (array-like): Array of positions to prune.
            minimum_neighbor_number (int): Minimum required number of neighbors for a position to be retained.
            remaining_old (int): The count of positions remaining from the previous iteration; used to detect convergence.

        Returns:
            array-like: Array of positions that meet the neighbor count criterion.
        """
        if minimum_neighbor_number <= 0:
            return positions
        distances = jnp.round(
            jnp.linalg.norm(positions[:, material.periodic] - positions[:, None, material.periodic], axis=-1), 4
        )
        minimum = jnp.unique(distances)[1]
        mask = (distances <= minimum).sum(axis=0) > minimum_neighbor_number
        remaining = mask.sum()
        if remaining_old == remaining:
            return positions[mask]
        else:
            return _prune_neighbors(
                positions[mask], minimum_neighbor_number, remaining
            )

    if material.name == 'graphene' and polygon.polygon_id in ["hexagon", "triangle"]:
        n, m, vertices, final_atom_positions, initial_atom_positions, sublattice = _cut_flake_graphene(polygon.polygon_id, polygon.edge_type, polygon.side_length, material.lattice_constant)

        raw_list, layer_index = [], 0
        for i, position in enumerate(final_atom_positions):
            orb = Orbital(
                position = position,
                layer_index = layer_index,
                tag="sublattice_1" if sublattice[i] == "A" else "sublattice_2",
                group_id = material._species_to_groups["pz"],                        
                spin=material.species["pz"][0],
                atom_name=material.species["pz"][1]
                    )
            layer_index += 1
            raw_list.append(orb)

        orbital_list = OrbitalList(raw_list)
        material._set_couplings(orbital_list.set_hamiltonian_groups, "hamiltonian")
        material._set_couplings(orbital_list.set_coulomb_groups, "coulomb")
        orb_list = orbital_list

    else:
        # to cover the plane, we solve the linear equation P = L C, where P are the polygon vertices, L is the lattice basis and C are the coefficients
        vertices = polygon.vertices
        L = material._lattice_basis[material.periodic,:2] * material.lattice_constant
        coeffs = jnp.linalg.inv(L.T) @ vertices.T * 1.1

        # we just take the largest extent of the shape
        u1, u2 = jnp.ceil( coeffs ).max( axis = 1)
        l1, l2 = jnp.floor( coeffs ).min( axis = 1)
        grid = material._get_grid( [ (int(l1), int(u1)), (int(l2), int(u2)) ]  )

        # get atom positions in the unit cell in fractional coordinates
        orbital_positions =  material._get_positions_in_uc()
        unit_cell_fractional_atom_positions = jnp.unique(
            jnp.round(orbital_positions, 6), axis=0
                )

        initial_atom_positions = material._get_positions_in_lattice(
            unit_cell_fractional_atom_positions, grid
        ) 

        polygon_path = Path(vertices)
        flags = polygon_path.contains_points(initial_atom_positions[:, :2])        
        pruned_atom_positions = initial_atom_positions[flags]

        # get atom positions where every atom has at least minimum_neighbor_number neighbors
        final_atom_positions = _prune_neighbors(
            pruned_atom_positions, minimum_neighbor_number
        )
        orb_list = material._get_orbital_list(final_atom_positions, grid)

    if plot == True:
        _display_lattice_cut(
            initial_atom_positions, final_atom_positions, vertices
        )
    return orb_list

cut_flake_generic(material, grid_range)

Cuts a flake from the material using a specified grid range. This method is generic and can be applied to materials of any dimensionality.

The function calculates the positions of orbitals within the unit cell, projects these onto the full lattice based on the provided grid range, and ensures that each position is unique. The result is a list of orbitals that are correctly positioned within the defined grid.

Parameters:
  • material (Material) –

    The material instance from which to cut the flake.

  • grid_range (list of tuples) –

    Each tuple in the list specifies the range for the grid in that dimension. For example, [(0, 10), (0, 5)] defines a grid that extends from 0 to 10 in the first dimension and from 0 to 5 in the second dimension.

Returns:
  • list

    A list of orbitals within the specified grid range, uniquely positioned.

Note

The grid_range parameter should be aligned with the material's dimensions and lattice structure, as mismatches can lead to incorrect or inefficient slicing of the material.

Source code in src/granad/materials.py
def cut_flake_generic( material, grid_range ):
    """
    Cuts a flake from the material using a specified grid range. This method is generic and can be applied
    to materials of any dimensionality.

    The function calculates the positions of orbitals within the unit cell, projects these onto the full
    lattice based on the provided grid range, and ensures that each position is unique. The result is a list
    of orbitals that are correctly positioned within the defined grid.

    Parameters:
        material (Material): The material instance from which to cut the flake.
        grid_range (list of tuples): Each tuple in the list specifies the range for the grid in that dimension.
                                     For example, [(0, 10), (0, 5)] defines a grid that extends from 0 to 10
                                     in the first dimension and from 0 to 5 in the second dimension.

    Returns:
        list: A list of orbitals within the specified grid range, uniquely positioned.

    Note:
        The grid_range parameter should be aligned with the material's dimensions and lattice structure,
        as mismatches can lead to incorrect or inefficient slicing of the material.
    """
    orbital_positions_uc =  material._get_positions_in_uc()
    grid = material._get_grid( grid_range)
    orbital_positions = material._get_positions_in_lattice( orbital_positions_uc, grid )
    orbital_positions = jnp.unique( orbital_positions, axis = 0)        
    return material._get_orbital_list( orbital_positions, grid )

get_chain(hopping=-2.66, lattice_const=1.42)

Generates a 1D metallic chain model with specified hopping and Coulomb interaction parameters.

Parameters:
  • hopping (float, default: -2.66 ) –

    nn hopping, defaults to -2.66 eV.

  • lattice (constant (float) –

    nn distance, defaults to 1.42 Ångström

Returns:
  • Material

    A Material object representing the 1D metallic chain, which includes: - Lattice Structure: - Lattice constant: 2.46 Å. - Lattice basis: [1, 0, 0] (1D chain along the x-axis). - Orbital: - Single orbital species: "pz" (associated with Carbon atoms). - One orbital per unit cell, positioned at [0].

Example

metal_chain = get_chain() print(metal_chain)

Source code in src/granad/materials.py
def get_chain(hopping = -2.66, lattice_const = 1.42):
    """
    Generates a 1D metallic chain model with specified hopping and Coulomb interaction parameters.

    Args:
        hopping (float, optional): nn hopping, defaults to -2.66 eV.
        lattice constant (float, optional): nn distance, defaults to 1.42 Ångström

    Returns:
        Material: A `Material` object representing the 1D metallic chain, which includes:
            - **Lattice Structure**: 
                - Lattice constant: 2.46 Å.
                - Lattice basis: [1, 0, 0] (1D chain along the x-axis).
            - **Orbital**:
                - Single orbital species: "pz" (associated with Carbon atoms).
                - One orbital per unit cell, positioned at [0].

    Example:
        >>> metal_chain = get_chain()
        >>> print(metal_chain)
    """
    return (Material("chain")
            .lattice_constant(lattice_const)
            .lattice_basis([
                [1, 0, 0],
            ])
            .add_orbital_species("pz", atom='C')
            .add_orbital(position=(0,), tag="", species="pz")
            .add_interaction(
                "hamiltonian",
                participants=("pz", "pz"),
                parameters=[0.0, hopping],
            )
            .add_interaction(
                "coulomb",
                participants=("pz", "pz"),
                parameters=[16.522, 8.64, 5.333],
                expression=ohno_potential(0)
            )
            )

get_graphene(hoppings=None)

Generates a graphene model based on parameters from David Tománek and Steven G. Louie, Phys. Rev. B 37, 8327 (1988).

Parameters:
  • hoppings (list, default: None ) –

    Hopping parameters for pz-pz interactions. Default is [onsite, nn] = [0, -2.66], as specified in the reference.

Returns:
  • Material

    A Material object representing the graphene model, which includes: - Lattice Structure: - Lattice constant: 2.46 Å. - Hexagonal lattice basis vectors: [1, 0, 0] and [-0.5, sqrt(3)/2, 0]. - Orbitals: - Two sublattices, each with a single "pz" orbital, positioned at (0, 0) and (-1/3, -2/3). - Hamiltonian Interaction: - Nearest-neighbor hopping: [0.0 (onsite energy), hopping (default -2.66 eV)]. - Coulomb Interaction: - Parameterized by the Ohno potential with parameters [16.522, 8.64, 5.333].

Example

graphene_model = get_graphene(hopping=-2.7) print(graphene_model)

Source code in src/granad/materials.py
def get_graphene(hoppings = None):
    """
    Generates a graphene model based on parameters from 
    [David Tománek and Steven G. Louie, Phys. Rev. B 37, 8327 (1988)](https://doi.org/10.1103/PhysRevB.37.8327).

    Args:
        hoppings (list, optional): Hopping parameters for pz-pz interactions.  Default is [onsite, nn] = [0, -2.66], as specified in the reference.

    Returns:
        Material: A `Material` object representing the graphene model, which includes:
            - **Lattice Structure**:
                - Lattice constant: 2.46 Å.
                - Hexagonal lattice basis vectors: [1, 0, 0] and [-0.5, sqrt(3)/2, 0].
            - **Orbitals**:
                - Two sublattices, each with a single "pz" orbital, positioned at (0, 0) and (-1/3, -2/3).
            - **Hamiltonian Interaction**:
                - Nearest-neighbor hopping: [0.0 (onsite energy), hopping (default -2.66 eV)].
            - **Coulomb Interaction**:
                - Parameterized by the Ohno potential with parameters [16.522, 8.64, 5.333].

    Example:
        >>> graphene_model = get_graphene(hopping=-2.7)
        >>> print(graphene_model)
    """
    hoppings = hoppings or [0, -2.66]
    return (Material("graphene")
            .lattice_constant(2.46)
            .lattice_basis([
                [1, 0, 0],
                [-0.5, jnp.sqrt(3)/2, 0]
            ])
            .add_orbital_species("pz",  atom='C')
            .add_orbital(position=(0, 0), tag="sublattice_1", species="pz")
            .add_orbital(position=(-1/3, -2/3), tag="sublattice_2", species="pz")
            .add_interaction(
                "hamiltonian",
                participants=("pz", "pz"),
                parameters=hoppings,
            )
            .add_interaction(
                "coulomb",
                participants=("pz", "pz"),
                parameters=[16.522, 8.64, 5.333],
                expression=ohno_potential(0)
            )
            )

get_hbn(lattice_constant=2.5, bb_hoppings=None, nn_hoppings=None, bn_hoppings=None)

Get a material representation for hexagonal boron nitride (hBN).

Parameters: - lattice_constant (float): The lattice constant for hBN. Default is 2.50. - bb_hoppings (list or None): Hopping parameters for B-B interactions. Default is [2.46, -0.04]. - nn_hoppings (list or None): Hopping parameters for nearest-neighbor interactions. Default is [-2.55, -0.04]. - bn_hoppings (list or None): Hopping parameters for B-N interactions. Default is [-2.16].

Default values are derived from the study of the electronic structure of hexagonal boron nitride (hBN). See Giraud et al. for more details.

Returns: - A tuple containing the lattice constant and hopping parameters.

Source code in src/granad/materials.py
def get_hbn(lattice_constant = 2.50, bb_hoppings = None, nn_hoppings = None, bn_hoppings = None):
    """
    Get a material representation for hexagonal boron nitride (hBN).

    Parameters:
    - lattice_constant (float): The lattice constant for hBN. Default is 2.50.
    - bb_hoppings (list or None): Hopping parameters for B-B interactions. 
                                  Default is [2.46, -0.04].
    - nn_hoppings (list or None): Hopping parameters for nearest-neighbor interactions. 
                                  Default is [-2.55, -0.04].
    - bn_hoppings (list or None): Hopping parameters for B-N interactions. 
                                  Default is [-2.16].

    Default values are derived from the study of the electronic structure of hexagonal boron nitride (hBN).
    See [Giraud et al.](https://www.semanticscholar.org/paper/Study-of-the-Electronic-Structure-of-hexagonal-on-Unibertsitatea-Thesis/ff1e000bbad5d8e2df5f85cb724b1a9e42a8b0f0) for more details.

    Returns:
    - A tuple containing the lattice constant and hopping parameters.
    """
    bb_hoppings = [2.46, -0.04] if bb_hoppings is None else bb_hoppings
    bn_hoppings = [-2.16]  if bn_hoppings is None else bn_hoppings
    nn_hoppings = [-2.55, -0.04]  if nn_hoppings is None else nn_hoppings

    return (Material("hBN")
            .lattice_constant(lattice_constant)  # Approximate lattice constant of hBN
            .lattice_basis([
                [1, 0, 0],
                [-0.5, jnp.sqrt(3)/2, 0],  # Hexagonal lattice
            ])
            .add_orbital_species("pz_boron", atom='B')
            .add_orbital_species("pz_nitrogen", atom='N')
            .add_orbital(position=(0, 0), tag="B", species="pz_boron")
            .add_orbital(position=(-1/3, -2/3), tag="N", species="pz_nitrogen")
            .add_interaction(
                "hamiltonian",
                participants=("pz_boron", "pz_boron"),
                parameters=bb_hoppings,  
            )
            .add_interaction(
                "hamiltonian",
                participants=("pz_nitrogen", "pz_nitrogen"),
                parameters=nn_hoppings,  
            )
            .add_interaction(
                "hamiltonian",
                participants=("pz_boron", "pz_nitrogen"),
                parameters=bn_hoppings,
            )
            .add_interaction(
                "coulomb",
                participants=("pz_boron", "pz_boron"),
                expression = ohno_potential(1)
            )
            .add_interaction(
                "coulomb",
                participants=("pz_nitrogen", "pz_nitrogen"),
                expression = ohno_potential(1)
            )
            .add_interaction(
                "coulomb",
                participants=("pz_boron", "pz_nitrogen"),
                expression = ohno_potential(1)
            )
            )

get_mos2()

Generates a MoS2 model based on parameters from Bert Jorissen, Lucian Covaci, and Bart Partoens, SciPost Phys. Core 7, 004 (2024), taking into account even-parity eigenstates.

Returns:
  • Material

    A Material object representing the MoS2 model.

Example

mos2 = get_mos2() print(mos2)

Source code in src/granad/materials.py
def get_mos2():
    """
    Generates a MoS2 model based on parameters from [Bert Jorissen, Lucian Covaci, and Bart Partoens, SciPost Phys. Core 7, 004 (2024)](https://scipost.org/SciPostPhysCore.7.1.004), taking into account even-parity eigenstates.

    Returns:
        Material: A `Material` object representing the MoS2 model.

    Example:
        >>> mos2 = get_mos2()
        >>> print(mos2)
    """
    reference_vector = jnp.array([1,1,0])
    ref = reference_vector[:2] / jnp.linalg.norm(reference_vector[:2])            

    # Onsite energies
    epsilon_M_e_0 = -6.475
    epsilon_M_e_1 = -4.891
    epsilon_X_e_0 = -7.907
    epsilon_X_e_1 = -9.470

    # First nearest neighbor hopping parameters (u1)
    u1_e_0 = 0.999
    u1_e_1 = -1.289
    u1_e_2 = 0.795
    u1_e_3 = -0.688
    u1_e_4 = -0.795

    # Second nearest neighbor hopping parameters for metal (u2,Me)
    u2_Me_0 = -0.048
    u2_Me_1 = 0.580
    u2_Me_2 = -0.074
    u2_Me_3 = -0.414
    u2_Me_4 = -0.299
    u2_Me_5 = 0.045

    # Second nearest neighbor hopping parameters for chalcogen (u2,Xe)
    u2_Xe_0 = 0.795
    u2_Xe_1 = -0.248
    u2_Xe_2 = 0.164
    u2_Xe_3 = -0.002
    u2_Xe_4 = -0.283
    u2_Xe_5 = -0.174

    onsite_x = jnp.array([epsilon_X_e_0, epsilon_X_e_0, epsilon_X_e_1])

    # poor man's back transformation since odd orbitals are discarded
    onsite_x /= 2

    onsite_m = jnp.array([epsilon_M_e_0, epsilon_M_e_1, epsilon_M_e_1])

    nn = jnp.array([
        [0, 0, u1_e_0],
        [u1_e_1, u1_e_2, 0],
        [u1_e_3, u1_e_4, 0]
    ])

    # poor man's back transformation since odd orbitals are discarded
    nn /= jnp.sqrt(2)

    nnn_M = jnp.array([
        [u2_Me_0,         u2_Me_1,         u2_Me_2],
        [u2_Me_1,         u2_Me_3,         u2_Me_4],
        [-u2_Me_2,       -u2_Me_4,         u2_Me_5]
    ])

    nnn_X = jnp.array([
        [u2_Xe_0,         u2_Xe_1,         u2_Xe_2],
        [-u2_Xe_1,        u2_Xe_3,         u2_Xe_4],
        [-u2_Xe_2,        u2_Xe_4,         u2_Xe_5] ]
    )
    # poor man's back transformation since odd orbitals are discarded
    nnn_X /= 2

    gamma = 2 * jnp.pi / 3  # 120 degrees in radians

    R_X_e = jnp.array([
        [jnp.cos(gamma), -jnp.sin(gamma), 0],
        [jnp.sin(gamma),  jnp.cos(gamma), 0],
        [0,              0,             1]
    ])

    theta = 2 * gamma      # 240 degrees for d_x2-y2 and d_xy rotation

    R_M_e = jnp.array([
        [1.0,           0.0,           0.0],
        [0.0,  jnp.cos(theta), -jnp.sin(theta)],
        [0.0,  jnp.sin(theta),  jnp.cos(theta)]
    ])

    nn_list = jnp.stack([nn, R_X_e @ nn @ R_M_e.T, R_X_e.T @ nn @ R_M_e])
    nnn_M_list = jnp.stack([nnn_M, R_M_e @ nnn_M @ R_M_e.T, R_M_e.T @ nnn_M @ R_M_e])
    nnn_X_list = jnp.stack([nnn_X, R_X_e @ nnn_X @ R_X_e.T, R_X_e.T @ nnn_X @ R_X_e])

    d_orbs =  ["dz2", "dx2-y2", "dxy"]
    p_orbs = ["px", "py", "pz"]
    orbs = d_orbs + p_orbs

    def generate_coupling(orb1, orb2):
        orb1_idx = orbs.index(orb1) % 3
        orb2_idx = orbs.index(orb2) % 3

        # Select the correct matrix stack
        arr = nn_list
        onsite = onsite_m
        if orb1[0] == "p" and orb2[0] == "p":
            arr = nnn_X_list
            onsite = onsite_x
        elif orb1[0] == "d" and orb2[0] == "d":
            arr = nnn_M_list

        def nn_coupling(vec):
            vec /= jnp.linalg.norm(vec)

            # Compute angle between ref and vec
            angle = jnp.arctan2(vec[1], vec[0]) - jnp.arctan2(ref[1], ref[0])
            angle = jnp.mod(angle + jnp.pi, 2 * jnp.pi) - jnp.pi  # Map to [-π, π]
            branch = 0 * jnp.logical_and(angle >= -jnp.pi / 3, angle <= jnp.pi / 3) + 1 * (angle < -jnp.pi / 3) + 2 * (angle >= jnp.pi/3)
            idx = jax.lax.switch(
                branch,
                [lambda : 0, lambda : 1, lambda : 2],
            )

            return arr[idx][orb1_idx, orb2_idx]

        def coupling(vec):
            length = jnp.linalg.norm(vec[:2])
            thresh = 3.4
            branch = 0 * (length == 0) + 1 * jnp.logical_and(0 < length, length < thresh) + 2 * (length >= thresh)

            return jax.lax.switch(branch,
                                  [lambda x : onsite[orb1_idx],
                                   lambda x : nn_coupling(x),
                                   lambda x : 0. 
                                   ],
                                  vec[:2]
                                  )

        return coupling

    mat = (Material("MoS2")
    .lattice_constant(3.16)  # Approximate lattice constant of monolayer MoS2
    .lattice_basis([
        [1, 0, 0],
        [-0.5, jnp.sqrt(3)/2, 0],  # Hexagonal lattice
        [0, 0, 1],
    ], periodic = [0,1])
    .add_orbital_species("dz2", atom='Mo')
    .add_orbital_species("dx2-y2", atom='Mo')
    .add_orbital_species("dxy", atom='Mo')
    .add_orbital_species("px", atom='S')
    .add_orbital_species("py", atom='S')
    .add_orbital_species("pz", atom='S')
    .add_orbital(position=(0, 0, 0), tag="dz2", species="dz2")
    .add_orbital(position=(0, 0, 0), tag="dx2-y2", species="dx2-y2")
    .add_orbital(position=(0, 0, 0), tag="dxy", species="dxy")
    .add_orbital(position=(1/3, 2/3, 1.5), tag="px_top", species="px")
    .add_orbital(position=(1/3, 2/3, 1.5), tag="py_top", species="py")
    .add_orbital(position=(1/3, 2/3, 1.5), tag="pz_top", species="pz")
    .add_orbital(position=(1/3, 2/3, -1.5), tag="px_bottom", species="px")
    .add_orbital(position=(1/3, 2/3, -1.5), tag="py_bottom", species="py")
    .add_orbital(position=(1/3, 2/3, -1.5), tag="pz_bottom", species="pz")
           )

    for orb1 in orbs:
        for orb2 in orbs:
            mat = (mat.add_interaction("hamiltonian",
                                       participants = (orb1, orb2),
                                       expression = generate_coupling(orb1, orb2)
                                       )
                   .add_interaction("coulomb",
                                    participants = (orb1, orb2),
                                    expression = ohno_potential(1)
                                    )
                   )            
    return mat

get_ssh(delta=0.2, displacement=0.4, base_hopping=-2.66, lattice_const=2.84)

Generates an SSH (Su-Schrieffer-Heeger) model with specified hopping parameters and a 2-atom unit cell.

Parameters:
  • delta (float, default: 0.2 ) –

    A parameter controlling the alternating hopping amplitudes in the model. - The nearest-neighbor hopping amplitudes are defined as [1 + delta, 1 - delta]. Default is 0.2.

  • displacement (float, default: 0.4 ) –

    The displacement of the second atom in the unit cell along the x-axis (in Ångström). - Determines the position of the second atom relative to the first. Default is 0.4. Takes values between 0 and 1.

  • base_hopping (float, default: -2.66 ) –

    base hopping value on which symmetrically intra and inter unit-cell hopping rates are applied, defaults to -2.66 eV.

  • lattice (constant (float) –

    distance between two unict cells, defaults to 2*1.42 = 2.84 Ångström (since each unit cell contains two sites).

Returns:
  • Material

    An SSH model represented as a Material object, including: - Lattice structure with a lattice constant of 2.46 Å. - Two pz orbitals (one per sublattice) placed at [0] and [displacement]. - Nearest-neighbor (NN) hopping amplitudes: [base_hopping(1 + delta), base_hopping(1 - delta)]. - Coulomb interactions parameterized by Ohno potential.

Source code in src/granad/materials.py
def get_ssh(delta = 0.2, displacement = 0.4, base_hopping = -2.66, lattice_const = 2.84):
    """
    Generates an SSH (Su-Schrieffer-Heeger) model with specified hopping parameters and a 2-atom unit cell.

    Args:
        delta (float, optional): A parameter controlling the alternating hopping amplitudes in the model. 
            - The nearest-neighbor hopping amplitudes are defined as [1 + delta, 1 - delta]. Default is 0.2.
        displacement (float, optional): The displacement of the second atom in the unit cell along the x-axis (in Ångström). 
            - Determines the position of the second atom relative to the first. Default is 0.4. Takes values between 0 and 1.
        base_hopping (float, optional): base hopping value on which symmetrically intra and inter unit-cell hopping rates are applied, defaults to -2.66 eV.
        lattice constant (float, optional):  distance between two unict cells, defaults to 2*1.42 = 2.84 Ångström (since each unit cell contains two sites).

    Returns:
        Material: An SSH model represented as a `Material` object, including:
            - Lattice structure with a lattice constant of 2.46 Å.
            - Two pz orbitals (one per sublattice) placed at [0] and [displacement].
            - Nearest-neighbor (NN) hopping amplitudes: [base_hopping*(1 + delta), base_hopping*(1 - delta)].
            - Coulomb interactions parameterized by Ohno potential.
    """
    return (Material("ssh")
            .lattice_constant(lattice_const) #Changed to  2*a_cc
            .lattice_basis([
                [1, 0, 0],
            ])
            .add_orbital_species("pz", atom='C')
            .add_orbital(position=(0,), tag="sublattice_1", species="pz")
            .add_orbital(position=(displacement,), tag="sublattice_2", species="pz")
            .add_interaction(
                "hamiltonian",
                participants=("pz", "pz"),
                parameters=[0.0, -2.66 + delta*(-2.66), -2.66 - delta*(-2.66)],
            )
            .add_interaction(
                "coulomb",
                participants=("pz", "pz"),
                parameters=[16.522, 8.64, 5.333],
                expression=ohno_potential(0)
            )
            )

ohno_potential(offset=0, start=14.399)

Generates a callable that represents a regularized Coulomb-like potential.

The potential function is parameterized to provide flexibility in adjusting the starting value and an offset, which can be used to avoid singularities at zero distance.

Parameters:
  • offset (float, default: 0 ) –

    The offset added to the distance to prevent division by zero and to regularize the potential at short distances. Defaults to 0.

  • start (float, default: 14.399 ) –

    The initial strength or scaling factor of the potential. Defaults to 14.399.

Returns:
  • Callable[[float], complex]: A function that takes a distance 'd' and returns the computed Coulomb-like potential as a complex number.

Note
potential = ohno_potential()
print(potential(1))  # Output: (14.399 + 0j) if default parameters used
Source code in src/granad/materials.py
def ohno_potential( offset = 0, start = 14.399 ):
    """
    Generates a callable that represents a regularized Coulomb-like potential.

    The potential function is parameterized to provide flexibility in adjusting the starting value and an offset,
    which can be used to avoid singularities at zero distance.

    Args:
        offset (float): The offset added to the distance to prevent division by zero and to regularize the potential at short distances. Defaults to 0.
        start (float): The initial strength or scaling factor of the potential. Defaults to 14.399.

    Returns:
        Callable[[float], complex]: A function that takes a distance 'd' and returns the computed Coulomb-like potential as a complex number.

    Note:
        ```python
        potential = ohno_potential()
        print(potential(1))  # Output: (14.399 + 0j) if default parameters used
        ```
    """
    def inner(d):
        """Coupling with a (regularized) Coulomb-like potential"""
        return start / (jnp.linalg.norm(d) + offset) + 0j
    return inner

zero_coupling(d)

Returns a zero coupling constant as a complex number.

Parameters:
  • d (float) –

    A parameter (typically representing distance or some other factor) that is ignored by the function, as the output is always zero.

Returns:
  • complex

    Returns 0.0 as a complex number (0.0j).

Source code in src/granad/materials.py
def zero_coupling(d):
    """
    Returns a zero coupling constant as a complex number.

    Args:
        d (float): A parameter (typically representing distance or some other factor) that is ignored by the function, as the output is always zero.

    Returns:
        complex: Returns 0.0 as a complex number (0.0j).
    """
    return 0.0j

Circle(radius, n_vertices=8)

Generates the vertices of a polygon that approximates a circle, given the radius and the number of vertices.

The circle approximation is created by calculating points along the circumference using the radius provided. The number of vertices specifies how many sides the polygon will have, thus controlling the granularity of the approximation. By default, an octagon is generated.

Parameters:
  • radius (float) –

    The radius of the circle to approximate, in arbitrary units.

  • n_vertices (int, default: 8 ) –

    The number of vertices (or sides) of the approximating polygon. Default is 8.

Returns:
  • jax.numpy.ndarray: An array of shape (n_vertices+1, 2), representing the vertices of the polygon, including the first vertex repeated at the end to close the shape.

Note

The accuracy of the circle approximation improves with an increase in the number of vertices. For a smoother circle, increase the number of vertices.

# Create an approximate circle with a radius of 20 units and default vertices
circle_octagon = Circle(20)

# Create an approximate circle with a radius of 15 units using 12 vertices
circle_dodecagon = Circle(15, 12)
Source code in src/granad/shapes.py
def Circle(radius, n_vertices = 8):
    """
    Generates the vertices of a polygon that approximates a circle, given the radius and the number of vertices.

    The circle approximation is created by calculating points along the circumference using the radius provided. The number of vertices specifies how many sides the polygon will have, thus controlling the granularity of the approximation. By default, an octagon is generated.

    Parameters:
        radius (float): The radius of the circle to approximate, in arbitrary units.
        n_vertices (int): The number of vertices (or sides) of the approximating polygon. Default is 8.

    Returns:
        jax.numpy.ndarray: An array of shape (n_vertices+1, 2), representing the vertices of the polygon,
                           including the first vertex repeated at the end to close the shape.

    Note:
        The accuracy of the circle approximation improves with an increase in the number of vertices. For a smoother circle, increase the number of vertices.

        ```python
        # Create an approximate circle with a radius of 20 units and default vertices
        circle_octagon = Circle(20)

        # Create an approximate circle with a radius of 15 units using 12 vertices
        circle_dodecagon = Circle(15, 12)
        ```
    """
    circle = jnp.array([
        (radius * jnp.cos(2 * jnp.pi * i / n_vertices), radius * jnp.sin(2 * jnp.pi * i / n_vertices))
        for i in range(n_vertices)
    ])

    return Polygon(vertices = jnp.vstack([circle, circle[0]]), polygon_id = "circle")

Hexagon(length)

Generates the vertices of a regular hexagon given the side length.

The hexagon is oriented such that one vertex points upwards and the function is designed to be used with the @_edge_type decorator for positional adjustments and rotations.

Parameters:
  • length (float) –

    The length of each side of the hexagon, specified in angstroms.

Returns:
  • jax.numpy.ndarray: An array of shape (7, 2), representing the vertices of the hexagon, including the starting vertex repeated at the end for drawing closed shapes.

Note
# Hexagon with side length of 1.0 angstrom
hexagon = Hexagon(1.0)
Source code in src/granad/shapes.py
@_edge_type
def Hexagon(length):
    """
    Generates the vertices of a regular hexagon given the side length.

    The hexagon is oriented such that one vertex points upwards and the function is designed
    to be used with the @_edge_type decorator for positional adjustments and rotations.

    Parameters:
        length (float): The length of each side of the hexagon, specified in angstroms.

    Returns:
        jax.numpy.ndarray: An array of shape (7, 2), representing the vertices of the hexagon,
                           including the starting vertex repeated at the end for drawing closed shapes.

    Note:
        ```python
        # Hexagon with side length of 1.0 angstrom
        hexagon = Hexagon(1.0)
        ```
    """
    n = 6
    s = 1
    angle = 2 * jnp.pi / n
    vertices = length * jnp.array(
        [
            (s * jnp.cos(i * angle), s * jnp.sin(i * angle))
            for i in [x for x in range(n)] + [0]
        ]
    )
    return Polygon(vertices = vertices, polygon_id = "hexagon", side_length = length)

Rectangle(length_x, length_y)

Generates the vertices of a rectangle given the lengths along the x and y dimensions.

The rectangle is centered at the origin, and the function is designed to be used with the @_edge_type decorator, allowing for positional shifts and rotations (if specified).

Parameters:
  • length_x (float) –

    The length of the rectangle along the x-axis, specified in angstroms.

  • length_y (float) –

    The length of the rectangle along the y-axis, specified in angstroms.

Returns:
  • jax.numpy.ndarray: An array of shape (5, 2), representing the vertices of the rectangle, starting and ending at the same vertex to facilitate drawing closed shapes.

Note
# Rectangle with length 2.0 and height 1.0 angstroms
rectangle = Rectangle(2.0, 1.0)
Source code in src/granad/shapes.py
@_edge_type
def Rectangle(length_x, length_y):
    """
    Generates the vertices of a rectangle given the lengths along the x and y dimensions.

    The rectangle is centered at the origin, and the function is designed to be used with
    the @_edge_type decorator, allowing for positional shifts and rotations (if specified).

    Parameters:
        length_x (float): The length of the rectangle along the x-axis, specified in angstroms.
        length_y (float): The length of the rectangle along the y-axis, specified in angstroms.

    Returns:
        jax.numpy.ndarray: An array of shape (5, 2), representing the vertices of the rectangle,
                           starting and ending at the same vertex to facilitate drawing closed shapes.

    Note:
        ```python
        # Rectangle with length 2.0 and height 1.0 angstroms
        rectangle = Rectangle(2.0, 1.0)
        ```
    """
    vertices = jnp.array(
        [
            (-1 * length_x, -0.5 * length_y),
            (1 * length_x, -0.5 * length_y),
            (1 * length_x, 0.5 * length_y),
            (-1 * length_x, 0.5 * length_y),
            (-1 * length_x, -0.5 * length_y),
        ]
    )
    return Polygon(vertices = vertices, polygon_id = "rectangle")

Rhomboid(base, height)

Generates the vertices of a rhomboid given the base length and height.

The rhomboid is initially oriented with the base along the x-axis, and one angle being 30 degrees, designed to be adjusted for position and orientation using the @_edge_type decorator.

Parameters:
  • base (float) –

    The length of the base of the rhomboid, specified in angstroms.

  • height (float) –

    The vertical height of the rhomboid, specified in angstroms.

Returns:
  • jax.numpy.ndarray: An array of shape (5, 2), representing the vertices of the rhomboid, starting and ending at the same vertex to complete the shape.

Note
# Rhomboid with base 2.0 angstroms and height 1.0 angstrom
rhomboid = Rhomboid(2.0, 1.0)
Source code in src/granad/shapes.py
@_edge_type
def Rhomboid(base, height):
    """
    Generates the vertices of a rhomboid given the base length and height.

    The rhomboid is initially oriented with the base along the x-axis, and one angle being 30 degrees,
    designed to be adjusted for position and orientation using the @_edge_type decorator.

    Parameters:
        base (float): The length of the base of the rhomboid, specified in angstroms.
        height (float): The vertical height of the rhomboid, specified in angstroms.

    Returns:
        jax.numpy.ndarray: An array of shape (5, 2), representing the vertices of the rhomboid,
                           starting and ending at the same vertex to complete the shape.

    Note:
        ```python
        # Rhomboid with base 2.0 angstroms and height 1.0 angstrom
        rhomboid = Rhomboid(2.0, 1.0)
        ```
    """
    angle = jnp.radians(30)
    vertices = jnp.array(
        [
            (0, 0),
            (base, 0),
            (base + height * jnp.sin(angle), height * jnp.cos(angle)),
            (height * jnp.sin(angle), height * jnp.cos(angle)),
            (0, 0),
        ]
    )
    return Polygon(vertices = vertices, polygon_id = "rhomboid")

Triangle(side_length)

Generates the vertices of an equilateral triangle given the side length.

The triangle is oriented such that one vertex points upwards and the base is horizontal. This function is designed to be used with the @_edge_type decorator, which adds functionality to shift the triangle's position or rotate it based on additional 'shift' and 'armchair' parameters passed to the function.

Parameters:
  • side_length (float) –

    The length of each side of the triangle, specified in angstroms.

Returns:
  • jax.numpy.ndarray: An array of shape (4, 2), representing the vertices of the triangle, including the starting vertex repeated at the end to facilitate drawing closed shapes.

Note
# Create a triangle with side length of 1.0 angstrom, no shift or rotation
triangle = Triangle(1.0)

# Create a triangle with side length of 1.0 angstrom, shifted by [1, 1] units
triangle_shifted = Triangle(1.0, shift=[1, 1])

# Create a triangle with side length of 1.0 angstrom, rotated by 90 degrees (armchair orientation)
triangle_rotated = Triangle(1.0, armchair=True)
Source code in src/granad/shapes.py
@_edge_type
def Triangle(side_length):
    """
    Generates the vertices of an equilateral triangle given the side length.

    The triangle is oriented such that one vertex points upwards and the base is horizontal.
    This function is designed to be used with the @_edge_type decorator, which adds functionality
    to shift the triangle's position or rotate it based on additional 'shift' and 'armchair'
    parameters passed to the function.

    Parameters:
        side_length (float): The length of each side of the triangle, specified in angstroms.

    Returns:
        jax.numpy.ndarray: An array of shape (4, 2), representing the vertices of the triangle,
                           including the starting vertex repeated at the end to facilitate
                           drawing closed shapes.

    Note:
        ```python
        # Create a triangle with side length of 1.0 angstrom, no shift or rotation
        triangle = Triangle(1.0)

        # Create a triangle with side length of 1.0 angstrom, shifted by [1, 1] units
        triangle_shifted = Triangle(1.0, shift=[1, 1])

        # Create a triangle with side length of 1.0 angstrom, rotated by 90 degrees (armchair orientation)
        triangle_rotated = Triangle(1.0, armchair=True)
        ```
    """
    vertices = side_length * jnp.array(
        [
            (0, jnp.sqrt(3) / 3),
            (-0.5, -jnp.sqrt(3) / 6),
            (0.5, -jnp.sqrt(3) / 6),
            (0, jnp.sqrt(3) / 3),
        ]
    )
    return Polygon(vertices = vertices, polygon_id = "triangle", side_length = side_length)

BareHamiltonian()

Represents the unperturbed single-particle tight-binding mean field Hamiltonian, denoted as \(h^{(0)}\).

Returns:
  • Function

    Provides the bare Hamiltonian matrix, representing the unperturbed state of the system.

Source code in src/granad/potentials.py
def BareHamiltonian():
    """Represents the unperturbed single-particle tight-binding mean field Hamiltonian, denoted as $h^{(0)}$.

    Returns:
        Function: Provides the bare Hamiltonian matrix, representing the unperturbed state of the system.
    """
    return lambda t, r, args: args.hamiltonian

Coulomb()

Calculates the induced Coulomb potential based on deviations from a stationary density matrix, represented as \(\sim \lambda C(\rho-\rho_0)\). Here, \(\lambda\) is a scaling factor.

Returns:
  • Function

    Computes the Coulomb interaction scaled by deviations from the stationary state.

Source code in src/granad/potentials.py
def Coulomb():
    """Calculates the induced Coulomb potential based on deviations from a stationary density matrix, represented as $\sim \lambda C(\\rho-\\rho_0)$. Here, $\lambda$ is a scaling factor.

    Returns:
        Function: Computes the Coulomb interaction scaled by deviations from the stationary state.
    """
    return lambda t, r, args: jnp.diag(args.coulomb_scaled @ (r-args.stationary_density_matrix).diagonal() * args.electrons )

Diamagnetic(vector_potential)

Diamagnetic Coulomb gauge coupling to an external vector potential represented as \(\sim \vec{A}^2\).

Parameters:
  • vector_potential (callable) –

    Function that returns the vector potential at a given time.

Returns:
  • Function

    Computes the square of the vector potential, representing diamagnetic interactions.

Source code in src/granad/potentials.py
def Diamagnetic(vector_potential):
    """Diamagnetic Coulomb gauge coupling to an external vector potential represented as $\sim \\vec{A}^2$. 

    Args:
        vector_potential (callable): Function that returns the vector potential at a given time.

    Returns:
        Function: Computes the square of the vector potential, representing diamagnetic interactions.
    """
    def inner(t, r, args):
        # ~ A^2
        q = m = 1
        return jnp.diag(q**2 / m * 0.5 * jnp.sum(vector_potential(t)**2, axis=1))
    return inner

DipoleGauge(illumination, use_rwa=False, intra_only=False)

Dipole gauge coupling to an external electric field, represented as \(\vec{E} \cdot \hat{\vec{P}}\). The dipole / polarization operator is defined by \(P^{c}_{ij} = <i|\hat{r}_c|j>\), where \(i,j\) correspond to localized (TB) orbitals, such that \(\hat{r}^c|i> = r^c{i}|i>\) in absence of dipole transitions.

Parameters:
  • illumination (callable) –

    Function that returns the electric field at a given time.

  • use_rwa (bool, default: False ) –

    If True, uses the rotating wave approximation which simplifies the calculations by considering only resonant terms.

  • intra_only (bool, default: False ) –

    If True, subtracts the diagonal of the potential matrix, focusing only on the interactions between different elements.

Returns:
  • Function

    Computes the electric potential based on the given illumination and options for RWA and intramolecular interactions.

Source code in src/granad/potentials.py
def DipoleGauge(illumination, use_rwa = False, intra_only = False):     
    """Dipole gauge coupling to an external electric field, represented as $\\vec{E} \cdot \hat{\\vec{P}}$. The dipole / polarization operator is
    defined by $P^{c}_{ij} = <i|\hat{r}_c|j>$, where $i,j$ correspond to localized (TB) orbitals, such that $\hat{r}^c|i> = r^c{i}|i>$ in absence of dipole transitions.

    Args:
        illumination (callable): Function that returns the electric field at a given time.
        use_rwa (bool): If True, uses the rotating wave approximation which simplifies the calculations by considering only resonant terms.
        intra_only (bool): If True, subtracts the diagonal of the potential matrix, focusing only on the interactions between different elements.

    Returns:
        Function: Computes the electric potential based on the given illumination and options for RWA and intramolecular interactions.
    """
    def electric_potential(t, r, args):
        return jnp.einsum(einsum_string, args.dipole_operator, illumination(t).real)

    def electric_potential_rwa(t, r, args):
        # the missing real part is crucial here! the RWA (for real dipole moments) makes the fields complex and divides by 2
        total_field_potential = jnp.einsum(einsum_string, args.dipole_operator, illumination(t, real = False))

        # Get the indices for the lower triangle, excluding the diagonal
        lower_indices = jnp.tril_indices(total_field_potential.shape[0], -1)

        # Replace elements in the lower triangle with their complex conjugates    
        tmp = total_field_potential.at[lower_indices].set( jnp.conj(total_field_potential[lower_indices]) )

        # make hermitian again
        return tmp - 1j*jnp.diag(tmp.diagonal().imag)

    maybe_diag = lambda f : f
    if intra_only == True:
        maybe_diag = lambda f : lambda t, r, args : f(t,r,args) - jnp.diag( f(t,r,args).diagonal() )

    einsum_string =  'Kij,K->ij' if illumination(0.).shape == (3,) else 'Kij,iK->ij'    

    if use_rwa == True:
        return maybe_diag(electric_potential_rwa)
    return maybe_diag(electric_potential)

DipolePulse(dipole_moment, source_location, omega=None, sigma=None, t0=0.0, kick=False, dt=None)

Function to compute the potential due to a pulsed dipole. The potential can optionally include a 'kick' which is an instantaneous spike at a specific time. If the dipole is placed at a position occupied by orbitals, its contribution will be set to zero.

Parameters:
  • dipole_moment

    Vector representing the dipole moment in xyz-components.

  • source_location

    Location of the source of the dipole in xyz-coordinates.

  • omega

    Angular frequency of the oscillation (default is None).

  • sigma

    Standard deviation of the pulse's temporal Gaussian profile (default is None).

  • t0

    Time at which the pulse is centered (default is 0.0).

  • kick

    If True, lets the spatial profile of the dipole kick only at time t0 (default is False) and discards omega, sigma.

  • dt

    The lifetime of the kick. Provide the minimum time step of your simulation.

Returns:
  • Function that computes the dipole potential at a given time and location, with adjustments for distance and orientation relative to the dipole.

Note

Recommended only with solver=diffrax.Dopri8.

Source code in src/granad/potentials.py
def DipolePulse( dipole_moment, source_location, omega = None, sigma = None, t0 = 0.0, kick = False, dt = None ):
    """Function to compute the potential due to a pulsed dipole. The potential can optionally include a 'kick' which is an instantaneous spike at a specific time.
    If the dipole is placed at a position occupied by orbitals, its contribution will be set to zero.

    Args:
        dipole_moment: Vector representing the dipole moment in xyz-components.
        source_location: Location of the source of the dipole in xyz-coordinates.
        omega: Angular frequency of the oscillation (default is None).
        sigma: Standard deviation of the pulse's temporal Gaussian profile (default is None).
        t0: Time at which the pulse is centered (default is 0.0).
        kick: If True, lets the spatial profile of the dipole kick only at time t0 (default is False) and discards omega, sigma.
        dt: The lifetime of the kick. Provide the minimum time step of your simulation.  

    Returns:
        Function that computes the dipole potential at a given time and location, with adjustments for distance and orientation relative to the dipole.

    Note:
       Recommended only with solver=diffrax.Dopri8.
    """
    from scipy.integrate import quad
    if kick and dt is None:
        raise ValueError("When kick=True, you must specify dt.")

    loc = jnp.array( source_location )[:,None]
    dip = jnp.array( dipole_moment )

    f_unscaled = lambda t : jnp.cos(omega * t) * jnp.exp( -(t-t0)**2 / sigma**2 )
    area, _ = quad(f_unscaled, 0, jnp.inf)
    f = lambda t : jnp.cos(omega * t) * jnp.exp( -(t-t0)**2 / sigma**2 )/area

    if kick == True:
        f = lambda t: jnp.where(jnp.abs(t - t0) < 1e-10, 1.0 / dt, 0.0)
        #f = lambda t: jnp.abs(t - t0) < 1e-10#In discrete time axis kick is a rectangular pulse.
        #The power delivered to a system depends on the width of the kick dt i.e. time-axis discretization.
        #To keep the results consistent with different dt we dynamically modify the intensity of the kick such that total power delivered by the kick is constant.

    def pot( t, r, args ):
        distances =args.dipole_operator.diagonal(axis1=-1, axis2=-2) - loc
        r_term = 14.39*(dip @ distances) / jnp.linalg.norm( distances, axis = 0 )**3
        return jnp.diag( jnp.nan_to_num(r_term) * f(t) )

    return pot

Induced()

Calculates the induced potential, which propagates the coulomb effect of induced charges in the system according to \(\sim \sum_r q_r/|r-r'|\).

Returns:
  • Function

    Computes the induced potential at a given time and location based on charge propagation.

Source code in src/granad/potentials.py
def Induced():
    """Calculates the induced potential, which propagates the coulomb effect of induced charges in the system according to $\sim \sum_r q_r/|r-r'|$.

    Returns:
        Function: Computes the induced potential at a given time and location based on charge propagation.
    """
    def inner(t, r, args):
        field = jnp.einsum("ijK,j->iK", args.propagator, -args.electrons*r.diagonal())
        return jnp.einsum("Kij,iK->ij", args.dipole_operator, field.real)
    return inner

Paramagnetic(vector_potential)

Paramagnetic Coulomb gauge coupling to an external vector potential represented as \(\sim \vec{A} \hat{\vec{v}}\).

Parameters:
  • vector_potential (callable) –

    Function that returns the vector potential at a given time.

Returns:
  • Function

    Computes the interaction of the vector potential with the velocity operator.

Source code in src/granad/potentials.py
def Paramagnetic(vector_potential):
    """Paramagnetic Coulomb gauge coupling to an external vector potential represented as $\sim \\vec{A} \hat{\\vec{v}}$. 

    Args:
        vector_potential (callable): Function that returns the vector potential at a given time.

    Returns:
        Function: Computes the interaction of the vector potential with the velocity operator.
    """
    def inner(t, r, args):
        # ~ A p
        q = 1
        return -q * jnp.einsum("Kij, iK -> ij", args.velocity_operator, vector_potential(t))
    return inner

WavePulse(amplitudes, omega=None, sigma=None, t0=0.0, kick=False)

Function to compute the wave potential using amplitude modulation. This function creates a pulse with temporal Gaussian characteristics and can include an optional 'kick' which introduces an instantaneous amplitude peak.

Parameters:
  • amplitudes

    List of amplitudes for the wave components in xyz-directions.

  • omega

    Angular frequency of the wave oscillation (default is None).

  • sigma

    Standard deviation of the Gaussian pulse in time (default is None).

  • t0

    Central time around which the pulse peaks (default is 0.0).

  • kick

    If True, lets the spatial profile of the wave kick only at time t0 (default is False) and discards omega, sigma.

Returns:
  • Function that computes the potential at a given time and location, incorporating the wave characteristics and specified modulations.

Note

This function, when not kicked, computes the same term as Pulse.

Source code in src/granad/potentials.py
def WavePulse( amplitudes, omega = None, sigma = None, t0 = 0.0, kick = False ):
    """Function to compute the wave potential using amplitude modulation. This function creates a pulse with temporal Gaussian characteristics and can include an optional 'kick' which introduces an instantaneous amplitude peak.

    Args:
        amplitudes: List of amplitudes for the wave components in xyz-directions.
        omega: Angular frequency of the wave oscillation (default is None).
        sigma: Standard deviation of the Gaussian pulse in time (default is None).
        t0: Central time around which the pulse peaks (default is 0.0).
        kick: If True, lets the spatial profile of the wave kick only at time t0 (default is False) and discards omega, sigma.

    Returns:
        Function that computes the potential at a given time and location, incorporating the wave characteristics and specified modulations.

    Note:
       This function, when not kicked, computes the same term as `Pulse`.
    """
    amplitudes = jnp.array(amplitudes)

    f = lambda t : amplitudes * jnp.cos(omega * t) * jnp.exp( -(t-t0)**2 / sigma**2 ) 
    if kick == True:
        f = lambda t : amplitudes * (jnp.abs(t - t0) < 1e-10)

    def pot( t, r, args ):
        return jnp.einsum('Kij,K->ij', args.dipole_operator, f(t))

    return pot

DecoherenceTime()

Function for modelling dissipation according to the relaxation approximation.

Source code in src/granad/dissipators.py
4
5
6
7
def DecoherenceTime():
    """Function for modelling dissipation according to the relaxation approximation.
    """
    return lambda t,r,args: -(r - args.stationary_density_matrix) * args.relaxation_rate

SaturationLindblad(saturation)

Function for modelling dissipation according to the saturated lindblad equation as detailed in Pelc et al.. The argument stands for the saturation functional. If identity is selected as the saturation functional, the model represents the canonical Lindblad relaxation.

Source code in src/granad/dissipators.py
def SaturationLindblad(saturation):
    """Function for modelling dissipation according to the saturated lindblad equation as detailed in [Pelc et al.](https://link.aps.org/doi/10.1103/PhysRevA.109.022237).
    The argument stands for the saturation functional. If identity is selected as the saturation functional, the model represents the canonical Lindblad relaxation.
    """
    saturation = jax.vmap(saturation, 0, 0)

    def inner(t, r, args):
        # convert rho to energy basis
        r = args.eigenvectors.conj().T @ r @ args.eigenvectors

        # extract occupations
        diag = jnp.diag(r) * args.electrons

        # apply the saturation functional to turn off elements in the gamma matrix
        gamma = args.relaxation_rate.astype(complex) * saturation(diag)[None, :]

        a = jnp.diag(gamma.T @ jnp.diag(r))
        mat = jnp.diag(jnp.sum(gamma, axis=1))
        b = -1 / 2 * (mat @ r + r @ mat)
        val = a + b

        return args.eigenvectors @ val @ args.eigenvectors.conj().T

    return inner

DP54_solver(rhs_func, ts, d_ini, args, postprocesses)

Solves an ODE using the Dormand-Prince 5(4) method with JAX acceleration.

Parameters:
  • rhs_func (callable) –

    Function defining the ODE's right-hand side (dy/dt = rhs_func(t, y, args)).

  • ts (array) –

    Array of time points for the solution.

  • d_ini (array) –

    Initial condition(s) of the dependent variable(s).

  • args (tuple) –

    Additional arguments to pass to rhs_func.

  • postprocesses (list) –

    List of callable post-processing functions to apply to each step's solution.

Returns:
  • tuple

    Solution array and list of post-processed results from jax.lax.scan.

Notes

Implements the Dormand-Prince 5(4) method with seven stages and 5th-order accuracy. Uses coefficients from the Dormand-Prince tableau for high precision. Optimized with JAX's JIT compilation and scan functionality.

Source code in src/granad/_numerics.py
def DP54_solver(rhs_func, ts, d_ini, args, postprocesses):
    """Solves an ODE using the Dormand-Prince 5(4) method with JAX acceleration.

    Args:
        rhs_func (callable): Function defining the ODE's right-hand side (dy/dt = rhs_func(t, y, args)).
        ts (array): Array of time points for the solution.
        d_ini (array): Initial condition(s) of the dependent variable(s).
        args (tuple): Additional arguments to pass to rhs_func.
        postprocesses (list): List of callable post-processing functions to apply to each step's solution.

    Returns:
        tuple: Solution array and list of post-processed results from jax.lax.scan.

    Notes:
        Implements the Dormand-Prince 5(4) method with seven stages and 5th-order accuracy.
        Uses coefficients from the Dormand-Prince tableau for high precision.
        Optimized with JAX's JIT compilation and scan functionality.
    """
    dt = ts[1] - ts[0]
    print(dt)
    def rho_nxt_DP54(state, t, rhs_func, args):
        # Coefficients from Dormand-Prince 5(4) tableau
        #Copied from: https://numerary.readthedocs.io/en/latest/dormand-prince-method.html
        a21 = 1/5
        a31, a32 = 3/40, 9/40
        a41, a42, a43 = 44/45, -56/15, 32/9
        a51, a52, a53, a54 = 19372/6561, -25360/2187, 64448/6561, -212/729
        a61, a62, a63, a64, a65 = 9017/3168, -355/33, 46732/5247, 49/176, -5103/18656
        a71, a72, a73, a74, a75, a76 = 35/384, 0, 500/1113, 125/192, -2187/6784, 11/84
        b1, b2, b3, b4, b5, b6, b7 = 35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0


        rho, k1 = state
        #k1 = rhs_func(t, rho, args)
        k2 = rhs_func(t + dt * a21, rho + dt * a21 * k1, args)
        k3 = rhs_func(t + dt * 3/10, rho + dt * (a31 * k1 + a32 * k2), args)
        k4 = rhs_func(t + dt * 4/5, rho + dt * (a41 * k1 + a42 * k2 + a43 * k3), args)
        k5 = rhs_func(t + dt * 1.0, rho + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4), args)
        k6 = rhs_func(t + dt * 1.0, rho + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5), args)
        k7 = rhs_func(t + dt * 1.0, rho + dt * (a71 * k1 + a72 * k2 + a73 * k3 + a74 * k4 + a75 * k5 + a76 * k6), args)

        rho_nxt = rho + dt * (b1 * k1 + b2 * k2 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6 + b7 * k7)
        state=(rho_nxt, k7)

        return state, [p(rho_nxt, args) for p in postprocesses]

    jitted_rho_nxt_DP54 = jax.jit(rho_nxt_DP54, static_argnums=(2,))

    @jax.jit
    def jax_scan_compatible_rho_nxt_DP54(state, t):
        return rho_nxt_DP54(state, t, rhs_func=rhs_func, args=args)
    k1 = rhs_func(ts[0], d_ini, args)
    initial_state=(d_ini, k1 )
    final_state, postprocessed_rhos = jax.lax.scan(jax_scan_compatible_rho_nxt_DP54, initial_state, ts)    
    return final_state[0], postprocessed_rhos

Euler_solver(rhs_func, ts, d_ini, args, postprocesses)

Solves an ODE using the explicit Euler method with JAX acceleration.

Parameters:
  • rhs_func (callable) –

    Function defining the ODE's right-hand side (dy/dt = rhs_func(t, y, args)).

  • ts (array) –

    Array of time points for the solution.

  • d_ini (array) –

    Initial condition(s) of the dependent variable(s).

  • args (tuple) –

    Additional arguments to pass to rhs_func.

  • postprocesses (list) –

    List of callable post-processing functions to apply to each step's solution.

Returns:
  • tuple

    Solution array and list of post-processed results from jax.lax.scan.

Notes

Implements the explicit Euler method: y[n+1] = y[n] + dt * rhs_func(t[n], y[n], args). Uses JAX's JIT compilation and scan for efficient computation.

Source code in src/granad/_numerics.py
def Euler_solver(rhs_func, ts, d_ini, args, postprocesses):
    """Solves an ODE using the explicit Euler method with JAX acceleration.

    Args:
        rhs_func (callable): Function defining the ODE's right-hand side (dy/dt = rhs_func(t, y, args)).
        ts (array): Array of time points for the solution.
        d_ini (array): Initial condition(s) of the dependent variable(s).
        args (tuple): Additional arguments to pass to rhs_func.
        postprocesses (list): List of callable post-processing functions to apply to each step's solution.

    Returns:
        tuple: Solution array and list of post-processed results from jax.lax.scan.

    Notes:
        Implements the explicit Euler method: y[n+1] = y[n] + dt * rhs_func(t[n], y[n], args).
        Uses JAX's JIT compilation and scan for efficient computation.
    """
    dt = ts[1] - ts[0]
    #import pdb; pdb.set_trace()
    def rho_nxt_Euler(rho, t, rhs_func, args):
        rho_nxt = rho + rhs_func(t, rho, args) * dt
        return rho_nxt, [p(rho_nxt, args) for p in postprocesses]

    jitted_rho_nxt_Euler = jax.jit(rho_nxt_Euler, static_argnums=(2,))

    @jax.jit
    def jax_scan_compatible_rho_nxt_Euler(rho, t):
        return rho_nxt_Euler(rho, t, rhs_func=rhs_func, args=args)

    return jax.lax.scan(jax_scan_compatible_rho_nxt_Euler, d_ini, ts)

RK45_solver(rhs_func, ts, d_ini, args, postprocesses)

Solves an ODE using the 4th-order Runge-Kutta (RK4) method with JAX acceleration.

Parameters:
  • rhs_func (callable) –

    Function defining the ODE's right-hand side (dy/dt = rhs_func(t, y, args)).

  • ts (array) –

    Array of time points for the solution.

  • d_ini (array) –

    Initial condition(s) of the dependent variable(s).

  • args (tuple) –

    Additional arguments to pass to rhs_func.

  • postprocesses (list) –

    List of callable post-processing functions to apply to each step's solution.

Returns:
  • tuple

    Solution array and list of post-processed results from jax.lax.scan.

Notes

Implements the classical RK4 method with four stages (k1, k2, k3, k4). Next state: y[n+1] = y[n] + (dt/6) * (k1 + 2k2 + 2k3 + k4). Optimized with JAX's JIT compilation and scan functionality.

Source code in src/granad/_numerics.py
def RK45_solver(rhs_func, ts, d_ini, args, postprocesses):
    """Solves an ODE using the 4th-order Runge-Kutta (RK4) method with JAX acceleration.

    Args:
        rhs_func (callable): Function defining the ODE's right-hand side (dy/dt = rhs_func(t, y, args)).
        ts (array): Array of time points for the solution.
        d_ini (array): Initial condition(s) of the dependent variable(s).
        args (tuple): Additional arguments to pass to rhs_func.
        postprocesses (list): List of callable post-processing functions to apply to each step's solution.

    Returns:
        tuple: Solution array and list of post-processed results from jax.lax.scan.

    Notes:
        Implements the classical RK4 method with four stages (k1, k2, k3, k4).
        Next state: y[n+1] = y[n] + (dt/6) * (k1 + 2*k2 + 2*k3 + k4).
        Optimized with JAX's JIT compilation and scan functionality.
    """
    dt = ts[1] - ts[0]
    print(dt)

    def rho_nxt_RK45(state, t, rhs_func, args):
        rho, k1 = state
        k1 = rhs_func(t, rho, args)
        k2 = rhs_func(t + dt/2, rho + (dt/2) * k1, args)
        k3 = rhs_func(t + dt/2, rho + (dt/2) * k2, args)
        k4 = rhs_func(t + dt, rho + dt * k3, args)
        rho_nxt = rho + dt/6 * (k1 + 2*k2 + 2*k3 + k4)

        state=(rho_nxt, k4)
        return state, [p(rho_nxt, args) for p in postprocesses]

    jitted_rho_nxt_RK45 = jax.jit(rho_nxt_RK45, static_argnums=(2,))

    @jax.jit
    def jax_scan_compatible_rho_nxt_RK45(state, t):
        return rho_nxt_RK45(state, t, rhs_func=rhs_func, args=args)

    k1 = rhs_func(ts[0], d_ini, args)
    initial_state=(d_ini, k1 )

    final_state, postprocessed_rhos=jax.lax.scan(jax_scan_compatible_rho_nxt_RK45, initial_state, ts)
    return final_state[0], postprocessed_rhos

fraction_periodic(signal, threshold=0.01)

Estimates the fraction of a periodic component in a given signal by analyzing the deviation of the cumulative mean from its median value. The periodicity is inferred based on the constancy of the cumulative mean of the absolute value of the signal.

Parameters:
  • signal (Array) –

    A 1D array representing the signal of interest.

  • threshold (float, default: 0.01 ) –

    A threshold value to determine the significance level of deviation from periodicity. Defaults to 0.01.

Returns:
  • float

    A ratio representing the fraction of the signal that is considered periodic, based on the specified threshold.

Source code in src/granad/_numerics.py
def fraction_periodic(signal, threshold=1e-2):
    """
    Estimates the fraction of a periodic component in a given signal by analyzing the deviation of the cumulative mean from its median value. The periodicity is inferred based on the constancy of the cumulative mean of the absolute value of the signal.

    Parameters:
        signal (jax.Array): A 1D array representing the signal of interest.
        threshold (float, optional): A threshold value to determine the significance level of deviation from periodicity. Defaults to 0.01.

    Returns:
        float: A ratio representing the fraction of the signal that is considered periodic, based on the specified threshold.

    """

    # signal is periodic => abs(signal) is periodic
    cum_sum = jnp.abs(signal).cumsum()

    # cumulative mean of periodic signal is constant
    cum_mean = cum_sum / jnp.arange(1, len(signal) + 1)

    # if cumulative mean doesn't move anymore, we have a lot of periodic signal
    med = jnp.median(cum_mean)
    deviation = jnp.abs(med - cum_mean) / med

    # approximate admixture of periodic signal
    return (deviation < threshold).sum().item() / len(signal)

get_coulomb_field_to_from(source_positions, target_positions, compute_at=None)

Calculate the contributions of point charges located at source_positions on points at target_positions.

Args: - source_positions (array): An (n_source, 3) array of source positions. - target_positions (array): An (n_target, 3) array of target positions.

Returns: - array: An (n_source, n_target, 3) array where each element is the contribution of a source at a target position.

Source code in src/granad/_numerics.py
def get_coulomb_field_to_from(source_positions, target_positions, compute_at=None):
    """
    Calculate the contributions of point charges located at `source_positions`
    on points at `target_positions`.

    **Args:**
    - source_positions (array): An (n_source, 3) array of source positions.
    - target_positions (array): An (n_target, 3) array of target positions.

    **Returns:**
    - array: An (n_source, n_target, 3) array where each element is the contribution
          of a source at a target position.
    """
    if compute_at is None:
        return None

    # Calculate vector differences between each pair of source and target positions
    distance_vector = target_positions[:, None, :] - source_positions
    # Compute the norm of these vectors
    norms = jnp.linalg.norm(distance_vector, axis=-1)
    # Safe division by the cube of the norm
    one_over_distance_cubed = jnp.where(norms > 0, 1 / norms**3, 0)
    # Calculate the contributions
    coulomb_field_to_from = distance_vector * one_over_distance_cubed[:, :, None]
    # final array
    coulomb_field_to_from_final = jnp.zeros_like(coulomb_field_to_from)        
    # include only contributions where desired
    return coulomb_field_to_from_final.at[compute_at].set(
        coulomb_field_to_from[compute_at]
    )

iterate(func)

A decorator that allows a function to iterate over list inputs.

Functionality: 1. If one or more of the function’s input arguments is a list, the function is executed for every combination of elements. 2. If multiple list inputs are present, the computation follows a Cartesian product pattern: - For the 1st element of list A, iterate over all elements of list B. - For the 2nd element of list A, iterate over all elements of list B. - And so on. 3. The results are reshaped into a nested structure matching the input lists. 4. If the function returns multiple values (a tuple), each return value is separately structured into its own nested list.

Parameters:
  • func (callable) –

    The function to be decorated.

Returns:
  • callable

    The decorated function that processes list inputs correctly.

Source code in src/granad/_numerics.py
def iterate(func):
    """
    A decorator that allows a function to iterate over list inputs.

    Functionality:
    1. If one or more of the function’s input arguments is a list, the function is executed for every combination of elements.
    2. If multiple list inputs are present, the computation follows a Cartesian product pattern:
       - For the 1st element of list A, iterate over all elements of list B.
       - For the 2nd element of list A, iterate over all elements of list B.
       - And so on.
    3. The results are reshaped into a nested structure matching the input lists.
    4. If the function returns multiple values (a tuple), each return value is separately structured into its own nested list.

    Parameters:
        func (callable): The function to be decorated.

    Returns:
        callable: The decorated function that processes list inputs correctly.
    """
    @functools.wraps(func)
    def inner(*args, **kwargs):
        result = []

        # Extract parameters that are lists
        dict_params = {key: values for key, values in kwargs.items() if isinstance(values, list)}
        shape = [len(values) for values in dict_params.values()]

        if dict_params:  # If any parameter is a list, perform Cartesian product iteration
            for combination in itertools.product(*dict_params.values()):
                new_kwargs = kwargs | dict(zip(dict_params.keys(), combination))
                result.append(func(*args, **new_kwargs))  # Call the original function

            # If function returns a tuple, separate results for each return value
            if isinstance(result[0], tuple):
                transposed_results = list(zip(*result))  # Split values across multiple lists
                return tuple(nest_result(list(r), shape) for r in transposed_results)
            else:
                return nest_result(result, shape)  # Single return value case

        return func(*args, **kwargs)  # Direct function call if no list parameters

    return inner

nest_result(result_list, shape)

Recursively reshapes a flat list into a nested list structure matching the Cartesian product shape.

Parameters:
  • result_list (list) –

    The flat list of computed results.

  • shape (list) –

    A list representing the shape of the Cartesian product.

Returns:
  • list

    A nested list following the Cartesian product shape.

Source code in src/granad/_numerics.py
def nest_result(result_list, shape):
    """
    Recursively reshapes a flat list into a nested list structure matching the Cartesian product shape.

    Parameters:
        result_list (list): The flat list of computed results.
        shape (list): A list representing the shape of the Cartesian product.

    Returns:
        list: A nested list following the Cartesian product shape.
    """
    if len(shape) == 1:
        return result_list  # Base case: last dimension

    chunk_size = int(len(result_list) / shape[0])  # Compute chunk size dynamically
    return [
        nest_result(result_list[i * chunk_size:(i + 1) * chunk_size], shape[1:]) 
        for i in range(shape[0])
    ]

show_2d(orbs, show_tags=None, show_index=False, display=None, scale=False, cmap=None, circle_scale=1000.0, title=None, mode=None, indicate_atoms=False, grid=False)

Generates a 2D scatter plot representing the positions of orbitals in the xy-plane, with optional filtering, coloring, and sizing.

Parameters:
  • `orbs` (list) –

    List of orbital objects, each containing attributes such as 'tag' (for labeling) and 'position' (xy-coordinates).

  • `show_tags` (list of str) –

    Filters the orbitals to display based on their tags. Only orbitals with matching tags will be shown. If None, all orbitals are displayed.

  • `show_index` (bool) –

    If True, displays the index of each orbital next to its corresponding point on the plot.

  • `display` (array - like) –

    Data used to color and scale the points (e.g., eigenvector amplitudes). Each value corresponds to an orbital.

  • `scale` (bool) –

    If True, the values in display are normalized and their absolute values are used.

  • `cmap` (optional) –

    Colormap used for the scatter plot when display is provided. If None, a default colormap (bwr) is used.

  • `circle_scale` (float) –

    A scaling factor for the size of the scatter plot points. Larger values result in larger circles. Default is 1000.

  • `title` (str) –

    Custom title for the plot. If None, the default title "Orbital positions in the xy-plane" is used.

  • `mode` (str) –

    Determines the plotting style for orbitals when display is provided. - 'two-signed': Displays orbitals with a diverging colormap centered at zero, highlighting both positive and negative values symmetrically. The colormap is scaled such that its limits are set by the maximum absolute value in the display array. - 'one-signed': Displays orbitals with a sequential colormap, highlighting only positive values. Negative values are ignored in this mode. - None: Defaults to a general plotting mode that uses the normalized values from display for coloring and sizing.

  • `indicate_atoms` (bool) –

    Show atoms as black dots if display is given, defaults to False.

  • `grid`(bool, (optional) –

    Shows grid, False by default.

Notes

If display is provided, the points are colored and sized according to the values in the display array, and a color bar is added to the plot. If show_index is True, the indices of the orbitals are annotated next to their corresponding points. The plot is automatically adjusted to ensure equal scaling of the axes, and grid lines are displayed.

Source code in src/granad/_plotting.py
@_plot_wrapper
def show_2d(orbs, show_tags=None, show_index=False, display = None, scale = False, cmap = None, circle_scale : float = 1e3, title = None, mode = None, indicate_atoms = False, grid = False):
    """
    Generates a 2D scatter plot representing the positions of orbitals in the xy-plane, with optional filtering, coloring, and sizing.

    Parameters:
        `orbs` (list): List of orbital objects, each containing attributes such as 'tag' (for labeling) and 'position' (xy-coordinates).
        `show_tags` (list of str, optional): Filters the orbitals to display based on their tags. Only orbitals with matching tags will be shown. If `None`, all orbitals are displayed.
        `show_index` (bool, optional): If `True`, displays the index of each orbital next to its corresponding point on the plot.
        `display` (array-like, optional): Data used to color and scale the points (e.g., eigenvector amplitudes). Each value corresponds to an orbital.
        `scale` (bool, optional): If `True`, the values in `display` are normalized and their absolute values are used.
        `cmap` (optional): Colormap used for the scatter plot when `display` is provided. If `None`, a default colormap (`bwr`) is used.
        `circle_scale` (float, optional): A scaling factor for the size of the scatter plot points. Larger values result in larger circles. Default is 1000.
        `title` (str, optional): Custom title for the plot. If `None`, the default title "Orbital positions in the xy-plane" is used.
        `mode` (str, optional): Determines the plotting style for orbitals when `display` is provided.
            - `'two-signed'`: Displays orbitals with a diverging colormap centered at zero, highlighting both positive and negative values symmetrically. 
                              The colormap is scaled such that its limits are set by the maximum absolute value in the `display` array.
            - `'one-signed'`: Displays orbitals with a sequential colormap, highlighting only positive values. Negative values are ignored in this mode.
            - `None`: Defaults to a general plotting mode that uses the normalized values from `display` for coloring and sizing.
        `indicate_atoms` (bool, optional): Show atoms as black dots if `display` is given, defaults to False.
        `grid`(bool, optional): Shows grid, `False` by default.


    Notes:
        If `display` is provided, the points are colored and sized according to the values in the `display` array, and a color bar is added to the plot.
        If `show_index` is `True`, the indices of the orbitals are annotated next to their corresponding points.
        The plot is automatically adjusted to ensure equal scaling of the axes, and grid lines are displayed.    
    """

    # decider whether to take abs val and normalize 
    def scale_vals( vals ):
        return jnp.abs(vals) / jnp.abs(vals).max() if scale else vals

    # Determine which tags to display
    if show_tags is None:
        show_tags = {orb.tag for orb in orbs}
    else:
        show_tags = set(show_tags)

    # Prepare data structures for plotting
    tags_to_pos, tags_to_idxs = defaultdict(list), defaultdict(list)
    for orb in orbs:
        if orb.tag in show_tags:
            tags_to_pos[orb.tag].append(orb.position)
            tags_to_idxs[orb.tag].append(orbs.index(orb))

    # Create plot
    fig, ax = plt.subplots()
    if display is not None:        
        cmap = plt.cm.bwr if cmap is None else cmap
        if mode == 'two-signed':
            display = display.real
            dmax = jnp.max(jnp.abs(display))            
            scatter = ax.scatter([orb.position[0] for orb in orbs], [orb.position[1] for orb in orbs], c = display, edgecolor='black', cmap=cmap, s = circle_scale / 10 )
            scatter.set_clim(-dmax, dmax)
        elif mode == 'one-signed':
            cmap = plt.cm.Reds
            display = display.real
            dmax = display[jnp.argmax(jnp.abs(display))]
            scatter = ax.scatter([orb.position[0] for orb in orbs], [orb.position[1] for orb in orbs], c = display, edgecolor='black', cmap=cmap, s = circle_scale / 10 )
            if(dmax<0):
                scatter.set_clim(dmax, 0)
            else:
                scatter.set_clim(0, dmax)
        else:
            colors = scale_vals(display)            
            scatter = ax.scatter([orb.position[0] for orb in orbs], [orb.position[1] for orb in orbs], c=colors, edgecolor='black', cmap=cmap, s = circle_scale*jnp.abs(display) )
        if indicate_atoms == True:
            ax.scatter([orb.position[0] for orb in orbs], [orb.position[1] for orb in orbs], color='black', s=10, marker='o')            
        cbar = fig.colorbar(scatter, ax=ax)

    else:
        # Color by tags if no show_state is given
        unique_tags = list(set(orb.tag for orb in orbs))
        color_map = {tag: plt.cm.get_cmap('tab10')(i / len(unique_tags)) for i, tag in enumerate(unique_tags)}
        for tag, positions in tags_to_pos.items():
            positions = jnp.array(positions)
            ax.scatter(positions[:, 0], positions[:, 1], label=tag, color=color_map[tag], edgecolor='white', alpha=0.7)
        plt.legend(title='Orbital Tags')

    # Optionally annotate points with their indexes
    if show_index:
        for orb in [orb for orb in orbs if orb.tag in show_tags]:
            pos = orb.position
            idx = orbs.index(orb)
            ax.annotate(str(idx), (pos[0], pos[1]), textcoords="offset points", xytext=(0,10), ha='center')

    # Finalize plot settings
    if title is not None:
        plt.title(title)
    plt.xlabel('X')
    plt.ylabel('Y')
    ax.grid(grid)
    ax.axis('equal')

show_3d(orbs, show_tags=None, show_index=False, display=None, scale=False, cmap=None, circle_scale=1000.0, title=None)

Generates a 3D scatter plot representing the positions of orbitals in 3D space, with optional filtering, coloring, and sizing.

Parameters:
  • `orbs` (list) –

    List of orbital objects, each containing attributes such as 'tag' (for labeling) and 'position' (3D coordinates).

  • `show_tags` (list of str) –

    Filters the orbitals to display based on their tags. Only orbitals with matching tags will be shown. If None, all orbitals are displayed.

  • `show_index` (bool) –

    If True, displays the index of each orbital next to its corresponding point on the plot.

  • `display` (array - like) –

    Data used to color and scale the points (e.g., eigenvector amplitudes). Each value corresponds to an orbital.

  • `scale` (bool) –

    If True, the values in display are normalized and their absolute values are used.

  • `cmap` (optional) –

    Colormap used for the scatter plot when display is provided. If None, a default colormap (bwr) is used.

  • `circle_scale` (float) –

    A scaling factor for the size of the scatter plot points. Larger values result in larger circles. Default is 1000.

  • `title` (str) –

    Custom title for the plot. If None, the default title "Orbital positions in 3D" is used.

Notes

If display is provided, the points are colored and sized according to the values in the display array, and a color bar is added to the plot. If show_index is True, the indices of the orbitals are annotated next to their corresponding points. The plot is automatically adjusted to display grid lines and 3D axes labels for X, Y, and Z.

Source code in src/granad/_plotting.py
@_plot_wrapper
def show_3d(orbs, show_tags=None, show_index=False, display = None, scale = False, cmap = None, circle_scale : float = 1e3, title = None):
    """
    Generates a 3D scatter plot representing the positions of orbitals in 3D space, with optional filtering, coloring, and sizing.

    Parameters:
        `orbs` (list): List of orbital objects, each containing attributes such as 'tag' (for labeling) and 'position' (3D coordinates).
        `show_tags` (list of str, optional): Filters the orbitals to display based on their tags. Only orbitals with matching tags will be shown. If `None`, all orbitals are displayed.
        `show_index` (bool, optional): If `True`, displays the index of each orbital next to its corresponding point on the plot.
        `display` (array-like, optional): Data used to color and scale the points (e.g., eigenvector amplitudes). Each value corresponds to an orbital.
        `scale` (bool, optional): If `True`, the values in `display` are normalized and their absolute values are used.
        `cmap` (optional): Colormap used for the scatter plot when `display` is provided. If `None`, a default colormap (`bwr`) is used.
        `circle_scale` (float, optional): A scaling factor for the size of the scatter plot points. Larger values result in larger circles. Default is 1000.
        `title` (str, optional): Custom title for the plot. If `None`, the default title "Orbital positions in 3D" is used.

    Notes:
        If `display` is provided, the points are colored and sized according to the values in the `display` array, and a color bar is added to the plot.
        If `show_index` is `True`, the indices of the orbitals are annotated next to their corresponding points.
        The plot is automatically adjusted to display grid lines and 3D axes labels for X, Y, and Z.
    """
    # decider whether to take abs val and normalize 
    def scale_vals( vals ):
        return jnp.abs(vals) / jnp.abs(vals).max() if scale else vals

    # Determine which tags to display
    if show_tags is None:
        show_tags = {orb.tag for orb in orbs}
    else:
        show_tags = set(show_tags)

    # Prepare data structures for plotting
    tags_to_pos, tags_to_idxs = defaultdict(list), defaultdict(list)
    for orb in orbs:
        if orb.tag in show_tags:
            tags_to_pos[orb.tag].append(orb.position)
            tags_to_idxs[orb.tag].append(orbs.index(orb))

    # Prepare 3D plot
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    if display is not None:
        cmap = plt.cm.bwr if cmap is None else cmap
        colors = scale_vals( display )
        scatter = ax.scatter([orb.position[0] for orb in orbs], [orb.position[1] for orb in orbs], [orb.position[2] for orb in orbs], c=colors, edgecolor='black', cmap=cmap, depthshade=True, s = circle_scale*jnp.abs(display))
        cbar = fig.colorbar(scatter, ax=ax)
        cbar.set_label('Eigenvector Magnitude')
    else:
        # Color by tags if no show_state is given
        unique_tags = list(set(orb.tag for orb in orbs))
        color_map = {tag: plt.cm.get_cmap('tab10')(i / len(unique_tags)) for i, tag in enumerate(unique_tags)}
        for tag, positions in tags_to_pos.items():
            positions = jnp.array(positions)
            ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], label=tag, color=color_map[tag], edgecolor='white', alpha=0.7)
        plt.legend(title='Orbital Tags')

    # Optionally annotate points with their indexes
    if show_index:
        for orb in [orb for orb in orbs if orb.tag in show_tags]:
            pos = orb.position
            idx = orbs.index(orb)
            ax.text(pos[0], pos[1], pos[2], str(idx), color='black', size=10)

    # Finalize plot settings
    ax.set_title('Orbital positions in 3D' if title is not None else title)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.grid(True)

show_energies(orbs, display=None, label=None, e_max=None, e_min=None)

Depicts the energy and occupation landscape of a stack, with energies plotted on the y-axis and eigenstates ordered by size on the x-axis.

Parameters:
  • `orbs`

    An object containing the orbital data, including energies, electron counts, and initial density matrix.

  • `display` (Array) –

    Array to annotate the energy states. - If None, electronic occupation is used.

  • `label` (Array) –

    Label for the colorbar. - If None, "initial state occupation" is used.

  • `e_max` (float) –

    The upper limit of the energy range to display on the y-axis. - If None, the maximum energy is used by default. - This parameter allows you to zoom into a specific range of energies for a more focused view.

  • `e_min` (float) –

    The lower limit of the energy range to display on the y-axis. - If None, the minimum energy is used by default. - This parameter allows you to filter out higher-energy states and focus on the lower-energy range.

Notes

The scatter plot displays the eigenstate number on the x-axis and the corresponding energy (in eV) on the y-axis. The color of each point represents the initial state occupation, calculated as the product of the electron count and the initial density matrix diagonal element for each state. A color bar is added to indicate the magnitude of the initial state occupation for each eigenstate.

Source code in src/granad/_plotting.py
@_plot_wrapper
def show_energies(orbs, display = None, label = None, e_max = None, e_min = None):
    """
    Depicts the energy and occupation landscape of a stack, with energies plotted on the y-axis and eigenstates ordered by size on the x-axis.

    Parameters:
        `orbs`: An object containing the orbital data, including energies, electron counts, and initial density matrix.
        `display` (jnp.Array, optional): Array to annotate the energy states.
            - If `None`, electronic occupation is used.
        `label` (jnp.Array, optional): Label for the colorbar.
            - If `None`, "initial state occupation" is used.
        `e_max` (float, optional): The upper limit of the energy range to display on the y-axis. 
            - If `None`, the maximum energy is used by default.
            - This parameter allows you to zoom into a specific range of energies for a more focused view.
        `e_min` (float, optional): The lower limit of the energy range to display on the y-axis.
            - If `None`, the minimum energy is used by default.
            - This parameter allows you to filter out higher-energy states and focus on the lower-energy range.

    Notes:
        The scatter plot displays the eigenstate number on the x-axis and the corresponding energy (in eV) on the y-axis.
        The color of each point represents the initial state occupation, calculated as the product of the electron count and the initial density matrix diagonal element for each state.
        A color bar is added to indicate the magnitude of the initial state occupation for each eigenstate.
    """
    from matplotlib.ticker import MaxNLocator
    e_max = (e_max or orbs.energies.max()) 
    e_min = (e_min or orbs.energies.min())
    widening = (e_max - e_min) * 0.01 # 1% larger in each direction
    e_max += widening
    e_min -= widening
    energies_filtered_idxs = jnp.argwhere( jnp.logical_and(orbs.energies <= e_max, orbs.energies >= e_min))
    state_numbers = energies_filtered_idxs[:, 0]
    energies_filtered = orbs.energies[energies_filtered_idxs]

    if display is None:
        display = jnp.diag(orbs.electrons * orbs.initial_density_matrix_e)
    label = label or "initial state occupation"

    colors =  display[energies_filtered_idxs]

    fig, ax = plt.subplots(1, 1)
    plt.colorbar(
        ax.scatter(
            state_numbers,
            energies_filtered,
            c=colors,
        ),
        label=label,
    )
    ax.set_xlabel("eigenstate number")
    ax.set_ylabel("energy (eV)")
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_ylim(e_min, e_max)    

show_induced_field(orbs, x, y, z, component=0, density_matrix=None, scale='log', levels=100)

Displays a 2D plot of the normalized logarithm of the absolute value of the induced field, for a given field component.

Parameters:
  • `orbs`

    An object containing the orbital data and field information.

  • `x` (array - like) –

    x-coordinates for the 2D grid on which the field is evaluated.

  • `y` (array - like) –

    y-coordinates for the 2D grid on which the field is evaluated.

  • `z` (float) –

    z-coordinate slice at which the field is evaluated in the xy-plane.

  • `component` (int) –

    The field component to display (default is 0). Represents the direction (e.g., x, y, or z) of the field.

  • `density_matrix` (optional) –

    The density matrix used to calculate the induced field. If not provided, the initial density matrix will be used.

  • `scale` (optional) –

    (linear or log) Linear or signed log scale. log is default.

  • `levels` (optional) –

    A list of level values, that should be labeled.

Note

The plot visualizes the induced field's magnitude using a logarithmic scale for better representation of variations in field strength. The field is normalized before applying the logarithm, ensuring that relative differences in field strength are emphasized.

Source code in src/granad/_plotting.py
@_plot_wrapper
def show_induced_field(orbs, x, y, z, component = 0, density_matrix=None, scale = "log", levels = 100):
    """
    Displays a 2D plot of the normalized logarithm of the absolute value of the induced field, for a given field component.

    Parameters:
        `orbs`: An object containing the orbital data and field information.
        `x` (array-like): x-coordinates for the 2D grid on which the field is evaluated.
        `y` (array-like): y-coordinates for the 2D grid on which the field is evaluated.
        `z` (float): z-coordinate slice at which the field is evaluated in the xy-plane.
        `component` (int, optional): The field component to display (default is 0). Represents the direction (e.g., x, y, or z) of the field.
        `density_matrix` (optional): The density matrix used to calculate the induced field. If not provided, the initial density matrix will be used.
        `scale` (optional): (linear or log) Linear or signed log scale. log is default.
        `levels` (optional): A list of level values, that should be labeled.

    Note:
        The plot visualizes the induced field's magnitude using a logarithmic scale for better representation of variations in field strength.
        The field is normalized before applying the logarithm, ensuring that relative differences in field strength are emphasized.
    """

    density_matrix = (
        density_matrix if density_matrix is not None else orbs.initial_density_matrix
    )
    charge = density_matrix.diagonal().real

    X, Y, Z = jnp.meshgrid(x, y, z)
    positions = jnp.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T    

    induced_field = jnp.einsum('jir,i->jr', get_coulomb_field_to_from(orbs.positions, positions, jnp.arange(len(positions))), charge)

    induced_field = induced_field[:, component].reshape(X[:, :, 0].shape)
    label = r"$\dfrac{E}{E_{max}}$"
    plot_field = jnp.abs(induced_field) / jnp.abs(induced_field).max()

    if scale == "log":
        E_sign = jnp.sign(induced_field)
        induced_field = jnp.log(jnp.abs(induced_field) / jnp.abs(induced_field).max() )
        label = r"$sign(E) \cdot \log\left(\dfrac{|E|}{|E|_{max}}\right)$"
        plot_field = E_sign * induced_field


    fig, ax = plt.subplots(1, 1)
    import matplotlib
    from matplotlib.colors import LinearSegmentedColormap
    CMAP=LinearSegmentedColormap.from_list('custom_cmap', ['White','Blue','Red','White'])
    fig.colorbar(
        ax.contour(
            X[:, :, 0], Y[:, :, 0], plot_field,
            cmap =CMAP,
            levels = levels,
            linewidths = .5
        ),
        label = label
    )
    ax.scatter(*zip(*orbs.positions[:, :2]), color = 'black', s=20, zorder = 10)
    ax.axis('equal')

show_res(orbs, res, plot_only=None, plot_labels=None, show_illumination=False, omega_max=None, omega_min=None, xlabel=None, ylabel=None)

Visualizes the evolution of an expectation value over time or frequency, based on the given simulation results.

Parameters:
  • `orbs`

    Not typically required in most use cases, as this function is generally attached to a 'flake' object (e.g., flake.show_res).

  • `res`

    A result object containing the simulation data, including the output values and corresponding time or frequency axis.

  • `plot_only` (Array) –

    Indices of specific components to be plotted. If not provided, all components will be plotted.

  • `plot_labels` (list[str]) –

    Labels for each plotted quantity. If not provided, no labels will be added.

  • `show_illumination` (bool) –

    Whether to include illumination data in the plot. If True, illumination components are displayed.

  • `omega_max` (optional) –

    Upper bound for the frequency range, used when plotting in the frequency domain.

  • `omega_min` (optional) –

    Lower bound for the frequency range, used when plotting in the frequency domain. xlabel (optional) : x-axis label for the plot ylabel (optional) : y-axis label for the plot

Notes

The function adapts automatically to display either time-dependent or frequency-dependent results based on the presence of omega_max and omega_min. If show_illumination is enabled, the function plots the illumination components (x, y, z) as additional curves. The x-axis label changes to represent time or frequency, depending on the mode of operation.

Source code in src/granad/_plotting.py
@_plot_wrapper
def show_res(
    orbs,
    res,
    plot_only : jax.Array = None,
    plot_labels : list[str] = None,
    show_illumination = False,
    omega_max = None,
    omega_min = None,
    xlabel = None,
    ylabel = None
):
    """
    Visualizes the evolution of an expectation value over time or frequency, based on the given simulation results.

    Parameters:
        `orbs`: Not typically required in most use cases, as this function is generally attached to a 'flake' object (e.g., `flake.show_res`).
        `res`: A result object containing the simulation data, including the output values and corresponding time or frequency axis.
        `plot_only` (jax.Array, optional): Indices of specific components to be plotted. If not provided, all components will be plotted.
        `plot_labels` (list[str], optional): Labels for each plotted quantity. If not provided, no labels will be added.
        `show_illumination` (bool, optional): Whether to include illumination data in the plot. If `True`, illumination components are displayed.
        `omega_max` (optional): Upper bound for the frequency range, used when plotting in the frequency domain.
        `omega_min` (optional): Lower bound for the frequency range, used when plotting in the frequency domain.
         `xlabel` (optional) : x-axis label for the plot
         `ylabel` (optional) : y-axis label for the plot

    Notes:
        The function adapts automatically to display either time-dependent or frequency-dependent results based on the presence of `omega_max` and `omega_min`.
        If `show_illumination` is enabled, the function plots the illumination components (`x`, `y`, `z`) as additional curves.
        The x-axis label changes to represent time or frequency, depending on the mode of operation.
    """
    def _show( obs, name ):
        ax.plot(x_axis, obs, label = name)

    fig, ax = plt.subplots(1, 1)    
    ax.set_xlabel(r"time [$\hbar$/eV]")
    plot_obs = res.output
    illu = res.td_illumination
    x_axis = res.time_axis
    cart_list = ["x", "y", "z"]


    if omega_max is not None and omega_min is not None:
        plot_obs = res.ft_output( omega_max, omega_min )
        x_axis, illu = res.ft_illumination( omega_max, omega_min )
        ax.set_xlabel(r"$\omega$ [$\hbar$ eV]")

    for obs in plot_obs:
        obs = obs if plot_only is None else obs[:, plot_only]
        for i, obs_flat in enumerate(obs.T):
            label = '' if plot_labels is None else plot_labels[i]
            _show( obs_flat, label )
        if show_illumination == True:
            for component, illu_flat in enumerate(illu.T):            
                _show(illu_flat, f'illumination_{cart_list[component]}')

    plt.legend()

    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)