diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 3c3b95a..e2572b9 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -10,12 +10,14 @@ import string import subprocess import threading +import time from typing import Any import jax import jax.extend.backend as jax_backend import pathwaysutils from pathwaysutils.experimental.shared_pathways_service import gke_utils +from pathwaysutils.experimental.shared_pathways_service import metrics_collector from pathwaysutils.experimental.shared_pathways_service import validators @@ -128,6 +130,9 @@ def _wait_for_placement( pod_name: str, num_slices: int, stream_logs_func=gke_utils.stream_pod_logs, + metrics_collector_inst: Any = None, + start_time: float | None = None, + total_chips: int = 0, ) -> None: """Waits for the placement to be complete by checking proxy logs.""" _logger.info("Streaming proxy logs until the placement is complete...") @@ -150,6 +155,8 @@ def _wait_for_placement( f"STDERR: {stderr}" ) + if metrics_collector_inst: + metrics_collector_inst.record_user_waiting(True) for line in log_process.stdout: line_lower = line.lower() if any(keyword.lower() in line_lower for keyword in keywords): @@ -165,6 +172,13 @@ def _wait_for_placement( ) else: _logger.info("TPU placement for %d slice(s) complete!", num_slices) + metrics_collector_inst.record_active_user(True) + metrics_collector_inst.record_user_waiting(False) + metrics_collector_inst.record_capacity_in_use(total_chips) + if start_time: + duration = time.time() - start_time + metrics_collector_inst.record_assignment_time(duration) + metrics_collector_inst.record_successful_request() break @@ -195,11 +209,15 @@ class _ISCPathways: proxy_pod_name: The name of the proxy pod, assigned during deployment. proxy_server_image: The image to use for the proxy server. proxy_options: Configuration options for the Pathways proxy. + metrics_collector: The metrics collector instance if enabled. + start_time: The start time of the TPU assignment. + total_chips: The total number of TPU chips expected across all instances. """ def __init__( self, - *, cluster: str, + *, + cluster: str, project: str, region: str, gcs_bucket: str, @@ -208,6 +226,7 @@ def __init__( proxy_job_name: str, proxy_server_image: str, proxy_options: ProxyOptions | None = None, + collect_service_metrics: bool = False, ): """Initializes the TPU manager.""" self.cluster = cluster @@ -223,9 +242,19 @@ def __init__( self.proxy_server_image = proxy_server_image self.proxy_options = proxy_options or ProxyOptions() self._old_jax_platforms = None + raw_collector = ( + metrics_collector.MetricsCollector(self.project) + if collect_service_metrics + else None + ) + self.metrics_collector = metrics_collector.SafeMetricsCollector( + raw_collector + ) + self.start_time = None self._old_jax_backend_target = None self._old_jax_platforms_config = None self._old_jax_backend_target_config = None + self.total_chips = self._get_total_chips() def __repr__(self): return ( @@ -237,8 +266,23 @@ def __repr__(self): f"proxy_options={self.proxy_options})" ) + def _get_total_chips(self) -> int: + """Calculates total chips from expected_tpu_instances.""" + total_chips = 0 + for tpu_type, count in self.expected_tpu_instances.items(): + parts = tpu_type.split(":") + topology = parts[1] + dimensions = [int(d) for d in topology.split("x")] + chips_per_instance = 1 + for d in dimensions: + chips_per_instance *= d + total_chips += chips_per_instance * count + return total_chips + def __enter__(self): """Enters the context manager, ensuring cluster exists.""" + self.metrics_collector.record_requested_capacity(self.total_chips) + self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper()) self._old_jax_backend_target = os.environ.get( _JAX_BACKEND_TARGET_KEY.upper() @@ -251,6 +295,7 @@ def __enter__(self): ) try: + self.start_time = time.time() _deploy_pathways_proxy_server( pathways_service=self.pathways_service, proxy_job_name=self._proxy_job_name, @@ -259,7 +304,7 @@ def __enter__(self): proxy_server_image=self.proxy_server_image, proxy_options=self.proxy_options, ) - # Print a link to Cloud Logging + self.metrics_collector.record_user_waiting(True) cloud_logging_link = gke_utils.get_log_link( cluster=self.cluster, project=self.project, @@ -303,14 +348,14 @@ def __exit__(self, exc_type, exc_value, traceback): def _cleanup(self) -> None: """Cleans up resources created by the ISCPathways context.""" - # 1. Clear JAX caches and run garbage collection. + # Clear JAX caches and run garbage collection. _logger.info("Starting Pathways proxy cleanup.") jax_backend.clear_backends() jax.clear_caches() gc.collect() _logger.info("Cleared JAX caches and ran garbage collection.") - # 2. Terminate the port forwarding process. + # Terminate the port forwarding process. if self._port_forward_process: _logger.info("Terminating port forwarding process...") self._port_forward_process.terminate() @@ -323,12 +368,12 @@ def _cleanup(self) -> None: e, ) - # 3. Delete the proxy GKE job. + # Delete the proxy GKE job. _logger.info("Deleting Pathways proxy...") gke_utils.delete_gke_job(self._proxy_job_name) _logger.info("Pathways proxy GKE job deletion complete.") - # 4. Restore JAX variables. + # Restore JAX variables. _logger.info("Restoring JAX env and config variables...") _restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms) _restore_env_var( @@ -353,6 +398,7 @@ def connect( proxy_job_name: str | None = None, proxy_server_image: str = DEFAULT_PROXY_IMAGE, proxy_options: ProxyOptions | None = None, + collect_service_metrics: bool = False, ) -> Iterator["_ISCPathways"]: """Connects to a Pathways server if the cluster exists. If not, creates it. @@ -370,6 +416,8 @@ def connect( default will be used. proxy_options: Configuration options for the Pathways proxy. If not provided, no extra options will be used. + collect_service_metrics: Whether to collect usage metrics for Shared + Pathways Service. Yields: The Pathways manager. @@ -399,6 +447,7 @@ def connect( proxy_job_name=proxy_job_name, proxy_server_image=proxy_server_image, proxy_options=proxy_options, + collect_service_metrics=collect_service_metrics, ) as t: if t.proxy_pod_name: num_slices = sum(t.expected_tpu_instances.values()) @@ -407,6 +456,10 @@ def connect( args=( t.proxy_pod_name, num_slices, + gke_utils.stream_pod_logs, + t.metrics_collector, + t.start_time, + t.total_chips, ), daemon=True, ) diff --git a/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py b/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py new file mode 100644 index 0000000..9ea0cae --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/metrics_collector.py @@ -0,0 +1,296 @@ +"""Metrics collector for Shared Pathways Service.""" + +import atexit +import logging +import threading +import time +from typing import Any, Dict +import uuid + +try: + # pylint: disable=g-import-not-at-top + from google.api_core import exceptions + from google.cloud import monitoring_v3 +except ImportError: + pass + +_logger = logging.getLogger(__name__) + + +METRIC_PREFIX = "custom.googleapis.com/shared_pathways_service/" + +_METRIC_NUM_ACTIVE_USERS = "num_active_users" +_METRIC_CAPACITY_IN_USE = "capacity_in_use" +_METRIC_ASSIGNMENT_TIME = "assignment_time" +_METRIC_NUM_SUCCESSFUL_REQS = "num_successful_reqs" +_METRIC_NUM_USERS_WAITING = "num_users_waiting" +_METRIC_REQUESTED_CAPACITY = "requested_capacity" +_METRIC_DESCRIPTORS = [ + { + "name": _METRIC_NUM_ACTIVE_USERS, + "description": "Number of active users at any given time", + "value_type": "INT64", + "unit": "1", + }, + { + "name": _METRIC_CAPACITY_IN_USE, + "description": "Number of chips that are actively running workloads", + "value_type": "INT64", + "unit": "chips", + "display_name": "Capacity (chips) in use", + }, + { + "name": _METRIC_ASSIGNMENT_TIME, + "description": "Time to assign slice(s) to an incoming client", + "value_type": "DOUBLE", + "unit": "s", + "display_name": "Capacity assignment time", + }, + { + "name": _METRIC_NUM_SUCCESSFUL_REQS, + "description": ( + "Number of user requests that got capacity assignment successfully" + ), + "value_type": "INT64", + "unit": "1", + "display_name": "Successful capacity assignment requests", + }, + { + "name": _METRIC_NUM_USERS_WAITING, + "description": "Number of users waiting for capacity", + "value_type": "INT64", + "unit": "1", + "display_name": "Users waiting", + }, + { + "name": _METRIC_REQUESTED_CAPACITY, + "description": "Number of chips requested by an incoming client", + "value_type": "INT64", + "unit": "chips", + "display_name": "Requested capacity (chips)", + }, +] + + +class MetricsCollector: + """Collects usage metrics for Shared Pathways Service and reports to Cloud Monitoring.""" + + def __init__(self, project_id: str): + self.project_id = project_id + self.client = monitoring_v3.MetricServiceClient() + self.project_name = f"projects/{self.project_id}" + self._lock = threading.Lock() + self._buffer: Dict[str, list[tuple[Any, str, Dict[str, str] | None]]] = {} + self._last_sent_time: Dict[str, float] = {} + self._instance_id = str(uuid.uuid4()) + self._running = True + for descriptor in _METRIC_DESCRIPTORS: + self._create_metric_descriptor(**descriptor) + self._flusher_thread = threading.Thread( + target=self._flush_loop, daemon=True + ) + self._flusher_thread.start() + atexit.register(self._shutdown) + _logger.info("Metrics collection initialized.") + + def _create_time_series_object( + self, + metric_type: str, + value: Any, + value_type: str, + metric_labels: Dict[str, str] | None = None, + resource_type: str = "global", + resource_labels: Dict[str, str] | None = None, + ) -> Any: + """Creates a TimeSeries object for a single metric.""" + # Using Any for return type to avoid failing when monitoring_v3 is not + # available. + series = monitoring_v3.TimeSeries() + series.metric.type = METRIC_PREFIX + metric_type + series.resource.type = resource_type + if resource_labels: + series.resource.labels.update(resource_labels) + if metric_labels: + series.metric.labels.update(metric_labels) + + now = time.time() + seconds = int(now) + nanos = int((now - seconds) * 10**9) + + point = monitoring_v3.Point( + interval=monitoring_v3.TimeInterval( + end_time={"seconds": seconds, "nanos": nanos} + ), + value=monitoring_v3.TypedValue(**{value_type: value}), + ) + series.points.append(point) + return series + + def _flush_loop(self): + """Runs continuously to flush the metrics buffer.""" + while self._running: + self.flush() + time.sleep(1) + + def flush(self): + """Sends any eligible buffered metrics to Cloud Monitoring.""" + with self._lock: + now = time.time() + to_send = [] + for metric_type, queue in list(self._buffer.items()): + if not queue: + del self._buffer[metric_type] + continue + last_time = self._last_sent_time.get(metric_type, 0) + # Add a slight cushion (10.5s) to prevent sub-second drift errors. + if now - last_time >= 10.5 or last_time == 0: + item = queue.pop(0) + to_send.append((metric_type, *item)) + self._last_sent_time[metric_type] = now + if not queue: + del self._buffer[metric_type] + + for metric_type, value, value_type, metric_labels in to_send: + self._transmit(metric_type, value, value_type, metric_labels) + + def _shutdown(self): + """Synchronously drains the final state of the queue before exiting.""" + self._running = False + while True: + with self._lock: + if not any(self._buffer.values()): + break + # Wait for the window to open for at least one item + now = time.time() + min_wait = 0.0 + for metric_type, queue in self._buffer.items(): + if queue: + last_time = self._last_sent_time.get(metric_type, 0) + wait_needed = 10.5 - (now - last_time) + if wait_needed > 0: + min_wait = max(min_wait, wait_needed) + if min_wait > 0: + _logger.info( + "Waiting %.1fs for Cloud Monitoring sampling window...", min_wait + ) + time.sleep(min_wait) + self.flush() + + def _send_metric( + self, + metric_type: str, + value: Any, + value_type: str, + metric_labels: Dict[str, str] | None = None, + ): + """Queues a single metric in the buffer.""" + default_labels = {"client_instance_id": self._instance_id} + if metric_labels: + default_labels.update(metric_labels) + _logger.info( + "Buffering metric %s: %s", + metric_type, + (value, value_type, default_labels), + ) + with self._lock: + if metric_type not in self._buffer: + self._buffer[metric_type] = [] + _logger.info( + "Successfully buffered metric %s: %s", + metric_type, + (value, value_type, default_labels), + ) + self._buffer[metric_type].append((value, value_type, default_labels)) + + def _transmit( + self, + metric_type: str, + value: Any, + value_type: str, + metric_labels: Dict[str, str] | None = None, + ): + """Physically transmits a TimeSeries to Cloud Monitoring.""" + series = self._create_time_series_object( + metric_type, value, value_type, metric_labels + ) + try: + self.client.create_time_series( + name=self.project_name, time_series=[series] + ) + _logger.info("Sent metric %s: %s", metric_type, value) + except exceptions.GoogleAPICallError as e: + _logger.warning("Failed to send metric %s: %s", metric_type, e) + + def _create_metric_descriptor( + self, + name: str, + description: str, + value_type: str, + unit: str, + metric_kind: str = "GAUGE", + display_name: str | None = None, + ): + """Creates a metric descriptor if not already present.""" + metric_type = METRIC_PREFIX + name + display_name = display_name or name + + try: + self.client.create_metric_descriptor( + name=f"projects/{self.project_id}", + metric_descriptor={ + "type": metric_type, + "metric_kind": metric_kind, + "value_type": value_type, + "description": description, + "display_name": display_name, + "unit": unit, + "labels": [{ + "key": "client_instance_id", + "value_type": "STRING", + "description": "Unique execution identifier", + }], + }, + ) + _logger.info("Created metric descriptor: %s", metric_type) + except exceptions.AlreadyExists: + _logger.debug("Metric descriptor %s already exists.", metric_type) + + def record_active_user(self, is_active: bool): + """Records the number of active users (1 for active, 0 for inactive).""" + self._send_metric( + _METRIC_NUM_ACTIVE_USERS, 1 if is_active else 0, "int64_value" + ) + + def record_capacity_in_use(self, chips: int): + """Records the number of chips in use.""" + self._send_metric(_METRIC_CAPACITY_IN_USE, chips, "int64_value") + + def record_requested_capacity(self, chips: int): + """Records the number of chips requested by the client.""" + self._send_metric(_METRIC_REQUESTED_CAPACITY, chips, "int64_value") + + def record_assignment_time(self, duration_seconds: float): + """Records the time taken to assign slices.""" + self._send_metric(_METRIC_ASSIGNMENT_TIME, duration_seconds, "double_value") + + def record_successful_request(self): + """Records a successful request.""" + self._send_metric(_METRIC_NUM_SUCCESSFUL_REQS, 1, "int64_value") + + def record_user_waiting(self, is_waiting: bool): + """Records a user waiting for capacity.""" + self._send_metric( + _METRIC_NUM_USERS_WAITING, 1 if is_waiting else 0, "int64_value" + ) + + +class SafeMetricsCollector: + """Wrapper for MetricsCollector that safely absorbs calls when metrics are disabled.""" + + def __init__(self, collector: MetricsCollector | None): + self._collector = collector + + def __getattr__(self, name: str): + if self._collector is None: + return lambda *args, **kwargs: None + return getattr(self._collector, name) diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py index db63c28..f07220e 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -41,6 +41,11 @@ "Configuration options for the Pathways proxy. Specify entries in the form" ' "key:value". For example: --proxy_options=use_insecure_credentials:true', ) +flags.DEFINE_bool( + "collect_service_metrics", + False, + "Whether to enable metrics collection for Shared Pathways Service.", +) flags.mark_flags_as_required([ "cluster", @@ -68,6 +73,7 @@ def main(argv: Sequence[str]) -> None: proxy_server_image=FLAGS.proxy_server_image or isc_pathways.DEFAULT_PROXY_IMAGE, proxy_options=proxy_options, + collect_service_metrics=FLAGS.collect_service_metrics, ): orig_matrix = jnp.zeros(5) result_matrix = orig_matrix + 1 diff --git a/pathwaysutils/experimental/shared_pathways_service/run_workload.py b/pathwaysutils/experimental/shared_pathways_service/run_workload.py index f662c35..c666f52 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_workload.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_workload.py @@ -67,6 +67,14 @@ _COMMAND = flags.DEFINE_string( "command", None, "The command to run on TPUs.", required=True ) +_COLLECT_SERVICE_METRICS = flags.DEFINE_bool( + "collect_service_metrics", + False, + "Whether to enable metrics collection for Shared Pathways Service. If" + " enabled, the service will collect usage metrics such as TPU assignment" + " time, active user count, capacity in use etc. The metrics will be" + " stored in Cloud Monitoring.", +) flags.register_validator( "proxy_options", @@ -93,6 +101,7 @@ def run_command( command: str, proxy_server_image: str | None = None, proxy_options: Sequence[str] | None = None, + collect_service_metrics: bool = False, connect_fn: Callable[..., ContextManager[Any]] = isc_pathways.connect, ) -> None: """Run the TPU workload within a Shared Pathways connection. @@ -108,6 +117,8 @@ def run_command( command: The command to run on TPUs. proxy_server_image: The proxy server image to use. proxy_options: Configuration options for the Pathways proxy. + collect_service_metrics: Whether to collect usage metrics for Shared Pathways + Service. Defaults to False. connect_fn: The function to use for establishing the connection context, expected to be a callable that returns a context manager. @@ -130,6 +141,7 @@ def run_command( else isc_pathways.DEFAULT_PROXY_IMAGE ), proxy_options=parsed_proxy_options, + collect_service_metrics=collect_service_metrics, ): logging.info("Connection established. Running command: %r", command) try: @@ -160,6 +172,7 @@ def main(argv: Sequence[str]) -> None: command=_COMMAND.value, proxy_server_image=_PROXY_SERVER_IMAGE.value, proxy_options=_PROXY_OPTIONS.value, + collect_service_metrics=_COLLECT_SERVICE_METRICS.value, )