Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions pathwaysutils/experimental/shared_pathways_service/isc_pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
import string
import subprocess
import threading
import time
from typing import Any

import jax
import jax.extend.backend as jax_backend
import pathwaysutils
from pathwaysutils.experimental.shared_pathways_service import gke_utils
from pathwaysutils.experimental.shared_pathways_service import metrics_collector
from pathwaysutils.experimental.shared_pathways_service import validators


Expand Down Expand Up @@ -128,6 +130,9 @@ def _wait_for_placement(
pod_name: str,
num_slices: int,
stream_logs_func=gke_utils.stream_pod_logs,
metrics_collector_inst: Any = None,
start_time: float | None = None,
total_chips: int = 0,
) -> None:
"""Waits for the placement to be complete by checking proxy logs."""
_logger.info("Streaming proxy logs until the placement is complete...")
Expand All @@ -150,6 +155,8 @@ def _wait_for_placement(
f"STDERR: {stderr}"
)

if metrics_collector_inst:
metrics_collector_inst.record_user_waiting(True)
for line in log_process.stdout:
line_lower = line.lower()
if any(keyword.lower() in line_lower for keyword in keywords):
Expand All @@ -165,6 +172,13 @@ def _wait_for_placement(
)
else:
_logger.info("TPU placement for %d slice(s) complete!", num_slices)
metrics_collector_inst.record_active_user(True)
metrics_collector_inst.record_user_waiting(False)
metrics_collector_inst.record_capacity_in_use(total_chips)
if start_time:
duration = time.time() - start_time
metrics_collector_inst.record_assignment_time(duration)
metrics_collector_inst.record_successful_request()
break


Expand Down Expand Up @@ -195,11 +209,15 @@ class _ISCPathways:
proxy_pod_name: The name of the proxy pod, assigned during deployment.
proxy_server_image: The image to use for the proxy server.
proxy_options: Configuration options for the Pathways proxy.
metrics_collector: The metrics collector instance if enabled.
start_time: The start time of the TPU assignment.
total_chips: The total number of TPU chips expected across all instances.
"""

def __init__(
self,
*, cluster: str,
*,
cluster: str,
project: str,
region: str,
gcs_bucket: str,
Expand All @@ -208,6 +226,7 @@ def __init__(
proxy_job_name: str,
proxy_server_image: str,
proxy_options: ProxyOptions | None = None,
collect_service_metrics: bool = False,
):
"""Initializes the TPU manager."""
self.cluster = cluster
Expand All @@ -223,9 +242,19 @@ def __init__(
self.proxy_server_image = proxy_server_image
self.proxy_options = proxy_options or ProxyOptions()
self._old_jax_platforms = None
raw_collector = (
metrics_collector.MetricsCollector(self.project)
if collect_service_metrics
else None
)
self.metrics_collector = metrics_collector.SafeMetricsCollector(
raw_collector
)
self.start_time = None
self._old_jax_backend_target = None
self._old_jax_platforms_config = None
self._old_jax_backend_target_config = None
self.total_chips = self._get_total_chips()

def __repr__(self):
return (
Expand All @@ -237,8 +266,23 @@ def __repr__(self):
f"proxy_options={self.proxy_options})"
)

def _get_total_chips(self) -> int:
"""Calculates total chips from expected_tpu_instances."""
total_chips = 0
for tpu_type, count in self.expected_tpu_instances.items():
parts = tpu_type.split(":")
topology = parts[1]
dimensions = [int(d) for d in topology.split("x")]
chips_per_instance = 1
for d in dimensions:
chips_per_instance *= d
total_chips += chips_per_instance * count
return total_chips

def __enter__(self):
"""Enters the context manager, ensuring cluster exists."""
self.metrics_collector.record_requested_capacity(self.total_chips)

self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY.upper())
self._old_jax_backend_target = os.environ.get(
_JAX_BACKEND_TARGET_KEY.upper()
Expand All @@ -251,6 +295,7 @@ def __enter__(self):
)

try:
self.start_time = time.time()
_deploy_pathways_proxy_server(
pathways_service=self.pathways_service,
proxy_job_name=self._proxy_job_name,
Expand All @@ -259,7 +304,7 @@ def __enter__(self):
proxy_server_image=self.proxy_server_image,
proxy_options=self.proxy_options,
)
# Print a link to Cloud Logging
self.metrics_collector.record_user_waiting(True)
cloud_logging_link = gke_utils.get_log_link(
cluster=self.cluster,
project=self.project,
Expand Down Expand Up @@ -303,14 +348,14 @@ def __exit__(self, exc_type, exc_value, traceback):

def _cleanup(self) -> None:
"""Cleans up resources created by the ISCPathways context."""
# 1. Clear JAX caches and run garbage collection.
# Clear JAX caches and run garbage collection.
_logger.info("Starting Pathways proxy cleanup.")
jax_backend.clear_backends()
jax.clear_caches()
gc.collect()
_logger.info("Cleared JAX caches and ran garbage collection.")

# 2. Terminate the port forwarding process.
# Terminate the port forwarding process.
if self._port_forward_process:
_logger.info("Terminating port forwarding process...")
self._port_forward_process.terminate()
Expand All @@ -323,12 +368,12 @@ def _cleanup(self) -> None:
e,
)

# 3. Delete the proxy GKE job.
# Delete the proxy GKE job.
_logger.info("Deleting Pathways proxy...")
gke_utils.delete_gke_job(self._proxy_job_name)
_logger.info("Pathways proxy GKE job deletion complete.")

# 4. Restore JAX variables.
# Restore JAX variables.
_logger.info("Restoring JAX env and config variables...")
_restore_env_var(_JAX_PLATFORMS_KEY.upper(), self._old_jax_platforms)
_restore_env_var(
Expand All @@ -353,6 +398,7 @@ def connect(
proxy_job_name: str | None = None,
proxy_server_image: str = DEFAULT_PROXY_IMAGE,
proxy_options: ProxyOptions | None = None,
collect_service_metrics: bool = False,
) -> Iterator["_ISCPathways"]:
"""Connects to a Pathways server if the cluster exists. If not, creates it.

Expand All @@ -370,6 +416,8 @@ def connect(
default will be used.
proxy_options: Configuration options for the Pathways proxy. If not
provided, no extra options will be used.
collect_service_metrics: Whether to collect usage metrics for Shared
Pathways Service.

Yields:
The Pathways manager.
Expand Down Expand Up @@ -399,6 +447,7 @@ def connect(
proxy_job_name=proxy_job_name,
proxy_server_image=proxy_server_image,
proxy_options=proxy_options,
collect_service_metrics=collect_service_metrics,
) as t:
if t.proxy_pod_name:
num_slices = sum(t.expected_tpu_instances.values())
Expand All @@ -407,6 +456,10 @@ def connect(
args=(
t.proxy_pod_name,
num_slices,
gke_utils.stream_pod_logs,
t.metrics_collector,
t.start_time,
t.total_chips,
),
daemon=True,
)
Expand Down
Loading
Loading