-
Notifications
You must be signed in to change notification settings - Fork 1.7k
refactor(bigframes): Simplify @udf wrapper object #16556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0e4f808
976586b
4fedd8d
ae221f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,18 +20,18 @@ | |
| import inspect | ||
| import sys | ||
| import threading | ||
| import warnings | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| cast, | ||
| Dict, | ||
| Literal, | ||
| Mapping, | ||
| Optional, | ||
| Sequence, | ||
| TYPE_CHECKING, | ||
| Union, | ||
| cast, | ||
| ) | ||
| import warnings | ||
|
|
||
| import google.api_core.exceptions | ||
| from google.cloud import ( | ||
|
|
@@ -41,9 +41,9 @@ | |
| resourcemanager_v3, | ||
| ) | ||
|
|
||
| from bigframes import clients | ||
| import bigframes.exceptions as bfe | ||
| import bigframes.formatting_helpers as bf_formatting | ||
| from bigframes import clients | ||
| from bigframes.functions import function as bq_functions | ||
| from bigframes.functions import udf_def | ||
|
|
||
|
|
@@ -630,25 +630,15 @@ def wrapper(func): | |
| if udf_sig.is_row_processor: | ||
| msg = bfe.format_message("input_types=Series is in preview.") | ||
| warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning) | ||
| return decorator( | ||
| bq_functions.BigqueryCallableRowRoutine( | ||
| udf_definition, | ||
| session, | ||
| cloud_function_ref=bigframes_cloud_function, | ||
| local_func=func, | ||
| is_managed=False, | ||
| ) | ||
| ) | ||
| else: | ||
| return decorator( | ||
| bq_functions.BigqueryCallableRoutine( | ||
| udf_definition, | ||
| session, | ||
| cloud_function_ref=bigframes_cloud_function, | ||
| local_func=func, | ||
| is_managed=False, | ||
| ) | ||
| return decorator( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lol. Good catch on these duplicate calls. Thanks for the cleanup! |
||
| bq_functions.BigqueryCallableRoutine( | ||
| udf_definition, | ||
| session, | ||
| cloud_function_ref=bigframes_cloud_function, | ||
| local_func=func, | ||
| is_managed=False, | ||
| ) | ||
| ) | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
@@ -834,8 +824,9 @@ def wrapper(func): | |
| bq_connection_manager, | ||
| session=session, # type: ignore | ||
| ) | ||
| code_def = udf_def.CodeDef.from_func(func) | ||
| config = udf_def.ManagedFunctionConfig( | ||
| code=udf_def.CodeDef.from_func(func), | ||
| code=code_def, | ||
| signature=udf_sig, | ||
| max_batching_rows=max_batching_rows, | ||
| container_cpu=container_cpu, | ||
|
|
@@ -859,28 +850,18 @@ def wrapper(func): | |
| signature=udf_sig, | ||
| ) | ||
|
|
||
| if not name: | ||
| self._update_temp_artifacts(full_rf_name, "") | ||
|
|
||
| decorator = functools.wraps(func) | ||
| if udf_sig.is_row_processor: | ||
| msg = bfe.format_message("input_types=Series is in preview.") | ||
| warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning) | ||
| assert session is not None # appease mypy | ||
| return decorator( | ||
| bq_functions.BigqueryCallableRowRoutine( | ||
| udf_definition, session, local_func=func, is_managed=True | ||
| ) | ||
| ) | ||
|
|
||
| if not name: # session-managed resource | ||
| self._update_temp_artifacts(full_rf_name, "") | ||
| return bq_functions.UdfRoutine(func=func, _udf_def=udf_definition) | ||
|
|
||
| # user-managed permanent resource | ||
| else: | ||
| assert session is not None # appease mypy | ||
| return decorator( | ||
| bq_functions.BigqueryCallableRoutine( | ||
| udf_definition, | ||
| session, | ||
| local_func=func, | ||
| is_managed=True, | ||
| ) | ||
| return bq_functions.BigqueryCallableRoutine( | ||
| udf_definition, session, local_func=func, is_managed=True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: the Perhaps "user-owned" and "bigframes session-owned" in the comments would help avoid confusion? Alternatively, something relating to "lifetime" in the comments, as that's really the key difference. |
||
| ) | ||
|
|
||
| return wrapper | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,11 +15,13 @@ | |
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| from typing import TYPE_CHECKING, Callable, Optional | ||
| from typing import Callable, Optional, Protocol, runtime_checkable, TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| import bigframes.series | ||
| from bigframes.session import Session | ||
| import bigframes.series | ||
|
|
||
| import dataclasses | ||
|
|
||
| import google.api_core.exceptions | ||
| from google.cloud import bigquery | ||
|
|
@@ -28,6 +30,9 @@ | |
| from bigframes.functions import _function_session as bff_session | ||
| from bigframes.functions import function_typing, udf_def | ||
|
|
||
| if TYPE_CHECKING: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we combine the two |
||
| import bigframes.core.col | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -90,13 +95,13 @@ def _try_import_routine( | |
|
|
||
| def _try_import_row_routine( | ||
| routine: bigquery.Routine, session: bigframes.Session | ||
| ) -> BigqueryCallableRowRoutine: | ||
| ) -> BigqueryCallableRoutine: | ||
| udf_def = _routine_as_udf_def(routine, is_row_processor=True) | ||
|
|
||
| is_remote = ( | ||
| hasattr(routine, "remote_function_options") and routine.remote_function_options | ||
| ) | ||
| return BigqueryCallableRowRoutine(udf_def, session, is_managed=not is_remote) | ||
| return BigqueryCallableRoutine(udf_def, session, is_managed=not is_remote) | ||
|
|
||
|
|
||
| def _routine_as_udf_def( | ||
|
|
@@ -117,7 +122,6 @@ def _routine_as_udf_def( | |
| ) | ||
|
|
||
|
|
||
| # TODO(b/399894805): Support managed function. | ||
| def read_gbq_function( | ||
| function_name: str, | ||
| *, | ||
|
|
@@ -152,6 +156,12 @@ def read_gbq_function( | |
| return _try_import_routine(routine, session) | ||
|
|
||
|
|
||
| @runtime_checkable | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add to the docstring for the Udf class the intended usage behind adding runtime_checkable? Seems to me that this is an important part of the contract and worth documenting. |
||
| class Udf(Protocol): | ||
| @property | ||
| def udf_def(self) -> udf_def.BigqueryUdf: ... | ||
|
|
||
|
|
||
| class BigqueryCallableRoutine: | ||
| """ | ||
| A reference to a routine in the context of a session. | ||
|
|
@@ -178,8 +188,8 @@ def __call__(self, *args, **kwargs): | |
| if self._local_fun: | ||
| return self._local_fun(*args, **kwargs) | ||
| # avoid circular imports | ||
| import bigframes.session._io.bigquery as bf_io_bigquery | ||
| from bigframes.core.compile.sqlglot import sql as sg_sql | ||
| import bigframes.session._io.bigquery as bf_io_bigquery | ||
|
|
||
| args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args]) | ||
| sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})" | ||
|
|
@@ -202,7 +212,7 @@ def bigframes_remote_function(self): | |
|
|
||
| @property | ||
| def is_row_processor(self) -> bool: | ||
| return False | ||
| return self.udf_def.signature.is_row_processor | ||
|
|
||
| @property | ||
| def udf_def(self) -> udf_def.BigqueryUdf: | ||
|
|
@@ -225,75 +235,16 @@ def bigframes_bigquery_function_output_dtype(self): | |
| return self.udf_def.signature.output.emulating_type.bf_type | ||
|
|
||
|
|
||
| class BigqueryCallableRowRoutine: | ||
| """ | ||
| A reference to a routine in the context of a session. | ||
|
|
||
| Can be used both directly as a callable, or as an input to dataframe ops that take a callable. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| udf_def: udf_def.BigqueryUdf, | ||
| session: bigframes.Session, | ||
| *, | ||
| local_func: Optional[Callable] = None, | ||
| cloud_function_ref: Optional[str] = None, | ||
| is_managed: bool = False, | ||
| ): | ||
| assert udf_def.signature.is_row_processor | ||
| self._udf_def = udf_def | ||
| self._session = session | ||
| self._local_fun = local_func | ||
| self._cloud_function = cloud_function_ref | ||
| self._is_managed = is_managed | ||
| @dataclasses.dataclass(frozen=True) | ||
| class UdfRoutine: | ||
| func: Callable | ||
| # Try not to depend on this, bq managed function creation will be deferred later | ||
| # And this ref will be replaced with requirements rather to support lazy creation | ||
| _udf_def: udf_def.BigqueryUdf | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| if self._local_fun: | ||
| return self._local_fun(*args, **kwargs) | ||
| # avoid circular imports | ||
| import bigframes.session._io.bigquery as bf_io_bigquery | ||
| from bigframes.core.compile.sqlglot import sql as sg_sql | ||
|
|
||
| args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args]) | ||
| sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})" | ||
| iter, job = bf_io_bigquery.start_query_with_client( | ||
| self._session.bqclient, | ||
| sql=sql, | ||
| query_with_job=True, | ||
| job_config=bigquery.QueryJobConfig(), | ||
| publisher=self._session._publisher, | ||
| ) # type: ignore | ||
| return list(iter.to_arrow().to_pydict().values())[0][0] | ||
|
|
||
| @property | ||
| def bigframes_bigquery_function(self) -> str: | ||
| return str(self._udf_def.routine_ref) | ||
|
|
||
| @property | ||
| def bigframes_remote_function(self): | ||
| return None if self._is_managed else str(self._udf_def.routine_ref) | ||
|
|
||
| @property | ||
| def is_row_processor(self) -> bool: | ||
| return True | ||
| return self.func(*args, **kwargs) | ||
|
|
||
| @property | ||
| def udf_def(self) -> udf_def.BigqueryUdf: | ||
| return self._udf_def | ||
|
|
||
| @property | ||
| def bigframes_cloud_function(self) -> Optional[str]: | ||
| return self._cloud_function | ||
|
|
||
| @property | ||
| def input_dtypes(self): | ||
| return tuple(arg.bf_type for arg in self.udf_def.signature.inputs) | ||
|
|
||
| @property | ||
| def output_dtype(self): | ||
| return self.udf_def.signature.output.bf_type | ||
|
|
||
| @property | ||
| def bigframes_bigquery_function_output_dtype(self): | ||
| return self.udf_def.signature.output.emulating_type.bf_type | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume this is the reason for
runtime_checkableabove?