#***************************************************************************
# *   Copyright (C) 2003-2006 by A Lynch                                    *
# *   aalynch@users.sourceforge.net                                         *
# *                                                                         *
# *   This program is free software; you can redistribute it and/or modify  *
# *   it under the terms of the GNU General Public License version 2 as published by  *
# *   the Free Software Foundation;                                         *
# *                                                                         *
# *   This program is distributed in the hope that it will be useful,       *
# *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
# *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
# *   GNU General Public License for more details.                          *
# *                                                                         *
# *   You should have received a copy of the GNU General Public License     *
# *   along with this program; if not, write to the                         *
# *   Free Software Foundation, Inc.,                                       *
# *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
# ***************************************************************************/

import sys
import constants

import antlr
from maximaparsertokens import tokenValues
from configuration import config

import logging
l = logging.getLogger(__name__)

fontfactory = None

BREAK_DONE = 1
BREAK_NOT_NEEDED = 2
BREAK_NOT_POSSIBLE = 3

def getScalingFunctions(factory):
    setattr(sys.modules["node"],"getScaledBracketWidth",factory.getScaledBracketWidth)
    setattr(sys.modules["node"],"getIntegralWidth",factory.getIntegralWidth)

class Node(antlr.CommonAST):
    def __init__(self,token=None):
        antlr.CommonAST.__init__(self,token)

        self.x = 0
        self.y = 0
        self.w = 0
        self.h = 0
        self.id = ""
        self.nodes = []
        self.italic = False
        self.parent = None
        self.dontDraw = False
        self.font = None
        self.fontSize = 5
        self.fontName = None
        self.metrics = None
        # operators are centered vertically on this line, which is about half a character up from the baseline
        # it is the distance from the operator centre line to the bottom of the item
        self.operatorcentreline = 0
        self.requestReduceVerticalHeight = False
        
    def clone(self, newNode):
        newNode.x = self.x
        newNode.y = self.y
        newNode.w = self.w
        newNode.h = self.h
        newNode.nodes = self.nodes
        newNode.font = self.font
        newNode.fontSize = self.fontSize
        newNode.fontName = self.fontName
        newNode.metrics = self.metrics
        newNode.italic = self.italic
        newNode.parent = self.parent
        newNode.dontDraw = self.dontDraw
        newNode.operatorcentreline = self.operatorcentreline
        newNode.id = id
        newNode.setText(self.getText())

    def postParse(self):
        # recurse through child nodes
        for node in self.nodes:
            node.postParse()

    def setChildParent(self):
        """ set child.parent and self.nodes """
        for child in self.children():
            child.parent = self
            self.nodes.append(child)
            child.setChildParent()
            
    def children(self):
        children = []
        child = self.getFirstChild()
        while child:
            children.append(child)
            child = child.getNextSibling()
        return children        
        
    def __repr__(self):
        if self.getType() not in constants.LEAFNODEOPERATORS:
            text = repr(self.nodes)
        else:
            text = "NC"
        return "<" + self.__class__.__name__ + " " + self.getText() + " " + text + ">"
        
    def getDebugString(self):
        s = self.getText()
        s += " " + self.__class__.__name__
        s += " x:%i y:%i w:%i h:%i oc:%i" % (self.x,self.y,self.w,self.h,self.operatorcentreline) 
        #s = s + " " + self.fontName + " " + str(self.fontSize)
        return s
        
    def removeNode(self, node):
        self.nodes.remove(node)
        
    def replaceNode(self, node, newNode):
        index = self.nodes.index(node)
        self.nodes[index] = newNode

    def getAllVariables(self):
        """ gets all the variable names mentioned in the equation """
        if self.getType() == tokenValues.VAR:
            if self.plainVar not in('%PI','%E','%I'):
                return [self.plainVar]
            else:
                return []
        vars = []
        if self.getType() == tokenValues.FUNCTION:
            # skip the VAR node here
            nodes = self.nodes[1].nodes
        elif self.getType() == tokenValues.LABEL:
            nodes = self.nodes[1:]
        else:
            nodes = self.nodes
        for node in nodes:
            newVars = node.getAllVariables()
            for newVar in newVars:
                if newVar not in vars:
                    vars.append(newVar)
        return vars

    def getString(self):
        try:
            if self.getType() in constants.LEAFNODEOPERATORS:
                text = self.getText()
            else:
                childText = ' '
                for node in self.nodes:
                    if childText:
                        childText +=', '
                    childText += node.getString()
                text = '[' + self.getText() + childText + ']'
            return text
        except:
            l.exception("getting string for " + repr(self))
            raise

    def getPowerDepth(self):
        """ calculates what level of power this is so we can later work out the correct font reduction """
        if self.nodes[0].getType() != tokenValues.POWER:
            return 1
        else:
            # add on power depth of child
            return 1 + self.nodes[0].getPowerDepth()

    def getRightmostTerm(self):
        """ gets the term furthest on the right (on this level, excluding powers) for this term and all its children
        - could be a bracket, number, variable etc """
        if len(self.nodes) == 1 or self.getType == tokenValues.POWER:
            rightmostTerm = self.nodes[0]
        else:
            rightmostTerm = self.nodes[1]
        if  rightmostTerm.getType() in [tokenValues.NUM,tokenValues.VAR,tokenValues.BRACKET,\
                tokenValues.SQBRACKET,tokenValues.SQRT] + constants.FUNCTIONS:
            return rightmostTerm
        else:
            return rightmostTerm.getRightmostTerm()

    def getLeftmostTerm(self):
        """ gets the term furthest on the left (on this level, excluding powers) for this term and all its children
        - could be a bracket, number, variable etc """
        if self.nodes[0].getType() in [tokenValues.NUM,tokenValues.VAR,tokenValues.BRACKET,\
                tokenValues.SQBRACKET,tokenValues.SQRT] + constants.FUNCTIONS:
            return self.nodes[0]
        else:
            return self.nodes[0].getLeftmostTerm()

    def containsDivides(self):
        import nodesubclasses
        if isinstance(self, nodesubclasses.DIVIDENode) or isinstance(self, nodesubclasses.DIFFERENTIATENode) \
               or isinstance(self, nodesubclasses.INTEGRATENode) or isinstance(self, nodesubclasses.MATRIXNode):
            return True
        else:
            for node in self.nodes:
                if node.containsDivides():
                    return True
        return False

    def setFont(self, name, size):
        try:
            self.fontName = name
            self.fontSize = size
            # recurse through child nodes
            if self.font:
                self.font.setFamily(name)
                self.font.setPointSize(size)
            for node in self.nodes:
                node.setFont(name, size)
        except:
            l.exception("setting font size for " + repr(self))
            raise
            
    def getHeightHorizontalTerms(self, nodes):
        extendsBelowCentre = 0
        extendsAboveCentre = 0
        for node in nodes:
            if node.operatorcentreline > extendsBelowCentre:
                extendsBelowCentre = node.operatorcentreline
            if (node.h - node.operatorcentreline) > extendsAboveCentre:
                extendsAboveCentre = (node.h - node.operatorcentreline)
        height = extendsBelowCentre + extendsAboveCentre
        operatorcentreline = extendsBelowCentre
        return height, operatorcentreline

    def getTextCentreLine(self, font, text):
        centre = self.getPlusBarHeight(font) + self.getTextDescender(font, text)
        #l.debug("text centre line for %s is %2.2f" % (text,centre))
        return centre

    def getPlusBarHeight(self, font):
        metrics = fontfactory.getFontMetrics(font)
        plusSize = metrics.charBoundingRect("+")
        plusLine = plusSize.y() + plusSize.height()/2
        #l.debug("plus bar height is %s in font %s" % ( plusLine,repr(font)))
        return plusLine

    def getTextDescender(self, font, text):
        metrics = fontfactory.getFontMetrics(font)
        descender = 0
        for char in text:
            size = metrics.charBoundingRect(char)
            #l.debug("size for %s is %s" % (text,repr(size)))
            yOffset = size.y()
            if yOffset < descender:
                descender = yOffset
        #l.debug("descender for %s is %2.2f" % (text,-descender))
        return -descender

    def drawNode(self,painter):
        """ draws the node on the painter. You must call layoutNode before this"""
        if self.getType() == 'PIX':
            self.drawNodeText(painter)
        else:
            self.drawNodeText(painter)

    def layoutNode(self,x,y,breakwidth, fontName, fontSize):
        """ call this to set the positions and sizes of everything before calling drawNode """
        try:
            if self.getType() == 'PIX':
                self.setTextPosition(x,5)
            else:
                self.setFont(fontName, fontSize)
                self.setTextSize()
                newNodes = self.lineBreakAndSize(breakwidth)
                for node in newNodes:
                    self.setTextSize()
                    node.setTextPosition(x,y)
            return newNodes
        except:
            l.exception("laying out node of class " + repr(self.__class__.__name__))
            raise

    def lineBreakAndSize(self,breakwidth):
        """ calculates the size of the node including line breaking
        but doesn't set the final text position so that must be called after this function"""
        self.setTextPosition(0,0)
        newNodes = self.doLineBreaks(breakwidth - self.operatorsize.width() - self.thirdspacewidth)
        return newNodes

    def drawString(self,painter, text, X, Y):
        #print "draw " + repr(item) + " at " + str(X) + "," + str(Y)
        painter.drawText(X,Y+self.getPlusBarHeight(self.font),text)
        
    def subClassDrawNodeText(self,painter, metrics):
        pass

    def setTextSize(self):
        """ The algorithm:(recursive)
            all terms are deemed to have an 'operationcentreline'.This is the Y coordinate of the horizontal axis of 'symmetry',
            e.g. the bar of the plus in '1+2' or the divider line in a fraction. Terms are then placed according to their
            height and operatorcentreline so that the operatorcentrelines line up from term to term where appropriate.
            One issue at the moment is how to calculate the distance of the bar of a '+' to the base line, as QFontMetrics does
            not currently seem to provide the ability to discover this."""

        self.colour = constants.black

        if self.getType() == 'PIX':
            # size already set
            return

        font = fontfactory.makeFont()
        font.setFamily(self.fontName)
        font.setPointSize(self.fontSize)
        font.setItalic(self.italic)
        self.font = font

        self.metrics = fontfactory.getFontMetrics(font)

        self.operatorsize = self.metrics.boundingRect(self.getText())

        self.Xheight = self.metrics.boundingRect('X').height()
        self.trueXheight = self.metrics.charBoundingRect('X').height()
        self.spacewidth = self.Xheight/6.0
        self.thirdspacewidth = self.spacewidth / 3.0
        self.bracketwidth = self.metrics.boundingRect('(').width()

        for node in self.nodes:
            try:
                node.setTextSize()
            except:
                l.exception("setting text size for " + repr(node) + " of class " + self.__class__.__name__)
                raise

    def setTextPosition(self,X,Y):
        """ Set the text position of all child terms relative to X and Y.
        IMPORTANT: X is the left hand edge,Y is the operatorcentreline (i.e. not the top left hand corner)  """

        self.x = X
        self.y = Y

    def drawNodeText(self,painter):

        import nodesubclasses

        if isinstance(self, nodesubclasses.EMPTYNode):
            return

        if self.getType() == 'PIX':
            """ draw the final pic """
            painter.drawPixmap(X,Y, self.nodes[1])
            return

        # redo font settings as they can be changed
        self.font.setFamily(self.fontName)
        self.font.setPointSize(self.fontSize)
        self.font.setItalic(self.italic)
        painter.setFont(self.font)
        painter.setPenColour(self.colour)
        
        for node in self.nodes:
            if not node.dontDraw:
                painter.save()
                node.drawNodeText(painter)
                painter.restore()
            
        # debugging
        if config["drawdebugrect"]:
            painter.drawRect(self.x, self.y - self.h + self.operatorcentreline, self.w, self.h)
            
        self.subClassDrawNodeText(painter, self.metrics)


######################### LINE BREAKING CODE #######################################33

    def doLineBreak(self, lineWidth):
        """ If the width of the current node is greater than the line width, then cause this node or one of its
        descendents to vertically lay out the child terms so as to fit within the specified width.
        e.g.
        A + B + C           where A is wide
        -->
        A +
        B + C

        it would also be nice to have

        (A + B + C)/D
        -->
        A/D +
        (B+C)/D

        and even

        sqrt(A*B*C)
        -->
        sqrt(A)*
        sqrt(B*C)

        """

        if (self.x + self.w) < lineWidth:
            return BREAK_NOT_NEEDED, self

        if self.getType() in constants.SEPERABLE_OPERATORS:
            if (self.nodes[0].x + self.nodes[0].w) > lineWidth:
                # break down terms[0] first (and rercursively) if possible
                flag, node = self.nodes[0].doLineBreak(lineWidth)
                if flag == BREAK_NOT_POSSIBLE:
                    # we can't break the child so break this node
                    return BREAK_DONE, self
                else:
                    return flag, node
            return BREAK_DONE, self

        return BREAK_NOT_POSSIBLE, self

    def doLineBreakIntoNodes(self, width):
        nodes = []
        ret, breakNode = self.doLineBreak(width)
        if ret == BREAK_DONE:
            newNode = breakNode.nodes[0]
            sisterNode = breakNode.nodes[1]
            originalNode = self
            import nodesubclasses
            # create new node with +/- and nodes[0] = breakNode and nodes[1] = emptyNode
            newNodeParent = nodesubclasses.PLUSMINUSNode()
            emptyNode = nodesubclasses.EMPTYNode()
            # copy font etc
            breakNode.clone(emptyNode)
            breakNode.clone(newNodeParent)                      
            emptyNode.nodes = []
            emptyNode.w = 0
            emptyNode.h = 0
            emptyNode.setText("EMPTY")
            newNodeParent.setText(breakNode.getText())
            newNodeParent.nodes = []
            newNodeParent.parent = None
            newNodeParent.w = 0
            newNodeParent.h = 0
            newNode.parent = newNodeParent
            emptyNode.parent = newNodeParent
            newNodeParent.nodes.append(newNode)
            newNodeParent.nodes.append(emptyNode)
            newNodeParent.setTextSize()
            # change original node
            if breakNode.parent:
                breakNode.parent.nodes[0] = sisterNode
            else:
                originalNode = sisterNode
            nodes = [newNodeParent, originalNode]
        else:
            # can't break node or not needed
            nodes = [self]
        return ret, nodes

    def doLineBreaks(self, width):
        """ recursive function to keep breaking lines until the job is fully done"""
        #l.debug("original node was " + self.getString())
        ret, nodes = self.doLineBreakIntoNodes(width)
        if ret == BREAK_DONE:
            # try again for last node
            node0 = nodes[0]
            node1  = nodes[1]
            node1.setTextSize()
            node1.setTextPosition(0,0)
            extraNodes = node1.doLineBreaks(width)
            
            del nodes[1]
            nodes.extend(extraNodes)
        return nodes
        

######################## END LINE BREAKING CODE ###################################

        

