From e4b6e981a533d7c4ba337ed296d6da8aef40719e Mon Sep 17 00:00:00 2001 From: James Campbell Date: Thu, 23 May 2024 00:35:44 -0400 Subject: [PATCH] Add multi-db support and lots of comments --- pgmon-metrics.cfg | 3 +- pgmon.cfg | 20 +- pgmon.py | 884 ++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 748 insertions(+), 159 deletions(-) diff --git a/pgmon-metrics.cfg b/pgmon-metrics.cfg index 7433b27..8c73b1c 100644 --- a/pgmon-metrics.cfg +++ b/pgmon-metrics.cfg @@ -7,7 +7,8 @@ metric=version:value::SHOW server_version_num metric=max_frozen_age:value::SELECT max(age(datfrozenxid)) FROM pg_database # Per-database metrics -metric=db_stats:row::SELECT * FROM pg_stat_database WHERE datname = '{datname}' +metric=db_stats:row::SELECT numbackends, xact_commit, xact_rollback, blks_read, blks_hit, tup_returned, tup_fetched, tup_inserted, tup_updated, tup_deleted, conflicts, temp_files, temp_bytes, deadlocks, blk_read_time, blk_write_time, extract('epoch' from stats_reset) FROM pg_stat_database WHERE datname = '{datname}' +metric=db_stats:row:160000:SELECT numbackends, xact_commit, xact_rollback, blks_read, blks_hit, tup_returned, tup_fetched, tup_inserted, tup_updated, tup_deleted, conflicts, temp_files, temp_bytes, deadlocks, checksum_failures, blk_read_time, blk_write_time, session_time, active_time, idle_in_transaction_time, sessions, sessions_abandoned, sessions_fatal, sessions_killed, extract('epoch' from stats_reset) FROM pg_stat_database WHERE datname = '{datname}' # Per-replication metrics metric=rep_stats:row::SELECT * FROM pg_stat_database WHERE client_addr || '_' || regexp_replace(application_name, '[ ,]', '_', 'g') = '{repid}' diff --git a/pgmon.cfg b/pgmon.cfg index fe2f268..ca4306b 100644 --- a/pgmon.cfg +++ b/pgmon.cfg @@ -36,24 +36,26 @@ ## # Log level for stderr logging (or 'off') -#stderr_log_level=info +stderr_log_level=debug # Log level for file logging (od 'off') -#file_log_level=info +file_log_level=off # Log file #log_file=pgmon.log ## # DB connection settings +# +# Each cluster entry is of the form: +# name:address:port:dbname:user:password +# +# Any element other than the name can be left empty to use the defaults ## -#host=localhost -host=/var/run/postgresql -#port=5432 -#user=postgres -user=zbx_monitor -#password=None +#cluster=local:/var/run/postgresql:5432:postgres:zbx_monitor: +cluster=pg15:localhost:54315:postgres:postgres: +cluster=pg96:localhost:54396:postgres:postgres: # Default database to connect to when none is specified for a metric #dbname=postgres @@ -65,4 +67,4 @@ user=zbx_monitor # Metrics #metrics={} -include=/etc/zabbix/pgmon-metrics.cfg +include=pgmon-metrics.cfg diff --git a/pgmon.py b/pgmon.py index fa89418..d960d3c 100755 --- a/pgmon.py +++ b/pgmon.py @@ -12,18 +12,23 @@ import signal import json import logging + # # Errors # class InvalidConfigError(Exception): pass +class UnknownClusterError(Exception): + pass class DuplicateMetricVersionError(Exception): pass class UnsupportedMetricVersionError(Exception): pass class RequestTimeoutError(Exception): pass +class DBError(Exception): + pass # # Logging @@ -44,6 +49,7 @@ def init_logging(config): Initialize (or re-initialize/modify) logging """ global logger + global current_log_file global file_log_handler global stderr_log_handler @@ -73,10 +79,18 @@ def init_logging(config): # Set up or modify file logging if config.file_log_level != 'OFF': + # Checck if we're switching files + if file_log_handler is not None and config.log_file != current_log_file: + old_file_logger = file_log_handler + file_log_handler = None + else: + olf_file_logger = None + # Create and add the handler if it doesn't exist if file_log_handler is None: file_log_handler = logging.FileHandler(config.log_file, encoding='utf-8') logger.addHandler(file_log_handler) + current_log_file = config.log_file # Set the formatter file_log_handler.setFormatter(formatter) @@ -85,13 +99,32 @@ def init_logging(config): level = logging.getLevelName(config.file_log_level) file_log_handler.setLevel(level) + # Remove the old handler if there was one + if old_file_logger is not None: + logger.removeHandler(old_file_logger) + # Note where logs are being written - print("Logging to {}".format(config.log_file)) + print("Logging to {} ({})".format(config.log_file, config.file_log_level)) else: if file_log_handler is not None: logger.removeHandler(file_log_handler) file_log_handler = None + # Set the log level for the logger itself + levels = [] + if stderr_log_handler is not None and stderr_log_handler.level != logging.NOTSET: + levels.append(stderr_log_handler.level) + + if file_log_handler is not None and file_log_handler.level != logging.NOTSET: + levels.append(file_log_handler.level) + + if len(levels) > 0: + logger.setLevel(min(levels)) + else: + # If we have no handlers, just bump the level to the max + logger.setLevel(logging.CRITICAL) + + # # PID file handling # @@ -106,16 +139,26 @@ def read_pid_file(pid_file): def remove_pid_file(pid_file): os.unlink(pid_file) + # # Global flags for signal handler # running = True reload = False + # # Signal handler # def signal_handler(sig, frame): + """ + Function for handling signals + + INT => Shot down + TERM => Shut down + QUIT => Shut down + HUP => Reload + """ # Restore the original handler signal.signal(signal.SIGINT, signal.default_int_handler) @@ -124,11 +167,14 @@ def signal_handler(sig, frame): logger.info("Shutting down ...") global running running = False + + # Signal a reload elif sig == signal.SIGHUP: logger.info("Reloading config ...") global reload reload = True + # # Classes # @@ -136,12 +182,25 @@ def signal_handler(sig, frame): class Config: """ Agent configuration - """ - def __init__(self, config_file, read_metrics = True): - # Set defaults - self.pid_file = '/tmp/pgmon.pid' # PID file - self.ipc_socket = '/tmp/pgmon.sock' # IPC socket + Note: The config is initially loaded before logging is configured, so be + mindful about logging anything in this class. + + Params: + config_file: (str) Path to the config file + read_metrics: (bool) Indicate if metrics should be parsed + read_clusters: (bool) Indicate if cluster information should be parsed + + Exceptions: + InvalidConfigError: Indicates an issue with the config file + OSError: Thrown if there's an issue opening a config file + ValueError: Thrown if there is an encoding error + """ + def __init__(self, config_file, read_metrics = True, read_clusters = True): + # Set defaults + self.pid_file = 'pgmon.pid' # PID file + + self.ipc_socket = 'pgmon.sock' # IPC socket self.ipc_timeout = 10 # IPC communication timeout (s) self.request_timeout = 10 # Request processing timeout (s) self.request_queue_size = 100 # Max size of the request queue before it blocks @@ -152,122 +211,349 @@ class Config: self.file_log_level = 'INFO' # Log level for file logging (od 'off') self.log_file = 'pgmon.log' # Log file - self.host = 'localhost' # DB host - self.port = 5432 # DB port - self.user = 'postgres' # DB user - self.password = None # DB password - self.dbname = 'postgres' # DB name - self.metrics = {} # Metrics - - # Dynamic values - self.pg_version = None # PostgreSQL version + self.clusters = {} # Known clusters # Read config - self.read_file(config_file, read_metrics) + self.read_file(config_file, read_metrics, read_clusters) - def read_file(self, config_file, read_metrics = True): - with open(config_file, 'r') as f: - for line in f: - line = line.strip() + def read_file(self, config_file, read_metrics, read_clusters): + """ + Read a config file, possibly skipping metrics and clusters to lighten + the load on the agent. - if line.startswith('#'): - continue - elif line == '': - continue + Params: + config_file: (str) Path to the config file + read_metrics: (bool) Indicate if metrics should be parsed + read_clusters: (bool) Indicate if cluster information should be parsed - (key, value) = line.split('=', 1) - if value is None: - raise InvalidConfigError("{}: {}", config_file, line) + Exceptions: + InvalidConfigError: Indicates an issue with the config file + """ + try: + with open(config_file, 'r') as f: + for line in f: + # Clean up any whitespace at either end + line = line.strip() - if key == 'include': - self.read_file(value, read_metrics) - elif key == 'pid_file': - self.pid_file = value - elif key == 'ipc_socket': - self.ipc_socket = value - elif key == 'ipc_timeout': - self.ipc_timeout = float(value) - elif key == 'request_timeout': - self.request_timeout = float(value) - elif key == 'request_queue_size': - self.request_queue_size = int(value) - elif key == 'request_queue_timeout': - self.request_queue_timeout = float(value) - elif key == 'worker_count': - self.worker_count = int(value) - elif key == 'stderr_log_level': - self.stderr_log_level = value.upper() - elif key == 'file_log_level': - self.file_log_level = value.upper() - elif key == 'log_file': - self.log_file = value - elif key == 'host': - self.host = value - elif key == 'port': - self.port = int(value) - elif key == 'user': - self.user = value - elif key == 'password': - self.password = value - elif key == 'dbname': - self.dbname = value - elif key == 'metric': - if read_metrics: - self.add_metric(value) - else: - raise InvalidConfigError("WARNING: Unknown config: {}".format(key)) + # Skip empty lines and comments + if line.startswith('#'): + continue + elif line == '': + continue + + # Separate the line into a key-value pair and clean up extra + # white space that may have been around the '=' + (key, value) = [x.strip() for x in line.split('=', 1)] + if value is None: + raise InvalidConfigError("{}: {}", config_file, line) + + # Handle each key appropriately + if key == 'include': + print("Including file: {}".format(value)) + self.read_file(value, read_metrics, read_clusters) + elif key == 'pid_file': + self.pid_file = value + elif key == 'ipc_socket': + self.ipc_socket = value + elif key == 'ipc_timeout': + self.ipc_timeout = float(value) + elif key == 'request_timeout': + self.request_timeout = float(value) + elif key == 'request_queue_size': + self.request_queue_size = int(value) + elif key == 'request_queue_timeout': + self.request_queue_timeout = float(value) + elif key == 'worker_count': + self.worker_count = int(value) + elif key == 'stderr_log_level': + self.stderr_log_level = value.upper() + elif key == 'file_log_level': + self.file_log_level = value.upper() + elif key == 'log_file': + self.log_file = value + elif key == 'cluster': + if read_clusters: + self.add_cluster(value) + elif key == 'metric': + if read_metrics: + print("Adding metric: {}".format(value)) + self.add_metric(value) + else: + raise InvalidConfigError("WARNING: Unknown config option: {}".format(key)) + except OSError as e: + raise InvalidConfigError("Failed to open/read config file: {}".format(e)) + except ValueError as e: + raise InvalidConfigError("Encoding error in config file: {}".format(e)) + + def add_cluster(self, cluster_def): + """ + Parse and add connection information about a cluster to the config. + + Each cluster line is of the format: + :[address]:[port]:[dbname]:[user]:password + + The name is a unique, arbitrary identifier to associate this cluster + with a request from the monitoring agent. + + The address can be an IP address, host name, or path the the directory + where a PostgreSQL socket exists. + + The dbname field is the default database to connect to for metrics + which don't specify a database. This is also used to identify the + PostgreSQL version. + + Default values replace empty conponents, except for the name field: + address: /var/run/postgresql + port: 5432 + dbname: postgres + user: postgres + password: None + + Params: + cluster_def: (str) Cluster definition string + + Exceptions: + InvalidConfigError: Thrown if the cluster entry is missing any fields + or contains invalid content + """ + # Split up the fields + try: + (name, address, port, dbname, user, password) = cluster_def.split(':') + except ValueError: + raise InvalidConfigError("Missing fields in cluster definition: {}".format(cluster_def)) + + # Make sure we have a name + if name == '': + raise InvalidConfigError("Cluster must have a name: {}".format(cluster_def)) + + # Set defaults for anything that's blank + if address == '': + address = '/var/run/postgresql' + + if port == '': + port = 5432 + else: + # Convert the port to a number here + try: + port = int(port) + except ValueError: + raise InvalidConfigError("Invalid port number: {}".format(port)) + + if dbname == '': + dbname = 'postgres' + + if user == '': + user = 'postgres' + + if password == '': + password = None + + # Create and add the cluster object + self.clusters[name] = Cluster(name, address, port, dbname, user, password) def add_metric(self, metric_def): - (name, ret_type, version, sql) = metric_def.split(':', 3) - if sql is None: - raise InvalidConfigError + """ + Parse and add a metric or metric version to the config. - if sql.startswith('file:'): - (_,path) = sql.split(':', 1) - with open(path, 'r') as f: - sql = f.read() + Each metric definition is of the form: + :[return type]:[PostgreSQL version]: + The name is an identifier used to reference this metric from the + monitoring server. + + The return type indicates how to format the results. Possible values + are: + value: A single value is returned (first column of first record) + column: A list is returned containing the first column of all records + set: All records are returned as a list of dicts + + The version field is the first version of PostgreSQL for which this + query is valid. + + The sql field if either the SQL to execute, or a string of the form: + file: + where is the path to a file containing the SQL. In either case, + the SQL can contain references to parameters that are to be passed to + PostgreSQL. These are substituted using python's format command, so + variable names should be enclosed in curly brackets like {foo} + + Params: + metric_def: (str) Metric definition string + + Exceptions: + InvalidConfigError: Thrown if the metric entry is missing any fields + or contains invalid content + """ + # Split up the fields + try: + (name, ret_type, version, sql) = metric_def.split(':', 3) + except ValueError: + raise InvalidConfigError("Missing fields in metric definition: {}".format(metric_def)) + + # Make sure we have a name and some SQL + if name == '': + raise InvalidConfigError("Missing name for metric: {}".format(metric_def)) + + # An empty SQL query indicates a metric is not suported after a particular version + if sql == '': + sql = None + + # If the sql is in a separate file, read it in + try: + if sql.startswith('file:'): + (_,path) = sql.split(':', 1) + with open(path, 'r') as f: + sql = f.read() + except OSError as e: + raise InvalidConfigError("Failed to open/read SQL file: {}".format(e)) + except ValueError as e: + raise InvalidConfigError("Encoding error in SQL file: {}".format(e)) + + # If no starting version is given, set it to 0 if version == '': version = 0 + # Find or create the metric try: metric = self.metrics[name] except KeyError: metric = Metric(name, ret_type) self.metrics[name] = metric + # Add what was given as a version of the metric metric.add_version(int(version), sql) - def get_pg_version(self): - if self.pg_version is None: - db = DB(self) - self.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num']) + def get_pg_version(self, cluster_name): + """ + Return the version of PostgreSQL running on the specified cluster. The + version is cached after the first successful retrieval. + + Params: + cluster_name: (str) The identifier for the cluster, as defined in the + config file + + Returns: + The version using PostgreSQL's integer format (Mmmpp or MM00mm) + + Exceptions: + UnknownClusterError: Thrown if the named cluster is not defined + DBError: A database error occurred + """ + # TODO: Move this out of the Config class. Possibly move the code + # to the DB class, while storing the value in the Cluster object. + + # Find the cluster + try: + cluster = self.clusters[cluster_name] + except KeyError: + raise UnknownClusterError(cluster_name) + + # Query the cluster if we don't already know the version + # TODO: Expire this value at some point to pick up upgrades. + if cluster.pg_version is None: + db = DB(self, cluster_name) + cluster.pg_version = int(db.query('SHOW server_version_num')[0]['server_version_num']) db.close() - return self.pg_version + + # Return the version number + return cluster.pg_version + + +class Cluster: + """ + Connection information for a PostgreSQL cluster + + Params: + name: A unique, arbitrary identifier to associate this cluster with a + request from the monitoring agent. + address: An IP address, host name, or path the the directory where a + PostgreSQL socket exists. + port: The database port number. + dbname: Default database to connect to for metrics which don't specify a + database. This is also used to identify the PostgreSQL version. + user: The database user to connect as. + password: The password to use when connecting to the database. Leave + this blank to use either no password or a password stored in + the executing user's ~/.pgpass file. + """ + def __init__(self, name, address, port, dbname, user, password): + self.name = name + self.address = address + self.dbname = dbname + self.port = port + self.user = user + self.password = password + + # Dynamically acquired PostgreSQL version + self.pg_version = None class DB: """ Database access + + Params: + config: The agent config + cluster_name: The name of the cluster to connect to + dbname: A database to connect to, or None to use the default + + Exceptions: + UnknownClusterError: Thrown if the named cluster is not defined + DBError: Thrown for any database related error """ - def __init__(self, config, dbname=None): + def __init__(self, config, cluster_name, dbname=None): + logger.debug("Creating connection to cluster: {}".format(cluster_name)) + + # Find the named cluster + try: + cluster = config.clusters[cluster_name] + except KeyError: + raise UnknownClusterError(cluster_name) + + # Use the default database if not given one if dbname is None: - dbname = config.dbname - self.conn = psycopg2.connect( - host = config.host, - port = config.port, - user = config.user, - password = config.password, - dbname = dbname); - self.conn.set_session(readonly=True, autocommit=True) + logger.debug("Using default database: {}".format(cluster.dbname)) + dbname = cluster.dbname + + # Connect to the database + try: + self.conn = psycopg2.connect( + host = cluster.address, + port = cluster.port, + user = cluster.user, + password = cluster.password, + dbname = dbname); + except Exception as e: + raise DBError("Failed to connect to the database: {}".format(e)) + + # Make the connection readonly and enable autocommit + try: + self.conn.set_session(readonly=True, autocommit=True) + except Exception as e: + raise DBError("Failed to set session parameters: {}".format(e)) def query(self, sql, args=[]): - with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: - cur.execute(sql, args) - return(cur.fetchall()) + """ + Execute a query in the database. + + Params: + sql: (str) The SQL statement to execute + args: (list) List of positional arguments for the query + + Exceptions: + DBError: Thrown for any database related error + """ + logger.debug("Executuing query: {}".format(sql)) + try: + with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute(sql, args) + return(cur.fetchall()) + except Exception as e: + raise DBError("Failed to execute query: ".format(e)) def close(self): + logger.debug("Closing db connection") try: if self.conn is not None: self.conn.close() @@ -279,8 +565,14 @@ class DB: class Request: """ A metric request + + Params: + cluster_name: (str) The name of the cluster the request is for + metric_name: (str) The name of the metric to obtain + args: (dict) Dictionary of arguments for the metric """ - def __init__(self, metric_name, args = {}): + def __init__(self, cluster_name, metric_name, args = {}): + self.cluster_name = cluster_name self.metric_name = metric_name self.args = args self.result = None @@ -290,74 +582,167 @@ class Request: self.complete.acquire() def set_result(self, result): + """ + Set the result for the metric, and release the lock that allows the + result to be returned to the client. + """ + # Set the result value self.result = result # Release the lock self.complete.release() def get_result(self, timeout = -1): + """ + Retrieve the result for the metric. This will wait for a result to be + available. If timeout is >0, a RequestTimeoutError exception will be + thrown if the specified number of seconds elapses. + + Params: + timeout: (float) Number of seconds to wait before timing out + """ # Wait until the request has been completed if self.complete.acquire(timeout = timeout): return self.result else: - raise RequestTimeoutError + raise RequestTimeoutError() class Metric: """ A metric + + The return_type parameter controls how the results will be formatted when + returned to the client. Possible values are: + value: Return a single value + column: Resturn a list conprised of the first column of the query results + set: Return a list of dictionaries representing the queried records + + Params: + name: (str) The name to associate with the metric + ret_type: (str) The return type of the metric (value, column, or set) """ def __init__(self, name, ret_type): self.name = name self.versions = {} self.type = ret_type - self.cached = None - self.cached_version = None - def add_version(self, version, sql): - if version in self.versions: + # Place holders for the query cache + self.cached = {} + + def add_version(self, pg_version, sql): + """ + Add a versioned query for the metric + + Params: + pg_version: (int) The first version number for which this query + applies, or 0 for any version + sql: (str) The SQL to execute for the specified version + """ + + # Check if we already have SQL for this version + if pg_version in self.versions: raise DuplicateMetricVersionError - self.versions[version] = MetricVersion(sql) - def get_version(self, version): - # Since we're usually going to keep asking for the same version, - # we cache the metric version and only search through the versions - # again if the PostgreSQL version changes - if self.cached is None or self.cached_version != version: - self.cached = None - self.cached_version = None + # Add the versioned SQL + self.versions[pg_version] = MetricVersion(sql) + + def get_version(self, pg_version): + """ + Get an apropriate SQL query for the viven PostgreSQL version + + Params: + pg_version: (int) The version of PostgreSQL for which to find a query + """ + # Since we're usually going to keep asking for the same version(s), + # we cache the metric version the first time we need it. + if pg_version not in self.cached: + self.cached[pg_version] = None + # Search through the cache starting from the lowest version until + # we find a supporting version for v in reversed(sorted(self.versions.keys())): - if version >= v: - self.cached = self.versions[v] - self.cached_version = version + if pg_version >= v: + self.cached[pg_version] = self.versions[v] break - if self.cached is None: + # If we didn't find a query, or the query is None (ie: the metric is + # no longer supported for this version), throw an exception + if pg_version not in self.cached or self.cached[pg_version] is None: raise UnsupportedMetricVersionError - return self.cached + # Return the cached version + return self.cached[pg_version] class MetricVersion: """ A version of a metric + + The query is formatted using str.format with a dictionary of variables, + so you can use tings like {table} in the template, then pass table=foo when + retrieving the SQL and '{table}' will be replaced with 'foo'. + + Note: Only minimal SQL injection checking is done on this. Basically we + just throw an error if any substitution value contains an apostrophe + or a semicolon + + Params: + sql: (str) The SQL query template """ def __init__(self, sql): self.sql = sql def get_sql(self, args): + """ + Return the SQL for this version after substituting the provided + variables. + + Params: + args: (dict) Dictionary of formatting substitutions to apply to the + query template + + Exceptions: + SQLInjectionError: if any of the values to be substituted contain an + apostrophe (') or semicolon (;) + """ + # Check for possible SQL injection attacks + for v in args.values(): + if "'" in v or ';' in v: + raise SQLInjectionError() + + # Format and return the SQL return self.sql.format(**args) class IPC: """ - IPC handler + IPC handler for communication between the agent and server components + + Params: + config: (Config) The agent configuration object + mode: (str) Which side of the communication setup this is (agent|server) + + Exceptions: + RuneimeError: if the IPC object's mode is invalid """ def __init__(self, config, mode): + # Validate the mode + if mode not in [ 'agent', 'server' ]: + raise RuntimeError("Invalid IPC mode: {}".format(self.mode)) + self.config = config self.mode = mode + + # Set up the connection self.reconnect() - def reconnect(self): + def initialize(self): + """ + Initialize the socket + + Exceptions: + OSError: if the socket file already exists and could not be removed + RuntimeError: if the IPC object's mode is invalid + """ if self.mode == 'server': # Try to clean up the socket if it exists try: @@ -368,41 +753,111 @@ class IPC: # Create the socket self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + # Set a timeout for accepting connections to be able to catch signals self.socket.settimeout(1) + + # Bind the socket and start listening self.socket.bind(self.config.ipc_socket) self.socket.listen(1) elif self.mode == 'agent': # Connect to the socket self.conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + # Set the timeout for the agent side of the connection self.conn.settimeout(self.config.ipc_timeout) else: - raise RuntimeError + raise RuntimeError("Invalid IPC mode: {}".format(self.mode)) def connect(self): """ - Establish a connection to a socket + Establish a connection to a socket (agent mode) + + Establish and return a connection to the socket + + Returns: + The connected socket object + + Exceptions: + TimeoutError: if establishing the connection times out """ + # Connect to the socket self.conn.connect(self.config.ipc_socket) + + # Return the socket from this IPC object return self.conn def accept(self): """ - Accept a connection + Accept a connection (server) + + Wait for, accept, and return a connection to the socket + + Returns: + The accepted socket object + + Exceptions: + TimeoutError: if no connection is accepted before the timeout """ + # Accept the connection (conn, _) = self.socket.accept() + + # Set the timeout on the new socket object conn.settimeout(self.config.ipc_timeout) + + # Return the new socket object return conn def send(self, conn, msg): + """ + Send a message across the IPC channel + + The message is encoded in UTF-8, and prefixed with a 4 byte, big endian + unsigned integer + + Params: + conn: (Socket) the connection to use + msg: (str) The message to send + + Exceptions: + TimeoutError: if no connection times out before finishing sending the + message + OverflowError: if the size of the message exceeds what fits in a four + byte, unsigned integer + UnicodeError: if the message is not UTF-8 encodable + """ + # Encode the message as bytes msg_bytes = msg.encode("utf-8") + + # Get the byte length of the message msg_len = len(msg_bytes) + + # Encode the message length msg_len_bytes = msg_len.to_bytes(4, byteorder='big') + + # Send the size and message bytes conn.sendall(msg_len_bytes + msg_bytes) - def recv(self, conn, timeout=-1): + def recv(self, conn): + """ + Receive a message across the IPC channel + + The message is encoded in UTF-8, and prefixed with a 4 byte, big endian + unsigned integer + + Params: + conn: (Socket) the connection to use + + Returns: + The received message + + Exceptions: + TimeoutError: if no connection times out before finishing receiving + the whole message + UnicodeError: if the message is not UTF-8 decodable + """ # Read at least the length buffer = [] msg_len = -4 @@ -410,24 +865,41 @@ class IPC: buffer += conn.recv(1024) msg_len = len(buffer) - 4 + # Pull out the bytes for the length msg_len_bytes = buffer[:4] + + # Convert the size bytes to an unsigned integer msg_len = int.from_bytes(msg_len_bytes, byteorder='big') + # Finish reading the message if we don't have all of it while len(buffer) < msg_len + 4: buffer += conn.recv(1024) + # Decode and return the message return bytes(buffer[4:]).decode("utf-8") class Agent: """ The agent side of the connector + + This mode is entended to be called byt the monitoring agent and + communicates with the server side of the connector. + + The key is a comma separated string formatted as: TODO + + Results are printed to stdout. + + Params: + config_file: (str) The path to the config file + key: (str) The key indicating the requested metric """ @staticmethod def run(config_file, key): - config = Config(config_file, read_metrics = False) + # Read the agent config + config = Config(config_file, read_metrics = False, read_clusters = False) - # Connect to the socket + # Connect to the IPC socket ipc = IPC(config, 'agent') try: conn = ipc.connect() @@ -449,9 +921,16 @@ class Agent: class Server: """ The server side of the connector + + Params: + config_file: (str) The path to the config file + + Exceptions: + """ def __init__(self, config_file): + # Note the path to the config file so it can be reloaded self.config_file = config_file # Load config @@ -470,31 +949,60 @@ class Server: signal.signal(signal.SIGHUP, signal_handler) # Create reqest queue + logger.debug("Creating request queue") self.req_queue = queue.Queue(self.config.request_queue_size) # Spawn worker threads + logger.debug("Spawning worker threads") self.workers = self.spawn_workers(self.config) def spawn_workers(self, config): + """ + Spawn all worker threads to process requests + + Params: + config: (Config) The agent config object + """ logger.info("Spawning {} workers".format(config.worker_count)) + # Spawn worker threads workers = [None] * config.worker_count for i in range(config.worker_count): workers[i] = Worker(config, self.req_queue) workers[i].start() + logger.debug("Started thread #{} (tid={})".format(i, workers[i].native_id)) + # Return the list of worker threads return workers def retire_workers(self, workers): + """ + Retire (terminate) all worker threads in the given list of threads + + Params: + workers: (list) List of worker threads to part ways with + """ logger.info("Retiring {} workers".format(len(workers))) + # Inform the workers that their services are no longer required for worker in workers: worker.active = False + # Wait for the workers to turn in their badges for worker in workers: worker.join() def reload_config(self): + """ + Reload the config and return a new config object, and spawn new worker + threads + + Exceptions: + InvalidConfigError: Indicates an issue with the config file + OSError: Thrown if there's an issue opening a config file + ValueError: Thrown if there is an encoding error when reading the + config + """ # Clear the reload flag global reload reload = False @@ -518,14 +1026,19 @@ class Server: # Adjust other settings # TODO + # Set the new config as the one the server will use self.config = new_config def run(self): + """ + Run the server's main loop + """ logger.info("Server starting") # Listen on ipc socket ipc = IPC(self.config, 'server') + # Enter the main loop while True: # Wait for a request connection try: @@ -549,38 +1062,48 @@ class Server: continue # Get the request string (csv) - key = ipc.recv(conn) - - # Parse ipc request (csv) try: - parts = key.split(',', 1) - metric_name = parts[0] - args_dict = {} - - if len(parts) > 1: - for arg in parts[1].split(','): - if arg != '': - (k, v) = arg.split('=', 1) - args_dict[k] = v + key = ipc.recv(conn) except socket.timeout: + # Handle timeouts when receiving the request logger.warning("IPC communication timeout receiving request") conn.close() continue - except Exception: - logger.warning("Received invalid request: '{}'".format(key)) + + # Parse ipc request (csv) + logger.debug("Parsing request key: {}".format(key)) + try: + # Split the key into a cluster name, metric name, and list of + # metric arguments + parts = key.split(',', 2) + cluster_name = parts[0] + metric_name = parts[1] + + # Parse any metric arguments into a dictionary + args_dict = {} + if len(parts) > 2: + for arg in parts[2].split(','): + if arg != '': + (k, v) = arg.split('=', 1) + args_dict[k] = v + except Exception as e: + # Handle problems parsing the request into its elements + logger.warning("Received invalid request '{}': {}".format(key, e)) ipc.send(conn, "ERROR: Invalid key") conn.close() continue # Create request object - req = Request(metric_name, args_dict) + req = Request(cluster_name, metric_name, args_dict) # Queue the request try: self.req_queue.put(req, timeout=self.config.request_queue_timeout) - except: - logger.warning("Failed to queue request") - req.set_result("ERROR: Queue timeout") + except queue.Full: + # Handle situations where the queue is full and we didn't get + # a free slot before the configured timeout + logger.warning("Failed to queue request, queue is full") + req.set_result("ERROR: Enqueue timeout") continue # Spawn a thread to wait for the result @@ -593,12 +1116,20 @@ class Server: # Clean up the PID file remove_pid_file(self.pid_file) + # Be polite logger.info("Good bye") + # Gracefully shut down logging + logging.shutdown() + class Worker(threading.Thread): """ Worker thread that processes requests (ie: queries the database) + + Params: + config: (Config) The agent config object + queue: (Queue) The request queue the worker should pull requests from """ def __init__(self, config, queue): super(Worker, self).__init__() @@ -609,6 +1140,9 @@ class Worker(threading.Thread): self.active = True def run(self): + """ + Main processing loop for a worker thread + """ while True: # Wait for a request try: @@ -618,6 +1152,15 @@ class Worker(threading.Thread): # Check if we're supposed to exit if not self.active: + # If we got a request, try to put it back on the queue + if req is not None: + try: + queue.put(req, timeout=1) + logger.info("Requeued request at worker exit") + except: + logger.warning("Failed to requeue request at worker exit") + req.set_result("ERROR: Failed to requeue at Worker exit") + logger.info("Worker exiting: tid={}".format(self.native_id)) break @@ -629,42 +1172,47 @@ class Worker(threading.Thread): try: metric = self.config.metrics[req.metric_name] except KeyError: - req.set_result("ERROR: Unknown key '{}'".format(req.metric_name)) + req.set_result("ERROR: Unknown metric: {}".format(req.metric_name)) continue # Get the DB version try: - pg_version = self.config.get_pg_version() + pg_version = self.config.get_pg_version(req.cluster_name) except Exception as e: - req.set_result("Failed to retrieve database version") - logger.error("Failed to get Postgresql version: {}".format(e)) + req.set_result("Failed to retrieve database version for cluster: {}".format(req.cluster_name)) + logger.error("Failed to get Postgresql version for cluster {}: {}".format(req.cluster_name, e)) continue # Get the query to use try: mv = metric.get_version(pg_version) except UnsupportedMetricVersionError: + # Handle unsuported metric versions req.set_result("Unsupported PosgreSQL version for metric") continue - # Query the database + # Identify the database to query try: dbname = req.args['db'] except KeyError: dbname = None + # Query the database try: - db = DB(self.config, dbname) + db = DB(self.config, req.cluster_name, dbname) res = db.query(mv.get_sql(req.args)) db.close() except Exception as e: + # Handle database errors + logger.error("Database error: {}".format(e)) + # Make sure the database connectin is closed (ie: if the query timed out) try: db.close() except: pass + req.set_result("Failed to query database") - logger.error("Database error: {}".format(e)) continue # Set the result on the request @@ -679,14 +1227,22 @@ class Worker(threading.Thread): elif metric.type == 'set': req.set_result(json.dumps(res)) + # Close the database connection try: db.close() except Exception as e: logger.debug("Failed to close database connection: {}".format(e)) + class Responder(threading.Thread): """ Thread responsible for replying to requests + + Params: + config: (Config) The agent config object + ipc: (IPC) The IPC object used for communication + conn: (Socket) The connected socket to communicate with + req: (Request) The request object to handle """ def __init__(self, config, ipc, conn, req): super(Responder, self).__init__() @@ -695,9 +1251,14 @@ class Responder(threading.Thread): result = req.get_result() # Send the result back to the client - ipc.send(conn, result) + try: + ipc.send(conn, result) + except Exception as e: + logger.warning("Failed to reply to agent: {}".format(e)) + def main(): + # Set up command line argument parser parser = argparse.ArgumentParser( prog='pgmon', description='A briidge between monitoring tools and PostgreSQL') @@ -713,22 +1274,46 @@ def main(): # Agent options parser.add_argument('key', nargs='?') + # Parse command line arguments args = parser.parse_args() if args.server: - server = Server(args.config) - server.run() + # Try to start running in server mode + try: + server = Server(args.config) + except Exception as e: + sys.exit("Failed to start server: {}".format(e)) + + try: + server.run() + except Exception as e: + sys.exit("Caught an unexpected runtime error: {}".format(e)) + elif args.reload: - config = Config(args.config, read_metrics = False) - pid = read_pid_file(config.pid_file) - os.kill(pid, signal.SIGHUP) + # Try to signal a running server to reload its config + try: + config = Config(args.config, read_metrics = False) + except Exception as e: + sys.exit("Failed to read config file: {}".format(e)) + + # Read the PID file + try: + pid = read_pid_file(config.pid_file) + except Exception as e: + sys.exit("Failed to read PID file: {}".format(e)) + + # Signal the server to reload + try: + os.kill(pid, signal.SIGHUP) else: + # Start running in agent mode Agent.run(args.config, args.key) if __name__ == '__main__': main() + ## # Tests ## @@ -741,19 +1326,20 @@ class TestDB: class TestRequest: def test_request_creation(self): # Create result with no args - req1 = Request('foo', {}) + req1 = Request('c1', 'foo', {}) + assert req1.cluster_name == 'c1' assert req1.metric_name == 'foo' assert len(req1.args) == 0 assert req1.complete.locked() # Create result with args - req2 = Request('foo', {'arg1': 'value1', 'arg2': 'value2'}) + req2 = Request('c1', 'foo', {'arg1': 'value1', 'arg2': 'value2'}) assert req2.metric_name == 'foo' assert len(req2.args) == 2 assert req2.complete.locked() def test_request_lock(self): - req1 = Request('foo', {}) + req1 = Request('c1', 'foo', {}) assert req1.complete.locked() req1.set_result('blah') assert not req1.complete.locked()