import os
from functools import partial
from itertools import islice
from pathlib import Path
from typing import Literal
from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
from tmallet.obfuscators import (
HierarchicalScrambleConfig,
HierarchicalScrambleObfuscator,
LemmaObfuscator,
LinearScrambleConfig,
LinearScrambleObfuscator,
POSFilter,
POSFilterConfig,
ShannonFilter,
ShannonFilterConfig,
)
from tmallet.obfuscators.base import Obfuscator, SpaCyObfuscator
from tmallet.utils import LangConfig, SpaCyInterface, flatten_dict
ObfuscationTechnique = Literal[
"pos-filter", # retain or remove specific POS tags
"scramble-hier", # dependency-parsing structural obfuscation
"scramble-BoW", # randomly shuffle words at the sentence or document level
"shannon", # filter based on an approximation of word importance
"lemmatize", # word-level lemmatisation - still available, but no longer supported
]
[docs]
class TMallet:
"""A text obfuscation manager that applies transformations to text.
This class applies selected algorithmic obfuscation techniques (such as POS filtering,
bag-of-words scrambling, or information-theoretic filtering) to strings, lists of text,
or entire datasets.
Arguments:
lang (LangConfig): The language configuration code (e.g., "en").
prefer_gpu (bool): Whether spaCy is configured to leverage GPU acceleration.
"""
# apply_spacy_preprocessing: determines whether spacy is used or not
# for the initial text processing
# -> determined automatically based on the configuration selected
apply_spacy_preprocessing: bool = False
is_obfuscation_set_up: bool = False
active_obfuscator = None
active_config: dict | None = None
active_algorithm: str | None = None
def __init__(self, lang: LangConfig = "en", prefer_gpu: bool = False):
"""Initialises the obfuscation pipeline.
Args:
lang (LangConfig, optional): The target language configuration - either "en" (English) or "de" (German). Defaults to "en". Let us know if you'd be interested in support for further languages.
prefer_gpu (bool, optional): If True, attempts to allocate spaCy operations
on the GPU. Defaults to False.
"""
self.spacy_interface: SpaCyInterface = SpaCyInterface(
lang=lang, prefer_gpu=prefer_gpu
)
self.lang: LangConfig = lang
self.prefer_gpu = prefer_gpu
[docs]
def load_obfuscator(
self,
algorithm: ObfuscationTechnique,
config: dict[str, str],
):
"""Validates configuration and dynamically instantiates an obfuscation algorithm.
Args:
algorithm (str): The identifier of the obfuscation technique (e.g., 'pos-filter').
config (Dict): Key-value pairings containing parameters for the specific algorithm.
Returns:
TMallet: The current class instance to allow for method chaining.
"""
self.active_config = self._validate_config(algorithm, config)
self.active_obfuscator = self._get_obfuscator(algorithm)
self.active_obfuscator.set_config(self.active_config)
self.active_algorithm = algorithm
return self
[docs]
def obfuscate(self, text: str) -> dict | str:
"""Obfuscates standalone text strings or lists of strings.
Requires an obfuscator to be loaded via `load_obfuscator` prior to invocation.
Args:
text (Union[List[str], str]): Single text payload or collection of texts to process.
Raises:
RuntimeError: If an obfuscator and configuration have not been loaded yet.
Returns:
Dict: The modified, obfuscated text or collection of texts in the form of a dictionary.
"""
if (
self.active_obfuscator is None
and self.active_config is None
and not self.active_algorithm == "lemmatize"
):
raise RuntimeError(
"Please use `set_obfuscator` to setup the obfuscation details first."
)
if self.apply_spacy_preprocessing:
text = self.spacy_interface.process(text)
return self.active_obfuscator.obfuscate(text)
def _obfuscate_batch(
self,
batch: dict,
column: str,
column_obfuscated: str,
multi: bool = True,
) -> dict:
"""Processes a single dictionary batch extracted from a Dataset pipeline wrapper.
Args:
batch (Dict[str, Any]): A batch slice containing lists mapped to column keys.
column (str): The column key containing the raw text strings.
column_obfuscated (str): Target base column key for saving the output.
multi (bool, optional): If True, flattens a complex nested dictionary output
directly into the batch root elements. Defaults to True.
Raises:
KeyError: If the specified target data column does not exist inside the batch.
Returns:
Dict[str, Any]: The batch containing the obfuscated results.
"""
if column not in batch.keys():
raise KeyError(
f"Invalid column provided. Please choose one of {list(batch.keys())}"
)
texts = batch[column]
if not multi:
batch[column_obfuscated] = [self.obfuscate(text) for text in texts]
else:
obfuscation_output = [flatten_dict(self.obfuscate(text)) for text in texts]
all_keys = obfuscation_output[0].keys()
batch.update(
{
key: [sample[key] for sample in obfuscation_output]
for key in all_keys
}
)
return batch
[docs]
def obfuscate_dataset(
self,
dataset: Dataset,
column: str,
column_obfuscated: str,
batch_size: int = 10,
num_proc: int | None = None,
):
"""Maps obfuscation across an entire HuggingFace/compatible dataset object sequentially.
Args:
dataset (Dataset): The underlying dataset collection containing columns of data.
column (str): Key of the column containing raw target text.
column_obfuscated (str): Target base column key for saving the output.
batch_size (int, optional): Size of chunk arrays processed together. Defaults to 10.
num_proc (Optional[int], optional): CPU core count split handling parallel tasks.
Defaults to None.
Returns:
Dataset: A newly updated copy of the dataset containing obfuscation columns.
"""
obfuscated_dataset = dataset.map(
partial(
self._obfuscate_batch,
column=column,
column_obfuscated=column_obfuscated,
),
batched=True,
batch_size=batch_size,
desc="Obfuscating...",
num_proc=num_proc,
# cache_file_name=None,
load_from_cache_file=False,
)
return obfuscated_dataset
[docs]
def obfuscate_dataset_by_chunk(
self,
dataset_repo: str,
column: str,
column_obfuscated: str,
save_chunks_to_folder: Path,
dataset_config: str | None = None,
dataset_split: str = "train",
chunk_size: int = 5_000,
batch_size: int = 100,
num_proc: int | None = None,
start_index: int = 0,
num_samples: int | None = None,
) -> Dataset:
"""Streams a dataset from the Hub in chunks, obfuscates each chunk,
and saves checkpoints to disk for fault tolerance.
Args:
dataset_repo (str): HuggingFace Hub repo ID or local path for load_dataset.
column (str): Key of the column containing raw target text.
column_obfuscated (str): Target base column key for saving the output.
save_chunks_to_folder (Path): Directory to save/load disk checkpoints.
dataset_config (Optional[str]): Dataset config/subset name passed to load_dataset.
dataset_split (str): Split to stream (e.g. "train", "validation"). Defaults to "train".
chunk_size (int): Number of examples per chunk. Defaults to 5_000.
batch_size (int): Inner batch size passed to .map. Defaults to 100.
num_proc (Optional[int]): CPU parallelism for .map. Defaults to None.
start_index (int): Index of the first example to process. Defaults to 0.
num_samples (Optional[int]): Number of examples to process from start_index.
Defaults to None (process until stream is exhausted).
Returns:
Dataset: Concatenated dataset of all processed chunks.
"""
stream = load_dataset(
dataset_repo,
dataset_config,
split=dataset_split,
streaming=True,
)
stream = stream.skip(start_index)
if num_samples is not None:
stream = stream.take(num_samples)
iterator = iter(stream)
processed_chunks = []
# Align chunk_index to the global position so checkpoint filenames
# remain consistent with absolute dataset offsets.
chunk_index = start_index // chunk_size
# Skip any partial-chunk offset within the first chunk
partial_offset = start_index % chunk_size
if partial_offset:
list(islice(iterator, partial_offset))
while True:
start = chunk_index * chunk_size
end = start + chunk_size
ckpt_path = Path(save_chunks_to_folder) / f"obfuscated_ckpt_{start}_{end}"
if os.path.exists(ckpt_path):
print(f"Loading checkpoint {ckpt_path}")
chunk = load_from_disk(ckpt_path)
list(islice(iterator, chunk_size))
else:
rows = list(islice(iterator, chunk_size))
if not rows:
break
print(f"Processing examples {start}:{end}")
chunk = Dataset.from_list(rows)
chunk = self.obfuscate_dataset(
chunk,
column=column,
column_obfuscated=column_obfuscated,
batch_size=batch_size,
num_proc=num_proc,
)
chunk.save_to_disk(ckpt_path)
processed_chunks.append(chunk)
chunk_index += 1
obfuscated_dataset = concatenate_datasets(processed_chunks)
return obfuscated_dataset
def get_active_obfuscator(self):
return self.active_obfuscator
def _validate_config(self, algorithm: str, config: dict):
match algorithm:
case "pos-filter":
return POSFilterConfig(**config)
case "scramble-BoW":
return LinearScrambleConfig(**config)
case "scramble-hier":
return HierarchicalScrambleConfig(**config)
case "shannon":
return ShannonFilterConfig(**config)
case "lemmatize":
return None
def _get_obfuscator(
self, algorithm: ObfuscationTechnique
) -> Obfuscator | SpaCyObfuscator:
match algorithm:
case "pos-filter":
self.apply_spacy_preprocessing = True
self.spacy_interface.set_pipeline("pos")
return POSFilter()
case "scramble-hier":
self.apply_spacy_preprocessing = True
self.spacy_interface.set_pipeline("full")
return HierarchicalScrambleObfuscator()
case "scramble-BoW":
self.apply_spacy_preprocessing = False
return LinearScrambleObfuscator()
case "shannon":
self.apply_spacy_preprocessing = False
self.spacy_interface.set_pipeline("pos")
return ShannonFilter(
lang=self.lang,
spacy_interface=self.spacy_interface,
prefer_gpu=self.prefer_gpu,
)
case "lemmatize":
self.apply_spacy_preprocessing = True
return LemmaObfuscator()
case _:
raise ValueError(
f"Input {algorithm} invalid. Please provide a valid obfuscation algorithm."
)