Source code for fenn.datasets.text_dataset
from typing import Optional, Sequence, Union
import torch
from torch.utils.data import Dataset
[docs]
class TextDataset(Dataset):
"""
Generic text + binary label dataset.
X: list[str]
y: list[int|float]
"""
[docs]
def __init__(
self,
X: Sequence[str],
y: Optional[Sequence[Union[int, float]]],
tokenizer,
max_length: int = 1024,
):
self.X = list(X)
self.y = None if y is None else [float(v) for v in y]
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.X)
def __getitem__(self, idx: int):
enc = self.tokenizer(
self.X[idx],
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)
input_ids = enc["input_ids"].squeeze(0)
attention_mask = enc["attention_mask"].squeeze(0)
if self.y is None:
return input_ids, attention_mask
label = torch.tensor(self.y[idx], dtype=torch.float32)
return input_ids, attention_mask, label