"""These facories can provide simpler returns than usual class inits. They are also desigend to avoid circular references, hence the awkward storage of a class with a lazy import of the module An open issue still is to set the metric """ from numbers import Number from importlib import import_module from copy import copy _base_module = __name__.rsplit('.', 1)[0] _factory = dict( One = '.value', Zero = '.value', Value = '.value', Dot = '.dot', Sum = '.sum', Pow = '.pow', Norm = '.norm', VariableNorm = '.norm', Constant = '.constant', Scalar = '.scalar', DoubleContraction = '.double', Tensor = '.tensor', CompoundTensor = '.compound', Variable = '.variable', ValueVariable = '.variable', CompoundVariable = '.variable.compound', SumVariable = '.variable.compound', DotVariable = '.variable.compound', Index = '.index', Slot = '.slot', Delta = '.special', Epsilon = '.special', EpsilonVariable = '.variable.special', default_metric = '.metric', Pipe = '.command', ) def set_metric(metric): factory._ste_metric(metric) class Factory(object): _inventory = dict() def __getattr__(self, attr): if attr not in _factory: raise AttributeError() if attr not in self._inventory: if attr not in _factory: raise AttributeError() module = import_module(_factory[attr], _base_module) item = getattr(module, attr) self._inventory[attr] = item return item return self._inventory[attr] @ property def metric(self): if not hasattr(self, '_metric'): self._metric = self.default_metric return self._metric def _set_metric(self, metric): self._metric = metric factory = Factory() class TensorDotFactory(object): def __call__(self, *tensors): if len(tensors) == 0: return factory.One() if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): raise AttributeError(f'[{self.__class__.__name__}] call error') if len(tensors) == 1: return copy(tensors[0]) new = factory.Dot(*tensors) if new.ntensors == 1: return new.tensors[0] return new tensor_dot = TensorDotFactory() class TensorPowFactory(object): def __call__(self, tensor, power): if power == 0: return factory.One() if isinstance(tensor, Number): return factory.Value(tensor ** power) if power == 1: return copy(tensor) if isinstance(tensor, ( factory.Value, factory.Scalar, )): return tensor ** power new = factory.Pow(tensor, power) # optinally, add code to extract scalar factors? return new tensor_pow = TensorPowFactory() class TensorSumFactory(object): def __call__(self, *tensors): if len(tensors) == 0: return factory.Zero() if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): raise AttributeError(f'[{self.__class__.__name__}] call error') if len(tensors) == 1: return copy(tensors[0]) new = factory.Sum(*tensors) if len(new.tensors) == 1: return new.tensors[0] return new tensor_sum = TensorSumFactory() class TensorNormFactory(object): def __call__(self, tensor): if tensor is None: return factory.Zero() if isinstance(tensor, Number): return factory.Value(abs(tensor)) if isinstance(tensor, factory.Value): return abs(tensor) if isinstance(tensor, factory.Variable): return factory.VariableNorm(tensor) if isinstance(tensor, factory.Tensor) and \ tensor.fundamental and \ tensor.variable.order == 1: return factory.VariableNorm(tensor.variable) return factory.Norm(tensor) tensor_norm = TensorNormFactory() class TensorValueFactory(object): def __call__(self, value): if value == 0: return factory.Zero() if value == 1: return factory.One() return factory.Value(value) tensor_value = TensorValueFactory() class TensorDoubleFactory(object): def __call__(self, *tensors, **kwargs): if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): raise AttributeError(f'[{self.__class__.__name__}] call error') new = factory.DoubleContraction(*tensors, **kwargs) new = new.eliminate_antisymmetry() if not isinstance(new, factory.DoubleContraction): return new return new.eliminate_delta() tensor_double = TensorDoubleFactory() class TensorScalarFactory(object): def __call__(self, tensor): if not isinstance(tensor, factory.CompoundTensor): return copy(tensor) if tensor.ntensors < 2: return copy(tensor) return factory.Scalar(tensor) tensor_scalar = TensorScalarFactory() class TensorDeltaFactory(object): def __call__(self, *slots): if len(slots) < 2: return factory.One() if len(slots) == 2: if slots[0] == slots[1]: return Value(factory.metric.dimension) return factory.Delta(*slots) tensor_delta = TensorDeltaFactory() class TensorEpsilonFactory(object): def __call__(self, *slots): if len(slots) < 2: return factory.Zero() if len(set(s.index if isinstance(s, factory.Slot) else s for s in slots)) < len(slots): return factory.Zero() if len(slots) > factory.metric.dimension: return factory.Zero() return factory.Epsilon(*slots) tensor_epsilon = TensorEpsilonFactory()