from lxml import etree
import datetime as dt
import numpy as np
import matplotlib.pyplot as plt


 
class SaleDataParser(object):
    def __init__(self, filename):
        self.saledatafile = filename;
        self.xml = etree.parse(filename)
        self.pms_start_date = dt.datetime.strptime("01-05-2011","%d-%m-%Y")

        
    #Sales for a month
    def get_monthly_sales(self, month, year):
        all_sales = self.xml.findall("//Sale")
        total_sales_for_month = 0
        for sale in all_sales:
            #Get sale price
            salePrice= sale.findtext("Payable")
            timestamp = sale.findtext("Purchase_Date")
            timestamp = timestamp.split('T')[0]#Remove data after date
            date = dt.datetime.strptime(timestamp,"%Y-%m-%d")
            #print date.month
            #print date.year
            if (date.month == month) and (date.year == year) :
                total_sales_for_month += float(salePrice)
                
        return total_sales_for_month
    
    #Sales since inception of PMS
    def get_overall_monthly_sales(self,filename="sales.dat"):
        f = open(filename, 'w')
        all_sales = self.xml.findall("//Sale")
        total_sales_for_month = 0
        last_month = self.pms_start_date.month
        last_year = self.pms_start_date.year
        for sale in all_sales:
            #Get sale price
            salePrice= sale.findtext("Payable")
            timestamp = sale.findtext("Purchase_Date")
            timestamp = timestamp.split('T')[0]#Remove data after date
            date = dt.datetime.strptime(timestamp,"%Y-%m-%d")
            #print date.month
            #print date.year
            if (date.month != last_month) :
                #print ("%d:%d Revenue:$%d" %(last_month, date.year,total_sales_for_month))
                print ("%d:%d %d" %(last_month, last_year, total_sales_for_month))
                f.write("%d:%d %d\n" %(last_month, last_year, total_sales_for_month))
                last_month = date.month
                last_year = date.year
                total_sales_for_month = float(salePrice)
            else:
                total_sales_for_month += float(salePrice)
        #Now tackle the last month
        print ("%d:%d %d" %(date.month, date.year, total_sales_for_month))
        f.write("%d:%d %d\n" %(date.month, date.year, total_sales_for_month))
                
    #Sales freq since inception            
    def get_overall_sales_freq(self, filename="sales_freq.dat"):
        f = open(filename, 'w')
        all_sales = self.xml.findall("//Sale")
        total_sales_freq = 0
        last_month = self.pms_start_date.month
        last_year = self.pms_start_date.year
        for sale in all_sales:
            #Get sale price
            salePrice= sale.findtext("Payable")
            timestamp = sale.findtext("Purchase_Date")
            timestamp = timestamp.split('T')[0]#Remove data after date
            date = dt.datetime.strptime(timestamp,"%Y-%m-%d")
            #print ("%d:%d" %(date.month, date.year))
            if (date.month != last_month) :
                print ("%d:%d Sales freq: %d Sales " %(last_month, last_month, total_sales_freq))
                f.write("%d:%d %d\n" %(last_month, last_year, total_sales_freq))
                last_month = date.month
                last_year = date.year
                total_sales_freq = 0
            else:
                total_sales_freq +=1        
        print ("%d:%d Sales freq: %d Sales " %(date.month, date.year,total_sales_freq))
        f.write("%d:%d %d\n" %(date.month, date.year, total_sales_freq))
                                       
             
    #Sales freq for month                                      
    def get_monthly_sales_freq(self, month, year):
        all_sales = self.xml.findall("//Sale")
        total_sales_frequency = 0
        for sale in all_sales:
            #Get sale price
            salePrice= sale.findtext("Payable")
            timestamp = sale.findtext("Purchase_Date")
            timestamp = timestamp.split('T')[0]#Remove data after date
            date = dt.datetime.strptime(timestamp,"%Y-%m-%d")
            #print ("%d:%d" %(date.month, date.year))
            if (date.month == month) and (date.year == year) :
                total_sales_frequency+=1
                
        return total_sales_frequency      

class ProductData:
    def __init__(self, name="", qty=0, sold=0):
        self.name= name
        self.qty= qty
        self.sold= sold

   
class ProductDataParser:

    def __init__(self, prodDataXML, saleDataXML):
        self.ProductDataList = []
        self.productdatafile = prodDataXML;
        self.saledatafile = saleDataXML;
        self.prodXML = etree.parse(prodDataXML);
        self.saleXML = etree.parse(saleDataXML);
        self.pms_start_date = dt.datetime.strptime("01-05-2011","%d-%m-%Y");
        #Now populate the ProductDataList  
        all_products = self.prodXML.findall("//Item")
        for product in all_products:
            productName = product.attrib;
            quantity = product.findtext("Quantity")
            #print ("%s:%d" %(productName["Name"], int(quantity)))
            self.ProductDataList.append(ProductData(productName["Name"], int(quantity), 0))
       
    def get_all_products_qty(self):
        all_products = self.prodXML.findall("//Item")
        for product in all_products:
            productName = product.attrib;
            quantity = product.findtext("Quantity")
            print ("NAME: %s qty: %s" %(productName["Name"], quantity))
            
    def update_product_sales(self):
        all_sales = self.saleXML.findall("//Sale")
        for sale in all_sales:
            for order in sale.iter("BearingOrder"):
                for item in order.iter("Item"):
                    Description = item.findtext("Desription")
                    for product in self.ProductDataList:
                        if(item.attrib["Name"] == product.name):
                           #print ( "Matched %s = %s" %(product.name, item.attrib["Name"]))
                           product.sold = product.sold + 1 

    def print_product_sales(self):
        for product in self.ProductDataList:
            print ( "%s      Qty sold: %d  Stock: %d" %(product.name, product.sold, product.qty))
    
    def GraphData(self):
        Graph = ProductSales2D()
        Graph.plotGraph(self.ProductDataList)
            
class ProductSales2D:
    def __init__(self):
        self.fig = plt.figure()
        self.ax=self.fig.add_subplot(111)
        
    def plotGraph(self, ProductDataList):
        ind = np.arange(len(ProductDataList))
        width = 6        
        qtyArray = []
        soldCountArray = []
        ProductNameArray = []
        for product in ProductDataList:
            ProductNameArray.append(product.name)
            qtyArray.append(product.qty)
            soldCountArray.append(product.sold)
        
        self.ax.set_ylabel('Count')
        self.ax.set_xticklabels(ProductNameArray)
        self.fig.autofmt_xdate(rotation=90)
       
        self.ax.set_xticks((ind*2*width)+(width+6))
        #qtyRect = self.ax.bar((ind*2*width), qtyArray, width, color='r')
        soldCountRect = self.ax.bar((ind*2*width)+(width), soldCountArray, (width), color='g')
        self.ax.yaxis.grid(color='gray')
        #Lets start the show
        #self.ax.legend((qtyRect[0],soldCountRect[0]),('qty', 'sold'))
        plt.grid(True, which="both")
        plt.show()
               
             
            
            

SaleData = SaleDataParser("SaleData.xml")
#pms_start_date = dt.datetime.strptime("01-05-2011","%d-%m-%Y")

#ProductData.get_all_products_qty()
print SaleData.get_monthly_sales(6,2011)   
print SaleData.get_monthly_sales_freq(6,2011)

SaleData.get_overall_monthly_sales()
SaleData.get_overall_sales_freq()      
SaleData.get_overall_sales_freq()

ProductData = ProductDataParser("ProductData.xml", "SaleData.xml")
ProductData.update_product_sales()
#ProductData.print_product_sales()
ProductData.GraphData()
 
#inFile = "SaleData.xml" 
#xmlData = etree.parse(inFile) #etree.parse() opens and parses the data
#print xmlData 
# find all "Message" records in the document, regardless of level
#all_sales = xmlData.findall("//Sale")
 
#for sale in all_sales:
    #print sale
#    salePrice= sale.findtext("Payable")
    
#    all_items = sale.findall("BearingOrder/Item")
#    print all_items
#    print "--------------------"
    #findall("BearingOrder/Item")
    #for saleItem in all_items
    #print salePrice
    #first, get the data in the "tags" of the record
 #   msgDate = msg.attrib['Date']
  #  msgTime = msg.attrib['Time']
 
    # then, find the from and to users, and get the "FriendlyName" attribute
    # notice how we drill down though the "From"/"To" level to the "User" record
   # msgFrom = msg.find("From/User").attrib['FriendlyName']
#    msgTo   = msg.find("To/User").attrib['FriendlyName']
 
    # then find the "Text" record, and get the content between the tags
 #   msgText = msg.findtext("Text")
  #  print  '%s, %s: %s to %s: %s\n'%(msgDate, msgTime, msgFrom, msgTo, msgText)
