from gvm.connections import TLSConnection
from gvm.protocols.gmpv208 import Gmp, AliveTest 
from gvm.transforms import EtreeTransform
from gvm.xml import pretty_print
from time import time, sleep
import logging
import json
import base64

local_ip = "127.0.0.1"
connection = TLSConnection(hostname=local_ip)
transform = EtreeTransform()
config = {'id':"9866edc1-8869-4e80-acac-d15d5647b4d9"}
scanner = {'id': "08b69003-5fc2-4037-a479-93b440211c73"}
ovs_ssh_credential = {'id': "b9af5845-8b87-4378-bca4-cee39a894c17"}


def get_version_old(auth_name, auth_passwd):
	with Gmp(connection, transform=transform) as gmp:
		gmp.authenticate(auth_name, auth_passwd)	
		pretty_print(gmp.get_version())

def create_connection(auth_name, auth_passwd):
    connection_retries = 5
    retry = connection_retries
    while(retry > 0):
        try:
            gmp = Gmp(connection, transform=transform)
            gmp.authenticate(auth_name, auth_passwd)
            return gmp
        except:
            logging.warning(f"Connection error with the gmp endpoint. Remaining {retry} retries")
            retry -= 1
            sleep(0.5)
    raise Exception("Impossible connect to the gmp endpoint even after 5 retries")

def get_version():
    gmp = create_connection()
    res = gmp.get_version()
    return res.xpath('version/text()')[0]

########## PORT LIST ##################################

def create_port_list(port_list_name, ports):
    gmp = create_connection()
    res = gmp.create_port_list(port_list_name, ','.join(ports))
    status = res.xpath('@status')[0]
    status_text = res.xpath('@status_text')[0]
    if status == "201":
        id = res.xpath('@id')[0]
        logging.debug(f'Created port list obj. Name: {port_list_name}, id: {id}, ports: {ports}')
        return {'name': port_list_name, 'id': id}
    else:
        logging.error(f"ERROR during Port list creation. Status code: {status}, msg: {status_text}")
        msg = f"ERROR during Port list creation. Status code: {status}, msg: {status_text}"
        raise Exception(msg) 

def get_port_lists(filter_str="rows=-1"):
    l_o = []
    gmp = create_connection()
    res = gmp.get_port_lists(filter_string=filter_str)
    for pl in res.xpath('port_list'):
        o = dict()
        o['name'] = pl.xpath('name/text()')[0]
        o['id'] = pl.xpath('@id')[0]
        o['in_use'] = pl.xpath('in_use/text()')[0]
        l_o.append(o)
    return l_o

def delete_port_list(port_list):
    gmp = create_connection()
    res = gmp.delete_port_list(port_list['id'])
    status = res.xpath('@status')[0]
    status_text = res.xpath('@status_text')[0]
    if status == "200":
        logging.info(f"Port_list with id: {port_list['id']} and name: {port_list['name']} DELETED") 
    else:
        logging.error(f"ERROR {status}: {status_text}")

def get_or_create_port_list(port_list_name, ports):
    res = get_port_lists(port_list_name)
    if len(res) == 0:
        port_list = create_port_list(port_list_name, ports)
        return get_port_lists(port_list['id'])[0]
    elif len(res) == 1:
        return res[0]
    else:
        logging.warning(f"Found {len(res)} results.")
        return res

############## TARGET  ##################################

def create_target(name,ip,port_list):
    o = dict()
    gmp = create_connection()
    res = gmp.create_target(
            name=name,
            comment = "",
            hosts=[ip],
            port_list_id = port_list['id'],
            ssh_credential_id = ovs_ssh_credential['id'],
            alive_test=AliveTest('Consider Alive'))
    status = res.xpath('@status')[0]
    status_text = res.xpath('@status_text')[0]
    if status == "201":
        id = res.xpath('@id')[0]
        return {'name': name, 'id': id}
    else:
        msg = f"ERROR during Target creation. Status code: {status}, msg: {status_text}"
        raise Exception(msg) 

def get_targets(filter_str):
    res = []
    gmp = create_connection()
    targets = gmp.get_targets(filter_string=filter_str)
    for target in targets.xpath('target'):
        o = dict()
        o['name'] = target.xpath('name/text()')[0]
        o['hosts'] = target.xpath('hosts/text()')[0]
        o['id'] = target.xpath('@id')[0]
        o['in_use'] = target.xpath('in_use/text()')[0]
        res.append(o)
    return res

def delete_target(target):
    gmp = create_connection()
    res = gmp.delete_target(target['id'])
    status = res.xpath('@status')[0]
    status_text = res.xpath('@status_text')[0]
    if status == "200":
        logging.info(f"Port_list with id: {target['id']} and name: {target['name']} DELETED") 
    else:
        logging.error(f"ERROR {status}: {status_text}")

def get_or_create_target(target_name,ip,port_list):
    res = get_targets(target_name)
    if len(res) == 0:
        t = create_target(target_name,ip,port_list)
        return get_targets(t['id'])[0]
    elif len(res) == 1:
        return res[0]
    else:
        print(f"Found {len(res)} results. Return None")
        return res

def search_and_delete_target(target_name):
    targets = get_targets(target_name)
    if len(targets) == 1:
        delete_target(targets[0]['id'])
    else:
        raise("Multiple results for search")

def search_and_delete_all_targets(target_name):
    targets = get_targets(target_name)
    for target in targets:
        delete_target(target)

############## TASK ##################################

def create_task(name, target):
    o = dict()
    gmp = create_connection()
    res = gmp.create_task(
            name=name,
            config_id=config['id'],
            target_id=target['id'],
            scanner_id=scanner['id'])
    status = res.xpath('@status')[0]
    status_text = res.xpath('@status_text')[0]
    if status == "201":
        id = res.xpath('@id')[0]
        return {'name': name, 'id': id}
    else:
        msg = f"ERROR during Task creation. Status code: {status}, msg: {status_text}"
        raise Exception(msg)

def get_tasks(filter_str):
    res = []
    gmp = create_connection()
    tasks = gmp.get_tasks(filter_string=filter_str)
    for task in tasks.xpath('task'):
            o = dict()
            o['name'] = task.xpath('name/text()')[0]
            o['id'] = task.xpath('@id')[0]
            o['progress'] = task.xpath('progress/text()')[0]
            o['in_use'] = task.xpath('in_use/text()')[0]
            o['status'] = task.xpath('status/text()')[0]
            o['target_id'] = task.xpath('target/@id')[0]
            try:
                o['report_id'] = task.xpath('last_report/report/@id')[0]
            except:
                pass
            res.append(o)
    return res

def get_or_create_task(task_name, target):
    res = get_tasks(task_name)
    if len(res) == 0:
        t = create_task(task_name, target)
        return get_tasks(t['id'])[0]
    elif len(res) == 1:
        return res[0]
    else:
        print(f"Found {len(res)} results. Return None")
        return res

def get_all_tasks():
    res = []
    gmp = create_connection()
    tasks = gmp.get_tasks(filter_string="rows=-1")
    for task in tasks.xpath('task'):
            o = dict()
            o['name'] = task.xpath('name/text()')[0]
            o['id'] = task.xpath('@id')[0]
            o['progress'] = task.xpath('progress/text()')[0]
            o['in_use'] = task.xpath('in_use/text()')[0]
            o['status'] = task.xpath('status/text()')[0]
            o['target_id'] = task.xpath('target/@id')[0]
            try:
                o['report_id'] = task.xpath('last_report/report/@id')[0]
            except:
                pass
            res.append(o)
    return res

def search_and_delete_all_tasks(filter_str):
    tasks = get_tasks(filter_str)
    for task in tasks:
        delete_task(task)

def start_task(task):
    gmp = create_connection()
    res = gmp.start_task(task['id'])
    task['report_id'] = res.xpath('report_id/text()')[0]
    return task        

def stop_task(task):
    gmp = create_connection()
    res = gmp.stop_task(task['id'])
    pretty_print(res)

def delete_task(task):
    gmp = create_connection()
    res = gmp.delete_task(task['id'])
    status = res.xpath('@status')[0]
    status_text = res.xpath('@status_text')[0]
    if status == "200":
        logging.info(f"Target with id: {task['id']} and name: {task['name']} DELETED") 
    else:
        logging.error(f"ERROR {status}: {status_text}")
        
############## REPORTS #####################################3

class report_formats:
    anonymous_xml = "5057e5cc-b825-11e4-9d0e-28d24461215b"
    csv_results   = "c1645568-627a-11e3-a660-406186ea4fc5"
    itg           = "77bd6c4a-1f62-11e1-abf0-406186ea4fc5"
    pdf           = "c402cc3e-b531-11e1-9163-406186ea4fc5"
    txt           = "a3810a62-1f62-11e1-9219-406186ea4fc5"
    xml           = "a994b278-1f62-11e1-96ac-406186ea4fc5"

def get_report_formats():
    gmp = create_connection()  
    res =  gmp.get_report_formats()
    for f in res.xpath('report_format'):
        name = f.xpath('name/text()')[0]
        id = f.xpath('@id')[0]
        print(id,name)

def get_report_format(id):
    gmp = create_connection()
    res =  gmp.get_report_formats()
    pretty_print(res)  

def get_progress(task):
    task_info = get_tasks(task['id'])[0]
    status = task_info['status']         #   New -> Requested -> Queued -> Running  -> Done
    progress = int(task_info['progress'])#    0         0           0      0 -> 100     -1
    return status, progress

def wait_for_task_ending(task, timeout=3600):
    start_time = time()
    logging.info("Waiting for scans ends the task")
    while True:
        status, progress = get_progress(task)
        if status not in ["New","Requested","Queued","Running","Done"]: # ["Interrupted", ...]
            logging.warning(f"Waiting for scans ends the task. Status: {status}")
            return False
        if status == "Done" and progress == -1:
            logging.info(f"Waiting for scans ends the task. Status: {status}")
            return True
        if time() - start_time > timeout:
                logging.error("TIMEOUT during waiting for task ending")
                return False
        logging.debug(f"Waiting for the task ends. Now {int(time() - start_time)}s from start. Status: {status}")
        sleep(10)
    
def save_report(task,report_format_id, report_filename ):
    gmp = create_connection()
    res = gmp.get_report(task['report_id'],
                                report_format_id=report_format_id, 
                                ignore_pagination=True,
                                details="1")
    code = str(res.xpath('report/text()')[0])
    with open(report_filename, "wb") as fh:
        fh.write(base64.b64decode(code))

def save_severity_report(task, severity_filename):
    dict_severity = {"Log": 0, "Low": 1, "Medium": 2, "High": 3}
    gmp = create_connection()
    res = gmp.get_report(task['report_id'],
                        report_format_id=report_formats.anonymous_xml, 
                        ignore_pagination=True,
                        details="1")
    severities = res.xpath('report/report/ports/port/threat/text()')
    old_num_severity = 0
    severity = "Log"
    for sev in severities:
        if dict_severity[sev] > old_num_severity:
            old_num_severity = dict_severity[sev]
            severity = sev
    with open(severity_filename, "w") as f:
        f.write(severity)

def get_report_info(task):
    report = dict()
    gmp = create_connection()
    res = gmp.get_report(task['report_id'],
                        report_format_id=report_formats.anonymous_xml,
                        ignore_pagination=True,
                        details="1")
    threats = res.xpath('report/report/ports/port/threat/text()')
    ports = res.xpath('report/report/ports/port/text()')
    severities = res.xpath('report/report/ports/port/severity/text()')
    severities = list(map(lambda a : float(a), severities))
    for p,t,s in zip(ports, threats, severities):
        report[p] = {'severity': s, 'threat': t}
    glob_severity = -1 # returned severities are null or positive
    glob_threat = 'Log'
    for threat,severity in zip(threats,severities):
        if severity > glob_severity:
            glob_severity = severity
            glob_threat = threat
            glob_severity = severity

    report['global'] = {'threat': glob_threat, 'severity': glob_severity}
    return report

            
def get_reports(filter_str="rows=-1"):
    lo = []
    gmp = create_connection()
    reports = gmp.get_reports(filter_string = filter_str)
    for report in reports.xpath('report'):
        o = dict()
        o['task_name'] = report.xpath('task/name/text()')[0]
        o['id'] = report.xpath('@id')[0]
        lo.append(o)
    return lo

def get_numeric_severity(severity):
    if severity == "Log":
        return 0
    elif severity == "Low":
        return 1
    elif severity == "Medium":
        return 2
    elif severity == "High":
        return 3
    else:
        return 4

def get_severity_from_number(num):
    if num == 0:
        return "Low"
    elif num == 1:
        return "Low"
    elif num == 2:
        return "Medium"
    elif num == 3:
        return "High"
    else:
        return "Undefined"

def process_global_reports_info(reports):
    glob_severity = -1
    glob_threat = 'Log'
    for host in reports:
        host_glob_severity = reports[host]['global']['severity']
        if host_glob_severity > glob_severity:
            glob_severity = host_glob_severity
            glob_threat = reports[host]['global']['threat']
    reports['deployment'] = {'severity': glob_severity, 
                             'threat': glob_threat}
    if reports['deployment']['severity'] < 4:
        reports['global'] = "OK"
    else:
        reports['global'] = "NOK"
    return reports

def pretty_json(j):
    return json.dumps(j,sort_keys=True,indent=4)
        
def import_dep_info(file_path, endpoints_to_scan):
    with open(file_path) as f:
        data = json.load(f)    
    endpoints = dict()
    for key in data['outputs'].keys():
        if key in endpoints_to_scan:
            endpoint = str(data['outputs'][key])
            prefix,url = endpoint.split("://")
            if ":" in url:
                host,port = url.split(":")
            else:
                host = url
                if prefix == "https":
                    port = '443'
                elif prefix == 'http':
                    port = '80'
                else:
                    raise Exception(f"Impossible to parse the endpoint port. Endpoint: {endpoint}")
            logging.info(f"Endpoint: {host}:{port}")
            if host not in endpoints:
                endpoints[host] = {"22"}
            endpoints[host].add(port)
    for host,ports in endpoints.items():
        endpoints[host] = sorted(list(ports))
    return endpoints