#!/usr/bin/env python3 import yaml import json import time import os import sys import argparse import logging from datetime import datetime, timedelta import psycopg2 from psycopg2.extras import RealDictCursor from psycopg2.pool import ThreadedConnectionPool from contextlib import contextmanager import signal from threading import Thread, Lock, Semaphore from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import ThreadingHTTPServer from urllib.parse import urlparse, parse_qs VERSION = "1.0.1" # Configuration config = {} # Dictionary of current PostgreSQL connection pools connections_lock = Lock() connections = {} # Dictionary of unhappy databases. Keys are database names, value is the time # the database was determined to be unhappy plus the cooldown setting. So, # basically it's the time when we should try to connect to the database again. unhappy_cooldown = {} # Version information cluster_version = None cluster_version_next_check = None cluster_version_lock = Lock() # Running state (used to gracefully shut down) running = True # The http server object httpd = None # Where the config file lives config_file = None # Configure logging log = logging.getLogger(__name__) formatter = logging.Formatter( "%(asctime)s - %(levelname)s - %(filename)s: %(funcName)s() line %(lineno)d: %(message)s" ) console_log_handler = logging.StreamHandler() console_log_handler.setFormatter(formatter) log.addHandler(console_log_handler) # Error types class ConfigError(Exception): pass class DisconnectedError(Exception): pass class UnhappyDBError(Exception): pass class UnknownMetricError(Exception): pass class MetricVersionError(Exception): pass # Default config settings default_config = { # The port the agent listens on for requests "port": 5400, # Min PostgreSQL connection pool size (per database) "min_pool_size": 0, # Max PostgreSQL connection pool size (per database) "max_pool_size": 4, # How long a connection can sit idle in the pool before it's removed (seconds) "max_idle_time": 30, # Log level for stderr logging "log_level": "error", # Database user to connect as "dbuser": "postgres", # Database host "dbhost": "/var/run/postgresql", # Database port "dbport": 5432, # Default database to connect to when none is specified for a metric "dbname": "postgres", # Timeout for getting a connection slot from a pool "pool_slot_timeout": 5, # PostgreSQL connection timeout (seconds) # Note: It can actually be double this because of retries "connect_timeout": 5, # Time to wait before trying to reconnect again after a reconnect failure (seconds) "reconnect_cooldown": 30, # How often to check the version of PostgreSQL (seconds) "version_check_period": 300, # Metrics "metrics": {}, } def update_deep(d1, d2): """ Recursively update a dict, adding keys to dictionaries and appending to lists. Note that this both modifies and returns the first dict. Params: d1: the dictionary to update d2: the dictionary to get new values from Returns: The new d1 """ if not isinstance(d1, dict) or not isinstance(d2, dict): raise TypeError("Both arguments to update_deep need to be dictionaries") for k, v2 in d2.items(): if isinstance(v2, dict): v1 = d1.get(k, {}) if not isinstance(v1, dict): raise TypeError( "Type mismatch between dictionaries: {} is not a dict".format( type(v1).__name__ ) ) d1[k] = update_deep(v1, v2) elif isinstance(v2, list): v1 = d1.get(k, []) if not isinstance(v1, list): raise TypeError( "Type mismatch between dictionaries: {} is not a list".format( type(v1).__name__ ) ) d1[k] = v1 + v2 else: d1[k] = v2 return d1 def read_config(path, included=False): """ Read a config file. params: path: path to the file to read included: is this file included by another file? """ # Read config file log.info("Reading log file: {}".format(path)) with open(path, "r") as f: try: cfg = yaml.safe_load(f) except yaml.parser.ParserError as e: raise ConfigError("Inavlid config file: {}: {}".format(path, e)) # Since we use it a few places, get the base directory from the config config_base = os.path.dirname(path) # Read any external queries and validate metric definitions for name, metric in cfg.get("metrics", {}).items(): # Validate return types try: if metric["type"] not in ["value", "row", "column", "set"]: raise ConfigError( "Invalid return type: {} for metric {} in {}".format( metric["type"], name, path ) ) except KeyError: raise ConfigError( "No type specified for metric {} in {}".format(name, path) ) # Ensure queries exist query_dict = metric.get("query", {}) if type(query_dict) is not dict: raise ConfigError( "Query definition should be a dictionary, got: {} for metric {} in {}".format( query_dict, name, path ) ) if len(query_dict) == 0: raise ConfigError("Missing queries for metric {} in {}".format(name, path)) # Read external sql files and validate version keys for vers, query in metric["query"].items(): try: int(vers) except: raise ConfigError( "Invalid version: {} for metric {} in {}".format(vers, name, path) ) if query.startswith("file:"): query_path = query[5:] if not query_path.startswith("/"): query_path = os.path.join(config_base, query_path) with open(query_path, "r") as f: metric["query"][vers] = f.read() # Read any included config files for inc in cfg.get("include", []): # Prefix relative paths with the directory from the current config if not inc.startswith("/"): inc = os.path.join(config_base, inc) update_deep(cfg, read_config(inc, included=True)) # Return the config we read if this is an include, otherwise set the final # config if included: return cfg else: new_config = {} update_deep(new_config, default_config) update_deep(new_config, cfg) # Minor sanity checks if len(new_config["metrics"]) == 0: log.error("No metrics are defined") raise ConfigError("No metrics defined") # Validate the new log level before changing the config if new_config["log_level"].upper() not in [ "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", ]: raise ConfigError("Invalid log level: {}".format(new_config["log_level"])) global config config = new_config # Apply changes to log level log.setLevel(logging.getLevelName(config["log_level"].upper())) def signal_handler(sig, frame): """ Function for handling signals HUP => Reload """ # Restore the original handler signal.signal(signal.SIGINT, signal.default_int_handler) # Signal everything to shut down if sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]: log.info("Shutting down ...") global running running = False if httpd is not None: httpd.socket.close() # Signal a reload if sig == signal.SIGHUP: log.warning("Received config reload signal") read_config(config_file) class ConnectionPool(ThreadedConnectionPool): def __init__(self, dbname, minconn, maxconn, *args, **kwargs): # Make sure dbname isn't different in the kwargs kwargs["dbname"] = dbname super().__init__(minconn, maxconn, *args, **kwargs) self.name = dbname @contextmanager def connection(self, timeout=None): conn = None timeout_time = datetime.now() + timedelta(timeout) # We will continue to try to get a connection slot until we time out while datetime.now() < timeout_time: # See if we can get a connection slot try: conn = self.getconn() try: yield conn finally: self.putconn(conn) return except psycopg2.pool.PoolError: # If we failed to get the connection slot, wait a bit and try again time.sleep(0.1) raise TimeoutError( "Timed out waiting for an available connection to {}".format(self.name) ) def get_pool(dbname): """ Get a database connection pool. """ # Check if the db is unhappy and wants to be left alone if dbname in unhappy_cooldown: if unhappy_cooldown[dbname] > datetime.now(): raise UnhappyDBError() # Create a connection pool if it doesn't already exist if dbname not in connections: with connections_lock: # Make sure nobody created the pool while we were waiting on the # lock if dbname not in connections: log.info("Creating connection pool for: {}".format(dbname)) connections[dbname] = ConnectionPool( dbname, int(config["min_pool_size"]), int(config["max_pool_size"]), application_name="pgmon", host=config["dbhost"], port=config["dbport"], user=config["dbuser"], connect_timeout=int(config["connect_timeout"]), sslmode="require", ) # Clear the unhappy indicator if present unhappy_cooldown.pop(dbname, None) return connections[dbname] def handle_connect_failure(pool): """ Mark the database as being unhappy so we can leave it alone for a while """ dbname = pool.name unhappy_cooldown[dbname] = datetime.now() + timedelta( seconds=int(config["reconnect_cooldown"]) ) def get_query(metric, version): """ Get the correct metric query for a given version of PostgreSQL. params: metric: The metric definition version: The PostgreSQL version number, as given by server_version_num """ # Select the correct query for v in reversed(sorted(metric["query"].keys())): if version >= v: if len(metric["query"][v].strip()) == 0: raise MetricVersionError( "Metric no longer applies to PostgreSQL {}".format(version) ) return metric["query"][v] raise MetricVersionError("Missing metric query for PostgreSQL {}".format(version)) def run_query_no_retry(pool, return_type, query, args): """ Run the query with no explicit retry code """ with pool.connection(float(config["connect_timeout"])) as conn: try: with conn.cursor(cursor_factory=RealDictCursor) as curs: curs.execute(query, args) res = curs.fetchall() if return_type == "value": return str(list(res[0].values())[0]) elif return_type == "row": return json.dumps(res[0]) elif return_type == "column": return json.dumps([list(r.values())[0] for r in res]) elif return_type == "set": return json.dumps(res) except: dbname = pool.name if dbname in unhappy_cooldown: raise UnhappyDBError() elif conn.broken: raise DisconnectedError() else: raise def run_query(pool, return_type, query, args): """ Run the query, and if we find upon the first attempt that the connection had been closed, wait a second and try again. This is because psycopg doesn't know if a connection closed (ie: PostgreSQL was restarted or the backend was terminated) until you try to execute a query. Note that the pool has its own retry mechanism as well, but it only applies to new connections being made. Also, this will not retry a query if the query itself failed, or if the database connection could not be established. """ # If we get disconnected, I think the putconn command will close the dead # connection. So we can just give it another shot. try: return run_query_no_retry(pool, return_type, query, args) except DisconnectedError: log.warning("Stale PostgreSQL connection found ... trying again") # This sleep is an annoying hack to give the pool workers time to # actually mark the connection, otherwise it can be given back in the # next connection() call # TODO: verify this is the case with psycopg2 time.sleep(1) try: return run_query_no_retry(pool, return_type, query, args) except: handle_connect_failure(pool) raise UnhappyDBError() def get_cluster_version(): """ Get the PostgreSQL version if we don't already know it, or if it's been too long sice the last time it was checked. """ global cluster_version global cluster_version_next_check # If we don't know the version or it's past the recheck time, get the # version from the database. Only one thread needs to do this, so they all # try to grab the lock, and then make sure nobody else beat them to it. if ( cluster_version is None or cluster_version_next_check is None or cluster_version_next_check < datetime.now() ): with cluster_version_lock: # Only check if nobody already got the version before us if ( cluster_version is None or cluster_version_next_check is None or cluster_version_next_check < datetime.now() ): log.info("Checking PostgreSQL cluster version") pool = get_pool(config["dbname"]) cluster_version = int( run_query(pool, "value", "SHOW server_version_num", None) ) cluster_version_next_check = datetime.now() + timedelta( seconds=int(config["version_check_period"]) ) log.info("Got PostgreSQL cluster version: {}".format(cluster_version)) log.debug( "Next PostgreSQL cluster version check will be after: {}".format( cluster_version_next_check ) ) return cluster_version def sample_metric(dbname, metric_name, args, retry=True): """ Run the appropriate query for the named metric against the specified database """ # Get the metric definition try: metric = config["metrics"][metric_name] except KeyError: raise UnknownMetricError("Unknown metric: {}".format(metric_name)) # Get the connection pool for the database, or create one if it doesn't # already exist. pool = get_pool(dbname) # Identify the PostgreSQL version version = get_cluster_version() # Get the query version query = get_query(metric, version) # Execute the quert if retry: return run_query(pool, metric["type"], query, args) else: return run_query_no_retry(pool, metric["type"], query, args) def test_queries(): """ Run all of the metric queries against a database and check the results """ # We just use the default db for tests dbname = config["dbname"] # Loop through all defined metrics. for metric_name in config["metrics"].keys(): # Get the actual metric definition metric = metrics[metric_name] # If the metric has arguments to use while testing, grab those args = metric.get("test_args", {}) # Run the query without the ability to retry. res = sample_metric(dbname, metric_name, args, retry=False) # Compare the result to the provided sample results # TODO # Return the number of errors # TODO return 0 class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): """ This is our request handling server. It is responsible for listening for requests, processing them, and responding. """ def log_request(self, code="-", size="-"): """ Override to suppress standard request logging """ pass def do_GET(self): """ Handle a request. This is just a wrapper around the actual handler code to keep things more readable. """ try: self._handle_request() except BrokenPipeError: log.error("Client disconnected, exiting handler") def _handle_request(self): """ Request handler """ # Parse the URL parsed_path = urlparse(self.path) metric_name = parsed_path.path.strip("/") parsed_query = parse_qs(parsed_path.query) if metric_name == "agent_version": self._reply(200, VERSION) return # Note: parse_qs returns the values as a list. Since we always expect # single values, just grab the first from each. args = {key: values[0] for key, values in parsed_query.items()} # Get the dbname. If none was provided, use the default from the # config. dbname = args.get("dbname", config["dbname"]) # Sample the metric try: self._reply(200, sample_metric(dbname, metric_name, args)) return except UnknownMetricError as e: log.error("Unknown metric: {}".format(metric_name)) self._reply(404, "Unknown metric") return except MetricVersionError as e: log.error( "Failed to find a version of {} for {}".format(metric_name, version) ) self._reply(404, "Unsupported version") return except UnhappyDBError as e: log.info("Database {} is unhappy, please be patient".format(dbname)) self._reply(503, "Database unavailable") return except Exception as e: log.error("Error running query: {}".format(e)) self._reply(500, "Unexpected error: {}".format(e)) return def _reply(self, code, content): """ Send a reply to the client """ self.send_response(code) self.send_header("Content-type", "application/json") self.end_headers() self.wfile.write(bytes(content, "utf-8")) if __name__ == "__main__": # Handle cli args parser = argparse.ArgumentParser( prog="pgmon", description="A PostgreSQL monitoring agent" ) parser.add_argument( "config_file", default="pgmon.yml", nargs="?", help="The config file to read (default: %(default)s)", ) parser.add_argument("test", action="store_true", help="Run query tests and exit") args = parser.parse_args() # Set the config file path config_file = args.config_file # Read the config file read_config(config_file) # Run query tests and exit if test mode is enabled if args.test: errors = test_queries() if errors > 0: sys.exit(1) else: sys.exit(0) # Set up the http server to receive requests server_address = ("127.0.0.1", config["port"]) httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler) # Set up the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGHUP, signal_handler) # Handle requests. log.info("Listening on port {}...".format(config["port"])) while running: httpd.handle_request() # Clean up PostgreSQL connections # TODO: Improve this ... not sure it actually closes all the connections cleanly for pool in connections.values(): pool.close()