From 43cd162313789681951106d3c78a9a1cac122378 Mon Sep 17 00:00:00 2001 From: James Campbell Date: Tue, 23 Sep 2025 01:12:49 -0400 Subject: [PATCH] Refactor to make pylint happy --- Makefile | 14 +- pylintrc | 4 + src/pgmon.py | 623 +++++++++++++++++++++++++--------------------- src/test_pgmon.py | 534 +++++++++++++++++++++++++++++---------- 4 files changed, 767 insertions(+), 408 deletions(-) create mode 100644 pylintrc diff --git a/Makefile b/Makefile index 95ae465..fa046d0 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ SUPPORTED := ubuntu-20.04 \ # These targets are the main ones to use for most things. ## -.PHONY: all clean tgz test query-tests install-common install-openrc install-systemd +.PHONY: all clean tgz lint format test query-tests install-common install-openrc install-systemd all: package-all @@ -80,6 +80,18 @@ tgz: clean: rm -rf $(BUILD_DIR) +# Check for lint +lint: + pylint src/pgmon.py + pylint src/test_pgmon.py + black --check --diff src/pgmon.py + black --check --diff src/test_pgmon.py + +# Format the code using black +format: + black src/pgmon.py + black src/test_pylint.py + # Run unit tests for the script test: cd src ; python3 -m unittest diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..d5f47b8 --- /dev/null +++ b/pylintrc @@ -0,0 +1,4 @@ +[MASTER] +py-version=3.5 + +disable=fixme diff --git a/src/pgmon.py b/src/pgmon.py index e5bd3fb..738ae64 100755 --- a/src/pgmon.py +++ b/src/pgmon.py @@ -1,105 +1,141 @@ #!/usr/bin/env python3 +""" +pgmon is a monitoring intermediary that sits between a PostgreSQL cluster and a monitoring systen +that is capable of parsing JSON responses over an HTTP connection. +""" + +# pylint: disable=too-few-public-methods -import yaml import json import time import os import sys - +import signal import argparse import logging +import re + +from decimal import Decimal + +from urllib.parse import urlparse, parse_qs + +from contextlib import contextmanager from datetime import datetime, timedelta +from http.server import BaseHTTPRequestHandler +from http.server import ThreadingHTTPServer + +from threading import Lock + +import yaml + import psycopg2 from psycopg2.extras import RealDictCursor from psycopg2.pool import ThreadedConnectionPool -from contextlib import contextmanager - -import signal -from threading import Thread, Lock, Semaphore - -from http.server import BaseHTTPRequestHandler, HTTPServer -from http.server import ThreadingHTTPServer -from urllib.parse import urlparse, parse_qs - import requests -import re -from decimal import Decimal VERSION = "1.0.4" -# Configuration -config = {} -# Dictionary of current PostgreSQL connection pools -connections_lock = Lock() -connections = {} +class Context: + """ + The global context for connections, config, version, nad IPC + """ -# Dictionary of unhappy databases. Keys are database names, value is the time -# the database was determined to be unhappy plus the cooldown setting. So, -# basically it's the time when we should try to connect to the database again. -unhappy_cooldown = {} + # Configuration + config = {} -# Version information -cluster_version = None -cluster_version_next_check = None -cluster_version_lock = Lock() + # Dictionary of current PostgreSQL connection pools + connections_lock = Lock() + connections = {} -# PostgreSQL latest version information -latest_version = None -latest_version_next_check = None -latest_version_lock = Lock() -release_supported = None + # Dictionary of unhappy databases. Keys are database names, value is the time + # the database was determined to be unhappy plus the cooldown setting. So, + # basically it's the time when we should try to connect to the database again. + unhappy_cooldown = {} -# Running state (used to gracefully shut down) -running = True + # Version information + cluster_version = None + cluster_version_next_check = None + cluster_version_lock = Lock() -# The http server object -httpd = None + # PostgreSQL latest version information + latest_version = None + latest_version_next_check = None + latest_version_lock = Lock() + release_supported = None -# Where the config file lives -config_file = None + # Running state (used to gracefully shut down) + running = True -# Configure logging -log = logging.getLogger(__name__) -formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(filename)s: %(funcName)s() line %(lineno)d: %(message)s" -) -console_log_handler = logging.StreamHandler() -console_log_handler.setFormatter(formatter) -log.addHandler(console_log_handler) + # The http server object + httpd = None + + # Where the config file lives + config_file = None + + # Configure logging + log = logging.getLogger(__name__) + + @classmethod + def init_logging(cls): + """ + Actually initialize the logging framework. Since we don't ever instantiate the Context + class, this provides a way to make a few modifications to the log handler. + """ + + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(filename)s: " + "%(funcName)s() line %(lineno)d: %(message)s" + ) + console_log_handler = logging.StreamHandler() + console_log_handler.setFormatter(formatter) + cls.log.addHandler(console_log_handler) # Error types class ConfigError(Exception): - pass + """ + Error type for all config related errors. + """ class DisconnectedError(Exception): - pass + """ + Error indicating a previously active connection to the database has been disconnected. + """ class UnhappyDBError(Exception): - pass + """ + Error indicating that a database the code has been asked to connect to is on the unhappy list. + """ class UnknownMetricError(Exception): - pass + """ + Error indicating that an undefined metric was requested. + """ class MetricVersionError(Exception): - pass + """ + Error indicating that there is no suitable query for a metric that was requested for the + version of PostgreSQL being monitored. + """ class LatestVersionCheckError(Exception): - pass + """ + Error indicating that there was a problem retrieving or parsing the latest version information. + """ # Default config settings -default_config = { +DEFAULT_CONFIG = { # The address the agent binds to "address": "127.0.0.1", # The port the agent listens on for requests @@ -177,6 +213,60 @@ def update_deep(d1, d2): return d1 +def validate_metric(path, name, metric): + """ + Validate a metric definition from a given file. If any query definitions come from external + files, the metric dict will be updated with the actual query. + + Params: + path: path to the file which contains this definition + name: name of the metric + metric: the dictionary containing the metric definition + """ + # Validate return types + try: + if metric["type"] not in ["value", "row", "column", "set"]: + raise ConfigError( + "Invalid return type: {} for metric {} in {}".format( + metric["type"], name, path + ) + ) + except KeyError as e: + raise ConfigError( + "No type specified for metric {} in {}".format(name, path) + ) from e + + # Ensure queries exist + query_dict = metric.get("query", {}) + if not isinstance(query_dict, dict): + raise ConfigError( + "Query definition should be a dictionary, got: {} for metric {} in {}".format( + query_dict, name, path + ) + ) + + if len(query_dict) == 0: + raise ConfigError("Missing queries for metric {} in {}".format(name, path)) + + # Read external sql files and validate version keys + config_base = os.path.dirname(path) + for vers, query in metric["query"].items(): + try: + int(vers) + except Exception as e: + raise ConfigError( + "Invalid version: {} for metric {} in {}".format(vers, name, path) + ) from e + + # Read in the external query and update the definition in the metricdictionary + if query.startswith("file:"): + query_path = query[5:] + if not query_path.startswith("/"): + query_path = os.path.join(config_base, query_path) + with open(query_path, "r", encoding="utf-8") as f: + metric["query"][vers] = f.read() + + def read_config(path, included=False): """ Read a config file. @@ -185,61 +275,21 @@ def read_config(path, included=False): path: path to the file to read included: is this file included by another file? """ + # Read config file - log.info("Reading log file: {}".format(path)) - with open(path, "r") as f: + Context.log.info("Reading log file: %s", path) + with open(path, "r", encoding="utf-8") as f: try: cfg = yaml.safe_load(f) except yaml.parser.ParserError as e: - raise ConfigError("Inavlid config file: {}: {}".format(path, e)) - - # Since we use it a few places, get the base directory from the config - config_base = os.path.dirname(path) + raise ConfigError("Inavlid config file: {}: {}".format(path, e)) from e # Read any external queries and validate metric definitions for name, metric in cfg.get("metrics", {}).items(): - # Validate return types - try: - if metric["type"] not in ["value", "row", "column", "set"]: - raise ConfigError( - "Invalid return type: {} for metric {} in {}".format( - metric["type"], name, path - ) - ) - except KeyError: - raise ConfigError( - "No type specified for metric {} in {}".format(name, path) - ) - - # Ensure queries exist - query_dict = metric.get("query", {}) - if type(query_dict) is not dict: - raise ConfigError( - "Query definition should be a dictionary, got: {} for metric {} in {}".format( - query_dict, name, path - ) - ) - - if len(query_dict) == 0: - raise ConfigError("Missing queries for metric {} in {}".format(name, path)) - - # Read external sql files and validate version keys - for vers, query in metric["query"].items(): - try: - int(vers) - except: - raise ConfigError( - "Invalid version: {} for metric {} in {}".format(vers, name, path) - ) - - if query.startswith("file:"): - query_path = query[5:] - if not query_path.startswith("/"): - query_path = os.path.join(config_base, query_path) - with open(query_path, "r") as f: - metric["query"][vers] = f.read() + validate_metric(path, name, metric) # Read any included config files + config_base = os.path.dirname(path) for inc in cfg.get("include", []): # Prefix relative paths with the directory from the current config if not inc.startswith("/"): @@ -250,34 +300,37 @@ def read_config(path, included=False): # config if included: return cfg - else: - new_config = {} - update_deep(new_config, default_config) - update_deep(new_config, cfg) - # Minor sanity checks - if len(new_config["metrics"]) == 0: - log.error("No metrics are defined") - raise ConfigError("No metrics defined") + new_config = {} + update_deep(new_config, DEFAULT_CONFIG) + update_deep(new_config, cfg) - # Validate the new log level before changing the config - if new_config["log_level"].upper() not in [ - "DEBUG", - "INFO", - "WARNING", - "ERROR", - "CRITICAL", - ]: - raise ConfigError("Invalid log level: {}".format(new_config["log_level"])) + # Minor sanity checks + if len(new_config["metrics"]) == 0: + Context.log.error("No metrics are defined") + raise ConfigError("No metrics defined") - global config - config = new_config + # Validate the new log level before changing the config + if new_config["log_level"].upper() not in [ + "DEBUG", + "INFO", + "WARNING", + "ERROR", + "CRITICAL", + ]: + raise ConfigError("Invalid log level: {}".format(new_config["log_level"])) - # Apply changes to log level - log.setLevel(logging.getLevelName(config["log_level"].upper())) + Context.config = new_config + + # Apply changes to log level + Context.log.setLevel(logging.getLevelName(Context.config["log_level"].upper())) + + # Return the config (mostly to make pylint happy, but also in case I opt to remove the side + # effect and make this more functional. + return Context.config -def signal_handler(sig, frame): +def signal_handler(sig, frame): # pylint: disable=unused-argument """ Function for handling signals @@ -288,19 +341,22 @@ def signal_handler(sig, frame): # Signal everything to shut down if sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]: - log.info("Shutting down ...") - global running - running = False - if httpd is not None: - httpd.socket.close() + Context.log.info("Shutting down ...") + Context.running = False + if Context.httpd is not None: + Context.httpd.socket.close() # Signal a reload if sig == signal.SIGHUP: - log.warning("Received config reload signal") - read_config(config_file) + Context.log.warning("Received config reload signal") + read_config(Context.config_file) class ConnectionPool(ThreadedConnectionPool): + """ + Threaded connection pool that has a context manager. + """ + def __init__(self, dbname, minconn, maxconn, *args, **kwargs): # Make sure dbname isn't different in the kwargs kwargs["dbname"] = dbname @@ -309,7 +365,14 @@ class ConnectionPool(ThreadedConnectionPool): self.name = dbname @contextmanager - def connection(self, timeout=None): + def connection(self, timeout): + """ + Connection context manager for our connection pool. This will attempt to retrieve a + connection until the timeout is reached. + + Params: + timeout: how long to keep trying to get a connection bedore giving up + """ conn = None timeout_time = datetime.now() + timedelta(timeout) # We will continue to try to get a connection slot until we time out @@ -333,34 +396,37 @@ class ConnectionPool(ThreadedConnectionPool): def get_pool(dbname): """ Get a database connection pool. + + Params: + dbname: the name of the database for which a connection pool should be returned. """ # Check if the db is unhappy and wants to be left alone - if dbname in unhappy_cooldown: - if unhappy_cooldown[dbname] > datetime.now(): + if dbname in Context.unhappy_cooldown: + if Context.unhappy_cooldown[dbname] > datetime.now(): raise UnhappyDBError() # Create a connection pool if it doesn't already exist - if dbname not in connections: - with connections_lock: + if dbname not in Context.connections: + with Context.connections_lock: # Make sure nobody created the pool while we were waiting on the # lock - if dbname not in connections: - log.info("Creating connection pool for: {}".format(dbname)) + if dbname not in Context.connections: + Context.log.info("Creating connection pool for: %s", dbname) # Actually create the connection pool - connections[dbname] = ConnectionPool( + Context.connections[dbname] = ConnectionPool( dbname, - int(config["min_pool_size"]), - int(config["max_pool_size"]), + int(Context.config["min_pool_size"]), + int(Context.config["max_pool_size"]), application_name="pgmon", - host=config["dbhost"], - port=config["dbport"], - user=config["dbuser"], - connect_timeout=int(config["connect_timeout"]), - sslmode=config["ssl_mode"], + host=Context.config["dbhost"], + port=Context.config["dbport"], + user=Context.config["dbuser"], + connect_timeout=int(Context.config["connect_timeout"]), + sslmode=Context.config["ssl_mode"], ) # Clear the unhappy indicator if present - unhappy_cooldown.pop(dbname, None) - return connections[dbname] + Context.unhappy_cooldown.pop(dbname, None) + return Context.connections[dbname] def handle_connect_failure(pool): @@ -368,8 +434,8 @@ def handle_connect_failure(pool): Mark the database as being unhappy so we can leave it alone for a while """ dbname = pool.name - unhappy_cooldown[dbname] = datetime.now() + timedelta( - seconds=int(config["reconnect_cooldown"]) + Context.unhappy_cooldown[dbname] = datetime.now() + timedelta( + seconds=int(Context.config["reconnect_cooldown"]) ) @@ -400,43 +466,48 @@ def json_encode_special(obj): """ if isinstance(obj, Decimal): return float(obj) - raise TypeError(f"Cannot serialize object of {type(obj)}") + raise TypeError("Cannot serialize object of {}".format(type(obj))) def run_query_no_retry(pool, return_type, query, args): """ Run the query with no explicit retry code """ - with pool.connection(float(config["connect_timeout"])) as conn: + with pool.connection(float(Context.config["connect_timeout"])) as conn: try: with conn.cursor(cursor_factory=RealDictCursor) as curs: + output = None curs.execute(query, args) res = curs.fetchall() if return_type == "value": if len(res) == 0: - return "" - return str(list(res[0].values())[0]) + output = "" + output = str(list(res[0].values())[0]) elif return_type == "row": - if len(res) == 0: - return "[]" - return json.dumps(res[0], default=json_encode_special) + # if len(res) == 0: + # return "[]" + output = json.dumps(res[0], default=json_encode_special) elif return_type == "column": - if len(res) == 0: - return "[]" - return json.dumps( + # if len(res) == 0: + # return "[]" + output = json.dumps( [list(r.values())[0] for r in res], default=json_encode_special ) elif return_type == "set": - return json.dumps(res, default=json_encode_special) - except: + output = json.dumps(res, default=json_encode_special) + else: + raise ConfigError( + "Invalid query return type: {}".format(return_type) + ) + return output + except Exception as e: dbname = pool.name - if dbname in unhappy_cooldown: - raise UnhappyDBError() - elif conn.closed != 0: - raise DisconnectedError() - else: - raise + if dbname in Context.unhappy_cooldown: + raise UnhappyDBError() from e + if conn.closed != 0: + raise DisconnectedError() from e + raise def run_query(pool, return_type, query, args): @@ -457,7 +528,7 @@ def run_query(pool, return_type, query, args): try: return run_query_no_retry(pool, return_type, query, args) except DisconnectedError: - log.warning("Stale PostgreSQL connection found ... trying again") + Context.log.warning("Stale PostgreSQL connection found ... trying again") # This sleep is an annoying hack to give the pool workers time to # actually mark the connection, otherwise it can be given back in the # next connection() call @@ -465,9 +536,9 @@ def run_query(pool, return_type, query, args): time.sleep(1) try: return run_query_no_retry(pool, return_type, query, args) - except: + except Exception as e: handle_connect_failure(pool) - raise UnhappyDBError() + raise UnhappyDBError() from e def get_cluster_version(): @@ -475,40 +546,39 @@ def get_cluster_version(): Get the PostgreSQL version if we don't already know it, or if it's been too long sice the last time it was checked. """ - global cluster_version - global cluster_version_next_check # If we don't know the version or it's past the recheck time, get the # version from the database. Only one thread needs to do this, so they all # try to grab the lock, and then make sure nobody else beat them to it. if ( - cluster_version is None - or cluster_version_next_check is None - or cluster_version_next_check < datetime.now() + Context.cluster_version is None + or Context.cluster_version_next_check is None + or Context.cluster_version_next_check < datetime.now() ): - with cluster_version_lock: + with Context.cluster_version_lock: # Only check if nobody already got the version before us if ( - cluster_version is None - or cluster_version_next_check is None - or cluster_version_next_check < datetime.now() + Context.cluster_version is None + or Context.cluster_version_next_check is None + or Context.cluster_version_next_check < datetime.now() ): - log.info("Checking PostgreSQL cluster version") - pool = get_pool(config["dbname"]) - cluster_version = int( + Context.log.info("Checking PostgreSQL cluster version") + pool = get_pool(Context.config["dbname"]) + Context.cluster_version = int( run_query(pool, "value", "SHOW server_version_num", None) ) - cluster_version_next_check = datetime.now() + timedelta( - seconds=int(config["version_check_period"]) + Context.cluster_version_next_check = datetime.now() + timedelta( + seconds=int(Context.config["version_check_period"]) ) - log.info("Got PostgreSQL cluster version: {}".format(cluster_version)) - log.debug( - "Next PostgreSQL cluster version check will be after: {}".format( - cluster_version_next_check - ) + Context.log.info( + "Got PostgreSQL cluster version: %s", Context.cluster_version + ) + Context.log.debug( + "Next PostgreSQL cluster version check will be after: %s", + Context.cluster_version_next_check, ) - return cluster_version + return Context.cluster_version def version_num_to_release(version_num): @@ -521,8 +591,7 @@ def version_num_to_release(version_num): """ if version_num // 10000 < 10: return version_num // 10000 + (version_num % 10000 // 100 / 10) - else: - return version_num // 10000 + return version_num // 10000 def parse_version_rss(raw_rss, release): @@ -530,7 +599,7 @@ def parse_version_rss(raw_rss, release): Parse the raw RSS from the versions.rss feed to extract the latest version of PostgreSQL that's availabe for the cluster being monitored. - This sets these global variables: + This sets these Context variables: latest_version release_supported @@ -540,8 +609,6 @@ def parse_version_rss(raw_rss, release): raw_rss: The raw rss text from versions.rss release: The PostgreSQL release we care about (ex: 9.2, 14) """ - global latest_version - global release_supported # Regular expressions for parsing the RSS document version_line = re.compile( @@ -561,75 +628,75 @@ def parse_version_rss(raw_rss, release): version = m.group(1) parts = list(map(int, version.split("."))) if parts[0] < 10: - latest_version = int( + Context.latest_version = int( "{}{:02}{:02}".format(parts[0], parts[1], parts[2]) ) else: - latest_version = int("{}00{:02}".format(parts[0], parts[1])) + Context.latest_version = int("{}00{:02}".format(parts[0], parts[1])) elif release_found: # The next line after the version tells if the version is supported if unsupported_line.match(line): - release_supported = False + Context.release_supported = False else: - release_supported = True + Context.release_supported = True break # Make sure we actually found it if not release_found: raise LatestVersionCheckError("Current release ({}) not found".format(release)) - log.info( - "Got latest PostgreSQL version: {} supported={}".format( - latest_version, release_supported - ) + Context.log.info( + "Got latest PostgreSQL version: %s supported=%s", + Context.latest_version, + Context.release_supported, ) - log.debug( - "Next latest PostgreSQL version check will be after: {}".format( - latest_version_next_check - ) + Context.log.debug( + "Next latest PostgreSQL version check will be after: %s", + Context.latest_version_next_check, ) def get_latest_version(): """ - Get the latest supported version of the major PostgreSQL release running on the server being monitored. + Get the latest supported version of the major PostgreSQL release running on the server being + monitored. """ - global latest_version_next_check - # If we don't know the latest version or it's past the recheck time, get the # version from the PostgreSQL RSS feed. Only one thread needs to do this, so # they all try to grab the lock, and then make sure nobody else beat them to it. if ( - latest_version is None - or latest_version_next_check is None - or latest_version_next_check < datetime.now() + Context.latest_version is None + or Context.latest_version_next_check is None + or Context.latest_version_next_check < datetime.now() ): # Note: we get the cluster version here before grabbing the latest_version_lock # lock so it's not held while trying to talk with the DB. release = version_num_to_release(get_cluster_version()) - with latest_version_lock: + with Context.latest_version_lock: # Only check if nobody already got the version before us if ( - latest_version is None - or latest_version_next_check is None - or latest_version_next_check < datetime.now() + Context.latest_version is None + or Context.latest_version_next_check is None + or Context.latest_version_next_check < datetime.now() ): - log.info("Checking latest PostgreSQL version") - latest_version_next_check = datetime.now() + timedelta( - seconds=int(config["latest_version_check_period"]) + Context.log.info("Checking latest PostgreSQL version") + Context.latest_version_next_check = datetime.now() + timedelta( + seconds=int(Context.config["latest_version_check_period"]) ) # Grab the RSS feed - raw_rss = requests.get("https://www.postgresql.org/versions.rss") + raw_rss = requests.get( + "https://www.postgresql.org/versions.rss", timeout=30 + ) if raw_rss.status_code != 200: - raise LatestVersionCheckError("code={}".format(r.status_code)) + raise LatestVersionCheckError("code={}".format(raw_rss.status_code)) - # Parse the RSS body and set global variables + # Parse the RSS body and set Context variables parse_version_rss(raw_rss.text, release) - return latest_version + return Context.latest_version def sample_metric(dbname, metric_name, args, retry=True): @@ -638,9 +705,9 @@ def sample_metric(dbname, metric_name, args, retry=True): """ # Get the metric definition try: - metric = config["metrics"][metric_name] - except KeyError: - raise UnknownMetricError("Unknown metric: {}".format(metric_name)) + metric = Context.config["metrics"][metric_name] + except KeyError as e: + raise UnknownMetricError("Unknown metric: {}".format(metric_name)) from e # Get the connection pool for the database, or create one if it doesn't # already exist. @@ -655,8 +722,7 @@ def sample_metric(dbname, metric_name, args, retry=True): # Execute the quert if retry: return run_query(pool, metric["type"], query, args) - else: - return run_query_no_retry(pool, metric["type"], query, args) + return run_query_no_retry(pool, metric["type"], query, args) def test_queries(): @@ -664,9 +730,9 @@ def test_queries(): Run all of the metric queries against a database and check the results """ # We just use the default db for tests - dbname = config["dbname"] + dbname = Context.config["dbname"] # Loop through all defined metrics. - for name, metric in config["metrics"].items(): + for name, metric in Context.config["metrics"].items(): # If the metric has arguments to use while testing, grab those args = metric.get("test_args", {}) print( @@ -711,9 +777,8 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): """ Override to suppress standard request logging """ - pass - def do_GET(self): + def do_GET(self): # pylint: disable=invalid-name """ Handle a request. This is just a wrapper around the actual handler code to keep things more readable. @@ -721,7 +786,7 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): try: self._handle_request() except BrokenPipeError: - log.error("Client disconnected, exiting handler") + Context.log.error("Client disconnected, exiting handler") def _handle_request(self): """ @@ -734,7 +799,6 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): if metric_name == "agent_version": self._reply(200, VERSION) - return elif metric_name == "latest_version_info": try: get_latest_version() @@ -742,46 +806,40 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): 200, json.dumps( { - "latest": latest_version, - "supported": 1 if release_supported else 0, + "latest": Context.latest_version, + "supported": 1 if Context.release_supported else 0, } ), ) except LatestVersionCheckError as e: - log.error("Failed to retrieve latest version information: {}".format(e)) + Context.log.error( + "Failed to retrieve latest version information: %s", e + ) self._reply(503, "Failed to retrieve latest version info") - return + else: + # Note: parse_qs returns the values as a list. Since we always expect + # single values, just grab the first from each. + args = {key: values[0] for key, values in parsed_query.items()} - # Note: parse_qs returns the values as a list. Since we always expect - # single values, just grab the first from each. - args = {key: values[0] for key, values in parsed_query.items()} + # Get the dbname. If none was provided, use the default from the + # config. + dbname = args.get("dbname", Context.config["dbname"]) - # Get the dbname. If none was provided, use the default from the - # config. - dbname = args.get("dbname", config["dbname"]) - - # Sample the metric - try: - self._reply(200, sample_metric(dbname, metric_name, args)) - return - except UnknownMetricError as e: - log.error("Unknown metric: {}".format(metric_name)) - self._reply(404, "Unknown metric") - return - except MetricVersionError as e: - log.error( - "Failed to find a version of {} for {}".format(metric_name, version) - ) - self._reply(404, "Unsupported version") - return - except UnhappyDBError as e: - log.info("Database {} is unhappy, please be patient".format(dbname)) - self._reply(503, "Database unavailable") - return - except Exception as e: - log.error("Error running query: {}".format(e)) - self._reply(500, "Unexpected error: {}".format(e)) - return + # Sample the metric + try: + self._reply(200, sample_metric(dbname, metric_name, args)) + except UnknownMetricError: + Context.log.error("Unknown metric: %s", metric_name) + self._reply(404, "Unknown metric") + except MetricVersionError: + Context.log.error("Failed to find an query version for %s", metric_name) + self._reply(404, "Unsupported version") + except UnhappyDBError: + Context.log.info("Database %s is unhappy, please be patient", dbname) + self._reply(503, "Database unavailable") + except Exception as e: # pylint: disable=broad-exception-caught + Context.log.error("Error running query: %s", e) + self._reply(500, "Unexpected error: {}".format(e)) def _reply(self, code, content): """ @@ -794,7 +852,14 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): self.wfile.write(bytes(content, "utf-8")) -if __name__ == "__main__": +def main(): + """ + Main application routine + """ + + # Initialize the logging framework + Context.init_logging() + # Handle cli args parser = argparse.ArgumentParser( prog="pgmon", description="A PostgreSQL monitoring agent" @@ -815,33 +880,35 @@ if __name__ == "__main__": args = parser.parse_args() # Set the config file path - config_file = args.config_file + Context.config_file = args.config_file # Read the config file - read_config(config_file) + read_config(Context.config_file) # Run query tests and exit if test mode is enabled if args.test: - errors = test_queries() - if errors > 0: + if test_queries() > 0: sys.exit(1) - else: - sys.exit(0) + sys.exit(0) # Set up the http server to receive requests - server_address = (config["address"], config["port"]) - httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler) + server_address = (Context.config["address"], Context.config["port"]) + Context.httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler) # Set up the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGHUP, signal_handler) # Handle requests. - log.info("Listening on port {}...".format(config["port"])) - while running: - httpd.handle_request() + Context.log.info("Listening on port %s...", Context.config["port"]) + while Context.running: + Context.httpd.handle_request() # Clean up PostgreSQL connections # TODO: Improve this ... not sure it actually closes all the connections cleanly - for pool in connections.values(): + for pool in Context.connections.values(): pool.close() + + +if __name__ == "__main__": + main() diff --git a/src/test_pgmon.py b/src/test_pgmon.py index 534f8ba..e75b0aa 100644 --- a/src/test_pgmon.py +++ b/src/test_pgmon.py @@ -1,5 +1,12 @@ +""" +Unit tests for pgmon +""" + +# pylint: disable=too-many-lines + import unittest +import os from datetime import datetime, timedelta import tempfile @@ -13,7 +20,7 @@ import pgmon # Silence most logging output logging.disable(logging.CRITICAL) -versions_rss = """ +VERSIONS_RSS = """ PostgreSQL latest versionshttps://www.postgresql.org/PostgreSQL latest versionsen-usThu, 08 May 2025 00:00:00 +000017.5 https://www.postgresql.org/docs/17/release-17-5.html17.5 is the latest release in the 17 series. @@ -103,12 +110,18 @@ This version is unsupported! """ -class TestPgmonMethods(unittest.TestCase): +class TestPgmonMethods(unittest.TestCase): # pylint: disable=too-many-public-methods + """ + Unit test class for pgmon + """ + ## # update_deep ## def test_update_deep__empty_cases(self): - # Test empty dict cases + """ + Test various empty dictionary permutations + """ d1 = {} d2 = {} pgmon.update_deep(d1, d2) @@ -128,7 +141,9 @@ class TestPgmonMethods(unittest.TestCase): self.assertEqual(d2, d1) def test_update_deep__scalars(self): - # Test adding/updating scalar values + """ + Test adding/updating scalar values + """ d1 = {"foo": 1, "bar": "text", "hello": "world"} d2 = {"foo": 2, "baz": "blah"} pgmon.update_deep(d1, d2) @@ -136,7 +151,9 @@ class TestPgmonMethods(unittest.TestCase): self.assertEqual(d2, {"foo": 2, "baz": "blah"}) def test_update_deep__lists(self): - # Test adding to lists + """ + Test adding to lists + """ d1 = {"lst1": []} d2 = {"lst1": [1, 2]} pgmon.update_deep(d1, d2) @@ -172,7 +189,9 @@ class TestPgmonMethods(unittest.TestCase): self.assertEqual(d2, {"obj1": {"l1": [3, 4]}}) def test_update_deep__dicts(self): - # Test adding to lists + """ + Test adding to dictionaries + """ d1 = {"obj1": {}} d2 = {"obj1": {"a": 1, "b": 2}} pgmon.update_deep(d1, d2) @@ -199,7 +218,9 @@ class TestPgmonMethods(unittest.TestCase): self.assertEqual(d2, {"obj1": {"d1": {"a": 5, "c": 12}}}) def test_update_deep__types(self): - # Test mismatched types + """ + Test mismatched types + """ d1 = {"foo": 5} d2 = None self.assertRaises(TypeError, pgmon.update_deep, d1, d2) @@ -218,15 +239,19 @@ class TestPgmonMethods(unittest.TestCase): ## def test_get_pool__simple(self): - # Just get a pool in a normal case - pgmon.config.update(pgmon.default_config) + """ + Test getting a pool in a normal case + """ + pgmon.Context.config.update(pgmon.DEFAULT_CONFIG) pool = pgmon.get_pool("postgres") self.assertIsNotNone(pool) def test_get_pool__unhappy(self): - # Test getting an unhappy database pool - pgmon.config.update(pgmon.default_config) - pgmon.unhappy_cooldown["postgres"] = datetime.now() + timedelta(60) + """ + Test getting an unhappy database pool + """ + pgmon.Context.config.update(pgmon.DEFAULT_CONFIG) + pgmon.Context.unhappy_cooldown["postgres"] = datetime.now() + timedelta(60) self.assertRaises(pgmon.UnhappyDBError, pgmon.get_pool, "postgres") # Test getting a different database when there's an unhappy one @@ -238,30 +263,37 @@ class TestPgmonMethods(unittest.TestCase): ## def test_handle_connect_failure__simple(self): - # Test adding to an empty unhappy list - pgmon.config.update(pgmon.default_config) - pgmon.unhappy_cooldown = {} + """ + Test adding to an empty unhappy list + """ + pgmon.Context.config.update(pgmon.DEFAULT_CONFIG) + pgmon.Context.unhappy_cooldown = {} pool = pgmon.get_pool("postgres") pgmon.handle_connect_failure(pool) - self.assertGreater(pgmon.unhappy_cooldown["postgres"], datetime.now()) + self.assertGreater(pgmon.Context.unhappy_cooldown["postgres"], datetime.now()) # Test adding another database pool = pgmon.get_pool("template0") pgmon.handle_connect_failure(pool) - self.assertGreater(pgmon.unhappy_cooldown["postgres"], datetime.now()) - self.assertGreater(pgmon.unhappy_cooldown["template0"], datetime.now()) - self.assertEqual(len(pgmon.unhappy_cooldown), 2) + self.assertGreater(pgmon.Context.unhappy_cooldown["postgres"], datetime.now()) + self.assertGreater(pgmon.Context.unhappy_cooldown["template0"], datetime.now()) + self.assertEqual(len(pgmon.Context.unhappy_cooldown), 2) ## # get_query ## def test_get_query__basic(self): - # Test getting a query with one version + """ + Test getting a query with just a default version. + """ metric = {"type": "value", "query": {0: "DEFAULT"}} self.assertEqual(pgmon.get_query(metric, 100000), "DEFAULT") def test_get_query__versions(self): + """ + Test getting queries when multiple versions are present. + """ metric = {"type": "value", "query": {0: "DEFAULT", 110000: "NEW"}} # Test getting the default version of a query with no lower bound and a newer @@ -281,6 +313,9 @@ class TestPgmonMethods(unittest.TestCase): self.assertEqual(pgmon.get_query(metric, 100000), "OLD") def test_get_query__missing_version(self): + """ + Test trying to get a query that is not defined for the requested version. + """ metric = {"type": "value", "query": {96000: "OLD", 110000: "NEW", 150000: ""}} # Test getting a metric that only exists for newer versions @@ -294,11 +329,16 @@ class TestPgmonMethods(unittest.TestCase): ## def test_read_config__simple(self): - pgmon.config = {} + """ + Test reading a simple config. + """ + pgmon.Context.config = {} # Test reading just a metric and using the defaults for everything else with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- # This is a comment! @@ -310,18 +350,20 @@ metrics: """ ) - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) self.assertEqual( - pgmon.config["max_pool_size"], pgmon.default_config["max_pool_size"] + pgmon.Context.config["max_pool_size"], pgmon.DEFAULT_CONFIG["max_pool_size"] ) - self.assertEqual(pgmon.config["dbuser"], pgmon.default_config["dbuser"]) + self.assertEqual(pgmon.Context.config["dbuser"], pgmon.DEFAULT_CONFIG["dbuser"]) - pgmon.config = {} + pgmon.Context.config = {} # Test reading a basic config with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- # This is a comment! @@ -357,22 +399,27 @@ metrics: """ ) - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) - self.assertEqual(pgmon.config["dbuser"], "someone") - self.assertEqual(pgmon.config["metrics"]["test1"]["type"], "value") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") - self.assertEqual(pgmon.config["metrics"]["test2"]["query"][0], "TEST2") + self.assertEqual(pgmon.Context.config["dbuser"], "someone") + self.assertEqual(pgmon.Context.config["metrics"]["test1"]["type"], "value") + self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1") + self.assertEqual(pgmon.Context.config["metrics"]["test2"]["query"][0], "TEST2") def test_read_config__include(self): - pgmon.config = {} + """ + Test including one config from another. + """ + pgmon.Context.config = {} # Test reading a config that includes other files (absolute and relative paths, # multiple levels) with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( - f"""--- + """--- # This is a comment! min_pool_size: 1 max_pool_size: 2 @@ -384,13 +431,17 @@ reconnect_cooldown: 15 version_check_period: 3600 include: - dbsettings.yml - - {tmpdirname}/metrics.yml -""" + - {}/metrics.yml +""".format( + tmpdirname + ) ) - with open(f"{tmpdirname}/dbsettings.yml", "w") as f: + with open( + os.path.join(tmpdirname, "dbsettings.yml"), "w", encoding="utf-8" + ) as f: f.write( - f"""--- + """--- dbuser: someone dbhost: localhost dbport: 5555 @@ -398,9 +449,11 @@ dbname: template0 """ ) - with open(f"{tmpdirname}/metrics.yml", "w") as f: + with open( + os.path.join(tmpdirname, "metrics.yml"), "w", encoding="utf-8" + ) as f: f.write( - f"""--- + """--- metrics: test1: type: value @@ -415,9 +468,11 @@ include: """ ) - with open(f"{tmpdirname}/more_metrics.yml", "w") as f: + with open( + os.path.join(tmpdirname, "more_metrics.yml"), "w", encoding="utf-8" + ) as f: f.write( - f"""--- + """--- metrics: test3: type: value @@ -425,20 +480,25 @@ metrics: 0: TEST3 """ ) - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) - self.assertEqual(pgmon.config["max_idle_time"], 10) - self.assertEqual(pgmon.config["dbuser"], "someone") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") - self.assertEqual(pgmon.config["metrics"]["test2"]["query"][0], "TEST2") - self.assertEqual(pgmon.config["metrics"]["test3"]["query"][0], "TEST3") + self.assertEqual(pgmon.Context.config["max_idle_time"], 10) + self.assertEqual(pgmon.Context.config["dbuser"], "someone") + self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1") + self.assertEqual(pgmon.Context.config["metrics"]["test2"]["query"][0], "TEST2") + self.assertEqual(pgmon.Context.config["metrics"]["test3"]["query"][0], "TEST3") def test_read_config__reload(self): - pgmon.config = {} + """ + Test reloading a config. + """ + pgmon.Context.config = {} # Test rereading a config to update an existing config with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- # This is a comment! @@ -466,12 +526,14 @@ metrics: """ ) - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) # Just make sure the first config was read - self.assertEqual(len(pgmon.config["metrics"]), 2) + self.assertEqual(len(pgmon.Context.config["metrics"]), 2) - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- # This is a comment! @@ -484,18 +546,23 @@ metrics: """ ) - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) - self.assertEqual(pgmon.config["min_pool_size"], 7) - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "NEW1") - self.assertEqual(len(pgmon.config["metrics"]), 1) + self.assertEqual(pgmon.Context.config["min_pool_size"], 7) + self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "NEW1") + self.assertEqual(len(pgmon.Context.config["metrics"]), 1) def test_read_config__query_file(self): - pgmon.config = {} + """ + Test reading a query definition from a separate file + """ + pgmon.Context.config = {} # Read a config file that reads a query from a file with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- metrics: @@ -506,22 +573,30 @@ metrics: """ ) - with open(f"{tmpdirname}/some_query.sql", "w") as f: + with open( + os.path.join(tmpdirname, "some_query.sql"), "w", encoding="utf-8" + ) as f: f.write("This is a query") - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) self.assertEqual( - pgmon.config["metrics"]["test1"]["query"][0], "This is a query" + pgmon.Context.config["metrics"]["test1"]["query"][0], "This is a query" ) - def test_read_config__invalid(self): - pgmon.config = {} + def init_invalid_config_test(self): + """ + Initialize an invalid config read test. Basically just set up a simple valid config in + order to confirm that an invalid read does not modify the live config. + """ + pgmon.Context.config = {} # For all of these tests, we start with a valid config and also ensure that # it is not modified when a new config read fails with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- metrics: @@ -532,20 +607,48 @@ metrics: """ ) - pgmon.read_config(f"{tmpdirname}/config.yml") + pgmon.read_config(os.path.join(tmpdirname, "config.yml")) # Just make sure the config was read - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1") + + def verify_invalid_config_test(self): + """ + Verify that an invalid read did not modify the live config. + """ + self.assertEqual(pgmon.Context.config["dbuser"], "postgres") + self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1") + + def test_read_config__missing(self): + """ + Test reading a nonexistant config file. + """ + # Set up the test + self.init_invalid_config_test() # Test reading a nonexistant config file with tempfile.TemporaryDirectory() as tmpdirname: self.assertRaises( - FileNotFoundError, pgmon.read_config, f"{tmpdirname}/missing.yml" + FileNotFoundError, + pgmon.read_config, + os.path.join(tmpdirname, "missing.yml"), ) + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__invalid(self): + """ + Test reading an invalid config file. + """ + # Set up the test + self.init_invalid_config_test() + # Test reading an invalid config file with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """[default] This looks a lot like an ini file to me @@ -554,12 +657,26 @@ Or maybe a TOML? """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__invalid_include(self): + """ + Test reading an invalid config file. + """ + # Set up the test + self.init_invalid_config_test() + # Test reading a config that includes an invalid file with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -573,14 +690,26 @@ include: """ ) self.assertRaises( - FileNotFoundError, pgmon.read_config, f"{tmpdirname}/config.yml" + FileNotFoundError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__invalid_log_level(self): + """ + Test reading an invalid log level from a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test invalid log level with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- log_level: noisy @@ -593,14 +722,26 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__invalid_type(self): + """ + Test reading an invalid query result type form a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test invalid query return type with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -612,32 +753,57 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__invalid_query_dict(self): + """ + Test reading an invalid query definition structure type form a config file. In other words + what's supposed to be a dictionary of the form version => query, we give it something else. + """ + # Set up the test + self.init_invalid_config_test() # Test invalid query dict type with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil metrics: test1: - type: lots_of_data + type: row query: EVIL1 """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__missing_type(self): + """ + Test reading a metric with a missing result type from a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test incomplete metric: missing type with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -648,14 +814,26 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__missing_queries(self): + """ + Test reading a metric with no queries from a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test incomplete metric: missing queries with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -665,14 +843,26 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__empty_query_dict(self): + """ + Test reading a fetric with an empty query dict from a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test incomplete metric: empty queries with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -683,14 +873,26 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__none_query_dict(self): + """ + Test reading a metric where the query dict is None from a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test incomplete metric: query dict is None with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -701,28 +903,53 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__missing_metrics(self): + """ + Test reading a config file with no metrics. + """ + # Set up the test + self.init_invalid_config_test() # Test reading a config with no metrics with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__missing_query_file(self): + """ + Test reading a metric from a config file where the query definition cones from a missing + file. + """ + # Set up the test + self.init_invalid_config_test() # Test reading a query defined in a file but the file is missing with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -734,14 +961,26 @@ metrics: """ ) self.assertRaises( - FileNotFoundError, pgmon.read_config, f"{tmpdirname}/config.yml" + FileNotFoundError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + def test_read_config__invalid_version(self): + """ + Test reading a metric with an invalid PostgreSQL version from a config file. + """ + # Set up the test + self.init_invalid_config_test() # Test invalid query versions with tempfile.TemporaryDirectory() as tmpdirname: - with open(f"{tmpdirname}/config.yml", "w") as f: + with open( + os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8" + ) as f: f.write( """--- dbuser: evil @@ -753,47 +992,84 @@ metrics: """ ) self.assertRaises( - pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml" + pgmon.ConfigError, + pgmon.read_config, + os.path.join(tmpdirname, "config.yml"), ) - self.assertEqual(pgmon.config["dbuser"], "postgres") - self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1") + + # Confirm nothing changed + self.verify_invalid_config_test() + + ## + # version_num + ## def test_version_num_to_release__valid(self): + """ + Test converting PostgreSQL versions before and after 10 when the numbering scheme changed. + """ self.assertEqual(pgmon.version_num_to_release(90602), 9.6) self.assertEqual(pgmon.version_num_to_release(130002), 13) - def test_parse_version_rss__simple(self): - pgmon.parse_version_rss(versions_rss, 13) - self.assertEqual(pgmon.latest_version, 130021) - self.assertTrue(pgmon.release_supported) + ## + # parse_version_rss + ## - pgmon.parse_version_rss(versions_rss, 9.6) - self.assertEqual(pgmon.latest_version, 90624) - self.assertFalse(pgmon.release_supported) + def test_parse_version_rss__supported(self): + """ + Test parsing a supported version from the RSS feed + """ + pgmon.parse_version_rss(VERSIONS_RSS, 13) + self.assertEqual(pgmon.Context.latest_version, 130021) + self.assertTrue(pgmon.Context.release_supported) + + def test_parse_version_rss__unsupported(self): + """ + Test parsing an unsupported version from the RSS feed + """ + pgmon.parse_version_rss(VERSIONS_RSS, 9.6) + self.assertEqual(pgmon.Context.latest_version, 90624) + self.assertFalse(pgmon.Context.release_supported) def test_parse_version_rss__missing(self): - # Test asking about versions that don't exist + """ + Test asking about versions that don't exist in the RSS feed + """ self.assertRaises( - pgmon.LatestVersionCheckError, pgmon.parse_version_rss, versions_rss, 9.7 + pgmon.LatestVersionCheckError, pgmon.parse_version_rss, VERSIONS_RSS, 9.7 ) self.assertRaises( - pgmon.LatestVersionCheckError, pgmon.parse_version_rss, versions_rss, 99 + pgmon.LatestVersionCheckError, pgmon.parse_version_rss, VERSIONS_RSS, 99 ) + ## + # get_latest_version + ## + def test_get_latest_version(self): + """ + Test getting the latest version from the actual RSS feed + """ # Define a cluster version here so the test doesn't need a database - pgmon.cluster_version_next_check = datetime.now() + timedelta(hours=1) - pgmon.cluster_version = 90623 + pgmon.Context.cluster_version_next_check = datetime.now() + timedelta(hours=1) + pgmon.Context.cluster_version = 90623 # Set up a default config - pgmon.update_deep(pgmon.config, pgmon.default_config) + pgmon.update_deep(pgmon.Context.config, pgmon.DEFAULT_CONFIG) # Make sure we can pull the RSS file (we assume the 9.6 series won't be getting # any more updates) self.assertEqual(pgmon.get_latest_version(), 90624) + ## + # json_encode_special + ## + def test_json_encode_special(self): + """ + Test encoding Decimal types as JSON + """ # Confirm that we're getting the right type self.assertFalse(isinstance(Decimal("0.5"), float)) self.assertTrue(isinstance(pgmon.json_encode_special(Decimal("0.5")), float))