from copy import copy from numbers import Number from .compound import CompoundScalar from .factory import factory, tensor_pow, tensor_scalar from .tensor import Tensor from .config import \ _debug_config, \ _latex_config, \ _latex_vec_config class Scalar(CompoundScalar): """ the same as Pow(tensor, 1) or Dot, but used to group the essence is that it can be use for diplac and to accelerate comparisons should have exactly one scalar factor """ _kind = 'scalar' def __init__(self, tensor=None): if tensor is None: return assert isinstance(tensor, Tensor), f'{tensor.__class__.__mro__}' assert tensor.order == 0 if isinstance(tensor, factory.CompoundTensor): assert tensor.ntensors > 1 self.tensors = (copy(tensor),) self.slots = () self.inner = () def __pow__(self, power): """ no symbolic powers for now """ assert isinstance(power, Number) return tensor_pow(self.tensors[0], power) def __rtruediv__(self, other): return tensor_pow(self.tensors[0], -1) def d(self, var): return self.tensors[0].d(var) def _repr(self, level=0, kind=None): l = level + 1 if not isinstance(self.tensors[0], factory.CompoundTensor): s = self.tensors[0].name else: s = self.tensors[0]._repr(level=l, kind=self._kind) if not isinstance(self.tensors[0], factory.Norm) and \ level > 1 and \ kind not in ('sum',): s = f'({s})' return s def __copy__(self): new = super().__copy__() return new def compare(self, other): if not isinstance(other, (Scalar, factory.Pow)): return 0 if isinstance(other, factory.Pow): if other.power != 1: return 0 return self.tensors[0].compare(other.tensors[0]) def new_tensors(self, *tensors): return tensor_scalar(*tensors) def expand(self, factors=None): if _debug_config['expand']: print(f'[{self.__class__.__name__}] (expand) {self} {factors=}') return self.tensors[0].expand(factors) def _latex_einstein(self, level=0, kind=None): l = level if kind in (self._kind, factory.Sum._kind, factory.Pow._kind, None): bracket = False else: l += 1 bracket = True s = self.tensors[0]._latex_einstein(l, self._kind) if bracket: brackets = _latex_config['brackets'] brackets = brackets[level % len(brackets)] s = brackets.format(s) return s def _latex_vec(self, valence, level=0, kind=None, order=None, transpose=None): assert valence == () assert order in (None, 0) l = level if kind in (self._kind, factory.Sum._kind, factory.Pow._kind, None): bracket = False else: l += 1 bracket = l > 0 t = self.tensors[0] if isinstance(t, factory.Dot): g = t._find_groups() if len(g) == 1: s = t._latex_vec_scalar((), level=l, kind=self._kind) else: s = t._latex_vec((), level=l, kind=self._kind) else: s = t._latex_vec((), level=l, kind=self._kind) if bracket: brackets = _latex_vec_config['brackets'] brackets = brackets[level % len(brackets)] s = brackets.format(s) return s