from copy import copy from .compound import CompoundScalar from .factory import \ factory, \ tensor_double, \ tensor_dot from .variable import Variable, DeltaVariable from .index import next_inner from .tensor import Tensor from .const import \ _double_contraction, \ _double_contraction_asym, \ _double_contraction_same from .config import \ _latex_config, \ _latex_vec_config, \ _tensor_config # Norm and Contraction are not trictly needed. # there could also be partial contractions, which would be a non-scalar unit. # TODO - VariableDouble class DoubleContraction(CompoundScalar): """ Expressiosn of the form P_ij * Q_ij to write as P:Q for anti-symmetric (sym=-1) maybe write P!Q for now - only constant tensors (derivative 0) TODO - allow general expressions for P and Q TODO - generalise (Scalars) """ _kind = "double" def __init__(self, *tensors, sym=None): if len(tensors) == 0: # doing init return assert len(tensors) == 2 if isinstance(tensors[0], Variable): assert isinstance(other, Variable) for t in tensors: assert t.order == 2 ii0 = new_inner('j', 2) if sym is None: sym = 1 if sym == 1: ii1 = copy(ii0) else: ii1 = ii0[::-1] iii = [ii0, ii1] tensors = list(t[*ii] for t,ii in zip(tensors, iii)) for t in tensors: assert t.order == 2 assert t.slots[0].index != t.slots[1].index if (tensors[0].slots[0] == ~tensors[1].slots[0] and tensors[0].slots[1] == ~tensors[1].slots[1]): sym_ = 1 elif (tensors[0].slots[0] == ~tensors[1].slots[1] and tensors[0].slots[1] == ~tensors[1].slots[0]): sym_ = -1 else: raise AttributeError( f'[{self.__class__.__name__}] Slot mismatch: ' + f'{tensors[0]!r} --- {tensors[1]!r}') # only static tensors for now for t in tensors: assert t.fundamental if sym is not None: assert sym == sym_, f'symmetry mismatch: {tensors}: {sym_}, {sym=}' else: sym = sym_ if sym == -1 and _tensor_config['double_use_symmetry']: s0 = tensors[0].symmetry() s1 = tensors[0].symmetry() if s0 == 1 or s1 == 1: if s0 == 1: tensors = [tensors[0].transpose(), tensors[1]] else: tensors = [tensors[0], tensors[1].transpose()] sym = 1 self.sym = sym self.tensors = tuple(tensors) self.slots = () inner, outer = self.get_indices() outer = inner | outer inner = list() ii0 = [next_inner(inner, outer) for _ in range(2)] if self.sym == 1: ii1 = copy(ii0) else: ii1 = ii0[::-1] iii = [ii0, ii1] tensors = list() for t,ii in zip(self.tensors, iii): tensors.append(t.replace_indices((0, 1), ii)) self.tensors = tuple(tensors) self.inner = tuple(ii0) def __copy__(self): new = super().__copy__() new.sym = copy(self.sym) return new def d(self, var): if self.metric is not None: raise NotImplementedError() return factory.Zero() def _repr(self, level=0, kind=None): if self.sym: symb = _double_contraction else: symb = _double_contraction_asym if self.tensors[0].variable == self.tensors[1].variable and self.sym: return ( _double_contraction_same + self.tensors[0].variable._repr() + _double_contraction_same) return symb.join(t.variable._repr() for t in self.tensors) def compare(self, other): if not isinstance(other, self.__class__): assert other.__class__.__name__ != self.__class__.__name__, 'Maybe auto-reload failed.' return 0 # for now, just allow (assume) base tensors arrange = None if ((self.tensors[0].variable == other.tensors[0].variable) and (self.tensors[1].variable == other.tensors[1].variable)): arrange = 1 elif ((self.tensors[0].variable == other.tensors[1].variable) and (self.tensors[1].variable == other.tensors[0].variable)): arrange = -1 if arrange is None: return 0 if self.sym == other.sym: return 1 # if both tensors have opposite symmetry result is 0 s0 = self.tensors[0].symmetry() s1 = self.tensors[1].symmetry() if s0 * s1 == -1: # both contractions actually evaluate to 0 return 1 if s0 == 1 or s1 == 1: return 1 if s0 == -1 or s1 == -1: return -1 return 0 def _replace_inner_slots(self, inner): new = copy(self) assert len(inner) == 2 inner = tuple(i.index for i in inner) tensors = list() for t in self.tensors: t = copy(t) t.replace_indices((0,1), inner) tensors.append(t) new.inner = inner new.tensors = tuple(tensors) return new def eliminate_antisymmetry(self): s0 = self.tensors[0].symmetry() s1 = self.tensors[1].symmetry() if s0 * s1 == -1: return factory.Zero() return copy(self) def eliminate_delta(self): delta = list() for j,t in enumerate(self.tensors): if isinstance(t.variable, DeltaVariable): delta.append(j) if len(delta) == 0: return copy(self) if len(delta) == 2: return factory.Value(self.variable.dimension) i = 1 - delta[0] return self.tensors[i].replace_indices(1, ~self.tensors[i].slots[0]) def new_tensors(self, *tensors): return tensor_double(*tensors) def replace(self, variable, tensor): assert isinstance(variable, Variable), f'{tensor.__class__.__mro__}' assert isinstance(tensor, (Tensor, Variable)), f'{tensor.__class__.__mro__}' # for now, only simple tensor support match = list(t.variable == variable for t in self.tensors) nmatch = sum(match) if nmatch == 0: return copy(self) if isinstance(tensor, Variable): assert tensor.order == 2 expand = False else: if tensor.fundamental: assert tensor.variable.order >= 2 expand = tensor.variable.order > 2 else: expand = True tensors = list() for i,t in enumerate(self.tensors): if not match[i]: tensors.append(t) continue if isinstance(tensor, Variable): tensors.append(tensor(*t.slots)) continue assert tensor.order == 2 tensors.append(tensor.replace_indices((0, 1), t.slots)) if expand: return tensor_dot(*tensors) return tensor_double(*tensors) def _latex_einstein(self, level=0, kind=None): return _latex_config["mult"].join(t._latex_einstein(level=0, kind=None) for t in self.tensors) def _latex_vec(self, valence, /, *, level=0, kind=None, order=None, transpose=None): out = self.tensors[0].variable._latex_vec(order=2) out += _latex_vec_config['double'] s = self.tensors[1].variable._latex_vec(order=2) if self.sym == -1: s = _latex_vec_config['transpose'].format(s) out += s return out