This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torchvision import datasets, transforms | |
| from tqdm import tqdm | |
| import numpy as np | |
| # Import the generated predict function | |
| from predict_function import predict | |
| # Load MNIST test dataset | |
| transform = transforms.Compose([transforms.ToTensor()]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import time | |
| from pynput import keyboard | |
| from datetime import datetime | |
| import subprocess | |
| import threading | |
| import tkinter as tk | |
| import queue | |
| # ML imports |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <stdio.h> | |
| #include <math.h> | |
| #include <time.h> | |
| #define PI 3.14159265358979323846 | |
| double s_inv(double x) { | |
| return asin(2.0 * (x - 0.5)) / PI + 0.5; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def generate_speculative( | |
| model: nn.Module, | |
| draft_model: nn.Module, | |
| tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
| prompt: str, | |
| max_tokens: int = 100, | |
| verbose: bool = False, | |
| formatter: Optional[Callable] = None, | |
| **kwargs, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from mlx_lm import load | |
| import mlx.core as mx | |
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | |
| import numpy as np | |
| # Copyright © 2023-2024 Apple Inc. | |
| import contextlib | |
| import copy | |
| import glob | |
| import importlib |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from mlx_lm import load | |
| import mlx.core as mx | |
| from mlx.utils import tree_flatten, tree_map, tree_unflatten | |
| import numpy as np | |
| # Copyright © 2023-2024 Apple Inc. | |
| import contextlib | |
| import copy | |
| import glob | |
| import importlib |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def generate_batched( | |
| model: nn.Module, | |
| tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
| prompt: str, | |
| batch_size: int, | |
| *, | |
| verbose: bool = False, | |
| formatter: Optional[Callable] = None, | |
| max_tokens: int = 256, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import random | |
| import mlx.optimizers as optim | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import numpy as np | |
| from tqdm import tqdm | |
| import time | |
| from datetime import datetime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| # Define the differentiable orthonormal linear layer | |
| class OrthonormalLayer(nn.Module): | |
| def __init__(self, n): | |
| """ | |
| Initializes a learnable layer with an orthonormal weight matrix. | |
| :param n: Dimension of the square weight matrix. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import mlx.core as mx | |
| @mx.compile | |
| def _compute_T1(A): | |
| """I + A""" | |
| return mx.eye(A.shape[-1]) + A | |
| @mx.compile | |
| def _compute_T2(A): | |
| """I + A + A^2/2""" | |
| A2 = A @ A | |
| return mx.eye(A.shape[-1]) + A + A2/2 |