Teach json how to serialize decimals

This commit is contained in:
James Campbell 2025-07-03 01:47:06 -04:00
parent 3c39d8aa97
commit 24d1214855
Signed by: james
GPG Key ID: 2287C33A40DC906A
2 changed files with 30 additions and 2 deletions

View File

@ -27,6 +27,8 @@ from urllib.parse import urlparse, parse_qs
import requests import requests
import re import re
from decimal import Decimal
VERSION = "1.0.3" VERSION = "1.0.3"
# Configuration # Configuration
@ -391,6 +393,16 @@ def get_query(metric, version):
raise MetricVersionError("Missing metric query for PostgreSQL {}".format(version)) raise MetricVersionError("Missing metric query for PostgreSQL {}".format(version))
def json_encode_special(obj):
"""
Encoder function to handle types the standard JSON package doesn't know what
to do with
"""
if isinstance(obj, Decimal):
return float(obj)
raise TypeError(f'Cannot serialize object of {type(obj)}')
def run_query_no_retry(pool, return_type, query, args): def run_query_no_retry(pool, return_type, query, args):
""" """
Run the query with no explicit retry code Run the query with no explicit retry code
@ -408,11 +420,11 @@ def run_query_no_retry(pool, return_type, query, args):
elif return_type == "row": elif return_type == "row":
if len(res) == 0: if len(res) == 0:
return "[]" return "[]"
return json.dumps(res[0]) return json.dumps(res[0], default=json_encode_special)
elif return_type == "column": elif return_type == "column":
if len(res) == 0: if len(res) == 0:
return "[]" return "[]"
return json.dumps([list(r.values())[0] for r in res]) return json.dumps([list(r.values())[0] for r in res], default=json_encode_special)
elif return_type == "set": elif return_type == "set":
return json.dumps(res) return json.dumps(res)
except: except:

View File

@ -5,6 +5,8 @@ import tempfile
import logging import logging
from decimal import Decimal
import pgmon import pgmon
# Silence most logging output # Silence most logging output
@ -789,3 +791,17 @@ metrics:
# Make sure we can pull the RSS file (we assume the 9.6 series won't be getting # Make sure we can pull the RSS file (we assume the 9.6 series won't be getting
# any more updates) # any more updates)
self.assertEqual(pgmon.get_latest_version(), 90624) self.assertEqual(pgmon.get_latest_version(), 90624)
def test_json_encode_special(self):
# Confirm that we're getting the right type
self.assertFalse(isinstance(Decimal('0.5'), float))
self.assertTrue(isinstance(pgmon.json_encode_special(Decimal('0.5')), float))
# Make sure we get sane values
self.assertEqual(pgmon.json_encode_special(Decimal('0.5')), 0.5)
self.assertEqual(pgmon.json_encode_special(Decimal('12')), 12.0)
# Make sure we can still fail for other types
self.assertRaises(
TypeError, pgmon.json_encode_special, object
)