"""
The :mod:`pyeda.boolalg.expr` module implements
Boolean functions represented as expressions.
Data Types:
abstract syntax tree
A nested tuple of entries that represents an expression.
It is defined recursively::
ast := ('const', bool)
| ('var', names, indices)
| ('not', ast)
| ('implies', ast, ast)
| ('ite', ast, ast, ast)
| (func, ast, ...)
bool := 0 | 1
names := (name, ...)
indices := (index, ...)
func := 'or'
| 'nor'
| 'and'
| 'nand'
| 'xor'
| 'xnor'
| 'equal'
| 'unequal'
| 'onehot0'
| 'onehot'
| 'majority'
| 'achillesheel'
Interface Functions:
* :func:`exprvar` --- Return a unique expression variable
* :func:`expr` --- Convert an arbitrary object into an Expression
* :func:`ast2expr` --- Convert an abstract syntax tree to an Expression
* :func:`expr2dimacscnf` --- Convert an expression into an equivalent DIMACS CNF
* :func:`upoint2exprpoint` --- Convert an untyped point to an Expression point
* :func:`Not` --- Expression negation operator
* :func:`Or` --- Expression disjunction (sum, OR) operator
* :func:`And` --- Expression conjunction (product, AND) operator
* :func:`Nor` --- Expression NOR (not OR) operator
* :func:`Nand` --- Expression NAND (not AND) operator
* :func:`Xor` --- Expression exclusive or (XOR) operator
* :func:`Xnor` --- Expression exclusive nor (XNOR) operator
* :func:`Equal` --- Expression equality operator
* :func:`Unequal` --- Expression inequality operator
* :func:`Implies` --- Expression implication operator
* :func:`ITE` --- Expression If-Then-Else (ITE) operator
* :func:`OneHot0`
* :func:`OneHot`
* :func:`Majority`
* :func:`AchillesHeel`
* :func:`Mux`
Interface Classes:
* :class:`Expression`
* :class:`ExprConstant`
* :class:`ExprLiteral`
* :class:`ExprVariable`
* :class:`ExprComplement`
* :class:`ExprNot`
* :class:`ExprOrAnd`
* :class:`ExprOr`
* :class:`ExprAnd`
* :class:`ExprXor`
* :class:`ExprEqual`
* :class:`ExprImplies`
* :class:`ExprITE`
* :class:`NormalForm`
* :class:`DisjNormalForm`
* :class:`ConjNormalForm`
* :class:`DimacsCNF`
"""
# Disable "redefining name from outer scope"
# pylint: disable=W0621
import collections
import itertools
from warnings import warn
import pyeda.parsing.boolexpr
from pyeda.boolalg import boolfunc, sat
from pyeda.util import bit_on, clog2, parity, cached_property
# FIXME: This is a hack for readthedocs Sphinx autodoc
try:
from pyeda.boolalg import picosat
except ImportError: # pragma: no cover
pass
# existing ExprVariable/ExprLiteral references
_EXPRLITERALS = dict()
# satisfy_one literal assumptions
_ASSUMPTIONS = set()
def _assume2upoint():
"""Convert global assumptions to an untyped point."""
return (
frozenset(lit.exprvar.uniqid for lit in _ASSUMPTIONS
if isinstance(lit, ExprComplement)),
frozenset(lit.uniqid for lit in _ASSUMPTIONS
if isinstance(lit, ExprVariable))
)
[docs]def exprvar(name, index=None):
r"""Return a unique Expression variable.
A Boolean *variable* is an abstract numerical quantity that may assume any
value in the set :math:`B = \{0, 1\}`.
The ``exprvar`` function returns a unique Boolean variable instance
represented by an expression.
Variable instances may be used to symbolically construct larger expressions.
A variable is defined by one or more *names*,
and zero or more *indices*.
Multiple names establish hierarchical namespaces,
and multiple indices group several related variables.
If the ``name`` parameter is a single ``str``,
it will be converted to ``(name, )``.
The ``index`` parameter is optional;
when empty, it will be converted to an empty tuple ``()``.
If the ``index`` parameter is a single ``int``,
it will be converted to ``(index, )``.
Given identical names and indices, the ``exprvar`` function will always
return the same variable:
>>> exprvar('a', 0) is exprvar('a', 0)
True
To create several single-letter variables:
>>> a, b, c, d = map(exprvar, 'abcd')
To create variables with multiple names (inner-most first):
>>> fifo_push = exprvar(('push', 'fifo'))
>>> fifo_pop = exprvar(('pop', 'fifo'))
.. seealso::
For creating arrays of variables with incremental indices,
use the :func:`pyeda.boolalg.bfarray.exprvars` function.
"""
bvar = boolfunc.var(name, index)
try:
var = _EXPRLITERALS[bvar.uniqid]
except KeyError:
var = _EXPRLITERALS[bvar.uniqid] = ExprVariable(bvar)
return var
def _exprcomp(exprvar):
"""Return an Expression Complement."""
uniqid = -exprvar.uniqid
try:
comp = _EXPRLITERALS[uniqid]
except KeyError:
comp = _EXPRLITERALS[uniqid] = ExprComplement(exprvar)
return comp
[docs]def expr(obj, simplify=True):
"""Convert an arbitrary object into an Expression."""
if isinstance(obj, Expression):
return obj
# False, True, 0, 1
elif isinstance(obj, int) and obj in {0, 1}:
return CONSTANTS[obj]
elif type(obj) is str:
ex = ast2expr(pyeda.parsing.boolexpr.parse(obj))
if simplify:
ex = ex.simplify()
return ex
else:
return CONSTANTS[bool(obj)]
[docs]def ast2expr(ast):
"""Convert an abstract syntax tree to an Expression."""
if ast[0] == 'const':
return CONSTANTS[ast[1]]
elif ast[0] == 'var':
return exprvar(ast[1], ast[2])
else:
xs = [ast2expr(x) for x in ast[1:]]
return ASTOPS[ast[0]](*xs)
[docs]def expr2dimacscnf(expr):
"""Convert an expression into an equivalent DIMACS CNF."""
if not isinstance(expr, Expression):
fstr = "expected expr to be an Expression, got {0.__name__}"
raise TypeError(fstr.format(type(expr)))
litmap, nvars, clauses = expr.encode_cnf()
return litmap, DimacsCNF(nvars, clauses)
def expr2dimacssat(expr): # pragma: no cover
"""Convert an expression into an equivalent DIMACS SAT string."""
if not isinstance(expr, Expression):
fstr = "expected expr to be an Expression, got {0.__name__}"
raise TypeError(fstr.format(type(expr)))
if not expr.simplified:
raise ValueError("expected expr to be simplified")
litmap, nvars = expr.encode_inputs()
formula = _expr2sat(expr, litmap)
if 'xor' in formula:
if '=' in formula:
fmt = 'satex'
else:
fmt = 'satx'
elif '=' in formula:
fmt = 'sate'
else:
fmt = 'sat'
return "p {} {}\n{}".format(fmt, nvars, formula)
def _expr2sat(expr, litmap): # pragma: no cover
"""Convert an expression to a DIMACS SAT string."""
if isinstance(expr, ExprLiteral):
return str(litmap[expr])
elif isinstance(expr, ExprNot):
return "-(" + _expr2sat(expr.x, litmap) + ")"
elif isinstance(expr, ExprOr):
return "+(" + " ".join(_expr2sat(x, litmap)
for x in expr.xs) + ")"
elif isinstance(expr, ExprAnd):
return "*(" + " ".join(_expr2sat(x, litmap)
for x in expr.xs) + ")"
elif isinstance(expr, ExprXor):
return ("xor(" + " ".join(_expr2sat(x, litmap)
for x in expr.xs) + ")")
elif isinstance(expr, ExprEqual):
return "=(" + " ".join(_expr2sat(x, litmap)
for x in expr.xs) + ")"
else:
fstr = ("expected expr to be a Literal or Not/Or/And/Xor/Equal op, "
"got {0.__name__}")
raise ValueError(fstr.format(type(expr)))
[docs]def upoint2exprpoint(upoint):
"""Convert an untyped point to an Expression point."""
point = dict()
for uniqid in upoint[0]:
point[_EXPRLITERALS[uniqid]] = 0
for uniqid in upoint[1]:
point[_EXPRLITERALS[uniqid]] = 1
return point
# primitive functions
[docs]def Not(x, simplify=True):
"""Expression negation operator
If *simplify* is ``True``, return a simplified expression.
"""
x = Expression.box(x)
expr = ExprNot(x)
if simplify:
expr = expr.simplify()
return expr
[docs]def Or(*xs, simplify=True):
"""Expression disjunction (sum, OR) operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprOr(*xs)
if simplify:
expr = expr.simplify()
return expr
[docs]def And(*xs, simplify=True):
"""Expression conjunction (product, AND) operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprAnd(*xs)
if simplify:
expr = expr.simplify()
return expr
# secondary functions
[docs]def Xor(*xs, simplify=True):
"""Expression exclusive or (XOR) operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprXor(*xs)
if simplify:
expr = expr.simplify()
return expr
[docs]def Equal(*xs, simplify=True):
"""Expression equality operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprEqual(*xs)
if simplify:
expr = expr.simplify()
return expr
[docs]def Implies(p, q, simplify=True):
"""Expression implication operator
If *simplify* is ``True``, return a simplified expression.
"""
p = Expression.box(p)
q = Expression.box(q)
expr = ExprImplies(p, q)
if simplify:
expr = expr.simplify()
return expr
[docs]def ITE(s, d1, d0, simplify=True):
"""Expression If-Then-Else (ITE) operator
If *simplify* is ``True``, return a simplified expression.
"""
s = Expression.box(s)
d1 = Expression.box(d1)
d0 = Expression.box(d0)
expr = ExprITE(s, d1, d0)
if simplify:
expr = expr.simplify()
return expr
# high order functions
[docs]def Nor(*xs, simplify=True):
"""Expression NOR (not OR) operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprNot(ExprOr(*xs))
if simplify:
expr = expr.simplify()
return expr
[docs]def Nand(*xs, simplify=True):
"""Expression NAND (not AND) operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprNot(ExprAnd(*xs))
if simplify:
expr = expr.simplify()
return expr
[docs]def Xnor(*xs, simplify=True):
"""Expression exclusive nor (XNOR) operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprNot(ExprXor(*xs))
if simplify:
expr = expr.simplify()
return expr
[docs]def Unequal(*xs, simplify=True):
"""Expression inequality operator
If *simplify* is ``True``, return a simplified expression.
"""
xs = [Expression.box(x) for x in xs]
expr = ExprNot(ExprEqual(*xs))
if simplify:
expr = expr.simplify()
return expr
[docs]def OneHot0(*xs, simplify=True, conj=True):
"""
Return an expression that means
"at most one input function is true".
If *simplify* is ``True``, return a simplified expression.
If *conj* is ``True``, return a CNF.
Otherwise, return a DNF.
"""
xs = [Expression.box(x) for x in xs]
terms = list()
if conj:
for x0, x1 in itertools.combinations(xs, 2):
terms.append(ExprOr(ExprNot(x0), ExprNot(x1)))
expr = ExprAnd(*terms)
else:
for _xs in itertools.combinations(xs, len(xs) - 1):
terms.append(ExprAnd(*[ExprNot(x) for x in _xs]))
expr = ExprOr(*terms)
if simplify:
expr = expr.simplify()
return expr
[docs]def OneHot(*xs, simplify=True, conj=True):
"""
Return an expression that means
"exactly one input function is true".
If *simplify* is ``True``, return a simplified expression.
If *conj* is ``True``, return a CNF.
Otherwise, return a DNF.
"""
xs = [Expression.box(x) for x in xs]
terms = list()
if conj:
for x0, x1 in itertools.combinations(xs, 2):
terms.append(ExprOr(ExprNot(x0), ExprNot(x1)))
terms.append(ExprOr(*xs))
expr = ExprAnd(*terms)
else:
for i, xi in enumerate(xs):
zeros = [ExprNot(x) for x in xs[:i] + xs[i+1:]]
terms.append(ExprAnd(xi, *zeros))
expr = ExprOr(*terms)
if simplify:
expr = expr.simplify()
return expr
[docs]def Majority(*xs, simplify=True, conj=False):
"""
Return an expression that means
"the majority of input functions are true".
If *simplify* is ``True``, return a simplified expression.
If *conj* is ``True``, return a CNF.
Otherwise, return a DNF.
"""
xs = [Expression.box(x) for x in xs]
if conj:
terms = list()
for _xs in itertools.combinations(xs, (len(xs) + 1) // 2):
terms.append(ExprOr(*_xs))
expr = ExprAnd(*terms)
else:
terms = list()
for _xs in itertools.combinations(xs, len(xs) // 2 + 1):
terms.append(ExprAnd(*_xs))
expr = ExprOr(*terms)
if simplify:
expr = expr.simplify()
return expr
[docs]def AchillesHeel(*xs, simplify=True):
r"""
Return the Achille's Heel function, defined as:
:math:`\prod_{i=0}^{n/2-1}{X_{2i} + X_{2i+1}}`.
If *simplify* is ``True``, return a simplified expression.
"""
nargs = len(xs)
if nargs & 1:
fstr = "expected an even number of arguments, got {}"
raise ValueError(fstr.format(nargs))
xs = [Expression.box(x) for x in xs]
expr = ExprAnd(*[ExprOr(xs[2*i], xs[2*i+1]) for i in range(nargs // 2)])
if simplify:
expr = expr.simplify()
return expr
[docs]def Mux(fs, sel, simplify=True):
"""
Return an expression that multiplexes a sequence of input functions over a
sequence of select functions.
"""
# convert Mux([a, b], x) to Mux([a, b], [x])
if isinstance(sel, Expression):
sel = [sel]
if len(sel) < clog2(len(fs)):
fstr = "expected at least {} select bits, got {}"
raise ValueError(fstr.format(clog2(len(fs)), len(sel)))
it = boolfunc.iter_terms(sel)
expr = ExprOr(*[ExprAnd(f, *next(it)) for f in fs])
if simplify:
expr = expr.simplify()
return expr
def ForAll(vs, expr): # pragma: no cover
"""
Return an expression that means
"for all variables in *vs*, *expr* is true".
"""
return And(*expr.cofactors(vs))
def Exists(vs, expr): # pragma: no cover
"""
Return an expression that means
"there exists a variable in *vs* such that *expr* is true".
"""
return Or(*expr.cofactors(vs))
[docs]class Expression(boolfunc.Function):
"""Boolean function represented by a logic expression
.. seealso::
This is a subclass of :class:`pyeda.boolalg.boolfunc.Function`
An expression is a tree data structure with operators at the branches,
and constants/literals at the leaves.
A literal is a variable or its complement.
There are two ways to construct an expression:
* Convert another function representation.
* Use operators on existing expressions.
Use the ``exprvar`` function to create expression variables,
and use the Python ``~|&^`` operators for NOT, OR, AND, XOR.
Additionally, expressions overload ``>>`` for IMPLIES.
For example::
>>> a, b, c, d, p, q = map(exprvar, 'abcdpq')
>>> f = ~a | b & c ^ d
>>> g = p >> q
To create unsimplified expressions, use class constructors::
>>> ExprOr(0, a)
Or(0, a)
To create simplified expressions,
use function form operators with ``simplify=True``,
or Python ``~|&^`` operators::
>>> Or(0, a)
a
>>> 0 | a
a
"""
ASTOP = NotImplemented
PRECEDENCE = -1
def __init__(self):
self._simplified = False
self._inverse = None
def __str__(self):
raise NotImplementedError()
def __repr__(self):
return self.__str__()
def __lt__(self, other): # pragma: no cover
return id(self) < id(other)
[docs] def to_unicode(self):
"""Return a string representation using Unicode symbols."""
raise NotImplementedError()
[docs] def to_latex(self):
"""Return a string representation using Latex."""
raise NotImplementedError()
# Operators
def __invert__(self):
return self.simplify()._inv
def __or__(self, g):
return ExprOr(self, self.box(g)).simplify()
def __and__(self, g):
return ExprAnd(self, self.box(g)).simplify()
def __xor__(self, g):
return ExprXor(self, self.box(g)).simplify()
[docs] def __rshift__(self, g):
r"""Boolean implication operator
+-----------+-----------+----------------------+
| :math:`f` | :math:`g` | :math:`f \implies g` |
+===========+===========+======================+
| 0 | 0 | 1 |
+-----------+-----------+----------------------+
| 0 | 1 | 1 |
+-----------+-----------+----------------------+
| 1 | 0 | 0 |
+-----------+-----------+----------------------+
| 1 | 1 | 1 |
+-----------+-----------+----------------------+
"""
return ExprImplies(self, self.box(g)).simplify()
def __rrshift__(self, g):
"""Reverse Boolean implication operator"""
return ExprImplies(self.box(g), self).simplify()
# From Function
@cached_property
def inputs(self):
return tuple(sorted(self.support))
def urestrict(self, upoint):
intersect = upoint[0] & upoint[1]
if intersect:
parts = list()
for uniqid in intersect:
v = _EXPRLITERALS[uniqid]
parts += [str(v), str(~v)]
raise ValueError("conflicting constraints: " + ", ".join(parts))
return self._urestrict2(upoint)
def _urestrict1(self, upoint):
"""Implementation of restrict after error-checking."""
raise NotImplementedError()
def _urestrict2(self, upoint):
"""Implementation of restrict after error-checking."""
return self._urestrict1(upoint).simplify()
def satisfy_one(self):
if self.is_cnf():
litmap, cnf = expr2dimacscnf(self)
assumptions = [litmap[lit] for lit in _ASSUMPTIONS]
soln = cnf.satisfy_one(assumptions)
if soln is None:
return None
else:
return cnf.soln2point(soln, litmap)
else:
if _ASSUMPTIONS:
aupnt = _assume2upoint()
upoint = sat.backtrack(self.urestrict(aupnt))
upoint = (upoint[0] | aupnt[0], upoint[1] | aupnt[1])
else:
upoint = sat.backtrack(self)
if upoint is None:
return None
else:
return upoint2exprpoint(upoint)
def satisfy_all(self):
if self.is_cnf():
litmap, cnf = expr2dimacscnf(self)
for soln in cnf.satisfy_all():
yield cnf.soln2point(soln, litmap)
else:
for upoint in sat.iter_backtrack(self):
yield upoint2exprpoint(upoint)
def is_zero(self):
return False
def is_one(self):
return False
@staticmethod
def box(obj):
if isinstance(obj, Expression):
return obj
# False, True, 0, 1
elif isinstance(obj, int) and obj in {0, 1}:
return CONSTANTS[obj]
elif type(obj) is str:
return ast2expr(pyeda.parsing.boolexpr.parse(obj))
else:
return CONSTANTS[bool(obj)]
# Specific to Expression
def traverse(self):
"""Iterate through all nodes in this expression in DFS order."""
yield from self._traverse(set())
def _traverse(self, visited):
"""Iterate through all nodes in this expression in DFS order."""
raise NotImplementedError()
[docs] def invert(self):
"""Return the inverse expression."""
warn("Expression.invert is deprecated")
return self._inv
@property
def _inv(self):
if self._inverse is None:
self._inverse = ExprNot(self)
return self._inverse
[docs] def simplify(self):
"""Return a simplified expression."""
raise NotImplementedError()
@property
def simplified(self):
"""Return True if the expression is simplified."""
return self._simplified
def factor(self, conj=False):
"""Return a factored expression.
A factored expression is one and only one of the following:
* A literal.
* A disjunction / conjunction of factored expressions.
Parameters
conj : bool
Always choose a conjunctive form when there's a choice
"""
warn("Expression.factor is deprecated. Use Expression.to_nnf instead")
return self.to_nnf()
@property
def depth(self):
"""Return the depth of the expression tree.
Expression depth is defined recursively:
1. A leaf node (constant or literal) has zero depth.
2. A branch node (operator) has depth equal to the maximum depth of
its children (arguments) plus one.
"""
raise NotImplementedError()
[docs] def to_ast(self):
"""Return the expression converted to an abstract syntax tree."""
raise NotImplementedError()
[docs] def expand(self, vs=None, conj=False):
"""Return the Shannon expansion with respect to a list of variables."""
vs = self._expect_vars(vs)
if vs:
outer, inner = (ExprAnd, ExprOr) if conj else (ExprOr, ExprAnd)
terms = [inner(self.restrict(p),
*boolfunc.point2term(p, conj)).simplify()
for p in boolfunc.iter_points(vs)]
if conj:
terms = [term for term in terms if term is not EXPRONE]
else:
terms = [term for term in terms if term is not EXPRZERO]
return outer(*terms)
else:
return self
[docs] def to_nnf(self, conj=False):
"""Return the expression in negation normal form (NNF).
An NNF expression is one of the following:
* A literal.
* A disjunction / conjunction of NNF expressions.
Parameters
conj : bool
Always choose a conjunctive form when there's a choice
"""
raise NotImplementedError()
def _inv_nnf(self, conj=False):
"""Return an inverted factored expression."""
raise NotImplementedError()
def _flatten(self, op):
"""Return a factored, flattened expression."""
raise NotImplementedError()
[docs] def flatten(self, op):
r"""Return a factored, flattened expression.
Use the distributive law to flatten all nested expressions:
.. math::
a + (b \cdot c) = (a + b) \cdot (a + c)
a \cdot (b + c) = (a \cdot b) + (a \cdot c)
op : ExprOr or ExprAnd
The operator you want to flatten. For example, if you want to
produce an ExprOr expression, flatten all the nested ExprAnds.
"""
if op is ExprOr or op is ExprAnd:
return self.to_nnf()._flatten(op)
else:
raise ValueError("expected op to be ExprOr or ExprAnd")
[docs] def to_dnf(self, flatten=True):
"""Return the expression in disjunctive normal form."""
if flatten:
return self.to_nnf()._flatten(ExprAnd)
else:
terms = list()
for upnt in _iter_ones(self):
lits = [~_EXPRLITERALS[uniqid] for uniqid in upnt[0]]
lits += [_EXPRLITERALS[uniqid] for uniqid in upnt[1]]
terms.append(ExprAnd(*lits).simplify())
return ExprOr(*terms).simplify()
[docs] def to_cdnf(self, flatten=True):
"""Return the expression in canonical disjunctive normal form."""
return self.to_dnf(flatten)._reduce()
def is_dnf(self):
"""Return True if this expression is in disjunctive normal form."""
# pylint: disable=R0201
return False
[docs] def to_cnf(self, flatten=True):
"""Return the expression in conjunctive normal form."""
if flatten:
return self.to_nnf()._flatten(ExprOr)
else:
terms = list()
for upnt in _iter_zeros(self):
lits = [_EXPRLITERALS[uniqid] for uniqid in upnt[0]]
lits += [~_EXPRLITERALS[uniqid] for uniqid in upnt[1]]
terms.append(ExprOr(*lits).simplify())
return ExprAnd(*terms).simplify()
[docs] def to_ccnf(self, flatten=True):
"""Return the expression in canonical conjunctive normal form."""
return self.to_cnf(flatten)._reduce()
def is_cnf(self):
"""Return True if this expression is in conjunctive normal form."""
# pylint: disable=R0201
return False
@cached_property
def _cover(self):
"""Return the cover representation."""
raise NotImplementedError()
@property
def cover(self):
"""Return the DNF expression as a cover of cubes."""
if self.is_dnf():
return self._cover
else:
raise ValueError("expected a DNF expression")
def _absorb(self):
"""Return the DNF/CNF expression after absorption."""
raise NotImplementedError()
[docs] def absorb(self):
r"""Return the DNF/CNF expression after absorption.
.. math::
a + (a \cdot b) = a
a \cdot (a + b) = a
"""
if self.is_dnf() or self.is_cnf():
return self._absorb()
else:
raise ValueError("expected expression to be in normal form")
def _reduce(self):
"""Reduce the DNF/CNF expression to a canonical form."""
raise NotImplementedError
[docs] def reduce(self):
"""Reduce the DNF/CNF expression to a canonical form."""
if self.is_dnf() or self.is_cnf():
return self._reduce()
else:
raise ValueError("expected expression to be in normal form")
def _encode_clause(self, litmap):
"""Encode as a compact DNF/CNF clause."""
raise NotImplementedError()
def _encode_dnf(self):
"""Encode as a compact DNF."""
raise NotImplementedError()
def _encode_cnf(self):
"""Encode as a compact CNF."""
raise NotImplementedError()
[docs] def encode_dnf(self):
"""Encode as a compact DNF."""
if self.is_dnf():
return self._encode_dnf()
else:
raise ValueError("expected a DNF expression")
[docs] def encode_cnf(self):
"""Encode as a compact CNF."""
if self.is_cnf():
return self._encode_cnf()
else:
raise ValueError("expected a CNF expression")
[docs] def tseitin(self, auxvarname='aux'):
"""Convert the expression to Tseitin's encoding."""
if self.is_cnf():
return self
_, constraints = _tseitin(self.to_nnf(), auxvarname)
fst = constraints[-1][1]
rst = [ExprEqual(f, expr).to_cnf() for f, expr in constraints[:-1]]
return ExprAnd(fst, *rst).simplify()
[docs] def complete_sum(self):
"""Return a DNF that contains all prime implicants."""
if self.is_dnf():
dnf = self
else:
dnf = self.to_dnf(flatten=False)
return _complete_sum(dnf)
[docs] def equivalent(self, other):
"""Return True if this expression is equivalent to other."""
f = ExprXor(self, self.box(other)).simplify()
return f.satisfy_one() is None
[docs] def to_dot(self, name='EXPR'): # pragma: no cover
"""Convert to DOT language representation."""
parts = ['graph', name, '{', 'rankdir=BT;']
for expr in self.traverse():
if expr is EXPRZERO:
parts += ['n' + str(id(expr)), '[label=0,shape=box]']
elif expr is EXPRONE:
parts += ['n' + str(id(expr)), '[label=1,shape=box]']
elif isinstance(expr, ExprLiteral):
parts += ['n' + str(id(expr)),
'[label="{}",shape=box]'.format(expr)]
else:
parts.append('n' + str(id(expr)))
parts.append("[label={0.ASTOP},shape=circle]".format(expr))
for expr in self.traverse():
if isinstance(expr, ExprNot):
parts += ['n' + str(id(expr.x)), '--',
'n' + str(id(expr))]
elif isinstance(expr, ExprImplies):
parts += ['n' + str(id(expr.p)), '--',
'n' + str(id(expr)), '[label=p]']
parts += ['n' + str(id(expr.q)), '--',
'n' + str(id(expr)), '[label=q]']
elif isinstance(expr, ExprITE):
parts += ['n' + str(id(expr.s)), '--',
'n' + str(id(expr)), '[label=s]']
parts += ['n' + str(id(expr.d1)), '--',
'n' + str(id(expr)), '[label=d1]']
parts += ['n' + str(id(expr.d0)), '--',
'n' + str(id(expr)), '[label=d0]']
elif isinstance(expr, _NaryOperator):
for x in expr.xs:
parts += ['n' + str(id(x)), '--', 'n' + str(id(expr))]
parts.append('}')
return " ".join(parts)
[docs]class ExprConstant(Expression):
"""Expression constant"""
ASTOP = 'const'
VALUE = NotImplemented
def __init__(self):
super(ExprConstant, self).__init__()
self._simplified = True
def __bool__(self):
return bool(self.VALUE)
def __int__(self):
return self.VALUE
def __str__(self):
return str(self.VALUE)
# From Function
@cached_property
def support(self):
return frozenset()
def _urestrict1(self, upoint):
return self
def compose(self, mapping):
return self
# From Expression
def _traverse(self, visited):
if self not in visited:
visited.add(self)
yield self
def simplify(self):
return self
@property
def depth(self):
return 0
def to_ast(self):
return (self.ASTOP, self.VALUE)
def to_nnf(self, conj=False):
return self
def _inv_nnf(self, conj=False):
return self._inv
# FactoredExpression
def _flatten(self, op):
return self
def _absorb(self):
return self
def _reduce(self):
return self
class _ExprZero(ExprConstant):
"""
Expression zero
.. note:: Never use this class. Use EXPRZERO singleton instead.
"""
VALUE = 0
def __lt__(self, other):
if isinstance(other, Expression):
return True
return id(self) < id(other)
# From Function
def is_zero(self):
return True
# From Expression
def is_dnf(self):
return True
@cached_property
def _cover(self):
return set()
def _encode_dnf(self):
litmap, nvars = self.encode_inputs()
clauses = set()
return litmap, nvars, clauses
class _ExprOne(ExprConstant):
"""
Expression one
.. note:: Never use this class. Use EXPRONE singleton instead.
"""
VALUE = 1
def __lt__(self, other):
if other is EXPRZERO:
return False
if isinstance(other, Expression):
return True
return id(self) < id(other)
# From Function
def is_one(self):
return True
# From Expression
def is_cnf(self):
return True
def _encode_cnf(self):
litmap, nvars = self.encode_inputs()
clauses = set()
return litmap, nvars, clauses
EXPRZERO = _ExprZero()
EXPRONE = _ExprOne()
EXPRZERO._inverse = EXPRONE
EXPRONE._inverse = EXPRZERO
CONSTANTS = [EXPRZERO, EXPRONE]
[docs]class ExprLiteral(Expression):
"""An instance of a variable or of its complement"""
def __init__(self):
super(ExprLiteral, self).__init__()
self._simplified = True
def __enter__(self):
_ASSUMPTIONS.add(self)
def __exit__(self, exc_type, exc_val, traceback):
_ASSUMPTIONS.discard(self)
# From Expression
def _traverse(self, visited):
if self not in visited:
visited.add(self)
yield self
def simplify(self):
return self
@property
def depth(self):
return 0
def to_nnf(self, conj=False):
return self
def _inv_nnf(self, conj=False):
return self._inv
def is_dnf(self):
return True
def is_cnf(self):
return True
# FactoredExpression
def _flatten(self, op):
return self
# FlattenedExpression
@cached_property
def _lits(self):
"""Return a frozenset of literals."""
return frozenset([self])
@cached_property
def _cube(self):
"""Return the cube representation."""
return frozenset([self])
@cached_property
def _cover(self):
return {self._cube}
def _absorb(self):
return self
def _reduce(self):
return self
def _encode_clause(self, litmap):
return frozenset([litmap[self]])
def _encode_dnf(self):
litmap, nvars = self.encode_inputs()
clauses = {self._encode_clause(litmap)}
return litmap, nvars, clauses
def _encode_cnf(self):
litmap, nvars = self.encode_inputs()
clauses = {self._encode_clause(litmap)}
return litmap, nvars, clauses
minterm_index = NotImplemented
[docs]class ExprVariable(boolfunc.Variable, ExprLiteral):
"""Expression variable"""
ASTOP = 'var'
def __init__(self, bvar):
boolfunc.Variable.__init__(self, bvar.names, bvar.indices)
ExprLiteral.__init__(self)
self._inverse = _exprcomp(self)
def to_unicode(self):
return self.__str__()
def to_latex(self):
suffix = ", ".join(str(idx) for idx in self.indices)
return self.qualname + "_{" + suffix + "}"
def __lt__(self, other):
if isinstance(other, ExprConstant):
return False
if isinstance(other, ExprVariable):
return boolfunc.Variable.__lt__(self, other)
if isinstance(other, ExprComplement):
return boolfunc.Variable.__lt__(self, other.exprvar)
if isinstance(other, Expression):
return True
return id(self) < id(other)
# From Function
@cached_property
def support(self):
return frozenset([self, ])
def _urestrict1(self, upoint):
if self.uniqid in upoint[0]:
return EXPRZERO
elif self.uniqid in upoint[1]:
return EXPRONE
else:
return self
def compose(self, mapping):
try:
return mapping[self].simplify()
except KeyError:
return self
# From Expression
def to_ast(self):
return (self.ASTOP, self.names, self.indices)
minterm_index = 1
maxterm_index = 0
@property
def splitvar(self):
"""Return a good splitting variable."""
return self
[docs]class ExprComplement(ExprLiteral):
"""Expression complement"""
# Prime - U2032
SYMBOL = '′'
def __init__(self, exprvar):
super(ExprComplement, self).__init__()
self._inverse = exprvar
self.exprvar = exprvar
self.uniqid = -exprvar.uniqid
def __str__(self):
return '~' + str(self.exprvar)
def to_unicode(self):
return str(self.exprvar) + self.SYMBOL
def to_latex(self):
return "\\bar{" + self.exprvar.to_latex() + "}"
def __lt__(self, other):
if isinstance(other, ExprConstant):
return False
if isinstance(other, ExprVariable):
return (self.exprvar.names < other.names or
self.exprvar.names == other.names and
self.exprvar.indices <= other.indices)
if isinstance(other, ExprComplement):
return boolfunc.Variable.__lt__(self.exprvar, other.exprvar)
if isinstance(other, Expression):
return True
return id(self) < id(other)
# From Function
@cached_property
def support(self):
return frozenset([self.exprvar, ])
def _urestrict1(self, upoint):
if self.exprvar.uniqid in upoint[0]:
return EXPRONE
elif self.exprvar.uniqid in upoint[1]:
return EXPRZERO
else:
return self
def compose(self, mapping):
try:
return mapping[self.exprvar].simplify()._inv
except KeyError:
return self
# From Expression
def to_ast(self):
return (ExprNot.ASTOP, self.exprvar.to_ast())
minterm_index = 0
maxterm_index = 1
@property
def splitvar(self):
"""Return a good splitting variable."""
return self.exprvar
[docs]class ExprNot(Expression):
"""Expression NOT operator"""
ASTOP = 'not'
# Logical NOT - U00AC
SYMBOL = '¬'
def __new__(cls, x):
# Primitives
if isinstance(x, ExprConstant) or isinstance(x, ExprLiteral):
return x._inv
# Auto-eliminate double negatives
elif isinstance(x, ExprNot):
return x.x
else:
return super(ExprNot, cls).__new__(cls)
def __init__(self, x):
super(ExprNot, self).__init__()
self._simplified = x._simplified
self._inverse = x
self.x = x
@property
def arg(self):
return self.x
def __str__(self):
return "Not(" + str(self.x) + ")"
def to_unicode(self):
return self.SYMBOL + "(" + str(self.x.to_unicode()) + ")"
def to_latex(self):
return "\\overline{" + self.x.to_latex() + "}"
# From Function
@property
def support(self):
return self.x.support
def _urestrict1(self, upoint):
_x = self.x._urestrict1(upoint)
if _x is not self.x:
return self.__class__(_x)
else:
return self
def compose(self, mapping):
_x = self.x.compose(mapping)
if _x is not self.x:
return self.__class__(_x).simplify()
else:
return self.simplify()
# From Expression
def _traverse(self, visited):
yield from self.x._traverse(visited)
if self not in visited:
visited.add(self)
yield self
def simplify(self):
if self._simplified:
return self
return self.x.simplify()._inv
@property
def depth(self):
return self.x.depth + 1
def to_nnf(self, conj=False):
return self.x._inv_nnf()
def _inv_nnf(self, conj=False):
return self.x.to_nnf()
def to_ast(self):
return (self.ASTOP, self.x.to_ast())
@property
def splitvar(self):
"""Return a good splitting variable."""
return self.x.splitvar
class _NaryOperator(Expression):
"""Common methods for N-ary expression operators."""
def __init__(self, *xs):
super(_NaryOperator, self).__init__()
self.xs = xs
@property
def args(self):
return self.xs
# From Function
@cached_property
def support(self):
return frozenset.union(*[x.support for x in self.xs])
def _urestrict1(self, upoint):
modified = False
xs = list()
for x in self.xs:
_x = x._urestrict1(upoint)
if _x is not x:
modified = True
xs.append(_x)
if modified:
return self.__class__(*xs)
else:
return self
def compose(self, mapping):
modified = False
xs = list()
for x in self.xs:
_x = x.compose(mapping)
if _x is not x:
modified = True
xs.append(_x)
if modified:
return self.__class__(*xs).simplify()
else:
return self.simplify()
# From Expression
def _traverse(self, visited):
for x in self.xs:
yield from x._traverse(visited)
if self not in visited:
visited.add(self)
yield self
@cached_property
def depth(self):
return max(x.depth for x in self.xs) + 1
def to_ast(self):
return (self.ASTOP, ) + tuple(x.to_ast() for x in self.xs)
# Specific to _NaryOperator
def _joinargs(self, sep):
"""Return arguments as a string, joined by a separator."""
return sep.join(str(x) for x in sorted(self.xs))
@cached_property
def splitvar(self):
"""Return a good splitting variable.
Heuristic: find the variable that appears in the max # of xs.
"""
cnt = collections.Counter()
for x in self.xs:
for v in self.support:
cnt[v] += (v in x.support)
return cnt.most_common(1)[0][0]
[docs]class ExprOrAnd(_NaryOperator):
"""Base class for Expression OR/AND expressions"""
IDENTITY = NotImplemented
DOMINATOR = NotImplemented
def __lt__(self, other):
if isinstance(other, ExprConstant):
return False
if isinstance(other, ExprLiteral):
return self.support < other.support
if isinstance(other, self.__class__) and self.depth == other.depth == 1:
# min/max term
if self.support == other.support:
return self.term_index < other.term_index
else:
# support containment
if self.support < other.support:
return True
if other.support < self.support:
return False
# support disjoint
v = sorted(self.support ^ other.support)[0]
if v in self.support:
return True
if v in other.support:
return False
return id(self) < id(other)
# From Function
def _urestrict1(self, upoint):
modified = False
xs = set()
for x in self.xs:
_x = x._urestrict1(upoint)
# speed hack
if _x is self.DOMINATOR:
return self.DOMINATOR
elif _x is not x:
modified = True
xs.add(_x)
if modified:
return self.__class__(*xs)
else:
return self
# From Expression
def simplify(self):
if self._simplified:
return self
temps, xs = set(self.xs), set()
while temps:
x = temps.pop().simplify()
if x is self.DOMINATOR:
return x
elif x is self.IDENTITY:
pass
# associative
elif isinstance(x, self.__class__):
temps.update(x.xs)
# complement
elif x._inv in xs:
return self.DOMINATOR
else:
xs.add(x)
# Or() = 0; And() = 1
if len(xs) == 0:
obj = self.IDENTITY
# Or(x) = x, And(x) = x
elif len(xs) == 1:
obj = xs.pop()
else:
obj = self.__class__(*xs)
obj._simplified = True
return obj
def to_nnf(self, conj=False):
xs = [x.to_nnf() for x in self.xs]
return self.__class__(*xs).simplify()
def _inv_nnf(self, conj=False):
xs = [x._inv_nnf() for x in self.xs]
return self.get_dual()(*xs).simplify()
# Specific to ExprOrAnd
@staticmethod
def get_dual():
"""Return the dual function.
The dual of Or is And, and the dual of And is Or.
"""
raise NotImplementedError()
@property
def term_index(self):
"""
Return an integer bitstring that uniquely identifies this min/max term.
Examples
--------
+--------------+-------+
| term | index |
+==============+=======+
| ~a & ~b & ~c | 000 |
| a & b & ~c | 110 |
| a & b & c | 111 |
+--------------+-------+
| ~a | ~b | ~c | 111 |
| a | b | ~c | 001 |
| a | b | c | 000 |
+==============+=======+
"""
raise NotImplementedError()
# FactoredExpression
def _flatten(self, op):
if isinstance(self, op):
self_dual = self.get_dual()
for i, xi in enumerate(self.xs):
if isinstance(xi, self_dual):
others = self.xs[:i] + self.xs[i+1:]
xs = [op(x, *others) for x in xi.xs]
expr = op.get_dual()(*xs).simplify()
return expr._flatten(op)._absorb()
return self
else:
xs = [x._flatten(op)._absorb() if x.depth > 1 else x
for x in self.xs]
return op.get_dual()(*xs).simplify()
# FlattenedExpression
@cached_property
def _lits(self):
"""Return a frozenset of literals."""
return frozenset(self.xs)
def _absorb(self):
dual = self.get_dual()
# Get rid of all equivalent terms
temps = {x._lits for x in self.xs}
xs = list()
# Drop all terms that are a superset of other terms
while temps:
fst = temps.pop()
drop_fst = False
drop_rst = set()
for temp in temps:
if fst > temp:
drop_fst = True
elif fst < temp:
drop_rst.add(temp)
if not drop_fst:
x = dual(*fst).simplify()
xs.append(x)
temps -= drop_rst
return self.__class__(*xs).simplify()
def _reduce(self):
if self.depth == 1:
return self
terms = list()
indices = set()
for term in self.xs:
vs = list(self.support - term.support)
eterms = self._term_expand(term, vs) if vs else (term, )
for eterm in eterms:
if eterm.term_index not in indices:
terms.append(eterm)
indices.add(eterm.term_index)
return self.__class__(*terms).simplify()
def _encode_clause(self, litmap):
return frozenset(litmap[x] for x in self.xs)
@staticmethod
def _term_expand(term, vs):
"""Return a term expanded by a list of variables."""
raise NotImplementedError()
[docs]class ExprOr(ExprOrAnd):
"""Expression OR operator"""
ASTOP = 'or'
SYMBOL = '+'
LATEX_SYMBOL = '+'
PRECEDENCE = 2
IDENTITY = EXPRZERO
DOMINATOR = EXPRONE
def __str__(self):
return "Or(" + self._joinargs(", ") + ")"
def to_unicode(self):
parts = list()
for x in sorted(self.xs):
# lower precedence: implies, equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_unicode() + ')')
else:
parts.append(x.to_unicode())
sep = " " + self.SYMBOL + " "
return sep.join(parts)
def to_latex(self):
parts = list()
for x in sorted(self.xs):
# lower precedence: implies, equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_latex() + ')')
else:
parts.append(x.to_latex())
sep = " " + self.LATEX_SYMBOL + " "
return sep.join(parts)
# From Expression
def is_dnf(self):
# a | b
if self.depth == 1:
return all(isinstance(x, ExprLiteral) for x in self.xs)
# a | b & c
elif self.depth == 2:
return all(isinstance(x, ExprLiteral) or
isinstance(x, ExprAnd) and x.is_cnf()
for x in self.xs)
else:
return False
def is_cnf(self):
# a & b
if self.depth == 1:
return all(isinstance(x, ExprLiteral) for x in self.xs)
else:
return False
def _encode_dnf(self):
litmap, nvars = self.encode_inputs()
clauses = {x._encode_clause(litmap) for x in self.xs}
return litmap, nvars, clauses
def _encode_cnf(self):
litmap, nvars = self.encode_inputs()
clauses = {self._encode_clause(litmap)}
return litmap, nvars, clauses
# Specific to ExprOrAnd
@staticmethod
def get_dual():
return ExprAnd
# FlattenedExpression
@cached_property
def _cover(self):
return {x._cube for x in self.xs}
@property
def term_index(self):
return self.maxterm_index
# Specific to ExprOr
@cached_property
def maxterm_index(self):
"""Return this maxterm's unique index."""
num = self.degree - 1
index = 0
for i, v in enumerate(self.inputs):
if ~v in self.xs:
index |= 1 << (num - i)
return index
@staticmethod
def _term_expand(term, vs):
return term.expand(vs, conj=False).xs
[docs]class ExprAnd(ExprOrAnd):
"""Expression AND operator"""
ASTOP = 'and'
# Middle dot - U00B7
SYMBOL = '·'
LATEX_SYMBOL = '\\cdot'
PRECEDENCE = 0
IDENTITY = EXPRONE
DOMINATOR = EXPRZERO
def __str__(self):
return "And(" + self._joinargs(", ") + ")"
def __enter__(self):
for x in self.xs:
if isinstance(x, ExprLiteral):
_ASSUMPTIONS.add(x)
else:
raise ValueError("expected assumption to be a literal")
def __exit__(self, exc_type, exc_val, traceback):
for x in self.xs:
if isinstance(x, ExprLiteral):
_ASSUMPTIONS.discard(x)
else:
raise ValueError("expected assumption to be a literal")
def to_unicode(self):
parts = list()
for x in sorted(self.xs):
# lower precedence: or, xor, implies, equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_unicode() + ')')
else:
parts.append(x.to_unicode())
sep = " " + self.SYMBOL + " "
return sep.join(parts)
def to_latex(self):
parts = list()
for x in sorted(self.xs):
# lower precedence: or, xor, implies, equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_latex() + ')')
else:
parts.append(x.to_latex())
sep = " " + self.LATEX_SYMBOL + " "
return sep.join(parts)
# From Expression
def is_dnf(self):
# a & b
if self.depth == 1:
return all(isinstance(x, ExprLiteral) for x in self.xs)
else:
return False
def is_cnf(self):
# a & b
if self.depth == 1:
return all(isinstance(x, ExprLiteral) for x in self.xs)
# a & (b | c)
elif self.depth == 2:
return all(isinstance(x, ExprLiteral) or
isinstance(x, ExprOr) and x.is_dnf()
for x in self.xs)
else:
return False
def _encode_dnf(self):
litmap, nvars = self.encode_inputs()
clauses = {self._encode_clause(litmap)}
return litmap, nvars, clauses
def _encode_cnf(self):
litmap, nvars = self.encode_inputs()
clauses = {x._encode_clause(litmap) for x in self.xs}
return litmap, nvars, clauses
# From ExprOrAnd
@staticmethod
def get_dual():
return ExprOr
# FlattenedExpression
@cached_property
def _cube(self):
"""Return the cube representation."""
return frozenset(self.xs)
@cached_property
def _cover(self):
return {self._cube}
@property
def term_index(self):
return self.minterm_index
# Specific to ExprAnd
@cached_property
def minterm_index(self):
"""Return this minterm's unique index."""
num = self.degree - 1
index = 0
for i, v in enumerate(self.inputs):
if v in self.xs:
index |= 1 << (num - i)
return index
@staticmethod
def _term_expand(term, vs):
return term.expand(vs, conj=True).xs
[docs]class ExprXor(_NaryOperator):
"""Expression exclusive OR (XOR) operator"""
ASTOP = 'xor'
# Circled plus - U2295
SYMBOL = '⊕'
LATEX_SYMBOL = '\\oplus'
PRECEDENCE = 1
IDENTITY = EXPRZERO
def __str__(self):
return "Xor(" + self._joinargs(", ") + ")"
def to_unicode(self):
parts = list()
for x in sorted(self.xs):
# lower precedence: or, implies, equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_unicode() + ')')
else:
parts.append(x.to_unicode())
sep = " " + self.SYMBOL + " "
return sep.join(parts)
def to_latex(self):
parts = list()
for x in sorted(self.xs):
# lower precedence: or, implies, equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_latex() + ')')
else:
parts.append(x.to_latex())
sep = " " + self.LATEX_SYMBOL + " "
return sep.join(parts)
# From Expression
def simplify(self):
if self._simplified:
return self
par = 1
temps, xs = list(self.xs), set()
while temps:
x = temps.pop().simplify()
if isinstance(x, ExprConstant):
par ^= int(x)
# associative
elif isinstance(x, self.__class__):
temps.extend(x.xs)
# (x, ~x) is either (0, 1) or (1, 0)
elif x._inv in xs:
xs.remove(x._inv)
par ^= 1
# (x, x) is either (0, 0) or (1, 1)
elif x in xs:
xs.remove(x)
else:
xs.add(x)
# Xor() = 0
if len(xs) == 0:
obj = self.IDENTITY
# Xor(x) = x; Xnor(x) = ~x
elif len(xs) == 1:
obj = xs.pop()
else:
obj = ExprXor(*xs)
obj._simplified = True
return obj if par else obj._inv
def to_nnf(self, conj=False):
outer, inner = (ExprAnd, ExprOr) if conj else (ExprOr, ExprAnd)
terms = list()
# Walk through all entries in the truth table
for num in range(1 << len(self.xs)):
if conj:
if parity(num) == 0:
term = list()
for i, xi in enumerate(self.xs):
if bit_on(num, i):
term.append(xi._inv_nnf())
else:
term.append(xi.to_nnf())
terms.append(inner(*term))
else:
if parity(num) == 1:
term = list()
for i, xi in enumerate(self.xs):
if bit_on(num, i):
term.append(xi.to_nnf())
else:
term.append(xi._inv_nnf())
terms.append(inner(*term))
return outer(*terms).simplify()
def _inv_nnf(self, conj=False):
outer, inner = (ExprAnd, ExprOr) if conj else (ExprOr, ExprAnd)
terms = list()
# Walk through all entries in the truth table
for num in range(1 << len(self.xs)):
if conj:
if parity(num) == 1:
term = list()
for i, xi in enumerate(self.xs):
if bit_on(num, i):
term.append(xi._inv_nnf())
else:
term.append(xi.to_nnf())
terms.append(inner(*term))
else:
if parity(num) == 0:
term = list()
for i, xi in enumerate(self.xs):
if bit_on(num, i):
term.append(xi.to_nnf())
else:
term.append(xi._inv_nnf())
terms.append(inner(*term))
return outer(*terms).simplify()
[docs]class ExprEqual(_NaryOperator):
"""Expression EQUAL operator"""
ASTOP = 'equal'
# Left right double arrow - 21D4
SYMBOL = '⇔'
LATEX_SYMBOL = '\\Leftrightarrow'
PRECEDENCE = 4
IDENTITY = EXPRONE
def __str__(self):
return "Equal(" + self._joinargs(", ") + ")"
def to_unicode(self):
parts = list()
for x in self.xs:
# lower precedence:
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_unicode() + ')')
else:
parts.append(x.to_unicode())
sep = " " + self.SYMBOL + " "
return sep.join(parts)
def to_latex(self):
parts = list()
for x in self.xs:
# lower precedence:
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_latex() + ')')
else:
parts.append(x.to_latex())
sep = " " + self.LATEX_SYMBOL + " "
return sep.join(parts)
# From Expression
def simplify(self):
if self._simplified:
return self
xs = {x.simplify() for x in self.xs}
if EXPRZERO in xs:
# Equal(0, 1, ...) = 0
if EXPRONE in xs:
return EXPRZERO
# Equal(0, x0, x1, ...) = Nor(x0, x1, ...)
else:
xs.remove(EXPRZERO)
return ExprOr(*xs).simplify()._inv
# Equal(1, x0, x1, ...) = And(x0, x1, ...)
if EXPRONE in xs:
xs.remove(EXPRONE)
return ExprAnd(*xs).simplify()
# no constants; all simplified
temps, xs = xs, set()
while temps:
x = temps.pop()
# Equal(x, ~x) = 0
if x._inv in xs:
return EXPRZERO
else:
xs.add(x)
# Equal(x) = Equal() = 1
if len(xs) <= 1:
obj = self.IDENTITY
else:
obj = self.__class__(*xs)
obj._simplified = True
return obj
def to_nnf(self, conj=False):
if conj:
xs = list()
for x0, x1 in itertools.combinations(self.xs, 2):
xs.append(ExprOr(x0._inv_nnf(), x1.to_nnf()))
xs.append(ExprOr(x0.to_nnf(), x1._inv_nnf()))
return ExprAnd(*xs).simplify()
else:
all0 = ExprAnd(*[x._inv_nnf() for x in self.xs])
all1 = ExprAnd(*[x.to_nnf() for x in self.xs])
return ExprOr(all0, all1).simplify()
def _inv_nnf(self, conj=False):
if conj:
any0 = ExprOr(*[x._inv_nnf() for x in self.xs])
any1 = ExprOr(*[x.to_nnf() for x in self.xs])
return ExprAnd(any0, any1).simplify()
else:
xs = list()
for x0, x1 in itertools.combinations(self.xs, 2):
xs.append(ExprAnd(x0._inv_nnf(), x1.to_nnf()))
xs.append(ExprAnd(x0.to_nnf(), x1._inv_nnf()))
return ExprOr(*xs).simplify()
[docs]class ExprImplies(Expression):
"""Expression implication operator"""
ASTOP = 'implies'
# Rightwards double arrow - 21D2
SYMBOL = '⇒'
LATEX_SYMBOL = '\\Rightarrow'
PRECEDENCE = 3
def __init__(self, p, q):
super(ExprImplies, self).__init__()
self.p = p
self.q = q
@property
def args(self):
return self.p, self.q
def __str__(self):
return "Implies({0.p}, {0.q})".format(self)
def to_unicode(self):
parts = list()
for x in (self.p, self.q):
# lower precedence: equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_unicode() + ')')
else:
parts.append(x.to_unicode())
sep = " " + self.SYMBOL + " "
return sep.join(parts)
def to_latex(self):
parts = list()
for x in (self.p, self.q):
# lower precedence: equal
if x.PRECEDENCE >= self.PRECEDENCE:
parts.append('(' + x.to_latex() + ')')
else:
parts.append(x.to_latex())
sep = " " + self.LATEX_SYMBOL + " "
return sep.join(parts)
# From Function
@cached_property
def support(self):
return frozenset.union(self.p.support, self.q.support)
def _urestrict1(self, upoint):
p = self.p._urestrict1(upoint)
q = self.q._urestrict1(upoint)
if p is self.p and q is self.q:
return self
else:
return self.__class__(p, q)
def compose(self, mapping):
p = self.p.compose(mapping)
q = self.q.compose(mapping)
if p is self.p and q is self.q:
return self.simplify()
else:
return self.__class__(p, q).simplify()
# From Expression
def simplify(self):
p = self.p.simplify()
q = self.q.simplify()
# 0 => q = 1; p => 1 = 1
if p is EXPRZERO or q is EXPRONE:
return EXPRONE
# 1 => q = q
elif p is EXPRONE:
return q
# p => 0 = ~p
elif q is EXPRZERO:
return p._inv
# p => p = 1
elif p is q:
return EXPRONE
# ~p => p = p
elif p._inv is q:
return q
obj = self.__class__(p, q)
obj._simplified = True
return obj
@cached_property
def depth(self):
return max([self.p.depth, self.q.depth]) + 1
def to_nnf(self, conj=False):
return ExprOr(self.p._inv_nnf(), self.q.to_nnf()).simplify()
def _inv_nnf(self, conj=False):
return ExprAnd(self.p.to_nnf(), self.q._inv_nnf()).simplify()
[docs]class ExprITE(Expression):
"""Expression if-then-else ternary operator"""
ASTOP = 'ite'
def __init__(self, s, d1, d0):
super(ExprITE, self).__init__()
self.s = s
self.d1 = d1
self.d0 = d0
@property
def args(self):
return self.s, self.d1, self.d0
def __str__(self):
return "ITE({0.s}, {0.d1}, {0.d0})".format(self)
def to_unicode(self):
unicode_args = [self.s.to_unicode(),
self.d1.to_unicode(), self.d0.to_unicode()]
return "ite({}, {}, {})".format(*unicode_args)
def to_latex(self):
latex_args = [self.s.to_latex(), self.d1.to_latex(), self.d0.to_latex()]
return "ite({}, {}, {})".format(*latex_args)
# From Function
@cached_property
def support(self):
return frozenset.union(self.s.support, self.d1.support, self.d0.support)
def _urestrict1(self, upoint):
s = self.s._urestrict1(upoint)
d1 = self.d1._urestrict1(upoint)
d0 = self.d0._urestrict1(upoint)
if s is self.s and d1 is self.d1 and d0 is self.d0:
return self
else:
return self.__class__(s, d1, d0)
def compose(self, mapping):
s = self.s.compose(mapping)
d1 = self.d1.compose(mapping)
d0 = self.d0.compose(mapping)
if s is self.s and d1 is self.d1 and d0 is self.d0:
return self.simplify()
else:
return self.__class__(s, d1, d0).simplify()
# From Expression
def simplify(self):
s = self.s.simplify()
d1 = self.d1.simplify()
d0 = self.d0.simplify()
# 0 ? d1 : d0 = d0
if s is EXPRZERO:
return d0
# 1 ? d1 : d0 = d1
elif s is EXPRONE:
return d1
elif d1 is EXPRZERO:
# s ? 0 : 0 = 0
if d0 is EXPRZERO:
return EXPRZERO
# s ? 0 : 1 = ~s
elif d0 is EXPRONE:
return s._inv
# s ? 0 : d0 = ~s & d0
else:
return ExprAnd(s._inv, d0).simplify()
elif d1 is EXPRONE:
# s ? 1 : 0 = s
if d0 is EXPRZERO:
return s
# s ? 1 : 1 = 1
elif d0 is EXPRONE:
return EXPRONE
# s ? 1 : d0 = s | d0
else:
return ExprOr(s, d0).simplify()
# s ? d1 : 0 = s & d1
elif d0 is EXPRZERO:
return ExprAnd(s, d1).simplify()
# s ? d1 : 1 = ~s | d1
elif d0 is EXPRONE:
return ExprOr(s._inv, d1).simplify()
# s ? d1 : d1 = d1
elif d1 is d0:
return d1
obj = self.__class__(s, d1, d0)
obj._simplified = True
return obj
@cached_property
def depth(self):
return max([self.s.depth, self.d1.depth, self.d0.depth]) + 1
def to_nnf(self, conj=False):
if conj:
# (~s | d1) & (s | d0)
x0 = ExprOr(self.s._inv_nnf(), self.d1.to_nnf())
x1 = ExprOr(self.s.to_nnf(), self.d0.to_nnf())
return ExprAnd(x0, x1).simplify()
else:
# s & d1 | ~s & d0
x0 = ExprAnd(self.s.to_nnf(), self.d1.to_nnf())
x1 = ExprAnd(self.s._inv_nnf(), self.d0.to_nnf())
return ExprOr(x0, x1).simplify()
def _inv_nnf(self, conj=False):
if conj:
# (~s | ~d1) & (s | ~d1)
x0 = ExprOr(self.s._inv_nnf(), self.d1._inv_nnf())
x1 = ExprOr(self.s.to_nnf(), self.d0._inv_nnf())
return ExprAnd(x0, x1).simplify()
else:
# s & ~d1 | ~s & ~d0
x0 = ExprAnd(self.s.to_nnf(), self.d1._inv_nnf())
x1 = ExprAnd(self.s._inv_nnf(), self.d0._inv_nnf())
return ExprOr(x0, x1).simplify()
[docs]class DimacsCNF(ConjNormalForm):
"""Wrapper class for a DIMACS CNF representation"""
def __str__(self):
formula = super(DimacsCNF, self).__str__()
return "p cnf {0.nvars} {0.nclauses}\n{1}".format(self, formula)
def _iter_zeros(expr):
"""Iterate through all upoints that map to element zero."""
if expr is EXPRZERO:
yield frozenset(), frozenset()
elif expr is not EXPRONE:
v = expr.splitvar
upnt0 = frozenset([v.uniqid]), frozenset()
upnt1 = frozenset(), frozenset([v.uniqid])
for upnt in [upnt0, upnt1]:
for zero_upnt in _iter_zeros(expr._urestrict2(upnt)):
yield (upnt[0] | zero_upnt[0], upnt[1] | zero_upnt[1])
def _iter_ones(expr):
"""Iterate through all upoints that map to element one."""
if expr is EXPRONE:
yield frozenset(), frozenset()
elif expr is not EXPRZERO:
v = expr.splitvar
upnt0 = frozenset([v.uniqid]), frozenset()
upnt1 = frozenset(), frozenset([v.uniqid])
for upnt in [upnt0, upnt1]:
for one_upnt in _iter_ones(expr._urestrict2(upnt)):
yield (upnt[0] | one_upnt[0], upnt[1] | one_upnt[1])
def _tseitin(expr, auxvarname, auxvars=None):
"""
Convert a factored expression to a literal, and a list of constraints.
"""
if isinstance(expr, ExprLiteral):
return expr, list()
else:
if auxvars is None:
auxvars = list()
fs = list()
constraints = list()
for x in expr.xs:
f, subcons = _tseitin(x, auxvarname, auxvars)
fs.append(f)
constraints.extend(subcons)
auxvarindex = len(auxvars)
auxvar = exprvar(auxvarname, auxvarindex)
auxvars.append(auxvar)
constraints.append((auxvar, expr.__class__(*fs)))
return auxvar, constraints
def _complete_sum(dnf):
"""
Recursive complete_sum function implementation.
CS(f) = ABS([x1 | CS(0, x2, ..., xn)] & [~x1 | CS(1, x2, ..., xn)])
"""
if dnf.depth <= 1:
return dnf
else:
v = dnf.splitvar
fv0, fv1 = dnf.cofactors(v)
f = And(Or(v, _complete_sum(fv0)), Or(~v, _complete_sum(fv1)))
if isinstance(f, ExprAnd):
f = Or(*[And(x, y)
for x in f.xs[0]._lits
for y in f.xs[1]._lits])
return f._absorb()
# Convenience dictionaries
ASTOPS = {
ExprNot.ASTOP : ExprNot,
ExprOr.ASTOP : ExprOr,
ExprAnd.ASTOP : ExprAnd,
ExprXor.ASTOP : ExprXor,
ExprEqual.ASTOP : ExprEqual,
ExprImplies.ASTOP : ExprImplies,
ExprITE.ASTOP : ExprITE,
'nor': Nor,
'nand': Nand,
'xnor': Xnor,
'unequal': Unequal,
'onehot0' : OneHot0,
'onehot' : OneHot,
'majority' : Majority,
'achillesheel' : AchillesHeel,
}