Skip to content

Commit

Permalink
Fix convention
Browse files Browse the repository at this point in the history
  • Loading branch information
hieuphmm committed Sep 10, 2024
1 parent a7d6332 commit 6bb71b8
Show file tree
Hide file tree
Showing 26 changed files with 1,839 additions and 350 deletions.
44 changes: 27 additions & 17 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,53 @@
import sys
from datetime import datetime

# Path setup
# -- Path setup --------------------------------------------------------------

# Add the path to your source code here.
sys.path.insert(0, os.path.abspath("../../src"))

# Project information
# -- Project information -----------------------------------------------------

PROJECT = "MELTs"
AUTHOR = "Thu Nguyen Hoang Anh"
COPYRIGHT = f"{datetime.now().year}, {AUTHOR}"
COPYRIGHT = f"{datetime.datetime.now().year}, {AUTHOR}"

# The full version, including alpha/beta/rc tags
RELEASE = "0.1"
# The version info for the project
VERSION = "0.1" # Short version (e.g., '0.1')
RELEASE = "0.1" # Full version (e.g., '0.1.0')

# General configuration
MASTER_DOC = "index"
# -- General configuration ---------------------------------------------------

# Sphinx extension modules as strings, can be built-in or custom
MASTER_DOC = "index" # The name of the master document

# Sphinx extensions to use
EXTENSIONS = [
"sphinx.ext.duration",
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx_rtd_theme",
"sphinx.ext.doctest",
"sphinx.ext.duration", # Measure build time
"sphinx.ext.autodoc", # Include documentation from docstrings
"sphinx.ext.coverage", # Check for documentation coverage
"sphinx.ext.doctest", # Test embedded doctests
"sphinx_rtd_theme", # Read the Docs theme
]

# List of modules to mock during autodoc generation
# Mock import for autodoc
AUTODOC_MOCK_IMPORTS = ["pyemd"]

# Paths that contain templates
TEMPLATES_PATH = ["_templates"]

# List of patterns to ignore when looking for source files
# Patterns to ignore when looking for source files
EXCLUDE_PATTERNS = []

# Sort members alphabetically in the autodoc
AUTODOC_MEMBER_ORDER = "alphabetical"

# Options for HTML output
# Theme to use for HTML and HTML Help pages
HTML_THEME = "sphinx_rtd_theme"

# Paths for custom static files (like style sheets)
# Theme options for customizing the appearance of the theme
HTML_THEME_OPTIONS = {
# You can add theme-specific options here
}

# Paths that contain custom static files (e.g., style sheets)
HTML_STATIC_PATH = ["_static"]
11 changes: 11 additions & 0 deletions src/melt/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
This script initializes NLP models and runs the main function from the 'cli' module.
The script performs the following tasks:
1. Downloads the 'punkt' tokenizer models using nltk.
2. Loads the spaCy 'en_core_web_sm' model, downloading it if necessary.
3. Imports and executes the 'main' function from the 'cli' module.
If any module or function cannot be imported, appropriate error messages are displayed.
"""

import logging
import spacy
import nltk
Expand Down
76 changes: 54 additions & 22 deletions src/melt/cli.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,64 @@
import spacy
from spacy.cli import download
from transformers import HfArgumentParser
from dotenv import load_dotenv
"""
This script initializes and runs the text generation pipeline using spaCy,
transformers, and dotenv. It also handles downloading the spaCy 'en_core_web_sm'
model if it is not already present.
from script_arguments import ScriptArguments # Ensure this module is in the correct path
from generation import generation # Ensure this module is in the correct path
The main function is responsible for:
1. Loading environment variables.
2. Parsing script arguments.
3. Running the generation process with the parsed arguments.
"""
try:
import spacy
except ImportError as e:
print(f"Failed to import 'spacy': {e}")

def ensure_spacy_model(model_name="en_core_web_sm"):
"""
Ensure the spaCy model is available. Download it if not present.
"""
try:
spacy.load("en_core_web_sm")
except OSError:
print(
"Downloading the spacy en_core_web_sm model\n"
"(don't worry, this will only happen once)"
)
try:
spacy.load(model_name)
print(f"spaCy model '{model_name}' is already installed.")
except OSError:
print(f"spaCy model '{model_name}' not found. Downloading...")
download(model_name)
print(f"spaCy model '{model_name}' has been downloaded and installed.")
from spacy.cli import download
download("en_core_web_sm")

except ImportError as e:
print(f"Failed to import 'spacy.cli': {e}")
try:
from transformers import HfArgumentParser
except ImportError as e:
print(f"Failed to import 'transformers': {e}")

try:
from dotenv import load_dotenv
except ImportError as e:
print(f"Failed to import 'dotenv': {e}")

try:
from .script_arguments import ScriptArguments
except ImportError as e:
print(f"Failed to import 'ScriptArguments' from 'script_arguments': {e}")
try:
from .generation import generation
except ImportError as e:
print(f"Failed to import 'generation' from 'generation': {e}")

def main():
"""
Main function to:
1. Load environment variables from a .env file.
2. Ensure the spaCy model is available.
3. Parse command-line arguments.
4. Execute the generation function with the parsed arguments.
The main function that initializes the environment, parses script arguments,
and triggers the text generation process.
This function performs the following steps:
1. Loads environment variables using `load_dotenv()`.
2. Creates an argument parser for `ScriptArguments` using `HfArgumentParser`.
3. Parses the arguments into data classes.
4. Calls the `generation` function with the parsed arguments to perform the text generation.
Returns:
None
"""
# Load environment variables
load_dotenv()

# Ensure spaCy model is available
Expand Down
60 changes: 59 additions & 1 deletion src/melt/generation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,69 @@
"""
This module provides functionality for evaluating and
generating data using specified pipelines and datasets.
The `generation` function is the main entry point of this script. It performs the following tasks:
1. Initializes the seed for reproducibility.
2. Loads and processes the dataset using `DatasetWrapper`.
3. Sets up directories for saving results if they don't already exist.
4. Handles continuation of inference from a previous run if specified.
5. Creates a DataLoader for batching dataset examples.
6. Initializes the evaluation pipeline (`EvalPipeline`).
7. Runs the evaluation pipeline and saves the results to JSON files.
The script is designed to work with various configurations
specified in the `script_args` parameter, including options for
few-shot prompting and continuing from previous results.
Modules used:
- `os`: For file and directory operations.
- `.tools.data`: Contains `DatasetWrapper` for
dataset management.
- `.tools.pipelines`: Contains `EvalPipeline` for
evaluation processes.
- `.tools.utils.utils`: Provides utility functions such as
`save_to_json`, `set_seed`, and `read_json`.
- `torch.utils.data`: For data loading with `DataLoader`.
"""
import os
from torch.utils.data import DataLoader
from .tools.data import DatasetWrapper
from .tools.pipelines import EvalPipeline
from .tools.utils.utils import save_to_json, set_seed, read_json
from torch.utils.data import DataLoader



def generation(script_args):
"""
Executes the data generation process based on the provided script arguments.
This function performs the following steps:
1. Sets the random seed for reproducibility using `set_seed`.
2. Loads and optionally processes the dataset using `DatasetWrapper`.
3. Constructs filenames for saving generation results and metrics based on the script arguments.
4. Creates necessary directories for saving results if they don't already exist.
5. Determines the starting index and results to continue
inference from a previous run if specified.
6. Initializes a `DataLoader` for batching the dataset examples.
7. Initializes an `EvalPipeline` for evaluating the data.
8. Runs the evaluation pipeline and saves the results using the `save_results` function.
Args:
script_args (ScriptArguments): An object containing the configuration
and parameters for the data generation process.
- seed (int): Random seed for reproducibility.
- smoke_test (bool): Flag to indicate if a smaller subset
of data should be used for testing.
- dataset_name (str): Name of the dataset.
- model_name (str): Name of the model.
- output_dir (str): Directory to save generation results.
- output_eval_dir (str): Directory to save evaluation metrics.
- continue_infer (bool): Flag to continue inference from a previous run.
- per_device_eval_batch_size (int): Batch size for evaluation.
- fewshot_prompting (bool): Flag for few-shot prompting.
Returns:
None
"""
set_seed(script_args.seed)

# Load dataset (you can process it here)
Expand Down
Loading

0 comments on commit 6bb71b8

Please sign in to comment.