#!/usr/bin/env python3
#
# Copyright (c) Greenplum Inc 2009. All Rights Reserved.
#
# gpconfig will guarantee that all segments have valid conf files if gpconfig runs successfully, but we do not handle
# sequence races. To be specific, suppose there are two simultaneous callers of gpconfig, U1 and U2, that modify
# the same guc, and that both U1 and U2 succeed. Also call the resulting conf files with those respective last conf
# lines U1 and U2. You'll either see just U1, just U2, U2U1, or U1U2. In other words, the sequence of updates isn't
# guaranteed and one might be lost. This is on a per-segment basis. It is possible that, on a per-segment basis,
# some segments see U1, U2, U2U1 or U1U2. If a user runs gpconfig simultaneously for the same guc with two
# different values, both might complete successfully yet different segments might have different gucs
# set(but each will be one of U1, U2, U1U2, U2U1). This is because we start different processes, one per segment,
# to rewrite the file, and we do not synchronize those.

import os
import sys
import re

try:
    from gppylib.gpparseopts import OptParser, OptChecker
    from gppylib.gparray import GpArray
    from gppylib.operations.detect_unreachable_hosts import get_unreachable_segment_hosts
    from gppylib.gplog import *
    from gppylib.commands.unix import *
    from gppylib.commands.gp import *
    from gppylib.db import dbconn
    from gppylib.userinput import *
    from pg import DatabaseError
    from gpconfig_modules.segment_guc import SegmentGuc
    from gpconfig_modules.database_segment_guc import DatabaseSegmentGuc
    from gpconfig_modules.file_segment_guc import FileSegmentGuc
    from gpconfig_modules.guc_collection import GucCollection
    from gppylib.gpresgroup import GpResGroup
    from gpconfig_modules.parse_guc_metadata import ParseGuc
except ImportError as err:
    sys.exit('Cannot import modules.  Please check that you have sourced '
             'greenplum_path.sh.  Detail: ' + str(err))

EXECNAME = os.path.split(__file__)[-1]

PROHIBITED_GUCS = set(["port", "listen_addresses"])
SAMEVALUE_GUCS = set(["gp_default_storage_options"])
read_only_gucs = set()  # populated at runtime
LOGGER = get_default_logger()
setup_tool_logging(EXECNAME, getLocalHostname(), getUserName())
gp_array = None


def parseargs():
    parser = OptParser(option_class=OptChecker)
    parser.remove_option('-h')
    parser.add_option('-h', '-?', '--help', action='help')
    parser.add_option('--verbose', action='store_true')
    parser.add_option('--skipvalidation', action='store_true')
    parser.add_option('--masteronly', '--coordinatoronly', dest="coordinatoronly", action='store_true')
    parser.add_option('--debug', action='store_true')
    parser.add_option('-c', '--change', type='string')
    parser.add_option('-r', '--remove', type='string')
    parser.add_option('-s', '--show', type='string')
    parser.add_option('-v', '--value', type='string')
    parser.add_option('-m', '--mastervalue', dest="coordinatorvalue", type='string')
    parser.add_option('--coordinatorvalue', dest="newcoordinatorvalue", type='string')
    parser.add_option('-l', '--list', action='store_true')
    parser.add_option('-P', '--primaryvalue', type='string')
    parser.add_option('-M', '--mirrorvalue', type='string')
    parser.add_option('-f', '--file', action='store_true')
    parser.add_option('--file-compare', dest='file_compare', action='store_true')
    parser.setHelp([])

    (options, _) = parser.parse_args()

    # --coordinatorvalue argument takes precedence over -m/--mastervalue
    if options.newcoordinatorvalue is not None:
        options.coordinatorvalue = options.newcoordinatorvalue

    options.entry = None
    validate_four_verbs(options)
    validate_mutual_options(options)
    return options


def validate_four_verbs(options):
    if options.change:
        options.entry = options.change
    elif options.remove:
        options.entry = options.remove
        options.remove = True
    elif not options.list and not options.show:
        log_and_raise("No action specified.  See the --help info.")


def validate_mutual_options(options):
    user = os.getenv('USER')
    if user is None or user == ' ':
        log_and_raise("USER environment variable must be set.")

    gphome = os.getenv('GPHOME')
    if not gphome:
        log_and_raise("GPHOME environment variable must be set.")

    if options.file and not options.show:
        log_and_raise("'--file' option must accompany '--show' option")
    if options.file and options.file_compare and options.show:
        log_and_raise("'--file' option and '--file-compare' option cannot be used together")
    if options.file:
        try:
            get_coordinatordatadir()
        except:
            log_and_raise("--file option requires that COORDINATOR_DATA_DIRECTORY be set")
    if options.remove and (options.value is not None or options.primaryvalue is not None or options.mirrorvalue is not None or options.coordinatorvalue is not None):
        log_and_raise("remove action does not take a value, primary value, mirror value or coordinator value parameter")
    if options.change and options.remove:
        log_and_raise("Multiple actions specified.  See the --help info.")
    if options.change and (options.value is None and options.mirrorvalue is None and options.primaryvalue is None):
        log_and_raise("change requested but value not specified")
    if options.change and options.coordinatorvalue is not None and options.coordinatoronly:
        log_and_raise("when changing a parameter on the coordinator only specify the --value not --mastervalue")
    if options.change and (options.value is not None and (options.primaryvalue is not None or options.mirrorvalue is not None)):
        log_and_raise("cannot use both value option and primaryvalue/mirrorvalue option")
    if (options.coordinatoronly or options.coordinatorvalue is not None) and options.entry in SAMEVALUE_GUCS:
        log_and_raise("%s value cannot be different on coordinator and segments" % options.entry)
    if options.value is not None and options.coordinatorvalue is None:
        options.coordinatorvalue = options.value


class ToolkitQuery:
    def __init__(self, name):
        self.query = "select * from gp_toolkit.gp_param_setting('%s')" % name


class GucQuery:
    def __init__(self, name=None):
        self.query = 'SELECT name, setting, unit, short_desc, context, vartype, min_val, max_val FROM pg_settings'
        if name:
            self.query = self.query + " where name = '" + name + "'"


class Guc:
    def __init__(self, row):
        self.name = row[0]
        self.setting = row[1]
        self.unit = row[2]
        self.short_desc = row[3]
        self.context = row[4]
        self.vartype = row[5]
        self.min_val = row[6]
        self.max_val = row[7]

    def validate(self, newval, newcoordinatorval, options):
        # todo add code here...
        # be careful 128KB in postgresql.conf is translated into 32KB units

        if self.name == "max_connections" and (not options.coordinatoronly):
            try:
                seg = int(newval)
                coordinator = int(newcoordinatorval)
                if seg <= coordinator:
                    return "the value of max_connections must be greater on the segments than on the coordinator"
            except:
                return "invalid value for max_connections"

        elif self.name == "gp_resource_manager":
            if newval == "'group'":
                msg = GpResGroup().validate()
                if msg is not None:
                    return msg
            elif newval == "'group-v2'":
                msg = GpResGroup().validate_v2()
                if msg is not None:
                    return msg
            elif newval != "'queue'":
                return "the value of gp_resource_manager must be 'group' or 'group-v2' or 'queue'"

        elif self.name == 'unix_socket_permissions':
            if newval[0] != '0':
                LOGGER.warn(
                    'Permission not entered in octal format.It was interpreted as Decimal.  %s in Octal = 0%s' % (
                        newval, int(newval, 8)))
        elif self.name == "gp_default_storage_options":
            newval = newval.strip()
            # Value must be enclosed in single quotes else postgres
            # will fail to start due to syntax errors in config file.
            if (not newval.startswith("'")) or (not newval.endswith("'")):
                return "Value must be enclosed in single quotes: '...'"
            newval = newval.strip("'")
            # Ensure that newval is of the form 'name=value,...'.
            for pair in newval.split(","):
                try:
                    name, value = pair.strip().split("=")
                    name = name.strip()
                    value = value.strip()
                    if re.match("^[a-z][a-z]*$", name) is None:
                        return "Invalid option name \"%s\"" % name
                    if re.match("^[a-z0-9_][a-z0-9_]*$", value) is None:
                        return ("Invalid value \"%s\" for option %s" %
                                (value, name))
                except ValueError:
                    return "Valid values are of the form 'name=value,...'."
        return "ok"

    def print_info(self):
        print(("[name: %s] [unit: %s] [context: %s] [vartype: %s] [min_val: %s] [max_val: %s]" % (
            self.name, self.unit, self.context, self.vartype, self.min_val, self.max_val)))


def confirm_user_wants_to_continue():
    if not ask_yesno('', "One or more segment hosts are not currently reachable. If you continue with gpconfig, "
                         "GUCs on unreachable segment hosts will not be updated. Do you want to continue?", 'N'):
        LOGGER.info("User Aborted. Exiting...")
        raise Exception("User Aborted. Exiting.")


def print_verbosely(options, hostname, directory):
    if options.verbose:
        msg = "host=%s dir=%s" % (hostname, directory)
        LOGGER.info(msg)


def do_list(skipvalidation):
    try:
        dburl = dbconn.DbURL()
        conn = dbconn.connect(dburl, True)
        rows = dbconn.query(conn, GucQuery().query)

        for row in rows:
            guc = Guc(row)
            if skipvalidation or (guc.name not in PROHIBITED_GUCS):
                guc.print_info()

        conn.close()

    except DatabaseError as _:
        LOGGER.error('Failed to connect to database, this script can only be run when the database is up.')


def get_gucs_from_database(gucname):
    try:
        dburl = dbconn.DbURL()
        # we always want to unset search path except when getting the
        # 'search_path' GUC itself
        unsetSearchPath = gucname != 'search_path'
        conn = dbconn.connect(dburl, False, True, unsetSearchPath=unsetSearchPath)
        query = ToolkitQuery(gucname).query
        cursor = dbconn.query(conn, query)
        # we assume that all roles are primary due to the query.
        gucs = [DatabaseSegmentGuc(row) for row in cursor]
        conn.close()
        return gucs

    except DatabaseError as ex:

        if re.search("unrecognized configuration parameter", ex.__str__()):
            LOGGER.error('Failed to retrieve GUC information, guc does not exist: ' + gucname)
        elif re.search("could not connect to server", ex.__str__()):
            LOGGER.error('Failed to retrieve GUC information, the database is not accessible')
        else:
            LOGGER.error('Failed to retrieve GUC information: ' + ex.__str__())


def _print_gucs(gucname, gucs, options):
    collection = GucCollection(options)
    collection.update_list(gucs)

    if _show_all_segment_values_always(options):
        print(("GUC                 : %s" % gucname))
        for guc in list(collection.values()):
            print(("Context: %5s Value: %s" % (guc.context, guc.value)))
    else:
        if collection.are_segments_consistent():
            print("Values on all segments are consistent")
            print(("GUC              : %s" % gucname))
            print((collection.report()))
        else:
            print("WARNING: GUCS ARE OUT OF SYNC: ")
            print((collection.report()))


def _show_all_segment_values_always(options):
    return options.show == "port"

# FIXME: add value to cmd_name.  We do not just do this due to encoding issues.
def do_add_config_script(pool, segs, value, options):
    for seg in segs:
        print_verbosely(options, seg.hostname, seg.datadir)
        cmd_name = "add %s parameter on host %s for seg %s" % (options.entry, seg.hostname, seg.datadir)
        cmd = GpConfigHelper(cmd_name, seg.datadir, options.entry, value=value,
                             removeParameter=options.remove,
                             ctxt=REMOTE,
                             remoteHost=seg.hostname)
        pool.addCommand(cmd)


def do_change(options):
    if options.debug:
        enable_verbose_logging()

    try:
        if not options.skipvalidation:
            conn = dbconn.connect(dbconn.DbURL(), True)
            guc = get_normal_guc(conn, options)

            # Force the postgresql.conf parser to detect vartype string as GUC_STRING in the guc-file.c/guc-file.l
            options.value = quote_string(guc, options.value)
            options.coordinatorvalue = quote_string(guc, options.coordinatorvalue)

            validate_change_options(options, conn, guc)
            conn.close()

    except DatabaseError as ex:
        LOGGER.error(ex.__str__())
        msg = 'Failed to connect to database, exiting without action. ' \
              'This script can only be run when the database is up.'
        LOGGER.error(msg)
        raise Exception(msg)

    hosts = [gp_array.coordinator.hostname]
    if gp_array.standbyCoordinator is not None:
        hosts.append(gp_array.standbyCoordinator.hostname)

    # if --coordinatoronly, we only need to check coordinator and standby. Else check other hosts too.
    if not options.coordinatoronly:
        hosts.extend(gp_array.get_hostlist(includeCoordinator=False))

    unreachable_hosts = get_unreachable_segment_hosts(hosts, len(set(hosts)))
    if len(unreachable_hosts) > 0:
        confirm_user_wants_to_continue()

    pool = WorkerPool()
    failure = False
    try:
        # do the segments
        if not options.coordinatoronly:
            reachable_segments = [seg for seg in gp_array.getSegDbList() if seg.hostname not in unreachable_hosts]
            if options.primaryvalue:
                do_add_config_script(pool, [seg for seg in reachable_segments if seg.isSegmentPrimary()],
                                     options.primaryvalue, options)

            if options.mirrorvalue:
                do_add_config_script(pool, [seg for seg in reachable_segments if seg.isSegmentMirror()],
                                     options.mirrorvalue, options)

            if not options.primaryvalue and not options.mirrorvalue:
                do_add_config_script(pool, [seg for seg in reachable_segments], options.value, options)

        # do the coordinator and standby
        if options.coordinatorvalue or options.remove:
            do_add_config_script(pool, [seg for seg in [gp_array.coordinator, gp_array.standbyCoordinator] if seg is not None and seg.hostname not in unreachable_hosts],
                                 options.coordinatorvalue, options)

        pool.join()
        items = pool.getCompletedItems()
        for i in items:
            if not i.was_successful():
                LOGGER.error('failed updating the postgresql.conf files on host: %s msg:%s' %
                             (i.remoteHost, i.get_results().stderr))
                failure = True

        pool.check_results()
    except Exception as err:
        failure = True
        LOGGER.error('errors in job:')
        LOGGER.error(err.__str__())
        LOGGER.error('exiting early')
    finally:
        pool.haltWork()
        pool.joinWorkers()

    # Replace literal empty strings with empty quotes, or it will look like the
    # user passed in an incorrect argument, which would be misleading
    params = [pipes.quote(arg) for arg in sys.argv[1:]]
    params = " ".join(params)
    if failure:
        LOGGER.error("finished with errors, parameter string '%s'" % params)
    else:
        LOGGER.info("completed successfully with parameters '%s'" % params)


# If the value is a string, escape and single-quote it as per the behavior of
# the quote_literal() function in Postgres.
def quote_string(guc, value):
    if value is not None and guc and guc.vartype == "string":
        # Escape single quotes, backslashes, and newlines
        value = value.replace("'", "''") \
                     .replace("\\", "\\\\") \
                     .replace("\n", "\\n")
        # Single-quote the whole string
        value = "'" + value + "'"
    return value


def _is_guc_writeable(options):
    """
FYI, metadata about gucs, like GUC_DISALLOW_IN_FILE, is not available via
sql. We work around that by making use of a file already created during 'make install'.
(That file is created by parsing the C code that defines all GUCS,
 storing all properties with GUC_DISALLOW_IN_FILE into a file.)
    """

    gphome = os.getenv('GPHOME')
    gpconfig_modules_dir = os.path.join(gphome, ParseGuc.DESTINATION_DIR)
    disallowed_guc_file = os.path.abspath(os.path.join(gpconfig_modules_dir, ParseGuc.DESTINATION_FILENAME))
    if os.path.exists(disallowed_guc_file):
        with open(disallowed_guc_file) as f:
            lines = f.readlines()
        read_only_gucs.update([guc.strip() for guc in lines])
    else:
        msg = "disallowed GUCs file missing: '%s'" % disallowed_guc_file
        LOGGER.warning(msg)
    return options.entry not in read_only_gucs


def validate_change_options(options, conn, guc):
    if not _is_guc_writeable(options):
        msg = "not a modifiable GUC: '%s'" % options.entry
        LOGGER.fatal(msg)
        raise Exception(msg)

    if not guc:
        # Hidden gucs: a guc is considered hidden if both:
        #     1. It is not present with normal gucs in pg_settings
        #     2. It has a valid return from SHOW <guc_name>;
        try:
            dbconn.querySingleton(conn, "SHOW %s" % options.entry)
        except DatabaseError:
            msg = "not a valid GUC: " + options.entry
            LOGGER.fatal(msg)
            raise Exception(msg)

        msg = "GUC Validation Failed: %s cannot be changed under normal conditions. " \
              "Please refer to gpconfig documentation." % options.entry
        LOGGER.fatal(msg)
        raise Exception(msg)

    if options.entry in PROHIBITED_GUCS:
        msg = "The parameter '%s' is not modifiable with this tool" % options.entry
        LOGGER.fatal(msg)
        raise Exception(msg)

    if options.value:
        msg = guc.validate(options.value, options.coordinatorvalue, options)
        if msg != "ok":
            msg = "new GUC value failed validation: " + msg
            LOGGER.fatal(msg)
            raise Exception(msg)


def get_normal_guc(conn, options):
    cursor = dbconn.query(conn, GucQuery(options.entry).query)
    rows = cursor.fetchall()
    guc = None
    if len(rows) > 1:
        msg = "more than 1 GUC matches: " + options.entry
        LOGGER.fatal(msg)
        raise Exception(msg)
    elif len(rows) == 1:
        guc = Guc(rows[0])
    return guc


def log_and_raise(err_str):
    LOGGER.error(err_str)
    raise Exception(err_str)


def get_gucs_from_files(guc):
    pool = WorkerPool()
    gucs_found = []
    for seg in gp_array.getDbList():
        cmd_name = "get %s parameter on host %s" % (guc, seg.hostname)
        pool.addCommand(
            GpConfigHelper(cmd_name, seg.datadir, guc,
                           segInfo=seg, getParameter=True,
                           ctxt=REMOTE,
                           remoteHost=seg.hostname))

    failedSegs = False
    pool.join()

    for cmd in pool.getCompletedItems():
        if not cmd.was_successful():
            LOGGER.error('failed obtaining guc %s from the postgresql.conf files on host: %s msg: %s' %
                         (guc, cmd.remoteHost, cmd.get_results().stderr))
            failedSegs = True
        else:
            gucs_found.append(
                FileSegmentGuc([cmd.segInfo.getSegmentContentId(), guc, cmd.get_value(),
                                cmd.segInfo.getSegmentDbId()]))

    pool.check_results()
    pool.haltWork()
    pool.joinWorkers()

    if failedSegs:
        LOGGER.error("finished with errors obtaining guc %s from at least one file" % guc)

    return gucs_found


def _set_gparray():
    try:
        global gp_array
        gp_array = GpArray.initFromCatalog(dbconn.DbURL(), utility=True)
    except DatabaseError as ex:
        LOGGER.error(ex.__str__())
        msg = 'Failed to connect to database, exiting without action. ' \
              'This script can only be run when the database is up.'
        LOGGER.error(msg)
        raise Exception(msg)


def do_show(options):
    if options.skipvalidation:
        log_and_raise('--skipvalidation can not be combined with --show')
        return

    gucname = options.show
    gucs = []

    if options.file:
        gucs.extend(get_gucs_from_files(gucname))
    elif options.file_compare:
        gucs.extend(get_gucs_from_database(gucname))
        gucs.extend(get_gucs_from_files(gucname))
    else:
        gucs.extend(get_gucs_from_database(gucname))

    _print_gucs(gucname, gucs, options)

def check_gpexpand():
    check_result, msg = conflict_with_gpexpand("gpconfig",
                                               refuse_phase1=True,
                                               refuse_phase2=False)
    if not check_result:
        LOGGER.error(msg)
        sys.exit(1)

def do_main():
    options = parseargs()
    _set_gparray()

    if options.list:
        do_list(options.skipvalidation)
    elif options.show:
        do_show(options)

    elif options.remove or options.change:
        # gpconfig should check gpexpand running status when
        # it wants to modify configurations
        check_gpexpand()

        do_change(options)


# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
if __name__ == '__main__':
    """
    Avoid stack trace; just print exception message and return error code
    """
    try:
        do_main()
    except Exception as err:
        print(err)
        sys.exit(1)
