I really like to peek into different ML codebases for distributed training and this is a very short post on some things I found interesting in Torch Titan:
Disable and control of Python’s garbage collector (GC): titan codebase disables the Python GC and then manually forces a collection in the beginning of every training step during the training loop. This makes sense, but I’m not sure what are the gains of doing it, I think doing every step can be too much and I’m not sure if taking control of GC would be worth for the gains you get by manually controlling it, especially depending on complexity of other dependencies you use, as this could cause unintended behavior that would be difficult to associate with the GC collection;
Custom GPU memory monitoring: titan has a custom class to monitor GPU memory that is quite nice, it resets peak stats and empty the CUDA caching allocator upon initialization. At every step then they collect the peak stats for both small and large pools by capturing the stats for active, reserved and also failed retries and number of OOMs. It is very common for people to just monitor max GPU usage externally from NVML, however, this ignores the fact that PyTorch uses a caching allocator and that you need to look at the internal memory management mechanism inside PyTorch. If you don’t do that, you will certainly be mislead by what you are getting from NVML;
Custom profiling context manager: they wrote a context manager for profiling, where they measure time it takes to dump the profiling data per rank. Interesting here that there is a barrier at the end, which makes sense, but this is often the pain point of distributed training with PyTorch + NCCL;
Measuring data loading: this is of minor interest, but I liked the idea of not iterating on data loader in the loop statement itself but manually calling next() to get the batches, that makes it easier to measure data loading, that they average at the end for each epoch;
Logging MFU (model FLOPS utilization): they also compute and log MFU, which is quite helpful;
Delete predictions before backward: titan also deletes the model predictions before the backward() call to avoid memory peaks. This can be quite effective since you really don’t need this tensor anymore and you can delete it immediately before the backward pass;
Reduction of NCCL timeout: after the first training step, they reduce the NCCL timeout from the default 10 min to 100 sec. This is nice if you have well behaved replicas code and don’t need to do anything more complex, but 100 sec is a very short timeout that I would be careful using, it might be a good fit for your load but if your replicas drift a bit more, then you will need to keep adding barriers to avoid timeouts that can be incredibly difficult to debug and cause a lot of headaches;
Distributed checkpointing with mid-epoch checkpoint support: this is a very cool implementation, it uses distributed checkpointing from PyTorch. They create some wrappers (e.g. for optimizer) where they implement the Stateful protocol to support checkpointing. They also use the StatefulDataLoader from torchdata to do checkpointing of mid-epoch data loader state;
Misc: there are of course other interesting things, but it is cool to mention that they also implemented a no frills LLaMA model without relying on thousands of different libs (it seems it became fashionable nowadays to keep adding dependencies), so kudos for that to keep it simple.
This is just a fun experiment to answer the question: how can I share a memory-mapped tensor from PyTorch to Numpy, Jax and TensorFlow in CPU without copy and making sure changes done in memory by torch are reflected on all these shared tensors ?
One approach is shown below:
import torch
import tensorflow as tf
import numpy as np
import jax.numpy as jnp
import jax.dlpack
# Create the tensor and persist
t = torch.randn(10, dtype=torch.float32)
t.numpy().tofile("tensor.pt")
# Memory-map the file with PyTorch
t_mapped = torch.from_file("tensor.pt", shared=True, size=10, dtype=torch.float32)
# Memory-map it with numpy, the same tensor
n_mapped = np.memmap("tensor.pt", dtype='float32', mode='r+', shape=(10))
# Convert it to Jax, will reuse the same buffer
j_mapped = jnp.asarray(n_mapped)
# Convert it to dlpack capsule and load in TensorFlow
dlcapsule = jax.dlpack.to_dlpack(j_mapped)
tf_mapped = tf.experimental.dlpack.from_dlpack(dlcapsule)
Now the fun part begins, I will change the tensor in PyTorch and we will check what happens in the Numpy, Jax and TensorFlow tensors:
One of the most interesting, but also obscure and difficult parts of Kant’s critique is schematism. Every time I reflect on generalisation in Machine Learning and how concepts should be grounded, it always leads to the same central problem of schematism. Friedrich H. Jacobi said that schematism was “the most wonderful and most mysterious of all unfathomable mysteries and wonders …” [1], and Schopenhauer also said that it was “famous for its profound darkness, because nobody has yet been able to make sense of it” [1].
It is very rewarding, however, to realize that it is impossible to read Kant without relating much of his revolutionary philosophy to the difficult problems we are facing (and had always been) in AI, especially regarding generalisation. The first edition of the Critique of Pure Reason (CPR) was published more than 240 years ago, therefore historical context is often required to understand Kant’s writing, and to make things worse there is a lot of debate and lack of consensus among Kant’s scholars, however, even with these difficulties, it is still one of the most relevant and worth reading works of philosophy today.
Just sharing ~100 slides about PyTorch 2 internals focusing on recent innovations (Dynamo, Inductor, and ExecuTorch). I had a lot of fun preparing this and hope you’ll enjoy it. I’m planning to record it soon.
We are so used to Euclidean geometry that we often overlook the significance of curved geometries and the methods for measuring things that don’t reside on orthonormal bases. Just as understanding physics and the curvature of spacetime requires Riemannian geometry, I believe a profound comprehension of Machine Learning (ML) and data is also not possible without it. There is an increasing body of research that integrates differential geometry into ML. Unfortunately, the term “geometric deep learning” has predominantly become associated with graphs. However, modern geometry offers much more than just graph-related applications in ML.
I was reading the excellent article from Sander Dieleman about different perspectives on diffusion, so I thought it would be cool to try to contribute a bit with a new perspective.
A tale of two scores
Fisher information, metric and score
There are two important quantities that are widely known today and that keep popping out basically everywhere. The first one is the fisher information matrix \( \mathbf{F}\) (or FIM):
$$\mathbf{F}_\theta = \mathop{\mathbb{E}} \left[ \nabla_\theta \log p_\theta(y \vert x) \, \nabla_\theta \log p_\theta(y \vert x)^T \right] \,$$ with \(y \sim p_\theta (y \vert x)\) and \(x \sim p_{\text{data}}\). Note that where \(y\) comes from is very important and often a source of confusion. \(y\) is from the model’s predictive distribution (and this is quite interesting because it means you don’t need labels to estimate \( \mathbf{F}\) as well). The FIM is used in many places, such as Cramér-Rao bound, continual learning, posterior approximation, optimization, bayesian prior, KL divergence curvature, etc. Note that there is a lot of debate about the FIM vs empirical FIM and their different properties that I will skip going over here (I discussed this in the optimization context in this presentation if you are interested).
The fisher information matrix is also used in information geometry as a Riemannian metric where it is called Fisher-Rao metric (there are other names for it as well, which can be quite confusing). In this statistical manifold, where coordinates are parametrizing probability distributions, the metric (which equips the manifold) induces a inner product and allows us to compute norms and distances for distributions. Information geometry was pioneered by the late C. R. Rao and further developed and popularized by Shun-ichi Amari (who wrote some fine books about it).
We will talk more about the statistical manifold and what the metric actually does more intuitively later, but for now, note that the FIM uses the score, or what we can call, the Fisher score:
This score is the gradient of the log-likelihood w.r.t. its parameters \(\theta\), so it is telling us the steepness of the likelihood, with the FIM meaning the variance of this score. The FIM is also equivalent to the negative expectation of the Hessian matrix, which points its significance as a curvature at a parameter point, hence its appearance as a metric tensor as well (to be precise, as a metric tensor field).
The other score, as in score-based models (aka Stein score)
Now, there is another score, which is the one used in score-based models and score matching, which is often called Stein score:
Note that even though it looks similar and has a similar name to the previous score we showed, this is a very different score function. It doesn’t give you the gradients for distribution’s parameters but gradients w.r.t. data. It has been shown that we can estimate this score function from data even in absence of ground truths to this quantity. Yang Song has a nice article explaining motivation and recent developments.
The main point is that once you have this score function, you have a very powerful gradient field that tells you how samples should move in data space. You can then sample from the data distribution using Langevin sampling, which is basically SGD with noise to avoid collapse to a minima.
The missing metric
If the Fisher score gives the building block to the metric tensor for the statistical manifold, which metric can we build with this (Stein) score and which manifold does it belongs to ? It is surprising that we still don’t seem to have a clear formalization for this yet, at least I wasn’t able to find much about it. You can find some works about diffusion models on Riemannian manifolds but not about using the estimated (through modern deep learning models) score to build a Riemannian metric.
There is a nice quote from the physicist John Wheeler about Einstein’s relativity:
Space-time tells matter how to move and matter tells space-time how to curve.
– John Wheeler
It is very interesting that we can build a metric using this estimated score function, with the same mathematical framework used in the theory of relativity, where the quote can be modified to our case as:
Diffusion models tells data how to move and data tells Diffusion models how to curve.
I will start to explore the topic with some examples in a series of posts, but here is a glimpse of a geodesic using the stein score as metric tensor where a Gaussian is curving the data manifold and creating this structure where the shortest distance from two points is not a straight line anymore:
This is a very interesting connection, seeing diffusion and score-based models as a metric tensor field can give us very interesting tools to explore data distances, geodesics, norms, etc, from the data manifold itself. We are still in the statistical domain, but the manifold is not the statistical manifold anymore where Riemannian coordinates parametrize distributions, it is a manifold where coordinates are the samples themselves. I think this connection of the score with the metric tensor field is a unexplored domain that is definitely very fertile, it can give us a much deeper understanding not only of data but also about our sampling algorithms.
The inner product induced by the score metric is the following:
Note that we are using the (Stein) score as building block for our metric tensor \(g_x\), and this score is replaced by the estimated one parametrized by a deep neural network, so notation can become a nightmare because the base point where the metric tensor is evaluated is already used as lower index, so it can become \(g^{\theta}_x\) to denote that this metric tensor is parametrized by \(\theta\) (to make things worse, in diff geometry, indices positions also has an important meaning).
Hope you like the idea and please provide feedback and keep an eye in the next posts of this series.
Updates
27 Sept 2023: added more details about the metric tensor definition using the (Stein) score; 3 Jun 2024: changes to improve clarity.
We have been training language models (LMs) for years, but finding valuable resources about the data pipelines commonly used to build the datasets for training these models is paradoxically challenging. It may be because we often take it for granted that these datasets exist (or at least existed? As replicating them is becoming increasingly difficult). However, one must consider the numerous decisions involved in creating such pipelines, as it can significantly impact the final model’s quality, as seen recently in the struggle of models aiming to replicate LLaMA (LLaMA: Open and Efficient Foundation Language Models). It might be tempting to think that now, with large models that can scale well, data is becoming more critical than modeling, since model architectures are not radically changing much. However, data has always been critical.
The entire pipeline of CCNet (plus some minor modifications made by LLaMA’s paper) can be seen below. It has the following stages: data source, deduplication, language, filtering, and the “is-reference” filtering which was added in LLaMA. I will go through each one of them in the sections below.
I just released Feste, a free and open-source framework with a permissive license that allows scalable composition of NLP tasks using a graph execution model that is optimized and executed by specialized schedulers. The main idea behind Feste is that it builds a graph of execution instead of executing tasks immediately, this graph allows Feste to optimize and parallelize it. One main example of optimization is when we have multiple calls to the same backend (e.g. same API), Feste automatically fuses these calls into a single one and therefore it batches the call to reduce latency and improve backend inference leverage of GPU vectorization. Feste also executes tasks that can be done in parallel in different processes, so the user doesn’t have to care about parallelization, especially when there are multiple frameworks using different concurrency strategies.
Uncertainty quantification for deep neural networks has recently evolved through many techniques. In this work, we revisit Laplace approximation, a classical approach for posterior approximation that is computationally attractive. However, instead of computing the curvature matrix, we show that, under some regularity conditions, the Laplace approximation can be easily constructed using the gradient second moment. This quantity is already estimated by many exponential moving average variants of Adagrad such as Adam and RMSprop, but is traditionally discarded after training. We show that our method (L2M) does not require changes in models or optimization, can be implemented in a few lines of code to yield reasonable results, and it does not require any extra computational steps besides what is already being computed by optimizers, without introducing any new hyperparameter. We hope our method can open new research directions on using quantities already computed by optimizers for uncertainty estimation in deep neural networks.
The imitation learning of self-driving vehicle policies through behavioral cloning is often carried out in an open-loop fashion, ignoring the effect of actions to future states. Training such policies purely with Empirical Risk Minimization (ERM) can be detrimental to real-world performance, as it biases policy networks towards matching only open-loop behavior, showing poor results when evaluated in closed-loop. In this work, we develop an efficient and simple-to-implement principle called Closed-loop Weighted Empirical Risk Minimization (CW-ERM), in which a closed-loop evaluation procedure is first used to identify training data samples that are important for practical driving performance and then we these samples to help debias the policy network. We evaluate CW-ERM in a challenging urban driving dataset and show that this procedure yields a significant reduction in collisions as well as other non-differentiable closed-loop metrics.
SafePathNet: Safe Real-World Autonomous Driving by Learning to Predict and Plan with a Mixture of Experts
The goal of autonomous vehicles is to navigate public roads safely and comfortably. To enforce safety, traditional planning approaches rely on handcrafted rules to generate trajectories. Machine learning-based systems, on the other hand, scale with data and are able to learn more complex behaviors. However, they often ignore that agents and self-driving vehicle trajectory distributions can be leveraged to improve safety. In this paper, we propose modeling a distribution over multiple future trajectories for both the self-driving vehicle and other road agents, using a unified neural network architecture for prediction and planning. During inference, we select the planning trajectory that minimizes a cost taking into account safety and the predicted probabilities. Our approach does not depend on any rule-based planners for trajectory generation or optimization, improves with more training data and is simple to implement. We extensively evaluate our method through a realistic simulator and show that the predicted trajectory distribution corresponds to different driving profiles. We also successfully deploy it on a self-driving vehicle on urban public roads, confirming that it drives safely without compromising comfort.
This website uses cookies to improve your experience. We'll assume you're ok with this, but you can opt-out if you wish. Cookie settingsACCEPT
Privacy & Cookies Policy
Privacy Overview
This website uses cookies to improve your experience while you navigate through the website. Out of these cookies, the cookies that are categorized as necessary are stored on your browser as they are essential for the working of basic functionalities of the website. We also use third-party cookies that help us analyze and understand how you use this website. These cookies will be stored in your browser only with your consent. You also have the option to opt-out of these cookies. But opting out of some of these cookies may have an effect on your browsing experience.
Necessary cookies are absolutely essential for the website to function properly. This category only includes cookies that ensures basic functionalities and security features of the website. These cookies do not store any personal information.
Any cookies that may not be particularly necessary for the website to function and is used specifically to collect user personal data via analytics, ads, other embedded contents are termed as non-necessary cookies. It is mandatory to procure user consent prior to running these cookies on your website.