import os
import sys
import xml.etree.ElementTree as xml
import xml.etree.cElementTree as ET
from collections import defaultdict

# Create and Configure the Cluster's XML Template
# for ElasticHPC Clusters

def create_template(cluster_input):
	cluster = xml.Element('cluster')
	cluster.attrib['name'] = cluster_input['name']
	cluster.attrib['nodes'] = cluster_input['nodes']
	for key, node in cluster_input['data'].iteritems():
		vm = _vm_template(key,node)
		cluster.append(vm)
	return cluster

def set_vm_template(xml_file, tag1='virtualmachine', tag2='provider',attribute='zone', input_value='us-centeral1-a', text_input='Google Compute Engine', text_output='Microsoft Windows Azure', output='West US', id='3'):
	tree = ET.ElementTree(file=xml_file)
	root = tree.getroot()
	# edit attribute:
	if attribute !=None:
		for element in tree.iter(tag=tag1):
			#print element
			if element.attrib['id'] == id:
				elements = element.findall(tag2)
				for elem in elements:
					if elem.attrib[attribute] == input_value:
						elem.set(attribute,output) 
	# edit text
	if text_input !=None:			
		for elem in tree.iter(tag=tag1):
			if elem.attrib['id'] == id:
				for element in elem.iter():	
						if element.text == text_input:				
							element.text = text_output
	#else:
	#	raise ValueError('Error: attributes even text values were not submitted')
	return tree


def get_vm_template(xml_file, tag='virtualmachine',id='3'):
	dictionary = defaultdict(dict)
	tree = ET.ElementTree(file=xml_file)
	dictionary = _etree_to_dict(tree.getroot())
	for d in dictionary['cluster']:
		if d['@id'] == id:
			dictionary = d
 	return dictionary

def _etree_to_dict(t):
	dictionary = {t.tag : map(_etree_to_dict, t.getchildren())}
	dictionary.update(('@' + k, v) for k, v in t.attrib.iteritems())
	dictionary['text'] = t.text
	return dictionary

def del_vm_template(xml_file, tag='virtualmachine', attr='name',vm_name=None):
	tree = ET.ElementTree(file=xml_file)
	root = tree.getroot()
	number_of_instances = 0	
	for elem in tree.iter(tag=tag):
		if elem.attrib[attr] == vm_name:
			del root[number_of_instances]
		number_of_instances +=1
	number_of_instances = 0
	for elem in tree.iter():
		number_of_instances +=1
	for elem in tree.iter(tag='cluster'):
		if elem.attrib['nodes'] == str(number_of_instances):
			elem.set('nodes',str(number_of_instances))
	return tree

def _create_provider(cluster_provider):
	provider = xml.Element('provider')
	provider.attrib['service'] = cluster_provider['service']
	provider.attrib['zone'] = cluster_provider['zone']
	provider.text = cluster_provider['data']
	return provider
	
def _disks_template(cluster_disks, id):
	disks = xml.Element('disks')
	disks.attrib['id'] = id
	rootdisk = xml.Element('rootdisk')
	rootdisk.attrib['size'] = cluster_disks['rootdisk']['size']
	rootdisk.attrib['device'] = cluster_disks['rootdisk']['device']
	rootdisk.text = cluster_disks['rootdisk']['name']
	disks.append(rootdisk)
	ephemeral = xml.Element('ephemeral')
	ephemeral.attrib['size'] = cluster_disks['ephemeral']['size']
	ephemeral.attrib['device'] = cluster_disks['ephemeral']['device']
	ephemeral.text = cluster_disks['ephemeral']['name']
	disks.append(ephemeral)
	return disks

def _ssh_template(cluster_ssh, id):
	ssh = xml.Element('ssh')
	ssh.attrib['id'] = id
	configured = xml.Element('configured')
	port = xml.Element('port')
	internalip = xml.Element('internalip')
	externalip = xml.Element('externalip')
	configured.text = cluster_ssh['configured']
	port.text = cluster_ssh['port']
	internalip.text = cluster_ssh['internalip']
	if cluster_ssh['externalip'] != None:
		externalip.text = cluster_ssh['externalip']
	else:
		externalip.text = ''
	ssh.append(configured)
	ssh.append(port)
	ssh.append(internalip)
	ssh.append(externalip)
	return ssh

def _firewall_port(fwname,cluster_port):
	port = xml.Element('port')
	port.attrib['name'] = fwname
	port.attrib['protocol'] = cluster_port['protocol']
	port.attrib['internalport'] = cluster_port['internalport']
	port.attrib['externalport'] = cluster_port['externalport']
	return port

def _firewall_template(cluster_firewall,id):
	firewall = xml.Element('firewall')
	firewall.attrib['id'] = id
	for key, port in cluster_firewall.iteritems():
		firewall.append(_firewall_port(key,port))
	return firewall

def _image_template(cluster_image):
	image = xml.Element('image')
	image.attrib['name'] = cluster_image['name']
	image.attrib['distro'] = cluster_image['distro']
	image.attrib['os'] = cluster_image['os']
	return image

def _sge_template(cluster_sge):
	sge = xml.Element('sge')
	sge.attrib['type'] = cluster_sge['type']
	sge.text = cluster_sge['installed']
	return sge

def _nfs_template(cluster_nfs, id):
	networkfilesystem = xml.Element('networkfilesystem')
	networkfilesystem.attrib['id'] = id
	nfs = xml.Element('nfs')
	nfs.attrib['mount_point'] = cluster_nfs['mount']
	nfs.attrib['device'] = cluster_nfs['device']
	nfs.attrib['fsid'] = cluster_nfs['fsid']
	networkfilesystem.append(nfs)
	return networkfilesystem

def _gluster_template(cluster_gluster):
	gluster = xml.Element('gluster')
	gluster.attrib['mount'] = cluster_gluster['mount']
	gluster.attrib['volume'] = cluster_gluster['volume']
	gluster.attrib['stripe'] = cluster_gluster['stripe']
	gluster.attrib['replicate'] = cluster_gluster['replicate']
	gluster.attrib['device'] = cluster_gluster['device']
	gluster.attrib['format'] = cluster_gluster['format']
	gluster.text = cluster_gluster['installed']
	return gluster
	
def _vm_template(key,node):
	virtualmachine = xml.Element('virtualmachine')
	virtualmachine.attrib['name'] = key
	virtualmachine.attrib['type'] = node['type']
	virtualmachine.attrib['master']=node['master']
	virtualmachine.attrib['status'] = node['status']
	virtualmachine.attrib['id'] = node['id']
	virtualmachine.append(_create_provider(node['provider']))
	virtualmachine.append(_disks_template(node['disks'], node['id']))
	virtualmachine.append(_ssh_template(node['ssh'],node['id']))
	virtualmachine.append(_image_template(node['image']))
	virtualmachine.append(_firewall_template(node['firewall'], node['id']))
	virtualmachine.append(_nfs_template(node['nfs'], node['id']))
	virtualmachine.append(_sge_template(node['sge']))
	virtualmachine.append(_gluster_template(node['gluster']))
	return virtualmachine
	
	
# start from here
nodes = defaultdict(dict)
firewall = defaultdict(dict)
disks = defaultdict(dict)
ssh = {}
image = {}
nfs = {}
sge = {}
gluster = {}
provider = {}

for i in range(5):
	nodes['elastic%s'%(str(i))]['status'] = 'running'
	nodes['elastic%s'%(str(i))]['master'] = 'elastic0'
	nodes['elastic%s'%(str(i))]['type'] = 'n1-standard-8'
	nodes['elastic%s'%(str(i))]['id'] = str(i)
	provider['zone'] = 'us-centeral1-a'
	provider['service'] = ' '
	provider['data'] = 'Google Compute Engine' 
	disks['rootdisk']['name'] = 'rootdisk'
	disks['rootdisk']['device'] = '/dev/sda'
	disks['rootdisk']['size'] = '10'
	disks['ephemeral']['name'] = 'disk2'
	disks['ephemeral']['device'] = '/dev/sdb'
	disks['ephemeral']['size'] = '200'
	ssh['configured'] = 'True'
	ssh['port'] = '22'
	ssh['internalip'] = '10.2.34.33'
	ssh['externalip'] = '149.123.142.44'
	image['distro'] = 'Ubuntu 14.04 LTS'
	image['os'] = 'Linux'
	image['name'] = 'genomekey-image'
	firewall['ehpcd']['protocol'] = 'tcp'
	firewall['ehpcd']['internalport'] = '5000'
	firewall['ehpcd']['externalport'] = '5000'
	firewall['ssh']['protocol'] = 'tcp'
	firewall['ssh']['internalport'] = '22'
	firewall['ssh']['externalport'] = '22'
	nfs['device'] = '/home'
	nfs['mount'] = '/home'
	nfs['fsid'] = '1'
	sge['type'] = 'Sun Grid Engine 11'
	sge['installed'] = 'False'  
	gluster['device'] = '/dev/sdb'
	gluster['format'] = 'False'
	gluster['mount'] = '/gluster/WGA'
	gluster['replicate'] = '1'
	gluster['stripe'] = '1'
	gluster['volume'] = 'gv0'
	gluster['installed'] = 'False'
	 
	nodes['elastic%s'%(str(i))]['provider'] = provider
	nodes['elastic%s'%(str(i))]['disks'] = disks
	nodes['elastic%s'%(str(i))]['firewall'] = firewall
	nodes['elastic%s'%(str(i))]['ssh'] = ssh
	nodes['elastic%s'%(str(i))]['image'] = image
 	nodes['elastic%s'%(str(i))]['nfs'] = nfs
 	nodes['elastic%s'%(str(i))]['sge'] = sge
 	nodes['elastic%s'%(str(i))]['gluster'] = gluster

cluster = {}
cluster_prefix = 'elastichpc'
cluster['name'] = cluster_prefix
cluster['nodes'] = str(len(nodes))   
cluster['data'] = nodes 
cluster_output = create_template(cluster) 
cluster_name = 'cluster_elastichpc'
file = open("%s.xml"%(cluster_name), 'w')
xml.ElementTree(cluster_output).write(file,encoding='utf-8', xml_declaration=True)
file.close()


#tree = edit_template('%s.xml'%(cluster_name),tag='ephemeral',attribute='device', input_value='/dev/sdb', output='/dev/sdc',vm='elastic4')
#tree = edit_template('%s.xml'%(cluster_name), tag='externalip', text='149.123.142.44', output='149.123.10.2',vm='elastic4')
#tree.write("%s.xml"%(cluster_name))
#tree = del_vm_template("%s.xml"%(cluster_name), attr='name', vm_name='elastic4')
#for i in range(2):
#	tree = del_vm_template("%s.xml"%(cluster_name), attr='name', vm_name='elastic%s'%(str(i)))
#	tree.write("%s.xml"%(cluster_name))


#################################################
tree = set_vm_template('%s.xml'%(cluster_name))
tree.write("%s.xml"%(cluster_name))

#print dict(get_vm_template("%s.xml"%(cluster_name)))

