#!/usr/bin/env python3 """ pgmon is a monitoring intermediary that sits between a PostgreSQL cluster and a monitoring systen that is capable of parsing JSON responses over an HTTP connection. """ # pylint: disable=too-few-public-methods import json import time import os import sys import signal import argparse import logging import re from decimal import Decimal from urllib.parse import urlparse, parse_qs from contextlib import contextmanager from datetime import datetime, timedelta from http.server import BaseHTTPRequestHandler from http.server import ThreadingHTTPServer from threading import Lock import yaml import psycopg2 from psycopg2.extras import RealDictCursor from psycopg2.pool import ThreadedConnectionPool import requests VERSION = "1.1.0-rc1" class Context: """ The global context for connections, config, version, nad IPC """ # Configuration config = {} # Dictionary of current PostgreSQL connection pools connections_lock = Lock() connections = {} # Dictionary of unhappy databases. Keys are database names, value is the time # the database was determined to be unhappy plus the cooldown setting. So, # basically it's the time when we should try to connect to the database again. unhappy_cooldown = {} # Version information cluster_version = None cluster_version_next_check = None cluster_version_lock = Lock() # PostgreSQL latest version information latest_version = None latest_version_next_check = None latest_version_lock = Lock() release_supported = None # Running state (used to gracefully shut down) running = True # The http server object httpd = None # Where the config file lives config_file = None # Configure logging log = logging.getLogger(__name__) @classmethod def init_logging(cls): """ Actually initialize the logging framework. Since we don't ever instantiate the Context class, this provides a way to make a few modifications to the log handler. """ formatter = logging.Formatter( "%(asctime)s - %(levelname)s - %(filename)s: " "%(funcName)s() line %(lineno)d: %(message)s" ) console_log_handler = logging.StreamHandler() console_log_handler.setFormatter(formatter) cls.log.addHandler(console_log_handler) # Error types class ConfigError(Exception): """ Error type for all config related errors. """ class DisconnectedError(Exception): """ Error indicating a previously active connection to the database has been disconnected. """ class UnhappyDBError(Exception): """ Error indicating that a database the code has been asked to connect to is on the unhappy list. """ class UnknownMetricError(Exception): """ Error indicating that an undefined metric was requested. """ class MetricVersionError(Exception): """ Error indicating that there is no suitable query for a metric that was requested for the version of PostgreSQL being monitored. """ class LatestVersionCheckError(Exception): """ Error indicating that there was a problem retrieving or parsing the latest version information. """ class InvalidDataError(Exception): """ Error indicating query results were somehow invalid """ # Default config settings DEFAULT_CONFIG = { # The address the agent binds to "address": "127.0.0.1", # The port the agent listens on for requests "port": 5400, # Min PostgreSQL connection pool size (per database) "min_pool_size": 0, # Max PostgreSQL connection pool size (per database) "max_pool_size": 4, # How long a connection can sit idle in the pool before it's removed (seconds) "max_idle_time": 30, # Log level for stderr logging "log_level": "error", # Database user to connect as "dbuser": "postgres", # Database host "dbhost": "/var/run/postgresql", # Database port "dbport": 5432, # Default database to connect to when none is specified for a metric "dbname": "postgres", # SSL connection mode "ssl_mode": "require", # Timeout for getting a connection slot from a pool "pool_slot_timeout": 5, # PostgreSQL connection timeout (seconds) # Note: It can actually be double this because of retries "connect_timeout": 5, # Time to wait before trying to reconnect again after a reconnect failure (seconds) "reconnect_cooldown": 30, # How often to check the version of PostgreSQL (seconds) "version_check_period": 300, # How often to check the latest supported version of PostgreSQL (seconds) "latest_version_check_period": 86400, # Metrics "metrics": {}, } def update_deep(d1, d2): """ Recursively update a dict, adding keys to dictionaries and appending to lists. Note that this both modifies and returns the first dict. Params: d1: the dictionary to update d2: the dictionary to get new values from Returns: The new d1 """ if not isinstance(d1, dict) or not isinstance(d2, dict): raise TypeError("Both arguments to update_deep need to be dictionaries") for k, v2 in d2.items(): if isinstance(v2, dict): v1 = d1.get(k, {}) if not isinstance(v1, dict): raise TypeError( "Type mismatch between dictionaries: {} is not a dict".format( type(v1).__name__ ) ) d1[k] = update_deep(v1, v2) elif isinstance(v2, list): v1 = d1.get(k, []) if not isinstance(v1, list): raise TypeError( "Type mismatch between dictionaries: {} is not a list".format( type(v1).__name__ ) ) d1[k] = v1 + v2 else: d1[k] = v2 return d1 def validate_metric(path, name, metric): """ Validate a metric definition from a given file. If any query definitions come from external files, the metric dict will be updated with the actual query. Params: path: path to the file which contains this definition name: name of the metric metric: the dictionary containing the metric definition """ # Validate return types try: if metric["type"] not in ["value", "row", "column", "set"]: raise ConfigError( "Invalid return type: {} for metric {} in {}".format( metric["type"], name, path ) ) except KeyError as e: raise ConfigError( "No type specified for metric {} in {}".format(name, path) ) from e # Ensure queries exist query_dict = metric.get("query", {}) if not isinstance(query_dict, dict): raise ConfigError( "Query definition should be a dictionary, got: {} for metric {} in {}".format( query_dict, name, path ) ) if len(query_dict) == 0: raise ConfigError("Missing queries for metric {} in {}".format(name, path)) # Read external sql files and validate version keys config_base = os.path.dirname(path) for vers, query in metric["query"].items(): try: int(vers) except Exception as e: raise ConfigError( "Invalid version: {} for metric {} in {}".format(vers, name, path) ) from e # Read in the external query and update the definition in the metricdictionary if query.startswith("file:"): query_path = query[5:] if not query_path.startswith("/"): query_path = os.path.join(config_base, query_path) with open(query_path, "r", encoding="utf-8") as f: metric["query"][vers] = f.read() def read_config(path, included=False): """ Read a config file. params: path: path to the file to read included: is this file included by another file? """ # Read config file Context.log.info("Reading log file: %s", path) with open(path, "r", encoding="utf-8") as f: try: cfg = yaml.safe_load(f) except yaml.parser.ParserError as e: raise ConfigError("Inavlid config file: {}: {}".format(path, e)) from e # Read any external queries and validate metric definitions for name, metric in cfg.get("metrics", {}).items(): validate_metric(path, name, metric) # Read any included config files config_base = os.path.dirname(path) for inc in cfg.get("include", []): # Prefix relative paths with the directory from the current config if not inc.startswith("/"): inc = os.path.join(config_base, inc) update_deep(cfg, read_config(inc, included=True)) # Return the config we read if this is an include, otherwise set the final # config if included: return cfg new_config = {} update_deep(new_config, DEFAULT_CONFIG) update_deep(new_config, cfg) # Minor sanity checks if len(new_config["metrics"]) == 0: Context.log.error("No metrics are defined") raise ConfigError("No metrics defined") # Validate the new log level before changing the config if new_config["log_level"].upper() not in [ "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", ]: raise ConfigError("Invalid log level: {}".format(new_config["log_level"])) Context.config = new_config # Apply changes to log level Context.log.setLevel(logging.getLevelName(Context.config["log_level"].upper())) # Return the config (mostly to make pylint happy, but also in case I opt to remove the side # effect and make this more functional. return Context.config def signal_handler(sig, frame): # pylint: disable=unused-argument """ Function for handling signals HUP => Reload """ # Restore the original handler signal.signal(signal.SIGINT, signal.default_int_handler) # Signal everything to shut down if sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]: Context.log.info("Shutting down ...") Context.running = False if Context.httpd is not None: Context.httpd.socket.close() # Signal a reload if sig == signal.SIGHUP: Context.log.warning("Received config reload signal") read_config(Context.config_file) class ConnectionPool(ThreadedConnectionPool): """ Threaded connection pool that has a context manager. """ def __init__(self, dbname, minconn, maxconn, *args, **kwargs): # Make sure dbname isn't different in the kwargs kwargs["dbname"] = dbname super().__init__(minconn, maxconn, *args, **kwargs) self.name = dbname @contextmanager def connection(self, timeout): """ Connection context manager for our connection pool. This will attempt to retrieve a connection until the timeout is reached. Params: timeout: how long to keep trying to get a connection bedore giving up """ conn = None timeout_time = datetime.now() + timedelta(timeout) # We will continue to try to get a connection slot until we time out while datetime.now() < timeout_time: # See if we can get a connection slot try: conn = self.getconn() try: yield conn finally: self.putconn(conn) return except psycopg2.pool.PoolError: # If we failed to get the connection slot, wait a bit and try again time.sleep(0.1) raise TimeoutError( "Timed out waiting for an available connection to {}".format(self.name) ) def get_pool(dbname): """ Get a database connection pool. Params: dbname: the name of the database for which a connection pool should be returned. """ # Check if the db is unhappy and wants to be left alone if dbname in Context.unhappy_cooldown: if Context.unhappy_cooldown[dbname] > datetime.now(): raise UnhappyDBError() # Create a connection pool if it doesn't already exist if dbname not in Context.connections: with Context.connections_lock: # Make sure nobody created the pool while we were waiting on the # lock if dbname not in Context.connections: Context.log.info("Creating connection pool for: %s", dbname) # Actually create the connection pool Context.connections[dbname] = ConnectionPool( dbname, int(Context.config["min_pool_size"]), int(Context.config["max_pool_size"]), application_name="pgmon", host=Context.config["dbhost"], port=Context.config["dbport"], user=Context.config["dbuser"], connect_timeout=int(Context.config["connect_timeout"]), sslmode=Context.config["ssl_mode"], ) # Clear the unhappy indicator if present Context.unhappy_cooldown.pop(dbname, None) return Context.connections[dbname] def handle_connect_failure(pool): """ Mark the database as being unhappy so we can leave it alone for a while """ dbname = pool.name Context.unhappy_cooldown[dbname] = datetime.now() + timedelta( seconds=int(Context.config["reconnect_cooldown"]) ) def get_query(metric, version): """ Get the correct metric query for a given version of PostgreSQL. params: metric: The metric definition version: The PostgreSQL version number, as given by server_version_num """ # Select the correct query for v in reversed(sorted(metric["query"].keys())): if version >= v: if len(metric["query"][v].strip()) == 0: raise MetricVersionError( "Metric no longer applies to PostgreSQL {}".format(version) ) return metric["query"][v] raise MetricVersionError("Missing metric query for PostgreSQL {}".format(version)) def json_encode_special(obj): """ Encoder function to handle types the standard JSON package doesn't know what to do with """ if isinstance(obj, Decimal): return float(obj) raise TypeError("Cannot serialize object of {}".format(type(obj))) def json_encode_result(return_type, res): """ Return a json string encoding of the results of a query. params: return_type: the expected structure to return. One of: value, row, column, set res: the query results returns: a json string form of the results raises: ConfigError: when an invalid return_type is given InvalidDataError: when the query results don't match the return type """ try: if return_type == "value": if len(res) == 0: return "" return str(list(res[0].values())[0]) if return_type == "row": return json.dumps( res[0] if len(res) > 0 else {}, default=json_encode_special ) if return_type == "column": return json.dumps( [list(r.values())[0] for r in res], default=json_encode_special ) if return_type == "set": return json.dumps(res, default=json_encode_special) except IndexError as e: raise InvalidDataError(e) from e # If we got to this point, the return type is invalid raise ConfigError("Invalid query return type: {}".format(return_type)) def run_query_no_retry(pool, return_type, query, args): """ Run the query with no explicit retry code """ with pool.connection(float(Context.config["connect_timeout"])) as conn: try: with conn.cursor(cursor_factory=RealDictCursor) as curs: curs.execute(query, args) res = curs.fetchall() return json_encode_result(return_type, res) except Exception as e: dbname = pool.name if dbname in Context.unhappy_cooldown: raise UnhappyDBError() from e if conn.closed != 0: raise DisconnectedError() from e raise def run_query(pool, return_type, query, args): """ Run the query, and if we find upon the first attempt that the connection had been closed, wait a second and try again. This is because psycopg doesn't know if a connection closed (ie: PostgreSQL was restarted or the backend was terminated) until you try to execute a query. Note that the pool has its own retry mechanism as well, but it only applies to new connections being made. Also, this will not retry a query if the query itself failed, or if the database connection could not be established. """ # If we get disconnected, I think the putconn command will close the dead # connection. So we can just give it another shot. try: return run_query_no_retry(pool, return_type, query, args) except DisconnectedError: Context.log.warning("Stale PostgreSQL connection found ... trying again") # This sleep is an annoying hack to give the pool workers time to # actually mark the connection, otherwise it can be given back in the # next connection() call # TODO: verify this is the case with psycopg2 time.sleep(1) try: return run_query_no_retry(pool, return_type, query, args) except Exception as e: handle_connect_failure(pool) raise UnhappyDBError() from e def get_cluster_version(): """ Get the PostgreSQL version if we don't already know it, or if it's been too long sice the last time it was checked. """ # If we don't know the version or it's past the recheck time, get the # version from the database. Only one thread needs to do this, so they all # try to grab the lock, and then make sure nobody else beat them to it. if ( Context.cluster_version is None or Context.cluster_version_next_check is None or Context.cluster_version_next_check < datetime.now() ): with Context.cluster_version_lock: # Only check if nobody already got the version before us if ( Context.cluster_version is None or Context.cluster_version_next_check is None or Context.cluster_version_next_check < datetime.now() ): Context.log.info("Checking PostgreSQL cluster version") pool = get_pool(Context.config["dbname"]) Context.cluster_version = int( run_query(pool, "value", "SHOW server_version_num", None) ) Context.cluster_version_next_check = datetime.now() + timedelta( seconds=int(Context.config["version_check_period"]) ) Context.log.info( "Got PostgreSQL cluster version: %s", Context.cluster_version ) Context.log.debug( "Next PostgreSQL cluster version check will be after: %s", Context.cluster_version_next_check, ) return Context.cluster_version def version_num_to_release(version_num): """ Extract the revease from a version_num. In other words, this converts things like: 90603 => 9.6 130010 => 13 """ if version_num // 10000 < 10: return version_num // 10000 + (version_num % 10000 // 100 / 10) return version_num // 10000 def parse_version_rss(raw_rss, release): """ Parse the raw RSS from the versions.rss feed to extract the latest version of PostgreSQL that's availabe for the cluster being monitored. This sets these Context variables: latest_version release_supported It is expected that the caller already holds the latest_version_lock lock. params: raw_rss: The raw rss text from versions.rss release: The PostgreSQL release we care about (ex: 9.2, 14) """ # Regular expressions for parsing the RSS document version_line = re.compile( r".*?([0-9][0-9.]+) is the latest release in the {} series.*".format(release) ) unsupported_line = re.compile(r"^This version is unsupported") # Loop through the RSS until we find the current release release_found = False for line in raw_rss.splitlines(): m = version_line.match(line) if m: # Note that we found the version we were looking for release_found = True # Convert the version to version_num format version = m.group(1) parts = list(map(int, version.split("."))) if parts[0] < 10: Context.latest_version = int( "{}{:02}{:02}".format(parts[0], parts[1], parts[2]) ) else: Context.latest_version = int("{}00{:02}".format(parts[0], parts[1])) elif release_found: # The next line after the version tells if the version is supported if unsupported_line.match(line): Context.release_supported = False else: Context.release_supported = True break # Make sure we actually found it if not release_found: raise LatestVersionCheckError("Current release ({}) not found".format(release)) Context.log.info( "Got latest PostgreSQL version: %s supported=%s", Context.latest_version, Context.release_supported, ) Context.log.debug( "Next latest PostgreSQL version check will be after: %s", Context.latest_version_next_check, ) def get_latest_version(): """ Get the latest supported version of the major PostgreSQL release running on the server being monitored. """ # If we don't know the latest version or it's past the recheck time, get the # version from the PostgreSQL RSS feed. Only one thread needs to do this, so # they all try to grab the lock, and then make sure nobody else beat them to it. if ( Context.latest_version is None or Context.latest_version_next_check is None or Context.latest_version_next_check < datetime.now() ): # Note: we get the cluster version here before grabbing the latest_version_lock # lock so it's not held while trying to talk with the DB. release = version_num_to_release(get_cluster_version()) with Context.latest_version_lock: # Only check if nobody already got the version before us if ( Context.latest_version is None or Context.latest_version_next_check is None or Context.latest_version_next_check < datetime.now() ): Context.log.info("Checking latest PostgreSQL version") Context.latest_version_next_check = datetime.now() + timedelta( seconds=int(Context.config["latest_version_check_period"]) ) # Grab the RSS feed raw_rss = requests.get( "https://www.postgresql.org/versions.rss", timeout=30 ) if raw_rss.status_code != 200: raise LatestVersionCheckError("code={}".format(raw_rss.status_code)) # Parse the RSS body and set Context variables parse_version_rss(raw_rss.text, release) return Context.latest_version def sample_metric(dbname, metric_name, args, retry=True): """ Run the appropriate query for the named metric against the specified database """ # Get the metric definition try: metric = Context.config["metrics"][metric_name] except KeyError as e: raise UnknownMetricError("Unknown metric: {}".format(metric_name)) from e # Get the connection pool for the database, or create one if it doesn't # already exist. pool = get_pool(dbname) # Identify the PostgreSQL version version = get_cluster_version() # Get the query version query = get_query(metric, version) # Execute the quert if retry: return run_query(pool, metric["type"], query, args) return run_query_no_retry(pool, metric["type"], query, args) def test_queries(): """ Run all of the metric queries against a database and check the results """ # We just use the default db for tests dbname = Context.config["dbname"] # Loop through all defined metrics. for name, metric in Context.config["metrics"].items(): # If the metric has arguments to use while testing, grab those args = metric.get("test_args", {}) print( "Testing {} [{}]".format( name, ", ".join(["{}={}".format(key, value) for key, value in args.items()]), ) ) # When testing against a docker container, we may end up connecting # before the service is truly up (it restarts during the initialization # phase). To cope with this, we'll allow a few connection failures. tries = 5 while True: # Run the query without the ability to retry try: res = sample_metric(dbname, name, args, retry=False) break except MetricVersionError: res = "Unsupported for this version" break except psycopg2.OperationalError as e: print("Error encountered, {} tries left: {}".format(tries, e)) if tries <= 0: raise time.sleep(1) tries -= 1 # Compare the result to the provided sample results # TODO print("{} -> {}".format(name, res)) # Return the number of errors # TODO return 0 class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): """ This is our request handling server. It is responsible for listening for requests, processing them, and responding. """ def log_request(self, code="-", size="-"): """ Override to suppress standard request logging """ def do_GET(self): # pylint: disable=invalid-name """ Handle a request. This is just a wrapper around the actual handler code to keep things more readable. """ try: self._handle_request() except BrokenPipeError: Context.log.error("Client disconnected, exiting handler") def _handle_request(self): """ Request handler """ # Parse the URL parsed_path = urlparse(self.path) metric_name = parsed_path.path.strip("/") parsed_query = parse_qs(parsed_path.query) if metric_name == "agent_version": self._reply(200, VERSION) elif metric_name == "latest_version_info": try: get_latest_version() self._reply( 200, json.dumps( { "latest": Context.latest_version, "supported": 1 if Context.release_supported else 0, } ), ) except LatestVersionCheckError as e: Context.log.error( "Failed to retrieve latest version information: %s", e ) self._reply(503, "Failed to retrieve latest version info") else: # Note: parse_qs returns the values as a list. Since we always expect # single values, just grab the first from each. args = {key: values[0] for key, values in parsed_query.items()} # Get the dbname. If none was provided, use the default from the # config. dbname = args.get("dbname", Context.config["dbname"]) # Sample the metric try: self._reply(200, sample_metric(dbname, metric_name, args)) except UnknownMetricError: Context.log.error("Unknown metric: %s", metric_name) self._reply(404, "Unknown metric") except MetricVersionError: Context.log.error("Failed to find an query version for %s", metric_name) self._reply(404, "Unsupported version") except UnhappyDBError: Context.log.info("Database %s is unhappy, please be patient", dbname) self._reply(503, "Database unavailable") except Exception as e: # pylint: disable=broad-exception-caught Context.log.error("Error running query: %s", e) self._reply(500, "Unexpected error: {}".format(e)) def _reply(self, code, content): """ Send a reply to the client """ self.send_response(code) self.send_header("Content-type", "application/json") self.end_headers() self.wfile.write(bytes(content, "utf-8")) def main(): """ Main application routine """ # Initialize the logging framework Context.init_logging() # Handle cli args parser = argparse.ArgumentParser( prog="pgmon", description="A PostgreSQL monitoring agent" ) parser.add_argument( "-c", "--config_file", default="pgmon.yml", nargs="?", help="The config file to read (default: %(default)s)", ) parser.add_argument( "-t", "--test", action="store_true", help="Run query tests and exit" ) args = parser.parse_args() # Set the config file path Context.config_file = args.config_file # Read the config file read_config(Context.config_file) # Run query tests and exit if test mode is enabled if args.test: if test_queries() > 0: sys.exit(1) sys.exit(0) # Set up the http server to receive requests server_address = (Context.config["address"], Context.config["port"]) Context.httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler) # Set up the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGHUP, signal_handler) # Handle requests. Context.log.info("Listening on port %s...", Context.config["port"]) while Context.running: Context.httpd.handle_request() # Clean up PostgreSQL connections # TODO: Improve this ... not sure it actually closes all the connections cleanly for pool in Context.connections.values(): pool.close() if __name__ == "__main__": main()