lakehouse_engine.utils.schema_utils
Utilities to facilitate dataframe schema management.
1"""Utilities to facilitate dataframe schema management.""" 2 3from logging import Logger 4from typing import Any, List, Optional 5 6from pyspark.sql.functions import col 7from pyspark.sql.types import StructType 8 9from lakehouse_engine.core.definitions import InputSpec 10from lakehouse_engine.core.exec_env import ExecEnv 11from lakehouse_engine.utils.logging_handler import LoggingHandler 12from lakehouse_engine.utils.storage.file_storage_functions import FileStorageFunctions 13 14 15class SchemaUtils(object): 16 """Schema utils that help retrieve and manage schemas of dataframes.""" 17 18 _logger: Logger = LoggingHandler(__name__).get_logger() 19 20 @staticmethod 21 def from_file(file_path: str, disable_dbfs_retry: bool = False) -> StructType: 22 """Get a spark schema from a file (spark StructType json file) in a file system. 23 24 Args: 25 file_path: path of the file in a file system. [Check here]( 26 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 27 disable_dbfs_retry: optional flag to disable file storage dbfs. 28 29 Returns: 30 Spark schema struct type. 31 """ 32 return StructType.fromJson( 33 FileStorageFunctions.read_json(file_path, disable_dbfs_retry) 34 ) 35 36 @staticmethod 37 def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any: 38 """Get a dict with the spark schema from a file in a file system. 39 40 Args: 41 file_path: path of the file in a file system. [Check here]( 42 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 43 disable_dbfs_retry: optional flag to disable file storage dbfs. 44 45 Returns: 46 Spark schema in a dict. 47 """ 48 return FileStorageFunctions.read_json(file_path, disable_dbfs_retry) 49 50 @staticmethod 51 def from_dict(struct_type: dict) -> StructType: 52 """Get a spark schema from a dict. 53 54 Args: 55 struct_type: dict containing a spark schema structure. [Check here]( 56 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 57 58 Returns: 59 Spark schema struct type. 60 """ 61 return StructType.fromJson(struct_type) 62 63 @staticmethod 64 def from_table_schema(table: str) -> StructType: 65 """Get a spark schema from a table. 66 67 Args: 68 table: table name from which to inherit the schema. 69 70 Returns: 71 Spark schema struct type. 72 """ 73 return ExecEnv.SESSION.read.table(table).schema 74 75 @classmethod 76 def from_input_spec(cls, input_spec: InputSpec) -> Optional[StructType]: 77 """Get a spark schema from an input specification. 78 79 This covers scenarios where the schema is provided as part of the input 80 specification of the algorithm. Schema can come from the table specified in the 81 input specification (enforce_schema_from_table) or by the dict with the spark 82 schema provided there also. 83 84 Args: 85 input_spec: input specification. 86 87 Returns: 88 spark schema struct type. 89 """ 90 if input_spec.enforce_schema_from_table: 91 cls._logger.info( 92 f"Reading schema from table: {input_spec.enforce_schema_from_table}" 93 ) 94 return SchemaUtils.from_table_schema(input_spec.enforce_schema_from_table) 95 elif input_spec.schema_path: 96 cls._logger.info(f"Reading schema from file: {input_spec.schema_path}") 97 return SchemaUtils.from_file( 98 input_spec.schema_path, input_spec.disable_dbfs_retry 99 ) 100 elif input_spec.schema: 101 cls._logger.info( 102 f"Reading schema from configuration file: {input_spec.schema}" 103 ) 104 return SchemaUtils.from_dict(input_spec.schema) 105 else: 106 cls._logger.info("No schema was provided... skipping enforce schema") 107 return None 108 109 @staticmethod 110 def _get_prefix_alias(num_chars: int, prefix: str, shorten_names: bool) -> str: 111 """Get prefix alias for a field.""" 112 return ( 113 f"""{'_'.join( 114 [item[:num_chars] for item in prefix.split('.')] 115 )}_""" 116 if shorten_names 117 else f"{prefix}_".replace(".", "_") 118 ) 119 120 @staticmethod 121 def schema_flattener( 122 schema: StructType, 123 prefix: str = None, 124 level: int = 1, 125 max_level: int = None, 126 shorten_names: bool = False, 127 alias: bool = True, 128 num_chars: int = 7, 129 ignore_cols: List = None, 130 ) -> List: 131 """Recursive method to flatten the schema of the dataframe. 132 133 Args: 134 schema: schema to be flattened. 135 prefix: prefix of the struct to get the value for. Only relevant 136 for being used in the internal recursive logic. 137 level: level of the depth in the schema being flattened. Only relevant 138 for being used in the internal recursive logic. 139 max_level: level until which you want to flatten the schema. Default: None. 140 shorten_names: whether to shorten the names of the prefixes of the fields 141 being flattened or not. Default: False. 142 alias: whether to define alias for the columns being flattened or 143 not. Default: True. 144 num_chars: number of characters to consider when shortening the names of 145 the fields. Default: 7. 146 ignore_cols: columns which you don't want to flatten. Default: None. 147 148 Returns: 149 A function to be called in .transform() spark function. 150 """ 151 cols = [] 152 ignore_cols = ignore_cols if ignore_cols else [] 153 for field in schema.fields: 154 name = prefix + "." + field.name if prefix else field.name 155 field_type = field.dataType 156 157 if ( 158 isinstance(field_type, StructType) 159 and name not in ignore_cols 160 and (max_level is None or level <= max_level) 161 ): 162 cols += SchemaUtils.schema_flattener( 163 schema=field_type, 164 prefix=name, 165 level=level + 1, 166 max_level=max_level, 167 shorten_names=shorten_names, 168 alias=alias, 169 num_chars=num_chars, 170 ignore_cols=ignore_cols, 171 ) 172 else: 173 if alias and prefix: 174 prefix_alias = SchemaUtils._get_prefix_alias( 175 num_chars, prefix, shorten_names 176 ) 177 cols.append(col(name).alias(f"{prefix_alias}{field.name}")) 178 else: 179 cols.append(col(name)) 180 return cols
16class SchemaUtils(object): 17 """Schema utils that help retrieve and manage schemas of dataframes.""" 18 19 _logger: Logger = LoggingHandler(__name__).get_logger() 20 21 @staticmethod 22 def from_file(file_path: str, disable_dbfs_retry: bool = False) -> StructType: 23 """Get a spark schema from a file (spark StructType json file) in a file system. 24 25 Args: 26 file_path: path of the file in a file system. [Check here]( 27 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 28 disable_dbfs_retry: optional flag to disable file storage dbfs. 29 30 Returns: 31 Spark schema struct type. 32 """ 33 return StructType.fromJson( 34 FileStorageFunctions.read_json(file_path, disable_dbfs_retry) 35 ) 36 37 @staticmethod 38 def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any: 39 """Get a dict with the spark schema from a file in a file system. 40 41 Args: 42 file_path: path of the file in a file system. [Check here]( 43 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 44 disable_dbfs_retry: optional flag to disable file storage dbfs. 45 46 Returns: 47 Spark schema in a dict. 48 """ 49 return FileStorageFunctions.read_json(file_path, disable_dbfs_retry) 50 51 @staticmethod 52 def from_dict(struct_type: dict) -> StructType: 53 """Get a spark schema from a dict. 54 55 Args: 56 struct_type: dict containing a spark schema structure. [Check here]( 57 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 58 59 Returns: 60 Spark schema struct type. 61 """ 62 return StructType.fromJson(struct_type) 63 64 @staticmethod 65 def from_table_schema(table: str) -> StructType: 66 """Get a spark schema from a table. 67 68 Args: 69 table: table name from which to inherit the schema. 70 71 Returns: 72 Spark schema struct type. 73 """ 74 return ExecEnv.SESSION.read.table(table).schema 75 76 @classmethod 77 def from_input_spec(cls, input_spec: InputSpec) -> Optional[StructType]: 78 """Get a spark schema from an input specification. 79 80 This covers scenarios where the schema is provided as part of the input 81 specification of the algorithm. Schema can come from the table specified in the 82 input specification (enforce_schema_from_table) or by the dict with the spark 83 schema provided there also. 84 85 Args: 86 input_spec: input specification. 87 88 Returns: 89 spark schema struct type. 90 """ 91 if input_spec.enforce_schema_from_table: 92 cls._logger.info( 93 f"Reading schema from table: {input_spec.enforce_schema_from_table}" 94 ) 95 return SchemaUtils.from_table_schema(input_spec.enforce_schema_from_table) 96 elif input_spec.schema_path: 97 cls._logger.info(f"Reading schema from file: {input_spec.schema_path}") 98 return SchemaUtils.from_file( 99 input_spec.schema_path, input_spec.disable_dbfs_retry 100 ) 101 elif input_spec.schema: 102 cls._logger.info( 103 f"Reading schema from configuration file: {input_spec.schema}" 104 ) 105 return SchemaUtils.from_dict(input_spec.schema) 106 else: 107 cls._logger.info("No schema was provided... skipping enforce schema") 108 return None 109 110 @staticmethod 111 def _get_prefix_alias(num_chars: int, prefix: str, shorten_names: bool) -> str: 112 """Get prefix alias for a field.""" 113 return ( 114 f"""{'_'.join( 115 [item[:num_chars] for item in prefix.split('.')] 116 )}_""" 117 if shorten_names 118 else f"{prefix}_".replace(".", "_") 119 ) 120 121 @staticmethod 122 def schema_flattener( 123 schema: StructType, 124 prefix: str = None, 125 level: int = 1, 126 max_level: int = None, 127 shorten_names: bool = False, 128 alias: bool = True, 129 num_chars: int = 7, 130 ignore_cols: List = None, 131 ) -> List: 132 """Recursive method to flatten the schema of the dataframe. 133 134 Args: 135 schema: schema to be flattened. 136 prefix: prefix of the struct to get the value for. Only relevant 137 for being used in the internal recursive logic. 138 level: level of the depth in the schema being flattened. Only relevant 139 for being used in the internal recursive logic. 140 max_level: level until which you want to flatten the schema. Default: None. 141 shorten_names: whether to shorten the names of the prefixes of the fields 142 being flattened or not. Default: False. 143 alias: whether to define alias for the columns being flattened or 144 not. Default: True. 145 num_chars: number of characters to consider when shortening the names of 146 the fields. Default: 7. 147 ignore_cols: columns which you don't want to flatten. Default: None. 148 149 Returns: 150 A function to be called in .transform() spark function. 151 """ 152 cols = [] 153 ignore_cols = ignore_cols if ignore_cols else [] 154 for field in schema.fields: 155 name = prefix + "." + field.name if prefix else field.name 156 field_type = field.dataType 157 158 if ( 159 isinstance(field_type, StructType) 160 and name not in ignore_cols 161 and (max_level is None or level <= max_level) 162 ): 163 cols += SchemaUtils.schema_flattener( 164 schema=field_type, 165 prefix=name, 166 level=level + 1, 167 max_level=max_level, 168 shorten_names=shorten_names, 169 alias=alias, 170 num_chars=num_chars, 171 ignore_cols=ignore_cols, 172 ) 173 else: 174 if alias and prefix: 175 prefix_alias = SchemaUtils._get_prefix_alias( 176 num_chars, prefix, shorten_names 177 ) 178 cols.append(col(name).alias(f"{prefix_alias}{field.name}")) 179 else: 180 cols.append(col(name)) 181 return cols
Schema utils that help retrieve and manage schemas of dataframes.
21 @staticmethod 22 def from_file(file_path: str, disable_dbfs_retry: bool = False) -> StructType: 23 """Get a spark schema from a file (spark StructType json file) in a file system. 24 25 Args: 26 file_path: path of the file in a file system. [Check here]( 27 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 28 disable_dbfs_retry: optional flag to disable file storage dbfs. 29 30 Returns: 31 Spark schema struct type. 32 """ 33 return StructType.fromJson( 34 FileStorageFunctions.read_json(file_path, disable_dbfs_retry) 35 )
Get a spark schema from a file (spark StructType json file) in a file system.
Arguments:
- file_path: path of the file in a file system. Check here.
- disable_dbfs_retry: optional flag to disable file storage dbfs.
Returns:
Spark schema struct type.
37 @staticmethod 38 def from_file_to_dict(file_path: str, disable_dbfs_retry: bool = False) -> Any: 39 """Get a dict with the spark schema from a file in a file system. 40 41 Args: 42 file_path: path of the file in a file system. [Check here]( 43 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 44 disable_dbfs_retry: optional flag to disable file storage dbfs. 45 46 Returns: 47 Spark schema in a dict. 48 """ 49 return FileStorageFunctions.read_json(file_path, disable_dbfs_retry)
Get a dict with the spark schema from a file in a file system.
Arguments:
- file_path: path of the file in a file system. Check here.
- disable_dbfs_retry: optional flag to disable file storage dbfs.
Returns:
Spark schema in a dict.
51 @staticmethod 52 def from_dict(struct_type: dict) -> StructType: 53 """Get a spark schema from a dict. 54 55 Args: 56 struct_type: dict containing a spark schema structure. [Check here]( 57 https://spark.apache.org/docs/latest/api/java/org/apache/spark/sql/types/StructType.html). 58 59 Returns: 60 Spark schema struct type. 61 """ 62 return StructType.fromJson(struct_type)
Get a spark schema from a dict.
Arguments:
- struct_type: dict containing a spark schema structure. Check here.
Returns:
Spark schema struct type.
64 @staticmethod 65 def from_table_schema(table: str) -> StructType: 66 """Get a spark schema from a table. 67 68 Args: 69 table: table name from which to inherit the schema. 70 71 Returns: 72 Spark schema struct type. 73 """ 74 return ExecEnv.SESSION.read.table(table).schema
Get a spark schema from a table.
Arguments:
- table: table name from which to inherit the schema.
Returns:
Spark schema struct type.
76 @classmethod 77 def from_input_spec(cls, input_spec: InputSpec) -> Optional[StructType]: 78 """Get a spark schema from an input specification. 79 80 This covers scenarios where the schema is provided as part of the input 81 specification of the algorithm. Schema can come from the table specified in the 82 input specification (enforce_schema_from_table) or by the dict with the spark 83 schema provided there also. 84 85 Args: 86 input_spec: input specification. 87 88 Returns: 89 spark schema struct type. 90 """ 91 if input_spec.enforce_schema_from_table: 92 cls._logger.info( 93 f"Reading schema from table: {input_spec.enforce_schema_from_table}" 94 ) 95 return SchemaUtils.from_table_schema(input_spec.enforce_schema_from_table) 96 elif input_spec.schema_path: 97 cls._logger.info(f"Reading schema from file: {input_spec.schema_path}") 98 return SchemaUtils.from_file( 99 input_spec.schema_path, input_spec.disable_dbfs_retry 100 ) 101 elif input_spec.schema: 102 cls._logger.info( 103 f"Reading schema from configuration file: {input_spec.schema}" 104 ) 105 return SchemaUtils.from_dict(input_spec.schema) 106 else: 107 cls._logger.info("No schema was provided... skipping enforce schema") 108 return None
Get a spark schema from an input specification.
This covers scenarios where the schema is provided as part of the input specification of the algorithm. Schema can come from the table specified in the input specification (enforce_schema_from_table) or by the dict with the spark schema provided there also.
Arguments:
- input_spec: input specification.
Returns:
spark schema struct type.
121 @staticmethod 122 def schema_flattener( 123 schema: StructType, 124 prefix: str = None, 125 level: int = 1, 126 max_level: int = None, 127 shorten_names: bool = False, 128 alias: bool = True, 129 num_chars: int = 7, 130 ignore_cols: List = None, 131 ) -> List: 132 """Recursive method to flatten the schema of the dataframe. 133 134 Args: 135 schema: schema to be flattened. 136 prefix: prefix of the struct to get the value for. Only relevant 137 for being used in the internal recursive logic. 138 level: level of the depth in the schema being flattened. Only relevant 139 for being used in the internal recursive logic. 140 max_level: level until which you want to flatten the schema. Default: None. 141 shorten_names: whether to shorten the names of the prefixes of the fields 142 being flattened or not. Default: False. 143 alias: whether to define alias for the columns being flattened or 144 not. Default: True. 145 num_chars: number of characters to consider when shortening the names of 146 the fields. Default: 7. 147 ignore_cols: columns which you don't want to flatten. Default: None. 148 149 Returns: 150 A function to be called in .transform() spark function. 151 """ 152 cols = [] 153 ignore_cols = ignore_cols if ignore_cols else [] 154 for field in schema.fields: 155 name = prefix + "." + field.name if prefix else field.name 156 field_type = field.dataType 157 158 if ( 159 isinstance(field_type, StructType) 160 and name not in ignore_cols 161 and (max_level is None or level <= max_level) 162 ): 163 cols += SchemaUtils.schema_flattener( 164 schema=field_type, 165 prefix=name, 166 level=level + 1, 167 max_level=max_level, 168 shorten_names=shorten_names, 169 alias=alias, 170 num_chars=num_chars, 171 ignore_cols=ignore_cols, 172 ) 173 else: 174 if alias and prefix: 175 prefix_alias = SchemaUtils._get_prefix_alias( 176 num_chars, prefix, shorten_names 177 ) 178 cols.append(col(name).alias(f"{prefix_alias}{field.name}")) 179 else: 180 cols.append(col(name)) 181 return cols
Recursive method to flatten the schema of the dataframe.
Arguments:
- schema: schema to be flattened.
- prefix: prefix of the struct to get the value for. Only relevant for being used in the internal recursive logic.
- level: level of the depth in the schema being flattened. Only relevant for being used in the internal recursive logic.
- max_level: level until which you want to flatten the schema. Default: None.
- shorten_names: whether to shorten the names of the prefixes of the fields being flattened or not. Default: False.
- alias: whether to define alias for the columns being flattened or not. Default: True.
- num_chars: number of characters to consider when shortening the names of the fields. Default: 7.
- ignore_cols: columns which you don't want to flatten. Default: None.
Returns:
A function to be called in .transform() spark function.