Add ability to run query tests

This commit is contained in:
James Campbell 2025-05-18 12:52:32 -04:00
parent 8928bba337
commit 529bef9679
Signed by: james
GPG Key ID: 2287C33A40DC906A

View File

@ -4,6 +4,7 @@ import yaml
import json import json
import time import time
import os import os
import sys
import argparse import argparse
import logging import logging
@ -74,6 +75,10 @@ class UnhappyDBError(Exception):
pass pass
class UnknownMetricError(Exception):
pass
class MetricVersionError(Exception): class MetricVersionError(Exception):
pass pass
@ -466,6 +471,54 @@ def get_cluster_version():
return cluster_version return cluster_version
def sample_metric(dbname, metric_name, args, retry=True):
"""
Run the appropriate query for the named metric against the specified database
"""
# Get the metric definition
try:
metric = config["metrics"][metric_name]
except KeyError:
raise UnknownMetricError("Unknown metric: {}".format(metric_name))
# Get the connection pool for the database, or create one if it doesn't
# already exist.
pool = get_pool(dbname)
# Identify the PostgreSQL version
version = get_cluster_version()
# Get the query version
query = get_query(metric, version)
# Execute the quert
if retry:
return run_query(pool, metric["type"], query, args)
else:
return run_query_no_retry(pool, metric["type"], query, args)
def test_queries():
"""
Run all of the metric queries against a database and check the results
"""
# We just use the default db for tests
dbname = config["dbname"]
# Loop through all defined metrics.
for metric_name in config["metrics"].keys():
# Get the actual metric definition
metric = metrics[metric_name]
# If the metric has arguments to use while testing, grab those
args = metric.get("test_args", {})
# Run the query without the ability to retry.
res = sample_metric(dbname, metric_name, args, retry=False)
# Compare the result to the provided sample results
# TODO
# Return the number of errors
# TODO
return 0
class SimpleHTTPRequestHandler(BaseHTTPRequestHandler): class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
""" """
This is our request handling server. It is responsible for listening for This is our request handling server. It is responsible for listening for
@ -494,10 +547,10 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
""" """
# Parse the URL # Parse the URL
parsed_path = urlparse(self.path) parsed_path = urlparse(self.path)
name = parsed_path.path.strip("/") metric_name = parsed_path.path.strip("/")
parsed_query = parse_qs(parsed_path.query) parsed_query = parse_qs(parsed_path.query)
if name == "agent_version": if metric_name == "agent_version":
self._reply(200, VERSION) self._reply(200, VERSION)
return return
@ -505,60 +558,31 @@ class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
# single values, just grab the first from each. # single values, just grab the first from each.
args = {key: values[0] for key, values in parsed_query.items()} args = {key: values[0] for key, values in parsed_query.items()}
# Get the metric definition
try:
metric = config["metrics"][name]
except KeyError:
log.error("Unknown metric: {}".format(name))
self._reply(404, "Unknown metric")
return
# Get the dbname. If none was provided, use the default from the # Get the dbname. If none was provided, use the default from the
# config. # config.
dbname = args.get("dbname", config["dbname"]) dbname = args.get("dbname", config["dbname"])
# Get the connection pool for the database, or create one if it doesn't # Sample the metric
# already exist.
try: try:
pool = get_pool(dbname) self._reply(200, sample_metric(dbname, metric_name, args))
except UnhappyDBError: return
except UnknownMetricError as e:
log.error("Unknown metric: {}".format(metric_name))
self._reply(404, "Unknown metric")
return
except MetricVersionError as e:
log.error(
"Failed to find a version of {} for {}".format(metric_name, version)
)
self._reply(404, "Unsupported version")
return
except UnhappyDBError as e:
log.info("Database {} is unhappy, please be patient".format(dbname)) log.info("Database {} is unhappy, please be patient".format(dbname))
self._reply(503, "Database unavailable") self._reply(503, "Database unavailable")
return return
# Identify the PostgreSQL version
try:
version = get_cluster_version()
except UnhappyDBError:
return
except Exception as e: except Exception as e:
if dbname in unhappy_cooldown: log.error("Error running query: {}".format(e))
log.info("Database {} is unhappy, please be patient".format(dbname)) self._reply(500, "Unexpected error: {}".format(e))
self._reply(503, "Database unavailable")
else:
log.error("Failed to get PostgreSQL version: {}".format(e))
self._reply(500, "Error getting DB version")
return
# Get the query version
try:
query = get_query(metric, version)
except KeyError:
log.error("Failed to find a version of {} for {}".format(name, version))
self._reply(404, "Unsupported version")
return
# Execute the quert
try:
self._reply(200, run_query(pool, metric["type"], query, args))
return
except Exception as e:
if dbname in unhappy_cooldown:
log.info("Database {} is unhappy, please be patient".format(dbname))
self._reply(503, "Database unavailable")
else:
log.error("Error running query: {}".format(e))
self._reply(500, "Error running query")
return return
def _reply(self, code, content): def _reply(self, code, content):
@ -585,6 +609,8 @@ if __name__ == "__main__":
help="The config file to read (default: %(default)s)", help="The config file to read (default: %(default)s)",
) )
parser.add_argument("test", action="store_true", help="Run query tests and exit")
args = parser.parse_args() args = parser.parse_args()
# Set the config file path # Set the config file path
@ -593,6 +619,14 @@ if __name__ == "__main__":
# Read the config file # Read the config file
read_config(config_file) read_config(config_file)
# Run query tests and exit if test mode is enabled
if args.test:
errors = test_queries()
if errors > 0:
sys.exit(1)
else:
sys.exit(0)
# Set up the http server to receive requests # Set up the http server to receive requests
server_address = ("127.0.0.1", config["port"]) server_address = ("127.0.0.1", config["port"])
httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler) httpd = ThreadingHTTPServer(server_address, SimpleHTTPRequestHandler)