import numpy as np
import math
from scipy.sparse import lil_matrix
import scipy.sparse.linalg as sp
import scipy.sparse as sparse


def get_lines(dss):
    Line = []
    Switch = []
    lines = dss.Lines.First()
    while lines:
        datum = {}
        line = dss.Lines
        datum["name"] = line.Name()
        datum["bus1"] = line.Bus1()
        datum["bus2"] = line.Bus2()
        datum["switch_flag"] = dss.run_command('? Line.' + datum["name"] + '.Switch')
        if datum["switch_flag"] == 'False':
            datum["wires"] = dss.run_command('? Line.' + datum["name"] + '.Wires')
            datum["length"] = line.Length()
            datum["phases"] = line.Phases()
            datum["spacing"] = line.Spacing()
            datum["linecode"] = line.LineCode()
            datum["normAmp"] = dss.run_command('? Linecode.'+datum["linecode"]+'.normamps')
            Line.append(datum)
        else:
            Switch.append(datum)
        lines = dss.Lines.Next()
    return [Line, Switch]

def get_transformer(dss,circuit):
    data = []
    circuit.SetActiveClass('Transformer')
    xfmr_index = dss.ActiveClass.First()
    while xfmr_index:
        dataline = []
        cktElement = dss.CktElement
        xfmr_name = cktElement.Name()
        buses = dss.run_command('? ' + xfmr_name + '.buses')
        conns = dss.run_command('? ' + xfmr_name + '.conns')
        kVs = dss.run_command('? ' + xfmr_name + '.kVs')
        kVAs = dss.run_command('? ' + xfmr_name + '.kVAs')
        phase = dss.run_command('? ' + xfmr_name + '.phases')
        loadloss = dss.run_command('? ' + xfmr_name + '.%loadloss')
        noloadloss = dss.run_command('? ' + xfmr_name + '.%noloadloss')
        Rs = dss.run_command('? ' + xfmr_name + '.%Rs')
        xhl = dss.run_command('? ' + xfmr_name + '.xhl')
        dataline = dict(name=xfmr_name,buses=buses,conns=conns,kVs=kVs,kVAs=kVAs,phase=phase,loadloss=loadloss,noloadloss=noloadloss,Rs=Rs,xhl=xhl)
        data.append(dataline)
        xfmr_index = dss.ActiveClass.Next()
    return data

def get_loads(dss, circuit):
    data = []
    load_flag = dss.Loads.First()
    total_load = 0
    while load_flag:
        load = dss.Loads
        datum = {
            "name": load.Name(),
            "kV": load.kV(),
            "kW": load.kW(),
            "PF": load.PF(),
            "Delta_conn": load.IsDelta()
        }
        indexCktElement = circuit.SetActiveElement("Load.%s" % datum["name"])
        cktElement = dss.CktElement
        bus = cktElement.BusNames()[0].split(".")
        datum["kVar"] = float(datum["kW"]) / float(datum["PF"]) * math.sqrt(1 - float(datum["PF"]) * float(datum["PF"]))
        datum["bus1"] = bus[0]
        datum["numPhases"] = len(bus[1:])
        datum["phases"] = bus[1:]
        if not datum["numPhases"]:
            datum["numPhases"] = 3
            datum["phases"] = ['1', '2', '3']
        datum["voltageMag"] = cktElement.VoltagesMagAng()[0]
        datum["voltageAng"] = cktElement.VoltagesMagAng()[1]
        datum["power"] = dss.CktElement.Powers()[0:2]

        data.append(datum)
        load_flag = dss.Loads.Next()
        total_load += datum["kW"]
    return [data, total_load]


def get_pvSystems(dss):
    data = []
    PV_flag = dss.PVsystems.First()
    while PV_flag:
        datum = {}
        PVname = dss.CktElement.Name()
        NumPhase = dss.CktElement.NumPhases()
        bus = dss.CktElement.BusNames()[0]

        PVkW = dss.run_command('? ' + PVname + '.Pmpp')
        PVpf = dss.run_command('? ' + PVname + '.pf')
        PVkVA = dss.run_command('? ' + PVname + '.kVA')
        PVkV = dss.run_command('? ' + PVname + '.kV')

        datum["name"] = PVname
        datum["bus"] = bus
        datum["Pmpp"] = PVkW
        datum["pf"] = PVpf
        datum["kV"] = PVkV
        datum["kVA"] = PVkVA
        datum["power"] = dss.CktElement.Powers()[0:2]
        datum["numPhase"] = NumPhase

        data.append(datum)
        PV_flag = dss.PVsystems.Next()
    return data


def get_PVmaxP(dss,circuit,PVsystem):
    Pmax = []
    for PV in PVsystem:
        circuit.SetActiveElement(PV["name"])
        Pmax.append(-float(dss.CktElement.Powers()[0]))
    return Pmax


def get_PQnode(dss, circuit, Load, PVsystem, AllNodeNames):
    Pload = [0] * len(AllNodeNames)
    Qload = [0] * len(AllNodeNames)
    for ld in Load:
        for ii in range(len(ld['phases'])):
            name = ld['bus1'] + '.' + ld['phases'][ii]
            index = AllNodeNames.index(name.upper())
            circuit.SetActiveElement('Load.' + ld["name"])
            power = dss.CktElement.Powers()
            Pload[index] = power[2*ii]
            Qload[index] = power[2*ii+1]

    #PQ_load = np.matrix(np.array(Pload) + 1j * np.array(Qload)).transpose()
    PQ_load = np.array(Pload) + 1j * np.array(Qload)

    Ppv = [0] * len(AllNodeNames)
    Qpv = [0] * len(AllNodeNames)
    for PV in PVsystem:
        index = AllNodeNames.index(PV["bus"].upper())
        circuit.SetActiveElement(PV["name"])
        power = dss.CktElement.Powers()
        Ppv[index] = power[0]
        Qpv[index] = power[1]

    #PQ_PV = np.matrix(np.array(Ppv) + 1j * np.array(Qpv)).transpose()
    PQ_PV = np.array(Ppv) + 1j * np.array(Qpv)

    PQ_node = - PQ_load - PQ_PV # power injection
    return [PQ_load, PQ_PV, PQ_node]

def get_subPower_byPhase(dss):
    dss.Lines.First()
    power = dss.CktElement.Powers()
    subpower = power[0:6:2]
    return subpower

def getTotalPowers(circuit, dss, type, names):
    d = [None] * len(names)
    count = 0
    for loadname in names:
        circuit.SetActiveElement(type+'.'+loadname)
        s = dss.CktElement.Powers()
        id = int(len(s) / 2)
        p = 0
        q = 0
        for idd in range(id):
            p = p + s[idd * 2]
            q = q + s[idd * 2 + 1]

        d[count] = [p,q]
        count = count + 1
    d = np.asarray(d)
    s = sum(d)
    return s

def construct_Ymatrix(Ysparse, slack_no,totalnode_number):
    Ymatrix = np.matrix([[complex(0, 0)] * totalnode_number] * totalnode_number)
    file = open(Ysparse, 'r')
    G = []
    B = []
    count = 0
    for line in file:
        if count >= 4:
            temp = line.split('=')
            temp_order = temp[0]
            temp_value = temp[1]
            temp1 = temp_order.split(',')
            row_value = int(temp1[0].replace("[", ""))
            column_value = int(temp1[1].replace("]", ""))
            temp2 = temp_value.split('+')
            G.append(float(temp2[0]))
            B.append(float(temp2[1].replace("j", "")))
            Ymatrix[row_value - 1, column_value - 1] = complex(G[-1], B[-1])
            Ymatrix[column_value - 1, row_value - 1] = Ymatrix[row_value - 1, column_value - 1]
        count = count + 1
    file.close()

    Y00 = Ymatrix[0:slack_no, 0:slack_no]
    Y01 = Ymatrix[0:slack_no, slack_no:]
    Y10 = Ymatrix[slack_no:, 0:slack_no]
    Y11 = Ymatrix[slack_no:, slack_no:]
    Y11_sparse = lil_matrix(Y11)
    Y11_sparse = Y11_sparse.tocsr()
    a_sps = sparse.csc_matrix(Y11)
    lu_obj = sp.splu(a_sps)
    Y11_inv = lu_obj.solve(np.eye(totalnode_number-slack_no))
    return [Y00,Y01,Y10,Y11,Y11_sparse,Y11_inv]

def re_orgnaize_for_volt(V1_temp,AllNodeNames,NewNodeNames):
    V1 = [complex(0, 0)] * len(V1_temp)
    count = 0
    for node in NewNodeNames:
        index = AllNodeNames.index(node)
        V1[index] = V1_temp[count]
        count = count + 1
    return V1


def getCapsPos(dss,capNames):
    o = [None]*len(capNames)
    for i,cap in enumerate(capNames):
        x = dss.run_command('? capacitor.%(cap)s.states' % locals())
        o[i] = int(x[-2:-1])
    return o

def getRegsTap(dss,regNames):
    o = [None]*len(regNames)
    for i, name in enumerate(regNames):
        xfmr = dss.run_command('? regcontrol.%(name)s.transformer' % locals())
        res = dss.run_command('? transformer.%(xfmr)s.tap' % locals())
        o[i] = float(res)
    return o

def result(circuit,dss):
    res = {}
    res['AllVoltage'] = circuit.AllBusMagPu()
    temp = circuit.YNodeVArray()
    data = []
    for ii in range(int(len(temp)/2)):
        data.append(complex(temp[2*ii],temp[2*ii+1]))
    res['AllVolt_Yorder'] = data
    res['loss'] = circuit.Losses()
    res['totalPower'] = circuit.TotalPower() # power generated into the circuit
    loadname = dss.Loads.AllNames()
    res['totalLoadPower'] = getTotalPowers(circuit, dss, 'Load', loadname)
    pvNames = dss.PVsystems.AllNames()
    if pvNames:
        res['totalPVpower'] =  getTotalPowers(circuit, dss, 'PVsystem', pvNames) # power generated(negative) into the circuit
    else:
        res['totalPVpower'] = [0,0]
    capNames = dss.Capacitors.AllNames()
    if capNames:
        res['CapState'] = getCapsPos(dss,capNames)
    else:
        res['CapState'] = 'nan'
    regNames = dss.RegControls.AllNames()
    if regNames:
        res['RegTap'] = getRegsTap(dss,regNames)
    else:
        res['RegTap'] = 'nan'
    dataP = np.zeros(len(pvNames))
    dataQ = np.zeros(len(pvNames))
    ii = 0
    for pv in pvNames:
        circuit.SetActiveElement('PVsystem.'+pv)
        tempPQ = dss.CktElement.Powers()
        dataP[ii] = tempPQ[0]
        dataQ[ii] = tempPQ[1]
        ii = ii + 1
    res['PV_Poutput'] = dataP
    res['PV_Qoutput'] = dataQ
    return res