class TransformerFactory(object):
"""TransformerFactory class following the factory pattern."""
_logger = LoggingHandler(__name__).get_logger()
UNSUPPORTED_STREAMING_TRANSFORMERS = [
"condense_record_mode_cdc",
"group_and_rank",
"with_auto_increment_id",
"with_row_id",
]
AVAILABLE_TRANSFORMERS = {
"add_current_date": DateTransformers.add_current_date,
"cache": Optimizers.cache,
"cast": ColumnReshapers.cast,
"coalesce": Repartitioners.coalesce,
"column_dropper": DataMaskers.column_dropper,
"column_filter_exp": Filters.column_filter_exp,
"column_selector": ColumnReshapers.column_selector,
"condense_record_mode_cdc": Condensers.condense_record_mode_cdc,
"convert_to_date": DateTransformers.convert_to_date,
"convert_to_timestamp": DateTransformers.convert_to_timestamp,
"custom_transformation": CustomTransformers.custom_transformation,
"drop_duplicate_rows": Filters.drop_duplicate_rows,
"expression_filter": Filters.expression_filter,
"format_date": DateTransformers.format_date,
"flatten_schema": ColumnReshapers.flatten_schema,
"explode_columns": ColumnReshapers.explode_columns,
"from_avro": ColumnReshapers.from_avro,
"from_avro_with_registry": ColumnReshapers.from_avro_with_registry,
"from_json": ColumnReshapers.from_json,
"get_date_hierarchy": DateTransformers.get_date_hierarchy,
"get_max_value": Aggregators.get_max_value,
"group_and_rank": Condensers.group_and_rank,
"hash_masker": DataMaskers.hash_masker,
"incremental_filter": Filters.incremental_filter,
"join": Joiners.join,
"persist": Optimizers.persist,
"rename": ColumnReshapers.rename,
"repartition": Repartitioners.repartition,
"replace_nulls": NullHandlers.replace_nulls,
"sql_transformation": CustomTransformers.sql_transformation,
"to_json": ColumnReshapers.to_json,
"union": Unions.union,
"union_by_name": Unions.union_by_name,
"with_watermark": Watermarker.with_watermark,
"unpersist": Optimizers.unpersist,
"with_auto_increment_id": ColumnCreators.with_auto_increment_id,
"with_expressions": ColumnReshapers.with_expressions,
"with_literals": ColumnCreators.with_literals,
"with_regex_value": RegexTransformers.with_regex_value,
"with_row_id": ColumnCreators.with_row_id,
}
@staticmethod
def get_transformer(spec: TransformerSpec, data: OrderedDict = None) -> Callable:
"""Get a transformer following the factory pattern.
Args:
spec: transformer specification (individual transformation... not to be
confused with list of all transformations).
data: ordered dict of dataframes to be transformed. Needed when a
transformer requires more than one dataframe as input.
Returns:
Transformer function to be executed in .transform() spark function.
{{get_example(method_name='get_transformer')}}
"""
if spec.function == "incremental_filter":
# incremental_filter optionally expects a DataFrame as input, so find it.
args_copy = TransformerFactory._get_spec_args_copy(spec.args)
if "increment_df" in args_copy:
args_copy["increment_df"] = data[args_copy["increment_df"]]
return TransformerFactory.AVAILABLE_TRANSFORMERS[ # type: ignore
spec.function
](**args_copy)
elif spec.function == "join":
# get the dataframe given the input_id in the input specs of the acon.
args_copy = TransformerFactory._get_spec_args_copy(spec.args)
args_copy["join_with"] = data[args_copy["join_with"]]
return TransformerFactory.AVAILABLE_TRANSFORMERS[ # type: ignore
spec.function
](**args_copy)
elif spec.function == "union" or spec.function == "union_by_name":
# get the list of dataframes given the input_id in the input specs
# of the acon.
args_copy = TransformerFactory._get_spec_args_copy(spec.args)
args_copy["union_with"] = []
for union_with_spec_id in spec.args["union_with"]:
args_copy["union_with"].append(data[union_with_spec_id])
return TransformerFactory.AVAILABLE_TRANSFORMERS[ # type: ignore
spec.function
](**args_copy)
elif spec.function in TransformerFactory.AVAILABLE_TRANSFORMERS:
return TransformerFactory.AVAILABLE_TRANSFORMERS[ # type: ignore
spec.function
](**spec.args)
else:
raise NotImplementedError(
f"The requested transformer {spec.function} is not implemented."
)
@staticmethod
def _get_spec_args_copy(spec_args: dict) -> dict:
"""Returns a shallow copy of `spec_args` to ensure immutability.
Args:
spec_args (dict): A dictionary containing the arguments of a
TransformerSpec.
Returns:
dict: A shallow copy of `spec_args`, preventing modifications to the
original dictionary. This is important in Spark, especially when
retries of failed attempts occur. For example, if during the first
run the `join_with` argument (initially a string) is replaced with a
DataFrame (as done in the `get_transformer` function), then on a retry,
depending on how Spark handles state, the `join_with` argument may no
longer be a string but a DataFrame, leading to key error.
"""
return dict(spec_args)