lakehouse_engine.utils.schema_utils

Utilities to facilitate dataframe schema management.

  1"""Utilities to facilitate dataframe schema management."""
  2
  3from logging import Logger
  4from typing import Any, List, Optional
  5
  6from pyspark.sql.functions import col
  7from pyspark.sql.types import StructType
  8
  9from lakehouse_engine.core.definitions import InputSpec
 10from lakehouse_engine.core.exec_env import ExecEnv
 11from lakehouse_engine.utils.logging_handler import LoggingHandler
 12from lakehouse_engine.utils.storage.file_storage_functions import FileStorageFunctions
 13
 14
 15class SchemaUtils(object):
 16    """Schema utils that help retrieve and manage schemas of dataframes."""
 17
 18    _logger: Logger = LoggingHandler(__name__).get_logger()
 19
 20    @staticmethod
 21    def from_file(file_path: str, disable_dbfs_retry: bool = False) -> StructType:
 22        """Get a spark schema from a file (spark StructType json file) in a file system.
 23
 24        Args:
 25            file_path: path of the file in a file system. [Check here](
 26                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
 27            disable_dbfs_retry: optional flag to disable file storage dbfs.
 28
 29        Returns:
 30            Spark schema struct type.
 31        """
 32        return StructType.fromJson(
 33            FileStorageFunctions.read_json(file_path, disable_dbfs_retry)
 34        )
 35
 36    @staticmethod
 37    def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any:
 38        """Get a dict with the spark schema from a file in a file system.
 39
 40        Args:
 41            file_path: path of the file in a file system. [Check here](
 42                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
 43            disable_dbfs_retry: optional flag to disable file storage dbfs.
 44
 45        Returns:
 46             Spark schema in a dict.
 47        """
 48        return FileStorageFunctions.read_json(file_path, disable_dbfs_retry)
 49
 50    @staticmethod
 51    def from_dict(struct_type: dict) -> StructType:
 52        """Get a spark schema from a dict.
 53
 54        Args:
 55            struct_type: dict containing a spark schema structure. [Check here](
 56                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
 57
 58        Returns:
 59             Spark schema struct type.
 60        """
 61        return StructType.fromJson(struct_type)
 62
 63    @staticmethod
 64    def from_table_schema(table: str) -> StructType:
 65        """Get a spark schema from a table.
 66
 67        Args:
 68            table: table name from which to inherit the schema.
 69
 70        Returns:
 71            Spark schema struct type.
 72        """
 73        return ExecEnv.SESSION.read.table(table).schema
 74
 75    @classmethod
 76    def from_input_spec(cls, input_spec: InputSpec) -> Optional[StructType]:
 77        """Get a spark schema from an input specification.
 78
 79        This covers scenarios where the schema is provided as part of the input
 80        specification of the algorithm. Schema can come from the table specified in the
 81        input specification (enforce_schema_from_table) or by the dict with the spark
 82        schema provided there also.
 83
 84        Args:
 85            input_spec: input specification.
 86
 87        Returns:
 88            spark schema struct type.
 89        """
 90        if input_spec.enforce_schema_from_table:
 91            cls._logger.info(
 92                f"Reading schema from table: {input_spec.enforce_schema_from_table}"
 93            )
 94            return SchemaUtils.from_table_schema(input_spec.enforce_schema_from_table)
 95        elif input_spec.schema_path:
 96            cls._logger.info(f"Reading schema from file: {input_spec.schema_path}")
 97            return SchemaUtils.from_file(
 98                input_spec.schema_path, input_spec.disable_dbfs_retry
 99            )
100        elif input_spec.schema:
101            cls._logger.info(
102                f"Reading schema from configuration file: {input_spec.schema}"
103            )
104            return SchemaUtils.from_dict(input_spec.schema)
105        else:
106            cls._logger.info("No schema was provided... skipping enforce schema")
107            return None
108
109    @staticmethod
110    def _get_prefix_alias(num_chars: int, prefix: str, shorten_names: bool) -> str:
111        """Get prefix alias for a field."""
112        return (
113            f"""{'_'.join(
114                [item[:num_chars] for item in prefix.split('.')]
115            )}_"""
116            if shorten_names
117            else f"{prefix}_".replace(".", "_")
118        )
119
120    @staticmethod
121    def schema_flattener(
122        schema: StructType,
123        prefix: str = None,
124        level: int = 1,
125        max_level: int = None,
126        shorten_names: bool = False,
127        alias: bool = True,
128        num_chars: int = 7,
129        ignore_cols: List = None,
130    ) -> List:
131        """Recursive method to flatten the schema of the dataframe.
132
133        Args:
134            schema: schema to be flattened.
135            prefix: prefix of the struct to get the value for. Only relevant
136                for being used in the internal recursive logic.
137            level: level of the depth in the schema being flattened. Only relevant
138                for being used in the internal recursive logic.
139            max_level: level until which you want to flatten the schema. Default: None.
140            shorten_names: whether to shorten the names of the prefixes of the fields
141                being flattened or not. Default: False.
142            alias: whether to define alias for the columns being flattened or
143                not. Default: True.
144            num_chars: number of characters to consider when shortening the names of
145                the fields. Default: 7.
146            ignore_cols: columns which you don't want to flatten. Default: None.
147
148        Returns:
149            A function to be called in .transform() spark function.
150        """
151        cols = []
152        ignore_cols = ignore_cols if ignore_cols else []
153        for field in schema.fields:
154            name = prefix + "." + field.name if prefix else field.name
155            field_type = field.dataType
156
157            if (
158                isinstance(field_type, StructType)
159                and name not in ignore_cols
160                and (max_level is None or level <= max_level)
161            ):
162                cols += SchemaUtils.schema_flattener(
163                    schema=field_type,
164                    prefix=name,
165                    level=level + 1,
166                    max_level=max_level,
167                    shorten_names=shorten_names,
168                    alias=alias,
169                    num_chars=num_chars,
170                    ignore_cols=ignore_cols,
171                )
172            else:
173                if alias and prefix:
174                    prefix_alias = SchemaUtils._get_prefix_alias(
175                        num_chars, prefix, shorten_names
176                    )
177                    cols.append(col(name).alias(f"{prefix_alias}{field.name}"))
178                else:
179                    cols.append(col(name))
180        return cols
class SchemaUtils:
 16class SchemaUtils(object):
 17    """Schema utils that help retrieve and manage schemas of dataframes."""
 18
 19    _logger: Logger = LoggingHandler(__name__).get_logger()
 20
 21    @staticmethod
 22    def from_file(file_path: str, disable_dbfs_retry: bool = False) -> StructType:
 23        """Get a spark schema from a file (spark StructType json file) in a file system.
 24
 25        Args:
 26            file_path: path of the file in a file system. [Check here](
 27                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
 28            disable_dbfs_retry: optional flag to disable file storage dbfs.
 29
 30        Returns:
 31            Spark schema struct type.
 32        """
 33        return StructType.fromJson(
 34            FileStorageFunctions.read_json(file_path, disable_dbfs_retry)
 35        )
 36
 37    @staticmethod
 38    def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any:
 39        """Get a dict with the spark schema from a file in a file system.
 40
 41        Args:
 42            file_path: path of the file in a file system. [Check here](
 43                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
 44            disable_dbfs_retry: optional flag to disable file storage dbfs.
 45
 46        Returns:
 47             Spark schema in a dict.
 48        """
 49        return FileStorageFunctions.read_json(file_path, disable_dbfs_retry)
 50
 51    @staticmethod
 52    def from_dict(struct_type: dict) -> StructType:
 53        """Get a spark schema from a dict.
 54
 55        Args:
 56            struct_type: dict containing a spark schema structure. [Check here](
 57                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
 58
 59        Returns:
 60             Spark schema struct type.
 61        """
 62        return StructType.fromJson(struct_type)
 63
 64    @staticmethod
 65    def from_table_schema(table: str) -> StructType:
 66        """Get a spark schema from a table.
 67
 68        Args:
 69            table: table name from which to inherit the schema.
 70
 71        Returns:
 72            Spark schema struct type.
 73        """
 74        return ExecEnv.SESSION.read.table(table).schema
 75
 76    @classmethod
 77    def from_input_spec(cls, input_spec: InputSpec) -> Optional[StructType]:
 78        """Get a spark schema from an input specification.
 79
 80        This covers scenarios where the schema is provided as part of the input
 81        specification of the algorithm. Schema can come from the table specified in the
 82        input specification (enforce_schema_from_table) or by the dict with the spark
 83        schema provided there also.
 84
 85        Args:
 86            input_spec: input specification.
 87
 88        Returns:
 89            spark schema struct type.
 90        """
 91        if input_spec.enforce_schema_from_table:
 92            cls._logger.info(
 93                f"Reading schema from table: {input_spec.enforce_schema_from_table}"
 94            )
 95            return SchemaUtils.from_table_schema(input_spec.enforce_schema_from_table)
 96        elif input_spec.schema_path:
 97            cls._logger.info(f"Reading schema from file: {input_spec.schema_path}")
 98            return SchemaUtils.from_file(
 99                input_spec.schema_path, input_spec.disable_dbfs_retry
100            )
101        elif input_spec.schema:
102            cls._logger.info(
103                f"Reading schema from configuration file: {input_spec.schema}"
104            )
105            return SchemaUtils.from_dict(input_spec.schema)
106        else:
107            cls._logger.info("No schema was provided... skipping enforce schema")
108            return None
109
110    @staticmethod
111    def _get_prefix_alias(num_chars: int, prefix: str, shorten_names: bool) -> str:
112        """Get prefix alias for a field."""
113        return (
114            f"""{'_'.join(
115                [item[:num_chars] for item in prefix.split('.')]
116            )}_"""
117            if shorten_names
118            else f"{prefix}_".replace(".", "_")
119        )
120
121    @staticmethod
122    def schema_flattener(
123        schema: StructType,
124        prefix: str = None,
125        level: int = 1,
126        max_level: int = None,
127        shorten_names: bool = False,
128        alias: bool = True,
129        num_chars: int = 7,
130        ignore_cols: List = None,
131    ) -> List:
132        """Recursive method to flatten the schema of the dataframe.
133
134        Args:
135            schema: schema to be flattened.
136            prefix: prefix of the struct to get the value for. Only relevant
137                for being used in the internal recursive logic.
138            level: level of the depth in the schema being flattened. Only relevant
139                for being used in the internal recursive logic.
140            max_level: level until which you want to flatten the schema. Default: None.
141            shorten_names: whether to shorten the names of the prefixes of the fields
142                being flattened or not. Default: False.
143            alias: whether to define alias for the columns being flattened or
144                not. Default: True.
145            num_chars: number of characters to consider when shortening the names of
146                the fields. Default: 7.
147            ignore_cols: columns which you don't want to flatten. Default: None.
148
149        Returns:
150            A function to be called in .transform() spark function.
151        """
152        cols = []
153        ignore_cols = ignore_cols if ignore_cols else []
154        for field in schema.fields:
155            name = prefix + "." + field.name if prefix else field.name
156            field_type = field.dataType
157
158            if (
159                isinstance(field_type, StructType)
160                and name not in ignore_cols
161                and (max_level is None or level <= max_level)
162            ):
163                cols += SchemaUtils.schema_flattener(
164                    schema=field_type,
165                    prefix=name,
166                    level=level + 1,
167                    max_level=max_level,
168                    shorten_names=shorten_names,
169                    alias=alias,
170                    num_chars=num_chars,
171                    ignore_cols=ignore_cols,
172                )
173            else:
174                if alias and prefix:
175                    prefix_alias = SchemaUtils._get_prefix_alias(
176                        num_chars, prefix, shorten_names
177                    )
178                    cols.append(col(name).alias(f"{prefix_alias}{field.name}"))
179                else:
180                    cols.append(col(name))
181        return cols

Schema utils that help retrieve and manage schemas of dataframes.

@staticmethod
def from_file( file_path: str, disable_dbfs_retry: bool = False) -> pyspark.sql.types.StructType:
21    @staticmethod
22    def from_file(file_path: str, disable_dbfs_retry: bool = False) -> StructType:
23        """Get a spark schema from a file (spark StructType json file) in a file system.
24
25        Args:
26            file_path: path of the file in a file system. [Check here](
27                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
28            disable_dbfs_retry: optional flag to disable file storage dbfs.
29
30        Returns:
31            Spark schema struct type.
32        """
33        return StructType.fromJson(
34            FileStorageFunctions.read_json(file_path, disable_dbfs_retry)
35        )

Get a spark schema from a file (spark StructType json file) in a file system.

Arguments:
  • file_path: path of the file in a file system. Check here.
  • disable_dbfs_retry: optional flag to disable file storage dbfs.
Returns:

Spark schema struct type.

@staticmethod
def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any:
37    @staticmethod
38    def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any:
39        """Get a dict with the spark schema from a file in a file system.
40
41        Args:
42            file_path: path of the file in a file system. [Check here](
43                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
44            disable_dbfs_retry: optional flag to disable file storage dbfs.
45
46        Returns:
47             Spark schema in a dict.
48        """
49        return FileStorageFunctions.read_json(file_path, disable_dbfs_retry)

Get a dict with the spark schema from a file in a file system.

Arguments:
  • file_path: path of the file in a file system. Check here.
  • disable_dbfs_retry: optional flag to disable file storage dbfs.
Returns:

Spark schema in a dict.

@staticmethod
def from_dict(struct_type: dict) -> pyspark.sql.types.StructType:
51    @staticmethod
52    def from_dict(struct_type: dict) -> StructType:
53        """Get a spark schema from a dict.
54
55        Args:
56            struct_type: dict containing a spark schema structure. [Check here](
57                https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html).
58
59        Returns:
60             Spark schema struct type.
61        """
62        return StructType.fromJson(struct_type)

Get a spark schema from a dict.

Arguments:
  • struct_type: dict containing a spark schema structure. Check here.
Returns:

Spark schema struct type.

@staticmethod
def from_table_schema(table: str) -> pyspark.sql.types.StructType:
64    @staticmethod
65    def from_table_schema(table: str) -> StructType:
66        """Get a spark schema from a table.
67
68        Args:
69            table: table name from which to inherit the schema.
70
71        Returns:
72            Spark schema struct type.
73        """
74        return ExecEnv.SESSION.read.table(table).schema

Get a spark schema from a table.

Arguments:
  • table: table name from which to inherit the schema.
Returns:

Spark schema struct type.

@classmethod
def from_input_spec( cls, input_spec: lakehouse_engine.core.definitions.InputSpec) -> Optional[pyspark.sql.types.StructType]:
 76    @classmethod
 77    def from_input_spec(cls, input_spec: InputSpec) -> Optional[StructType]:
 78        """Get a spark schema from an input specification.
 79
 80        This covers scenarios where the schema is provided as part of the input
 81        specification of the algorithm. Schema can come from the table specified in the
 82        input specification (enforce_schema_from_table) or by the dict with the spark
 83        schema provided there also.
 84
 85        Args:
 86            input_spec: input specification.
 87
 88        Returns:
 89            spark schema struct type.
 90        """
 91        if input_spec.enforce_schema_from_table:
 92            cls._logger.info(
 93                f"Reading schema from table: {input_spec.enforce_schema_from_table}"
 94            )
 95            return SchemaUtils.from_table_schema(input_spec.enforce_schema_from_table)
 96        elif input_spec.schema_path:
 97            cls._logger.info(f"Reading schema from file: {input_spec.schema_path}")
 98            return SchemaUtils.from_file(
 99                input_spec.schema_path, input_spec.disable_dbfs_retry
100            )
101        elif input_spec.schema:
102            cls._logger.info(
103                f"Reading schema from configuration file: {input_spec.schema}"
104            )
105            return SchemaUtils.from_dict(input_spec.schema)
106        else:
107            cls._logger.info("No schema was provided... skipping enforce schema")
108            return None

Get a spark schema from an input specification.

This covers scenarios where the schema is provided as part of the input specification of the algorithm. Schema can come from the table specified in the input specification (enforce_schema_from_table) or by the dict with the spark schema provided there also.

Arguments:
  • input_spec: input specification.
Returns:

spark schema struct type.

@staticmethod
def schema_flattener( schema: pyspark.sql.types.StructType, prefix: str = None, level: int = 1, max_level: int = None, shorten_names: bool = False, alias: bool = True, num_chars: int = 7, ignore_cols: List = None) -> List:
121    @staticmethod
122    def schema_flattener(
123        schema: StructType,
124        prefix: str = None,
125        level: int = 1,
126        max_level: int = None,
127        shorten_names: bool = False,
128        alias: bool = True,
129        num_chars: int = 7,
130        ignore_cols: List = None,
131    ) -> List:
132        """Recursive method to flatten the schema of the dataframe.
133
134        Args:
135            schema: schema to be flattened.
136            prefix: prefix of the struct to get the value for. Only relevant
137                for being used in the internal recursive logic.
138            level: level of the depth in the schema being flattened. Only relevant
139                for being used in the internal recursive logic.
140            max_level: level until which you want to flatten the schema. Default: None.
141            shorten_names: whether to shorten the names of the prefixes of the fields
142                being flattened or not. Default: False.
143            alias: whether to define alias for the columns being flattened or
144                not. Default: True.
145            num_chars: number of characters to consider when shortening the names of
146                the fields. Default: 7.
147            ignore_cols: columns which you don't want to flatten. Default: None.
148
149        Returns:
150            A function to be called in .transform() spark function.
151        """
152        cols = []
153        ignore_cols = ignore_cols if ignore_cols else []
154        for field in schema.fields:
155            name = prefix + "." + field.name if prefix else field.name
156            field_type = field.dataType
157
158            if (
159                isinstance(field_type, StructType)
160                and name not in ignore_cols
161                and (max_level is None or level <= max_level)
162            ):
163                cols += SchemaUtils.schema_flattener(
164                    schema=field_type,
165                    prefix=name,
166                    level=level + 1,
167                    max_level=max_level,
168                    shorten_names=shorten_names,
169                    alias=alias,
170                    num_chars=num_chars,
171                    ignore_cols=ignore_cols,
172                )
173            else:
174                if alias and prefix:
175                    prefix_alias = SchemaUtils._get_prefix_alias(
176                        num_chars, prefix, shorten_names
177                    )
178                    cols.append(col(name).alias(f"{prefix_alias}{field.name}"))
179                else:
180                    cols.append(col(name))
181        return cols

Recursive method to flatten the schema of the dataframe.

Arguments:
  • schema: schema to be flattened.
  • prefix: prefix of the struct to get the value for. Only relevant for being used in the internal recursive logic.
  • level: level of the depth in the schema being flattened. Only relevant for being used in the internal recursive logic.
  • max_level: level until which you want to flatten the schema. Default: None.
  • shorten_names: whether to shorten the names of the prefixes of the fields being flattened or not. Default: False.
  • alias: whether to define alias for the columns being flattened or not. Default: True.
  • num_chars: number of characters to consider when shortening the names of the fields. Default: 7.
  • ignore_cols: columns which you don't want to flatten. Default: None.
Returns:

A function to be called in .transform() spark function.