From 28179d8937071c1153c5136a00cf2109400573b2 Mon Sep 17 00:00:00 2001 From: Pratik Garg Date: Tue, 12 May 2026 16:11:59 -0700 Subject: [PATCH] Implement Prometheus metrics emission for PyGrain. PiperOrigin-RevId: 914530722 --- grain/__init__.py | 4 +- grain/_src/core/BUILD | 35 +- grain/_src/core/config.py | 5 +- grain/_src/core/monitoring.py | 167 ++++--- grain/_src/core/monitoring_base.py | 96 ++++ grain/_src/core/prometheus_monitoring.py | 412 ++++++++++++++++++ grain/_src/core/prometheus_monitoring_test.py | 326 ++++++++++++++ grain/_src/python/data_loader.py | 6 +- grain/_src/python/data_sources.py | 5 +- grain/_src/python/dataset/dataset.py | 5 +- grain/_src/python/dataset/stats.py | 6 +- grain/_src/python/load.py | 5 +- grain/_src/python/samplers.py | 5 +- test_requirements.in | 1 + test_requirements_lock_3_11.txt | 4 + test_requirements_lock_3_12.txt | 4 + test_requirements_lock_3_13.txt | 4 + test_requirements_lock_3_14.txt | 4 + 18 files changed, 1004 insertions(+), 90 deletions(-) create mode 100644 grain/_src/core/monitoring_base.py create mode 100644 grain/_src/core/prometheus_monitoring.py create mode 100644 grain/_src/core/prometheus_monitoring_test.py diff --git a/grain/__init__.py b/grain/__init__.py index 4ebc6192a..d07ee37d7 100644 --- a/grain/__init__.py +++ b/grain/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. """Public API for Grain.""" - # pylint: disable=g-importing-member # pylint: disable=unused-import # pylint: disable=g-multiple-import @@ -32,6 +31,7 @@ transforms, ) +from grain._src.core import monitoring from grain._src.core.config import config from grain._src.core.version import __version__, __version_info__ from grain._src.python.data_loader import ( @@ -46,3 +46,5 @@ from grain._src.python.load import load from grain._src.python.options import ReadOptions, MultiprocessingOptions from grain._src.python.record import Record, RecordMetadata + +monitoring.setup_telemetry() diff --git a/grain/_src/core/BUILD b/grain/_src/core/BUILD index 0e820b107..dfde8f0a0 100644 --- a/grain/_src/core/BUILD +++ b/grain/_src/core/BUILD @@ -27,11 +27,28 @@ py_library( srcs_version = "PY3", ) +py_library( + name = "monitoring_base", + srcs = ["monitoring_base.py"], + srcs_version = "PY3", +) + py_library( name = "monitoring", - srcs = ["monitoring.py"], + srcs = [ + "monitoring.py", + "prometheus_monitoring.py", + ], srcs_version = "PY3", + tags = [ + "ignore_for_dep=grain._src.core.fast_monitoring", + "ignore_for_dep=grain._src.core.google.streamz_monitoring", + "ignore_for_dep=grain._src.core.monitoring_base", + "ignore_for_dep=grain._src.core.prometheus_monitoring", + ], deps = [ + ":monitoring_base", + "@pypi//prometheus_client:pkg", ], ) @@ -255,3 +272,19 @@ py_test( "@pypi//cloudpickle:pkg", ], ) + +py_test( + name = "prometheus_monitoring_test", + srcs = ["prometheus_monitoring_test.py"], + srcs_version = "PY3", + tags = [ + "ignore_for_dep=grain._src.core.google_monitoring", + "ignore_for_dep=grain._src.core.monitoring_base", + "ignore_for_dep=grain._src.core.prometheus_monitoring", + ], + deps = [ + ":monitoring", + "@abseil-py//absl/testing:absltest", + "@pypi//prometheus_client:pkg", + ], +) diff --git a/grain/_src/core/config.py b/grain/_src/core/config.py index e062073ed..0a6878c8f 100644 --- a/grain/_src/core/config.py +++ b/grain/_src/core/config.py @@ -21,8 +21,6 @@ from typing import Any from absl import flags -from grain._src.core import monitoring as grain_monitoring - from grain._src.core import monitoring # Performance optimisations. Consider most of these experimental. We might @@ -148,9 +146,8 @@ _grain_experiment_metric = monitoring.Metric( "/grain/experiment", - value_type=int, + int, metadata=monitoring.Metadata(description="Grain experiment opt-in metric."), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) diff --git a/grain/_src/core/monitoring.py b/grain/_src/core/monitoring.py index d98147479..b28f2d775 100644 --- a/grain/_src/core/monitoring.py +++ b/grain/_src/core/monitoring.py @@ -1,58 +1,109 @@ -"""Grain metrics.""" - -import enum - - -@enum.unique -class Units(enum.Enum): - """Grain metric units.""" - - NANOSECONDS = enum.auto() - MILLISECONDS = enum.auto() - - -# pylint: disable=invalid-name -class NoOpMetric: - """Grain metric no-op implementation.""" - - def __init__(self, *args, **kwargs): - del args, kwargs - - def IncrementBy(self, *args, **kwargs): - del args, kwargs - - def Increment(self, *args, **kwargs): - self.IncrementBy(1, *args, **kwargs) - - def Set(self, *args, **kwargs): - del args, kwargs - - def Record(self, *args, **kwargs): - del args, kwargs - - def Get(self, *args, **kwargs): - del args, kwargs - - -class Metadata: - """Grain metric no-op metadata.""" - - def __init__(self, *args, **kwargs): - del args, kwargs - - -# pylint: disable=invalid-name -class Bucketer: - """Grain metric no-op bucketer.""" - - def __init__(self, *args, **kwargs): - del args, kwargs - - def PowersOf(self, *args, **kwargs): - del args, kwargs - - -Counter = Metric = EventMetric = NoOpMetric - -def get_monitoring_root() -> None: - return None +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Grain metrics convenience functions and implementation selection.""" + +from __future__ import annotations + +from grain._src.core import monitoring_base + +from grain._src.core import prometheus_monitoring as _impl + +Bucketer = monitoring_base.Bucketer +Metadata = monitoring_base.Metadata +Units = monitoring_base.Units + +Counter = _impl.Counter +EventMetric = _impl.EventMetric +Metric = _impl.Metric +get_monitoring_root = _impl.get_monitoring_root +setup_telemetry = _impl.SetupTelemetry + + +def record_autotune_node_throughput(node_name: str, throughput: float): + _grain_autotune_node_throughput.Set(throughput, node_name) + + +def record_autotune_parameter(node_name: str, name: str, value: float): + _grain_autotune_parameters.Set(value, node_name, name) + + +def record_autotune_usl_coeff(node_name: str, name: str, value: float): + _grain_autotune_usl_coeffs.Set(value, node_name, name) + + +def record_autotune_optimization_latency(latency_ms: float): + _grain_autotune_optimization_latency.Record(latency_ms) + + +def record_framework_type(framework_type: str): + _grain_framework_type_metric.Increment(framework_type) + + +def set_debug_server_ports(ports: list[int]): + if not ports: + return + _grain_debug_server_ports.Set(','.join(map(str, ports))) + + +_grain_framework_type_metric = Counter( + '/grain/framework_type', + Metadata(description='The framework type used to build the Grain dataset.'), + fields=[('name', str)], +) + +_grain_debug_server_ports = Metric( + '/grain/debug_server_ports', + str, + metadata=Metadata( + description=( + 'A comma-separated list of debug server ports. The first port is' + ' for the main process, followed by worker process ports.' + ) + ), +) + + +_grain_autotune_node_throughput = Metric( + '/grain/autotune/node_throughput', + float, + metadata=Metadata( + description='The observed throughput of the autotune node (elements/s).' + ), + fields=[('node_name', str)], +) + +_grain_autotune_parameters = Metric( + '/grain/autotune/parameters', + float, + metadata=Metadata( + description='The current values for the autotuned parameters.' + ), + fields=[('node_name', str), ('name', str)], +) + +_grain_autotune_usl_coeffs = Metric( + '/grain/autotune/usl_coeffs', + float, + metadata=Metadata( + description='The estimated USL coefficients for the autotune node.' + ), + fields=[('node_name', str), ('name', str)], +) + +_grain_autotune_optimization_latency = EventMetric( + '/grain/autotune/optimization_latency', + metadata=Metadata( + description='The time taken by the autotune optimizer (ms).' + ), +) diff --git a/grain/_src/core/monitoring_base.py b/grain/_src/core/monitoring_base.py new file mode 100644 index 000000000..3bfe19ff4 --- /dev/null +++ b/grain/_src/core/monitoring_base.py @@ -0,0 +1,96 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Grain metrics base definitions and protocols.""" + +from __future__ import annotations + +import enum +from typing import Any, Protocol + + +@enum.unique +class Units(enum.Enum): + """Grain metric units.""" + + SECONDS = enum.auto() + MILLISECONDS = enum.auto() + MICROSECONDS = enum.auto() + NANOSECONDS = enum.auto() + BITS = enum.auto() + BYTES = enum.auto() + + +class Metadata: + """Grain metric metadata.""" + + def __init__(self, description='', **kwargs): + self.description = description + for key, value in kwargs.items(): + setattr(self, key, value) + self._kwargs = kwargs + + +class Bucketer: + """Grain metric bucketer.""" + + def __init__(self, *args, bucketer_type=None, **kwargs): + self.args = args + self.kwargs = kwargs + self.type = bucketer_type + + @staticmethod + def PowersOf(base: float): + return Bucketer(base, bucketer_type='PowersOf') + + +class CounterProtocol(Protocol): + """Protocol for Counter metrics.""" + + def Increment(self, *args: Any, **kwargs: Any) -> None: + ... + + def IncrementBy(self, value: float | int, *args: Any, **kwargs: Any) -> None: + ... + + def Get(self, *args: Any, **kwargs: Any) -> float | int: + ... + + def ClearAll(self) -> None: + ... + + +class MetricProtocol(Protocol): + """Protocol for Gauge metrics.""" + + def Set(self, value: float | int, *args: Any, **kwargs: Any) -> None: + ... + + def Get(self, *args: Any, **kwargs: Any) -> float | int: + ... + + def ClearAll(self) -> None: + ... + + +class EventMetricProtocol(Protocol): + """Protocol for Event/Histogram metrics.""" + + def Record(self, value: float | int, *args: Any, **kwargs: Any) -> None: + ... + + def Get(self, *args: Any, **kwargs: Any) -> float | int: + ... + + def ClearAll(self) -> None: + ... diff --git a/grain/_src/core/prometheus_monitoring.py b/grain/_src/core/prometheus_monitoring.py new file mode 100644 index 000000000..a4f45888c --- /dev/null +++ b/grain/_src/core/prometheus_monitoring.py @@ -0,0 +1,412 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Grain metrics Prometheus implementation.""" + +from __future__ import annotations + +import atexit +import errno +import logging +import multiprocessing +import os +import shutil +import tempfile +import threading +from typing import Any, cast + +from grain._src.core import monitoring_base + +Bucketer = monitoring_base.Bucketer +Metadata = monitoring_base.Metadata +Units = monitoring_base.Units + +_DEFAULT_PROMETHEUS_PORT = 9431 + +# Keep a global reference so the directory is not deleted until the program +# exits. +_prometheus_multiproc_dir = None + +try: + # pylint: disable=g-import-not-at-top + import prometheus_client # pytype: disable=import-error + from prometheus_client import multiprocess # pytype: disable=import-error + + prometheus_client = cast(Any, prometheus_client) + multiprocess = cast(Any, multiprocess) + + _prom_counter = prometheus_client.Counter + _prom_gauge = prometheus_client.Gauge + _prom_histogram = prometheus_client.Histogram +except (ImportError, AttributeError): + prometheus_client = None + _prom_counter = None + _prom_gauge = None + _prom_histogram = None + +_initialized = False +_prometheus_metrics = {} # Maps metric names to their Prometheus objects. +_lock = threading.Lock() + +_PROMETHEUS_ALLOWED_METRICS = { + '/grain/python/dataset/next_duration_ns', + '/grain/python/dataset/prefetch_buffer_ready_count', + '/grain/python/data_sources/bytes_read', + '/grain/python/dataset/source_read_time_ns', +} + + +def _IsAllowed(metric_name: str) -> bool: + """Returns True if the metric is allowed for Prometheus export.""" + return metric_name in _PROMETHEUS_ALLOWED_METRICS + + +# pylint: disable=invalid-name +def get_monitoring_root() -> None: + """Returns None as Prometheus does not have a monitoring root.""" + return None + + +# pylint: enable=invalid-name + + +_METADATA_TYPES = (Metadata,) + + +def _ExtractMetadataAndFields( + metadata_candidate: Any, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[Any, list[tuple[str, Any]], tuple[Any, ...]]: + """Robustly extracts metadata and fields from positional or keyword args.""" + _metadata = kwargs.pop('metadata', None) + _fields = kwargs.pop('fields', None) + + all_pos = [] + if metadata_candidate is not None: + all_pos.append(metadata_candidate) + all_pos.extend(args) + + remaining = [] + for arg in all_pos: + if _metadata is None and isinstance(arg, _METADATA_TYPES): + _metadata = arg + elif _fields is None and isinstance(arg, (list, tuple)): + _fields = arg + else: + remaining.append(arg) + + return _metadata, (_fields or []), tuple(remaining) + + +class Counter: + """Prometheus Counter wrapper.""" + + def __init__(self, name, metadata=None, *args, **kwargs): + _metadata, _fields, _ = _ExtractMetadataAndFields(metadata, args, kwargs) + + if not _IsAllowed(name) or not _prom_counter: + self._metric = None + return + + prom_name = name.strip('/').replace('/', '_') + labelnames = tuple(f[0] for f in _fields) + description = 'Grain Counter' + if _metadata and hasattr(_metadata, 'description'): + description = _metadata.description + + with _lock: + if prom_name not in _prometheus_metrics: + _prometheus_metrics[prom_name] = _prom_counter( + prom_name, description, labelnames=labelnames + ) + self._metric = _prometheus_metrics[prom_name] + + def Increment(self, *args, **kwargs): + if not self._metric: + return + if not args and not kwargs: + self._metric.inc() + return + try: + self._metric.labels(*args, **kwargs).inc() + except ValueError as e: + logging.warning( + 'Failed to record Prometheus event due to label mismatch: %s.', e + ) + + def IncrementBy(self, value, *args, **kwargs): + if not self._metric: + return + if not args and not kwargs: + self._metric.inc(value) + return + try: + self._metric.labels(*args, **kwargs).inc(value) + except ValueError as e: + logging.warning( + 'Failed to record Prometheus event due to label mismatch: %s.', e + ) + + def Get(self, *args, **kwargs): + """Returns 0.0 as Prometheus histograms do not support Get.""" + del args, kwargs + return 0.0 + + def ClearAll(self): + """No-op for Prometheus histograms.""" + pass + + +class Metric: + """Prometheus Gauge wrapper.""" + + def __init__(self, name, value_type=float, metadata=None, *args, **kwargs): + del value_type + _metadata, _fields, _ = _ExtractMetadataAndFields(metadata, args, kwargs) + + if not _IsAllowed(name) or not _prom_gauge: + self._metric = None + return + + prom_name = name.strip('/').replace('/', '_') + labelnames = tuple(f[0] for f in _fields) + description = 'Grain Gauge' + if _metadata and hasattr(_metadata, 'description'): + description = _metadata.description + + with _lock: + if prom_name not in _prometheus_metrics: + _prometheus_metrics[prom_name] = _prom_gauge( + prom_name, description, labelnames=labelnames + ) + self._metric = _prometheus_metrics[prom_name] + + def Set(self, value, *args, **kwargs): + if not self._metric: + return + if not args and not kwargs: + self._metric.set(value) + return + try: + self._metric.labels(*args, **kwargs).set(value) + except ValueError as e: + logging.warning( + 'Failed to record Prometheus gauge due to label mismatch: %s.', e + ) + + def Get(self, *args, **kwargs): + """Returns 0.0 as Prometheus histograms do not support Get.""" + del args, kwargs + return 0.0 + + def ClearAll(self): + """No-op for Prometheus histograms.""" + pass + + +class EventMetric: + """Prometheus Histogram wrapper.""" + + def __init__(self, name, metadata=None, *args, **kwargs): + _metadata, _fields, _ = _ExtractMetadataAndFields(metadata, args, kwargs) + + if not _IsAllowed(name) or not _prom_histogram: + self._metric = None + return + + prom_name = name.strip('/').replace('/', '_') + labelnames = tuple(f[0] for f in _fields) + description = 'Grain Histogram' + if _metadata and hasattr(_metadata, 'description'): + description = _metadata.description + + # Support both 'buckets' and 'bucketer' keywords. + buckets = kwargs.get('buckets') + bucketer = kwargs.get('bucketer') + + if buckets is None and isinstance(bucketer, Bucketer): + units = getattr(_metadata, 'units', None) + buckets = self._GetPrometheusBuckets(bucketer, units=units) + + with _lock: + if prom_name not in _prometheus_metrics: + construct_kwargs = {'labelnames': labelnames} + if buckets: + construct_kwargs['buckets'] = buckets + _prometheus_metrics[prom_name] = _prom_histogram( + prom_name, description, **construct_kwargs + ) + self._metric = _prometheus_metrics[prom_name] + + def _GetPrometheusBuckets( + self, bucketer: Bucketer, units: Units | None = None + ) -> list[float] | None: + """Generates Prometheus buckets from the bucketer definition.""" + if bucketer.type != 'PowersOf' or not bucketer.args: + return None + + base = bucketer.args[0] + if base <= 1.0: + logging.warning( + 'Prometheus Bucketer.PowersOf requires base > 1.0, got %f', base + ) + return None + + # Determine the maximum value for the buckets. + # Default to a very large number if units are not time-based. + max_val = float('inf') + start_val = 1.0 + if units == Units.SECONDS: + max_val = 3600.0 + # Start at 1 millisecond for seconds-based metrics. + start_val = 0.001 + elif units == Units.MILLISECONDS or units is None: + # Default to milliseconds (3.6M) if units are MILLISECONDS or None. + max_val = 3600000.0 + start_val = 1.0 + elif units == Units.MICROSECONDS: + max_val = 3600000000.0 + start_val = 1.0 + elif units == Units.NANOSECONDS: + max_val = 3600000000000.0 + start_val = 1.0 + + buckets = [] + val = start_val + # Add a safeguard to prevent excessively many buckets, even with inf + # max_val. + # A practical limit for Prometheus histograms is around 1e18. + practical_limit = 1e18 + while val < min(max_val, practical_limit): + buckets.append(val) + val *= base + return buckets + + def Record(self, value, *args, **kwargs): + """Records a value in the Prometheus histogram.""" + if not self._metric: + return + if not args and not kwargs: + self._metric.observe(value) + return + try: + self._metric.labels(*args, **kwargs).observe(value) + except ValueError as e: + logging.warning( + 'Failed to record Prometheus histogram due to label mismatch: %s.', + e, + ) + + def Get(self, *args, **kwargs): + """Returns 0.0 as Prometheus histograms do not support Get.""" + del args, kwargs + return 0.0 + + def ClearAll(self): + """No-op for Prometheus histograms.""" + pass + + +def Initialize(port=None): + """Initializes PyGrain metric reporting.""" + global _initialized + if _initialized: + return + + if not prometheus_client: + return + + if port is None: + env_port = os.environ.get('PYGRAIN_PROMETHEUS_PORT') + port = _DEFAULT_PROMETHEUS_PORT + if env_port: + try: + port = int(env_port) + except ValueError: + logging.warning( + 'Invalid PYGRAIN_PROMETHEUS_PORT "%s". Falling back to default %d.', + env_port, + _DEFAULT_PROMETHEUS_PORT, + ) + if not prometheus_client: + logging.warning( + 'prometheus-client not found. Grain metrics will not be reported to' + ' Prometheus.' + ) + return + + with _lock: + if _initialized: + return + if port <= 0: + _initialized = True + return + + try: + # If multiprocess directory is configured, use MultiProcessCollector + # to aggregate metrics from all worker processes. + multiprocess_started = False + if 'PROMETHEUS_MULTIPROC_DIR' in os.environ: + try: + registry = prometheus_client.CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + prometheus_client.start_http_server(port, registry=registry) + logging.info( + 'Prometheus multiprocess metrics server started on port %s.', + port, + ) + multiprocess_started = True + except (ImportError, AttributeError): + pass + + if not multiprocess_started: + # Standard single-process server + prometheus_client.start_http_server(port) + logging.info('Prometheus metrics server started on port %s.', port) + except ValueError as e: + logging.warning('Failed to start Prometheus server: %s', e) + return + except OSError as e: + if e.errno != errno.EADDRINUSE: + logging.warning('Failed to start Prometheus server: %s', e) + return + logging.info('Prometheus server already active.') + + _initialized = True + + +def SetupTelemetry(): + """Autostarts Prometheus metrics server if enabled via environment variable.""" + enable_telemetry = os.environ.get( + 'ENABLE_PYGRAIN_PROMETHEUS_TELEMETRY', 'false' + ) + if enable_telemetry.lower() == 'true': + global _prometheus_multiproc_dir + + if 'PROMETHEUS_MULTIPROC_DIR' not in os.environ: + # Create a directory for prometheus multiprocessing. + _prometheus_multiproc_dir = tempfile.mkdtemp( + prefix='prometheus_multiproc_' + ) + os.environ['PROMETHEUS_MULTIPROC_DIR'] = _prometheus_multiproc_dir + _creator_pid = os.getpid() + + def _Cleanup(): + if os.getpid() == _creator_pid: + shutil.rmtree(_prometheus_multiproc_dir, ignore_errors=True) + + atexit.register(_Cleanup) + + if multiprocessing.current_process().name == 'MainProcess': + Initialize(port=9431) + else: + Initialize(port=0) diff --git a/grain/_src/core/prometheus_monitoring_test.py b/grain/_src/core/prometheus_monitoring_test.py new file mode 100644 index 000000000..aecfcb741 --- /dev/null +++ b/grain/_src/core/prometheus_monitoring_test.py @@ -0,0 +1,326 @@ +# pytype: skip-file +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import errno +import os +import tempfile +from unittest import mock + +from absl.testing import absltest + +from grain._src.core import monitoring +from grain._src.core import prometheus_monitoring + +# pylint: disable=g-import-not-at-top +try: + import prometheus_client +except ImportError: + prometheus_client = None + + +class MonitoringTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.enter_context( + mock.patch.object(prometheus_monitoring, '_initialized', False) + ) + self.enter_context( + mock.patch.object(prometheus_monitoring, '_prometheus_metrics', {}) + ) + + if prometheus_client: + # Clear registry for hermetic tests. + registry = prometheus_client.REGISTRY + collector_to_names = getattr(registry, '_collector_to_names', None) + if collector_to_names: + for collector in list(collector_to_names): + registry.unregister(collector) + + def test_initialize_prometheus_server_called_once(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + with mock.patch.object( + prometheus_client, 'start_http_server', autospec=True + ) as mock_start_http_server: + prometheus_monitoring.Initialize(9431) + mock_start_http_server.assert_called_once_with(9431) + prometheus_monitoring.Initialize(9431) + # Still once, because _initialized is True in this context. + mock_start_http_server.assert_called_once_with(9431) + + def test_counter_prometheus_routing(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/grain/test/counter'}, + ): + c = prometheus_monitoring.Counter( + '/grain/test/counter', + metadata=monitoring.Metadata(description='My counter'), + ) + c.Increment() + c.IncrementBy(5) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value( + 'grain_test_counter_total' + ), + 6, + ) + + def test_counter_with_labels_prometheus_routing(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/grain/test/counter_labels'}, + ): + c = prometheus_monitoring.Counter( + '/grain/test/counter_labels', + metadata=monitoring.Metadata(description='Labeled counter'), + fields=[('label1', str), ('label2', int)], + ) + c.Increment('a', 1) + c.IncrementBy(2, 'b', 2) + c.Increment('a', 1) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value( + 'grain_test_counter_labels_total', {'label1': 'a', 'label2': '1'} + ), + 2, + ) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value( + 'grain_test_counter_labels_total', {'label1': 'b', 'label2': '2'} + ), + 2, + ) + + def test_gauge_prometheus_routing(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/grain/test/gauge'}, + ): + g = prometheus_monitoring.Metric( + '/grain/test/gauge', + metadata=monitoring.Metadata(description='My gauge'), + ) + g.Set(5) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value('grain_test_gauge'), + 5, + ) + + def test_histogram_prometheus_routing(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/grain/test/histogram'}, + ): + h = prometheus_monitoring.EventMetric( + '/grain/test/histogram', + metadata=monitoring.Metadata(description='My histogram'), + ) + h.Record(10) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value( + 'grain_test_histogram_count' + ), + 1, + ) + self.assertEqual( + prometheus_client.REGISTRY.get_sample_value( + 'grain_test_histogram_sum' + ), + 10, + ) + + def test_ignore_non_allowed_metrics_prometheus(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + # Ensure allowlist is intact and does NOT include the test metric + c = prometheus_monitoring.Counter( + '/not_allowed/counter', metadata=monitoring.Metadata(description='d') + ) + c.Increment() + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value('not_allowed_counter_total') + ) + + def test_initialize_multiprocess(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with tempfile.TemporaryDirectory() as tmp_dir: + with mock.patch.dict(os.environ, {'PROMETHEUS_MULTIPROC_DIR': tmp_dir}): + with mock.patch.object( + prometheus_client, 'start_http_server' + ) as mock_start: + prometheus_monitoring.Initialize(9431) + mock_start.assert_called_once() + args, kwargs = mock_start.call_args + self.assertEqual(args[0], 9431) + self.assertIn('registry', kwargs) + + def test_setup_telemetry_main_process(self): + mock_process = mock.Mock() + mock_process.name = 'MainProcess' + with mock.patch.dict( + os.environ, {'ENABLE_PYGRAIN_PROMETHEUS_TELEMETRY': 'true'} + ): + with mock.patch( + 'multiprocessing.current_process', return_value=mock_process + ): + with mock.patch.object( + prometheus_monitoring, 'Initialize' + ) as mock_init: + prometheus_monitoring.SetupTelemetry() + mock_init.assert_called_once_with(port=9431) + + def test_setup_telemetry_worker_process(self): + mock_process = mock.Mock() + mock_process.name = 'Worker-1' + with mock.patch.dict( + os.environ, {'ENABLE_PYGRAIN_PROMETHEUS_TELEMETRY': 'true'} + ): + with mock.patch( + 'multiprocessing.current_process', return_value=mock_process + ): + with mock.patch.object( + prometheus_monitoring, 'Initialize' + ) as mock_init: + prometheus_monitoring.SetupTelemetry() + mock_init.assert_called_once_with(port=0) + + def test_initialize_port_already_in_use(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_client, + 'start_http_server', + side_effect=OSError(errno.EADDRINUSE, 'Address already in use'), + ): + with self.assertLogs(level='INFO') as log: + prometheus_monitoring.Initialize(9431) + self.assertTrue(getattr(prometheus_monitoring, '_initialized')) + self.assertTrue( + any('Prometheus server already active' in m for m in log.output) + ) + + def test_initialize_other_oserror(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_client, + 'start_http_server', + side_effect=OSError(errno.EINVAL, 'Some other error'), + ): + with self.assertLogs(level='WARNING') as log: + prometheus_monitoring.Initialize(9431) + self.assertFalse(getattr(prometheus_monitoring, '_initialized')) + self.assertTrue( + any('Failed to start Prometheus server' in m for m in log.output) + ) + + def test_ignore_non_allowed_metrics_prometheus_gauge(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + g = prometheus_monitoring.Metric( + '/not_allowed/gauge', metadata=monitoring.Metadata(description='d') + ) + g.Set(5) + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value('not_allowed_gauge') + ) + + def test_ignore_non_allowed_metrics_prometheus_event(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + h = prometheus_monitoring.EventMetric( + '/not_allowed/histogram', + metadata=monitoring.Metadata(description='d'), + ) + h.Record(5) + self.assertIsNone( + prometheus_client.REGISTRY.get_sample_value( + 'not_allowed_histogram_count' + ) + ) + + def test_missing_prometheus_metrics(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_monitoring, '_PROMETHEUS_ALLOWED_METRICS', {'/grain/test'} + ): + with mock.patch.object( + prometheus_monitoring, '_prom_counter', None + ), mock.patch.object( + prometheus_monitoring, '_prom_gauge', None + ), mock.patch.object( + prometheus_monitoring, '_prom_histogram', None + ): + c = prometheus_monitoring.Counter('/grain/test') + self.assertIsNone(c._metric) + m = prometheus_monitoring.Metric('/grain/test') + self.assertIsNone(m._metric) + e = prometheus_monitoring.EventMetric('/grain/test') + self.assertIsNone(e._metric) + + def test_histogram_with_bucketer_prometheus_routing(self): + if prometheus_client is None: + self.skipTest('prometheus-client not installed') + + with mock.patch.object( + prometheus_monitoring, + '_PROMETHEUS_ALLOWED_METRICS', + {'/grain/test/histogram_bucketer'}, + ): + bucketer = monitoring.Bucketer.PowersOf(2.0) + h = prometheus_monitoring.EventMetric( + '/grain/test/histogram_bucketer', + metadata=monitoring.Metadata( + description='d', units=monitoring.Units.MILLISECONDS + ), + bucketer=bucketer, + ) + prometheus_metric = getattr(h, '_metric', None) + if prometheus_metric: + upper_bounds = getattr(prometheus_metric, '_upper_bounds', None) + if upper_bounds: + self.assertIn(1.0, upper_bounds) + self.assertIn(2.0, upper_bounds) + self.assertIn(4.0, upper_bounds) + + +if __name__ == '__main__': + absltest.main() diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index 65851a50e..56cfeb136 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -25,7 +25,7 @@ from typing import Any, Awaitable, Callable, Optional, Sequence, TypeVar from etils import epath -from grain._src.core import monitoring as grain_monitoring +from grain._src.core import monitoring from grain._src.core import sharding from grain._src.core import transforms from grain._src.core import tree_lib @@ -41,13 +41,10 @@ from grain._src.python.ipc import shared_memory_array import numpy as np -from grain._src.core import monitoring - _api_usage_counter = monitoring.Counter( "/grain/python/data_loader/api", monitoring.Metadata(description="API initialization counter."), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) _iterator_get_next_metric = monitoring.EventMetric( @@ -56,7 +53,6 @@ description="Gauge for DataLoaderIterator.__next__() latency.", units=monitoring.Units.NANOSECONDS, ), - root=grain_monitoring.get_monitoring_root(), ) _T = TypeVar("_T") diff --git a/grain/_src/python/data_sources.py b/grain/_src/python/data_sources.py index 3448c7344..f885daeb8 100644 --- a/grain/_src/python/data_sources.py +++ b/grain/_src/python/data_sources.py @@ -32,11 +32,9 @@ from absl import logging from etils import epath - -from grain._src.core import monitoring as grain_monitoring +from grain._src.core import monitoring from grain._src.python.dataset import stats as dataset_stats -from grain._src.core import monitoring # pylint: disable=g-bad-import-order # pylint: disable=g-import-not-at-top, g-importing-member, g-bad-import-order import platform @@ -55,7 +53,6 @@ def __init__(self, *args, **kwargs): _api_usage_counter = monitoring.Counter( "/grain/python/data_sources/api", monitoring.Metadata(description="API initialization counter."), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index ed49bfe1e..1250a23cd 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -56,7 +56,7 @@ from etils import epath from grain._src.core import config as grain_config # pylint: disable=g-importing-member. -from grain._src.core import monitoring as grain_monitoring +from grain._src.core import monitoring from grain._src.core import traceback_util from grain._src.core import transforms from grain._src.python import options as grain_options @@ -66,15 +66,12 @@ from grain.proto import execution_summary_pb2 import numpy as np -from grain._src.core import monitoring - _api_usage_counter = monitoring.Counter( "/grain/python/lazy_dataset/api", metadata=monitoring.Metadata( description="Lazy Dataset API initialization counter." ), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 472ee0ba6..d350399c4 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -34,15 +34,13 @@ from absl import logging from grain._src.core import config as grain_config -from grain._src.core import monitoring as grain_monitoring +from grain._src.core import monitoring from grain._src.core import profiler from grain._src.core import tree_lib from grain._src.python.dataset import base from grain._src.python.dataset import stats_utils from grain.proto import execution_summary_pb2 -from grain._src.core import monitoring - # Registry of weak references to output dataset iterators for collecting # execution stats. @@ -67,7 +65,6 @@ class NodeType(enum.Enum): ), units=monitoring.Units.NANOSECONDS, ), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], bucketer=monitoring.Bucketer.PowersOf(2.0), ) @@ -82,7 +79,6 @@ class NodeType(enum.Enum): ), units=monitoring.Units.NANOSECONDS, ), - root=grain_monitoring.get_monitoring_root(), bucketer=monitoring.Bucketer.PowersOf(2.0), ) diff --git a/grain/_src/python/load.py b/grain/_src/python/load.py index f78ba3ebd..3bba3f65f 100644 --- a/grain/_src/python/load.py +++ b/grain/_src/python/load.py @@ -2,7 +2,7 @@ from typing import Optional -from grain._src.core import monitoring as grain_monitoring +from grain._src.core import monitoring from grain._src.core import sharding from grain._src.core import transforms from grain._src.python import data_loader @@ -10,13 +10,10 @@ from grain._src.python import samplers from grain._src.python.dataset import base as dataset_base -from grain._src.core import monitoring - _api_usage_counter = monitoring.Counter( "/grain/python/load/api", monitoring.Metadata(description="API initialization counter."), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) diff --git a/grain/_src/python/samplers.py b/grain/_src/python/samplers.py index b7deb8d6c..761f2b2e5 100644 --- a/grain/_src/python/samplers.py +++ b/grain/_src/python/samplers.py @@ -16,21 +16,18 @@ import sys from typing import Optional, Protocol -from grain._src.core import monitoring as grain_monitoring +from grain._src.core import monitoring from grain._src.core import sharding from grain._src.python import record from grain._src.python.dataset import dataset import numpy as np -from grain._src.core import monitoring - _api_usage_counter = monitoring.Counter( "/grain/python/samplers/api", metadata=monitoring.Metadata( description="Sampler API initialization counter." ), - root=grain_monitoring.get_monitoring_root(), fields=[("name", str)], ) diff --git a/test_requirements.in b/test_requirements.in index fdd357333..8d706b674 100644 --- a/test_requirements.in +++ b/test_requirements.in @@ -22,6 +22,7 @@ etils[epath,epy] cloudpickle jax numpy +prometheus-client attrs pyarrow pytest diff --git a/test_requirements_lock_3_11.txt b/test_requirements_lock_3_11.txt index 3f40bc9c0..3d61ed814 100644 --- a/test_requirements_lock_3_11.txt +++ b/test_requirements_lock_3_11.txt @@ -218,6 +218,10 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r test_requirements.in +prometheus-client==0.21.1 \ + --hash=sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb \ + --hash=sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301 + # via -r test_requirements.in psutil==7.2.2 \ --hash=sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372 \ --hash=sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9 \ diff --git a/test_requirements_lock_3_12.txt b/test_requirements_lock_3_12.txt index f0470c085..06741f589 100644 --- a/test_requirements_lock_3_12.txt +++ b/test_requirements_lock_3_12.txt @@ -218,6 +218,10 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r test_requirements.in +prometheus-client==0.21.1 \ + --hash=sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb \ + --hash=sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301 + # via -r test_requirements.in psutil==7.2.2 \ --hash=sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372 \ --hash=sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9 \ diff --git a/test_requirements_lock_3_13.txt b/test_requirements_lock_3_13.txt index bd6161c78..225245bd7 100644 --- a/test_requirements_lock_3_13.txt +++ b/test_requirements_lock_3_13.txt @@ -218,6 +218,10 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r test_requirements.in +prometheus-client==0.21.1 \ + --hash=sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb \ + --hash=sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301 + # via -r test_requirements.in psutil==7.2.2 \ --hash=sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372 \ --hash=sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9 \ diff --git a/test_requirements_lock_3_14.txt b/test_requirements_lock_3_14.txt index abfab8ef6..9abdfadff 100644 --- a/test_requirements_lock_3_14.txt +++ b/test_requirements_lock_3_14.txt @@ -218,6 +218,10 @@ portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r test_requirements.in +prometheus-client==0.21.1 \ + --hash=sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb \ + --hash=sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301 + # via -r test_requirements.in psutil==7.2.2 \ --hash=sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372 \ --hash=sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9 \