Refactor to make pylint happy

This commit is contained in:
James Campbell 2025-09-23 01:12:49 -04:00
parent 29bfd07dad
commit 43cd162313
Signed by: james
GPG Key ID: 2287C33A40DC906A
4 changed files with 767 additions and 408 deletions

View File

@ -38,7 +38,7 @@ SUPPORTED := ubuntu-20.04 \
# These targets are the main ones to use for most things.
##
.PHONY: all clean tgz test query-tests install-common install-openrc install-systemd
.PHONY: all clean tgz lint format test query-tests install-common install-openrc install-systemd
all: package-all
@ -80,6 +80,18 @@ tgz:
clean:
rm -rf $(BUILD_DIR)
# Check for lint
lint:
pylint src/pgmon.py
pylint src/test_pgmon.py
black --check --diff src/pgmon.py
black --check --diff src/test_pgmon.py
# Format the code using black
format:
black src/pgmon.py
black src/test_pylint.py
# Run unit tests for the script
test:
cd src ; python3 -m unittest

4
pylintrc Normal file
View File

@ -0,0 +1,4 @@
[MASTER]
py-version=3.5
disable=fixme

View File

@ -1,105 +1,141 @@
#!/usr/bin/env python3
"""
pgmon is a monitoring intermediary that sits between a PostgreSQL cluster and a monitoring systen
that is capable of parsing JSON responses over an HTTP connection.
"""
# pylint: disable=too-few-public-methods
import yaml
import json
import time
import os
import sys
import signal
import argparse
import logging
import re
from decimal import Decimal
from urllib.parse import urlparse, parse_qs
from contextlib import contextmanager
from datetime import datetime, timedelta
from http.server import BaseHTTPRequestHandler
from http.server import ThreadingHTTPServer
from threading import Lock
import yaml
import psycopg2
from psycopg2.extras import RealDictCursor
from psycopg2.pool import ThreadedConnectionPool
from contextlib import contextmanager
import signal
from threading import Thread, Lock, Semaphore
from http.server import BaseHTTPRequestHandler, HTTPServer
from http.server import ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
import requests
import re
from decimal import Decimal
VERSION = "1.0.4"
# Configuration
config = {}
# Dictionary of current PostgreSQL connection pools
connections_lock = Lock()
connections = {}
class Context:
"""
The global context for connections, config, version, nad IPC
"""
# Dictionary of unhappy databases. Keys are database names, value is the time
# the database was determined to be unhappy plus the cooldown setting. So,
# basically it's the time when we should try to connect to the database again.
unhappy_cooldown = {}
# Configuration
config = {}
# Version information
cluster_version = None
cluster_version_next_check = None
cluster_version_lock = Lock()
# Dictionary of current PostgreSQL connection pools
connections_lock = Lock()
connections = {}
# PostgreSQL latest version information
latest_version = None
latest_version_next_check = None
latest_version_lock = Lock()
release_supported = None
# Dictionary of unhappy databases. Keys are database names, value is the time
# the database was determined to be unhappy plus the cooldown setting. So,
# basically it's the time when we should try to connect to the database again.
unhappy_cooldown = {}
# Running state (used to gracefully shut down)
running = True
# Version information
cluster_version = None
cluster_version_next_check = None
cluster_version_lock = Lock()
# The http server object
httpd = None
# PostgreSQL latest version information
latest_version = None
latest_version_next_check = None
latest_version_lock = Lock()
release_supported = None
# Where the config file lives
config_file = None
# Running state (used to gracefully shut down)
running = True
# Configure logging
log = logging.getLogger(__name__)
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(filename)s: %(funcName)s() line %(lineno)d: %(message)s"
)
console_log_handler = logging.StreamHandler()
console_log_handler.setFormatter(formatter)
log.addHandler(console_log_handler)
# The http server object
httpd = None
# Where the config file lives
config_file = None
# Configure logging
log = logging.getLogger(__name__)
@classmethod
def init_logging(cls):
"""
Actually initialize the logging framework. Since we don't ever instantiate the Context
class, this provides a way to make a few modifications to the log handler.
"""
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(filename)s: "
"%(funcName)s() line %(lineno)d: %(message)s"
)
console_log_handler = logging.StreamHandler()
console_log_handler.setFormatter(formatter)
cls.log.addHandler(console_log_handler)
# Error types
class ConfigError(Exception):
pass
"""
Error type for all config related errors.
"""
class DisconnectedError(Exception):
pass
"""
Error indicating a previously active connection to the database has been disconnected.
"""
class UnhappyDBError(Exception):
pass
"""
Error indicating that a database the code has been asked to connect to is on the unhappy list.
"""
class UnknownMetricError(Exception):
pass
"""
Error indicating that an undefined metric was requested.
"""
class MetricVersionError(Exception):
pass
"""
Error indicating that there is no suitable query for a metric that was requested for the
version of PostgreSQL being monitored.
"""
class LatestVersionCheckError(Exception):
pass
"""
Error indicating that there was a problem retrieving or parsing the latest version information.
"""
# Default config settings
default_config = {
DEFAULT_CONFIG = {
# The address the agent binds to
"address": "127.0.0.1",
# The port the agent listens on for requests
@ -177,6 +213,60 @@ def update_deep(d1, d2):
return d1
def validate_metric(path, name, metric):
"""
Validate a metric definition from a given file. If any query definitions come from external
files, the metric dict will be updated with the actual query.
Params:
path: path to the file which contains this definition
name: name of the metric
metric: the dictionary containing the metric definition
"""
# Validate return types
try:
if metric["type"] not in ["value", "row", "column", "set"]:
raise ConfigError(
"Invalid return type: {} for metric {} in {}".format(
metric["type"], name, path
)
)
except KeyError as e:
raise ConfigError(
"No type specified for metric {} in {}".format(name, path)
) from e
# Ensure queries exist
query_dict = metric.get("query", {})
if not isinstance(query_dict, dict):
raise ConfigError(
"Query definition should be a dictionary, got: {} for metric {} in {}".format(
query_dict, name, path
)
)
if len(query_dict) == 0:
raise ConfigError("Missing queries for metric {} in {}".format(name, path))
# Read external sql files and validate version keys
config_base = os.path.dirname(path)
for vers, query in metric["query"].items():
try:
int(vers)
except Exception as e:
raise ConfigError(
"Invalid version: {} for metric {} in {}".format(vers, name, path)
) from e
# Read in the external query and update the definition in the metricdictionary
if query.startswith("file:"):
query_path = query[5:]
if not query_path.startswith("/"):
query_path = os.path.join(config_base, query_path)
with open(query_path, "r", encoding="utf-8") as f:
metric["query"][vers] = f.read()
def read_config(path, included=False):
"""
Read a config file.
@ -185,61 +275,21 @@ def read_config(path, included=False):
path: path to the file to read
included: is this file included by another file?
"""
# Read config file
log.info("Reading log file: {}".format(path))
with open(path, "r") as f:
Context.log.info("Reading log file: %s", path)
with open(path, "r", encoding="utf-8") as f:
try:
cfg = yaml.safe_load(f)
except yaml.parser.ParserError as e:
raise ConfigError("Inavlid config file: {}: {}".format(path, e))
# Since we use it a few places, get the base directory from the config
config_base = os.path.dirname(path)
raise ConfigError("Inavlid config file: {}: {}".format(path, e)) from e
# Read any external queries and validate metric definitions
for name, metric in cfg.get("metrics", {}).items():
# Validate return types
try:
if metric["type"] not in ["value", "row", "column", "set"]:
raise ConfigError(
"Invalid return type: {} for metric {} in {}".format(
metric["type"], name, path
)
)
except KeyError:
raise ConfigError(
"No type specified for metric {} in {}".format(name, path)
)
# Ensure queries exist
query_dict = metric.get("query", {})
if type(query_dict) is not dict:
raise ConfigError(
"Query definition should be a dictionary, got: {} for metric {} in {}".format(
query_dict, name, path
)
)
if len(query_dict) == 0:
raise ConfigError("Missing queries for metric {} in {}".format(name, path))
# Read external sql files and validate version keys
for vers, query in metric["query"].items():
try:
int(vers)
except:
raise ConfigError(
"Invalid version: {} for metric {} in {}".format(vers, name, path)
)
if query.startswith("file:"):
query_path = query[5:]
if not query_path.startswith("/"):
query_path = os.path.join(config_base, query_path)
with open(query_path, "r") as f:
metric["query"][vers] = f.read()
validate_metric(path, name, metric)
# Read any included config files
config_base = os.path.dirname(path)
for inc in cfg.get("include", []):
# Prefix relative paths with the directory from the current config
if not inc.startswith("/"):
@ -250,34 +300,37 @@ def read_config(path, included=False):
# config
if included:
return cfg
else:
new_config = {}
update_deep(new_config, default_config)
update_deep(new_config, cfg)
# Minor sanity checks
if len(new_config["metrics"]) == 0:
log.error("No metrics are defined")
raise ConfigError("No metrics defined")
new_config = {}
update_deep(new_config, DEFAULT_CONFIG)
update_deep(new_config, cfg)
# Validate the new log level before changing the config
if new_config["log_level"].upper() not in [
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"CRITICAL",
]:
raise ConfigError("Invalid log level: {}".format(new_config["log_level"]))
# Minor sanity checks
if len(new_config["metrics"]) == 0:
Context.log.error("No metrics are defined")
raise ConfigError("No metrics defined")
global config
config = new_config
# Validate the new log level before changing the config
if new_config["log_level"].upper() not in [
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"CRITICAL",
]:
raise ConfigError("Invalid log level: {}".format(new_config["log_level"]))
# Apply changes to log level
log.setLevel(logging.getLevelName(config["log_level"].upper()))
Context.config = new_config
# Apply changes to log level
Context.log.setLevel(logging.getLevelName(Context.config["log_level"].upper()))
# Return the config (mostly to make pylint happy, but also in case I opt to remove the side
# effect and make this more functional.
return Context.config
def signal_handler(sig, frame):
def signal_handler(sig, frame): # pylint: disable=unused-argument
"""
Function for handling signals
@ -288,19 +341,22 @@ def signal_handler(sig, frame):
# Signal everything to shut down
if sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]:
log.info("Shutting down ...")
global running
running = False
if httpd is not None:
httpd.socket.close()
Context.log.info("Shutting down ...")
Context.running = False
if Context.httpd is not None:
Context.httpd.socket.close()
# Signal a reload
if sig == signal.SIGHUP:
log.warning("Received config reload signal")
read_config(config_file)
Context.log.warning("Received config reload signal")
read_config(Context.config_file)
class ConnectionPool(ThreadedConnectionPool):
"""
Threaded connection pool that has a context manager.
"""
def __init__(self, dbname, minconn, maxconn, *args, **kwargs):
# Make sure dbname isn't different in the kwargs
kwargs["dbname"] = dbname
@ -309,7 +365,14 @@ class ConnectionPool(ThreadedConnectionPool):
self.name = dbname
@contextmanager
def connection(self, timeout=None):
def connection(self, timeout):
"""
Connection context manager for our connection pool. This will attempt to retrieve a
connection until the timeout is reached.
Params:
timeout: how long to keep trying to get a connection bedore giving up
"""
conn = None
timeout_time = datetime.now() + timedelta(timeout)
# We will continue to try to get a connection slot until we time out
@ -333,34 +396,37 @@ class ConnectionPool(ThreadedConnectionPool):
def get_pool(dbname):
"""
Get a database connection pool.
Params:
dbname: the name of the database for which a connection pool should be returned.
"""
# Check if the db is unhappy and wants to be left alone
if dbname in unhappy_cooldown:
if unhappy_cooldown[dbname] > datetime.now():
if dbname in Context.unhappy_cooldown:
if Context.unhappy_cooldown[dbname] > datetime.now():
raise UnhappyDBError()
# Create a connection pool if it doesn't already exist
if dbname not in connections:
with connections_lock:
if dbname not in Context.connections:
with Context.connections_lock:
# Make sure nobody created the pool while we were waiting on the
# lock
if dbname not in connections:
log.info("Creating connection pool for: {}".format(dbname))
if dbname not in Context.connections:
Context.log.info("Creating connection pool for: %s", dbname)
# Actually create the connection pool
connections[dbname] = ConnectionPool(
Context.connections[dbname] = ConnectionPool(
dbname,
int(config["min_pool_size"]),
int(config["max_pool_size"]),
int(Context.config["min_pool_size"]),
int(Context.config["max_pool_size"]),
application_name="pgmon",
host=config["dbhost"],
port=config["dbport"],
user=config["dbuser"],
connect_timeout=int(config["connect_timeout"]),
sslmode=config["ssl_mode"],
host=Context.config["dbhost"],
port=Context.config["dbport"],
user=Context.config["dbuser"],
connect_timeout=int(Context.config["connect_timeout"]),
sslmode=Context.config["ssl_mode"],
)
# Clear the unhappy indicator if present
unhappy_cooldown.pop(dbname, None)
return connections[dbname]
Context.unhappy_cooldown.pop(dbname, None)
return Context.connections[dbname]
def handle_connect_failure(pool):
@ -368,8 +434,8 @@ def handle_connect_failure(pool):
Mark the database as being unhappy so we can leave it alone for a while
"""
dbname = pool.name
unhappy_cooldown[dbname] = datetime.now() + timedelta(
seconds=int(config["reconnect_cooldown"])
Context.unhappy_cooldown[dbname] = datetime.now() + timedelta(
seconds=int(Context.config["reconnect_cooldown"])
)
@ -400,43 +466,48 @@ def json_encode_special(obj):
"""
if isinstance(obj, Decimal):
return float(obj)
raise TypeError(f"Cannot serialize object of {type(obj)}")
raise TypeError("Cannot serialize object of {}".format(type(obj)))
def run_query_no_retry(pool, return_type, query, args):
"""
Run the query with no explicit retry code
"""
with pool.connection(float(config["connect_timeout"])) as conn:
with pool.connection(float(Context.config["connect_timeout"])) as conn:
try:
with conn.cursor(cursor_factory=RealDictCursor) as curs:
output = None
curs.execute(query, args)
res = curs.fetchall()
if return_type == "value":
if len(res) == 0:
return ""
return str(list(res[0].values())[0])
output = ""
output = str(list(res[0].values())[0])
elif return_type == "row":
if len(res) == 0:
return "[]"
return json.dumps(res[0], default=json_encode_special)
# if len(res) == 0:
# return "[]"
output = json.dumps(res[0], default=json_encode_special)
elif return_type == "column":
if len(res) == 0:
return "[]"
return json.dumps(
# if len(res) == 0:
# return "[]"
output = json.dumps(
[list(r.values())[0] for r in res], default=json_encode_special
)
elif return_type == "set":
return json.dumps(res, default=json_encode_special)
except:
output = json.dumps(res, default=json_encode_special)
else:
raise ConfigError(
"Invalid query return type: {}".format(return_type)
)
return output
except Exception as e:
dbname = pool.name
if dbname in unhappy_cooldown:
raise UnhappyDBError()
elif conn.closed != 0:
raise DisconnectedError()
else:
raise
if dbname in Context.unhappy_cooldown:
raise UnhappyDBError() from e
if conn.closed != 0:
raise DisconnectedError() from e
raise
def run_query(pool, return_type, query, args):
@ -457,7 +528,7 @@ def run_query(pool, return_type, query, args):
try:
return run_query_no_retry(pool, return_type, query, args)
except DisconnectedError:
log.warning("Stale PostgreSQL connection found ... trying again")
Context.log.warning("Stale PostgreSQL connection found ... trying again")
# This sleep is an annoying hack to give the pool workers time to
# actually mark the connection, otherwise it can be given back in the
# next connection() call
@ -465,9 +536,9 @@ def run_query(pool, return_type, query, args):
time.sleep(1)
try:
return run_query_no_retry(pool, return_type, query, args)
except:
except Exception as e:
handle_connect_failure(pool)
raise UnhappyDBError()
raise UnhappyDBError() from e
def get_cluster_version():
@ -475,40 +546,39 @@ def get_cluster_version():
Get the PostgreSQL version if we don't already know it, or if it's been
too long sice the last time it was checked.
"""
global cluster_version
global cluster_version_next_check
# If we don't know the version or it's past the recheck time, get the
# version from the database. Only one thread needs to do this, so they all
# try to grab the lock, and then make sure nobody else beat them to it.
if (
cluster_version is None
or cluster_version_next_check is None
or cluster_version_next_check < datetime.now()
Context.cluster_version is None
or Context.cluster_version_next_check is None
or Context.cluster_version_next_check < datetime.now()
):
with cluster_version_lock:
with Context.cluster_version_lock:
# Only check if nobody already got the version before us
if (
cluster_version is None
or cluster_version_next_check is None
or cluster_version_next_check < datetime.now()
Context.cluster_version is None
or Context.cluster_version_next_check is None
or Context.cluster_version_next_check < datetime.now()
):
log.info("Checking PostgreSQL cluster version")
pool = get_pool(config["dbname"])
cluster_version = int(
Context.log.info("Checking PostgreSQL cluster version")
pool = get_pool(Context.config["dbname"])
Context.cluster_version = int(
run_query(pool, "value", "SHOW server_version_num", None)
)
cluster_version_next_check = datetime.now() + timedelta(
seconds=int(config["version_check_period"])
Context.cluster_version_next_check = datetime.now() + timedelta(
seconds=int(Context.config["version_check_period"])
)
log.info("Got PostgreSQL cluster version: {}".format(cluster_version))
log.debug(
"Next PostgreSQL cluster version check will be after: {}".format(
cluster_version_next_check
)
Context.log.info(
"Got PostgreSQL cluster version: %s", Context.cluster_version
)
Context.log.debug(
"Next PostgreSQL cluster version check will be after: %s",
Context.cluster_version_next_check,
)
return cluster_version
return Context.cluster_version
def version_num_to_release(version_num):
@ -521,8 +591,7 @@ def version_num_to_release(version_num):
"""
if version_num // 10000 < 10:
return version_num // 10000 + (version_num % 10000 // 100 / 10)
else:
return version_num // 10000
return version_num // 10000
def parse_version_rss(raw_rss, release):
@ -530,7 +599,7 @@ def parse_version_rss(raw_rss, release):
Parse the raw RSS from the versions.rss feed to extract the latest version of
PostgreSQL that's availabe for the cluster being monitored.
This sets these global variables:
This sets these Context variables:
latest_version
release_supported
@ -540,8 +609,6 @@ def parse_version_rss(raw_rss, release):
raw_rss: The raw rss text from versions.rss
release: The PostgreSQL release we care about (ex: 9.2, 14)
"""
global latest_version
global release_supported
# Regular expressions for parsing the RSS document
version_line = re.compile(
@ -561,75 +628,75 @@ def parse_version_rss(raw_rss, release):
version = m.group(1)
parts = list(map(int, version.split(".")))
if parts[0] < 10:
latest_version = int(
Context.latest_version = int(
"{}{:02}{:02}".format(parts[0], parts[1], parts[2])
)
else:
latest_version = int("{}00{:02}".format(parts[0], parts[1]))
Context.latest_version = int("{}00{:02}".format(parts[0], parts[1]))
elif release_found:
# The next line after the version tells if the version is supported
if unsupported_line.match(line):
release_supported = False
Context.release_supported = False
else:
release_supported = True
Context.release_supported = True
break
# Make sure we actually found it
if not release_found:
raise LatestVersionCheckError("Current release ({}) not found".format(release))
log.info(
"Got latest PostgreSQL version: {} supported={}".format(
latest_version, release_supported
)
Context.log.info(
"Got latest PostgreSQL version: %s supported=%s",
Context.latest_version,
Context.release_supported,
)
log.debug(
"Next latest PostgreSQL version check will be after: {}".format(
latest_version_next_check
)
Context.log.debug(
"Next latest PostgreSQL version check will be after: %s",
Context.latest_version_next_check,
)
def get_latest_version():
"""
Get the latest supported version of the major PostgreSQL release running on the server being monitored.
Get the latest supported version of the major PostgreSQL release running on the server being
monitored.
"""
global latest_version_next_check
# If we don't know the latest version or it's past the recheck time, get the
# version from the PostgreSQL RSS feed. Only one thread needs to do this, so
# they all try to grab the lock, and then make sure nobody else beat them to it.
if (
latest_version is None
or latest_version_next_check is None
or latest_version_next_check < datetime.now()
Context.latest_version is None
or Context.latest_version_next_check is None
or Context.latest_version_next_check < datetime.now()
):
# Note: we get the cluster version here before grabbing the latest_version_lock
# lock so it's not held while trying to talk with the DB.
release = version_num_to_release(get_cluster_version())
with latest_version_lock:
with Context.latest_version_lock:
# Only check if nobody already got the version before us
if (
latest_version is None
or latest_version_next_check is None
or latest_version_next_check < datetime.now()
Context.latest_version is None
or Context.latest_version_next_check is None
or Context.latest_version_next_check < datetime.now()
):
log.info("Checking latest PostgreSQL version")
latest_version_next_check = datetime.now() + timedelta(
seconds=int(config["latest_version_check_period"])
Context.log.info("Checking latest PostgreSQL version")
Context.latest_version_next_check = datetime.now() + timedelta(
seconds=int(Context.config["latest_version_check_period"])
)
# Grab the RSS feed
raw_rss = requests.get("https://www.postgresql.org/versions.rss")
raw_rss = requests.get(
"https://www.postgresql.org/versions.rss", timeout=30
)
if raw_rss.status_code != 200:
raise LatestVersionCheckError("code={}".format(r.status_code))
raise LatestVersionCheckError("code={}".format(raw_rss.status_code))
# Parse the RSS body and set global variables
# Parse the RSS body and set Context variables
parse_version_rss(raw_rss.text, release)
return latest_version
return Context.latest_version
def sample_metric(dbname, metric_name, args, retry=True):
@ -638,9 +705,9 @@ def sample_metric(dbname, metric_name, args, retry=True):
"""
# Get the metric definition
try:
metric = config["metrics"][metric_name]
except KeyError:
raise UnknownMetricError("Unknown metric: {}".format(metric_name))
metric = Context.config["metrics"][metric_name]
except KeyError as e:
raise UnknownMetricError("Unknown metric: {}".format(metric_name)) from e
# Get the connection pool for the database, or create one if it doesn't
# already exist.
@ -655,8 +722,7 @@ def sample_metric(dbname, metric_name, args, retry=True):
# Execute the quert
if retry:
return run_query(pool, metric["type"], query, args)
else:
return run_query_no_retry(pool, metric["type"], query, args)
return run_query_no_retry(pool, metric["type"], query, args)
def test_queries():
@ -664,9 +730,9 @@ def test_queries():
Run all of the metric queries against a database and check the results
"""
# We just use the default db for tests
dbname = config["dbname"]
dbname = Context.config["dbname"]
# Loop through all defined metrics.
for name, metric in config["metrics"].items():
for name, metric in Context.config["metrics"].items():
# If the metric has arguments to use while testing, grab those
args = metric.get("test_args", {})
print(
@ -711,9 +777,8 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
"""
Override to suppress standard request logging
"""
pass
def do_GET(self):
def do_GET(self): # pylint: disable=invalid-name
"""
Handle a request. This is just a wrapper around the actual handler
code to keep things more readable.
@ -721,7 +786,7 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
try:
self._handle_request()
except BrokenPipeError:
log.error("Client disconnected, exiting handler")
Context.log.error("Client disconnected, exiting handler")
def _handle_request(self):
"""
@ -734,7 +799,6 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
if metric_name == "agent_version":
self._reply(200, VERSION)
return
elif metric_name == "latest_version_info":
try:
get_latest_version()
@ -742,46 +806,40 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
200,
json.dumps(
{
"latest": latest_version,
"supported": 1 if release_supported else 0,
"latest": Context.latest_version,
"supported": 1 if Context.release_supported else 0,
}
),
)
except LatestVersionCheckError as e:
log.error("Failed to retrieve latest version information: {}".format(e))
Context.log.error(
"Failed to retrieve latest version information: %s", e
)
self._reply(503, "Failed to retrieve latest version info")
return
else:
# Note: parse_qs returns the values as a list. Since we always expect
# single values, just grab the first from each.
args = {key: values[0] for key, values in parsed_query.items()}
# Note: parse_qs returns the values as a list. Since we always expect
# single values, just grab the first from each.
args = {key: values[0] for key, values in parsed_query.items()}
# Get the dbname. If none was provided, use the default from the
# config.
dbname = args.get("dbname", Context.config["dbname"])
# Get the dbname. If none was provided, use the default from the
# config.
dbname = args.get("dbname", config["dbname"])
# Sample the metric
try:
self._reply(200, sample_metric(dbname, metric_name, args))
return
except UnknownMetricError as e:
log.error("Unknown metric: {}".format(metric_name))
self._reply(404, "Unknown metric")
return
except MetricVersionError as e:
log.error(
"Failed to find a version of {} for {}".format(metric_name, version)
)
self._reply(404, "Unsupported version")
return
except UnhappyDBError as e:
log.info("Database {} is unhappy, please be patient".format(dbname))
self._reply(503, "Database unavailable")
return
except Exception as e:
log.error("Error running query: {}".format(e))
self._reply(500, "Unexpected error: {}".format(e))
return
# Sample the metric
try:
self._reply(200, sample_metric(dbname, metric_name, args))
except UnknownMetricError:
Context.log.error("Unknown metric: %s", metric_name)
self._reply(404, "Unknown metric")
except MetricVersionError:
Context.log.error("Failed to find an query version for %s", metric_name)
self._reply(404, "Unsupported version")
except UnhappyDBError:
Context.log.info("Database %s is unhappy, please be patient", dbname)
self._reply(503, "Database unavailable")
except Exception as e: # pylint: disable=broad-exception-caught
Context.log.error("Error running query: %s", e)
self._reply(500, "Unexpected error: {}".format(e))
def _reply(self, code, content):
"""
@ -794,7 +852,14 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
self.wfile.write(bytes(content, "utf-8"))
if __name__ == "__main__":
def main():
"""
Main application routine
"""
# Initialize the logging framework
Context.init_logging()
# Handle cli args
parser = argparse.ArgumentParser(
prog="pgmon", description="A PostgreSQL monitoring agent"
@ -815,33 +880,35 @@ if __name__ == "__main__":
args = parser.parse_args()
# Set the config file path
config_file = args.config_file
Context.config_file = args.config_file
# Read the config file
read_config(config_file)
read_config(Context.config_file)
# Run query tests and exit if test mode is enabled
if args.test:
errors = test_queries()
if errors > 0:
if test_queries() > 0:
sys.exit(1)
else:
sys.exit(0)
sys.exit(0)
# Set up the http server to receive requests
server_address = (config["address"], config["port"])
httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler)
server_address = (Context.config["address"], Context.config["port"])
Context.httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler)
# Set up the signal handler
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)
# Handle requests.
log.info("Listening on port {}...".format(config["port"]))
while running:
httpd.handle_request()
Context.log.info("Listening on port %s...", Context.config["port"])
while Context.running:
Context.httpd.handle_request()
# Clean up PostgreSQL connections
# TODO: Improve this ... not sure it actually closes all the connections cleanly
for pool in connections.values():
for pool in Context.connections.values():
pool.close()
if __name__ == "__main__":
main()

View File

@ -1,5 +1,12 @@
"""
Unit tests for pgmon
"""
# pylint: disable=too-many-lines
import unittest
import os
from datetime import datetime, timedelta
import tempfile
@ -13,7 +20,7 @@ import pgmon
# Silence most logging output
logging.disable(logging.CRITICAL)
versions_rss = """
VERSIONS_RSS = """
<?xml version="1.0" encoding="utf-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>PostgreSQL latest versions</title><link>https://www.postgresql.org/</link><description>PostgreSQL latest versions</description><atom:link href="https://www.postgresql.org/versions.rss" rel="self"/><language>en-us</language><lastBuildDate>Thu, 08 May 2025 00:00:00 +0000</lastBuildDate><item><title>17.5
</title><link>https://www.postgresql.org/docs/17/release-17-5.html</link><description>17.5 is the latest release in the 17 series.
@ -103,12 +110,18 @@ This version is unsupported!
"""
class TestPgmonMethods(unittest.TestCase):
class TestPgmonMethods(unittest.TestCase): # pylint: disable=too-many-public-methods
"""
Unit test class for pgmon
"""
##
# update_deep
##
def test_update_deep__empty_cases(self):
# Test empty dict cases
"""
Test various empty dictionary permutations
"""
d1 = {}
d2 = {}
pgmon.update_deep(d1, d2)
@ -128,7 +141,9 @@ class TestPgmonMethods(unittest.TestCase):
self.assertEqual(d2, d1)
def test_update_deep__scalars(self):
# Test adding/updating scalar values
"""
Test adding/updating scalar values
"""
d1 = {"foo": 1, "bar": "text", "hello": "world"}
d2 = {"foo": 2, "baz": "blah"}
pgmon.update_deep(d1, d2)
@ -136,7 +151,9 @@ class TestPgmonMethods(unittest.TestCase):
self.assertEqual(d2, {"foo": 2, "baz": "blah"})
def test_update_deep__lists(self):
# Test adding to lists
"""
Test adding to lists
"""
d1 = {"lst1": []}
d2 = {"lst1": [1, 2]}
pgmon.update_deep(d1, d2)
@ -172,7 +189,9 @@ class TestPgmonMethods(unittest.TestCase):
self.assertEqual(d2, {"obj1": {"l1": [3, 4]}})
def test_update_deep__dicts(self):
# Test adding to lists
"""
Test adding to dictionaries
"""
d1 = {"obj1": {}}
d2 = {"obj1": {"a": 1, "b": 2}}
pgmon.update_deep(d1, d2)
@ -199,7 +218,9 @@ class TestPgmonMethods(unittest.TestCase):
self.assertEqual(d2, {"obj1": {"d1": {"a": 5, "c": 12}}})
def test_update_deep__types(self):
# Test mismatched types
"""
Test mismatched types
"""
d1 = {"foo": 5}
d2 = None
self.assertRaises(TypeError, pgmon.update_deep, d1, d2)
@ -218,15 +239,19 @@ class TestPgmonMethods(unittest.TestCase):
##
def test_get_pool__simple(self):
# Just get a pool in a normal case
pgmon.config.update(pgmon.default_config)
"""
Test getting a pool in a normal case
"""
pgmon.Context.config.update(pgmon.DEFAULT_CONFIG)
pool = pgmon.get_pool("postgres")
self.assertIsNotNone(pool)
def test_get_pool__unhappy(self):
# Test getting an unhappy database pool
pgmon.config.update(pgmon.default_config)
pgmon.unhappy_cooldown["postgres"] = datetime.now() + timedelta(60)
"""
Test getting an unhappy database pool
"""
pgmon.Context.config.update(pgmon.DEFAULT_CONFIG)
pgmon.Context.unhappy_cooldown["postgres"] = datetime.now() + timedelta(60)
self.assertRaises(pgmon.UnhappyDBError, pgmon.get_pool, "postgres")
# Test getting a different database when there's an unhappy one
@ -238,30 +263,37 @@ class TestPgmonMethods(unittest.TestCase):
##
def test_handle_connect_failure__simple(self):
# Test adding to an empty unhappy list
pgmon.config.update(pgmon.default_config)
pgmon.unhappy_cooldown = {}
"""
Test adding to an empty unhappy list
"""
pgmon.Context.config.update(pgmon.DEFAULT_CONFIG)
pgmon.Context.unhappy_cooldown = {}
pool = pgmon.get_pool("postgres")
pgmon.handle_connect_failure(pool)
self.assertGreater(pgmon.unhappy_cooldown["postgres"], datetime.now())
self.assertGreater(pgmon.Context.unhappy_cooldown["postgres"], datetime.now())
# Test adding another database
pool = pgmon.get_pool("template0")
pgmon.handle_connect_failure(pool)
self.assertGreater(pgmon.unhappy_cooldown["postgres"], datetime.now())
self.assertGreater(pgmon.unhappy_cooldown["template0"], datetime.now())
self.assertEqual(len(pgmon.unhappy_cooldown), 2)
self.assertGreater(pgmon.Context.unhappy_cooldown["postgres"], datetime.now())
self.assertGreater(pgmon.Context.unhappy_cooldown["template0"], datetime.now())
self.assertEqual(len(pgmon.Context.unhappy_cooldown), 2)
##
# get_query
##
def test_get_query__basic(self):
# Test getting a query with one version
"""
Test getting a query with just a default version.
"""
metric = {"type": "value", "query": {0: "DEFAULT"}}
self.assertEqual(pgmon.get_query(metric, 100000), "DEFAULT")
def test_get_query__versions(self):
"""
Test getting queries when multiple versions are present.
"""
metric = {"type": "value", "query": {0: "DEFAULT", 110000: "NEW"}}
# Test getting the default version of a query with no lower bound and a newer
@ -281,6 +313,9 @@ class TestPgmonMethods(unittest.TestCase):
self.assertEqual(pgmon.get_query(metric, 100000), "OLD")
def test_get_query__missing_version(self):
"""
Test trying to get a query that is not defined for the requested version.
"""
metric = {"type": "value", "query": {96000: "OLD", 110000: "NEW", 150000: ""}}
# Test getting a metric that only exists for newer versions
@ -294,11 +329,16 @@ class TestPgmonMethods(unittest.TestCase):
##
def test_read_config__simple(self):
pgmon.config = {}
"""
Test reading a simple config.
"""
pgmon.Context.config = {}
# Test reading just a metric and using the defaults for everything else
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
# This is a comment!
@ -310,18 +350,20 @@ metrics:
"""
)
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
self.assertEqual(
pgmon.config["max_pool_size"], pgmon.default_config["max_pool_size"]
pgmon.Context.config["max_pool_size"], pgmon.DEFAULT_CONFIG["max_pool_size"]
)
self.assertEqual(pgmon.config["dbuser"], pgmon.default_config["dbuser"])
self.assertEqual(pgmon.Context.config["dbuser"], pgmon.DEFAULT_CONFIG["dbuser"])
pgmon.config = {}
pgmon.Context.config = {}
# Test reading a basic config
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
# This is a comment!
@ -357,22 +399,27 @@ metrics:
"""
)
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
self.assertEqual(pgmon.config["dbuser"], "someone")
self.assertEqual(pgmon.config["metrics"]["test1"]["type"], "value")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.config["metrics"]["test2"]["query"][0], "TEST2")
self.assertEqual(pgmon.Context.config["dbuser"], "someone")
self.assertEqual(pgmon.Context.config["metrics"]["test1"]["type"], "value")
self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.Context.config["metrics"]["test2"]["query"][0], "TEST2")
def test_read_config__include(self):
pgmon.config = {}
"""
Test including one config from another.
"""
pgmon.Context.config = {}
# Test reading a config that includes other files (absolute and relative paths,
# multiple levels)
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
f"""---
"""---
# This is a comment!
min_pool_size: 1
max_pool_size: 2
@ -384,13 +431,17 @@ reconnect_cooldown: 15
version_check_period: 3600
include:
- dbsettings.yml
- {tmpdirname}/metrics.yml
"""
- {}/metrics.yml
""".format(
tmpdirname
)
)
with open(f"{tmpdirname}/dbsettings.yml", "w") as f:
with open(
os.path.join(tmpdirname, "dbsettings.yml"), "w", encoding="utf-8"
) as f:
f.write(
f"""---
"""---
dbuser: someone
dbhost: localhost
dbport: 5555
@ -398,9 +449,11 @@ dbname: template0
"""
)
with open(f"{tmpdirname}/metrics.yml", "w") as f:
with open(
os.path.join(tmpdirname, "metrics.yml"), "w", encoding="utf-8"
) as f:
f.write(
f"""---
"""---
metrics:
test1:
type: value
@ -415,9 +468,11 @@ include:
"""
)
with open(f"{tmpdirname}/more_metrics.yml", "w") as f:
with open(
os.path.join(tmpdirname, "more_metrics.yml"), "w", encoding="utf-8"
) as f:
f.write(
f"""---
"""---
metrics:
test3:
type: value
@ -425,20 +480,25 @@ metrics:
0: TEST3
"""
)
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
self.assertEqual(pgmon.config["max_idle_time"], 10)
self.assertEqual(pgmon.config["dbuser"], "someone")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.config["metrics"]["test2"]["query"][0], "TEST2")
self.assertEqual(pgmon.config["metrics"]["test3"]["query"][0], "TEST3")
self.assertEqual(pgmon.Context.config["max_idle_time"], 10)
self.assertEqual(pgmon.Context.config["dbuser"], "someone")
self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.Context.config["metrics"]["test2"]["query"][0], "TEST2")
self.assertEqual(pgmon.Context.config["metrics"]["test3"]["query"][0], "TEST3")
def test_read_config__reload(self):
pgmon.config = {}
"""
Test reloading a config.
"""
pgmon.Context.config = {}
# Test rereading a config to update an existing config
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
# This is a comment!
@ -466,12 +526,14 @@ metrics:
"""
)
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
# Just make sure the first config was read
self.assertEqual(len(pgmon.config["metrics"]), 2)
self.assertEqual(len(pgmon.Context.config["metrics"]), 2)
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
# This is a comment!
@ -484,18 +546,23 @@ metrics:
"""
)
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
self.assertEqual(pgmon.config["min_pool_size"], 7)
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "NEW1")
self.assertEqual(len(pgmon.config["metrics"]), 1)
self.assertEqual(pgmon.Context.config["min_pool_size"], 7)
self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "NEW1")
self.assertEqual(len(pgmon.Context.config["metrics"]), 1)
def test_read_config__query_file(self):
pgmon.config = {}
"""
Test reading a query definition from a separate file
"""
pgmon.Context.config = {}
# Read a config file that reads a query from a file
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
metrics:
@ -506,22 +573,30 @@ metrics:
"""
)
with open(f"{tmpdirname}/some_query.sql", "w") as f:
with open(
os.path.join(tmpdirname, "some_query.sql"), "w", encoding="utf-8"
) as f:
f.write("This is a query")
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
self.assertEqual(
pgmon.config["metrics"]["test1"]["query"][0], "This is a query"
pgmon.Context.config["metrics"]["test1"]["query"][0], "This is a query"
)
def test_read_config__invalid(self):
pgmon.config = {}
def init_invalid_config_test(self):
"""
Initialize an invalid config read test. Basically just set up a simple valid config in
order to confirm that an invalid read does not modify the live config.
"""
pgmon.Context.config = {}
# For all of these tests, we start with a valid config and also ensure that
# it is not modified when a new config read fails
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
metrics:
@ -532,20 +607,48 @@ metrics:
"""
)
pgmon.read_config(f"{tmpdirname}/config.yml")
pgmon.read_config(os.path.join(tmpdirname, "config.yml"))
# Just make sure the config was read
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1")
def verify_invalid_config_test(self):
"""
Verify that an invalid read did not modify the live config.
"""
self.assertEqual(pgmon.Context.config["dbuser"], "postgres")
self.assertEqual(pgmon.Context.config["metrics"]["test1"]["query"][0], "TEST1")
def test_read_config__missing(self):
"""
Test reading a nonexistant config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test reading a nonexistant config file
with tempfile.TemporaryDirectory() as tmpdirname:
self.assertRaises(
FileNotFoundError, pgmon.read_config, f"{tmpdirname}/missing.yml"
FileNotFoundError,
pgmon.read_config,
os.path.join(tmpdirname, "missing.yml"),
)
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__invalid(self):
"""
Test reading an invalid config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test reading an invalid config file
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""[default]
This looks a lot like an ini file to me
@ -554,12 +657,26 @@ Or maybe a TOML?
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__invalid_include(self):
"""
Test reading an invalid config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test reading a config that includes an invalid file
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -573,14 +690,26 @@ include:
"""
)
self.assertRaises(
FileNotFoundError, pgmon.read_config, f"{tmpdirname}/config.yml"
FileNotFoundError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__invalid_log_level(self):
"""
Test reading an invalid log level from a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test invalid log level
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
log_level: noisy
@ -593,14 +722,26 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__invalid_type(self):
"""
Test reading an invalid query result type form a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test invalid query return type
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -612,32 +753,57 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__invalid_query_dict(self):
"""
Test reading an invalid query definition structure type form a config file. In other words
what's supposed to be a dictionary of the form version => query, we give it something else.
"""
# Set up the test
self.init_invalid_config_test()
# Test invalid query dict type
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
metrics:
test1:
type: lots_of_data
type: row
query: EVIL1
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__missing_type(self):
"""
Test reading a metric with a missing result type from a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test incomplete metric: missing type
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -648,14 +814,26 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__missing_queries(self):
"""
Test reading a metric with no queries from a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test incomplete metric: missing queries
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -665,14 +843,26 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__empty_query_dict(self):
"""
Test reading a fetric with an empty query dict from a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test incomplete metric: empty queries
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -683,14 +873,26 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__none_query_dict(self):
"""
Test reading a metric where the query dict is None from a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test incomplete metric: query dict is None
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -701,28 +903,53 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__missing_metrics(self):
"""
Test reading a config file with no metrics.
"""
# Set up the test
self.init_invalid_config_test()
# Test reading a config with no metrics
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__missing_query_file(self):
"""
Test reading a metric from a config file where the query definition cones from a missing
file.
"""
# Set up the test
self.init_invalid_config_test()
# Test reading a query defined in a file but the file is missing
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -734,14 +961,26 @@ metrics:
"""
)
self.assertRaises(
FileNotFoundError, pgmon.read_config, f"{tmpdirname}/config.yml"
FileNotFoundError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
def test_read_config__invalid_version(self):
"""
Test reading a metric with an invalid PostgreSQL version from a config file.
"""
# Set up the test
self.init_invalid_config_test()
# Test invalid query versions
with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", "w") as f:
with open(
os.path.join(tmpdirname, "config.yml"), "w", encoding="utf-8"
) as f:
f.write(
"""---
dbuser: evil
@ -753,47 +992,84 @@ metrics:
"""
)
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
pgmon.ConfigError,
pgmon.read_config,
os.path.join(tmpdirname, "config.yml"),
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Confirm nothing changed
self.verify_invalid_config_test()
##
# version_num
##
def test_version_num_to_release__valid(self):
"""
Test converting PostgreSQL versions before and after 10 when the numbering scheme changed.
"""
self.assertEqual(pgmon.version_num_to_release(90602), 9.6)
self.assertEqual(pgmon.version_num_to_release(130002), 13)
def test_parse_version_rss__simple(self):
pgmon.parse_version_rss(versions_rss, 13)
self.assertEqual(pgmon.latest_version, 130021)
self.assertTrue(pgmon.release_supported)
##
# parse_version_rss
##
pgmon.parse_version_rss(versions_rss, 9.6)
self.assertEqual(pgmon.latest_version, 90624)
self.assertFalse(pgmon.release_supported)
def test_parse_version_rss__supported(self):
"""
Test parsing a supported version from the RSS feed
"""
pgmon.parse_version_rss(VERSIONS_RSS, 13)
self.assertEqual(pgmon.Context.latest_version, 130021)
self.assertTrue(pgmon.Context.release_supported)
def test_parse_version_rss__unsupported(self):
"""
Test parsing an unsupported version from the RSS feed
"""
pgmon.parse_version_rss(VERSIONS_RSS, 9.6)
self.assertEqual(pgmon.Context.latest_version, 90624)
self.assertFalse(pgmon.Context.release_supported)
def test_parse_version_rss__missing(self):
# Test asking about versions that don't exist
"""
Test asking about versions that don't exist in the RSS feed
"""
self.assertRaises(
pgmon.LatestVersionCheckError, pgmon.parse_version_rss, versions_rss, 9.7
pgmon.LatestVersionCheckError, pgmon.parse_version_rss, VERSIONS_RSS, 9.7
)
self.assertRaises(
pgmon.LatestVersionCheckError, pgmon.parse_version_rss, versions_rss, 99
pgmon.LatestVersionCheckError, pgmon.parse_version_rss, VERSIONS_RSS, 99
)
##
# get_latest_version
##
def test_get_latest_version(self):
"""
Test getting the latest version from the actual RSS feed
"""
# Define a cluster version here so the test doesn't need a database
pgmon.cluster_version_next_check = datetime.now() + timedelta(hours=1)
pgmon.cluster_version = 90623
pgmon.Context.cluster_version_next_check = datetime.now() + timedelta(hours=1)
pgmon.Context.cluster_version = 90623
# Set up a default config
pgmon.update_deep(pgmon.config, pgmon.default_config)
pgmon.update_deep(pgmon.Context.config, pgmon.DEFAULT_CONFIG)
# Make sure we can pull the RSS file (we assume the 9.6 series won't be getting
# any more updates)
self.assertEqual(pgmon.get_latest_version(), 90624)
##
# json_encode_special
##
def test_json_encode_special(self):
"""
Test encoding Decimal types as JSON
"""
# Confirm that we're getting the right type
self.assertFalse(isinstance(Decimal("0.5"), float))
self.assertTrue(isinstance(pgmon.json_encode_special(Decimal("0.5")), float))