Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion grain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Public API for Grain."""


# pylint: disable=g-importing-member
# pylint: disable=unused-import
# pylint: disable=g-multiple-import
Expand All @@ -32,6 +31,7 @@
transforms,
)

from grain._src.core import monitoring
from grain._src.core.config import config
from grain._src.core.version import __version__, __version_info__
from grain._src.python.data_loader import (
Expand All @@ -46,3 +46,5 @@
from grain._src.python.load import load
from grain._src.python.options import ReadOptions, MultiprocessingOptions
from grain._src.python.record import Record, RecordMetadata

monitoring.setup_telemetry()
35 changes: 34 additions & 1 deletion grain/_src/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,28 @@ py_library(
srcs_version = "PY3",
)

py_library(
name = "monitoring_base",
srcs = ["monitoring_base.py"],
srcs_version = "PY3",
)

py_library(
name = "monitoring",
srcs = ["monitoring.py"],
srcs = [
"monitoring.py",
"prometheus_monitoring.py",
],
srcs_version = "PY3",
tags = [
"ignore_for_dep=grain._src.core.fast_monitoring",
"ignore_for_dep=grain._src.core.google.streamz_monitoring",
"ignore_for_dep=grain._src.core.monitoring_base",
"ignore_for_dep=grain._src.core.prometheus_monitoring",
],
deps = [
":monitoring_base",
"@pypi//prometheus_client:pkg",
],
)

Expand Down Expand Up @@ -255,3 +272,19 @@ py_test(
"@pypi//cloudpickle:pkg",
],
)

py_test(
name = "prometheus_monitoring_test",
srcs = ["prometheus_monitoring_test.py"],
srcs_version = "PY3",
tags = [
"ignore_for_dep=grain._src.core.google_monitoring",
"ignore_for_dep=grain._src.core.monitoring_base",
"ignore_for_dep=grain._src.core.prometheus_monitoring",
],
deps = [
":monitoring",
"@abseil-py//absl/testing:absltest",
"@pypi//prometheus_client:pkg",
],
)
5 changes: 1 addition & 4 deletions grain/_src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from typing import Any

from absl import flags
from grain._src.core import monitoring as grain_monitoring

from grain._src.core import monitoring

# Performance optimisations. Consider most of these experimental. We might
Expand Down Expand Up @@ -148,9 +146,8 @@

_grain_experiment_metric = monitoring.Metric(
"/grain/experiment",
value_type=int,
int,
metadata=monitoring.Metadata(description="Grain experiment opt-in metric."),
root=grain_monitoring.get_monitoring_root(),
fields=[("name", str)],
)

Expand Down
167 changes: 109 additions & 58 deletions grain/_src/core/monitoring.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,109 @@
"""Grain metrics."""

import enum


@enum.unique
class Units(enum.Enum):
"""Grain metric units."""

NANOSECONDS = enum.auto()
MILLISECONDS = enum.auto()


# pylint: disable=invalid-name
class NoOpMetric:
"""Grain metric no-op implementation."""

def __init__(self, *args, **kwargs):
del args, kwargs

def IncrementBy(self, *args, **kwargs):
del args, kwargs

def Increment(self, *args, **kwargs):
self.IncrementBy(1, *args, **kwargs)

def Set(self, *args, **kwargs):
del args, kwargs

def Record(self, *args, **kwargs):
del args, kwargs

def Get(self, *args, **kwargs):
del args, kwargs


class Metadata:
"""Grain metric no-op metadata."""

def __init__(self, *args, **kwargs):
del args, kwargs


# pylint: disable=invalid-name
class Bucketer:
"""Grain metric no-op bucketer."""

def __init__(self, *args, **kwargs):
del args, kwargs

def PowersOf(self, *args, **kwargs):
del args, kwargs


Counter = Metric = EventMetric = NoOpMetric

def get_monitoring_root() -> None:
return None
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grain metrics convenience functions and implementation selection."""

from __future__ import annotations

from grain._src.core import monitoring_base

from grain._src.core import prometheus_monitoring as _impl

Bucketer = monitoring_base.Bucketer
Metadata = monitoring_base.Metadata
Units = monitoring_base.Units

Counter = _impl.Counter
EventMetric = _impl.EventMetric
Metric = _impl.Metric
get_monitoring_root = _impl.get_monitoring_root
setup_telemetry = _impl.SetupTelemetry


def record_autotune_node_throughput(node_name: str, throughput: float):
_grain_autotune_node_throughput.Set(throughput, node_name)


def record_autotune_parameter(node_name: str, name: str, value: float):
_grain_autotune_parameters.Set(value, node_name, name)


def record_autotune_usl_coeff(node_name: str, name: str, value: float):
_grain_autotune_usl_coeffs.Set(value, node_name, name)


def record_autotune_optimization_latency(latency_ms: float):
_grain_autotune_optimization_latency.Record(latency_ms)


def record_framework_type(framework_type: str):
_grain_framework_type_metric.Increment(framework_type)


def set_debug_server_ports(ports: list[int]):
if not ports:
return
_grain_debug_server_ports.Set(','.join(map(str, ports)))


_grain_framework_type_metric = Counter(
'/grain/framework_type',
Metadata(description='The framework type used to build the Grain dataset.'),
fields=[('name', str)],
)

_grain_debug_server_ports = Metric(
'/grain/debug_server_ports',
str,
metadata=Metadata(
description=(
'A comma-separated list of debug server ports. The first port is'
' for the main process, followed by worker process ports.'
)
),
)


_grain_autotune_node_throughput = Metric(
'/grain/autotune/node_throughput',
float,
metadata=Metadata(
description='The observed throughput of the autotune node (elements/s).'
),
fields=[('node_name', str)],
)

_grain_autotune_parameters = Metric(
'/grain/autotune/parameters',
float,
metadata=Metadata(
description='The current values for the autotuned parameters.'
),
fields=[('node_name', str), ('name', str)],
)

_grain_autotune_usl_coeffs = Metric(
'/grain/autotune/usl_coeffs',
float,
metadata=Metadata(
description='The estimated USL coefficients for the autotune node.'
),
fields=[('node_name', str), ('name', str)],
)

_grain_autotune_optimization_latency = EventMetric(
'/grain/autotune/optimization_latency',
metadata=Metadata(
description='The time taken by the autotune optimizer (ms).'
),
)
96 changes: 96 additions & 0 deletions grain/_src/core/monitoring_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grain metrics base definitions and protocols."""

from __future__ import annotations

import enum
from typing import Any, Protocol


@enum.unique
class Units(enum.Enum):
"""Grain metric units."""

SECONDS = enum.auto()
MILLISECONDS = enum.auto()
MICROSECONDS = enum.auto()
NANOSECONDS = enum.auto()
BITS = enum.auto()
BYTES = enum.auto()


class Metadata:
"""Grain metric metadata."""

def __init__(self, description='', **kwargs):
self.description = description
for key, value in kwargs.items():
setattr(self, key, value)
self._kwargs = kwargs


class Bucketer:
"""Grain metric bucketer."""

def __init__(self, *args, bucketer_type=None, **kwargs):
self.args = args
self.kwargs = kwargs
self.type = bucketer_type

@staticmethod
def PowersOf(base: float):
return Bucketer(base, bucketer_type='PowersOf')


class CounterProtocol(Protocol):
"""Protocol for Counter metrics."""

def Increment(self, *args: Any, **kwargs: Any) -> None:
...

def IncrementBy(self, value: float | int, *args: Any, **kwargs: Any) -> None:
...

def Get(self, *args: Any, **kwargs: Any) -> float | int:
...

def ClearAll(self) -> None:
...


class MetricProtocol(Protocol):
"""Protocol for Gauge metrics."""

def Set(self, value: float | int, *args: Any, **kwargs: Any) -> None:
...

def Get(self, *args: Any, **kwargs: Any) -> float | int:
...

def ClearAll(self) -> None:
...


class EventMetricProtocol(Protocol):
"""Protocol for Event/Histogram metrics."""

def Record(self, value: float | int, *args: Any, **kwargs: Any) -> None:
...

def Get(self, *args: Any, **kwargs: Any) -> float | int:
...

def ClearAll(self) -> None:
...
Loading
Loading