Reformat python code using black

This commit is contained in:
James Campbell 2025-05-13 01:44:47 -04:00
parent 98ac25743b
commit bffabd9c8f
Signed by: james
GPG Key ID: 2287C33A40DC906A
2 changed files with 574 additions and 467 deletions

View File

@ -23,7 +23,7 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
from http.server import ThreadingHTTPServer from http.server import ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs from urllib.parse import urlparse, parse_qs
VERSION = '0.1.0' VERSION = "0.1.0"
# Configuration # Configuration
config = {} config = {}
@ -53,67 +53,65 @@ config_file = None
# Configure logging # Configure logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(filename)s: %(funcName)s() line %(lineno)d: %(message)s') formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(filename)s: %(funcName)s() line %(lineno)d: %(message)s"
)
console_log_handler = logging.StreamHandler() console_log_handler = logging.StreamHandler()
console_log_handler.setFormatter(formatter) console_log_handler.setFormatter(formatter)
log.addHandler(console_log_handler) log.addHandler(console_log_handler)
# Error types # Error types
class ConfigError(Exception): class ConfigError(Exception):
pass pass
class DisconnectedError(Exception): class DisconnectedError(Exception):
pass pass
class UnhappyDBError(Exception): class UnhappyDBError(Exception):
pass pass
class MetricVersionError(Exception): class MetricVersionError(Exception):
pass pass
# Default config settings # Default config settings
default_config = { default_config = {
# The port the agent listens on for requests # The port the agent listens on for requests
'port': 5400, "port": 5400,
# Min PostgreSQL connection pool size (per database) # Min PostgreSQL connection pool size (per database)
'min_pool_size': 0, "min_pool_size": 0,
# Max PostgreSQL connection pool size (per database) # Max PostgreSQL connection pool size (per database)
'max_pool_size': 4, "max_pool_size": 4,
# How long a connection can sit idle in the pool before it's removed (seconds) # How long a connection can sit idle in the pool before it's removed (seconds)
'max_idle_time': 30, "max_idle_time": 30,
# Log level for stderr logging # Log level for stderr logging
'log_level': 'error', "log_level": "error",
# Database user to connect as # Database user to connect as
'dbuser': 'postgres', "dbuser": "postgres",
# Database host # Database host
'dbhost': '/var/run/postgresql', "dbhost": "/var/run/postgresql",
# Database port # Database port
'dbport': 5432, "dbport": 5432,
# Default database to connect to when none is specified for a metric # Default database to connect to when none is specified for a metric
'dbname': 'postgres', "dbname": "postgres",
# Timeout for getting a connection slot from a pool # Timeout for getting a connection slot from a pool
'pool_slot_timeout': 5, "pool_slot_timeout": 5,
# PostgreSQL connection timeout (seconds) # PostgreSQL connection timeout (seconds)
# Note: It can actually be double this because of retries # Note: It can actually be double this because of retries
'connect_timeout': 5, "connect_timeout": 5,
# Time to wait before trying to reconnect again after a reconnect failure (seconds) # Time to wait before trying to reconnect again after a reconnect failure (seconds)
'reconnect_cooldown': 30, "reconnect_cooldown": 30,
# How often to check the version of PostgreSQL (seconds) # How often to check the version of PostgreSQL (seconds)
'version_check_period': 300, "version_check_period": 300,
# Metrics # Metrics
'metrics': {} "metrics": {},
} }
def update_deep(d1, d2): def update_deep(d1, d2):
""" """
Recursively update a dict, adding keys to dictionaries and appending to Recursively update a dict, adding keys to dictionaries and appending to
@ -127,24 +125,33 @@ def update_deep(d1, d2):
The new d1 The new d1
""" """
if not isinstance(d1, dict) or not isinstance(d2, dict): if not isinstance(d1, dict) or not isinstance(d2, dict):
raise TypeError('Both arguments to update_deep need to be dictionaries') raise TypeError("Both arguments to update_deep need to be dictionaries")
for k, v2 in d2.items(): for k, v2 in d2.items():
if isinstance(v2, dict): if isinstance(v2, dict):
v1 = d1.get(k, {}) v1 = d1.get(k, {})
if not isinstance(v1, dict): if not isinstance(v1, dict):
raise TypeError('Type mismatch between dictionaries: {} is not a dict'.format(type(v1).__name__)) raise TypeError(
"Type mismatch between dictionaries: {} is not a dict".format(
type(v1).__name__
)
)
d1[k] = update_deep(v1, v2) d1[k] = update_deep(v1, v2)
elif isinstance(v2, list): elif isinstance(v2, list):
v1 = d1.get(k, []) v1 = d1.get(k, [])
if not isinstance(v1, list): if not isinstance(v1, list):
raise TypeError('Type mismatch between dictionaries: {} is not a list'.format(type(v1).__name__)) raise TypeError(
"Type mismatch between dictionaries: {} is not a list".format(
type(v1).__name__
)
)
d1[k] = v1 + v2 d1[k] = v1 + v2
else: else:
d1[k] = v2 d1[k] = v2
return d1 return d1
def read_config(path, included = False):
def read_config(path, included=False):
""" """
Read a config file. Read a config file.
@ -154,7 +161,7 @@ def read_config(path, included = False):
""" """
# Read config file # Read config file
log.info("Reading log file: {}".format(path)) log.info("Reading log file: {}".format(path))
with open(path, 'r') as f: with open(path, "r") as f:
try: try:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
except yaml.parser.ParserError as e: except yaml.parser.ParserError as e:
@ -164,41 +171,52 @@ def read_config(path, included = False):
config_base = os.path.dirname(path) config_base = os.path.dirname(path)
# Read any external queries and validate metric definitions # Read any external queries and validate metric definitions
for name, metric in cfg.get('metrics', {}).items(): for name, metric in cfg.get("metrics", {}).items():
# Validate return types # Validate return types
try: try:
if metric['type'] not in ['value', 'row', 'column', 'set']: if metric["type"] not in ["value", "row", "column", "set"]:
raise ConfigError("Invalid return type: {} for metric {} in {}".format(metric['type'], name, path)) raise ConfigError(
"Invalid return type: {} for metric {} in {}".format(
metric["type"], name, path
)
)
except KeyError: except KeyError:
raise ConfigError("No type specified for metric {} in {}".format(name, path)) raise ConfigError(
"No type specified for metric {} in {}".format(name, path)
)
# Ensure queries exist # Ensure queries exist
query_dict = metric.get('query', {}) query_dict = metric.get("query", {})
if type(query_dict) is not dict: if type(query_dict) is not dict:
raise ConfigError("Query definition should be a dictionary, got: {} for metric {} in {}".format(query_dict, name, path)) raise ConfigError(
"Query definition should be a dictionary, got: {} for metric {} in {}".format(
query_dict, name, path
)
)
if len(query_dict) == 0: if len(query_dict) == 0:
raise ConfigError("Missing queries for metric {} in {}".format(name, path)) raise ConfigError("Missing queries for metric {} in {}".format(name, path))
# Read external sql files and validate version keys # Read external sql files and validate version keys
for vers, query in metric['query'].items(): for vers, query in metric["query"].items():
try: try:
int(vers) int(vers)
except: except:
raise ConfigError("Invalid version: {} for metric {} in {}".format(vers, name, path)) raise ConfigError(
"Invalid version: {} for metric {} in {}".format(vers, name, path)
)
if query.startswith('file:'): if query.startswith("file:"):
query_path = query[5:] query_path = query[5:]
if not query_path.startswith('/'): if not query_path.startswith("/"):
query_path = os.path.join(config_base, query_path) query_path = os.path.join(config_base, query_path)
with open(query_path, 'r') as f: with open(query_path, "r") as f:
metric['query'][vers] = f.read() metric["query"][vers] = f.read()
# Read any included config files # Read any included config files
for inc in cfg.get('include', []): for inc in cfg.get("include", []):
# Prefix relative paths with the directory from the current config # Prefix relative paths with the directory from the current config
if not inc.startswith('/'): if not inc.startswith("/"):
inc = os.path.join(config_base, inc) inc = os.path.join(config_base, inc)
update_deep(cfg, read_config(inc, included=True)) update_deep(cfg, read_config(inc, included=True))
@ -212,19 +230,26 @@ def read_config(path, included = False):
update_deep(new_config, cfg) update_deep(new_config, cfg)
# Minor sanity checks # Minor sanity checks
if len(new_config['metrics']) == 0: if len(new_config["metrics"]) == 0:
log.error("No metrics are defined") log.error("No metrics are defined")
raise ConfigError("No metrics defined") raise ConfigError("No metrics defined")
# Validate the new log level before changing the config # Validate the new log level before changing the config
if new_config['log_level'].upper() not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: if new_config["log_level"].upper() not in [
raise ConfigError("Invalid log level: {}".format(new_config['log_level'])) "DEBUG",
"INFO",
"WARNING",
"ERROR",
"CRITICAL",
]:
raise ConfigError("Invalid log level: {}".format(new_config["log_level"]))
global config global config
config = new_config config = new_config
# Apply changes to log level # Apply changes to log level
log.setLevel(logging.getLevelName(config['log_level'].upper())) log.setLevel(logging.getLevelName(config["log_level"].upper()))
def signal_handler(sig, frame): def signal_handler(sig, frame):
""" """
@ -236,7 +261,7 @@ def signal_handler(sig, frame):
signal.signal(signal.SIGINT, signal.default_int_handler) signal.signal(signal.SIGINT, signal.default_int_handler)
# Signal everything to shut down # Signal everything to shut down
if sig in [ signal.SIGINT, signal.SIGTERM, signal.SIGQUIT ]: if sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]:
log.info("Shutting down ...") log.info("Shutting down ...")
global running global running
running = False running = False
@ -248,10 +273,11 @@ def signal_handler(sig, frame):
log.warning("Received config reload signal") log.warning("Received config reload signal")
read_config(config_file) read_config(config_file)
class ConnectionPool(ThreadedConnectionPool): class ConnectionPool(ThreadedConnectionPool):
def __init__(self, dbname, minconn, maxconn, *args, **kwargs): def __init__(self, dbname, minconn, maxconn, *args, **kwargs):
# Make sure dbname isn't different in the kwargs # Make sure dbname isn't different in the kwargs
kwargs['dbname'] = dbname kwargs["dbname"] = dbname
super().__init__(minconn, maxconn, *args, **kwargs) super().__init__(minconn, maxconn, *args, **kwargs)
self.name = dbname self.name = dbname
@ -273,7 +299,10 @@ class ConnectionPool(ThreadedConnectionPool):
except psycopg2.pool.PoolError: except psycopg2.pool.PoolError:
# If we failed to get the connection slot, wait a bit and try again # If we failed to get the connection slot, wait a bit and try again
time.sleep(0.1) time.sleep(0.1)
raise TimeoutError("Timed out waiting for an available connection to {}".format(self.name)) raise TimeoutError(
"Timed out waiting for an available connection to {}".format(self.name)
)
def get_pool(dbname): def get_pool(dbname):
""" """
@ -293,24 +322,29 @@ def get_pool(dbname):
log.info("Creating connection pool for: {}".format(dbname)) log.info("Creating connection pool for: {}".format(dbname))
connections[dbname] = ConnectionPool( connections[dbname] = ConnectionPool(
dbname, dbname,
int(config['min_pool_size']), int(config["min_pool_size"]),
int(config['max_pool_size']), int(config["max_pool_size"]),
application_name='pgmon', application_name="pgmon",
host=config['dbhost'], host=config["dbhost"],
port=config['dbport'], port=config["dbport"],
user=config['dbuser'], user=config["dbuser"],
connect_timeout=int(config['connect_timeout']), connect_timeout=int(config["connect_timeout"]),
sslmode='require') sslmode="require",
)
# Clear the unhappy indicator if present # Clear the unhappy indicator if present
unhappy_cooldown.pop(dbname, None) unhappy_cooldown.pop(dbname, None)
return connections[dbname] return connections[dbname]
def handle_connect_failure(pool): def handle_connect_failure(pool):
""" """
Mark the database as being unhappy so we can leave it alone for a while Mark the database as being unhappy so we can leave it alone for a while
""" """
dbname = pool.name dbname = pool.name
unhappy_cooldown[dbname] = datetime.now() + timedelta(seconds=int(config['reconnect_cooldown'])) unhappy_cooldown[dbname] = datetime.now() + timedelta(
seconds=int(config["reconnect_cooldown"])
)
def get_query(metric, version): def get_query(metric, version):
""" """
@ -321,32 +355,34 @@ def get_query(metric, version):
version: The PostgreSQL version number, as given by server_version_num version: The PostgreSQL version number, as given by server_version_num
""" """
# Select the correct query # Select the correct query
for v in reversed(sorted(metric['query'].keys())): for v in reversed(sorted(metric["query"].keys())):
if version >= v: if version >= v:
if len(metric['query'][v].strip()) == 0: if len(metric["query"][v].strip()) == 0:
raise MetricVersionError("Metric no longer applies to PostgreSQL {}".format(version)) raise MetricVersionError(
return metric['query'][v] "Metric no longer applies to PostgreSQL {}".format(version)
)
return metric["query"][v]
raise MetricVersionError('Missing metric query for PostgreSQL {}'.format(version)) raise MetricVersionError("Missing metric query for PostgreSQL {}".format(version))
def run_query_no_retry(pool, return_type, query, args): def run_query_no_retry(pool, return_type, query, args):
""" """
Run the query with no explicit retry code Run the query with no explicit retry code
""" """
with pool.connection(float(config['connect_timeout'])) as conn: with pool.connection(float(config["connect_timeout"])) as conn:
try: try:
with conn.cursor(cursor_factory=RealDictCursor) as curs: with conn.cursor(cursor_factory=RealDictCursor) as curs:
curs.execute(query, args) curs.execute(query, args)
res = curs.fetchall() res = curs.fetchall()
if return_type == 'value': if return_type == "value":
return str(list(res[0].values())[0]) return str(list(res[0].values())[0])
elif return_type == 'row': elif return_type == "row":
return json.dumps(res[0]) return json.dumps(res[0])
elif return_type == 'column': elif return_type == "column":
return json.dumps([list(r.values())[0] for r in res]) return json.dumps([list(r.values())[0] for r in res])
elif return_type == 'set': elif return_type == "set":
return json.dumps(res) return json.dumps(res)
except: except:
dbname = pool.name dbname = pool.name
@ -357,6 +393,7 @@ def run_query_no_retry(pool, return_type, query, args):
else: else:
raise raise
def run_query(pool, return_type, query, args): def run_query(pool, return_type, query, args):
""" """
Run the query, and if we find upon the first attempt that the connection Run the query, and if we find upon the first attempt that the connection
@ -387,6 +424,7 @@ def run_query(pool, return_type, query, args):
handle_connect_failure(pool) handle_connect_failure(pool)
raise UnhappyDBError() raise UnhappyDBError()
def get_cluster_version(): def get_cluster_version():
""" """
Get the PostgreSQL version if we don't already know it, or if it's been Get the PostgreSQL version if we don't already know it, or if it's been
@ -398,26 +436,43 @@ def get_cluster_version():
# If we don't know the version or it's past the recheck time, get the # If we don't know the version or it's past the recheck time, get the
# version from the database. Only one thread needs to do this, so they all # version from the database. Only one thread needs to do this, so they all
# try to grab the lock, and then make sure nobody else beat them to it. # try to grab the lock, and then make sure nobody else beat them to it.
if cluster_version is None or cluster_version_next_check is None or cluster_version_next_check < datetime.now(): if (
cluster_version is None
or cluster_version_next_check is None
or cluster_version_next_check < datetime.now()
):
with cluster_version_lock: with cluster_version_lock:
# Only check if nobody already got the version before us # Only check if nobody already got the version before us
if cluster_version is None or cluster_version_next_check is None or cluster_version_next_check < datetime.now(): if (
log.info('Checking PostgreSQL cluster version') cluster_version is None
pool = get_pool(config['dbname']) or cluster_version_next_check is None
cluster_version = int(run_query(pool, 'value', 'SHOW server_version_num', None)) or cluster_version_next_check < datetime.now()
cluster_version_next_check = datetime.now() + timedelta(seconds=int(config['version_check_period'])) ):
log.info("Checking PostgreSQL cluster version")
pool = get_pool(config["dbname"])
cluster_version = int(
run_query(pool, "value", "SHOW server_version_num", None)
)
cluster_version_next_check = datetime.now() + timedelta(
seconds=int(config["version_check_period"])
)
log.info("Got PostgreSQL cluster version: {}".format(cluster_version)) log.info("Got PostgreSQL cluster version: {}".format(cluster_version))
log.debug("Next PostgreSQL cluster version check will be after: {}".format(cluster_version_next_check)) log.debug(
"Next PostgreSQL cluster version check will be after: {}".format(
cluster_version_next_check
)
)
return cluster_version return cluster_version
class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
""" """
This is our request handling server. It is responsible for listening for This is our request handling server. It is responsible for listening for
requests, processing them, and responding. requests, processing them, and responding.
""" """
def log_request(self, code='-', size='-'): def log_request(self, code="-", size="-"):
""" """
Override to suppress standard request logging Override to suppress standard request logging
""" """
@ -439,10 +494,10 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
""" """
# Parse the URL # Parse the URL
parsed_path = urlparse(self.path) parsed_path = urlparse(self.path)
name = parsed_path.path.strip('/') name = parsed_path.path.strip("/")
parsed_query = parse_qs(parsed_path.query) parsed_query = parse_qs(parsed_path.query)
if name == 'agent_version': if name == "agent_version":
self._reply(200, VERSION) self._reply(200, VERSION)
return return
@ -452,15 +507,15 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
# Get the metric definition # Get the metric definition
try: try:
metric = config['metrics'][name] metric = config["metrics"][name]
except KeyError: except KeyError:
log.error("Unknown metric: {}".format(name)) log.error("Unknown metric: {}".format(name))
self._reply(404, 'Unknown metric') self._reply(404, "Unknown metric")
return return
# Get the dbname. If none was provided, use the default from the # Get the dbname. If none was provided, use the default from the
# config. # config.
dbname = args.get('dbname', config['dbname']) dbname = args.get("dbname", config["dbname"])
# Get the connection pool for the database, or create one if it doesn't # Get the connection pool for the database, or create one if it doesn't
# already exist. # already exist.
@ -468,7 +523,7 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
pool = get_pool(dbname) pool = get_pool(dbname)
except UnhappyDBError: except UnhappyDBError:
log.info("Database {} is unhappy, please be patient".format(dbname)) log.info("Database {} is unhappy, please be patient".format(dbname))
self._reply(503, 'Database unavailable') self._reply(503, "Database unavailable")
return return
# Identify the PostgreSQL version # Identify the PostgreSQL version
@ -479,10 +534,10 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
except Exception as e: except Exception as e:
if dbname in unhappy_cooldown: if dbname in unhappy_cooldown:
log.info("Database {} is unhappy, please be patient".format(dbname)) log.info("Database {} is unhappy, please be patient".format(dbname))
self._reply(503, 'Database unavailable') self._reply(503, "Database unavailable")
else: else:
log.error("Failed to get PostgreSQL version: {}".format(e)) log.error("Failed to get PostgreSQL version: {}".format(e))
self._reply(500, 'Error getting DB version') self._reply(500, "Error getting DB version")
return return
# Get the query version # Get the query version
@ -490,17 +545,17 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
query = get_query(metric, version) query = get_query(metric, version)
except KeyError: except KeyError:
log.error("Failed to find a version of {} for {}".format(name, version)) log.error("Failed to find a version of {} for {}".format(name, version))
self._reply(404, 'Unsupported version') self._reply(404, "Unsupported version")
return return
# Execute the quert # Execute the quert
try: try:
self._reply(200, run_query(pool, metric['type'], query, args)) self._reply(200, run_query(pool, metric["type"], query, args))
return return
except Exception as e: except Exception as e:
if dbname in unhappy_cooldown: if dbname in unhappy_cooldown:
log.info("Database {} is unhappy, please be patient".format(dbname)) log.info("Database {} is unhappy, please be patient".format(dbname))
self._reply(503, 'Database unavailable') self._reply(503, "Database unavailable")
else: else:
log.error("Error running query: {}".format(e)) log.error("Error running query: {}".format(e))
self._reply(500, "Error running query") self._reply(500, "Error running query")
@ -511,19 +566,24 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
Send a reply to the client Send a reply to the client
""" """
self.send_response(code) self.send_response(code)
self.send_header('Content-type', 'application/json') self.send_header("Content-type", "application/json")
self.end_headers() self.end_headers()
self.wfile.write(bytes(content, 'utf-8')) self.wfile.write(bytes(content, "utf-8"))
if __name__ == '__main__':
if __name__ == "__main__":
# Handle cli args # Handle cli args
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog = 'pgmon', prog="pgmon", description="A PostgreSQL monitoring agent"
description='A PostgreSQL monitoring agent') )
parser.add_argument('config_file', default='pgmon.yml', nargs='?', parser.add_argument(
help='The config file to read (default: %(default)s)') "config_file",
default="pgmon.yml",
nargs="?",
help="The config file to read (default: %(default)s)",
)
args = parser.parse_args() args = parser.parse_args()
@ -534,7 +594,7 @@ if __name__ == '__main__':
read_config(config_file) read_config(config_file)
# Set up the http server to receive requests # Set up the http server to receive requests
server_address = ('127.0.0.1', config['port']) server_address = ("127.0.0.1", config["port"])
httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler) httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler)
# Set up the signal handler # Set up the signal handler
@ -542,7 +602,7 @@ if __name__ == '__main__':
signal.signal(signal.SIGHUP, signal_handler) signal.signal(signal.SIGHUP, signal_handler)
# Handle requests. # Handle requests.
log.info("Listening on port {}...".format(config['port'])) log.info("Listening on port {}...".format(config["port"]))
while running: while running:
httpd.handle_request() httpd.handle_request()

View File

@ -10,6 +10,7 @@ import pgmon
# Silence most logging output # Silence most logging output
logging.disable(logging.CRITICAL) logging.disable(logging.CRITICAL)
class TestPgmonMethods(unittest.TestCase): class TestPgmonMethods(unittest.TestCase):
## ##
# update_deep # update_deep
@ -22,103 +23,104 @@ class TestPgmonMethods(unittest.TestCase):
self.assertEqual(d1, {}) self.assertEqual(d1, {})
self.assertEqual(d2, {}) self.assertEqual(d2, {})
d1 = {'a': 1} d1 = {"a": 1}
d2 = {} d2 = {}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, { 'a': 1 }) self.assertEqual(d1, {"a": 1})
self.assertEqual(d2, {}) self.assertEqual(d2, {})
d1 = {} d1 = {}
d2 = {'a': 1} d2 = {"a": 1}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, { 'a': 1 }) self.assertEqual(d1, {"a": 1})
self.assertEqual(d2, d1) self.assertEqual(d2, d1)
def test_update_deep__scalars(self): def test_update_deep__scalars(self):
# Test adding/updating scalar values # Test adding/updating scalar values
d1 = {'foo': 1, 'bar': "text", 'hello': "world"} d1 = {"foo": 1, "bar": "text", "hello": "world"}
d2 = {'foo': 2, 'baz': "blah"} d2 = {"foo": 2, "baz": "blah"}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'foo': 2, 'bar': "text", 'baz': "blah", 'hello': "world"}) self.assertEqual(d1, {"foo": 2, "bar": "text", "baz": "blah", "hello": "world"})
self.assertEqual(d2, {'foo': 2, 'baz': "blah"}) self.assertEqual(d2, {"foo": 2, "baz": "blah"})
def test_update_deep__lists(self): def test_update_deep__lists(self):
# Test adding to lists # Test adding to lists
d1 = {'lst1': []} d1 = {"lst1": []}
d2 = {'lst1': [1, 2]} d2 = {"lst1": [1, 2]}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'lst1': [1, 2]}) self.assertEqual(d1, {"lst1": [1, 2]})
self.assertEqual(d2, d1) self.assertEqual(d2, d1)
d1 = {'lst1': [1, 2]} d1 = {"lst1": [1, 2]}
d2 = {'lst1': []} d2 = {"lst1": []}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'lst1': [1, 2]}) self.assertEqual(d1, {"lst1": [1, 2]})
self.assertEqual(d2, {'lst1': []}) self.assertEqual(d2, {"lst1": []})
d1 = {'lst1': [1, 2, 3]} d1 = {"lst1": [1, 2, 3]}
d2 = {'lst1': [3, 4]} d2 = {"lst1": [3, 4]}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'lst1': [1, 2, 3, 3, 4]}) self.assertEqual(d1, {"lst1": [1, 2, 3, 3, 4]})
self.assertEqual(d2, {'lst1': [3, 4]}) self.assertEqual(d2, {"lst1": [3, 4]})
# Lists of objects # Lists of objects
d1 = {'lst1': [{'id': 1}, {'id': 2}, {'id': 3}]} d1 = {"lst1": [{"id": 1}, {"id": 2}, {"id": 3}]}
d2 = {'lst1': [{'id': 3}, {'id': 4}]} d2 = {"lst1": [{"id": 3}, {"id": 4}]}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'lst1': [{'id': 1}, {'id': 2}, {'id': 3}, {'id': 3}, {'id': 4}]}) self.assertEqual(
self.assertEqual(d2, {'lst1': [{'id': 3}, {'id': 4}]}) d1, {"lst1": [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 3}, {"id": 4}]}
)
self.assertEqual(d2, {"lst1": [{"id": 3}, {"id": 4}]})
# Nested lists # Nested lists
d1 = {'obj1': {'l1': [1, 2]}} d1 = {"obj1": {"l1": [1, 2]}}
d2 = {'obj1': {'l1': [3, 4]}} d2 = {"obj1": {"l1": [3, 4]}}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'obj1': {'l1': [1, 2, 3, 4]}}) self.assertEqual(d1, {"obj1": {"l1": [1, 2, 3, 4]}})
self.assertEqual(d2, {'obj1': {'l1': [3, 4]}}) self.assertEqual(d2, {"obj1": {"l1": [3, 4]}})
def test_update_deep__dicts(self): def test_update_deep__dicts(self):
# Test adding to lists # Test adding to lists
d1 = {'obj1': {}} d1 = {"obj1": {}}
d2 = {'obj1': {'a': 1, 'b': 2}} d2 = {"obj1": {"a": 1, "b": 2}}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'obj1': {'a': 1, 'b': 2}}) self.assertEqual(d1, {"obj1": {"a": 1, "b": 2}})
self.assertEqual(d2, d1) self.assertEqual(d2, d1)
d1 = {'obj1': {'a': 1, 'b': 2}} d1 = {"obj1": {"a": 1, "b": 2}}
d2 = {'obj1': {}} d2 = {"obj1": {}}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'obj1': {'a': 1, 'b': 2}}) self.assertEqual(d1, {"obj1": {"a": 1, "b": 2}})
self.assertEqual(d2, {'obj1': {}}) self.assertEqual(d2, {"obj1": {}})
d1 = {'obj1': {'a': 1, 'b': 2}} d1 = {"obj1": {"a": 1, "b": 2}}
d2 = {'obj1': {'a': 5, 'c': 12}} d2 = {"obj1": {"a": 5, "c": 12}}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'obj1': {'a': 5, 'b': 2, 'c': 12}}) self.assertEqual(d1, {"obj1": {"a": 5, "b": 2, "c": 12}})
self.assertEqual(d2, {'obj1': {'a': 5, 'c': 12}}) self.assertEqual(d2, {"obj1": {"a": 5, "c": 12}})
# Nested dicts # Nested dicts
d1 = {'obj1': {'d1': {'a': 1, 'b': 2}}} d1 = {"obj1": {"d1": {"a": 1, "b": 2}}}
d2 = {'obj1': {'d1': {'a': 5, 'c': 12}}} d2 = {"obj1": {"d1": {"a": 5, "c": 12}}}
pgmon.update_deep(d1, d2) pgmon.update_deep(d1, d2)
self.assertEqual(d1, {'obj1': {'d1': {'a': 5, 'b': 2, 'c': 12}}}) self.assertEqual(d1, {"obj1": {"d1": {"a": 5, "b": 2, "c": 12}}})
self.assertEqual(d2, {'obj1': {'d1': {'a': 5, 'c': 12}}}) self.assertEqual(d2, {"obj1": {"d1": {"a": 5, "c": 12}}})
def test_update_deep__types(self): def test_update_deep__types(self):
# Test mismatched types # Test mismatched types
d1 = {'foo': 5} d1 = {"foo": 5}
d2 = None d2 = None
self.assertRaises(TypeError, pgmon.update_deep, d1, d2) self.assertRaises(TypeError, pgmon.update_deep, d1, d2)
d1 = None d1 = None
d2 = {'foo': 5} d2 = {"foo": 5}
self.assertRaises(TypeError, pgmon.update_deep, d1, d2) self.assertRaises(TypeError, pgmon.update_deep, d1, d2)
# Nested mismatched types # Nested mismatched types
d1 = {'foo': [1, 2]} d1 = {"foo": [1, 2]}
d2 = {'foo': {'a': 7}} d2 = {"foo": {"a": 7}}
self.assertRaises(TypeError, pgmon.update_deep, d1, d2) self.assertRaises(TypeError, pgmon.update_deep, d1, d2)
## ##
# get_pool # get_pool
## ##
@ -126,20 +128,19 @@ class TestPgmonMethods(unittest.TestCase):
def test_get_pool__simple(self): def test_get_pool__simple(self):
# Just get a pool in a normal case # Just get a pool in a normal case
pgmon.config.update(pgmon.default_config) pgmon.config.update(pgmon.default_config)
pool = pgmon.get_pool('postgres') pool = pgmon.get_pool("postgres")
self.assertIsNotNone(pool) self.assertIsNotNone(pool)
def test_get_pool__unhappy(self): def test_get_pool__unhappy(self):
# Test getting an unhappy database pool # Test getting an unhappy database pool
pgmon.config.update(pgmon.default_config) pgmon.config.update(pgmon.default_config)
pgmon.unhappy_cooldown['postgres'] = datetime.now() + timedelta(60) pgmon.unhappy_cooldown["postgres"] = datetime.now() + timedelta(60)
self.assertRaises(pgmon.UnhappyDBError, pgmon.get_pool, 'postgres') self.assertRaises(pgmon.UnhappyDBError, pgmon.get_pool, "postgres")
# Test getting a different database when there's an unhappy one # Test getting a different database when there's an unhappy one
pool = pgmon.get_pool('template0') pool = pgmon.get_pool("template0")
self.assertIsNotNone(pool) self.assertIsNotNone(pool)
## ##
# handle_connect_failure # handle_connect_failure
## ##
@ -148,70 +149,44 @@ class TestPgmonMethods(unittest.TestCase):
# Test adding to an empty unhappy list # Test adding to an empty unhappy list
pgmon.config.update(pgmon.default_config) pgmon.config.update(pgmon.default_config)
pgmon.unhappy_cooldown = {} pgmon.unhappy_cooldown = {}
pool = pgmon.get_pool('postgres') pool = pgmon.get_pool("postgres")
pgmon.handle_connect_failure(pool) pgmon.handle_connect_failure(pool)
self.assertGreater(pgmon.unhappy_cooldown['postgres'], datetime.now()) self.assertGreater(pgmon.unhappy_cooldown["postgres"], datetime.now())
# Test adding another database # Test adding another database
pool = pgmon.get_pool('template0') pool = pgmon.get_pool("template0")
pgmon.handle_connect_failure(pool) pgmon.handle_connect_failure(pool)
self.assertGreater(pgmon.unhappy_cooldown['postgres'], datetime.now()) self.assertGreater(pgmon.unhappy_cooldown["postgres"], datetime.now())
self.assertGreater(pgmon.unhappy_cooldown['template0'], datetime.now()) self.assertGreater(pgmon.unhappy_cooldown["template0"], datetime.now())
self.assertEqual(len(pgmon.unhappy_cooldown), 2) self.assertEqual(len(pgmon.unhappy_cooldown), 2)
## ##
# get_query # get_query
## ##
def test_get_query__basic(self): def test_get_query__basic(self):
# Test getting a query with one version # Test getting a query with one version
metric = { metric = {"type": "value", "query": {0: "DEFAULT"}}
'type': 'value', self.assertEqual(pgmon.get_query(metric, 100000), "DEFAULT")
'query': {
0: 'DEFAULT'
}
}
self.assertEqual(pgmon.get_query(metric, 100000), 'DEFAULT')
def test_get_query__versions(self): def test_get_query__versions(self):
metric = { metric = {"type": "value", "query": {0: "DEFAULT", 110000: "NEW"}}
'type': 'value',
'query': {
0: 'DEFAULT',
110000: 'NEW'
}
}
# Test getting the default version of a query with no lower bound and a newer version # Test getting the default version of a query with no lower bound and a newer version
self.assertEqual(pgmon.get_query(metric, 100000), 'DEFAULT') self.assertEqual(pgmon.get_query(metric, 100000), "DEFAULT")
# Test getting the newer version of a query with no lower bound and a newer version for the newer version # Test getting the newer version of a query with no lower bound and a newer version for the newer version
self.assertEqual(pgmon.get_query(metric, 110000), 'NEW') self.assertEqual(pgmon.get_query(metric, 110000), "NEW")
# Test getting the newer version of a query with no lower bound and a newer version for an even newer version # Test getting the newer version of a query with no lower bound and a newer version for an even newer version
self.assertEqual(pgmon.get_query(metric, 160000), 'NEW') self.assertEqual(pgmon.get_query(metric, 160000), "NEW")
# Test getting a version in bwtween two other versions # Test getting a version in bwtween two other versions
metric = { metric = {"type": "value", "query": {0: "DEFAULT", 96000: "OLD", 110000: "NEW"}}
'type': 'value', self.assertEqual(pgmon.get_query(metric, 100000), "OLD")
'query': {
0: 'DEFAULT',
96000: 'OLD',
110000: 'NEW'
}
}
self.assertEqual(pgmon.get_query(metric, 100000), 'OLD')
def test_get_query__missing_version(self): def test_get_query__missing_version(self):
metric = { metric = {"type": "value", "query": {96000: "OLD", 110000: "NEW", 150000: ""}}
'type': 'value',
'query': {
96000: 'OLD',
110000: 'NEW',
150000: ''
}
}
# Test getting a metric that only exists for newer versions # Test getting a metric that only exists for newer versions
self.assertRaises(pgmon.MetricVersionError, pgmon.get_query, metric, 80000) self.assertRaises(pgmon.MetricVersionError, pgmon.get_query, metric, 80000)
@ -219,7 +194,6 @@ class TestPgmonMethods(unittest.TestCase):
# Test getting a metric that only exists for older versions # Test getting a metric that only exists for older versions
self.assertRaises(pgmon.MetricVersionError, pgmon.get_query, metric, 160000) self.assertRaises(pgmon.MetricVersionError, pgmon.get_query, metric, 160000)
## ##
# read_config # read_config
## ##
@ -229,27 +203,32 @@ class TestPgmonMethods(unittest.TestCase):
# Test reading just a metric and using the defaults for everything else # Test reading just a metric and using the defaults for everything else
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
# This is a comment! # This is a comment!
metrics: metrics:
test1: test1:
type: value type: value
query: query:
0: TEST1 0: TEST1
""") """
)
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
self.assertEqual(pgmon.config['max_pool_size'], pgmon.default_config['max_pool_size']) self.assertEqual(
self.assertEqual(pgmon.config['dbuser'], pgmon.default_config['dbuser']) pgmon.config["max_pool_size"], pgmon.default_config["max_pool_size"]
)
self.assertEqual(pgmon.config["dbuser"], pgmon.default_config["dbuser"])
pgmon.config = {} pgmon.config = {}
# Test reading a basic config # Test reading a basic config
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
# This is a comment! # This is a comment!
min_pool_size: 1 min_pool_size: 1
max_pool_size: 2 max_pool_size: 2
@ -280,22 +259,24 @@ metrics:
type: column type: column
query: query:
0: TEST4 0: TEST4
""") """
)
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
self.assertEqual(pgmon.config['dbuser'], 'someone') self.assertEqual(pgmon.config["dbuser"], "someone")
self.assertEqual(pgmon.config['metrics']['test1']['type'], 'value') self.assertEqual(pgmon.config["metrics"]["test1"]["type"], "value")
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.config['metrics']['test2']['query'][0], 'TEST2') self.assertEqual(pgmon.config["metrics"]["test2"]["query"][0], "TEST2")
def test_read_config__include(self): def test_read_config__include(self):
pgmon.config = {} pgmon.config = {}
# Test reading a config that includes other files (absolute and relative paths, multiple levels) # Test reading a config that includes other files (absolute and relative paths, multiple levels)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write(f"""--- f.write(
f"""---
# This is a comment! # This is a comment!
min_pool_size: 1 min_pool_size: 1
max_pool_size: 2 max_pool_size: 2
@ -308,18 +289,22 @@ version_check_period: 3600
include: include:
- dbsettings.yml - dbsettings.yml
- {tmpdirname}/metrics.yml - {tmpdirname}/metrics.yml
""") """
)
with open(f"{tmpdirname}/dbsettings.yml", 'w') as f: with open(f"{tmpdirname}/dbsettings.yml", "w") as f:
f.write(f"""--- f.write(
f"""---
dbuser: someone dbuser: someone
dbhost: localhost dbhost: localhost
dbport: 5555 dbport: 5555
dbname: template0 dbname: template0
""") """
)
with open(f"{tmpdirname}/metrics.yml", 'w') as f: with open(f"{tmpdirname}/metrics.yml", "w") as f:
f.write(f"""--- f.write(
f"""---
metrics: metrics:
test1: test1:
type: value type: value
@ -331,31 +316,35 @@ metrics:
0: TEST2 0: TEST2
include: include:
- more_metrics.yml - more_metrics.yml
""") """
)
with open(f"{tmpdirname}/more_metrics.yml", 'w') as f: with open(f"{tmpdirname}/more_metrics.yml", "w") as f:
f.write(f"""--- f.write(
f"""---
metrics: metrics:
test3: test3:
type: value type: value
query: query:
0: TEST3 0: TEST3
""") """
)
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
self.assertEqual(pgmon.config['max_idle_time'], 10) self.assertEqual(pgmon.config["max_idle_time"], 10)
self.assertEqual(pgmon.config['dbuser'], 'someone') self.assertEqual(pgmon.config["dbuser"], "someone")
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
self.assertEqual(pgmon.config['metrics']['test2']['query'][0], 'TEST2') self.assertEqual(pgmon.config["metrics"]["test2"]["query"][0], "TEST2")
self.assertEqual(pgmon.config['metrics']['test3']['query'][0], 'TEST3') self.assertEqual(pgmon.config["metrics"]["test3"]["query"][0], "TEST3")
def test_read_config__reload(self): def test_read_config__reload(self):
pgmon.config = {} pgmon.config = {}
# Test rereading a config to update an existing config # Test rereading a config to update an existing config
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
# This is a comment! # This is a comment!
min_pool_size: 1 min_pool_size: 1
max_pool_size: 2 max_pool_size: 2
@ -378,15 +367,17 @@ metrics:
type: value type: value
query: query:
0: TEST2 0: TEST2
""") """
)
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
# Just make sure the first config was read # Just make sure the first config was read
self.assertEqual(len(pgmon.config['metrics']), 2) self.assertEqual(len(pgmon.config["metrics"]), 2)
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
# This is a comment! # This is a comment!
min_pool_size: 7 min_pool_size: 7
metrics: metrics:
@ -394,34 +385,39 @@ metrics:
type: value type: value
query: query:
0: NEW1 0: NEW1
""") """
)
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
self.assertEqual(pgmon.config['min_pool_size'], 7) self.assertEqual(pgmon.config["min_pool_size"], 7)
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'NEW1') self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "NEW1")
self.assertEqual(len(pgmon.config['metrics']), 1) self.assertEqual(len(pgmon.config["metrics"]), 1)
def test_read_config__query_file(self): def test_read_config__query_file(self):
pgmon.config = {} pgmon.config = {}
# Read a config file that reads a query from a file # Read a config file that reads a query from a file
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
metrics: metrics:
test1: test1:
type: value type: value
query: query:
0: file:some_query.sql 0: file:some_query.sql
""") """
)
with open(f"{tmpdirname}/some_query.sql", 'w') as f: with open(f"{tmpdirname}/some_query.sql", "w") as f:
f.write("This is a query") f.write("This is a query")
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'This is a query') self.assertEqual(
pgmon.config["metrics"]["test1"]["query"][0], "This is a query"
)
def test_read_config__invalid(self): def test_read_config__invalid(self):
pgmon.config = {} pgmon.config = {}
@ -429,38 +425,47 @@ metrics:
# For all of these tests, we start with a valid config and also ensure that # For all of these tests, we start with a valid config and also ensure that
# it is not modified when a new config read fails # it is not modified when a new config read fails
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
metrics: metrics:
test1: test1:
type: value type: value
query: query:
0: TEST1 0: TEST1
""") """
)
pgmon.read_config(f"{tmpdirname}/config.yml") pgmon.read_config(f"{tmpdirname}/config.yml")
# Just make sure the config was read # Just make sure the config was read
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test reading a nonexistant config file # Test reading a nonexistant config file
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
self.assertRaises(FileNotFoundError, pgmon.read_config, f'{tmpdirname}/missing.yml') self.assertRaises(
FileNotFoundError, pgmon.read_config, f"{tmpdirname}/missing.yml"
)
# Test reading an invalid config file # Test reading an invalid config file
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""[default] f.write(
"""[default]
This looks a lot like an ini file to me This looks a lot like an ini file to me
Or maybe a TOML? Or maybe a TOML?
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertRaises(
pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
# Test reading a config that includes an invalid file # Test reading a config that includes an invalid file
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
@ -469,15 +474,19 @@ metrics:
0: EVIL1 0: EVIL1
include: include:
- missing_file.yml - missing_file.yml
""") """
self.assertRaises(FileNotFoundError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') FileNotFoundError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test invalid log level # Test invalid log level
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
log_level: noisy log_level: noisy
dbuser: evil dbuser: evil
metrics: metrics:
@ -485,132 +494,170 @@ metrics:
type: value type: value
query: query:
0: EVIL1 0: EVIL1
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test invalid query return type # Test invalid query return type
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: lots_of_data type: lots_of_data
query: query:
0: EVIL1 0: EVIL1
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test invalid query dict type # Test invalid query dict type
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: lots_of_data type: lots_of_data
query: EVIL1 query: EVIL1
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test incomplete metric: missing type # Test incomplete metric: missing type
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
query: query:
0: EVIL1 0: EVIL1
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test incomplete metric: missing queries # Test incomplete metric: missing queries
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: value type: value
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test incomplete metric: empty queries # Test incomplete metric: empty queries
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: value type: value
query: {} query: {}
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test incomplete metric: query dict is None # Test incomplete metric: query dict is None
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: value type: value
query: query:
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test reading a config with no metrics # Test reading a config with no metrics
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test reading a query defined in a file but the file is missing # Test reading a query defined in a file but the file is missing
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: value type: value
query: query:
0: file:missing.sql 0: file:missing.sql
""") """
self.assertRaises(FileNotFoundError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') FileNotFoundError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")
# Test invalid query versions # Test invalid query versions
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
with open(f"{tmpdirname}/config.yml", 'w') as f: with open(f"{tmpdirname}/config.yml", "w") as f:
f.write("""--- f.write(
"""---
dbuser: evil dbuser: evil
metrics: metrics:
test1: test1:
type: value type: value
query: query:
default: EVIL1 default: EVIL1
""") """
self.assertRaises(pgmon.ConfigError, pgmon.read_config, f'{tmpdirname}/config.yml') )
self.assertEqual(pgmon.config['dbuser'], 'postgres') self.assertRaises(
self.assertEqual(pgmon.config['metrics']['test1']['query'][0], 'TEST1') pgmon.ConfigError, pgmon.read_config, f"{tmpdirname}/config.yml"
)
self.assertEqual(pgmon.config["dbuser"], "postgres")
self.assertEqual(pgmon.config["metrics"]["test1"]["query"][0], "TEST1")