refactor(bigframes): Move AI ops to googlesql op framework#17146
refactor(bigframes): Move AI ops to googlesql op framework#17146TrevorBergeron wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors BigQuery AI operations to use a generic GoogleSqlScalarOp framework and introduces a centralized alignment utility in bigframes.core.align. Key changes include replacing specialized AI classes with GoogleSqlScalarOp instances, adding FrozenDict for non-hashable constants, and updating SQL generation to use STRUCT prompts. Review feedback identifies a TypeError in the AIGenerateOp signature, suggests making FrozenDict immutable by overriding mutation methods, and recommends refactoring AI_CLASSIFY to handle dynamic output types. Additionally, the reviewer advised improving alignment logic for multi-index objects and correcting type hints in apply_op.
| def _ai_classify_output_type(*input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: | ||
| output_mode = input_types[5] if len(input_types) > 5 else None | ||
| if output_mode is not None: | ||
| return dtypes.list_type(dtypes.STRING_DTYPE) | ||
| return dtypes.STRING_DTYPE |
There was a problem hiding this comment.
The output type of AI.CLASSIFY depends on the value of the output_mode parameter ('SINGLE' vs 'MULTI'), not just its presence. Since GoogleSqlScalarOp.signature only receives argument types, it cannot distinguish between these modes if both are passed as strings. Consider refactoring AI_CLASSIFY into a class (similar to AIGenerateOp) that captures the output_mode value during instantiation to correctly determine the output type.
| raise ValueError("At least one input must be a BigFrames Series or Index object.") | ||
|
|
||
| block = ref_obj._block | ||
| series_to_expr = {id(ref_obj): ex.deref(ref_obj._value_column)} |
There was a problem hiding this comment.
The _value_column property is specific to Series objects. While Index objects in BigFrames also provide this property for single-column indices, it will raise a ValueError for multi-indices. Since align_n is intended to be a generic alignment utility, it should handle multi-column objects or provide a more descriptive error message if they are not supported.
| class FrozenDict(dict): | ||
| def __hash__(self) -> int: | ||
| return hash(tuple((k, make_hashable(v)) for k, v in sorted(self.items()))) |
There was a problem hiding this comment.
Inheriting from dict makes FrozenDict technically mutable, which could lead to issues if the object is mutated after being hashed. To ensure true immutability, consider overriding mutation methods like __setitem__, update, and pop to raise a TypeError. Additionally, ensure the dictionary is programmatically sorted during hashing to maintain consistency.
| class FrozenDict(dict): | |
| def __hash__(self) -> int: | |
| return hash(tuple((k, make_hashable(v)) for k, v in sorted(self.items()))) | |
| class FrozenDict(dict): | |
| def __hash__(self) -> int: | |
| return hash(tuple((k, make_hashable(v)) for k, v in sorted(self.items()))) | |
| def __setitem__(self, key, value): | |
| raise TypeError("FrozenDict is immutable") | |
| def update(self, *args, **kwargs): | |
| raise TypeError("FrozenDict is immutable") | |
| def pop(self, key, default=None): | |
| raise TypeError("FrozenDict is immutable") | |
| def clear(self): | |
| raise TypeError("FrozenDict is immutable") |
References
- To ensure dictionary keys remain sorted without manual effort, programmatically sort the dictionary before returning it (or using it for stable operations like hashing).
| ArgSpec(arg_name="model_params", optional=True, const_only=True), | ||
| ArgSpec(arg_name="output_schema", optional=True, const_only=True), | ||
| ) | ||
| signature: typing.Callable[..., dtypes.ExpressionType] = lambda: dtypes.STRING_DTYPE |
There was a problem hiding this comment.
The lambda function for signature is defined with no arguments, but it will be called with *input_types by the output_type method of the base class. This will result in a TypeError at runtime if output_type is called on an AIGenerateOp instance where the override is not used or if the base implementation is invoked.
| signature: typing.Callable[..., dtypes.ExpressionType] = lambda: dtypes.STRING_DTYPE | |
| signature: typing.Callable[..., dtypes.ExpressionType] = lambda *args: dtypes.STRING_DTYPE |
|
|
||
|
|
||
| def apply_op( | ||
| op: ops.NaryOp, |
There was a problem hiding this comment.
The type hint for op is ops.NaryOp, but GoogleSqlScalarOp (and its subclasses like AIGenerateOp) does not inherit from ops.NaryOp. This will cause type checking errors. Using a less specific type like Any is an acceptable trade-off here to satisfy the type checker.
| op: ops.NaryOp, | |
| op: typing.Any, |
References
- When finding a precise type hint that satisfies both mypy and unit tests is not cost-effective, using a less specific type (e.g., Any) is an acceptable trade-off.
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
Fixes #<issue_number_goes_here> 🦕